Lazy-load plugin mounts (#3255)
* Lazy load plugins to avoid setup-unwrap cycle * Remove commented blocks * Refactor NewTestCluster, use single core cluster on basic plugin tests * Set c.pluginDirectory in TestAddTestPlugin for setupPluginCatalog to work properly * Add special path to mock plugin * Move ensureCoresSealed to vault/testing.go * Use same method for EnsureCoresSealed and Cleanup * Bump ensureCoresSealed timeout to 60s * Correctly handle nil opts on NewTestCluster * Add metadata flag to APIClientMeta, use meta-enabled plugin when mounting to bootstrap * Check metadata flag directly on the plugin process * Plumb isMetadataMode down to PluginRunner * Add NOOP shims when running in metadata mode * Remove unused flag from the APIMetadata object * Remove setupSecretPlugins and setupCredentialPlugins functions * Move when we setup rollback manager to after the plugins are initialized * Fix tests * Fix merge issue * start rollback manager after the credential setup * Add guards against running certain client and server functions while in metadata mode * Call initialize once a plugin is loaded on the fly * Add more tests, update basic secret/auth plugin tests to trigger lazy loading * Skip mount if plugin removed from catalog * Fixup * Remove commented line on LookupPlugin * Fail on mount operation if plugin is re-added to catalog and mount is on existing path * Check type and special paths on startBackend * Fix merge conflicts * Refactor PluginRunner run methods to use runCommon, fix TestSystemBackend_Plugin_auth
This commit is contained in:
parent
590e2de328
commit
a581e96b78
|
@ -3,13 +3,20 @@ package plugin
|
|||
import (
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
bplugin "github.com/hashicorp/vault/logical/plugin"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMismatchType = fmt.Errorf("mismatch on mounted backend and plugin backend type")
|
||||
ErrMismatchPaths = fmt.Errorf("mismatch on mounted backend and plugin backend special paths")
|
||||
)
|
||||
|
||||
// Factory returns a configured plugin logical.Backend.
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
_, ok := conf.Config["plugin_name"]
|
||||
|
@ -31,14 +38,33 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
|||
// or as a concrete implementation if builtin, casted as logical.Backend.
|
||||
func Backend(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
var b backend
|
||||
|
||||
name := conf.Config["plugin_name"]
|
||||
sys := conf.System
|
||||
|
||||
raw, err := bplugin.NewBackend(name, sys, conf.Logger)
|
||||
// NewBackend with isMetadataMode set to true
|
||||
raw, err := bplugin.NewBackend(name, sys, conf.Logger, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.Backend = raw
|
||||
err = raw.Setup(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Get SpecialPaths and BackendType
|
||||
paths := raw.SpecialPaths()
|
||||
btype := raw.Type()
|
||||
|
||||
// Cleanup meta plugin backend
|
||||
raw.Cleanup()
|
||||
|
||||
// Initialize b.Backend with dummy backend since plugin
|
||||
// backends will need to be lazy loaded.
|
||||
b.Backend = &framework.Backend{
|
||||
PathsSpecial: paths,
|
||||
BackendType: btype,
|
||||
}
|
||||
|
||||
b.config = conf
|
||||
|
||||
return &b, nil
|
||||
|
@ -53,16 +79,24 @@ type backend struct {
|
|||
|
||||
// Used to detect if we already reloaded
|
||||
canary string
|
||||
|
||||
// Used to detect if plugin is set
|
||||
loaded bool
|
||||
}
|
||||
|
||||
func (b *backend) reloadBackend() error {
|
||||
b.Logger().Trace("plugin: reloading plugin backend", "plugin", b.config.Config["plugin_name"])
|
||||
return b.startBackend()
|
||||
}
|
||||
|
||||
// startBackend starts a plugin backend
|
||||
func (b *backend) startBackend() error {
|
||||
pluginName := b.config.Config["plugin_name"]
|
||||
b.Logger().Trace("plugin: reloading plugin backend", "plugin", pluginName)
|
||||
|
||||
// Ensure proper cleanup of the backend (i.e. call client.Kill())
|
||||
b.Backend.Cleanup()
|
||||
|
||||
nb, err := bplugin.NewBackend(pluginName, b.config.System, b.config.Logger)
|
||||
nb, err := bplugin.NewBackend(pluginName, b.config.System, b.config.Logger, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -70,7 +104,29 @@ func (b *backend) reloadBackend() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the backend has not been loaded (i.e. still in metadata mode),
|
||||
// check if type and special paths still matches
|
||||
if !b.loaded {
|
||||
if b.Backend.Type() != nb.Type() {
|
||||
nb.Cleanup()
|
||||
b.Logger().Warn("plugin: failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchType)
|
||||
return ErrMismatchType
|
||||
}
|
||||
if !reflect.DeepEqual(b.Backend.SpecialPaths(), nb.SpecialPaths()) {
|
||||
nb.Cleanup()
|
||||
b.Logger().Warn("plugin: failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchPaths)
|
||||
return ErrMismatchPaths
|
||||
}
|
||||
}
|
||||
|
||||
b.Backend = nb
|
||||
b.loaded = true
|
||||
|
||||
// Call initialize
|
||||
if err := b.Backend.Initialize(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -79,6 +135,23 @@ func (b *backend) reloadBackend() error {
|
|||
func (b *backend) HandleRequest(req *logical.Request) (*logical.Response, error) {
|
||||
b.RLock()
|
||||
canary := b.canary
|
||||
|
||||
// Lazy-load backend
|
||||
if !b.loaded {
|
||||
// Upgrade lock
|
||||
b.RUnlock()
|
||||
b.Lock()
|
||||
// Check once more after lock swap
|
||||
if !b.loaded {
|
||||
err := b.startBackend()
|
||||
if err != nil {
|
||||
b.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
b.Unlock()
|
||||
b.RLock()
|
||||
}
|
||||
resp, err := b.Backend.HandleRequest(req)
|
||||
b.RUnlock()
|
||||
// Need to compare string value for case were err comes from plugin RPC
|
||||
|
@ -112,6 +185,24 @@ func (b *backend) HandleRequest(req *logical.Request) (*logical.Response, error)
|
|||
func (b *backend) HandleExistenceCheck(req *logical.Request) (bool, bool, error) {
|
||||
b.RLock()
|
||||
canary := b.canary
|
||||
|
||||
// Lazy-load backend
|
||||
if !b.loaded {
|
||||
// Upgrade lock
|
||||
b.RUnlock()
|
||||
b.Lock()
|
||||
// Check once more after lock swap
|
||||
if !b.loaded {
|
||||
err := b.startBackend()
|
||||
if err != nil {
|
||||
b.Unlock()
|
||||
return false, false, err
|
||||
}
|
||||
}
|
||||
b.Unlock()
|
||||
b.RLock()
|
||||
}
|
||||
|
||||
checkFound, exists, err := b.Backend.HandleExistenceCheck(req)
|
||||
b.RUnlock()
|
||||
if err != nil && err.Error() == rpc.ErrShutdown.Error() {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
|
@ -39,7 +40,8 @@ func TestBackend_Factory(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBackend_PluginMain(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
args := []string{}
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadaModeEnv) != "true" {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -48,7 +50,7 @@ func TestBackend_PluginMain(t *testing.T) {
|
|||
t.Fatal("CA cert not passed in")
|
||||
}
|
||||
|
||||
args := []string{"--ca-cert=" + caPEM}
|
||||
args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM))
|
||||
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
// PluginUnwrapTokenEnv is the ENV name used to pass the configuration for
|
||||
// PluginMlockEnabled is the ENV name used to pass the configuration for
|
||||
// enabling mlock
|
||||
PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED"
|
||||
)
|
||||
|
|
|
@ -2,6 +2,7 @@ package pluginutil
|
|||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
@ -22,6 +23,7 @@ type Looker interface {
|
|||
// Wrapper interface defines the functions needed by the runner to wrap the
|
||||
// metadata needed to run a plugin process. This includes looking up Mlock
|
||||
// configuration and wrapping data in a respose wrapped token.
|
||||
// logical.SystemView implementataions satisfy this interface.
|
||||
type RunnerUtil interface {
|
||||
ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error)
|
||||
MlockEnabled() bool
|
||||
|
@ -44,56 +46,82 @@ type PluginRunner struct {
|
|||
BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"`
|
||||
}
|
||||
|
||||
// Run takes a wrapper instance, and the go-plugin paramaters and executes a
|
||||
// plugin.
|
||||
// Run takes a wrapper RunnerUtil instance along with the go-plugin paramaters and
|
||||
// returns a configured plugin.Client with TLS Configured and a wrapping token set
|
||||
// on PluginUnwrapTokenEnv for plugin process consumption.
|
||||
func (r *PluginRunner) Run(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string, logger log.Logger) (*plugin.Client, error) {
|
||||
// Get a CA TLS Certificate
|
||||
certBytes, key, err := generateCert()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.runCommon(wrapper, pluginMap, hs, env, logger, false)
|
||||
}
|
||||
|
||||
// Use CA to sign a client cert and return a configured TLS config
|
||||
clientTLSConfig, err := createClientTLSConfig(certBytes, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// RunMetadataMode returns a configured plugin.Client that will dispense a plugin
|
||||
// in metadata mode. The PluginMetadaModeEnv is passed in as part of the Cmd to
|
||||
// plugin.Client, and consumed by the plugin process on pluginutil.VaultPluginTLSProvider.
|
||||
func (r *PluginRunner) RunMetadataMode(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string, logger log.Logger) (*plugin.Client, error) {
|
||||
return r.runCommon(wrapper, pluginMap, hs, env, logger, true)
|
||||
|
||||
// Use CA to sign a server cert and wrap the values in a response wrapped
|
||||
// token.
|
||||
wrapToken, err := wrapServerConfig(wrapper, certBytes, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (r *PluginRunner) runCommon(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string, logger log.Logger, isMetadataMode bool) (*plugin.Client, error) {
|
||||
cmd := exec.Command(r.Command, r.Args...)
|
||||
cmd.Env = append(cmd.Env, env...)
|
||||
// Add the response wrap token to the ENV of the plugin
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken))
|
||||
|
||||
// Add the mlock setting to the ENV of the plugin
|
||||
if wrapper.MlockEnabled() {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true"))
|
||||
}
|
||||
|
||||
secureConfig := &plugin.SecureConfig{
|
||||
Checksum: r.Sha256,
|
||||
Hash: sha256.New(),
|
||||
}
|
||||
|
||||
// Create logger for the plugin client
|
||||
clogger := &hclogFaker{
|
||||
logger: logger,
|
||||
}
|
||||
namedLogger := clogger.ResetNamed("plugin")
|
||||
|
||||
client := plugin.NewClient(&plugin.ClientConfig{
|
||||
var clientTLSConfig *tls.Config
|
||||
if !isMetadataMode {
|
||||
// Add the metadata mode ENV and set it to false
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMetadaModeEnv, "false"))
|
||||
|
||||
// Get a CA TLS Certificate
|
||||
certBytes, key, err := generateCert()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use CA to sign a client cert and return a configured TLS config
|
||||
clientTLSConfig, err = createClientTLSConfig(certBytes, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use CA to sign a server cert and wrap the values in a response wrapped
|
||||
// token.
|
||||
wrapToken, err := wrapServerConfig(wrapper, certBytes, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add the response wrap token to the ENV of the plugin
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken))
|
||||
} else {
|
||||
namedLogger = clogger.ResetNamed("plugin.metadata")
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMetadaModeEnv, "true"))
|
||||
}
|
||||
|
||||
secureConfig := &plugin.SecureConfig{
|
||||
Checksum: r.Sha256,
|
||||
Hash: sha256.New(),
|
||||
}
|
||||
|
||||
clientConfig := &plugin.ClientConfig{
|
||||
HandshakeConfig: hs,
|
||||
Plugins: pluginMap,
|
||||
Cmd: cmd,
|
||||
TLSConfig: clientTLSConfig,
|
||||
SecureConfig: secureConfig,
|
||||
TLSConfig: clientTLSConfig,
|
||||
Logger: namedLogger,
|
||||
})
|
||||
}
|
||||
|
||||
client := plugin.NewClient(clientConfig)
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
@ -108,7 +136,7 @@ type APIClientMeta struct {
|
|||
}
|
||||
|
||||
func (f *APIClientMeta) FlagSet() *flag.FlagSet {
|
||||
fs := flag.NewFlagSet("tls settings", flag.ContinueOnError)
|
||||
fs := flag.NewFlagSet("vault plugin settings", flag.ContinueOnError)
|
||||
|
||||
fs.StringVar(&f.flagCACert, "ca-cert", "", "")
|
||||
fs.StringVar(&f.flagCAPath, "ca-path", "", "")
|
||||
|
|
|
@ -29,6 +29,10 @@ var (
|
|||
// PluginCACertPEMEnv is an ENV name used for holding a CA PEM-encoded
|
||||
// string. Used for testing.
|
||||
PluginCACertPEMEnv = "VAULT_TESTING_PLUGIN_CA_PEM"
|
||||
|
||||
// PluginMetadaModeEnv is an ENV name used to disable TLS communication
|
||||
// to bootstrap mounting plugins.
|
||||
PluginMetadaModeEnv = "VAULT_PLUGIN_METADATA_MODE"
|
||||
)
|
||||
|
||||
// generateCert is used internally to create certificates for the plugin
|
||||
|
@ -124,6 +128,10 @@ func wrapServerConfig(sys RunnerUtil, certBytes []byte, key *ecdsa.PrivateKey) (
|
|||
// VaultPluginTLSProvider is run inside a plugin and retrives the response
|
||||
// wrapped TLS certificate from vault. It returns a configured TLS Config.
|
||||
func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, error) {
|
||||
if os.Getenv(PluginMetadaModeEnv) == "true" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return func() (*tls.Config, error) {
|
||||
unwrapToken := os.Getenv(PluginUnwrapTokenEnv)
|
||||
|
||||
|
@ -157,7 +165,10 @@ func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, er
|
|||
clientConf := api.DefaultConfig()
|
||||
clientConf.Address = vaultAddr
|
||||
if apiTLSConfig != nil {
|
||||
clientConf.ConfigureTLS(apiTLSConfig)
|
||||
err := clientConf.ConfigureTLS(apiTLSConfig)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error configuring api client {{err}}", err)
|
||||
}
|
||||
}
|
||||
client, err := api.NewClient(clientConf)
|
||||
if err != nil {
|
||||
|
|
|
@ -9,7 +9,8 @@ import (
|
|||
|
||||
// BackendPlugin is the plugin.Plugin implementation
|
||||
type BackendPlugin struct {
|
||||
Factory func(*logical.BackendConfig) (logical.Backend, error)
|
||||
Factory func(*logical.BackendConfig) (logical.Backend, error)
|
||||
metadataMode bool
|
||||
}
|
||||
|
||||
// Server gets called when on plugin.Serve()
|
||||
|
@ -19,5 +20,5 @@ func (b *BackendPlugin) Server(broker *plugin.MuxBroker) (interface{}, error) {
|
|||
|
||||
// Client gets called on plugin.NewClient()
|
||||
func (b BackendPlugin) Client(broker *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
|
||||
return &backendPluginClient{client: c, broker: broker}, nil
|
||||
return &backendPluginClient{client: c, broker: broker, metadataMode: b.metadataMode}, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/rpc"
|
||||
|
||||
"github.com/hashicorp/go-plugin"
|
||||
|
@ -8,11 +9,16 @@ import (
|
|||
log "github.com/mgutz/logxi/v1"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrClientInMetadataMode = errors.New("plugin client can not perform action while in metadata mode")
|
||||
)
|
||||
|
||||
// backendPluginClient implements logical.Backend and is the
|
||||
// go-plugin client.
|
||||
type backendPluginClient struct {
|
||||
broker *plugin.MuxBroker
|
||||
client *rpc.Client
|
||||
broker *plugin.MuxBroker
|
||||
client *rpc.Client
|
||||
metadataMode bool
|
||||
|
||||
system logical.SystemView
|
||||
logger log.Logger
|
||||
|
@ -83,6 +89,10 @@ type RegisterLicenseReply struct {
|
|||
}
|
||||
|
||||
func (b *backendPluginClient) HandleRequest(req *logical.Request) (*logical.Response, error) {
|
||||
if b.metadataMode {
|
||||
return nil, ErrClientInMetadataMode
|
||||
}
|
||||
|
||||
// Do not send the storage, since go-plugin cannot serialize
|
||||
// interfaces. The server will pick up the storage from the shim.
|
||||
req.Storage = nil
|
||||
|
@ -136,6 +146,10 @@ func (b *backendPluginClient) Logger() log.Logger {
|
|||
}
|
||||
|
||||
func (b *backendPluginClient) HandleExistenceCheck(req *logical.Request) (bool, bool, error) {
|
||||
if b.metadataMode {
|
||||
return false, false, ErrClientInMetadataMode
|
||||
}
|
||||
|
||||
// Do not send the storage, since go-plugin cannot serialize
|
||||
// interfaces. The server will pick up the storage from the shim.
|
||||
req.Storage = nil
|
||||
|
@ -172,31 +186,49 @@ func (b *backendPluginClient) Cleanup() {
|
|||
}
|
||||
|
||||
func (b *backendPluginClient) Initialize() error {
|
||||
if b.metadataMode {
|
||||
return ErrClientInMetadataMode
|
||||
}
|
||||
err := b.client.Call("Plugin.Initialize", new(interface{}), &struct{}{})
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *backendPluginClient) InvalidateKey(key string) {
|
||||
if b.metadataMode {
|
||||
return
|
||||
}
|
||||
b.client.Call("Plugin.InvalidateKey", key, &struct{}{})
|
||||
}
|
||||
|
||||
func (b *backendPluginClient) Setup(config *logical.BackendConfig) error {
|
||||
// Shim logical.Storage
|
||||
storageImpl := config.StorageView
|
||||
if b.metadataMode {
|
||||
storageImpl = &NOOPStorage{}
|
||||
}
|
||||
storageID := b.broker.NextId()
|
||||
go b.broker.AcceptAndServe(storageID, &StorageServer{
|
||||
impl: config.StorageView,
|
||||
impl: storageImpl,
|
||||
})
|
||||
|
||||
// Shim log.Logger
|
||||
loggerImpl := config.Logger
|
||||
if b.metadataMode {
|
||||
loggerImpl = log.NullLog
|
||||
}
|
||||
loggerID := b.broker.NextId()
|
||||
go b.broker.AcceptAndServe(loggerID, &LoggerServer{
|
||||
logger: config.Logger,
|
||||
logger: loggerImpl,
|
||||
})
|
||||
|
||||
// Shim logical.SystemView
|
||||
sysViewImpl := config.System
|
||||
if b.metadataMode {
|
||||
sysViewImpl = &logical.StaticSystemView{}
|
||||
}
|
||||
sysViewID := b.broker.NextId()
|
||||
go b.broker.AcceptAndServe(sysViewID, &SystemViewServer{
|
||||
impl: config.System,
|
||||
impl: sysViewImpl,
|
||||
})
|
||||
|
||||
args := &SetupArgs{
|
||||
|
@ -233,6 +265,10 @@ func (b *backendPluginClient) Type() logical.BackendType {
|
|||
}
|
||||
|
||||
func (b *backendPluginClient) RegisterLicense(license interface{}) error {
|
||||
if b.metadataMode {
|
||||
return ErrClientInMetadataMode
|
||||
}
|
||||
|
||||
var reply RegisterLicenseReply
|
||||
args := RegisterLicenseArgs{
|
||||
License: license,
|
||||
|
|
|
@ -1,12 +1,19 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/rpc"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrServerInMetadataMode = errors.New("plugin server can not perform action while in metadata mode")
|
||||
)
|
||||
|
||||
// backendPluginServer is the RPC server that backendPluginClient talks to,
|
||||
// it methods conforming to requirements by net/rpc
|
||||
type backendPluginServer struct {
|
||||
|
@ -19,7 +26,15 @@ type backendPluginServer struct {
|
|||
storageClient *rpc.Client
|
||||
}
|
||||
|
||||
func inMetadataMode() bool {
|
||||
return os.Getenv(pluginutil.PluginMetadaModeEnv) == "true"
|
||||
}
|
||||
|
||||
func (b *backendPluginServer) HandleRequest(args *HandleRequestArgs, reply *HandleRequestReply) error {
|
||||
if inMetadataMode() {
|
||||
return ErrServerInMetadataMode
|
||||
}
|
||||
|
||||
storage := &StorageClient{client: b.storageClient}
|
||||
args.Request.Storage = storage
|
||||
|
||||
|
@ -40,6 +55,10 @@ func (b *backendPluginServer) SpecialPaths(_ interface{}, reply *SpecialPathsRep
|
|||
}
|
||||
|
||||
func (b *backendPluginServer) HandleExistenceCheck(args *HandleExistenceCheckArgs, reply *HandleExistenceCheckReply) error {
|
||||
if inMetadataMode() {
|
||||
return ErrServerInMetadataMode
|
||||
}
|
||||
|
||||
storage := &StorageClient{client: b.storageClient}
|
||||
args.Request.Storage = storage
|
||||
|
||||
|
@ -64,11 +83,19 @@ func (b *backendPluginServer) Cleanup(_ interface{}, _ *struct{}) error {
|
|||
}
|
||||
|
||||
func (b *backendPluginServer) Initialize(_ interface{}, _ *struct{}) error {
|
||||
if inMetadataMode() {
|
||||
return ErrServerInMetadataMode
|
||||
}
|
||||
|
||||
err := b.backend.Initialize()
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *backendPluginServer) InvalidateKey(args string, _ *struct{}) error {
|
||||
if inMetadataMode() {
|
||||
return ErrServerInMetadataMode
|
||||
}
|
||||
|
||||
b.backend.InvalidateKey(args)
|
||||
return nil
|
||||
}
|
||||
|
@ -145,6 +172,10 @@ func (b *backendPluginServer) Type(_ interface{}, reply *TypeReply) error {
|
|||
}
|
||||
|
||||
func (b *backendPluginServer) RegisterLicense(args *RegisterLicenseArgs, reply *RegisterLicenseReply) error {
|
||||
if inMetadataMode() {
|
||||
return ErrServerInMetadataMode
|
||||
}
|
||||
|
||||
err := b.backend.RegisterLicense(args.License)
|
||||
if err != nil {
|
||||
*reply = RegisterLicenseReply{
|
||||
|
|
|
@ -43,6 +43,7 @@ func Backend() *backend {
|
|||
kvPaths(&b),
|
||||
[]*framework.Path{
|
||||
pathInternal(&b),
|
||||
pathSpecial(&b),
|
||||
},
|
||||
),
|
||||
PathsSpecial: &logical.Paths{
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
func main() {
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args)
|
||||
flags.Parse(os.Args[1:]) // Ignore command, strictly parse flags
|
||||
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig)
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
package mock
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
// pathSpecial is used to test special paths.
|
||||
func pathSpecial(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "special",
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathSpecialRead,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathSpecialRead(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
// Return the secret
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"data": "foo",
|
||||
},
|
||||
}, nil
|
||||
|
||||
}
|
|
@ -40,8 +40,9 @@ func (b *BackendPluginClient) Cleanup() {
|
|||
|
||||
// NewBackend will return an instance of an RPC-based client implementation of the backend for
|
||||
// external plugins, or a concrete implementation of the backend if it is a builtin backend.
|
||||
// The backend is returned as a logical.Backend interface.
|
||||
func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (logical.Backend, error) {
|
||||
// The backend is returned as a logical.Backend interface. The isMetadataMode param determines whether
|
||||
// the plugin should run in metadata mode.
|
||||
func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
|
||||
// Look for plugin in the plugin catalog
|
||||
pluginRunner, err := sys.LookupPlugin(pluginName)
|
||||
if err != nil {
|
||||
|
@ -65,7 +66,7 @@ func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Log
|
|||
|
||||
} else {
|
||||
// create a backendPluginClient instance
|
||||
backend, err = newPluginClient(sys, pluginRunner, logger)
|
||||
backend, err = newPluginClient(sys, pluginRunner, logger, isMetadataMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -74,12 +75,21 @@ func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Log
|
|||
return backend, nil
|
||||
}
|
||||
|
||||
func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger) (logical.Backend, error) {
|
||||
func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
|
||||
// pluginMap is the map of plugins we can dispense.
|
||||
pluginMap := map[string]plugin.Plugin{
|
||||
"backend": &BackendPlugin{},
|
||||
"backend": &BackendPlugin{
|
||||
metadataMode: isMetadataMode,
|
||||
},
|
||||
}
|
||||
|
||||
var client *plugin.Client
|
||||
var err error
|
||||
if isMetadataMode {
|
||||
client, err = pluginRunner.RunMetadataMode(sys, pluginMap, handshakeConfig, []string{}, logger)
|
||||
} else {
|
||||
client, err = pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}, logger)
|
||||
}
|
||||
client, err := pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -20,7 +20,8 @@ type ServeOpts struct {
|
|||
TLSProviderFunc TLSProdiverFunc
|
||||
}
|
||||
|
||||
// Serve is used to serve a backend plugin
|
||||
// Serve is a helper function used to serve a backend plugin. This
|
||||
// should be ran on the plugin's main process.
|
||||
func Serve(opts *ServeOpts) error {
|
||||
// pluginMap is the map of plugins we can dispense.
|
||||
var pluginMap = map[string]plugin.Plugin{
|
||||
|
@ -34,6 +35,7 @@ func Serve(opts *ServeOpts) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// If FetchMetadata is true, run without TLSProvider
|
||||
plugin.Serve(&plugin.ServeConfig{
|
||||
HandshakeConfig: handshakeConfig,
|
||||
Plugins: pluginMap,
|
||||
|
|
|
@ -117,3 +117,23 @@ type StoragePutReply struct {
|
|||
type StorageDeleteReply struct {
|
||||
Error *plugin.BasicError
|
||||
}
|
||||
|
||||
// NOOPStorage is used to deny access to the storage interface while running a
|
||||
// backend plugin in metadata mode.
|
||||
type NOOPStorage struct{}
|
||||
|
||||
func (s *NOOPStorage) List(prefix string) ([]string, error) {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
func (s *NOOPStorage) Get(key string) (*logical.StorageEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *NOOPStorage) Put(entry *logical.StorageEntry) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NOOPStorage) Delete(key string) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
|
@ -397,7 +398,6 @@ func (c *Core) persistAuth(table *MountTable, localOnly bool) error {
|
|||
// setupCredentials is invoked after we've loaded the auth table to
|
||||
// initialize the credential backends and setup the router
|
||||
func (c *Core) setupCredentials() error {
|
||||
var backend logical.Backend
|
||||
var view *BarrierView
|
||||
var err error
|
||||
var persistNeeded bool
|
||||
|
@ -406,6 +406,7 @@ func (c *Core) setupCredentials() error {
|
|||
defer c.authLock.Unlock()
|
||||
|
||||
for _, entry := range c.auth.Entries {
|
||||
var backend logical.Backend
|
||||
// Work around some problematic code that existed in master for a while
|
||||
if strings.HasPrefix(entry.Path, credentialRoutePrefix) {
|
||||
entry.Path = strings.TrimPrefix(entry.Path, credentialRoutePrefix)
|
||||
|
@ -425,6 +426,9 @@ func (c *Core) setupCredentials() error {
|
|||
backend, err = c.newCredentialBackend(entry.Type, sysView, view, conf)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err)
|
||||
if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" {
|
||||
goto ROUTER_MOUNT
|
||||
}
|
||||
return errLoadAuthFailed
|
||||
}
|
||||
if backend == nil {
|
||||
|
@ -432,15 +436,14 @@ func (c *Core) setupCredentials() error {
|
|||
}
|
||||
|
||||
// Check for the correct backend type
|
||||
backendType := backend.Type()
|
||||
if entry.Type == "plugin" && backendType != logical.TypeCredential {
|
||||
return fmt.Errorf("cannot mount '%s' of type '%s' as an auth backend", entry.Config.PluginName, backendType)
|
||||
if entry.Type == "plugin" && backend.Type() != logical.TypeCredential {
|
||||
return fmt.Errorf("cannot mount '%s' of type '%s' as an auth backend", entry.Config.PluginName, backend.Type())
|
||||
}
|
||||
|
||||
if err := backend.Initialize(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ROUTER_MOUNT:
|
||||
// Mount the backend
|
||||
path := credentialRoutePrefix + entry.Path
|
||||
err = c.router.Mount(backend, path, entry, view)
|
||||
|
|
|
@ -1369,9 +1369,6 @@ func (c *Core) postUnseal() (retErr error) {
|
|||
if err := c.setupMounts(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.startRollback(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.setupPolicyStore(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1384,6 +1381,9 @@ func (c *Core) postUnseal() (retErr error) {
|
|||
if err := c.setupCredentials(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.startRollback(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.setupExpiration(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/helper/wrapping"
|
||||
|
@ -132,7 +134,7 @@ func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner,
|
|||
return nil, err
|
||||
}
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("no plugin found with name: %s", name)
|
||||
return nil, errwrap.Wrapf(fmt.Sprintf("{{err}}: %s", name), ErrPluginNotFound)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/plugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
|
@ -15,17 +16,196 @@ import (
|
|||
)
|
||||
|
||||
func TestSystemBackend_Plugin_secret(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 2, logical.TypeLogical)
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
||||
// Make a request to lazy load the plugin
|
||||
req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal")
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err := core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: response should not be nil")
|
||||
}
|
||||
|
||||
// Seal the cluster
|
||||
cluster.EnsureCoresSealed(t)
|
||||
|
||||
// Unseal the cluster
|
||||
barrierKeys := cluster.BarrierKeys
|
||||
for _, core := range cluster.Cores {
|
||||
for _, key := range barrierKeys {
|
||||
_, err := core.Unseal(vault.TestKeyCopy(key))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
sealed, err := core.Sealed()
|
||||
if err != nil {
|
||||
t.Fatalf("err checking seal status: %s", err)
|
||||
}
|
||||
if sealed {
|
||||
t.Fatal("should not be sealed")
|
||||
}
|
||||
// Wait for active so post-unseal takes place
|
||||
// If it fails, it means unseal process failed
|
||||
vault.TestWaitActive(t, core.Core)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_auth(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 2, logical.TypeCredential)
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeCredential)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
||||
// Make a request to lazy load the plugin
|
||||
req := logical.TestRequest(t, logical.ReadOperation, "auth/mock-0/internal")
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err := core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: response should not be nil")
|
||||
}
|
||||
|
||||
// Seal the cluster
|
||||
cluster.EnsureCoresSealed(t)
|
||||
|
||||
// Unseal the cluster
|
||||
barrierKeys := cluster.BarrierKeys
|
||||
for _, core := range cluster.Cores {
|
||||
for _, key := range barrierKeys {
|
||||
_, err := core.Unseal(vault.TestKeyCopy(key))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
sealed, err := core.Sealed()
|
||||
if err != nil {
|
||||
t.Fatalf("err checking seal status: %s", err)
|
||||
}
|
||||
if sealed {
|
||||
t.Fatal("should not be sealed")
|
||||
}
|
||||
// Wait for active so post-unseal takes place
|
||||
// If it fails, it means unseal process failed
|
||||
vault.TestWaitActive(t, core.Core)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_MismatchType(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
||||
// Replace the plugin with a credential backend
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials")
|
||||
|
||||
// Make a request to lazy load the now-credential plugin
|
||||
// and expect an error
|
||||
req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal")
|
||||
req.ClientToken = core.Client.Token()
|
||||
_, err := core.HandleRequest(req)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error due to mismatch on error type: %s", err)
|
||||
}
|
||||
|
||||
// Sleep a bit before cleanup is called
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_CatalogRemoved(t *testing.T) {
|
||||
t.Run("secret", func(t *testing.T) {
|
||||
testPlugin_CatalogRemoved(t, logical.TypeLogical, false)
|
||||
})
|
||||
|
||||
t.Run("auth", func(t *testing.T) {
|
||||
testPlugin_CatalogRemoved(t, logical.TypeCredential, false)
|
||||
})
|
||||
|
||||
t.Run("secret-mount-existing", func(t *testing.T) {
|
||||
testPlugin_CatalogRemoved(t, logical.TypeLogical, true)
|
||||
})
|
||||
|
||||
t.Run("auth-mount-existing", func(t *testing.T) {
|
||||
testPlugin_CatalogRemoved(t, logical.TypeCredential, true)
|
||||
})
|
||||
}
|
||||
|
||||
func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMount bool) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, btype)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
||||
// Remove the plugin from the catalog
|
||||
req := logical.TestRequest(t, logical.DeleteOperation, "sys/plugins/catalog/mock-plugin")
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err := core.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
// Seal the cluster
|
||||
cluster.EnsureCoresSealed(t)
|
||||
|
||||
// Unseal the cluster
|
||||
barrierKeys := cluster.BarrierKeys
|
||||
for _, core := range cluster.Cores {
|
||||
for _, key := range barrierKeys {
|
||||
_, err := core.Unseal(vault.TestKeyCopy(key))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
sealed, err := core.Sealed()
|
||||
if err != nil {
|
||||
t.Fatalf("err checking seal status: %s", err)
|
||||
}
|
||||
if sealed {
|
||||
t.Fatal("should not be sealed")
|
||||
}
|
||||
// Wait for active so post-unseal takes place
|
||||
// If it fails, it means unseal process failed
|
||||
vault.TestWaitActive(t, core.Core)
|
||||
}
|
||||
|
||||
if testMount {
|
||||
// Add plugin back to the catalog
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical")
|
||||
|
||||
// Mount the plugin at the same path after plugin is re-added to the catalog
|
||||
// and expect an error due to existing path.
|
||||
var err error
|
||||
switch btype {
|
||||
case logical.TypeLogical:
|
||||
_, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{
|
||||
"type": "plugin",
|
||||
"config": map[string]interface{}{
|
||||
"plugin_name": "mock-plugin",
|
||||
},
|
||||
})
|
||||
case logical.TypeCredential:
|
||||
_, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{
|
||||
"type": "plugin",
|
||||
"plugin_name": "mock-plugin",
|
||||
})
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatal("expected error when mounting on existing path")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_autoReload(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, logical.TypeLogical)
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
@ -65,6 +245,35 @@ func TestSystemBackend_Plugin_autoReload(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_SealUnseal(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
// Seal the cluster
|
||||
cluster.EnsureCoresSealed(t)
|
||||
|
||||
// Unseal the cluster
|
||||
barrierKeys := cluster.BarrierKeys
|
||||
for _, core := range cluster.Cores {
|
||||
for _, key := range barrierKeys {
|
||||
_, err := core.Unseal(vault.TestKeyCopy(key))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
sealed, err := core.Sealed()
|
||||
if err != nil {
|
||||
t.Fatalf("err checking seal status: %s", err)
|
||||
}
|
||||
if sealed {
|
||||
t.Fatal("should not be sealed")
|
||||
}
|
||||
// Wait for active so post-unseal takes place
|
||||
// If it fails, it means unseal process failed
|
||||
vault.TestWaitActive(t, core.Core)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_reload(t *testing.T) {
|
||||
data := map[string]interface{}{
|
||||
"plugin": "mock-plugin",
|
||||
|
@ -77,8 +286,9 @@ func TestSystemBackend_Plugin_reload(t *testing.T) {
|
|||
t.Run("mounts", func(t *testing.T) { testSystemBackend_PluginReload(t, data) })
|
||||
}
|
||||
|
||||
// Helper func to test different reload methods on plugin reload endpoint
|
||||
func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}) {
|
||||
cluster := testSystemBackendMock(t, 2, logical.TypeLogical)
|
||||
cluster := testSystemBackendMock(t, 1, 2, logical.TypeLogical)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
@ -123,7 +333,7 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
|
|||
|
||||
// testSystemBackendMock returns a systemBackend with the desired number
|
||||
// of mounted mock plugin backends
|
||||
func testSystemBackendMock(t *testing.T, numMounts int, backendType logical.BackendType) *vault.TestCluster {
|
||||
func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType) *vault.TestCluster {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"plugin": plugin.Factory,
|
||||
|
@ -134,7 +344,9 @@ func testSystemBackendMock(t *testing.T, numMounts int, backendType logical.Back
|
|||
}
|
||||
|
||||
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
KeepStandbysSealed: true,
|
||||
NumCores: numCores,
|
||||
})
|
||||
cluster.Start()
|
||||
|
||||
|
@ -197,7 +409,8 @@ func testSystemBackendMock(t *testing.T, numMounts int, backendType logical.Back
|
|||
}
|
||||
|
||||
func TestBackend_PluginMainLogical(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
args := []string{}
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadaModeEnv) != "true" {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -205,16 +418,16 @@ func TestBackend_PluginMainLogical(t *testing.T) {
|
|||
if caPEM == "" {
|
||||
t.Fatal("CA cert not passed in")
|
||||
}
|
||||
|
||||
factoryFunc := mock.FactoryType(logical.TypeLogical)
|
||||
|
||||
args := []string{"--ca-cert=" + caPEM}
|
||||
args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM))
|
||||
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
factoryFunc := mock.FactoryType(logical.TypeLogical)
|
||||
|
||||
err := lplugin.Serve(&lplugin.ServeOpts{
|
||||
BackendFactoryFunc: factoryFunc,
|
||||
TLSProviderFunc: tlsProviderFunc,
|
||||
|
@ -225,7 +438,8 @@ func TestBackend_PluginMainLogical(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBackend_PluginMainCredentials(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
args := []string{}
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadaModeEnv) != "true" {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -233,16 +447,16 @@ func TestBackend_PluginMainCredentials(t *testing.T) {
|
|||
if caPEM == "" {
|
||||
t.Fatal("CA cert not passed in")
|
||||
}
|
||||
|
||||
factoryFunc := mock.FactoryType(logical.TypeCredential)
|
||||
|
||||
args := []string{"--ca-cert=" + caPEM}
|
||||
args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM))
|
||||
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
factoryFunc := mock.FactoryType(logical.TypeCredential)
|
||||
|
||||
err := lplugin.Serve(&lplugin.ServeOpts{
|
||||
BackendFactoryFunc: factoryFunc,
|
||||
TLSProviderFunc: tlsProviderFunc,
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
|
@ -663,11 +664,12 @@ func (c *Core) setupMounts() error {
|
|||
c.mountsLock.Lock()
|
||||
defer c.mountsLock.Unlock()
|
||||
|
||||
var backend logical.Backend
|
||||
var view *BarrierView
|
||||
var err error
|
||||
|
||||
for _, entry := range c.mounts.Entries {
|
||||
var backend logical.Backend
|
||||
|
||||
// Initialize the backend, special casing for system
|
||||
barrierPath := backendBarrierPrefix + entry.UUID + "/"
|
||||
if entry.Type == "system" {
|
||||
|
@ -686,6 +688,9 @@ func (c *Core) setupMounts() error {
|
|||
backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err)
|
||||
if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" {
|
||||
goto ROUTER_MOUNT
|
||||
}
|
||||
return errLoadMountsFailed
|
||||
}
|
||||
if backend == nil {
|
||||
|
@ -693,9 +698,8 @@ func (c *Core) setupMounts() error {
|
|||
}
|
||||
|
||||
// Check for the correct backend type
|
||||
backendType := backend.Type()
|
||||
if entry.Type == "plugin" && backendType != logical.TypeLogical {
|
||||
return fmt.Errorf("cannot mount '%s' of type '%s' as a logical backend", entry.Config.PluginName, backendType)
|
||||
if entry.Type == "plugin" && backend.Type() != logical.TypeLogical {
|
||||
return fmt.Errorf("cannot mount '%s' of type '%s' as a logical backend", entry.Config.PluginName, backend.Type())
|
||||
}
|
||||
|
||||
if err := backend.Initialize(); err != nil {
|
||||
|
@ -710,7 +714,7 @@ func (c *Core) setupMounts() error {
|
|||
ch.saltUUID = entry.UUID
|
||||
ch.storageView = view
|
||||
}
|
||||
|
||||
ROUTER_MOUNT:
|
||||
// Mount the backend
|
||||
err = c.router.Mount(backend, entry.Path, entry, view)
|
||||
if err != nil {
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
var (
|
||||
pluginCatalogPath = "core/plugin-catalog/"
|
||||
ErrDirectoryNotConfigured = errors.New("could not set plugin, plugin directory is not configured")
|
||||
ErrPluginNotFound = errors.New("plugin not found in the catalog")
|
||||
)
|
||||
|
||||
// PluginCatalog keeps a record of plugins known to vault. External plugins need
|
||||
|
@ -37,6 +38,10 @@ func (c *Core) setupPluginCatalog() error {
|
|||
directory: c.pluginDirectory,
|
||||
}
|
||||
|
||||
if c.logger.IsInfo() {
|
||||
c.logger.Info("core: successfully setup plugin catalog", "plugin-directory", c.pluginDirectory)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -64,9 +64,12 @@ func (r *Router) Mount(backend logical.Backend, prefix string, mountEntry *Mount
|
|||
}
|
||||
|
||||
// Build the paths
|
||||
paths := backend.SpecialPaths()
|
||||
if paths == nil {
|
||||
paths = new(logical.Paths)
|
||||
paths := new(logical.Paths)
|
||||
if backend != nil {
|
||||
specialPaths := backend.SpecialPaths()
|
||||
if specialPaths != nil {
|
||||
paths = specialPaths
|
||||
}
|
||||
}
|
||||
|
||||
// Create a mount entry
|
||||
|
|
635
vault/testing.go
635
vault/testing.go
|
@ -335,11 +335,17 @@ func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) {
|
|||
}
|
||||
|
||||
sum := hash.Sum(nil)
|
||||
c.pluginCatalog.directory, err = filepath.EvalSymlinks(os.Args[0])
|
||||
|
||||
// Determine plugin directory path
|
||||
fullPath, err := filepath.EvalSymlinks(os.Args[0])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c.pluginCatalog.directory = filepath.Dir(c.pluginCatalog.directory)
|
||||
directoryPath := filepath.Dir(fullPath)
|
||||
|
||||
// Set core's plugin directory and plugin catalog directory
|
||||
c.pluginDirectory = directoryPath
|
||||
c.pluginCatalog.directory = directoryPath
|
||||
|
||||
command := fmt.Sprintf("%s --test.run=%s", filepath.Base(os.Args[0]), testFunc)
|
||||
err = c.pluginCatalog.Set(name, command, sum)
|
||||
|
@ -585,6 +591,7 @@ func GenerateRandBytes(length int) ([]byte, error) {
|
|||
}
|
||||
|
||||
func TestWaitActive(t testing.T, core *Core) {
|
||||
t.Helper()
|
||||
start := time.Now()
|
||||
var standby bool
|
||||
var err error
|
||||
|
@ -627,6 +634,13 @@ func (c *TestCluster) Start() {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *TestCluster) EnsureCoresSealed(t testing.T) {
|
||||
t.Helper()
|
||||
if err := c.ensureCoresSealed(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TestCluster) Cleanup() {
|
||||
// Close listeners
|
||||
for _, core := range c.Cores {
|
||||
|
@ -638,25 +652,7 @@ func (c *TestCluster) Cleanup() {
|
|||
}
|
||||
|
||||
// Seal the cores
|
||||
for _, core := range c.Cores {
|
||||
if err := core.Shutdown(); err != nil {
|
||||
continue
|
||||
}
|
||||
timeout := time.Now().Add(60 * time.Second)
|
||||
for {
|
||||
if time.Now().After(timeout) {
|
||||
continue
|
||||
}
|
||||
sealed, err := core.Sealed()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if sealed {
|
||||
break
|
||||
}
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
c.ensureCoresSealed()
|
||||
|
||||
// Remove any temp dir that exists
|
||||
if c.TempDir != "" {
|
||||
|
@ -667,6 +663,29 @@ func (c *TestCluster) Cleanup() {
|
|||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
func (c *TestCluster) ensureCoresSealed() error {
|
||||
for _, core := range c.Cores {
|
||||
if err := core.Shutdown(); err != nil {
|
||||
return err
|
||||
}
|
||||
timeout := time.Now().Add(60 * time.Second)
|
||||
for {
|
||||
if time.Now().After(timeout) {
|
||||
return fmt.Errorf("timeout waiting for core to seal")
|
||||
}
|
||||
sealed, err := core.Sealed()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sealed {
|
||||
break
|
||||
}
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type TestListener struct {
|
||||
net.Listener
|
||||
Address *net.TCPAddr
|
||||
|
@ -692,9 +711,29 @@ type TestClusterOptions struct {
|
|||
KeepStandbysSealed bool
|
||||
HandlerFunc func(*Core) http.Handler
|
||||
BaseListenAddress string
|
||||
NumCores int
|
||||
}
|
||||
|
||||
var DefaultNumCores = 3
|
||||
|
||||
type certInfo struct {
|
||||
cert *x509.Certificate
|
||||
certPEM []byte
|
||||
certBytes []byte
|
||||
key *ecdsa.PrivateKey
|
||||
keyPEM []byte
|
||||
}
|
||||
|
||||
// NewTestCluster creates a new test cluster based on the provided core config
|
||||
// and test cluster options.
|
||||
func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster {
|
||||
var numCores int
|
||||
if opts == nil || opts.NumCores == 0 {
|
||||
numCores = DefaultNumCores
|
||||
} else {
|
||||
numCores = opts.NumCores
|
||||
}
|
||||
|
||||
certIPs := []net.IP{
|
||||
net.IPv6loopback,
|
||||
net.ParseIP("127.0.0.1"),
|
||||
|
@ -770,270 +809,131 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s1Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s1CertTemplate := &x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "localhost",
|
||||
},
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: certIPs,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
|
||||
SerialNumber: big.NewInt(mathrand.Int63()),
|
||||
NotBefore: time.Now().Add(-30 * time.Second),
|
||||
NotAfter: time.Now().Add(262980 * time.Hour),
|
||||
}
|
||||
s1CertBytes, err := x509.CreateCertificate(rand.Reader, s1CertTemplate, caCert, s1Key.Public(), caKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s1Cert, err := x509.ParseCertificate(s1CertBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s1CertPEMBlock := &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: s1CertBytes,
|
||||
}
|
||||
s1CertPEM := pem.EncodeToMemory(s1CertPEMBlock)
|
||||
s1MarshaledKey, err := x509.MarshalECPrivateKey(s1Key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s1KeyPEMBlock := &pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: s1MarshaledKey,
|
||||
}
|
||||
s1KeyPEM := pem.EncodeToMemory(s1KeyPEMBlock)
|
||||
var certInfoSlice []*certInfo
|
||||
|
||||
s2Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s2CertTemplate := &x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "localhost",
|
||||
},
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: certIPs,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
|
||||
SerialNumber: big.NewInt(mathrand.Int63()),
|
||||
NotBefore: time.Now().Add(-30 * time.Second),
|
||||
NotAfter: time.Now().Add(262980 * time.Hour),
|
||||
}
|
||||
s2CertBytes, err := x509.CreateCertificate(rand.Reader, s2CertTemplate, caCert, s2Key.Public(), caKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s2Cert, err := x509.ParseCertificate(s2CertBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s2CertPEMBlock := &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: s2CertBytes,
|
||||
}
|
||||
s2CertPEM := pem.EncodeToMemory(s2CertPEMBlock)
|
||||
s2MarshaledKey, err := x509.MarshalECPrivateKey(s2Key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s2KeyPEMBlock := &pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: s2MarshaledKey,
|
||||
}
|
||||
s2KeyPEM := pem.EncodeToMemory(s2KeyPEMBlock)
|
||||
//
|
||||
// Certs generation
|
||||
//
|
||||
for i := 0; i < numCores; i++ {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
certTemplate := &x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "localhost",
|
||||
},
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: certIPs,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
|
||||
SerialNumber: big.NewInt(mathrand.Int63()),
|
||||
NotBefore: time.Now().Add(-30 * time.Second),
|
||||
NotAfter: time.Now().Add(262980 * time.Hour),
|
||||
}
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cert, err := x509.ParseCertificate(certBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
certPEMBlock := &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certBytes,
|
||||
}
|
||||
certPEM := pem.EncodeToMemory(certPEMBlock)
|
||||
marshaledKey, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
keyPEMBlock := &pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: marshaledKey,
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(keyPEMBlock)
|
||||
|
||||
s3Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
certInfoSlice = append(certInfoSlice, &certInfo{
|
||||
cert: cert,
|
||||
certPEM: certPEM,
|
||||
certBytes: certBytes,
|
||||
key: key,
|
||||
keyPEM: keyPEM,
|
||||
})
|
||||
}
|
||||
s3CertTemplate := &x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "localhost",
|
||||
},
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: certIPs,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
|
||||
SerialNumber: big.NewInt(mathrand.Int63()),
|
||||
NotBefore: time.Now().Add(-30 * time.Second),
|
||||
NotAfter: time.Now().Add(262980 * time.Hour),
|
||||
}
|
||||
s3CertBytes, err := x509.CreateCertificate(rand.Reader, s3CertTemplate, caCert, s3Key.Public(), caKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s3Cert, err := x509.ParseCertificate(s3CertBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s3CertPEMBlock := &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: s3CertBytes,
|
||||
}
|
||||
s3CertPEM := pem.EncodeToMemory(s3CertPEMBlock)
|
||||
s3MarshaledKey, err := x509.MarshalECPrivateKey(s3Key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s3KeyPEMBlock := &pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: s3MarshaledKey,
|
||||
}
|
||||
s3KeyPEM := pem.EncodeToMemory(s3KeyPEMBlock)
|
||||
|
||||
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||
|
||||
//
|
||||
// Listener setup
|
||||
//
|
||||
ports := []int{0, 0, 0}
|
||||
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||
ports := make([]int, numCores)
|
||||
if baseAddr != nil {
|
||||
ports = []int{baseAddr.Port, baseAddr.Port + 1, baseAddr.Port + 2}
|
||||
for i := 0; i < numCores; i++ {
|
||||
ports[i] = baseAddr.Port + i
|
||||
}
|
||||
} else {
|
||||
baseAddr = &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
}
|
||||
}
|
||||
baseAddr.Port = ports[0]
|
||||
ln, err := net.ListenTCP("tcp", baseAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s1CertFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node1_port_%d_cert.pem", ln.Addr().(*net.TCPAddr).Port))
|
||||
s1KeyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node1_port_%d_key.pem", ln.Addr().(*net.TCPAddr).Port))
|
||||
err = ioutil.WriteFile(s1CertFile, s1CertPEM, 0755)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = ioutil.WriteFile(s1KeyFile, s1KeyPEM, 0755)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s1TLSCert, err := tls.X509KeyPair(s1CertPEM, s1KeyPEM)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s1CertGetter := reload.NewCertificateGetter(s1CertFile, s1KeyFile)
|
||||
s1TLSConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{s1TLSCert},
|
||||
RootCAs: testCluster.RootCAs,
|
||||
ClientCAs: testCluster.RootCAs,
|
||||
ClientAuth: tls.VerifyClientCertIfGiven,
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
GetCertificate: s1CertGetter.GetCertificate,
|
||||
}
|
||||
s1TLSConfig.BuildNameToCertificate()
|
||||
c1lns := []*TestListener{&TestListener{
|
||||
Listener: tls.NewListener(ln, s1TLSConfig),
|
||||
Address: ln.Addr().(*net.TCPAddr),
|
||||
},
|
||||
}
|
||||
var handler1 http.Handler = http.NewServeMux()
|
||||
server1 := &http.Server{
|
||||
Handler: handler1,
|
||||
}
|
||||
if err := http2.ConfigureServer(server1, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
baseAddr.Port = ports[1]
|
||||
ln, err = net.ListenTCP("tcp", baseAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s2CertFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node2_port_%d_cert.pem", ln.Addr().(*net.TCPAddr).Port))
|
||||
s2KeyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node2_port_%d_key.pem", ln.Addr().(*net.TCPAddr).Port))
|
||||
err = ioutil.WriteFile(s2CertFile, s2CertPEM, 0755)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = ioutil.WriteFile(s2KeyFile, s2KeyPEM, 0755)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s2TLSCert, err := tls.X509KeyPair(s2CertPEM, s2KeyPEM)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s2CertGetter := reload.NewCertificateGetter(s2CertFile, s2KeyFile)
|
||||
s2TLSConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{s2TLSCert},
|
||||
RootCAs: testCluster.RootCAs,
|
||||
ClientCAs: testCluster.RootCAs,
|
||||
ClientAuth: tls.VerifyClientCertIfGiven,
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
GetCertificate: s2CertGetter.GetCertificate,
|
||||
}
|
||||
s2TLSConfig.BuildNameToCertificate()
|
||||
c2lns := []*TestListener{&TestListener{
|
||||
Listener: tls.NewListener(ln, s2TLSConfig),
|
||||
Address: ln.Addr().(*net.TCPAddr),
|
||||
},
|
||||
}
|
||||
var handler2 http.Handler = http.NewServeMux()
|
||||
server2 := &http.Server{
|
||||
Handler: handler2,
|
||||
}
|
||||
if err := http2.ConfigureServer(server2, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
baseAddr.Port = ports[2]
|
||||
ln, err = net.ListenTCP("tcp", baseAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s3CertFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node3_port_%d_cert.pem", ln.Addr().(*net.TCPAddr).Port))
|
||||
s3KeyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node3_port_%d_key.pem", ln.Addr().(*net.TCPAddr).Port))
|
||||
err = ioutil.WriteFile(s3CertFile, s3CertPEM, 0755)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = ioutil.WriteFile(s3KeyFile, s3KeyPEM, 0755)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s3TLSCert, err := tls.X509KeyPair(s3CertPEM, s3KeyPEM)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s3CertGetter := reload.NewCertificateGetter(s3CertFile, s3KeyFile)
|
||||
s3TLSConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{s3TLSCert},
|
||||
RootCAs: testCluster.RootCAs,
|
||||
ClientCAs: testCluster.RootCAs,
|
||||
ClientAuth: tls.VerifyClientCertIfGiven,
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
GetCertificate: s3CertGetter.GetCertificate,
|
||||
}
|
||||
s3TLSConfig.BuildNameToCertificate()
|
||||
c3lns := []*TestListener{&TestListener{
|
||||
Listener: tls.NewListener(ln, s3TLSConfig),
|
||||
Address: ln.Addr().(*net.TCPAddr),
|
||||
},
|
||||
}
|
||||
var handler3 http.Handler = http.NewServeMux()
|
||||
server3 := &http.Server{
|
||||
Handler: handler3,
|
||||
}
|
||||
if err := http2.ConfigureServer(server3, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
listeners := [][]*TestListener{}
|
||||
servers := []*http.Server{}
|
||||
handlers := []http.Handler{}
|
||||
tlsConfigs := []*tls.Config{}
|
||||
certGetters := []*reload.CertificateGetter{}
|
||||
for i := 0; i < numCores; i++ {
|
||||
baseAddr.Port = ports[i]
|
||||
ln, err := net.ListenTCP("tcp", baseAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port))
|
||||
keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port))
|
||||
err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0755)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0755)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
certGetter := reload.NewCertificateGetter(certFile, keyFile)
|
||||
certGetters = append(certGetters, certGetter)
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
RootCAs: testCluster.RootCAs,
|
||||
ClientCAs: testCluster.RootCAs,
|
||||
ClientAuth: tls.VerifyClientCertIfGiven,
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
GetCertificate: certGetter.GetCertificate,
|
||||
}
|
||||
tlsConfig.BuildNameToCertificate()
|
||||
tlsConfigs = append(tlsConfigs, tlsConfig)
|
||||
lns := []*TestListener{&TestListener{
|
||||
Listener: tls.NewListener(ln, tlsConfig),
|
||||
Address: ln.Addr().(*net.TCPAddr),
|
||||
},
|
||||
}
|
||||
listeners = append(listeners, lns)
|
||||
var handler http.Handler = http.NewServeMux()
|
||||
handlers = append(handlers, handler)
|
||||
server := &http.Server{
|
||||
Handler: handler,
|
||||
}
|
||||
servers = append(servers, server)
|
||||
if err := http2.ConfigureServer(server, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create three cores with the same physical and different redirect/cluster
|
||||
|
@ -1049,8 +949,8 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
|
|||
LogicalBackends: make(map[string]logical.Factory),
|
||||
CredentialBackends: make(map[string]logical.Factory),
|
||||
AuditBackends: make(map[string]audit.Factory),
|
||||
RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port),
|
||||
ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port+105),
|
||||
RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port),
|
||||
ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port+105),
|
||||
DisableMlock: true,
|
||||
EnableUI: true,
|
||||
}
|
||||
|
@ -1126,39 +1026,21 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
|
|||
coreConfig.HAPhysical = haPhys.(physical.HABackend)
|
||||
}
|
||||
|
||||
c1, err := NewCore(coreConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if opts != nil && opts.HandlerFunc != nil {
|
||||
handler1 = opts.HandlerFunc(c1)
|
||||
server1.Handler = handler1
|
||||
}
|
||||
|
||||
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port)
|
||||
if coreConfig.ClusterAddr != "" {
|
||||
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port+105)
|
||||
}
|
||||
c2, err := NewCore(coreConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if opts != nil && opts.HandlerFunc != nil {
|
||||
handler2 = opts.HandlerFunc(c2)
|
||||
server2.Handler = handler2
|
||||
}
|
||||
|
||||
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port)
|
||||
if coreConfig.ClusterAddr != "" {
|
||||
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port+105)
|
||||
}
|
||||
c3, err := NewCore(coreConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if opts != nil && opts.HandlerFunc != nil {
|
||||
handler3 = opts.HandlerFunc(c3)
|
||||
server3.Handler = handler3
|
||||
cores := []*Core{}
|
||||
for i := 0; i < numCores; i++ {
|
||||
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port)
|
||||
if coreConfig.ClusterAddr != "" {
|
||||
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port+105)
|
||||
}
|
||||
c, err := NewCore(coreConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
cores = append(cores, c)
|
||||
if opts != nil && opts.HandlerFunc != nil {
|
||||
handlers[i] = opts.HandlerFunc(c)
|
||||
servers[i].Handler = handlers[i]
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
|
@ -1175,16 +1057,19 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
|
|||
return ret
|
||||
}
|
||||
|
||||
c2.SetClusterListenerAddrs(clusterAddrGen(c2lns))
|
||||
c2.SetClusterHandler(handler2)
|
||||
c3.SetClusterListenerAddrs(clusterAddrGen(c3lns))
|
||||
c3.SetClusterHandler(handler3)
|
||||
if numCores > 1 {
|
||||
for i := 1; i < numCores; i++ {
|
||||
cores[i].SetClusterListenerAddrs(clusterAddrGen(listeners[i]))
|
||||
cores[i].SetClusterHandler(handlers[i])
|
||||
}
|
||||
}
|
||||
|
||||
keys, root := TestCoreInitClusterWrapperSetup(t, c1, clusterAddrGen(c1lns), handler1)
|
||||
keys, root := TestCoreInitClusterWrapperSetup(t, cores[0], clusterAddrGen(listeners[0]), handlers[0])
|
||||
barrierKeys, _ := copystructure.Copy(keys)
|
||||
testCluster.BarrierKeys = barrierKeys.([][]byte)
|
||||
testCluster.RootToken = root
|
||||
|
||||
// Write root token and barrier keys
|
||||
err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(root), 0755)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -1201,14 +1086,15 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Unseal first core
|
||||
for _, key := range keys {
|
||||
if _, err := c1.Unseal(TestKeyCopy(key)); err != nil {
|
||||
if _, err := cores[0].Unseal(TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify unsealed
|
||||
sealed, err := c1.Sealed()
|
||||
sealed, err := cores[0].Sealed()
|
||||
if err != nil {
|
||||
t.Fatalf("err checking seal status: %s", err)
|
||||
}
|
||||
|
@ -1216,41 +1102,38 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
|
|||
t.Fatal("should not be sealed")
|
||||
}
|
||||
|
||||
TestWaitActive(t, c1)
|
||||
TestWaitActive(t, cores[0])
|
||||
|
||||
if opts == nil || !opts.KeepStandbysSealed {
|
||||
for _, key := range keys {
|
||||
if _, err := c2.Unseal(TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
}
|
||||
for _, key := range keys {
|
||||
if _, err := c3.Unseal(TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
// Unseal other cores unless otherwise specified
|
||||
if (opts == nil || !opts.KeepStandbysSealed) && numCores > 1 {
|
||||
for i := 1; i < numCores; i++ {
|
||||
for _, key := range keys {
|
||||
if _, err := cores[i].Unseal(TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Let them come fully up to standby
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Ensure cluster connection info is populated
|
||||
isLeader, _, _, err := c2.Leader()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if isLeader {
|
||||
t.Fatal("c2 should not be leader")
|
||||
}
|
||||
isLeader, _, _, err = c3.Leader()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if isLeader {
|
||||
t.Fatal("c3 should not be leader")
|
||||
// Ensure cluster connection info is populated.
|
||||
// Other cores should not come up as leaders.
|
||||
for i := 1; i < numCores; i++ {
|
||||
isLeader, _, _, err := cores[i].Leader()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if isLeader {
|
||||
t.Fatalf("core[%d] should not be leader", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cluster, err := c1.Cluster()
|
||||
//
|
||||
// Set test cluster core(s) and test cluster
|
||||
//
|
||||
cluster, err := cores[0].Cluster()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1278,65 +1161,27 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
|
|||
}
|
||||
|
||||
var ret []*TestClusterCore
|
||||
t1 := &TestClusterCore{
|
||||
Core: c1,
|
||||
ServerKey: s1Key,
|
||||
ServerKeyPEM: s1KeyPEM,
|
||||
ServerCert: s1Cert,
|
||||
ServerCertBytes: s1CertBytes,
|
||||
ServerCertPEM: s1CertPEM,
|
||||
Listeners: c1lns,
|
||||
Handler: handler1,
|
||||
Server: server1,
|
||||
TLSConfig: s1TLSConfig,
|
||||
Client: getAPIClient(c1lns[0].Address.Port, s1TLSConfig),
|
||||
for i := 0; i < numCores; i++ {
|
||||
tcc := &TestClusterCore{
|
||||
Core: cores[i],
|
||||
ServerKey: certInfoSlice[i].key,
|
||||
ServerKeyPEM: certInfoSlice[i].keyPEM,
|
||||
ServerCert: certInfoSlice[i].cert,
|
||||
ServerCertBytes: certInfoSlice[i].certBytes,
|
||||
ServerCertPEM: certInfoSlice[i].certPEM,
|
||||
Listeners: listeners[i],
|
||||
Handler: handlers[i],
|
||||
Server: servers[i],
|
||||
TLSConfig: tlsConfigs[i],
|
||||
Client: getAPIClient(listeners[i][0].Address.Port, tlsConfigs[i]),
|
||||
}
|
||||
tcc.ReloadFuncs = &cores[i].reloadFuncs
|
||||
tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock
|
||||
tcc.ReloadFuncsLock.Lock()
|
||||
(*tcc.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{certGetters[i].Reload}
|
||||
tcc.ReloadFuncsLock.Unlock()
|
||||
ret = append(ret, tcc)
|
||||
}
|
||||
t1.ReloadFuncs = &c1.reloadFuncs
|
||||
t1.ReloadFuncsLock = &c1.reloadFuncsLock
|
||||
t1.ReloadFuncsLock.Lock()
|
||||
(*t1.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{s1CertGetter.Reload}
|
||||
t1.ReloadFuncsLock.Unlock()
|
||||
ret = append(ret, t1)
|
||||
|
||||
t2 := &TestClusterCore{
|
||||
Core: c2,
|
||||
ServerKey: s2Key,
|
||||
ServerKeyPEM: s2KeyPEM,
|
||||
ServerCert: s2Cert,
|
||||
ServerCertBytes: s2CertBytes,
|
||||
ServerCertPEM: s2CertPEM,
|
||||
Listeners: c2lns,
|
||||
Handler: handler2,
|
||||
Server: server2,
|
||||
TLSConfig: s2TLSConfig,
|
||||
Client: getAPIClient(c2lns[0].Address.Port, s2TLSConfig),
|
||||
}
|
||||
t2.ReloadFuncs = &c2.reloadFuncs
|
||||
t2.ReloadFuncsLock = &c2.reloadFuncsLock
|
||||
t2.ReloadFuncsLock.Lock()
|
||||
(*t2.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{s2CertGetter.Reload}
|
||||
t2.ReloadFuncsLock.Unlock()
|
||||
ret = append(ret, t2)
|
||||
|
||||
t3 := &TestClusterCore{
|
||||
Core: c3,
|
||||
ServerKey: s3Key,
|
||||
ServerKeyPEM: s3KeyPEM,
|
||||
ServerCert: s3Cert,
|
||||
ServerCertBytes: s3CertBytes,
|
||||
ServerCertPEM: s3CertPEM,
|
||||
Listeners: c3lns,
|
||||
Handler: handler3,
|
||||
Server: server3,
|
||||
TLSConfig: s3TLSConfig,
|
||||
Client: getAPIClient(c3lns[0].Address.Port, s3TLSConfig),
|
||||
}
|
||||
t3.ReloadFuncs = &c3.reloadFuncs
|
||||
t3.ReloadFuncsLock = &c3.reloadFuncsLock
|
||||
t3.ReloadFuncsLock.Lock()
|
||||
(*t3.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{s3CertGetter.Reload}
|
||||
t3.ReloadFuncsLock.Unlock()
|
||||
ret = append(ret, t3)
|
||||
|
||||
testCluster.Cores = ret
|
||||
return &testCluster
|
||||
|
|
Loading…
Reference in New Issue