Address review feedback
This commit is contained in:
parent
56e42cf03d
commit
160ba48eb4
|
@ -8,14 +8,12 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
clientconfig "github.com/hashicorp/nomad/client/config"
|
||||
"github.com/hashicorp/nomad/client/rpcproxy"
|
||||
"github.com/hashicorp/nomad/nomad"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/hashicorp/nomad/nomad/structs/config"
|
||||
vaultapi "github.com/hashicorp/vault/api"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
// The interface which nomad client uses to interact with vault and
|
||||
|
@ -158,11 +156,6 @@ func NewVaultClient(node *structs.Node, region string, config *config.VaultConfi
|
|||
return nil, fmt.Errorf("nil vault config")
|
||||
}
|
||||
|
||||
// Creation of a vault client requires a token
|
||||
if config.Token == "" {
|
||||
return nil, fmt.Errorf("vault token not set")
|
||||
}
|
||||
|
||||
if config.TaskTokenTTL == "" {
|
||||
return nil, fmt.Errorf("task_token_ttl not set")
|
||||
}
|
||||
|
@ -189,43 +182,41 @@ func NewVaultClient(node *structs.Node, region string, config *config.VaultConfi
|
|||
stopCh: make(chan struct{}),
|
||||
// Update channel should be a buffered channel
|
||||
updateCh: make(chan struct{}, 1),
|
||||
heap: NewVaultClientHeap(),
|
||||
heap: newVaultClientHeap(),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Get the Vault API configuration
|
||||
apiConf, err := config.ApiConfig()
|
||||
if err != nil {
|
||||
logger.Printf("[ERR] client/vaultclient: failed to create vault API config: %v", err)
|
||||
logger.Printf("[ERR] client.vault: failed to create vault API config: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create the Vault API client
|
||||
client, err := vaultapi.NewClient(apiConf)
|
||||
if err != nil {
|
||||
logger.Printf("[ERR] client/vaultclient: failed to create Vault client. Not retrying: %v", err)
|
||||
logger.Printf("[ERR] client.vault: failed to create Vault client. Not retrying: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set the token and store the client
|
||||
client.SetToken(c.config.Token)
|
||||
c.client = client
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// NewVaultClientHeap returns a new vault client heap with both the heap and a
|
||||
// newVaultClientHeap returns a new vault client heap with both the heap and a
|
||||
// map which is a secondary index for heap elements, both initialized.
|
||||
func NewVaultClientHeap() *vaultClientHeap {
|
||||
func newVaultClientHeap() *vaultClientHeap {
|
||||
return &vaultClientHeap{
|
||||
heapMap: make(map[string]*vaultClientHeapEntry),
|
||||
heap: make(vaultDataHeapImp, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// IsTracked returns if a given identifier is already present in the heap and
|
||||
// isTracked returns if a given identifier is already present in the heap and
|
||||
// hence is being renewed. Lock should be held before calling this method.
|
||||
func (c *vaultClient) IsTracked(id string) bool {
|
||||
func (c *vaultClient) isTracked(id string) bool {
|
||||
if id == "" {
|
||||
return false
|
||||
}
|
||||
|
@ -240,7 +231,7 @@ func (c *vaultClient) Start() {
|
|||
return
|
||||
}
|
||||
|
||||
c.logger.Printf("[DEBUG] client/vaultclient: establishing connection to vault")
|
||||
c.logger.Printf("[DEBUG] client.vault: establishing connection to vault")
|
||||
go c.establishConnection()
|
||||
}
|
||||
|
||||
|
@ -269,7 +260,7 @@ OUTER:
|
|||
case <-retryTimer.C:
|
||||
// Ensure the API is reachable
|
||||
if _, err := c.client.Sys().InitStatus(); err != nil {
|
||||
c.logger.Printf("[WARN] client/vaultclient: failed to contact Vault API. Retrying in %v",
|
||||
c.logger.Printf("[WARN] client.vault: failed to contact Vault API. Retrying in %v",
|
||||
c.config.ConnectionRetryIntv)
|
||||
retryTimer.Reset(c.config.ConnectionRetryIntv)
|
||||
continue OUTER
|
||||
|
@ -283,85 +274,9 @@ OUTER:
|
|||
c.connEstablished = true
|
||||
c.lock.Unlock()
|
||||
|
||||
// Retrieve our token, validate it and parse the lease duration
|
||||
if err := c.parseSelfToken(); err != nil {
|
||||
c.logger.Printf("[ERR] client/vaultclient: failed to lookup self token and not retrying: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Begin the renewal loop
|
||||
go c.run()
|
||||
c.logger.Printf("[DEBUG] client/vaultclient: started")
|
||||
|
||||
// If we are given a token that needs renewal, place it in the renewal
|
||||
// loop.
|
||||
|
||||
// Root tokens can also have a TTL
|
||||
if c.token.Root && c.token.TTL == 0 {
|
||||
c.logger.Printf("[DEBUG] client/vaultclient: not renewing token as it is a non-expiring root token")
|
||||
} else {
|
||||
c.logger.Printf("[DEBUG] client/vaultclient: token lease duration is %v", time.Duration(c.token.CreationTTL)*time.Second)
|
||||
|
||||
// Renew the token and place it in renewal min-heap
|
||||
errCh := c.RenewToken(c.config.Token, c.token.CreationTTL)
|
||||
|
||||
// Catch the renewal error of VaultClient's token.
|
||||
go func(errCh <-chan error) {
|
||||
var err error
|
||||
for {
|
||||
select {
|
||||
case err = <-errCh:
|
||||
c.logger.Printf("[ERR] client/vaultclient: error while renewing the vault client's token: %v", err)
|
||||
}
|
||||
}
|
||||
}(errCh)
|
||||
}
|
||||
}
|
||||
|
||||
// parseSelfToken looks up the VaultClient's token in vault and parses its data
|
||||
// storing it in the client. If the token is not valid for Nomads purposes an
|
||||
// error is returned.
|
||||
func (c *vaultClient) parseSelfToken() error {
|
||||
// Get the initial lease duration
|
||||
auth := c.client.Auth().Token()
|
||||
self, err := auth.LookupSelf()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to lookup VaultClient's token: %v", err)
|
||||
}
|
||||
|
||||
// Read and parse the fields
|
||||
var data tokenData
|
||||
if err := mapstructure.WeakDecode(self.Data, &data); err != nil {
|
||||
return fmt.Errorf("failed to parse Vault token's data block: %v", err)
|
||||
}
|
||||
|
||||
root := false
|
||||
for _, p := range data.Policies {
|
||||
if p == "root" {
|
||||
root = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !data.Renewable && !root {
|
||||
return fmt.Errorf("vault token is not renewable or root")
|
||||
}
|
||||
|
||||
if data.CreationTTL == 0 && !root {
|
||||
return fmt.Errorf("invalid lease duration of zero")
|
||||
}
|
||||
|
||||
if data.TTL == 0 && !root {
|
||||
return fmt.Errorf("token TTL is zero")
|
||||
}
|
||||
|
||||
if !root && data.Role == "" {
|
||||
return fmt.Errorf("token role name must be set when not using a root token")
|
||||
}
|
||||
|
||||
data.Root = root
|
||||
c.token = &data
|
||||
return nil
|
||||
c.logger.Printf("[DEBUG] client.vault: started")
|
||||
}
|
||||
|
||||
// Stops the renewal loop of vault client
|
||||
|
@ -382,42 +297,39 @@ func (c *vaultClient) Stop() {
|
|||
// The return value is a map containing all the unwrapped tokens indexed by the
|
||||
// task name.
|
||||
func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) (map[string]string, error) {
|
||||
var result *multierror.Error
|
||||
|
||||
if !c.running {
|
||||
result = multierror.Append(fmt.Errorf("vault client is not running"))
|
||||
return nil, result
|
||||
return nil, fmt.Errorf("vault client is not running")
|
||||
}
|
||||
|
||||
if alloc == nil {
|
||||
result = multierror.Append(fmt.Errorf("nil allocation"))
|
||||
return nil, result
|
||||
}
|
||||
if taskNames == nil || len(taskNames) == 0 {
|
||||
result = multierror.Append(fmt.Errorf("missing task names"))
|
||||
return nil, result
|
||||
return nil, fmt.Errorf("nil allocation")
|
||||
}
|
||||
|
||||
if taskNames == nil || len(taskNames) == 0 {
|
||||
return nil, fmt.Errorf("missing task names")
|
||||
}
|
||||
|
||||
found := false
|
||||
verifiedTasks := []string{}
|
||||
// Check if the given task names actually exist in the allocation
|
||||
for _, taskName := range taskNames {
|
||||
found = false
|
||||
for _, group := range alloc.Job.TaskGroups {
|
||||
for _, task := range group.Tasks {
|
||||
if task.Name == taskName {
|
||||
found = true
|
||||
found := false
|
||||
// Check if the given task names actually exist in the allocation under
|
||||
// the correct group name
|
||||
for _, group := range alloc.Job.TaskGroups {
|
||||
// Refer only to the group belonging to the allocation
|
||||
if group.Name == alloc.TaskGroup {
|
||||
for _, taskName := range taskNames {
|
||||
found = false
|
||||
for _, task := range group.Tasks {
|
||||
if task.Name == taskName {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
c.logger.Printf("[ERR] task %q not found in the allocation", taskName)
|
||||
return nil, fmt.Errorf("task %q not found in the allocaition", taskName)
|
||||
}
|
||||
verifiedTasks = append(verifiedTasks, taskName)
|
||||
}
|
||||
}
|
||||
if found {
|
||||
verifiedTasks = append(verifiedTasks, taskName)
|
||||
} else {
|
||||
// Append the error for an invalid task name, but don't
|
||||
// break out of the loop. Continue to process other
|
||||
// tasks.
|
||||
result = multierror.Append(result, fmt.Errorf("task %s not found in the allocation", taskName))
|
||||
}
|
||||
}
|
||||
|
||||
// DeriveVaultToken of nomad server can take in a set of tasks and
|
||||
|
@ -436,38 +348,32 @@ func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string)
|
|||
// Derive the tokens
|
||||
var resp structs.DeriveVaultTokenResponse
|
||||
if err := c.RPC("Node.DeriveVaultToken", &req, &resp); err != nil {
|
||||
c.logger.Printf("[ERR] client/vaultclient: failed to derive vault tokens: %v", err)
|
||||
result = multierror.Append(result, fmt.Errorf("failed to derive vault tokens: %v", err))
|
||||
return nil, result
|
||||
c.logger.Printf("[ERR] client.vault: failed to derive vault tokens: %v", err)
|
||||
return nil, fmt.Errorf("failed to derive vault tokens: %v", err)
|
||||
}
|
||||
if resp.Tasks == nil {
|
||||
c.logger.Printf("[ERR] client/vaultclient: failed to derive vault token: invalid response")
|
||||
result = multierror.Append(result, fmt.Errorf("failed to derive vault tokens: invalid response"))
|
||||
return nil, result
|
||||
c.logger.Printf("[ERR] client.vault: failed to derive vault token: invalid response")
|
||||
return nil, fmt.Errorf("failed to derive vault tokens: invalid response")
|
||||
}
|
||||
|
||||
unwrappedTokens := make(map[string]string)
|
||||
|
||||
// Retrieve the wrapped tokens from the response and unwrap it using
|
||||
// the VaultClient's token, which is cached at the API client.
|
||||
// Retrieve the wrapped tokens from the response and unwrap it
|
||||
for _, taskName := range verifiedTasks {
|
||||
// Get the wrapped token
|
||||
wrappedToken, ok := resp.Tasks[taskName]
|
||||
if !ok {
|
||||
c.logger.Printf("[ERR] client/vaultclient: wrapped token missing for task %q", taskName)
|
||||
result = multierror.Append(result, fmt.Errorf("wrapped token missing for task %q", taskName))
|
||||
return nil, result
|
||||
c.logger.Printf("[ERR] client.vault: wrapped token missing for task %q", taskName)
|
||||
return nil, fmt.Errorf("wrapped token missing for task %q", taskName)
|
||||
}
|
||||
|
||||
// Unwrap the vault token
|
||||
unwrapResp, err := c.client.Logical().Unwrap(wrappedToken)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to unwrap the token for task %q: %v", taskName, err))
|
||||
return nil, result
|
||||
return nil, fmt.Errorf("failed to unwrap the token for task %q: %v", taskName, err)
|
||||
}
|
||||
if unwrapResp == nil || unwrapResp.Auth == nil || unwrapResp.Auth.ClientToken == "" {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to unwrap the token for task %q", taskName))
|
||||
return nil, result
|
||||
return nil, fmt.Errorf("failed to unwrap the token for task %q", taskName)
|
||||
}
|
||||
|
||||
// Append the unwrapped token to the return value
|
||||
|
@ -493,18 +399,19 @@ func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error)
|
|||
// Use the token supplied to interact with vault
|
||||
c.client.SetToken(token)
|
||||
|
||||
// Restore the token in client to VaultClient's token
|
||||
defer c.client.SetToken(c.config.Token)
|
||||
// Reset the token before returning
|
||||
defer c.client.SetToken("")
|
||||
|
||||
// Read the consul ACL token and return the secret directly
|
||||
return c.client.Logical().Read(path)
|
||||
}
|
||||
|
||||
// RenewToken renews the supplied token and adds it to the min-heap so that it
|
||||
// is renewed periodically by the renewal loop. Any error returned during
|
||||
// renewal will be written to a buffered channel and the channel is returned
|
||||
// instead of an actual error. This helps the caller be notified of a renewal
|
||||
// failure asynchronously for appropriate actions to be taken.
|
||||
// RenewToken renews the supplied token for a given duration (in seconds) and
|
||||
// adds it to the min-heap so that it is renewed periodically by the renewal
|
||||
// loop. Any error returned during renewal will be written to a buffered
|
||||
// channel and the channel is returned instead of an actual error. This helps
|
||||
// the caller be notified of a renewal failure asynchronously for appropriate
|
||||
// actions to be taken.
|
||||
func (c *vaultClient) RenewToken(token string, increment int) <-chan error {
|
||||
// Create a buffered error channel
|
||||
errCh := make(chan error, 1)
|
||||
|
@ -531,19 +438,20 @@ func (c *vaultClient) RenewToken(token string, increment int) <-chan error {
|
|||
// error channel.
|
||||
if err := c.renew(renewalReq); err != nil {
|
||||
errCh <- err
|
||||
close(errCh)
|
||||
}
|
||||
|
||||
return errCh
|
||||
}
|
||||
|
||||
// RenewLease renews the supplied lease identifier for a supplied duration and
|
||||
// adds it to the min-heap so that it gets renewed periodically by the renewal
|
||||
// loop. Any error returned during renewal will be written to a buffered
|
||||
// channel and the channel is returned instead of an actual error. This helps
|
||||
// the caller be notified of a renewal failure asynchronously for appropriate
|
||||
// actions to be taken.
|
||||
// RenewLease renews the supplied lease identifier for a supplied duration (in
|
||||
// seconds) and adds it to the min-heap so that it gets renewed periodically by
|
||||
// the renewal loop. Any error returned during renewal will be written to a
|
||||
// buffered channel and the channel is returned instead of an actual error.
|
||||
// This helps the caller be notified of a renewal failure asynchronously for
|
||||
// appropriate actions to be taken.
|
||||
func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error {
|
||||
c.logger.Printf("[DEBUG] client/vaultclient: renewing lease %q", leaseId)
|
||||
c.logger.Printf("[DEBUG] client.vault: renewing lease %q", leaseId)
|
||||
// Create a buffered error channel
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
|
@ -567,6 +475,7 @@ func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error {
|
|||
// Renew the secret and send any error to the dedicated error channel
|
||||
if err := c.renew(renewalReq); err != nil {
|
||||
errCh <- err
|
||||
close(errCh)
|
||||
}
|
||||
|
||||
return errCh
|
||||
|
@ -598,9 +507,8 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
|
|||
var renewalErr error
|
||||
leaseDuration := req.increment
|
||||
if req.isToken {
|
||||
// Reset the token in the API client to that of VaultClient
|
||||
// before returning
|
||||
defer c.client.SetToken(c.config.Token)
|
||||
// Reset the token in the API client before returning
|
||||
defer c.client.SetToken("")
|
||||
|
||||
// Set the token in the API client to the one that needs
|
||||
// renewal
|
||||
|
@ -643,13 +551,13 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
|
|||
rand.Seed(time.Now().Unix())
|
||||
duration = min + rand.Intn(max-min)
|
||||
}
|
||||
c.logger.Printf("[DEBUG] client/vaultclient: req.increment: %d, leaseDuration: %d, duration: %d",
|
||||
c.logger.Printf("[DEBUG] client.vault: req.increment: %d, leaseDuration: %d, duration: %d",
|
||||
req.increment, leaseDuration, duration)
|
||||
|
||||
// Determine the next renewal time
|
||||
next := time.Now().Add(time.Duration(duration) * time.Second)
|
||||
|
||||
if c.IsTracked(req.id) {
|
||||
if c.isTracked(req.id) {
|
||||
// If the identifier is already tracked, this indicates a
|
||||
// subsequest renewal. In this case, update the existing
|
||||
// element in the heap with the new renewal time.
|
||||
|
@ -727,7 +635,7 @@ func (c *vaultClient) run() {
|
|||
case <-c.updateCh:
|
||||
continue
|
||||
case <-c.stopCh:
|
||||
c.logger.Printf("[DEBUG] client/vaultclient: stopped")
|
||||
c.logger.Printf("[DEBUG] client.vault: stopped")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -751,7 +659,7 @@ func (c *vaultClient) stopRenew(id string) error {
|
|||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if !c.IsTracked(id) {
|
||||
if !c.isTracked(id) {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -161,7 +161,7 @@ func TestVaultClient_Heap(t *testing.T) {
|
|||
if err := c.heap.Push(renewalReq1, now.Add(50*time.Second)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !c.IsTracked("id1") {
|
||||
if !c.isTracked("id1") {
|
||||
t.Fatalf("id1 should have been tracked")
|
||||
}
|
||||
|
||||
|
@ -173,7 +173,7 @@ func TestVaultClient_Heap(t *testing.T) {
|
|||
if err := c.heap.Push(renewalReq2, now.Add(40*time.Second)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !c.IsTracked("id2") {
|
||||
if !c.isTracked("id2") {
|
||||
t.Fatalf("id2 should have been tracked")
|
||||
}
|
||||
|
||||
|
@ -185,7 +185,7 @@ func TestVaultClient_Heap(t *testing.T) {
|
|||
if err := c.heap.Push(renewalReq3, now.Add(60*time.Second)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !c.IsTracked("id3") {
|
||||
if !c.isTracked("id3") {
|
||||
t.Fatalf("id3 should have been tracked")
|
||||
}
|
||||
|
||||
|
@ -226,15 +226,15 @@ func TestVaultClient_Heap(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if c.IsTracked("id1") {
|
||||
if c.isTracked("id1") {
|
||||
t.Fatalf("id1 should not have been tracked")
|
||||
}
|
||||
|
||||
if c.IsTracked("id1") {
|
||||
if c.isTracked("id1") {
|
||||
t.Fatalf("id1 should not have been tracked")
|
||||
}
|
||||
|
||||
if c.IsTracked("id1") {
|
||||
if c.isTracked("id1") {
|
||||
t.Fatalf("id1 should not have been tracked")
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue