diff --git a/.changelog/10838.txt b/.changelog/10838.txt new file mode 100644 index 000000000..c62a86124 --- /dev/null +++ b/.changelog/10838.txt @@ -0,0 +1,3 @@ +```release-note:bug +grpc: ensure that streaming gRPC requests work over mesh gateway based wan federation +``` diff --git a/agent/agent_test.go b/agent/agent_test.go index 926c23323..6638fa35c 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -371,6 +371,9 @@ func (f fakeGRPCConnPool) ClientConnLeader() (*grpc.ClientConn, error) { return nil, nil } +func (f fakeGRPCConnPool) SetGatewayResolver(_ func(string) string) { +} + func TestAgent_ReconnectConfigWanDisabled(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -4524,6 +4527,9 @@ LOOP: } // This is a mirror of a similar test in agent/consul/server_test.go +// +// TODO(rb): implement something similar to this as a full containerized test suite with proper +// isolation so requests can't "cheat" and bypass the mesh gateways func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -4771,6 +4777,9 @@ func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) { }) // Ensure we can do some trivial RPC in all directions. + // + // NOTE: we explicitly make streaming and non-streaming assertions here to + // verify both rpc and grpc codepaths. agents := map[string]*TestAgent{"dc1": a1, "dc2": a2, "dc3": a3} names := map[string]string{"dc1": "bob", "dc2": "betty", "dc3": "bonnie"} for _, srcDC := range []string{"dc1", "dc2", "dc3"} { @@ -4780,20 +4789,39 @@ func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) { continue } t.Run(srcDC+" to "+dstDC, func(t *testing.T) { - req, err := http.NewRequest("GET", "/v1/catalog/nodes?dc="+dstDC, nil) - require.NoError(t, err) + t.Run("normal-rpc", func(t *testing.T) { + req, err := http.NewRequest("GET", "/v1/catalog/nodes?dc="+dstDC, nil) + require.NoError(t, err) - resp := httptest.NewRecorder() - obj, err := a.srv.CatalogNodes(resp, req) - require.NoError(t, err) - require.NotNil(t, obj) + resp := httptest.NewRecorder() + obj, err := a.srv.CatalogNodes(resp, req) + require.NoError(t, err) + require.NotNil(t, obj) - nodes, ok := obj.(structs.Nodes) - require.True(t, ok) - require.Len(t, nodes, 1) - node := nodes[0] - require.Equal(t, dstDC, node.Datacenter) - require.Equal(t, names[dstDC], node.Node) + nodes, ok := obj.(structs.Nodes) + require.True(t, ok) + require.Len(t, nodes, 1) + node := nodes[0] + require.Equal(t, dstDC, node.Datacenter) + require.Equal(t, names[dstDC], node.Node) + }) + t.Run("streaming-grpc", func(t *testing.T) { + req, err := http.NewRequest("GET", "/v1/health/service/consul?cached&dc="+dstDC, nil) + require.NoError(t, err) + + resp := httptest.NewRecorder() + obj, err := a.srv.HealthServiceNodes(resp, req) + require.NoError(t, err) + require.NotNil(t, obj) + + csns, ok := obj.(structs.CheckServiceNodes) + require.True(t, ok) + require.Len(t, csns, 1) + + csn := csns[0] + require.Equal(t, dstDC, csn.Node.Datacenter) + require.Equal(t, names[dstDC], csn.Node.Node) + }) }) } } diff --git a/agent/consul/client_test.go b/agent/consul/client_test.go index 6e50735cd..0e1b4f808 100644 --- a/agent/consul/client_test.go +++ b/agent/consul/client_test.go @@ -5,10 +5,17 @@ import ( "fmt" "net" "os" + "strings" "sync" "testing" "time" + "github.com/hashicorp/go-hclog" + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" + "github.com/hashicorp/serf/serf" + "github.com/stretchr/testify/require" + "golang.org/x/time/rate" + "github.com/hashicorp/consul/agent/grpc" "github.com/hashicorp/consul/agent/grpc/resolver" "github.com/hashicorp/consul/agent/pool" @@ -20,11 +27,6 @@ import ( "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/tlsutil" - "github.com/hashicorp/go-hclog" - msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" - "github.com/hashicorp/serf/serf" - "github.com/stretchr/testify/require" - "golang.org/x/time/rate" ) func testClientConfig(t *testing.T) (string, *Config) { @@ -490,6 +492,13 @@ func newClient(t *testing.T, config *Config) *Client { return client } +func newTestResolverConfig(t *testing.T, suffix string) resolver.Config { + n := t.Name() + s := strings.Replace(n, "/", "", -1) + s = strings.Replace(s, "_", "", -1) + return resolver.Config{Authority: strings.ToLower(s) + "-" + suffix} +} + func newDefaultDeps(t *testing.T, c *Config) Deps { t.Helper() @@ -502,7 +511,7 @@ func newDefaultDeps(t *testing.T, c *Config) Deps { tls, err := tlsutil.NewConfigurator(c.TLSConfig, logger) require.NoError(t, err, "failed to create tls configuration") - builder := resolver.NewServerResolverBuilder(resolver.Config{Authority: c.NodeName}) + builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter)) r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter), builder) resolver.Register(builder) @@ -522,7 +531,13 @@ func newDefaultDeps(t *testing.T, c *Config) Deps { Tokens: new(token.Store), Router: r, ConnPool: connPool, - GRPCConnPool: grpc.NewClientConnPool(builder, grpc.TLSWrapper(tls.OutgoingRPCWrapper()), tls.UseTLS), + GRPCConnPool: grpc.NewClientConnPool(grpc.ClientConnPoolConfig{ + Servers: builder, + TLSWrapper: grpc.TLSWrapper(tls.OutgoingRPCWrapper()), + UseTLSForDC: tls.UseTLS, + DialingFromServer: true, + DialingFromDatacenter: c.Datacenter, + }), LeaderForwarder: builder, EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c), } diff --git a/agent/consul/options.go b/agent/consul/options.go index efcf32ab2..3440b0245 100644 --- a/agent/consul/options.go +++ b/agent/consul/options.go @@ -24,8 +24,10 @@ type Deps struct { type GRPCClientConner interface { ClientConn(datacenter string) (*grpc.ClientConn, error) ClientConnLeader() (*grpc.ClientConn, error) + SetGatewayResolver(func(string) string) } type LeaderForwarder interface { - UpdateLeaderAddr(leaderAddr string) + // UpdateLeaderAddr updates the leader address in the local DC's resolver. + UpdateLeaderAddr(datacenter, addr string) } diff --git a/agent/consul/rpc.go b/agent/consul/rpc.go index 4c9180eef..2189f4eb7 100644 --- a/agent/consul/rpc.go +++ b/agent/consul/rpc.go @@ -293,7 +293,7 @@ func (s *Server) handleNativeTLS(conn net.Conn) { s.handleSnapshotConn(tlsConn) case pool.ALPN_RPCGRPC: - s.grpcHandler.Handle(conn) + s.grpcHandler.Handle(tlsConn) case pool.ALPN_WANGossipPacket: if err := s.handleALPN_WANGossipPacketStream(tlsConn); err != nil && err != io.EOF { @@ -373,7 +373,7 @@ func (s *Server) handleMultiplexV2(conn net.Conn) { } sub = peeked switch first { - case pool.RPCGossip: + case byte(pool.RPCGossip): buf := make([]byte, 1) sub.Read(buf) go func() { diff --git a/agent/consul/rpc_test.go b/agent/consul/rpc_test.go index 99d174b1d..b771697a1 100644 --- a/agent/consul/rpc_test.go +++ b/agent/consul/rpc_test.go @@ -460,7 +460,7 @@ func TestRPC_TLSHandshakeTimeout(t *testing.T) { // Write TLS byte to avoid being closed by either the (outer) first byte // timeout or the fact that server requires TLS - _, err = conn.Write([]byte{pool.RPCTLS}) + _, err = conn.Write([]byte{byte(pool.RPCTLS)}) require.NoError(t, err) // Wait for more than the timeout before we start a TLS handshake. This is diff --git a/agent/consul/server.go b/agent/consul/server.go index 90b8a02fb..b3d01207d 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -173,6 +173,9 @@ type Server struct { // Connection pool to other consul servers connPool *pool.ConnPool + // Connection pool to other consul servers using gRPC + grpcConnPool GRPCClientConner + // eventChLAN is used to receive events from the // serf cluster in the datacenter eventChLAN chan serf.Event @@ -348,6 +351,7 @@ func NewServer(config *Config, flat Deps) (*Server, error) { config: config, tokens: flat.Tokens, connPool: flat.ConnPool, + grpcConnPool: flat.GRPCConnPool, eventChLAN: make(chan serf.Event, serfEventChSize), eventChWAN: make(chan serf.Event, serfEventChSize), logger: serverLogger, @@ -377,6 +381,7 @@ func NewServer(config *Config, flat Deps) (*Server, error) { s.config.PrimaryDatacenter, ) s.connPool.GatewayResolver = s.gatewayLocator.PickGateway + s.grpcConnPool.SetGatewayResolver(s.gatewayLocator.PickGateway) } // Initialize enterprise specific server functionality @@ -1461,7 +1466,7 @@ func (s *Server) trackLeaderChanges() { continue } - s.grpcLeaderForwarder.UpdateLeaderAddr(string(leaderObs.Leader)) + s.grpcLeaderForwarder.UpdateLeaderAddr(s.config.Datacenter, string(leaderObs.Leader)) case <-s.shutdownCh: s.raft.DeregisterObserver(observer) return diff --git a/agent/consul/subscribe_backend_test.go b/agent/consul/subscribe_backend_test.go index 92dfafe2a..e11d24b35 100644 --- a/agent/consul/subscribe_backend_test.go +++ b/agent/consul/subscribe_backend_test.go @@ -25,6 +25,8 @@ import ( func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) { t.Parallel() + // TODO(rb): add tests for the wanfed/alpn variations + _, conf1 := testServerConfig(t) conf1.TLSConfig.VerifyIncoming = true conf1.TLSConfig.VerifyOutgoing = true @@ -60,7 +62,13 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) { // Start a Subscribe call to our streaming endpoint from the client. { - pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS) + pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{ + Servers: builder, + TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), + UseTLSForDC: client.tlsConfigurator.UseTLS, + DialingFromServer: true, + DialingFromDatacenter: "dc1", + }) conn, err := pool.ClientConn("dc1") require.NoError(t, err) @@ -91,8 +99,13 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) { // Start a Subscribe call to our streaming endpoint from the server's loopback client. { - - pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS) + pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{ + Servers: builder, + TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), + UseTLSForDC: client.tlsConfigurator.UseTLS, + DialingFromServer: true, + DialingFromDatacenter: "dc1", + }) conn, err := pool.ClientConn("dc1") require.NoError(t, err) @@ -166,7 +179,13 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) { // Subscribe calls should fail initially joinLAN(t, client, server) - pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS) + pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{ + Servers: builder, + TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), + UseTLSForDC: client.tlsConfigurator.UseTLS, + DialingFromServer: true, + DialingFromDatacenter: "dc1", + }) conn, err := pool.ClientConn("dc1") require.NoError(t, err) @@ -294,7 +313,13 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T } }() - pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS) + pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{ + Servers: builder, + TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), + UseTLSForDC: client.tlsConfigurator.UseTLS, + DialingFromServer: true, + DialingFromDatacenter: "dc1", + }) conn, err := pool.ClientConn("dc1") require.NoError(t, err) @@ -337,7 +362,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T } func newClientWithGRPCResolver(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) { - builder := resolver.NewServerResolverBuilder(resolver.Config{Authority: t.Name()}) + builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, "client")) resolver.Register(builder) t.Cleanup(func() { resolver.Deregister(builder.Authority()) diff --git a/agent/consul/wanfed/wanfed.go b/agent/consul/wanfed/wanfed.go index 0afc0d6d6..c528500c0 100644 --- a/agent/consul/wanfed/wanfed.go +++ b/agent/consul/wanfed/wanfed.go @@ -1,6 +1,7 @@ package wanfed import ( + "context" "encoding/binary" "errors" "fmt" @@ -11,7 +12,6 @@ import ( "github.com/hashicorp/memberlist" "github.com/hashicorp/consul/agent/pool" - "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/tlsutil" ) @@ -97,13 +97,8 @@ func (t *Transport) WriteToAddress(b []byte, addr memberlist.Address) (time.Time } if dc != t.datacenter { - gwAddr := t.gwResolver(dc) - if gwAddr == "" { - return time.Time{}, structs.ErrDCNotAvailable - } - dialFunc := func() (net.Conn, error) { - return t.dial(dc, node, pool.ALPN_WANGossipPacket, gwAddr) + return t.dial(dc, node, pool.ALPN_WANGossipPacket) } conn, err := t.pool.AcquireOrDial(addr.Name, dialFunc) if err != nil { @@ -136,42 +131,24 @@ func (t *Transport) DialAddressTimeout(addr memberlist.Address, timeout time.Dur } if dc != t.datacenter { - gwAddr := t.gwResolver(dc) - if gwAddr == "" { - return nil, structs.ErrDCNotAvailable - } - - return t.dial(dc, node, pool.ALPN_WANGossipStream, gwAddr) + return t.dial(dc, node, pool.ALPN_WANGossipStream) } return t.IngestionAwareTransport.DialAddressTimeout(addr, timeout) } -// NOTE: There is a close mirror of this method in agent/pool/pool.go:DialTimeoutWithRPCType -func (t *Transport) dial(dc, nodeName, nextProto, addr string) (net.Conn, error) { - wrapper := t.tlsConfigurator.OutgoingALPNRPCWrapper() - if wrapper == nil { - return nil, fmt.Errorf("wanfed: cannot dial via a mesh gateway when outgoing TLS is disabled") - } - - dialer := &net.Dialer{Timeout: 10 * time.Second} - - rawConn, err := dialer.Dial("tcp", addr) - if err != nil { - return nil, err - } - - if tcp, ok := rawConn.(*net.TCPConn); ok { - _ = tcp.SetKeepAlive(true) - _ = tcp.SetNoDelay(true) - } - - tlsConn, err := wrapper(dc, nodeName, nextProto, rawConn) - if err != nil { - return nil, err - } - - return tlsConn, nil +func (t *Transport) dial(dc, nodeName, nextProto string) (net.Conn, error) { + conn, _, err := pool.DialRPCViaMeshGateway( + context.Background(), + dc, + nodeName, + nil, // TODO(rb): thread source address through here? + t.tlsConfigurator.OutgoingALPNRPCWrapper(), + nextProto, + true, + t.gwResolver, + ) + return conn, err } // SplitNodeName splits a node name as it would be represented in diff --git a/agent/grpc/client.go b/agent/grpc/client.go index d3709744a..9afd6becd 100644 --- a/agent/grpc/client.go +++ b/agent/grpc/client.go @@ -12,38 +12,93 @@ import ( "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/pool" + "github.com/hashicorp/consul/tlsutil" ) // ClientConnPool creates and stores a connection for each datacenter. type ClientConnPool struct { - dialer dialer - servers ServerLocator - conns map[string]*grpc.ClientConn - connsLock sync.Mutex + dialer dialer + servers ServerLocator + gwResolverDep gatewayResolverDep + conns map[string]*grpc.ClientConn + connsLock sync.Mutex } type ServerLocator interface { - // ServerForAddr is used to look up server metadata from an address. - ServerForAddr(addr string) (*metadata.Server, error) + // ServerForGlobalAddr returns server metadata for a server with the specified globally unique address. + ServerForGlobalAddr(globalAddr string) (*metadata.Server, error) + // Authority returns the target authority to use to dial the server. This is primarily // needed for testing multiple agents in parallel, because gRPC requires the // resolver to be registered globally. Authority() string } +// gatewayResolverDep is just a holder for a function pointer that can be +// updated lazily after the structs are instantiated (but before first use) +// and all structs with a reference to this struct will see the same update. +type gatewayResolverDep struct { + // GatewayResolver is a function that returns a suitable random mesh + // gateway address for dialing servers in a given DC. This is only + // needed if wan federation via mesh gateways is enabled. + GatewayResolver func(string) string +} + // TLSWrapper wraps a non-TLS connection and returns a connection with TLS // enabled. type TLSWrapper func(dc string, conn net.Conn) (net.Conn, error) +// ALPNWrapper is a function that is used to wrap a non-TLS connection and +// returns an appropriate TLS connection or error. This taks a datacenter and +// node name as argument to configure the desired SNI value and the desired +// next proto for configuring ALPN. +type ALPNWrapper func(dc, nodeName, alpnProto string, conn net.Conn) (net.Conn, error) + type dialer func(context.Context, string) (net.Conn, error) -// NewClientConnPool create new GRPC client pool to connect to servers using GRPC over RPC -func NewClientConnPool(servers ServerLocator, tls TLSWrapper, useTLSForDC func(dc string) bool) *ClientConnPool { - return &ClientConnPool{ - dialer: newDialer(servers, tls, useTLSForDC), - servers: servers, +type ClientConnPoolConfig struct { + // Servers is a reference for how to figure out how to dial any server. + Servers ServerLocator + + // SrcAddr is the source address for outgoing connections. + SrcAddr *net.TCPAddr + + // TLSWrapper is the specifics of wrapping a socket when doing an TYPE_BYTE+TLS + // wrapped RPC request. + TLSWrapper TLSWrapper + + // ALPNWrapper is the specifics of wrapping a socket when doing an ALPN+TLS + // wrapped RPC request (typically only for wan federation via mesh + // gateways). + ALPNWrapper ALPNWrapper + + // UseTLSForDC is a function to determine if dialing a given datacenter + // should use TLS. + UseTLSForDC func(dc string) bool + + // DialingFromServer should be set to true if this connection pool is owned + // by a consul server instance. + DialingFromServer bool + + // DialingFromDatacenter is the datacenter of the consul agent using this + // pool. + DialingFromDatacenter string +} + +// NewClientConnPool create new GRPC client pool to connect to servers using +// GRPC over RPC. +func NewClientConnPool(cfg ClientConnPoolConfig) *ClientConnPool { + c := &ClientConnPool{ + servers: cfg.Servers, conns: make(map[string]*grpc.ClientConn), } + c.dialer = newDialer(cfg, &c.gwResolverDep) + return c +} + +// SetGatewayResolver is only to be called during setup before the pool is used. +func (c *ClientConnPool) SetGatewayResolver(gatewayResolver func(string) string) { + c.gwResolverDep.GatewayResolver = gatewayResolver } // ClientConn returns a grpc.ClientConn for the datacenter. If there are no @@ -102,22 +157,39 @@ func (c *ClientConnPool) dial(datacenter string, serverType string) (*grpc.Clien // newDialer returns a gRPC dialer function that conditionally wraps the connection // with TLS based on the Server.useTLS value. -func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc string) bool) func(context.Context, string) (net.Conn, error) { - return func(ctx context.Context, addr string) (net.Conn, error) { - d := net.Dialer{} - conn, err := d.DialContext(ctx, "tcp", addr) +func newDialer(cfg ClientConnPoolConfig, gwResolverDep *gatewayResolverDep) func(context.Context, string) (net.Conn, error) { + return func(ctx context.Context, globalAddr string) (net.Conn, error) { + server, err := cfg.Servers.ServerForGlobalAddr(globalAddr) if err != nil { return nil, err } - server, err := servers.ServerForAddr(addr) + if cfg.DialingFromServer && + gwResolverDep.GatewayResolver != nil && + cfg.ALPNWrapper != nil && + server.Datacenter != cfg.DialingFromDatacenter { + // NOTE: TLS is required on this branch. + conn, _, err := pool.DialRPCViaMeshGateway( + ctx, + server.Datacenter, + server.ShortName, + cfg.SrcAddr, + tlsutil.ALPNWrapper(cfg.ALPNWrapper), + pool.ALPN_RPCGRPC, + cfg.DialingFromServer, + gwResolverDep.GatewayResolver, + ) + return conn, err + } + + d := net.Dialer{LocalAddr: cfg.SrcAddr, Timeout: pool.DefaultDialTimeout} + conn, err := d.DialContext(ctx, "tcp", server.Addr.String()) if err != nil { - conn.Close() return nil, err } - if server.UseTLS && useTLSForDC(server.Datacenter) { - if wrapper == nil { + if server.UseTLS && cfg.UseTLSForDC(server.Datacenter) { + if cfg.TLSWrapper == nil { conn.Close() return nil, fmt.Errorf("TLS enabled but got nil TLS wrapper") } @@ -129,7 +201,7 @@ func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc st } // Wrap the connection in a TLS client - tlsConn, err := wrapper(server.Datacenter, conn) + tlsConn, err := cfg.TLSWrapper(server.Datacenter, conn) if err != nil { conn.Close() return nil, err @@ -137,7 +209,7 @@ func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc st conn = tlsConn } - _, err = conn.Write([]byte{pool.RPCGRPC}) + _, err = conn.Write([]byte{byte(pool.RPCGRPC)}) if err != nil { conn.Close() return nil, err diff --git a/agent/grpc/client_test.go b/agent/grpc/client_test.go index d4789b4f0..f4d0138f3 100644 --- a/agent/grpc/client_test.go +++ b/agent/grpc/client_test.go @@ -4,17 +4,22 @@ import ( "context" "fmt" "net" + "strconv" "strings" "sync/atomic" "testing" "time" + "github.com/google/tcpproxy" "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/hashicorp/consul/agent/grpc/internal/testservice" "github.com/hashicorp/consul/agent/grpc/resolver" "github.com/hashicorp/consul/agent/metadata" + "github.com/hashicorp/consul/ipaddr" + "github.com/hashicorp/consul/sdk/freeport" "github.com/hashicorp/consul/tlsutil" ) @@ -24,11 +29,14 @@ func useTLSForDcAlwaysTrue(_ string) bool { } func TestNewDialer_WithTLSWrapper(t *testing.T) { - lis, err := net.Listen("tcp", "127.0.0.1:0") + ports := freeport.MustTake(1) + defer freeport.Return(ports) + + lis, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(ports[0]))) require.NoError(t, err) t.Cleanup(logError(t, lis.Close)) - builder := resolver.NewServerResolverBuilder(resolver.Config{}) + builder := resolver.NewServerResolverBuilder(newConfig(t)) builder.AddServer(&metadata.Server{ Name: "server-1", ID: "ID1", @@ -42,19 +50,107 @@ func TestNewDialer_WithTLSWrapper(t *testing.T) { called = true return conn, nil } - dial := newDialer(builder, wrapper, useTLSForDcAlwaysTrue) + dial := newDialer( + ClientConnPoolConfig{ + Servers: builder, + TLSWrapper: wrapper, + UseTLSForDC: useTLSForDcAlwaysTrue, + DialingFromServer: true, + DialingFromDatacenter: "dc1", + }, + &gatewayResolverDep{}, + ) ctx := context.Background() - conn, err := dial(ctx, lis.Addr().String()) + conn, err := dial(ctx, resolver.DCPrefix("dc1", lis.Addr().String())) require.NoError(t, err) require.NoError(t, conn.Close()) require.True(t, called, "expected TLSWrapper to be called") } +func TestNewDialer_WithALPNWrapper(t *testing.T) { + ports := freeport.MustTake(3) + defer freeport.Return(ports) + + var ( + s1addr = ipaddr.FormatAddressPort("127.0.0.1", ports[0]) + s2addr = ipaddr.FormatAddressPort("127.0.0.1", ports[1]) + gwAddr = ipaddr.FormatAddressPort("127.0.0.1", ports[2]) + ) + + lis1, err := net.Listen("tcp", s1addr) + require.NoError(t, err) + t.Cleanup(logError(t, lis1.Close)) + + lis2, err := net.Listen("tcp", s2addr) + require.NoError(t, err) + t.Cleanup(logError(t, lis2.Close)) + + // Send all of the traffic to dc2's server + var p tcpproxy.Proxy + p.AddRoute(gwAddr, tcpproxy.To(s2addr)) + p.AddStopACMESearch(gwAddr) + require.NoError(t, p.Start()) + defer func() { + p.Close() + p.Wait() + }() + + builder := resolver.NewServerResolverBuilder(newConfig(t)) + builder.AddServer(&metadata.Server{ + Name: "server-1", + ID: "ID1", + Datacenter: "dc1", + Addr: lis1.Addr(), + UseTLS: true, + }) + builder.AddServer(&metadata.Server{ + Name: "server-2", + ID: "ID2", + Datacenter: "dc2", + Addr: lis2.Addr(), + UseTLS: true, + }) + + var calledTLS bool + wrapperTLS := func(_ string, conn net.Conn) (net.Conn, error) { + calledTLS = true + return conn, nil + } + var calledALPN bool + wrapperALPN := func(_, _, _ string, conn net.Conn) (net.Conn, error) { + calledALPN = true + return conn, nil + } + gwResolverDep := &gatewayResolverDep{ + GatewayResolver: func(addr string) string { + return gwAddr + }, + } + dial := newDialer( + ClientConnPoolConfig{ + Servers: builder, + TLSWrapper: wrapperTLS, + ALPNWrapper: wrapperALPN, + UseTLSForDC: useTLSForDcAlwaysTrue, + DialingFromServer: true, + DialingFromDatacenter: "dc1", + }, + gwResolverDep, + ) + + ctx := context.Background() + conn, err := dial(ctx, resolver.DCPrefix("dc2", lis2.Addr().String())) + require.NoError(t, err) + require.NoError(t, conn.Close()) + + assert.False(t, calledTLS, "expected TLSWrapper not to be called") + assert.True(t, calledALPN, "expected ALPNWrapper to be called") +} + func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) { res := resolver.NewServerResolverBuilder(newConfig(t)) registerWithGRPC(t, res) - srv := newTestServer(t, "server-1", "dc1") tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{ VerifyIncoming: true, VerifyOutgoing: true, @@ -63,12 +159,20 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) { KeyFile: "../../test/hostname/Alice.key", }, hclog.New(nil)) require.NoError(t, err) - srv.rpc.tlsConf = tlsConf - res.AddServer(srv.Metadata()) + srv := newTestServer(t, "server-1", "dc1", tlsConf) + + md := srv.Metadata() + res.AddServer(md) t.Cleanup(srv.shutdown) - pool := NewClientConnPool(res, TLSWrapper(tlsConf.OutgoingRPCWrapper()), tlsConf.UseTLS) + pool := NewClientConnPool(ClientConnPoolConfig{ + Servers: res, + TLSWrapper: TLSWrapper(tlsConf.OutgoingRPCWrapper()), + UseTLSForDC: tlsConf.UseTLS, + DialingFromServer: true, + DialingFromDatacenter: "dc1", + }) conn, err := pool.ClientConn("dc1") require.NoError(t, err) @@ -81,17 +185,98 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) { require.NoError(t, err) require.Equal(t, "server-1", resp.ServerName) require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) > 0) + require.True(t, atomic.LoadInt32(&srv.rpc.alpnConnEstablished) == 0) +} + +func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T) { + ports := freeport.MustTake(1) + defer freeport.Return(ports) + + gwAddr := ipaddr.FormatAddressPort("127.0.0.1", ports[0]) + + res := resolver.NewServerResolverBuilder(newConfig(t)) + registerWithGRPC(t, res) + + tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{ + VerifyIncoming: true, + VerifyOutgoing: true, + VerifyServerHostname: true, + CAFile: "../../test/hostname/CertAuth.crt", + CertFile: "../../test/hostname/Bob.crt", + KeyFile: "../../test/hostname/Bob.key", + Domain: "consul", + NodeName: "bob", + }, hclog.New(nil)) + require.NoError(t, err) + + srv := newTestServer(t, "bob", "dc1", tlsConf) + + // Send all of the traffic to dc1's server + var p tcpproxy.Proxy + p.AddRoute(gwAddr, tcpproxy.To(srv.addr.String())) + p.AddStopACMESearch(gwAddr) + require.NoError(t, p.Start()) + defer func() { + p.Close() + p.Wait() + }() + + md := srv.Metadata() + res.AddServer(md) + t.Cleanup(srv.shutdown) + + clientTLSConf, err := tlsutil.NewConfigurator(tlsutil.Config{ + VerifyIncoming: true, + VerifyOutgoing: true, + VerifyServerHostname: true, + CAFile: "../../test/hostname/CertAuth.crt", + CertFile: "../../test/hostname/Betty.crt", + KeyFile: "../../test/hostname/Betty.key", + Domain: "consul", + NodeName: "betty", + }, hclog.New(nil)) + require.NoError(t, err) + + pool := NewClientConnPool(ClientConnPoolConfig{ + Servers: res, + TLSWrapper: TLSWrapper(clientTLSConf.OutgoingRPCWrapper()), + ALPNWrapper: ALPNWrapper(clientTLSConf.OutgoingALPNRPCWrapper()), + UseTLSForDC: tlsConf.UseTLS, + DialingFromServer: true, + DialingFromDatacenter: "dc2", + }) + pool.SetGatewayResolver(func(addr string) string { + return gwAddr + }) + + conn, err := pool.ClientConn("dc1") + require.NoError(t, err) + client := testservice.NewSimpleClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(cancel) + + resp, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + require.Equal(t, "bob", resp.ServerName) + require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) == 0) + require.True(t, atomic.LoadInt32(&srv.rpc.alpnConnEstablished) > 0) } func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { count := 4 res := resolver.NewServerResolverBuilder(newConfig(t)) registerWithGRPC(t, res) - pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue) + pool := NewClientConnPool(ClientConnPoolConfig{ + Servers: res, + UseTLSForDC: useTLSForDcAlwaysTrue, + DialingFromServer: true, + DialingFromDatacenter: "dc1", + }) for i := 0; i < count; i++ { name := fmt.Sprintf("server-%d", i) - srv := newTestServer(t, name, "dc1") + srv := newTestServer(t, name, "dc1", nil) res.AddServer(srv.Metadata()) t.Cleanup(srv.shutdown) } @@ -115,22 +300,27 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) { count := 3 - conf := newConfig(t) - res := resolver.NewServerResolverBuilder(conf) + res := resolver.NewServerResolverBuilder(newConfig(t)) registerWithGRPC(t, res) - pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue) + pool := NewClientConnPool(ClientConnPoolConfig{ + Servers: res, + UseTLSForDC: useTLSForDcAlwaysTrue, + DialingFromServer: true, + DialingFromDatacenter: "dc1", + }) var servers []testServer for i := 0; i < count; i++ { name := fmt.Sprintf("server-%d", i) - srv := newTestServer(t, name, "dc1") + srv := newTestServer(t, name, "dc1", nil) res.AddServer(srv.Metadata()) servers = append(servers, srv) t.Cleanup(srv.shutdown) } // Set the leader address to the first server. - res.UpdateLeaderAddr(servers[0].addr.String()) + srv0 := servers[0].Metadata() + res.UpdateLeaderAddr(srv0.Datacenter, srv0.Addr.String()) conn, err := pool.ClientConnLeader() require.NoError(t, err) @@ -144,7 +334,8 @@ func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) { require.Equal(t, first.ServerName, servers[0].name) // Update the leader address and make another request. - res.UpdateLeaderAddr(servers[1].addr.String()) + srv1 := servers[1].Metadata() + res.UpdateLeaderAddr(srv1.Datacenter, srv1.Addr.String()) resp, err := client.Something(ctx, &testservice.Req{}) require.NoError(t, err) @@ -162,11 +353,16 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) { count := 5 res := resolver.NewServerResolverBuilder(newConfig(t)) registerWithGRPC(t, res) - pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue) + pool := NewClientConnPool(ClientConnPoolConfig{ + Servers: res, + UseTLSForDC: useTLSForDcAlwaysTrue, + DialingFromServer: true, + DialingFromDatacenter: "dc1", + }) for i := 0; i < count; i++ { name := fmt.Sprintf("server-%d", i) - srv := newTestServer(t, name, "dc1") + srv := newTestServer(t, name, "dc1", nil) res.AddServer(srv.Metadata()) t.Cleanup(srv.shutdown) } @@ -211,11 +407,16 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { res := resolver.NewServerResolverBuilder(newConfig(t)) registerWithGRPC(t, res) - pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue) + pool := NewClientConnPool(ClientConnPoolConfig{ + Servers: res, + UseTLSForDC: useTLSForDcAlwaysTrue, + DialingFromServer: true, + DialingFromDatacenter: "dc1", + }) for _, dc := range dcs { name := "server-0-" + dc - srv := newTestServer(t, name, dc) + srv := newTestServer(t, name, dc, nil) res.AddServer(srv.Metadata()) t.Cleanup(srv.shutdown) } diff --git a/agent/grpc/resolver/resolver.go b/agent/grpc/resolver/resolver.go index 570016845..f6c3d7fe9 100644 --- a/agent/grpc/resolver/resolver.go +++ b/agent/grpc/resolver/resolver.go @@ -67,17 +67,17 @@ func (s *ServerResolverBuilder) NewRebalancer(dc string) func() { } } -// ServerForAddr returns server metadata for a server with the specified address. -func (s *ServerResolverBuilder) ServerForAddr(addr string) (*metadata.Server, error) { +// ServerForGlobalAddr returns server metadata for a server with the specified globally unique address. +func (s *ServerResolverBuilder) ServerForGlobalAddr(globalAddr string) (*metadata.Server, error) { s.lock.RLock() defer s.lock.RUnlock() for _, server := range s.servers { - if server.Addr.String() == addr { + if DCPrefix(server.Datacenter, server.Addr.String()) == globalAddr { return server, nil } } - return nil, fmt.Errorf("failed to find Consul server for address %q", addr) + return nil, fmt.Errorf("failed to find Consul server for global address %q", globalAddr) } // Build returns a new serverResolver for the given ClientConn. The resolver @@ -161,6 +161,12 @@ func uniqueID(server *metadata.Server) string { return server.Datacenter + "-" + server.ID } +// DCPrefix prefixes the given string with a datacenter for use in +// disambiguation. +func DCPrefix(datacenter, suffix string) string { + return datacenter + "-" + suffix +} + // RemoveServer updates the resolvers' states with the given server removed. func (s *ServerResolverBuilder) RemoveServer(server *metadata.Server) { s.lock.Lock() @@ -186,7 +192,8 @@ func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address { } addrs = append(addrs, resolver.Address{ - Addr: server.Addr.String(), + // NOTE: the address persisted here is only dialable using our custom dialer + Addr: DCPrefix(server.Datacenter, server.Addr.String()), Type: resolver.Backend, ServerName: server.Name, }) @@ -195,11 +202,11 @@ func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address { } // UpdateLeaderAddr updates the leader address in the local DC's resolver. -func (s *ServerResolverBuilder) UpdateLeaderAddr(leaderAddr string) { +func (s *ServerResolverBuilder) UpdateLeaderAddr(datacenter, addr string) { s.lock.Lock() defer s.lock.Unlock() - s.leaderResolver.addr = leaderAddr + s.leaderResolver.globalAddr = DCPrefix(datacenter, addr) s.leaderResolver.updateClientConn() } @@ -262,7 +269,7 @@ func (r *serverResolver) Close() { func (*serverResolver) ResolveNow(resolver.ResolveNowOption) {} type leaderResolver struct { - addr string + globalAddr string clientConn resolver.ClientConn } @@ -271,12 +278,13 @@ func (l leaderResolver) ResolveNow(resolver.ResolveNowOption) {} func (l leaderResolver) Close() {} func (l leaderResolver) updateClientConn() { - if l.addr == "" || l.clientConn == nil { + if l.globalAddr == "" || l.clientConn == nil { return } addrs := []resolver.Address{ { - Addr: l.addr, + // NOTE: the address persisted here is only dialable using our custom dialer + Addr: l.globalAddr, Type: resolver.Backend, ServerName: "leader", }, diff --git a/agent/grpc/server_test.go b/agent/grpc/server_test.go index 442b617d5..9945f1e6c 100644 --- a/agent/grpc/server_test.go +++ b/agent/grpc/server_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "strconv" "sync/atomic" "testing" "time" @@ -17,6 +18,7 @@ import ( "github.com/hashicorp/consul/agent/grpc/internal/testservice" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/pool" + "github.com/hashicorp/consul/sdk/freeport" "github.com/hashicorp/consul/tlsutil" ) @@ -31,22 +33,29 @@ type testServer struct { func (s testServer) Metadata() *metadata.Server { return &metadata.Server{ ID: s.name, + Name: s.name + "." + s.dc, + ShortName: s.name, Datacenter: s.dc, Addr: s.addr, UseTLS: s.rpc.tlsConf != nil, } } -func newTestServer(t *testing.T, name string, dc string) testServer { +func newTestServer(t *testing.T, name string, dc string, tlsConf *tlsutil.Configurator) testServer { addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} handler := NewHandler(addr, func(server *grpc.Server) { testservice.RegisterSimpleServer(server, &simple{name: name, dc: dc}) }) - lis, err := net.Listen("tcp", "127.0.0.1:0") + ports := freeport.MustTake(1) + t.Cleanup(func() { + freeport.Return(ports) + }) + + lis, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(ports[0]))) require.NoError(t, err) - rpc := &fakeRPCListener{t: t, handler: handler} + rpc := &fakeRPCListener{t: t, handler: handler, tlsConf: tlsConf} g := errgroup.Group{} g.Go(func() error { @@ -107,11 +116,12 @@ func (s *simple) Something(_ context.Context, _ *testservice.Req) (*testservice. // For now, since this logic is in agent/consul, we can't easily use Server.listen // so we fake it. type fakeRPCListener struct { - t *testing.T - handler *Handler - shutdown bool - tlsConf *tlsutil.Configurator - tlsConnEstablished int32 + t *testing.T + handler *Handler + shutdown bool + tlsConf *tlsutil.Configurator + tlsConnEstablished int32 + alpnConnEstablished int32 } func (f *fakeRPCListener) listen(listener net.Listener) error { @@ -129,6 +139,26 @@ func (f *fakeRPCListener) listen(listener net.Listener) error { } func (f *fakeRPCListener) handleConn(conn net.Conn) { + if f.tlsConf != nil && f.tlsConf.MutualTLSCapable() { + // See if actually this is native TLS multiplexed onto the old + // "type-byte" system. + + peekedConn, nativeTLS, err := pool.PeekForTLS(conn) + if err != nil { + if err != io.EOF { + fmt.Printf("ERROR: failed to read first byte: %v\n", err) + } + conn.Close() + return + } + + if nativeTLS { + f.handleNativeTLSConn(peekedConn) + return + } + conn = peekedConn + } + buf := make([]byte, 1) if _, err := conn.Read(buf); err != nil { @@ -166,3 +196,32 @@ func (f *fakeRPCListener) handleConn(conn net.Conn) { conn.Close() } } + +func (f *fakeRPCListener) handleNativeTLSConn(conn net.Conn) { + tlscfg := f.tlsConf.IncomingALPNRPCConfig(pool.RPCNextProtos) + tlsConn := tls.Server(conn, tlscfg) + + // Force the handshake to conclude. + if err := tlsConn.Handshake(); err != nil { + fmt.Printf("ERROR: TLS handshake failed: %v", err) + conn.Close() + return + } + + conn.SetReadDeadline(time.Time{}) + + var ( + cs = tlsConn.ConnectionState() + nextProto = cs.NegotiatedProtocol + ) + + switch nextProto { + case pool.ALPN_RPCGRPC: + atomic.AddInt32(&f.alpnConnEstablished, 1) + f.handler.Handle(tlsConn) + + default: + fmt.Printf("ERROR: discarding RPC for unknown negotiated protocol %q\n", nextProto) + conn.Close() + } +} diff --git a/agent/pool/conn.go b/agent/pool/conn.go index 79731953b..cc19171c9 100644 --- a/agent/pool/conn.go +++ b/agent/pool/conn.go @@ -20,6 +20,8 @@ func (t RPCType) ALPNString() string { return ALPN_RPCGossip case RPCTLSInsecure: return "" // unsupported + case RPCGRPC: + return ALPN_RPCGRPC default: return "" // unsupported } @@ -28,19 +30,19 @@ func (t RPCType) ALPNString() string { const ( // keep numbers unique. RPCConsul RPCType = 0 - RPCRaft = 1 - RPCMultiplex = 2 // Old Muxado byte, no longer supported. - RPCTLS = 3 - RPCMultiplexV2 = 4 - RPCSnapshot = 5 - RPCGossip = 6 + RPCRaft RPCType = 1 + RPCMultiplex RPCType = 2 // Old Muxado byte, no longer supported. + RPCTLS RPCType = 3 + RPCMultiplexV2 RPCType = 4 + RPCSnapshot RPCType = 5 + RPCGossip RPCType = 6 // RPCTLSInsecure is used to flag RPC calls that require verify // incoming to be disabled, even when it is turned on in the // configuration. At the time of writing there is only AutoEncrypt.Sign // that is supported and it might be the only one there // ever is. - RPCTLSInsecure = 7 - RPCGRPC = 8 + RPCTLSInsecure RPCType = 7 + RPCGRPC RPCType = 8 // RPCMaxTypeValue is the maximum rpc type byte value currently used for the // various protocols riding over our "rpc" port. @@ -79,6 +81,7 @@ var RPCNextProtos = []string{ ALPN_RPCMultiplexV2, ALPN_RPCSnapshot, ALPN_RPCGossip, + ALPN_RPCGRPC, ALPN_WANGossipPacket, ALPN_WANGossipStream, } diff --git a/agent/pool/peek_test.go b/agent/pool/peek_test.go index 8b50bb2ea..b9e74ad93 100644 --- a/agent/pool/peek_test.go +++ b/agent/pool/peek_test.go @@ -10,8 +10,9 @@ import ( "testing" "time" - "github.com/hashicorp/consul/tlsutil" "github.com/stretchr/testify/require" + + "github.com/hashicorp/consul/tlsutil" ) func TestPeekForTLS_not_TLS(t *testing.T) { @@ -30,6 +31,7 @@ func TestPeekForTLS_not_TLS(t *testing.T) { RPCSnapshot, RPCGossip, RPCTLSInsecure, + RPCGRPC, } { cases = append(cases, testcase{ name: fmt.Sprintf("tcp rpc type byte %d", rpcType), @@ -76,6 +78,7 @@ func TestPeekForTLS_actual_TLS(t *testing.T) { RPCSnapshot, RPCGossip, RPCTLSInsecure, + RPCGRPC, } { cases = append(cases, testcase{ name: fmt.Sprintf("tcp rpc type byte %d", rpcType), diff --git a/agent/pool/pool.go b/agent/pool/pool.go index 315e59b94..64d779b60 100644 --- a/agent/pool/pool.go +++ b/agent/pool/pool.go @@ -2,6 +2,7 @@ package pool import ( "container/list" + "context" "crypto/tls" "fmt" "log" @@ -11,14 +12,15 @@ import ( "sync/atomic" "time" + msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" + "github.com/hashicorp/yamux" + "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/tlsutil" - msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" - "github.com/hashicorp/yamux" ) -const defaultDialTimeout = 10 * time.Second +const DefaultDialTimeout = 10 * time.Second // muxSession is used to provide an interface for a stream multiplexer. type muxSession interface { @@ -291,21 +293,24 @@ func (p *ConnPool) DialTimeout( ) (net.Conn, HalfCloser, error) { p.once.Do(p.init) - if p.Server && p.GatewayResolver != nil && p.TLSConfigurator != nil && dc != p.Datacenter { + if p.Server && + p.GatewayResolver != nil && + p.TLSConfigurator != nil && + dc != p.Datacenter { // NOTE: TLS is required on this branch. - return DialTimeoutWithRPCTypeViaMeshGateway( + nextProto := actualRPCType.ALPNString() + if nextProto == "" { + return nil, nil, fmt.Errorf("rpc type %d cannot be routed through a mesh gateway", actualRPCType) + } + return DialRPCViaMeshGateway( + context.Background(), dc, nodeName, - addr, p.SrcAddr, p.TLSConfigurator.OutgoingALPNRPCWrapper(), - actualRPCType, - RPCTLS, - // gateway stuff + nextProto, p.Server, - p.TLSConfigurator, p.GatewayResolver, - p.Datacenter, ) } @@ -319,7 +324,7 @@ func (p *ConnPool) dial( tlsRPCType RPCType, ) (net.Conn, HalfCloser, error) { // Try to dial the conn - d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: defaultDialTimeout} + d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: DefaultDialTimeout} conn, err := d.Dial("tcp", addr.String()) if err != nil { return nil, nil, err @@ -372,62 +377,49 @@ func (p *ConnPool) dial( return conn, hc, nil } -// DialTimeoutWithRPCTypeViaMeshGateway dials the destination node and sets up -// the connection to be the correct RPC type using ALPN. This currently is -// exclusively used to dial other servers in foreign datacenters via mesh -// gateways. -// -// NOTE: There is a close mirror of this method in agent/consul/wanfed/wanfed.go:dial -func DialTimeoutWithRPCTypeViaMeshGateway( - dc string, - nodeName string, - addr net.Addr, - src *net.TCPAddr, - wrapper tlsutil.ALPNWrapper, - actualRPCType RPCType, - tlsRPCType RPCType, - // gateway stuff +// DialRPCViaMeshGateway dials the destination node and sets up the connection +// to be the correct RPC type using ALPN. This currently is exclusively used to +// dial other servers in foreign datacenters via mesh gateways. +func DialRPCViaMeshGateway( + ctx context.Context, + dc string, // (metadata.Server).Datacenter + nodeName string, // (metadata.Server).ShortName + srcAddr *net.TCPAddr, + alpnWrapper tlsutil.ALPNWrapper, + nextProto string, dialingFromServer bool, - tlsConfigurator *tlsutil.Configurator, gatewayResolver func(string) string, - thisDatacenter string, ) (net.Conn, HalfCloser, error) { if !dialingFromServer { return nil, nil, fmt.Errorf("must dial via mesh gateways from a server agent") } else if gatewayResolver == nil { return nil, nil, fmt.Errorf("gatewayResolver is nil") - } else if tlsConfigurator == nil { - return nil, nil, fmt.Errorf("tlsConfigurator is nil") - } else if dc == thisDatacenter { - return nil, nil, fmt.Errorf("cannot dial servers in the same datacenter via a mesh gateway") - } else if wrapper == nil { + } else if alpnWrapper == nil { return nil, nil, fmt.Errorf("cannot dial via a mesh gateway when outgoing TLS is disabled") } - nextProto := actualRPCType.ALPNString() - if nextProto == "" { - return nil, nil, fmt.Errorf("rpc type %d cannot be routed through a mesh gateway", actualRPCType) - } - gwAddr := gatewayResolver(dc) if gwAddr == "" { return nil, nil, structs.ErrDCNotAvailable } - dialer := &net.Dialer{LocalAddr: src, Timeout: defaultDialTimeout} + dialer := &net.Dialer{LocalAddr: srcAddr, Timeout: DefaultDialTimeout} - rawConn, err := dialer.Dial("tcp", gwAddr) + rawConn, err := dialer.DialContext(ctx, "tcp", gwAddr) if err != nil { return nil, nil, err } - if tcp, ok := rawConn.(*net.TCPConn); ok { - _ = tcp.SetKeepAlive(true) - _ = tcp.SetNoDelay(true) + if nextProto != ALPN_RPCGRPC { + // agent/grpc/client.go:dial() handles this in another way for gRPC + if tcp, ok := rawConn.(*net.TCPConn); ok { + _ = tcp.SetKeepAlive(true) + _ = tcp.SetNoDelay(true) + } } // NOTE: now we wrap the connection in a TLS client. - tlsConn, err := wrapper(dc, nodeName, nextProto, rawConn) + tlsConn, err := alpnWrapper(dc, nodeName, nextProto, rawConn) if err != nil { return nil, nil, err } diff --git a/agent/setup.go b/agent/setup.go index fdb750fde..c10a5e34f 100644 --- a/agent/setup.go +++ b/agent/setup.go @@ -106,9 +106,22 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer) (BaseDeps, error) d.ViewStore = submatview.NewStore(d.Logger.Named("viewstore")) d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator) - builder := resolver.NewServerResolverBuilder(resolver.Config{}) + builder := resolver.NewServerResolverBuilder(resolver.Config{ + // Set the authority to something sufficiently unique so any usage in + // tests would be self-isolating in the global resolver map, while also + // not incurring a huge penalty for non-test code. + Authority: cfg.Datacenter + "." + string(cfg.NodeID), + }) resolver.Register(builder) - d.GRPCConnPool = grpc.NewClientConnPool(builder, grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()), d.TLSConfigurator.UseTLS) + d.GRPCConnPool = grpc.NewClientConnPool(grpc.ClientConnPoolConfig{ + Servers: builder, + SrcAddr: d.ConnPool.SrcAddr, + TLSWrapper: grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()), + ALPNWrapper: grpc.ALPNWrapper(d.TLSConfigurator.OutgoingALPNRPCWrapper()), + UseTLSForDC: d.TLSConfigurator.UseTLS, + DialingFromServer: cfg.ServerMode, + DialingFromDatacenter: cfg.Datacenter, + }) d.LeaderForwarder = builder d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder)