diff --git a/tlsutil/config.go b/tlsutil/config.go index 9b7835d95..6fcdb1a2e 100644 --- a/tlsutil/config.go +++ b/tlsutil/config.go @@ -537,21 +537,12 @@ func (c *Configurator) VerifyIncomingRPC() bool { } // This function acquires a read lock because it reads from the config. -func (c *Configurator) outgoingRPCTLSDisabled() bool { +func (c *Configurator) outgoingRPCTLSEnabled() bool { c.lock.RLock() defer c.lock.RUnlock() - // if AutoEncrypt enabled, always use TLS - if c.base.AutoTLS { - return false - } - - // if CAs are provided or VerifyOutgoing is set, use TLS - if c.base.VerifyOutgoing { - return false - } - - return true + // use TLS if AutoEncrypt or VerifyOutgoing are enabled. + return c.base.AutoTLS || c.base.VerifyOutgoing } // MutualTLSCapable returns true if Configurator has a CA and a local TLS @@ -716,7 +707,7 @@ func (c *Configurator) OutgoingTLSConfigForCheck(skipVerify bool, serverName str // otherwise we assume that no TLS should be used. func (c *Configurator) OutgoingRPCConfig() *tls.Config { c.log("OutgoingRPCConfig") - if c.outgoingRPCTLSDisabled() { + if !c.outgoingRPCTLSEnabled() { return nil } return c.commonTLSConfig(false) @@ -754,8 +745,10 @@ func (c *Configurator) OutgoingRPCWrapper() DCWrapper { } } +// UseTLS returns true if the outgoing RPC requests have been explicitly configured +// to use TLS (via VerifyOutgoing or AutoTLS, and the target DC supports TLS. func (c *Configurator) UseTLS(dc string) bool { - return !c.outgoingRPCTLSDisabled() && c.getAreaForPeerDatacenterUseTLS(dc) + return c.outgoingRPCTLSEnabled() && c.getAreaForPeerDatacenterUseTLS(dc) } // OutgoingALPNRPCWrapper wraps the result of outgoingALPNRPCConfig in an diff --git a/tlsutil/config_test.go b/tlsutil/config_test.go index d1586bdaa..d0b8b9d2b 100644 --- a/tlsutil/config_test.go +++ b/tlsutil/config_test.go @@ -741,22 +741,21 @@ func TestConfigurator_OutgoingRPCTLSDisabled(t *testing.T) { expected bool } variants := []variant{ - {false, false, nil, true}, - {true, false, nil, false}, - {false, true, nil, false}, - {true, true, nil, false}, + {false, false, nil, false}, + {true, false, nil, true}, + {false, true, nil, true}, + {true, true, nil, true}, - // {false, false, &x509.CertPool{}, false}, - {true, false, &x509.CertPool{}, false}, - {false, true, &x509.CertPool{}, false}, - {true, true, &x509.CertPool{}, false}, + {true, false, &x509.CertPool{}, true}, + {false, true, &x509.CertPool{}, true}, + {true, true, &x509.CertPool{}, true}, } for i, v := range variants { info := fmt.Sprintf("case %d", i) c.caPool = v.pool c.base.VerifyOutgoing = v.verify c.base.AutoTLS = v.autoEncryptTLS - require.Equal(t, v.expected, c.outgoingRPCTLSDisabled(), info) + require.Equal(t, v.expected, c.outgoingRPCTLSEnabled(), info) } }