Return non-retryable errors on transit encrypt and decrypt failures (#13111)

* Return HTTP 400s on transit decrypt requests where decryption fails. (#10842)

* Don't abort transit batch encryption when a single batch item fails.

* Add unit tests for updated transit batch decryption behavior.

* Add changelog entry for transit encrypt/decrypt batch abort fix.

* Simplify transit batch error message generation when ciphertext is empty.

* Return error HTTP status codes in transit on partial batch decrypt failure.

* Return error HTTP status codes in transit on partial batch encrypt failure.

* Properly account for non-batch transit decryption failure return. Simplify transit batch decryption test data. Ensure HTTP status codes are expected values on batch transit batch decryption partial failure.

* Properly account for non-batch transit encryption failure return. Actually return error HTTP status code on transit batch encryption failure (partial or full).
This commit is contained in:
Matt Schultz 2021-11-15 15:53:22 -06:00 committed by GitHub
parent 3d46021d4e
commit 0abd248c9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 214 additions and 53 deletions

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net/http"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/helper/errutil"
@ -91,12 +92,16 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
batchResponseItems := make([]DecryptBatchResponseItem, len(batchInputItems)) batchResponseItems := make([]DecryptBatchResponseItem, len(batchInputItems))
contextSet := len(batchInputItems[0].Context) != 0 contextSet := len(batchInputItems[0].Context) != 0
userErrorInBatch := false
internalErrorInBatch := false
for i, item := range batchInputItems { for i, item := range batchInputItems {
if (len(item.Context) == 0 && contextSet) || (len(item.Context) != 0 && !contextSet) { if (len(item.Context) == 0 && contextSet) || (len(item.Context) != 0 && !contextSet) {
return logical.ErrorResponse("context should be set either in all the request blocks or in none"), logical.ErrInvalidRequest return logical.ErrorResponse("context should be set either in all the request blocks or in none"), logical.ErrInvalidRequest
} }
if item.Ciphertext == "" { if item.Ciphertext == "" {
userErrorInBatch = true
batchResponseItems[i].Error = "missing ciphertext to decrypt" batchResponseItems[i].Error = "missing ciphertext to decrypt"
continue continue
} }
@ -105,6 +110,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
if len(item.Context) != 0 { if len(item.Context) != 0 {
batchInputItems[i].DecodedContext, err = base64.StdEncoding.DecodeString(item.Context) batchInputItems[i].DecodedContext, err = base64.StdEncoding.DecodeString(item.Context)
if err != nil { if err != nil {
userErrorInBatch = true
batchResponseItems[i].Error = err.Error() batchResponseItems[i].Error = err.Error()
continue continue
} }
@ -114,6 +120,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
if len(item.Nonce) != 0 { if len(item.Nonce) != 0 {
batchInputItems[i].DecodedNonce, err = base64.StdEncoding.DecodeString(item.Nonce) batchInputItems[i].DecodedNonce, err = base64.StdEncoding.DecodeString(item.Nonce)
if err != nil { if err != nil {
userErrorInBatch = true
batchResponseItems[i].Error = err.Error() batchResponseItems[i].Error = err.Error()
continue continue
} }
@ -143,13 +150,13 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
plaintext, err := p.Decrypt(item.DecodedContext, item.DecodedNonce, item.Ciphertext) plaintext, err := p.Decrypt(item.DecodedContext, item.DecodedNonce, item.Ciphertext)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case errutil.UserError: case errutil.InternalError:
internalErrorInBatch = true
default:
userErrorInBatch = true
}
batchResponseItems[i].Error = err.Error() batchResponseItems[i].Error = err.Error()
continue continue
default:
p.Unlock()
return nil, err
}
} }
batchResponseItems[i].Plaintext = plaintext batchResponseItems[i].Plaintext = plaintext
} }
@ -162,6 +169,11 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
} else { } else {
if batchResponseItems[0].Error != "" { if batchResponseItems[0].Error != "" {
p.Unlock() p.Unlock()
if internalErrorInBatch {
return nil, errutil.InternalError{Err: batchResponseItems[0].Error}
}
return logical.ErrorResponse(batchResponseItems[0].Error), logical.ErrInvalidRequest return logical.ErrorResponse(batchResponseItems[0].Error), logical.ErrInvalidRequest
} }
resp.Data = map[string]interface{}{ resp.Data = map[string]interface{}{
@ -170,6 +182,18 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
} }
p.Unlock() p.Unlock()
// Depending on the errors in the batch, different status codes should be returned. User errors
// will return a 400 and precede internal errors which return a 500. The reasoning behind this is
// that user errors are non-retryable without making changes to the request, and should be surfaced
// to the user first.
switch {
case userErrorInBatch:
return logical.RespondWithStatusCode(resp, req, http.StatusBadRequest)
case internalErrorInBatch:
return logical.RespondWithStatusCode(resp, req, http.StatusInternalServerError)
}
return resp, nil return resp, nil
} }

View File

@ -3,9 +3,13 @@ package transit
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"net/http"
"reflect"
"testing" "testing"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/mitchellh/mapstructure"
) )
func TestTransit_BatchDecryption(t *testing.T) { func TestTransit_BatchDecryption(t *testing.T) {
@ -64,74 +68,179 @@ func TestTransit_BatchDecryption(t *testing.T) {
} }
func TestTransit_BatchDecryption_DerivedKey(t *testing.T) { func TestTransit_BatchDecryption_DerivedKey(t *testing.T) {
var req *logical.Request
var resp *logical.Response var resp *logical.Response
var err error var err error
b, s := createBackendWithStorage(t) b, s := createBackendWithStorage(t)
policyData := map[string]interface{}{ // Create a derived key.
"derived": true, req = &logical.Request{
}
policyReq := &logical.Request{
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
Path: "keys/existing_key", Path: "keys/existing_key",
Storage: s, Storage: s,
Data: policyData, Data: map[string]interface{}{
"derived": true,
},
} }
resp, err = b.HandleRequest(context.Background(), req)
resp, err = b.HandleRequest(context.Background(), policyReq)
if err != nil || (resp != nil && resp.IsError()) { if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp) t.Fatalf("err:%v resp:%#v", err, resp)
} }
batchInput := []interface{}{ // Encrypt some values for use in test cases.
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dGVzdGNvbnRleHQ="}, plaintextItems := []struct {
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dGVzdGNvbnRleHQ="}, plaintext, context string
}{
{plaintext: "dGhlIHF1aWNrIGJyb3duIGZveA==", context: "dGVzdGNvbnRleHQ="},
{plaintext: "anVtcGVkIG92ZXIgdGhlIGxhenkgZG9n", context: "dGVzdGNvbnRleHQy"},
} }
req = &logical.Request{
batchData := map[string]interface{}{
"batch_input": batchInput,
}
batchReq := &logical.Request{
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
Path: "encrypt/existing_key", Path: "encrypt/existing_key",
Storage: s, Storage: s,
Data: batchData, Data: map[string]interface{}{
"batch_input": []interface{}{
map[string]interface{}{"plaintext": plaintextItems[0].plaintext, "context": plaintextItems[0].context},
map[string]interface{}{"plaintext": plaintextItems[1].plaintext, "context": plaintextItems[1].context},
},
},
} }
resp, err = b.HandleRequest(context.Background(), batchReq) resp, err = b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) { if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp) t.Fatalf("err:%v resp:%#v", err, resp)
} }
batchDecryptionInputItems := resp.Data["batch_results"].([]EncryptBatchResponseItem) encryptedItems := resp.Data["batch_results"].([]EncryptBatchResponseItem)
batchDecryptionInput := make([]interface{}, len(batchDecryptionInputItems)) tests := []struct {
for i, item := range batchDecryptionInputItems { name string
batchDecryptionInput[i] = map[string]interface{}{"ciphertext": item.Ciphertext, "context": "dGVzdGNvbnRleHQ="} in []interface{}
want []DecryptBatchResponseItem
shouldErr bool
wantHTTPStatus int
}{
{
name: "nil-input",
in: nil,
shouldErr: true,
},
{
name: "empty-input",
in: []interface{}{},
shouldErr: true,
},
{
name: "single-item-success",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[0].context},
},
want: []DecryptBatchResponseItem{
{Plaintext: plaintextItems[0].plaintext},
},
},
{
name: "single-item-invalid-ciphertext",
in: []interface{}{
map[string]interface{}{"ciphertext": "xxx", "context": plaintextItems[0].context},
},
want: []DecryptBatchResponseItem{
{Error: "invalid ciphertext: no prefix"},
},
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "single-item-wrong-context",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context},
},
want: []DecryptBatchResponseItem{
{Error: "cipher: message authentication failed"},
},
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "batch-full-success",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[0].context},
map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[1].context},
},
want: []DecryptBatchResponseItem{
{Plaintext: plaintextItems[0].plaintext},
{Plaintext: plaintextItems[1].plaintext},
},
},
{
name: "batch-partial-success",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context},
map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[1].context},
},
want: []DecryptBatchResponseItem{
{Error: "cipher: message authentication failed"},
{Plaintext: plaintextItems[1].plaintext},
},
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "batch-full-failure",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context},
map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[0].context},
},
want: []DecryptBatchResponseItem{
{Error: "cipher: message authentication failed"},
{Error: "cipher: message authentication failed"},
},
wantHTTPStatus: http.StatusBadRequest,
},
} }
batchDecryptionData := map[string]interface{}{ for _, tt := range tests {
"batch_input": batchDecryptionInput, t.Run(tt.name, func(t *testing.T) {
} req = &logical.Request{
batchDecryptionReq := &logical.Request{
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
Path: "decrypt/existing_key", Path: "decrypt/existing_key",
Storage: s, Storage: s,
Data: batchDecryptionData, Data: map[string]interface{}{
"batch_input": tt.in,
},
} }
resp, err = b.HandleRequest(context.Background(), batchDecryptionReq) resp, err = b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp) didErr := err != nil || (resp != nil && resp.IsError())
if didErr {
if !tt.shouldErr {
t.Fatalf("unexpected error err:%v, resp:%#v", err, resp)
}
} else {
if tt.shouldErr {
t.Fatal("expected error, but none occurred")
} }
batchDecryptionResponseItems := resp.Data["batch_results"].([]DecryptBatchResponseItem) if rawRespBody, ok := resp.Data[logical.HTTPRawBody]; ok {
httpResp := &logical.HTTPResponse{}
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" err = jsonutil.DecodeJSON([]byte(rawRespBody.(string)), httpResp)
for _, item := range batchDecryptionResponseItems { if err != nil {
if item.Plaintext != plaintext { t.Fatalf("failed to unmarshal nested response: err:%v, resp:%#v", err, resp)
t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, item.Plaintext)
} }
if respStatus, ok := resp.Data[logical.HTTPStatusCode]; !ok || respStatus != tt.wantHTTPStatus {
t.Fatalf("HTTP response status code mismatch, want:%d, got:%d", tt.wantHTTPStatus, respStatus)
}
resp = logical.HTTPResponseToLogicalResponse(httpResp)
}
var respItems []DecryptBatchResponseItem
err = mapstructure.Decode(resp.Data["batch_results"], &respItems)
if err != nil {
t.Fatalf("problem decoding response items: err:%v, resp:%#v", err, resp)
}
if !reflect.DeepEqual(tt.want, respItems) {
t.Fatalf("response items mismatch, want:%#v, got:%#v", tt.want, respItems)
}
}
})
} }
} }

View File

@ -5,6 +5,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"reflect" "reflect"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
@ -263,6 +264,9 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
batchResponseItems := make([]EncryptBatchResponseItem, len(batchInputItems)) batchResponseItems := make([]EncryptBatchResponseItem, len(batchInputItems))
contextSet := len(batchInputItems[0].Context) != 0 contextSet := len(batchInputItems[0].Context) != 0
userErrorInBatch := false
internalErrorInBatch := false
// Before processing the batch request items, get the policy. If the // Before processing the batch request items, get the policy. If the
// policy is supposed to be upserted, then determine if 'derived' is to // policy is supposed to be upserted, then determine if 'derived' is to
// be set or not, based on the presence of 'context' field in all the // be set or not, based on the presence of 'context' field in all the
@ -274,6 +278,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
_, err := base64.StdEncoding.DecodeString(item.Plaintext) _, err := base64.StdEncoding.DecodeString(item.Plaintext)
if err != nil { if err != nil {
userErrorInBatch = true
batchResponseItems[i].Error = err.Error() batchResponseItems[i].Error = err.Error()
continue continue
} }
@ -282,6 +287,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
if len(item.Context) != 0 { if len(item.Context) != 0 {
batchInputItems[i].DecodedContext, err = base64.StdEncoding.DecodeString(item.Context) batchInputItems[i].DecodedContext, err = base64.StdEncoding.DecodeString(item.Context)
if err != nil { if err != nil {
userErrorInBatch = true
batchResponseItems[i].Error = err.Error() batchResponseItems[i].Error = err.Error()
continue continue
} }
@ -291,6 +297,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
if len(item.Nonce) != 0 { if len(item.Nonce) != 0 {
batchInputItems[i].DecodedNonce, err = base64.StdEncoding.DecodeString(item.Nonce) batchInputItems[i].DecodedNonce, err = base64.StdEncoding.DecodeString(item.Nonce)
if err != nil { if err != nil {
userErrorInBatch = true
batchResponseItems[i].Error = err.Error() batchResponseItems[i].Error = err.Error()
continue continue
} }
@ -358,18 +365,19 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
ciphertext, err := p.Encrypt(item.KeyVersion, item.DecodedContext, item.DecodedNonce, item.Plaintext) ciphertext, err := p.Encrypt(item.KeyVersion, item.DecodedContext, item.DecodedNonce, item.Plaintext)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case errutil.UserError: case errutil.InternalError:
internalErrorInBatch = true
default:
userErrorInBatch = true
}
batchResponseItems[i].Error = err.Error() batchResponseItems[i].Error = err.Error()
continue continue
default:
p.Unlock()
return nil, err
}
} }
if ciphertext == "" { if ciphertext == "" {
p.Unlock() userErrorInBatch = true
return nil, fmt.Errorf("empty ciphertext returned for input item %d", i) batchResponseItems[i].Error = fmt.Sprintf("empty ciphertext returned for input item %d", i)
continue
} }
keyVersion := item.KeyVersion keyVersion := item.KeyVersion
@ -389,6 +397,11 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
} else { } else {
if batchResponseItems[0].Error != "" { if batchResponseItems[0].Error != "" {
p.Unlock() p.Unlock()
if internalErrorInBatch {
return nil, errutil.InternalError{Err: batchResponseItems[0].Error}
}
return logical.ErrorResponse(batchResponseItems[0].Error), logical.ErrInvalidRequest return logical.ErrorResponse(batchResponseItems[0].Error), logical.ErrInvalidRequest
} }
@ -403,6 +416,18 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
} }
p.Unlock() p.Unlock()
// Depending on the errors in the batch, different status codes should be returned. User errors
// will return a 400 and precede internal errors which return a 500. The reasoning behind this is
// that user errors are non-retryable without making changes to the request, and should be surfaced
// to the user first.
switch {
case userErrorInBatch:
return logical.RespondWithStatusCode(resp, req, http.StatusBadRequest)
case internalErrorInBatch:
return logical.RespondWithStatusCode(resp, req, http.StatusInternalServerError)
}
return resp, nil return resp, nil
} }

3
changelog/13111.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
secrets/transit: Don't abort transit encrypt or decrypt batches on single item failure.
```