tlsutil: reduce interface provided to auto-config

Replace two methods with a single one that returns the cert. This moves more
of the logic into the single caller (auto-config).

tlsutil.Configurator is widely used. By keeping it smaller and focused only on storing and
returning TLS config, we make the code easier to follow.

These two methods were more related to auto-config than to tlsutil, so reducing the interface
moves the logic closer to the feature that requires it.
This commit is contained in:
Daniel Nephin 2021-06-21 15:19:34 -04:00
parent 1bdcd3df91
commit 86c9cb037f
7 changed files with 58 additions and 43 deletions

View File

@ -2,6 +2,7 @@ package autoconf
import ( import (
"context" "context"
"crypto/x509"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@ -577,7 +578,9 @@ func TestGoRoutineManagement(t *testing.T) {
mcfg.tokens.On("Notify", token.TokenKindAgent).Return(token.Notifier{}).Times(2) mcfg.tokens.On("Notify", token.TokenKindAgent).Return(token.Notifier{}).Times(2)
mcfg.tokens.On("StopNotify", token.Notifier{}).Times(2) mcfg.tokens.On("StopNotify", token.Notifier{}).Times(2)
mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Times(0) mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
NotAfter: time.Now().Add(10 * time.Minute),
}).Times(0)
// ensure that auto-config isn't running // ensure that auto-config isn't running
require.False(t, ac.IsRunning()) require.False(t, ac.IsRunning())
@ -734,7 +737,9 @@ func startedAutoConfig(t *testing.T, autoEncrypt bool) testAutoConfig {
indexedRoots, cert, extraCerts := mcfg.setupInitialTLS(t, "autoconf", "dc1", originalToken) indexedRoots, cert, extraCerts := mcfg.setupInitialTLS(t, "autoconf", "dc1", originalToken)
mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(cert.ValidBefore).Once() mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
NotAfter: cert.ValidBefore,
}).Once()
populateResponse := func(args mock.Arguments) { populateResponse := func(args mock.Arguments) {
method := args.String(3) method := args.String(3)
@ -920,7 +925,9 @@ func TestRootsUpdate(t *testing.T) {
}) })
// when a cache event comes in we end up recalculating the fallback timer which requires this call // when a cache event comes in we end up recalculating the fallback timer which requires this call
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Once() testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
NotAfter: time.Now().Add(10 * time.Minute),
}).Once()
req := structs.DCSpecificRequest{Datacenter: "dc1"} req := structs.DCSpecificRequest{Datacenter: "dc1"}
require.True(t, testAC.mcfg.cache.sendNotification(context.Background(), req.CacheInfo().Key, cache.UpdateEvent{ require.True(t, testAC.mcfg.cache.sendNotification(context.Background(), req.CacheInfo().Key, cache.UpdateEvent{
@ -960,7 +967,9 @@ func TestCertUpdate(t *testing.T) {
}) })
// when a cache event comes in we end up recalculating the fallback timer which requires this call // when a cache event comes in we end up recalculating the fallback timer which requires this call
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(secondCert.ValidBefore).Once() testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
NotAfter: secondCert.ValidBefore,
}).Once()
req := cachetype.ConnectCALeafRequest{ req := cachetype.ConnectCALeafRequest{
Datacenter: "dc1", Datacenter: "dc1",
@ -1025,8 +1034,9 @@ func TestFallback(t *testing.T) {
}) })
// when a cache event comes in we end up recalculating the fallback timer which requires this call // when a cache event comes in we end up recalculating the fallback timer which requires this call
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(secondCert.ValidBefore).Once() testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
testAC.mcfg.tlsCfg.On("AutoEncryptCertExpired").Return(true).Once() NotAfter: secondCert.ValidBefore,
}).Times(2)
fallbackCtx, fallbackCancel := context.WithCancel(context.Background()) fallbackCtx, fallbackCancel := context.WithCancel(context.Background())
@ -1082,7 +1092,9 @@ func TestFallback(t *testing.T) {
testAC.mcfg.expectInitialTLS(t, "autoconf", "dc1", testAC.originalToken, secondCA, &secondRoots, thirdCert, testAC.extraCerts) testAC.mcfg.expectInitialTLS(t, "autoconf", "dc1", testAC.originalToken, secondCA, &secondRoots, thirdCert, testAC.extraCerts)
// after the second RPC we now will use the new certs validity period in the next run loop iteration // after the second RPC we now will use the new certs validity period in the next run loop iteration
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Once() testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
NotAfter: time.Now().Add(10 * time.Minute),
}).Once()
// now that all the mocks are set up we can trigger the whole thing by sending the second expired cert // now that all the mocks are set up we can trigger the whole thing by sending the second expired cert
// as a cache update event. // as a cache update event.

View File

@ -406,7 +406,9 @@ func TestAutoEncrypt_RootsUpdate(t *testing.T) {
}) })
// when a cache event comes in we end up recalculating the fallback timer which requires this call // when a cache event comes in we end up recalculating the fallback timer which requires this call
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Once() testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
NotAfter: time.Now().Add(10 * time.Minute),
}).Once()
req := structs.DCSpecificRequest{Datacenter: "dc1"} req := structs.DCSpecificRequest{Datacenter: "dc1"}
require.True(t, testAC.mcfg.cache.sendNotification(context.Background(), req.CacheInfo().Key, cache.UpdateEvent{ require.True(t, testAC.mcfg.cache.sendNotification(context.Background(), req.CacheInfo().Key, cache.UpdateEvent{
@ -433,7 +435,9 @@ func TestAutoEncrypt_CertUpdate(t *testing.T) {
}) })
// when a cache event comes in we end up recalculating the fallback timer which requires this call // when a cache event comes in we end up recalculating the fallback timer which requires this call
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(secondCert.ValidBefore).Once() testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
NotAfter: secondCert.ValidBefore,
}).Once()
req := cachetype.ConnectCALeafRequest{ req := cachetype.ConnectCALeafRequest{
Datacenter: "dc1", Datacenter: "dc1",
@ -484,8 +488,9 @@ func TestAutoEncrypt_Fallback(t *testing.T) {
}) })
// when a cache event comes in we end up recalculating the fallback timer which requires this call // when a cache event comes in we end up recalculating the fallback timer which requires this call
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(secondCert.ValidBefore).Once() testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
testAC.mcfg.tlsCfg.On("AutoEncryptCertExpired").Return(true).Once() NotAfter: secondCert.ValidBefore,
}).Times(2)
fallbackCtx, fallbackCancel := context.WithCancel(context.Background()) fallbackCtx, fallbackCancel := context.WithCancel(context.Background())
@ -536,7 +541,9 @@ func TestAutoEncrypt_Fallback(t *testing.T) {
testAC.mcfg.expectInitialTLS(t, "autoconf", "dc1", testAC.originalToken, secondCA, &secondRoots, thirdCert, testAC.extraCerts) testAC.mcfg.expectInitialTLS(t, "autoconf", "dc1", testAC.originalToken, secondCA, &secondRoots, thirdCert, testAC.extraCerts)
// after the second RPC we now will use the new certs validity period in the next run loop iteration // after the second RPC we now will use the new certs validity period in the next run loop iteration
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Once() testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
NotAfter: time.Now().Add(10 * time.Minute),
}).Once()
// now that all the mocks are set up we can trigger the whole thing by sending the second expired cert // now that all the mocks are set up we can trigger the whole thing by sending the second expired cert
// as a cache update event. // as a cache update event.

View File

@ -2,6 +2,7 @@ package autoconf
import ( import (
"context" "context"
"crypto/x509"
"net" "net"
"time" "time"
@ -41,8 +42,7 @@ type TLSConfigurator interface {
UpdateAutoTLS(manualCAPEMs, connectCAPEMs []string, pub, priv string, verifyServerHostname bool) error UpdateAutoTLS(manualCAPEMs, connectCAPEMs []string, pub, priv string, verifyServerHostname bool) error
UpdateAutoTLSCA([]string) error UpdateAutoTLSCA([]string) error
UpdateAutoTLSCert(pub, priv string) error UpdateAutoTLSCert(pub, priv string) error
AutoEncryptCertNotAfter() time.Time AutoEncryptCert() *x509.Certificate
AutoEncryptCertExpired() bool
} }
// TokenStore is an interface of the methods we will need to use from the token.Store. // TokenStore is an interface of the methods we will need to use from the token.Store.

View File

@ -2,10 +2,12 @@ package autoconf
import ( import (
"context" "context"
"crypto/x509"
"net" "net"
"sync" "sync"
"testing" "testing"
"time"
"github.com/stretchr/testify/mock"
"github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/cache"
cachetype "github.com/hashicorp/consul/agent/cache-types" cachetype "github.com/hashicorp/consul/agent/cache-types"
@ -15,7 +17,6 @@ import (
"github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/proto/pbautoconf" "github.com/hashicorp/consul/proto/pbautoconf"
"github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil"
"github.com/stretchr/testify/mock"
) )
type mockDirectRPC struct { type mockDirectRPC struct {
@ -72,6 +73,7 @@ func (m *mockTLSConfigurator) UpdateAutoTLSCA(pems []string) error {
ret := m.Called(pems) ret := m.Called(pems)
return ret.Error(0) return ret.Error(0)
} }
func (m *mockTLSConfigurator) UpdateAutoTLSCert(pub, priv string) error { func (m *mockTLSConfigurator) UpdateAutoTLSCert(pub, priv string) error {
if priv != "" { if priv != "" {
priv = "redacted" priv = "redacted"
@ -79,15 +81,11 @@ func (m *mockTLSConfigurator) UpdateAutoTLSCert(pub, priv string) error {
ret := m.Called(pub, priv) ret := m.Called(pub, priv)
return ret.Error(0) return ret.Error(0)
} }
func (m *mockTLSConfigurator) AutoEncryptCertNotAfter() time.Time {
ret := m.Called()
ts, _ := ret.Get(0).(time.Time)
return ts func (m *mockTLSConfigurator) AutoEncryptCert() *x509.Certificate {
}
func (m *mockTLSConfigurator) AutoEncryptCertExpired() bool {
ret := m.Called() ret := m.Called()
return ret.Bool(0) cert, _ := ret.Get(0).(*x509.Certificate)
return cert
} }
type mockServerProvider struct { type mockServerProvider struct {

View File

@ -116,14 +116,17 @@ func (ac *AutoConfig) run(ctx context.Context, exit chan struct{}) {
// expires. The agent cache should be handling the expiration // expires. The agent cache should be handling the expiration
// and renew it before then. // and renew it before then.
// //
// If there is no cert, AutoEncryptCertNotAfter returns // If there is no cert, use a value which immediately triggers the
// a value in the past which immediately triggers the
// renew, but this case shouldn't happen because at // renew, but this case shouldn't happen because at
// this point, auto_encrypt was just being setup // this point, auto_encrypt was just being setup
// successfully. // successfully.
calcFallbackInterval := func() time.Duration { calcFallbackInterval := func() time.Duration {
certExpiry := ac.acConfig.TLSConfigurator.AutoEncryptCertNotAfter() cert := ac.acConfig.TLSConfigurator.AutoEncryptCert()
return certExpiry.Add(ac.acConfig.FallbackLeeway).Sub(time.Now()) if cert == nil {
return -1
}
expiry := cert.NotAfter.Add(ac.acConfig.FallbackLeeway)
return expiry.Sub(time.Now())
} }
fallbackTimer := time.NewTimer(calcFallbackInterval()) fallbackTimer := time.NewTimer(calcFallbackInterval())
@ -174,7 +177,8 @@ func (ac *AutoConfig) run(ctx context.Context, exit chan struct{}) {
// never use the AutoEncrypt.Sign endpoint. // never use the AutoEncrypt.Sign endpoint.
// check auto encrypt client cert expiration // check auto encrypt client cert expiration
if ac.acConfig.TLSConfigurator.AutoEncryptCertExpired() { cert := ac.acConfig.TLSConfigurator.AutoEncryptCert()
if cert == nil || cert.NotAfter.Before(time.Now()) {
if err := ac.handleFallback(ctx); err != nil { if err := ac.handleFallback(ctx); err != nil {
ac.logger.Error("error when handling a certificate expiry event", "error", err) ac.logger.Error("error when handling a certificate expiry event", "error", err)
fallbackTimer = time.NewTimer(ac.acConfig.FallbackRetry) fallbackTimer = time.NewTimer(ac.acConfig.FallbackRetry)

View File

@ -796,25 +796,19 @@ func (c *Configurator) OutgoingALPNRPCWrapper() ALPNWrapper {
return c.wrapALPNTLSClient return c.wrapALPNTLSClient
} }
// AutoEncryptCertNotAfter returns NotAfter from the auto_encrypt cert. In case // AutoEncryptCert returns the TLS certificate received from auto-encrypt.
// there is no cert, it will return a time in the past. func (c *Configurator) AutoEncryptCert() *x509.Certificate {
func (c *Configurator) AutoEncryptCertNotAfter() time.Time {
c.lock.RLock() c.lock.RLock()
defer c.lock.RUnlock() defer c.lock.RUnlock()
tlsCert := c.autoTLS.cert tlsCert := c.autoTLS.cert
if tlsCert == nil || tlsCert.Certificate == nil { if tlsCert == nil || tlsCert.Certificate == nil {
return time.Now().AddDate(0, 0, -1) return nil
} }
cert, err := x509.ParseCertificate(tlsCert.Certificate[0]) cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil { if err != nil {
return time.Now().AddDate(0, 0, -1) return nil
} }
return cert.NotAfter return cert
}
// AutoEncryptCertExpired returns if the auto_encrypt cert is expired.
func (c *Configurator) AutoEncryptCertExpired() bool {
return c.AutoEncryptCertNotAfter().Before(time.Now())
} }
func (c *Configurator) log(name string) { func (c *Configurator) log(name string) {

View File

@ -1120,19 +1120,19 @@ func TestConfigurator_VerifyServerHostname(t *testing.T) {
require.True(t, c.VerifyServerHostname()) require.True(t, c.VerifyServerHostname())
} }
func TestConfigurator_AutoEncrytCertExpired(t *testing.T) { func TestConfigurator_AutoEncryptCert(t *testing.T) {
c := Configurator{base: &Config{}} c := Configurator{base: &Config{}}
require.True(t, c.AutoEncryptCertExpired()) require.Nil(t, c.AutoEncryptCert())
cert, err := loadKeyPair("../test/key/something_expired.cer", "../test/key/something_expired.key") cert, err := loadKeyPair("../test/key/something_expired.cer", "../test/key/something_expired.key")
require.NoError(t, err) require.NoError(t, err)
c.autoTLS.cert = cert c.autoTLS.cert = cert
require.True(t, c.AutoEncryptCertExpired()) require.Equal(t, int64(1561561551), c.AutoEncryptCert().NotAfter.Unix())
cert, err = loadKeyPair("../test/key/ourdomain.cer", "../test/key/ourdomain.key") cert, err = loadKeyPair("../test/key/ourdomain.cer", "../test/key/ourdomain.key")
require.NoError(t, err) require.NoError(t, err)
c.autoTLS.cert = cert c.autoTLS.cert = cert
require.False(t, c.AutoEncryptCertExpired()) require.Equal(t, int64(4679716209), c.AutoEncryptCert().NotAfter.Unix())
} }
func TestConfig_tlsVersions(t *testing.T) { func TestConfig_tlsVersions(t *testing.T) {