From 6002e596c4c42a4dcd0509edc40a5592310323bd Mon Sep 17 00:00:00 2001 From: vishalnayak Date: Wed, 17 Aug 2016 23:28:48 -0400 Subject: [PATCH] VaultClient for Nomad Client --- client/client.go | 35 ++ client/client_test.go | 1 + client/config/config.go | 1 + client/vaultclient/vaultclient.go | 789 +++++++++++++++++++++++++ client/vaultclient/vaultclient_test.go | 221 +++++++ 5 files changed, 1047 insertions(+) create mode 100644 client/vaultclient/vaultclient.go create mode 100644 client/vaultclient/vaultclient_test.go diff --git a/client/client.go b/client/client.go index 8379fd4bf..04897d0eb 100644 --- a/client/client.go +++ b/client/client.go @@ -23,6 +23,7 @@ import ( "github.com/hashicorp/nomad/client/fingerprint" "github.com/hashicorp/nomad/client/rpcproxy" "github.com/hashicorp/nomad/client/stats" + "github.com/hashicorp/nomad/client/vaultclient" "github.com/hashicorp/nomad/command/agent/consul" "github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad/structs" @@ -147,6 +148,9 @@ type Client struct { shutdown bool shutdownCh chan struct{} shutdownLock sync.Mutex + + // client to interact with vault for token and secret renewals + vaultClient vaultclient.VaultClient } // NewClient is used to create a new client from the given configuration @@ -213,6 +217,11 @@ func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logg return nil, fmt.Errorf("failed to create client Consul syncer: %v", err) } + // Setup the vault client for token and secret renewals + if err := c.setupVaultClient(); err != nil { + return nil, fmt.Errorf("failed to setup vault client: %v", err) + } + // Register and then start heartbeating to the servers. go c.registerAndHeartbeat() @@ -238,6 +247,9 @@ func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logg // populated by periodically polling Consul, if available. go c.rpcProxy.Run() + // Start renewing tokens and secrets + go c.vaultClient.Start() + return c, nil } @@ -319,6 +331,11 @@ func (c *Client) Shutdown() error { return nil } + // Stop renewing tokens and secrets + if c.vaultClient != nil { + c.vaultClient.Stop() + } + // Destroy all the running allocations. if c.config.DevMode { c.allocLock.Lock() @@ -1275,6 +1292,24 @@ func (c *Client) addAlloc(alloc *structs.Allocation) error { return nil } +// setupVaultClient creates an object to periodically renew tokens and secrets +// with vault. +func (c *Client) setupVaultClient() error { + if c.config.VaultConfig == nil { + return fmt.Errorf("nil vault config") + } + if c.config.VaultConfig.Token == "" { + return fmt.Errorf("vault token not set") + } + + var err error + if c.vaultClient, err = vaultclient.NewVaultClient(c.config.VaultConfig, c.logger); err != nil { + return err + } + + return nil +} + // setupConsulSyncer creates Client-mode consul.Syncer which periodically // executes callbacks on a fixed interval. // diff --git a/client/client_test.go b/client/client_test.go index 1bc984bc3..ba98dc054 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -85,6 +85,7 @@ func testServer(t *testing.T, cb func(*nomad.Config)) (*nomad.Server, string) { func testClient(t *testing.T, cb func(c *config.Config)) *Client { conf := config.DefaultConfig() + conf.VaultConfig.Enabled = false conf.DevMode = true if cb != nil { cb(conf) diff --git a/client/config/config.go b/client/config/config.go index 6ed0670fa..65cb822ab 100644 --- a/client/config/config.go +++ b/client/config/config.go @@ -149,6 +149,7 @@ func (c *Config) Copy() *Config { // DefaultConfig returns the default configuration func DefaultConfig() *Config { return &Config{ + VaultConfig: config.DefaultVaultConfig(), ConsulConfig: config.DefaultConsulConfig(), LogOutput: os.Stderr, Region: "global", diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go new file mode 100644 index 000000000..32600cead --- /dev/null +++ b/client/vaultclient/vaultclient.go @@ -0,0 +1,789 @@ +package vaultclient + +import ( + "container/heap" + "fmt" + "log" + "sync" + "time" + + "github.com/hashicorp/nomad/nomad/structs/config" + vaultapi "github.com/hashicorp/vault/api" + "github.com/mitchellh/mapstructure" +) + +// The interface which nomad client uses to interact with vault and +// periodically renews the tokens and secrets. +type VaultClient interface { + // Starts the renewal loop of tokens and secrets + Start() + + // Stops the renewal loop for tokens and secrets + Stop() + + // Contacts the nomad server and fetches a wrapped token. The wrapped + // token will be unwrapped by contacting vault and returned. + DeriveToken() (string, error) + + // Fetch the Consul ACL token required for the task + GetConsulACL(string, string) (*vaultapi.Secret, error) + + // Renews a token with the given increment and adds it to the min-heap + // for periodic renewal. + RenewToken(string, int) <-chan error + + // Removes the token from the min-heap, stopping its renewal. + StopRenewToken(string) error + + // Renews a vault secret's lease and add the lease identifier to the + // min-heap for periodic renewal. + RenewLease(string, int) <-chan error + + // Removes a secret's lease id from the min-heap, stopping its renewal. + StopRenewLease(string) error +} + +// Implementation of VaultClient interface to interact with vault and perform +// token and lease renewals periodically. +type vaultClient struct { + // running indicates if the renewal loop is active or not + running bool + + // connEstablished marks whether the connection to vault was successful + // or not + connEstablished bool + + // tokenData is the data of the passed VaultClient token + token *tokenData + + // API client to interact with vault + client *vaultapi.Client + + // Channel to notify heap modifications to the renewal loop + updateCh chan struct{} + + // Channel to trigger termination of renewal loop + stopCh chan struct{} + + // Min-Heap to keep track of both tokens and leases + heap *vaultClientHeap + + // Configuration to connect to vault + config *config.VaultConfig + + lock sync.RWMutex + logger *log.Logger +} + +// tokenData holds the relevant information about the Vault token passed to the +// client. +type tokenData struct { + CreationTTL int `mapstructure:"creation_ttl"` + TTL int `mapstructure:"ttl"` + Renewable bool `mapstructure:"renewable"` + Policies []string `mapstructure:"policies"` + Role string `mapstructure:"role"` + Root bool +} + +// Request object for renewals. This can be used for both token renewals and +// secret's lease renewals. +type vaultClientRenewalRequest struct { + // Channel into which any renewal error will be sent down to + errCh chan error + + // This can either be a token or a lease identifier + id string + + // Duration for which the token or lease should be renewed for + increment int + + // Indicates whether the 'id' field is a token or not + isToken bool +} + +// Element representing an entry in the renewal heap +type vaultClientHeapEntry struct { + req *vaultClientRenewalRequest + next time.Time + index int +} + +// Wrapper around the actual heap to provide additional symantics on top of +// functions provided by the heap interface. In order to achieve that, an +// additional map is placed beside the actual heap. This map can be used to +// check if an entry is already present in the heap. +type vaultClientHeap struct { + heapMap map[string]*vaultClientHeapEntry + heap vaultDataHeapImp +} + +// Data type of the heap +type vaultDataHeapImp []*vaultClientHeapEntry + +// NewVaultClient returns a new vault client from the given config. +func NewVaultClient(config *config.VaultConfig, logger *log.Logger) (*vaultClient, error) { + if config == nil { + return nil, fmt.Errorf("nil vault config") + } + + // Creation of a vault client requires that the token is supplied via + // config. + if config.Token == "" { + return nil, fmt.Errorf("vault token not set") + } + + if config.TaskTokenTTL == "" { + return nil, fmt.Errorf("task_token_ttl not set") + } + + if logger == nil { + return nil, fmt.Errorf("nil logger") + } + + c := &vaultClient{ + config: config, + stopCh: make(chan struct{}), + updateCh: make(chan struct{}, 1), + heap: NewVaultClientHeap(), + logger: logger, + } + + if !c.config.Enabled { + return nil, nil + } + + // Get the Vault API configuration + apiConf, err := config.ApiConfig() + if err != nil { + return nil, fmt.Errorf("failed to create vault API config: %v", err) + } + + // Create the Vault API client + client, err := vaultapi.NewClient(apiConf) + if err != nil { + logger.Printf("[ERR] vault: failed to create Vault client. Not retrying: %v", err) + return nil, err + } + + // Set the token and store the client + client.SetToken(c.config.Token) + + c.client = client + + return c, nil +} + +// NewVaultClientHeap returns a new vault client heap with both the heap and a +// map which is a secondary index for heap elements, both initialized. +func NewVaultClientHeap() *vaultClientHeap { + return &vaultClientHeap{ + heapMap: make(map[string]*vaultClientHeapEntry), + heap: make(vaultDataHeapImp, 0), + } +} + +// IsTracked returns if a given identifier is already present in the heap and +// hence is being renewed. +func (c *vaultClient) IsTracked(id string) bool { + if id == "" { + return false + } + + _, ok := c.heap.heapMap[id] + return ok +} + +// Starts the renewal loop of vault client +func (c *vaultClient) Start() { + if !c.config.Enabled || c.running { + return + } + + c.logger.Printf("[INFO] vaultclient: establishing connection to vault") + go c.establishConnection() +} + +// ConnectionEstablished indicates whether VaultClient successfully established +// connection to vault or not +func (c *vaultClient) ConnectionEstablished() bool { + c.lock.RLock() + defer c.lock.RUnlock() + return c.connEstablished +} + +// establishConnection is used to make first contact with Vault. This should be +// called in a go-routine since the connection is retried til the Vault Client +// is stopped or the connection is successfully made at which point the renew +// loop is started. +func (c *vaultClient) establishConnection() { + // Create the retry timer and set initial duration to zero so it fires + // immediately + retryTimer := time.NewTimer(0) + +OUTER: + for { + select { + case <-c.stopCh: + return + case <-retryTimer.C: + // Ensure the API is reachable + if _, err := c.client.Sys().InitStatus(); err != nil { + c.logger.Printf("[WARN] vaultclient: failed to contact Vault API. Retrying in %v", + c.config.ConnectionRetryIntv) + retryTimer.Reset(c.config.ConnectionRetryIntv) + continue OUTER + } + + break OUTER + } + } + + c.lock.Lock() + c.connEstablished = true + c.lock.Unlock() + + // Retrieve our token, validate it and parse the lease duration + if err := c.parseSelfToken(); err != nil { + c.logger.Printf("[ERR] vaultclient: failed to lookup self token and not retrying: %v", err) + return + } + + // Begin the renewal loop + go c.run() + c.logger.Printf("[INFO] vaultclient: started") + + // If we are given a non-root token, start renewing it + if c.token.Renewable { + c.logger.Printf("[INFO] vaultclient: not renewing token as it is not renewable") + } else { + c.logger.Printf("[INFO] vaultclient: token lease duration is %v", time.Duration(c.token.CreationTTL)*time.Second) + + // Add the VaultClient's token to the renewal loop + errCh := c.RenewToken(c.config.Token, c.token.CreationTTL) + // Catch the renewal error of VaultClient's token. + go func(errCh <-chan error) { + var err error + for { + select { + case err = <-errCh: + c.logger.Printf("[ERR] vaultclient: error while renewing the vault client's token: %v", err) + } + } + }(errCh) + } +} + +// parseSelfToken looks up the Vault token in Vault and parses its data storing +// it in the client. If the token is not valid for Nomads purposes an error is +// returned. +func (c *vaultClient) parseSelfToken() error { + // Get the initial lease duration + auth := c.client.Auth().Token() + self, err := auth.LookupSelf() + if err != nil { + return fmt.Errorf("failed to lookup VaultClient's token: %v", err) + } + + // Read and parse the fields + var data tokenData + if err := mapstructure.WeakDecode(self.Data, &data); err != nil { + return fmt.Errorf("failed to parse Vault token's data block: %v", err) + } + + root := false + for _, p := range data.Policies { + if p == "root" { + root = true + break + } + } + + if !data.Renewable && !root { + return fmt.Errorf("vault token is not renewable or root") + } + + if data.CreationTTL == 0 && !root { + return fmt.Errorf("invalid lease duration of zero") + } + + if data.TTL == 0 && !root { + return fmt.Errorf("token TTL is zero") + } + + if !root && data.Role == "" { + return fmt.Errorf("token role name must be set when not using a root token") + } + + data.Root = root + c.token = &data + return nil +} + +// Stops the renewal loop of vault client +func (c *vaultClient) Stop() { + if !c.config.Enabled || !c.running { + return + } + + c.lock.Lock() + defer c.lock.Unlock() + + c.running = false + close(c.stopCh) +} + +// DeriveToken contacts the nomad server and fetches a wrapped token. Then it +// contacts vault to unwrap the token and returns the unwrapped token. +func (c *vaultClient) DeriveToken() (string, error) { + // TODO: Replace this code with an actual call to the nomad server. + // This is a sample code which directly fetches a wrapped token from + // vault and unwraps it for time being. + tcr := &vaultapi.TokenCreateRequest{ + Policies: []string{"foo", "bar"}, + TTL: "10s", + DisplayName: "derived-token", + Renewable: new(bool), + } + *tcr.Renewable = true + + // Set the TTL for the wrapped token + wrapLookupFunc := func(method, path string) string { + if method == "POST" && path == "auth/token/create" { + return "60s" + } + return "" + } + c.client.SetWrappingLookupFunc(wrapLookupFunc) + + // Create a wrapped token + secret, err := c.client.Auth().Token().Create(tcr) + if err != nil { + return "", fmt.Errorf("failed to create vault token: %v", err) + } + if secret == nil || secret.WrapInfo == nil || secret.WrapInfo.Token == "" || + secret.WrapInfo.WrappedAccessor == "" { + return "", fmt.Errorf("failed to derive a wrapped vault token") + } + + wrappedToken := secret.WrapInfo.Token + + // Unwrap the vault token + unwrapResp, err := c.client.Logical().Unwrap(wrappedToken) + if err != nil { + return "", fmt.Errorf("failed to unwrap the token: %v", err) + } + if unwrapResp == nil || unwrapResp.Auth == nil || unwrapResp.Auth.ClientToken == "" { + return "", fmt.Errorf("failed to unwrap the token") + } + + // Return the unwrapped token + return unwrapResp.Auth.ClientToken, nil +} + +// GetConsulACL creates a vault API client and reads from vault a consul ACL +// token used by the task. +func (c *vaultClient) GetConsulACL(token, vaultPath string) (*vaultapi.Secret, error) { + if token == "" { + return nil, fmt.Errorf("missing token") + } + if vaultPath == "" { + return nil, fmt.Errorf("missing vault path") + } + + // Use the token supplied to interact with vault + c.client.SetToken(token) + + // Read the consul ACL token and return the secret directly + return c.client.Logical().Read(vaultPath) +} + +// RenewToken renews the supplied token and adds it to the min-heap so that it +// is renewed periodically by the renewal loop. Any error returned during +// renewal will be written to a buffered channel and the channel is returned +// instead of an actual error. This helps the caller be notified of a renewal +// failure asynchronously for appropriate actions to be taken. +func (c *vaultClient) RenewToken(token string, increment int) <-chan error { + // Create a buffered error channel + errCh := make(chan error, 1) + + if token == "" { + errCh <- fmt.Errorf("missing token") + return errCh + } + + // Create a renewal request and indicate that the identifier in the + // request is a token and not a lease + renewalReq := &vaultClientRenewalRequest{ + errCh: errCh, + id: token, + isToken: true, + increment: increment, + } + + // Perform the renewal of the token and send any error to the dedicated + // error channel. + if err := c.renew(renewalReq); err != nil { + errCh <- err + } + + return errCh +} + +// RenewLease renews the supplied lease identifier for a supplied duration and +// adds it to the min-heap so that it gets renewed periodically by the renewal +// loop. Any error returned during renewal will be written to a buffered +// channel and the channel is returned instead of an actual error. This helps +// the caller be notified of a renewal failure asynchronously for appropriate +// actions to be taken. +func (c *vaultClient) RenewLease(leaseId string, leaseDuration int) <-chan error { + c.logger.Printf("[INFO] vaultclient: renewing lease %q", leaseId) + // Create a buffered error channel + errCh := make(chan error, 1) + + if leaseId == "" { + errCh <- fmt.Errorf("missing lease ID") + return errCh + } + + if leaseDuration == 0 { + errCh <- fmt.Errorf("missing lease duration") + return errCh + } + + // Create a renewal request using the supplied lease and duration + renewalReq := &vaultClientRenewalRequest{ + errCh: make(chan error, 1), + id: leaseId, + increment: leaseDuration, + } + + // Renew the secret and send any error to the dedicated error channel + if err := c.renew(renewalReq); err != nil { + errCh <- err + } + + return errCh +} + +// renew is a common method to handle renewal of both tokens and secret leases. +// It creates a vault API client and invokes either a token renewal request or +// a secret renewal request. If renewal is successful, min-heap is updated +// based on the duration after which it needs its renewal again. The duration +// is set to half the lease duration present in the renewal response. +func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { + c.logger.Printf("[INFO] vaultclient: ~~~~~~~Renewing %s~~~~~~~~", req.id) + c.lock.Lock() + defer c.lock.Unlock() + + if !c.running { + return fmt.Errorf("vault client is not running") + } + + if req == nil { + return fmt.Errorf("nil renewal request") + } + if req.id == "" { + return fmt.Errorf("missing id in renewal request") + } + if req.increment == 0 { + return fmt.Errorf("missing increment in renewal request") + } + + var duration time.Duration + if req.isToken { + // Reset the token in the API client to that of VaultClient + // before returning + defer c.client.SetToken(c.config.Token) + + // Set the token in the API client to the one that needs + // renewal + c.client.SetToken(req.id) + + // Renew the token + renewResp, err := c.client.Auth().Token().RenewSelf(req.increment) + if err != nil { + return fmt.Errorf("failed to renew the vault token: %v", err) + } + if renewResp == nil || renewResp.Auth == nil { + return fmt.Errorf("failed to renew the vault token") + } + + // Set the next renewal time to half the lease duration + duration = time.Duration(renewResp.Auth.LeaseDuration) * time.Second / 2 + } else { + // Renew the secret + renewResp, err := c.client.Sys().Renew(req.id, req.increment) + if err != nil { + return fmt.Errorf("failed to renew vault secret: %v", err) + } + if renewResp == nil { + return fmt.Errorf("failed to renew vault secret") + } + + // Set the next renewal time to half the lease duration + duration = time.Duration(renewResp.LeaseDuration) * time.Second / 2 + } + + // Determine the next renewal time + next := time.Now().Add(duration) + + if c.IsTracked(req.id) { + // If the identifier is already tracked, this indicates a + // subsequest renewal. In this case, update the existing + // element in the heap with the new renewal time. + + // There is no need to signal an update to the renewal loop + // here because this case is hit from the renewal loop itself. + if err := c.heap.Update(req, next); err != nil { + return fmt.Errorf("failed to update heap entry. err: %v", err) + } + } else { + // If the identifier is not already tracked, this is a first + // renewal request. In this case, add an entry into the heap + // with the next renewal time. + if err := c.heap.Push(req, next); err != nil { + return fmt.Errorf("failed to push an entry to heap. err: %v", err) + } + + // Signal an update for the renewal loop to trigger a fresh + // computation for the next best candidate for renewal. + if c.running { + select { + case c.updateCh <- struct{}{}: + default: + } + } + } + + return nil +} + +// run is the renewal loop which performs the periodic renewals of both the +// tokens and the secret leases. +func (c *vaultClient) run() { + var renewalCh <-chan time.Time + + if !c.config.Enabled { + return + } + + c.lock.Lock() + c.running = true + c.lock.Unlock() + + for c.config.Enabled && c.running { + // Fetches the candidate for next renewal + renewalReq, renewalTime := c.nextRenewal() + if renewalTime.IsZero() { + // If the heap is empty, don't do anything + renewalCh = nil + } else { + now := time.Now() + if renewalTime.After(now) { + // Compute the duration after which the item + // needs renewal and set the renewalCh to fire + // at that time. + renewalDuration := renewalTime.Sub(time.Now()) + renewalCh = time.After(renewalDuration) + } else { + // If the renewals of multiple items are too + // close to each other and by the time the + // entry is fetched from heap it might be past + // the current time (by a small margin). In + // which case, fire immediately. + renewalCh = time.After(0) + } + } + + select { + case <-renewalCh: + if err := c.renew(renewalReq); err != nil { + renewalReq.errCh <- err + } + case <-c.updateCh: + continue + case <-c.stopCh: + c.logger.Printf("[INFO] vaultclient: stopped") + return + } + } +} + +// StopRenewToken removes the item from the heap which represents the given +// token. +func (c *vaultClient) StopRenewToken(token string) error { + return c.stopRenew(token) +} + +// StopRenewLease removes the item from the heap which represents the given +// lease identifier. +func (c *vaultClient) StopRenewLease(leaseId string) error { + return c.stopRenew(leaseId) +} + +// stopRenew removes the given identifier from the heap and signals the renewal +// loop to compute the next best candidate for renewal. +func (c *vaultClient) stopRenew(id string) error { + c.lock.Lock() + defer c.lock.Unlock() + + if !c.IsTracked(id) { + return nil + } + + // Remove the identifier from the heap + if err := c.heap.Remove(id); err != nil { + return fmt.Errorf("failed to remove heap entry: %v", err) + } + + // Delete the identifier from the map only after the it is removed from + // the heap. Heap's remove method relies on the heap map. + delete(c.heap.heapMap, id) + + // Signal an update to the renewal loop. + if c.running { + select { + case c.updateCh <- struct{}{}: + default: + } + } + + return nil +} + +// nextRenewal returns the root element of the min-heap, which represents the +// next element to be renewed and the time at which the renewal needs to be +// triggered. +func (c *vaultClient) nextRenewal() (*vaultClientRenewalRequest, time.Time) { + c.lock.RLock() + defer c.lock.RUnlock() + + if c.heap.Length() == 0 { + return nil, time.Time{} + } + + // Fetches the root element in the min-heap + nextEntry := c.heap.Peek() + if nextEntry == nil { + return nil, time.Time{} + } + + return nextEntry.req, nextEntry.next +} + +// Additional helper functions on top of interface methods + +// Length returns the number of elements in the heap +func (h *vaultClientHeap) Length() int { + return len(h.heap) +} + +// Returns the root node of the min-heap +func (h *vaultClientHeap) Peek() *vaultClientHeapEntry { + if len(h.heap) == 0 { + return nil + } + + return h.heap[0] +} + +// Push adds the secondary index and inserts an item into the heap +func (h *vaultClientHeap) Push(req *vaultClientRenewalRequest, next time.Time) error { + if req == nil { + return fmt.Errorf("nil request") + } + + if _, ok := h.heapMap[req.id]; ok { + return fmt.Errorf("entry %v already exists", req.id) + } + + heapEntry := &vaultClientHeapEntry{ + req: req, + next: next, + } + h.heapMap[req.id] = heapEntry + heap.Push(&h.heap, heapEntry) + return nil +} + +// Update will modify the existing item in the heap with the new data and the +// time, and fixes the heap. +func (h *vaultClientHeap) Update(req *vaultClientRenewalRequest, next time.Time) error { + if entry, ok := h.heapMap[req.id]; ok { + entry.req = req + entry.next = next + heap.Fix(&h.heap, entry.index) + return nil + } + + return fmt.Errorf("heap doesn't contain %v", req.id) +} + +// Remove will remove an identifier from the secondary index and deletes the +// corresponding node from the heap. +func (h *vaultClientHeap) Remove(id string) error { + if entry, ok := h.heapMap[id]; ok { + heap.Remove(&h.heap, entry.index) + delete(h.heapMap, id) + return nil + } + + return fmt.Errorf("heap doesn't contain entry for %v", id) +} + +// The heap interface requires the following methods to be implemented. +// * Push(x interface{}) // add x as element Len() +// * Pop() interface{} // remove and return element Len() - 1. +// * sort.Interface +// +// sort.Interface comprises of the following methods: +// * Len() int +// * Less(i, j int) bool +// * Swap(i, j int) + +// Part of sort.Interface +func (h vaultDataHeapImp) Len() int { return len(h) } + +// Part of sort.Interface +func (h vaultDataHeapImp) Less(i, j int) bool { + // Two zero times should return false. + // Otherwise, zero is "greater" than any other time. + // (To sort it at the end of the list.) + // Sort such that zero times are at the end of the list. + iZero, jZero := h[i].next.IsZero(), h[j].next.IsZero() + if iZero && jZero { + return false + } else if iZero { + return false + } else if jZero { + return true + } + + return h[i].next.Before(h[j].next) +} + +// Part of sort.Interface +func (h vaultDataHeapImp) Swap(i, j int) { + h[i], h[j] = h[j], h[i] + h[i].index = i + h[j].index = j +} + +// Part of heap.Interface +func (h *vaultDataHeapImp) Push(x interface{}) { + n := len(*h) + entry := x.(*vaultClientHeapEntry) + entry.index = n + *h = append(*h, entry) +} + +// Part of heap.Interface +func (h *vaultDataHeapImp) Pop() interface{} { + old := *h + n := len(old) + entry := old[n-1] + entry.index = -1 // for safety + *h = old[0 : n-1] + return entry +} diff --git a/client/vaultclient/vaultclient_test.go b/client/vaultclient/vaultclient_test.go new file mode 100644 index 000000000..d36c2e711 --- /dev/null +++ b/client/vaultclient/vaultclient_test.go @@ -0,0 +1,221 @@ +package vaultclient + +import ( + "log" + "os" + "testing" + "time" + + "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/testutil" + vaultapi "github.com/hashicorp/vault/api" +) + +func TestVaultClient_EstablishConnection(t *testing.T) { + v := testutil.NewTestVault(t) + + logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) + v.Config.ConnectionRetryIntv = 100 * time.Millisecond + v.Config.TaskTokenTTL = "10s" + + c, err := NewVaultClient(v.Config, logger) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + + c.Start() + defer c.Stop() + + // Sleep a little while and check that no connection has been established. + time.Sleep(100 * time.Duration(testutil.TestMultiplier()) * time.Millisecond) + + if c.ConnectionEstablished() { + t.Fatalf("ConnectionEstablished() returned true before Vault server started") + } + + // Start Vault + v.Start() + defer v.Stop() + + testutil.WaitForResult(func() (bool, error) { + return c.ConnectionEstablished(), nil + }, func(err error) { + t.Fatalf("Connection not established") + }) +} + +func TestVaultClient_TokenRenewals(t *testing.T) { + v := testutil.NewTestVault(t).Start() + defer v.Stop() + + v.Config.ConnectionRetryIntv = 100 * time.Millisecond + v.Config.TaskTokenTTL = "10s" + + logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) + c, err := NewVaultClient(v.Config, logger) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + + c.Start() + defer c.Stop() + + // Sleep a little while to ensure that the renewal loop is active + time.Sleep(2 * time.Second) + + tcr := &vaultapi.TokenCreateRequest{ + Policies: []string{"foo", "bar"}, + TTL: "2s", + DisplayName: "derived-for-task", + Renewable: new(bool), + } + *tcr.Renewable = true + + num := 10 + tokens := make([]string, num) + for i := 0; i < num; i++ { + c.client.SetToken(v.Config.Token) + + if err := c.client.SetAddress(v.Config.Addr); err != nil { + t.Fatal(err) + } + + secret, err := c.client.Auth().Token().Create(tcr) + if err != nil { + t.Fatalf("failed to create vault token: %v", err) + } + + if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { + t.Fatal("failed to derive a wrapped vault token") + } + + tokens[i] = secret.Auth.ClientToken + + errCh := c.RenewToken(tokens[i], secret.Auth.LeaseDuration) + go func(errCh <-chan error) { + var err error + for { + select { + case err = <-errCh: + t.Fatalf("error while renewing the token: %v", err) + } + } + }(errCh) + } + + if c.heap.Length() != num { + t.Fatalf("bad: heap length: expected: %d, actual: %d", num, c.heap.Length()) + } + + time.Sleep(10 * time.Second) + + for i := 0; i < num; i++ { + if err := c.StopRenewToken(tokens[i]); err != nil { + t.Fatal(err) + } + } + + if c.heap.Length() != 0 { + t.Fatal("bad: heap length: expected: 0, actual: %d", c.heap.Length()) + } +} + +func TestVaultClient_Heap(t *testing.T) { + conf := config.DefaultConfig() + conf.VaultConfig.Token = "testvaulttoken" + conf.VaultConfig.TaskTokenTTL = "10s" + + logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) + c, err := NewVaultClient(conf.VaultConfig, logger) + if err != nil { + t.Fatal(err) + } + + now := time.Now() + + renewalReq1 := &vaultClientRenewalRequest{ + errCh: make(chan error, 1), + id: "id1", + increment: 10, + } + if err := c.heap.Push(renewalReq1, now.Add(50*time.Second)); err != nil { + t.Fatal(err) + } + if !c.IsTracked("id1") { + t.Fatalf("id1 should have been tracked") + } + + renewalReq2 := &vaultClientRenewalRequest{ + errCh: make(chan error, 1), + id: "id2", + increment: 10, + } + if err := c.heap.Push(renewalReq2, now.Add(40*time.Second)); err != nil { + t.Fatal(err) + } + if !c.IsTracked("id2") { + t.Fatalf("id2 should have been tracked") + } + + renewalReq3 := &vaultClientRenewalRequest{ + errCh: make(chan error, 1), + id: "id3", + increment: 10, + } + if err := c.heap.Push(renewalReq3, now.Add(60*time.Second)); err != nil { + t.Fatal(err) + } + if !c.IsTracked("id3") { + t.Fatalf("id3 should have been tracked") + } + + // Reading elements should yield id2, id1 and id3 in order + req, _ := c.nextRenewal() + if req != renewalReq2 { + t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq2, req) + } + if err := c.heap.Update(req, now.Add(70*time.Second)); err != nil { + t.Fatal(err) + } + + req, _ = c.nextRenewal() + if req != renewalReq1 { + t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq1, req) + } + if err := c.heap.Update(req, now.Add(80*time.Second)); err != nil { + t.Fatal(err) + } + + req, _ = c.nextRenewal() + if req != renewalReq3 { + t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq3, req) + } + if err := c.heap.Update(req, now.Add(90*time.Second)); err != nil { + t.Fatal(err) + } + + if err := c.StopRenewToken("id1"); err != nil { + t.Fatal(err) + } + + if err := c.StopRenewToken("id2"); err != nil { + t.Fatal(err) + } + + if err := c.StopRenewToken("id3"); err != nil { + t.Fatal(err) + } + + if c.IsTracked("id1") { + t.Fatalf("id1 should not have been tracked") + } + + if c.IsTracked("id1") { + t.Fatalf("id1 should not have been tracked") + } + + if c.IsTracked("id1") { + t.Fatalf("id1 should not have been tracked") + } + +}