open-nomad/client/vaultclient/vaultclient_testing.go

150 lines
3.8 KiB
Go
Raw Normal View History

2016-09-16 00:24:09 +00:00
package vaultclient
import (
"sync"
"github.com/hashicorp/nomad/helper/uuid"
2016-09-16 00:24:09 +00:00
"github.com/hashicorp/nomad/nomad/structs"
vaultapi "github.com/hashicorp/vault/api"
)
// MockVaultClient is used for testing the vaultclient integration and is safe
// for concurrent access.
2016-09-16 00:24:09 +00:00
type MockVaultClient struct {
// stoppedTokens tracks the tokens that have stopped renewing
stoppedTokens []string
2016-09-16 00:24:09 +00:00
// renewTokens are the tokens that have been renewed and their error
2016-09-16 00:24:09 +00:00
// channels
renewTokens map[string]chan error
2016-09-16 00:24:09 +00:00
// renewTokenErrors is used to return an error when the RenewToken is called
2016-09-16 00:24:09 +00:00
// with the given token
renewTokenErrors map[string]error
2016-09-16 00:24:09 +00:00
// deriveTokenErrors maps an allocation ID and tasks to an error when the
2016-09-16 00:24:09 +00:00
// token is derived
deriveTokenErrors map[string]map[string]error
2016-10-18 18:22:16 +00:00
2016-10-18 18:36:04 +00:00
// DeriveTokenFn allows the caller to control the DeriveToken function. If
// not set an error is returned if found in DeriveTokenErrors and otherwise
// a token is generated and returned
2016-10-18 18:22:16 +00:00
DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error)
mu sync.Mutex
2016-09-16 00:24:09 +00:00
}
// NewMockVaultClient returns a MockVaultClient for testing
func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} }
func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) {
vc.mu.Lock()
defer vc.mu.Unlock()
2016-10-18 18:22:16 +00:00
if vc.DeriveTokenFn != nil {
return vc.DeriveTokenFn(a, tasks)
}
2016-09-16 00:24:09 +00:00
tokens := make(map[string]string, len(tasks))
for _, task := range tasks {
if tasks, ok := vc.deriveTokenErrors[a.ID]; ok {
2016-09-16 00:24:09 +00:00
if err, ok := tasks[task]; ok {
return nil, err
}
}
tokens[task] = uuid.Generate()
2016-09-16 00:24:09 +00:00
}
return tokens, nil
}
func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) {
vc.mu.Lock()
defer vc.mu.Unlock()
if vc.deriveTokenErrors == nil {
vc.deriveTokenErrors = make(map[string]map[string]error, 10)
2016-09-16 00:24:09 +00:00
}
if _, ok := vc.deriveTokenErrors[allocID]; !ok {
vc.deriveTokenErrors[allocID] = make(map[string]error, 10)
2016-09-16 00:24:09 +00:00
}
for _, task := range tasks {
vc.deriveTokenErrors[allocID][task] = err
2016-09-16 00:24:09 +00:00
}
}
func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) {
vc.mu.Lock()
defer vc.mu.Unlock()
if err, ok := vc.renewTokenErrors[token]; ok {
2016-09-16 00:24:09 +00:00
return nil, err
}
renewCh := make(chan error)
if vc.renewTokens == nil {
vc.renewTokens = make(map[string]chan error, 10)
2016-09-16 00:24:09 +00:00
}
vc.renewTokens[token] = renewCh
2016-09-16 00:24:09 +00:00
return renewCh, nil
}
func (vc *MockVaultClient) SetRenewTokenError(token string, err error) {
vc.mu.Lock()
defer vc.mu.Unlock()
if vc.renewTokenErrors == nil {
vc.renewTokenErrors = make(map[string]error, 10)
2016-09-16 00:24:09 +00:00
}
vc.renewTokenErrors[token] = err
2016-09-16 00:24:09 +00:00
}
func (vc *MockVaultClient) StopRenewToken(token string) error {
vc.mu.Lock()
defer vc.mu.Unlock()
vc.stoppedTokens = append(vc.stoppedTokens, token)
2016-09-16 00:24:09 +00:00
return nil
}
func (vc *MockVaultClient) Start() {}
func (vc *MockVaultClient) Stop() {}
2016-09-16 00:24:09 +00:00
func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil }
// StoppedTokens tracks the tokens that have stopped renewing
func (vc *MockVaultClient) StoppedTokens() []string {
vc.mu.Lock()
defer vc.mu.Unlock()
return vc.stoppedTokens
}
// RenewTokens are the tokens that have been renewed and their error
// channels
func (vc *MockVaultClient) RenewTokens() map[string]chan error {
vc.mu.Lock()
defer vc.mu.Unlock()
return vc.renewTokens
}
// RenewTokenErrors is used to return an error when the RenewToken is called
// with the given token
func (vc *MockVaultClient) RenewTokenErrors() map[string]error {
vc.mu.Lock()
defer vc.mu.Unlock()
return vc.renewTokenErrors
}
// DeriveTokenErrors maps an allocation ID and tasks to an error when the
// token is derived
func (vc *MockVaultClient) DeriveTokenErrors() map[string]map[string]error {
vc.mu.Lock()
defer vc.mu.Unlock()
return vc.deriveTokenErrors
}