From 77dde1df386a109ee79403567a239220ebeb0fe2 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Thu, 24 Jun 2021 14:45:52 -0400 Subject: [PATCH] tlsutil: inline verifyIncomingHTTPS This function was only used in one place, and the indirection makes it slightly harder to see what the one caller is doing. Since it's only accesing a couple fields it seems like the logic can exist in the one caller. --- tlsutil/config.go | 14 +++++----- tlsutil/config_test.go | 61 +++++++++++++++++++++++++++++++----------- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/tlsutil/config.go b/tlsutil/config.go index d933219a6..9b7835d95 100644 --- a/tlsutil/config.go +++ b/tlsutil/config.go @@ -594,13 +594,6 @@ func (c *Configurator) domain() string { return c.base.Domain } -// This function acquires a read lock because it reads from the config. -func (c *Configurator) verifyIncomingHTTPS() bool { - c.lock.RLock() - defer c.lock.RUnlock() - return c.base.VerifyIncoming || c.base.VerifyIncomingHTTPS -} - // This function acquires a read lock because it reads from the config. func (c *Configurator) serverNameOrNodeName() string { c.lock.RLock() @@ -677,7 +670,12 @@ func (c *Configurator) IncomingInsecureRPCConfig() *tls.Config { // IncomingHTTPSConfig generates a *tls.Config for incoming HTTPS connections. func (c *Configurator) IncomingHTTPSConfig() *tls.Config { c.log("IncomingHTTPSConfig") - config := c.commonTLSConfig(c.verifyIncomingHTTPS()) + + c.lock.RLock() + verifyIncoming := c.base.VerifyIncoming || c.base.VerifyIncomingHTTPS + c.lock.RUnlock() + + config := c.commonTLSConfig(verifyIncoming) config.NextProtos = []string{"h2", "http/1.1"} config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return c.IncomingHTTPSConfig(), nil diff --git a/tlsutil/config_test.go b/tlsutil/config_test.go index 8ca985186..d1586bdaa 100644 --- a/tlsutil/config_test.go +++ b/tlsutil/config_test.go @@ -850,14 +850,6 @@ func TestConfigurator_VerifyIncomingRPC(t *testing.T) { require.Equal(t, c.base.VerifyIncomingRPC, verify) } -func TestConfigurator_VerifyIncomingHTTPS(t *testing.T) { - c := Configurator{base: &Config{ - VerifyIncomingHTTPS: true, - }} - verify := c.verifyIncomingHTTPS() - require.Equal(t, c.base.VerifyIncomingHTTPS, verify) -} - func TestConfigurator_IncomingRPCConfig(t *testing.T) { c, err := NewConfigurator(Config{ VerifyIncomingRPC: true, @@ -903,8 +895,52 @@ func TestConfigurator_IncomingALPNRPCConfig(t *testing.T) { } func TestConfigurator_IncomingHTTPSConfig(t *testing.T) { - c := Configurator{base: &Config{}} - require.Equal(t, []string{"h2", "http/1.1"}, c.IncomingHTTPSConfig().NextProtos) + + // compare tls.Config.GetConfigForClient by nil/not-nil, since Go can not compare + // functions any other way. + cmpClientFunc := cmp.Comparer(func(x, y func(*tls.ClientHelloInfo) (*tls.Config, error)) bool { + return (x == nil && y == nil) || (x != nil && y != nil) + }) + + t.Run("default", func(t *testing.T) { + c, err := NewConfigurator(Config{}, nil) + require.NoError(t, err) + + cfg := c.IncomingHTTPSConfig() + + expected := &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS10, + GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + }, + } + assertDeepEqual(t, expected, cfg, cmpTLSConfig, cmpClientFunc) + }) + + t.Run("verify incoming", func(t *testing.T) { + c := Configurator{base: &Config{VerifyIncoming: true}} + + cfg := c.IncomingHTTPSConfig() + + expected := &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS10, + GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + }, + ClientAuth: tls.RequireAndVerifyClientCert, + } + assertDeepEqual(t, expected, cfg, cmpTLSConfig, cmpClientFunc) + }) + +} + +var cmpTLSConfig = cmp.Options{ + cmpopts.IgnoreFields(tls.Config{}, "GetCertificate", "GetClientCertificate"), + cmpopts.IgnoreUnexported(tls.Config{}), } func TestConfigurator_OutgoingTLSConfigForCheck(t *testing.T) { @@ -916,11 +952,6 @@ func TestConfigurator_OutgoingTLSConfigForCheck(t *testing.T) { expected *tls.Config } - cmpTLSConfig := cmp.Options{ - cmpopts.IgnoreFields(tls.Config{}, "GetCertificate", "GetClientCertificate"), - cmpopts.IgnoreUnexported(tls.Config{}), - } - run := func(t *testing.T, tc testCase) { configurator, err := tc.conf() require.NoError(t, err)