Switch to lockManager
This commit is contained in:
parent
08b91b776d
commit
3e5391aa9c
|
@ -33,16 +33,12 @@ func Backend(conf *logical.BackendConfig) *backend {
|
|||
Secrets: []*framework.Secret{},
|
||||
}
|
||||
|
||||
if conf.System.CachingDisabled() {
|
||||
b.policies = newSimplePolicyCRUD()
|
||||
} else {
|
||||
b.policies = newCachingPolicyCRUD()
|
||||
}
|
||||
b.lm = newLockManager(conf.System.CachingDisabled())
|
||||
|
||||
return &b
|
||||
}
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
policies policyCRUD
|
||||
lm *lockManager
|
||||
}
|
||||
|
|
|
@ -604,7 +604,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) {
|
|||
|
||||
var chosenFunc, chosenKey string
|
||||
|
||||
//t.Logf("Starting")
|
||||
//t.Errorf("Starting %d", id)
|
||||
for {
|
||||
// Stop after 10 seconds
|
||||
if time.Now().Sub(startTime) > 10*time.Second {
|
||||
|
@ -629,7 +629,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) {
|
|||
switch chosenFunc {
|
||||
// Encrypt our plaintext and store the result
|
||||
case "encrypt":
|
||||
//t.Logf("%s, %s", chosenFunc, chosenKey)
|
||||
//t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id)
|
||||
fd.Raw["plaintext"] = base64.StdEncoding.EncodeToString([]byte(testPlaintext))
|
||||
fd.Schema = be.pathEncrypt().Fields
|
||||
resp, err := be.pathEncryptWrite(req, fd)
|
||||
|
@ -649,7 +649,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) {
|
|||
|
||||
// Decrypt the ciphertext and compare the result
|
||||
case "decrypt":
|
||||
//t.Logf("%s, %s", chosenFunc, chosenKey)
|
||||
//t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id)
|
||||
ct := latestEncryptedText[chosenKey]
|
||||
if ct == "" {
|
||||
continue
|
||||
|
@ -677,7 +677,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) {
|
|||
|
||||
// Change the min version, which also tests the archive functionality
|
||||
case "change_min_version":
|
||||
//t.Logf("%s, %s", chosenFunc, chosenKey)
|
||||
//t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id)
|
||||
resp, err := be.pathPolicyRead(req, fd)
|
||||
if err != nil {
|
||||
t.Fatalf("got an error reading policy %s: %v", chosenKey, err)
|
||||
|
|
|
@ -1,110 +0,0 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
// cachingPolicyCRUD implements CRUD operations with a simple locking cache of
|
||||
// policies in memory
|
||||
type cachingPolicyCRUD struct {
|
||||
sync.RWMutex
|
||||
cache map[string]lockingPolicy
|
||||
}
|
||||
|
||||
func newCachingPolicyCRUD() *cachingPolicyCRUD {
|
||||
return &cachingPolicyCRUD{
|
||||
cache: map[string]lockingPolicy{},
|
||||
}
|
||||
}
|
||||
|
||||
// See general comments on the interface method
|
||||
func (p *cachingPolicyCRUD) getPolicy(storage logical.Storage, name string) (lockingPolicy, error) {
|
||||
// We don't defer this since we may need to give it up and get a write lock
|
||||
p.RLock()
|
||||
|
||||
// First, see if we're in the cache -- if so, return that
|
||||
if p.cache[name] != nil {
|
||||
defer p.RUnlock()
|
||||
return p.cache[name], nil
|
||||
}
|
||||
|
||||
// If we didn't find anything, we'll need to write into the cache, plus possibly
|
||||
// persist the entry, so lock the cache
|
||||
p.RUnlock()
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
// Check one more time to ensure that another process did not write during
|
||||
// our lock switcheroo.
|
||||
if p.cache[name] != nil {
|
||||
return p.cache[name], nil
|
||||
}
|
||||
|
||||
return p.refreshPolicy(storage, name)
|
||||
}
|
||||
|
||||
// See general comments on the interface method
|
||||
func (p *cachingPolicyCRUD) refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) {
|
||||
// Check once more to ensure it hasn't been added to the cache since the lock was acquired
|
||||
if p.cache[name] != nil {
|
||||
return p.cache[name], nil
|
||||
}
|
||||
|
||||
// Note that we don't need to create the locking entry until the end,
|
||||
// because the policy wasn't in the cache so we don't know about it, and we
|
||||
// hold the cache lock so nothing else can be writing it in right now
|
||||
policy, err := fetchPolicyFromStorage(storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if policy == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lp := &mutexLockingPolicy{
|
||||
policy: policy,
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
p.cache[name] = lp
|
||||
|
||||
return lp, nil
|
||||
}
|
||||
|
||||
// See general comments on the interface method
|
||||
func (p *cachingPolicyCRUD) generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) {
|
||||
policy, err := generatePolicyCommon(p, storage, name, derived)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Now we need to check again in the cache to ensure the policy wasn't
|
||||
// created since we ran generatePolicy and then got the lock. A policy
|
||||
// being created holds a write lock until it's done (starting from this
|
||||
// point), so it'll be in the cache at this point.
|
||||
if lp := p.cache[name]; lp != nil {
|
||||
return lp, nil
|
||||
}
|
||||
|
||||
lp := &mutexLockingPolicy{
|
||||
policy: policy,
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
p.cache[name] = lp
|
||||
|
||||
// Return the policy
|
||||
return lp, nil
|
||||
}
|
||||
|
||||
// See general comments on the interface method
|
||||
func (p *cachingPolicyCRUD) deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error {
|
||||
err := deletePolicyCommon(p, lp, storage, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(p.cache, name)
|
||||
|
||||
return nil
|
||||
}
|
337
builtin/logical/transit/lock_manager.go
Normal file
337
builtin/logical/transit/lock_manager.go
Normal file
|
@ -0,0 +1,337 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
const (
|
||||
shared = false
|
||||
exclusive = true
|
||||
)
|
||||
|
||||
type lockManager struct {
|
||||
// A lock for each named key
|
||||
locks map[string]*sync.RWMutex
|
||||
|
||||
// A mutex for the map itself
|
||||
lockMutex sync.RWMutex
|
||||
|
||||
// If caching is enabled, the map of name to in-memory policy cache
|
||||
cache map[string]*Policy
|
||||
|
||||
// Used for global locking, and as the cache map mutex
|
||||
globalMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func newLockManager(cacheDisabled bool) *lockManager {
|
||||
lm := &lockManager{
|
||||
locks: map[string]*sync.RWMutex{},
|
||||
}
|
||||
if !cacheDisabled {
|
||||
lm.cache = map[string]*Policy{}
|
||||
}
|
||||
return lm
|
||||
}
|
||||
|
||||
func (lm *lockManager) CacheActive() bool {
|
||||
return lm.cache != nil
|
||||
}
|
||||
|
||||
func (lm *lockManager) LockAll(name string) {
|
||||
lm.globalMutex.Lock()
|
||||
lm.LockPolicy(name, exclusive)
|
||||
}
|
||||
|
||||
func (lm *lockManager) UnlockAll(name string) {
|
||||
lm.UnlockPolicy(name, exclusive)
|
||||
lm.globalMutex.Unlock()
|
||||
}
|
||||
|
||||
func (lm *lockManager) LockPolicy(name string, writeLock bool) {
|
||||
lm.lockMutex.RLock()
|
||||
lock := lm.locks[name]
|
||||
if lock != nil {
|
||||
// We want to give this up before locking the lock, but it's safe --
|
||||
// the only time we ever write to a value in this map is the first time
|
||||
// we access the value, so it won't be changing out from under us
|
||||
lm.lockMutex.RUnlock()
|
||||
if writeLock {
|
||||
lock.Lock()
|
||||
} else {
|
||||
lock.RLock()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
lm.lockMutex.RUnlock()
|
||||
lm.lockMutex.Lock()
|
||||
|
||||
// Don't defer the unlock call because if we get a valid lock below we want
|
||||
// to release the lock mutex right away to avoid the possibility of
|
||||
// deadlock by trying to grab the second lock
|
||||
|
||||
// Check to make sure it hasn't been created since
|
||||
lock = lm.locks[name]
|
||||
if lock != nil {
|
||||
lm.lockMutex.Unlock()
|
||||
if writeLock {
|
||||
lock.Lock()
|
||||
} else {
|
||||
lock.RLock()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
lock = &sync.RWMutex{}
|
||||
lm.locks[name] = lock
|
||||
lm.lockMutex.Unlock()
|
||||
if writeLock {
|
||||
lock.Lock()
|
||||
} else {
|
||||
lock.RLock()
|
||||
}
|
||||
}
|
||||
|
||||
func (lm *lockManager) UnlockPolicy(name string, writeLock bool) {
|
||||
lm.lockMutex.RLock()
|
||||
lock := lm.locks[name]
|
||||
lm.lockMutex.RUnlock()
|
||||
|
||||
if writeLock {
|
||||
lock.Unlock()
|
||||
} else {
|
||||
lock.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (lm *lockManager) GetPolicy(storage logical.Storage, name string) (*Policy, bool, error) {
|
||||
p, lt, _, err := lm.getPolicyCommon(storage, name, false, false)
|
||||
return p, lt, err
|
||||
}
|
||||
|
||||
func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived bool) (*Policy, bool, bool, error) {
|
||||
return lm.getPolicyCommon(storage, name, true, derived)
|
||||
}
|
||||
|
||||
// When the function returns, a lock will be held on the policy if err == nil.
|
||||
// The type of lock will be indicated by the return value. It is the caller's
|
||||
// responsibility to unlock.
|
||||
func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived bool) (p *Policy, lockType bool, upserted bool, err error) {
|
||||
// If we are using a cache, lock it now to avoid having to do really
|
||||
// complicated lock juggling as we call various functions. We'll also defer
|
||||
// the store into the cache.
|
||||
lockType = shared
|
||||
lm.LockPolicy(name, shared)
|
||||
|
||||
if lm.CacheActive() {
|
||||
lm.globalMutex.RLock()
|
||||
p = lm.cache[name]
|
||||
if p != nil {
|
||||
defer lm.globalMutex.RUnlock()
|
||||
return
|
||||
}
|
||||
lm.globalMutex.RUnlock()
|
||||
|
||||
// When we return, since we didn't have the policy in the cache, if
|
||||
// there was no error, write the value in.
|
||||
defer func() {
|
||||
if err == nil {
|
||||
lm.globalMutex.Lock()
|
||||
defer lm.globalMutex.Unlock()
|
||||
// Make sure a policy didn't appear
|
||||
exp := lm.cache[name]
|
||||
if exp != nil {
|
||||
p = exp
|
||||
return
|
||||
}
|
||||
|
||||
lm.cache[name] = p
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
p, err = lm.getStoredPolicy(storage, name)
|
||||
if err != nil {
|
||||
defer lm.UnlockPolicy(name, shared)
|
||||
return
|
||||
}
|
||||
|
||||
if p == nil {
|
||||
if !upsert {
|
||||
defer lm.UnlockPolicy(name, shared)
|
||||
return
|
||||
}
|
||||
|
||||
// Get an exlusive lock; on success, check again to ensure that no
|
||||
// policy exists. Note that if we are using a cache we will already be
|
||||
// serializing this entire code path and it's currently the only one
|
||||
// that generates policies, so we don't need to check the cache here;
|
||||
// simply checking the disk again is sufficient.
|
||||
lm.UnlockPolicy(name, shared)
|
||||
lockType = exclusive
|
||||
lm.LockPolicy(name, exclusive)
|
||||
|
||||
p, err = lm.getStoredPolicy(storage, name)
|
||||
if err != nil {
|
||||
defer lm.UnlockPolicy(name, exclusive)
|
||||
return
|
||||
}
|
||||
if p != nil {
|
||||
return
|
||||
}
|
||||
|
||||
upserted = true
|
||||
|
||||
p = &Policy{
|
||||
Name: name,
|
||||
CipherMode: "aes-gcm",
|
||||
Derived: derived,
|
||||
}
|
||||
if derived {
|
||||
p.KDFMode = kdfMode
|
||||
}
|
||||
|
||||
err = p.rotate(storage)
|
||||
if err != nil {
|
||||
defer lm.UnlockPolicy(name, exclusive)
|
||||
p = nil
|
||||
}
|
||||
|
||||
// We don't need to worry about upgrading since it will be a new policy
|
||||
return
|
||||
}
|
||||
|
||||
if p.needsUpgrade() {
|
||||
lm.UnlockPolicy(name, shared)
|
||||
lockType = exclusive
|
||||
lm.LockPolicy(name, exclusive)
|
||||
|
||||
// Reload the policy with the write lock to ensure we still need the upgrade
|
||||
p, err = lm.getStoredPolicy(storage, name)
|
||||
if err != nil {
|
||||
defer lm.UnlockPolicy(name, exclusive)
|
||||
return
|
||||
}
|
||||
if p == nil {
|
||||
defer lm.UnlockPolicy(name, exclusive)
|
||||
err = fmt.Errorf("error reloading policy for upgrade")
|
||||
return
|
||||
}
|
||||
|
||||
if !p.needsUpgrade() {
|
||||
// Already happened, return the newly loaded policy
|
||||
return
|
||||
}
|
||||
|
||||
err = p.upgrade(storage)
|
||||
if err != nil {
|
||||
defer lm.UnlockPolicy(name, exclusive)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error {
|
||||
lm.LockAll(name)
|
||||
defer lm.UnlockAll(name)
|
||||
|
||||
var p *Policy
|
||||
var err error
|
||||
|
||||
if lm.CacheActive() {
|
||||
p = lm.cache[name]
|
||||
if p == nil {
|
||||
return fmt.Errorf("could not delete policy; not found")
|
||||
}
|
||||
} else {
|
||||
p, err = lm.getStoredPolicy(storage, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if p == nil {
|
||||
return fmt.Errorf("could not delete policy; not found")
|
||||
}
|
||||
}
|
||||
|
||||
if !p.DeletionAllowed {
|
||||
return fmt.Errorf("deletion is not allowed for this policy")
|
||||
}
|
||||
|
||||
err = storage.Delete("policy/" + name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deleting policy %s: %s", name, err)
|
||||
}
|
||||
|
||||
err = storage.Delete("archive/" + name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deleting archive %s: %s", name, err)
|
||||
}
|
||||
|
||||
if lm.CacheActive() {
|
||||
delete(lm.cache, name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// When this function returns it's the responsibility of the caller to call UnlockAll if err is not nil
|
||||
func (lm *lockManager) RefreshPolicy(storage logical.Storage, name string) (p *Policy, err error) {
|
||||
lm.LockPolicy(name, exclusive)
|
||||
|
||||
if lm.CacheActive() {
|
||||
p = lm.cache[name]
|
||||
if p != nil {
|
||||
return
|
||||
}
|
||||
err = fmt.Errorf("could not refresh policy; not found")
|
||||
defer lm.UnlockPolicy(name, exclusive)
|
||||
return
|
||||
}
|
||||
|
||||
p, err = lm.getStoredPolicy(storage, name)
|
||||
if err != nil {
|
||||
defer lm.UnlockPolicy(name, exclusive)
|
||||
return
|
||||
}
|
||||
|
||||
if p == nil {
|
||||
err = fmt.Errorf("could not refresh policy; not found")
|
||||
defer lm.UnlockPolicy(name, exclusive)
|
||||
}
|
||||
|
||||
if p.needsUpgrade() {
|
||||
err = p.upgrade(storage)
|
||||
if err != nil {
|
||||
defer lm.UnlockPolicy(name, exclusive)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*Policy, error) {
|
||||
// Check if the policy already exists
|
||||
raw, err := storage.Get("policy/" + name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Decode the policy
|
||||
policy := &Policy{
|
||||
Keys: KeyEntryMap{},
|
||||
}
|
||||
err = json.Unmarshal(raw.Value, policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return policy, nil
|
||||
}
|
|
@ -41,49 +41,38 @@ func (b *backend) pathConfigWrite(
|
|||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
// Check if the policy already exists
|
||||
lp, err := b.policies.getPolicy(req.Storage, name)
|
||||
// Check if the policy already exists before we lock everything
|
||||
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
return logical.ErrorResponse(
|
||||
fmt.Sprintf("no existing key named %s could be found", name)),
|
||||
logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Store some values so we can detect if the policy changed after locking
|
||||
lp.RLock()
|
||||
currDeletionAllowed := lp.Policy().DeletionAllowed
|
||||
currMinDecryptionVersion := lp.Policy().MinDecryptionVersion
|
||||
lp.RUnlock()
|
||||
// Store some values so we can detect a change when we lock everything
|
||||
currDeletionAllowed := p.DeletionAllowed
|
||||
currMinDecryptionVersion := p.MinDecryptionVersion
|
||||
|
||||
// Hold both locks since we want to ensure the policy doesn't change from
|
||||
// underneath us
|
||||
b.policies.Lock()
|
||||
defer b.policies.Unlock()
|
||||
lp.Lock()
|
||||
defer lp.Unlock()
|
||||
b.lm.UnlockPolicy(name, lockType)
|
||||
|
||||
// Refresh in case it's changed since before we grabbed the lock
|
||||
lp, err = b.policies.refreshPolicy(req.Storage, name)
|
||||
p, err = b.lm.RefreshPolicy(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
return nil, fmt.Errorf("error finding key %s after locking for changes", name)
|
||||
}
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.Policy() == nil {
|
||||
return nil, fmt.Errorf("no existing key named %s could be found", name)
|
||||
}
|
||||
defer b.lm.UnlockPolicy(name, exclusive)
|
||||
|
||||
resp := &logical.Response{}
|
||||
|
||||
// Check for anything to have been updated since we got the write lock
|
||||
if currDeletionAllowed != lp.Policy().DeletionAllowed ||
|
||||
currMinDecryptionVersion != lp.Policy().MinDecryptionVersion {
|
||||
if currDeletionAllowed != p.DeletionAllowed ||
|
||||
currMinDecryptionVersion != p.MinDecryptionVersion {
|
||||
resp.AddWarning("key configuration has changed since this endpoint was called, not updating")
|
||||
return resp, nil
|
||||
}
|
||||
|
@ -104,12 +93,12 @@ func (b *backend) pathConfigWrite(
|
|||
}
|
||||
|
||||
if minDecryptionVersion > 0 &&
|
||||
minDecryptionVersion != lp.Policy().MinDecryptionVersion {
|
||||
if minDecryptionVersion > lp.Policy().LatestVersion {
|
||||
minDecryptionVersion != p.MinDecryptionVersion {
|
||||
if minDecryptionVersion > p.LatestVersion {
|
||||
return logical.ErrorResponse(
|
||||
fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, lp.Policy().LatestVersion)), nil
|
||||
fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, p.LatestVersion)), nil
|
||||
}
|
||||
lp.Policy().MinDecryptionVersion = minDecryptionVersion
|
||||
p.MinDecryptionVersion = minDecryptionVersion
|
||||
persistNeeded = true
|
||||
}
|
||||
}
|
||||
|
@ -117,8 +106,8 @@ func (b *backend) pathConfigWrite(
|
|||
allowDeletionInt, ok := d.GetOk("deletion_allowed")
|
||||
if ok {
|
||||
allowDeletion := allowDeletionInt.(bool)
|
||||
if allowDeletion != lp.Policy().DeletionAllowed {
|
||||
lp.Policy().DeletionAllowed = allowDeletion
|
||||
if allowDeletion != p.DeletionAllowed {
|
||||
p.DeletionAllowed = allowDeletion
|
||||
persistNeeded = true
|
||||
}
|
||||
}
|
||||
|
@ -126,8 +115,8 @@ func (b *backend) pathConfigWrite(
|
|||
// Add this as a guard here before persisting since we now require the min
|
||||
// decryption version to start at 1; even if it's not explicitly set here,
|
||||
// force the upgrade
|
||||
if lp.Policy().MinDecryptionVersion == 0 {
|
||||
lp.Policy().MinDecryptionVersion = 1
|
||||
if p.MinDecryptionVersion == 0 {
|
||||
p.MinDecryptionVersion = 1
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
|
@ -135,7 +124,7 @@ func (b *backend) pathConfigWrite(
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
return resp, lp.Policy().Persist(req.Storage)
|
||||
return resp, p.Persist(req.Storage)
|
||||
}
|
||||
|
||||
const pathConfigHelpSyn = `Configure a named encryption key`
|
||||
|
|
|
@ -73,23 +73,14 @@ func (b *backend) pathDatakeyWrite(
|
|||
}
|
||||
|
||||
// Get the policy
|
||||
lp, err := b.policies.getPolicy(req.Storage, name)
|
||||
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
lp.RLock()
|
||||
defer lp.RUnlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.Policy() == nil {
|
||||
return nil, fmt.Errorf("no existing policy named %s could be found", name)
|
||||
}
|
||||
defer b.lm.UnlockPolicy(name, lockType)
|
||||
|
||||
newKey := make([]byte, 32)
|
||||
bits := d.Get("bits").(int)
|
||||
|
@ -107,7 +98,7 @@ func (b *backend) pathDatakeyWrite(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext, err := lp.Policy().Encrypt(context, base64.StdEncoding.EncodeToString(newKey))
|
||||
ciphertext, err := p.Encrypt(context, base64.StdEncoding.EncodeToString(newKey))
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
|
|
|
@ -58,25 +58,17 @@ func (b *backend) pathDecryptWrite(
|
|||
}
|
||||
|
||||
// Get the policy
|
||||
lp, err := b.policies.getPolicy(req.Storage, name)
|
||||
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
lp.RLock()
|
||||
defer lp.RUnlock()
|
||||
defer b.lm.UnlockPolicy(name, lockType)
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.Policy() == nil {
|
||||
return nil, fmt.Errorf("no existing policy named %s could be found", name)
|
||||
}
|
||||
|
||||
plaintext, err := lp.Policy().Decrypt(context, ciphertext)
|
||||
plaintext, err := p.Decrypt(context, ciphertext)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
|
|
|
@ -44,12 +44,14 @@ func (b *backend) pathEncrypt() *framework.Path {
|
|||
func (b *backend) pathEncryptExistenceCheck(
|
||||
req *logical.Request, d *framework.FieldData) (bool, error) {
|
||||
name := d.Get("name").(string)
|
||||
lp, err := b.policies.getPolicy(req.Storage, name)
|
||||
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return lp != nil, nil
|
||||
if p != nil {
|
||||
defer b.lm.UnlockPolicy(name, lockType)
|
||||
}
|
||||
return p != nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathEncryptWrite(
|
||||
|
@ -63,8 +65,8 @@ func (b *backend) pathEncryptWrite(
|
|||
// Decode the context if any
|
||||
contextRaw := d.Get("context").(string)
|
||||
var context []byte
|
||||
var err error
|
||||
if len(contextRaw) != 0 {
|
||||
var err error
|
||||
context, err = base64.StdEncoding.DecodeString(contextRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest
|
||||
|
@ -72,43 +74,23 @@ func (b *backend) pathEncryptWrite(
|
|||
}
|
||||
|
||||
// Get the policy
|
||||
lp, err := b.policies.getPolicy(req.Storage, name)
|
||||
var p *Policy
|
||||
var lockType bool
|
||||
var upserted bool
|
||||
if req.Operation == logical.CreateOperation {
|
||||
p, lockType, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0)
|
||||
} else {
|
||||
p, lockType, err = b.lm.GetPolicy(req.Storage, name)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error or upsert if invalid policy
|
||||
if lp == nil {
|
||||
if req.Operation != logical.CreateOperation {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Get a write lock
|
||||
b.policies.Lock()
|
||||
|
||||
isDerived := len(context) != 0
|
||||
|
||||
// This also checks to make sure one hasn't been created since we grabbed the write lock
|
||||
lp, err = b.policies.generatePolicy(req.Storage, name, isDerived)
|
||||
// If the error is that the policy has been created in the interim we
|
||||
// will get the policy back, so only consider it an error if err is not
|
||||
// nil and we do not get a policy back
|
||||
if err != nil && lp != nil {
|
||||
b.policies.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
b.policies.Unlock()
|
||||
if p == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
defer b.lm.UnlockPolicy(name, lockType)
|
||||
|
||||
lp.RLock()
|
||||
defer lp.RUnlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.Policy() == nil {
|
||||
return nil, fmt.Errorf("no existing policy named %s could be found", name)
|
||||
}
|
||||
|
||||
ciphertext, err := lp.Policy().Encrypt(context, value)
|
||||
ciphertext, err := p.Encrypt(context, value)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
|
@ -130,6 +112,11 @@ func (b *backend) pathEncryptWrite(
|
|||
"ciphertext": ciphertext,
|
||||
},
|
||||
}
|
||||
|
||||
if req.Operation == logical.CreateOperation && !upserted {
|
||||
resp.AddWarning("Attempted creation of the key during the encrypt operation, but it was created beforehand")
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -36,55 +36,58 @@ func (b *backend) pathKeys() *framework.Path {
|
|||
|
||||
func (b *backend) pathPolicyWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
// Grab a write lock right off the bat
|
||||
b.policies.Lock()
|
||||
defer b.policies.Unlock()
|
||||
|
||||
name := d.Get("name").(string)
|
||||
derived := d.Get("derived").(bool)
|
||||
|
||||
// Generate the policy; this will also check if it exists for safety
|
||||
_, err := b.policies.generatePolicy(req.Storage, name, derived)
|
||||
return nil, err
|
||||
p, lockType, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p == nil {
|
||||
return nil, fmt.Errorf("error generating key: returned policy was nil")
|
||||
}
|
||||
|
||||
defer b.lm.UnlockPolicy(name, lockType)
|
||||
|
||||
resp := &logical.Response{}
|
||||
if !upserted {
|
||||
resp.AddWarning(fmt.Sprintf("key %s already existed", name))
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathPolicyRead(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
lp, err := b.policies.getPolicy(req.Storage, name)
|
||||
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lp.RLock()
|
||||
defer lp.RUnlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.Policy() == nil {
|
||||
return nil, fmt.Errorf("no existing policy named %s could be found", name)
|
||||
}
|
||||
defer b.lm.UnlockPolicy(name, lockType)
|
||||
|
||||
// Return the response
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"name": lp.Policy().Name,
|
||||
"cipher_mode": lp.Policy().CipherMode,
|
||||
"derived": lp.Policy().Derived,
|
||||
"deletion_allowed": lp.Policy().DeletionAllowed,
|
||||
"min_decryption_version": lp.Policy().MinDecryptionVersion,
|
||||
"latest_version": lp.Policy().LatestVersion,
|
||||
"name": p.Name,
|
||||
"cipher_mode": p.CipherMode,
|
||||
"derived": p.Derived,
|
||||
"deletion_allowed": p.DeletionAllowed,
|
||||
"min_decryption_version": p.MinDecryptionVersion,
|
||||
"latest_version": p.LatestVersion,
|
||||
},
|
||||
}
|
||||
if lp.Policy().Derived {
|
||||
resp.Data["kdf_mode"] = lp.Policy().KDFMode
|
||||
if p.Derived {
|
||||
resp.Data["kdf_mode"] = p.KDFMode
|
||||
}
|
||||
|
||||
retKeys := map[string]int64{}
|
||||
for k, v := range lp.Policy().Keys {
|
||||
for k, v := range p.Keys {
|
||||
retKeys[strconv.Itoa(k)] = v.CreationTime
|
||||
}
|
||||
resp.Data["keys"] = retKeys
|
||||
|
@ -96,33 +99,17 @@ func (b *backend) pathPolicyDelete(
|
|||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
// Some sanity checking
|
||||
lp, err := b.policies.getPolicy(req.Storage, name)
|
||||
// Some sanity checking before we lock it all in the DeletePolicy method
|
||||
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err
|
||||
}
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest
|
||||
}
|
||||
b.lm.UnlockPolicy(name, lockType)
|
||||
|
||||
// Hold both locks since we'll be affecting both the cache (if it exists)
|
||||
// and the locking policy itself
|
||||
b.policies.Lock()
|
||||
defer b.policies.Unlock()
|
||||
lp.Lock()
|
||||
defer lp.Unlock()
|
||||
|
||||
// Make sure that we have up-to-date values since deletePolicy will check
|
||||
// things like whether deletion is allowed
|
||||
lp, err = b.policies.refreshPolicy(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lp == nil {
|
||||
return nil, fmt.Errorf("error finding key %s after locking for deletion", name)
|
||||
}
|
||||
|
||||
err = b.policies.deletePolicy(req.Storage, lp, name)
|
||||
err = b.lm.DeletePolicy(req.Storage, name)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err
|
||||
}
|
||||
|
|
|
@ -59,25 +59,18 @@ func (b *backend) pathRewrapWrite(
|
|||
}
|
||||
|
||||
// Get the policy
|
||||
lp, err := b.policies.getPolicy(req.Storage, name)
|
||||
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
lp.RLock()
|
||||
defer lp.RUnlock()
|
||||
defer b.lm.UnlockPolicy(name, lockType)
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.Policy() == nil {
|
||||
return nil, fmt.Errorf("no existing policy named %s could be found", name)
|
||||
}
|
||||
|
||||
plaintext, err := lp.Policy().Decrypt(context, value)
|
||||
plaintext, err := p.Decrypt(context, value)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
|
@ -93,7 +86,7 @@ func (b *backend) pathRewrapWrite(
|
|||
return nil, fmt.Errorf("empty plaintext returned during rewrap")
|
||||
}
|
||||
|
||||
ciphertext, err := lp.Policy().Encrypt(context, plaintext)
|
||||
ciphertext, err := p.Encrypt(context, plaintext)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
|
|
|
@ -31,48 +31,38 @@ func (b *backend) pathRotateWrite(
|
|||
name := d.Get("name").(string)
|
||||
|
||||
// Get the policy
|
||||
lp, err := b.policies.getPolicy(req.Storage, name)
|
||||
p, lockType, err := b.lm.GetPolicy(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
return logical.ErrorResponse("key not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Store so we can detect later if this has changed out from under us
|
||||
keyVersion := lp.Policy().LatestVersion
|
||||
keyVersion := p.LatestVersion
|
||||
|
||||
// lock the policies object so we can refresh
|
||||
b.policies.Lock()
|
||||
defer b.policies.Unlock()
|
||||
lp.Lock()
|
||||
defer lp.Unlock()
|
||||
b.lm.UnlockPolicy(name, lockType)
|
||||
|
||||
// Refresh in case it's changed since before we grabbed the lock
|
||||
lp, err = b.policies.refreshPolicy(req.Storage, name)
|
||||
p, err = b.lm.RefreshPolicy(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
return nil, fmt.Errorf("error finding key %s after locking for changes", name)
|
||||
}
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.Policy() == nil {
|
||||
return nil, fmt.Errorf("no existing key named %s could be found", name)
|
||||
}
|
||||
defer b.lm.UnlockPolicy(name, exclusive)
|
||||
|
||||
// Make sure that the policy hasn't been rotated simultaneously
|
||||
if keyVersion != lp.Policy().LatestVersion {
|
||||
if keyVersion != p.LatestVersion {
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("key has been rotated since this endpoint was called; did not perform rotation")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Rotate the policy
|
||||
err = lp.Policy().rotate(req.Storage)
|
||||
err = p.rotate(req.Storage)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -235,6 +235,67 @@ func (p *Policy) Serialize() ([]byte, error) {
|
|||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
func (p *Policy) needsUpgrade() bool {
|
||||
// Ensure we've moved from Key -> Keys
|
||||
if p.Key != nil && len(p.Key) > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// With archiving, past assumptions about the length of the keys map are no longer valid
|
||||
if p.LatestVersion == 0 && len(p.Keys) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// We disallow setting the version to 0, since they start at 1 since moving
|
||||
// to rotate-able keys, so update if it's set to 0
|
||||
if p.MinDecryptionVersion == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// On first load after an upgrade, copy keys to the archive
|
||||
if p.ArchiveVersion == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Policy) upgrade(storage logical.Storage) error {
|
||||
persistNeeded := false
|
||||
// Ensure we've moved from Key -> Keys
|
||||
if p.Key != nil && len(p.Key) > 0 {
|
||||
p.migrateKeyToKeysMap()
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// With archiving, past assumptions about the length of the keys map are no longer valid
|
||||
if p.LatestVersion == 0 && len(p.Keys) != 0 {
|
||||
p.LatestVersion = len(p.Keys)
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// We disallow setting the version to 0, since they start at 1 since moving
|
||||
// to rotate-able keys, so update if it's set to 0
|
||||
if p.MinDecryptionVersion == 0 {
|
||||
p.MinDecryptionVersion = 1
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// On first load after an upgrade, copy keys to the archive
|
||||
if p.ArchiveVersion == 0 {
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
if persistNeeded {
|
||||
err := p.Persist(storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeriveKey is used to derive the encryption key that should
|
||||
// be used depending on the policy. If derivation is disabled the
|
||||
// raw key is used and no context is required, otherwise the KDF
|
||||
|
|
|
@ -1,188 +0,0 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
type lockingPolicy interface {
|
||||
Lock()
|
||||
RLock()
|
||||
Unlock()
|
||||
RUnlock()
|
||||
Policy() *Policy
|
||||
SetPolicy(*Policy)
|
||||
}
|
||||
|
||||
type policyCRUD interface {
|
||||
// getPolicy returns a lockingPolicy. It performs its own locking according
|
||||
// to implementation.
|
||||
getPolicy(storage logical.Storage, name string) (lockingPolicy, error)
|
||||
|
||||
// refreshPolicy returns a lockingPolicy. It does not perform its own
|
||||
// locking; a write lock must be held before calling.
|
||||
refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error)
|
||||
|
||||
// generatePolicy generates and returns a lockingPolicy. A write lock must
|
||||
// be held before calling.
|
||||
generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error)
|
||||
|
||||
// deletePolicy deletes a lockingPolicy. A write lock must be held on both
|
||||
// the CRUD implementation and the lockingPolicy before calling.
|
||||
deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error
|
||||
|
||||
// These are generally satisfied by embedded mutexes in the implementing struct
|
||||
Lock()
|
||||
RLock()
|
||||
Unlock()
|
||||
RUnlock()
|
||||
}
|
||||
|
||||
// The mutex is kept separate from the struct since we may set it to its own
|
||||
// mutex (if the object is shared) or a shared mutex (if the object isn't
|
||||
// shared and only the locking is)
|
||||
type mutexLockingPolicy struct {
|
||||
mutex *sync.RWMutex
|
||||
policy *Policy
|
||||
}
|
||||
|
||||
func (m *mutexLockingPolicy) Lock() {
|
||||
m.mutex.Lock()
|
||||
}
|
||||
|
||||
func (m *mutexLockingPolicy) RLock() {
|
||||
m.mutex.RLock()
|
||||
}
|
||||
|
||||
func (m *mutexLockingPolicy) Unlock() {
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *mutexLockingPolicy) RUnlock() {
|
||||
m.mutex.RUnlock()
|
||||
}
|
||||
|
||||
func (m *mutexLockingPolicy) Policy() *Policy {
|
||||
return m.policy
|
||||
}
|
||||
|
||||
func (m *mutexLockingPolicy) SetPolicy(p *Policy) {
|
||||
m.policy = p
|
||||
}
|
||||
|
||||
// fetchPolicyFromStorage fetches the policy from backend storage. The caller
|
||||
// should hold the write lock when calling this, to handle upgrades.
|
||||
func fetchPolicyFromStorage(storage logical.Storage, name string) (*Policy, error) {
|
||||
// Check if the policy already exists
|
||||
raw, err := storage.Get("policy/" + name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Decode the policy
|
||||
policy := &Policy{
|
||||
Keys: KeyEntryMap{},
|
||||
}
|
||||
err = json.Unmarshal(raw.Value, policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
persistNeeded := false
|
||||
// Ensure we've moved from Key -> Keys
|
||||
if policy.Key != nil && len(policy.Key) > 0 {
|
||||
policy.migrateKeyToKeysMap()
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// With archiving, past assumptions about the length of the keys map are no longer valid
|
||||
if policy.LatestVersion == 0 && len(policy.Keys) != 0 {
|
||||
policy.LatestVersion = len(policy.Keys)
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// We disallow setting the version to 0, since they start at 1 since moving
|
||||
// to rotate-able keys, so update if it's set to 0
|
||||
if policy.MinDecryptionVersion == 0 {
|
||||
policy.MinDecryptionVersion = 1
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// On first load after an upgrade, copy keys to the archive
|
||||
if policy.ArchiveVersion == 0 {
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
if persistNeeded {
|
||||
err = policy.Persist(storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
// generatePolicyCommon is used to create a new named policy with a randomly
|
||||
// generated key. The caller should have a write lock prior to calling this.
|
||||
func generatePolicyCommon(p policyCRUD, storage logical.Storage, name string, derived bool) (*Policy, error) {
|
||||
// Make sure this doesn't exist in case it was created before we got the write lock
|
||||
policy, err := fetchPolicyFromStorage(storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if policy != nil {
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
// Create the policy object
|
||||
policy = &Policy{
|
||||
Name: name,
|
||||
CipherMode: "aes-gcm",
|
||||
Derived: derived,
|
||||
}
|
||||
if derived {
|
||||
policy.KDFMode = kdfMode
|
||||
}
|
||||
|
||||
err = policy.rotate(storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return policy, err
|
||||
}
|
||||
|
||||
// deletePolicyCommon deletes a policy. The caller should hold the write lock
|
||||
// for both the policy and lockingPolicy prior to calling this.
|
||||
func deletePolicyCommon(p policyCRUD, lp lockingPolicy, storage logical.Storage, name string) error {
|
||||
if lp.Policy() == nil {
|
||||
// This got deleted before we grabbed the lock
|
||||
return fmt.Errorf("policy already deleted")
|
||||
}
|
||||
|
||||
// Verify this hasn't changed
|
||||
if !lp.Policy().DeletionAllowed {
|
||||
return fmt.Errorf("deletion not allowed for policy %s", name)
|
||||
}
|
||||
|
||||
err := storage.Delete("policy/" + name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deleting policy %s: %s", name, err)
|
||||
}
|
||||
|
||||
err = storage.Delete("archive/" + name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deleting archive %s: %s", name, err)
|
||||
}
|
||||
|
||||
lp.SetPolicy(nil)
|
||||
|
||||
return nil
|
||||
}
|
|
@ -16,48 +16,51 @@ func resetKeysArchive() {
|
|||
}
|
||||
|
||||
func Test_KeyUpgrade(t *testing.T) {
|
||||
testKeyUpgradeCommon(t, newSimplePolicyCRUD())
|
||||
testKeyUpgradeCommon(t, newCachingPolicyCRUD())
|
||||
testKeyUpgradeCommon(t, newLockManager(false))
|
||||
testKeyUpgradeCommon(t, newLockManager(true))
|
||||
}
|
||||
|
||||
func testKeyUpgradeCommon(t *testing.T, policies policyCRUD) {
|
||||
func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
|
||||
storage := &logical.InmemStorage{}
|
||||
lp, err := policies.generatePolicy(storage, "test", false)
|
||||
p, lockType, upserted, err := lm.GetPolicyUpsert(storage, "test", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if lp == nil {
|
||||
t.Fatal("nil lockingPolicy")
|
||||
if p == nil {
|
||||
t.Fatal("nil policy")
|
||||
}
|
||||
defer lm.UnlockPolicy("test", lockType)
|
||||
|
||||
if !upserted {
|
||||
t.Fatal("expected an upsert")
|
||||
}
|
||||
if lockType != exclusive {
|
||||
t.Fatal("expected an exclusive lock")
|
||||
}
|
||||
|
||||
policy := lp.Policy()
|
||||
if policy == nil {
|
||||
t.Fatal("nil policy in lockingPolicy")
|
||||
}
|
||||
testBytes := make([]byte, len(p.Keys[1].Key))
|
||||
copy(testBytes, p.Keys[1].Key)
|
||||
|
||||
testBytes := make([]byte, len(policy.Keys[1].Key))
|
||||
copy(testBytes, policy.Keys[1].Key)
|
||||
|
||||
policy.Key = policy.Keys[1].Key
|
||||
policy.Keys = nil
|
||||
policy.migrateKeyToKeysMap()
|
||||
if policy.Key != nil {
|
||||
p.Key = p.Keys[1].Key
|
||||
p.Keys = nil
|
||||
p.migrateKeyToKeysMap()
|
||||
if p.Key != nil {
|
||||
t.Fatal("policy.Key is not nil")
|
||||
}
|
||||
if len(policy.Keys) != 1 {
|
||||
if len(p.Keys) != 1 {
|
||||
t.Fatal("policy.Keys is the wrong size")
|
||||
}
|
||||
if !reflect.DeepEqual(testBytes, policy.Keys[1].Key) {
|
||||
if !reflect.DeepEqual(testBytes, p.Keys[1].Key) {
|
||||
t.Fatal("key mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ArchivingUpgrade(t *testing.T) {
|
||||
testArchivingUpgradeCommon(t, newSimplePolicyCRUD())
|
||||
testArchivingUpgradeCommon(t, newCachingPolicyCRUD())
|
||||
testArchivingUpgradeCommon(t, newLockManager(false))
|
||||
testArchivingUpgradeCommon(t, newLockManager(true))
|
||||
}
|
||||
|
||||
func testArchivingUpgradeCommon(t *testing.T, policies policyCRUD) {
|
||||
func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
|
||||
resetKeysArchive()
|
||||
|
||||
// First, we generate a policy and rotate it a number of times. Each time
|
||||
|
@ -67,30 +70,26 @@ func testArchivingUpgradeCommon(t *testing.T, policies policyCRUD) {
|
|||
|
||||
storage := &logical.InmemStorage{}
|
||||
|
||||
lp, err := policies.generatePolicy(storage, "test", false)
|
||||
p, lockType, _, err := lm.GetPolicyUpsert(storage, "test", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if lp == nil {
|
||||
t.Fatal("nil lockingPolicy")
|
||||
}
|
||||
|
||||
policy := lp.Policy()
|
||||
if policy == nil {
|
||||
t.Fatal("nil policy in lockingPolicy")
|
||||
if p == nil {
|
||||
t.Fatal("nil policy")
|
||||
}
|
||||
lm.UnlockPolicy("test", lockType)
|
||||
|
||||
// Store the initial key in the archive
|
||||
keysArchive = append(keysArchive, policy.Keys[1])
|
||||
checkKeys(t, policy, storage, "initial", 1, 1, 1)
|
||||
keysArchive = append(keysArchive, p.Keys[1])
|
||||
checkKeys(t, p, storage, "initial", 1, 1, 1)
|
||||
|
||||
for i := 2; i <= 10; i++ {
|
||||
err = policy.rotate(storage)
|
||||
err = p.rotate(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
keysArchive = append(keysArchive, policy.Keys[i])
|
||||
checkKeys(t, policy, storage, "rotate", i, i, i)
|
||||
keysArchive = append(keysArchive, p.Keys[i])
|
||||
checkKeys(t, p, storage, "rotate", i, i, i)
|
||||
}
|
||||
|
||||
// Now, wipe the archive and set the archive version to zero
|
||||
|
@ -98,53 +97,49 @@ func testArchivingUpgradeCommon(t *testing.T, policies policyCRUD) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
policy.ArchiveVersion = 0
|
||||
p.ArchiveVersion = 0
|
||||
|
||||
// Store it, but without calling persist, so we don't trigger
|
||||
// handleArchiving()
|
||||
buf, err := policy.Serialize()
|
||||
buf, err := p.Serialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Write the policy into storage
|
||||
err = storage.Put(&logical.StorageEntry{
|
||||
Key: "policy/" + policy.Name,
|
||||
Key: "policy/" + p.Name,
|
||||
Value: buf,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// If it's a caching CRUD, expire from the cache since we modified it
|
||||
// If we're caching, expire from the cache since we modified it
|
||||
// under-the-hood
|
||||
if cachingCRUD, ok := policies.(*cachingPolicyCRUD); ok {
|
||||
delete(cachingCRUD.cache, "test")
|
||||
if lm.CacheActive() {
|
||||
delete(lm.cache, "test")
|
||||
}
|
||||
|
||||
// Now get the policy again; the upgrade should happen automatically
|
||||
lp, err = policies.getPolicy(storage, "test")
|
||||
p, lockType, err = lm.GetPolicy(storage, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if lp == nil {
|
||||
if p == nil {
|
||||
t.Fatal("nil lockingPolicy")
|
||||
}
|
||||
lm.UnlockPolicy("test", lockType)
|
||||
|
||||
policy = lp.Policy()
|
||||
if policy == nil {
|
||||
t.Fatal("nil policy in lockingPolicy")
|
||||
}
|
||||
|
||||
checkKeys(t, policy, storage, "upgrade", 10, 10, 10)
|
||||
checkKeys(t, p, storage, "upgrade", 10, 10, 10)
|
||||
}
|
||||
|
||||
func Test_Archiving(t *testing.T) {
|
||||
testArchivingCommon(t, newSimplePolicyCRUD())
|
||||
testArchivingCommon(t, newCachingPolicyCRUD())
|
||||
testArchivingCommon(t, newLockManager(false))
|
||||
testArchivingCommon(t, newLockManager(true))
|
||||
}
|
||||
|
||||
func testArchivingCommon(t *testing.T, policies policyCRUD) {
|
||||
func testArchivingCommon(t *testing.T, lm *lockManager) {
|
||||
resetKeysArchive()
|
||||
|
||||
// First, we generate a policy and rotate it a number of times. Each time
|
||||
|
@ -154,37 +149,33 @@ func testArchivingCommon(t *testing.T, policies policyCRUD) {
|
|||
|
||||
storage := &logical.InmemStorage{}
|
||||
|
||||
lp, err := policies.generatePolicy(storage, "test", false)
|
||||
p, lockType, _, err := lm.GetPolicyUpsert(storage, "test", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if lp == nil {
|
||||
t.Fatal("nil lockingPolicy")
|
||||
}
|
||||
|
||||
policy := lp.Policy()
|
||||
if policy == nil {
|
||||
t.Fatal("nil policy in lockingPolicy")
|
||||
if p == nil {
|
||||
t.Fatal("nil policy")
|
||||
}
|
||||
defer lm.UnlockPolicy("test", lockType)
|
||||
|
||||
// Store the initial key in the archive
|
||||
keysArchive = append(keysArchive, policy.Keys[1])
|
||||
checkKeys(t, policy, storage, "initial", 1, 1, 1)
|
||||
keysArchive = append(keysArchive, p.Keys[1])
|
||||
checkKeys(t, p, storage, "initial", 1, 1, 1)
|
||||
|
||||
for i := 2; i <= 10; i++ {
|
||||
err = policy.rotate(storage)
|
||||
err = p.rotate(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
keysArchive = append(keysArchive, policy.Keys[i])
|
||||
checkKeys(t, policy, storage, "rotate", i, i, i)
|
||||
keysArchive = append(keysArchive, p.Keys[i])
|
||||
checkKeys(t, p, storage, "rotate", i, i, i)
|
||||
}
|
||||
|
||||
// Move the min decryption version up
|
||||
for i := 1; i <= 10; i++ {
|
||||
policy.MinDecryptionVersion = i
|
||||
p.MinDecryptionVersion = i
|
||||
|
||||
err = policy.Persist(storage)
|
||||
err = p.Persist(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -196,14 +187,14 @@ func testArchivingCommon(t *testing.T, policies policyCRUD) {
|
|||
// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
|
||||
// decryption version plus 1 (the min decryption version key
|
||||
// itself)
|
||||
checkKeys(t, policy, storage, "minadd", 10, 10, policy.LatestVersion-policy.MinDecryptionVersion+1)
|
||||
checkKeys(t, p, storage, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1)
|
||||
}
|
||||
|
||||
// Move the min decryption version down
|
||||
for i := 10; i >= 1; i-- {
|
||||
policy.MinDecryptionVersion = i
|
||||
p.MinDecryptionVersion = i
|
||||
|
||||
err = policy.Persist(storage)
|
||||
err = p.Persist(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -215,7 +206,7 @@ func testArchivingCommon(t *testing.T, policies policyCRUD) {
|
|||
// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
|
||||
// decryption version plus 1 (the min decryption version key
|
||||
// itself)
|
||||
checkKeys(t, policy, storage, "minsub", 10, 10, policy.LatestVersion-policy.MinDecryptionVersion+1)
|
||||
checkKeys(t, p, storage, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,82 +0,0 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
// Directly implements CRUD operations without caching, mapped to the backend,
|
||||
// but implements shared locking to ensure that we can't overwrite data on the
|
||||
// backend from multiple operators
|
||||
type simplePolicyCRUD struct {
|
||||
sync.RWMutex
|
||||
locks map[string]*sync.RWMutex
|
||||
}
|
||||
|
||||
func newSimplePolicyCRUD() *simplePolicyCRUD {
|
||||
return &simplePolicyCRUD{
|
||||
locks: map[string]*sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// The write lock must be held before calling this; for this CRUD type this
|
||||
// should always be the case, since the only method not requiring a write lock
|
||||
// when called is getPolicy, and that itself grabs a write lock before calling
|
||||
// refreshPolicy
|
||||
func (p *simplePolicyCRUD) ensureLockExists(name string) {
|
||||
if p.locks[name] == nil {
|
||||
p.locks[name] = &sync.RWMutex{}
|
||||
}
|
||||
}
|
||||
|
||||
// See general comments on the interface method
|
||||
func (p *simplePolicyCRUD) getPolicy(storage logical.Storage, name string) (lockingPolicy, error) {
|
||||
// Use a write lock since fetching the policy can cause a need for upgrade persistence
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
return p.refreshPolicy(storage, name)
|
||||
}
|
||||
|
||||
// See general comments on the interface method
|
||||
func (p *simplePolicyCRUD) refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) {
|
||||
p.ensureLockExists(name)
|
||||
|
||||
policy, err := fetchPolicyFromStorage(storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if policy == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lp := &mutexLockingPolicy{
|
||||
policy: policy,
|
||||
mutex: p.locks[name],
|
||||
}
|
||||
|
||||
return lp, nil
|
||||
}
|
||||
|
||||
// See general comments on the interface method
|
||||
func (p *simplePolicyCRUD) generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) {
|
||||
p.ensureLockExists(name)
|
||||
|
||||
policy, err := generatePolicyCommon(p, storage, name, derived)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lp := &mutexLockingPolicy{
|
||||
policy: policy,
|
||||
mutex: p.locks[name],
|
||||
}
|
||||
|
||||
return lp, nil
|
||||
}
|
||||
|
||||
// See general comments on the interface method
|
||||
func (p *simplePolicyCRUD) deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error {
|
||||
return deletePolicyCommon(p, lp, storage, name)
|
||||
}
|
Loading…
Reference in a new issue