From 6743ed9fdce584458e044d76f49b886d9c556684 Mon Sep 17 00:00:00 2001 From: Michael Schurter Date: Tue, 12 Feb 2019 13:46:09 -0800 Subject: [PATCH] tests: port TestTaskRunner_BlockForVault from 0.8 Also fix race conditions in the mock vault client. --- .../taskrunner/task_runner_test.go | 77 +++++++++++++++ client/vaultclient/vaultclient_testing.go | 97 ++++++++++++++----- 2 files changed, 151 insertions(+), 23 deletions(-) diff --git a/client/allocrunner/taskrunner/task_runner_test.go b/client/allocrunner/taskrunner/task_runner_test.go index bd37845c0..e84c3cf26 100644 --- a/client/allocrunner/taskrunner/task_runner_test.go +++ b/client/allocrunner/taskrunner/task_runner_test.go @@ -767,6 +767,83 @@ func TestTaskRunner_CheckWatcher_Restart(t *testing.T) { require.True(t, state.Failed, pretty.Sprint(state)) } +// TestTaskRunner_BlockForVault asserts tasks do not start until a vault token +// is derived. +func TestTaskRunner_BlockForVault(t *testing.T) { + t.Parallel() + + alloc := mock.BatchAlloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Config = map[string]interface{}{ + "run_for": "0s", + } + task.Vault = &structs.Vault{Policies: []string{"default"}} + + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + defer cleanup() + + // Control when we get a Vault token + token := "1234" + waitCh := make(chan struct{}) + handler := func(*structs.Allocation, []string) (map[string]string, error) { + <-waitCh + return map[string]string{task.Name: token}, nil + } + vaultClient := conf.Vault.(*vaultclient.MockVaultClient) + vaultClient.DeriveTokenFn = handler + + tr, err := NewTaskRunner(conf) + require.NoError(t, err) + defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup")) + go tr.Run() + + // Assert TR blocks on vault token (does *not* exit) + select { + case <-tr.WaitCh(): + require.Fail(t, "tr exited before vault unblocked") + case <-time.After(1 * time.Second): + } + + // Assert task state is still Pending + require.Equal(t, structs.TaskStatePending, tr.TaskState().State) + + // Unblock vault token + close(waitCh) + + // TR should exit now that it's unblocked by vault as its a batch job + // with 0 sleeping. + select { + case <-tr.WaitCh(): + case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second): + require.Fail(t, "timed out waiting for batch task to exit") + } + + // Assert task exited successfully + finalState := tr.TaskState() + require.Equal(t, structs.TaskStateDead, finalState.State) + require.False(t, finalState.Failed) + + // Check that the token is on disk + tokenPath := filepath.Join(conf.TaskDir.SecretsDir, vaultTokenFile) + data, err := ioutil.ReadFile(tokenPath) + require.NoError(t, err) + require.Equal(t, token, string(data)) + + // Check the token was revoked + testutil.WaitForResult(func() (bool, error) { + if len(vaultClient.StoppedTokens()) != 1 { + return false, fmt.Errorf("Expected a stopped token: %v", vaultClient.StoppedTokens()) + } + + if a := vaultClient.StoppedTokens()[0]; a != token { + return false, fmt.Errorf("got stopped token %q; want %q", a, token) + } + return true, nil + }, func(err error) { + require.Fail(t, err.Error()) + }) +} + // testWaitForTaskToStart waits for the task to be running or fails the test func testWaitForTaskToStart(t *testing.T, tr *TaskRunner) { testutil.WaitForResult(func() (bool, error) { diff --git a/client/vaultclient/vaultclient_testing.go b/client/vaultclient/vaultclient_testing.go index 62f4acf7f..ed99f4cc4 100644 --- a/client/vaultclient/vaultclient_testing.go +++ b/client/vaultclient/vaultclient_testing.go @@ -1,45 +1,53 @@ package vaultclient import ( + "sync" + "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/structs" vaultapi "github.com/hashicorp/vault/api" ) -// MockVaultClient is used for testing the vaultclient integration +// MockVaultClient is used for testing the vaultclient integration and is safe +// for concurrent access. type MockVaultClient struct { - // StoppedTokens tracks the tokens that have stopped renewing - StoppedTokens []string + // stoppedTokens tracks the tokens that have stopped renewing + stoppedTokens []string - // RenewTokens are the tokens that have been renewed and their error + // renewTokens are the tokens that have been renewed and their error // channels - RenewTokens map[string]chan error + renewTokens map[string]chan error - // RenewTokenErrors is used to return an error when the RenewToken is called + // renewTokenErrors is used to return an error when the RenewToken is called // with the given token - RenewTokenErrors map[string]error + renewTokenErrors map[string]error - // DeriveTokenErrors maps an allocation ID and tasks to an error when the + // deriveTokenErrors maps an allocation ID and tasks to an error when the // token is derived - DeriveTokenErrors map[string]map[string]error + deriveTokenErrors map[string]map[string]error // DeriveTokenFn allows the caller to control the DeriveToken function. If // not set an error is returned if found in DeriveTokenErrors and otherwise // a token is generated and returned DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error) + + mu sync.Mutex } // NewMockVaultClient returns a MockVaultClient for testing func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} } func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) { + vc.mu.Lock() + defer vc.mu.Unlock() + if vc.DeriveTokenFn != nil { return vc.DeriveTokenFn(a, tasks) } tokens := make(map[string]string, len(tasks)) for _, task := range tasks { - if tasks, ok := vc.DeriveTokenErrors[a.ID]; ok { + if tasks, ok := vc.deriveTokenErrors[a.ID]; ok { if err, ok := tasks[task]; ok { return nil, err } @@ -52,42 +60,54 @@ func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (m } func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) { - if vc.DeriveTokenErrors == nil { - vc.DeriveTokenErrors = make(map[string]map[string]error, 10) + vc.mu.Lock() + defer vc.mu.Unlock() + + if vc.deriveTokenErrors == nil { + vc.deriveTokenErrors = make(map[string]map[string]error, 10) } - if _, ok := vc.RenewTokenErrors[allocID]; !ok { - vc.DeriveTokenErrors[allocID] = make(map[string]error, 10) + if _, ok := vc.renewTokenErrors[allocID]; !ok { + vc.deriveTokenErrors[allocID] = make(map[string]error, 10) } for _, task := range tasks { - vc.DeriveTokenErrors[allocID][task] = err + vc.deriveTokenErrors[allocID][task] = err } } func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) { - if err, ok := vc.RenewTokenErrors[token]; ok { + vc.mu.Lock() + defer vc.mu.Unlock() + + if err, ok := vc.renewTokenErrors[token]; ok { return nil, err } renewCh := make(chan error) - if vc.RenewTokens == nil { - vc.RenewTokens = make(map[string]chan error, 10) + if vc.renewTokens == nil { + vc.renewTokens = make(map[string]chan error, 10) } - vc.RenewTokens[token] = renewCh + vc.renewTokens[token] = renewCh return renewCh, nil } func (vc *MockVaultClient) SetRenewTokenError(token string, err error) { - if vc.RenewTokenErrors == nil { - vc.RenewTokenErrors = make(map[string]error, 10) + vc.mu.Lock() + defer vc.mu.Unlock() + + if vc.renewTokenErrors == nil { + vc.renewTokenErrors = make(map[string]error, 10) } - vc.RenewTokenErrors[token] = err + vc.renewTokenErrors[token] = err } func (vc *MockVaultClient) StopRenewToken(token string) error { - vc.StoppedTokens = append(vc.StoppedTokens, token) + vc.mu.Lock() + defer vc.mu.Unlock() + + vc.stoppedTokens = append(vc.stoppedTokens, token) return nil } @@ -98,3 +118,34 @@ func (vc *MockVaultClient) StopRenewLease(leaseId string) error func (vc *MockVaultClient) Start() {} func (vc *MockVaultClient) Stop() {} func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil } + +// StoppedTokens tracks the tokens that have stopped renewing +func (vc *MockVaultClient) StoppedTokens() []string { + vc.mu.Lock() + defer vc.mu.Unlock() + return vc.stoppedTokens +} + +// RenewTokens are the tokens that have been renewed and their error +// channels +func (vc *MockVaultClient) RenewTokens() map[string]chan error { + vc.mu.Lock() + defer vc.mu.Unlock() + return vc.renewTokens +} + +// RenewTokenErrors is used to return an error when the RenewToken is called +// with the given token +func (vc *MockVaultClient) RenewTokenErrors() map[string]error { + vc.mu.Lock() + defer vc.mu.Unlock() + return vc.renewTokenErrors +} + +// DeriveTokenErrors maps an allocation ID and tasks to an error when the +// token is derived +func (vc *MockVaultClient) DeriveTokenErrors() map[string]map[string]error { + vc.mu.Lock() + defer vc.mu.Unlock() + return vc.deriveTokenErrors +}