diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index 3bb20d16d..bd50d8d2e 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -6,7 +6,7 @@ import ( ) func Factory(conf *logical.BackendConfig) (logical.Backend, error) { - b := Backend() + b := Backend(conf) be, err := b.Backend.Setup(conf) if err != nil { return nil, err @@ -15,7 +15,7 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) { return be, nil } -func Backend() *backend { +func Backend(conf *logical.BackendConfig) *backend { var b backend b.Backend = &framework.Backend{ Paths: []*framework.Path{ @@ -33,8 +33,10 @@ func Backend() *backend { Secrets: []*framework.Secret{}, } - b.policies = policyCache{ - cache: map[string]*lockingPolicy{}, + if conf.System.CachingDisabled() { + b.policies = newSimplePolicyCRUD() + } else { + b.policies = newCachingPolicyCRUD() } return &b @@ -42,5 +44,5 @@ func Backend() *backend { type backend struct { *framework.Backend - policies policyCache + policies policyCRUD } diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index dc0efaad4..8e385f31b 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -559,16 +559,31 @@ func TestPolicyFuzzing(t *testing.T) { return } - be := Backend() + var be *backend + sysView := logical.TestSystemView() + be = Backend(&logical.BackendConfig{ + System: sysView, + }) + testPolicyFuzzingCommon(t, be) + + sysView.CachingDisabledVal = true + be = Backend(&logical.BackendConfig{ + System: sysView, + }) + testPolicyFuzzingCommon(t, be) +} + +func testPolicyFuzzingCommon(t *testing.T, be *backend) { storage := &logical.LockingInmemStorage{} wg := sync.WaitGroup{} funcs := []string{"encrypt", "decrypt", "rotate", "change_min_version"} + //keys := []string{"test1", "test2", "test3", "test4", "test5"} keys := []string{"test1", "test2", "test3"} // This is the goroutine loop - doFuzzy := func() { + doFuzzy := func(id int) { // Check for panics, otherwise notify we're done defer func() { if err := recover(); err != nil { @@ -587,15 +602,24 @@ func TestPolicyFuzzing(t *testing.T) { } fd := &framework.FieldData{} + var retest bool + var chosenFunc, chosenKey string + + //t.Logf("Starting") for { // Stop after 10 seconds if time.Now().Sub(startTime) > 10*time.Second { + if retest { + t.Errorf("ended runtime on a retest, id is %d", id) + } return } // Pick a function and a key - chosenFunc := funcs[rand.Int()%len(funcs)] - chosenKey := keys[rand.Int()%len(keys)] + if !retest { + chosenFunc = funcs[rand.Int()%len(funcs)] + chosenKey = keys[rand.Int()%len(keys)] + } fd.Raw = map[string]interface{}{ "name": chosenKey, @@ -605,33 +629,36 @@ func TestPolicyFuzzing(t *testing.T) { // Try to write the key to make sure it exists _, err := be.pathPolicyWrite(req, fd) if err != nil { - t.Errorf("got an error: %v", err) + t.Fatalf("got an error: %v", err) return } switch chosenFunc { // Encrypt our plaintext and store the result case "encrypt": + //t.Logf("%s, %s", chosenFunc, chosenKey) fd.Raw["plaintext"] = base64.StdEncoding.EncodeToString([]byte(testPlaintext)) fd.Schema = be.pathEncrypt().Fields resp, err := be.pathEncryptWrite(req, fd) if err != nil { - t.Errorf("got an error: %v, resp is %#v", err, *resp) + t.Fatalf("got an error: %v, resp is %#v", err, *resp) return } latestEncryptedText[chosenKey] = resp.Data["ciphertext"].(string) // Rotate to a new key version case "rotate": + //t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id) fd.Schema = be.pathRotate().Fields resp, err := be.pathRotateWrite(req, fd) if err != nil { - t.Errorf("got an error: %v, resp is %#v", err, *resp) + t.Fatalf("got an error: %v, resp is %#v, chosenKey is %s", err, *resp, chosenKey) return } // Decrypt the ciphertext and compare the result case "decrypt": + //t.Logf("%s, %s", chosenFunc, chosenKey) ct := latestEncryptedText[chosenKey] if ct == "" { continue @@ -645,13 +672,14 @@ func TestPolicyFuzzing(t *testing.T) { if resp.Data["error"].(string) == ErrTooOld { continue } - t.Errorf("got an error: %v, resp is %#v, ciphertext was %s", err, *resp, latestEncryptedText[chosenKey]) - return + t.Errorf("got an error: %v, resp is %#v, ciphertext was %s, chosenKey is %s, id is %d", err, *resp, ct, chosenKey, id) + retest = true + continue } ptb64 := resp.Data["plaintext"].(string) pt, err := base64.StdEncoding.DecodeString(ptb64) if err != nil { - t.Errorf("got an error decoding base64 plaintext: %v", err) + t.Fatalf("got an error decoding base64 plaintext: %v", err) return } if string(pt) != testPlaintext { @@ -660,9 +688,10 @@ func TestPolicyFuzzing(t *testing.T) { // Change the min version, which also tests the archive functionality case "change_min_version": + //t.Logf("%s, %s", chosenFunc, chosenKey) resp, err := be.pathPolicyRead(req, fd) if err != nil { - t.Errorf("got an error reading policy %s: %v", chosenKey, err) + t.Fatalf("got an error reading policy %s: %v", chosenKey, err) return } latestVersion := resp.Data["latest_version"].(int) @@ -673,17 +702,22 @@ func TestPolicyFuzzing(t *testing.T) { fd.Schema = be.pathConfig().Fields resp, err = be.pathConfigWrite(req, fd) if err != nil { - t.Errorf("got an error setting min decryption version: %v", err) + t.Fatalf("got an error setting min decryption version: %v", err) return } } + + if retest { + t.Errorf("success, setting retest false, id is %d", id) + } + retest = false } } // Spawn 1000 of these workers for 10 seconds for i := 0; i < 1000; i++ { wg.Add(1) - go doFuzzy() + go doFuzzy(i) } // Wait for them all to finish diff --git a/builtin/logical/transit/caching_crud.go b/builtin/logical/transit/caching_crud.go new file mode 100644 index 000000000..b0f3a9cca --- /dev/null +++ b/builtin/logical/transit/caching_crud.go @@ -0,0 +1,109 @@ +package transit + +import ( + "sync" + + "github.com/hashicorp/vault/logical" +) + +// policyCache implements CRUD operations with a simple locking cache of +// policies +type cachingPolicyCRUD struct { + sync.RWMutex + cache map[string]lockingPolicy +} + +func newCachingPolicyCRUD() *cachingPolicyCRUD { + return &cachingPolicyCRUD{ + cache: map[string]lockingPolicy{}, + } +} + +func (p *cachingPolicyCRUD) getPolicy(storage logical.Storage, name string) (lockingPolicy, error) { + // We don't defer this since we may need to give it up and get a write lock + p.RLock() + + // First, see if we're in the cache -- if so, return that + if p.cache[name] != nil { + defer p.RUnlock() + return p.cache[name], nil + } + + // If we didn't find anything, we'll need to write into the cache, plus possibly + // persist the entry, so lock the cache + p.RUnlock() + p.Lock() + defer p.Unlock() + + // Check one more time to ensure that another process did not write during + // our lock switcheroo. + if p.cache[name] != nil { + return p.cache[name], nil + } + + return p.refreshPolicy(storage, name) +} + +func (p *cachingPolicyCRUD) refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) { + // Check once more to ensure it hasn't been added to the cache since the lock was acquired + if p.cache[name] != nil { + return p.cache[name], nil + } + + // Note that we don't need to create the locking entry until the end, + // because the policy wasn't in the cache so we don't know about it, and we + // hold the cache lock so nothing else can be writing it in right now + policy, err := fetchPolicyFromStorage(storage, name) + if err != nil { + return nil, err + } + if policy == nil { + return nil, nil + } + + lp := &mutexLockingPolicy{ + policy: policy, + mutex: &sync.RWMutex{}, + } + p.cache[name] = lp + + return lp, nil +} + +// generatePolicy is used to create a new named policy with a randomly +// generated key. The caller should hold the write lock prior to calling this. +func (p *cachingPolicyCRUD) generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) { + policy, err := generatePolicyCommon(p, storage, name, derived) + if err != nil { + return nil, err + } + + // Now we need to check again in the cache to ensure the policy wasn't + // created since we ran generatePolicy and then got the lock. A policy + // being created holds a write lock until it's done (starting from this + // point), so it'll be in the cache at this point. + if lp := p.cache[name]; lp != nil { + return lp, nil + } + + lp := &mutexLockingPolicy{ + policy: policy, + mutex: &sync.RWMutex{}, + } + p.cache[name] = lp + + // Return the policy + return lp, nil +} + +// deletePolicy deletes a policy +func (p *cachingPolicyCRUD) deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error { + err := deletePolicyCommon(p, lp, storage, name) + if err != nil { + return err + } + + delete(p.cache, name) + + return nil +} diff --git a/builtin/logical/transit/path_config.go b/builtin/logical/transit/path_config.go index 967b74341..cf4348bd5 100644 --- a/builtin/logical/transit/path_config.go +++ b/builtin/logical/transit/path_config.go @@ -42,26 +42,48 @@ func (b *backend) pathConfigWrite( name := d.Get("name").(string) // Check if the policy already exists - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return nil, err } if lp == nil { return logical.ErrorResponse( - fmt.Sprintf("no existing role named %s could be found", name)), + fmt.Sprintf("no existing key named %s could be found", name)), logical.ErrInvalidRequest } + currDeletionAllowed := lp.Policy().DeletionAllowed + currMinDecryptionVersion := lp.Policy().MinDecryptionVersion + + // Hold both locks since we want to ensure the policy doesn't change from underneath us + b.policies.Lock() + defer b.policies.Unlock() lp.Lock() defer lp.Unlock() + // Refresh in case it's changed since before we grabbed the lock + lp, err = b.policies.refreshPolicy(req.Storage, name) + if err != nil { + return nil, err + } + if lp == nil { + return nil, fmt.Errorf("error finding key %s after locking for changes", name) + } + // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { - return nil, fmt.Errorf("no existing role named %s could be found", name) + if lp.Policy() == nil { + return nil, fmt.Errorf("no existing key named %s could be found", name) } resp := &logical.Response{} + // Check for anything to have been updated since we got the policy + if currDeletionAllowed != lp.Policy().DeletionAllowed || + currMinDecryptionVersion != lp.Policy().MinDecryptionVersion { + resp.AddWarning("key configuration has changed since this endpoint was called, not updating") + return resp, nil + } + persistNeeded := false minDecryptionVersionRaw, ok := d.GetOk("min_decryption_version") @@ -78,12 +100,12 @@ func (b *backend) pathConfigWrite( } if minDecryptionVersion > 0 && - minDecryptionVersion != lp.policy.MinDecryptionVersion { - if minDecryptionVersion > lp.policy.LatestVersion { + minDecryptionVersion != lp.Policy().MinDecryptionVersion { + if minDecryptionVersion > lp.Policy().LatestVersion { return logical.ErrorResponse( - fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, lp.policy.LatestVersion)), nil + fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, lp.Policy().LatestVersion)), nil } - lp.policy.MinDecryptionVersion = minDecryptionVersion + lp.Policy().MinDecryptionVersion = minDecryptionVersion persistNeeded = true } } @@ -91,8 +113,8 @@ func (b *backend) pathConfigWrite( allowDeletionInt, ok := d.GetOk("deletion_allowed") if ok { allowDeletion := allowDeletionInt.(bool) - if allowDeletion != lp.policy.DeletionAllowed { - lp.policy.DeletionAllowed = allowDeletion + if allowDeletion != lp.Policy().DeletionAllowed { + lp.Policy().DeletionAllowed = allowDeletion persistNeeded = true } } @@ -100,8 +122,8 @@ func (b *backend) pathConfigWrite( // Add this as a guard here before persisting since we now require the min // decryption version to start at 1; even if it's not explicitly set here, // force the upgrade - if lp.policy.MinDecryptionVersion == 0 { - lp.policy.MinDecryptionVersion = 1 + if lp.Policy().MinDecryptionVersion == 0 { + lp.Policy().MinDecryptionVersion = 1 persistNeeded = true } @@ -109,7 +131,7 @@ func (b *backend) pathConfigWrite( return nil, nil } - return resp, lp.policy.Persist(req.Storage) + return resp, lp.Policy().Persist(req.Storage) } const pathConfigHelpSyn = `Configure a named encryption key` diff --git a/builtin/logical/transit/path_datakey.go b/builtin/logical/transit/path_datakey.go index f2bec71d0..9e6483673 100644 --- a/builtin/logical/transit/path_datakey.go +++ b/builtin/logical/transit/path_datakey.go @@ -73,7 +73,7 @@ func (b *backend) pathDatakeyWrite( } // Get the policy - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return nil, err } @@ -87,7 +87,7 @@ func (b *backend) pathDatakeyWrite( defer lp.RUnlock() // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { + if lp.Policy() == nil { return nil, fmt.Errorf("no existing policy named %s could be found", name) } @@ -107,7 +107,7 @@ func (b *backend) pathDatakeyWrite( return nil, err } - ciphertext, err := lp.policy.Encrypt(context, base64.StdEncoding.EncodeToString(newKey)) + ciphertext, err := lp.Policy().Encrypt(context, base64.StdEncoding.EncodeToString(newKey)) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index 66191b629..a58709858 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -58,7 +58,7 @@ func (b *backend) pathDecryptWrite( } // Get the policy - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return nil, err } @@ -72,11 +72,11 @@ func (b *backend) pathDecryptWrite( defer lp.RUnlock() // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { + if lp.Policy() == nil { return nil, fmt.Errorf("no existing policy named %s could be found", name) } - plaintext, err := lp.policy.Decrypt(context, ciphertext) + plaintext, err := lp.Policy().Decrypt(context, ciphertext) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index fc2e2048f..8c4cc91e7 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -44,7 +44,7 @@ func (b *backend) pathEncrypt() *framework.Path { func (b *backend) pathEncryptExistenceCheck( req *logical.Request, d *framework.FieldData) (bool, error) { name := d.Get("name").(string) - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return false, err } @@ -72,37 +72,43 @@ func (b *backend) pathEncryptWrite( } // Get the policy - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return nil, err } - // Error if invalid policy + // Error or upsert if invalid policy if lp == nil { if req.Operation != logical.CreateOperation { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } + // Get a write lock + b.policies.Lock() + isDerived := len(context) != 0 + // This also checks to make sure one hasn't been created since we grabbed the write lock lp, err = b.policies.generatePolicy(req.Storage, name, isDerived) // If the error is that the policy has been created in the interim we // will get the policy back, so only consider it an error if err is not // nil and we do not get a policy back if err != nil && lp != nil { + b.policies.Unlock() return nil, err } + b.policies.Unlock() } lp.RLock() defer lp.RUnlock() // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { + if lp.Policy() == nil { return nil, fmt.Errorf("no existing policy named %s could be found", name) } - ciphertext, err := lp.policy.Encrypt(context, value) + ciphertext, err := lp.Policy().Encrypt(context, value) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 32587b6c6..02d7f882b 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -36,20 +36,14 @@ func (b *backend) pathKeys() *framework.Path { func (b *backend) pathPolicyWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + b.policies.Lock() + defer b.policies.Unlock() + name := d.Get("name").(string) derived := d.Get("derived").(bool) - // Check if the policy already exists - existing, err := b.policies.getPolicy(req, name) - if err != nil { - return nil, err - } - if existing != nil { - return nil, nil - } - - // Generate the policy - _, err = b.policies.generatePolicy(req.Storage, name, derived) + // Generate the policy; this will also check if it exists for safety + _, err := b.policies.generatePolicy(req.Storage, name, derived) return nil, err } @@ -57,7 +51,7 @@ func (b *backend) pathPolicyRead( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return nil, err } @@ -69,27 +63,27 @@ func (b *backend) pathPolicyRead( defer lp.RUnlock() // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { + if lp.Policy() == nil { return nil, fmt.Errorf("no existing policy named %s could be found", name) } // Return the response resp := &logical.Response{ Data: map[string]interface{}{ - "name": lp.policy.Name, - "cipher_mode": lp.policy.CipherMode, - "derived": lp.policy.Derived, - "deletion_allowed": lp.policy.DeletionAllowed, - "min_decryption_version": lp.policy.MinDecryptionVersion, - "latest_version": lp.policy.LatestVersion, + "name": lp.Policy().Name, + "cipher_mode": lp.Policy().CipherMode, + "derived": lp.Policy().Derived, + "deletion_allowed": lp.Policy().DeletionAllowed, + "min_decryption_version": lp.Policy().MinDecryptionVersion, + "latest_version": lp.Policy().LatestVersion, }, } - if lp.policy.Derived { - resp.Data["kdf_mode"] = lp.policy.KDFMode + if lp.Policy().Derived { + resp.Data["kdf_mode"] = lp.Policy().KDFMode } retKeys := map[string]int64{} - for k, v := range lp.policy.Keys { + for k, v := range lp.Policy().Keys { retKeys[strconv.Itoa(k)] = v.CreationTime } resp.Data["keys"] = retKeys @@ -101,7 +95,7 @@ func (b *backend) pathPolicyDelete( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err } @@ -109,7 +103,12 @@ func (b *backend) pathPolicyDelete( return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest } - err = b.policies.deletePolicy(req.Storage, name) + b.policies.Lock() + defer b.policies.Unlock() + lp.Lock() + defer lp.Unlock() + + err = b.policies.deletePolicy(req.Storage, lp, name) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err } diff --git a/builtin/logical/transit/path_rewrap.go b/builtin/logical/transit/path_rewrap.go index d8d01bddc..2233689b3 100644 --- a/builtin/logical/transit/path_rewrap.go +++ b/builtin/logical/transit/path_rewrap.go @@ -59,7 +59,7 @@ func (b *backend) pathRewrapWrite( } // Get the policy - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return nil, err } @@ -73,11 +73,11 @@ func (b *backend) pathRewrapWrite( defer lp.RUnlock() // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { + if lp.Policy() == nil { return nil, fmt.Errorf("no existing policy named %s could be found", name) } - plaintext, err := lp.policy.Decrypt(context, value) + plaintext, err := lp.Policy().Decrypt(context, value) if err != nil { switch err.(type) { case certutil.UserError: @@ -93,7 +93,7 @@ func (b *backend) pathRewrapWrite( return nil, fmt.Errorf("empty plaintext returned during rewrap") } - ciphertext, err := lp.policy.Encrypt(context, plaintext) + ciphertext, err := lp.Policy().Encrypt(context, plaintext) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_rotate.go b/builtin/logical/transit/path_rotate.go index 90bbc2e18..dbde58e28 100644 --- a/builtin/logical/transit/path_rotate.go +++ b/builtin/logical/transit/path_rotate.go @@ -31,26 +31,48 @@ func (b *backend) pathRotateWrite( name := d.Get("name").(string) // Get the policy - lp, err := b.policies.getPolicy(req, name) + lp, err := b.policies.getPolicy(req.Storage, name) if err != nil { return nil, err } // Error if invalid policy if lp == nil { - return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest + return logical.ErrorResponse("key not found"), logical.ErrInvalidRequest } + keyVersion := lp.Policy().LatestVersion + + // lock the policies object so we can refresh + b.policies.Lock() + defer b.policies.Unlock() lp.Lock() defer lp.Unlock() - // Verify if wasn't deleted before we grabbed the lock - if lp.policy == nil { - return nil, fmt.Errorf("no existing policy named %s could be found", name) + // Refresh in case it's changed since before we grabbed the lock + lp, err = b.policies.refreshPolicy(req.Storage, name) + if err != nil { + return nil, err + } + if lp == nil { + return nil, fmt.Errorf("error finding key %s after locking for changes", name) } - // Generate the policy - err = lp.policy.rotate(req.Storage) + // Verify if wasn't deleted before we grabbed the lock + if lp.Policy() == nil { + return nil, fmt.Errorf("no existing key named %s could be found", name) + } + + // Make sure that the policy hasn't been rotated simultaneously + if keyVersion != lp.Policy().LatestVersion { + resp := &logical.Response{} + resp.AddWarning("key has been rotated since this endpoint was called; did not perform rotation") + return resp, nil + } + + //fmt.Printf("Rotating key %s, orig seen version is %d, currVersion is %d\n", name, keyVersion, lp.Policy().LatestVersion) + // Rotate the policy + err = lp.Policy().rotate(req.Storage) return nil, err } diff --git a/builtin/logical/transit/policy.go b/builtin/logical/transit/policy.go index 17d7cc00b..ad47af754 100644 --- a/builtin/logical/transit/policy.go +++ b/builtin/logical/transit/policy.go @@ -9,7 +9,6 @@ import ( "fmt" "strconv" "strings" - "sync" "time" "github.com/hashicorp/vault/helper/certutil" @@ -24,197 +23,6 @@ const ( ErrTooOld = "ciphertext version is disallowed by policy (too old)" ) -// policyCache implements a simple locking cache of policies -type policyCache struct { - sync.RWMutex - cache map[string]*lockingPolicy -} - -// getPolicy loads a policy into the cache or returns one already in the cache -func (p *policyCache) getPolicy(req *logical.Request, name string) (*lockingPolicy, error) { - // We don't defer this since we may need to give it up and get a write lock - p.RLock() - - // First, see if we're in the cache -- if so, return that - if p.cache[name] != nil { - defer p.RUnlock() - return p.cache[name], nil - } - - // If we didn't find anything, we'll need to write into the cache, plus possibly - // persist the entry, so lock the cache - p.RUnlock() - p.Lock() - defer p.Unlock() - - // Check one more time to ensure that another process did not write during - // our lock switcheroo. - if p.cache[name] != nil { - return p.cache[name], nil - } - - // Note that we don't need to create the locking entry until the end, - // because the policy wasn't in the cache so we don't know about it, and we - // hold the cache lock so nothing else can be writing it in right now - - // Check if the policy already exists - raw, err := req.Storage.Get("policy/" + name) - if err != nil { - return nil, err - } - if raw == nil { - return nil, nil - } - - // Decode the policy - policy := &Policy{ - Keys: KeyEntryMap{}, - } - err = json.Unmarshal(raw.Value, policy) - if err != nil { - return nil, err - } - - persistNeeded := false - // Ensure we've moved from Key -> Keys - if policy.Key != nil && len(policy.Key) > 0 { - policy.migrateKeyToKeysMap() - persistNeeded = true - } - - // With archiving, past assumptions about the length of the keys map are no longer valid - if policy.LatestVersion == 0 && len(policy.Keys) != 0 { - policy.LatestVersion = len(policy.Keys) - persistNeeded = true - } - - // We disallow setting the version to 0, since they start at 1 since moving - // to rotate-able keys, so update if it's set to 0 - if policy.MinDecryptionVersion == 0 { - policy.MinDecryptionVersion = 1 - persistNeeded = true - } - - // On first load after an upgrade, copy keys to the archive - if policy.ArchiveVersion == 0 { - persistNeeded = true - } - - if persistNeeded { - err = policy.Persist(req.Storage) - if err != nil { - return nil, err - } - } - - lp := &lockingPolicy{ - policy: policy, - } - p.cache[name] = lp - - return lp, nil -} - -// generatePolicy is used to create a new named policy with a randomly -// generated key -func (p *policyCache) generatePolicy(storage logical.Storage, name string, derived bool) (*lockingPolicy, error) { - // Ensure one with this name doesn't already exist - lp, err := p.getPolicy(&logical.Request{ - Storage: storage, - }, name) - if err != nil { - return nil, fmt.Errorf("error checking if policy already exists: %s", err) - } - if lp != nil { - return nil, fmt.Errorf("policy %s already exists", name) - } - - p.Lock() - defer p.Unlock() - - // Now we need to check again in the cache to ensure the policy wasn't - // created since we checked getPolicy. A policy being created holds a write - // lock until it's done, so it'll be in the cache at this point. - if lp := p.cache[name]; lp != nil { - return nil, fmt.Errorf("policy %s already exists", name) - } - - // Create the policy object - policy := &Policy{ - Name: name, - CipherMode: "aes-gcm", - Derived: derived, - } - if derived { - policy.KDFMode = kdfMode - } - - err = policy.rotate(storage) - if err != nil { - return nil, err - } - - lp = &lockingPolicy{ - policy: policy, - } - p.cache[name] = lp - - // Return the policy - return lp, nil -} - -// deletePolicy deletes a policy -func (p *policyCache) deletePolicy(storage logical.Storage, name string) error { - // Ensure one with this name exists - lp, err := p.getPolicy(&logical.Request{ - Storage: storage, - }, name) - if err != nil { - return fmt.Errorf("error checking if policy already exists: %s", err) - } - if lp == nil { - return fmt.Errorf("policy %s does not exist", name) - } - - p.Lock() - defer p.Unlock() - - lp = p.cache[name] - if lp == nil { - return fmt.Errorf("policy %s not found", name) - } - - // We need to ensure all other access has stopped - lp.Lock() - defer lp.Unlock() - - // Verify this hasn't changed - if !lp.policy.DeletionAllowed { - return fmt.Errorf("deletion not allowed for policy %s", name) - } - - err = storage.Delete("policy/" + name) - if err != nil { - return fmt.Errorf("error deleting policy %s: %s", name, err) - } - - err = storage.Delete("archive/" + name) - if err != nil { - return fmt.Errorf("error deleting archive %s: %s", name, err) - } - - lp.policy = nil - delete(p.cache, name) - - return nil -} - -// lockingPolicy holds a Policy guarded by a lock -type lockingPolicy struct { - sync.RWMutex - policy *Policy -} - // KeyEntry stores the key and metadata type KeyEntry struct { Key []byte `json:"key"` @@ -521,17 +329,17 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) { func (p *Policy) Decrypt(context []byte, value string) (string, error) { // Verify the prefix if !strings.HasPrefix(value, "vault:v") { - return "", certutil.UserError{Err: "invalid ciphertext"} + return "", certutil.UserError{Err: "invalid ciphertext: no prefix"} } splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2) if len(splitVerCiphertext) != 2 { - return "", certutil.UserError{Err: "invalid ciphertext"} + return "", certutil.UserError{Err: "invalid ciphertext: wrong number of fields"} } ver, err := strconv.Atoi(splitVerCiphertext[0]) if err != nil { - return "", certutil.UserError{Err: "invalid ciphertext"} + return "", certutil.UserError{Err: "invalid ciphertext: version number could not be decoded"} } if ver == 0 { @@ -540,6 +348,10 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) { ver = 1 } + if ver > p.LatestVersion { + return "", certutil.UserError{Err: "invalid ciphertext: version is too new"} + } + if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion { return "", certutil.UserError{Err: ErrTooOld} } @@ -560,7 +372,7 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) { // Decode the base64 decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1]) if err != nil { - return "", certutil.UserError{Err: "invalid ciphertext"} + return "", certutil.UserError{Err: "invalid ciphertext: could not decode base64"} } // Setup the cipher @@ -582,7 +394,7 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) { // Verify and Decrypt plain, err := gcm.Open(nil, nonce, ciphertext, nil) if err != nil { - return "", certutil.UserError{Err: "invalid ciphertext"} + return "", certutil.UserError{Err: "invalid ciphertext: unable to decrypt"} } return base64.StdEncoding.EncodeToString(plain), nil @@ -617,6 +429,8 @@ func (p *Policy) rotate(storage logical.Storage) error { p.MinDecryptionVersion = 1 } + //fmt.Printf("policy %s rotated to %d\n", p.Name, p.LatestVersion) + return p.Persist(storage) } diff --git a/builtin/logical/transit/policy_crud.go b/builtin/logical/transit/policy_crud.go new file mode 100644 index 000000000..90c68c90f --- /dev/null +++ b/builtin/logical/transit/policy_crud.go @@ -0,0 +1,185 @@ +package transit + +import ( + "encoding/json" + "fmt" + "sync" + + "github.com/hashicorp/vault/logical" +) + +type lockingPolicy interface { + Lock() + RLock() + Unlock() + RUnlock() + Policy() *Policy + SetPolicy(*Policy) +} + +type policyCRUD interface { + // getPolicy returns a lockingPolicy. It performs its own locking according + // to implementation. + getPolicy(storage logical.Storage, name string) (lockingPolicy, error) + + // refreshPolicy returns a lockingPolicy. It does not perform its own + // locking; a write lock must be held before calling. + refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) + + // generatePolicy generates and returns a lockingPolicy. A write lock must + // be held before calling. + generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) + + // deletePolicy deletes a lockingPolicy. A write lock must be held on both + // the CRUD implementation and the lockingPolicy before calling. + deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error + + // These are generally satisfied by embedded mutexes in the implementing struct + Lock() + RLock() + Unlock() + RUnlock() +} + +type mutexLockingPolicy struct { + mutex *sync.RWMutex + policy *Policy +} + +func (m *mutexLockingPolicy) Lock() { + m.mutex.Lock() +} + +func (m *mutexLockingPolicy) RLock() { + m.mutex.RLock() +} + +func (m *mutexLockingPolicy) Unlock() { + m.mutex.Unlock() +} + +func (m *mutexLockingPolicy) RUnlock() { + m.mutex.RUnlock() +} + +func (m *mutexLockingPolicy) Policy() *Policy { + return m.policy +} + +func (m *mutexLockingPolicy) SetPolicy(p *Policy) { + m.policy = p +} + +// The caller should hold the write lock when calling this +func fetchPolicyFromStorage(storage logical.Storage, name string) (*Policy, error) { + // Check if the policy already exists + raw, err := storage.Get("policy/" + name) + if err != nil { + return nil, err + } + if raw == nil { + return nil, nil + } + + // Decode the policy + policy := &Policy{ + Keys: KeyEntryMap{}, + } + err = json.Unmarshal(raw.Value, policy) + if err != nil { + return nil, err + } + + persistNeeded := false + // Ensure we've moved from Key -> Keys + if policy.Key != nil && len(policy.Key) > 0 { + policy.migrateKeyToKeysMap() + persistNeeded = true + } + + // With archiving, past assumptions about the length of the keys map are no longer valid + if policy.LatestVersion == 0 && len(policy.Keys) != 0 { + policy.LatestVersion = len(policy.Keys) + persistNeeded = true + } + + // We disallow setting the version to 0, since they start at 1 since moving + // to rotate-able keys, so update if it's set to 0 + if policy.MinDecryptionVersion == 0 { + policy.MinDecryptionVersion = 1 + persistNeeded = true + } + + // On first load after an upgrade, copy keys to the archive + if policy.ArchiveVersion == 0 { + persistNeeded = true + } + + if persistNeeded { + err = policy.Persist(storage) + if err != nil { + return nil, err + } + } + + return policy, nil +} + +// generatePolicyCommon is used to create a new named policy with a randomly +// generated key. The caller should have a write lock prior to calling this. +func generatePolicyCommon(p policyCRUD, storage logical.Storage, name string, derived bool) (*Policy, error) { + // Make sure this doesn't exist in case it was created before we got the write lock + policy, err := fetchPolicyFromStorage(storage, name) + if err != nil { + return nil, err + } + if policy != nil { + return policy, nil + } + + //log.Printf("generating a new policy with name %s", name) + + // Create the policy object + policy = &Policy{ + Name: name, + CipherMode: "aes-gcm", + Derived: derived, + } + if derived { + policy.KDFMode = kdfMode + } + + err = policy.rotate(storage) + if err != nil { + return nil, err + } + + return policy, err +} + +// deletePolicy deletes a policy. The caller should hold the write lock for both the policy and lockingPolicy prior to calling this. +func deletePolicyCommon(p policyCRUD, lp lockingPolicy, storage logical.Storage, name string) error { + if lp.Policy() == nil { + // This got deleted before we grabbed the lock + return fmt.Errorf("policy already deleted") + } + + // Verify this hasn't changed + if !lp.Policy().DeletionAllowed { + return fmt.Errorf("deletion not allowed for policy %s", name) + } + + err := storage.Delete("policy/" + name) + if err != nil { + return fmt.Errorf("error deleting policy %s: %s", name, err) + } + + err = storage.Delete("archive/" + name) + if err != nil { + return fmt.Errorf("error deleting archive %s: %s", name, err) + } + + lp.SetPolicy(nil) + + return nil +} diff --git a/builtin/logical/transit/policy_test.go b/builtin/logical/transit/policy_test.go index 04b3c3bbb..af49c27d2 100644 --- a/builtin/logical/transit/policy_test.go +++ b/builtin/logical/transit/policy_test.go @@ -16,19 +16,24 @@ func resetKeysArchive() { } func Test_KeyUpgrade(t *testing.T) { + testKeyUpgradeCommon(t, newSimplePolicyCRUD()) + testKeyUpgradeCommon(t, newCachingPolicyCRUD()) +} + +func testKeyUpgradeCommon(t *testing.T, policies policyCRUD) { storage := &logical.InmemStorage{} - policies := &policyCache{ - cache: map[string]*lockingPolicy{}, - } lp, err := policies.generatePolicy(storage, "test", false) if err != nil { t.Fatal(err) } if lp == nil { - t.Fatal("nil policy") + t.Fatal("nil lockingPolicy") } - policy := lp.policy + policy := lp.Policy() + if policy == nil { + t.Fatal("nil policy in lockingPolicy") + } testBytes := make([]byte, len(policy.Keys[1].Key)) copy(testBytes, policy.Keys[1].Key) @@ -48,6 +53,11 @@ func Test_KeyUpgrade(t *testing.T) { } func Test_ArchivingUpgrade(t *testing.T) { + testArchivingUpgradeCommon(t, newSimplePolicyCRUD()) + testArchivingUpgradeCommon(t, newCachingPolicyCRUD()) +} + +func testArchivingUpgradeCommon(t *testing.T, policies policyCRUD) { resetKeysArchive() // First, we generate a policy and rotate it a number of times. Each time @@ -56,19 +66,19 @@ func Test_ArchivingUpgrade(t *testing.T) { // zero and latest, respectively storage := &logical.InmemStorage{} - policies := &policyCache{ - cache: map[string]*lockingPolicy{}, - } lp, err := policies.generatePolicy(storage, "test", false) if err != nil { t.Fatal(err) } if lp == nil { - t.Fatal("policy is nil") + t.Fatal("nil lockingPolicy") } - policy := lp.policy + policy := lp.Policy() + if policy == nil { + t.Fatal("nil policy in lockingPolicy") + } // Store the initial key in the archive keysArchive = append(keysArchive, policy.Keys[1]) @@ -106,26 +116,35 @@ func Test_ArchivingUpgrade(t *testing.T) { t.Fatal(err) } - // Expire from the cache since we modified it under-the-hood - delete(policies.cache, "test") + // If it's a caching CRUD, expire from the cache since we modified it + // under-the-hood + if cachingCRUD, ok := policies.(*cachingPolicyCRUD); ok { + delete(cachingCRUD.cache, "test") + } // Now get the policy again; the upgrade should happen automatically - lp, err = policies.getPolicy(&logical.Request{ - Storage: storage, - }, "test") + lp, err = policies.getPolicy(storage, "test") if err != nil { t.Fatal(err) } if lp == nil { - t.Fatal("policy is nil") + t.Fatal("nil lockingPolicy") } - policy = lp.policy + policy = lp.Policy() + if policy == nil { + t.Fatal("nil policy in lockingPolicy") + } checkKeys(t, policy, storage, "upgrade", 10, 10, 10) } func Test_Archiving(t *testing.T) { + testArchivingCommon(t, newSimplePolicyCRUD()) + testArchivingCommon(t, newCachingPolicyCRUD()) +} + +func testArchivingCommon(t *testing.T, policies policyCRUD) { resetKeysArchive() // First, we generate a policy and rotate it a number of times. Each time @@ -135,19 +154,18 @@ func Test_Archiving(t *testing.T) { storage := &logical.InmemStorage{} - policies := &policyCache{ - cache: map[string]*lockingPolicy{}, - } - lp, err := policies.generatePolicy(storage, "test", false) if err != nil { t.Fatal(err) } if lp == nil { - t.Fatal("policy is nil") + t.Fatal("nil lockingPolicy") } - policy := lp.policy + policy := lp.Policy() + if policy == nil { + t.Fatal("nil policy in lockingPolicy") + } // Store the initial key in the archive keysArchive = append(keysArchive, policy.Keys[1]) diff --git a/builtin/logical/transit/simple_crud.go b/builtin/logical/transit/simple_crud.go new file mode 100644 index 000000000..874f760d7 --- /dev/null +++ b/builtin/logical/transit/simple_crud.go @@ -0,0 +1,88 @@ +package transit + +import ( + "sync" + + "github.com/hashicorp/vault/logical" +) + +// Directly implements CRUD operations without caching, mapped to the backend, +// but implements locking to ensure that we can't overwrite data on the backend +// from multiple operators +type simplePolicyCRUD struct { + sync.RWMutex + locks map[string]*sync.RWMutex + locksMapMutex sync.RWMutex +} + +func newSimplePolicyCRUD() *simplePolicyCRUD { + return &simplePolicyCRUD{ + locks: map[string]*sync.RWMutex{}, + } +} + +func (p *simplePolicyCRUD) ensureLockExists(name string) { + p.locksMapMutex.RLock() + + if p.locks[name] == nil { + p.locksMapMutex.RUnlock() + p.locksMapMutex.Lock() + // Make sure nothing has appeared since we switched the lock type + if p.locks[name] == nil { + p.locks[name] = &sync.RWMutex{} + } + p.locksMapMutex.Unlock() + return + } + + p.locksMapMutex.RUnlock() +} + +func (p *simplePolicyCRUD) getPolicy(storage logical.Storage, name string) (lockingPolicy, error) { + // Use a write lock since fetching the policy can cause a need for upgrade persistence + p.Lock() + defer p.Unlock() + + return p.refreshPolicy(storage, name) +} + +func (p *simplePolicyCRUD) refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) { + p.ensureLockExists(name) + + policy, err := fetchPolicyFromStorage(storage, name) + if err != nil { + return nil, err + } + if policy == nil { + return nil, nil + } + + lp := &mutexLockingPolicy{ + policy: policy, + mutex: p.locks[name], + } + + return lp, nil +} + +// The caller must hold the write lock when calling this +func (p *simplePolicyCRUD) generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) { + p.ensureLockExists(name) + + policy, err := generatePolicyCommon(p, storage, name, derived) + if err != nil { + return nil, err + } + + lp := &mutexLockingPolicy{ + policy: policy, + mutex: p.locks[name], + } + + return lp, nil +} + +// The caller must hold the write lock when calling this +func (p *simplePolicyCRUD) deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error { + return deletePolicyCommon(p, lp, storage, name) +} diff --git a/logical/system_view.go b/logical/system_view.go index 4e26300cb..d20bf0c37 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -29,7 +29,7 @@ type SystemView interface { // Returns true if caching is disabled. If true, no caches should be used, // despite known slowdowns. - CacheDisabled() bool + CachingDisabled() bool } type StaticSystemView struct { @@ -37,7 +37,7 @@ type StaticSystemView struct { MaxLeaseTTLVal time.Duration SudoPrivilegeVal bool TaintedVal bool - CacheDisabledVal bool + CachingDisabledVal bool } func (d StaticSystemView) DefaultLeaseTTL() time.Duration { @@ -56,6 +56,6 @@ func (d StaticSystemView) Tainted() bool { return d.TaintedVal } -func (d StaticSystemView) CacheDisabled() bool { - return d.CacheDisabledVal +func (d StaticSystemView) CachingDisabled() bool { + return d.CachingDisabledVal } diff --git a/vault/core.go b/vault/core.go index 12ceefb92..44dcd298e 100644 --- a/vault/core.go +++ b/vault/core.go @@ -219,8 +219,8 @@ type Core struct { logger *log.Logger - // cacheDisabled indicates whether caches are disabled - cacheDisabled bool + // cachingDisabled indicates whether caches are disabled + cachingDisabled bool } // CoreConfig is used to parameterize a core @@ -318,7 +318,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { logger: conf.Logger, defaultLeaseTTL: conf.DefaultLeaseTTL, maxLeaseTTL: conf.MaxLeaseTTL, - cacheDisabled: conf.DisableCache, + cachingDisabled: conf.DisableCache, } // Setup the backends diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index b4ef6a77a..8dc806de6 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -70,7 +70,7 @@ func (d dynamicSystemView) Tainted() bool { return d.mountEntry.Tainted } -// CacheDisabled indicates whether to use caching behavior -func (d dynamicSystemView) CacheDisabled() bool { - return d.core.cacheDisabled +// CachingDisabled indicates whether to use caching behavior +func (d dynamicSystemView) CachingDisabled() bool { + return d.core.cachingDisabled } diff --git a/vault/policy_store.go b/vault/policy_store.go index d4729f740..862002a42 100644 --- a/vault/policy_store.go +++ b/vault/policy_store.go @@ -40,7 +40,7 @@ func NewPolicyStore(view *BarrierView, system logical.SystemView) *PolicyStore { view: view, system: system, } - if !system.CacheDisabled() { + if !system.CachingDisabled() { cache, _ := lru.New2Q(policyCacheSize) p.lru = cache } @@ -100,7 +100,7 @@ func (ps *PolicyStore) SetPolicy(p *Policy) error { return fmt.Errorf("failed to persist policy: %v", err) } - if !ps.system.CacheDisabled() { + if !ps.system.CachingDisabled() { // Update the LRU cache ps.lru.Add(p.Name, p) } @@ -110,7 +110,7 @@ func (ps *PolicyStore) SetPolicy(p *Policy) error { // GetPolicy is used to fetch the named policy func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) { defer metrics.MeasureSince([]string{"policy", "get_policy"}, time.Now()) - if !ps.system.CacheDisabled() { + if !ps.system.CachingDisabled() { // Check for cached policy if raw, ok := ps.lru.Get(name); ok { return raw.(*Policy), nil @@ -120,7 +120,7 @@ func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) { // Special case the root policy if name == "root" { p := &Policy{Name: "root"} - if !ps.system.CacheDisabled() { + if !ps.system.CachingDisabled() { ps.lru.Add(p.Name, p) } return p, nil @@ -163,7 +163,7 @@ func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) { policy = p } - if !ps.system.CacheDisabled() { + if !ps.system.CachingDisabled() { // Update the LRU cache ps.lru.Add(name, policy) } @@ -192,7 +192,7 @@ func (ps *PolicyStore) DeletePolicy(name string) error { return fmt.Errorf("failed to delete policy: %v", err) } - if !ps.system.CacheDisabled() { + if !ps.system.CachingDisabled() { // Clear the cache ps.lru.Remove(name) } diff --git a/vault/policy_store_test.go b/vault/policy_store_test.go index 456bc8375..05cbd1c79 100644 --- a/vault/policy_store_test.go +++ b/vault/policy_store_test.go @@ -16,7 +16,7 @@ func mockPolicyStore(t *testing.T) *PolicyStore { func mockPolicyStoreNoCache(t *testing.T) *PolicyStore { sysView := logical.TestSystemView() - sysView.CacheDisabledVal = true + sysView.CachingDisabledVal = true _, barrier, _ := mockBarrier(t) view := NewBarrierView(barrier, "foo/") p := NewPolicyStore(view, sysView)