feature: secrets/auth plugin multiplexing (#14946)

* enable registering backend muxed plugins in plugin catalog

* set the sysview on the pluginconfig to allow enabling secrets/auth plugins

* store backend instances in map

* store single implementations in the instances map

cleanup instance map and ensure we don't deadlock

* fix system backend unit tests

move GetMultiplexIDFromContext to pluginutil package

fix pluginutil test

fix dbplugin ut

* return error(s) if we can't get the plugin client

update comments

* refactor/move GetMultiplexIDFromContext test

* add changelog

* remove unnecessary field on pluginClient

* add unit tests to PluginCatalog for secrets/auth plugins

* fix comment

* return pluginClient from TestRunTestPlugin

* add multiplexed backend test

* honor metadatamode value in newbackend pluginconfig

* check that connection exists on cleanup

* add automtls to secrets/auth plugins

* don't remove apiclientmeta parsing

* use formatting directive for fmt.Errorf

* fix ut: remove tls provider func

* remove tlsproviderfunc from backend plugin tests

* use env var to prevent test plugin from running as a unit test

* WIP: remove lazy loading

* move non lazy loaded backend to new package

* use version wrapper for backend plugin factory

* remove backendVersionWrapper type

* implement getBackendPluginType for plugin catalog

* handle backend plugin v4 registration

* add plugin automtls env guard

* modify plugin factory to determine the backend to use

* remove old pluginsets from v5 and log pid in plugin catalog

* add reload mechanism via context

* readd v3 and v4 to pluginset

* call cleanup from reload if non-muxed

* move v5 backend code to new package

* use context reload for for ErrPluginShutdown case

* add wrapper on v5 backend

* fix run config UTs

* fix unit tests

- use v4/v5 mapping for plugin versions
- fix test build err
- add reload method on fakePluginClient
- add multiplexed cases for integration tests

* remove comment and update AutoMTLS field in test

* remove comment

* remove errwrap and unused context

* only support metadatamode false for v5 backend plugins

* update plugin catalog errors

* use const for env variables

* rename locks and remove unused

* remove unneeded nil check

* improvements based on staticcheck recommendations

* use const for single implementation string

* use const for context key

* use info default log level

* move pid to pluginClient struct

* remove v3 and v4 from multiplexed plugin set

* return from reload when non-multiplexed

* update automtls env string

* combine getBackend and getBrokeredClient

* update comments for plugin reload, Backend return val and log

* revert Backend return type

* allow non-muxed plugins to serve v5

* move v5 code to existing sdk plugin package

* do next export sdk fields now that we have removed extra plugin pkg

* set TLSProvider in ServeMultiplex for backwards compat

* use bool to flag multiplexing support on grpc backend server

* revert userpass main.go

* refactor plugin sdk

- update comments
- make use of multiplexing boolean and single implementation ID const

* update comment and use multierr

* attempt v4 if dispense fails on getPluginTypeForUnknown

* update comments on sdk plugin backend
This commit is contained in:
John-Michael Faircloth 2022-08-29 21:42:26 -05:00 committed by GitHub
parent cf332842cc
commit b6c05fae33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1742 additions and 666 deletions

View File

@ -16,7 +16,11 @@ import (
"github.com/hashicorp/errwrap"
)
var (
const (
// PluginAutoMTLSEnv is used to ensure AutoMTLS is used. This will override
// setting a TLSProviderFunc for a plugin.
PluginAutoMTLSEnv = "VAULT_PLUGIN_AUTOMTLS_ENABLED"
// PluginMetadataModeEnv is an ENV name used to disable TLS communication
// to bootstrap mounting plugins.
PluginMetadataModeEnv = "VAULT_PLUGIN_METADATA_MODE"
@ -24,51 +28,51 @@ var (
// PluginUnwrapTokenEnv is the ENV name used to pass unwrap tokens to the
// plugin.
PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN"
// sudoPaths is a map containing the paths that require a token's policy
// to have the "sudo" capability. The keys are the paths as strings, in
// the same format as they are returned by the OpenAPI spec. The values
// are the regular expressions that can be used to test whether a given
// path matches that path or not (useful specifically for the paths that
// contain templated fields.)
sudoPaths = map[string]*regexp.Regexp{
"/auth/token/accessors/": regexp.MustCompile(`^/auth/token/accessors/$`),
"/pki/root": regexp.MustCompile(`^/pki/root$`),
"/pki/root/sign-self-issued": regexp.MustCompile(`^/pki/root/sign-self-issued$`),
"/sys/audit": regexp.MustCompile(`^/sys/audit$`),
"/sys/audit/{path}": regexp.MustCompile(`^/sys/audit/.+$`),
"/sys/auth/{path}": regexp.MustCompile(`^/sys/auth/.+$`),
"/sys/auth/{path}/tune": regexp.MustCompile(`^/sys/auth/.+/tune$`),
"/sys/config/auditing/request-headers": regexp.MustCompile(`^/sys/config/auditing/request-headers$`),
"/sys/config/auditing/request-headers/{header}": regexp.MustCompile(`^/sys/config/auditing/request-headers/.+$`),
"/sys/config/cors": regexp.MustCompile(`^/sys/config/cors$`),
"/sys/config/ui/headers/": regexp.MustCompile(`^/sys/config/ui/headers/$`),
"/sys/config/ui/headers/{header}": regexp.MustCompile(`^/sys/config/ui/headers/.+$`),
"/sys/leases": regexp.MustCompile(`^/sys/leases$`),
"/sys/leases/lookup/": regexp.MustCompile(`^/sys/leases/lookup/$`),
"/sys/leases/lookup/{prefix}": regexp.MustCompile(`^/sys/leases/lookup/.+$`),
"/sys/leases/revoke-force/{prefix}": regexp.MustCompile(`^/sys/leases/revoke-force/.+$`),
"/sys/leases/revoke-prefix/{prefix}": regexp.MustCompile(`^/sys/leases/revoke-prefix/.+$`),
"/sys/plugins/catalog/{name}": regexp.MustCompile(`^/sys/plugins/catalog/[^/]+$`),
"/sys/plugins/catalog/{type}": regexp.MustCompile(`^/sys/plugins/catalog/[\w-]+$`),
"/sys/plugins/catalog/{type}/{name}": regexp.MustCompile(`^/sys/plugins/catalog/[\w-]+/[^/]+$`),
"/sys/raw": regexp.MustCompile(`^/sys/raw$`),
"/sys/raw/{path}": regexp.MustCompile(`^/sys/raw/.+$`),
"/sys/remount": regexp.MustCompile(`^/sys/remount$`),
"/sys/revoke-force/{prefix}": regexp.MustCompile(`^/sys/revoke-force/.+$`),
"/sys/revoke-prefix/{prefix}": regexp.MustCompile(`^/sys/revoke-prefix/.+$`),
"/sys/rotate": regexp.MustCompile(`^/sys/rotate$`),
// enterprise-only paths
"/sys/replication/dr/primary/secondary-token": regexp.MustCompile(`^/sys/replication/dr/primary/secondary-token$`),
"/sys/replication/performance/primary/secondary-token": regexp.MustCompile(`^/sys/replication/performance/primary/secondary-token$`),
"/sys/replication/primary/secondary-token": regexp.MustCompile(`^/sys/replication/primary/secondary-token$`),
"/sys/replication/reindex": regexp.MustCompile(`^/sys/replication/reindex$`),
"/sys/storage/raft/snapshot-auto/config/": regexp.MustCompile(`^/sys/storage/raft/snapshot-auto/config/$`),
"/sys/storage/raft/snapshot-auto/config/{name}": regexp.MustCompile(`^/sys/storage/raft/snapshot-auto/config/[^/]+$`),
}
)
// sudoPaths is a map containing the paths that require a token's policy
// to have the "sudo" capability. The keys are the paths as strings, in
// the same format as they are returned by the OpenAPI spec. The values
// are the regular expressions that can be used to test whether a given
// path matches that path or not (useful specifically for the paths that
// contain templated fields.)
var sudoPaths = map[string]*regexp.Regexp{
"/auth/token/accessors/": regexp.MustCompile(`^/auth/token/accessors/$`),
"/pki/root": regexp.MustCompile(`^/pki/root$`),
"/pki/root/sign-self-issued": regexp.MustCompile(`^/pki/root/sign-self-issued$`),
"/sys/audit": regexp.MustCompile(`^/sys/audit$`),
"/sys/audit/{path}": regexp.MustCompile(`^/sys/audit/.+$`),
"/sys/auth/{path}": regexp.MustCompile(`^/sys/auth/.+$`),
"/sys/auth/{path}/tune": regexp.MustCompile(`^/sys/auth/.+/tune$`),
"/sys/config/auditing/request-headers": regexp.MustCompile(`^/sys/config/auditing/request-headers$`),
"/sys/config/auditing/request-headers/{header}": regexp.MustCompile(`^/sys/config/auditing/request-headers/.+$`),
"/sys/config/cors": regexp.MustCompile(`^/sys/config/cors$`),
"/sys/config/ui/headers/": regexp.MustCompile(`^/sys/config/ui/headers/$`),
"/sys/config/ui/headers/{header}": regexp.MustCompile(`^/sys/config/ui/headers/.+$`),
"/sys/leases": regexp.MustCompile(`^/sys/leases$`),
"/sys/leases/lookup/": regexp.MustCompile(`^/sys/leases/lookup/$`),
"/sys/leases/lookup/{prefix}": regexp.MustCompile(`^/sys/leases/lookup/.+$`),
"/sys/leases/revoke-force/{prefix}": regexp.MustCompile(`^/sys/leases/revoke-force/.+$`),
"/sys/leases/revoke-prefix/{prefix}": regexp.MustCompile(`^/sys/leases/revoke-prefix/.+$`),
"/sys/plugins/catalog/{name}": regexp.MustCompile(`^/sys/plugins/catalog/[^/]+$`),
"/sys/plugins/catalog/{type}": regexp.MustCompile(`^/sys/plugins/catalog/[\w-]+$`),
"/sys/plugins/catalog/{type}/{name}": regexp.MustCompile(`^/sys/plugins/catalog/[\w-]+/[^/]+$`),
"/sys/raw": regexp.MustCompile(`^/sys/raw$`),
"/sys/raw/{path}": regexp.MustCompile(`^/sys/raw/.+$`),
"/sys/remount": regexp.MustCompile(`^/sys/remount$`),
"/sys/revoke-force/{prefix}": regexp.MustCompile(`^/sys/revoke-force/.+$`),
"/sys/revoke-prefix/{prefix}": regexp.MustCompile(`^/sys/revoke-prefix/.+$`),
"/sys/rotate": regexp.MustCompile(`^/sys/rotate$`),
// enterprise-only paths
"/sys/replication/dr/primary/secondary-token": regexp.MustCompile(`^/sys/replication/dr/primary/secondary-token$`),
"/sys/replication/performance/primary/secondary-token": regexp.MustCompile(`^/sys/replication/performance/primary/secondary-token$`),
"/sys/replication/primary/secondary-token": regexp.MustCompile(`^/sys/replication/primary/secondary-token$`),
"/sys/replication/reindex": regexp.MustCompile(`^/sys/replication/reindex$`),
"/sys/storage/raft/snapshot-auto/config/": regexp.MustCompile(`^/sys/storage/raft/snapshot-auto/config/$`),
"/sys/storage/raft/snapshot-auto/config/{name}": regexp.MustCompile(`^/sys/storage/raft/snapshot-auto/config/[^/]+$`),
}
// PluginAPIClientMeta is a helper that plugins can use to configure TLS connections
// back to Vault.
type PluginAPIClientMeta struct {
@ -120,7 +124,7 @@ func VaultPluginTLSProvider(apiTLSConfig *TLSConfig) func() (*tls.Config, error)
// VaultPluginTLSProviderContext is run inside a plugin and retrieves the response
// wrapped TLS certificate from vault. It returns a configured TLS Config.
func VaultPluginTLSProviderContext(ctx context.Context, apiTLSConfig *TLSConfig) func() (*tls.Config, error) {
if os.Getenv(PluginMetadataModeEnv) == "true" {
if os.Getenv(PluginAutoMTLSEnv) == "true" || os.Getenv(PluginMetadataModeEnv) == "true" {
return nil
}

View File

@ -13,7 +13,6 @@ func main() {
apiClientMeta := &api.PluginAPIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(os.Args[1:])
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)

View File

@ -9,7 +9,9 @@ import (
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
uuid "github.com/hashicorp/go-uuid"
v5 "github.com/hashicorp/vault/builtin/plugin/v5"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
@ -23,17 +25,29 @@ var (
// Factory returns a configured plugin logical.Backend.
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
merr := &multierror.Error{}
_, ok := conf.Config["plugin_name"]
if !ok {
return nil, fmt.Errorf("plugin_name not provided")
}
b, err := Backend(ctx, conf)
b, err := v5.Backend(ctx, conf)
if err == nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
}
merr = multierror.Append(merr, err)
b, err = Backend(ctx, conf)
if err != nil {
return nil, err
merr = multierror.Append(merr, err)
return nil, fmt.Errorf("invalid backend version: %s", merr)
}
if err := b.Setup(ctx, conf); err != nil {
return nil, err
merr = multierror.Append(merr, err)
return nil, merr.ErrorOrNil()
}
return b, nil
}

View File

@ -24,22 +24,34 @@ func TestBackend_impl(t *testing.T) {
}
func TestBackend(t *testing.T) {
config, cleanup := testConfig(t)
defer cleanup()
pluginCmds := []string{"TestBackend_PluginMain", "TestBackend_PluginMain_Multiplexed"}
_, err := plugin.Backend(context.Background(), config)
if err != nil {
t.Fatal(err)
for _, pluginCmd := range pluginCmds {
t.Run(pluginCmd, func(t *testing.T) {
config, cleanup := testConfig(t, pluginCmd)
defer cleanup()
_, err := plugin.Backend(context.Background(), config)
if err != nil {
t.Fatal(err)
}
})
}
}
func TestBackend_Factory(t *testing.T) {
config, cleanup := testConfig(t)
defer cleanup()
pluginCmds := []string{"TestBackend_PluginMain", "TestBackend_PluginMain_Multiplexed"}
_, err := plugin.Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
for _, pluginCmd := range pluginCmds {
t.Run(pluginCmd, func(t *testing.T) {
config, cleanup := testConfig(t, pluginCmd)
defer cleanup()
_, err := plugin.Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
})
}
}
@ -71,7 +83,35 @@ func TestBackend_PluginMain(t *testing.T) {
}
}
func testConfig(t *testing.T) (*logical.BackendConfig, func()) {
func TestBackend_PluginMain_Multiplexed(t *testing.T) {
args := []string{}
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadataModeEnv) != "true" {
return
}
caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv)
if caPEM == "" {
t.Fatal("CA cert not passed in")
}
args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM))
apiClientMeta := &api.PluginAPIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(args)
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
err := logicalPlugin.ServeMultiplex(&logicalPlugin.ServeOpts{
BackendFactoryFunc: mock.Factory,
TLSProviderFunc: tlsProviderFunc,
})
if err != nil {
t.Fatal(err)
}
}
func testConfig(t *testing.T, pluginCmd string) (*logical.BackendConfig, func()) {
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
@ -93,7 +133,7 @@ func testConfig(t *testing.T) (*logical.BackendConfig, func()) {
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMain", []string{}, "")
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, pluginCmd, []string{}, "")
return config, func() {
cluster.Cleanup()

View File

@ -0,0 +1,147 @@
package plugin
import (
"context"
"net/rpc"
"sync"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/plugin"
bplugin "github.com/hashicorp/vault/sdk/plugin"
)
// Backend returns an instance of the backend, either as a plugin if external
// or as a concrete implementation if builtin, casted as logical.Backend.
func Backend(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
var b backend
name := conf.Config["plugin_name"]
pluginType, err := consts.ParsePluginType(conf.Config["plugin_type"])
if err != nil {
return nil, err
}
sys := conf.System
raw, err := plugin.NewBackendV5(ctx, name, pluginType, sys, conf)
if err != nil {
return nil, err
}
b.Backend = raw
b.config = conf
return &b, nil
}
// backend is a thin wrapper around plugin.BackendPluginClientV5
type backend struct {
logical.Backend
mu sync.RWMutex
config *logical.BackendConfig
// Used to detect if we already reloaded
canary string
}
func (b *backend) reloadBackend(ctx context.Context) error {
pluginName := b.config.Config["plugin_name"]
pluginType, err := consts.ParsePluginType(b.config.Config["plugin_type"])
if err != nil {
return err
}
b.Logger().Debug("plugin: reloading plugin backend", "plugin", pluginName)
// Ensure proper cleanup of the backend
// Pass a context value so that the plugin client will call the appropriate
// cleanup method for reloading
reloadCtx := context.WithValue(ctx, plugin.ContextKeyPluginReload, "reload")
b.Backend.Cleanup(reloadCtx)
nb, err := plugin.NewBackendV5(ctx, pluginName, pluginType, b.config.System, b.config)
if err != nil {
return err
}
err = nb.Setup(ctx, b.config)
if err != nil {
return err
}
b.Backend = nb
return nil
}
// HandleRequest is a thin wrapper implementation of HandleRequest that includes automatic plugin reload.
func (b *backend) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) {
b.mu.RLock()
canary := b.canary
resp, err := b.Backend.HandleRequest(ctx, req)
b.mu.RUnlock()
// Need to compare string value for case were err comes from plugin RPC
// and is returned as plugin.BasicError type.
if err != nil &&
(err.Error() == rpc.ErrShutdown.Error() || err == bplugin.ErrPluginShutdown) {
// Reload plugin if it's an rpc.ErrShutdown
b.mu.Lock()
if b.canary == canary {
err := b.reloadBackend(ctx)
if err != nil {
b.mu.Unlock()
return nil, err
}
b.canary, err = uuid.GenerateUUID()
if err != nil {
b.mu.Unlock()
return nil, err
}
}
b.mu.Unlock()
// Try request once more
b.mu.RLock()
defer b.mu.RUnlock()
return b.Backend.HandleRequest(ctx, req)
}
return resp, err
}
// HandleExistenceCheck is a thin wrapper implementation of HandleRequest that includes automatic plugin reload.
func (b *backend) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) {
b.mu.RLock()
canary := b.canary
checkFound, exists, err := b.Backend.HandleExistenceCheck(ctx, req)
b.mu.RUnlock()
if err != nil &&
(err.Error() == rpc.ErrShutdown.Error() || err == bplugin.ErrPluginShutdown) {
// Reload plugin if it's an rpc.ErrShutdown
b.mu.Lock()
if b.canary == canary {
err := b.reloadBackend(ctx)
if err != nil {
b.mu.Unlock()
return false, false, err
}
b.canary, err = uuid.GenerateUUID()
if err != nil {
b.mu.Unlock()
return false, false, err
}
}
b.mu.Unlock()
// Try request once more
b.mu.RLock()
defer b.mu.RUnlock()
return b.Backend.HandleExistenceCheck(ctx, req)
}
return checkFound, exists, err
}
// InvalidateKey is a thin wrapper used to ensure we grab the lock for race purposes
func (b *backend) InvalidateKey(ctx context.Context, key string) {
b.mu.RLock()
defer b.mu.RUnlock()
b.Backend.InvalidateKey(ctx, key)
}

3
changelog/14946.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:feature
**Secrets/auth plugin multiplexing**: manage multiple plugin configurations with a single plugin process
```

View File

@ -81,14 +81,10 @@ func TestPlugin_PluginMain(t *testing.T) {
flags := apiClientMeta.FlagSet()
flags.Parse(args)
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
factoryFunc := mock.FactoryType(logical.TypeLogical)
err := plugin.Serve(&plugin.ServeOpts{
BackendFactoryFunc: factoryFunc,
TLSProviderFunc: tlsProviderFunc,
})
if err != nil {
t.Fatal(err)

View File

@ -10,7 +10,6 @@ import (
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
@ -30,25 +29,6 @@ type gRPCServer struct {
sync.RWMutex
}
func getMultiplexIDFromContext(ctx context.Context) (string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", fmt.Errorf("missing plugin multiplexing metadata")
}
multiplexIDs := md[pluginutil.MultiplexingCtxKey]
if len(multiplexIDs) != 1 {
return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs))
}
multiplexID := multiplexIDs[0]
if multiplexID == "" {
return "", fmt.Errorf("empty multiplex ID in metadata")
}
return multiplexID, nil
}
func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) {
g.Lock()
defer g.Unlock()
@ -57,7 +37,7 @@ func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error)
return g.singleImpl, nil
}
id, err := getMultiplexIDFromContext(ctx)
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
if err != nil {
return nil, err
}
@ -83,7 +63,7 @@ func (g *gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error)
return g.singleImpl, nil
}
id, err := getMultiplexIDFromContext(ctx)
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
if err != nil {
return nil, err
}
@ -312,7 +292,7 @@ func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, e
if g.singleImpl == nil {
// only cleanup instances map when multiplexing is supported
id, err := getMultiplexIDFromContext(ctx)
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
if err != nil {
return nil, err
}

View File

@ -581,57 +581,6 @@ func TestGRPCServer_Close(t *testing.T) {
}
}
func TestGetMultiplexIDFromContext(t *testing.T) {
type testCase struct {
ctx context.Context
expectedResp string
expectedErr error
}
tests := map[string]testCase{
"missing plugin multiplexing metadata": {
ctx: context.Background(),
expectedResp: "",
expectedErr: fmt.Errorf("missing plugin multiplexing metadata"),
},
"unexpected number of IDs in metadata": {
ctx: idCtx(t, "12345", "67891"),
expectedResp: "",
expectedErr: fmt.Errorf("unexpected number of IDs in metadata: (2)"),
},
"empty multiplex ID in metadata": {
ctx: idCtx(t, ""),
expectedResp: "",
expectedErr: fmt.Errorf("empty multiplex ID in metadata"),
},
"happy path, id is returned from metadata": {
ctx: idCtx(t, "12345"),
expectedResp: "12345",
expectedErr: nil,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
resp, err := getMultiplexIDFromContext(test.ctx)
if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil {
t.Fatalf("err expected, got nil")
} else if !reflect.DeepEqual(err, test.expectedErr) {
t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr)
}
if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
// testGrpcServer is a test helper that returns a context with an ID set in its
// metadata and a gRPCServer instance for a multiplexed plugin
func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) {

View File

@ -112,6 +112,10 @@ func (f *fakePluginClient) Conn() grpc.ClientConnInterface {
return nil
}
func (f *fakePluginClient) Reload() error {
return nil
}
func (f *fakePluginClient) Dispense(name string) (interface{}, error) {
return f.dispenseResp, f.dispenseErr
}

View File

@ -7,7 +7,11 @@ import (
version "github.com/hashicorp/go-version"
)
var (
const (
// PluginAutoMTLSEnv is used to ensure AutoMTLS is used. This will override
// setting a TLSProviderFunc for a plugin.
PluginAutoMTLSEnv = "VAULT_PLUGIN_AUTOMTLS_ENABLED"
// PluginMlockEnabled is the ENV name used to pass the configuration for
// enabling mlock
PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED"

View File

@ -6,6 +6,7 @@ import (
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
status "google.golang.org/grpc/status"
)
@ -45,3 +46,22 @@ func MultiplexingSupported(ctx context.Context, cc grpc.ClientConnInterface) (bo
return resp.Supported, nil
}
func GetMultiplexIDFromContext(ctx context.Context) (string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", fmt.Errorf("missing plugin multiplexing metadata")
}
multiplexIDs := md[MultiplexingCtxKey]
if len(multiplexIDs) != 1 {
return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs))
}
multiplexID := multiplexIDs[0]
if multiplexID == "" {
return "", fmt.Errorf("empty multiplex ID in metadata")
}
return multiplexID, nil
}

View File

@ -0,0 +1,73 @@
package pluginutil
import (
"context"
"fmt"
"reflect"
"testing"
"google.golang.org/grpc/metadata"
)
func TestGetMultiplexIDFromContext(t *testing.T) {
type testCase struct {
ctx context.Context
expectedResp string
expectedErr error
}
tests := map[string]testCase{
"missing plugin multiplexing metadata": {
ctx: context.Background(),
expectedResp: "",
expectedErr: fmt.Errorf("missing plugin multiplexing metadata"),
},
"unexpected number of IDs in metadata": {
ctx: idCtx(t, "12345", "67891"),
expectedResp: "",
expectedErr: fmt.Errorf("unexpected number of IDs in metadata: (2)"),
},
"empty multiplex ID in metadata": {
ctx: idCtx(t, ""),
expectedResp: "",
expectedErr: fmt.Errorf("empty multiplex ID in metadata"),
},
"happy path, id is returned from metadata": {
ctx: idCtx(t, "12345"),
expectedResp: "12345",
expectedErr: nil,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
resp, err := GetMultiplexIDFromContext(test.ctx)
if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil {
t.Fatalf("err expected, got nil")
} else if !reflect.DeepEqual(err, test.expectedErr) {
t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr)
}
if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
// idCtx is a test helper that will return a context with the IDs set in its
// metadata
func idCtx(t *testing.T, ids ...string) context.Context {
// Context doesn't need to timeout since this is just passed through
ctx := context.Background()
md := metadata.MD{}
for _, id := range ids {
md.Append(MultiplexingCtxKey, id)
}
return metadata.NewIncomingContext(ctx, md)
}

View File

@ -23,6 +23,7 @@ type PluginClientConfig struct {
IsMetadataMode bool
AutoMTLS bool
MLock bool
Wrapper RunnerUtil
}
type runConfig struct {
@ -34,8 +35,6 @@ type runConfig struct {
// Initialized with what's in PluginRunner.Env, but can be added to
env []string
wrapper RunnerUtil
PluginClientConfig
}
@ -44,7 +43,7 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error
cmd.Env = append(cmd.Env, rc.env...)
// Add the mlock setting to the ENV of the plugin
if rc.MLock || (rc.wrapper != nil && rc.wrapper.MlockEnabled()) {
if rc.MLock || (rc.Wrapper != nil && rc.Wrapper.MlockEnabled()) {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true"))
}
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version))
@ -55,6 +54,9 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error
metadataEnv := fmt.Sprintf("%s=%t", PluginMetadataModeEnv, rc.IsMetadataMode)
cmd.Env = append(cmd.Env, metadataEnv)
automtlsEnv := fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, rc.AutoMTLS)
cmd.Env = append(cmd.Env, automtlsEnv)
var clientTLSConfig *tls.Config
if !rc.AutoMTLS && !rc.IsMetadataMode {
// Get a CA TLS Certificate
@ -71,7 +73,7 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error
// Use CA to sign a server cert and wrap the values in a response wrapped
// token.
wrapToken, err := wrapServerConfig(ctx, rc.wrapper, certBytes, key)
wrapToken, err := wrapServerConfig(ctx, rc.Wrapper, certBytes, key)
if err != nil {
return nil, err
}
@ -121,7 +123,7 @@ func Env(env ...string) RunOpt {
func Runner(wrapper RunnerUtil) RunOpt {
return func(rc *runConfig) {
rc.wrapper = wrapper
rc.Wrapper = wrapper
}
}

View File

@ -4,7 +4,6 @@ import (
"context"
"fmt"
"os/exec"
"reflect"
"testing"
"time"
@ -14,6 +13,7 @@ import (
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/helper/wrapping"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func TestMakeConfig(t *testing.T) {
@ -78,6 +78,7 @@ func TestMakeConfig(t *testing.T) {
"initial=true",
fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version),
fmt.Sprintf("%s=%t", PluginMetadataModeEnv, true),
fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, false),
},
),
SecureConfig: &plugin.SecureConfig{
@ -143,6 +144,7 @@ func TestMakeConfig(t *testing.T) {
fmt.Sprintf("%s=%t", PluginMlockEnabled, true),
fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version),
fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false),
fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, false),
fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, "testtoken"),
},
),
@ -205,6 +207,7 @@ func TestMakeConfig(t *testing.T) {
"initial=true",
fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version),
fmt.Sprintf("%s=%t", PluginMetadataModeEnv, true),
fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true),
},
),
SecureConfig: &plugin.SecureConfig{
@ -266,6 +269,7 @@ func TestMakeConfig(t *testing.T) {
"initial=true",
fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version),
fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false),
fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true),
},
),
SecureConfig: &plugin.SecureConfig{
@ -290,7 +294,7 @@ func TestMakeConfig(t *testing.T) {
Return(test.responseWrapInfo, test.responseWrapInfoErr)
mockWrapper.On("MlockEnabled").
Return(test.mlockEnabled)
test.rc.wrapper = mockWrapper
test.rc.Wrapper = mockWrapper
defer mockWrapper.AssertNumberOfCalls(t, "ResponseWrapData", test.responseWrapInfoTimes)
defer mockWrapper.AssertNumberOfCalls(t, "MlockEnabled", test.mlockEnabledTimes)
@ -318,9 +322,7 @@ func TestMakeConfig(t *testing.T) {
}
config.TLSConfig = nil
if !reflect.DeepEqual(config, test.expectedConfig) {
t.Fatalf("Actual config: %#v\nExpected config: %#v", config, test.expectedConfig)
}
require.Equal(t, test.expectedConfig, config)
})
}
}

View File

@ -36,6 +36,7 @@ type LookRunnerUtil interface {
type PluginClient interface {
Conn() grpc.ClientConnInterface
Reload() error
plugin.ClientProtocol
}

View File

@ -8,6 +8,7 @@ import (
log "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/plugin/pb"
)
@ -24,18 +25,32 @@ type GRPCBackendPlugin struct {
MetadataMode bool
Logger log.Logger
MultiplexingSupport bool
// Embeding this will disable the netRPC protocol
plugin.NetRPCUnsupportedPlugin
}
func (b GRPCBackendPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error {
pb.RegisterBackendServer(s, &backendGRPCPluginServer{
broker: broker,
factory: b.Factory,
// We pass the logger down into the backend so go-plugin will forward
// logs for us.
server := backendGRPCPluginServer{
broker: broker,
factory: b.Factory,
instances: make(map[string]backendInstance),
// We pass the logger down into the backend so go-plugin will
// forward logs for us.
logger: b.Logger,
})
}
if b.MultiplexingSupport {
// Multiplexing is enabled for this plugin, register the server so we
// can tell the client in Vault.
pluginutil.RegisterPluginMultiplexingServer(s, pluginutil.PluginMultiplexingServerImpl{
Supported: true,
})
server.multiplexingSupport = true
}
pb.RegisterBackendServer(s, &server)
return nil
}

View File

@ -3,6 +3,8 @@ package plugin
import (
"context"
"errors"
"fmt"
"sync"
log "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
@ -14,29 +16,79 @@ import (
var ErrServerInMetadataMode = errors.New("plugin server can not perform action while in metadata mode")
// singleImplementationID is the string used to define the instance ID of a
// non-multiplexed plugin
const singleImplementationID string = "single"
type backendInstance struct {
brokeredClient *grpc.ClientConn
backend logical.Backend
}
type backendGRPCPluginServer struct {
pb.UnimplementedBackendServer
broker *plugin.GRPCBroker
backend logical.Backend
broker *plugin.GRPCBroker
instances map[string]backendInstance
instancesLock sync.RWMutex
multiplexingSupport bool
factory logical.Factory
brokeredClient *grpc.ClientConn
logger log.Logger
}
// getBackendAndBrokeredClientInternal returns the backend and client
// connection but does not hold a lock
func (b *backendGRPCPluginServer) getBackendAndBrokeredClientInternal(ctx context.Context) (logical.Backend, *grpc.ClientConn, error) {
if b.multiplexingSupport {
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
if err != nil {
return nil, nil, err
}
if inst, ok := b.instances[id]; ok {
return inst.backend, inst.brokeredClient, nil
}
}
if singleImpl, ok := b.instances[singleImplementationID]; ok {
return singleImpl.backend, singleImpl.brokeredClient, nil
}
return nil, nil, fmt.Errorf("no backend instance found")
}
// getBackendAndBrokeredClient holds a read lock and returns the backend and
// client connection
func (b *backendGRPCPluginServer) getBackendAndBrokeredClient(ctx context.Context) (logical.Backend, *grpc.ClientConn, error) {
b.instancesLock.RLock()
defer b.instancesLock.RUnlock()
return b.getBackendAndBrokeredClientInternal(ctx)
}
// Setup dials into the plugin's broker to get a shimmed storage, logger, and
// system view of the backend. This method also instantiates the underlying
// backend through its factory func for the server side of the plugin.
func (b *backendGRPCPluginServer) Setup(ctx context.Context, args *pb.SetupArgs) (*pb.SetupReply, error) {
var err error
id := singleImplementationID
if b.multiplexingSupport {
id, err = pluginutil.GetMultiplexIDFromContext(ctx)
if err != nil {
return &pb.SetupReply{}, err
}
}
// Dial for storage
brokeredClient, err := b.broker.Dial(args.BrokerID)
if err != nil {
return &pb.SetupReply{}, err
}
b.brokeredClient = brokeredClient
storage := newGRPCStorageClient(brokeredClient)
sysView := newGRPCSystemView(brokeredClient)
@ -56,12 +108,20 @@ func (b *backendGRPCPluginServer) Setup(ctx context.Context, args *pb.SetupArgs)
Err: pb.ErrToString(err),
}, nil
}
b.backend = backend
b.instances[id] = backendInstance{
brokeredClient: brokeredClient,
backend: backend,
}
return &pb.SetupReply{}, nil
}
func (b *backendGRPCPluginServer) HandleRequest(ctx context.Context, args *pb.HandleRequestArgs) (*pb.HandleRequestReply, error) {
backend, brokeredClient, err := b.getBackendAndBrokeredClient(ctx)
if err != nil {
return &pb.HandleRequestReply{}, err
}
if pluginutil.InMetadataMode() {
return &pb.HandleRequestReply{}, ErrServerInMetadataMode
}
@ -71,9 +131,9 @@ func (b *backendGRPCPluginServer) HandleRequest(ctx context.Context, args *pb.Ha
return &pb.HandleRequestReply{}, err
}
logicalReq.Storage = newGRPCStorageClient(b.brokeredClient)
logicalReq.Storage = newGRPCStorageClient(brokeredClient)
resp, respErr := b.backend.HandleRequest(ctx, logicalReq)
resp, respErr := backend.HandleRequest(ctx, logicalReq)
pbResp, err := pb.LogicalResponseToProtoResponse(resp)
if err != nil {
@ -87,15 +147,20 @@ func (b *backendGRPCPluginServer) HandleRequest(ctx context.Context, args *pb.Ha
}
func (b *backendGRPCPluginServer) Initialize(ctx context.Context, _ *pb.InitializeArgs) (*pb.InitializeReply, error) {
backend, brokeredClient, err := b.getBackendAndBrokeredClient(ctx)
if err != nil {
return &pb.InitializeReply{}, err
}
if pluginutil.InMetadataMode() {
return &pb.InitializeReply{}, ErrServerInMetadataMode
}
req := &logical.InitializationRequest{
Storage: newGRPCStorageClient(b.brokeredClient),
Storage: newGRPCStorageClient(brokeredClient),
}
respErr := b.backend.Initialize(ctx, req)
respErr := backend.Initialize(ctx, req)
return &pb.InitializeReply{
Err: pb.ErrToProtoErr(respErr),
@ -103,7 +168,12 @@ func (b *backendGRPCPluginServer) Initialize(ctx context.Context, _ *pb.Initiali
}
func (b *backendGRPCPluginServer) SpecialPaths(ctx context.Context, args *pb.Empty) (*pb.SpecialPathsReply, error) {
paths := b.backend.SpecialPaths()
backend, _, err := b.getBackendAndBrokeredClient(ctx)
if err != nil {
return &pb.SpecialPathsReply{}, err
}
paths := backend.SpecialPaths()
if paths == nil {
return &pb.SpecialPathsReply{
Paths: nil,
@ -121,6 +191,11 @@ func (b *backendGRPCPluginServer) SpecialPaths(ctx context.Context, args *pb.Emp
}
func (b *backendGRPCPluginServer) HandleExistenceCheck(ctx context.Context, args *pb.HandleExistenceCheckArgs) (*pb.HandleExistenceCheckReply, error) {
backend, brokeredClient, err := b.getBackendAndBrokeredClient(ctx)
if err != nil {
return &pb.HandleExistenceCheckReply{}, err
}
if pluginutil.InMetadataMode() {
return &pb.HandleExistenceCheckReply{}, ErrServerInMetadataMode
}
@ -129,9 +204,10 @@ func (b *backendGRPCPluginServer) HandleExistenceCheck(ctx context.Context, args
if err != nil {
return &pb.HandleExistenceCheckReply{}, err
}
logicalReq.Storage = newGRPCStorageClient(b.brokeredClient)
checkFound, exists, err := b.backend.HandleExistenceCheck(ctx, logicalReq)
logicalReq.Storage = newGRPCStorageClient(brokeredClient)
checkFound, exists, err := backend.HandleExistenceCheck(ctx, logicalReq)
return &pb.HandleExistenceCheckReply{
CheckFound: checkFound,
Exists: exists,
@ -140,24 +216,53 @@ func (b *backendGRPCPluginServer) HandleExistenceCheck(ctx context.Context, args
}
func (b *backendGRPCPluginServer) Cleanup(ctx context.Context, _ *pb.Empty) (*pb.Empty, error) {
b.backend.Cleanup(ctx)
b.instancesLock.Lock()
defer b.instancesLock.Unlock()
backend, brokeredClient, err := b.getBackendAndBrokeredClientInternal(ctx)
if err != nil {
return &pb.Empty{}, err
}
backend.Cleanup(ctx)
// Close rpc clients
b.brokeredClient.Close()
brokeredClient.Close()
if b.multiplexingSupport {
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
if err != nil {
return nil, err
}
delete(b.instances, id)
} else if _, ok := b.instances[singleImplementationID]; ok {
delete(b.instances, singleImplementationID)
}
return &pb.Empty{}, nil
}
func (b *backendGRPCPluginServer) InvalidateKey(ctx context.Context, args *pb.InvalidateKeyArgs) (*pb.Empty, error) {
backend, _, err := b.getBackendAndBrokeredClient(ctx)
if err != nil {
return &pb.Empty{}, err
}
if pluginutil.InMetadataMode() {
return &pb.Empty{}, ErrServerInMetadataMode
}
b.backend.InvalidateKey(ctx, args.Key)
backend.InvalidateKey(ctx, args.Key)
return &pb.Empty{}, nil
}
func (b *backendGRPCPluginServer) Type(ctx context.Context, _ *pb.Empty) (*pb.TypeReply, error) {
backend, _, err := b.getBackendAndBrokeredClient(ctx)
if err != nil {
return &pb.TypeReply{}, err
}
return &pb.TypeReply{
Type: uint32(b.backend.Type()),
Type: uint32(backend.Type()),
}, nil
}

View File

@ -6,12 +6,12 @@ import (
"testing"
"time"
"github.com/golang/protobuf/proto"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/plugin/pb"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
)
func TestSystem_GRPC_GRPC_impl(t *testing.T) {

View File

@ -222,7 +222,6 @@ func (l *deprecatedLoggerClient) Error(msg string, args ...interface{}) error {
func (l *deprecatedLoggerClient) Fatal(msg string, args ...interface{}) {
// NOOP since it's not actually used within vault
return
}
func (l *deprecatedLoggerClient) Log(level int, msg string, args []interface{}) {

View File

@ -10,16 +10,16 @@ import (
// backendPluginClient implements logical.Backend and is the
// go-plugin client.
type backendTracingMiddleware struct {
type BackendTracingMiddleware struct {
logger log.Logger
next logical.Backend
}
// Validate the backendTracingMiddle object satisfies the backend interface
var _ logical.Backend = &backendTracingMiddleware{}
var _ logical.Backend = &BackendTracingMiddleware{}
func (b *backendTracingMiddleware) Initialize(ctx context.Context, req *logical.InitializationRequest) (err error) {
func (b *BackendTracingMiddleware) Initialize(ctx context.Context, req *logical.InitializationRequest) (err error) {
defer func(then time.Time) {
b.logger.Trace("initialize", "status", "finished", "err", err, "took", time.Since(then))
}(time.Now())
@ -28,7 +28,7 @@ func (b *backendTracingMiddleware) Initialize(ctx context.Context, req *logical.
return b.next.Initialize(ctx, req)
}
func (b *backendTracingMiddleware) HandleRequest(ctx context.Context, req *logical.Request) (resp *logical.Response, err error) {
func (b *BackendTracingMiddleware) HandleRequest(ctx context.Context, req *logical.Request) (resp *logical.Response, err error) {
defer func(then time.Time) {
b.logger.Trace("handle request", "path", req.Path, "status", "finished", "err", err, "took", time.Since(then))
}(time.Now())
@ -37,7 +37,7 @@ func (b *backendTracingMiddleware) HandleRequest(ctx context.Context, req *logic
return b.next.HandleRequest(ctx, req)
}
func (b *backendTracingMiddleware) SpecialPaths() *logical.Paths {
func (b *BackendTracingMiddleware) SpecialPaths() *logical.Paths {
defer func(then time.Time) {
b.logger.Trace("special paths", "status", "finished", "took", time.Since(then))
}(time.Now())
@ -46,15 +46,15 @@ func (b *backendTracingMiddleware) SpecialPaths() *logical.Paths {
return b.next.SpecialPaths()
}
func (b *backendTracingMiddleware) System() logical.SystemView {
func (b *BackendTracingMiddleware) System() logical.SystemView {
return b.next.System()
}
func (b *backendTracingMiddleware) Logger() log.Logger {
func (b *BackendTracingMiddleware) Logger() log.Logger {
return b.next.Logger()
}
func (b *backendTracingMiddleware) HandleExistenceCheck(ctx context.Context, req *logical.Request) (found bool, exists bool, err error) {
func (b *BackendTracingMiddleware) HandleExistenceCheck(ctx context.Context, req *logical.Request) (found bool, exists bool, err error) {
defer func(then time.Time) {
b.logger.Trace("handle existence check", "path", req.Path, "status", "finished", "err", err, "took", time.Since(then))
}(time.Now())
@ -63,7 +63,7 @@ func (b *backendTracingMiddleware) HandleExistenceCheck(ctx context.Context, req
return b.next.HandleExistenceCheck(ctx, req)
}
func (b *backendTracingMiddleware) Cleanup(ctx context.Context) {
func (b *BackendTracingMiddleware) Cleanup(ctx context.Context) {
defer func(then time.Time) {
b.logger.Trace("cleanup", "status", "finished", "took", time.Since(then))
}(time.Now())
@ -72,7 +72,7 @@ func (b *backendTracingMiddleware) Cleanup(ctx context.Context) {
b.next.Cleanup(ctx)
}
func (b *backendTracingMiddleware) InvalidateKey(ctx context.Context, key string) {
func (b *BackendTracingMiddleware) InvalidateKey(ctx context.Context, key string) {
defer func(then time.Time) {
b.logger.Trace("invalidate key", "key", key, "status", "finished", "took", time.Since(then))
}(time.Now())
@ -81,7 +81,7 @@ func (b *backendTracingMiddleware) InvalidateKey(ctx context.Context, key string
b.next.InvalidateKey(ctx, key)
}
func (b *backendTracingMiddleware) Setup(ctx context.Context, config *logical.BackendConfig) (err error) {
func (b *BackendTracingMiddleware) Setup(ctx context.Context, config *logical.BackendConfig) (err error) {
defer func(then time.Time) {
b.logger.Trace("setup", "status", "finished", "err", err, "took", time.Since(then))
}(time.Now())
@ -90,7 +90,7 @@ func (b *backendTracingMiddleware) Setup(ctx context.Context, config *logical.Ba
return b.next.Setup(ctx, config)
}
func (b *backendTracingMiddleware) Type() logical.BackendType {
func (b *BackendTracingMiddleware) Type() logical.BackendType {
defer func(then time.Time) {
b.logger.Trace("type", "status", "finished", "took", time.Since(then))
}(time.Now())

View File

@ -4,9 +4,7 @@ import (
"context"
"errors"
"fmt"
"sync"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/helper/consts"
@ -19,7 +17,6 @@ import (
// used to cleanly kill the client on Cleanup()
type BackendPluginClient struct {
client *plugin.Client
sync.Mutex
logical.Backend
}
@ -48,7 +45,7 @@ func NewBackend(ctx context.Context, pluginName string, pluginType consts.Plugin
// from the pluginRunner. Then cast it to logical.Factory.
rawFactory, err := pluginRunner.BuiltinFactory()
if err != nil {
return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err)
return nil, fmt.Errorf("error getting plugin type: %q", err)
}
if factory, ok := rawFactory.(logical.Factory); !ok {
@ -93,9 +90,9 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
var client *plugin.Client
var err error
if isMetadataMode {
client, err = pluginRunner.RunMetadataMode(ctx, sys, pluginSet, handshakeConfig, []string{}, namedLogger)
client, err = pluginRunner.RunMetadataMode(ctx, sys, pluginSet, HandshakeConfig, []string{}, namedLogger)
} else {
client, err = pluginRunner.Run(ctx, sys, pluginSet, handshakeConfig, []string{}, namedLogger)
client, err = pluginRunner.Run(ctx, sys, pluginSet, HandshakeConfig, []string{}, namedLogger)
}
if err != nil {
return nil, err
@ -117,9 +114,9 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
var transport string
// We should have a logical backend type now. This feels like a normal interface
// implementation but is in fact over an RPC connection.
switch raw.(type) {
switch b := raw.(type) {
case *backendGRPCPluginClient:
backend = raw.(*backendGRPCPluginClient)
backend = b
transport = "gRPC"
default:
return nil, errors.New("unsupported plugin client type")
@ -127,7 +124,7 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
// Wrap the backend in a tracing middleware
if namedLogger.IsTrace() {
backend = &backendTracingMiddleware{
backend = &BackendTracingMiddleware{
logger: namedLogger.With("transport", transport),
next: backend,
}
@ -138,22 +135,3 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
Backend: backend,
}, nil
}
// wrapError takes a generic error type and makes it usable with the plugin
// interface. Only errors which have exported fields and have been registered
// with gob can be unwrapped and transported. This checks error types and, if
// none match, wrap the error in a plugin.BasicError.
func wrapError(err error) error {
if err == nil {
return nil
}
switch err.(type) {
case *plugin.BasicError,
logical.HTTPCodedError,
*logical.StatusBadRequest:
return err
}
return plugin.NewBasicError(err)
}

165
sdk/plugin/plugin_v5.go Normal file
View File

@ -0,0 +1,165 @@
package plugin
import (
"context"
"errors"
"fmt"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/plugin/pb"
)
// BackendPluginClientV5 is a wrapper around backendPluginClient
// that also contains its plugin.Client instance. It's primarily
// used to cleanly kill the client on Cleanup()
type BackendPluginClientV5 struct {
client pluginutil.PluginClient
logical.Backend
}
type ContextKey string
func (c ContextKey) String() string {
return "plugin" + string(c)
}
const ContextKeyPluginReload = ContextKey("plugin-reload")
// Cleanup cleans up the go-plugin client and the plugin catalog
func (b *BackendPluginClientV5) Cleanup(ctx context.Context) {
_, ok := ctx.Value(ContextKeyPluginReload).(string)
if !ok {
b.Backend.Cleanup(ctx)
b.client.Close()
return
}
b.Backend.Cleanup(ctx)
b.client.Reload()
}
// NewBackendV5 will return an instance of an RPC-based client implementation of
// the backend for external plugins, or a concrete implementation of the
// backend if it is a builtin backend. The backend is returned as a
// logical.Backend interface.
func NewBackendV5(ctx context.Context, pluginName string, pluginType consts.PluginType, sys pluginutil.LookRunnerUtil, conf *logical.BackendConfig) (logical.Backend, error) {
// Look for plugin in the plugin catalog
pluginRunner, err := sys.LookupPlugin(ctx, pluginName, pluginType)
if err != nil {
return nil, err
}
var backend logical.Backend
if pluginRunner.Builtin {
// Plugin is builtin so we can retrieve an instance of the interface
// from the pluginRunner. Then cast it to logical.Factory.
rawFactory, err := pluginRunner.BuiltinFactory()
if err != nil {
return nil, fmt.Errorf("error getting plugin type: %q", err)
}
if factory, ok := rawFactory.(logical.Factory); !ok {
return nil, fmt.Errorf("unsupported backend type: %q", pluginName)
} else {
if backend, err = factory(ctx, conf); err != nil {
return nil, err
}
}
} else {
// create a backendPluginClient instance
config := pluginutil.PluginClientConfig{
Name: pluginName,
PluginSets: PluginSet,
PluginType: pluginType,
HandshakeConfig: HandshakeConfig,
Logger: conf.Logger.Named(pluginName),
AutoMTLS: true,
Wrapper: sys,
}
backend, err = NewPluginClientV5(ctx, sys, config)
if err != nil {
return nil, err
}
}
return backend, nil
}
// PluginSet is the map of plugins we can dispense.
var PluginSet = map[int]plugin.PluginSet{
5: {
"backend": &GRPCBackendPlugin{},
},
}
func Dispense(rpcClient plugin.ClientProtocol, pluginClient pluginutil.PluginClient) (logical.Backend, error) {
// Request the plugin
raw, err := rpcClient.Dispense("backend")
if err != nil {
return nil, err
}
var backend logical.Backend
// We should have a logical backend type now. This feels like a normal interface
// implementation but is in fact over an RPC connection.
switch c := raw.(type) {
case *backendGRPCPluginClient:
// This is an abstraction leak from go-plugin but it is necessary in
// order to enable multiplexing on multiplexed plugins
c.client = pb.NewBackendClient(pluginClient.Conn())
backend = c
default:
return nil, errors.New("unsupported plugin client type")
}
return &BackendPluginClientV5{
client: pluginClient,
Backend: backend,
}, nil
}
func NewPluginClientV5(ctx context.Context, sys pluginutil.RunnerUtil, config pluginutil.PluginClientConfig) (logical.Backend, error) {
pluginClient, err := sys.NewPluginClient(ctx, config)
if err != nil {
return nil, err
}
// Request the plugin
raw, err := pluginClient.Dispense("backend")
if err != nil {
return nil, err
}
var backend logical.Backend
var transport string
// We should have a logical backend type now. This feels like a normal interface
// implementation but is in fact over an RPC connection.
switch c := raw.(type) {
case *backendGRPCPluginClient:
// This is an abstraction leak from go-plugin but it is necessary in
// order to enable multiplexing on multiplexed plugins
c.client = pb.NewBackendClient(pluginClient.Conn())
backend = c
transport = "gRPC"
default:
return nil, errors.New("unsupported plugin client type")
}
// Wrap the backend in a tracing middleware
if config.Logger.IsTrace() {
backend = &BackendTracingMiddleware{
logger: config.Logger.With("transport", transport),
next: backend,
}
}
return &BackendPluginClientV5{
client: pluginClient,
Backend: backend,
}, nil
}

View File

@ -55,6 +55,13 @@ func Serve(opts *ServeOpts) error {
Logger: logger,
},
},
5: {
"backend": &GRPCBackendPlugin{
Factory: opts.BackendFactoryFunc,
MultiplexingSupport: false,
Logger: logger,
},
},
}
err := pluginutil.OptionallyEnableMlock()
@ -63,7 +70,7 @@ func Serve(opts *ServeOpts) error {
}
serveOpts := &plugin.ServeConfig{
HandshakeConfig: handshakeConfig,
HandshakeConfig: HandshakeConfig,
VersionedPlugins: pluginSets,
TLSProvider: opts.TLSProviderFunc,
Logger: logger,
@ -81,12 +88,77 @@ func Serve(opts *ServeOpts) error {
return nil
}
// ServeMultiplex is a helper function used to serve a backend plugin. This
// should be ran on the plugin's main process.
func ServeMultiplex(opts *ServeOpts) error {
logger := opts.Logger
if logger == nil {
logger = log.New(&log.LoggerOptions{
Level: log.Info,
Output: os.Stderr,
JSONFormat: true,
})
}
// pluginMap is the map of plugins we can dispense.
pluginSets := map[int]plugin.PluginSet{
// Version 3 used to supports both protocols. We want to keep it around
// since it's possible old plugins built against this version will still
// work with gRPC. There is currently no difference between version 3
// and version 4.
3: {
"backend": &GRPCBackendPlugin{
Factory: opts.BackendFactoryFunc,
Logger: logger,
},
},
4: {
"backend": &GRPCBackendPlugin{
Factory: opts.BackendFactoryFunc,
Logger: logger,
},
},
5: {
"backend": &GRPCBackendPlugin{
Factory: opts.BackendFactoryFunc,
MultiplexingSupport: true,
Logger: logger,
},
},
}
err := pluginutil.OptionallyEnableMlock()
if err != nil {
return err
}
serveOpts := &plugin.ServeConfig{
HandshakeConfig: HandshakeConfig,
VersionedPlugins: pluginSets,
Logger: logger,
// A non-nil value here enables gRPC serving for this plugin...
GRPCServer: func(opts []grpc.ServerOption) *grpc.Server {
opts = append(opts, grpc.MaxRecvMsgSize(math.MaxInt32))
opts = append(opts, grpc.MaxSendMsgSize(math.MaxInt32))
return plugin.DefaultGRPCServer(opts)
},
// TLSProvider is required to support v3 and v4 plugins.
// It will be ignored for v5 which uses AutoMTLS
TLSProvider: opts.TLSProviderFunc,
}
plugin.Serve(serveOpts)
return nil
}
// handshakeConfigs are used to just do a basic handshake between
// a plugin and host. If the handshake fails, a user friendly error is shown.
// This prevents users from executing bad plugins or executing a plugin
// directory. It is a UX feature, not a security feature.
var handshakeConfig = plugin.HandshakeConfig{
ProtocolVersion: 4,
var HandshakeConfig = plugin.HandshakeConfig{
MagicCookieKey: "VAULT_BACKEND_PLUGIN",
MagicCookieValue: "6669da05-b1c8-4f49-97d9-c8e5bed98e20",
}

View File

@ -7,7 +7,7 @@ import (
"strings"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/go-uuid"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/builtin/plugin"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/helper/consts"
@ -924,7 +924,6 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV
f = wrapFactoryCheckPerms(c, plugin.Factory)
}
}
// Set up conf to pass in plugin_name
conf := make(map[string]string)
for k, v := range entry.Options {

File diff suppressed because it is too large Load Diff

View File

@ -1432,7 +1432,6 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView
f = wrapFactoryCheckPerms(c, plugin.Factory)
}
}
// Set up conf to pass in plugin_name
conf := make(map[string]string)
for k, v := range entry.Options {

View File

@ -30,10 +30,11 @@ import (
)
var (
pluginCatalogPath = "core/plugin-catalog/"
ErrDirectoryNotConfigured = errors.New("could not set plugin, plugin directory is not configured")
ErrPluginNotFound = errors.New("plugin not found in the catalog")
ErrPluginBadType = errors.New("unable to determine plugin type")
pluginCatalogPath = "core/plugin-catalog/"
ErrDirectoryNotConfigured = errors.New("could not set plugin, plugin directory is not configured")
ErrPluginNotFound = errors.New("plugin not found in the catalog")
ErrPluginConnectionNotFound = errors.New("plugin connection not found for client")
ErrPluginBadType = errors.New("unable to determine plugin type")
)
// PluginCatalog keeps a record of plugins known to vault. External plugins need
@ -78,13 +79,15 @@ type pluginClient struct {
logger log.Logger
// id is the connection ID
id string
id string
pid int
// client handles the lifecycle of a plugin process
// multiplexed plugins share the same client
client *plugin.Client
clientConn grpc.ClientConnInterface
cleanupFunc func() error
reloadFunc func() error
plugin.ClientProtocol
}
@ -148,6 +151,38 @@ func (p *pluginClient) Conn() grpc.ClientConnInterface {
return p.clientConn
}
func (p *pluginClient) Reload() error {
p.logger.Debug("reload external plugin process")
return p.reloadFunc()
}
// reloadExternalPlugin
// This should be called with the write lock held.
func (c *PluginCatalog) reloadExternalPlugin(name, id string) error {
extPlugin, ok := c.externalPlugins[name]
if !ok {
return fmt.Errorf("plugin client not found")
}
if !extPlugin.multiplexingSupport {
err := c.cleanupExternalPlugin(name, id)
if err != nil {
return err
}
return nil
}
pc, ok := extPlugin.connections[id]
if !ok {
return fmt.Errorf("%w id: %s", ErrPluginConnectionNotFound, id)
}
delete(c.externalPlugins, name)
pc.client.Kill()
c.logger.Debug("killed external plugin process for reload", "name", name, "pid", pc.pid)
return nil
}
// Close calls the plugin client's cleanupFunc to do any necessary cleanup on
// the plugin client and the PluginCatalog. This implements the
// plugin.ClientProtocol interface.
@ -167,19 +202,26 @@ func (c *PluginCatalog) cleanupExternalPlugin(name, id string) error {
pc, ok := extPlugin.connections[id]
if !ok {
return fmt.Errorf("plugin connection not found")
// this can happen if the backend is reloaded due to a plugin process
// being killed out of band
c.logger.Warn(ErrPluginConnectionNotFound.Error(), "id", id)
return fmt.Errorf("%w id: %s", ErrPluginConnectionNotFound, id)
}
delete(extPlugin.connections, id)
c.logger.Debug("removed plugin client connection", "id", id)
if !extPlugin.multiplexingSupport {
pc.client.Kill()
if len(extPlugin.connections) == 0 {
delete(c.externalPlugins, name)
}
c.logger.Debug("killed external plugin process", "name", name, "pid", pc.pid)
} else if len(extPlugin.connections) == 0 || pc.client.Exited() {
pc.client.Kill()
delete(c.externalPlugins, name)
c.logger.Debug("killed external multiplexed plugin process", "name", name, "pid", pc.pid)
}
return nil
@ -252,6 +294,11 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi
defer c.lock.Unlock()
return c.cleanupExternalPlugin(pluginRunner.Name, id)
},
reloadFunc: func() error {
c.lock.Lock()
defer c.lock.Unlock()
return c.reloadExternalPlugin(pluginRunner.Name, id)
},
}
// Multiplexing support will always be false initially, but will be
@ -264,9 +311,8 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi
pluginutil.Logger(config.Logger),
pluginutil.MetadataMode(config.IsMetadataMode),
pluginutil.MLock(c.mlockPlugins),
// NewPluginClient only supports AutoMTLS today
pluginutil.AutoMTLS(true),
pluginutil.AutoMTLS(config.AutoMTLS),
pluginutil.Runner(config.Wrapper),
)
if err != nil {
return nil, err
@ -294,6 +340,12 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi
return nil, err
}
// get the external plugin pid
conf := pc.client.ReattachConfig()
if conf != nil {
pc.pid = conf.Pid
}
clientConn := rpcClient.(*plugin.GRPCClient).Conn
muxed, err := pluginutil.MultiplexingSupported(ctx, clientConn)
@ -322,9 +374,8 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi
}
// getPluginTypeFromUnknown will attempt to run the plugin to determine the
// type and if it supports multiplexing. It will first attempt to run as a
// database plugin then a backend plugin. Both of these will be run in metadata
// mode.
// type. It will first attempt to run as a database plugin then a backend
// plugin.
func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, error) {
merr := &multierror.Error{}
err := c.isDatabasePlugin(ctx, plugin)
@ -333,43 +384,95 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log
}
merr = multierror.Append(merr, err)
// Attempt to run as backend plugin
client, err := backendplugin.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true)
pluginType, err := c.getBackendPluginType(ctx, plugin)
if err == nil {
err := client.Setup(ctx, &logical.BackendConfig{})
return pluginType, nil
}
merr = multierror.Append(merr, err)
return consts.PluginTypeUnknown, merr
}
// getBackendPluginType returns an error if the plugin is not a backend plugin.
func (c *PluginCatalog) getBackendPluginType(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (consts.PluginType, error) {
merr := &multierror.Error{}
// Attempt to run as backend plugin
config := pluginutil.PluginClientConfig{
Name: pluginRunner.Name,
PluginSets: backendplugin.PluginSet,
HandshakeConfig: backendplugin.HandshakeConfig,
Logger: log.NewNullLogger(),
IsMetadataMode: false,
AutoMTLS: true,
}
var client logical.Backend
var attemptV4 bool
// First, attempt to run as backend V5 plugin
c.logger.Debug("attempting to load backend plugin", "name", pluginRunner.Name)
pc, err := c.newPluginClient(ctx, pluginRunner, config)
if err == nil {
// we spawned a subprocess, so make sure to clean it up
defer c.cleanupExternalPlugin(pluginRunner.Name, pc.id)
// dispense the plugin so we can get its type
client, err = backendplugin.Dispense(pc.ClientProtocol, pc)
if err != nil {
return consts.PluginTypeUnknown, err
}
backendType := client.Type()
client.Cleanup(ctx)
switch backendType {
case logical.TypeCredential:
return consts.PluginTypeCredential, nil
case logical.TypeLogical:
return consts.PluginTypeSecrets, nil
merr = multierror.Append(merr, fmt.Errorf("failed to dispense plugin as backend v5: %w", err))
c.logger.Debug("failed to dispense v5 backend plugin", "name", pluginRunner.Name)
attemptV4 = true
} else {
c.logger.Debug("successfully dispensed v5 backend plugin", "name", pluginRunner.Name)
}
} else {
merr = multierror.Append(merr, err)
attemptV4 = true
}
if attemptV4 {
c.logger.Debug("failed to dispense v5 backend plugin", "name", pluginRunner.Name, "error", err)
config.AutoMTLS = false
config.IsMetadataMode = true
// attempt to run as a v4 backend plugin
client, err = backendplugin.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true)
if err != nil {
c.logger.Debug("failed to dispense v4 backend plugin", "name", pluginRunner.Name, "error", err)
merr = multierror.Append(merr, fmt.Errorf("failed to dispense v4 backend plugin: %w", err))
return consts.PluginTypeUnknown, merr.ErrorOrNil()
}
c.logger.Debug("successfully dispensed v4 backend plugin", "name", pluginRunner.Name)
defer client.Cleanup(ctx)
}
err = client.Setup(ctx, &logical.BackendConfig{})
if err != nil {
return consts.PluginTypeUnknown, err
}
backendType := client.Type()
switch backendType {
case logical.TypeCredential:
return consts.PluginTypeCredential, nil
case logical.TypeLogical:
return consts.PluginTypeSecrets, nil
}
if client == nil || client.Type() == logical.TypeUnknown {
logger.Warn("unknown plugin type",
"plugin name", plugin.Name,
c.logger.Warn("unknown plugin type",
"plugin name", pluginRunner.Name,
"error", merr.Error())
} else {
logger.Warn("unsupported plugin type",
"plugin name", plugin.Name,
c.logger.Warn("unsupported plugin type",
"plugin name", pluginRunner.Name,
"plugin type", client.Type().String(),
"error", merr.Error())
}
return consts.PluginTypeUnknown, nil
merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as backend plugin: %w", err))
return consts.PluginTypeUnknown, merr.ErrorOrNil()
}
// isDatabasePlugin returns true if the plugin supports multiplexing. An error
// is returned if the plugin is not a database plugin.
// isDatabasePlugin returns an error if the plugin is not a database plugin.
func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *pluginutil.PluginRunner) error {
merr := &multierror.Error{}
config := pluginutil.PluginClientConfig{

View File

@ -10,11 +10,13 @@ import (
"sort"
"testing"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/credential/userpass"
"github.com/hashicorp/vault/plugins/database/postgresql"
v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
backendplugin "github.com/hashicorp/vault/sdk/plugin"
"github.com/hashicorp/vault/helper/builtinplugins"
)
@ -59,7 +61,7 @@ func TestPluginCatalog_CRUD(t *testing.T) {
}
defer file.Close()
command := fmt.Sprintf("%s", filepath.Base(file.Name()))
command := filepath.Base(file.Name())
err = core.pluginCatalog.Set(context.Background(), pluginName, consts.PluginTypeDatabase, "", command, []string{"--test"}, []string{"FOO=BAR"}, []byte{'1'})
if err != nil {
t.Fatal(err)
@ -375,47 +377,109 @@ func TestPluginCatalog_NewPluginClient(t *testing.T) {
TestAddTestPlugin(t, core, "single-postgres-1", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_Postgres", []string{}, "")
TestAddTestPlugin(t, core, "single-postgres-2", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_Postgres", []string{}, "")
TestAddTestPlugin(t, core, "mux-userpass", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_UserpassMultiplexed", []string{}, "")
TestAddTestPlugin(t, core, "single-userpass-1", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_Userpass", []string{}, "")
TestAddTestPlugin(t, core, "single-userpass-2", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_Userpass", []string{}, "")
var pluginClients []*pluginClient
// run plugins
if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("mux-postgres")); err != nil {
t.Fatal(err)
}
if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("mux-postgres")); err != nil {
t.Fatal(err)
}
if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("single-postgres-1")); err != nil {
t.Fatal(err)
}
if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("single-postgres-2")); err != nil {
t.Fatal(err)
}
// run "mux-postgres" twice which will start a single plugin for 2
// distinct connections
c := TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "mux-postgres")
pluginClients = append(pluginClients, c)
c = TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "mux-postgres")
pluginClients = append(pluginClients, c)
c = TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "single-postgres-1")
pluginClients = append(pluginClients, c)
c = TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "single-postgres-2")
pluginClients = append(pluginClients, c)
// run "mux-userpass" twice which will start a single plugin for 2
// distinct connections
c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "mux-userpass")
pluginClients = append(pluginClients, c)
c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "mux-userpass")
pluginClients = append(pluginClients, c)
c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "single-userpass-1")
pluginClients = append(pluginClients, c)
c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "single-userpass-2")
pluginClients = append(pluginClients, c)
externalPlugins := core.pluginCatalog.externalPlugins
if len(externalPlugins) != 3 {
t.Fatalf("expected externalPlugins map to be of len 3 but got %d", len(externalPlugins))
if len(externalPlugins) != 6 {
t.Fatalf("expected externalPlugins map to be of len 6 but got %d", len(externalPlugins))
}
// check connections map
expectedLen := 2
if len(externalPlugins["mux-postgres"].connections) != expectedLen {
t.Fatalf("expected multiplexed external plugin's connections map to be of len %d but got %d", expectedLen, len(externalPlugins["mux-postgres"].connections))
}
expectedLen = 1
if len(externalPlugins["single-postgres-1"].connections) != expectedLen {
t.Fatalf("expected multiplexed external plugin's connections map to be of len %d but got %d", expectedLen, len(externalPlugins["mux-postgres"].connections))
}
if len(externalPlugins["single-postgres-2"].connections) != expectedLen {
t.Fatalf("expected multiplexed external plugin's connections map to be of len %d but got %d", expectedLen, len(externalPlugins["mux-postgres"].connections))
}
expectConnectionLen(t, 2, externalPlugins["mux-postgres"].connections)
expectConnectionLen(t, 1, externalPlugins["single-postgres-1"].connections)
expectConnectionLen(t, 1, externalPlugins["single-postgres-2"].connections)
expectConnectionLen(t, 2, externalPlugins["mux-userpass"].connections)
expectConnectionLen(t, 1, externalPlugins["single-userpass-1"].connections)
expectConnectionLen(t, 1, externalPlugins["single-userpass-2"].connections)
// check multiplexing support
if !externalPlugins["mux-postgres"].multiplexingSupport {
t.Fatalf("expected external plugin to be multiplexed")
expectMultiplexingSupport(t, true, externalPlugins["mux-postgres"].multiplexingSupport)
expectMultiplexingSupport(t, false, externalPlugins["single-postgres-1"].multiplexingSupport)
expectMultiplexingSupport(t, false, externalPlugins["single-postgres-2"].multiplexingSupport)
expectMultiplexingSupport(t, true, externalPlugins["mux-userpass"].multiplexingSupport)
expectMultiplexingSupport(t, false, externalPlugins["single-userpass-1"].multiplexingSupport)
expectMultiplexingSupport(t, false, externalPlugins["single-userpass-2"].multiplexingSupport)
// cleanup all of the external plugin processes
for _, client := range pluginClients {
client.Close()
}
if externalPlugins["single-postgres-1"].multiplexingSupport {
t.Fatalf("expected external plugin to be non-multiplexed")
// check that externalPlugins map is cleaned up
if len(externalPlugins) != 0 {
t.Fatalf("expected external plugin map to be of len 0 but got %d", len(externalPlugins))
}
if externalPlugins["single-postgres-2"].multiplexingSupport {
t.Fatalf("expected external plugin to be non-multiplexed")
}
func TestPluginCatalog_PluginMain_Userpass(t *testing.T) {
if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" {
return
}
apiClientMeta := &api.PluginAPIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(os.Args[1:])
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
err := backendplugin.Serve(
&backendplugin.ServeOpts{
BackendFactoryFunc: userpass.Factory,
TLSProviderFunc: tlsProviderFunc,
},
)
if err != nil {
t.Fatalf("Failed to initialize userpass: %s", err)
}
}
func TestPluginCatalog_PluginMain_UserpassMultiplexed(t *testing.T) {
if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" {
return
}
apiClientMeta := &api.PluginAPIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(os.Args[1:])
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
err := backendplugin.ServeMultiplex(
&backendplugin.ServeOpts{
BackendFactoryFunc: userpass.Factory,
TLSProviderFunc: tlsProviderFunc,
},
)
if err != nil {
t.Fatalf("Failed to initialize userpass: %s", err)
}
}
@ -440,14 +504,16 @@ func TestPluginCatalog_PluginMain_PostgresMultiplexed(_ *testing.T) {
v5.ServeMultiplex(postgresql.New)
}
func testPluginClientConfig(pluginName string) pluginutil.PluginClientConfig {
return pluginutil.PluginClientConfig{
Name: pluginName,
PluginType: consts.PluginTypeDatabase,
PluginSets: v5.PluginSets,
HandshakeConfig: v5.HandshakeConfig,
Logger: log.NewNullLogger(),
IsMetadataMode: false,
AutoMTLS: true,
// expectConnectionLen asserts that the PluginCatalog's externalPlugin
// connections map has a length of expectedLen
func expectConnectionLen(t *testing.T, expectedLen int, connections map[string]*pluginClient) {
if len(connections) != expectedLen {
t.Fatalf("expected external plugin's connections map to be of len %d but got %d", expectedLen, len(connections))
}
}
func expectMultiplexingSupport(t *testing.T, expected, actual bool) {
if expected != actual {
t.Fatalf("expected external plugin multiplexing support to be %t", expected)
}
}

View File

@ -10,6 +10,7 @@ import (
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/plugin"
)
// reloadMatchingPluginMounts reloads provided mounts, regardless of
@ -147,8 +148,11 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut
// Only call Cleanup if backend is initialized
if re.backend != nil {
// Pass a context value so that the plugin client will call the
// appropriate cleanup method for reloading
reloadCtx := context.WithValue(ctx, plugin.ContextKeyPluginReload, "reload")
// Call backend's Cleanup routine
re.backend.Cleanup(ctx)
re.backend.Cleanup(reloadCtx)
}
view := re.storageView

View File

@ -39,13 +39,16 @@ import (
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/internalshared/configutil"
dbMysql "github.com/hashicorp/vault/plugins/database/mysql"
v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/physical"
physInmem "github.com/hashicorp/vault/sdk/physical/inmem"
backendplugin "github.com/hashicorp/vault/sdk/plugin"
"github.com/hashicorp/vault/vault/cluster"
"github.com/hashicorp/vault/vault/seal"
"github.com/mitchellh/copystructure"
@ -563,6 +566,48 @@ func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.Plug
}
}
// TestRunTestPlugin runs the testFunc which has already been registered to the
// plugin catalog and returns a pluginClient. This can be called after calling
// TestAddTestPlugin.
func TestRunTestPlugin(t testing.T, c *Core, pluginType consts.PluginType, pluginName string) *pluginClient {
t.Helper()
config := TestPluginClientConfig(c, pluginType, pluginName)
client, err := c.pluginCatalog.NewPluginClient(context.Background(), config)
if err != nil {
t.Fatal(err)
}
return client
}
func TestPluginClientConfig(c *Core, pluginType consts.PluginType, pluginName string) pluginutil.PluginClientConfig {
switch pluginType {
case consts.PluginTypeCredential, consts.PluginTypeSecrets:
dsv := TestDynamicSystemView(c, nil)
return pluginutil.PluginClientConfig{
Name: pluginName,
PluginType: pluginType,
PluginSets: backendplugin.PluginSet,
HandshakeConfig: backendplugin.HandshakeConfig,
Logger: log.NewNullLogger(),
AutoMTLS: true,
IsMetadataMode: false,
Wrapper: dsv,
}
case consts.PluginTypeDatabase:
return pluginutil.PluginClientConfig{
Name: pluginName,
PluginType: pluginType,
PluginSets: v5.PluginSets,
HandshakeConfig: v5.HandshakeConfig,
Logger: log.NewNullLogger(),
AutoMTLS: true,
IsMetadataMode: false,
}
}
return pluginutil.PluginClientConfig{}
}
var (
testLogicalBackends = map[string]logical.Factory{}
testCredentialBackends = map[string]logical.Factory{}

View File

@ -141,7 +141,7 @@ DONELISTHANDLING:
NamespaceID: ns.ID,
}
if err := c.tokenStore.create(ctx, &te); err != nil {
if err := c.CreateToken(ctx, &te); err != nil {
c.logger.Error("failed to create wrapping token", "error", err)
return nil, ErrInternalError
}