diff --git a/nomad/fsm_test.go b/nomad/fsm_test.go index fe6f58813..805e365c4 100644 --- a/nomad/fsm_test.go +++ b/nomad/fsm_test.go @@ -770,6 +770,54 @@ func TestFSM_UpdateAllocFromClient(t *testing.T) { } } +func TestFSM_UpsertVaultAccessor(t *testing.T) { + fsm := testFSM(t) + fsm.blockedEvals.SetEnabled(true) + + va := mock.VaultAccessor() + va2 := mock.VaultAccessor() + req := structs.VaultAccessorRegisterRequest{ + Accessors: []*structs.VaultAccessor{va, va2}, + } + buf, err := structs.Encode(structs.VaultAccessorRegisterRequestType, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := fsm.Apply(makeLog(buf)) + if resp != nil { + t.Fatalf("resp: %v", resp) + } + + // Verify we are registered + out1, err := fsm.State().VaultAccessor(va.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + if out1 == nil { + t.Fatalf("not found!") + } + if out1.CreateIndex != 1 { + t.Fatalf("bad index: %d", out1.CreateIndex) + } + out2, err := fsm.State().VaultAccessor(va2.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + if out2 == nil { + t.Fatalf("not found!") + } + if out1.CreateIndex != 1 { + t.Fatalf("bad index: %d", out2.CreateIndex) + } + + tt := fsm.TimeTable() + index := tt.NearestIndex(time.Now().UTC()) + if index != 1 { + t.Fatalf("bad: %d", index) + } +} + func testSnapshotRestore(t *testing.T, fsm *nomadFSM) *nomadFSM { // Snapshot snap, err := fsm.Snapshot() diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index 4175dbfc6..da98a6916 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -936,6 +936,9 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, if alloc.NodeID != args.NodeID { return fmt.Errorf("Allocation %q not running on Node %q", args.AllocID, args.NodeID) } + if alloc.TerminalStatus() { + return fmt.Errorf("Can't request Vault token for terminal allocation") + } // Check the policies policies := alloc.Job.VaultPolicies() @@ -950,7 +953,7 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, var unneeded []string for _, task := range args.Tasks { taskVault := tg[task] - if len(taskVault.Policies) == 0 { + if taskVault == nil || len(taskVault.Policies) == 0 { unneeded = append(unneeded, task) } } diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index 23dabec2f..93531f341 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" + vapi "github.com/hashicorp/vault/api" ) func TestClientEndpoint_Register(t *testing.T) { @@ -1597,3 +1598,160 @@ func TestBatchFuture(t *testing.T) { t.Fatalf("bad: %d", bf.Index()) } } + +func TestClientEndpoint_DeriveVaultToken_Bad(t *testing.T) { + s1 := testServer(t, nil) + defer s1.Shutdown() + state := s1.fsm.State() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Create the node + node := mock.Node() + if err := state.UpsertNode(2, node); err != nil { + t.Fatalf("err: %v", err) + } + + // Create an alloc + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + tasks := []string{task.Name} + if err := state.UpsertAllocs(3, []*structs.Allocation{alloc}); err != nil { + t.Fatalf("err: %v", err) + } + + req := &structs.DeriveVaultTokenRequest{ + NodeID: node.ID, + SecretID: structs.GenerateUUID(), + AllocID: alloc.ID, + Tasks: tasks, + QueryOptions: structs.QueryOptions{ + Region: "global", + }, + } + + var resp structs.DeriveVaultTokenResponse + err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp) + if err == nil || !strings.Contains(err.Error(), "SecretID mismatch") { + t.Fatalf("Expected SecretID mismatch: %v", err) + } + + // Put the correct SecretID + req.SecretID = node.SecretID + + // Now we should get an error about the allocation not running on the node + err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp) + if err == nil || !strings.Contains(err.Error(), "not running on Node") { + t.Fatalf("Expected not running on node error: %v", err) + } + + // Update to be running on the node + alloc.NodeID = node.ID + if err := state.UpsertAllocs(4, []*structs.Allocation{alloc}); err != nil { + t.Fatalf("err: %v", err) + } + + // Now we should get an error about the job not needing any Vault secrets + err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp) + if err == nil || !strings.Contains(err.Error(), "without defined Vault") { + t.Fatalf("Expected no policies error: %v", err) + } + + // Update to be terminal + alloc.DesiredStatus = structs.AllocDesiredStatusStop + if err := state.UpsertAllocs(5, []*structs.Allocation{alloc}); err != nil { + t.Fatalf("err: %v", err) + } + + // Now we should get an error about the job not needing any Vault secrets + err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp) + if err == nil || !strings.Contains(err.Error(), "terminal") { + t.Fatalf("Expected terminal allocation error: %v", err) + } +} + +func TestClientEndpoint_DeriveVaultToken(t *testing.T) { + s1 := testServer(t, nil) + defer s1.Shutdown() + state := s1.fsm.State() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Enable vault and allow authenticated + s1.config.VaultConfig.Enabled = true + s1.config.VaultConfig.AllowUnauthenticated = true + + // Replace the Vault Client on the server + tvc := &TestVaultClient{} + s1.vault = tvc + + // Create the node + node := mock.Node() + if err := state.UpsertNode(2, node); err != nil { + t.Fatalf("err: %v", err) + } + + // Create an alloc an allocation that has vault policies required + alloc := mock.Alloc() + alloc.NodeID = node.ID + task := alloc.Job.TaskGroups[0].Tasks[0] + tasks := []string{task.Name} + task.Vault = &structs.Vault{Policies: []string{"a", "b"}} + if err := state.UpsertAllocs(3, []*structs.Allocation{alloc}); err != nil { + t.Fatalf("err: %v", err) + } + + // Return a secret for the task + token := structs.GenerateUUID() + accessor := structs.GenerateUUID() + ttl := 10 + secret := &vapi.Secret{ + WrapInfo: &vapi.SecretWrapInfo{ + Token: token, + WrappedAccessor: accessor, + TTL: ttl, + }, + } + tvc.SetCreateTokenSecret(alloc.ID, task.Name, secret) + + req := &structs.DeriveVaultTokenRequest{ + NodeID: node.ID, + SecretID: node.SecretID, + AllocID: alloc.ID, + Tasks: tasks, + QueryOptions: structs.QueryOptions{ + Region: "global", + }, + } + + var resp structs.DeriveVaultTokenResponse + if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil { + t.Fatalf("bad: %v", err) + } + + // Check the state store and ensure that we created a VaultAccessor + va, err := state.VaultAccessor(accessor) + if err != nil { + t.Fatalf("bad: %v", err) + } + if va == nil { + t.Fatalf("bad: %v", va) + } + + if va.CreateIndex == 0 { + t.Fatalf("bad: %v", va) + } + + va.CreateIndex = 0 + expected := &structs.VaultAccessor{ + AllocID: alloc.ID, + Task: task.Name, + NodeID: alloc.NodeID, + Accessor: accessor, + CreationTTL: ttl, + } + + if !reflect.DeepEqual(expected, va) { + t.Fatalf("Got %#v; want %#v", va, expected) + } +} diff --git a/nomad/vault_testing.go b/nomad/vault_testing.go index d4fa4c0b8..73e5efc34 100644 --- a/nomad/vault_testing.go +++ b/nomad/vault_testing.go @@ -18,6 +18,14 @@ type TestVaultClient struct { // LookupTokenSecret maps a token to the Vault secret that will be returned // by the LookupToken call LookupTokenSecret map[string]*vapi.Secret + + // CreateTokenErrors maps a token to an error that will be returned by the + // CreateToken call + CreateTokenErrors map[string]map[string]error + + // CreateTokenSecret maps a token to the Vault secret that will be returned + // by the CreateToken call + CreateTokenSecret map[string]map[string]*vapi.Secret } func (v *TestVaultClient) LookupToken(ctx context.Context, token string) (*vapi.Secret, error) { @@ -67,7 +75,55 @@ func (v *TestVaultClient) SetLookupTokenAllowedPolicies(token string, policies [ } func (v *TestVaultClient) CreateToken(ctx context.Context, a *structs.Allocation, task string) (*vapi.Secret, error) { - return nil, nil + var secret *vapi.Secret + var err error + + if v.CreateTokenSecret != nil { + tasks := v.CreateTokenSecret[a.ID] + if tasks != nil { + secret = tasks[task] + } + } + if v.CreateTokenErrors != nil { + tasks := v.CreateTokenErrors[a.ID] + if tasks != nil { + err = tasks[task] + } + } + + return secret, err +} + +// SetCreateTokenError sets the error that will be returned by the token +// creation +func (v *TestVaultClient) SetCreateTokenError(allocID, task string, err error) { + if v.CreateTokenErrors == nil { + v.CreateTokenErrors = make(map[string]map[string]error) + } + + tasks := v.CreateTokenErrors[allocID] + if tasks == nil { + tasks = make(map[string]error) + v.CreateTokenErrors[allocID] = tasks + } + + v.CreateTokenErrors[allocID][task] = err +} + +// SetCreateTokenSecret sets the secret that will be returned by the token +// creation +func (v *TestVaultClient) SetCreateTokenSecret(allocID, task string, secret *vapi.Secret) { + if v.CreateTokenSecret == nil { + v.CreateTokenSecret = make(map[string]map[string]*vapi.Secret) + } + + tasks := v.CreateTokenSecret[allocID] + if tasks == nil { + tasks = make(map[string]*vapi.Secret) + v.CreateTokenSecret[allocID] = tasks + } + + v.CreateTokenSecret[allocID][task] = secret } func (v *TestVaultClient) Stop() {}