diff --git a/builtin/plugin/backend.go b/builtin/plugin/backend.go index 3945f78d2..a1c781f57 100644 --- a/builtin/plugin/backend.go +++ b/builtin/plugin/backend.go @@ -3,13 +3,20 @@ package plugin import ( "fmt" "net/rpc" + "reflect" "sync" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" bplugin "github.com/hashicorp/vault/logical/plugin" ) +var ( + ErrMismatchType = fmt.Errorf("mismatch on mounted backend and plugin backend type") + ErrMismatchPaths = fmt.Errorf("mismatch on mounted backend and plugin backend special paths") +) + // Factory returns a configured plugin logical.Backend. func Factory(conf *logical.BackendConfig) (logical.Backend, error) { _, ok := conf.Config["plugin_name"] @@ -31,14 +38,33 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) { // or as a concrete implementation if builtin, casted as logical.Backend. func Backend(conf *logical.BackendConfig) (logical.Backend, error) { var b backend + name := conf.Config["plugin_name"] sys := conf.System - raw, err := bplugin.NewBackend(name, sys, conf.Logger) + // NewBackend with isMetadataMode set to true + raw, err := bplugin.NewBackend(name, sys, conf.Logger, true) if err != nil { return nil, err } - b.Backend = raw + err = raw.Setup(conf) + if err != nil { + return nil, err + } + // Get SpecialPaths and BackendType + paths := raw.SpecialPaths() + btype := raw.Type() + + // Cleanup meta plugin backend + raw.Cleanup() + + // Initialize b.Backend with dummy backend since plugin + // backends will need to be lazy loaded. + b.Backend = &framework.Backend{ + PathsSpecial: paths, + BackendType: btype, + } + b.config = conf return &b, nil @@ -53,16 +79,24 @@ type backend struct { // Used to detect if we already reloaded canary string + + // Used to detect if plugin is set + loaded bool } func (b *backend) reloadBackend() error { + b.Logger().Trace("plugin: reloading plugin backend", "plugin", b.config.Config["plugin_name"]) + return b.startBackend() +} + +// startBackend starts a plugin backend +func (b *backend) startBackend() error { pluginName := b.config.Config["plugin_name"] - b.Logger().Trace("plugin: reloading plugin backend", "plugin", pluginName) // Ensure proper cleanup of the backend (i.e. call client.Kill()) b.Backend.Cleanup() - nb, err := bplugin.NewBackend(pluginName, b.config.System, b.config.Logger) + nb, err := bplugin.NewBackend(pluginName, b.config.System, b.config.Logger, false) if err != nil { return err } @@ -70,7 +104,29 @@ func (b *backend) reloadBackend() error { if err != nil { return err } + + // If the backend has not been loaded (i.e. still in metadata mode), + // check if type and special paths still matches + if !b.loaded { + if b.Backend.Type() != nb.Type() { + nb.Cleanup() + b.Logger().Warn("plugin: failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchType) + return ErrMismatchType + } + if !reflect.DeepEqual(b.Backend.SpecialPaths(), nb.SpecialPaths()) { + nb.Cleanup() + b.Logger().Warn("plugin: failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchPaths) + return ErrMismatchPaths + } + } + b.Backend = nb + b.loaded = true + + // Call initialize + if err := b.Backend.Initialize(); err != nil { + return err + } return nil } @@ -79,6 +135,23 @@ func (b *backend) reloadBackend() error { func (b *backend) HandleRequest(req *logical.Request) (*logical.Response, error) { b.RLock() canary := b.canary + + // Lazy-load backend + if !b.loaded { + // Upgrade lock + b.RUnlock() + b.Lock() + // Check once more after lock swap + if !b.loaded { + err := b.startBackend() + if err != nil { + b.Unlock() + return nil, err + } + } + b.Unlock() + b.RLock() + } resp, err := b.Backend.HandleRequest(req) b.RUnlock() // Need to compare string value for case were err comes from plugin RPC @@ -112,6 +185,24 @@ func (b *backend) HandleRequest(req *logical.Request) (*logical.Response, error) func (b *backend) HandleExistenceCheck(req *logical.Request) (bool, bool, error) { b.RLock() canary := b.canary + + // Lazy-load backend + if !b.loaded { + // Upgrade lock + b.RUnlock() + b.Lock() + // Check once more after lock swap + if !b.loaded { + err := b.startBackend() + if err != nil { + b.Unlock() + return false, false, err + } + } + b.Unlock() + b.RLock() + } + checkFound, exists, err := b.Backend.HandleExistenceCheck(req) b.RUnlock() if err != nil && err.Error() == rpc.ErrShutdown.Error() { diff --git a/builtin/plugin/backend_test.go b/builtin/plugin/backend_test.go index 0a37691d6..5b0719709 100644 --- a/builtin/plugin/backend_test.go +++ b/builtin/plugin/backend_test.go @@ -1,6 +1,7 @@ package plugin import ( + "fmt" "os" "testing" @@ -39,7 +40,8 @@ func TestBackend_Factory(t *testing.T) { } func TestBackend_PluginMain(t *testing.T) { - if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { + args := []string{} + if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadaModeEnv) != "true" { return } @@ -48,7 +50,7 @@ func TestBackend_PluginMain(t *testing.T) { t.Fatal("CA cert not passed in") } - args := []string{"--ca-cert=" + caPEM} + args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM)) apiClientMeta := &pluginutil.APIClientMeta{} flags := apiClientMeta.FlagSet() diff --git a/helper/pluginutil/mlock.go b/helper/pluginutil/mlock.go index dd9115a89..1660ca8e0 100644 --- a/helper/pluginutil/mlock.go +++ b/helper/pluginutil/mlock.go @@ -7,7 +7,7 @@ import ( ) var ( - // PluginUnwrapTokenEnv is the ENV name used to pass the configuration for + // PluginMlockEnabled is the ENV name used to pass the configuration for // enabling mlock PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED" ) diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index e34f070c2..2047651ed 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -2,6 +2,7 @@ package pluginutil import ( "crypto/sha256" + "crypto/tls" "flag" "fmt" "os/exec" @@ -22,6 +23,7 @@ type Looker interface { // Wrapper interface defines the functions needed by the runner to wrap the // metadata needed to run a plugin process. This includes looking up Mlock // configuration and wrapping data in a respose wrapped token. +// logical.SystemView implementataions satisfy this interface. type RunnerUtil interface { ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) MlockEnabled() bool @@ -44,56 +46,82 @@ type PluginRunner struct { BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"` } -// Run takes a wrapper instance, and the go-plugin paramaters and executes a -// plugin. +// Run takes a wrapper RunnerUtil instance along with the go-plugin paramaters and +// returns a configured plugin.Client with TLS Configured and a wrapping token set +// on PluginUnwrapTokenEnv for plugin process consumption. func (r *PluginRunner) Run(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string, logger log.Logger) (*plugin.Client, error) { - // Get a CA TLS Certificate - certBytes, key, err := generateCert() - if err != nil { - return nil, err - } + return r.runCommon(wrapper, pluginMap, hs, env, logger, false) +} - // Use CA to sign a client cert and return a configured TLS config - clientTLSConfig, err := createClientTLSConfig(certBytes, key) - if err != nil { - return nil, err - } +// RunMetadataMode returns a configured plugin.Client that will dispense a plugin +// in metadata mode. The PluginMetadaModeEnv is passed in as part of the Cmd to +// plugin.Client, and consumed by the plugin process on pluginutil.VaultPluginTLSProvider. +func (r *PluginRunner) RunMetadataMode(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string, logger log.Logger) (*plugin.Client, error) { + return r.runCommon(wrapper, pluginMap, hs, env, logger, true) - // Use CA to sign a server cert and wrap the values in a response wrapped - // token. - wrapToken, err := wrapServerConfig(wrapper, certBytes, key) - if err != nil { - return nil, err - } +} +func (r *PluginRunner) runCommon(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string, logger log.Logger, isMetadataMode bool) (*plugin.Client, error) { cmd := exec.Command(r.Command, r.Args...) cmd.Env = append(cmd.Env, env...) - // Add the response wrap token to the ENV of the plugin - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken)) + // Add the mlock setting to the ENV of the plugin if wrapper.MlockEnabled() { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true")) } - secureConfig := &plugin.SecureConfig{ - Checksum: r.Sha256, - Hash: sha256.New(), - } - // Create logger for the plugin client clogger := &hclogFaker{ logger: logger, } namedLogger := clogger.ResetNamed("plugin") - client := plugin.NewClient(&plugin.ClientConfig{ + var clientTLSConfig *tls.Config + if !isMetadataMode { + // Add the metadata mode ENV and set it to false + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMetadaModeEnv, "false")) + + // Get a CA TLS Certificate + certBytes, key, err := generateCert() + if err != nil { + return nil, err + } + + // Use CA to sign a client cert and return a configured TLS config + clientTLSConfig, err = createClientTLSConfig(certBytes, key) + if err != nil { + return nil, err + } + + // Use CA to sign a server cert and wrap the values in a response wrapped + // token. + wrapToken, err := wrapServerConfig(wrapper, certBytes, key) + if err != nil { + return nil, err + } + + // Add the response wrap token to the ENV of the plugin + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken)) + } else { + namedLogger = clogger.ResetNamed("plugin.metadata") + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMetadaModeEnv, "true")) + } + + secureConfig := &plugin.SecureConfig{ + Checksum: r.Sha256, + Hash: sha256.New(), + } + + clientConfig := &plugin.ClientConfig{ HandshakeConfig: hs, Plugins: pluginMap, Cmd: cmd, - TLSConfig: clientTLSConfig, SecureConfig: secureConfig, + TLSConfig: clientTLSConfig, Logger: namedLogger, - }) + } + + client := plugin.NewClient(clientConfig) return client, nil } @@ -108,7 +136,7 @@ type APIClientMeta struct { } func (f *APIClientMeta) FlagSet() *flag.FlagSet { - fs := flag.NewFlagSet("tls settings", flag.ContinueOnError) + fs := flag.NewFlagSet("vault plugin settings", flag.ContinueOnError) fs.StringVar(&f.flagCACert, "ca-cert", "", "") fs.StringVar(&f.flagCAPath, "ca-path", "", "") diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index d31344e3c..112d33cf0 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -29,6 +29,10 @@ var ( // PluginCACertPEMEnv is an ENV name used for holding a CA PEM-encoded // string. Used for testing. PluginCACertPEMEnv = "VAULT_TESTING_PLUGIN_CA_PEM" + + // PluginMetadaModeEnv is an ENV name used to disable TLS communication + // to bootstrap mounting plugins. + PluginMetadaModeEnv = "VAULT_PLUGIN_METADATA_MODE" ) // generateCert is used internally to create certificates for the plugin @@ -124,6 +128,10 @@ func wrapServerConfig(sys RunnerUtil, certBytes []byte, key *ecdsa.PrivateKey) ( // VaultPluginTLSProvider is run inside a plugin and retrives the response // wrapped TLS certificate from vault. It returns a configured TLS Config. func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, error) { + if os.Getenv(PluginMetadaModeEnv) == "true" { + return nil + } + return func() (*tls.Config, error) { unwrapToken := os.Getenv(PluginUnwrapTokenEnv) @@ -157,7 +165,10 @@ func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, er clientConf := api.DefaultConfig() clientConf.Address = vaultAddr if apiTLSConfig != nil { - clientConf.ConfigureTLS(apiTLSConfig) + err := clientConf.ConfigureTLS(apiTLSConfig) + if err != nil { + return nil, errwrap.Wrapf("error configuring api client {{err}}", err) + } } client, err := api.NewClient(clientConf) if err != nil { diff --git a/logical/plugin/backend.go b/logical/plugin/backend.go index 8a80be029..081922c9b 100644 --- a/logical/plugin/backend.go +++ b/logical/plugin/backend.go @@ -9,7 +9,8 @@ import ( // BackendPlugin is the plugin.Plugin implementation type BackendPlugin struct { - Factory func(*logical.BackendConfig) (logical.Backend, error) + Factory func(*logical.BackendConfig) (logical.Backend, error) + metadataMode bool } // Server gets called when on plugin.Serve() @@ -19,5 +20,5 @@ func (b *BackendPlugin) Server(broker *plugin.MuxBroker) (interface{}, error) { // Client gets called on plugin.NewClient() func (b BackendPlugin) Client(broker *plugin.MuxBroker, c *rpc.Client) (interface{}, error) { - return &backendPluginClient{client: c, broker: broker}, nil + return &backendPluginClient{client: c, broker: broker, metadataMode: b.metadataMode}, nil } diff --git a/logical/plugin/backend_client.go b/logical/plugin/backend_client.go index f18564b13..cc2d83bcf 100644 --- a/logical/plugin/backend_client.go +++ b/logical/plugin/backend_client.go @@ -1,6 +1,7 @@ package plugin import ( + "errors" "net/rpc" "github.com/hashicorp/go-plugin" @@ -8,11 +9,16 @@ import ( log "github.com/mgutz/logxi/v1" ) +var ( + ErrClientInMetadataMode = errors.New("plugin client can not perform action while in metadata mode") +) + // backendPluginClient implements logical.Backend and is the // go-plugin client. type backendPluginClient struct { - broker *plugin.MuxBroker - client *rpc.Client + broker *plugin.MuxBroker + client *rpc.Client + metadataMode bool system logical.SystemView logger log.Logger @@ -83,6 +89,10 @@ type RegisterLicenseReply struct { } func (b *backendPluginClient) HandleRequest(req *logical.Request) (*logical.Response, error) { + if b.metadataMode { + return nil, ErrClientInMetadataMode + } + // Do not send the storage, since go-plugin cannot serialize // interfaces. The server will pick up the storage from the shim. req.Storage = nil @@ -136,6 +146,10 @@ func (b *backendPluginClient) Logger() log.Logger { } func (b *backendPluginClient) HandleExistenceCheck(req *logical.Request) (bool, bool, error) { + if b.metadataMode { + return false, false, ErrClientInMetadataMode + } + // Do not send the storage, since go-plugin cannot serialize // interfaces. The server will pick up the storage from the shim. req.Storage = nil @@ -172,31 +186,49 @@ func (b *backendPluginClient) Cleanup() { } func (b *backendPluginClient) Initialize() error { + if b.metadataMode { + return ErrClientInMetadataMode + } err := b.client.Call("Plugin.Initialize", new(interface{}), &struct{}{}) return err } func (b *backendPluginClient) InvalidateKey(key string) { + if b.metadataMode { + return + } b.client.Call("Plugin.InvalidateKey", key, &struct{}{}) } func (b *backendPluginClient) Setup(config *logical.BackendConfig) error { // Shim logical.Storage + storageImpl := config.StorageView + if b.metadataMode { + storageImpl = &NOOPStorage{} + } storageID := b.broker.NextId() go b.broker.AcceptAndServe(storageID, &StorageServer{ - impl: config.StorageView, + impl: storageImpl, }) // Shim log.Logger + loggerImpl := config.Logger + if b.metadataMode { + loggerImpl = log.NullLog + } loggerID := b.broker.NextId() go b.broker.AcceptAndServe(loggerID, &LoggerServer{ - logger: config.Logger, + logger: loggerImpl, }) // Shim logical.SystemView + sysViewImpl := config.System + if b.metadataMode { + sysViewImpl = &logical.StaticSystemView{} + } sysViewID := b.broker.NextId() go b.broker.AcceptAndServe(sysViewID, &SystemViewServer{ - impl: config.System, + impl: sysViewImpl, }) args := &SetupArgs{ @@ -233,6 +265,10 @@ func (b *backendPluginClient) Type() logical.BackendType { } func (b *backendPluginClient) RegisterLicense(license interface{}) error { + if b.metadataMode { + return ErrClientInMetadataMode + } + var reply RegisterLicenseReply args := RegisterLicenseArgs{ License: license, diff --git a/logical/plugin/backend_server.go b/logical/plugin/backend_server.go index 335bfa5f7..47045b1c1 100644 --- a/logical/plugin/backend_server.go +++ b/logical/plugin/backend_server.go @@ -1,12 +1,19 @@ package plugin import ( + "errors" "net/rpc" + "os" "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/logical" ) +var ( + ErrServerInMetadataMode = errors.New("plugin server can not perform action while in metadata mode") +) + // backendPluginServer is the RPC server that backendPluginClient talks to, // it methods conforming to requirements by net/rpc type backendPluginServer struct { @@ -19,7 +26,15 @@ type backendPluginServer struct { storageClient *rpc.Client } +func inMetadataMode() bool { + return os.Getenv(pluginutil.PluginMetadaModeEnv) == "true" +} + func (b *backendPluginServer) HandleRequest(args *HandleRequestArgs, reply *HandleRequestReply) error { + if inMetadataMode() { + return ErrServerInMetadataMode + } + storage := &StorageClient{client: b.storageClient} args.Request.Storage = storage @@ -40,6 +55,10 @@ func (b *backendPluginServer) SpecialPaths(_ interface{}, reply *SpecialPathsRep } func (b *backendPluginServer) HandleExistenceCheck(args *HandleExistenceCheckArgs, reply *HandleExistenceCheckReply) error { + if inMetadataMode() { + return ErrServerInMetadataMode + } + storage := &StorageClient{client: b.storageClient} args.Request.Storage = storage @@ -64,11 +83,19 @@ func (b *backendPluginServer) Cleanup(_ interface{}, _ *struct{}) error { } func (b *backendPluginServer) Initialize(_ interface{}, _ *struct{}) error { + if inMetadataMode() { + return ErrServerInMetadataMode + } + err := b.backend.Initialize() return err } func (b *backendPluginServer) InvalidateKey(args string, _ *struct{}) error { + if inMetadataMode() { + return ErrServerInMetadataMode + } + b.backend.InvalidateKey(args) return nil } @@ -145,6 +172,10 @@ func (b *backendPluginServer) Type(_ interface{}, reply *TypeReply) error { } func (b *backendPluginServer) RegisterLicense(args *RegisterLicenseArgs, reply *RegisterLicenseReply) error { + if inMetadataMode() { + return ErrServerInMetadataMode + } + err := b.backend.RegisterLicense(args.License) if err != nil { *reply = RegisterLicenseReply{ diff --git a/logical/plugin/mock/backend.go b/logical/plugin/mock/backend.go index 5f4c97749..ac8c0ba88 100644 --- a/logical/plugin/mock/backend.go +++ b/logical/plugin/mock/backend.go @@ -43,6 +43,7 @@ func Backend() *backend { kvPaths(&b), []*framework.Path{ pathInternal(&b), + pathSpecial(&b), }, ), PathsSpecial: &logical.Paths{ diff --git a/logical/plugin/mock/mock-plugin/main.go b/logical/plugin/mock/mock-plugin/main.go index 1cb47bfcf..b1b7fbd71 100644 --- a/logical/plugin/mock/mock-plugin/main.go +++ b/logical/plugin/mock/mock-plugin/main.go @@ -13,7 +13,7 @@ import ( func main() { apiClientMeta := &pluginutil.APIClientMeta{} flags := apiClientMeta.FlagSet() - flags.Parse(os.Args) + flags.Parse(os.Args[1:]) // Ignore command, strictly parse flags tlsConfig := apiClientMeta.GetTLSConfig() tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig) diff --git a/logical/plugin/mock/path_special.go b/logical/plugin/mock/path_special.go new file mode 100644 index 000000000..f695e209e --- /dev/null +++ b/logical/plugin/mock/path_special.go @@ -0,0 +1,27 @@ +package mock + +import ( + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +// pathSpecial is used to test special paths. +func pathSpecial(b *backend) *framework.Path { + return &framework.Path{ + Pattern: "special", + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathSpecialRead, + }, + } +} + +func (b *backend) pathSpecialRead( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + // Return the secret + return &logical.Response{ + Data: map[string]interface{}{ + "data": "foo", + }, + }, nil + +} diff --git a/logical/plugin/plugin.go b/logical/plugin/plugin.go index eeb6e073c..7eba9ed5c 100644 --- a/logical/plugin/plugin.go +++ b/logical/plugin/plugin.go @@ -40,8 +40,9 @@ func (b *BackendPluginClient) Cleanup() { // NewBackend will return an instance of an RPC-based client implementation of the backend for // external plugins, or a concrete implementation of the backend if it is a builtin backend. -// The backend is returned as a logical.Backend interface. -func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (logical.Backend, error) { +// The backend is returned as a logical.Backend interface. The isMetadataMode param determines whether +// the plugin should run in metadata mode. +func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger, isMetadataMode bool) (logical.Backend, error) { // Look for plugin in the plugin catalog pluginRunner, err := sys.LookupPlugin(pluginName) if err != nil { @@ -65,7 +66,7 @@ func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Log } else { // create a backendPluginClient instance - backend, err = newPluginClient(sys, pluginRunner, logger) + backend, err = newPluginClient(sys, pluginRunner, logger, isMetadataMode) if err != nil { return nil, err } @@ -74,12 +75,21 @@ func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Log return backend, nil } -func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger) (logical.Backend, error) { +func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (logical.Backend, error) { // pluginMap is the map of plugins we can dispense. pluginMap := map[string]plugin.Plugin{ - "backend": &BackendPlugin{}, + "backend": &BackendPlugin{ + metadataMode: isMetadataMode, + }, + } + + var client *plugin.Client + var err error + if isMetadataMode { + client, err = pluginRunner.RunMetadataMode(sys, pluginMap, handshakeConfig, []string{}, logger) + } else { + client, err = pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}, logger) } - client, err := pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}, logger) if err != nil { return nil, err } diff --git a/logical/plugin/serve.go b/logical/plugin/serve.go index 4eb69a8c5..7a52754c3 100644 --- a/logical/plugin/serve.go +++ b/logical/plugin/serve.go @@ -20,7 +20,8 @@ type ServeOpts struct { TLSProviderFunc TLSProdiverFunc } -// Serve is used to serve a backend plugin +// Serve is a helper function used to serve a backend plugin. This +// should be ran on the plugin's main process. func Serve(opts *ServeOpts) error { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ @@ -34,6 +35,7 @@ func Serve(opts *ServeOpts) error { return err } + // If FetchMetadata is true, run without TLSProvider plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, diff --git a/logical/plugin/storage.go b/logical/plugin/storage.go index 55cea8449..99c21f646 100644 --- a/logical/plugin/storage.go +++ b/logical/plugin/storage.go @@ -117,3 +117,23 @@ type StoragePutReply struct { type StorageDeleteReply struct { Error *plugin.BasicError } + +// NOOPStorage is used to deny access to the storage interface while running a +// backend plugin in metadata mode. +type NOOPStorage struct{} + +func (s *NOOPStorage) List(prefix string) ([]string, error) { + return []string{}, nil +} + +func (s *NOOPStorage) Get(key string) (*logical.StorageEntry, error) { + return nil, nil +} + +func (s *NOOPStorage) Put(entry *logical.StorageEntry) error { + return nil +} + +func (s *NOOPStorage) Delete(key string) error { + return nil +} diff --git a/vault/auth.go b/vault/auth.go index 608a217a5..3cd93be2a 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -5,6 +5,7 @@ import ( "fmt" "strings" + "github.com/hashicorp/errwrap" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/logical" @@ -397,7 +398,6 @@ func (c *Core) persistAuth(table *MountTable, localOnly bool) error { // setupCredentials is invoked after we've loaded the auth table to // initialize the credential backends and setup the router func (c *Core) setupCredentials() error { - var backend logical.Backend var view *BarrierView var err error var persistNeeded bool @@ -406,6 +406,7 @@ func (c *Core) setupCredentials() error { defer c.authLock.Unlock() for _, entry := range c.auth.Entries { + var backend logical.Backend // Work around some problematic code that existed in master for a while if strings.HasPrefix(entry.Path, credentialRoutePrefix) { entry.Path = strings.TrimPrefix(entry.Path, credentialRoutePrefix) @@ -425,6 +426,9 @@ func (c *Core) setupCredentials() error { backend, err = c.newCredentialBackend(entry.Type, sysView, view, conf) if err != nil { c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err) + if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" { + goto ROUTER_MOUNT + } return errLoadAuthFailed } if backend == nil { @@ -432,15 +436,14 @@ func (c *Core) setupCredentials() error { } // Check for the correct backend type - backendType := backend.Type() - if entry.Type == "plugin" && backendType != logical.TypeCredential { - return fmt.Errorf("cannot mount '%s' of type '%s' as an auth backend", entry.Config.PluginName, backendType) + if entry.Type == "plugin" && backend.Type() != logical.TypeCredential { + return fmt.Errorf("cannot mount '%s' of type '%s' as an auth backend", entry.Config.PluginName, backend.Type()) } if err := backend.Initialize(); err != nil { return err } - + ROUTER_MOUNT: // Mount the backend path := credentialRoutePrefix + entry.Path err = c.router.Mount(backend, path, entry, view) diff --git a/vault/core.go b/vault/core.go index e63e29aaf..b79f3fbde 100644 --- a/vault/core.go +++ b/vault/core.go @@ -1369,9 +1369,6 @@ func (c *Core) postUnseal() (retErr error) { if err := c.setupMounts(); err != nil { return err } - if err := c.startRollback(); err != nil { - return err - } if err := c.setupPolicyStore(); err != nil { return err } @@ -1384,6 +1381,9 @@ func (c *Core) postUnseal() (retErr error) { if err := c.setupCredentials(); err != nil { return err } + if err := c.startRollback(); err != nil { + return err + } if err := c.setupExpiration(); err != nil { return err } diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index af6cb7fbe..3a6456364 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -4,6 +4,8 @@ import ( "fmt" "time" + "github.com/hashicorp/errwrap" + "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/helper/wrapping" @@ -132,7 +134,7 @@ func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, return nil, err } if r == nil { - return nil, fmt.Errorf("no plugin found with name: %s", name) + return nil, errwrap.Wrapf(fmt.Sprintf("{{err}}: %s", name), ErrPluginNotFound) } return r, nil diff --git a/vault/logical_system_integ_test.go b/vault/logical_system_integ_test.go index 599c5de2a..60eab6b69 100644 --- a/vault/logical_system_integ_test.go +++ b/vault/logical_system_integ_test.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "testing" + "time" "github.com/hashicorp/vault/builtin/plugin" "github.com/hashicorp/vault/helper/pluginutil" @@ -15,17 +16,196 @@ import ( ) func TestSystemBackend_Plugin_secret(t *testing.T) { - cluster := testSystemBackendMock(t, 2, logical.TypeLogical) + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) defer cluster.Cleanup() + + core := cluster.Cores[0] + + // Make a request to lazy load the plugin + req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: response should not be nil") + } + + // Seal the cluster + cluster.EnsureCoresSealed(t) + + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + sealed, err := core.Sealed() + if err != nil { + t.Fatalf("err checking seal status: %s", err) + } + if sealed { + t.Fatal("should not be sealed") + } + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + } } func TestSystemBackend_Plugin_auth(t *testing.T) { - cluster := testSystemBackendMock(t, 2, logical.TypeCredential) + cluster := testSystemBackendMock(t, 1, 1, logical.TypeCredential) defer cluster.Cleanup() + + core := cluster.Cores[0] + + // Make a request to lazy load the plugin + req := logical.TestRequest(t, logical.ReadOperation, "auth/mock-0/internal") + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: response should not be nil") + } + + // Seal the cluster + cluster.EnsureCoresSealed(t) + + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + sealed, err := core.Sealed() + if err != nil { + t.Fatalf("err checking seal status: %s", err) + } + if sealed { + t.Fatal("should not be sealed") + } + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + } +} + +func TestSystemBackend_Plugin_MismatchType(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) + defer cluster.Cleanup() + + core := cluster.Cores[0] + + // Replace the plugin with a credential backend + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials") + + // Make a request to lazy load the now-credential plugin + // and expect an error + req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") + req.ClientToken = core.Client.Token() + _, err := core.HandleRequest(req) + if err == nil { + t.Fatalf("expected error due to mismatch on error type: %s", err) + } + + // Sleep a bit before cleanup is called + time.Sleep(1 * time.Second) +} + +func TestSystemBackend_Plugin_CatalogRemoved(t *testing.T) { + t.Run("secret", func(t *testing.T) { + testPlugin_CatalogRemoved(t, logical.TypeLogical, false) + }) + + t.Run("auth", func(t *testing.T) { + testPlugin_CatalogRemoved(t, logical.TypeCredential, false) + }) + + t.Run("secret-mount-existing", func(t *testing.T) { + testPlugin_CatalogRemoved(t, logical.TypeLogical, true) + }) + + t.Run("auth-mount-existing", func(t *testing.T) { + testPlugin_CatalogRemoved(t, logical.TypeCredential, true) + }) +} + +func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMount bool) { + cluster := testSystemBackendMock(t, 1, 1, btype) + defer cluster.Cleanup() + + core := cluster.Cores[0] + + // Remove the plugin from the catalog + req := logical.TestRequest(t, logical.DeleteOperation, "sys/plugins/catalog/mock-plugin") + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + // Seal the cluster + cluster.EnsureCoresSealed(t) + + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + sealed, err := core.Sealed() + if err != nil { + t.Fatalf("err checking seal status: %s", err) + } + if sealed { + t.Fatal("should not be sealed") + } + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + } + + if testMount { + // Add plugin back to the catalog + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical") + + // Mount the plugin at the same path after plugin is re-added to the catalog + // and expect an error due to existing path. + var err error + switch btype { + case logical.TypeLogical: + _, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{ + "type": "plugin", + "config": map[string]interface{}{ + "plugin_name": "mock-plugin", + }, + }) + case logical.TypeCredential: + _, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{ + "type": "plugin", + "plugin_name": "mock-plugin", + }) + } + if err == nil { + t.Fatal("expected error when mounting on existing path") + } + } } func TestSystemBackend_Plugin_autoReload(t *testing.T) { - cluster := testSystemBackendMock(t, 1, logical.TypeLogical) + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) defer cluster.Cleanup() core := cluster.Cores[0] @@ -65,6 +245,35 @@ func TestSystemBackend_Plugin_autoReload(t *testing.T) { } } +func TestSystemBackend_Plugin_SealUnseal(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) + defer cluster.Cleanup() + + // Seal the cluster + cluster.EnsureCoresSealed(t) + + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + sealed, err := core.Sealed() + if err != nil { + t.Fatalf("err checking seal status: %s", err) + } + if sealed { + t.Fatal("should not be sealed") + } + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + } +} + func TestSystemBackend_Plugin_reload(t *testing.T) { data := map[string]interface{}{ "plugin": "mock-plugin", @@ -77,8 +286,9 @@ func TestSystemBackend_Plugin_reload(t *testing.T) { t.Run("mounts", func(t *testing.T) { testSystemBackend_PluginReload(t, data) }) } +// Helper func to test different reload methods on plugin reload endpoint func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}) { - cluster := testSystemBackendMock(t, 2, logical.TypeLogical) + cluster := testSystemBackendMock(t, 1, 2, logical.TypeLogical) defer cluster.Cleanup() core := cluster.Cores[0] @@ -123,7 +333,7 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{} // testSystemBackendMock returns a systemBackend with the desired number // of mounted mock plugin backends -func testSystemBackendMock(t *testing.T, numMounts int, backendType logical.BackendType) *vault.TestCluster { +func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType) *vault.TestCluster { coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "plugin": plugin.Factory, @@ -134,7 +344,9 @@ func testSystemBackendMock(t *testing.T, numMounts int, backendType logical.Back } cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ - HandlerFunc: vaulthttp.Handler, + HandlerFunc: vaulthttp.Handler, + KeepStandbysSealed: true, + NumCores: numCores, }) cluster.Start() @@ -197,7 +409,8 @@ func testSystemBackendMock(t *testing.T, numMounts int, backendType logical.Back } func TestBackend_PluginMainLogical(t *testing.T) { - if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { + args := []string{} + if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadaModeEnv) != "true" { return } @@ -205,16 +418,16 @@ func TestBackend_PluginMainLogical(t *testing.T) { if caPEM == "" { t.Fatal("CA cert not passed in") } - - factoryFunc := mock.FactoryType(logical.TypeLogical) - - args := []string{"--ca-cert=" + caPEM} + args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM)) apiClientMeta := &pluginutil.APIClientMeta{} flags := apiClientMeta.FlagSet() flags.Parse(args) tlsConfig := apiClientMeta.GetTLSConfig() tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig) + + factoryFunc := mock.FactoryType(logical.TypeLogical) + err := lplugin.Serve(&lplugin.ServeOpts{ BackendFactoryFunc: factoryFunc, TLSProviderFunc: tlsProviderFunc, @@ -225,7 +438,8 @@ func TestBackend_PluginMainLogical(t *testing.T) { } func TestBackend_PluginMainCredentials(t *testing.T) { - if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { + args := []string{} + if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadaModeEnv) != "true" { return } @@ -233,16 +447,16 @@ func TestBackend_PluginMainCredentials(t *testing.T) { if caPEM == "" { t.Fatal("CA cert not passed in") } - - factoryFunc := mock.FactoryType(logical.TypeCredential) - - args := []string{"--ca-cert=" + caPEM} + args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM)) apiClientMeta := &pluginutil.APIClientMeta{} flags := apiClientMeta.FlagSet() flags.Parse(args) tlsConfig := apiClientMeta.GetTLSConfig() tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig) + + factoryFunc := mock.FactoryType(logical.TypeCredential) + err := lplugin.Serve(&lplugin.ServeOpts{ BackendFactoryFunc: factoryFunc, TLSProviderFunc: tlsProviderFunc, diff --git a/vault/mount.go b/vault/mount.go index 16b048a69..1fd5f748c 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/hashicorp/errwrap" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/jsonutil" @@ -663,11 +664,12 @@ func (c *Core) setupMounts() error { c.mountsLock.Lock() defer c.mountsLock.Unlock() - var backend logical.Backend var view *BarrierView var err error for _, entry := range c.mounts.Entries { + var backend logical.Backend + // Initialize the backend, special casing for system barrierPath := backendBarrierPrefix + entry.UUID + "/" if entry.Type == "system" { @@ -686,6 +688,9 @@ func (c *Core) setupMounts() error { backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf) if err != nil { c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err) + if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" { + goto ROUTER_MOUNT + } return errLoadMountsFailed } if backend == nil { @@ -693,9 +698,8 @@ func (c *Core) setupMounts() error { } // Check for the correct backend type - backendType := backend.Type() - if entry.Type == "plugin" && backendType != logical.TypeLogical { - return fmt.Errorf("cannot mount '%s' of type '%s' as a logical backend", entry.Config.PluginName, backendType) + if entry.Type == "plugin" && backend.Type() != logical.TypeLogical { + return fmt.Errorf("cannot mount '%s' of type '%s' as a logical backend", entry.Config.PluginName, backend.Type()) } if err := backend.Initialize(); err != nil { @@ -710,7 +714,7 @@ func (c *Core) setupMounts() error { ch.saltUUID = entry.UUID ch.storageView = view } - + ROUTER_MOUNT: // Mount the backend err = c.router.Mount(backend, entry.Path, entry, view) if err != nil { diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 09f612cc8..3e2466ff6 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -19,6 +19,7 @@ import ( var ( pluginCatalogPath = "core/plugin-catalog/" ErrDirectoryNotConfigured = errors.New("could not set plugin, plugin directory is not configured") + ErrPluginNotFound = errors.New("plugin not found in the catalog") ) // PluginCatalog keeps a record of plugins known to vault. External plugins need @@ -37,6 +38,10 @@ func (c *Core) setupPluginCatalog() error { directory: c.pluginDirectory, } + if c.logger.IsInfo() { + c.logger.Info("core: successfully setup plugin catalog", "plugin-directory", c.pluginDirectory) + } + return nil } diff --git a/vault/router.go b/vault/router.go index e6c36c4c9..6e516be6a 100644 --- a/vault/router.go +++ b/vault/router.go @@ -64,9 +64,12 @@ func (r *Router) Mount(backend logical.Backend, prefix string, mountEntry *Mount } // Build the paths - paths := backend.SpecialPaths() - if paths == nil { - paths = new(logical.Paths) + paths := new(logical.Paths) + if backend != nil { + specialPaths := backend.SpecialPaths() + if specialPaths != nil { + paths = specialPaths + } } // Create a mount entry diff --git a/vault/testing.go b/vault/testing.go index 3c6a2713c..f4d2cfc5d 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -335,11 +335,17 @@ func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) { } sum := hash.Sum(nil) - c.pluginCatalog.directory, err = filepath.EvalSymlinks(os.Args[0]) + + // Determine plugin directory path + fullPath, err := filepath.EvalSymlinks(os.Args[0]) if err != nil { t.Fatal(err) } - c.pluginCatalog.directory = filepath.Dir(c.pluginCatalog.directory) + directoryPath := filepath.Dir(fullPath) + + // Set core's plugin directory and plugin catalog directory + c.pluginDirectory = directoryPath + c.pluginCatalog.directory = directoryPath command := fmt.Sprintf("%s --test.run=%s", filepath.Base(os.Args[0]), testFunc) err = c.pluginCatalog.Set(name, command, sum) @@ -585,6 +591,7 @@ func GenerateRandBytes(length int) ([]byte, error) { } func TestWaitActive(t testing.T, core *Core) { + t.Helper() start := time.Now() var standby bool var err error @@ -627,6 +634,13 @@ func (c *TestCluster) Start() { } } +func (c *TestCluster) EnsureCoresSealed(t testing.T) { + t.Helper() + if err := c.ensureCoresSealed(); err != nil { + t.Fatal(err) + } +} + func (c *TestCluster) Cleanup() { // Close listeners for _, core := range c.Cores { @@ -638,25 +652,7 @@ func (c *TestCluster) Cleanup() { } // Seal the cores - for _, core := range c.Cores { - if err := core.Shutdown(); err != nil { - continue - } - timeout := time.Now().Add(60 * time.Second) - for { - if time.Now().After(timeout) { - continue - } - sealed, err := core.Sealed() - if err != nil { - continue - } - if sealed { - break - } - time.Sleep(250 * time.Millisecond) - } - } + c.ensureCoresSealed() // Remove any temp dir that exists if c.TempDir != "" { @@ -667,6 +663,29 @@ func (c *TestCluster) Cleanup() { time.Sleep(time.Second) } +func (c *TestCluster) ensureCoresSealed() error { + for _, core := range c.Cores { + if err := core.Shutdown(); err != nil { + return err + } + timeout := time.Now().Add(60 * time.Second) + for { + if time.Now().After(timeout) { + return fmt.Errorf("timeout waiting for core to seal") + } + sealed, err := core.Sealed() + if err != nil { + return err + } + if sealed { + break + } + time.Sleep(250 * time.Millisecond) + } + } + return nil +} + type TestListener struct { net.Listener Address *net.TCPAddr @@ -692,9 +711,29 @@ type TestClusterOptions struct { KeepStandbysSealed bool HandlerFunc func(*Core) http.Handler BaseListenAddress string + NumCores int } +var DefaultNumCores = 3 + +type certInfo struct { + cert *x509.Certificate + certPEM []byte + certBytes []byte + key *ecdsa.PrivateKey + keyPEM []byte +} + +// NewTestCluster creates a new test cluster based on the provided core config +// and test cluster options. func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster { + var numCores int + if opts == nil || opts.NumCores == 0 { + numCores = DefaultNumCores + } else { + numCores = opts.NumCores + } + certIPs := []net.IP{ net.IPv6loopback, net.ParseIP("127.0.0.1"), @@ -770,270 +809,131 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te t.Fatal(err) } - s1Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatal(err) - } - s1CertTemplate := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "localhost", - }, - DNSNames: []string{"localhost"}, - IPAddresses: certIPs, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageServerAuth, - x509.ExtKeyUsageClientAuth, - }, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - NotAfter: time.Now().Add(262980 * time.Hour), - } - s1CertBytes, err := x509.CreateCertificate(rand.Reader, s1CertTemplate, caCert, s1Key.Public(), caKey) - if err != nil { - t.Fatal(err) - } - s1Cert, err := x509.ParseCertificate(s1CertBytes) - if err != nil { - t.Fatal(err) - } - s1CertPEMBlock := &pem.Block{ - Type: "CERTIFICATE", - Bytes: s1CertBytes, - } - s1CertPEM := pem.EncodeToMemory(s1CertPEMBlock) - s1MarshaledKey, err := x509.MarshalECPrivateKey(s1Key) - if err != nil { - t.Fatal(err) - } - s1KeyPEMBlock := &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: s1MarshaledKey, - } - s1KeyPEM := pem.EncodeToMemory(s1KeyPEMBlock) + var certInfoSlice []*certInfo - s2Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatal(err) - } - s2CertTemplate := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "localhost", - }, - DNSNames: []string{"localhost"}, - IPAddresses: certIPs, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageServerAuth, - x509.ExtKeyUsageClientAuth, - }, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - NotAfter: time.Now().Add(262980 * time.Hour), - } - s2CertBytes, err := x509.CreateCertificate(rand.Reader, s2CertTemplate, caCert, s2Key.Public(), caKey) - if err != nil { - t.Fatal(err) - } - s2Cert, err := x509.ParseCertificate(s2CertBytes) - if err != nil { - t.Fatal(err) - } - s2CertPEMBlock := &pem.Block{ - Type: "CERTIFICATE", - Bytes: s2CertBytes, - } - s2CertPEM := pem.EncodeToMemory(s2CertPEMBlock) - s2MarshaledKey, err := x509.MarshalECPrivateKey(s2Key) - if err != nil { - t.Fatal(err) - } - s2KeyPEMBlock := &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: s2MarshaledKey, - } - s2KeyPEM := pem.EncodeToMemory(s2KeyPEMBlock) + // + // Certs generation + // + for i := 0; i < numCores; i++ { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + certTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "localhost", + }, + DNSNames: []string{"localhost"}, + IPAddresses: certIPs, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + } + certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey) + if err != nil { + t.Fatal(err) + } + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + t.Fatal(err) + } + certPEMBlock := &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + } + certPEM := pem.EncodeToMemory(certPEMBlock) + marshaledKey, err := x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatal(err) + } + keyPEMBlock := &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: marshaledKey, + } + keyPEM := pem.EncodeToMemory(keyPEMBlock) - s3Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatal(err) + certInfoSlice = append(certInfoSlice, &certInfo{ + cert: cert, + certPEM: certPEM, + certBytes: certBytes, + key: key, + keyPEM: keyPEM, + }) } - s3CertTemplate := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: "localhost", - }, - DNSNames: []string{"localhost"}, - IPAddresses: certIPs, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageServerAuth, - x509.ExtKeyUsageClientAuth, - }, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - NotAfter: time.Now().Add(262980 * time.Hour), - } - s3CertBytes, err := x509.CreateCertificate(rand.Reader, s3CertTemplate, caCert, s3Key.Public(), caKey) - if err != nil { - t.Fatal(err) - } - s3Cert, err := x509.ParseCertificate(s3CertBytes) - if err != nil { - t.Fatal(err) - } - s3CertPEMBlock := &pem.Block{ - Type: "CERTIFICATE", - Bytes: s3CertBytes, - } - s3CertPEM := pem.EncodeToMemory(s3CertPEMBlock) - s3MarshaledKey, err := x509.MarshalECPrivateKey(s3Key) - if err != nil { - t.Fatal(err) - } - s3KeyPEMBlock := &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: s3MarshaledKey, - } - s3KeyPEM := pem.EncodeToMemory(s3KeyPEMBlock) - - logger := logformat.NewVaultLogger(log.LevelTrace) // // Listener setup // - ports := []int{0, 0, 0} + logger := logformat.NewVaultLogger(log.LevelTrace) + ports := make([]int, numCores) if baseAddr != nil { - ports = []int{baseAddr.Port, baseAddr.Port + 1, baseAddr.Port + 2} + for i := 0; i < numCores; i++ { + ports[i] = baseAddr.Port + i + } } else { baseAddr = &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 0, } } - baseAddr.Port = ports[0] - ln, err := net.ListenTCP("tcp", baseAddr) - if err != nil { - t.Fatal(err) - } - s1CertFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node1_port_%d_cert.pem", ln.Addr().(*net.TCPAddr).Port)) - s1KeyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node1_port_%d_key.pem", ln.Addr().(*net.TCPAddr).Port)) - err = ioutil.WriteFile(s1CertFile, s1CertPEM, 0755) - if err != nil { - t.Fatal(err) - } - err = ioutil.WriteFile(s1KeyFile, s1KeyPEM, 0755) - if err != nil { - t.Fatal(err) - } - s1TLSCert, err := tls.X509KeyPair(s1CertPEM, s1KeyPEM) - if err != nil { - t.Fatal(err) - } - s1CertGetter := reload.NewCertificateGetter(s1CertFile, s1KeyFile) - s1TLSConfig := &tls.Config{ - Certificates: []tls.Certificate{s1TLSCert}, - RootCAs: testCluster.RootCAs, - ClientCAs: testCluster.RootCAs, - ClientAuth: tls.VerifyClientCertIfGiven, - NextProtos: []string{"h2", "http/1.1"}, - GetCertificate: s1CertGetter.GetCertificate, - } - s1TLSConfig.BuildNameToCertificate() - c1lns := []*TestListener{&TestListener{ - Listener: tls.NewListener(ln, s1TLSConfig), - Address: ln.Addr().(*net.TCPAddr), - }, - } - var handler1 http.Handler = http.NewServeMux() - server1 := &http.Server{ - Handler: handler1, - } - if err := http2.ConfigureServer(server1, nil); err != nil { - t.Fatal(err) - } - baseAddr.Port = ports[1] - ln, err = net.ListenTCP("tcp", baseAddr) - if err != nil { - t.Fatal(err) - } - s2CertFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node2_port_%d_cert.pem", ln.Addr().(*net.TCPAddr).Port)) - s2KeyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node2_port_%d_key.pem", ln.Addr().(*net.TCPAddr).Port)) - err = ioutil.WriteFile(s2CertFile, s2CertPEM, 0755) - if err != nil { - t.Fatal(err) - } - err = ioutil.WriteFile(s2KeyFile, s2KeyPEM, 0755) - if err != nil { - t.Fatal(err) - } - s2TLSCert, err := tls.X509KeyPair(s2CertPEM, s2KeyPEM) - if err != nil { - t.Fatal(err) - } - s2CertGetter := reload.NewCertificateGetter(s2CertFile, s2KeyFile) - s2TLSConfig := &tls.Config{ - Certificates: []tls.Certificate{s2TLSCert}, - RootCAs: testCluster.RootCAs, - ClientCAs: testCluster.RootCAs, - ClientAuth: tls.VerifyClientCertIfGiven, - NextProtos: []string{"h2", "http/1.1"}, - GetCertificate: s2CertGetter.GetCertificate, - } - s2TLSConfig.BuildNameToCertificate() - c2lns := []*TestListener{&TestListener{ - Listener: tls.NewListener(ln, s2TLSConfig), - Address: ln.Addr().(*net.TCPAddr), - }, - } - var handler2 http.Handler = http.NewServeMux() - server2 := &http.Server{ - Handler: handler2, - } - if err := http2.ConfigureServer(server2, nil); err != nil { - t.Fatal(err) - } - - baseAddr.Port = ports[2] - ln, err = net.ListenTCP("tcp", baseAddr) - if err != nil { - t.Fatal(err) - } - s3CertFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node3_port_%d_cert.pem", ln.Addr().(*net.TCPAddr).Port)) - s3KeyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node3_port_%d_key.pem", ln.Addr().(*net.TCPAddr).Port)) - err = ioutil.WriteFile(s3CertFile, s3CertPEM, 0755) - if err != nil { - t.Fatal(err) - } - err = ioutil.WriteFile(s3KeyFile, s3KeyPEM, 0755) - if err != nil { - t.Fatal(err) - } - s3TLSCert, err := tls.X509KeyPair(s3CertPEM, s3KeyPEM) - if err != nil { - t.Fatal(err) - } - s3CertGetter := reload.NewCertificateGetter(s3CertFile, s3KeyFile) - s3TLSConfig := &tls.Config{ - Certificates: []tls.Certificate{s3TLSCert}, - RootCAs: testCluster.RootCAs, - ClientCAs: testCluster.RootCAs, - ClientAuth: tls.VerifyClientCertIfGiven, - NextProtos: []string{"h2", "http/1.1"}, - GetCertificate: s3CertGetter.GetCertificate, - } - s3TLSConfig.BuildNameToCertificate() - c3lns := []*TestListener{&TestListener{ - Listener: tls.NewListener(ln, s3TLSConfig), - Address: ln.Addr().(*net.TCPAddr), - }, - } - var handler3 http.Handler = http.NewServeMux() - server3 := &http.Server{ - Handler: handler3, - } - if err := http2.ConfigureServer(server3, nil); err != nil { - t.Fatal(err) + listeners := [][]*TestListener{} + servers := []*http.Server{} + handlers := []http.Handler{} + tlsConfigs := []*tls.Config{} + certGetters := []*reload.CertificateGetter{} + for i := 0; i < numCores; i++ { + baseAddr.Port = ports[i] + ln, err := net.ListenTCP("tcp", baseAddr) + if err != nil { + t.Fatal(err) + } + certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) + keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port)) + err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0755) + if err != nil { + t.Fatal(err) + } + err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0755) + if err != nil { + t.Fatal(err) + } + tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM) + if err != nil { + t.Fatal(err) + } + certGetter := reload.NewCertificateGetter(certFile, keyFile) + certGetters = append(certGetters, certGetter) + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + RootCAs: testCluster.RootCAs, + ClientCAs: testCluster.RootCAs, + ClientAuth: tls.VerifyClientCertIfGiven, + NextProtos: []string{"h2", "http/1.1"}, + GetCertificate: certGetter.GetCertificate, + } + tlsConfig.BuildNameToCertificate() + tlsConfigs = append(tlsConfigs, tlsConfig) + lns := []*TestListener{&TestListener{ + Listener: tls.NewListener(ln, tlsConfig), + Address: ln.Addr().(*net.TCPAddr), + }, + } + listeners = append(listeners, lns) + var handler http.Handler = http.NewServeMux() + handlers = append(handlers, handler) + server := &http.Server{ + Handler: handler, + } + servers = append(servers, server) + if err := http2.ConfigureServer(server, nil); err != nil { + t.Fatal(err) + } } // Create three cores with the same physical and different redirect/cluster @@ -1049,8 +949,8 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te LogicalBackends: make(map[string]logical.Factory), CredentialBackends: make(map[string]logical.Factory), AuditBackends: make(map[string]audit.Factory), - RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port), - ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port+105), + RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port), + ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port+105), DisableMlock: true, EnableUI: true, } @@ -1126,39 +1026,21 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te coreConfig.HAPhysical = haPhys.(physical.HABackend) } - c1, err := NewCore(coreConfig) - if err != nil { - t.Fatalf("err: %v", err) - } - if opts != nil && opts.HandlerFunc != nil { - handler1 = opts.HandlerFunc(c1) - server1.Handler = handler1 - } - - coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port) - if coreConfig.ClusterAddr != "" { - coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port+105) - } - c2, err := NewCore(coreConfig) - if err != nil { - t.Fatalf("err: %v", err) - } - if opts != nil && opts.HandlerFunc != nil { - handler2 = opts.HandlerFunc(c2) - server2.Handler = handler2 - } - - coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port) - if coreConfig.ClusterAddr != "" { - coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port+105) - } - c3, err := NewCore(coreConfig) - if err != nil { - t.Fatalf("err: %v", err) - } - if opts != nil && opts.HandlerFunc != nil { - handler3 = opts.HandlerFunc(c3) - server3.Handler = handler3 + cores := []*Core{} + for i := 0; i < numCores; i++ { + coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port) + if coreConfig.ClusterAddr != "" { + coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port+105) + } + c, err := NewCore(coreConfig) + if err != nil { + t.Fatalf("err: %v", err) + } + cores = append(cores, c) + if opts != nil && opts.HandlerFunc != nil { + handlers[i] = opts.HandlerFunc(c) + servers[i].Handler = handlers[i] + } } // @@ -1175,16 +1057,19 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te return ret } - c2.SetClusterListenerAddrs(clusterAddrGen(c2lns)) - c2.SetClusterHandler(handler2) - c3.SetClusterListenerAddrs(clusterAddrGen(c3lns)) - c3.SetClusterHandler(handler3) + if numCores > 1 { + for i := 1; i < numCores; i++ { + cores[i].SetClusterListenerAddrs(clusterAddrGen(listeners[i])) + cores[i].SetClusterHandler(handlers[i]) + } + } - keys, root := TestCoreInitClusterWrapperSetup(t, c1, clusterAddrGen(c1lns), handler1) + keys, root := TestCoreInitClusterWrapperSetup(t, cores[0], clusterAddrGen(listeners[0]), handlers[0]) barrierKeys, _ := copystructure.Copy(keys) testCluster.BarrierKeys = barrierKeys.([][]byte) testCluster.RootToken = root + // Write root token and barrier keys err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(root), 0755) if err != nil { t.Fatal(err) @@ -1201,14 +1086,15 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te t.Fatal(err) } + // Unseal first core for _, key := range keys { - if _, err := c1.Unseal(TestKeyCopy(key)); err != nil { + if _, err := cores[0].Unseal(TestKeyCopy(key)); err != nil { t.Fatalf("unseal err: %s", err) } } // Verify unsealed - sealed, err := c1.Sealed() + sealed, err := cores[0].Sealed() if err != nil { t.Fatalf("err checking seal status: %s", err) } @@ -1216,41 +1102,38 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te t.Fatal("should not be sealed") } - TestWaitActive(t, c1) + TestWaitActive(t, cores[0]) - if opts == nil || !opts.KeepStandbysSealed { - for _, key := range keys { - if _, err := c2.Unseal(TestKeyCopy(key)); err != nil { - t.Fatalf("unseal err: %s", err) - } - } - for _, key := range keys { - if _, err := c3.Unseal(TestKeyCopy(key)); err != nil { - t.Fatalf("unseal err: %s", err) + // Unseal other cores unless otherwise specified + if (opts == nil || !opts.KeepStandbysSealed) && numCores > 1 { + for i := 1; i < numCores; i++ { + for _, key := range keys { + if _, err := cores[i].Unseal(TestKeyCopy(key)); err != nil { + t.Fatalf("unseal err: %s", err) + } } } // Let them come fully up to standby time.Sleep(2 * time.Second) - // Ensure cluster connection info is populated - isLeader, _, _, err := c2.Leader() - if err != nil { - t.Fatal(err) - } - if isLeader { - t.Fatal("c2 should not be leader") - } - isLeader, _, _, err = c3.Leader() - if err != nil { - t.Fatal(err) - } - if isLeader { - t.Fatal("c3 should not be leader") + // Ensure cluster connection info is populated. + // Other cores should not come up as leaders. + for i := 1; i < numCores; i++ { + isLeader, _, _, err := cores[i].Leader() + if err != nil { + t.Fatal(err) + } + if isLeader { + t.Fatalf("core[%d] should not be leader", i) + } } } - cluster, err := c1.Cluster() + // + // Set test cluster core(s) and test cluster + // + cluster, err := cores[0].Cluster() if err != nil { t.Fatal(err) } @@ -1278,65 +1161,27 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te } var ret []*TestClusterCore - t1 := &TestClusterCore{ - Core: c1, - ServerKey: s1Key, - ServerKeyPEM: s1KeyPEM, - ServerCert: s1Cert, - ServerCertBytes: s1CertBytes, - ServerCertPEM: s1CertPEM, - Listeners: c1lns, - Handler: handler1, - Server: server1, - TLSConfig: s1TLSConfig, - Client: getAPIClient(c1lns[0].Address.Port, s1TLSConfig), + for i := 0; i < numCores; i++ { + tcc := &TestClusterCore{ + Core: cores[i], + ServerKey: certInfoSlice[i].key, + ServerKeyPEM: certInfoSlice[i].keyPEM, + ServerCert: certInfoSlice[i].cert, + ServerCertBytes: certInfoSlice[i].certBytes, + ServerCertPEM: certInfoSlice[i].certPEM, + Listeners: listeners[i], + Handler: handlers[i], + Server: servers[i], + TLSConfig: tlsConfigs[i], + Client: getAPIClient(listeners[i][0].Address.Port, tlsConfigs[i]), + } + tcc.ReloadFuncs = &cores[i].reloadFuncs + tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock + tcc.ReloadFuncsLock.Lock() + (*tcc.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{certGetters[i].Reload} + tcc.ReloadFuncsLock.Unlock() + ret = append(ret, tcc) } - t1.ReloadFuncs = &c1.reloadFuncs - t1.ReloadFuncsLock = &c1.reloadFuncsLock - t1.ReloadFuncsLock.Lock() - (*t1.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{s1CertGetter.Reload} - t1.ReloadFuncsLock.Unlock() - ret = append(ret, t1) - - t2 := &TestClusterCore{ - Core: c2, - ServerKey: s2Key, - ServerKeyPEM: s2KeyPEM, - ServerCert: s2Cert, - ServerCertBytes: s2CertBytes, - ServerCertPEM: s2CertPEM, - Listeners: c2lns, - Handler: handler2, - Server: server2, - TLSConfig: s2TLSConfig, - Client: getAPIClient(c2lns[0].Address.Port, s2TLSConfig), - } - t2.ReloadFuncs = &c2.reloadFuncs - t2.ReloadFuncsLock = &c2.reloadFuncsLock - t2.ReloadFuncsLock.Lock() - (*t2.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{s2CertGetter.Reload} - t2.ReloadFuncsLock.Unlock() - ret = append(ret, t2) - - t3 := &TestClusterCore{ - Core: c3, - ServerKey: s3Key, - ServerKeyPEM: s3KeyPEM, - ServerCert: s3Cert, - ServerCertBytes: s3CertBytes, - ServerCertPEM: s3CertPEM, - Listeners: c3lns, - Handler: handler3, - Server: server3, - TLSConfig: s3TLSConfig, - Client: getAPIClient(c3lns[0].Address.Port, s3TLSConfig), - } - t3.ReloadFuncs = &c3.reloadFuncs - t3.ReloadFuncsLock = &c3.reloadFuncsLock - t3.ReloadFuncsLock.Lock() - (*t3.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{s3CertGetter.Reload} - t3.ReloadFuncsLock.Unlock() - ret = append(ret, t3) testCluster.Cores = ret return &testCluster