Add unit tests
This commit is contained in:
parent
5000711a67
commit
30ffc18c19
|
@ -89,11 +89,19 @@ func pathConfigWrite(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add this as a guard here before persisting since we now require the min
|
||||||
|
// decryption version to start at 1; even if it's not explicitly set here,
|
||||||
|
// force the upgrade
|
||||||
|
if policy.MinDecryptionVersion == 0 {
|
||||||
|
policy.MinDecryptionVersion = 1
|
||||||
|
persistNeeded = true
|
||||||
|
}
|
||||||
|
|
||||||
if !persistNeeded {
|
if !persistNeeded {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, policy.Persist(req.Storage, name)
|
return resp, policy.Persist(req.Storage)
|
||||||
}
|
}
|
||||||
|
|
||||||
const pathConfigHelpSyn = `Configure a named encryption key`
|
const pathConfigHelpSyn = `Configure a named encryption key`
|
||||||
|
|
|
@ -76,23 +76,24 @@ type Policy struct {
|
||||||
// The latest key version in this policy
|
// The latest key version in this policy
|
||||||
LatestVersion int `json:"latest_version"`
|
LatestVersion int `json:"latest_version"`
|
||||||
|
|
||||||
// The latest key version in the archive. We never delete these, so this is a max.
|
// The latest key version in the archive. We never delete these, so this is
|
||||||
|
// a max.
|
||||||
ArchiveVersion int `json:"archive_version"`
|
ArchiveVersion int `json:"archive_version"`
|
||||||
|
|
||||||
// Whether the key is allowed to be deleted
|
// Whether the key is allowed to be deleted
|
||||||
DeletionAllowed bool `json:"deletion_allowed"`
|
DeletionAllowed bool `json:"deletion_allowed"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ArchivedKeys stores old keys. This is used to keep the key loading time sane when
|
// ArchivedKeys stores old keys. This is used to keep the key loading time sane
|
||||||
// there are huge numbers of rotations.
|
// when there are huge numbers of rotations.
|
||||||
type ArchivedKeys struct {
|
type ArchivedKeys struct {
|
||||||
Keys []KeyEntry `json:"keys"`
|
Keys []KeyEntry `json:"keys"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Policy) loadArchive(storage logical.Storage, name string) (*ArchivedKeys, error) {
|
func (p *Policy) loadArchive(storage logical.Storage) (*ArchivedKeys, error) {
|
||||||
archive := &ArchivedKeys{}
|
archive := &ArchivedKeys{}
|
||||||
|
|
||||||
raw, err := storage.Get("policy/" + name + "/archive")
|
raw, err := storage.Get("policy/" + p.Name + "/archive")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -108,7 +109,7 @@ func (p *Policy) loadArchive(storage logical.Storage, name string) (*ArchivedKey
|
||||||
return archive, nil
|
return archive, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage, name string) error {
|
func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage) error {
|
||||||
// Encode the policy
|
// Encode the policy
|
||||||
buf, err := json.Marshal(archive)
|
buf, err := json.Marshal(archive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -117,7 +118,7 @@ func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage, na
|
||||||
|
|
||||||
// Write the policy into storage
|
// Write the policy into storage
|
||||||
err = storage.Put(&logical.StorageEntry{
|
err = storage.Put(&logical.StorageEntry{
|
||||||
Key: "policy/" + name + "/archive",
|
Key: "policy/" + p.Name + "/archive",
|
||||||
Value: buf,
|
Value: buf,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -130,13 +131,24 @@ func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage, na
|
||||||
// handleArchiving manages the movement of keys to and from the policy archive.
|
// handleArchiving manages the movement of keys to and from the policy archive.
|
||||||
// This should *ONLY* be called from Persist() since it assumes that the policy
|
// This should *ONLY* be called from Persist() since it assumes that the policy
|
||||||
// will be persisted afterwards.
|
// will be persisted afterwards.
|
||||||
func (p *Policy) handleArchiving(storage logical.Storage, name string) error {
|
func (p *Policy) handleArchiving(storage logical.Storage) error {
|
||||||
// We need to move keys that are no longer accessible to ArchivedKeys, and keys
|
// We need to move keys that are no longer accessible to ArchivedKeys, and keys
|
||||||
// that now need to be accessible back here.
|
// that now need to be accessible back here.
|
||||||
//
|
//
|
||||||
// For safety, because there isn't really a good reason to, we never delete
|
// For safety, because there isn't really a good reason to, we never delete
|
||||||
// keys from the archive even when we move them back.
|
// keys from the archive even when we move them back.
|
||||||
|
|
||||||
|
// Sanity checks
|
||||||
|
switch {
|
||||||
|
case p.MinDecryptionVersion < 1:
|
||||||
|
return fmt.Errorf("minimum decryption version of %d is less than 1", p.MinDecryptionVersion)
|
||||||
|
case p.LatestVersion < 1:
|
||||||
|
return fmt.Errorf("latest version of %d is less than 1", p.LatestVersion)
|
||||||
|
case p.MinDecryptionVersion > p.LatestVersion:
|
||||||
|
return fmt.Errorf("minimum decryption version of %d is greater than the latest version %d",
|
||||||
|
p.MinDecryptionVersion, p.LatestVersion)
|
||||||
|
}
|
||||||
|
|
||||||
// Check if we have the latest minimum version in the current set of keys
|
// Check if we have the latest minimum version in the current set of keys
|
||||||
_, keysContainsMinimum := p.Keys[p.MinDecryptionVersion]
|
_, keysContainsMinimum := p.Keys[p.MinDecryptionVersion]
|
||||||
|
|
||||||
|
@ -148,7 +160,7 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
archive, err := p.loadArchive(storage, name)
|
archive, err := p.loadArchive(storage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -156,11 +168,11 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error {
|
||||||
if keysContainsMinimum {
|
if keysContainsMinimum {
|
||||||
// Need to move keys *to* archive
|
// Need to move keys *to* archive
|
||||||
|
|
||||||
if len(archive.Keys) < p.MinDecryptionVersion-1 {
|
// We need a size that is equivalent to the minimum decryption version
|
||||||
// Increase the size of the archive slice. We need a size that is
|
// minus 1, but adding one since slice numbering starts at 0 and we're
|
||||||
// equivalent to the minimum decryption version minus 1, but adding
|
// indexing by key version
|
||||||
// one since slice numbering starts at 0 and we're indexing by key
|
if len(archive.Keys) < p.MinDecryptionVersion {
|
||||||
// version
|
// Increase the size of the archive slice
|
||||||
newKeys := make([]KeyEntry, p.MinDecryptionVersion)
|
newKeys := make([]KeyEntry, p.MinDecryptionVersion)
|
||||||
copy(newKeys, archive.Keys)
|
copy(newKeys, archive.Keys)
|
||||||
archive.Keys = newKeys
|
archive.Keys = newKeys
|
||||||
|
@ -177,7 +189,7 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error {
|
||||||
archive.Keys[i] = p.Keys[i]
|
archive.Keys[i] = p.Keys[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.storeArchive(archive, storage, name)
|
err = p.storeArchive(archive, storage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -231,8 +243,8 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Policy) Persist(storage logical.Storage, name string) error {
|
func (p *Policy) Persist(storage logical.Storage) error {
|
||||||
err := p.handleArchiving(storage, name)
|
err := p.handleArchiving(storage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -245,7 +257,7 @@ func (p *Policy) Persist(storage logical.Storage, name string) error {
|
||||||
|
|
||||||
// Write the policy into storage
|
// Write the policy into storage
|
||||||
err = storage.Put(&logical.StorageEntry{
|
err = storage.Put(&logical.StorageEntry{
|
||||||
Key: "policy/" + name,
|
Key: "policy/" + p.Name,
|
||||||
Value: buf,
|
Value: buf,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -449,7 +461,7 @@ func (p *Policy) rotate(storage logical.Storage) error {
|
||||||
p.MinDecryptionVersion = 1
|
p.MinDecryptionVersion = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.Persist(storage, p.Name)
|
return p.Persist(storage)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Policy) migrateKeyToKeysMap() {
|
func (p *Policy) migrateKeyToKeysMap() {
|
||||||
|
@ -516,7 +528,7 @@ func getPolicy(req *logical.Request, name string) (*Policy, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if persistNeeded {
|
if persistNeeded {
|
||||||
err = p.Persist(req.Storage, name)
|
err = p.Persist(req.Storage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,153 @@
|
||||||
|
package transit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/logical"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
keysArchive = []KeyEntry{KeyEntry{}}
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_Archiving(t *testing.T) {
|
||||||
|
// First, we generate a policy and rotate it a number of times. Each time
|
||||||
|
// we'll ensure that we have the expected number of keys in the archive and
|
||||||
|
// the main keys object, which without changing the min version should be
|
||||||
|
// zero and latest, respectively
|
||||||
|
|
||||||
|
storage := &logical.InmemStorage{}
|
||||||
|
testName := "test"
|
||||||
|
|
||||||
|
policy, err := generatePolicy(storage, testName, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if policy == nil {
|
||||||
|
t.Fatal("policy is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the initial key in the archive
|
||||||
|
keysArchive = append(keysArchive, policy.Keys[1])
|
||||||
|
checkKeys(t, policy, storage, 0, 1, 1)
|
||||||
|
|
||||||
|
for i := 2; i <= 10; i++ {
|
||||||
|
err = policy.rotate(storage)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
keysArchive = append(keysArchive, policy.Keys[i])
|
||||||
|
checkKeys(t, policy, storage, 0, i, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move the min decryption version up
|
||||||
|
for i := 1; i <= 10; i++ {
|
||||||
|
policy.MinDecryptionVersion = i
|
||||||
|
|
||||||
|
err = policy.Persist(storage)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// We expect to find:
|
||||||
|
// * The keys in archive are the min decryption version - 1
|
||||||
|
// * The latest version is constant
|
||||||
|
// * The number of keys in the policy itself is from the min
|
||||||
|
// decryption version up to the latest version, so for e.g. 7 and
|
||||||
|
// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
|
||||||
|
// decryption version plus 1 (the min decryption version key
|
||||||
|
// itself)
|
||||||
|
checkKeys(t, policy, storage, i-1, 10, policy.LatestVersion-policy.MinDecryptionVersion+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move the min decryption version down
|
||||||
|
for i := 10; i >= 1; i-- {
|
||||||
|
policy.MinDecryptionVersion = i
|
||||||
|
|
||||||
|
err = policy.Persist(storage)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// We expect to find:
|
||||||
|
// * The keys in archive are never removed so should be the previous
|
||||||
|
// min decryption version (10) minus 1, always
|
||||||
|
// * The latest version is constant
|
||||||
|
// * The number of keys in the policy itself is from the min
|
||||||
|
// decryption version up to the latest version, so for e.g. 7 and
|
||||||
|
// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
|
||||||
|
// decryption version plus 1 (the min decryption version key
|
||||||
|
// itself)
|
||||||
|
checkKeys(t, policy, storage, 9, 10, policy.LatestVersion-policy.MinDecryptionVersion+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkKeys(t *testing.T,
|
||||||
|
policy *Policy,
|
||||||
|
storage logical.Storage,
|
||||||
|
archiveVer, latestVer, keysSize int) {
|
||||||
|
|
||||||
|
// Sanity check
|
||||||
|
if len(keysArchive) != latestVer+1 {
|
||||||
|
t.Fatalf("latest expected key version is %d, expected test keys archive size is %d, "+
|
||||||
|
"but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive))
|
||||||
|
}
|
||||||
|
|
||||||
|
archive, err := policy.loadArchive(storage)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
badArchiveVer := false
|
||||||
|
if archiveVer == 0 {
|
||||||
|
if len(archive.Keys) != 0 || policy.ArchiveVersion != 0 {
|
||||||
|
badArchiveVer = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// We need to subtract one because we have the indexes match key
|
||||||
|
// versions, which start at 1. So for an archive version of 1, we
|
||||||
|
// actually have two entries -- a blank 0 entry, and the key at spot 1
|
||||||
|
if archiveVer != len(archive.Keys)-1 || archiveVer != policy.ArchiveVersion {
|
||||||
|
badArchiveVer = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if badArchiveVer {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected archive version %d, found length of archive keys %d and policy archive version %d",
|
||||||
|
archiveVer, len(archive.Keys), policy.ArchiveVersion,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if latestVer != policy.LatestVersion {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected latest version %d, found %d",
|
||||||
|
latestVer, policy.LatestVersion,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if keysSize != len(policy.Keys) {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected keys size %d, found %d",
|
||||||
|
keysSize, len(policy.Keys),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := policy.MinDecryptionVersion; i <= policy.LatestVersion; i++ {
|
||||||
|
if _, ok := policy.Keys[i]; !ok {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected key %d, did not find it in policy keys", i,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := policy.MinDecryptionVersion; i <= policy.LatestVersion; i++ {
|
||||||
|
if !reflect.DeepEqual(policy.Keys[i], keysArchive[i]) {
|
||||||
|
t.Fatalf("key %d not equivalent between policy keys and test keys archive", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i < len(archive.Keys); i++ {
|
||||||
|
if !reflect.DeepEqual(archive.Keys[i].Key, keysArchive[i].Key) {
|
||||||
|
t.Fatalf("key %d not equivalent between policy archive and test keys archive", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue