From 7d4824ade77be1889b0a461a77baa0d88c6a0879 Mon Sep 17 00:00:00 2001 From: Nelson Elhage Date: Mon, 26 May 2014 10:58:57 -0700 Subject: [PATCH] Allow multiple PEM-encoded certificates in the ca_file. fixes #167 --- consul/config.go | 52 +++++++++++++++++-------------------------- consul/config_test.go | 15 ++++++++----- 2 files changed, 30 insertions(+), 37 deletions(-) diff --git a/consul/config.go b/consul/config.go index e848c3b01..85d94e698 100644 --- a/consul/config.go +++ b/consul/config.go @@ -3,16 +3,16 @@ package consul import ( "crypto/tls" "crypto/x509" - "encoding/pem" "fmt" - "github.com/hashicorp/memberlist" - "github.com/hashicorp/raft" - "github.com/hashicorp/serf/serf" "io" "io/ioutil" "net" "os" "time" + + "github.com/hashicorp/memberlist" + "github.com/hashicorp/raft" + "github.com/hashicorp/serf/serf" ) const ( @@ -131,30 +131,24 @@ func (c *Config) CheckVersion() error { return nil } -// CACertificate is used to open and parse a CA file -func (c *Config) CACertificate() (*x509.Certificate, error) { +// AppendCA opens and parses the CA file and adds the certificates to +// the provided CertPool. +func (c *Config) AppendCA(pool *x509.CertPool) error { if c.CAFile == "" { - return nil, nil + return nil } // Read the file data, err := ioutil.ReadFile(c.CAFile) if err != nil { - return nil, fmt.Errorf("Failed to read CA file: %v", err) + return fmt.Errorf("Failed to read CA file: %v", err) } - // Decode from the PEM format - block, _ := pem.Decode(data) - if block == nil { - return nil, fmt.Errorf("Failed to decode CA PEM!") + if !pool.AppendCertsFromPEM(data) { + return fmt.Errorf("Failed to parse any CA certificates") } - // Parse the certificate - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - return nil, fmt.Errorf("Failed to parse CA file: %v", err) - } - return cert, nil + return nil } // KeyPair is used to open and parse a certificate and key file @@ -177,17 +171,15 @@ func (c *Config) OutgoingTLSConfig() (*tls.Config, error) { InsecureSkipVerify: !c.VerifyOutgoing, } - // Parse the CA cert if any - ca, err := c.CACertificate() - if err != nil { - return nil, err - } else if ca != nil { - tlsConfig.RootCAs.AddCert(ca) + // Ensure we have a CA if VerifyOutgoing is set + if c.VerifyOutgoing && c.CAFile == "" { + return nil, fmt.Errorf("VerifyOutgoing set, and no CA certificate provided!") } - // Ensure we have a CA if VerifyOutgoing is set - if c.VerifyOutgoing && ca == nil { - return nil, fmt.Errorf("VerifyOutgoing set, and no CA certificate provided!") + // Parse the CA cert if any + err := c.AppendCA(tlsConfig.RootCAs) + if err != nil { + return nil, err } // Add cert/key @@ -210,11 +202,9 @@ func (c *Config) IncomingTLSConfig() (*tls.Config, error) { } // Parse the CA cert if any - ca, err := c.CACertificate() + err := c.AppendCA(tlsConfig.ClientCAs) if err != nil { return nil, err - } else if ca != nil { - tlsConfig.ClientCAs.AddCert(ca) } // Add cert/key @@ -228,7 +218,7 @@ func (c *Config) IncomingTLSConfig() (*tls.Config, error) { // Check if we require verification if c.VerifyIncoming { tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - if ca == nil { + if c.CAFile == "" { return nil, fmt.Errorf("VerifyIncoming set, and no CA certificate provided!") } if cert == nil { diff --git a/consul/config_test.go b/consul/config_test.go index 6d4c106db..c6081603e 100644 --- a/consul/config_test.go +++ b/consul/config_test.go @@ -2,17 +2,19 @@ package consul import ( "crypto/tls" + "crypto/x509" "testing" ) -func TestConfig_CACertificate_None(t *testing.T) { +func TestConfig_AppendCA_None(t *testing.T) { conf := &Config{} - cert, err := conf.CACertificate() + pool := x509.NewCertPool() + err := conf.AppendCA(pool) if err != nil { t.Fatalf("err: %v", err) } - if cert != nil { - t.Fatalf("bad: %v", cert) + if len(pool.Subjects()) != 0 { + t.Fatalf("bad: %v", pool.Subjects()) } } @@ -20,11 +22,12 @@ func TestConfig_CACertificate_Valid(t *testing.T) { conf := &Config{ CAFile: "../test/ca/root.cer", } - cert, err := conf.CACertificate() + pool := x509.NewCertPool() + err := conf.AppendCA(pool) if err != nil { t.Fatalf("err: %v", err) } - if cert == nil { + if len(pool.Subjects()) == 0 { t.Fatalf("expected cert") } }