grpc: ensure that streaming gRPC requests work over mesh gateway based wan federation (#10838)

Fixes #10796
This commit is contained in:
R.B. Boyer 2021-08-24 16:28:44 -05:00 committed by GitHub
parent 6b574abc89
commit a84f5fa25d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 589 additions and 183 deletions

3
.changelog/10838.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
grpc: ensure that streaming gRPC requests work over mesh gateway based wan federation
```

View File

@ -371,6 +371,9 @@ func (f fakeGRPCConnPool) ClientConnLeader() (*grpc.ClientConn, error) {
return nil, nil
}
func (f fakeGRPCConnPool) SetGatewayResolver(_ func(string) string) {
}
func TestAgent_ReconnectConfigWanDisabled(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
@ -4524,6 +4527,9 @@ LOOP:
}
// This is a mirror of a similar test in agent/consul/server_test.go
//
// TODO(rb): implement something similar to this as a full containerized test suite with proper
// isolation so requests can't "cheat" and bypass the mesh gateways
func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
@ -4771,6 +4777,9 @@ func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) {
})
// Ensure we can do some trivial RPC in all directions.
//
// NOTE: we explicitly make streaming and non-streaming assertions here to
// verify both rpc and grpc codepaths.
agents := map[string]*TestAgent{"dc1": a1, "dc2": a2, "dc3": a3}
names := map[string]string{"dc1": "bob", "dc2": "betty", "dc3": "bonnie"}
for _, srcDC := range []string{"dc1", "dc2", "dc3"} {
@ -4780,20 +4789,39 @@ func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) {
continue
}
t.Run(srcDC+" to "+dstDC, func(t *testing.T) {
req, err := http.NewRequest("GET", "/v1/catalog/nodes?dc="+dstDC, nil)
require.NoError(t, err)
t.Run("normal-rpc", func(t *testing.T) {
req, err := http.NewRequest("GET", "/v1/catalog/nodes?dc="+dstDC, nil)
require.NoError(t, err)
resp := httptest.NewRecorder()
obj, err := a.srv.CatalogNodes(resp, req)
require.NoError(t, err)
require.NotNil(t, obj)
resp := httptest.NewRecorder()
obj, err := a.srv.CatalogNodes(resp, req)
require.NoError(t, err)
require.NotNil(t, obj)
nodes, ok := obj.(structs.Nodes)
require.True(t, ok)
require.Len(t, nodes, 1)
node := nodes[0]
require.Equal(t, dstDC, node.Datacenter)
require.Equal(t, names[dstDC], node.Node)
nodes, ok := obj.(structs.Nodes)
require.True(t, ok)
require.Len(t, nodes, 1)
node := nodes[0]
require.Equal(t, dstDC, node.Datacenter)
require.Equal(t, names[dstDC], node.Node)
})
t.Run("streaming-grpc", func(t *testing.T) {
req, err := http.NewRequest("GET", "/v1/health/service/consul?cached&dc="+dstDC, nil)
require.NoError(t, err)
resp := httptest.NewRecorder()
obj, err := a.srv.HealthServiceNodes(resp, req)
require.NoError(t, err)
require.NotNil(t, obj)
csns, ok := obj.(structs.CheckServiceNodes)
require.True(t, ok)
require.Len(t, csns, 1)
csn := csns[0]
require.Equal(t, dstDC, csn.Node.Datacenter)
require.Equal(t, names[dstDC], csn.Node.Node)
})
})
}
}

View File

@ -5,10 +5,17 @@ import (
"fmt"
"net"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/hashicorp/go-hclog"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/serf/serf"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
"github.com/hashicorp/consul/agent/grpc"
"github.com/hashicorp/consul/agent/grpc/resolver"
"github.com/hashicorp/consul/agent/pool"
@ -20,11 +27,6 @@ import (
"github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/go-hclog"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/serf/serf"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)
func testClientConfig(t *testing.T) (string, *Config) {
@ -490,6 +492,13 @@ func newClient(t *testing.T, config *Config) *Client {
return client
}
func newTestResolverConfig(t *testing.T, suffix string) resolver.Config {
n := t.Name()
s := strings.Replace(n, "/", "", -1)
s = strings.Replace(s, "_", "", -1)
return resolver.Config{Authority: strings.ToLower(s) + "-" + suffix}
}
func newDefaultDeps(t *testing.T, c *Config) Deps {
t.Helper()
@ -502,7 +511,7 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
tls, err := tlsutil.NewConfigurator(c.TLSConfig, logger)
require.NoError(t, err, "failed to create tls configuration")
builder := resolver.NewServerResolverBuilder(resolver.Config{Authority: c.NodeName})
builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter))
r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter), builder)
resolver.Register(builder)
@ -522,7 +531,13 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
Tokens: new(token.Store),
Router: r,
ConnPool: connPool,
GRPCConnPool: grpc.NewClientConnPool(builder, grpc.TLSWrapper(tls.OutgoingRPCWrapper()), tls.UseTLS),
GRPCConnPool: grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
Servers: builder,
TLSWrapper: grpc.TLSWrapper(tls.OutgoingRPCWrapper()),
UseTLSForDC: tls.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: c.Datacenter,
}),
LeaderForwarder: builder,
EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c),
}

View File

@ -24,8 +24,10 @@ type Deps struct {
type GRPCClientConner interface {
ClientConn(datacenter string) (*grpc.ClientConn, error)
ClientConnLeader() (*grpc.ClientConn, error)
SetGatewayResolver(func(string) string)
}
type LeaderForwarder interface {
UpdateLeaderAddr(leaderAddr string)
// UpdateLeaderAddr updates the leader address in the local DC's resolver.
UpdateLeaderAddr(datacenter, addr string)
}

View File

@ -293,7 +293,7 @@ func (s *Server) handleNativeTLS(conn net.Conn) {
s.handleSnapshotConn(tlsConn)
case pool.ALPN_RPCGRPC:
s.grpcHandler.Handle(conn)
s.grpcHandler.Handle(tlsConn)
case pool.ALPN_WANGossipPacket:
if err := s.handleALPN_WANGossipPacketStream(tlsConn); err != nil && err != io.EOF {
@ -373,7 +373,7 @@ func (s *Server) handleMultiplexV2(conn net.Conn) {
}
sub = peeked
switch first {
case pool.RPCGossip:
case byte(pool.RPCGossip):
buf := make([]byte, 1)
sub.Read(buf)
go func() {

View File

@ -460,7 +460,7 @@ func TestRPC_TLSHandshakeTimeout(t *testing.T) {
// Write TLS byte to avoid being closed by either the (outer) first byte
// timeout or the fact that server requires TLS
_, err = conn.Write([]byte{pool.RPCTLS})
_, err = conn.Write([]byte{byte(pool.RPCTLS)})
require.NoError(t, err)
// Wait for more than the timeout before we start a TLS handshake. This is

View File

@ -173,6 +173,9 @@ type Server struct {
// Connection pool to other consul servers
connPool *pool.ConnPool
// Connection pool to other consul servers using gRPC
grpcConnPool GRPCClientConner
// eventChLAN is used to receive events from the
// serf cluster in the datacenter
eventChLAN chan serf.Event
@ -348,6 +351,7 @@ func NewServer(config *Config, flat Deps) (*Server, error) {
config: config,
tokens: flat.Tokens,
connPool: flat.ConnPool,
grpcConnPool: flat.GRPCConnPool,
eventChLAN: make(chan serf.Event, serfEventChSize),
eventChWAN: make(chan serf.Event, serfEventChSize),
logger: serverLogger,
@ -377,6 +381,7 @@ func NewServer(config *Config, flat Deps) (*Server, error) {
s.config.PrimaryDatacenter,
)
s.connPool.GatewayResolver = s.gatewayLocator.PickGateway
s.grpcConnPool.SetGatewayResolver(s.gatewayLocator.PickGateway)
}
// Initialize enterprise specific server functionality
@ -1461,7 +1466,7 @@ func (s *Server) trackLeaderChanges() {
continue
}
s.grpcLeaderForwarder.UpdateLeaderAddr(string(leaderObs.Leader))
s.grpcLeaderForwarder.UpdateLeaderAddr(s.config.Datacenter, string(leaderObs.Leader))
case <-s.shutdownCh:
s.raft.DeregisterObserver(observer)
return

View File

@ -25,6 +25,8 @@ import (
func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
t.Parallel()
// TODO(rb): add tests for the wanfed/alpn variations
_, conf1 := testServerConfig(t)
conf1.TLSConfig.VerifyIncoming = true
conf1.TLSConfig.VerifyOutgoing = true
@ -60,7 +62,13 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
// Start a Subscribe call to our streaming endpoint from the client.
{
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS)
pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
Servers: builder,
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
@ -91,8 +99,13 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
// Start a Subscribe call to our streaming endpoint from the server's loopback client.
{
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS)
pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
Servers: builder,
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
@ -166,7 +179,13 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) {
// Subscribe calls should fail initially
joinLAN(t, client, server)
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS)
pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
Servers: builder,
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
@ -294,7 +313,13 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
}
}()
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS)
pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
Servers: builder,
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
@ -337,7 +362,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
}
func newClientWithGRPCResolver(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) {
builder := resolver.NewServerResolverBuilder(resolver.Config{Authority: t.Name()})
builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, "client"))
resolver.Register(builder)
t.Cleanup(func() {
resolver.Deregister(builder.Authority())

View File

@ -1,6 +1,7 @@
package wanfed
import (
"context"
"encoding/binary"
"errors"
"fmt"
@ -11,7 +12,6 @@ import (
"github.com/hashicorp/memberlist"
"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/tlsutil"
)
@ -97,13 +97,8 @@ func (t *Transport) WriteToAddress(b []byte, addr memberlist.Address) (time.Time
}
if dc != t.datacenter {
gwAddr := t.gwResolver(dc)
if gwAddr == "" {
return time.Time{}, structs.ErrDCNotAvailable
}
dialFunc := func() (net.Conn, error) {
return t.dial(dc, node, pool.ALPN_WANGossipPacket, gwAddr)
return t.dial(dc, node, pool.ALPN_WANGossipPacket)
}
conn, err := t.pool.AcquireOrDial(addr.Name, dialFunc)
if err != nil {
@ -136,42 +131,24 @@ func (t *Transport) DialAddressTimeout(addr memberlist.Address, timeout time.Dur
}
if dc != t.datacenter {
gwAddr := t.gwResolver(dc)
if gwAddr == "" {
return nil, structs.ErrDCNotAvailable
}
return t.dial(dc, node, pool.ALPN_WANGossipStream, gwAddr)
return t.dial(dc, node, pool.ALPN_WANGossipStream)
}
return t.IngestionAwareTransport.DialAddressTimeout(addr, timeout)
}
// NOTE: There is a close mirror of this method in agent/pool/pool.go:DialTimeoutWithRPCType
func (t *Transport) dial(dc, nodeName, nextProto, addr string) (net.Conn, error) {
wrapper := t.tlsConfigurator.OutgoingALPNRPCWrapper()
if wrapper == nil {
return nil, fmt.Errorf("wanfed: cannot dial via a mesh gateway when outgoing TLS is disabled")
}
dialer := &net.Dialer{Timeout: 10 * time.Second}
rawConn, err := dialer.Dial("tcp", addr)
if err != nil {
return nil, err
}
if tcp, ok := rawConn.(*net.TCPConn); ok {
_ = tcp.SetKeepAlive(true)
_ = tcp.SetNoDelay(true)
}
tlsConn, err := wrapper(dc, nodeName, nextProto, rawConn)
if err != nil {
return nil, err
}
return tlsConn, nil
func (t *Transport) dial(dc, nodeName, nextProto string) (net.Conn, error) {
conn, _, err := pool.DialRPCViaMeshGateway(
context.Background(),
dc,
nodeName,
nil, // TODO(rb): thread source address through here?
t.tlsConfigurator.OutgoingALPNRPCWrapper(),
nextProto,
true,
t.gwResolver,
)
return conn, err
}
// SplitNodeName splits a node name as it would be represented in

View File

@ -12,38 +12,93 @@ import (
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/tlsutil"
)
// ClientConnPool creates and stores a connection for each datacenter.
type ClientConnPool struct {
dialer dialer
servers ServerLocator
conns map[string]*grpc.ClientConn
connsLock sync.Mutex
dialer dialer
servers ServerLocator
gwResolverDep gatewayResolverDep
conns map[string]*grpc.ClientConn
connsLock sync.Mutex
}
type ServerLocator interface {
// ServerForAddr is used to look up server metadata from an address.
ServerForAddr(addr string) (*metadata.Server, error)
// ServerForGlobalAddr returns server metadata for a server with the specified globally unique address.
ServerForGlobalAddr(globalAddr string) (*metadata.Server, error)
// Authority returns the target authority to use to dial the server. This is primarily
// needed for testing multiple agents in parallel, because gRPC requires the
// resolver to be registered globally.
Authority() string
}
// gatewayResolverDep is just a holder for a function pointer that can be
// updated lazily after the structs are instantiated (but before first use)
// and all structs with a reference to this struct will see the same update.
type gatewayResolverDep struct {
// GatewayResolver is a function that returns a suitable random mesh
// gateway address for dialing servers in a given DC. This is only
// needed if wan federation via mesh gateways is enabled.
GatewayResolver func(string) string
}
// TLSWrapper wraps a non-TLS connection and returns a connection with TLS
// enabled.
type TLSWrapper func(dc string, conn net.Conn) (net.Conn, error)
// ALPNWrapper is a function that is used to wrap a non-TLS connection and
// returns an appropriate TLS connection or error. This taks a datacenter and
// node name as argument to configure the desired SNI value and the desired
// next proto for configuring ALPN.
type ALPNWrapper func(dc, nodeName, alpnProto string, conn net.Conn) (net.Conn, error)
type dialer func(context.Context, string) (net.Conn, error)
// NewClientConnPool create new GRPC client pool to connect to servers using GRPC over RPC
func NewClientConnPool(servers ServerLocator, tls TLSWrapper, useTLSForDC func(dc string) bool) *ClientConnPool {
return &ClientConnPool{
dialer: newDialer(servers, tls, useTLSForDC),
servers: servers,
type ClientConnPoolConfig struct {
// Servers is a reference for how to figure out how to dial any server.
Servers ServerLocator
// SrcAddr is the source address for outgoing connections.
SrcAddr *net.TCPAddr
// TLSWrapper is the specifics of wrapping a socket when doing an TYPE_BYTE+TLS
// wrapped RPC request.
TLSWrapper TLSWrapper
// ALPNWrapper is the specifics of wrapping a socket when doing an ALPN+TLS
// wrapped RPC request (typically only for wan federation via mesh
// gateways).
ALPNWrapper ALPNWrapper
// UseTLSForDC is a function to determine if dialing a given datacenter
// should use TLS.
UseTLSForDC func(dc string) bool
// DialingFromServer should be set to true if this connection pool is owned
// by a consul server instance.
DialingFromServer bool
// DialingFromDatacenter is the datacenter of the consul agent using this
// pool.
DialingFromDatacenter string
}
// NewClientConnPool create new GRPC client pool to connect to servers using
// GRPC over RPC.
func NewClientConnPool(cfg ClientConnPoolConfig) *ClientConnPool {
c := &ClientConnPool{
servers: cfg.Servers,
conns: make(map[string]*grpc.ClientConn),
}
c.dialer = newDialer(cfg, &c.gwResolverDep)
return c
}
// SetGatewayResolver is only to be called during setup before the pool is used.
func (c *ClientConnPool) SetGatewayResolver(gatewayResolver func(string) string) {
c.gwResolverDep.GatewayResolver = gatewayResolver
}
// ClientConn returns a grpc.ClientConn for the datacenter. If there are no
@ -102,22 +157,39 @@ func (c *ClientConnPool) dial(datacenter string, serverType string) (*grpc.Clien
// newDialer returns a gRPC dialer function that conditionally wraps the connection
// with TLS based on the Server.useTLS value.
func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc string) bool) func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, addr string) (net.Conn, error) {
d := net.Dialer{}
conn, err := d.DialContext(ctx, "tcp", addr)
func newDialer(cfg ClientConnPoolConfig, gwResolverDep *gatewayResolverDep) func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, globalAddr string) (net.Conn, error) {
server, err := cfg.Servers.ServerForGlobalAddr(globalAddr)
if err != nil {
return nil, err
}
server, err := servers.ServerForAddr(addr)
if cfg.DialingFromServer &&
gwResolverDep.GatewayResolver != nil &&
cfg.ALPNWrapper != nil &&
server.Datacenter != cfg.DialingFromDatacenter {
// NOTE: TLS is required on this branch.
conn, _, err := pool.DialRPCViaMeshGateway(
ctx,
server.Datacenter,
server.ShortName,
cfg.SrcAddr,
tlsutil.ALPNWrapper(cfg.ALPNWrapper),
pool.ALPN_RPCGRPC,
cfg.DialingFromServer,
gwResolverDep.GatewayResolver,
)
return conn, err
}
d := net.Dialer{LocalAddr: cfg.SrcAddr, Timeout: pool.DefaultDialTimeout}
conn, err := d.DialContext(ctx, "tcp", server.Addr.String())
if err != nil {
conn.Close()
return nil, err
}
if server.UseTLS && useTLSForDC(server.Datacenter) {
if wrapper == nil {
if server.UseTLS && cfg.UseTLSForDC(server.Datacenter) {
if cfg.TLSWrapper == nil {
conn.Close()
return nil, fmt.Errorf("TLS enabled but got nil TLS wrapper")
}
@ -129,7 +201,7 @@ func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc st
}
// Wrap the connection in a TLS client
tlsConn, err := wrapper(server.Datacenter, conn)
tlsConn, err := cfg.TLSWrapper(server.Datacenter, conn)
if err != nil {
conn.Close()
return nil, err
@ -137,7 +209,7 @@ func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc st
conn = tlsConn
}
_, err = conn.Write([]byte{pool.RPCGRPC})
_, err = conn.Write([]byte{byte(pool.RPCGRPC)})
if err != nil {
conn.Close()
return nil, err

View File

@ -4,17 +4,22 @@ import (
"context"
"fmt"
"net"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/google/tcpproxy"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent/grpc/internal/testservice"
"github.com/hashicorp/consul/agent/grpc/resolver"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/tlsutil"
)
@ -24,11 +29,14 @@ func useTLSForDcAlwaysTrue(_ string) bool {
}
func TestNewDialer_WithTLSWrapper(t *testing.T) {
lis, err := net.Listen("tcp", "127.0.0.1:0")
ports := freeport.MustTake(1)
defer freeport.Return(ports)
lis, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(ports[0])))
require.NoError(t, err)
t.Cleanup(logError(t, lis.Close))
builder := resolver.NewServerResolverBuilder(resolver.Config{})
builder := resolver.NewServerResolverBuilder(newConfig(t))
builder.AddServer(&metadata.Server{
Name: "server-1",
ID: "ID1",
@ -42,19 +50,107 @@ func TestNewDialer_WithTLSWrapper(t *testing.T) {
called = true
return conn, nil
}
dial := newDialer(builder, wrapper, useTLSForDcAlwaysTrue)
dial := newDialer(
ClientConnPoolConfig{
Servers: builder,
TLSWrapper: wrapper,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
},
&gatewayResolverDep{},
)
ctx := context.Background()
conn, err := dial(ctx, lis.Addr().String())
conn, err := dial(ctx, resolver.DCPrefix("dc1", lis.Addr().String()))
require.NoError(t, err)
require.NoError(t, conn.Close())
require.True(t, called, "expected TLSWrapper to be called")
}
func TestNewDialer_WithALPNWrapper(t *testing.T) {
ports := freeport.MustTake(3)
defer freeport.Return(ports)
var (
s1addr = ipaddr.FormatAddressPort("127.0.0.1", ports[0])
s2addr = ipaddr.FormatAddressPort("127.0.0.1", ports[1])
gwAddr = ipaddr.FormatAddressPort("127.0.0.1", ports[2])
)
lis1, err := net.Listen("tcp", s1addr)
require.NoError(t, err)
t.Cleanup(logError(t, lis1.Close))
lis2, err := net.Listen("tcp", s2addr)
require.NoError(t, err)
t.Cleanup(logError(t, lis2.Close))
// Send all of the traffic to dc2's server
var p tcpproxy.Proxy
p.AddRoute(gwAddr, tcpproxy.To(s2addr))
p.AddStopACMESearch(gwAddr)
require.NoError(t, p.Start())
defer func() {
p.Close()
p.Wait()
}()
builder := resolver.NewServerResolverBuilder(newConfig(t))
builder.AddServer(&metadata.Server{
Name: "server-1",
ID: "ID1",
Datacenter: "dc1",
Addr: lis1.Addr(),
UseTLS: true,
})
builder.AddServer(&metadata.Server{
Name: "server-2",
ID: "ID2",
Datacenter: "dc2",
Addr: lis2.Addr(),
UseTLS: true,
})
var calledTLS bool
wrapperTLS := func(_ string, conn net.Conn) (net.Conn, error) {
calledTLS = true
return conn, nil
}
var calledALPN bool
wrapperALPN := func(_, _, _ string, conn net.Conn) (net.Conn, error) {
calledALPN = true
return conn, nil
}
gwResolverDep := &gatewayResolverDep{
GatewayResolver: func(addr string) string {
return gwAddr
},
}
dial := newDialer(
ClientConnPoolConfig{
Servers: builder,
TLSWrapper: wrapperTLS,
ALPNWrapper: wrapperALPN,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
},
gwResolverDep,
)
ctx := context.Background()
conn, err := dial(ctx, resolver.DCPrefix("dc2", lis2.Addr().String()))
require.NoError(t, err)
require.NoError(t, conn.Close())
assert.False(t, calledTLS, "expected TLSWrapper not to be called")
assert.True(t, calledALPN, "expected ALPNWrapper to be called")
}
func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
srv := newTestServer(t, "server-1", "dc1")
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
VerifyIncoming: true,
VerifyOutgoing: true,
@ -63,12 +159,20 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
KeyFile: "../../test/hostname/Alice.key",
}, hclog.New(nil))
require.NoError(t, err)
srv.rpc.tlsConf = tlsConf
res.AddServer(srv.Metadata())
srv := newTestServer(t, "server-1", "dc1", tlsConf)
md := srv.Metadata()
res.AddServer(md)
t.Cleanup(srv.shutdown)
pool := NewClientConnPool(res, TLSWrapper(tlsConf.OutgoingRPCWrapper()), tlsConf.UseTLS)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
TLSWrapper: TLSWrapper(tlsConf.OutgoingRPCWrapper()),
UseTLSForDC: tlsConf.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
@ -81,17 +185,98 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "server-1", resp.ServerName)
require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) > 0)
require.True(t, atomic.LoadInt32(&srv.rpc.alpnConnEstablished) == 0)
}
func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T) {
ports := freeport.MustTake(1)
defer freeport.Return(ports)
gwAddr := ipaddr.FormatAddressPort("127.0.0.1", ports[0])
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
VerifyIncoming: true,
VerifyOutgoing: true,
VerifyServerHostname: true,
CAFile: "../../test/hostname/CertAuth.crt",
CertFile: "../../test/hostname/Bob.crt",
KeyFile: "../../test/hostname/Bob.key",
Domain: "consul",
NodeName: "bob",
}, hclog.New(nil))
require.NoError(t, err)
srv := newTestServer(t, "bob", "dc1", tlsConf)
// Send all of the traffic to dc1's server
var p tcpproxy.Proxy
p.AddRoute(gwAddr, tcpproxy.To(srv.addr.String()))
p.AddStopACMESearch(gwAddr)
require.NoError(t, p.Start())
defer func() {
p.Close()
p.Wait()
}()
md := srv.Metadata()
res.AddServer(md)
t.Cleanup(srv.shutdown)
clientTLSConf, err := tlsutil.NewConfigurator(tlsutil.Config{
VerifyIncoming: true,
VerifyOutgoing: true,
VerifyServerHostname: true,
CAFile: "../../test/hostname/CertAuth.crt",
CertFile: "../../test/hostname/Betty.crt",
KeyFile: "../../test/hostname/Betty.key",
Domain: "consul",
NodeName: "betty",
}, hclog.New(nil))
require.NoError(t, err)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
TLSWrapper: TLSWrapper(clientTLSConf.OutgoingRPCWrapper()),
ALPNWrapper: ALPNWrapper(clientTLSConf.OutgoingALPNRPCWrapper()),
UseTLSForDC: tlsConf.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc2",
})
pool.SetGatewayResolver(func(addr string) string {
return gwAddr
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
client := testservice.NewSimpleClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
t.Cleanup(cancel)
resp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
require.Equal(t, "bob", resp.ServerName)
require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) == 0)
require.True(t, atomic.LoadInt32(&srv.rpc.alpnConnEstablished) > 0)
}
func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
count := 4
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
for i := 0; i < count; i++ {
name := fmt.Sprintf("server-%d", i)
srv := newTestServer(t, name, "dc1")
srv := newTestServer(t, name, "dc1", nil)
res.AddServer(srv.Metadata())
t.Cleanup(srv.shutdown)
}
@ -115,22 +300,27 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
count := 3
conf := newConfig(t)
res := resolver.NewServerResolverBuilder(conf)
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
var servers []testServer
for i := 0; i < count; i++ {
name := fmt.Sprintf("server-%d", i)
srv := newTestServer(t, name, "dc1")
srv := newTestServer(t, name, "dc1", nil)
res.AddServer(srv.Metadata())
servers = append(servers, srv)
t.Cleanup(srv.shutdown)
}
// Set the leader address to the first server.
res.UpdateLeaderAddr(servers[0].addr.String())
srv0 := servers[0].Metadata()
res.UpdateLeaderAddr(srv0.Datacenter, srv0.Addr.String())
conn, err := pool.ClientConnLeader()
require.NoError(t, err)
@ -144,7 +334,8 @@ func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
require.Equal(t, first.ServerName, servers[0].name)
// Update the leader address and make another request.
res.UpdateLeaderAddr(servers[1].addr.String())
srv1 := servers[1].Metadata()
res.UpdateLeaderAddr(srv1.Datacenter, srv1.Addr.String())
resp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
@ -162,11 +353,16 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) {
count := 5
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
for i := 0; i < count; i++ {
name := fmt.Sprintf("server-%d", i)
srv := newTestServer(t, name, "dc1")
srv := newTestServer(t, name, "dc1", nil)
res.AddServer(srv.Metadata())
t.Cleanup(srv.shutdown)
}
@ -211,11 +407,16 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue)
pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
})
for _, dc := range dcs {
name := "server-0-" + dc
srv := newTestServer(t, name, dc)
srv := newTestServer(t, name, dc, nil)
res.AddServer(srv.Metadata())
t.Cleanup(srv.shutdown)
}

View File

@ -67,17 +67,17 @@ func (s *ServerResolverBuilder) NewRebalancer(dc string) func() {
}
}
// ServerForAddr returns server metadata for a server with the specified address.
func (s *ServerResolverBuilder) ServerForAddr(addr string) (*metadata.Server, error) {
// ServerForGlobalAddr returns server metadata for a server with the specified globally unique address.
func (s *ServerResolverBuilder) ServerForGlobalAddr(globalAddr string) (*metadata.Server, error) {
s.lock.RLock()
defer s.lock.RUnlock()
for _, server := range s.servers {
if server.Addr.String() == addr {
if DCPrefix(server.Datacenter, server.Addr.String()) == globalAddr {
return server, nil
}
}
return nil, fmt.Errorf("failed to find Consul server for address %q", addr)
return nil, fmt.Errorf("failed to find Consul server for global address %q", globalAddr)
}
// Build returns a new serverResolver for the given ClientConn. The resolver
@ -161,6 +161,12 @@ func uniqueID(server *metadata.Server) string {
return server.Datacenter + "-" + server.ID
}
// DCPrefix prefixes the given string with a datacenter for use in
// disambiguation.
func DCPrefix(datacenter, suffix string) string {
return datacenter + "-" + suffix
}
// RemoveServer updates the resolvers' states with the given server removed.
func (s *ServerResolverBuilder) RemoveServer(server *metadata.Server) {
s.lock.Lock()
@ -186,7 +192,8 @@ func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address {
}
addrs = append(addrs, resolver.Address{
Addr: server.Addr.String(),
// NOTE: the address persisted here is only dialable using our custom dialer
Addr: DCPrefix(server.Datacenter, server.Addr.String()),
Type: resolver.Backend,
ServerName: server.Name,
})
@ -195,11 +202,11 @@ func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address {
}
// UpdateLeaderAddr updates the leader address in the local DC's resolver.
func (s *ServerResolverBuilder) UpdateLeaderAddr(leaderAddr string) {
func (s *ServerResolverBuilder) UpdateLeaderAddr(datacenter, addr string) {
s.lock.Lock()
defer s.lock.Unlock()
s.leaderResolver.addr = leaderAddr
s.leaderResolver.globalAddr = DCPrefix(datacenter, addr)
s.leaderResolver.updateClientConn()
}
@ -262,7 +269,7 @@ func (r *serverResolver) Close() {
func (*serverResolver) ResolveNow(resolver.ResolveNowOption) {}
type leaderResolver struct {
addr string
globalAddr string
clientConn resolver.ClientConn
}
@ -271,12 +278,13 @@ func (l leaderResolver) ResolveNow(resolver.ResolveNowOption) {}
func (l leaderResolver) Close() {}
func (l leaderResolver) updateClientConn() {
if l.addr == "" || l.clientConn == nil {
if l.globalAddr == "" || l.clientConn == nil {
return
}
addrs := []resolver.Address{
{
Addr: l.addr,
// NOTE: the address persisted here is only dialable using our custom dialer
Addr: l.globalAddr,
Type: resolver.Backend,
ServerName: "leader",
},

View File

@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"strconv"
"sync/atomic"
"testing"
"time"
@ -17,6 +18,7 @@ import (
"github.com/hashicorp/consul/agent/grpc/internal/testservice"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/tlsutil"
)
@ -31,22 +33,29 @@ type testServer struct {
func (s testServer) Metadata() *metadata.Server {
return &metadata.Server{
ID: s.name,
Name: s.name + "." + s.dc,
ShortName: s.name,
Datacenter: s.dc,
Addr: s.addr,
UseTLS: s.rpc.tlsConf != nil,
}
}
func newTestServer(t *testing.T, name string, dc string) testServer {
func newTestServer(t *testing.T, name string, dc string, tlsConf *tlsutil.Configurator) testServer {
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
handler := NewHandler(addr, func(server *grpc.Server) {
testservice.RegisterSimpleServer(server, &simple{name: name, dc: dc})
})
lis, err := net.Listen("tcp", "127.0.0.1:0")
ports := freeport.MustTake(1)
t.Cleanup(func() {
freeport.Return(ports)
})
lis, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(ports[0])))
require.NoError(t, err)
rpc := &fakeRPCListener{t: t, handler: handler}
rpc := &fakeRPCListener{t: t, handler: handler, tlsConf: tlsConf}
g := errgroup.Group{}
g.Go(func() error {
@ -107,11 +116,12 @@ func (s *simple) Something(_ context.Context, _ *testservice.Req) (*testservice.
// For now, since this logic is in agent/consul, we can't easily use Server.listen
// so we fake it.
type fakeRPCListener struct {
t *testing.T
handler *Handler
shutdown bool
tlsConf *tlsutil.Configurator
tlsConnEstablished int32
t *testing.T
handler *Handler
shutdown bool
tlsConf *tlsutil.Configurator
tlsConnEstablished int32
alpnConnEstablished int32
}
func (f *fakeRPCListener) listen(listener net.Listener) error {
@ -129,6 +139,26 @@ func (f *fakeRPCListener) listen(listener net.Listener) error {
}
func (f *fakeRPCListener) handleConn(conn net.Conn) {
if f.tlsConf != nil && f.tlsConf.MutualTLSCapable() {
// See if actually this is native TLS multiplexed onto the old
// "type-byte" system.
peekedConn, nativeTLS, err := pool.PeekForTLS(conn)
if err != nil {
if err != io.EOF {
fmt.Printf("ERROR: failed to read first byte: %v\n", err)
}
conn.Close()
return
}
if nativeTLS {
f.handleNativeTLSConn(peekedConn)
return
}
conn = peekedConn
}
buf := make([]byte, 1)
if _, err := conn.Read(buf); err != nil {
@ -166,3 +196,32 @@ func (f *fakeRPCListener) handleConn(conn net.Conn) {
conn.Close()
}
}
func (f *fakeRPCListener) handleNativeTLSConn(conn net.Conn) {
tlscfg := f.tlsConf.IncomingALPNRPCConfig(pool.RPCNextProtos)
tlsConn := tls.Server(conn, tlscfg)
// Force the handshake to conclude.
if err := tlsConn.Handshake(); err != nil {
fmt.Printf("ERROR: TLS handshake failed: %v", err)
conn.Close()
return
}
conn.SetReadDeadline(time.Time{})
var (
cs = tlsConn.ConnectionState()
nextProto = cs.NegotiatedProtocol
)
switch nextProto {
case pool.ALPN_RPCGRPC:
atomic.AddInt32(&f.alpnConnEstablished, 1)
f.handler.Handle(tlsConn)
default:
fmt.Printf("ERROR: discarding RPC for unknown negotiated protocol %q\n", nextProto)
conn.Close()
}
}

View File

@ -20,6 +20,8 @@ func (t RPCType) ALPNString() string {
return ALPN_RPCGossip
case RPCTLSInsecure:
return "" // unsupported
case RPCGRPC:
return ALPN_RPCGRPC
default:
return "" // unsupported
}
@ -28,19 +30,19 @@ func (t RPCType) ALPNString() string {
const (
// keep numbers unique.
RPCConsul RPCType = 0
RPCRaft = 1
RPCMultiplex = 2 // Old Muxado byte, no longer supported.
RPCTLS = 3
RPCMultiplexV2 = 4
RPCSnapshot = 5
RPCGossip = 6
RPCRaft RPCType = 1
RPCMultiplex RPCType = 2 // Old Muxado byte, no longer supported.
RPCTLS RPCType = 3
RPCMultiplexV2 RPCType = 4
RPCSnapshot RPCType = 5
RPCGossip RPCType = 6
// RPCTLSInsecure is used to flag RPC calls that require verify
// incoming to be disabled, even when it is turned on in the
// configuration. At the time of writing there is only AutoEncrypt.Sign
// that is supported and it might be the only one there
// ever is.
RPCTLSInsecure = 7
RPCGRPC = 8
RPCTLSInsecure RPCType = 7
RPCGRPC RPCType = 8
// RPCMaxTypeValue is the maximum rpc type byte value currently used for the
// various protocols riding over our "rpc" port.
@ -79,6 +81,7 @@ var RPCNextProtos = []string{
ALPN_RPCMultiplexV2,
ALPN_RPCSnapshot,
ALPN_RPCGossip,
ALPN_RPCGRPC,
ALPN_WANGossipPacket,
ALPN_WANGossipStream,
}

View File

@ -10,8 +10,9 @@ import (
"testing"
"time"
"github.com/hashicorp/consul/tlsutil"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/tlsutil"
)
func TestPeekForTLS_not_TLS(t *testing.T) {
@ -30,6 +31,7 @@ func TestPeekForTLS_not_TLS(t *testing.T) {
RPCSnapshot,
RPCGossip,
RPCTLSInsecure,
RPCGRPC,
} {
cases = append(cases, testcase{
name: fmt.Sprintf("tcp rpc type byte %d", rpcType),
@ -76,6 +78,7 @@ func TestPeekForTLS_actual_TLS(t *testing.T) {
RPCSnapshot,
RPCGossip,
RPCTLSInsecure,
RPCGRPC,
} {
cases = append(cases, testcase{
name: fmt.Sprintf("tcp rpc type byte %d", rpcType),

View File

@ -2,6 +2,7 @@ package pool
import (
"container/list"
"context"
"crypto/tls"
"fmt"
"log"
@ -11,14 +12,15 @@ import (
"sync/atomic"
"time"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/yamux"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/tlsutil"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/yamux"
)
const defaultDialTimeout = 10 * time.Second
const DefaultDialTimeout = 10 * time.Second
// muxSession is used to provide an interface for a stream multiplexer.
type muxSession interface {
@ -291,21 +293,24 @@ func (p *ConnPool) DialTimeout(
) (net.Conn, HalfCloser, error) {
p.once.Do(p.init)
if p.Server && p.GatewayResolver != nil && p.TLSConfigurator != nil && dc != p.Datacenter {
if p.Server &&
p.GatewayResolver != nil &&
p.TLSConfigurator != nil &&
dc != p.Datacenter {
// NOTE: TLS is required on this branch.
return DialTimeoutWithRPCTypeViaMeshGateway(
nextProto := actualRPCType.ALPNString()
if nextProto == "" {
return nil, nil, fmt.Errorf("rpc type %d cannot be routed through a mesh gateway", actualRPCType)
}
return DialRPCViaMeshGateway(
context.Background(),
dc,
nodeName,
addr,
p.SrcAddr,
p.TLSConfigurator.OutgoingALPNRPCWrapper(),
actualRPCType,
RPCTLS,
// gateway stuff
nextProto,
p.Server,
p.TLSConfigurator,
p.GatewayResolver,
p.Datacenter,
)
}
@ -319,7 +324,7 @@ func (p *ConnPool) dial(
tlsRPCType RPCType,
) (net.Conn, HalfCloser, error) {
// Try to dial the conn
d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: defaultDialTimeout}
d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: DefaultDialTimeout}
conn, err := d.Dial("tcp", addr.String())
if err != nil {
return nil, nil, err
@ -372,62 +377,49 @@ func (p *ConnPool) dial(
return conn, hc, nil
}
// DialTimeoutWithRPCTypeViaMeshGateway dials the destination node and sets up
// the connection to be the correct RPC type using ALPN. This currently is
// exclusively used to dial other servers in foreign datacenters via mesh
// gateways.
//
// NOTE: There is a close mirror of this method in agent/consul/wanfed/wanfed.go:dial
func DialTimeoutWithRPCTypeViaMeshGateway(
dc string,
nodeName string,
addr net.Addr,
src *net.TCPAddr,
wrapper tlsutil.ALPNWrapper,
actualRPCType RPCType,
tlsRPCType RPCType,
// gateway stuff
// DialRPCViaMeshGateway dials the destination node and sets up the connection
// to be the correct RPC type using ALPN. This currently is exclusively used to
// dial other servers in foreign datacenters via mesh gateways.
func DialRPCViaMeshGateway(
ctx context.Context,
dc string, // (metadata.Server).Datacenter
nodeName string, // (metadata.Server).ShortName
srcAddr *net.TCPAddr,
alpnWrapper tlsutil.ALPNWrapper,
nextProto string,
dialingFromServer bool,
tlsConfigurator *tlsutil.Configurator,
gatewayResolver func(string) string,
thisDatacenter string,
) (net.Conn, HalfCloser, error) {
if !dialingFromServer {
return nil, nil, fmt.Errorf("must dial via mesh gateways from a server agent")
} else if gatewayResolver == nil {
return nil, nil, fmt.Errorf("gatewayResolver is nil")
} else if tlsConfigurator == nil {
return nil, nil, fmt.Errorf("tlsConfigurator is nil")
} else if dc == thisDatacenter {
return nil, nil, fmt.Errorf("cannot dial servers in the same datacenter via a mesh gateway")
} else if wrapper == nil {
} else if alpnWrapper == nil {
return nil, nil, fmt.Errorf("cannot dial via a mesh gateway when outgoing TLS is disabled")
}
nextProto := actualRPCType.ALPNString()
if nextProto == "" {
return nil, nil, fmt.Errorf("rpc type %d cannot be routed through a mesh gateway", actualRPCType)
}
gwAddr := gatewayResolver(dc)
if gwAddr == "" {
return nil, nil, structs.ErrDCNotAvailable
}
dialer := &net.Dialer{LocalAddr: src, Timeout: defaultDialTimeout}
dialer := &net.Dialer{LocalAddr: srcAddr, Timeout: DefaultDialTimeout}
rawConn, err := dialer.Dial("tcp", gwAddr)
rawConn, err := dialer.DialContext(ctx, "tcp", gwAddr)
if err != nil {
return nil, nil, err
}
if tcp, ok := rawConn.(*net.TCPConn); ok {
_ = tcp.SetKeepAlive(true)
_ = tcp.SetNoDelay(true)
if nextProto != ALPN_RPCGRPC {
// agent/grpc/client.go:dial() handles this in another way for gRPC
if tcp, ok := rawConn.(*net.TCPConn); ok {
_ = tcp.SetKeepAlive(true)
_ = tcp.SetNoDelay(true)
}
}
// NOTE: now we wrap the connection in a TLS client.
tlsConn, err := wrapper(dc, nodeName, nextProto, rawConn)
tlsConn, err := alpnWrapper(dc, nodeName, nextProto, rawConn)
if err != nil {
return nil, nil, err
}

View File

@ -106,9 +106,22 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer) (BaseDeps, error)
d.ViewStore = submatview.NewStore(d.Logger.Named("viewstore"))
d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator)
builder := resolver.NewServerResolverBuilder(resolver.Config{})
builder := resolver.NewServerResolverBuilder(resolver.Config{
// Set the authority to something sufficiently unique so any usage in
// tests would be self-isolating in the global resolver map, while also
// not incurring a huge penalty for non-test code.
Authority: cfg.Datacenter + "." + string(cfg.NodeID),
})
resolver.Register(builder)
d.GRPCConnPool = grpc.NewClientConnPool(builder, grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()), d.TLSConfigurator.UseTLS)
d.GRPCConnPool = grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
Servers: builder,
SrcAddr: d.ConnPool.SrcAddr,
TLSWrapper: grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()),
ALPNWrapper: grpc.ALPNWrapper(d.TLSConfigurator.OutgoingALPNRPCWrapper()),
UseTLSForDC: d.TLSConfigurator.UseTLS,
DialingFromServer: cfg.ServerMode,
DialingFromDatacenter: cfg.Datacenter,
})
d.LeaderForwarder = builder
d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder)