diff --git a/nomad/fsm.go b/nomad/fsm.go index 318ad8bf3..61c52e612 100644 --- a/nomad/fsm.go +++ b/nomad/fsm.go @@ -140,6 +140,8 @@ func (n *nomadFSM) Apply(log *raft.Log) interface{} { return n.applyReconcileSummaries(buf[1:], log.Index) case structs.VaultAccessorRegisterRequestType: return n.applyUpsertVaultAccessor(buf[1:], log.Index) + case structs.VaultAccessorDegisterRequestType: + return n.applyDeregisterVaultAccessor(buf[1:], log.Index) default: if ignoreUnknown { n.logger.Printf("[WARN] nomad.fsm: ignoring unknown message type (%d), upgrade to newer version", msgType) @@ -461,7 +463,7 @@ func (n *nomadFSM) applyReconcileSummaries(buf []byte, index uint64) interface{} // and task func (n *nomadFSM) applyUpsertVaultAccessor(buf []byte, index uint64) interface{} { defer metrics.MeasureSince([]string{"nomad", "fsm", "upsert_vault_accessor"}, time.Now()) - var req structs.VaultAccessorRegisterRequest + var req structs.VaultAccessorsRequest if err := structs.Decode(buf, &req); err != nil { panic(fmt.Errorf("failed to decode request: %v", err)) } @@ -474,6 +476,23 @@ func (n *nomadFSM) applyUpsertVaultAccessor(buf []byte, index uint64) interface{ return nil } +// applyDeregisterVaultAccessor stores the Vault accessors for a given allocation +// and task +func (n *nomadFSM) applyDeregisterVaultAccessor(buf []byte, index uint64) interface{} { + defer metrics.MeasureSince([]string{"nomad", "fsm", "deregister_vault_accessor"}, time.Now()) + var req structs.VaultAccessorsRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + + if err := n.state.DeleteVaultAccessors(index, req.Accessors); err != nil { + n.logger.Printf("[ERR] nomad.fsm: DeregisterVaultAccessor failed: %v", err) + return err + } + + return nil +} + func (n *nomadFSM) Snapshot() (raft.FSMSnapshot, error) { // Create a new snapshot snap, err := n.state.Snapshot() diff --git a/nomad/fsm_test.go b/nomad/fsm_test.go index 805e365c4..704637a66 100644 --- a/nomad/fsm_test.go +++ b/nomad/fsm_test.go @@ -776,7 +776,7 @@ func TestFSM_UpsertVaultAccessor(t *testing.T) { va := mock.VaultAccessor() va2 := mock.VaultAccessor() - req := structs.VaultAccessorRegisterRequest{ + req := structs.VaultAccessorsRequest{ Accessors: []*structs.VaultAccessor{va, va2}, } buf, err := structs.Encode(structs.VaultAccessorRegisterRequestType, req) @@ -818,6 +818,47 @@ func TestFSM_UpsertVaultAccessor(t *testing.T) { } } +func TestFSM_DeregisterVaultAccessor(t *testing.T) { + fsm := testFSM(t) + fsm.blockedEvals.SetEnabled(true) + + va := mock.VaultAccessor() + va2 := mock.VaultAccessor() + accessors := []*structs.VaultAccessor{va, va2} + + // Insert the accessors + if err := fsm.State().UpsertVaultAccessor(1000, accessors); err != nil { + t.Fatalf("bad: %v", err) + } + + req := structs.VaultAccessorsRequest{ + Accessors: accessors, + } + buf, err := structs.Encode(structs.VaultAccessorDegisterRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + out1, err := fsm.State().VaultAccessor(va.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + if out1 != nil { + t.Fatalf("not deleted!") + } + + tt := fsm.TimeTable() + index := tt.NearestIndex(time.Now().UTC()) + if index != 1 { + t.Fatalf("bad: %d", index) + } +} + func testSnapshotRestore(t *testing.T, fsm *nomadFSM) *nomadFSM { // Snapshot snap, err := fsm.Snapshot() diff --git a/nomad/leader.go b/nomad/leader.go index 9424c4147..3307608e6 100644 --- a/nomad/leader.go +++ b/nomad/leader.go @@ -1,6 +1,7 @@ package nomad import ( + "context" "errors" "fmt" "time" @@ -132,6 +133,12 @@ func (s *Server) establishLeadership(stopCh chan struct{}) error { return err } + // Activate the vault client + s.vault.SetActive(true) + if err := s.restoreRevokingAccessors(); err != nil { + return err + } + // Enable the periodic dispatcher, since we are now the leader. s.periodicDispatcher.SetEnabled(true) s.periodicDispatcher.Start() @@ -205,6 +212,57 @@ func (s *Server) restoreEvals() error { return nil } +// restoreRevokingAccessors is used to restore Vault accessors that should be +// revoked. +func (s *Server) restoreRevokingAccessors() error { + // An accessor should be revoked if its allocation or node is terminal + state := s.fsm.State() + iter, err := state.VaultAccessors() + if err != nil { + return fmt.Errorf("failed to get vault accessors: %v", err) + } + + var revoke []*structs.VaultAccessor + for { + raw := iter.Next() + if raw == nil { + break + } + + va := raw.(*structs.VaultAccessor) + + // Check the allocation + alloc, err := state.AllocByID(va.AllocID) + if err != nil { + return fmt.Errorf("failed to lookup allocation: %v", va.AllocID, err) + } + if alloc == nil || alloc.Terminated() { + // No longer running and should be revoked + revoke = append(revoke, va) + continue + } + + // Check the node + node, err := state.NodeByID(va.NodeID) + if err != nil { + return fmt.Errorf("failed to lookup node %q: %v", va.NodeID, err) + } + if node == nil || node.TerminalStatus() { + // Node is terminal so any accessor from it should be revoked + revoke = append(revoke, va) + continue + } + } + + if len(revoke) != 0 { + if err := s.vault.RevokeTokens(context.Background(), revoke, true); err != nil { + return fmt.Errorf("failed to revoke tokens: %v", err) + } + } + + return nil +} + // restorePeriodicDispatcher is used to restore all periodic jobs into the // periodic dispatcher. It also determines if a periodic job should have been // created during the leadership transition and force runs them. The periodic @@ -409,6 +467,9 @@ func (s *Server) revokeLeadership() error { // Disable the periodic dispatcher, since it is only useful as a leader s.periodicDispatcher.SetEnabled(false) + // Disable the Vault client as it is only useful as a leader. + s.vault.SetActive(false) + // Clear the heartbeat timers on either shutdown or step down, // since we are no longer responsible for TTL expirations. if err := s.clearAllHeartbeatTimers(); err != nil { diff --git a/nomad/leader_test.go b/nomad/leader_test.go index b16f714a5..71b4e7878 100644 --- a/nomad/leader_test.go +++ b/nomad/leader_test.go @@ -544,3 +544,31 @@ func TestLeader_ReapDuplicateEval(t *testing.T) { t.Fatalf("err: %v", err) }) } + +func TestLeader_RestoreVaultAccessors(t *testing.T) { + s1 := testServer(t, func(c *Config) { + c.NumSchedulers = 0 + }) + defer s1.Shutdown() + testutil.WaitForLeader(t, s1.RPC) + + // Insert a vault accessor that should be revoked + state := s1.fsm.State() + va := mock.VaultAccessor() + if err := state.UpsertVaultAccessor(100, []*structs.VaultAccessor{va}); err != nil { + t.Fatalf("bad: %v", err) + } + + // Swap the Vault client + tvc := &TestVaultClient{} + s1.vault = tvc + + // Do a restore + if err := s1.restoreRevokingAccessors(); err != nil { + t.Fatalf("Failed to restore: %v", err) + } + + if len(tvc.RevokedTokens) != 1 && tvc.RevokedTokens[0].Accessor != va.Accessor { + t.Fatalf("Bad revoked accessors: %v", tvc.RevokedTokens) + } +} diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index abed2632a..82a0c7560 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -11,6 +11,7 @@ import ( "github.com/armon/go-metrics" "github.com/hashicorp/go-memdb" + "github.com/hashicorp/go-multierror" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/watch" @@ -215,7 +216,7 @@ func (n *Node) constructNodeServerInfoResponse(snap *state.StateSnapshot, reply return nil } -// Deregister is used to remove a client from the client. If a client should +// Deregister is used to remove a client from the cluster. If a client should // just be made unavailable for scheduling, a status update is preferred. func (n *Node) Deregister(args *structs.NodeDeregisterRequest, reply *structs.NodeUpdateResponse) error { if done, err := n.srv.forward("Node.Deregister", args, args, reply); done { @@ -245,6 +246,20 @@ func (n *Node) Deregister(args *structs.NodeDeregisterRequest, reply *structs.No return err } + // Determine if there are any Vault accessors on the node + accessors, err := n.srv.State().VaultAccessorsByNode(args.NodeID) + if err != nil { + n.srv.logger.Printf("[ERR] nomad.client: looking up accessors for node %q failed: %v", args.NodeID, err) + return err + } + + if len(accessors) != 0 { + if err := n.srv.vault.RevokeTokens(context.Background(), accessors, true); err != nil { + n.srv.logger.Printf("[ERR] nomad.client: revoking accessors for node %q failed: %v", args.NodeID, err) + return err + } + } + // Setup the reply reply.EvalIDs = evalIDs reply.EvalCreateIndex = evalIndex @@ -311,7 +326,22 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct } // Check if we need to setup a heartbeat - if args.Status != structs.NodeStatusDown { + switch args.Status { + case structs.NodeStatusDown: + // Determine if there are any Vault accessors on the node + accessors, err := n.srv.State().VaultAccessorsByNode(args.NodeID) + if err != nil { + n.srv.logger.Printf("[ERR] nomad.client: looking up accessors for node %q failed: %v", args.NodeID, err) + return err + } + + if len(accessors) != 0 { + if err := n.srv.vault.RevokeTokens(context.Background(), accessors, true); err != nil { + n.srv.logger.Printf("[ERR] nomad.client: revoking accessors for node %q failed: %v", args.NodeID, err) + return err + } + } + default: ttl, err := n.srv.resetHeartbeatTimer(args.NodeID) if err != nil { n.srv.logger.Printf("[ERR] nomad.client: heartbeat reset failed: %v", err) @@ -686,13 +716,39 @@ func (n *Node) batchUpdate(future *batchFuture, updates []*structs.Allocation) { } // Commit this update via Raft + var mErr multierror.Error _, index, err := n.srv.raftApply(structs.AllocClientUpdateRequestType, batch) if err != nil { n.srv.logger.Printf("[ERR] nomad.client: alloc update failed: %v", err) + mErr.Errors = append(mErr.Errors, err) + } + + var revoke []*structs.VaultAccessor + for _, alloc := range updates { + // Skip any allocation that isn't dead on the client + if !alloc.Terminated() { + continue + } + + // Determine if there are any Vault accessors for the allocation + accessors, err := n.srv.State().VaultAccessorsByAlloc(alloc.ID) + if err != nil { + n.srv.logger.Printf("[ERR] nomad.client: looking up accessors for alloc %q failed: %v", alloc.ID, err) + mErr.Errors = append(mErr.Errors, err) + } + + revoke = append(revoke, accessors...) + } + + if len(revoke) != 0 { + if err := n.srv.vault.RevokeTokens(context.Background(), revoke, true); err != nil { + n.srv.logger.Printf("[ERR] nomad.client: batched accessor revocation failed: %v", err) + mErr.Errors = append(mErr.Errors, err) + } } // Respond to the future - future.Respond(index, err) + future.Respond(index, mErr.ErrorOrNil()) } // List is used to list the available nodes @@ -1011,10 +1067,6 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, // Wait for everything to complete or for an error err = g.Wait() - if err != nil { - // TODO Revoke any created token - return err - } // Commit to Raft before returning any of the tokens accessors := make([]*structs.VaultAccessor, 0, len(results)) @@ -1037,7 +1089,17 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, accessors = append(accessors, accessor) } - req := structs.VaultAccessorRegisterRequest{Accessors: accessors} + // If there was an error revoke the created tokens + if err != nil { + var mErr multierror.Error + mErr.Errors = append(mErr.Errors, err) + if err := n.srv.vault.RevokeTokens(context.Background(), accessors, false); err != nil { + mErr.Errors = append(mErr.Errors, err) + } + return mErr.ErrorOrNil() + } + + req := structs.VaultAccessorsRequest{Accessors: accessors} _, index, err := n.srv.raftApply(structs.VaultAccessorRegisterRequestType, &req) if err != nil { n.srv.logger.Printf("[ERR] nomad.client: Register Vault accessors failed: %v", err) diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index 93531f341..099bd2ae2 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -170,6 +170,65 @@ func TestClientEndpoint_Deregister(t *testing.T) { } } +func TestClientEndpoint_Deregister_Vault(t *testing.T) { + s1 := testServer(t, nil) + defer s1.Shutdown() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Create the register request + node := mock.Node() + reg := &structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{Region: "global"}, + } + + // Fetch the response + var resp structs.GenericResponse + if err := msgpackrpc.CallWithCodec(codec, "Node.Register", reg, &resp); err != nil { + t.Fatalf("err: %v", err) + } + + // Swap the servers Vault Client + tvc := &TestVaultClient{} + s1.vault = tvc + + // Put some Vault accessors in the state store for that node + state := s1.fsm.State() + va1 := mock.VaultAccessor() + va1.NodeID = node.ID + va2 := mock.VaultAccessor() + va2.NodeID = node.ID + state.UpsertVaultAccessor(100, []*structs.VaultAccessor{va1, va2}) + + // Deregister + dereg := &structs.NodeDeregisterRequest{ + NodeID: node.ID, + WriteRequest: structs.WriteRequest{Region: "global"}, + } + var resp2 structs.GenericResponse + if err := msgpackrpc.CallWithCodec(codec, "Node.Deregister", dereg, &resp2); err != nil { + t.Fatalf("err: %v", err) + } + if resp2.Index == 0 { + t.Fatalf("bad index: %d", resp2.Index) + } + + // Check for the node in the FSM + out, err := state.NodeByID(node.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if out != nil { + t.Fatalf("unexpected node") + } + + // Check that the endpoint revoked the tokens + if l := len(tvc.RevokedTokens); l != 2 { + t.Fatalf("Deregister revoked %d tokens; want 2", l) + } +} + func TestClientEndpoint_UpdateStatus(t *testing.T) { s1 := testServer(t, nil) defer s1.Shutdown() @@ -229,6 +288,63 @@ func TestClientEndpoint_UpdateStatus(t *testing.T) { } } +func TestClientEndpoint_UpdateStatus_Vault(t *testing.T) { + s1 := testServer(t, nil) + defer s1.Shutdown() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Create the register request + node := mock.Node() + reg := &structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{Region: "global"}, + } + + // Fetch the response + var resp structs.NodeUpdateResponse + if err := msgpackrpc.CallWithCodec(codec, "Node.Register", reg, &resp); err != nil { + t.Fatalf("err: %v", err) + } + + // Check for heartbeat interval + ttl := resp.HeartbeatTTL + if ttl < s1.config.MinHeartbeatTTL || ttl > 2*s1.config.MinHeartbeatTTL { + t.Fatalf("bad: %#v", ttl) + } + + // Swap the servers Vault Client + tvc := &TestVaultClient{} + s1.vault = tvc + + // Put some Vault accessors in the state store for that node + state := s1.fsm.State() + va1 := mock.VaultAccessor() + va1.NodeID = node.ID + va2 := mock.VaultAccessor() + va2.NodeID = node.ID + state.UpsertVaultAccessor(100, []*structs.VaultAccessor{va1, va2}) + + // Update the status to be down + dereg := &structs.NodeUpdateStatusRequest{ + NodeID: node.ID, + Status: structs.NodeStatusDown, + WriteRequest: structs.WriteRequest{Region: "global"}, + } + var resp2 structs.NodeUpdateResponse + if err := msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", dereg, &resp2); err != nil { + t.Fatalf("err: %v", err) + } + if resp2.Index == 0 { + t.Fatalf("bad index: %d", resp2.Index) + } + + // Check that the endpoint revoked the tokens + if l := len(tvc.RevokedTokens); l != 2 { + t.Fatalf("Deregister revoked %d tokens; want 2", l) + } +} + func TestClientEndpoint_Register_GetEvals(t *testing.T) { s1 := testServer(t, nil) defer s1.Shutdown() @@ -1235,6 +1351,81 @@ func TestClientEndpoint_BatchUpdate(t *testing.T) { } } +func TestClientEndpoint_UpdateAlloc_Vault(t *testing.T) { + s1 := testServer(t, nil) + defer s1.Shutdown() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Create the register request + node := mock.Node() + reg := &structs.NodeRegisterRequest{ + Node: node, + WriteRequest: structs.WriteRequest{Region: "global"}, + } + + // Fetch the response + var resp structs.GenericResponse + if err := msgpackrpc.CallWithCodec(codec, "Node.Register", reg, &resp); err != nil { + t.Fatalf("err: %v", err) + } + + // Swap the servers Vault Client + tvc := &TestVaultClient{} + s1.vault = tvc + + // Inject fake allocation and vault accessor + alloc := mock.Alloc() + alloc.NodeID = node.ID + state := s1.fsm.State() + state.UpsertJobSummary(99, mock.JobSummary(alloc.JobID)) + if err := state.UpsertAllocs(100, []*structs.Allocation{alloc}); err != nil { + t.Fatalf("err: %v", err) + } + + va := mock.VaultAccessor() + va.NodeID = node.ID + va.AllocID = alloc.ID + if err := state.UpsertVaultAccessor(101, []*structs.VaultAccessor{va}); err != nil { + t.Fatalf("err: %v", err) + } + + // Attempt update + clientAlloc := new(structs.Allocation) + *clientAlloc = *alloc + clientAlloc.ClientStatus = structs.AllocClientStatusFailed + + // Update the alloc + update := &structs.AllocUpdateRequest{ + Alloc: []*structs.Allocation{clientAlloc}, + WriteRequest: structs.WriteRequest{Region: "global"}, + } + var resp2 structs.NodeAllocsResponse + start := time.Now() + if err := msgpackrpc.CallWithCodec(codec, "Node.UpdateAlloc", update, &resp2); err != nil { + t.Fatalf("err: %v", err) + } + if resp2.Index == 0 { + t.Fatalf("Bad index: %d", resp2.Index) + } + if diff := time.Since(start); diff < batchUpdateInterval { + t.Fatalf("too fast: %v", diff) + } + + // Lookup the alloc + out, err := state.AllocByID(alloc.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if out.ClientStatus != structs.AllocClientStatusFailed { + t.Fatalf("Bad: %#v", out) + } + + if l := len(tvc.RevokedTokens); l != 1 { + t.Fatalf("Deregister revoked %d tokens; want 1", l) + } +} + func TestClientEndpoint_CreateNodeEvals(t *testing.T) { s1 := testServer(t, nil) defer s1.Shutdown() diff --git a/nomad/server.go b/nomad/server.go index 49d0037d9..8d1851613 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -567,7 +567,7 @@ func (s *Server) setupConsulSyncer() error { // setupVaultClient is used to set up the Vault API client. func (s *Server) setupVaultClient() error { - v, err := NewVaultClient(s.config.VaultConfig, s.logger) + v, err := NewVaultClient(s.config.VaultConfig, s.logger, s.purgeVaultAccessors) if err != nil { return err } diff --git a/nomad/state/state_store.go b/nomad/state/state_store.go index cee5935f3..9ead06192 100644 --- a/nomad/state/state_store.go +++ b/nomad/state/state_store.go @@ -1136,24 +1136,19 @@ func (s *StateStore) UpsertVaultAccessor(index uint64, accessors []*structs.Vaul return nil } -// DeleteVaultAccessor is used to delete a Vault Accessor -func (s *StateStore) DeleteVaultAccessor(index uint64, accessor string) error { +// DeleteVaultAccessors is used to delete a set of Vault Accessors +func (s *StateStore) DeleteVaultAccessors(index uint64, accessors []*structs.VaultAccessor) error { txn := s.db.Txn(true) defer txn.Abort() // Lookup the accessor - existing, err := txn.First("vault_accessors", "id", accessor) - if err != nil { - return fmt.Errorf("accessor lookup failed: %v", err) - } - if existing == nil { - return fmt.Errorf("vault_accessor not found") + for _, accessor := range accessors { + // Delete the accessor + if err := txn.Delete("vault_accessors", accessor); err != nil { + return fmt.Errorf("accessor delete failed: %v", err) + } } - // Delete the accessor - if err := txn.Delete("vault_accessors", existing); err != nil { - return fmt.Errorf("accessor delete failed: %v", err) - } if err := txn.Insert("index", &IndexEntry{"vault_accessors", index}); err != nil { return fmt.Errorf("index update failed: %v", err) } diff --git a/nomad/state/state_store_test.go b/nomad/state/state_store_test.go index f272b9fb1..ac008d15c 100644 --- a/nomad/state/state_store_test.go +++ b/nomad/state/state_store_test.go @@ -2894,27 +2894,35 @@ func TestStateStore_UpsertVaultAccessors(t *testing.T) { } } -func TestStateStore_DeleteVaultAccessor(t *testing.T) { +func TestStateStore_DeleteVaultAccessors(t *testing.T) { state := testStateStore(t) - accessor := mock.VaultAccessor() + a1 := mock.VaultAccessor() + a2 := mock.VaultAccessor() + accessors := []*structs.VaultAccessor{a1, a2} - err := state.UpsertVaultAccessor(1000, []*structs.VaultAccessor{accessor}) + err := state.UpsertVaultAccessor(1000, accessors) if err != nil { t.Fatalf("err: %v", err) } - err = state.DeleteVaultAccessor(1001, accessor.Accessor) + err = state.DeleteVaultAccessors(1001, accessors) if err != nil { t.Fatalf("err: %v", err) } - out, err := state.VaultAccessor(accessor.Accessor) + out, err := state.VaultAccessor(a1.Accessor) if err != nil { t.Fatalf("err: %v", err) } - if out != nil { - t.Fatalf("bad: %#v %#v", accessor, out) + t.Fatalf("bad: %#v %#v", a1, out) + } + out, err = state.VaultAccessor(a2.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + if out != nil { + t.Fatalf("bad: %#v %#v", a2, out) } index, err := state.Index("vault_accessors") diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 058ae16d3..54472d0a0 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -48,6 +48,7 @@ const ( AllocClientUpdateRequestType ReconcileJobSummariesRequestType VaultAccessorRegisterRequestType + VaultAccessorDegisterRequestType ) const ( @@ -365,8 +366,8 @@ type DeriveVaultTokenRequest struct { QueryOptions } -// VaultAccessorRegisterRequest is used to register a set of Vault accessors -type VaultAccessorRegisterRequest struct { +// VaultAccessorsRequest is used to operate on a set of Vault accessors +type VaultAccessorsRequest struct { Accessors []*VaultAccessor } diff --git a/nomad/vault.go b/nomad/vault.go index 1866505b4..c75bb9483 100644 --- a/nomad/vault.go +++ b/nomad/vault.go @@ -7,13 +7,17 @@ import ( "log" "math/rand" "sync" + "sync/atomic" "time" + "gopkg.in/tomb.v2" + "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs/config" vapi "github.com/hashicorp/vault/api" "github.com/mitchellh/mapstructure" + "golang.org/x/sync/errgroup" "golang.org/x/time/rate" ) @@ -33,10 +37,25 @@ const ( // requestRateLimit is the maximum number of requests per second Nomad will // make against Vault requestRateLimit rate.Limit = 500.0 + + // maxParallelRevokes is the maximum number of parallel Vault + // token revocation requests + maxParallelRevokes = 64 + + // vaultRevocationIntv is the interval at which Vault tokens that failed + // initial revocation are retried + vaultRevocationIntv = 5 * time.Minute ) // VaultClient is the Servers interface for interfacing with Vault type VaultClient interface { + // SetActive activates or de-activates the Vault client. When active, token + // creation/lookup/revocation operation are allowed. + SetActive(active bool) + + // SetConfig updates the config used by the Vault client + SetConfig(config *config.VaultConfig) error + // CreateToken takes an allocation and task and returns an appropriate Vault // Secret CreateToken(ctx context.Context, a *structs.Allocation, task string) (*vapi.Secret, error) @@ -44,10 +63,18 @@ type VaultClient interface { // LookupToken takes a token string and returns its capabilities. LookupToken(ctx context.Context, token string) (*vapi.Secret, error) - // Stop is used to stop token renewal. + // RevokeTokens takes a set of tokens accessor and revokes the tokens + RevokeTokens(ctx context.Context, accessors []*structs.VaultAccessor, committed bool) error + + // Stop is used to stop token renewal Stop() } +// PurgeVaultAccessor is called to remove VaultAccessors from the system. If +// the function returns an error, the token will still be tracked and revocation +// will retry till there is a success +type PurgeVaultAccessor func(accessors []*structs.VaultAccessor) error + // tokenData holds the relevant information about the Vault token passed to the // client. type tokenData struct { @@ -76,21 +103,27 @@ type vaultClient struct { // config is the user passed Vault config config *config.VaultConfig - // renewalRunning marks whether the renewal goroutine is running - renewalRunning bool + // connEstablished marks whether we have an established connection to Vault. + // It should be accessed using a helper and updated atomically + connEstablished int32 - // establishingConn marks whether we are trying to establishe a connection to Vault - establishingConn bool - - // connEstablished marks whether we have an established connection to Vault - connEstablished bool + // token is the raw token used by the client + token string // tokenData is the data of the passed Vault token - token *tokenData + tokenData *tokenData - // enabled indicates whether the vaultClient is enabled. If it is not the - // token lookup and create methods will return errors. - enabled bool + // revoking tracks the VaultAccessors that must be revoked + revoking map[*structs.VaultAccessor]time.Time + purgeFn PurgeVaultAccessor + revLock sync.Mutex + + // active indicates whether the vaultClient is active. It should be + // accessed using a helper and updated atomically + active int32 + + // running indicates whether the vault client is started. + running bool // childTTL is the TTL for child tokens. childTTL string @@ -98,15 +131,17 @@ type vaultClient struct { // lastRenewed is the time the token was last renewed lastRenewed time.Time - shutdownCh chan struct{} - l sync.Mutex - logger *log.Logger + tomb *tomb.Tomb + logger *log.Logger + + // l is used to lock the configuration aspects of the client such that + // multiple callers can't cause conflicting config updates + l sync.Mutex } // NewVaultClient returns a Vault client from the given config. If the client -// couldn't be made an error is returned. If an error is not returned, Shutdown -// is expected to be called to clean up any created goroutine -func NewVaultClient(c *config.VaultConfig, logger *log.Logger) (*vaultClient, error) { +// couldn't be made an error is returned. +func NewVaultClient(c *config.VaultConfig, logger *log.Logger, purgeFn PurgeVaultAccessor) (*vaultClient, error) { if c == nil { return nil, fmt.Errorf("must pass valid VaultConfig") } @@ -116,69 +151,146 @@ func NewVaultClient(c *config.VaultConfig, logger *log.Logger) (*vaultClient, er } v := &vaultClient{ - enabled: c.Enabled, - config: c, - logger: logger, - limiter: rate.NewLimiter(requestRateLimit, int(requestRateLimit)), + config: c, + logger: logger, + limiter: rate.NewLimiter(requestRateLimit, int(requestRateLimit)), + revoking: make(map[*structs.VaultAccessor]time.Time), + purgeFn: purgeFn, + tomb: &tomb.Tomb{}, } - // If vault is not enabled do not configure an API client or start any token - // renewal. - if !v.enabled { - return v, nil + if v.config.Enabled { + if err := v.buildClient(); err != nil { + return nil, err + } + + // Launch the required goroutines + v.tomb.Go(wrapNilError(v.establishConnection)) + v.tomb.Go(wrapNilError(v.revokeDaemon)) + + v.running = true } + return v, nil +} + +func (v *vaultClient) Stop() { + v.l.Lock() + running := v.running + v.running = false + v.l.Unlock() + + if running { + v.tomb.Kill(nil) + v.tomb.Wait() + v.flush() + } +} + +// SetActive activates or de-activates the Vault client. When active, token +// creation/lookup/revocation operation are allowed. All queued revocations are +// cancelled if set un-active as it is assumed another instances is taking over +func (v *vaultClient) SetActive(active bool) { + atomic.StoreInt32(&v.active, 1) + return +} + +// flush is used to reset the state of the vault client +func (v *vaultClient) flush() { + v.l.Lock() + defer v.l.Unlock() + + v.client = nil + v.auth = nil + v.connEstablished = 0 + v.token = "" + v.tokenData = nil + v.revoking = make(map[*structs.VaultAccessor]time.Time) + v.childTTL = "" + v.tomb = &tomb.Tomb{} +} + +// SetConfig is used to update the Vault config being used. A temporary outage +// may occur after calling as it re-establishes a connection to Vault +func (v *vaultClient) SetConfig(config *config.VaultConfig) error { + if config == nil { + return fmt.Errorf("must pass valid VaultConfig") + } + + v.l.Lock() + defer v.l.Unlock() + + // Store the new config + v.config = config + + if v.config.Enabled { + // Stop accepting any new request + atomic.StoreInt32(&v.connEstablished, 0) + + // Kill any background routine and create a new tomb + v.tomb.Kill(nil) + v.tomb.Wait() + v.tomb = &tomb.Tomb{} + + // Rebuild the client + if err := v.buildClient(); err != nil { + v.l.Unlock() + return err + } + + // Launch the required goroutines + v.tomb.Go(wrapNilError(v.establishConnection)) + v.tomb.Go(wrapNilError(v.revokeDaemon)) + } + + return nil +} + +// buildClient is used to build a Vault client based on the stored Vault config +func (v *vaultClient) buildClient() error { // Validate we have the required fields. - if c.Token == "" { - return nil, errors.New("Vault token must be set") - } else if c.Addr == "" { - return nil, errors.New("Vault address must be set") + if v.config.Token == "" { + return errors.New("Vault token must be set") + } else if v.config.Addr == "" { + return errors.New("Vault address must be set") } // Parse the TTL if it is set - if c.TaskTokenTTL != "" { - d, err := time.ParseDuration(c.TaskTokenTTL) + if v.config.TaskTokenTTL != "" { + d, err := time.ParseDuration(v.config.TaskTokenTTL) if err != nil { - return nil, fmt.Errorf("failed to parse TaskTokenTTL %q: %v", c.TaskTokenTTL, err) + return fmt.Errorf("failed to parse TaskTokenTTL %q: %v", v.config.TaskTokenTTL, err) } if d.Nanoseconds() < minimumTokenTTL.Nanoseconds() { - return nil, fmt.Errorf("ChildTokenTTL is less than minimum allowed of %v", minimumTokenTTL) + return fmt.Errorf("ChildTokenTTL is less than minimum allowed of %v", minimumTokenTTL) } - v.childTTL = c.TaskTokenTTL + v.childTTL = v.config.TaskTokenTTL } else { // Default the TaskTokenTTL v.childTTL = defaultTokenTTL } // Get the Vault API configuration - apiConf, err := c.ApiConfig() + apiConf, err := v.config.ApiConfig() if err != nil { - return nil, fmt.Errorf("Failed to create Vault API config: %v", err) + return fmt.Errorf("Failed to create Vault API config: %v", err) } // Create the Vault API client client, err := vapi.NewClient(apiConf) if err != nil { v.logger.Printf("[ERR] vault: failed to create Vault client. Not retrying: %v", err) - return nil, err + return err } // Set the token and store the client - client.SetToken(v.config.Token) + v.token = v.config.Token + client.SetToken(v.token) v.client = client v.auth = client.Auth().Token() - - // Prepare and launch the token renewal goroutine - v.shutdownCh = make(chan struct{}) - go v.establishConnection() - return v, nil -} - -// setLimit is used to update the rate limit -func (v *vaultClient) setLimit(l rate.Limit) { - v.limiter = rate.NewLimiter(l, int(l)) + return nil } // establishConnection is used to make first contact with Vault. This should be @@ -186,10 +298,6 @@ func (v *vaultClient) setLimit(l rate.Limit) { // is stopped or the connection is successfully made at which point the renew // loop is started. func (v *vaultClient) establishConnection() { - v.l.Lock() - v.establishingConn = true - v.l.Unlock() - // Create the retry timer and set initial duration to zero so it fires // immediately retryTimer := time.NewTimer(0) @@ -197,7 +305,7 @@ func (v *vaultClient) establishConnection() { OUTER: for { select { - case <-v.shutdownCh: + case <-v.tomb.Dying(): return case <-retryTimer.C: // Ensure the API is reachable @@ -212,10 +320,7 @@ OUTER: } } - v.l.Lock() - v.connEstablished = true - v.establishingConn = false - v.l.Unlock() + atomic.StoreInt32(&v.connEstablished, 1) // Retrieve our token, validate it and parse the lease duration if err := v.parseSelfToken(); err != nil { @@ -228,22 +333,18 @@ OUTER: v.client.SetWrappingLookupFunc(v.getWrappingFn()) // If we are given a non-root token, start renewing it - if v.token.Root { + if v.tokenData.Root { v.logger.Printf("[DEBUG] vault: not renewing token as it is root") } else { v.logger.Printf("[DEBUG] vault: token lease duration is %v", - time.Duration(v.token.CreationTTL)*time.Second) - go v.renewalLoop() + time.Duration(v.tokenData.CreationTTL)*time.Second) + v.tomb.Go(wrapNilError(v.renewalLoop)) } } // renewalLoop runs the renew loop. This should only be called if we are given a // non-root token. func (v *vaultClient) renewalLoop() { - v.l.Lock() - v.renewalRunning = true - v.l.Unlock() - // Create the renewal timer and set initial duration to zero so it fires // immediately authRenewTimer := time.NewTimer(0) @@ -254,12 +355,12 @@ func (v *vaultClient) renewalLoop() { for { select { - case <-v.shutdownCh: + case <-v.tomb.Dying(): return case <-authRenewTimer.C: // Renew the token and determine the new expiration err := v.renew() - currentExpiration := v.lastRenewed.Add(time.Duration(v.token.CreationTTL) * time.Second) + currentExpiration := v.lastRenewed.Add(time.Duration(v.tokenData.CreationTTL) * time.Second) // Successfully renewed if err == nil { @@ -305,14 +406,8 @@ func (v *vaultClient) renewalLoop() { if maxBackoff < 0 { // We have failed to renew the token past its expiration. Stop // renewing with Vault. - v.l.Lock() - defer v.l.Unlock() - v.logger.Printf("[ERR] vault: failed to renew Vault token before lease expiration. Renew loop exiting") - if v.renewalRunning { - v.renewalRunning = false - close(v.shutdownCh) - } - + v.logger.Printf("[ERR] vault: failed to renew Vault token before lease expiration. Shutting down Vault client") + atomic.StoreInt32(&v.connEstablished, 0) return } else if backoff > maxBackoff.Seconds() { @@ -331,7 +426,7 @@ func (v *vaultClient) renewalLoop() { // returned. This method updates the lastRenewed time func (v *vaultClient) renew() error { // Attempt to renew the token - secret, err := v.auth.RenewSelf(v.token.CreationTTL) + secret, err := v.auth.RenewSelf(v.tokenData.CreationTTL) if err != nil { return err } @@ -351,8 +446,8 @@ func (v *vaultClient) renew() error { // getWrappingFn returns an appropriate wrapping function for Nomad Servers func (v *vaultClient) getWrappingFn() func(operation, path string) string { createPath := "auth/token/create" - if !v.token.Root { - createPath = fmt.Sprintf("auth/token/create/%s", v.token.Role) + if !v.tokenData.Root { + createPath = fmt.Sprintf("auth/token/create/%s", v.tokenData.Role) } return func(operation, path string) string { @@ -407,45 +502,38 @@ func (v *vaultClient) parseSelfToken() error { } data.Root = root - v.token = &data + v.tokenData = &data return nil } -// Stop stops any goroutine that may be running, either for establishing a Vault -// connection or token renewal. -func (v *vaultClient) Stop() { - // Nothing to do - if !v.enabled { - return - } - - v.l.Lock() - defer v.l.Unlock() - if !v.renewalRunning && !v.establishingConn { - return - } - - close(v.shutdownCh) - v.renewalRunning = false - v.establishingConn = false -} - // ConnectionEstablished returns whether a connection to Vault has been // established. func (v *vaultClient) ConnectionEstablished() bool { + return atomic.LoadInt32(&v.connEstablished) == 1 +} + +func (v *vaultClient) Enabled() bool { v.l.Lock() defer v.l.Unlock() - return v.connEstablished + return v.config.Enabled +} + +// +func (v *vaultClient) Active() bool { + return atomic.LoadInt32(&v.active) == 1 } // CreateToken takes the allocation and task and returns an appropriate Vault // token. The call is rate limited and may be canceled with the passed policy func (v *vaultClient) CreateToken(ctx context.Context, a *structs.Allocation, task string) (*vapi.Secret, error) { - // Nothing to do - if !v.enabled { + if !v.Enabled() { return nil, fmt.Errorf("Vault integration disabled") } + if !v.Active() { + return nil, fmt.Errorf("Vault client not active") + } + // Check if we have established a connection with Vault if !v.ConnectionEstablished() { return nil, fmt.Errorf("Connection to Vault has not been established. Retry") @@ -486,12 +574,12 @@ func (v *vaultClient) CreateToken(ctx context.Context, a *structs.Allocation, ta // token or a role based token var secret *vapi.Secret var err error - if v.token.Root { + if v.tokenData.Root { req.Period = v.childTTL secret, err = v.auth.Create(req) } else { // Make the token using the role - secret, err = v.auth.CreateWithRole(req, v.token.Role) + secret, err = v.auth.CreateWithRole(req, v.tokenData.Role) } return secret, err @@ -500,11 +588,14 @@ func (v *vaultClient) CreateToken(ctx context.Context, a *structs.Allocation, ta // LookupToken takes a Vault token and does a lookup against Vault. The call is // rate limited and may be canceled with passed context. func (v *vaultClient) LookupToken(ctx context.Context, token string) (*vapi.Secret, error) { - // Nothing to do - if !v.enabled { + if !v.Enabled() { return nil, fmt.Errorf("Vault integration disabled") } + if !v.Active() { + return nil, fmt.Errorf("Vault client not active") + } + // Check if we have established a connection with Vault if !v.ConnectionEstablished() { return nil, fmt.Errorf("Connection to Vault has not been established. Retry") @@ -531,3 +622,209 @@ func PoliciesFrom(s *vapi.Secret) ([]string, error) { return data.Policies, nil } + +// RevokeTokens revokes the passed set of accessors. If committed is set, the +// purge function passed to the client is called. If there is an error purging +// either because of Vault failures or because of the purge function, the +// revocation is retried until the tokens TTL. +func (v *vaultClient) RevokeTokens(ctx context.Context, accessors []*structs.VaultAccessor, committed bool) error { + if !v.Enabled() { + return nil + } + + if !v.Active() { + return fmt.Errorf("Vault client not active") + } + + // Check if we have established a connection with Vault. If not just add it + // to the queue + if !v.ConnectionEstablished() { + // Only bother tracking it for later revocation if the accessor was + // committed + if committed { + v.storeForRevocation(accessors) + } + + return nil + } + + // Attempt to revoke immediately and if it fails, add it to the revoke queue + err := v.parallelRevoke(ctx, accessors) + if !committed { + // If it is uncommitted, it is a best effort revoke as it will shortly + // TTL within the cubbyhole and has not been leaked to any outside + // system + return nil + } + + if err != nil { + v.logger.Printf("[WARN] vault: failed to revoke tokens. Will reattempt til TTL: %v", err) + v.storeForRevocation(accessors) + return nil + } + + if err := v.purgeFn(accessors); err != nil { + v.logger.Printf("[ERR] vault: failed to purge Vault accessors: %v", err) + v.storeForRevocation(accessors) + return nil + } + + return nil +} + +// storeForRevocation stores the passed set of accessors for revocation. It +// captrues their effective TTL by storing their create TTL plus the current +// time. +func (v *vaultClient) storeForRevocation(accessors []*structs.VaultAccessor) { + v.revLock.Lock() + now := time.Now() + for _, a := range accessors { + v.revoking[a] = now.Add(time.Duration(a.CreationTTL) * time.Second) + } + v.revLock.Unlock() +} + +// parallelRevoke revokes the passed VaultAccessors in parallel. +func (v *vaultClient) parallelRevoke(ctx context.Context, accessors []*structs.VaultAccessor) error { + if !v.Enabled() { + return fmt.Errorf("Vault integration disabled") + } + + if !v.Active() { + return fmt.Errorf("Vault client not active") + } + + // Check if we have established a connection with Vault + if !v.ConnectionEstablished() { + return fmt.Errorf("Connection to Vault has not been established. Retry") + } + + g, pCtx := errgroup.WithContext(ctx) + + // Cap the handlers + handlers := len(accessors) + if handlers > maxParallelRevokes { + handlers = maxParallelRevokes + } + + // Create the Vault Tokens + input := make(chan *structs.VaultAccessor, handlers) + for i := 0; i < handlers; i++ { + g.Go(func() error { + for { + select { + case va, ok := <-input: + if !ok { + return nil + } + + if err := v.auth.RevokeAccessor(va.Accessor); err != nil { + return fmt.Errorf("failed to revoke token (alloc: %q, node: %q, task: %q)", va.AllocID, va.NodeID, va.Task) + } + case <-pCtx.Done(): + return nil + } + } + }) + } + + // Send the input + go func() { + defer close(input) + for _, va := range accessors { + select { + case <-pCtx.Done(): + return + case input <- va: + } + } + + }() + + // Wait for everything to complete + return g.Wait() +} + +// revokeDaemon should be called in a goroutine and is used to periodically +// revoke Vault accessors that failed the original revocation +func (v *vaultClient) revokeDaemon() { + ticker := time.NewTicker(vaultRevocationIntv) + defer ticker.Stop() + + for { + select { + case <-v.tomb.Dying(): + return + case now := <-ticker.C: + if !v.ConnectionEstablished() { + continue + } + + v.revLock.Lock() + + // Fast path + if len(v.revoking) == 0 { + v.revLock.Unlock() + continue + } + + // Build the list of allocations that need to revoked while pruning any TTL'd checks + revoking := make([]*structs.VaultAccessor, 0, len(v.revoking)) + for va, ttl := range v.revoking { + if now.After(ttl) { + delete(v.revoking, va) + } else { + revoking = append(revoking, va) + } + } + + if err := v.parallelRevoke(context.Background(), revoking); err != nil { + v.logger.Printf("[WARN] vault: background token revocation errored: %v", err) + v.revLock.Unlock() + continue + } + + // Unlock before a potentially expensive operation + v.revLock.Unlock() + + // Call the passed in token revocation function + if err := v.purgeFn(revoking); err != nil { + // Can continue since revocation is idempotent + v.logger.Printf("[ERR] vault: token revocation errored: %v", err) + continue + } + + // Can delete from the tracked list now that we have purged + v.revLock.Lock() + for _, va := range revoking { + delete(v.revoking, va) + } + v.revLock.Unlock() + } + } +} + +// purgeVaultAccessors creates a Raft transaction to remove the passed Vault +// Accessors +func (s *Server) purgeVaultAccessors(accessors []*structs.VaultAccessor) error { + // Commit this update via Raft + req := structs.VaultAccessorsRequest{Accessors: accessors} + _, _, err := s.raftApply(structs.VaultAccessorDegisterRequestType, req) + return err +} + +// wrapNilError is a helper that returns a wrapped function that returns a nil +// error +func wrapNilError(f func()) func() error { + return func() error { + f() + return nil + } +} + +// setLimit is used to update the rate limit +func (v *vaultClient) setLimit(l rate.Limit) { + v.l.Lock() + defer v.l.Unlock() + v.limiter = rate.NewLimiter(l, int(l)) +} diff --git a/nomad/vault_test.go b/nomad/vault_test.go index 4d12fd6a7..2da6ca5c9 100644 --- a/nomad/vault_test.go +++ b/nomad/vault_test.go @@ -31,24 +31,19 @@ func TestVaultClient_BadConfig(t *testing.T) { logger := log.New(os.Stderr, "", log.LstdFlags) // Should be no error since Vault is not enabled - client, err := NewVaultClient(conf, logger) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - defer client.Stop() - - if client.ConnectionEstablished() { - t.Fatalf("bad") + _, err := NewVaultClient(nil, logger, nil) + if err == nil || !strings.Contains(err.Error(), "valid") { + t.Fatalf("expected config error: %v", err) } conf.Enabled = true - _, err = NewVaultClient(conf, logger) + _, err = NewVaultClient(conf, logger, nil) if err == nil || !strings.Contains(err.Error(), "token must be set") { t.Fatalf("Expected token unset error: %v", err) } conf.Token = "123" - _, err = NewVaultClient(conf, logger) + _, err = NewVaultClient(conf, logger, nil) if err == nil || !strings.Contains(err.Error(), "address must be set") { t.Fatalf("Expected address unset error: %v", err) } @@ -62,7 +57,7 @@ func TestVaultClient_EstablishConnection(t *testing.T) { logger := log.New(os.Stderr, "", log.LstdFlags) v.Config.ConnectionRetryIntv = 100 * time.Millisecond - client, err := NewVaultClient(v.Config, logger) + client, err := NewVaultClient(v.Config, logger, nil) if err != nil { t.Fatalf("failed to build vault client: %v", err) } @@ -79,11 +74,69 @@ func TestVaultClient_EstablishConnection(t *testing.T) { v.Start() waitForConnection(client, t) +} - // Ensure that since we are using a root token that we haven started the - // renewal loop. - if client.renewalRunning { - t.Fatalf("No renewal loop should be running") +func TestVaultClient_SetActive(t *testing.T) { + v := testutil.NewTestVault(t).Start() + defer v.Stop() + + logger := log.New(os.Stderr, "", log.LstdFlags) + client, err := NewVaultClient(v.Config, logger, nil) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + defer client.Stop() + + waitForConnection(client, t) + + // Do a lookup and expect an error about not being active + _, err = client.LookupToken(context.Background(), "123") + if err == nil || !strings.Contains(err.Error(), "not active") { + t.Fatalf("Expected not-active error: %v", err) + } + + client.SetActive(true) + + // Do a lookup of ourselves + _, err = client.LookupToken(context.Background(), v.RootToken) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } +} + +// Test that we can update the config and things keep working +func TestVaultClient_SetConfig(t *testing.T) { + v := testutil.NewTestVault(t).Start() + defer v.Stop() + + v2 := testutil.NewTestVault(t).Start() + defer v2.Stop() + + // Set the configs token in a new test role + v2.Config.Token = testVaultRoleAndToken(v2, t, 20) + + logger := log.New(os.Stderr, "", log.LstdFlags) + client, err := NewVaultClient(v.Config, logger, nil) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + defer client.Stop() + + waitForConnection(client, t) + + if client.tokenData == nil || len(client.tokenData.Policies) != 1 { + t.Fatalf("unexpected token: %v", client.tokenData) + } + + // Update the config + if err := client.SetConfig(v2.Config); err != nil { + t.Fatalf("SetConfig failed: %v", err) + } + + waitForConnection(client, t) + + if client.tokenData == nil || len(client.tokenData.Policies) != 2 { + t.Fatalf("unexpected token: %v", client.tokenData) } } @@ -128,7 +181,7 @@ func TestVaultClient_RenewalLoop(t *testing.T) { // Start the client logger := log.New(os.Stderr, "", log.LstdFlags) - client, err := NewVaultClient(v.Config, logger) + client, err := NewVaultClient(v.Config, logger, nil) if err != nil { t.Fatalf("failed to build vault client: %v", err) } @@ -177,29 +230,19 @@ func parseTTLFromLookup(s *vapi.Secret, t *testing.T) int64 { func TestVaultClient_LookupToken_Invalid(t *testing.T) { conf := &config.VaultConfig{ - Enabled: false, - } - - logger := log.New(os.Stderr, "", log.LstdFlags) - client, err := NewVaultClient(conf, logger) - if err != nil { - t.Fatalf("failed to build vault client: %v", err) - } - defer client.Stop() - - _, err = client.LookupToken(context.Background(), "foo") - if err == nil || !strings.Contains(err.Error(), "disabled") { - t.Fatalf("Expected error because Vault is disabled: %v", err) + Enabled: true, + Addr: "http://foobar:12345", + Token: structs.GenerateUUID(), } // Enable vault but use a bad address so it never establishes a conn - conf.Enabled = true - conf.Addr = "http://foobar:12345" - conf.Token = structs.GenerateUUID() - client, err = NewVaultClient(conf, logger) + logger := log.New(os.Stderr, "", log.LstdFlags) + client, err := NewVaultClient(conf, logger, nil) if err != nil { t.Fatalf("failed to build vault client: %v", err) } + client.SetActive(true) + defer client.Stop() _, err = client.LookupToken(context.Background(), "foo") if err == nil || !strings.Contains(err.Error(), "established") { @@ -207,23 +250,16 @@ func TestVaultClient_LookupToken_Invalid(t *testing.T) { } } -func waitForConnection(v *vaultClient, t *testing.T) { - testutil.WaitForResult(func() (bool, error) { - return v.ConnectionEstablished(), nil - }, func(err error) { - t.Fatalf("Connection not established") - }) -} - func TestVaultClient_LookupToken(t *testing.T) { v := testutil.NewTestVault(t).Start() defer v.Stop() logger := log.New(os.Stderr, "", log.LstdFlags) - client, err := NewVaultClient(v.Config, logger) + client, err := NewVaultClient(v.Config, logger, nil) if err != nil { t.Fatalf("failed to build vault client: %v", err) } + client.SetActive(true) defer client.Stop() waitForConnection(client, t) @@ -280,10 +316,11 @@ func TestVaultClient_LookupToken_RateLimit(t *testing.T) { defer v.Stop() logger := log.New(os.Stderr, "", log.LstdFlags) - client, err := NewVaultClient(v.Config, logger) + client, err := NewVaultClient(v.Config, logger, nil) if err != nil { t.Fatalf("failed to build vault client: %v", err) } + client.SetActive(true) defer client.Stop() client.setLimit(rate.Limit(1.0)) @@ -334,10 +371,11 @@ func TestVaultClient_CreateToken_Root(t *testing.T) { defer v.Stop() logger := log.New(os.Stderr, "", log.LstdFlags) - client, err := NewVaultClient(v.Config, logger) + client, err := NewVaultClient(v.Config, logger, nil) if err != nil { t.Fatalf("failed to build vault client: %v", err) } + client.SetActive(true) defer client.Stop() waitForConnection(client, t) @@ -380,10 +418,11 @@ func TestVaultClient_CreateToken_Role(t *testing.T) { //testVaultRoleAndToken(v, t, 5) // Start the client logger := log.New(os.Stderr, "", log.LstdFlags) - client, err := NewVaultClient(v.Config, logger) + client, err := NewVaultClient(v.Config, logger, nil) if err != nil { t.Fatalf("failed to build vault client: %v", err) } + client.SetActive(true) defer client.Stop() waitForConnection(client, t) @@ -416,3 +455,110 @@ func TestVaultClient_CreateToken_Role(t *testing.T) { t.Fatalf("Bad ttl: %v", s.WrapInfo.WrappedAccessor) } } + +func TestVaultClient_RevokeTokens_PreEstablishs(t *testing.T) { + v := testutil.NewTestVault(t) + logger := log.New(os.Stderr, "", log.LstdFlags) + client, err := NewVaultClient(v.Config, logger, nil) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + client.SetActive(true) + defer client.Stop() + + // Create some VaultAccessors + vas := []*structs.VaultAccessor{ + mock.VaultAccessor(), + mock.VaultAccessor(), + } + + if err := client.RevokeTokens(context.Background(), vas, false); err != nil { + t.Fatalf("RevokeTokens failed: %v", err) + } + + // Wasn't committed + if len(client.revoking) != 0 { + t.Fatalf("didn't add to revoke loop") + } + + if err := client.RevokeTokens(context.Background(), vas, true); err != nil { + t.Fatalf("RevokeTokens failed: %v", err) + } + + // Was committed + if len(client.revoking) != 2 { + t.Fatalf("didn't add to revoke loop") + } +} + +func TestVaultClient_RevokeTokens(t *testing.T) { + v := testutil.NewTestVault(t).Start() + defer v.Stop() + + purged := 0 + purge := func(accessors []*structs.VaultAccessor) error { + purged += len(accessors) + return nil + } + + logger := log.New(os.Stderr, "", log.LstdFlags) + client, err := NewVaultClient(v.Config, logger, purge) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + client.SetActive(true) + defer client.Stop() + + waitForConnection(client, t) + + // Create some vault tokens + auth := v.Client.Auth().Token() + req := vapi.TokenCreateRequest{ + Policies: []string{"default"}, + } + t1, err := auth.Create(&req) + if err != nil { + t.Fatalf("Failed to create vault token: %v", err) + } + if t1 == nil || t1.Auth == nil { + t.Fatalf("bad secret response: %+v", t1) + } + t2, err := auth.Create(&req) + if err != nil { + t.Fatalf("Failed to create vault token: %v", err) + } + if t2 == nil || t2.Auth == nil { + t.Fatalf("bad secret response: %+v", t2) + } + + // Create two VaultAccessors + vas := []*structs.VaultAccessor{ + &structs.VaultAccessor{Accessor: t1.Auth.Accessor}, + &structs.VaultAccessor{Accessor: t2.Auth.Accessor}, + } + + // Issue a token revocation + if err := client.RevokeTokens(context.Background(), vas, true); err != nil { + t.Fatalf("RevokeTokens failed: %v", err) + } + + // Lookup the token and make sure we get an error + if s, err := auth.Lookup(t1.Auth.ClientToken); err == nil { + t.Fatalf("Revoked token lookup didn't fail: %+v", s) + } + if s, err := auth.Lookup(t2.Auth.ClientToken); err == nil { + t.Fatalf("Revoked token lookup didn't fail: %+v", s) + } + + if purged != 2 { + t.Fatalf("Expected purged 2; got %d", purged) + } +} + +func waitForConnection(v *vaultClient, t *testing.T) { + testutil.WaitForResult(func() (bool, error) { + return v.ConnectionEstablished(), nil + }, func(err error) { + t.Fatalf("Connection not established") + }) +} diff --git a/nomad/vault_testing.go b/nomad/vault_testing.go index 73e5efc34..ebc163f9d 100644 --- a/nomad/vault_testing.go +++ b/nomad/vault_testing.go @@ -4,6 +4,7 @@ import ( "context" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/nomad/structs/config" vapi "github.com/hashicorp/vault/api" ) @@ -26,6 +27,8 @@ type TestVaultClient struct { // CreateTokenSecret maps a token to the Vault secret that will be returned // by the CreateToken call CreateTokenSecret map[string]map[string]*vapi.Secret + + RevokedTokens []*structs.VaultAccessor } func (v *TestVaultClient) LookupToken(ctx context.Context, token string) (*vapi.Secret, error) { @@ -126,4 +129,11 @@ func (v *TestVaultClient) SetCreateTokenSecret(allocID, task string, secret *vap v.CreateTokenSecret[allocID][task] = secret } -func (v *TestVaultClient) Stop() {} +func (v *TestVaultClient) RevokeTokens(ctx context.Context, accessors []*structs.VaultAccessor, committed bool) error { + v.RevokedTokens = append(v.RevokedTokens, accessors...) + return nil +} + +func (v *TestVaultClient) Stop() {} +func (v *TestVaultClient) SetActive(enabled bool) {} +func (v *TestVaultClient) SetConfig(config *config.VaultConfig) error { return nil } diff --git a/testutil/vault.go b/testutil/vault.go index 1f449c73c..9b52bf118 100644 --- a/testutil/vault.go +++ b/testutil/vault.go @@ -66,7 +66,7 @@ func NewTestVault(t *testing.T) *TestVault { t: t, Addr: bind, HTTPAddr: http, - RootToken: root, + RootToken: token, Client: client, Config: &config.VaultConfig{ Enabled: true, diff --git a/vendor/gopkg.in/tomb.v2/LICENSE b/vendor/gopkg.in/tomb.v2/LICENSE new file mode 100644 index 000000000..a4249bb31 --- /dev/null +++ b/vendor/gopkg.in/tomb.v2/LICENSE @@ -0,0 +1,29 @@ +tomb - support for clean goroutine termination in Go. + +Copyright (c) 2010-2011 - Gustavo Niemeyer + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/gopkg.in/tomb.v2/README.md b/vendor/gopkg.in/tomb.v2/README.md new file mode 100644 index 000000000..e7f282b5a --- /dev/null +++ b/vendor/gopkg.in/tomb.v2/README.md @@ -0,0 +1,4 @@ +Installation and usage +---------------------- + +See [gopkg.in/tomb.v2](https://gopkg.in/tomb.v2) for documentation and usage details. diff --git a/vendor/gopkg.in/tomb.v2/tomb.go b/vendor/gopkg.in/tomb.v2/tomb.go new file mode 100644 index 000000000..28bc552b2 --- /dev/null +++ b/vendor/gopkg.in/tomb.v2/tomb.go @@ -0,0 +1,223 @@ +// Copyright (c) 2011 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// The tomb package handles clean goroutine tracking and termination. +// +// The zero value of a Tomb is ready to handle the creation of a tracked +// goroutine via its Go method, and then any tracked goroutine may call +// the Go method again to create additional tracked goroutines at +// any point. +// +// If any of the tracked goroutines returns a non-nil error, or the +// Kill or Killf method is called by any goroutine in the system (tracked +// or not), the tomb Err is set, Alive is set to false, and the Dying +// channel is closed to flag that all tracked goroutines are supposed +// to willingly terminate as soon as possible. +// +// Once all tracked goroutines terminate, the Dead channel is closed, +// and Wait unblocks and returns the first non-nil error presented +// to the tomb via a result or an explicit Kill or Killf method call, +// or nil if there were no errors. +// +// It is okay to create further goroutines via the Go method while +// the tomb is in a dying state. The final dead state is only reached +// once all tracked goroutines terminate, at which point calling +// the Go method again will cause a runtime panic. +// +// Tracked functions and methods that are still running while the tomb +// is in dying state may choose to return ErrDying as their error value. +// This preserves the well established non-nil error convention, but is +// understood by the tomb as a clean termination. The Err and Wait +// methods will still return nil if all observed errors were either +// nil or ErrDying. +// +// For background and a detailed example, see the following blog post: +// +// http://blog.labix.org/2011/10/09/death-of-goroutines-under-control +// +package tomb + +import ( + "errors" + "fmt" + "sync" +) + +// A Tomb tracks the lifecycle of one or more goroutines as alive, +// dying or dead, and the reason for their death. +// +// See the package documentation for details. +type Tomb struct { + m sync.Mutex + alive int + dying chan struct{} + dead chan struct{} + reason error +} + +var ( + ErrStillAlive = errors.New("tomb: still alive") + ErrDying = errors.New("tomb: dying") +) + +func (t *Tomb) init() { + t.m.Lock() + if t.dead == nil { + t.dead = make(chan struct{}) + t.dying = make(chan struct{}) + t.reason = ErrStillAlive + } + t.m.Unlock() +} + +// Dead returns the channel that can be used to wait until +// all goroutines have finished running. +func (t *Tomb) Dead() <-chan struct{} { + t.init() + return t.dead +} + +// Dying returns the channel that can be used to wait until +// t.Kill is called. +func (t *Tomb) Dying() <-chan struct{} { + t.init() + return t.dying +} + +// Wait blocks until all goroutines have finished running, and +// then returns the reason for their death. +func (t *Tomb) Wait() error { + t.init() + <-t.dead + t.m.Lock() + reason := t.reason + t.m.Unlock() + return reason +} + +// Go runs f in a new goroutine and tracks its termination. +// +// If f returns a non-nil error, t.Kill is called with that +// error as the death reason parameter. +// +// It is f's responsibility to monitor the tomb and return +// appropriately once it is in a dying state. +// +// It is safe for the f function to call the Go method again +// to create additional tracked goroutines. Once all tracked +// goroutines return, the Dead channel is closed and the +// Wait method unblocks and returns the death reason. +// +// Calling the Go method after all tracked goroutines return +// causes a runtime panic. For that reason, calling the Go +// method a second time out of a tracked goroutine is unsafe. +func (t *Tomb) Go(f func() error) { + t.init() + t.m.Lock() + defer t.m.Unlock() + select { + case <-t.dead: + panic("tomb.Go called after all goroutines terminated") + default: + } + t.alive++ + go t.run(f) +} + +func (t *Tomb) run(f func() error) { + err := f() + t.m.Lock() + defer t.m.Unlock() + t.alive-- + if t.alive == 0 || err != nil { + t.kill(err) + if t.alive == 0 { + close(t.dead) + } + } +} + +// Kill puts the tomb in a dying state for the given reason, +// closes the Dying channel, and sets Alive to false. +// +// Althoguh Kill may be called multiple times, only the first +// non-nil error is recorded as the death reason. +// +// If reason is ErrDying, the previous reason isn't replaced +// even if nil. It's a runtime error to call Kill with ErrDying +// if t is not in a dying state. +func (t *Tomb) Kill(reason error) { + t.init() + t.m.Lock() + defer t.m.Unlock() + t.kill(reason) +} + +func (t *Tomb) kill(reason error) { + if reason == ErrStillAlive { + panic("tomb: Kill with ErrStillAlive") + } + if reason == ErrDying { + if t.reason == ErrStillAlive { + panic("tomb: Kill with ErrDying while still alive") + } + return + } + if t.reason == ErrStillAlive { + t.reason = reason + close(t.dying) + return + } + if t.reason == nil { + t.reason = reason + return + } +} + +// Killf calls the Kill method with an error built providing the received +// parameters to fmt.Errorf. The generated error is also returned. +func (t *Tomb) Killf(f string, a ...interface{}) error { + err := fmt.Errorf(f, a...) + t.Kill(err) + return err +} + +// Err returns the death reason, or ErrStillAlive if the tomb +// is not in a dying or dead state. +func (t *Tomb) Err() (reason error) { + t.init() + t.m.Lock() + reason = t.reason + t.m.Unlock() + return +} + +// Alive returns true if the tomb is not in a dying or dead state. +func (t *Tomb) Alive() bool { + return t.Err() == ErrStillAlive +} diff --git a/vendor/vendor.json b/vendor/vendor.json index 6e0dcba03..28d7c7c04 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -866,6 +866,12 @@ "path": "gopkg.in/tomb.v1", "revision": "dd632973f1e7218eb1089048e0798ec9ae7dceb8", "revisionTime": "2014-10-24T13:56:13Z" + }, + { + "checksumSHA1": "WiyCOMvfzRdymImAJ3ME6aoYUdM=", + "path": "gopkg.in/tomb.v2", + "revision": "14b3d72120e8d10ea6e6b7f87f7175734b1faab8", + "revisionTime": "2014-06-26T14:46:23Z" } ], "rootPath": "github.com/hashicorp/nomad"