backport of commit 3d37a2507bc1e54e2dc5e95c7cd099790543b3d1 (#23810)

Co-authored-by: Peter Wilson <peter.wilson@hashicorp.com>
This commit is contained in:
hc-github-team-secure-vault-core 2023-10-24 18:07:54 -04:00 committed by GitHub
parent 9d1caca4e5
commit 8cd78f723e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 96 additions and 5 deletions

View File

@ -211,6 +211,11 @@ func (b *AESGCMBarrier) Initialize(ctx context.Context, key, sealKey []byte, rea
// persistKeyring is used to write out the keyring using the
// root key to encrypt it.
func (b *AESGCMBarrier) persistKeyring(ctx context.Context, keyring *Keyring) error {
const (
// The keyring is persisted before the root key.
keyringTimeout = 1 * time.Second
)
// Create the keyring entry
keyringBuf, err := keyring.Serialize()
defer memzero(keyringBuf)
@ -221,13 +226,13 @@ func (b *AESGCMBarrier) persistKeyring(ctx context.Context, keyring *Keyring) er
// Create the AES-GCM
gcm, err := b.aeadFromKey(keyring.RootKey())
if err != nil {
return err
return fmt.Errorf("failed to retrieve AES-GCM AEAD from root key: %w", err)
}
// Encrypt the barrier init value
value, err := b.encrypt(keyringPath, initialKeyTerm, gcm, keyringBuf)
if err != nil {
return err
return fmt.Errorf("failed to encrypt barrier initial value: %w", err)
}
// Create the keyring physical entry
@ -235,7 +240,12 @@ func (b *AESGCMBarrier) persistKeyring(ctx context.Context, keyring *Keyring) er
Key: keyringPath,
Value: value,
}
if err := b.backend.Put(ctx, pe); err != nil {
// We reduce the timeout on the initial 'put' but if this succeeds we will
// allow longer later on when we try to persist the root key .
ctxKeyring, cancelKeyring := context.WithTimeout(ctx, keyringTimeout)
defer cancelKeyring()
if err := b.backend.Put(ctxKeyring, pe); err != nil {
return fmt.Errorf("failed to persist keyring: %w", err)
}
@ -255,11 +265,11 @@ func (b *AESGCMBarrier) persistKeyring(ctx context.Context, keyring *Keyring) er
activeKey := keyring.ActiveKey()
aead, err := b.aeadFromKey(activeKey.Value)
if err != nil {
return err
return fmt.Errorf("failed to retrieve AES-GCM AEAD from active key: %w", err)
}
value, err = b.encryptTracked(rootKeyPath, activeKey.Term, aead, keyBuf)
if err != nil {
return err
return fmt.Errorf("failed to encrypt and track active key value: %w", err)
}
// Update the rootKeyPath for standby instances
@ -267,6 +277,9 @@ func (b *AESGCMBarrier) persistKeyring(ctx context.Context, keyring *Keyring) er
Key: rootKeyPath,
Value: value,
}
// Use the longer timeout from the original context, for the follow-up write
// to persist the root key, as the initial storage of the keyring was successful.
if err := b.backend.Put(ctx, pe); err != nil {
return fmt.Errorf("failed to persist root key: %w", err)
}

View File

@ -12,10 +12,12 @@ import (
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/physical"
"github.com/hashicorp/vault/sdk/physical/inmem"
"github.com/stretchr/testify/require"
)
var logger = logging.NewVaultLogger(log.Trace)
@ -694,3 +696,79 @@ func TestBarrier_LegacyRotate(t *testing.T) {
t.Fail()
}
}
// TestBarrier_persistKeyring_Context checks that we get the right errors if
// the context is cancelled or times-out before the first part of persistKeyring
// is able to persist the keyring itself (i.e. we don't go on to try and persist
// the root key).
func TestBarrier_persistKeyring_Context(t *testing.T) {
t.Parallel()
tests := map[string]struct {
shouldCancel bool
isErrorExpected bool
expectedErrorMessage string
contextTimeout time.Duration
testTimeout time.Duration
}{
"cancelled": {
shouldCancel: true,
isErrorExpected: true,
expectedErrorMessage: "failed to persist keyring: context canceled",
contextTimeout: 8 * time.Second,
testTimeout: 10 * time.Second,
},
"timeout-before-keyring": {
isErrorExpected: true,
expectedErrorMessage: "failed to persist keyring: context deadline exceeded",
contextTimeout: 1 * time.Nanosecond,
testTimeout: 5 * time.Second,
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
// Set up barrier
backend, err := inmem.NewInmem(nil, corehelpers.NewTestLogger(t))
require.NoError(t, err)
barrier, err := NewAESGCMBarrier(backend)
require.NoError(t, err)
key, err := barrier.GenerateKey(rand.Reader)
require.NoError(t, err)
err = barrier.Initialize(context.Background(), key, nil, rand.Reader)
require.NoError(t, err)
err = barrier.Unseal(context.Background(), key)
require.NoError(t, err)
k := barrier.keyring.TermKey(1)
k.Encryptions = 0
k.InstallTime = time.Now().Add(-24 * 366 * time.Hour)
// Persist the keyring
ctx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout)
persistChan := make(chan error)
go func() {
if tc.shouldCancel {
cancel()
}
persistChan <- barrier.persistKeyring(ctx, barrier.keyring)
}()
select {
case err := <-persistChan:
switch {
case tc.isErrorExpected:
require.Error(t, err)
require.EqualError(t, err, tc.expectedErrorMessage)
default:
require.NoError(t, err)
}
case <-time.After(tc.testTimeout):
t.Fatal("timeout reached")
}
})
}
}