185 lines
4.4 KiB
Go
185 lines
4.4 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package server
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/url"
|
|
"path"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-uuid"
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/internalshared/configutil"
|
|
"github.com/hashicorp/vault/sdk/helper/docker"
|
|
)
|
|
|
|
func TestTransitWrapper_Lifecycle(t *testing.T) {
|
|
cleanup, config := prepareTestContainer(t)
|
|
defer cleanup()
|
|
|
|
wrapperConfig := map[string]string{
|
|
"address": config.URL().String(),
|
|
"token": config.token,
|
|
"mount_path": config.mountPath,
|
|
"key_name": config.keyName,
|
|
}
|
|
|
|
kms, _, err := configutil.GetTransitKMSFunc(&configutil.KMS{Config: wrapperConfig})
|
|
if err != nil {
|
|
t.Fatalf("error setting wrapper config: %v", err)
|
|
}
|
|
|
|
// Test Encrypt and Decrypt calls
|
|
input := []byte("foo")
|
|
swi, err := kms.Encrypt(context.Background(), input, nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err.Error())
|
|
}
|
|
|
|
pt, err := kms.Decrypt(context.Background(), swi, nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err.Error())
|
|
}
|
|
|
|
if !reflect.DeepEqual(input, pt) {
|
|
t.Fatalf("expected %s, got %s", input, pt)
|
|
}
|
|
}
|
|
|
|
func TestTransitSeal_TokenRenewal(t *testing.T) {
|
|
cleanup, config := prepareTestContainer(t)
|
|
defer cleanup()
|
|
|
|
remoteClient, err := api.NewClient(config.apiConfig())
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
remoteClient.SetToken(config.token)
|
|
|
|
req := &api.TokenCreateRequest{
|
|
Period: "5s",
|
|
}
|
|
rsp, err := remoteClient.Auth().Token().Create(req)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
wrapperConfig := map[string]string{
|
|
"address": config.URL().String(),
|
|
"token": rsp.Auth.ClientToken,
|
|
"mount_path": config.mountPath,
|
|
"key_name": config.keyName,
|
|
}
|
|
kms, _, err := configutil.GetTransitKMSFunc(&configutil.KMS{Config: wrapperConfig})
|
|
if err != nil {
|
|
t.Fatalf("error setting wrapper config: %v", err)
|
|
}
|
|
|
|
time.Sleep(7 * time.Second)
|
|
|
|
// Test Encrypt and Decrypt calls
|
|
input := []byte("foo")
|
|
swi, err := kms.Encrypt(context.Background(), input, nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err.Error())
|
|
}
|
|
|
|
pt, err := kms.Decrypt(context.Background(), swi, nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err.Error())
|
|
}
|
|
|
|
if !reflect.DeepEqual(input, pt) {
|
|
t.Fatalf("expected %s, got %s", input, pt)
|
|
}
|
|
}
|
|
|
|
type DockerVaultConfig struct {
|
|
docker.ServiceURL
|
|
token string
|
|
mountPath string
|
|
keyName string
|
|
tlsConfig *api.TLSConfig
|
|
}
|
|
|
|
func (c *DockerVaultConfig) apiConfig() *api.Config {
|
|
vaultConfig := api.DefaultConfig()
|
|
vaultConfig.Address = c.URL().String()
|
|
if err := vaultConfig.ConfigureTLS(c.tlsConfig); err != nil {
|
|
panic("unable to configure TLS")
|
|
}
|
|
|
|
return vaultConfig
|
|
}
|
|
|
|
var _ docker.ServiceConfig = &DockerVaultConfig{}
|
|
|
|
func prepareTestContainer(t *testing.T) (func(), *DockerVaultConfig) {
|
|
rootToken, err := uuid.GenerateUUID()
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
testMountPath, err := uuid.GenerateUUID()
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
testKeyName, err := uuid.GenerateUUID()
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
runner, err := docker.NewServiceRunner(docker.RunOptions{
|
|
ContainerName: "vault",
|
|
ImageRepo: "docker.mirror.hashicorp.services/hashicorp/vault",
|
|
ImageTag: "latest",
|
|
Cmd: []string{
|
|
"server", "-log-level=trace", "-dev", fmt.Sprintf("-dev-root-token-id=%s", rootToken),
|
|
"-dev-listen-address=0.0.0.0:8200",
|
|
},
|
|
Ports: []string{"8200/tcp"},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("could not start docker vault: %s", err)
|
|
}
|
|
|
|
svc, err := runner.StartService(context.Background(), func(ctx context.Context, host string, port int) (docker.ServiceConfig, error) {
|
|
c := &DockerVaultConfig{
|
|
ServiceURL: *docker.NewServiceURL(url.URL{Scheme: "http", Host: fmt.Sprintf("%s:%d", host, port)}),
|
|
tlsConfig: &api.TLSConfig{
|
|
Insecure: true,
|
|
},
|
|
token: rootToken,
|
|
mountPath: testMountPath,
|
|
keyName: testKeyName,
|
|
}
|
|
vault, err := api.NewClient(c.apiConfig())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vault.SetToken(rootToken)
|
|
|
|
// Set up transit
|
|
if err := vault.Sys().Mount(testMountPath, &api.MountInput{
|
|
Type: "transit",
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create default aesgcm key
|
|
if _, err := vault.Logical().Write(path.Join(testMountPath, "keys", testKeyName), map[string]interface{}{}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return c, nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("could not start docker vault: %s", err)
|
|
}
|
|
return svc.Cleanup, svc.Config.(*DockerVaultConfig)
|
|
}
|