diff --git a/tlsutil/config_test.go b/tlsutil/config_test.go index 19b04d4a6..91d4f1aae 100644 --- a/tlsutil/config_test.go +++ b/tlsutil/config_test.go @@ -11,6 +11,8 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/hashicorp/go-hclog" "github.com/hashicorp/yamux" "github.com/stretchr/testify/require" @@ -913,26 +915,86 @@ func TestConfigurator_IncomingHTTPSConfig(t *testing.T) { require.Equal(t, []string{"h2", "http/1.1"}, c.IncomingHTTPSConfig().NextProtos) } -func TestConfigurator_OutgoingTLSConfigForChecks(t *testing.T) { - c := Configurator{base: &Config{ - TLSMinVersion: "tls12", - EnableAgentTLSForChecks: false, - }} - tlsConf := c.OutgoingTLSConfigForCheck(true, "") - require.Equal(t, true, tlsConf.InsecureSkipVerify) - require.Equal(t, uint16(0), tlsConf.MinVersion) +func TestConfigurator_OutgoingTLSConfigForCheck(t *testing.T) { + type testCase struct { + name string + conf func() (*Configurator, error) + skipVerify bool + serverName string + expected *tls.Config + } - c.base.EnableAgentTLSForChecks = true - c.base.ServerName = "servername" - tlsConf = c.OutgoingTLSConfigForCheck(true, "") - require.Equal(t, true, tlsConf.InsecureSkipVerify) - require.Equal(t, tlsLookup[c.base.TLSMinVersion], tlsConf.MinVersion) - require.Equal(t, c.base.ServerName, tlsConf.ServerName) + cmpTLSConfig := cmp.Options{ + cmpopts.IgnoreFields(tls.Config{}, "GetCertificate", "GetClientCertificate"), + cmpopts.IgnoreUnexported(tls.Config{}), + } - tlsConf = c.OutgoingTLSConfigForCheck(true, "servername2") - require.Equal(t, true, tlsConf.InsecureSkipVerify) - require.Equal(t, tlsLookup[c.base.TLSMinVersion], tlsConf.MinVersion) - require.Equal(t, "servername2", tlsConf.ServerName) + run := func(t *testing.T, tc testCase) { + configurator, err := tc.conf() + require.NoError(t, err) + c := configurator.OutgoingTLSConfigForCheck(tc.skipVerify, tc.serverName) + assertDeepEqual(t, tc.expected, c, cmpTLSConfig) + } + + testCases := []testCase{ + { + name: "default tls, skip verify, no server name", + conf: func() (*Configurator, error) { + return NewConfigurator(Config{ + TLSMinVersion: "tls12", + EnableAgentTLSForChecks: false, + }, nil) + }, + skipVerify: true, + expected: &tls.Config{InsecureSkipVerify: true}, + }, + { + name: "agent tls, skip verify, default server name", + conf: func() (*Configurator, error) { + return NewConfigurator(Config{ + TLSMinVersion: "tls12", + EnableAgentTLSForChecks: true, + ServerName: "servername", + }, nil) + }, + skipVerify: true, + expected: &tls.Config{ + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS12, + ServerName: "servername", + }, + }, + { + name: "agent tls, skip verify, with server name override", + conf: func() (*Configurator, error) { + return NewConfigurator(Config{ + TLSMinVersion: "tls12", + EnableAgentTLSForChecks: true, + ServerName: "servername", + }, nil) + }, + skipVerify: true, + serverName: "override", + expected: &tls.Config{ + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS12, + ServerName: "override", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + run(t, tc) + }) + } +} + +func assertDeepEqual(t *testing.T, x, y interface{}, opts ...cmp.Option) { + t.Helper() + if diff := cmp.Diff(x, y, opts...); diff != "" { + t.Fatalf("assertion failed: values are not equal\n--- expected\n+++ actual\n%v", diff) + } } func TestConfigurator_OutgoingRPCConfig(t *testing.T) {