Pulled out transit's lock manager and policy structs into a helper

This commit is contained in:
vishalnayak 2016-10-26 19:52:31 -04:00
parent c74303dd59
commit 6d1e1a3ba5
11 changed files with 200 additions and 186 deletions

View File

@ -1,6 +1,7 @@
package transit
import (
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -39,12 +40,12 @@ func Backend(conf *logical.BackendConfig) *backend {
Secrets: []*framework.Secret{},
}
b.lm = newLockManager(conf.System.CachingDisabled())
b.lm = keysutil.NewLockManager(conf.System.CachingDisabled())
return &b
}
type backend struct {
*framework.Backend
lm *lockManager
lm *keysutil.LockManager
}

View File

@ -12,6 +12,7 @@ import (
"time"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
logicaltest "github.com/hashicorp/vault/logical/testing"
@ -289,7 +290,7 @@ func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool)
if d.Name != name {
return fmt.Errorf("bad name: %#v", d)
}
if d.Type != KeyType(keyType_AES256_GCM96).String() {
if d.Type != keysutil.KeyType(keysutil.KeyType_AES256_GCM96).String() {
return fmt.Errorf("bad key type: %#v", d)
}
// Should NOT get a key back
@ -583,13 +584,13 @@ func testAccStepDecryptDatakey(t *testing.T, name string,
func TestKeyUpgrade(t *testing.T) {
key, _ := uuid.GenerateRandomBytes(32)
p := &policy{
p := &keysutil.Policy{
Name: "test",
Key: key,
Type: keyType_AES256_GCM96,
Type: keysutil.KeyType_AES256_GCM96,
}
p.migrateKeyToKeysMap()
p.MigrateKeyToKeysMap()
if p.Key != nil ||
p.Keys == nil ||
@ -604,18 +605,18 @@ func TestDerivedKeyUpgrade(t *testing.T) {
key, _ := uuid.GenerateRandomBytes(32)
context, _ := uuid.GenerateRandomBytes(32)
p := &policy{
p := &keysutil.Policy{
Name: "test",
Key: key,
Type: keyType_AES256_GCM96,
Type: keysutil.KeyType_AES256_GCM96,
Derived: true,
}
p.migrateKeyToKeysMap()
p.upgrade(storage) // Need to run the upgrade code to make the migration stick
p.MigrateKeyToKeysMap()
p.Upgrade(storage) // Need to run the upgrade code to make the migration stick
if p.KDF != kdf_hmac_sha256_counter {
t.Fatalf("bad KDF value by default; counter val is %d, KDF val is %d, policy is %#v", kdf_hmac_sha256_counter, p.KDF, *p)
if p.KDF != keysutil.Kdf_hmac_sha256_counter {
t.Fatalf("bad KDF value by default; counter val is %d, KDF val is %d, policy is %#v", keysutil.Kdf_hmac_sha256_counter, p.KDF, *p)
}
derBytesOld, err := p.DeriveKey(context, 1)
@ -632,8 +633,8 @@ func TestDerivedKeyUpgrade(t *testing.T) {
t.Fatal("mismatch of same context alg")
}
p.KDF = kdf_hkdf_sha256
if p.needsUpgrade() {
p.KDF = keysutil.Kdf_hkdf_sha256
if p.NeedsUpgrade() {
t.Fatal("expected no upgrade needed")
}
@ -692,15 +693,15 @@ func testConvergentEncryptionCommon(t *testing.T, ver int) {
t.Fatalf("bad: expected error response, got %#v", *resp)
}
p := &policy{
p := &keysutil.Policy{
Name: "testkey",
Type: keyType_AES256_GCM96,
Type: keysutil.KeyType_AES256_GCM96,
Derived: true,
ConvergentEncryption: true,
ConvergentVersion: ver,
}
err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}
@ -976,7 +977,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) {
resp, err := be.pathDecryptWrite(req, fd)
if err != nil {
// This could well happen since the min version is jumping around
if resp.Data["error"].(string) == ErrTooOld {
if resp.Data["error"].(string) == keysutil.ErrTooOld {
continue
}
t.Fatalf("got an error: %v, resp is %#v, ciphertext was %s, chosenKey is %s, id is %d", err, *resp, ct, chosenKey, id)

View File

@ -6,6 +6,7 @@ import (
"sync"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -116,7 +117,7 @@ func (b *backend) pathEncryptWrite(
}
// Get the policy
var p *policy
var p *keysutil.Policy
var lock *sync.RWMutex
var upserted bool
if req.Operation == logical.CreateOperation {
@ -125,17 +126,17 @@ func (b *backend) pathEncryptWrite(
return logical.ErrorResponse("convergent encryption requires derivation to be enabled, so context is required"), nil
}
polReq := policyRequest{
storage: req.Storage,
name: name,
derived: len(context) != 0,
convergent: convergent,
polReq := keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
Derived: len(context) != 0,
Convergent: convergent,
}
keyType := d.Get("type").(string)
switch keyType {
case "aes256-gcm96":
polReq.keyType = keyType_AES256_GCM96
polReq.KeyType = keysutil.KeyType_AES256_GCM96
case "ecdsa-p256":
return logical.ErrorResponse(fmt.Sprintf("key type %v not supported for this operation", keyType)), logical.ErrInvalidRequest
default:

View File

@ -124,7 +124,7 @@ func TestTransit_HMAC(t *testing.T) {
req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
// Rotate
err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"strconv"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -95,17 +96,17 @@ func (b *backend) pathPolicyWrite(
return logical.ErrorResponse("convergent encryption requires derivation to be enabled"), nil
}
polReq := policyRequest{
storage: req.Storage,
name: name,
derived: derived,
convergent: convergent,
polReq := keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
Derived: derived,
Convergent: convergent,
}
switch keyType {
case "aes256-gcm96":
polReq.keyType = keyType_AES256_GCM96
polReq.KeyType = keysutil.KeyType_AES256_GCM96
case "ecdsa-p256":
polReq.keyType = keyType_ECDSA_P256
polReq.KeyType = keysutil.KeyType_ECDSA_P256
default:
return logical.ErrorResponse(fmt.Sprintf("unknown key type %v", keyType)), logical.ErrInvalidRequest
}
@ -158,10 +159,10 @@ func (b *backend) pathPolicyRead(
if p.Derived {
switch p.KDF {
case kdf_hmac_sha256_counter:
case keysutil.Kdf_hmac_sha256_counter:
resp.Data["kdf"] = "hmac-sha256-counter"
resp.Data["kdf_mode"] = "hmac-sha256-counter"
case kdf_hkdf_sha256:
case keysutil.Kdf_hkdf_sha256:
resp.Data["kdf"] = "hkdf_sha256"
}
resp.Data["convergent_encryption"] = p.ConvergentEncryption
@ -171,14 +172,14 @@ func (b *backend) pathPolicyRead(
}
switch p.Type {
case keyType_AES256_GCM96:
case keysutil.KeyType_AES256_GCM96:
retKeys := map[string]int64{}
for k, v := range p.Keys {
retKeys[strconv.Itoa(k)] = v.CreationTime
}
resp.Data["keys"] = retKeys
case keyType_ECDSA_P256:
case keysutil.KeyType_ECDSA_P256:
type ecdsaKey struct {
Name string `json:"name"`
PublicKey string `json:"public_key"`

View File

@ -41,7 +41,7 @@ func (b *backend) pathRotateWrite(
}
// Rotate the policy
err = p.rotate(req.Storage)
err = p.Rotate(req.Storage)
return nil, err
}

View File

@ -177,11 +177,11 @@ func TestTransit_SignVerify(t *testing.T) {
signRequest(req, true, "")
// Rotate and set min decryption version
err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}
err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}

View File

@ -4,28 +4,29 @@ import (
"reflect"
"testing"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
)
var (
keysArchive []keyEntry
keysArchive []keysutil.KeyEntry
)
func resetKeysArchive() {
keysArchive = []keyEntry{keyEntry{}}
keysArchive = []keysutil.KeyEntry{keysutil.KeyEntry{}}
}
func Test_KeyUpgrade(t *testing.T) {
testKeyUpgradeCommon(t, newLockManager(false))
testKeyUpgradeCommon(t, newLockManager(true))
testKeyUpgradeCommon(t, keysutil.NewLockManager(false))
testKeyUpgradeCommon(t, keysutil.NewLockManager(true))
}
func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
func testKeyUpgradeCommon(t *testing.T, lm *keysutil.LockManager) {
storage := &logical.InmemStorage{}
p, lock, upserted, err := lm.GetPolicyUpsert(policyRequest{
storage: storage,
keyType: keyType_AES256_GCM96,
name: "test",
p, lock, upserted, err := lm.GetPolicyUpsert(keysutil.PolicyRequest{
Storage: storage,
KeyType: keysutil.KeyType_AES256_GCM96,
Name: "test",
})
if lock != nil {
defer lock.RUnlock()
@ -45,7 +46,7 @@ func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
p.Key = p.Keys[1].AESKey
p.Keys = nil
p.migrateKeyToKeysMap()
p.MigrateKeyToKeysMap()
if p.Key != nil {
t.Fatal("policy.Key is not nil")
}
@ -58,11 +59,11 @@ func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
}
func Test_ArchivingUpgrade(t *testing.T) {
testArchivingUpgradeCommon(t, newLockManager(false))
testArchivingUpgradeCommon(t, newLockManager(true))
testArchivingUpgradeCommon(t, keysutil.NewLockManager(false))
testArchivingUpgradeCommon(t, keysutil.NewLockManager(true))
}
func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
func testArchivingUpgradeCommon(t *testing.T, lm *keysutil.LockManager) {
resetKeysArchive()
// First, we generate a policy and rotate it a number of times. Each time
@ -71,10 +72,10 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
// zero and latest, respectively
storage := &logical.InmemStorage{}
p, lock, _, err := lm.GetPolicyUpsert(policyRequest{
storage: storage,
keyType: keyType_AES256_GCM96,
name: "test",
p, lock, _, err := lm.GetPolicyUpsert(keysutil.PolicyRequest{
Storage: storage,
KeyType: keysutil.KeyType_AES256_GCM96,
Name: "test",
})
if err != nil {
t.Fatal(err)
@ -89,7 +90,7 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
checkKeys(t, p, storage, "initial", 1, 1, 1)
for i := 2; i <= 10; i++ {
err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}
@ -123,7 +124,7 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
// If we're caching, expire from the cache since we modified it
// under-the-hood
if lm.CacheActive() {
delete(lm.cache, "test")
lm.CacheDelete("test")
}
// Now get the policy again; the upgrade should happen automatically
@ -141,7 +142,7 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
// Let's check some deletion logic while we're at it
// The policy should be in there
if lm.CacheActive() && lm.cache["test"] == nil {
if lm.CacheActive() && lm.Cache("test") == nil {
t.Fatal("nil policy in cache")
}
@ -152,7 +153,7 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
}
// The policy should still be in there
if lm.CacheActive() && lm.cache["test"] == nil {
if lm.CacheActive() && lm.Cache("test") == nil {
t.Fatal("nil policy in cache")
}
@ -177,7 +178,7 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
}
// The policy should *not* be in there
if lm.CacheActive() && lm.cache["test"] != nil {
if lm.CacheActive() && lm.Cache("test") != nil {
t.Fatal("non-nil policy in cache")
}
@ -191,11 +192,11 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
}
func Test_Archiving(t *testing.T) {
testArchivingCommon(t, newLockManager(false))
testArchivingCommon(t, newLockManager(true))
testArchivingCommon(t, keysutil.NewLockManager(false))
testArchivingCommon(t, keysutil.NewLockManager(true))
}
func testArchivingCommon(t *testing.T, lm *lockManager) {
func testArchivingCommon(t *testing.T, lm *keysutil.LockManager) {
resetKeysArchive()
// First, we generate a policy and rotate it a number of times. Each time // we'll ensure that we have the expected number of keys in the archive and
@ -203,10 +204,10 @@ func testArchivingCommon(t *testing.T, lm *lockManager) {
// zero and latest, respectively
storage := &logical.InmemStorage{}
p, lock, _, err := lm.GetPolicyUpsert(policyRequest{
storage: storage,
keyType: keyType_AES256_GCM96,
name: "test",
p, lock, _, err := lm.GetPolicyUpsert(keysutil.PolicyRequest{
Storage: storage,
KeyType: keysutil.KeyType_AES256_GCM96,
Name: "test",
})
if lock != nil {
defer lock.RUnlock()
@ -223,7 +224,7 @@ func testArchivingCommon(t *testing.T, lm *lockManager) {
checkKeys(t, p, storage, "initial", 1, 1, 1)
for i := 2; i <= 10; i++ {
err = p.rotate(storage)
err = p.Rotate(storage)
if err != nil {
t.Fatal(err)
}
@ -271,7 +272,7 @@ func testArchivingCommon(t *testing.T, lm *lockManager) {
}
func checkKeys(t *testing.T,
p *policy,
p *keysutil.Policy,
storage logical.Storage,
action string,
archiveVer, latestVer, keysSize int) {
@ -282,7 +283,7 @@ func checkKeys(t *testing.T,
"but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive))
}
archive, err := p.loadArchive(storage)
archive, err := p.LoadArchive(storage)
if err != nil {
t.Fatal(err)
}

View File

@ -1,4 +1,4 @@
package transit
package keysutil
import (
"errors"
@ -18,29 +18,29 @@ var (
errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation")
)
// policyRequest holds values used when requesting a policy. Most values are
// PolicyRequest holds values used when requesting a policy. Most values are
// only used during an upsert.
type policyRequest struct {
type PolicyRequest struct {
// The storage to use
storage logical.Storage
Storage logical.Storage
// The name of the policy
name string
Name string
// The key type
keyType KeyType
KeyType KeyType
// Whether it should be derived
derived bool
Derived bool
// Whether to enable convergent encryption
convergent bool
Convergent bool
// Whether to upsert
upsert bool
Upsert bool
}
type lockManager struct {
type LockManager struct {
// A lock for each named key
locks map[string]*sync.RWMutex
@ -48,27 +48,35 @@ type lockManager struct {
locksMutex sync.RWMutex
// If caching is enabled, the map of name to in-memory policy cache
cache map[string]*policy
cache map[string]*Policy
// Used for global locking, and as the cache map mutex
cacheMutex sync.RWMutex
}
func newLockManager(cacheDisabled bool) *lockManager {
lm := &lockManager{
func NewLockManager(cacheDisabled bool) *LockManager {
lm := &LockManager{
locks: map[string]*sync.RWMutex{},
}
if !cacheDisabled {
lm.cache = map[string]*policy{}
lm.cache = map[string]*Policy{}
}
return lm
}
func (lm *lockManager) CacheActive() bool {
func (lm *LockManager) CacheActive() bool {
return lm.cache != nil
}
func (lm *lockManager) policyLock(name string, lockType bool) *sync.RWMutex {
func (lm *LockManager) CacheDelete(name string) {
delete(lm.cache, name)
}
func (lm *LockManager) Cache(name string) *Policy {
return lm.cache[name]
}
func (lm *LockManager) policyLock(name string, lockType bool) *sync.RWMutex {
lm.locksMutex.RLock()
lock := lm.locks[name]
if lock != nil {
@ -115,7 +123,7 @@ func (lm *lockManager) policyLock(name string, lockType bool) *sync.RWMutex {
return lock
}
func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
func (lm *LockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
if lockType == exclusive {
lock.Unlock()
} else {
@ -126,10 +134,10 @@ func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
// Get the policy with a read lock. If we get an error saying an exclusive lock
// is needed (for instance, for an upgrade/migration), give up the read lock,
// call again with an exclusive lock, then swap back out for a read lock.
func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(policyRequest{
storage: storage,
name: name,
func (lm *LockManager) GetPolicyShared(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(PolicyRequest{
Storage: storage,
Name: name,
}, shared)
if err == nil ||
(err != nil && err != errNeedExclusiveLock) {
@ -137,9 +145,9 @@ func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*p
}
// Try again while asking for an exlusive lock
p, lock, _, err = lm.getPolicyCommon(policyRequest{
storage: storage,
name: name,
p, lock, _, err = lm.getPolicyCommon(PolicyRequest{
Storage: storage,
Name: name,
}, exclusive)
if err != nil || p == nil || lock == nil {
return p, lock, err
@ -147,18 +155,18 @@ func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*p
lock.Unlock()
p, lock, _, err = lm.getPolicyCommon(policyRequest{
storage: storage,
name: name,
p, lock, _, err = lm.getPolicyCommon(PolicyRequest{
Storage: storage,
Name: name,
}, shared)
return p, lock, err
}
// Get the policy with an exclusive lock
func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string) (*policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(policyRequest{
storage: storage,
name: name,
func (lm *LockManager) GetPolicyExclusive(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(PolicyRequest{
Storage: storage,
Name: name,
}, exclusive)
return p, lock, err
}
@ -166,8 +174,8 @@ func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string)
// Get the policy with a read lock; if it returns that an exclusive lock is
// needed, retry. If successful, call one more time to get a read lock and
// return the value.
func (lm *lockManager) GetPolicyUpsert(req policyRequest) (*policy, *sync.RWMutex, bool, error) {
req.upsert = true
func (lm *LockManager) GetPolicyUpsert(req PolicyRequest) (*Policy, *sync.RWMutex, bool, error) {
req.Upsert = true
p, lock, _, err := lm.getPolicyCommon(req, shared)
if err == nil ||
@ -182,7 +190,7 @@ func (lm *lockManager) GetPolicyUpsert(req policyRequest) (*policy, *sync.RWMute
}
lock.Unlock()
req.upsert = false
req.Upsert = false
// Now get a shared lock for the return, but preserve the value of upserted
p, lock, _, err = lm.getPolicyCommon(req, shared)
@ -191,16 +199,16 @@ func (lm *lockManager) GetPolicyUpsert(req policyRequest) (*policy, *sync.RWMute
// When the function returns, a lock will be held on the policy if err == nil.
// It is the caller's responsibility to unlock.
func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*policy, *sync.RWMutex, bool, error) {
lock := lm.policyLock(req.name, lockType)
func (lm *LockManager) getPolicyCommon(req PolicyRequest, lockType bool) (*Policy, *sync.RWMutex, bool, error) {
lock := lm.policyLock(req.Name, lockType)
var p *policy
var p *Policy
var err error
// Check if it's in our cache. If so, return right away.
if lm.CacheActive() {
lm.cacheMutex.RLock()
p = lm.cache[req.name]
p = lm.cache[req.Name]
if p != nil {
lm.cacheMutex.RUnlock()
return p, lock, false, nil
@ -209,7 +217,7 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
}
// Load it from storage
p, err = lm.getStoredPolicy(req.storage, req.name)
p, err = lm.getStoredPolicy(req.Storage, req.Name)
if err != nil {
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, err
@ -218,7 +226,7 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
if p == nil {
// This is the only place we upsert a new policy, so if upsert is not
// specified, or the lock type is wrong, unlock before returning
if !req.upsert {
if !req.Upsert {
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, nil
}
@ -228,33 +236,33 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
return nil, nil, false, errNeedExclusiveLock
}
switch req.keyType {
case keyType_AES256_GCM96:
if req.convergent && !req.derived {
switch req.KeyType {
case KeyType_AES256_GCM96:
if req.Convergent && !req.Derived {
return nil, nil, false, fmt.Errorf("convergent encryption requires derivation to be enabled")
}
case keyType_ECDSA_P256:
if req.derived || req.convergent {
return nil, nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %s", keyType_ECDSA_P256)
case KeyType_ECDSA_P256:
if req.Derived || req.Convergent {
return nil, nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %s", KeyType_ECDSA_P256)
}
default:
return nil, nil, false, fmt.Errorf("unsupported key type %v", req.keyType)
return nil, nil, false, fmt.Errorf("unsupported key type %v", req.KeyType)
}
p = &policy{
Name: req.name,
Type: req.keyType,
Derived: req.derived,
p = &Policy{
Name: req.Name,
Type: req.KeyType,
Derived: req.Derived,
}
if req.derived {
p.KDF = kdf_hkdf_sha256
p.ConvergentEncryption = req.convergent
if req.Derived {
p.KDF = Kdf_hkdf_sha256
p.ConvergentEncryption = req.Convergent
p.ConvergentVersion = 2
}
err = p.rotate(req.storage)
err = p.Rotate(req.Storage)
if err != nil {
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, err
@ -267,12 +275,12 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
defer lm.cacheMutex.Unlock()
// Make sure a policy didn't appear. If so, it will only be set if
// there was no error, so assume it's good and return that
exp := lm.cache[req.name]
exp := lm.cache[req.Name]
if exp != nil {
return exp, lock, false, nil
}
if err == nil {
lm.cache[req.name] = p
lm.cache[req.Name] = p
}
}
@ -280,13 +288,13 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
return p, lock, true, nil
}
if p.needsUpgrade() {
if p.NeedsUpgrade() {
if lockType == shared {
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, errNeedExclusiveLock
}
err = p.upgrade(req.storage)
err = p.Upgrade(req.Storage)
if err != nil {
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, err
@ -300,25 +308,25 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
defer lm.cacheMutex.Unlock()
// Make sure a policy didn't appear. If so, it will only be set if
// there was no error, so assume it's good and return that
exp := lm.cache[req.name]
exp := lm.cache[req.Name]
if exp != nil {
return exp, lock, false, nil
}
if err == nil {
lm.cache[req.name] = p
lm.cache[req.Name] = p
}
}
return p, lock, false, nil
}
func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error {
func (lm *LockManager) DeletePolicy(storage logical.Storage, name string) error {
lm.cacheMutex.Lock()
lock := lm.policyLock(name, exclusive)
defer lock.Unlock()
defer lm.cacheMutex.Unlock()
var p *policy
var p *Policy
var err error
if lm.CacheActive() {
@ -355,7 +363,7 @@ func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error
return nil
}
func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*policy, error) {
func (lm *LockManager) getStoredPolicy(storage logical.Storage, name string) (*Policy, error) {
// Check if the policy already exists
raw, err := storage.Get("policy/" + name)
if err != nil {
@ -366,7 +374,7 @@ func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*p
}
// Decode the policy
policy := &policy{
policy := &Policy{
Keys: keyEntryMap{},
}
err = jsonutil.DecodeJSON(raw.Value, policy)

View File

@ -1,4 +1,4 @@
package transit
package keysutil
import (
"bytes"
@ -33,14 +33,14 @@ import (
// Careful with iota; don't put anything before it in this const block because
// we need the default of zero to be the old-style KDF
const (
kdf_hmac_sha256_counter = iota // built-in helper
kdf_hkdf_sha256 // golang.org/x/crypto/hkdf
Kdf_hmac_sha256_counter = iota // built-in helper
Kdf_hkdf_sha256 // golang.org/x/crypto/hkdf
)
// Or this one...we need the default of zero to be the original AES256-GCM96
const (
keyType_AES256_GCM96 = iota
keyType_ECDSA_P256
KeyType_AES256_GCM96 = iota
KeyType_ECDSA_P256
)
const ErrTooOld = "ciphertext or signature version is disallowed by policy (too old)"
@ -53,7 +53,7 @@ type KeyType int
func (kt KeyType) EncryptionSupported() bool {
switch kt {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
return true
}
return false
@ -61,7 +61,7 @@ func (kt KeyType) EncryptionSupported() bool {
func (kt KeyType) DecryptionSupported() bool {
switch kt {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
return true
}
return false
@ -69,7 +69,7 @@ func (kt KeyType) DecryptionSupported() bool {
func (kt KeyType) SigningSupported() bool {
switch kt {
case keyType_ECDSA_P256:
case KeyType_ECDSA_P256:
return true
}
return false
@ -77,7 +77,7 @@ func (kt KeyType) SigningSupported() bool {
func (kt KeyType) DerivationSupported() bool {
switch kt {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
return true
}
return false
@ -85,17 +85,17 @@ func (kt KeyType) DerivationSupported() bool {
func (kt KeyType) String() string {
switch kt {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
return "aes256-gcm96"
case keyType_ECDSA_P256:
case KeyType_ECDSA_P256:
return "ecdsa-p256"
}
return "[unknown]"
}
// keyEntry stores the key and metadata
type keyEntry struct {
// KeyEntry stores the key and metadata
type KeyEntry struct {
AESKey []byte `json:"key"`
HMACKey []byte `json:"hmac_key"`
CreationTime int64 `json:"creation_time"`
@ -106,11 +106,11 @@ type keyEntry struct {
}
// keyEntryMap is used to allow JSON marshal/unmarshal
type keyEntryMap map[int]keyEntry
type keyEntryMap map[int]KeyEntry
// MarshalJSON implements JSON marshaling
func (kem keyEntryMap) MarshalJSON() ([]byte, error) {
intermediate := map[string]keyEntry{}
intermediate := map[string]KeyEntry{}
for k, v := range kem {
intermediate[strconv.Itoa(k)] = v
}
@ -119,7 +119,7 @@ func (kem keyEntryMap) MarshalJSON() ([]byte, error) {
// MarshalJSON implements JSON unmarshaling
func (kem keyEntryMap) UnmarshalJSON(data []byte) error {
intermediate := map[string]keyEntry{}
intermediate := map[string]KeyEntry{}
if err := jsonutil.DecodeJSON(data, &intermediate); err != nil {
return err
}
@ -135,7 +135,7 @@ func (kem keyEntryMap) UnmarshalJSON(data []byte) error {
}
// Policy is the struct used to store metadata
type policy struct {
type Policy struct {
Name string `json:"name"`
Key []byte `json:"key,omitempty"` //DEPRECATED
Keys keyEntryMap `json:"keys"`
@ -171,10 +171,10 @@ type policy struct {
// ArchivedKeys stores old keys. This is used to keep the key loading time sane
// when there are huge numbers of rotations.
type archivedKeys struct {
Keys []keyEntry `json:"keys"`
Keys []KeyEntry `json:"keys"`
}
func (p *policy) loadArchive(storage logical.Storage) (*archivedKeys, error) {
func (p *Policy) LoadArchive(storage logical.Storage) (*archivedKeys, error) {
archive := &archivedKeys{}
raw, err := storage.Get("archive/" + p.Name)
@ -182,7 +182,7 @@ func (p *policy) loadArchive(storage logical.Storage) (*archivedKeys, error) {
return nil, err
}
if raw == nil {
archive.Keys = make([]keyEntry, 0)
archive.Keys = make([]KeyEntry, 0)
return archive, nil
}
@ -193,7 +193,7 @@ func (p *policy) loadArchive(storage logical.Storage) (*archivedKeys, error) {
return archive, nil
}
func (p *policy) storeArchive(archive *archivedKeys, storage logical.Storage) error {
func (p *Policy) storeArchive(archive *archivedKeys, storage logical.Storage) error {
// Encode the policy
buf, err := json.Marshal(archive)
if err != nil {
@ -215,7 +215,7 @@ func (p *policy) storeArchive(archive *archivedKeys, storage logical.Storage) er
// handleArchiving manages the movement of keys to and from the policy archive.
// This should *ONLY* be called from Persist() since it assumes that the policy
// will be persisted afterwards.
func (p *policy) handleArchiving(storage logical.Storage) error {
func (p *Policy) handleArchiving(storage logical.Storage) error {
// We need to move keys that are no longer accessible to archivedKeys, and keys
// that now need to be accessible back here.
//
@ -241,7 +241,7 @@ func (p *policy) handleArchiving(storage logical.Storage) error {
p.MinDecryptionVersion, p.LatestVersion)
}
archive, err := p.loadArchive(storage)
archive, err := p.LoadArchive(storage)
if err != nil {
return err
}
@ -263,7 +263,7 @@ func (p *policy) handleArchiving(storage logical.Storage) error {
// key version
if len(archive.Keys) < p.LatestVersion+1 {
// Increase the size of the archive slice
newKeys := make([]keyEntry, p.LatestVersion+1)
newKeys := make([]KeyEntry, p.LatestVersion+1)
copy(newKeys, archive.Keys)
archive.Keys = newKeys
}
@ -289,7 +289,7 @@ func (p *policy) handleArchiving(storage logical.Storage) error {
return nil
}
func (p *policy) Persist(storage logical.Storage) error {
func (p *Policy) Persist(storage logical.Storage) error {
err := p.handleArchiving(storage)
if err != nil {
return err
@ -313,11 +313,11 @@ func (p *policy) Persist(storage logical.Storage) error {
return nil
}
func (p *policy) Serialize() ([]byte, error) {
func (p *Policy) Serialize() ([]byte, error) {
return json.Marshal(p)
}
func (p *policy) needsUpgrade() bool {
func (p *Policy) NeedsUpgrade() bool {
// Ensure we've moved from Key -> Keys
if p.Key != nil && len(p.Key) > 0 {
return true
@ -352,11 +352,11 @@ func (p *policy) needsUpgrade() bool {
return false
}
func (p *policy) upgrade(storage logical.Storage) error {
func (p *Policy) Upgrade(storage logical.Storage) error {
persistNeeded := false
// Ensure we've moved from Key -> Keys
if p.Key != nil && len(p.Key) > 0 {
p.migrateKeyToKeysMap()
p.MigrateKeyToKeysMap()
persistNeeded = true
}
@ -409,7 +409,7 @@ func (p *policy) upgrade(storage logical.Storage) error {
// on the policy. If derivation is disabled the raw key is used and no context
// is required, otherwise the KDF mode is used with the context to derive the
// proper key.
func (p *policy) DeriveKey(context []byte, ver int) ([]byte, error) {
func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
if !p.Type.DerivationSupported() {
return nil, errutil.UserError{Err: fmt.Sprintf("derivation not supported for key type %v", p.Type)}
}
@ -433,11 +433,11 @@ func (p *policy) DeriveKey(context []byte, ver int) ([]byte, error) {
}
switch p.KDF {
case kdf_hmac_sha256_counter:
case Kdf_hmac_sha256_counter:
prf := kdf.HMACSHA256PRF
prfLen := kdf.HMACSHA256PRFLen
return kdf.CounterMode(prf, prfLen, p.Keys[ver].AESKey, context, 256)
case kdf_hkdf_sha256:
case Kdf_hkdf_sha256:
reader := hkdf.New(sha256.New, p.Keys[ver].AESKey, nil, context)
derBytes := bytes.NewBuffer(nil)
derBytes.Grow(32)
@ -458,14 +458,14 @@ func (p *policy) DeriveKey(context []byte, ver int) ([]byte, error) {
}
}
func (p *policy) Encrypt(context, nonce []byte, value string) (string, error) {
func (p *Policy) Encrypt(context, nonce []byte, value string) (string, error) {
if !p.Type.EncryptionSupported() {
return "", errutil.UserError{Err: fmt.Sprintf("message encryption not supported for key type %v", p.Type)}
}
// Guard against a potentially invalid key type
switch p.Type {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
}
@ -484,7 +484,7 @@ func (p *policy) Encrypt(context, nonce []byte, value string) (string, error) {
// Guard against a potentially invalid key type
switch p.Type {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
}
@ -539,7 +539,7 @@ func (p *policy) Encrypt(context, nonce []byte, value string) (string, error) {
return encoded, nil
}
func (p *policy) Decrypt(context, nonce []byte, value string) (string, error) {
func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
if !p.Type.DecryptionSupported() {
return "", errutil.UserError{Err: fmt.Sprintf("message decryption not supported for key type %v", p.Type)}
}
@ -585,7 +585,7 @@ func (p *policy) Decrypt(context, nonce []byte, value string) (string, error) {
// Guard against a potentially invalid key type
switch p.Type {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
}
@ -626,7 +626,7 @@ func (p *policy) Decrypt(context, nonce []byte, value string) (string, error) {
return base64.StdEncoding.EncodeToString(plain), nil
}
func (p *policy) HMACKey(version int) ([]byte, error) {
func (p *Policy) HMACKey(version int) ([]byte, error) {
if version < p.MinDecryptionVersion {
return nil, fmt.Errorf("key version disallowed by policy (minimum is %d)", p.MinDecryptionVersion)
}
@ -642,14 +642,14 @@ func (p *policy) HMACKey(version int) ([]byte, error) {
return p.Keys[version].HMACKey, nil
}
func (p *policy) Sign(hashedInput []byte) (string, error) {
func (p *Policy) Sign(hashedInput []byte) (string, error) {
if !p.Type.SigningSupported() {
return "", fmt.Errorf("message signing not supported for key type %v", p.Type)
}
var sig []byte
switch p.Type {
case keyType_ECDSA_P256:
case KeyType_ECDSA_P256:
keyParams := p.Keys[p.LatestVersion]
key := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
@ -685,7 +685,7 @@ func (p *policy) Sign(hashedInput []byte) (string, error) {
return encoded, nil
}
func (p *policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
func (p *Policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
if !p.Type.SigningSupported() {
return false, errutil.UserError{Err: fmt.Sprintf("message verification not supported for key type %v", p.Type)}
}
@ -714,7 +714,7 @@ func (p *policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
}
switch p.Type {
case keyType_ECDSA_P256:
case KeyType_ECDSA_P256:
asn1Sig, err := base64.StdEncoding.DecodeString(splitVerSig[1])
if err != nil {
return false, errutil.UserError{Err: "invalid base64 signature value"}
@ -744,7 +744,7 @@ func (p *policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
return false, errutil.InternalError{Err: "no valid key type found"}
}
func (p *policy) rotate(storage logical.Storage) error {
func (p *Policy) Rotate(storage logical.Storage) error {
if p.Keys == nil {
// This is an initial key rotation when generating a new policy. We
// don't need to call migrate here because if we've called getPolicy to
@ -753,7 +753,7 @@ func (p *policy) rotate(storage logical.Storage) error {
}
p.LatestVersion += 1
entry := keyEntry{
entry := KeyEntry{
CreationTime: time.Now().Unix(),
}
@ -764,7 +764,7 @@ func (p *policy) rotate(storage logical.Storage) error {
entry.HMACKey = hmacKey
switch p.Type {
case keyType_AES256_GCM96:
case KeyType_AES256_GCM96:
// Generate a 256bit key
newKey, err := uuid.GenerateRandomBytes(32)
if err != nil {
@ -772,7 +772,7 @@ func (p *policy) rotate(storage logical.Storage) error {
}
entry.AESKey = newKey
case keyType_ECDSA_P256:
case KeyType_ECDSA_P256:
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return err
@ -807,9 +807,9 @@ func (p *policy) rotate(storage logical.Storage) error {
return p.Persist(storage)
}
func (p *policy) migrateKeyToKeysMap() {
func (p *Policy) MigrateKeyToKeysMap() {
p.Keys = keyEntryMap{
1: keyEntry{
1: KeyEntry{
AESKey: p.Key,
CreationTime: time.Now().Unix(),
},

View File

@ -21,6 +21,7 @@ import (
"github.com/hashicorp/vault/api"
credCert "github.com/hashicorp/vault/builtin/credential/cert"
"github.com/hashicorp/vault/builtin/logical/transit"
"github.com/hashicorp/vault/helper/keysutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)
@ -381,7 +382,7 @@ func testHTTP_Forwarding_Stress_Common(t *testing.T, rpc, parallel bool, num uin
secret, err := doResp(resp)
if err != nil {
// This could well happen since the min version is jumping around
if strings.Contains(err.Error(), transit.ErrTooOld) {
if strings.Contains(err.Error(), keysutil.ErrTooOld) {
mySuccessfulOps++
continue
}