diff --git a/builtin/logical/transit/policy.go b/builtin/logical/transit/policy.go index d4bd72231..d8f674fad 100644 --- a/builtin/logical/transit/policy.go +++ b/builtin/logical/transit/policy.go @@ -116,10 +116,23 @@ func (p *policyCache) getPolicy(req *logical.Request, name string) (*lockingPoli // generatePolicy is used to create a new named policy with a randomly // generated key func (p *policyCache) generatePolicy(storage logical.Storage, name string, derived bool) (*lockingPolicy, error) { + // Ensure one with this name doesn't already exist + lp, err := p.getPolicy(&logical.Request{ + Storage: storage, + }, name) + if err != nil { + return nil, fmt.Errorf("error checking if policy already exists: %s", err) + } + if lp != nil { + return nil, fmt.Errorf("policy %s already exists", name) + } + p.lock.Lock() defer p.lock.Unlock() - // Ensure one doesn't already exist + // Now we need to check again in the cache to ensure the policy wasn't + // created since we checked getPolicy. A policy being created holds a write + // lock until it's done, so it'll be in the cache at this point. if lp := p.cache[name]; lp != nil { return nil, fmt.Errorf("policy %s already exists", name) } @@ -134,12 +147,12 @@ func (p *policyCache) generatePolicy(storage logical.Storage, name string, deriv policy.KDFMode = kdfMode } - err := policy.rotate(storage) + err = policy.rotate(storage) if err != nil { return nil, err } - lp := &lockingPolicy{ + lp = &lockingPolicy{ policy: policy, } p.cache[name] = lp