feature: multiplexing support for database plugins (#14033)

* feat: DB plugin multiplexing (#13734)

* WIP: start from main and get a plugin runner from core

* move MultiplexedClient map to plugin catalog
- call sys.NewPluginClient from PluginFactory
- updates to getPluginClient
- thread through isMetadataMode

* use go-plugin ClientProtocol interface
- call sys.NewPluginClient from dbplugin.NewPluginClient

* move PluginSets to dbplugin package
- export dbplugin HandshakeConfig
- small refactor of PluginCatalog.getPluginClient

* add removeMultiplexedClient; clean up on Close()
- call client.Kill from plugin catalog
- set rpcClient when muxed client exists

* add ID to dbplugin.DatabasePluginClient struct

* only create one plugin process per plugin type

* update NewPluginClient to return connection ID to sdk
- wrap grpc.ClientConn so we can inject the ID into context
- get ID from context on grpc server

* add v6 multiplexing  protocol version

* WIP: backwards compat for db plugins

* Ensure locking on plugin catalog access

- Create public GetPluginClient method for plugin catalog
- rename postgres db plugin

* use the New constructor for db plugins

* grpc server: use write lock for Close and rlock for CRUD

* cleanup MultiplexedClients on Close

* remove TODO

* fix multiplexing regression with grpc server connection

* cleanup grpc server instances on close

* embed ClientProtocol in Multiplexer interface

* use PluginClientConfig arg to make NewPluginClient plugin type agnostic

* create a new plugin process for non-muxed plugins

* feat: plugin multiplexing: handle plugin client cleanup (#13896)

* use closure for plugin client cleanup

* log and return errors; add comments

* move rpcClient wrapping to core for ID injection

* refactor core plugin client and sdk

* remove unused ID method

* refactor and only wrap clientConn on multiplexed plugins

* rename structs and do not export types

* Slight refactor of system view interface

* Revert "Slight refactor of system view interface"

This reverts commit 73d420e5cd2f0415e000c5a9284ea72a58016dd6.

* Revert "Revert "Slight refactor of system view interface""

This reverts commit f75527008a1db06d04a23e04c3059674be8adb5f.

* only provide pluginRunner arg to the internal newPluginClient method

* embed ClientProtocol in pluginClient and name logger

* Add back MLock support

* remove enableMlock arg from setupPluginCatalog

* rename plugin util interface to PluginClient

Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>

* feature: multiplexing: fix unit tests (#14007)

* fix grpc_server tests and add coverage

* update run_config tests

* add happy path test case for grpc_server ID from context

* update test helpers

* feat: multiplexing: handle v5 plugin compiled with new sdk

* add mux supported flag and increase test coverage

* set multiplexingSupport field in plugin server

* remove multiplexingSupport field in sdk

* revert postgres to non-multiplexed

* add comments on grpc server fields

* use pointer receiver on grpc server methods

* add changelog

* use pointer for grpcserver instance

* Use a gRPC server to determine if a plugin should be multiplexed

* Apply suggestions from code review

Co-authored-by: Brian Kassouf <briankassouf@users.noreply.github.com>

* add lock to removePluginClient

* add multiplexingSupport field to externalPlugin struct

* do not send nil to grpc MultiplexingSupport

* check err before logging

* handle locking scenario for cleanupFunc

* allow ServeConfigMultiplex to dispense v5 plugin

* reposition structs, add err check and comments

* add comment on locking for cleanupExternalPlugin

Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
Co-authored-by: Brian Kassouf <briankassouf@users.noreply.github.com>
This commit is contained in:
John-Michael Faircloth 2022-02-17 08:50:33 -06:00 committed by GitHub
parent 91f5069c03
commit 1cf74e1179
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 1262 additions and 247 deletions

View File

@ -194,6 +194,7 @@ proto: bootstrap
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/database/dbplugin/*.proto
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/database/dbplugin/v5/proto/*.proto
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/plugin/pb/*.proto
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/helper/pluginutil/*.proto
# No additional sed expressions should be added to this list. Going forward
# we should just use the variable names choosen by protobuf. These are left

View File

@ -110,6 +110,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
}
type databaseBackend struct {
// connections holds configured database connections by config name
connections map[string]*dbPluginInstance
logger log.Logger

View File

@ -329,6 +329,8 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
}
config.ConnectionDetails = initResp.Config
b.Logger().Debug("created database object", "name", name, "plugin_name", config.PluginName)
b.Lock()
defer b.Unlock()
@ -365,6 +367,9 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
"Vault (or the sdk if using a custom plugin) to gain password policy support", config.PluginName))
}
if len(resp.Warnings) == 0 {
return nil, nil
}
return resp, nil
}
}

3
changelog/14033.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:feature
**Database plugin multiplexing**: manage multiple database connections with a single plugin process
```

View File

@ -150,8 +150,8 @@ type registry struct {
logicalBackends map[string]logical.Factory
}
// Get returns the BuiltinFactory func for a particular backend plugin
// from the plugins map.
// Get returns the Factory func for a particular backend plugin from the
// plugins map.
func (r *registry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) {
switch pluginType {
case consts.PluginTypeCredential:

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.17.3
// protoc v3.19.4
// source: helper/forwarding/types.proto
package forwarding

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.17.3
// protoc v3.19.4
// source: helper/identity/mfa/types.proto
package mfa

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.17.3
// protoc v3.19.4
// source: helper/identity/types.proto
package identity

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.17.3
// protoc v3.19.4
// source: helper/storagepacker/types.proto
package storagepacker

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.17.3
// protoc v3.19.4
// source: physical/raft/types.proto
package raft

View File

@ -48,7 +48,6 @@ var (
singleQuotedPhrases = regexp.MustCompile(`('.*?')`)
)
// New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) {
db := new()
// Wrap the plugin with middleware to sanitize errors

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.17.3
// protoc v3.19.4
// source: sdk/database/dbplugin/database.proto
package dbplugin

View File

@ -5,6 +5,7 @@ import (
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"google.golang.org/grpc"
)
@ -12,14 +13,17 @@ import (
// a plugin and host. If the handshake fails, a user friendly error is shown.
// This prevents users from executing bad plugins or executing a plugin
// directory. It is a UX feature, not a security feature.
var handshakeConfig = plugin.HandshakeConfig{
ProtocolVersion: 5,
var HandshakeConfig = plugin.HandshakeConfig{
MagicCookieKey: "VAULT_DATABASE_PLUGIN",
MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb",
}
// Factory is the factory function to create a dbplugin Database.
type Factory func() (interface{}, error)
type GRPCDatabasePlugin struct {
Impl Database
FactoryFunc Factory
Impl Database
// Embeding this will disable the netRPC protocol
plugin.NetRPCUnsupportedPlugin
@ -31,7 +35,25 @@ var (
)
func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error {
proto.RegisterDatabaseServer(s, gRPCServer{impl: d.Impl})
var server gRPCServer
if d.Impl != nil {
server = gRPCServer{singleImpl: d.Impl}
} else {
// multiplexing is supported
server = gRPCServer{
factoryFunc: d.FactoryFunc,
instances: make(map[string]Database),
}
// Multiplexing is enabled for this plugin, register the server so we
// can tell the client in Vault.
pluginutil.RegisterPluginMultiplexingServer(s, pluginutil.PluginMultiplexingServerImpl{
Supported: true,
})
}
proto.RegisterDatabaseServer(s, &server)
return nil
}

View File

@ -3,24 +3,113 @@ package dbplugin
import (
"context"
"fmt"
"sync"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
var _ proto.DatabaseServer = gRPCServer{}
var _ proto.DatabaseServer = &gRPCServer{}
type gRPCServer struct {
proto.UnimplementedDatabaseServer
impl Database
// holds the non-multiplexed Database
// when this is set the plugin does not support multiplexing
singleImpl Database
// instances holds the multiplexed Databases
instances map[string]Database
factoryFunc func() (interface{}, error)
sync.RWMutex
}
func getMultiplexIDFromContext(ctx context.Context) (string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", fmt.Errorf("missing plugin multiplexing metadata")
}
multiplexIDs := md[pluginutil.MultiplexingCtxKey]
if len(multiplexIDs) != 1 {
return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs))
}
multiplexID := multiplexIDs[0]
if multiplexID == "" {
return "", fmt.Errorf("empty multiplex ID in metadata")
}
return multiplexID, nil
}
func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) {
g.Lock()
defer g.Unlock()
if g.singleImpl != nil {
return g.singleImpl, nil
}
id, err := getMultiplexIDFromContext(ctx)
if err != nil {
return nil, err
}
if db, ok := g.instances[id]; ok {
return db, nil
}
db, err := g.factoryFunc()
if err != nil {
return nil, err
}
database := db.(Database)
g.instances[id] = database
return database, nil
}
// getDatabaseInternal returns the database but does not hold a lock
func (g *gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) {
if g.singleImpl != nil {
return g.singleImpl, nil
}
id, err := getMultiplexIDFromContext(ctx)
if err != nil {
return nil, err
}
if db, ok := g.instances[id]; ok {
return db, nil
}
return nil, fmt.Errorf("no database instance found")
}
// getDatabase holds a read lock and returns the database
func (g *gRPCServer) getDatabase(ctx context.Context) (Database, error) {
g.RLock()
impl, err := g.getDatabaseInternal(ctx)
g.RUnlock()
return impl, err
}
// Initialize the database plugin
func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeRequest) (*proto.InitializeResponse, error) {
func (g *gRPCServer) Initialize(ctx context.Context, request *proto.InitializeRequest) (*proto.InitializeResponse, error) {
impl, err := g.getOrCreateDatabase(ctx)
if err != nil {
return nil, err
}
rawConfig := structToMap(request.ConfigData)
dbReq := InitializeRequest{
@ -28,7 +117,7 @@ func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeReq
VerifyConnection: request.VerifyConnection,
}
dbResp, err := g.impl.Initialize(ctx, dbReq)
dbResp, err := impl.Initialize(ctx, dbReq)
if err != nil {
return &proto.InitializeResponse{}, status.Errorf(codes.Internal, "failed to initialize: %s", err)
}
@ -45,7 +134,7 @@ func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeReq
return resp, nil
}
func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*proto.NewUserResponse, error) {
func (g *gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*proto.NewUserResponse, error) {
if req.GetUsernameConfig() == nil {
return &proto.NewUserResponse{}, status.Errorf(codes.InvalidArgument, "missing username config")
}
@ -60,6 +149,11 @@ func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*pr
expiration = exp
}
impl, err := g.getDatabase(ctx)
if err != nil {
return nil, err
}
dbReq := NewUserRequest{
UsernameConfig: UsernameMetadata{
DisplayName: req.GetUsernameConfig().GetDisplayName(),
@ -71,7 +165,7 @@ func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*pr
RollbackStatements: getStatementsFromProto(req.GetRollbackStatements()),
}
dbResp, err := g.impl.NewUser(ctx, dbReq)
dbResp, err := impl.NewUser(ctx, dbReq)
if err != nil {
return &proto.NewUserResponse{}, status.Errorf(codes.Internal, "unable to create new user: %s", err)
}
@ -82,7 +176,7 @@ func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*pr
return resp, nil
}
func (g gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest) (*proto.UpdateUserResponse, error) {
func (g *gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest) (*proto.UpdateUserResponse, error) {
if req.GetUsername() == "" {
return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided")
}
@ -92,7 +186,12 @@ func (g gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest
return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, err.Error())
}
_, err = g.impl.UpdateUser(ctx, dbReq)
impl, err := g.getDatabase(ctx)
if err != nil {
return nil, err
}
_, err = impl.UpdateUser(ctx, dbReq)
if err != nil {
return &proto.UpdateUserResponse{}, status.Errorf(codes.Internal, "unable to update user: %s", err)
}
@ -144,7 +243,7 @@ func hasChange(dbReq UpdateUserRequest) bool {
return false
}
func (g gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest) (*proto.DeleteUserResponse, error) {
func (g *gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest) (*proto.DeleteUserResponse, error) {
if req.GetUsername() == "" {
return &proto.DeleteUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided")
}
@ -153,15 +252,25 @@ func (g gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest
Statements: getStatementsFromProto(req.GetStatements()),
}
_, err := g.impl.DeleteUser(ctx, dbReq)
impl, err := g.getDatabase(ctx)
if err != nil {
return nil, err
}
_, err = impl.DeleteUser(ctx, dbReq)
if err != nil {
return &proto.DeleteUserResponse{}, status.Errorf(codes.Internal, "unable to delete user: %s", err)
}
return &proto.DeleteUserResponse{}, nil
}
func (g gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeResponse, error) {
t, err := g.impl.Type()
func (g *gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeResponse, error) {
impl, err := g.getOrCreateDatabase(ctx)
if err != nil {
return nil, err
}
t, err := impl.Type()
if err != nil {
return &proto.TypeResponse{}, status.Errorf(codes.Internal, "unable to retrieve type: %s", err)
}
@ -172,11 +281,29 @@ func (g gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeRespon
return resp, nil
}
func (g gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) {
err := g.impl.Close()
func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) {
g.Lock()
defer g.Unlock()
impl, err := g.getDatabaseInternal(ctx)
if err != nil {
return nil, err
}
err = impl.Close()
if err != nil {
return &proto.Empty{}, status.Errorf(codes.Internal, "unable to close database plugin: %s", err)
}
if g.singleImpl == nil {
// only cleanup instances map when multiplexing is supported
id, err := getMultiplexIDFromContext(ctx)
if err != nil {
return nil, err
}
delete(g.instances, id)
}
return &proto.Empty{}, nil
}

View File

@ -13,7 +13,9 @@ import (
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
@ -22,11 +24,12 @@ var invalidExpiration = time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC)
func TestGRPCServer_Initialize(t *testing.T) {
type testCase struct {
db Database
req *proto.InitializeRequest
expectedResp *proto.InitializeResponse
expectErr bool
expectCode codes.Code
db Database
req *proto.InitializeRequest
expectedResp *proto.InitializeResponse
expectErr bool
expectCode codes.Code
grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer)
}
tests := map[string]testCase{
@ -34,10 +37,11 @@ func TestGRPCServer_Initialize(t *testing.T) {
db: fakeDatabase{
initErr: errors.New("initialization error"),
},
req: &proto.InitializeRequest{},
expectedResp: &proto.InitializeResponse{},
expectErr: true,
expectCode: codes.Internal,
req: &proto.InitializeRequest{},
expectedResp: &proto.InitializeResponse{},
expectErr: true,
expectCode: codes.Internal,
grpcSetupFunc: testGrpcServer,
},
"newConfig can't marshal to JSON": {
db: fakeDatabase{
@ -47,12 +51,13 @@ func TestGRPCServer_Initialize(t *testing.T) {
},
},
},
req: &proto.InitializeRequest{},
expectedResp: &proto.InitializeResponse{},
expectErr: true,
expectCode: codes.Internal,
req: &proto.InitializeRequest{},
expectedResp: &proto.InitializeResponse{},
expectErr: true,
expectCode: codes.Internal,
grpcSetupFunc: testGrpcServer,
},
"happy path with config data": {
"happy path with config data for multiplexed plugin": {
db: fakeDatabase{
initResp: InitializeResponse{
Config: map[string]interface{}{
@ -70,21 +75,39 @@ func TestGRPCServer_Initialize(t *testing.T) {
"foo": "bar",
}),
},
expectErr: false,
expectCode: codes.OK,
expectErr: false,
expectCode: codes.OK,
grpcSetupFunc: testGrpcServer,
},
"happy path with config data for non-multiplexed plugin": {
db: fakeDatabase{
initResp: InitializeResponse{
Config: map[string]interface{}{
"foo": "bar",
},
},
},
req: &proto.InitializeRequest{
ConfigData: marshal(t, map[string]interface{}{
"foo": "bar",
}),
},
expectedResp: &proto.InitializeResponse{
ConfigData: marshal(t, map[string]interface{}{
"foo": "bar",
}),
},
expectErr: false,
expectCode: codes.OK,
grpcSetupFunc: testGrpcServerSingleImpl,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
g := gRPCServer{
impl: test.db,
}
idCtx, g := test.grpcSetupFunc(t, test.db)
resp, err := g.Initialize(idCtx, test.req)
// Context doesn't need to timeout since this is just passed through
ctx := context.Background()
resp, err := g.Initialize(ctx, test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
@ -252,14 +275,9 @@ func TestGRPCServer_NewUser(t *testing.T) {
for name, test := range tests {
t.Run(name, func(t *testing.T) {
g := gRPCServer{
impl: test.db,
}
idCtx, g := testGrpcServer(t, test.db)
resp, err := g.NewUser(idCtx, test.req)
// Context doesn't need to timeout since this is just passed through
ctx := context.Background()
resp, err := g.NewUser(ctx, test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
@ -362,14 +380,9 @@ func TestGRPCServer_UpdateUser(t *testing.T) {
for name, test := range tests {
t.Run(name, func(t *testing.T) {
g := gRPCServer{
impl: test.db,
}
idCtx, g := testGrpcServer(t, test.db)
resp, err := g.UpdateUser(idCtx, test.req)
// Context doesn't need to timeout since this is just passed through
ctx := context.Background()
resp, err := g.UpdateUser(ctx, test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
@ -430,14 +443,9 @@ func TestGRPCServer_DeleteUser(t *testing.T) {
for name, test := range tests {
t.Run(name, func(t *testing.T) {
g := gRPCServer{
impl: test.db,
}
idCtx, g := testGrpcServer(t, test.db)
resp, err := g.DeleteUser(idCtx, test.req)
// Context doesn't need to timeout since this is just passed through
ctx := context.Background()
resp, err := g.DeleteUser(ctx, test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
@ -488,14 +496,9 @@ func TestGRPCServer_Type(t *testing.T) {
for name, test := range tests {
t.Run(name, func(t *testing.T) {
g := gRPCServer{
impl: test.db,
}
idCtx, g := testGrpcServer(t, test.db)
resp, err := g.Type(idCtx, &proto.Empty{})
// Context doesn't need to timeout since this is just passed through
ctx := context.Background()
resp, err := g.Type(ctx, &proto.Empty{})
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
@ -517,9 +520,11 @@ func TestGRPCServer_Type(t *testing.T) {
func TestGRPCServer_Close(t *testing.T) {
type testCase struct {
db Database
expectErr bool
expectCode codes.Code
db Database
expectErr bool
expectCode codes.Code
grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer)
assertFunc func(t *testing.T, g gRPCServer)
}
tests := map[string]testCase{
@ -527,26 +532,36 @@ func TestGRPCServer_Close(t *testing.T) {
db: fakeDatabase{
closeErr: errors.New("close error"),
},
expectErr: true,
expectCode: codes.Internal,
expectErr: true,
expectCode: codes.Internal,
grpcSetupFunc: testGrpcServer,
assertFunc: nil,
},
"happy path": {
db: fakeDatabase{},
expectErr: false,
expectCode: codes.OK,
"happy path for multiplexed plugin": {
db: fakeDatabase{},
expectErr: false,
expectCode: codes.OK,
grpcSetupFunc: testGrpcServer,
assertFunc: func(t *testing.T, g gRPCServer) {
if len(g.instances) != 0 {
t.Fatalf("err expected instances map to be empty")
}
},
},
"happy path for non-multiplexed plugin": {
db: fakeDatabase{},
expectErr: false,
expectCode: codes.OK,
grpcSetupFunc: testGrpcServerSingleImpl,
assertFunc: nil,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
g := gRPCServer{
impl: test.db,
}
idCtx, g := test.grpcSetupFunc(t, test.db)
_, err := g.Close(idCtx, &proto.Empty{})
// Context doesn't need to timeout since this is just passed through
ctx := context.Background()
_, err := g.Close(ctx, &proto.Empty{})
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
@ -558,10 +573,105 @@ func TestGRPCServer_Close(t *testing.T) {
if actualCode != test.expectCode {
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
}
if test.assertFunc != nil {
test.assertFunc(t, g)
}
})
}
}
func TestGetMultiplexIDFromContext(t *testing.T) {
type testCase struct {
ctx context.Context
expectedResp string
expectedErr error
}
tests := map[string]testCase{
"missing plugin multiplexing metadata": {
ctx: context.Background(),
expectedResp: "",
expectedErr: fmt.Errorf("missing plugin multiplexing metadata"),
},
"unexpected number of IDs in metadata": {
ctx: idCtx(t, "12345", "67891"),
expectedResp: "",
expectedErr: fmt.Errorf("unexpected number of IDs in metadata: (2)"),
},
"empty multiplex ID in metadata": {
ctx: idCtx(t, ""),
expectedResp: "",
expectedErr: fmt.Errorf("empty multiplex ID in metadata"),
},
"happy path, id is returned from metadata": {
ctx: idCtx(t, "12345"),
expectedResp: "12345",
expectedErr: nil,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
resp, err := getMultiplexIDFromContext(test.ctx)
if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil {
t.Fatalf("err expected, got nil")
} else if !reflect.DeepEqual(err, test.expectedErr) {
t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr)
}
if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
// testGrpcServer is a test helper that returns a context with an ID set in its
// metadata and a gRPCServer instance for a multiplexed plugin
func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) {
t.Helper()
g := gRPCServer{
factoryFunc: func() (interface{}, error) {
return db, nil
},
instances: make(map[string]Database),
}
id := "12345"
idCtx := idCtx(t, id)
g.instances[id] = db
return idCtx, g
}
// testGrpcServerSingleImpl is a test helper that returns a context and a
// gRPCServer instance for a non-multiplexed plugin
func testGrpcServerSingleImpl(t *testing.T, db Database) (context.Context, gRPCServer) {
t.Helper()
return context.Background(), gRPCServer{
singleImpl: db,
}
}
// idCtx is a test helper that will return a context with the IDs set in its
// metadata
func idCtx(t *testing.T, ids ...string) context.Context {
t.Helper()
// Context doesn't need to timeout since this is just passed through
ctx := context.Background()
md := metadata.MD{}
for _, id := range ids {
md.Append(pluginutil.MultiplexingCtxKey, id)
}
return metadata.NewIncomingContext(ctx, md)
}
func marshal(t *testing.T, m map[string]interface{}) *structpb.Struct {
t.Helper()

View File

@ -3,19 +3,14 @@ package dbplugin
import (
"context"
"errors"
"sync"
log "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
)
// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's Close
// method to also call Kill() on the plugin.Client.
type DatabasePluginClient struct {
client *plugin.Client
sync.Mutex
client pluginutil.PluginClient
Database
}
@ -23,42 +18,31 @@ type DatabasePluginClient struct {
// and kill the plugin.
func (dc *DatabasePluginClient) Close() error {
err := dc.Database.Close()
dc.client.Kill()
dc.client.Close()
return err
}
// pluginSets is the map of plugins we can dispense.
var PluginSets = map[int]plugin.PluginSet{
5: {
"database": &GRPCDatabasePlugin{},
},
6: {
"database": &GRPCDatabasePlugin{},
},
}
// NewPluginClient returns a databaseRPCClient with a connection to a running
// plugin. The client is wrapped in a DatabasePluginClient object to ensure the
// plugin is killed on call of Close().
func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (Database, error) {
// pluginSets is the map of plugins we can dispense.
pluginSets := map[int]plugin.PluginSet{
5: {
"database": new(GRPCDatabasePlugin),
},
}
client, err := pluginRunner.RunConfig(ctx,
pluginutil.Runner(sys),
pluginutil.PluginSets(pluginSets),
pluginutil.HandshakeConfig(handshakeConfig),
pluginutil.Logger(logger),
pluginutil.MetadataMode(isMetadataMode),
pluginutil.AutoMTLS(true),
)
if err != nil {
return nil, err
}
// Connect via RPC
rpcClient, err := client.Client()
// plugin.
func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, config pluginutil.PluginClientConfig) (Database, error) {
pluginClient, err := sys.NewPluginClient(ctx, config)
if err != nil {
return nil, err
}
// Request the plugin
raw, err := rpcClient.Dispense("database")
raw, err := pluginClient.Dispense("database")
if err != nil {
return nil, err
}
@ -66,16 +50,19 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
// We should have a database type now. This feels like a normal interface
// implementation but is in fact over an RPC connection.
var db Database
switch raw.(type) {
switch c := raw.(type) {
case gRPCClient:
db = raw.(gRPCClient)
// This is an abstraction leak from go-plugin but it is necessary in
// order to enable multiplexing on multiplexed plugins
c.client = proto.NewDatabaseClient(pluginClient.Conn())
db = c
default:
return nil, errors.New("unsupported client type")
}
// Wrap RPC implementation in DatabasePluginClient
return &DatabasePluginClient{
client: client,
client: pluginClient,
Database: db,
}, nil
}

View File

@ -40,8 +40,17 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu
transport = "builtin"
} else {
config := pluginutil.PluginClientConfig{
Name: pluginName,
PluginType: consts.PluginTypeDatabase,
PluginSets: PluginSets,
HandshakeConfig: HandshakeConfig,
Logger: namedLogger,
IsMetadataMode: false,
AutoMTLS: true,
}
// create a DatabasePluginClient instance
db, err = NewPluginClient(ctx, sys, pluginRunner, namedLogger, false)
db, err = NewPluginClient(ctx, sys, config)
if err != nil {
return nil, err
}
@ -59,6 +68,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu
if err != nil {
return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err)
}
logger.Debug("got database plugin instance", "type", typeStr)
// Wrap with metrics middleware
db = &databaseMetricsMiddleware{

View File

@ -31,7 +31,49 @@ func ServeConfig(db Database) *plugin.ServeConfig {
}
conf := &plugin.ServeConfig{
HandshakeConfig: handshakeConfig,
HandshakeConfig: HandshakeConfig,
VersionedPlugins: pluginSets,
GRPCServer: plugin.DefaultGRPCServer,
}
return conf
}
func ServeMultiplex(factory Factory) {
plugin.Serve(ServeConfigMultiplex(factory))
}
func ServeConfigMultiplex(factory Factory) *plugin.ServeConfig {
err := pluginutil.OptionallyEnableMlock()
if err != nil {
fmt.Println(err)
return nil
}
db, err := factory()
if err != nil {
fmt.Println(err)
return nil
}
database := db.(Database)
// pluginSets is the map of plugins we can dispense.
pluginSets := map[int]plugin.PluginSet{
5: {
"database": &GRPCDatabasePlugin{
Impl: database,
},
},
6: {
"database": &GRPCDatabasePlugin{
FactoryFunc: factory,
},
},
}
conf := &plugin.ServeConfig{
HandshakeConfig: HandshakeConfig,
VersionedPlugins: pluginSets,
GRPCServer: plugin.DefaultGRPCServer,
}

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.17.3
// protoc v3.19.4
// source: sdk/database/dbplugin/v5/proto/database.proto
package proto

View File

@ -0,0 +1,47 @@
package pluginutil
import (
context "context"
"fmt"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
type PluginMultiplexingServerImpl struct {
UnimplementedPluginMultiplexingServer
Supported bool
}
func (pm PluginMultiplexingServerImpl) MultiplexingSupport(ctx context.Context, req *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error) {
return &MultiplexingSupportResponse{
Supported: pm.Supported,
}, nil
}
func MultiplexingSupported(ctx context.Context, cc grpc.ClientConnInterface) (bool, error) {
if cc == nil {
return false, fmt.Errorf("client connection is nil")
}
req := new(MultiplexingSupportRequest)
resp, err := NewPluginMultiplexingClient(cc).MultiplexingSupport(ctx, req)
if err != nil {
// If the server does not implement the multiplexing server then we can
// assume it is not multiplexed
if status.Code(err) == codes.Unimplemented {
return false, nil
}
return false, err
}
if resp == nil {
// Somehow got a nil response, assume not multiplexed
return false, nil
}
return resp.Supported, nil
}

View File

@ -0,0 +1,213 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.19.4
// source: sdk/helper/pluginutil/multiplexing.proto
package pluginutil
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type MultiplexingSupportRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *MultiplexingSupportRequest) Reset() {
*x = MultiplexingSupportRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *MultiplexingSupportRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*MultiplexingSupportRequest) ProtoMessage() {}
func (x *MultiplexingSupportRequest) ProtoReflect() protoreflect.Message {
mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use MultiplexingSupportRequest.ProtoReflect.Descriptor instead.
func (*MultiplexingSupportRequest) Descriptor() ([]byte, []int) {
return file_sdk_helper_pluginutil_multiplexing_proto_rawDescGZIP(), []int{0}
}
type MultiplexingSupportResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Supported bool `protobuf:"varint,1,opt,name=supported,proto3" json:"supported,omitempty"`
}
func (x *MultiplexingSupportResponse) Reset() {
*x = MultiplexingSupportResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *MultiplexingSupportResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*MultiplexingSupportResponse) ProtoMessage() {}
func (x *MultiplexingSupportResponse) ProtoReflect() protoreflect.Message {
mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use MultiplexingSupportResponse.ProtoReflect.Descriptor instead.
func (*MultiplexingSupportResponse) Descriptor() ([]byte, []int) {
return file_sdk_helper_pluginutil_multiplexing_proto_rawDescGZIP(), []int{1}
}
func (x *MultiplexingSupportResponse) GetSupported() bool {
if x != nil {
return x.Supported
}
return false
}
var File_sdk_helper_pluginutil_multiplexing_proto protoreflect.FileDescriptor
var file_sdk_helper_pluginutil_multiplexing_proto_rawDesc = []byte{
0x0a, 0x28, 0x73, 0x64, 0x6b, 0x2f, 0x68, 0x65, 0x6c, 0x70, 0x65, 0x72, 0x2f, 0x70, 0x6c, 0x75,
0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2f, 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65,
0x78, 0x69, 0x6e, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x17, 0x70, 0x6c, 0x75, 0x67,
0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2e, 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78,
0x69, 0x6e, 0x67, 0x22, 0x1c, 0x0a, 0x1a, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78,
0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x22, 0x3b, 0x0a, 0x1b, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e,
0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x12, 0x1c, 0x0a, 0x09, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x65, 0x64, 0x18, 0x01, 0x20,
0x01, 0x28, 0x08, 0x52, 0x09, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x65, 0x64, 0x32, 0x97,
0x01, 0x0a, 0x12, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c,
0x65, 0x78, 0x69, 0x6e, 0x67, 0x12, 0x80, 0x01, 0x0a, 0x13, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70,
0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x33, 0x2e,
0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2e, 0x6d, 0x75, 0x6c, 0x74, 0x69,
0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65,
0x78, 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x1a, 0x34, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2e,
0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x75, 0x6c,
0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x32, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68,
0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70,
0x2f, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x73, 0x64, 0x6b, 0x2f, 0x68, 0x65, 0x6c, 0x70, 0x65,
0x72, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x62, 0x06, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x33,
}
var (
file_sdk_helper_pluginutil_multiplexing_proto_rawDescOnce sync.Once
file_sdk_helper_pluginutil_multiplexing_proto_rawDescData = file_sdk_helper_pluginutil_multiplexing_proto_rawDesc
)
func file_sdk_helper_pluginutil_multiplexing_proto_rawDescGZIP() []byte {
file_sdk_helper_pluginutil_multiplexing_proto_rawDescOnce.Do(func() {
file_sdk_helper_pluginutil_multiplexing_proto_rawDescData = protoimpl.X.CompressGZIP(file_sdk_helper_pluginutil_multiplexing_proto_rawDescData)
})
return file_sdk_helper_pluginutil_multiplexing_proto_rawDescData
}
var file_sdk_helper_pluginutil_multiplexing_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_sdk_helper_pluginutil_multiplexing_proto_goTypes = []interface{}{
(*MultiplexingSupportRequest)(nil), // 0: pluginutil.multiplexing.MultiplexingSupportRequest
(*MultiplexingSupportResponse)(nil), // 1: pluginutil.multiplexing.MultiplexingSupportResponse
}
var file_sdk_helper_pluginutil_multiplexing_proto_depIdxs = []int32{
0, // 0: pluginutil.multiplexing.PluginMultiplexing.MultiplexingSupport:input_type -> pluginutil.multiplexing.MultiplexingSupportRequest
1, // 1: pluginutil.multiplexing.PluginMultiplexing.MultiplexingSupport:output_type -> pluginutil.multiplexing.MultiplexingSupportResponse
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_sdk_helper_pluginutil_multiplexing_proto_init() }
func file_sdk_helper_pluginutil_multiplexing_proto_init() {
if File_sdk_helper_pluginutil_multiplexing_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*MultiplexingSupportRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*MultiplexingSupportResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_sdk_helper_pluginutil_multiplexing_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_sdk_helper_pluginutil_multiplexing_proto_goTypes,
DependencyIndexes: file_sdk_helper_pluginutil_multiplexing_proto_depIdxs,
MessageInfos: file_sdk_helper_pluginutil_multiplexing_proto_msgTypes,
}.Build()
File_sdk_helper_pluginutil_multiplexing_proto = out.File
file_sdk_helper_pluginutil_multiplexing_proto_rawDesc = nil
file_sdk_helper_pluginutil_multiplexing_proto_goTypes = nil
file_sdk_helper_pluginutil_multiplexing_proto_depIdxs = nil
}

View File

@ -0,0 +1,13 @@
syntax = "proto3";
package pluginutil.multiplexing;
option go_package = "github.com/hashicorp/vault/sdk/helper/pluginutil";
message MultiplexingSupportRequest {}
message MultiplexingSupportResponse {
bool supported = 1;
}
service PluginMultiplexing {
rpc MultiplexingSupport(MultiplexingSupportRequest) returns (MultiplexingSupportResponse);
}

View File

@ -0,0 +1,101 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
package pluginutil
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// PluginMultiplexingClient is the client API for PluginMultiplexing service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type PluginMultiplexingClient interface {
MultiplexingSupport(ctx context.Context, in *MultiplexingSupportRequest, opts ...grpc.CallOption) (*MultiplexingSupportResponse, error)
}
type pluginMultiplexingClient struct {
cc grpc.ClientConnInterface
}
func NewPluginMultiplexingClient(cc grpc.ClientConnInterface) PluginMultiplexingClient {
return &pluginMultiplexingClient{cc}
}
func (c *pluginMultiplexingClient) MultiplexingSupport(ctx context.Context, in *MultiplexingSupportRequest, opts ...grpc.CallOption) (*MultiplexingSupportResponse, error) {
out := new(MultiplexingSupportResponse)
err := c.cc.Invoke(ctx, "/pluginutil.multiplexing.PluginMultiplexing/MultiplexingSupport", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// PluginMultiplexingServer is the server API for PluginMultiplexing service.
// All implementations must embed UnimplementedPluginMultiplexingServer
// for forward compatibility
type PluginMultiplexingServer interface {
MultiplexingSupport(context.Context, *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error)
mustEmbedUnimplementedPluginMultiplexingServer()
}
// UnimplementedPluginMultiplexingServer must be embedded to have forward compatible implementations.
type UnimplementedPluginMultiplexingServer struct {
}
func (UnimplementedPluginMultiplexingServer) MultiplexingSupport(context.Context, *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method MultiplexingSupport not implemented")
}
func (UnimplementedPluginMultiplexingServer) mustEmbedUnimplementedPluginMultiplexingServer() {}
// UnsafePluginMultiplexingServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to PluginMultiplexingServer will
// result in compilation errors.
type UnsafePluginMultiplexingServer interface {
mustEmbedUnimplementedPluginMultiplexingServer()
}
func RegisterPluginMultiplexingServer(s grpc.ServiceRegistrar, srv PluginMultiplexingServer) {
s.RegisterService(&PluginMultiplexing_ServiceDesc, srv)
}
func _PluginMultiplexing_MultiplexingSupport_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(MultiplexingSupportRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(PluginMultiplexingServer).MultiplexingSupport(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/pluginutil.multiplexing.PluginMultiplexing/MultiplexingSupport",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(PluginMultiplexingServer).MultiplexingSupport(ctx, req.(*MultiplexingSupportRequest))
}
return interceptor(ctx, in, info, handler)
}
// PluginMultiplexing_ServiceDesc is the grpc.ServiceDesc for PluginMultiplexing service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var PluginMultiplexing_ServiceDesc = grpc.ServiceDesc{
ServiceName: "pluginutil.multiplexing.PluginMultiplexing",
HandlerType: (*PluginMultiplexingServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "MultiplexingSupport",
Handler: _PluginMultiplexing_MultiplexingSupport_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "sdk/helper/pluginutil/multiplexing.proto",
}

View File

@ -9,9 +9,21 @@ import (
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/version"
)
type PluginClientConfig struct {
Name string
PluginType consts.PluginType
PluginSets map[int]plugin.PluginSet
HandshakeConfig plugin.HandshakeConfig
Logger log.Logger
IsMetadataMode bool
AutoMTLS bool
MLock bool
}
type runConfig struct {
// Provided by PluginRunner
command string
@ -21,12 +33,9 @@ type runConfig struct {
// Initialized with what's in PluginRunner.Env, but can be added to
env []string
wrapper RunnerUtil
pluginSets map[int]plugin.PluginSet
hs plugin.HandshakeConfig
logger log.Logger
isMetadataMode bool
autoMTLS bool
wrapper RunnerUtil
PluginClientConfig
}
func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error) {
@ -34,19 +43,19 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error
cmd.Env = append(cmd.Env, rc.env...)
// Add the mlock setting to the ENV of the plugin
if rc.wrapper != nil && rc.wrapper.MlockEnabled() {
if rc.MLock || (rc.wrapper != nil && rc.wrapper.MlockEnabled()) {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true"))
}
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version))
if rc.isMetadataMode {
rc.logger = rc.logger.With("metadata", "true")
if rc.IsMetadataMode {
rc.Logger = rc.Logger.With("metadata", "true")
}
metadataEnv := fmt.Sprintf("%s=%t", PluginMetadataModeEnv, rc.isMetadataMode)
metadataEnv := fmt.Sprintf("%s=%t", PluginMetadataModeEnv, rc.IsMetadataMode)
cmd.Env = append(cmd.Env, metadataEnv)
var clientTLSConfig *tls.Config
if !rc.autoMTLS && !rc.isMetadataMode {
if !rc.AutoMTLS && !rc.IsMetadataMode {
// Get a CA TLS Certificate
certBytes, key, err := generateCert()
if err != nil {
@ -76,17 +85,17 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error
}
clientConfig := &plugin.ClientConfig{
HandshakeConfig: rc.hs,
VersionedPlugins: rc.pluginSets,
HandshakeConfig: rc.HandshakeConfig,
VersionedPlugins: rc.PluginSets,
Cmd: cmd,
SecureConfig: secureConfig,
TLSConfig: clientTLSConfig,
Logger: rc.logger,
Logger: rc.Logger,
AllowedProtocols: []plugin.Protocol{
plugin.ProtocolNetRPC,
plugin.ProtocolGRPC,
},
AutoMTLS: rc.autoMTLS,
AutoMTLS: rc.AutoMTLS,
}
return clientConfig, nil
}
@ -117,31 +126,37 @@ func Runner(wrapper RunnerUtil) RunOpt {
func PluginSets(pluginSets map[int]plugin.PluginSet) RunOpt {
return func(rc *runConfig) {
rc.pluginSets = pluginSets
rc.PluginSets = pluginSets
}
}
func HandshakeConfig(hs plugin.HandshakeConfig) RunOpt {
return func(rc *runConfig) {
rc.hs = hs
rc.HandshakeConfig = hs
}
}
func Logger(logger log.Logger) RunOpt {
return func(rc *runConfig) {
rc.logger = logger
rc.Logger = logger
}
}
func MetadataMode(isMetadataMode bool) RunOpt {
return func(rc *runConfig) {
rc.isMetadataMode = isMetadataMode
rc.IsMetadataMode = isMetadataMode
}
}
func AutoMTLS(autoMTLS bool) RunOpt {
return func(rc *runConfig) {
rc.autoMTLS = autoMTLS
rc.AutoMTLS = autoMTLS
}
}
func MLock(mlock bool) RunOpt {
return func(rc *runConfig) {
rc.MLock = mlock
}
}

View File

@ -38,19 +38,21 @@ func TestMakeConfig(t *testing.T) {
args: []string{"foo", "bar"},
sha256: []byte("some_sha256"),
env: []string{"initial=true"},
pluginSets: map[int]plugin.PluginSet{
1: {
"bogus": nil,
PluginClientConfig: PluginClientConfig{
PluginSets: map[int]plugin.PluginSet{
1: {
"bogus": nil,
},
},
HandshakeConfig: plugin.HandshakeConfig{
ProtocolVersion: 1,
MagicCookieKey: "magic_cookie_key",
MagicCookieValue: "magic_cookie_value",
},
Logger: hclog.NewNullLogger(),
IsMetadataMode: true,
AutoMTLS: false,
},
hs: plugin.HandshakeConfig{
ProtocolVersion: 1,
MagicCookieKey: "magic_cookie_key",
MagicCookieValue: "magic_cookie_value",
},
logger: hclog.NewNullLogger(),
isMetadataMode: true,
autoMTLS: false,
},
responseWrapInfoTimes: 0,
@ -97,19 +99,21 @@ func TestMakeConfig(t *testing.T) {
args: []string{"foo", "bar"},
sha256: []byte("some_sha256"),
env: []string{"initial=true"},
pluginSets: map[int]plugin.PluginSet{
1: {
"bogus": nil,
PluginClientConfig: PluginClientConfig{
PluginSets: map[int]plugin.PluginSet{
1: {
"bogus": nil,
},
},
HandshakeConfig: plugin.HandshakeConfig{
ProtocolVersion: 1,
MagicCookieKey: "magic_cookie_key",
MagicCookieValue: "magic_cookie_value",
},
Logger: hclog.NewNullLogger(),
IsMetadataMode: false,
AutoMTLS: false,
},
hs: plugin.HandshakeConfig{
ProtocolVersion: 1,
MagicCookieKey: "magic_cookie_key",
MagicCookieValue: "magic_cookie_value",
},
logger: hclog.NewNullLogger(),
isMetadataMode: false,
autoMTLS: false,
},
responseWrapInfo: &wrapping.ResponseWrapInfo{
@ -161,19 +165,21 @@ func TestMakeConfig(t *testing.T) {
args: []string{"foo", "bar"},
sha256: []byte("some_sha256"),
env: []string{"initial=true"},
pluginSets: map[int]plugin.PluginSet{
1: {
"bogus": nil,
PluginClientConfig: PluginClientConfig{
PluginSets: map[int]plugin.PluginSet{
1: {
"bogus": nil,
},
},
HandshakeConfig: plugin.HandshakeConfig{
ProtocolVersion: 1,
MagicCookieKey: "magic_cookie_key",
MagicCookieValue: "magic_cookie_value",
},
Logger: hclog.NewNullLogger(),
IsMetadataMode: true,
AutoMTLS: true,
},
hs: plugin.HandshakeConfig{
ProtocolVersion: 1,
MagicCookieKey: "magic_cookie_key",
MagicCookieValue: "magic_cookie_value",
},
logger: hclog.NewNullLogger(),
isMetadataMode: true,
autoMTLS: true,
},
responseWrapInfoTimes: 0,
@ -220,19 +226,21 @@ func TestMakeConfig(t *testing.T) {
args: []string{"foo", "bar"},
sha256: []byte("some_sha256"),
env: []string{"initial=true"},
pluginSets: map[int]plugin.PluginSet{
1: {
"bogus": nil,
PluginClientConfig: PluginClientConfig{
PluginSets: map[int]plugin.PluginSet{
1: {
"bogus": nil,
},
},
HandshakeConfig: plugin.HandshakeConfig{
ProtocolVersion: 1,
MagicCookieKey: "magic_cookie_key",
MagicCookieValue: "magic_cookie_value",
},
Logger: hclog.NewNullLogger(),
IsMetadataMode: false,
AutoMTLS: true,
},
hs: plugin.HandshakeConfig{
ProtocolVersion: 1,
MagicCookieKey: "magic_cookie_key",
MagicCookieValue: "magic_cookie_value",
},
logger: hclog.NewNullLogger(),
isMetadataMode: false,
autoMTLS: true,
},
responseWrapInfoTimes: 0,
@ -329,6 +337,11 @@ type mockRunnerUtil struct {
mock.Mock
}
func (m *mockRunnerUtil) NewPluginClient(ctx context.Context, config PluginClientConfig) (PluginClient, error) {
args := m.Called(ctx, config)
return args.Get(0).(PluginClient), args.Error(1)
}
func (m *mockRunnerUtil) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
args := m.Called(ctx, data, ttl, jwt)
return args.Get(0).(*wrapping.ResponseWrapInfo), args.Error(1)

View File

@ -8,6 +8,7 @@ import (
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/wrapping"
"google.golang.org/grpc"
)
// Looker defines the plugin Lookup function that looks into the plugin catalog
@ -21,6 +22,7 @@ type Looker interface {
// configuration and wrapping data in a response wrapped token.
// logical.SystemView implementations satisfy this interface.
type RunnerUtil interface {
NewPluginClient(ctx context.Context, config PluginClientConfig) (PluginClient, error)
ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error)
MlockEnabled() bool
}
@ -31,17 +33,25 @@ type LookRunnerUtil interface {
RunnerUtil
}
type PluginClient interface {
Conn() grpc.ClientConnInterface
plugin.ClientProtocol
}
const MultiplexingCtxKey string = "multiplex_id"
// PluginRunner defines the metadata needed to run a plugin securely with
// go-plugin.
type PluginRunner struct {
Name string `json:"name" structs:"name"`
Type consts.PluginType `json:"type" structs:"type"`
Command string `json:"command" structs:"command"`
Args []string `json:"args" structs:"args"`
Env []string `json:"env" structs:"env"`
Sha256 []byte `json:"sha256" structs:"sha256"`
Builtin bool `json:"builtin" structs:"builtin"`
BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"`
Name string `json:"name" structs:"name"`
Type consts.PluginType `json:"type" structs:"type"`
Command string `json:"command" structs:"command"`
Args []string `json:"args" structs:"args"`
Env []string `json:"env" structs:"env"`
Sha256 []byte `json:"sha256" structs:"sha256"`
Builtin bool `json:"builtin" structs:"builtin"`
BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"`
MultiplexingSupport bool `json:"multiplexing_support" structs:"multiplexing_support"`
}
// Run takes a wrapper RunnerUtil instance along with the go-plugin parameters and

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.17.3
// protoc v3.19.4
// source: sdk/logical/identity.proto
package logical

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.17.3
// protoc v3.19.4
// source: sdk/logical/plugin.proto
package logical

View File

@ -56,6 +56,10 @@ type SystemView interface {
// name. Returns a PluginRunner or an error if a plugin can not be found.
LookupPlugin(context.Context, string, consts.PluginType) (*pluginutil.PluginRunner, error)
// NewPluginClient returns a client for managing the lifecycle of plugin
// processes
NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error)
// MlockEnabled returns the configuration setting for enabling mlock on
// plugins.
MlockEnabled() bool
@ -152,6 +156,10 @@ func (d StaticSystemView) ReplicationState() consts.ReplicationState {
return d.ReplicationStateVal
}
func (d StaticSystemView) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) {
return nil, errors.New("NewPluginClient is not implemented in StaticSystemView")
}
func (d StaticSystemView) ResponseWrapData(_ context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
return nil, errors.New("ResponseWrapData is not implemented in StaticSystemView")
}

View File

@ -99,6 +99,10 @@ func (s *gRPCSystemViewClient) ResponseWrapData(ctx context.Context, data map[st
return info, nil
}
func (s *gRPCSystemViewClient) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) {
return nil, fmt.Errorf("cannot call NewPluginClient from a plugin backend")
}
func (s *gRPCSystemViewClient) LookupPlugin(_ context.Context, _ string, _ consts.PluginType) (*pluginutil.PluginRunner, error) {
return nil, fmt.Errorf("cannot call LookupPlugin from a plugin backend")
}

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.17.3
// protoc v3.19.4
// source: sdk/plugin/pb/backend.proto
package pb

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.17.3
// protoc v3.19.4
// source: vault/activity/activity_log.proto
package activity

View File

@ -215,6 +215,22 @@ func (d dynamicSystemView) ResponseWrapData(ctx context.Context, data map[string
return resp.WrapInfo, nil
}
func (d dynamicSystemView) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) {
if d.core == nil {
return nil, fmt.Errorf("system view core is nil")
}
if d.core.pluginCatalog == nil {
return nil, fmt.Errorf("system view core plugin catalog is nil")
}
c, err := d.core.pluginCatalog.NewPluginClient(ctx, config)
if err != nil {
return nil, err
}
return c, nil
}
// LookupPlugin looks for a plugin with the given name in the plugin catalog. It
// returns a PluginRunner or an error if no plugin was found.
func (d dynamicSystemView) LookupPlugin(ctx context.Context, name string, pluginType consts.PluginType) (*pluginutil.PluginRunner, error) {

View File

@ -12,6 +12,8 @@ import (
log "github.com/hashicorp/go-hclog"
multierror "github.com/hashicorp/go-multierror"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/go-secure-stdlib/base62"
v4 "github.com/hashicorp/vault/sdk/database/dbplugin"
v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/helper/consts"
@ -19,6 +21,8 @@ import (
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
backendplugin "github.com/hashicorp/vault/sdk/plugin"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
var (
@ -35,21 +39,62 @@ type PluginCatalog struct {
builtinRegistry BuiltinRegistry
catalogView *BarrierView
directory string
logger log.Logger
// externalPlugins holds plugin process connections by plugin name
//
// This allows plugins that suppport multiplexing to use a single grpc
// connection to communicate with multiple "backends". Each backend
// configuration using the same plugin will be routed to the existing
// plugin process.
externalPlugins map[string]*externalPlugin
mlockPlugins bool
lock sync.RWMutex
}
// externalPlugin holds client connections for multiplexed and
// non-multiplexed plugin processes
type externalPlugin struct {
// name is the plugin name
name string
// connections holds client connections by ID
connections map[string]*pluginClient
multiplexingSupport bool
}
// pluginClient represents a connection to a plugin process
type pluginClient struct {
logger log.Logger
// id is the connection ID
id string
// client handles the lifecycle of a plugin process
// multiplexed plugins share the same client
client *plugin.Client
clientConn grpc.ClientConnInterface
cleanupFunc func() error
plugin.ClientProtocol
}
func (c *Core) setupPluginCatalog(ctx context.Context) error {
c.pluginCatalog = &PluginCatalog{
builtinRegistry: c.builtinRegistry,
catalogView: NewBarrierView(c.barrier, pluginCatalogPath),
directory: c.pluginDirectory,
logger: c.logger,
mlockPlugins: c.enableMlock,
}
// Run upgrade if untyped plugins exist
err := c.pluginCatalog.UpgradePlugins(ctx, c.logger)
if err != nil {
c.logger.Error("error while upgrading plugin storage", "error", err)
return err
}
if c.logger.IsInfo() {
@ -59,14 +104,205 @@ func (c *Core) setupPluginCatalog(ctx context.Context) error {
return nil
}
type pluginClientConn struct {
*grpc.ClientConn
id string
}
var _ grpc.ClientConnInterface = &pluginClientConn{}
func (d *pluginClientConn) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error {
// Inject ID to the context
md := metadata.Pairs(pluginutil.MultiplexingCtxKey, d.id)
idCtx := metadata.NewOutgoingContext(ctx, md)
return d.ClientConn.Invoke(idCtx, method, args, reply, opts...)
}
func (d *pluginClientConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
// Inject ID to the context
md := metadata.Pairs(pluginutil.MultiplexingCtxKey, d.id)
idCtx := metadata.NewOutgoingContext(ctx, md)
return d.ClientConn.NewStream(idCtx, desc, method, opts...)
}
func (p *pluginClient) Conn() grpc.ClientConnInterface {
return p.clientConn
}
// Close calls the plugin client's cleanupFunc to do any necessary cleanup on
// the plugin client and the PluginCatalog. This implements the
// plugin.ClientProtocol interface.
func (p *pluginClient) Close() error {
p.logger.Debug("cleaning up plugin client connection", "id", p.id)
return p.cleanupFunc()
}
// cleanupExternalPlugin will kill plugin processes and perform any necessary
// cleanup on the externalPlugins map for multiplexed and non-multiplexed
// plugins. This should be called with the write lock held.
func (c *PluginCatalog) cleanupExternalPlugin(name, id string) error {
extPlugin, ok := c.externalPlugins[name]
if !ok {
return fmt.Errorf("plugin client not found")
}
pluginClient := extPlugin.connections[id]
delete(extPlugin.connections, id)
if !extPlugin.multiplexingSupport {
pluginClient.client.Kill()
if len(extPlugin.connections) == 0 {
delete(c.externalPlugins, name)
}
} else if len(extPlugin.connections) == 0 {
pluginClient.client.Kill()
delete(c.externalPlugins, name)
}
return nil
}
func (c *PluginCatalog) getExternalPlugin(pluginName string) *externalPlugin {
if extPlugin, ok := c.externalPlugins[pluginName]; ok {
return extPlugin
}
return c.newExternalPlugin(pluginName)
}
func (c *PluginCatalog) newExternalPlugin(pluginName string) *externalPlugin {
if c.externalPlugins == nil {
c.externalPlugins = make(map[string]*externalPlugin)
}
extPlugin := &externalPlugin{
connections: make(map[string]*pluginClient),
name: pluginName,
}
c.externalPlugins[pluginName] = extPlugin
return extPlugin
}
// NewPluginClient returns a client for managing the lifecycle of a plugin
// process
func (c *PluginCatalog) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (*pluginClient, error) {
c.lock.Lock()
defer c.lock.Unlock()
if config.Name == "" {
return nil, fmt.Errorf("no name provided for plugin")
}
if config.PluginType == consts.PluginTypeUnknown {
return nil, fmt.Errorf("no plugin type provided")
}
pluginRunner, err := c.get(ctx, config.Name, config.PluginType)
if err != nil {
return nil, fmt.Errorf("failed to lookup plugin: %w", err)
}
if pluginRunner == nil {
return nil, fmt.Errorf("no plugin found")
}
pc, err := c.newPluginClient(ctx, pluginRunner, config)
return pc, err
}
// newPluginClient returns a client for managing the lifecycle of a plugin
// process. Callers should have the write lock held.
func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (*pluginClient, error) {
if pluginRunner == nil {
return nil, fmt.Errorf("no plugin found")
}
extPlugin := c.getExternalPlugin(pluginRunner.Name)
id, err := base62.Random(10)
if err != nil {
return nil, err
}
pc := &pluginClient{
id: id,
logger: c.logger.Named(pluginRunner.Name),
cleanupFunc: func() error {
c.lock.Lock()
defer c.lock.Unlock()
return c.cleanupExternalPlugin(pluginRunner.Name, id)
},
}
if !pluginRunner.MultiplexingSupport || len(extPlugin.connections) == 0 {
c.logger.Debug("spawning a new plugin process", "plugin_name", pluginRunner.Name, "id", id)
client, err := pluginRunner.RunConfig(ctx,
pluginutil.PluginSets(config.PluginSets),
pluginutil.HandshakeConfig(config.HandshakeConfig),
pluginutil.Logger(config.Logger),
pluginutil.MetadataMode(config.IsMetadataMode),
pluginutil.MLock(c.mlockPlugins),
// NewPluginClient only supports AutoMTLS today
pluginutil.AutoMTLS(true),
)
if err != nil {
return nil, err
}
pc.client = client
} else {
c.logger.Debug("returning existing plugin client for multiplexed plugin", "id", id)
// get the first client, since they are all the same
for k := range extPlugin.connections {
pc.client = extPlugin.connections[k].client
break
}
if pc.client == nil {
return nil, fmt.Errorf("plugin client is nil")
}
}
// Get the protocol client for this connection.
// Subsequent calls to this will return the same client.
rpcClient, err := pc.client.Client()
if err != nil {
return nil, err
}
clientConn := rpcClient.(*plugin.GRPCClient).Conn
if pluginRunner.MultiplexingSupport {
// Wrap rpcClient with our implementation so that we can inject the
// ID into the context
pc.clientConn = &pluginClientConn{
ClientConn: clientConn,
id: id,
}
} else {
pc.clientConn = clientConn
}
pc.ClientProtocol = rpcClient
extPlugin.connections[id] = pc
extPlugin.name = pluginRunner.Name
extPlugin.multiplexingSupport = pluginRunner.MultiplexingSupport
return extPlugin.connections[id], nil
}
// getPluginTypeFromUnknown will attempt to run the plugin to determine the
// type. It will first attempt to run as a database plugin then a backend
// plugin. Both of these will be run in metadata mode.
func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, error) {
// type and if it supports multiplexing. It will first attempt to run as a
// database plugin then a backend plugin. Both of these will be run in metadata
// mode.
func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, bool, error) {
merr := &multierror.Error{}
err := isDatabasePlugin(ctx, plugin)
multiplexingSupport, err := c.isDatabasePlugin(ctx, plugin)
if err == nil {
return consts.PluginTypeDatabase, nil
return consts.PluginTypeDatabase, multiplexingSupport, nil
}
merr = multierror.Append(merr, err)
@ -75,7 +311,7 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log
if err == nil {
err := client.Setup(ctx, &logical.BackendConfig{})
if err != nil {
return consts.PluginTypeUnknown, err
return consts.PluginTypeUnknown, false, err
}
backendType := client.Type()
@ -83,9 +319,9 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log
switch backendType {
case logical.TypeCredential:
return consts.PluginTypeCredential, nil
return consts.PluginTypeCredential, false, nil
case logical.TypeLogical:
return consts.PluginTypeSecrets, nil
return consts.PluginTypeSecrets, false, nil
}
} else {
merr = multierror.Append(merr, err)
@ -102,29 +338,55 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log
"error", merr.Error())
}
return consts.PluginTypeUnknown, nil
return consts.PluginTypeUnknown, false, nil
}
func isDatabasePlugin(ctx context.Context, plugin *pluginutil.PluginRunner) error {
// isDatabasePlugin returns true if the plugin supports multiplexing. An error
// is returned if the plugin is not a database plugin.
func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (bool, error) {
merr := &multierror.Error{}
// Attempt to run as database V5 plugin
v5Client, err := v5.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true)
config := pluginutil.PluginClientConfig{
Name: pluginRunner.Name,
PluginSets: v5.PluginSets,
PluginType: consts.PluginTypeDatabase,
HandshakeConfig: v5.HandshakeConfig,
Logger: log.NewNullLogger(),
IsMetadataMode: true,
AutoMTLS: true,
}
// Attempt to run as database V5 or V6 multiplexed plugin
v5Client, err := c.newPluginClient(ctx, pluginRunner, config)
if err == nil {
// At this point the pluginRunner does not know if multiplexing is
// supported or not. So we need to ask the plugin client itself.
multiplexingSupport, err := pluginutil.MultiplexingSupported(ctx, v5Client.clientConn)
if err != nil {
return false, err
}
// Close the client and cleanup the plugin process
v5Client.Close()
return nil
err = c.cleanupExternalPlugin(pluginRunner.Name, v5Client.id)
if err != nil {
c.logger.Error("error closing plugin client", "error", err)
}
return multiplexingSupport, nil
}
merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v5: %w", err))
v4Client, err := v4.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true)
v4Client, err := v4.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true)
if err == nil {
// Close the client and cleanup the plugin process
v4Client.Close()
return nil
err = v4Client.Close()
if err != nil {
c.logger.Error("error closing plugin client", "error", err)
}
return false, nil
}
merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v4: %w", err))
return merr.ErrorOrNil()
return false, merr.ErrorOrNil()
}
// UpdatePlugins will loop over all the plugins of unknown type and attempt to
@ -170,7 +432,7 @@ func (c *PluginCatalog) UpgradePlugins(ctx context.Context, logger log.Logger) e
cmdOld := plugin.Command
plugin.Command = filepath.Join(c.directory, plugin.Command)
pluginType, err := c.getPluginTypeFromUnknown(ctx, logger, plugin)
pluginType, multiplexingSupport, err := c.getPluginTypeFromUnknown(ctx, logger, plugin)
if err != nil {
retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: %s", pluginName, err))
continue
@ -181,7 +443,7 @@ func (c *PluginCatalog) UpgradePlugins(ctx context.Context, logger log.Logger) e
}
// Upgrade the storage
err = c.setInternal(ctx, pluginName, pluginType, cmdOld, plugin.Args, plugin.Env, plugin.Sha256)
err = c.setInternal(ctx, pluginName, pluginType, multiplexingSupport, cmdOld, plugin.Args, plugin.Env, plugin.Sha256)
if err != nil {
retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: %s", pluginName, err))
continue
@ -269,10 +531,14 @@ func (c *PluginCatalog) Set(ctx context.Context, name string, pluginType consts.
c.lock.Lock()
defer c.lock.Unlock()
return c.setInternal(ctx, name, pluginType, command, args, env, sha256)
// During plugin registration, we can't know if a plugin is multiplexed or
// not until we run it. So we set it to false here. Once started, we ask
// the plugin if it is multiplexed and set this value accordingly.
multiplexingSupport := false
return c.setInternal(ctx, name, pluginType, multiplexingSupport, command, args, env, sha256)
}
func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType consts.PluginType, command string, args []string, env []string, sha256 []byte) error {
func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType consts.PluginType, multiplexingSupport bool, command string, args []string, env []string, sha256 []byte) error {
// Best effort check to make sure the command isn't breaking out of the
// configured plugin directory.
commandFull := filepath.Join(c.directory, command)
@ -294,15 +560,16 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType
// entryTmp should only be used for the below type check, it uses the
// full command instead of the relative command.
entryTmp := &pluginutil.PluginRunner{
Name: name,
Command: commandFull,
Args: args,
Env: env,
Sha256: sha256,
Builtin: false,
Name: name,
Command: commandFull,
Args: args,
Env: env,
Sha256: sha256,
Builtin: false,
MultiplexingSupport: multiplexingSupport,
}
pluginType, err = c.getPluginTypeFromUnknown(ctx, log.Default(), entryTmp)
pluginType, multiplexingSupport, err = c.getPluginTypeFromUnknown(ctx, log.Default(), entryTmp)
if err != nil {
return err
}
@ -312,13 +579,14 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType
}
entry := &pluginutil.PluginRunner{
Name: name,
Type: pluginType,
Command: command,
Args: args,
Env: env,
Sha256: sha256,
Builtin: false,
Name: name,
Type: pluginType,
Command: command,
Args: args,
Env: env,
Sha256: sha256,
Builtin: false,
MultiplexingSupport: multiplexingSupport,
}
buf, err := json.Marshal(entry)

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.17.3
// protoc v3.19.4
// source: vault/request_forwarding_service.proto
package vault