Address review feedback
This commit is contained in:
parent
e5a58109ec
commit
369d0bbad0
|
@ -74,7 +74,6 @@ func pathPolicyRead(
|
|||
"deletion_allowed": p.DeletionAllowed,
|
||||
"min_decryption_version": p.MinDecryptionVersion,
|
||||
"latest_version": p.LatestVersion,
|
||||
"archive_version": p.ArchiveVersion,
|
||||
},
|
||||
}
|
||||
if p.Derived {
|
||||
|
|
|
@ -162,43 +162,44 @@ func (p *Policy) handleArchiving(storage logical.Storage) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if keysContainsMinimum {
|
||||
// Need to move keys *to* archive
|
||||
|
||||
// We need a size that is equivalent to the latest version (number of
|
||||
// keys) but adding one since slice numbering starts at 0 and we're
|
||||
// indexing by key version
|
||||
if len(archive.Keys) < p.LatestVersion+1 {
|
||||
// Increase the size of the archive slice
|
||||
newKeys := make([]KeyEntry, p.LatestVersion+1)
|
||||
copy(newKeys, archive.Keys)
|
||||
archive.Keys = newKeys
|
||||
}
|
||||
|
||||
// We are storing all keys in the archive, so we ensure that it is up
|
||||
// to date up to p.LatestVersion
|
||||
for i := p.ArchiveVersion + 1; i <= p.LatestVersion; i++ {
|
||||
archive.Keys[i] = p.Keys[i]
|
||||
p.ArchiveVersion = i
|
||||
}
|
||||
|
||||
err = p.storeArchive(archive, storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Perform deletion afterwards so that if there is an error saving we
|
||||
// haven't messed with the current policy
|
||||
for i := p.LatestVersion - len(p.Keys) + 1; i < p.MinDecryptionVersion; i++ {
|
||||
delete(p.Keys, i)
|
||||
}
|
||||
|
||||
} else {
|
||||
if !keysContainsMinimum {
|
||||
// Need to move keys *from* archive
|
||||
|
||||
for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ {
|
||||
p.Keys[i] = archive.Keys[i]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Need to move keys *to* archive
|
||||
|
||||
// We need a size that is equivalent to the latest version (number of keys)
|
||||
// but adding one since slice numbering starts at 0 and we're indexing by
|
||||
// key version
|
||||
if len(archive.Keys) < p.LatestVersion+1 {
|
||||
// Increase the size of the archive slice
|
||||
newKeys := make([]KeyEntry, p.LatestVersion+1)
|
||||
copy(newKeys, archive.Keys)
|
||||
archive.Keys = newKeys
|
||||
}
|
||||
|
||||
// We are storing all keys in the archive, so we ensure that it is up to
|
||||
// date up to p.LatestVersion
|
||||
for i := p.ArchiveVersion + 1; i <= p.LatestVersion; i++ {
|
||||
archive.Keys[i] = p.Keys[i]
|
||||
p.ArchiveVersion = i
|
||||
}
|
||||
|
||||
err = p.storeArchive(archive, storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Perform deletion afterwards so that if there is an error saving we
|
||||
// haven't messed with the current policy
|
||||
for i := p.LatestVersion - len(p.Keys) + 1; i < p.MinDecryptionVersion; i++ {
|
||||
delete(p.Keys, i)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -238,10 +239,7 @@ func (p *Policy) Serialize() ([]byte, error) {
|
|||
// mode is used with the context to derive the proper key.
|
||||
func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
|
||||
if p.Keys == nil || p.LatestVersion == 0 {
|
||||
if p.Key == nil || len(p.Key) == 0 {
|
||||
return nil, certutil.InternalError{Err: "unable to access the key; no key versions found"}
|
||||
}
|
||||
p.migrateKeyToKeysMap()
|
||||
return nil, certutil.InternalError{Err: "unable to access the key; no key versions found"}
|
||||
}
|
||||
|
||||
if p.LatestVersion == 0 {
|
||||
|
@ -398,7 +396,10 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) {
|
|||
|
||||
func (p *Policy) rotate(storage logical.Storage) error {
|
||||
if p.Keys == nil {
|
||||
p.migrateKeyToKeysMap()
|
||||
// This is an initial key rotation when generating a new policy. We
|
||||
// don't need to call migrate here because if we've called getPolicy to
|
||||
// get the policy in the first place it will have been run.
|
||||
p.Keys = KeyEntryMap{}
|
||||
}
|
||||
|
||||
// Generate a 256bit key
|
||||
|
@ -426,12 +427,6 @@ func (p *Policy) rotate(storage logical.Storage) error {
|
|||
}
|
||||
|
||||
func (p *Policy) migrateKeyToKeysMap() {
|
||||
if p.Key == nil || len(p.Key) == 0 {
|
||||
p.Key = nil
|
||||
p.Keys = KeyEntryMap{}
|
||||
return
|
||||
}
|
||||
|
||||
p.Keys = KeyEntryMap{
|
||||
1: KeyEntry{
|
||||
Key: p.Key,
|
||||
|
|
|
@ -11,6 +11,33 @@ var (
|
|||
keysArchive = []KeyEntry{KeyEntry{}}
|
||||
)
|
||||
|
||||
func Test_KeyUpgrade(t *testing.T) {
|
||||
storage := &logical.InmemStorage{}
|
||||
policy, err := generatePolicy(storage, "test", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if policy == nil {
|
||||
t.Fatal("nil policy")
|
||||
}
|
||||
|
||||
testBytes := make([]byte, len(policy.Keys[1].Key))
|
||||
copy(testBytes, policy.Keys[1].Key)
|
||||
|
||||
policy.Key = policy.Keys[1].Key
|
||||
policy.Keys = nil
|
||||
policy.migrateKeyToKeysMap()
|
||||
if policy.Key != nil {
|
||||
t.Fatal("policy.Key is not nil")
|
||||
}
|
||||
if len(policy.Keys) != 1 {
|
||||
t.Fatal("policy.Keys is the wrong size")
|
||||
}
|
||||
if !reflect.DeepEqual(testBytes, policy.Keys[1].Key) {
|
||||
t.Fatal("key mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Archiving(t *testing.T) {
|
||||
// First, we generate a policy and rotate it a number of times. Each time
|
||||
// we'll ensure that we have the expected number of keys in the archive and
|
||||
|
@ -18,9 +45,8 @@ func Test_Archiving(t *testing.T) {
|
|||
// zero and latest, respectively
|
||||
|
||||
storage := &logical.InmemStorage{}
|
||||
testName := "test"
|
||||
|
||||
policy, err := generatePolicy(storage, testName, false)
|
||||
policy, err := generatePolicy(storage, "test", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue