diff --git a/.changelog/12607.txt b/.changelog/12607.txt new file mode 100644 index 000000000..65577d1a1 --- /dev/null +++ b/.changelog/12607.txt @@ -0,0 +1,3 @@ +```release-note:bug +connect/ca: cancel old Vault renewal on CA configuration. Provide a 1 - 6 second backoff on repeated token renewal requests to prevent overwhelming Vault. +``` \ No newline at end of file diff --git a/agent/connect/ca/provider_vault.go b/agent/connect/ca/provider_vault.go index 91b92528c..beec649c3 100644 --- a/agent/connect/ca/provider_vault.go +++ b/agent/connect/ca/provider_vault.go @@ -12,13 +12,14 @@ import ( "strings" "time" + "github.com/hashicorp/consul/lib/decode" + "github.com/hashicorp/consul/lib/retry" "github.com/hashicorp/go-hclog" vaultapi "github.com/hashicorp/vault/api" "github.com/mitchellh/mapstructure" "github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/structs" - "github.com/hashicorp/consul/lib/decode" ) const ( @@ -43,6 +44,10 @@ const ( VaultAuthMethodTypeUserpass = "userpass" defaultK8SServiceAccountTokenPath = "/var/run/secrets/kubernetes.io/serviceaccount/token" + + retryMin = 1 * time.Second + retryMax = 5 * time.Second + retryJitter = 20 ) var ErrBackendNotMounted = fmt.Errorf("backend not mounted") @@ -52,7 +57,7 @@ type VaultProvider struct { config *structs.VaultCAProviderConfig client *vaultapi.Client - shutdown func() + stopWatcher func() isPrimary bool clusterID string @@ -63,8 +68,8 @@ type VaultProvider struct { func NewVaultProvider(logger hclog.Logger) *VaultProvider { return &VaultProvider{ - shutdown: func() {}, - logger: logger, + stopWatcher: func() {}, + logger: logger, } } @@ -153,7 +158,10 @@ func (v *VaultProvider) Configure(cfg ProviderConfig) error { } ctx, cancel := context.WithCancel(context.Background()) - v.shutdown = cancel + if v.stopWatcher != nil { + v.stopWatcher() + } + v.stopWatcher = cancel go v.renewToken(ctx, lifetimeWatcher) } @@ -195,16 +203,33 @@ func (v *VaultProvider) renewToken(ctx context.Context, watcher *vaultapi.Lifeti go watcher.Start() defer watcher.Stop() + // TODO: Once we've upgraded to a later version of protobuf we can upgrade to github.com/hashicorp/vault/api@1.1.1 + // or later and rip this out. + retrier := retry.Waiter{ + MinFailures: 5, + MinWait: retryMin, + MaxWait: retryMax, + Jitter: retry.NewJitter(retryJitter), + } + for { select { case <-ctx.Done(): return case err := <-watcher.DoneCh(): + // In the event we fail to login to Vault or our token is no longer valid we can overwhelm a Vault instance + // with rate limit configured. We would make these requests to Vault as fast as we possibly could and start + // causing all client's to receive 429 response codes. To mitigate that we're sleeping 1 second or less + // before moving on to login again and restart the lifetime watcher. Once we can upgrade to + // github.com/hashicorp/vault/api@v1.1.1 or later the LifetimeWatcher _should_ perform that backoff for us. if err != nil { v.logger.Error("Error renewing token for Vault provider", "error", err) } + // wait at least 1 second after returning from the lifetime watcher + retrier.Wait(ctx) + // If the watcher has exited and auth method is enabled, // re-authenticate using the auth method and set up a new watcher. if v.config.AuthMethod != nil { @@ -212,7 +237,7 @@ func (v *VaultProvider) renewToken(ctx context.Context, watcher *vaultapi.Lifeti loginResp, err := vaultLogin(v.client, v.config.AuthMethod) if err != nil { v.logger.Error("Error login in to Vault with %q auth method", v.config.AuthMethod.Type) - // Restart the watcher. + // Restart the watcher go watcher.Start() continue } @@ -232,11 +257,12 @@ func (v *VaultProvider) renewToken(ctx context.Context, watcher *vaultapi.Lifeti continue } } - // Restart the watcher. + go watcher.Start() case <-watcher.RenewCh(): + retrier.Reset() v.logger.Info("Successfully renewed token for Vault provider") } } @@ -677,7 +703,7 @@ func (v *VaultProvider) Cleanup(providerTypeChange bool, otherConfig map[string] // Stop shuts down the token renew goroutine. func (v *VaultProvider) Stop() { - v.shutdown() + v.stopWatcher() } func (v *VaultProvider) PrimaryUsesIntermediate() {} diff --git a/agent/connect/ca/provider_vault_test.go b/agent/connect/ca/provider_vault_test.go index 460507383..11689ae69 100644 --- a/agent/connect/ca/provider_vault_test.go +++ b/agent/connect/ca/provider_vault_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "sync/atomic" "testing" "time" @@ -212,6 +213,52 @@ func TestVaultCAProvider_RenewToken(t *testing.T) { }) } +func TestVaultCAProvider_RenewTokenStopWatcherOnConfigure(t *testing.T) { + + SkipIfVaultNotPresent(t) + + testVault, err := runTestVault(t) + require.NoError(t, err) + testVault.WaitUntilReady(t) + + // Create a token with a short TTL to be renewed by the provider. + ttl := 1 * time.Second + tcr := &vaultapi.TokenCreateRequest{ + TTL: ttl.String(), + } + secret, err := testVault.client.Auth().Token().Create(tcr) + require.NoError(t, err) + providerToken := secret.Auth.ClientToken + + provider, err := createVaultProvider(t, true, testVault.Addr, providerToken, nil) + require.NoError(t, err) + + var gotStopped = uint32(0) + provider.stopWatcher = func() { + atomic.StoreUint32(&gotStopped, 1) + } + + // Check the last renewal time. + secret, err = testVault.client.Auth().Token().Lookup(providerToken) + require.NoError(t, err) + firstRenewal, err := secret.Data["last_renewal_time"].(json.Number).Int64() + require.NoError(t, err) + + // Wait past the TTL and make sure the token has been renewed. + retry.Run(t, func(r *retry.R) { + secret, err = testVault.client.Auth().Token().Lookup(providerToken) + require.NoError(r, err) + lastRenewal, err := secret.Data["last_renewal_time"].(json.Number).Int64() + require.NoError(r, err) + require.Greater(r, lastRenewal, firstRenewal) + }) + + providerConfig := vaultProviderConfig(t, testVault.Addr, providerToken, nil) + + require.NoError(t, provider.Configure(providerConfig)) + require.Equal(t, uint32(1), atomic.LoadUint32(&gotStopped)) +} + func TestVaultCAProvider_Bootstrap(t *testing.T) { SkipIfVaultNotPresent(t) @@ -762,27 +809,10 @@ func testVaultProviderWithConfig(t *testing.T, isPrimary bool, rawConf map[strin } func createVaultProvider(t *testing.T, isPrimary bool, addr, token string, rawConf map[string]interface{}) (*VaultProvider, error) { - conf := map[string]interface{}{ - "Address": addr, - "Token": token, - "RootPKIPath": "pki-root/", - "IntermediatePKIPath": "pki-intermediate/", - // Tests duration parsing after msgpack type mangling during raft apply. - "LeafCertTTL": []uint8("72h"), - } - for k, v := range rawConf { - conf[k] = v - } + cfg := vaultProviderConfig(t, addr, token, rawConf) provider := NewVaultProvider(hclog.New(nil)) - cfg := ProviderConfig{ - ClusterID: connect.TestClusterID, - Datacenter: "dc1", - IsPrimary: true, - RawConfig: conf, - } - if !isPrimary { cfg.IsPrimary = false cfg.Datacenter = "dc2" @@ -799,3 +829,26 @@ func createVaultProvider(t *testing.T, isPrimary bool, addr, token string, rawCo return provider, nil } + +func vaultProviderConfig(t *testing.T, addr, token string, rawConf map[string]interface{}) ProviderConfig { + conf := map[string]interface{}{ + "Address": addr, + "Token": token, + "RootPKIPath": "pki-root/", + "IntermediatePKIPath": "pki-intermediate/", + // Tests duration parsing after msgpack type mangling during raft apply. + "LeafCertTTL": []uint8("72h"), + } + for k, v := range rawConf { + conf[k] = v + } + + cfg := ProviderConfig{ + ClusterID: connect.TestClusterID, + Datacenter: "dc1", + IsPrimary: true, + RawConfig: conf, + } + + return cfg +}