From 86c9cb037fe99074431c3139da502c1bfec87efd Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Mon, 21 Jun 2021 15:19:34 -0400 Subject: [PATCH] tlsutil: reduce interface provided to auto-config Replace two methods with a single one that returns the cert. This moves more of the logic into the single caller (auto-config). tlsutil.Configurator is widely used. By keeping it smaller and focused only on storing and returning TLS config, we make the code easier to follow. These two methods were more related to auto-config than to tlsutil, so reducing the interface moves the logic closer to the feature that requires it. --- agent/auto-config/auto_config_test.go | 26 +++++++++++++++++++------- agent/auto-config/auto_encrypt_test.go | 17 ++++++++++++----- agent/auto-config/config.go | 4 ++-- agent/auto-config/mock_test.go | 16 +++++++--------- agent/auto-config/run.go | 14 +++++++++----- tlsutil/config.go | 16 +++++----------- tlsutil/config_test.go | 8 ++++---- 7 files changed, 58 insertions(+), 43 deletions(-) diff --git a/agent/auto-config/auto_config_test.go b/agent/auto-config/auto_config_test.go index c4b40606e..37e9f67ab 100644 --- a/agent/auto-config/auto_config_test.go +++ b/agent/auto-config/auto_config_test.go @@ -2,6 +2,7 @@ package autoconf import ( "context" + "crypto/x509" "fmt" "io/ioutil" "net" @@ -577,7 +578,9 @@ func TestGoRoutineManagement(t *testing.T) { mcfg.tokens.On("Notify", token.TokenKindAgent).Return(token.Notifier{}).Times(2) mcfg.tokens.On("StopNotify", token.Notifier{}).Times(2) - mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Times(0) + mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{ + NotAfter: time.Now().Add(10 * time.Minute), + }).Times(0) // ensure that auto-config isn't running require.False(t, ac.IsRunning()) @@ -734,7 +737,9 @@ func startedAutoConfig(t *testing.T, autoEncrypt bool) testAutoConfig { indexedRoots, cert, extraCerts := mcfg.setupInitialTLS(t, "autoconf", "dc1", originalToken) - mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(cert.ValidBefore).Once() + mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{ + NotAfter: cert.ValidBefore, + }).Once() populateResponse := func(args mock.Arguments) { method := args.String(3) @@ -920,7 +925,9 @@ func TestRootsUpdate(t *testing.T) { }) // when a cache event comes in we end up recalculating the fallback timer which requires this call - testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Once() + testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{ + NotAfter: time.Now().Add(10 * time.Minute), + }).Once() req := structs.DCSpecificRequest{Datacenter: "dc1"} require.True(t, testAC.mcfg.cache.sendNotification(context.Background(), req.CacheInfo().Key, cache.UpdateEvent{ @@ -960,7 +967,9 @@ func TestCertUpdate(t *testing.T) { }) // when a cache event comes in we end up recalculating the fallback timer which requires this call - testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(secondCert.ValidBefore).Once() + testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{ + NotAfter: secondCert.ValidBefore, + }).Once() req := cachetype.ConnectCALeafRequest{ Datacenter: "dc1", @@ -1025,8 +1034,9 @@ func TestFallback(t *testing.T) { }) // when a cache event comes in we end up recalculating the fallback timer which requires this call - testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(secondCert.ValidBefore).Once() - testAC.mcfg.tlsCfg.On("AutoEncryptCertExpired").Return(true).Once() + testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{ + NotAfter: secondCert.ValidBefore, + }).Times(2) fallbackCtx, fallbackCancel := context.WithCancel(context.Background()) @@ -1082,7 +1092,9 @@ func TestFallback(t *testing.T) { testAC.mcfg.expectInitialTLS(t, "autoconf", "dc1", testAC.originalToken, secondCA, &secondRoots, thirdCert, testAC.extraCerts) // after the second RPC we now will use the new certs validity period in the next run loop iteration - testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Once() + testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{ + NotAfter: time.Now().Add(10 * time.Minute), + }).Once() // now that all the mocks are set up we can trigger the whole thing by sending the second expired cert // as a cache update event. diff --git a/agent/auto-config/auto_encrypt_test.go b/agent/auto-config/auto_encrypt_test.go index c92736ed6..2de33f68d 100644 --- a/agent/auto-config/auto_encrypt_test.go +++ b/agent/auto-config/auto_encrypt_test.go @@ -406,7 +406,9 @@ func TestAutoEncrypt_RootsUpdate(t *testing.T) { }) // when a cache event comes in we end up recalculating the fallback timer which requires this call - testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Once() + testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{ + NotAfter: time.Now().Add(10 * time.Minute), + }).Once() req := structs.DCSpecificRequest{Datacenter: "dc1"} require.True(t, testAC.mcfg.cache.sendNotification(context.Background(), req.CacheInfo().Key, cache.UpdateEvent{ @@ -433,7 +435,9 @@ func TestAutoEncrypt_CertUpdate(t *testing.T) { }) // when a cache event comes in we end up recalculating the fallback timer which requires this call - testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(secondCert.ValidBefore).Once() + testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{ + NotAfter: secondCert.ValidBefore, + }).Once() req := cachetype.ConnectCALeafRequest{ Datacenter: "dc1", @@ -484,8 +488,9 @@ func TestAutoEncrypt_Fallback(t *testing.T) { }) // when a cache event comes in we end up recalculating the fallback timer which requires this call - testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(secondCert.ValidBefore).Once() - testAC.mcfg.tlsCfg.On("AutoEncryptCertExpired").Return(true).Once() + testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{ + NotAfter: secondCert.ValidBefore, + }).Times(2) fallbackCtx, fallbackCancel := context.WithCancel(context.Background()) @@ -536,7 +541,9 @@ func TestAutoEncrypt_Fallback(t *testing.T) { testAC.mcfg.expectInitialTLS(t, "autoconf", "dc1", testAC.originalToken, secondCA, &secondRoots, thirdCert, testAC.extraCerts) // after the second RPC we now will use the new certs validity period in the next run loop iteration - testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Once() + testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{ + NotAfter: time.Now().Add(10 * time.Minute), + }).Once() // now that all the mocks are set up we can trigger the whole thing by sending the second expired cert // as a cache update event. diff --git a/agent/auto-config/config.go b/agent/auto-config/config.go index 090b30dcc..a20121fb9 100644 --- a/agent/auto-config/config.go +++ b/agent/auto-config/config.go @@ -2,6 +2,7 @@ package autoconf import ( "context" + "crypto/x509" "net" "time" @@ -41,8 +42,7 @@ type TLSConfigurator interface { UpdateAutoTLS(manualCAPEMs, connectCAPEMs []string, pub, priv string, verifyServerHostname bool) error UpdateAutoTLSCA([]string) error UpdateAutoTLSCert(pub, priv string) error - AutoEncryptCertNotAfter() time.Time - AutoEncryptCertExpired() bool + AutoEncryptCert() *x509.Certificate } // TokenStore is an interface of the methods we will need to use from the token.Store. diff --git a/agent/auto-config/mock_test.go b/agent/auto-config/mock_test.go index 8d63a05b4..49d3ed29e 100644 --- a/agent/auto-config/mock_test.go +++ b/agent/auto-config/mock_test.go @@ -2,10 +2,12 @@ package autoconf import ( "context" + "crypto/x509" "net" "sync" "testing" - "time" + + "github.com/stretchr/testify/mock" "github.com/hashicorp/consul/agent/cache" cachetype "github.com/hashicorp/consul/agent/cache-types" @@ -15,7 +17,6 @@ import ( "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/proto/pbautoconf" "github.com/hashicorp/consul/sdk/testutil" - "github.com/stretchr/testify/mock" ) type mockDirectRPC struct { @@ -72,6 +73,7 @@ func (m *mockTLSConfigurator) UpdateAutoTLSCA(pems []string) error { ret := m.Called(pems) return ret.Error(0) } + func (m *mockTLSConfigurator) UpdateAutoTLSCert(pub, priv string) error { if priv != "" { priv = "redacted" @@ -79,15 +81,11 @@ func (m *mockTLSConfigurator) UpdateAutoTLSCert(pub, priv string) error { ret := m.Called(pub, priv) return ret.Error(0) } -func (m *mockTLSConfigurator) AutoEncryptCertNotAfter() time.Time { - ret := m.Called() - ts, _ := ret.Get(0).(time.Time) - return ts -} -func (m *mockTLSConfigurator) AutoEncryptCertExpired() bool { +func (m *mockTLSConfigurator) AutoEncryptCert() *x509.Certificate { ret := m.Called() - return ret.Bool(0) + cert, _ := ret.Get(0).(*x509.Certificate) + return cert } type mockServerProvider struct { diff --git a/agent/auto-config/run.go b/agent/auto-config/run.go index 6155dc6be..136d4a8ae 100644 --- a/agent/auto-config/run.go +++ b/agent/auto-config/run.go @@ -116,14 +116,17 @@ func (ac *AutoConfig) run(ctx context.Context, exit chan struct{}) { // expires. The agent cache should be handling the expiration // and renew it before then. // - // If there is no cert, AutoEncryptCertNotAfter returns - // a value in the past which immediately triggers the + // If there is no cert, use a value which immediately triggers the // renew, but this case shouldn't happen because at // this point, auto_encrypt was just being setup // successfully. calcFallbackInterval := func() time.Duration { - certExpiry := ac.acConfig.TLSConfigurator.AutoEncryptCertNotAfter() - return certExpiry.Add(ac.acConfig.FallbackLeeway).Sub(time.Now()) + cert := ac.acConfig.TLSConfigurator.AutoEncryptCert() + if cert == nil { + return -1 + } + expiry := cert.NotAfter.Add(ac.acConfig.FallbackLeeway) + return expiry.Sub(time.Now()) } fallbackTimer := time.NewTimer(calcFallbackInterval()) @@ -174,7 +177,8 @@ func (ac *AutoConfig) run(ctx context.Context, exit chan struct{}) { // never use the AutoEncrypt.Sign endpoint. // check auto encrypt client cert expiration - if ac.acConfig.TLSConfigurator.AutoEncryptCertExpired() { + cert := ac.acConfig.TLSConfigurator.AutoEncryptCert() + if cert == nil || cert.NotAfter.Before(time.Now()) { if err := ac.handleFallback(ctx); err != nil { ac.logger.Error("error when handling a certificate expiry event", "error", err) fallbackTimer = time.NewTimer(ac.acConfig.FallbackRetry) diff --git a/tlsutil/config.go b/tlsutil/config.go index 8d66cf975..70d987f11 100644 --- a/tlsutil/config.go +++ b/tlsutil/config.go @@ -796,25 +796,19 @@ func (c *Configurator) OutgoingALPNRPCWrapper() ALPNWrapper { return c.wrapALPNTLSClient } -// AutoEncryptCertNotAfter returns NotAfter from the auto_encrypt cert. In case -// there is no cert, it will return a time in the past. -func (c *Configurator) AutoEncryptCertNotAfter() time.Time { +// AutoEncryptCert returns the TLS certificate received from auto-encrypt. +func (c *Configurator) AutoEncryptCert() *x509.Certificate { c.lock.RLock() defer c.lock.RUnlock() tlsCert := c.autoTLS.cert if tlsCert == nil || tlsCert.Certificate == nil { - return time.Now().AddDate(0, 0, -1) + return nil } cert, err := x509.ParseCertificate(tlsCert.Certificate[0]) if err != nil { - return time.Now().AddDate(0, 0, -1) + return nil } - return cert.NotAfter -} - -// AutoEncryptCertExpired returns if the auto_encrypt cert is expired. -func (c *Configurator) AutoEncryptCertExpired() bool { - return c.AutoEncryptCertNotAfter().Before(time.Now()) + return cert } func (c *Configurator) log(name string) { diff --git a/tlsutil/config_test.go b/tlsutil/config_test.go index 0811c00ac..19b04d4a6 100644 --- a/tlsutil/config_test.go +++ b/tlsutil/config_test.go @@ -1120,19 +1120,19 @@ func TestConfigurator_VerifyServerHostname(t *testing.T) { require.True(t, c.VerifyServerHostname()) } -func TestConfigurator_AutoEncrytCertExpired(t *testing.T) { +func TestConfigurator_AutoEncryptCert(t *testing.T) { c := Configurator{base: &Config{}} - require.True(t, c.AutoEncryptCertExpired()) + require.Nil(t, c.AutoEncryptCert()) cert, err := loadKeyPair("../test/key/something_expired.cer", "../test/key/something_expired.key") require.NoError(t, err) c.autoTLS.cert = cert - require.True(t, c.AutoEncryptCertExpired()) + require.Equal(t, int64(1561561551), c.AutoEncryptCert().NotAfter.Unix()) cert, err = loadKeyPair("../test/key/ourdomain.cer", "../test/key/ourdomain.key") require.NoError(t, err) c.autoTLS.cert = cert - require.False(t, c.AutoEncryptCertExpired()) + require.Equal(t, int64(4679716209), c.AutoEncryptCert().NotAfter.Unix()) } func TestConfig_tlsVersions(t *testing.T) {