tests: port TestTaskRunner_BlockForVault from 0.8

Also fix race conditions in the mock vault client.
This commit is contained in:
Michael Schurter 2019-02-12 13:46:09 -08:00
parent f7102cd01d
commit 6743ed9fdc
2 changed files with 151 additions and 23 deletions

View File

@ -767,6 +767,83 @@ func TestTaskRunner_CheckWatcher_Restart(t *testing.T) {
require.True(t, state.Failed, pretty.Sprint(state))
}
// TestTaskRunner_BlockForVault asserts tasks do not start until a vault token
// is derived.
func TestTaskRunner_BlockForVault(t *testing.T) {
t.Parallel()
alloc := mock.BatchAlloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Config = map[string]interface{}{
"run_for": "0s",
}
task.Vault = &structs.Vault{Policies: []string{"default"}}
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
defer cleanup()
// Control when we get a Vault token
token := "1234"
waitCh := make(chan struct{})
handler := func(*structs.Allocation, []string) (map[string]string, error) {
<-waitCh
return map[string]string{task.Name: token}, nil
}
vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
vaultClient.DeriveTokenFn = handler
tr, err := NewTaskRunner(conf)
require.NoError(t, err)
defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
go tr.Run()
// Assert TR blocks on vault token (does *not* exit)
select {
case <-tr.WaitCh():
require.Fail(t, "tr exited before vault unblocked")
case <-time.After(1 * time.Second):
}
// Assert task state is still Pending
require.Equal(t, structs.TaskStatePending, tr.TaskState().State)
// Unblock vault token
close(waitCh)
// TR should exit now that it's unblocked by vault as its a batch job
// with 0 sleeping.
select {
case <-tr.WaitCh():
case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
require.Fail(t, "timed out waiting for batch task to exit")
}
// Assert task exited successfully
finalState := tr.TaskState()
require.Equal(t, structs.TaskStateDead, finalState.State)
require.False(t, finalState.Failed)
// Check that the token is on disk
tokenPath := filepath.Join(conf.TaskDir.SecretsDir, vaultTokenFile)
data, err := ioutil.ReadFile(tokenPath)
require.NoError(t, err)
require.Equal(t, token, string(data))
// Check the token was revoked
testutil.WaitForResult(func() (bool, error) {
if len(vaultClient.StoppedTokens()) != 1 {
return false, fmt.Errorf("Expected a stopped token: %v", vaultClient.StoppedTokens())
}
if a := vaultClient.StoppedTokens()[0]; a != token {
return false, fmt.Errorf("got stopped token %q; want %q", a, token)
}
return true, nil
}, func(err error) {
require.Fail(t, err.Error())
})
}
// testWaitForTaskToStart waits for the task to be running or fails the test
func testWaitForTaskToStart(t *testing.T, tr *TaskRunner) {
testutil.WaitForResult(func() (bool, error) {

View File

@ -1,45 +1,53 @@
package vaultclient
import (
"sync"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/structs"
vaultapi "github.com/hashicorp/vault/api"
)
// MockVaultClient is used for testing the vaultclient integration
// MockVaultClient is used for testing the vaultclient integration and is safe
// for concurrent access.
type MockVaultClient struct {
// StoppedTokens tracks the tokens that have stopped renewing
StoppedTokens []string
// stoppedTokens tracks the tokens that have stopped renewing
stoppedTokens []string
// RenewTokens are the tokens that have been renewed and their error
// renewTokens are the tokens that have been renewed and their error
// channels
RenewTokens map[string]chan error
renewTokens map[string]chan error
// RenewTokenErrors is used to return an error when the RenewToken is called
// renewTokenErrors is used to return an error when the RenewToken is called
// with the given token
RenewTokenErrors map[string]error
renewTokenErrors map[string]error
// DeriveTokenErrors maps an allocation ID and tasks to an error when the
// deriveTokenErrors maps an allocation ID and tasks to an error when the
// token is derived
DeriveTokenErrors map[string]map[string]error
deriveTokenErrors map[string]map[string]error
// DeriveTokenFn allows the caller to control the DeriveToken function. If
// not set an error is returned if found in DeriveTokenErrors and otherwise
// a token is generated and returned
DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error)
mu sync.Mutex
}
// NewMockVaultClient returns a MockVaultClient for testing
func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} }
func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) {
vc.mu.Lock()
defer vc.mu.Unlock()
if vc.DeriveTokenFn != nil {
return vc.DeriveTokenFn(a, tasks)
}
tokens := make(map[string]string, len(tasks))
for _, task := range tasks {
if tasks, ok := vc.DeriveTokenErrors[a.ID]; ok {
if tasks, ok := vc.deriveTokenErrors[a.ID]; ok {
if err, ok := tasks[task]; ok {
return nil, err
}
@ -52,42 +60,54 @@ func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (m
}
func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) {
if vc.DeriveTokenErrors == nil {
vc.DeriveTokenErrors = make(map[string]map[string]error, 10)
vc.mu.Lock()
defer vc.mu.Unlock()
if vc.deriveTokenErrors == nil {
vc.deriveTokenErrors = make(map[string]map[string]error, 10)
}
if _, ok := vc.RenewTokenErrors[allocID]; !ok {
vc.DeriveTokenErrors[allocID] = make(map[string]error, 10)
if _, ok := vc.renewTokenErrors[allocID]; !ok {
vc.deriveTokenErrors[allocID] = make(map[string]error, 10)
}
for _, task := range tasks {
vc.DeriveTokenErrors[allocID][task] = err
vc.deriveTokenErrors[allocID][task] = err
}
}
func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) {
if err, ok := vc.RenewTokenErrors[token]; ok {
vc.mu.Lock()
defer vc.mu.Unlock()
if err, ok := vc.renewTokenErrors[token]; ok {
return nil, err
}
renewCh := make(chan error)
if vc.RenewTokens == nil {
vc.RenewTokens = make(map[string]chan error, 10)
if vc.renewTokens == nil {
vc.renewTokens = make(map[string]chan error, 10)
}
vc.RenewTokens[token] = renewCh
vc.renewTokens[token] = renewCh
return renewCh, nil
}
func (vc *MockVaultClient) SetRenewTokenError(token string, err error) {
if vc.RenewTokenErrors == nil {
vc.RenewTokenErrors = make(map[string]error, 10)
vc.mu.Lock()
defer vc.mu.Unlock()
if vc.renewTokenErrors == nil {
vc.renewTokenErrors = make(map[string]error, 10)
}
vc.RenewTokenErrors[token] = err
vc.renewTokenErrors[token] = err
}
func (vc *MockVaultClient) StopRenewToken(token string) error {
vc.StoppedTokens = append(vc.StoppedTokens, token)
vc.mu.Lock()
defer vc.mu.Unlock()
vc.stoppedTokens = append(vc.stoppedTokens, token)
return nil
}
@ -98,3 +118,34 @@ func (vc *MockVaultClient) StopRenewLease(leaseId string) error
func (vc *MockVaultClient) Start() {}
func (vc *MockVaultClient) Stop() {}
func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil }
// StoppedTokens tracks the tokens that have stopped renewing
func (vc *MockVaultClient) StoppedTokens() []string {
vc.mu.Lock()
defer vc.mu.Unlock()
return vc.stoppedTokens
}
// RenewTokens are the tokens that have been renewed and their error
// channels
func (vc *MockVaultClient) RenewTokens() map[string]chan error {
vc.mu.Lock()
defer vc.mu.Unlock()
return vc.renewTokens
}
// RenewTokenErrors is used to return an error when the RenewToken is called
// with the given token
func (vc *MockVaultClient) RenewTokenErrors() map[string]error {
vc.mu.Lock()
defer vc.mu.Unlock()
return vc.renewTokenErrors
}
// DeriveTokenErrors maps an allocation ID and tasks to an error when the
// token is derived
func (vc *MockVaultClient) DeriveTokenErrors() map[string]map[string]error {
vc.mu.Lock()
defer vc.mu.Unlock()
return vc.deriveTokenErrors
}