Address review feedback

This commit is contained in:
vishalnayak 2016-08-29 12:37:39 -04:00
parent 56e42cf03d
commit 160ba48eb4
2 changed files with 71 additions and 163 deletions

View file

@ -8,14 +8,12 @@ import (
"sync"
"time"
"github.com/hashicorp/go-multierror"
clientconfig "github.com/hashicorp/nomad/client/config"
"github.com/hashicorp/nomad/client/rpcproxy"
"github.com/hashicorp/nomad/nomad"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/nomad/structs/config"
vaultapi "github.com/hashicorp/vault/api"
"github.com/mitchellh/mapstructure"
)
// The interface which nomad client uses to interact with vault and
@ -158,11 +156,6 @@ func NewVaultClient(node *structs.Node, region string, config *config.VaultConfi
return nil, fmt.Errorf("nil vault config")
}
// Creation of a vault client requires a token
if config.Token == "" {
return nil, fmt.Errorf("vault token not set")
}
if config.TaskTokenTTL == "" {
return nil, fmt.Errorf("task_token_ttl not set")
}
@ -189,43 +182,41 @@ func NewVaultClient(node *structs.Node, region string, config *config.VaultConfi
stopCh: make(chan struct{}),
// Update channel should be a buffered channel
updateCh: make(chan struct{}, 1),
heap: NewVaultClientHeap(),
heap: newVaultClientHeap(),
logger: logger,
}
// Get the Vault API configuration
apiConf, err := config.ApiConfig()
if err != nil {
logger.Printf("[ERR] client/vaultclient: failed to create vault API config: %v", err)
logger.Printf("[ERR] client.vault: failed to create vault API config: %v", err)
return nil, err
}
// Create the Vault API client
client, err := vaultapi.NewClient(apiConf)
if err != nil {
logger.Printf("[ERR] client/vaultclient: failed to create Vault client. Not retrying: %v", err)
logger.Printf("[ERR] client.vault: failed to create Vault client. Not retrying: %v", err)
return nil, err
}
// Set the token and store the client
client.SetToken(c.config.Token)
c.client = client
return c, nil
}
// NewVaultClientHeap returns a new vault client heap with both the heap and a
// newVaultClientHeap returns a new vault client heap with both the heap and a
// map which is a secondary index for heap elements, both initialized.
func NewVaultClientHeap() *vaultClientHeap {
func newVaultClientHeap() *vaultClientHeap {
return &vaultClientHeap{
heapMap: make(map[string]*vaultClientHeapEntry),
heap: make(vaultDataHeapImp, 0),
}
}
// IsTracked returns if a given identifier is already present in the heap and
// isTracked returns if a given identifier is already present in the heap and
// hence is being renewed. Lock should be held before calling this method.
func (c *vaultClient) IsTracked(id string) bool {
func (c *vaultClient) isTracked(id string) bool {
if id == "" {
return false
}
@ -240,7 +231,7 @@ func (c *vaultClient) Start() {
return
}
c.logger.Printf("[DEBUG] client/vaultclient: establishing connection to vault")
c.logger.Printf("[DEBUG] client.vault: establishing connection to vault")
go c.establishConnection()
}
@ -269,7 +260,7 @@ OUTER:
case <-retryTimer.C:
// Ensure the API is reachable
if _, err := c.client.Sys().InitStatus(); err != nil {
c.logger.Printf("[WARN] client/vaultclient: failed to contact Vault API. Retrying in %v",
c.logger.Printf("[WARN] client.vault: failed to contact Vault API. Retrying in %v",
c.config.ConnectionRetryIntv)
retryTimer.Reset(c.config.ConnectionRetryIntv)
continue OUTER
@ -283,85 +274,9 @@ OUTER:
c.connEstablished = true
c.lock.Unlock()
// Retrieve our token, validate it and parse the lease duration
if err := c.parseSelfToken(); err != nil {
c.logger.Printf("[ERR] client/vaultclient: failed to lookup self token and not retrying: %v", err)
return
}
// Begin the renewal loop
go c.run()
c.logger.Printf("[DEBUG] client/vaultclient: started")
// If we are given a token that needs renewal, place it in the renewal
// loop.
// Root tokens can also have a TTL
if c.token.Root && c.token.TTL == 0 {
c.logger.Printf("[DEBUG] client/vaultclient: not renewing token as it is a non-expiring root token")
} else {
c.logger.Printf("[DEBUG] client/vaultclient: token lease duration is %v", time.Duration(c.token.CreationTTL)*time.Second)
// Renew the token and place it in renewal min-heap
errCh := c.RenewToken(c.config.Token, c.token.CreationTTL)
// Catch the renewal error of VaultClient's token.
go func(errCh <-chan error) {
var err error
for {
select {
case err = <-errCh:
c.logger.Printf("[ERR] client/vaultclient: error while renewing the vault client's token: %v", err)
}
}
}(errCh)
}
}
// parseSelfToken looks up the VaultClient's token in vault and parses its data
// storing it in the client. If the token is not valid for Nomads purposes an
// error is returned.
func (c *vaultClient) parseSelfToken() error {
// Get the initial lease duration
auth := c.client.Auth().Token()
self, err := auth.LookupSelf()
if err != nil {
return fmt.Errorf("failed to lookup VaultClient's token: %v", err)
}
// Read and parse the fields
var data tokenData
if err := mapstructure.WeakDecode(self.Data, &data); err != nil {
return fmt.Errorf("failed to parse Vault token's data block: %v", err)
}
root := false
for _, p := range data.Policies {
if p == "root" {
root = true
break
}
}
if !data.Renewable && !root {
return fmt.Errorf("vault token is not renewable or root")
}
if data.CreationTTL == 0 && !root {
return fmt.Errorf("invalid lease duration of zero")
}
if data.TTL == 0 && !root {
return fmt.Errorf("token TTL is zero")
}
if !root && data.Role == "" {
return fmt.Errorf("token role name must be set when not using a root token")
}
data.Root = root
c.token = &data
return nil
c.logger.Printf("[DEBUG] client.vault: started")
}
// Stops the renewal loop of vault client
@ -382,42 +297,39 @@ func (c *vaultClient) Stop() {
// The return value is a map containing all the unwrapped tokens indexed by the
// task name.
func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) (map[string]string, error) {
var result *multierror.Error
if !c.running {
result = multierror.Append(fmt.Errorf("vault client is not running"))
return nil, result
return nil, fmt.Errorf("vault client is not running")
}
if alloc == nil {
result = multierror.Append(fmt.Errorf("nil allocation"))
return nil, result
}
if taskNames == nil || len(taskNames) == 0 {
result = multierror.Append(fmt.Errorf("missing task names"))
return nil, result
return nil, fmt.Errorf("nil allocation")
}
if taskNames == nil || len(taskNames) == 0 {
return nil, fmt.Errorf("missing task names")
}
found := false
verifiedTasks := []string{}
// Check if the given task names actually exist in the allocation
for _, taskName := range taskNames {
found = false
for _, group := range alloc.Job.TaskGroups {
for _, task := range group.Tasks {
if task.Name == taskName {
found = true
found := false
// Check if the given task names actually exist in the allocation under
// the correct group name
for _, group := range alloc.Job.TaskGroups {
// Refer only to the group belonging to the allocation
if group.Name == alloc.TaskGroup {
for _, taskName := range taskNames {
found = false
for _, task := range group.Tasks {
if task.Name == taskName {
found = true
}
}
if !found {
c.logger.Printf("[ERR] task %q not found in the allocation", taskName)
return nil, fmt.Errorf("task %q not found in the allocaition", taskName)
}
verifiedTasks = append(verifiedTasks, taskName)
}
}
if found {
verifiedTasks = append(verifiedTasks, taskName)
} else {
// Append the error for an invalid task name, but don't
// break out of the loop. Continue to process other
// tasks.
result = multierror.Append(result, fmt.Errorf("task %s not found in the allocation", taskName))
}
}
// DeriveVaultToken of nomad server can take in a set of tasks and
@ -436,38 +348,32 @@ func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string)
// Derive the tokens
var resp structs.DeriveVaultTokenResponse
if err := c.RPC("Node.DeriveVaultToken", &req, &resp); err != nil {
c.logger.Printf("[ERR] client/vaultclient: failed to derive vault tokens: %v", err)
result = multierror.Append(result, fmt.Errorf("failed to derive vault tokens: %v", err))
return nil, result
c.logger.Printf("[ERR] client.vault: failed to derive vault tokens: %v", err)
return nil, fmt.Errorf("failed to derive vault tokens: %v", err)
}
if resp.Tasks == nil {
c.logger.Printf("[ERR] client/vaultclient: failed to derive vault token: invalid response")
result = multierror.Append(result, fmt.Errorf("failed to derive vault tokens: invalid response"))
return nil, result
c.logger.Printf("[ERR] client.vault: failed to derive vault token: invalid response")
return nil, fmt.Errorf("failed to derive vault tokens: invalid response")
}
unwrappedTokens := make(map[string]string)
// Retrieve the wrapped tokens from the response and unwrap it using
// the VaultClient's token, which is cached at the API client.
// Retrieve the wrapped tokens from the response and unwrap it
for _, taskName := range verifiedTasks {
// Get the wrapped token
wrappedToken, ok := resp.Tasks[taskName]
if !ok {
c.logger.Printf("[ERR] client/vaultclient: wrapped token missing for task %q", taskName)
result = multierror.Append(result, fmt.Errorf("wrapped token missing for task %q", taskName))
return nil, result
c.logger.Printf("[ERR] client.vault: wrapped token missing for task %q", taskName)
return nil, fmt.Errorf("wrapped token missing for task %q", taskName)
}
// Unwrap the vault token
unwrapResp, err := c.client.Logical().Unwrap(wrappedToken)
if err != nil {
result = multierror.Append(result, fmt.Errorf("failed to unwrap the token for task %q: %v", taskName, err))
return nil, result
return nil, fmt.Errorf("failed to unwrap the token for task %q: %v", taskName, err)
}
if unwrapResp == nil || unwrapResp.Auth == nil || unwrapResp.Auth.ClientToken == "" {
result = multierror.Append(result, fmt.Errorf("failed to unwrap the token for task %q", taskName))
return nil, result
return nil, fmt.Errorf("failed to unwrap the token for task %q", taskName)
}
// Append the unwrapped token to the return value
@ -493,18 +399,19 @@ func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error)
// Use the token supplied to interact with vault
c.client.SetToken(token)
// Restore the token in client to VaultClient's token
defer c.client.SetToken(c.config.Token)
// Reset the token before returning
defer c.client.SetToken("")
// Read the consul ACL token and return the secret directly
return c.client.Logical().Read(path)
}
// RenewToken renews the supplied token and adds it to the min-heap so that it
// is renewed periodically by the renewal loop. Any error returned during
// renewal will be written to a buffered channel and the channel is returned
// instead of an actual error. This helps the caller be notified of a renewal
// failure asynchronously for appropriate actions to be taken.
// RenewToken renews the supplied token for a given duration (in seconds) and
// adds it to the min-heap so that it is renewed periodically by the renewal
// loop. Any error returned during renewal will be written to a buffered
// channel and the channel is returned instead of an actual error. This helps
// the caller be notified of a renewal failure asynchronously for appropriate
// actions to be taken.
func (c *vaultClient) RenewToken(token string, increment int) <-chan error {
// Create a buffered error channel
errCh := make(chan error, 1)
@ -531,19 +438,20 @@ func (c *vaultClient) RenewToken(token string, increment int) <-chan error {
// error channel.
if err := c.renew(renewalReq); err != nil {
errCh <- err
close(errCh)
}
return errCh
}
// RenewLease renews the supplied lease identifier for a supplied duration and
// adds it to the min-heap so that it gets renewed periodically by the renewal
// loop. Any error returned during renewal will be written to a buffered
// channel and the channel is returned instead of an actual error. This helps
// the caller be notified of a renewal failure asynchronously for appropriate
// actions to be taken.
// RenewLease renews the supplied lease identifier for a supplied duration (in
// seconds) and adds it to the min-heap so that it gets renewed periodically by
// the renewal loop. Any error returned during renewal will be written to a
// buffered channel and the channel is returned instead of an actual error.
// This helps the caller be notified of a renewal failure asynchronously for
// appropriate actions to be taken.
func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error {
c.logger.Printf("[DEBUG] client/vaultclient: renewing lease %q", leaseId)
c.logger.Printf("[DEBUG] client.vault: renewing lease %q", leaseId)
// Create a buffered error channel
errCh := make(chan error, 1)
@ -567,6 +475,7 @@ func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error {
// Renew the secret and send any error to the dedicated error channel
if err := c.renew(renewalReq); err != nil {
errCh <- err
close(errCh)
}
return errCh
@ -598,9 +507,8 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
var renewalErr error
leaseDuration := req.increment
if req.isToken {
// Reset the token in the API client to that of VaultClient
// before returning
defer c.client.SetToken(c.config.Token)
// Reset the token in the API client before returning
defer c.client.SetToken("")
// Set the token in the API client to the one that needs
// renewal
@ -643,13 +551,13 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
rand.Seed(time.Now().Unix())
duration = min + rand.Intn(max-min)
}
c.logger.Printf("[DEBUG] client/vaultclient: req.increment: %d, leaseDuration: %d, duration: %d",
c.logger.Printf("[DEBUG] client.vault: req.increment: %d, leaseDuration: %d, duration: %d",
req.increment, leaseDuration, duration)
// Determine the next renewal time
next := time.Now().Add(time.Duration(duration) * time.Second)
if c.IsTracked(req.id) {
if c.isTracked(req.id) {
// If the identifier is already tracked, this indicates a
// subsequest renewal. In this case, update the existing
// element in the heap with the new renewal time.
@ -727,7 +635,7 @@ func (c *vaultClient) run() {
case <-c.updateCh:
continue
case <-c.stopCh:
c.logger.Printf("[DEBUG] client/vaultclient: stopped")
c.logger.Printf("[DEBUG] client.vault: stopped")
return
}
}
@ -751,7 +659,7 @@ func (c *vaultClient) stopRenew(id string) error {
c.lock.Lock()
defer c.lock.Unlock()
if !c.IsTracked(id) {
if !c.isTracked(id) {
return nil
}

View file

@ -161,7 +161,7 @@ func TestVaultClient_Heap(t *testing.T) {
if err := c.heap.Push(renewalReq1, now.Add(50*time.Second)); err != nil {
t.Fatal(err)
}
if !c.IsTracked("id1") {
if !c.isTracked("id1") {
t.Fatalf("id1 should have been tracked")
}
@ -173,7 +173,7 @@ func TestVaultClient_Heap(t *testing.T) {
if err := c.heap.Push(renewalReq2, now.Add(40*time.Second)); err != nil {
t.Fatal(err)
}
if !c.IsTracked("id2") {
if !c.isTracked("id2") {
t.Fatalf("id2 should have been tracked")
}
@ -185,7 +185,7 @@ func TestVaultClient_Heap(t *testing.T) {
if err := c.heap.Push(renewalReq3, now.Add(60*time.Second)); err != nil {
t.Fatal(err)
}
if !c.IsTracked("id3") {
if !c.isTracked("id3") {
t.Fatalf("id3 should have been tracked")
}
@ -226,15 +226,15 @@ func TestVaultClient_Heap(t *testing.T) {
t.Fatal(err)
}
if c.IsTracked("id1") {
if c.isTracked("id1") {
t.Fatalf("id1 should not have been tracked")
}
if c.IsTracked("id1") {
if c.isTracked("id1") {
t.Fatalf("id1 should not have been tracked")
}
if c.IsTracked("id1") {
if c.isTracked("id1") {
t.Fatalf("id1 should not have been tracked")
}