1319 lines
35 KiB
Go
1319 lines
35 KiB
Go
package vault
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/base64"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"math/big"
|
|
mathrand "math/rand"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/mgutz/logxi/v1"
|
|
"github.com/mitchellh/copystructure"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/net/http2"
|
|
|
|
cleanhttp "github.com/hashicorp/go-cleanhttp"
|
|
"github.com/hashicorp/go-uuid"
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/audit"
|
|
"github.com/hashicorp/vault/helper/logformat"
|
|
"github.com/hashicorp/vault/helper/reload"
|
|
"github.com/hashicorp/vault/helper/salt"
|
|
"github.com/hashicorp/vault/logical"
|
|
"github.com/hashicorp/vault/logical/framework"
|
|
"github.com/hashicorp/vault/physical"
|
|
"github.com/mitchellh/go-testing-interface"
|
|
|
|
physInmem "github.com/hashicorp/vault/physical/inmem"
|
|
)
|
|
|
|
// This file contains a number of methods that are useful for unit
|
|
// tests within other packages.
|
|
|
|
const (
|
|
testSharedPublicKey = `
|
|
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC9i+hFxZHGo6KblVme4zrAcJstR6I0PTJozW286X4WyvPnkMYDQ5mnhEYC7UWCvjoTWbPEXPX7NjhRtwQTGD67bV+lrxgfyzK1JZbUXK4PwgKJvQD+XyyWYMzDgGSQY61KUSqCxymSm/9NZkPU3ElaQ9xQuTzPpztM4ROfb8f2Yv6/ZESZsTo0MTAkp8Pcy+WkioI/uJ1H7zqs0EA4OMY4aDJRu0UtP4rTVeYNEAuRXdX+eH4aW3KMvhzpFTjMbaJHJXlEeUm2SaX5TNQyTOvghCeQILfYIL/Ca2ij8iwCmulwdV6eQGfd4VDu40PvSnmfoaE38o6HaPnX0kUcnKiT
|
|
`
|
|
testSharedPrivateKey = `
|
|
-----BEGIN RSA PRIVATE KEY-----
|
|
MIIEogIBAAKCAQEAvYvoRcWRxqOim5VZnuM6wHCbLUeiND0yaM1tvOl+Fsrz55DG
|
|
A0OZp4RGAu1Fgr46E1mzxFz1+zY4UbcEExg+u21fpa8YH8sytSWW1FyuD8ICib0A
|
|
/l8slmDMw4BkkGOtSlEqgscpkpv/TWZD1NxJWkPcULk8z6c7TOETn2/H9mL+v2RE
|
|
mbE6NDEwJKfD3MvlpIqCP7idR+86rNBAODjGOGgyUbtFLT+K01XmDRALkV3V/nh+
|
|
GltyjL4c6RU4zG2iRyV5RHlJtkml+UzUMkzr4IQnkCC32CC/wmtoo/IsAprpcHVe
|
|
nkBn3eFQ7uND70p5n6GhN/KOh2j519JFHJyokwIDAQABAoIBAHX7VOvBC3kCN9/x
|
|
+aPdup84OE7Z7MvpX6w+WlUhXVugnmsAAVDczhKoUc/WktLLx2huCGhsmKvyVuH+
|
|
MioUiE+vx75gm3qGx5xbtmOfALVMRLopjCnJYf6EaFA0ZeQ+NwowNW7Lu0PHmAU8
|
|
Z3JiX8IwxTz14DU82buDyewO7v+cEr97AnERe3PUcSTDoUXNaoNxjNpEJkKREY6h
|
|
4hAY676RT/GsRcQ8tqe/rnCqPHNd7JGqL+207FK4tJw7daoBjQyijWuB7K5chSal
|
|
oPInylM6b13ASXuOAOT/2uSUBWmFVCZPDCmnZxy2SdnJGbsJAMl7Ma3MUlaGvVI+
|
|
Tfh1aQkCgYEA4JlNOabTb3z42wz6mz+Nz3JRwbawD+PJXOk5JsSnV7DtPtfgkK9y
|
|
6FTQdhnozGWShAvJvc+C4QAihs9AlHXoaBY5bEU7R/8UK/pSqwzam+MmxmhVDV7G
|
|
IMQPV0FteoXTaJSikhZ88mETTegI2mik+zleBpVxvfdhE5TR+lq8Br0CgYEA2AwJ
|
|
CUD5CYUSj09PluR0HHqamWOrJkKPFPwa+5eiTTCzfBBxImYZh7nXnWuoviXC0sg2
|
|
AuvCW+uZ48ygv/D8gcz3j1JfbErKZJuV+TotK9rRtNIF5Ub7qysP7UjyI7zCssVM
|
|
kuDd9LfRXaB/qGAHNkcDA8NxmHW3gpln4CFdSY8CgYANs4xwfercHEWaJ1qKagAe
|
|
rZyrMpffAEhicJ/Z65lB0jtG4CiE6w8ZeUMWUVJQVcnwYD+4YpZbX4S7sJ0B8Ydy
|
|
AhkSr86D/92dKTIt2STk6aCN7gNyQ1vW198PtaAWH1/cO2UHgHOy3ZUt5X/Uwxl9
|
|
cex4flln+1Viumts2GgsCQKBgCJH7psgSyPekK5auFdKEr5+Gc/jB8I/Z3K9+g4X
|
|
5nH3G1PBTCJYLw7hRzw8W/8oALzvddqKzEFHphiGXK94Lqjt/A4q1OdbCrhiE68D
|
|
My21P/dAKB1UYRSs9Y8CNyHCjuZM9jSMJ8vv6vG/SOJPsnVDWVAckAbQDvlTHC9t
|
|
O98zAoGAcbW6uFDkrv0XMCpB9Su3KaNXOR0wzag+WIFQRXCcoTvxVi9iYfUReQPi
|
|
oOyBJU/HMVvBfv4g+OVFLVgSwwm6owwsouZ0+D/LasbuHqYyqYqdyPJQYzWA2Y+F
|
|
+B6f4RoPdSXj24JHPg/ioRxjaj094UXJxua2yfkcecGNEuBQHSs=
|
|
-----END RSA PRIVATE KEY-----
|
|
`
|
|
)
|
|
|
|
// TestCore returns a pure in-memory, uninitialized core for testing.
|
|
func TestCore(t testing.T) *Core {
|
|
return TestCoreWithSeal(t, nil)
|
|
}
|
|
|
|
// TestCoreNewSeal returns a pure in-memory, uninitialized core with
|
|
// the new seal configuration.
|
|
func TestCoreNewSeal(t testing.T) *Core {
|
|
return TestCoreWithSeal(t, &TestSeal{})
|
|
}
|
|
|
|
// TestCoreWithSeal returns a pure in-memory, uninitialized core with the
|
|
// specified seal for testing.
|
|
func TestCoreWithSeal(t testing.T, testSeal Seal) *Core {
|
|
logger := logformat.NewVaultLogger(log.LevelTrace)
|
|
physicalBackend, err := physInmem.NewInmem(nil, logger)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
conf := testCoreConfig(t, physicalBackend, logger)
|
|
|
|
if testSeal != nil {
|
|
conf.Seal = testSeal
|
|
}
|
|
|
|
c, err := NewCore(conf)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Logger) *CoreConfig {
|
|
noopAudits := map[string]audit.Factory{
|
|
"noop": func(config *audit.BackendConfig) (audit.Backend, error) {
|
|
view := &logical.InmemStorage{}
|
|
view.Put(&logical.StorageEntry{
|
|
Key: "salt",
|
|
Value: []byte("foo"),
|
|
})
|
|
config.SaltConfig = &salt.Config{
|
|
HMAC: sha256.New,
|
|
HMACType: "hmac-sha256",
|
|
}
|
|
config.SaltView = view
|
|
return &noopAudit{
|
|
Config: config,
|
|
}, nil
|
|
},
|
|
}
|
|
noopBackends := make(map[string]logical.Factory)
|
|
noopBackends["noop"] = func(config *logical.BackendConfig) (logical.Backend, error) {
|
|
b := new(framework.Backend)
|
|
b.Setup(config)
|
|
return b, nil
|
|
}
|
|
noopBackends["http"] = func(config *logical.BackendConfig) (logical.Backend, error) {
|
|
return new(rawHTTP), nil
|
|
}
|
|
logicalBackends := make(map[string]logical.Factory)
|
|
for backendName, backendFactory := range noopBackends {
|
|
logicalBackends[backendName] = backendFactory
|
|
}
|
|
logicalBackends["generic"] = LeasedPassthroughBackendFactory
|
|
for backendName, backendFactory := range testLogicalBackends {
|
|
logicalBackends[backendName] = backendFactory
|
|
}
|
|
|
|
conf := &CoreConfig{
|
|
Physical: physicalBackend,
|
|
AuditBackends: noopAudits,
|
|
LogicalBackends: logicalBackends,
|
|
CredentialBackends: noopBackends,
|
|
DisableMlock: true,
|
|
Logger: logger,
|
|
}
|
|
|
|
return conf
|
|
}
|
|
|
|
// TestCoreInit initializes the core with a single key, and returns
|
|
// the key that must be used to unseal the core and a root token.
|
|
func TestCoreInit(t testing.T, core *Core) ([][]byte, string) {
|
|
return TestCoreInitClusterWrapperSetup(t, core, nil, nil)
|
|
}
|
|
|
|
func TestCoreInitClusterWrapperSetup(t testing.T, core *Core, clusterAddrs []*net.TCPAddr, handler http.Handler) ([][]byte, string) {
|
|
core.SetClusterListenerAddrs(clusterAddrs)
|
|
core.SetClusterHandler(handler)
|
|
result, err := core.Initialize(&InitParams{
|
|
BarrierConfig: &SealConfig{
|
|
SecretShares: 3,
|
|
SecretThreshold: 3,
|
|
},
|
|
RecoveryConfig: &SealConfig{
|
|
SecretShares: 3,
|
|
SecretThreshold: 3,
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
return result.SecretShares, result.RootToken
|
|
}
|
|
|
|
func TestCoreUnseal(core *Core, key []byte) (bool, error) {
|
|
return core.Unseal(key)
|
|
}
|
|
|
|
// TestCoreUnsealed returns a pure in-memory core that is already
|
|
// initialized and unsealed.
|
|
func TestCoreUnsealed(t testing.T) (*Core, [][]byte, string) {
|
|
core := TestCore(t)
|
|
keys, token := TestCoreInit(t, core)
|
|
for _, key := range keys {
|
|
if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil {
|
|
t.Fatalf("unseal err: %s", err)
|
|
}
|
|
}
|
|
|
|
sealed, err := core.Sealed()
|
|
if err != nil {
|
|
t.Fatalf("err checking seal status: %s", err)
|
|
}
|
|
if sealed {
|
|
t.Fatal("should not be sealed")
|
|
}
|
|
|
|
return core, keys, token
|
|
}
|
|
|
|
func TestCoreUnsealedBackend(t testing.T, backend physical.Backend) (*Core, [][]byte, string) {
|
|
logger := logformat.NewVaultLogger(log.LevelTrace)
|
|
conf := testCoreConfig(t, backend, logger)
|
|
conf.Seal = &TestSeal{}
|
|
|
|
core, err := NewCore(conf)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
keys, token := TestCoreInit(t, core)
|
|
for _, key := range keys {
|
|
if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil {
|
|
t.Fatalf("unseal err: %s", err)
|
|
}
|
|
}
|
|
|
|
sealed, err := core.Sealed()
|
|
if err != nil {
|
|
t.Fatalf("err checking seal status: %s", err)
|
|
}
|
|
if sealed {
|
|
t.Fatal("should not be sealed")
|
|
}
|
|
|
|
return core, keys, token
|
|
}
|
|
|
|
func testTokenStore(t testing.T, c *Core) *TokenStore {
|
|
me := &MountEntry{
|
|
Table: credentialTableType,
|
|
Path: "token/",
|
|
Type: "token",
|
|
Description: "token based credentials",
|
|
}
|
|
|
|
meUUID, err := uuid.GenerateUUID()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
me.UUID = meUUID
|
|
|
|
view := NewBarrierView(c.barrier, credentialBarrierPrefix+me.UUID+"/")
|
|
sysView := c.mountEntrySysView(me)
|
|
|
|
tokenstore, _ := c.newCredentialBackend("token", sysView, view, nil)
|
|
if err := tokenstore.Initialize(); err != nil {
|
|
panic(err)
|
|
}
|
|
ts := tokenstore.(*TokenStore)
|
|
|
|
router := NewRouter()
|
|
err = router.Mount(ts, "auth/token/", &MountEntry{Table: credentialTableType, UUID: "authtokenuuid", Path: "auth/token", Accessor: "authtokenaccessor"}, ts.view)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
subview := c.systemBarrierView.SubView(expirationSubPath)
|
|
logger := logformat.NewVaultLogger(log.LevelTrace)
|
|
|
|
exp := NewExpirationManager(router, subview, ts, logger)
|
|
ts.SetExpirationManager(exp)
|
|
|
|
return ts
|
|
}
|
|
|
|
// TestCoreWithTokenStore returns an in-memory core that has a token store
|
|
// mounted, so that logical token functions can be used
|
|
func TestCoreWithTokenStore(t testing.T) (*Core, *TokenStore, [][]byte, string) {
|
|
c, keys, root := TestCoreUnsealed(t)
|
|
ts := testTokenStore(t, c)
|
|
|
|
return c, ts, keys, root
|
|
}
|
|
|
|
// TestCoreWithBackendTokenStore returns a core that has a token store
|
|
// mounted and used the provided physical backend, so that logical token
|
|
// functions can be used
|
|
func TestCoreWithBackendTokenStore(t testing.T, backend physical.Backend) (*Core, *TokenStore, [][]byte, string) {
|
|
c, keys, root := TestCoreUnsealedBackend(t, backend)
|
|
ts := testTokenStore(t, c)
|
|
|
|
return c, ts, keys, root
|
|
}
|
|
|
|
// TestKeyCopy is a silly little function to just copy the key so that
|
|
// it can be used with Unseal easily.
|
|
func TestKeyCopy(key []byte) []byte {
|
|
result := make([]byte, len(key))
|
|
copy(result, key)
|
|
return result
|
|
}
|
|
|
|
func TestDynamicSystemView(c *Core) *dynamicSystemView {
|
|
me := &MountEntry{
|
|
Config: MountConfig{
|
|
DefaultLeaseTTL: 24 * time.Hour,
|
|
MaxLeaseTTL: 2 * 24 * time.Hour,
|
|
},
|
|
}
|
|
|
|
return &dynamicSystemView{c, me}
|
|
}
|
|
|
|
func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) {
|
|
file, err := os.Open(os.Args[0])
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer file.Close()
|
|
|
|
hash := sha256.New()
|
|
|
|
_, err = io.Copy(hash, file)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
sum := hash.Sum(nil)
|
|
c.pluginCatalog.directory, err = filepath.EvalSymlinks(os.Args[0])
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
c.pluginCatalog.directory = filepath.Dir(c.pluginCatalog.directory)
|
|
|
|
command := fmt.Sprintf("%s --test.run=%s", filepath.Base(os.Args[0]), testFunc)
|
|
err = c.pluginCatalog.Set(name, command, sum)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
var testLogicalBackends = map[string]logical.Factory{}
|
|
|
|
// Starts the test server which responds to SSH authentication.
|
|
// Used to test the SSH secret backend.
|
|
func StartSSHHostTestServer() (string, error) {
|
|
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testSharedPublicKey))
|
|
if err != nil {
|
|
return "", fmt.Errorf("Error parsing public key")
|
|
}
|
|
serverConfig := &ssh.ServerConfig{
|
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
|
if bytes.Compare(pubKey.Marshal(), key.Marshal()) == 0 {
|
|
return &ssh.Permissions{}, nil
|
|
} else {
|
|
return nil, fmt.Errorf("Key does not match")
|
|
}
|
|
},
|
|
}
|
|
signer, err := ssh.ParsePrivateKey([]byte(testSharedPrivateKey))
|
|
if err != nil {
|
|
panic("Error parsing private key")
|
|
}
|
|
serverConfig.AddHostKey(signer)
|
|
|
|
soc, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
return "", fmt.Errorf("Error listening to connection")
|
|
}
|
|
|
|
go func() {
|
|
for {
|
|
conn, err := soc.Accept()
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Error accepting incoming connection: %s", err))
|
|
}
|
|
defer conn.Close()
|
|
sshConn, chanReqs, _, err := ssh.NewServerConn(conn, serverConfig)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Handshaking error: %v", err))
|
|
}
|
|
|
|
go func() {
|
|
for chanReq := range chanReqs {
|
|
go func(chanReq ssh.NewChannel) {
|
|
if chanReq.ChannelType() != "session" {
|
|
chanReq.Reject(ssh.UnknownChannelType, "unknown channel type")
|
|
return
|
|
}
|
|
|
|
ch, requests, err := chanReq.Accept()
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Error accepting channel: %s", err))
|
|
}
|
|
|
|
go func(ch ssh.Channel, in <-chan *ssh.Request) {
|
|
for req := range in {
|
|
executeServerCommand(ch, req)
|
|
}
|
|
}(ch, requests)
|
|
}(chanReq)
|
|
}
|
|
sshConn.Close()
|
|
}()
|
|
}
|
|
}()
|
|
return soc.Addr().String(), nil
|
|
}
|
|
|
|
// This executes the commands requested to be run on the server.
|
|
// Used to test the SSH secret backend.
|
|
func executeServerCommand(ch ssh.Channel, req *ssh.Request) {
|
|
command := string(req.Payload[4:])
|
|
cmd := exec.Command("/bin/bash", []string{"-c", command}...)
|
|
req.Reply(true, nil)
|
|
|
|
cmd.Stdout = ch
|
|
cmd.Stderr = ch
|
|
cmd.Stdin = ch
|
|
|
|
err := cmd.Start()
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Error starting the command: '%s'", err))
|
|
}
|
|
|
|
go func() {
|
|
_, err := cmd.Process.Wait()
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Error while waiting for command to finish:'%s'", err))
|
|
}
|
|
ch.Close()
|
|
}()
|
|
}
|
|
|
|
// This adds a logical backend for the test core. This needs to be
|
|
// invoked before the test core is created.
|
|
func AddTestLogicalBackend(name string, factory logical.Factory) error {
|
|
if name == "" {
|
|
return fmt.Errorf("Missing backend name")
|
|
}
|
|
if factory == nil {
|
|
return fmt.Errorf("Missing backend factory function")
|
|
}
|
|
testLogicalBackends[name] = factory
|
|
return nil
|
|
}
|
|
|
|
type noopAudit struct {
|
|
Config *audit.BackendConfig
|
|
salt *salt.Salt
|
|
saltMutex sync.RWMutex
|
|
}
|
|
|
|
func (n *noopAudit) GetHash(data string) (string, error) {
|
|
salt, err := n.Salt()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return salt.GetIdentifiedHMAC(data), nil
|
|
}
|
|
|
|
func (n *noopAudit) LogRequest(a *logical.Auth, r *logical.Request, e error) error {
|
|
return nil
|
|
}
|
|
|
|
func (n *noopAudit) LogResponse(a *logical.Auth, r *logical.Request, re *logical.Response, err error) error {
|
|
return nil
|
|
}
|
|
|
|
func (n *noopAudit) Reload() error {
|
|
return nil
|
|
}
|
|
|
|
func (n *noopAudit) Invalidate() {
|
|
n.saltMutex.Lock()
|
|
defer n.saltMutex.Unlock()
|
|
n.salt = nil
|
|
}
|
|
|
|
func (n *noopAudit) Salt() (*salt.Salt, error) {
|
|
n.saltMutex.RLock()
|
|
if n.salt != nil {
|
|
defer n.saltMutex.RUnlock()
|
|
return n.salt, nil
|
|
}
|
|
n.saltMutex.RUnlock()
|
|
n.saltMutex.Lock()
|
|
defer n.saltMutex.Unlock()
|
|
if n.salt != nil {
|
|
return n.salt, nil
|
|
}
|
|
salt, err := salt.NewSalt(n.Config.SaltView, n.Config.SaltConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
n.salt = salt
|
|
return salt, nil
|
|
}
|
|
|
|
type rawHTTP struct{}
|
|
|
|
func (n *rawHTTP) HandleRequest(req *logical.Request) (*logical.Response, error) {
|
|
return &logical.Response{
|
|
Data: map[string]interface{}{
|
|
logical.HTTPStatusCode: 200,
|
|
logical.HTTPContentType: "plain/text",
|
|
logical.HTTPRawBody: []byte("hello world"),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (n *rawHTTP) HandleExistenceCheck(req *logical.Request) (bool, bool, error) {
|
|
return false, false, nil
|
|
}
|
|
|
|
func (n *rawHTTP) SpecialPaths() *logical.Paths {
|
|
return &logical.Paths{Unauthenticated: []string{"*"}}
|
|
}
|
|
|
|
func (n *rawHTTP) System() logical.SystemView {
|
|
return logical.StaticSystemView{
|
|
DefaultLeaseTTLVal: time.Hour * 24,
|
|
MaxLeaseTTLVal: time.Hour * 24 * 32,
|
|
}
|
|
}
|
|
|
|
func (n *rawHTTP) Logger() log.Logger {
|
|
return logformat.NewVaultLogger(log.LevelTrace)
|
|
}
|
|
|
|
func (n *rawHTTP) Cleanup() {
|
|
// noop
|
|
}
|
|
|
|
func (n *rawHTTP) Initialize() error {
|
|
// noop
|
|
return nil
|
|
}
|
|
|
|
func (n *rawHTTP) InvalidateKey(string) {
|
|
// noop
|
|
}
|
|
|
|
func (n *rawHTTP) Setup(config *logical.BackendConfig) error {
|
|
// noop
|
|
return nil
|
|
}
|
|
|
|
func (n *rawHTTP) Type() logical.BackendType {
|
|
return logical.TypeUnknown
|
|
}
|
|
|
|
func (n *rawHTTP) RegisterLicense(license interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func GenerateRandBytes(length int) ([]byte, error) {
|
|
if length < 0 {
|
|
return nil, fmt.Errorf("length must be >= 0")
|
|
}
|
|
|
|
buf := make([]byte, length)
|
|
if length == 0 {
|
|
return buf, nil
|
|
}
|
|
|
|
n, err := rand.Read(buf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if n != length {
|
|
return nil, fmt.Errorf("unable to read %d bytes; only read %d", length, n)
|
|
}
|
|
|
|
return buf, nil
|
|
}
|
|
|
|
func TestWaitActive(t testing.T, core *Core) {
|
|
start := time.Now()
|
|
var standby bool
|
|
var err error
|
|
for time.Now().Sub(start) < time.Second {
|
|
standby, err = core.Standby()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if !standby {
|
|
break
|
|
}
|
|
}
|
|
if standby {
|
|
t.Fatalf("should not be in standby mode")
|
|
}
|
|
}
|
|
|
|
type TestCluster struct {
|
|
BarrierKeys [][]byte
|
|
CACert *x509.Certificate
|
|
CACertBytes []byte
|
|
CACertPEM []byte
|
|
CACertPEMFile string
|
|
CAKey *ecdsa.PrivateKey
|
|
CAKeyPEM []byte
|
|
Cores []*TestClusterCore
|
|
ID string
|
|
RootToken string
|
|
RootCAs *x509.CertPool
|
|
TempDir string
|
|
}
|
|
|
|
func (t *TestCluster) Start() {
|
|
for _, core := range t.Cores {
|
|
if core.Server != nil {
|
|
for _, ln := range core.Listeners {
|
|
go core.Server.Serve(ln)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (t *TestCluster) Cleanup() {
|
|
for _, core := range t.Cores {
|
|
if core.Listeners != nil {
|
|
for _, ln := range core.Listeners {
|
|
ln.Close()
|
|
}
|
|
}
|
|
}
|
|
|
|
if t.TempDir != "" {
|
|
os.RemoveAll(t.TempDir)
|
|
}
|
|
|
|
// Give time to actually shut down/clean up before the next test
|
|
time.Sleep(time.Second)
|
|
}
|
|
|
|
type TestListener struct {
|
|
net.Listener
|
|
Address *net.TCPAddr
|
|
}
|
|
|
|
type TestClusterCore struct {
|
|
*Core
|
|
Client *api.Client
|
|
Handler http.Handler
|
|
Listeners []*TestListener
|
|
ReloadFuncs *map[string][]reload.ReloadFunc
|
|
ReloadFuncsLock *sync.RWMutex
|
|
Server *http.Server
|
|
ServerCert *x509.Certificate
|
|
ServerCertBytes []byte
|
|
ServerCertPEM []byte
|
|
ServerKey *ecdsa.PrivateKey
|
|
ServerKeyPEM []byte
|
|
TLSConfig *tls.Config
|
|
}
|
|
|
|
type TestClusterOptions struct {
|
|
KeepStandbysSealed bool
|
|
HandlerFunc func(*Core) http.Handler
|
|
BaseListenAddress string
|
|
}
|
|
|
|
func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster {
|
|
certIPs := []net.IP{
|
|
net.IPv6loopback,
|
|
net.ParseIP("127.0.0.1"),
|
|
}
|
|
var baseAddr *net.TCPAddr
|
|
if opts != nil && opts.BaseListenAddress != "" {
|
|
var err error
|
|
baseAddr, err = net.ResolveTCPAddr("tcp", opts.BaseListenAddress)
|
|
if err != nil {
|
|
t.Fatal("could not parse given base IP")
|
|
}
|
|
certIPs = append(certIPs, baseAddr.IP)
|
|
}
|
|
|
|
var testCluster TestCluster
|
|
tempDir, err := ioutil.TempDir("", "vault-test-cluster-")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
testCluster.TempDir = tempDir
|
|
|
|
caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
testCluster.CAKey = caKey
|
|
caCertTemplate := &x509.Certificate{
|
|
Subject: pkix.Name{
|
|
CommonName: "localhost",
|
|
},
|
|
DNSNames: []string{"localhost"},
|
|
IPAddresses: certIPs,
|
|
KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
|
|
SerialNumber: big.NewInt(mathrand.Int63()),
|
|
NotBefore: time.Now().Add(-30 * time.Second),
|
|
NotAfter: time.Now().Add(262980 * time.Hour),
|
|
BasicConstraintsValid: true,
|
|
IsCA: true,
|
|
}
|
|
caBytes, err := x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
caCert, err := x509.ParseCertificate(caBytes)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
testCluster.CACert = caCert
|
|
testCluster.CACertBytes = caBytes
|
|
testCluster.RootCAs = x509.NewCertPool()
|
|
testCluster.RootCAs.AddCert(caCert)
|
|
caCertPEMBlock := &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: caBytes,
|
|
}
|
|
testCluster.CACertPEM = pem.EncodeToMemory(caCertPEMBlock)
|
|
testCluster.CACertPEMFile = filepath.Join(testCluster.TempDir, "ca_cert.pem")
|
|
err = ioutil.WriteFile(testCluster.CACertPEMFile, testCluster.CACertPEM, 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
marshaledCAKey, err := x509.MarshalECPrivateKey(caKey)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
caKeyPEMBlock := &pem.Block{
|
|
Type: "EC PRIVATE KEY",
|
|
Bytes: marshaledCAKey,
|
|
}
|
|
testCluster.CAKeyPEM = pem.EncodeToMemory(caKeyPEMBlock)
|
|
err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "ca_key.pem"), testCluster.CAKeyPEM, 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
s1Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s1CertTemplate := &x509.Certificate{
|
|
Subject: pkix.Name{
|
|
CommonName: "localhost",
|
|
},
|
|
DNSNames: []string{"localhost"},
|
|
IPAddresses: certIPs,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
|
x509.ExtKeyUsageServerAuth,
|
|
x509.ExtKeyUsageClientAuth,
|
|
},
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
|
|
SerialNumber: big.NewInt(mathrand.Int63()),
|
|
NotBefore: time.Now().Add(-30 * time.Second),
|
|
NotAfter: time.Now().Add(262980 * time.Hour),
|
|
}
|
|
s1CertBytes, err := x509.CreateCertificate(rand.Reader, s1CertTemplate, caCert, s1Key.Public(), caKey)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s1Cert, err := x509.ParseCertificate(s1CertBytes)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s1CertPEMBlock := &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: s1CertBytes,
|
|
}
|
|
s1CertPEM := pem.EncodeToMemory(s1CertPEMBlock)
|
|
s1MarshaledKey, err := x509.MarshalECPrivateKey(s1Key)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s1KeyPEMBlock := &pem.Block{
|
|
Type: "EC PRIVATE KEY",
|
|
Bytes: s1MarshaledKey,
|
|
}
|
|
s1KeyPEM := pem.EncodeToMemory(s1KeyPEMBlock)
|
|
|
|
s2Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s2CertTemplate := &x509.Certificate{
|
|
Subject: pkix.Name{
|
|
CommonName: "localhost",
|
|
},
|
|
DNSNames: []string{"localhost"},
|
|
IPAddresses: certIPs,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
|
x509.ExtKeyUsageServerAuth,
|
|
x509.ExtKeyUsageClientAuth,
|
|
},
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
|
|
SerialNumber: big.NewInt(mathrand.Int63()),
|
|
NotBefore: time.Now().Add(-30 * time.Second),
|
|
NotAfter: time.Now().Add(262980 * time.Hour),
|
|
}
|
|
s2CertBytes, err := x509.CreateCertificate(rand.Reader, s2CertTemplate, caCert, s2Key.Public(), caKey)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s2Cert, err := x509.ParseCertificate(s2CertBytes)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s2CertPEMBlock := &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: s2CertBytes,
|
|
}
|
|
s2CertPEM := pem.EncodeToMemory(s2CertPEMBlock)
|
|
s2MarshaledKey, err := x509.MarshalECPrivateKey(s2Key)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s2KeyPEMBlock := &pem.Block{
|
|
Type: "EC PRIVATE KEY",
|
|
Bytes: s2MarshaledKey,
|
|
}
|
|
s2KeyPEM := pem.EncodeToMemory(s2KeyPEMBlock)
|
|
|
|
s3Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s3CertTemplate := &x509.Certificate{
|
|
Subject: pkix.Name{
|
|
CommonName: "localhost",
|
|
},
|
|
DNSNames: []string{"localhost"},
|
|
IPAddresses: certIPs,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
|
x509.ExtKeyUsageServerAuth,
|
|
x509.ExtKeyUsageClientAuth,
|
|
},
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
|
|
SerialNumber: big.NewInt(mathrand.Int63()),
|
|
NotBefore: time.Now().Add(-30 * time.Second),
|
|
NotAfter: time.Now().Add(262980 * time.Hour),
|
|
}
|
|
s3CertBytes, err := x509.CreateCertificate(rand.Reader, s3CertTemplate, caCert, s3Key.Public(), caKey)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s3Cert, err := x509.ParseCertificate(s3CertBytes)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s3CertPEMBlock := &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: s3CertBytes,
|
|
}
|
|
s3CertPEM := pem.EncodeToMemory(s3CertPEMBlock)
|
|
s3MarshaledKey, err := x509.MarshalECPrivateKey(s3Key)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s3KeyPEMBlock := &pem.Block{
|
|
Type: "EC PRIVATE KEY",
|
|
Bytes: s3MarshaledKey,
|
|
}
|
|
s3KeyPEM := pem.EncodeToMemory(s3KeyPEMBlock)
|
|
|
|
logger := logformat.NewVaultLogger(log.LevelTrace)
|
|
|
|
//
|
|
// Listener setup
|
|
//
|
|
ports := []int{0, 0, 0}
|
|
if baseAddr != nil {
|
|
ports = []int{baseAddr.Port, baseAddr.Port + 1, baseAddr.Port + 2}
|
|
} else {
|
|
baseAddr = &net.TCPAddr{
|
|
IP: net.ParseIP("127.0.0.1"),
|
|
Port: 0,
|
|
}
|
|
}
|
|
baseAddr.Port = ports[0]
|
|
ln, err := net.ListenTCP("tcp", baseAddr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s1CertFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node1_port_%d_cert.pem", ln.Addr().(*net.TCPAddr).Port))
|
|
s1KeyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node1_port_%d_key.pem", ln.Addr().(*net.TCPAddr).Port))
|
|
err = ioutil.WriteFile(s1CertFile, s1CertPEM, 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err = ioutil.WriteFile(s1KeyFile, s1KeyPEM, 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s1TLSCert, err := tls.X509KeyPair(s1CertPEM, s1KeyPEM)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s1CertGetter := reload.NewCertificateGetter(s1CertFile, s1KeyFile)
|
|
s1TLSConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{s1TLSCert},
|
|
RootCAs: testCluster.RootCAs,
|
|
ClientCAs: testCluster.RootCAs,
|
|
ClientAuth: tls.VerifyClientCertIfGiven,
|
|
NextProtos: []string{"h2", "http/1.1"},
|
|
GetCertificate: s1CertGetter.GetCertificate,
|
|
}
|
|
s1TLSConfig.BuildNameToCertificate()
|
|
c1lns := []*TestListener{&TestListener{
|
|
Listener: tls.NewListener(ln, s1TLSConfig),
|
|
Address: ln.Addr().(*net.TCPAddr),
|
|
},
|
|
}
|
|
var handler1 http.Handler = http.NewServeMux()
|
|
server1 := &http.Server{
|
|
Handler: handler1,
|
|
}
|
|
if err := http2.ConfigureServer(server1, nil); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
baseAddr.Port = ports[1]
|
|
ln, err = net.ListenTCP("tcp", baseAddr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s2CertFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node2_port_%d_cert.pem", ln.Addr().(*net.TCPAddr).Port))
|
|
s2KeyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node2_port_%d_key.pem", ln.Addr().(*net.TCPAddr).Port))
|
|
err = ioutil.WriteFile(s2CertFile, s2CertPEM, 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err = ioutil.WriteFile(s2KeyFile, s2KeyPEM, 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s2TLSCert, err := tls.X509KeyPair(s2CertPEM, s2KeyPEM)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s2CertGetter := reload.NewCertificateGetter(s2CertFile, s2KeyFile)
|
|
s2TLSConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{s2TLSCert},
|
|
RootCAs: testCluster.RootCAs,
|
|
ClientCAs: testCluster.RootCAs,
|
|
ClientAuth: tls.VerifyClientCertIfGiven,
|
|
NextProtos: []string{"h2", "http/1.1"},
|
|
GetCertificate: s2CertGetter.GetCertificate,
|
|
}
|
|
s2TLSConfig.BuildNameToCertificate()
|
|
c2lns := []*TestListener{&TestListener{
|
|
Listener: tls.NewListener(ln, s2TLSConfig),
|
|
Address: ln.Addr().(*net.TCPAddr),
|
|
},
|
|
}
|
|
var handler2 http.Handler = http.NewServeMux()
|
|
server2 := &http.Server{
|
|
Handler: handler2,
|
|
}
|
|
if err := http2.ConfigureServer(server2, nil); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
baseAddr.Port = ports[2]
|
|
ln, err = net.ListenTCP("tcp", baseAddr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s3CertFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node3_port_%d_cert.pem", ln.Addr().(*net.TCPAddr).Port))
|
|
s3KeyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node3_port_%d_key.pem", ln.Addr().(*net.TCPAddr).Port))
|
|
err = ioutil.WriteFile(s3CertFile, s3CertPEM, 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err = ioutil.WriteFile(s3KeyFile, s3KeyPEM, 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s3TLSCert, err := tls.X509KeyPair(s3CertPEM, s3KeyPEM)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s3CertGetter := reload.NewCertificateGetter(s3CertFile, s3KeyFile)
|
|
s3TLSConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{s3TLSCert},
|
|
RootCAs: testCluster.RootCAs,
|
|
ClientCAs: testCluster.RootCAs,
|
|
ClientAuth: tls.VerifyClientCertIfGiven,
|
|
NextProtos: []string{"h2", "http/1.1"},
|
|
GetCertificate: s3CertGetter.GetCertificate,
|
|
}
|
|
s3TLSConfig.BuildNameToCertificate()
|
|
c3lns := []*TestListener{&TestListener{
|
|
Listener: tls.NewListener(ln, s3TLSConfig),
|
|
Address: ln.Addr().(*net.TCPAddr),
|
|
},
|
|
}
|
|
var handler3 http.Handler = http.NewServeMux()
|
|
server3 := &http.Server{
|
|
Handler: handler3,
|
|
}
|
|
if err := http2.ConfigureServer(server3, nil); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Create three cores with the same physical and different redirect/cluster
|
|
// addrs.
|
|
// N.B.: On OSX, instead of random ports, it assigns new ports to new
|
|
// listeners sequentially. Aside from being a bad idea in a security sense,
|
|
// it also broke tests that assumed it was OK to just use the port above
|
|
// the redirect addr. This has now been changed to 105 ports above, but if
|
|
// we ever do more than three nodes in a cluster it may need to be bumped.
|
|
// Note: it's 105 so that we don't conflict with a running Consul by
|
|
// default.
|
|
coreConfig := &CoreConfig{
|
|
LogicalBackends: make(map[string]logical.Factory),
|
|
CredentialBackends: make(map[string]logical.Factory),
|
|
AuditBackends: make(map[string]audit.Factory),
|
|
RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port),
|
|
ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port+105),
|
|
DisableMlock: true,
|
|
EnableUI: true,
|
|
}
|
|
|
|
if base != nil {
|
|
coreConfig.DisableCache = base.DisableCache
|
|
coreConfig.EnableUI = base.EnableUI
|
|
coreConfig.DefaultLeaseTTL = base.DefaultLeaseTTL
|
|
coreConfig.MaxLeaseTTL = base.MaxLeaseTTL
|
|
coreConfig.CacheSize = base.CacheSize
|
|
coreConfig.PluginDirectory = base.PluginDirectory
|
|
coreConfig.Seal = base.Seal
|
|
coreConfig.DevToken = base.DevToken
|
|
|
|
if !coreConfig.DisableMlock {
|
|
base.DisableMlock = false
|
|
}
|
|
|
|
if base.Physical != nil {
|
|
coreConfig.Physical = base.Physical
|
|
}
|
|
|
|
if base.HAPhysical != nil {
|
|
coreConfig.HAPhysical = base.HAPhysical
|
|
}
|
|
|
|
// Used to set something non-working to test fallback
|
|
switch base.ClusterAddr {
|
|
case "empty":
|
|
coreConfig.ClusterAddr = ""
|
|
case "":
|
|
default:
|
|
coreConfig.ClusterAddr = base.ClusterAddr
|
|
}
|
|
|
|
if base.LogicalBackends != nil {
|
|
for k, v := range base.LogicalBackends {
|
|
coreConfig.LogicalBackends[k] = v
|
|
}
|
|
}
|
|
if base.CredentialBackends != nil {
|
|
for k, v := range base.CredentialBackends {
|
|
coreConfig.CredentialBackends[k] = v
|
|
}
|
|
}
|
|
if base.AuditBackends != nil {
|
|
for k, v := range base.AuditBackends {
|
|
coreConfig.AuditBackends[k] = v
|
|
}
|
|
}
|
|
if base.Logger != nil {
|
|
coreConfig.Logger = base.Logger
|
|
}
|
|
|
|
coreConfig.DisableCache = base.DisableCache
|
|
|
|
coreConfig.DevToken = base.DevToken
|
|
}
|
|
|
|
if coreConfig.Physical == nil {
|
|
coreConfig.Physical, err = physInmem.NewInmem(nil, logger)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
if coreConfig.HAPhysical == nil {
|
|
haPhys, err := physInmem.NewInmemHA(nil, logger)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
coreConfig.HAPhysical = haPhys.(physical.HABackend)
|
|
}
|
|
|
|
c1, err := NewCore(coreConfig)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if opts != nil && opts.HandlerFunc != nil {
|
|
handler1 = opts.HandlerFunc(c1)
|
|
server1.Handler = handler1
|
|
}
|
|
|
|
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port)
|
|
if coreConfig.ClusterAddr != "" {
|
|
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port+105)
|
|
}
|
|
c2, err := NewCore(coreConfig)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if opts != nil && opts.HandlerFunc != nil {
|
|
handler2 = opts.HandlerFunc(c2)
|
|
server2.Handler = handler2
|
|
}
|
|
|
|
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port)
|
|
if coreConfig.ClusterAddr != "" {
|
|
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port+105)
|
|
}
|
|
c3, err := NewCore(coreConfig)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if opts != nil && opts.HandlerFunc != nil {
|
|
handler3 = opts.HandlerFunc(c3)
|
|
server3.Handler = handler3
|
|
}
|
|
|
|
//
|
|
// Clustering setup
|
|
//
|
|
clusterAddrGen := func(lns []*TestListener) []*net.TCPAddr {
|
|
ret := make([]*net.TCPAddr, len(lns))
|
|
for i, ln := range lns {
|
|
ret[i] = &net.TCPAddr{
|
|
IP: ln.Address.IP,
|
|
Port: ln.Address.Port + 105,
|
|
}
|
|
}
|
|
return ret
|
|
}
|
|
|
|
c2.SetClusterListenerAddrs(clusterAddrGen(c2lns))
|
|
c2.SetClusterHandler(handler2)
|
|
c3.SetClusterListenerAddrs(clusterAddrGen(c3lns))
|
|
c3.SetClusterHandler(handler3)
|
|
|
|
keys, root := TestCoreInitClusterWrapperSetup(t, c1, clusterAddrGen(c1lns), handler1)
|
|
barrierKeys, _ := copystructure.Copy(keys)
|
|
testCluster.BarrierKeys = barrierKeys.([][]byte)
|
|
testCluster.RootToken = root
|
|
|
|
err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(root), 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var buf bytes.Buffer
|
|
for i, key := range testCluster.BarrierKeys {
|
|
buf.Write([]byte(base64.StdEncoding.EncodeToString(key)))
|
|
if i < len(testCluster.BarrierKeys)-1 {
|
|
buf.WriteRune('\n')
|
|
}
|
|
}
|
|
err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "barrier_keys"), buf.Bytes(), 0755)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
for _, key := range keys {
|
|
if _, err := c1.Unseal(TestKeyCopy(key)); err != nil {
|
|
t.Fatalf("unseal err: %s", err)
|
|
}
|
|
}
|
|
|
|
// Verify unsealed
|
|
sealed, err := c1.Sealed()
|
|
if err != nil {
|
|
t.Fatalf("err checking seal status: %s", err)
|
|
}
|
|
if sealed {
|
|
t.Fatal("should not be sealed")
|
|
}
|
|
|
|
TestWaitActive(t, c1)
|
|
|
|
if opts == nil || !opts.KeepStandbysSealed {
|
|
for _, key := range keys {
|
|
if _, err := c2.Unseal(TestKeyCopy(key)); err != nil {
|
|
t.Fatalf("unseal err: %s", err)
|
|
}
|
|
}
|
|
for _, key := range keys {
|
|
if _, err := c3.Unseal(TestKeyCopy(key)); err != nil {
|
|
t.Fatalf("unseal err: %s", err)
|
|
}
|
|
}
|
|
|
|
// Let them come fully up to standby
|
|
time.Sleep(2 * time.Second)
|
|
|
|
// Ensure cluster connection info is populated
|
|
isLeader, _, _, err := c2.Leader()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if isLeader {
|
|
t.Fatal("c2 should not be leader")
|
|
}
|
|
isLeader, _, _, err = c3.Leader()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if isLeader {
|
|
t.Fatal("c3 should not be leader")
|
|
}
|
|
}
|
|
|
|
cluster, err := c1.Cluster()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
testCluster.ID = cluster.ID
|
|
|
|
getAPIClient := func(port int, tlsConfig *tls.Config) *api.Client {
|
|
transport := cleanhttp.DefaultPooledTransport()
|
|
transport.TLSClientConfig = tlsConfig
|
|
client := &http.Client{
|
|
Transport: transport,
|
|
CheckRedirect: func(*http.Request, []*http.Request) error {
|
|
// This can of course be overridden per-test by using its own client
|
|
return fmt.Errorf("redirects not allowed in these tests")
|
|
},
|
|
}
|
|
config := api.DefaultConfig()
|
|
config.Address = fmt.Sprintf("https://127.0.0.1:%d", port)
|
|
config.HttpClient = client
|
|
apiClient, err := api.NewClient(config)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
apiClient.SetToken(root)
|
|
return apiClient
|
|
}
|
|
|
|
var ret []*TestClusterCore
|
|
t1 := &TestClusterCore{
|
|
Core: c1,
|
|
ServerKey: s1Key,
|
|
ServerKeyPEM: s1KeyPEM,
|
|
ServerCert: s1Cert,
|
|
ServerCertBytes: s1CertBytes,
|
|
ServerCertPEM: s1CertPEM,
|
|
Listeners: c1lns,
|
|
Handler: handler1,
|
|
Server: server1,
|
|
TLSConfig: s1TLSConfig,
|
|
Client: getAPIClient(c1lns[0].Address.Port, s1TLSConfig),
|
|
}
|
|
t1.ReloadFuncs = &c1.reloadFuncs
|
|
t1.ReloadFuncsLock = &c1.reloadFuncsLock
|
|
t1.ReloadFuncsLock.Lock()
|
|
(*t1.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{s1CertGetter.Reload}
|
|
t1.ReloadFuncsLock.Unlock()
|
|
ret = append(ret, t1)
|
|
|
|
t2 := &TestClusterCore{
|
|
Core: c2,
|
|
ServerKey: s2Key,
|
|
ServerKeyPEM: s2KeyPEM,
|
|
ServerCert: s2Cert,
|
|
ServerCertBytes: s2CertBytes,
|
|
ServerCertPEM: s2CertPEM,
|
|
Listeners: c2lns,
|
|
Handler: handler2,
|
|
Server: server2,
|
|
TLSConfig: s2TLSConfig,
|
|
Client: getAPIClient(c2lns[0].Address.Port, s2TLSConfig),
|
|
}
|
|
t2.ReloadFuncs = &c2.reloadFuncs
|
|
t2.ReloadFuncsLock = &c2.reloadFuncsLock
|
|
t2.ReloadFuncsLock.Lock()
|
|
(*t2.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{s2CertGetter.Reload}
|
|
t2.ReloadFuncsLock.Unlock()
|
|
ret = append(ret, t2)
|
|
|
|
t3 := &TestClusterCore{
|
|
Core: c3,
|
|
ServerKey: s3Key,
|
|
ServerKeyPEM: s3KeyPEM,
|
|
ServerCert: s3Cert,
|
|
ServerCertBytes: s3CertBytes,
|
|
ServerCertPEM: s3CertPEM,
|
|
Listeners: c3lns,
|
|
Handler: handler3,
|
|
Server: server3,
|
|
TLSConfig: s3TLSConfig,
|
|
Client: getAPIClient(c3lns[0].Address.Port, s3TLSConfig),
|
|
}
|
|
t3.ReloadFuncs = &c3.reloadFuncs
|
|
t3.ReloadFuncsLock = &c3.reloadFuncsLock
|
|
t3.ReloadFuncsLock.Lock()
|
|
(*t3.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{s3CertGetter.Reload}
|
|
t3.ReloadFuncsLock.Unlock()
|
|
ret = append(ret, t3)
|
|
|
|
testCluster.Cores = ret
|
|
return &testCluster
|
|
}
|