From 512cb6ebf7007b5ae6e3c79fcda1a283b79a1adc Mon Sep 17 00:00:00 2001 From: James Phillips Date: Thu, 24 Mar 2016 11:24:18 -0700 Subject: [PATCH] Adds TLS config helper to API client. --- api/api.go | 81 ++++++++++++++++++++++++++++++++++++++++-- api/api_test.go | 94 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+), 2 deletions(-) diff --git a/api/api.go b/api/api.go index d0e0ceeae..db791e282 100644 --- a/api/api.go +++ b/api/api.go @@ -3,9 +3,11 @@ package api import ( "bytes" "crypto/tls" + "crypto/x509" "encoding/json" "fmt" "io" + "io/ioutil" "log" "net" "net/http" @@ -122,6 +124,28 @@ type Config struct { Token string } +// TLSConfig is used to generate a TLSClientConfig that's useful for talking to +// Consul using TLS. +type TLSConfig struct { + // Address is the optional address of the Consul server. + Address string + + // CAFile is the optional path to the CA certificate used for Consul + // communication, defaults to the system bundle if not specified. + CAFile string + + // CertFile is the optional path to the certificate for Consul + // communication. If this is set then you need to also set KeyFile. + CertFile string + + // KeyFile is the optional path to the private key for Consul communication. + // If this is set then you need to also set CertFile. + KeyFile string + + // InsecureSkipVerify if set to true will disable TLS host verification. + InsecureSkipVerify bool +} + // DefaultConfig returns a default configuration for the client. By default this // will pool and reuse idle connections to Consul. If you have a long-lived // client object, this is the desired behavior and should make the most efficient @@ -194,10 +218,19 @@ func defaultConfig(transportFn func() *http.Transport) *Config { } if !doVerify { - transport := transportFn() - transport.TLSClientConfig = &tls.Config{ + tlsClientConfig, err := SetupTLSConfig(&TLSConfig{ InsecureSkipVerify: true, + }) + + // We don't expect this to fail given that we aren't + // parsing any of the input, but we panic just in case + // since this doesn't have an error return. + if err != nil { + panic(err) } + + transport := transportFn() + transport.TLSClientConfig = tlsClientConfig config.HttpClient.Transport = transport } } @@ -205,6 +238,50 @@ func defaultConfig(transportFn func() *http.Transport) *Config { return config } +// TLSConfig is used to generate a TLSClientConfig that's useful for talking to +// Consul using TLS. +func SetupTLSConfig(tlsConfig *TLSConfig) (*tls.Config, error) { + tlsClientConfig := &tls.Config{ + InsecureSkipVerify: tlsConfig.InsecureSkipVerify, + } + + if tlsConfig.Address != "" { + server := tlsConfig.Address + hasPort := strings.LastIndex(server, ":") > strings.LastIndex(server, "]") + if hasPort { + var err error + server, _, err = net.SplitHostPort(server) + if err != nil { + return nil, err + } + } + tlsClientConfig.ServerName = server + } + + if tlsConfig.CertFile != "" && tlsConfig.KeyFile != "" { + tlsCert, err := tls.LoadX509KeyPair(tlsConfig.CertFile, tlsConfig.KeyFile) + if err != nil { + return nil, err + } + tlsClientConfig.Certificates = []tls.Certificate{tlsCert} + } + + if tlsConfig.CAFile != "" { + data, err := ioutil.ReadFile(tlsConfig.CAFile) + if err != nil { + return nil, fmt.Errorf("failed to read CA file: %v", err) + } + + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM(data) { + return nil, fmt.Errorf("failed to parse CA certificate") + } + tlsClientConfig.RootCAs = caPool + } + + return tlsClientConfig, nil +} + // Client provides a client to the Consul API type Client struct { config Config diff --git a/api/api_test.go b/api/api_test.go index 22913c8a6..da7c00550 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -2,11 +2,13 @@ package api import ( crand "crypto/rand" + "crypto/tls" "fmt" "io/ioutil" "net/http" "os" "path/filepath" + "reflect" "runtime" "testing" "time" @@ -121,6 +123,98 @@ func TestDefaultConfig_env(t *testing.T) { } } +func TestSetupTLSConfig(t *testing.T) { + // A default config should result in a clean default client config. + tlsConfig := &TLSConfig{} + cc, err := SetupTLSConfig(tlsConfig) + if err != nil { + t.Fatalf("err: %v", err) + } + expected := &tls.Config{} + if !reflect.DeepEqual(cc, expected) { + t.Fatalf("bad: %v", cc) + } + + // Try some address variations with and without ports. + tlsConfig.Address = "127.0.0.1" + cc, err = SetupTLSConfig(tlsConfig) + if err != nil { + t.Fatalf("err: %v", err) + } + expected.ServerName = "127.0.0.1" + if !reflect.DeepEqual(cc, expected) { + t.Fatalf("bad: %v", cc) + } + + tlsConfig.Address = "127.0.0.1:80" + cc, err = SetupTLSConfig(tlsConfig) + if err != nil { + t.Fatalf("err: %v", err) + } + expected.ServerName = "127.0.0.1" + if !reflect.DeepEqual(cc, expected) { + t.Fatalf("bad: %v", cc) + } + + tlsConfig.Address = "demo.consul.io:80" + cc, err = SetupTLSConfig(tlsConfig) + if err != nil { + t.Fatalf("err: %v", err) + } + expected.ServerName = "demo.consul.io" + if !reflect.DeepEqual(cc, expected) { + t.Fatalf("bad: %v", cc) + } + + tlsConfig.Address = "[2001:db8:a0b:12f0::1]" + cc, err = SetupTLSConfig(tlsConfig) + if err != nil { + t.Fatalf("err: %v", err) + } + expected.ServerName = "[2001:db8:a0b:12f0::1]" + if !reflect.DeepEqual(cc, expected) { + t.Fatalf("bad: %v", cc) + } + + tlsConfig.Address = "[2001:db8:a0b:12f0::1]:80" + cc, err = SetupTLSConfig(tlsConfig) + if err != nil { + t.Fatalf("err: %v", err) + } + expected.ServerName = "2001:db8:a0b:12f0::1" + if !reflect.DeepEqual(cc, expected) { + t.Fatalf("bad: %v", cc) + } + + // Skip verification. + tlsConfig.InsecureSkipVerify = true + cc, err = SetupTLSConfig(tlsConfig) + if err != nil { + t.Fatalf("err: %v", err) + } + expected.InsecureSkipVerify = true + if !reflect.DeepEqual(cc, expected) { + t.Fatalf("bad: %v", cc) + } + + // Make a new config that hits all the file parsers. + tlsConfig = &TLSConfig{ + CertFile: "../test/hostname/Alice.crt", + KeyFile: "../test/hostname/Alice.key", + CAFile: "../test/hostname/CertAuth.crt", + } + cc, err = SetupTLSConfig(tlsConfig) + if err != nil { + t.Fatalf("err: %v", err) + } + if len(cc.Certificates) != 1 { + t.Fatalf("missing certificate: %v", cc.Certificates) + } + if cc.RootCAs == nil { + t.Fatalf("didn't load root CAs") + } +} + func TestSetQueryOptions(t *testing.T) { t.Parallel() c, s := makeClient(t)