grpc: ensure that streaming gRPC requests work over mesh gateway based wan federation (#10838)
Fixes #10796
This commit is contained in:
parent
6b574abc89
commit
a84f5fa25d
|
@ -0,0 +1,3 @@
|
||||||
|
```release-note:bug
|
||||||
|
grpc: ensure that streaming gRPC requests work over mesh gateway based wan federation
|
||||||
|
```
|
|
@ -371,6 +371,9 @@ func (f fakeGRPCConnPool) ClientConnLeader() (*grpc.ClientConn, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f fakeGRPCConnPool) SetGatewayResolver(_ func(string) string) {
|
||||||
|
}
|
||||||
|
|
||||||
func TestAgent_ReconnectConfigWanDisabled(t *testing.T) {
|
func TestAgent_ReconnectConfigWanDisabled(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("too slow for testing.Short")
|
t.Skip("too slow for testing.Short")
|
||||||
|
@ -4524,6 +4527,9 @@ LOOP:
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is a mirror of a similar test in agent/consul/server_test.go
|
// This is a mirror of a similar test in agent/consul/server_test.go
|
||||||
|
//
|
||||||
|
// TODO(rb): implement something similar to this as a full containerized test suite with proper
|
||||||
|
// isolation so requests can't "cheat" and bypass the mesh gateways
|
||||||
func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) {
|
func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("too slow for testing.Short")
|
t.Skip("too slow for testing.Short")
|
||||||
|
@ -4771,6 +4777,9 @@ func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Ensure we can do some trivial RPC in all directions.
|
// Ensure we can do some trivial RPC in all directions.
|
||||||
|
//
|
||||||
|
// NOTE: we explicitly make streaming and non-streaming assertions here to
|
||||||
|
// verify both rpc and grpc codepaths.
|
||||||
agents := map[string]*TestAgent{"dc1": a1, "dc2": a2, "dc3": a3}
|
agents := map[string]*TestAgent{"dc1": a1, "dc2": a2, "dc3": a3}
|
||||||
names := map[string]string{"dc1": "bob", "dc2": "betty", "dc3": "bonnie"}
|
names := map[string]string{"dc1": "bob", "dc2": "betty", "dc3": "bonnie"}
|
||||||
for _, srcDC := range []string{"dc1", "dc2", "dc3"} {
|
for _, srcDC := range []string{"dc1", "dc2", "dc3"} {
|
||||||
|
@ -4780,6 +4789,7 @@ func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
t.Run(srcDC+" to "+dstDC, func(t *testing.T) {
|
t.Run(srcDC+" to "+dstDC, func(t *testing.T) {
|
||||||
|
t.Run("normal-rpc", func(t *testing.T) {
|
||||||
req, err := http.NewRequest("GET", "/v1/catalog/nodes?dc="+dstDC, nil)
|
req, err := http.NewRequest("GET", "/v1/catalog/nodes?dc="+dstDC, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -4795,6 +4805,24 @@ func TestAgent_JoinWAN_viaMeshGateway(t *testing.T) {
|
||||||
require.Equal(t, dstDC, node.Datacenter)
|
require.Equal(t, dstDC, node.Datacenter)
|
||||||
require.Equal(t, names[dstDC], node.Node)
|
require.Equal(t, names[dstDC], node.Node)
|
||||||
})
|
})
|
||||||
|
t.Run("streaming-grpc", func(t *testing.T) {
|
||||||
|
req, err := http.NewRequest("GET", "/v1/health/service/consul?cached&dc="+dstDC, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
obj, err := a.srv.HealthServiceNodes(resp, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, obj)
|
||||||
|
|
||||||
|
csns, ok := obj.(structs.CheckServiceNodes)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, csns, 1)
|
||||||
|
|
||||||
|
csn := csns[0]
|
||||||
|
require.Equal(t, dstDC, csn.Node.Datacenter)
|
||||||
|
require.Equal(t, names[dstDC], csn.Node.Node)
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,10 +5,17 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-hclog"
|
||||||
|
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
|
||||||
|
"github.com/hashicorp/serf/serf"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent/grpc"
|
"github.com/hashicorp/consul/agent/grpc"
|
||||||
"github.com/hashicorp/consul/agent/grpc/resolver"
|
"github.com/hashicorp/consul/agent/grpc/resolver"
|
||||||
"github.com/hashicorp/consul/agent/pool"
|
"github.com/hashicorp/consul/agent/pool"
|
||||||
|
@ -20,11 +27,6 @@ import (
|
||||||
"github.com/hashicorp/consul/sdk/testutil/retry"
|
"github.com/hashicorp/consul/sdk/testutil/retry"
|
||||||
"github.com/hashicorp/consul/testrpc"
|
"github.com/hashicorp/consul/testrpc"
|
||||||
"github.com/hashicorp/consul/tlsutil"
|
"github.com/hashicorp/consul/tlsutil"
|
||||||
"github.com/hashicorp/go-hclog"
|
|
||||||
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
|
|
||||||
"github.com/hashicorp/serf/serf"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/time/rate"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func testClientConfig(t *testing.T) (string, *Config) {
|
func testClientConfig(t *testing.T) (string, *Config) {
|
||||||
|
@ -490,6 +492,13 @@ func newClient(t *testing.T, config *Config) *Client {
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTestResolverConfig(t *testing.T, suffix string) resolver.Config {
|
||||||
|
n := t.Name()
|
||||||
|
s := strings.Replace(n, "/", "", -1)
|
||||||
|
s = strings.Replace(s, "_", "", -1)
|
||||||
|
return resolver.Config{Authority: strings.ToLower(s) + "-" + suffix}
|
||||||
|
}
|
||||||
|
|
||||||
func newDefaultDeps(t *testing.T, c *Config) Deps {
|
func newDefaultDeps(t *testing.T, c *Config) Deps {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
@ -502,7 +511,7 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
|
||||||
tls, err := tlsutil.NewConfigurator(c.TLSConfig, logger)
|
tls, err := tlsutil.NewConfigurator(c.TLSConfig, logger)
|
||||||
require.NoError(t, err, "failed to create tls configuration")
|
require.NoError(t, err, "failed to create tls configuration")
|
||||||
|
|
||||||
builder := resolver.NewServerResolverBuilder(resolver.Config{Authority: c.NodeName})
|
builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter))
|
||||||
r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter), builder)
|
r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter), builder)
|
||||||
resolver.Register(builder)
|
resolver.Register(builder)
|
||||||
|
|
||||||
|
@ -522,7 +531,13 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
|
||||||
Tokens: new(token.Store),
|
Tokens: new(token.Store),
|
||||||
Router: r,
|
Router: r,
|
||||||
ConnPool: connPool,
|
ConnPool: connPool,
|
||||||
GRPCConnPool: grpc.NewClientConnPool(builder, grpc.TLSWrapper(tls.OutgoingRPCWrapper()), tls.UseTLS),
|
GRPCConnPool: grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
|
||||||
|
Servers: builder,
|
||||||
|
TLSWrapper: grpc.TLSWrapper(tls.OutgoingRPCWrapper()),
|
||||||
|
UseTLSForDC: tls.UseTLS,
|
||||||
|
DialingFromServer: true,
|
||||||
|
DialingFromDatacenter: c.Datacenter,
|
||||||
|
}),
|
||||||
LeaderForwarder: builder,
|
LeaderForwarder: builder,
|
||||||
EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c),
|
EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c),
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,8 +24,10 @@ type Deps struct {
|
||||||
type GRPCClientConner interface {
|
type GRPCClientConner interface {
|
||||||
ClientConn(datacenter string) (*grpc.ClientConn, error)
|
ClientConn(datacenter string) (*grpc.ClientConn, error)
|
||||||
ClientConnLeader() (*grpc.ClientConn, error)
|
ClientConnLeader() (*grpc.ClientConn, error)
|
||||||
|
SetGatewayResolver(func(string) string)
|
||||||
}
|
}
|
||||||
|
|
||||||
type LeaderForwarder interface {
|
type LeaderForwarder interface {
|
||||||
UpdateLeaderAddr(leaderAddr string)
|
// UpdateLeaderAddr updates the leader address in the local DC's resolver.
|
||||||
|
UpdateLeaderAddr(datacenter, addr string)
|
||||||
}
|
}
|
||||||
|
|
|
@ -293,7 +293,7 @@ func (s *Server) handleNativeTLS(conn net.Conn) {
|
||||||
s.handleSnapshotConn(tlsConn)
|
s.handleSnapshotConn(tlsConn)
|
||||||
|
|
||||||
case pool.ALPN_RPCGRPC:
|
case pool.ALPN_RPCGRPC:
|
||||||
s.grpcHandler.Handle(conn)
|
s.grpcHandler.Handle(tlsConn)
|
||||||
|
|
||||||
case pool.ALPN_WANGossipPacket:
|
case pool.ALPN_WANGossipPacket:
|
||||||
if err := s.handleALPN_WANGossipPacketStream(tlsConn); err != nil && err != io.EOF {
|
if err := s.handleALPN_WANGossipPacketStream(tlsConn); err != nil && err != io.EOF {
|
||||||
|
@ -373,7 +373,7 @@ func (s *Server) handleMultiplexV2(conn net.Conn) {
|
||||||
}
|
}
|
||||||
sub = peeked
|
sub = peeked
|
||||||
switch first {
|
switch first {
|
||||||
case pool.RPCGossip:
|
case byte(pool.RPCGossip):
|
||||||
buf := make([]byte, 1)
|
buf := make([]byte, 1)
|
||||||
sub.Read(buf)
|
sub.Read(buf)
|
||||||
go func() {
|
go func() {
|
||||||
|
|
|
@ -460,7 +460,7 @@ func TestRPC_TLSHandshakeTimeout(t *testing.T) {
|
||||||
|
|
||||||
// Write TLS byte to avoid being closed by either the (outer) first byte
|
// Write TLS byte to avoid being closed by either the (outer) first byte
|
||||||
// timeout or the fact that server requires TLS
|
// timeout or the fact that server requires TLS
|
||||||
_, err = conn.Write([]byte{pool.RPCTLS})
|
_, err = conn.Write([]byte{byte(pool.RPCTLS)})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Wait for more than the timeout before we start a TLS handshake. This is
|
// Wait for more than the timeout before we start a TLS handshake. This is
|
||||||
|
|
|
@ -173,6 +173,9 @@ type Server struct {
|
||||||
// Connection pool to other consul servers
|
// Connection pool to other consul servers
|
||||||
connPool *pool.ConnPool
|
connPool *pool.ConnPool
|
||||||
|
|
||||||
|
// Connection pool to other consul servers using gRPC
|
||||||
|
grpcConnPool GRPCClientConner
|
||||||
|
|
||||||
// eventChLAN is used to receive events from the
|
// eventChLAN is used to receive events from the
|
||||||
// serf cluster in the datacenter
|
// serf cluster in the datacenter
|
||||||
eventChLAN chan serf.Event
|
eventChLAN chan serf.Event
|
||||||
|
@ -348,6 +351,7 @@ func NewServer(config *Config, flat Deps) (*Server, error) {
|
||||||
config: config,
|
config: config,
|
||||||
tokens: flat.Tokens,
|
tokens: flat.Tokens,
|
||||||
connPool: flat.ConnPool,
|
connPool: flat.ConnPool,
|
||||||
|
grpcConnPool: flat.GRPCConnPool,
|
||||||
eventChLAN: make(chan serf.Event, serfEventChSize),
|
eventChLAN: make(chan serf.Event, serfEventChSize),
|
||||||
eventChWAN: make(chan serf.Event, serfEventChSize),
|
eventChWAN: make(chan serf.Event, serfEventChSize),
|
||||||
logger: serverLogger,
|
logger: serverLogger,
|
||||||
|
@ -377,6 +381,7 @@ func NewServer(config *Config, flat Deps) (*Server, error) {
|
||||||
s.config.PrimaryDatacenter,
|
s.config.PrimaryDatacenter,
|
||||||
)
|
)
|
||||||
s.connPool.GatewayResolver = s.gatewayLocator.PickGateway
|
s.connPool.GatewayResolver = s.gatewayLocator.PickGateway
|
||||||
|
s.grpcConnPool.SetGatewayResolver(s.gatewayLocator.PickGateway)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize enterprise specific server functionality
|
// Initialize enterprise specific server functionality
|
||||||
|
@ -1461,7 +1466,7 @@ func (s *Server) trackLeaderChanges() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
s.grpcLeaderForwarder.UpdateLeaderAddr(string(leaderObs.Leader))
|
s.grpcLeaderForwarder.UpdateLeaderAddr(s.config.Datacenter, string(leaderObs.Leader))
|
||||||
case <-s.shutdownCh:
|
case <-s.shutdownCh:
|
||||||
s.raft.DeregisterObserver(observer)
|
s.raft.DeregisterObserver(observer)
|
||||||
return
|
return
|
||||||
|
|
|
@ -25,6 +25,8 @@ import (
|
||||||
func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
|
func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
// TODO(rb): add tests for the wanfed/alpn variations
|
||||||
|
|
||||||
_, conf1 := testServerConfig(t)
|
_, conf1 := testServerConfig(t)
|
||||||
conf1.TLSConfig.VerifyIncoming = true
|
conf1.TLSConfig.VerifyIncoming = true
|
||||||
conf1.TLSConfig.VerifyOutgoing = true
|
conf1.TLSConfig.VerifyOutgoing = true
|
||||||
|
@ -60,7 +62,13 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
|
||||||
|
|
||||||
// Start a Subscribe call to our streaming endpoint from the client.
|
// Start a Subscribe call to our streaming endpoint from the client.
|
||||||
{
|
{
|
||||||
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS)
|
pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
|
||||||
|
Servers: builder,
|
||||||
|
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
|
||||||
|
UseTLSForDC: client.tlsConfigurator.UseTLS,
|
||||||
|
DialingFromServer: true,
|
||||||
|
DialingFromDatacenter: "dc1",
|
||||||
|
})
|
||||||
conn, err := pool.ClientConn("dc1")
|
conn, err := pool.ClientConn("dc1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -91,8 +99,13 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
|
||||||
|
|
||||||
// Start a Subscribe call to our streaming endpoint from the server's loopback client.
|
// Start a Subscribe call to our streaming endpoint from the server's loopback client.
|
||||||
{
|
{
|
||||||
|
pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
|
||||||
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS)
|
Servers: builder,
|
||||||
|
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
|
||||||
|
UseTLSForDC: client.tlsConfigurator.UseTLS,
|
||||||
|
DialingFromServer: true,
|
||||||
|
DialingFromDatacenter: "dc1",
|
||||||
|
})
|
||||||
conn, err := pool.ClientConn("dc1")
|
conn, err := pool.ClientConn("dc1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -166,7 +179,13 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) {
|
||||||
// Subscribe calls should fail initially
|
// Subscribe calls should fail initially
|
||||||
joinLAN(t, client, server)
|
joinLAN(t, client, server)
|
||||||
|
|
||||||
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS)
|
pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
|
||||||
|
Servers: builder,
|
||||||
|
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
|
||||||
|
UseTLSForDC: client.tlsConfigurator.UseTLS,
|
||||||
|
DialingFromServer: true,
|
||||||
|
DialingFromDatacenter: "dc1",
|
||||||
|
})
|
||||||
conn, err := pool.ClientConn("dc1")
|
conn, err := pool.ClientConn("dc1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -294,7 +313,13 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
pool := grpc.NewClientConnPool(builder, grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()), client.tlsConfigurator.UseTLS)
|
pool := grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
|
||||||
|
Servers: builder,
|
||||||
|
TLSWrapper: grpc.TLSWrapper(client.tlsConfigurator.OutgoingRPCWrapper()),
|
||||||
|
UseTLSForDC: client.tlsConfigurator.UseTLS,
|
||||||
|
DialingFromServer: true,
|
||||||
|
DialingFromDatacenter: "dc1",
|
||||||
|
})
|
||||||
conn, err := pool.ClientConn("dc1")
|
conn, err := pool.ClientConn("dc1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -337,7 +362,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientWithGRPCResolver(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) {
|
func newClientWithGRPCResolver(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) {
|
||||||
builder := resolver.NewServerResolverBuilder(resolver.Config{Authority: t.Name()})
|
builder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, "client"))
|
||||||
resolver.Register(builder)
|
resolver.Register(builder)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
resolver.Deregister(builder.Authority())
|
resolver.Deregister(builder.Authority())
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package wanfed
|
package wanfed
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -11,7 +12,6 @@ import (
|
||||||
"github.com/hashicorp/memberlist"
|
"github.com/hashicorp/memberlist"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent/pool"
|
"github.com/hashicorp/consul/agent/pool"
|
||||||
"github.com/hashicorp/consul/agent/structs"
|
|
||||||
"github.com/hashicorp/consul/tlsutil"
|
"github.com/hashicorp/consul/tlsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -97,13 +97,8 @@ func (t *Transport) WriteToAddress(b []byte, addr memberlist.Address) (time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
if dc != t.datacenter {
|
if dc != t.datacenter {
|
||||||
gwAddr := t.gwResolver(dc)
|
|
||||||
if gwAddr == "" {
|
|
||||||
return time.Time{}, structs.ErrDCNotAvailable
|
|
||||||
}
|
|
||||||
|
|
||||||
dialFunc := func() (net.Conn, error) {
|
dialFunc := func() (net.Conn, error) {
|
||||||
return t.dial(dc, node, pool.ALPN_WANGossipPacket, gwAddr)
|
return t.dial(dc, node, pool.ALPN_WANGossipPacket)
|
||||||
}
|
}
|
||||||
conn, err := t.pool.AcquireOrDial(addr.Name, dialFunc)
|
conn, err := t.pool.AcquireOrDial(addr.Name, dialFunc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -136,42 +131,24 @@ func (t *Transport) DialAddressTimeout(addr memberlist.Address, timeout time.Dur
|
||||||
}
|
}
|
||||||
|
|
||||||
if dc != t.datacenter {
|
if dc != t.datacenter {
|
||||||
gwAddr := t.gwResolver(dc)
|
return t.dial(dc, node, pool.ALPN_WANGossipStream)
|
||||||
if gwAddr == "" {
|
|
||||||
return nil, structs.ErrDCNotAvailable
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.dial(dc, node, pool.ALPN_WANGossipStream, gwAddr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return t.IngestionAwareTransport.DialAddressTimeout(addr, timeout)
|
return t.IngestionAwareTransport.DialAddressTimeout(addr, timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: There is a close mirror of this method in agent/pool/pool.go:DialTimeoutWithRPCType
|
func (t *Transport) dial(dc, nodeName, nextProto string) (net.Conn, error) {
|
||||||
func (t *Transport) dial(dc, nodeName, nextProto, addr string) (net.Conn, error) {
|
conn, _, err := pool.DialRPCViaMeshGateway(
|
||||||
wrapper := t.tlsConfigurator.OutgoingALPNRPCWrapper()
|
context.Background(),
|
||||||
if wrapper == nil {
|
dc,
|
||||||
return nil, fmt.Errorf("wanfed: cannot dial via a mesh gateway when outgoing TLS is disabled")
|
nodeName,
|
||||||
}
|
nil, // TODO(rb): thread source address through here?
|
||||||
|
t.tlsConfigurator.OutgoingALPNRPCWrapper(),
|
||||||
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
nextProto,
|
||||||
|
true,
|
||||||
rawConn, err := dialer.Dial("tcp", addr)
|
t.gwResolver,
|
||||||
if err != nil {
|
)
|
||||||
return nil, err
|
return conn, err
|
||||||
}
|
|
||||||
|
|
||||||
if tcp, ok := rawConn.(*net.TCPConn); ok {
|
|
||||||
_ = tcp.SetKeepAlive(true)
|
|
||||||
_ = tcp.SetNoDelay(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConn, err := wrapper(dc, nodeName, nextProto, rawConn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return tlsConn, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SplitNodeName splits a node name as it would be represented in
|
// SplitNodeName splits a node name as it would be represented in
|
||||||
|
|
|
@ -12,38 +12,93 @@ import (
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent/metadata"
|
"github.com/hashicorp/consul/agent/metadata"
|
||||||
"github.com/hashicorp/consul/agent/pool"
|
"github.com/hashicorp/consul/agent/pool"
|
||||||
|
"github.com/hashicorp/consul/tlsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClientConnPool creates and stores a connection for each datacenter.
|
// ClientConnPool creates and stores a connection for each datacenter.
|
||||||
type ClientConnPool struct {
|
type ClientConnPool struct {
|
||||||
dialer dialer
|
dialer dialer
|
||||||
servers ServerLocator
|
servers ServerLocator
|
||||||
|
gwResolverDep gatewayResolverDep
|
||||||
conns map[string]*grpc.ClientConn
|
conns map[string]*grpc.ClientConn
|
||||||
connsLock sync.Mutex
|
connsLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerLocator interface {
|
type ServerLocator interface {
|
||||||
// ServerForAddr is used to look up server metadata from an address.
|
// ServerForGlobalAddr returns server metadata for a server with the specified globally unique address.
|
||||||
ServerForAddr(addr string) (*metadata.Server, error)
|
ServerForGlobalAddr(globalAddr string) (*metadata.Server, error)
|
||||||
|
|
||||||
// Authority returns the target authority to use to dial the server. This is primarily
|
// Authority returns the target authority to use to dial the server. This is primarily
|
||||||
// needed for testing multiple agents in parallel, because gRPC requires the
|
// needed for testing multiple agents in parallel, because gRPC requires the
|
||||||
// resolver to be registered globally.
|
// resolver to be registered globally.
|
||||||
Authority() string
|
Authority() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// gatewayResolverDep is just a holder for a function pointer that can be
|
||||||
|
// updated lazily after the structs are instantiated (but before first use)
|
||||||
|
// and all structs with a reference to this struct will see the same update.
|
||||||
|
type gatewayResolverDep struct {
|
||||||
|
// GatewayResolver is a function that returns a suitable random mesh
|
||||||
|
// gateway address for dialing servers in a given DC. This is only
|
||||||
|
// needed if wan federation via mesh gateways is enabled.
|
||||||
|
GatewayResolver func(string) string
|
||||||
|
}
|
||||||
|
|
||||||
// TLSWrapper wraps a non-TLS connection and returns a connection with TLS
|
// TLSWrapper wraps a non-TLS connection and returns a connection with TLS
|
||||||
// enabled.
|
// enabled.
|
||||||
type TLSWrapper func(dc string, conn net.Conn) (net.Conn, error)
|
type TLSWrapper func(dc string, conn net.Conn) (net.Conn, error)
|
||||||
|
|
||||||
|
// ALPNWrapper is a function that is used to wrap a non-TLS connection and
|
||||||
|
// returns an appropriate TLS connection or error. This taks a datacenter and
|
||||||
|
// node name as argument to configure the desired SNI value and the desired
|
||||||
|
// next proto for configuring ALPN.
|
||||||
|
type ALPNWrapper func(dc, nodeName, alpnProto string, conn net.Conn) (net.Conn, error)
|
||||||
|
|
||||||
type dialer func(context.Context, string) (net.Conn, error)
|
type dialer func(context.Context, string) (net.Conn, error)
|
||||||
|
|
||||||
// NewClientConnPool create new GRPC client pool to connect to servers using GRPC over RPC
|
type ClientConnPoolConfig struct {
|
||||||
func NewClientConnPool(servers ServerLocator, tls TLSWrapper, useTLSForDC func(dc string) bool) *ClientConnPool {
|
// Servers is a reference for how to figure out how to dial any server.
|
||||||
return &ClientConnPool{
|
Servers ServerLocator
|
||||||
dialer: newDialer(servers, tls, useTLSForDC),
|
|
||||||
servers: servers,
|
// SrcAddr is the source address for outgoing connections.
|
||||||
|
SrcAddr *net.TCPAddr
|
||||||
|
|
||||||
|
// TLSWrapper is the specifics of wrapping a socket when doing an TYPE_BYTE+TLS
|
||||||
|
// wrapped RPC request.
|
||||||
|
TLSWrapper TLSWrapper
|
||||||
|
|
||||||
|
// ALPNWrapper is the specifics of wrapping a socket when doing an ALPN+TLS
|
||||||
|
// wrapped RPC request (typically only for wan federation via mesh
|
||||||
|
// gateways).
|
||||||
|
ALPNWrapper ALPNWrapper
|
||||||
|
|
||||||
|
// UseTLSForDC is a function to determine if dialing a given datacenter
|
||||||
|
// should use TLS.
|
||||||
|
UseTLSForDC func(dc string) bool
|
||||||
|
|
||||||
|
// DialingFromServer should be set to true if this connection pool is owned
|
||||||
|
// by a consul server instance.
|
||||||
|
DialingFromServer bool
|
||||||
|
|
||||||
|
// DialingFromDatacenter is the datacenter of the consul agent using this
|
||||||
|
// pool.
|
||||||
|
DialingFromDatacenter string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientConnPool create new GRPC client pool to connect to servers using
|
||||||
|
// GRPC over RPC.
|
||||||
|
func NewClientConnPool(cfg ClientConnPoolConfig) *ClientConnPool {
|
||||||
|
c := &ClientConnPool{
|
||||||
|
servers: cfg.Servers,
|
||||||
conns: make(map[string]*grpc.ClientConn),
|
conns: make(map[string]*grpc.ClientConn),
|
||||||
}
|
}
|
||||||
|
c.dialer = newDialer(cfg, &c.gwResolverDep)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetGatewayResolver is only to be called during setup before the pool is used.
|
||||||
|
func (c *ClientConnPool) SetGatewayResolver(gatewayResolver func(string) string) {
|
||||||
|
c.gwResolverDep.GatewayResolver = gatewayResolver
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientConn returns a grpc.ClientConn for the datacenter. If there are no
|
// ClientConn returns a grpc.ClientConn for the datacenter. If there are no
|
||||||
|
@ -102,22 +157,39 @@ func (c *ClientConnPool) dial(datacenter string, serverType string) (*grpc.Clien
|
||||||
|
|
||||||
// newDialer returns a gRPC dialer function that conditionally wraps the connection
|
// newDialer returns a gRPC dialer function that conditionally wraps the connection
|
||||||
// with TLS based on the Server.useTLS value.
|
// with TLS based on the Server.useTLS value.
|
||||||
func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc string) bool) func(context.Context, string) (net.Conn, error) {
|
func newDialer(cfg ClientConnPoolConfig, gwResolverDep *gatewayResolverDep) func(context.Context, string) (net.Conn, error) {
|
||||||
return func(ctx context.Context, addr string) (net.Conn, error) {
|
return func(ctx context.Context, globalAddr string) (net.Conn, error) {
|
||||||
d := net.Dialer{}
|
server, err := cfg.Servers.ServerForGlobalAddr(globalAddr)
|
||||||
conn, err := d.DialContext(ctx, "tcp", addr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
server, err := servers.ServerForAddr(addr)
|
if cfg.DialingFromServer &&
|
||||||
|
gwResolverDep.GatewayResolver != nil &&
|
||||||
|
cfg.ALPNWrapper != nil &&
|
||||||
|
server.Datacenter != cfg.DialingFromDatacenter {
|
||||||
|
// NOTE: TLS is required on this branch.
|
||||||
|
conn, _, err := pool.DialRPCViaMeshGateway(
|
||||||
|
ctx,
|
||||||
|
server.Datacenter,
|
||||||
|
server.ShortName,
|
||||||
|
cfg.SrcAddr,
|
||||||
|
tlsutil.ALPNWrapper(cfg.ALPNWrapper),
|
||||||
|
pool.ALPN_RPCGRPC,
|
||||||
|
cfg.DialingFromServer,
|
||||||
|
gwResolverDep.GatewayResolver,
|
||||||
|
)
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
d := net.Dialer{LocalAddr: cfg.SrcAddr, Timeout: pool.DefaultDialTimeout}
|
||||||
|
conn, err := d.DialContext(ctx, "tcp", server.Addr.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if server.UseTLS && useTLSForDC(server.Datacenter) {
|
if server.UseTLS && cfg.UseTLSForDC(server.Datacenter) {
|
||||||
if wrapper == nil {
|
if cfg.TLSWrapper == nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, fmt.Errorf("TLS enabled but got nil TLS wrapper")
|
return nil, fmt.Errorf("TLS enabled but got nil TLS wrapper")
|
||||||
}
|
}
|
||||||
|
@ -129,7 +201,7 @@ func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc st
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap the connection in a TLS client
|
// Wrap the connection in a TLS client
|
||||||
tlsConn, err := wrapper(server.Datacenter, conn)
|
tlsConn, err := cfg.TLSWrapper(server.Datacenter, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -137,7 +209,7 @@ func newDialer(servers ServerLocator, wrapper TLSWrapper, useTLSForDC func(dc st
|
||||||
conn = tlsConn
|
conn = tlsConn
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = conn.Write([]byte{pool.RPCGRPC})
|
_, err = conn.Write([]byte{byte(pool.RPCGRPC)})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -4,17 +4,22 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/tcpproxy"
|
||||||
"github.com/hashicorp/go-hclog"
|
"github.com/hashicorp/go-hclog"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent/grpc/internal/testservice"
|
"github.com/hashicorp/consul/agent/grpc/internal/testservice"
|
||||||
"github.com/hashicorp/consul/agent/grpc/resolver"
|
"github.com/hashicorp/consul/agent/grpc/resolver"
|
||||||
"github.com/hashicorp/consul/agent/metadata"
|
"github.com/hashicorp/consul/agent/metadata"
|
||||||
|
"github.com/hashicorp/consul/ipaddr"
|
||||||
|
"github.com/hashicorp/consul/sdk/freeport"
|
||||||
"github.com/hashicorp/consul/tlsutil"
|
"github.com/hashicorp/consul/tlsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,11 +29,14 @@ func useTLSForDcAlwaysTrue(_ string) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewDialer_WithTLSWrapper(t *testing.T) {
|
func TestNewDialer_WithTLSWrapper(t *testing.T) {
|
||||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
ports := freeport.MustTake(1)
|
||||||
|
defer freeport.Return(ports)
|
||||||
|
|
||||||
|
lis, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(ports[0])))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(logError(t, lis.Close))
|
t.Cleanup(logError(t, lis.Close))
|
||||||
|
|
||||||
builder := resolver.NewServerResolverBuilder(resolver.Config{})
|
builder := resolver.NewServerResolverBuilder(newConfig(t))
|
||||||
builder.AddServer(&metadata.Server{
|
builder.AddServer(&metadata.Server{
|
||||||
Name: "server-1",
|
Name: "server-1",
|
||||||
ID: "ID1",
|
ID: "ID1",
|
||||||
|
@ -42,19 +50,107 @@ func TestNewDialer_WithTLSWrapper(t *testing.T) {
|
||||||
called = true
|
called = true
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
dial := newDialer(builder, wrapper, useTLSForDcAlwaysTrue)
|
dial := newDialer(
|
||||||
|
ClientConnPoolConfig{
|
||||||
|
Servers: builder,
|
||||||
|
TLSWrapper: wrapper,
|
||||||
|
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||||
|
DialingFromServer: true,
|
||||||
|
DialingFromDatacenter: "dc1",
|
||||||
|
},
|
||||||
|
&gatewayResolverDep{},
|
||||||
|
)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
conn, err := dial(ctx, lis.Addr().String())
|
conn, err := dial(ctx, resolver.DCPrefix("dc1", lis.Addr().String()))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, conn.Close())
|
require.NoError(t, conn.Close())
|
||||||
require.True(t, called, "expected TLSWrapper to be called")
|
require.True(t, called, "expected TLSWrapper to be called")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewDialer_WithALPNWrapper(t *testing.T) {
|
||||||
|
ports := freeport.MustTake(3)
|
||||||
|
defer freeport.Return(ports)
|
||||||
|
|
||||||
|
var (
|
||||||
|
s1addr = ipaddr.FormatAddressPort("127.0.0.1", ports[0])
|
||||||
|
s2addr = ipaddr.FormatAddressPort("127.0.0.1", ports[1])
|
||||||
|
gwAddr = ipaddr.FormatAddressPort("127.0.0.1", ports[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
lis1, err := net.Listen("tcp", s1addr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(logError(t, lis1.Close))
|
||||||
|
|
||||||
|
lis2, err := net.Listen("tcp", s2addr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(logError(t, lis2.Close))
|
||||||
|
|
||||||
|
// Send all of the traffic to dc2's server
|
||||||
|
var p tcpproxy.Proxy
|
||||||
|
p.AddRoute(gwAddr, tcpproxy.To(s2addr))
|
||||||
|
p.AddStopACMESearch(gwAddr)
|
||||||
|
require.NoError(t, p.Start())
|
||||||
|
defer func() {
|
||||||
|
p.Close()
|
||||||
|
p.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
builder := resolver.NewServerResolverBuilder(newConfig(t))
|
||||||
|
builder.AddServer(&metadata.Server{
|
||||||
|
Name: "server-1",
|
||||||
|
ID: "ID1",
|
||||||
|
Datacenter: "dc1",
|
||||||
|
Addr: lis1.Addr(),
|
||||||
|
UseTLS: true,
|
||||||
|
})
|
||||||
|
builder.AddServer(&metadata.Server{
|
||||||
|
Name: "server-2",
|
||||||
|
ID: "ID2",
|
||||||
|
Datacenter: "dc2",
|
||||||
|
Addr: lis2.Addr(),
|
||||||
|
UseTLS: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
var calledTLS bool
|
||||||
|
wrapperTLS := func(_ string, conn net.Conn) (net.Conn, error) {
|
||||||
|
calledTLS = true
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
var calledALPN bool
|
||||||
|
wrapperALPN := func(_, _, _ string, conn net.Conn) (net.Conn, error) {
|
||||||
|
calledALPN = true
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
gwResolverDep := &gatewayResolverDep{
|
||||||
|
GatewayResolver: func(addr string) string {
|
||||||
|
return gwAddr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dial := newDialer(
|
||||||
|
ClientConnPoolConfig{
|
||||||
|
Servers: builder,
|
||||||
|
TLSWrapper: wrapperTLS,
|
||||||
|
ALPNWrapper: wrapperALPN,
|
||||||
|
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||||
|
DialingFromServer: true,
|
||||||
|
DialingFromDatacenter: "dc1",
|
||||||
|
},
|
||||||
|
gwResolverDep,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
conn, err := dial(ctx, resolver.DCPrefix("dc2", lis2.Addr().String()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, conn.Close())
|
||||||
|
|
||||||
|
assert.False(t, calledTLS, "expected TLSWrapper not to be called")
|
||||||
|
assert.True(t, calledALPN, "expected ALPNWrapper to be called")
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
|
func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
|
||||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||||
registerWithGRPC(t, res)
|
registerWithGRPC(t, res)
|
||||||
|
|
||||||
srv := newTestServer(t, "server-1", "dc1")
|
|
||||||
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
|
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
|
||||||
VerifyIncoming: true,
|
VerifyIncoming: true,
|
||||||
VerifyOutgoing: true,
|
VerifyOutgoing: true,
|
||||||
|
@ -63,12 +159,20 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
|
||||||
KeyFile: "../../test/hostname/Alice.key",
|
KeyFile: "../../test/hostname/Alice.key",
|
||||||
}, hclog.New(nil))
|
}, hclog.New(nil))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
srv.rpc.tlsConf = tlsConf
|
|
||||||
|
|
||||||
res.AddServer(srv.Metadata())
|
srv := newTestServer(t, "server-1", "dc1", tlsConf)
|
||||||
|
|
||||||
|
md := srv.Metadata()
|
||||||
|
res.AddServer(md)
|
||||||
t.Cleanup(srv.shutdown)
|
t.Cleanup(srv.shutdown)
|
||||||
|
|
||||||
pool := NewClientConnPool(res, TLSWrapper(tlsConf.OutgoingRPCWrapper()), tlsConf.UseTLS)
|
pool := NewClientConnPool(ClientConnPoolConfig{
|
||||||
|
Servers: res,
|
||||||
|
TLSWrapper: TLSWrapper(tlsConf.OutgoingRPCWrapper()),
|
||||||
|
UseTLSForDC: tlsConf.UseTLS,
|
||||||
|
DialingFromServer: true,
|
||||||
|
DialingFromDatacenter: "dc1",
|
||||||
|
})
|
||||||
|
|
||||||
conn, err := pool.ClientConn("dc1")
|
conn, err := pool.ClientConn("dc1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -81,17 +185,98 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "server-1", resp.ServerName)
|
require.Equal(t, "server-1", resp.ServerName)
|
||||||
require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) > 0)
|
require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) > 0)
|
||||||
|
require.True(t, atomic.LoadInt32(&srv.rpc.alpnConnEstablished) == 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T) {
|
||||||
|
ports := freeport.MustTake(1)
|
||||||
|
defer freeport.Return(ports)
|
||||||
|
|
||||||
|
gwAddr := ipaddr.FormatAddressPort("127.0.0.1", ports[0])
|
||||||
|
|
||||||
|
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||||
|
registerWithGRPC(t, res)
|
||||||
|
|
||||||
|
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
|
||||||
|
VerifyIncoming: true,
|
||||||
|
VerifyOutgoing: true,
|
||||||
|
VerifyServerHostname: true,
|
||||||
|
CAFile: "../../test/hostname/CertAuth.crt",
|
||||||
|
CertFile: "../../test/hostname/Bob.crt",
|
||||||
|
KeyFile: "../../test/hostname/Bob.key",
|
||||||
|
Domain: "consul",
|
||||||
|
NodeName: "bob",
|
||||||
|
}, hclog.New(nil))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
srv := newTestServer(t, "bob", "dc1", tlsConf)
|
||||||
|
|
||||||
|
// Send all of the traffic to dc1's server
|
||||||
|
var p tcpproxy.Proxy
|
||||||
|
p.AddRoute(gwAddr, tcpproxy.To(srv.addr.String()))
|
||||||
|
p.AddStopACMESearch(gwAddr)
|
||||||
|
require.NoError(t, p.Start())
|
||||||
|
defer func() {
|
||||||
|
p.Close()
|
||||||
|
p.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
md := srv.Metadata()
|
||||||
|
res.AddServer(md)
|
||||||
|
t.Cleanup(srv.shutdown)
|
||||||
|
|
||||||
|
clientTLSConf, err := tlsutil.NewConfigurator(tlsutil.Config{
|
||||||
|
VerifyIncoming: true,
|
||||||
|
VerifyOutgoing: true,
|
||||||
|
VerifyServerHostname: true,
|
||||||
|
CAFile: "../../test/hostname/CertAuth.crt",
|
||||||
|
CertFile: "../../test/hostname/Betty.crt",
|
||||||
|
KeyFile: "../../test/hostname/Betty.key",
|
||||||
|
Domain: "consul",
|
||||||
|
NodeName: "betty",
|
||||||
|
}, hclog.New(nil))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pool := NewClientConnPool(ClientConnPoolConfig{
|
||||||
|
Servers: res,
|
||||||
|
TLSWrapper: TLSWrapper(clientTLSConf.OutgoingRPCWrapper()),
|
||||||
|
ALPNWrapper: ALPNWrapper(clientTLSConf.OutgoingALPNRPCWrapper()),
|
||||||
|
UseTLSForDC: tlsConf.UseTLS,
|
||||||
|
DialingFromServer: true,
|
||||||
|
DialingFromDatacenter: "dc2",
|
||||||
|
})
|
||||||
|
pool.SetGatewayResolver(func(addr string) string {
|
||||||
|
return gwAddr
|
||||||
|
})
|
||||||
|
|
||||||
|
conn, err := pool.ClientConn("dc1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
client := testservice.NewSimpleClient(conn)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
resp, err := client.Something(ctx, &testservice.Req{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "bob", resp.ServerName)
|
||||||
|
require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) == 0)
|
||||||
|
require.True(t, atomic.LoadInt32(&srv.rpc.alpnConnEstablished) > 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
|
func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
|
||||||
count := 4
|
count := 4
|
||||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||||
registerWithGRPC(t, res)
|
registerWithGRPC(t, res)
|
||||||
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue)
|
pool := NewClientConnPool(ClientConnPoolConfig{
|
||||||
|
Servers: res,
|
||||||
|
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||||
|
DialingFromServer: true,
|
||||||
|
DialingFromDatacenter: "dc1",
|
||||||
|
})
|
||||||
|
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
name := fmt.Sprintf("server-%d", i)
|
name := fmt.Sprintf("server-%d", i)
|
||||||
srv := newTestServer(t, name, "dc1")
|
srv := newTestServer(t, name, "dc1", nil)
|
||||||
res.AddServer(srv.Metadata())
|
res.AddServer(srv.Metadata())
|
||||||
t.Cleanup(srv.shutdown)
|
t.Cleanup(srv.shutdown)
|
||||||
}
|
}
|
||||||
|
@ -115,22 +300,27 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
|
||||||
|
|
||||||
func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
|
func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
|
||||||
count := 3
|
count := 3
|
||||||
conf := newConfig(t)
|
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||||
res := resolver.NewServerResolverBuilder(conf)
|
|
||||||
registerWithGRPC(t, res)
|
registerWithGRPC(t, res)
|
||||||
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue)
|
pool := NewClientConnPool(ClientConnPoolConfig{
|
||||||
|
Servers: res,
|
||||||
|
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||||
|
DialingFromServer: true,
|
||||||
|
DialingFromDatacenter: "dc1",
|
||||||
|
})
|
||||||
|
|
||||||
var servers []testServer
|
var servers []testServer
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
name := fmt.Sprintf("server-%d", i)
|
name := fmt.Sprintf("server-%d", i)
|
||||||
srv := newTestServer(t, name, "dc1")
|
srv := newTestServer(t, name, "dc1", nil)
|
||||||
res.AddServer(srv.Metadata())
|
res.AddServer(srv.Metadata())
|
||||||
servers = append(servers, srv)
|
servers = append(servers, srv)
|
||||||
t.Cleanup(srv.shutdown)
|
t.Cleanup(srv.shutdown)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the leader address to the first server.
|
// Set the leader address to the first server.
|
||||||
res.UpdateLeaderAddr(servers[0].addr.String())
|
srv0 := servers[0].Metadata()
|
||||||
|
res.UpdateLeaderAddr(srv0.Datacenter, srv0.Addr.String())
|
||||||
|
|
||||||
conn, err := pool.ClientConnLeader()
|
conn, err := pool.ClientConnLeader()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -144,7 +334,8 @@ func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
|
||||||
require.Equal(t, first.ServerName, servers[0].name)
|
require.Equal(t, first.ServerName, servers[0].name)
|
||||||
|
|
||||||
// Update the leader address and make another request.
|
// Update the leader address and make another request.
|
||||||
res.UpdateLeaderAddr(servers[1].addr.String())
|
srv1 := servers[1].Metadata()
|
||||||
|
res.UpdateLeaderAddr(srv1.Datacenter, srv1.Addr.String())
|
||||||
|
|
||||||
resp, err := client.Something(ctx, &testservice.Req{})
|
resp, err := client.Something(ctx, &testservice.Req{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -162,11 +353,16 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) {
|
||||||
count := 5
|
count := 5
|
||||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||||
registerWithGRPC(t, res)
|
registerWithGRPC(t, res)
|
||||||
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue)
|
pool := NewClientConnPool(ClientConnPoolConfig{
|
||||||
|
Servers: res,
|
||||||
|
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||||
|
DialingFromServer: true,
|
||||||
|
DialingFromDatacenter: "dc1",
|
||||||
|
})
|
||||||
|
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
name := fmt.Sprintf("server-%d", i)
|
name := fmt.Sprintf("server-%d", i)
|
||||||
srv := newTestServer(t, name, "dc1")
|
srv := newTestServer(t, name, "dc1", nil)
|
||||||
res.AddServer(srv.Metadata())
|
res.AddServer(srv.Metadata())
|
||||||
t.Cleanup(srv.shutdown)
|
t.Cleanup(srv.shutdown)
|
||||||
}
|
}
|
||||||
|
@ -211,11 +407,16 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {
|
||||||
|
|
||||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||||
registerWithGRPC(t, res)
|
registerWithGRPC(t, res)
|
||||||
pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue)
|
pool := NewClientConnPool(ClientConnPoolConfig{
|
||||||
|
Servers: res,
|
||||||
|
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||||
|
DialingFromServer: true,
|
||||||
|
DialingFromDatacenter: "dc1",
|
||||||
|
})
|
||||||
|
|
||||||
for _, dc := range dcs {
|
for _, dc := range dcs {
|
||||||
name := "server-0-" + dc
|
name := "server-0-" + dc
|
||||||
srv := newTestServer(t, name, dc)
|
srv := newTestServer(t, name, dc, nil)
|
||||||
res.AddServer(srv.Metadata())
|
res.AddServer(srv.Metadata())
|
||||||
t.Cleanup(srv.shutdown)
|
t.Cleanup(srv.shutdown)
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,17 +67,17 @@ func (s *ServerResolverBuilder) NewRebalancer(dc string) func() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerForAddr returns server metadata for a server with the specified address.
|
// ServerForGlobalAddr returns server metadata for a server with the specified globally unique address.
|
||||||
func (s *ServerResolverBuilder) ServerForAddr(addr string) (*metadata.Server, error) {
|
func (s *ServerResolverBuilder) ServerForGlobalAddr(globalAddr string) (*metadata.Server, error) {
|
||||||
s.lock.RLock()
|
s.lock.RLock()
|
||||||
defer s.lock.RUnlock()
|
defer s.lock.RUnlock()
|
||||||
|
|
||||||
for _, server := range s.servers {
|
for _, server := range s.servers {
|
||||||
if server.Addr.String() == addr {
|
if DCPrefix(server.Datacenter, server.Addr.String()) == globalAddr {
|
||||||
return server, nil
|
return server, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("failed to find Consul server for address %q", addr)
|
return nil, fmt.Errorf("failed to find Consul server for global address %q", globalAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build returns a new serverResolver for the given ClientConn. The resolver
|
// Build returns a new serverResolver for the given ClientConn. The resolver
|
||||||
|
@ -161,6 +161,12 @@ func uniqueID(server *metadata.Server) string {
|
||||||
return server.Datacenter + "-" + server.ID
|
return server.Datacenter + "-" + server.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DCPrefix prefixes the given string with a datacenter for use in
|
||||||
|
// disambiguation.
|
||||||
|
func DCPrefix(datacenter, suffix string) string {
|
||||||
|
return datacenter + "-" + suffix
|
||||||
|
}
|
||||||
|
|
||||||
// RemoveServer updates the resolvers' states with the given server removed.
|
// RemoveServer updates the resolvers' states with the given server removed.
|
||||||
func (s *ServerResolverBuilder) RemoveServer(server *metadata.Server) {
|
func (s *ServerResolverBuilder) RemoveServer(server *metadata.Server) {
|
||||||
s.lock.Lock()
|
s.lock.Lock()
|
||||||
|
@ -186,7 +192,8 @@ func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address {
|
||||||
}
|
}
|
||||||
|
|
||||||
addrs = append(addrs, resolver.Address{
|
addrs = append(addrs, resolver.Address{
|
||||||
Addr: server.Addr.String(),
|
// NOTE: the address persisted here is only dialable using our custom dialer
|
||||||
|
Addr: DCPrefix(server.Datacenter, server.Addr.String()),
|
||||||
Type: resolver.Backend,
|
Type: resolver.Backend,
|
||||||
ServerName: server.Name,
|
ServerName: server.Name,
|
||||||
})
|
})
|
||||||
|
@ -195,11 +202,11 @@ func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address {
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLeaderAddr updates the leader address in the local DC's resolver.
|
// UpdateLeaderAddr updates the leader address in the local DC's resolver.
|
||||||
func (s *ServerResolverBuilder) UpdateLeaderAddr(leaderAddr string) {
|
func (s *ServerResolverBuilder) UpdateLeaderAddr(datacenter, addr string) {
|
||||||
s.lock.Lock()
|
s.lock.Lock()
|
||||||
defer s.lock.Unlock()
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
s.leaderResolver.addr = leaderAddr
|
s.leaderResolver.globalAddr = DCPrefix(datacenter, addr)
|
||||||
s.leaderResolver.updateClientConn()
|
s.leaderResolver.updateClientConn()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -262,7 +269,7 @@ func (r *serverResolver) Close() {
|
||||||
func (*serverResolver) ResolveNow(resolver.ResolveNowOption) {}
|
func (*serverResolver) ResolveNow(resolver.ResolveNowOption) {}
|
||||||
|
|
||||||
type leaderResolver struct {
|
type leaderResolver struct {
|
||||||
addr string
|
globalAddr string
|
||||||
clientConn resolver.ClientConn
|
clientConn resolver.ClientConn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -271,12 +278,13 @@ func (l leaderResolver) ResolveNow(resolver.ResolveNowOption) {}
|
||||||
func (l leaderResolver) Close() {}
|
func (l leaderResolver) Close() {}
|
||||||
|
|
||||||
func (l leaderResolver) updateClientConn() {
|
func (l leaderResolver) updateClientConn() {
|
||||||
if l.addr == "" || l.clientConn == nil {
|
if l.globalAddr == "" || l.clientConn == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
addrs := []resolver.Address{
|
addrs := []resolver.Address{
|
||||||
{
|
{
|
||||||
Addr: l.addr,
|
// NOTE: the address persisted here is only dialable using our custom dialer
|
||||||
|
Addr: l.globalAddr,
|
||||||
Type: resolver.Backend,
|
Type: resolver.Backend,
|
||||||
ServerName: "leader",
|
ServerName: "leader",
|
||||||
},
|
},
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -17,6 +18,7 @@ import (
|
||||||
"github.com/hashicorp/consul/agent/grpc/internal/testservice"
|
"github.com/hashicorp/consul/agent/grpc/internal/testservice"
|
||||||
"github.com/hashicorp/consul/agent/metadata"
|
"github.com/hashicorp/consul/agent/metadata"
|
||||||
"github.com/hashicorp/consul/agent/pool"
|
"github.com/hashicorp/consul/agent/pool"
|
||||||
|
"github.com/hashicorp/consul/sdk/freeport"
|
||||||
"github.com/hashicorp/consul/tlsutil"
|
"github.com/hashicorp/consul/tlsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,22 +33,29 @@ type testServer struct {
|
||||||
func (s testServer) Metadata() *metadata.Server {
|
func (s testServer) Metadata() *metadata.Server {
|
||||||
return &metadata.Server{
|
return &metadata.Server{
|
||||||
ID: s.name,
|
ID: s.name,
|
||||||
|
Name: s.name + "." + s.dc,
|
||||||
|
ShortName: s.name,
|
||||||
Datacenter: s.dc,
|
Datacenter: s.dc,
|
||||||
Addr: s.addr,
|
Addr: s.addr,
|
||||||
UseTLS: s.rpc.tlsConf != nil,
|
UseTLS: s.rpc.tlsConf != nil,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestServer(t *testing.T, name string, dc string) testServer {
|
func newTestServer(t *testing.T, name string, dc string, tlsConf *tlsutil.Configurator) testServer {
|
||||||
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
|
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
|
||||||
handler := NewHandler(addr, func(server *grpc.Server) {
|
handler := NewHandler(addr, func(server *grpc.Server) {
|
||||||
testservice.RegisterSimpleServer(server, &simple{name: name, dc: dc})
|
testservice.RegisterSimpleServer(server, &simple{name: name, dc: dc})
|
||||||
})
|
})
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
ports := freeport.MustTake(1)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
freeport.Return(ports)
|
||||||
|
})
|
||||||
|
|
||||||
|
lis, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(ports[0])))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
rpc := &fakeRPCListener{t: t, handler: handler}
|
rpc := &fakeRPCListener{t: t, handler: handler, tlsConf: tlsConf}
|
||||||
|
|
||||||
g := errgroup.Group{}
|
g := errgroup.Group{}
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
|
@ -112,6 +121,7 @@ type fakeRPCListener struct {
|
||||||
shutdown bool
|
shutdown bool
|
||||||
tlsConf *tlsutil.Configurator
|
tlsConf *tlsutil.Configurator
|
||||||
tlsConnEstablished int32
|
tlsConnEstablished int32
|
||||||
|
alpnConnEstablished int32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fakeRPCListener) listen(listener net.Listener) error {
|
func (f *fakeRPCListener) listen(listener net.Listener) error {
|
||||||
|
@ -129,6 +139,26 @@ func (f *fakeRPCListener) listen(listener net.Listener) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fakeRPCListener) handleConn(conn net.Conn) {
|
func (f *fakeRPCListener) handleConn(conn net.Conn) {
|
||||||
|
if f.tlsConf != nil && f.tlsConf.MutualTLSCapable() {
|
||||||
|
// See if actually this is native TLS multiplexed onto the old
|
||||||
|
// "type-byte" system.
|
||||||
|
|
||||||
|
peekedConn, nativeTLS, err := pool.PeekForTLS(conn)
|
||||||
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
fmt.Printf("ERROR: failed to read first byte: %v\n", err)
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if nativeTLS {
|
||||||
|
f.handleNativeTLSConn(peekedConn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn = peekedConn
|
||||||
|
}
|
||||||
|
|
||||||
buf := make([]byte, 1)
|
buf := make([]byte, 1)
|
||||||
|
|
||||||
if _, err := conn.Read(buf); err != nil {
|
if _, err := conn.Read(buf); err != nil {
|
||||||
|
@ -166,3 +196,32 @@ func (f *fakeRPCListener) handleConn(conn net.Conn) {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *fakeRPCListener) handleNativeTLSConn(conn net.Conn) {
|
||||||
|
tlscfg := f.tlsConf.IncomingALPNRPCConfig(pool.RPCNextProtos)
|
||||||
|
tlsConn := tls.Server(conn, tlscfg)
|
||||||
|
|
||||||
|
// Force the handshake to conclude.
|
||||||
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
fmt.Printf("ERROR: TLS handshake failed: %v", err)
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.SetReadDeadline(time.Time{})
|
||||||
|
|
||||||
|
var (
|
||||||
|
cs = tlsConn.ConnectionState()
|
||||||
|
nextProto = cs.NegotiatedProtocol
|
||||||
|
)
|
||||||
|
|
||||||
|
switch nextProto {
|
||||||
|
case pool.ALPN_RPCGRPC:
|
||||||
|
atomic.AddInt32(&f.alpnConnEstablished, 1)
|
||||||
|
f.handler.Handle(tlsConn)
|
||||||
|
|
||||||
|
default:
|
||||||
|
fmt.Printf("ERROR: discarding RPC for unknown negotiated protocol %q\n", nextProto)
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -20,6 +20,8 @@ func (t RPCType) ALPNString() string {
|
||||||
return ALPN_RPCGossip
|
return ALPN_RPCGossip
|
||||||
case RPCTLSInsecure:
|
case RPCTLSInsecure:
|
||||||
return "" // unsupported
|
return "" // unsupported
|
||||||
|
case RPCGRPC:
|
||||||
|
return ALPN_RPCGRPC
|
||||||
default:
|
default:
|
||||||
return "" // unsupported
|
return "" // unsupported
|
||||||
}
|
}
|
||||||
|
@ -28,19 +30,19 @@ func (t RPCType) ALPNString() string {
|
||||||
const (
|
const (
|
||||||
// keep numbers unique.
|
// keep numbers unique.
|
||||||
RPCConsul RPCType = 0
|
RPCConsul RPCType = 0
|
||||||
RPCRaft = 1
|
RPCRaft RPCType = 1
|
||||||
RPCMultiplex = 2 // Old Muxado byte, no longer supported.
|
RPCMultiplex RPCType = 2 // Old Muxado byte, no longer supported.
|
||||||
RPCTLS = 3
|
RPCTLS RPCType = 3
|
||||||
RPCMultiplexV2 = 4
|
RPCMultiplexV2 RPCType = 4
|
||||||
RPCSnapshot = 5
|
RPCSnapshot RPCType = 5
|
||||||
RPCGossip = 6
|
RPCGossip RPCType = 6
|
||||||
// RPCTLSInsecure is used to flag RPC calls that require verify
|
// RPCTLSInsecure is used to flag RPC calls that require verify
|
||||||
// incoming to be disabled, even when it is turned on in the
|
// incoming to be disabled, even when it is turned on in the
|
||||||
// configuration. At the time of writing there is only AutoEncrypt.Sign
|
// configuration. At the time of writing there is only AutoEncrypt.Sign
|
||||||
// that is supported and it might be the only one there
|
// that is supported and it might be the only one there
|
||||||
// ever is.
|
// ever is.
|
||||||
RPCTLSInsecure = 7
|
RPCTLSInsecure RPCType = 7
|
||||||
RPCGRPC = 8
|
RPCGRPC RPCType = 8
|
||||||
|
|
||||||
// RPCMaxTypeValue is the maximum rpc type byte value currently used for the
|
// RPCMaxTypeValue is the maximum rpc type byte value currently used for the
|
||||||
// various protocols riding over our "rpc" port.
|
// various protocols riding over our "rpc" port.
|
||||||
|
@ -79,6 +81,7 @@ var RPCNextProtos = []string{
|
||||||
ALPN_RPCMultiplexV2,
|
ALPN_RPCMultiplexV2,
|
||||||
ALPN_RPCSnapshot,
|
ALPN_RPCSnapshot,
|
||||||
ALPN_RPCGossip,
|
ALPN_RPCGossip,
|
||||||
|
ALPN_RPCGRPC,
|
||||||
ALPN_WANGossipPacket,
|
ALPN_WANGossipPacket,
|
||||||
ALPN_WANGossipStream,
|
ALPN_WANGossipStream,
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,8 +10,9 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/tlsutil"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul/tlsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPeekForTLS_not_TLS(t *testing.T) {
|
func TestPeekForTLS_not_TLS(t *testing.T) {
|
||||||
|
@ -30,6 +31,7 @@ func TestPeekForTLS_not_TLS(t *testing.T) {
|
||||||
RPCSnapshot,
|
RPCSnapshot,
|
||||||
RPCGossip,
|
RPCGossip,
|
||||||
RPCTLSInsecure,
|
RPCTLSInsecure,
|
||||||
|
RPCGRPC,
|
||||||
} {
|
} {
|
||||||
cases = append(cases, testcase{
|
cases = append(cases, testcase{
|
||||||
name: fmt.Sprintf("tcp rpc type byte %d", rpcType),
|
name: fmt.Sprintf("tcp rpc type byte %d", rpcType),
|
||||||
|
@ -76,6 +78,7 @@ func TestPeekForTLS_actual_TLS(t *testing.T) {
|
||||||
RPCSnapshot,
|
RPCSnapshot,
|
||||||
RPCGossip,
|
RPCGossip,
|
||||||
RPCTLSInsecure,
|
RPCTLSInsecure,
|
||||||
|
RPCGRPC,
|
||||||
} {
|
} {
|
||||||
cases = append(cases, testcase{
|
cases = append(cases, testcase{
|
||||||
name: fmt.Sprintf("tcp rpc type byte %d", rpcType),
|
name: fmt.Sprintf("tcp rpc type byte %d", rpcType),
|
||||||
|
|
|
@ -2,6 +2,7 @@ package pool
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"container/list"
|
"container/list"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
@ -11,14 +12,15 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
|
||||||
|
"github.com/hashicorp/yamux"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent/structs"
|
"github.com/hashicorp/consul/agent/structs"
|
||||||
"github.com/hashicorp/consul/lib"
|
"github.com/hashicorp/consul/lib"
|
||||||
"github.com/hashicorp/consul/tlsutil"
|
"github.com/hashicorp/consul/tlsutil"
|
||||||
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
|
|
||||||
"github.com/hashicorp/yamux"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultDialTimeout = 10 * time.Second
|
const DefaultDialTimeout = 10 * time.Second
|
||||||
|
|
||||||
// muxSession is used to provide an interface for a stream multiplexer.
|
// muxSession is used to provide an interface for a stream multiplexer.
|
||||||
type muxSession interface {
|
type muxSession interface {
|
||||||
|
@ -291,21 +293,24 @@ func (p *ConnPool) DialTimeout(
|
||||||
) (net.Conn, HalfCloser, error) {
|
) (net.Conn, HalfCloser, error) {
|
||||||
p.once.Do(p.init)
|
p.once.Do(p.init)
|
||||||
|
|
||||||
if p.Server && p.GatewayResolver != nil && p.TLSConfigurator != nil && dc != p.Datacenter {
|
if p.Server &&
|
||||||
|
p.GatewayResolver != nil &&
|
||||||
|
p.TLSConfigurator != nil &&
|
||||||
|
dc != p.Datacenter {
|
||||||
// NOTE: TLS is required on this branch.
|
// NOTE: TLS is required on this branch.
|
||||||
return DialTimeoutWithRPCTypeViaMeshGateway(
|
nextProto := actualRPCType.ALPNString()
|
||||||
|
if nextProto == "" {
|
||||||
|
return nil, nil, fmt.Errorf("rpc type %d cannot be routed through a mesh gateway", actualRPCType)
|
||||||
|
}
|
||||||
|
return DialRPCViaMeshGateway(
|
||||||
|
context.Background(),
|
||||||
dc,
|
dc,
|
||||||
nodeName,
|
nodeName,
|
||||||
addr,
|
|
||||||
p.SrcAddr,
|
p.SrcAddr,
|
||||||
p.TLSConfigurator.OutgoingALPNRPCWrapper(),
|
p.TLSConfigurator.OutgoingALPNRPCWrapper(),
|
||||||
actualRPCType,
|
nextProto,
|
||||||
RPCTLS,
|
|
||||||
// gateway stuff
|
|
||||||
p.Server,
|
p.Server,
|
||||||
p.TLSConfigurator,
|
|
||||||
p.GatewayResolver,
|
p.GatewayResolver,
|
||||||
p.Datacenter,
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -319,7 +324,7 @@ func (p *ConnPool) dial(
|
||||||
tlsRPCType RPCType,
|
tlsRPCType RPCType,
|
||||||
) (net.Conn, HalfCloser, error) {
|
) (net.Conn, HalfCloser, error) {
|
||||||
// Try to dial the conn
|
// Try to dial the conn
|
||||||
d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: defaultDialTimeout}
|
d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: DefaultDialTimeout}
|
||||||
conn, err := d.Dial("tcp", addr.String())
|
conn, err := d.Dial("tcp", addr.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
@ -372,62 +377,49 @@ func (p *ConnPool) dial(
|
||||||
return conn, hc, nil
|
return conn, hc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialTimeoutWithRPCTypeViaMeshGateway dials the destination node and sets up
|
// DialRPCViaMeshGateway dials the destination node and sets up the connection
|
||||||
// the connection to be the correct RPC type using ALPN. This currently is
|
// to be the correct RPC type using ALPN. This currently is exclusively used to
|
||||||
// exclusively used to dial other servers in foreign datacenters via mesh
|
// dial other servers in foreign datacenters via mesh gateways.
|
||||||
// gateways.
|
func DialRPCViaMeshGateway(
|
||||||
//
|
ctx context.Context,
|
||||||
// NOTE: There is a close mirror of this method in agent/consul/wanfed/wanfed.go:dial
|
dc string, // (metadata.Server).Datacenter
|
||||||
func DialTimeoutWithRPCTypeViaMeshGateway(
|
nodeName string, // (metadata.Server).ShortName
|
||||||
dc string,
|
srcAddr *net.TCPAddr,
|
||||||
nodeName string,
|
alpnWrapper tlsutil.ALPNWrapper,
|
||||||
addr net.Addr,
|
nextProto string,
|
||||||
src *net.TCPAddr,
|
|
||||||
wrapper tlsutil.ALPNWrapper,
|
|
||||||
actualRPCType RPCType,
|
|
||||||
tlsRPCType RPCType,
|
|
||||||
// gateway stuff
|
|
||||||
dialingFromServer bool,
|
dialingFromServer bool,
|
||||||
tlsConfigurator *tlsutil.Configurator,
|
|
||||||
gatewayResolver func(string) string,
|
gatewayResolver func(string) string,
|
||||||
thisDatacenter string,
|
|
||||||
) (net.Conn, HalfCloser, error) {
|
) (net.Conn, HalfCloser, error) {
|
||||||
if !dialingFromServer {
|
if !dialingFromServer {
|
||||||
return nil, nil, fmt.Errorf("must dial via mesh gateways from a server agent")
|
return nil, nil, fmt.Errorf("must dial via mesh gateways from a server agent")
|
||||||
} else if gatewayResolver == nil {
|
} else if gatewayResolver == nil {
|
||||||
return nil, nil, fmt.Errorf("gatewayResolver is nil")
|
return nil, nil, fmt.Errorf("gatewayResolver is nil")
|
||||||
} else if tlsConfigurator == nil {
|
} else if alpnWrapper == nil {
|
||||||
return nil, nil, fmt.Errorf("tlsConfigurator is nil")
|
|
||||||
} else if dc == thisDatacenter {
|
|
||||||
return nil, nil, fmt.Errorf("cannot dial servers in the same datacenter via a mesh gateway")
|
|
||||||
} else if wrapper == nil {
|
|
||||||
return nil, nil, fmt.Errorf("cannot dial via a mesh gateway when outgoing TLS is disabled")
|
return nil, nil, fmt.Errorf("cannot dial via a mesh gateway when outgoing TLS is disabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
nextProto := actualRPCType.ALPNString()
|
|
||||||
if nextProto == "" {
|
|
||||||
return nil, nil, fmt.Errorf("rpc type %d cannot be routed through a mesh gateway", actualRPCType)
|
|
||||||
}
|
|
||||||
|
|
||||||
gwAddr := gatewayResolver(dc)
|
gwAddr := gatewayResolver(dc)
|
||||||
if gwAddr == "" {
|
if gwAddr == "" {
|
||||||
return nil, nil, structs.ErrDCNotAvailable
|
return nil, nil, structs.ErrDCNotAvailable
|
||||||
}
|
}
|
||||||
|
|
||||||
dialer := &net.Dialer{LocalAddr: src, Timeout: defaultDialTimeout}
|
dialer := &net.Dialer{LocalAddr: srcAddr, Timeout: DefaultDialTimeout}
|
||||||
|
|
||||||
rawConn, err := dialer.Dial("tcp", gwAddr)
|
rawConn, err := dialer.DialContext(ctx, "tcp", gwAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if nextProto != ALPN_RPCGRPC {
|
||||||
|
// agent/grpc/client.go:dial() handles this in another way for gRPC
|
||||||
if tcp, ok := rawConn.(*net.TCPConn); ok {
|
if tcp, ok := rawConn.(*net.TCPConn); ok {
|
||||||
_ = tcp.SetKeepAlive(true)
|
_ = tcp.SetKeepAlive(true)
|
||||||
_ = tcp.SetNoDelay(true)
|
_ = tcp.SetNoDelay(true)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: now we wrap the connection in a TLS client.
|
// NOTE: now we wrap the connection in a TLS client.
|
||||||
tlsConn, err := wrapper(dc, nodeName, nextProto, rawConn)
|
tlsConn, err := alpnWrapper(dc, nodeName, nextProto, rawConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -106,9 +106,22 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer) (BaseDeps, error)
|
||||||
d.ViewStore = submatview.NewStore(d.Logger.Named("viewstore"))
|
d.ViewStore = submatview.NewStore(d.Logger.Named("viewstore"))
|
||||||
d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator)
|
d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator)
|
||||||
|
|
||||||
builder := resolver.NewServerResolverBuilder(resolver.Config{})
|
builder := resolver.NewServerResolverBuilder(resolver.Config{
|
||||||
|
// Set the authority to something sufficiently unique so any usage in
|
||||||
|
// tests would be self-isolating in the global resolver map, while also
|
||||||
|
// not incurring a huge penalty for non-test code.
|
||||||
|
Authority: cfg.Datacenter + "." + string(cfg.NodeID),
|
||||||
|
})
|
||||||
resolver.Register(builder)
|
resolver.Register(builder)
|
||||||
d.GRPCConnPool = grpc.NewClientConnPool(builder, grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()), d.TLSConfigurator.UseTLS)
|
d.GRPCConnPool = grpc.NewClientConnPool(grpc.ClientConnPoolConfig{
|
||||||
|
Servers: builder,
|
||||||
|
SrcAddr: d.ConnPool.SrcAddr,
|
||||||
|
TLSWrapper: grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper()),
|
||||||
|
ALPNWrapper: grpc.ALPNWrapper(d.TLSConfigurator.OutgoingALPNRPCWrapper()),
|
||||||
|
UseTLSForDC: d.TLSConfigurator.UseTLS,
|
||||||
|
DialingFromServer: cfg.ServerMode,
|
||||||
|
DialingFromDatacenter: cfg.Datacenter,
|
||||||
|
})
|
||||||
d.LeaderForwarder = builder
|
d.LeaderForwarder = builder
|
||||||
|
|
||||||
d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder)
|
d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder)
|
||||||
|
|
Loading…
Reference in New Issue