Remove some instances of potential recursive locking (#6548)

This commit is contained in:
Jeff Mitchell 2019-04-08 12:45:28 -04:00 committed by GitHub
parent 991373c969
commit 9f0a6edfcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 23 deletions

View File

@ -423,7 +423,7 @@ func NewClient(c *Config) (*Client, error) {
}
if namespace := os.Getenv(EnvVaultNamespace); namespace != "" {
client.SetNamespace(namespace)
client.setNamespace(namespace)
}
return client, nil
@ -535,7 +535,10 @@ func (c *Client) SetMFACreds(creds []string) {
func (c *Client) SetNamespace(namespace string) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.setNamespace(namespace)
}
func (c *Client) setNamespace(namespace string) {
if c.headers == nil {
c.headers = make(http.Header)
}

View File

@ -308,17 +308,26 @@ func (b *AESGCMBarrier) ReloadMasterKey(ctx context.Context) error {
return nil
}
defer memzero(out.Value)
// Grab write lock and refetch
b.l.Lock()
defer b.l.Unlock()
out, err = b.lockSwitchedGet(ctx, masterKeyPath, false)
if err != nil {
return errwrap.Wrapf("failed to read master key path: {{err}}", err)
}
if out == nil {
return nil
}
// Deserialize the master key
key, err := DeserializeKey(out.Value)
memzero(out.Value)
if err != nil {
return errwrap.Wrapf("failed to deserialize key: {{err}}", err)
}
b.l.Lock()
defer b.l.Unlock()
// Check if the master key is the same
if subtle.ConstantTimeCompare(b.keyring.MasterKey(), key.Value) == 1 {
return nil
@ -499,8 +508,8 @@ func (b *AESGCMBarrier) Rotate(ctx context.Context) (uint32, error) {
// CreateUpgrade creates an upgrade path key to the given term from the previous term
func (b *AESGCMBarrier) CreateUpgrade(ctx context.Context, term uint32) error {
b.l.RLock()
defer b.l.RUnlock()
if b.sealed {
b.l.RUnlock()
return ErrBarrierSealed
}
@ -509,6 +518,7 @@ func (b *AESGCMBarrier) CreateUpgrade(ctx context.Context, term uint32) error {
buf, err := termKey.Serialize()
defer memzero(buf)
if err != nil {
b.l.RUnlock()
return err
}
@ -516,11 +526,13 @@ func (b *AESGCMBarrier) CreateUpgrade(ctx context.Context, term uint32) error {
prevTerm := term - 1
primary, err := b.aeadForTerm(prevTerm)
if err != nil {
b.l.RUnlock()
return err
}
key := fmt.Sprintf("%s%d", keyringUpgradePrefix, prevTerm)
value, err := b.encrypt(key, prevTerm, primary, buf)
b.l.RUnlock()
if err != nil {
return err
}
@ -541,8 +553,8 @@ func (b *AESGCMBarrier) DestroyUpgrade(ctx context.Context, term uint32) error {
// CheckUpgrade looks for an upgrade to the current term and installs it
func (b *AESGCMBarrier) CheckUpgrade(ctx context.Context) (bool, uint32, error) {
b.l.RLock()
defer b.l.RUnlock()
if b.sealed {
b.l.RUnlock()
return false, 0, ErrBarrierSealed
}
@ -551,30 +563,48 @@ func (b *AESGCMBarrier) CheckUpgrade(ctx context.Context) (bool, uint32, error)
// Check for an upgrade key
upgrade := fmt.Sprintf("%s%d", keyringUpgradePrefix, activeTerm)
entry, err := b.Get(ctx, upgrade)
entry, err := b.lockSwitchedGet(ctx, upgrade, false)
if err != nil {
b.l.RUnlock()
return false, 0, err
}
// Nothing to do if no upgrade
if entry == nil {
b.l.RUnlock()
return false, 0, nil
}
defer memzero(entry.Value)
// Deserialize the key
key, err := DeserializeKey(entry.Value)
if err != nil {
return false, 0, err
}
// Upgrade from read lock to write lock
b.l.RUnlock()
defer b.l.RLock()
b.l.Lock()
defer b.l.Unlock()
// Validate base cases and refetch values again
if b.sealed {
return false, 0, ErrBarrierSealed
}
activeTerm = b.keyring.ActiveTerm()
upgrade = fmt.Sprintf("%s%d", keyringUpgradePrefix, activeTerm)
entry, err = b.lockSwitchedGet(ctx, upgrade, false)
if err != nil {
return false, 0, err
}
if entry == nil {
return false, 0, nil
}
// Deserialize the key
key, err := DeserializeKey(entry.Value)
memzero(entry.Value)
if err != nil {
return false, 0, err
}
// Update the keyring
newKeyring, err := b.keyring.AddKey(key)
if err != nil {
@ -692,25 +722,39 @@ func (b *AESGCMBarrier) Put(ctx context.Context, entry *logical.StorageEntry) er
// Get is used to fetch an entry
func (b *AESGCMBarrier) Get(ctx context.Context, key string) (*logical.StorageEntry, error) {
return b.lockSwitchedGet(ctx, key, true)
}
func (b *AESGCMBarrier) lockSwitchedGet(ctx context.Context, key string, getLock bool) (*logical.StorageEntry, error) {
defer metrics.MeasureSince([]string{"barrier", "get"}, time.Now())
b.l.RLock()
if getLock {
b.l.RLock()
}
if b.sealed {
b.l.RUnlock()
if getLock {
b.l.RUnlock()
}
return nil, ErrBarrierSealed
}
// Read the key from the backend
pe, err := b.backend.Get(ctx, key)
if err != nil {
b.l.RUnlock()
if getLock {
b.l.RUnlock()
}
return nil, err
} else if pe == nil {
b.l.RUnlock()
if getLock {
b.l.RUnlock()
}
return nil, nil
}
if len(pe.Value) < 4 {
b.l.RUnlock()
if getLock {
b.l.RUnlock()
}
return nil, errors.New("invalid value")
}
@ -721,7 +765,9 @@ func (b *AESGCMBarrier) Get(ctx context.Context, key string) (*logical.StorageEn
// It is expensive to do this first but it is not a
// normal case that this won't match
gcm, err := b.aeadForTerm(term)
b.l.RUnlock()
if getLock {
b.l.RUnlock()
}
if err != nil {
return nil, err
}