From 1553c310c40b6471cac7e7d2948fc822451502bd Mon Sep 17 00:00:00 2001 From: John-Michael Faircloth Date: Tue, 14 Mar 2023 09:36:37 -0500 Subject: [PATCH] Fix a possible data race with rollback manager and plugin reload (#19468) * fix data race on plugin reload * add changelog * add comment for posterity * revert comment and return assignment in router.go * rework plugin continue on error tests to use compilePlugin * fix race condition on route entry * add test for plugin reload and rollback race detection * add go doc for test --- changelog/19468.txt | 3 + .../plugin/external_plugin_test.go | 263 +++++++++++++++--- vault/external_tests/plugin/plugin_test.go | 160 ----------- vault/router.go | 7 +- 4 files changed, 227 insertions(+), 206 deletions(-) create mode 100644 changelog/19468.txt diff --git a/changelog/19468.txt b/changelog/19468.txt new file mode 100644 index 000000000..5afce90eb --- /dev/null +++ b/changelog/19468.txt @@ -0,0 +1,3 @@ +```release-note:bug +plugin/reload: Fix a possible data race with rollback manager and plugin reload +``` diff --git a/vault/external_tests/plugin/external_plugin_test.go b/vault/external_tests/plugin/external_plugin_test.go index 5df47861a..cdb6da225 100644 --- a/vault/external_tests/plugin/external_plugin_test.go +++ b/vault/external_tests/plugin/external_plugin_test.go @@ -3,13 +3,18 @@ package plugin_test import ( "context" "fmt" + "os" + "path/filepath" "testing" + "time" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api/auth/approle" "github.com/hashicorp/vault/builtin/logical/database" + "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/testhelpers/consul" "github.com/hashicorp/vault/helper/testhelpers/corehelpers" + "github.com/hashicorp/vault/helper/testhelpers/pluginhelpers" postgreshelper "github.com/hashicorp/vault/helper/testhelpers/postgresql" vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/sdk/helper/consts" @@ -30,6 +35,7 @@ func getCluster(t *testing.T, typ consts.PluginType, numCores int) *vault.TestCl } cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + TempDir: pluginDir, NumCores: numCores, Plugins: &vault.TestPluginConfig{ Typ: typ, @@ -44,6 +50,208 @@ func getCluster(t *testing.T, typ consts.PluginType, numCores int) *vault.TestCl return cluster } +// TestExternalPlugin_RollbackAndReload ensures that we can successfully +// rollback and reload a plugin without triggering race conditions by the go +// race detector +func TestExternalPlugin_RollbackAndReload(t *testing.T) { + pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) + t.Cleanup(func() { cleanup(t) }) + coreConfig := &vault.CoreConfig{ + // set rollback period to a short interval to make conditions more "racy" + RollbackPeriod: 1 * time.Second, + PluginDirectory: pluginDir, + } + + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + TempDir: pluginDir, + NumCores: 1, + Plugins: &vault.TestPluginConfig{ + Typ: consts.PluginTypeSecrets, + Versions: []string{""}, + }, + HandlerFunc: vaulthttp.Handler, + }) + + cluster.Start() + vault.TestWaitActive(t, cluster.Cores[0].Core) + + core := cluster.Cores[0] + plugin := cluster.Plugins[0] + client := core.Client + client.SetToken(cluster.RootToken) + + testRegisterAndEnable(t, client, plugin) + if _, err := client.Sys().ReloadPlugin(&api.ReloadPluginInput{ + Plugin: plugin.Name, + }); err != nil { + t.Fatal(err) + } +} + +func testRegisterAndEnable(t *testing.T, client *api.Client, plugin pluginhelpers.TestPlugin) { + t.Helper() + if err := client.Sys().RegisterPlugin(&api.RegisterPluginInput{ + Name: plugin.Name, + Type: api.PluginType(plugin.Typ), + Command: plugin.Name, + SHA256: plugin.Sha256, + Version: plugin.Version, + }); err != nil { + t.Fatal(err) + } + + switch plugin.Typ { + case consts.PluginTypeSecrets: + if err := client.Sys().Mount(plugin.Name, &api.MountInput{ + Type: plugin.Name, + }); err != nil { + t.Fatal(err) + } + case consts.PluginTypeCredential: + if err := client.Sys().EnableAuthWithOptions(plugin.Name, &api.EnableAuthOptions{ + Type: plugin.Name, + }); err != nil { + t.Fatal(err) + } + } +} + +// TestExternalPlugin_ContinueOnError tests that vault can recover from a +// sha256 mismatch or missing plugin binary scenario +func TestExternalPlugin_ContinueOnError(t *testing.T) { + t.Run("secret", func(t *testing.T) { + t.Parallel() + t.Run("sha256_mismatch", func(t *testing.T) { + t.Parallel() + testExternalPlugin_ContinueOnError(t, true, consts.PluginTypeSecrets) + }) + + t.Run("missing_plugin", func(t *testing.T) { + t.Parallel() + testExternalPlugin_ContinueOnError(t, false, consts.PluginTypeSecrets) + }) + }) + + t.Run("auth", func(t *testing.T) { + t.Parallel() + t.Run("sha256_mismatch", func(t *testing.T) { + t.Parallel() + testExternalPlugin_ContinueOnError(t, true, consts.PluginTypeCredential) + }) + + t.Run("missing_plugin", func(t *testing.T) { + t.Parallel() + testExternalPlugin_ContinueOnError(t, false, consts.PluginTypeCredential) + }) + }) +} + +func testExternalPlugin_ContinueOnError(t *testing.T, mismatch bool, pluginType consts.PluginType) { + cluster := getCluster(t, pluginType, 1) + defer cluster.Cleanup() + + core := cluster.Cores[0] + plugin := cluster.Plugins[0] + client := core.Client + client.SetToken(cluster.RootToken) + + testRegisterAndEnable(t, client, plugin) + pluginPath := fmt.Sprintf("sys/plugins/catalog/%s/%s", pluginType, plugin.Name) + // Get the registered plugin + req := logical.TestRequest(t, logical.ReadOperation, pluginPath) + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(namespace.RootContext(testCtx), req) + if err != nil || resp == nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + command, ok := resp.Data["command"].(string) + if !ok || command == "" { + t.Fatal("invalid command") + } + + // Trigger a sha256 mismatch or missing plugin error + if mismatch { + req = logical.TestRequest(t, logical.UpdateOperation, pluginPath) + req.Data = map[string]interface{}{ + "sha256": "d17bd7334758e53e6fbab15745d2520765c06e296f2ce8e25b7919effa0ac216", + "command": filepath.Base(command), + } + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(namespace.RootContext(testCtx), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + } else { + err := os.Remove(filepath.Join(cluster.TempDir, filepath.Base(command))) + if err != nil { + t.Fatal(err) + } + } + + // Seal the cluster + cluster.EnsureCoresSealed(t) + + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + if core.Sealed() { + t.Fatal("should not be sealed") + } + } + + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + + // unmount + switch pluginType { + case consts.PluginTypeSecrets: + if err := client.Sys().Unmount(plugin.Name); err != nil { + t.Fatal(err) + } + case consts.PluginTypeCredential: + if err := client.Sys().DisableAuth(plugin.Name); err != nil { + t.Fatal(err) + } + } + + // Re-compile plugin + var plugins []pluginhelpers.TestPlugin + plugins = append(plugins, pluginhelpers.CompilePlugin(t, pluginType, "", core.CoreConfig.PluginDirectory)) + cluster.Plugins = plugins + + // Re-add the plugin to the catalog + testRegisterAndEnable(t, client, plugin) + + // Reload the plugin + req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/reload/backend") + req.Data = map[string]interface{}{ + "plugin": plugin.Name, + } + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(namespace.RootContext(testCtx), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + req = logical.TestRequest(t, logical.ReadOperation, pluginPath) + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(namespace.RootContext(testCtx), req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: response should not be nil") + } +} + // TestExternalPlugin_AuthMethod tests that we can build, register and use an // external auth method func TestExternalPlugin_AuthMethod(t *testing.T) { @@ -173,28 +381,10 @@ func TestExternalPlugin_AuthMethodReload(t *testing.T) { client := cluster.Cores[0].Client client.SetToken(cluster.RootToken) - // Register - if err := client.Sys().RegisterPlugin(&api.RegisterPluginInput{ - Name: plugin.Name, - Type: api.PluginType(plugin.Typ), - Command: plugin.Name, - SHA256: plugin.Sha256, - Version: plugin.Version, - }); err != nil { - t.Fatal(err) - } - - pluginPath := fmt.Sprintf("%s-%d", plugin.Name, 0) - - // Enable - if err := client.Sys().EnableAuthWithOptions(pluginPath, &api.EnableAuthOptions{ - Type: plugin.Name, - }); err != nil { - t.Fatal(err) - } + testRegisterAndEnable(t, client, plugin) // Configure - _, err := client.Logical().Write("auth/"+pluginPath+"/role/role1", map[string]interface{}{ + _, err := client.Logical().Write("auth/"+plugin.Name+"/role/role1", map[string]interface{}{ "bind_secret_id": "true", "period": "300", }) @@ -202,13 +392,13 @@ func TestExternalPlugin_AuthMethodReload(t *testing.T) { t.Fatal(err) } - secret, err := client.Logical().Write("auth/"+pluginPath+"/role/role1/secret-id", nil) + secret, err := client.Logical().Write("auth/"+plugin.Name+"/role/role1/secret-id", nil) if err != nil { t.Fatal(err) } secretID := secret.Data["secret_id"].(string) - secret, err = client.Logical().Read("auth/" + pluginPath + "/role/role1/role-id") + secret, err = client.Logical().Read("auth/" + plugin.Name + "/role/role1/role-id") if err != nil { t.Fatal(err) } @@ -218,7 +408,7 @@ func TestExternalPlugin_AuthMethodReload(t *testing.T) { authMethod, err := approle.NewAppRoleAuth( roleID, &approle.SecretID{FromString: secretID}, - approle.WithMountPath(pluginPath), + approle.WithMountPath(plugin.Name), ) if err != nil { t.Fatal(err) @@ -346,30 +536,13 @@ func TestExternalPlugin_SecretsEngineReload(t *testing.T) { client := cluster.Cores[0].Client client.SetToken(cluster.RootToken) - // Register - if err := client.Sys().RegisterPlugin(&api.RegisterPluginInput{ - Name: plugin.Name, - Type: api.PluginType(plugin.Typ), - Command: plugin.Name, - SHA256: plugin.Sha256, - Version: plugin.Version, - }); err != nil { - t.Fatal(err) - } - - pluginPath := fmt.Sprintf("%s-%d", plugin.Name, 0) - // Enable - if err := client.Sys().Mount(pluginPath, &api.MountInput{ - Type: plugin.Name, - }); err != nil { - t.Fatal(err) - } + testRegisterAndEnable(t, client, plugin) // Configure cleanupConsul, consulConfig := consul.PrepareTestContainer(t, "", false, true) defer cleanupConsul() - _, err := client.Logical().Write(pluginPath+"/config/access", map[string]interface{}{ + _, err := client.Logical().Write(plugin.Name+"/config/access", map[string]interface{}{ "address": consulConfig.Address(), "token": consulConfig.Token, }) @@ -377,7 +550,7 @@ func TestExternalPlugin_SecretsEngineReload(t *testing.T) { t.Fatal(err) } - _, err = client.Logical().Write(pluginPath+"/roles/test", map[string]interface{}{ + _, err = client.Logical().Write(plugin.Name+"/roles/test", map[string]interface{}{ "consul_policies": []string{"test"}, "ttl": "6h", "local": false, @@ -386,7 +559,7 @@ func TestExternalPlugin_SecretsEngineReload(t *testing.T) { t.Fatal(err) } - resp, err := client.Logical().Read(pluginPath + "/creds/test") + resp, err := client.Logical().Read(plugin.Name + "/creds/test") if err != nil { t.Fatal(err) } @@ -401,7 +574,7 @@ func TestExternalPlugin_SecretsEngineReload(t *testing.T) { t.Fatal(err) } - resp, err = client.Logical().Read(pluginPath + "/creds/test") + resp, err = client.Logical().Read(plugin.Name + "/creds/test") if err != nil { t.Fatal(err) } diff --git a/vault/external_tests/plugin/plugin_test.go b/vault/external_tests/plugin/plugin_test.go index c0380bf56..6c33c2615 100644 --- a/vault/external_tests/plugin/plugin_test.go +++ b/vault/external_tests/plugin/plugin_test.go @@ -360,166 +360,6 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun } } -func TestSystemBackend_Plugin_continueOnError(t *testing.T) { - t.Run("secret", func(t *testing.T) { - t.Parallel() - t.Run("sha256_mismatch", func(t *testing.T) { - t.Parallel() - testPlugin_continueOnError(t, logical.TypeLogical, true, "mock-plugin", consts.PluginTypeSecrets) - }) - - t.Run("missing_plugin", func(t *testing.T) { - t.Parallel() - testPlugin_continueOnError(t, logical.TypeLogical, false, "mock-plugin", consts.PluginTypeSecrets) - }) - }) - - t.Run("auth", func(t *testing.T) { - t.Parallel() - t.Run("sha256_mismatch", func(t *testing.T) { - t.Parallel() - testPlugin_continueOnError(t, logical.TypeCredential, true, "mock-plugin", consts.PluginTypeCredential) - }) - - t.Run("missing_plugin", func(t *testing.T) { - t.Parallel() - testPlugin_continueOnError(t, logical.TypeCredential, false, "mock-plugin", consts.PluginTypeCredential) - }) - - t.Run("sha256_mismatch", func(t *testing.T) { - t.Parallel() - testPlugin_continueOnError(t, logical.TypeCredential, true, "oidc", consts.PluginTypeCredential) - }) - - t.Run("missing_plugin", func(t *testing.T) { - t.Parallel() - testPlugin_continueOnError(t, logical.TypeCredential, false, "oidc", consts.PluginTypeCredential) - }) - }) -} - -func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatch bool, mountPoint string, pluginType consts.PluginType) { - testCases := []struct { - pluginVersion string - }{ - { - pluginVersion: "v5_multiplexed", - }, - { - pluginVersion: "v5", - }, - { - pluginVersion: "v4", - }, - } - - for _, tc := range testCases { - t.Run(tc.pluginVersion, func(t *testing.T) { - t.Parallel() - cluster := testSystemBackendMock(t, 1, 1, btype, tc.pluginVersion) - defer cluster.Cleanup() - - core := cluster.Cores[0] - - // Get the registered plugin - req := logical.TestRequest(t, logical.ReadOperation, fmt.Sprintf("sys/plugins/catalog/%s/mock-plugin", pluginType)) - // We are using the mock backend from vault/sdk/plugin/mock/backend.go which sets the plugin version. - req.Data["version"] = "v0.0.0+mock" - req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(namespace.RootContext(testCtx), req) - if err != nil || resp == nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v", err, resp) - } - - command, ok := resp.Data["command"].(string) - if !ok || command == "" { - t.Fatal("invalid command") - } - - // Trigger a sha256 mismatch or missing plugin error - if mismatch { - req = logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("sys/plugins/catalog/%s/mock-plugin", pluginType)) - req.Data = map[string]interface{}{ - "sha256": "d17bd7334758e53e6fbab15745d2520765c06e296f2ce8e25b7919effa0ac216", - "command": filepath.Base(command), - } - req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(namespace.RootContext(testCtx), req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v", err, resp) - } - } else { - err := os.Remove(filepath.Join(cluster.TempDir, filepath.Base(command))) - if err != nil { - t.Fatal(err) - } - } - - // Seal the cluster - cluster.EnsureCoresSealed(t) - - // Unseal the cluster - barrierKeys := cluster.BarrierKeys - for _, core := range cluster.Cores { - for _, key := range barrierKeys { - _, err := core.Unseal(vault.TestKeyCopy(key)) - if err != nil { - t.Fatal(err) - } - } - if core.Sealed() { - t.Fatal("should not be sealed") - } - } - - // Wait for active so post-unseal takes place - // If it fails, it means unseal process failed - vault.TestWaitActive(t, core.Core) - - env := []string{pluginutil.PluginCACertPEMEnv + "=" + cluster.CACertPEMFile} - // Re-add the plugin to the catalog - switch btype { - case logical.TypeLogical: - plugin := logicalVersionMap[tc.pluginVersion] - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", plugin, env, cluster.TempDir) - case logical.TypeCredential: - plugin := credentialVersionMap[tc.pluginVersion] - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "", plugin, env, cluster.TempDir) - } - - // Reload the plugin - req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/reload/backend") - req.Data = map[string]interface{}{ - "plugin": "mock-plugin", - } - req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(namespace.RootContext(testCtx), req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v", err, resp) - } - - // Make a request to lazy load the plugin - var reqPath string - switch btype { - case logical.TypeLogical: - reqPath = "mock-0/internal" - case logical.TypeCredential: - reqPath = "auth/mock-0/internal" - } - - req = logical.TestRequest(t, logical.ReadOperation, reqPath) - req.ClientToken = core.Client.Token() - resp, err = core.HandleRequest(namespace.RootContext(testCtx), req) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp == nil { - t.Fatalf("bad: response should not be nil") - } - }) - } -} - func TestSystemBackend_Plugin_autoReload(t *testing.T) { t.Parallel() testCases := []struct { diff --git a/vault/router.go b/vault/router.go index 211c054ef..fe9ab26a1 100644 --- a/vault/router.go +++ b/vault/router.go @@ -462,7 +462,12 @@ func (r *Router) MatchingBackend(ctx context.Context, path string) logical.Backe if !ok { return nil } - return raw.(*routeEntry).backend + + re := raw.(*routeEntry) + re.l.RLock() + defer re.l.RUnlock() + + return re.backend } // MatchingSystemView returns the SystemView used for a path