diff --git a/client/task_runner.go b/client/task_runner.go index f9d15767c..56df71f61 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -958,14 +958,34 @@ func (r *TaskRunner) run() { } case se := <-r.signalCh: - r.logger.Printf("[DEBUG] client: task being signalled with %v: %s", se.s, se.e.TaskSignalReason) + r.runningLock.Lock() + running := r.running + r.runningLock.Unlock() + common := fmt.Sprintf("signal %v to task %v for alloc %q", se.s, r.task.Name, r.alloc.ID) + if !running { + // Send no error + r.logger.Printf("[DEBUG] client: skipping %s", common) + se.result <- nil + continue + } + + r.logger.Printf("[DEBUG] client: sending %s", common) r.setState(structs.TaskStateRunning, se.e) res := r.handle.Signal(se.s) se.result <- res case event := <-r.restartCh: - r.logger.Printf("[DEBUG] client: task being restarted: %s", event.RestartReason) + r.runningLock.Lock() + running := r.running + r.runningLock.Unlock() + common := fmt.Sprintf("task %v for alloc %q", r.task.Name, r.alloc.ID) + if !running { + r.logger.Printf("[DEBUG] client: skipping restart of %v: task isn't running", common) + continue + } + + r.logger.Printf("[DEBUG] client: restarting %s: %v", common, event.RestartReason) r.setState(structs.TaskStateRunning, event) r.killTask(nil) @@ -1365,23 +1385,9 @@ func (r *TaskRunner) handleDestroy() (destroyed bool, err error) { // Restart will restart the task func (r *TaskRunner) Restart(source, reason string) { - reasonStr := fmt.Sprintf("%s: %s", source, reason) event := structs.NewTaskEvent(structs.TaskRestartSignal).SetRestartReason(reasonStr) - r.logger.Printf("[DEBUG] client: restarting task %v for alloc %q: %v", - r.task.Name, r.alloc.ID, reasonStr) - - r.runningLock.Lock() - running := r.running - r.runningLock.Unlock() - - // Drop the restart event - if !running { - r.logger.Printf("[DEBUG] client: skipping restart since task isn't running") - return - } - select { case r.restartCh <- event: case <-r.waitCh: @@ -1394,24 +1400,13 @@ func (r *TaskRunner) Signal(source, reason string, s os.Signal) error { reasonStr := fmt.Sprintf("%s: %s", source, reason) event := structs.NewTaskEvent(structs.TaskSignaling).SetTaskSignal(s).SetTaskSignalReason(reasonStr) - r.logger.Printf("[DEBUG] client: sending signal %v to task %v for alloc %q", s, r.task.Name, r.alloc.ID) - - r.runningLock.Lock() - running := r.running - r.runningLock.Unlock() - - // Drop the restart event - if !running { - r.logger.Printf("[DEBUG] client: skipping signal since task isn't running") - return nil - } - resCh := make(chan error) se := SignalEvent{ s: s, e: event, result: resCh, } + select { case r.signalCh <- se: case <-r.waitCh: diff --git a/client/task_runner_test.go b/client/task_runner_test.go index 2282f3921..8d2dab2bd 100644 --- a/client/task_runner_test.go +++ b/client/task_runner_test.go @@ -316,10 +316,10 @@ func TestTaskRunner_Update(t *testing.T) { return false, fmt.Errorf("Task not copied") } if ctx.tr.restartTracker.policy.Mode != newMode { - return false, fmt.Errorf("restart policy not ctx.upd.ted") + return false, fmt.Errorf("restart policy not ctx.updated") } if ctx.tr.handle.ID() == oldHandle { - return false, fmt.Errorf("handle not ctx.upd.ted") + return false, fmt.Errorf("handle not ctx.updated") } return true, nil }, func(err error) { @@ -645,6 +645,66 @@ func TestTaskRunner_RestartTask(t *testing.T) { } } +// This test is just to make sure we are resilient to failures when a restart or +// signal is triggered and the task is not running. +func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "100s", + } + + // Use vault to block the start + task.Vault = &structs.Vault{Policies: []string{"default"}} + + ctx := testTaskRunnerFromAlloc(t, true, alloc) + ctx.tr.MarkReceived() + defer ctx.Cleanup() + + // Control when we get a Vault token + token := "1234" + waitCh := make(chan struct{}) + defer close(waitCh) + handler := func(*structs.Allocation, []string) (map[string]string, error) { + <-waitCh + return map[string]string{task.Name: token}, nil + } + ctx.tr.vaultClient.(*vaultclient.MockVaultClient).DeriveTokenFn = handler + go ctx.tr.Run() + + select { + case <-ctx.tr.WaitCh(): + t.Fatalf("premature exit") + case <-time.After(1 * time.Second): + } + + // Send a signal and restart + if err := ctx.tr.Signal("test", "don't panic", syscall.SIGCHLD); err != nil { + t.Fatalf("Signalling errored: %v", err) + } + + // Send a restart + ctx.tr.Restart("test", "don't panic") + + if len(ctx.upd.events) != 2 { + t.Fatalf("should have 2 ctx.updates: %#v", ctx.upd.events) + } + + if ctx.upd.state != structs.TaskStatePending { + t.Fatalf("TaskState %v; want %v", ctx.upd.state, structs.TaskStatePending) + } + + if ctx.upd.events[0].Type != structs.TaskReceived { + t.Fatalf("First Event was %v; want %v", ctx.upd.events[0].Type, structs.TaskReceived) + } + + if ctx.upd.events[1].Type != structs.TaskSetup { + t.Fatalf("Second Event was %v; want %v", ctx.upd.events[1].Type, structs.TaskSetup) + } +} + func TestTaskRunner_KillTask(t *testing.T) { alloc := mock.Alloc() task := alloc.Job.TaskGroups[0].Tasks[0] @@ -1148,7 +1208,7 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) { } if originalManager == ctx.tr.templateManager { - return false, fmt.Errorf("Template manager not ctx.upd.ted") + return false, fmt.Errorf("Template manager not ctx.updated") } return true, nil