diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index bd50d8d2e..94d07cbd9 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -33,16 +33,12 @@ func Backend(conf *logical.BackendConfig) *backend { Secrets: []*framework.Secret{}, } - if conf.System.CachingDisabled() { - b.policies = newSimplePolicyCRUD() - } else { - b.policies = newCachingPolicyCRUD() - } + b.lm = newLockManager(conf.System.CachingDisabled()) return &b } type backend struct { *framework.Backend - policies policyCRUD + lm *lockManager } diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index 3b8319913..3b31cc2bb 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -604,7 +604,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { var chosenFunc, chosenKey string - //t.Logf("Starting") + //t.Errorf("Starting %d", id) for { // Stop after 10 seconds if time.Now().Sub(startTime) > 10*time.Second { @@ -629,7 +629,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { switch chosenFunc { // Encrypt our plaintext and store the result case "encrypt": - //t.Logf("%s, %s", chosenFunc, chosenKey) + //t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id) fd.Raw["plaintext"] = base64.StdEncoding.EncodeToString([]byte(testPlaintext)) fd.Schema = be.pathEncrypt().Fields resp, err := be.pathEncryptWrite(req, fd) @@ -649,7 +649,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { // Decrypt the ciphertext and compare the result case "decrypt": - //t.Logf("%s, %s", chosenFunc, chosenKey) + //t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id) ct := latestEncryptedText[chosenKey] if ct == "" { continue @@ -677,7 +677,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) { // Change the min version, which also tests the archive functionality case "change_min_version": - //t.Logf("%s, %s", chosenFunc, chosenKey) + //t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id) resp, err := be.pathPolicyRead(req, fd) if err != nil { t.Fatalf("got an error reading policy %s: %v", chosenKey, err) diff --git a/builtin/logical/transit/caching_crud.go b/builtin/logical/transit/caching_crud.go deleted file mode 100644 index 8759a4228..000000000 --- a/builtin/logical/transit/caching_crud.go +++ /dev/null @@ -1,110 +0,0 @@ -package transit - -import ( - "sync" - - "github.com/hashicorp/vault/logical" -) - -// cachingPolicyCRUD implements CRUD operations with a simple locking cache of -// policies in memory -type cachingPolicyCRUD struct { - sync.RWMutex - cache map[string]lockingPolicy -} - -func newCachingPolicyCRUD() *cachingPolicyCRUD { - return &cachingPolicyCRUD{ - cache: map[string]lockingPolicy{}, - } -} - -// See general comments on the interface method -func (p *cachingPolicyCRUD) getPolicy(storage logical.Storage, name string) (lockingPolicy, error) { - // We don't defer this since we may need to give it up and get a write lock - p.RLock() - - // First, see if we're in the cache -- if so, return that - if p.cache[name] != nil { - defer p.RUnlock() - return p.cache[name], nil - } - - // If we didn't find anything, we'll need to write into the cache, plus possibly - // persist the entry, so lock the cache - p.RUnlock() - p.Lock() - defer p.Unlock() - - // Check one more time to ensure that another process did not write during - // our lock switcheroo. - if p.cache[name] != nil { - return p.cache[name], nil - } - - return p.refreshPolicy(storage, name) -} - -// See general comments on the interface method -func (p *cachingPolicyCRUD) refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) { - // Check once more to ensure it hasn't been added to the cache since the lock was acquired - if p.cache[name] != nil { - return p.cache[name], nil - } - - // Note that we don't need to create the locking entry until the end, - // because the policy wasn't in the cache so we don't know about it, and we - // hold the cache lock so nothing else can be writing it in right now - policy, err := fetchPolicyFromStorage(storage, name) - if err != nil { - return nil, err - } - if policy == nil { - return nil, nil - } - - lp := &mutexLockingPolicy{ - policy: policy, - mutex: &sync.RWMutex{}, - } - p.cache[name] = lp - - return lp, nil -} - -// See general comments on the interface method -func (p *cachingPolicyCRUD) generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) { - policy, err := generatePolicyCommon(p, storage, name, derived) - if err != nil { - return nil, err - } - - // Now we need to check again in the cache to ensure the policy wasn't - // created since we ran generatePolicy and then got the lock. A policy - // being created holds a write lock until it's done (starting from this - // point), so it'll be in the cache at this point. - if lp := p.cache[name]; lp != nil { - return lp, nil - } - - lp := &mutexLockingPolicy{ - policy: policy, - mutex: &sync.RWMutex{}, - } - p.cache[name] = lp - - // Return the policy - return lp, nil -} - -// See general comments on the interface method -func (p *cachingPolicyCRUD) deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error { - err := deletePolicyCommon(p, lp, storage, name) - if err != nil { - return err - } - - delete(p.cache, name) - - return nil -} diff --git a/builtin/logical/transit/lock_manager.go b/builtin/logical/transit/lock_manager.go new file mode 100644 index 000000000..6f994b4ed --- /dev/null +++ b/builtin/logical/transit/lock_manager.go @@ -0,0 +1,337 @@ +package transit + +import ( + "encoding/json" + "fmt" + "sync" + + "github.com/hashicorp/vault/logical" +) + +const ( + shared = false + exclusive = true +) + +type lockManager struct { + // A lock for each named key + locks map[string]*sync.RWMutex + + // A mutex for the map itself + lockMutex sync.RWMutex + + // If caching is enabled, the map of name to in-memory policy cache + cache map[string]*Policy + + // Used for global locking, and as the cache map mutex + globalMutex sync.RWMutex +} + +func newLockManager(cacheDisabled bool) *lockManager { + lm := &lockManager{ + locks: map[string]*sync.RWMutex{}, + } + if !cacheDisabled { + lm.cache = map[string]*Policy{} + } + return lm +} + +func (lm *lockManager) CacheActive() bool { + return lm.cache != nil +} + +func (lm *lockManager) LockAll(name string) { + lm.globalMutex.Lock() + lm.LockPolicy(name, exclusive) +} + +func (lm *lockManager) UnlockAll(name string) { + lm.UnlockPolicy(name, exclusive) + lm.globalMutex.Unlock() +} + +func (lm *lockManager) LockPolicy(name string, writeLock bool) { + lm.lockMutex.RLock() + lock := lm.locks[name] + if lock != nil { + // We want to give this up before locking the lock, but it's safe -- + // the only time we ever write to a value in this map is the first time + // we access the value, so it won't be changing out from under us + lm.lockMutex.RUnlock() + if writeLock { + lock.Lock() + } else { + lock.RLock() + } + return + } + + lm.lockMutex.RUnlock() + lm.lockMutex.Lock() + + // Don't defer the unlock call because if we get a valid lock below we want + // to release the lock mutex right away to avoid the possibility of + // deadlock by trying to grab the second lock + + // Check to make sure it hasn't been created since + lock = lm.locks[name] + if lock != nil { + lm.lockMutex.Unlock() + if writeLock { + lock.Lock() + } else { + lock.RLock() + } + return + } + + lock = &sync.RWMutex{} + lm.locks[name] = lock + lm.lockMutex.Unlock() + if writeLock { + lock.Lock() + } else { + lock.RLock() + } +} + +func (lm *lockManager) UnlockPolicy(name string, writeLock bool) { + lm.lockMutex.RLock() + lock := lm.locks[name] + lm.lockMutex.RUnlock() + + if writeLock { + lock.Unlock() + } else { + lock.RUnlock() + } +} + +func (lm *lockManager) GetPolicy(storage logical.Storage, name string) (*Policy, bool, error) { + p, lt, _, err := lm.getPolicyCommon(storage, name, false, false) + return p, lt, err +} + +func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived bool) (*Policy, bool, bool, error) { + return lm.getPolicyCommon(storage, name, true, derived) +} + +// When the function returns, a lock will be held on the policy if err == nil. +// The type of lock will be indicated by the return value. It is the caller's +// responsibility to unlock. +func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived bool) (p *Policy, lockType bool, upserted bool, err error) { + // If we are using a cache, lock it now to avoid having to do really + // complicated lock juggling as we call various functions. We'll also defer + // the store into the cache. + lockType = shared + lm.LockPolicy(name, shared) + + if lm.CacheActive() { + lm.globalMutex.RLock() + p = lm.cache[name] + if p != nil { + defer lm.globalMutex.RUnlock() + return + } + lm.globalMutex.RUnlock() + + // When we return, since we didn't have the policy in the cache, if + // there was no error, write the value in. + defer func() { + if err == nil { + lm.globalMutex.Lock() + defer lm.globalMutex.Unlock() + // Make sure a policy didn't appear + exp := lm.cache[name] + if exp != nil { + p = exp + return + } + + lm.cache[name] = p + } + }() + } + + p, err = lm.getStoredPolicy(storage, name) + if err != nil { + defer lm.UnlockPolicy(name, shared) + return + } + + if p == nil { + if !upsert { + defer lm.UnlockPolicy(name, shared) + return + } + + // Get an exlusive lock; on success, check again to ensure that no + // policy exists. Note that if we are using a cache we will already be + // serializing this entire code path and it's currently the only one + // that generates policies, so we don't need to check the cache here; + // simply checking the disk again is sufficient. + lm.UnlockPolicy(name, shared) + lockType = exclusive + lm.LockPolicy(name, exclusive) + + p, err = lm.getStoredPolicy(storage, name) + if err != nil { + defer lm.UnlockPolicy(name, exclusive) + return + } + if p != nil { + return + } + + upserted = true + + p = &Policy{ + Name: name, + CipherMode: "aes-gcm", + Derived: derived, + } + if derived { + p.KDFMode = kdfMode + } + + err = p.rotate(storage) + if err != nil { + defer lm.UnlockPolicy(name, exclusive) + p = nil + } + + // We don't need to worry about upgrading since it will be a new policy + return + } + + if p.needsUpgrade() { + lm.UnlockPolicy(name, shared) + lockType = exclusive + lm.LockPolicy(name, exclusive) + + // Reload the policy with the write lock to ensure we still need the upgrade + p, err = lm.getStoredPolicy(storage, name) + if err != nil { + defer lm.UnlockPolicy(name, exclusive) + return + } + if p == nil { + defer lm.UnlockPolicy(name, exclusive) + err = fmt.Errorf("error reloading policy for upgrade") + return + } + + if !p.needsUpgrade() { + // Already happened, return the newly loaded policy + return + } + + err = p.upgrade(storage) + if err != nil { + defer lm.UnlockPolicy(name, exclusive) + } + } + + return +} + +func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error { + lm.LockAll(name) + defer lm.UnlockAll(name) + + var p *Policy + var err error + + if lm.CacheActive() { + p = lm.cache[name] + if p == nil { + return fmt.Errorf("could not delete policy; not found") + } + } else { + p, err = lm.getStoredPolicy(storage, name) + if err != nil { + return err + } + if p == nil { + return fmt.Errorf("could not delete policy; not found") + } + } + + if !p.DeletionAllowed { + return fmt.Errorf("deletion is not allowed for this policy") + } + + err = storage.Delete("policy/" + name) + if err != nil { + return fmt.Errorf("error deleting policy %s: %s", name, err) + } + + err = storage.Delete("archive/" + name) + if err != nil { + return fmt.Errorf("error deleting archive %s: %s", name, err) + } + + if lm.CacheActive() { + delete(lm.cache, name) + } + + return nil +} + +// When this function returns it's the responsibility of the caller to call UnlockAll if err is not nil +func (lm *lockManager) RefreshPolicy(storage logical.Storage, name string) (p *Policy, err error) { + lm.LockPolicy(name, exclusive) + + if lm.CacheActive() { + p = lm.cache[name] + if p != nil { + return + } + err = fmt.Errorf("could not refresh policy; not found") + defer lm.UnlockPolicy(name, exclusive) + return + } + + p, err = lm.getStoredPolicy(storage, name) + if err != nil { + defer lm.UnlockPolicy(name, exclusive) + return + } + + if p == nil { + err = fmt.Errorf("could not refresh policy; not found") + defer lm.UnlockPolicy(name, exclusive) + } + + if p.needsUpgrade() { + err = p.upgrade(storage) + if err != nil { + defer lm.UnlockPolicy(name, exclusive) + } + } + + return +} + +func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*Policy, error) { + // Check if the policy already exists + raw, err := storage.Get("policy/" + name) + if err != nil { + return nil, err + } + if raw == nil { + return nil, nil + } + + // Decode the policy + policy := &Policy{ + Keys: KeyEntryMap{}, + } + err = json.Unmarshal(raw.Value, policy) + if err != nil { + return nil, err + } + + return policy, nil +} diff --git a/builtin/logical/transit/path_config.go b/builtin/logical/transit/path_config.go index 7395fbe10..eabb3603f 100644 --- a/builtin/logical/transit/path_config.go +++ b/builtin/logical/transit/path_config.go @@ -41,49 +41,38 @@ func (b *backend) pathConfigWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - // Check if the policy already exists - lp, err := b.policies.getPolicy(req.Storage, name) + // Check if the policy already exists before we lock everything + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return nil, err } - if lp == nil { + if p == nil { return logical.ErrorResponse( fmt.Sprintf("no existing key named %s could be found", name)), logical.ErrInvalidRequest } - // Store some values so we can detect if the policy changed after locking - lp.RLock() - currDeletionAllowed := lp.Policy().DeletionAllowed - currMinDecryptionVersion := lp.Policy().MinDecryptionVersion - lp.RUnlock() + // Store some values so we can detect a change when we lock everything + currDeletionAllowed := p.DeletionAllowed + currMinDecryptionVersion := p.MinDecryptionVersion - // Hold both locks since we want to ensure the policy doesn't change from - // underneath us - b.policies.Lock() - defer b.policies.Unlock() - lp.Lock() - defer lp.Unlock() + b.lm.UnlockPolicy(name, lockType) // Refresh in case it's changed since before we grabbed the lock - lp, err = b.policies.refreshPolicy(req.Storage, name) + p, err = b.lm.RefreshPolicy(req.Storage, name) if err != nil { return nil, err } - if lp == nil { + if p == nil { return nil, fmt.Errorf("error finding key %s after locking for changes", name) } - - // Verify if wasn't deleted before we grabbed the lock - if lp.Policy() == nil { - return nil, fmt.Errorf("no existing key named %s could be found", name) - } + defer b.lm.UnlockPolicy(name, exclusive) resp := &logical.Response{} // Check for anything to have been updated since we got the write lock - if currDeletionAllowed != lp.Policy().DeletionAllowed || - currMinDecryptionVersion != lp.Policy().MinDecryptionVersion { + if currDeletionAllowed != p.DeletionAllowed || + currMinDecryptionVersion != p.MinDecryptionVersion { resp.AddWarning("key configuration has changed since this endpoint was called, not updating") return resp, nil } @@ -104,12 +93,12 @@ func (b *backend) pathConfigWrite( } if minDecryptionVersion > 0 && - minDecryptionVersion != lp.Policy().MinDecryptionVersion { - if minDecryptionVersion > lp.Policy().LatestVersion { + minDecryptionVersion != p.MinDecryptionVersion { + if minDecryptionVersion > p.LatestVersion { return logical.ErrorResponse( - fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, lp.Policy().LatestVersion)), nil + fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, p.LatestVersion)), nil } - lp.Policy().MinDecryptionVersion = minDecryptionVersion + p.MinDecryptionVersion = minDecryptionVersion persistNeeded = true } } @@ -117,8 +106,8 @@ func (b *backend) pathConfigWrite( allowDeletionInt, ok := d.GetOk("deletion_allowed") if ok { allowDeletion := allowDeletionInt.(bool) - if allowDeletion != lp.Policy().DeletionAllowed { - lp.Policy().DeletionAllowed = allowDeletion + if allowDeletion != p.DeletionAllowed { + p.DeletionAllowed = allowDeletion persistNeeded = true } } @@ -126,8 +115,8 @@ func (b *backend) pathConfigWrite( // Add this as a guard here before persisting since we now require the min // decryption version to start at 1; even if it's not explicitly set here, // force the upgrade - if lp.Policy().MinDecryptionVersion == 0 { - lp.Policy().MinDecryptionVersion = 1 + if p.MinDecryptionVersion == 0 { + p.MinDecryptionVersion = 1 persistNeeded = true } @@ -135,7 +124,7 @@ func (b *backend) pathConfigWrite( return nil, nil } - return resp, lp.Policy().Persist(req.Storage) + return resp, p.Persist(req.Storage) } const pathConfigHelpSyn = `Configure a named encryption key` diff --git a/builtin/logical/transit/path_datakey.go b/builtin/logical/transit/path_datakey.go index 9e6483673..bd6fca651 100644 --- a/builtin/logical/transit/path_datakey.go +++ b/builtin/logical/transit/path_datakey.go @@ -73,23 +73,14 @@ func (b *backend) pathDatakeyWrite( } // Get the policy - lp, err := b.policies.getPolicy(req.Storage, name) + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return nil, err } - - // Error if invalid policy - if lp == nil { + if p == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - - lp.RLock() - defer lp.RUnlock() - - // Verify if wasn't deleted before we grabbed the lock - if lp.Policy() == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } + defer b.lm.UnlockPolicy(name, lockType) newKey := make([]byte, 32) bits := d.Get("bits").(int) @@ -107,7 +98,7 @@ func (b *backend) pathDatakeyWrite( return nil, err } - ciphertext, err := lp.Policy().Encrypt(context, base64.StdEncoding.EncodeToString(newKey)) + ciphertext, err := p.Encrypt(context, base64.StdEncoding.EncodeToString(newKey)) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index a58709858..8c5799331 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -58,25 +58,17 @@ func (b *backend) pathDecryptWrite( } // Get the policy - lp, err := b.policies.getPolicy(req.Storage, name) + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return nil, err } - - // Error if invalid policy - if lp == nil { + if p == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - lp.RLock() - defer lp.RUnlock() + defer b.lm.UnlockPolicy(name, lockType) - // Verify if wasn't deleted before we grabbed the lock - if lp.Policy() == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } - - plaintext, err := lp.Policy().Decrypt(context, ciphertext) + plaintext, err := p.Decrypt(context, ciphertext) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index 8c4cc91e7..bba766797 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -44,12 +44,14 @@ func (b *backend) pathEncrypt() *framework.Path { func (b *backend) pathEncryptExistenceCheck( req *logical.Request, d *framework.FieldData) (bool, error) { name := d.Get("name").(string) - lp, err := b.policies.getPolicy(req.Storage, name) + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return false, err } - - return lp != nil, nil + if p != nil { + defer b.lm.UnlockPolicy(name, lockType) + } + return p != nil, nil } func (b *backend) pathEncryptWrite( @@ -63,8 +65,8 @@ func (b *backend) pathEncryptWrite( // Decode the context if any contextRaw := d.Get("context").(string) var context []byte + var err error if len(contextRaw) != 0 { - var err error context, err = base64.StdEncoding.DecodeString(contextRaw) if err != nil { return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest @@ -72,43 +74,23 @@ func (b *backend) pathEncryptWrite( } // Get the policy - lp, err := b.policies.getPolicy(req.Storage, name) + var p *Policy + var lockType bool + var upserted bool + if req.Operation == logical.CreateOperation { + p, lockType, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0) + } else { + p, lockType, err = b.lm.GetPolicy(req.Storage, name) + } if err != nil { return nil, err } - - // Error or upsert if invalid policy - if lp == nil { - if req.Operation != logical.CreateOperation { - return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest - } - - // Get a write lock - b.policies.Lock() - - isDerived := len(context) != 0 - - // This also checks to make sure one hasn't been created since we grabbed the write lock - lp, err = b.policies.generatePolicy(req.Storage, name, isDerived) - // If the error is that the policy has been created in the interim we - // will get the policy back, so only consider it an error if err is not - // nil and we do not get a policy back - if err != nil && lp != nil { - b.policies.Unlock() - return nil, err - } - b.policies.Unlock() + if p == nil { + return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } + defer b.lm.UnlockPolicy(name, lockType) - lp.RLock() - defer lp.RUnlock() - - // Verify if wasn't deleted before we grabbed the lock - if lp.Policy() == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } - - ciphertext, err := lp.Policy().Encrypt(context, value) + ciphertext, err := p.Encrypt(context, value) if err != nil { switch err.(type) { case certutil.UserError: @@ -130,6 +112,11 @@ func (b *backend) pathEncryptWrite( "ciphertext": ciphertext, }, } + + if req.Operation == logical.CreateOperation && !upserted { + resp.AddWarning("Attempted creation of the key during the encrypt operation, but it was created beforehand") + } + return resp, nil } diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 32cbad485..956b295d1 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -36,55 +36,58 @@ func (b *backend) pathKeys() *framework.Path { func (b *backend) pathPolicyWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - // Grab a write lock right off the bat - b.policies.Lock() - defer b.policies.Unlock() - name := d.Get("name").(string) derived := d.Get("derived").(bool) - // Generate the policy; this will also check if it exists for safety - _, err := b.policies.generatePolicy(req.Storage, name, derived) - return nil, err + p, lockType, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived) + if err != nil { + return nil, err + } + if p == nil { + return nil, fmt.Errorf("error generating key: returned policy was nil") + } + + defer b.lm.UnlockPolicy(name, lockType) + + resp := &logical.Response{} + if !upserted { + resp.AddWarning(fmt.Sprintf("key %s already existed", name)) + } + + return nil, nil } func (b *backend) pathPolicyRead( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - lp, err := b.policies.getPolicy(req.Storage, name) + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return nil, err } - if lp == nil { + if p == nil { return nil, nil } - lp.RLock() - defer lp.RUnlock() - - // Verify if wasn't deleted before we grabbed the lock - if lp.Policy() == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } + defer b.lm.UnlockPolicy(name, lockType) // Return the response resp := &logical.Response{ Data: map[string]interface{}{ - "name": lp.Policy().Name, - "cipher_mode": lp.Policy().CipherMode, - "derived": lp.Policy().Derived, - "deletion_allowed": lp.Policy().DeletionAllowed, - "min_decryption_version": lp.Policy().MinDecryptionVersion, - "latest_version": lp.Policy().LatestVersion, + "name": p.Name, + "cipher_mode": p.CipherMode, + "derived": p.Derived, + "deletion_allowed": p.DeletionAllowed, + "min_decryption_version": p.MinDecryptionVersion, + "latest_version": p.LatestVersion, }, } - if lp.Policy().Derived { - resp.Data["kdf_mode"] = lp.Policy().KDFMode + if p.Derived { + resp.Data["kdf_mode"] = p.KDFMode } retKeys := map[string]int64{} - for k, v := range lp.Policy().Keys { + for k, v := range p.Keys { retKeys[strconv.Itoa(k)] = v.CreationTime } resp.Data["keys"] = retKeys @@ -96,33 +99,17 @@ func (b *backend) pathPolicyDelete( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - // Some sanity checking - lp, err := b.policies.getPolicy(req.Storage, name) + // Some sanity checking before we lock it all in the DeletePolicy method + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err } - if lp == nil { + if p == nil { return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest } + b.lm.UnlockPolicy(name, lockType) - // Hold both locks since we'll be affecting both the cache (if it exists) - // and the locking policy itself - b.policies.Lock() - defer b.policies.Unlock() - lp.Lock() - defer lp.Unlock() - - // Make sure that we have up-to-date values since deletePolicy will check - // things like whether deletion is allowed - lp, err = b.policies.refreshPolicy(req.Storage, name) - if err != nil { - return nil, err - } - if lp == nil { - return nil, fmt.Errorf("error finding key %s after locking for deletion", name) - } - - err = b.policies.deletePolicy(req.Storage, lp, name) + err = b.lm.DeletePolicy(req.Storage, name) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err } diff --git a/builtin/logical/transit/path_rewrap.go b/builtin/logical/transit/path_rewrap.go index 2233689b3..61e3f607e 100644 --- a/builtin/logical/transit/path_rewrap.go +++ b/builtin/logical/transit/path_rewrap.go @@ -59,25 +59,18 @@ func (b *backend) pathRewrapWrite( } // Get the policy - lp, err := b.policies.getPolicy(req.Storage, name) + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return nil, err } - // Error if invalid policy - if lp == nil { + if p == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - lp.RLock() - defer lp.RUnlock() + defer b.lm.UnlockPolicy(name, lockType) - // Verify if wasn't deleted before we grabbed the lock - if lp.Policy() == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) - } - - plaintext, err := lp.Policy().Decrypt(context, value) + plaintext, err := p.Decrypt(context, value) if err != nil { switch err.(type) { case certutil.UserError: @@ -93,7 +86,7 @@ func (b *backend) pathRewrapWrite( return nil, fmt.Errorf("empty plaintext returned during rewrap") } - ciphertext, err := lp.Policy().Encrypt(context, plaintext) + ciphertext, err := p.Encrypt(context, plaintext) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_rotate.go b/builtin/logical/transit/path_rotate.go index 93e1951a9..442646ef8 100644 --- a/builtin/logical/transit/path_rotate.go +++ b/builtin/logical/transit/path_rotate.go @@ -31,48 +31,38 @@ func (b *backend) pathRotateWrite( name := d.Get("name").(string) // Get the policy - lp, err := b.policies.getPolicy(req.Storage, name) + p, lockType, err := b.lm.GetPolicy(req.Storage, name) if err != nil { return nil, err } - - // Error if invalid policy - if lp == nil { + if p == nil { return logical.ErrorResponse("key not found"), logical.ErrInvalidRequest } // Store so we can detect later if this has changed out from under us - keyVersion := lp.Policy().LatestVersion + keyVersion := p.LatestVersion - // lock the policies object so we can refresh - b.policies.Lock() - defer b.policies.Unlock() - lp.Lock() - defer lp.Unlock() + b.lm.UnlockPolicy(name, lockType) // Refresh in case it's changed since before we grabbed the lock - lp, err = b.policies.refreshPolicy(req.Storage, name) + p, err = b.lm.RefreshPolicy(req.Storage, name) if err != nil { return nil, err } - if lp == nil { + if p == nil { return nil, fmt.Errorf("error finding key %s after locking for changes", name) } - - // Verify if wasn't deleted before we grabbed the lock - if lp.Policy() == nil { - return nil, fmt.Errorf("no existing key named %s could be found", name) - } + defer b.lm.UnlockPolicy(name, exclusive) // Make sure that the policy hasn't been rotated simultaneously - if keyVersion != lp.Policy().LatestVersion { + if keyVersion != p.LatestVersion { resp := &logical.Response{} resp.AddWarning("key has been rotated since this endpoint was called; did not perform rotation") return resp, nil } // Rotate the policy - err = lp.Policy().rotate(req.Storage) + err = p.rotate(req.Storage) return nil, err } diff --git a/builtin/logical/transit/policy.go b/builtin/logical/transit/policy.go index ad47af754..c1b09e2d9 100644 --- a/builtin/logical/transit/policy.go +++ b/builtin/logical/transit/policy.go @@ -235,6 +235,67 @@ func (p *Policy) Serialize() ([]byte, error) { return json.Marshal(p) } +func (p *Policy) needsUpgrade() bool { + // Ensure we've moved from Key -> Keys + if p.Key != nil && len(p.Key) > 0 { + return true + } + + // With archiving, past assumptions about the length of the keys map are no longer valid + if p.LatestVersion == 0 && len(p.Keys) != 0 { + return true + } + + // We disallow setting the version to 0, since they start at 1 since moving + // to rotate-able keys, so update if it's set to 0 + if p.MinDecryptionVersion == 0 { + return true + } + + // On first load after an upgrade, copy keys to the archive + if p.ArchiveVersion == 0 { + return true + } + + return false +} + +func (p *Policy) upgrade(storage logical.Storage) error { + persistNeeded := false + // Ensure we've moved from Key -> Keys + if p.Key != nil && len(p.Key) > 0 { + p.migrateKeyToKeysMap() + persistNeeded = true + } + + // With archiving, past assumptions about the length of the keys map are no longer valid + if p.LatestVersion == 0 && len(p.Keys) != 0 { + p.LatestVersion = len(p.Keys) + persistNeeded = true + } + + // We disallow setting the version to 0, since they start at 1 since moving + // to rotate-able keys, so update if it's set to 0 + if p.MinDecryptionVersion == 0 { + p.MinDecryptionVersion = 1 + persistNeeded = true + } + + // On first load after an upgrade, copy keys to the archive + if p.ArchiveVersion == 0 { + persistNeeded = true + } + + if persistNeeded { + err := p.Persist(storage) + if err != nil { + return err + } + } + + return nil +} + // DeriveKey is used to derive the encryption key that should // be used depending on the policy. If derivation is disabled the // raw key is used and no context is required, otherwise the KDF diff --git a/builtin/logical/transit/policy_crud.go b/builtin/logical/transit/policy_crud.go deleted file mode 100644 index 46472de96..000000000 --- a/builtin/logical/transit/policy_crud.go +++ /dev/null @@ -1,188 +0,0 @@ -package transit - -import ( - "encoding/json" - "fmt" - "sync" - - "github.com/hashicorp/vault/logical" -) - -type lockingPolicy interface { - Lock() - RLock() - Unlock() - RUnlock() - Policy() *Policy - SetPolicy(*Policy) -} - -type policyCRUD interface { - // getPolicy returns a lockingPolicy. It performs its own locking according - // to implementation. - getPolicy(storage logical.Storage, name string) (lockingPolicy, error) - - // refreshPolicy returns a lockingPolicy. It does not perform its own - // locking; a write lock must be held before calling. - refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) - - // generatePolicy generates and returns a lockingPolicy. A write lock must - // be held before calling. - generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) - - // deletePolicy deletes a lockingPolicy. A write lock must be held on both - // the CRUD implementation and the lockingPolicy before calling. - deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error - - // These are generally satisfied by embedded mutexes in the implementing struct - Lock() - RLock() - Unlock() - RUnlock() -} - -// The mutex is kept separate from the struct since we may set it to its own -// mutex (if the object is shared) or a shared mutex (if the object isn't -// shared and only the locking is) -type mutexLockingPolicy struct { - mutex *sync.RWMutex - policy *Policy -} - -func (m *mutexLockingPolicy) Lock() { - m.mutex.Lock() -} - -func (m *mutexLockingPolicy) RLock() { - m.mutex.RLock() -} - -func (m *mutexLockingPolicy) Unlock() { - m.mutex.Unlock() -} - -func (m *mutexLockingPolicy) RUnlock() { - m.mutex.RUnlock() -} - -func (m *mutexLockingPolicy) Policy() *Policy { - return m.policy -} - -func (m *mutexLockingPolicy) SetPolicy(p *Policy) { - m.policy = p -} - -// fetchPolicyFromStorage fetches the policy from backend storage. The caller -// should hold the write lock when calling this, to handle upgrades. -func fetchPolicyFromStorage(storage logical.Storage, name string) (*Policy, error) { - // Check if the policy already exists - raw, err := storage.Get("policy/" + name) - if err != nil { - return nil, err - } - if raw == nil { - return nil, nil - } - - // Decode the policy - policy := &Policy{ - Keys: KeyEntryMap{}, - } - err = json.Unmarshal(raw.Value, policy) - if err != nil { - return nil, err - } - - persistNeeded := false - // Ensure we've moved from Key -> Keys - if policy.Key != nil && len(policy.Key) > 0 { - policy.migrateKeyToKeysMap() - persistNeeded = true - } - - // With archiving, past assumptions about the length of the keys map are no longer valid - if policy.LatestVersion == 0 && len(policy.Keys) != 0 { - policy.LatestVersion = len(policy.Keys) - persistNeeded = true - } - - // We disallow setting the version to 0, since they start at 1 since moving - // to rotate-able keys, so update if it's set to 0 - if policy.MinDecryptionVersion == 0 { - policy.MinDecryptionVersion = 1 - persistNeeded = true - } - - // On first load after an upgrade, copy keys to the archive - if policy.ArchiveVersion == 0 { - persistNeeded = true - } - - if persistNeeded { - err = policy.Persist(storage) - if err != nil { - return nil, err - } - } - - return policy, nil -} - -// generatePolicyCommon is used to create a new named policy with a randomly -// generated key. The caller should have a write lock prior to calling this. -func generatePolicyCommon(p policyCRUD, storage logical.Storage, name string, derived bool) (*Policy, error) { - // Make sure this doesn't exist in case it was created before we got the write lock - policy, err := fetchPolicyFromStorage(storage, name) - if err != nil { - return nil, err - } - if policy != nil { - return policy, nil - } - - // Create the policy object - policy = &Policy{ - Name: name, - CipherMode: "aes-gcm", - Derived: derived, - } - if derived { - policy.KDFMode = kdfMode - } - - err = policy.rotate(storage) - if err != nil { - return nil, err - } - - return policy, err -} - -// deletePolicyCommon deletes a policy. The caller should hold the write lock -// for both the policy and lockingPolicy prior to calling this. -func deletePolicyCommon(p policyCRUD, lp lockingPolicy, storage logical.Storage, name string) error { - if lp.Policy() == nil { - // This got deleted before we grabbed the lock - return fmt.Errorf("policy already deleted") - } - - // Verify this hasn't changed - if !lp.Policy().DeletionAllowed { - return fmt.Errorf("deletion not allowed for policy %s", name) - } - - err := storage.Delete("policy/" + name) - if err != nil { - return fmt.Errorf("error deleting policy %s: %s", name, err) - } - - err = storage.Delete("archive/" + name) - if err != nil { - return fmt.Errorf("error deleting archive %s: %s", name, err) - } - - lp.SetPolicy(nil) - - return nil -} diff --git a/builtin/logical/transit/policy_test.go b/builtin/logical/transit/policy_test.go index af49c27d2..b6d0acf41 100644 --- a/builtin/logical/transit/policy_test.go +++ b/builtin/logical/transit/policy_test.go @@ -16,48 +16,51 @@ func resetKeysArchive() { } func Test_KeyUpgrade(t *testing.T) { - testKeyUpgradeCommon(t, newSimplePolicyCRUD()) - testKeyUpgradeCommon(t, newCachingPolicyCRUD()) + testKeyUpgradeCommon(t, newLockManager(false)) + testKeyUpgradeCommon(t, newLockManager(true)) } -func testKeyUpgradeCommon(t *testing.T, policies policyCRUD) { +func testKeyUpgradeCommon(t *testing.T, lm *lockManager) { storage := &logical.InmemStorage{} - lp, err := policies.generatePolicy(storage, "test", false) + p, lockType, upserted, err := lm.GetPolicyUpsert(storage, "test", false) if err != nil { t.Fatal(err) } - if lp == nil { - t.Fatal("nil lockingPolicy") + if p == nil { + t.Fatal("nil policy") + } + defer lm.UnlockPolicy("test", lockType) + + if !upserted { + t.Fatal("expected an upsert") + } + if lockType != exclusive { + t.Fatal("expected an exclusive lock") } - policy := lp.Policy() - if policy == nil { - t.Fatal("nil policy in lockingPolicy") - } + testBytes := make([]byte, len(p.Keys[1].Key)) + copy(testBytes, p.Keys[1].Key) - testBytes := make([]byte, len(policy.Keys[1].Key)) - copy(testBytes, policy.Keys[1].Key) - - policy.Key = policy.Keys[1].Key - policy.Keys = nil - policy.migrateKeyToKeysMap() - if policy.Key != nil { + p.Key = p.Keys[1].Key + p.Keys = nil + p.migrateKeyToKeysMap() + if p.Key != nil { t.Fatal("policy.Key is not nil") } - if len(policy.Keys) != 1 { + if len(p.Keys) != 1 { t.Fatal("policy.Keys is the wrong size") } - if !reflect.DeepEqual(testBytes, policy.Keys[1].Key) { + if !reflect.DeepEqual(testBytes, p.Keys[1].Key) { t.Fatal("key mismatch") } } func Test_ArchivingUpgrade(t *testing.T) { - testArchivingUpgradeCommon(t, newSimplePolicyCRUD()) - testArchivingUpgradeCommon(t, newCachingPolicyCRUD()) + testArchivingUpgradeCommon(t, newLockManager(false)) + testArchivingUpgradeCommon(t, newLockManager(true)) } -func testArchivingUpgradeCommon(t *testing.T, policies policyCRUD) { +func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) { resetKeysArchive() // First, we generate a policy and rotate it a number of times. Each time @@ -67,30 +70,26 @@ func testArchivingUpgradeCommon(t *testing.T, policies policyCRUD) { storage := &logical.InmemStorage{} - lp, err := policies.generatePolicy(storage, "test", false) + p, lockType, _, err := lm.GetPolicyUpsert(storage, "test", false) if err != nil { t.Fatal(err) } - if lp == nil { - t.Fatal("nil lockingPolicy") - } - - policy := lp.Policy() - if policy == nil { - t.Fatal("nil policy in lockingPolicy") + if p == nil { + t.Fatal("nil policy") } + lm.UnlockPolicy("test", lockType) // Store the initial key in the archive - keysArchive = append(keysArchive, policy.Keys[1]) - checkKeys(t, policy, storage, "initial", 1, 1, 1) + keysArchive = append(keysArchive, p.Keys[1]) + checkKeys(t, p, storage, "initial", 1, 1, 1) for i := 2; i <= 10; i++ { - err = policy.rotate(storage) + err = p.rotate(storage) if err != nil { t.Fatal(err) } - keysArchive = append(keysArchive, policy.Keys[i]) - checkKeys(t, policy, storage, "rotate", i, i, i) + keysArchive = append(keysArchive, p.Keys[i]) + checkKeys(t, p, storage, "rotate", i, i, i) } // Now, wipe the archive and set the archive version to zero @@ -98,53 +97,49 @@ func testArchivingUpgradeCommon(t *testing.T, policies policyCRUD) { if err != nil { t.Fatal(err) } - policy.ArchiveVersion = 0 + p.ArchiveVersion = 0 // Store it, but without calling persist, so we don't trigger // handleArchiving() - buf, err := policy.Serialize() + buf, err := p.Serialize() if err != nil { t.Fatal(err) } // Write the policy into storage err = storage.Put(&logical.StorageEntry{ - Key: "policy/" + policy.Name, + Key: "policy/" + p.Name, Value: buf, }) if err != nil { t.Fatal(err) } - // If it's a caching CRUD, expire from the cache since we modified it + // If we're caching, expire from the cache since we modified it // under-the-hood - if cachingCRUD, ok := policies.(*cachingPolicyCRUD); ok { - delete(cachingCRUD.cache, "test") + if lm.CacheActive() { + delete(lm.cache, "test") } // Now get the policy again; the upgrade should happen automatically - lp, err = policies.getPolicy(storage, "test") + p, lockType, err = lm.GetPolicy(storage, "test") if err != nil { t.Fatal(err) } - if lp == nil { + if p == nil { t.Fatal("nil lockingPolicy") } + lm.UnlockPolicy("test", lockType) - policy = lp.Policy() - if policy == nil { - t.Fatal("nil policy in lockingPolicy") - } - - checkKeys(t, policy, storage, "upgrade", 10, 10, 10) + checkKeys(t, p, storage, "upgrade", 10, 10, 10) } func Test_Archiving(t *testing.T) { - testArchivingCommon(t, newSimplePolicyCRUD()) - testArchivingCommon(t, newCachingPolicyCRUD()) + testArchivingCommon(t, newLockManager(false)) + testArchivingCommon(t, newLockManager(true)) } -func testArchivingCommon(t *testing.T, policies policyCRUD) { +func testArchivingCommon(t *testing.T, lm *lockManager) { resetKeysArchive() // First, we generate a policy and rotate it a number of times. Each time @@ -154,37 +149,33 @@ func testArchivingCommon(t *testing.T, policies policyCRUD) { storage := &logical.InmemStorage{} - lp, err := policies.generatePolicy(storage, "test", false) + p, lockType, _, err := lm.GetPolicyUpsert(storage, "test", false) if err != nil { t.Fatal(err) } - if lp == nil { - t.Fatal("nil lockingPolicy") - } - - policy := lp.Policy() - if policy == nil { - t.Fatal("nil policy in lockingPolicy") + if p == nil { + t.Fatal("nil policy") } + defer lm.UnlockPolicy("test", lockType) // Store the initial key in the archive - keysArchive = append(keysArchive, policy.Keys[1]) - checkKeys(t, policy, storage, "initial", 1, 1, 1) + keysArchive = append(keysArchive, p.Keys[1]) + checkKeys(t, p, storage, "initial", 1, 1, 1) for i := 2; i <= 10; i++ { - err = policy.rotate(storage) + err = p.rotate(storage) if err != nil { t.Fatal(err) } - keysArchive = append(keysArchive, policy.Keys[i]) - checkKeys(t, policy, storage, "rotate", i, i, i) + keysArchive = append(keysArchive, p.Keys[i]) + checkKeys(t, p, storage, "rotate", i, i, i) } // Move the min decryption version up for i := 1; i <= 10; i++ { - policy.MinDecryptionVersion = i + p.MinDecryptionVersion = i - err = policy.Persist(storage) + err = p.Persist(storage) if err != nil { t.Fatal(err) } @@ -196,14 +187,14 @@ func testArchivingCommon(t *testing.T, policies policyCRUD) { // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min // decryption version plus 1 (the min decryption version key // itself) - checkKeys(t, policy, storage, "minadd", 10, 10, policy.LatestVersion-policy.MinDecryptionVersion+1) + checkKeys(t, p, storage, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) } // Move the min decryption version down for i := 10; i >= 1; i-- { - policy.MinDecryptionVersion = i + p.MinDecryptionVersion = i - err = policy.Persist(storage) + err = p.Persist(storage) if err != nil { t.Fatal(err) } @@ -215,7 +206,7 @@ func testArchivingCommon(t *testing.T, policies policyCRUD) { // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min // decryption version plus 1 (the min decryption version key // itself) - checkKeys(t, policy, storage, "minsub", 10, 10, policy.LatestVersion-policy.MinDecryptionVersion+1) + checkKeys(t, p, storage, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) } } diff --git a/builtin/logical/transit/simple_crud.go b/builtin/logical/transit/simple_crud.go deleted file mode 100644 index 6538c7eea..000000000 --- a/builtin/logical/transit/simple_crud.go +++ /dev/null @@ -1,82 +0,0 @@ -package transit - -import ( - "sync" - - "github.com/hashicorp/vault/logical" -) - -// Directly implements CRUD operations without caching, mapped to the backend, -// but implements shared locking to ensure that we can't overwrite data on the -// backend from multiple operators -type simplePolicyCRUD struct { - sync.RWMutex - locks map[string]*sync.RWMutex -} - -func newSimplePolicyCRUD() *simplePolicyCRUD { - return &simplePolicyCRUD{ - locks: map[string]*sync.RWMutex{}, - } -} - -// The write lock must be held before calling this; for this CRUD type this -// should always be the case, since the only method not requiring a write lock -// when called is getPolicy, and that itself grabs a write lock before calling -// refreshPolicy -func (p *simplePolicyCRUD) ensureLockExists(name string) { - if p.locks[name] == nil { - p.locks[name] = &sync.RWMutex{} - } -} - -// See general comments on the interface method -func (p *simplePolicyCRUD) getPolicy(storage logical.Storage, name string) (lockingPolicy, error) { - // Use a write lock since fetching the policy can cause a need for upgrade persistence - p.Lock() - defer p.Unlock() - - return p.refreshPolicy(storage, name) -} - -// See general comments on the interface method -func (p *simplePolicyCRUD) refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) { - p.ensureLockExists(name) - - policy, err := fetchPolicyFromStorage(storage, name) - if err != nil { - return nil, err - } - if policy == nil { - return nil, nil - } - - lp := &mutexLockingPolicy{ - policy: policy, - mutex: p.locks[name], - } - - return lp, nil -} - -// See general comments on the interface method -func (p *simplePolicyCRUD) generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) { - p.ensureLockExists(name) - - policy, err := generatePolicyCommon(p, storage, name, derived) - if err != nil { - return nil, err - } - - lp := &mutexLockingPolicy{ - policy: policy, - mutex: p.locks[name], - } - - return lp, nil -} - -// See general comments on the interface method -func (p *simplePolicyCRUD) deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error { - return deletePolicyCommon(p, lp, storage, name) -}