1558 lines
41 KiB
Go
1558 lines
41 KiB
Go
package vault
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/base64"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"math/big"
|
|
mathrand "math/rand"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/mitchellh/copystructure"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/net/http2"
|
|
|
|
"github.com/hashicorp/go-cleanhttp"
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/audit"
|
|
"github.com/hashicorp/vault/helper/consts"
|
|
"github.com/hashicorp/vault/helper/logging"
|
|
"github.com/hashicorp/vault/helper/reload"
|
|
"github.com/hashicorp/vault/helper/salt"
|
|
"github.com/hashicorp/vault/logical"
|
|
"github.com/hashicorp/vault/logical/framework"
|
|
"github.com/hashicorp/vault/physical"
|
|
dbMysql "github.com/hashicorp/vault/plugins/database/mysql"
|
|
dbPostgres "github.com/hashicorp/vault/plugins/database/postgresql"
|
|
"github.com/mitchellh/go-testing-interface"
|
|
|
|
physInmem "github.com/hashicorp/vault/physical/inmem"
|
|
)
|
|
|
|
// This file contains a number of methods that are useful for unit
|
|
// tests within other packages.
|
|
|
|
const (
|
|
testSharedPublicKey = `
|
|
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC9i+hFxZHGo6KblVme4zrAcJstR6I0PTJozW286X4WyvPnkMYDQ5mnhEYC7UWCvjoTWbPEXPX7NjhRtwQTGD67bV+lrxgfyzK1JZbUXK4PwgKJvQD+XyyWYMzDgGSQY61KUSqCxymSm/9NZkPU3ElaQ9xQuTzPpztM4ROfb8f2Yv6/ZESZsTo0MTAkp8Pcy+WkioI/uJ1H7zqs0EA4OMY4aDJRu0UtP4rTVeYNEAuRXdX+eH4aW3KMvhzpFTjMbaJHJXlEeUm2SaX5TNQyTOvghCeQILfYIL/Ca2ij8iwCmulwdV6eQGfd4VDu40PvSnmfoaE38o6HaPnX0kUcnKiT
|
|
`
|
|
testSharedPrivateKey = `
|
|
-----BEGIN RSA PRIVATE KEY-----
|
|
MIIEogIBAAKCAQEAvYvoRcWRxqOim5VZnuM6wHCbLUeiND0yaM1tvOl+Fsrz55DG
|
|
A0OZp4RGAu1Fgr46E1mzxFz1+zY4UbcEExg+u21fpa8YH8sytSWW1FyuD8ICib0A
|
|
/l8slmDMw4BkkGOtSlEqgscpkpv/TWZD1NxJWkPcULk8z6c7TOETn2/H9mL+v2RE
|
|
mbE6NDEwJKfD3MvlpIqCP7idR+86rNBAODjGOGgyUbtFLT+K01XmDRALkV3V/nh+
|
|
GltyjL4c6RU4zG2iRyV5RHlJtkml+UzUMkzr4IQnkCC32CC/wmtoo/IsAprpcHVe
|
|
nkBn3eFQ7uND70p5n6GhN/KOh2j519JFHJyokwIDAQABAoIBAHX7VOvBC3kCN9/x
|
|
+aPdup84OE7Z7MvpX6w+WlUhXVugnmsAAVDczhKoUc/WktLLx2huCGhsmKvyVuH+
|
|
MioUiE+vx75gm3qGx5xbtmOfALVMRLopjCnJYf6EaFA0ZeQ+NwowNW7Lu0PHmAU8
|
|
Z3JiX8IwxTz14DU82buDyewO7v+cEr97AnERe3PUcSTDoUXNaoNxjNpEJkKREY6h
|
|
4hAY676RT/GsRcQ8tqe/rnCqPHNd7JGqL+207FK4tJw7daoBjQyijWuB7K5chSal
|
|
oPInylM6b13ASXuOAOT/2uSUBWmFVCZPDCmnZxy2SdnJGbsJAMl7Ma3MUlaGvVI+
|
|
Tfh1aQkCgYEA4JlNOabTb3z42wz6mz+Nz3JRwbawD+PJXOk5JsSnV7DtPtfgkK9y
|
|
6FTQdhnozGWShAvJvc+C4QAihs9AlHXoaBY5bEU7R/8UK/pSqwzam+MmxmhVDV7G
|
|
IMQPV0FteoXTaJSikhZ88mETTegI2mik+zleBpVxvfdhE5TR+lq8Br0CgYEA2AwJ
|
|
CUD5CYUSj09PluR0HHqamWOrJkKPFPwa+5eiTTCzfBBxImYZh7nXnWuoviXC0sg2
|
|
AuvCW+uZ48ygv/D8gcz3j1JfbErKZJuV+TotK9rRtNIF5Ub7qysP7UjyI7zCssVM
|
|
kuDd9LfRXaB/qGAHNkcDA8NxmHW3gpln4CFdSY8CgYANs4xwfercHEWaJ1qKagAe
|
|
rZyrMpffAEhicJ/Z65lB0jtG4CiE6w8ZeUMWUVJQVcnwYD+4YpZbX4S7sJ0B8Ydy
|
|
AhkSr86D/92dKTIt2STk6aCN7gNyQ1vW198PtaAWH1/cO2UHgHOy3ZUt5X/Uwxl9
|
|
cex4flln+1Viumts2GgsCQKBgCJH7psgSyPekK5auFdKEr5+Gc/jB8I/Z3K9+g4X
|
|
5nH3G1PBTCJYLw7hRzw8W/8oALzvddqKzEFHphiGXK94Lqjt/A4q1OdbCrhiE68D
|
|
My21P/dAKB1UYRSs9Y8CNyHCjuZM9jSMJ8vv6vG/SOJPsnVDWVAckAbQDvlTHC9t
|
|
O98zAoGAcbW6uFDkrv0XMCpB9Su3KaNXOR0wzag+WIFQRXCcoTvxVi9iYfUReQPi
|
|
oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F
|
|
+B6f4RoPdSXj24JHPg/ioRxjaj094UXJxua2yfkcecGNEuBQHSs=
|
|
-----END RSA PRIVATE KEY-----
|
|
`
|
|
)
|
|
|
|
// TestCore returns a pure in-memory, uninitialized core for testing.
|
|
func TestCore(t testing.T) *Core {
|
|
return TestCoreWithSeal(t, nil, false)
|
|
}
|
|
|
|
// TestCoreRaw returns a pure in-memory, uninitialized core for testing. The raw
|
|
// storage endpoints are enabled with this core.
|
|
func TestCoreRaw(t testing.T) *Core {
|
|
return TestCoreWithSeal(t, nil, true)
|
|
}
|
|
|
|
// TestCoreNewSeal returns a pure in-memory, uninitialized core with
|
|
// the new seal configuration.
|
|
func TestCoreNewSeal(t testing.T) *Core {
|
|
seal := NewTestSeal(t, nil)
|
|
return TestCoreWithSeal(t, seal, false)
|
|
}
|
|
|
|
// TestCoreWithConfig returns a pure in-memory, uninitialized core with the
|
|
// specified core configurations overridden for testing.
|
|
func TestCoreWithConfig(t testing.T, conf *CoreConfig) *Core {
|
|
return TestCoreWithSealAndUI(t, conf)
|
|
}
|
|
|
|
// TestCoreWithSeal returns a pure in-memory, uninitialized core with the
|
|
// specified seal for testing.
|
|
func TestCoreWithSeal(t testing.T, testSeal Seal, enableRaw bool) *Core {
|
|
conf := &CoreConfig{
|
|
Seal: testSeal,
|
|
EnableUI: false,
|
|
EnableRaw: enableRaw,
|
|
BuiltinRegistry: NewMockBuiltinRegistry(),
|
|
}
|
|
return TestCoreWithSealAndUI(t, conf)
|
|
}
|
|
|
|
func TestCoreUI(t testing.T, enableUI bool) *Core {
|
|
conf := &CoreConfig{
|
|
EnableUI: enableUI,
|
|
EnableRaw: true,
|
|
BuiltinRegistry: NewMockBuiltinRegistry(),
|
|
}
|
|
return TestCoreWithSealAndUI(t, conf)
|
|
}
|
|
|
|
func TestCoreWithSealAndUI(t testing.T, opts *CoreConfig) *Core {
|
|
logger := logging.NewVaultLogger(log.Trace)
|
|
physicalBackend, err := physInmem.NewInmem(nil, logger)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Start off with base test core config
|
|
conf := testCoreConfig(t, physicalBackend, logger)
|
|
|
|
// Override config values with ones that gets passed in
|
|
conf.EnableUI = opts.EnableUI
|
|
conf.EnableRaw = opts.EnableRaw
|
|
conf.Seal = opts.Seal
|
|
conf.LicensingConfig = opts.LicensingConfig
|
|
conf.DisableKeyEncodingChecks = opts.DisableKeyEncodingChecks
|
|
|
|
for k, v := range opts.LogicalBackends {
|
|
conf.LogicalBackends[k] = v
|
|
}
|
|
for k, v := range opts.CredentialBackends {
|
|
conf.CredentialBackends[k] = v
|
|
}
|
|
|
|
c, err := NewCore(conf)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Logger) *CoreConfig {
|
|
t.Helper()
|
|
noopAudits := map[string]audit.Factory{
|
|
"noop": func(_ context.Context, config *audit.BackendConfig) (audit.Backend, error) {
|
|
view := &logical.InmemStorage{}
|
|
view.Put(context.Background(), &logical.StorageEntry{
|
|
Key: "salt",
|
|
Value: []byte("foo"),
|
|
})
|
|
config.SaltConfig = &salt.Config{
|
|
HMAC: sha256.New,
|
|
HMACType: "hmac-sha256",
|
|
}
|
|
config.SaltView = view
|
|
return &noopAudit{
|
|
Config: config,
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
noopBackends := make(map[string]logical.Factory)
|
|
noopBackends["noop"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) {
|
|
b := new(framework.Backend)
|
|
b.Setup(ctx, config)
|
|
b.BackendType = logical.TypeCredential
|
|
return b, nil
|
|
}
|
|
noopBackends["http"] = func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) {
|
|
return new(rawHTTP), nil
|
|
}
|
|
|
|
credentialBackends := make(map[string]logical.Factory)
|
|
for backendName, backendFactory := range noopBackends {
|
|
credentialBackends[backendName] = backendFactory
|
|
}
|
|
for backendName, backendFactory := range testCredentialBackends {
|
|
credentialBackends[backendName] = backendFactory
|
|
}
|
|
|
|
logicalBackends := make(map[string]logical.Factory)
|
|
for backendName, backendFactory := range noopBackends {
|
|
logicalBackends[backendName] = backendFactory
|
|
}
|
|
|
|
logicalBackends["kv"] = LeasedPassthroughBackendFactory
|
|
for backendName, backendFactory := range testLogicalBackends {
|
|
logicalBackends[backendName] = backendFactory
|
|
}
|
|
|
|
conf := &CoreConfig{
|
|
Physical: physicalBackend,
|
|
AuditBackends: noopAudits,
|
|
LogicalBackends: logicalBackends,
|
|
CredentialBackends: credentialBackends,
|
|
DisableMlock: true,
|
|
Logger: logger,
|
|
BuiltinRegistry: NewMockBuiltinRegistry(),
|
|
}
|
|
|
|
return conf
|
|
}
|
|
|
|
// TestCoreInit initializes the core with a single key, and returns
|
|
// the key that must be used to unseal the core and a root token.
|
|
func TestCoreInit(t testing.T, core *Core) ([][]byte, string) {
|
|
t.Helper()
|
|
secretShares, _, root := TestCoreInitClusterWrapperSetup(t, core, nil, nil)
|
|
return secretShares, root
|
|
}
|
|
|
|
func TestCoreInitClusterWrapperSetup(t testing.T, core *Core, clusterAddrs []*net.TCPAddr, handler http.Handler) ([][]byte, [][]byte, string) {
|
|
t.Helper()
|
|
core.SetClusterListenerAddrs(clusterAddrs)
|
|
core.SetClusterHandler(handler)
|
|
|
|
barrierConfig := &SealConfig{
|
|
SecretShares: 3,
|
|
SecretThreshold: 3,
|
|
}
|
|
|
|
// If we support storing barrier keys, then set that to equal the min threshold to unseal
|
|
if core.seal.StoredKeysSupported() {
|
|
barrierConfig.StoredShares = barrierConfig.SecretThreshold
|
|
}
|
|
|
|
recoveryConfig := &SealConfig{
|
|
SecretShares: 3,
|
|
SecretThreshold: 3,
|
|
}
|
|
|
|
result, err := core.Initialize(context.Background(), &InitParams{
|
|
BarrierConfig: barrierConfig,
|
|
RecoveryConfig: recoveryConfig,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
return result.SecretShares, result.RecoveryShares, result.RootToken
|
|
}
|
|
|
|
func TestCoreUnseal(core *Core, key []byte) (bool, error) {
|
|
return core.Unseal(key)
|
|
}
|
|
|
|
func TestCoreUnsealWithRecoveryKeys(core *Core, key []byte) (bool, error) {
|
|
return core.UnsealWithRecoveryKeys(key)
|
|
}
|
|
|
|
// TestCoreUnsealed returns a pure in-memory core that is already
|
|
// initialized and unsealed.
|
|
func TestCoreUnsealed(t testing.T) (*Core, [][]byte, string) {
|
|
t.Helper()
|
|
core := TestCore(t)
|
|
return testCoreUnsealed(t, core)
|
|
}
|
|
|
|
// TestCoreUnsealedRaw returns a pure in-memory core that is already
|
|
// initialized, unsealed, and with raw endpoints enabled.
|
|
func TestCoreUnsealedRaw(t testing.T) (*Core, [][]byte, string) {
|
|
t.Helper()
|
|
core := TestCoreRaw(t)
|
|
return testCoreUnsealed(t, core)
|
|
}
|
|
|
|
// TestCoreUnsealedWithConfig returns a pure in-memory core that is already
|
|
// initialized, unsealed, with the any provided core config values overridden.
|
|
func TestCoreUnsealedWithConfig(t testing.T, conf *CoreConfig) (*Core, [][]byte, string) {
|
|
t.Helper()
|
|
core := TestCoreWithConfig(t, conf)
|
|
return testCoreUnsealed(t, core)
|
|
}
|
|
|
|
func testCoreUnsealed(t testing.T, core *Core) (*Core, [][]byte, string) {
|
|
t.Helper()
|
|
keys, token := TestCoreInit(t, core)
|
|
for _, key := range keys {
|
|
if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil {
|
|
t.Fatalf("unseal err: %s", err)
|
|
}
|
|
}
|
|
|
|
if core.Sealed() {
|
|
t.Fatal("should not be sealed")
|
|
}
|
|
|
|
return core, keys, token
|
|
}
|
|
|
|
func TestCoreUnsealedBackend(t testing.T, backend physical.Backend) (*Core, [][]byte, string) {
|
|
t.Helper()
|
|
logger := logging.NewVaultLogger(log.Trace)
|
|
conf := testCoreConfig(t, backend, logger)
|
|
conf.Seal = NewTestSeal(t, nil)
|
|
|
|
core, err := NewCore(conf)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
keys, token := TestCoreInit(t, core)
|
|
for _, key := range keys {
|
|
if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil {
|
|
t.Fatalf("unseal err: %s", err)
|
|
}
|
|
}
|
|
|
|
if err := core.UnsealWithStoredKeys(context.Background()); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if core.Sealed() {
|
|
t.Fatal("should not be sealed")
|
|
}
|
|
|
|
return core, keys, token
|
|
}
|
|
|
|
// TestKeyCopy is a silly little function to just copy the key so that
|
|
// it can be used with Unseal easily.
|
|
func TestKeyCopy(key []byte) []byte {
|
|
result := make([]byte, len(key))
|
|
copy(result, key)
|
|
return result
|
|
}
|
|
|
|
func TestDynamicSystemView(c *Core) *dynamicSystemView {
|
|
me := &MountEntry{
|
|
Config: MountConfig{
|
|
DefaultLeaseTTL: 24 * time.Hour,
|
|
MaxLeaseTTL: 2 * 24 * time.Hour,
|
|
},
|
|
}
|
|
|
|
return &dynamicSystemView{c, me}
|
|
}
|
|
|
|
// TestAddTestPlugin registers the testFunc as part of the plugin command to the
|
|
// plugin catalog. If provided, uses tmpDir as the plugin directory.
|
|
func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.PluginType, testFunc string, env []string, tempDir string) {
|
|
file, err := os.Open(os.Args[0])
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer file.Close()
|
|
|
|
dirPath := filepath.Dir(os.Args[0])
|
|
fileName := filepath.Base(os.Args[0])
|
|
|
|
if tempDir != "" {
|
|
fi, err := file.Stat()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Copy over the file to the temp dir
|
|
dst := filepath.Join(tempDir, fileName)
|
|
out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer out.Close()
|
|
|
|
if _, err = io.Copy(out, file); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err = out.Sync()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
dirPath = tempDir
|
|
}
|
|
|
|
// Determine plugin directory full path, evaluating potential symlink path
|
|
fullPath, err := filepath.EvalSymlinks(dirPath)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
reader, err := os.Open(filepath.Join(fullPath, fileName))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer reader.Close()
|
|
|
|
// Find out the sha256
|
|
hash := sha256.New()
|
|
|
|
_, err = io.Copy(hash, reader)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
sum := hash.Sum(nil)
|
|
|
|
// Set core's plugin directory and plugin catalog directory
|
|
c.pluginDirectory = fullPath
|
|
c.pluginCatalog.directory = fullPath
|
|
|
|
args := []string{fmt.Sprintf("--test.run=%s", testFunc)}
|
|
err = c.pluginCatalog.Set(context.Background(), name, pluginType, fileName, args, env, sum)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
var testLogicalBackends = map[string]logical.Factory{}
|
|
var testCredentialBackends = map[string]logical.Factory{}
|
|
|
|
// StartSSHHostTestServer starts the test server which responds to SSH
|
|
// authentication. Used to test the SSH secret backend.
|
|
func StartSSHHostTestServer() (string, error) {
|
|
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testSharedPublicKey))
|
|
if err != nil {
|
|
return "", fmt.Errorf("error parsing public key")
|
|
}
|
|
serverConfig := &ssh.ServerConfig{
|
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
|
if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 {
|
|
return &ssh.Permissions{}, nil
|
|
} else {
|
|
return nil, fmt.Errorf("key does not match")
|
|
}
|
|
},
|
|
}
|
|
signer, err := ssh.ParsePrivateKey([]byte(testSharedPrivateKey))
|
|
if err != nil {
|
|
panic("Error parsing private key")
|
|
}
|
|
serverConfig.AddHostKey(signer)
|
|
|
|
soc, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
return "", fmt.Errorf("error listening to connection")
|
|
}
|
|
|
|
go func() {
|
|
for {
|
|
conn, err := soc.Accept()
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Error accepting incoming connection: %s", err))
|
|
}
|
|
defer conn.Close()
|
|
sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Handshaking error: %v", err))
|
|
}
|
|
|
|
go func() {
|
|
for chanReq := range chanReqs {
|
|
go func(chanReq ssh.NewChannel) {
|
|
if chanReq.ChannelType() != "session" {
|
|
chanReq.Reject(ssh.UnknownChannelType, "unknown channel type")
|
|
return
|
|
}
|
|
|
|
ch, requests, err := chanReq.Accept()
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Error accepting channel: %s", err))
|
|
}
|
|
|
|
go func(ch ssh.Channel, in <-chan *ssh.Request) {
|
|
for req := range in {
|
|
executeServerCommand(ch, req)
|
|
}
|
|
}(ch, requests)
|
|
}(chanReq)
|
|
}
|
|
sshConn.Close()
|
|
}()
|
|
}
|
|
}()
|
|
return soc.Addr().String(), nil
|
|
}
|
|
|
|
// This executes the commands requested to be run on the server.
|
|
// Used to test the SSH secret backend.
|
|
func executeServerCommand(ch ssh.Channel, req *ssh.Request) {
|
|
command := string(req.Payload[4:])
|
|
cmd := exec.Command("/bin/bash", []string{"-c", command}...)
|
|
req.Reply(true, nil)
|
|
|
|
cmd.Stdout = ch
|
|
cmd.Stderr = ch
|
|
cmd.Stdin = ch
|
|
|
|
err := cmd.Start()
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Error starting the command: '%s'", err))
|
|
}
|
|
|
|
go func() {
|
|
_, err := cmd.Process.Wait()
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err))
|
|
}
|
|
ch.Close()
|
|
}()
|
|
}
|
|
|
|
// This adds a credential backend for the test core. This needs to be
|
|
// invoked before the test core is created.
|
|
func AddTestCredentialBackend(name string, factory logical.Factory) error {
|
|
if name == "" {
|
|
return fmt.Errorf("missing backend name")
|
|
}
|
|
if factory == nil {
|
|
return fmt.Errorf("missing backend factory function")
|
|
}
|
|
testCredentialBackends[name] = factory
|
|
return nil
|
|
}
|
|
|
|
// This adds a logical backend for the test core. This needs to be
|
|
// invoked before the test core is created.
|
|
func AddTestLogicalBackend(name string, factory logical.Factory) error {
|
|
if name == "" {
|
|
return fmt.Errorf("missing backend name")
|
|
}
|
|
if factory == nil {
|
|
return fmt.Errorf("missing backend factory function")
|
|
}
|
|
testLogicalBackends[name] = factory
|
|
return nil
|
|
}
|
|
|
|
type noopAudit struct {
|
|
Config *audit.BackendConfig
|
|
salt *salt.Salt
|
|
saltMutex sync.RWMutex
|
|
}
|
|
|
|
func (n *noopAudit) GetHash(ctx context.Context, data string) (string, error) {
|
|
salt, err := n.Salt(ctx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return salt.GetIdentifiedHMAC(data), nil
|
|
}
|
|
|
|
func (n *noopAudit) LogRequest(_ context.Context, _ *audit.LogInput) error {
|
|
return nil
|
|
}
|
|
|
|
func (n *noopAudit) LogResponse(_ context.Context, _ *audit.LogInput) error {
|
|
return nil
|
|
}
|
|
|
|
func (n *noopAudit) Reload(_ context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (n *noopAudit) Invalidate(_ context.Context) {
|
|
n.saltMutex.Lock()
|
|
defer n.saltMutex.Unlock()
|
|
n.salt = nil
|
|
}
|
|
|
|
func (n *noopAudit) Salt(ctx context.Context) (*salt.Salt, error) {
|
|
n.saltMutex.RLock()
|
|
if n.salt != nil {
|
|
defer n.saltMutex.RUnlock()
|
|
return n.salt, nil
|
|
}
|
|
n.saltMutex.RUnlock()
|
|
n.saltMutex.Lock()
|
|
defer n.saltMutex.Unlock()
|
|
if n.salt != nil {
|
|
return n.salt, nil
|
|
}
|
|
salt, err := salt.NewSalt(ctx, n.Config.SaltView, n.Config.SaltConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
n.salt = salt
|
|
return salt, nil
|
|
}
|
|
|
|
type rawHTTP struct{}
|
|
|
|
func (n *rawHTTP) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) {
|
|
return &logical.Response{
|
|
Data: map[string]interface{}{
|
|
logical.HTTPStatusCode: 200,
|
|
logical.HTTPContentType: "plain/text",
|
|
logical.HTTPRawBody: []byte("hello world"),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (n *rawHTTP) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) {
|
|
return false, false, nil
|
|
}
|
|
|
|
func (n *rawHTTP) SpecialPaths() *logical.Paths {
|
|
return &logical.Paths{Unauthenticated: []string{"*"}}
|
|
}
|
|
|
|
func (n *rawHTTP) System() logical.SystemView {
|
|
return logical.StaticSystemView{
|
|
DefaultLeaseTTLVal: time.Hour * 24,
|
|
MaxLeaseTTLVal: time.Hour * 24 * 32,
|
|
}
|
|
}
|
|
|
|
func (n *rawHTTP) Logger() log.Logger {
|
|
return logging.NewVaultLogger(log.Trace)
|
|
}
|
|
|
|
func (n *rawHTTP) Cleanup(ctx context.Context) {
|
|
// noop
|
|
}
|
|
|
|
func (n *rawHTTP) Initialize(ctx context.Context) error {
|
|
// noop
|
|
return nil
|
|
}
|
|
|
|
func (n *rawHTTP) InvalidateKey(context.Context, string) {
|
|
// noop
|
|
}
|
|
|
|
func (n *rawHTTP) Setup(ctx context.Context, config *logical.BackendConfig) error {
|
|
// noop
|
|
return nil
|
|
}
|
|
|
|
func (n *rawHTTP) Type() logical.BackendType {
|
|
return logical.TypeLogical
|
|
}
|
|
|
|
func GenerateRandBytes(length int) ([]byte, error) {
|
|
if length < 0 {
|
|
return nil, fmt.Errorf("length must be >= 0")
|
|
}
|
|
|
|
buf := make([]byte, length)
|
|
if length == 0 {
|
|
return buf, nil
|
|
}
|
|
|
|
n, err := rand.Read(buf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if n != length {
|
|
return nil, fmt.Errorf("unable to read %d bytes; only read %d", length, n)
|
|
}
|
|
|
|
return buf, nil
|
|
}
|
|
|
|
func TestWaitActive(t testing.T, core *Core) {
|
|
t.Helper()
|
|
if err := TestWaitActiveWithError(core); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestWaitActiveWithError(core *Core) error {
|
|
start := time.Now()
|
|
var standby bool
|
|
var err error
|
|
for time.Now().Sub(start) < time.Second {
|
|
standby, err = core.Standby()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !standby {
|
|
break
|
|
}
|
|
}
|
|
if standby {
|
|
return errors.New("should not be in standby mode")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type TestCluster struct {
|
|
BarrierKeys [][]byte
|
|
RecoveryKeys [][]byte
|
|
CACert *x509.Certificate
|
|
CACertBytes []byte
|
|
CACertPEM []byte
|
|
CACertPEMFile string
|
|
CAKey *ecdsa.PrivateKey
|
|
CAKeyPEM []byte
|
|
Cores []*TestClusterCore
|
|
ID string
|
|
RootToken string
|
|
RootCAs *x509.CertPool
|
|
TempDir string
|
|
}
|
|
|
|
func (c *TestCluster) Start() {
|
|
for _, core := range c.Cores {
|
|
if core.Server != nil {
|
|
for _, ln := range core.Listeners {
|
|
go core.Server.Serve(ln)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// UnsealCores uses the cluster barrier keys to unseal the test cluster cores
|
|
func (c *TestCluster) UnsealCores(t testing.T) {
|
|
if err := c.UnsealCoresWithError(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func (c *TestCluster) UnsealCoresWithError() error {
|
|
numCores := len(c.Cores)
|
|
|
|
// Unseal first core
|
|
for _, key := range c.BarrierKeys {
|
|
if _, err := c.Cores[0].Unseal(TestKeyCopy(key)); err != nil {
|
|
return fmt.Errorf("unseal err: %s", err)
|
|
}
|
|
}
|
|
|
|
// Verify unsealed
|
|
if c.Cores[0].Sealed() {
|
|
return fmt.Errorf("should not be sealed")
|
|
}
|
|
|
|
if err := TestWaitActiveWithError(c.Cores[0].Core); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Unseal other cores
|
|
for i := 1; i < numCores; i++ {
|
|
for _, key := range c.BarrierKeys {
|
|
if _, err := c.Cores[i].Core.Unseal(TestKeyCopy(key)); err != nil {
|
|
return fmt.Errorf("unseal err: %s", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Let them come fully up to standby
|
|
time.Sleep(2 * time.Second)
|
|
|
|
// Ensure cluster connection info is populated.
|
|
// Other cores should not come up as leaders.
|
|
for i := 1; i < numCores; i++ {
|
|
isLeader, _, _, err := c.Cores[i].Leader()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if isLeader {
|
|
return fmt.Errorf("core[%d] should not be leader", i)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *TestCluster) EnsureCoresSealed(t testing.T) {
|
|
t.Helper()
|
|
if err := c.ensureCoresSealed(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func CleanupClusters(clusters []*TestCluster) {
|
|
wg := &sync.WaitGroup{}
|
|
for _, cluster := range clusters {
|
|
wg.Add(1)
|
|
lc := cluster
|
|
go func() {
|
|
defer wg.Done()
|
|
lc.Cleanup()
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func (c *TestCluster) Cleanup() {
|
|
// Close listeners
|
|
wg := &sync.WaitGroup{}
|
|
for _, core := range c.Cores {
|
|
wg.Add(1)
|
|
lc := core
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
if lc.Listeners != nil {
|
|
for _, ln := range lc.Listeners {
|
|
ln.Close()
|
|
}
|
|
}
|
|
if lc.licensingStopCh != nil {
|
|
close(lc.licensingStopCh)
|
|
lc.licensingStopCh = nil
|
|
}
|
|
|
|
if err := lc.Shutdown(); err != nil {
|
|
lc.Logger().Error("error during shutdown; abandoning sealing", "error", err)
|
|
} else {
|
|
timeout := time.Now().Add(60 * time.Second)
|
|
for {
|
|
if time.Now().After(timeout) {
|
|
lc.Logger().Error("timeout waiting for core to seal")
|
|
}
|
|
if lc.Sealed() {
|
|
break
|
|
}
|
|
time.Sleep(250 * time.Millisecond)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
// Remove any temp dir that exists
|
|
if c.TempDir != "" {
|
|
os.RemoveAll(c.TempDir)
|
|
}
|
|
|
|
// Give time to actually shut down/clean up before the next test
|
|
time.Sleep(time.Second)
|
|
}
|
|
|
|
func (c *TestCluster) ensureCoresSealed() error {
|
|
for _, core := range c.Cores {
|
|
if err := core.Shutdown(); err != nil {
|
|
return err
|
|
}
|
|
timeout := time.Now().Add(60 * time.Second)
|
|
for {
|
|
if time.Now().After(timeout) {
|
|
return fmt.Errorf("timeout waiting for core to seal")
|
|
}
|
|
if core.Sealed() {
|
|
break
|
|
}
|
|
time.Sleep(250 * time.Millisecond)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UnsealWithStoredKeys uses stored keys to unseal the test cluster cores
|
|
func (c *TestCluster) UnsealWithStoredKeys(t testing.T) error {
|
|
for _, core := range c.Cores {
|
|
if err := core.UnsealWithStoredKeys(context.Background()); err != nil {
|
|
return err
|
|
}
|
|
timeout := time.Now().Add(60 * time.Second)
|
|
for {
|
|
if time.Now().After(timeout) {
|
|
return fmt.Errorf("timeout waiting for core to unseal")
|
|
}
|
|
if !core.Sealed() {
|
|
break
|
|
}
|
|
time.Sleep(250 * time.Millisecond)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func SetReplicationFailureMode(core *TestClusterCore, mode uint32) {
|
|
atomic.StoreUint32(core.Core.replicationFailure, mode)
|
|
}
|
|
|
|
type TestListener struct {
|
|
net.Listener
|
|
Address *net.TCPAddr
|
|
}
|
|
|
|
type TestClusterCore struct {
|
|
*Core
|
|
CoreConfig *CoreConfig
|
|
Client *api.Client
|
|
Handler http.Handler
|
|
Listeners []*TestListener
|
|
ReloadFuncs *map[string][]reload.ReloadFunc
|
|
ReloadFuncsLock *sync.RWMutex
|
|
Server *http.Server
|
|
ServerCert *x509.Certificate
|
|
ServerCertBytes []byte
|
|
ServerCertPEM []byte
|
|
ServerKey *ecdsa.PrivateKey
|
|
ServerKeyPEM []byte
|
|
TLSConfig *tls.Config
|
|
UnderlyingStorage physical.Backend
|
|
}
|
|
|
|
type TestClusterOptions struct {
|
|
KeepStandbysSealed bool
|
|
SkipInit bool
|
|
HandlerFunc func(*HandlerProperties) http.Handler
|
|
BaseListenAddress string
|
|
NumCores int
|
|
SealFunc func() Seal
|
|
Logger log.Logger
|
|
TempDir string
|
|
CACert []byte
|
|
CAKey *ecdsa.PrivateKey
|
|
}
|
|
|
|
var DefaultNumCores = 3
|
|
|
|
type certInfo struct {
|
|
cert *x509.Certificate
|
|
certPEM []byte
|
|
certBytes []byte
|
|
key *ecdsa.PrivateKey
|
|
keyPEM []byte
|
|
}
|
|
|
|
// NewTestCluster creates a new test cluster based on the provided core config
|
|
// and test cluster options.
|
|
//
|
|
// N.B. Even though a single base CoreConfig is provided, NewTestCluster will instantiate a
|
|
// core config for each core it creates. If separate seal per core is desired, opts.SealFunc
|
|
// can be provided to generate a seal for each one. Otherwise, the provided base.Seal will be
|
|
// shared among cores. NewCore's default behavior is to generate a new DefaultSeal if the
|
|
// provided Seal in coreConfig (i.e. base.Seal) is nil.
|
|
func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster {
|
|
var err error
|
|
|
|
var numCores int
|
|
if opts == nil || opts.NumCores == 0 {
|
|
numCores = DefaultNumCores
|
|
} else {
|
|
numCores = opts.NumCores
|
|
}
|
|
|
|
certIPs := []net.IP{
|
|
net.IPv6loopback,
|
|
net.ParseIP("127.0.0.1"),
|
|
}
|
|
var baseAddr *net.TCPAddr
|
|
if opts != nil && opts.BaseListenAddress != "" {
|
|
baseAddr, err = net.ResolveTCPAddr("tcp", opts.BaseListenAddress)
|
|
if err != nil {
|
|
t.Fatal("could not parse given base IP")
|
|
}
|
|
certIPs = append(certIPs, baseAddr.IP)
|
|
}
|
|
|
|
var testCluster TestCluster
|
|
if opts != nil && opts.TempDir != "" {
|
|
if _, err := os.Stat(opts.TempDir); os.IsNotExist(err) {
|
|
if err := os.MkdirAll(opts.TempDir, 0700); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
testCluster.TempDir = opts.TempDir
|
|
} else {
|
|
tempDir, err := ioutil.TempDir("", "vault-test-cluster-")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
testCluster.TempDir = tempDir
|
|
}
|
|
|
|
var caKey *ecdsa.PrivateKey
|
|
if opts != nil && opts.CAKey != nil {
|
|
caKey = opts.CAKey
|
|
} else {
|
|
caKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
testCluster.CAKey = caKey
|
|
var caBytes []byte
|
|
if opts != nil && len(opts.CACert) > 0 {
|
|
caBytes = opts.CACert
|
|
} else {
|
|
caCertTemplate := &x509.Certificate{
|
|
Subject: pkix.Name{
|
|
CommonName: "localhost",
|
|
},
|
|
DNSNames: []string{"localhost"},
|
|
IPAddresses: certIPs,
|
|
KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
|
|
SerialNumber: big.NewInt(mathrand.Int63()),
|
|
NotBefore: time.Now().Add(-30 * time.Second),
|
|
NotAfter: time.Now().Add(262980 * time.Hour),
|
|
BasicConstraintsValid: true,
|
|
IsCA: true,
|
|
}
|
|
caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
caCert, err := x509.ParseCertificate(caBytes)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
testCluster.CACert = caCert
|
|
testCluster.CACertBytes = caBytes
|
|
testCluster.RootCAs = x509.NewCertPool()
|
|
testCluster.RootCAs.AddCert(caCert)
|
|
caCertPEMBlock := &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: caBytes,
|
|
}
|
|
testCluster.CACertPEM = pem.EncodeToMemory(caCertPEMBlock)
|
|
testCluster.CACertPEMFile = filepath.Join(testCluster.TempDir, "ca_cert.pem")
|
|
err = ioutil.WriteFile(testCluster.CACertPEMFile, testCluster.CACertPEM, 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
marshaledCAKey, err := x509.MarshalECPrivateKey(caKey)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
caKeyPEMBlock := &pem.Block{
|
|
Type: "EC PRIVATE KEY",
|
|
Bytes: marshaledCAKey,
|
|
}
|
|
testCluster.CAKeyPEM = pem.EncodeToMemory(caKeyPEMBlock)
|
|
err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "ca_key.pem"), testCluster.CAKeyPEM, 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
var certInfoSlice []*certInfo
|
|
|
|
//
|
|
// Certs generation
|
|
//
|
|
for i := 0; i < numCores; i++ {
|
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
certTemplate := &x509.Certificate{
|
|
Subject: pkix.Name{
|
|
CommonName: "localhost",
|
|
},
|
|
DNSNames: []string{"localhost"},
|
|
IPAddresses: certIPs,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
|
x509.ExtKeyUsageServerAuth,
|
|
x509.ExtKeyUsageClientAuth,
|
|
},
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
|
|
SerialNumber: big.NewInt(mathrand.Int63()),
|
|
NotBefore: time.Now().Add(-30 * time.Second),
|
|
NotAfter: time.Now().Add(262980 * time.Hour),
|
|
}
|
|
certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
cert, err := x509.ParseCertificate(certBytes)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
certPEMBlock := &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: certBytes,
|
|
}
|
|
certPEM := pem.EncodeToMemory(certPEMBlock)
|
|
marshaledKey, err := x509.MarshalECPrivateKey(key)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
keyPEMBlock := &pem.Block{
|
|
Type: "EC PRIVATE KEY",
|
|
Bytes: marshaledKey,
|
|
}
|
|
keyPEM := pem.EncodeToMemory(keyPEMBlock)
|
|
|
|
certInfoSlice = append(certInfoSlice, &certInfo{
|
|
cert: cert,
|
|
certPEM: certPEM,
|
|
certBytes: certBytes,
|
|
key: key,
|
|
keyPEM: keyPEM,
|
|
})
|
|
}
|
|
|
|
//
|
|
// Listener setup
|
|
//
|
|
logger := logging.NewVaultLogger(log.Trace)
|
|
ports := make([]int, numCores)
|
|
if baseAddr != nil {
|
|
for i := 0; i < numCores; i++ {
|
|
ports[i] = baseAddr.Port + i
|
|
}
|
|
} else {
|
|
baseAddr = &net.TCPAddr{
|
|
IP: net.ParseIP("127.0.0.1"),
|
|
Port: 0,
|
|
}
|
|
}
|
|
|
|
listeners := [][]*TestListener{}
|
|
servers := []*http.Server{}
|
|
handlers := []http.Handler{}
|
|
tlsConfigs := []*tls.Config{}
|
|
certGetters := []*reload.CertificateGetter{}
|
|
for i := 0; i < numCores; i++ {
|
|
baseAddr.Port = ports[i]
|
|
ln, err := net.ListenTCP("tcp", baseAddr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port))
|
|
keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port))
|
|
err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
certGetter := reload.NewCertificateGetter(certFile, keyFile, "")
|
|
certGetters = append(certGetters, certGetter)
|
|
tlsConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{tlsCert},
|
|
RootCAs: testCluster.RootCAs,
|
|
ClientCAs: testCluster.RootCAs,
|
|
ClientAuth: tls.RequestClientCert,
|
|
NextProtos: []string{"h2", "http/1.1"},
|
|
GetCertificate: certGetter.GetCertificate,
|
|
}
|
|
tlsConfig.BuildNameToCertificate()
|
|
tlsConfigs = append(tlsConfigs, tlsConfig)
|
|
lns := []*TestListener{&TestListener{
|
|
Listener: tls.NewListener(ln, tlsConfig),
|
|
Address: ln.Addr().(*net.TCPAddr),
|
|
},
|
|
}
|
|
listeners = append(listeners, lns)
|
|
var handler http.Handler = http.NewServeMux()
|
|
handlers = append(handlers, handler)
|
|
server := &http.Server{
|
|
Handler: handler,
|
|
ErrorLog: logger.StandardLogger(nil),
|
|
}
|
|
servers = append(servers, server)
|
|
}
|
|
|
|
// Create three cores with the same physical and different redirect/cluster
|
|
// addrs.
|
|
// N.B.: On OSX, instead of random ports, it assigns new ports to new
|
|
// listeners sequentially. Aside from being a bad idea in a security sense,
|
|
// it also broke tests that assumed it was OK to just use the port above
|
|
// the redirect addr. This has now been changed to 105 ports above, but if
|
|
// we ever do more than three nodes in a cluster it may need to be bumped.
|
|
// Note: it's 105 so that we don't conflict with a running Consul by
|
|
// default.
|
|
coreConfig := &CoreConfig{
|
|
LogicalBackends: make(map[string]logical.Factory),
|
|
CredentialBackends: make(map[string]logical.Factory),
|
|
AuditBackends: make(map[string]audit.Factory),
|
|
RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port),
|
|
ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port+105),
|
|
DisableMlock: true,
|
|
EnableUI: true,
|
|
EnableRaw: true,
|
|
BuiltinRegistry: NewMockBuiltinRegistry(),
|
|
}
|
|
|
|
if base != nil {
|
|
coreConfig.DisableCache = base.DisableCache
|
|
coreConfig.EnableUI = base.EnableUI
|
|
coreConfig.DefaultLeaseTTL = base.DefaultLeaseTTL
|
|
coreConfig.MaxLeaseTTL = base.MaxLeaseTTL
|
|
coreConfig.CacheSize = base.CacheSize
|
|
coreConfig.PluginDirectory = base.PluginDirectory
|
|
coreConfig.Seal = base.Seal
|
|
coreConfig.DevToken = base.DevToken
|
|
coreConfig.EnableRaw = base.EnableRaw
|
|
coreConfig.DisableSealWrap = base.DisableSealWrap
|
|
coreConfig.DevLicenseDuration = base.DevLicenseDuration
|
|
coreConfig.DisableCache = base.DisableCache
|
|
if base.BuiltinRegistry != nil {
|
|
coreConfig.BuiltinRegistry = base.BuiltinRegistry
|
|
}
|
|
|
|
if !coreConfig.DisableMlock {
|
|
base.DisableMlock = false
|
|
}
|
|
|
|
if base.Physical != nil {
|
|
coreConfig.Physical = base.Physical
|
|
}
|
|
|
|
if base.HAPhysical != nil {
|
|
coreConfig.HAPhysical = base.HAPhysical
|
|
}
|
|
|
|
// Used to set something non-working to test fallback
|
|
switch base.ClusterAddr {
|
|
case "empty":
|
|
coreConfig.ClusterAddr = ""
|
|
case "":
|
|
default:
|
|
coreConfig.ClusterAddr = base.ClusterAddr
|
|
}
|
|
|
|
if base.LogicalBackends != nil {
|
|
for k, v := range base.LogicalBackends {
|
|
coreConfig.LogicalBackends[k] = v
|
|
}
|
|
}
|
|
if base.CredentialBackends != nil {
|
|
for k, v := range base.CredentialBackends {
|
|
coreConfig.CredentialBackends[k] = v
|
|
}
|
|
}
|
|
if base.AuditBackends != nil {
|
|
for k, v := range base.AuditBackends {
|
|
coreConfig.AuditBackends[k] = v
|
|
}
|
|
}
|
|
if base.Logger != nil {
|
|
coreConfig.Logger = base.Logger
|
|
}
|
|
|
|
coreConfig.ClusterCipherSuites = base.ClusterCipherSuites
|
|
|
|
coreConfig.DisableCache = base.DisableCache
|
|
|
|
coreConfig.DevToken = base.DevToken
|
|
}
|
|
|
|
if coreConfig.Physical == nil {
|
|
coreConfig.Physical, err = physInmem.NewInmem(nil, logger)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
if coreConfig.HAPhysical == nil {
|
|
haPhys, err := physInmem.NewInmemHA(nil, logger)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
coreConfig.HAPhysical = haPhys.(physical.HABackend)
|
|
}
|
|
|
|
pubKey, priKey, err := testGenerateCoreKeys()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
cores := []*Core{}
|
|
coreConfigs := []*CoreConfig{}
|
|
for i := 0; i < numCores; i++ {
|
|
localConfig := *coreConfig
|
|
localConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port)
|
|
if localConfig.ClusterAddr != "" {
|
|
localConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port+105)
|
|
}
|
|
|
|
// if opts.SealFunc is provided, use that to generate a seal for the config instead
|
|
if opts != nil && opts.SealFunc != nil {
|
|
localConfig.Seal = opts.SealFunc()
|
|
}
|
|
|
|
if opts != nil && opts.Logger != nil {
|
|
localConfig.Logger = opts.Logger.Named(fmt.Sprintf("core%d", i))
|
|
}
|
|
|
|
localConfig.LicensingConfig = testGetLicensingConfig(pubKey)
|
|
|
|
c, err := NewCore(&localConfig)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
cores = append(cores, c)
|
|
coreConfigs = append(coreConfigs, &localConfig)
|
|
if opts != nil && opts.HandlerFunc != nil {
|
|
handlers[i] = opts.HandlerFunc(&HandlerProperties{
|
|
Core: c,
|
|
MaxRequestDuration: DefaultMaxRequestDuration,
|
|
})
|
|
servers[i].Handler = handlers[i]
|
|
}
|
|
|
|
// Set this in case the Seal was manually set before the core was
|
|
// created
|
|
if localConfig.Seal != nil {
|
|
localConfig.Seal.SetCore(c)
|
|
}
|
|
}
|
|
|
|
//
|
|
// Clustering setup
|
|
//
|
|
clusterAddrGen := func(lns []*TestListener) []*net.TCPAddr {
|
|
ret := make([]*net.TCPAddr, len(lns))
|
|
for i, ln := range lns {
|
|
ret[i] = &net.TCPAddr{
|
|
IP: ln.Address.IP,
|
|
Port: ln.Address.Port + 105,
|
|
}
|
|
}
|
|
return ret
|
|
}
|
|
|
|
for i := 0; i < numCores; i++ {
|
|
if coreConfigs[i].ClusterAddr != "" {
|
|
cores[i].SetClusterListenerAddrs(clusterAddrGen(listeners[i]))
|
|
cores[i].SetClusterHandler(handlers[i])
|
|
}
|
|
}
|
|
|
|
if opts == nil || !opts.SkipInit {
|
|
bKeys, rKeys, root := TestCoreInitClusterWrapperSetup(t, cores[0], clusterAddrGen(listeners[0]), handlers[0])
|
|
barrierKeys, _ := copystructure.Copy(bKeys)
|
|
testCluster.BarrierKeys = barrierKeys.([][]byte)
|
|
recoveryKeys, _ := copystructure.Copy(rKeys)
|
|
testCluster.RecoveryKeys = recoveryKeys.([][]byte)
|
|
testCluster.RootToken = root
|
|
|
|
// Write root token and barrier keys
|
|
err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(root), 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var buf bytes.Buffer
|
|
for i, key := range testCluster.BarrierKeys {
|
|
buf.Write([]byte(base64.StdEncoding.EncodeToString(key)))
|
|
if i < len(testCluster.BarrierKeys)-1 {
|
|
buf.WriteRune('\n')
|
|
}
|
|
}
|
|
err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "barrier_keys"), buf.Bytes(), 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
for i, key := range testCluster.RecoveryKeys {
|
|
buf.Write([]byte(base64.StdEncoding.EncodeToString(key)))
|
|
if i < len(testCluster.RecoveryKeys)-1 {
|
|
buf.WriteRune('\n')
|
|
}
|
|
}
|
|
err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "recovery_keys"), buf.Bytes(), 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Unseal first core
|
|
for _, key := range bKeys {
|
|
if _, err := cores[0].Unseal(TestKeyCopy(key)); err != nil {
|
|
t.Fatalf("unseal err: %s", err)
|
|
}
|
|
}
|
|
|
|
ctx := context.Background()
|
|
|
|
// If stored keys is supported, the above will no no-op, so trigger auto-unseal
|
|
// using stored keys to try to unseal
|
|
if err := cores[0].UnsealWithStoredKeys(ctx); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Verify unsealed
|
|
if cores[0].Sealed() {
|
|
t.Fatal("should not be sealed")
|
|
}
|
|
|
|
TestWaitActive(t, cores[0])
|
|
|
|
// Unseal other cores unless otherwise specified
|
|
if (opts == nil || !opts.KeepStandbysSealed) && numCores > 1 {
|
|
for i := 1; i < numCores; i++ {
|
|
for _, key := range bKeys {
|
|
if _, err := cores[i].Unseal(TestKeyCopy(key)); err != nil {
|
|
t.Fatalf("unseal err: %s", err)
|
|
}
|
|
}
|
|
|
|
// If stored keys is supported, the above will no no-op, so trigger auto-unseal
|
|
// using stored keys
|
|
if err := cores[i].UnsealWithStoredKeys(ctx); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
// Let them come fully up to standby
|
|
time.Sleep(2 * time.Second)
|
|
|
|
// Ensure cluster connection info is populated.
|
|
// Other cores should not come up as leaders.
|
|
for i := 1; i < numCores; i++ {
|
|
isLeader, _, _, err := cores[i].Leader()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if isLeader {
|
|
t.Fatalf("core[%d] should not be leader", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
//
|
|
// Set test cluster core(s) and test cluster
|
|
//
|
|
cluster, err := cores[0].Cluster(context.Background())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
testCluster.ID = cluster.ID
|
|
}
|
|
|
|
getAPIClient := func(port int, tlsConfig *tls.Config) *api.Client {
|
|
transport := cleanhttp.DefaultPooledTransport()
|
|
transport.TLSClientConfig = tlsConfig.Clone()
|
|
if err := http2.ConfigureTransport(transport); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
client := &http.Client{
|
|
Transport: transport,
|
|
CheckRedirect: func(*http.Request, []*http.Request) error {
|
|
// This can of course be overridden per-test by using its own client
|
|
return fmt.Errorf("redirects not allowed in these tests")
|
|
},
|
|
}
|
|
config := api.DefaultConfig()
|
|
if config.Error != nil {
|
|
t.Fatal(config.Error)
|
|
}
|
|
config.Address = fmt.Sprintf("https://127.0.0.1:%d", port)
|
|
config.HttpClient = client
|
|
config.MaxRetries = 0
|
|
apiClient, err := api.NewClient(config)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if opts == nil || !opts.SkipInit {
|
|
apiClient.SetToken(testCluster.RootToken)
|
|
}
|
|
return apiClient
|
|
}
|
|
|
|
var ret []*TestClusterCore
|
|
for i := 0; i < numCores; i++ {
|
|
tcc := &TestClusterCore{
|
|
Core: cores[i],
|
|
CoreConfig: coreConfigs[i],
|
|
ServerKey: certInfoSlice[i].key,
|
|
ServerKeyPEM: certInfoSlice[i].keyPEM,
|
|
ServerCert: certInfoSlice[i].cert,
|
|
ServerCertBytes: certInfoSlice[i].certBytes,
|
|
ServerCertPEM: certInfoSlice[i].certPEM,
|
|
Listeners: listeners[i],
|
|
Handler: handlers[i],
|
|
Server: servers[i],
|
|
TLSConfig: tlsConfigs[i],
|
|
Client: getAPIClient(listeners[i][0].Address.Port, tlsConfigs[i]),
|
|
}
|
|
tcc.ReloadFuncs = &cores[i].reloadFuncs
|
|
tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock
|
|
tcc.ReloadFuncsLock.Lock()
|
|
(*tcc.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{certGetters[i].Reload}
|
|
tcc.ReloadFuncsLock.Unlock()
|
|
|
|
testAdjustTestCore(base, tcc)
|
|
|
|
ret = append(ret, tcc)
|
|
}
|
|
|
|
testCluster.Cores = ret
|
|
|
|
testExtraClusterCoresTestSetup(t, priKey, testCluster.Cores)
|
|
|
|
return &testCluster
|
|
}
|
|
|
|
func NewMockBuiltinRegistry() *mockBuiltinRegistry {
|
|
return &mockBuiltinRegistry{
|
|
forTesting: map[string]consts.PluginType{
|
|
"mysql-database-plugin": consts.PluginTypeDatabase,
|
|
"postgresql-database-plugin": consts.PluginTypeDatabase,
|
|
},
|
|
}
|
|
}
|
|
|
|
type mockBuiltinRegistry struct {
|
|
forTesting map[string]consts.PluginType
|
|
}
|
|
|
|
func (m *mockBuiltinRegistry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) {
|
|
testPluginType, ok := m.forTesting[name]
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
if pluginType != testPluginType {
|
|
return nil, false
|
|
}
|
|
if name == "postgresql-database-plugin" {
|
|
return dbPostgres.New, true
|
|
}
|
|
return dbMysql.New(dbMysql.MetadataLen, dbMysql.MetadataLen, dbMysql.UsernameLen), true
|
|
}
|
|
|
|
// Keys only supports getting a realistic list of the keys for database plugins.
|
|
func (m *mockBuiltinRegistry) Keys(pluginType consts.PluginType) []string {
|
|
if pluginType != consts.PluginTypeDatabase {
|
|
return []string{}
|
|
}
|
|
/*
|
|
This is a hard-coded reproduction of the db plugin keys in helper/builtinplugins/registry.go.
|
|
The registry isn't directly used because it causes import cycles.
|
|
*/
|
|
return []string{
|
|
"mysql-database-plugin",
|
|
"mysql-aurora-database-plugin",
|
|
"mysql-rds-database-plugin",
|
|
"mysql-legacy-database-plugin",
|
|
"postgresql-database-plugin",
|
|
"mssql-database-plugin",
|
|
"cassandra-database-plugin",
|
|
"mongodb-database-plugin",
|
|
"hana-database-plugin",
|
|
}
|
|
}
|
|
|
|
func (m *mockBuiltinRegistry) Contains(name string, pluginType consts.PluginType) bool {
|
|
return false
|
|
}
|