diff --git a/tlsutil/config.go b/tlsutil/config.go index 7d658efea..0fad12716 100644 --- a/tlsutil/config.go +++ b/tlsutil/config.go @@ -326,7 +326,7 @@ func (c *Configurator) check(config Config, pool *x509.CertPool, cert *tls.Certi if pool == nil { return fmt.Errorf("VerifyIncoming set, and no CA certificate provided!") } - if cert == nil || cert.Certificate == nil { + if cert == nil { return fmt.Errorf("VerifyIncoming set, and no Cert/Key pair provided!") } } @@ -351,7 +351,7 @@ func (c *Config) baseVerifyIncoming() bool { func loadKeyPair(certFile, keyFile string) (*tls.Certificate, error) { if certFile == "" || keyFile == "" { - return &tls.Certificate{}, nil + return nil, nil } cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { @@ -428,11 +428,16 @@ func (c *Configurator) commonTLSConfig(verifyIncoming bool) *tls.Config { tlsConfig.PreferServerCipherSuites = c.base.PreferServerCipherSuites // GetCertificate is used when acting as a server and responding to - // client requests. Always return the manually configured cert, because - // on the server this is all we have. And on the client, this is the - // only sensitive option. + // client requests. Default to the manually configured cert, but allow + // autoEncrypt cert too so that a client can encrypt incoming + // connections without having a manual cert configured. tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { - return c.manual.cert, nil + cert := c.manual.cert + if cert == nil { + cert = c.autoEncrypt.cert + } + + return cert, nil } // GetClientCertificate is used when acting as a client and responding @@ -630,9 +635,9 @@ func (c *Configurator) OutgoingRPCWrapper() DCWrapper { // there is no cert, it will return a time in the past. func (c *Configurator) AutoEncryptCertNotAfter() time.Time { c.RLock() + defer c.RUnlock() tlsCert := c.autoEncrypt.cert - c.RUnlock() - if tlsCert == nil { + if tlsCert == nil || tlsCert.Certificate == nil { return time.Now().AddDate(0, 0, -1) } cert, err := x509.ParseCertificate(tlsCert.Certificate[0]) diff --git a/tlsutil/config_test.go b/tlsutil/config_test.go index 049130236..21d4a5740 100644 --- a/tlsutil/config_test.go +++ b/tlsutil/config_test.go @@ -293,17 +293,16 @@ func TestConfigurator_loadKeyPair(t *testing.T) { cert, key string shoulderr bool isnil bool - isempty bool } variants := []variant{ - {"", "", false, false, true}, - {"bogus", "", false, false, true}, - {"", "bogus", false, false, true}, - {"../test/key/ourdomain.cer", "", false, false, true}, - {"", "../test/key/ourdomain.key", false, false, true}, - {"bogus", "bogus", true, true, false}, + {"", "", false, true}, + {"bogus", "", false, true}, + {"", "bogus", false, true}, + {"../test/key/ourdomain.cer", "", false, true}, + {"", "../test/key/ourdomain.key", false, true}, + {"bogus", "bogus", true, true}, {"../test/key/ourdomain.cer", "../test/key/ourdomain.key", - false, false, false}, + false, false}, } for i, v := range variants { info := fmt.Sprintf("case %d", i) @@ -317,10 +316,6 @@ func TestConfigurator_loadKeyPair(t *testing.T) { require.NoError(t, err1, info) require.NoError(t, err2, info) } - if v.isempty { - require.Empty(t, cert1.Certificate, info) - require.Empty(t, cert2.Certificate, info) - } if v.isnil { require.Nil(t, cert1, info) require.Nil(t, cert2, info) @@ -538,17 +533,47 @@ func TestConfigurator_CommonTLSConfigGetClientCertificate(t *testing.T) { c, err := NewConfigurator(Config{}, nil) require.NoError(t, err) - cert, err := c.commonTLSConfig(false).GetCertificate(nil) + cert, err := c.commonTLSConfig(false).GetClientCertificate(nil) require.NoError(t, err) - require.Nil(t, cert.Certificate) + require.Nil(t, cert) - c.manual.cert = &tls.Certificate{} - cert, err = c.commonTLSConfig(false).GetCertificate(nil) + c1, err := loadKeyPair("../test/key/something_expired.cer", "../test/key/something_expired.key") + require.NoError(t, err) + c.manual.cert = c1 + cert, err = c.commonTLSConfig(false).GetClientCertificate(nil) require.NoError(t, err) require.Equal(t, c.manual.cert, cert) + c2, err := loadKeyPair("../test/key/ourdomain.cer", "../test/key/ourdomain.key") + require.NoError(t, err) + c.autoEncrypt.cert = c2 cert, err = c.commonTLSConfig(false).GetClientCertificate(nil) require.NoError(t, err) + require.Equal(t, c.autoEncrypt.cert, cert) +} + +func TestConfigurator_CommonTLSConfigGetCertificate(t *testing.T) { + c, err := NewConfigurator(Config{}, nil) + require.NoError(t, err) + + cert, err := c.commonTLSConfig(false).GetCertificate(nil) + require.NoError(t, err) + require.Nil(t, cert) + + // Setting a certificate as the auto-encrypt cert will return it as the regular server certificate + c1, err := loadKeyPair("../test/key/something_expired.cer", "../test/key/something_expired.key") + require.NoError(t, err) + c.autoEncrypt.cert = c1 + cert, err = c.commonTLSConfig(false).GetCertificate(nil) + require.NoError(t, err) + require.Equal(t, c.autoEncrypt.cert, cert) + + // Setting a different certificate as a manual cert will override the auto-encrypt cert and instead return the manual cert + c2, err := loadKeyPair("../test/key/ourdomain.cer", "../test/key/ourdomain.key") + require.NoError(t, err) + c.manual.cert = c2 + cert, err = c.commonTLSConfig(false).GetCertificate(nil) + require.NoError(t, err) require.Equal(t, c.manual.cert, cert) } @@ -715,7 +740,7 @@ func TestConfigurator_UpdateSetsStuff(t *testing.T) { c, err := NewConfigurator(Config{}, nil) require.NoError(t, err) require.Nil(t, c.caPool) - require.Nil(t, c.manual.cert.Certificate) + require.Nil(t, c.manual.cert) require.Equal(t, c.base, &Config{}) require.Equal(t, 1, c.version)