feature: secrets/auth plugin multiplexing (#14946)
* enable registering backend muxed plugins in plugin catalog * set the sysview on the pluginconfig to allow enabling secrets/auth plugins * store backend instances in map * store single implementations in the instances map cleanup instance map and ensure we don't deadlock * fix system backend unit tests move GetMultiplexIDFromContext to pluginutil package fix pluginutil test fix dbplugin ut * return error(s) if we can't get the plugin client update comments * refactor/move GetMultiplexIDFromContext test * add changelog * remove unnecessary field on pluginClient * add unit tests to PluginCatalog for secrets/auth plugins * fix comment * return pluginClient from TestRunTestPlugin * add multiplexed backend test * honor metadatamode value in newbackend pluginconfig * check that connection exists on cleanup * add automtls to secrets/auth plugins * don't remove apiclientmeta parsing * use formatting directive for fmt.Errorf * fix ut: remove tls provider func * remove tlsproviderfunc from backend plugin tests * use env var to prevent test plugin from running as a unit test * WIP: remove lazy loading * move non lazy loaded backend to new package * use version wrapper for backend plugin factory * remove backendVersionWrapper type * implement getBackendPluginType for plugin catalog * handle backend plugin v4 registration * add plugin automtls env guard * modify plugin factory to determine the backend to use * remove old pluginsets from v5 and log pid in plugin catalog * add reload mechanism via context * readd v3 and v4 to pluginset * call cleanup from reload if non-muxed * move v5 backend code to new package * use context reload for for ErrPluginShutdown case * add wrapper on v5 backend * fix run config UTs * fix unit tests - use v4/v5 mapping for plugin versions - fix test build err - add reload method on fakePluginClient - add multiplexed cases for integration tests * remove comment and update AutoMTLS field in test * remove comment * remove errwrap and unused context * only support metadatamode false for v5 backend plugins * update plugin catalog errors * use const for env variables * rename locks and remove unused * remove unneeded nil check * improvements based on staticcheck recommendations * use const for single implementation string * use const for context key * use info default log level * move pid to pluginClient struct * remove v3 and v4 from multiplexed plugin set * return from reload when non-multiplexed * update automtls env string * combine getBackend and getBrokeredClient * update comments for plugin reload, Backend return val and log * revert Backend return type * allow non-muxed plugins to serve v5 * move v5 code to existing sdk plugin package * do next export sdk fields now that we have removed extra plugin pkg * set TLSProvider in ServeMultiplex for backwards compat * use bool to flag multiplexing support on grpc backend server * revert userpass main.go * refactor plugin sdk - update comments - make use of multiplexing boolean and single implementation ID const * update comment and use multierr * attempt v4 if dispense fails on getPluginTypeForUnknown * update comments on sdk plugin backend
This commit is contained in:
parent
cf332842cc
commit
b6c05fae33
|
@ -16,7 +16,11 @@ import (
|
|||
"github.com/hashicorp/errwrap"
|
||||
)
|
||||
|
||||
var (
|
||||
const (
|
||||
// PluginAutoMTLSEnv is used to ensure AutoMTLS is used. This will override
|
||||
// setting a TLSProviderFunc for a plugin.
|
||||
PluginAutoMTLSEnv = "VAULT_PLUGIN_AUTOMTLS_ENABLED"
|
||||
|
||||
// PluginMetadataModeEnv is an ENV name used to disable TLS communication
|
||||
// to bootstrap mounting plugins.
|
||||
PluginMetadataModeEnv = "VAULT_PLUGIN_METADATA_MODE"
|
||||
|
@ -24,14 +28,15 @@ var (
|
|||
// PluginUnwrapTokenEnv is the ENV name used to pass unwrap tokens to the
|
||||
// plugin.
|
||||
PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN"
|
||||
)
|
||||
|
||||
// sudoPaths is a map containing the paths that require a token's policy
|
||||
// to have the "sudo" capability. The keys are the paths as strings, in
|
||||
// the same format as they are returned by the OpenAPI spec. The values
|
||||
// are the regular expressions that can be used to test whether a given
|
||||
// path matches that path or not (useful specifically for the paths that
|
||||
// contain templated fields.)
|
||||
sudoPaths = map[string]*regexp.Regexp{
|
||||
// sudoPaths is a map containing the paths that require a token's policy
|
||||
// to have the "sudo" capability. The keys are the paths as strings, in
|
||||
// the same format as they are returned by the OpenAPI spec. The values
|
||||
// are the regular expressions that can be used to test whether a given
|
||||
// path matches that path or not (useful specifically for the paths that
|
||||
// contain templated fields.)
|
||||
var sudoPaths = map[string]*regexp.Regexp{
|
||||
"/auth/token/accessors/": regexp.MustCompile(`^/auth/token/accessors/$`),
|
||||
"/pki/root": regexp.MustCompile(`^/pki/root$`),
|
||||
"/pki/root/sign-self-issued": regexp.MustCompile(`^/pki/root/sign-self-issued$`),
|
||||
|
@ -66,8 +71,7 @@ var (
|
|||
"/sys/replication/reindex": regexp.MustCompile(`^/sys/replication/reindex$`),
|
||||
"/sys/storage/raft/snapshot-auto/config/": regexp.MustCompile(`^/sys/storage/raft/snapshot-auto/config/$`),
|
||||
"/sys/storage/raft/snapshot-auto/config/{name}": regexp.MustCompile(`^/sys/storage/raft/snapshot-auto/config/[^/]+$`),
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// PluginAPIClientMeta is a helper that plugins can use to configure TLS connections
|
||||
// back to Vault.
|
||||
|
@ -120,7 +124,7 @@ func VaultPluginTLSProvider(apiTLSConfig *TLSConfig) func() (*tls.Config, error)
|
|||
// VaultPluginTLSProviderContext is run inside a plugin and retrieves the response
|
||||
// wrapped TLS certificate from vault. It returns a configured TLS Config.
|
||||
func VaultPluginTLSProviderContext(ctx context.Context, apiTLSConfig *TLSConfig) func() (*tls.Config, error) {
|
||||
if os.Getenv(PluginMetadataModeEnv) == "true" {
|
||||
if os.Getenv(PluginAutoMTLSEnv) == "true" || os.Getenv(PluginMetadataModeEnv) == "true" {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@ func main() {
|
|||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args[1:])
|
||||
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
|
|
|
@ -9,7 +9,9 @@ import (
|
|||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
v5 "github.com/hashicorp/vault/builtin/plugin/v5"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
|
@ -23,17 +25,29 @@ var (
|
|||
|
||||
// Factory returns a configured plugin logical.Backend.
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
merr := &multierror.Error{}
|
||||
_, ok := conf.Config["plugin_name"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("plugin_name not provided")
|
||||
}
|
||||
b, err := Backend(ctx, conf)
|
||||
if err != nil {
|
||||
b, err := v5.Backend(ctx, conf)
|
||||
if err == nil {
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
merr = multierror.Append(merr, err)
|
||||
|
||||
b, err = Backend(ctx, conf)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
return nil, fmt.Errorf("invalid backend version: %s", merr)
|
||||
}
|
||||
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
merr = multierror.Append(merr, err)
|
||||
return nil, merr.ErrorOrNil()
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
|
|
@ -24,23 +24,35 @@ func TestBackend_impl(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBackend(t *testing.T) {
|
||||
config, cleanup := testConfig(t)
|
||||
pluginCmds := []string{"TestBackend_PluginMain", "TestBackend_PluginMain_Multiplexed"}
|
||||
|
||||
for _, pluginCmd := range pluginCmds {
|
||||
t.Run(pluginCmd, func(t *testing.T) {
|
||||
config, cleanup := testConfig(t, pluginCmd)
|
||||
defer cleanup()
|
||||
|
||||
_, err := plugin.Backend(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_Factory(t *testing.T) {
|
||||
config, cleanup := testConfig(t)
|
||||
pluginCmds := []string{"TestBackend_PluginMain", "TestBackend_PluginMain_Multiplexed"}
|
||||
|
||||
for _, pluginCmd := range pluginCmds {
|
||||
t.Run(pluginCmd, func(t *testing.T) {
|
||||
config, cleanup := testConfig(t, pluginCmd)
|
||||
defer cleanup()
|
||||
|
||||
_, err := plugin.Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_PluginMain(t *testing.T) {
|
||||
|
@ -71,7 +83,35 @@ func TestBackend_PluginMain(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func testConfig(t *testing.T) (*logical.BackendConfig, func()) {
|
||||
func TestBackend_PluginMain_Multiplexed(t *testing.T) {
|
||||
args := []string{}
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadataModeEnv) != "true" {
|
||||
return
|
||||
}
|
||||
|
||||
caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv)
|
||||
if caPEM == "" {
|
||||
t.Fatal("CA cert not passed in")
|
||||
}
|
||||
|
||||
args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM))
|
||||
|
||||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
err := logicalPlugin.ServeMultiplex(&logicalPlugin.ServeOpts{
|
||||
BackendFactoryFunc: mock.Factory,
|
||||
TLSProviderFunc: tlsProviderFunc,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func testConfig(t *testing.T, pluginCmd string) (*logical.BackendConfig, func()) {
|
||||
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
})
|
||||
|
@ -93,7 +133,7 @@ func testConfig(t *testing.T) (*logical.BackendConfig, func()) {
|
|||
|
||||
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
|
||||
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMain", []string{}, "")
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, pluginCmd, []string{}, "")
|
||||
|
||||
return config, func() {
|
||||
cluster.Cleanup()
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/rpc"
|
||||
"sync"
|
||||
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/sdk/plugin"
|
||||
bplugin "github.com/hashicorp/vault/sdk/plugin"
|
||||
)
|
||||
|
||||
// Backend returns an instance of the backend, either as a plugin if external
|
||||
// or as a concrete implementation if builtin, casted as logical.Backend.
|
||||
func Backend(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
var b backend
|
||||
name := conf.Config["plugin_name"]
|
||||
pluginType, err := consts.ParsePluginType(conf.Config["plugin_type"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sys := conf.System
|
||||
|
||||
raw, err := plugin.NewBackendV5(ctx, name, pluginType, sys, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.Backend = raw
|
||||
b.config = conf
|
||||
|
||||
return &b, nil
|
||||
}
|
||||
|
||||
// backend is a thin wrapper around plugin.BackendPluginClientV5
|
||||
type backend struct {
|
||||
logical.Backend
|
||||
mu sync.RWMutex
|
||||
|
||||
config *logical.BackendConfig
|
||||
|
||||
// Used to detect if we already reloaded
|
||||
canary string
|
||||
}
|
||||
|
||||
func (b *backend) reloadBackend(ctx context.Context) error {
|
||||
pluginName := b.config.Config["plugin_name"]
|
||||
pluginType, err := consts.ParsePluginType(b.config.Config["plugin_type"])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.Logger().Debug("plugin: reloading plugin backend", "plugin", pluginName)
|
||||
|
||||
// Ensure proper cleanup of the backend
|
||||
// Pass a context value so that the plugin client will call the appropriate
|
||||
// cleanup method for reloading
|
||||
reloadCtx := context.WithValue(ctx, plugin.ContextKeyPluginReload, "reload")
|
||||
b.Backend.Cleanup(reloadCtx)
|
||||
|
||||
nb, err := plugin.NewBackendV5(ctx, pluginName, pluginType, b.config.System, b.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = nb.Setup(ctx, b.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.Backend = nb
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleRequest is a thin wrapper implementation of HandleRequest that includes automatic plugin reload.
|
||||
func (b *backend) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) {
|
||||
b.mu.RLock()
|
||||
canary := b.canary
|
||||
resp, err := b.Backend.HandleRequest(ctx, req)
|
||||
b.mu.RUnlock()
|
||||
// Need to compare string value for case were err comes from plugin RPC
|
||||
// and is returned as plugin.BasicError type.
|
||||
if err != nil &&
|
||||
(err.Error() == rpc.ErrShutdown.Error() || err == bplugin.ErrPluginShutdown) {
|
||||
// Reload plugin if it's an rpc.ErrShutdown
|
||||
b.mu.Lock()
|
||||
if b.canary == canary {
|
||||
err := b.reloadBackend(ctx)
|
||||
if err != nil {
|
||||
b.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
b.canary, err = uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
b.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
// Try request once more
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
return b.Backend.HandleRequest(ctx, req)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// HandleExistenceCheck is a thin wrapper implementation of HandleRequest that includes automatic plugin reload.
|
||||
func (b *backend) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) {
|
||||
b.mu.RLock()
|
||||
canary := b.canary
|
||||
checkFound, exists, err := b.Backend.HandleExistenceCheck(ctx, req)
|
||||
b.mu.RUnlock()
|
||||
if err != nil &&
|
||||
(err.Error() == rpc.ErrShutdown.Error() || err == bplugin.ErrPluginShutdown) {
|
||||
// Reload plugin if it's an rpc.ErrShutdown
|
||||
b.mu.Lock()
|
||||
if b.canary == canary {
|
||||
err := b.reloadBackend(ctx)
|
||||
if err != nil {
|
||||
b.mu.Unlock()
|
||||
return false, false, err
|
||||
}
|
||||
b.canary, err = uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
b.mu.Unlock()
|
||||
return false, false, err
|
||||
}
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
// Try request once more
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
return b.Backend.HandleExistenceCheck(ctx, req)
|
||||
}
|
||||
return checkFound, exists, err
|
||||
}
|
||||
|
||||
// InvalidateKey is a thin wrapper used to ensure we grab the lock for race purposes
|
||||
func (b *backend) InvalidateKey(ctx context.Context, key string) {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
b.Backend.InvalidateKey(ctx, key)
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
```release-note:feature
|
||||
**Secrets/auth plugin multiplexing**: manage multiple plugin configurations with a single plugin process
|
||||
```
|
|
@ -81,14 +81,10 @@ func TestPlugin_PluginMain(t *testing.T) {
|
|||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
factoryFunc := mock.FactoryType(logical.TypeLogical)
|
||||
|
||||
err := plugin.Serve(&plugin.ServeOpts{
|
||||
BackendFactoryFunc: factoryFunc,
|
||||
TLSProviderFunc: tlsProviderFunc,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
|
@ -30,25 +29,6 @@ type gRPCServer struct {
|
|||
sync.RWMutex
|
||||
}
|
||||
|
||||
func getMultiplexIDFromContext(ctx context.Context) (string, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing plugin multiplexing metadata")
|
||||
}
|
||||
|
||||
multiplexIDs := md[pluginutil.MultiplexingCtxKey]
|
||||
if len(multiplexIDs) != 1 {
|
||||
return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs))
|
||||
}
|
||||
|
||||
multiplexID := multiplexIDs[0]
|
||||
if multiplexID == "" {
|
||||
return "", fmt.Errorf("empty multiplex ID in metadata")
|
||||
}
|
||||
|
||||
return multiplexID, nil
|
||||
}
|
||||
|
||||
func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
|
@ -57,7 +37,7 @@ func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error)
|
|||
return g.singleImpl, nil
|
||||
}
|
||||
|
||||
id, err := getMultiplexIDFromContext(ctx)
|
||||
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -83,7 +63,7 @@ func (g *gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error)
|
|||
return g.singleImpl, nil
|
||||
}
|
||||
|
||||
id, err := getMultiplexIDFromContext(ctx)
|
||||
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -312,7 +292,7 @@ func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, e
|
|||
|
||||
if g.singleImpl == nil {
|
||||
// only cleanup instances map when multiplexing is supported
|
||||
id, err := getMultiplexIDFromContext(ctx)
|
||||
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -581,57 +581,6 @@ func TestGRPCServer_Close(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGetMultiplexIDFromContext(t *testing.T) {
|
||||
type testCase struct {
|
||||
ctx context.Context
|
||||
expectedResp string
|
||||
expectedErr error
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"missing plugin multiplexing metadata": {
|
||||
ctx: context.Background(),
|
||||
expectedResp: "",
|
||||
expectedErr: fmt.Errorf("missing plugin multiplexing metadata"),
|
||||
},
|
||||
"unexpected number of IDs in metadata": {
|
||||
ctx: idCtx(t, "12345", "67891"),
|
||||
expectedResp: "",
|
||||
expectedErr: fmt.Errorf("unexpected number of IDs in metadata: (2)"),
|
||||
},
|
||||
"empty multiplex ID in metadata": {
|
||||
ctx: idCtx(t, ""),
|
||||
expectedResp: "",
|
||||
expectedErr: fmt.Errorf("empty multiplex ID in metadata"),
|
||||
},
|
||||
"happy path, id is returned from metadata": {
|
||||
ctx: idCtx(t, "12345"),
|
||||
expectedResp: "12345",
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
resp, err := getMultiplexIDFromContext(test.ctx)
|
||||
|
||||
if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
} else if !reflect.DeepEqual(err, test.expectedErr) {
|
||||
t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr)
|
||||
}
|
||||
|
||||
if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(resp, test.expectedResp) {
|
||||
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testGrpcServer is a test helper that returns a context with an ID set in its
|
||||
// metadata and a gRPCServer instance for a multiplexed plugin
|
||||
func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) {
|
||||
|
|
|
@ -112,6 +112,10 @@ func (f *fakePluginClient) Conn() grpc.ClientConnInterface {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (f *fakePluginClient) Reload() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakePluginClient) Dispense(name string) (interface{}, error) {
|
||||
return f.dispenseResp, f.dispenseErr
|
||||
}
|
||||
|
|
|
@ -7,7 +7,11 @@ import (
|
|||
version "github.com/hashicorp/go-version"
|
||||
)
|
||||
|
||||
var (
|
||||
const (
|
||||
// PluginAutoMTLSEnv is used to ensure AutoMTLS is used. This will override
|
||||
// setting a TLSProviderFunc for a plugin.
|
||||
PluginAutoMTLSEnv = "VAULT_PLUGIN_AUTOMTLS_ENABLED"
|
||||
|
||||
// PluginMlockEnabled is the ENV name used to pass the configuration for
|
||||
// enabling mlock
|
||||
PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED"
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
|
@ -45,3 +46,22 @@ func MultiplexingSupported(ctx context.Context, cc grpc.ClientConnInterface) (bo
|
|||
|
||||
return resp.Supported, nil
|
||||
}
|
||||
|
||||
func GetMultiplexIDFromContext(ctx context.Context) (string, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing plugin multiplexing metadata")
|
||||
}
|
||||
|
||||
multiplexIDs := md[MultiplexingCtxKey]
|
||||
if len(multiplexIDs) != 1 {
|
||||
return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs))
|
||||
}
|
||||
|
||||
multiplexID := multiplexIDs[0]
|
||||
if multiplexID == "" {
|
||||
return "", fmt.Errorf("empty multiplex ID in metadata")
|
||||
}
|
||||
|
||||
return multiplexID, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
package pluginutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestGetMultiplexIDFromContext(t *testing.T) {
|
||||
type testCase struct {
|
||||
ctx context.Context
|
||||
expectedResp string
|
||||
expectedErr error
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"missing plugin multiplexing metadata": {
|
||||
ctx: context.Background(),
|
||||
expectedResp: "",
|
||||
expectedErr: fmt.Errorf("missing plugin multiplexing metadata"),
|
||||
},
|
||||
"unexpected number of IDs in metadata": {
|
||||
ctx: idCtx(t, "12345", "67891"),
|
||||
expectedResp: "",
|
||||
expectedErr: fmt.Errorf("unexpected number of IDs in metadata: (2)"),
|
||||
},
|
||||
"empty multiplex ID in metadata": {
|
||||
ctx: idCtx(t, ""),
|
||||
expectedResp: "",
|
||||
expectedErr: fmt.Errorf("empty multiplex ID in metadata"),
|
||||
},
|
||||
"happy path, id is returned from metadata": {
|
||||
ctx: idCtx(t, "12345"),
|
||||
expectedResp: "12345",
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
resp, err := GetMultiplexIDFromContext(test.ctx)
|
||||
|
||||
if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
} else if !reflect.DeepEqual(err, test.expectedErr) {
|
||||
t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr)
|
||||
}
|
||||
|
||||
if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(resp, test.expectedResp) {
|
||||
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// idCtx is a test helper that will return a context with the IDs set in its
|
||||
// metadata
|
||||
func idCtx(t *testing.T, ids ...string) context.Context {
|
||||
// Context doesn't need to timeout since this is just passed through
|
||||
ctx := context.Background()
|
||||
md := metadata.MD{}
|
||||
for _, id := range ids {
|
||||
md.Append(MultiplexingCtxKey, id)
|
||||
}
|
||||
return metadata.NewIncomingContext(ctx, md)
|
||||
}
|
|
@ -23,6 +23,7 @@ type PluginClientConfig struct {
|
|||
IsMetadataMode bool
|
||||
AutoMTLS bool
|
||||
MLock bool
|
||||
Wrapper RunnerUtil
|
||||
}
|
||||
|
||||
type runConfig struct {
|
||||
|
@ -34,8 +35,6 @@ type runConfig struct {
|
|||
// Initialized with what's in PluginRunner.Env, but can be added to
|
||||
env []string
|
||||
|
||||
wrapper RunnerUtil
|
||||
|
||||
PluginClientConfig
|
||||
}
|
||||
|
||||
|
@ -44,7 +43,7 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error
|
|||
cmd.Env = append(cmd.Env, rc.env...)
|
||||
|
||||
// Add the mlock setting to the ENV of the plugin
|
||||
if rc.MLock || (rc.wrapper != nil && rc.wrapper.MlockEnabled()) {
|
||||
if rc.MLock || (rc.Wrapper != nil && rc.Wrapper.MlockEnabled()) {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true"))
|
||||
}
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version))
|
||||
|
@ -55,6 +54,9 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error
|
|||
metadataEnv := fmt.Sprintf("%s=%t", PluginMetadataModeEnv, rc.IsMetadataMode)
|
||||
cmd.Env = append(cmd.Env, metadataEnv)
|
||||
|
||||
automtlsEnv := fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, rc.AutoMTLS)
|
||||
cmd.Env = append(cmd.Env, automtlsEnv)
|
||||
|
||||
var clientTLSConfig *tls.Config
|
||||
if !rc.AutoMTLS && !rc.IsMetadataMode {
|
||||
// Get a CA TLS Certificate
|
||||
|
@ -71,7 +73,7 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error
|
|||
|
||||
// Use CA to sign a server cert and wrap the values in a response wrapped
|
||||
// token.
|
||||
wrapToken, err := wrapServerConfig(ctx, rc.wrapper, certBytes, key)
|
||||
wrapToken, err := wrapServerConfig(ctx, rc.Wrapper, certBytes, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -121,7 +123,7 @@ func Env(env ...string) RunOpt {
|
|||
|
||||
func Runner(wrapper RunnerUtil) RunOpt {
|
||||
return func(rc *runConfig) {
|
||||
rc.wrapper = wrapper
|
||||
rc.Wrapper = wrapper
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -14,6 +13,7 @@ import (
|
|||
"github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/sdk/helper/wrapping"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMakeConfig(t *testing.T) {
|
||||
|
@ -78,6 +78,7 @@ func TestMakeConfig(t *testing.T) {
|
|||
"initial=true",
|
||||
fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version),
|
||||
fmt.Sprintf("%s=%t", PluginMetadataModeEnv, true),
|
||||
fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, false),
|
||||
},
|
||||
),
|
||||
SecureConfig: &plugin.SecureConfig{
|
||||
|
@ -143,6 +144,7 @@ func TestMakeConfig(t *testing.T) {
|
|||
fmt.Sprintf("%s=%t", PluginMlockEnabled, true),
|
||||
fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version),
|
||||
fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false),
|
||||
fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, false),
|
||||
fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, "testtoken"),
|
||||
},
|
||||
),
|
||||
|
@ -205,6 +207,7 @@ func TestMakeConfig(t *testing.T) {
|
|||
"initial=true",
|
||||
fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version),
|
||||
fmt.Sprintf("%s=%t", PluginMetadataModeEnv, true),
|
||||
fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true),
|
||||
},
|
||||
),
|
||||
SecureConfig: &plugin.SecureConfig{
|
||||
|
@ -266,6 +269,7 @@ func TestMakeConfig(t *testing.T) {
|
|||
"initial=true",
|
||||
fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version),
|
||||
fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false),
|
||||
fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true),
|
||||
},
|
||||
),
|
||||
SecureConfig: &plugin.SecureConfig{
|
||||
|
@ -290,7 +294,7 @@ func TestMakeConfig(t *testing.T) {
|
|||
Return(test.responseWrapInfo, test.responseWrapInfoErr)
|
||||
mockWrapper.On("MlockEnabled").
|
||||
Return(test.mlockEnabled)
|
||||
test.rc.wrapper = mockWrapper
|
||||
test.rc.Wrapper = mockWrapper
|
||||
defer mockWrapper.AssertNumberOfCalls(t, "ResponseWrapData", test.responseWrapInfoTimes)
|
||||
defer mockWrapper.AssertNumberOfCalls(t, "MlockEnabled", test.mlockEnabledTimes)
|
||||
|
||||
|
@ -318,9 +322,7 @@ func TestMakeConfig(t *testing.T) {
|
|||
}
|
||||
config.TLSConfig = nil
|
||||
|
||||
if !reflect.DeepEqual(config, test.expectedConfig) {
|
||||
t.Fatalf("Actual config: %#v\nExpected config: %#v", config, test.expectedConfig)
|
||||
}
|
||||
require.Equal(t, test.expectedConfig, config)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ type LookRunnerUtil interface {
|
|||
|
||||
type PluginClient interface {
|
||||
Conn() grpc.ClientConnInterface
|
||||
Reload() error
|
||||
plugin.ClientProtocol
|
||||
}
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/sdk/plugin/pb"
|
||||
)
|
||||
|
@ -24,18 +25,32 @@ type GRPCBackendPlugin struct {
|
|||
MetadataMode bool
|
||||
Logger log.Logger
|
||||
|
||||
MultiplexingSupport bool
|
||||
|
||||
// Embeding this will disable the netRPC protocol
|
||||
plugin.NetRPCUnsupportedPlugin
|
||||
}
|
||||
|
||||
func (b GRPCBackendPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error {
|
||||
pb.RegisterBackendServer(s, &backendGRPCPluginServer{
|
||||
server := backendGRPCPluginServer{
|
||||
broker: broker,
|
||||
factory: b.Factory,
|
||||
// We pass the logger down into the backend so go-plugin will forward
|
||||
// logs for us.
|
||||
instances: make(map[string]backendInstance),
|
||||
// We pass the logger down into the backend so go-plugin will
|
||||
// forward logs for us.
|
||||
logger: b.Logger,
|
||||
}
|
||||
|
||||
if b.MultiplexingSupport {
|
||||
// Multiplexing is enabled for this plugin, register the server so we
|
||||
// can tell the client in Vault.
|
||||
pluginutil.RegisterPluginMultiplexingServer(s, pluginutil.PluginMultiplexingServerImpl{
|
||||
Supported: true,
|
||||
})
|
||||
server.multiplexingSupport = true
|
||||
}
|
||||
|
||||
pb.RegisterBackendServer(s, &server)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,8 @@ package plugin
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
|
@ -14,29 +16,79 @@ import (
|
|||
|
||||
var ErrServerInMetadataMode = errors.New("plugin server can not perform action while in metadata mode")
|
||||
|
||||
// singleImplementationID is the string used to define the instance ID of a
|
||||
// non-multiplexed plugin
|
||||
const singleImplementationID string = "single"
|
||||
|
||||
type backendInstance struct {
|
||||
brokeredClient *grpc.ClientConn
|
||||
backend logical.Backend
|
||||
}
|
||||
|
||||
type backendGRPCPluginServer struct {
|
||||
pb.UnimplementedBackendServer
|
||||
|
||||
broker *plugin.GRPCBroker
|
||||
backend logical.Backend
|
||||
|
||||
instances map[string]backendInstance
|
||||
instancesLock sync.RWMutex
|
||||
multiplexingSupport bool
|
||||
|
||||
factory logical.Factory
|
||||
|
||||
brokeredClient *grpc.ClientConn
|
||||
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
// getBackendAndBrokeredClientInternal returns the backend and client
|
||||
// connection but does not hold a lock
|
||||
func (b *backendGRPCPluginServer) getBackendAndBrokeredClientInternal(ctx context.Context) (logical.Backend, *grpc.ClientConn, error) {
|
||||
if b.multiplexingSupport {
|
||||
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if inst, ok := b.instances[id]; ok {
|
||||
return inst.backend, inst.brokeredClient, nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if singleImpl, ok := b.instances[singleImplementationID]; ok {
|
||||
return singleImpl.backend, singleImpl.brokeredClient, nil
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("no backend instance found")
|
||||
}
|
||||
|
||||
// getBackendAndBrokeredClient holds a read lock and returns the backend and
|
||||
// client connection
|
||||
func (b *backendGRPCPluginServer) getBackendAndBrokeredClient(ctx context.Context) (logical.Backend, *grpc.ClientConn, error) {
|
||||
b.instancesLock.RLock()
|
||||
defer b.instancesLock.RUnlock()
|
||||
return b.getBackendAndBrokeredClientInternal(ctx)
|
||||
}
|
||||
|
||||
// Setup dials into the plugin's broker to get a shimmed storage, logger, and
|
||||
// system view of the backend. This method also instantiates the underlying
|
||||
// backend through its factory func for the server side of the plugin.
|
||||
func (b *backendGRPCPluginServer) Setup(ctx context.Context, args *pb.SetupArgs) (*pb.SetupReply, error) {
|
||||
var err error
|
||||
id := singleImplementationID
|
||||
|
||||
if b.multiplexingSupport {
|
||||
id, err = pluginutil.GetMultiplexIDFromContext(ctx)
|
||||
if err != nil {
|
||||
return &pb.SetupReply{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Dial for storage
|
||||
brokeredClient, err := b.broker.Dial(args.BrokerID)
|
||||
if err != nil {
|
||||
return &pb.SetupReply{}, err
|
||||
}
|
||||
b.brokeredClient = brokeredClient
|
||||
|
||||
storage := newGRPCStorageClient(brokeredClient)
|
||||
sysView := newGRPCSystemView(brokeredClient)
|
||||
|
||||
|
@ -56,12 +108,20 @@ func (b *backendGRPCPluginServer) Setup(ctx context.Context, args *pb.SetupArgs)
|
|||
Err: pb.ErrToString(err),
|
||||
}, nil
|
||||
}
|
||||
b.backend = backend
|
||||
b.instances[id] = backendInstance{
|
||||
brokeredClient: brokeredClient,
|
||||
backend: backend,
|
||||
}
|
||||
|
||||
return &pb.SetupReply{}, nil
|
||||
}
|
||||
|
||||
func (b *backendGRPCPluginServer) HandleRequest(ctx context.Context, args *pb.HandleRequestArgs) (*pb.HandleRequestReply, error) {
|
||||
backend, brokeredClient, err := b.getBackendAndBrokeredClient(ctx)
|
||||
if err != nil {
|
||||
return &pb.HandleRequestReply{}, err
|
||||
}
|
||||
|
||||
if pluginutil.InMetadataMode() {
|
||||
return &pb.HandleRequestReply{}, ErrServerInMetadataMode
|
||||
}
|
||||
|
@ -71,9 +131,9 @@ func (b *backendGRPCPluginServer) HandleRequest(ctx context.Context, args *pb.Ha
|
|||
return &pb.HandleRequestReply{}, err
|
||||
}
|
||||
|
||||
logicalReq.Storage = newGRPCStorageClient(b.brokeredClient)
|
||||
logicalReq.Storage = newGRPCStorageClient(brokeredClient)
|
||||
|
||||
resp, respErr := b.backend.HandleRequest(ctx, logicalReq)
|
||||
resp, respErr := backend.HandleRequest(ctx, logicalReq)
|
||||
|
||||
pbResp, err := pb.LogicalResponseToProtoResponse(resp)
|
||||
if err != nil {
|
||||
|
@ -87,15 +147,20 @@ func (b *backendGRPCPluginServer) HandleRequest(ctx context.Context, args *pb.Ha
|
|||
}
|
||||
|
||||
func (b *backendGRPCPluginServer) Initialize(ctx context.Context, _ *pb.InitializeArgs) (*pb.InitializeReply, error) {
|
||||
backend, brokeredClient, err := b.getBackendAndBrokeredClient(ctx)
|
||||
if err != nil {
|
||||
return &pb.InitializeReply{}, err
|
||||
}
|
||||
|
||||
if pluginutil.InMetadataMode() {
|
||||
return &pb.InitializeReply{}, ErrServerInMetadataMode
|
||||
}
|
||||
|
||||
req := &logical.InitializationRequest{
|
||||
Storage: newGRPCStorageClient(b.brokeredClient),
|
||||
Storage: newGRPCStorageClient(brokeredClient),
|
||||
}
|
||||
|
||||
respErr := b.backend.Initialize(ctx, req)
|
||||
respErr := backend.Initialize(ctx, req)
|
||||
|
||||
return &pb.InitializeReply{
|
||||
Err: pb.ErrToProtoErr(respErr),
|
||||
|
@ -103,7 +168,12 @@ func (b *backendGRPCPluginServer) Initialize(ctx context.Context, _ *pb.Initiali
|
|||
}
|
||||
|
||||
func (b *backendGRPCPluginServer) SpecialPaths(ctx context.Context, args *pb.Empty) (*pb.SpecialPathsReply, error) {
|
||||
paths := b.backend.SpecialPaths()
|
||||
backend, _, err := b.getBackendAndBrokeredClient(ctx)
|
||||
if err != nil {
|
||||
return &pb.SpecialPathsReply{}, err
|
||||
}
|
||||
|
||||
paths := backend.SpecialPaths()
|
||||
if paths == nil {
|
||||
return &pb.SpecialPathsReply{
|
||||
Paths: nil,
|
||||
|
@ -121,6 +191,11 @@ func (b *backendGRPCPluginServer) SpecialPaths(ctx context.Context, args *pb.Emp
|
|||
}
|
||||
|
||||
func (b *backendGRPCPluginServer) HandleExistenceCheck(ctx context.Context, args *pb.HandleExistenceCheckArgs) (*pb.HandleExistenceCheckReply, error) {
|
||||
backend, brokeredClient, err := b.getBackendAndBrokeredClient(ctx)
|
||||
if err != nil {
|
||||
return &pb.HandleExistenceCheckReply{}, err
|
||||
}
|
||||
|
||||
if pluginutil.InMetadataMode() {
|
||||
return &pb.HandleExistenceCheckReply{}, ErrServerInMetadataMode
|
||||
}
|
||||
|
@ -129,9 +204,10 @@ func (b *backendGRPCPluginServer) HandleExistenceCheck(ctx context.Context, args
|
|||
if err != nil {
|
||||
return &pb.HandleExistenceCheckReply{}, err
|
||||
}
|
||||
logicalReq.Storage = newGRPCStorageClient(b.brokeredClient)
|
||||
|
||||
checkFound, exists, err := b.backend.HandleExistenceCheck(ctx, logicalReq)
|
||||
logicalReq.Storage = newGRPCStorageClient(brokeredClient)
|
||||
|
||||
checkFound, exists, err := backend.HandleExistenceCheck(ctx, logicalReq)
|
||||
return &pb.HandleExistenceCheckReply{
|
||||
CheckFound: checkFound,
|
||||
Exists: exists,
|
||||
|
@ -140,24 +216,53 @@ func (b *backendGRPCPluginServer) HandleExistenceCheck(ctx context.Context, args
|
|||
}
|
||||
|
||||
func (b *backendGRPCPluginServer) Cleanup(ctx context.Context, _ *pb.Empty) (*pb.Empty, error) {
|
||||
b.backend.Cleanup(ctx)
|
||||
b.instancesLock.Lock()
|
||||
defer b.instancesLock.Unlock()
|
||||
|
||||
backend, brokeredClient, err := b.getBackendAndBrokeredClientInternal(ctx)
|
||||
if err != nil {
|
||||
return &pb.Empty{}, err
|
||||
}
|
||||
|
||||
backend.Cleanup(ctx)
|
||||
|
||||
// Close rpc clients
|
||||
b.brokeredClient.Close()
|
||||
brokeredClient.Close()
|
||||
|
||||
if b.multiplexingSupport {
|
||||
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
delete(b.instances, id)
|
||||
} else if _, ok := b.instances[singleImplementationID]; ok {
|
||||
delete(b.instances, singleImplementationID)
|
||||
}
|
||||
|
||||
return &pb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (b *backendGRPCPluginServer) InvalidateKey(ctx context.Context, args *pb.InvalidateKeyArgs) (*pb.Empty, error) {
|
||||
backend, _, err := b.getBackendAndBrokeredClient(ctx)
|
||||
if err != nil {
|
||||
return &pb.Empty{}, err
|
||||
}
|
||||
|
||||
if pluginutil.InMetadataMode() {
|
||||
return &pb.Empty{}, ErrServerInMetadataMode
|
||||
}
|
||||
|
||||
b.backend.InvalidateKey(ctx, args.Key)
|
||||
backend.InvalidateKey(ctx, args.Key)
|
||||
return &pb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (b *backendGRPCPluginServer) Type(ctx context.Context, _ *pb.Empty) (*pb.TypeReply, error) {
|
||||
backend, _, err := b.getBackendAndBrokeredClient(ctx)
|
||||
if err != nil {
|
||||
return &pb.TypeReply{}, err
|
||||
}
|
||||
|
||||
return &pb.TypeReply{
|
||||
Type: uint32(b.backend.Type()),
|
||||
Type: uint32(backend.Type()),
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -6,12 +6,12 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/sdk/plugin/pb"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestSystem_GRPC_GRPC_impl(t *testing.T) {
|
||||
|
|
|
@ -222,7 +222,6 @@ func (l *deprecatedLoggerClient) Error(msg string, args ...interface{}) error {
|
|||
|
||||
func (l *deprecatedLoggerClient) Fatal(msg string, args ...interface{}) {
|
||||
// NOOP since it's not actually used within vault
|
||||
return
|
||||
}
|
||||
|
||||
func (l *deprecatedLoggerClient) Log(level int, msg string, args []interface{}) {
|
||||
|
|
|
@ -10,16 +10,16 @@ import (
|
|||
|
||||
// backendPluginClient implements logical.Backend and is the
|
||||
// go-plugin client.
|
||||
type backendTracingMiddleware struct {
|
||||
type BackendTracingMiddleware struct {
|
||||
logger log.Logger
|
||||
|
||||
next logical.Backend
|
||||
}
|
||||
|
||||
// Validate the backendTracingMiddle object satisfies the backend interface
|
||||
var _ logical.Backend = &backendTracingMiddleware{}
|
||||
var _ logical.Backend = &BackendTracingMiddleware{}
|
||||
|
||||
func (b *backendTracingMiddleware) Initialize(ctx context.Context, req *logical.InitializationRequest) (err error) {
|
||||
func (b *BackendTracingMiddleware) Initialize(ctx context.Context, req *logical.InitializationRequest) (err error) {
|
||||
defer func(then time.Time) {
|
||||
b.logger.Trace("initialize", "status", "finished", "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
@ -28,7 +28,7 @@ func (b *backendTracingMiddleware) Initialize(ctx context.Context, req *logical.
|
|||
return b.next.Initialize(ctx, req)
|
||||
}
|
||||
|
||||
func (b *backendTracingMiddleware) HandleRequest(ctx context.Context, req *logical.Request) (resp *logical.Response, err error) {
|
||||
func (b *BackendTracingMiddleware) HandleRequest(ctx context.Context, req *logical.Request) (resp *logical.Response, err error) {
|
||||
defer func(then time.Time) {
|
||||
b.logger.Trace("handle request", "path", req.Path, "status", "finished", "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
@ -37,7 +37,7 @@ func (b *backendTracingMiddleware) HandleRequest(ctx context.Context, req *logic
|
|||
return b.next.HandleRequest(ctx, req)
|
||||
}
|
||||
|
||||
func (b *backendTracingMiddleware) SpecialPaths() *logical.Paths {
|
||||
func (b *BackendTracingMiddleware) SpecialPaths() *logical.Paths {
|
||||
defer func(then time.Time) {
|
||||
b.logger.Trace("special paths", "status", "finished", "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
@ -46,15 +46,15 @@ func (b *backendTracingMiddleware) SpecialPaths() *logical.Paths {
|
|||
return b.next.SpecialPaths()
|
||||
}
|
||||
|
||||
func (b *backendTracingMiddleware) System() logical.SystemView {
|
||||
func (b *BackendTracingMiddleware) System() logical.SystemView {
|
||||
return b.next.System()
|
||||
}
|
||||
|
||||
func (b *backendTracingMiddleware) Logger() log.Logger {
|
||||
func (b *BackendTracingMiddleware) Logger() log.Logger {
|
||||
return b.next.Logger()
|
||||
}
|
||||
|
||||
func (b *backendTracingMiddleware) HandleExistenceCheck(ctx context.Context, req *logical.Request) (found bool, exists bool, err error) {
|
||||
func (b *BackendTracingMiddleware) HandleExistenceCheck(ctx context.Context, req *logical.Request) (found bool, exists bool, err error) {
|
||||
defer func(then time.Time) {
|
||||
b.logger.Trace("handle existence check", "path", req.Path, "status", "finished", "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
@ -63,7 +63,7 @@ func (b *backendTracingMiddleware) HandleExistenceCheck(ctx context.Context, req
|
|||
return b.next.HandleExistenceCheck(ctx, req)
|
||||
}
|
||||
|
||||
func (b *backendTracingMiddleware) Cleanup(ctx context.Context) {
|
||||
func (b *BackendTracingMiddleware) Cleanup(ctx context.Context) {
|
||||
defer func(then time.Time) {
|
||||
b.logger.Trace("cleanup", "status", "finished", "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
@ -72,7 +72,7 @@ func (b *backendTracingMiddleware) Cleanup(ctx context.Context) {
|
|||
b.next.Cleanup(ctx)
|
||||
}
|
||||
|
||||
func (b *backendTracingMiddleware) InvalidateKey(ctx context.Context, key string) {
|
||||
func (b *BackendTracingMiddleware) InvalidateKey(ctx context.Context, key string) {
|
||||
defer func(then time.Time) {
|
||||
b.logger.Trace("invalidate key", "key", key, "status", "finished", "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
@ -81,7 +81,7 @@ func (b *backendTracingMiddleware) InvalidateKey(ctx context.Context, key string
|
|||
b.next.InvalidateKey(ctx, key)
|
||||
}
|
||||
|
||||
func (b *backendTracingMiddleware) Setup(ctx context.Context, config *logical.BackendConfig) (err error) {
|
||||
func (b *BackendTracingMiddleware) Setup(ctx context.Context, config *logical.BackendConfig) (err error) {
|
||||
defer func(then time.Time) {
|
||||
b.logger.Trace("setup", "status", "finished", "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
@ -90,7 +90,7 @@ func (b *backendTracingMiddleware) Setup(ctx context.Context, config *logical.Ba
|
|||
return b.next.Setup(ctx, config)
|
||||
}
|
||||
|
||||
func (b *backendTracingMiddleware) Type() logical.BackendType {
|
||||
func (b *BackendTracingMiddleware) Type() logical.BackendType {
|
||||
defer func(then time.Time) {
|
||||
b.logger.Trace("type", "status", "finished", "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
|
|
@ -4,9 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
|
@ -19,7 +17,6 @@ import (
|
|||
// used to cleanly kill the client on Cleanup()
|
||||
type BackendPluginClient struct {
|
||||
client *plugin.Client
|
||||
sync.Mutex
|
||||
|
||||
logical.Backend
|
||||
}
|
||||
|
@ -48,7 +45,7 @@ func NewBackend(ctx context.Context, pluginName string, pluginType consts.Plugin
|
|||
// from the pluginRunner. Then cast it to logical.Factory.
|
||||
rawFactory, err := pluginRunner.BuiltinFactory()
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err)
|
||||
return nil, fmt.Errorf("error getting plugin type: %q", err)
|
||||
}
|
||||
|
||||
if factory, ok := rawFactory.(logical.Factory); !ok {
|
||||
|
@ -93,9 +90,9 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
|
|||
var client *plugin.Client
|
||||
var err error
|
||||
if isMetadataMode {
|
||||
client, err = pluginRunner.RunMetadataMode(ctx, sys, pluginSet, handshakeConfig, []string{}, namedLogger)
|
||||
client, err = pluginRunner.RunMetadataMode(ctx, sys, pluginSet, HandshakeConfig, []string{}, namedLogger)
|
||||
} else {
|
||||
client, err = pluginRunner.Run(ctx, sys, pluginSet, handshakeConfig, []string{}, namedLogger)
|
||||
client, err = pluginRunner.Run(ctx, sys, pluginSet, HandshakeConfig, []string{}, namedLogger)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -117,9 +114,9 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
|
|||
var transport string
|
||||
// We should have a logical backend type now. This feels like a normal interface
|
||||
// implementation but is in fact over an RPC connection.
|
||||
switch raw.(type) {
|
||||
switch b := raw.(type) {
|
||||
case *backendGRPCPluginClient:
|
||||
backend = raw.(*backendGRPCPluginClient)
|
||||
backend = b
|
||||
transport = "gRPC"
|
||||
default:
|
||||
return nil, errors.New("unsupported plugin client type")
|
||||
|
@ -127,7 +124,7 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
|
|||
|
||||
// Wrap the backend in a tracing middleware
|
||||
if namedLogger.IsTrace() {
|
||||
backend = &backendTracingMiddleware{
|
||||
backend = &BackendTracingMiddleware{
|
||||
logger: namedLogger.With("transport", transport),
|
||||
next: backend,
|
||||
}
|
||||
|
@ -138,22 +135,3 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
|
|||
Backend: backend,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// wrapError takes a generic error type and makes it usable with the plugin
|
||||
// interface. Only errors which have exported fields and have been registered
|
||||
// with gob can be unwrapped and transported. This checks error types and, if
|
||||
// none match, wrap the error in a plugin.BasicError.
|
||||
func wrapError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch err.(type) {
|
||||
case *plugin.BasicError,
|
||||
logical.HTTPCodedError,
|
||||
*logical.StatusBadRequest:
|
||||
return err
|
||||
}
|
||||
|
||||
return plugin.NewBasicError(err)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/sdk/plugin/pb"
|
||||
)
|
||||
|
||||
// BackendPluginClientV5 is a wrapper around backendPluginClient
|
||||
// that also contains its plugin.Client instance. It's primarily
|
||||
// used to cleanly kill the client on Cleanup()
|
||||
type BackendPluginClientV5 struct {
|
||||
client pluginutil.PluginClient
|
||||
|
||||
logical.Backend
|
||||
}
|
||||
|
||||
type ContextKey string
|
||||
|
||||
func (c ContextKey) String() string {
|
||||
return "plugin" + string(c)
|
||||
}
|
||||
|
||||
const ContextKeyPluginReload = ContextKey("plugin-reload")
|
||||
|
||||
// Cleanup cleans up the go-plugin client and the plugin catalog
|
||||
func (b *BackendPluginClientV5) Cleanup(ctx context.Context) {
|
||||
_, ok := ctx.Value(ContextKeyPluginReload).(string)
|
||||
if !ok {
|
||||
b.Backend.Cleanup(ctx)
|
||||
b.client.Close()
|
||||
return
|
||||
}
|
||||
b.Backend.Cleanup(ctx)
|
||||
b.client.Reload()
|
||||
}
|
||||
|
||||
// NewBackendV5 will return an instance of an RPC-based client implementation of
|
||||
// the backend for external plugins, or a concrete implementation of the
|
||||
// backend if it is a builtin backend. The backend is returned as a
|
||||
// logical.Backend interface.
|
||||
func NewBackendV5(ctx context.Context, pluginName string, pluginType consts.PluginType, sys pluginutil.LookRunnerUtil, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
// Look for plugin in the plugin catalog
|
||||
pluginRunner, err := sys.LookupPlugin(ctx, pluginName, pluginType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var backend logical.Backend
|
||||
if pluginRunner.Builtin {
|
||||
// Plugin is builtin so we can retrieve an instance of the interface
|
||||
// from the pluginRunner. Then cast it to logical.Factory.
|
||||
rawFactory, err := pluginRunner.BuiltinFactory()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting plugin type: %q", err)
|
||||
}
|
||||
|
||||
if factory, ok := rawFactory.(logical.Factory); !ok {
|
||||
return nil, fmt.Errorf("unsupported backend type: %q", pluginName)
|
||||
} else {
|
||||
if backend, err = factory(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// create a backendPluginClient instance
|
||||
config := pluginutil.PluginClientConfig{
|
||||
Name: pluginName,
|
||||
PluginSets: PluginSet,
|
||||
PluginType: pluginType,
|
||||
HandshakeConfig: HandshakeConfig,
|
||||
Logger: conf.Logger.Named(pluginName),
|
||||
AutoMTLS: true,
|
||||
Wrapper: sys,
|
||||
}
|
||||
backend, err = NewPluginClientV5(ctx, sys, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
// PluginSet is the map of plugins we can dispense.
|
||||
var PluginSet = map[int]plugin.PluginSet{
|
||||
5: {
|
||||
"backend": &GRPCBackendPlugin{},
|
||||
},
|
||||
}
|
||||
|
||||
func Dispense(rpcClient plugin.ClientProtocol, pluginClient pluginutil.PluginClient) (logical.Backend, error) {
|
||||
// Request the plugin
|
||||
raw, err := rpcClient.Dispense("backend")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var backend logical.Backend
|
||||
// We should have a logical backend type now. This feels like a normal interface
|
||||
// implementation but is in fact over an RPC connection.
|
||||
switch c := raw.(type) {
|
||||
case *backendGRPCPluginClient:
|
||||
// This is an abstraction leak from go-plugin but it is necessary in
|
||||
// order to enable multiplexing on multiplexed plugins
|
||||
c.client = pb.NewBackendClient(pluginClient.Conn())
|
||||
|
||||
backend = c
|
||||
default:
|
||||
return nil, errors.New("unsupported plugin client type")
|
||||
}
|
||||
|
||||
return &BackendPluginClientV5{
|
||||
client: pluginClient,
|
||||
Backend: backend,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewPluginClientV5(ctx context.Context, sys pluginutil.RunnerUtil, config pluginutil.PluginClientConfig) (logical.Backend, error) {
|
||||
pluginClient, err := sys.NewPluginClient(ctx, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Request the plugin
|
||||
raw, err := pluginClient.Dispense("backend")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var backend logical.Backend
|
||||
var transport string
|
||||
// We should have a logical backend type now. This feels like a normal interface
|
||||
// implementation but is in fact over an RPC connection.
|
||||
switch c := raw.(type) {
|
||||
case *backendGRPCPluginClient:
|
||||
// This is an abstraction leak from go-plugin but it is necessary in
|
||||
// order to enable multiplexing on multiplexed plugins
|
||||
c.client = pb.NewBackendClient(pluginClient.Conn())
|
||||
|
||||
backend = c
|
||||
transport = "gRPC"
|
||||
default:
|
||||
return nil, errors.New("unsupported plugin client type")
|
||||
}
|
||||
|
||||
// Wrap the backend in a tracing middleware
|
||||
if config.Logger.IsTrace() {
|
||||
backend = &BackendTracingMiddleware{
|
||||
logger: config.Logger.With("transport", transport),
|
||||
next: backend,
|
||||
}
|
||||
}
|
||||
|
||||
return &BackendPluginClientV5{
|
||||
client: pluginClient,
|
||||
Backend: backend,
|
||||
}, nil
|
||||
}
|
|
@ -55,6 +55,13 @@ func Serve(opts *ServeOpts) error {
|
|||
Logger: logger,
|
||||
},
|
||||
},
|
||||
5: {
|
||||
"backend": &GRPCBackendPlugin{
|
||||
Factory: opts.BackendFactoryFunc,
|
||||
MultiplexingSupport: false,
|
||||
Logger: logger,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := pluginutil.OptionallyEnableMlock()
|
||||
|
@ -63,7 +70,7 @@ func Serve(opts *ServeOpts) error {
|
|||
}
|
||||
|
||||
serveOpts := &plugin.ServeConfig{
|
||||
HandshakeConfig: handshakeConfig,
|
||||
HandshakeConfig: HandshakeConfig,
|
||||
VersionedPlugins: pluginSets,
|
||||
TLSProvider: opts.TLSProviderFunc,
|
||||
Logger: logger,
|
||||
|
@ -81,12 +88,77 @@ func Serve(opts *ServeOpts) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// ServeMultiplex is a helper function used to serve a backend plugin. This
|
||||
// should be ran on the plugin's main process.
|
||||
func ServeMultiplex(opts *ServeOpts) error {
|
||||
logger := opts.Logger
|
||||
if logger == nil {
|
||||
logger = log.New(&log.LoggerOptions{
|
||||
Level: log.Info,
|
||||
Output: os.Stderr,
|
||||
JSONFormat: true,
|
||||
})
|
||||
}
|
||||
|
||||
// pluginMap is the map of plugins we can dispense.
|
||||
pluginSets := map[int]plugin.PluginSet{
|
||||
// Version 3 used to supports both protocols. We want to keep it around
|
||||
// since it's possible old plugins built against this version will still
|
||||
// work with gRPC. There is currently no difference between version 3
|
||||
// and version 4.
|
||||
3: {
|
||||
"backend": &GRPCBackendPlugin{
|
||||
Factory: opts.BackendFactoryFunc,
|
||||
Logger: logger,
|
||||
},
|
||||
},
|
||||
4: {
|
||||
"backend": &GRPCBackendPlugin{
|
||||
Factory: opts.BackendFactoryFunc,
|
||||
Logger: logger,
|
||||
},
|
||||
},
|
||||
5: {
|
||||
"backend": &GRPCBackendPlugin{
|
||||
Factory: opts.BackendFactoryFunc,
|
||||
MultiplexingSupport: true,
|
||||
Logger: logger,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := pluginutil.OptionallyEnableMlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serveOpts := &plugin.ServeConfig{
|
||||
HandshakeConfig: HandshakeConfig,
|
||||
VersionedPlugins: pluginSets,
|
||||
Logger: logger,
|
||||
|
||||
// A non-nil value here enables gRPC serving for this plugin...
|
||||
GRPCServer: func(opts []grpc.ServerOption) *grpc.Server {
|
||||
opts = append(opts, grpc.MaxRecvMsgSize(math.MaxInt32))
|
||||
opts = append(opts, grpc.MaxSendMsgSize(math.MaxInt32))
|
||||
return plugin.DefaultGRPCServer(opts)
|
||||
},
|
||||
|
||||
// TLSProvider is required to support v3 and v4 plugins.
|
||||
// It will be ignored for v5 which uses AutoMTLS
|
||||
TLSProvider: opts.TLSProviderFunc,
|
||||
}
|
||||
|
||||
plugin.Serve(serveOpts)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handshakeConfigs are used to just do a basic handshake between
|
||||
// a plugin and host. If the handshake fails, a user friendly error is shown.
|
||||
// This prevents users from executing bad plugins or executing a plugin
|
||||
// directory. It is a UX feature, not a security feature.
|
||||
var handshakeConfig = plugin.HandshakeConfig{
|
||||
ProtocolVersion: 4,
|
||||
var HandshakeConfig = plugin.HandshakeConfig{
|
||||
MagicCookieKey: "VAULT_BACKEND_PLUGIN",
|
||||
MagicCookieValue: "6669da05-b1c8-4f49-97d9-c8e5bed98e20",
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/builtin/plugin"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
|
@ -924,7 +924,6 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV
|
|||
f = wrapFactoryCheckPerms(c, plugin.Factory)
|
||||
}
|
||||
}
|
||||
|
||||
// Set up conf to pass in plugin_name
|
||||
conf := make(map[string]string)
|
||||
for k, v := range entry.Options {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package vault_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
@ -31,8 +32,40 @@ const (
|
|||
expectedEnvValue = "BAR"
|
||||
)
|
||||
|
||||
// logicalVersionMap is a map of version to test plugin
|
||||
var logicalVersionMap = map[string]string{
|
||||
"v4": "TestBackend_PluginMain_V4_Logical",
|
||||
"v5": "TestBackend_PluginMainLogical",
|
||||
"v5_multiplexed": "TestBackend_PluginMain_Multiplexed_Logical",
|
||||
}
|
||||
|
||||
// credentialVersionMap is a map of version to test plugin
|
||||
var credentialVersionMap = map[string]string{
|
||||
"v4": "TestBackend_PluginMain_V4_Credentials",
|
||||
"v5": "TestBackend_PluginMainCredentials",
|
||||
"v5_multiplexed": "TestBackend_PluginMain_Multiplexed_Credentials",
|
||||
}
|
||||
|
||||
var testCtx = context.TODO()
|
||||
|
||||
func TestSystemBackend_Plugin_secret(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
|
||||
testCases := []struct {
|
||||
pluginVersion string
|
||||
}{
|
||||
{
|
||||
pluginVersion: "v5_multiplexed",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v5",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.pluginVersion, func(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
@ -40,7 +73,7 @@ func TestSystemBackend_Plugin_secret(t *testing.T) {
|
|||
// Make a request to lazy load the plugin
|
||||
req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal")
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err := core.HandleRequest(namespace.RootContext(nil), req)
|
||||
resp, err := core.HandleRequest(namespace.RootContext(testCtx), req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -67,10 +100,28 @@ func TestSystemBackend_Plugin_secret(t *testing.T) {
|
|||
// If it fails, it means unseal process failed
|
||||
vault.TestWaitActive(t, core.Core)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_auth(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeCredential)
|
||||
testCases := []struct {
|
||||
pluginVersion string
|
||||
}{
|
||||
{
|
||||
pluginVersion: "v5_multiplexed",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v5",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.pluginVersion, func(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeCredential, tc.pluginVersion)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
@ -78,7 +129,7 @@ func TestSystemBackend_Plugin_auth(t *testing.T) {
|
|||
// Make a request to lazy load the plugin
|
||||
req := logical.TestRequest(t, logical.ReadOperation, "auth/mock-0/internal")
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err := core.HandleRequest(namespace.RootContext(nil), req)
|
||||
resp, err := core.HandleRequest(namespace.RootContext(testCtx), req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -105,10 +156,28 @@ func TestSystemBackend_Plugin_auth(t *testing.T) {
|
|||
// If it fails, it means unseal process failed
|
||||
vault.TestWaitActive(t, core.Core)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_MissingBinary(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
|
||||
testCases := []struct {
|
||||
pluginVersion string
|
||||
}{
|
||||
{
|
||||
pluginVersion: "v5_multiplexed",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v5",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.pluginVersion, func(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
@ -116,7 +185,7 @@ func TestSystemBackend_Plugin_MissingBinary(t *testing.T) {
|
|||
// Make a request to lazy load the plugin
|
||||
req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal")
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err := core.HandleRequest(namespace.RootContext(nil), req)
|
||||
resp, err := core.HandleRequest(namespace.RootContext(testCtx), req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -142,14 +211,32 @@ func TestSystemBackend_Plugin_MissingBinary(t *testing.T) {
|
|||
// Make a request against on tune after it is removed
|
||||
req = logical.TestRequest(t, logical.ReadOperation, "sys/mounts/mock-0/tune")
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err = core.HandleRequest(namespace.RootContext(nil), req)
|
||||
_, err = core.HandleRequest(namespace.RootContext(testCtx), req)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_MismatchType(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
|
||||
testCases := []struct {
|
||||
pluginVersion string
|
||||
}{
|
||||
{
|
||||
pluginVersion: "v5_multiplexed",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v5",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.pluginVersion, func(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
@ -161,35 +248,53 @@ func TestSystemBackend_Plugin_MismatchType(t *testing.T) {
|
|||
// and expect an error
|
||||
req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal")
|
||||
req.ClientToken = core.Client.Token()
|
||||
_, err := core.HandleRequest(namespace.RootContext(nil), req)
|
||||
_, err := core.HandleRequest(namespace.RootContext(testCtx), req)
|
||||
if err != nil {
|
||||
t.Fatalf("adding a same-named plugin of a different type should be no problem: %s", err)
|
||||
}
|
||||
|
||||
// Sleep a bit before cleanup is called
|
||||
time.Sleep(1 * time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_CatalogRemoved(t *testing.T) {
|
||||
t.Run("secret", func(t *testing.T) {
|
||||
testPlugin_CatalogRemoved(t, logical.TypeLogical, false)
|
||||
testPlugin_CatalogRemoved(t, logical.TypeLogical, false, logicalVersionMap)
|
||||
})
|
||||
|
||||
t.Run("auth", func(t *testing.T) {
|
||||
testPlugin_CatalogRemoved(t, logical.TypeCredential, false)
|
||||
testPlugin_CatalogRemoved(t, logical.TypeCredential, false, credentialVersionMap)
|
||||
})
|
||||
|
||||
t.Run("secret-mount-existing", func(t *testing.T) {
|
||||
testPlugin_CatalogRemoved(t, logical.TypeLogical, true)
|
||||
testPlugin_CatalogRemoved(t, logical.TypeLogical, true, logicalVersionMap)
|
||||
})
|
||||
|
||||
t.Run("auth-mount-existing", func(t *testing.T) {
|
||||
testPlugin_CatalogRemoved(t, logical.TypeCredential, true)
|
||||
testPlugin_CatalogRemoved(t, logical.TypeCredential, true, credentialVersionMap)
|
||||
})
|
||||
}
|
||||
|
||||
func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMount bool) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, btype)
|
||||
func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMount bool, versionMap map[string]string) {
|
||||
testCases := []struct {
|
||||
pluginVersion string
|
||||
}{
|
||||
{
|
||||
pluginVersion: "v5_multiplexed",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v5",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.pluginVersion, func(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
@ -197,7 +302,7 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
|
|||
// Remove the plugin from the catalog
|
||||
req := logical.TestRequest(t, logical.DeleteOperation, "sys/plugins/catalog/database/mock-plugin")
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err := core.HandleRequest(namespace.RootContext(nil), req)
|
||||
resp, err := core.HandleRequest(namespace.RootContext(testCtx), req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
@ -230,13 +335,13 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
|
|||
switch btype {
|
||||
case logical.TypeLogical:
|
||||
// Add plugin back to the catalog
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMainLogical", []string{}, "")
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, logicalVersionMap[tc.pluginVersion], []string{}, "")
|
||||
_, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{
|
||||
"type": "test",
|
||||
})
|
||||
case logical.TypeCredential:
|
||||
// Add plugin back to the catalog
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, "")
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, credentialVersionMap[tc.pluginVersion], []string{}, "")
|
||||
_, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{
|
||||
"type": "test",
|
||||
})
|
||||
|
@ -245,6 +350,8 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
|
|||
t.Fatal("expected error when mounting on existing path")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_continueOnError(t *testing.T) {
|
||||
|
@ -278,7 +385,23 @@ func TestSystemBackend_Plugin_continueOnError(t *testing.T) {
|
|||
}
|
||||
|
||||
func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatch bool, mountPoint string, pluginType consts.PluginType) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, btype)
|
||||
testCases := []struct {
|
||||
pluginVersion string
|
||||
}{
|
||||
{
|
||||
pluginVersion: "v5_multiplexed",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v5",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.pluginVersion, func(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, btype, tc.pluginVersion)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
@ -286,7 +409,7 @@ func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatc
|
|||
// Get the registered plugin
|
||||
req := logical.TestRequest(t, logical.ReadOperation, fmt.Sprintf("sys/plugins/catalog/%s/mock-plugin", pluginType))
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err := core.HandleRequest(namespace.RootContext(nil), req)
|
||||
resp, err := core.HandleRequest(namespace.RootContext(testCtx), req)
|
||||
if err != nil || resp == nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
@ -296,18 +419,6 @@ func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatc
|
|||
t.Fatal("invalid command")
|
||||
}
|
||||
|
||||
// Mount credential type plugins
|
||||
switch btype {
|
||||
case logical.TypeCredential:
|
||||
vault.TestAddTestPlugin(t, core.Core, mountPoint, consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, cluster.TempDir)
|
||||
_, err = core.Client.Logical().Write(fmt.Sprintf("sys/auth/%s", mountPoint), map[string]interface{}{
|
||||
"type": "mock-plugin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger a sha256 mismatch or missing plugin error
|
||||
if mismatch {
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("sys/plugins/catalog/%s/mock-plugin", pluginType))
|
||||
|
@ -316,7 +427,7 @@ func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatc
|
|||
"command": filepath.Base(command),
|
||||
}
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err = core.HandleRequest(namespace.RootContext(nil), req)
|
||||
resp, err = core.HandleRequest(namespace.RootContext(testCtx), req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
@ -351,9 +462,11 @@ func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatc
|
|||
// Re-add the plugin to the catalog
|
||||
switch btype {
|
||||
case logical.TypeLogical:
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMainLogical", []string{}, cluster.TempDir)
|
||||
plugin := logicalVersionMap[tc.pluginVersion]
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, plugin, []string{}, cluster.TempDir)
|
||||
case logical.TypeCredential:
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, cluster.TempDir)
|
||||
plugin := credentialVersionMap[tc.pluginVersion]
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, plugin, []string{}, cluster.TempDir)
|
||||
}
|
||||
|
||||
// Reload the plugin
|
||||
|
@ -362,7 +475,7 @@ func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatc
|
|||
"plugin": "mock-plugin",
|
||||
}
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err = core.HandleRequest(namespace.RootContext(nil), req)
|
||||
resp, err = core.HandleRequest(namespace.RootContext(testCtx), req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
@ -378,17 +491,35 @@ func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatc
|
|||
|
||||
req = logical.TestRequest(t, logical.ReadOperation, reqPath)
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err = core.HandleRequest(namespace.RootContext(nil), req)
|
||||
resp, err = core.HandleRequest(namespace.RootContext(testCtx), req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: response should not be nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_autoReload(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
|
||||
testCases := []struct {
|
||||
pluginVersion string
|
||||
}{
|
||||
{
|
||||
pluginVersion: "v5_multiplexed",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v5",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.pluginVersion, func(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
@ -397,7 +528,7 @@ func TestSystemBackend_Plugin_autoReload(t *testing.T) {
|
|||
req := logical.TestRequest(t, logical.UpdateOperation, "mock-0/internal")
|
||||
req.ClientToken = core.Client.Token()
|
||||
req.Data["value"] = "baz"
|
||||
resp, err := core.HandleRequest(namespace.RootContext(nil), req)
|
||||
resp, err := core.HandleRequest(namespace.RootContext(testCtx), req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -408,7 +539,7 @@ func TestSystemBackend_Plugin_autoReload(t *testing.T) {
|
|||
// Call errors/rpc endpoint to trigger reload
|
||||
req = logical.TestRequest(t, logical.ReadOperation, "mock-0/errors/rpc")
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err = core.HandleRequest(namespace.RootContext(nil), req)
|
||||
_, err = core.HandleRequest(namespace.RootContext(testCtx), req)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error from error/rpc request")
|
||||
}
|
||||
|
@ -416,7 +547,7 @@ func TestSystemBackend_Plugin_autoReload(t *testing.T) {
|
|||
// Check internal value to make sure it's reset
|
||||
req = logical.TestRequest(t, logical.ReadOperation, "mock-0/internal")
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err = core.HandleRequest(namespace.RootContext(nil), req)
|
||||
resp, err = core.HandleRequest(namespace.RootContext(testCtx), req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -426,10 +557,28 @@ func TestSystemBackend_Plugin_autoReload(t *testing.T) {
|
|||
if resp.Data["value"].(string) == "baz" {
|
||||
t.Fatal("did not expect backend internal value to be 'baz'")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_SealUnseal(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
|
||||
testCases := []struct {
|
||||
pluginVersion string
|
||||
}{
|
||||
{
|
||||
pluginVersion: "v5_multiplexed",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v5",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.pluginVersion, func(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical, tc.pluginVersion)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
// Seal the cluster
|
||||
|
@ -452,6 +601,8 @@ func TestSystemBackend_Plugin_SealUnseal(t *testing.T) {
|
|||
// Wait for active so post-unseal takes place
|
||||
// If it fails, it means unseal process failed
|
||||
vault.TestWaitActive(t, cluster.Cores[0].Core)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_reload(t *testing.T) {
|
||||
|
@ -498,7 +649,23 @@ func TestSystemBackend_Plugin_reload(t *testing.T) {
|
|||
|
||||
// Helper func to test different reload methods on plugin reload endpoint
|
||||
func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}, backendType logical.BackendType) {
|
||||
cluster := testSystemBackendMock(t, 1, 2, backendType)
|
||||
testCases := []struct {
|
||||
pluginVersion string
|
||||
}{
|
||||
{
|
||||
pluginVersion: "v5_multiplexed",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v5",
|
||||
},
|
||||
{
|
||||
pluginVersion: "v4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.pluginVersion, func(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 2, backendType, tc.pluginVersion)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
@ -546,6 +713,8 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
|
|||
t.Fatal("did not expect backend internal value to be 'baz'")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testSystemBackendMock returns a systemBackend with the desired number
|
||||
|
@ -553,7 +722,7 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
|
|||
// ways of providing the plugin_name.
|
||||
//
|
||||
// The mounts are mounted at sys/mounts/mock-[numMounts] or sys/auth/mock-[numMounts]
|
||||
func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType) *vault.TestCluster {
|
||||
func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType, pluginVersion string) *vault.TestCluster {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"plugin": plugin.Factory,
|
||||
|
@ -585,7 +754,8 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo
|
|||
|
||||
switch backendType {
|
||||
case logical.TypeLogical:
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "TestBackend_PluginMainLogical", []string{}, tempDir)
|
||||
plugin := logicalVersionMap[pluginVersion]
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, plugin, []string{}, tempDir)
|
||||
for i := 0; i < numMounts; i++ {
|
||||
// Alternate input styles for plugin_name on every other mount
|
||||
options := map[string]interface{}{
|
||||
|
@ -600,7 +770,8 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo
|
|||
}
|
||||
}
|
||||
case logical.TypeCredential:
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "TestBackend_PluginMainCredentials", []string{}, tempDir)
|
||||
plugin := credentialVersionMap[pluginVersion]
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, plugin, []string{}, tempDir)
|
||||
for i := 0; i < numMounts; i++ {
|
||||
// Alternate input styles for plugin_name on every other mount
|
||||
options := map[string]interface{}{
|
||||
|
@ -671,9 +842,15 @@ func testSystemBackend_SingleCluster_Env(t *testing.T, env []string) *vault.Test
|
|||
return cluster
|
||||
}
|
||||
|
||||
func TestBackend_PluginMainLogical(t *testing.T) {
|
||||
func TestBackend_PluginMain_V4_Logical(t *testing.T) {
|
||||
args := []string{}
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadataModeEnv) != "true" {
|
||||
// don't run as a standalone unit test
|
||||
if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// don't run as a V5 plugin
|
||||
if os.Getenv(pluginutil.PluginAutoMTLSEnv) == "true" {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -686,6 +863,8 @@ func TestBackend_PluginMainLogical(t *testing.T) {
|
|||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
|
||||
// V4 does not support AutoMTLS so we set a TLSConfig via TLSProviderFunc
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
|
@ -700,9 +879,9 @@ func TestBackend_PluginMainLogical(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestBackend_PluginMainCredentials(t *testing.T) {
|
||||
func TestBackend_PluginMain_Multiplexed_Logical(t *testing.T) {
|
||||
args := []string{}
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadataModeEnv) != "true" {
|
||||
if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -715,6 +894,66 @@ func TestBackend_PluginMainCredentials(t *testing.T) {
|
|||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
|
||||
factoryFunc := mock.FactoryType(logical.TypeLogical)
|
||||
|
||||
err := lplugin.ServeMultiplex(&lplugin.ServeOpts{
|
||||
BackendFactoryFunc: factoryFunc,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_PluginMainLogical(t *testing.T) {
|
||||
args := []string{}
|
||||
if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv)
|
||||
if caPEM == "" {
|
||||
t.Fatal("CA cert not passed in")
|
||||
}
|
||||
args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM))
|
||||
|
||||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
|
||||
factoryFunc := mock.FactoryType(logical.TypeLogical)
|
||||
|
||||
err := lplugin.Serve(&lplugin.ServeOpts{
|
||||
BackendFactoryFunc: factoryFunc,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_PluginMain_V4_Credentials(t *testing.T) {
|
||||
args := []string{}
|
||||
// don't run as a standalone unit test
|
||||
if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// don't run as a V5 plugin
|
||||
if os.Getenv(pluginutil.PluginAutoMTLSEnv) == "true" {
|
||||
return
|
||||
}
|
||||
|
||||
caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv)
|
||||
if caPEM == "" {
|
||||
t.Fatal("CA cert not passed in")
|
||||
}
|
||||
args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM))
|
||||
|
||||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
|
||||
// V4 does not support AutoMTLS so we set a TLSConfig via TLSProviderFunc
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
|
@ -729,6 +968,58 @@ func TestBackend_PluginMainCredentials(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestBackend_PluginMain_Multiplexed_Credentials(t *testing.T) {
|
||||
args := []string{}
|
||||
if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv)
|
||||
if caPEM == "" {
|
||||
t.Fatal("CA cert not passed in")
|
||||
}
|
||||
args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM))
|
||||
|
||||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
|
||||
factoryFunc := mock.FactoryType(logical.TypeCredential)
|
||||
|
||||
err := lplugin.ServeMultiplex(&lplugin.ServeOpts{
|
||||
BackendFactoryFunc: factoryFunc,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_PluginMainCredentials(t *testing.T) {
|
||||
args := []string{}
|
||||
if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv)
|
||||
if caPEM == "" {
|
||||
t.Fatal("CA cert not passed in")
|
||||
}
|
||||
args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM))
|
||||
|
||||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
|
||||
factoryFunc := mock.FactoryType(logical.TypeCredential)
|
||||
|
||||
err := lplugin.Serve(&lplugin.ServeOpts{
|
||||
BackendFactoryFunc: factoryFunc,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBackend_PluginMainEnv is a mock plugin that simply checks for the existence of FOO env var.
|
||||
func TestBackend_PluginMainEnv(t *testing.T) {
|
||||
args := []string{}
|
||||
|
@ -751,14 +1042,11 @@ func TestBackend_PluginMainEnv(t *testing.T) {
|
|||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
factoryFunc := mock.FactoryType(logical.TypeLogical)
|
||||
|
||||
err := lplugin.Serve(&lplugin.ServeOpts{
|
||||
BackendFactoryFunc: factoryFunc,
|
||||
TLSProviderFunc: tlsProviderFunc,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
|
@ -1432,7 +1432,6 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView
|
|||
f = wrapFactoryCheckPerms(c, plugin.Factory)
|
||||
}
|
||||
}
|
||||
|
||||
// Set up conf to pass in plugin_name
|
||||
conf := make(map[string]string)
|
||||
for k, v := range entry.Options {
|
||||
|
|
|
@ -33,6 +33,7 @@ var (
|
|||
pluginCatalogPath = "core/plugin-catalog/"
|
||||
ErrDirectoryNotConfigured = errors.New("could not set plugin, plugin directory is not configured")
|
||||
ErrPluginNotFound = errors.New("plugin not found in the catalog")
|
||||
ErrPluginConnectionNotFound = errors.New("plugin connection not found for client")
|
||||
ErrPluginBadType = errors.New("unable to determine plugin type")
|
||||
)
|
||||
|
||||
|
@ -79,12 +80,14 @@ type pluginClient struct {
|
|||
|
||||
// id is the connection ID
|
||||
id string
|
||||
pid int
|
||||
|
||||
// client handles the lifecycle of a plugin process
|
||||
// multiplexed plugins share the same client
|
||||
client *plugin.Client
|
||||
clientConn grpc.ClientConnInterface
|
||||
cleanupFunc func() error
|
||||
reloadFunc func() error
|
||||
|
||||
plugin.ClientProtocol
|
||||
}
|
||||
|
@ -148,6 +151,38 @@ func (p *pluginClient) Conn() grpc.ClientConnInterface {
|
|||
return p.clientConn
|
||||
}
|
||||
|
||||
func (p *pluginClient) Reload() error {
|
||||
p.logger.Debug("reload external plugin process")
|
||||
return p.reloadFunc()
|
||||
}
|
||||
|
||||
// reloadExternalPlugin
|
||||
// This should be called with the write lock held.
|
||||
func (c *PluginCatalog) reloadExternalPlugin(name, id string) error {
|
||||
extPlugin, ok := c.externalPlugins[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("plugin client not found")
|
||||
}
|
||||
if !extPlugin.multiplexingSupport {
|
||||
err := c.cleanupExternalPlugin(name, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
pc, ok := extPlugin.connections[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("%w id: %s", ErrPluginConnectionNotFound, id)
|
||||
}
|
||||
|
||||
delete(c.externalPlugins, name)
|
||||
pc.client.Kill()
|
||||
c.logger.Debug("killed external plugin process for reload", "name", name, "pid", pc.pid)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close calls the plugin client's cleanupFunc to do any necessary cleanup on
|
||||
// the plugin client and the PluginCatalog. This implements the
|
||||
// plugin.ClientProtocol interface.
|
||||
|
@ -167,19 +202,26 @@ func (c *PluginCatalog) cleanupExternalPlugin(name, id string) error {
|
|||
|
||||
pc, ok := extPlugin.connections[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("plugin connection not found")
|
||||
// this can happen if the backend is reloaded due to a plugin process
|
||||
// being killed out of band
|
||||
c.logger.Warn(ErrPluginConnectionNotFound.Error(), "id", id)
|
||||
return fmt.Errorf("%w id: %s", ErrPluginConnectionNotFound, id)
|
||||
}
|
||||
|
||||
delete(extPlugin.connections, id)
|
||||
c.logger.Debug("removed plugin client connection", "id", id)
|
||||
|
||||
if !extPlugin.multiplexingSupport {
|
||||
pc.client.Kill()
|
||||
|
||||
if len(extPlugin.connections) == 0 {
|
||||
delete(c.externalPlugins, name)
|
||||
}
|
||||
c.logger.Debug("killed external plugin process", "name", name, "pid", pc.pid)
|
||||
} else if len(extPlugin.connections) == 0 || pc.client.Exited() {
|
||||
pc.client.Kill()
|
||||
delete(c.externalPlugins, name)
|
||||
c.logger.Debug("killed external multiplexed plugin process", "name", name, "pid", pc.pid)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -252,6 +294,11 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi
|
|||
defer c.lock.Unlock()
|
||||
return c.cleanupExternalPlugin(pluginRunner.Name, id)
|
||||
},
|
||||
reloadFunc: func() error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
return c.reloadExternalPlugin(pluginRunner.Name, id)
|
||||
},
|
||||
}
|
||||
|
||||
// Multiplexing support will always be false initially, but will be
|
||||
|
@ -264,9 +311,8 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi
|
|||
pluginutil.Logger(config.Logger),
|
||||
pluginutil.MetadataMode(config.IsMetadataMode),
|
||||
pluginutil.MLock(c.mlockPlugins),
|
||||
|
||||
// NewPluginClient only supports AutoMTLS today
|
||||
pluginutil.AutoMTLS(true),
|
||||
pluginutil.AutoMTLS(config.AutoMTLS),
|
||||
pluginutil.Runner(config.Wrapper),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -294,6 +340,12 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// get the external plugin pid
|
||||
conf := pc.client.ReattachConfig()
|
||||
if conf != nil {
|
||||
pc.pid = conf.Pid
|
||||
}
|
||||
|
||||
clientConn := rpcClient.(*plugin.GRPCClient).Conn
|
||||
|
||||
muxed, err := pluginutil.MultiplexingSupported(ctx, clientConn)
|
||||
|
@ -322,9 +374,8 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi
|
|||
}
|
||||
|
||||
// getPluginTypeFromUnknown will attempt to run the plugin to determine the
|
||||
// type and if it supports multiplexing. It will first attempt to run as a
|
||||
// database plugin then a backend plugin. Both of these will be run in metadata
|
||||
// mode.
|
||||
// type. It will first attempt to run as a database plugin then a backend
|
||||
// plugin.
|
||||
func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, error) {
|
||||
merr := &multierror.Error{}
|
||||
err := c.isDatabasePlugin(ctx, plugin)
|
||||
|
@ -333,16 +384,70 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log
|
|||
}
|
||||
merr = multierror.Append(merr, err)
|
||||
|
||||
// Attempt to run as backend plugin
|
||||
client, err := backendplugin.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true)
|
||||
pluginType, err := c.getBackendPluginType(ctx, plugin)
|
||||
if err == nil {
|
||||
err := client.Setup(ctx, &logical.BackendConfig{})
|
||||
return pluginType, nil
|
||||
}
|
||||
merr = multierror.Append(merr, err)
|
||||
|
||||
return consts.PluginTypeUnknown, merr
|
||||
}
|
||||
|
||||
// getBackendPluginType returns an error if the plugin is not a backend plugin.
|
||||
func (c *PluginCatalog) getBackendPluginType(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (consts.PluginType, error) {
|
||||
merr := &multierror.Error{}
|
||||
// Attempt to run as backend plugin
|
||||
config := pluginutil.PluginClientConfig{
|
||||
Name: pluginRunner.Name,
|
||||
PluginSets: backendplugin.PluginSet,
|
||||
HandshakeConfig: backendplugin.HandshakeConfig,
|
||||
Logger: log.NewNullLogger(),
|
||||
IsMetadataMode: false,
|
||||
AutoMTLS: true,
|
||||
}
|
||||
|
||||
var client logical.Backend
|
||||
var attemptV4 bool
|
||||
// First, attempt to run as backend V5 plugin
|
||||
c.logger.Debug("attempting to load backend plugin", "name", pluginRunner.Name)
|
||||
pc, err := c.newPluginClient(ctx, pluginRunner, config)
|
||||
if err == nil {
|
||||
// we spawned a subprocess, so make sure to clean it up
|
||||
defer c.cleanupExternalPlugin(pluginRunner.Name, pc.id)
|
||||
|
||||
// dispense the plugin so we can get its type
|
||||
client, err = backendplugin.Dispense(pc.ClientProtocol, pc)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("failed to dispense plugin as backend v5: %w", err))
|
||||
c.logger.Debug("failed to dispense v5 backend plugin", "name", pluginRunner.Name)
|
||||
attemptV4 = true
|
||||
} else {
|
||||
c.logger.Debug("successfully dispensed v5 backend plugin", "name", pluginRunner.Name)
|
||||
}
|
||||
} else {
|
||||
attemptV4 = true
|
||||
}
|
||||
|
||||
if attemptV4 {
|
||||
c.logger.Debug("failed to dispense v5 backend plugin", "name", pluginRunner.Name, "error", err)
|
||||
config.AutoMTLS = false
|
||||
config.IsMetadataMode = true
|
||||
// attempt to run as a v4 backend plugin
|
||||
client, err = backendplugin.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true)
|
||||
if err != nil {
|
||||
c.logger.Debug("failed to dispense v4 backend plugin", "name", pluginRunner.Name, "error", err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("failed to dispense v4 backend plugin: %w", err))
|
||||
return consts.PluginTypeUnknown, merr.ErrorOrNil()
|
||||
}
|
||||
c.logger.Debug("successfully dispensed v4 backend plugin", "name", pluginRunner.Name)
|
||||
defer client.Cleanup(ctx)
|
||||
}
|
||||
|
||||
err = client.Setup(ctx, &logical.BackendConfig{})
|
||||
if err != nil {
|
||||
return consts.PluginTypeUnknown, err
|
||||
}
|
||||
|
||||
backendType := client.Type()
|
||||
client.Cleanup(ctx)
|
||||
|
||||
switch backendType {
|
||||
case logical.TypeCredential:
|
||||
|
@ -350,26 +455,24 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log
|
|||
case logical.TypeLogical:
|
||||
return consts.PluginTypeSecrets, nil
|
||||
}
|
||||
} else {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if client == nil || client.Type() == logical.TypeUnknown {
|
||||
logger.Warn("unknown plugin type",
|
||||
"plugin name", plugin.Name,
|
||||
c.logger.Warn("unknown plugin type",
|
||||
"plugin name", pluginRunner.Name,
|
||||
"error", merr.Error())
|
||||
} else {
|
||||
logger.Warn("unsupported plugin type",
|
||||
"plugin name", plugin.Name,
|
||||
c.logger.Warn("unsupported plugin type",
|
||||
"plugin name", pluginRunner.Name,
|
||||
"plugin type", client.Type().String(),
|
||||
"error", merr.Error())
|
||||
}
|
||||
|
||||
return consts.PluginTypeUnknown, nil
|
||||
merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as backend plugin: %w", err))
|
||||
|
||||
return consts.PluginTypeUnknown, merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
// isDatabasePlugin returns true if the plugin supports multiplexing. An error
|
||||
// is returned if the plugin is not a database plugin.
|
||||
// isDatabasePlugin returns an error if the plugin is not a database plugin.
|
||||
func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *pluginutil.PluginRunner) error {
|
||||
merr := &multierror.Error{}
|
||||
config := pluginutil.PluginClientConfig{
|
||||
|
|
|
@ -10,11 +10,13 @@ import (
|
|||
"sort"
|
||||
"testing"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/credential/userpass"
|
||||
"github.com/hashicorp/vault/plugins/database/postgresql"
|
||||
v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
backendplugin "github.com/hashicorp/vault/sdk/plugin"
|
||||
|
||||
"github.com/hashicorp/vault/helper/builtinplugins"
|
||||
)
|
||||
|
@ -59,7 +61,7 @@ func TestPluginCatalog_CRUD(t *testing.T) {
|
|||
}
|
||||
defer file.Close()
|
||||
|
||||
command := fmt.Sprintf("%s", filepath.Base(file.Name()))
|
||||
command := filepath.Base(file.Name())
|
||||
err = core.pluginCatalog.Set(context.Background(), pluginName, consts.PluginTypeDatabase, "", command, []string{"--test"}, []string{"FOO=BAR"}, []byte{'1'})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -375,47 +377,109 @@ func TestPluginCatalog_NewPluginClient(t *testing.T) {
|
|||
TestAddTestPlugin(t, core, "single-postgres-1", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_Postgres", []string{}, "")
|
||||
TestAddTestPlugin(t, core, "single-postgres-2", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_Postgres", []string{}, "")
|
||||
|
||||
TestAddTestPlugin(t, core, "mux-userpass", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_UserpassMultiplexed", []string{}, "")
|
||||
TestAddTestPlugin(t, core, "single-userpass-1", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_Userpass", []string{}, "")
|
||||
TestAddTestPlugin(t, core, "single-userpass-2", consts.PluginTypeUnknown, "TestPluginCatalog_PluginMain_Userpass", []string{}, "")
|
||||
|
||||
var pluginClients []*pluginClient
|
||||
// run plugins
|
||||
if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("mux-postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("mux-postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("single-postgres-1")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := core.pluginCatalog.NewPluginClient(context.Background(), testPluginClientConfig("single-postgres-2")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// run "mux-postgres" twice which will start a single plugin for 2
|
||||
// distinct connections
|
||||
c := TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "mux-postgres")
|
||||
pluginClients = append(pluginClients, c)
|
||||
c = TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "mux-postgres")
|
||||
pluginClients = append(pluginClients, c)
|
||||
c = TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "single-postgres-1")
|
||||
pluginClients = append(pluginClients, c)
|
||||
c = TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "single-postgres-2")
|
||||
pluginClients = append(pluginClients, c)
|
||||
|
||||
// run "mux-userpass" twice which will start a single plugin for 2
|
||||
// distinct connections
|
||||
c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "mux-userpass")
|
||||
pluginClients = append(pluginClients, c)
|
||||
c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "mux-userpass")
|
||||
pluginClients = append(pluginClients, c)
|
||||
c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "single-userpass-1")
|
||||
pluginClients = append(pluginClients, c)
|
||||
c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "single-userpass-2")
|
||||
pluginClients = append(pluginClients, c)
|
||||
|
||||
externalPlugins := core.pluginCatalog.externalPlugins
|
||||
if len(externalPlugins) != 3 {
|
||||
t.Fatalf("expected externalPlugins map to be of len 3 but got %d", len(externalPlugins))
|
||||
if len(externalPlugins) != 6 {
|
||||
t.Fatalf("expected externalPlugins map to be of len 6 but got %d", len(externalPlugins))
|
||||
}
|
||||
|
||||
// check connections map
|
||||
expectedLen := 2
|
||||
if len(externalPlugins["mux-postgres"].connections) != expectedLen {
|
||||
t.Fatalf("expected multiplexed external plugin's connections map to be of len %d but got %d", expectedLen, len(externalPlugins["mux-postgres"].connections))
|
||||
}
|
||||
expectedLen = 1
|
||||
if len(externalPlugins["single-postgres-1"].connections) != expectedLen {
|
||||
t.Fatalf("expected multiplexed external plugin's connections map to be of len %d but got %d", expectedLen, len(externalPlugins["mux-postgres"].connections))
|
||||
}
|
||||
if len(externalPlugins["single-postgres-2"].connections) != expectedLen {
|
||||
t.Fatalf("expected multiplexed external plugin's connections map to be of len %d but got %d", expectedLen, len(externalPlugins["mux-postgres"].connections))
|
||||
}
|
||||
expectConnectionLen(t, 2, externalPlugins["mux-postgres"].connections)
|
||||
expectConnectionLen(t, 1, externalPlugins["single-postgres-1"].connections)
|
||||
expectConnectionLen(t, 1, externalPlugins["single-postgres-2"].connections)
|
||||
expectConnectionLen(t, 2, externalPlugins["mux-userpass"].connections)
|
||||
expectConnectionLen(t, 1, externalPlugins["single-userpass-1"].connections)
|
||||
expectConnectionLen(t, 1, externalPlugins["single-userpass-2"].connections)
|
||||
|
||||
// check multiplexing support
|
||||
if !externalPlugins["mux-postgres"].multiplexingSupport {
|
||||
t.Fatalf("expected external plugin to be multiplexed")
|
||||
expectMultiplexingSupport(t, true, externalPlugins["mux-postgres"].multiplexingSupport)
|
||||
expectMultiplexingSupport(t, false, externalPlugins["single-postgres-1"].multiplexingSupport)
|
||||
expectMultiplexingSupport(t, false, externalPlugins["single-postgres-2"].multiplexingSupport)
|
||||
expectMultiplexingSupport(t, true, externalPlugins["mux-userpass"].multiplexingSupport)
|
||||
expectMultiplexingSupport(t, false, externalPlugins["single-userpass-1"].multiplexingSupport)
|
||||
expectMultiplexingSupport(t, false, externalPlugins["single-userpass-2"].multiplexingSupport)
|
||||
|
||||
// cleanup all of the external plugin processes
|
||||
for _, client := range pluginClients {
|
||||
client.Close()
|
||||
}
|
||||
if externalPlugins["single-postgres-1"].multiplexingSupport {
|
||||
t.Fatalf("expected external plugin to be non-multiplexed")
|
||||
|
||||
// check that externalPlugins map is cleaned up
|
||||
if len(externalPlugins) != 0 {
|
||||
t.Fatalf("expected external plugin map to be of len 0 but got %d", len(externalPlugins))
|
||||
}
|
||||
if externalPlugins["single-postgres-2"].multiplexingSupport {
|
||||
t.Fatalf("expected external plugin to be non-multiplexed")
|
||||
}
|
||||
|
||||
func TestPluginCatalog_PluginMain_Userpass(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args[1:])
|
||||
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
err := backendplugin.Serve(
|
||||
&backendplugin.ServeOpts{
|
||||
BackendFactoryFunc: userpass.Factory,
|
||||
TLSProviderFunc: tlsProviderFunc,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize userpass: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluginCatalog_PluginMain_UserpassMultiplexed(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args[1:])
|
||||
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
err := backendplugin.ServeMultiplex(
|
||||
&backendplugin.ServeOpts{
|
||||
BackendFactoryFunc: userpass.Factory,
|
||||
TLSProviderFunc: tlsProviderFunc,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize userpass: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -440,14 +504,16 @@ func TestPluginCatalog_PluginMain_PostgresMultiplexed(_ *testing.T) {
|
|||
v5.ServeMultiplex(postgresql.New)
|
||||
}
|
||||
|
||||
func testPluginClientConfig(pluginName string) pluginutil.PluginClientConfig {
|
||||
return pluginutil.PluginClientConfig{
|
||||
Name: pluginName,
|
||||
PluginType: consts.PluginTypeDatabase,
|
||||
PluginSets: v5.PluginSets,
|
||||
HandshakeConfig: v5.HandshakeConfig,
|
||||
Logger: log.NewNullLogger(),
|
||||
IsMetadataMode: false,
|
||||
AutoMTLS: true,
|
||||
// expectConnectionLen asserts that the PluginCatalog's externalPlugin
|
||||
// connections map has a length of expectedLen
|
||||
func expectConnectionLen(t *testing.T, expectedLen int, connections map[string]*pluginClient) {
|
||||
if len(connections) != expectedLen {
|
||||
t.Fatalf("expected external plugin's connections map to be of len %d but got %d", expectedLen, len(connections))
|
||||
}
|
||||
}
|
||||
|
||||
func expectMultiplexingSupport(t *testing.T, expected, actual bool) {
|
||||
if expected != actual {
|
||||
t.Fatalf("expected external plugin multiplexing support to be %t", expected)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
multierror "github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/sdk/plugin"
|
||||
)
|
||||
|
||||
// reloadMatchingPluginMounts reloads provided mounts, regardless of
|
||||
|
@ -147,8 +148,11 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut
|
|||
|
||||
// Only call Cleanup if backend is initialized
|
||||
if re.backend != nil {
|
||||
// Pass a context value so that the plugin client will call the
|
||||
// appropriate cleanup method for reloading
|
||||
reloadCtx := context.WithValue(ctx, plugin.ContextKeyPluginReload, "reload")
|
||||
// Call backend's Cleanup routine
|
||||
re.backend.Cleanup(ctx)
|
||||
re.backend.Cleanup(reloadCtx)
|
||||
}
|
||||
|
||||
view := re.storageView
|
||||
|
|
|
@ -39,13 +39,16 @@ import (
|
|||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/internalshared/configutil"
|
||||
dbMysql "github.com/hashicorp/vault/plugins/database/mysql"
|
||||
v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/salt"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
physInmem "github.com/hashicorp/vault/sdk/physical/inmem"
|
||||
backendplugin "github.com/hashicorp/vault/sdk/plugin"
|
||||
"github.com/hashicorp/vault/vault/cluster"
|
||||
"github.com/hashicorp/vault/vault/seal"
|
||||
"github.com/mitchellh/copystructure"
|
||||
|
@ -563,6 +566,48 @@ func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.Plug
|
|||
}
|
||||
}
|
||||
|
||||
// TestRunTestPlugin runs the testFunc which has already been registered to the
|
||||
// plugin catalog and returns a pluginClient. This can be called after calling
|
||||
// TestAddTestPlugin.
|
||||
func TestRunTestPlugin(t testing.T, c *Core, pluginType consts.PluginType, pluginName string) *pluginClient {
|
||||
t.Helper()
|
||||
config := TestPluginClientConfig(c, pluginType, pluginName)
|
||||
client, err := c.pluginCatalog.NewPluginClient(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func TestPluginClientConfig(c *Core, pluginType consts.PluginType, pluginName string) pluginutil.PluginClientConfig {
|
||||
switch pluginType {
|
||||
case consts.PluginTypeCredential, consts.PluginTypeSecrets:
|
||||
dsv := TestDynamicSystemView(c, nil)
|
||||
return pluginutil.PluginClientConfig{
|
||||
Name: pluginName,
|
||||
PluginType: pluginType,
|
||||
PluginSets: backendplugin.PluginSet,
|
||||
HandshakeConfig: backendplugin.HandshakeConfig,
|
||||
Logger: log.NewNullLogger(),
|
||||
AutoMTLS: true,
|
||||
IsMetadataMode: false,
|
||||
Wrapper: dsv,
|
||||
}
|
||||
case consts.PluginTypeDatabase:
|
||||
return pluginutil.PluginClientConfig{
|
||||
Name: pluginName,
|
||||
PluginType: pluginType,
|
||||
PluginSets: v5.PluginSets,
|
||||
HandshakeConfig: v5.HandshakeConfig,
|
||||
Logger: log.NewNullLogger(),
|
||||
AutoMTLS: true,
|
||||
IsMetadataMode: false,
|
||||
}
|
||||
}
|
||||
return pluginutil.PluginClientConfig{}
|
||||
}
|
||||
|
||||
var (
|
||||
testLogicalBackends = map[string]logical.Factory{}
|
||||
testCredentialBackends = map[string]logical.Factory{}
|
||||
|
|
|
@ -141,7 +141,7 @@ DONELISTHANDLING:
|
|||
NamespaceID: ns.ID,
|
||||
}
|
||||
|
||||
if err := c.tokenStore.create(ctx, &te); err != nil {
|
||||
if err := c.CreateToken(ctx, &te); err != nil {
|
||||
c.logger.Error("failed to create wrapping token", "error", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue