Return non-retryable errors on transit encrypt and decrypt failures (#13111)
* Return HTTP 400s on transit decrypt requests where decryption fails. (#10842) * Don't abort transit batch encryption when a single batch item fails. * Add unit tests for updated transit batch decryption behavior. * Add changelog entry for transit encrypt/decrypt batch abort fix. * Simplify transit batch error message generation when ciphertext is empty. * Return error HTTP status codes in transit on partial batch decrypt failure. * Return error HTTP status codes in transit on partial batch encrypt failure. * Properly account for non-batch transit decryption failure return. Simplify transit batch decryption test data. Ensure HTTP status codes are expected values on batch transit batch decryption partial failure. * Properly account for non-batch transit encryption failure return. Actually return error HTTP status code on transit batch encryption failure (partial or full).
This commit is contained in:
parent
3d46021d4e
commit
0abd248c9f
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/errutil"
|
||||
|
@ -91,12 +92,16 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
|
|||
batchResponseItems := make([]DecryptBatchResponseItem, len(batchInputItems))
|
||||
contextSet := len(batchInputItems[0].Context) != 0
|
||||
|
||||
userErrorInBatch := false
|
||||
internalErrorInBatch := false
|
||||
|
||||
for i, item := range batchInputItems {
|
||||
if (len(item.Context) == 0 && contextSet) || (len(item.Context) != 0 && !contextSet) {
|
||||
return logical.ErrorResponse("context should be set either in all the request blocks or in none"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
if item.Ciphertext == "" {
|
||||
userErrorInBatch = true
|
||||
batchResponseItems[i].Error = "missing ciphertext to decrypt"
|
||||
continue
|
||||
}
|
||||
|
@ -105,6 +110,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
|
|||
if len(item.Context) != 0 {
|
||||
batchInputItems[i].DecodedContext, err = base64.StdEncoding.DecodeString(item.Context)
|
||||
if err != nil {
|
||||
userErrorInBatch = true
|
||||
batchResponseItems[i].Error = err.Error()
|
||||
continue
|
||||
}
|
||||
|
@ -114,6 +120,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
|
|||
if len(item.Nonce) != 0 {
|
||||
batchInputItems[i].DecodedNonce, err = base64.StdEncoding.DecodeString(item.Nonce)
|
||||
if err != nil {
|
||||
userErrorInBatch = true
|
||||
batchResponseItems[i].Error = err.Error()
|
||||
continue
|
||||
}
|
||||
|
@ -143,13 +150,13 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
|
|||
plaintext, err := p.Decrypt(item.DecodedContext, item.DecodedNonce, item.Ciphertext)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case errutil.UserError:
|
||||
case errutil.InternalError:
|
||||
internalErrorInBatch = true
|
||||
default:
|
||||
userErrorInBatch = true
|
||||
}
|
||||
batchResponseItems[i].Error = err.Error()
|
||||
continue
|
||||
default:
|
||||
p.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
batchResponseItems[i].Plaintext = plaintext
|
||||
}
|
||||
|
@ -162,6 +169,11 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
|
|||
} else {
|
||||
if batchResponseItems[0].Error != "" {
|
||||
p.Unlock()
|
||||
|
||||
if internalErrorInBatch {
|
||||
return nil, errutil.InternalError{Err: batchResponseItems[0].Error}
|
||||
}
|
||||
|
||||
return logical.ErrorResponse(batchResponseItems[0].Error), logical.ErrInvalidRequest
|
||||
}
|
||||
resp.Data = map[string]interface{}{
|
||||
|
@ -170,6 +182,18 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
|
|||
}
|
||||
|
||||
p.Unlock()
|
||||
|
||||
// Depending on the errors in the batch, different status codes should be returned. User errors
|
||||
// will return a 400 and precede internal errors which return a 500. The reasoning behind this is
|
||||
// that user errors are non-retryable without making changes to the request, and should be surfaced
|
||||
// to the user first.
|
||||
switch {
|
||||
case userErrorInBatch:
|
||||
return logical.RespondWithStatusCode(resp, req, http.StatusBadRequest)
|
||||
case internalErrorInBatch:
|
||||
return logical.RespondWithStatusCode(resp, req, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -3,9 +3,13 @@ package transit
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
func TestTransit_BatchDecryption(t *testing.T) {
|
||||
|
@ -64,74 +68,179 @@ func TestTransit_BatchDecryption(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTransit_BatchDecryption_DerivedKey(t *testing.T) {
|
||||
var req *logical.Request
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
policyData := map[string]interface{}{
|
||||
"derived": true,
|
||||
}
|
||||
|
||||
policyReq := &logical.Request{
|
||||
// Create a derived key.
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "keys/existing_key",
|
||||
Storage: s,
|
||||
Data: policyData,
|
||||
Data: map[string]interface{}{
|
||||
"derived": true,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(context.Background(), policyReq)
|
||||
resp, err = b.HandleRequest(context.Background(), req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
batchInput := []interface{}{
|
||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dGVzdGNvbnRleHQ="},
|
||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dGVzdGNvbnRleHQ="},
|
||||
// Encrypt some values for use in test cases.
|
||||
plaintextItems := []struct {
|
||||
plaintext, context string
|
||||
}{
|
||||
{plaintext: "dGhlIHF1aWNrIGJyb3duIGZveA==", context: "dGVzdGNvbnRleHQ="},
|
||||
{plaintext: "anVtcGVkIG92ZXIgdGhlIGxhenkgZG9n", context: "dGVzdGNvbnRleHQy"},
|
||||
}
|
||||
|
||||
batchData := map[string]interface{}{
|
||||
"batch_input": batchInput,
|
||||
}
|
||||
batchReq := &logical.Request{
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "encrypt/existing_key",
|
||||
Storage: s,
|
||||
Data: batchData,
|
||||
Data: map[string]interface{}{
|
||||
"batch_input": []interface{}{
|
||||
map[string]interface{}{"plaintext": plaintextItems[0].plaintext, "context": plaintextItems[0].context},
|
||||
map[string]interface{}{"plaintext": plaintextItems[1].plaintext, "context": plaintextItems[1].context},
|
||||
},
|
||||
},
|
||||
}
|
||||
resp, err = b.HandleRequest(context.Background(), batchReq)
|
||||
resp, err = b.HandleRequest(context.Background(), req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
batchDecryptionInputItems := resp.Data["batch_results"].([]EncryptBatchResponseItem)
|
||||
encryptedItems := resp.Data["batch_results"].([]EncryptBatchResponseItem)
|
||||
|
||||
batchDecryptionInput := make([]interface{}, len(batchDecryptionInputItems))
|
||||
for i, item := range batchDecryptionInputItems {
|
||||
batchDecryptionInput[i] = map[string]interface{}{"ciphertext": item.Ciphertext, "context": "dGVzdGNvbnRleHQ="}
|
||||
tests := []struct {
|
||||
name string
|
||||
in []interface{}
|
||||
want []DecryptBatchResponseItem
|
||||
shouldErr bool
|
||||
wantHTTPStatus int
|
||||
}{
|
||||
{
|
||||
name: "nil-input",
|
||||
in: nil,
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty-input",
|
||||
in: []interface{}{},
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "single-item-success",
|
||||
in: []interface{}{
|
||||
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[0].context},
|
||||
},
|
||||
want: []DecryptBatchResponseItem{
|
||||
{Plaintext: plaintextItems[0].plaintext},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single-item-invalid-ciphertext",
|
||||
in: []interface{}{
|
||||
map[string]interface{}{"ciphertext": "xxx", "context": plaintextItems[0].context},
|
||||
},
|
||||
want: []DecryptBatchResponseItem{
|
||||
{Error: "invalid ciphertext: no prefix"},
|
||||
},
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "single-item-wrong-context",
|
||||
in: []interface{}{
|
||||
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context},
|
||||
},
|
||||
want: []DecryptBatchResponseItem{
|
||||
{Error: "cipher: message authentication failed"},
|
||||
},
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "batch-full-success",
|
||||
in: []interface{}{
|
||||
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[0].context},
|
||||
map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[1].context},
|
||||
},
|
||||
want: []DecryptBatchResponseItem{
|
||||
{Plaintext: plaintextItems[0].plaintext},
|
||||
{Plaintext: plaintextItems[1].plaintext},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "batch-partial-success",
|
||||
in: []interface{}{
|
||||
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context},
|
||||
map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[1].context},
|
||||
},
|
||||
want: []DecryptBatchResponseItem{
|
||||
{Error: "cipher: message authentication failed"},
|
||||
{Plaintext: plaintextItems[1].plaintext},
|
||||
},
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "batch-full-failure",
|
||||
in: []interface{}{
|
||||
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context},
|
||||
map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[0].context},
|
||||
},
|
||||
want: []DecryptBatchResponseItem{
|
||||
{Error: "cipher: message authentication failed"},
|
||||
{Error: "cipher: message authentication failed"},
|
||||
},
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
batchDecryptionData := map[string]interface{}{
|
||||
"batch_input": batchDecryptionInput,
|
||||
}
|
||||
|
||||
batchDecryptionReq := &logical.Request{
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "decrypt/existing_key",
|
||||
Storage: s,
|
||||
Data: batchDecryptionData,
|
||||
Data: map[string]interface{}{
|
||||
"batch_input": tt.in,
|
||||
},
|
||||
}
|
||||
resp, err = b.HandleRequest(context.Background(), batchDecryptionReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
resp, err = b.HandleRequest(context.Background(), req)
|
||||
|
||||
didErr := err != nil || (resp != nil && resp.IsError())
|
||||
if didErr {
|
||||
if !tt.shouldErr {
|
||||
t.Fatalf("unexpected error err:%v, resp:%#v", err, resp)
|
||||
}
|
||||
} else {
|
||||
if tt.shouldErr {
|
||||
t.Fatal("expected error, but none occurred")
|
||||
}
|
||||
|
||||
batchDecryptionResponseItems := resp.Data["batch_results"].([]DecryptBatchResponseItem)
|
||||
if rawRespBody, ok := resp.Data[logical.HTTPRawBody]; ok {
|
||||
httpResp := &logical.HTTPResponse{}
|
||||
err = jsonutil.DecodeJSON([]byte(rawRespBody.(string)), httpResp)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal nested response: err:%v, resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||
for _, item := range batchDecryptionResponseItems {
|
||||
if item.Plaintext != plaintext {
|
||||
t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, item.Plaintext)
|
||||
if respStatus, ok := resp.Data[logical.HTTPStatusCode]; !ok || respStatus != tt.wantHTTPStatus {
|
||||
t.Fatalf("HTTP response status code mismatch, want:%d, got:%d", tt.wantHTTPStatus, respStatus)
|
||||
}
|
||||
|
||||
resp = logical.HTTPResponseToLogicalResponse(httpResp)
|
||||
}
|
||||
|
||||
var respItems []DecryptBatchResponseItem
|
||||
err = mapstructure.Decode(resp.Data["batch_results"], &respItems)
|
||||
if err != nil {
|
||||
t.Fatalf("problem decoding response items: err:%v, resp:%#v", err, resp)
|
||||
}
|
||||
if !reflect.DeepEqual(tt.want, respItems) {
|
||||
t.Fatalf("response items mismatch, want:%#v, got:%#v", tt.want, respItems)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
|
@ -263,6 +264,9 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
|
|||
batchResponseItems := make([]EncryptBatchResponseItem, len(batchInputItems))
|
||||
contextSet := len(batchInputItems[0].Context) != 0
|
||||
|
||||
userErrorInBatch := false
|
||||
internalErrorInBatch := false
|
||||
|
||||
// Before processing the batch request items, get the policy. If the
|
||||
// policy is supposed to be upserted, then determine if 'derived' is to
|
||||
// be set or not, based on the presence of 'context' field in all the
|
||||
|
@ -274,6 +278,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
|
|||
|
||||
_, err := base64.StdEncoding.DecodeString(item.Plaintext)
|
||||
if err != nil {
|
||||
userErrorInBatch = true
|
||||
batchResponseItems[i].Error = err.Error()
|
||||
continue
|
||||
}
|
||||
|
@ -282,6 +287,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
|
|||
if len(item.Context) != 0 {
|
||||
batchInputItems[i].DecodedContext, err = base64.StdEncoding.DecodeString(item.Context)
|
||||
if err != nil {
|
||||
userErrorInBatch = true
|
||||
batchResponseItems[i].Error = err.Error()
|
||||
continue
|
||||
}
|
||||
|
@ -291,6 +297,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
|
|||
if len(item.Nonce) != 0 {
|
||||
batchInputItems[i].DecodedNonce, err = base64.StdEncoding.DecodeString(item.Nonce)
|
||||
if err != nil {
|
||||
userErrorInBatch = true
|
||||
batchResponseItems[i].Error = err.Error()
|
||||
continue
|
||||
}
|
||||
|
@ -358,18 +365,19 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
|
|||
ciphertext, err := p.Encrypt(item.KeyVersion, item.DecodedContext, item.DecodedNonce, item.Plaintext)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case errutil.UserError:
|
||||
case errutil.InternalError:
|
||||
internalErrorInBatch = true
|
||||
default:
|
||||
userErrorInBatch = true
|
||||
}
|
||||
batchResponseItems[i].Error = err.Error()
|
||||
continue
|
||||
default:
|
||||
p.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if ciphertext == "" {
|
||||
p.Unlock()
|
||||
return nil, fmt.Errorf("empty ciphertext returned for input item %d", i)
|
||||
userErrorInBatch = true
|
||||
batchResponseItems[i].Error = fmt.Sprintf("empty ciphertext returned for input item %d", i)
|
||||
continue
|
||||
}
|
||||
|
||||
keyVersion := item.KeyVersion
|
||||
|
@ -389,6 +397,11 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
|
|||
} else {
|
||||
if batchResponseItems[0].Error != "" {
|
||||
p.Unlock()
|
||||
|
||||
if internalErrorInBatch {
|
||||
return nil, errutil.InternalError{Err: batchResponseItems[0].Error}
|
||||
}
|
||||
|
||||
return logical.ErrorResponse(batchResponseItems[0].Error), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
|
@ -403,6 +416,18 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
|
|||
}
|
||||
|
||||
p.Unlock()
|
||||
|
||||
// Depending on the errors in the batch, different status codes should be returned. User errors
|
||||
// will return a 400 and precede internal errors which return a 500. The reasoning behind this is
|
||||
// that user errors are non-retryable without making changes to the request, and should be surfaced
|
||||
// to the user first.
|
||||
switch {
|
||||
case userErrorInBatch:
|
||||
return logical.RespondWithStatusCode(resp, req, http.StatusBadRequest)
|
||||
case internalErrorInBatch:
|
||||
return logical.RespondWithStatusCode(resp, req, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
```release-note:improvement
|
||||
secrets/transit: Don't abort transit encrypt or decrypt batches on single item failure.
|
||||
```
|
Loading…
Reference in New Issue