From 3d1320d997f411e3e07bae87aa133d16ae01094b Mon Sep 17 00:00:00 2001 From: Chris Hoffman Date: Mon, 12 Nov 2018 10:57:02 -0500 Subject: [PATCH] Fixing AliCloud KMS seal encryption/decryption (#5756) * fixing seal encryption/decryption * Address feedback. Co-Authored-By: chrishoffman --- vault/seal/alicloudkms/alicloudkms.go | 18 +++++-- .../seal/alicloudkms/alicloudkms_acc_test.go | 48 +++++++++++++++++++ 2 files changed, 61 insertions(+), 5 deletions(-) create mode 100644 vault/seal/alicloudkms/alicloudkms_acc_test.go diff --git a/vault/seal/alicloudkms/alicloudkms.go b/vault/seal/alicloudkms/alicloudkms.go index e34ccf7d5..a2bf24713 100644 --- a/vault/seal/alicloudkms/alicloudkms.go +++ b/vault/seal/alicloudkms/alicloudkms.go @@ -2,6 +2,7 @@ package alicloudkms import ( "context" + "encoding/base64" "errors" "fmt" "os" @@ -64,13 +65,12 @@ func (k *AliCloudKMSSeal) SetConfig(config map[string]string) (map[string]string region := "" if k.client == nil { - // Check and set region. region = os.Getenv("ALICLOUD_REGION") if region == "" { ok := false if region, ok = config["region"]; !ok { - region = "us-east-1" + region = "cn-beijing" } } @@ -129,6 +129,9 @@ func (k *AliCloudKMSSeal) SetConfig(config map[string]string) (map[string]string if keyInfo == nil || keyInfo.KeyMetadata.KeyId == "" { return nil, errors.New("no key information returned") } + + // Store the current key id. If using a key alias, this will point to the actual + // unique key that that was used for this encrypt operation. k.currentKeyID.Store(keyInfo.KeyMetadata.KeyId) // Map that holds non-sensitive configuration info @@ -178,7 +181,7 @@ func (k *AliCloudKMSSeal) Encrypt(_ context.Context, plaintext []byte) (*physica input := kms.CreateEncryptRequest() input.KeyId = k.keyID - input.Plaintext = string(env.Key) + input.Plaintext = base64.StdEncoding.EncodeToString(env.Key) input.Domain = k.domain output, err := k.client.Encrypt(input) @@ -208,7 +211,7 @@ func (k *AliCloudKMSSeal) Decrypt(_ context.Context, in *physical.EncryptedBlobI return nil, fmt.Errorf("given input for decryption is nil") } - // KeyID is not passed to this call because AWS handles this + // KeyID is not passed to this call because AliCloud handles this // internally based on the metadata stored with the encrypted data input := kms.CreateDecryptRequest() input.CiphertextBlob = string(in.KeyInfo.WrappedKey) @@ -219,8 +222,13 @@ func (k *AliCloudKMSSeal) Decrypt(_ context.Context, in *physical.EncryptedBlobI return nil, errwrap.Wrapf("error decrypting data encryption key: {{err}}", err) } + keyBytes, err := base64.StdEncoding.DecodeString(output.Plaintext) + if err != nil { + return nil, err + } + envInfo := &seal.EnvelopeInfo{ - Key: []byte(output.Plaintext), + Key: keyBytes, IV: in.IV, Ciphertext: in.Ciphertext, } diff --git a/vault/seal/alicloudkms/alicloudkms_acc_test.go b/vault/seal/alicloudkms/alicloudkms_acc_test.go new file mode 100644 index 000000000..656167845 --- /dev/null +++ b/vault/seal/alicloudkms/alicloudkms_acc_test.go @@ -0,0 +1,48 @@ +package alicloudkms + +import ( + "context" + "os" + "reflect" + "testing" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/helper/logging" +) + +// This test executes real calls. The calls themselves should be free, +// but the KMS key used is generally not free. Alibaba doesn't publish +// the price but it can be assumed to be around $1/month because that's +// what AWS charges for the same. +// +// To run this test, the following env variables need to be set: +// - VAULT_ALICLOUDKMS_SEAL_KEY_ID +// - ALICLOUD_REGION +// - ALICLOUD_ACCESS_KEY +// - ALICLOUD_SECRET_KEY +func TestAccAliCloudKMSSeal_Lifecycle(t *testing.T) { + if os.Getenv("VAULT_ACC") == "" { + t.SkipNow() + } + + s := NewSeal(logging.NewVaultLogger(log.Trace)) + _, err := s.SetConfig(nil) + if err != nil { + t.Fatalf("err : %s", err) + } + + input := []byte("foo") + swi, err := s.Encrypt(context.Background(), input) + if err != nil { + t.Fatalf("err: %s", err.Error()) + } + + pt, err := s.Decrypt(context.Background(), swi) + if err != nil { + t.Fatalf("err: %s", err.Error()) + } + + if !reflect.DeepEqual(input, pt) { + t.Fatalf("expected %s, got %s", input, pt) + } +}