dd0e8eec14
* copyright headers for agent folder * Ignore test data files * fix proto files and remove headers in agent/uiserver folder * ignore deep-copy files
125 lines
3.5 KiB
Go
125 lines
3.5 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
//go:build windows
|
|
// +build windows
|
|
|
|
package checks
|
|
|
|
import (
|
|
"fmt"
|
|
"syscall"
|
|
|
|
"golang.org/x/sys/windows"
|
|
"golang.org/x/sys/windows/svc"
|
|
"golang.org/x/sys/windows/svc/mgr"
|
|
)
|
|
|
|
var (
|
|
win windowsSystem = windowsOS{}
|
|
)
|
|
|
|
type OSServiceClient struct{}
|
|
|
|
func NewOSServiceClient() (*OSServiceClient, error) {
|
|
return &OSServiceClient{}, nil
|
|
}
|
|
|
|
func (client *OSServiceClient) Check(serviceName string) (err error) {
|
|
h, err := windows.OpenSCManager(nil, nil, windows.SC_MANAGER_CONNECT)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to Windows service manager: %w", err)
|
|
}
|
|
|
|
m := win.getWindowsSvcMgr(h)
|
|
defer m.Disconnect()
|
|
|
|
svcNamePtr, err := syscall.UTF16PtrFromString(serviceName)
|
|
if err != nil {
|
|
return fmt.Errorf("service name must not contain NUL bytes: %w", err)
|
|
}
|
|
svcHandle, err := win.OpenService(win.getWindowsSvcMgrHandle(m), svcNamePtr, windows.SC_MANAGER_ENUMERATE_SERVICE)
|
|
if err != nil {
|
|
return fmt.Errorf("error accessing service: %w", err)
|
|
}
|
|
service := win.getWindowsSvc(serviceName, svcHandle)
|
|
defer service.Close()
|
|
|
|
status, err := service.Query()
|
|
if err != nil {
|
|
return fmt.Errorf("error querying service status: %w", err)
|
|
}
|
|
|
|
switch status.State {
|
|
case svc.Running:
|
|
return nil
|
|
case svc.Paused, svc.Stopped:
|
|
err = fmt.Errorf("service status: %v - %w", svcStateString(status.State), ErrOSServiceStatusCritical)
|
|
default:
|
|
err = fmt.Errorf("service status: %v", svcStateString(status.State))
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
type windowsOS struct{}
|
|
|
|
func (windowsOS) OpenSCManager(machineName *uint16, databaseName *uint16, access uint32) (handle windows.Handle, err error) {
|
|
return windows.OpenSCManager(machineName, databaseName, access)
|
|
}
|
|
func (windowsOS) OpenService(mgr windows.Handle, serviceName *uint16, access uint32) (handle windows.Handle, err error) {
|
|
return windows.OpenService(mgr, serviceName, access)
|
|
}
|
|
|
|
func (windowsOS) getWindowsSvcMgr(h windows.Handle) windowsSvcMgr { return &mgr.Mgr{Handle: h} }
|
|
func (windowsOS) getWindowsSvcMgrHandle(sm windowsSvcMgr) windows.Handle {
|
|
return sm.(*mgr.Mgr).Handle
|
|
}
|
|
|
|
func (windowsOS) getWindowsSvc(name string, h windows.Handle) windowsSvc {
|
|
return &mgr.Service{Name: name, Handle: h}
|
|
}
|
|
|
|
type windowsSystem interface {
|
|
OpenSCManager(machineName *uint16, databaseName *uint16, access uint32) (handle windows.Handle, err error)
|
|
OpenService(mgr windows.Handle, serviceName *uint16, access uint32) (handle windows.Handle, err error)
|
|
|
|
getWindowsSvcMgr(h windows.Handle) windowsSvcMgr
|
|
getWindowsSvcMgrHandle(sm windowsSvcMgr) windows.Handle
|
|
getWindowsSvc(name string, h windows.Handle) windowsSvc
|
|
}
|
|
|
|
type windowsSvcMgr interface {
|
|
Disconnect() error
|
|
}
|
|
|
|
type windowsSvc interface {
|
|
Close() error
|
|
Query() (svc.Status, error)
|
|
}
|
|
|
|
// svcStateString converts svc.State (uint32) to human readable string
|
|
//
|
|
// source: https://pkg.go.dev/golang.org/x/sys/windows/svc#pkg-constants
|
|
func svcStateString(state svc.State) string {
|
|
switch state {
|
|
case svc.State(windows.SERVICE_STOPPED):
|
|
return "Stopped"
|
|
case svc.State(windows.SERVICE_START_PENDING):
|
|
return "StartPending"
|
|
case svc.State(windows.SERVICE_STOP_PENDING):
|
|
return "StopPending"
|
|
case svc.State(windows.SERVICE_RUNNING):
|
|
return "Running"
|
|
case svc.State(windows.SERVICE_CONTINUE_PENDING):
|
|
return "ContinuePending"
|
|
case svc.State(windows.SERVICE_PAUSE_PENDING):
|
|
return "PausePending"
|
|
case svc.State(windows.SERVICE_PAUSED):
|
|
return "Paused"
|
|
default:
|
|
//if not handled we return the underlying uint32
|
|
return fmt.Sprintf("%d", state)
|
|
}
|
|
}
|