From 2408f99ccaec3e28bff7465396a2f926365cadaa Mon Sep 17 00:00:00 2001 From: Kyle Havlovitz Date: Wed, 12 Dec 2018 05:22:25 -0800 Subject: [PATCH] txn: add tests for RPC endpoint --- agent/consul/acl.go | 17 +- agent/consul/state/txn.go | 17 +- agent/consul/txn_endpoint.go | 2 +- agent/consul/txn_endpoint_test.go | 462 ++++++++++++++++++++++++++++-- 4 files changed, 461 insertions(+), 37 deletions(-) diff --git a/agent/consul/acl.go b/agent/consul/acl.go index 3f9d9eb14..749e0ece1 100644 --- a/agent/consul/acl.go +++ b/agent/consul/acl.go @@ -1354,6 +1354,11 @@ func vetDeregisterWithACL(rule acl.Authorizer, subj *structs.DeregisterRequest, // vetNodeTxnOp applies the given ACL policy to a node transaction operation. func vetNodeTxnOp(op *structs.TxnNodeOp, rule acl.Authorizer) error { + // Fast path if ACLs are not enabled. + if rule == nil { + return nil + } + node := op.Node // Filtering for GETs is done on the output side. @@ -1378,7 +1383,7 @@ func vetNodeTxnOp(op *structs.TxnNodeOp, rule acl.Authorizer) error { } } - if !rule.NodeWrite(node.Node, scope) { + if rule != nil && !rule.NodeWrite(node.Node, scope) { return acl.ErrPermissionDenied } @@ -1387,6 +1392,11 @@ func vetNodeTxnOp(op *structs.TxnNodeOp, rule acl.Authorizer) error { // vetServiceTxnOp applies the given ACL policy to a service transaction operation. func vetServiceTxnOp(op *structs.TxnServiceOp, rule acl.Authorizer) error { + // Fast path if ACLs are not enabled. + if rule == nil { + return nil + } + service := op.Service // Filtering for GETs is done on the output side. @@ -1416,6 +1426,11 @@ func vetServiceTxnOp(op *structs.TxnServiceOp, rule acl.Authorizer) error { // vetCheckTxnOp applies the given ACL policy to a check transaction operation. func vetCheckTxnOp(op *structs.TxnCheckOp, rule acl.Authorizer) error { + // Fast path if ACLs are not enabled. + if rule == nil { + return nil + } + // Filtering for GETs is done on the output side. if op.Verb == api.CheckGet { return nil diff --git a/agent/consul/state/txn.go b/agent/consul/state/txn.go index 264e04d15..aa6458f02 100644 --- a/agent/consul/state/txn.go +++ b/agent/consul/state/txn.go @@ -132,16 +132,19 @@ func (s *Store) txnNode(tx *memdb.Txn, idx uint64, op *structs.TxnNodeOp) (struc entry, err = getNodeIDTxn(tx, op.Node.ID) case api.NodeSet: - entry = &op.Node err = s.ensureNodeTxn(tx, idx, &op.Node) + if err == nil { + entry, err = getNodeIDTxn(tx, op.Node.ID) + } case api.NodeCAS: var ok bool - entry = &op.Node ok, err = s.ensureNodeCASTxn(tx, idx, &op.Node) if !ok && err == nil { err = fmt.Errorf("failed to set node %q, index is stale", op.Node.Node) + break } + entry, err = getNodeIDTxn(tx, op.Node.ID) case api.NodeDelete: err = s.deleteNodeTxn(tx, idx, op.Node.Node) @@ -187,8 +190,8 @@ func (s *Store) txnService(tx *memdb.Txn, idx uint64, op *structs.TxnServiceOp) entry, err = s.nodeServiceTxn(tx, op.Node, op.Service.ID) case api.ServiceSet: - entry = &op.Service err = s.ensureServiceTxn(tx, idx, op.Node, &op.Service) + entry, err = s.nodeServiceTxn(tx, op.Node, op.Service.ID) case api.ServiceCAS: var ok bool @@ -246,8 +249,10 @@ func (s *Store) txnCheck(tx *memdb.Txn, idx uint64, op *structs.TxnCheckOp) (str } case api.CheckSet: - entry = &op.Check - err = s.ensureCheckTxn(tx, idx, entry) + err = s.ensureCheckTxn(tx, idx, &op.Check) + if err == nil { + _, entry, err = s.nodeCheckTxn(tx, op.Check.Node, op.Check.CheckID) + } case api.CheckCAS: var ok bool @@ -255,7 +260,9 @@ func (s *Store) txnCheck(tx *memdb.Txn, idx uint64, op *structs.TxnCheckOp) (str ok, err = s.ensureCheckCASTxn(tx, idx, entry) if !ok && err == nil { err = fmt.Errorf("failed to set check %q on node %q, index is stale", entry.CheckID, entry.Node) + break } + _, entry, err = s.nodeCheckTxn(tx, op.Check.Node, op.Check.CheckID) case api.CheckDelete: err = s.deleteCheckTxn(tx, idx, op.Check.Node, op.Check.CheckID) diff --git a/agent/consul/txn_endpoint.go b/agent/consul/txn_endpoint.go index 27b926908..b69113b4e 100644 --- a/agent/consul/txn_endpoint.go +++ b/agent/consul/txn_endpoint.go @@ -55,7 +55,7 @@ func (t *Txn) preCheck(authorizer acl.Authorizer, ops structs.TxnOps) structs.Tx } case op.Service != nil: service := &op.Service.Service - if err := servicePreApply(service, authorizer); err != nil { + if err := servicePreApply(service, nil); err != nil { errors = append(errors, &structs.TxnError{ OpIndex: i, What: err.Error(), diff --git a/agent/consul/txn_endpoint_test.go b/agent/consul/txn_endpoint_test.go index c1447cf99..5cbfb56f2 100644 --- a/agent/consul/txn_endpoint_test.go +++ b/agent/consul/txn_endpoint_test.go @@ -12,9 +12,49 @@ import ( "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/consul/types" "github.com/hashicorp/net-rpc-msgpackrpc" + "github.com/pascaldekloe/goe/verify" + "github.com/stretchr/testify/require" ) +var testTxnRules = ` +key "" { + policy = "deny" +} +key "foo" { + policy = "read" +} +key "test" { + policy = "write" +} +key "test/priv" { + policy = "read" +} + +service "" { + policy = "deny" +} +service "foo-svc" { + policy = "read" +} +service "test-svc" { + policy = "write" +} + +node "" { + policy = "deny" +} +node "foo-node" { + policy = "read" +} +node "test-node" { + policy = "write" +} +` + +var testNodeID = "9749a7df-fac5-46b4-8078-32a3d96c59f3" + func TestTxn_CheckNotExists(t *testing.T) { t.Parallel() dir1, s1 := testServer(t) @@ -101,12 +141,76 @@ func TestTxn_Apply(t *testing.T) { }, }, }, + &structs.TxnOp{ + Node: &structs.TxnNodeOp{ + Verb: api.NodeSet, + Node: structs.Node{ + ID: types.NodeID(testNodeID), + Node: "foo", + Address: "127.0.0.1", + }, + }, + }, + &structs.TxnOp{ + Node: &structs.TxnNodeOp{ + Verb: api.NodeGet, + Node: structs.Node{ + ID: types.NodeID(testNodeID), + Node: "foo", + }, + }, + }, + &structs.TxnOp{ + Service: &structs.TxnServiceOp{ + Verb: api.ServiceSet, + Node: "foo", + Service: structs.NodeService{ + ID: "svc-foo", + Service: "svc-foo", + Address: "1.1.1.1", + }, + }, + }, + &structs.TxnOp{ + Service: &structs.TxnServiceOp{ + Verb: api.ServiceGet, + Node: "foo", + Service: structs.NodeService{ + ID: "svc-foo", + Service: "svc-foo", + }, + }, + }, + &structs.TxnOp{ + Check: &structs.TxnCheckOp{ + Verb: api.CheckSet, + Check: structs.HealthCheck{ + Node: "foo", + CheckID: types.CheckID("check-foo"), + Name: "test", + Status: "passing", + }, + }, + }, + &structs.TxnOp{ + Check: &structs.TxnCheckOp{ + Verb: api.CheckGet, + Check: structs.HealthCheck{ + Node: "foo", + CheckID: types.CheckID("check-foo"), + Name: "test", + }, + }, + }, }, } var out structs.TxnResponse if err := msgpackrpc.CallWithCodec(codec, "Txn.Apply", &arg, &out); err != nil { t.Fatalf("err: %v", err) } + if len(out.Errors) != 0 { + t.Fatalf("errs: %v", out.Errors) + } // Verify the state store directly. state := s1.fsm.State() @@ -122,6 +226,30 @@ func TestTxn_Apply(t *testing.T) { t.Fatalf("bad: %v", d) } + _, n, err := state.GetNode("foo") + if err != nil { + t.Fatalf("err: %v", err) + } + if n.Node != "foo" || n.Address != "127.0.0.1" { + t.Fatalf("bad: %v", err) + } + + _, s, err := state.NodeService("foo", "svc-foo") + if err != nil { + t.Fatalf("err: %v", err) + } + if s.ID != "svc-foo" || s.Address != "1.1.1.1" { + t.Fatalf("bad: %v", err) + } + + _, c, err := state.NodeCheck("foo", types.CheckID("check-foo")) + if err != nil { + t.Fatalf("err: %v", err) + } + if c.CheckID != "check-foo" || c.Status != "passing" || c.Name != "test" { + t.Fatalf("bad: %v", err) + } + // Verify the transaction's return value. expected := structs.TxnResponse{ Results: structs.TxnResults{ @@ -147,15 +275,34 @@ func TestTxn_Apply(t *testing.T) { }, }, }, + &structs.TxnResult{ + Node: n, + }, + &structs.TxnResult{ + Node: n, + }, + &structs.TxnResult{ + Service: s, + }, + &structs.TxnResult{ + Service: s, + }, + &structs.TxnResult{ + Check: c, + }, + &structs.TxnResult{ + Check: c, + }, }, } - if !reflect.DeepEqual(out, expected) { - t.Fatalf("bad %v", out) - } + verify.Values(t, "", out, expected) } func TestTxn_Apply_ACLDeny(t *testing.T) { t.Parallel() + + require := require.New(t) + dir1, s1 := testServerWithConfig(t, func(c *Config) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true @@ -167,15 +314,25 @@ func TestTxn_Apply_ACLDeny(t *testing.T) { testrpc.WaitForLeader(t, s1.RPC, "dc1") - // Put in a key to read back. + // Set up some state to read back. state := s1.fsm.State() d := &structs.DirEntry{ Key: "nope", Value: []byte("hello"), } - if err := state.KVSSet(1, d); err != nil { - t.Fatalf("err: %v", err) + require.NoError(state.KVSSet(1, d)) + + node := &structs.Node{ + ID: types.NodeID(testNodeID), + Node: "nope", } + require.NoError(state.EnsureNode(2, node)) + + svc := structs.NodeService{ID: "nope", Service: "nope", Address: "127.0.0.1"} + require.NoError(state.EnsureService(3, "nope", &svc)) + + check := structs.HealthCheck{Node: "nope", CheckID: types.CheckID("nope")} + state.EnsureCheck(4, &check) // Create the ACL. var id string @@ -186,7 +343,7 @@ func TestTxn_Apply_ACLDeny(t *testing.T) { ACL: structs.ACL{ Name: "User token", Type: structs.ACLTokenTypeClient, - Rules: testListRules, + Rules: testTxnRules, }, WriteRequest: structs.WriteRequest{Token: "root"}, } @@ -296,6 +453,101 @@ func TestTxn_Apply_ACLDeny(t *testing.T) { }, }, }, + &structs.TxnOp{ + Node: &structs.TxnNodeOp{ + Verb: api.NodeGet, + Node: structs.Node{ID: node.ID, Node: node.Node}, + }, + }, + &structs.TxnOp{ + Node: &structs.TxnNodeOp{ + Verb: api.NodeSet, + Node: structs.Node{ID: node.ID, Node: node.Node}, + }, + }, + &structs.TxnOp{ + Node: &structs.TxnNodeOp{ + Verb: api.NodeCAS, + Node: structs.Node{ID: node.ID, Node: node.Node}, + }, + }, + &structs.TxnOp{ + Node: &structs.TxnNodeOp{ + Verb: api.NodeDelete, + Node: structs.Node{ID: node.ID, Node: node.Node}, + }, + }, + &structs.TxnOp{ + Node: &structs.TxnNodeOp{ + Verb: api.NodeDeleteCAS, + Node: structs.Node{ID: node.ID, Node: node.Node}, + }, + }, + &structs.TxnOp{ + Service: &structs.TxnServiceOp{ + Verb: api.ServiceGet, + Node: "foo-node", + Service: svc, + }, + }, + &structs.TxnOp{ + Service: &structs.TxnServiceOp{ + Verb: api.ServiceSet, + Node: "foo-node", + Service: svc, + }, + }, + &structs.TxnOp{ + Service: &structs.TxnServiceOp{ + Verb: api.ServiceCAS, + Node: "foo-node", + Service: svc, + }, + }, + &structs.TxnOp{ + Service: &structs.TxnServiceOp{ + Verb: api.ServiceDelete, + Node: "foo-node", + Service: svc, + }, + }, + &structs.TxnOp{ + Service: &structs.TxnServiceOp{ + Verb: api.ServiceDeleteCAS, + Node: "foo-node", + Service: svc, + }, + }, + &structs.TxnOp{ + Check: &structs.TxnCheckOp{ + Verb: api.CheckGet, + Check: check, + }, + }, + &structs.TxnOp{ + Check: &structs.TxnCheckOp{ + Verb: api.CheckSet, + Check: check, + }, + }, + &structs.TxnOp{ + Check: &structs.TxnCheckOp{ + Verb: api.CheckCAS, + Check: check, + }, + }, + &structs.TxnOp{ + Check: &structs.TxnCheckOp{ + Verb: api.CheckDelete, + Check: check, + }, + }, + &structs.TxnOp{ + Check: &structs.TxnCheckOp{ + Verb: api.CheckDeleteCAS, + Check: check, + }, + }, }, WriteRequest: structs.WriteRequest{ Token: id, @@ -309,20 +561,55 @@ func TestTxn_Apply_ACLDeny(t *testing.T) { // Verify the transaction's return value. var expected structs.TxnResponse for i, op := range arg.Ops { - switch op.KV.Verb { - case api.KVGet, api.KVGetTree: - // These get filtered but won't result in an error. + switch { + case op.KV != nil: + switch op.KV.Verb { + case api.KVGet, api.KVGetTree: + // These get filtered but won't result in an error. - default: - expected.Errors = append(expected.Errors, &structs.TxnError{ - OpIndex: i, - What: acl.ErrPermissionDenied.Error(), - }) + default: + expected.Errors = append(expected.Errors, &structs.TxnError{ + OpIndex: i, + What: acl.ErrPermissionDenied.Error(), + }) + } + case op.Node != nil: + switch op.Node.Verb { + case api.NodeGet: + // These get filtered but won't result in an error. + + default: + expected.Errors = append(expected.Errors, &structs.TxnError{ + OpIndex: i, + What: acl.ErrPermissionDenied.Error(), + }) + } + case op.Service != nil: + switch op.Service.Verb { + case api.ServiceGet: + // These get filtered but won't result in an error. + + default: + expected.Errors = append(expected.Errors, &structs.TxnError{ + OpIndex: i, + What: acl.ErrPermissionDenied.Error(), + }) + } + case op.Check != nil: + switch op.Check.Verb { + case api.CheckGet: + // These get filtered but won't result in an error. + + default: + expected.Errors = append(expected.Errors, &structs.TxnError{ + OpIndex: i, + What: acl.ErrPermissionDenied.Error(), + }) + } } } - if !reflect.DeepEqual(out, expected) { - t.Fatalf("bad %v", out) - } + + verify.Values(t, "", out, expected) } func TestTxn_Apply_LockDelay(t *testing.T) { @@ -413,6 +700,9 @@ func TestTxn_Apply_LockDelay(t *testing.T) { func TestTxn_Read(t *testing.T) { t.Parallel() + + require := require.New(t) + dir1, s1 := testServer(t) defer os.RemoveAll(dir1) defer s1.Shutdown() @@ -431,6 +721,19 @@ func TestTxn_Read(t *testing.T) { t.Fatalf("err: %v", err) } + // Put in a node/check/service to read back. + node := &structs.Node{ + ID: types.NodeID(testNodeID), + Node: "foo", + } + require.NoError(state.EnsureNode(2, node)) + + svc := structs.NodeService{ID: "svc-foo", Service: "svc-foo", Address: "127.0.0.1"} + require.NoError(state.EnsureService(3, "foo", &svc)) + + check := structs.HealthCheck{Node: "foo", CheckID: types.CheckID("check-foo")} + state.EnsureCheck(4, &check) + // Do a super basic request. The state store test covers the details so // we just need to be sure that the transaction is sent correctly and // the results are converted appropriately. @@ -445,6 +748,25 @@ func TestTxn_Read(t *testing.T) { }, }, }, + &structs.TxnOp{ + Node: &structs.TxnNodeOp{ + Verb: api.NodeGet, + Node: structs.Node{ID: node.ID, Node: node.Node}, + }, + }, + &structs.TxnOp{ + Service: &structs.TxnServiceOp{ + Verb: api.ServiceGet, + Node: "foo", + Service: svc, + }, + }, + &structs.TxnOp{ + Check: &structs.TxnCheckOp{ + Verb: api.CheckGet, + Check: check, + }, + }, }, } var out structs.TxnReadResponse @@ -453,6 +775,8 @@ func TestTxn_Read(t *testing.T) { } // Verify the transaction's return value. + svc.Weights = &structs.Weights{Passing: 1, Warning: 1} + svc.RaftIndex = structs.RaftIndex{CreateIndex: 3, ModifyIndex: 3} expected := structs.TxnReadResponse{ TxnResponse: structs.TxnResponse{ Results: structs.TxnResults{ @@ -466,19 +790,29 @@ func TestTxn_Read(t *testing.T) { }, }, }, + &structs.TxnResult{ + Node: node, + }, + &structs.TxnResult{ + Service: &svc, + }, + &structs.TxnResult{ + Check: &check, + }, }, }, QueryMeta: structs.QueryMeta{ KnownLeader: true, }, } - if !reflect.DeepEqual(out, expected) { - t.Fatalf("bad %v", out) - } + verify.Values(t, "", out, expected) } func TestTxn_Read_ACLDeny(t *testing.T) { t.Parallel() + + require := require.New(t) + dir1, s1 := testServerWithConfig(t, func(c *Config) { c.ACLDatacenter = "dc1" c.ACLsEnabled = true @@ -502,6 +836,19 @@ func TestTxn_Read_ACLDeny(t *testing.T) { t.Fatalf("err: %v", err) } + // Put in a node/check/service to read back. + node := &structs.Node{ + ID: types.NodeID(testNodeID), + Node: "nope", + } + require.NoError(state.EnsureNode(2, node)) + + svc := structs.NodeService{ID: "nope", Service: "nope", Address: "127.0.0.1"} + require.NoError(state.EnsureService(3, "nope", &svc)) + + check := structs.HealthCheck{Node: "nope", CheckID: types.CheckID("nope")} + state.EnsureCheck(4, &check) + // Create the ACL. var id string { @@ -511,7 +858,7 @@ func TestTxn_Read_ACLDeny(t *testing.T) { ACL: structs.ACL{ Name: "User token", Type: structs.ACLTokenTypeClient, - Rules: testListRules, + Rules: testTxnRules, }, WriteRequest: structs.WriteRequest{Token: "root"}, } @@ -557,6 +904,25 @@ func TestTxn_Read_ACLDeny(t *testing.T) { }, }, }, + &structs.TxnOp{ + Node: &structs.TxnNodeOp{ + Verb: api.NodeGet, + Node: structs.Node{ID: node.ID, Node: node.Node}, + }, + }, + &structs.TxnOp{ + Service: &structs.TxnServiceOp{ + Verb: api.ServiceGet, + Node: "foo", + Service: svc, + }, + }, + &structs.TxnOp{ + Check: &structs.TxnCheckOp{ + Verb: api.CheckGet, + Check: check, + }, + }, }, QueryOptions: structs.QueryOptions{ Token: id, @@ -574,15 +940,51 @@ func TestTxn_Read_ACLDeny(t *testing.T) { }, } for i, op := range arg.Ops { - switch op.KV.Verb { - case api.KVGet, api.KVGetTree: - // These get filtered but won't result in an error. + switch { + case op.KV != nil: + switch op.KV.Verb { + case api.KVGet, api.KVGetTree: + // These get filtered but won't result in an error. - default: - expected.Errors = append(expected.Errors, &structs.TxnError{ - OpIndex: i, - What: acl.ErrPermissionDenied.Error(), - }) + default: + expected.Errors = append(expected.Errors, &structs.TxnError{ + OpIndex: i, + What: acl.ErrPermissionDenied.Error(), + }) + } + case op.Node != nil: + switch op.Node.Verb { + case api.NodeGet: + // These get filtered but won't result in an error. + + default: + expected.Errors = append(expected.Errors, &structs.TxnError{ + OpIndex: i, + What: acl.ErrPermissionDenied.Error(), + }) + } + case op.Service != nil: + switch op.Service.Verb { + case api.ServiceGet: + // These get filtered but won't result in an error. + + default: + expected.Errors = append(expected.Errors, &structs.TxnError{ + OpIndex: i, + What: acl.ErrPermissionDenied.Error(), + }) + } + case op.Check != nil: + switch op.Check.Verb { + case api.CheckGet: + // These get filtered but won't result in an error. + + default: + expected.Errors = append(expected.Errors, &structs.TxnError{ + OpIndex: i, + What: acl.ErrPermissionDenied.Error(), + }) + } } } if !reflect.DeepEqual(out, expected) {