diff --git a/agent/agent.go b/agent/agent.go index ec1486374..7b218174a 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1599,10 +1599,7 @@ func (a *Agent) ShutdownAgent() error { a.stopLicenseManager() - // this would be cancelled anyways (by the closing of the shutdown ch) but - // this should help them to be stopped more quickly - a.baseDeps.AutoConfig.Stop() - a.baseDeps.MetricsConfig.Cancel() + a.baseDeps.Close() a.stateLock.Lock() defer a.stateLock.Unlock() diff --git a/agent/consul/client_test.go b/agent/consul/client_test.go index 15af55509..97958b486 100644 --- a/agent/consul/client_test.go +++ b/agent/consul/client_test.go @@ -522,9 +522,13 @@ func newDefaultDeps(t *testing.T, c *Config) Deps { resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter)) resolver.Register(resolverBuilder) + t.Cleanup(func() { + resolver.Deregister(resolverBuilder.Authority()) + }) balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t)) balancerBuilder.Register() + t.Cleanup(balancerBuilder.Deregister) r := router.NewRouter( logger, @@ -559,7 +563,6 @@ func newDefaultDeps(t *testing.T, c *Config) Deps { UseTLSForDC: tls.UseTLS, DialingFromServer: true, DialingFromDatacenter: c.Datacenter, - BalancerBuilder: balancerBuilder, }), LeaderForwarder: resolverBuilder, NewRequestRecorderFunc: middleware.NewRequestRecorder, diff --git a/agent/consul/rpc_test.go b/agent/consul/rpc_test.go index 0eff59b2b..fa0107dd1 100644 --- a/agent/consul/rpc_test.go +++ b/agent/consul/rpc_test.go @@ -1165,7 +1165,7 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) { var conn *grpc.ClientConn { - client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, func(c *Config) { + client, resolverBuilder := newClientWithGRPCPlumbing(t, func(c *Config) { c.Datacenter = "dc2" c.PrimaryDatacenter = "dc1" c.RPCConfig.EnableStreaming = true @@ -1177,7 +1177,6 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) { Servers: resolverBuilder, DialingFromServer: false, DialingFromDatacenter: "dc2", - BalancerBuilder: balancerBuilder, }) conn, err = pool.ClientConn("dc2") diff --git a/agent/consul/subscribe_backend_test.go b/agent/consul/subscribe_backend_test.go index a0d07e648..b0ae2366e 100644 --- a/agent/consul/subscribe_backend_test.go +++ b/agent/consul/subscribe_backend_test.go @@ -39,7 +39,7 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) { require.NoError(t, err) defer server.Shutdown() - client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing) + client, resolverBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing) // Try to join testrpc.WaitForLeader(t, server.RPC, "dc1") @@ -71,7 +71,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) { UseTLSForDC: client.tlsConfigurator.UseTLS, DialingFromServer: true, DialingFromDatacenter: "dc1", - BalancerBuilder: balancerBuilder, }) conn, err := pool.ClientConn("dc1") require.NoError(t, err) @@ -116,7 +115,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) { UseTLSForDC: client.tlsConfigurator.UseTLS, DialingFromServer: true, DialingFromDatacenter: "dc1", - BalancerBuilder: balancerBuilder, }) conn, err := pool.ClientConn("dc1") require.NoError(t, err) @@ -191,7 +189,7 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) { defer server.Shutdown() // Set up a client with valid certs and verify_outgoing = true - client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing) + client, resolverBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing) testrpc.WaitForLeader(t, server.RPC, "dc1") @@ -204,7 +202,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) { UseTLSForDC: client.tlsConfigurator.UseTLS, DialingFromServer: true, DialingFromDatacenter: "dc1", - BalancerBuilder: balancerBuilder, }) conn, err := pool.ClientConn("dc1") require.NoError(t, err) @@ -284,7 +281,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T codec := rpcClient(t, server) defer codec.Close() - client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t) + client, resolverBuilder := newClientWithGRPCPlumbing(t) // Try to join testrpc.WaitForLeader(t, server.RPC, "dc1") @@ -346,7 +343,6 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T UseTLSForDC: client.tlsConfigurator.UseTLS, DialingFromServer: true, DialingFromDatacenter: "dc1", - BalancerBuilder: balancerBuilder, }) conn, err := pool.ClientConn("dc1") require.NoError(t, err) @@ -376,7 +372,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T "at least some of the subscribers should have received non-snapshot updates") } -func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder, *balancer.Builder) { +func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) { _, config := testClientConfig(t) for _, op := range ops { op(config) @@ -392,6 +388,7 @@ func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *re balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t)) balancerBuilder.Register() + t.Cleanup(balancerBuilder.Deregister) deps := newDefaultDeps(t, config) deps.Router = router.NewRouter( @@ -406,7 +403,7 @@ func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *re t.Cleanup(func() { client.Shutdown() }) - return client, resolverBuilder, balancerBuilder + return client, resolverBuilder } type testLogger interface { diff --git a/agent/grpc-internal/balancer/balancer.go b/agent/grpc-internal/balancer/balancer.go index efd349c82..64521d456 100644 --- a/agent/grpc-internal/balancer/balancer.go +++ b/agent/grpc-internal/balancer/balancer.go @@ -65,21 +65,25 @@ import ( "google.golang.org/grpc/status" ) -// NewBuilder constructs a new Builder with the given name. -func NewBuilder(name string, logger hclog.Logger) *Builder { +// NewBuilder constructs a new Builder. Calling Register will add the Builder +// to our global registry under the given "authority" such that it will be used +// when dialing targets in the form "consul-internal:///...", this +// allows us to add and remove balancers for different in-memory agents during +// tests. +func NewBuilder(authority string, logger hclog.Logger) *Builder { return &Builder{ - name: name, - logger: logger, - byTarget: make(map[string]*list.List), - shuffler: randomShuffler(), + authority: authority, + logger: logger, + byTarget: make(map[string]*list.List), + shuffler: randomShuffler(), } } // Builder implements gRPC's balancer.Builder interface to construct balancers. type Builder struct { - name string - logger hclog.Logger - shuffler shuffler + authority string + logger hclog.Logger + shuffler shuffler mu sync.Mutex byTarget map[string]*list.List @@ -129,19 +133,15 @@ func (b *Builder) removeBalancer(targetURL string, elem *list.Element) { } } -// Name implements the gRPC Balancer interface by returning its given name. -func (b *Builder) Name() string { return b.name } - -// gRPC's balancer.Register method is not thread-safe, so we guard our calls -// with a global lock (as it may be called from parallel tests). -var registerLock sync.Mutex - -// Register the Builder in gRPC's global registry using its given name. +// Register the Builder in our global registry. Users should call Deregister +// when finished using the Builder to clean-up global state. func (b *Builder) Register() { - registerLock.Lock() - defer registerLock.Unlock() + globalRegistry.register(b.authority, b) +} - gbalancer.Register(b) +// Deregister the Builder from our global registry to clean up state. +func (b *Builder) Deregister() { + globalRegistry.deregister(b.authority) } // Rebalance randomizes the priority order of servers for the given target to diff --git a/agent/grpc-internal/balancer/balancer_test.go b/agent/grpc-internal/balancer/balancer_test.go index 830092ab3..8406e7c7b 100644 --- a/agent/grpc-internal/balancer/balancer_test.go +++ b/agent/grpc-internal/balancer/balancer_test.go @@ -21,6 +21,8 @@ import ( "google.golang.org/grpc/stats" "google.golang.org/grpc/status" + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice" "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil/retry" @@ -34,12 +36,13 @@ func TestBalancer(t *testing.T) { server1 := runServer(t, "server1") server2 := runServer(t, "server2") - target, _ := stubResolver(t, server1, server2) + target, authority, _ := stubResolver(t, server1, server2) - balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t)) + balancerBuilder := NewBuilder(authority, testutil.Logger(t)) balancerBuilder.Register() + t.Cleanup(balancerBuilder.Deregister) - conn := dial(t, target, balancerBuilder) + conn := dial(t, target) client := testservice.NewSimpleClient(conn) var serverName string @@ -78,12 +81,13 @@ func TestBalancer(t *testing.T) { server1 := runServer(t, "server1") server2 := runServer(t, "server2") - target, _ := stubResolver(t, server1, server2) + target, authority, _ := stubResolver(t, server1, server2) - balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t)) + balancerBuilder := NewBuilder(authority, testutil.Logger(t)) balancerBuilder.Register() + t.Cleanup(balancerBuilder.Deregister) - conn := dial(t, target, balancerBuilder) + conn := dial(t, target) client := testservice.NewSimpleClient(conn) // Figure out which server we're talking to now, and which we should switch to. @@ -123,10 +127,11 @@ func TestBalancer(t *testing.T) { server1 := runServer(t, "server1") server2 := runServer(t, "server2") - target, _ := stubResolver(t, server1, server2) + target, authority, _ := stubResolver(t, server1, server2) - balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t)) + balancerBuilder := NewBuilder(authority, testutil.Logger(t)) balancerBuilder.Register() + t.Cleanup(balancerBuilder.Deregister) // Provide a custom prioritizer that causes Rebalance to choose whichever // server didn't get our first request. @@ -137,7 +142,7 @@ func TestBalancer(t *testing.T) { }) } - conn := dial(t, target, balancerBuilder) + conn := dial(t, target) client := testservice.NewSimpleClient(conn) // Figure out which server we're talking to now. @@ -177,12 +182,13 @@ func TestBalancer(t *testing.T) { server1 := runServer(t, "server1") server2 := runServer(t, "server2") - target, res := stubResolver(t, server1, server2) + target, authority, res := stubResolver(t, server1, server2) - balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t)) + balancerBuilder := NewBuilder(authority, testutil.Logger(t)) balancerBuilder.Register() + t.Cleanup(balancerBuilder.Deregister) - conn := dial(t, target, balancerBuilder) + conn := dial(t, target) client := testservice.NewSimpleClient(conn) // Figure out which server we're talking to now. @@ -233,7 +239,7 @@ func TestBalancer(t *testing.T) { }) } -func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) { +func stubResolver(t *testing.T, servers ...*server) (string, string, *manual.Resolver) { t.Helper() addresses := make([]resolver.Address, len(servers)) @@ -249,7 +255,10 @@ func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) { resolver.Register(r) t.Cleanup(func() { resolver.UnregisterForTesting(scheme) }) - return fmt.Sprintf("%s://", scheme), r + authority, err := uuid.GenerateUUID() + require.NoError(t, err) + + return fmt.Sprintf("%s://%s", scheme, authority), authority, r } func runServer(t *testing.T, name string) *server { @@ -309,12 +318,12 @@ func (s *server) Something(context.Context, *testservice.Req) (*testservice.Resp return &testservice.Resp{ServerName: s.name}, nil } -func dial(t *testing.T, target string, builder *Builder) *grpc.ClientConn { +func dial(t *testing.T, target string) *grpc.ClientConn { conn, err := grpc.Dial( target, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig( - fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, builder.Name()), + fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, BuilderName), ), ) t.Cleanup(func() { diff --git a/agent/grpc-internal/balancer/registry.go b/agent/grpc-internal/balancer/registry.go new file mode 100644 index 000000000..778b2c31c --- /dev/null +++ b/agent/grpc-internal/balancer/registry.go @@ -0,0 +1,69 @@ +package balancer + +import ( + "fmt" + "sync" + + gbalancer "google.golang.org/grpc/balancer" +) + +// BuilderName should be given in gRPC service configuration to enable our +// custom balancer. It refers to this package's global registry, rather than +// an instance of Builder to enable us to add and remove builders at runtime, +// specifically during tests. +const BuilderName = "consul-internal" + +// gRPC's balancer.Register method is thread-unsafe because it mutates a global +// map without holding a lock. As such, it's expected that you register custom +// balancers once at the start of your program (e.g. a package init function). +// +// In production, this is fine. Agents register a single instance of our builder +// and use it for the duration. Tests are where this becomes problematic, as we +// spin up several agents in-memory and register/deregister a builder for each, +// with its own agent-specific state, logger, etc. +// +// To avoid data races, we call gRPC's Register method once, on-package init, +// with a global registry struct that implements the Builder interface but +// delegates the building to N instances of our Builder that are registered and +// deregistered at runtime. We the dial target's host (aka "authority") which +// is unique per-agent to pick the correct builder. +func init() { + gbalancer.Register(globalRegistry) +} + +var globalRegistry = ®istry{ + byAuthority: make(map[string]*Builder), +} + +type registry struct { + mu sync.RWMutex + byAuthority map[string]*Builder +} + +func (r *registry) Build(cc gbalancer.ClientConn, opts gbalancer.BuildOptions) gbalancer.Balancer { + r.mu.RLock() + defer r.mu.RUnlock() + + auth := opts.Target.URL.Host + builder, ok := r.byAuthority[auth] + if !ok { + panic(fmt.Sprintf("no gRPC balancer builder registered for authority: %q", auth)) + } + return builder.Build(cc, opts) +} + +func (r *registry) Name() string { return BuilderName } + +func (r *registry) register(auth string, builder *Builder) { + r.mu.Lock() + defer r.mu.Unlock() + + r.byAuthority[auth] = builder +} + +func (r *registry) deregister(auth string) { + r.mu.Lock() + defer r.mu.Unlock() + + delete(r.byAuthority, auth) +} diff --git a/agent/grpc-internal/client.go b/agent/grpc-internal/client.go index 36431f248..38010ac24 100644 --- a/agent/grpc-internal/client.go +++ b/agent/grpc-internal/client.go @@ -8,12 +8,12 @@ import ( "time" "google.golang.org/grpc" - gbalancer "google.golang.org/grpc/balancer" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" "github.com/armon/go-metrics" + "github.com/hashicorp/consul/agent/grpc-internal/balancer" agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/pool" @@ -22,8 +22,8 @@ import ( // grpcServiceConfig is provided as the default service config. // -// It configures our custom balancer (via the %s directive to interpolate its -// name) which will automatically switch servers on error. +// It configures our custom balancer which will automatically switch servers +// on error. // // It also enables gRPC's built-in automatic retries for RESOURCE_EXHAUSTED // errors *only*, as this is the status code servers will return for an @@ -41,7 +41,7 @@ import ( // but we're working on generating them automatically from the protobuf files const grpcServiceConfig = ` { - "loadBalancingConfig": [{"%s":{}}], + "loadBalancingConfig": [{"` + balancer.BuilderName + `":{}}], "methodConfig": [ { "name": [{}], @@ -131,12 +131,11 @@ const grpcServiceConfig = ` // ClientConnPool creates and stores a connection for each datacenter. type ClientConnPool struct { - dialer dialer - servers ServerLocator - gwResolverDep gatewayResolverDep - conns map[string]*grpc.ClientConn - connsLock sync.Mutex - balancerBuilder gbalancer.Builder + dialer dialer + servers ServerLocator + gwResolverDep gatewayResolverDep + conns map[string]*grpc.ClientConn + connsLock sync.Mutex } type ServerLocator interface { @@ -198,21 +197,14 @@ type ClientConnPoolConfig struct { // DialingFromDatacenter is the datacenter of the consul agent using this // pool. DialingFromDatacenter string - - // BalancerBuilder is a builder for the gRPC balancer that will be used. - BalancerBuilder gbalancer.Builder } // NewClientConnPool create new GRPC client pool to connect to servers using // GRPC over RPC. func NewClientConnPool(cfg ClientConnPoolConfig) *ClientConnPool { - if cfg.BalancerBuilder == nil { - panic("missing required BalancerBuilder") - } c := &ClientConnPool{ - servers: cfg.Servers, - conns: make(map[string]*grpc.ClientConn), - balancerBuilder: cfg.BalancerBuilder, + servers: cfg.Servers, + conns: make(map[string]*grpc.ClientConn), } c.dialer = newDialer(cfg, &c.gwResolverDep) return c @@ -251,9 +243,7 @@ func (c *ClientConnPool) dial(datacenter string, serverType string) (*grpc.Clien grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(c.dialer), grpc.WithStatsHandler(agentmiddleware.NewStatsHandler(metrics.Default(), metricsLabels)), - grpc.WithDefaultServiceConfig( - fmt.Sprintf(grpcServiceConfig, c.balancerBuilder.Name()), - ), + grpc.WithDefaultServiceConfig(grpcServiceConfig), // Keep alive parameters are based on the same default ones we used for // Yamux. These are somewhat arbitrary but we did observe in scale testing // that the gRPC defaults (servers send keepalives only every 2 hours, diff --git a/agent/grpc-internal/client_test.go b/agent/grpc-internal/client_test.go index ebd0601ad..65e08feac 100644 --- a/agent/grpc-internal/client_test.go +++ b/agent/grpc-internal/client_test.go @@ -13,7 +13,6 @@ import ( "github.com/hashicorp/go-hclog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - gbalancer "google.golang.org/grpc/balancer" "github.com/hashicorp/consul/agent/grpc-internal/balancer" "github.com/hashicorp/consul/agent/grpc-internal/resolver" @@ -143,7 +142,8 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) { // if this test is failing because of expired certificates // use the procedure in test/CA-GENERATION.md res := resolver.NewServerResolverBuilder(newConfig(t)) - registerWithGRPC(t, res) + bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) + registerWithGRPC(t, res, bb) tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{ InternalRPC: tlsutil.ProtocolConfig{ @@ -168,7 +168,6 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) { UseTLSForDC: tlsConf.UseTLS, DialingFromServer: true, DialingFromDatacenter: "dc1", - BalancerBuilder: balancerBuilder(t, res.Authority()), }) conn, err := pool.ClientConn("dc1") @@ -191,7 +190,8 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T) gwAddr := ipaddr.FormatAddressPort("127.0.0.1", freeport.GetOne(t)) res := resolver.NewServerResolverBuilder(newConfig(t)) - registerWithGRPC(t, res) + bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) + registerWithGRPC(t, res, bb) tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{ InternalRPC: tlsutil.ProtocolConfig{ @@ -244,7 +244,6 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T) UseTLSForDC: tlsConf.UseTLS, DialingFromServer: true, DialingFromDatacenter: "dc2", - BalancerBuilder: balancerBuilder(t, res.Authority()), }) pool.SetGatewayResolver(func(addr string) string { return gwAddr @@ -267,13 +266,13 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T) func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { count := 4 res := resolver.NewServerResolverBuilder(newConfig(t)) - registerWithGRPC(t, res) + bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) + registerWithGRPC(t, res, bb) pool := NewClientConnPool(ClientConnPoolConfig{ Servers: res, UseTLSForDC: useTLSForDcAlwaysTrue, DialingFromServer: true, DialingFromDatacenter: "dc1", - BalancerBuilder: balancerBuilder(t, res.Authority()), }) for i := 0; i < count; i++ { @@ -303,13 +302,13 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) { count := 3 res := resolver.NewServerResolverBuilder(newConfig(t)) - registerWithGRPC(t, res) + bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) + registerWithGRPC(t, res, bb) pool := NewClientConnPool(ClientConnPoolConfig{ Servers: res, UseTLSForDC: useTLSForDcAlwaysTrue, DialingFromServer: true, DialingFromDatacenter: "dc1", - BalancerBuilder: balancerBuilder(t, res.Authority()), }) var servers []testServer @@ -356,13 +355,13 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { dcs := []string{"dc1", "dc2", "dc3"} res := resolver.NewServerResolverBuilder(newConfig(t)) - registerWithGRPC(t, res) + bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) + registerWithGRPC(t, res, bb) pool := NewClientConnPool(ClientConnPoolConfig{ Servers: res, UseTLSForDC: useTLSForDcAlwaysTrue, DialingFromServer: true, DialingFromDatacenter: "dc1", - BalancerBuilder: balancerBuilder(t, res.Authority()), }) for _, dc := range dcs { @@ -386,18 +385,11 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { } } -func registerWithGRPC(t *testing.T, b *resolver.ServerResolverBuilder) { - resolver.Register(b) +func registerWithGRPC(t *testing.T, rb *resolver.ServerResolverBuilder, bb *balancer.Builder) { + resolver.Register(rb) + bb.Register() t.Cleanup(func() { - resolver.Deregister(b.Authority()) + resolver.Deregister(rb.Authority()) + bb.Deregister() }) } - -func balancerBuilder(t *testing.T, name string) gbalancer.Builder { - t.Helper() - - bb := balancer.NewBuilder(name, testutil.Logger(t)) - bb.Register() - - return bb -} diff --git a/agent/grpc-internal/handler_test.go b/agent/grpc-internal/handler_test.go index 96f6f036e..080aaa938 100644 --- a/agent/grpc-internal/handler_test.go +++ b/agent/grpc-internal/handler_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/types" "github.com/hashicorp/go-hclog" @@ -13,6 +14,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/hashicorp/consul/agent/grpc-internal/balancer" "github.com/hashicorp/consul/agent/grpc-internal/resolver" "github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice" ) @@ -27,7 +29,8 @@ func TestHandler_PanicRecoveryInterceptor(t *testing.T) { }) res := resolver.NewServerResolverBuilder(newConfig(t)) - registerWithGRPC(t, res) + bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) + registerWithGRPC(t, res, bb) srv := newPanicTestServer(t, logger, "server-1", "dc1", nil) res.AddServer(types.AreaWAN, srv.Metadata()) @@ -38,7 +41,6 @@ func TestHandler_PanicRecoveryInterceptor(t *testing.T) { UseTLSForDC: useTLSForDcAlwaysTrue, DialingFromServer: true, DialingFromDatacenter: "dc1", - BalancerBuilder: balancerBuilder(t, res.Authority()), }) conn, err := pool.ClientConn("dc1") diff --git a/agent/rpc/peering/service_test.go b/agent/rpc/peering/service_test.go index ced1e286c..1699cb5ac 100644 --- a/agent/rpc/peering/service_test.go +++ b/agent/rpc/peering/service_test.go @@ -1693,8 +1693,9 @@ func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps { Datacenter: c.Datacenter, } - balancerBuilder := balancer.NewBuilder(t.Name(), testutil.Logger(t)) + balancerBuilder := balancer.NewBuilder(builder.Authority(), testutil.Logger(t)) balancerBuilder.Register() + t.Cleanup(balancerBuilder.Deregister) return consul.Deps{ EventPublisher: stream.NewEventPublisher(10 * time.Second), @@ -1709,7 +1710,6 @@ func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps { UseTLSForDC: tls.UseTLS, DialingFromServer: true, DialingFromDatacenter: c.Datacenter, - BalancerBuilder: balancerBuilder, }), LeaderForwarder: builder, EnterpriseDeps: newDefaultDepsEnterprise(t, logger, c), diff --git a/agent/setup.go b/agent/setup.go index 8dc5e5e18..4600f4016 100644 --- a/agent/setup.go +++ b/agent/setup.go @@ -53,6 +53,8 @@ type BaseDeps struct { Cache *cache.Cache ViewStore *submatview.Store WatchedFiles []string + + deregisterBalancer, deregisterResolver func() } type ConfigLoader func(source config.Source) (config.LoadResult, error) @@ -122,14 +124,16 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer, providedLogger hcl Authority: cfg.Datacenter + "." + string(cfg.NodeID), }) resolver.Register(resolverBuilder) + d.deregisterResolver = func() { + resolver.Deregister(resolverBuilder.Authority()) + } balancerBuilder := balancer.NewBuilder( - // Balancer name doesn't really matter, we set it to the resolver authority - // to keep it unique for tests. resolverBuilder.Authority(), d.Logger.Named("grpc.balancer"), ) balancerBuilder.Register() + d.deregisterBalancer = balancerBuilder.Deregister d.GRPCConnPool = grpcInt.NewClientConnPool(grpcInt.ClientConnPoolConfig{ Servers: resolverBuilder, @@ -139,7 +143,6 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer, providedLogger hcl UseTLSForDC: d.TLSConfigurator.UseTLS, DialingFromServer: cfg.ServerMode, DialingFromDatacenter: cfg.Datacenter, - BalancerBuilder: balancerBuilder, }) d.LeaderForwarder = resolverBuilder @@ -189,6 +192,20 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer, providedLogger hcl return d, nil } +// Close cleans up any state and goroutines associated to bd's members not +// handled by something else (e.g. the agent stop channel). +func (bd BaseDeps) Close() { + bd.AutoConfig.Stop() + bd.MetricsConfig.Cancel() + + if fn := bd.deregisterBalancer; fn != nil { + fn() + } + if fn := bd.deregisterResolver; fn != nil { + fn() + } +} + // grpcLogInitOnce because the test suite will call NewBaseDeps in many tests and // causes data races when it is re-initialized. var grpcLogInitOnce sync.Once