pool: remove useTLS and ForceTLS

In the past TLS usage was enforced with these variables, but these days
this decision is made by TLSConfigurator and there is no reason to keep
using the variables.
This commit is contained in:
Hans Hasselberg 2020-05-28 10:18:30 +02:00
parent 9ef44ec3da
commit 5cda505495
6 changed files with 45 additions and 81 deletions

View File

@ -137,7 +137,6 @@ func NewClientLogger(config *Config, logger hclog.InterceptLogger, tlsConfigurat
MaxTime: clientRPCConnMaxIdle, MaxTime: clientRPCConnMaxIdle,
MaxStreams: clientMaxStreams, MaxStreams: clientMaxStreams,
TLSConfigurator: tlsConfigurator, TLSConfigurator: tlsConfigurator,
ForceTLS: config.VerifyOutgoing,
Datacenter: config.Datacenter, Datacenter: config.Datacenter,
} }
@ -356,7 +355,7 @@ func (c *Client) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io
// Request the operation. // Request the operation.
var reply structs.SnapshotResponse var reply structs.SnapshotResponse
snap, err := SnapshotRPC(c.connPool, c.config.Datacenter, server.ShortName, server.Addr, server.UseTLS, args, in, &reply) snap, err := SnapshotRPC(c.connPool, c.config.Datacenter, server.ShortName, server.Addr, args, in, &reply)
if err != nil { if err != nil {
return err return err
} }

View File

@ -374,7 +374,6 @@ func NewServerLogger(config *Config, logger hclog.InterceptLogger, tokens *token
MaxTime: serverRPCCache, MaxTime: serverRPCCache,
MaxStreams: serverMaxStreams, MaxStreams: serverMaxStreams,
TLSConfigurator: tlsConfigurator, TLSConfigurator: tlsConfigurator,
ForceTLS: config.VerifyOutgoing,
Datacenter: config.Datacenter, Datacenter: config.Datacenter,
} }

View File

@ -37,7 +37,7 @@ func (s *Server) dispatchSnapshotRequest(args *structs.SnapshotRequest, in io.Re
return nil, structs.ErrNoDCPath return nil, structs.ErrNoDCPath
} }
snap, err := SnapshotRPC(s.connPool, dc, server.ShortName, server.Addr, server.UseTLS, args, in, reply) snap, err := SnapshotRPC(s.connPool, dc, server.ShortName, server.Addr, args, in, reply)
if err != nil { if err != nil {
manager.NotifyFailedServer(server) manager.NotifyFailedServer(server)
return nil, err return nil, err
@ -52,7 +52,7 @@ func (s *Server) dispatchSnapshotRequest(args *structs.SnapshotRequest, in io.Re
if server == nil { if server == nil {
return nil, structs.ErrNoLeader return nil, structs.ErrNoLeader
} }
return SnapshotRPC(s.connPool, args.Datacenter, server.ShortName, server.Addr, server.UseTLS, args, in, reply) return SnapshotRPC(s.connPool, args.Datacenter, server.ShortName, server.Addr, args, in, reply)
} }
} }
@ -194,14 +194,13 @@ func SnapshotRPC(
dc string, dc string,
nodeName string, nodeName string,
addr net.Addr, addr net.Addr,
useTLS bool,
args *structs.SnapshotRequest, args *structs.SnapshotRequest,
in io.Reader, in io.Reader,
reply *structs.SnapshotResponse, reply *structs.SnapshotResponse,
) (io.ReadCloser, error) { ) (io.ReadCloser, error) {
// Write the snapshot RPC byte to set the mode, then perform the // Write the snapshot RPC byte to set the mode, then perform the
// request. // request.
conn, hc, err := connPool.DialTimeout(dc, nodeName, addr, 10*time.Second, useTLS, pool.RPCSnapshot) conn, hc, err := connPool.DialTimeout(dc, nodeName, addr, 10*time.Second, pool.RPCSnapshot)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -46,7 +46,7 @@ func verifySnapshot(t *testing.T, s *Server, dc, token string) {
Op: structs.SnapshotSave, Op: structs.SnapshotSave,
} }
var reply structs.SnapshotResponse var reply structs.SnapshotResponse
snap, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, snap, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr,
&args, bytes.NewReader([]byte("")), &reply) &args, bytes.NewReader([]byte("")), &reply)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
@ -121,7 +121,7 @@ func verifySnapshot(t *testing.T, s *Server, dc, token string) {
// Restore the snapshot. // Restore the snapshot.
args.Op = structs.SnapshotRestore args.Op = structs.SnapshotRestore
restore, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, restore, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr,
&args, snap, &reply) &args, snap, &reply)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
@ -196,7 +196,7 @@ func TestSnapshot_LeaderState(t *testing.T) {
Op: structs.SnapshotSave, Op: structs.SnapshotSave,
} }
var reply structs.SnapshotResponse var reply structs.SnapshotResponse
snap, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, snap, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr,
&args, bytes.NewReader([]byte("")), &reply) &args, bytes.NewReader([]byte("")), &reply)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
@ -229,7 +229,7 @@ func TestSnapshot_LeaderState(t *testing.T) {
// Restore the snapshot. // Restore the snapshot.
args.Op = structs.SnapshotRestore args.Op = structs.SnapshotRestore
restore, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, restore, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr,
&args, snap, &reply) &args, snap, &reply)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
@ -268,7 +268,7 @@ func TestSnapshot_ACLDeny(t *testing.T) {
Op: structs.SnapshotSave, Op: structs.SnapshotSave,
} }
var reply structs.SnapshotResponse var reply structs.SnapshotResponse
_, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, _, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr,
&args, bytes.NewReader([]byte("")), &reply) &args, bytes.NewReader([]byte("")), &reply)
if !acl.IsErrPermissionDenied(err) { if !acl.IsErrPermissionDenied(err) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
@ -282,7 +282,7 @@ func TestSnapshot_ACLDeny(t *testing.T) {
Op: structs.SnapshotRestore, Op: structs.SnapshotRestore,
} }
var reply structs.SnapshotResponse var reply structs.SnapshotResponse
_, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, _, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr,
&args, bytes.NewReader([]byte("")), &reply) &args, bytes.NewReader([]byte("")), &reply)
if !acl.IsErrPermissionDenied(err) { if !acl.IsErrPermissionDenied(err) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
@ -391,7 +391,7 @@ func TestSnapshot_AllowStale(t *testing.T) {
Op: structs.SnapshotSave, Op: structs.SnapshotSave,
} }
var reply structs.SnapshotResponse var reply structs.SnapshotResponse
_, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, _, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr,
&args, bytes.NewReader([]byte("")), &reply) &args, bytes.NewReader([]byte("")), &reply)
if err == nil || !strings.Contains(err.Error(), structs.ErrNoLeader.Error()) { if err == nil || !strings.Contains(err.Error(), structs.ErrNoLeader.Error()) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
@ -408,7 +408,7 @@ func TestSnapshot_AllowStale(t *testing.T) {
Op: structs.SnapshotSave, Op: structs.SnapshotSave,
} }
var reply structs.SnapshotResponse var reply structs.SnapshotResponse
_, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, _, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr,
&args, bytes.NewReader([]byte("")), &reply) &args, bytes.NewReader([]byte("")), &reply)
if err == nil || !strings.Contains(err.Error(), "Raft error when taking snapshot") { if err == nil || !strings.Contains(err.Error(), "Raft error when taking snapshot") {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)

View File

@ -37,20 +37,25 @@ func insecureRPCClient(s *Server, c tlsutil.Config) (rpc.ClientCodec, error) {
if wrapper == nil { if wrapper == nil {
return nil, err return nil, err
} }
conn, _, err := pool.DialTimeoutWithRPCTypeDirectly( d := &net.Dialer{Timeout: time.Second}
s.config.Datacenter, conn, err := d.Dial("tcp", addr.String())
s.config.NodeName,
addr,
nil,
time.Second,
true,
wrapper,
pool.RPCTLSInsecure,
pool.RPCTLSInsecure,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Switch the connection into TLS mode
if _, err = conn.Write([]byte{byte(pool.RPCTLSInsecure)}); err != nil {
conn.Close()
return nil, err
}
// Wrap the connection in a TLS client
tlsConn, err := wrapper(s.config.Datacenter, conn)
if err != nil {
conn.Close()
return nil, err
}
conn = tlsConn
return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle), nil return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle), nil
} }

View File

@ -146,9 +146,6 @@ type ConnPool struct {
// Datacenter is the datacenter of the current agent. // Datacenter is the datacenter of the current agent.
Datacenter string Datacenter string
// ForceTLS is used to enforce outgoing TLS verification
ForceTLS bool
// Server should be set to true if this connection pool is configured in a // Server should be set to true if this connection pool is configured in a
// server instead of a client. // server instead of a client.
Server bool Server bool
@ -208,7 +205,7 @@ func (p *ConnPool) Shutdown() error {
// wait for an existing connection attempt to finish, if one if in progress, // wait for an existing connection attempt to finish, if one if in progress,
// and will return that one if it succeeds. If all else fails, it will return a // and will return that one if it succeeds. If all else fails, it will return a
// newly-created connection and add it to the pool. // newly-created connection and add it to the pool.
func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, error) { func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr) (*Conn, error) {
if nodeName == "" { if nodeName == "" {
return nil, fmt.Errorf("pool: ConnPool.acquire requires a node name") return nil, fmt.Errorf("pool: ConnPool.acquire requires a node name")
} }
@ -243,7 +240,7 @@ func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr, useTLS boo
// If we are the lead thread, make the new connection and then wake // If we are the lead thread, make the new connection and then wake
// everybody else up to see if we got it. // everybody else up to see if we got it.
if isLeadThread { if isLeadThread {
c, err := p.getNewConn(dc, nodeName, addr, useTLS) c, err := p.getNewConn(dc, nodeName, addr)
p.Lock() p.Lock()
delete(p.limiter, addrStr) delete(p.limiter, addrStr)
close(wait) close(wait)
@ -290,7 +287,6 @@ func (p *ConnPool) DialTimeout(
nodeName string, nodeName string,
addr net.Addr, addr net.Addr,
timeout time.Duration, timeout time.Duration,
useTLS bool,
actualRPCType RPCType, actualRPCType RPCType,
) (net.Conn, HalfCloser, error) { ) (net.Conn, HalfCloser, error) {
p.once.Do(p.init) p.once.Do(p.init)
@ -314,64 +310,26 @@ func (p *ConnPool) DialTimeout(
) )
} }
return DialTimeoutWithRPCTypeDirectly( return p.dial(
dc, dc,
nodeName, nodeName,
addr, addr,
p.SrcAddr,
timeout, timeout,
useTLS || p.ForceTLS,
p.TLSConfigurator.OutgoingRPCWrapper(),
actualRPCType, actualRPCType,
RPCTLS, RPCTLS,
) )
} }
// DialTimeoutInsecure is used to establish a raw connection to the given func (p *ConnPool) dial(
// server, with given connection timeout. It also writes RPCTLSInsecure as the
// first byte to indicate that the client cannot provide a certificate. This is
// so far only used for AutoEncrypt.Sign.
func (p *ConnPool) DialTimeoutInsecure(
dc string, dc string,
nodeName string, nodeName string,
addr net.Addr, addr net.Addr,
timeout time.Duration, timeout time.Duration,
wrapper tlsutil.DCWrapper,
) (net.Conn, HalfCloser, error) {
p.once.Do(p.init)
if wrapper == nil {
return nil, nil, fmt.Errorf("wrapper cannot be nil")
} else if dc != p.Datacenter {
return nil, nil, fmt.Errorf("insecure dialing prohibited between datacenters")
}
return DialTimeoutWithRPCTypeDirectly(
dc,
nodeName,
addr,
p.SrcAddr,
timeout,
true,
wrapper,
RPCTLSInsecure,
RPCTLSInsecure,
)
}
func DialTimeoutWithRPCTypeDirectly(
dc string,
nodeName string,
addr net.Addr,
src *net.TCPAddr,
timeout time.Duration,
useTLS bool,
wrapper tlsutil.DCWrapper,
actualRPCType RPCType, actualRPCType RPCType,
tlsRPCType RPCType, tlsRPCType RPCType,
) (net.Conn, HalfCloser, error) { ) (net.Conn, HalfCloser, error) {
// Try to dial the conn // Try to dial the conn
d := &net.Dialer{LocalAddr: src, Timeout: timeout} d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: timeout}
conn, err := d.Dial("tcp", addr.String()) conn, err := d.Dial("tcp", addr.String())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -388,7 +346,8 @@ func DialTimeoutWithRPCTypeDirectly(
} }
// Check if TLS is enabled // Check if TLS is enabled
if useTLS && wrapper != nil { if p.TLSConfigurator.UseTLS(dc) {
wrapper := p.TLSConfigurator.OutgoingRPCWrapper()
// Switch the connection into TLS mode // Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(tlsRPCType)}); err != nil { if _, err := conn.Write([]byte{byte(tlsRPCType)}); err != nil {
conn.Close() conn.Close()
@ -496,13 +455,13 @@ func DialTimeoutWithRPCTypeViaMeshGateway(
} }
// getNewConn is used to return a new connection // getNewConn is used to return a new connection
func (p *ConnPool) getNewConn(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, error) { func (p *ConnPool) getNewConn(dc string, nodeName string, addr net.Addr) (*Conn, error) {
if nodeName == "" { if nodeName == "" {
return nil, fmt.Errorf("pool: ConnPool.getNewConn requires a node name") return nil, fmt.Errorf("pool: ConnPool.getNewConn requires a node name")
} }
// Get a new, raw connection and write the Consul multiplex byte to set the mode // Get a new, raw connection and write the Consul multiplex byte to set the mode
conn, _, err := p.DialTimeout(dc, nodeName, addr, defaultDialTimeout, useTLS, RPCMultiplexV2) conn, _, err := p.DialTimeout(dc, nodeName, addr, defaultDialTimeout, RPCMultiplexV2)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -560,11 +519,11 @@ func (p *ConnPool) releaseConn(conn *Conn) {
} }
// getClient is used to get a usable client for an address // getClient is used to get a usable client for an address
func (p *ConnPool) getClient(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, *StreamClient, error) { func (p *ConnPool) getClient(dc string, nodeName string, addr net.Addr) (*Conn, *StreamClient, error) {
retries := 0 retries := 0
START: START:
// Try to get a conn first // Try to get a conn first
conn, err := p.acquire(dc, nodeName, addr, useTLS) conn, err := p.acquire(dc, nodeName, addr)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to get conn: %v", err) return nil, nil, fmt.Errorf("failed to get conn: %v", err)
} }
@ -611,8 +570,12 @@ func (p *ConnPool) RPC(
// AutoEncrypt.Sign is a one-off call and it doesn't make sense to pool that // AutoEncrypt.Sign is a one-off call and it doesn't make sense to pool that
// connection if it is not being reused. // connection if it is not being reused.
func (p *ConnPool) rpcInsecure(dc string, nodeName string, addr net.Addr, method string, args interface{}, reply interface{}) error { func (p *ConnPool) rpcInsecure(dc string, nodeName string, addr net.Addr, method string, args interface{}, reply interface{}) error {
if dc != p.Datacenter {
return fmt.Errorf("insecure dialing prohibited between datacenters")
}
var codec rpc.ClientCodec var codec rpc.ClientCodec
conn, _, err := p.DialTimeoutInsecure(dc, nodeName, addr, 1*time.Second, p.TLSConfigurator.OutgoingRPCWrapper()) conn, _, err := p.dial(dc, nodeName, addr, 1*time.Second, 0, RPCTLSInsecure)
if err != nil { if err != nil {
return fmt.Errorf("rpcinsecure error establishing connection: %v", err) return fmt.Errorf("rpcinsecure error establishing connection: %v", err)
} }
@ -631,8 +594,7 @@ func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, method string,
p.once.Do(p.init) p.once.Do(p.init)
// Get a usable client // Get a usable client
useTLS := p.TLSConfigurator.UseTLS(dc) conn, sc, err := p.getClient(dc, nodeName, addr)
conn, sc, err := p.getClient(dc, nodeName, addr, useTLS)
if err != nil { if err != nil {
return fmt.Errorf("rpc error getting client: %v", err) return fmt.Errorf("rpc error getting client: %v", err)
} }