diff --git a/agent/connect/ca/provider_vault.go b/agent/connect/ca/provider_vault.go index dc91ecbf3..4282e8ce9 100644 --- a/agent/connect/ca/provider_vault.go +++ b/agent/connect/ca/provider_vault.go @@ -2,13 +2,13 @@ package ca import ( "bytes" + "context" "crypto/x509" "encoding/pem" "fmt" "io/ioutil" "net/http" "strings" - "sync" "github.com/hashicorp/consul/agent/connect" "github.com/hashicorp/consul/agent/structs" @@ -27,9 +27,7 @@ type VaultProvider struct { config *structs.VaultCAProviderConfig client *vaultapi.Client - shutdown bool - shutdownCh chan struct{} - shutdownLock sync.RWMutex + shutdown func() isPrimary bool clusterID string @@ -38,6 +36,10 @@ type VaultProvider struct { logger hclog.Logger } +func NewVaultProvider() *VaultProvider { + return &VaultProvider{shutdown: func() {}} +} + func vaultTLSConfig(config *structs.VaultCAProviderConfig) *vaultapi.TLSConfig { return &vaultapi.TLSConfig{ CACert: config.CAFile, @@ -74,7 +76,6 @@ func (v *VaultProvider) Configure(cfg ProviderConfig) error { v.isPrimary = cfg.IsPrimary v.clusterID = cfg.ClusterID v.spiffeID = connect.SpiffeIDSigningForCluster(&structs.CAConfiguration{ClusterID: v.clusterID}) - v.shutdownCh = make(chan struct{}, 0) // Look up the token to see if we can auto-renew its lease. secret, err := client.Auth().Token().Lookup(config.Token) @@ -99,25 +100,28 @@ func (v *VaultProvider) Configure(cfg ProviderConfig) error { LeaseDuration: secret.LeaseDuration, }, }, - Increment: int(token.TTL), + Increment: token.TTL, }) if err != nil { return fmt.Errorf("Error beginning Vault provider token renewal: %v", err) } - go v.renewToken(renewer) + + ctx, cancel := context.WithCancel(context.TODO()) + v.shutdown = cancel + go v.renewToken(ctx, renewer) } return nil } // renewToken uses a vaultapi.Renewer to repeatedly renew our token's lease. -func (v *VaultProvider) renewToken(renewer *vaultapi.Renewer) { +func (v *VaultProvider) renewToken(ctx context.Context, renewer *vaultapi.Renewer) { go renewer.Renew() + defer renewer.Stop() for { select { - case <-v.shutdownCh: - renewer.Stop() + case <-ctx.Done(): return case err := <-renewer.DoneCh(): @@ -125,6 +129,9 @@ func (v *VaultProvider) renewToken(renewer *vaultapi.Renewer) { v.logger.Error(fmt.Sprintf("Error renewing token for Vault provider: %v", err)) } + // Renewer routine has finished, so start it again. + go renewer.Renew() + case <-renewer.RenewCh(): v.logger.Error("Successfully renewed token for Vault provider") } @@ -508,13 +515,7 @@ func (v *VaultProvider) Cleanup() error { // Stop shuts down the token renew goroutine. func (v *VaultProvider) Stop() { - v.shutdownLock.Lock() - defer v.shutdownLock.Unlock() - - if !v.shutdown && v.shutdownCh != nil { - close(v.shutdownCh) - v.shutdown = true - } + v.shutdown() } func ParseVaultCAConfig(raw map[string]interface{}) (*structs.VaultCAProviderConfig, error) { diff --git a/agent/connect/ca/provider_vault_test.go b/agent/connect/ca/provider_vault_test.go index 6f432f751..3094cb092 100644 --- a/agent/connect/ca/provider_vault_test.go +++ b/agent/connect/ca/provider_vault_test.go @@ -55,14 +55,10 @@ func TestVaultCAProvider_SecondaryActiveIntermediate(t *testing.T) { func TestVaultCAProvider_RenewToken(t *testing.T) { t.Parallel() - require := require.New(t) skipIfVaultNotPresent(t) - testVault, err := runTestVault() - if err != nil { - t.Fatalf("err: %v", err) - } - + testVault, err := runTestVault(t) + require.NoError(t, err) testVault.WaitUntilReady(t) // Create a token with a short TTL to be renewed by the provider. @@ -71,26 +67,26 @@ func TestVaultCAProvider_RenewToken(t *testing.T) { TTL: ttl.String(), } secret, err := testVault.client.Auth().Token().Create(tcr) - require.NoError(err) + require.NoError(t, err) providerToken := secret.Auth.ClientToken - _, err = createVaultProvider(true, testVault.addr, providerToken, nil) - require.NoError(err) + _, err = createVaultProvider(t, true, testVault.addr, providerToken, nil) + require.NoError(t, err) // Check the last renewal time. secret, err = testVault.client.Auth().Token().Lookup(providerToken) - require.NoError(err) + require.NoError(t, err) firstRenewal, err := secret.Data["last_renewal_time"].(json.Number).Int64() - require.NoError(err) - - time.Sleep(ttl * 2) + require.NoError(t, err) // Wait past the TTL and make sure the token has been renewed. - secret, err = testVault.client.Auth().Token().Lookup(providerToken) - require.NoError(err) - lastRenewal, err := secret.Data["last_renewal_time"].(json.Number).Int64() - require.NoError(err) - require.Greater(lastRenewal, firstRenewal) + retry.Run(t, func(r *retry.R) { + secret, err = testVault.client.Auth().Token().Lookup(providerToken) + require.NoError(r, err) + lastRenewal, err := secret.Data["last_renewal_time"].(json.Number).Int64() + require.NoError(r, err) + require.Greater(r, lastRenewal, firstRenewal) + }) } func TestVaultCAProvider_Bootstrap(t *testing.T) { @@ -391,14 +387,14 @@ func testVaultProvider(t *testing.T) (*VaultProvider, *testVaultServer) { } func testVaultProviderWithConfig(t *testing.T, isPrimary bool, rawConf map[string]interface{}) (*VaultProvider, *testVaultServer) { - testVault, err := runTestVault() + testVault, err := runTestVault(t) if err != nil { t.Fatalf("err: %v", err) } testVault.WaitUntilReady(t) - provider, err := createVaultProvider(isPrimary, testVault.addr, testVault.rootToken, rawConf) + provider, err := createVaultProvider(t, isPrimary, testVault.addr, testVault.rootToken, rawConf) if err != nil { testVault.Stop() t.Fatalf("err: %v", err) @@ -406,7 +402,7 @@ func testVaultProviderWithConfig(t *testing.T, isPrimary bool, rawConf map[strin return provider, testVault } -func createVaultProvider(isPrimary bool, addr, token string, rawConf map[string]interface{}) (*VaultProvider, error) { +func createVaultProvider(t *testing.T, isPrimary bool, addr, token string, rawConf map[string]interface{}) (*VaultProvider, error) { conf := map[string]interface{}{ "Address": addr, "Token": token, @@ -419,7 +415,7 @@ func createVaultProvider(isPrimary bool, addr, token string, rawConf map[string] conf[k] = v } - provider := &VaultProvider{} + provider := NewVaultProvider() cfg := ProviderConfig{ ClusterID: connect.TestClusterID, @@ -438,16 +434,11 @@ func createVaultProvider(isPrimary bool, addr, token string, rawConf map[string] cfg.Datacenter = "dc2" } - if err := provider.Configure(cfg); err != nil { - return nil, err - } + require.NoError(t, provider.Configure(cfg)) if isPrimary { - if err := provider.GenerateRoot(); err != nil { - return nil, err - } - if _, err := provider.GenerateIntermediate(); err != nil { - return nil, err - } + require.NoError(t, provider.GenerateRoot()) + _, err := provider.GenerateIntermediate() + require.NoError(t, err) } return provider, nil @@ -469,7 +460,7 @@ func skipIfVaultNotPresent(t *testing.T) { } } -func runTestVault() (*testVaultServer, error) { +func runTestVault(t *testing.T) (*testVaultServer, error) { vaultBinaryName := os.Getenv("VAULT_BINARY_NAME") if vaultBinaryName == "" { vaultBinaryName = "vault" @@ -520,13 +511,17 @@ func runTestVault() (*testVaultServer, error) { return nil, err } - return &testVaultServer{ + testVault := &testVaultServer{ rootToken: token, addr: "http://" + clientAddr, cmd: cmd, client: client, returnPortsFn: returnPortsFn, - }, nil + } + t.Cleanup(func() { + testVault.Stop() + }) + return testVault, nil } type testVaultServer struct { diff --git a/agent/consul/connect_ca_endpoint.go b/agent/consul/connect_ca_endpoint.go index 6644030db..2198c0ec6 100644 --- a/agent/consul/connect_ca_endpoint.go +++ b/agent/consul/connect_ca_endpoint.go @@ -158,7 +158,7 @@ func (s *ConnectCA) ConfigurationSet( defer func() { if cleanupNewProvider { if err := newProvider.Cleanup(); err != nil { - s.logger.Warn("failed to clean up temporary new CA provider", "provider", newProvider) + s.logger.Warn("failed to clean up CA provider while handling startup failure", "provider", newProvider, "error", err) } } }() diff --git a/agent/consul/leader_connect.go b/agent/consul/leader_connect.go index b9049ed88..602018b57 100644 --- a/agent/consul/leader_connect.go +++ b/agent/consul/leader_connect.go @@ -116,7 +116,7 @@ func (s *Server) createCAProvider(conf *structs.CAConfiguration) (ca.Provider, e case structs.ConsulCAProvider: p = &ca.ConsulProvider{Delegate: &consulCADelegate{s}} case structs.VaultCAProvider: - p = &ca.VaultProvider{} + p = ca.NewVaultProvider() case structs.AWSCAProvider: p = &ca.AWSProvider{} default: