diff --git a/client/pluginmanager/drivermanager/instance.go b/client/pluginmanager/drivermanager/instance.go index 40f985c4c..b9581403a 100644 --- a/client/pluginmanager/drivermanager/instance.go +++ b/client/pluginmanager/drivermanager/instance.go @@ -248,10 +248,6 @@ func (i *instanceManager) cleanup() { return } - if internalPlugin, ok := i.plugin.Plugin().(drivers.InternalDriverPlugin); ok { - internalPlugin.Shutdown() - } - if !i.plugin.Exited() { i.plugin.Kill() if err := i.storeReattach(nil); err != nil { diff --git a/devices/gpu/nvidia/cmd/main.go b/devices/gpu/nvidia/cmd/main.go index 1f48a3450..5c0bea6c4 100644 --- a/devices/gpu/nvidia/cmd/main.go +++ b/devices/gpu/nvidia/cmd/main.go @@ -1,6 +1,8 @@ package main import ( + "context" + log "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/devices/gpu/nvidia" @@ -9,10 +11,10 @@ import ( func main() { // Serve the plugin - plugins.Serve(factory) + plugins.ServeCtx(factory) } // factory returns a new instance of the Nvidia GPU plugin -func factory(log log.Logger) interface{} { - return nvidia.NewNvidiaDevice(log) +func factory(ctx context.Context, log log.Logger) interface{} { + return nvidia.NewNvidiaDevice(ctx, log) } diff --git a/devices/gpu/nvidia/device.go b/devices/gpu/nvidia/device.go index a4fb82aac..064161cf5 100644 --- a/devices/gpu/nvidia/device.go +++ b/devices/gpu/nvidia/device.go @@ -46,7 +46,7 @@ var ( // PluginConfig is the nvidia factory function registered in the // plugin catalog. PluginConfig = &loader.InternalPluginConfig{ - Factory: func(l log.Logger) interface{} { return NewNvidiaDevice(l) }, + Factory: func(ctx context.Context, l log.Logger) interface{} { return NewNvidiaDevice(ctx, l) }, } // pluginInfo describes the plugin @@ -99,7 +99,7 @@ type NvidiaDevice struct { } // NewNvidiaDevice returns a new nvidia device plugin. -func NewNvidiaDevice(log log.Logger) *NvidiaDevice { +func NewNvidiaDevice(_ context.Context, log log.Logger) *NvidiaDevice { nvmlClient, err := nvml.NewNvmlClient() logger := log.Named(pluginName) if err != nil && err.Error() != nvml.UnavailableLib.Error() { diff --git a/drivers/docker/cmd/main.go b/drivers/docker/cmd/main.go index faa77e5f6..f830f4f86 100644 --- a/drivers/docker/cmd/main.go +++ b/drivers/docker/cmd/main.go @@ -6,6 +6,7 @@ package main import ( + "context" "os" log "github.com/hashicorp/go-hclog" @@ -42,10 +43,10 @@ func main() { } // Serve the plugin - plugins.Serve(factory) + plugins.ServeCtx(factory) } // factory returns a new instance of the docker driver plugin -func factory(log log.Logger) interface{} { - return docker.NewDockerDriver(log) +func factory(ctx context.Context, log log.Logger) interface{} { + return docker.NewDockerDriver(ctx, log) } diff --git a/drivers/docker/config.go b/drivers/docker/config.go index 612d2de68..08be91b8e 100644 --- a/drivers/docker/config.go +++ b/drivers/docker/config.go @@ -1,6 +1,7 @@ package docker import ( + "context" "fmt" "strconv" "strings" @@ -123,7 +124,7 @@ var ( // plugin catalog. PluginConfig = &loader.InternalPluginConfig{ Config: map[string]interface{}{}, - Factory: func(l hclog.Logger) interface{} { return NewDockerDriver(l) }, + Factory: func(ctx context.Context, l hclog.Logger) interface{} { return NewDockerDriver(ctx, l) }, } // pluginInfo is the response returned for the PluginInfo RPC diff --git a/drivers/docker/driver.go b/drivers/docker/driver.go index 291b36d37..35e247ae5 100644 --- a/drivers/docker/driver.go +++ b/drivers/docker/driver.go @@ -90,10 +90,6 @@ type Driver struct { // coordinate shutdown ctx context.Context - // signalShutdown is called when the driver is shutting down and cancels the - // ctx passed to any subsystems - signalShutdown context.CancelFunc - // tasks is the in memory datastore mapping taskIDs to taskHandles tasks *taskStore @@ -120,16 +116,14 @@ type Driver struct { } // NewDockerDriver returns a docker implementation of a driver plugin -func NewDockerDriver(logger hclog.Logger) drivers.DriverPlugin { - ctx, cancel := context.WithCancel(context.Background()) +func NewDockerDriver(ctx context.Context, logger hclog.Logger) drivers.DriverPlugin { logger = logger.Named(pluginName) return &Driver{ - eventer: eventer.NewEventer(ctx, logger), - config: &DriverConfig{}, - tasks: newTaskStore(), - ctx: ctx, - signalShutdown: cancel, - logger: logger, + eventer: eventer.NewEventer(ctx, logger), + config: &DriverConfig{}, + tasks: newTaskStore(), + ctx: ctx, + logger: logger, } } @@ -1622,10 +1616,6 @@ func sliceMergeUlimit(ulimitsRaw map[string]string) ([]docker.ULimit, error) { return ulimits, nil } -func (d *Driver) Shutdown() { - d.signalShutdown() -} - func isDockerTransientError(err error) bool { if err == nil { return false diff --git a/drivers/docker/driver_test.go b/drivers/docker/driver_test.go index eeee19607..16141bbeb 100644 --- a/drivers/docker/driver_test.go +++ b/drivers/docker/driver_test.go @@ -175,7 +175,9 @@ func cleanSlate(client *docker.Client, imageID string) { // A driver plugin interface and cleanup function is returned func dockerDriverHarness(t *testing.T, cfg map[string]interface{}) *dtestutil.DriverHarness { logger := testlog.HCLogger(t) - harness := dtestutil.NewDriverHarness(t, NewDockerDriver(logger)) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(func() { cancel() }) + harness := dtestutil.NewDriverHarness(t, NewDockerDriver(ctx, logger)) if cfg == nil { cfg = map[string]interface{}{ "gc": map[string]interface{}{ @@ -190,7 +192,7 @@ func dockerDriverHarness(t *testing.T, cfg map[string]interface{}) *dtestutil.Dr InternalPlugins: map[loader.PluginID]*loader.InternalPluginConfig{ PluginID: { Config: cfg, - Factory: func(hclog.Logger) interface{} { + Factory: func(context.Context, hclog.Logger) interface{} { return harness }, }, diff --git a/drivers/docker/fingerprint_test.go b/drivers/docker/fingerprint_test.go index db39a31de..52389dfb2 100644 --- a/drivers/docker/fingerprint_test.go +++ b/drivers/docker/fingerprint_test.go @@ -1,6 +1,7 @@ package docker import ( + "context" "testing" "github.com/hashicorp/nomad/client/testutil" @@ -20,7 +21,10 @@ func TestDockerDriver_FingerprintHealth(t *testing.T) { } testutil.DockerCompatible(t) - d := NewDockerDriver(testlog.HCLogger(t)).(*Driver) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewDockerDriver(ctx, testlog.HCLogger(t)).(*Driver) fp := d.buildFingerprint() require.Equal(t, drivers.HealthStateHealthy, fp.Health) diff --git a/drivers/exec/driver.go b/drivers/exec/driver.go index 850daa01e..22094216d 100644 --- a/drivers/exec/driver.go +++ b/drivers/exec/driver.go @@ -47,7 +47,7 @@ var ( // plugin catalog. PluginConfig = &loader.InternalPluginConfig{ Config: map[string]interface{}{}, - Factory: func(l hclog.Logger) interface{} { return NewExecDriver(l) }, + Factory: func(ctx context.Context, l hclog.Logger) interface{} { return NewExecDriver(ctx, l) }, } // pluginInfo is the response returned for the PluginInfo RPC @@ -107,10 +107,6 @@ type Driver struct { // coordinate shutdown ctx context.Context - // signalShutdown is called when the driver is shutting down and cancels the - // ctx passed to any subsystems - signalShutdown context.CancelFunc - // logger will log to the Nomad agent logger hclog.Logger @@ -144,15 +140,13 @@ type TaskState struct { } // NewExecDriver returns a new DrivePlugin implementation -func NewExecDriver(logger hclog.Logger) drivers.DriverPlugin { - ctx, cancel := context.WithCancel(context.Background()) +func NewExecDriver(ctx context.Context, logger hclog.Logger) drivers.DriverPlugin { logger = logger.Named(pluginName) return &Driver{ - eventer: eventer.NewEventer(ctx, logger), - tasks: newTaskStore(), - ctx: ctx, - signalShutdown: cancel, - logger: logger, + eventer: eventer.NewEventer(ctx, logger), + tasks: newTaskStore(), + ctx: ctx, + logger: logger, } } @@ -201,10 +195,6 @@ func (d *Driver) SetConfig(cfg *base.Config) error { return nil } -func (d *Driver) Shutdown() { - d.signalShutdown() -} - func (d *Driver) TaskConfigSchema() (*hclspec.Spec, error) { return taskConfigSpec, nil } diff --git a/drivers/exec/driver_test.go b/drivers/exec/driver_test.go index de1138e32..685e63e66 100644 --- a/drivers/exec/driver_test.go +++ b/drivers/exec/driver_test.go @@ -59,7 +59,10 @@ func TestExecDriver_Fingerprint_NonLinux(t *testing.T) { t.Skip("Test only available not on Linux") } - d := NewExecDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewExecDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) fingerCh, err := harness.Fingerprint(context.Background()) @@ -78,7 +81,10 @@ func TestExecDriver_Fingerprint(t *testing.T) { ctestutils.ExecCompatible(t) - d := NewExecDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewExecDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) fingerCh, err := harness.Fingerprint(context.Background()) @@ -97,7 +103,10 @@ func TestExecDriver_StartWait(t *testing.T) { require := require.New(t) ctestutils.ExecCompatible(t) - d := NewExecDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewExecDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) task := &drivers.TaskConfig{ ID: uuid.Generate(), @@ -129,7 +138,10 @@ func TestExecDriver_StartWaitStopKill(t *testing.T) { require := require.New(t) ctestutils.ExecCompatible(t) - d := NewExecDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewExecDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) task := &drivers.TaskConfig{ ID: uuid.Generate(), @@ -190,7 +202,10 @@ func TestExecDriver_StartWaitRecover(t *testing.T) { require := require.New(t) ctestutils.ExecCompatible(t) - d := NewExecDriver(testlog.HCLogger(t)) + dctx, dcancel := context.WithCancel(context.Background()) + defer dcancel() + + d := NewExecDriver(dctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) task := &drivers.TaskConfig{ ID: uuid.Generate(), @@ -262,7 +277,10 @@ func TestExecDriver_DestroyKillsAll(t *testing.T) { require := require.New(t) ctestutils.ExecCompatible(t) - d := NewExecDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewExecDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) defer harness.Kill() @@ -360,7 +378,10 @@ func TestExecDriver_Stats(t *testing.T) { require := require.New(t) ctestutils.ExecCompatible(t) - d := NewExecDriver(testlog.HCLogger(t)) + dctx, dcancel := context.WithCancel(context.Background()) + defer dcancel() + + d := NewExecDriver(dctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) task := &drivers.TaskConfig{ ID: uuid.Generate(), @@ -403,7 +424,10 @@ func TestExecDriver_Start_Wait_AllocDir(t *testing.T) { require := require.New(t) ctestutils.ExecCompatible(t) - d := NewExecDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewExecDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) task := &drivers.TaskConfig{ ID: uuid.Generate(), @@ -452,7 +476,10 @@ func TestExecDriver_User(t *testing.T) { require := require.New(t) ctestutils.ExecCompatible(t) - d := NewExecDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewExecDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) task := &drivers.TaskConfig{ ID: uuid.Generate(), @@ -486,7 +513,10 @@ func TestExecDriver_HandlerExec(t *testing.T) { require := require.New(t) ctestutils.ExecCompatible(t) - d := NewExecDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewExecDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) task := &drivers.TaskConfig{ ID: uuid.Generate(), @@ -574,7 +604,10 @@ func TestExecDriver_DevicesAndMounts(t *testing.T) { err = ioutil.WriteFile(filepath.Join(tmpDir, "testfile"), []byte("from-host"), 600) require.NoError(err) - d := NewExecDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewExecDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) task := &drivers.TaskConfig{ ID: uuid.Generate(), @@ -678,7 +711,10 @@ func TestExecDriver_NoPivotRoot(t *testing.T) { require := require.New(t) ctestutils.ExecCompatible(t) - d := NewExecDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewExecDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) config := &Config{NoPivotRoot: true} diff --git a/drivers/exec/driver_unix_test.go b/drivers/exec/driver_unix_test.go index 342993a8f..5ff063745 100644 --- a/drivers/exec/driver_unix_test.go +++ b/drivers/exec/driver_unix_test.go @@ -23,7 +23,10 @@ func TestExecDriver_StartWaitStop(t *testing.T) { require := require.New(t) ctestutils.ExecCompatible(t) - d := NewExecDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewExecDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) task := &drivers.TaskConfig{ ID: uuid.Generate(), @@ -82,7 +85,10 @@ func TestExec_ExecTaskStreaming(t *testing.T) { t.Parallel() require := require.New(t) - d := NewExecDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewExecDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) defer harness.Kill() diff --git a/drivers/java/driver.go b/drivers/java/driver.go index 04a820e21..f8aa408f6 100644 --- a/drivers/java/driver.go +++ b/drivers/java/driver.go @@ -50,7 +50,7 @@ var ( // plugin catalog. PluginConfig = &loader.InternalPluginConfig{ Config: map[string]interface{}{}, - Factory: func(l hclog.Logger) interface{} { return NewDriver(l) }, + Factory: func(ctx context.Context, l hclog.Logger) interface{} { return NewDriver(ctx, l) }, } // pluginInfo is the response returned for the PluginInfo RPC @@ -135,23 +135,17 @@ type Driver struct { // nomadConf is the client agent's configuration nomadConfig *base.ClientDriverConfig - // signalShutdown is called when the driver is shutting down and cancels the - // ctx passed to any subsystems - signalShutdown context.CancelFunc - // logger will log to the Nomad agent logger hclog.Logger } -func NewDriver(logger hclog.Logger) drivers.DriverPlugin { - ctx, cancel := context.WithCancel(context.Background()) +func NewDriver(ctx context.Context, logger hclog.Logger) drivers.DriverPlugin { logger = logger.Named(pluginName) return &Driver{ - eventer: eventer.NewEventer(ctx, logger), - tasks: newTaskStore(), - ctx: ctx, - signalShutdown: cancel, - logger: logger, + eventer: eventer.NewEventer(ctx, logger), + tasks: newTaskStore(), + ctx: ctx, + logger: logger, } } @@ -589,7 +583,3 @@ func GetAbsolutePath(bin string) (string, error) { return filepath.EvalSymlinks(lp) } - -func (d *Driver) Shutdown() { - d.signalShutdown() -} diff --git a/drivers/java/driver_test.go b/drivers/java/driver_test.go index b4b3d0010..4aed991f2 100644 --- a/drivers/java/driver_test.go +++ b/drivers/java/driver_test.go @@ -38,7 +38,10 @@ func TestJavaDriver_Fingerprint(t *testing.T) { t.Parallel() } - d := NewDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) fpCh, err := harness.Fingerprint(context.Background()) @@ -61,7 +64,10 @@ func TestJavaDriver_Jar_Start_Wait(t *testing.T) { } require := require.New(t) - d := NewDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) tc := &TaskConfig{ @@ -101,7 +107,10 @@ func TestJavaDriver_Jar_Stop_Wait(t *testing.T) { } require := require.New(t) - d := NewDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) tc := &TaskConfig{ @@ -162,7 +171,10 @@ func TestJavaDriver_Class_Start_Wait(t *testing.T) { } require := require.New(t) - d := NewDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) tc := &TaskConfig{ @@ -250,7 +262,10 @@ func TestJavaDriver_ExecTaskStreaming(t *testing.T) { } require := require.New(t) - d := NewDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) defer harness.Kill() diff --git a/drivers/mock/driver.go b/drivers/mock/driver.go index 80c91182a..604ac51ec 100644 --- a/drivers/mock/driver.go +++ b/drivers/mock/driver.go @@ -45,7 +45,7 @@ var ( // plugin catalog. PluginConfig = &loader.InternalPluginConfig{ Config: map[string]interface{}{}, - Factory: func(l hclog.Logger) interface{} { return NewMockDriver(l) }, + Factory: func(ctx context.Context, l hclog.Logger) interface{} { return NewMockDriver(ctx, l) }, } // pluginInfo is the response returned for the PluginInfo RPC @@ -129,10 +129,6 @@ type Driver struct { // coordinate shutdown ctx context.Context - // signalShutdown is called when the driver is shutting down and cancels the - // ctx passed to any subsystems - signalShutdown context.CancelFunc - shutdownFingerprintTime time.Time // lastDriverTaskConfig is the last *drivers.TaskConfig passed to StartTask @@ -149,8 +145,7 @@ type Driver struct { } // NewMockDriver returns a new DriverPlugin implementation -func NewMockDriver(logger hclog.Logger) drivers.DriverPlugin { - ctx, cancel := context.WithCancel(context.Background()) +func NewMockDriver(ctx context.Context, logger hclog.Logger) drivers.DriverPlugin { logger = logger.Named(pluginName) capabilities := &drivers.Capabilities{ @@ -161,13 +156,12 @@ func NewMockDriver(logger hclog.Logger) drivers.DriverPlugin { } return &Driver{ - eventer: eventer.NewEventer(ctx, logger), - capabilities: capabilities, - config: &Config{}, - tasks: newTaskStore(), - ctx: ctx, - signalShutdown: cancel, - logger: logger, + eventer: eventer.NewEventer(ctx, logger), + capabilities: capabilities, + config: &Config{}, + tasks: newTaskStore(), + ctx: ctx, + logger: logger, } } @@ -676,10 +670,6 @@ func (d *Driver) GetHandle(taskID string) *taskHandle { return h } -func (d *Driver) Shutdown() { - d.signalShutdown() -} - func (d *Driver) CreateNetwork(allocID string) (*drivers.NetworkIsolationSpec, error) { return nil, nil } diff --git a/drivers/qemu/driver.go b/drivers/qemu/driver.go index 97f81b3de..24805d0b2 100644 --- a/drivers/qemu/driver.go +++ b/drivers/qemu/driver.go @@ -61,7 +61,7 @@ var ( // plugin catalog. PluginConfig = &loader.InternalPluginConfig{ Config: map[string]interface{}{}, - Factory: func(l hclog.Logger) interface{} { return NewQemuDriver(l) }, + Factory: func(ctx context.Context, l hclog.Logger) interface{} { return NewQemuDriver(ctx, l) }, } versionRegex = regexp.MustCompile(`version (\d[\.\d+]+)`) @@ -142,23 +142,17 @@ type Driver struct { // nomadConf is the client agent's configuration nomadConfig *base.ClientDriverConfig - // signalShutdown is called when the driver is shutting down and cancels the - // ctx passed to any subsystems - signalShutdown context.CancelFunc - // logger will log to the Nomad agent logger hclog.Logger } -func NewQemuDriver(logger hclog.Logger) drivers.DriverPlugin { - ctx, cancel := context.WithCancel(context.Background()) +func NewQemuDriver(ctx context.Context, logger hclog.Logger) drivers.DriverPlugin { logger = logger.Named(pluginName) return &Driver{ - eventer: eventer.NewEventer(ctx, logger), - tasks: newTaskStore(), - ctx: ctx, - signalShutdown: cancel, - logger: logger, + eventer: eventer.NewEventer(ctx, logger), + tasks: newTaskStore(), + ctx: ctx, + logger: logger, } } @@ -653,7 +647,3 @@ func sendQemuShutdown(logger hclog.Logger, monitorPath string, userPid int) erro } return err } - -func (d *Driver) Shutdown() { - d.signalShutdown() -} diff --git a/drivers/qemu/driver_test.go b/drivers/qemu/driver_test.go index a1f11c60c..c28a22260 100644 --- a/drivers/qemu/driver_test.go +++ b/drivers/qemu/driver_test.go @@ -32,7 +32,10 @@ func TestQemuDriver_Start_Wait_Stop(t *testing.T) { } require := require.New(t) - d := NewQemuDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewQemuDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) task := &drivers.TaskConfig{ @@ -94,7 +97,10 @@ func TestQemuDriver_GetMonitorPathOldQemu(t *testing.T) { } require := require.New(t) - d := NewQemuDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewQemuDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) task := &drivers.TaskConfig{ @@ -149,7 +155,10 @@ func TestQemuDriver_GetMonitorPathNewQemu(t *testing.T) { } require := require.New(t) - d := NewQemuDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewQemuDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) task := &drivers.TaskConfig{ @@ -229,7 +238,10 @@ func TestQemuDriver_User(t *testing.T) { } require := require.New(t) - d := NewQemuDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewQemuDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) task := &drivers.TaskConfig{ @@ -286,7 +298,10 @@ func TestQemuDriver_Stats(t *testing.T) { } require := require.New(t) - d := NewQemuDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewQemuDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) task := &drivers.TaskConfig{ @@ -363,7 +378,10 @@ func TestQemuDriver_Fingerprint(t *testing.T) { t.Parallel() } - d := NewQemuDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewQemuDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) fingerCh, err := harness.Fingerprint(context.Background()) diff --git a/drivers/rawexec/driver.go b/drivers/rawexec/driver.go index 1c23e49a9..8a86c75c4 100644 --- a/drivers/rawexec/driver.go +++ b/drivers/rawexec/driver.go @@ -45,7 +45,7 @@ var ( // plugin catalog. PluginConfig = &loader.InternalPluginConfig{ Config: map[string]interface{}{}, - Factory: func(l hclog.Logger) interface{} { return NewRawExecDriver(l) }, + Factory: func(ctx context.Context, l hclog.Logger) interface{} { return NewRawExecDriver(ctx, l) }, } errDisabledDriver = fmt.Errorf("raw_exec is disabled") @@ -126,10 +126,6 @@ type Driver struct { // coordinate shutdown ctx context.Context - // signalShutdown is called when the driver is shutting down and cancels the - // ctx passed to any subsystems - signalShutdown context.CancelFunc - // logger will log to the Nomad agent logger hclog.Logger } @@ -161,16 +157,14 @@ type TaskState struct { } // NewRawExecDriver returns a new DriverPlugin implementation -func NewRawExecDriver(logger hclog.Logger) drivers.DriverPlugin { - ctx, cancel := context.WithCancel(context.Background()) +func NewRawExecDriver(ctx context.Context, logger hclog.Logger) drivers.DriverPlugin { logger = logger.Named(pluginName) return &Driver{ - eventer: eventer.NewEventer(ctx, logger), - config: &Config{}, - tasks: newTaskStore(), - ctx: ctx, - signalShutdown: cancel, - logger: logger, + eventer: eventer.NewEventer(ctx, logger), + config: &Config{}, + tasks: newTaskStore(), + ctx: ctx, + logger: logger, } } @@ -197,10 +191,6 @@ func (d *Driver) SetConfig(cfg *base.Config) error { return nil } -func (d *Driver) Shutdown() { - d.signalShutdown() -} - func (d *Driver) TaskConfigSchema() (*hclspec.Spec, error) { return taskConfigSpec, nil } diff --git a/drivers/rawexec/driver_test.go b/drivers/rawexec/driver_test.go index 8028edfd6..d6bf2f0bf 100644 --- a/drivers/rawexec/driver_test.go +++ b/drivers/rawexec/driver_test.go @@ -33,7 +33,10 @@ func TestMain(m *testing.M) { } func newEnabledRawExecDriver(t *testing.T) *Driver { - d := NewRawExecDriver(testlog.HCLogger(t)).(*Driver) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(func() { cancel() }) + + d := NewRawExecDriver(ctx, testlog.HCLogger(t)).(*Driver) d.config.Enabled = true return d } @@ -42,7 +45,10 @@ func TestRawExecDriver_SetConfig(t *testing.T) { t.Parallel() require := require.New(t) - d := NewRawExecDriver(testlog.HCLogger(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + d := NewRawExecDriver(ctx, testlog.HCLogger(t)) harness := dtestutil.NewDriverHarness(t, d) defer harness.Kill() diff --git a/helper/pluginutils/loader/init.go b/helper/pluginutils/loader/init.go index babecd50e..e9e2cc3a9 100644 --- a/helper/pluginutils/loader/init.go +++ b/helper/pluginutils/loader/init.go @@ -1,6 +1,7 @@ package loader import ( + "context" "fmt" "os" "os/exec" @@ -85,11 +86,14 @@ func (l *PluginLoader) init(config *PluginLoaderConfig) error { // initInternal initializes internal plugins. func (l *PluginLoader) initInternal(plugins map[PluginID]*InternalPluginConfig, configs map[string]*config.PluginConfig) (map[PluginID]*pluginInfo, error) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var mErr multierror.Error fingerprinted := make(map[PluginID]*pluginInfo, len(plugins)) for k, config := range plugins { // Create an instance - raw := config.Factory(l.logger) + raw := config.Factory(ctx, l.logger) base, ok := raw.(base.BasePlugin) if !ok { multierror.Append(&mErr, fmt.Errorf("internal plugin %s doesn't meet base plugin interface", k)) diff --git a/helper/pluginutils/loader/instance.go b/helper/pluginutils/loader/instance.go index d20bff4a2..e396eb45a 100644 --- a/helper/pluginutils/loader/instance.go +++ b/helper/pluginutils/loader/instance.go @@ -31,10 +31,12 @@ type PluginInstance interface { type internalPluginInstance struct { instance interface{} apiVersion string + killFn func() } -func (p *internalPluginInstance) Internal() bool { return true } -func (p *internalPluginInstance) Kill() {} +func (p *internalPluginInstance) Internal() bool { return true } +func (p *internalPluginInstance) Kill() { p.killFn() } + func (p *internalPluginInstance) ReattachConfig() (*plugin.ReattachConfig, bool) { return nil, false } func (p *internalPluginInstance) Plugin() interface{} { return p.instance } func (p *internalPluginInstance) Exited() bool { return false } diff --git a/helper/pluginutils/loader/loader.go b/helper/pluginutils/loader/loader.go index 34ffe2e4e..01644cb12 100644 --- a/helper/pluginutils/loader/loader.go +++ b/helper/pluginutils/loader/loader.go @@ -1,6 +1,7 @@ package loader import ( + "context" "fmt" "os/exec" @@ -31,7 +32,7 @@ type PluginCatalog interface { // InternalPluginConfig is used to configure launching an internal plugin. type InternalPluginConfig struct { Config map[string]interface{} - Factory plugins.PluginFactory + Factory plugins.PluginCtxFactory } // PluginID is a tuple identifying a plugin @@ -92,7 +93,7 @@ type PluginLoader struct { // pluginInfo captures the necessary information to launch and configure a // plugin. type pluginInfo struct { - factory plugins.PluginFactory + factory plugins.PluginCtxFactory exePath string args []string @@ -153,9 +154,11 @@ func (l *PluginLoader) Dispense(name, pluginType string, config *base.AgentConfi // If the plugin is internal, launch via the factory var instance PluginInstance if pinfo.factory != nil { + ctx, cancel := context.WithCancel(context.Background()) instance = &internalPluginInstance{ - instance: pinfo.factory(logger), + instance: pinfo.factory(ctx, logger), apiVersion: pinfo.apiVersion, + killFn: cancel, } } else { var err error diff --git a/helper/pluginutils/loader/plugin_test.go b/helper/pluginutils/loader/plugin_test.go index d619dd249..69bd9fc85 100644 --- a/helper/pluginutils/loader/plugin_test.go +++ b/helper/pluginutils/loader/plugin_test.go @@ -93,8 +93,8 @@ func pluginMain(name, pluginType, version string, apiVersions []string, config b // mockFactory returns a PluginFactory method which creates the mock plugin with // the passed parameters -func mockFactory(name, ptype, version string, apiVersions []string, configSchema bool) func(log log.Logger) interface{} { - return func(log log.Logger) interface{} { +func mockFactory(name, ptype, version string, apiVersions []string, configSchema bool) func(context.Context, log.Logger) interface{} { + return func(ctx context.Context, log log.Logger) interface{} { return &mockPlugin{ name: name, ptype: ptype, diff --git a/nomad/client_csi_endpoint.go b/nomad/client_csi_endpoint.go index eb0a4262a..8dfe85e55 100644 --- a/nomad/client_csi_endpoint.go +++ b/nomad/client_csi_endpoint.go @@ -136,7 +136,7 @@ func (a *ClientCSI) nodeForController(pluginID, nodeID string) (string, error) { return nodeID, nil } else { // we'll fall-through and select a node at random - a.logger.Trace("%s could not be used for client RPC: %v", nodeID, err) + a.logger.Trace("could not be used for client RPC", "node", nodeID, "error", err) } } diff --git a/nomad/deploymentwatcher/batcher.go b/nomad/deploymentwatcher/batcher.go index 3d0f34eb9..c3b6fcfc6 100644 --- a/nomad/deploymentwatcher/batcher.go +++ b/nomad/deploymentwatcher/batcher.go @@ -26,7 +26,7 @@ type AllocUpdateBatcher struct { // NewAllocUpdateBatcher returns an AllocUpdateBatcher that uses the passed raft endpoints to // create the allocation desired transition updates and new evaluations and // exits the batcher when the passed exit channel is closed. -func NewAllocUpdateBatcher(batchDuration time.Duration, raft DeploymentRaftEndpoints, ctx context.Context) *AllocUpdateBatcher { +func NewAllocUpdateBatcher(ctx context.Context, batchDuration time.Duration, raft DeploymentRaftEndpoints) *AllocUpdateBatcher { b := &AllocUpdateBatcher{ batch: batchDuration, raft: raft, diff --git a/nomad/deploymentwatcher/deployments_watcher.go b/nomad/deploymentwatcher/deployments_watcher.go index 494d5c16d..7cfbb94ab 100644 --- a/nomad/deploymentwatcher/deployments_watcher.go +++ b/nomad/deploymentwatcher/deployments_watcher.go @@ -118,7 +118,7 @@ func (w *Watcher) SetEnabled(enabled bool, state *state.StateStore) { } // Flush the state to create the necessary objects - w.flush() + w.flush(enabled) // If we are starting now, launch the watch daemon if enabled && !wasEnabled { @@ -127,7 +127,7 @@ func (w *Watcher) SetEnabled(enabled bool, state *state.StateStore) { } // flush is used to clear the state of the watcher -func (w *Watcher) flush() { +func (w *Watcher) flush(enabled bool) { // Stop all the watchers and clear it for _, watcher := range w.watchers { watcher.StopWatch() @@ -140,7 +140,12 @@ func (w *Watcher) flush() { w.watchers = make(map[string]*deploymentWatcher, 32) w.ctx, w.exitFn = context.WithCancel(context.Background()) - w.allocUpdateBatcher = NewAllocUpdateBatcher(w.updateBatchDuration, w.raft, w.ctx) + + if enabled { + w.allocUpdateBatcher = NewAllocUpdateBatcher(w.ctx, w.updateBatchDuration, w.raft) + } else { + w.allocUpdateBatcher = nil + } } // watchDeployments is the long lived go-routine that watches for deployments to @@ -361,7 +366,11 @@ func (w *Watcher) FailDeployment(req *structs.DeploymentFailRequest, resp *struc // createUpdate commits the given allocation desired transition and evaluation // to Raft but batches the commit with other calls. func (w *Watcher) createUpdate(allocs map[string]*structs.DesiredTransition, eval *structs.Evaluation) (uint64, error) { - return w.allocUpdateBatcher.CreateUpdate(allocs, eval).Results() + b := w.allocUpdateBatcher + if b == nil { + return 0, notEnabled + } + return b.CreateUpdate(allocs, eval).Results() } // upsertJob commits the given job to Raft diff --git a/nomad/leader.go b/nomad/leader.go index 18c36ef38..0822c4e05 100644 --- a/nomad/leader.go +++ b/nomad/leader.go @@ -108,6 +108,9 @@ func (s *Server) monitorLeadership() { s.logger.Warn("cluster leadership gained and lost leadership immediately. Could indicate network issues, memory paging, or high CPU load.") } case <-s.shutdownCh: + if weAreLeaderCh != nil { + leaderStep(false) + } return } } diff --git a/nomad/volumewatcher/batcher.go b/nomad/volumewatcher/batcher.go index e9eb74abe..0728b806d 100644 --- a/nomad/volumewatcher/batcher.go +++ b/nomad/volumewatcher/batcher.go @@ -29,7 +29,7 @@ type VolumeUpdateBatcher struct { // NewVolumeUpdateBatcher returns an VolumeUpdateBatcher that uses the // passed raft endpoints to create the updates to volume claims, and // exits the batcher when the passed exit channel is closed. -func NewVolumeUpdateBatcher(batchDuration time.Duration, raft VolumeRaftEndpoints, ctx context.Context) *VolumeUpdateBatcher { +func NewVolumeUpdateBatcher(ctx context.Context, batchDuration time.Duration, raft VolumeRaftEndpoints) *VolumeUpdateBatcher { b := &VolumeUpdateBatcher{ batchDuration: batchDuration, raft: raft, diff --git a/nomad/volumewatcher/batcher_test.go b/nomad/volumewatcher/batcher_test.go index 930342cdc..fd6773c43 100644 --- a/nomad/volumewatcher/batcher_test.go +++ b/nomad/volumewatcher/batcher_test.go @@ -23,7 +23,7 @@ func TestVolumeWatch_Batcher(t *testing.T) { srv := &MockBatchingRPCServer{} srv.state = state.TestStateStore(t) - srv.volumeUpdateBatcher = NewVolumeUpdateBatcher(CrossVolumeUpdateBatchDuration, srv, ctx) + srv.volumeUpdateBatcher = NewVolumeUpdateBatcher(ctx, CrossVolumeUpdateBatchDuration, srv) srv.nextCSIControllerDetachError = fmt.Errorf("some controller plugin error") plugin := mock.CSIPlugin() diff --git a/nomad/volumewatcher/volumes_watcher.go b/nomad/volumewatcher/volumes_watcher.go index 63446c461..509f49f86 100644 --- a/nomad/volumewatcher/volumes_watcher.go +++ b/nomad/volumewatcher/volumes_watcher.go @@ -2,6 +2,7 @@ package volumewatcher import ( "context" + "errors" "sync" "time" @@ -100,7 +101,7 @@ func (w *Watcher) SetEnabled(enabled bool, state *state.StateStore) { } // Flush the state to create the necessary objects - w.flush() + w.flush(enabled) // If we are starting now, launch the watch daemon if enabled && !wasEnabled { @@ -109,7 +110,7 @@ func (w *Watcher) SetEnabled(enabled bool, state *state.StateStore) { } // flush is used to clear the state of the watcher -func (w *Watcher) flush() { +func (w *Watcher) flush(enabled bool) { // Stop all the watchers and clear it for _, watcher := range w.watchers { watcher.Stop() @@ -122,7 +123,12 @@ func (w *Watcher) flush() { w.watchers = make(map[string]*volumeWatcher, 32) w.ctx, w.exitFn = context.WithCancel(context.Background()) - w.volumeUpdateBatcher = NewVolumeUpdateBatcher(w.updateBatchDuration, w.raft, w.ctx) + + if enabled { + w.volumeUpdateBatcher = NewVolumeUpdateBatcher(w.ctx, w.updateBatchDuration, w.raft) + } else { + w.volumeUpdateBatcher = nil + } } // watchVolumes is the long lived go-routine that watches for volumes to @@ -228,5 +234,9 @@ func (w *Watcher) removeLocked(volID, namespace string) { // updatesClaims sends the claims to the batch updater and waits for // the results func (w *Watcher) updateClaims(claims []structs.CSIVolumeClaimRequest) (uint64, error) { - return w.volumeUpdateBatcher.CreateUpdate(claims).Results() + b := w.volumeUpdateBatcher + if b == nil { + return 0, errors.New("volume watcher is not enabled") + } + return b.CreateUpdate(claims).Results() } diff --git a/nomad/volumewatcher/volumes_watcher_test.go b/nomad/volumewatcher/volumes_watcher_test.go index 3d00fa13b..53f97fe30 100644 --- a/nomad/volumewatcher/volumes_watcher_test.go +++ b/nomad/volumewatcher/volumes_watcher_test.go @@ -110,7 +110,7 @@ func TestVolumeWatch_StartStop(t *testing.T) { srv.state = state.TestStateStore(t) index := uint64(100) srv.volumeUpdateBatcher = NewVolumeUpdateBatcher( - CrossVolumeUpdateBatchDuration, srv, ctx) + ctx, CrossVolumeUpdateBatchDuration, srv) watcher := NewVolumesWatcher(testlog.HCLogger(t), srv, srv, @@ -261,7 +261,7 @@ func TestVolumeWatch_RegisterDeregister(t *testing.T) { srv := &MockStatefulRPCServer{} srv.state = state.TestStateStore(t) srv.volumeUpdateBatcher = NewVolumeUpdateBatcher( - CrossVolumeUpdateBatchDuration, srv, ctx) + ctx, CrossVolumeUpdateBatchDuration, srv) index := uint64(100) diff --git a/plugins/drivers/driver.go b/plugins/drivers/driver.go index 11ca6446d..89e306172 100644 --- a/plugins/drivers/driver.go +++ b/plugins/drivers/driver.go @@ -87,14 +87,6 @@ type DriverNetworkManager interface { DestroyNetwork(allocID string, spec *NetworkIsolationSpec) error } -// InternalDriverPlugin is an interface that exposes functions that are only -// implemented by internal driver plugins. -type InternalDriverPlugin interface { - // Shutdown allows the plugin to cleanup any running state to avoid leaking - // resources. It should not block. - Shutdown() -} - // DriverSignalTaskNotSupported can be embedded by drivers which don't support // the SignalTask RPC. This satisfies the SignalTask func requirement for the // DriverPlugin interface. diff --git a/plugins/serve.go b/plugins/serve.go index f317f52c8..63dba8c39 100644 --- a/plugins/serve.go +++ b/plugins/serve.go @@ -1,6 +1,7 @@ package plugins import ( + "context" "fmt" log "github.com/hashicorp/go-hclog" @@ -11,6 +12,9 @@ import ( // PluginFactory returns a new plugin instance type PluginFactory func(log log.Logger) interface{} +// PluginFactory returns a new plugin instance, that takes in a context +type PluginCtxFactory func(ctx context.Context, log log.Logger) interface{} + // Serve is used to serve a new Nomad plugin func Serve(f PluginFactory) { logger := log.New(&log.LoggerOptions{ @@ -19,6 +23,23 @@ func Serve(f PluginFactory) { }) plugin := f(logger) + serve(plugin, logger) +} + +// Serve is used to serve a new Nomad plugin +func ServeCtx(f PluginCtxFactory) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.New(&log.LoggerOptions{ + Level: log.Trace, + JSONFormat: true, + }) + + plugin := f(ctx, logger) + serve(plugin, logger) +} +func serve(plugin interface{}, logger log.Logger) { switch p := plugin.(type) { case device.DevicePlugin: device.Serve(p, logger)