2c8e88ab67
Check if plugin version matches running version When registering a plugin, we check if the request version matches the self-reported version from the plugin. If these do not match, we log a warning. This uncovered a few missing pieces for getting the database version code fully working. We added an environment variable that helps us unit test the running version behavior as well, but only for approle, postgresql, and consul plugins. Return 400 on plugin not found or version mismatch Populate the running SHA256 of plugins in the mount and auth tables (#17217)
358 lines
9.4 KiB
Go
358 lines
9.4 KiB
Go
package dbplugin
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang/protobuf/ptypes"
|
|
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
|
"github.com/hashicorp/vault/sdk/helper/base62"
|
|
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
var _ proto.DatabaseServer = &gRPCServer{}
|
|
|
|
type gRPCServer struct {
|
|
proto.UnimplementedDatabaseServer
|
|
logical.UnimplementedPluginVersionServer
|
|
|
|
// holds the non-multiplexed Database
|
|
// when this is set the plugin does not support multiplexing
|
|
singleImpl Database
|
|
|
|
// instances holds the multiplexed Databases
|
|
instances map[string]Database
|
|
factoryFunc func() (interface{}, error)
|
|
|
|
sync.RWMutex
|
|
}
|
|
|
|
func (g *gRPCServer) getOrCreateDatabase(ctx context.Context) (Database, error) {
|
|
g.Lock()
|
|
defer g.Unlock()
|
|
|
|
if g.singleImpl != nil {
|
|
return g.singleImpl, nil
|
|
}
|
|
|
|
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if db, ok := g.instances[id]; ok {
|
|
return db, nil
|
|
}
|
|
return g.createDatabase(id)
|
|
}
|
|
|
|
// must hold the g.Lock() to call this function
|
|
func (g *gRPCServer) createDatabase(id string) (Database, error) {
|
|
db, err := g.factoryFunc()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
database := db.(Database)
|
|
g.instances[id] = database
|
|
|
|
return database, nil
|
|
}
|
|
|
|
// getDatabaseInternal returns the database but does not hold a lock
|
|
func (g *gRPCServer) getDatabaseInternal(ctx context.Context) (Database, error) {
|
|
if g.singleImpl != nil {
|
|
return g.singleImpl, nil
|
|
}
|
|
|
|
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if db, ok := g.instances[id]; ok {
|
|
return db, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("no database instance found")
|
|
}
|
|
|
|
// getDatabase holds a read lock and returns the database
|
|
func (g *gRPCServer) getDatabase(ctx context.Context) (Database, error) {
|
|
g.RLock()
|
|
impl, err := g.getDatabaseInternal(ctx)
|
|
g.RUnlock()
|
|
return impl, err
|
|
}
|
|
|
|
// Initialize the database plugin
|
|
func (g *gRPCServer) Initialize(ctx context.Context, request *proto.InitializeRequest) (*proto.InitializeResponse, error) {
|
|
impl, err := g.getOrCreateDatabase(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rawConfig := structToMap(request.ConfigData)
|
|
|
|
dbReq := InitializeRequest{
|
|
Config: rawConfig,
|
|
VerifyConnection: request.VerifyConnection,
|
|
}
|
|
|
|
dbResp, err := impl.Initialize(ctx, dbReq)
|
|
if err != nil {
|
|
return &proto.InitializeResponse{}, status.Errorf(codes.Internal, "failed to initialize: %s", err)
|
|
}
|
|
|
|
newConfig, err := mapToStruct(dbResp.Config)
|
|
if err != nil {
|
|
return &proto.InitializeResponse{}, status.Errorf(codes.Internal, "failed to marshal new config to JSON: %s", err)
|
|
}
|
|
|
|
resp := &proto.InitializeResponse{
|
|
ConfigData: newConfig,
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (g *gRPCServer) NewUser(ctx context.Context, req *proto.NewUserRequest) (*proto.NewUserResponse, error) {
|
|
if req.GetUsernameConfig() == nil {
|
|
return &proto.NewUserResponse{}, status.Errorf(codes.InvalidArgument, "missing username config")
|
|
}
|
|
|
|
var expiration time.Time
|
|
|
|
if req.GetExpiration() != nil {
|
|
exp, err := ptypes.Timestamp(req.GetExpiration())
|
|
if err != nil {
|
|
return &proto.NewUserResponse{}, status.Errorf(codes.InvalidArgument, "unable to parse expiration date: %s", err)
|
|
}
|
|
expiration = exp
|
|
}
|
|
|
|
impl, err := g.getDatabase(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
dbReq := NewUserRequest{
|
|
UsernameConfig: UsernameMetadata{
|
|
DisplayName: req.GetUsernameConfig().GetDisplayName(),
|
|
RoleName: req.GetUsernameConfig().GetRoleName(),
|
|
},
|
|
CredentialType: CredentialType(req.GetCredentialType()),
|
|
Password: req.GetPassword(),
|
|
PublicKey: req.GetPublicKey(),
|
|
Expiration: expiration,
|
|
Statements: getStatementsFromProto(req.GetStatements()),
|
|
RollbackStatements: getStatementsFromProto(req.GetRollbackStatements()),
|
|
}
|
|
|
|
dbResp, err := impl.NewUser(ctx, dbReq)
|
|
if err != nil {
|
|
return &proto.NewUserResponse{}, status.Errorf(codes.Internal, "unable to create new user: %s", err)
|
|
}
|
|
|
|
resp := &proto.NewUserResponse{
|
|
Username: dbResp.Username,
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
func (g *gRPCServer) UpdateUser(ctx context.Context, req *proto.UpdateUserRequest) (*proto.UpdateUserResponse, error) {
|
|
if req.GetUsername() == "" {
|
|
return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided")
|
|
}
|
|
|
|
dbReq, err := getUpdateUserRequest(req)
|
|
if err != nil {
|
|
return &proto.UpdateUserResponse{}, status.Errorf(codes.InvalidArgument, err.Error())
|
|
}
|
|
|
|
impl, err := g.getDatabase(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
_, err = impl.UpdateUser(ctx, dbReq)
|
|
if err != nil {
|
|
return &proto.UpdateUserResponse{}, status.Errorf(codes.Internal, "unable to update user: %s", err)
|
|
}
|
|
return &proto.UpdateUserResponse{}, nil
|
|
}
|
|
|
|
func getUpdateUserRequest(req *proto.UpdateUserRequest) (UpdateUserRequest, error) {
|
|
var password *ChangePassword
|
|
if req.GetPassword() != nil && req.GetPassword().GetNewPassword() != "" {
|
|
password = &ChangePassword{
|
|
NewPassword: req.GetPassword().GetNewPassword(),
|
|
Statements: getStatementsFromProto(req.GetPassword().GetStatements()),
|
|
}
|
|
}
|
|
|
|
var publicKey *ChangePublicKey
|
|
if req.GetPublicKey() != nil && len(req.GetPublicKey().GetNewPublicKey()) > 0 {
|
|
publicKey = &ChangePublicKey{
|
|
NewPublicKey: req.GetPublicKey().GetNewPublicKey(),
|
|
Statements: getStatementsFromProto(req.GetPublicKey().GetStatements()),
|
|
}
|
|
}
|
|
|
|
var expiration *ChangeExpiration
|
|
if req.GetExpiration() != nil && req.GetExpiration().GetNewExpiration() != nil {
|
|
newExpiration, err := ptypes.Timestamp(req.GetExpiration().GetNewExpiration())
|
|
if err != nil {
|
|
return UpdateUserRequest{}, fmt.Errorf("unable to parse new expiration: %w", err)
|
|
}
|
|
|
|
expiration = &ChangeExpiration{
|
|
NewExpiration: newExpiration,
|
|
Statements: getStatementsFromProto(req.GetExpiration().GetStatements()),
|
|
}
|
|
}
|
|
|
|
dbReq := UpdateUserRequest{
|
|
Username: req.GetUsername(),
|
|
CredentialType: CredentialType(req.GetCredentialType()),
|
|
Password: password,
|
|
PublicKey: publicKey,
|
|
Expiration: expiration,
|
|
}
|
|
|
|
if !hasChange(dbReq) {
|
|
return UpdateUserRequest{}, fmt.Errorf("update user request has no changes")
|
|
}
|
|
|
|
return dbReq, nil
|
|
}
|
|
|
|
func hasChange(dbReq UpdateUserRequest) bool {
|
|
if dbReq.Password != nil && dbReq.Password.NewPassword != "" {
|
|
return true
|
|
}
|
|
if dbReq.PublicKey != nil && len(dbReq.PublicKey.NewPublicKey) > 0 {
|
|
return true
|
|
}
|
|
if dbReq.Expiration != nil && !dbReq.Expiration.NewExpiration.IsZero() {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (g *gRPCServer) DeleteUser(ctx context.Context, req *proto.DeleteUserRequest) (*proto.DeleteUserResponse, error) {
|
|
if req.GetUsername() == "" {
|
|
return &proto.DeleteUserResponse{}, status.Errorf(codes.InvalidArgument, "no username provided")
|
|
}
|
|
dbReq := DeleteUserRequest{
|
|
Username: req.GetUsername(),
|
|
Statements: getStatementsFromProto(req.GetStatements()),
|
|
}
|
|
|
|
impl, err := g.getDatabase(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
_, err = impl.DeleteUser(ctx, dbReq)
|
|
if err != nil {
|
|
return &proto.DeleteUserResponse{}, status.Errorf(codes.Internal, "unable to delete user: %s", err)
|
|
}
|
|
return &proto.DeleteUserResponse{}, nil
|
|
}
|
|
|
|
func (g *gRPCServer) Type(ctx context.Context, _ *proto.Empty) (*proto.TypeResponse, error) {
|
|
impl, err := g.getOrCreateDatabase(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
t, err := impl.Type()
|
|
if err != nil {
|
|
return &proto.TypeResponse{}, status.Errorf(codes.Internal, "unable to retrieve type: %s", err)
|
|
}
|
|
|
|
resp := &proto.TypeResponse{
|
|
Type: t,
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
func (g *gRPCServer) Close(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) {
|
|
g.Lock()
|
|
defer g.Unlock()
|
|
|
|
impl, err := g.getDatabaseInternal(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = impl.Close()
|
|
if err != nil {
|
|
return &proto.Empty{}, status.Errorf(codes.Internal, "unable to close database plugin: %s", err)
|
|
}
|
|
|
|
if g.singleImpl == nil {
|
|
// only cleanup instances map when multiplexing is supported
|
|
id, err := pluginutil.GetMultiplexIDFromContext(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
delete(g.instances, id)
|
|
}
|
|
|
|
return &proto.Empty{}, nil
|
|
}
|
|
|
|
// getOrForceCreateDatabase will create a database even if the multiplexing ID is not present
|
|
func (g *gRPCServer) getOrForceCreateDatabase(ctx context.Context) (Database, error) {
|
|
impl, err := g.getOrCreateDatabase(ctx)
|
|
if errors.Is(err, pluginutil.ErrNoMultiplexingIDFound) {
|
|
// if this is called without a multiplexing context, like from the plugin catalog directly,
|
|
// then we won't have a database ID, so let's generate a new database instance
|
|
id, err := base62.Random(10)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
g.Lock()
|
|
defer g.Unlock()
|
|
impl, err = g.createDatabase(id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else if err != nil {
|
|
return nil, err
|
|
}
|
|
return impl, nil
|
|
}
|
|
|
|
// Version forwards the version request to the underlying Database implementation.
|
|
func (g *gRPCServer) Version(ctx context.Context, _ *logical.Empty) (*logical.VersionReply, error) {
|
|
impl, err := g.getOrForceCreateDatabase(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if versioner, ok := impl.(logical.PluginVersioner); ok {
|
|
return &logical.VersionReply{PluginVersion: versioner.PluginVersion().Version}, nil
|
|
}
|
|
return &logical.VersionReply{}, nil
|
|
}
|
|
|
|
func getStatementsFromProto(protoStmts *proto.Statements) (statements Statements) {
|
|
if protoStmts == nil {
|
|
return statements
|
|
}
|
|
cmds := protoStmts.GetCommands()
|
|
statements = Statements{
|
|
Commands: cmds,
|
|
}
|
|
return statements
|
|
}
|