Lazy-load plugin mounts (#3255)

* Lazy load plugins to avoid setup-unwrap cycle

* Remove commented blocks

* Refactor NewTestCluster, use single core cluster on basic plugin tests

* Set c.pluginDirectory in TestAddTestPlugin for setupPluginCatalog to work properly

* Add special path to mock plugin

* Move ensureCoresSealed to vault/testing.go

* Use same method for EnsureCoresSealed and Cleanup

* Bump ensureCoresSealed timeout to 60s

* Correctly handle nil opts on NewTestCluster

* Add metadata flag to APIClientMeta, use meta-enabled plugin when mounting to bootstrap

* Check metadata flag directly on the plugin process

* Plumb isMetadataMode down to PluginRunner

* Add NOOP shims when running in metadata mode

* Remove unused flag from the APIMetadata object

* Remove setupSecretPlugins and setupCredentialPlugins functions

* Move when we setup rollback manager to after the plugins are initialized

* Fix tests

* Fix merge issue

* start rollback manager after the credential setup

* Add guards against running certain client and server functions while in metadata mode

* Call initialize once a plugin is loaded on the fly

* Add more tests, update basic secret/auth plugin tests to trigger lazy loading

* Skip mount if plugin removed from catalog

* Fixup

* Remove commented line on LookupPlugin

* Fail on mount operation if plugin is re-added to catalog and mount is on existing path

* Check type and special paths on startBackend

* Fix merge conflicts

* Refactor PluginRunner run methods to use runCommon, fix TestSystemBackend_Plugin_auth
This commit is contained in:
Calvin Leung Huang 2017-09-01 01:02:03 -04:00 committed by GitHub
parent 590e2de328
commit a581e96b78
22 changed files with 816 additions and 480 deletions

View File

@ -3,13 +3,20 @@ package plugin
import (
"fmt"
"net/rpc"
"reflect"
"sync"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
bplugin "github.com/hashicorp/vault/logical/plugin"
)
var (
ErrMismatchType = fmt.Errorf("mismatch on mounted backend and plugin backend type")
ErrMismatchPaths = fmt.Errorf("mismatch on mounted backend and plugin backend special paths")
)
// Factory returns a configured plugin logical.Backend.
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
_, ok := conf.Config["plugin_name"]
@ -31,14 +38,33 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
// or as a concrete implementation if builtin, casted as logical.Backend.
func Backend(conf *logical.BackendConfig) (logical.Backend, error) {
var b backend
name := conf.Config["plugin_name"]
sys := conf.System
raw, err := bplugin.NewBackend(name, sys, conf.Logger)
// NewBackend with isMetadataMode set to true
raw, err := bplugin.NewBackend(name, sys, conf.Logger, true)
if err != nil {
return nil, err
}
b.Backend = raw
err = raw.Setup(conf)
if err != nil {
return nil, err
}
// Get SpecialPaths and BackendType
paths := raw.SpecialPaths()
btype := raw.Type()
// Cleanup meta plugin backend
raw.Cleanup()
// Initialize b.Backend with dummy backend since plugin
// backends will need to be lazy loaded.
b.Backend = &framework.Backend{
PathsSpecial: paths,
BackendType: btype,
}
b.config = conf
return &b, nil
@ -53,16 +79,24 @@ type backend struct {
// Used to detect if we already reloaded
canary string
// Used to detect if plugin is set
loaded bool
}
func (b *backend) reloadBackend() error {
b.Logger().Trace("plugin: reloading plugin backend", "plugin", b.config.Config["plugin_name"])
return b.startBackend()
}
// startBackend starts a plugin backend
func (b *backend) startBackend() error {
pluginName := b.config.Config["plugin_name"]
b.Logger().Trace("plugin: reloading plugin backend", "plugin", pluginName)
// Ensure proper cleanup of the backend (i.e. call client.Kill())
b.Backend.Cleanup()
nb, err := bplugin.NewBackend(pluginName, b.config.System, b.config.Logger)
nb, err := bplugin.NewBackend(pluginName, b.config.System, b.config.Logger, false)
if err != nil {
return err
}
@ -70,7 +104,29 @@ func (b *backend) reloadBackend() error {
if err != nil {
return err
}
// If the backend has not been loaded (i.e. still in metadata mode),
// check if type and special paths still matches
if !b.loaded {
if b.Backend.Type() != nb.Type() {
nb.Cleanup()
b.Logger().Warn("plugin: failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchType)
return ErrMismatchType
}
if !reflect.DeepEqual(b.Backend.SpecialPaths(), nb.SpecialPaths()) {
nb.Cleanup()
b.Logger().Warn("plugin: failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchPaths)
return ErrMismatchPaths
}
}
b.Backend = nb
b.loaded = true
// Call initialize
if err := b.Backend.Initialize(); err != nil {
return err
}
return nil
}
@ -79,6 +135,23 @@ func (b *backend) reloadBackend() error {
func (b *backend) HandleRequest(req *logical.Request) (*logical.Response, error) {
b.RLock()
canary := b.canary
// Lazy-load backend
if !b.loaded {
// Upgrade lock
b.RUnlock()
b.Lock()
// Check once more after lock swap
if !b.loaded {
err := b.startBackend()
if err != nil {
b.Unlock()
return nil, err
}
}
b.Unlock()
b.RLock()
}
resp, err := b.Backend.HandleRequest(req)
b.RUnlock()
// Need to compare string value for case were err comes from plugin RPC
@ -112,6 +185,24 @@ func (b *backend) HandleRequest(req *logical.Request) (*logical.Response, error)
func (b *backend) HandleExistenceCheck(req *logical.Request) (bool, bool, error) {
b.RLock()
canary := b.canary
// Lazy-load backend
if !b.loaded {
// Upgrade lock
b.RUnlock()
b.Lock()
// Check once more after lock swap
if !b.loaded {
err := b.startBackend()
if err != nil {
b.Unlock()
return false, false, err
}
}
b.Unlock()
b.RLock()
}
checkFound, exists, err := b.Backend.HandleExistenceCheck(req)
b.RUnlock()
if err != nil && err.Error() == rpc.ErrShutdown.Error() {

View File

@ -1,6 +1,7 @@
package plugin
import (
"fmt"
"os"
"testing"
@ -39,7 +40,8 @@ func TestBackend_Factory(t *testing.T) {
}
func TestBackend_PluginMain(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
args := []string{}
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadaModeEnv) != "true" {
return
}
@ -48,7 +50,7 @@ func TestBackend_PluginMain(t *testing.T) {
t.Fatal("CA cert not passed in")
}
args := []string{"--ca-cert=" + caPEM}
args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM))
apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet()

View File

@ -7,7 +7,7 @@ import (
)
var (
// PluginUnwrapTokenEnv is the ENV name used to pass the configuration for
// PluginMlockEnabled is the ENV name used to pass the configuration for
// enabling mlock
PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED"
)

View File

@ -2,6 +2,7 @@ package pluginutil
import (
"crypto/sha256"
"crypto/tls"
"flag"
"fmt"
"os/exec"
@ -22,6 +23,7 @@ type Looker interface {
// Wrapper interface defines the functions needed by the runner to wrap the
// metadata needed to run a plugin process. This includes looking up Mlock
// configuration and wrapping data in a respose wrapped token.
// logical.SystemView implementataions satisfy this interface.
type RunnerUtil interface {
ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error)
MlockEnabled() bool
@ -44,56 +46,82 @@ type PluginRunner struct {
BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"`
}
// Run takes a wrapper instance, and the go-plugin paramaters and executes a
// plugin.
// Run takes a wrapper RunnerUtil instance along with the go-plugin paramaters and
// returns a configured plugin.Client with TLS Configured and a wrapping token set
// on PluginUnwrapTokenEnv for plugin process consumption.
func (r *PluginRunner) Run(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string, logger log.Logger) (*plugin.Client, error) {
// Get a CA TLS Certificate
certBytes, key, err := generateCert()
if err != nil {
return nil, err
}
return r.runCommon(wrapper, pluginMap, hs, env, logger, false)
}
// Use CA to sign a client cert and return a configured TLS config
clientTLSConfig, err := createClientTLSConfig(certBytes, key)
if err != nil {
return nil, err
}
// RunMetadataMode returns a configured plugin.Client that will dispense a plugin
// in metadata mode. The PluginMetadaModeEnv is passed in as part of the Cmd to
// plugin.Client, and consumed by the plugin process on pluginutil.VaultPluginTLSProvider.
func (r *PluginRunner) RunMetadataMode(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string, logger log.Logger) (*plugin.Client, error) {
return r.runCommon(wrapper, pluginMap, hs, env, logger, true)
// Use CA to sign a server cert and wrap the values in a response wrapped
// token.
wrapToken, err := wrapServerConfig(wrapper, certBytes, key)
if err != nil {
return nil, err
}
}
func (r *PluginRunner) runCommon(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string, logger log.Logger, isMetadataMode bool) (*plugin.Client, error) {
cmd := exec.Command(r.Command, r.Args...)
cmd.Env = append(cmd.Env, env...)
// Add the response wrap token to the ENV of the plugin
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken))
// Add the mlock setting to the ENV of the plugin
if wrapper.MlockEnabled() {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true"))
}
secureConfig := &plugin.SecureConfig{
Checksum: r.Sha256,
Hash: sha256.New(),
}
// Create logger for the plugin client
clogger := &hclogFaker{
logger: logger,
}
namedLogger := clogger.ResetNamed("plugin")
client := plugin.NewClient(&plugin.ClientConfig{
var clientTLSConfig *tls.Config
if !isMetadataMode {
// Add the metadata mode ENV and set it to false
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMetadaModeEnv, "false"))
// Get a CA TLS Certificate
certBytes, key, err := generateCert()
if err != nil {
return nil, err
}
// Use CA to sign a client cert and return a configured TLS config
clientTLSConfig, err = createClientTLSConfig(certBytes, key)
if err != nil {
return nil, err
}
// Use CA to sign a server cert and wrap the values in a response wrapped
// token.
wrapToken, err := wrapServerConfig(wrapper, certBytes, key)
if err != nil {
return nil, err
}
// Add the response wrap token to the ENV of the plugin
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken))
} else {
namedLogger = clogger.ResetNamed("plugin.metadata")
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMetadaModeEnv, "true"))
}
secureConfig := &plugin.SecureConfig{
Checksum: r.Sha256,
Hash: sha256.New(),
}
clientConfig := &plugin.ClientConfig{
HandshakeConfig: hs,
Plugins: pluginMap,
Cmd: cmd,
TLSConfig: clientTLSConfig,
SecureConfig: secureConfig,
TLSConfig: clientTLSConfig,
Logger: namedLogger,
})
}
client := plugin.NewClient(clientConfig)
return client, nil
}
@ -108,7 +136,7 @@ type APIClientMeta struct {
}
func (f *APIClientMeta) FlagSet() *flag.FlagSet {
fs := flag.NewFlagSet("tls settings", flag.ContinueOnError)
fs := flag.NewFlagSet("vault plugin settings", flag.ContinueOnError)
fs.StringVar(&f.flagCACert, "ca-cert", "", "")
fs.StringVar(&f.flagCAPath, "ca-path", "", "")

View File

@ -29,6 +29,10 @@ var (
// PluginCACertPEMEnv is an ENV name used for holding a CA PEM-encoded
// string. Used for testing.
PluginCACertPEMEnv = "VAULT_TESTING_PLUGIN_CA_PEM"
// PluginMetadaModeEnv is an ENV name used to disable TLS communication
// to bootstrap mounting plugins.
PluginMetadaModeEnv = "VAULT_PLUGIN_METADATA_MODE"
)
// generateCert is used internally to create certificates for the plugin
@ -124,6 +128,10 @@ func wrapServerConfig(sys RunnerUtil, certBytes []byte, key *ecdsa.PrivateKey) (
// VaultPluginTLSProvider is run inside a plugin and retrives the response
// wrapped TLS certificate from vault. It returns a configured TLS Config.
func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, error) {
if os.Getenv(PluginMetadaModeEnv) == "true" {
return nil
}
return func() (*tls.Config, error) {
unwrapToken := os.Getenv(PluginUnwrapTokenEnv)
@ -157,7 +165,10 @@ func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, er
clientConf := api.DefaultConfig()
clientConf.Address = vaultAddr
if apiTLSConfig != nil {
clientConf.ConfigureTLS(apiTLSConfig)
err := clientConf.ConfigureTLS(apiTLSConfig)
if err != nil {
return nil, errwrap.Wrapf("error configuring api client {{err}}", err)
}
}
client, err := api.NewClient(clientConf)
if err != nil {

View File

@ -9,7 +9,8 @@ import (
// BackendPlugin is the plugin.Plugin implementation
type BackendPlugin struct {
Factory func(*logical.BackendConfig) (logical.Backend, error)
Factory func(*logical.BackendConfig) (logical.Backend, error)
metadataMode bool
}
// Server gets called when on plugin.Serve()
@ -19,5 +20,5 @@ func (b *BackendPlugin) Server(broker *plugin.MuxBroker) (interface{}, error) {
// Client gets called on plugin.NewClient()
func (b BackendPlugin) Client(broker *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
return &backendPluginClient{client: c, broker: broker}, nil
return &backendPluginClient{client: c, broker: broker, metadataMode: b.metadataMode}, nil
}

View File

@ -1,6 +1,7 @@
package plugin
import (
"errors"
"net/rpc"
"github.com/hashicorp/go-plugin"
@ -8,11 +9,16 @@ import (
log "github.com/mgutz/logxi/v1"
)
var (
ErrClientInMetadataMode = errors.New("plugin client can not perform action while in metadata mode")
)
// backendPluginClient implements logical.Backend and is the
// go-plugin client.
type backendPluginClient struct {
broker *plugin.MuxBroker
client *rpc.Client
broker *plugin.MuxBroker
client *rpc.Client
metadataMode bool
system logical.SystemView
logger log.Logger
@ -83,6 +89,10 @@ type RegisterLicenseReply struct {
}
func (b *backendPluginClient) HandleRequest(req *logical.Request) (*logical.Response, error) {
if b.metadataMode {
return nil, ErrClientInMetadataMode
}
// Do not send the storage, since go-plugin cannot serialize
// interfaces. The server will pick up the storage from the shim.
req.Storage = nil
@ -136,6 +146,10 @@ func (b *backendPluginClient) Logger() log.Logger {
}
func (b *backendPluginClient) HandleExistenceCheck(req *logical.Request) (bool, bool, error) {
if b.metadataMode {
return false, false, ErrClientInMetadataMode
}
// Do not send the storage, since go-plugin cannot serialize
// interfaces. The server will pick up the storage from the shim.
req.Storage = nil
@ -172,31 +186,49 @@ func (b *backendPluginClient) Cleanup() {
}
func (b *backendPluginClient) Initialize() error {
if b.metadataMode {
return ErrClientInMetadataMode
}
err := b.client.Call("Plugin.Initialize", new(interface{}), &struct{}{})
return err
}
func (b *backendPluginClient) InvalidateKey(key string) {
if b.metadataMode {
return
}
b.client.Call("Plugin.InvalidateKey", key, &struct{}{})
}
func (b *backendPluginClient) Setup(config *logical.BackendConfig) error {
// Shim logical.Storage
storageImpl := config.StorageView
if b.metadataMode {
storageImpl = &NOOPStorage{}
}
storageID := b.broker.NextId()
go b.broker.AcceptAndServe(storageID, &StorageServer{
impl: config.StorageView,
impl: storageImpl,
})
// Shim log.Logger
loggerImpl := config.Logger
if b.metadataMode {
loggerImpl = log.NullLog
}
loggerID := b.broker.NextId()
go b.broker.AcceptAndServe(loggerID, &LoggerServer{
logger: config.Logger,
logger: loggerImpl,
})
// Shim logical.SystemView
sysViewImpl := config.System
if b.metadataMode {
sysViewImpl = &logical.StaticSystemView{}
}
sysViewID := b.broker.NextId()
go b.broker.AcceptAndServe(sysViewID, &SystemViewServer{
impl: config.System,
impl: sysViewImpl,
})
args := &SetupArgs{
@ -233,6 +265,10 @@ func (b *backendPluginClient) Type() logical.BackendType {
}
func (b *backendPluginClient) RegisterLicense(license interface{}) error {
if b.metadataMode {
return ErrClientInMetadataMode
}
var reply RegisterLicenseReply
args := RegisterLicenseArgs{
License: license,

View File

@ -1,12 +1,19 @@
package plugin
import (
"errors"
"net/rpc"
"os"
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil"
"github.com/hashicorp/vault/logical"
)
var (
ErrServerInMetadataMode = errors.New("plugin server can not perform action while in metadata mode")
)
// backendPluginServer is the RPC server that backendPluginClient talks to,
// it methods conforming to requirements by net/rpc
type backendPluginServer struct {
@ -19,7 +26,15 @@ type backendPluginServer struct {
storageClient *rpc.Client
}
func inMetadataMode() bool {
return os.Getenv(pluginutil.PluginMetadaModeEnv) == "true"
}
func (b *backendPluginServer) HandleRequest(args *HandleRequestArgs, reply *HandleRequestReply) error {
if inMetadataMode() {
return ErrServerInMetadataMode
}
storage := &StorageClient{client: b.storageClient}
args.Request.Storage = storage
@ -40,6 +55,10 @@ func (b *backendPluginServer) SpecialPaths(_ interface{}, reply *SpecialPathsRep
}
func (b *backendPluginServer) HandleExistenceCheck(args *HandleExistenceCheckArgs, reply *HandleExistenceCheckReply) error {
if inMetadataMode() {
return ErrServerInMetadataMode
}
storage := &StorageClient{client: b.storageClient}
args.Request.Storage = storage
@ -64,11 +83,19 @@ func (b *backendPluginServer) Cleanup(_ interface{}, _ *struct{}) error {
}
func (b *backendPluginServer) Initialize(_ interface{}, _ *struct{}) error {
if inMetadataMode() {
return ErrServerInMetadataMode
}
err := b.backend.Initialize()
return err
}
func (b *backendPluginServer) InvalidateKey(args string, _ *struct{}) error {
if inMetadataMode() {
return ErrServerInMetadataMode
}
b.backend.InvalidateKey(args)
return nil
}
@ -145,6 +172,10 @@ func (b *backendPluginServer) Type(_ interface{}, reply *TypeReply) error {
}
func (b *backendPluginServer) RegisterLicense(args *RegisterLicenseArgs, reply *RegisterLicenseReply) error {
if inMetadataMode() {
return ErrServerInMetadataMode
}
err := b.backend.RegisterLicense(args.License)
if err != nil {
*reply = RegisterLicenseReply{

View File

@ -43,6 +43,7 @@ func Backend() *backend {
kvPaths(&b),
[]*framework.Path{
pathInternal(&b),
pathSpecial(&b),
},
),
PathsSpecial: &logical.Paths{

View File

@ -13,7 +13,7 @@ import (
func main() {
apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(os.Args)
flags.Parse(os.Args[1:]) // Ignore command, strictly parse flags
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig)

View File

@ -0,0 +1,27 @@
package mock
import (
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
// pathSpecial is used to test special paths.
func pathSpecial(b *backend) *framework.Path {
return &framework.Path{
Pattern: "special",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathSpecialRead,
},
}
}
func (b *backend) pathSpecialRead(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
// Return the secret
return &logical.Response{
Data: map[string]interface{}{
"data": "foo",
},
}, nil
}

View File

@ -40,8 +40,9 @@ func (b *BackendPluginClient) Cleanup() {
// NewBackend will return an instance of an RPC-based client implementation of the backend for
// external plugins, or a concrete implementation of the backend if it is a builtin backend.
// The backend is returned as a logical.Backend interface.
func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (logical.Backend, error) {
// The backend is returned as a logical.Backend interface. The isMetadataMode param determines whether
// the plugin should run in metadata mode.
func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
// Look for plugin in the plugin catalog
pluginRunner, err := sys.LookupPlugin(pluginName)
if err != nil {
@ -65,7 +66,7 @@ func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Log
} else {
// create a backendPluginClient instance
backend, err = newPluginClient(sys, pluginRunner, logger)
backend, err = newPluginClient(sys, pluginRunner, logger, isMetadataMode)
if err != nil {
return nil, err
}
@ -74,12 +75,21 @@ func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Log
return backend, nil
}
func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger) (logical.Backend, error) {
func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
// pluginMap is the map of plugins we can dispense.
pluginMap := map[string]plugin.Plugin{
"backend": &BackendPlugin{},
"backend": &BackendPlugin{
metadataMode: isMetadataMode,
},
}
var client *plugin.Client
var err error
if isMetadataMode {
client, err = pluginRunner.RunMetadataMode(sys, pluginMap, handshakeConfig, []string{}, logger)
} else {
client, err = pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}, logger)
}
client, err := pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}, logger)
if err != nil {
return nil, err
}

View File

@ -20,7 +20,8 @@ type ServeOpts struct {
TLSProviderFunc TLSProdiverFunc
}
// Serve is used to serve a backend plugin
// Serve is a helper function used to serve a backend plugin. This
// should be ran on the plugin's main process.
func Serve(opts *ServeOpts) error {
// pluginMap is the map of plugins we can dispense.
var pluginMap = map[string]plugin.Plugin{
@ -34,6 +35,7 @@ func Serve(opts *ServeOpts) error {
return err
}
// If FetchMetadata is true, run without TLSProvider
plugin.Serve(&plugin.ServeConfig{
HandshakeConfig: handshakeConfig,
Plugins: pluginMap,

View File

@ -117,3 +117,23 @@ type StoragePutReply struct {
type StorageDeleteReply struct {
Error *plugin.BasicError
}
// NOOPStorage is used to deny access to the storage interface while running a
// backend plugin in metadata mode.
type NOOPStorage struct{}
func (s *NOOPStorage) List(prefix string) ([]string, error) {
return []string{}, nil
}
func (s *NOOPStorage) Get(key string) (*logical.StorageEntry, error) {
return nil, nil
}
func (s *NOOPStorage) Put(entry *logical.StorageEntry) error {
return nil
}
func (s *NOOPStorage) Delete(key string) error {
return nil
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"strings"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/logical"
@ -397,7 +398,6 @@ func (c *Core) persistAuth(table *MountTable, localOnly bool) error {
// setupCredentials is invoked after we've loaded the auth table to
// initialize the credential backends and setup the router
func (c *Core) setupCredentials() error {
var backend logical.Backend
var view *BarrierView
var err error
var persistNeeded bool
@ -406,6 +406,7 @@ func (c *Core) setupCredentials() error {
defer c.authLock.Unlock()
for _, entry := range c.auth.Entries {
var backend logical.Backend
// Work around some problematic code that existed in master for a while
if strings.HasPrefix(entry.Path, credentialRoutePrefix) {
entry.Path = strings.TrimPrefix(entry.Path, credentialRoutePrefix)
@ -425,6 +426,9 @@ func (c *Core) setupCredentials() error {
backend, err = c.newCredentialBackend(entry.Type, sysView, view, conf)
if err != nil {
c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err)
if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" {
goto ROUTER_MOUNT
}
return errLoadAuthFailed
}
if backend == nil {
@ -432,15 +436,14 @@ func (c *Core) setupCredentials() error {
}
// Check for the correct backend type
backendType := backend.Type()
if entry.Type == "plugin" && backendType != logical.TypeCredential {
return fmt.Errorf("cannot mount '%s' of type '%s' as an auth backend", entry.Config.PluginName, backendType)
if entry.Type == "plugin" && backend.Type() != logical.TypeCredential {
return fmt.Errorf("cannot mount '%s' of type '%s' as an auth backend", entry.Config.PluginName, backend.Type())
}
if err := backend.Initialize(); err != nil {
return err
}
ROUTER_MOUNT:
// Mount the backend
path := credentialRoutePrefix + entry.Path
err = c.router.Mount(backend, path, entry, view)

View File

@ -1369,9 +1369,6 @@ func (c *Core) postUnseal() (retErr error) {
if err := c.setupMounts(); err != nil {
return err
}
if err := c.startRollback(); err != nil {
return err
}
if err := c.setupPolicyStore(); err != nil {
return err
}
@ -1384,6 +1381,9 @@ func (c *Core) postUnseal() (retErr error) {
if err := c.setupCredentials(); err != nil {
return err
}
if err := c.startRollback(); err != nil {
return err
}
if err := c.setupExpiration(); err != nil {
return err
}

View File

@ -4,6 +4,8 @@ import (
"fmt"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/pluginutil"
"github.com/hashicorp/vault/helper/wrapping"
@ -132,7 +134,7 @@ func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner,
return nil, err
}
if r == nil {
return nil, fmt.Errorf("no plugin found with name: %s", name)
return nil, errwrap.Wrapf(fmt.Sprintf("{{err}}: %s", name), ErrPluginNotFound)
}
return r, nil

View File

@ -4,6 +4,7 @@ import (
"fmt"
"os"
"testing"
"time"
"github.com/hashicorp/vault/builtin/plugin"
"github.com/hashicorp/vault/helper/pluginutil"
@ -15,17 +16,196 @@ import (
)
func TestSystemBackend_Plugin_secret(t *testing.T) {
cluster := testSystemBackendMock(t, 2, logical.TypeLogical)
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
defer cluster.Cleanup()
core := cluster.Cores[0]
// Make a request to lazy load the plugin
req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal")
req.ClientToken = core.Client.Token()
resp, err := core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: response should not be nil")
}
// Seal the cluster
cluster.EnsureCoresSealed(t)
// Unseal the cluster
barrierKeys := cluster.BarrierKeys
for _, core := range cluster.Cores {
for _, key := range barrierKeys {
_, err := core.Unseal(vault.TestKeyCopy(key))
if err != nil {
t.Fatal(err)
}
}
sealed, err := core.Sealed()
if err != nil {
t.Fatalf("err checking seal status: %s", err)
}
if sealed {
t.Fatal("should not be sealed")
}
// Wait for active so post-unseal takes place
// If it fails, it means unseal process failed
vault.TestWaitActive(t, core.Core)
}
}
func TestSystemBackend_Plugin_auth(t *testing.T) {
cluster := testSystemBackendMock(t, 2, logical.TypeCredential)
cluster := testSystemBackendMock(t, 1, 1, logical.TypeCredential)
defer cluster.Cleanup()
core := cluster.Cores[0]
// Make a request to lazy load the plugin
req := logical.TestRequest(t, logical.ReadOperation, "auth/mock-0/internal")
req.ClientToken = core.Client.Token()
resp, err := core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: response should not be nil")
}
// Seal the cluster
cluster.EnsureCoresSealed(t)
// Unseal the cluster
barrierKeys := cluster.BarrierKeys
for _, core := range cluster.Cores {
for _, key := range barrierKeys {
_, err := core.Unseal(vault.TestKeyCopy(key))
if err != nil {
t.Fatal(err)
}
}
sealed, err := core.Sealed()
if err != nil {
t.Fatalf("err checking seal status: %s", err)
}
if sealed {
t.Fatal("should not be sealed")
}
// Wait for active so post-unseal takes place
// If it fails, it means unseal process failed
vault.TestWaitActive(t, core.Core)
}
}
func TestSystemBackend_Plugin_MismatchType(t *testing.T) {
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
defer cluster.Cleanup()
core := cluster.Cores[0]
// Replace the plugin with a credential backend
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials")
// Make a request to lazy load the now-credential plugin
// and expect an error
req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal")
req.ClientToken = core.Client.Token()
_, err := core.HandleRequest(req)
if err == nil {
t.Fatalf("expected error due to mismatch on error type: %s", err)
}
// Sleep a bit before cleanup is called
time.Sleep(1 * time.Second)
}
func TestSystemBackend_Plugin_CatalogRemoved(t *testing.T) {
t.Run("secret", func(t *testing.T) {
testPlugin_CatalogRemoved(t, logical.TypeLogical, false)
})
t.Run("auth", func(t *testing.T) {
testPlugin_CatalogRemoved(t, logical.TypeCredential, false)
})
t.Run("secret-mount-existing", func(t *testing.T) {
testPlugin_CatalogRemoved(t, logical.TypeLogical, true)
})
t.Run("auth-mount-existing", func(t *testing.T) {
testPlugin_CatalogRemoved(t, logical.TypeCredential, true)
})
}
func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMount bool) {
cluster := testSystemBackendMock(t, 1, 1, btype)
defer cluster.Cleanup()
core := cluster.Cores[0]
// Remove the plugin from the catalog
req := logical.TestRequest(t, logical.DeleteOperation, "sys/plugins/catalog/mock-plugin")
req.ClientToken = core.Client.Token()
resp, err := core.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Seal the cluster
cluster.EnsureCoresSealed(t)
// Unseal the cluster
barrierKeys := cluster.BarrierKeys
for _, core := range cluster.Cores {
for _, key := range barrierKeys {
_, err := core.Unseal(vault.TestKeyCopy(key))
if err != nil {
t.Fatal(err)
}
}
sealed, err := core.Sealed()
if err != nil {
t.Fatalf("err checking seal status: %s", err)
}
if sealed {
t.Fatal("should not be sealed")
}
// Wait for active so post-unseal takes place
// If it fails, it means unseal process failed
vault.TestWaitActive(t, core.Core)
}
if testMount {
// Add plugin back to the catalog
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical")
// Mount the plugin at the same path after plugin is re-added to the catalog
// and expect an error due to existing path.
var err error
switch btype {
case logical.TypeLogical:
_, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{
"type": "plugin",
"config": map[string]interface{}{
"plugin_name": "mock-plugin",
},
})
case logical.TypeCredential:
_, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{
"type": "plugin",
"plugin_name": "mock-plugin",
})
}
if err == nil {
t.Fatal("expected error when mounting on existing path")
}
}
}
func TestSystemBackend_Plugin_autoReload(t *testing.T) {
cluster := testSystemBackendMock(t, 1, logical.TypeLogical)
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
defer cluster.Cleanup()
core := cluster.Cores[0]
@ -65,6 +245,35 @@ func TestSystemBackend_Plugin_autoReload(t *testing.T) {
}
}
func TestSystemBackend_Plugin_SealUnseal(t *testing.T) {
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
defer cluster.Cleanup()
// Seal the cluster
cluster.EnsureCoresSealed(t)
// Unseal the cluster
barrierKeys := cluster.BarrierKeys
for _, core := range cluster.Cores {
for _, key := range barrierKeys {
_, err := core.Unseal(vault.TestKeyCopy(key))
if err != nil {
t.Fatal(err)
}
}
sealed, err := core.Sealed()
if err != nil {
t.Fatalf("err checking seal status: %s", err)
}
if sealed {
t.Fatal("should not be sealed")
}
// Wait for active so post-unseal takes place
// If it fails, it means unseal process failed
vault.TestWaitActive(t, core.Core)
}
}
func TestSystemBackend_Plugin_reload(t *testing.T) {
data := map[string]interface{}{
"plugin": "mock-plugin",
@ -77,8 +286,9 @@ func TestSystemBackend_Plugin_reload(t *testing.T) {
t.Run("mounts", func(t *testing.T) { testSystemBackend_PluginReload(t, data) })
}
// Helper func to test different reload methods on plugin reload endpoint
func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}) {
cluster := testSystemBackendMock(t, 2, logical.TypeLogical)
cluster := testSystemBackendMock(t, 1, 2, logical.TypeLogical)
defer cluster.Cleanup()
core := cluster.Cores[0]
@ -123,7 +333,7 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
// testSystemBackendMock returns a systemBackend with the desired number
// of mounted mock plugin backends
func testSystemBackendMock(t *testing.T, numMounts int, backendType logical.BackendType) *vault.TestCluster {
func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType) *vault.TestCluster {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"plugin": plugin.Factory,
@ -134,7 +344,9 @@ func testSystemBackendMock(t *testing.T, numMounts int, backendType logical.Back
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
HandlerFunc: vaulthttp.Handler,
KeepStandbysSealed: true,
NumCores: numCores,
})
cluster.Start()
@ -197,7 +409,8 @@ func testSystemBackendMock(t *testing.T, numMounts int, backendType logical.Back
}
func TestBackend_PluginMainLogical(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
args := []string{}
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadaModeEnv) != "true" {
return
}
@ -205,16 +418,16 @@ func TestBackend_PluginMainLogical(t *testing.T) {
if caPEM == "" {
t.Fatal("CA cert not passed in")
}
factoryFunc := mock.FactoryType(logical.TypeLogical)
args := []string{"--ca-cert=" + caPEM}
args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM))
apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(args)
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig)
factoryFunc := mock.FactoryType(logical.TypeLogical)
err := lplugin.Serve(&lplugin.ServeOpts{
BackendFactoryFunc: factoryFunc,
TLSProviderFunc: tlsProviderFunc,
@ -225,7 +438,8 @@ func TestBackend_PluginMainLogical(t *testing.T) {
}
func TestBackend_PluginMainCredentials(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
args := []string{}
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadaModeEnv) != "true" {
return
}
@ -233,16 +447,16 @@ func TestBackend_PluginMainCredentials(t *testing.T) {
if caPEM == "" {
t.Fatal("CA cert not passed in")
}
factoryFunc := mock.FactoryType(logical.TypeCredential)
args := []string{"--ca-cert=" + caPEM}
args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM))
apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(args)
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig)
factoryFunc := mock.FactoryType(logical.TypeCredential)
err := lplugin.Serve(&lplugin.ServeOpts{
BackendFactoryFunc: factoryFunc,
TLSProviderFunc: tlsProviderFunc,

View File

@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/jsonutil"
@ -663,11 +664,12 @@ func (c *Core) setupMounts() error {
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
var backend logical.Backend
var view *BarrierView
var err error
for _, entry := range c.mounts.Entries {
var backend logical.Backend
// Initialize the backend, special casing for system
barrierPath := backendBarrierPrefix + entry.UUID + "/"
if entry.Type == "system" {
@ -686,6 +688,9 @@ func (c *Core) setupMounts() error {
backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf)
if err != nil {
c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err)
if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" {
goto ROUTER_MOUNT
}
return errLoadMountsFailed
}
if backend == nil {
@ -693,9 +698,8 @@ func (c *Core) setupMounts() error {
}
// Check for the correct backend type
backendType := backend.Type()
if entry.Type == "plugin" && backendType != logical.TypeLogical {
return fmt.Errorf("cannot mount '%s' of type '%s' as a logical backend", entry.Config.PluginName, backendType)
if entry.Type == "plugin" && backend.Type() != logical.TypeLogical {
return fmt.Errorf("cannot mount '%s' of type '%s' as a logical backend", entry.Config.PluginName, backend.Type())
}
if err := backend.Initialize(); err != nil {
@ -710,7 +714,7 @@ func (c *Core) setupMounts() error {
ch.saltUUID = entry.UUID
ch.storageView = view
}
ROUTER_MOUNT:
// Mount the backend
err = c.router.Mount(backend, entry.Path, entry, view)
if err != nil {

View File

@ -19,6 +19,7 @@ import (
var (
pluginCatalogPath = "core/plugin-catalog/"
ErrDirectoryNotConfigured = errors.New("could not set plugin, plugin directory is not configured")
ErrPluginNotFound = errors.New("plugin not found in the catalog")
)
// PluginCatalog keeps a record of plugins known to vault. External plugins need
@ -37,6 +38,10 @@ func (c *Core) setupPluginCatalog() error {
directory: c.pluginDirectory,
}
if c.logger.IsInfo() {
c.logger.Info("core: successfully setup plugin catalog", "plugin-directory", c.pluginDirectory)
}
return nil
}

View File

@ -64,9 +64,12 @@ func (r *Router) Mount(backend logical.Backend, prefix string, mountEntry *Mount
}
// Build the paths
paths := backend.SpecialPaths()
if paths == nil {
paths = new(logical.Paths)
paths := new(logical.Paths)
if backend != nil {
specialPaths := backend.SpecialPaths()
if specialPaths != nil {
paths = specialPaths
}
}
// Create a mount entry

View File

@ -335,11 +335,17 @@ func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) {
}
sum := hash.Sum(nil)
c.pluginCatalog.directory, err = filepath.EvalSymlinks(os.Args[0])
// Determine plugin directory path
fullPath, err := filepath.EvalSymlinks(os.Args[0])
if err != nil {
t.Fatal(err)
}
c.pluginCatalog.directory = filepath.Dir(c.pluginCatalog.directory)
directoryPath := filepath.Dir(fullPath)
// Set core's plugin directory and plugin catalog directory
c.pluginDirectory = directoryPath
c.pluginCatalog.directory = directoryPath
command := fmt.Sprintf("%s --test.run=%s", filepath.Base(os.Args[0]), testFunc)
err = c.pluginCatalog.Set(name, command, sum)
@ -585,6 +591,7 @@ func GenerateRandBytes(length int) ([]byte, error) {
}
func TestWaitActive(t testing.T, core *Core) {
t.Helper()
start := time.Now()
var standby bool
var err error
@ -627,6 +634,13 @@ func (c *TestCluster) Start() {
}
}
func (c *TestCluster) EnsureCoresSealed(t testing.T) {
t.Helper()
if err := c.ensureCoresSealed(); err != nil {
t.Fatal(err)
}
}
func (c *TestCluster) Cleanup() {
// Close listeners
for _, core := range c.Cores {
@ -638,25 +652,7 @@ func (c *TestCluster) Cleanup() {
}
// Seal the cores
for _, core := range c.Cores {
if err := core.Shutdown(); err != nil {
continue
}
timeout := time.Now().Add(60 * time.Second)
for {
if time.Now().After(timeout) {
continue
}
sealed, err := core.Sealed()
if err != nil {
continue
}
if sealed {
break
}
time.Sleep(250 * time.Millisecond)
}
}
c.ensureCoresSealed()
// Remove any temp dir that exists
if c.TempDir != "" {
@ -667,6 +663,29 @@ func (c *TestCluster) Cleanup() {
time.Sleep(time.Second)
}
func (c *TestCluster) ensureCoresSealed() error {
for _, core := range c.Cores {
if err := core.Shutdown(); err != nil {
return err
}
timeout := time.Now().Add(60 * time.Second)
for {
if time.Now().After(timeout) {
return fmt.Errorf("timeout waiting for core to seal")
}
sealed, err := core.Sealed()
if err != nil {
return err
}
if sealed {
break
}
time.Sleep(250 * time.Millisecond)
}
}
return nil
}
type TestListener struct {
net.Listener
Address *net.TCPAddr
@ -692,9 +711,29 @@ type TestClusterOptions struct {
KeepStandbysSealed bool
HandlerFunc func(*Core) http.Handler
BaseListenAddress string
NumCores int
}
var DefaultNumCores = 3
type certInfo struct {
cert *x509.Certificate
certPEM []byte
certBytes []byte
key *ecdsa.PrivateKey
keyPEM []byte
}
// NewTestCluster creates a new test cluster based on the provided core config
// and test cluster options.
func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster {
var numCores int
if opts == nil || opts.NumCores == 0 {
numCores = DefaultNumCores
} else {
numCores = opts.NumCores
}
certIPs := []net.IP{
net.IPv6loopback,
net.ParseIP("127.0.0.1"),
@ -770,270 +809,131 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
t.Fatal(err)
}
s1Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
s1CertTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: "localhost",
},
DNSNames: []string{"localhost"},
IPAddresses: certIPs,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
}
s1CertBytes, err := x509.CreateCertificate(rand.Reader, s1CertTemplate, caCert, s1Key.Public(), caKey)
if err != nil {
t.Fatal(err)
}
s1Cert, err := x509.ParseCertificate(s1CertBytes)
if err != nil {
t.Fatal(err)
}
s1CertPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: s1CertBytes,
}
s1CertPEM := pem.EncodeToMemory(s1CertPEMBlock)
s1MarshaledKey, err := x509.MarshalECPrivateKey(s1Key)
if err != nil {
t.Fatal(err)
}
s1KeyPEMBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: s1MarshaledKey,
}
s1KeyPEM := pem.EncodeToMemory(s1KeyPEMBlock)
var certInfoSlice []*certInfo
s2Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
s2CertTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: "localhost",
},
DNSNames: []string{"localhost"},
IPAddresses: certIPs,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
}
s2CertBytes, err := x509.CreateCertificate(rand.Reader, s2CertTemplate, caCert, s2Key.Public(), caKey)
if err != nil {
t.Fatal(err)
}
s2Cert, err := x509.ParseCertificate(s2CertBytes)
if err != nil {
t.Fatal(err)
}
s2CertPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: s2CertBytes,
}
s2CertPEM := pem.EncodeToMemory(s2CertPEMBlock)
s2MarshaledKey, err := x509.MarshalECPrivateKey(s2Key)
if err != nil {
t.Fatal(err)
}
s2KeyPEMBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: s2MarshaledKey,
}
s2KeyPEM := pem.EncodeToMemory(s2KeyPEMBlock)
//
// Certs generation
//
for i := 0; i < numCores; i++ {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
certTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: "localhost",
},
DNSNames: []string{"localhost"},
IPAddresses: certIPs,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
}
certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey)
if err != nil {
t.Fatal(err)
}
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
t.Fatal(err)
}
certPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
}
certPEM := pem.EncodeToMemory(certPEMBlock)
marshaledKey, err := x509.MarshalECPrivateKey(key)
if err != nil {
t.Fatal(err)
}
keyPEMBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: marshaledKey,
}
keyPEM := pem.EncodeToMemory(keyPEMBlock)
s3Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
certInfoSlice = append(certInfoSlice, &certInfo{
cert: cert,
certPEM: certPEM,
certBytes: certBytes,
key: key,
keyPEM: keyPEM,
})
}
s3CertTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: "localhost",
},
DNSNames: []string{"localhost"},
IPAddresses: certIPs,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
}
s3CertBytes, err := x509.CreateCertificate(rand.Reader, s3CertTemplate, caCert, s3Key.Public(), caKey)
if err != nil {
t.Fatal(err)
}
s3Cert, err := x509.ParseCertificate(s3CertBytes)
if err != nil {
t.Fatal(err)
}
s3CertPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: s3CertBytes,
}
s3CertPEM := pem.EncodeToMemory(s3CertPEMBlock)
s3MarshaledKey, err := x509.MarshalECPrivateKey(s3Key)
if err != nil {
t.Fatal(err)
}
s3KeyPEMBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: s3MarshaledKey,
}
s3KeyPEM := pem.EncodeToMemory(s3KeyPEMBlock)
logger := logformat.NewVaultLogger(log.LevelTrace)
//
// Listener setup
//
ports := []int{0, 0, 0}
logger := logformat.NewVaultLogger(log.LevelTrace)
ports := make([]int, numCores)
if baseAddr != nil {
ports = []int{baseAddr.Port, baseAddr.Port + 1, baseAddr.Port + 2}
for i := 0; i < numCores; i++ {
ports[i] = baseAddr.Port + i
}
} else {
baseAddr = &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
}
}
baseAddr.Port = ports[0]
ln, err := net.ListenTCP("tcp", baseAddr)
if err != nil {
t.Fatal(err)
}
s1CertFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node1_port_%d_cert.pem", ln.Addr().(*net.TCPAddr).Port))
s1KeyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node1_port_%d_key.pem", ln.Addr().(*net.TCPAddr).Port))
err = ioutil.WriteFile(s1CertFile, s1CertPEM, 0755)
if err != nil {
t.Fatal(err)
}
err = ioutil.WriteFile(s1KeyFile, s1KeyPEM, 0755)
if err != nil {
t.Fatal(err)
}
s1TLSCert, err := tls.X509KeyPair(s1CertPEM, s1KeyPEM)
if err != nil {
t.Fatal(err)
}
s1CertGetter := reload.NewCertificateGetter(s1CertFile, s1KeyFile)
s1TLSConfig := &tls.Config{
Certificates: []tls.Certificate{s1TLSCert},
RootCAs: testCluster.RootCAs,
ClientCAs: testCluster.RootCAs,
ClientAuth: tls.VerifyClientCertIfGiven,
NextProtos: []string{"h2", "http/1.1"},
GetCertificate: s1CertGetter.GetCertificate,
}
s1TLSConfig.BuildNameToCertificate()
c1lns := []*TestListener{&TestListener{
Listener: tls.NewListener(ln, s1TLSConfig),
Address: ln.Addr().(*net.TCPAddr),
},
}
var handler1 http.Handler = http.NewServeMux()
server1 := &http.Server{
Handler: handler1,
}
if err := http2.ConfigureServer(server1, nil); err != nil {
t.Fatal(err)
}
baseAddr.Port = ports[1]
ln, err = net.ListenTCP("tcp", baseAddr)
if err != nil {
t.Fatal(err)
}
s2CertFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node2_port_%d_cert.pem", ln.Addr().(*net.TCPAddr).Port))
s2KeyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node2_port_%d_key.pem", ln.Addr().(*net.TCPAddr).Port))
err = ioutil.WriteFile(s2CertFile, s2CertPEM, 0755)
if err != nil {
t.Fatal(err)
}
err = ioutil.WriteFile(s2KeyFile, s2KeyPEM, 0755)
if err != nil {
t.Fatal(err)
}
s2TLSCert, err := tls.X509KeyPair(s2CertPEM, s2KeyPEM)
if err != nil {
t.Fatal(err)
}
s2CertGetter := reload.NewCertificateGetter(s2CertFile, s2KeyFile)
s2TLSConfig := &tls.Config{
Certificates: []tls.Certificate{s2TLSCert},
RootCAs: testCluster.RootCAs,
ClientCAs: testCluster.RootCAs,
ClientAuth: tls.VerifyClientCertIfGiven,
NextProtos: []string{"h2", "http/1.1"},
GetCertificate: s2CertGetter.GetCertificate,
}
s2TLSConfig.BuildNameToCertificate()
c2lns := []*TestListener{&TestListener{
Listener: tls.NewListener(ln, s2TLSConfig),
Address: ln.Addr().(*net.TCPAddr),
},
}
var handler2 http.Handler = http.NewServeMux()
server2 := &http.Server{
Handler: handler2,
}
if err := http2.ConfigureServer(server2, nil); err != nil {
t.Fatal(err)
}
baseAddr.Port = ports[2]
ln, err = net.ListenTCP("tcp", baseAddr)
if err != nil {
t.Fatal(err)
}
s3CertFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node3_port_%d_cert.pem", ln.Addr().(*net.TCPAddr).Port))
s3KeyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node3_port_%d_key.pem", ln.Addr().(*net.TCPAddr).Port))
err = ioutil.WriteFile(s3CertFile, s3CertPEM, 0755)
if err != nil {
t.Fatal(err)
}
err = ioutil.WriteFile(s3KeyFile, s3KeyPEM, 0755)
if err != nil {
t.Fatal(err)
}
s3TLSCert, err := tls.X509KeyPair(s3CertPEM, s3KeyPEM)
if err != nil {
t.Fatal(err)
}
s3CertGetter := reload.NewCertificateGetter(s3CertFile, s3KeyFile)
s3TLSConfig := &tls.Config{
Certificates: []tls.Certificate{s3TLSCert},
RootCAs: testCluster.RootCAs,
ClientCAs: testCluster.RootCAs,
ClientAuth: tls.VerifyClientCertIfGiven,
NextProtos: []string{"h2", "http/1.1"},
GetCertificate: s3CertGetter.GetCertificate,
}
s3TLSConfig.BuildNameToCertificate()
c3lns := []*TestListener{&TestListener{
Listener: tls.NewListener(ln, s3TLSConfig),
Address: ln.Addr().(*net.TCPAddr),
},
}
var handler3 http.Handler = http.NewServeMux()
server3 := &http.Server{
Handler: handler3,
}
if err := http2.ConfigureServer(server3, nil); err != nil {
t.Fatal(err)
listeners := [][]*TestListener{}
servers := []*http.Server{}
handlers := []http.Handler{}
tlsConfigs := []*tls.Config{}
certGetters := []*reload.CertificateGetter{}
for i := 0; i < numCores; i++ {
baseAddr.Port = ports[i]
ln, err := net.ListenTCP("tcp", baseAddr)
if err != nil {
t.Fatal(err)
}
certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port))
keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port))
err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0755)
if err != nil {
t.Fatal(err)
}
err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0755)
if err != nil {
t.Fatal(err)
}
tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM)
if err != nil {
t.Fatal(err)
}
certGetter := reload.NewCertificateGetter(certFile, keyFile)
certGetters = append(certGetters, certGetter)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{tlsCert},
RootCAs: testCluster.RootCAs,
ClientCAs: testCluster.RootCAs,
ClientAuth: tls.VerifyClientCertIfGiven,
NextProtos: []string{"h2", "http/1.1"},
GetCertificate: certGetter.GetCertificate,
}
tlsConfig.BuildNameToCertificate()
tlsConfigs = append(tlsConfigs, tlsConfig)
lns := []*TestListener{&TestListener{
Listener: tls.NewListener(ln, tlsConfig),
Address: ln.Addr().(*net.TCPAddr),
},
}
listeners = append(listeners, lns)
var handler http.Handler = http.NewServeMux()
handlers = append(handlers, handler)
server := &http.Server{
Handler: handler,
}
servers = append(servers, server)
if err := http2.ConfigureServer(server, nil); err != nil {
t.Fatal(err)
}
}
// Create three cores with the same physical and different redirect/cluster
@ -1049,8 +949,8 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
LogicalBackends: make(map[string]logical.Factory),
CredentialBackends: make(map[string]logical.Factory),
AuditBackends: make(map[string]audit.Factory),
RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port),
ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port+105),
RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port),
ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port+105),
DisableMlock: true,
EnableUI: true,
}
@ -1126,39 +1026,21 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
coreConfig.HAPhysical = haPhys.(physical.HABackend)
}
c1, err := NewCore(coreConfig)
if err != nil {
t.Fatalf("err: %v", err)
}
if opts != nil && opts.HandlerFunc != nil {
handler1 = opts.HandlerFunc(c1)
server1.Handler = handler1
}
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port)
if coreConfig.ClusterAddr != "" {
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port+105)
}
c2, err := NewCore(coreConfig)
if err != nil {
t.Fatalf("err: %v", err)
}
if opts != nil && opts.HandlerFunc != nil {
handler2 = opts.HandlerFunc(c2)
server2.Handler = handler2
}
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port)
if coreConfig.ClusterAddr != "" {
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port+105)
}
c3, err := NewCore(coreConfig)
if err != nil {
t.Fatalf("err: %v", err)
}
if opts != nil && opts.HandlerFunc != nil {
handler3 = opts.HandlerFunc(c3)
server3.Handler = handler3
cores := []*Core{}
for i := 0; i < numCores; i++ {
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port)
if coreConfig.ClusterAddr != "" {
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port+105)
}
c, err := NewCore(coreConfig)
if err != nil {
t.Fatalf("err: %v", err)
}
cores = append(cores, c)
if opts != nil && opts.HandlerFunc != nil {
handlers[i] = opts.HandlerFunc(c)
servers[i].Handler = handlers[i]
}
}
//
@ -1175,16 +1057,19 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
return ret
}
c2.SetClusterListenerAddrs(clusterAddrGen(c2lns))
c2.SetClusterHandler(handler2)
c3.SetClusterListenerAddrs(clusterAddrGen(c3lns))
c3.SetClusterHandler(handler3)
if numCores > 1 {
for i := 1; i < numCores; i++ {
cores[i].SetClusterListenerAddrs(clusterAddrGen(listeners[i]))
cores[i].SetClusterHandler(handlers[i])
}
}
keys, root := TestCoreInitClusterWrapperSetup(t, c1, clusterAddrGen(c1lns), handler1)
keys, root := TestCoreInitClusterWrapperSetup(t, cores[0], clusterAddrGen(listeners[0]), handlers[0])
barrierKeys, _ := copystructure.Copy(keys)
testCluster.BarrierKeys = barrierKeys.([][]byte)
testCluster.RootToken = root
// Write root token and barrier keys
err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(root), 0755)
if err != nil {
t.Fatal(err)
@ -1201,14 +1086,15 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
t.Fatal(err)
}
// Unseal first core
for _, key := range keys {
if _, err := c1.Unseal(TestKeyCopy(key)); err != nil {
if _, err := cores[0].Unseal(TestKeyCopy(key)); err != nil {
t.Fatalf("unseal err: %s", err)
}
}
// Verify unsealed
sealed, err := c1.Sealed()
sealed, err := cores[0].Sealed()
if err != nil {
t.Fatalf("err checking seal status: %s", err)
}
@ -1216,41 +1102,38 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
t.Fatal("should not be sealed")
}
TestWaitActive(t, c1)
TestWaitActive(t, cores[0])
if opts == nil || !opts.KeepStandbysSealed {
for _, key := range keys {
if _, err := c2.Unseal(TestKeyCopy(key)); err != nil {
t.Fatalf("unseal err: %s", err)
}
}
for _, key := range keys {
if _, err := c3.Unseal(TestKeyCopy(key)); err != nil {
t.Fatalf("unseal err: %s", err)
// Unseal other cores unless otherwise specified
if (opts == nil || !opts.KeepStandbysSealed) && numCores > 1 {
for i := 1; i < numCores; i++ {
for _, key := range keys {
if _, err := cores[i].Unseal(TestKeyCopy(key)); err != nil {
t.Fatalf("unseal err: %s", err)
}
}
}
// Let them come fully up to standby
time.Sleep(2 * time.Second)
// Ensure cluster connection info is populated
isLeader, _, _, err := c2.Leader()
if err != nil {
t.Fatal(err)
}
if isLeader {
t.Fatal("c2 should not be leader")
}
isLeader, _, _, err = c3.Leader()
if err != nil {
t.Fatal(err)
}
if isLeader {
t.Fatal("c3 should not be leader")
// Ensure cluster connection info is populated.
// Other cores should not come up as leaders.
for i := 1; i < numCores; i++ {
isLeader, _, _, err := cores[i].Leader()
if err != nil {
t.Fatal(err)
}
if isLeader {
t.Fatalf("core[%d] should not be leader", i)
}
}
}
cluster, err := c1.Cluster()
//
// Set test cluster core(s) and test cluster
//
cluster, err := cores[0].Cluster()
if err != nil {
t.Fatal(err)
}
@ -1278,65 +1161,27 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
}
var ret []*TestClusterCore
t1 := &TestClusterCore{
Core: c1,
ServerKey: s1Key,
ServerKeyPEM: s1KeyPEM,
ServerCert: s1Cert,
ServerCertBytes: s1CertBytes,
ServerCertPEM: s1CertPEM,
Listeners: c1lns,
Handler: handler1,
Server: server1,
TLSConfig: s1TLSConfig,
Client: getAPIClient(c1lns[0].Address.Port, s1TLSConfig),
for i := 0; i < numCores; i++ {
tcc := &TestClusterCore{
Core: cores[i],
ServerKey: certInfoSlice[i].key,
ServerKeyPEM: certInfoSlice[i].keyPEM,
ServerCert: certInfoSlice[i].cert,
ServerCertBytes: certInfoSlice[i].certBytes,
ServerCertPEM: certInfoSlice[i].certPEM,
Listeners: listeners[i],
Handler: handlers[i],
Server: servers[i],
TLSConfig: tlsConfigs[i],
Client: getAPIClient(listeners[i][0].Address.Port, tlsConfigs[i]),
}
tcc.ReloadFuncs = &cores[i].reloadFuncs
tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock
tcc.ReloadFuncsLock.Lock()
(*tcc.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{certGetters[i].Reload}
tcc.ReloadFuncsLock.Unlock()
ret = append(ret, tcc)
}
t1.ReloadFuncs = &c1.reloadFuncs
t1.ReloadFuncsLock = &c1.reloadFuncsLock
t1.ReloadFuncsLock.Lock()
(*t1.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{s1CertGetter.Reload}
t1.ReloadFuncsLock.Unlock()
ret = append(ret, t1)
t2 := &TestClusterCore{
Core: c2,
ServerKey: s2Key,
ServerKeyPEM: s2KeyPEM,
ServerCert: s2Cert,
ServerCertBytes: s2CertBytes,
ServerCertPEM: s2CertPEM,
Listeners: c2lns,
Handler: handler2,
Server: server2,
TLSConfig: s2TLSConfig,
Client: getAPIClient(c2lns[0].Address.Port, s2TLSConfig),
}
t2.ReloadFuncs = &c2.reloadFuncs
t2.ReloadFuncsLock = &c2.reloadFuncsLock
t2.ReloadFuncsLock.Lock()
(*t2.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{s2CertGetter.Reload}
t2.ReloadFuncsLock.Unlock()
ret = append(ret, t2)
t3 := &TestClusterCore{
Core: c3,
ServerKey: s3Key,
ServerKeyPEM: s3KeyPEM,
ServerCert: s3Cert,
ServerCertBytes: s3CertBytes,
ServerCertPEM: s3CertPEM,
Listeners: c3lns,
Handler: handler3,
Server: server3,
TLSConfig: s3TLSConfig,
Client: getAPIClient(c3lns[0].Address.Port, s3TLSConfig),
}
t3.ReloadFuncs = &c3.reloadFuncs
t3.ReloadFuncsLock = &c3.reloadFuncsLock
t3.ReloadFuncsLock.Lock()
(*t3.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{s3CertGetter.Reload}
t3.ReloadFuncsLock.Unlock()
ret = append(ret, t3)
testCluster.Cores = ret
return &testCluster