Make a non-caching but still locking variant of transit for when caches are disabled

This commit is contained in:
Jeff Mitchell 2016-04-21 20:32:06 +00:00
parent 8572190b64
commit fe1f56de40
19 changed files with 613 additions and 314 deletions

View File

@ -6,7 +6,7 @@ import (
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
b := Backend(conf)
be, err := b.Backend.Setup(conf)
if err != nil {
return nil, err
@ -15,7 +15,7 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
return be, nil
}
func Backend() *backend {
func Backend(conf *logical.BackendConfig) *backend {
var b backend
b.Backend = &framework.Backend{
Paths: []*framework.Path{
@ -33,8 +33,10 @@ func Backend() *backend {
Secrets: []*framework.Secret{},
}
b.policies = policyCache{
cache: map[string]*lockingPolicy{},
if conf.System.CachingDisabled() {
b.policies = newSimplePolicyCRUD()
} else {
b.policies = newCachingPolicyCRUD()
}
return &b
@ -42,5 +44,5 @@ func Backend() *backend {
type backend struct {
*framework.Backend
policies policyCache
policies policyCRUD
}

View File

@ -559,16 +559,31 @@ func TestPolicyFuzzing(t *testing.T) {
return
}
be := Backend()
var be *backend
sysView := logical.TestSystemView()
be = Backend(&logical.BackendConfig{
System: sysView,
})
testPolicyFuzzingCommon(t, be)
sysView.CachingDisabledVal = true
be = Backend(&logical.BackendConfig{
System: sysView,
})
testPolicyFuzzingCommon(t, be)
}
func testPolicyFuzzingCommon(t *testing.T, be *backend) {
storage := &logical.LockingInmemStorage{}
wg := sync.WaitGroup{}
funcs := []string{"encrypt", "decrypt", "rotate", "change_min_version"}
//keys := []string{"test1", "test2", "test3", "test4", "test5"}
keys := []string{"test1", "test2", "test3"}
// This is the goroutine loop
doFuzzy := func() {
doFuzzy := func(id int) {
// Check for panics, otherwise notify we're done
defer func() {
if err := recover(); err != nil {
@ -587,15 +602,24 @@ func TestPolicyFuzzing(t *testing.T) {
}
fd := &framework.FieldData{}
var retest bool
var chosenFunc, chosenKey string
//t.Logf("Starting")
for {
// Stop after 10 seconds
if time.Now().Sub(startTime) > 10*time.Second {
if retest {
t.Errorf("ended runtime on a retest, id is %d", id)
}
return
}
// Pick a function and a key
chosenFunc := funcs[rand.Int()%len(funcs)]
chosenKey := keys[rand.Int()%len(keys)]
if !retest {
chosenFunc = funcs[rand.Int()%len(funcs)]
chosenKey = keys[rand.Int()%len(keys)]
}
fd.Raw = map[string]interface{}{
"name": chosenKey,
@ -605,33 +629,36 @@ func TestPolicyFuzzing(t *testing.T) {
// Try to write the key to make sure it exists
_, err := be.pathPolicyWrite(req, fd)
if err != nil {
t.Errorf("got an error: %v", err)
t.Fatalf("got an error: %v", err)
return
}
switch chosenFunc {
// Encrypt our plaintext and store the result
case "encrypt":
//t.Logf("%s, %s", chosenFunc, chosenKey)
fd.Raw["plaintext"] = base64.StdEncoding.EncodeToString([]byte(testPlaintext))
fd.Schema = be.pathEncrypt().Fields
resp, err := be.pathEncryptWrite(req, fd)
if err != nil {
t.Errorf("got an error: %v, resp is %#v", err, *resp)
t.Fatalf("got an error: %v, resp is %#v", err, *resp)
return
}
latestEncryptedText[chosenKey] = resp.Data["ciphertext"].(string)
// Rotate to a new key version
case "rotate":
//t.Errorf("%s, %s, %d", chosenFunc, chosenKey, id)
fd.Schema = be.pathRotate().Fields
resp, err := be.pathRotateWrite(req, fd)
if err != nil {
t.Errorf("got an error: %v, resp is %#v", err, *resp)
t.Fatalf("got an error: %v, resp is %#v, chosenKey is %s", err, *resp, chosenKey)
return
}
// Decrypt the ciphertext and compare the result
case "decrypt":
//t.Logf("%s, %s", chosenFunc, chosenKey)
ct := latestEncryptedText[chosenKey]
if ct == "" {
continue
@ -645,13 +672,14 @@ func TestPolicyFuzzing(t *testing.T) {
if resp.Data["error"].(string) == ErrTooOld {
continue
}
t.Errorf("got an error: %v, resp is %#v, ciphertext was %s", err, *resp, latestEncryptedText[chosenKey])
return
t.Errorf("got an error: %v, resp is %#v, ciphertext was %s, chosenKey is %s, id is %d", err, *resp, ct, chosenKey, id)
retest = true
continue
}
ptb64 := resp.Data["plaintext"].(string)
pt, err := base64.StdEncoding.DecodeString(ptb64)
if err != nil {
t.Errorf("got an error decoding base64 plaintext: %v", err)
t.Fatalf("got an error decoding base64 plaintext: %v", err)
return
}
if string(pt) != testPlaintext {
@ -660,9 +688,10 @@ func TestPolicyFuzzing(t *testing.T) {
// Change the min version, which also tests the archive functionality
case "change_min_version":
//t.Logf("%s, %s", chosenFunc, chosenKey)
resp, err := be.pathPolicyRead(req, fd)
if err != nil {
t.Errorf("got an error reading policy %s: %v", chosenKey, err)
t.Fatalf("got an error reading policy %s: %v", chosenKey, err)
return
}
latestVersion := resp.Data["latest_version"].(int)
@ -673,17 +702,22 @@ func TestPolicyFuzzing(t *testing.T) {
fd.Schema = be.pathConfig().Fields
resp, err = be.pathConfigWrite(req, fd)
if err != nil {
t.Errorf("got an error setting min decryption version: %v", err)
t.Fatalf("got an error setting min decryption version: %v", err)
return
}
}
if retest {
t.Errorf("success, setting retest false, id is %d", id)
}
retest = false
}
}
// Spawn 1000 of these workers for 10 seconds
for i := 0; i < 1000; i++ {
wg.Add(1)
go doFuzzy()
go doFuzzy(i)
}
// Wait for them all to finish

View File

@ -0,0 +1,109 @@
package transit
import (
"sync"
"github.com/hashicorp/vault/logical"
)
// policyCache implements CRUD operations with a simple locking cache of
// policies
type cachingPolicyCRUD struct {
sync.RWMutex
cache map[string]lockingPolicy
}
func newCachingPolicyCRUD() *cachingPolicyCRUD {
return &cachingPolicyCRUD{
cache: map[string]lockingPolicy{},
}
}
func (p *cachingPolicyCRUD) getPolicy(storage logical.Storage, name string) (lockingPolicy, error) {
// We don't defer this since we may need to give it up and get a write lock
p.RLock()
// First, see if we're in the cache -- if so, return that
if p.cache[name] != nil {
defer p.RUnlock()
return p.cache[name], nil
}
// If we didn't find anything, we'll need to write into the cache, plus possibly
// persist the entry, so lock the cache
p.RUnlock()
p.Lock()
defer p.Unlock()
// Check one more time to ensure that another process did not write during
// our lock switcheroo.
if p.cache[name] != nil {
return p.cache[name], nil
}
return p.refreshPolicy(storage, name)
}
func (p *cachingPolicyCRUD) refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) {
// Check once more to ensure it hasn't been added to the cache since the lock was acquired
if p.cache[name] != nil {
return p.cache[name], nil
}
// Note that we don't need to create the locking entry until the end,
// because the policy wasn't in the cache so we don't know about it, and we
// hold the cache lock so nothing else can be writing it in right now
policy, err := fetchPolicyFromStorage(storage, name)
if err != nil {
return nil, err
}
if policy == nil {
return nil, nil
}
lp := &mutexLockingPolicy{
policy: policy,
mutex: &sync.RWMutex{},
}
p.cache[name] = lp
return lp, nil
}
// generatePolicy is used to create a new named policy with a randomly
// generated key. The caller should hold the write lock prior to calling this.
func (p *cachingPolicyCRUD) generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) {
policy, err := generatePolicyCommon(p, storage, name, derived)
if err != nil {
return nil, err
}
// Now we need to check again in the cache to ensure the policy wasn't
// created since we ran generatePolicy and then got the lock. A policy
// being created holds a write lock until it's done (starting from this
// point), so it'll be in the cache at this point.
if lp := p.cache[name]; lp != nil {
return lp, nil
}
lp := &mutexLockingPolicy{
policy: policy,
mutex: &sync.RWMutex{},
}
p.cache[name] = lp
// Return the policy
return lp, nil
}
// deletePolicy deletes a policy
func (p *cachingPolicyCRUD) deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error {
err := deletePolicyCommon(p, lp, storage, name)
if err != nil {
return err
}
delete(p.cache, name)
return nil
}

View File

@ -42,26 +42,48 @@ func (b *backend) pathConfigWrite(
name := d.Get("name").(string)
// Check if the policy already exists
lp, err := b.policies.getPolicy(req, name)
lp, err := b.policies.getPolicy(req.Storage, name)
if err != nil {
return nil, err
}
if lp == nil {
return logical.ErrorResponse(
fmt.Sprintf("no existing role named %s could be found", name)),
fmt.Sprintf("no existing key named %s could be found", name)),
logical.ErrInvalidRequest
}
currDeletionAllowed := lp.Policy().DeletionAllowed
currMinDecryptionVersion := lp.Policy().MinDecryptionVersion
// Hold both locks since we want to ensure the policy doesn't change from underneath us
b.policies.Lock()
defer b.policies.Unlock()
lp.Lock()
defer lp.Unlock()
// Refresh in case it's changed since before we grabbed the lock
lp, err = b.policies.refreshPolicy(req.Storage, name)
if err != nil {
return nil, err
}
if lp == nil {
return nil, fmt.Errorf("error finding key %s after locking for changes", name)
}
// Verify if wasn't deleted before we grabbed the lock
if lp.policy == nil {
return nil, fmt.Errorf("no existing role named %s could be found", name)
if lp.Policy() == nil {
return nil, fmt.Errorf("no existing key named %s could be found", name)
}
resp := &logical.Response{}
// Check for anything to have been updated since we got the policy
if currDeletionAllowed != lp.Policy().DeletionAllowed ||
currMinDecryptionVersion != lp.Policy().MinDecryptionVersion {
resp.AddWarning("key configuration has changed since this endpoint was called, not updating")
return resp, nil
}
persistNeeded := false
minDecryptionVersionRaw, ok := d.GetOk("min_decryption_version")
@ -78,12 +100,12 @@ func (b *backend) pathConfigWrite(
}
if minDecryptionVersion > 0 &&
minDecryptionVersion != lp.policy.MinDecryptionVersion {
if minDecryptionVersion > lp.policy.LatestVersion {
minDecryptionVersion != lp.Policy().MinDecryptionVersion {
if minDecryptionVersion > lp.Policy().LatestVersion {
return logical.ErrorResponse(
fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, lp.policy.LatestVersion)), nil
fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, lp.Policy().LatestVersion)), nil
}
lp.policy.MinDecryptionVersion = minDecryptionVersion
lp.Policy().MinDecryptionVersion = minDecryptionVersion
persistNeeded = true
}
}
@ -91,8 +113,8 @@ func (b *backend) pathConfigWrite(
allowDeletionInt, ok := d.GetOk("deletion_allowed")
if ok {
allowDeletion := allowDeletionInt.(bool)
if allowDeletion != lp.policy.DeletionAllowed {
lp.policy.DeletionAllowed = allowDeletion
if allowDeletion != lp.Policy().DeletionAllowed {
lp.Policy().DeletionAllowed = allowDeletion
persistNeeded = true
}
}
@ -100,8 +122,8 @@ func (b *backend) pathConfigWrite(
// Add this as a guard here before persisting since we now require the min
// decryption version to start at 1; even if it's not explicitly set here,
// force the upgrade
if lp.policy.MinDecryptionVersion == 0 {
lp.policy.MinDecryptionVersion = 1
if lp.Policy().MinDecryptionVersion == 0 {
lp.Policy().MinDecryptionVersion = 1
persistNeeded = true
}
@ -109,7 +131,7 @@ func (b *backend) pathConfigWrite(
return nil, nil
}
return resp, lp.policy.Persist(req.Storage)
return resp, lp.Policy().Persist(req.Storage)
}
const pathConfigHelpSyn = `Configure a named encryption key`

View File

@ -73,7 +73,7 @@ func (b *backend) pathDatakeyWrite(
}
// Get the policy
lp, err := b.policies.getPolicy(req, name)
lp, err := b.policies.getPolicy(req.Storage, name)
if err != nil {
return nil, err
}
@ -87,7 +87,7 @@ func (b *backend) pathDatakeyWrite(
defer lp.RUnlock()
// Verify if wasn't deleted before we grabbed the lock
if lp.policy == nil {
if lp.Policy() == nil {
return nil, fmt.Errorf("no existing policy named %s could be found", name)
}
@ -107,7 +107,7 @@ func (b *backend) pathDatakeyWrite(
return nil, err
}
ciphertext, err := lp.policy.Encrypt(context, base64.StdEncoding.EncodeToString(newKey))
ciphertext, err := lp.Policy().Encrypt(context, base64.StdEncoding.EncodeToString(newKey))
if err != nil {
switch err.(type) {
case certutil.UserError:

View File

@ -58,7 +58,7 @@ func (b *backend) pathDecryptWrite(
}
// Get the policy
lp, err := b.policies.getPolicy(req, name)
lp, err := b.policies.getPolicy(req.Storage, name)
if err != nil {
return nil, err
}
@ -72,11 +72,11 @@ func (b *backend) pathDecryptWrite(
defer lp.RUnlock()
// Verify if wasn't deleted before we grabbed the lock
if lp.policy == nil {
if lp.Policy() == nil {
return nil, fmt.Errorf("no existing policy named %s could be found", name)
}
plaintext, err := lp.policy.Decrypt(context, ciphertext)
plaintext, err := lp.Policy().Decrypt(context, ciphertext)
if err != nil {
switch err.(type) {
case certutil.UserError:

View File

@ -44,7 +44,7 @@ func (b *backend) pathEncrypt() *framework.Path {
func (b *backend) pathEncryptExistenceCheck(
req *logical.Request, d *framework.FieldData) (bool, error) {
name := d.Get("name").(string)
lp, err := b.policies.getPolicy(req, name)
lp, err := b.policies.getPolicy(req.Storage, name)
if err != nil {
return false, err
}
@ -72,37 +72,43 @@ func (b *backend) pathEncryptWrite(
}
// Get the policy
lp, err := b.policies.getPolicy(req, name)
lp, err := b.policies.getPolicy(req.Storage, name)
if err != nil {
return nil, err
}
// Error if invalid policy
// Error or upsert if invalid policy
if lp == nil {
if req.Operation != logical.CreateOperation {
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
// Get a write lock
b.policies.Lock()
isDerived := len(context) != 0
// This also checks to make sure one hasn't been created since we grabbed the write lock
lp, err = b.policies.generatePolicy(req.Storage, name, isDerived)
// If the error is that the policy has been created in the interim we
// will get the policy back, so only consider it an error if err is not
// nil and we do not get a policy back
if err != nil && lp != nil {
b.policies.Unlock()
return nil, err
}
b.policies.Unlock()
}
lp.RLock()
defer lp.RUnlock()
// Verify if wasn't deleted before we grabbed the lock
if lp.policy == nil {
if lp.Policy() == nil {
return nil, fmt.Errorf("no existing policy named %s could be found", name)
}
ciphertext, err := lp.policy.Encrypt(context, value)
ciphertext, err := lp.Policy().Encrypt(context, value)
if err != nil {
switch err.(type) {
case certutil.UserError:

View File

@ -36,20 +36,14 @@ func (b *backend) pathKeys() *framework.Path {
func (b *backend) pathPolicyWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
b.policies.Lock()
defer b.policies.Unlock()
name := d.Get("name").(string)
derived := d.Get("derived").(bool)
// Check if the policy already exists
existing, err := b.policies.getPolicy(req, name)
if err != nil {
return nil, err
}
if existing != nil {
return nil, nil
}
// Generate the policy
_, err = b.policies.generatePolicy(req.Storage, name, derived)
// Generate the policy; this will also check if it exists for safety
_, err := b.policies.generatePolicy(req.Storage, name, derived)
return nil, err
}
@ -57,7 +51,7 @@ func (b *backend) pathPolicyRead(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
lp, err := b.policies.getPolicy(req, name)
lp, err := b.policies.getPolicy(req.Storage, name)
if err != nil {
return nil, err
}
@ -69,27 +63,27 @@ func (b *backend) pathPolicyRead(
defer lp.RUnlock()
// Verify if wasn't deleted before we grabbed the lock
if lp.policy == nil {
if lp.Policy() == nil {
return nil, fmt.Errorf("no existing policy named %s could be found", name)
}
// Return the response
resp := &logical.Response{
Data: map[string]interface{}{
"name": lp.policy.Name,
"cipher_mode": lp.policy.CipherMode,
"derived": lp.policy.Derived,
"deletion_allowed": lp.policy.DeletionAllowed,
"min_decryption_version": lp.policy.MinDecryptionVersion,
"latest_version": lp.policy.LatestVersion,
"name": lp.Policy().Name,
"cipher_mode": lp.Policy().CipherMode,
"derived": lp.Policy().Derived,
"deletion_allowed": lp.Policy().DeletionAllowed,
"min_decryption_version": lp.Policy().MinDecryptionVersion,
"latest_version": lp.Policy().LatestVersion,
},
}
if lp.policy.Derived {
resp.Data["kdf_mode"] = lp.policy.KDFMode
if lp.Policy().Derived {
resp.Data["kdf_mode"] = lp.Policy().KDFMode
}
retKeys := map[string]int64{}
for k, v := range lp.policy.Keys {
for k, v := range lp.Policy().Keys {
retKeys[strconv.Itoa(k)] = v.CreationTime
}
resp.Data["keys"] = retKeys
@ -101,7 +95,7 @@ func (b *backend) pathPolicyDelete(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
lp, err := b.policies.getPolicy(req, name)
lp, err := b.policies.getPolicy(req.Storage, name)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err
}
@ -109,7 +103,12 @@ func (b *backend) pathPolicyDelete(
return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest
}
err = b.policies.deletePolicy(req.Storage, name)
b.policies.Lock()
defer b.policies.Unlock()
lp.Lock()
defer lp.Unlock()
err = b.policies.deletePolicy(req.Storage, lp, name)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err
}

View File

@ -59,7 +59,7 @@ func (b *backend) pathRewrapWrite(
}
// Get the policy
lp, err := b.policies.getPolicy(req, name)
lp, err := b.policies.getPolicy(req.Storage, name)
if err != nil {
return nil, err
}
@ -73,11 +73,11 @@ func (b *backend) pathRewrapWrite(
defer lp.RUnlock()
// Verify if wasn't deleted before we grabbed the lock
if lp.policy == nil {
if lp.Policy() == nil {
return nil, fmt.Errorf("no existing policy named %s could be found", name)
}
plaintext, err := lp.policy.Decrypt(context, value)
plaintext, err := lp.Policy().Decrypt(context, value)
if err != nil {
switch err.(type) {
case certutil.UserError:
@ -93,7 +93,7 @@ func (b *backend) pathRewrapWrite(
return nil, fmt.Errorf("empty plaintext returned during rewrap")
}
ciphertext, err := lp.policy.Encrypt(context, plaintext)
ciphertext, err := lp.Policy().Encrypt(context, plaintext)
if err != nil {
switch err.(type) {
case certutil.UserError:

View File

@ -31,26 +31,48 @@ func (b *backend) pathRotateWrite(
name := d.Get("name").(string)
// Get the policy
lp, err := b.policies.getPolicy(req, name)
lp, err := b.policies.getPolicy(req.Storage, name)
if err != nil {
return nil, err
}
// Error if invalid policy
if lp == nil {
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
return logical.ErrorResponse("key not found"), logical.ErrInvalidRequest
}
keyVersion := lp.Policy().LatestVersion
// lock the policies object so we can refresh
b.policies.Lock()
defer b.policies.Unlock()
lp.Lock()
defer lp.Unlock()
// Verify if wasn't deleted before we grabbed the lock
if lp.policy == nil {
return nil, fmt.Errorf("no existing policy named %s could be found", name)
// Refresh in case it's changed since before we grabbed the lock
lp, err = b.policies.refreshPolicy(req.Storage, name)
if err != nil {
return nil, err
}
if lp == nil {
return nil, fmt.Errorf("error finding key %s after locking for changes", name)
}
// Generate the policy
err = lp.policy.rotate(req.Storage)
// Verify if wasn't deleted before we grabbed the lock
if lp.Policy() == nil {
return nil, fmt.Errorf("no existing key named %s could be found", name)
}
// Make sure that the policy hasn't been rotated simultaneously
if keyVersion != lp.Policy().LatestVersion {
resp := &logical.Response{}
resp.AddWarning("key has been rotated since this endpoint was called; did not perform rotation")
return resp, nil
}
//fmt.Printf("Rotating key %s, orig seen version is %d, currVersion is %d\n", name, keyVersion, lp.Policy().LatestVersion)
// Rotate the policy
err = lp.Policy().rotate(req.Storage)
return nil, err
}

View File

@ -9,7 +9,6 @@ import (
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/hashicorp/vault/helper/certutil"
@ -24,197 +23,6 @@ const (
ErrTooOld = "ciphertext version is disallowed by policy (too old)"
)
// policyCache implements a simple locking cache of policies
type policyCache struct {
sync.RWMutex
cache map[string]*lockingPolicy
}
// getPolicy loads a policy into the cache or returns one already in the cache
func (p *policyCache) getPolicy(req *logical.Request, name string) (*lockingPolicy, error) {
// We don't defer this since we may need to give it up and get a write lock
p.RLock()
// First, see if we're in the cache -- if so, return that
if p.cache[name] != nil {
defer p.RUnlock()
return p.cache[name], nil
}
// If we didn't find anything, we'll need to write into the cache, plus possibly
// persist the entry, so lock the cache
p.RUnlock()
p.Lock()
defer p.Unlock()
// Check one more time to ensure that another process did not write during
// our lock switcheroo.
if p.cache[name] != nil {
return p.cache[name], nil
}
// Note that we don't need to create the locking entry until the end,
// because the policy wasn't in the cache so we don't know about it, and we
// hold the cache lock so nothing else can be writing it in right now
// Check if the policy already exists
raw, err := req.Storage.Get("policy/" + name)
if err != nil {
return nil, err
}
if raw == nil {
return nil, nil
}
// Decode the policy
policy := &Policy{
Keys: KeyEntryMap{},
}
err = json.Unmarshal(raw.Value, policy)
if err != nil {
return nil, err
}
persistNeeded := false
// Ensure we've moved from Key -> Keys
if policy.Key != nil && len(policy.Key) > 0 {
policy.migrateKeyToKeysMap()
persistNeeded = true
}
// With archiving, past assumptions about the length of the keys map are no longer valid
if policy.LatestVersion == 0 && len(policy.Keys) != 0 {
policy.LatestVersion = len(policy.Keys)
persistNeeded = true
}
// We disallow setting the version to 0, since they start at 1 since moving
// to rotate-able keys, so update if it's set to 0
if policy.MinDecryptionVersion == 0 {
policy.MinDecryptionVersion = 1
persistNeeded = true
}
// On first load after an upgrade, copy keys to the archive
if policy.ArchiveVersion == 0 {
persistNeeded = true
}
if persistNeeded {
err = policy.Persist(req.Storage)
if err != nil {
return nil, err
}
}
lp := &lockingPolicy{
policy: policy,
}
p.cache[name] = lp
return lp, nil
}
// generatePolicy is used to create a new named policy with a randomly
// generated key
func (p *policyCache) generatePolicy(storage logical.Storage, name string, derived bool) (*lockingPolicy, error) {
// Ensure one with this name doesn't already exist
lp, err := p.getPolicy(&logical.Request{
Storage: storage,
}, name)
if err != nil {
return nil, fmt.Errorf("error checking if policy already exists: %s", err)
}
if lp != nil {
return nil, fmt.Errorf("policy %s already exists", name)
}
p.Lock()
defer p.Unlock()
// Now we need to check again in the cache to ensure the policy wasn't
// created since we checked getPolicy. A policy being created holds a write
// lock until it's done, so it'll be in the cache at this point.
if lp := p.cache[name]; lp != nil {
return nil, fmt.Errorf("policy %s already exists", name)
}
// Create the policy object
policy := &Policy{
Name: name,
CipherMode: "aes-gcm",
Derived: derived,
}
if derived {
policy.KDFMode = kdfMode
}
err = policy.rotate(storage)
if err != nil {
return nil, err
}
lp = &lockingPolicy{
policy: policy,
}
p.cache[name] = lp
// Return the policy
return lp, nil
}
// deletePolicy deletes a policy
func (p *policyCache) deletePolicy(storage logical.Storage, name string) error {
// Ensure one with this name exists
lp, err := p.getPolicy(&logical.Request{
Storage: storage,
}, name)
if err != nil {
return fmt.Errorf("error checking if policy already exists: %s", err)
}
if lp == nil {
return fmt.Errorf("policy %s does not exist", name)
}
p.Lock()
defer p.Unlock()
lp = p.cache[name]
if lp == nil {
return fmt.Errorf("policy %s not found", name)
}
// We need to ensure all other access has stopped
lp.Lock()
defer lp.Unlock()
// Verify this hasn't changed
if !lp.policy.DeletionAllowed {
return fmt.Errorf("deletion not allowed for policy %s", name)
}
err = storage.Delete("policy/" + name)
if err != nil {
return fmt.Errorf("error deleting policy %s: %s", name, err)
}
err = storage.Delete("archive/" + name)
if err != nil {
return fmt.Errorf("error deleting archive %s: %s", name, err)
}
lp.policy = nil
delete(p.cache, name)
return nil
}
// lockingPolicy holds a Policy guarded by a lock
type lockingPolicy struct {
sync.RWMutex
policy *Policy
}
// KeyEntry stores the key and metadata
type KeyEntry struct {
Key []byte `json:"key"`
@ -521,17 +329,17 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) {
func (p *Policy) Decrypt(context []byte, value string) (string, error) {
// Verify the prefix
if !strings.HasPrefix(value, "vault:v") {
return "", certutil.UserError{Err: "invalid ciphertext"}
return "", certutil.UserError{Err: "invalid ciphertext: no prefix"}
}
splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2)
if len(splitVerCiphertext) != 2 {
return "", certutil.UserError{Err: "invalid ciphertext"}
return "", certutil.UserError{Err: "invalid ciphertext: wrong number of fields"}
}
ver, err := strconv.Atoi(splitVerCiphertext[0])
if err != nil {
return "", certutil.UserError{Err: "invalid ciphertext"}
return "", certutil.UserError{Err: "invalid ciphertext: version number could not be decoded"}
}
if ver == 0 {
@ -540,6 +348,10 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) {
ver = 1
}
if ver > p.LatestVersion {
return "", certutil.UserError{Err: "invalid ciphertext: version is too new"}
}
if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion {
return "", certutil.UserError{Err: ErrTooOld}
}
@ -560,7 +372,7 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) {
// Decode the base64
decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1])
if err != nil {
return "", certutil.UserError{Err: "invalid ciphertext"}
return "", certutil.UserError{Err: "invalid ciphertext: could not decode base64"}
}
// Setup the cipher
@ -582,7 +394,7 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) {
// Verify and Decrypt
plain, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", certutil.UserError{Err: "invalid ciphertext"}
return "", certutil.UserError{Err: "invalid ciphertext: unable to decrypt"}
}
return base64.StdEncoding.EncodeToString(plain), nil
@ -617,6 +429,8 @@ func (p *Policy) rotate(storage logical.Storage) error {
p.MinDecryptionVersion = 1
}
//fmt.Printf("policy %s rotated to %d\n", p.Name, p.LatestVersion)
return p.Persist(storage)
}

View File

@ -0,0 +1,185 @@
package transit
import (
"encoding/json"
"fmt"
"sync"
"github.com/hashicorp/vault/logical"
)
type lockingPolicy interface {
Lock()
RLock()
Unlock()
RUnlock()
Policy() *Policy
SetPolicy(*Policy)
}
type policyCRUD interface {
// getPolicy returns a lockingPolicy. It performs its own locking according
// to implementation.
getPolicy(storage logical.Storage, name string) (lockingPolicy, error)
// refreshPolicy returns a lockingPolicy. It does not perform its own
// locking; a write lock must be held before calling.
refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error)
// generatePolicy generates and returns a lockingPolicy. A write lock must
// be held before calling.
generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error)
// deletePolicy deletes a lockingPolicy. A write lock must be held on both
// the CRUD implementation and the lockingPolicy before calling.
deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error
// These are generally satisfied by embedded mutexes in the implementing struct
Lock()
RLock()
Unlock()
RUnlock()
}
type mutexLockingPolicy struct {
mutex *sync.RWMutex
policy *Policy
}
func (m *mutexLockingPolicy) Lock() {
m.mutex.Lock()
}
func (m *mutexLockingPolicy) RLock() {
m.mutex.RLock()
}
func (m *mutexLockingPolicy) Unlock() {
m.mutex.Unlock()
}
func (m *mutexLockingPolicy) RUnlock() {
m.mutex.RUnlock()
}
func (m *mutexLockingPolicy) Policy() *Policy {
return m.policy
}
func (m *mutexLockingPolicy) SetPolicy(p *Policy) {
m.policy = p
}
// The caller should hold the write lock when calling this
func fetchPolicyFromStorage(storage logical.Storage, name string) (*Policy, error) {
// Check if the policy already exists
raw, err := storage.Get("policy/" + name)
if err != nil {
return nil, err
}
if raw == nil {
return nil, nil
}
// Decode the policy
policy := &Policy{
Keys: KeyEntryMap{},
}
err = json.Unmarshal(raw.Value, policy)
if err != nil {
return nil, err
}
persistNeeded := false
// Ensure we've moved from Key -> Keys
if policy.Key != nil && len(policy.Key) > 0 {
policy.migrateKeyToKeysMap()
persistNeeded = true
}
// With archiving, past assumptions about the length of the keys map are no longer valid
if policy.LatestVersion == 0 && len(policy.Keys) != 0 {
policy.LatestVersion = len(policy.Keys)
persistNeeded = true
}
// We disallow setting the version to 0, since they start at 1 since moving
// to rotate-able keys, so update if it's set to 0
if policy.MinDecryptionVersion == 0 {
policy.MinDecryptionVersion = 1
persistNeeded = true
}
// On first load after an upgrade, copy keys to the archive
if policy.ArchiveVersion == 0 {
persistNeeded = true
}
if persistNeeded {
err = policy.Persist(storage)
if err != nil {
return nil, err
}
}
return policy, nil
}
// generatePolicyCommon is used to create a new named policy with a randomly
// generated key. The caller should have a write lock prior to calling this.
func generatePolicyCommon(p policyCRUD, storage logical.Storage, name string, derived bool) (*Policy, error) {
// Make sure this doesn't exist in case it was created before we got the write lock
policy, err := fetchPolicyFromStorage(storage, name)
if err != nil {
return nil, err
}
if policy != nil {
return policy, nil
}
//log.Printf("generating a new policy with name %s", name)
// Create the policy object
policy = &Policy{
Name: name,
CipherMode: "aes-gcm",
Derived: derived,
}
if derived {
policy.KDFMode = kdfMode
}
err = policy.rotate(storage)
if err != nil {
return nil, err
}
return policy, err
}
// deletePolicy deletes a policy. The caller should hold the write lock for both the policy and lockingPolicy prior to calling this.
func deletePolicyCommon(p policyCRUD, lp lockingPolicy, storage logical.Storage, name string) error {
if lp.Policy() == nil {
// This got deleted before we grabbed the lock
return fmt.Errorf("policy already deleted")
}
// Verify this hasn't changed
if !lp.Policy().DeletionAllowed {
return fmt.Errorf("deletion not allowed for policy %s", name)
}
err := storage.Delete("policy/" + name)
if err != nil {
return fmt.Errorf("error deleting policy %s: %s", name, err)
}
err = storage.Delete("archive/" + name)
if err != nil {
return fmt.Errorf("error deleting archive %s: %s", name, err)
}
lp.SetPolicy(nil)
return nil
}

View File

@ -16,19 +16,24 @@ func resetKeysArchive() {
}
func Test_KeyUpgrade(t *testing.T) {
testKeyUpgradeCommon(t, newSimplePolicyCRUD())
testKeyUpgradeCommon(t, newCachingPolicyCRUD())
}
func testKeyUpgradeCommon(t *testing.T, policies policyCRUD) {
storage := &logical.InmemStorage{}
policies := &policyCache{
cache: map[string]*lockingPolicy{},
}
lp, err := policies.generatePolicy(storage, "test", false)
if err != nil {
t.Fatal(err)
}
if lp == nil {
t.Fatal("nil policy")
t.Fatal("nil lockingPolicy")
}
policy := lp.policy
policy := lp.Policy()
if policy == nil {
t.Fatal("nil policy in lockingPolicy")
}
testBytes := make([]byte, len(policy.Keys[1].Key))
copy(testBytes, policy.Keys[1].Key)
@ -48,6 +53,11 @@ func Test_KeyUpgrade(t *testing.T) {
}
func Test_ArchivingUpgrade(t *testing.T) {
testArchivingUpgradeCommon(t, newSimplePolicyCRUD())
testArchivingUpgradeCommon(t, newCachingPolicyCRUD())
}
func testArchivingUpgradeCommon(t *testing.T, policies policyCRUD) {
resetKeysArchive()
// First, we generate a policy and rotate it a number of times. Each time
@ -56,19 +66,19 @@ func Test_ArchivingUpgrade(t *testing.T) {
// zero and latest, respectively
storage := &logical.InmemStorage{}
policies := &policyCache{
cache: map[string]*lockingPolicy{},
}
lp, err := policies.generatePolicy(storage, "test", false)
if err != nil {
t.Fatal(err)
}
if lp == nil {
t.Fatal("policy is nil")
t.Fatal("nil lockingPolicy")
}
policy := lp.policy
policy := lp.Policy()
if policy == nil {
t.Fatal("nil policy in lockingPolicy")
}
// Store the initial key in the archive
keysArchive = append(keysArchive, policy.Keys[1])
@ -106,26 +116,35 @@ func Test_ArchivingUpgrade(t *testing.T) {
t.Fatal(err)
}
// Expire from the cache since we modified it under-the-hood
delete(policies.cache, "test")
// If it's a caching CRUD, expire from the cache since we modified it
// under-the-hood
if cachingCRUD, ok := policies.(*cachingPolicyCRUD); ok {
delete(cachingCRUD.cache, "test")
}
// Now get the policy again; the upgrade should happen automatically
lp, err = policies.getPolicy(&logical.Request{
Storage: storage,
}, "test")
lp, err = policies.getPolicy(storage, "test")
if err != nil {
t.Fatal(err)
}
if lp == nil {
t.Fatal("policy is nil")
t.Fatal("nil lockingPolicy")
}
policy = lp.policy
policy = lp.Policy()
if policy == nil {
t.Fatal("nil policy in lockingPolicy")
}
checkKeys(t, policy, storage, "upgrade", 10, 10, 10)
}
func Test_Archiving(t *testing.T) {
testArchivingCommon(t, newSimplePolicyCRUD())
testArchivingCommon(t, newCachingPolicyCRUD())
}
func testArchivingCommon(t *testing.T, policies policyCRUD) {
resetKeysArchive()
// First, we generate a policy and rotate it a number of times. Each time
@ -135,19 +154,18 @@ func Test_Archiving(t *testing.T) {
storage := &logical.InmemStorage{}
policies := &policyCache{
cache: map[string]*lockingPolicy{},
}
lp, err := policies.generatePolicy(storage, "test", false)
if err != nil {
t.Fatal(err)
}
if lp == nil {
t.Fatal("policy is nil")
t.Fatal("nil lockingPolicy")
}
policy := lp.policy
policy := lp.Policy()
if policy == nil {
t.Fatal("nil policy in lockingPolicy")
}
// Store the initial key in the archive
keysArchive = append(keysArchive, policy.Keys[1])

View File

@ -0,0 +1,88 @@
package transit
import (
"sync"
"github.com/hashicorp/vault/logical"
)
// Directly implements CRUD operations without caching, mapped to the backend,
// but implements locking to ensure that we can't overwrite data on the backend
// from multiple operators
type simplePolicyCRUD struct {
sync.RWMutex
locks map[string]*sync.RWMutex
locksMapMutex sync.RWMutex
}
func newSimplePolicyCRUD() *simplePolicyCRUD {
return &simplePolicyCRUD{
locks: map[string]*sync.RWMutex{},
}
}
func (p *simplePolicyCRUD) ensureLockExists(name string) {
p.locksMapMutex.RLock()
if p.locks[name] == nil {
p.locksMapMutex.RUnlock()
p.locksMapMutex.Lock()
// Make sure nothing has appeared since we switched the lock type
if p.locks[name] == nil {
p.locks[name] = &sync.RWMutex{}
}
p.locksMapMutex.Unlock()
return
}
p.locksMapMutex.RUnlock()
}
func (p *simplePolicyCRUD) getPolicy(storage logical.Storage, name string) (lockingPolicy, error) {
// Use a write lock since fetching the policy can cause a need for upgrade persistence
p.Lock()
defer p.Unlock()
return p.refreshPolicy(storage, name)
}
func (p *simplePolicyCRUD) refreshPolicy(storage logical.Storage, name string) (lockingPolicy, error) {
p.ensureLockExists(name)
policy, err := fetchPolicyFromStorage(storage, name)
if err != nil {
return nil, err
}
if policy == nil {
return nil, nil
}
lp := &mutexLockingPolicy{
policy: policy,
mutex: p.locks[name],
}
return lp, nil
}
// The caller must hold the write lock when calling this
func (p *simplePolicyCRUD) generatePolicy(storage logical.Storage, name string, derived bool) (lockingPolicy, error) {
p.ensureLockExists(name)
policy, err := generatePolicyCommon(p, storage, name, derived)
if err != nil {
return nil, err
}
lp := &mutexLockingPolicy{
policy: policy,
mutex: p.locks[name],
}
return lp, nil
}
// The caller must hold the write lock when calling this
func (p *simplePolicyCRUD) deletePolicy(storage logical.Storage, lp lockingPolicy, name string) error {
return deletePolicyCommon(p, lp, storage, name)
}

View File

@ -29,7 +29,7 @@ type SystemView interface {
// Returns true if caching is disabled. If true, no caches should be used,
// despite known slowdowns.
CacheDisabled() bool
CachingDisabled() bool
}
type StaticSystemView struct {
@ -37,7 +37,7 @@ type StaticSystemView struct {
MaxLeaseTTLVal time.Duration
SudoPrivilegeVal bool
TaintedVal bool
CacheDisabledVal bool
CachingDisabledVal bool
}
func (d StaticSystemView) DefaultLeaseTTL() time.Duration {
@ -56,6 +56,6 @@ func (d StaticSystemView) Tainted() bool {
return d.TaintedVal
}
func (d StaticSystemView) CacheDisabled() bool {
return d.CacheDisabledVal
func (d StaticSystemView) CachingDisabled() bool {
return d.CachingDisabledVal
}

View File

@ -219,8 +219,8 @@ type Core struct {
logger *log.Logger
// cacheDisabled indicates whether caches are disabled
cacheDisabled bool
// cachingDisabled indicates whether caches are disabled
cachingDisabled bool
}
// CoreConfig is used to parameterize a core
@ -318,7 +318,7 @@ func NewCore(conf *CoreConfig) (*Core, error) {
logger: conf.Logger,
defaultLeaseTTL: conf.DefaultLeaseTTL,
maxLeaseTTL: conf.MaxLeaseTTL,
cacheDisabled: conf.DisableCache,
cachingDisabled: conf.DisableCache,
}
// Setup the backends

View File

@ -70,7 +70,7 @@ func (d dynamicSystemView) Tainted() bool {
return d.mountEntry.Tainted
}
// CacheDisabled indicates whether to use caching behavior
func (d dynamicSystemView) CacheDisabled() bool {
return d.core.cacheDisabled
// CachingDisabled indicates whether to use caching behavior
func (d dynamicSystemView) CachingDisabled() bool {
return d.core.cachingDisabled
}

View File

@ -40,7 +40,7 @@ func NewPolicyStore(view *BarrierView, system logical.SystemView) *PolicyStore {
view: view,
system: system,
}
if !system.CacheDisabled() {
if !system.CachingDisabled() {
cache, _ := lru.New2Q(policyCacheSize)
p.lru = cache
}
@ -100,7 +100,7 @@ func (ps *PolicyStore) SetPolicy(p *Policy) error {
return fmt.Errorf("failed to persist policy: %v", err)
}
if !ps.system.CacheDisabled() {
if !ps.system.CachingDisabled() {
// Update the LRU cache
ps.lru.Add(p.Name, p)
}
@ -110,7 +110,7 @@ func (ps *PolicyStore) SetPolicy(p *Policy) error {
// GetPolicy is used to fetch the named policy
func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) {
defer metrics.MeasureSince([]string{"policy", "get_policy"}, time.Now())
if !ps.system.CacheDisabled() {
if !ps.system.CachingDisabled() {
// Check for cached policy
if raw, ok := ps.lru.Get(name); ok {
return raw.(*Policy), nil
@ -120,7 +120,7 @@ func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) {
// Special case the root policy
if name == "root" {
p := &Policy{Name: "root"}
if !ps.system.CacheDisabled() {
if !ps.system.CachingDisabled() {
ps.lru.Add(p.Name, p)
}
return p, nil
@ -163,7 +163,7 @@ func (ps *PolicyStore) GetPolicy(name string) (*Policy, error) {
policy = p
}
if !ps.system.CacheDisabled() {
if !ps.system.CachingDisabled() {
// Update the LRU cache
ps.lru.Add(name, policy)
}
@ -192,7 +192,7 @@ func (ps *PolicyStore) DeletePolicy(name string) error {
return fmt.Errorf("failed to delete policy: %v", err)
}
if !ps.system.CacheDisabled() {
if !ps.system.CachingDisabled() {
// Clear the cache
ps.lru.Remove(name)
}

View File

@ -16,7 +16,7 @@ func mockPolicyStore(t *testing.T) *PolicyStore {
func mockPolicyStoreNoCache(t *testing.T) *PolicyStore {
sysView := logical.TestSystemView()
sysView.CacheDisabledVal = true
sysView.CachingDisabledVal = true
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "foo/")
p := NewPolicyStore(view, sysView)