tlsutil: convert tests for OutgoingTLSConfigForCheck to a table

In preparation for adding more test cases.
This commit is contained in:
Daniel Nephin 2021-06-24 12:51:40 -04:00
parent e0a6946506
commit 2bfdd8ceed
1 changed files with 80 additions and 18 deletions

View File

@ -11,6 +11,8 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/yamux" "github.com/hashicorp/yamux"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -913,26 +915,86 @@ func TestConfigurator_IncomingHTTPSConfig(t *testing.T) {
require.Equal(t, []string{"h2", "http/1.1"}, c.IncomingHTTPSConfig().NextProtos) require.Equal(t, []string{"h2", "http/1.1"}, c.IncomingHTTPSConfig().NextProtos)
} }
func TestConfigurator_OutgoingTLSConfigForChecks(t *testing.T) { func TestConfigurator_OutgoingTLSConfigForCheck(t *testing.T) {
c := Configurator{base: &Config{ type testCase struct {
TLSMinVersion: "tls12", name string
EnableAgentTLSForChecks: false, conf func() (*Configurator, error)
}} skipVerify bool
tlsConf := c.OutgoingTLSConfigForCheck(true, "") serverName string
require.Equal(t, true, tlsConf.InsecureSkipVerify) expected *tls.Config
require.Equal(t, uint16(0), tlsConf.MinVersion) }
c.base.EnableAgentTLSForChecks = true cmpTLSConfig := cmp.Options{
c.base.ServerName = "servername" cmpopts.IgnoreFields(tls.Config{}, "GetCertificate", "GetClientCertificate"),
tlsConf = c.OutgoingTLSConfigForCheck(true, "") cmpopts.IgnoreUnexported(tls.Config{}),
require.Equal(t, true, tlsConf.InsecureSkipVerify) }
require.Equal(t, tlsLookup[c.base.TLSMinVersion], tlsConf.MinVersion)
require.Equal(t, c.base.ServerName, tlsConf.ServerName)
tlsConf = c.OutgoingTLSConfigForCheck(true, "servername2") run := func(t *testing.T, tc testCase) {
require.Equal(t, true, tlsConf.InsecureSkipVerify) configurator, err := tc.conf()
require.Equal(t, tlsLookup[c.base.TLSMinVersion], tlsConf.MinVersion) require.NoError(t, err)
require.Equal(t, "servername2", tlsConf.ServerName) c := configurator.OutgoingTLSConfigForCheck(tc.skipVerify, tc.serverName)
assertDeepEqual(t, tc.expected, c, cmpTLSConfig)
}
testCases := []testCase{
{
name: "default tls, skip verify, no server name",
conf: func() (*Configurator, error) {
return NewConfigurator(Config{
TLSMinVersion: "tls12",
EnableAgentTLSForChecks: false,
}, nil)
},
skipVerify: true,
expected: &tls.Config{InsecureSkipVerify: true},
},
{
name: "agent tls, skip verify, default server name",
conf: func() (*Configurator, error) {
return NewConfigurator(Config{
TLSMinVersion: "tls12",
EnableAgentTLSForChecks: true,
ServerName: "servername",
}, nil)
},
skipVerify: true,
expected: &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS12,
ServerName: "servername",
},
},
{
name: "agent tls, skip verify, with server name override",
conf: func() (*Configurator, error) {
return NewConfigurator(Config{
TLSMinVersion: "tls12",
EnableAgentTLSForChecks: true,
ServerName: "servername",
}, nil)
},
skipVerify: true,
serverName: "override",
expected: &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS12,
ServerName: "override",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
func assertDeepEqual(t *testing.T, x, y interface{}, opts ...cmp.Option) {
t.Helper()
if diff := cmp.Diff(x, y, opts...); diff != "" {
t.Fatalf("assertion failed: values are not equal\n--- expected\n+++ actual\n%v", diff)
}
} }
func TestConfigurator_OutgoingRPCConfig(t *testing.T) { func TestConfigurator_OutgoingRPCConfig(t *testing.T) {