diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index cf6d45060..c48df88ac 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "fmt" + "net/http" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/errutil" @@ -91,12 +92,16 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d batchResponseItems := make([]DecryptBatchResponseItem, len(batchInputItems)) contextSet := len(batchInputItems[0].Context) != 0 + userErrorInBatch := false + internalErrorInBatch := false + for i, item := range batchInputItems { if (len(item.Context) == 0 && contextSet) || (len(item.Context) != 0 && !contextSet) { return logical.ErrorResponse("context should be set either in all the request blocks or in none"), logical.ErrInvalidRequest } if item.Ciphertext == "" { + userErrorInBatch = true batchResponseItems[i].Error = "missing ciphertext to decrypt" continue } @@ -105,6 +110,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d if len(item.Context) != 0 { batchInputItems[i].DecodedContext, err = base64.StdEncoding.DecodeString(item.Context) if err != nil { + userErrorInBatch = true batchResponseItems[i].Error = err.Error() continue } @@ -114,6 +120,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d if len(item.Nonce) != 0 { batchInputItems[i].DecodedNonce, err = base64.StdEncoding.DecodeString(item.Nonce) if err != nil { + userErrorInBatch = true batchResponseItems[i].Error = err.Error() continue } @@ -143,13 +150,13 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d plaintext, err := p.Decrypt(item.DecodedContext, item.DecodedNonce, item.Ciphertext) if err != nil { switch err.(type) { - case errutil.UserError: - batchResponseItems[i].Error = err.Error() - continue + case errutil.InternalError: + internalErrorInBatch = true default: - p.Unlock() - return nil, err + userErrorInBatch = true } + batchResponseItems[i].Error = err.Error() + continue } batchResponseItems[i].Plaintext = plaintext } @@ -162,6 +169,11 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d } else { if batchResponseItems[0].Error != "" { p.Unlock() + + if internalErrorInBatch { + return nil, errutil.InternalError{Err: batchResponseItems[0].Error} + } + return logical.ErrorResponse(batchResponseItems[0].Error), logical.ErrInvalidRequest } resp.Data = map[string]interface{}{ @@ -170,6 +182,18 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d } p.Unlock() + + // Depending on the errors in the batch, different status codes should be returned. User errors + // will return a 400 and precede internal errors which return a 500. The reasoning behind this is + // that user errors are non-retryable without making changes to the request, and should be surfaced + // to the user first. + switch { + case userErrorInBatch: + return logical.RespondWithStatusCode(resp, req, http.StatusBadRequest) + case internalErrorInBatch: + return logical.RespondWithStatusCode(resp, req, http.StatusInternalServerError) + } + return resp, nil } diff --git a/builtin/logical/transit/path_decrypt_test.go b/builtin/logical/transit/path_decrypt_test.go index 7e4d0a38e..c52cd6cf0 100644 --- a/builtin/logical/transit/path_decrypt_test.go +++ b/builtin/logical/transit/path_decrypt_test.go @@ -3,9 +3,13 @@ package transit import ( "context" "encoding/json" + "net/http" + "reflect" "testing" + "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/logical" + "github.com/mitchellh/mapstructure" ) func TestTransit_BatchDecryption(t *testing.T) { @@ -64,74 +68,179 @@ func TestTransit_BatchDecryption(t *testing.T) { } func TestTransit_BatchDecryption_DerivedKey(t *testing.T) { + var req *logical.Request var resp *logical.Response var err error b, s := createBackendWithStorage(t) - policyData := map[string]interface{}{ - "derived": true, - } - - policyReq := &logical.Request{ + // Create a derived key. + req = &logical.Request{ Operation: logical.UpdateOperation, Path: "keys/existing_key", Storage: s, - Data: policyData, + Data: map[string]interface{}{ + "derived": true, + }, } - - resp, err = b.HandleRequest(context.Background(), policyReq) + resp, err = b.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err:%v resp:%#v", err, resp) } - batchInput := []interface{}{ - map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dGVzdGNvbnRleHQ="}, - map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dGVzdGNvbnRleHQ="}, + // Encrypt some values for use in test cases. + plaintextItems := []struct { + plaintext, context string + }{ + {plaintext: "dGhlIHF1aWNrIGJyb3duIGZveA==", context: "dGVzdGNvbnRleHQ="}, + {plaintext: "anVtcGVkIG92ZXIgdGhlIGxhenkgZG9n", context: "dGVzdGNvbnRleHQy"}, } - - batchData := map[string]interface{}{ - "batch_input": batchInput, - } - batchReq := &logical.Request{ + req = &logical.Request{ Operation: logical.UpdateOperation, Path: "encrypt/existing_key", Storage: s, - Data: batchData, + Data: map[string]interface{}{ + "batch_input": []interface{}{ + map[string]interface{}{"plaintext": plaintextItems[0].plaintext, "context": plaintextItems[0].context}, + map[string]interface{}{"plaintext": plaintextItems[1].plaintext, "context": plaintextItems[1].context}, + }, + }, } - resp, err = b.HandleRequest(context.Background(), batchReq) + resp, err = b.HandleRequest(context.Background(), req) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err:%v resp:%#v", err, resp) } - batchDecryptionInputItems := resp.Data["batch_results"].([]EncryptBatchResponseItem) + encryptedItems := resp.Data["batch_results"].([]EncryptBatchResponseItem) - batchDecryptionInput := make([]interface{}, len(batchDecryptionInputItems)) - for i, item := range batchDecryptionInputItems { - batchDecryptionInput[i] = map[string]interface{}{"ciphertext": item.Ciphertext, "context": "dGVzdGNvbnRleHQ="} + tests := []struct { + name string + in []interface{} + want []DecryptBatchResponseItem + shouldErr bool + wantHTTPStatus int + }{ + { + name: "nil-input", + in: nil, + shouldErr: true, + }, + { + name: "empty-input", + in: []interface{}{}, + shouldErr: true, + }, + { + name: "single-item-success", + in: []interface{}{ + map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[0].context}, + }, + want: []DecryptBatchResponseItem{ + {Plaintext: plaintextItems[0].plaintext}, + }, + }, + { + name: "single-item-invalid-ciphertext", + in: []interface{}{ + map[string]interface{}{"ciphertext": "xxx", "context": plaintextItems[0].context}, + }, + want: []DecryptBatchResponseItem{ + {Error: "invalid ciphertext: no prefix"}, + }, + wantHTTPStatus: http.StatusBadRequest, + }, + { + name: "single-item-wrong-context", + in: []interface{}{ + map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context}, + }, + want: []DecryptBatchResponseItem{ + {Error: "cipher: message authentication failed"}, + }, + wantHTTPStatus: http.StatusBadRequest, + }, + { + name: "batch-full-success", + in: []interface{}{ + map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[0].context}, + map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[1].context}, + }, + want: []DecryptBatchResponseItem{ + {Plaintext: plaintextItems[0].plaintext}, + {Plaintext: plaintextItems[1].plaintext}, + }, + }, + { + name: "batch-partial-success", + in: []interface{}{ + map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context}, + map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[1].context}, + }, + want: []DecryptBatchResponseItem{ + {Error: "cipher: message authentication failed"}, + {Plaintext: plaintextItems[1].plaintext}, + }, + wantHTTPStatus: http.StatusBadRequest, + }, + { + name: "batch-full-failure", + in: []interface{}{ + map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context}, + map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[0].context}, + }, + want: []DecryptBatchResponseItem{ + {Error: "cipher: message authentication failed"}, + {Error: "cipher: message authentication failed"}, + }, + wantHTTPStatus: http.StatusBadRequest, + }, } - batchDecryptionData := map[string]interface{}{ - "batch_input": batchDecryptionInput, - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "decrypt/existing_key", + Storage: s, + Data: map[string]interface{}{ + "batch_input": tt.in, + }, + } + resp, err = b.HandleRequest(context.Background(), req) - batchDecryptionReq := &logical.Request{ - Operation: logical.UpdateOperation, - Path: "decrypt/existing_key", - Storage: s, - Data: batchDecryptionData, - } - resp, err = b.HandleRequest(context.Background(), batchDecryptionReq) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v", err, resp) - } + didErr := err != nil || (resp != nil && resp.IsError()) + if didErr { + if !tt.shouldErr { + t.Fatalf("unexpected error err:%v, resp:%#v", err, resp) + } + } else { + if tt.shouldErr { + t.Fatal("expected error, but none occurred") + } - batchDecryptionResponseItems := resp.Data["batch_results"].([]DecryptBatchResponseItem) + if rawRespBody, ok := resp.Data[logical.HTTPRawBody]; ok { + httpResp := &logical.HTTPResponse{} + err = jsonutil.DecodeJSON([]byte(rawRespBody.(string)), httpResp) + if err != nil { + t.Fatalf("failed to unmarshal nested response: err:%v, resp:%#v", err, resp) + } - plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" - for _, item := range batchDecryptionResponseItems { - if item.Plaintext != plaintext { - t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, item.Plaintext) - } + if respStatus, ok := resp.Data[logical.HTTPStatusCode]; !ok || respStatus != tt.wantHTTPStatus { + t.Fatalf("HTTP response status code mismatch, want:%d, got:%d", tt.wantHTTPStatus, respStatus) + } + + resp = logical.HTTPResponseToLogicalResponse(httpResp) + } + + var respItems []DecryptBatchResponseItem + err = mapstructure.Decode(resp.Data["batch_results"], &respItems) + if err != nil { + t.Fatalf("problem decoding response items: err:%v, resp:%#v", err, resp) + } + if !reflect.DeepEqual(tt.want, respItems) { + t.Fatalf("response items mismatch, want:%#v, got:%#v", tt.want, respItems) + } + } + }) } } diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index c32895164..deb564b12 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "net/http" "reflect" "github.com/hashicorp/vault/sdk/framework" @@ -263,6 +264,9 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d batchResponseItems := make([]EncryptBatchResponseItem, len(batchInputItems)) contextSet := len(batchInputItems[0].Context) != 0 + userErrorInBatch := false + internalErrorInBatch := false + // Before processing the batch request items, get the policy. If the // policy is supposed to be upserted, then determine if 'derived' is to // be set or not, based on the presence of 'context' field in all the @@ -274,6 +278,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d _, err := base64.StdEncoding.DecodeString(item.Plaintext) if err != nil { + userErrorInBatch = true batchResponseItems[i].Error = err.Error() continue } @@ -282,6 +287,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d if len(item.Context) != 0 { batchInputItems[i].DecodedContext, err = base64.StdEncoding.DecodeString(item.Context) if err != nil { + userErrorInBatch = true batchResponseItems[i].Error = err.Error() continue } @@ -291,6 +297,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d if len(item.Nonce) != 0 { batchInputItems[i].DecodedNonce, err = base64.StdEncoding.DecodeString(item.Nonce) if err != nil { + userErrorInBatch = true batchResponseItems[i].Error = err.Error() continue } @@ -358,18 +365,19 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d ciphertext, err := p.Encrypt(item.KeyVersion, item.DecodedContext, item.DecodedNonce, item.Plaintext) if err != nil { switch err.(type) { - case errutil.UserError: - batchResponseItems[i].Error = err.Error() - continue + case errutil.InternalError: + internalErrorInBatch = true default: - p.Unlock() - return nil, err + userErrorInBatch = true } + batchResponseItems[i].Error = err.Error() + continue } if ciphertext == "" { - p.Unlock() - return nil, fmt.Errorf("empty ciphertext returned for input item %d", i) + userErrorInBatch = true + batchResponseItems[i].Error = fmt.Sprintf("empty ciphertext returned for input item %d", i) + continue } keyVersion := item.KeyVersion @@ -389,6 +397,11 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d } else { if batchResponseItems[0].Error != "" { p.Unlock() + + if internalErrorInBatch { + return nil, errutil.InternalError{Err: batchResponseItems[0].Error} + } + return logical.ErrorResponse(batchResponseItems[0].Error), logical.ErrInvalidRequest } @@ -403,6 +416,18 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d } p.Unlock() + + // Depending on the errors in the batch, different status codes should be returned. User errors + // will return a 400 and precede internal errors which return a 500. The reasoning behind this is + // that user errors are non-retryable without making changes to the request, and should be surfaced + // to the user first. + switch { + case userErrorInBatch: + return logical.RespondWithStatusCode(resp, req, http.StatusBadRequest) + case internalErrorInBatch: + return logical.RespondWithStatusCode(resp, req, http.StatusInternalServerError) + } + return resp, nil } diff --git a/changelog/13111.txt b/changelog/13111.txt new file mode 100644 index 000000000..800cabd9a --- /dev/null +++ b/changelog/13111.txt @@ -0,0 +1,3 @@ +```release-note:improvement +secrets/transit: Don't abort transit encrypt or decrypt batches on single item failure. +```