Fix a possible data race with rollback manager and plugin reload (#19468)

* fix data race on plugin reload

* add changelog

* add comment for posterity

* revert comment and return assignment in router.go

* rework plugin continue on error tests to use compilePlugin

* fix race condition on route entry

* add test for plugin reload and rollback race detection

* add go doc for test
This commit is contained in:
John-Michael Faircloth 2023-03-14 09:36:37 -05:00 committed by GitHub
parent f7f19aab3c
commit 1553c310c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 227 additions and 206 deletions

3
changelog/19468.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
plugin/reload: Fix a possible data race with rollback manager and plugin reload
```

View File

@ -3,13 +3,18 @@ package plugin_test
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"path/filepath"
"testing" "testing"
"time"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/api/auth/approle" "github.com/hashicorp/vault/api/auth/approle"
"github.com/hashicorp/vault/builtin/logical/database" "github.com/hashicorp/vault/builtin/logical/database"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/testhelpers/consul" "github.com/hashicorp/vault/helper/testhelpers/consul"
"github.com/hashicorp/vault/helper/testhelpers/corehelpers" "github.com/hashicorp/vault/helper/testhelpers/corehelpers"
"github.com/hashicorp/vault/helper/testhelpers/pluginhelpers"
postgreshelper "github.com/hashicorp/vault/helper/testhelpers/postgresql" postgreshelper "github.com/hashicorp/vault/helper/testhelpers/postgresql"
vaulthttp "github.com/hashicorp/vault/http" vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
@ -30,6 +35,7 @@ func getCluster(t *testing.T, typ consts.PluginType, numCores int) *vault.TestCl
} }
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
TempDir: pluginDir,
NumCores: numCores, NumCores: numCores,
Plugins: &vault.TestPluginConfig{ Plugins: &vault.TestPluginConfig{
Typ: typ, Typ: typ,
@ -44,6 +50,208 @@ func getCluster(t *testing.T, typ consts.PluginType, numCores int) *vault.TestCl
return cluster return cluster
} }
// TestExternalPlugin_RollbackAndReload ensures that we can successfully
// rollback and reload a plugin without triggering race conditions by the go
// race detector
func TestExternalPlugin_RollbackAndReload(t *testing.T) {
pluginDir, cleanup := corehelpers.MakeTestPluginDir(t)
t.Cleanup(func() { cleanup(t) })
coreConfig := &vault.CoreConfig{
// set rollback period to a short interval to make conditions more "racy"
RollbackPeriod: 1 * time.Second,
PluginDirectory: pluginDir,
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
TempDir: pluginDir,
NumCores: 1,
Plugins: &vault.TestPluginConfig{
Typ: consts.PluginTypeSecrets,
Versions: []string{""},
},
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
vault.TestWaitActive(t, cluster.Cores[0].Core)
core := cluster.Cores[0]
plugin := cluster.Plugins[0]
client := core.Client
client.SetToken(cluster.RootToken)
testRegisterAndEnable(t, client, plugin)
if _, err := client.Sys().ReloadPlugin(&api.ReloadPluginInput{
Plugin: plugin.Name,
}); err != nil {
t.Fatal(err)
}
}
func testRegisterAndEnable(t *testing.T, client *api.Client, plugin pluginhelpers.TestPlugin) {
t.Helper()
if err := client.Sys().RegisterPlugin(&api.RegisterPluginInput{
Name: plugin.Name,
Type: api.PluginType(plugin.Typ),
Command: plugin.Name,
SHA256: plugin.Sha256,
Version: plugin.Version,
}); err != nil {
t.Fatal(err)
}
switch plugin.Typ {
case consts.PluginTypeSecrets:
if err := client.Sys().Mount(plugin.Name, &api.MountInput{
Type: plugin.Name,
}); err != nil {
t.Fatal(err)
}
case consts.PluginTypeCredential:
if err := client.Sys().EnableAuthWithOptions(plugin.Name, &api.EnableAuthOptions{
Type: plugin.Name,
}); err != nil {
t.Fatal(err)
}
}
}
// TestExternalPlugin_ContinueOnError tests that vault can recover from a
// sha256 mismatch or missing plugin binary scenario
func TestExternalPlugin_ContinueOnError(t *testing.T) {
t.Run("secret", func(t *testing.T) {
t.Parallel()
t.Run("sha256_mismatch", func(t *testing.T) {
t.Parallel()
testExternalPlugin_ContinueOnError(t, true, consts.PluginTypeSecrets)
})
t.Run("missing_plugin", func(t *testing.T) {
t.Parallel()
testExternalPlugin_ContinueOnError(t, false, consts.PluginTypeSecrets)
})
})
t.Run("auth", func(t *testing.T) {
t.Parallel()
t.Run("sha256_mismatch", func(t *testing.T) {
t.Parallel()
testExternalPlugin_ContinueOnError(t, true, consts.PluginTypeCredential)
})
t.Run("missing_plugin", func(t *testing.T) {
t.Parallel()
testExternalPlugin_ContinueOnError(t, false, consts.PluginTypeCredential)
})
})
}
func testExternalPlugin_ContinueOnError(t *testing.T, mismatch bool, pluginType consts.PluginType) {
cluster := getCluster(t, pluginType, 1)
defer cluster.Cleanup()
core := cluster.Cores[0]
plugin := cluster.Plugins[0]
client := core.Client
client.SetToken(cluster.RootToken)
testRegisterAndEnable(t, client, plugin)
pluginPath := fmt.Sprintf("sys/plugins/catalog/%s/%s", pluginType, plugin.Name)
// Get the registered plugin
req := logical.TestRequest(t, logical.ReadOperation, pluginPath)
req.ClientToken = core.Client.Token()
resp, err := core.HandleRequest(namespace.RootContext(testCtx), req)
if err != nil || resp == nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
command, ok := resp.Data["command"].(string)
if !ok || command == "" {
t.Fatal("invalid command")
}
// Trigger a sha256 mismatch or missing plugin error
if mismatch {
req = logical.TestRequest(t, logical.UpdateOperation, pluginPath)
req.Data = map[string]interface{}{
"sha256": "d17bd7334758e53e6fbab15745d2520765c06e296f2ce8e25b7919effa0ac216",
"command": filepath.Base(command),
}
req.ClientToken = core.Client.Token()
resp, err = core.HandleRequest(namespace.RootContext(testCtx), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
} else {
err := os.Remove(filepath.Join(cluster.TempDir, filepath.Base(command)))
if err != nil {
t.Fatal(err)
}
}
// Seal the cluster
cluster.EnsureCoresSealed(t)
// Unseal the cluster
barrierKeys := cluster.BarrierKeys
for _, core := range cluster.Cores {
for _, key := range barrierKeys {
_, err := core.Unseal(vault.TestKeyCopy(key))
if err != nil {
t.Fatal(err)
}
}
if core.Sealed() {
t.Fatal("should not be sealed")
}
}
// Wait for active so post-unseal takes place
// If it fails, it means unseal process failed
vault.TestWaitActive(t, core.Core)
// unmount
switch pluginType {
case consts.PluginTypeSecrets:
if err := client.Sys().Unmount(plugin.Name); err != nil {
t.Fatal(err)
}
case consts.PluginTypeCredential:
if err := client.Sys().DisableAuth(plugin.Name); err != nil {
t.Fatal(err)
}
}
// Re-compile plugin
var plugins []pluginhelpers.TestPlugin
plugins = append(plugins, pluginhelpers.CompilePlugin(t, pluginType, "", core.CoreConfig.PluginDirectory))
cluster.Plugins = plugins
// Re-add the plugin to the catalog
testRegisterAndEnable(t, client, plugin)
// Reload the plugin
req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/reload/backend")
req.Data = map[string]interface{}{
"plugin": plugin.Name,
}
req.ClientToken = core.Client.Token()
resp, err = core.HandleRequest(namespace.RootContext(testCtx), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
req = logical.TestRequest(t, logical.ReadOperation, pluginPath)
req.ClientToken = core.Client.Token()
resp, err = core.HandleRequest(namespace.RootContext(testCtx), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: response should not be nil")
}
}
// TestExternalPlugin_AuthMethod tests that we can build, register and use an // TestExternalPlugin_AuthMethod tests that we can build, register and use an
// external auth method // external auth method
func TestExternalPlugin_AuthMethod(t *testing.T) { func TestExternalPlugin_AuthMethod(t *testing.T) {
@ -173,28 +381,10 @@ func TestExternalPlugin_AuthMethodReload(t *testing.T) {
client := cluster.Cores[0].Client client := cluster.Cores[0].Client
client.SetToken(cluster.RootToken) client.SetToken(cluster.RootToken)
// Register testRegisterAndEnable(t, client, plugin)
if err := client.Sys().RegisterPlugin(&api.RegisterPluginInput{
Name: plugin.Name,
Type: api.PluginType(plugin.Typ),
Command: plugin.Name,
SHA256: plugin.Sha256,
Version: plugin.Version,
}); err != nil {
t.Fatal(err)
}
pluginPath := fmt.Sprintf("%s-%d", plugin.Name, 0)
// Enable
if err := client.Sys().EnableAuthWithOptions(pluginPath, &api.EnableAuthOptions{
Type: plugin.Name,
}); err != nil {
t.Fatal(err)
}
// Configure // Configure
_, err := client.Logical().Write("auth/"+pluginPath+"/role/role1", map[string]interface{}{ _, err := client.Logical().Write("auth/"+plugin.Name+"/role/role1", map[string]interface{}{
"bind_secret_id": "true", "bind_secret_id": "true",
"period": "300", "period": "300",
}) })
@ -202,13 +392,13 @@ func TestExternalPlugin_AuthMethodReload(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
secret, err := client.Logical().Write("auth/"+pluginPath+"/role/role1/secret-id", nil) secret, err := client.Logical().Write("auth/"+plugin.Name+"/role/role1/secret-id", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
secretID := secret.Data["secret_id"].(string) secretID := secret.Data["secret_id"].(string)
secret, err = client.Logical().Read("auth/" + pluginPath + "/role/role1/role-id") secret, err = client.Logical().Read("auth/" + plugin.Name + "/role/role1/role-id")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -218,7 +408,7 @@ func TestExternalPlugin_AuthMethodReload(t *testing.T) {
authMethod, err := approle.NewAppRoleAuth( authMethod, err := approle.NewAppRoleAuth(
roleID, roleID,
&approle.SecretID{FromString: secretID}, &approle.SecretID{FromString: secretID},
approle.WithMountPath(pluginPath), approle.WithMountPath(plugin.Name),
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -346,30 +536,13 @@ func TestExternalPlugin_SecretsEngineReload(t *testing.T) {
client := cluster.Cores[0].Client client := cluster.Cores[0].Client
client.SetToken(cluster.RootToken) client.SetToken(cluster.RootToken)
// Register testRegisterAndEnable(t, client, plugin)
if err := client.Sys().RegisterPlugin(&api.RegisterPluginInput{
Name: plugin.Name,
Type: api.PluginType(plugin.Typ),
Command: plugin.Name,
SHA256: plugin.Sha256,
Version: plugin.Version,
}); err != nil {
t.Fatal(err)
}
pluginPath := fmt.Sprintf("%s-%d", plugin.Name, 0)
// Enable
if err := client.Sys().Mount(pluginPath, &api.MountInput{
Type: plugin.Name,
}); err != nil {
t.Fatal(err)
}
// Configure // Configure
cleanupConsul, consulConfig := consul.PrepareTestContainer(t, "", false, true) cleanupConsul, consulConfig := consul.PrepareTestContainer(t, "", false, true)
defer cleanupConsul() defer cleanupConsul()
_, err := client.Logical().Write(pluginPath+"/config/access", map[string]interface{}{ _, err := client.Logical().Write(plugin.Name+"/config/access", map[string]interface{}{
"address": consulConfig.Address(), "address": consulConfig.Address(),
"token": consulConfig.Token, "token": consulConfig.Token,
}) })
@ -377,7 +550,7 @@ func TestExternalPlugin_SecretsEngineReload(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
_, err = client.Logical().Write(pluginPath+"/roles/test", map[string]interface{}{ _, err = client.Logical().Write(plugin.Name+"/roles/test", map[string]interface{}{
"consul_policies": []string{"test"}, "consul_policies": []string{"test"},
"ttl": "6h", "ttl": "6h",
"local": false, "local": false,
@ -386,7 +559,7 @@ func TestExternalPlugin_SecretsEngineReload(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
resp, err := client.Logical().Read(pluginPath + "/creds/test") resp, err := client.Logical().Read(plugin.Name + "/creds/test")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -401,7 +574,7 @@ func TestExternalPlugin_SecretsEngineReload(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
resp, err = client.Logical().Read(pluginPath + "/creds/test") resp, err = client.Logical().Read(plugin.Name + "/creds/test")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -360,166 +360,6 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
} }
} }
func TestSystemBackend_Plugin_continueOnError(t *testing.T) {
t.Run("secret", func(t *testing.T) {
t.Parallel()
t.Run("sha256_mismatch", func(t *testing.T) {
t.Parallel()
testPlugin_continueOnError(t, logical.TypeLogical, true, "mock-plugin", consts.PluginTypeSecrets)
})
t.Run("missing_plugin", func(t *testing.T) {
t.Parallel()
testPlugin_continueOnError(t, logical.TypeLogical, false, "mock-plugin", consts.PluginTypeSecrets)
})
})
t.Run("auth", func(t *testing.T) {
t.Parallel()
t.Run("sha256_mismatch", func(t *testing.T) {
t.Parallel()
testPlugin_continueOnError(t, logical.TypeCredential, true, "mock-plugin", consts.PluginTypeCredential)
})
t.Run("missing_plugin", func(t *testing.T) {
t.Parallel()
testPlugin_continueOnError(t, logical.TypeCredential, false, "mock-plugin", consts.PluginTypeCredential)
})
t.Run("sha256_mismatch", func(t *testing.T) {
t.Parallel()
testPlugin_continueOnError(t, logical.TypeCredential, true, "oidc", consts.PluginTypeCredential)
})
t.Run("missing_plugin", func(t *testing.T) {
t.Parallel()
testPlugin_continueOnError(t, logical.TypeCredential, false, "oidc", consts.PluginTypeCredential)
})
})
}
func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatch bool, mountPoint string, pluginType consts.PluginType) {
testCases := []struct {
pluginVersion string
}{
{
pluginVersion: "v5_multiplexed",
},
{
pluginVersion: "v5",
},
{
pluginVersion: "v4",
},
}
for _, tc := range testCases {
t.Run(tc.pluginVersion, func(t *testing.T) {
t.Parallel()
cluster := testSystemBackendMock(t, 1, 1, btype, tc.pluginVersion)
defer cluster.Cleanup()
core := cluster.Cores[0]
// Get the registered plugin
req := logical.TestRequest(t, logical.ReadOperation, fmt.Sprintf("sys/plugins/catalog/%s/mock-plugin", pluginType))
// We are using the mock backend from vault/sdk/plugin/mock/backend.go which sets the plugin version.
req.Data["version"] = "v0.0.0+mock"
req.ClientToken = core.Client.Token()
resp, err := core.HandleRequest(namespace.RootContext(testCtx), req)
if err != nil || resp == nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
command, ok := resp.Data["command"].(string)
if !ok || command == "" {
t.Fatal("invalid command")
}
// Trigger a sha256 mismatch or missing plugin error
if mismatch {
req = logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("sys/plugins/catalog/%s/mock-plugin", pluginType))
req.Data = map[string]interface{}{
"sha256": "d17bd7334758e53e6fbab15745d2520765c06e296f2ce8e25b7919effa0ac216",
"command": filepath.Base(command),
}
req.ClientToken = core.Client.Token()
resp, err = core.HandleRequest(namespace.RootContext(testCtx), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
} else {
err := os.Remove(filepath.Join(cluster.TempDir, filepath.Base(command)))
if err != nil {
t.Fatal(err)
}
}
// Seal the cluster
cluster.EnsureCoresSealed(t)
// Unseal the cluster
barrierKeys := cluster.BarrierKeys
for _, core := range cluster.Cores {
for _, key := range barrierKeys {
_, err := core.Unseal(vault.TestKeyCopy(key))
if err != nil {
t.Fatal(err)
}
}
if core.Sealed() {
t.Fatal("should not be sealed")
}
}
// Wait for active so post-unseal takes place
// If it fails, it means unseal process failed
vault.TestWaitActive(t, core.Core)
env := []string{pluginutil.PluginCACertPEMEnv + "=" + cluster.CACertPEMFile}
// Re-add the plugin to the catalog
switch btype {
case logical.TypeLogical:
plugin := logicalVersionMap[tc.pluginVersion]
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", plugin, env, cluster.TempDir)
case logical.TypeCredential:
plugin := credentialVersionMap[tc.pluginVersion]
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "", plugin, env, cluster.TempDir)
}
// Reload the plugin
req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/reload/backend")
req.Data = map[string]interface{}{
"plugin": "mock-plugin",
}
req.ClientToken = core.Client.Token()
resp, err = core.HandleRequest(namespace.RootContext(testCtx), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Make a request to lazy load the plugin
var reqPath string
switch btype {
case logical.TypeLogical:
reqPath = "mock-0/internal"
case logical.TypeCredential:
reqPath = "auth/mock-0/internal"
}
req = logical.TestRequest(t, logical.ReadOperation, reqPath)
req.ClientToken = core.Client.Token()
resp, err = core.HandleRequest(namespace.RootContext(testCtx), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: response should not be nil")
}
})
}
}
func TestSystemBackend_Plugin_autoReload(t *testing.T) { func TestSystemBackend_Plugin_autoReload(t *testing.T) {
t.Parallel() t.Parallel()
testCases := []struct { testCases := []struct {

View File

@ -462,7 +462,12 @@ func (r *Router) MatchingBackend(ctx context.Context, path string) logical.Backe
if !ok { if !ok {
return nil return nil
} }
return raw.(*routeEntry).backend
re := raw.(*routeEntry)
re.l.RLock()
defer re.l.RUnlock()
return re.backend
} }
// MatchingSystemView returns the SystemView used for a path // MatchingSystemView returns the SystemView used for a path