Address review feedback

This commit is contained in:
Jeff Mitchell 2016-01-26 20:21:58 -05:00
parent e5a58109ec
commit 369d0bbad0
3 changed files with 66 additions and 46 deletions

View file

@ -74,7 +74,6 @@ func pathPolicyRead(
"deletion_allowed": p.DeletionAllowed, "deletion_allowed": p.DeletionAllowed,
"min_decryption_version": p.MinDecryptionVersion, "min_decryption_version": p.MinDecryptionVersion,
"latest_version": p.LatestVersion, "latest_version": p.LatestVersion,
"archive_version": p.ArchiveVersion,
}, },
} }
if p.Derived { if p.Derived {

View file

@ -162,43 +162,44 @@ func (p *Policy) handleArchiving(storage logical.Storage) error {
return err return err
} }
if keysContainsMinimum { if !keysContainsMinimum {
// Need to move keys *to* archive
// We need a size that is equivalent to the latest version (number of
// keys) but adding one since slice numbering starts at 0 and we're
// indexing by key version
if len(archive.Keys) < p.LatestVersion+1 {
// Increase the size of the archive slice
newKeys := make([]KeyEntry, p.LatestVersion+1)
copy(newKeys, archive.Keys)
archive.Keys = newKeys
}
// We are storing all keys in the archive, so we ensure that it is up
// to date up to p.LatestVersion
for i := p.ArchiveVersion + 1; i <= p.LatestVersion; i++ {
archive.Keys[i] = p.Keys[i]
p.ArchiveVersion = i
}
err = p.storeArchive(archive, storage)
if err != nil {
return err
}
// Perform deletion afterwards so that if there is an error saving we
// haven't messed with the current policy
for i := p.LatestVersion - len(p.Keys) + 1; i < p.MinDecryptionVersion; i++ {
delete(p.Keys, i)
}
} else {
// Need to move keys *from* archive // Need to move keys *from* archive
for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ { for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ {
p.Keys[i] = archive.Keys[i] p.Keys[i] = archive.Keys[i]
} }
return nil
}
// Need to move keys *to* archive
// We need a size that is equivalent to the latest version (number of keys)
// but adding one since slice numbering starts at 0 and we're indexing by
// key version
if len(archive.Keys) < p.LatestVersion+1 {
// Increase the size of the archive slice
newKeys := make([]KeyEntry, p.LatestVersion+1)
copy(newKeys, archive.Keys)
archive.Keys = newKeys
}
// We are storing all keys in the archive, so we ensure that it is up to
// date up to p.LatestVersion
for i := p.ArchiveVersion + 1; i <= p.LatestVersion; i++ {
archive.Keys[i] = p.Keys[i]
p.ArchiveVersion = i
}
err = p.storeArchive(archive, storage)
if err != nil {
return err
}
// Perform deletion afterwards so that if there is an error saving we
// haven't messed with the current policy
for i := p.LatestVersion - len(p.Keys) + 1; i < p.MinDecryptionVersion; i++ {
delete(p.Keys, i)
} }
return nil return nil
@ -238,10 +239,7 @@ func (p *Policy) Serialize() ([]byte, error) {
// mode is used with the context to derive the proper key. // mode is used with the context to derive the proper key.
func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) { func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
if p.Keys == nil || p.LatestVersion == 0 { if p.Keys == nil || p.LatestVersion == 0 {
if p.Key == nil || len(p.Key) == 0 { return nil, certutil.InternalError{Err: "unable to access the key; no key versions found"}
return nil, certutil.InternalError{Err: "unable to access the key; no key versions found"}
}
p.migrateKeyToKeysMap()
} }
if p.LatestVersion == 0 { if p.LatestVersion == 0 {
@ -398,7 +396,10 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) {
func (p *Policy) rotate(storage logical.Storage) error { func (p *Policy) rotate(storage logical.Storage) error {
if p.Keys == nil { if p.Keys == nil {
p.migrateKeyToKeysMap() // This is an initial key rotation when generating a new policy. We
// don't need to call migrate here because if we've called getPolicy to
// get the policy in the first place it will have been run.
p.Keys = KeyEntryMap{}
} }
// Generate a 256bit key // Generate a 256bit key
@ -426,12 +427,6 @@ func (p *Policy) rotate(storage logical.Storage) error {
} }
func (p *Policy) migrateKeyToKeysMap() { func (p *Policy) migrateKeyToKeysMap() {
if p.Key == nil || len(p.Key) == 0 {
p.Key = nil
p.Keys = KeyEntryMap{}
return
}
p.Keys = KeyEntryMap{ p.Keys = KeyEntryMap{
1: KeyEntry{ 1: KeyEntry{
Key: p.Key, Key: p.Key,

View file

@ -11,6 +11,33 @@ var (
keysArchive = []KeyEntry{KeyEntry{}} keysArchive = []KeyEntry{KeyEntry{}}
) )
func Test_KeyUpgrade(t *testing.T) {
storage := &logical.InmemStorage{}
policy, err := generatePolicy(storage, "test", false)
if err != nil {
t.Fatal(err)
}
if policy == nil {
t.Fatal("nil policy")
}
testBytes := make([]byte, len(policy.Keys[1].Key))
copy(testBytes, policy.Keys[1].Key)
policy.Key = policy.Keys[1].Key
policy.Keys = nil
policy.migrateKeyToKeysMap()
if policy.Key != nil {
t.Fatal("policy.Key is not nil")
}
if len(policy.Keys) != 1 {
t.Fatal("policy.Keys is the wrong size")
}
if !reflect.DeepEqual(testBytes, policy.Keys[1].Key) {
t.Fatal("key mismatch")
}
}
func Test_Archiving(t *testing.T) { func Test_Archiving(t *testing.T) {
// First, we generate a policy and rotate it a number of times. Each time // First, we generate a policy and rotate it a number of times. Each time
// we'll ensure that we have the expected number of keys in the archive and // we'll ensure that we have the expected number of keys in the archive and
@ -18,9 +45,8 @@ func Test_Archiving(t *testing.T) {
// zero and latest, respectively // zero and latest, respectively
storage := &logical.InmemStorage{} storage := &logical.InmemStorage{}
testName := "test"
policy, err := generatePolicy(storage, testName, false) policy, err := generatePolicy(storage, "test", false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }