Check if plugin version matches running version (#17182)

Check if plugin version matches running version

When registering a plugin, we check if the request version matches the
self-reported version from the plugin. If these do not match, we log a
warning.

This uncovered a few missing pieces for getting the database version
code fully working.

We added an environment variable that helps us unit test the running
version behavior as well, but only for approle, postgresql, and consul
plugins.

Return 400 on plugin not found or version mismatch

Populate the running SHA256 of plugins in the mount and auth tables (#17217)
This commit is contained in:
Christopher Swenson 2022-09-21 12:25:04 -07:00 committed by GitHub
parent dc3beb428e
commit 2c8e88ab67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 394 additions and 72 deletions

View File

@ -18,6 +18,9 @@ const (
secretIDAccessorLocalPrefix = "accessor_local/"
)
// ReportedVersion is used to report a specific version to Vault.
var ReportedVersion = ""
type backend struct {
*framework.Backend
@ -111,8 +114,9 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
pathTidySecretID(b),
},
),
Invalidate: b.invalidate,
BackendType: logical.TypeCredential,
Invalidate: b.invalidate,
BackendType: logical.TypeCredential,
RunningVersion: ReportedVersion,
}
return b, nil
}

View File

@ -7,6 +7,9 @@ import (
"github.com/hashicorp/vault/sdk/logical"
)
// ReportedVersion is used to report a specific version to Vault.
var ReportedVersion = ""
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(ctx, conf); err != nil {
@ -34,7 +37,8 @@ func Backend() *backend {
Secrets: []*framework.Secret{
secretToken(&b),
},
BackendType: logical.TypeLogical,
BackendType: logical.TypeLogical,
RunningVersion: ReportedVersion,
}
return &b

View File

@ -14,6 +14,7 @@ import (
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/dbtxn"
"github.com/hashicorp/vault/sdk/helper/template"
"github.com/hashicorp/vault/sdk/logical"
_ "github.com/jackc/pgx/v4/stdlib"
)
@ -32,7 +33,8 @@ ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';
)
var (
_ dbplugin.Database = &PostgreSQL{}
_ dbplugin.Database = (*PostgreSQL)(nil)
_ logical.PluginVersioner = (*PostgreSQL)(nil)
// postgresEndStatement is basically the word "END" but
// surrounded by a word boundary to differentiate it from
@ -46,6 +48,9 @@ var (
// singleQuotedPhrases finds substrings like 'hello'
// and pulls them out with the quotes included.
singleQuotedPhrases = regexp.MustCompile(`('.*?')`)
// ReportedVersion is used to report a specific version to Vault.
ReportedVersion = ""
)
func New() (interface{}, error) {
@ -469,6 +474,10 @@ func (p *PostgreSQL) secretValues() map[string]string {
}
}
func (p *PostgreSQL) PluginVersion() logical.PluginVersion {
return logical.PluginVersion{Version: ReportedVersion}
}
// containsMultilineStatement is a best effort to determine whether
// a particular statement is multiline, and therefore should not be
// split upon semicolons. If it's unsure, it defaults to false.

View File

@ -2,12 +2,14 @@ package dbplugin
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/base62"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
"google.golang.org/grpc/codes"
@ -43,11 +45,14 @@ func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error)
if err != nil {
return nil, err
}
if db, ok := g.instances[id]; ok {
return db, nil
}
return g.createDatabase(id)
}
// must hold the g.Lock() to call this function
func (g *gRPCServer) createDatabase(id string) (Database, error) {
db, err := g.factoryFunc()
if err != nil {
return nil, err
@ -304,12 +309,36 @@ func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, e
return &proto.Empty{}, nil
}
// getOrForceCreateDatabase will create a database even if the multiplexing ID is not present
func (g *gRPCServer) getOrForceCreateDatabase(ctx context.Context) (Database, error) {
impl, err := g.getOrCreateDatabase(ctx)
if errors.Is(err, pluginutil.ErrNoMultiplexingIDFound) {
// if this is called without a multiplexing context, like from the plugin catalog directly,
// then we won't have a database ID, so let's generate a new database instance
id, err := base62.Random(10)
if err != nil {
return nil, err
}
g.Lock()
defer g.Unlock()
impl, err = g.createDatabase(id)
if err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
return impl, nil
}
// Version forwards the version request to the underlying Database implementation.
func (g *gRPCServer) Version(ctx context.Context, _ *logical.Empty) (*logical.VersionReply, error) {
impl, err := g.getDatabaseInternal(ctx)
impl, err := g.getOrForceCreateDatabase(ctx)
if err != nil {
return nil, err
}
if versioner, ok := impl.(logical.PluginVersioner); ok {
return &logical.VersionReply{PluginVersion: versioner.PluginVersion().Version}, nil
}

View File

@ -233,7 +233,10 @@ func (mw databaseMetricsMiddleware) Close() (err error) {
// Error Sanitizer Middleware Domain
// ///////////////////////////////////////////////////
var _ Database = DatabaseErrorSanitizerMiddleware{}
var (
_ Database = (*DatabaseErrorSanitizerMiddleware)(nil)
_ logical.PluginVersioner = (*DatabaseErrorSanitizerMiddleware)(nil)
)
// DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and
// sanitizes returned error messages
@ -280,6 +283,13 @@ func (mw DatabaseErrorSanitizerMiddleware) Close() (err error) {
return mw.sanitize(mw.next.Close())
}
func (mw DatabaseErrorSanitizerMiddleware) PluginVersion() logical.PluginVersion {
if versioner, ok := mw.next.(logical.PluginVersioner); ok {
return versioner.PluginVersion()
}
return logical.EmptyPluginVersion
}
// sanitize errors by removing any sensitive strings within their messages. This uses
// the secretsFn to determine what fields should be sanitized.
func (mw DatabaseErrorSanitizerMiddleware) sanitize(err error) error {

View File

@ -2,6 +2,7 @@ package pluginutil
import (
"context"
"errors"
"fmt"
"os"
"strings"
@ -13,6 +14,8 @@ import (
"google.golang.org/grpc/status"
)
var ErrNoMultiplexingIDFound = errors.New("no multiplexing ID found")
type PluginMultiplexingServerImpl struct {
UnimplementedPluginMultiplexingServer
@ -62,7 +65,9 @@ func GetMultiplexIDFromContext(ctx context.Context) (string, error) {
}
multiplexIDs := md[MultiplexingCtxKey]
if len(multiplexIDs) != 1 {
if len(multiplexIDs) == 0 {
return "", ErrNoMultiplexingIDFound
} else if len(multiplexIDs) != 1 {
return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs))
}

View File

@ -2,6 +2,7 @@ package vault
import (
"context"
"encoding/hex"
"errors"
"fmt"
"strings"
@ -170,7 +171,7 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry,
var backend logical.Backend
// Create the new backend
sysView := c.mountEntrySysView(entry)
backend, err = c.newCredentialBackend(ctx, entry, sysView, view)
backend, entry.RunningSha, err = c.newCredentialBackend(ctx, entry, sysView, view)
if err != nil {
return err
}
@ -794,7 +795,7 @@ func (c *Core) setupCredentials(ctx context.Context) error {
// Initialize the backend
sysView := c.mountEntrySysView(entry)
backend, err = c.newCredentialBackend(ctx, entry, sysView, view)
backend, entry.RunningSha, err = c.newCredentialBackend(ctx, entry, sysView, view)
if err != nil {
c.logger.Error("failed to create credential entry", "path", entry.Path, "error", err)
plug, plugerr := c.pluginCatalog.Get(ctx, entry.Type, consts.PluginTypeCredential, "")
@ -913,25 +914,30 @@ func (c *Core) teardownCredentials(ctx context.Context) error {
return nil
}
// newCredentialBackend is used to create and configure a new credential backend by name
func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, error) {
// newCredentialBackend is used to create and configure a new credential backend by name.
// It also returns the SHA256 of the plugin, if available.
func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, string, error) {
t := entry.Type
if alias, ok := credentialAliases[t]; ok {
t = alias
}
var runningSha string
f, ok := c.credentialBackends[t]
if !ok {
plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeCredential, entry.Version)
if err != nil {
return nil, err
return nil, "", err
}
if plug == nil {
errContext := t
if entry.Version != "" {
errContext += fmt.Sprintf(", version=%s", entry.Version)
}
return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, errContext)
return nil, "", fmt.Errorf("%w: %s", ErrPluginNotFound, errContext)
}
if len(plug.Sha256) > 0 {
runningSha = hex.EncodeToString(plug.Sha256)
}
f = plugin.Factory
@ -967,10 +973,10 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV
b, err := f(ctx, config)
if err != nil {
return nil, err
return nil, "", err
}
return b, nil
return b, runningSha, nil
}
// defaultAuthTable creates a default auth table

View File

@ -3,6 +3,7 @@ package vault
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"os"
@ -16,18 +17,19 @@ import (
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
)
var (
compileAuthOnce sync.Once
compileSecretOnce sync.Once
authPluginBytes []byte
secretPluginBytes []byte
pluginCacheLock sync.Mutex
pluginCache = map[string][]byte{}
)
func testCoreWithPlugin(t *testing.T, typ consts.PluginType) (*Core, string, string) {
// version is used to override the plugin's self-reported version
func testCoreWithPlugin(t *testing.T, typ consts.PluginType, version string) (*Core, string, string) {
t.Helper()
pluginName, pluginSHA256, pluginDir := compilePlugin(t, typ)
pluginName, pluginSHA256, pluginDir := compilePlugin(t, typ, version)
conf := &CoreConfig{
BuiltinRegistry: NewMockBuiltinRegistry(),
PluginDirectory: pluginDir,
@ -37,29 +39,46 @@ func testCoreWithPlugin(t *testing.T, typ consts.PluginType) (*Core, string, str
return core, pluginName, pluginSHA256
}
// to mount a plugin, we need a working binary plugin, so we compile one here.
func compilePlugin(t *testing.T, typ consts.PluginType) (name string, shasum string, pluginDir string) {
func getPlugin(t *testing.T, typ consts.PluginType) (string, string, string, string) {
t.Helper()
var pluginName string
var pluginType string
var pluginMain string
var pluginVersionLocation string
var pluginType, pluginName, builtinDirectory string
var once *sync.Once
var pluginBytes *[]byte
switch typ {
case consts.PluginTypeCredential:
pluginType = "approle"
pluginName = "vault-plugin-auth-" + pluginType
builtinDirectory = "credential"
once = &compileAuthOnce
pluginBytes = &authPluginBytes
pluginMain = filepath.Join("builtin", "credential", pluginType, "cmd", pluginType, "main.go")
pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/builtin/credential/%s.ReportedVersion", pluginType)
case consts.PluginTypeSecrets:
pluginType = "consul"
pluginName = "vault-plugin-secrets-" + pluginType
builtinDirectory = "logical"
once = &compileSecretOnce
pluginBytes = &secretPluginBytes
pluginMain = filepath.Join("builtin", "logical", pluginType, "cmd", pluginType, "main.go")
pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/builtin/logical/%s.ReportedVersion", pluginType)
case consts.PluginTypeDatabase:
pluginType = "postgresql"
pluginName = "vault-plugin-database-" + pluginType
pluginMain = filepath.Join("plugins", "database", pluginType, fmt.Sprintf("%s-database-plugin", pluginType), "main.go")
pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/plugins/database/%s.ReportedVersion", pluginType)
default:
t.Fatal(typ.String())
}
return pluginName, pluginType, pluginMain, pluginVersionLocation
}
// to mount a plugin, we need a working binary plugin, so we compile one here.
// pluginVersion is used to override the plugin's self-reported version
func compilePlugin(t *testing.T, typ consts.PluginType, pluginVersion string) (pluginName string, shasum string, pluginDir string) {
t.Helper()
pluginName, pluginType, pluginMain, pluginVersionLocation := getPlugin(t, typ)
pluginCacheLock.Lock()
defer pluginCacheLock.Unlock()
var pluginBytes []byte
dir := ""
// detect if we are in the "vault/" or the root directory and compensate
@ -76,31 +95,41 @@ func compilePlugin(t *testing.T, typ consts.PluginType) (name string, shasum str
pluginPath := path.Join(pluginDir, pluginName)
key := fmt.Sprintf("%s %s %s", pluginName, pluginType, pluginVersion)
// cache the compilation to only run once
once.Do(func() {
cmd := exec.Command("go", "build", "-o", pluginPath, fmt.Sprintf("builtin/%s/%s/cmd/%s/main.go", builtinDirectory, pluginType, pluginType))
var ok bool
pluginBytes, ok = pluginCache[key]
if !ok {
// we need to compile
line := []string{"build"}
if pluginVersion != "" {
line = append(line, "-ldflags", fmt.Sprintf("-X %s=%s", pluginVersionLocation, pluginVersion))
}
line = append(line, "-o", pluginPath, pluginMain)
cmd := exec.Command("go", line...)
cmd.Dir = dir
output, err := cmd.CombinedOutput()
if err != nil {
t.Fatal(fmt.Errorf("error running go build %v output: %s", err, output))
}
*pluginBytes, err = os.ReadFile(pluginPath)
pluginCache[key], err = os.ReadFile(pluginPath)
if err != nil {
t.Fatal(err)
}
})
pluginBytes = pluginCache[key]
}
// write the cached plugin if necessary
var err error
if _, err := os.Stat(pluginPath); os.IsNotExist(err) {
err = os.WriteFile(pluginPath, *pluginBytes, 0o777)
err = os.WriteFile(pluginPath, pluginBytes, 0o777)
}
if err != nil {
t.Fatal(err)
}
sha := sha256.New()
_, err = sha.Write(*pluginBytes)
_, err = sha.Write(pluginBytes)
if err != nil {
t.Fatal(err)
}
@ -125,7 +154,7 @@ func TestCore_EnableExternalPlugin(t *testing.T) {
},
} {
t.Run(name, func(t *testing.T) {
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType)
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "")
d := &framework.FieldData{
Raw: map[string]interface{}{
"name": pluginName,
@ -201,7 +230,7 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) {
},
} {
t.Run(name, func(t *testing.T) {
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType)
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "")
for _, version := range tc.registerVersions {
d := &framework.FieldData{
Raw: map[string]interface{}{
@ -247,6 +276,10 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) {
if raw.(*routeEntry).mountEntry.RunningVersion != "" {
t.Errorf("Expected mount to have no running version but got %s", raw.(*routeEntry).mountEntry.RunningVersion)
}
if raw.(*routeEntry).mountEntry.RunningSha == "" {
t.Errorf("Expected RunningSha to be present: %+v", raw.(*routeEntry).mountEntry.RunningSha)
}
})
}
}
@ -269,7 +302,7 @@ func TestCore_EnableExternalPlugin_NoVersionsOkay(t *testing.T) {
},
} {
t.Run(name, func(t *testing.T) {
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType)
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "")
d := &framework.FieldData{
Raw: map[string]interface{}{
"name": pluginName,
@ -328,7 +361,7 @@ func TestCore_EnableExternalCredentialPlugin_NoVersionOnRegister(t *testing.T) {
},
} {
t.Run(name, func(t *testing.T) {
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType)
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "")
d := &framework.FieldData{
Raw: map[string]interface{}{
"name": pluginName,
@ -372,7 +405,7 @@ func TestCore_EnableExternalCredentialPlugin_InvalidName(t *testing.T) {
},
} {
t.Run(name, func(t *testing.T) {
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType)
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "")
d := &framework.FieldData{
Raw: map[string]interface{}{
"name": pluginName,
@ -390,6 +423,69 @@ func TestCore_EnableExternalCredentialPlugin_InvalidName(t *testing.T) {
}
}
func TestExternalPlugin_getBackendTypeVersion(t *testing.T) {
for name, tc := range map[string]struct {
pluginType consts.PluginType
setRunningVersion string
}{
"external credential plugin": {
pluginType: consts.PluginTypeCredential,
setRunningVersion: "v1.2.3",
},
"external secrets plugin": {
pluginType: consts.PluginTypeSecrets,
setRunningVersion: "v1.2.3",
},
"external database plugin": {
pluginType: consts.PluginTypeDatabase,
setRunningVersion: "v1.2.3",
},
} {
t.Run(name, func(t *testing.T) {
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, tc.setRunningVersion)
d := &framework.FieldData{
Raw: map[string]interface{}{
"name": pluginName,
"sha256": pluginSHA256,
"version": tc.setRunningVersion,
"command": pluginName,
},
Schema: c.systemBackend.pluginsCatalogCRUDPath().Fields,
}
resp, err := c.systemBackend.handlePluginCatalogUpdate(context.Background(), nil, d)
if err != nil {
t.Fatal(err)
}
if resp.Error() != nil {
t.Fatalf("%#v", resp)
}
shaBytes, _ := hex.DecodeString(pluginSHA256)
commandFull := filepath.Join(c.pluginCatalog.directory, pluginName)
entry := &pluginutil.PluginRunner{
Name: pluginName,
Command: commandFull,
Args: nil,
Sha256: shaBytes,
Builtin: false,
}
var version logical.PluginVersion
if tc.pluginType == consts.PluginTypeDatabase {
version, err = c.pluginCatalog.getDatabaseRunningVersion(context.Background(), entry)
} else {
version, err = c.pluginCatalog.getBackendRunningVersion(context.Background(), entry)
}
if err != nil {
t.Fatal(err)
}
if version.Version != tc.setRunningVersion {
t.Errorf("Expected to get version %v but got %v", tc.setRunningVersion, version.Version)
}
})
}
}
func mountTable(pluginType consts.PluginType) string {
switch pluginType {
case consts.PluginTypeCredential:

View File

@ -8,7 +8,7 @@ import (
"strings"
"testing"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/go-uuid"
credGithub "github.com/hashicorp/vault/builtin/credential/github"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/namespace"
@ -688,7 +688,7 @@ func TestIdentityStore_LoadingEntities(t *testing.T) {
ghSysview := c.mountEntrySysView(meGH)
// Create new github auth credential backend
ghAuth, err := c.newCredentialBackend(context.Background(), meGH, ghSysview, ghView)
ghAuth, _, err := c.newCredentialBackend(context.Background(), meGH, ghSysview, ghView)
if err != nil {
t.Fatal(err)
}

View File

@ -517,6 +517,9 @@ func (b *SystemBackend) handlePluginCatalogUpdate(ctx context.Context, _ *logica
err = b.Core.pluginCatalog.Set(ctx, pluginName, pluginType, pluginVersion, parts[0], args, env, sha256Bytes)
if err != nil {
if errors.Is(err, ErrPluginNotFound) || strings.HasPrefix(err.Error(), "plugin version mismatch") {
return logical.ErrorResponse(err.Error()), nil
}
return nil, err
}

View File

@ -2,6 +2,7 @@ package vault
import (
"context"
"encoding/hex"
"errors"
"fmt"
"os"
@ -608,7 +609,7 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry, updateStora
var backend logical.Backend
sysView := c.mountEntrySysView(entry)
backend, err = c.newLogicalBackend(ctx, entry, sysView, view)
backend, entry.RunningSha, err = c.newLogicalBackend(ctx, entry, sysView, view)
if err != nil {
return err
}
@ -1419,7 +1420,7 @@ func (c *Core) setupMounts(ctx context.Context) error {
var backend logical.Backend
// Create the new backend
sysView := c.mountEntrySysView(entry)
backend, err = c.newLogicalBackend(ctx, entry, sysView, view)
backend, entry.RunningSha, err = c.newLogicalBackend(ctx, entry, sysView, view)
if err != nil {
c.logger.Error("failed to create mount entry", "path", entry.Path, "error", err)
if !c.builtinRegistry.Contains(entry.Type, consts.PluginTypeSecrets) {
@ -1523,25 +1524,30 @@ func (c *Core) unloadMounts(ctx context.Context) error {
return nil
}
// newLogicalBackend is used to create and configure a new logical backend by name
func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, error) {
// newLogicalBackend is used to create and configure a new logical backend by name.
// It also returns the SHA256 of the plugin, if available.
func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, string, error) {
t := entry.Type
if alias, ok := mountAliases[t]; ok {
t = alias
}
var runningSha string
f, ok := c.logicalBackends[t]
if !ok {
plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeSecrets, entry.Version)
if err != nil {
return nil, err
return nil, "", err
}
if plug == nil {
errContext := t
if entry.Version != "" {
errContext += fmt.Sprintf(", version=%s", entry.Version)
}
return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, errContext)
return nil, "", fmt.Errorf("%w: %s", ErrPluginNotFound, errContext)
}
if len(plug.Sha256) > 0 {
runningSha = hex.EncodeToString(plug.Sha256)
}
f = plugin.Factory
@ -1578,14 +1584,14 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView
ctx = context.WithValue(ctx, "core_number", c.coreNumber)
b, err := f(ctx, config)
if err != nil {
return nil, err
return nil, "", err
}
if b == nil {
return nil, fmt.Errorf("nil backend of type %q returned from factory", t)
return nil, "", fmt.Errorf("nil backend of type %q returned from factory", t)
}
addLicenseCallback(c, b)
return b, nil
return b, runningSha, nil
}
// defaultMountTable creates a default mount table

View File

@ -365,7 +365,7 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi
// getPluginTypeFromUnknown will attempt to run the plugin to determine the
// type. It will first attempt to run as a database plugin then a backend
// plugin.
func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, error) {
func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, plugin *pluginutil.PluginRunner) (consts.PluginType, error) {
merr := &multierror.Error{}
err := c.isDatabasePlugin(ctx, plugin)
if err == nil {
@ -461,6 +461,124 @@ func (c *PluginCatalog) getBackendPluginType(ctx context.Context, pluginRunner *
return consts.PluginTypeUnknown, merr.ErrorOrNil()
}
// getBackendRunningVersion attempts to get the plugin version
func (c *PluginCatalog) getBackendRunningVersion(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (logical.PluginVersion, error) {
merr := &multierror.Error{}
// Attempt to run as backend plugin
config := pluginutil.PluginClientConfig{
Name: pluginRunner.Name,
PluginSets: backendplugin.PluginSet,
HandshakeConfig: backendplugin.HandshakeConfig,
Logger: log.NewNullLogger(),
IsMetadataMode: false,
AutoMTLS: true,
}
var client logical.Backend
// First, attempt to run as backend V5 plugin
c.logger.Debug("attempting to load backend plugin", "name", pluginRunner.Name)
pc, err := c.newPluginClient(ctx, pluginRunner, config)
if err == nil {
// we spawned a subprocess, so make sure to clean it up
defer c.cleanupExternalPlugin(pluginRunner.Name, pc.id)
// dispense the plugin so we can get its version
client, err = backendplugin.Dispense(pc.ClientProtocol, pc)
if err == nil {
c.logger.Debug("successfully dispensed v5 backend plugin", "name", pluginRunner.Name)
err = client.Setup(ctx, &logical.BackendConfig{})
if err != nil {
return logical.EmptyPluginVersion, nil
}
if versioner, ok := client.(logical.PluginVersioner); ok {
return versioner.PluginVersion(), nil
}
return logical.EmptyPluginVersion, nil
}
merr = multierror.Append(merr, fmt.Errorf("failed to dispense plugin as backend v5: %w", err))
}
c.logger.Debug("failed to dispense v5 backend plugin", "name", pluginRunner.Name, "error", err)
config.AutoMTLS = false
config.IsMetadataMode = true
// attempt to run as a v4 backend plugin
client, err = backendplugin.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("failed to dispense v4 backend plugin: %w", err))
c.logger.Debug("failed to dispense v4 backend plugin", "name", pluginRunner.Name, "error", merr)
return logical.EmptyPluginVersion, merr.ErrorOrNil()
}
c.logger.Debug("successfully dispensed v4 backend plugin", "name", pluginRunner.Name)
defer client.Cleanup(ctx)
err = client.Setup(ctx, &logical.BackendConfig{})
if err != nil {
return logical.EmptyPluginVersion, err
}
if versioner, ok := client.(logical.PluginVersioner); ok {
return versioner.PluginVersion(), nil
}
return logical.EmptyPluginVersion, nil
}
// getDatabaseRunningVersion returns the version reported by a database plugin
func (c *PluginCatalog) getDatabaseRunningVersion(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (logical.PluginVersion, error) {
merr := &multierror.Error{}
config := pluginutil.PluginClientConfig{
Name: pluginRunner.Name,
PluginSets: v5.PluginSets,
PluginType: consts.PluginTypeDatabase,
Version: pluginRunner.Version,
HandshakeConfig: v5.HandshakeConfig,
Logger: log.Default(),
IsMetadataMode: true,
AutoMTLS: true,
}
// Attempt to run as database V5+ multiplexed plugin
c.logger.Debug("attempting to load database plugin as v5", "name", pluginRunner.Name)
v5Client, err := c.newPluginClient(ctx, pluginRunner, config)
if err == nil {
defer func() {
// Close the client and cleanup the plugin process
err = c.cleanupExternalPlugin(pluginRunner.Name, v5Client.id)
if err != nil {
c.logger.Error("error closing plugin client", "error", err)
}
}()
raw, err := v5Client.Dispense("database")
if err != nil {
return logical.EmptyPluginVersion, err
}
if versioner, ok := raw.(logical.PluginVersioner); ok {
return versioner.PluginVersion(), nil
}
return logical.EmptyPluginVersion, nil
}
merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v5: %w", err))
c.logger.Debug("attempting to load database plugin as v4", "name", pluginRunner.Name)
v4Client, err := v4.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true)
if err == nil {
// Close the client and cleanup the plugin process
defer func() {
err = v4Client.Close()
if err != nil {
c.logger.Error("error closing plugin client", "error", err)
}
}()
if versioner, ok := v4Client.(logical.PluginVersioner); ok {
return versioner.PluginVersion(), nil
}
return logical.EmptyPluginVersion, nil
}
merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v4: %w", err))
return logical.EmptyPluginVersion, merr
}
// isDatabasePlugin returns an error if the plugin is not a database plugin.
func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *pluginutil.PluginRunner) error {
merr := &multierror.Error{}
@ -475,7 +593,7 @@ func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *plug
AutoMTLS: true,
}
// Attempt to run as database V5 or V6 multiplexed plugin
// Attempt to run as database V5+ multiplexed plugin
c.logger.Debug("attempting to load database plugin as v5", "name", pluginRunner.Name)
v5Client, err := c.newPluginClient(ctx, pluginRunner, config)
if err == nil {
@ -671,20 +789,19 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType
return nil, errors.New("cannot execute files outside of configured plugin directory")
}
// entryTmp should only be used for the below type and version checks, it uses the
// full command instead of the relative command.
entryTmp := &pluginutil.PluginRunner{
Name: name,
Command: commandFull,
Args: args,
Env: env,
Sha256: sha256,
Builtin: false,
}
// If the plugin type is unknown, we want to attempt to determine the type
if pluginType == consts.PluginTypeUnknown {
// entryTmp should only be used for the below type check, it uses the
// full command instead of the relative command.
entryTmp := &pluginutil.PluginRunner{
Name: name,
Command: commandFull,
Args: args,
Env: env,
Sha256: sha256,
Builtin: false,
}
pluginType, err = c.getPluginTypeFromUnknown(ctx, log.Default(), entryTmp)
pluginType, err = c.getPluginTypeFromUnknown(ctx, entryTmp)
if err != nil {
return nil, err
}
@ -693,6 +810,24 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType
}
}
// getting the plugin version is best-effort, so errors are not fatal
runningVersion := logical.EmptyPluginVersion
var versionErr error
switch pluginType {
case consts.PluginTypeSecrets, consts.PluginTypeCredential:
runningVersion, versionErr = c.getBackendRunningVersion(ctx, entryTmp)
case consts.PluginTypeDatabase:
runningVersion, versionErr = c.getDatabaseRunningVersion(ctx, entryTmp)
default:
return nil, fmt.Errorf("unknown plugin type: %v", pluginType)
}
if versionErr != nil {
c.logger.Warn("Error determining plugin version", "error", versionErr)
} else if version != "" && runningVersion.Version != "" && version != runningVersion.Version {
c.logger.Warn("Plugin self-reported version did not match requested version", "plugin", name, "requestedVersion", version, "reportedVersion", runningVersion.Version)
return nil, fmt.Errorf("plugin version mismatch: %s reported version (%s) did not match requested version (%s)", name, runningVersion.Version, version)
}
entry := &pluginutil.PluginRunner{
Name: name,
Type: pluginType,

View File

@ -7,7 +7,7 @@ import (
"github.com/hashicorp/vault/helper/namespace"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/plugin"
@ -174,11 +174,12 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut
}
var backend logical.Backend
oldSha := entry.RunningSha
if !isAuth {
// Dispense a new backend
backend, err = c.newLogicalBackend(ctx, entry, sysView, view)
backend, entry.RunningSha, err = c.newLogicalBackend(ctx, entry, sysView, view)
} else {
backend, err = c.newCredentialBackend(ctx, entry, sysView, view)
backend, entry.RunningSha, err = c.newCredentialBackend(ctx, entry, sysView, view)
}
if err != nil {
return err
@ -187,6 +188,20 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut
return fmt.Errorf("nil backend of type %q returned from creation function", entry.Type)
}
// update the mount table since we changed the runningSha
if oldSha != entry.RunningSha && MountTableUpdateStorage {
if isAuth {
err = c.persistAuth(ctx, c.auth, &entry.Local)
if err != nil {
return err
}
} else {
err = c.persistMounts(ctx, c.mounts, &entry.Local)
if err != nil {
return err
}
}
}
addPathCheckers(c, entry, backend, viewPath)
if nilMount {