Move logic around a bit to avoid holding locks when not necessary (#5277)

Also, ensure we are error checking the rand call
This commit is contained in:
Jeff Mitchell 2018-09-05 11:49:32 -04:00 committed by GitHub
parent cb4fd4ff4b
commit c9e2cd93e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 113 additions and 64 deletions

View File

@ -7,6 +7,7 @@ import (
"crypto/rand"
"crypto/subtle"
"encoding/binary"
"errors"
"fmt"
"strings"
"sync"
@ -154,7 +155,10 @@ func (b *AESGCMBarrier) persistKeyring(ctx context.Context, keyring *Keyring) er
}
// Encrypt the barrier init value
value := b.encrypt(keyringPath, initialKeyTerm, gcm, keyringBuf)
value, err := b.encrypt(keyringPath, initialKeyTerm, gcm, keyringBuf)
if err != nil {
return err
}
// Create the keyring physical entry
pe := &physical.Entry{
@ -183,7 +187,10 @@ func (b *AESGCMBarrier) persistKeyring(ctx context.Context, keyring *Keyring) er
if err != nil {
return err
}
value = b.encrypt(masterKeyPath, activeKey.Term, aead, keyBuf)
value, err = b.encrypt(masterKeyPath, activeKey.Term, aead, keyBuf)
if err != nil {
return err
}
// Update the masterKeyPath for standby instances
pe = &physical.Entry{
@ -253,7 +260,13 @@ func (b *AESGCMBarrier) ReloadKeyring(ctx context.Context) error {
// Ensure that the keyring exists. This should never happen,
// and indicates something really bad has happened.
if out == nil {
return fmt.Errorf("keyring unexpectedly missing")
return errors.New("keyring unexpectedly missing")
}
// Verify the term is always just one
term := binary.BigEndian.Uint32(out.Value[:4])
if term != initialKeyTerm {
return errors.New("term mis-match")
}
// Decrypt the barrier init key
@ -340,6 +353,12 @@ func (b *AESGCMBarrier) Unseal(ctx context.Context, key []byte) error {
return errwrap.Wrapf("failed to check for keyring: {{err}}", err)
}
if out != nil {
// Verify the term is always just one
term := binary.BigEndian.Uint32(out.Value[:4])
if term != initialKeyTerm {
return errors.New("term mis-match")
}
// Decrypt the barrier init key
plain, err := b.decrypt(keyringPath, gcm, out.Value)
defer memzero(plain)
@ -371,6 +390,12 @@ func (b *AESGCMBarrier) Unseal(ctx context.Context, key []byte) error {
return ErrBarrierNotInit
}
// Verify the term is always just one
term := binary.BigEndian.Uint32(out.Value[:4])
if term != initialKeyTerm {
return errors.New("term mis-match")
}
// Decrypt the barrier init key
plain, err := b.decrypt(barrierInitPath, gcm, out.Value)
if err != nil {
@ -494,7 +519,10 @@ func (b *AESGCMBarrier) CreateUpgrade(ctx context.Context, term uint32) error {
}
key := fmt.Sprintf("%s%d", keyringUpgradePrefix, prevTerm)
value := b.encrypt(key, prevTerm, primary, buf)
value, err := b.encrypt(key, prevTerm, primary, buf)
if err != nil {
return err
}
// Create upgrade key
pe := &physical.Entry{
Key: key,
@ -637,20 +665,25 @@ func (b *AESGCMBarrier) updateMasterKeyCommon(key []byte) (*Keyring, error) {
func (b *AESGCMBarrier) Put(ctx context.Context, entry *Entry) error {
defer metrics.MeasureSince([]string{"barrier", "put"}, time.Now())
b.l.RLock()
defer b.l.RUnlock()
if b.sealed {
b.l.RUnlock()
return ErrBarrierSealed
}
term := b.keyring.ActiveTerm()
primary, err := b.aeadForTerm(term)
b.l.RUnlock()
if err != nil {
return err
}
value, err := b.encrypt(entry.Key, term, primary, entry.Value)
if err != nil {
return err
}
pe := &physical.Entry{
Key: entry.Key,
Value: b.encrypt(entry.Key, term, primary, entry.Value),
Value: value,
SealWrap: entry.SealWrap,
}
return b.backend.Put(ctx, pe)
@ -660,21 +693,38 @@ func (b *AESGCMBarrier) Put(ctx context.Context, entry *Entry) error {
func (b *AESGCMBarrier) Get(ctx context.Context, key string) (*Entry, error) {
defer metrics.MeasureSince([]string{"barrier", "get"}, time.Now())
b.l.RLock()
defer b.l.RUnlock()
if b.sealed {
b.l.RUnlock()
return nil, ErrBarrierSealed
}
// Read the key from the backend
pe, err := b.backend.Get(ctx, key)
if err != nil {
b.l.RUnlock()
return nil, err
} else if pe == nil {
b.l.RUnlock()
return nil, nil
}
// Verify the term
term := binary.BigEndian.Uint32(pe.Value[:4])
// Get the GCM by term
// It is expensive to do this first but it is not a
// normal case that this won't match
gcm, err := b.aeadForTerm(term)
b.l.RUnlock()
if err != nil {
return nil, err
}
if gcm == nil {
return nil, fmt.Errorf("no decryption key available for term %d", term)
}
// Decrypt the ciphertext
plain, err := b.decryptKeyring(key, pe.Value)
plain, err := b.decrypt(key, gcm, pe.Value)
if err != nil {
return nil, errwrap.Wrapf("decryption failed: {{err}}", err)
}
@ -692,8 +742,9 @@ func (b *AESGCMBarrier) Get(ctx context.Context, key string) (*Entry, error) {
func (b *AESGCMBarrier) Delete(ctx context.Context, key string) error {
defer metrics.MeasureSince([]string{"barrier", "delete"}, time.Now())
b.l.RLock()
defer b.l.RUnlock()
if b.sealed {
sealed := b.sealed
b.l.RUnlock()
if sealed {
return ErrBarrierSealed
}
@ -705,8 +756,9 @@ func (b *AESGCMBarrier) Delete(ctx context.Context, key string) error {
func (b *AESGCMBarrier) List(ctx context.Context, prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"barrier", "list"}, time.Now())
b.l.RLock()
defer b.l.RUnlock()
if b.sealed {
sealed := b.sealed
b.l.RUnlock()
if sealed {
return nil, ErrBarrierSealed
}
@ -765,7 +817,7 @@ func (b *AESGCMBarrier) aeadFromKey(key []byte) (cipher.AEAD, error) {
}
// encrypt is used to encrypt a value
func (b *AESGCMBarrier) encrypt(path string, term uint32, gcm cipher.AEAD, plain []byte) []byte {
func (b *AESGCMBarrier) encrypt(path string, term uint32, gcm cipher.AEAD, plain []byte) ([]byte, error) {
// Allocate the output buffer with room for tern, version byte,
// nonce, GCM tag and the plaintext
capacity := termSize + 1 + gcm.NonceSize() + gcm.Overhead() + len(plain)
@ -780,7 +832,13 @@ func (b *AESGCMBarrier) encrypt(path string, term uint32, gcm cipher.AEAD, plain
// Generate a random nonce
nonce := out[5 : 5+gcm.NonceSize()]
rand.Read(nonce)
n, err := rand.Read(nonce)
if err != nil {
return nil, err
}
if n != len(nonce) {
return nil, errors.New("unable to read enough random bytes to fill gcm nonce")
}
// Seal the output
switch b.currentAESGCMVersionByte {
@ -792,53 +850,16 @@ func (b *AESGCMBarrier) encrypt(path string, term uint32, gcm cipher.AEAD, plain
panic("Unknown AESGCM version")
}
return out
return out, nil
}
// decrypt is used to decrypt a value
// decrypt is used to decrypt a value using the keyring
func (b *AESGCMBarrier) decrypt(path string, gcm cipher.AEAD, cipher []byte) ([]byte, error) {
// Verify the term is always just one
term := binary.BigEndian.Uint32(cipher[:4])
if term != initialKeyTerm {
return nil, fmt.Errorf("term mis-match")
}
// Capture the parts
nonce := cipher[5 : 5+gcm.NonceSize()]
raw := cipher[5+gcm.NonceSize():]
out := make([]byte, 0, len(raw)-gcm.NonceSize())
// Verify the cipher byte and attempt to open
switch cipher[4] {
case AESGCMVersion1:
return gcm.Open(out, nonce, raw, nil)
case AESGCMVersion2:
return gcm.Open(out, nonce, raw, []byte(path))
default:
return nil, fmt.Errorf("version bytes mis-match")
}
}
// decryptKeyring is used to decrypt a value using the keyring
func (b *AESGCMBarrier) decryptKeyring(path string, cipher []byte) ([]byte, error) {
// Verify the term
term := binary.BigEndian.Uint32(cipher[:4])
// Get the GCM by term
// It is expensive to do this first but it is not a
// normal case that this won't match
gcm, err := b.aeadForTerm(term)
if err != nil {
return nil, err
}
if gcm == nil {
return nil, fmt.Errorf("no decryption key available for term %d", term)
}
nonce := cipher[5 : 5+gcm.NonceSize()]
raw := cipher[5+gcm.NonceSize():]
out := make([]byte, 0, len(raw)-gcm.NonceSize())
// Attempt to open
switch cipher[4] {
case AESGCMVersion1:
@ -860,13 +881,15 @@ func (b *AESGCMBarrier) Encrypt(ctx context.Context, key string, plaintext []byt
term := b.keyring.ActiveTerm()
primary, err := b.aeadForTerm(term)
b.l.RUnlock()
if err != nil {
b.l.RUnlock()
return nil, err
}
ciphertext := b.encrypt(key, term, primary, plaintext)
b.l.RUnlock()
ciphertext, err := b.encrypt(key, term, primary, plaintext)
if err != nil {
return nil, err
}
return ciphertext, nil
}
@ -878,14 +901,27 @@ func (b *AESGCMBarrier) Decrypt(ctx context.Context, key string, ciphertext []by
return nil, ErrBarrierSealed
}
// Decrypt the ciphertext
plain, err := b.decryptKeyring(key, ciphertext)
// Verify the term
term := binary.BigEndian.Uint32(ciphertext[:4])
// Get the GCM by term
// It is expensive to do this first but it is not a
// normal case that this won't match
gcm, err := b.aeadForTerm(term)
b.l.RUnlock()
if err != nil {
return nil, err
}
if gcm == nil {
return nil, fmt.Errorf("no decryption key available for term %d", term)
}
// Decrypt the ciphertext
plain, err := b.decrypt(key, gcm, ciphertext)
if err != nil {
b.l.RUnlock()
return nil, errwrap.Wrapf("decryption failed: {{err}}", err)
}
b.l.RUnlock()
return plain, nil
}

View File

@ -125,7 +125,10 @@ func TestAESGCMBarrier_BackwardsCompatible(t *testing.T) {
// Protect with master key
master, _ := b.GenerateKey()
gcm, _ := b.aeadFromKey(master)
value := b.encrypt(barrierInitPath, initialKeyTerm, gcm, buf)
value, err := b.encrypt(barrierInitPath, initialKeyTerm, gcm, buf)
if err != nil {
t.Fatal(err)
}
// Write to the physical backend
pe := &physical.Entry{
@ -136,9 +139,13 @@ func TestAESGCMBarrier_BackwardsCompatible(t *testing.T) {
// Create a fake key
gcm, _ = b.aeadFromKey(encrypt)
value, err = b.encrypt("test/foo", initialKeyTerm, gcm, []byte("test"))
if err != nil {
t.Fatal(err)
}
pe = &physical.Entry{
Key: "test/foo",
Value: b.encrypt("test/foo", initialKeyTerm, gcm, []byte("test")),
Value: value,
}
inm.Put(context.Background(), pe)
@ -429,8 +436,14 @@ func TestEncrypt_Unique(t *testing.T) {
term := b.keyring.ActiveTerm()
primary, _ := b.aeadForTerm(term)
first := b.encrypt("test", term, primary, entry.Value)
second := b.encrypt("test", term, primary, entry.Value)
first, err := b.encrypt("test", term, primary, entry.Value)
if err != nil {
t.Fatal(err)
}
second, err := b.encrypt("test", term, primary, entry.Value)
if err != nil {
t.Fatal(err)
}
if bytes.Equal(first, second) == true {
t.Fatalf("improper random seeding detected")