From b6c05fae33434fcfa7812f4afe8f63c39461177f Mon Sep 17 00:00:00 2001 From: John-Michael Faircloth Date: Mon, 29 Aug 2022 21:42:26 -0500 Subject: [PATCH] feature: secrets/auth plugin multiplexing (#14946) * enable registering backend muxed plugins in plugin catalog * set the sysview on the pluginconfig to allow enabling secrets/auth plugins * store backend instances in map * store single implementations in the instances map cleanup instance map and ensure we don't deadlock * fix system backend unit tests move GetMultiplexIDFromContext to pluginutil package fix pluginutil test fix dbplugin ut * return error(s) if we can't get the plugin client update comments * refactor/move GetMultiplexIDFromContext test * add changelog * remove unnecessary field on pluginClient * add unit tests to PluginCatalog for secrets/auth plugins * fix comment * return pluginClient from TestRunTestPlugin * add multiplexed backend test * honor metadatamode value in newbackend pluginconfig * check that connection exists on cleanup * add automtls to secrets/auth plugins * don't remove apiclientmeta parsing * use formatting directive for fmt.Errorf * fix ut: remove tls provider func * remove tlsproviderfunc from backend plugin tests * use env var to prevent test plugin from running as a unit test * WIP: remove lazy loading * move non lazy loaded backend to new package * use version wrapper for backend plugin factory * remove backendVersionWrapper type * implement getBackendPluginType for plugin catalog * handle backend plugin v4 registration * add plugin automtls env guard * modify plugin factory to determine the backend to use * remove old pluginsets from v5 and log pid in plugin catalog * add reload mechanism via context * readd v3 and v4 to pluginset * call cleanup from reload if non-muxed * move v5 backend code to new package * use context reload for for ErrPluginShutdown case * add wrapper on v5 backend * fix run config UTs * fix unit tests - use v4/v5 mapping for plugin versions - fix test build err - add reload method on fakePluginClient - add multiplexed cases for integration tests * remove comment and update AutoMTLS field in test * remove comment * remove errwrap and unused context * only support metadatamode false for v5 backend plugins * update plugin catalog errors * use const for env variables * rename locks and remove unused * remove unneeded nil check * improvements based on staticcheck recommendations * use const for single implementation string * use const for context key * use info default log level * move pid to pluginClient struct * remove v3 and v4 from multiplexed plugin set * return from reload when non-multiplexed * update automtls env string * combine getBackend and getBrokeredClient * update comments for plugin reload, Backend return val and log * revert Backend return type * allow non-muxed plugins to serve v5 * move v5 code to existing sdk plugin package * do next export sdk fields now that we have removed extra plugin pkg * set TLSProvider in ServeMultiplex for backwards compat * use bool to flag multiplexing support on grpc backend server * revert userpass main.go * refactor plugin sdk - update comments - make use of multiplexing boolean and single implementation ID const * update comment and use multierr * attempt v4 if dispense fails on getPluginTypeForUnknown * update comments on sdk plugin backend --- api/plugin_helpers.go | 94 +- .../credential/userpass/cmd/userpass/main.go | 1 - builtin/plugin/backend.go | 20 +- builtin/plugin/backend_test.go | 64 +- builtin/plugin/v5/backend.go | 147 +++ changelog/14946.txt | 3 + http/plugin_test.go | 4 - sdk/database/dbplugin/v5/grpc_server.go | 26 +- sdk/database/dbplugin/v5/grpc_server_test.go | 51 - .../dbplugin/v5/plugin_client_test.go | 4 + sdk/helper/pluginutil/env.go | 6 +- sdk/helper/pluginutil/multiplexing.go | 20 + sdk/helper/pluginutil/multiplexing_test.go | 73 ++ sdk/helper/pluginutil/run_config.go | 12 +- sdk/helper/pluginutil/run_config_test.go | 12 +- sdk/helper/pluginutil/runner.go | 1 + sdk/plugin/backend.go | 27 +- sdk/plugin/grpc_backend_server.go | 139 ++- sdk/plugin/grpc_system_test.go | 2 +- sdk/plugin/logger_test.go | 1 - sdk/plugin/middleware.go | 24 +- sdk/plugin/plugin.go | 34 +- sdk/plugin/plugin_v5.go | 165 +++ sdk/plugin/serve.go | 78 +- vault/auth.go | 3 +- vault/logical_system_integ_test.go | 1022 +++++++++++------ vault/mount.go | 1 - vault/plugin_catalog.go | 171 ++- vault/plugin_catalog_test.go | 150 ++- vault/plugin_reload.go | 6 +- vault/testing.go | 45 + vault/wrapping.go | 2 +- 32 files changed, 1742 insertions(+), 666 deletions(-) create mode 100644 builtin/plugin/v5/backend.go create mode 100644 changelog/14946.txt create mode 100644 sdk/helper/pluginutil/multiplexing_test.go create mode 100644 sdk/plugin/plugin_v5.go diff --git a/api/plugin_helpers.go b/api/plugin_helpers.go index e8ceb9c2f..2b1b35c3b 100644 --- a/api/plugin_helpers.go +++ b/api/plugin_helpers.go @@ -16,7 +16,11 @@ import ( "github.com/hashicorp/errwrap" ) -var ( +const ( + // PluginAutoMTLSEnv is used to ensure AutoMTLS is used. This will override + // setting a TLSProviderFunc for a plugin. + PluginAutoMTLSEnv = "VAULT_PLUGIN_AUTOMTLS_ENABLED" + // PluginMetadataModeEnv is an ENV name used to disable TLS communication // to bootstrap mounting plugins. PluginMetadataModeEnv = "VAULT_PLUGIN_METADATA_MODE" @@ -24,51 +28,51 @@ var ( // PluginUnwrapTokenEnv is the ENV name used to pass unwrap tokens to the // plugin. PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN" - - // sudoPaths is a map containing the paths that require a token's policy - // to have the "sudo" capability. The keys are the paths as strings, in - // the same format as they are returned by the OpenAPI spec. The values - // are the regular expressions that can be used to test whether a given - // path matches that path or not (useful specifically for the paths that - // contain templated fields.) - sudoPaths = map[string]*regexp.Regexp{ - "/auth/token/accessors/": regexp.MustCompile(`^/auth/token/accessors/$`), - "/pki/root": regexp.MustCompile(`^/pki/root$`), - "/pki/root/sign-self-issued": regexp.MustCompile(`^/pki/root/sign-self-issued$`), - "/sys/audit": regexp.MustCompile(`^/sys/audit$`), - "/sys/audit/{path}": regexp.MustCompile(`^/sys/audit/.+$`), - "/sys/auth/{path}": regexp.MustCompile(`^/sys/auth/.+$`), - "/sys/auth/{path}/tune": regexp.MustCompile(`^/sys/auth/.+/tune$`), - "/sys/config/auditing/request-headers": regexp.MustCompile(`^/sys/config/auditing/request-headers$`), - "/sys/config/auditing/request-headers/{header}": regexp.MustCompile(`^/sys/config/auditing/request-headers/.+$`), - "/sys/config/cors": regexp.MustCompile(`^/sys/config/cors$`), - "/sys/config/ui/headers/": regexp.MustCompile(`^/sys/config/ui/headers/$`), - "/sys/config/ui/headers/{header}": regexp.MustCompile(`^/sys/config/ui/headers/.+$`), - "/sys/leases": regexp.MustCompile(`^/sys/leases$`), - "/sys/leases/lookup/": regexp.MustCompile(`^/sys/leases/lookup/$`), - "/sys/leases/lookup/{prefix}": regexp.MustCompile(`^/sys/leases/lookup/.+$`), - "/sys/leases/revoke-force/{prefix}": regexp.MustCompile(`^/sys/leases/revoke-force/.+$`), - "/sys/leases/revoke-prefix/{prefix}": regexp.MustCompile(`^/sys/leases/revoke-prefix/.+$`), - "/sys/plugins/catalog/{name}": regexp.MustCompile(`^/sys/plugins/catalog/[^/]+$`), - "/sys/plugins/catalog/{type}": regexp.MustCompile(`^/sys/plugins/catalog/[\w-]+$`), - "/sys/plugins/catalog/{type}/{name}": regexp.MustCompile(`^/sys/plugins/catalog/[\w-]+/[^/]+$`), - "/sys/raw": regexp.MustCompile(`^/sys/raw$`), - "/sys/raw/{path}": regexp.MustCompile(`^/sys/raw/.+$`), - "/sys/remount": regexp.MustCompile(`^/sys/remount$`), - "/sys/revoke-force/{prefix}": regexp.MustCompile(`^/sys/revoke-force/.+$`), - "/sys/revoke-prefix/{prefix}": regexp.MustCompile(`^/sys/revoke-prefix/.+$`), - "/sys/rotate": regexp.MustCompile(`^/sys/rotate$`), - - // enterprise-only paths - "/sys/replication/dr/primary/secondary-token": regexp.MustCompile(`^/sys/replication/dr/primary/secondary-token$`), - "/sys/replication/performance/primary/secondary-token": regexp.MustCompile(`^/sys/replication/performance/primary/secondary-token$`), - "/sys/replication/primary/secondary-token": regexp.MustCompile(`^/sys/replication/primary/secondary-token$`), - "/sys/replication/reindex": regexp.MustCompile(`^/sys/replication/reindex$`), - "/sys/storage/raft/snapshot-auto/config/": regexp.MustCompile(`^/sys/storage/raft/snapshot-auto/config/$`), - "/sys/storage/raft/snapshot-auto/config/{name}": regexp.MustCompile(`^/sys/storage/raft/snapshot-auto/config/[^/]+$`), - } ) +// sudoPaths is a map containing the paths that require a token's policy +// to have the "sudo" capability. The keys are the paths as strings, in +// the same format as they are returned by the OpenAPI spec. The values +// are the regular expressions that can be used to test whether a given +// path matches that path or not (useful specifically for the paths that +// contain templated fields.) +var sudoPaths = map[string]*regexp.Regexp{ + "/auth/token/accessors/": regexp.MustCompile(`^/auth/token/accessors/$`), + "/pki/root": regexp.MustCompile(`^/pki/root$`), + "/pki/root/sign-self-issued": regexp.MustCompile(`^/pki/root/sign-self-issued$`), + "/sys/audit": regexp.MustCompile(`^/sys/audit$`), + "/sys/audit/{path}": regexp.MustCompile(`^/sys/audit/.+$`), + "/sys/auth/{path}": regexp.MustCompile(`^/sys/auth/.+$`), + "/sys/auth/{path}/tune": regexp.MustCompile(`^/sys/auth/.+/tune$`), + "/sys/config/auditing/request-headers": regexp.MustCompile(`^/sys/config/auditing/request-headers$`), + "/sys/config/auditing/request-headers/{header}": regexp.MustCompile(`^/sys/config/auditing/request-headers/.+$`), + "/sys/config/cors": regexp.MustCompile(`^/sys/config/cors$`), + "/sys/config/ui/headers/": regexp.MustCompile(`^/sys/config/ui/headers/$`), + "/sys/config/ui/headers/{header}": regexp.MustCompile(`^/sys/config/ui/headers/.+$`), + "/sys/leases": regexp.MustCompile(`^/sys/leases$`), + "/sys/leases/lookup/": regexp.MustCompile(`^/sys/leases/lookup/$`), + "/sys/leases/lookup/{prefix}": regexp.MustCompile(`^/sys/leases/lookup/.+$`), + "/sys/leases/revoke-force/{prefix}": regexp.MustCompile(`^/sys/leases/revoke-force/.+$`), + "/sys/leases/revoke-prefix/{prefix}": regexp.MustCompile(`^/sys/leases/revoke-prefix/.+$`), + "/sys/plugins/catalog/{name}": regexp.MustCompile(`^/sys/plugins/catalog/[^/]+$`), + "/sys/plugins/catalog/{type}": regexp.MustCompile(`^/sys/plugins/catalog/[\w-]+$`), + "/sys/plugins/catalog/{type}/{name}": regexp.MustCompile(`^/sys/plugins/catalog/[\w-]+/[^/]+$`), + "/sys/raw": regexp.MustCompile(`^/sys/raw$`), + "/sys/raw/{path}": regexp.MustCompile(`^/sys/raw/.+$`), + "/sys/remount": regexp.MustCompile(`^/sys/remount$`), + "/sys/revoke-force/{prefix}": regexp.MustCompile(`^/sys/revoke-force/.+$`), + "/sys/revoke-prefix/{prefix}": regexp.MustCompile(`^/sys/revoke-prefix/.+$`), + "/sys/rotate": regexp.MustCompile(`^/sys/rotate$`), + + // enterprise-only paths + "/sys/replication/dr/primary/secondary-token": regexp.MustCompile(`^/sys/replication/dr/primary/secondary-token$`), + "/sys/replication/performance/primary/secondary-token": regexp.MustCompile(`^/sys/replication/performance/primary/secondary-token$`), + "/sys/replication/primary/secondary-token": regexp.MustCompile(`^/sys/replication/primary/secondary-token$`), + "/sys/replication/reindex": regexp.MustCompile(`^/sys/replication/reindex$`), + "/sys/storage/raft/snapshot-auto/config/": regexp.MustCompile(`^/sys/storage/raft/snapshot-auto/config/$`), + "/sys/storage/raft/snapshot-auto/config/{name}": regexp.MustCompile(`^/sys/storage/raft/snapshot-auto/config/[^/]+$`), +} + // PluginAPIClientMeta is a helper that plugins can use to configure TLS connections // back to Vault. type PluginAPIClientMeta struct { @@ -120,7 +124,7 @@ func VaultPluginTLSProvider(apiTLSConfig *TLSConfig) func() (*tls.Config, error) // VaultPluginTLSProviderContext is run inside a plugin and retrieves the response // wrapped TLS certificate from vault. It returns a configured TLS Config. func VaultPluginTLSProviderContext(ctx context.Context, apiTLSConfig *TLSConfig) func() (*tls.Config, error) { - if os.Getenv(PluginMetadataModeEnv) == "true" { + if os.Getenv(PluginAutoMTLSEnv) == "true" || os.Getenv(PluginMetadataModeEnv) == "true" { return nil } diff --git a/builtin/credential/userpass/cmd/userpass/main.go b/builtin/credential/userpass/cmd/userpass/main.go index 43098807a..5ea1894d2 100644 --- a/builtin/credential/userpass/cmd/userpass/main.go +++ b/builtin/credential/userpass/cmd/userpass/main.go @@ -13,7 +13,6 @@ func main() { apiClientMeta := &api.PluginAPIClientMeta{} flags := apiClientMeta.FlagSet() flags.Parse(os.Args[1:]) - tlsConfig := apiClientMeta.GetTLSConfig() tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig) diff --git a/builtin/plugin/backend.go b/builtin/plugin/backend.go index 751588905..b9dd409b2 100644 --- a/builtin/plugin/backend.go +++ b/builtin/plugin/backend.go @@ -9,7 +9,9 @@ import ( log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-multierror" uuid "github.com/hashicorp/go-uuid" + v5 "github.com/hashicorp/vault/builtin/plugin/v5" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/logical" @@ -23,17 +25,29 @@ var ( // Factory returns a configured plugin logical.Backend. func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { + merr := &multierror.Error{} _, ok := conf.Config["plugin_name"] if !ok { return nil, fmt.Errorf("plugin_name not provided") } - b, err := Backend(ctx, conf) + b, err := v5.Backend(ctx, conf) + if err == nil { + if err := b.Setup(ctx, conf); err != nil { + return nil, err + } + return b, nil + } + merr = multierror.Append(merr, err) + + b, err = Backend(ctx, conf) if err != nil { - return nil, err + merr = multierror.Append(merr, err) + return nil, fmt.Errorf("invalid backend version: %s", merr) } if err := b.Setup(ctx, conf); err != nil { - return nil, err + merr = multierror.Append(merr, err) + return nil, merr.ErrorOrNil() } return b, nil } diff --git a/builtin/plugin/backend_test.go b/builtin/plugin/backend_test.go index 9354463bf..ef05f748b 100644 --- a/builtin/plugin/backend_test.go +++ b/builtin/plugin/backend_test.go @@ -24,22 +24,34 @@ func TestBackend_impl(t *testing.T) { } func TestBackend(t *testing.T) { - config, cleanup := testConfig(t) - defer cleanup() + pluginCmds := []string{"TestBackend_PluginMain", "TestBackend_PluginMain_Multiplexed"} - _, err := plugin.Backend(context.Background(), config) - if err != nil { - t.Fatal(err) + for _, pluginCmd := range pluginCmds { + t.Run(pluginCmd, func(t *testing.T) { + config, cleanup := testConfig(t, pluginCmd) + defer cleanup() + + _, err := plugin.Backend(context.Background(), config) + if err != nil { + t.Fatal(err) + } + }) } } func TestBackend_Factory(t *testing.T) { - config, cleanup := testConfig(t) - defer cleanup() + pluginCmds := []string{"TestBackend_PluginMain", "TestBackend_PluginMain_Multiplexed"} - _, err := plugin.Factory(context.Background(), config) - if err != nil { - t.Fatal(err) + for _, pluginCmd := range pluginCmds { + t.Run(pluginCmd, func(t *testing.T) { + config, cleanup := testConfig(t, pluginCmd) + defer cleanup() + + _, err := plugin.Factory(context.Background(), config) + if err != nil { + t.Fatal(err) + } + }) } } @@ -71,7 +83,35 @@ func TestBackend_PluginMain(t *testing.T) { } } -func testConfig(t *testing.T) (*logical.BackendConfig, func()) { +func TestBackend_PluginMain_Multiplexed(t *testing.T) { + args := []string{} + if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadataModeEnv) != "true" { + return + } + + caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv) + if caPEM == "" { + t.Fatal("CA cert not passed in") + } + + args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM)) + + apiClientMeta := &api.PluginAPIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(args) + tlsConfig := apiClientMeta.GetTLSConfig() + tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig) + + err := logicalPlugin.ServeMultiplex(&logicalPlugin.ServeOpts{ + BackendFactoryFunc: mock.Factory, + TLSProviderFunc: tlsProviderFunc, + }) + if err != nil { + t.Fatal(err) + } +} + +func testConfig(t *testing.T, pluginCmd string) (*logical.BackendConfig, func()) { cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, }) @@ -93,7 +133,7 @@ func testConfig(t *testing.T) (*logical.BackendConfig, func()) { os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMain", []string{}, "") + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, pluginCmd, []string{}, "") return config, func() { cluster.Cleanup() diff --git a/builtin/plugin/v5/backend.go b/builtin/plugin/v5/backend.go new file mode 100644 index 000000000..9ce97b246 --- /dev/null +++ b/builtin/plugin/v5/backend.go @@ -0,0 +1,147 @@ +package plugin + +import ( + "context" + "net/rpc" + "sync" + + uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/sdk/plugin" + bplugin "github.com/hashicorp/vault/sdk/plugin" +) + +// Backend returns an instance of the backend, either as a plugin if external +// or as a concrete implementation if builtin, casted as logical.Backend. +func Backend(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { + var b backend + name := conf.Config["plugin_name"] + pluginType, err := consts.ParsePluginType(conf.Config["plugin_type"]) + if err != nil { + return nil, err + } + + sys := conf.System + + raw, err := plugin.NewBackendV5(ctx, name, pluginType, sys, conf) + if err != nil { + return nil, err + } + b.Backend = raw + b.config = conf + + return &b, nil +} + +// backend is a thin wrapper around plugin.BackendPluginClientV5 +type backend struct { + logical.Backend + mu sync.RWMutex + + config *logical.BackendConfig + + // Used to detect if we already reloaded + canary string +} + +func (b *backend) reloadBackend(ctx context.Context) error { + pluginName := b.config.Config["plugin_name"] + pluginType, err := consts.ParsePluginType(b.config.Config["plugin_type"]) + if err != nil { + return err + } + + b.Logger().Debug("plugin: reloading plugin backend", "plugin", pluginName) + + // Ensure proper cleanup of the backend + // Pass a context value so that the plugin client will call the appropriate + // cleanup method for reloading + reloadCtx := context.WithValue(ctx, plugin.ContextKeyPluginReload, "reload") + b.Backend.Cleanup(reloadCtx) + + nb, err := plugin.NewBackendV5(ctx, pluginName, pluginType, b.config.System, b.config) + if err != nil { + return err + } + err = nb.Setup(ctx, b.config) + if err != nil { + return err + } + b.Backend = nb + + return nil +} + +// HandleRequest is a thin wrapper implementation of HandleRequest that includes automatic plugin reload. +func (b *backend) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) { + b.mu.RLock() + canary := b.canary + resp, err := b.Backend.HandleRequest(ctx, req) + b.mu.RUnlock() + // Need to compare string value for case were err comes from plugin RPC + // and is returned as plugin.BasicError type. + if err != nil && + (err.Error() == rpc.ErrShutdown.Error() || err == bplugin.ErrPluginShutdown) { + // Reload plugin if it's an rpc.ErrShutdown + b.mu.Lock() + if b.canary == canary { + err := b.reloadBackend(ctx) + if err != nil { + b.mu.Unlock() + return nil, err + } + b.canary, err = uuid.GenerateUUID() + if err != nil { + b.mu.Unlock() + return nil, err + } + } + b.mu.Unlock() + + // Try request once more + b.mu.RLock() + defer b.mu.RUnlock() + return b.Backend.HandleRequest(ctx, req) + } + return resp, err +} + +// HandleExistenceCheck is a thin wrapper implementation of HandleRequest that includes automatic plugin reload. +func (b *backend) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) { + b.mu.RLock() + canary := b.canary + checkFound, exists, err := b.Backend.HandleExistenceCheck(ctx, req) + b.mu.RUnlock() + if err != nil && + (err.Error() == rpc.ErrShutdown.Error() || err == bplugin.ErrPluginShutdown) { + // Reload plugin if it's an rpc.ErrShutdown + b.mu.Lock() + if b.canary == canary { + err := b.reloadBackend(ctx) + if err != nil { + b.mu.Unlock() + return false, false, err + } + b.canary, err = uuid.GenerateUUID() + if err != nil { + b.mu.Unlock() + return false, false, err + } + } + b.mu.Unlock() + + // Try request once more + b.mu.RLock() + defer b.mu.RUnlock() + return b.Backend.HandleExistenceCheck(ctx, req) + } + return checkFound, exists, err +} + +// InvalidateKey is a thin wrapper used to ensure we grab the lock for race purposes +func (b *backend) InvalidateKey(ctx context.Context, key string) { + b.mu.RLock() + defer b.mu.RUnlock() + b.Backend.InvalidateKey(ctx, key) +} diff --git a/changelog/14946.txt b/changelog/14946.txt new file mode 100644 index 000000000..43ee1e55d --- /dev/null +++ b/changelog/14946.txt @@ -0,0 +1,3 @@ +```release-note:feature +**Secrets/auth plugin multiplexing**: manage multiple plugin configurations with a single plugin process +``` diff --git a/http/plugin_test.go b/http/plugin_test.go index 38b8669eb..dcb5d436d 100644 --- a/http/plugin_test.go +++ b/http/plugin_test.go @@ -81,14 +81,10 @@ func TestPlugin_PluginMain(t *testing.T) { flags := apiClientMeta.FlagSet() flags.Parse(args) - tlsConfig := apiClientMeta.GetTLSConfig() - tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig) - factoryFunc := mock.FactoryType(logical.TypeLogical) err := plugin.Serve(&plugin.ServeOpts{ BackendFactoryFunc: factoryFunc, - TLSProviderFunc: tlsProviderFunc, }) if err != nil { t.Fatal(err) diff --git a/sdk/database/dbplugin/v5/grpc_server.go b/sdk/database/dbplugin/v5/grpc_server.go index 4d29a5a62..d38e97127 100644 --- a/sdk/database/dbplugin/v5/grpc_server.go +++ b/sdk/database/dbplugin/v5/grpc_server.go @@ -10,7 +10,6 @@ import ( "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" "github.com/hashicorp/vault/sdk/helper/pluginutil" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -30,25 +29,6 @@ type gRPCServer struct { sync.RWMutex } -func getMultiplexIDFromContext(ctx context.Context) (string, error) { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return "", fmt.Errorf("missing plugin multiplexing metadata") - } - - multiplexIDs := md[pluginutil.MultiplexingCtxKey] - if len(multiplexIDs) != 1 { - return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs)) - } - - multiplexID := multiplexIDs[0] - if multiplexID == "" { - return "", fmt.Errorf("empty multiplex ID in metadata") - } - - return multiplexID, nil -} - func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) { g.Lock() defer g.Unlock() @@ -57,7 +37,7 @@ func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) return g.singleImpl, nil } - id, err := getMultiplexIDFromContext(ctx) + id, err := pluginutil.GetMultiplexIDFromContext(ctx) if err != nil { return nil, err } @@ -83,7 +63,7 @@ func (g *gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) return g.singleImpl, nil } - id, err := getMultiplexIDFromContext(ctx) + id, err := pluginutil.GetMultiplexIDFromContext(ctx) if err != nil { return nil, err } @@ -312,7 +292,7 @@ func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, e if g.singleImpl == nil { // only cleanup instances map when multiplexing is supported - id, err := getMultiplexIDFromContext(ctx) + id, err := pluginutil.GetMultiplexIDFromContext(ctx) if err != nil { return nil, err } diff --git a/sdk/database/dbplugin/v5/grpc_server_test.go b/sdk/database/dbplugin/v5/grpc_server_test.go index 4f45e54bb..b901839d0 100644 --- a/sdk/database/dbplugin/v5/grpc_server_test.go +++ b/sdk/database/dbplugin/v5/grpc_server_test.go @@ -581,57 +581,6 @@ func TestGRPCServer_Close(t *testing.T) { } } -func TestGetMultiplexIDFromContext(t *testing.T) { - type testCase struct { - ctx context.Context - expectedResp string - expectedErr error - } - - tests := map[string]testCase{ - "missing plugin multiplexing metadata": { - ctx: context.Background(), - expectedResp: "", - expectedErr: fmt.Errorf("missing plugin multiplexing metadata"), - }, - "unexpected number of IDs in metadata": { - ctx: idCtx(t, "12345", "67891"), - expectedResp: "", - expectedErr: fmt.Errorf("unexpected number of IDs in metadata: (2)"), - }, - "empty multiplex ID in metadata": { - ctx: idCtx(t, ""), - expectedResp: "", - expectedErr: fmt.Errorf("empty multiplex ID in metadata"), - }, - "happy path, id is returned from metadata": { - ctx: idCtx(t, "12345"), - expectedResp: "12345", - expectedErr: nil, - }, - } - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - resp, err := getMultiplexIDFromContext(test.ctx) - - if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil { - t.Fatalf("err expected, got nil") - } else if !reflect.DeepEqual(err, test.expectedErr) { - t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr) - } - - if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil { - t.Fatalf("no error expected, got: %s", err) - } - - if !reflect.DeepEqual(resp, test.expectedResp) { - t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp) - } - }) - } -} - // testGrpcServer is a test helper that returns a context with an ID set in its // metadata and a gRPCServer instance for a multiplexed plugin func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) { diff --git a/sdk/database/dbplugin/v5/plugin_client_test.go b/sdk/database/dbplugin/v5/plugin_client_test.go index 0ff8309f1..c8faf55e1 100644 --- a/sdk/database/dbplugin/v5/plugin_client_test.go +++ b/sdk/database/dbplugin/v5/plugin_client_test.go @@ -112,6 +112,10 @@ func (f *fakePluginClient) Conn() grpc.ClientConnInterface { return nil } +func (f *fakePluginClient) Reload() error { + return nil +} + func (f *fakePluginClient) Dispense(name string) (interface{}, error) { return f.dispenseResp, f.dispenseErr } diff --git a/sdk/helper/pluginutil/env.go b/sdk/helper/pluginutil/env.go index fd0cd4fb8..24f82daec 100644 --- a/sdk/helper/pluginutil/env.go +++ b/sdk/helper/pluginutil/env.go @@ -7,7 +7,11 @@ import ( version "github.com/hashicorp/go-version" ) -var ( +const ( + // PluginAutoMTLSEnv is used to ensure AutoMTLS is used. This will override + // setting a TLSProviderFunc for a plugin. + PluginAutoMTLSEnv = "VAULT_PLUGIN_AUTOMTLS_ENABLED" + // PluginMlockEnabled is the ENV name used to pass the configuration for // enabling mlock PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED" diff --git a/sdk/helper/pluginutil/multiplexing.go b/sdk/helper/pluginutil/multiplexing.go index cbf50335d..726a4ca45 100644 --- a/sdk/helper/pluginutil/multiplexing.go +++ b/sdk/helper/pluginutil/multiplexing.go @@ -6,6 +6,7 @@ import ( grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" status "google.golang.org/grpc/status" ) @@ -45,3 +46,22 @@ func MultiplexingSupported(ctx context.Context, cc grpc.ClientConnInterface) (bo return resp.Supported, nil } + +func GetMultiplexIDFromContext(ctx context.Context) (string, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", fmt.Errorf("missing plugin multiplexing metadata") + } + + multiplexIDs := md[MultiplexingCtxKey] + if len(multiplexIDs) != 1 { + return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs)) + } + + multiplexID := multiplexIDs[0] + if multiplexID == "" { + return "", fmt.Errorf("empty multiplex ID in metadata") + } + + return multiplexID, nil +} diff --git a/sdk/helper/pluginutil/multiplexing_test.go b/sdk/helper/pluginutil/multiplexing_test.go new file mode 100644 index 000000000..bb230853d --- /dev/null +++ b/sdk/helper/pluginutil/multiplexing_test.go @@ -0,0 +1,73 @@ +package pluginutil + +import ( + "context" + "fmt" + "reflect" + "testing" + + "google.golang.org/grpc/metadata" +) + +func TestGetMultiplexIDFromContext(t *testing.T) { + type testCase struct { + ctx context.Context + expectedResp string + expectedErr error + } + + tests := map[string]testCase{ + "missing plugin multiplexing metadata": { + ctx: context.Background(), + expectedResp: "", + expectedErr: fmt.Errorf("missing plugin multiplexing metadata"), + }, + "unexpected number of IDs in metadata": { + ctx: idCtx(t, "12345", "67891"), + expectedResp: "", + expectedErr: fmt.Errorf("unexpected number of IDs in metadata: (2)"), + }, + "empty multiplex ID in metadata": { + ctx: idCtx(t, ""), + expectedResp: "", + expectedErr: fmt.Errorf("empty multiplex ID in metadata"), + }, + "happy path, id is returned from metadata": { + ctx: idCtx(t, "12345"), + expectedResp: "12345", + expectedErr: nil, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + resp, err := GetMultiplexIDFromContext(test.ctx) + + if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil { + t.Fatalf("err expected, got nil") + } else if !reflect.DeepEqual(err, test.expectedErr) { + t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr) + } + + if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil { + t.Fatalf("no error expected, got: %s", err) + } + + if !reflect.DeepEqual(resp, test.expectedResp) { + t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp) + } + }) + } +} + +// idCtx is a test helper that will return a context with the IDs set in its +// metadata +func idCtx(t *testing.T, ids ...string) context.Context { + // Context doesn't need to timeout since this is just passed through + ctx := context.Background() + md := metadata.MD{} + for _, id := range ids { + md.Append(MultiplexingCtxKey, id) + } + return metadata.NewIncomingContext(ctx, md) +} diff --git a/sdk/helper/pluginutil/run_config.go b/sdk/helper/pluginutil/run_config.go index 47228abb9..3eb8fb2b2 100644 --- a/sdk/helper/pluginutil/run_config.go +++ b/sdk/helper/pluginutil/run_config.go @@ -23,6 +23,7 @@ type PluginClientConfig struct { IsMetadataMode bool AutoMTLS bool MLock bool + Wrapper RunnerUtil } type runConfig struct { @@ -34,8 +35,6 @@ type runConfig struct { // Initialized with what's in PluginRunner.Env, but can be added to env []string - wrapper RunnerUtil - PluginClientConfig } @@ -44,7 +43,7 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error cmd.Env = append(cmd.Env, rc.env...) // Add the mlock setting to the ENV of the plugin - if rc.MLock || (rc.wrapper != nil && rc.wrapper.MlockEnabled()) { + if rc.MLock || (rc.Wrapper != nil && rc.Wrapper.MlockEnabled()) { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true")) } cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version)) @@ -55,6 +54,9 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error metadataEnv := fmt.Sprintf("%s=%t", PluginMetadataModeEnv, rc.IsMetadataMode) cmd.Env = append(cmd.Env, metadataEnv) + automtlsEnv := fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, rc.AutoMTLS) + cmd.Env = append(cmd.Env, automtlsEnv) + var clientTLSConfig *tls.Config if !rc.AutoMTLS && !rc.IsMetadataMode { // Get a CA TLS Certificate @@ -71,7 +73,7 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error // Use CA to sign a server cert and wrap the values in a response wrapped // token. - wrapToken, err := wrapServerConfig(ctx, rc.wrapper, certBytes, key) + wrapToken, err := wrapServerConfig(ctx, rc.Wrapper, certBytes, key) if err != nil { return nil, err } @@ -121,7 +123,7 @@ func Env(env ...string) RunOpt { func Runner(wrapper RunnerUtil) RunOpt { return func(rc *runConfig) { - rc.wrapper = wrapper + rc.Wrapper = wrapper } } diff --git a/sdk/helper/pluginutil/run_config_test.go b/sdk/helper/pluginutil/run_config_test.go index f2373fe9b..3c2fef219 100644 --- a/sdk/helper/pluginutil/run_config_test.go +++ b/sdk/helper/pluginutil/run_config_test.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "os/exec" - "reflect" "testing" "time" @@ -14,6 +13,7 @@ import ( "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/sdk/helper/wrapping" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) func TestMakeConfig(t *testing.T) { @@ -78,6 +78,7 @@ func TestMakeConfig(t *testing.T) { "initial=true", fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version), fmt.Sprintf("%s=%t", PluginMetadataModeEnv, true), + fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, false), }, ), SecureConfig: &plugin.SecureConfig{ @@ -143,6 +144,7 @@ func TestMakeConfig(t *testing.T) { fmt.Sprintf("%s=%t", PluginMlockEnabled, true), fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version), fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false), + fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, false), fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, "testtoken"), }, ), @@ -205,6 +207,7 @@ func TestMakeConfig(t *testing.T) { "initial=true", fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version), fmt.Sprintf("%s=%t", PluginMetadataModeEnv, true), + fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true), }, ), SecureConfig: &plugin.SecureConfig{ @@ -266,6 +269,7 @@ func TestMakeConfig(t *testing.T) { "initial=true", fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version), fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false), + fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true), }, ), SecureConfig: &plugin.SecureConfig{ @@ -290,7 +294,7 @@ func TestMakeConfig(t *testing.T) { Return(test.responseWrapInfo, test.responseWrapInfoErr) mockWrapper.On("MlockEnabled"). Return(test.mlockEnabled) - test.rc.wrapper = mockWrapper + test.rc.Wrapper = mockWrapper defer mockWrapper.AssertNumberOfCalls(t, "ResponseWrapData", test.responseWrapInfoTimes) defer mockWrapper.AssertNumberOfCalls(t, "MlockEnabled", test.mlockEnabledTimes) @@ -318,9 +322,7 @@ func TestMakeConfig(t *testing.T) { } config.TLSConfig = nil - if !reflect.DeepEqual(config, test.expectedConfig) { - t.Fatalf("Actual config: %#v\nExpected config: %#v", config, test.expectedConfig) - } + require.Equal(t, test.expectedConfig, config) }) } } diff --git a/sdk/helper/pluginutil/runner.go b/sdk/helper/pluginutil/runner.go index b18951e37..370da22d1 100644 --- a/sdk/helper/pluginutil/runner.go +++ b/sdk/helper/pluginutil/runner.go @@ -36,6 +36,7 @@ type LookRunnerUtil interface { type PluginClient interface { Conn() grpc.ClientConnInterface + Reload() error plugin.ClientProtocol } diff --git a/sdk/plugin/backend.go b/sdk/plugin/backend.go index 82c728732..545565684 100644 --- a/sdk/plugin/backend.go +++ b/sdk/plugin/backend.go @@ -8,6 +8,7 @@ import ( log "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/plugin/pb" ) @@ -24,18 +25,32 @@ type GRPCBackendPlugin struct { MetadataMode bool Logger log.Logger + MultiplexingSupport bool + // Embeding this will disable the netRPC protocol plugin.NetRPCUnsupportedPlugin } func (b GRPCBackendPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error { - pb.RegisterBackendServer(s, &backendGRPCPluginServer{ - broker: broker, - factory: b.Factory, - // We pass the logger down into the backend so go-plugin will forward - // logs for us. + server := backendGRPCPluginServer{ + broker: broker, + factory: b.Factory, + instances: make(map[string]backendInstance), + // We pass the logger down into the backend so go-plugin will + // forward logs for us. logger: b.Logger, - }) + } + + if b.MultiplexingSupport { + // Multiplexing is enabled for this plugin, register the server so we + // can tell the client in Vault. + pluginutil.RegisterPluginMultiplexingServer(s, pluginutil.PluginMultiplexingServerImpl{ + Supported: true, + }) + server.multiplexingSupport = true + } + + pb.RegisterBackendServer(s, &server) return nil } diff --git a/sdk/plugin/grpc_backend_server.go b/sdk/plugin/grpc_backend_server.go index ce9ecdf06..3d38a0e26 100644 --- a/sdk/plugin/grpc_backend_server.go +++ b/sdk/plugin/grpc_backend_server.go @@ -3,6 +3,8 @@ package plugin import ( "context" "errors" + "fmt" + "sync" log "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" @@ -14,29 +16,79 @@ import ( var ErrServerInMetadataMode = errors.New("plugin server can not perform action while in metadata mode") +// singleImplementationID is the string used to define the instance ID of a +// non-multiplexed plugin +const singleImplementationID string = "single" + +type backendInstance struct { + brokeredClient *grpc.ClientConn + backend logical.Backend +} + type backendGRPCPluginServer struct { pb.UnimplementedBackendServer - broker *plugin.GRPCBroker - backend logical.Backend + broker *plugin.GRPCBroker + + instances map[string]backendInstance + instancesLock sync.RWMutex + multiplexingSupport bool factory logical.Factory - brokeredClient *grpc.ClientConn - logger log.Logger } +// getBackendAndBrokeredClientInternal returns the backend and client +// connection but does not hold a lock +func (b *backendGRPCPluginServer) getBackendAndBrokeredClientInternal(ctx context.Context) (logical.Backend, *grpc.ClientConn, error) { + if b.multiplexingSupport { + id, err := pluginutil.GetMultiplexIDFromContext(ctx) + if err != nil { + return nil, nil, err + } + + if inst, ok := b.instances[id]; ok { + return inst.backend, inst.brokeredClient, nil + } + + } + + if singleImpl, ok := b.instances[singleImplementationID]; ok { + return singleImpl.backend, singleImpl.brokeredClient, nil + } + + return nil, nil, fmt.Errorf("no backend instance found") +} + +// getBackendAndBrokeredClient holds a read lock and returns the backend and +// client connection +func (b *backendGRPCPluginServer) getBackendAndBrokeredClient(ctx context.Context) (logical.Backend, *grpc.ClientConn, error) { + b.instancesLock.RLock() + defer b.instancesLock.RUnlock() + return b.getBackendAndBrokeredClientInternal(ctx) +} + // Setup dials into the plugin's broker to get a shimmed storage, logger, and // system view of the backend. This method also instantiates the underlying // backend through its factory func for the server side of the plugin. func (b *backendGRPCPluginServer) Setup(ctx context.Context, args *pb.SetupArgs) (*pb.SetupReply, error) { + var err error + id := singleImplementationID + + if b.multiplexingSupport { + id, err = pluginutil.GetMultiplexIDFromContext(ctx) + if err != nil { + return &pb.SetupReply{}, err + } + } + // Dial for storage brokeredClient, err := b.broker.Dial(args.BrokerID) if err != nil { return &pb.SetupReply{}, err } - b.brokeredClient = brokeredClient + storage := newGRPCStorageClient(brokeredClient) sysView := newGRPCSystemView(brokeredClient) @@ -56,12 +108,20 @@ func (b *backendGRPCPluginServer) Setup(ctx context.Context, args *pb.SetupArgs) Err: pb.ErrToString(err), }, nil } - b.backend = backend + b.instances[id] = backendInstance{ + brokeredClient: brokeredClient, + backend: backend, + } return &pb.SetupReply{}, nil } func (b *backendGRPCPluginServer) HandleRequest(ctx context.Context, args *pb.HandleRequestArgs) (*pb.HandleRequestReply, error) { + backend, brokeredClient, err := b.getBackendAndBrokeredClient(ctx) + if err != nil { + return &pb.HandleRequestReply{}, err + } + if pluginutil.InMetadataMode() { return &pb.HandleRequestReply{}, ErrServerInMetadataMode } @@ -71,9 +131,9 @@ func (b *backendGRPCPluginServer) HandleRequest(ctx context.Context, args *pb.Ha return &pb.HandleRequestReply{}, err } - logicalReq.Storage = newGRPCStorageClient(b.brokeredClient) + logicalReq.Storage = newGRPCStorageClient(brokeredClient) - resp, respErr := b.backend.HandleRequest(ctx, logicalReq) + resp, respErr := backend.HandleRequest(ctx, logicalReq) pbResp, err := pb.LogicalResponseToProtoResponse(resp) if err != nil { @@ -87,15 +147,20 @@ func (b *backendGRPCPluginServer) HandleRequest(ctx context.Context, args *pb.Ha } func (b *backendGRPCPluginServer) Initialize(ctx context.Context, _ *pb.InitializeArgs) (*pb.InitializeReply, error) { + backend, brokeredClient, err := b.getBackendAndBrokeredClient(ctx) + if err != nil { + return &pb.InitializeReply{}, err + } + if pluginutil.InMetadataMode() { return &pb.InitializeReply{}, ErrServerInMetadataMode } req := &logical.InitializationRequest{ - Storage: newGRPCStorageClient(b.brokeredClient), + Storage: newGRPCStorageClient(brokeredClient), } - respErr := b.backend.Initialize(ctx, req) + respErr := backend.Initialize(ctx, req) return &pb.InitializeReply{ Err: pb.ErrToProtoErr(respErr), @@ -103,7 +168,12 @@ func (b *backendGRPCPluginServer) Initialize(ctx context.Context, _ *pb.Initiali } func (b *backendGRPCPluginServer) SpecialPaths(ctx context.Context, args *pb.Empty) (*pb.SpecialPathsReply, error) { - paths := b.backend.SpecialPaths() + backend, _, err := b.getBackendAndBrokeredClient(ctx) + if err != nil { + return &pb.SpecialPathsReply{}, err + } + + paths := backend.SpecialPaths() if paths == nil { return &pb.SpecialPathsReply{ Paths: nil, @@ -121,6 +191,11 @@ func (b *backendGRPCPluginServer) SpecialPaths(ctx context.Context, args *pb.Emp } func (b *backendGRPCPluginServer) HandleExistenceCheck(ctx context.Context, args *pb.HandleExistenceCheckArgs) (*pb.HandleExistenceCheckReply, error) { + backend, brokeredClient, err := b.getBackendAndBrokeredClient(ctx) + if err != nil { + return &pb.HandleExistenceCheckReply{}, err + } + if pluginutil.InMetadataMode() { return &pb.HandleExistenceCheckReply{}, ErrServerInMetadataMode } @@ -129,9 +204,10 @@ func (b *backendGRPCPluginServer) HandleExistenceCheck(ctx context.Context, args if err != nil { return &pb.HandleExistenceCheckReply{}, err } - logicalReq.Storage = newGRPCStorageClient(b.brokeredClient) - checkFound, exists, err := b.backend.HandleExistenceCheck(ctx, logicalReq) + logicalReq.Storage = newGRPCStorageClient(brokeredClient) + + checkFound, exists, err := backend.HandleExistenceCheck(ctx, logicalReq) return &pb.HandleExistenceCheckReply{ CheckFound: checkFound, Exists: exists, @@ -140,24 +216,53 @@ func (b *backendGRPCPluginServer) HandleExistenceCheck(ctx context.Context, args } func (b *backendGRPCPluginServer) Cleanup(ctx context.Context, _ *pb.Empty) (*pb.Empty, error) { - b.backend.Cleanup(ctx) + b.instancesLock.Lock() + defer b.instancesLock.Unlock() + + backend, brokeredClient, err := b.getBackendAndBrokeredClientInternal(ctx) + if err != nil { + return &pb.Empty{}, err + } + + backend.Cleanup(ctx) // Close rpc clients - b.brokeredClient.Close() + brokeredClient.Close() + + if b.multiplexingSupport { + id, err := pluginutil.GetMultiplexIDFromContext(ctx) + if err != nil { + return nil, err + } + delete(b.instances, id) + } else if _, ok := b.instances[singleImplementationID]; ok { + delete(b.instances, singleImplementationID) + } + return &pb.Empty{}, nil } func (b *backendGRPCPluginServer) InvalidateKey(ctx context.Context, args *pb.InvalidateKeyArgs) (*pb.Empty, error) { + backend, _, err := b.getBackendAndBrokeredClient(ctx) + if err != nil { + return &pb.Empty{}, err + } + if pluginutil.InMetadataMode() { return &pb.Empty{}, ErrServerInMetadataMode } - b.backend.InvalidateKey(ctx, args.Key) + backend.InvalidateKey(ctx, args.Key) return &pb.Empty{}, nil } func (b *backendGRPCPluginServer) Type(ctx context.Context, _ *pb.Empty) (*pb.TypeReply, error) { + backend, _, err := b.getBackendAndBrokeredClient(ctx) + if err != nil { + return &pb.TypeReply{}, err + } + return &pb.TypeReply{ - Type: uint32(b.backend.Type()), + Type: uint32(backend.Type()), }, nil } diff --git a/sdk/plugin/grpc_system_test.go b/sdk/plugin/grpc_system_test.go index fa9744434..8d4de0afa 100644 --- a/sdk/plugin/grpc_system_test.go +++ b/sdk/plugin/grpc_system_test.go @@ -6,12 +6,12 @@ import ( "testing" "time" - "github.com/golang/protobuf/proto" plugin "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/plugin/pb" "google.golang.org/grpc" + "google.golang.org/protobuf/proto" ) func TestSystem_GRPC_GRPC_impl(t *testing.T) { diff --git a/sdk/plugin/logger_test.go b/sdk/plugin/logger_test.go index 99c27b15b..a2b8a8015 100644 --- a/sdk/plugin/logger_test.go +++ b/sdk/plugin/logger_test.go @@ -222,7 +222,6 @@ func (l *deprecatedLoggerClient) Error(msg string, args ...interface{}) error { func (l *deprecatedLoggerClient) Fatal(msg string, args ...interface{}) { // NOOP since it's not actually used within vault - return } func (l *deprecatedLoggerClient) Log(level int, msg string, args []interface{}) { diff --git a/sdk/plugin/middleware.go b/sdk/plugin/middleware.go index 04a6f4c50..3f0babde2 100644 --- a/sdk/plugin/middleware.go +++ b/sdk/plugin/middleware.go @@ -10,16 +10,16 @@ import ( // backendPluginClient implements logical.Backend and is the // go-plugin client. -type backendTracingMiddleware struct { +type BackendTracingMiddleware struct { logger log.Logger next logical.Backend } // Validate the backendTracingMiddle object satisfies the backend interface -var _ logical.Backend = &backendTracingMiddleware{} +var _ logical.Backend = &BackendTracingMiddleware{} -func (b *backendTracingMiddleware) Initialize(ctx context.Context, req *logical.InitializationRequest) (err error) { +func (b *BackendTracingMiddleware) Initialize(ctx context.Context, req *logical.InitializationRequest) (err error) { defer func(then time.Time) { b.logger.Trace("initialize", "status", "finished", "err", err, "took", time.Since(then)) }(time.Now()) @@ -28,7 +28,7 @@ func (b *backendTracingMiddleware) Initialize(ctx context.Context, req *logical. return b.next.Initialize(ctx, req) } -func (b *backendTracingMiddleware) HandleRequest(ctx context.Context, req *logical.Request) (resp *logical.Response, err error) { +func (b *BackendTracingMiddleware) HandleRequest(ctx context.Context, req *logical.Request) (resp *logical.Response, err error) { defer func(then time.Time) { b.logger.Trace("handle request", "path", req.Path, "status", "finished", "err", err, "took", time.Since(then)) }(time.Now()) @@ -37,7 +37,7 @@ func (b *backendTracingMiddleware) HandleRequest(ctx context.Context, req *logic return b.next.HandleRequest(ctx, req) } -func (b *backendTracingMiddleware) SpecialPaths() *logical.Paths { +func (b *BackendTracingMiddleware) SpecialPaths() *logical.Paths { defer func(then time.Time) { b.logger.Trace("special paths", "status", "finished", "took", time.Since(then)) }(time.Now()) @@ -46,15 +46,15 @@ func (b *backendTracingMiddleware) SpecialPaths() *logical.Paths { return b.next.SpecialPaths() } -func (b *backendTracingMiddleware) System() logical.SystemView { +func (b *BackendTracingMiddleware) System() logical.SystemView { return b.next.System() } -func (b *backendTracingMiddleware) Logger() log.Logger { +func (b *BackendTracingMiddleware) Logger() log.Logger { return b.next.Logger() } -func (b *backendTracingMiddleware) HandleExistenceCheck(ctx context.Context, req *logical.Request) (found bool, exists bool, err error) { +func (b *BackendTracingMiddleware) HandleExistenceCheck(ctx context.Context, req *logical.Request) (found bool, exists bool, err error) { defer func(then time.Time) { b.logger.Trace("handle existence check", "path", req.Path, "status", "finished", "err", err, "took", time.Since(then)) }(time.Now()) @@ -63,7 +63,7 @@ func (b *backendTracingMiddleware) HandleExistenceCheck(ctx context.Context, req return b.next.HandleExistenceCheck(ctx, req) } -func (b *backendTracingMiddleware) Cleanup(ctx context.Context) { +func (b *BackendTracingMiddleware) Cleanup(ctx context.Context) { defer func(then time.Time) { b.logger.Trace("cleanup", "status", "finished", "took", time.Since(then)) }(time.Now()) @@ -72,7 +72,7 @@ func (b *backendTracingMiddleware) Cleanup(ctx context.Context) { b.next.Cleanup(ctx) } -func (b *backendTracingMiddleware) InvalidateKey(ctx context.Context, key string) { +func (b *BackendTracingMiddleware) InvalidateKey(ctx context.Context, key string) { defer func(then time.Time) { b.logger.Trace("invalidate key", "key", key, "status", "finished", "took", time.Since(then)) }(time.Now()) @@ -81,7 +81,7 @@ func (b *backendTracingMiddleware) InvalidateKey(ctx context.Context, key string b.next.InvalidateKey(ctx, key) } -func (b *backendTracingMiddleware) Setup(ctx context.Context, config *logical.BackendConfig) (err error) { +func (b *BackendTracingMiddleware) Setup(ctx context.Context, config *logical.BackendConfig) (err error) { defer func(then time.Time) { b.logger.Trace("setup", "status", "finished", "err", err, "took", time.Since(then)) }(time.Now()) @@ -90,7 +90,7 @@ func (b *backendTracingMiddleware) Setup(ctx context.Context, config *logical.Ba return b.next.Setup(ctx, config) } -func (b *backendTracingMiddleware) Type() logical.BackendType { +func (b *BackendTracingMiddleware) Type() logical.BackendType { defer func(then time.Time) { b.logger.Trace("type", "status", "finished", "took", time.Since(then)) }(time.Now()) diff --git a/sdk/plugin/plugin.go b/sdk/plugin/plugin.go index f4f2d8e18..58163f2b3 100644 --- a/sdk/plugin/plugin.go +++ b/sdk/plugin/plugin.go @@ -4,9 +4,7 @@ import ( "context" "errors" "fmt" - "sync" - "github.com/hashicorp/errwrap" log "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/sdk/helper/consts" @@ -19,7 +17,6 @@ import ( // used to cleanly kill the client on Cleanup() type BackendPluginClient struct { client *plugin.Client - sync.Mutex logical.Backend } @@ -48,7 +45,7 @@ func NewBackend(ctx context.Context, pluginName string, pluginType consts.Plugin // from the pluginRunner. Then cast it to logical.Factory. rawFactory, err := pluginRunner.BuiltinFactory() if err != nil { - return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err) + return nil, fmt.Errorf("error getting plugin type: %q", err) } if factory, ok := rawFactory.(logical.Factory); !ok { @@ -93,9 +90,9 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne var client *plugin.Client var err error if isMetadataMode { - client, err = pluginRunner.RunMetadataMode(ctx, sys, pluginSet, handshakeConfig, []string{}, namedLogger) + client, err = pluginRunner.RunMetadataMode(ctx, sys, pluginSet, HandshakeConfig, []string{}, namedLogger) } else { - client, err = pluginRunner.Run(ctx, sys, pluginSet, handshakeConfig, []string{}, namedLogger) + client, err = pluginRunner.Run(ctx, sys, pluginSet, HandshakeConfig, []string{}, namedLogger) } if err != nil { return nil, err @@ -117,9 +114,9 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne var transport string // We should have a logical backend type now. This feels like a normal interface // implementation but is in fact over an RPC connection. - switch raw.(type) { + switch b := raw.(type) { case *backendGRPCPluginClient: - backend = raw.(*backendGRPCPluginClient) + backend = b transport = "gRPC" default: return nil, errors.New("unsupported plugin client type") @@ -127,7 +124,7 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne // Wrap the backend in a tracing middleware if namedLogger.IsTrace() { - backend = &backendTracingMiddleware{ + backend = &BackendTracingMiddleware{ logger: namedLogger.With("transport", transport), next: backend, } @@ -138,22 +135,3 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne Backend: backend, }, nil } - -// wrapError takes a generic error type and makes it usable with the plugin -// interface. Only errors which have exported fields and have been registered -// with gob can be unwrapped and transported. This checks error types and, if -// none match, wrap the error in a plugin.BasicError. -func wrapError(err error) error { - if err == nil { - return nil - } - - switch err.(type) { - case *plugin.BasicError, - logical.HTTPCodedError, - *logical.StatusBadRequest: - return err - } - - return plugin.NewBasicError(err) -} diff --git a/sdk/plugin/plugin_v5.go b/sdk/plugin/plugin_v5.go new file mode 100644 index 000000000..fda30f1ce --- /dev/null +++ b/sdk/plugin/plugin_v5.go @@ -0,0 +1,165 @@ +package plugin + +import ( + "context" + "errors" + "fmt" + + plugin "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/pluginutil" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/sdk/plugin/pb" +) + +// BackendPluginClientV5 is a wrapper around backendPluginClient +// that also contains its plugin.Client instance. It's primarily +// used to cleanly kill the client on Cleanup() +type BackendPluginClientV5 struct { + client pluginutil.PluginClient + + logical.Backend +} + +type ContextKey string + +func (c ContextKey) String() string { + return "plugin" + string(c) +} + +const ContextKeyPluginReload = ContextKey("plugin-reload") + +// Cleanup cleans up the go-plugin client and the plugin catalog +func (b *BackendPluginClientV5) Cleanup(ctx context.Context) { + _, ok := ctx.Value(ContextKeyPluginReload).(string) + if !ok { + b.Backend.Cleanup(ctx) + b.client.Close() + return + } + b.Backend.Cleanup(ctx) + b.client.Reload() +} + +// NewBackendV5 will return an instance of an RPC-based client implementation of +// the backend for external plugins, or a concrete implementation of the +// backend if it is a builtin backend. The backend is returned as a +// logical.Backend interface. +func NewBackendV5(ctx context.Context, pluginName string, pluginType consts.PluginType, sys pluginutil.LookRunnerUtil, conf *logical.BackendConfig) (logical.Backend, error) { + // Look for plugin in the plugin catalog + pluginRunner, err := sys.LookupPlugin(ctx, pluginName, pluginType) + if err != nil { + return nil, err + } + + var backend logical.Backend + if pluginRunner.Builtin { + // Plugin is builtin so we can retrieve an instance of the interface + // from the pluginRunner. Then cast it to logical.Factory. + rawFactory, err := pluginRunner.BuiltinFactory() + if err != nil { + return nil, fmt.Errorf("error getting plugin type: %q", err) + } + + if factory, ok := rawFactory.(logical.Factory); !ok { + return nil, fmt.Errorf("unsupported backend type: %q", pluginName) + } else { + if backend, err = factory(ctx, conf); err != nil { + return nil, err + } + } + } else { + // create a backendPluginClient instance + config := pluginutil.PluginClientConfig{ + Name: pluginName, + PluginSets: PluginSet, + PluginType: pluginType, + HandshakeConfig: HandshakeConfig, + Logger: conf.Logger.Named(pluginName), + AutoMTLS: true, + Wrapper: sys, + } + backend, err = NewPluginClientV5(ctx, sys, config) + if err != nil { + return nil, err + } + } + + return backend, nil +} + +// PluginSet is the map of plugins we can dispense. +var PluginSet = map[int]plugin.PluginSet{ + 5: { + "backend": &GRPCBackendPlugin{}, + }, +} + +func Dispense(rpcClient plugin.ClientProtocol, pluginClient pluginutil.PluginClient) (logical.Backend, error) { + // Request the plugin + raw, err := rpcClient.Dispense("backend") + if err != nil { + return nil, err + } + + var backend logical.Backend + // We should have a logical backend type now. This feels like a normal interface + // implementation but is in fact over an RPC connection. + switch c := raw.(type) { + case *backendGRPCPluginClient: + // This is an abstraction leak from go-plugin but it is necessary in + // order to enable multiplexing on multiplexed plugins + c.client = pb.NewBackendClient(pluginClient.Conn()) + + backend = c + default: + return nil, errors.New("unsupported plugin client type") + } + + return &BackendPluginClientV5{ + client: pluginClient, + Backend: backend, + }, nil +} + +func NewPluginClientV5(ctx context.Context, sys pluginutil.RunnerUtil, config pluginutil.PluginClientConfig) (logical.Backend, error) { + pluginClient, err := sys.NewPluginClient(ctx, config) + if err != nil { + return nil, err + } + + // Request the plugin + raw, err := pluginClient.Dispense("backend") + if err != nil { + return nil, err + } + + var backend logical.Backend + var transport string + // We should have a logical backend type now. This feels like a normal interface + // implementation but is in fact over an RPC connection. + switch c := raw.(type) { + case *backendGRPCPluginClient: + // This is an abstraction leak from go-plugin but it is necessary in + // order to enable multiplexing on multiplexed plugins + c.client = pb.NewBackendClient(pluginClient.Conn()) + + backend = c + transport = "gRPC" + default: + return nil, errors.New("unsupported plugin client type") + } + + // Wrap the backend in a tracing middleware + if config.Logger.IsTrace() { + backend = &BackendTracingMiddleware{ + logger: config.Logger.With("transport", transport), + next: backend, + } + } + + return &BackendPluginClientV5{ + client: pluginClient, + Backend: backend, + }, nil +} diff --git a/sdk/plugin/serve.go b/sdk/plugin/serve.go index 1119a2dac..0da143f76 100644 --- a/sdk/plugin/serve.go +++ b/sdk/plugin/serve.go @@ -55,6 +55,13 @@ func Serve(opts *ServeOpts) error { Logger: logger, }, }, + 5: { + "backend": &GRPCBackendPlugin{ + Factory: opts.BackendFactoryFunc, + MultiplexingSupport: false, + Logger: logger, + }, + }, } err := pluginutil.OptionallyEnableMlock() @@ -63,7 +70,7 @@ func Serve(opts *ServeOpts) error { } serveOpts := &plugin.ServeConfig{ - HandshakeConfig: handshakeConfig, + HandshakeConfig: HandshakeConfig, VersionedPlugins: pluginSets, TLSProvider: opts.TLSProviderFunc, Logger: logger, @@ -81,12 +88,77 @@ func Serve(opts *ServeOpts) error { return nil } +// ServeMultiplex is a helper function used to serve a backend plugin. This +// should be ran on the plugin's main process. +func ServeMultiplex(opts *ServeOpts) error { + logger := opts.Logger + if logger == nil { + logger = log.New(&log.LoggerOptions{ + Level: log.Info, + Output: os.Stderr, + JSONFormat: true, + }) + } + + // pluginMap is the map of plugins we can dispense. + pluginSets := map[int]plugin.PluginSet{ + // Version 3 used to supports both protocols. We want to keep it around + // since it's possible old plugins built against this version will still + // work with gRPC. There is currently no difference between version 3 + // and version 4. + 3: { + "backend": &GRPCBackendPlugin{ + Factory: opts.BackendFactoryFunc, + Logger: logger, + }, + }, + 4: { + "backend": &GRPCBackendPlugin{ + Factory: opts.BackendFactoryFunc, + Logger: logger, + }, + }, + 5: { + "backend": &GRPCBackendPlugin{ + Factory: opts.BackendFactoryFunc, + MultiplexingSupport: true, + Logger: logger, + }, + }, + } + + err := pluginutil.OptionallyEnableMlock() + if err != nil { + return err + } + + serveOpts := &plugin.ServeConfig{ + HandshakeConfig: HandshakeConfig, + VersionedPlugins: pluginSets, + Logger: logger, + + // A non-nil value here enables gRPC serving for this plugin... + GRPCServer: func(opts []grpc.ServerOption) *grpc.Server { + opts = append(opts, grpc.MaxRecvMsgSize(math.MaxInt32)) + opts = append(opts, grpc.MaxSendMsgSize(math.MaxInt32)) + return plugin.DefaultGRPCServer(opts) + }, + + // TLSProvider is required to support v3 and v4 plugins. + // It will be ignored for v5 which uses AutoMTLS + TLSProvider: opts.TLSProviderFunc, + } + + plugin.Serve(serveOpts) + + return nil +} + // handshakeConfigs are used to just do a basic handshake between // a plugin and host. If the handshake fails, a user friendly error is shown. // This prevents users from executing bad plugins or executing a plugin // directory. It is a UX feature, not a security feature. -var handshakeConfig = plugin.HandshakeConfig{ - ProtocolVersion: 4, +var HandshakeConfig = plugin.HandshakeConfig{ MagicCookieKey: "VAULT_BACKEND_PLUGIN", MagicCookieValue: "6669da05-b1c8-4f49-97d9-c8e5bed98e20", } diff --git a/vault/auth.go b/vault/auth.go index 58cbd2e38..0720cc656 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/hashicorp/go-secure-stdlib/strutil" - "github.com/hashicorp/go-uuid" + uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/builtin/plugin" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/helper/consts" @@ -924,7 +924,6 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV f = wrapFactoryCheckPerms(c, plugin.Factory) } } - // Set up conf to pass in plugin_name conf := make(map[string]string) for k, v := range entry.Options { diff --git a/vault/logical_system_integ_test.go b/vault/logical_system_integ_test.go index 174231642..e573e431e 100644 --- a/vault/logical_system_integ_test.go +++ b/vault/logical_system_integ_test.go @@ -1,6 +1,7 @@ package vault_test import ( + "context" "fmt" "io/ioutil" "os" @@ -31,219 +32,325 @@ const ( expectedEnvValue = "BAR" ) +// logicalVersionMap is a map of version to test plugin +var logicalVersionMap = map[string]string{ + "v4": "TestBackend_PluginMain_V4_Logical", + "v5": "TestBackend_PluginMainLogical", + "v5_multiplexed": "TestBackend_PluginMain_Multiplexed_Logical", +} + +// credentialVersionMap is a map of version to test plugin +var credentialVersionMap = map[string]string{ + "v4": "TestBackend_PluginMain_V4_Credentials", + "v5": "TestBackend_PluginMainCredentials", + "v5_multiplexed": "TestBackend_PluginMain_Multiplexed_Credentials", +} + +var testCtx = context.TODO() + func TestSystemBackend_Plugin_secret(t *testing.T) { - cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) - defer cluster.Cleanup() - - core := cluster.Cores[0] - - // Make a request to lazy load the plugin - req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") - req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(namespace.RootContext(nil), req) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp == nil { - t.Fatalf("bad: response should not be nil") + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5_multiplexed", + }, + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, } - // Seal the cluster - cluster.EnsureCoresSealed(t) + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion) + defer cluster.Cleanup() - // Unseal the cluster - barrierKeys := cluster.BarrierKeys - for _, core := range cluster.Cores { - for _, key := range barrierKeys { - _, err := core.Unseal(vault.TestKeyCopy(key)) + core := cluster.Cores[0] + + // Make a request to lazy load the plugin + req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(namespace.RootContext(testCtx), req) if err != nil { - t.Fatal(err) + t.Fatalf("err: %v", err) } - } - if core.Sealed() { - t.Fatal("should not be sealed") - } - // Wait for active so post-unseal takes place - // If it fails, it means unseal process failed - vault.TestWaitActive(t, core.Core) + if resp == nil { + t.Fatalf("bad: response should not be nil") + } + + // Seal the cluster + cluster.EnsureCoresSealed(t) + + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + if core.Sealed() { + t.Fatal("should not be sealed") + } + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + } + }) } } func TestSystemBackend_Plugin_auth(t *testing.T) { - cluster := testSystemBackendMock(t, 1, 1, logical.TypeCredential) - defer cluster.Cleanup() - - core := cluster.Cores[0] - - // Make a request to lazy load the plugin - req := logical.TestRequest(t, logical.ReadOperation, "auth/mock-0/internal") - req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(namespace.RootContext(nil), req) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp == nil { - t.Fatalf("bad: response should not be nil") + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5_multiplexed", + }, + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, } - // Seal the cluster - cluster.EnsureCoresSealed(t) + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeCredential, tc.pluginVersion) + defer cluster.Cleanup() - // Unseal the cluster - barrierKeys := cluster.BarrierKeys - for _, core := range cluster.Cores { - for _, key := range barrierKeys { - _, err := core.Unseal(vault.TestKeyCopy(key)) + core := cluster.Cores[0] + + // Make a request to lazy load the plugin + req := logical.TestRequest(t, logical.ReadOperation, "auth/mock-0/internal") + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(namespace.RootContext(testCtx), req) if err != nil { - t.Fatal(err) + t.Fatalf("err: %v", err) } - } - if core.Sealed() { - t.Fatal("should not be sealed") - } - // Wait for active so post-unseal takes place - // If it fails, it means unseal process failed - vault.TestWaitActive(t, core.Core) + if resp == nil { + t.Fatalf("bad: response should not be nil") + } + + // Seal the cluster + cluster.EnsureCoresSealed(t) + + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + if core.Sealed() { + t.Fatal("should not be sealed") + } + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + } + }) } } func TestSystemBackend_Plugin_MissingBinary(t *testing.T) { - cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) - defer cluster.Cleanup() - - core := cluster.Cores[0] - - // Make a request to lazy load the plugin - req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") - req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(namespace.RootContext(nil), req) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp == nil { - t.Fatalf("bad: response should not be nil") + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5_multiplexed", + }, + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, } - // Seal the cluster - cluster.EnsureCoresSealed(t) + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion) + defer cluster.Cleanup() - // Simulate removal of the plugin binary. Use os.Args to determine file name - // since that's how we create the file for catalog registration in the test - // helper. - pluginFileName := filepath.Base(os.Args[0]) - err = os.Remove(filepath.Join(cluster.TempDir, pluginFileName)) - if err != nil { - t.Fatal(err) - } + core := cluster.Cores[0] - // Unseal the cluster - cluster.UnsealCores(t) + // Make a request to lazy load the plugin + req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(namespace.RootContext(testCtx), req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: response should not be nil") + } - // Make a request against on tune after it is removed - req = logical.TestRequest(t, logical.ReadOperation, "sys/mounts/mock-0/tune") - req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(namespace.RootContext(nil), req) - if err == nil { - t.Fatalf("expected error") + // Seal the cluster + cluster.EnsureCoresSealed(t) + + // Simulate removal of the plugin binary. Use os.Args to determine file name + // since that's how we create the file for catalog registration in the test + // helper. + pluginFileName := filepath.Base(os.Args[0]) + err = os.Remove(filepath.Join(cluster.TempDir, pluginFileName)) + if err != nil { + t.Fatal(err) + } + + // Unseal the cluster + cluster.UnsealCores(t) + + // Make a request against on tune after it is removed + req = logical.TestRequest(t, logical.ReadOperation, "sys/mounts/mock-0/tune") + req.ClientToken = core.Client.Token() + _, err = core.HandleRequest(namespace.RootContext(testCtx), req) + if err == nil { + t.Fatalf("expected error") + } + }) } } func TestSystemBackend_Plugin_MismatchType(t *testing.T) { - cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) - defer cluster.Cleanup() - - core := cluster.Cores[0] - - // Add a credential backend with the same name - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, "") - - // Make a request to lazy load the now-credential plugin - // and expect an error - req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") - req.ClientToken = core.Client.Token() - _, err := core.HandleRequest(namespace.RootContext(nil), req) - if err != nil { - t.Fatalf("adding a same-named plugin of a different type should be no problem: %s", err) + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5_multiplexed", + }, + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, } - // Sleep a bit before cleanup is called - time.Sleep(1 * time.Second) + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion) + defer cluster.Cleanup() + + core := cluster.Cores[0] + + // Add a credential backend with the same name + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, "") + + // Make a request to lazy load the now-credential plugin + // and expect an error + req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") + req.ClientToken = core.Client.Token() + _, err := core.HandleRequest(namespace.RootContext(testCtx), req) + if err != nil { + t.Fatalf("adding a same-named plugin of a different type should be no problem: %s", err) + } + + // Sleep a bit before cleanup is called + time.Sleep(1 * time.Second) + }) + } } func TestSystemBackend_Plugin_CatalogRemoved(t *testing.T) { t.Run("secret", func(t *testing.T) { - testPlugin_CatalogRemoved(t, logical.TypeLogical, false) + testPlugin_CatalogRemoved(t, logical.TypeLogical, false, logicalVersionMap) }) t.Run("auth", func(t *testing.T) { - testPlugin_CatalogRemoved(t, logical.TypeCredential, false) + testPlugin_CatalogRemoved(t, logical.TypeCredential, false, credentialVersionMap) }) t.Run("secret-mount-existing", func(t *testing.T) { - testPlugin_CatalogRemoved(t, logical.TypeLogical, true) + testPlugin_CatalogRemoved(t, logical.TypeLogical, true, logicalVersionMap) }) t.Run("auth-mount-existing", func(t *testing.T) { - testPlugin_CatalogRemoved(t, logical.TypeCredential, true) + testPlugin_CatalogRemoved(t, logical.TypeCredential, true, credentialVersionMap) }) } -func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMount bool) { - cluster := testSystemBackendMock(t, 1, 1, btype) - defer cluster.Cleanup() - - core := cluster.Cores[0] - - // Remove the plugin from the catalog - req := logical.TestRequest(t, logical.DeleteOperation, "sys/plugins/catalog/database/mock-plugin") - req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v", err, resp) +func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMount bool, versionMap map[string]string) { + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5_multiplexed", + }, + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, } - // Seal the cluster - cluster.EnsureCoresSealed(t) + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion) + defer cluster.Cleanup() - // Unseal the cluster - barrierKeys := cluster.BarrierKeys - for _, core := range cluster.Cores { - for _, key := range barrierKeys { - _, err := core.Unseal(vault.TestKeyCopy(key)) - if err != nil { - t.Fatal(err) + core := cluster.Cores[0] + + // Remove the plugin from the catalog + req := logical.TestRequest(t, logical.DeleteOperation, "sys/plugins/catalog/database/mock-plugin") + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(namespace.RootContext(testCtx), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) } - } - if core.Sealed() { - t.Fatal("should not be sealed") - } - } - // Wait for active so post-unseal takes place - // If it fails, it means unseal process failed - vault.TestWaitActive(t, core.Core) + // Seal the cluster + cluster.EnsureCoresSealed(t) - if testMount { - // Mount the plugin at the same path after plugin is re-added to the catalog - // and expect an error due to existing path. - var err error - switch btype { - case logical.TypeLogical: - // Add plugin back to the catalog - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMainLogical", []string{}, "") - _, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{ - "type": "test", - }) - case logical.TypeCredential: - // Add plugin back to the catalog - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, "") - _, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{ - "type": "test", - }) - } - if err == nil { - t.Fatal("expected error when mounting on existing path") - } + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + if core.Sealed() { + t.Fatal("should not be sealed") + } + } + + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + + if testMount { + // Mount the plugin at the same path after plugin is re-added to the catalog + // and expect an error due to existing path. + var err error + switch btype { + case logical.TypeLogical: + // Add plugin back to the catalog + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, logicalVersionMap[tc.pluginVersion], []string{}, "") + _, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{ + "type": "test", + }) + case logical.TypeCredential: + // Add plugin back to the catalog + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, credentialVersionMap[tc.pluginVersion], []string{}, "") + _, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{ + "type": "test", + }) + } + if err == nil { + t.Fatal("expected error when mounting on existing path") + } + } + }) } } @@ -278,180 +385,224 @@ func TestSystemBackend_Plugin_continueOnError(t *testing.T) { } func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatch bool, mountPoint string, pluginType consts.PluginType) { - cluster := testSystemBackendMock(t, 1, 1, btype) - defer cluster.Cleanup() - - core := cluster.Cores[0] - - // Get the registered plugin - req := logical.TestRequest(t, logical.ReadOperation, fmt.Sprintf("sys/plugins/catalog/%s/mock-plugin", pluginType)) - req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(namespace.RootContext(nil), req) - if err != nil || resp == nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v", err, resp) + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5_multiplexed", + }, + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, } - command, ok := resp.Data["command"].(string) - if !ok || command == "" { - t.Fatal("invalid command") - } + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, btype, tc.pluginVersion) + defer cluster.Cleanup() - // Mount credential type plugins - switch btype { - case logical.TypeCredential: - vault.TestAddTestPlugin(t, core.Core, mountPoint, consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, cluster.TempDir) - _, err = core.Client.Logical().Write(fmt.Sprintf("sys/auth/%s", mountPoint), map[string]interface{}{ - "type": "mock-plugin", - }) - if err != nil { - t.Fatalf("err:%v", err) - } - } + core := cluster.Cores[0] - // Trigger a sha256 mismatch or missing plugin error - if mismatch { - req = logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("sys/plugins/catalog/%s/mock-plugin", pluginType)) - req.Data = map[string]interface{}{ - "sha256": "d17bd7334758e53e6fbab15745d2520765c06e296f2ce8e25b7919effa0ac216", - "command": filepath.Base(command), - } - req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v", err, resp) - } - } else { - err := os.Remove(filepath.Join(cluster.TempDir, filepath.Base(command))) - if err != nil { - t.Fatal(err) - } - } - - // Seal the cluster - cluster.EnsureCoresSealed(t) - - // Unseal the cluster - barrierKeys := cluster.BarrierKeys - for _, core := range cluster.Cores { - for _, key := range barrierKeys { - _, err := core.Unseal(vault.TestKeyCopy(key)) - if err != nil { - t.Fatal(err) + // Get the registered plugin + req := logical.TestRequest(t, logical.ReadOperation, fmt.Sprintf("sys/plugins/catalog/%s/mock-plugin", pluginType)) + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(namespace.RootContext(testCtx), req) + if err != nil || resp == nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) } - } - if core.Sealed() { - t.Fatal("should not be sealed") - } - } - // Wait for active so post-unseal takes place - // If it fails, it means unseal process failed - vault.TestWaitActive(t, core.Core) + command, ok := resp.Data["command"].(string) + if !ok || command == "" { + t.Fatal("invalid command") + } - // Re-add the plugin to the catalog - switch btype { - case logical.TypeLogical: - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMainLogical", []string{}, cluster.TempDir) - case logical.TypeCredential: - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, cluster.TempDir) - } + // Trigger a sha256 mismatch or missing plugin error + if mismatch { + req = logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("sys/plugins/catalog/%s/mock-plugin", pluginType)) + req.Data = map[string]interface{}{ + "sha256": "d17bd7334758e53e6fbab15745d2520765c06e296f2ce8e25b7919effa0ac216", + "command": filepath.Base(command), + } + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(namespace.RootContext(testCtx), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + } else { + err := os.Remove(filepath.Join(cluster.TempDir, filepath.Base(command))) + if err != nil { + t.Fatal(err) + } + } - // Reload the plugin - req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/reload/backend") - req.Data = map[string]interface{}{ - "plugin": "mock-plugin", - } - req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v", err, resp) - } + // Seal the cluster + cluster.EnsureCoresSealed(t) - // Make a request to lazy load the plugin - var reqPath string - switch btype { - case logical.TypeLogical: - reqPath = "mock-0/internal" - case logical.TypeCredential: - reqPath = "auth/mock-0/internal" - } + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + if core.Sealed() { + t.Fatal("should not be sealed") + } + } - req = logical.TestRequest(t, logical.ReadOperation, reqPath) - req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(namespace.RootContext(nil), req) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp == nil { - t.Fatalf("bad: response should not be nil") + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + + // Re-add the plugin to the catalog + switch btype { + case logical.TypeLogical: + plugin := logicalVersionMap[tc.pluginVersion] + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, plugin, []string{}, cluster.TempDir) + case logical.TypeCredential: + plugin := credentialVersionMap[tc.pluginVersion] + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, plugin, []string{}, cluster.TempDir) + } + + // Reload the plugin + req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/reload/backend") + req.Data = map[string]interface{}{ + "plugin": "mock-plugin", + } + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(namespace.RootContext(testCtx), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + // Make a request to lazy load the plugin + var reqPath string + switch btype { + case logical.TypeLogical: + reqPath = "mock-0/internal" + case logical.TypeCredential: + reqPath = "auth/mock-0/internal" + } + + req = logical.TestRequest(t, logical.ReadOperation, reqPath) + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(namespace.RootContext(testCtx), req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: response should not be nil") + } + }) } } func TestSystemBackend_Plugin_autoReload(t *testing.T) { - cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) - defer cluster.Cleanup() - - core := cluster.Cores[0] - - // Update internal value - req := logical.TestRequest(t, logical.UpdateOperation, "mock-0/internal") - req.ClientToken = core.Client.Token() - req.Data["value"] = "baz" - resp, err := core.HandleRequest(namespace.RootContext(nil), req) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp != nil { - t.Fatalf("bad: %v", resp) + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5_multiplexed", + }, + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, } - // Call errors/rpc endpoint to trigger reload - req = logical.TestRequest(t, logical.ReadOperation, "mock-0/errors/rpc") - req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(namespace.RootContext(nil), req) - if err == nil { - t.Fatalf("expected error from error/rpc request") - } + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion) + defer cluster.Cleanup() - // Check internal value to make sure it's reset - req = logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") - req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(namespace.RootContext(nil), req) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp == nil { - t.Fatalf("bad: response should not be nil") - } - if resp.Data["value"].(string) == "baz" { - t.Fatal("did not expect backend internal value to be 'baz'") + core := cluster.Cores[0] + + // Update internal value + req := logical.TestRequest(t, logical.UpdateOperation, "mock-0/internal") + req.ClientToken = core.Client.Token() + req.Data["value"] = "baz" + resp, err := core.HandleRequest(namespace.RootContext(testCtx), req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %v", resp) + } + + // Call errors/rpc endpoint to trigger reload + req = logical.TestRequest(t, logical.ReadOperation, "mock-0/errors/rpc") + req.ClientToken = core.Client.Token() + _, err = core.HandleRequest(namespace.RootContext(testCtx), req) + if err == nil { + t.Fatalf("expected error from error/rpc request") + } + + // Check internal value to make sure it's reset + req = logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(namespace.RootContext(testCtx), req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: response should not be nil") + } + if resp.Data["value"].(string) == "baz" { + t.Fatal("did not expect backend internal value to be 'baz'") + } + }) } } func TestSystemBackend_Plugin_SealUnseal(t *testing.T) { - cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) - defer cluster.Cleanup() - - // Seal the cluster - cluster.EnsureCoresSealed(t) - - // Unseal the cluster - barrierKeys := cluster.BarrierKeys - for _, core := range cluster.Cores { - for _, key := range barrierKeys { - _, err := core.Unseal(vault.TestKeyCopy(key)) - if err != nil { - t.Fatal(err) - } - } - if core.Sealed() { - t.Fatal("should not be sealed") - } + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5_multiplexed", + }, + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, } - // Wait for active so post-unseal takes place - // If it fails, it means unseal process failed - vault.TestWaitActive(t, cluster.Cores[0].Core) + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion) + defer cluster.Cleanup() + + // Seal the cluster + cluster.EnsureCoresSealed(t) + + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + if core.Sealed() { + t.Fatal("should not be sealed") + } + } + + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, cluster.Cores[0].Core) + }) + } } func TestSystemBackend_Plugin_reload(t *testing.T) { @@ -498,53 +649,71 @@ func TestSystemBackend_Plugin_reload(t *testing.T) { // Helper func to test different reload methods on plugin reload endpoint func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}, backendType logical.BackendType) { - cluster := testSystemBackendMock(t, 1, 2, backendType) - defer cluster.Cleanup() - - core := cluster.Cores[0] - client := core.Client - - pathPrefix := "mock-" - if backendType == logical.TypeCredential { - pathPrefix = "auth/" + pathPrefix + testCases := []struct { + pluginVersion string + }{ + { + pluginVersion: "v5_multiplexed", + }, + { + pluginVersion: "v5", + }, + { + pluginVersion: "v4", + }, } - for i := 0; i < 2; i++ { - // Update internal value in the backend - resp, err := client.Logical().Write(fmt.Sprintf("%s%d/internal", pathPrefix, i), map[string]interface{}{ - "value": "baz", + + for _, tc := range testCases { + t.Run(tc.pluginVersion, func(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 2, backendType, tc.pluginVersion) + defer cluster.Cleanup() + + core := cluster.Cores[0] + client := core.Client + + pathPrefix := "mock-" + if backendType == logical.TypeCredential { + pathPrefix = "auth/" + pathPrefix + } + for i := 0; i < 2; i++ { + // Update internal value in the backend + resp, err := client.Logical().Write(fmt.Sprintf("%s%d/internal", pathPrefix, i), map[string]interface{}{ + "value": "baz", + }) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %v", resp) + } + } + + // Perform plugin reload + resp, err := client.Logical().Write("sys/plugins/reload/backend", reqData) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: %v", resp) + } + if resp.Data["reload_id"] == nil { + t.Fatal("no reload_id in response") + } + + for i := 0; i < 2; i++ { + // Ensure internal backed value is reset + resp, err := client.Logical().Read(fmt.Sprintf("%s%d/internal", pathPrefix, i)) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: response should not be nil") + } + if resp.Data["value"].(string) == "baz" { + t.Fatal("did not expect backend internal value to be 'baz'") + } + } }) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp != nil { - t.Fatalf("bad: %v", resp) - } - } - - // Perform plugin reload - resp, err := client.Logical().Write("sys/plugins/reload/backend", reqData) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp == nil { - t.Fatalf("bad: %v", resp) - } - if resp.Data["reload_id"] == nil { - t.Fatal("no reload_id in response") - } - - for i := 0; i < 2; i++ { - // Ensure internal backed value is reset - resp, err := client.Logical().Read(fmt.Sprintf("%s%d/internal", pathPrefix, i)) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp == nil { - t.Fatalf("bad: response should not be nil") - } - if resp.Data["value"].(string) == "baz" { - t.Fatal("did not expect backend internal value to be 'baz'") - } } } @@ -553,7 +722,7 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{} // ways of providing the plugin_name. // // The mounts are mounted at sys/mounts/mock-[numMounts] or sys/auth/mock-[numMounts] -func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType) *vault.TestCluster { +func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType, pluginVersion string) *vault.TestCluster { coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "plugin": plugin.Factory, @@ -585,7 +754,8 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo switch backendType { case logical.TypeLogical: - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMainLogical", []string{}, tempDir) + plugin := logicalVersionMap[pluginVersion] + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, plugin, []string{}, tempDir) for i := 0; i < numMounts; i++ { // Alternate input styles for plugin_name on every other mount options := map[string]interface{}{ @@ -600,7 +770,8 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo } } case logical.TypeCredential: - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, tempDir) + plugin := credentialVersionMap[pluginVersion] + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, plugin, []string{}, tempDir) for i := 0; i < numMounts; i++ { // Alternate input styles for plugin_name on every other mount options := map[string]interface{}{ @@ -671,9 +842,15 @@ func testSystemBackend_SingleCluster_Env(t *testing.T, env []string) *vault.Test return cluster } -func TestBackend_PluginMainLogical(t *testing.T) { +func TestBackend_PluginMain_V4_Logical(t *testing.T) { args := []string{} - if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadataModeEnv) != "true" { + // don't run as a standalone unit test + if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { + return + } + + // don't run as a V5 plugin + if os.Getenv(pluginutil.PluginAutoMTLSEnv) == "true" { return } @@ -686,6 +863,8 @@ func TestBackend_PluginMainLogical(t *testing.T) { apiClientMeta := &api.PluginAPIClientMeta{} flags := apiClientMeta.FlagSet() flags.Parse(args) + + // V4 does not support AutoMTLS so we set a TLSConfig via TLSProviderFunc tlsConfig := apiClientMeta.GetTLSConfig() tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig) @@ -700,9 +879,9 @@ func TestBackend_PluginMainLogical(t *testing.T) { } } -func TestBackend_PluginMainCredentials(t *testing.T) { +func TestBackend_PluginMain_Multiplexed_Logical(t *testing.T) { args := []string{} - if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadataModeEnv) != "true" { + if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { return } @@ -715,6 +894,66 @@ func TestBackend_PluginMainCredentials(t *testing.T) { apiClientMeta := &api.PluginAPIClientMeta{} flags := apiClientMeta.FlagSet() flags.Parse(args) + + factoryFunc := mock.FactoryType(logical.TypeLogical) + + err := lplugin.ServeMultiplex(&lplugin.ServeOpts{ + BackendFactoryFunc: factoryFunc, + }) + if err != nil { + t.Fatal(err) + } +} + +func TestBackend_PluginMainLogical(t *testing.T) { + args := []string{} + if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { + return + } + + caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv) + if caPEM == "" { + t.Fatal("CA cert not passed in") + } + args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM)) + + apiClientMeta := &api.PluginAPIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(args) + + factoryFunc := mock.FactoryType(logical.TypeLogical) + + err := lplugin.Serve(&lplugin.ServeOpts{ + BackendFactoryFunc: factoryFunc, + }) + if err != nil { + t.Fatal(err) + } +} + +func TestBackend_PluginMain_V4_Credentials(t *testing.T) { + args := []string{} + // don't run as a standalone unit test + if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { + return + } + + // don't run as a V5 plugin + if os.Getenv(pluginutil.PluginAutoMTLSEnv) == "true" { + return + } + + caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv) + if caPEM == "" { + t.Fatal("CA cert not passed in") + } + args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM)) + + apiClientMeta := &api.PluginAPIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(args) + + // V4 does not support AutoMTLS so we set a TLSConfig via TLSProviderFunc tlsConfig := apiClientMeta.GetTLSConfig() tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig) @@ -729,6 +968,58 @@ func TestBackend_PluginMainCredentials(t *testing.T) { } } +func TestBackend_PluginMain_Multiplexed_Credentials(t *testing.T) { + args := []string{} + if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { + return + } + + caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv) + if caPEM == "" { + t.Fatal("CA cert not passed in") + } + args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM)) + + apiClientMeta := &api.PluginAPIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(args) + + factoryFunc := mock.FactoryType(logical.TypeCredential) + + err := lplugin.ServeMultiplex(&lplugin.ServeOpts{ + BackendFactoryFunc: factoryFunc, + }) + if err != nil { + t.Fatal(err) + } +} + +func TestBackend_PluginMainCredentials(t *testing.T) { + args := []string{} + if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { + return + } + + caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv) + if caPEM == "" { + t.Fatal("CA cert not passed in") + } + args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM)) + + apiClientMeta := &api.PluginAPIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(args) + + factoryFunc := mock.FactoryType(logical.TypeCredential) + + err := lplugin.Serve(&lplugin.ServeOpts{ + BackendFactoryFunc: factoryFunc, + }) + if err != nil { + t.Fatal(err) + } +} + // TestBackend_PluginMainEnv is a mock plugin that simply checks for the existence of FOO env var. func TestBackend_PluginMainEnv(t *testing.T) { args := []string{} @@ -751,14 +1042,11 @@ func TestBackend_PluginMainEnv(t *testing.T) { apiClientMeta := &api.PluginAPIClientMeta{} flags := apiClientMeta.FlagSet() flags.Parse(args) - tlsConfig := apiClientMeta.GetTLSConfig() - tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig) factoryFunc := mock.FactoryType(logical.TypeLogical) err := lplugin.Serve(&lplugin.ServeOpts{ BackendFactoryFunc: factoryFunc, - TLSProviderFunc: tlsProviderFunc, }) if err != nil { t.Fatal(err) diff --git a/vault/mount.go b/vault/mount.go index 4ab0e5200..f69cd072c 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -1432,7 +1432,6 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView f = wrapFactoryCheckPerms(c, plugin.Factory) } } - // Set up conf to pass in plugin_name conf := make(map[string]string) for k, v := range entry.Options { diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 155e82ca1..c49b3c67e 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -30,10 +30,11 @@ import ( ) var ( - pluginCatalogPath = "core/plugin-catalog/" - ErrDirectoryNotConfigured = errors.New("could not set plugin, plugin directory is not configured") - ErrPluginNotFound = errors.New("plugin not found in the catalog") - ErrPluginBadType = errors.New("unable to determine plugin type") + pluginCatalogPath = "core/plugin-catalog/" + ErrDirectoryNotConfigured = errors.New("could not set plugin, plugin directory is not configured") + ErrPluginNotFound = errors.New("plugin not found in the catalog") + ErrPluginConnectionNotFound = errors.New("plugin connection not found for client") + ErrPluginBadType = errors.New("unable to determine plugin type") ) // PluginCatalog keeps a record of plugins known to vault. External plugins need @@ -78,13 +79,15 @@ type pluginClient struct { logger log.Logger // id is the connection ID - id string + id string + pid int // client handles the lifecycle of a plugin process // multiplexed plugins share the same client client *plugin.Client clientConn grpc.ClientConnInterface cleanupFunc func() error + reloadFunc func() error plugin.ClientProtocol } @@ -148,6 +151,38 @@ func (p *pluginClient) Conn() grpc.ClientConnInterface { return p.clientConn } +func (p *pluginClient) Reload() error { + p.logger.Debug("reload external plugin process") + return p.reloadFunc() +} + +// reloadExternalPlugin +// This should be called with the write lock held. +func (c *PluginCatalog) reloadExternalPlugin(name, id string) error { + extPlugin, ok := c.externalPlugins[name] + if !ok { + return fmt.Errorf("plugin client not found") + } + if !extPlugin.multiplexingSupport { + err := c.cleanupExternalPlugin(name, id) + if err != nil { + return err + } + return nil + } + + pc, ok := extPlugin.connections[id] + if !ok { + return fmt.Errorf("%w id: %s", ErrPluginConnectionNotFound, id) + } + + delete(c.externalPlugins, name) + pc.client.Kill() + c.logger.Debug("killed external plugin process for reload", "name", name, "pid", pc.pid) + + return nil +} + // Close calls the plugin client's cleanupFunc to do any necessary cleanup on // the plugin client and the PluginCatalog. This implements the // plugin.ClientProtocol interface. @@ -167,19 +202,26 @@ func (c *PluginCatalog) cleanupExternalPlugin(name, id string) error { pc, ok := extPlugin.connections[id] if !ok { - return fmt.Errorf("plugin connection not found") + // this can happen if the backend is reloaded due to a plugin process + // being killed out of band + c.logger.Warn(ErrPluginConnectionNotFound.Error(), "id", id) + return fmt.Errorf("%w id: %s", ErrPluginConnectionNotFound, id) } delete(extPlugin.connections, id) + c.logger.Debug("removed plugin client connection", "id", id) + if !extPlugin.multiplexingSupport { pc.client.Kill() if len(extPlugin.connections) == 0 { delete(c.externalPlugins, name) } + c.logger.Debug("killed external plugin process", "name", name, "pid", pc.pid) } else if len(extPlugin.connections) == 0 || pc.client.Exited() { pc.client.Kill() delete(c.externalPlugins, name) + c.logger.Debug("killed external multiplexed plugin process", "name", name, "pid", pc.pid) } return nil @@ -252,6 +294,11 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi defer c.lock.Unlock() return c.cleanupExternalPlugin(pluginRunner.Name, id) }, + reloadFunc: func() error { + c.lock.Lock() + defer c.lock.Unlock() + return c.reloadExternalPlugin(pluginRunner.Name, id) + }, } // Multiplexing support will always be false initially, but will be @@ -264,9 +311,8 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi pluginutil.Logger(config.Logger), pluginutil.MetadataMode(config.IsMetadataMode), pluginutil.MLock(c.mlockPlugins), - - // NewPluginClient only supports AutoMTLS today - pluginutil.AutoMTLS(true), + pluginutil.AutoMTLS(config.AutoMTLS), + pluginutil.Runner(config.Wrapper), ) if err != nil { return nil, err @@ -294,6 +340,12 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi return nil, err } + // get the external plugin pid + conf := pc.client.ReattachConfig() + if conf != nil { + pc.pid = conf.Pid + } + clientConn := rpcClient.(*plugin.GRPCClient).Conn muxed, err := pluginutil.MultiplexingSupported(ctx, clientConn) @@ -322,9 +374,8 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi } // getPluginTypeFromUnknown will attempt to run the plugin to determine the -// type and if it supports multiplexing. It will first attempt to run as a -// database plugin then a backend plugin. Both of these will be run in metadata -// mode. +// type. It will first attempt to run as a database plugin then a backend +// plugin. func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, error) { merr := &multierror.Error{} err := c.isDatabasePlugin(ctx, plugin) @@ -333,43 +384,95 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log } merr = multierror.Append(merr, err) - // Attempt to run as backend plugin - client, err := backendplugin.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true) + pluginType, err := c.getBackendPluginType(ctx, plugin) if err == nil { - err := client.Setup(ctx, &logical.BackendConfig{}) + return pluginType, nil + } + merr = multierror.Append(merr, err) + + return consts.PluginTypeUnknown, merr +} + +// getBackendPluginType returns an error if the plugin is not a backend plugin. +func (c *PluginCatalog) getBackendPluginType(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (consts.PluginType, error) { + merr := &multierror.Error{} + // Attempt to run as backend plugin + config := pluginutil.PluginClientConfig{ + Name: pluginRunner.Name, + PluginSets: backendplugin.PluginSet, + HandshakeConfig: backendplugin.HandshakeConfig, + Logger: log.NewNullLogger(), + IsMetadataMode: false, + AutoMTLS: true, + } + + var client logical.Backend + var attemptV4 bool + // First, attempt to run as backend V5 plugin + c.logger.Debug("attempting to load backend plugin", "name", pluginRunner.Name) + pc, err := c.newPluginClient(ctx, pluginRunner, config) + if err == nil { + // we spawned a subprocess, so make sure to clean it up + defer c.cleanupExternalPlugin(pluginRunner.Name, pc.id) + + // dispense the plugin so we can get its type + client, err = backendplugin.Dispense(pc.ClientProtocol, pc) if err != nil { - return consts.PluginTypeUnknown, err - } - - backendType := client.Type() - client.Cleanup(ctx) - - switch backendType { - case logical.TypeCredential: - return consts.PluginTypeCredential, nil - case logical.TypeLogical: - return consts.PluginTypeSecrets, nil + merr = multierror.Append(merr, fmt.Errorf("failed to dispense plugin as backend v5: %w", err)) + c.logger.Debug("failed to dispense v5 backend plugin", "name", pluginRunner.Name) + attemptV4 = true + } else { + c.logger.Debug("successfully dispensed v5 backend plugin", "name", pluginRunner.Name) } } else { - merr = multierror.Append(merr, err) + attemptV4 = true + } + + if attemptV4 { + c.logger.Debug("failed to dispense v5 backend plugin", "name", pluginRunner.Name, "error", err) + config.AutoMTLS = false + config.IsMetadataMode = true + // attempt to run as a v4 backend plugin + client, err = backendplugin.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true) + if err != nil { + c.logger.Debug("failed to dispense v4 backend plugin", "name", pluginRunner.Name, "error", err) + merr = multierror.Append(merr, fmt.Errorf("failed to dispense v4 backend plugin: %w", err)) + return consts.PluginTypeUnknown, merr.ErrorOrNil() + } + c.logger.Debug("successfully dispensed v4 backend plugin", "name", pluginRunner.Name) + defer client.Cleanup(ctx) + } + + err = client.Setup(ctx, &logical.BackendConfig{}) + if err != nil { + return consts.PluginTypeUnknown, err + } + backendType := client.Type() + + switch backendType { + case logical.TypeCredential: + return consts.PluginTypeCredential, nil + case logical.TypeLogical: + return consts.PluginTypeSecrets, nil } if client == nil || client.Type() == logical.TypeUnknown { - logger.Warn("unknown plugin type", - "plugin name", plugin.Name, + c.logger.Warn("unknown plugin type", + "plugin name", pluginRunner.Name, "error", merr.Error()) } else { - logger.Warn("unsupported plugin type", - "plugin name", plugin.Name, + c.logger.Warn("unsupported plugin type", + "plugin name", pluginRunner.Name, "plugin type", client.Type().String(), "error", merr.Error()) } - return consts.PluginTypeUnknown, nil + merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as backend plugin: %w", err)) + + return consts.PluginTypeUnknown, merr.ErrorOrNil() } -// isDatabasePlugin returns true if the plugin supports multiplexing. An error -// is returned if the plugin is not a database plugin. +// isDatabasePlugin returns an error if the plugin is not a database plugin. func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *pluginutil.PluginRunner) error { merr := &multierror.Error{} config := pluginutil.PluginClientConfig{ diff --git a/vault/plugin_catalog_test.go b/vault/plugin_catalog_test.go index 0e3b31e54..e2a16d8b1 100644 --- a/vault/plugin_catalog_test.go +++ b/vault/plugin_catalog_test.go @@ -10,11 +10,13 @@ import ( "sort" "testing" - log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/builtin/credential/userpass" "github.com/hashicorp/vault/plugins/database/postgresql" v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/pluginutil" + backendplugin "github.com/hashicorp/vault/sdk/plugin" "github.com/hashicorp/vault/helper/builtinplugins" ) @@ -59,7 +61,7 @@ func TestPluginCatalog_CRUD(t *testing.T) { } defer file.Close() - command := fmt.Sprintf("%s", filepath.Base(file.Name())) + command := filepath.Base(file.Name()) err = core.pluginCatalog.Set(context.Background(), pluginName, consts.PluginTypeDatabase, "", command, []string{"--test"}, []string{"FOO=BAR"}, []byte{'1'}) if err != nil { t.Fatal(err) @@ -375,47 +377,109 @@ func TestPluginCatalog_NewPluginClient(t *testing.T) { TestAddTestPlugin(t, core, "single-postgres-1", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_Postgres", []string{}, "") TestAddTestPlugin(t, core, "single-postgres-2", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_Postgres", []string{}, "") + TestAddTestPlugin(t, core, "mux-userpass", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_UserpassMultiplexed", []string{}, "") + TestAddTestPlugin(t, core, "single-userpass-1", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_Userpass", []string{}, "") + TestAddTestPlugin(t, core, "single-userpass-2", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_Userpass", []string{}, "") + + var pluginClients []*pluginClient // run plugins - if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("mux-postgres")); err != nil { - t.Fatal(err) - } - if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("mux-postgres")); err != nil { - t.Fatal(err) - } - if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("single-postgres-1")); err != nil { - t.Fatal(err) - } - if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("single-postgres-2")); err != nil { - t.Fatal(err) - } + // run "mux-postgres" twice which will start a single plugin for 2 + // distinct connections + c := TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "mux-postgres") + pluginClients = append(pluginClients, c) + c = TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "mux-postgres") + pluginClients = append(pluginClients, c) + c = TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "single-postgres-1") + pluginClients = append(pluginClients, c) + c = TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "single-postgres-2") + pluginClients = append(pluginClients, c) + + // run "mux-userpass" twice which will start a single plugin for 2 + // distinct connections + c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "mux-userpass") + pluginClients = append(pluginClients, c) + c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "mux-userpass") + pluginClients = append(pluginClients, c) + c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "single-userpass-1") + pluginClients = append(pluginClients, c) + c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "single-userpass-2") + pluginClients = append(pluginClients, c) externalPlugins := core.pluginCatalog.externalPlugins - if len(externalPlugins) != 3 { - t.Fatalf("expected externalPlugins map to be of len 3 but got %d", len(externalPlugins)) + if len(externalPlugins) != 6 { + t.Fatalf("expected externalPlugins map to be of len 6 but got %d", len(externalPlugins)) } // check connections map - expectedLen := 2 - if len(externalPlugins["mux-postgres"].connections) != expectedLen { - t.Fatalf("expected multiplexed external plugin's connections map to be of len %d but got %d", expectedLen, len(externalPlugins["mux-postgres"].connections)) - } - expectedLen = 1 - if len(externalPlugins["single-postgres-1"].connections) != expectedLen { - t.Fatalf("expected multiplexed external plugin's connections map to be of len %d but got %d", expectedLen, len(externalPlugins["mux-postgres"].connections)) - } - if len(externalPlugins["single-postgres-2"].connections) != expectedLen { - t.Fatalf("expected multiplexed external plugin's connections map to be of len %d but got %d", expectedLen, len(externalPlugins["mux-postgres"].connections)) - } + expectConnectionLen(t, 2, externalPlugins["mux-postgres"].connections) + expectConnectionLen(t, 1, externalPlugins["single-postgres-1"].connections) + expectConnectionLen(t, 1, externalPlugins["single-postgres-2"].connections) + expectConnectionLen(t, 2, externalPlugins["mux-userpass"].connections) + expectConnectionLen(t, 1, externalPlugins["single-userpass-1"].connections) + expectConnectionLen(t, 1, externalPlugins["single-userpass-2"].connections) // check multiplexing support - if !externalPlugins["mux-postgres"].multiplexingSupport { - t.Fatalf("expected external plugin to be multiplexed") + expectMultiplexingSupport(t, true, externalPlugins["mux-postgres"].multiplexingSupport) + expectMultiplexingSupport(t, false, externalPlugins["single-postgres-1"].multiplexingSupport) + expectMultiplexingSupport(t, false, externalPlugins["single-postgres-2"].multiplexingSupport) + expectMultiplexingSupport(t, true, externalPlugins["mux-userpass"].multiplexingSupport) + expectMultiplexingSupport(t, false, externalPlugins["single-userpass-1"].multiplexingSupport) + expectMultiplexingSupport(t, false, externalPlugins["single-userpass-2"].multiplexingSupport) + + // cleanup all of the external plugin processes + for _, client := range pluginClients { + client.Close() } - if externalPlugins["single-postgres-1"].multiplexingSupport { - t.Fatalf("expected external plugin to be non-multiplexed") + + // check that externalPlugins map is cleaned up + if len(externalPlugins) != 0 { + t.Fatalf("expected external plugin map to be of len 0 but got %d", len(externalPlugins)) } - if externalPlugins["single-postgres-2"].multiplexingSupport { - t.Fatalf("expected external plugin to be non-multiplexed") +} + +func TestPluginCatalog_PluginMain_Userpass(t *testing.T) { + if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { + return + } + + apiClientMeta := &api.PluginAPIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(os.Args[1:]) + + tlsConfig := apiClientMeta.GetTLSConfig() + tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig) + + err := backendplugin.Serve( + &backendplugin.ServeOpts{ + BackendFactoryFunc: userpass.Factory, + TLSProviderFunc: tlsProviderFunc, + }, + ) + if err != nil { + t.Fatalf("Failed to initialize userpass: %s", err) + } +} + +func TestPluginCatalog_PluginMain_UserpassMultiplexed(t *testing.T) { + if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { + return + } + + apiClientMeta := &api.PluginAPIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(os.Args[1:]) + + tlsConfig := apiClientMeta.GetTLSConfig() + tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig) + + err := backendplugin.ServeMultiplex( + &backendplugin.ServeOpts{ + BackendFactoryFunc: userpass.Factory, + TLSProviderFunc: tlsProviderFunc, + }, + ) + if err != nil { + t.Fatalf("Failed to initialize userpass: %s", err) } } @@ -440,14 +504,16 @@ func TestPluginCatalog_PluginMain_PostgresMultiplexed(_ *testing.T) { v5.ServeMultiplex(postgresql.New) } -func testPluginClientConfig(pluginName string) pluginutil.PluginClientConfig { - return pluginutil.PluginClientConfig{ - Name: pluginName, - PluginType: consts.PluginTypeDatabase, - PluginSets: v5.PluginSets, - HandshakeConfig: v5.HandshakeConfig, - Logger: log.NewNullLogger(), - IsMetadataMode: false, - AutoMTLS: true, +// expectConnectionLen asserts that the PluginCatalog's externalPlugin +// connections map has a length of expectedLen +func expectConnectionLen(t *testing.T, expectedLen int, connections map[string]*pluginClient) { + if len(connections) != expectedLen { + t.Fatalf("expected external plugin's connections map to be of len %d but got %d", expectedLen, len(connections)) + } +} + +func expectMultiplexingSupport(t *testing.T, expected, actual bool) { + if expected != actual { + t.Fatalf("expected external plugin multiplexing support to be %t", expected) } } diff --git a/vault/plugin_reload.go b/vault/plugin_reload.go index b0c513e25..e60228a9b 100644 --- a/vault/plugin_reload.go +++ b/vault/plugin_reload.go @@ -10,6 +10,7 @@ import ( multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/sdk/plugin" ) // reloadMatchingPluginMounts reloads provided mounts, regardless of @@ -147,8 +148,11 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut // Only call Cleanup if backend is initialized if re.backend != nil { + // Pass a context value so that the plugin client will call the + // appropriate cleanup method for reloading + reloadCtx := context.WithValue(ctx, plugin.ContextKeyPluginReload, "reload") // Call backend's Cleanup routine - re.backend.Cleanup(ctx) + re.backend.Cleanup(reloadCtx) } view := re.storageView diff --git a/vault/testing.go b/vault/testing.go index 55ac127e6..31bf8e575 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -39,13 +39,16 @@ import ( "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/internalshared/configutil" dbMysql "github.com/hashicorp/vault/plugins/database/mysql" + v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/helper/salt" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/physical" physInmem "github.com/hashicorp/vault/sdk/physical/inmem" + backendplugin "github.com/hashicorp/vault/sdk/plugin" "github.com/hashicorp/vault/vault/cluster" "github.com/hashicorp/vault/vault/seal" "github.com/mitchellh/copystructure" @@ -563,6 +566,48 @@ func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.Plug } } +// TestRunTestPlugin runs the testFunc which has already been registered to the +// plugin catalog and returns a pluginClient. This can be called after calling +// TestAddTestPlugin. +func TestRunTestPlugin(t testing.T, c *Core, pluginType consts.PluginType, pluginName string) *pluginClient { + t.Helper() + config := TestPluginClientConfig(c, pluginType, pluginName) + client, err := c.pluginCatalog.NewPluginClient(context.Background(), config) + if err != nil { + t.Fatal(err) + } + + return client +} + +func TestPluginClientConfig(c *Core, pluginType consts.PluginType, pluginName string) pluginutil.PluginClientConfig { + switch pluginType { + case consts.PluginTypeCredential, consts.PluginTypeSecrets: + dsv := TestDynamicSystemView(c, nil) + return pluginutil.PluginClientConfig{ + Name: pluginName, + PluginType: pluginType, + PluginSets: backendplugin.PluginSet, + HandshakeConfig: backendplugin.HandshakeConfig, + Logger: log.NewNullLogger(), + AutoMTLS: true, + IsMetadataMode: false, + Wrapper: dsv, + } + case consts.PluginTypeDatabase: + return pluginutil.PluginClientConfig{ + Name: pluginName, + PluginType: pluginType, + PluginSets: v5.PluginSets, + HandshakeConfig: v5.HandshakeConfig, + Logger: log.NewNullLogger(), + AutoMTLS: true, + IsMetadataMode: false, + } + } + return pluginutil.PluginClientConfig{} +} + var ( testLogicalBackends = map[string]logical.Factory{} testCredentialBackends = map[string]logical.Factory{} diff --git a/vault/wrapping.go b/vault/wrapping.go index 0f613ba79..850d5ad06 100644 --- a/vault/wrapping.go +++ b/vault/wrapping.go @@ -141,7 +141,7 @@ DONELISTHANDLING: NamespaceID: ns.ID, } - if err := c.tokenStore.create(ctx, &te); err != nil { + if err := c.CreateToken(ctx, &te); err != nil { c.logger.Error("failed to create wrapping token", "error", err) return nil, ErrInternalError }