grpc: strip local ACL tokens from RPCs during forwarding if crossing datacenters (#11099)

Fixes #11086
This commit is contained in:
R.B. Boyer 2021-09-22 13:14:26 -05:00 committed by GitHub
parent b0b88286b8
commit ba13416b57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 399 additions and 37 deletions

6
.changelog/11099.txt Normal file
View File

@ -0,0 +1,6 @@
```release-note:bug
grpc: strip local ACL tokens from RPCs during forwarding if crossing datacenters
```
```release-note:feature
partitions: allow for partition queries to be forwarded
```

View File

@ -22,6 +22,7 @@ import (
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/raft" "github.com/hashicorp/raft"
"github.com/hashicorp/yamux" "github.com/hashicorp/yamux"
"google.golang.org/grpc"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/state"
@ -556,13 +557,87 @@ func canRetry(info structs.RPCInfo, err error, start time.Time, config *Config)
return info != nil && info.IsRead() && lib.IsErrEOF(err) return info != nil && info.IsRead() && lib.IsErrEOF(err)
} }
// ForwardRPC is used to forward an RPC request to a remote DC or to the local leader // ForwardRPC is used to potentially forward an RPC request to a remote DC or
// Returns a bool of if forwarding was performed, as well as any error // to the local leader depending upon the request.
//
// Returns a bool of if forwarding was performed, as well as any error. If
// false is returned (with no error) it is assumed that the current server
// should handle the request.
func (s *Server) ForwardRPC(method string, info structs.RPCInfo, reply interface{}) (bool, error) { func (s *Server) ForwardRPC(method string, info structs.RPCInfo, reply interface{}) (bool, error) {
firstCheck := time.Now() forwardToDC := func(dc string) error {
return s.forwardDC(method, dc, info, reply)
}
forwardToLeader := func(leader *metadata.Server) error {
return s.connPool.RPC(s.config.Datacenter, leader.ShortName, leader.Addr,
method, info, reply)
}
return s.forwardRPC(info, forwardToDC, forwardToLeader)
}
// ForwardGRPC is used to potentially forward an RPC request to a remote DC or
// to the local leader depending upon the request.
//
// Returns a bool of if forwarding was performed, as well as any error. If
// false is returned (with no error) it is assumed that the current server
// should handle the request.
func (s *Server) ForwardGRPC(connPool GRPCClientConner, info structs.RPCInfo, f func(*grpc.ClientConn) error) (handled bool, err error) {
forwardToDC := func(dc string) error {
conn, err := connPool.ClientConn(dc)
if err != nil {
return err
}
return f(conn)
}
forwardToLeader := func(leader *metadata.Server) error {
conn, err := connPool.ClientConnLeader()
if err != nil {
return err
}
return f(conn)
}
return s.forwardRPC(info, forwardToDC, forwardToLeader)
}
// forwardRPC is used to potentially forward an RPC request to a remote DC or
// to the local leader depending upon the request.
//
// If info.RequestDatacenter() does not match the local datacenter, then the
// request will be forwarded to the DC using forwardToDC.
//
// Stale read requests will be handled locally if the current node has an
// initialized raft database, otherwise requests will be forwarded to the local
// leader using forwardToLeader.
//
// Returns a bool of if forwarding was performed, as well as any error. If
// false is returned (with no error) it is assumed that the current server
// should handle the request.
func (s *Server) forwardRPC(
info structs.RPCInfo,
forwardToDC func(dc string) error,
forwardToLeader func(leader *metadata.Server) error,
) (handled bool, err error) {
// Forward the request to the requested datacenter.
if handled, err := s.forwardRequestToOtherDatacenter(info, forwardToDC); handled || err != nil {
return handled, err
}
// See if we should let this server handle the read request without
// shipping the request to the leader.
if s.canServeReadRequest(info) {
return false, nil
}
return s.forwardRequestToLeader(info, forwardToLeader)
}
// forwardRequestToOtherDatacenter is an implementation detail of forwardRPC.
// See the comment for forwardRPC for more details.
func (s *Server) forwardRequestToOtherDatacenter(info structs.RPCInfo, forwardToDC func(dc string) error) (handled bool, err error) {
// Handle DC forwarding // Handle DC forwarding
dc := info.RequestDatacenter() dc := info.RequestDatacenter()
if dc == "" {
dc = s.config.Datacenter
}
if dc != s.config.Datacenter { if dc != s.config.Datacenter {
// Local tokens only work within the current datacenter. Check to see // Local tokens only work within the current datacenter. Check to see
// if we are attempting to forward one to a remote datacenter and strip // if we are attempting to forward one to a remote datacenter and strip
@ -581,15 +656,23 @@ func (s *Server) ForwardRPC(method string, info structs.RPCInfo, reply interface
} }
} }
err := s.forwardDC(method, dc, info, reply) return true, forwardToDC(dc)
return true, err
} }
// Check if we can allow a stale read, ensure our local DB is initialized
if info.IsRead() && info.AllowStaleRead() && !s.raft.LastContact().IsZero() {
return false, nil return false, nil
} }
// canServeReadRequest determines if the request is a stale read request and
// the current node can safely process that request.
func (s *Server) canServeReadRequest(info structs.RPCInfo) bool {
// Check if we can allow a stale read, ensure our local DB is initialized
return info.IsRead() && info.AllowStaleRead() && !s.raft.LastContact().IsZero()
}
// forwardRequestToLeader is an implementation detail of forwardRPC.
// See the comment for forwardRPC for more details.
func (s *Server) forwardRequestToLeader(info structs.RPCInfo, forwardToLeader func(leader *metadata.Server) error) (handled bool, err error) {
firstCheck := time.Now()
CHECK_LEADER: CHECK_LEADER:
// Fail fast if we are in the process of leaving // Fail fast if we are in the process of leaving
select { select {
@ -608,8 +691,7 @@ CHECK_LEADER:
// Handle the case of a known leader // Handle the case of a known leader
if leader != nil { if leader != nil {
rpcErr = s.connPool.RPC(s.config.Datacenter, leader.ShortName, leader.Addr, rpcErr = forwardToLeader(leader)
method, info, reply)
if rpcErr == nil { if rpcErr == nil {
return true, nil return true, nil
} }

View File

@ -3,6 +3,7 @@ package consul
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/x509" "crypto/x509"
"encoding/binary" "encoding/binary"
"errors" "errors"
@ -25,15 +26,17 @@ import (
"github.com/hashicorp/raft" "github.com/hashicorp/raft"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/state"
agent_grpc "github.com/hashicorp/consul/agent/grpc"
"github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
tokenStore "github.com/hashicorp/consul/agent/token" tokenStore "github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/proto/pbsubscribe"
"github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/testrpc"
@ -964,6 +967,201 @@ func TestRPC_LocalTokenStrippedOnForward(t *testing.T) {
require.Equal(t, localToken2.SecretID, arg.WriteRequest.Token, "token should not be stripped") require.Equal(t, localToken2.SecretID, arg.WriteRequest.Token, "token should not be stripped")
} }
func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
c.ACLResolverSettings.ACLDefaultPolicy = "deny"
c.ACLMasterToken = "root"
c.RPCConfig.EnableStreaming = true
})
s1.tokens.UpdateAgentToken("root", tokenStore.TokenSourceConfig)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
codec := rpcClient(t, s1)
defer codec.Close()
dir2, s2 := testServerWithConfig(t, func(c *Config) {
c.Datacenter = "dc2"
c.PrimaryDatacenter = "dc1"
c.ACLsEnabled = true
c.ACLResolverSettings.ACLDefaultPolicy = "deny"
c.ACLTokenReplication = true
c.ACLReplicationRate = 100
c.ACLReplicationBurst = 100
c.ACLReplicationApplyLimit = 1000000
c.RPCConfig.EnableStreaming = true
})
s2.tokens.UpdateReplicationToken("root", tokenStore.TokenSourceConfig)
s2.tokens.UpdateAgentToken("root", tokenStore.TokenSourceConfig)
testrpc.WaitForLeader(t, s2.RPC, "dc2")
defer os.RemoveAll(dir2)
defer s2.Shutdown()
codec2 := rpcClient(t, s2)
defer codec2.Close()
// Try to join.
joinWAN(t, s2, s1)
testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForLeader(t, s1.RPC, "dc2")
// Wait for legacy acls to be disabled so we are clear that
// legacy replication isn't meddling.
waitForNewACLs(t, s1)
waitForNewACLs(t, s2)
waitForNewACLReplication(t, s2, structs.ACLReplicateTokens, 1, 1, 0)
// create simple service policy
policy, err := upsertTestPolicyWithRules(codec, "root", "dc1", `
node_prefix "" { policy = "read" }
service_prefix "" { policy = "read" }
`)
require.NoError(t, err)
// Wait for it to replicate
retry.Run(t, func(r *retry.R) {
_, p, err := s2.fsm.State().ACLPolicyGetByID(nil, policy.ID, &structs.EnterpriseMeta{})
require.Nil(r, err)
require.NotNil(r, p)
})
// create local token that only works in DC2
localToken2, err := upsertTestToken(codec, "root", "dc2", func(token *structs.ACLToken) {
token.Local = true
token.Policies = []structs.ACLTokenPolicyLink{
{ID: policy.ID},
}
})
require.NoError(t, err)
runStep(t, "Register a dummy node with a service", func(t *testing.T) {
req := &structs.RegisterRequest{
Node: "node1",
Address: "3.4.5.6",
Datacenter: "dc1",
Service: &structs.NodeService{
ID: "redis1",
Service: "redis",
Address: "3.4.5.6",
Port: 8080,
},
WriteRequest: structs.WriteRequest{Token: "root"},
}
var out struct{}
require.NoError(t, s1.RPC("Catalog.Register", &req, &out))
})
var conn *grpc.ClientConn
{
client, builder := newClientWithGRPCResolver(t, func(c *Config) {
c.Datacenter = "dc2"
c.PrimaryDatacenter = "dc1"
c.RPCConfig.EnableStreaming = true
})
joinLAN(t, client, s2)
testrpc.WaitForTestAgent(t, client.RPC, "dc2", testrpc.WithToken("root"))
pool := agent_grpc.NewClientConnPool(agent_grpc.ClientConnPoolConfig{
Servers: builder,
DialingFromServer: false,
DialingFromDatacenter: "dc2",
})
conn, err = pool.ClientConn("dc2")
require.NoError(t, err)
}
// Try to use it locally (it should work)
runStep(t, "token used locally should work", func(t *testing.T) {
arg := &pbsubscribe.SubscribeRequest{
Topic: pbsubscribe.Topic_ServiceHealth,
Key: "redis",
Token: localToken2.SecretID,
Datacenter: "dc2",
}
event, err := getFirstSubscribeEventOrError(conn, arg)
require.NoError(t, err)
require.NotNil(t, event)
// make sure that token restore defer works
require.Equal(t, localToken2.SecretID, arg.Token, "token should not be stripped")
})
runStep(t, "token used remotely should not work", func(t *testing.T) {
arg := &pbsubscribe.SubscribeRequest{
Topic: pbsubscribe.Topic_ServiceHealth,
Key: "redis",
Token: localToken2.SecretID,
Datacenter: "dc1",
}
event, err := getFirstSubscribeEventOrError(conn, arg)
// NOTE: the subscription endpoint is a filtering style instead of a
// hard-fail style so when the token isn't present 100% of the data is
// filtered out leading to a stream with an empty snapshot.
require.NoError(t, err)
require.IsType(t, &pbsubscribe.Event_EndOfSnapshot{}, event.Payload)
require.True(t, event.Payload.(*pbsubscribe.Event_EndOfSnapshot).EndOfSnapshot)
})
runStep(t, "update anonymous token to read services", func(t *testing.T) {
tokenUpsertReq := structs.ACLTokenSetRequest{
Datacenter: "dc1",
ACLToken: structs.ACLToken{
AccessorID: structs.ACLTokenAnonymousID,
Policies: []structs.ACLTokenPolicyLink{
{ID: policy.ID},
},
},
WriteRequest: structs.WriteRequest{Token: "root"},
}
token := structs.ACLToken{}
err = msgpackrpc.CallWithCodec(codec, "ACL.TokenSet", &tokenUpsertReq, &token)
require.NoError(t, err)
require.NotEmpty(t, token.SecretID)
})
runStep(t, "token used remotely should fallback on anonymous token now", func(t *testing.T) {
arg := &pbsubscribe.SubscribeRequest{
Topic: pbsubscribe.Topic_ServiceHealth,
Key: "redis",
Token: localToken2.SecretID,
Datacenter: "dc1",
}
event, err := getFirstSubscribeEventOrError(conn, arg)
require.NoError(t, err)
require.NotNil(t, event)
// So now that we can read data, we should get a snapshot with just instances of the "consul" service.
require.NoError(t, err)
require.IsType(t, &pbsubscribe.Event_ServiceHealth{}, event.Payload)
esh := event.Payload.(*pbsubscribe.Event_ServiceHealth)
require.Equal(t, pbsubscribe.CatalogOp_Register, esh.ServiceHealth.Op)
csn := esh.ServiceHealth.CheckServiceNode
require.NotNil(t, csn)
require.NotNil(t, csn.Node)
require.Equal(t, "node1", csn.Node.Node)
require.Equal(t, "3.4.5.6", csn.Node.Address)
require.NotNil(t, csn.Service)
require.Equal(t, "redis1", csn.Service.ID)
require.Equal(t, "redis", csn.Service.Service)
// make sure that token restore defer works
require.Equal(t, localToken2.SecretID, arg.Token, "token should not be stripped")
})
}
func TestCanRetry(t *testing.T) { func TestCanRetry(t *testing.T) {
type testCase struct { type testCase struct {
name string name string
@ -1362,3 +1560,23 @@ func isConnectionClosedError(err error) bool {
return false return false
} }
} }
func getFirstSubscribeEventOrError(conn *grpc.ClientConn, req *pbsubscribe.SubscribeRequest) (*pbsubscribe.Event, error) {
streamClient := pbsubscribe.NewStateChangeSubscriptionClient(conn)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
handle, err := streamClient.Subscribe(ctx, req)
if err != nil {
return nil, err
}
event, err := handle.Recv()
if err == io.EOF {
return nil, nil
}
if err != nil {
return nil, err
}
return event, nil
}

View File

@ -26,21 +26,8 @@ func (s subscribeBackend) ResolveTokenAndDefaultMeta(
var _ subscribe.Backend = (*subscribeBackend)(nil) var _ subscribe.Backend = (*subscribeBackend)(nil)
// Forward requests to a remote datacenter by calling f if the target dc does not func (s subscribeBackend) Forward(info structs.RPCInfo, f func(*grpc.ClientConn) error) (handled bool, err error) {
// match the config. Does nothing but return handled=false if dc is not specified, return s.srv.ForwardGRPC(s.connPool, info, f)
// or if it matches the Datacenter in config.
//
// TODO: extract this so that it can be used with other grpc services.
// TODO: rename to ForwardToDC
func (s subscribeBackend) Forward(dc string, f func(*grpc.ClientConn) error) (handled bool, err error) {
if dc == "" || dc == s.srv.config.Datacenter {
return false, nil
}
conn, err := s.connPool.ClientConn(dc)
if err != nil {
return false, err
}
return true, f(conn)
} }
func (s subscribeBackend) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) { func (s subscribeBackend) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) {

View File

@ -362,17 +362,19 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
} }
func newClientWithGRPCResolver(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) { func newClientWithGRPCResolver(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) {
builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, "client"))
resolver.Register(builder)
t.Cleanup(func() {
resolver.Deregister(builder.Authority())
})
_, config := testClientConfig(t) _, config := testClientConfig(t)
for _, op := range ops { for _, op := range ops {
op(config) op(config)
} }
builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t,
"client."+config.Datacenter+"."+string(config.NodeID)))
resolver.Register(builder)
t.Cleanup(func() {
resolver.Deregister(builder.Authority())
})
deps := newDefaultDeps(t, config) deps := newDefaultDeps(t, config)
deps.Router = router.NewRouter( deps.Router = router.NewRouter(
deps.Logger, deps.Logger,

View File

@ -37,13 +37,13 @@ var _ pbsubscribe.StateChangeSubscriptionServer = (*Server)(nil)
type Backend interface { type Backend interface {
ResolveTokenAndDefaultMeta(token string, entMeta *structs.EnterpriseMeta, authzContext *acl.AuthorizerContext) (acl.Authorizer, error) ResolveTokenAndDefaultMeta(token string, entMeta *structs.EnterpriseMeta, authzContext *acl.AuthorizerContext) (acl.Authorizer, error)
Forward(dc string, f func(*grpc.ClientConn) error) (handled bool, err error) Forward(info structs.RPCInfo, f func(*grpc.ClientConn) error) (handled bool, err error)
Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error) Subscribe(req *stream.SubscribeRequest) (*stream.Subscription, error)
} }
func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsubscribe.StateChangeSubscription_SubscribeServer) error { func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsubscribe.StateChangeSubscription_SubscribeServer) error {
logger := newLoggerForRequest(h.Logger, req) logger := newLoggerForRequest(h.Logger, req)
handled, err := h.Backend.Forward(req.Datacenter, forwardToDC(req, serverStream, logger)) handled, err := h.Backend.Forward(req, forwardToDC(req, serverStream, logger))
if handled || err != nil { if handled || err != nil {
return err return err
} }

View File

@ -292,7 +292,7 @@ func (b testBackend) ResolveTokenAndDefaultMeta(
return b.authorizer(token, entMeta), nil return b.authorizer(token, entMeta), nil
} }
func (b testBackend) Forward(_ string, fn func(*gogrpc.ClientConn) error) (handled bool, err error) { func (b testBackend) Forward(_ structs.RPCInfo, fn func(*gogrpc.ClientConn) error) (handled bool, err error) {
if b.forwardConn != nil { if b.forwardConn != nil {
return true, fn(b.forwardConn) return true, fn(b.forwardConn)
} }

View File

@ -146,7 +146,7 @@ func (b backend) ResolveTokenAndDefaultMeta(string, *structs.EnterpriseMeta, *ac
return acl.AllowAll(), nil return acl.AllowAll(), nil
} }
func (b backend) Forward(string, func(*grpc.ClientConn) error) (handled bool, err error) { func (b backend) Forward(structs.RPCInfo, func(*grpc.ClientConn) error) (handled bool, err error) {
return false, nil return false, nil
} }

View File

@ -112,27 +112,61 @@ func (q *QueryMeta) GetBackend() structs.QueryBackend {
} }
// WriteRequest only applies to writes, always false // WriteRequest only applies to writes, always false
//
// IsRead implements structs.RPCInfo
func (w WriteRequest) IsRead() bool { func (w WriteRequest) IsRead() bool {
return false return false
} }
// SetTokenSecret implements structs.RPCInfo
func (w WriteRequest) TokenSecret() string { func (w WriteRequest) TokenSecret() string {
return w.Token return w.Token
} }
// SetTokenSecret implements structs.RPCInfo
func (w *WriteRequest) SetTokenSecret(s string) { func (w *WriteRequest) SetTokenSecret(s string) {
w.Token = s w.Token = s
} }
// AllowStaleRead returns whether a stale read should be allowed // AllowStaleRead returns whether a stale read should be allowed
//
// AllowStaleRead implements structs.RPCInfo
func (w WriteRequest) AllowStaleRead() bool { func (w WriteRequest) AllowStaleRead() bool {
return false return false
} }
// HasTimedOut implements structs.RPCInfo
func (w WriteRequest) HasTimedOut(start time.Time, rpcHoldTimeout, _, _ time.Duration) bool { func (w WriteRequest) HasTimedOut(start time.Time, rpcHoldTimeout, _, _ time.Duration) bool {
return time.Since(start) > rpcHoldTimeout return time.Since(start) > rpcHoldTimeout
} }
// IsRead implements structs.RPCInfo
func (r *ReadRequest) IsRead() bool {
return true
}
// AllowStaleRead implements structs.RPCInfo
func (r *ReadRequest) AllowStaleRead() bool {
// TODO(partitions): plumb this?
return false
}
// TokenSecret implements structs.RPCInfo
func (r *ReadRequest) TokenSecret() string {
return r.Token
}
// SetTokenSecret implements structs.RPCInfo
func (r *ReadRequest) SetTokenSecret(token string) {
r.Token = token
}
// HasTimedOut implements structs.RPCInfo
func (r *ReadRequest) HasTimedOut(start time.Time, rpcHoldTimeout, maxQueryTime, defaultQueryTime time.Duration) bool {
return time.Since(start) > rpcHoldTimeout
}
// RequestDatacenter implements structs.RPCInfo
func (td TargetDatacenter) RequestDatacenter() string { func (td TargetDatacenter) RequestDatacenter() string {
return td.Datacenter return td.Datacenter
} }

View File

@ -0,0 +1,33 @@
package pbsubscribe
import "time"
// RequestDatacenter implements structs.RPCInfo
func (req *SubscribeRequest) RequestDatacenter() string {
return req.Datacenter
}
// IsRead implements structs.RPCInfo
func (req *SubscribeRequest) IsRead() bool {
return true
}
// AllowStaleRead implements structs.RPCInfo
func (req *SubscribeRequest) AllowStaleRead() bool {
return true
}
// TokenSecret implements structs.RPCInfo
func (req *SubscribeRequest) TokenSecret() string {
return req.Token
}
// SetTokenSecret implements structs.RPCInfo
func (req *SubscribeRequest) SetTokenSecret(token string) {
req.Token = token
}
// HasTimedOut implements structs.RPCInfo
func (req *SubscribeRequest) HasTimedOut(start time.Time, rpcHoldTimeout, maxQueryTime, defaultQueryTime time.Duration) bool {
return time.Since(start) > rpcHoldTimeout
}