testing: improve fidelity of mock driver task restore (#16990)

While working on client status update improvements, I encountered problems
getting tests with the mock driver to correctly restore.

Unlike typical drivers the mock driver doesn't have an external source of truth
for whether the task is running (ex. making API calls to `dockerd` or looking
for a running PID), and so in order to make up that information, it re-parses
the original task config. But the taskrunner doesn't call the encoding step for
`RecoverTask`, only `StartTask`, so the task config the mock driver gets is
missing data.

Update the mock driver to stash the "external" state in the task state that
we'll get from the task runner, so that we don't have to try to recover from the
original `TaskConfig` anymore. This should bring the mock driver closer to the
behavior of the other drivers.
This commit is contained in:
Tim Gross 2023-04-27 11:54:10 -04:00 committed by GitHub
parent fddef4c6e1
commit 87f416943c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 233 additions and 29 deletions

View File

@ -266,6 +266,15 @@ type TaskConfig struct {
type MockTaskState struct {
StartedAt time.Time
// these are not strictly "state" but because there's no external
// reattachment we need somewhere to stash this config so we can properly
// restore mock tasks
Command Command
ExecCommand *Command
PluginExitAfter time.Duration
KillAfter time.Duration
ProcState drivers.TaskState
}
func (d *Driver) PluginInfo() (*base.PluginInfoResponse, error) {
@ -358,21 +367,39 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error {
return fmt.Errorf("failed to decode task state from handle: %v", err)
}
driverCfg, err := parseDriverConfig(handle.Config)
if err != nil {
d.logger.Error("failed to parse driver config from handle", "error", err, "task_id", handle.Config.ID, "config", hclog.Fmt("%+v", handle.Config))
return fmt.Errorf("failed to parse driver config from handle: %v", err)
taskState.Command.parseDurations()
if taskState.ExecCommand != nil {
taskState.ExecCommand.parseDurations()
}
// Remove the plugin exit time if set
driverCfg.pluginExitAfterDuration = 0
// Correct the run_for time based on how long it has already been running
now := time.Now()
driverCfg.runForDuration = driverCfg.runForDuration - now.Sub(taskState.StartedAt)
if !taskState.StartedAt.IsZero() {
taskState.Command.runForDuration = taskState.Command.runForDuration - now.Sub(taskState.StartedAt)
if taskState.ExecCommand != nil {
taskState.ExecCommand.runForDuration = taskState.ExecCommand.runForDuration - now.Sub(taskState.StartedAt)
}
}
// Recreate the taskHandle. Because there's no real running process, we'll
// assume we're still running if we've recovered it at all.
killCtx, killCancel := context.WithCancel(context.Background())
h := &taskHandle{
logger: d.logger.With("task_name", handle.Config.Name),
pluginExitAfter: taskState.PluginExitAfter,
killAfter: taskState.KillAfter,
waitCh: make(chan any),
taskConfig: handle.Config,
command: taskState.Command,
execCommand: taskState.ExecCommand,
procState: drivers.TaskStateRunning,
startedAt: taskState.StartedAt,
kill: killCancel,
killCh: killCtx.Done(),
Recovered: true,
}
h := newTaskHandle(handle.Config, driverCfg, d.logger)
h.Recovered = true
d.tasks.Set(handle.Config.ID, h)
go h.run()
return nil
@ -423,23 +450,6 @@ func parseDriverConfig(cfg *drivers.TaskConfig) (*TaskConfig, error) {
return &driverConfig, nil
}
func newTaskHandle(cfg *drivers.TaskConfig, driverConfig *TaskConfig, logger hclog.Logger) *taskHandle {
killCtx, killCancel := context.WithCancel(context.Background())
h := &taskHandle{
taskConfig: cfg,
command: driverConfig.Command,
execCommand: driverConfig.ExecCommand,
pluginExitAfter: driverConfig.pluginExitAfterDuration,
killAfter: driverConfig.killAfterDuration,
logger: logger.With("task_name", cfg.Name),
waitCh: make(chan interface{}),
killCh: killCtx.Done(),
kill: killCancel,
startedAt: time.Now(),
}
return h
}
func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drivers.DriverNetwork, error) {
driverConfig, err := parseDriverConfig(cfg)
if err != nil {
@ -477,9 +487,26 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
net.PortMap = map[string]int{parts[0]: port}
}
h := newTaskHandle(cfg, driverConfig, d.logger)
killCtx, killCancel := context.WithCancel(context.Background())
h := &taskHandle{
taskConfig: cfg,
command: driverConfig.Command,
execCommand: driverConfig.ExecCommand,
pluginExitAfter: driverConfig.pluginExitAfterDuration,
killAfter: driverConfig.killAfterDuration,
logger: d.logger.With("task_name", cfg.Name),
waitCh: make(chan interface{}),
killCh: killCtx.Done(),
kill: killCancel,
startedAt: time.Now(),
}
driverState := MockTaskState{
StartedAt: h.startedAt,
StartedAt: h.startedAt,
Command: driverConfig.Command,
ExecCommand: driverConfig.ExecCommand,
PluginExitAfter: driverConfig.pluginExitAfterDuration,
KillAfter: driverConfig.killAfterDuration,
}
handle := drivers.NewTaskHandle(taskHandleVersion)
handle.Config = cfg

177
drivers/mock/driver_test.go Normal file
View File

@ -0,0 +1,177 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package mock
import (
"context"
"os"
"sync"
"testing"
"time"
hclog "github.com/hashicorp/go-hclog"
"github.com/shoenig/test/must"
"github.com/shoenig/test/wait"
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client/allocdir"
"github.com/hashicorp/nomad/client/config"
"github.com/hashicorp/nomad/client/taskenv"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/helper/testtask"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
basePlug "github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/drivers"
dtestutil "github.com/hashicorp/nomad/plugins/drivers/testutils"
)
func TestMockDriver_StartWaitRecoverWaitStop(t *testing.T) {
ci.Parallel(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
logger := testlog.HCLogger(t)
d := NewMockDriver(ctx, logger).(*Driver)
harness := dtestutil.NewDriverHarness(t, d)
defer harness.Kill()
var data []byte
must.NoError(t, basePlug.MsgPackEncode(&data, &Config{}))
bconfig := &basePlug.Config{PluginConfig: data}
must.NoError(t, harness.SetConfig(bconfig))
task := &drivers.TaskConfig{
AllocID: uuid.Generate(),
ID: uuid.Generate(),
Name: "sleep",
Env: map[string]string{},
}
tc := &TaskConfig{
Command: Command{
RunFor: "10s",
runForDuration: time.Second * 10,
},
PluginExitAfter: "30s",
pluginExitAfterDuration: time.Second * 30,
}
must.NoError(t, task.EncodeConcreteDriverConfig(&tc))
testtask.SetTaskConfigEnv(task)
cleanup := mkTestAllocDir(t, harness, logger, task)
t.Cleanup(cleanup)
handle, _, err := harness.StartTask(task)
must.NoError(t, err)
ch, err := harness.WaitTask(context.Background(), task.ID)
must.NoError(t, err)
var waitDone bool
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
result := <-ch
must.Error(t, result.Err)
waitDone = true
}()
originalStatus, err := d.InspectTask(task.ID)
must.NoError(t, err)
d.tasks.Delete(task.ID)
wg.Wait()
must.True(t, waitDone)
_, err = d.InspectTask(task.ID)
must.Eq(t, drivers.ErrTaskNotFound, err)
err = d.RecoverTask(handle)
must.NoError(t, err)
// need to make sure the task is left running and doesn't just immediately
// exit after we recover it
must.Wait(t, wait.ContinualSuccess(
wait.BoolFunc(func() bool {
status, err := d.InspectTask(task.ID)
must.NoError(t, err)
return status.State == "running"
}),
wait.Timeout(1*time.Second),
wait.Gap(100*time.Millisecond),
))
status, err := d.InspectTask(task.ID)
must.NoError(t, err)
must.Eq(t, originalStatus, status)
ch, err = harness.WaitTask(context.Background(), task.ID)
must.NoError(t, err)
wg.Add(1)
waitDone = false
go func() {
defer wg.Done()
result := <-ch
must.NoError(t, result.Err)
must.Zero(t, result.ExitCode)
waitDone = true
}()
time.Sleep(300 * time.Millisecond)
must.NoError(t, d.StopTask(task.ID, 0, "SIGKILL"))
wg.Wait()
must.NoError(t, d.DestroyTask(task.ID, false))
must.True(t, waitDone)
}
func mkTestAllocDir(t *testing.T, h *dtestutil.DriverHarness, logger hclog.Logger, tc *drivers.TaskConfig) func() {
dir, err := os.MkdirTemp("", "nomad_driver_harness-")
must.NoError(t, err)
allocDir := allocdir.NewAllocDir(logger, dir, tc.AllocID)
must.NoError(t, allocDir.Build())
tc.AllocDir = allocDir.AllocDir
taskDir := allocDir.NewTaskDir(tc.Name)
must.NoError(t, taskDir.Build(false, ci.TinyChroot))
task := &structs.Task{
Name: tc.Name,
Env: tc.Env,
}
// no logging
tc.StdoutPath = os.DevNull
tc.StderrPath = os.DevNull
// Create the mock allocation
alloc := mock.Alloc()
alloc.ID = tc.AllocID
if tc.Resources != nil {
alloc.AllocatedResources.Tasks[task.Name] = tc.Resources.NomadResources
}
taskBuilder := taskenv.NewBuilder(mock.Node(), alloc, task, "global")
dtestutil.SetEnvvars(taskBuilder, drivers.FSIsolationNone, taskDir, config.DefaultConfig())
taskEnv := taskBuilder.Build()
if tc.Env == nil {
tc.Env = taskEnv.Map()
} else {
for k, v := range taskEnv.Map() {
if _, ok := tc.Env[k]; !ok {
tc.Env[k] = v
}
}
}
return func() {
allocDir.Destroy()
}
}