tests: port TestTaskRunner_BlockForVault from 0.8
Also fix race conditions in the mock vault client.
This commit is contained in:
parent
f7102cd01d
commit
6743ed9fdc
|
@ -767,6 +767,83 @@ func TestTaskRunner_CheckWatcher_Restart(t *testing.T) {
|
|||
require.True(t, state.Failed, pretty.Sprint(state))
|
||||
}
|
||||
|
||||
// TestTaskRunner_BlockForVault asserts tasks do not start until a vault token
|
||||
// is derived.
|
||||
func TestTaskRunner_BlockForVault(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
alloc := mock.BatchAlloc()
|
||||
task := alloc.Job.TaskGroups[0].Tasks[0]
|
||||
task.Config = map[string]interface{}{
|
||||
"run_for": "0s",
|
||||
}
|
||||
task.Vault = &structs.Vault{Policies: []string{"default"}}
|
||||
|
||||
conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name)
|
||||
defer cleanup()
|
||||
|
||||
// Control when we get a Vault token
|
||||
token := "1234"
|
||||
waitCh := make(chan struct{})
|
||||
handler := func(*structs.Allocation, []string) (map[string]string, error) {
|
||||
<-waitCh
|
||||
return map[string]string{task.Name: token}, nil
|
||||
}
|
||||
vaultClient := conf.Vault.(*vaultclient.MockVaultClient)
|
||||
vaultClient.DeriveTokenFn = handler
|
||||
|
||||
tr, err := NewTaskRunner(conf)
|
||||
require.NoError(t, err)
|
||||
defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup"))
|
||||
go tr.Run()
|
||||
|
||||
// Assert TR blocks on vault token (does *not* exit)
|
||||
select {
|
||||
case <-tr.WaitCh():
|
||||
require.Fail(t, "tr exited before vault unblocked")
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
|
||||
// Assert task state is still Pending
|
||||
require.Equal(t, structs.TaskStatePending, tr.TaskState().State)
|
||||
|
||||
// Unblock vault token
|
||||
close(waitCh)
|
||||
|
||||
// TR should exit now that it's unblocked by vault as its a batch job
|
||||
// with 0 sleeping.
|
||||
select {
|
||||
case <-tr.WaitCh():
|
||||
case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
|
||||
require.Fail(t, "timed out waiting for batch task to exit")
|
||||
}
|
||||
|
||||
// Assert task exited successfully
|
||||
finalState := tr.TaskState()
|
||||
require.Equal(t, structs.TaskStateDead, finalState.State)
|
||||
require.False(t, finalState.Failed)
|
||||
|
||||
// Check that the token is on disk
|
||||
tokenPath := filepath.Join(conf.TaskDir.SecretsDir, vaultTokenFile)
|
||||
data, err := ioutil.ReadFile(tokenPath)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, token, string(data))
|
||||
|
||||
// Check the token was revoked
|
||||
testutil.WaitForResult(func() (bool, error) {
|
||||
if len(vaultClient.StoppedTokens()) != 1 {
|
||||
return false, fmt.Errorf("Expected a stopped token: %v", vaultClient.StoppedTokens())
|
||||
}
|
||||
|
||||
if a := vaultClient.StoppedTokens()[0]; a != token {
|
||||
return false, fmt.Errorf("got stopped token %q; want %q", a, token)
|
||||
}
|
||||
return true, nil
|
||||
}, func(err error) {
|
||||
require.Fail(t, err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
// testWaitForTaskToStart waits for the task to be running or fails the test
|
||||
func testWaitForTaskToStart(t *testing.T, tr *TaskRunner) {
|
||||
testutil.WaitForResult(func() (bool, error) {
|
||||
|
|
|
@ -1,45 +1,53 @@
|
|||
package vaultclient
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/nomad/helper/uuid"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
vaultapi "github.com/hashicorp/vault/api"
|
||||
)
|
||||
|
||||
// MockVaultClient is used for testing the vaultclient integration
|
||||
// MockVaultClient is used for testing the vaultclient integration and is safe
|
||||
// for concurrent access.
|
||||
type MockVaultClient struct {
|
||||
// StoppedTokens tracks the tokens that have stopped renewing
|
||||
StoppedTokens []string
|
||||
// stoppedTokens tracks the tokens that have stopped renewing
|
||||
stoppedTokens []string
|
||||
|
||||
// RenewTokens are the tokens that have been renewed and their error
|
||||
// renewTokens are the tokens that have been renewed and their error
|
||||
// channels
|
||||
RenewTokens map[string]chan error
|
||||
renewTokens map[string]chan error
|
||||
|
||||
// RenewTokenErrors is used to return an error when the RenewToken is called
|
||||
// renewTokenErrors is used to return an error when the RenewToken is called
|
||||
// with the given token
|
||||
RenewTokenErrors map[string]error
|
||||
renewTokenErrors map[string]error
|
||||
|
||||
// DeriveTokenErrors maps an allocation ID and tasks to an error when the
|
||||
// deriveTokenErrors maps an allocation ID and tasks to an error when the
|
||||
// token is derived
|
||||
DeriveTokenErrors map[string]map[string]error
|
||||
deriveTokenErrors map[string]map[string]error
|
||||
|
||||
// DeriveTokenFn allows the caller to control the DeriveToken function. If
|
||||
// not set an error is returned if found in DeriveTokenErrors and otherwise
|
||||
// a token is generated and returned
|
||||
DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error)
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewMockVaultClient returns a MockVaultClient for testing
|
||||
func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} }
|
||||
|
||||
func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) {
|
||||
vc.mu.Lock()
|
||||
defer vc.mu.Unlock()
|
||||
|
||||
if vc.DeriveTokenFn != nil {
|
||||
return vc.DeriveTokenFn(a, tasks)
|
||||
}
|
||||
|
||||
tokens := make(map[string]string, len(tasks))
|
||||
for _, task := range tasks {
|
||||
if tasks, ok := vc.DeriveTokenErrors[a.ID]; ok {
|
||||
if tasks, ok := vc.deriveTokenErrors[a.ID]; ok {
|
||||
if err, ok := tasks[task]; ok {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -52,42 +60,54 @@ func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (m
|
|||
}
|
||||
|
||||
func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) {
|
||||
if vc.DeriveTokenErrors == nil {
|
||||
vc.DeriveTokenErrors = make(map[string]map[string]error, 10)
|
||||
vc.mu.Lock()
|
||||
defer vc.mu.Unlock()
|
||||
|
||||
if vc.deriveTokenErrors == nil {
|
||||
vc.deriveTokenErrors = make(map[string]map[string]error, 10)
|
||||
}
|
||||
|
||||
if _, ok := vc.RenewTokenErrors[allocID]; !ok {
|
||||
vc.DeriveTokenErrors[allocID] = make(map[string]error, 10)
|
||||
if _, ok := vc.renewTokenErrors[allocID]; !ok {
|
||||
vc.deriveTokenErrors[allocID] = make(map[string]error, 10)
|
||||
}
|
||||
|
||||
for _, task := range tasks {
|
||||
vc.DeriveTokenErrors[allocID][task] = err
|
||||
vc.deriveTokenErrors[allocID][task] = err
|
||||
}
|
||||
}
|
||||
|
||||
func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) {
|
||||
if err, ok := vc.RenewTokenErrors[token]; ok {
|
||||
vc.mu.Lock()
|
||||
defer vc.mu.Unlock()
|
||||
|
||||
if err, ok := vc.renewTokenErrors[token]; ok {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
renewCh := make(chan error)
|
||||
if vc.RenewTokens == nil {
|
||||
vc.RenewTokens = make(map[string]chan error, 10)
|
||||
if vc.renewTokens == nil {
|
||||
vc.renewTokens = make(map[string]chan error, 10)
|
||||
}
|
||||
vc.RenewTokens[token] = renewCh
|
||||
vc.renewTokens[token] = renewCh
|
||||
return renewCh, nil
|
||||
}
|
||||
|
||||
func (vc *MockVaultClient) SetRenewTokenError(token string, err error) {
|
||||
if vc.RenewTokenErrors == nil {
|
||||
vc.RenewTokenErrors = make(map[string]error, 10)
|
||||
vc.mu.Lock()
|
||||
defer vc.mu.Unlock()
|
||||
|
||||
if vc.renewTokenErrors == nil {
|
||||
vc.renewTokenErrors = make(map[string]error, 10)
|
||||
}
|
||||
|
||||
vc.RenewTokenErrors[token] = err
|
||||
vc.renewTokenErrors[token] = err
|
||||
}
|
||||
|
||||
func (vc *MockVaultClient) StopRenewToken(token string) error {
|
||||
vc.StoppedTokens = append(vc.StoppedTokens, token)
|
||||
vc.mu.Lock()
|
||||
defer vc.mu.Unlock()
|
||||
|
||||
vc.stoppedTokens = append(vc.stoppedTokens, token)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -98,3 +118,34 @@ func (vc *MockVaultClient) StopRenewLease(leaseId string) error
|
|||
func (vc *MockVaultClient) Start() {}
|
||||
func (vc *MockVaultClient) Stop() {}
|
||||
func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil }
|
||||
|
||||
// StoppedTokens tracks the tokens that have stopped renewing
|
||||
func (vc *MockVaultClient) StoppedTokens() []string {
|
||||
vc.mu.Lock()
|
||||
defer vc.mu.Unlock()
|
||||
return vc.stoppedTokens
|
||||
}
|
||||
|
||||
// RenewTokens are the tokens that have been renewed and their error
|
||||
// channels
|
||||
func (vc *MockVaultClient) RenewTokens() map[string]chan error {
|
||||
vc.mu.Lock()
|
||||
defer vc.mu.Unlock()
|
||||
return vc.renewTokens
|
||||
}
|
||||
|
||||
// RenewTokenErrors is used to return an error when the RenewToken is called
|
||||
// with the given token
|
||||
func (vc *MockVaultClient) RenewTokenErrors() map[string]error {
|
||||
vc.mu.Lock()
|
||||
defer vc.mu.Unlock()
|
||||
return vc.renewTokenErrors
|
||||
}
|
||||
|
||||
// DeriveTokenErrors maps an allocation ID and tasks to an error when the
|
||||
// token is derived
|
||||
func (vc *MockVaultClient) DeriveTokenErrors() map[string]map[string]error {
|
||||
vc.mu.Lock()
|
||||
defer vc.mu.Unlock()
|
||||
return vc.deriveTokenErrors
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue