open-vault/vault/barrier_aes_gcm_test.go

775 lines
18 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package vault
import (
"bytes"
"context"
"crypto/rand"
"encoding/json"
"testing"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/physical"
"github.com/hashicorp/vault/sdk/physical/inmem"
"github.com/stretchr/testify/require"
)
var logger = logging.NewVaultLogger(log.Trace)
// mockBarrier returns a physical backend, security barrier, and master key
func mockBarrier(t testing.TB) (physical.Backend, SecurityBarrier, []byte) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
// Initialize and unseal
key, _ := b.GenerateKey(rand.Reader)
b.Initialize(context.Background(), key, nil, rand.Reader)
b.Unseal(context.Background(), key)
return inm, b, key
}
func TestAESGCMBarrier_Basic(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
testBarrier(t, b)
}
func TestAESGCMBarrier_Rotate(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
testBarrier_Rotate(t, b)
}
func TestAESGCMBarrier_MissingRotateConfig(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
// Initialize and unseal
key, _ := b.GenerateKey(rand.Reader)
b.Initialize(context.Background(), key, nil, rand.Reader)
b.Unseal(context.Background(), key)
// Write a keyring which lacks rotation config settings
oldKeyring := b.keyring.Clone()
oldKeyring.rotationConfig = KeyRotationConfig{}
b.persistKeyring(context.Background(), oldKeyring)
b.ReloadKeyring(context.Background())
// At this point, the rotation config should match the default
if !defaultRotationConfig.Equals(b.keyring.rotationConfig) {
t.Fatalf("expected empty rotation config to recover as default config")
}
}
func TestAESGCMBarrier_Upgrade(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b1, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
b2, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
testBarrier_Upgrade(t, b1, b2)
}
func TestAESGCMBarrier_Upgrade_Rekey(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b1, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
b2, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
testBarrier_Upgrade_Rekey(t, b1, b2)
}
func TestAESGCMBarrier_Rekey(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
testBarrier_Rekey(t, b)
}
// Test an upgrade from the old (0.1) barrier/init to the new
// core/keyring style
func TestAESGCMBarrier_BackwardsCompatible(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
// Generate a barrier/init entry
encrypt, _ := b.GenerateKey(rand.Reader)
init := &barrierInit{
Version: 1,
Key: encrypt,
}
buf, _ := json.Marshal(init)
// Protect with master key
master, _ := b.GenerateKey(rand.Reader)
gcm, _ := b.aeadFromKey(master)
value, err := b.encrypt(barrierInitPath, initialKeyTerm, gcm, buf)
if err != nil {
t.Fatal(err)
}
// Write to the physical backend
pe := &physical.Entry{
Key: barrierInitPath,
Value: value,
}
inm.Put(context.Background(), pe)
// Create a fake key
gcm, _ = b.aeadFromKey(encrypt)
value, err = b.encrypt("test/foo", initialKeyTerm, gcm, []byte("test"))
if err != nil {
t.Fatal(err)
}
pe = &physical.Entry{
Key: "test/foo",
Value: value,
}
inm.Put(context.Background(), pe)
// Should still be initialized
isInit, err := b.Initialized(context.Background())
if err != nil {
t.Fatalf("err: %v", err)
}
if !isInit {
t.Fatalf("should be initialized")
}
// Unseal should work and migrate online
err = b.Unseal(context.Background(), master)
if err != nil {
t.Fatalf("err: %v", err)
}
// Check for migration
out, err := inm.Get(context.Background(), barrierInitPath)
if err != nil {
t.Fatalf("err: %v", err)
}
if out != nil {
t.Fatalf("should delete old barrier init")
}
// Should have keyring
out, err = inm.Get(context.Background(), keyringPath)
if err != nil {
t.Fatalf("err: %v", err)
}
if out == nil {
t.Fatalf("should have keyring file")
}
// Attempt to read encrypted key
entry, err := b.Get(context.Background(), "test/foo")
if err != nil {
t.Fatalf("err: %v", err)
}
if string(entry.Value) != "test" {
t.Fatalf("bad: %#v", entry)
}
}
// Verify data sent through is encrypted
func TestAESGCMBarrier_Confidential(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
// Initialize and unseal
key, _ := b.GenerateKey(rand.Reader)
b.Initialize(context.Background(), key, nil, rand.Reader)
b.Unseal(context.Background(), key)
// Put a logical entry
entry := &logical.StorageEntry{Key: "test", Value: []byte("test")}
err = b.Put(context.Background(), entry)
if err != nil {
t.Fatalf("err: %v", err)
}
// Check the physical entry
pe, err := inm.Get(context.Background(), "test")
if err != nil {
t.Fatalf("err: %v", err)
}
if pe == nil {
t.Fatalf("missing physical entry")
}
if pe.Key != "test" {
t.Fatalf("bad: %#v", pe)
}
if bytes.Equal(pe.Value, entry.Value) {
t.Fatalf("bad: %#v", pe)
}
}
// Verify data sent through cannot be tampered with
func TestAESGCMBarrier_Integrity(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
// Initialize and unseal
key, _ := b.GenerateKey(rand.Reader)
b.Initialize(context.Background(), key, nil, rand.Reader)
b.Unseal(context.Background(), key)
// Put a logical entry
entry := &logical.StorageEntry{Key: "test", Value: []byte("test")}
err = b.Put(context.Background(), entry)
if err != nil {
t.Fatalf("err: %v", err)
}
// Change a byte in the underlying physical entry
pe, _ := inm.Get(context.Background(), "test")
pe.Value[15]++
err = inm.Put(context.Background(), pe)
if err != nil {
t.Fatalf("err: %v", err)
}
// Read from the barrier
_, err = b.Get(context.Background(), "test")
if err == nil {
t.Fatalf("should fail!")
}
}
// Verify data sent through cannot be moved
func TestAESGCMBarrier_MoveIntegrityV1(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
b.currentAESGCMVersionByte = AESGCMVersion1
// Initialize and unseal
key, _ := b.GenerateKey(rand.Reader)
err = b.Initialize(context.Background(), key, nil, rand.Reader)
if err != nil {
t.Fatalf("err: %v", err)
}
err = b.Unseal(context.Background(), key)
if err != nil {
t.Fatalf("err: %v", err)
}
// Put a logical entry
entry := &logical.StorageEntry{Key: "test", Value: []byte("test")}
err = b.Put(context.Background(), entry)
if err != nil {
t.Fatalf("err: %v", err)
}
// Change the location of the underlying physical entry
pe, _ := inm.Get(context.Background(), "test")
pe.Key = "moved"
err = inm.Put(context.Background(), pe)
if err != nil {
t.Fatalf("err: %v", err)
}
// Read from the barrier
_, err = b.Get(context.Background(), "moved")
if err != nil {
t.Fatalf("should succeed with version 1!")
}
}
func TestAESGCMBarrier_MoveIntegrityV2(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
b.currentAESGCMVersionByte = AESGCMVersion2
// Initialize and unseal
key, _ := b.GenerateKey(rand.Reader)
err = b.Initialize(context.Background(), key, nil, rand.Reader)
if err != nil {
t.Fatalf("err: %v", err)
}
err = b.Unseal(context.Background(), key)
if err != nil {
t.Fatalf("err: %v", err)
}
// Put a logical entry
entry := &logical.StorageEntry{Key: "test", Value: []byte("test")}
err = b.Put(context.Background(), entry)
if err != nil {
t.Fatalf("err: %v", err)
}
// Change the location of the underlying physical entry
pe, _ := inm.Get(context.Background(), "test")
pe.Key = "moved"
err = inm.Put(context.Background(), pe)
if err != nil {
t.Fatalf("err: %v", err)
}
// Read from the barrier
_, err = b.Get(context.Background(), "moved")
if err == nil {
t.Fatalf("should fail with version 2!")
}
}
func TestAESGCMBarrier_UpgradeV1toV2(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
b.currentAESGCMVersionByte = AESGCMVersion1
// Initialize and unseal
key, _ := b.GenerateKey(rand.Reader)
err = b.Initialize(context.Background(), key, nil, rand.Reader)
if err != nil {
t.Fatalf("err: %v", err)
}
err = b.Unseal(context.Background(), key)
if err != nil {
t.Fatalf("err: %v", err)
}
// Put a logical entry
entry := &logical.StorageEntry{Key: "test", Value: []byte("test")}
err = b.Put(context.Background(), entry)
if err != nil {
t.Fatalf("err: %v", err)
}
// Seal
err = b.Seal()
if err != nil {
t.Fatalf("err: %v", err)
}
// Open again as version 2
b, err = NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
b.currentAESGCMVersionByte = AESGCMVersion2
// Unseal
err = b.Unseal(context.Background(), key)
if err != nil {
t.Fatalf("err: %v", err)
}
// Check successful decryption
_, err = b.Get(context.Background(), "test")
if err != nil {
t.Fatalf("Upgrade unsuccessful")
}
}
func TestEncrypt_Unique(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
key, _ := b.GenerateKey(rand.Reader)
b.Initialize(context.Background(), key, nil, rand.Reader)
b.Unseal(context.Background(), key)
if b.keyring == nil {
t.Fatalf("barrier is sealed")
}
entry := &logical.StorageEntry{Key: "test", Value: []byte("test")}
term := b.keyring.ActiveTerm()
primary, _ := b.aeadForTerm(term)
first, err := b.encrypt("test", term, primary, entry.Value)
if err != nil {
t.Fatal(err)
}
second, err := b.encrypt("test", term, primary, entry.Value)
if err != nil {
t.Fatal(err)
}
if bytes.Equal(first, second) {
t.Fatalf("improper random seeding detected")
}
}
func TestInitialize_KeyLength(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
long := []byte("ThisKeyDoesNotHaveTheRightLength!")
middle := []byte("ThisIsASecretKeyAndMore")
short := []byte("Key")
err = b.Initialize(context.Background(), long, nil, rand.Reader)
if err == nil {
t.Fatalf("key length protection failed")
}
err = b.Initialize(context.Background(), middle, nil, rand.Reader)
if err == nil {
t.Fatalf("key length protection failed")
}
err = b.Initialize(context.Background(), short, nil, rand.Reader)
if err == nil {
t.Fatalf("key length protection failed")
}
}
func TestEncrypt_BarrierEncryptor(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
// Initialize and unseal
key, err := b.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("err generating key: %v", err)
}
ctx := context.Background()
b.Initialize(ctx, key, nil, rand.Reader)
b.Unseal(ctx, key)
cipher, err := b.Encrypt(ctx, "foo", []byte("quick brown fox"))
if err != nil {
t.Fatalf("err: %v", err)
}
plain, err := b.Decrypt(ctx, "foo", cipher)
if err != nil {
t.Fatalf("err: %v", err)
}
if string(plain) != "quick brown fox" {
t.Fatalf("bad: %s", plain)
}
}
// Ensure Decrypt returns an error (rather than panic) when given a ciphertext
// that is nil or too short
func TestDecrypt_InvalidCipherLength(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
key, err := b.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("err generating key: %v", err)
}
ctx := context.Background()
b.Initialize(ctx, key, nil, rand.Reader)
b.Unseal(ctx, key)
var nilCipher []byte
if _, err = b.Decrypt(ctx, "", nilCipher); err == nil {
t.Fatal("expected error when given nil cipher")
}
emptyCipher := []byte{}
if _, err = b.Decrypt(ctx, "", emptyCipher); err == nil {
t.Fatal("expected error when given empty cipher")
}
badTermLengthCipher := make([]byte, 3, 3)
if _, err = b.Decrypt(ctx, "", badTermLengthCipher); err == nil {
t.Fatal("expected error when given cipher with too short term")
}
}
func TestAESGCMBarrier_ReloadKeyring(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
// Initialize and unseal
key, _ := b.GenerateKey(rand.Reader)
b.Initialize(context.Background(), key, nil, rand.Reader)
b.Unseal(context.Background(), key)
keyringRaw, err := inm.Get(context.Background(), keyringPath)
if err != nil {
t.Fatalf("err: %v", err)
}
// Encrypt something to test cache invalidation
_, err = b.Encrypt(context.Background(), "foo", []byte("quick brown fox"))
if err != nil {
t.Fatalf("err: %v", err)
}
{
// Create a second barrier and rotate the keyring
b2, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
}
b2.Unseal(context.Background(), key)
_, err = b2.Rotate(context.Background(), rand.Reader)
if err != nil {
t.Fatalf("err: %v", err)
}
}
// Reload the keyring on the first
err = b.ReloadKeyring(context.Background())
if err != nil {
t.Fatalf("err: %v", err)
}
if b.keyring.ActiveTerm() != 2 {
t.Fatal("failed to reload keyring")
}
if len(b.cache) != 0 {
t.Fatal("failed to clear cache")
}
// Encrypt something to test cache invalidation
_, err = b.Encrypt(context.Background(), "foo", []byte("quick brown fox"))
if err != nil {
t.Fatalf("err: %v", err)
}
// Restore old keyring to test rolling back
err = inm.Put(context.Background(), keyringRaw)
if err != nil {
t.Fatalf("err: %v", err)
}
// Reload the keyring on the first
err = b.ReloadKeyring(context.Background())
if err != nil {
t.Fatalf("err: %v", err)
}
if b.keyring.ActiveTerm() != 1 {
t.Fatal("failed to reload keyring")
}
if len(b.cache) != 0 {
t.Fatal("failed to clear cache")
}
}
func TestBarrier_LegacyRotate(t *testing.T) {
inm, err := inmem.NewInmem(nil, logger)
if err != nil {
t.Fatalf("err: %v", err)
}
b1, err := NewAESGCMBarrier(inm)
if err != nil {
t.Fatalf("err: %v", err)
} // Initialize the barrier
key, _ := b1.GenerateKey(rand.Reader)
b1.Initialize(context.Background(), key, nil, rand.Reader)
err = b1.Unseal(context.Background(), key)
if err != nil {
t.Fatalf("err: %v", err)
}
k1 := b1.keyring.TermKey(1)
k1.Encryptions = 0
k1.InstallTime = time.Now().Add(-24 * 366 * time.Hour)
b1.persistKeyring(context.Background(), b1.keyring)
b1.Seal()
err = b1.Unseal(context.Background(), key)
if err != nil {
t.Fatalf("err: %v", err)
}
reason, err := b1.CheckBarrierAutoRotate(context.Background())
if err != nil || reason != legacyRotateReason {
t.Fail()
}
}
// TestBarrier_persistKeyring_Context checks that we get the right errors if
// the context is cancelled or times-out before the first part of persistKeyring
// is able to persist the keyring itself (i.e. we don't go on to try and persist
// the root key).
func TestBarrier_persistKeyring_Context(t *testing.T) {
t.Parallel()
tests := map[string]struct {
shouldCancel bool
isErrorExpected bool
expectedErrorMessage string
contextTimeout time.Duration
testTimeout time.Duration
}{
"cancelled": {
shouldCancel: true,
isErrorExpected: true,
expectedErrorMessage: "failed to persist keyring: context canceled",
contextTimeout: 8 * time.Second,
testTimeout: 10 * time.Second,
},
"timeout-before-keyring": {
isErrorExpected: true,
expectedErrorMessage: "failed to persist keyring: context deadline exceeded",
contextTimeout: 1 * time.Nanosecond,
testTimeout: 5 * time.Second,
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
// Set up barrier
backend, err := inmem.NewInmem(nil, corehelpers.NewTestLogger(t))
require.NoError(t, err)
barrier, err := NewAESGCMBarrier(backend)
require.NoError(t, err)
key, err := barrier.GenerateKey(rand.Reader)
require.NoError(t, err)
err = barrier.Initialize(context.Background(), key, nil, rand.Reader)
require.NoError(t, err)
err = barrier.Unseal(context.Background(), key)
require.NoError(t, err)
k := barrier.keyring.TermKey(1)
k.Encryptions = 0
k.InstallTime = time.Now().Add(-24 * 366 * time.Hour)
// Persist the keyring
ctx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout)
persistChan := make(chan error)
go func() {
if tc.shouldCancel {
cancel()
}
persistChan <- barrier.persistKeyring(ctx, barrier.keyring)
}()
select {
case err := <-persistChan:
switch {
case tc.isErrorExpected:
require.Error(t, err)
require.EqualError(t, err, tc.expectedErrorMessage)
default:
require.NoError(t, err)
}
case <-time.After(tc.testTimeout):
t.Fatal("timeout reached")
}
})
}
}