Remove some instances of potential recursive locking (#6548)
This commit is contained in:
parent
991373c969
commit
9f0a6edfcb
|
@ -423,7 +423,7 @@ func NewClient(c *Config) (*Client, error) {
|
|||
}
|
||||
|
||||
if namespace := os.Getenv(EnvVaultNamespace); namespace != "" {
|
||||
client.SetNamespace(namespace)
|
||||
client.setNamespace(namespace)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
|
@ -535,7 +535,10 @@ func (c *Client) SetMFACreds(creds []string) {
|
|||
func (c *Client) SetNamespace(namespace string) {
|
||||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
c.setNamespace(namespace)
|
||||
}
|
||||
|
||||
func (c *Client) setNamespace(namespace string) {
|
||||
if c.headers == nil {
|
||||
c.headers = make(http.Header)
|
||||
}
|
||||
|
|
|
@ -308,17 +308,26 @@ func (b *AESGCMBarrier) ReloadMasterKey(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
defer memzero(out.Value)
|
||||
// Grab write lock and refetch
|
||||
b.l.Lock()
|
||||
defer b.l.Unlock()
|
||||
|
||||
out, err = b.lockSwitchedGet(ctx, masterKeyPath, false)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to read master key path: {{err}}", err)
|
||||
}
|
||||
|
||||
if out == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deserialize the master key
|
||||
key, err := DeserializeKey(out.Value)
|
||||
memzero(out.Value)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to deserialize key: {{err}}", err)
|
||||
}
|
||||
|
||||
b.l.Lock()
|
||||
defer b.l.Unlock()
|
||||
|
||||
// Check if the master key is the same
|
||||
if subtle.ConstantTimeCompare(b.keyring.MasterKey(), key.Value) == 1 {
|
||||
return nil
|
||||
|
@ -499,8 +508,8 @@ func (b *AESGCMBarrier) Rotate(ctx context.Context) (uint32, error) {
|
|||
// CreateUpgrade creates an upgrade path key to the given term from the previous term
|
||||
func (b *AESGCMBarrier) CreateUpgrade(ctx context.Context, term uint32) error {
|
||||
b.l.RLock()
|
||||
defer b.l.RUnlock()
|
||||
if b.sealed {
|
||||
b.l.RUnlock()
|
||||
return ErrBarrierSealed
|
||||
}
|
||||
|
||||
|
@ -509,6 +518,7 @@ func (b *AESGCMBarrier) CreateUpgrade(ctx context.Context, term uint32) error {
|
|||
buf, err := termKey.Serialize()
|
||||
defer memzero(buf)
|
||||
if err != nil {
|
||||
b.l.RUnlock()
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -516,11 +526,13 @@ func (b *AESGCMBarrier) CreateUpgrade(ctx context.Context, term uint32) error {
|
|||
prevTerm := term - 1
|
||||
primary, err := b.aeadForTerm(prevTerm)
|
||||
if err != nil {
|
||||
b.l.RUnlock()
|
||||
return err
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%d", keyringUpgradePrefix, prevTerm)
|
||||
value, err := b.encrypt(key, prevTerm, primary, buf)
|
||||
b.l.RUnlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -541,8 +553,8 @@ func (b *AESGCMBarrier) DestroyUpgrade(ctx context.Context, term uint32) error {
|
|||
// CheckUpgrade looks for an upgrade to the current term and installs it
|
||||
func (b *AESGCMBarrier) CheckUpgrade(ctx context.Context) (bool, uint32, error) {
|
||||
b.l.RLock()
|
||||
defer b.l.RUnlock()
|
||||
if b.sealed {
|
||||
b.l.RUnlock()
|
||||
return false, 0, ErrBarrierSealed
|
||||
}
|
||||
|
||||
|
@ -551,30 +563,48 @@ func (b *AESGCMBarrier) CheckUpgrade(ctx context.Context) (bool, uint32, error)
|
|||
|
||||
// Check for an upgrade key
|
||||
upgrade := fmt.Sprintf("%s%d", keyringUpgradePrefix, activeTerm)
|
||||
entry, err := b.Get(ctx, upgrade)
|
||||
entry, err := b.lockSwitchedGet(ctx, upgrade, false)
|
||||
if err != nil {
|
||||
b.l.RUnlock()
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
// Nothing to do if no upgrade
|
||||
if entry == nil {
|
||||
b.l.RUnlock()
|
||||
return false, 0, nil
|
||||
}
|
||||
|
||||
defer memzero(entry.Value)
|
||||
|
||||
// Deserialize the key
|
||||
key, err := DeserializeKey(entry.Value)
|
||||
if err != nil {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
// Upgrade from read lock to write lock
|
||||
b.l.RUnlock()
|
||||
defer b.l.RLock()
|
||||
b.l.Lock()
|
||||
defer b.l.Unlock()
|
||||
|
||||
// Validate base cases and refetch values again
|
||||
|
||||
if b.sealed {
|
||||
return false, 0, ErrBarrierSealed
|
||||
}
|
||||
|
||||
activeTerm = b.keyring.ActiveTerm()
|
||||
|
||||
upgrade = fmt.Sprintf("%s%d", keyringUpgradePrefix, activeTerm)
|
||||
entry, err = b.lockSwitchedGet(ctx, upgrade, false)
|
||||
if err != nil {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
if entry == nil {
|
||||
return false, 0, nil
|
||||
}
|
||||
|
||||
// Deserialize the key
|
||||
key, err := DeserializeKey(entry.Value)
|
||||
memzero(entry.Value)
|
||||
if err != nil {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
// Update the keyring
|
||||
newKeyring, err := b.keyring.AddKey(key)
|
||||
if err != nil {
|
||||
|
@ -692,25 +722,39 @@ func (b *AESGCMBarrier) Put(ctx context.Context, entry *logical.StorageEntry) er
|
|||
|
||||
// Get is used to fetch an entry
|
||||
func (b *AESGCMBarrier) Get(ctx context.Context, key string) (*logical.StorageEntry, error) {
|
||||
return b.lockSwitchedGet(ctx, key, true)
|
||||
}
|
||||
|
||||
func (b *AESGCMBarrier) lockSwitchedGet(ctx context.Context, key string, getLock bool) (*logical.StorageEntry, error) {
|
||||
defer metrics.MeasureSince([]string{"barrier", "get"}, time.Now())
|
||||
b.l.RLock()
|
||||
if getLock {
|
||||
b.l.RLock()
|
||||
}
|
||||
if b.sealed {
|
||||
b.l.RUnlock()
|
||||
if getLock {
|
||||
b.l.RUnlock()
|
||||
}
|
||||
return nil, ErrBarrierSealed
|
||||
}
|
||||
|
||||
// Read the key from the backend
|
||||
pe, err := b.backend.Get(ctx, key)
|
||||
if err != nil {
|
||||
b.l.RUnlock()
|
||||
if getLock {
|
||||
b.l.RUnlock()
|
||||
}
|
||||
return nil, err
|
||||
} else if pe == nil {
|
||||
b.l.RUnlock()
|
||||
if getLock {
|
||||
b.l.RUnlock()
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if len(pe.Value) < 4 {
|
||||
b.l.RUnlock()
|
||||
if getLock {
|
||||
b.l.RUnlock()
|
||||
}
|
||||
return nil, errors.New("invalid value")
|
||||
}
|
||||
|
||||
|
@ -721,7 +765,9 @@ func (b *AESGCMBarrier) Get(ctx context.Context, key string) (*logical.StorageEn
|
|||
// It is expensive to do this first but it is not a
|
||||
// normal case that this won't match
|
||||
gcm, err := b.aeadForTerm(term)
|
||||
b.l.RUnlock()
|
||||
if getLock {
|
||||
b.l.RUnlock()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue