open-nomad/client/consul/identities_testing.go

83 lines
2.3 KiB
Go

package consul
import (
"sync"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/structs"
)
// MockServiceIdentitiesClient is used for testing the client for managing consul service
// identity tokens.
type MockServiceIdentitiesClient struct {
// deriveTokenErrors maps an allocation ID and tasks to an error when the
// token is derived
deriveTokenErrors map[string]map[string]error
// DeriveTokenFn allows the caller to control the DeriveToken function. If
// not set an error is returned if found in DeriveTokenErrors and otherwise
// a token is generated and returned
DeriveTokenFn TokenDeriverFunc
// lock around everything
lock sync.Mutex
}
var _ ServiceIdentityAPI = (*MockServiceIdentitiesClient)(nil)
// NewMockServiceIdentitiesClient returns a MockServiceIdentitiesClient for testing.
func NewMockServiceIdentitiesClient() *MockServiceIdentitiesClient {
return &MockServiceIdentitiesClient{
deriveTokenErrors: make(map[string]map[string]error),
}
}
func (mtc *MockServiceIdentitiesClient) DeriveSITokens(alloc *structs.Allocation, tasks []string) (map[string]string, error) {
mtc.lock.Lock()
defer mtc.lock.Unlock()
// if the DeriveTokenFn is explicitly set, use that
if mtc.DeriveTokenFn != nil {
return mtc.DeriveTokenFn(alloc, tasks)
}
// generate a token for each task, unless the mock has an error ready for
// one or more of the tasks in which case return that
tokens := make(map[string]string, len(tasks))
for _, task := range tasks {
if m, ok := mtc.deriveTokenErrors[alloc.ID]; ok {
if err, ok := m[task]; ok {
return nil, err
}
}
tokens[task] = uuid.Generate()
}
return tokens, nil
}
func (mtc *MockServiceIdentitiesClient) SetDeriveTokenError(allocID string, tasks []string, err error) {
mtc.lock.Lock()
defer mtc.lock.Unlock()
if _, ok := mtc.deriveTokenErrors[allocID]; !ok {
mtc.deriveTokenErrors[allocID] = make(map[string]error, 10)
}
for _, task := range tasks {
mtc.deriveTokenErrors[allocID][task] = err
}
}
func (mtc *MockServiceIdentitiesClient) DeriveTokenErrors() map[string]map[string]error {
mtc.lock.Lock()
defer mtc.lock.Unlock()
m := make(map[string]map[string]error)
for aID, tasks := range mtc.deriveTokenErrors {
for task, err := range tasks {
m[aID][task] = err
}
}
return m
}