Check if plugin version matches running version (#17182)
Check if plugin version matches running version When registering a plugin, we check if the request version matches the self-reported version from the plugin. If these do not match, we log a warning. This uncovered a few missing pieces for getting the database version code fully working. We added an environment variable that helps us unit test the running version behavior as well, but only for approle, postgresql, and consul plugins. Return 400 on plugin not found or version mismatch Populate the running SHA256 of plugins in the mount and auth tables (#17217)
This commit is contained in:
parent
dc3beb428e
commit
2c8e88ab67
|
@ -18,6 +18,9 @@ const (
|
|||
secretIDAccessorLocalPrefix = "accessor_local/"
|
||||
)
|
||||
|
||||
// ReportedVersion is used to report a specific version to Vault.
|
||||
var ReportedVersion = ""
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
|
||||
|
@ -111,8 +114,9 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
|
|||
pathTidySecretID(b),
|
||||
},
|
||||
),
|
||||
Invalidate: b.invalidate,
|
||||
BackendType: logical.TypeCredential,
|
||||
Invalidate: b.invalidate,
|
||||
BackendType: logical.TypeCredential,
|
||||
RunningVersion: ReportedVersion,
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
|
|
@ -7,6 +7,9 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
// ReportedVersion is used to report a specific version to Vault.
|
||||
var ReportedVersion = ""
|
||||
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
|
@ -34,7 +37,8 @@ func Backend() *backend {
|
|||
Secrets: []*framework.Secret{
|
||||
secretToken(&b),
|
||||
},
|
||||
BackendType: logical.TypeLogical,
|
||||
BackendType: logical.TypeLogical,
|
||||
RunningVersion: ReportedVersion,
|
||||
}
|
||||
|
||||
return &b
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/dbtxn"
|
||||
"github.com/hashicorp/vault/sdk/helper/template"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
)
|
||||
|
||||
|
@ -32,7 +33,8 @@ ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';
|
|||
)
|
||||
|
||||
var (
|
||||
_ dbplugin.Database = &PostgreSQL{}
|
||||
_ dbplugin.Database = (*PostgreSQL)(nil)
|
||||
_ logical.PluginVersioner = (*PostgreSQL)(nil)
|
||||
|
||||
// postgresEndStatement is basically the word "END" but
|
||||
// surrounded by a word boundary to differentiate it from
|
||||
|
@ -46,6 +48,9 @@ var (
|
|||
// singleQuotedPhrases finds substrings like 'hello'
|
||||
// and pulls them out with the quotes included.
|
||||
singleQuotedPhrases = regexp.MustCompile(`('.*?')`)
|
||||
|
||||
// ReportedVersion is used to report a specific version to Vault.
|
||||
ReportedVersion = ""
|
||||
)
|
||||
|
||||
func New() (interface{}, error) {
|
||||
|
@ -469,6 +474,10 @@ func (p *PostgreSQL) secretValues() map[string]string {
|
|||
}
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) PluginVersion() logical.PluginVersion {
|
||||
return logical.PluginVersion{Version: ReportedVersion}
|
||||
}
|
||||
|
||||
// containsMultilineStatement is a best effort to determine whether
|
||||
// a particular statement is multiline, and therefore should not be
|
||||
// split upon semicolons. If it's unsure, it defaults to false.
|
||||
|
|
|
@ -2,12 +2,14 @@ package dbplugin
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
||||
"github.com/hashicorp/vault/sdk/helper/base62"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
@ -43,11 +45,14 @@ func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if db, ok := g.instances[id]; ok {
|
||||
return db, nil
|
||||
}
|
||||
return g.createDatabase(id)
|
||||
}
|
||||
|
||||
// must hold the g.Lock() to call this function
|
||||
func (g *gRPCServer) createDatabase(id string) (Database, error) {
|
||||
db, err := g.factoryFunc()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -304,12 +309,36 @@ func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, e
|
|||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
// getOrForceCreateDatabase will create a database even if the multiplexing ID is not present
|
||||
func (g *gRPCServer) getOrForceCreateDatabase(ctx context.Context) (Database, error) {
|
||||
impl, err := g.getOrCreateDatabase(ctx)
|
||||
if errors.Is(err, pluginutil.ErrNoMultiplexingIDFound) {
|
||||
// if this is called without a multiplexing context, like from the plugin catalog directly,
|
||||
// then we won't have a database ID, so let's generate a new database instance
|
||||
id, err := base62.Random(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
impl, err = g.createDatabase(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return impl, nil
|
||||
}
|
||||
|
||||
// Version forwards the version request to the underlying Database implementation.
|
||||
func (g *gRPCServer) Version(ctx context.Context, _ *logical.Empty) (*logical.VersionReply, error) {
|
||||
impl, err := g.getDatabaseInternal(ctx)
|
||||
impl, err := g.getOrForceCreateDatabase(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if versioner, ok := impl.(logical.PluginVersioner); ok {
|
||||
return &logical.VersionReply{PluginVersion: versioner.PluginVersion().Version}, nil
|
||||
}
|
||||
|
|
|
@ -233,7 +233,10 @@ func (mw databaseMetricsMiddleware) Close() (err error) {
|
|||
// Error Sanitizer Middleware Domain
|
||||
// ///////////////////////////////////////////////////
|
||||
|
||||
var _ Database = DatabaseErrorSanitizerMiddleware{}
|
||||
var (
|
||||
_ Database = (*DatabaseErrorSanitizerMiddleware)(nil)
|
||||
_ logical.PluginVersioner = (*DatabaseErrorSanitizerMiddleware)(nil)
|
||||
)
|
||||
|
||||
// DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and
|
||||
// sanitizes returned error messages
|
||||
|
@ -280,6 +283,13 @@ func (mw DatabaseErrorSanitizerMiddleware) Close() (err error) {
|
|||
return mw.sanitize(mw.next.Close())
|
||||
}
|
||||
|
||||
func (mw DatabaseErrorSanitizerMiddleware) PluginVersion() logical.PluginVersion {
|
||||
if versioner, ok := mw.next.(logical.PluginVersioner); ok {
|
||||
return versioner.PluginVersion()
|
||||
}
|
||||
return logical.EmptyPluginVersion
|
||||
}
|
||||
|
||||
// sanitize errors by removing any sensitive strings within their messages. This uses
|
||||
// the secretsFn to determine what fields should be sanitized.
|
||||
func (mw DatabaseErrorSanitizerMiddleware) sanitize(err error) error {
|
||||
|
|
|
@ -2,6 +2,7 @@ package pluginutil
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
@ -13,6 +14,8 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var ErrNoMultiplexingIDFound = errors.New("no multiplexing ID found")
|
||||
|
||||
type PluginMultiplexingServerImpl struct {
|
||||
UnimplementedPluginMultiplexingServer
|
||||
|
||||
|
@ -62,7 +65,9 @@ func GetMultiplexIDFromContext(ctx context.Context) (string, error) {
|
|||
}
|
||||
|
||||
multiplexIDs := md[MultiplexingCtxKey]
|
||||
if len(multiplexIDs) != 1 {
|
||||
if len(multiplexIDs) == 0 {
|
||||
return "", ErrNoMultiplexingIDFound
|
||||
} else if len(multiplexIDs) != 1 {
|
||||
return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs))
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package vault
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -170,7 +171,7 @@ func (c *Core) enableCredentialInternal(ctx context.Context, entry *MountEntry,
|
|||
var backend logical.Backend
|
||||
// Create the new backend
|
||||
sysView := c.mountEntrySysView(entry)
|
||||
backend, err = c.newCredentialBackend(ctx, entry, sysView, view)
|
||||
backend, entry.RunningSha, err = c.newCredentialBackend(ctx, entry, sysView, view)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -794,7 +795,7 @@ func (c *Core) setupCredentials(ctx context.Context) error {
|
|||
// Initialize the backend
|
||||
sysView := c.mountEntrySysView(entry)
|
||||
|
||||
backend, err = c.newCredentialBackend(ctx, entry, sysView, view)
|
||||
backend, entry.RunningSha, err = c.newCredentialBackend(ctx, entry, sysView, view)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to create credential entry", "path", entry.Path, "error", err)
|
||||
plug, plugerr := c.pluginCatalog.Get(ctx, entry.Type, consts.PluginTypeCredential, "")
|
||||
|
@ -913,25 +914,30 @@ func (c *Core) teardownCredentials(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// newCredentialBackend is used to create and configure a new credential backend by name
|
||||
func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, error) {
|
||||
// newCredentialBackend is used to create and configure a new credential backend by name.
|
||||
// It also returns the SHA256 of the plugin, if available.
|
||||
func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, string, error) {
|
||||
t := entry.Type
|
||||
if alias, ok := credentialAliases[t]; ok {
|
||||
t = alias
|
||||
}
|
||||
|
||||
var runningSha string
|
||||
f, ok := c.credentialBackends[t]
|
||||
if !ok {
|
||||
plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeCredential, entry.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
if plug == nil {
|
||||
errContext := t
|
||||
if entry.Version != "" {
|
||||
errContext += fmt.Sprintf(", version=%s", entry.Version)
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, errContext)
|
||||
return nil, "", fmt.Errorf("%w: %s", ErrPluginNotFound, errContext)
|
||||
}
|
||||
if len(plug.Sha256) > 0 {
|
||||
runningSha = hex.EncodeToString(plug.Sha256)
|
||||
}
|
||||
|
||||
f = plugin.Factory
|
||||
|
@ -967,10 +973,10 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV
|
|||
|
||||
b, err := f(ctx, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
return b, runningSha, nil
|
||||
}
|
||||
|
||||
// defaultAuthTable creates a default auth table
|
||||
|
|
|
@ -3,6 +3,7 @@ package vault
|
|||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -16,18 +17,19 @@ import (
|
|||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
var (
|
||||
compileAuthOnce sync.Once
|
||||
compileSecretOnce sync.Once
|
||||
authPluginBytes []byte
|
||||
secretPluginBytes []byte
|
||||
pluginCacheLock sync.Mutex
|
||||
pluginCache = map[string][]byte{}
|
||||
)
|
||||
|
||||
func testCoreWithPlugin(t *testing.T, typ consts.PluginType) (*Core, string, string) {
|
||||
// version is used to override the plugin's self-reported version
|
||||
func testCoreWithPlugin(t *testing.T, typ consts.PluginType, version string) (*Core, string, string) {
|
||||
t.Helper()
|
||||
pluginName, pluginSHA256, pluginDir := compilePlugin(t, typ)
|
||||
pluginName, pluginSHA256, pluginDir := compilePlugin(t, typ, version)
|
||||
conf := &CoreConfig{
|
||||
BuiltinRegistry: NewMockBuiltinRegistry(),
|
||||
PluginDirectory: pluginDir,
|
||||
|
@ -37,29 +39,46 @@ func testCoreWithPlugin(t *testing.T, typ consts.PluginType) (*Core, string, str
|
|||
return core, pluginName, pluginSHA256
|
||||
}
|
||||
|
||||
// to mount a plugin, we need a working binary plugin, so we compile one here.
|
||||
func compilePlugin(t *testing.T, typ consts.PluginType) (name string, shasum string, pluginDir string) {
|
||||
func getPlugin(t *testing.T, typ consts.PluginType) (string, string, string, string) {
|
||||
t.Helper()
|
||||
var pluginName string
|
||||
var pluginType string
|
||||
var pluginMain string
|
||||
var pluginVersionLocation string
|
||||
|
||||
var pluginType, pluginName, builtinDirectory string
|
||||
var once *sync.Once
|
||||
var pluginBytes *[]byte
|
||||
switch typ {
|
||||
case consts.PluginTypeCredential:
|
||||
pluginType = "approle"
|
||||
pluginName = "vault-plugin-auth-" + pluginType
|
||||
builtinDirectory = "credential"
|
||||
once = &compileAuthOnce
|
||||
pluginBytes = &authPluginBytes
|
||||
pluginMain = filepath.Join("builtin", "credential", pluginType, "cmd", pluginType, "main.go")
|
||||
pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/builtin/credential/%s.ReportedVersion", pluginType)
|
||||
case consts.PluginTypeSecrets:
|
||||
pluginType = "consul"
|
||||
pluginName = "vault-plugin-secrets-" + pluginType
|
||||
builtinDirectory = "logical"
|
||||
once = &compileSecretOnce
|
||||
pluginBytes = &secretPluginBytes
|
||||
pluginMain = filepath.Join("builtin", "logical", pluginType, "cmd", pluginType, "main.go")
|
||||
pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/builtin/logical/%s.ReportedVersion", pluginType)
|
||||
case consts.PluginTypeDatabase:
|
||||
pluginType = "postgresql"
|
||||
pluginName = "vault-plugin-database-" + pluginType
|
||||
pluginMain = filepath.Join("plugins", "database", pluginType, fmt.Sprintf("%s-database-plugin", pluginType), "main.go")
|
||||
pluginVersionLocation = fmt.Sprintf("github.com/hashicorp/vault/plugins/database/%s.ReportedVersion", pluginType)
|
||||
default:
|
||||
t.Fatal(typ.String())
|
||||
}
|
||||
return pluginName, pluginType, pluginMain, pluginVersionLocation
|
||||
}
|
||||
|
||||
// to mount a plugin, we need a working binary plugin, so we compile one here.
|
||||
// pluginVersion is used to override the plugin's self-reported version
|
||||
func compilePlugin(t *testing.T, typ consts.PluginType, pluginVersion string) (pluginName string, shasum string, pluginDir string) {
|
||||
t.Helper()
|
||||
|
||||
pluginName, pluginType, pluginMain, pluginVersionLocation := getPlugin(t, typ)
|
||||
|
||||
pluginCacheLock.Lock()
|
||||
defer pluginCacheLock.Unlock()
|
||||
|
||||
var pluginBytes []byte
|
||||
|
||||
dir := ""
|
||||
// detect if we are in the "vault/" or the root directory and compensate
|
||||
|
@ -76,31 +95,41 @@ func compilePlugin(t *testing.T, typ consts.PluginType) (name string, shasum str
|
|||
|
||||
pluginPath := path.Join(pluginDir, pluginName)
|
||||
|
||||
key := fmt.Sprintf("%s %s %s", pluginName, pluginType, pluginVersion)
|
||||
// cache the compilation to only run once
|
||||
once.Do(func() {
|
||||
cmd := exec.Command("go", "build", "-o", pluginPath, fmt.Sprintf("builtin/%s/%s/cmd/%s/main.go", builtinDirectory, pluginType, pluginType))
|
||||
var ok bool
|
||||
pluginBytes, ok = pluginCache[key]
|
||||
if !ok {
|
||||
// we need to compile
|
||||
line := []string{"build"}
|
||||
if pluginVersion != "" {
|
||||
line = append(line, "-ldflags", fmt.Sprintf("-X %s=%s", pluginVersionLocation, pluginVersion))
|
||||
}
|
||||
line = append(line, "-o", pluginPath, pluginMain)
|
||||
cmd := exec.Command("go", line...)
|
||||
cmd.Dir = dir
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("error running go build %v output: %s", err, output))
|
||||
}
|
||||
*pluginBytes, err = os.ReadFile(pluginPath)
|
||||
pluginCache[key], err = os.ReadFile(pluginPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
pluginBytes = pluginCache[key]
|
||||
}
|
||||
|
||||
// write the cached plugin if necessary
|
||||
var err error
|
||||
if _, err := os.Stat(pluginPath); os.IsNotExist(err) {
|
||||
err = os.WriteFile(pluginPath, *pluginBytes, 0o777)
|
||||
err = os.WriteFile(pluginPath, pluginBytes, 0o777)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sha := sha256.New()
|
||||
_, err = sha.Write(*pluginBytes)
|
||||
_, err = sha.Write(pluginBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -125,7 +154,7 @@ func TestCore_EnableExternalPlugin(t *testing.T) {
|
|||
},
|
||||
} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType)
|
||||
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "")
|
||||
d := &framework.FieldData{
|
||||
Raw: map[string]interface{}{
|
||||
"name": pluginName,
|
||||
|
@ -201,7 +230,7 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) {
|
|||
},
|
||||
} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType)
|
||||
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "")
|
||||
for _, version := range tc.registerVersions {
|
||||
d := &framework.FieldData{
|
||||
Raw: map[string]interface{}{
|
||||
|
@ -247,6 +276,10 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) {
|
|||
if raw.(*routeEntry).mountEntry.RunningVersion != "" {
|
||||
t.Errorf("Expected mount to have no running version but got %s", raw.(*routeEntry).mountEntry.RunningVersion)
|
||||
}
|
||||
|
||||
if raw.(*routeEntry).mountEntry.RunningSha == "" {
|
||||
t.Errorf("Expected RunningSha to be present: %+v", raw.(*routeEntry).mountEntry.RunningSha)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -269,7 +302,7 @@ func TestCore_EnableExternalPlugin_NoVersionsOkay(t *testing.T) {
|
|||
},
|
||||
} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType)
|
||||
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "")
|
||||
d := &framework.FieldData{
|
||||
Raw: map[string]interface{}{
|
||||
"name": pluginName,
|
||||
|
@ -328,7 +361,7 @@ func TestCore_EnableExternalCredentialPlugin_NoVersionOnRegister(t *testing.T) {
|
|||
},
|
||||
} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType)
|
||||
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "")
|
||||
d := &framework.FieldData{
|
||||
Raw: map[string]interface{}{
|
||||
"name": pluginName,
|
||||
|
@ -372,7 +405,7 @@ func TestCore_EnableExternalCredentialPlugin_InvalidName(t *testing.T) {
|
|||
},
|
||||
} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType)
|
||||
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "")
|
||||
d := &framework.FieldData{
|
||||
Raw: map[string]interface{}{
|
||||
"name": pluginName,
|
||||
|
@ -390,6 +423,69 @@ func TestCore_EnableExternalCredentialPlugin_InvalidName(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestExternalPlugin_getBackendTypeVersion(t *testing.T) {
|
||||
for name, tc := range map[string]struct {
|
||||
pluginType consts.PluginType
|
||||
setRunningVersion string
|
||||
}{
|
||||
"external credential plugin": {
|
||||
pluginType: consts.PluginTypeCredential,
|
||||
setRunningVersion: "v1.2.3",
|
||||
},
|
||||
"external secrets plugin": {
|
||||
pluginType: consts.PluginTypeSecrets,
|
||||
setRunningVersion: "v1.2.3",
|
||||
},
|
||||
"external database plugin": {
|
||||
pluginType: consts.PluginTypeDatabase,
|
||||
setRunningVersion: "v1.2.3",
|
||||
},
|
||||
} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, tc.setRunningVersion)
|
||||
d := &framework.FieldData{
|
||||
Raw: map[string]interface{}{
|
||||
"name": pluginName,
|
||||
"sha256": pluginSHA256,
|
||||
"version": tc.setRunningVersion,
|
||||
"command": pluginName,
|
||||
},
|
||||
Schema: c.systemBackend.pluginsCatalogCRUDPath().Fields,
|
||||
}
|
||||
resp, err := c.systemBackend.handlePluginCatalogUpdate(context.Background(), nil, d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Error() != nil {
|
||||
t.Fatalf("%#v", resp)
|
||||
}
|
||||
|
||||
shaBytes, _ := hex.DecodeString(pluginSHA256)
|
||||
commandFull := filepath.Join(c.pluginCatalog.directory, pluginName)
|
||||
entry := &pluginutil.PluginRunner{
|
||||
Name: pluginName,
|
||||
Command: commandFull,
|
||||
Args: nil,
|
||||
Sha256: shaBytes,
|
||||
Builtin: false,
|
||||
}
|
||||
|
||||
var version logical.PluginVersion
|
||||
if tc.pluginType == consts.PluginTypeDatabase {
|
||||
version, err = c.pluginCatalog.getDatabaseRunningVersion(context.Background(), entry)
|
||||
} else {
|
||||
version, err = c.pluginCatalog.getBackendRunningVersion(context.Background(), entry)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version.Version != tc.setRunningVersion {
|
||||
t.Errorf("Expected to get version %v but got %v", tc.setRunningVersion, version.Version)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mountTable(pluginType consts.PluginType) string {
|
||||
switch pluginType {
|
||||
case consts.PluginTypeCredential:
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
credGithub "github.com/hashicorp/vault/builtin/credential/github"
|
||||
"github.com/hashicorp/vault/helper/identity"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
|
@ -688,7 +688,7 @@ func TestIdentityStore_LoadingEntities(t *testing.T) {
|
|||
ghSysview := c.mountEntrySysView(meGH)
|
||||
|
||||
// Create new github auth credential backend
|
||||
ghAuth, err := c.newCredentialBackend(context.Background(), meGH, ghSysview, ghView)
|
||||
ghAuth, _, err := c.newCredentialBackend(context.Background(), meGH, ghSysview, ghView)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -517,6 +517,9 @@ func (b *SystemBackend) handlePluginCatalogUpdate(ctx context.Context, _ *logica
|
|||
|
||||
err = b.Core.pluginCatalog.Set(ctx, pluginName, pluginType, pluginVersion, parts[0], args, env, sha256Bytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrPluginNotFound) || strings.HasPrefix(err.Error(), "plugin version mismatch") {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package vault
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -608,7 +609,7 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry, updateStora
|
|||
var backend logical.Backend
|
||||
sysView := c.mountEntrySysView(entry)
|
||||
|
||||
backend, err = c.newLogicalBackend(ctx, entry, sysView, view)
|
||||
backend, entry.RunningSha, err = c.newLogicalBackend(ctx, entry, sysView, view)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1419,7 +1420,7 @@ func (c *Core) setupMounts(ctx context.Context) error {
|
|||
var backend logical.Backend
|
||||
// Create the new backend
|
||||
sysView := c.mountEntrySysView(entry)
|
||||
backend, err = c.newLogicalBackend(ctx, entry, sysView, view)
|
||||
backend, entry.RunningSha, err = c.newLogicalBackend(ctx, entry, sysView, view)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to create mount entry", "path", entry.Path, "error", err)
|
||||
if !c.builtinRegistry.Contains(entry.Type, consts.PluginTypeSecrets) {
|
||||
|
@ -1523,25 +1524,30 @@ func (c *Core) unloadMounts(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// newLogicalBackend is used to create and configure a new logical backend by name
|
||||
func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, error) {
|
||||
// newLogicalBackend is used to create and configure a new logical backend by name.
|
||||
// It also returns the SHA256 of the plugin, if available.
|
||||
func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView logical.SystemView, view logical.Storage) (logical.Backend, string, error) {
|
||||
t := entry.Type
|
||||
if alias, ok := mountAliases[t]; ok {
|
||||
t = alias
|
||||
}
|
||||
|
||||
var runningSha string
|
||||
f, ok := c.logicalBackends[t]
|
||||
if !ok {
|
||||
plug, err := c.pluginCatalog.Get(ctx, t, consts.PluginTypeSecrets, entry.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
if plug == nil {
|
||||
errContext := t
|
||||
if entry.Version != "" {
|
||||
errContext += fmt.Sprintf(", version=%s", entry.Version)
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, errContext)
|
||||
return nil, "", fmt.Errorf("%w: %s", ErrPluginNotFound, errContext)
|
||||
}
|
||||
if len(plug.Sha256) > 0 {
|
||||
runningSha = hex.EncodeToString(plug.Sha256)
|
||||
}
|
||||
|
||||
f = plugin.Factory
|
||||
|
@ -1578,14 +1584,14 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView
|
|||
ctx = context.WithValue(ctx, "core_number", c.coreNumber)
|
||||
b, err := f(ctx, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
if b == nil {
|
||||
return nil, fmt.Errorf("nil backend of type %q returned from factory", t)
|
||||
return nil, "", fmt.Errorf("nil backend of type %q returned from factory", t)
|
||||
}
|
||||
addLicenseCallback(c, b)
|
||||
|
||||
return b, nil
|
||||
return b, runningSha, nil
|
||||
}
|
||||
|
||||
// defaultMountTable creates a default mount table
|
||||
|
|
|
@ -365,7 +365,7 @@ func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *plugi
|
|||
// getPluginTypeFromUnknown will attempt to run the plugin to determine the
|
||||
// type. It will first attempt to run as a database plugin then a backend
|
||||
// plugin.
|
||||
func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, error) {
|
||||
func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, plugin *pluginutil.PluginRunner) (consts.PluginType, error) {
|
||||
merr := &multierror.Error{}
|
||||
err := c.isDatabasePlugin(ctx, plugin)
|
||||
if err == nil {
|
||||
|
@ -461,6 +461,124 @@ func (c *PluginCatalog) getBackendPluginType(ctx context.Context, pluginRunner *
|
|||
return consts.PluginTypeUnknown, merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
// getBackendRunningVersion attempts to get the plugin version
|
||||
func (c *PluginCatalog) getBackendRunningVersion(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (logical.PluginVersion, error) {
|
||||
merr := &multierror.Error{}
|
||||
// Attempt to run as backend plugin
|
||||
config := pluginutil.PluginClientConfig{
|
||||
Name: pluginRunner.Name,
|
||||
PluginSets: backendplugin.PluginSet,
|
||||
HandshakeConfig: backendplugin.HandshakeConfig,
|
||||
Logger: log.NewNullLogger(),
|
||||
IsMetadataMode: false,
|
||||
AutoMTLS: true,
|
||||
}
|
||||
|
||||
var client logical.Backend
|
||||
// First, attempt to run as backend V5 plugin
|
||||
c.logger.Debug("attempting to load backend plugin", "name", pluginRunner.Name)
|
||||
pc, err := c.newPluginClient(ctx, pluginRunner, config)
|
||||
if err == nil {
|
||||
// we spawned a subprocess, so make sure to clean it up
|
||||
defer c.cleanupExternalPlugin(pluginRunner.Name, pc.id)
|
||||
|
||||
// dispense the plugin so we can get its version
|
||||
client, err = backendplugin.Dispense(pc.ClientProtocol, pc)
|
||||
if err == nil {
|
||||
c.logger.Debug("successfully dispensed v5 backend plugin", "name", pluginRunner.Name)
|
||||
|
||||
err = client.Setup(ctx, &logical.BackendConfig{})
|
||||
if err != nil {
|
||||
return logical.EmptyPluginVersion, nil
|
||||
}
|
||||
if versioner, ok := client.(logical.PluginVersioner); ok {
|
||||
return versioner.PluginVersion(), nil
|
||||
}
|
||||
return logical.EmptyPluginVersion, nil
|
||||
}
|
||||
merr = multierror.Append(merr, fmt.Errorf("failed to dispense plugin as backend v5: %w", err))
|
||||
}
|
||||
c.logger.Debug("failed to dispense v5 backend plugin", "name", pluginRunner.Name, "error", err)
|
||||
config.AutoMTLS = false
|
||||
config.IsMetadataMode = true
|
||||
// attempt to run as a v4 backend plugin
|
||||
client, err = backendplugin.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("failed to dispense v4 backend plugin: %w", err))
|
||||
c.logger.Debug("failed to dispense v4 backend plugin", "name", pluginRunner.Name, "error", merr)
|
||||
return logical.EmptyPluginVersion, merr.ErrorOrNil()
|
||||
}
|
||||
c.logger.Debug("successfully dispensed v4 backend plugin", "name", pluginRunner.Name)
|
||||
defer client.Cleanup(ctx)
|
||||
|
||||
err = client.Setup(ctx, &logical.BackendConfig{})
|
||||
if err != nil {
|
||||
return logical.EmptyPluginVersion, err
|
||||
}
|
||||
if versioner, ok := client.(logical.PluginVersioner); ok {
|
||||
return versioner.PluginVersion(), nil
|
||||
}
|
||||
return logical.EmptyPluginVersion, nil
|
||||
}
|
||||
|
||||
// getDatabaseRunningVersion returns the version reported by a database plugin
|
||||
func (c *PluginCatalog) getDatabaseRunningVersion(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (logical.PluginVersion, error) {
|
||||
merr := &multierror.Error{}
|
||||
config := pluginutil.PluginClientConfig{
|
||||
Name: pluginRunner.Name,
|
||||
PluginSets: v5.PluginSets,
|
||||
PluginType: consts.PluginTypeDatabase,
|
||||
Version: pluginRunner.Version,
|
||||
HandshakeConfig: v5.HandshakeConfig,
|
||||
Logger: log.Default(),
|
||||
IsMetadataMode: true,
|
||||
AutoMTLS: true,
|
||||
}
|
||||
|
||||
// Attempt to run as database V5+ multiplexed plugin
|
||||
c.logger.Debug("attempting to load database plugin as v5", "name", pluginRunner.Name)
|
||||
v5Client, err := c.newPluginClient(ctx, pluginRunner, config)
|
||||
if err == nil {
|
||||
defer func() {
|
||||
// Close the client and cleanup the plugin process
|
||||
err = c.cleanupExternalPlugin(pluginRunner.Name, v5Client.id)
|
||||
if err != nil {
|
||||
c.logger.Error("error closing plugin client", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
raw, err := v5Client.Dispense("database")
|
||||
if err != nil {
|
||||
return logical.EmptyPluginVersion, err
|
||||
}
|
||||
if versioner, ok := raw.(logical.PluginVersioner); ok {
|
||||
return versioner.PluginVersion(), nil
|
||||
}
|
||||
return logical.EmptyPluginVersion, nil
|
||||
}
|
||||
merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v5: %w", err))
|
||||
|
||||
c.logger.Debug("attempting to load database plugin as v4", "name", pluginRunner.Name)
|
||||
v4Client, err := v4.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true)
|
||||
if err == nil {
|
||||
// Close the client and cleanup the plugin process
|
||||
defer func() {
|
||||
err = v4Client.Close()
|
||||
if err != nil {
|
||||
c.logger.Error("error closing plugin client", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if versioner, ok := v4Client.(logical.PluginVersioner); ok {
|
||||
return versioner.PluginVersion(), nil
|
||||
}
|
||||
|
||||
return logical.EmptyPluginVersion, nil
|
||||
}
|
||||
merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v4: %w", err))
|
||||
return logical.EmptyPluginVersion, merr
|
||||
}
|
||||
|
||||
// isDatabasePlugin returns an error if the plugin is not a database plugin.
|
||||
func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *pluginutil.PluginRunner) error {
|
||||
merr := &multierror.Error{}
|
||||
|
@ -475,7 +593,7 @@ func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *plug
|
|||
AutoMTLS: true,
|
||||
}
|
||||
|
||||
// Attempt to run as database V5 or V6 multiplexed plugin
|
||||
// Attempt to run as database V5+ multiplexed plugin
|
||||
c.logger.Debug("attempting to load database plugin as v5", "name", pluginRunner.Name)
|
||||
v5Client, err := c.newPluginClient(ctx, pluginRunner, config)
|
||||
if err == nil {
|
||||
|
@ -671,20 +789,19 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType
|
|||
return nil, errors.New("cannot execute files outside of configured plugin directory")
|
||||
}
|
||||
|
||||
// entryTmp should only be used for the below type and version checks, it uses the
|
||||
// full command instead of the relative command.
|
||||
entryTmp := &pluginutil.PluginRunner{
|
||||
Name: name,
|
||||
Command: commandFull,
|
||||
Args: args,
|
||||
Env: env,
|
||||
Sha256: sha256,
|
||||
Builtin: false,
|
||||
}
|
||||
// If the plugin type is unknown, we want to attempt to determine the type
|
||||
if pluginType == consts.PluginTypeUnknown {
|
||||
// entryTmp should only be used for the below type check, it uses the
|
||||
// full command instead of the relative command.
|
||||
entryTmp := &pluginutil.PluginRunner{
|
||||
Name: name,
|
||||
Command: commandFull,
|
||||
Args: args,
|
||||
Env: env,
|
||||
Sha256: sha256,
|
||||
Builtin: false,
|
||||
}
|
||||
|
||||
pluginType, err = c.getPluginTypeFromUnknown(ctx, log.Default(), entryTmp)
|
||||
pluginType, err = c.getPluginTypeFromUnknown(ctx, entryTmp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -693,6 +810,24 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType
|
|||
}
|
||||
}
|
||||
|
||||
// getting the plugin version is best-effort, so errors are not fatal
|
||||
runningVersion := logical.EmptyPluginVersion
|
||||
var versionErr error
|
||||
switch pluginType {
|
||||
case consts.PluginTypeSecrets, consts.PluginTypeCredential:
|
||||
runningVersion, versionErr = c.getBackendRunningVersion(ctx, entryTmp)
|
||||
case consts.PluginTypeDatabase:
|
||||
runningVersion, versionErr = c.getDatabaseRunningVersion(ctx, entryTmp)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown plugin type: %v", pluginType)
|
||||
}
|
||||
if versionErr != nil {
|
||||
c.logger.Warn("Error determining plugin version", "error", versionErr)
|
||||
} else if version != "" && runningVersion.Version != "" && version != runningVersion.Version {
|
||||
c.logger.Warn("Plugin self-reported version did not match requested version", "plugin", name, "requestedVersion", version, "reportedVersion", runningVersion.Version)
|
||||
return nil, fmt.Errorf("plugin version mismatch: %s reported version (%s) did not match requested version (%s)", name, runningVersion.Version, version)
|
||||
}
|
||||
|
||||
entry := &pluginutil.PluginRunner{
|
||||
Name: name,
|
||||
Type: pluginType,
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
|
||||
multierror "github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/sdk/plugin"
|
||||
|
@ -174,11 +174,12 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut
|
|||
}
|
||||
|
||||
var backend logical.Backend
|
||||
oldSha := entry.RunningSha
|
||||
if !isAuth {
|
||||
// Dispense a new backend
|
||||
backend, err = c.newLogicalBackend(ctx, entry, sysView, view)
|
||||
backend, entry.RunningSha, err = c.newLogicalBackend(ctx, entry, sysView, view)
|
||||
} else {
|
||||
backend, err = c.newCredentialBackend(ctx, entry, sysView, view)
|
||||
backend, entry.RunningSha, err = c.newCredentialBackend(ctx, entry, sysView, view)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -187,6 +188,20 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut
|
|||
return fmt.Errorf("nil backend of type %q returned from creation function", entry.Type)
|
||||
}
|
||||
|
||||
// update the mount table since we changed the runningSha
|
||||
if oldSha != entry.RunningSha && MountTableUpdateStorage {
|
||||
if isAuth {
|
||||
err = c.persistAuth(ctx, c.auth, &entry.Local)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err = c.persistMounts(ctx, c.mounts, &entry.Local)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
addPathCheckers(c, entry, backend, viewPath)
|
||||
|
||||
if nilMount {
|
||||
|
|
Loading…
Reference in New Issue