tlsutil: reduce interface provided to auto-config
Replace two methods with a single one that returns the cert. This moves more of the logic into the single caller (auto-config). tlsutil.Configurator is widely used. By keeping it smaller and focused only on storing and returning TLS config, we make the code easier to follow. These two methods were more related to auto-config than to tlsutil, so reducing the interface moves the logic closer to the feature that requires it.
This commit is contained in:
parent
1bdcd3df91
commit
86c9cb037f
|
@ -2,6 +2,7 @@ package autoconf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
|
@ -577,7 +578,9 @@ func TestGoRoutineManagement(t *testing.T) {
|
||||||
mcfg.tokens.On("Notify", token.TokenKindAgent).Return(token.Notifier{}).Times(2)
|
mcfg.tokens.On("Notify", token.TokenKindAgent).Return(token.Notifier{}).Times(2)
|
||||||
mcfg.tokens.On("StopNotify", token.Notifier{}).Times(2)
|
mcfg.tokens.On("StopNotify", token.Notifier{}).Times(2)
|
||||||
|
|
||||||
mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Times(0)
|
mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
|
||||||
|
NotAfter: time.Now().Add(10 * time.Minute),
|
||||||
|
}).Times(0)
|
||||||
|
|
||||||
// ensure that auto-config isn't running
|
// ensure that auto-config isn't running
|
||||||
require.False(t, ac.IsRunning())
|
require.False(t, ac.IsRunning())
|
||||||
|
@ -734,7 +737,9 @@ func startedAutoConfig(t *testing.T, autoEncrypt bool) testAutoConfig {
|
||||||
|
|
||||||
indexedRoots, cert, extraCerts := mcfg.setupInitialTLS(t, "autoconf", "dc1", originalToken)
|
indexedRoots, cert, extraCerts := mcfg.setupInitialTLS(t, "autoconf", "dc1", originalToken)
|
||||||
|
|
||||||
mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(cert.ValidBefore).Once()
|
mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
|
||||||
|
NotAfter: cert.ValidBefore,
|
||||||
|
}).Once()
|
||||||
|
|
||||||
populateResponse := func(args mock.Arguments) {
|
populateResponse := func(args mock.Arguments) {
|
||||||
method := args.String(3)
|
method := args.String(3)
|
||||||
|
@ -920,7 +925,9 @@ func TestRootsUpdate(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// when a cache event comes in we end up recalculating the fallback timer which requires this call
|
// when a cache event comes in we end up recalculating the fallback timer which requires this call
|
||||||
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Once()
|
testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
|
||||||
|
NotAfter: time.Now().Add(10 * time.Minute),
|
||||||
|
}).Once()
|
||||||
|
|
||||||
req := structs.DCSpecificRequest{Datacenter: "dc1"}
|
req := structs.DCSpecificRequest{Datacenter: "dc1"}
|
||||||
require.True(t, testAC.mcfg.cache.sendNotification(context.Background(), req.CacheInfo().Key, cache.UpdateEvent{
|
require.True(t, testAC.mcfg.cache.sendNotification(context.Background(), req.CacheInfo().Key, cache.UpdateEvent{
|
||||||
|
@ -960,7 +967,9 @@ func TestCertUpdate(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// when a cache event comes in we end up recalculating the fallback timer which requires this call
|
// when a cache event comes in we end up recalculating the fallback timer which requires this call
|
||||||
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(secondCert.ValidBefore).Once()
|
testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
|
||||||
|
NotAfter: secondCert.ValidBefore,
|
||||||
|
}).Once()
|
||||||
|
|
||||||
req := cachetype.ConnectCALeafRequest{
|
req := cachetype.ConnectCALeafRequest{
|
||||||
Datacenter: "dc1",
|
Datacenter: "dc1",
|
||||||
|
@ -1025,8 +1034,9 @@ func TestFallback(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// when a cache event comes in we end up recalculating the fallback timer which requires this call
|
// when a cache event comes in we end up recalculating the fallback timer which requires this call
|
||||||
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(secondCert.ValidBefore).Once()
|
testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
|
||||||
testAC.mcfg.tlsCfg.On("AutoEncryptCertExpired").Return(true).Once()
|
NotAfter: secondCert.ValidBefore,
|
||||||
|
}).Times(2)
|
||||||
|
|
||||||
fallbackCtx, fallbackCancel := context.WithCancel(context.Background())
|
fallbackCtx, fallbackCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
@ -1082,7 +1092,9 @@ func TestFallback(t *testing.T) {
|
||||||
testAC.mcfg.expectInitialTLS(t, "autoconf", "dc1", testAC.originalToken, secondCA, &secondRoots, thirdCert, testAC.extraCerts)
|
testAC.mcfg.expectInitialTLS(t, "autoconf", "dc1", testAC.originalToken, secondCA, &secondRoots, thirdCert, testAC.extraCerts)
|
||||||
|
|
||||||
// after the second RPC we now will use the new certs validity period in the next run loop iteration
|
// after the second RPC we now will use the new certs validity period in the next run loop iteration
|
||||||
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Once()
|
testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
|
||||||
|
NotAfter: time.Now().Add(10 * time.Minute),
|
||||||
|
}).Once()
|
||||||
|
|
||||||
// now that all the mocks are set up we can trigger the whole thing by sending the second expired cert
|
// now that all the mocks are set up we can trigger the whole thing by sending the second expired cert
|
||||||
// as a cache update event.
|
// as a cache update event.
|
||||||
|
|
|
@ -406,7 +406,9 @@ func TestAutoEncrypt_RootsUpdate(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// when a cache event comes in we end up recalculating the fallback timer which requires this call
|
// when a cache event comes in we end up recalculating the fallback timer which requires this call
|
||||||
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Once()
|
testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
|
||||||
|
NotAfter: time.Now().Add(10 * time.Minute),
|
||||||
|
}).Once()
|
||||||
|
|
||||||
req := structs.DCSpecificRequest{Datacenter: "dc1"}
|
req := structs.DCSpecificRequest{Datacenter: "dc1"}
|
||||||
require.True(t, testAC.mcfg.cache.sendNotification(context.Background(), req.CacheInfo().Key, cache.UpdateEvent{
|
require.True(t, testAC.mcfg.cache.sendNotification(context.Background(), req.CacheInfo().Key, cache.UpdateEvent{
|
||||||
|
@ -433,7 +435,9 @@ func TestAutoEncrypt_CertUpdate(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// when a cache event comes in we end up recalculating the fallback timer which requires this call
|
// when a cache event comes in we end up recalculating the fallback timer which requires this call
|
||||||
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(secondCert.ValidBefore).Once()
|
testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
|
||||||
|
NotAfter: secondCert.ValidBefore,
|
||||||
|
}).Once()
|
||||||
|
|
||||||
req := cachetype.ConnectCALeafRequest{
|
req := cachetype.ConnectCALeafRequest{
|
||||||
Datacenter: "dc1",
|
Datacenter: "dc1",
|
||||||
|
@ -484,8 +488,9 @@ func TestAutoEncrypt_Fallback(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// when a cache event comes in we end up recalculating the fallback timer which requires this call
|
// when a cache event comes in we end up recalculating the fallback timer which requires this call
|
||||||
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(secondCert.ValidBefore).Once()
|
testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
|
||||||
testAC.mcfg.tlsCfg.On("AutoEncryptCertExpired").Return(true).Once()
|
NotAfter: secondCert.ValidBefore,
|
||||||
|
}).Times(2)
|
||||||
|
|
||||||
fallbackCtx, fallbackCancel := context.WithCancel(context.Background())
|
fallbackCtx, fallbackCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
@ -536,7 +541,9 @@ func TestAutoEncrypt_Fallback(t *testing.T) {
|
||||||
testAC.mcfg.expectInitialTLS(t, "autoconf", "dc1", testAC.originalToken, secondCA, &secondRoots, thirdCert, testAC.extraCerts)
|
testAC.mcfg.expectInitialTLS(t, "autoconf", "dc1", testAC.originalToken, secondCA, &secondRoots, thirdCert, testAC.extraCerts)
|
||||||
|
|
||||||
// after the second RPC we now will use the new certs validity period in the next run loop iteration
|
// after the second RPC we now will use the new certs validity period in the next run loop iteration
|
||||||
testAC.mcfg.tlsCfg.On("AutoEncryptCertNotAfter").Return(time.Now().Add(10 * time.Minute)).Once()
|
testAC.mcfg.tlsCfg.On("AutoEncryptCert").Return(&x509.Certificate{
|
||||||
|
NotAfter: time.Now().Add(10 * time.Minute),
|
||||||
|
}).Once()
|
||||||
|
|
||||||
// now that all the mocks are set up we can trigger the whole thing by sending the second expired cert
|
// now that all the mocks are set up we can trigger the whole thing by sending the second expired cert
|
||||||
// as a cache update event.
|
// as a cache update event.
|
||||||
|
|
|
@ -2,6 +2,7 @@ package autoconf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -41,8 +42,7 @@ type TLSConfigurator interface {
|
||||||
UpdateAutoTLS(manualCAPEMs, connectCAPEMs []string, pub, priv string, verifyServerHostname bool) error
|
UpdateAutoTLS(manualCAPEMs, connectCAPEMs []string, pub, priv string, verifyServerHostname bool) error
|
||||||
UpdateAutoTLSCA([]string) error
|
UpdateAutoTLSCA([]string) error
|
||||||
UpdateAutoTLSCert(pub, priv string) error
|
UpdateAutoTLSCert(pub, priv string) error
|
||||||
AutoEncryptCertNotAfter() time.Time
|
AutoEncryptCert() *x509.Certificate
|
||||||
AutoEncryptCertExpired() bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenStore is an interface of the methods we will need to use from the token.Store.
|
// TokenStore is an interface of the methods we will need to use from the token.Store.
|
||||||
|
|
|
@ -2,10 +2,12 @@ package autoconf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent/cache"
|
"github.com/hashicorp/consul/agent/cache"
|
||||||
cachetype "github.com/hashicorp/consul/agent/cache-types"
|
cachetype "github.com/hashicorp/consul/agent/cache-types"
|
||||||
|
@ -15,7 +17,6 @@ import (
|
||||||
"github.com/hashicorp/consul/agent/token"
|
"github.com/hashicorp/consul/agent/token"
|
||||||
"github.com/hashicorp/consul/proto/pbautoconf"
|
"github.com/hashicorp/consul/proto/pbautoconf"
|
||||||
"github.com/hashicorp/consul/sdk/testutil"
|
"github.com/hashicorp/consul/sdk/testutil"
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockDirectRPC struct {
|
type mockDirectRPC struct {
|
||||||
|
@ -72,6 +73,7 @@ func (m *mockTLSConfigurator) UpdateAutoTLSCA(pems []string) error {
|
||||||
ret := m.Called(pems)
|
ret := m.Called(pems)
|
||||||
return ret.Error(0)
|
return ret.Error(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockTLSConfigurator) UpdateAutoTLSCert(pub, priv string) error {
|
func (m *mockTLSConfigurator) UpdateAutoTLSCert(pub, priv string) error {
|
||||||
if priv != "" {
|
if priv != "" {
|
||||||
priv = "redacted"
|
priv = "redacted"
|
||||||
|
@ -79,15 +81,11 @@ func (m *mockTLSConfigurator) UpdateAutoTLSCert(pub, priv string) error {
|
||||||
ret := m.Called(pub, priv)
|
ret := m.Called(pub, priv)
|
||||||
return ret.Error(0)
|
return ret.Error(0)
|
||||||
}
|
}
|
||||||
func (m *mockTLSConfigurator) AutoEncryptCertNotAfter() time.Time {
|
|
||||||
ret := m.Called()
|
|
||||||
ts, _ := ret.Get(0).(time.Time)
|
|
||||||
|
|
||||||
return ts
|
func (m *mockTLSConfigurator) AutoEncryptCert() *x509.Certificate {
|
||||||
}
|
|
||||||
func (m *mockTLSConfigurator) AutoEncryptCertExpired() bool {
|
|
||||||
ret := m.Called()
|
ret := m.Called()
|
||||||
return ret.Bool(0)
|
cert, _ := ret.Get(0).(*x509.Certificate)
|
||||||
|
return cert
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockServerProvider struct {
|
type mockServerProvider struct {
|
||||||
|
|
|
@ -116,14 +116,17 @@ func (ac *AutoConfig) run(ctx context.Context, exit chan struct{}) {
|
||||||
// expires. The agent cache should be handling the expiration
|
// expires. The agent cache should be handling the expiration
|
||||||
// and renew it before then.
|
// and renew it before then.
|
||||||
//
|
//
|
||||||
// If there is no cert, AutoEncryptCertNotAfter returns
|
// If there is no cert, use a value which immediately triggers the
|
||||||
// a value in the past which immediately triggers the
|
|
||||||
// renew, but this case shouldn't happen because at
|
// renew, but this case shouldn't happen because at
|
||||||
// this point, auto_encrypt was just being setup
|
// this point, auto_encrypt was just being setup
|
||||||
// successfully.
|
// successfully.
|
||||||
calcFallbackInterval := func() time.Duration {
|
calcFallbackInterval := func() time.Duration {
|
||||||
certExpiry := ac.acConfig.TLSConfigurator.AutoEncryptCertNotAfter()
|
cert := ac.acConfig.TLSConfigurator.AutoEncryptCert()
|
||||||
return certExpiry.Add(ac.acConfig.FallbackLeeway).Sub(time.Now())
|
if cert == nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
expiry := cert.NotAfter.Add(ac.acConfig.FallbackLeeway)
|
||||||
|
return expiry.Sub(time.Now())
|
||||||
}
|
}
|
||||||
fallbackTimer := time.NewTimer(calcFallbackInterval())
|
fallbackTimer := time.NewTimer(calcFallbackInterval())
|
||||||
|
|
||||||
|
@ -174,7 +177,8 @@ func (ac *AutoConfig) run(ctx context.Context, exit chan struct{}) {
|
||||||
// never use the AutoEncrypt.Sign endpoint.
|
// never use the AutoEncrypt.Sign endpoint.
|
||||||
|
|
||||||
// check auto encrypt client cert expiration
|
// check auto encrypt client cert expiration
|
||||||
if ac.acConfig.TLSConfigurator.AutoEncryptCertExpired() {
|
cert := ac.acConfig.TLSConfigurator.AutoEncryptCert()
|
||||||
|
if cert == nil || cert.NotAfter.Before(time.Now()) {
|
||||||
if err := ac.handleFallback(ctx); err != nil {
|
if err := ac.handleFallback(ctx); err != nil {
|
||||||
ac.logger.Error("error when handling a certificate expiry event", "error", err)
|
ac.logger.Error("error when handling a certificate expiry event", "error", err)
|
||||||
fallbackTimer = time.NewTimer(ac.acConfig.FallbackRetry)
|
fallbackTimer = time.NewTimer(ac.acConfig.FallbackRetry)
|
||||||
|
|
|
@ -796,25 +796,19 @@ func (c *Configurator) OutgoingALPNRPCWrapper() ALPNWrapper {
|
||||||
return c.wrapALPNTLSClient
|
return c.wrapALPNTLSClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutoEncryptCertNotAfter returns NotAfter from the auto_encrypt cert. In case
|
// AutoEncryptCert returns the TLS certificate received from auto-encrypt.
|
||||||
// there is no cert, it will return a time in the past.
|
func (c *Configurator) AutoEncryptCert() *x509.Certificate {
|
||||||
func (c *Configurator) AutoEncryptCertNotAfter() time.Time {
|
|
||||||
c.lock.RLock()
|
c.lock.RLock()
|
||||||
defer c.lock.RUnlock()
|
defer c.lock.RUnlock()
|
||||||
tlsCert := c.autoTLS.cert
|
tlsCert := c.autoTLS.cert
|
||||||
if tlsCert == nil || tlsCert.Certificate == nil {
|
if tlsCert == nil || tlsCert.Certificate == nil {
|
||||||
return time.Now().AddDate(0, 0, -1)
|
return nil
|
||||||
}
|
}
|
||||||
cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
|
cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return time.Now().AddDate(0, 0, -1)
|
return nil
|
||||||
}
|
}
|
||||||
return cert.NotAfter
|
return cert
|
||||||
}
|
|
||||||
|
|
||||||
// AutoEncryptCertExpired returns if the auto_encrypt cert is expired.
|
|
||||||
func (c *Configurator) AutoEncryptCertExpired() bool {
|
|
||||||
return c.AutoEncryptCertNotAfter().Before(time.Now())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Configurator) log(name string) {
|
func (c *Configurator) log(name string) {
|
||||||
|
|
|
@ -1120,19 +1120,19 @@ func TestConfigurator_VerifyServerHostname(t *testing.T) {
|
||||||
require.True(t, c.VerifyServerHostname())
|
require.True(t, c.VerifyServerHostname())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfigurator_AutoEncrytCertExpired(t *testing.T) {
|
func TestConfigurator_AutoEncryptCert(t *testing.T) {
|
||||||
c := Configurator{base: &Config{}}
|
c := Configurator{base: &Config{}}
|
||||||
require.True(t, c.AutoEncryptCertExpired())
|
require.Nil(t, c.AutoEncryptCert())
|
||||||
|
|
||||||
cert, err := loadKeyPair("../test/key/something_expired.cer", "../test/key/something_expired.key")
|
cert, err := loadKeyPair("../test/key/something_expired.cer", "../test/key/something_expired.key")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
c.autoTLS.cert = cert
|
c.autoTLS.cert = cert
|
||||||
require.True(t, c.AutoEncryptCertExpired())
|
require.Equal(t, int64(1561561551), c.AutoEncryptCert().NotAfter.Unix())
|
||||||
|
|
||||||
cert, err = loadKeyPair("../test/key/ourdomain.cer", "../test/key/ourdomain.key")
|
cert, err = loadKeyPair("../test/key/ourdomain.cer", "../test/key/ourdomain.key")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
c.autoTLS.cert = cert
|
c.autoTLS.cert = cert
|
||||||
require.False(t, c.AutoEncryptCertExpired())
|
require.Equal(t, int64(4679716209), c.AutoEncryptCert().NotAfter.Unix())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_tlsVersions(t *testing.T) {
|
func TestConfig_tlsVersions(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue