From 369d0bbad038aad261ead51a8a0d9b30e129fc4f Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Tue, 26 Jan 2016 20:21:58 -0500 Subject: [PATCH] Address review feedback --- builtin/logical/transit/path_keys.go | 1 - builtin/logical/transit/policy.go | 81 ++++++++++++-------------- builtin/logical/transit/policy_test.go | 30 +++++++++- 3 files changed, 66 insertions(+), 46 deletions(-) diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 8c13b0bb9..d6a85c5aa 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -74,7 +74,6 @@ func pathPolicyRead( "deletion_allowed": p.DeletionAllowed, "min_decryption_version": p.MinDecryptionVersion, "latest_version": p.LatestVersion, - "archive_version": p.ArchiveVersion, }, } if p.Derived { diff --git a/builtin/logical/transit/policy.go b/builtin/logical/transit/policy.go index 832d22c7a..1c334dcff 100644 --- a/builtin/logical/transit/policy.go +++ b/builtin/logical/transit/policy.go @@ -162,43 +162,44 @@ func (p *Policy) handleArchiving(storage logical.Storage) error { return err } - if keysContainsMinimum { - // Need to move keys *to* archive - - // We need a size that is equivalent to the latest version (number of - // keys) but adding one since slice numbering starts at 0 and we're - // indexing by key version - if len(archive.Keys) < p.LatestVersion+1 { - // Increase the size of the archive slice - newKeys := make([]KeyEntry, p.LatestVersion+1) - copy(newKeys, archive.Keys) - archive.Keys = newKeys - } - - // We are storing all keys in the archive, so we ensure that it is up - // to date up to p.LatestVersion - for i := p.ArchiveVersion + 1; i <= p.LatestVersion; i++ { - archive.Keys[i] = p.Keys[i] - p.ArchiveVersion = i - } - - err = p.storeArchive(archive, storage) - if err != nil { - return err - } - - // Perform deletion afterwards so that if there is an error saving we - // haven't messed with the current policy - for i := p.LatestVersion - len(p.Keys) + 1; i < p.MinDecryptionVersion; i++ { - delete(p.Keys, i) - } - - } else { + if !keysContainsMinimum { // Need to move keys *from* archive for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ { p.Keys[i] = archive.Keys[i] } + + return nil + } + + // Need to move keys *to* archive + + // We need a size that is equivalent to the latest version (number of keys) + // but adding one since slice numbering starts at 0 and we're indexing by + // key version + if len(archive.Keys) < p.LatestVersion+1 { + // Increase the size of the archive slice + newKeys := make([]KeyEntry, p.LatestVersion+1) + copy(newKeys, archive.Keys) + archive.Keys = newKeys + } + + // We are storing all keys in the archive, so we ensure that it is up to + // date up to p.LatestVersion + for i := p.ArchiveVersion + 1; i <= p.LatestVersion; i++ { + archive.Keys[i] = p.Keys[i] + p.ArchiveVersion = i + } + + err = p.storeArchive(archive, storage) + if err != nil { + return err + } + + // Perform deletion afterwards so that if there is an error saving we + // haven't messed with the current policy + for i := p.LatestVersion - len(p.Keys) + 1; i < p.MinDecryptionVersion; i++ { + delete(p.Keys, i) } return nil @@ -238,10 +239,7 @@ func (p *Policy) Serialize() ([]byte, error) { // mode is used with the context to derive the proper key. func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) { if p.Keys == nil || p.LatestVersion == 0 { - if p.Key == nil || len(p.Key) == 0 { - return nil, certutil.InternalError{Err: "unable to access the key; no key versions found"} - } - p.migrateKeyToKeysMap() + return nil, certutil.InternalError{Err: "unable to access the key; no key versions found"} } if p.LatestVersion == 0 { @@ -398,7 +396,10 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) { func (p *Policy) rotate(storage logical.Storage) error { if p.Keys == nil { - p.migrateKeyToKeysMap() + // This is an initial key rotation when generating a new policy. We + // don't need to call migrate here because if we've called getPolicy to + // get the policy in the first place it will have been run. + p.Keys = KeyEntryMap{} } // Generate a 256bit key @@ -426,12 +427,6 @@ func (p *Policy) rotate(storage logical.Storage) error { } func (p *Policy) migrateKeyToKeysMap() { - if p.Key == nil || len(p.Key) == 0 { - p.Key = nil - p.Keys = KeyEntryMap{} - return - } - p.Keys = KeyEntryMap{ 1: KeyEntry{ Key: p.Key, diff --git a/builtin/logical/transit/policy_test.go b/builtin/logical/transit/policy_test.go index 295562f32..5a6aea853 100644 --- a/builtin/logical/transit/policy_test.go +++ b/builtin/logical/transit/policy_test.go @@ -11,6 +11,33 @@ var ( keysArchive = []KeyEntry{KeyEntry{}} ) +func Test_KeyUpgrade(t *testing.T) { + storage := &logical.InmemStorage{} + policy, err := generatePolicy(storage, "test", false) + if err != nil { + t.Fatal(err) + } + if policy == nil { + t.Fatal("nil policy") + } + + testBytes := make([]byte, len(policy.Keys[1].Key)) + copy(testBytes, policy.Keys[1].Key) + + policy.Key = policy.Keys[1].Key + policy.Keys = nil + policy.migrateKeyToKeysMap() + if policy.Key != nil { + t.Fatal("policy.Key is not nil") + } + if len(policy.Keys) != 1 { + t.Fatal("policy.Keys is the wrong size") + } + if !reflect.DeepEqual(testBytes, policy.Keys[1].Key) { + t.Fatal("key mismatch") + } +} + func Test_Archiving(t *testing.T) { // First, we generate a policy and rotate it a number of times. Each time // we'll ensure that we have the expected number of keys in the archive and @@ -18,9 +45,8 @@ func Test_Archiving(t *testing.T) { // zero and latest, respectively storage := &logical.InmemStorage{} - testName := "test" - policy, err := generatePolicy(storage, testName, false) + policy, err := generatePolicy(storage, "test", false) if err != nil { t.Fatal(err) }