diff --git a/client/task_runner.go b/client/task_runner.go index c578ca5ea..2a24b8fd2 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -1442,7 +1442,7 @@ func (r *TaskRunner) updateServices(d driver.Driver, h driver.ScriptExecutor, ol // Allow set the script executor if the driver supports it exec = h } - interpolateServices(r.getTaskEnv(), r.task) + interpolateServices(r.getTaskEnv(), new) return r.consul.UpdateTask(r.alloc.ID, old, new, exec) } diff --git a/command/agent/consul/client.go b/command/agent/consul/client.go index 2fd2542e4..38f331f20 100644 --- a/command/agent/consul/client.go +++ b/command/agent/consul/client.go @@ -221,6 +221,7 @@ func (c *ServiceClient) merge(ops *operations) { if script, ok := c.runningScripts[cid]; ok { script.cancel() delete(c.scripts, cid) + delete(c.runningScripts, cid) } delete(c.checks, cid) } @@ -673,14 +674,15 @@ func createCheckReg(serviceID, checkID string, check *structs.ServiceCheck, host switch check.Type { case structs.ServiceCheckHTTP: - if check.Protocol == "" { - check.Protocol = "http" + proto := check.Protocol + if proto == "" { + proto = "http" } if check.TLSSkipVerify { chkReg.TLSSkipVerify = true } base := url.URL{ - Scheme: check.Protocol, + Scheme: proto, Host: net.JoinHostPort(host, strconv.Itoa(port)), } relative, err := url.Parse(check.Path) diff --git a/command/agent/consul/unit_test.go b/command/agent/consul/unit_test.go index 96794e0c1..94e00a949 100644 --- a/command/agent/consul/unit_test.go +++ b/command/agent/consul/unit_test.go @@ -767,3 +767,90 @@ func TestConsul_NoTLSSkipVerifySupport(t *testing.T) { } } } + +// TestConsul_RemoveScript assert removing a script check removes all objects +// related to that check. +func TestConsul_CancelScript(t *testing.T) { + ctx := setupFake() + ctx.Task.Services[0].Checks = []*structs.ServiceCheck{ + { + Name: "scriptcheckDel", + Type: "script", + Interval: 9000 * time.Hour, + Timeout: 9000 * time.Hour, + }, + { + Name: "scriptcheckKeep", + Type: "script", + Interval: 9000 * time.Hour, + Timeout: 9000 * time.Hour, + }, + } + + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx); err != nil { + t.Fatalf("unexpected error registering task: %v", err) + } + + if err := ctx.syncOnce(); err != nil { + t.Fatalf("unexpected error syncing task: %v", err) + } + + if len(ctx.FakeConsul.checks) != 2 { + t.Errorf("expected 2 checks but found %d", len(ctx.FakeConsul.checks)) + } + + if len(ctx.ServiceClient.scripts) != 2 && len(ctx.ServiceClient.runningScripts) != 2 { + t.Errorf("expected 2 running script but found scripts=%d runningScripts=%d", + len(ctx.ServiceClient.scripts), len(ctx.ServiceClient.runningScripts)) + } + + for i := 0; i < 2; i++ { + select { + case <-ctx.execs: + // Script ran as expected! + case <-time.After(3 * time.Second): + t.Fatalf("timed out waiting for script check to run") + } + } + + // Remove a check and update the task + origTask := ctx.Task.Copy() + ctx.Task.Services[0].Checks = []*structs.ServiceCheck{ + { + Name: "scriptcheckKeep", + Type: "script", + Interval: 9000 * time.Hour, + Timeout: 9000 * time.Hour, + }, + } + + if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, ctx); err != nil { + t.Fatalf("unexpected error registering task: %v", err) + } + + if err := ctx.syncOnce(); err != nil { + t.Fatalf("unexpected error syncing task: %v", err) + } + + if len(ctx.FakeConsul.checks) != 1 { + t.Errorf("expected 1 check but found %d", len(ctx.FakeConsul.checks)) + } + + if len(ctx.ServiceClient.scripts) != 1 && len(ctx.ServiceClient.runningScripts) != 1 { + t.Errorf("expected 1 running script but found scripts=%d runningScripts=%d", + len(ctx.ServiceClient.scripts), len(ctx.ServiceClient.runningScripts)) + } + + // Make sure exec wasn't called again + select { + case <-ctx.execs: + t.Errorf("unexpected execution of script; was goroutine not cancelled?") + case <-time.After(100 * time.Millisecond): + // No unexpected script execs + } + + // Don't leak goroutines + for _, scriptHandle := range ctx.ServiceClient.runningScripts { + scriptHandle.cancel() + } +} diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 49196b41b..af2553cd7 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -2088,7 +2088,6 @@ func (tg *TaskGroup) GoString() string { } const ( - // TODO add Consul TTL check ServiceCheckHTTP = "http" ServiceCheckTCP = "tcp" ServiceCheckScript = "script"