diff --git a/agent/consul/client.go b/agent/consul/client.go index c7a36293b..446ba9962 100644 --- a/agent/consul/client.go +++ b/agent/consul/client.go @@ -137,7 +137,6 @@ func NewClientLogger(config *Config, logger hclog.InterceptLogger, tlsConfigurat MaxTime: clientRPCConnMaxIdle, MaxStreams: clientMaxStreams, TLSConfigurator: tlsConfigurator, - ForceTLS: config.VerifyOutgoing, Datacenter: config.Datacenter, } @@ -356,7 +355,7 @@ func (c *Client) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io // Request the operation. var reply structs.SnapshotResponse - snap, err := SnapshotRPC(c.connPool, c.config.Datacenter, server.ShortName, server.Addr, server.UseTLS, args, in, &reply) + snap, err := SnapshotRPC(c.connPool, c.config.Datacenter, server.ShortName, server.Addr, args, in, &reply) if err != nil { return err } diff --git a/agent/consul/server.go b/agent/consul/server.go index fec1efd6c..eb8d1cadd 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -374,7 +374,6 @@ func NewServerLogger(config *Config, logger hclog.InterceptLogger, tokens *token MaxTime: serverRPCCache, MaxStreams: serverMaxStreams, TLSConfigurator: tlsConfigurator, - ForceTLS: config.VerifyOutgoing, Datacenter: config.Datacenter, } diff --git a/agent/consul/snapshot_endpoint.go b/agent/consul/snapshot_endpoint.go index 233354c0e..5f6e3e0f8 100644 --- a/agent/consul/snapshot_endpoint.go +++ b/agent/consul/snapshot_endpoint.go @@ -37,7 +37,7 @@ func (s *Server) dispatchSnapshotRequest(args *structs.SnapshotRequest, in io.Re return nil, structs.ErrNoDCPath } - snap, err := SnapshotRPC(s.connPool, dc, server.ShortName, server.Addr, server.UseTLS, args, in, reply) + snap, err := SnapshotRPC(s.connPool, dc, server.ShortName, server.Addr, args, in, reply) if err != nil { manager.NotifyFailedServer(server) return nil, err @@ -52,7 +52,7 @@ func (s *Server) dispatchSnapshotRequest(args *structs.SnapshotRequest, in io.Re if server == nil { return nil, structs.ErrNoLeader } - return SnapshotRPC(s.connPool, args.Datacenter, server.ShortName, server.Addr, server.UseTLS, args, in, reply) + return SnapshotRPC(s.connPool, args.Datacenter, server.ShortName, server.Addr, args, in, reply) } } @@ -194,14 +194,13 @@ func SnapshotRPC( dc string, nodeName string, addr net.Addr, - useTLS bool, args *structs.SnapshotRequest, in io.Reader, reply *structs.SnapshotResponse, ) (io.ReadCloser, error) { // Write the snapshot RPC byte to set the mode, then perform the // request. - conn, hc, err := connPool.DialTimeout(dc, nodeName, addr, 10*time.Second, useTLS, pool.RPCSnapshot) + conn, hc, err := connPool.DialTimeout(dc, nodeName, addr, 10*time.Second, pool.RPCSnapshot) if err != nil { return nil, err } diff --git a/agent/consul/snapshot_endpoint_test.go b/agent/consul/snapshot_endpoint_test.go index 9073fa01e..e0cd31a1d 100644 --- a/agent/consul/snapshot_endpoint_test.go +++ b/agent/consul/snapshot_endpoint_test.go @@ -46,7 +46,7 @@ func verifySnapshot(t *testing.T, s *Server, dc, token string) { Op: structs.SnapshotSave, } var reply structs.SnapshotResponse - snap, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, + snap, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if err != nil { t.Fatalf("err: %v", err) @@ -121,7 +121,7 @@ func verifySnapshot(t *testing.T, s *Server, dc, token string) { // Restore the snapshot. args.Op = structs.SnapshotRestore - restore, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, + restore, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, &args, snap, &reply) if err != nil { t.Fatalf("err: %v", err) @@ -196,7 +196,7 @@ func TestSnapshot_LeaderState(t *testing.T) { Op: structs.SnapshotSave, } var reply structs.SnapshotResponse - snap, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, + snap, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if err != nil { t.Fatalf("err: %v", err) @@ -229,7 +229,7 @@ func TestSnapshot_LeaderState(t *testing.T) { // Restore the snapshot. args.Op = structs.SnapshotRestore - restore, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, + restore, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, &args, snap, &reply) if err != nil { t.Fatalf("err: %v", err) @@ -268,7 +268,7 @@ func TestSnapshot_ACLDeny(t *testing.T) { Op: structs.SnapshotSave, } var reply structs.SnapshotResponse - _, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, + _, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if !acl.IsErrPermissionDenied(err) { t.Fatalf("err: %v", err) @@ -282,7 +282,7 @@ func TestSnapshot_ACLDeny(t *testing.T) { Op: structs.SnapshotRestore, } var reply structs.SnapshotResponse - _, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, + _, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if !acl.IsErrPermissionDenied(err) { t.Fatalf("err: %v", err) @@ -391,7 +391,7 @@ func TestSnapshot_AllowStale(t *testing.T) { Op: structs.SnapshotSave, } var reply structs.SnapshotResponse - _, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, + _, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if err == nil || !strings.Contains(err.Error(), structs.ErrNoLeader.Error()) { t.Fatalf("err: %v", err) @@ -408,7 +408,7 @@ func TestSnapshot_AllowStale(t *testing.T) { Op: structs.SnapshotSave, } var reply structs.SnapshotResponse - _, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, + _, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if err == nil || !strings.Contains(err.Error(), "Raft error when taking snapshot") { t.Fatalf("err: %v", err) diff --git a/agent/consul/status_endpoint_test.go b/agent/consul/status_endpoint_test.go index a9cc158fa..4ef010830 100644 --- a/agent/consul/status_endpoint_test.go +++ b/agent/consul/status_endpoint_test.go @@ -37,20 +37,25 @@ func insecureRPCClient(s *Server, c tlsutil.Config) (rpc.ClientCodec, error) { if wrapper == nil { return nil, err } - conn, _, err := pool.DialTimeoutWithRPCTypeDirectly( - s.config.Datacenter, - s.config.NodeName, - addr, - nil, - time.Second, - true, - wrapper, - pool.RPCTLSInsecure, - pool.RPCTLSInsecure, - ) + d := &net.Dialer{Timeout: time.Second} + conn, err := d.Dial("tcp", addr.String()) if err != nil { return nil, err } + // Switch the connection into TLS mode + if _, err = conn.Write([]byte{byte(pool.RPCTLSInsecure)}); err != nil { + conn.Close() + return nil, err + } + + // Wrap the connection in a TLS client + tlsConn, err := wrapper(s.config.Datacenter, conn) + if err != nil { + conn.Close() + return nil, err + } + conn = tlsConn + return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle), nil } diff --git a/agent/pool/pool.go b/agent/pool/pool.go index 16dcb7a91..6255f4f4e 100644 --- a/agent/pool/pool.go +++ b/agent/pool/pool.go @@ -146,9 +146,6 @@ type ConnPool struct { // Datacenter is the datacenter of the current agent. Datacenter string - // ForceTLS is used to enforce outgoing TLS verification - ForceTLS bool - // Server should be set to true if this connection pool is configured in a // server instead of a client. Server bool @@ -208,7 +205,7 @@ func (p *ConnPool) Shutdown() error { // wait for an existing connection attempt to finish, if one if in progress, // and will return that one if it succeeds. If all else fails, it will return a // newly-created connection and add it to the pool. -func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, error) { +func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr) (*Conn, error) { if nodeName == "" { return nil, fmt.Errorf("pool: ConnPool.acquire requires a node name") } @@ -243,7 +240,7 @@ func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr, useTLS boo // If we are the lead thread, make the new connection and then wake // everybody else up to see if we got it. if isLeadThread { - c, err := p.getNewConn(dc, nodeName, addr, useTLS) + c, err := p.getNewConn(dc, nodeName, addr) p.Lock() delete(p.limiter, addrStr) close(wait) @@ -290,7 +287,6 @@ func (p *ConnPool) DialTimeout( nodeName string, addr net.Addr, timeout time.Duration, - useTLS bool, actualRPCType RPCType, ) (net.Conn, HalfCloser, error) { p.once.Do(p.init) @@ -314,64 +310,26 @@ func (p *ConnPool) DialTimeout( ) } - return DialTimeoutWithRPCTypeDirectly( + return p.dial( dc, nodeName, addr, - p.SrcAddr, timeout, - useTLS || p.ForceTLS, - p.TLSConfigurator.OutgoingRPCWrapper(), actualRPCType, RPCTLS, ) } -// DialTimeoutInsecure is used to establish a raw connection to the given -// server, with given connection timeout. It also writes RPCTLSInsecure as the -// first byte to indicate that the client cannot provide a certificate. This is -// so far only used for AutoEncrypt.Sign. -func (p *ConnPool) DialTimeoutInsecure( +func (p *ConnPool) dial( dc string, nodeName string, addr net.Addr, timeout time.Duration, - wrapper tlsutil.DCWrapper, -) (net.Conn, HalfCloser, error) { - p.once.Do(p.init) - - if wrapper == nil { - return nil, nil, fmt.Errorf("wrapper cannot be nil") - } else if dc != p.Datacenter { - return nil, nil, fmt.Errorf("insecure dialing prohibited between datacenters") - } - - return DialTimeoutWithRPCTypeDirectly( - dc, - nodeName, - addr, - p.SrcAddr, - timeout, - true, - wrapper, - RPCTLSInsecure, - RPCTLSInsecure, - ) -} - -func DialTimeoutWithRPCTypeDirectly( - dc string, - nodeName string, - addr net.Addr, - src *net.TCPAddr, - timeout time.Duration, - useTLS bool, - wrapper tlsutil.DCWrapper, actualRPCType RPCType, tlsRPCType RPCType, ) (net.Conn, HalfCloser, error) { // Try to dial the conn - d := &net.Dialer{LocalAddr: src, Timeout: timeout} + d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: timeout} conn, err := d.Dial("tcp", addr.String()) if err != nil { return nil, nil, err @@ -388,7 +346,8 @@ func DialTimeoutWithRPCTypeDirectly( } // Check if TLS is enabled - if useTLS && wrapper != nil { + if p.TLSConfigurator.UseTLS(dc) { + wrapper := p.TLSConfigurator.OutgoingRPCWrapper() // Switch the connection into TLS mode if _, err := conn.Write([]byte{byte(tlsRPCType)}); err != nil { conn.Close() @@ -496,13 +455,13 @@ func DialTimeoutWithRPCTypeViaMeshGateway( } // getNewConn is used to return a new connection -func (p *ConnPool) getNewConn(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, error) { +func (p *ConnPool) getNewConn(dc string, nodeName string, addr net.Addr) (*Conn, error) { if nodeName == "" { return nil, fmt.Errorf("pool: ConnPool.getNewConn requires a node name") } // Get a new, raw connection and write the Consul multiplex byte to set the mode - conn, _, err := p.DialTimeout(dc, nodeName, addr, defaultDialTimeout, useTLS, RPCMultiplexV2) + conn, _, err := p.DialTimeout(dc, nodeName, addr, defaultDialTimeout, RPCMultiplexV2) if err != nil { return nil, err } @@ -560,11 +519,11 @@ func (p *ConnPool) releaseConn(conn *Conn) { } // getClient is used to get a usable client for an address -func (p *ConnPool) getClient(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, *StreamClient, error) { +func (p *ConnPool) getClient(dc string, nodeName string, addr net.Addr) (*Conn, *StreamClient, error) { retries := 0 START: // Try to get a conn first - conn, err := p.acquire(dc, nodeName, addr, useTLS) + conn, err := p.acquire(dc, nodeName, addr) if err != nil { return nil, nil, fmt.Errorf("failed to get conn: %v", err) } @@ -611,8 +570,12 @@ func (p *ConnPool) RPC( // AutoEncrypt.Sign is a one-off call and it doesn't make sense to pool that // connection if it is not being reused. func (p *ConnPool) rpcInsecure(dc string, nodeName string, addr net.Addr, method string, args interface{}, reply interface{}) error { + if dc != p.Datacenter { + return fmt.Errorf("insecure dialing prohibited between datacenters") + } + var codec rpc.ClientCodec - conn, _, err := p.DialTimeoutInsecure(dc, nodeName, addr, 1*time.Second, p.TLSConfigurator.OutgoingRPCWrapper()) + conn, _, err := p.dial(dc, nodeName, addr, 1*time.Second, 0, RPCTLSInsecure) if err != nil { return fmt.Errorf("rpcinsecure error establishing connection: %v", err) } @@ -631,8 +594,7 @@ func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, method string, p.once.Do(p.init) // Get a usable client - useTLS := p.TLSConfigurator.UseTLS(dc) - conn, sc, err := p.getClient(dc, nodeName, addr, useTLS) + conn, sc, err := p.getClient(dc, nodeName, addr) if err != nil { return fmt.Errorf("rpc error getting client: %v", err) }