From ba13416b5726e337db2830969342cf3d55b078ba Mon Sep 17 00:00:00 2001 From: "R.B. Boyer" <4903+rboyer@users.noreply.github.com> Date: Wed, 22 Sep 2021 13:14:26 -0500 Subject: [PATCH] grpc: strip local ACL tokens from RPCs during forwarding if crossing datacenters (#11099) Fixes #11086 --- .changelog/11099.txt | 6 + agent/consul/rpc.go | 102 +++++++++- agent/consul/rpc_test.go | 222 ++++++++++++++++++++- agent/consul/subscribe_backend.go | 17 +- agent/consul/subscribe_backend_test.go | 14 +- agent/rpc/subscribe/subscribe.go | 4 +- agent/rpc/subscribe/subscribe_test.go | 2 +- agent/submatview/store_integration_test.go | 2 +- proto/pbcommon/common.go | 34 ++++ proto/pbsubscribe/subscribe.go | 33 +++ 10 files changed, 399 insertions(+), 37 deletions(-) create mode 100644 .changelog/11099.txt create mode 100644 proto/pbsubscribe/subscribe.go diff --git a/.changelog/11099.txt b/.changelog/11099.txt new file mode 100644 index 000000000..b6f68d180 --- /dev/null +++ b/.changelog/11099.txt @@ -0,0 +1,6 @@ +```release-note:bug +grpc: strip local ACL tokens from RPCs during forwarding if crossing datacenters +``` +```release-note:feature +partitions: allow for partition queries to be forwarded +``` diff --git a/agent/consul/rpc.go b/agent/consul/rpc.go index e2c0a1419..c8d733a28 100644 --- a/agent/consul/rpc.go +++ b/agent/consul/rpc.go @@ -22,6 +22,7 @@ import ( msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/raft" "github.com/hashicorp/yamux" + "google.golang.org/grpc" "github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/agent/consul/state" @@ -556,13 +557,87 @@ func canRetry(info structs.RPCInfo, err error, start time.Time, config *Config) return info != nil && info.IsRead() && lib.IsErrEOF(err) } -// ForwardRPC is used to forward an RPC request to a remote DC or to the local leader -// Returns a bool of if forwarding was performed, as well as any error +// ForwardRPC is used to potentially forward an RPC request to a remote DC or +// to the local leader depending upon the request. +// +// Returns a bool of if forwarding was performed, as well as any error. If +// false is returned (with no error) it is assumed that the current server +// should handle the request. func (s *Server) ForwardRPC(method string, info structs.RPCInfo, reply interface{}) (bool, error) { - firstCheck := time.Now() + forwardToDC := func(dc string) error { + return s.forwardDC(method, dc, info, reply) + } + forwardToLeader := func(leader *metadata.Server) error { + return s.connPool.RPC(s.config.Datacenter, leader.ShortName, leader.Addr, + method, info, reply) + } + return s.forwardRPC(info, forwardToDC, forwardToLeader) +} +// ForwardGRPC is used to potentially forward an RPC request to a remote DC or +// to the local leader depending upon the request. +// +// Returns a bool of if forwarding was performed, as well as any error. If +// false is returned (with no error) it is assumed that the current server +// should handle the request. +func (s *Server) ForwardGRPC(connPool GRPCClientConner, info structs.RPCInfo, f func(*grpc.ClientConn) error) (handled bool, err error) { + forwardToDC := func(dc string) error { + conn, err := connPool.ClientConn(dc) + if err != nil { + return err + } + return f(conn) + } + forwardToLeader := func(leader *metadata.Server) error { + conn, err := connPool.ClientConnLeader() + if err != nil { + return err + } + return f(conn) + } + return s.forwardRPC(info, forwardToDC, forwardToLeader) +} + +// forwardRPC is used to potentially forward an RPC request to a remote DC or +// to the local leader depending upon the request. +// +// If info.RequestDatacenter() does not match the local datacenter, then the +// request will be forwarded to the DC using forwardToDC. +// +// Stale read requests will be handled locally if the current node has an +// initialized raft database, otherwise requests will be forwarded to the local +// leader using forwardToLeader. +// +// Returns a bool of if forwarding was performed, as well as any error. If +// false is returned (with no error) it is assumed that the current server +// should handle the request. +func (s *Server) forwardRPC( + info structs.RPCInfo, + forwardToDC func(dc string) error, + forwardToLeader func(leader *metadata.Server) error, +) (handled bool, err error) { + // Forward the request to the requested datacenter. + if handled, err := s.forwardRequestToOtherDatacenter(info, forwardToDC); handled || err != nil { + return handled, err + } + + // See if we should let this server handle the read request without + // shipping the request to the leader. + if s.canServeReadRequest(info) { + return false, nil + } + + return s.forwardRequestToLeader(info, forwardToLeader) +} + +// forwardRequestToOtherDatacenter is an implementation detail of forwardRPC. +// See the comment for forwardRPC for more details. +func (s *Server) forwardRequestToOtherDatacenter(info structs.RPCInfo, forwardToDC func(dc string) error) (handled bool, err error) { // Handle DC forwarding dc := info.RequestDatacenter() + if dc == "" { + dc = s.config.Datacenter + } if dc != s.config.Datacenter { // Local tokens only work within the current datacenter. Check to see // if we are attempting to forward one to a remote datacenter and strip @@ -581,15 +656,23 @@ func (s *Server) ForwardRPC(method string, info structs.RPCInfo, reply interface } } - err := s.forwardDC(method, dc, info, reply) - return true, err + return true, forwardToDC(dc) } + return false, nil +} + +// canServeReadRequest determines if the request is a stale read request and +// the current node can safely process that request. +func (s *Server) canServeReadRequest(info structs.RPCInfo) bool { // Check if we can allow a stale read, ensure our local DB is initialized - if info.IsRead() && info.AllowStaleRead() && !s.raft.LastContact().IsZero() { - return false, nil - } + return info.IsRead() && info.AllowStaleRead() && !s.raft.LastContact().IsZero() +} +// forwardRequestToLeader is an implementation detail of forwardRPC. +// See the comment for forwardRPC for more details. +func (s *Server) forwardRequestToLeader(info structs.RPCInfo, forwardToLeader func(leader *metadata.Server) error) (handled bool, err error) { + firstCheck := time.Now() CHECK_LEADER: // Fail fast if we are in the process of leaving select { @@ -608,8 +691,7 @@ CHECK_LEADER: // Handle the case of a known leader if leader != nil { - rpcErr = s.connPool.RPC(s.config.Datacenter, leader.ShortName, leader.Addr, - method, info, reply) + rpcErr = forwardToLeader(leader) if rpcErr == nil { return true, nil } diff --git a/agent/consul/rpc_test.go b/agent/consul/rpc_test.go index 6565d1ce7..f1b23872a 100644 --- a/agent/consul/rpc_test.go +++ b/agent/consul/rpc_test.go @@ -3,6 +3,7 @@ package consul import ( "bufio" "bytes" + "context" "crypto/x509" "encoding/binary" "errors" @@ -25,15 +26,17 @@ import ( "github.com/hashicorp/raft" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/hashicorp/consul/agent/connect" + "google.golang.org/grpc" "github.com/hashicorp/consul/acl" + "github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/consul/state" + agent_grpc "github.com/hashicorp/consul/agent/grpc" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/structs" tokenStore "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/api" + "github.com/hashicorp/consul/proto/pbsubscribe" "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/testrpc" @@ -964,6 +967,201 @@ func TestRPC_LocalTokenStrippedOnForward(t *testing.T) { require.Equal(t, localToken2.SecretID, arg.WriteRequest.Token, "token should not be stripped") } +func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + t.Parallel() + dir1, s1 := testServerWithConfig(t, func(c *Config) { + c.PrimaryDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLResolverSettings.ACLDefaultPolicy = "deny" + c.ACLMasterToken = "root" + c.RPCConfig.EnableStreaming = true + }) + s1.tokens.UpdateAgentToken("root", tokenStore.TokenSourceConfig) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + testrpc.WaitForLeader(t, s1.RPC, "dc1") + codec := rpcClient(t, s1) + defer codec.Close() + + dir2, s2 := testServerWithConfig(t, func(c *Config) { + c.Datacenter = "dc2" + c.PrimaryDatacenter = "dc1" + c.ACLsEnabled = true + c.ACLResolverSettings.ACLDefaultPolicy = "deny" + c.ACLTokenReplication = true + c.ACLReplicationRate = 100 + c.ACLReplicationBurst = 100 + c.ACLReplicationApplyLimit = 1000000 + c.RPCConfig.EnableStreaming = true + }) + s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig) + s2.tokens.UpdateAgentToken("root", tokenStore.TokenSourceConfig) + testrpc.WaitForLeader(t, s2.RPC, "dc2") + defer os.RemoveAll(dir2) + defer s2.Shutdown() + codec2 := rpcClient(t, s2) + defer codec2.Close() + + // Try to join. + joinWAN(t, s2, s1) + testrpc.WaitForLeader(t, s1.RPC, "dc1") + testrpc.WaitForLeader(t, s1.RPC, "dc2") + + // Wait for legacy acls to be disabled so we are clear that + // legacy replication isn't meddling. + waitForNewACLs(t, s1) + waitForNewACLs(t, s2) + waitForNewACLReplication(t, s2, structs.ACLReplicateTokens, 1, 1, 0) + + // create simple service policy + policy, err := upsertTestPolicyWithRules(codec, "root", "dc1", ` + node_prefix "" { policy = "read" } + service_prefix "" { policy = "read" } + `) + require.NoError(t, err) + + // Wait for it to replicate + retry.Run(t, func(r *retry.R) { + _, p, err := s2.fsm.State().ACLPolicyGetByID(nil, policy.ID, &structs.EnterpriseMeta{}) + require.Nil(r, err) + require.NotNil(r, p) + }) + + // create local token that only works in DC2 + localToken2, err := upsertTestToken(codec, "root", "dc2", func(token *structs.ACLToken) { + token.Local = true + token.Policies = []structs.ACLTokenPolicyLink{ + {ID: policy.ID}, + } + }) + require.NoError(t, err) + + runStep(t, "Register a dummy node with a service", func(t *testing.T) { + req := &structs.RegisterRequest{ + Node: "node1", + Address: "3.4.5.6", + Datacenter: "dc1", + Service: &structs.NodeService{ + ID: "redis1", + Service: "redis", + Address: "3.4.5.6", + Port: 8080, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + var out struct{} + require.NoError(t, s1.RPC("Catalog.Register", &req, &out)) + }) + + var conn *grpc.ClientConn + { + client, builder := newClientWithGRPCResolver(t, func(c *Config) { + c.Datacenter = "dc2" + c.PrimaryDatacenter = "dc1" + c.RPCConfig.EnableStreaming = true + }) + joinLAN(t, client, s2) + testrpc.WaitForTestAgent(t, client.RPC, "dc2", testrpc.WithToken("root")) + + pool := agent_grpc.NewClientConnPool(agent_grpc.ClientConnPoolConfig{ + Servers: builder, + DialingFromServer: false, + DialingFromDatacenter: "dc2", + }) + + conn, err = pool.ClientConn("dc2") + require.NoError(t, err) + } + + // Try to use it locally (it should work) + runStep(t, "token used locally should work", func(t *testing.T) { + arg := &pbsubscribe.SubscribeRequest{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Token: localToken2.SecretID, + Datacenter: "dc2", + } + event, err := getFirstSubscribeEventOrError(conn, arg) + require.NoError(t, err) + require.NotNil(t, event) + + // make sure that token restore defer works + require.Equal(t, localToken2.SecretID, arg.Token, "token should not be stripped") + }) + + runStep(t, "token used remotely should not work", func(t *testing.T) { + arg := &pbsubscribe.SubscribeRequest{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Token: localToken2.SecretID, + Datacenter: "dc1", + } + + event, err := getFirstSubscribeEventOrError(conn, arg) + + // NOTE: the subscription endpoint is a filtering style instead of a + // hard-fail style so when the token isn't present 100% of the data is + // filtered out leading to a stream with an empty snapshot. + require.NoError(t, err) + require.IsType(t, &pbsubscribe.Event_EndOfSnapshot{}, event.Payload) + require.True(t, event.Payload.(*pbsubscribe.Event_EndOfSnapshot).EndOfSnapshot) + }) + + runStep(t, "update anonymous token to read services", func(t *testing.T) { + tokenUpsertReq := structs.ACLTokenSetRequest{ + Datacenter: "dc1", + ACLToken: structs.ACLToken{ + AccessorID: structs.ACLTokenAnonymousID, + Policies: []structs.ACLTokenPolicyLink{ + {ID: policy.ID}, + }, + }, + WriteRequest: structs.WriteRequest{Token: "root"}, + } + token := structs.ACLToken{} + err = msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &tokenUpsertReq, &token) + require.NoError(t, err) + require.NotEmpty(t, token.SecretID) + }) + + runStep(t, "token used remotely should fallback on anonymous token now", func(t *testing.T) { + arg := &pbsubscribe.SubscribeRequest{ + Topic: pbsubscribe.Topic_ServiceHealth, + Key: "redis", + Token: localToken2.SecretID, + Datacenter: "dc1", + } + + event, err := getFirstSubscribeEventOrError(conn, arg) + require.NoError(t, err) + require.NotNil(t, event) + + // So now that we can read data, we should get a snapshot with just instances of the "consul" service. + require.NoError(t, err) + + require.IsType(t, &pbsubscribe.Event_ServiceHealth{}, event.Payload) + esh := event.Payload.(*pbsubscribe.Event_ServiceHealth) + + require.Equal(t, pbsubscribe.CatalogOp_Register, esh.ServiceHealth.Op) + csn := esh.ServiceHealth.CheckServiceNode + + require.NotNil(t, csn) + require.NotNil(t, csn.Node) + require.Equal(t, "node1", csn.Node.Node) + require.Equal(t, "3.4.5.6", csn.Node.Address) + require.NotNil(t, csn.Service) + require.Equal(t, "redis1", csn.Service.ID) + require.Equal(t, "redis", csn.Service.Service) + + // make sure that token restore defer works + require.Equal(t, localToken2.SecretID, arg.Token, "token should not be stripped") + }) +} + func TestCanRetry(t *testing.T) { type testCase struct { name string @@ -1362,3 +1560,23 @@ func isConnectionClosedError(err error) bool { return false } } + +func getFirstSubscribeEventOrError(conn *grpc.ClientConn, req *pbsubscribe.SubscribeRequest) (*pbsubscribe.Event, error) { + streamClient := pbsubscribe.NewStateChangeSubscriptionClient(conn) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + handle, err := streamClient.Subscribe(ctx, req) + if err != nil { + return nil, err + } + + event, err := handle.Recv() + if err == io.EOF { + return nil, nil + } + if err != nil { + return nil, err + } + return event, nil +} diff --git a/agent/consul/subscribe_backend.go b/agent/consul/subscribe_backend.go index a1ba47236..8b6eddd84 100644 --- a/agent/consul/subscribe_backend.go +++ b/agent/consul/subscribe_backend.go @@ -26,21 +26,8 @@ func (s subscribeBackend) ResolveTokenAndDefaultMeta( var _ subscribe.Backend = (*subscribeBackend)(nil) -// Forward requests to a remote datacenter by calling f if the target dc does not -// match the config. Does nothing but return handled=false if dc is not specified, -// or if it matches the Datacenter in config. -// -// TODO: extract this so that it can be used with other grpc services. -// TODO: rename to ForwardToDC -func (s subscribeBackend) Forward(dc string, f func(*grpc.ClientConn) error) (handled bool, err error) { - if dc == "" || dc == s.srv.config.Datacenter { - return false, nil - } - conn, err := s.connPool.ClientConn(dc) - if err != nil { - return false, err - } - return true, f(conn) +func (s subscribeBackend) Forward(info structs.RPCInfo, f func(*grpc.ClientConn) error) (handled bool, err error) { + return s.srv.ForwardGRPC(s.connPool, info, f) } func (s subscribeBackend) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) { diff --git a/agent/consul/subscribe_backend_test.go b/agent/consul/subscribe_backend_test.go index e11d24b35..fe6a95732 100644 --- a/agent/consul/subscribe_backend_test.go +++ b/agent/consul/subscribe_backend_test.go @@ -362,17 +362,19 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T } func newClientWithGRPCResolver(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) { - builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, "client")) - resolver.Register(builder) - t.Cleanup(func() { - resolver.Deregister(builder.Authority()) - }) - _, config := testClientConfig(t) for _, op := range ops { op(config) } + builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, + "client."+config.Datacenter+"."+string(config.NodeID))) + + resolver.Register(builder) + t.Cleanup(func() { + resolver.Deregister(builder.Authority()) + }) + deps := newDefaultDeps(t, config) deps.Router = router.NewRouter( deps.Logger, diff --git a/agent/rpc/subscribe/subscribe.go b/agent/rpc/subscribe/subscribe.go index 0e7bfb24d..4c7255c62 100644 --- a/agent/rpc/subscribe/subscribe.go +++ b/agent/rpc/subscribe/subscribe.go @@ -37,13 +37,13 @@ var _ pbsubscribe.StateChangeSubscriptionServer = (*Server)(nil) type Backend interface { ResolveTokenAndDefaultMeta(token string, entMeta *structs.EnterpriseMeta, authzContext *acl.AuthorizerContext) (acl.Authorizer, error) - Forward(dc string, f func(*grpc.ClientConn) error) (handled bool, err error) + Forward(info structs.RPCInfo, f func(*grpc.ClientConn) error) (handled bool, err error) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) } func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsubscribe.StateChangeSubscription_SubscribeServer) error { logger := newLoggerForRequest(h.Logger, req) - handled, err := h.Backend.Forward(req.Datacenter, forwardToDC(req, serverStream, logger)) + handled, err := h.Backend.Forward(req, forwardToDC(req, serverStream, logger)) if handled || err != nil { return err } diff --git a/agent/rpc/subscribe/subscribe_test.go b/agent/rpc/subscribe/subscribe_test.go index a41740b03..d367a3ddb 100644 --- a/agent/rpc/subscribe/subscribe_test.go +++ b/agent/rpc/subscribe/subscribe_test.go @@ -292,7 +292,7 @@ func (b testBackend) ResolveTokenAndDefaultMeta( return b.authorizer(token, entMeta), nil } -func (b testBackend) Forward(_ string, fn func(*gogrpc.ClientConn) error) (handled bool, err error) { +func (b testBackend) Forward(_ structs.RPCInfo, fn func(*gogrpc.ClientConn) error) (handled bool, err error) { if b.forwardConn != nil { return true, fn(b.forwardConn) } diff --git a/agent/submatview/store_integration_test.go b/agent/submatview/store_integration_test.go index 38135e0ca..c8b6c21a8 100644 --- a/agent/submatview/store_integration_test.go +++ b/agent/submatview/store_integration_test.go @@ -146,7 +146,7 @@ func (b backend) ResolveTokenAndDefaultMeta(string, *structs.EnterpriseMeta, *ac return acl.AllowAll(), nil } -func (b backend) Forward(string, func(*grpc.ClientConn) error) (handled bool, err error) { +func (b backend) Forward(structs.RPCInfo, func(*grpc.ClientConn) error) (handled bool, err error) { return false, nil } diff --git a/proto/pbcommon/common.go b/proto/pbcommon/common.go index 97241341c..8850cc796 100644 --- a/proto/pbcommon/common.go +++ b/proto/pbcommon/common.go @@ -112,27 +112,61 @@ func (q *QueryMeta) GetBackend() structs.QueryBackend { } // WriteRequest only applies to writes, always false +// +// IsRead implements structs.RPCInfo func (w WriteRequest) IsRead() bool { return false } +// SetTokenSecret implements structs.RPCInfo func (w WriteRequest) TokenSecret() string { return w.Token } +// SetTokenSecret implements structs.RPCInfo func (w *WriteRequest) SetTokenSecret(s string) { w.Token = s } // AllowStaleRead returns whether a stale read should be allowed +// +// AllowStaleRead implements structs.RPCInfo func (w WriteRequest) AllowStaleRead() bool { return false } +// HasTimedOut implements structs.RPCInfo func (w WriteRequest) HasTimedOut(start time.Time, rpcHoldTimeout, _, _ time.Duration) bool { return time.Since(start) > rpcHoldTimeout } +// IsRead implements structs.RPCInfo +func (r *ReadRequest) IsRead() bool { + return true +} + +// AllowStaleRead implements structs.RPCInfo +func (r *ReadRequest) AllowStaleRead() bool { + // TODO(partitions): plumb this? + return false +} + +// TokenSecret implements structs.RPCInfo +func (r *ReadRequest) TokenSecret() string { + return r.Token +} + +// SetTokenSecret implements structs.RPCInfo +func (r *ReadRequest) SetTokenSecret(token string) { + r.Token = token +} + +// HasTimedOut implements structs.RPCInfo +func (r *ReadRequest) HasTimedOut(start time.Time, rpcHoldTimeout, maxQueryTime, defaultQueryTime time.Duration) bool { + return time.Since(start) > rpcHoldTimeout +} + +// RequestDatacenter implements structs.RPCInfo func (td TargetDatacenter) RequestDatacenter() string { return td.Datacenter } diff --git a/proto/pbsubscribe/subscribe.go b/proto/pbsubscribe/subscribe.go new file mode 100644 index 000000000..66f479fe6 --- /dev/null +++ b/proto/pbsubscribe/subscribe.go @@ -0,0 +1,33 @@ +package pbsubscribe + +import "time" + +// RequestDatacenter implements structs.RPCInfo +func (req *SubscribeRequest) RequestDatacenter() string { + return req.Datacenter +} + +// IsRead implements structs.RPCInfo +func (req *SubscribeRequest) IsRead() bool { + return true +} + +// AllowStaleRead implements structs.RPCInfo +func (req *SubscribeRequest) AllowStaleRead() bool { + return true +} + +// TokenSecret implements structs.RPCInfo +func (req *SubscribeRequest) TokenSecret() string { + return req.Token +} + +// SetTokenSecret implements structs.RPCInfo +func (req *SubscribeRequest) SetTokenSecret(token string) { + req.Token = token +} + +// HasTimedOut implements structs.RPCInfo +func (req *SubscribeRequest) HasTimedOut(start time.Time, rpcHoldTimeout, maxQueryTime, defaultQueryTime time.Duration) bool { + return time.Since(start) > rpcHoldTimeout +}