Adds TLS config helper to API client.

This commit is contained in:
James Phillips 2016-03-24 11:24:18 -07:00
parent 0f23210628
commit 512cb6ebf7
2 changed files with 173 additions and 2 deletions

View File

@ -3,9 +3,11 @@ package api
import ( import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"crypto/x509"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -122,6 +124,28 @@ type Config struct {
Token string Token string
} }
// TLSConfig is used to generate a TLSClientConfig that's useful for talking to
// Consul using TLS.
type TLSConfig struct {
// Address is the optional address of the Consul server.
Address string
// CAFile is the optional path to the CA certificate used for Consul
// communication, defaults to the system bundle if not specified.
CAFile string
// CertFile is the optional path to the certificate for Consul
// communication. If this is set then you need to also set KeyFile.
CertFile string
// KeyFile is the optional path to the private key for Consul communication.
// If this is set then you need to also set CertFile.
KeyFile string
// InsecureSkipVerify if set to true will disable TLS host verification.
InsecureSkipVerify bool
}
// DefaultConfig returns a default configuration for the client. By default this // DefaultConfig returns a default configuration for the client. By default this
// will pool and reuse idle connections to Consul. If you have a long-lived // will pool and reuse idle connections to Consul. If you have a long-lived
// client object, this is the desired behavior and should make the most efficient // client object, this is the desired behavior and should make the most efficient
@ -194,10 +218,19 @@ func defaultConfig(transportFn func() *http.Transport) *Config {
} }
if !doVerify { if !doVerify {
transport := transportFn() tlsClientConfig, err := SetupTLSConfig(&TLSConfig{
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
})
// We don't expect this to fail given that we aren't
// parsing any of the input, but we panic just in case
// since this doesn't have an error return.
if err != nil {
panic(err)
} }
transport := transportFn()
transport.TLSClientConfig = tlsClientConfig
config.HttpClient.Transport = transport config.HttpClient.Transport = transport
} }
} }
@ -205,6 +238,50 @@ func defaultConfig(transportFn func() *http.Transport) *Config {
return config return config
} }
// TLSConfig is used to generate a TLSClientConfig that's useful for talking to
// Consul using TLS.
func SetupTLSConfig(tlsConfig *TLSConfig) (*tls.Config, error) {
tlsClientConfig := &tls.Config{
InsecureSkipVerify: tlsConfig.InsecureSkipVerify,
}
if tlsConfig.Address != "" {
server := tlsConfig.Address
hasPort := strings.LastIndex(server, ":") > strings.LastIndex(server, "]")
if hasPort {
var err error
server, _, err = net.SplitHostPort(server)
if err != nil {
return nil, err
}
}
tlsClientConfig.ServerName = server
}
if tlsConfig.CertFile != "" && tlsConfig.KeyFile != "" {
tlsCert, err := tls.LoadX509KeyPair(tlsConfig.CertFile, tlsConfig.KeyFile)
if err != nil {
return nil, err
}
tlsClientConfig.Certificates = []tls.Certificate{tlsCert}
}
if tlsConfig.CAFile != "" {
data, err := ioutil.ReadFile(tlsConfig.CAFile)
if err != nil {
return nil, fmt.Errorf("failed to read CA file: %v", err)
}
caPool := x509.NewCertPool()
if !caPool.AppendCertsFromPEM(data) {
return nil, fmt.Errorf("failed to parse CA certificate")
}
tlsClientConfig.RootCAs = caPool
}
return tlsClientConfig, nil
}
// Client provides a client to the Consul API // Client provides a client to the Consul API
type Client struct { type Client struct {
config Config config Config

View File

@ -2,11 +2,13 @@ package api
import ( import (
crand "crypto/rand" crand "crypto/rand"
"crypto/tls"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"runtime" "runtime"
"testing" "testing"
"time" "time"
@ -121,6 +123,98 @@ func TestDefaultConfig_env(t *testing.T) {
} }
} }
func TestSetupTLSConfig(t *testing.T) {
// A default config should result in a clean default client config.
tlsConfig := &TLSConfig{}
cc, err := SetupTLSConfig(tlsConfig)
if err != nil {
t.Fatalf("err: %v", err)
}
expected := &tls.Config{}
if !reflect.DeepEqual(cc, expected) {
t.Fatalf("bad: %v", cc)
}
// Try some address variations with and without ports.
tlsConfig.Address = "127.0.0.1"
cc, err = SetupTLSConfig(tlsConfig)
if err != nil {
t.Fatalf("err: %v", err)
}
expected.ServerName = "127.0.0.1"
if !reflect.DeepEqual(cc, expected) {
t.Fatalf("bad: %v", cc)
}
tlsConfig.Address = "127.0.0.1:80"
cc, err = SetupTLSConfig(tlsConfig)
if err != nil {
t.Fatalf("err: %v", err)
}
expected.ServerName = "127.0.0.1"
if !reflect.DeepEqual(cc, expected) {
t.Fatalf("bad: %v", cc)
}
tlsConfig.Address = "demo.consul.io:80"
cc, err = SetupTLSConfig(tlsConfig)
if err != nil {
t.Fatalf("err: %v", err)
}
expected.ServerName = "demo.consul.io"
if !reflect.DeepEqual(cc, expected) {
t.Fatalf("bad: %v", cc)
}
tlsConfig.Address = "[2001:db8:a0b:12f0::1]"
cc, err = SetupTLSConfig(tlsConfig)
if err != nil {
t.Fatalf("err: %v", err)
}
expected.ServerName = "[2001:db8:a0b:12f0::1]"
if !reflect.DeepEqual(cc, expected) {
t.Fatalf("bad: %v", cc)
}
tlsConfig.Address = "[2001:db8:a0b:12f0::1]:80"
cc, err = SetupTLSConfig(tlsConfig)
if err != nil {
t.Fatalf("err: %v", err)
}
expected.ServerName = "2001:db8:a0b:12f0::1"
if !reflect.DeepEqual(cc, expected) {
t.Fatalf("bad: %v", cc)
}
// Skip verification.
tlsConfig.InsecureSkipVerify = true
cc, err = SetupTLSConfig(tlsConfig)
if err != nil {
t.Fatalf("err: %v", err)
}
expected.InsecureSkipVerify = true
if !reflect.DeepEqual(cc, expected) {
t.Fatalf("bad: %v", cc)
}
// Make a new config that hits all the file parsers.
tlsConfig = &TLSConfig{
CertFile: "../test/hostname/Alice.crt",
KeyFile: "../test/hostname/Alice.key",
CAFile: "../test/hostname/CertAuth.crt",
}
cc, err = SetupTLSConfig(tlsConfig)
if err != nil {
t.Fatalf("err: %v", err)
}
if len(cc.Certificates) != 1 {
t.Fatalf("missing certificate: %v", cc.Certificates)
}
if cc.RootCAs == nil {
t.Fatalf("didn't load root CAs")
}
}
func TestSetQueryOptions(t *testing.T) { func TestSetQueryOptions(t *testing.T) {
t.Parallel() t.Parallel()
c, s := makeClient(t) c, s := makeClient(t)