Refactor convergent encryption to make specifying a nonce in addition to context possible

This commit is contained in:
Jeff Mitchell 2016-08-05 17:52:44 -04:00
parent 8c209dd0d6
commit 8b1d47037e
9 changed files with 152 additions and 48 deletions

View File

@ -601,6 +601,7 @@ func TestConvergentEncryption(t *testing.T) {
Data: map[string]interface{}{
"derived": false,
"convergent_encryption": true,
"context_as_nonce": true,
},
}
@ -619,6 +620,7 @@ func TestConvergentEncryption(t *testing.T) {
req.Data = map[string]interface{}{
"derived": true,
"convergent_encryption": true,
"context_as_nonce": true,
}
resp, err = b.HandleRequest(req)

View File

@ -105,42 +105,42 @@ func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
// is needed (for instance, for an upgrade/migration), give up the read lock,
// call again with an exclusive lock, then swap back out for a read lock.
func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, false, shared)
p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, false, false, shared)
if err == nil ||
(err != nil && err != errNeedExclusiveLock) {
return p, lock, err
}
// Try again while asking for an exlusive lock
p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, false, exclusive)
p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, false, false, exclusive)
if err != nil || p == nil || lock == nil {
return p, lock, err
}
lock.Unlock()
p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, false, shared)
p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, false, false, shared)
return p, lock, err
}
// Get the policy with an exclusive lock
func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, false, exclusive)
p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, false, false, exclusive)
return p, lock, err
}
// Get the policy with a read lock; if it returns that an exclusive lock is
// needed, retry. If successful, call one more time to get a read lock and
// return the value.
func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived bool, convergent bool) (*Policy, *sync.RWMutex, bool, error) {
p, lock, _, err := lm.getPolicyCommon(storage, name, true, derived, convergent, shared)
func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived, convergent, contextAsNonce bool) (*Policy, *sync.RWMutex, bool, error) {
p, lock, _, err := lm.getPolicyCommon(storage, name, true, derived, convergent, contextAsNonce, shared)
if err == nil ||
(err != nil && err != errNeedExclusiveLock) {
return p, lock, false, err
}
// Try again while asking for an exlusive lock
p, lock, upserted, err := lm.getPolicyCommon(storage, name, true, derived, convergent, exclusive)
p, lock, upserted, err := lm.getPolicyCommon(storage, name, true, derived, convergent, contextAsNonce, exclusive)
if err != nil || p == nil || lock == nil {
return p, lock, upserted, err
}
@ -148,14 +148,14 @@ func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, der
lock.Unlock()
// Now get a shared lock for the return, but preserve the value of upsert
p, lock, _, err = lm.getPolicyCommon(storage, name, true, derived, convergent, shared)
p, lock, _, err = lm.getPolicyCommon(storage, name, true, derived, convergent, contextAsNonce, shared)
return p, lock, upserted, err
}
// When the function returns, a lock will be held on the policy if err == nil.
// It is the caller's responsibility to unlock.
func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived, convergent, lockType bool) (*Policy, *sync.RWMutex, bool, error) {
func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived, convergent, contextAsNonce, lockType bool) (*Policy, *sync.RWMutex, bool, error) {
lock := lm.policyLock(name, lockType)
var p *Policy
@ -204,6 +204,8 @@ func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, ups
if derived {
p.KDFMode = kdfMode
p.ConvergentEncryption = convergent
p.ContextAsNonce = new(bool)
*p.ContextAsNonce = contextAsNonce
}
err = p.rotate(storage)

View File

@ -100,7 +100,7 @@ func (b *backend) pathDatakeyWrite(
return nil, err
}
ciphertext, err := p.Encrypt(context, base64.StdEncoding.EncodeToString(newKey))
ciphertext, err := p.Encrypt(context, nil, base64.StdEncoding.EncodeToString(newKey))
if err != nil {
switch err.(type) {
case errutil.UserError:

View File

@ -27,6 +27,11 @@ func (b *backend) pathDecrypt() *framework.Path {
Type: framework.TypeString,
Description: "Context for key derivation. Required for derived keys.",
},
"nonce": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Nonce for when convergent encryption is used and the context is not used as the nonce",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -46,17 +51,28 @@ func (b *backend) pathDecryptWrite(
return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest
}
var err error
// Decode the context if any
contextRaw := d.Get("context").(string)
var context []byte
if len(contextRaw) != 0 {
var err error
context, err = base64.StdEncoding.DecodeString(contextRaw)
if err != nil {
return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest
}
}
// Decode the nonce if any
nonceRaw := d.Get("nonce").(string)
var nonce []byte
if len(nonceRaw) != 0 {
nonce, err = base64.StdEncoding.DecodeString(nonceRaw)
if err != nil {
return logical.ErrorResponse("failed to decode nonce as base64"), logical.ErrInvalidRequest
}
}
// Get the policy
p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
if lock != nil {
@ -69,7 +85,7 @@ func (b *backend) pathDecryptWrite(
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
plaintext, err := p.Decrypt(context, ciphertext)
plaintext, err := p.Decrypt(context, nonce, ciphertext)
if err != nil {
switch err.(type) {
case errutil.UserError:

View File

@ -28,6 +28,11 @@ func (b *backend) pathEncrypt() *framework.Path {
Type: framework.TypeString,
Description: "Context for key derivation. Required for derived keys.",
},
"nonce": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Nonce for when convergent encryption is used and the context is not used as the nonce",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -63,10 +68,11 @@ func (b *backend) pathEncryptWrite(
return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest
}
var err error
// Decode the context if any
contextRaw := d.Get("context").(string)
var context []byte
var err error
if len(contextRaw) != 0 {
context, err = base64.StdEncoding.DecodeString(contextRaw)
if err != nil {
@ -74,12 +80,22 @@ func (b *backend) pathEncryptWrite(
}
}
// Decode the nonce if any
nonceRaw := d.Get("nonce").(string)
var nonce []byte
if len(nonceRaw) != 0 {
nonce, err = base64.StdEncoding.DecodeString(nonceRaw)
if err != nil {
return logical.ErrorResponse("failed to decode nonce as base64"), logical.ErrInvalidRequest
}
}
// Get the policy
var p *Policy
var lock *sync.RWMutex
var upserted bool
if req.Operation == logical.CreateOperation {
p, lock, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0, false)
p, lock, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0, false, false)
} else {
p, lock, err = b.lm.GetPolicyShared(req.Storage, name)
}
@ -93,7 +109,7 @@ func (b *backend) pathEncryptWrite(
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
ciphertext, err := p.Encrypt(context, value)
ciphertext, err := p.Encrypt(context, nonce, value)
if err != nil {
switch err.(type) {
case errutil.UserError:

View File

@ -18,22 +18,41 @@ func (b *backend) pathKeys() *framework.Path {
},
"derived": &framework.FieldSchema{
Type: framework.TypeBool,
Description: "Enables key derivation mode. This allows for per-transaction unique keys",
Type: framework.TypeBool,
Description: `Enables key derivation mode. This
allows for per-transaction unique keys.`,
},
"convergent_encryption": &framework.FieldSchema{
Type: framework.TypeBool,
Description: `Whether to use convergent encryption.
Description: `Whether to support convergent encryption.
This is only supported when using a key with
key derivation enabled and will require all
context values to be 96 bits (12 bytes) when
base64-decoded. This mode ensures that when
the same context is supplied, the same
ciphertext is emitted from the encryption
function. It is *very important* when using
this mode that you ensure that all contexts
are *globally unique*. Failing to do so will
requests to carry both a context and 96-bit
(12-byte) nonce, unless the "context_as_nonce"
feature is also enabled. The given nonce will
be used in place of a randomly generated nonce.
As a result, when the same context and nonce
(or context, if "context_as_nonce" is enabled)
are supplied, the same ciphertext is emitted
from the encryption function. It is *very
important* when using this mode that you ensure
that all nonces are unique for a given context,
or, when using "context_as_nonce", that all
contexts are unique for a given key. Failing to
do so will severely impact the ciphertext's
security.`,
},
"context_as_nonce": &framework.FieldSchema{
Type: framework.TypeBool,
Description: `Whether to use the context value as the
nonce in the convergent encryption operation
mode. If set true, the user will have to
supply a 96-bit (12-byte) context value.
It is *very important* when using this
mode that you ensure that all contexts are
*globally unique*. Failing to do so will
severely impact the security of the key.`,
},
},
@ -54,12 +73,13 @@ func (b *backend) pathPolicyWrite(
name := d.Get("name").(string)
derived := d.Get("derived").(bool)
convergent := d.Get("convergent_encryption").(bool)
contextAsNonce := d.Get("context_as_nonce").(bool)
if !derived && convergent {
return logical.ErrorResponse("convergent encryption requires derivation to be enabled"), nil
}
p, lock, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived, convergent)
p, lock, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived, convergent, contextAsNonce)
if lock != nil {
defer lock.RUnlock()
}
@ -107,6 +127,9 @@ func (b *backend) pathPolicyRead(
if p.Derived {
resp.Data["kdf_mode"] = p.KDFMode
resp.Data["convergent_encryption"] = p.ConvergentEncryption
if p.ContextAsNonce != nil {
resp.Data["context_as_nonce"] = *p.ContextAsNonce
}
}
retKeys := map[string]int64{}

View File

@ -27,6 +27,11 @@ func (b *backend) pathRewrap() *framework.Path {
Type: framework.TypeString,
Description: "Context for key derivation. Required for derived keys.",
},
"nonce": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Nonce for when convergent encryption is used and the context is not used as the nonce",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -47,17 +52,28 @@ func (b *backend) pathRewrapWrite(
return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest
}
var err error
// Decode the context if any
contextRaw := d.Get("context").(string)
var context []byte
if len(contextRaw) != 0 {
var err error
context, err = base64.StdEncoding.DecodeString(contextRaw)
if err != nil {
return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest
}
}
// Decode the nonce if any
nonceRaw := d.Get("nonce").(string)
var nonce []byte
if len(nonceRaw) != 0 {
nonce, err = base64.StdEncoding.DecodeString(nonceRaw)
if err != nil {
return logical.ErrorResponse("failed to decode nonce as base64"), logical.ErrInvalidRequest
}
}
// Get the policy
p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
if lock != nil {
@ -71,7 +87,7 @@ func (b *backend) pathRewrapWrite(
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
plaintext, err := p.Decrypt(context, value)
plaintext, err := p.Decrypt(context, nonce, value)
if err != nil {
switch err.(type) {
case errutil.UserError:
@ -87,7 +103,7 @@ func (b *backend) pathRewrapWrite(
return nil, fmt.Errorf("empty plaintext returned during rewrap")
}
ciphertext, err := p.Encrypt(context, plaintext)
ciphertext, err := p.Encrypt(context, nonce, plaintext)
if err != nil {
switch err.(type) {
case errutil.UserError:

View File

@ -72,6 +72,7 @@ type Policy struct {
Derived bool `json:"derived"`
KDFMode string `json:"kdf_mode"`
ConvergentEncryption bool `json:"convergent_encryption"`
ContextAsNonce *bool `json:"context_as_nonce"`
// The minimum version of the key allowed to be used
// for decryption
@ -259,6 +260,10 @@ func (p *Policy) needsUpgrade() bool {
return true
}
if p.ConvergentEncryption && p.ContextAsNonce == nil {
return true
}
return false
}
@ -288,6 +293,14 @@ func (p *Policy) upgrade(storage logical.Storage) error {
persistNeeded = true
}
// Originally the context-as-nonce mode was the only mode, so keep that
// behavior if convergent encryption is already in use
if p.ConvergentEncryption && p.ContextAsNonce == nil {
p.ContextAsNonce = new(bool)
*p.ContextAsNonce = true
persistNeeded = true
}
if persistNeeded {
err := p.Persist(storage)
if err != nil {
@ -307,10 +320,6 @@ func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"}
}
if p.LatestVersion == 0 {
return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"}
}
if ver <= 0 || ver > p.LatestVersion {
return nil, errutil.UserError{Err: "invalid key version"}
}
@ -335,7 +344,7 @@ func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
}
}
func (p *Policy) Encrypt(context []byte, value string) (string, error) {
func (p *Policy) Encrypt(context, nonce []byte, value string) (string, error) {
// Decode the plaintext value
plaintext, err := base64.StdEncoding.DecodeString(value)
if err != nil {
@ -367,15 +376,20 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) {
return "", errutil.InternalError{Err: err.Error()}
}
if p.ConvergentEncryption && len(context) != gcm.NonceSize() {
return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded context must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())}
}
// Compute random nonce
var nonce []byte
if p.ConvergentEncryption {
nonce = context
if *p.ContextAsNonce {
if len(context) != gcm.NonceSize() {
return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded context must be %d bytes long when using convergent encryption with context-as-nonce with this key", gcm.NonceSize())}
}
nonce = context
} else if len(nonce) != gcm.NonceSize() {
return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())}
}
} else {
// Compute random nonce
nonce = make([]byte, gcm.NonceSize())
_, err = rand.Read(nonce)
if err != nil {
@ -387,7 +401,10 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) {
out := gcm.Seal(nil, nonce, plaintext, nil)
// Place the encrypted data after the nonce
full := append(nonce, out...)
full := out
if !p.ConvergentEncryption {
full = append(nonce, out...)
}
// Convert to base64
encoded := base64.StdEncoding.EncodeToString(full)
@ -398,12 +415,16 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) {
return encoded, nil
}
func (p *Policy) Decrypt(context []byte, value string) (string, error) {
func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
// Verify the prefix
if !strings.HasPrefix(value, "vault:v") {
return "", errutil.UserError{Err: "invalid ciphertext: no prefix"}
}
if p.ConvergentEncryption && !*p.ContextAsNonce && (nonce == nil || len(nonce) == 0) {
return "", errutil.UserError{Err: "invalid convergent nonce supplied"}
}
splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2)
if len(splitVerCiphertext) != 2 {
return "", errutil.UserError{Err: "invalid ciphertext: wrong number of fields"}
@ -460,8 +481,16 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) {
}
// Extract the nonce and ciphertext
nonce := decoded[:gcm.NonceSize()]
ciphertext := decoded[gcm.NonceSize():]
var ciphertext []byte
if p.ConvergentEncryption {
if *p.ContextAsNonce {
nonce = context
}
ciphertext = decoded
} else {
nonce = decoded[:gcm.NonceSize()]
ciphertext = decoded[gcm.NonceSize():]
}
// Verify and Decrypt
plain, err := gcm.Open(nil, nonce, ciphertext, nil)

View File

@ -22,7 +22,7 @@ func Test_KeyUpgrade(t *testing.T) {
func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
storage := &logical.InmemStorage{}
p, lock, upserted, err := lm.GetPolicyUpsert(storage, "test", false, false)
p, lock, upserted, err := lm.GetPolicyUpsert(storage, "test", false, false, false)
if lock != nil {
defer lock.RUnlock()
}
@ -68,7 +68,7 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
storage := &logical.InmemStorage{}
p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false, false)
p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false, false, false)
if err != nil {
t.Fatal(err)
}
@ -198,7 +198,7 @@ func testArchivingCommon(t *testing.T, lm *lockManager) {
storage := &logical.InmemStorage{}
p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false, false)
p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false, false, false)
if lock != nil {
defer lock.RUnlock()
}