190 lines
4.0 KiB
Go
190 lines
4.0 KiB
Go
|
// Copyright (c) HashiCorp, Inc.
|
||
|
// SPDX-License-Identifier: MPL-2.0
|
||
|
|
||
|
package command
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"crypto/rand"
|
||
|
"crypto/rsa"
|
||
|
"crypto/x509"
|
||
|
"encoding/base64"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/hashicorp/vault/api"
|
||
|
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
// Validate the `vault transit import` command works.
|
||
|
func TestTransitImport(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
client, closer := testVaultServer(t)
|
||
|
defer closer()
|
||
|
|
||
|
if err := client.Sys().Mount("transit", &api.MountInput{
|
||
|
Type: "transit",
|
||
|
}); err != nil {
|
||
|
t.Fatalf("transit mount error: %#v", err)
|
||
|
}
|
||
|
|
||
|
rsa1, rsa2, aes128, aes256 := generateKeys(t)
|
||
|
|
||
|
type testCase struct {
|
||
|
variant string
|
||
|
path string
|
||
|
key []byte
|
||
|
args []string
|
||
|
shouldFail bool
|
||
|
}
|
||
|
tests := []testCase{
|
||
|
{
|
||
|
"import",
|
||
|
"transit/keys/rsa1",
|
||
|
rsa1,
|
||
|
[]string{"type=rsa-2048"},
|
||
|
false, /* first import */
|
||
|
},
|
||
|
{
|
||
|
"import",
|
||
|
"transit/keys/rsa1",
|
||
|
rsa2,
|
||
|
[]string{"type=rsa-2048"},
|
||
|
true, /* already exists */
|
||
|
},
|
||
|
{
|
||
|
"import-version",
|
||
|
"transit/keys/rsa1",
|
||
|
rsa2,
|
||
|
[]string{"type=rsa-2048"},
|
||
|
false, /* new version */
|
||
|
},
|
||
|
{
|
||
|
"import",
|
||
|
"transit/keys/rsa2",
|
||
|
rsa2,
|
||
|
[]string{"type=rsa-4096"},
|
||
|
true, /* wrong type */
|
||
|
},
|
||
|
{
|
||
|
"import",
|
||
|
"transit/keys/rsa2",
|
||
|
rsa2,
|
||
|
[]string{"type=rsa-2048"},
|
||
|
false, /* new name */
|
||
|
},
|
||
|
{
|
||
|
"import",
|
||
|
"transit/keys/aes1",
|
||
|
aes128,
|
||
|
[]string{"type=aes128-gcm96"},
|
||
|
false, /* first import */
|
||
|
},
|
||
|
{
|
||
|
"import",
|
||
|
"transit/keys/aes1",
|
||
|
aes256,
|
||
|
[]string{"type=aes256-gcm96"},
|
||
|
true, /* already exists */
|
||
|
},
|
||
|
{
|
||
|
"import-version",
|
||
|
"transit/keys/aes1",
|
||
|
aes256,
|
||
|
[]string{"type=aes256-gcm96"},
|
||
|
true, /* new version, different type */
|
||
|
},
|
||
|
{
|
||
|
"import-version",
|
||
|
"transit/keys/aes1",
|
||
|
aes128,
|
||
|
[]string{"type=aes128-gcm96"},
|
||
|
false, /* new version */
|
||
|
},
|
||
|
{
|
||
|
"import",
|
||
|
"transit/keys/aes2",
|
||
|
aes256,
|
||
|
[]string{"type=aes128-gcm96"},
|
||
|
true, /* wrong type */
|
||
|
},
|
||
|
{
|
||
|
"import",
|
||
|
"transit/keys/aes2",
|
||
|
aes256,
|
||
|
[]string{"type=aes256-gcm96"},
|
||
|
false, /* new name */
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for index, tc := range tests {
|
||
|
t.Logf("Running test case %d: %v", index, tc)
|
||
|
execTransitImport(t, client, tc.variant, tc.path, tc.key, tc.args, tc.shouldFail)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func execTransitImport(t *testing.T, client *api.Client, method string, path string, key []byte, data []string, expectFailure bool) {
|
||
|
t.Helper()
|
||
|
|
||
|
keyBase64 := base64.StdEncoding.EncodeToString(key)
|
||
|
|
||
|
var args []string
|
||
|
args = append(args, "transit")
|
||
|
args = append(args, method)
|
||
|
args = append(args, path)
|
||
|
args = append(args, keyBase64)
|
||
|
args = append(args, data...)
|
||
|
|
||
|
stdout := bytes.NewBuffer(nil)
|
||
|
stderr := bytes.NewBuffer(nil)
|
||
|
runOpts := &RunOptions{
|
||
|
Stdout: stdout,
|
||
|
Stderr: stderr,
|
||
|
Client: client,
|
||
|
}
|
||
|
|
||
|
code := RunCustom(args, runOpts)
|
||
|
combined := stdout.String() + stderr.String()
|
||
|
|
||
|
if code != 0 {
|
||
|
if !expectFailure {
|
||
|
t.Fatalf("Got unexpected failure from test (ret %d): %v", code, combined)
|
||
|
}
|
||
|
} else {
|
||
|
if expectFailure {
|
||
|
t.Fatalf("Expected failure, got success from test (ret %d): %v", code, combined)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func generateKeys(t *testing.T) (rsa1 []byte, rsa2 []byte, aes128 []byte, aes256 []byte) {
|
||
|
t.Helper()
|
||
|
|
||
|
priv1, err := rsa.GenerateKey(rand.Reader, 2048)
|
||
|
require.NotNil(t, priv1, "failed generating RSA 1 key")
|
||
|
require.NoError(t, err, "failed generating RSA 1 key")
|
||
|
|
||
|
rsa1, err = x509.MarshalPKCS8PrivateKey(priv1)
|
||
|
require.NotNil(t, rsa1, "failed marshaling RSA 1 key")
|
||
|
require.NoError(t, err, "failed marshaling RSA 1 key")
|
||
|
|
||
|
priv2, err := rsa.GenerateKey(rand.Reader, 2048)
|
||
|
require.NotNil(t, priv2, "failed generating RSA 2 key")
|
||
|
require.NoError(t, err, "failed generating RSA 2 key")
|
||
|
|
||
|
rsa2, err = x509.MarshalPKCS8PrivateKey(priv2)
|
||
|
require.NotNil(t, rsa2, "failed marshaling RSA 2 key")
|
||
|
require.NoError(t, err, "failed marshaling RSA 2 key")
|
||
|
|
||
|
aes128 = make([]byte, 128/8)
|
||
|
_, err = rand.Read(aes128)
|
||
|
require.NoError(t, err, "failed generating AES 128 key")
|
||
|
|
||
|
aes256 = make([]byte, 256/8)
|
||
|
_, err = rand.Read(aes256)
|
||
|
require.NoError(t, err, "failed generating AES 256 key")
|
||
|
|
||
|
return
|
||
|
}
|