From 30ffc18c196ce059191e3be0355f544191ef7ab1 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Tue, 26 Jan 2016 12:23:42 -0500 Subject: [PATCH] Add unit tests --- builtin/logical/transit/path_config.go | 10 +- builtin/logical/transit/policy.go | 52 +++++---- builtin/logical/transit/policy_test.go | 153 +++++++++++++++++++++++++ 3 files changed, 194 insertions(+), 21 deletions(-) create mode 100644 builtin/logical/transit/policy_test.go diff --git a/builtin/logical/transit/path_config.go b/builtin/logical/transit/path_config.go index 31438a6e6..4636181a6 100644 --- a/builtin/logical/transit/path_config.go +++ b/builtin/logical/transit/path_config.go @@ -89,11 +89,19 @@ func pathConfigWrite( } } + // Add this as a guard here before persisting since we now require the min + // decryption version to start at 1; even if it's not explicitly set here, + // force the upgrade + if policy.MinDecryptionVersion == 0 { + policy.MinDecryptionVersion = 1 + persistNeeded = true + } + if !persistNeeded { return nil, nil } - return resp, policy.Persist(req.Storage, name) + return resp, policy.Persist(req.Storage) } const pathConfigHelpSyn = `Configure a named encryption key` diff --git a/builtin/logical/transit/policy.go b/builtin/logical/transit/policy.go index db42f1e8a..ea76f1424 100644 --- a/builtin/logical/transit/policy.go +++ b/builtin/logical/transit/policy.go @@ -76,23 +76,24 @@ type Policy struct { // The latest key version in this policy LatestVersion int `json:"latest_version"` - // The latest key version in the archive. We never delete these, so this is a max. + // The latest key version in the archive. We never delete these, so this is + // a max. ArchiveVersion int `json:"archive_version"` // Whether the key is allowed to be deleted DeletionAllowed bool `json:"deletion_allowed"` } -// ArchivedKeys stores old keys. This is used to keep the key loading time sane when -// there are huge numbers of rotations. +// ArchivedKeys stores old keys. This is used to keep the key loading time sane +// when there are huge numbers of rotations. type ArchivedKeys struct { Keys []KeyEntry `json:"keys"` } -func (p *Policy) loadArchive(storage logical.Storage, name string) (*ArchivedKeys, error) { +func (p *Policy) loadArchive(storage logical.Storage) (*ArchivedKeys, error) { archive := &ArchivedKeys{} - raw, err := storage.Get("policy/" + name + "/archive") + raw, err := storage.Get("policy/" + p.Name + "/archive") if err != nil { return nil, err } @@ -108,7 +109,7 @@ func (p *Policy) loadArchive(storage logical.Storage, name string) (*ArchivedKey return archive, nil } -func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage, name string) error { +func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage) error { // Encode the policy buf, err := json.Marshal(archive) if err != nil { @@ -117,7 +118,7 @@ func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage, na // Write the policy into storage err = storage.Put(&logical.StorageEntry{ - Key: "policy/" + name + "/archive", + Key: "policy/" + p.Name + "/archive", Value: buf, }) if err != nil { @@ -130,13 +131,24 @@ func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage, na // handleArchiving manages the movement of keys to and from the policy archive. // This should *ONLY* be called from Persist() since it assumes that the policy // will be persisted afterwards. -func (p *Policy) handleArchiving(storage logical.Storage, name string) error { +func (p *Policy) handleArchiving(storage logical.Storage) error { // We need to move keys that are no longer accessible to ArchivedKeys, and keys // that now need to be accessible back here. // // For safety, because there isn't really a good reason to, we never delete // keys from the archive even when we move them back. + // Sanity checks + switch { + case p.MinDecryptionVersion < 1: + return fmt.Errorf("minimum decryption version of %d is less than 1", p.MinDecryptionVersion) + case p.LatestVersion < 1: + return fmt.Errorf("latest version of %d is less than 1", p.LatestVersion) + case p.MinDecryptionVersion > p.LatestVersion: + return fmt.Errorf("minimum decryption version of %d is greater than the latest version %d", + p.MinDecryptionVersion, p.LatestVersion) + } + // Check if we have the latest minimum version in the current set of keys _, keysContainsMinimum := p.Keys[p.MinDecryptionVersion] @@ -148,7 +160,7 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error { return nil } - archive, err := p.loadArchive(storage, name) + archive, err := p.loadArchive(storage) if err != nil { return err } @@ -156,11 +168,11 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error { if keysContainsMinimum { // Need to move keys *to* archive - if len(archive.Keys) < p.MinDecryptionVersion-1 { - // Increase the size of the archive slice. We need a size that is - // equivalent to the minimum decryption version minus 1, but adding - // one since slice numbering starts at 0 and we're indexing by key - // version + // We need a size that is equivalent to the minimum decryption version + // minus 1, but adding one since slice numbering starts at 0 and we're + // indexing by key version + if len(archive.Keys) < p.MinDecryptionVersion { + // Increase the size of the archive slice newKeys := make([]KeyEntry, p.MinDecryptionVersion) copy(newKeys, archive.Keys) archive.Keys = newKeys @@ -177,7 +189,7 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error { archive.Keys[i] = p.Keys[i] } - err = p.storeArchive(archive, storage, name) + err = p.storeArchive(archive, storage) if err != nil { return err } @@ -231,8 +243,8 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error { return nil } -func (p *Policy) Persist(storage logical.Storage, name string) error { - err := p.handleArchiving(storage, name) +func (p *Policy) Persist(storage logical.Storage) error { + err := p.handleArchiving(storage) if err != nil { return err } @@ -245,7 +257,7 @@ func (p *Policy) Persist(storage logical.Storage, name string) error { // Write the policy into storage err = storage.Put(&logical.StorageEntry{ - Key: "policy/" + name, + Key: "policy/" + p.Name, Value: buf, }) if err != nil { @@ -449,7 +461,7 @@ func (p *Policy) rotate(storage logical.Storage) error { p.MinDecryptionVersion = 1 } - return p.Persist(storage, p.Name) + return p.Persist(storage) } func (p *Policy) migrateKeyToKeysMap() { @@ -516,7 +528,7 @@ func getPolicy(req *logical.Request, name string) (*Policy, error) { } if persistNeeded { - err = p.Persist(req.Storage, name) + err = p.Persist(req.Storage) if err != nil { return nil, err } diff --git a/builtin/logical/transit/policy_test.go b/builtin/logical/transit/policy_test.go new file mode 100644 index 000000000..c9f43c7a2 --- /dev/null +++ b/builtin/logical/transit/policy_test.go @@ -0,0 +1,153 @@ +package transit + +import ( + "reflect" + "testing" + + "github.com/hashicorp/vault/logical" +) + +var ( + keysArchive = []KeyEntry{KeyEntry{}} +) + +func Test_Archiving(t *testing.T) { + // First, we generate a policy and rotate it a number of times. Each time + // we'll ensure that we have the expected number of keys in the archive and + // the main keys object, which without changing the min version should be + // zero and latest, respectively + + storage := &logical.InmemStorage{} + testName := "test" + + policy, err := generatePolicy(storage, testName, false) + if err != nil { + t.Fatal(err) + } + if policy == nil { + t.Fatal("policy is nil") + } + + // Store the initial key in the archive + keysArchive = append(keysArchive, policy.Keys[1]) + checkKeys(t, policy, storage, 0, 1, 1) + + for i := 2; i <= 10; i++ { + err = policy.rotate(storage) + if err != nil { + t.Fatal(err) + } + keysArchive = append(keysArchive, policy.Keys[i]) + checkKeys(t, policy, storage, 0, i, i) + } + + // Move the min decryption version up + for i := 1; i <= 10; i++ { + policy.MinDecryptionVersion = i + + err = policy.Persist(storage) + if err != nil { + t.Fatal(err) + } + // We expect to find: + // * The keys in archive are the min decryption version - 1 + // * The latest version is constant + // * The number of keys in the policy itself is from the min + // decryption version up to the latest version, so for e.g. 7 and + // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min + // decryption version plus 1 (the min decryption version key + // itself) + checkKeys(t, policy, storage, i-1, 10, policy.LatestVersion-policy.MinDecryptionVersion+1) + } + + // Move the min decryption version down + for i := 10; i >= 1; i-- { + policy.MinDecryptionVersion = i + + err = policy.Persist(storage) + if err != nil { + t.Fatal(err) + } + // We expect to find: + // * The keys in archive are never removed so should be the previous + // min decryption version (10) minus 1, always + // * The latest version is constant + // * The number of keys in the policy itself is from the min + // decryption version up to the latest version, so for e.g. 7 and + // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min + // decryption version plus 1 (the min decryption version key + // itself) + checkKeys(t, policy, storage, 9, 10, policy.LatestVersion-policy.MinDecryptionVersion+1) + } +} + +func checkKeys(t *testing.T, + policy *Policy, + storage logical.Storage, + archiveVer, latestVer, keysSize int) { + + // Sanity check + if len(keysArchive) != latestVer+1 { + t.Fatalf("latest expected key version is %d, expected test keys archive size is %d, "+ + "but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive)) + } + + archive, err := policy.loadArchive(storage) + if err != nil { + t.Fatal(err) + } + + badArchiveVer := false + if archiveVer == 0 { + if len(archive.Keys) != 0 || policy.ArchiveVersion != 0 { + badArchiveVer = true + } + } else { + // We need to subtract one because we have the indexes match key + // versions, which start at 1. So for an archive version of 1, we + // actually have two entries -- a blank 0 entry, and the key at spot 1 + if archiveVer != len(archive.Keys)-1 || archiveVer != policy.ArchiveVersion { + badArchiveVer = true + } + } + if badArchiveVer { + t.Fatalf( + "expected archive version %d, found length of archive keys %d and policy archive version %d", + archiveVer, len(archive.Keys), policy.ArchiveVersion, + ) + } + + if latestVer != policy.LatestVersion { + t.Fatalf( + "expected latest version %d, found %d", + latestVer, policy.LatestVersion, + ) + } + + if keysSize != len(policy.Keys) { + t.Fatalf( + "expected keys size %d, found %d", + keysSize, len(policy.Keys), + ) + } + + for i := policy.MinDecryptionVersion; i <= policy.LatestVersion; i++ { + if _, ok := policy.Keys[i]; !ok { + t.Fatalf( + "expected key %d, did not find it in policy keys", i, + ) + } + } + + for i := policy.MinDecryptionVersion; i <= policy.LatestVersion; i++ { + if !reflect.DeepEqual(policy.Keys[i], keysArchive[i]) { + t.Fatalf("key %d not equivalent between policy keys and test keys archive", i) + } + } + + for i := 1; i < len(archive.Keys); i++ { + if !reflect.DeepEqual(archive.Keys[i].Key, keysArchive[i].Key) { + t.Fatalf("key %d not equivalent between policy archive and test keys archive", i) + } + } +}