Add AD mode to Transit's AEAD ciphers (#17638)

* Allow passing AssociatedData factories in keysutil

This allows the high-level, algorithm-agnostic Encrypt/Decrypt with
Factory to pass in AssociatedData, and potentially take multiple
factories (to allow KMS keys to work). On AEAD ciphers with a relevant
factory, an AssociatedData factory will be used to populate the
AdditionalData field of the SymmetricOpts struct, using it in the AEAD
Seal process.

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

* Add associated_data to Transit Encrypt/Decrypt API

This allows passing the associated_data (the last AD in AEAD) to
Transit's encrypt/decrypt when using an AEAD cipher (currently
aes128-gcm96, aes256-gcm96, and chacha20-poly1305). We err if this
parameter is passed on non-AEAD ciphers presently.

This associated data can be safely transited in plaintext, without risk
of modifications. In the event of tampering with either the ciphertext
or the associated data, decryption will fail.

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

* Add changelog

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

* Add to documentation

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>
This commit is contained in:
Alexander Scheel 2022-10-24 13:41:02 -04:00 committed by GitHub
parent 73f9b13762
commit 09939f0ba9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 259 additions and 22 deletions

View File

@ -1717,3 +1717,115 @@ func TestTransit_AutoRotateKeys(t *testing.T) {
)
}
}
func TestTransit_AEAD(t *testing.T) {
testTransit_AEAD(t, "aes128-gcm96")
testTransit_AEAD(t, "aes256-gcm96")
testTransit_AEAD(t, "chacha20-poly1305")
}
func testTransit_AEAD(t *testing.T, keyType string) {
var resp *logical.Response
var err error
b, storage := createBackendWithStorage(t)
keyReq := &logical.Request{
Path: "keys/aead",
Operation: logical.UpdateOperation,
Data: map[string]interface{}{
"type": keyType,
},
Storage: storage,
}
resp, err = b.HandleRequest(context.Background(), keyReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" // "the quick brown fox"
associated := "U3BoaW54IG9mIGJsYWNrIHF1YXJ0eiwganVkZ2UgbXkgdm93Lgo=" // "Sphinx of black quartz, judge my vow."
// Basic encrypt/decrypt should work.
encryptReq := &logical.Request{
Path: "encrypt/aead",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"plaintext": plaintext,
},
}
resp, err = b.HandleRequest(context.Background(), encryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
ciphertext1 := resp.Data["ciphertext"].(string)
decryptReq := &logical.Request{
Path: "decrypt/aead",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"ciphertext": ciphertext1,
},
}
resp, err = b.HandleRequest(context.Background(), decryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
decryptedPlaintext := resp.Data["plaintext"]
if plaintext != decryptedPlaintext {
t.Fatalf("bad: plaintext; expected: %q\nactual: %q", plaintext, decryptedPlaintext)
}
// Using associated as ciphertext should fail.
decryptReq.Data["ciphertext"] = associated
resp, err = b.HandleRequest(context.Background(), decryptReq)
if err == nil || (resp != nil && !resp.IsError()) {
t.Fatalf("bad expected error: err: %v\nresp: %#v", err, resp)
}
// Redoing the above with additional data should work.
encryptReq.Data["associated_data"] = associated
resp, err = b.HandleRequest(context.Background(), encryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
ciphertext2 := resp.Data["ciphertext"].(string)
decryptReq.Data["ciphertext"] = ciphertext2
decryptReq.Data["associated_data"] = associated
resp, err = b.HandleRequest(context.Background(), decryptReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: err: %v\nresp: %#v", err, resp)
}
decryptedPlaintext = resp.Data["plaintext"]
if plaintext != decryptedPlaintext {
t.Fatalf("bad: plaintext; expected: %q\nactual: %q", plaintext, decryptedPlaintext)
}
// Removing the associated_data should break the decryption.
decryptReq.Data = map[string]interface{}{
"ciphertext": ciphertext2,
}
resp, err = b.HandleRequest(context.Background(), decryptReq)
if err == nil || (resp != nil && !resp.IsError()) {
t.Fatalf("bad expected error: err: %v\nresp: %#v", err, resp)
}
// Using a valid ciphertext with associated_data should also break the
// decryption.
decryptReq.Data["ciphertext"] = ciphertext1
decryptReq.Data["associated_data"] = associated
resp, err = b.HandleRequest(context.Background(), decryptReq)
if err == nil || (resp != nil && !resp.IsError()) {
t.Fatalf("bad expected error: err: %v\nresp: %#v", err, resp)
}
}

View File

@ -50,6 +50,7 @@ Base64 encoded nonce value used during encryption. Must be provided if
convergent encryption is enabled for this key and the key was generated with
Vault 0.6.1. Not required for keys created in 0.6.2+.`,
},
"partial_failure_response_code": {
Type: framework.TypeInt,
Description: `
@ -58,6 +59,17 @@ the HTTP response code is 400 (Bad Request). Some applications may want to trea
Providing the parameter returns the given response code integer instead of a 400 in this case. If all values fail
HTTP 400 is still returned.`,
},
"associated_data": {
Type: framework.TypeString,
Description: `
When using an AEAD cipher mode, such as AES-GCM, this parameter allows
passing associated data (AD/AAD) into the encryption function; this data
must be passed on subsequent decryption requests but can be transited in
plaintext. On successful decryption, both the ciphertext and the associated
data are attested not to have been tampered with.
`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -90,9 +102,10 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
batchInputItems = make([]BatchRequestItem, 1)
batchInputItems[0] = BatchRequestItem{
Ciphertext: ciphertext,
Context: d.Get("context").(string),
Nonce: d.Get("nonce").(string),
Ciphertext: ciphertext,
Context: d.Get("context").(string),
Nonce: d.Get("nonce").(string),
AssociatedData: d.Get("associated_data").(string),
}
}
@ -155,7 +168,17 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
continue
}
plaintext, err := p.Decrypt(item.DecodedContext, item.DecodedNonce, item.Ciphertext)
var factory interface{}
if item.AssociatedData != "" {
if !p.Type.AssociatedDataSupported() {
batchResponseItems[i].Error = fmt.Sprintf("'[%d].associated_data' provided for non-AEAD cipher suite %v", i, p.Type.String())
continue
}
factory = AssocDataFactory{item.AssociatedData}
}
plaintext, err := p.DecryptWithFactory(item.DecodedContext, item.DecodedNonce, item.Ciphertext, factory)
if err != nil {
switch err.(type) {
case errutil.InternalError:

View File

@ -39,6 +39,9 @@ type BatchRequestItem struct {
// DecodedNonce is the base64 decoded version of Nonce
DecodedNonce []byte
// Associated Data for AEAD ciphers
AssociatedData string `json:"associated_data" struct:"associated_data" mapstructure:"associated_data"`
}
// EncryptBatchResponseItem represents a response item for batch processing
@ -55,6 +58,14 @@ type EncryptBatchResponseItem struct {
Error string `json:"error,omitempty" structs:"error" mapstructure:"error"`
}
type AssocDataFactory struct {
Encoded string
}
func (a AssocDataFactory) GetAssociatedData() ([]byte, error) {
return base64.StdEncoding.DecodeString(a.Encoded)
}
func (b *backend) pathEncrypt() *framework.Path {
return &framework.Path{
Pattern: "encrypt/" + framework.GenericNameRegex("name"),
@ -113,6 +124,7 @@ will severely impact the ciphertext's security.`,
Must be 0 (for latest) or a value greater than or equal
to the min_encryption_version configured on the key.`,
},
"partial_failure_response_code": {
Type: framework.TypeInt,
Description: `
@ -121,6 +133,17 @@ the HTTP response code is 400 (Bad Request). Some applications may want to trea
Providing the parameter returns the given response code integer instead of a 400 in this case. If all values fail
HTTP 400 is still returned.`,
},
"associated_data": {
Type: framework.TypeString,
Description: `
When using an AEAD cipher mode, such as AES-GCM, this parameter allows
passing associated data (AD/AAD) into the encryption function; this data
must be passed on subsequent decryption requests but can be transited in
plaintext. On successful decryption, both the ciphertext and the associated
data are attested not to have been tampered with.
`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -229,6 +252,15 @@ func decodeBatchRequestItems(src interface{}, requirePlaintext bool, requireCiph
errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].key_version' expected type 'int', got unconvertible type '%T'", i, item["key_version"]))
}
}
if v, has := item["associated_data"]; has {
if !reflect.ValueOf(v).IsValid() {
} else if casted, ok := v.(string); ok {
(*dst)[i].AssociatedData = casted
} else {
errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].associated_data' expected type 'string', got unconvertible type '%T'", i, item["associated_data"]))
}
}
}
if len(errs.Errors) > 0 {
@ -279,10 +311,11 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
batchInputItems = make([]BatchRequestItem, 1)
batchInputItems[0] = BatchRequestItem{
Plaintext: valueRaw.(string),
Context: d.Get("context").(string),
Nonce: d.Get("nonce").(string),
KeyVersion: d.Get("key_version").(int),
Plaintext: valueRaw.(string),
Context: d.Get("context").(string),
Nonce: d.Get("nonce").(string),
KeyVersion: d.Get("key_version").(int),
AssociatedData: d.Get("associated_data").(string),
}
}
@ -393,7 +426,17 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
warnAboutNonceUsage = true
}
ciphertext, err := p.Encrypt(item.KeyVersion, item.DecodedContext, item.DecodedNonce, item.Plaintext)
var factory interface{}
if item.AssociatedData != "" {
if !p.Type.AssociatedDataSupported() {
batchResponseItems[i].Error = fmt.Sprintf("'[%d].associated_data' provided for non-AEAD cipher suite %v", i, p.Type.String())
continue
}
factory = AssocDataFactory{item.AssociatedData}
}
ciphertext, err := p.EncryptWithFactory(item.KeyVersion, item.DecodedContext, item.DecodedNonce, item.Plaintext, factory)
if err != nil {
switch err.(type) {
case errutil.InternalError:

3
changelog/17638.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
secrets/transit: Add associated_data parameter for additional authenticated data in AEAD ciphers
```

View File

@ -79,6 +79,10 @@ type AEADFactory interface {
GetAEAD(iv []byte) (cipher.AEAD, error)
}
type AssociatedDataFactory interface {
GetAssociatedData() ([]byte, error)
}
type RestoreInfo struct {
Time time.Time `json:"time"`
Version int `json:"version"`
@ -147,6 +151,14 @@ func (kt KeyType) DerivationSupported() bool {
return false
}
func (kt KeyType) AssociatedDataSupported() bool {
switch kt {
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
return true
}
return false
}
func (kt KeyType) String() string {
switch kt {
case KeyType_AES128_GCM96:
@ -844,6 +856,10 @@ func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string,
}
func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
return p.DecryptWithFactory(context, nonce, value, nil)
}
func (p *Policy) DecryptWithFactory(context, nonce []byte, value string, factories ...interface{}) (string, error) {
if !p.Type.DecryptionSupported() {
return "", errutil.UserError{Err: fmt.Sprintf("message decryption not supported for key type %v", p.Type)}
}
@ -911,11 +927,28 @@ func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
return "", errutil.InternalError{Err: "could not derive enc key, length not correct"}
}
plain, err = p.SymmetricDecryptRaw(encKey, decoded,
SymmetricOpts{
Convergent: p.ConvergentEncryption,
ConvergentVersion: p.ConvergentVersion,
})
symopts := SymmetricOpts{
Convergent: p.ConvergentEncryption,
ConvergentVersion: p.ConvergentVersion,
}
for index, rawFactory := range factories {
if rawFactory == nil {
continue
}
switch factory := rawFactory.(type) {
case AEADFactory:
symopts.AEADFactory = factory
case AssociatedDataFactory:
symopts.AdditionalData, err = factory.GetAssociatedData()
if err != nil {
return "", errutil.InternalError{Err: fmt.Sprintf("unable to get associated_data/additional_data from factory[%d]: %v", index, err)}
}
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unknown type of factory[%d]: %T", index, rawFactory)}
}
}
plain, err = p.SymmetricDecryptRaw(encKey, decoded, symopts)
if err != nil {
return "", err
}
@ -1830,7 +1863,7 @@ func (p *Policy) SymmetricDecryptRaw(encKey, ciphertext []byte, opts SymmetricOp
return plain, nil
}
func (p *Policy) EncryptWithFactory(ver int, context []byte, nonce []byte, value string, factory AEADFactory) (string, error) {
func (p *Policy) EncryptWithFactory(ver int, context []byte, nonce []byte, value string, factories ...interface{}) (string, error) {
if !p.Type.EncryptionSupported() {
return "", errutil.UserError{Err: fmt.Sprintf("message encryption not supported for key type %v", p.Type)}
}
@ -1891,14 +1924,29 @@ func (p *Policy) EncryptWithFactory(ver int, context []byte, nonce []byte, value
}
}
ciphertext, err = p.SymmetricEncryptRaw(ver, encKey, plaintext,
SymmetricOpts{
Convergent: p.ConvergentEncryption,
HMACKey: hmacKey,
Nonce: nonce,
AEADFactory: factory,
})
symopts := SymmetricOpts{
Convergent: p.ConvergentEncryption,
HMACKey: hmacKey,
Nonce: nonce,
}
for index, rawFactory := range factories {
if rawFactory == nil {
continue
}
switch factory := rawFactory.(type) {
case AEADFactory:
symopts.AEADFactory = factory
case AssociatedDataFactory:
symopts.AdditionalData, err = factory.GetAssociatedData()
if err != nil {
return "", errutil.InternalError{Err: fmt.Sprintf("unable to get associated_data/additional_data from factory[%d]: %v", index, err)}
}
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unknown type of factory[%d]: %T", index, rawFactory)}
}
}
ciphertext, err = p.SymmetricEncryptRaw(ver, encKey, plaintext, symopts)
if err != nil {
return "", err
}

View File

@ -535,6 +535,10 @@ will be returned.
- `plaintext` `(string: <required>)` Specifies **base64 encoded** plaintext to
be encoded.
- `associated_data` `(string: "")` - Specifies **base64 encoded** associated
data (also known as additional data or AAD) to also be authenticated with
AEAD ciphers (`aes128-gcm96`, `aes256-gcm`, and `chacha20-poly1305`).
- `context` `(string: "")`  Specifies the **base64 encoded** context for key
derivation. This is required if key derivation is enabled for this key.
@ -646,6 +650,10 @@ This endpoint decrypts the provided ciphertext using the named key.
- `ciphertext` `(string: <required>)`  Specifies the ciphertext to decrypt.
- `associated_data` `(string: "")` - Specifies **base64 encoded** associated
data (also known as additional data or AAD) to also be authenticated with
AEAD ciphers (`aes128-gcm96`, `aes256-gcm`, and `chacha20-poly1305`).
- `context` `(string: "")` Specifies the **base64 encoded** context for key
derivation. This is required if key derivation is enabled.