feature: multiplexing support for database plugins (#14033)
* feat: DB plugin multiplexing (#13734) * WIP: start from main and get a plugin runner from core * move MultiplexedClient map to plugin catalog - call sys.NewPluginClient from PluginFactory - updates to getPluginClient - thread through isMetadataMode * use go-plugin ClientProtocol interface - call sys.NewPluginClient from dbplugin.NewPluginClient * move PluginSets to dbplugin package - export dbplugin HandshakeConfig - small refactor of PluginCatalog.getPluginClient * add removeMultiplexedClient; clean up on Close() - call client.Kill from plugin catalog - set rpcClient when muxed client exists * add ID to dbplugin.DatabasePluginClient struct * only create one plugin process per plugin type * update NewPluginClient to return connection ID to sdk - wrap grpc.ClientConn so we can inject the ID into context - get ID from context on grpc server * add v6 multiplexing protocol version * WIP: backwards compat for db plugins * Ensure locking on plugin catalog access - Create public GetPluginClient method for plugin catalog - rename postgres db plugin * use the New constructor for db plugins * grpc server: use write lock for Close and rlock for CRUD * cleanup MultiplexedClients on Close * remove TODO * fix multiplexing regression with grpc server connection * cleanup grpc server instances on close * embed ClientProtocol in Multiplexer interface * use PluginClientConfig arg to make NewPluginClient plugin type agnostic * create a new plugin process for non-muxed plugins * feat: plugin multiplexing: handle plugin client cleanup (#13896) * use closure for plugin client cleanup * log and return errors; add comments * move rpcClient wrapping to core for ID injection * refactor core plugin client and sdk * remove unused ID method * refactor and only wrap clientConn on multiplexed plugins * rename structs and do not export types * Slight refactor of system view interface * Revert "Slight refactor of system view interface" This reverts commit 73d420e5cd2f0415e000c5a9284ea72a58016dd6. * Revert "Revert "Slight refactor of system view interface"" This reverts commit f75527008a1db06d04a23e04c3059674be8adb5f. * only provide pluginRunner arg to the internal newPluginClient method * embed ClientProtocol in pluginClient and name logger * Add back MLock support * remove enableMlock arg from setupPluginCatalog * rename plugin util interface to PluginClient Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com> * feature: multiplexing: fix unit tests (#14007) * fix grpc_server tests and add coverage * update run_config tests * add happy path test case for grpc_server ID from context * update test helpers * feat: multiplexing: handle v5 plugin compiled with new sdk * add mux supported flag and increase test coverage * set multiplexingSupport field in plugin server * remove multiplexingSupport field in sdk * revert postgres to non-multiplexed * add comments on grpc server fields * use pointer receiver on grpc server methods * add changelog * use pointer for grpcserver instance * Use a gRPC server to determine if a plugin should be multiplexed * Apply suggestions from code review Co-authored-by: Brian Kassouf <briankassouf@users.noreply.github.com> * add lock to removePluginClient * add multiplexingSupport field to externalPlugin struct * do not send nil to grpc MultiplexingSupport * check err before logging * handle locking scenario for cleanupFunc * allow ServeConfigMultiplex to dispense v5 plugin * reposition structs, add err check and comments * add comment on locking for cleanupExternalPlugin Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com> Co-authored-by: Brian Kassouf <briankassouf@users.noreply.github.com>
This commit is contained in:
parent
91f5069c03
commit
1cf74e1179
1
Makefile
1
Makefile
|
@ -194,6 +194,7 @@ proto: bootstrap
|
|||
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/database/dbplugin/*.proto
|
||||
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/database/dbplugin/v5/proto/*.proto
|
||||
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/plugin/pb/*.proto
|
||||
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/helper/pluginutil/*.proto
|
||||
|
||||
# No additional sed expressions should be added to this list. Going forward
|
||||
# we should just use the variable names choosen by protobuf. These are left
|
||||
|
|
|
@ -110,6 +110,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
|
|||
}
|
||||
|
||||
type databaseBackend struct {
|
||||
// connections holds configured database connections by config name
|
||||
connections map[string]*dbPluginInstance
|
||||
logger log.Logger
|
||||
|
||||
|
|
|
@ -329,6 +329,8 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
|||
}
|
||||
config.ConnectionDetails = initResp.Config
|
||||
|
||||
b.Logger().Debug("created database object", "name", name, "plugin_name", config.PluginName)
|
||||
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
|
@ -365,6 +367,9 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
|||
"Vault (or the sdk if using a custom plugin) to gain password policy support", config.PluginName))
|
||||
}
|
||||
|
||||
if len(resp.Warnings) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
```release-note:feature
|
||||
**Database plugin multiplexing**: manage multiple database connections with a single plugin process
|
||||
```
|
|
@ -150,8 +150,8 @@ type registry struct {
|
|||
logicalBackends map[string]logical.Factory
|
||||
}
|
||||
|
||||
// Get returns the BuiltinFactory func for a particular backend plugin
|
||||
// from the plugins map.
|
||||
// Get returns the Factory func for a particular backend plugin from the
|
||||
// plugins map.
|
||||
func (r *registry) Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) {
|
||||
switch pluginType {
|
||||
case consts.PluginTypeCredential:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.27.1
|
||||
// protoc v3.17.3
|
||||
// protoc v3.19.4
|
||||
// source: helper/forwarding/types.proto
|
||||
|
||||
package forwarding
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.27.1
|
||||
// protoc v3.17.3
|
||||
// protoc v3.19.4
|
||||
// source: helper/identity/mfa/types.proto
|
||||
|
||||
package mfa
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.27.1
|
||||
// protoc v3.17.3
|
||||
// protoc v3.19.4
|
||||
// source: helper/identity/types.proto
|
||||
|
||||
package identity
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.27.1
|
||||
// protoc v3.17.3
|
||||
// protoc v3.19.4
|
||||
// source: helper/storagepacker/types.proto
|
||||
|
||||
package storagepacker
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.27.1
|
||||
// protoc v3.17.3
|
||||
// protoc v3.19.4
|
||||
// source: physical/raft/types.proto
|
||||
|
||||
package raft
|
||||
|
|
|
@ -48,7 +48,6 @@ var (
|
|||
singleQuotedPhrases = regexp.MustCompile(`('.*?')`)
|
||||
)
|
||||
|
||||
// New implements builtinplugins.BuiltinFactory
|
||||
func New() (interface{}, error) {
|
||||
db := new()
|
||||
// Wrap the plugin with middleware to sanitize errors
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.27.1
|
||||
// protoc v3.17.3
|
||||
// protoc v3.19.4
|
||||
// source: sdk/database/dbplugin/database.proto
|
||||
|
||||
package dbplugin
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
|
||||
"github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
|
@ -12,14 +13,17 @@ import (
|
|||
// a plugin and host. If the handshake fails, a user friendly error is shown.
|
||||
// This prevents users from executing bad plugins or executing a plugin
|
||||
// directory. It is a UX feature, not a security feature.
|
||||
var handshakeConfig = plugin.HandshakeConfig{
|
||||
ProtocolVersion: 5,
|
||||
var HandshakeConfig = plugin.HandshakeConfig{
|
||||
MagicCookieKey: "VAULT_DATABASE_PLUGIN",
|
||||
MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb",
|
||||
}
|
||||
|
||||
// Factory is the factory function to create a dbplugin Database.
|
||||
type Factory func() (interface{}, error)
|
||||
|
||||
type GRPCDatabasePlugin struct {
|
||||
Impl Database
|
||||
FactoryFunc Factory
|
||||
Impl Database
|
||||
|
||||
// Embeding this will disable the netRPC protocol
|
||||
plugin.NetRPCUnsupportedPlugin
|
||||
|
@ -31,7 +35,25 @@ var (
|
|||
)
|
||||
|
||||
func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error {
|
||||
proto.RegisterDatabaseServer(s, gRPCServer{impl: d.Impl})
|
||||
var server gRPCServer
|
||||
|
||||
if d.Impl != nil {
|
||||
server = gRPCServer{singleImpl: d.Impl}
|
||||
} else {
|
||||
// multiplexing is supported
|
||||
server = gRPCServer{
|
||||
factoryFunc: d.FactoryFunc,
|
||||
instances: make(map[string]Database),
|
||||
}
|
||||
|
||||
// Multiplexing is enabled for this plugin, register the server so we
|
||||
// can tell the client in Vault.
|
||||
pluginutil.RegisterPluginMultiplexingServer(s, pluginutil.PluginMultiplexingServerImpl{
|
||||
Supported: true,
|
||||
})
|
||||
}
|
||||
|
||||
proto.RegisterDatabaseServer(s, &server)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -3,24 +3,113 @@ package dbplugin
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var _ proto.DatabaseServer = gRPCServer{}
|
||||
var _ proto.DatabaseServer = &gRPCServer{}
|
||||
|
||||
type gRPCServer struct {
|
||||
proto.UnimplementedDatabaseServer
|
||||
|
||||
impl Database
|
||||
// holds the non-multiplexed Database
|
||||
// when this is set the plugin does not support multiplexing
|
||||
singleImpl Database
|
||||
|
||||
// instances holds the multiplexed Databases
|
||||
instances map[string]Database
|
||||
factoryFunc func() (interface{}, error)
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func getMultiplexIDFromContext(ctx context.Context) (string, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing plugin multiplexing metadata")
|
||||
}
|
||||
|
||||
multiplexIDs := md[pluginutil.MultiplexingCtxKey]
|
||||
if len(multiplexIDs) != 1 {
|
||||
return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs))
|
||||
}
|
||||
|
||||
multiplexID := multiplexIDs[0]
|
||||
if multiplexID == "" {
|
||||
return "", fmt.Errorf("empty multiplex ID in metadata")
|
||||
}
|
||||
|
||||
return multiplexID, nil
|
||||
}
|
||||
|
||||
func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
|
||||
if g.singleImpl != nil {
|
||||
return g.singleImpl, nil
|
||||
}
|
||||
|
||||
id, err := getMultiplexIDFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if db, ok := g.instances[id]; ok {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
db, err := g.factoryFunc()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
database := db.(Database)
|
||||
g.instances[id] = database
|
||||
|
||||
return database, nil
|
||||
}
|
||||
|
||||
// getDatabaseInternal returns the database but does not hold a lock
|
||||
func (g *gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) {
|
||||
if g.singleImpl != nil {
|
||||
return g.singleImpl, nil
|
||||
}
|
||||
|
||||
id, err := getMultiplexIDFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if db, ok := g.instances[id]; ok {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no database instance found")
|
||||
}
|
||||
|
||||
// getDatabase holds a read lock and returns the database
|
||||
func (g *gRPCServer) getDatabase(ctx context.Context) (Database, error) {
|
||||
g.RLock()
|
||||
impl, err := g.getDatabaseInternal(ctx)
|
||||
g.RUnlock()
|
||||
return impl, err
|
||||
}
|
||||
|
||||
// Initialize the database plugin
|
||||
func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeRequest) (*proto.InitializeResponse, error) {
|
||||
func (g *gRPCServer) Initialize(ctx context.Context, request *proto.InitializeRequest) (*proto.InitializeResponse, error) {
|
||||
impl, err := g.getOrCreateDatabase(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawConfig := structToMap(request.ConfigData)
|
||||
|
||||
dbReq := InitializeRequest{
|
||||
|
@ -28,7 +117,7 @@ func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeReq
|
|||
VerifyConnection: request.VerifyConnection,
|
||||
}
|
||||
|
||||
dbResp, err := g.impl.Initialize(ctx, dbReq)
|
||||
dbResp, err := impl.Initialize(ctx, dbReq)
|
||||
if err != nil {
|
||||
return &proto.InitializeResponse{}, status.Errorf(codes.Internal, "failed to initialize: %s", err)
|
||||
}
|
||||
|
@ -45,7 +134,7 @@ func (g gRPCServer) Initialize(ctx context.Context, request *proto.InitializeReq
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*proto.NewUserResponse, error) {
|
||||
func (g *gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*proto.NewUserResponse, error) {
|
||||
if req.GetUsernameConfig() == nil {
|
||||
return &proto.NewUserResponse{}, status.Errorf(codes.InvalidArgument, "missing username config")
|
||||
}
|
||||
|
@ -60,6 +149,11 @@ func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*pr
|
|||
expiration = exp
|
||||
}
|
||||
|
||||
impl, err := g.getDatabase(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dbReq := NewUserRequest{
|
||||
UsernameConfig: UsernameMetadata{
|
||||
DisplayName: req.GetUsernameConfig().GetDisplayName(),
|
||||
|
@ -71,7 +165,7 @@ func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*pr
|
|||
RollbackStatements: getStatementsFromProto(req.GetRollbackStatements()),
|
||||
}
|
||||
|
||||
dbResp, err := g.impl.NewUser(ctx, dbReq)
|
||||
dbResp, err := impl.NewUser(ctx, dbReq)
|
||||
if err != nil {
|
||||
return &proto.NewUserResponse{}, status.Errorf(codes.Internal, "unable to create new user: %s", err)
|
||||
}
|
||||
|
@ -82,7 +176,7 @@ func (g gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*pr
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
func (g gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest) (*proto.UpdateUserResponse, error) {
|
||||
func (g *gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest) (*proto.UpdateUserResponse, error) {
|
||||
if req.GetUsername() == "" {
|
||||
return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided")
|
||||
}
|
||||
|
@ -92,7 +186,12 @@ func (g gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest
|
|||
return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
_, err = g.impl.UpdateUser(ctx, dbReq)
|
||||
impl, err := g.getDatabase(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = impl.UpdateUser(ctx, dbReq)
|
||||
if err != nil {
|
||||
return &proto.UpdateUserResponse{}, status.Errorf(codes.Internal, "unable to update user: %s", err)
|
||||
}
|
||||
|
@ -144,7 +243,7 @@ func hasChange(dbReq UpdateUserRequest) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (g gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest) (*proto.DeleteUserResponse, error) {
|
||||
func (g *gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest) (*proto.DeleteUserResponse, error) {
|
||||
if req.GetUsername() == "" {
|
||||
return &proto.DeleteUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided")
|
||||
}
|
||||
|
@ -153,15 +252,25 @@ func (g gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest
|
|||
Statements: getStatementsFromProto(req.GetStatements()),
|
||||
}
|
||||
|
||||
_, err := g.impl.DeleteUser(ctx, dbReq)
|
||||
impl, err := g.getDatabase(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = impl.DeleteUser(ctx, dbReq)
|
||||
if err != nil {
|
||||
return &proto.DeleteUserResponse{}, status.Errorf(codes.Internal, "unable to delete user: %s", err)
|
||||
}
|
||||
return &proto.DeleteUserResponse{}, nil
|
||||
}
|
||||
|
||||
func (g gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeResponse, error) {
|
||||
t, err := g.impl.Type()
|
||||
func (g *gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeResponse, error) {
|
||||
impl, err := g.getOrCreateDatabase(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t, err := impl.Type()
|
||||
if err != nil {
|
||||
return &proto.TypeResponse{}, status.Errorf(codes.Internal, "unable to retrieve type: %s", err)
|
||||
}
|
||||
|
@ -172,11 +281,29 @@ func (g gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeRespon
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
func (g gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) {
|
||||
err := g.impl.Close()
|
||||
func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) {
|
||||
g.Lock()
|
||||
defer g.Unlock()
|
||||
|
||||
impl, err := g.getDatabaseInternal(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = impl.Close()
|
||||
if err != nil {
|
||||
return &proto.Empty{}, status.Errorf(codes.Internal, "unable to close database plugin: %s", err)
|
||||
}
|
||||
|
||||
if g.singleImpl == nil {
|
||||
// only cleanup instances map when multiplexing is supported
|
||||
id, err := getMultiplexIDFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
delete(g.instances, id)
|
||||
}
|
||||
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,9 @@ import (
|
|||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
|
@ -22,11 +24,12 @@ var invalidExpiration = time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC)
|
|||
|
||||
func TestGRPCServer_Initialize(t *testing.T) {
|
||||
type testCase struct {
|
||||
db Database
|
||||
req *proto.InitializeRequest
|
||||
expectedResp *proto.InitializeResponse
|
||||
expectErr bool
|
||||
expectCode codes.Code
|
||||
db Database
|
||||
req *proto.InitializeRequest
|
||||
expectedResp *proto.InitializeResponse
|
||||
expectErr bool
|
||||
expectCode codes.Code
|
||||
grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer)
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
|
@ -34,10 +37,11 @@ func TestGRPCServer_Initialize(t *testing.T) {
|
|||
db: fakeDatabase{
|
||||
initErr: errors.New("initialization error"),
|
||||
},
|
||||
req: &proto.InitializeRequest{},
|
||||
expectedResp: &proto.InitializeResponse{},
|
||||
expectErr: true,
|
||||
expectCode: codes.Internal,
|
||||
req: &proto.InitializeRequest{},
|
||||
expectedResp: &proto.InitializeResponse{},
|
||||
expectErr: true,
|
||||
expectCode: codes.Internal,
|
||||
grpcSetupFunc: testGrpcServer,
|
||||
},
|
||||
"newConfig can't marshal to JSON": {
|
||||
db: fakeDatabase{
|
||||
|
@ -47,12 +51,13 @@ func TestGRPCServer_Initialize(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
req: &proto.InitializeRequest{},
|
||||
expectedResp: &proto.InitializeResponse{},
|
||||
expectErr: true,
|
||||
expectCode: codes.Internal,
|
||||
req: &proto.InitializeRequest{},
|
||||
expectedResp: &proto.InitializeResponse{},
|
||||
expectErr: true,
|
||||
expectCode: codes.Internal,
|
||||
grpcSetupFunc: testGrpcServer,
|
||||
},
|
||||
"happy path with config data": {
|
||||
"happy path with config data for multiplexed plugin": {
|
||||
db: fakeDatabase{
|
||||
initResp: InitializeResponse{
|
||||
Config: map[string]interface{}{
|
||||
|
@ -70,21 +75,39 @@ func TestGRPCServer_Initialize(t *testing.T) {
|
|||
"foo": "bar",
|
||||
}),
|
||||
},
|
||||
expectErr: false,
|
||||
expectCode: codes.OK,
|
||||
expectErr: false,
|
||||
expectCode: codes.OK,
|
||||
grpcSetupFunc: testGrpcServer,
|
||||
},
|
||||
"happy path with config data for non-multiplexed plugin": {
|
||||
db: fakeDatabase{
|
||||
initResp: InitializeResponse{
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
req: &proto.InitializeRequest{
|
||||
ConfigData: marshal(t, map[string]interface{}{
|
||||
"foo": "bar",
|
||||
}),
|
||||
},
|
||||
expectedResp: &proto.InitializeResponse{
|
||||
ConfigData: marshal(t, map[string]interface{}{
|
||||
"foo": "bar",
|
||||
}),
|
||||
},
|
||||
expectErr: false,
|
||||
expectCode: codes.OK,
|
||||
grpcSetupFunc: testGrpcServerSingleImpl,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
g := gRPCServer{
|
||||
impl: test.db,
|
||||
}
|
||||
idCtx, g := test.grpcSetupFunc(t, test.db)
|
||||
resp, err := g.Initialize(idCtx, test.req)
|
||||
|
||||
// Context doesn't need to timeout since this is just passed through
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := g.Initialize(ctx, test.req)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
|
@ -252,14 +275,9 @@ func TestGRPCServer_NewUser(t *testing.T) {
|
|||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
g := gRPCServer{
|
||||
impl: test.db,
|
||||
}
|
||||
idCtx, g := testGrpcServer(t, test.db)
|
||||
resp, err := g.NewUser(idCtx, test.req)
|
||||
|
||||
// Context doesn't need to timeout since this is just passed through
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := g.NewUser(ctx, test.req)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
|
@ -362,14 +380,9 @@ func TestGRPCServer_UpdateUser(t *testing.T) {
|
|||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
g := gRPCServer{
|
||||
impl: test.db,
|
||||
}
|
||||
idCtx, g := testGrpcServer(t, test.db)
|
||||
resp, err := g.UpdateUser(idCtx, test.req)
|
||||
|
||||
// Context doesn't need to timeout since this is just passed through
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := g.UpdateUser(ctx, test.req)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
|
@ -430,14 +443,9 @@ func TestGRPCServer_DeleteUser(t *testing.T) {
|
|||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
g := gRPCServer{
|
||||
impl: test.db,
|
||||
}
|
||||
idCtx, g := testGrpcServer(t, test.db)
|
||||
resp, err := g.DeleteUser(idCtx, test.req)
|
||||
|
||||
// Context doesn't need to timeout since this is just passed through
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := g.DeleteUser(ctx, test.req)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
|
@ -488,14 +496,9 @@ func TestGRPCServer_Type(t *testing.T) {
|
|||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
g := gRPCServer{
|
||||
impl: test.db,
|
||||
}
|
||||
idCtx, g := testGrpcServer(t, test.db)
|
||||
resp, err := g.Type(idCtx, &proto.Empty{})
|
||||
|
||||
// Context doesn't need to timeout since this is just passed through
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := g.Type(ctx, &proto.Empty{})
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
|
@ -517,9 +520,11 @@ func TestGRPCServer_Type(t *testing.T) {
|
|||
|
||||
func TestGRPCServer_Close(t *testing.T) {
|
||||
type testCase struct {
|
||||
db Database
|
||||
expectErr bool
|
||||
expectCode codes.Code
|
||||
db Database
|
||||
expectErr bool
|
||||
expectCode codes.Code
|
||||
grpcSetupFunc func(*testing.T, Database) (context.Context, gRPCServer)
|
||||
assertFunc func(t *testing.T, g gRPCServer)
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
|
@ -527,26 +532,36 @@ func TestGRPCServer_Close(t *testing.T) {
|
|||
db: fakeDatabase{
|
||||
closeErr: errors.New("close error"),
|
||||
},
|
||||
expectErr: true,
|
||||
expectCode: codes.Internal,
|
||||
expectErr: true,
|
||||
expectCode: codes.Internal,
|
||||
grpcSetupFunc: testGrpcServer,
|
||||
assertFunc: nil,
|
||||
},
|
||||
"happy path": {
|
||||
db: fakeDatabase{},
|
||||
expectErr: false,
|
||||
expectCode: codes.OK,
|
||||
"happy path for multiplexed plugin": {
|
||||
db: fakeDatabase{},
|
||||
expectErr: false,
|
||||
expectCode: codes.OK,
|
||||
grpcSetupFunc: testGrpcServer,
|
||||
assertFunc: func(t *testing.T, g gRPCServer) {
|
||||
if len(g.instances) != 0 {
|
||||
t.Fatalf("err expected instances map to be empty")
|
||||
}
|
||||
},
|
||||
},
|
||||
"happy path for non-multiplexed plugin": {
|
||||
db: fakeDatabase{},
|
||||
expectErr: false,
|
||||
expectCode: codes.OK,
|
||||
grpcSetupFunc: testGrpcServerSingleImpl,
|
||||
assertFunc: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
g := gRPCServer{
|
||||
impl: test.db,
|
||||
}
|
||||
idCtx, g := test.grpcSetupFunc(t, test.db)
|
||||
_, err := g.Close(idCtx, &proto.Empty{})
|
||||
|
||||
// Context doesn't need to timeout since this is just passed through
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := g.Close(ctx, &proto.Empty{})
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
|
@ -558,10 +573,105 @@ func TestGRPCServer_Close(t *testing.T) {
|
|||
if actualCode != test.expectCode {
|
||||
t.Fatalf("Actual code: %s Expected code: %s", actualCode, test.expectCode)
|
||||
}
|
||||
|
||||
if test.assertFunc != nil {
|
||||
test.assertFunc(t, g)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMultiplexIDFromContext(t *testing.T) {
|
||||
type testCase struct {
|
||||
ctx context.Context
|
||||
expectedResp string
|
||||
expectedErr error
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"missing plugin multiplexing metadata": {
|
||||
ctx: context.Background(),
|
||||
expectedResp: "",
|
||||
expectedErr: fmt.Errorf("missing plugin multiplexing metadata"),
|
||||
},
|
||||
"unexpected number of IDs in metadata": {
|
||||
ctx: idCtx(t, "12345", "67891"),
|
||||
expectedResp: "",
|
||||
expectedErr: fmt.Errorf("unexpected number of IDs in metadata: (2)"),
|
||||
},
|
||||
"empty multiplex ID in metadata": {
|
||||
ctx: idCtx(t, ""),
|
||||
expectedResp: "",
|
||||
expectedErr: fmt.Errorf("empty multiplex ID in metadata"),
|
||||
},
|
||||
"happy path, id is returned from metadata": {
|
||||
ctx: idCtx(t, "12345"),
|
||||
expectedResp: "12345",
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
resp, err := getMultiplexIDFromContext(test.ctx)
|
||||
|
||||
if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
} else if !reflect.DeepEqual(err, test.expectedErr) {
|
||||
t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr)
|
||||
}
|
||||
|
||||
if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(resp, test.expectedResp) {
|
||||
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testGrpcServer is a test helper that returns a context with an ID set in its
|
||||
// metadata and a gRPCServer instance for a multiplexed plugin
|
||||
func testGrpcServer(t *testing.T, db Database) (context.Context, gRPCServer) {
|
||||
t.Helper()
|
||||
g := gRPCServer{
|
||||
factoryFunc: func() (interface{}, error) {
|
||||
return db, nil
|
||||
},
|
||||
instances: make(map[string]Database),
|
||||
}
|
||||
|
||||
id := "12345"
|
||||
idCtx := idCtx(t, id)
|
||||
g.instances[id] = db
|
||||
|
||||
return idCtx, g
|
||||
}
|
||||
|
||||
// testGrpcServerSingleImpl is a test helper that returns a context and a
|
||||
// gRPCServer instance for a non-multiplexed plugin
|
||||
func testGrpcServerSingleImpl(t *testing.T, db Database) (context.Context, gRPCServer) {
|
||||
t.Helper()
|
||||
return context.Background(), gRPCServer{
|
||||
singleImpl: db,
|
||||
}
|
||||
}
|
||||
|
||||
// idCtx is a test helper that will return a context with the IDs set in its
|
||||
// metadata
|
||||
func idCtx(t *testing.T, ids ...string) context.Context {
|
||||
t.Helper()
|
||||
// Context doesn't need to timeout since this is just passed through
|
||||
ctx := context.Background()
|
||||
md := metadata.MD{}
|
||||
for _, id := range ids {
|
||||
md.Append(pluginutil.MultiplexingCtxKey, id)
|
||||
}
|
||||
return metadata.NewIncomingContext(ctx, md)
|
||||
}
|
||||
|
||||
func marshal(t *testing.T, m map[string]interface{}) *structpb.Struct {
|
||||
t.Helper()
|
||||
|
||||
|
|
|
@ -3,19 +3,14 @@ package dbplugin
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
)
|
||||
|
||||
// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's Close
|
||||
// method to also call Kill() on the plugin.Client.
|
||||
type DatabasePluginClient struct {
|
||||
client *plugin.Client
|
||||
sync.Mutex
|
||||
|
||||
client pluginutil.PluginClient
|
||||
Database
|
||||
}
|
||||
|
||||
|
@ -23,42 +18,31 @@ type DatabasePluginClient struct {
|
|||
// and kill the plugin.
|
||||
func (dc *DatabasePluginClient) Close() error {
|
||||
err := dc.Database.Close()
|
||||
dc.client.Kill()
|
||||
dc.client.Close()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// pluginSets is the map of plugins we can dispense.
|
||||
var PluginSets = map[int]plugin.PluginSet{
|
||||
5: {
|
||||
"database": &GRPCDatabasePlugin{},
|
||||
},
|
||||
6: {
|
||||
"database": &GRPCDatabasePlugin{},
|
||||
},
|
||||
}
|
||||
|
||||
// NewPluginClient returns a databaseRPCClient with a connection to a running
|
||||
// plugin. The client is wrapped in a DatabasePluginClient object to ensure the
|
||||
// plugin is killed on call of Close().
|
||||
func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (Database, error) {
|
||||
// pluginSets is the map of plugins we can dispense.
|
||||
pluginSets := map[int]plugin.PluginSet{
|
||||
5: {
|
||||
"database": new(GRPCDatabasePlugin),
|
||||
},
|
||||
}
|
||||
|
||||
client, err := pluginRunner.RunConfig(ctx,
|
||||
pluginutil.Runner(sys),
|
||||
pluginutil.PluginSets(pluginSets),
|
||||
pluginutil.HandshakeConfig(handshakeConfig),
|
||||
pluginutil.Logger(logger),
|
||||
pluginutil.MetadataMode(isMetadataMode),
|
||||
pluginutil.AutoMTLS(true),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Connect via RPC
|
||||
rpcClient, err := client.Client()
|
||||
// plugin.
|
||||
func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, config pluginutil.PluginClientConfig) (Database, error) {
|
||||
pluginClient, err := sys.NewPluginClient(ctx, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Request the plugin
|
||||
raw, err := rpcClient.Dispense("database")
|
||||
raw, err := pluginClient.Dispense("database")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -66,16 +50,19 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
|
|||
// We should have a database type now. This feels like a normal interface
|
||||
// implementation but is in fact over an RPC connection.
|
||||
var db Database
|
||||
switch raw.(type) {
|
||||
switch c := raw.(type) {
|
||||
case gRPCClient:
|
||||
db = raw.(gRPCClient)
|
||||
// This is an abstraction leak from go-plugin but it is necessary in
|
||||
// order to enable multiplexing on multiplexed plugins
|
||||
c.client = proto.NewDatabaseClient(pluginClient.Conn())
|
||||
|
||||
db = c
|
||||
default:
|
||||
return nil, errors.New("unsupported client type")
|
||||
}
|
||||
|
||||
// Wrap RPC implementation in DatabasePluginClient
|
||||
return &DatabasePluginClient{
|
||||
client: client,
|
||||
client: pluginClient,
|
||||
Database: db,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -40,8 +40,17 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu
|
|||
transport = "builtin"
|
||||
|
||||
} else {
|
||||
config := pluginutil.PluginClientConfig{
|
||||
Name: pluginName,
|
||||
PluginType: consts.PluginTypeDatabase,
|
||||
PluginSets: PluginSets,
|
||||
HandshakeConfig: HandshakeConfig,
|
||||
Logger: namedLogger,
|
||||
IsMetadataMode: false,
|
||||
AutoMTLS: true,
|
||||
}
|
||||
// create a DatabasePluginClient instance
|
||||
db, err = NewPluginClient(ctx, sys, pluginRunner, namedLogger, false)
|
||||
db, err = NewPluginClient(ctx, sys, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -59,6 +68,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu
|
|||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err)
|
||||
}
|
||||
logger.Debug("got database plugin instance", "type", typeStr)
|
||||
|
||||
// Wrap with metrics middleware
|
||||
db = &databaseMetricsMiddleware{
|
||||
|
|
|
@ -31,7 +31,49 @@ func ServeConfig(db Database) *plugin.ServeConfig {
|
|||
}
|
||||
|
||||
conf := &plugin.ServeConfig{
|
||||
HandshakeConfig: handshakeConfig,
|
||||
HandshakeConfig: HandshakeConfig,
|
||||
VersionedPlugins: pluginSets,
|
||||
GRPCServer: plugin.DefaultGRPCServer,
|
||||
}
|
||||
|
||||
return conf
|
||||
}
|
||||
|
||||
func ServeMultiplex(factory Factory) {
|
||||
plugin.Serve(ServeConfigMultiplex(factory))
|
||||
}
|
||||
|
||||
func ServeConfigMultiplex(factory Factory) *plugin.ServeConfig {
|
||||
err := pluginutil.OptionallyEnableMlock()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
db, err := factory()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
database := db.(Database)
|
||||
|
||||
// pluginSets is the map of plugins we can dispense.
|
||||
pluginSets := map[int]plugin.PluginSet{
|
||||
5: {
|
||||
"database": &GRPCDatabasePlugin{
|
||||
Impl: database,
|
||||
},
|
||||
},
|
||||
6: {
|
||||
"database": &GRPCDatabasePlugin{
|
||||
FactoryFunc: factory,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
conf := &plugin.ServeConfig{
|
||||
HandshakeConfig: HandshakeConfig,
|
||||
VersionedPlugins: pluginSets,
|
||||
GRPCServer: plugin.DefaultGRPCServer,
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.27.1
|
||||
// protoc v3.17.3
|
||||
// protoc v3.19.4
|
||||
// source: sdk/database/dbplugin/v5/proto/database.proto
|
||||
|
||||
package proto
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
package pluginutil
|
||||
|
||||
import (
|
||||
context "context"
|
||||
"fmt"
|
||||
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type PluginMultiplexingServerImpl struct {
|
||||
UnimplementedPluginMultiplexingServer
|
||||
|
||||
Supported bool
|
||||
}
|
||||
|
||||
func (pm PluginMultiplexingServerImpl) MultiplexingSupport(ctx context.Context, req *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error) {
|
||||
return &MultiplexingSupportResponse{
|
||||
Supported: pm.Supported,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func MultiplexingSupported(ctx context.Context, cc grpc.ClientConnInterface) (bool, error) {
|
||||
if cc == nil {
|
||||
return false, fmt.Errorf("client connection is nil")
|
||||
}
|
||||
|
||||
req := new(MultiplexingSupportRequest)
|
||||
resp, err := NewPluginMultiplexingClient(cc).MultiplexingSupport(ctx, req)
|
||||
if err != nil {
|
||||
|
||||
// If the server does not implement the multiplexing server then we can
|
||||
// assume it is not multiplexed
|
||||
if status.Code(err) == codes.Unimplemented {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
if resp == nil {
|
||||
// Somehow got a nil response, assume not multiplexed
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return resp.Supported, nil
|
||||
}
|
|
@ -0,0 +1,213 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.27.1
|
||||
// protoc v3.19.4
|
||||
// source: sdk/helper/pluginutil/multiplexing.proto
|
||||
|
||||
package pluginutil
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type MultiplexingSupportRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
}
|
||||
|
||||
func (x *MultiplexingSupportRequest) Reset() {
|
||||
*x = MultiplexingSupportRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *MultiplexingSupportRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*MultiplexingSupportRequest) ProtoMessage() {}
|
||||
|
||||
func (x *MultiplexingSupportRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use MultiplexingSupportRequest.ProtoReflect.Descriptor instead.
|
||||
func (*MultiplexingSupportRequest) Descriptor() ([]byte, []int) {
|
||||
return file_sdk_helper_pluginutil_multiplexing_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
type MultiplexingSupportResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Supported bool `protobuf:"varint,1,opt,name=supported,proto3" json:"supported,omitempty"`
|
||||
}
|
||||
|
||||
func (x *MultiplexingSupportResponse) Reset() {
|
||||
*x = MultiplexingSupportResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *MultiplexingSupportResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*MultiplexingSupportResponse) ProtoMessage() {}
|
||||
|
||||
func (x *MultiplexingSupportResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use MultiplexingSupportResponse.ProtoReflect.Descriptor instead.
|
||||
func (*MultiplexingSupportResponse) Descriptor() ([]byte, []int) {
|
||||
return file_sdk_helper_pluginutil_multiplexing_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *MultiplexingSupportResponse) GetSupported() bool {
|
||||
if x != nil {
|
||||
return x.Supported
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var File_sdk_helper_pluginutil_multiplexing_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_sdk_helper_pluginutil_multiplexing_proto_rawDesc = []byte{
|
||||
0x0a, 0x28, 0x73, 0x64, 0x6b, 0x2f, 0x68, 0x65, 0x6c, 0x70, 0x65, 0x72, 0x2f, 0x70, 0x6c, 0x75,
|
||||
0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2f, 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65,
|
||||
0x78, 0x69, 0x6e, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x17, 0x70, 0x6c, 0x75, 0x67,
|
||||
0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2e, 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78,
|
||||
0x69, 0x6e, 0x67, 0x22, 0x1c, 0x0a, 0x1a, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78,
|
||||
0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
|
||||
0x74, 0x22, 0x3b, 0x0a, 0x1b, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e,
|
||||
0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
|
||||
0x12, 0x1c, 0x0a, 0x09, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x65, 0x64, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x08, 0x52, 0x09, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x65, 0x64, 0x32, 0x97,
|
||||
0x01, 0x0a, 0x12, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c,
|
||||
0x65, 0x78, 0x69, 0x6e, 0x67, 0x12, 0x80, 0x01, 0x0a, 0x13, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70,
|
||||
0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x33, 0x2e,
|
||||
0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2e, 0x6d, 0x75, 0x6c, 0x74, 0x69,
|
||||
0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65,
|
||||
0x78, 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65,
|
||||
0x73, 0x74, 0x1a, 0x34, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x2e,
|
||||
0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x2e, 0x4d, 0x75, 0x6c,
|
||||
0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74,
|
||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x32, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68,
|
||||
0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70,
|
||||
0x2f, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x73, 0x64, 0x6b, 0x2f, 0x68, 0x65, 0x6c, 0x70, 0x65,
|
||||
0x72, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x75, 0x74, 0x69, 0x6c, 0x62, 0x06, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_sdk_helper_pluginutil_multiplexing_proto_rawDescOnce sync.Once
|
||||
file_sdk_helper_pluginutil_multiplexing_proto_rawDescData = file_sdk_helper_pluginutil_multiplexing_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_sdk_helper_pluginutil_multiplexing_proto_rawDescGZIP() []byte {
|
||||
file_sdk_helper_pluginutil_multiplexing_proto_rawDescOnce.Do(func() {
|
||||
file_sdk_helper_pluginutil_multiplexing_proto_rawDescData = protoimpl.X.CompressGZIP(file_sdk_helper_pluginutil_multiplexing_proto_rawDescData)
|
||||
})
|
||||
return file_sdk_helper_pluginutil_multiplexing_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_sdk_helper_pluginutil_multiplexing_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_sdk_helper_pluginutil_multiplexing_proto_goTypes = []interface{}{
|
||||
(*MultiplexingSupportRequest)(nil), // 0: pluginutil.multiplexing.MultiplexingSupportRequest
|
||||
(*MultiplexingSupportResponse)(nil), // 1: pluginutil.multiplexing.MultiplexingSupportResponse
|
||||
}
|
||||
var file_sdk_helper_pluginutil_multiplexing_proto_depIdxs = []int32{
|
||||
0, // 0: pluginutil.multiplexing.PluginMultiplexing.MultiplexingSupport:input_type -> pluginutil.multiplexing.MultiplexingSupportRequest
|
||||
1, // 1: pluginutil.multiplexing.PluginMultiplexing.MultiplexingSupport:output_type -> pluginutil.multiplexing.MultiplexingSupportResponse
|
||||
1, // [1:2] is the sub-list for method output_type
|
||||
0, // [0:1] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_sdk_helper_pluginutil_multiplexing_proto_init() }
|
||||
func file_sdk_helper_pluginutil_multiplexing_proto_init() {
|
||||
if File_sdk_helper_pluginutil_multiplexing_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*MultiplexingSupportRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sdk_helper_pluginutil_multiplexing_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*MultiplexingSupportResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_sdk_helper_pluginutil_multiplexing_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 2,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_sdk_helper_pluginutil_multiplexing_proto_goTypes,
|
||||
DependencyIndexes: file_sdk_helper_pluginutil_multiplexing_proto_depIdxs,
|
||||
MessageInfos: file_sdk_helper_pluginutil_multiplexing_proto_msgTypes,
|
||||
}.Build()
|
||||
File_sdk_helper_pluginutil_multiplexing_proto = out.File
|
||||
file_sdk_helper_pluginutil_multiplexing_proto_rawDesc = nil
|
||||
file_sdk_helper_pluginutil_multiplexing_proto_goTypes = nil
|
||||
file_sdk_helper_pluginutil_multiplexing_proto_depIdxs = nil
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
syntax = "proto3";
|
||||
package pluginutil.multiplexing;
|
||||
|
||||
option go_package = "github.com/hashicorp/vault/sdk/helper/pluginutil";
|
||||
|
||||
message MultiplexingSupportRequest {}
|
||||
message MultiplexingSupportResponse {
|
||||
bool supported = 1;
|
||||
}
|
||||
|
||||
service PluginMultiplexing {
|
||||
rpc MultiplexingSupport(MultiplexingSupportRequest) returns (MultiplexingSupportResponse);
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
|
||||
package pluginutil
|
||||
|
||||
import (
|
||||
context "context"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7
|
||||
|
||||
// PluginMultiplexingClient is the client API for PluginMultiplexing service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type PluginMultiplexingClient interface {
|
||||
MultiplexingSupport(ctx context.Context, in *MultiplexingSupportRequest, opts ...grpc.CallOption) (*MultiplexingSupportResponse, error)
|
||||
}
|
||||
|
||||
type pluginMultiplexingClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewPluginMultiplexingClient(cc grpc.ClientConnInterface) PluginMultiplexingClient {
|
||||
return &pluginMultiplexingClient{cc}
|
||||
}
|
||||
|
||||
func (c *pluginMultiplexingClient) MultiplexingSupport(ctx context.Context, in *MultiplexingSupportRequest, opts ...grpc.CallOption) (*MultiplexingSupportResponse, error) {
|
||||
out := new(MultiplexingSupportResponse)
|
||||
err := c.cc.Invoke(ctx, "/pluginutil.multiplexing.PluginMultiplexing/MultiplexingSupport", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// PluginMultiplexingServer is the server API for PluginMultiplexing service.
|
||||
// All implementations must embed UnimplementedPluginMultiplexingServer
|
||||
// for forward compatibility
|
||||
type PluginMultiplexingServer interface {
|
||||
MultiplexingSupport(context.Context, *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error)
|
||||
mustEmbedUnimplementedPluginMultiplexingServer()
|
||||
}
|
||||
|
||||
// UnimplementedPluginMultiplexingServer must be embedded to have forward compatible implementations.
|
||||
type UnimplementedPluginMultiplexingServer struct {
|
||||
}
|
||||
|
||||
func (UnimplementedPluginMultiplexingServer) MultiplexingSupport(context.Context, *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method MultiplexingSupport not implemented")
|
||||
}
|
||||
func (UnimplementedPluginMultiplexingServer) mustEmbedUnimplementedPluginMultiplexingServer() {}
|
||||
|
||||
// UnsafePluginMultiplexingServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to PluginMultiplexingServer will
|
||||
// result in compilation errors.
|
||||
type UnsafePluginMultiplexingServer interface {
|
||||
mustEmbedUnimplementedPluginMultiplexingServer()
|
||||
}
|
||||
|
||||
func RegisterPluginMultiplexingServer(s grpc.ServiceRegistrar, srv PluginMultiplexingServer) {
|
||||
s.RegisterService(&PluginMultiplexing_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _PluginMultiplexing_MultiplexingSupport_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(MultiplexingSupportRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(PluginMultiplexingServer).MultiplexingSupport(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/pluginutil.multiplexing.PluginMultiplexing/MultiplexingSupport",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(PluginMultiplexingServer).MultiplexingSupport(ctx, req.(*MultiplexingSupportRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// PluginMultiplexing_ServiceDesc is the grpc.ServiceDesc for PluginMultiplexing service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var PluginMultiplexing_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "pluginutil.multiplexing.PluginMultiplexing",
|
||||
HandlerType: (*PluginMultiplexingServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "MultiplexingSupport",
|
||||
Handler: _PluginMultiplexing_MultiplexingSupport_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "sdk/helper/pluginutil/multiplexing.proto",
|
||||
}
|
|
@ -9,9 +9,21 @@ import (
|
|||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/version"
|
||||
)
|
||||
|
||||
type PluginClientConfig struct {
|
||||
Name string
|
||||
PluginType consts.PluginType
|
||||
PluginSets map[int]plugin.PluginSet
|
||||
HandshakeConfig plugin.HandshakeConfig
|
||||
Logger log.Logger
|
||||
IsMetadataMode bool
|
||||
AutoMTLS bool
|
||||
MLock bool
|
||||
}
|
||||
|
||||
type runConfig struct {
|
||||
// Provided by PluginRunner
|
||||
command string
|
||||
|
@ -21,12 +33,9 @@ type runConfig struct {
|
|||
// Initialized with what's in PluginRunner.Env, but can be added to
|
||||
env []string
|
||||
|
||||
wrapper RunnerUtil
|
||||
pluginSets map[int]plugin.PluginSet
|
||||
hs plugin.HandshakeConfig
|
||||
logger log.Logger
|
||||
isMetadataMode bool
|
||||
autoMTLS bool
|
||||
wrapper RunnerUtil
|
||||
|
||||
PluginClientConfig
|
||||
}
|
||||
|
||||
func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error) {
|
||||
|
@ -34,19 +43,19 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error
|
|||
cmd.Env = append(cmd.Env, rc.env...)
|
||||
|
||||
// Add the mlock setting to the ENV of the plugin
|
||||
if rc.wrapper != nil && rc.wrapper.MlockEnabled() {
|
||||
if rc.MLock || (rc.wrapper != nil && rc.wrapper.MlockEnabled()) {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true"))
|
||||
}
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginVaultVersionEnv, version.GetVersion().Version))
|
||||
|
||||
if rc.isMetadataMode {
|
||||
rc.logger = rc.logger.With("metadata", "true")
|
||||
if rc.IsMetadataMode {
|
||||
rc.Logger = rc.Logger.With("metadata", "true")
|
||||
}
|
||||
metadataEnv := fmt.Sprintf("%s=%t", PluginMetadataModeEnv, rc.isMetadataMode)
|
||||
metadataEnv := fmt.Sprintf("%s=%t", PluginMetadataModeEnv, rc.IsMetadataMode)
|
||||
cmd.Env = append(cmd.Env, metadataEnv)
|
||||
|
||||
var clientTLSConfig *tls.Config
|
||||
if !rc.autoMTLS && !rc.isMetadataMode {
|
||||
if !rc.AutoMTLS && !rc.IsMetadataMode {
|
||||
// Get a CA TLS Certificate
|
||||
certBytes, key, err := generateCert()
|
||||
if err != nil {
|
||||
|
@ -76,17 +85,17 @@ func (rc runConfig) makeConfig(ctx context.Context) (*plugin.ClientConfig, error
|
|||
}
|
||||
|
||||
clientConfig := &plugin.ClientConfig{
|
||||
HandshakeConfig: rc.hs,
|
||||
VersionedPlugins: rc.pluginSets,
|
||||
HandshakeConfig: rc.HandshakeConfig,
|
||||
VersionedPlugins: rc.PluginSets,
|
||||
Cmd: cmd,
|
||||
SecureConfig: secureConfig,
|
||||
TLSConfig: clientTLSConfig,
|
||||
Logger: rc.logger,
|
||||
Logger: rc.Logger,
|
||||
AllowedProtocols: []plugin.Protocol{
|
||||
plugin.ProtocolNetRPC,
|
||||
plugin.ProtocolGRPC,
|
||||
},
|
||||
AutoMTLS: rc.autoMTLS,
|
||||
AutoMTLS: rc.AutoMTLS,
|
||||
}
|
||||
return clientConfig, nil
|
||||
}
|
||||
|
@ -117,31 +126,37 @@ func Runner(wrapper RunnerUtil) RunOpt {
|
|||
|
||||
func PluginSets(pluginSets map[int]plugin.PluginSet) RunOpt {
|
||||
return func(rc *runConfig) {
|
||||
rc.pluginSets = pluginSets
|
||||
rc.PluginSets = pluginSets
|
||||
}
|
||||
}
|
||||
|
||||
func HandshakeConfig(hs plugin.HandshakeConfig) RunOpt {
|
||||
return func(rc *runConfig) {
|
||||
rc.hs = hs
|
||||
rc.HandshakeConfig = hs
|
||||
}
|
||||
}
|
||||
|
||||
func Logger(logger log.Logger) RunOpt {
|
||||
return func(rc *runConfig) {
|
||||
rc.logger = logger
|
||||
rc.Logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
func MetadataMode(isMetadataMode bool) RunOpt {
|
||||
return func(rc *runConfig) {
|
||||
rc.isMetadataMode = isMetadataMode
|
||||
rc.IsMetadataMode = isMetadataMode
|
||||
}
|
||||
}
|
||||
|
||||
func AutoMTLS(autoMTLS bool) RunOpt {
|
||||
return func(rc *runConfig) {
|
||||
rc.autoMTLS = autoMTLS
|
||||
rc.AutoMTLS = autoMTLS
|
||||
}
|
||||
}
|
||||
|
||||
func MLock(mlock bool) RunOpt {
|
||||
return func(rc *runConfig) {
|
||||
rc.MLock = mlock
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -38,19 +38,21 @@ func TestMakeConfig(t *testing.T) {
|
|||
args: []string{"foo", "bar"},
|
||||
sha256: []byte("some_sha256"),
|
||||
env: []string{"initial=true"},
|
||||
pluginSets: map[int]plugin.PluginSet{
|
||||
1: {
|
||||
"bogus": nil,
|
||||
PluginClientConfig: PluginClientConfig{
|
||||
PluginSets: map[int]plugin.PluginSet{
|
||||
1: {
|
||||
"bogus": nil,
|
||||
},
|
||||
},
|
||||
HandshakeConfig: plugin.HandshakeConfig{
|
||||
ProtocolVersion: 1,
|
||||
MagicCookieKey: "magic_cookie_key",
|
||||
MagicCookieValue: "magic_cookie_value",
|
||||
},
|
||||
Logger: hclog.NewNullLogger(),
|
||||
IsMetadataMode: true,
|
||||
AutoMTLS: false,
|
||||
},
|
||||
hs: plugin.HandshakeConfig{
|
||||
ProtocolVersion: 1,
|
||||
MagicCookieKey: "magic_cookie_key",
|
||||
MagicCookieValue: "magic_cookie_value",
|
||||
},
|
||||
logger: hclog.NewNullLogger(),
|
||||
isMetadataMode: true,
|
||||
autoMTLS: false,
|
||||
},
|
||||
|
||||
responseWrapInfoTimes: 0,
|
||||
|
@ -97,19 +99,21 @@ func TestMakeConfig(t *testing.T) {
|
|||
args: []string{"foo", "bar"},
|
||||
sha256: []byte("some_sha256"),
|
||||
env: []string{"initial=true"},
|
||||
pluginSets: map[int]plugin.PluginSet{
|
||||
1: {
|
||||
"bogus": nil,
|
||||
PluginClientConfig: PluginClientConfig{
|
||||
PluginSets: map[int]plugin.PluginSet{
|
||||
1: {
|
||||
"bogus": nil,
|
||||
},
|
||||
},
|
||||
HandshakeConfig: plugin.HandshakeConfig{
|
||||
ProtocolVersion: 1,
|
||||
MagicCookieKey: "magic_cookie_key",
|
||||
MagicCookieValue: "magic_cookie_value",
|
||||
},
|
||||
Logger: hclog.NewNullLogger(),
|
||||
IsMetadataMode: false,
|
||||
AutoMTLS: false,
|
||||
},
|
||||
hs: plugin.HandshakeConfig{
|
||||
ProtocolVersion: 1,
|
||||
MagicCookieKey: "magic_cookie_key",
|
||||
MagicCookieValue: "magic_cookie_value",
|
||||
},
|
||||
logger: hclog.NewNullLogger(),
|
||||
isMetadataMode: false,
|
||||
autoMTLS: false,
|
||||
},
|
||||
|
||||
responseWrapInfo: &wrapping.ResponseWrapInfo{
|
||||
|
@ -161,19 +165,21 @@ func TestMakeConfig(t *testing.T) {
|
|||
args: []string{"foo", "bar"},
|
||||
sha256: []byte("some_sha256"),
|
||||
env: []string{"initial=true"},
|
||||
pluginSets: map[int]plugin.PluginSet{
|
||||
1: {
|
||||
"bogus": nil,
|
||||
PluginClientConfig: PluginClientConfig{
|
||||
PluginSets: map[int]plugin.PluginSet{
|
||||
1: {
|
||||
"bogus": nil,
|
||||
},
|
||||
},
|
||||
HandshakeConfig: plugin.HandshakeConfig{
|
||||
ProtocolVersion: 1,
|
||||
MagicCookieKey: "magic_cookie_key",
|
||||
MagicCookieValue: "magic_cookie_value",
|
||||
},
|
||||
Logger: hclog.NewNullLogger(),
|
||||
IsMetadataMode: true,
|
||||
AutoMTLS: true,
|
||||
},
|
||||
hs: plugin.HandshakeConfig{
|
||||
ProtocolVersion: 1,
|
||||
MagicCookieKey: "magic_cookie_key",
|
||||
MagicCookieValue: "magic_cookie_value",
|
||||
},
|
||||
logger: hclog.NewNullLogger(),
|
||||
isMetadataMode: true,
|
||||
autoMTLS: true,
|
||||
},
|
||||
|
||||
responseWrapInfoTimes: 0,
|
||||
|
@ -220,19 +226,21 @@ func TestMakeConfig(t *testing.T) {
|
|||
args: []string{"foo", "bar"},
|
||||
sha256: []byte("some_sha256"),
|
||||
env: []string{"initial=true"},
|
||||
pluginSets: map[int]plugin.PluginSet{
|
||||
1: {
|
||||
"bogus": nil,
|
||||
PluginClientConfig: PluginClientConfig{
|
||||
PluginSets: map[int]plugin.PluginSet{
|
||||
1: {
|
||||
"bogus": nil,
|
||||
},
|
||||
},
|
||||
HandshakeConfig: plugin.HandshakeConfig{
|
||||
ProtocolVersion: 1,
|
||||
MagicCookieKey: "magic_cookie_key",
|
||||
MagicCookieValue: "magic_cookie_value",
|
||||
},
|
||||
Logger: hclog.NewNullLogger(),
|
||||
IsMetadataMode: false,
|
||||
AutoMTLS: true,
|
||||
},
|
||||
hs: plugin.HandshakeConfig{
|
||||
ProtocolVersion: 1,
|
||||
MagicCookieKey: "magic_cookie_key",
|
||||
MagicCookieValue: "magic_cookie_value",
|
||||
},
|
||||
logger: hclog.NewNullLogger(),
|
||||
isMetadataMode: false,
|
||||
autoMTLS: true,
|
||||
},
|
||||
|
||||
responseWrapInfoTimes: 0,
|
||||
|
@ -329,6 +337,11 @@ type mockRunnerUtil struct {
|
|||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockRunnerUtil) NewPluginClient(ctx context.Context, config PluginClientConfig) (PluginClient, error) {
|
||||
args := m.Called(ctx, config)
|
||||
return args.Get(0).(PluginClient), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockRunnerUtil) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
|
||||
args := m.Called(ctx, data, ttl, jwt)
|
||||
return args.Get(0).(*wrapping.ResponseWrapInfo), args.Error(1)
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/wrapping"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// Looker defines the plugin Lookup function that looks into the plugin catalog
|
||||
|
@ -21,6 +22,7 @@ type Looker interface {
|
|||
// configuration and wrapping data in a response wrapped token.
|
||||
// logical.SystemView implementations satisfy this interface.
|
||||
type RunnerUtil interface {
|
||||
NewPluginClient(ctx context.Context, config PluginClientConfig) (PluginClient, error)
|
||||
ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error)
|
||||
MlockEnabled() bool
|
||||
}
|
||||
|
@ -31,17 +33,25 @@ type LookRunnerUtil interface {
|
|||
RunnerUtil
|
||||
}
|
||||
|
||||
type PluginClient interface {
|
||||
Conn() grpc.ClientConnInterface
|
||||
plugin.ClientProtocol
|
||||
}
|
||||
|
||||
const MultiplexingCtxKey string = "multiplex_id"
|
||||
|
||||
// PluginRunner defines the metadata needed to run a plugin securely with
|
||||
// go-plugin.
|
||||
type PluginRunner struct {
|
||||
Name string `json:"name" structs:"name"`
|
||||
Type consts.PluginType `json:"type" structs:"type"`
|
||||
Command string `json:"command" structs:"command"`
|
||||
Args []string `json:"args" structs:"args"`
|
||||
Env []string `json:"env" structs:"env"`
|
||||
Sha256 []byte `json:"sha256" structs:"sha256"`
|
||||
Builtin bool `json:"builtin" structs:"builtin"`
|
||||
BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"`
|
||||
Name string `json:"name" structs:"name"`
|
||||
Type consts.PluginType `json:"type" structs:"type"`
|
||||
Command string `json:"command" structs:"command"`
|
||||
Args []string `json:"args" structs:"args"`
|
||||
Env []string `json:"env" structs:"env"`
|
||||
Sha256 []byte `json:"sha256" structs:"sha256"`
|
||||
Builtin bool `json:"builtin" structs:"builtin"`
|
||||
BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"`
|
||||
MultiplexingSupport bool `json:"multiplexing_support" structs:"multiplexing_support"`
|
||||
}
|
||||
|
||||
// Run takes a wrapper RunnerUtil instance along with the go-plugin parameters and
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.27.1
|
||||
// protoc v3.17.3
|
||||
// protoc v3.19.4
|
||||
// source: sdk/logical/identity.proto
|
||||
|
||||
package logical
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.27.1
|
||||
// protoc v3.17.3
|
||||
// protoc v3.19.4
|
||||
// source: sdk/logical/plugin.proto
|
||||
|
||||
package logical
|
||||
|
|
|
@ -56,6 +56,10 @@ type SystemView interface {
|
|||
// name. Returns a PluginRunner or an error if a plugin can not be found.
|
||||
LookupPlugin(context.Context, string, consts.PluginType) (*pluginutil.PluginRunner, error)
|
||||
|
||||
// NewPluginClient returns a client for managing the lifecycle of plugin
|
||||
// processes
|
||||
NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error)
|
||||
|
||||
// MlockEnabled returns the configuration setting for enabling mlock on
|
||||
// plugins.
|
||||
MlockEnabled() bool
|
||||
|
@ -152,6 +156,10 @@ func (d StaticSystemView) ReplicationState() consts.ReplicationState {
|
|||
return d.ReplicationStateVal
|
||||
}
|
||||
|
||||
func (d StaticSystemView) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) {
|
||||
return nil, errors.New("NewPluginClient is not implemented in StaticSystemView")
|
||||
}
|
||||
|
||||
func (d StaticSystemView) ResponseWrapData(_ context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
|
||||
return nil, errors.New("ResponseWrapData is not implemented in StaticSystemView")
|
||||
}
|
||||
|
|
|
@ -99,6 +99,10 @@ func (s *gRPCSystemViewClient) ResponseWrapData(ctx context.Context, data map[st
|
|||
return info, nil
|
||||
}
|
||||
|
||||
func (s *gRPCSystemViewClient) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) {
|
||||
return nil, fmt.Errorf("cannot call NewPluginClient from a plugin backend")
|
||||
}
|
||||
|
||||
func (s *gRPCSystemViewClient) LookupPlugin(_ context.Context, _ string, _ consts.PluginType) (*pluginutil.PluginRunner, error) {
|
||||
return nil, fmt.Errorf("cannot call LookupPlugin from a plugin backend")
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.27.1
|
||||
// protoc v3.17.3
|
||||
// protoc v3.19.4
|
||||
// source: sdk/plugin/pb/backend.proto
|
||||
|
||||
package pb
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.27.1
|
||||
// protoc v3.17.3
|
||||
// protoc v3.19.4
|
||||
// source: vault/activity/activity_log.proto
|
||||
|
||||
package activity
|
||||
|
|
|
@ -215,6 +215,22 @@ func (d dynamicSystemView) ResponseWrapData(ctx context.Context, data map[string
|
|||
return resp.WrapInfo, nil
|
||||
}
|
||||
|
||||
func (d dynamicSystemView) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) {
|
||||
if d.core == nil {
|
||||
return nil, fmt.Errorf("system view core is nil")
|
||||
}
|
||||
if d.core.pluginCatalog == nil {
|
||||
return nil, fmt.Errorf("system view core plugin catalog is nil")
|
||||
}
|
||||
|
||||
c, err := d.core.pluginCatalog.NewPluginClient(ctx, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// LookupPlugin looks for a plugin with the given name in the plugin catalog. It
|
||||
// returns a PluginRunner or an error if no plugin was found.
|
||||
func (d dynamicSystemView) LookupPlugin(ctx context.Context, name string, pluginType consts.PluginType) (*pluginutil.PluginRunner, error) {
|
||||
|
|
|
@ -12,6 +12,8 @@ import (
|
|||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
multierror "github.com/hashicorp/go-multierror"
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/go-secure-stdlib/base62"
|
||||
v4 "github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
|
@ -19,6 +21,8 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
backendplugin "github.com/hashicorp/vault/sdk/plugin"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -35,21 +39,62 @@ type PluginCatalog struct {
|
|||
builtinRegistry BuiltinRegistry
|
||||
catalogView *BarrierView
|
||||
directory string
|
||||
logger log.Logger
|
||||
|
||||
// externalPlugins holds plugin process connections by plugin name
|
||||
//
|
||||
// This allows plugins that suppport multiplexing to use a single grpc
|
||||
// connection to communicate with multiple "backends". Each backend
|
||||
// configuration using the same plugin will be routed to the existing
|
||||
// plugin process.
|
||||
externalPlugins map[string]*externalPlugin
|
||||
mlockPlugins bool
|
||||
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// externalPlugin holds client connections for multiplexed and
|
||||
// non-multiplexed plugin processes
|
||||
type externalPlugin struct {
|
||||
// name is the plugin name
|
||||
name string
|
||||
|
||||
// connections holds client connections by ID
|
||||
connections map[string]*pluginClient
|
||||
|
||||
multiplexingSupport bool
|
||||
}
|
||||
|
||||
// pluginClient represents a connection to a plugin process
|
||||
type pluginClient struct {
|
||||
logger log.Logger
|
||||
|
||||
// id is the connection ID
|
||||
id string
|
||||
|
||||
// client handles the lifecycle of a plugin process
|
||||
// multiplexed plugins share the same client
|
||||
client *plugin.Client
|
||||
clientConn grpc.ClientConnInterface
|
||||
cleanupFunc func() error
|
||||
|
||||
plugin.ClientProtocol
|
||||
}
|
||||
|
||||
func (c *Core) setupPluginCatalog(ctx context.Context) error {
|
||||
c.pluginCatalog = &PluginCatalog{
|
||||
builtinRegistry: c.builtinRegistry,
|
||||
catalogView: NewBarrierView(c.barrier, pluginCatalogPath),
|
||||
directory: c.pluginDirectory,
|
||||
logger: c.logger,
|
||||
mlockPlugins: c.enableMlock,
|
||||
}
|
||||
|
||||
// Run upgrade if untyped plugins exist
|
||||
err := c.pluginCatalog.UpgradePlugins(ctx, c.logger)
|
||||
if err != nil {
|
||||
c.logger.Error("error while upgrading plugin storage", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if c.logger.IsInfo() {
|
||||
|
@ -59,14 +104,205 @@ func (c *Core) setupPluginCatalog(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
type pluginClientConn struct {
|
||||
*grpc.ClientConn
|
||||
id string
|
||||
}
|
||||
|
||||
var _ grpc.ClientConnInterface = &pluginClientConn{}
|
||||
|
||||
func (d *pluginClientConn) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error {
|
||||
// Inject ID to the context
|
||||
md := metadata.Pairs(pluginutil.MultiplexingCtxKey, d.id)
|
||||
idCtx := metadata.NewOutgoingContext(ctx, md)
|
||||
|
||||
return d.ClientConn.Invoke(idCtx, method, args, reply, opts...)
|
||||
}
|
||||
|
||||
func (d *pluginClientConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
// Inject ID to the context
|
||||
md := metadata.Pairs(pluginutil.MultiplexingCtxKey, d.id)
|
||||
idCtx := metadata.NewOutgoingContext(ctx, md)
|
||||
|
||||
return d.ClientConn.NewStream(idCtx, desc, method, opts...)
|
||||
}
|
||||
|
||||
func (p *pluginClient) Conn() grpc.ClientConnInterface {
|
||||
return p.clientConn
|
||||
}
|
||||
|
||||
// Close calls the plugin client's cleanupFunc to do any necessary cleanup on
|
||||
// the plugin client and the PluginCatalog. This implements the
|
||||
// plugin.ClientProtocol interface.
|
||||
func (p *pluginClient) Close() error {
|
||||
p.logger.Debug("cleaning up plugin client connection", "id", p.id)
|
||||
return p.cleanupFunc()
|
||||
}
|
||||
|
||||
// cleanupExternalPlugin will kill plugin processes and perform any necessary
|
||||
// cleanup on the externalPlugins map for multiplexed and non-multiplexed
|
||||
// plugins. This should be called with the write lock held.
|
||||
func (c *PluginCatalog) cleanupExternalPlugin(name, id string) error {
|
||||
extPlugin, ok := c.externalPlugins[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("plugin client not found")
|
||||
}
|
||||
|
||||
pluginClient := extPlugin.connections[id]
|
||||
|
||||
delete(extPlugin.connections, id)
|
||||
if !extPlugin.multiplexingSupport {
|
||||
pluginClient.client.Kill()
|
||||
|
||||
if len(extPlugin.connections) == 0 {
|
||||
delete(c.externalPlugins, name)
|
||||
}
|
||||
} else if len(extPlugin.connections) == 0 {
|
||||
pluginClient.client.Kill()
|
||||
delete(c.externalPlugins, name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PluginCatalog) getExternalPlugin(pluginName string) *externalPlugin {
|
||||
if extPlugin, ok := c.externalPlugins[pluginName]; ok {
|
||||
return extPlugin
|
||||
}
|
||||
|
||||
return c.newExternalPlugin(pluginName)
|
||||
}
|
||||
|
||||
func (c *PluginCatalog) newExternalPlugin(pluginName string) *externalPlugin {
|
||||
if c.externalPlugins == nil {
|
||||
c.externalPlugins = make(map[string]*externalPlugin)
|
||||
}
|
||||
|
||||
extPlugin := &externalPlugin{
|
||||
connections: make(map[string]*pluginClient),
|
||||
name: pluginName,
|
||||
}
|
||||
|
||||
c.externalPlugins[pluginName] = extPlugin
|
||||
return extPlugin
|
||||
}
|
||||
|
||||
// NewPluginClient returns a client for managing the lifecycle of a plugin
|
||||
// process
|
||||
func (c *PluginCatalog) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (*pluginClient, error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if config.Name == "" {
|
||||
return nil, fmt.Errorf("no name provided for plugin")
|
||||
}
|
||||
if config.PluginType == consts.PluginTypeUnknown {
|
||||
return nil, fmt.Errorf("no plugin type provided")
|
||||
}
|
||||
|
||||
pluginRunner, err := c.get(ctx, config.Name, config.PluginType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to lookup plugin: %w", err)
|
||||
}
|
||||
if pluginRunner == nil {
|
||||
return nil, fmt.Errorf("no plugin found")
|
||||
}
|
||||
pc, err := c.newPluginClient(ctx, pluginRunner, config)
|
||||
return pc, err
|
||||
}
|
||||
|
||||
// newPluginClient returns a client for managing the lifecycle of a plugin
|
||||
// process. Callers should have the write lock held.
|
||||
func (c *PluginCatalog) newPluginClient(ctx context.Context, pluginRunner *pluginutil.PluginRunner, config pluginutil.PluginClientConfig) (*pluginClient, error) {
|
||||
if pluginRunner == nil {
|
||||
return nil, fmt.Errorf("no plugin found")
|
||||
}
|
||||
|
||||
extPlugin := c.getExternalPlugin(pluginRunner.Name)
|
||||
id, err := base62.Random(10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pc := &pluginClient{
|
||||
id: id,
|
||||
logger: c.logger.Named(pluginRunner.Name),
|
||||
cleanupFunc: func() error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
return c.cleanupExternalPlugin(pluginRunner.Name, id)
|
||||
},
|
||||
}
|
||||
|
||||
if !pluginRunner.MultiplexingSupport || len(extPlugin.connections) == 0 {
|
||||
c.logger.Debug("spawning a new plugin process", "plugin_name", pluginRunner.Name, "id", id)
|
||||
client, err := pluginRunner.RunConfig(ctx,
|
||||
pluginutil.PluginSets(config.PluginSets),
|
||||
pluginutil.HandshakeConfig(config.HandshakeConfig),
|
||||
pluginutil.Logger(config.Logger),
|
||||
pluginutil.MetadataMode(config.IsMetadataMode),
|
||||
pluginutil.MLock(c.mlockPlugins),
|
||||
|
||||
// NewPluginClient only supports AutoMTLS today
|
||||
pluginutil.AutoMTLS(true),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pc.client = client
|
||||
} else {
|
||||
c.logger.Debug("returning existing plugin client for multiplexed plugin", "id", id)
|
||||
|
||||
// get the first client, since they are all the same
|
||||
for k := range extPlugin.connections {
|
||||
pc.client = extPlugin.connections[k].client
|
||||
break
|
||||
}
|
||||
|
||||
if pc.client == nil {
|
||||
return nil, fmt.Errorf("plugin client is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// Get the protocol client for this connection.
|
||||
// Subsequent calls to this will return the same client.
|
||||
rpcClient, err := pc.client.Client()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientConn := rpcClient.(*plugin.GRPCClient).Conn
|
||||
|
||||
if pluginRunner.MultiplexingSupport {
|
||||
// Wrap rpcClient with our implementation so that we can inject the
|
||||
// ID into the context
|
||||
pc.clientConn = &pluginClientConn{
|
||||
ClientConn: clientConn,
|
||||
id: id,
|
||||
}
|
||||
} else {
|
||||
pc.clientConn = clientConn
|
||||
}
|
||||
|
||||
pc.ClientProtocol = rpcClient
|
||||
|
||||
extPlugin.connections[id] = pc
|
||||
extPlugin.name = pluginRunner.Name
|
||||
extPlugin.multiplexingSupport = pluginRunner.MultiplexingSupport
|
||||
|
||||
return extPlugin.connections[id], nil
|
||||
}
|
||||
|
||||
// getPluginTypeFromUnknown will attempt to run the plugin to determine the
|
||||
// type. It will first attempt to run as a database plugin then a backend
|
||||
// plugin. Both of these will be run in metadata mode.
|
||||
func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, error) {
|
||||
// type and if it supports multiplexing. It will first attempt to run as a
|
||||
// database plugin then a backend plugin. Both of these will be run in metadata
|
||||
// mode.
|
||||
func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log.Logger, plugin *pluginutil.PluginRunner) (consts.PluginType, bool, error) {
|
||||
merr := &multierror.Error{}
|
||||
err := isDatabasePlugin(ctx, plugin)
|
||||
multiplexingSupport, err := c.isDatabasePlugin(ctx, plugin)
|
||||
if err == nil {
|
||||
return consts.PluginTypeDatabase, nil
|
||||
return consts.PluginTypeDatabase, multiplexingSupport, nil
|
||||
}
|
||||
merr = multierror.Append(merr, err)
|
||||
|
||||
|
@ -75,7 +311,7 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log
|
|||
if err == nil {
|
||||
err := client.Setup(ctx, &logical.BackendConfig{})
|
||||
if err != nil {
|
||||
return consts.PluginTypeUnknown, err
|
||||
return consts.PluginTypeUnknown, false, err
|
||||
}
|
||||
|
||||
backendType := client.Type()
|
||||
|
@ -83,9 +319,9 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log
|
|||
|
||||
switch backendType {
|
||||
case logical.TypeCredential:
|
||||
return consts.PluginTypeCredential, nil
|
||||
return consts.PluginTypeCredential, false, nil
|
||||
case logical.TypeLogical:
|
||||
return consts.PluginTypeSecrets, nil
|
||||
return consts.PluginTypeSecrets, false, nil
|
||||
}
|
||||
} else {
|
||||
merr = multierror.Append(merr, err)
|
||||
|
@ -102,29 +338,55 @@ func (c *PluginCatalog) getPluginTypeFromUnknown(ctx context.Context, logger log
|
|||
"error", merr.Error())
|
||||
}
|
||||
|
||||
return consts.PluginTypeUnknown, nil
|
||||
return consts.PluginTypeUnknown, false, nil
|
||||
}
|
||||
|
||||
func isDatabasePlugin(ctx context.Context, plugin *pluginutil.PluginRunner) error {
|
||||
// isDatabasePlugin returns true if the plugin supports multiplexing. An error
|
||||
// is returned if the plugin is not a database plugin.
|
||||
func (c *PluginCatalog) isDatabasePlugin(ctx context.Context, pluginRunner *pluginutil.PluginRunner) (bool, error) {
|
||||
merr := &multierror.Error{}
|
||||
// Attempt to run as database V5 plugin
|
||||
v5Client, err := v5.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true)
|
||||
config := pluginutil.PluginClientConfig{
|
||||
Name: pluginRunner.Name,
|
||||
PluginSets: v5.PluginSets,
|
||||
PluginType: consts.PluginTypeDatabase,
|
||||
HandshakeConfig: v5.HandshakeConfig,
|
||||
Logger: log.NewNullLogger(),
|
||||
IsMetadataMode: true,
|
||||
AutoMTLS: true,
|
||||
}
|
||||
// Attempt to run as database V5 or V6 multiplexed plugin
|
||||
v5Client, err := c.newPluginClient(ctx, pluginRunner, config)
|
||||
if err == nil {
|
||||
// At this point the pluginRunner does not know if multiplexing is
|
||||
// supported or not. So we need to ask the plugin client itself.
|
||||
multiplexingSupport, err := pluginutil.MultiplexingSupported(ctx, v5Client.clientConn)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Close the client and cleanup the plugin process
|
||||
v5Client.Close()
|
||||
return nil
|
||||
err = c.cleanupExternalPlugin(pluginRunner.Name, v5Client.id)
|
||||
if err != nil {
|
||||
c.logger.Error("error closing plugin client", "error", err)
|
||||
}
|
||||
|
||||
return multiplexingSupport, nil
|
||||
}
|
||||
merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v5: %w", err))
|
||||
|
||||
v4Client, err := v4.NewPluginClient(ctx, nil, plugin, log.NewNullLogger(), true)
|
||||
v4Client, err := v4.NewPluginClient(ctx, nil, pluginRunner, log.NewNullLogger(), true)
|
||||
if err == nil {
|
||||
// Close the client and cleanup the plugin process
|
||||
v4Client.Close()
|
||||
return nil
|
||||
err = v4Client.Close()
|
||||
if err != nil {
|
||||
c.logger.Error("error closing plugin client", "error", err)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
merr = multierror.Append(merr, fmt.Errorf("failed to load plugin as database v4: %w", err))
|
||||
|
||||
return merr.ErrorOrNil()
|
||||
return false, merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
// UpdatePlugins will loop over all the plugins of unknown type and attempt to
|
||||
|
@ -170,7 +432,7 @@ func (c *PluginCatalog) UpgradePlugins(ctx context.Context, logger log.Logger) e
|
|||
cmdOld := plugin.Command
|
||||
plugin.Command = filepath.Join(c.directory, plugin.Command)
|
||||
|
||||
pluginType, err := c.getPluginTypeFromUnknown(ctx, logger, plugin)
|
||||
pluginType, multiplexingSupport, err := c.getPluginTypeFromUnknown(ctx, logger, plugin)
|
||||
if err != nil {
|
||||
retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: %s", pluginName, err))
|
||||
continue
|
||||
|
@ -181,7 +443,7 @@ func (c *PluginCatalog) UpgradePlugins(ctx context.Context, logger log.Logger) e
|
|||
}
|
||||
|
||||
// Upgrade the storage
|
||||
err = c.setInternal(ctx, pluginName, pluginType, cmdOld, plugin.Args, plugin.Env, plugin.Sha256)
|
||||
err = c.setInternal(ctx, pluginName, pluginType, multiplexingSupport, cmdOld, plugin.Args, plugin.Env, plugin.Sha256)
|
||||
if err != nil {
|
||||
retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: %s", pluginName, err))
|
||||
continue
|
||||
|
@ -269,10 +531,14 @@ func (c *PluginCatalog) Set(ctx context.Context, name string, pluginType consts.
|
|||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
return c.setInternal(ctx, name, pluginType, command, args, env, sha256)
|
||||
// During plugin registration, we can't know if a plugin is multiplexed or
|
||||
// not until we run it. So we set it to false here. Once started, we ask
|
||||
// the plugin if it is multiplexed and set this value accordingly.
|
||||
multiplexingSupport := false
|
||||
return c.setInternal(ctx, name, pluginType, multiplexingSupport, command, args, env, sha256)
|
||||
}
|
||||
|
||||
func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType consts.PluginType, command string, args []string, env []string, sha256 []byte) error {
|
||||
func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType consts.PluginType, multiplexingSupport bool, command string, args []string, env []string, sha256 []byte) error {
|
||||
// Best effort check to make sure the command isn't breaking out of the
|
||||
// configured plugin directory.
|
||||
commandFull := filepath.Join(c.directory, command)
|
||||
|
@ -294,15 +560,16 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType
|
|||
// entryTmp should only be used for the below type check, it uses the
|
||||
// full command instead of the relative command.
|
||||
entryTmp := &pluginutil.PluginRunner{
|
||||
Name: name,
|
||||
Command: commandFull,
|
||||
Args: args,
|
||||
Env: env,
|
||||
Sha256: sha256,
|
||||
Builtin: false,
|
||||
Name: name,
|
||||
Command: commandFull,
|
||||
Args: args,
|
||||
Env: env,
|
||||
Sha256: sha256,
|
||||
Builtin: false,
|
||||
MultiplexingSupport: multiplexingSupport,
|
||||
}
|
||||
|
||||
pluginType, err = c.getPluginTypeFromUnknown(ctx, log.Default(), entryTmp)
|
||||
pluginType, multiplexingSupport, err = c.getPluginTypeFromUnknown(ctx, log.Default(), entryTmp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -312,13 +579,14 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType
|
|||
}
|
||||
|
||||
entry := &pluginutil.PluginRunner{
|
||||
Name: name,
|
||||
Type: pluginType,
|
||||
Command: command,
|
||||
Args: args,
|
||||
Env: env,
|
||||
Sha256: sha256,
|
||||
Builtin: false,
|
||||
Name: name,
|
||||
Type: pluginType,
|
||||
Command: command,
|
||||
Args: args,
|
||||
Env: env,
|
||||
Sha256: sha256,
|
||||
Builtin: false,
|
||||
MultiplexingSupport: multiplexingSupport,
|
||||
}
|
||||
|
||||
buf, err := json.Marshal(entry)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.27.1
|
||||
// protoc v3.17.3
|
||||
// protoc v3.19.4
|
||||
// source: vault/request_forwarding_service.proto
|
||||
|
||||
package vault
|
||||
|
|
Loading…
Reference in New Issue