diff --git a/vault/barrier_aes_gcm.go b/vault/barrier_aes_gcm.go index 2d7cbef42..5627e7cf8 100644 --- a/vault/barrier_aes_gcm.go +++ b/vault/barrier_aes_gcm.go @@ -211,6 +211,11 @@ func (b *AESGCMBarrier) Initialize(ctx context.Context, key, sealKey []byte, rea // persistKeyring is used to write out the keyring using the // root key to encrypt it. func (b *AESGCMBarrier) persistKeyring(ctx context.Context, keyring *Keyring) error { + const ( + // The keyring is persisted before the root key. + keyringTimeout = 1 * time.Second + ) + // Create the keyring entry keyringBuf, err := keyring.Serialize() defer memzero(keyringBuf) @@ -221,13 +226,13 @@ func (b *AESGCMBarrier) persistKeyring(ctx context.Context, keyring *Keyring) er // Create the AES-GCM gcm, err := b.aeadFromKey(keyring.RootKey()) if err != nil { - return err + return fmt.Errorf("failed to retrieve AES-GCM AEAD from root key: %w", err) } // Encrypt the barrier init value value, err := b.encrypt(keyringPath, initialKeyTerm, gcm, keyringBuf) if err != nil { - return err + return fmt.Errorf("failed to encrypt barrier initial value: %w", err) } // Create the keyring physical entry @@ -235,7 +240,12 @@ func (b *AESGCMBarrier) persistKeyring(ctx context.Context, keyring *Keyring) er Key: keyringPath, Value: value, } - if err := b.backend.Put(ctx, pe); err != nil { + + // We reduce the timeout on the initial 'put' but if this succeeds we will + // allow longer later on when we try to persist the root key . + ctxKeyring, cancelKeyring := context.WithTimeout(ctx, keyringTimeout) + defer cancelKeyring() + if err := b.backend.Put(ctxKeyring, pe); err != nil { return fmt.Errorf("failed to persist keyring: %w", err) } @@ -255,11 +265,11 @@ func (b *AESGCMBarrier) persistKeyring(ctx context.Context, keyring *Keyring) er activeKey := keyring.ActiveKey() aead, err := b.aeadFromKey(activeKey.Value) if err != nil { - return err + return fmt.Errorf("failed to retrieve AES-GCM AEAD from active key: %w", err) } value, err = b.encryptTracked(rootKeyPath, activeKey.Term, aead, keyBuf) if err != nil { - return err + return fmt.Errorf("failed to encrypt and track active key value: %w", err) } // Update the rootKeyPath for standby instances @@ -267,6 +277,9 @@ func (b *AESGCMBarrier) persistKeyring(ctx context.Context, keyring *Keyring) er Key: rootKeyPath, Value: value, } + + // Use the longer timeout from the original context, for the follow-up write + // to persist the root key, as the initial storage of the keyring was successful. if err := b.backend.Put(ctx, pe); err != nil { return fmt.Errorf("failed to persist root key: %w", err) } diff --git a/vault/barrier_aes_gcm_test.go b/vault/barrier_aes_gcm_test.go index bdc9250ff..4339be466 100644 --- a/vault/barrier_aes_gcm_test.go +++ b/vault/barrier_aes_gcm_test.go @@ -12,10 +12,12 @@ import ( "time" log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/helper/testhelpers/corehelpers" "github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/physical" "github.com/hashicorp/vault/sdk/physical/inmem" + "github.com/stretchr/testify/require" ) var logger = logging.NewVaultLogger(log.Trace) @@ -694,3 +696,79 @@ func TestBarrier_LegacyRotate(t *testing.T) { t.Fail() } } + +// TestBarrier_persistKeyring_Context checks that we get the right errors if +// the context is cancelled or times-out before the first part of persistKeyring +// is able to persist the keyring itself (i.e. we don't go on to try and persist +// the root key). +func TestBarrier_persistKeyring_Context(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + shouldCancel bool + isErrorExpected bool + expectedErrorMessage string + contextTimeout time.Duration + testTimeout time.Duration + }{ + "cancelled": { + shouldCancel: true, + isErrorExpected: true, + expectedErrorMessage: "failed to persist keyring: context canceled", + contextTimeout: 8 * time.Second, + testTimeout: 10 * time.Second, + }, + "timeout-before-keyring": { + isErrorExpected: true, + expectedErrorMessage: "failed to persist keyring: context deadline exceeded", + contextTimeout: 1 * time.Nanosecond, + testTimeout: 5 * time.Second, + }, + } + + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + // Set up barrier + backend, err := inmem.NewInmem(nil, corehelpers.NewTestLogger(t)) + require.NoError(t, err) + barrier, err := NewAESGCMBarrier(backend) + require.NoError(t, err) + key, err := barrier.GenerateKey(rand.Reader) + require.NoError(t, err) + err = barrier.Initialize(context.Background(), key, nil, rand.Reader) + require.NoError(t, err) + err = barrier.Unseal(context.Background(), key) + require.NoError(t, err) + k := barrier.keyring.TermKey(1) + k.Encryptions = 0 + k.InstallTime = time.Now().Add(-24 * 366 * time.Hour) + + // Persist the keyring + ctx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout) + persistChan := make(chan error) + go func() { + if tc.shouldCancel { + cancel() + } + persistChan <- barrier.persistKeyring(ctx, barrier.keyring) + }() + + select { + case err := <-persistChan: + switch { + case tc.isErrorExpected: + require.Error(t, err) + require.EqualError(t, err, tc.expectedErrorMessage) + default: + require.NoError(t, err) + } + case <-time.After(tc.testTimeout): + t.Fatal("timeout reached") + } + }) + } +}