Add plugin auto-reload capability (#3171)

* Add automatic plugin reload

* Refactor builtin/backend

* Remove plugin reload at the core level

* Refactor plugin tests

* Add auto-reload test case

* Change backend to use sync.RWMutex, fix dangling test plugin processes

* Add a canary to plugin backends to avoid reloading many times (#3174)

* Call setupPluginCatalog before mount-related operations in postUnseal

* Don't create multiple system backends since core only holds a reference (#3176)

to one.
This commit is contained in:
Calvin Leung Huang 2017-08-15 22:10:32 -04:00 committed by GitHub
parent 102848b30a
commit 86ea7e945d
8 changed files with 261 additions and 106 deletions

View File

@ -2,7 +2,10 @@ package plugin
import (
"fmt"
"net/rpc"
"sync"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/logical"
bplugin "github.com/hashicorp/vault/logical/plugin"
)
@ -27,13 +30,111 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
// Backend returns an instance of the backend, either as a plugin if external
// or as a concrete implementation if builtin, casted as logical.Backend.
func Backend(conf *logical.BackendConfig) (logical.Backend, error) {
var b backend
name := conf.Config["plugin_name"]
sys := conf.System
b, err := bplugin.NewBackend(name, sys, conf.Logger)
raw, err := bplugin.NewBackend(name, sys, conf.Logger)
if err != nil {
return nil, err
}
b.Backend = raw
b.config = conf
return b, nil
return &b, nil
}
// backend is a thin wrapper around plugin.BackendPluginClient
type backend struct {
logical.Backend
sync.RWMutex
config *logical.BackendConfig
// Used to detect if we already reloaded
canary string
}
func (b *backend) reloadBackend() error {
pluginName := b.config.Config["plugin_name"]
b.Logger().Trace("plugin: reloading plugin backend", "plugin", pluginName)
// Ensure proper cleanup of the backend (i.e. call client.Kill())
b.Backend.Cleanup()
nb, err := bplugin.NewBackend(pluginName, b.config.System, b.config.Logger)
if err != nil {
return err
}
err = nb.Setup(b.config)
if err != nil {
return err
}
b.Backend = nb
return nil
}
// HandleRequest is a thin wrapper implementation of HandleRequest that includes automatic plugin reload.
func (b *backend) HandleRequest(req *logical.Request) (*logical.Response, error) {
b.RLock()
canary := b.canary
resp, err := b.Backend.HandleRequest(req)
b.RUnlock()
// Need to compare string value for case were err comes from plugin RPC
// and is returned as plugin.BasicError type.
if err != nil && err.Error() == rpc.ErrShutdown.Error() {
// Reload plugin if it's an rpc.ErrShutdown
b.Lock()
if b.canary == canary {
err := b.reloadBackend()
if err != nil {
b.Unlock()
return nil, err
}
b.canary, err = uuid.GenerateUUID()
if err != nil {
b.Unlock()
return nil, err
}
}
b.Unlock()
// Try request once more
b.RLock()
defer b.RUnlock()
return b.Backend.HandleRequest(req)
}
return resp, err
}
// HandleExistenceCheck is a thin wrapper implementation of HandleRequest that includes automatic plugin reload.
func (b *backend) HandleExistenceCheck(req *logical.Request) (bool, bool, error) {
b.RLock()
canary := b.canary
checkFound, exists, err := b.Backend.HandleExistenceCheck(req)
b.RUnlock()
if err != nil && err.Error() == rpc.ErrShutdown.Error() {
// Reload plugin if it's an rpc.ErrShutdown
b.Lock()
if b.canary == canary {
err := b.reloadBackend()
if err != nil {
b.Unlock()
return false, false, err
}
b.canary, err = uuid.GenerateUUID()
if err != nil {
b.Unlock()
return false, false, err
}
}
b.Unlock()
// Try request once more
b.RLock()
defer b.RUnlock()
return b.Backend.HandleExistenceCheck(req)
}
return checkFound, exists, err
}

View File

@ -14,6 +14,10 @@ import (
log "github.com/mgutz/logxi/v1"
)
func TestBackend_impl(t *testing.T) {
var _ logical.Backend = &backend{}
}
func TestBackend(t *testing.T) {
config, cleanup := testConfig(t)
defer cleanup()

View File

@ -38,10 +38,13 @@ func Backend() *backend {
var b backend
b.Backend = &framework.Backend{
Help: "",
Paths: []*framework.Path{
pathKV(&b),
pathInternal(&b),
},
Paths: framework.PathAppend(
errorPaths(&b),
kvPaths(&b),
[]*framework.Path{
pathInternal(&b),
},
),
PathsSpecial: &logical.Paths{
Unauthenticated: []string{
"special",

View File

@ -0,0 +1,32 @@
package mock
import (
"net/rpc"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
// pathInternal is used to test viewing internal backend values. In this case,
// it is used to test the invalidate func.
func errorPaths(b *backend) []*framework.Path {
return []*framework.Path{
&framework.Path{
Pattern: "errors/rpc",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathErrorRPCRead,
},
},
&framework.Path{
Pattern: "errors/kill",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathErrorRPCRead,
},
},
}
}
func (b *backend) pathErrorRPCRead(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
return nil, rpc.ErrShutdown
}

View File

@ -7,22 +7,29 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
// pathKV is used to test CRUD and List operations. It is a simplified
// kvPaths is used to test CRUD and List operations. It is a simplified
// version of the passthrough backend that only accepts string values.
func pathKV(b *backend) *framework.Path {
return &framework.Path{
Pattern: "kv/" + framework.GenericNameRegex("key"),
Fields: map[string]*framework.FieldSchema{
"key": &framework.FieldSchema{Type: framework.TypeString},
"value": &framework.FieldSchema{Type: framework.TypeString},
func kvPaths(b *backend) []*framework.Path {
return []*framework.Path{
&framework.Path{
Pattern: "kv/?",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathKVList,
},
},
ExistenceCheck: b.pathExistenceCheck,
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathKVRead,
logical.CreateOperation: b.pathKVCreateUpdate,
logical.UpdateOperation: b.pathKVCreateUpdate,
logical.DeleteOperation: b.pathKVDelete,
logical.ListOperation: b.pathKVList,
&framework.Path{
Pattern: "kv/" + framework.GenericNameRegex("key"),
Fields: map[string]*framework.FieldSchema{
"key": &framework.FieldSchema{Type: framework.TypeString},
"value": &framework.FieldSchema{Type: framework.TypeString},
},
ExistenceCheck: b.pathExistenceCheck,
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathKVRead,
logical.CreateOperation: b.pathKVCreateUpdate,
logical.UpdateOperation: b.pathKVCreateUpdate,
logical.DeleteOperation: b.pathKVDelete,
},
},
}
}

View File

@ -1347,6 +1347,9 @@ func (c *Core) postUnseal() (retErr error) {
if err := c.ensureWrappingKey(); err != nil {
return err
}
if err := c.setupPluginCatalog(); err != nil {
return err
}
if err := c.loadMounts(); err != nil {
return err
}
@ -1380,9 +1383,6 @@ func (c *Core) postUnseal() (retErr error) {
if err := c.setupAuditedHeadersConfig(); err != nil {
return err
}
if err := c.setupPluginCatalog(); err != nil {
return err
}
if c.ha != nil {
if err := c.startClusterListener(); err != nil {

View File

@ -121,6 +121,12 @@ func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl tim
// LookupPlugin looks for a plugin with the given name in the plugin catalog. It
// returns a PluginRunner or an error if no plugin was found.
func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) {
if d.core == nil {
return nil, fmt.Errorf("system view core is nil")
}
if d.core.pluginCatalog == nil {
return nil, fmt.Errorf("system view core plugin catalog is nil")
}
r, err := d.core.pluginCatalog.Get(name)
if err != nil {
return nil, err

View File

@ -4,67 +4,68 @@ import (
"fmt"
"os"
"testing"
"time"
"github.com/hashicorp/vault/builtin/plugin"
"github.com/hashicorp/vault/helper/logformat"
"github.com/hashicorp/vault/helper/pluginutil"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
lplugin "github.com/hashicorp/vault/logical/plugin"
"github.com/hashicorp/vault/logical/plugin/mock"
"github.com/hashicorp/vault/vault"
log "github.com/mgutz/logxi/v1"
)
func TestSystemBackend_enableAuth_plugin(t *testing.T) {
coreConfig := &vault.CoreConfig{
CredentialBackends: map[string]logical.Factory{
"plugin": plugin.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
func TestSystemBackend_Plugin_secret(t *testing.T) {
cluster := testSystemBackendMock(t, 1, logical.TypeLogical)
defer cluster.Cleanup()
core := cluster.Cores[0].Core
vault.TestWaitActive(t, core)
}
b := vault.NewSystemBackend(core)
logger := logformat.NewVaultLogger(log.LevelTrace)
bc := &logical.BackendConfig{
Logger: logger,
System: logical.StaticSystemView{
DefaultLeaseTTLVal: time.Hour * 24,
MaxLeaseTTLVal: time.Hour * 24 * 32,
},
}
func TestSystemBackend_Plugin_auth(t *testing.T) {
cluster := testSystemBackendMock(t, 1, logical.TypeCredential)
defer cluster.Cleanup()
}
err := b.Backend.Setup(bc)
if err != nil {
t.Fatal(err)
}
func TestSystemBackend_Plugin_autoReload(t *testing.T) {
cluster := testSystemBackendMock(t, 1, logical.TypeLogical)
defer cluster.Cleanup()
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
core := cluster.Cores[0]
vault.TestAddTestPlugin(t, core, "mock-plugin", "TestBackend_PluginMainCredentials")
req := logical.TestRequest(t, logical.UpdateOperation, "auth/mock-plugin")
req.Data["type"] = "plugin"
req.Data["plugin_name"] = "mock-plugin"
resp, err := b.HandleRequest(req)
// Update internal value
req := logical.TestRequest(t, logical.UpdateOperation, "mock-0/internal")
req.ClientToken = core.Client.Token()
req.Data["value"] = "baz"
resp, err := core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %v", resp)
}
// Call errors/rpc endpoint to trigger reload
req = logical.TestRequest(t, logical.ReadOperation, "mock-0/errors/rpc")
req.ClientToken = core.Client.Token()
resp, err = core.HandleRequest(req)
if err == nil {
t.Fatalf("expected error from error/rpc request")
}
// Check internal value to make sure it's reset
req = logical.TestRequest(t, logical.ReadOperation, "mock-0/internal")
req.ClientToken = core.Client.Token()
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: response should not be nil")
}
if resp.Data["value"].(string) == "baz" {
t.Fatal("did not expect backend internal value to be 'baz'")
}
}
func TestSystemBackend_PluginReload(t *testing.T) {
func TestSystemBackend_Plugin_reload(t *testing.T) {
data := map[string]interface{}{
"plugin": "mock-plugin",
}
@ -77,17 +78,17 @@ func TestSystemBackend_PluginReload(t *testing.T) {
}
func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}) {
cluster, b := testSystemBackendMock(t, 2)
cluster := testSystemBackendMock(t, 2, logical.TypeLogical)
defer cluster.Cleanup()
core := cluster.Cores[0]
client := core.Client
for i := 0; i < 2; i++ {
// Update internal value in the backend
req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("mock-%d/internal", i))
req.ClientToken = core.Client.Token()
req.Data["value"] = "baz"
resp, err := core.HandleRequest(req)
resp, err := client.Logical().Write(fmt.Sprintf("mock-%d/internal", i), map[string]interface{}{
"value": "baz",
})
if err != nil {
t.Fatalf("err: %v", err)
}
@ -97,10 +98,7 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
}
// Perform plugin reload
req := logical.TestRequest(t, logical.UpdateOperation, "plugins/backend/reload")
req.ClientToken = core.Client.Token()
req.Data = reqData
resp, err := b.HandleRequest(req)
resp, err := client.Logical().Write("sys/plugins/backend/reload", reqData)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -110,9 +108,7 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
for i := 0; i < 2; i++ {
// Ensure internal backed value is reset
req := logical.TestRequest(t, logical.ReadOperation, "mock-1/internal")
req.ClientToken = core.Client.Token()
resp, err := core.HandleRequest(req)
resp, err := client.Logical().Read(fmt.Sprintf("mock-%d/internal", i))
if err != nil {
t.Fatalf("err: %v", err)
}
@ -127,11 +123,14 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
// testSystemBackendMock returns a systemBackend with the desired number
// of mounted mock plugin backends
func testSystemBackendMock(t *testing.T, numMounts int) (*vault.TestCluster, *vault.SystemBackend) {
func testSystemBackendMock(t *testing.T, numMounts int, backendType logical.BackendType) *vault.TestCluster {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"plugin": plugin.Factory,
},
CredentialBackends: map[string]logical.Factory{
"plugin": plugin.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
@ -139,45 +138,48 @@ func testSystemBackendMock(t *testing.T, numMounts int) (*vault.TestCluster, *va
})
cluster.Start()
core := cluster.Cores[0].Core
vault.TestWaitActive(t, core)
b := vault.NewSystemBackend(core)
logger := logformat.NewVaultLogger(log.LevelTrace)
bc := &logical.BackendConfig{
Logger: logger,
System: logical.StaticSystemView{
DefaultLeaseTTLVal: time.Hour * 24,
MaxLeaseTTLVal: time.Hour * 24 * 32,
},
}
err := b.Backend.Setup(bc)
if err != nil {
t.Fatal(err)
}
core := cluster.Cores[0]
vault.TestWaitActive(t, core.Core)
client := core.Client
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
vault.TestAddTestPlugin(t, core, "mock-plugin", "TestBackend_PluginMainLogical")
for i := 0; i < numMounts; i++ {
req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("mounts/mock-%d/", i))
req.Data["type"] = "plugin"
req.Data["config"] = map[string]interface{}{
"plugin_name": "mock-plugin",
switch backendType {
case logical.TypeLogical:
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical")
for i := 0; i < numMounts; i++ {
resp, err := client.Logical().Write(fmt.Sprintf("sys/mounts/mock-%d", i), map[string]interface{}{
"type": "plugin",
"config": map[string]interface{}{
"plugin_name": "mock-plugin",
},
})
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %v", resp)
}
}
resp, err := b.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %v", resp)
case logical.TypeCredential:
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials")
for i := 0; i < numMounts; i++ {
resp, err := client.Logical().Write(fmt.Sprintf("sys/auth/mock-%d", i), map[string]interface{}{
"type": "plugin",
"plugin_name": "mock-plugin",
})
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %v", resp)
}
}
default:
t.Fatal("unknown backend type provided")
}
return cluster, b
return cluster
}
func TestBackend_PluginMainLogical(t *testing.T) {