tlsutil: inline verifyIncomingHTTPS

This function was only used in one place, and the indirection makes it slightly
harder to see what the one caller is doing. Since it's only accesing a couple fields
it seems like the logic can exist in the one caller.
This commit is contained in:
Daniel Nephin 2021-06-24 14:45:52 -04:00
parent 7342c7e977
commit 77dde1df38
2 changed files with 52 additions and 23 deletions

View File

@ -594,13 +594,6 @@ func (c *Configurator) domain() string {
return c.base.Domain return c.base.Domain
} }
// This function acquires a read lock because it reads from the config.
func (c *Configurator) verifyIncomingHTTPS() bool {
c.lock.RLock()
defer c.lock.RUnlock()
return c.base.VerifyIncoming || c.base.VerifyIncomingHTTPS
}
// This function acquires a read lock because it reads from the config. // This function acquires a read lock because it reads from the config.
func (c *Configurator) serverNameOrNodeName() string { func (c *Configurator) serverNameOrNodeName() string {
c.lock.RLock() c.lock.RLock()
@ -677,7 +670,12 @@ func (c *Configurator) IncomingInsecureRPCConfig() *tls.Config {
// IncomingHTTPSConfig generates a *tls.Config for incoming HTTPS connections. // IncomingHTTPSConfig generates a *tls.Config for incoming HTTPS connections.
func (c *Configurator) IncomingHTTPSConfig() *tls.Config { func (c *Configurator) IncomingHTTPSConfig() *tls.Config {
c.log("IncomingHTTPSConfig") c.log("IncomingHTTPSConfig")
config := c.commonTLSConfig(c.verifyIncomingHTTPS())
c.lock.RLock()
verifyIncoming := c.base.VerifyIncoming || c.base.VerifyIncomingHTTPS
c.lock.RUnlock()
config := c.commonTLSConfig(verifyIncoming)
config.NextProtos = []string{"h2", "http/1.1"} config.NextProtos = []string{"h2", "http/1.1"}
config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
return c.IncomingHTTPSConfig(), nil return c.IncomingHTTPSConfig(), nil

View File

@ -850,14 +850,6 @@ func TestConfigurator_VerifyIncomingRPC(t *testing.T) {
require.Equal(t, c.base.VerifyIncomingRPC, verify) require.Equal(t, c.base.VerifyIncomingRPC, verify)
} }
func TestConfigurator_VerifyIncomingHTTPS(t *testing.T) {
c := Configurator{base: &Config{
VerifyIncomingHTTPS: true,
}}
verify := c.verifyIncomingHTTPS()
require.Equal(t, c.base.VerifyIncomingHTTPS, verify)
}
func TestConfigurator_IncomingRPCConfig(t *testing.T) { func TestConfigurator_IncomingRPCConfig(t *testing.T) {
c, err := NewConfigurator(Config{ c, err := NewConfigurator(Config{
VerifyIncomingRPC: true, VerifyIncomingRPC: true,
@ -903,8 +895,52 @@ func TestConfigurator_IncomingALPNRPCConfig(t *testing.T) {
} }
func TestConfigurator_IncomingHTTPSConfig(t *testing.T) { func TestConfigurator_IncomingHTTPSConfig(t *testing.T) {
c := Configurator{base: &Config{}}
require.Equal(t, []string{"h2", "http/1.1"}, c.IncomingHTTPSConfig().NextProtos) // compare tls.Config.GetConfigForClient by nil/not-nil, since Go can not compare
// functions any other way.
cmpClientFunc := cmp.Comparer(func(x, y func(*tls.ClientHelloInfo) (*tls.Config, error)) bool {
return (x == nil && y == nil) || (x != nil && y != nil)
})
t.Run("default", func(t *testing.T) {
c, err := NewConfigurator(Config{}, nil)
require.NoError(t, err)
cfg := c.IncomingHTTPSConfig()
expected := &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS10,
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
return nil, nil
},
}
assertDeepEqual(t, expected, cfg, cmpTLSConfig, cmpClientFunc)
})
t.Run("verify incoming", func(t *testing.T) {
c := Configurator{base: &Config{VerifyIncoming: true}}
cfg := c.IncomingHTTPSConfig()
expected := &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS10,
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
return nil, nil
},
ClientAuth: tls.RequireAndVerifyClientCert,
}
assertDeepEqual(t, expected, cfg, cmpTLSConfig, cmpClientFunc)
})
}
var cmpTLSConfig = cmp.Options{
cmpopts.IgnoreFields(tls.Config{}, "GetCertificate", "GetClientCertificate"),
cmpopts.IgnoreUnexported(tls.Config{}),
} }
func TestConfigurator_OutgoingTLSConfigForCheck(t *testing.T) { func TestConfigurator_OutgoingTLSConfigForCheck(t *testing.T) {
@ -916,11 +952,6 @@ func TestConfigurator_OutgoingTLSConfigForCheck(t *testing.T) {
expected *tls.Config expected *tls.Config
} }
cmpTLSConfig := cmp.Options{
cmpopts.IgnoreFields(tls.Config{}, "GetCertificate", "GetClientCertificate"),
cmpopts.IgnoreUnexported(tls.Config{}),
}
run := func(t *testing.T, tc testCase) { run := func(t *testing.T, tc testCase) {
configurator, err := tc.conf() configurator, err := tc.conf()
require.NoError(t, err) require.NoError(t, err)