grpc: ensure grpc resolver correctly uses lan/wan addresses on servers (#17270)

The grpc resolver implementation is fed from changes to the
router.Router. Within the router there is a map of various areas storing
the addressing information for servers in those areas. All map entries
are of the WAN variety except a single special entry for the LAN.

Addressing information in the LAN "area" are local addresses intended
for use when making a client-to-server or server-to-server request.

The client agent correctly updates this LAN area when receiving lan serf
events, so by extension the grpc resolver works fine in that scenario.

The server agent only initially populates a single entry in the LAN area
(for itself) on startup, and then never mutates that area map again.
For normal RPCs a different structure is used for LAN routing.

Additionally when selecting a server to contact in the local datacenter
it will randomly select addresses from either the LAN or WAN addressed
entries in the map.

Unfortunately this means that the grpc resolver stack as it exists on
server agents is either broken or only accidentally functions by having
servers dial each other over the WAN-accessible address. If the operator
disables the serf wan port completely likely this incidental functioning
would break.

This PR enforces that local requests for servers (both for stale reads
or leader forwarded requests) exclusively use the LAN "area" information
and also fixes it so that servers keep that area up to date in the
router.

A test for the grpc resolver logic was added, as well as a higher level
full-stack test to ensure the externally perceived bug does not return.
This commit is contained in:
R.B. Boyer 2023-05-11 11:08:57 -05:00 committed by GitHub
parent cb16046672
commit 0b79707beb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 552 additions and 31 deletions

3
.changelog/17270.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
grpc: ensure grpc resolver correctly uses lan/wan addresses on servers
```

View File

@ -504,11 +504,15 @@ func newClient(t *testing.T, config *Config) *Client {
return client return client
} }
func newTestResolverConfig(t *testing.T, suffix string) resolver.Config { func newTestResolverConfig(t *testing.T, suffix string, dc, agentType string) resolver.Config {
n := t.Name() n := t.Name()
s := strings.Replace(n, "/", "", -1) s := strings.Replace(n, "/", "", -1)
s = strings.Replace(s, "_", "", -1) s = strings.Replace(s, "_", "", -1)
return resolver.Config{Authority: strings.ToLower(s) + "-" + suffix} return resolver.Config{
Datacenter: dc,
AgentType: agentType,
Authority: strings.ToLower(s) + "-" + suffix,
}
} }
func newDefaultDeps(t *testing.T, c *Config) Deps { func newDefaultDeps(t *testing.T, c *Config) Deps {
@ -523,7 +527,7 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
tls, err := tlsutil.NewConfigurator(c.TLSConfig, logger) tls, err := tlsutil.NewConfigurator(c.TLSConfig, logger)
require.NoError(t, err, "failed to create tls configuration") require.NoError(t, err, "failed to create tls configuration")
resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter)) resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter, c.Datacenter, "server"))
resolver.Register(resolverBuilder) resolver.Register(resolverBuilder)
t.Cleanup(func() { t.Cleanup(func() {
resolver.Deregister(resolverBuilder.Authority()) resolver.Deregister(resolverBuilder.Authority())

View File

@ -23,6 +23,7 @@ import (
"github.com/hashicorp/consul/lib" "github.com/hashicorp/consul/lib"
libserf "github.com/hashicorp/consul/lib/serf" libserf "github.com/hashicorp/consul/lib/serf"
"github.com/hashicorp/consul/logging" "github.com/hashicorp/consul/logging"
"github.com/hashicorp/consul/types"
) )
const ( const (
@ -359,6 +360,7 @@ func (s *Server) lanNodeJoin(me serf.MemberEvent) {
// Update server lookup // Update server lookup
s.serverLookup.AddServer(serverMeta) s.serverLookup.AddServer(serverMeta)
s.router.AddServer(types.AreaLAN, serverMeta)
// If we're still expecting to bootstrap, may need to handle this. // If we're still expecting to bootstrap, may need to handle this.
if s.config.BootstrapExpect != 0 { if s.config.BootstrapExpect != 0 {
@ -380,6 +382,7 @@ func (s *Server) lanNodeUpdate(me serf.MemberEvent) {
// Update server lookup // Update server lookup
s.serverLookup.AddServer(serverMeta) s.serverLookup.AddServer(serverMeta)
s.router.AddServer(types.AreaLAN, serverMeta)
} }
} }
@ -518,5 +521,6 @@ func (s *Server) lanNodeFailed(me serf.MemberEvent) {
// Update id to address map // Update id to address map
s.serverLookup.RemoveServer(serverMeta) s.serverLookup.RemoveServer(serverMeta)
s.router.RemoveServer(types.AreaLAN, serverMeta)
} }
} }

View File

@ -382,7 +382,10 @@ func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *re
} }
resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t,
"client."+config.Datacenter+"."+string(config.NodeID))) "client."+config.Datacenter+"."+string(config.NodeID),
config.Datacenter,
"client",
))
resolver.Register(resolverBuilder) resolver.Register(resolverBuilder)
t.Cleanup(func() { t.Cleanup(func() {

View File

@ -38,8 +38,8 @@ func TestNewDialer_WithTLSWrapper(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(logError(t, lis.Close)) t.Cleanup(logError(t, lis.Close))
builder := resolver.NewServerResolverBuilder(newConfig(t)) builder := resolver.NewServerResolverBuilder(newConfig(t, "dc1", "server"))
builder.AddServer(types.AreaWAN, &metadata.Server{ builder.AddServer(types.AreaLAN, &metadata.Server{
Name: "server-1", Name: "server-1",
ID: "ID1", ID: "ID1",
Datacenter: "dc1", Datacenter: "dc1",
@ -89,7 +89,7 @@ func TestNewDialer_WithALPNWrapper(t *testing.T) {
p.Wait() p.Wait()
}() }()
builder := resolver.NewServerResolverBuilder(newConfig(t)) builder := resolver.NewServerResolverBuilder(newConfig(t, "dc1", "server"))
builder.AddServer(types.AreaWAN, &metadata.Server{ builder.AddServer(types.AreaWAN, &metadata.Server{
Name: "server-1", Name: "server-1",
ID: "ID1", ID: "ID1",
@ -144,7 +144,7 @@ func TestNewDialer_WithALPNWrapper(t *testing.T) {
func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) { func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
// if this test is failing because of expired certificates // if this test is failing because of expired certificates
// use the procedure in test/CA-GENERATION.md // use the procedure in test/CA-GENERATION.md
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t, "dc1", "server"))
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb) registerWithGRPC(t, res, bb)
@ -162,9 +162,17 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
srv := newSimpleTestServer(t, "server-1", "dc1", tlsConf) srv := newSimpleTestServer(t, "server-1", "dc1", tlsConf)
md := srv.Metadata() md := srv.Metadata()
res.AddServer(types.AreaWAN, md) res.AddServer(types.AreaLAN, md)
t.Cleanup(srv.shutdown) t.Cleanup(srv.shutdown)
{
// Put a duplicate instance of this on the WAN that will
// fail if we accidentally use it.
srv := newPanicTestServer(t, hclog.Default(), "server-1", "dc1", nil)
res.AddServer(types.AreaWAN, srv.Metadata())
t.Cleanup(srv.shutdown)
}
pool := NewClientConnPool(ClientConnPoolConfig{ pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res, Servers: res,
TLSWrapper: TLSWrapper(tlsConf.OutgoingRPCWrapper()), TLSWrapper: TLSWrapper(tlsConf.OutgoingRPCWrapper()),
@ -192,7 +200,7 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T)
// use the procedure in test/CA-GENERATION.md // use the procedure in test/CA-GENERATION.md
gwAddr := ipaddr.FormatAddressPort("127.0.0.1", freeport.GetOne(t)) gwAddr := ipaddr.FormatAddressPort("127.0.0.1", freeport.GetOne(t))
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t, "dc2", "server"))
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb) registerWithGRPC(t, res, bb)
@ -268,7 +276,7 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T)
func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
count := 4 count := 4
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t, "dc1", "server"))
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb) registerWithGRPC(t, res, bb)
pool := NewClientConnPool(ClientConnPoolConfig{ pool := NewClientConnPool(ClientConnPoolConfig{
@ -280,9 +288,18 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
name := fmt.Sprintf("server-%d", i) name := fmt.Sprintf("server-%d", i)
srv := newSimpleTestServer(t, name, "dc1", nil) {
res.AddServer(types.AreaWAN, srv.Metadata()) srv := newSimpleTestServer(t, name, "dc1", nil)
t.Cleanup(srv.shutdown) res.AddServer(types.AreaLAN, srv.Metadata())
t.Cleanup(srv.shutdown)
}
{
// Put a duplicate instance of this on the WAN that will
// fail if we accidentally use it.
srv := newPanicTestServer(t, hclog.Default(), name, "dc1", nil)
res.AddServer(types.AreaWAN, srv.Metadata())
t.Cleanup(srv.shutdown)
}
} }
conn, err := pool.ClientConn("dc1") conn, err := pool.ClientConn("dc1")
@ -295,7 +312,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
first, err := client.Something(ctx, &testservice.Req{}) first, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err) require.NoError(t, err)
res.RemoveServer(types.AreaWAN, &metadata.Server{ID: first.ServerName, Datacenter: "dc1"}) res.RemoveServer(types.AreaLAN, &metadata.Server{ID: first.ServerName, Datacenter: "dc1"})
resp, err := client.Something(ctx, &testservice.Req{}) resp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err) require.NoError(t, err)
@ -304,7 +321,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) { func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
count := 3 count := 3
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t, "dc1", "server"))
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb) registerWithGRPC(t, res, bb)
pool := NewClientConnPool(ClientConnPoolConfig{ pool := NewClientConnPool(ClientConnPoolConfig{
@ -317,10 +334,19 @@ func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
var servers []testServer var servers []testServer
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
name := fmt.Sprintf("server-%d", i) name := fmt.Sprintf("server-%d", i)
srv := newSimpleTestServer(t, name, "dc1", nil) {
res.AddServer(types.AreaWAN, srv.Metadata()) srv := newSimpleTestServer(t, name, "dc1", nil)
servers = append(servers, srv) res.AddServer(types.AreaLAN, srv.Metadata())
t.Cleanup(srv.shutdown) servers = append(servers, srv)
t.Cleanup(srv.shutdown)
}
{
// Put a duplicate instance of this on the WAN that will
// fail if we accidentally use it.
srv := newPanicTestServer(t, hclog.Default(), name, "dc1", nil)
res.AddServer(types.AreaWAN, srv.Metadata())
t.Cleanup(srv.shutdown)
}
} }
// Set the leader address to the first server. // Set the leader address to the first server.
@ -347,19 +373,24 @@ func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
require.Equal(t, resp.ServerName, servers[1].name) require.Equal(t, resp.ServerName, servers[1].name)
} }
func newConfig(t *testing.T) resolver.Config { func newConfig(t *testing.T, dc, agentType string) resolver.Config {
n := t.Name() n := t.Name()
s := strings.Replace(n, "/", "", -1) s := strings.Replace(n, "/", "", -1)
s = strings.Replace(s, "_", "", -1) s = strings.Replace(s, "_", "", -1)
return resolver.Config{Authority: strings.ToLower(s)} return resolver.Config{
Datacenter: dc,
AgentType: agentType,
Authority: strings.ToLower(s),
}
} }
func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {
dcs := []string{"dc1", "dc2", "dc3"} dcs := []string{"dc1", "dc2", "dc3"}
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t, "dc1", "server"))
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb) registerWithGRPC(t, res, bb)
pool := NewClientConnPool(ClientConnPoolConfig{ pool := NewClientConnPool(ClientConnPoolConfig{
Servers: res, Servers: res,
UseTLSForDC: useTLSForDcAlwaysTrue, UseTLSForDC: useTLSForDcAlwaysTrue,
@ -370,7 +401,16 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {
for _, dc := range dcs { for _, dc := range dcs {
name := "server-0-" + dc name := "server-0-" + dc
srv := newSimpleTestServer(t, name, dc, nil) srv := newSimpleTestServer(t, name, dc, nil)
res.AddServer(types.AreaWAN, srv.Metadata()) if dc == "dc1" {
res.AddServer(types.AreaLAN, srv.Metadata())
// Put a duplicate instance of this on the WAN that will
// fail if we accidentally use it.
srvBad := newPanicTestServer(t, hclog.Default(), name, dc, nil)
res.AddServer(types.AreaWAN, srvBad.Metadata())
t.Cleanup(srvBad.shutdown)
} else {
res.AddServer(types.AreaWAN, srv.Metadata())
}
t.Cleanup(srv.shutdown) t.Cleanup(srv.shutdown)
} }

View File

@ -31,12 +31,12 @@ func TestHandler_PanicRecoveryInterceptor(t *testing.T) {
Output: &buf, Output: &buf,
}) })
res := resolver.NewServerResolverBuilder(newConfig(t)) res := resolver.NewServerResolverBuilder(newConfig(t, "dc1", "server"))
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
registerWithGRPC(t, res, bb) registerWithGRPC(t, res, bb)
srv := newPanicTestServer(t, logger, "server-1", "dc1", nil) srv := newPanicTestServer(t, logger, "server-1", "dc1", nil)
res.AddServer(types.AreaWAN, srv.Metadata()) res.AddServer(types.AreaLAN, srv.Metadata())
t.Cleanup(srv.shutdown) t.Cleanup(srv.shutdown)
pool := NewClientConnPool(ClientConnPoolConfig{ pool := NewClientConnPool(ClientConnPoolConfig{

View File

@ -18,25 +18,45 @@ import (
// ServerResolvers updated when changes occur. // ServerResolvers updated when changes occur.
type ServerResolverBuilder struct { type ServerResolverBuilder struct {
cfg Config cfg Config
// leaderResolver is used to track the address of the leader in the local DC. // leaderResolver is used to track the address of the leader in the local DC.
leaderResolver leaderResolver leaderResolver leaderResolver
// servers is an index of Servers by area and Server.ID. The map contains server IDs // servers is an index of Servers by area and Server.ID. The map contains server IDs
// for all datacenters. // for all datacenters.
servers map[types.AreaID]map[string]*metadata.Server servers map[types.AreaID]map[string]*metadata.Server
// resolvers is an index of connections to the serverResolver which manages // resolvers is an index of connections to the serverResolver which manages
// addresses of servers for that connection. // addresses of servers for that connection.
//
// this is only applicable for non-leader conn types
resolvers map[resolver.ClientConn]*serverResolver resolvers map[resolver.ClientConn]*serverResolver
// lock for all stateful fields (excludes config which is immutable). // lock for all stateful fields (excludes config which is immutable).
lock sync.RWMutex lock sync.RWMutex
} }
type Config struct { type Config struct {
// Datacenter is the datacenter of this agent.
Datacenter string
// AgentType is either 'server' or 'client' and is required.
AgentType string
// Authority used to query the server. Defaults to "". Used to support // Authority used to query the server. Defaults to "". Used to support
// parallel testing because gRPC registers resolvers globally. // parallel testing because gRPC registers resolvers globally.
Authority string Authority string
} }
func NewServerResolverBuilder(cfg Config) *ServerResolverBuilder { func NewServerResolverBuilder(cfg Config) *ServerResolverBuilder {
if cfg.Datacenter == "" {
panic("ServerResolverBuilder needs Config.Datacenter to be nonempty")
}
switch cfg.AgentType {
case "server", "client":
default:
panic("ServerResolverBuilder needs Config.AgentType to be either server or client")
}
return &ServerResolverBuilder{ return &ServerResolverBuilder{
cfg: cfg, cfg: cfg,
servers: make(map[types.AreaID]map[string]*metadata.Server), servers: make(map[types.AreaID]map[string]*metadata.Server),
@ -56,6 +76,7 @@ func (s *ServerResolverBuilder) ServerForGlobalAddr(globalAddr string) (*metadat
} }
} }
} }
return nil, fmt.Errorf("failed to find Consul server for global address %q", globalAddr) return nil, fmt.Errorf("failed to find Consul server for global address %q", globalAddr)
} }
@ -67,12 +88,12 @@ func (s *ServerResolverBuilder) Build(target resolver.Target, cc resolver.Client
// If there's already a resolver for this connection, return it. // If there's already a resolver for this connection, return it.
// TODO(streaming): how would this happen since we already cache connections in ClientConnPool? // TODO(streaming): how would this happen since we already cache connections in ClientConnPool?
if resolver, ok := s.resolvers[cc]; ok {
return resolver, nil
}
if cc == s.leaderResolver.clientConn { if cc == s.leaderResolver.clientConn {
return s.leaderResolver, nil return s.leaderResolver, nil
} }
if resolver, ok := s.resolvers[cc]; ok {
return resolver, nil
}
//nolint:staticcheck //nolint:staticcheck
serverType, datacenter, err := parseEndpoint(target.Endpoint) serverType, datacenter, err := parseEndpoint(target.Endpoint)
@ -119,6 +140,10 @@ func (s *ServerResolverBuilder) Authority() string {
// AddServer updates the resolvers' states to include the new server's address. // AddServer updates the resolvers' states to include the new server's address.
func (s *ServerResolverBuilder) AddServer(areaID types.AreaID, server *metadata.Server) { func (s *ServerResolverBuilder) AddServer(areaID types.AreaID, server *metadata.Server) {
if s.shouldIgnoreServer(areaID, server) {
return
}
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
@ -155,6 +180,10 @@ func DCPrefix(datacenter, suffix string) string {
// RemoveServer updates the resolvers' states with the given server removed. // RemoveServer updates the resolvers' states with the given server removed.
func (s *ServerResolverBuilder) RemoveServer(areaID types.AreaID, server *metadata.Server) { func (s *ServerResolverBuilder) RemoveServer(areaID types.AreaID, server *metadata.Server) {
if s.shouldIgnoreServer(areaID, server) {
return
}
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
@ -176,14 +205,48 @@ func (s *ServerResolverBuilder) RemoveServer(areaID types.AreaID, server *metada
} }
} }
// shouldIgnoreServer is used to contextually decide if a particular kind of
// server should be accepted into a given area.
//
// On client agents it's pretty easy: clients only participate in the standard
// LAN, so we only accept servers from the LAN.
//
// On server agents it's a little less obvious. This resolver is ultimately
// used to have servers dial other servers. If a server is going to cross
// between datacenters (using traditional federation) then we want to use the
// WAN addresses for them, but if a server is going to dial a sibling server in
// the same datacenter we want it to use the LAN addresses always. To achieve
// that here we simply never allow WAN servers for our current datacenter to be
// added into the resolver, letting only the LAN instances through.
func (s *ServerResolverBuilder) shouldIgnoreServer(areaID types.AreaID, server *metadata.Server) bool {
if s.cfg.AgentType == "client" && areaID != types.AreaLAN {
return true
}
if s.cfg.AgentType == "server" &&
server.Datacenter == s.cfg.Datacenter &&
areaID != types.AreaLAN {
return true
}
return false
}
// getDCAddrs returns a list of the server addresses for the given datacenter. // getDCAddrs returns a list of the server addresses for the given datacenter.
// This method requires that lock is held for reads. // This method requires that lock is held for reads.
func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address { func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address {
lanRequest := (s.cfg.Datacenter == dc)
var ( var (
addrs []resolver.Address addrs []resolver.Address
keptServerIDs = make(map[string]struct{}) keptServerIDs = make(map[string]struct{})
) )
for _, areaServers := range s.servers { for areaID, areaServers := range s.servers {
if (areaID == types.AreaLAN) != lanRequest {
// LAN requests only look at LAN data. WAN requests only look at
// WAN data.
continue
}
for _, server := range areaServers { for _, server := range areaServers {
if server.Datacenter != dc { if server.Datacenter != dc {
continue continue

View File

@ -0,0 +1,195 @@
package resolver
import (
"fmt"
"net"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/types"
)
func TestServerResolverBuilder(t *testing.T) {
const agentDC = "dc1"
type testcase struct {
name string
agentType string // server/client
serverType string // server/leader
requestDC string
expectLAN bool
}
run := func(t *testing.T, tc testcase) {
rs := NewServerResolverBuilder(newConfig(t, agentDC, tc.agentType))
endpoint := ""
if tc.serverType == "leader" {
endpoint = "leader.local"
} else {
endpoint = tc.serverType + "." + tc.requestDC
}
cc := &fakeClientConn{}
_, err := rs.Build(resolver.Target{
Scheme: "consul",
Authority: rs.Authority(),
Endpoint: endpoint,
}, cc, resolver.BuildOptions{})
require.NoError(t, err)
for i := 0; i < 3; i++ {
dc := fmt.Sprintf("dc%d", i+1)
for j := 0; j < 3; j++ {
wanIP := fmt.Sprintf("127.1.%d.%d", i+1, j+10)
name := fmt.Sprintf("%s-server-%d", dc, j+1)
wanMeta := newServerMeta(name, dc, wanIP, true)
if tc.agentType == "server" {
rs.AddServer(types.AreaWAN, wanMeta)
}
if dc == agentDC {
// register LAN/WAN pairs for the same instances
lanIP := fmt.Sprintf("127.0.%d.%d", i+1, j+10)
lanMeta := newServerMeta(name, dc, lanIP, false)
rs.AddServer(types.AreaLAN, lanMeta)
if j == 0 {
rs.UpdateLeaderAddr(dc, lanIP)
}
}
}
}
if tc.serverType == "leader" {
assert.Len(t, cc.state.Addresses, 1)
} else {
assert.Len(t, cc.state.Addresses, 3)
}
for _, addr := range cc.state.Addresses {
addrPrefix := tc.requestDC + "-"
if tc.expectLAN {
addrPrefix += "127.0."
} else {
addrPrefix += "127.1."
}
assert.True(t, strings.HasPrefix(addr.Addr, addrPrefix),
"%q does not start with %q (returned WAN for LAN request)", addr.Addr, addrPrefix)
if tc.expectLAN {
assert.False(t, strings.Contains(addr.ServerName, ".dc"),
"%q ends with datacenter suffix (returned WAN for LAN request)", addr.ServerName)
} else {
assert.True(t, strings.HasSuffix(addr.ServerName, "."+tc.requestDC),
"%q does not end with %q", addr.ServerName, "."+tc.requestDC)
}
}
}
cases := []testcase{
{
name: "server requesting local servers",
agentType: "server",
serverType: "server",
requestDC: agentDC,
expectLAN: true,
},
{
name: "server requesting remote servers in dc2",
agentType: "server",
serverType: "server",
requestDC: "dc2",
expectLAN: false,
},
{
name: "server requesting remote servers in dc3",
agentType: "server",
serverType: "server",
requestDC: "dc3",
expectLAN: false,
},
// ---------------
{
name: "server requesting local leader",
agentType: "server",
serverType: "leader",
requestDC: agentDC,
expectLAN: true,
},
// ---------------
{
name: "client requesting local server",
agentType: "client",
serverType: "server",
requestDC: agentDC,
expectLAN: true,
},
{
name: "client requesting local leader",
agentType: "client",
serverType: "leader",
requestDC: agentDC,
expectLAN: true,
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
func newServerMeta(name, dc, ip string, wan bool) *metadata.Server {
fullname := name
if wan {
fullname = name + "." + dc
}
return &metadata.Server{
ID: name,
Name: fullname,
ShortName: name,
Datacenter: dc,
Addr: &net.IPAddr{IP: net.ParseIP(ip)},
UseTLS: false,
}
}
func newConfig(t *testing.T, dc, agentType string) Config {
n := t.Name()
s := strings.Replace(n, "/", "", -1)
s = strings.Replace(s, "_", "", -1)
return Config{
Datacenter: dc,
AgentType: agentType,
Authority: strings.ToLower(s),
}
}
// fakeClientConn implements resolver.ClientConn for tests
type fakeClientConn struct {
state resolver.State
}
var _ resolver.ClientConn = (*fakeClientConn)(nil)
func (f *fakeClientConn) UpdateState(state resolver.State) error {
f.state = state
return nil
}
func (*fakeClientConn) ReportError(error) {}
func (*fakeClientConn) NewAddress(addresses []resolver.Address) {}
func (*fakeClientConn) NewServiceConfig(serviceConfig string) {}
func (*fakeClientConn) ParseServiceConfig(serviceConfigJSON string) *serviceconfig.ParseResult {
return nil
}

View File

@ -4,6 +4,7 @@
package agent package agent
import ( import (
"bufio"
"bytes" "bytes"
"context" "context"
"encoding/base64" "encoding/base64"
@ -12,19 +13,208 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strconv"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc"
gpeer "google.golang.org/grpc/peer"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/proto/private/pbpeering" "github.com/hashicorp/consul/proto/private/pbpeering"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/testrpc"
) )
func TestHTTP_Peering_Integration(t *testing.T) {
// This is a full-stack integration test of the gRPC (internal) stack. We
// use peering CRUD b/c that is one of the few endpoints exposed over gRPC
// (internal).
if testing.Short() {
t.Skip("too slow for testing.Short")
}
// We advertise a wan address we are not using, so that incidental attempts
// to use it will loudly fail.
const ip = "192.0.2.2"
connectivityConfig := `
ports { serf_wan = -1 }
bind_addr = "0.0.0.0"
client_addr = "0.0.0.0"
advertise_addr = "127.0.0.1"
advertise_addr_wan = "` + ip + `" `
var (
buf1, buf2, buf3 bytes.Buffer
testLog = testutil.NewLogBuffer(t)
log1 = io.MultiWriter(testLog, &buf1)
log2 = io.MultiWriter(testLog, &buf2)
log3 = io.MultiWriter(testLog, &buf3)
)
a1 := StartTestAgent(t, TestAgent{LogOutput: log1, HCL: `
server = true
bootstrap = false
bootstrap_expect = 3
` + connectivityConfig})
t.Cleanup(func() { a1.Shutdown() })
a2 := StartTestAgent(t, TestAgent{LogOutput: log2, HCL: `
server = true
bootstrap = false
bootstrap_expect = 3
` + connectivityConfig})
t.Cleanup(func() { a2.Shutdown() })
a3 := StartTestAgent(t, TestAgent{LogOutput: log3, HCL: `
server = true
bootstrap = false
bootstrap_expect = 3
` + connectivityConfig})
t.Cleanup(func() { a3.Shutdown() })
{ // join a2 to a1
addr := fmt.Sprintf("127.0.0.1:%d", a2.Config.SerfPortLAN)
_, err := a1.JoinLAN([]string{addr}, nil)
require.NoError(t, err)
}
{ // join a3 to a1
addr := fmt.Sprintf("127.0.0.1:%d", a3.Config.SerfPortLAN)
_, err := a1.JoinLAN([]string{addr}, nil)
require.NoError(t, err)
}
testrpc.WaitForLeader(t, a1.RPC, "dc1")
testrpc.WaitForActiveCARoot(t, a1.RPC, "dc1", nil)
testrpc.WaitForTestAgent(t, a1.RPC, "dc1")
testrpc.WaitForTestAgent(t, a2.RPC, "dc1")
testrpc.WaitForTestAgent(t, a3.RPC, "dc1")
retry.Run(t, func(r *retry.R) {
require.Len(r, a1.LANMembersInAgentPartition(), 3)
require.Len(r, a2.LANMembersInAgentPartition(), 3)
require.Len(r, a3.LANMembersInAgentPartition(), 3)
})
type testcase struct {
agent *TestAgent
peerName string
prevCount int
}
checkPeeringList := func(t *testing.T, a *TestAgent, expect int) {
req, err := http.NewRequest("GET", "/v1/peerings", nil)
require.NoError(t, err)
resp := httptest.NewRecorder()
a.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusOK, resp.Code)
var apiResp []*api.Peering
require.NoError(t, json.NewDecoder(resp.Body).Decode(&apiResp))
require.Len(t, apiResp, expect)
}
testConn := func(t *testing.T, conn *grpc.ClientConn, peers map[string]int) {
rpcClientPeering := pbpeering.NewPeeringServiceClient(conn)
peer := &gpeer.Peer{}
_, err := rpcClientPeering.PeeringList(
context.Background(),
&pbpeering.PeeringListRequest{},
grpc.Peer(peer),
)
require.NoError(t, err)
peers[peer.Addr.String()]++
}
var (
standardPeers = make(map[string]int)
leaderPeers = make(map[string]int)
)
runOnce := func(t *testing.T, tc testcase) {
conn, err := tc.agent.baseDeps.GRPCConnPool.ClientConn("dc1")
require.NoError(t, err)
testConn(t, conn, standardPeers)
leaderConn, err := tc.agent.baseDeps.GRPCConnPool.ClientConnLeader()
require.NoError(t, err)
testConn(t, leaderConn, leaderPeers)
checkPeeringList(t, tc.agent, tc.prevCount)
body := &pbpeering.GenerateTokenRequest{
PeerName: tc.peerName,
}
bodyBytes, err := json.Marshal(body)
require.NoError(t, err)
req, err := http.NewRequest("POST", "/v1/peering/token", bytes.NewReader(bodyBytes))
require.NoError(t, err)
resp := httptest.NewRecorder()
tc.agent.srv.h.ServeHTTP(resp, req)
require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String())
var r pbpeering.GenerateTokenResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&r))
checkPeeringList(t, tc.agent, tc.prevCount+1)
}
// Try the procedure on all agents to force N-1 of them to leader-forward.
cases := []testcase{
{agent: a1, peerName: "peer-1", prevCount: 0},
{agent: a2, peerName: "peer-2", prevCount: 1},
{agent: a3, peerName: "peer-3", prevCount: 2},
}
for i, tc := range cases {
tc := tc
testutil.RunStep(t, "server-"+strconv.Itoa(i+1), func(t *testing.T) {
runOnce(t, tc)
})
}
testutil.RunStep(t, "ensure we got the right mixture of responses", func(t *testing.T) {
assert.Len(t, standardPeers, 3)
// Each server talks to a single leader.
assert.Len(t, leaderPeers, 1)
for p, n := range leaderPeers {
assert.Equal(t, 3, n, "peer %q expected 3 uses", p)
}
})
testutil.RunStep(t, "no server experienced the server resolution error", func(t *testing.T) {
// Check them all for the bad error
const grpcError = `failed to find Consul server for global address`
var buf bytes.Buffer
buf.ReadFrom(&buf1)
buf.ReadFrom(&buf2)
buf.ReadFrom(&buf3)
scan := bufio.NewScanner(&buf)
for scan.Scan() {
line := scan.Text()
require.NotContains(t, line, grpcError)
}
require.NoError(t, scan.Err())
})
}
func TestHTTP_Peering_GenerateToken(t *testing.T) { func TestHTTP_Peering_GenerateToken(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("too slow for testing.Short") t.Skip("too slow for testing.Short")

View File

@ -11,6 +11,7 @@ import (
"net" "net"
"os" "os"
"path" "path"
"strings"
"testing" "testing"
"time" "time"
@ -1669,6 +1670,17 @@ type testingServer struct {
PublicGRPCAddr string PublicGRPCAddr string
} }
func newConfig(t *testing.T, dc, agentType string) resolver.Config {
n := t.Name()
s := strings.Replace(n, "/", "", -1)
s = strings.Replace(s, "_", "", -1)
return resolver.Config{
Datacenter: dc,
AgentType: agentType,
Authority: strings.ToLower(s),
}
}
// TODO(peering): remove duplication between this and agent/consul tests // TODO(peering): remove duplication between this and agent/consul tests
func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps { func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps {
t.Helper() t.Helper()
@ -1683,7 +1695,7 @@ func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps {
require.NoError(t, err, "failed to create tls configuration") require.NoError(t, err, "failed to create tls configuration")
r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter), nil) r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter), nil)
builder := resolver.NewServerResolverBuilder(resolver.Config{}) builder := resolver.NewServerResolverBuilder(newConfig(t, c.Datacenter, "client"))
resolver.Register(builder) resolver.Register(builder)
connPool := &pool.ConnPool{ connPool := &pool.ConnPool{

View File

@ -120,7 +120,14 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer, providedLogger hcl
d.ViewStore = submatview.NewStore(d.Logger.Named("viewstore")) d.ViewStore = submatview.NewStore(d.Logger.Named("viewstore"))
d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator) d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator)
agentType := "client"
if cfg.ServerMode {
agentType = "server"
}
resolverBuilder := resolver.NewServerResolverBuilder(resolver.Config{ resolverBuilder := resolver.NewServerResolverBuilder(resolver.Config{
AgentType: agentType,
Datacenter: cfg.Datacenter,
// Set the authority to something sufficiently unique so any usage in // Set the authority to something sufficiently unique so any usage in
// tests would be self-isolating in the global resolver map, while also // tests would be self-isolating in the global resolver map, while also
// not incurring a huge penalty for non-test code. // not incurring a huge penalty for non-test code.