Add unit tests

This commit is contained in:
Jeff Mitchell 2016-01-26 12:23:42 -05:00
parent 5000711a67
commit 30ffc18c19
3 changed files with 194 additions and 21 deletions

View File

@ -89,11 +89,19 @@ func pathConfigWrite(
} }
} }
// Add this as a guard here before persisting since we now require the min
// decryption version to start at 1; even if it's not explicitly set here,
// force the upgrade
if policy.MinDecryptionVersion == 0 {
policy.MinDecryptionVersion = 1
persistNeeded = true
}
if !persistNeeded { if !persistNeeded {
return nil, nil return nil, nil
} }
return resp, policy.Persist(req.Storage, name) return resp, policy.Persist(req.Storage)
} }
const pathConfigHelpSyn = `Configure a named encryption key` const pathConfigHelpSyn = `Configure a named encryption key`

View File

@ -76,23 +76,24 @@ type Policy struct {
// The latest key version in this policy // The latest key version in this policy
LatestVersion int `json:"latest_version"` LatestVersion int `json:"latest_version"`
// The latest key version in the archive. We never delete these, so this is a max. // The latest key version in the archive. We never delete these, so this is
// a max.
ArchiveVersion int `json:"archive_version"` ArchiveVersion int `json:"archive_version"`
// Whether the key is allowed to be deleted // Whether the key is allowed to be deleted
DeletionAllowed bool `json:"deletion_allowed"` DeletionAllowed bool `json:"deletion_allowed"`
} }
// ArchivedKeys stores old keys. This is used to keep the key loading time sane when // ArchivedKeys stores old keys. This is used to keep the key loading time sane
// there are huge numbers of rotations. // when there are huge numbers of rotations.
type ArchivedKeys struct { type ArchivedKeys struct {
Keys []KeyEntry `json:"keys"` Keys []KeyEntry `json:"keys"`
} }
func (p *Policy) loadArchive(storage logical.Storage, name string) (*ArchivedKeys, error) { func (p *Policy) loadArchive(storage logical.Storage) (*ArchivedKeys, error) {
archive := &ArchivedKeys{} archive := &ArchivedKeys{}
raw, err := storage.Get("policy/" + name + "/archive") raw, err := storage.Get("policy/" + p.Name + "/archive")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -108,7 +109,7 @@ func (p *Policy) loadArchive(storage logical.Storage, name string) (*ArchivedKey
return archive, nil return archive, nil
} }
func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage, name string) error { func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage) error {
// Encode the policy // Encode the policy
buf, err := json.Marshal(archive) buf, err := json.Marshal(archive)
if err != nil { if err != nil {
@ -117,7 +118,7 @@ func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage, na
// Write the policy into storage // Write the policy into storage
err = storage.Put(&logical.StorageEntry{ err = storage.Put(&logical.StorageEntry{
Key: "policy/" + name + "/archive", Key: "policy/" + p.Name + "/archive",
Value: buf, Value: buf,
}) })
if err != nil { if err != nil {
@ -130,13 +131,24 @@ func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage, na
// handleArchiving manages the movement of keys to and from the policy archive. // handleArchiving manages the movement of keys to and from the policy archive.
// This should *ONLY* be called from Persist() since it assumes that the policy // This should *ONLY* be called from Persist() since it assumes that the policy
// will be persisted afterwards. // will be persisted afterwards.
func (p *Policy) handleArchiving(storage logical.Storage, name string) error { func (p *Policy) handleArchiving(storage logical.Storage) error {
// We need to move keys that are no longer accessible to ArchivedKeys, and keys // We need to move keys that are no longer accessible to ArchivedKeys, and keys
// that now need to be accessible back here. // that now need to be accessible back here.
// //
// For safety, because there isn't really a good reason to, we never delete // For safety, because there isn't really a good reason to, we never delete
// keys from the archive even when we move them back. // keys from the archive even when we move them back.
// Sanity checks
switch {
case p.MinDecryptionVersion < 1:
return fmt.Errorf("minimum decryption version of %d is less than 1", p.MinDecryptionVersion)
case p.LatestVersion < 1:
return fmt.Errorf("latest version of %d is less than 1", p.LatestVersion)
case p.MinDecryptionVersion > p.LatestVersion:
return fmt.Errorf("minimum decryption version of %d is greater than the latest version %d",
p.MinDecryptionVersion, p.LatestVersion)
}
// Check if we have the latest minimum version in the current set of keys // Check if we have the latest minimum version in the current set of keys
_, keysContainsMinimum := p.Keys[p.MinDecryptionVersion] _, keysContainsMinimum := p.Keys[p.MinDecryptionVersion]
@ -148,7 +160,7 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error {
return nil return nil
} }
archive, err := p.loadArchive(storage, name) archive, err := p.loadArchive(storage)
if err != nil { if err != nil {
return err return err
} }
@ -156,11 +168,11 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error {
if keysContainsMinimum { if keysContainsMinimum {
// Need to move keys *to* archive // Need to move keys *to* archive
if len(archive.Keys) < p.MinDecryptionVersion-1 { // We need a size that is equivalent to the minimum decryption version
// Increase the size of the archive slice. We need a size that is // minus 1, but adding one since slice numbering starts at 0 and we're
// equivalent to the minimum decryption version minus 1, but adding // indexing by key version
// one since slice numbering starts at 0 and we're indexing by key if len(archive.Keys) < p.MinDecryptionVersion {
// version // Increase the size of the archive slice
newKeys := make([]KeyEntry, p.MinDecryptionVersion) newKeys := make([]KeyEntry, p.MinDecryptionVersion)
copy(newKeys, archive.Keys) copy(newKeys, archive.Keys)
archive.Keys = newKeys archive.Keys = newKeys
@ -177,7 +189,7 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error {
archive.Keys[i] = p.Keys[i] archive.Keys[i] = p.Keys[i]
} }
err = p.storeArchive(archive, storage, name) err = p.storeArchive(archive, storage)
if err != nil { if err != nil {
return err return err
} }
@ -231,8 +243,8 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error {
return nil return nil
} }
func (p *Policy) Persist(storage logical.Storage, name string) error { func (p *Policy) Persist(storage logical.Storage) error {
err := p.handleArchiving(storage, name) err := p.handleArchiving(storage)
if err != nil { if err != nil {
return err return err
} }
@ -245,7 +257,7 @@ func (p *Policy) Persist(storage logical.Storage, name string) error {
// Write the policy into storage // Write the policy into storage
err = storage.Put(&logical.StorageEntry{ err = storage.Put(&logical.StorageEntry{
Key: "policy/" + name, Key: "policy/" + p.Name,
Value: buf, Value: buf,
}) })
if err != nil { if err != nil {
@ -449,7 +461,7 @@ func (p *Policy) rotate(storage logical.Storage) error {
p.MinDecryptionVersion = 1 p.MinDecryptionVersion = 1
} }
return p.Persist(storage, p.Name) return p.Persist(storage)
} }
func (p *Policy) migrateKeyToKeysMap() { func (p *Policy) migrateKeyToKeysMap() {
@ -516,7 +528,7 @@ func getPolicy(req *logical.Request, name string) (*Policy, error) {
} }
if persistNeeded { if persistNeeded {
err = p.Persist(req.Storage, name) err = p.Persist(req.Storage)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -0,0 +1,153 @@
package transit
import (
"reflect"
"testing"
"github.com/hashicorp/vault/logical"
)
var (
keysArchive = []KeyEntry{KeyEntry{}}
)
func Test_Archiving(t *testing.T) {
// First, we generate a policy and rotate it a number of times. Each time
// we'll ensure that we have the expected number of keys in the archive and
// the main keys object, which without changing the min version should be
// zero and latest, respectively
storage := &logical.InmemStorage{}
testName := "test"
policy, err := generatePolicy(storage, testName, false)
if err != nil {
t.Fatal(err)
}
if policy == nil {
t.Fatal("policy is nil")
}
// Store the initial key in the archive
keysArchive = append(keysArchive, policy.Keys[1])
checkKeys(t, policy, storage, 0, 1, 1)
for i := 2; i <= 10; i++ {
err = policy.rotate(storage)
if err != nil {
t.Fatal(err)
}
keysArchive = append(keysArchive, policy.Keys[i])
checkKeys(t, policy, storage, 0, i, i)
}
// Move the min decryption version up
for i := 1; i <= 10; i++ {
policy.MinDecryptionVersion = i
err = policy.Persist(storage)
if err != nil {
t.Fatal(err)
}
// We expect to find:
// * The keys in archive are the min decryption version - 1
// * The latest version is constant
// * The number of keys in the policy itself is from the min
// decryption version up to the latest version, so for e.g. 7 and
// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
// decryption version plus 1 (the min decryption version key
// itself)
checkKeys(t, policy, storage, i-1, 10, policy.LatestVersion-policy.MinDecryptionVersion+1)
}
// Move the min decryption version down
for i := 10; i >= 1; i-- {
policy.MinDecryptionVersion = i
err = policy.Persist(storage)
if err != nil {
t.Fatal(err)
}
// We expect to find:
// * The keys in archive are never removed so should be the previous
// min decryption version (10) minus 1, always
// * The latest version is constant
// * The number of keys in the policy itself is from the min
// decryption version up to the latest version, so for e.g. 7 and
// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
// decryption version plus 1 (the min decryption version key
// itself)
checkKeys(t, policy, storage, 9, 10, policy.LatestVersion-policy.MinDecryptionVersion+1)
}
}
func checkKeys(t *testing.T,
policy *Policy,
storage logical.Storage,
archiveVer, latestVer, keysSize int) {
// Sanity check
if len(keysArchive) != latestVer+1 {
t.Fatalf("latest expected key version is %d, expected test keys archive size is %d, "+
"but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive))
}
archive, err := policy.loadArchive(storage)
if err != nil {
t.Fatal(err)
}
badArchiveVer := false
if archiveVer == 0 {
if len(archive.Keys) != 0 || policy.ArchiveVersion != 0 {
badArchiveVer = true
}
} else {
// We need to subtract one because we have the indexes match key
// versions, which start at 1. So for an archive version of 1, we
// actually have two entries -- a blank 0 entry, and the key at spot 1
if archiveVer != len(archive.Keys)-1 || archiveVer != policy.ArchiveVersion {
badArchiveVer = true
}
}
if badArchiveVer {
t.Fatalf(
"expected archive version %d, found length of archive keys %d and policy archive version %d",
archiveVer, len(archive.Keys), policy.ArchiveVersion,
)
}
if latestVer != policy.LatestVersion {
t.Fatalf(
"expected latest version %d, found %d",
latestVer, policy.LatestVersion,
)
}
if keysSize != len(policy.Keys) {
t.Fatalf(
"expected keys size %d, found %d",
keysSize, len(policy.Keys),
)
}
for i := policy.MinDecryptionVersion; i <= policy.LatestVersion; i++ {
if _, ok := policy.Keys[i]; !ok {
t.Fatalf(
"expected key %d, did not find it in policy keys", i,
)
}
}
for i := policy.MinDecryptionVersion; i <= policy.LatestVersion; i++ {
if !reflect.DeepEqual(policy.Keys[i], keysArchive[i]) {
t.Fatalf("key %d not equivalent between policy keys and test keys archive", i)
}
}
for i := 1; i < len(archive.Keys); i++ {
if !reflect.DeepEqual(archive.Keys[i].Key, keysArchive[i].Key) {
t.Fatalf("key %d not equivalent between policy archive and test keys archive", i)
}
}
}