Plugins: Tighten requirements for multiplexing (#17403)

Change the multiplexing key to use all `PluginRunner` config (converted to a struct which is comparable), so that plugins with the same name but different env, args, types, versions etc are not incorrectly multiplexed together.

Co-authored-by: Christopher Swenson <christopher.swenson@hashicorp.com>
This commit is contained in:
Tom Proctor 2022-10-05 09:29:29 +01:00 committed by GitHub
parent 138c516498
commit 3aa2fe8d8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 375 additions and 101 deletions

View File

@ -2,11 +2,14 @@ package mock
import ( import (
"context" "context"
"os"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
const MockPluginVersionEnv = "TESTING_MOCK_VAULT_PLUGIN_VERSION"
// New returns a new backend as an interface. This func // New returns a new backend as an interface. This func
// is only necessary for builtin backend plugins. // is only necessary for builtin backend plugins.
func New() (interface{}, error) { func New() (interface{}, error) {
@ -60,6 +63,9 @@ func Backend() *backend {
} }
b.internal = "bar" b.internal = "bar"
b.RunningVersion = "v0.0.0+mock" b.RunningVersion = "v0.0.0+mock"
if version := os.Getenv(MockPluginVersionEnv); version != "" {
b.RunningVersion = version
}
return &b return &b
} }

View File

@ -18,24 +18,42 @@ import (
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/plugin"
"github.com/hashicorp/vault/sdk/plugin/mock"
) )
const vaultTestingMockPluginEnv = "VAULT_TESTING_MOCK_PLUGIN"
var ( var (
pluginCacheLock sync.Mutex pluginCacheLock sync.Mutex
pluginCache = map[string][]byte{} pluginCache = map[string][]byte{}
) )
type testPlugin struct {
name string
typ consts.PluginType
version string
fileName string
sha256 string
}
// version is used to override the plugin's self-reported version // version is used to override the plugin's self-reported version
func testCoreWithPlugin(t *testing.T, typ consts.PluginType, version string) (*Core, string, string) { func testCoreWithPlugins(t *testing.T, typ consts.PluginType, versions ...string) (*Core, []testPlugin) {
t.Helper() t.Helper()
pluginName, pluginSHA256, pluginDir := compilePlugin(t, typ, version) pluginDir, cleanup := MakeTestPluginDir(t)
t.Cleanup(func() { cleanup(t) })
var plugins []testPlugin
for _, version := range versions {
plugins = append(plugins, compilePlugin(t, typ, version, pluginDir))
}
conf := &CoreConfig{ conf := &CoreConfig{
BuiltinRegistry: NewMockBuiltinRegistry(), BuiltinRegistry: NewMockBuiltinRegistry(),
PluginDirectory: pluginDir, PluginDirectory: pluginDir,
} }
core := TestCoreWithSealAndUI(t, conf) core := TestCoreWithSealAndUI(t, conf)
core, _, _ = testCoreUnsealed(t, core) core, _, _ = testCoreUnsealed(t, core)
return core, pluginName, pluginSHA256 return core, plugins
} }
func getPlugin(t *testing.T, typ consts.PluginType) (string, string, string, string) { func getPlugin(t *testing.T, typ consts.PluginType) (string, string, string, string) {
@ -69,7 +87,7 @@ func getPlugin(t *testing.T, typ consts.PluginType) (string, string, string, str
// to mount a plugin, we need a working binary plugin, so we compile one here. // to mount a plugin, we need a working binary plugin, so we compile one here.
// pluginVersion is used to override the plugin's self-reported version // pluginVersion is used to override the plugin's self-reported version
func compilePlugin(t *testing.T, typ consts.PluginType, pluginVersion string) (pluginName string, shasum string, pluginDir string) { func compilePlugin(t *testing.T, typ consts.PluginType, pluginVersion string, pluginDir string) testPlugin {
t.Helper() t.Helper()
pluginName, pluginType, pluginMain, pluginVersionLocation := getPlugin(t, typ) pluginName, pluginType, pluginMain, pluginVersionLocation := getPlugin(t, typ)
@ -89,10 +107,10 @@ func compilePlugin(t *testing.T, typ consts.PluginType, pluginVersion string) (p
dir = filepath.Dir(wd) dir = filepath.Dir(wd)
} }
pluginDir, cleanup := MakeTestPluginDir(t)
t.Cleanup(func() { cleanup(t) })
pluginPath := path.Join(pluginDir, pluginName) pluginPath := path.Join(pluginDir, pluginName)
if pluginVersion != "" {
pluginPath += "-" + pluginVersion
}
key := fmt.Sprintf("%s %s %s", pluginName, pluginType, pluginVersion) key := fmt.Sprintf("%s %s %s", pluginName, pluginType, pluginVersion)
// cache the compilation to only run once // cache the compilation to only run once
@ -132,7 +150,13 @@ func compilePlugin(t *testing.T, typ consts.PluginType, pluginVersion string) (p
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
return pluginName, fmt.Sprintf("%x", sha.Sum(nil)), pluginDir return testPlugin{
name: pluginName,
typ: typ,
version: pluginVersion,
fileName: path.Base(pluginPath),
sha256: fmt.Sprintf("%x", sha.Sum(nil)),
}
} }
func TestCore_EnableExternalPlugin(t *testing.T) { func TestCore_EnableExternalPlugin(t *testing.T) {
@ -153,10 +177,10 @@ func TestCore_EnableExternalPlugin(t *testing.T) {
}, },
} { } {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") c, plugins := testCoreWithPlugins(t, tc.pluginType, "")
registerPlugin(t, c.systemBackend, pluginName, tc.pluginType.String(), "1.0.0", pluginSHA256) registerPlugin(t, c.systemBackend, plugins[0].name, tc.pluginType.String(), "1.0.0", plugins[0].sha256, plugins[0].fileName)
mountPlugin(t, c.systemBackend, pluginName, tc.pluginType, "v1.0.0") mountPlugin(t, c.systemBackend, plugins[0].name, tc.pluginType, "v1.0.0", "")
match := c.router.MatchingMount(namespace.RootContext(nil), tc.routerPath) match := c.router.MatchingMount(namespace.RootContext(nil), tc.routerPath)
if match != tc.expectedMatch { if match != tc.expectedMatch {
@ -225,12 +249,12 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) {
}, },
} { } {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") c, plugins := testCoreWithPlugins(t, tc.pluginType, "")
for _, version := range tc.registerVersions { for _, version := range tc.registerVersions {
registerPlugin(t, c.systemBackend, pluginName, tc.pluginType.String(), version, pluginSHA256) registerPlugin(t, c.systemBackend, plugins[0].name, tc.pluginType.String(), version, plugins[0].sha256, plugins[0].fileName)
} }
mountPlugin(t, c.systemBackend, pluginName, tc.pluginType, tc.mountVersion) mountPlugin(t, c.systemBackend, plugins[0].name, tc.pluginType, tc.mountVersion, "")
match := c.router.MatchingMount(namespace.RootContext(nil), tc.routerPath) match := c.router.MatchingMount(namespace.RootContext(nil), tc.routerPath)
if match != tc.expectedMatch { if match != tc.expectedMatch {
@ -254,13 +278,16 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) {
} }
func TestCore_EnableExternalKv_MultipleVersions(t *testing.T) { func TestCore_EnableExternalKv_MultipleVersions(t *testing.T) {
pluginDir, cleanup := MakeTestPluginDir(t)
t.Cleanup(func() { cleanup(t) })
// new kv plugin can be registered but not mounted // new kv plugin can be registered but not mounted
pluginName, pluginSHA256, pluginDir := compilePlugin(t, consts.PluginTypeSecrets, "v1.2.3") plugin := compilePlugin(t, consts.PluginTypeSecrets, "v1.2.3", pluginDir)
err := os.Link(path.Join(pluginDir, pluginName), path.Join(pluginDir, "kv")) err := os.Link(path.Join(pluginDir, plugin.fileName), path.Join(pluginDir, "kv"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
pluginName = "kv" pluginName := "kv"
conf := &CoreConfig{ conf := &CoreConfig{
BuiltinRegistry: NewMockBuiltinRegistry(), BuiltinRegistry: NewMockBuiltinRegistry(),
PluginDirectory: pluginDir, PluginDirectory: pluginDir,
@ -268,7 +295,7 @@ func TestCore_EnableExternalKv_MultipleVersions(t *testing.T) {
c := TestCoreWithSealAndUI(t, conf) c := TestCoreWithSealAndUI(t, conf)
c, _, _ = testCoreUnsealed(t, c) c, _, _ = testCoreUnsealed(t, c)
registerPlugin(t, c.systemBackend, pluginName, consts.PluginTypeSecrets.String(), "v1.2.3", pluginSHA256) registerPlugin(t, c.systemBackend, pluginName, consts.PluginTypeSecrets.String(), "v1.2.3", plugin.sha256, plugin.fileName)
req := logical.TestRequest(t, logical.ReadOperation, "plugins/catalog") req := logical.TestRequest(t, logical.ReadOperation, "plugins/catalog")
resp, err := c.systemBackend.HandleRequest(namespace.RootContext(nil), req) resp, err := c.systemBackend.HandleRequest(namespace.RootContext(nil), req)
if err != nil { if err != nil {
@ -279,7 +306,7 @@ func TestCore_EnableExternalKv_MultipleVersions(t *testing.T) {
} }
found := false found := false
for _, plugin := range resp.Data["detailed"].([]pluginutil.VersionedPlugin) { for _, plugin := range resp.Data["detailed"].([]pluginutil.VersionedPlugin) {
if plugin.Name == "kv" && plugin.Version == "v1.2.3" { if plugin.Name == pluginName && plugin.Version == "v1.2.3" {
found = true found = true
break break
} }
@ -304,13 +331,16 @@ func TestCore_EnableExternalKv_MultipleVersions(t *testing.T) {
} }
func TestCore_EnableExternalNoop_MultipleVersions(t *testing.T) { func TestCore_EnableExternalNoop_MultipleVersions(t *testing.T) {
pluginDir, cleanup := MakeTestPluginDir(t)
t.Cleanup(func() { cleanup(t) })
// new noop plugin can be registered but not mounted // new noop plugin can be registered but not mounted
pluginName, pluginSHA256, pluginDir := compilePlugin(t, consts.PluginTypeCredential, "v1.2.3") plugin := compilePlugin(t, consts.PluginTypeCredential, "v1.2.3", pluginDir)
err := os.Link(path.Join(pluginDir, pluginName), path.Join(pluginDir, "noop")) err := os.Link(path.Join(pluginDir, plugin.fileName), path.Join(pluginDir, "noop"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
pluginName = "noop" pluginName := "noop"
conf := &CoreConfig{ conf := &CoreConfig{
BuiltinRegistry: NewMockBuiltinRegistry(), BuiltinRegistry: NewMockBuiltinRegistry(),
PluginDirectory: pluginDir, PluginDirectory: pluginDir,
@ -318,7 +348,7 @@ func TestCore_EnableExternalNoop_MultipleVersions(t *testing.T) {
c := TestCoreWithSealAndUI(t, conf) c := TestCoreWithSealAndUI(t, conf)
c, _, _ = testCoreUnsealed(t, c) c, _, _ = testCoreUnsealed(t, c)
registerPlugin(t, c.systemBackend, pluginName, consts.PluginTypeCredential.String(), "v1.2.3", pluginSHA256) registerPlugin(t, c.systemBackend, pluginName, consts.PluginTypeCredential.String(), "v1.2.3", plugin.sha256, plugin.fileName)
req := logical.TestRequest(t, logical.ReadOperation, "plugins/catalog") req := logical.TestRequest(t, logical.ReadOperation, "plugins/catalog")
resp, err := c.systemBackend.HandleRequest(namespace.RootContext(nil), req) resp, err := c.systemBackend.HandleRequest(namespace.RootContext(nil), req)
if err != nil { if err != nil {
@ -371,15 +401,15 @@ func TestCore_EnableExternalPlugin_NoVersionsOkay(t *testing.T) {
}, },
} { } {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") c, plugins := testCoreWithPlugins(t, tc.pluginType, "")
// When an unversioned plugin is registered, mounting a plugin with no // When an unversioned plugin is registered, mounting a plugin with no
// version specified should mount the unversioned plugin even if there // version specified should mount the unversioned plugin even if there
// are versioned plugins available. // are versioned plugins available.
for _, version := range []string{"", "v1.0.0"} { for _, version := range []string{"", "v1.0.0"} {
registerPlugin(t, c.systemBackend, pluginName, tc.pluginType.String(), version, pluginSHA256) registerPlugin(t, c.systemBackend, plugins[0].name, tc.pluginType.String(), version, plugins[0].sha256, plugins[0].fileName)
} }
mountPlugin(t, c.systemBackend, pluginName, tc.pluginType, "") mountPlugin(t, c.systemBackend, plugins[0].name, tc.pluginType, "", "")
match := c.router.MatchingMount(namespace.RootContext(nil), tc.routerPath) match := c.router.MatchingMount(namespace.RootContext(nil), tc.routerPath)
if match != tc.expectedMatch { if match != tc.expectedMatch {
@ -412,12 +442,12 @@ func TestCore_EnableExternalCredentialPlugin_NoVersionOnRegister(t *testing.T) {
}, },
} { } {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") c, plugins := testCoreWithPlugins(t, tc.pluginType, "")
registerPlugin(t, c.systemBackend, pluginName, tc.pluginType.String(), "", pluginSHA256) registerPlugin(t, c.systemBackend, plugins[0].name, tc.pluginType.String(), "", plugins[0].sha256, plugins[0].fileName)
req := logical.TestRequest(t, logical.UpdateOperation, mountTable(tc.pluginType)) req := logical.TestRequest(t, logical.UpdateOperation, mountTable(tc.pluginType))
req.Data = map[string]interface{}{ req.Data = map[string]interface{}{
"type": pluginName, "type": plugins[0].name,
"config": map[string]interface{}{ "config": map[string]interface{}{
"plugin_version": "v1.0.0", "plugin_version": "v1.0.0",
}, },
@ -442,13 +472,13 @@ func TestCore_EnableExternalCredentialPlugin_InvalidName(t *testing.T) {
}, },
} { } {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") c, plugins := testCoreWithPlugins(t, tc.pluginType, "")
d := &framework.FieldData{ d := &framework.FieldData{
Raw: map[string]interface{}{ Raw: map[string]interface{}{
"name": pluginName, "name": plugins[0].name,
"sha256": pluginSHA256, "sha256": plugins[0].sha256,
"version": "v1.0.0", "version": "v1.0.0",
"command": pluginName + "xyz", "command": plugins[0].name + "xyz",
}, },
Schema: c.systemBackend.pluginsCatalogCRUDPath().Fields, Schema: c.systemBackend.pluginsCatalogCRUDPath().Fields,
} }
@ -479,13 +509,13 @@ func TestExternalPlugin_getBackendTypeVersion(t *testing.T) {
}, },
} { } {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, tc.setRunningVersion) c, plugins := testCoreWithPlugins(t, tc.pluginType, tc.setRunningVersion)
registerPlugin(t, c.systemBackend, pluginName, tc.pluginType.String(), tc.setRunningVersion, pluginSHA256) registerPlugin(t, c.systemBackend, plugins[0].name, tc.pluginType.String(), tc.setRunningVersion, plugins[0].sha256, plugins[0].fileName)
shaBytes, _ := hex.DecodeString(pluginSHA256) shaBytes, _ := hex.DecodeString(plugins[0].sha256)
commandFull := filepath.Join(c.pluginCatalog.directory, pluginName) commandFull := filepath.Join(c.pluginCatalog.directory, plugins[0].fileName)
entry := &pluginutil.PluginRunner{ entry := &pluginutil.PluginRunner{
Name: pluginName, Name: plugins[0].name,
Command: commandFull, Command: commandFull,
Args: nil, Args: nil,
Sha256: shaBytes, Sha256: shaBytes,
@ -544,14 +574,14 @@ func TestExternalPlugin_CheckFilePermissions(t *testing.T) {
}, },
} { } {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, tc.pluginVersion) c, plugins := testCoreWithPlugins(t, tc.pluginType, tc.pluginVersion)
registeredPluginName := fmt.Sprintf(tc.pluginNameFmt, pluginName) registeredPluginName := fmt.Sprintf(tc.pluginNameFmt, plugins[0].name)
// Permissions will be checked once during registration. // Permissions will be checked once during registration.
req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("plugins/catalog/%s/%s", tc.pluginType.String(), registeredPluginName)) req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("plugins/catalog/%s/%s", tc.pluginType.String(), registeredPluginName))
req.Data = map[string]interface{}{ req.Data = map[string]interface{}{
"command": pluginName, "command": plugins[0].fileName,
"sha256": pluginSHA256, "sha256": plugins[0].sha256,
"version": tc.pluginVersion, "version": tc.pluginVersion,
} }
resp, err := c.systemBackend.HandleRequest(namespace.RootContext(nil), req) resp, err := c.systemBackend.HandleRequest(namespace.RootContext(nil), req)
@ -583,26 +613,138 @@ func TestExternalPlugin_CheckFilePermissions(t *testing.T) {
} }
} }
func registerPlugin(t *testing.T, sys *SystemBackend, pluginName, pluginType, version, sha string) { func TestExternalPlugin_DifferentVersionsAndArgs_AreNotMultiplexed(t *testing.T) {
env := []string{fmt.Sprintf("%s=yes", vaultTestingMockPluginEnv)}
core, _, _ := TestCoreUnsealed(t)
for i, tc := range []struct {
version string
testName string
}{
{"v1.2.3", "TestBackend_PluginMain_Multiplexed_Logical_v123"},
{"v1.2.4", "TestBackend_PluginMain_Multiplexed_Logical_v124"},
} {
// Register and mount plugins.
TestAddTestPlugin(t, core, "mux-secret", consts.PluginTypeSecrets, tc.version, tc.testName, env, "")
mountPlugin(t, core.systemBackend, "mux-secret", consts.PluginTypeSecrets, tc.version, fmt.Sprintf("foo%d", i))
}
if len(core.pluginCatalog.externalPlugins) != 2 {
t.Fatalf("expected 2 external plugins, but got %d", len(core.pluginCatalog.externalPlugins))
}
}
func TestExternalPlugin_DifferentTypes_AreNotMultiplexed(t *testing.T) {
const version = "v1.2.3"
env := []string{fmt.Sprintf("%s=yes", vaultTestingMockPluginEnv)}
core, _, _ := TestCoreUnsealed(t)
// Register and mount plugins.
TestAddTestPlugin(t, core, "mux-aws", consts.PluginTypeSecrets, version, "TestBackend_PluginMain_Multiplexed_Logical_v123", env, "")
TestAddTestPlugin(t, core, "mux-aws", consts.PluginTypeCredential, version, "TestBackend_PluginMain_Multiplexed_Credential_v123", env, "")
mountPlugin(t, core.systemBackend, "mux-aws", consts.PluginTypeSecrets, version, "")
mountPlugin(t, core.systemBackend, "mux-aws", consts.PluginTypeCredential, version, "")
if len(core.pluginCatalog.externalPlugins) != 2 {
t.Fatalf("expected 2 external plugins, but got %d", len(core.pluginCatalog.externalPlugins))
}
}
func TestExternalPlugin_DifferentEnv_AreNotMultiplexed(t *testing.T) {
const version = "v1.2.3"
baseEnv := []string{
fmt.Sprintf("%s=yes", vaultTestingMockPluginEnv),
}
alteredEnv := []string{
fmt.Sprintf("%s=yes", vaultTestingMockPluginEnv),
"FOO=BAR",
}
core, _, _ := TestCoreUnsealed(t)
// Register and mount plugins.
for i, env := range [][]string{baseEnv, alteredEnv} {
TestAddTestPlugin(t, core, "mux-secret", consts.PluginTypeSecrets, version, "TestBackend_PluginMain_Multiplexed_Logical_v123", env, "")
mountPlugin(t, core.systemBackend, "mux-secret", consts.PluginTypeSecrets, version, fmt.Sprintf("foo%d", i))
}
if len(core.pluginCatalog.externalPlugins) != 2 {
t.Fatalf("expected 2 external plugins, but got %d", len(core.pluginCatalog.externalPlugins))
}
}
// Used to run a mock multiplexed secrets plugin
func TestBackend_PluginMain_Multiplexed_Logical_v123(t *testing.T) {
if os.Getenv(vaultTestingMockPluginEnv) == "" {
return
}
os.Setenv(mock.MockPluginVersionEnv, "v1.2.3")
err := plugin.ServeMultiplex(&plugin.ServeOpts{
BackendFactoryFunc: mock.FactoryType(logical.TypeLogical),
})
if err != nil {
t.Fatal(err)
}
}
// Used to run a mock multiplexed secrets plugin
func TestBackend_PluginMain_Multiplexed_Logical_v124(t *testing.T) {
if os.Getenv(vaultTestingMockPluginEnv) == "" {
return
}
os.Setenv(mock.MockPluginVersionEnv, "v1.2.4")
err := plugin.ServeMultiplex(&plugin.ServeOpts{
BackendFactoryFunc: mock.FactoryType(logical.TypeLogical),
})
if err != nil {
t.Fatal(err)
}
}
// Used to run a mock multiplexed auth plugin
func TestBackend_PluginMain_Multiplexed_Credential_v123(t *testing.T) {
if os.Getenv(vaultTestingMockPluginEnv) == "" {
return
}
os.Setenv(mock.MockPluginVersionEnv, "v1.2.3")
err := plugin.ServeMultiplex(&plugin.ServeOpts{
BackendFactoryFunc: mock.FactoryType(logical.TypeCredential),
})
if err != nil {
t.Fatal(err)
}
}
func registerPlugin(t *testing.T, sys *SystemBackend, pluginName, pluginType, version, sha, command string) {
t.Helper() t.Helper()
req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("plugins/catalog/%s/%s", pluginType, pluginName)) req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("plugins/catalog/%s/%s", pluginType, pluginName))
req.Data = map[string]interface{}{ req.Data = map[string]interface{}{
"command": pluginName, "command": command,
"sha256": sha, "sha256": sha,
"version": version, "version": version,
} }
resp, err := sys.HandleRequest(namespace.RootContext(nil), req) resp, err := sys.HandleRequest(namespace.RootContext(nil), req)
if err != nil { if err != nil || (resp != nil && resp.IsError()) {
t.Fatal(err) t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Error() != nil {
t.Fatal(resp.Error())
} }
} }
func mountPlugin(t *testing.T, sys *SystemBackend, pluginName string, pluginType consts.PluginType, version string) { func mountPlugin(t *testing.T, sys *SystemBackend, pluginName string, pluginType consts.PluginType, version, path string) {
t.Helper() t.Helper()
req := logical.TestRequest(t, logical.UpdateOperation, mountTable(pluginType)) var mountPath string
if path == "" {
mountPath = mountTable(pluginType)
} else {
mountPath = mountTableWithPath(consts.PluginTypeSecrets, path)
}
req := logical.TestRequest(t, logical.UpdateOperation, mountPath)
req.Data = map[string]interface{}{ req.Data = map[string]interface{}{
"type": pluginName, "type": pluginName,
} }
@ -612,20 +754,21 @@ func mountPlugin(t *testing.T, sys *SystemBackend, pluginName string, pluginType
} }
} }
resp, err := sys.HandleRequest(namespace.RootContext(nil), req) resp, err := sys.HandleRequest(namespace.RootContext(nil), req)
if err != nil { if err != nil || (resp != nil && resp.IsError()) {
t.Fatal(err) t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Error() != nil {
t.Fatal(resp.Error())
} }
} }
func mountTable(pluginType consts.PluginType) string { func mountTable(pluginType consts.PluginType) string {
return mountTableWithPath(pluginType, "foo")
}
func mountTableWithPath(pluginType consts.PluginType, path string) string {
switch pluginType { switch pluginType {
case consts.PluginTypeCredential: case consts.PluginTypeCredential:
return "auth/foo" return "auth/" + path
case consts.PluginTypeSecrets: case consts.PluginTypeSecrets:
return "mounts/foo" return "mounts/" + path
default: default:
panic("test does not support mounting plugin type yet: " + pluginType.String()) panic("test does not support mounting plugin type yet: " + pluginType.String())
} }

View File

@ -45,24 +45,63 @@ type PluginCatalog struct {
directory string directory string
logger log.Logger logger log.Logger
// externalPlugins holds plugin process connections by plugin name // externalPlugins holds plugin process connections by a key which is
// generated from the plugin runner config.
// //
// This allows plugins that suppport multiplexing to use a single grpc // This allows plugins that suppport multiplexing to use a single grpc
// connection to communicate with multiple "backends". Each backend // connection to communicate with multiple "backends". Each backend
// configuration using the same plugin will be routed to the existing // configuration using the same plugin will be routed to the existing
// plugin process. // plugin process.
externalPlugins map[string]*externalPlugin externalPlugins map[externalPluginsKey]*externalPlugin
mlockPlugins bool mlockPlugins bool
lock sync.RWMutex lock sync.RWMutex
} }
// Only plugins running with identical PluginRunner config can be multiplexed,
// so we use the PluginRunner input as the key for the external plugins map.
//
// However, to be a map key, it must be comparable:
// https://go.dev/ref/spec#Comparison_operators.
// In particular, the PluginRunner struct has slices and a function which are not
// comparable, so we need to transform it into a struct which is.
type externalPluginsKey struct {
name string
typ consts.PluginType
version string
command string
args string
env string
sha256 string
builtin bool
}
func makeExternalPluginsKey(p *pluginutil.PluginRunner) (externalPluginsKey, error) {
args, err := json.Marshal(p.Args)
if err != nil {
return externalPluginsKey{}, err
}
env, err := json.Marshal(p.Env)
if err != nil {
return externalPluginsKey{}, err
}
return externalPluginsKey{
name: p.Name,
typ: p.Type,
version: p.Version,
command: p.Command,
args: string(args),
env: string(env),
sha256: hex.EncodeToString(p.Sha256),
builtin: p.Builtin,
}, nil
}
// externalPlugin holds client connections for multiplexed and // externalPlugin holds client connections for multiplexed and
// non-multiplexed plugin processes // non-multiplexed plugin processes
type externalPlugin struct { type externalPlugin struct {
// name is the plugin name
name string
// connections holds client connections by ID // connections holds client connections by ID
connections map[string]*pluginClient connections map[string]*pluginClient
@ -179,13 +218,13 @@ func (p *pluginClient) Reload() error {
// reloadExternalPlugin // reloadExternalPlugin
// This should be called with the write lock held. // This should be called with the write lock held.
func (c *PluginCatalog) reloadExternalPlugin(name, id string) error { func (c *PluginCatalog) reloadExternalPlugin(key externalPluginsKey, id, path string) error {
extPlugin, ok := c.externalPlugins[name] extPlugin, ok := c.externalPlugins[key]
if !ok { if !ok {
return fmt.Errorf("plugin client not found") return fmt.Errorf("plugin client not found")
} }
if !extPlugin.multiplexingSupport { if !extPlugin.multiplexingSupport {
err := c.cleanupExternalPlugin(name, id) err := c.cleanupExternalPlugin(key, id, path)
if err != nil { if err != nil {
return err return err
} }
@ -197,9 +236,9 @@ func (c *PluginCatalog) reloadExternalPlugin(name, id string) error {
return fmt.Errorf("%w id: %s", ErrPluginConnectionNotFound, id) return fmt.Errorf("%w id: %s", ErrPluginConnectionNotFound, id)
} }
delete(c.externalPlugins, name) delete(c.externalPlugins, key)
pc.client.Kill() pc.client.Kill()
c.logger.Debug("killed external plugin process for reload", "name", name, "pid", pc.pid) c.logger.Debug("killed external plugin process for reload", "path", path, "pid", pc.pid)
return nil return nil
} }
@ -215,8 +254,8 @@ func (p *pluginClient) Close() error {
// cleanupExternalPlugin will kill plugin processes and perform any necessary // cleanupExternalPlugin will kill plugin processes and perform any necessary
// cleanup on the externalPlugins map for multiplexed and non-multiplexed // cleanup on the externalPlugins map for multiplexed and non-multiplexed
// plugins. This should be called with the write lock held. // plugins. This should be called with the write lock held.
func (c *PluginCatalog) cleanupExternalPlugin(name, id string) error { func (c *PluginCatalog) cleanupExternalPlugin(key externalPluginsKey, id, path string) error {
extPlugin, ok := c.externalPlugins[name] extPlugin, ok := c.externalPlugins[key]
if !ok { if !ok {
return fmt.Errorf("plugin client not found") return fmt.Errorf("plugin client not found")
} }
@ -236,37 +275,36 @@ func (c *PluginCatalog) cleanupExternalPlugin(name, id string) error {
pc.client.Kill() pc.client.Kill()
if len(extPlugin.connections) == 0 { if len(extPlugin.connections) == 0 {
delete(c.externalPlugins, name) delete(c.externalPlugins, key)
} }
c.logger.Debug("killed external plugin process", "name", name, "pid", pc.pid) c.logger.Debug("killed external plugin process", "path", path, "pid", pc.pid)
} else if len(extPlugin.connections) == 0 || pc.client.Exited() { } else if len(extPlugin.connections) == 0 || pc.client.Exited() {
pc.client.Kill() pc.client.Kill()
delete(c.externalPlugins, name) delete(c.externalPlugins, key)
c.logger.Debug("killed external multiplexed plugin process", "name", name, "pid", pc.pid) c.logger.Debug("killed external multiplexed plugin process", "path", path, "pid", pc.pid)
} }
return nil return nil
} }
func (c *PluginCatalog) getExternalPlugin(pluginName string) *externalPlugin { func (c *PluginCatalog) getExternalPlugin(key externalPluginsKey) *externalPlugin {
if extPlugin, ok := c.externalPlugins[pluginName]; ok { if extPlugin, ok := c.externalPlugins[key]; ok {
return extPlugin return extPlugin
} }
return c.newExternalPlugin(pluginName) return c.newExternalPlugin(key)
} }
func (c *PluginCatalog) newExternalPlugin(pluginName string) *externalPlugin { func (c *PluginCatalog) newExternalPlugin(key externalPluginsKey) *externalPlugin {
if c.externalPlugins == nil { if c.externalPlugins == nil {
c.externalPlugins = make(map[string]*externalPlugin) c.externalPlugins = make(map[externalPluginsKey]*externalPlugin)
} }
extPlugin := &externalPlugin{ extPlugin := &externalPlugin{
connections: make(map[string]*pluginClient), connections: make(map[string]*pluginClient),
name: pluginName,
} }
c.externalPlugins[pluginName] = extPlugin c.externalPlugins[key] = extPlugin
return extPlugin return extPlugin
} }
@ -301,7 +339,12 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi
return nil, fmt.Errorf("no plugin found") return nil, fmt.Errorf("no plugin found")
} }
extPlugin := c.getExternalPlugin(pluginRunner.Name) key, err := makeExternalPluginsKey(pluginRunner)
if err != nil {
return nil, err
}
extPlugin := c.getExternalPlugin(key)
id, err := base62.Random(10) id, err := base62.Random(10)
if err != nil { if err != nil {
return nil, err return nil, err
@ -313,12 +356,12 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi
cleanupFunc: func() error { cleanupFunc: func() error {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
return c.cleanupExternalPlugin(pluginRunner.Name, id) return c.cleanupExternalPlugin(key, id, pluginRunner.Command)
}, },
reloadFunc: func() error { reloadFunc: func() error {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
return c.reloadExternalPlugin(pluginRunner.Name, id) return c.reloadExternalPlugin(key, id, pluginRunner.Command)
}, },
} }
@ -382,7 +425,6 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi
pc.ClientProtocol = rpcClient pc.ClientProtocol = rpcClient
extPlugin.connections[id] = pc extPlugin.connections[id] = pc
extPlugin.name = pluginRunner.Name
extPlugin.multiplexingSupport = muxed extPlugin.multiplexingSupport = muxed
return extPlugin.connections[id], nil return extPlugin.connections[id], nil
@ -428,7 +470,17 @@ func (c *PluginCatalog) getBackendPluginType(ctx context.Context, pluginRunner *
pc, err := c.newPluginClient(ctx, pluginRunner, config) pc, err := c.newPluginClient(ctx, pluginRunner, config)
if err == nil { if err == nil {
// we spawned a subprocess, so make sure to clean it up // we spawned a subprocess, so make sure to clean it up
defer c.cleanupExternalPlugin(pluginRunner.Name, pc.id) key, err := makeExternalPluginsKey(pluginRunner)
if err != nil {
return consts.PluginTypeUnknown, err
}
defer func() {
// Close the client and cleanup the plugin process
err = c.cleanupExternalPlugin(key, pc.id, pluginRunner.Command)
if err != nil {
c.logger.Error("error closing plugin client", "error", err)
}
}()
// dispense the plugin so we can get its type // dispense the plugin so we can get its type
client, err = backendplugin.Dispense(pc.ClientProtocol, pc) client, err = backendplugin.Dispense(pc.ClientProtocol, pc)
@ -506,7 +558,17 @@ func (c *PluginCatalog) getBackendRunningVersion(ctx context.Context, pluginRunn
pc, err := c.newPluginClient(ctx, pluginRunner, config) pc, err := c.newPluginClient(ctx, pluginRunner, config)
if err == nil { if err == nil {
// we spawned a subprocess, so make sure to clean it up // we spawned a subprocess, so make sure to clean it up
defer c.cleanupExternalPlugin(pluginRunner.Name, pc.id) key, err := makeExternalPluginsKey(pluginRunner)
if err != nil {
return logical.EmptyPluginVersion, err
}
defer func() {
// Close the client and cleanup the plugin process
err = c.cleanupExternalPlugin(key, pc.id, pluginRunner.Command)
if err != nil {
c.logger.Error("error closing plugin client", "error", err)
}
}()
// dispense the plugin so we can get its version // dispense the plugin so we can get its version
client, err = backendplugin.Dispense(pc.ClientProtocol, pc) client, err = backendplugin.Dispense(pc.ClientProtocol, pc)
@ -565,9 +627,13 @@ func (c *PluginCatalog) getDatabaseRunningVersion(ctx context.Context, pluginRun
c.logger.Debug("attempting to load database plugin as v5", "name", pluginRunner.Name) c.logger.Debug("attempting to load database plugin as v5", "name", pluginRunner.Name)
v5Client, err := c.newPluginClient(ctx, pluginRunner, config) v5Client, err := c.newPluginClient(ctx, pluginRunner, config)
if err == nil { if err == nil {
key, err := makeExternalPluginsKey(pluginRunner)
if err != nil {
return logical.EmptyPluginVersion, err
}
defer func() { defer func() {
// Close the client and cleanup the plugin process // Close the client and cleanup the plugin process
err = c.cleanupExternalPlugin(pluginRunner.Name, v5Client.id) err = c.cleanupExternalPlugin(key, v5Client.id, pluginRunner.Command)
if err != nil { if err != nil {
c.logger.Error("error closing plugin client", "error", err) c.logger.Error("error closing plugin client", "error", err)
} }
@ -624,7 +690,11 @@ func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *plug
v5Client, err := c.newPluginClient(ctx, pluginRunner, config) v5Client, err := c.newPluginClient(ctx, pluginRunner, config)
if err == nil { if err == nil {
// Close the client and cleanup the plugin process // Close the client and cleanup the plugin process
err = c.cleanupExternalPlugin(pluginRunner.Name, v5Client.id) key, err := makeExternalPluginsKey(pluginRunner)
if err != nil {
return err
}
err = c.cleanupExternalPlugin(key, v5Client.id, pluginRunner.Command)
if err != nil { if err != nil {
c.logger.Error("error closing plugin client", "error", err) c.logger.Error("error closing plugin client", "error", err)
} }

View File

@ -2,6 +2,7 @@ package vault
import ( import (
"context" "context"
"crypto/sha256"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -480,6 +481,23 @@ func TestPluginCatalog_NewPluginClient(t *testing.T) {
TestAddTestPlugin(t, core, "single-userpass-1", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_Userpass", []string{}, "") TestAddTestPlugin(t, core, "single-userpass-1", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_Userpass", []string{}, "")
TestAddTestPlugin(t, core, "single-userpass-2", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_Userpass", []string{}, "") TestAddTestPlugin(t, core, "single-userpass-2", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_Userpass", []string{}, "")
getKey := func(pluginName string, pluginType consts.PluginType) externalPluginsKey {
t.Helper()
ctx := context.Background()
plugin, err := core.pluginCatalog.Get(ctx, pluginName, pluginType, "")
if err != nil {
t.Fatal(err)
}
if plugin == nil {
t.Fatal("did not find " + pluginName)
}
key, err := makeExternalPluginsKey(plugin)
if err != nil {
t.Fatal(err)
}
return key
}
var pluginClients []*pluginClient var pluginClients []*pluginClient
// run plugins // run plugins
// run "mux-postgres" twice which will start a single plugin for 2 // run "mux-postgres" twice which will start a single plugin for 2
@ -510,20 +528,20 @@ func TestPluginCatalog_NewPluginClient(t *testing.T) {
} }
// check connections map // check connections map
expectConnectionLen(t, 2, externalPlugins["mux-postgres"].connections) expectConnectionLen(t, 2, externalPlugins[getKey("mux-postgres", consts.PluginTypeDatabase)].connections)
expectConnectionLen(t, 1, externalPlugins["single-postgres-1"].connections) expectConnectionLen(t, 1, externalPlugins[getKey("single-postgres-1", consts.PluginTypeDatabase)].connections)
expectConnectionLen(t, 1, externalPlugins["single-postgres-2"].connections) expectConnectionLen(t, 1, externalPlugins[getKey("single-postgres-2", consts.PluginTypeDatabase)].connections)
expectConnectionLen(t, 2, externalPlugins["mux-userpass"].connections) expectConnectionLen(t, 2, externalPlugins[getKey("mux-userpass", consts.PluginTypeCredential)].connections)
expectConnectionLen(t, 1, externalPlugins["single-userpass-1"].connections) expectConnectionLen(t, 1, externalPlugins[getKey("single-userpass-1", consts.PluginTypeCredential)].connections)
expectConnectionLen(t, 1, externalPlugins["single-userpass-2"].connections) expectConnectionLen(t, 1, externalPlugins[getKey("single-userpass-2", consts.PluginTypeCredential)].connections)
// check multiplexing support // check multiplexing support
expectMultiplexingSupport(t, true, externalPlugins["mux-postgres"].multiplexingSupport) expectMultiplexingSupport(t, true, externalPlugins[getKey("mux-postgres", consts.PluginTypeDatabase)].multiplexingSupport)
expectMultiplexingSupport(t, false, externalPlugins["single-postgres-1"].multiplexingSupport) expectMultiplexingSupport(t, false, externalPlugins[getKey("single-postgres-1", consts.PluginTypeDatabase)].multiplexingSupport)
expectMultiplexingSupport(t, false, externalPlugins["single-postgres-2"].multiplexingSupport) expectMultiplexingSupport(t, false, externalPlugins[getKey("single-postgres-2", consts.PluginTypeDatabase)].multiplexingSupport)
expectMultiplexingSupport(t, true, externalPlugins["mux-userpass"].multiplexingSupport) expectMultiplexingSupport(t, true, externalPlugins[getKey("mux-userpass", consts.PluginTypeCredential)].multiplexingSupport)
expectMultiplexingSupport(t, false, externalPlugins["single-userpass-1"].multiplexingSupport) expectMultiplexingSupport(t, false, externalPlugins[getKey("single-userpass-1", consts.PluginTypeCredential)].multiplexingSupport)
expectMultiplexingSupport(t, false, externalPlugins["single-userpass-2"].multiplexingSupport) expectMultiplexingSupport(t, false, externalPlugins[getKey("single-userpass-2", consts.PluginTypeCredential)].multiplexingSupport)
// cleanup all of the external plugin processes // cleanup all of the external plugin processes
for _, client := range pluginClients { for _, client := range pluginClients {
@ -536,6 +554,38 @@ func TestPluginCatalog_NewPluginClient(t *testing.T) {
} }
} }
func TestPluginCatalog_MakeExternalPluginsKey_Comparable(t *testing.T) {
var plugins []pluginutil.PluginRunner
hasher := sha256.New()
hasher.Write([]byte("Some random input"))
for i := 0; i < 2; i++ {
plugins = append(plugins, pluginutil.PluginRunner{
Name: "Name",
Type: consts.PluginTypeDatabase,
Version: "Version",
Command: "Command",
Args: []string{"Some", "Args"},
Env: []string{"Env=foo", "bar=", "baz=foo"},
Sha256: hasher.Sum(nil),
Builtin: true,
})
}
var keys []externalPluginsKey
for _, plugin := range plugins {
key, err := makeExternalPluginsKey(&plugin)
if err != nil {
t.Fatal(err)
}
keys = append(keys, key)
}
if keys[0] != keys[1] {
t.Fatal("expected equality")
}
}
func TestPluginCatalog_PluginMain_Userpass(t *testing.T) { func TestPluginCatalog_PluginMain_Userpass(t *testing.T) {
if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" {
return return

View File

@ -513,6 +513,11 @@ func TestDynamicSystemView(c *Core, ns *namespace.Namespace) *dynamicSystemView
// TestAddTestPlugin registers the testFunc as part of the plugin command to the // TestAddTestPlugin registers the testFunc as part of the plugin command to the
// plugin catalog. If provided, uses tmpDir as the plugin directory. // plugin catalog. If provided, uses tmpDir as the plugin directory.
// NB: The test func you pass in MUST be in the same package as the parent test,
// or the test func won't be compiled into the test binary being run and the output
// will be something like:
// stderr (ignored by go-plugin): "testing: warning: no tests to run"
// stdout: "PASS"
func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.PluginType, version string, testFunc string, env []string, tempDir string) { func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.PluginType, version string, testFunc string, env []string, tempDir string) {
file, err := os.Open(os.Args[0]) file, err := os.Open(os.Args[0])
if err != nil { if err != nil {