diff --git a/tlsutil/config.go b/tlsutil/config.go index f3bb580b2..105934d3a 100644 --- a/tlsutil/config.go +++ b/tlsutil/config.go @@ -6,6 +6,7 @@ import ( "fmt" "io/ioutil" "net" + "strings" "time" ) @@ -157,11 +158,14 @@ func (c *Config) OutgoingTLSWrapper() (DCWrapper, error) { return nil, nil } + // Strip the trailing '.' from the domain if any + domain := strings.TrimSuffix(c.Domain, ".") + // Generate the wrapper based on hostname verification if c.VerifyServerHostname { wrapper := func(dc string, conn net.Conn) (net.Conn, error) { conf := *tlsConfig - conf.ServerName = "server." + dc + "." + c.Domain + conf.ServerName = "server." + dc + "." + domain return WrapTLSClient(conn, &conf) } return wrapper, nil diff --git a/tlsutil/config_test.go b/tlsutil/config_test.go index 65b96ffd3..cbd127ac8 100644 --- a/tlsutil/config_test.go +++ b/tlsutil/config_test.go @@ -7,6 +7,8 @@ import ( "io/ioutil" "net" "testing" + + "github.com/hashicorp/yamux" ) func TestConfig_AppendCA_None(t *testing.T) { @@ -191,8 +193,16 @@ func startTLSServer(config *Config) (net.Conn, chan error) { } client, server := net.Pipe() + + // Use yamux to buffer the reads, otherwise it's easy to deadlock + muxConf := yamux.DefaultConfig() + serverSession, _ := yamux.Server(server, muxConf) + clientSession, _ := yamux.Client(client, muxConf) + clientConn, _ := clientSession.Open() + serverConn, _ := serverSession.Accept() + go func() { - tlsServer := tls.Server(server, tlsConfigServer) + tlsServer := tls.Server(serverConn, tlsConfigServer) if err := tlsServer.Handshake(); err != nil { errc <- err } @@ -206,7 +216,107 @@ func startTLSServer(config *Config) (net.Conn, chan error) { io.Copy(ioutil.Discard, tlsServer) tlsServer.Close() }() - return client, errc + return clientConn, errc +} + +func TestConfig_outgoingWrapper_OK(t *testing.T) { + config := &Config{ + CAFile: "../test/hostname/CertAuth.crt", + CertFile: "../test/hostname/Alice.crt", + KeyFile: "../test/hostname/Alice.key", + VerifyServerHostname: true, + Domain: "consul", + } + + client, errc := startTLSServer(config) + if client == nil { + t.Fatalf("startTLSServer err: %v", <-errc) + } + + wrap, err := config.OutgoingTLSWrapper() + if err != nil { + t.Fatalf("OutgoingTLSWrapper err: %v", err) + } + + tlsClient, err := wrap("dc1", client) + if err != nil { + t.Fatalf("wrapTLS err: %v", err) + } + defer tlsClient.Close() + if err := tlsClient.(*tls.Conn).Handshake(); err != nil { + t.Fatalf("write err: %v", err) + } + + err = <-errc + if err != nil { + t.Fatalf("server: %v", err) + } +} + +func TestConfig_outgoingWrapper_BadDC(t *testing.T) { + config := &Config{ + CAFile: "../test/hostname/CertAuth.crt", + CertFile: "../test/hostname/Alice.crt", + KeyFile: "../test/hostname/Alice.key", + VerifyServerHostname: true, + Domain: "consul", + } + + client, errc := startTLSServer(config) + if client == nil { + t.Fatalf("startTLSServer err: %v", <-errc) + } + + wrap, err := config.OutgoingTLSWrapper() + if err != nil { + t.Fatalf("OutgoingTLSWrapper err: %v", err) + } + + tlsClient, err := wrap("dc2", client) + if err != nil { + t.Fatalf("wrapTLS err: %v", err) + } + defer tlsClient.Close() + err = tlsClient.(*tls.Conn).Handshake() + + if _, ok := err.(x509.HostnameError); !ok { + t.Fatalf("should get hostname err: %v", err) + } + + <-errc +} + +func TestConfig_outgoingWrapper_BadCert(t *testing.T) { + config := &Config{ + CAFile: "../test/ca/root.cer", + CertFile: "../test/key/ourdomain.cer", + KeyFile: "../test/key/ourdomain.key", + VerifyServerHostname: true, + Domain: "consul", + } + + client, errc := startTLSServer(config) + if client == nil { + t.Fatalf("startTLSServer err: %v", <-errc) + } + + wrap, err := config.OutgoingTLSWrapper() + if err != nil { + t.Fatalf("OutgoingTLSWrapper err: %v", err) + } + + tlsClient, err := wrap("dc1", client) + if err != nil { + t.Fatalf("wrapTLS err: %v", err) + } + defer tlsClient.Close() + err = tlsClient.(*tls.Conn).Handshake() + + if _, ok := err.(x509.HostnameError); !ok { + t.Fatalf("should get hostname err: %v", err) + } + + <-errc } func TestConfig_wrapTLS_OK(t *testing.T) {