tests: added syscall mocking and tests for Check_OSService

This commit is contained in:
Alessandro De Blasis 2022-06-09 15:48:34 +01:00
parent 4592351260
commit b53bb6f70e
2 changed files with 501 additions and 11 deletions

View File

@ -0,0 +1,430 @@
//go:build windows
// +build windows
package checks
import (
"errors"
"testing"
"time"
"github.com/hashicorp/consul/agent/mock"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/sdk/testutil"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
)
func TestOSServiceClient(t *testing.T) {
type args struct {
returnsOpenSCManagerError error
returnsOpenServiceError error
returnsServiceQueryError error
returnsServiceCloseError error
returnsSCMgrDisconnectError error
returnsServiceState svc.State
}
tests := []struct {
name string
args args
maybeHealthy *bool
}{
// healthy
{"should pass for healthy service", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Running,
}, boolPointer(true)},
{"should pass for healthy service even when there's an error closing the service handle", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: errors.New("error while closing the service handle"),
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Running,
}, boolPointer(true)},
{"should pass for healthy service even when there's an error disconnecting from SCManager", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: errors.New("error while disconnecting from service manager"),
returnsServiceState: svc.Running,
}, boolPointer(true)},
// warning
{"should be in warning state for any state that's not Running, Paused or Stopped", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.StartPending,
}, nil},
{"should be in warning state when we cannot connect to the service manager", args{
returnsOpenSCManagerError: errors.New("cannot connect to service manager"),
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Running,
}, nil},
{"should be in warning state when we cannot open the service", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: errors.New("service testService does not exist"),
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Running,
}, nil},
{"should be in warning state when we cannot query the service state", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: errors.New("cannot query testService state"),
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Running,
}, nil},
{"should be in warning state for for any state that's not Running, Paused or Stopped when there's an error closing the service handle", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: errors.New("error while closing the service handle"),
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.StartPending,
}, nil},
{"should be in warning state for for any state that's not Running, Paused or Stopped when there's an error disconnecting from SCManager", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: errors.New("error while disconnecting from service manager"),
returnsServiceState: svc.StartPending,
}, nil},
// critical
{"should fail for paused service", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Paused,
}, boolPointer(false)},
{"should fail for stopped service", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Stopped,
}, boolPointer(false)},
{"should fail for stopped service even when there's an error closing the service handle", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: errors.New("error while closing the service handle"),
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Stopped,
}, boolPointer(false)},
{"should fail for stopped service even when there's an error disconnecting from SCManager", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: errors.New("error while disconnecting from service manager"),
returnsServiceState: svc.Stopped,
}, boolPointer(false)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
old := win
defer func() { win = old }()
win = fakeWindowsOS{
returnsOpenSCManagerError: tt.args.returnsOpenSCManagerError,
returnsOpenServiceError: tt.args.returnsOpenServiceError,
returnsServiceQueryError: tt.args.returnsServiceQueryError,
returnsServiceCloseError: tt.args.returnsServiceCloseError,
returnsSCMgrDisconnectError: tt.args.returnsSCMgrDisconnectError,
returnsServiceState: tt.args.returnsServiceState,
}
probe, err := NewOSServiceClient()
if (tt.args.returnsOpenSCManagerError != nil && err == nil) || (tt.args.returnsOpenSCManagerError == nil && err != nil) {
t.Errorf("FAIL: %s. Expected error on OpenSCManager %v , but err == %v", tt.name, tt.args.returnsOpenSCManagerError, err)
}
if err != nil {
return
}
actualError := probe.Check("testService")
actuallyHealthy := actualError == nil
actualErrorIsCritical := errors.Is(actualError, ErrOSServiceStatusCritical)
actualWarning := !actuallyHealthy && !actualErrorIsCritical
expectedHealthy := tt.maybeHealthy != nil && *tt.maybeHealthy
expectedWarning := tt.maybeHealthy == nil
if expectedHealthy && !actuallyHealthy {
t.Errorf("FAIL: %s. Expected healthy %t, but err == %v", tt.name, boolVal(tt.maybeHealthy), actualError)
}
if expectedWarning && !actualWarning {
t.Errorf("FAIL: %s. Expected non critical error, but err == %v", tt.name, actualError)
}
})
}
}
func TestCheck_OSService(t *testing.T) {
type args struct {
returnsOpenSCManagerError error
returnsOpenServiceError error
returnsServiceQueryError error
returnsServiceCloseError error
returnsSCMgrDisconnectError error
returnsServiceState svc.State
}
tests := []struct {
desc string
args args
state string
}{
//healthy
{"should pass for healthy service", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Running,
}, api.HealthPassing},
{"should pass for healthy service even when there's an error closing the service handle", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: errors.New("error while closing the service handle"),
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Running,
}, api.HealthPassing},
{"should pass for healthy service even when there's an error disconnecting from SCManager", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: errors.New("error while disconnecting from service manager"),
returnsServiceState: svc.Running,
}, api.HealthPassing},
// // warning
{"should be in warning state for any state that's not Running, Paused or Stopped", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.StartPending,
}, api.HealthWarning},
{"should be in warning state when we cannot connect to the service manager", args{
returnsOpenSCManagerError: errors.New("cannot connect to service manager"),
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Running,
}, api.HealthWarning},
{"should be in warning state when we cannot open the service", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: errors.New("service testService does not exist"),
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Running,
}, api.HealthWarning},
{"should be in warning state when we cannot query the service state", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: errors.New("cannot query testService state"),
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Running,
}, api.HealthWarning},
{"should be in warning state for for any state that's not Running, Paused or Stopped when there's an error closing the service handle", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: errors.New("error while closing the service handle"),
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.StartPending,
}, api.HealthWarning},
{"should be in warning state for for any state that's not Running, Paused or Stopped when there's an error disconnecting from SCManager", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: errors.New("error while disconnecting from service manager"),
returnsServiceState: svc.StartPending,
}, api.HealthWarning},
// critical
{"should fail for paused service", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Paused,
}, api.HealthCritical},
{"should fail for stopped service", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Stopped,
}, api.HealthCritical},
{"should fail for stopped service even when there's an error closing the service handle", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: errors.New("error while closing the service handle"),
returnsSCMgrDisconnectError: nil,
returnsServiceState: svc.Stopped,
}, api.HealthCritical},
{"should fail for stopped service even when there's an error disconnecting from SCManager", args{
returnsOpenSCManagerError: nil,
returnsOpenServiceError: nil,
returnsServiceQueryError: nil,
returnsServiceCloseError: nil,
returnsSCMgrDisconnectError: errors.New("error while disconnecting from service manager"),
returnsServiceState: svc.Stopped,
}, api.HealthCritical},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
old := win
defer func() { win = old }()
win = fakeWindowsOS{
returnsOpenSCManagerError: tt.args.returnsOpenSCManagerError,
returnsOpenServiceError: tt.args.returnsOpenServiceError,
returnsServiceQueryError: tt.args.returnsServiceQueryError,
returnsServiceCloseError: tt.args.returnsServiceCloseError,
returnsSCMgrDisconnectError: tt.args.returnsSCMgrDisconnectError,
returnsServiceState: tt.args.returnsServiceState,
}
c, err := NewOSServiceClient()
if (tt.args.returnsOpenSCManagerError != nil && err == nil) || (tt.args.returnsOpenSCManagerError == nil && err != nil) {
t.Errorf("FAIL: %s. Expected error on OpenSCManager %v , but err == %v", tt.desc, tt.args.returnsOpenSCManagerError, err)
}
if err != nil {
return
}
notif, upd := mock.NewNotifyChan()
logger := testutil.Logger(t)
statusHandler := NewStatusHandler(notif, logger, 0, 0, 0)
id := structs.NewCheckID("chk", nil)
check := &CheckOSService{
CheckID: id,
OSService: "testService",
Interval: 25 * time.Millisecond,
Client: c,
Logger: logger,
StatusHandler: statusHandler,
}
check.Start()
defer check.Stop()
<-upd // wait for update
if got, want := notif.State(id), tt.state; got != want {
t.Fatalf("got status %q want %q", got, want)
}
})
}
}
const (
validSCManagerHandle = windows.Handle(1)
validOpenServiceHandle = windows.Handle(2)
)
type fakeWindowsOS struct {
returnsOpenSCManagerError error
returnsOpenServiceError error
returnsServiceQueryError error
returnsServiceCloseError error
returnsSCMgrDisconnectError error
returnsServiceState svc.State
}
func (f fakeWindowsOS) OpenSCManager(machineName *uint16, databaseName *uint16, access uint32) (handle windows.Handle, err error) {
if f.returnsOpenSCManagerError != nil {
return windows.InvalidHandle, f.returnsOpenSCManagerError
}
return validSCManagerHandle, nil
}
func (f fakeWindowsOS) OpenService(mgr windows.Handle, serviceName *uint16, access uint32) (handle windows.Handle, err error) {
if f.returnsOpenServiceError != nil {
return windows.InvalidHandle, f.returnsOpenServiceError
}
return validOpenServiceHandle, nil
}
func (f fakeWindowsOS) getWindowsSvcMgr(h windows.Handle) windowsSvcMgr {
return &fakeWindowsSvcMgr{
Handle: h,
returnsDisconnectError: f.returnsSCMgrDisconnectError,
}
}
func (fakeWindowsOS) getWindowsSvcMgrHandle(sm windowsSvcMgr) windows.Handle {
return sm.(*fakeWindowsSvcMgr).Handle
}
func (f fakeWindowsOS) getWindowsSvc(name string, h windows.Handle) windowsSvc {
return &fakeWindowsSvc{
Name: name,
Handle: h,
returnsCloseError: f.returnsServiceCloseError,
returnsServiceQueryError: f.returnsServiceQueryError,
returnsServiceState: f.returnsServiceState,
}
}
type fakeWindowsSvcMgr struct {
Handle windows.Handle
returnsDisconnectError error
}
func (f fakeWindowsSvcMgr) Disconnect() error { return f.returnsDisconnectError }
type fakeWindowsSvc struct {
Handle windows.Handle
Name string
returnsServiceQueryError error
returnsCloseError error
returnsServiceState svc.State
}
func (f fakeWindowsSvc) Close() error { return f.returnsCloseError }
func (f fakeWindowsSvc) Query() (svc.Status, error) {
if f.returnsServiceQueryError != nil {
return svc.Status{}, f.returnsServiceQueryError
}
return svc.Status{State: f.returnsServiceState}, nil
}
func boolPointer(b bool) *bool {
return &b
}
func boolVal(v *bool) bool {
if v == nil {
return false
}
return *v
}

View File

@ -12,13 +12,17 @@ import (
"golang.org/x/sys/windows/svc/mgr"
)
var (
win windowsSystem = windowsOS{}
)
type OSServiceClient struct {
scHandle windows.Handle
}
func NewOSServiceClient() (*OSServiceClient, error) {
var s *uint16
scHandle, err := windows.OpenSCManager(s, nil, windows.SC_MANAGER_CONNECT)
scHandle, err := win.OpenSCManager(s, nil, windows.SC_MANAGER_CONNECT)
if err != nil {
return nil, fmt.Errorf("error connecting to service manager: %w", err)
@ -29,15 +33,32 @@ func NewOSServiceClient() (*OSServiceClient, error) {
}, nil
}
func (client *OSServiceClient) Check(serviceName string) error {
m := &mgr.Mgr{Handle: client.scHandle}
defer m.Disconnect()
svcHandle, err := windows.OpenService(m.Handle, syscall.StringToUTF16Ptr(serviceName), windows.SC_MANAGER_ENUMERATE_SERVICE)
func (client *OSServiceClient) Check(serviceName string) (err error) {
var isHealthy bool
m := win.getWindowsSvcMgr(client.scHandle)
defer func() {
errDisconnect := m.Disconnect()
if isHealthy || errDisconnect == nil || err != nil {
return
}
//unreachable at the moment but we might want to log this error. leaving here for code-review
err = errDisconnect
}()
svcHandle, err := win.OpenService(win.getWindowsSvcMgrHandle(m), syscall.StringToUTF16Ptr(serviceName), windows.SC_MANAGER_ENUMERATE_SERVICE)
if err != nil {
return fmt.Errorf("error accessing service: %w", err)
}
service := &mgr.Service{Name: serviceName, Handle: svcHandle}
defer service.Close()
service := win.getWindowsSvc(serviceName, svcHandle)
defer func() {
errClose := service.Close()
if isHealthy || errClose == nil || err != nil {
return
}
//unreachable at the moment but we might want to log this error. leaving here for code-review
err = errClose
}()
status, err := service.Query()
if err != nil {
return fmt.Errorf("error querying service status: %w", err)
@ -45,10 +66,49 @@ func (client *OSServiceClient) Check(serviceName string) error {
switch status.State {
case svc.Running:
return nil
case svc.Stopped:
return ErrOSServiceStatusCritical
err = nil
isHealthy = true
case svc.Paused, svc.Stopped:
err = ErrOSServiceStatusCritical
default:
return fmt.Errorf("service status: %v", status.State)
err = fmt.Errorf("service status: %v", status.State)
}
return err
}
type windowsOS struct{}
func (windowsOS) OpenSCManager(machineName *uint16, databaseName *uint16, access uint32) (handle windows.Handle, err error) {
return windows.OpenSCManager(machineName, databaseName, access)
}
func (windowsOS) OpenService(mgr windows.Handle, serviceName *uint16, access uint32) (handle windows.Handle, err error) {
return windows.OpenService(mgr, serviceName, access)
}
func (windowsOS) getWindowsSvcMgr(h windows.Handle) windowsSvcMgr { return &mgr.Mgr{Handle: h} }
func (windowsOS) getWindowsSvcMgrHandle(sm windowsSvcMgr) windows.Handle {
return sm.(*mgr.Mgr).Handle
}
func (windowsOS) getWindowsSvc(name string, h windows.Handle) windowsSvc {
return &mgr.Service{Name: name, Handle: h}
}
type windowsSystem interface {
OpenSCManager(machineName *uint16, databaseName *uint16, access uint32) (handle windows.Handle, err error)
OpenService(mgr windows.Handle, serviceName *uint16, access uint32) (handle windows.Handle, err error)
getWindowsSvcMgr(h windows.Handle) windowsSvcMgr
getWindowsSvcMgrHandle(sm windowsSvcMgr) windows.Handle
getWindowsSvc(name string, h windows.Handle) windowsSvc
}
type windowsSvcMgr interface {
Disconnect() error
}
type windowsSvc interface {
Close() error
Query() (svc.Status, error)
}