diff --git a/vault/barrier_aes_gcm_test.go b/vault/barrier_aes_gcm_test.go index 070e5449d..8fb7f4ff3 100644 --- a/vault/barrier_aes_gcm_test.go +++ b/vault/barrier_aes_gcm_test.go @@ -32,6 +32,15 @@ func TestAESGCMBarrier_Basic(t *testing.T) { testBarrier(t, b) } +func TestAESGCMBarrier_Rotate(t *testing.T) { + inm := physical.NewInmem() + b, err := NewAESGCMBarrier(inm) + if err != nil { + t.Fatalf("err: %v", err) + } + testBarrier_Rotate(t, b) +} + // Test an upgrade from the old (0.1) barrier/init to the new // core/keyring style func TestAESGCMBarrier_BackwardsCompatible(t *testing.T) { diff --git a/vault/barrier_test.go b/vault/barrier_test.go index ddbf0f1ac..37b82cdcd 100644 --- a/vault/barrier_test.go +++ b/vault/barrier_test.go @@ -233,3 +233,75 @@ func testBarrier(t *testing.T, b SecurityBarrier) { t.Fatalf("err: %v", err) } } + +func testBarrier_Rotate(t *testing.T, b SecurityBarrier) { + // Initialize the barrier + key, _ := b.GenerateKey() + b.Initialize(key) + err := b.Unseal(key) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Write a key + e1 := &Entry{Key: "test", Value: []byte("test")} + if err := b.Put(e1); err != nil { + t.Fatalf("err: %v", err) + } + + // Rotate the encryption key + err = b.Rotate() + if err != nil { + t.Fatalf("err: %v", err) + } + + // Write another key + e2 := &Entry{Key: "foo", Value: []byte("test")} + if err := b.Put(e2); err != nil { + t.Fatalf("err: %v", err) + } + + // Reading both should work + out, err := b.Get(e1.Key) + if err != nil { + t.Fatalf("err: %v", err) + } + if out == nil { + t.Fatalf("bad: %v", out) + } + + out, err = b.Get(e2.Key) + if err != nil { + t.Fatalf("err: %v", err) + } + if out == nil { + t.Fatalf("bad: %v", out) + } + + // Seal and unseal + err = b.Seal() + if err != nil { + t.Fatalf("err: %v", err) + } + err = b.Unseal(key) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Reading both should work + out, err = b.Get(e1.Key) + if err != nil { + t.Fatalf("err: %v", err) + } + if out == nil { + t.Fatalf("bad: %v", out) + } + + out, err = b.Get(e2.Key) + if err != nil { + t.Fatalf("err: %v", err) + } + if out == nil { + t.Fatalf("bad: %v", out) + } +}