100 lines
3 KiB
Go
100 lines
3 KiB
Go
package vaultclient
|
|
|
|
import (
|
|
"github.com/hashicorp/nomad/nomad/structs"
|
|
vaultapi "github.com/hashicorp/vault/api"
|
|
)
|
|
|
|
// MockVaultClient is used for testing the vaultclient integration
|
|
type MockVaultClient struct {
|
|
// StoppedTokens tracks the tokens that have stopped renewing
|
|
StoppedTokens []string
|
|
|
|
// RenewTokens are the tokens that have been renewed and their error
|
|
// channels
|
|
RenewTokens map[string]chan error
|
|
|
|
// RenewTokenErrors is used to return an error when the RenewToken is called
|
|
// with the given token
|
|
RenewTokenErrors map[string]error
|
|
|
|
// DeriveTokenErrors maps an allocation ID and tasks to an error when the
|
|
// token is derived
|
|
DeriveTokenErrors map[string]map[string]error
|
|
|
|
// DeriveTokenFn allows the caller to control the DeriveToken function. If
|
|
// not set an error is returned if found in DeriveTokenErrors and otherwise
|
|
// a token is generated and returned
|
|
DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error)
|
|
}
|
|
|
|
// NewMockVaultClient returns a MockVaultClient for testing
|
|
func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} }
|
|
|
|
func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) {
|
|
if vc.DeriveTokenFn != nil {
|
|
return vc.DeriveTokenFn(a, tasks)
|
|
}
|
|
|
|
tokens := make(map[string]string, len(tasks))
|
|
for _, task := range tasks {
|
|
if tasks, ok := vc.DeriveTokenErrors[a.ID]; ok {
|
|
if err, ok := tasks[task]; ok {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
tokens[task] = structs.GenerateUUID()
|
|
}
|
|
|
|
return tokens, nil
|
|
}
|
|
|
|
func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) {
|
|
if vc.DeriveTokenErrors == nil {
|
|
vc.DeriveTokenErrors = make(map[string]map[string]error, 10)
|
|
}
|
|
|
|
if _, ok := vc.RenewTokenErrors[allocID]; !ok {
|
|
vc.DeriveTokenErrors[allocID] = make(map[string]error, 10)
|
|
}
|
|
|
|
for _, task := range tasks {
|
|
vc.DeriveTokenErrors[allocID][task] = err
|
|
}
|
|
}
|
|
|
|
func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) {
|
|
if err, ok := vc.RenewTokenErrors[token]; ok {
|
|
return nil, err
|
|
}
|
|
|
|
renewCh := make(chan error)
|
|
if vc.RenewTokens == nil {
|
|
vc.RenewTokens = make(map[string]chan error, 10)
|
|
}
|
|
vc.RenewTokens[token] = renewCh
|
|
return renewCh, nil
|
|
}
|
|
|
|
func (vc *MockVaultClient) SetRenewTokenError(token string, err error) {
|
|
if vc.RenewTokenErrors == nil {
|
|
vc.RenewTokenErrors = make(map[string]error, 10)
|
|
}
|
|
|
|
vc.RenewTokenErrors[token] = err
|
|
}
|
|
|
|
func (vc *MockVaultClient) StopRenewToken(token string) error {
|
|
vc.StoppedTokens = append(vc.StoppedTokens, token)
|
|
return nil
|
|
}
|
|
|
|
func (vc *MockVaultClient) RenewLease(leaseId string, interval int) (<-chan error, error) {
|
|
return nil, nil
|
|
}
|
|
func (vc *MockVaultClient) StopRenewLease(leaseId string) error { return nil }
|
|
func (vc *MockVaultClient) Start() {}
|
|
func (vc *MockVaultClient) Stop() {}
|
|
func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil }
|