open-vault/builtin/logical/transit/stepwise_test.go

234 lines
7.3 KiB
Go

package transit
import (
"encoding/base64"
"fmt"
"os"
"testing"
stepwise "github.com/hashicorp/vault-testing-stepwise"
dockerEnvironment "github.com/hashicorp/vault-testing-stepwise/environments/docker"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/keysutil"
"github.com/mitchellh/mapstructure"
)
// TestBackend_basic_docker is an example test using the Docker Environment
func TestAccBackend_basic_docker(t *testing.T) {
decryptData := make(map[string]interface{})
envOptions := stepwise.MountOptions{
RegistryName: "updatedtransit",
PluginType: stepwise.PluginTypeSecrets,
PluginName: "transit",
MountPathPrefix: "transit_temp",
}
stepwise.Run(t, stepwise.Case{
Environment: dockerEnvironment.NewEnvironment("updatedtransit", &envOptions),
Steps: []stepwise.Step{
testAccStepwiseListPolicy(t, "test", true),
testAccStepwiseWritePolicy(t, "test", true),
testAccStepwiseListPolicy(t, "test", false),
testAccStepwiseReadPolicy(t, "test", false, true),
testAccStepwiseEncryptContext(t, "test", testPlaintext, "my-cool-context", decryptData),
testAccStepwiseDecrypt(t, "test", testPlaintext, decryptData),
testAccStepwiseEnableDeletion(t, "test"),
testAccStepwiseDeletePolicy(t, "test"),
testAccStepwiseReadPolicy(t, "test", true, true),
},
})
}
func testAccStepwiseWritePolicy(t *testing.T, name string, derived bool) stepwise.Step {
ts := stepwise.Step{
Operation: stepwise.WriteOperation,
Path: "keys/" + name,
Data: map[string]interface{}{
"derived": derived,
},
}
if os.Getenv("TRANSIT_ACC_KEY_TYPE") == "CHACHA" {
ts.Data["type"] = "chacha20-poly1305"
}
return ts
}
func testAccStepwiseListPolicy(t *testing.T, name string, expectNone bool) stepwise.Step {
return stepwise.Step{
Operation: stepwise.ListOperation,
Path: "keys",
Assert: func(resp *api.Secret, err error) error {
if (resp == nil || len(resp.Data) == 0) && !expectNone {
return fmt.Errorf("missing response")
}
if expectNone && resp != nil {
return fmt.Errorf("response data when expecting none")
}
if expectNone && resp == nil {
return nil
}
var d struct {
Keys []string `mapstructure:"keys"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if len(d.Keys) == 0 {
return fmt.Errorf("missing keys")
}
if len(d.Keys) > 1 {
return fmt.Errorf("only 1 key expected, %d returned", len(d.Keys))
}
if d.Keys[0] != name {
return fmt.Errorf("Actual key: %s\nExpected key: %s", d.Keys[0], name)
}
return nil
},
}
}
func testAccStepwiseReadPolicy(t *testing.T, name string, expectNone, derived bool) stepwise.Step {
t.Helper()
return testAccStepwiseReadPolicyWithVersions(t, name, expectNone, derived, 1, 0)
}
func testAccStepwiseReadPolicyWithVersions(t *testing.T, name string, expectNone, derived bool, minDecryptionVersion int, minEncryptionVersion int) stepwise.Step {
t.Helper()
return stepwise.Step{
Operation: stepwise.ReadOperation,
Path: "keys/" + name,
Assert: func(resp *api.Secret, err error) error {
t.Helper()
if resp == nil && !expectNone {
return fmt.Errorf("missing response")
} else if expectNone {
if resp != nil {
return fmt.Errorf("response when expecting none")
}
return nil
}
var d struct {
Name string `mapstructure:"name"`
Key []byte `mapstructure:"key"`
Keys map[string]int64 `mapstructure:"keys"`
Type string `mapstructure:"type"`
Derived bool `mapstructure:"derived"`
KDF string `mapstructure:"kdf"`
DeletionAllowed bool `mapstructure:"deletion_allowed"`
ConvergentEncryption bool `mapstructure:"convergent_encryption"`
MinDecryptionVersion int `mapstructure:"min_decryption_version"`
MinEncryptionVersion int `mapstructure:"min_encryption_version"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if d.Name != name {
return fmt.Errorf("bad name: %#v", d)
}
if os.Getenv("TRANSIT_ACC_KEY_TYPE") == "CHACHA" {
if d.Type != keysutil.KeyType(keysutil.KeyType_ChaCha20_Poly1305).String() {
return fmt.Errorf("bad key type: %#v", d)
}
} else if d.Type != keysutil.KeyType(keysutil.KeyType_AES256_GCM96).String() {
return fmt.Errorf("bad key type: %#v", d)
}
// Should NOT get a key back
if d.Key != nil {
return fmt.Errorf("unexpected key found")
}
if d.Keys == nil {
return fmt.Errorf("no keys found")
}
if d.MinDecryptionVersion != minDecryptionVersion {
return fmt.Errorf("minimum decryption version mismatch, expected (%#v), found (%#v)", minEncryptionVersion, d.MinDecryptionVersion)
}
if d.MinEncryptionVersion != minEncryptionVersion {
return fmt.Errorf("minimum encryption version mismatch, expected (%#v), found (%#v)", minEncryptionVersion, d.MinDecryptionVersion)
}
if d.DeletionAllowed == true {
return fmt.Errorf("expected DeletionAllowed to be false, but got true")
}
if d.Derived != derived {
return fmt.Errorf("derived mismatch, expected (%t), got (%t)", derived, d.Derived)
}
if derived && d.KDF != "hkdf_sha256" {
return fmt.Errorf("expected KDF to be hkdf_sha256, but got (%s)", d.KDF)
}
return nil
},
}
}
func testAccStepwiseEncryptContext(
t *testing.T, name, plaintext, context string, decryptData map[string]interface{}) stepwise.Step {
return stepwise.Step{
Operation: stepwise.UpdateOperation,
Path: "encrypt/" + name,
Data: map[string]interface{}{
"plaintext": base64.StdEncoding.EncodeToString([]byte(plaintext)),
"context": base64.StdEncoding.EncodeToString([]byte(context)),
},
Assert: func(resp *api.Secret, err error) error {
var d struct {
Ciphertext string `mapstructure:"ciphertext"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if d.Ciphertext == "" {
return fmt.Errorf("missing ciphertext")
}
decryptData["ciphertext"] = d.Ciphertext
decryptData["context"] = base64.StdEncoding.EncodeToString([]byte(context))
return nil
},
}
}
func testAccStepwiseDecrypt(
t *testing.T, name, plaintext string, decryptData map[string]interface{}) stepwise.Step {
return stepwise.Step{
Operation: stepwise.UpdateOperation,
Path: "decrypt/" + name,
Data: decryptData,
Assert: func(resp *api.Secret, err error) error {
var d struct {
Plaintext string `mapstructure:"plaintext"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
// Decode the base64
plainRaw, err := base64.StdEncoding.DecodeString(d.Plaintext)
if err != nil {
return err
}
if string(plainRaw) != plaintext {
return fmt.Errorf("plaintext mismatch: %s expect: %s, decryptData was %#v", plainRaw, plaintext, decryptData)
}
return nil
},
}
}
func testAccStepwiseEnableDeletion(t *testing.T, name string) stepwise.Step {
return stepwise.Step{
Operation: stepwise.UpdateOperation,
Path: "keys/" + name + "/config",
Data: map[string]interface{}{
"deletion_allowed": true,
},
}
}
func testAccStepwiseDeletePolicy(t *testing.T, name string) stepwise.Step {
return stepwise.Step{
Operation: stepwise.DeleteOperation,
Path: "keys/" + name,
}
}