grpc: ensure that streaming gRPC requests work over mesh gateway based wan federation (#10838)

Fixes #10796
This commit is contained in:
R.B. Boyer 2021-08-24 16:28:44 -05:00 committed by GitHub
parent 6b574abc89
commit a84f5fa25d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 589 additions and 183 deletions

3
.changelog/10838.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
grpc: ensure that streaming gRPC requests work over mesh gateway based wan federation
```

View File

@ -371,6 +371,9 @@ func (f fakeGRPCConnPool) ClientConnLeader() (*grpc.ClientConn, error) {
return nil, nil return nil, nil
} }
func (f fakeGRPCConnPool) SetGatewayResolver(_ func(string) string) {
}
func TestAgent_ReconnectConfigWanDisabled(t *testing.T) { func TestAgent_ReconnectConfigWanDisabled(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -4524,6 +4527,9 @@ LOOP:
} }
// This is a mirror of a similar test in agent/consul/server_test.go // This is a mirror of a similar test in agent/consul/server_test.go
//
// TODO(rb): implement something similar to this as a full containerized test suite with proper
// isolation so requests can't "cheat" and bypass the mesh gateways
func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) { func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")
@ -4771,6 +4777,9 @@ func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) {
}) })
// Ensure we can do some trivial RPC in all directions. // Ensure we can do some trivial RPC in all directions.
//
// NOTE: we explicitly make streaming and non-streaming assertions here to
// verify both rpc and grpc codepaths.
agents := map[string]*TestAgent{"dc1": a1, "dc2": a2, "dc3": a3} agents := map[string]*TestAgent{"dc1": a1, "dc2": a2, "dc3": a3}
names := map[string]string{"dc1": "bob", "dc2": "betty", "dc3": "bonnie"} names := map[string]string{"dc1": "bob", "dc2": "betty", "dc3": "bonnie"}
for _, srcDC := range []string{"dc1", "dc2", "dc3"} { for _, srcDC := range []string{"dc1", "dc2", "dc3"} {
@ -4780,20 +4789,39 @@ func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) {
continue continue
} }
t.Run(srcDC+" to "+dstDC, func(t *testing.T) { t.Run(srcDC+" to "+dstDC, func(t *testing.T) {
req, err := http.NewRequest("GET", "/v1/catalog/nodes?dc="+dstDC, nil) t.Run("normal-rpc", func(t *testing.T) {
require.NoError(t, err) req, err := http.NewRequest("GET", "/v1/catalog/nodes?dc="+dstDC, nil)
require.NoError(t, err)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
obj, err := a.srv.CatalogNodes(resp, req) obj, err := a.srv.CatalogNodes(resp, req)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, obj) require.NotNil(t, obj)
nodes, ok := obj.(structs.Nodes) nodes, ok := obj.(structs.Nodes)
require.True(t, ok) require.True(t, ok)
require.Len(t, nodes, 1) require.Len(t, nodes, 1)
node := nodes[0] node := nodes[0]
require.Equal(t, dstDC, node.Datacenter) require.Equal(t, dstDC, node.Datacenter)
require.Equal(t, names[dstDC], node.Node) require.Equal(t, names[dstDC], node.Node)
})
t.Run("streaming-grpc", func(t *testing.T) {
req, err := http.NewRequest("GET", "/v1/health/service/consul?cached&dc="+dstDC, nil)
require.NoError(t, err)
resp := httptest.NewRecorder()
obj, err := a.srv.HealthServiceNodes(resp, req)
require.NoError(t, err)
require.NotNil(t, obj)
csns, ok := obj.(structs.CheckServiceNodes)
require.True(t, ok)
require.Len(t, csns, 1)
csn := csns[0]
require.Equal(t, dstDC, csn.Node.Datacenter)
require.Equal(t, names[dstDC], csn.Node.Node)
})
}) })
} }
} }

View File

@ -5,10 +5,17 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/hashicorp/go-hclog"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/serf/serf"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
"github.com/hashicorp/consul/agent/grpc" "github.com/hashicorp/consul/agent/grpc"
"github.com/hashicorp/consul/agent/grpc/resolver" "github.com/hashicorp/consul/agent/grpc/resolver"
"github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/pool"
@ -20,11 +27,6 @@ import (
"github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/go-hclog"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/serf/serf"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
) )
func testClientConfig(t *testing.T) (string, *Config) { func testClientConfig(t *testing.T) (string, *Config) {
@ -490,6 +492,13 @@ func newClient(t *testing.T, config *Config) *Client {
return client return client
} }
func newTestResolverConfig(t *testing.T, suffix string) resolver.Config {
n := t.Name()
s := strings.Replace(n, "/", "", -1)
s = strings.Replace(s, "_", "", -1)
return resolver.Config{Authority: strings.ToLower(s) + "-" + suffix}
}
func newDefaultDeps(t *testing.T, c *Config) Deps { func newDefaultDeps(t *testing.T, c *Config) Deps {
t.Helper() t.Helper()
@ -502,7 +511,7 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
tls, err := tlsutil.NewConfigurator(c.TLSConfig, logger) tls, err := tlsutil.NewConfigurator(c.TLSConfig, logger)
require.NoError(t, err, "failed to create tls configuration") require.NoError(t, err, "failed to create tls configuration")
builder := resolver.NewServerResolverBuilder(resolver.Config{Authority: c.NodeName}) builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter))
r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter), builder) r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter), builder)
resolver.Register(builder) resolver.Register(builder)
@ -522,7 +531,13 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
Tokens: new(token.Store), Tokens: new(token.Store),
Router: r, Router: r,
ConnPool: connPool, ConnPool: connPool,
GRPCConnPool: grpc.NewClientConnPool(builder, grpc.TLSWrapper(tls.OutgoingRPCWrapper()), tls.UseTLS), GRPCConnPool: grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
Servers: builder,
TLSWrapper: grpc.TLSWrapper(tls.OutgoingRPCWrapper()),
UseTLSForDC: tls.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: c.Datacenter,
}),
LeaderForwarder: builder, LeaderForwarder: builder,
EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c), EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c),
} }

View File

@ -24,8 +24,10 @@ type Deps struct {
type GRPCClientConner interface { type GRPCClientConner interface {
ClientConn(datacenter string) (*grpc.ClientConn, error) ClientConn(datacenter string) (*grpc.ClientConn, error)
ClientConnLeader() (*grpc.ClientConn, error) ClientConnLeader() (*grpc.ClientConn, error)
SetGatewayResolver(func(string) string)
} }
type LeaderForwarder interface { type LeaderForwarder interface {
UpdateLeaderAddr(leaderAddr string) // UpdateLeaderAddr updates the leader address in the local DC's resolver.
UpdateLeaderAddr(datacenter, addr string)
} }

View File

@ -293,7 +293,7 @@ func (s *Server) handleNativeTLS(conn net.Conn) {
s.handleSnapshotConn(tlsConn) s.handleSnapshotConn(tlsConn)
case pool.ALPN_RPCGRPC: case pool.ALPN_RPCGRPC:
s.grpcHandler.Handle(conn) s.grpcHandler.Handle(tlsConn)
case pool.ALPN_WANGossipPacket: case pool.ALPN_WANGossipPacket:
if err := s.handleALPN_WANGossipPacketStream(tlsConn); err != nil && err != io.EOF { if err := s.handleALPN_WANGossipPacketStream(tlsConn); err != nil && err != io.EOF {
@ -373,7 +373,7 @@ func (s *Server) handleMultiplexV2(conn net.Conn) {
} }
sub = peeked sub = peeked
switch first { switch first {
case pool.RPCGossip: case byte(pool.RPCGossip):
buf := make([]byte, 1) buf := make([]byte, 1)
sub.Read(buf) sub.Read(buf)
go func() { go func() {

View File

@ -460,7 +460,7 @@ func TestRPC_TLSHandshakeTimeout(t *testing.T) {
// Write TLS byte to avoid being closed by either the (outer) first byte // Write TLS byte to avoid being closed by either the (outer) first byte
// timeout or the fact that server requires TLS // timeout or the fact that server requires TLS
_, err = conn.Write([]byte{pool.RPCTLS}) _, err = conn.Write([]byte{byte(pool.RPCTLS)})
require.NoError(t, err) require.NoError(t, err)
// Wait for more than the timeout before we start a TLS handshake. This is // Wait for more than the timeout before we start a TLS handshake. This is

View File

@ -173,6 +173,9 @@ type Server struct {
// Connection pool to other consul servers // Connection pool to other consul servers
connPool *pool.ConnPool connPool *pool.ConnPool
// Connection pool to other consul servers using gRPC
grpcConnPool GRPCClientConner
// eventChLAN is used to receive events from the // eventChLAN is used to receive events from the
// serf cluster in the datacenter // serf cluster in the datacenter
eventChLAN chan serf.Event eventChLAN chan serf.Event
@ -348,6 +351,7 @@ func NewServer(config *Config, flat Deps) (*Server, error) {
config: config, config: config,
tokens: flat.Tokens, tokens: flat.Tokens,
connPool: flat.ConnPool, connPool: flat.ConnPool,
grpcConnPool: flat.GRPCConnPool,
eventChLAN: make(chan serf.Event, serfEventChSize), eventChLAN: make(chan serf.Event, serfEventChSize),
eventChWAN: make(chan serf.Event, serfEventChSize), eventChWAN: make(chan serf.Event, serfEventChSize),
logger: serverLogger, logger: serverLogger,
@ -377,6 +381,7 @@ func NewServer(config *Config, flat Deps) (*Server, error) {
s.config.PrimaryDatacenter, s.config.PrimaryDatacenter,
) )
s.connPool.GatewayResolver = s.gatewayLocator.PickGateway s.connPool.GatewayResolver = s.gatewayLocator.PickGateway
s.grpcConnPool.SetGatewayResolver(s.gatewayLocator.PickGateway)
} }
// Initialize enterprise specific server functionality // Initialize enterprise specific server functionality
@ -1461,7 +1466,7 @@ func (s *Server) trackLeaderChanges() {
continue continue
} }
s.grpcLeaderForwarder.UpdateLeaderAddr(string(leaderObs.Leader)) s.grpcLeaderForwarder.UpdateLeaderAddr(s.config.Datacenter, string(leaderObs.Leader))
case <-s.shutdownCh: case <-s.shutdownCh:
s.raft.DeregisterObserver(observer) s.raft.DeregisterObserver(observer)
return return

View File

@ -25,6 +25,8 @@ import (
func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) { func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
t.Parallel() t.Parallel()
// TODO(rb): add tests for the wanfed/alpn variations
_, conf1 := testServerConfig(t) _, conf1 := testServerConfig(t)
conf1.TLSConfig.VerifyIncoming = true conf1.TLSConfig.VerifyIncoming = true
conf1.TLSConfig.VerifyOutgoing = true conf1.TLSConfig.VerifyOutgoing = true
@ -60,7 +62,13 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
// Start a Subscribe call to our streaming endpoint from the client. // Start a Subscribe call to our streaming endpoint from the client.
{ {
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS) pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
Servers: builder,
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
conn, err := pool.ClientConn("dc1") conn, err := pool.ClientConn("dc1")
require.NoError(t, err) require.NoError(t, err)
@ -91,8 +99,13 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
// Start a Subscribe call to our streaming endpoint from the server's loopback client. // Start a Subscribe call to our streaming endpoint from the server's loopback client.
{ {
pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS) Servers: builder,
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
conn, err := pool.ClientConn("dc1") conn, err := pool.ClientConn("dc1")
require.NoError(t, err) require.NoError(t, err)
@ -166,7 +179,13 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) {
// Subscribe calls should fail initially // Subscribe calls should fail initially
joinLAN(t, client, server) joinLAN(t, client, server)
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS) pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
Servers: builder,
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
conn, err := pool.ClientConn("dc1") conn, err := pool.ClientConn("dc1")
require.NoError(t, err) require.NoError(t, err)
@ -294,7 +313,13 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
} }
}() }()
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS) pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
Servers: builder,
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
conn, err := pool.ClientConn("dc1") conn, err := pool.ClientConn("dc1")
require.NoError(t, err) require.NoError(t, err)
@ -337,7 +362,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
} }
func newClientWithGRPCResolver(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) { func newClientWithGRPCResolver(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) {
builder := resolver.NewServerResolverBuilder(resolver.Config{Authority: t.Name()}) builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, "client"))
resolver.Register(builder) resolver.Register(builder)
t.Cleanup(func() { t.Cleanup(func() {
resolver.Deregister(builder.Authority()) resolver.Deregister(builder.Authority())

View File

@ -1,6 +1,7 @@
package wanfed package wanfed
import ( import (
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -11,7 +12,6 @@ import (
"github.com/hashicorp/memberlist" "github.com/hashicorp/memberlist"
"github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/tlsutil"
) )
@ -97,13 +97,8 @@ func (t *Transport) WriteToAddress(b []byte, addr memberlist.Address) (time.Time
} }
if dc != t.datacenter { if dc != t.datacenter {
gwAddr := t.gwResolver(dc)
if gwAddr == "" {
return time.Time{}, structs.ErrDCNotAvailable
}
dialFunc := func() (net.Conn, error) { dialFunc := func() (net.Conn, error) {
return t.dial(dc, node, pool.ALPN_WANGossipPacket, gwAddr) return t.dial(dc, node, pool.ALPN_WANGossipPacket)
} }
conn, err := t.pool.AcquireOrDial(addr.Name, dialFunc) conn, err := t.pool.AcquireOrDial(addr.Name, dialFunc)
if err != nil { if err != nil {
@ -136,42 +131,24 @@ func (t *Transport) DialAddressTimeout(addr memberlist.Address, timeout time.Dur
} }
if dc != t.datacenter { if dc != t.datacenter {
gwAddr := t.gwResolver(dc) return t.dial(dc, node, pool.ALPN_WANGossipStream)
if gwAddr == "" {
return nil, structs.ErrDCNotAvailable
}
return t.dial(dc, node, pool.ALPN_WANGossipStream, gwAddr)
} }
return t.IngestionAwareTransport.DialAddressTimeout(addr, timeout) return t.IngestionAwareTransport.DialAddressTimeout(addr, timeout)
} }
// NOTE: There is a close mirror of this method in agent/pool/pool.go:DialTimeoutWithRPCType func (t *Transport) dial(dc, nodeName, nextProto string) (net.Conn, error) {
func (t *Transport) dial(dc, nodeName, nextProto, addr string) (net.Conn, error) { conn, _, err := pool.DialRPCViaMeshGateway(
wrapper := t.tlsConfigurator.OutgoingALPNRPCWrapper() context.Background(),
if wrapper == nil { dc,
return nil, fmt.Errorf("wanfed: cannot dial via a mesh gateway when outgoing TLS is disabled") nodeName,
} nil, // TODO(rb): thread source address through here?
t.tlsConfigurator.OutgoingALPNRPCWrapper(),
dialer := &net.Dialer{Timeout: 10 * time.Second} nextProto,
true,
rawConn, err := dialer.Dial("tcp", addr) t.gwResolver,
if err != nil { )
return nil, err return conn, err
}
if tcp, ok := rawConn.(*net.TCPConn); ok {
_ = tcp.SetKeepAlive(true)
_ = tcp.SetNoDelay(true)
}
tlsConn, err := wrapper(dc, nodeName, nextProto, rawConn)
if err != nil {
return nil, err
}
return tlsConn, nil
} }
// SplitNodeName splits a node name as it would be represented in // SplitNodeName splits a node name as it would be represented in

View File

@ -12,38 +12,93 @@ import (
"github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/tlsutil"
) )
// ClientConnPool creates and stores a connection for each datacenter. // ClientConnPool creates and stores a connection for each datacenter.
type ClientConnPool struct { type ClientConnPool struct {
dialer dialer dialer dialer
servers ServerLocator servers ServerLocator
conns map[string]*grpc.ClientConn gwResolverDep gatewayResolverDep
connsLock sync.Mutex conns map[string]*grpc.ClientConn
connsLock sync.Mutex
} }
type ServerLocator interface { type ServerLocator interface {
// ServerForAddr is used to look up server metadata from an address. // ServerForGlobalAddr returns server metadata for a server with the specified globally unique address.
ServerForAddr(addr string) (*metadata.Server, error) ServerForGlobalAddr(globalAddr string) (*metadata.Server, error)
// Authority returns the target authority to use to dial the server. This is primarily // Authority returns the target authority to use to dial the server. This is primarily
// needed for testing multiple agents in parallel, because gRPC requires the // needed for testing multiple agents in parallel, because gRPC requires the
// resolver to be registered globally. // resolver to be registered globally.
Authority() string Authority() string
} }
// gatewayResolverDep is just a holder for a function pointer that can be
// updated lazily after the structs are instantiated (but before first use)
// and all structs with a reference to this struct will see the same update.
type gatewayResolverDep struct {
// GatewayResolver is a function that returns a suitable random mesh
// gateway address for dialing servers in a given DC. This is only
// needed if wan federation via mesh gateways is enabled.
GatewayResolver func(string) string
}
// TLSWrapper wraps a non-TLS connection and returns a connection with TLS // TLSWrapper wraps a non-TLS connection and returns a connection with TLS
// enabled. // enabled.
type TLSWrapper func(dc string, conn net.Conn) (net.Conn, error) type TLSWrapper func(dc string, conn net.Conn) (net.Conn, error)
// ALPNWrapper is a function that is used to wrap a non-TLS connection and
// returns an appropriate TLS connection or error. This taks a datacenter and
// node name as argument to configure the desired SNI value and the desired
// next proto for configuring ALPN.
type ALPNWrapper func(dc, nodeName, alpnProto string, conn net.Conn) (net.Conn, error)
type dialer func(context.Context, string) (net.Conn, error) type dialer func(context.Context, string) (net.Conn, error)
// NewClientConnPool create new GRPC client pool to connect to servers using GRPC over RPC type ClientConnPoolConfig struct {
func NewClientConnPool(servers ServerLocator, tls TLSWrapper, useTLSForDC func(dc string) bool) *ClientConnPool { // Servers is a reference for how to figure out how to dial any server.
return &ClientConnPool{ Servers ServerLocator
dialer: newDialer(servers, tls, useTLSForDC),
servers: servers, // SrcAddr is the source address for outgoing connections.
SrcAddr *net.TCPAddr
// TLSWrapper is the specifics of wrapping a socket when doing an TYPE_BYTE+TLS
// wrapped RPC request.
TLSWrapper TLSWrapper
// ALPNWrapper is the specifics of wrapping a socket when doing an ALPN+TLS
// wrapped RPC request (typically only for wan federation via mesh
// gateways).
ALPNWrapper ALPNWrapper
// UseTLSForDC is a function to determine if dialing a given datacenter
// should use TLS.
UseTLSForDC func(dc string) bool
// DialingFromServer should be set to true if this connection pool is owned
// by a consul server instance.
DialingFromServer bool
// DialingFromDatacenter is the datacenter of the consul agent using this
// pool.
DialingFromDatacenter string
}
// NewClientConnPool create new GRPC client pool to connect to servers using
// GRPC over RPC.
func NewClientConnPool(cfg ClientConnPoolConfig) *ClientConnPool {
c := &ClientConnPool{
servers: cfg.Servers,
conns: make(map[string]*grpc.ClientConn), conns: make(map[string]*grpc.ClientConn),
} }
c.dialer = newDialer(cfg, &c.gwResolverDep)
return c
}
// SetGatewayResolver is only to be called during setup before the pool is used.
func (c *ClientConnPool) SetGatewayResolver(gatewayResolver func(string) string) {
c.gwResolverDep.GatewayResolver = gatewayResolver
} }
// ClientConn returns a grpc.ClientConn for the datacenter. If there are no // ClientConn returns a grpc.ClientConn for the datacenter. If there are no
@ -102,22 +157,39 @@ func (c *ClientConnPool) dial(datacenter string, serverType string) (*grpc.Clien
// newDialer returns a gRPC dialer function that conditionally wraps the connection // newDialer returns a gRPC dialer function that conditionally wraps the connection
// with TLS based on the Server.useTLS value. // with TLS based on the Server.useTLS value.
func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc string) bool) func(context.Context, string) (net.Conn, error) { func newDialer(cfg ClientConnPoolConfig, gwResolverDep *gatewayResolverDep) func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, addr string) (net.Conn, error) { return func(ctx context.Context, globalAddr string) (net.Conn, error) {
d := net.Dialer{} server, err := cfg.Servers.ServerForGlobalAddr(globalAddr)
conn, err := d.DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
server, err := servers.ServerForAddr(addr) if cfg.DialingFromServer &&
gwResolverDep.GatewayResolver != nil &&
cfg.ALPNWrapper != nil &&
server.Datacenter != cfg.DialingFromDatacenter {
// NOTE: TLS is required on this branch.
conn, _, err := pool.DialRPCViaMeshGateway(
ctx,
server.Datacenter,
server.ShortName,
cfg.SrcAddr,
tlsutil.ALPNWrapper(cfg.ALPNWrapper),
pool.ALPN_RPCGRPC,
cfg.DialingFromServer,
gwResolverDep.GatewayResolver,
)
return conn, err
}
d := net.Dialer{LocalAddr: cfg.SrcAddr, Timeout: pool.DefaultDialTimeout}
conn, err := d.DialContext(ctx, "tcp", server.Addr.String())
if err != nil { if err != nil {
conn.Close()
return nil, err return nil, err
} }
if server.UseTLS && useTLSForDC(server.Datacenter) { if server.UseTLS && cfg.UseTLSForDC(server.Datacenter) {
if wrapper == nil { if cfg.TLSWrapper == nil {
conn.Close() conn.Close()
return nil, fmt.Errorf("TLS enabled but got nil TLS wrapper") return nil, fmt.Errorf("TLS enabled but got nil TLS wrapper")
} }
@ -129,7 +201,7 @@ func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc st
} }
// Wrap the connection in a TLS client // Wrap the connection in a TLS client
tlsConn, err := wrapper(server.Datacenter, conn) tlsConn, err := cfg.TLSWrapper(server.Datacenter, conn)
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
@ -137,7 +209,7 @@ func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc st
conn = tlsConn conn = tlsConn
} }
_, err = conn.Write([]byte{pool.RPCGRPC}) _, err = conn.Write([]byte{byte(pool.RPCGRPC)})
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, err return nil, err

View File

@ -4,17 +4,22 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
"github.com/google/tcpproxy"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/grpc/internal/testservice" "github.com/hashicorp/consul/agent/grpc/internal/testservice"
"github.com/hashicorp/consul/agent/grpc/resolver" "github.com/hashicorp/consul/agent/grpc/resolver"
"github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/tlsutil"
) )
@ -24,11 +29,14 @@ func useTLSForDcAlwaysTrue(_ string) bool {
} }
func TestNewDialer_WithTLSWrapper(t *testing.T) { func TestNewDialer_WithTLSWrapper(t *testing.T) {
lis, err := net.Listen("tcp", "127.0.0.1:0") ports := freeport.MustTake(1)
defer freeport.Return(ports)
lis, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(ports[0])))
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(logError(t, lis.Close)) t.Cleanup(logError(t, lis.Close))
builder := resolver.NewServerResolverBuilder(resolver.Config{}) builder := resolver.NewServerResolverBuilder(newConfig(t))
builder.AddServer(&metadata.Server{ builder.AddServer(&metadata.Server{
Name: "server-1", Name: "server-1",
ID: "ID1", ID: "ID1",
@ -42,19 +50,107 @@ func TestNewDialer_WithTLSWrapper(t *testing.T) {
called = true called = true
return conn, nil return conn, nil
} }
dial := newDialer(builder, wrapper, useTLSForDcAlwaysTrue) dial := newDialer(
ClientConnPoolConfig{
Servers: builder,
TLSWrapper: wrapper,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
},
&gatewayResolverDep{},
)
ctx := context.Background() ctx := context.Background()
conn, err := dial(ctx, lis.Addr().String()) conn, err := dial(ctx, resolver.DCPrefix("dc1", lis.Addr().String()))
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, conn.Close()) require.NoError(t, conn.Close())
require.True(t, called, "expected TLSWrapper to be called") require.True(t, called, "expected TLSWrapper to be called")
} }
func TestNewDialer_WithALPNWrapper(t *testing.T) {
ports := freeport.MustTake(3)
defer freeport.Return(ports)
var (
s1addr = ipaddr.FormatAddressPort("127.0.0.1", ports[0])
s2addr = ipaddr.FormatAddressPort("127.0.0.1", ports[1])
gwAddr = ipaddr.FormatAddressPort("127.0.0.1", ports[2])
)
lis1, err := net.Listen("tcp", s1addr)
require.NoError(t, err)
t.Cleanup(logError(t, lis1.Close))
lis2, err := net.Listen("tcp", s2addr)
require.NoError(t, err)
t.Cleanup(logError(t, lis2.Close))
// Send all of the traffic to dc2's server
var p tcpproxy.Proxy
p.AddRoute(gwAddr, tcpproxy.To(s2addr))
p.AddStopACMESearch(gwAddr)
require.NoError(t, p.Start())
defer func() {
p.Close()
p.Wait()
}()
builder := resolver.NewServerResolverBuilder(newConfig(t))
builder.AddServer(&metadata.Server{
Name: "server-1",
ID: "ID1",
Datacenter: "dc1",
Addr: lis1.Addr(),
UseTLS: true,
})
builder.AddServer(&metadata.Server{
Name: "server-2",
ID: "ID2",
Datacenter: "dc2",
Addr: lis2.Addr(),
UseTLS: true,
})
var calledTLS bool
wrapperTLS := func(_ string, conn net.Conn) (net.Conn, error) {
calledTLS = true
return conn, nil
}
var calledALPN bool
wrapperALPN := func(_, _, _ string, conn net.Conn) (net.Conn, error) {
calledALPN = true
return conn, nil
}
gwResolverDep := &gatewayResolverDep{
GatewayResolver: func(addr string) string {
return gwAddr
},
}
dial := newDialer(
ClientConnPoolConfig{
Servers: builder,
TLSWrapper: wrapperTLS,
ALPNWrapper: wrapperALPN,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
},
gwResolverDep,
)
ctx := context.Background()
conn, err := dial(ctx, resolver.DCPrefix("dc2", lis2.Addr().String()))
require.NoError(t, err)
require.NoError(t, conn.Close())
assert.False(t, calledTLS, "expected TLSWrapper not to be called")
assert.True(t, calledALPN, "expected ALPNWrapper to be called")
}
func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) { func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res) registerWithGRPC(t, res)
srv := newTestServer(t, "server-1", "dc1")
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{ tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
VerifyIncoming: true, VerifyIncoming: true,
VerifyOutgoing: true, VerifyOutgoing: true,
@ -63,12 +159,20 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
KeyFile: "../../test/hostname/Alice.key", KeyFile: "../../test/hostname/Alice.key",
}, hclog.New(nil)) }, hclog.New(nil))
require.NoError(t, err) require.NoError(t, err)
srv.rpc.tlsConf = tlsConf
res.AddServer(srv.Metadata()) srv := newTestServer(t, "server-1", "dc1", tlsConf)
md := srv.Metadata()
res.AddServer(md)
t.Cleanup(srv.shutdown) t.Cleanup(srv.shutdown)
pool := NewClientConnPool(res, TLSWrapper(tlsConf.OutgoingRPCWrapper()), tlsConf.UseTLS) pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
TLSWrapper: TLSWrapper(tlsConf.OutgoingRPCWrapper()),
UseTLSForDC: tlsConf.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
conn, err := pool.ClientConn("dc1") conn, err := pool.ClientConn("dc1")
require.NoError(t, err) require.NoError(t, err)
@ -81,17 +185,98 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "server-1", resp.ServerName) require.Equal(t, "server-1", resp.ServerName)
require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) > 0) require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) > 0)
require.True(t, atomic.LoadInt32(&srv.rpc.alpnConnEstablished) == 0)
}
func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T) {
ports := freeport.MustTake(1)
defer freeport.Return(ports)
gwAddr := ipaddr.FormatAddressPort("127.0.0.1", ports[0])
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
VerifyIncoming: true,
VerifyOutgoing: true,
VerifyServerHostname: true,
CAFile: "../../test/hostname/CertAuth.crt",
CertFile: "../../test/hostname/Bob.crt",
KeyFile: "../../test/hostname/Bob.key",
Domain: "consul",
NodeName: "bob",
}, hclog.New(nil))
require.NoError(t, err)
srv := newTestServer(t, "bob", "dc1", tlsConf)
// Send all of the traffic to dc1's server
var p tcpproxy.Proxy
p.AddRoute(gwAddr, tcpproxy.To(srv.addr.String()))
p.AddStopACMESearch(gwAddr)
require.NoError(t, p.Start())
defer func() {
p.Close()
p.Wait()
}()
md := srv.Metadata()
res.AddServer(md)
t.Cleanup(srv.shutdown)
clientTLSConf, err := tlsutil.NewConfigurator(tlsutil.Config{
VerifyIncoming: true,
VerifyOutgoing: true,
VerifyServerHostname: true,
CAFile: "../../test/hostname/CertAuth.crt",
CertFile: "../../test/hostname/Betty.crt",
KeyFile: "../../test/hostname/Betty.key",
Domain: "consul",
NodeName: "betty",
}, hclog.New(nil))
require.NoError(t, err)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
TLSWrapper: TLSWrapper(clientTLSConf.OutgoingRPCWrapper()),
ALPNWrapper: ALPNWrapper(clientTLSConf.OutgoingALPNRPCWrapper()),
UseTLSForDC: tlsConf.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc2",
})
pool.SetGatewayResolver(func(addr string) string {
return gwAddr
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
client := testservice.NewSimpleClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
t.Cleanup(cancel)
resp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
require.Equal(t, "bob", resp.ServerName)
require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) == 0)
require.True(t, atomic.LoadInt32(&srv.rpc.alpnConnEstablished) > 0)
} }
func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
count := 4 count := 4
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res) registerWithGRPC(t, res)
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue) pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
name := fmt.Sprintf("server-%d", i) name := fmt.Sprintf("server-%d", i)
srv := newTestServer(t, name, "dc1") srv := newTestServer(t, name, "dc1", nil)
res.AddServer(srv.Metadata()) res.AddServer(srv.Metadata())
t.Cleanup(srv.shutdown) t.Cleanup(srv.shutdown)
} }
@ -115,22 +300,27 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) { func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
count := 3 count := 3
conf := newConfig(t) res := resolver.NewServerResolverBuilder(newConfig(t))
res := resolver.NewServerResolverBuilder(conf)
registerWithGRPC(t, res) registerWithGRPC(t, res)
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue) pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
var servers []testServer var servers []testServer
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
name := fmt.Sprintf("server-%d", i) name := fmt.Sprintf("server-%d", i)
srv := newTestServer(t, name, "dc1") srv := newTestServer(t, name, "dc1", nil)
res.AddServer(srv.Metadata()) res.AddServer(srv.Metadata())
servers = append(servers, srv) servers = append(servers, srv)
t.Cleanup(srv.shutdown) t.Cleanup(srv.shutdown)
} }
// Set the leader address to the first server. // Set the leader address to the first server.
res.UpdateLeaderAddr(servers[0].addr.String()) srv0 := servers[0].Metadata()
res.UpdateLeaderAddr(srv0.Datacenter, srv0.Addr.String())
conn, err := pool.ClientConnLeader() conn, err := pool.ClientConnLeader()
require.NoError(t, err) require.NoError(t, err)
@ -144,7 +334,8 @@ func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
require.Equal(t, first.ServerName, servers[0].name) require.Equal(t, first.ServerName, servers[0].name)
// Update the leader address and make another request. // Update the leader address and make another request.
res.UpdateLeaderAddr(servers[1].addr.String()) srv1 := servers[1].Metadata()
res.UpdateLeaderAddr(srv1.Datacenter, srv1.Addr.String())
resp, err := client.Something(ctx, &testservice.Req{}) resp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err) require.NoError(t, err)
@ -162,11 +353,16 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) {
count := 5 count := 5
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res) registerWithGRPC(t, res)
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue) pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
name := fmt.Sprintf("server-%d", i) name := fmt.Sprintf("server-%d", i)
srv := newTestServer(t, name, "dc1") srv := newTestServer(t, name, "dc1", nil)
res.AddServer(srv.Metadata()) res.AddServer(srv.Metadata())
t.Cleanup(srv.shutdown) t.Cleanup(srv.shutdown)
} }
@ -211,11 +407,16 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res) registerWithGRPC(t, res)
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue) pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
for _, dc := range dcs { for _, dc := range dcs {
name := "server-0-" + dc name := "server-0-" + dc
srv := newTestServer(t, name, dc) srv := newTestServer(t, name, dc, nil)
res.AddServer(srv.Metadata()) res.AddServer(srv.Metadata())
t.Cleanup(srv.shutdown) t.Cleanup(srv.shutdown)
} }

View File

@ -67,17 +67,17 @@ func (s *ServerResolverBuilder) NewRebalancer(dc string) func() {
} }
} }
// ServerForAddr returns server metadata for a server with the specified address. // ServerForGlobalAddr returns server metadata for a server with the specified globally unique address.
func (s *ServerResolverBuilder) ServerForAddr(addr string) (*metadata.Server, error) { func (s *ServerResolverBuilder) ServerForGlobalAddr(globalAddr string) (*metadata.Server, error) {
s.lock.RLock() s.lock.RLock()
defer s.lock.RUnlock() defer s.lock.RUnlock()
for _, server := range s.servers { for _, server := range s.servers {
if server.Addr.String() == addr { if DCPrefix(server.Datacenter, server.Addr.String()) == globalAddr {
return server, nil return server, nil
} }
} }
return nil, fmt.Errorf("failed to find Consul server for address %q", addr) return nil, fmt.Errorf("failed to find Consul server for global address %q", globalAddr)
} }
// Build returns a new serverResolver for the given ClientConn. The resolver // Build returns a new serverResolver for the given ClientConn. The resolver
@ -161,6 +161,12 @@ func uniqueID(server *metadata.Server) string {
return server.Datacenter + "-" + server.ID return server.Datacenter + "-" + server.ID
} }
// DCPrefix prefixes the given string with a datacenter for use in
// disambiguation.
func DCPrefix(datacenter, suffix string) string {
return datacenter + "-" + suffix
}
// RemoveServer updates the resolvers' states with the given server removed. // RemoveServer updates the resolvers' states with the given server removed.
func (s *ServerResolverBuilder) RemoveServer(server *metadata.Server) { func (s *ServerResolverBuilder) RemoveServer(server *metadata.Server) {
s.lock.Lock() s.lock.Lock()
@ -186,7 +192,8 @@ func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address {
} }
addrs = append(addrs, resolver.Address{ addrs = append(addrs, resolver.Address{
Addr: server.Addr.String(), // NOTE: the address persisted here is only dialable using our custom dialer
Addr: DCPrefix(server.Datacenter, server.Addr.String()),
Type: resolver.Backend, Type: resolver.Backend,
ServerName: server.Name, ServerName: server.Name,
}) })
@ -195,11 +202,11 @@ func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address {
} }
// UpdateLeaderAddr updates the leader address in the local DC's resolver. // UpdateLeaderAddr updates the leader address in the local DC's resolver.
func (s *ServerResolverBuilder) UpdateLeaderAddr(leaderAddr string) { func (s *ServerResolverBuilder) UpdateLeaderAddr(datacenter, addr string) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
s.leaderResolver.addr = leaderAddr s.leaderResolver.globalAddr = DCPrefix(datacenter, addr)
s.leaderResolver.updateClientConn() s.leaderResolver.updateClientConn()
} }
@ -262,7 +269,7 @@ func (r *serverResolver) Close() {
func (*serverResolver) ResolveNow(resolver.ResolveNowOption) {} func (*serverResolver) ResolveNow(resolver.ResolveNowOption) {}
type leaderResolver struct { type leaderResolver struct {
addr string globalAddr string
clientConn resolver.ClientConn clientConn resolver.ClientConn
} }
@ -271,12 +278,13 @@ func (l leaderResolver) ResolveNow(resolver.ResolveNowOption) {}
func (l leaderResolver) Close() {} func (l leaderResolver) Close() {}
func (l leaderResolver) updateClientConn() { func (l leaderResolver) updateClientConn() {
if l.addr == "" || l.clientConn == nil { if l.globalAddr == "" || l.clientConn == nil {
return return
} }
addrs := []resolver.Address{ addrs := []resolver.Address{
{ {
Addr: l.addr, // NOTE: the address persisted here is only dialable using our custom dialer
Addr: l.globalAddr,
Type: resolver.Backend, Type: resolver.Backend,
ServerName: "leader", ServerName: "leader",
}, },

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"strconv"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@ -17,6 +18,7 @@ import (
"github.com/hashicorp/consul/agent/grpc/internal/testservice" "github.com/hashicorp/consul/agent/grpc/internal/testservice"
"github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/tlsutil"
) )
@ -31,22 +33,29 @@ type testServer struct {
func (s testServer) Metadata() *metadata.Server { func (s testServer) Metadata() *metadata.Server {
return &metadata.Server{ return &metadata.Server{
ID: s.name, ID: s.name,
Name: s.name + "." + s.dc,
ShortName: s.name,
Datacenter: s.dc, Datacenter: s.dc,
Addr: s.addr, Addr: s.addr,
UseTLS: s.rpc.tlsConf != nil, UseTLS: s.rpc.tlsConf != nil,
} }
} }
func newTestServer(t *testing.T, name string, dc string) testServer { func newTestServer(t *testing.T, name string, dc string, tlsConf *tlsutil.Configurator) testServer {
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
handler := NewHandler(addr, func(server *grpc.Server) { handler := NewHandler(addr, func(server *grpc.Server) {
testservice.RegisterSimpleServer(server, &simple{name: name, dc: dc}) testservice.RegisterSimpleServer(server, &simple{name: name, dc: dc})
}) })
lis, err := net.Listen("tcp", "127.0.0.1:0") ports := freeport.MustTake(1)
t.Cleanup(func() {
freeport.Return(ports)
})
lis, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(ports[0])))
require.NoError(t, err) require.NoError(t, err)
rpc := &fakeRPCListener{t: t, handler: handler} rpc := &fakeRPCListener{t: t, handler: handler, tlsConf: tlsConf}
g := errgroup.Group{} g := errgroup.Group{}
g.Go(func() error { g.Go(func() error {
@ -107,11 +116,12 @@ func (s *simple) Something(_ context.Context, _ *testservice.Req) (*testservice.
// For now, since this logic is in agent/consul, we can't easily use Server.listen // For now, since this logic is in agent/consul, we can't easily use Server.listen
// so we fake it. // so we fake it.
type fakeRPCListener struct { type fakeRPCListener struct {
t *testing.T t *testing.T
handler *Handler handler *Handler
shutdown bool shutdown bool
tlsConf *tlsutil.Configurator tlsConf *tlsutil.Configurator
tlsConnEstablished int32 tlsConnEstablished int32
alpnConnEstablished int32
} }
func (f *fakeRPCListener) listen(listener net.Listener) error { func (f *fakeRPCListener) listen(listener net.Listener) error {
@ -129,6 +139,26 @@ func (f *fakeRPCListener) listen(listener net.Listener) error {
} }
func (f *fakeRPCListener) handleConn(conn net.Conn) { func (f *fakeRPCListener) handleConn(conn net.Conn) {
if f.tlsConf != nil && f.tlsConf.MutualTLSCapable() {
// See if actually this is native TLS multiplexed onto the old
// "type-byte" system.
peekedConn, nativeTLS, err := pool.PeekForTLS(conn)
if err != nil {
if err != io.EOF {
fmt.Printf("ERROR: failed to read first byte: %v\n", err)
}
conn.Close()
return
}
if nativeTLS {
f.handleNativeTLSConn(peekedConn)
return
}
conn = peekedConn
}
buf := make([]byte, 1) buf := make([]byte, 1)
if _, err := conn.Read(buf); err != nil { if _, err := conn.Read(buf); err != nil {
@ -166,3 +196,32 @@ func (f *fakeRPCListener) handleConn(conn net.Conn) {
conn.Close() conn.Close()
} }
} }
func (f *fakeRPCListener) handleNativeTLSConn(conn net.Conn) {
tlscfg := f.tlsConf.IncomingALPNRPCConfig(pool.RPCNextProtos)
tlsConn := tls.Server(conn, tlscfg)
// Force the handshake to conclude.
if err := tlsConn.Handshake(); err != nil {
fmt.Printf("ERROR: TLS handshake failed: %v", err)
conn.Close()
return
}
conn.SetReadDeadline(time.Time{})
var (
cs = tlsConn.ConnectionState()
nextProto = cs.NegotiatedProtocol
)
switch nextProto {
case pool.ALPN_RPCGRPC:
atomic.AddInt32(&f.alpnConnEstablished, 1)
f.handler.Handle(tlsConn)
default:
fmt.Printf("ERROR: discarding RPC for unknown negotiated protocol %q\n", nextProto)
conn.Close()
}
}

View File

@ -20,6 +20,8 @@ func (t RPCType) ALPNString() string {
return ALPN_RPCGossip return ALPN_RPCGossip
case RPCTLSInsecure: case RPCTLSInsecure:
return "" // unsupported return "" // unsupported
case RPCGRPC:
return ALPN_RPCGRPC
default: default:
return "" // unsupported return "" // unsupported
} }
@ -28,19 +30,19 @@ func (t RPCType) ALPNString() string {
const ( const (
// keep numbers unique. // keep numbers unique.
RPCConsul RPCType = 0 RPCConsul RPCType = 0
RPCRaft = 1 RPCRaft RPCType = 1
RPCMultiplex = 2 // Old Muxado byte, no longer supported. RPCMultiplex RPCType = 2 // Old Muxado byte, no longer supported.
RPCTLS = 3 RPCTLS RPCType = 3
RPCMultiplexV2 = 4 RPCMultiplexV2 RPCType = 4
RPCSnapshot = 5 RPCSnapshot RPCType = 5
RPCGossip = 6 RPCGossip RPCType = 6
// RPCTLSInsecure is used to flag RPC calls that require verify // RPCTLSInsecure is used to flag RPC calls that require verify
// incoming to be disabled, even when it is turned on in the // incoming to be disabled, even when it is turned on in the
// configuration. At the time of writing there is only AutoEncrypt.Sign // configuration. At the time of writing there is only AutoEncrypt.Sign
// that is supported and it might be the only one there // that is supported and it might be the only one there
// ever is. // ever is.
RPCTLSInsecure = 7 RPCTLSInsecure RPCType = 7
RPCGRPC = 8 RPCGRPC RPCType = 8
// RPCMaxTypeValue is the maximum rpc type byte value currently used for the // RPCMaxTypeValue is the maximum rpc type byte value currently used for the
// various protocols riding over our "rpc" port. // various protocols riding over our "rpc" port.
@ -79,6 +81,7 @@ var RPCNextProtos = []string{
ALPN_RPCMultiplexV2, ALPN_RPCMultiplexV2,
ALPN_RPCSnapshot, ALPN_RPCSnapshot,
ALPN_RPCGossip, ALPN_RPCGossip,
ALPN_RPCGRPC,
ALPN_WANGossipPacket, ALPN_WANGossipPacket,
ALPN_WANGossipStream, ALPN_WANGossipStream,
} }

View File

@ -10,8 +10,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/hashicorp/consul/tlsutil"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/hashicorp/consul/tlsutil"
) )
func TestPeekForTLS_not_TLS(t *testing.T) { func TestPeekForTLS_not_TLS(t *testing.T) {
@ -30,6 +31,7 @@ func TestPeekForTLS_not_TLS(t *testing.T) {
RPCSnapshot, RPCSnapshot,
RPCGossip, RPCGossip,
RPCTLSInsecure, RPCTLSInsecure,
RPCGRPC,
} { } {
cases = append(cases, testcase{ cases = append(cases, testcase{
name: fmt.Sprintf("tcp rpc type byte %d", rpcType), name: fmt.Sprintf("tcp rpc type byte %d", rpcType),
@ -76,6 +78,7 @@ func TestPeekForTLS_actual_TLS(t *testing.T) {
RPCSnapshot, RPCSnapshot,
RPCGossip, RPCGossip,
RPCTLSInsecure, RPCTLSInsecure,
RPCGRPC,
} { } {
cases = append(cases, testcase{ cases = append(cases, testcase{
name: fmt.Sprintf("tcp rpc type byte %d", rpcType), name: fmt.Sprintf("tcp rpc type byte %d", rpcType),

View File

@ -2,6 +2,7 @@ package pool
import ( import (
"container/list" "container/list"
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"log" "log"
@ -11,14 +12,15 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/yamux"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/tlsutil"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/yamux"
) )
const defaultDialTimeout = 10 * time.Second const DefaultDialTimeout = 10 * time.Second
// muxSession is used to provide an interface for a stream multiplexer. // muxSession is used to provide an interface for a stream multiplexer.
type muxSession interface { type muxSession interface {
@ -291,21 +293,24 @@ func (p *ConnPool) DialTimeout(
) (net.Conn, HalfCloser, error) { ) (net.Conn, HalfCloser, error) {
p.once.Do(p.init) p.once.Do(p.init)
if p.Server && p.GatewayResolver != nil && p.TLSConfigurator != nil && dc != p.Datacenter { if p.Server &&
p.GatewayResolver != nil &&
p.TLSConfigurator != nil &&
dc != p.Datacenter {
// NOTE: TLS is required on this branch. // NOTE: TLS is required on this branch.
return DialTimeoutWithRPCTypeViaMeshGateway( nextProto := actualRPCType.ALPNString()
if nextProto == "" {
return nil, nil, fmt.Errorf("rpc type %d cannot be routed through a mesh gateway", actualRPCType)
}
return DialRPCViaMeshGateway(
context.Background(),
dc, dc,
nodeName, nodeName,
addr,
p.SrcAddr, p.SrcAddr,
p.TLSConfigurator.OutgoingALPNRPCWrapper(), p.TLSConfigurator.OutgoingALPNRPCWrapper(),
actualRPCType, nextProto,
RPCTLS,
// gateway stuff
p.Server, p.Server,
p.TLSConfigurator,
p.GatewayResolver, p.GatewayResolver,
p.Datacenter,
) )
} }
@ -319,7 +324,7 @@ func (p *ConnPool) dial(
tlsRPCType RPCType, tlsRPCType RPCType,
) (net.Conn, HalfCloser, error) { ) (net.Conn, HalfCloser, error) {
// Try to dial the conn // Try to dial the conn
d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: defaultDialTimeout} d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: DefaultDialTimeout}
conn, err := d.Dial("tcp", addr.String()) conn, err := d.Dial("tcp", addr.String())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -372,62 +377,49 @@ func (p *ConnPool) dial(
return conn, hc, nil return conn, hc, nil
} }
// DialTimeoutWithRPCTypeViaMeshGateway dials the destination node and sets up // DialRPCViaMeshGateway dials the destination node and sets up the connection
// the connection to be the correct RPC type using ALPN. This currently is // to be the correct RPC type using ALPN. This currently is exclusively used to
// exclusively used to dial other servers in foreign datacenters via mesh // dial other servers in foreign datacenters via mesh gateways.
// gateways. func DialRPCViaMeshGateway(
// ctx context.Context,
// NOTE: There is a close mirror of this method in agent/consul/wanfed/wanfed.go:dial dc string, // (metadata.Server).Datacenter
func DialTimeoutWithRPCTypeViaMeshGateway( nodeName string, // (metadata.Server).ShortName
dc string, srcAddr *net.TCPAddr,
nodeName string, alpnWrapper tlsutil.ALPNWrapper,
addr net.Addr, nextProto string,
src *net.TCPAddr,
wrapper tlsutil.ALPNWrapper,
actualRPCType RPCType,
tlsRPCType RPCType,
// gateway stuff
dialingFromServer bool, dialingFromServer bool,
tlsConfigurator *tlsutil.Configurator,
gatewayResolver func(string) string, gatewayResolver func(string) string,
thisDatacenter string,
) (net.Conn, HalfCloser, error) { ) (net.Conn, HalfCloser, error) {
if !dialingFromServer { if !dialingFromServer {
return nil, nil, fmt.Errorf("must dial via mesh gateways from a server agent") return nil, nil, fmt.Errorf("must dial via mesh gateways from a server agent")
} else if gatewayResolver == nil { } else if gatewayResolver == nil {
return nil, nil, fmt.Errorf("gatewayResolver is nil") return nil, nil, fmt.Errorf("gatewayResolver is nil")
} else if tlsConfigurator == nil { } else if alpnWrapper == nil {
return nil, nil, fmt.Errorf("tlsConfigurator is nil")
} else if dc == thisDatacenter {
return nil, nil, fmt.Errorf("cannot dial servers in the same datacenter via a mesh gateway")
} else if wrapper == nil {
return nil, nil, fmt.Errorf("cannot dial via a mesh gateway when outgoing TLS is disabled") return nil, nil, fmt.Errorf("cannot dial via a mesh gateway when outgoing TLS is disabled")
} }
nextProto := actualRPCType.ALPNString()
if nextProto == "" {
return nil, nil, fmt.Errorf("rpc type %d cannot be routed through a mesh gateway", actualRPCType)
}
gwAddr := gatewayResolver(dc) gwAddr := gatewayResolver(dc)
if gwAddr == "" { if gwAddr == "" {
return nil, nil, structs.ErrDCNotAvailable return nil, nil, structs.ErrDCNotAvailable
} }
dialer := &net.Dialer{LocalAddr: src, Timeout: defaultDialTimeout} dialer := &net.Dialer{LocalAddr: srcAddr, Timeout: DefaultDialTimeout}
rawConn, err := dialer.Dial("tcp", gwAddr) rawConn, err := dialer.DialContext(ctx, "tcp", gwAddr)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if tcp, ok := rawConn.(*net.TCPConn); ok { if nextProto != ALPN_RPCGRPC {
_ = tcp.SetKeepAlive(true) // agent/grpc/client.go:dial() handles this in another way for gRPC
_ = tcp.SetNoDelay(true) if tcp, ok := rawConn.(*net.TCPConn); ok {
_ = tcp.SetKeepAlive(true)
_ = tcp.SetNoDelay(true)
}
} }
// NOTE: now we wrap the connection in a TLS client. // NOTE: now we wrap the connection in a TLS client.
tlsConn, err := wrapper(dc, nodeName, nextProto, rawConn) tlsConn, err := alpnWrapper(dc, nodeName, nextProto, rawConn)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -106,9 +106,22 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer) (BaseDeps, error)
d.ViewStore = submatview.NewStore(d.Logger.Named("viewstore")) d.ViewStore = submatview.NewStore(d.Logger.Named("viewstore"))
d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator) d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator)
builder := resolver.NewServerResolverBuilder(resolver.Config{}) builder := resolver.NewServerResolverBuilder(resolver.Config{
// Set the authority to something sufficiently unique so any usage in
// tests would be self-isolating in the global resolver map, while also
// not incurring a huge penalty for non-test code.
Authority: cfg.Datacenter + "." + string(cfg.NodeID),
})
resolver.Register(builder) resolver.Register(builder)
d.GRPCConnPool = grpc.NewClientConnPool(builder, grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()), d.TLSConfigurator.UseTLS) d.GRPCConnPool = grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
Servers: builder,
SrcAddr: d.ConnPool.SrcAddr,
TLSWrapper: grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()),
ALPNWrapper: grpc.ALPNWrapper(d.TLSConfigurator.OutgoingALPNRPCWrapper()),
UseTLSForDC: d.TLSConfigurator.UseTLS,
DialingFromServer: cfg.ServerMode,
DialingFromDatacenter: cfg.Datacenter,
})
d.LeaderForwarder = builder d.LeaderForwarder = builder
d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder) d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder)