diff --git a/agent/consul/client_test.go b/agent/consul/client_test.go index ea0250454..6fe726681 100644 --- a/agent/consul/client_test.go +++ b/agent/consul/client_test.go @@ -480,7 +480,7 @@ func newDefaultDeps(t *testing.T, c *Config) Deps { tls, err := tlsutil.NewConfigurator(c.ToTLSUtilConfig(), logger) require.NoError(t, err, "failed to create tls configuration") - r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter)) + r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter), nil) connPool := &pool.ConnPool{ Server: false, diff --git a/agent/grpc/client.go b/agent/grpc/client.go index d2f9f32b2..e65e95a13 100644 --- a/agent/grpc/client.go +++ b/agent/grpc/client.go @@ -10,61 +10,71 @@ import ( "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/pool" - "github.com/hashicorp/consul/tlsutil" ) -type ServerProvider interface { - Servers() []*metadata.Server +// ClientConnPool creates and stores a connection for each datacenter. +type ClientConnPool struct { + dialer dialer + servers ServerLocator + conns map[string]*grpc.ClientConn + connsLock sync.Mutex } -type Client struct { - serverProvider ServerProvider - tlsConfigurator *tlsutil.Configurator - grpcConns map[string]*grpc.ClientConn - grpcConnLock sync.Mutex +type ServerLocator interface { + // ServerForAddr is used to look up server metadata from an address. + ServerForAddr(addr string) (*metadata.Server, error) + // Scheme returns the url scheme to use to dial the server. This is primarily + // needed for testing multiple agents in parallel, because gRPC requires the + // resolver to be registered globally. + Scheme() string } -func NewGRPCClient(serverProvider ServerProvider, tlsConfigurator *tlsutil.Configurator) *Client { - // Note we don't actually use the logger anywhere yet but I guess it was added - // for future compatibility... - return &Client{ - serverProvider: serverProvider, - tlsConfigurator: tlsConfigurator, - grpcConns: make(map[string]*grpc.ClientConn), +// TLSWrapper wraps a non-TLS connection and returns a connection with TLS +// enabled. +type TLSWrapper func(dc string, conn net.Conn) (net.Conn, error) + +type dialer func(context.Context, string) (net.Conn, error) + +func NewClientConnPool(servers ServerLocator, tls TLSWrapper) *ClientConnPool { + return &ClientConnPool{ + dialer: newDialer(servers, tls), + servers: servers, + conns: make(map[string]*grpc.ClientConn), } } -func (c *Client) GRPCConn(datacenter string) (*grpc.ClientConn, error) { - c.grpcConnLock.Lock() - defer c.grpcConnLock.Unlock() +// ClientConn returns a grpc.ClientConn for the datacenter. If there are no +// existing connections in the pool, a new one will be created, stored in the pool, +// then returned. +func (c *ClientConnPool) ClientConn(datacenter string) (*grpc.ClientConn, error) { + c.connsLock.Lock() + defer c.connsLock.Unlock() - // If there's an existing ClientConn for the given DC, return it. - if conn, ok := c.grpcConns[datacenter]; ok { + if conn, ok := c.conns[datacenter]; ok { return conn, nil } - dialer := newDialer(c.serverProvider, c.tlsConfigurator.OutgoingRPCWrapper()) - conn, err := grpc.Dial(fmt.Sprintf("%s:///server.%s", scheme, datacenter), + conn, err := grpc.Dial(fmt.Sprintf("%s:///server.%s", c.servers.Scheme(), datacenter), // use WithInsecure mode here because we handle the TLS wrapping in the // custom dialer based on logic around whether the server has TLS enabled. grpc.WithInsecure(), - grpc.WithContextDialer(dialer), + grpc.WithContextDialer(c.dialer), grpc.WithDisableRetry(), - // TODO: previously this handler was shared with the Handler. Is that necessary? + // TODO: previously this statsHandler was shared with the Handler. Is that necessary? grpc.WithStatsHandler(&statsHandler{}), + // nolint:staticcheck // there is no other supported alternative to WithBalancerName grpc.WithBalancerName("pick_first")) if err != nil { return nil, err } - c.grpcConns[datacenter] = conn - + c.conns[datacenter] = conn return conn, nil } // newDialer returns a gRPC dialer function that conditionally wraps the connection -// with TLS depending on the given useTLS value. -func newDialer(serverProvider ServerProvider, wrapper tlsutil.DCWrapper) func(context.Context, string) (net.Conn, error) { +// with TLS based on the Server.useTLS value. +func newDialer(servers ServerLocator, wrapper TLSWrapper) func(context.Context, string) (net.Conn, error) { return func(ctx context.Context, addr string) (net.Conn, error) { d := net.Dialer{} conn, err := d.DialContext(ctx, "tcp", addr) @@ -72,17 +82,10 @@ func newDialer(serverProvider ServerProvider, wrapper tlsutil.DCWrapper) func(co return nil, err } - // Check if TLS is enabled for the server. - var found bool - var server *metadata.Server - for _, s := range serverProvider.Servers() { - if s.Addr.String() == addr { - found = true - server = s - } - } - if !found { - return nil, fmt.Errorf("could not find Consul server for address %q", addr) + server, err := servers.ServerForAddr(addr) + if err != nil { + // TODO: should conn be closed in this case, as it is in other error cases? + return nil, err } if server.UseTLS { @@ -107,6 +110,7 @@ func newDialer(serverProvider ServerProvider, wrapper tlsutil.DCWrapper) func(co _, err = conn.Write([]byte{pool.RPCGRPC}) if err != nil { + // TODO: should conn be closed in this case, as it is in other error cases? return nil, err } diff --git a/agent/grpc/resolver/resolver.go b/agent/grpc/resolver/resolver.go index fa52af5d2..82e814ae0 100644 --- a/agent/grpc/resolver/resolver.go +++ b/agent/grpc/resolver/resolver.go @@ -1,7 +1,8 @@ -package grpc +package resolver import ( "context" + "fmt" "math/rand" "strings" "sync" @@ -12,17 +13,21 @@ import ( "google.golang.org/grpc/resolver" ) -//var registerLock sync.Mutex -// -//// registerResolverBuilder registers our custom grpc resolver with the given scheme. -//func registerResolverBuilder(datacenter string) *ServerResolverBuilder { -// registerLock.Lock() -// defer registerLock.Unlock() -// grpcResolverBuilder := NewServerResolverBuilder(datacenter) -// resolver.Register(grpcResolverBuilder) -// return grpcResolverBuilder -//} +var registerLock sync.Mutex +// RegisterWithGRPC registers the ServerResolverBuilder as a grpc/resolver. +// This function exists to synchronize registrations with a lock. +// grpc/resolver.Register expects all registration to happen at init and does +// not allow for concurrent registration. This function exists to support +// parallel testing. +func RegisterWithGRPC(b *ServerResolverBuilder) { + registerLock.Lock() + defer registerLock.Unlock() + resolver.Register(b) +} + +// Nodes provides a count of the number of nodes in the cluster. It is very +// likely implemented by serf to return the number of LAN members. type Nodes interface { NumNodes() int } @@ -30,27 +35,52 @@ type Nodes interface { // ServerResolverBuilder tracks the current server list and keeps any // ServerResolvers updated when changes occur. type ServerResolverBuilder struct { - // Allow overriding the scheme to support parallel tests, since - // the resolver builder is registered globally. - scheme string + // datacenter of the local agent. datacenter string - servers map[string]*metadata.Server - resolvers map[resolver.ClientConn]*ServerResolver - nodes Nodes - lock sync.Mutex + // scheme used to query the server. Defaults to consul. Used to support + // parallel testing because gRPC registers resolvers globally. + scheme string + // servers is an index of Servers by Server.ID + servers map[string]*metadata.Server + // resolvers is an index of connections to the serverResolver which manages + // addresses of servers for that connection. + resolvers map[resolver.ClientConn]*serverResolver + // nodes provides the number of nodes in the cluster. + nodes Nodes + // lock for servers and resolvers. + lock sync.RWMutex } -func NewServerResolverBuilder(nodes Nodes, datacenter string) *ServerResolverBuilder { +var _ resolver.Builder = (*ServerResolverBuilder)(nil) + +type Config struct { + // Datacenter of the local agent. + Datacenter string + // Scheme used to connect to the server. Defaults to consul. + Scheme string +} + +func NewServerResolverBuilder(cfg Config, nodes Nodes) *ServerResolverBuilder { + if cfg.Scheme == "" { + cfg.Scheme = "consul" + } return &ServerResolverBuilder{ - datacenter: datacenter, + scheme: cfg.Scheme, + datacenter: cfg.Datacenter, nodes: nodes, servers: make(map[string]*metadata.Server), - resolvers: make(map[resolver.ClientConn]*ServerResolver), + resolvers: make(map[resolver.ClientConn]*serverResolver), } } -// Run periodically reshuffles the order of server addresses -// within the resolvers to ensure the load is balanced across servers. +// Run periodically reshuffles the order of server addresses within the +// resolvers to ensure the load is balanced across servers. +// +// TODO: this looks very similar to agent/router.Manager.Start, which is the +// only other caller of ComputeRebalanceTimer. Are the values passed to these +// two functions different enough that we need separate goroutines to rebalance? +// or could we have a single thing handle the timers, and call both rebalance +// functions? func (s *ServerResolverBuilder) Run(ctx context.Context) { // Compute the rebalance timer based on the number of local servers and nodes. rebalanceDuration := router.ComputeRebalanceTimer(s.serversInDC(s.datacenter), s.nodes.NumNodes()) @@ -73,13 +103,13 @@ func (s *ServerResolverBuilder) Run(ctx context.Context) { // rebalanceResolvers shuffles the server list for resolvers in all datacenters. func (s *ServerResolverBuilder) rebalanceResolvers() { - s.lock.Lock() - defer s.lock.Unlock() + s.lock.RLock() + defer s.lock.RUnlock() for _, resolver := range s.resolvers { // Shuffle the list of addresses using the last list given to the resolver. resolver.addrLock.Lock() - addrs := resolver.lastAddrs + addrs := resolver.addrs rand.Shuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] }) @@ -91,8 +121,8 @@ func (s *ServerResolverBuilder) rebalanceResolvers() { // serversInDC returns the number of servers in the given datacenter. func (s *ServerResolverBuilder) serversInDC(dc string) int { - s.lock.Lock() - defer s.lock.Unlock() + s.lock.RLock() + defer s.lock.RUnlock() var serverCount int for _, server := range s.servers { @@ -104,52 +134,49 @@ func (s *ServerResolverBuilder) serversInDC(dc string) int { return serverCount } -// Servers returns metadata for all currently known servers. This is used -// by grpc.ClientConn through our custom dialer. -func (s *ServerResolverBuilder) Servers() []*metadata.Server { - s.lock.Lock() - defer s.lock.Unlock() +// ServerForAddr returns server metadata for a server with the specified address. +func (s *ServerResolverBuilder) ServerForAddr(addr string) (*metadata.Server, error) { + s.lock.RLock() + defer s.lock.RUnlock() - servers := make([]*metadata.Server, 0, len(s.servers)) for _, server := range s.servers { - servers = append(servers, server) + if server.Addr.String() == addr { + return server, nil + } } - return servers + return nil, fmt.Errorf("failed to find Consul server for address %q", addr) } -// Build returns a new ServerResolver for the given ClientConn. The resolver +// Build returns a new serverResolver for the given ClientConn. The resolver // will keep the ClientConn's state updated based on updates from Serf. func (s *ServerResolverBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOption) (resolver.Resolver, error) { s.lock.Lock() defer s.lock.Unlock() - // If there's already a resolver for this datacenter, return it. - datacenter := strings.TrimPrefix(target.Endpoint, "server.") + // If there's already a resolver for this connection, return it. + // TODO(streaming): how would this happen since we already cache connections in ClientConnPool? if resolver, ok := s.resolvers[cc]; ok { return resolver, nil } // Make a new resolver for the dc and add it to the list of active ones. - resolver := &ServerResolver{ + datacenter := strings.TrimPrefix(target.Endpoint, "server.") + resolver := &serverResolver{ datacenter: datacenter, clientConn: cc, + close: func() { + s.lock.Lock() + defer s.lock.Unlock() + delete(s.resolvers, cc) + }, } resolver.updateAddrs(s.getDCAddrs(datacenter)) - resolver.closeCallback = func() { - s.lock.Lock() - defer s.lock.Unlock() - delete(s.resolvers, cc) - } s.resolvers[cc] = resolver - return resolver, nil } -// scheme is the URL scheme used to dial the Consul Server rpc endpoint. -var scheme = "consul" - -func (s *ServerResolverBuilder) Scheme() string { return scheme } +func (s *ServerResolverBuilder) Scheme() string { return s.scheme } // AddServer updates the resolvers' states to include the new server's address. func (s *ServerResolverBuilder) AddServer(server *metadata.Server) { @@ -182,7 +209,7 @@ func (s *ServerResolverBuilder) RemoveServer(server *metadata.Server) { } // getDCAddrs returns a list of the server addresses for the given datacenter. -// This method assumes the lock is held. +// This method requires that lock is held for reads. func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address { var addrs []resolver.Address for _, server := range s.servers { @@ -199,28 +226,39 @@ func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address { return addrs } -// ServerResolver is a grpc Resolver that will keep a grpc.ClientConn up to date +// serverResolver is a grpc Resolver that will keep a grpc.ClientConn up to date // on the list of server addresses to use. -type ServerResolver struct { - datacenter string - clientConn resolver.ClientConn - closeCallback func() +type serverResolver struct { + // datacenter that can be reached by the clientConn. Used by ServerResolverBuilder + // to filter resolvers for those in a specific datacenter. + datacenter string - lastAddrs []resolver.Address - addrLock sync.Mutex + // clientConn that this resolver is providing addresses for. + clientConn resolver.ClientConn + + // close is used by ServerResolverBuilder to remove this resolver from the + // index of resolvers. It is called by grpc when the connection is closed. + close func() + + // addrs stores the list of addresses passed to updateAddrs, so that they + // can be rebalanced periodically by ServerResolverBuilder. + addrs []resolver.Address + addrLock sync.Mutex } -// updateAddrs updates this ServerResolver's ClientConn to use the given set of +var _ resolver.Resolver = (*serverResolver)(nil) + +// updateAddrs updates this serverResolver's ClientConn to use the given set of // addrs. -func (r *ServerResolver) updateAddrs(addrs []resolver.Address) { +func (r *serverResolver) updateAddrs(addrs []resolver.Address) { r.addrLock.Lock() defer r.addrLock.Unlock() r.updateAddrsLocked(addrs) } -// updateAddrsLocked updates this ServerResolver's ClientConn to use the given -// set of addrs. addrLock must be held by calleer. -func (r *ServerResolver) updateAddrsLocked(addrs []resolver.Address) { +// updateAddrsLocked updates this serverResolver's ClientConn to use the given +// set of addrs. addrLock must be held by caller. +func (r *serverResolver) updateAddrsLocked(addrs []resolver.Address) { // Only pass the first address initially, which will cause the // balancer to spin down the connection for its previous first address // if it is different. If we don't do this, it will keep using the old @@ -236,12 +274,12 @@ func (r *ServerResolver) updateAddrsLocked(addrs []resolver.Address) { // for failover. r.clientConn.UpdateState(resolver.State{Addresses: addrs}) - r.lastAddrs = addrs + r.addrs = addrs } -func (r *ServerResolver) Close() { - r.closeCallback() +func (r *serverResolver) Close() { + r.close() } -// Unneeded since we only update the ClientConn when our server list changes. -func (*ServerResolver) ResolveNow(_ resolver.ResolveNowOption) {} +// ResolveNow is not used +func (*serverResolver) ResolveNow(_ resolver.ResolveNowOption) {} diff --git a/agent/setup.go b/agent/setup.go index 18a0be0c3..d56419680 100644 --- a/agent/setup.go +++ b/agent/setup.go @@ -82,7 +82,8 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer) (BaseDeps, error) d.Cache = cache.New(cfg.Cache) d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator) - d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter)) + // TODO: set grpcServerTracker, requires serf to be setup before this. + d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), nil) acConf := autoconf.Config{ DirectRPC: d.ConnPool,