Add unit tests
This commit is contained in:
parent
5000711a67
commit
30ffc18c19
|
@ -89,11 +89,19 @@ func pathConfigWrite(
|
|||
}
|
||||
}
|
||||
|
||||
// Add this as a guard here before persisting since we now require the min
|
||||
// decryption version to start at 1; even if it's not explicitly set here,
|
||||
// force the upgrade
|
||||
if policy.MinDecryptionVersion == 0 {
|
||||
policy.MinDecryptionVersion = 1
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
if !persistNeeded {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return resp, policy.Persist(req.Storage, name)
|
||||
return resp, policy.Persist(req.Storage)
|
||||
}
|
||||
|
||||
const pathConfigHelpSyn = `Configure a named encryption key`
|
||||
|
|
|
@ -76,23 +76,24 @@ type Policy struct {
|
|||
// The latest key version in this policy
|
||||
LatestVersion int `json:"latest_version"`
|
||||
|
||||
// The latest key version in the archive. We never delete these, so this is a max.
|
||||
// The latest key version in the archive. We never delete these, so this is
|
||||
// a max.
|
||||
ArchiveVersion int `json:"archive_version"`
|
||||
|
||||
// Whether the key is allowed to be deleted
|
||||
DeletionAllowed bool `json:"deletion_allowed"`
|
||||
}
|
||||
|
||||
// ArchivedKeys stores old keys. This is used to keep the key loading time sane when
|
||||
// there are huge numbers of rotations.
|
||||
// ArchivedKeys stores old keys. This is used to keep the key loading time sane
|
||||
// when there are huge numbers of rotations.
|
||||
type ArchivedKeys struct {
|
||||
Keys []KeyEntry `json:"keys"`
|
||||
}
|
||||
|
||||
func (p *Policy) loadArchive(storage logical.Storage, name string) (*ArchivedKeys, error) {
|
||||
func (p *Policy) loadArchive(storage logical.Storage) (*ArchivedKeys, error) {
|
||||
archive := &ArchivedKeys{}
|
||||
|
||||
raw, err := storage.Get("policy/" + name + "/archive")
|
||||
raw, err := storage.Get("policy/" + p.Name + "/archive")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -108,7 +109,7 @@ func (p *Policy) loadArchive(storage logical.Storage, name string) (*ArchivedKey
|
|||
return archive, nil
|
||||
}
|
||||
|
||||
func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage, name string) error {
|
||||
func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage) error {
|
||||
// Encode the policy
|
||||
buf, err := json.Marshal(archive)
|
||||
if err != nil {
|
||||
|
@ -117,7 +118,7 @@ func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage, na
|
|||
|
||||
// Write the policy into storage
|
||||
err = storage.Put(&logical.StorageEntry{
|
||||
Key: "policy/" + name + "/archive",
|
||||
Key: "policy/" + p.Name + "/archive",
|
||||
Value: buf,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -130,13 +131,24 @@ func (p *Policy) storeArchive(archive *ArchivedKeys, storage logical.Storage, na
|
|||
// handleArchiving manages the movement of keys to and from the policy archive.
|
||||
// This should *ONLY* be called from Persist() since it assumes that the policy
|
||||
// will be persisted afterwards.
|
||||
func (p *Policy) handleArchiving(storage logical.Storage, name string) error {
|
||||
func (p *Policy) handleArchiving(storage logical.Storage) error {
|
||||
// We need to move keys that are no longer accessible to ArchivedKeys, and keys
|
||||
// that now need to be accessible back here.
|
||||
//
|
||||
// For safety, because there isn't really a good reason to, we never delete
|
||||
// keys from the archive even when we move them back.
|
||||
|
||||
// Sanity checks
|
||||
switch {
|
||||
case p.MinDecryptionVersion < 1:
|
||||
return fmt.Errorf("minimum decryption version of %d is less than 1", p.MinDecryptionVersion)
|
||||
case p.LatestVersion < 1:
|
||||
return fmt.Errorf("latest version of %d is less than 1", p.LatestVersion)
|
||||
case p.MinDecryptionVersion > p.LatestVersion:
|
||||
return fmt.Errorf("minimum decryption version of %d is greater than the latest version %d",
|
||||
p.MinDecryptionVersion, p.LatestVersion)
|
||||
}
|
||||
|
||||
// Check if we have the latest minimum version in the current set of keys
|
||||
_, keysContainsMinimum := p.Keys[p.MinDecryptionVersion]
|
||||
|
||||
|
@ -148,7 +160,7 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
archive, err := p.loadArchive(storage, name)
|
||||
archive, err := p.loadArchive(storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -156,11 +168,11 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error {
|
|||
if keysContainsMinimum {
|
||||
// Need to move keys *to* archive
|
||||
|
||||
if len(archive.Keys) < p.MinDecryptionVersion-1 {
|
||||
// Increase the size of the archive slice. We need a size that is
|
||||
// equivalent to the minimum decryption version minus 1, but adding
|
||||
// one since slice numbering starts at 0 and we're indexing by key
|
||||
// version
|
||||
// We need a size that is equivalent to the minimum decryption version
|
||||
// minus 1, but adding one since slice numbering starts at 0 and we're
|
||||
// indexing by key version
|
||||
if len(archive.Keys) < p.MinDecryptionVersion {
|
||||
// Increase the size of the archive slice
|
||||
newKeys := make([]KeyEntry, p.MinDecryptionVersion)
|
||||
copy(newKeys, archive.Keys)
|
||||
archive.Keys = newKeys
|
||||
|
@ -177,7 +189,7 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error {
|
|||
archive.Keys[i] = p.Keys[i]
|
||||
}
|
||||
|
||||
err = p.storeArchive(archive, storage, name)
|
||||
err = p.storeArchive(archive, storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -231,8 +243,8 @@ func (p *Policy) handleArchiving(storage logical.Storage, name string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *Policy) Persist(storage logical.Storage, name string) error {
|
||||
err := p.handleArchiving(storage, name)
|
||||
func (p *Policy) Persist(storage logical.Storage) error {
|
||||
err := p.handleArchiving(storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -245,7 +257,7 @@ func (p *Policy) Persist(storage logical.Storage, name string) error {
|
|||
|
||||
// Write the policy into storage
|
||||
err = storage.Put(&logical.StorageEntry{
|
||||
Key: "policy/" + name,
|
||||
Key: "policy/" + p.Name,
|
||||
Value: buf,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -449,7 +461,7 @@ func (p *Policy) rotate(storage logical.Storage) error {
|
|||
p.MinDecryptionVersion = 1
|
||||
}
|
||||
|
||||
return p.Persist(storage, p.Name)
|
||||
return p.Persist(storage)
|
||||
}
|
||||
|
||||
func (p *Policy) migrateKeyToKeysMap() {
|
||||
|
@ -516,7 +528,7 @@ func getPolicy(req *logical.Request, name string) (*Policy, error) {
|
|||
}
|
||||
|
||||
if persistNeeded {
|
||||
err = p.Persist(req.Storage, name)
|
||||
err = p.Persist(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,153 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
var (
|
||||
keysArchive = []KeyEntry{KeyEntry{}}
|
||||
)
|
||||
|
||||
func Test_Archiving(t *testing.T) {
|
||||
// First, we generate a policy and rotate it a number of times. Each time
|
||||
// we'll ensure that we have the expected number of keys in the archive and
|
||||
// the main keys object, which without changing the min version should be
|
||||
// zero and latest, respectively
|
||||
|
||||
storage := &logical.InmemStorage{}
|
||||
testName := "test"
|
||||
|
||||
policy, err := generatePolicy(storage, testName, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if policy == nil {
|
||||
t.Fatal("policy is nil")
|
||||
}
|
||||
|
||||
// Store the initial key in the archive
|
||||
keysArchive = append(keysArchive, policy.Keys[1])
|
||||
checkKeys(t, policy, storage, 0, 1, 1)
|
||||
|
||||
for i := 2; i <= 10; i++ {
|
||||
err = policy.rotate(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
keysArchive = append(keysArchive, policy.Keys[i])
|
||||
checkKeys(t, policy, storage, 0, i, i)
|
||||
}
|
||||
|
||||
// Move the min decryption version up
|
||||
for i := 1; i <= 10; i++ {
|
||||
policy.MinDecryptionVersion = i
|
||||
|
||||
err = policy.Persist(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// We expect to find:
|
||||
// * The keys in archive are the min decryption version - 1
|
||||
// * The latest version is constant
|
||||
// * The number of keys in the policy itself is from the min
|
||||
// decryption version up to the latest version, so for e.g. 7 and
|
||||
// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
|
||||
// decryption version plus 1 (the min decryption version key
|
||||
// itself)
|
||||
checkKeys(t, policy, storage, i-1, 10, policy.LatestVersion-policy.MinDecryptionVersion+1)
|
||||
}
|
||||
|
||||
// Move the min decryption version down
|
||||
for i := 10; i >= 1; i-- {
|
||||
policy.MinDecryptionVersion = i
|
||||
|
||||
err = policy.Persist(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// We expect to find:
|
||||
// * The keys in archive are never removed so should be the previous
|
||||
// min decryption version (10) minus 1, always
|
||||
// * The latest version is constant
|
||||
// * The number of keys in the policy itself is from the min
|
||||
// decryption version up to the latest version, so for e.g. 7 and
|
||||
// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
|
||||
// decryption version plus 1 (the min decryption version key
|
||||
// itself)
|
||||
checkKeys(t, policy, storage, 9, 10, policy.LatestVersion-policy.MinDecryptionVersion+1)
|
||||
}
|
||||
}
|
||||
|
||||
func checkKeys(t *testing.T,
|
||||
policy *Policy,
|
||||
storage logical.Storage,
|
||||
archiveVer, latestVer, keysSize int) {
|
||||
|
||||
// Sanity check
|
||||
if len(keysArchive) != latestVer+1 {
|
||||
t.Fatalf("latest expected key version is %d, expected test keys archive size is %d, "+
|
||||
"but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive))
|
||||
}
|
||||
|
||||
archive, err := policy.loadArchive(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
badArchiveVer := false
|
||||
if archiveVer == 0 {
|
||||
if len(archive.Keys) != 0 || policy.ArchiveVersion != 0 {
|
||||
badArchiveVer = true
|
||||
}
|
||||
} else {
|
||||
// We need to subtract one because we have the indexes match key
|
||||
// versions, which start at 1. So for an archive version of 1, we
|
||||
// actually have two entries -- a blank 0 entry, and the key at spot 1
|
||||
if archiveVer != len(archive.Keys)-1 || archiveVer != policy.ArchiveVersion {
|
||||
badArchiveVer = true
|
||||
}
|
||||
}
|
||||
if badArchiveVer {
|
||||
t.Fatalf(
|
||||
"expected archive version %d, found length of archive keys %d and policy archive version %d",
|
||||
archiveVer, len(archive.Keys), policy.ArchiveVersion,
|
||||
)
|
||||
}
|
||||
|
||||
if latestVer != policy.LatestVersion {
|
||||
t.Fatalf(
|
||||
"expected latest version %d, found %d",
|
||||
latestVer, policy.LatestVersion,
|
||||
)
|
||||
}
|
||||
|
||||
if keysSize != len(policy.Keys) {
|
||||
t.Fatalf(
|
||||
"expected keys size %d, found %d",
|
||||
keysSize, len(policy.Keys),
|
||||
)
|
||||
}
|
||||
|
||||
for i := policy.MinDecryptionVersion; i <= policy.LatestVersion; i++ {
|
||||
if _, ok := policy.Keys[i]; !ok {
|
||||
t.Fatalf(
|
||||
"expected key %d, did not find it in policy keys", i,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
for i := policy.MinDecryptionVersion; i <= policy.LatestVersion; i++ {
|
||||
if !reflect.DeepEqual(policy.Keys[i], keysArchive[i]) {
|
||||
t.Fatalf("key %d not equivalent between policy keys and test keys archive", i)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 1; i < len(archive.Keys); i++ {
|
||||
if !reflect.DeepEqual(archive.Keys[i].Key, keysArchive[i].Key) {
|
||||
t.Fatalf("key %d not equivalent between policy archive and test keys archive", i)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue