diff --git a/agent/grpc/client.go b/agent/grpc/client.go index 783cbae36..6a3feaf7a 100644 --- a/agent/grpc/client.go +++ b/agent/grpc/client.go @@ -61,8 +61,7 @@ func (c *ClientConnPool) ClientConn(datacenter string) (*grpc.ClientConn, error) grpc.WithInsecure(), grpc.WithContextDialer(c.dialer), grpc.WithDisableRetry(), - // TODO: previously this statsHandler was shared with the Handler. Is that necessary? - grpc.WithStatsHandler(newStatsHandler()), + grpc.WithStatsHandler(newStatsHandler(defaultMetrics)), // nolint:staticcheck // there is no other supported alternative to WithBalancerName grpc.WithBalancerName("pick_first")) if err != nil { diff --git a/agent/grpc/client_test.go b/agent/grpc/client_test.go index 400e0a815..38ecc40aa 100644 --- a/agent/grpc/client_test.go +++ b/agent/grpc/client_test.go @@ -5,14 +5,17 @@ import ( "fmt" "net" "strings" + "sync/atomic" "testing" "time" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/require" + "github.com/hashicorp/consul/agent/grpc/internal/testservice" "github.com/hashicorp/consul/agent/grpc/resolver" "github.com/hashicorp/consul/agent/metadata" - "github.com/hashicorp/consul/sdk/testutil/retry" - "github.com/stretchr/testify/require" + "github.com/hashicorp/consul/tlsutil" ) func TestNewDialer_WithTLSWrapper(t *testing.T) { @@ -42,14 +45,43 @@ func TestNewDialer_WithTLSWrapper(t *testing.T) { require.True(t, called, "expected TLSWrapper to be called") } -// TODO: integration test TestNewDialer with TLS and rcp server, when the rpc -// exists as an isolated component. +func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) { + res := resolver.NewServerResolverBuilder(newConfig(t)) + registerWithGRPC(t, res) + + srv := newTestServer(t, "server-1", "dc1") + tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{ + VerifyIncoming: true, + VerifyOutgoing: true, + CAFile: "../../test/hostname/CertAuth.crt", + CertFile: "../../test/hostname/Alice.crt", + KeyFile: "../../test/hostname/Alice.key", + }, hclog.New(nil)) + require.NoError(t, err) + srv.rpc.tlsConf = tlsConf + + res.AddServer(srv.Metadata()) + t.Cleanup(srv.shutdown) + + pool := NewClientConnPool(res, TLSWrapper(tlsConf.OutgoingRPCWrapper())) + + conn, err := pool.ClientConn("dc1") + require.NoError(t, err) + client := testservice.NewSimpleClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(cancel) + + resp, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + require.Equal(t, "server-1", resp.ServerName) + require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) > 0) +} func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { count := 4 - cfg := resolver.Config{Scheme: newScheme(t.Name())} - res := resolver.NewServerResolverBuilder(cfg) - resolver.RegisterWithGRPC(res) + res := resolver.NewServerResolverBuilder(newConfig(t)) + registerWithGRPC(t, res) pool := NewClientConnPool(res, nil) for i := 0; i < count; i++ { @@ -76,17 +108,17 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { require.NotEqual(t, resp.ServerName, first.ServerName) } -func newScheme(n string) string { +func newConfig(t *testing.T) resolver.Config { + n := t.Name() s := strings.Replace(n, "/", "", -1) s = strings.Replace(s, "_", "", -1) - return strings.ToLower(s) + return resolver.Config{Scheme: strings.ToLower(s)} } func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) { - count := 4 - cfg := resolver.Config{Scheme: newScheme(t.Name())} - res := resolver.NewServerResolverBuilder(cfg) - resolver.RegisterWithGRPC(res) + count := 5 + res := resolver.NewServerResolverBuilder(newConfig(t)) + registerWithGRPC(t, res) pool := NewClientConnPool(res, nil) for i := 0; i < count; i++ { @@ -117,22 +149,25 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) { t.Run("rebalance the dc", func(t *testing.T) { // Rebalance is random, but if we repeat it a few times it should give us a // new server. - retry.RunWith(fastRetry, t, func(r *retry.R) { + attempts := 100 + for i := 0; i < attempts; i++ { res.NewRebalancer("dc1")() resp, err := client.Something(ctx, &testservice.Req{}) - require.NoError(r, err) - require.NotEqual(r, resp.ServerName, first.ServerName) - }) + require.NoError(t, err) + if resp.ServerName != first.ServerName { + return + } + } + t.Fatalf("server was not rebalanced after %v attempts", attempts) }) } func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { dcs := []string{"dc1", "dc2", "dc3"} - cfg := resolver.Config{Scheme: newScheme(t.Name())} - res := resolver.NewServerResolverBuilder(cfg) - resolver.RegisterWithGRPC(res) + res := resolver.NewServerResolverBuilder(newConfig(t)) + registerWithGRPC(t, res) pool := NewClientConnPool(res, nil) for _, dc := range dcs { diff --git a/agent/grpc/handler.go b/agent/grpc/handler.go index d70fd2b10..d381b9da8 100644 --- a/agent/grpc/handler.go +++ b/agent/grpc/handler.go @@ -17,12 +17,12 @@ func NewHandler(addr net.Addr, register func(server *grpc.Server)) *Handler { // We don't need to pass tls.Config to the server since it's multiplexed // behind the RPC listener, which already has TLS configured. srv := grpc.NewServer( - grpc.StatsHandler(newStatsHandler()), - grpc.StreamInterceptor((&activeStreamCounter{}).Intercept), + grpc.StatsHandler(newStatsHandler(defaultMetrics)), + grpc.StreamInterceptor((&activeStreamCounter{metrics: defaultMetrics}).Intercept), ) register(srv) - lis := &chanListener{addr: addr, conns: make(chan net.Conn)} + lis := &chanListener{addr: addr, conns: make(chan net.Conn), done: make(chan struct{})} return &Handler{srv: srv, listener: lis} } @@ -51,22 +51,22 @@ func (h *Handler) Shutdown() error { type chanListener struct { conns chan net.Conn addr net.Addr + done chan struct{} } // Accept blocks until a connection is received from Handle, and then returns the // connection. Accept implements part of the net.Listener interface for grpc.Server. func (l *chanListener) Accept() (net.Conn, error) { select { - case c, ok := <-l.conns: - if !ok { - return nil, &net.OpError{ - Op: "accept", - Net: l.addr.Network(), - Addr: l.addr, - Err: fmt.Errorf("listener closed"), - } - } + case c := <-l.conns: return c, nil + case <-l.done: + return nil, &net.OpError{ + Op: "accept", + Net: l.addr.Network(), + Addr: l.addr, + Err: fmt.Errorf("listener closed"), + } } } @@ -75,7 +75,7 @@ func (l *chanListener) Addr() net.Addr { } func (l *chanListener) Close() error { - close(l.conns) + close(l.done) return nil } diff --git a/agent/grpc/resolver/resolver.go b/agent/grpc/resolver/resolver.go index b34aad72f..76a2188d2 100644 --- a/agent/grpc/resolver/resolver.go +++ b/agent/grpc/resolver/resolver.go @@ -7,23 +7,11 @@ import ( "sync" "time" - "github.com/hashicorp/consul/agent/metadata" "google.golang.org/grpc/resolver" + + "github.com/hashicorp/consul/agent/metadata" ) -var registerLock sync.Mutex - -// RegisterWithGRPC registers the ServerResolverBuilder as a grpc/resolver. -// This function exists to synchronize registrations with a lock. -// grpc/resolver.Register expects all registration to happen at init and does -// not allow for concurrent registration. This function exists to support -// parallel testing. -func RegisterWithGRPC(b *ServerResolverBuilder) { - registerLock.Lock() - defer registerLock.Unlock() - resolver.Register(b) -} - // ServerResolverBuilder tracks the current server list and keeps any // ServerResolvers updated when changes occur. type ServerResolverBuilder struct { @@ -31,7 +19,7 @@ type ServerResolverBuilder struct { // parallel testing because gRPC registers resolvers globally. scheme string // servers is an index of Servers by Server.ID. The map contains server IDs - // for all datacenters, so it assumes the ID is globally unique. + // for all datacenters. servers map[string]*metadata.Server // resolvers is an index of connections to the serverResolver which manages // addresses of servers for that connection. @@ -131,7 +119,7 @@ func (s *ServerResolverBuilder) AddServer(server *metadata.Server) { s.lock.Lock() defer s.lock.Unlock() - s.servers[server.ID] = server + s.servers[uniqueID(server)] = server addrs := s.getDCAddrs(server.Datacenter) for _, resolver := range s.resolvers { @@ -141,12 +129,21 @@ func (s *ServerResolverBuilder) AddServer(server *metadata.Server) { } } +// uniqueID returns a unique identifier for the server which includes the +// Datacenter and the ID. +// +// In practice it is expected that the server.ID is already a globally unique +// UUID. This function is an extra safeguard in case that ever changes. +func uniqueID(server *metadata.Server) string { + return server.Datacenter + "-" + server.ID +} + // RemoveServer updates the resolvers' states with the given server removed. func (s *ServerResolverBuilder) RemoveServer(server *metadata.Server) { s.lock.Lock() defer s.lock.Unlock() - delete(s.servers, server.ID) + delete(s.servers, uniqueID(server)) addrs := s.getDCAddrs(server.Datacenter) for _, resolver := range s.resolvers { diff --git a/agent/grpc/server_test.go b/agent/grpc/server_test.go index 68417354b..b660a66a7 100644 --- a/agent/grpc/server_test.go +++ b/agent/grpc/server_test.go @@ -2,18 +2,23 @@ package grpc import ( "context" + "crypto/tls" "fmt" "io" "net" + "sync/atomic" "testing" "time" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/grpc/resolver" + "github.com/hashicorp/consul/agent/grpc/internal/testservice" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/pool" - "github.com/stretchr/testify/require" - "golang.org/x/sync/errgroup" - "google.golang.org/grpc" + "github.com/hashicorp/consul/tlsutil" ) type testServer struct { @@ -21,10 +26,16 @@ type testServer struct { name string dc string shutdown func() + rpc *fakeRPCListener } func (s testServer) Metadata() *metadata.Server { - return &metadata.Server{ID: s.name, Datacenter: s.dc, Addr: s.addr} + return &metadata.Server{ + ID: s.name, + Datacenter: s.dc, + Addr: s.addr, + UseTLS: s.rpc.tlsConf != nil, + } } func newTestServer(t *testing.T, name string, dc string) testServer { @@ -40,16 +51,24 @@ func newTestServer(t *testing.T, name string, dc string) testServer { g := errgroup.Group{} g.Go(func() error { - return rpc.listen(lis) + if err := rpc.listen(lis); err != nil { + return fmt.Errorf("fake rpc listen error: %w", err) + } + return nil }) g.Go(func() error { - return handler.Run() + if err := handler.Run(); err != nil { + return fmt.Errorf("grpc server error: %w", err) + } + return nil }) return testServer{ addr: lis.Addr(), name: name, dc: dc, + rpc: rpc, shutdown: func() { + rpc.shutdown = true if err := lis.Close(); err != nil { t.Logf("listener closed with error: %v", err) } @@ -57,7 +76,7 @@ func newTestServer(t *testing.T, name string, dc string) testServer { t.Logf("grpc server shutdown: %v", err) } if err := g.Wait(); err != nil { - t.Logf("grpc server error: %v", err) + t.Log(err) } }, } @@ -89,14 +108,20 @@ func (s *simple) Something(_ context.Context, _ *testservice.Req) (*testservice. // For now, since this logic is in agent/consul, we can't easily use Server.listen // so we fake it. type fakeRPCListener struct { - t *testing.T - handler *Handler + t *testing.T + handler *Handler + shutdown bool + tlsConf *tlsutil.Configurator + tlsConnEstablished int32 } func (f *fakeRPCListener) listen(listener net.Listener) error { for { conn, err := listener.Accept() if err != nil { + if f.shutdown { + return nil + } return err } @@ -116,11 +141,36 @@ func (f *fakeRPCListener) handleConn(conn net.Conn) { } typ := pool.RPCType(buf[0]) - if typ == pool.RPCGRPC { + switch typ { + + case pool.RPCGRPC: f.handler.Handle(conn) return - } - fmt.Println("ERROR: unexpected byte", typ) - conn.Close() + case pool.RPCTLS: + // occasionally we see a test client connecting to an rpc listener that + // was created as part of another test, despite none of the tests running + // in parallel. + // Maybe some strange grpc behaviour? I'm not sure. + if f.tlsConf == nil { + fmt.Println("ERROR: tls is not configured") + conn.Close() + return + } + + atomic.AddInt32(&f.tlsConnEstablished, 1) + conn = tls.Server(conn, f.tlsConf.IncomingRPCConfig()) + f.handleConn(conn) + + default: + fmt.Println("ERROR: unexpected byte", typ) + conn.Close() + } +} + +func registerWithGRPC(t *testing.T, b resolver.Builder) { + resolver.Register(b) + t.Cleanup(func() { + resolver.UnregisterForTesting(b.Scheme()) + }) } diff --git a/agent/grpc/stats.go b/agent/grpc/stats.go index d25048110..16961e7f0 100644 --- a/agent/grpc/stats.go +++ b/agent/grpc/stats.go @@ -18,8 +18,8 @@ type statsHandler struct { activeConns uint64 // must be 8-byte aligned for atomic access } -func newStatsHandler() *statsHandler { - return &statsHandler{metrics: defaultMetrics} +func newStatsHandler(m *metrics.Metrics) *statsHandler { + return &statsHandler{metrics: m} } // TagRPC implements grpcStats.StatsHandler @@ -64,6 +64,7 @@ func (c *statsHandler) HandleConn(_ context.Context, s stats.ConnStats) { } type activeStreamCounter struct { + metrics *metrics.Metrics // count of the number of open streaming RPCs on a server. It is accessed // atomically. count uint64 @@ -78,10 +79,10 @@ func (i *activeStreamCounter) Intercept( handler grpc.StreamHandler, ) error { count := atomic.AddUint64(&i.count, 1) - defaultMetrics.SetGauge([]string{"grpc", "server", "active_streams"}, float32(count)) + i.metrics.SetGauge([]string{"grpc", "server", "active_streams"}, float32(count)) defer func() { count := atomic.AddUint64(&i.count, ^uint64(0)) - defaultMetrics.SetGauge([]string{"grpc", "server", "active_streams"}, float32(count)) + i.metrics.SetGauge([]string{"grpc", "server", "active_streams"}, float32(count)) }() return handler(srv, ss) diff --git a/agent/grpc/stats_test.go b/agent/grpc/stats_test.go index f98d0f3cb..ea4cb70b2 100644 --- a/agent/grpc/stats_test.go +++ b/agent/grpc/stats_test.go @@ -3,32 +3,33 @@ package grpc import ( "context" "net" + "sort" "sync" "testing" "time" "github.com/armon/go-metrics" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" "google.golang.org/grpc" "github.com/hashicorp/consul/agent/grpc/internal/testservice" - "github.com/hashicorp/consul/sdk/testutil/retry" ) func noopRegister(*grpc.Server) {} func TestHandler_EmitsStats(t *testing.T) { - sink := patchGlobalMetrics(t) + sink, reset := patchGlobalMetrics(t) addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} handler := NewHandler(addr, noopRegister) + reset() testservice.RegisterSimpleServer(handler.srv, &simple{}) lis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) - t.Cleanup(logError(t, lis.Close)) ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) @@ -48,7 +49,7 @@ func TestHandler_EmitsStats(t *testing.T) { conn, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithInsecure()) require.NoError(t, err) - t.Cleanup(logError(t, conn.Close)) + t.Cleanup(func() { conn.Close() }) client := testservice.NewSimpleClient(conn) fClient, err := client.Flow(ctx, &testservice.Req{Datacenter: "mine"}) @@ -64,23 +65,42 @@ func TestHandler_EmitsStats(t *testing.T) { // Wait for the server to stop so that active_streams is predictable. require.NoError(t, g.Wait()) + // Occasionally the active_stream=0 metric may be emitted before the + // active_conns=0 metric. The order of those metrics is not really important + // so we sort the calls to match the expected. + sort.Slice(sink.gaugeCalls, func(i, j int) bool { + if i < 2 || j < 2 { + return i < j + } + if len(sink.gaugeCalls[i].key) < 4 || len(sink.gaugeCalls[j].key) < 4 { + return i < j + } + return sink.gaugeCalls[i].key[3] < sink.gaugeCalls[j].key[3] + }) + + cmpMetricCalls := cmp.AllowUnexported(metricCall{}) expectedGauge := []metricCall{ {key: []string{"testing", "grpc", "server", "active_conns"}, val: 1}, {key: []string{"testing", "grpc", "server", "active_streams"}, val: 1}, {key: []string{"testing", "grpc", "server", "active_conns"}, val: 0}, {key: []string{"testing", "grpc", "server", "active_streams"}, val: 0}, } - require.Equal(t, expectedGauge, sink.gaugeCalls) + assertDeepEqual(t, expectedGauge, sink.gaugeCalls, cmpMetricCalls) expectedCounter := []metricCall{ {key: []string{"testing", "grpc", "server", "request"}, val: 1}, } - require.Equal(t, expectedCounter, sink.incrCounterCalls) + assertDeepEqual(t, expectedCounter, sink.incrCounterCalls, cmpMetricCalls) } -var fastRetry = &retry.Timer{Timeout: 7 * time.Second, Wait: 2 * time.Millisecond} +func assertDeepEqual(t *testing.T, x, y interface{}, opts ...cmp.Option) { + t.Helper() + if diff := cmp.Diff(x, y, opts...); diff != "" { + t.Fatalf("assertion failed: values are not equal\n--- expected\n+++ actual\n%v", diff) + } +} -func patchGlobalMetrics(t *testing.T) *fakeMetricsSink { +func patchGlobalMetrics(t *testing.T) (*fakeMetricsSink, func()) { t.Helper() sink := &fakeMetricsSink{} @@ -93,11 +113,12 @@ func patchGlobalMetrics(t *testing.T) *fakeMetricsSink { var err error defaultMetrics, err = metrics.New(cfg, sink) require.NoError(t, err) - t.Cleanup(func() { - _, err = metrics.NewGlobal(cfg, &metrics.BlackholeSink{}) + reset := func() { + t.Helper() + defaultMetrics, err = metrics.New(cfg, &metrics.BlackholeSink{}) require.NoError(t, err, "failed to reset global metrics") - }) - return sink + } + return sink, reset } type fakeMetricsSink struct { diff --git a/agent/setup.go b/agent/setup.go index 7c65777c9..c7fe7f523 100644 --- a/agent/setup.go +++ b/agent/setup.go @@ -5,10 +5,12 @@ import ( "io" "net" "net/http" + "sync" "time" "github.com/hashicorp/go-hclog" "google.golang.org/grpc/grpclog" + grpcresolver "google.golang.org/grpc/resolver" autoconf "github.com/hashicorp/consul/agent/auto-config" "github.com/hashicorp/consul/agent/cache" @@ -88,7 +90,7 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer) (BaseDeps, error) d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator) builder := resolver.NewServerResolverBuilder(resolver.Config{}) - resolver.RegisterWithGRPC(builder) + registerWithGRPC(builder) d.GRPCConnPool = grpc.NewClientConnPool(builder, grpc.TLSWrapper(d.TLSConfigurator.OutgoingRPCWrapper())) d.Router = router.NewRouter(d.Logger, cfg.Datacenter, fmt.Sprintf("%s.%s", cfg.NodeName, cfg.Datacenter), builder) @@ -162,3 +164,16 @@ func newConnPool(config *config.RuntimeConfig, logger hclog.Logger, tls *tlsutil } return pool } + +var registerLock sync.Mutex + +// registerWithGRPC registers the grpc/resolver.Builder as a grpc/resolver. +// This function exists to synchronize registrations with a lock. +// grpc/resolver.Register expects all registration to happen at init and does +// not allow for concurrent registration. This function exists to support +// parallel testing. +func registerWithGRPC(b grpcresolver.Builder) { + registerLock.Lock() + defer registerLock.Unlock() + grpcresolver.Register(b) +} diff --git a/tlsutil/config.go b/tlsutil/config.go index c966ec724..b40a11b42 100644 --- a/tlsutil/config.go +++ b/tlsutil/config.go @@ -13,8 +13,9 @@ import ( "sync" "time" - "github.com/hashicorp/consul/logging" "github.com/hashicorp/go-hclog" + + "github.com/hashicorp/consul/logging" ) // ALPNWrapper is a function that is used to wrap a non-TLS connection and @@ -80,9 +81,6 @@ type Config struct { // cannot break existing clients. VerifyServerHostname bool - // UseTLS is used to enable outgoing TLS connections to Consul servers. - UseTLS bool - // CAFile is a path to a certificate authority file. This is used with // VerifyIncoming or VerifyOutgoing to verify the TLS connection. CAFile string