From 5cda50549573f9731df86c52457d0a23cc2eff3c Mon Sep 17 00:00:00 2001 From: Hans Hasselberg Date: Thu, 28 May 2020 10:18:30 +0200 Subject: [PATCH] pool: remove useTLS and ForceTLS In the past TLS usage was enforced with these variables, but these days this decision is made by TLSConfigurator and there is no reason to keep using the variables. --- agent/consul/client.go | 3 +- agent/consul/server.go | 1 - agent/consul/snapshot_endpoint.go | 7 ++- agent/consul/snapshot_endpoint_test.go | 16 +++--- agent/consul/status_endpoint_test.go | 27 ++++++---- agent/pool/pool.go | 72 ++++++-------------------- 6 files changed, 45 insertions(+), 81 deletions(-) diff --git a/agent/consul/client.go b/agent/consul/client.go index c7a36293b..446ba9962 100644 --- a/agent/consul/client.go +++ b/agent/consul/client.go @@ -137,7 +137,6 @@ func NewClientLogger(config *Config, logger hclog.InterceptLogger, tlsConfigurat MaxTime: clientRPCConnMaxIdle, MaxStreams: clientMaxStreams, TLSConfigurator: tlsConfigurator, - ForceTLS: config.VerifyOutgoing, Datacenter: config.Datacenter, } @@ -356,7 +355,7 @@ func (c *Client) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io // Request the operation. var reply structs.SnapshotResponse - snap, err := SnapshotRPC(c.connPool, c.config.Datacenter, server.ShortName, server.Addr, server.UseTLS, args, in, &reply) + snap, err := SnapshotRPC(c.connPool, c.config.Datacenter, server.ShortName, server.Addr, args, in, &reply) if err != nil { return err } diff --git a/agent/consul/server.go b/agent/consul/server.go index fec1efd6c..eb8d1cadd 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -374,7 +374,6 @@ func NewServerLogger(config *Config, logger hclog.InterceptLogger, tokens *token MaxTime: serverRPCCache, MaxStreams: serverMaxStreams, TLSConfigurator: tlsConfigurator, - ForceTLS: config.VerifyOutgoing, Datacenter: config.Datacenter, } diff --git a/agent/consul/snapshot_endpoint.go b/agent/consul/snapshot_endpoint.go index 233354c0e..5f6e3e0f8 100644 --- a/agent/consul/snapshot_endpoint.go +++ b/agent/consul/snapshot_endpoint.go @@ -37,7 +37,7 @@ func (s *Server) dispatchSnapshotRequest(args *structs.SnapshotRequest, in io.Re return nil, structs.ErrNoDCPath } - snap, err := SnapshotRPC(s.connPool, dc, server.ShortName, server.Addr, server.UseTLS, args, in, reply) + snap, err := SnapshotRPC(s.connPool, dc, server.ShortName, server.Addr, args, in, reply) if err != nil { manager.NotifyFailedServer(server) return nil, err @@ -52,7 +52,7 @@ func (s *Server) dispatchSnapshotRequest(args *structs.SnapshotRequest, in io.Re if server == nil { return nil, structs.ErrNoLeader } - return SnapshotRPC(s.connPool, args.Datacenter, server.ShortName, server.Addr, server.UseTLS, args, in, reply) + return SnapshotRPC(s.connPool, args.Datacenter, server.ShortName, server.Addr, args, in, reply) } } @@ -194,14 +194,13 @@ func SnapshotRPC( dc string, nodeName string, addr net.Addr, - useTLS bool, args *structs.SnapshotRequest, in io.Reader, reply *structs.SnapshotResponse, ) (io.ReadCloser, error) { // Write the snapshot RPC byte to set the mode, then perform the // request. - conn, hc, err := connPool.DialTimeout(dc, nodeName, addr, 10*time.Second, useTLS, pool.RPCSnapshot) + conn, hc, err := connPool.DialTimeout(dc, nodeName, addr, 10*time.Second, pool.RPCSnapshot) if err != nil { return nil, err } diff --git a/agent/consul/snapshot_endpoint_test.go b/agent/consul/snapshot_endpoint_test.go index 9073fa01e..e0cd31a1d 100644 --- a/agent/consul/snapshot_endpoint_test.go +++ b/agent/consul/snapshot_endpoint_test.go @@ -46,7 +46,7 @@ func verifySnapshot(t *testing.T, s *Server, dc, token string) { Op: structs.SnapshotSave, } var reply structs.SnapshotResponse - snap, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, + snap, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if err != nil { t.Fatalf("err: %v", err) @@ -121,7 +121,7 @@ func verifySnapshot(t *testing.T, s *Server, dc, token string) { // Restore the snapshot. args.Op = structs.SnapshotRestore - restore, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, + restore, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, &args, snap, &reply) if err != nil { t.Fatalf("err: %v", err) @@ -196,7 +196,7 @@ func TestSnapshot_LeaderState(t *testing.T) { Op: structs.SnapshotSave, } var reply structs.SnapshotResponse - snap, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, + snap, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if err != nil { t.Fatalf("err: %v", err) @@ -229,7 +229,7 @@ func TestSnapshot_LeaderState(t *testing.T) { // Restore the snapshot. args.Op = structs.SnapshotRestore - restore, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, + restore, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, &args, snap, &reply) if err != nil { t.Fatalf("err: %v", err) @@ -268,7 +268,7 @@ func TestSnapshot_ACLDeny(t *testing.T) { Op: structs.SnapshotSave, } var reply structs.SnapshotResponse - _, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, + _, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if !acl.IsErrPermissionDenied(err) { t.Fatalf("err: %v", err) @@ -282,7 +282,7 @@ func TestSnapshot_ACLDeny(t *testing.T) { Op: structs.SnapshotRestore, } var reply structs.SnapshotResponse - _, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, + _, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if !acl.IsErrPermissionDenied(err) { t.Fatalf("err: %v", err) @@ -391,7 +391,7 @@ func TestSnapshot_AllowStale(t *testing.T) { Op: structs.SnapshotSave, } var reply structs.SnapshotResponse - _, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, + _, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if err == nil || !strings.Contains(err.Error(), structs.ErrNoLeader.Error()) { t.Fatalf("err: %v", err) @@ -408,7 +408,7 @@ func TestSnapshot_AllowStale(t *testing.T) { Op: structs.SnapshotSave, } var reply structs.SnapshotResponse - _, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, + _, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if err == nil || !strings.Contains(err.Error(), "Raft error when taking snapshot") { t.Fatalf("err: %v", err) diff --git a/agent/consul/status_endpoint_test.go b/agent/consul/status_endpoint_test.go index a9cc158fa..4ef010830 100644 --- a/agent/consul/status_endpoint_test.go +++ b/agent/consul/status_endpoint_test.go @@ -37,20 +37,25 @@ func insecureRPCClient(s *Server, c tlsutil.Config) (rpc.ClientCodec, error) { if wrapper == nil { return nil, err } - conn, _, err := pool.DialTimeoutWithRPCTypeDirectly( - s.config.Datacenter, - s.config.NodeName, - addr, - nil, - time.Second, - true, - wrapper, - pool.RPCTLSInsecure, - pool.RPCTLSInsecure, - ) + d := &net.Dialer{Timeout: time.Second} + conn, err := d.Dial("tcp", addr.String()) if err != nil { return nil, err } + // Switch the connection into TLS mode + if _, err = conn.Write([]byte{byte(pool.RPCTLSInsecure)}); err != nil { + conn.Close() + return nil, err + } + + // Wrap the connection in a TLS client + tlsConn, err := wrapper(s.config.Datacenter, conn) + if err != nil { + conn.Close() + return nil, err + } + conn = tlsConn + return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle), nil } diff --git a/agent/pool/pool.go b/agent/pool/pool.go index 16dcb7a91..6255f4f4e 100644 --- a/agent/pool/pool.go +++ b/agent/pool/pool.go @@ -146,9 +146,6 @@ type ConnPool struct { // Datacenter is the datacenter of the current agent. Datacenter string - // ForceTLS is used to enforce outgoing TLS verification - ForceTLS bool - // Server should be set to true if this connection pool is configured in a // server instead of a client. Server bool @@ -208,7 +205,7 @@ func (p *ConnPool) Shutdown() error { // wait for an existing connection attempt to finish, if one if in progress, // and will return that one if it succeeds. If all else fails, it will return a // newly-created connection and add it to the pool. -func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, error) { +func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr) (*Conn, error) { if nodeName == "" { return nil, fmt.Errorf("pool: ConnPool.acquire requires a node name") } @@ -243,7 +240,7 @@ func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr, useTLS boo // If we are the lead thread, make the new connection and then wake // everybody else up to see if we got it. if isLeadThread { - c, err := p.getNewConn(dc, nodeName, addr, useTLS) + c, err := p.getNewConn(dc, nodeName, addr) p.Lock() delete(p.limiter, addrStr) close(wait) @@ -290,7 +287,6 @@ func (p *ConnPool) DialTimeout( nodeName string, addr net.Addr, timeout time.Duration, - useTLS bool, actualRPCType RPCType, ) (net.Conn, HalfCloser, error) { p.once.Do(p.init) @@ -314,64 +310,26 @@ func (p *ConnPool) DialTimeout( ) } - return DialTimeoutWithRPCTypeDirectly( + return p.dial( dc, nodeName, addr, - p.SrcAddr, timeout, - useTLS || p.ForceTLS, - p.TLSConfigurator.OutgoingRPCWrapper(), actualRPCType, RPCTLS, ) } -// DialTimeoutInsecure is used to establish a raw connection to the given -// server, with given connection timeout. It also writes RPCTLSInsecure as the -// first byte to indicate that the client cannot provide a certificate. This is -// so far only used for AutoEncrypt.Sign. -func (p *ConnPool) DialTimeoutInsecure( +func (p *ConnPool) dial( dc string, nodeName string, addr net.Addr, timeout time.Duration, - wrapper tlsutil.DCWrapper, -) (net.Conn, HalfCloser, error) { - p.once.Do(p.init) - - if wrapper == nil { - return nil, nil, fmt.Errorf("wrapper cannot be nil") - } else if dc != p.Datacenter { - return nil, nil, fmt.Errorf("insecure dialing prohibited between datacenters") - } - - return DialTimeoutWithRPCTypeDirectly( - dc, - nodeName, - addr, - p.SrcAddr, - timeout, - true, - wrapper, - RPCTLSInsecure, - RPCTLSInsecure, - ) -} - -func DialTimeoutWithRPCTypeDirectly( - dc string, - nodeName string, - addr net.Addr, - src *net.TCPAddr, - timeout time.Duration, - useTLS bool, - wrapper tlsutil.DCWrapper, actualRPCType RPCType, tlsRPCType RPCType, ) (net.Conn, HalfCloser, error) { // Try to dial the conn - d := &net.Dialer{LocalAddr: src, Timeout: timeout} + d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: timeout} conn, err := d.Dial("tcp", addr.String()) if err != nil { return nil, nil, err @@ -388,7 +346,8 @@ func DialTimeoutWithRPCTypeDirectly( } // Check if TLS is enabled - if useTLS && wrapper != nil { + if p.TLSConfigurator.UseTLS(dc) { + wrapper := p.TLSConfigurator.OutgoingRPCWrapper() // Switch the connection into TLS mode if _, err := conn.Write([]byte{byte(tlsRPCType)}); err != nil { conn.Close() @@ -496,13 +455,13 @@ func DialTimeoutWithRPCTypeViaMeshGateway( } // getNewConn is used to return a new connection -func (p *ConnPool) getNewConn(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, error) { +func (p *ConnPool) getNewConn(dc string, nodeName string, addr net.Addr) (*Conn, error) { if nodeName == "" { return nil, fmt.Errorf("pool: ConnPool.getNewConn requires a node name") } // Get a new, raw connection and write the Consul multiplex byte to set the mode - conn, _, err := p.DialTimeout(dc, nodeName, addr, defaultDialTimeout, useTLS, RPCMultiplexV2) + conn, _, err := p.DialTimeout(dc, nodeName, addr, defaultDialTimeout, RPCMultiplexV2) if err != nil { return nil, err } @@ -560,11 +519,11 @@ func (p *ConnPool) releaseConn(conn *Conn) { } // getClient is used to get a usable client for an address -func (p *ConnPool) getClient(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, *StreamClient, error) { +func (p *ConnPool) getClient(dc string, nodeName string, addr net.Addr) (*Conn, *StreamClient, error) { retries := 0 START: // Try to get a conn first - conn, err := p.acquire(dc, nodeName, addr, useTLS) + conn, err := p.acquire(dc, nodeName, addr) if err != nil { return nil, nil, fmt.Errorf("failed to get conn: %v", err) } @@ -611,8 +570,12 @@ func (p *ConnPool) RPC( // AutoEncrypt.Sign is a one-off call and it doesn't make sense to pool that // connection if it is not being reused. func (p *ConnPool) rpcInsecure(dc string, nodeName string, addr net.Addr, method string, args interface{}, reply interface{}) error { + if dc != p.Datacenter { + return fmt.Errorf("insecure dialing prohibited between datacenters") + } + var codec rpc.ClientCodec - conn, _, err := p.DialTimeoutInsecure(dc, nodeName, addr, 1*time.Second, p.TLSConfigurator.OutgoingRPCWrapper()) + conn, _, err := p.dial(dc, nodeName, addr, 1*time.Second, 0, RPCTLSInsecure) if err != nil { return fmt.Errorf("rpcinsecure error establishing connection: %v", err) } @@ -631,8 +594,7 @@ func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, method string, p.once.Do(p.init) // Get a usable client - useTLS := p.TLSConfigurator.UseTLS(dc) - conn, sc, err := p.getClient(dc, nodeName, addr, useTLS) + conn, sc, err := p.getClient(dc, nodeName, addr) if err != nil { return fmt.Errorf("rpc error getting client: %v", err) }