diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index c4d92a3dd..cfbd4f4da 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -1717,3 +1717,115 @@ func TestTransit_AutoRotateKeys(t *testing.T) { ) } } + +func TestTransit_AEAD(t *testing.T) { + testTransit_AEAD(t, "aes128-gcm96") + testTransit_AEAD(t, "aes256-gcm96") + testTransit_AEAD(t, "chacha20-poly1305") +} + +func testTransit_AEAD(t *testing.T, keyType string) { + var resp *logical.Response + var err error + b, storage := createBackendWithStorage(t) + + keyReq := &logical.Request{ + Path: "keys/aead", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "type": keyType, + }, + Storage: storage, + } + + resp, err = b.HandleRequest(context.Background(), keyReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + + plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" // "the quick brown fox" + associated := "U3BoaW54IG9mIGJsYWNrIHF1YXJ0eiwganVkZ2UgbXkgdm93Lgo=" // "Sphinx of black quartz, judge my vow." + + // Basic encrypt/decrypt should work. + encryptReq := &logical.Request{ + Path: "encrypt/aead", + Operation: logical.UpdateOperation, + Storage: storage, + Data: map[string]interface{}{ + "plaintext": plaintext, + }, + } + + resp, err = b.HandleRequest(context.Background(), encryptReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + + ciphertext1 := resp.Data["ciphertext"].(string) + + decryptReq := &logical.Request{ + Path: "decrypt/aead", + Operation: logical.UpdateOperation, + Storage: storage, + Data: map[string]interface{}{ + "ciphertext": ciphertext1, + }, + } + + resp, err = b.HandleRequest(context.Background(), decryptReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + + decryptedPlaintext := resp.Data["plaintext"] + + if plaintext != decryptedPlaintext { + t.Fatalf("bad: plaintext; expected: %q\nactual: %q", plaintext, decryptedPlaintext) + } + + // Using associated as ciphertext should fail. + decryptReq.Data["ciphertext"] = associated + resp, err = b.HandleRequest(context.Background(), decryptReq) + if err == nil || (resp != nil && !resp.IsError()) { + t.Fatalf("bad expected error: err: %v\nresp: %#v", err, resp) + } + + // Redoing the above with additional data should work. + encryptReq.Data["associated_data"] = associated + resp, err = b.HandleRequest(context.Background(), encryptReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + + ciphertext2 := resp.Data["ciphertext"].(string) + decryptReq.Data["ciphertext"] = ciphertext2 + decryptReq.Data["associated_data"] = associated + + resp, err = b.HandleRequest(context.Background(), decryptReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: err: %v\nresp: %#v", err, resp) + } + + decryptedPlaintext = resp.Data["plaintext"] + if plaintext != decryptedPlaintext { + t.Fatalf("bad: plaintext; expected: %q\nactual: %q", plaintext, decryptedPlaintext) + } + + // Removing the associated_data should break the decryption. + decryptReq.Data = map[string]interface{}{ + "ciphertext": ciphertext2, + } + resp, err = b.HandleRequest(context.Background(), decryptReq) + if err == nil || (resp != nil && !resp.IsError()) { + t.Fatalf("bad expected error: err: %v\nresp: %#v", err, resp) + } + + // Using a valid ciphertext with associated_data should also break the + // decryption. + decryptReq.Data["ciphertext"] = ciphertext1 + decryptReq.Data["associated_data"] = associated + resp, err = b.HandleRequest(context.Background(), decryptReq) + if err == nil || (resp != nil && !resp.IsError()) { + t.Fatalf("bad expected error: err: %v\nresp: %#v", err, resp) + } +} diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index 820079873..c6c705442 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -50,6 +50,7 @@ Base64 encoded nonce value used during encryption. Must be provided if convergent encryption is enabled for this key and the key was generated with Vault 0.6.1. Not required for keys created in 0.6.2+.`, }, + "partial_failure_response_code": { Type: framework.TypeInt, Description: ` @@ -58,6 +59,17 @@ the HTTP response code is 400 (Bad Request). Some applications may want to trea Providing the parameter returns the given response code integer instead of a 400 in this case. If all values fail HTTP 400 is still returned.`, }, + + "associated_data": { + Type: framework.TypeString, + Description: ` +When using an AEAD cipher mode, such as AES-GCM, this parameter allows +passing associated data (AD/AAD) into the encryption function; this data +must be passed on subsequent decryption requests but can be transited in +plaintext. On successful decryption, both the ciphertext and the associated +data are attested not to have been tampered with. + `, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -90,9 +102,10 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d batchInputItems = make([]BatchRequestItem, 1) batchInputItems[0] = BatchRequestItem{ - Ciphertext: ciphertext, - Context: d.Get("context").(string), - Nonce: d.Get("nonce").(string), + Ciphertext: ciphertext, + Context: d.Get("context").(string), + Nonce: d.Get("nonce").(string), + AssociatedData: d.Get("associated_data").(string), } } @@ -155,7 +168,17 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d continue } - plaintext, err := p.Decrypt(item.DecodedContext, item.DecodedNonce, item.Ciphertext) + var factory interface{} + if item.AssociatedData != "" { + if !p.Type.AssociatedDataSupported() { + batchResponseItems[i].Error = fmt.Sprintf("'[%d].associated_data' provided for non-AEAD cipher suite %v", i, p.Type.String()) + continue + } + + factory = AssocDataFactory{item.AssociatedData} + } + + plaintext, err := p.DecryptWithFactory(item.DecodedContext, item.DecodedNonce, item.Ciphertext, factory) if err != nil { switch err.(type) { case errutil.InternalError: diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index 3e0c72037..6c0a3fc98 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -39,6 +39,9 @@ type BatchRequestItem struct { // DecodedNonce is the base64 decoded version of Nonce DecodedNonce []byte + + // Associated Data for AEAD ciphers + AssociatedData string `json:"associated_data" struct:"associated_data" mapstructure:"associated_data"` } // EncryptBatchResponseItem represents a response item for batch processing @@ -55,6 +58,14 @@ type EncryptBatchResponseItem struct { Error string `json:"error,omitempty" structs:"error" mapstructure:"error"` } +type AssocDataFactory struct { + Encoded string +} + +func (a AssocDataFactory) GetAssociatedData() ([]byte, error) { + return base64.StdEncoding.DecodeString(a.Encoded) +} + func (b *backend) pathEncrypt() *framework.Path { return &framework.Path{ Pattern: "encrypt/" + framework.GenericNameRegex("name"), @@ -113,6 +124,7 @@ will severely impact the ciphertext's security.`, Must be 0 (for latest) or a value greater than or equal to the min_encryption_version configured on the key.`, }, + "partial_failure_response_code": { Type: framework.TypeInt, Description: ` @@ -121,6 +133,17 @@ the HTTP response code is 400 (Bad Request). Some applications may want to trea Providing the parameter returns the given response code integer instead of a 400 in this case. If all values fail HTTP 400 is still returned.`, }, + + "associated_data": { + Type: framework.TypeString, + Description: ` +When using an AEAD cipher mode, such as AES-GCM, this parameter allows +passing associated data (AD/AAD) into the encryption function; this data +must be passed on subsequent decryption requests but can be transited in +plaintext. On successful decryption, both the ciphertext and the associated +data are attested not to have been tampered with. + `, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -229,6 +252,15 @@ func decodeBatchRequestItems(src interface{}, requirePlaintext bool, requireCiph errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].key_version' expected type 'int', got unconvertible type '%T'", i, item["key_version"])) } } + + if v, has := item["associated_data"]; has { + if !reflect.ValueOf(v).IsValid() { + } else if casted, ok := v.(string); ok { + (*dst)[i].AssociatedData = casted + } else { + errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].associated_data' expected type 'string', got unconvertible type '%T'", i, item["associated_data"])) + } + } } if len(errs.Errors) > 0 { @@ -279,10 +311,11 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d batchInputItems = make([]BatchRequestItem, 1) batchInputItems[0] = BatchRequestItem{ - Plaintext: valueRaw.(string), - Context: d.Get("context").(string), - Nonce: d.Get("nonce").(string), - KeyVersion: d.Get("key_version").(int), + Plaintext: valueRaw.(string), + Context: d.Get("context").(string), + Nonce: d.Get("nonce").(string), + KeyVersion: d.Get("key_version").(int), + AssociatedData: d.Get("associated_data").(string), } } @@ -393,7 +426,17 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d warnAboutNonceUsage = true } - ciphertext, err := p.Encrypt(item.KeyVersion, item.DecodedContext, item.DecodedNonce, item.Plaintext) + var factory interface{} + if item.AssociatedData != "" { + if !p.Type.AssociatedDataSupported() { + batchResponseItems[i].Error = fmt.Sprintf("'[%d].associated_data' provided for non-AEAD cipher suite %v", i, p.Type.String()) + continue + } + + factory = AssocDataFactory{item.AssociatedData} + } + + ciphertext, err := p.EncryptWithFactory(item.KeyVersion, item.DecodedContext, item.DecodedNonce, item.Plaintext, factory) if err != nil { switch err.(type) { case errutil.InternalError: diff --git a/changelog/17638.txt b/changelog/17638.txt new file mode 100644 index 000000000..37e057539 --- /dev/null +++ b/changelog/17638.txt @@ -0,0 +1,3 @@ +```release-note:improvement +secrets/transit: Add associated_data parameter for additional authenticated data in AEAD ciphers +``` diff --git a/sdk/helper/keysutil/policy.go b/sdk/helper/keysutil/policy.go index 73afb9ecd..3417c2992 100644 --- a/sdk/helper/keysutil/policy.go +++ b/sdk/helper/keysutil/policy.go @@ -79,6 +79,10 @@ type AEADFactory interface { GetAEAD(iv []byte) (cipher.AEAD, error) } +type AssociatedDataFactory interface { + GetAssociatedData() ([]byte, error) +} + type RestoreInfo struct { Time time.Time `json:"time"` Version int `json:"version"` @@ -147,6 +151,14 @@ func (kt KeyType) DerivationSupported() bool { return false } +func (kt KeyType) AssociatedDataSupported() bool { + switch kt { + case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305: + return true + } + return false +} + func (kt KeyType) String() string { switch kt { case KeyType_AES128_GCM96: @@ -844,6 +856,10 @@ func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, } func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) { + return p.DecryptWithFactory(context, nonce, value, nil) +} + +func (p *Policy) DecryptWithFactory(context, nonce []byte, value string, factories ...interface{}) (string, error) { if !p.Type.DecryptionSupported() { return "", errutil.UserError{Err: fmt.Sprintf("message decryption not supported for key type %v", p.Type)} } @@ -911,11 +927,28 @@ func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) { return "", errutil.InternalError{Err: "could not derive enc key, length not correct"} } - plain, err = p.SymmetricDecryptRaw(encKey, decoded, - SymmetricOpts{ - Convergent: p.ConvergentEncryption, - ConvergentVersion: p.ConvergentVersion, - }) + symopts := SymmetricOpts{ + Convergent: p.ConvergentEncryption, + ConvergentVersion: p.ConvergentVersion, + } + for index, rawFactory := range factories { + if rawFactory == nil { + continue + } + switch factory := rawFactory.(type) { + case AEADFactory: + symopts.AEADFactory = factory + case AssociatedDataFactory: + symopts.AdditionalData, err = factory.GetAssociatedData() + if err != nil { + return "", errutil.InternalError{Err: fmt.Sprintf("unable to get associated_data/additional_data from factory[%d]: %v", index, err)} + } + default: + return "", errutil.InternalError{Err: fmt.Sprintf("unknown type of factory[%d]: %T", index, rawFactory)} + } + } + + plain, err = p.SymmetricDecryptRaw(encKey, decoded, symopts) if err != nil { return "", err } @@ -1830,7 +1863,7 @@ func (p *Policy) SymmetricDecryptRaw(encKey, ciphertext []byte, opts SymmetricOp return plain, nil } -func (p *Policy) EncryptWithFactory(ver int, context []byte, nonce []byte, value string, factory AEADFactory) (string, error) { +func (p *Policy) EncryptWithFactory(ver int, context []byte, nonce []byte, value string, factories ...interface{}) (string, error) { if !p.Type.EncryptionSupported() { return "", errutil.UserError{Err: fmt.Sprintf("message encryption not supported for key type %v", p.Type)} } @@ -1891,14 +1924,29 @@ func (p *Policy) EncryptWithFactory(ver int, context []byte, nonce []byte, value } } - ciphertext, err = p.SymmetricEncryptRaw(ver, encKey, plaintext, - SymmetricOpts{ - Convergent: p.ConvergentEncryption, - HMACKey: hmacKey, - Nonce: nonce, - AEADFactory: factory, - }) + symopts := SymmetricOpts{ + Convergent: p.ConvergentEncryption, + HMACKey: hmacKey, + Nonce: nonce, + } + for index, rawFactory := range factories { + if rawFactory == nil { + continue + } + switch factory := rawFactory.(type) { + case AEADFactory: + symopts.AEADFactory = factory + case AssociatedDataFactory: + symopts.AdditionalData, err = factory.GetAssociatedData() + if err != nil { + return "", errutil.InternalError{Err: fmt.Sprintf("unable to get associated_data/additional_data from factory[%d]: %v", index, err)} + } + default: + return "", errutil.InternalError{Err: fmt.Sprintf("unknown type of factory[%d]: %T", index, rawFactory)} + } + } + ciphertext, err = p.SymmetricEncryptRaw(ver, encKey, plaintext, symopts) if err != nil { return "", err } diff --git a/website/content/api-docs/secret/transit.mdx b/website/content/api-docs/secret/transit.mdx index 7c1827b46..f5870af1f 100644 --- a/website/content/api-docs/secret/transit.mdx +++ b/website/content/api-docs/secret/transit.mdx @@ -535,6 +535,10 @@ will be returned. - `plaintext` `(string: )` – Specifies **base64 encoded** plaintext to be encoded. +- `associated_data` `(string: "")` - Specifies **base64 encoded** associated + data (also known as additional data or AAD) to also be authenticated with + AEAD ciphers (`aes128-gcm96`, `aes256-gcm`, and `chacha20-poly1305`). + - `context` `(string: "")` – Specifies the **base64 encoded** context for key derivation. This is required if key derivation is enabled for this key. @@ -646,6 +650,10 @@ This endpoint decrypts the provided ciphertext using the named key. - `ciphertext` `(string: )` – Specifies the ciphertext to decrypt. +- `associated_data` `(string: "")` - Specifies **base64 encoded** associated + data (also known as additional data or AAD) to also be authenticated with + AEAD ciphers (`aes128-gcm96`, `aes256-gcm`, and `chacha20-poly1305`). + - `context` `(string: "")` – Specifies the **base64 encoded** context for key derivation. This is required if key derivation is enabled.