grpc: fix data race in balancer registration (#16229)
Registering gRPC balancers is thread-unsafe because they are stored in a global map variable that is accessed without holding a lock. Therefore, it's expected that balancers are registered _once_ at the beginning of your program (e.g. in a package `init` function) and certainly not after you've started dialing connections, etc. > NOTE: this function must only be called during initialization time > (i.e. in an init() function), and is not thread-safe. While this is fine for us in production, it's challenging for tests that spin up multiple agents in-memory. We currently register a balancer per- agent which holds agent-specific state that cannot safely be shared. This commit introduces our own registry that _is_ thread-safe, and implements the Builder interface such that we can call gRPC's `Register` method once, on start-up. It uses the same pattern as our resolver registry where we use the dial target's host (aka "authority"), which is unique per-agent, to determine which builder to use.
This commit is contained in:
parent
86017c93ef
commit
118ffb1e95
|
@ -1599,10 +1599,7 @@ func (a *Agent) ShutdownAgent() error {
|
|||
|
||||
a.stopLicenseManager()
|
||||
|
||||
// this would be cancelled anyways (by the closing of the shutdown ch) but
|
||||
// this should help them to be stopped more quickly
|
||||
a.baseDeps.AutoConfig.Stop()
|
||||
a.baseDeps.MetricsConfig.Cancel()
|
||||
a.baseDeps.Close()
|
||||
|
||||
a.stateLock.Lock()
|
||||
defer a.stateLock.Unlock()
|
||||
|
|
|
@ -522,9 +522,13 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
|
|||
|
||||
resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter))
|
||||
resolver.Register(resolverBuilder)
|
||||
t.Cleanup(func() {
|
||||
resolver.Deregister(resolverBuilder.Authority())
|
||||
})
|
||||
|
||||
balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t))
|
||||
balancerBuilder.Register()
|
||||
t.Cleanup(balancerBuilder.Deregister)
|
||||
|
||||
r := router.NewRouter(
|
||||
logger,
|
||||
|
@ -559,7 +563,6 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
|
|||
UseTLSForDC: tls.UseTLS,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: c.Datacenter,
|
||||
BalancerBuilder: balancerBuilder,
|
||||
}),
|
||||
LeaderForwarder: resolverBuilder,
|
||||
NewRequestRecorderFunc: middleware.NewRequestRecorder,
|
||||
|
|
|
@ -1165,7 +1165,7 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) {
|
|||
|
||||
var conn *grpc.ClientConn
|
||||
{
|
||||
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, func(c *Config) {
|
||||
client, resolverBuilder := newClientWithGRPCPlumbing(t, func(c *Config) {
|
||||
c.Datacenter = "dc2"
|
||||
c.PrimaryDatacenter = "dc1"
|
||||
c.RPCConfig.EnableStreaming = true
|
||||
|
@ -1177,7 +1177,6 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) {
|
|||
Servers: resolverBuilder,
|
||||
DialingFromServer: false,
|
||||
DialingFromDatacenter: "dc2",
|
||||
BalancerBuilder: balancerBuilder,
|
||||
})
|
||||
|
||||
conn, err = pool.ClientConn("dc2")
|
||||
|
|
|
@ -39,7 +39,7 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
defer server.Shutdown()
|
||||
|
||||
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)
|
||||
client, resolverBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)
|
||||
|
||||
// Try to join
|
||||
testrpc.WaitForLeader(t, server.RPC, "dc1")
|
||||
|
@ -71,7 +71,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
|
|||
UseTLSForDC: client.tlsConfigurator.UseTLS,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
BalancerBuilder: balancerBuilder,
|
||||
})
|
||||
conn, err := pool.ClientConn("dc1")
|
||||
require.NoError(t, err)
|
||||
|
@ -116,7 +115,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
|
|||
UseTLSForDC: client.tlsConfigurator.UseTLS,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
BalancerBuilder: balancerBuilder,
|
||||
})
|
||||
conn, err := pool.ClientConn("dc1")
|
||||
require.NoError(t, err)
|
||||
|
@ -191,7 +189,7 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) {
|
|||
defer server.Shutdown()
|
||||
|
||||
// Set up a client with valid certs and verify_outgoing = true
|
||||
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)
|
||||
client, resolverBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)
|
||||
|
||||
testrpc.WaitForLeader(t, server.RPC, "dc1")
|
||||
|
||||
|
@ -204,7 +202,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) {
|
|||
UseTLSForDC: client.tlsConfigurator.UseTLS,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
BalancerBuilder: balancerBuilder,
|
||||
})
|
||||
conn, err := pool.ClientConn("dc1")
|
||||
require.NoError(t, err)
|
||||
|
@ -284,7 +281,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
|
|||
codec := rpcClient(t, server)
|
||||
defer codec.Close()
|
||||
|
||||
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t)
|
||||
client, resolverBuilder := newClientWithGRPCPlumbing(t)
|
||||
|
||||
// Try to join
|
||||
testrpc.WaitForLeader(t, server.RPC, "dc1")
|
||||
|
@ -346,7 +343,6 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
|
|||
UseTLSForDC: client.tlsConfigurator.UseTLS,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
BalancerBuilder: balancerBuilder,
|
||||
})
|
||||
conn, err := pool.ClientConn("dc1")
|
||||
require.NoError(t, err)
|
||||
|
@ -376,7 +372,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
|
|||
"at least some of the subscribers should have received non-snapshot updates")
|
||||
}
|
||||
|
||||
func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder, *balancer.Builder) {
|
||||
func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) {
|
||||
_, config := testClientConfig(t)
|
||||
for _, op := range ops {
|
||||
op(config)
|
||||
|
@ -392,6 +388,7 @@ func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *re
|
|||
|
||||
balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t))
|
||||
balancerBuilder.Register()
|
||||
t.Cleanup(balancerBuilder.Deregister)
|
||||
|
||||
deps := newDefaultDeps(t, config)
|
||||
deps.Router = router.NewRouter(
|
||||
|
@ -406,7 +403,7 @@ func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *re
|
|||
t.Cleanup(func() {
|
||||
client.Shutdown()
|
||||
})
|
||||
return client, resolverBuilder, balancerBuilder
|
||||
return client, resolverBuilder
|
||||
}
|
||||
|
||||
type testLogger interface {
|
||||
|
|
|
@ -65,10 +65,14 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// NewBuilder constructs a new Builder with the given name.
|
||||
func NewBuilder(name string, logger hclog.Logger) *Builder {
|
||||
// NewBuilder constructs a new Builder. Calling Register will add the Builder
|
||||
// to our global registry under the given "authority" such that it will be used
|
||||
// when dialing targets in the form "consul-internal://<authority>/...", this
|
||||
// allows us to add and remove balancers for different in-memory agents during
|
||||
// tests.
|
||||
func NewBuilder(authority string, logger hclog.Logger) *Builder {
|
||||
return &Builder{
|
||||
name: name,
|
||||
authority: authority,
|
||||
logger: logger,
|
||||
byTarget: make(map[string]*list.List),
|
||||
shuffler: randomShuffler(),
|
||||
|
@ -77,7 +81,7 @@ func NewBuilder(name string, logger hclog.Logger) *Builder {
|
|||
|
||||
// Builder implements gRPC's balancer.Builder interface to construct balancers.
|
||||
type Builder struct {
|
||||
name string
|
||||
authority string
|
||||
logger hclog.Logger
|
||||
shuffler shuffler
|
||||
|
||||
|
@ -129,19 +133,15 @@ func (b *Builder) removeBalancer(targetURL string, elem *list.Element) {
|
|||
}
|
||||
}
|
||||
|
||||
// Name implements the gRPC Balancer interface by returning its given name.
|
||||
func (b *Builder) Name() string { return b.name }
|
||||
|
||||
// gRPC's balancer.Register method is not thread-safe, so we guard our calls
|
||||
// with a global lock (as it may be called from parallel tests).
|
||||
var registerLock sync.Mutex
|
||||
|
||||
// Register the Builder in gRPC's global registry using its given name.
|
||||
// Register the Builder in our global registry. Users should call Deregister
|
||||
// when finished using the Builder to clean-up global state.
|
||||
func (b *Builder) Register() {
|
||||
registerLock.Lock()
|
||||
defer registerLock.Unlock()
|
||||
globalRegistry.register(b.authority, b)
|
||||
}
|
||||
|
||||
gbalancer.Register(b)
|
||||
// Deregister the Builder from our global registry to clean up state.
|
||||
func (b *Builder) Deregister() {
|
||||
globalRegistry.deregister(b.authority)
|
||||
}
|
||||
|
||||
// Rebalance randomizes the priority order of servers for the given target to
|
||||
|
|
|
@ -21,6 +21,8 @@ import (
|
|||
"google.golang.org/grpc/stats"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
|
||||
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/sdk/testutil/retry"
|
||||
|
@ -34,12 +36,13 @@ func TestBalancer(t *testing.T) {
|
|||
server1 := runServer(t, "server1")
|
||||
server2 := runServer(t, "server2")
|
||||
|
||||
target, _ := stubResolver(t, server1, server2)
|
||||
target, authority, _ := stubResolver(t, server1, server2)
|
||||
|
||||
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
|
||||
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
|
||||
balancerBuilder.Register()
|
||||
t.Cleanup(balancerBuilder.Deregister)
|
||||
|
||||
conn := dial(t, target, balancerBuilder)
|
||||
conn := dial(t, target)
|
||||
client := testservice.NewSimpleClient(conn)
|
||||
|
||||
var serverName string
|
||||
|
@ -78,12 +81,13 @@ func TestBalancer(t *testing.T) {
|
|||
server1 := runServer(t, "server1")
|
||||
server2 := runServer(t, "server2")
|
||||
|
||||
target, _ := stubResolver(t, server1, server2)
|
||||
target, authority, _ := stubResolver(t, server1, server2)
|
||||
|
||||
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
|
||||
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
|
||||
balancerBuilder.Register()
|
||||
t.Cleanup(balancerBuilder.Deregister)
|
||||
|
||||
conn := dial(t, target, balancerBuilder)
|
||||
conn := dial(t, target)
|
||||
client := testservice.NewSimpleClient(conn)
|
||||
|
||||
// Figure out which server we're talking to now, and which we should switch to.
|
||||
|
@ -123,10 +127,11 @@ func TestBalancer(t *testing.T) {
|
|||
server1 := runServer(t, "server1")
|
||||
server2 := runServer(t, "server2")
|
||||
|
||||
target, _ := stubResolver(t, server1, server2)
|
||||
target, authority, _ := stubResolver(t, server1, server2)
|
||||
|
||||
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
|
||||
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
|
||||
balancerBuilder.Register()
|
||||
t.Cleanup(balancerBuilder.Deregister)
|
||||
|
||||
// Provide a custom prioritizer that causes Rebalance to choose whichever
|
||||
// server didn't get our first request.
|
||||
|
@ -137,7 +142,7 @@ func TestBalancer(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
conn := dial(t, target, balancerBuilder)
|
||||
conn := dial(t, target)
|
||||
client := testservice.NewSimpleClient(conn)
|
||||
|
||||
// Figure out which server we're talking to now.
|
||||
|
@ -177,12 +182,13 @@ func TestBalancer(t *testing.T) {
|
|||
server1 := runServer(t, "server1")
|
||||
server2 := runServer(t, "server2")
|
||||
|
||||
target, res := stubResolver(t, server1, server2)
|
||||
target, authority, res := stubResolver(t, server1, server2)
|
||||
|
||||
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
|
||||
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
|
||||
balancerBuilder.Register()
|
||||
t.Cleanup(balancerBuilder.Deregister)
|
||||
|
||||
conn := dial(t, target, balancerBuilder)
|
||||
conn := dial(t, target)
|
||||
client := testservice.NewSimpleClient(conn)
|
||||
|
||||
// Figure out which server we're talking to now.
|
||||
|
@ -233,7 +239,7 @@ func TestBalancer(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) {
|
||||
func stubResolver(t *testing.T, servers ...*server) (string, string, *manual.Resolver) {
|
||||
t.Helper()
|
||||
|
||||
addresses := make([]resolver.Address, len(servers))
|
||||
|
@ -249,7 +255,10 @@ func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) {
|
|||
resolver.Register(r)
|
||||
t.Cleanup(func() { resolver.UnregisterForTesting(scheme) })
|
||||
|
||||
return fmt.Sprintf("%s://", scheme), r
|
||||
authority, err := uuid.GenerateUUID()
|
||||
require.NoError(t, err)
|
||||
|
||||
return fmt.Sprintf("%s://%s", scheme, authority), authority, r
|
||||
}
|
||||
|
||||
func runServer(t *testing.T, name string) *server {
|
||||
|
@ -309,12 +318,12 @@ func (s *server) Something(context.Context, *testservice.Req) (*testservice.Resp
|
|||
return &testservice.Resp{ServerName: s.name}, nil
|
||||
}
|
||||
|
||||
func dial(t *testing.T, target string, builder *Builder) *grpc.ClientConn {
|
||||
func dial(t *testing.T, target string) *grpc.ClientConn {
|
||||
conn, err := grpc.Dial(
|
||||
target,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultServiceConfig(
|
||||
fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, builder.Name()),
|
||||
fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, BuilderName),
|
||||
),
|
||||
)
|
||||
t.Cleanup(func() {
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
package balancer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
gbalancer "google.golang.org/grpc/balancer"
|
||||
)
|
||||
|
||||
// BuilderName should be given in gRPC service configuration to enable our
|
||||
// custom balancer. It refers to this package's global registry, rather than
|
||||
// an instance of Builder to enable us to add and remove builders at runtime,
|
||||
// specifically during tests.
|
||||
const BuilderName = "consul-internal"
|
||||
|
||||
// gRPC's balancer.Register method is thread-unsafe because it mutates a global
|
||||
// map without holding a lock. As such, it's expected that you register custom
|
||||
// balancers once at the start of your program (e.g. a package init function).
|
||||
//
|
||||
// In production, this is fine. Agents register a single instance of our builder
|
||||
// and use it for the duration. Tests are where this becomes problematic, as we
|
||||
// spin up several agents in-memory and register/deregister a builder for each,
|
||||
// with its own agent-specific state, logger, etc.
|
||||
//
|
||||
// To avoid data races, we call gRPC's Register method once, on-package init,
|
||||
// with a global registry struct that implements the Builder interface but
|
||||
// delegates the building to N instances of our Builder that are registered and
|
||||
// deregistered at runtime. We the dial target's host (aka "authority") which
|
||||
// is unique per-agent to pick the correct builder.
|
||||
func init() {
|
||||
gbalancer.Register(globalRegistry)
|
||||
}
|
||||
|
||||
var globalRegistry = ®istry{
|
||||
byAuthority: make(map[string]*Builder),
|
||||
}
|
||||
|
||||
type registry struct {
|
||||
mu sync.RWMutex
|
||||
byAuthority map[string]*Builder
|
||||
}
|
||||
|
||||
func (r *registry) Build(cc gbalancer.ClientConn, opts gbalancer.BuildOptions) gbalancer.Balancer {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
auth := opts.Target.URL.Host
|
||||
builder, ok := r.byAuthority[auth]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("no gRPC balancer builder registered for authority: %q", auth))
|
||||
}
|
||||
return builder.Build(cc, opts)
|
||||
}
|
||||
|
||||
func (r *registry) Name() string { return BuilderName }
|
||||
|
||||
func (r *registry) register(auth string, builder *Builder) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.byAuthority[auth] = builder
|
||||
}
|
||||
|
||||
func (r *registry) deregister(auth string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
delete(r.byAuthority, auth)
|
||||
}
|
|
@ -8,12 +8,12 @@ import (
|
|||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
gbalancer "google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
|
||||
"github.com/hashicorp/consul/agent/grpc-internal/balancer"
|
||||
agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
||||
"github.com/hashicorp/consul/agent/metadata"
|
||||
"github.com/hashicorp/consul/agent/pool"
|
||||
|
@ -22,8 +22,8 @@ import (
|
|||
|
||||
// grpcServiceConfig is provided as the default service config.
|
||||
//
|
||||
// It configures our custom balancer (via the %s directive to interpolate its
|
||||
// name) which will automatically switch servers on error.
|
||||
// It configures our custom balancer which will automatically switch servers
|
||||
// on error.
|
||||
//
|
||||
// It also enables gRPC's built-in automatic retries for RESOURCE_EXHAUSTED
|
||||
// errors *only*, as this is the status code servers will return for an
|
||||
|
@ -41,7 +41,7 @@ import (
|
|||
// but we're working on generating them automatically from the protobuf files
|
||||
const grpcServiceConfig = `
|
||||
{
|
||||
"loadBalancingConfig": [{"%s":{}}],
|
||||
"loadBalancingConfig": [{"` + balancer.BuilderName + `":{}}],
|
||||
"methodConfig": [
|
||||
{
|
||||
"name": [{}],
|
||||
|
@ -136,7 +136,6 @@ type ClientConnPool struct {
|
|||
gwResolverDep gatewayResolverDep
|
||||
conns map[string]*grpc.ClientConn
|
||||
connsLock sync.Mutex
|
||||
balancerBuilder gbalancer.Builder
|
||||
}
|
||||
|
||||
type ServerLocator interface {
|
||||
|
@ -198,21 +197,14 @@ type ClientConnPoolConfig struct {
|
|||
// DialingFromDatacenter is the datacenter of the consul agent using this
|
||||
// pool.
|
||||
DialingFromDatacenter string
|
||||
|
||||
// BalancerBuilder is a builder for the gRPC balancer that will be used.
|
||||
BalancerBuilder gbalancer.Builder
|
||||
}
|
||||
|
||||
// NewClientConnPool create new GRPC client pool to connect to servers using
|
||||
// GRPC over RPC.
|
||||
func NewClientConnPool(cfg ClientConnPoolConfig) *ClientConnPool {
|
||||
if cfg.BalancerBuilder == nil {
|
||||
panic("missing required BalancerBuilder")
|
||||
}
|
||||
c := &ClientConnPool{
|
||||
servers: cfg.Servers,
|
||||
conns: make(map[string]*grpc.ClientConn),
|
||||
balancerBuilder: cfg.BalancerBuilder,
|
||||
}
|
||||
c.dialer = newDialer(cfg, &c.gwResolverDep)
|
||||
return c
|
||||
|
@ -251,9 +243,7 @@ func (c *ClientConnPool) dial(datacenter string, serverType string) (*grpc.Clien
|
|||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(c.dialer),
|
||||
grpc.WithStatsHandler(agentmiddleware.NewStatsHandler(metrics.Default(), metricsLabels)),
|
||||
grpc.WithDefaultServiceConfig(
|
||||
fmt.Sprintf(grpcServiceConfig, c.balancerBuilder.Name()),
|
||||
),
|
||||
grpc.WithDefaultServiceConfig(grpcServiceConfig),
|
||||
// Keep alive parameters are based on the same default ones we used for
|
||||
// Yamux. These are somewhat arbitrary but we did observe in scale testing
|
||||
// that the gRPC defaults (servers send keepalives only every 2 hours,
|
||||
|
|
|
@ -13,7 +13,6 @@ import (
|
|||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
gbalancer "google.golang.org/grpc/balancer"
|
||||
|
||||
"github.com/hashicorp/consul/agent/grpc-internal/balancer"
|
||||
"github.com/hashicorp/consul/agent/grpc-internal/resolver"
|
||||
|
@ -143,7 +142,8 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
|
|||
// if this test is failing because of expired certificates
|
||||
// use the procedure in test/CA-GENERATION.md
|
||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||
registerWithGRPC(t, res)
|
||||
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
|
||||
registerWithGRPC(t, res, bb)
|
||||
|
||||
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
|
||||
InternalRPC: tlsutil.ProtocolConfig{
|
||||
|
@ -168,7 +168,6 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
|
|||
UseTLSForDC: tlsConf.UseTLS,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
BalancerBuilder: balancerBuilder(t, res.Authority()),
|
||||
})
|
||||
|
||||
conn, err := pool.ClientConn("dc1")
|
||||
|
@ -191,7 +190,8 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T)
|
|||
gwAddr := ipaddr.FormatAddressPort("127.0.0.1", freeport.GetOne(t))
|
||||
|
||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||
registerWithGRPC(t, res)
|
||||
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
|
||||
registerWithGRPC(t, res, bb)
|
||||
|
||||
tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{
|
||||
InternalRPC: tlsutil.ProtocolConfig{
|
||||
|
@ -244,7 +244,6 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T)
|
|||
UseTLSForDC: tlsConf.UseTLS,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc2",
|
||||
BalancerBuilder: balancerBuilder(t, res.Authority()),
|
||||
})
|
||||
pool.SetGatewayResolver(func(addr string) string {
|
||||
return gwAddr
|
||||
|
@ -267,13 +266,13 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T)
|
|||
func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
|
||||
count := 4
|
||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||
registerWithGRPC(t, res)
|
||||
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
|
||||
registerWithGRPC(t, res, bb)
|
||||
pool := NewClientConnPool(ClientConnPoolConfig{
|
||||
Servers: res,
|
||||
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
BalancerBuilder: balancerBuilder(t, res.Authority()),
|
||||
})
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
|
@ -303,13 +302,13 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
|
|||
func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) {
|
||||
count := 3
|
||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||
registerWithGRPC(t, res)
|
||||
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
|
||||
registerWithGRPC(t, res, bb)
|
||||
pool := NewClientConnPool(ClientConnPoolConfig{
|
||||
Servers: res,
|
||||
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
BalancerBuilder: balancerBuilder(t, res.Authority()),
|
||||
})
|
||||
|
||||
var servers []testServer
|
||||
|
@ -356,13 +355,13 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {
|
|||
dcs := []string{"dc1", "dc2", "dc3"}
|
||||
|
||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||
registerWithGRPC(t, res)
|
||||
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
|
||||
registerWithGRPC(t, res, bb)
|
||||
pool := NewClientConnPool(ClientConnPoolConfig{
|
||||
Servers: res,
|
||||
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
BalancerBuilder: balancerBuilder(t, res.Authority()),
|
||||
})
|
||||
|
||||
for _, dc := range dcs {
|
||||
|
@ -386,18 +385,11 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func registerWithGRPC(t *testing.T, b *resolver.ServerResolverBuilder) {
|
||||
resolver.Register(b)
|
||||
func registerWithGRPC(t *testing.T, rb *resolver.ServerResolverBuilder, bb *balancer.Builder) {
|
||||
resolver.Register(rb)
|
||||
bb.Register()
|
||||
t.Cleanup(func() {
|
||||
resolver.Deregister(b.Authority())
|
||||
resolver.Deregister(rb.Authority())
|
||||
bb.Deregister()
|
||||
})
|
||||
}
|
||||
|
||||
func balancerBuilder(t *testing.T, name string) gbalancer.Builder {
|
||||
t.Helper()
|
||||
|
||||
bb := balancer.NewBuilder(name, testutil.Logger(t))
|
||||
bb.Register()
|
||||
|
||||
return bb
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/types"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
|
@ -13,6 +14,7 @@ import (
|
|||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/hashicorp/consul/agent/grpc-internal/balancer"
|
||||
"github.com/hashicorp/consul/agent/grpc-internal/resolver"
|
||||
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
|
||||
)
|
||||
|
@ -27,7 +29,8 @@ func TestHandler_PanicRecoveryInterceptor(t *testing.T) {
|
|||
})
|
||||
|
||||
res := resolver.NewServerResolverBuilder(newConfig(t))
|
||||
registerWithGRPC(t, res)
|
||||
bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t))
|
||||
registerWithGRPC(t, res, bb)
|
||||
|
||||
srv := newPanicTestServer(t, logger, "server-1", "dc1", nil)
|
||||
res.AddServer(types.AreaWAN, srv.Metadata())
|
||||
|
@ -38,7 +41,6 @@ func TestHandler_PanicRecoveryInterceptor(t *testing.T) {
|
|||
UseTLSForDC: useTLSForDcAlwaysTrue,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: "dc1",
|
||||
BalancerBuilder: balancerBuilder(t, res.Authority()),
|
||||
})
|
||||
|
||||
conn, err := pool.ClientConn("dc1")
|
||||
|
|
|
@ -1693,8 +1693,9 @@ func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps {
|
|||
Datacenter: c.Datacenter,
|
||||
}
|
||||
|
||||
balancerBuilder := balancer.NewBuilder(t.Name(), testutil.Logger(t))
|
||||
balancerBuilder := balancer.NewBuilder(builder.Authority(), testutil.Logger(t))
|
||||
balancerBuilder.Register()
|
||||
t.Cleanup(balancerBuilder.Deregister)
|
||||
|
||||
return consul.Deps{
|
||||
EventPublisher: stream.NewEventPublisher(10 * time.Second),
|
||||
|
@ -1709,7 +1710,6 @@ func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps {
|
|||
UseTLSForDC: tls.UseTLS,
|
||||
DialingFromServer: true,
|
||||
DialingFromDatacenter: c.Datacenter,
|
||||
BalancerBuilder: balancerBuilder,
|
||||
}),
|
||||
LeaderForwarder: builder,
|
||||
EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c),
|
||||
|
|
|
@ -53,6 +53,8 @@ type BaseDeps struct {
|
|||
Cache *cache.Cache
|
||||
ViewStore *submatview.Store
|
||||
WatchedFiles []string
|
||||
|
||||
deregisterBalancer, deregisterResolver func()
|
||||
}
|
||||
|
||||
type ConfigLoader func(source config.Source) (config.LoadResult, error)
|
||||
|
@ -122,14 +124,16 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer, providedLogger hcl
|
|||
Authority: cfg.Datacenter + "." + string(cfg.NodeID),
|
||||
})
|
||||
resolver.Register(resolverBuilder)
|
||||
d.deregisterResolver = func() {
|
||||
resolver.Deregister(resolverBuilder.Authority())
|
||||
}
|
||||
|
||||
balancerBuilder := balancer.NewBuilder(
|
||||
// Balancer name doesn't really matter, we set it to the resolver authority
|
||||
// to keep it unique for tests.
|
||||
resolverBuilder.Authority(),
|
||||
d.Logger.Named("grpc.balancer"),
|
||||
)
|
||||
balancerBuilder.Register()
|
||||
d.deregisterBalancer = balancerBuilder.Deregister
|
||||
|
||||
d.GRPCConnPool = grpcInt.NewClientConnPool(grpcInt.ClientConnPoolConfig{
|
||||
Servers: resolverBuilder,
|
||||
|
@ -139,7 +143,6 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer, providedLogger hcl
|
|||
UseTLSForDC: d.TLSConfigurator.UseTLS,
|
||||
DialingFromServer: cfg.ServerMode,
|
||||
DialingFromDatacenter: cfg.Datacenter,
|
||||
BalancerBuilder: balancerBuilder,
|
||||
})
|
||||
d.LeaderForwarder = resolverBuilder
|
||||
|
||||
|
@ -189,6 +192,20 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer, providedLogger hcl
|
|||
return d, nil
|
||||
}
|
||||
|
||||
// Close cleans up any state and goroutines associated to bd's members not
|
||||
// handled by something else (e.g. the agent stop channel).
|
||||
func (bd BaseDeps) Close() {
|
||||
bd.AutoConfig.Stop()
|
||||
bd.MetricsConfig.Cancel()
|
||||
|
||||
if fn := bd.deregisterBalancer; fn != nil {
|
||||
fn()
|
||||
}
|
||||
if fn := bd.deregisterResolver; fn != nil {
|
||||
fn()
|
||||
}
|
||||
}
|
||||
|
||||
// grpcLogInitOnce because the test suite will call NewBaseDeps in many tests and
|
||||
// causes data races when it is re-initialized.
|
||||
var grpcLogInitOnce sync.Once
|
||||
|
|
Loading…
Reference in New Issue