// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package internal import ( "context" "fmt" "net" "strings" "sync/atomic" "testing" "time" "github.com/google/tcpproxy" "github.com/hashicorp/go-hclog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/hashicorp/consul/agent/grpc-internal/balancer" "github.com/hashicorp/consul/agent/grpc-internal/resolver" "github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/ipaddr" "github.com/hashicorp/consul/sdk/freeport" "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/types" ) // useTLSForDcAlwaysTrue tell GRPC to always return the TLS is enabled func useTLSForDcAlwaysTrue(_ string) bool { return true } func TestNewDialer_WithTLSWrapper(t *testing.T) { lis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) t.Cleanup(logError(t, lis.Close)) builder := resolver.NewServerResolverBuilder(newConfig(t)) builder.AddServer(types.AreaWAN, &metadata.Server{ Name: "server-1", ID: "ID1", Datacenter: "dc1", Addr: lis.Addr(), UseTLS: true, }) var called bool wrapper := func(_ string, conn net.Conn) (net.Conn, error) { called = true return conn, nil } dial := newDialer( ClientConnPoolConfig{ Servers: builder, TLSWrapper: wrapper, UseTLSForDC: useTLSForDcAlwaysTrue, DialingFromServer: true, DialingFromDatacenter: "dc1", }, &gatewayResolverDep{}, ) ctx := context.Background() conn, err := dial(ctx, resolver.DCPrefix("dc1", lis.Addr().String())) require.NoError(t, err) require.NoError(t, conn.Close()) require.True(t, called, "expected TLSWrapper to be called") } func TestNewDialer_WithALPNWrapper(t *testing.T) { lis1, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) t.Cleanup(logError(t, lis1.Close)) lis2, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) t.Cleanup(logError(t, lis2.Close)) // Send all of the traffic to dc2's server var p tcpproxy.Proxy gwAddr := ipaddr.FormatAddressPort("127.0.0.1", freeport.GetOne(t)) p.AddRoute(gwAddr, tcpproxy.To(lis2.Addr().String())) p.AddStopACMESearch(gwAddr) require.NoError(t, p.Start()) defer func() { p.Close() p.Wait() }() builder := resolver.NewServerResolverBuilder(newConfig(t)) builder.AddServer(types.AreaWAN, &metadata.Server{ Name: "server-1", ID: "ID1", Datacenter: "dc1", Addr: lis1.Addr(), UseTLS: true, }) builder.AddServer(types.AreaWAN, &metadata.Server{ Name: "server-2", ID: "ID2", Datacenter: "dc2", Addr: lis2.Addr(), UseTLS: true, }) var calledTLS bool wrapperTLS := func(_ string, conn net.Conn) (net.Conn, error) { calledTLS = true return conn, nil } var calledALPN bool wrapperALPN := func(_, _, _ string, conn net.Conn) (net.Conn, error) { calledALPN = true return conn, nil } gwResolverDep := &gatewayResolverDep{ GatewayResolver: func(addr string) string { return gwAddr }, } dial := newDialer( ClientConnPoolConfig{ Servers: builder, TLSWrapper: wrapperTLS, ALPNWrapper: wrapperALPN, UseTLSForDC: useTLSForDcAlwaysTrue, DialingFromServer: true, DialingFromDatacenter: "dc1", }, gwResolverDep, ) ctx := context.Background() conn, err := dial(ctx, resolver.DCPrefix("dc2", lis2.Addr().String())) require.NoError(t, err) require.NoError(t, conn.Close()) assert.False(t, calledTLS, "expected TLSWrapper not to be called") assert.True(t, calledALPN, "expected ALPNWrapper to be called") } func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) { // if this test is failing because of expired certificates // use the procedure in test/CA-GENERATION.md res := resolver.NewServerResolverBuilder(newConfig(t)) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) registerWithGRPC(t, res, bb) tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{ InternalRPC: tlsutil.ProtocolConfig{ VerifyIncoming: true, CAFile: "../../test/hostname/CertAuth.crt", CertFile: "../../test/hostname/Alice.crt", KeyFile: "../../test/hostname/Alice.key", VerifyOutgoing: true, }, }, hclog.New(nil)) require.NoError(t, err) srv := newSimpleTestServer(t, "server-1", "dc1", tlsConf) md := srv.Metadata() res.AddServer(types.AreaWAN, md) t.Cleanup(srv.shutdown) pool := NewClientConnPool(ClientConnPoolConfig{ Servers: res, TLSWrapper: TLSWrapper(tlsConf.OutgoingRPCWrapper()), UseTLSForDC: tlsConf.UseTLS, DialingFromServer: true, DialingFromDatacenter: "dc1", }) conn, err := pool.ClientConn("dc1") require.NoError(t, err) client := testservice.NewSimpleClient(conn) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) t.Cleanup(cancel) resp, err := client.Something(ctx, &testservice.Req{}) require.NoError(t, err) require.Equal(t, "server-1", resp.ServerName) require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) > 0) require.True(t, atomic.LoadInt32(&srv.rpc.alpnConnEstablished) == 0) } func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T) { // if this test is failing because of expired certificates // use the procedure in test/CA-GENERATION.md gwAddr := ipaddr.FormatAddressPort("127.0.0.1", freeport.GetOne(t)) res := resolver.NewServerResolverBuilder(newConfig(t)) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) registerWithGRPC(t, res, bb) tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{ InternalRPC: tlsutil.ProtocolConfig{ VerifyIncoming: true, CAFile: "../../test/hostname/CertAuth.crt", CertFile: "../../test/hostname/Bob.crt", KeyFile: "../../test/hostname/Bob.key", VerifyOutgoing: true, VerifyServerHostname: true, }, Domain: "consul", NodeName: "bob", }, hclog.New(nil)) require.NoError(t, err) srv := newSimpleTestServer(t, "bob", "dc1", tlsConf) // Send all of the traffic to dc1's server var p tcpproxy.Proxy p.AddRoute(gwAddr, tcpproxy.To(srv.addr.String())) p.AddStopACMESearch(gwAddr) require.NoError(t, p.Start()) defer func() { p.Close() p.Wait() }() md := srv.Metadata() res.AddServer(types.AreaWAN, md) t.Cleanup(srv.shutdown) clientTLSConf, err := tlsutil.NewConfigurator(tlsutil.Config{ InternalRPC: tlsutil.ProtocolConfig{ VerifyIncoming: true, CAFile: "../../test/hostname/CertAuth.crt", CertFile: "../../test/hostname/Betty.crt", KeyFile: "../../test/hostname/Betty.key", VerifyOutgoing: true, VerifyServerHostname: true, }, Domain: "consul", NodeName: "betty", }, hclog.New(nil)) require.NoError(t, err) pool := NewClientConnPool(ClientConnPoolConfig{ Servers: res, TLSWrapper: TLSWrapper(clientTLSConf.OutgoingRPCWrapper()), ALPNWrapper: ALPNWrapper(clientTLSConf.OutgoingALPNRPCWrapper()), UseTLSForDC: tlsConf.UseTLS, DialingFromServer: true, DialingFromDatacenter: "dc2", }) pool.SetGatewayResolver(func(addr string) string { return gwAddr }) conn, err := pool.ClientConn("dc1") require.NoError(t, err) client := testservice.NewSimpleClient(conn) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) t.Cleanup(cancel) resp, err := client.Something(ctx, &testservice.Req{}) require.NoError(t, err) require.Equal(t, "bob", resp.ServerName) require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) == 0) require.True(t, atomic.LoadInt32(&srv.rpc.alpnConnEstablished) > 0) } func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { count := 4 res := resolver.NewServerResolverBuilder(newConfig(t)) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) registerWithGRPC(t, res, bb) pool := NewClientConnPool(ClientConnPoolConfig{ Servers: res, UseTLSForDC: useTLSForDcAlwaysTrue, DialingFromServer: true, DialingFromDatacenter: "dc1", }) for i := 0; i < count; i++ { name := fmt.Sprintf("server-%d", i) srv := newSimpleTestServer(t, name, "dc1", nil) res.AddServer(types.AreaWAN, srv.Metadata()) t.Cleanup(srv.shutdown) } conn, err := pool.ClientConn("dc1") require.NoError(t, err) client := testservice.NewSimpleClient(conn) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) t.Cleanup(cancel) first, err := client.Something(ctx, &testservice.Req{}) require.NoError(t, err) res.RemoveServer(types.AreaWAN, &metadata.Server{ID: first.ServerName, Datacenter: "dc1"}) resp, err := client.Something(ctx, &testservice.Req{}) require.NoError(t, err) require.NotEqual(t, resp.ServerName, first.ServerName) } func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) { count := 3 res := resolver.NewServerResolverBuilder(newConfig(t)) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) registerWithGRPC(t, res, bb) pool := NewClientConnPool(ClientConnPoolConfig{ Servers: res, UseTLSForDC: useTLSForDcAlwaysTrue, DialingFromServer: true, DialingFromDatacenter: "dc1", }) var servers []testServer for i := 0; i < count; i++ { name := fmt.Sprintf("server-%d", i) srv := newSimpleTestServer(t, name, "dc1", nil) res.AddServer(types.AreaWAN, srv.Metadata()) servers = append(servers, srv) t.Cleanup(srv.shutdown) } // Set the leader address to the first server. srv0 := servers[0].Metadata() res.UpdateLeaderAddr(srv0.Datacenter, srv0.Addr.String()) conn, err := pool.ClientConnLeader() require.NoError(t, err) client := testservice.NewSimpleClient(conn) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) t.Cleanup(cancel) first, err := client.Something(ctx, &testservice.Req{}) require.NoError(t, err) require.Equal(t, first.ServerName, servers[0].name) // Update the leader address and make another request. srv1 := servers[1].Metadata() res.UpdateLeaderAddr(srv1.Datacenter, srv1.Addr.String()) resp, err := client.Something(ctx, &testservice.Req{}) require.NoError(t, err) require.Equal(t, resp.ServerName, servers[1].name) } func newConfig(t *testing.T) resolver.Config { n := t.Name() s := strings.Replace(n, "/", "", -1) s = strings.Replace(s, "_", "", -1) return resolver.Config{Authority: strings.ToLower(s)} } func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { dcs := []string{"dc1", "dc2", "dc3"} res := resolver.NewServerResolverBuilder(newConfig(t)) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) registerWithGRPC(t, res, bb) pool := NewClientConnPool(ClientConnPoolConfig{ Servers: res, UseTLSForDC: useTLSForDcAlwaysTrue, DialingFromServer: true, DialingFromDatacenter: "dc1", }) for _, dc := range dcs { name := "server-0-" + dc srv := newSimpleTestServer(t, name, dc, nil) res.AddServer(types.AreaWAN, srv.Metadata()) t.Cleanup(srv.shutdown) } ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) t.Cleanup(cancel) for _, dc := range dcs { conn, err := pool.ClientConn(dc) require.NoError(t, err) client := testservice.NewSimpleClient(conn) resp, err := client.Something(ctx, &testservice.Req{}) require.NoError(t, err) require.Equal(t, resp.Datacenter, dc) } } func registerWithGRPC(t *testing.T, rb *resolver.ServerResolverBuilder, bb *balancer.Builder) { resolver.Register(rb) bb.Register() t.Cleanup(func() { resolver.Deregister(rb.Authority()) bb.Deregister() }) }