Database gRPC plugins (#3666)
* Start work on context aware backends * Start work on moving the database plugins to gRPC in order to pass context * Add context to builtin database plugins * use byte slice instead of string * Context all the things * Move proto messages to the dbplugin package * Add a grpc mechanism for running backend plugins * Serve the GRPC plugin * Add backwards compatibility to the database plugins * Remove backend plugin changes * Remove backend plugin changes * Cleanup the transport implementations * If grpc connection is in an unexpected state restart the plugin * Fix tests * Fix tests * Remove context from the request object, replace it with context.TODO * Add a test to verify netRPC plugins still work * Remove unused mapstructure call * Code review fixes * Code review fixes * Code review fixes
This commit is contained in:
parent
829f2c38dc
commit
afe53eb862
1
Makefile
1
Makefile
|
@ -84,6 +84,7 @@ proto:
|
|||
protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding
|
||||
protoc -I physical physical/types.proto --go_out=plugins=grpc:physical
|
||||
protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity
|
||||
protoc builtin/logical/database/dbplugin/*.proto --go_out=plugins=grpc:.
|
||||
sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go
|
||||
sed -i -e 's/Iv/IV/' -e 's/Hmac/HMAC/' physical/types.pb.go
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"strings"
|
||||
|
@ -87,7 +88,7 @@ func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) {
|
|||
// This function creates a new db object from the stored configuration and
|
||||
// caches it in the connections map. The caller of this function needs to hold
|
||||
// the backend's write lock
|
||||
func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) {
|
||||
func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) {
|
||||
db, ok := b.connections[name]
|
||||
if ok {
|
||||
return db, nil
|
||||
|
@ -103,7 +104,7 @@ func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.
|
|||
return nil, err
|
||||
}
|
||||
|
||||
err = db.Initialize(config.ConnectionDetails, true)
|
||||
err = db.Initialize(ctx, config.ConnectionDetails, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -170,7 +171,8 @@ func (b *databaseBackend) clearConnection(name string) {
|
|||
|
||||
func (b *databaseBackend) closeIfShutdown(name string, err error) {
|
||||
// Plugin has shutdown, close it so next call can reconnect.
|
||||
if err == rpc.ErrShutdown {
|
||||
switch err {
|
||||
case rpc.ErrShutdown, dbplugin.ErrPluginShutdown:
|
||||
b.Lock()
|
||||
b.clearConnection(name)
|
||||
b.Unlock()
|
||||
|
|
|
@ -488,9 +488,11 @@ func TestBackend_roleCrud(t *testing.T) {
|
|||
RevocationStatements: defaultRevocationSQL,
|
||||
}
|
||||
|
||||
var actual dbplugin.Statements
|
||||
if err := mapstructure.Decode(resp.Data, &actual); err != nil {
|
||||
t.Fatal(err)
|
||||
actual := dbplugin.Statements{
|
||||
CreationStatements: resp.Data["creation_statements"].(string),
|
||||
RevocationStatements: resp.Data["revocation_statements"].(string),
|
||||
RollbackStatements: resp.Data["rollback_statements"].(string),
|
||||
RenewStatements: resp.Data["renew_statements"].(string),
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
|
@ -17,11 +15,11 @@ type DatabasePluginClient struct {
|
|||
client *plugin.Client
|
||||
sync.Mutex
|
||||
|
||||
*databasePluginRPCClient
|
||||
Database
|
||||
}
|
||||
|
||||
func (dc *DatabasePluginClient) Close() error {
|
||||
err := dc.databasePluginRPCClient.Close()
|
||||
err := dc.Database.Close()
|
||||
dc.client.Kill()
|
||||
|
||||
return err
|
||||
|
@ -55,79 +53,20 @@ func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginR
|
|||
|
||||
// We should have a database type now. This feels like a normal interface
|
||||
// implementation but is in fact over an RPC connection.
|
||||
databaseRPC := raw.(*databasePluginRPCClient)
|
||||
var db Database
|
||||
switch raw.(type) {
|
||||
case *gRPCClient:
|
||||
db = raw.(*gRPCClient)
|
||||
case *databasePluginRPCClient:
|
||||
logger.Warn("database: plugin is using deprecated net RPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name)
|
||||
db = raw.(*databasePluginRPCClient)
|
||||
default:
|
||||
return nil, errors.New("unsupported client type")
|
||||
}
|
||||
|
||||
// Wrap RPC implimentation in DatabasePluginClient
|
||||
return &DatabasePluginClient{
|
||||
client: client,
|
||||
databasePluginRPCClient: databaseRPC,
|
||||
client: client,
|
||||
Database: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ---- RPC client domain ----
|
||||
|
||||
// databasePluginRPCClient implements Database and is used on the client to
|
||||
// make RPC calls to a plugin.
|
||||
type databasePluginRPCClient struct {
|
||||
client *rpc.Client
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Type() (string, error) {
|
||||
var dbType string
|
||||
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
|
||||
|
||||
return fmt.Sprintf("plugin-%s", dbType), err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
req := CreateUserRequest{
|
||||
Statements: statements,
|
||||
UsernameConfig: usernameConfig,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
var resp CreateUserResponse
|
||||
err = dr.client.Call("Plugin.CreateUser", req, &resp)
|
||||
|
||||
return resp.Username, resp.Password, err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error {
|
||||
req := RenewUserRequest{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error {
|
||||
req := RevokeUserRequest{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
||||
req := InitializeRequest{
|
||||
Config: conf,
|
||||
VerifyConnection: verifyConnection,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Close() error {
|
||||
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,556 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// source: builtin/logical/database/dbplugin/database.proto
|
||||
|
||||
/*
|
||||
Package dbplugin is a generated protocol buffer package.
|
||||
|
||||
It is generated from these files:
|
||||
builtin/logical/database/dbplugin/database.proto
|
||||
|
||||
It has these top-level messages:
|
||||
InitializeRequest
|
||||
CreateUserRequest
|
||||
RenewUserRequest
|
||||
RevokeUserRequest
|
||||
Statements
|
||||
UsernameConfig
|
||||
CreateUserResponse
|
||||
TypeResponse
|
||||
Empty
|
||||
*/
|
||||
package dbplugin
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import fmt "fmt"
|
||||
import math "math"
|
||||
import google_protobuf "github.com/golang/protobuf/ptypes/timestamp"
|
||||
|
||||
import (
|
||||
context "golang.org/x/net/context"
|
||||
grpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
type InitializeRequest struct {
|
||||
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
|
||||
VerifyConnection bool `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"`
|
||||
}
|
||||
|
||||
func (m *InitializeRequest) Reset() { *m = InitializeRequest{} }
|
||||
func (m *InitializeRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*InitializeRequest) ProtoMessage() {}
|
||||
func (*InitializeRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
|
||||
|
||||
func (m *InitializeRequest) GetConfig() []byte {
|
||||
if m != nil {
|
||||
return m.Config
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *InitializeRequest) GetVerifyConnection() bool {
|
||||
if m != nil {
|
||||
return m.VerifyConnection
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type CreateUserRequest struct {
|
||||
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
|
||||
UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"`
|
||||
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
|
||||
}
|
||||
|
||||
func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} }
|
||||
func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*CreateUserRequest) ProtoMessage() {}
|
||||
func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
|
||||
|
||||
func (m *CreateUserRequest) GetStatements() *Statements {
|
||||
if m != nil {
|
||||
return m.Statements
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *CreateUserRequest) GetUsernameConfig() *UsernameConfig {
|
||||
if m != nil {
|
||||
return m.UsernameConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *CreateUserRequest) GetExpiration() *google_protobuf.Timestamp {
|
||||
if m != nil {
|
||||
return m.Expiration
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RenewUserRequest struct {
|
||||
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
|
||||
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
|
||||
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
|
||||
}
|
||||
|
||||
func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} }
|
||||
func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*RenewUserRequest) ProtoMessage() {}
|
||||
func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
|
||||
|
||||
func (m *RenewUserRequest) GetStatements() *Statements {
|
||||
if m != nil {
|
||||
return m.Statements
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *RenewUserRequest) GetUsername() string {
|
||||
if m != nil {
|
||||
return m.Username
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *RenewUserRequest) GetExpiration() *google_protobuf.Timestamp {
|
||||
if m != nil {
|
||||
return m.Expiration
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RevokeUserRequest struct {
|
||||
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
|
||||
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
|
||||
}
|
||||
|
||||
func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} }
|
||||
func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*RevokeUserRequest) ProtoMessage() {}
|
||||
func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
|
||||
|
||||
func (m *RevokeUserRequest) GetStatements() *Statements {
|
||||
if m != nil {
|
||||
return m.Statements
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *RevokeUserRequest) GetUsername() string {
|
||||
if m != nil {
|
||||
return m.Username
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type Statements struct {
|
||||
CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"`
|
||||
RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"`
|
||||
RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"`
|
||||
RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Statements) Reset() { *m = Statements{} }
|
||||
func (m *Statements) String() string { return proto.CompactTextString(m) }
|
||||
func (*Statements) ProtoMessage() {}
|
||||
func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
|
||||
|
||||
func (m *Statements) GetCreationStatements() string {
|
||||
if m != nil {
|
||||
return m.CreationStatements
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Statements) GetRevocationStatements() string {
|
||||
if m != nil {
|
||||
return m.RevocationStatements
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Statements) GetRollbackStatements() string {
|
||||
if m != nil {
|
||||
return m.RollbackStatements
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Statements) GetRenewStatements() string {
|
||||
if m != nil {
|
||||
return m.RenewStatements
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type UsernameConfig struct {
|
||||
DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"`
|
||||
RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"`
|
||||
}
|
||||
|
||||
func (m *UsernameConfig) Reset() { *m = UsernameConfig{} }
|
||||
func (m *UsernameConfig) String() string { return proto.CompactTextString(m) }
|
||||
func (*UsernameConfig) ProtoMessage() {}
|
||||
func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
|
||||
|
||||
func (m *UsernameConfig) GetDisplayName() string {
|
||||
if m != nil {
|
||||
return m.DisplayName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *UsernameConfig) GetRoleName() string {
|
||||
if m != nil {
|
||||
return m.RoleName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type CreateUserResponse struct {
|
||||
Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"`
|
||||
Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"`
|
||||
}
|
||||
|
||||
func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} }
|
||||
func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) }
|
||||
func (*CreateUserResponse) ProtoMessage() {}
|
||||
func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
|
||||
|
||||
func (m *CreateUserResponse) GetUsername() string {
|
||||
if m != nil {
|
||||
return m.Username
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *CreateUserResponse) GetPassword() string {
|
||||
if m != nil {
|
||||
return m.Password
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type TypeResponse struct {
|
||||
Type string `protobuf:"bytes,1,opt,name=type" json:"type,omitempty"`
|
||||
}
|
||||
|
||||
func (m *TypeResponse) Reset() { *m = TypeResponse{} }
|
||||
func (m *TypeResponse) String() string { return proto.CompactTextString(m) }
|
||||
func (*TypeResponse) ProtoMessage() {}
|
||||
func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
|
||||
|
||||
func (m *TypeResponse) GetType() string {
|
||||
if m != nil {
|
||||
return m.Type
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type Empty struct {
|
||||
}
|
||||
|
||||
func (m *Empty) Reset() { *m = Empty{} }
|
||||
func (m *Empty) String() string { return proto.CompactTextString(m) }
|
||||
func (*Empty) ProtoMessage() {}
|
||||
func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} }
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest")
|
||||
proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest")
|
||||
proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest")
|
||||
proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest")
|
||||
proto.RegisterType((*Statements)(nil), "dbplugin.Statements")
|
||||
proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig")
|
||||
proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse")
|
||||
proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse")
|
||||
proto.RegisterType((*Empty)(nil), "dbplugin.Empty")
|
||||
}
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ context.Context
|
||||
var _ grpc.ClientConn
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
const _ = grpc.SupportPackageIsVersion4
|
||||
|
||||
// Client API for Database service
|
||||
|
||||
type DatabaseClient interface {
|
||||
Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error)
|
||||
CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error)
|
||||
RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error)
|
||||
RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error)
|
||||
Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error)
|
||||
Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
|
||||
}
|
||||
|
||||
type databaseClient struct {
|
||||
cc *grpc.ClientConn
|
||||
}
|
||||
|
||||
func NewDatabaseClient(cc *grpc.ClientConn) DatabaseClient {
|
||||
return &databaseClient{cc}
|
||||
}
|
||||
|
||||
func (c *databaseClient) Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) {
|
||||
out := new(TypeResponse)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/Type", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *databaseClient) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) {
|
||||
out := new(CreateUserResponse)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/CreateUser", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *databaseClient) RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/RenewUser", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/RevokeUser", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/Close", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Server API for Database service
|
||||
|
||||
type DatabaseServer interface {
|
||||
Type(context.Context, *Empty) (*TypeResponse, error)
|
||||
CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error)
|
||||
RenewUser(context.Context, *RenewUserRequest) (*Empty, error)
|
||||
RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error)
|
||||
Initialize(context.Context, *InitializeRequest) (*Empty, error)
|
||||
Close(context.Context, *Empty) (*Empty, error)
|
||||
}
|
||||
|
||||
func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) {
|
||||
s.RegisterService(&_Database_serviceDesc, srv)
|
||||
}
|
||||
|
||||
func _Database_Type_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(Empty)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).Type(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/Type",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).Type(ctx, req.(*Empty))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Database_CreateUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(CreateUserRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).CreateUser(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/CreateUser",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).CreateUser(ctx, req.(*CreateUserRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Database_RenewUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RenewUserRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).RenewUser(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/RenewUser",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).RenewUser(ctx, req.(*RenewUserRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RevokeUserRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).RevokeUser(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/RevokeUser",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).RevokeUser(ctx, req.(*RevokeUserRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(InitializeRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).Initialize(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/Initialize",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(Empty)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).Close(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/Close",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).Close(ctx, req.(*Empty))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
var _Database_serviceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "dbplugin.Database",
|
||||
HandlerType: (*DatabaseServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Type",
|
||||
Handler: _Database_Type_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "CreateUser",
|
||||
Handler: _Database_CreateUser_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RenewUser",
|
||||
Handler: _Database_RenewUser_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RevokeUser",
|
||||
Handler: _Database_RevokeUser_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "Initialize",
|
||||
Handler: _Database_Initialize_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "Close",
|
||||
Handler: _Database_Close_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "builtin/logical/database/dbplugin/database.proto",
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) }
|
||||
|
||||
var fileDescriptor0 = []byte{
|
||||
// 548 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e,
|
||||
0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f,
|
||||
0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e,
|
||||
0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2,
|
||||
0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6,
|
||||
0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c,
|
||||
0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e,
|
||||
0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3,
|
||||
0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1,
|
||||
0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f,
|
||||
0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49,
|
||||
0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35,
|
||||
0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20,
|
||||
0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6,
|
||||
0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34,
|
||||
0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0,
|
||||
0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce,
|
||||
0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68,
|
||||
0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0,
|
||||
0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8,
|
||||
0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1,
|
||||
0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0,
|
||||
0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a,
|
||||
0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab,
|
||||
0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c,
|
||||
0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4,
|
||||
0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36,
|
||||
0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a,
|
||||
0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf,
|
||||
0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67,
|
||||
0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b,
|
||||
0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d,
|
||||
0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6,
|
||||
0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56,
|
||||
0x94, 0x05, 0x00, 0x00,
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
syntax = "proto3";
|
||||
package dbplugin;
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
message InitializeRequest {
|
||||
bytes config = 1;
|
||||
bool verify_connection = 2;
|
||||
}
|
||||
|
||||
message CreateUserRequest {
|
||||
Statements statements = 1;
|
||||
UsernameConfig username_config = 2;
|
||||
google.protobuf.Timestamp expiration = 3;
|
||||
}
|
||||
|
||||
message RenewUserRequest {
|
||||
Statements statements = 1;
|
||||
string username = 2;
|
||||
google.protobuf.Timestamp expiration = 3;
|
||||
}
|
||||
|
||||
message RevokeUserRequest {
|
||||
Statements statements = 1;
|
||||
string username = 2;
|
||||
}
|
||||
|
||||
message Statements {
|
||||
string creation_statements = 1;
|
||||
string revocation_statements = 2;
|
||||
string rollback_statements = 3;
|
||||
string renew_statements = 4;
|
||||
}
|
||||
|
||||
message UsernameConfig {
|
||||
string DisplayName = 1;
|
||||
string RoleName = 2;
|
||||
}
|
||||
|
||||
message CreateUserResponse {
|
||||
string username = 1;
|
||||
string password = 2;
|
||||
}
|
||||
|
||||
message TypeResponse {
|
||||
string type = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
|
||||
service Database {
|
||||
rpc Type(Empty) returns (TypeResponse);
|
||||
rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
|
||||
rpc RenewUser(RenewUserRequest) returns (Empty);
|
||||
rpc RevokeUser(RevokeUserRequest) returns (Empty);
|
||||
rpc Initialize(InitializeRequest) returns (Empty);
|
||||
rpc Close(Empty) returns (Empty);
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
|
@ -15,55 +16,56 @@ type databaseTracingMiddleware struct {
|
|||
next Database
|
||||
logger log.Logger
|
||||
|
||||
typeStr string
|
||||
typeStr string
|
||||
transport string
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Type() (string, error) {
|
||||
return mw.next.Type()
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (mw *databaseTracingMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr)
|
||||
return mw.next.CreateUser(statements, usernameConfig, expiration)
|
||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
|
||||
func (mw *databaseTracingMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr)
|
||||
return mw.next.RenewUser(statements, username, expiration)
|
||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr, "transport", mw.transport)
|
||||
return mw.next.RenewUser(ctx, statements, username, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) {
|
||||
func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr)
|
||||
return mw.next.RevokeUser(statements, username)
|
||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||
return mw.next.RevokeUser(ctx, statements, username)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then))
|
||||
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr)
|
||||
return mw.next.Initialize(conf, verifyConnection)
|
||||
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||
return mw.next.Initialize(ctx, conf, verifyConnection)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Close() (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr)
|
||||
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||
return mw.next.Close()
|
||||
}
|
||||
|
||||
|
@ -81,7 +83,7 @@ func (mw *databaseMetricsMiddleware) Type() (string, error) {
|
|||
return mw.next.Type()
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (mw *databaseMetricsMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "CreateUser"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now)
|
||||
|
@ -94,10 +96,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameC
|
|||
|
||||
metrics.IncrCounter([]string{"database", "CreateUser"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1)
|
||||
return mw.next.CreateUser(statements, usernameConfig, expiration)
|
||||
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
|
||||
func (mw *databaseMetricsMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "RenewUser"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now)
|
||||
|
@ -110,10 +112,10 @@ func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username s
|
|||
|
||||
metrics.IncrCounter([]string{"database", "RenewUser"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1)
|
||||
return mw.next.RenewUser(statements, username, expiration)
|
||||
return mw.next.RenewUser(ctx, statements, username, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username string) (err error) {
|
||||
func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "RevokeUser"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now)
|
||||
|
@ -126,10 +128,10 @@ func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username
|
|||
|
||||
metrics.IncrCounter([]string{"database", "RevokeUser"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1)
|
||||
return mw.next.RevokeUser(statements, username)
|
||||
return mw.next.RevokeUser(ctx, statements, username)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "Initialize"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
|
||||
|
@ -142,7 +144,7 @@ func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, ver
|
|||
|
||||
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
|
||||
return mw.next.Initialize(conf, verifyConnection)
|
||||
return mw.next.Initialize(ctx, conf, verifyConnection)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) Close() (err error) {
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPluginShutdown = errors.New("plugin shutdown")
|
||||
)
|
||||
|
||||
// ---- gRPC Server domain ----
|
||||
|
||||
type gRPCServer struct {
|
||||
impl Database
|
||||
}
|
||||
|
||||
func (s *gRPCServer) Type(context.Context, *Empty) (*TypeResponse, error) {
|
||||
t, err := s.impl.Type()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TypeResponse{
|
||||
Type: t,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *gRPCServer) CreateUser(ctx context.Context, req *CreateUserRequest) (*CreateUserResponse, error) {
|
||||
e, err := ptypes.Timestamp(req.Expiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u, p, err := s.impl.CreateUser(ctx, *req.Statements, *req.UsernameConfig, e)
|
||||
|
||||
return &CreateUserResponse{
|
||||
Username: u,
|
||||
Password: p,
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) RenewUser(ctx context.Context, req *RenewUserRequest) (*Empty, error) {
|
||||
e, err := ptypes.Timestamp(req.Expiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = s.impl.RenewUser(ctx, *req.Statements, req.Username, e)
|
||||
return &Empty{}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*Empty, error) {
|
||||
err := s.impl.RevokeUser(ctx, *req.Statements, req.Username)
|
||||
return &Empty{}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
|
||||
config := map[string]interface{}{}
|
||||
|
||||
err := json.Unmarshal(req.Config, &config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.impl.Initialize(ctx, config, req.VerifyConnection)
|
||||
return &Empty{}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) {
|
||||
s.impl.Close()
|
||||
return &Empty{}, nil
|
||||
}
|
||||
|
||||
// ---- gRPC client domain ----
|
||||
|
||||
type gRPCClient struct {
|
||||
client DatabaseClient
|
||||
clientConn *grpc.ClientConn
|
||||
}
|
||||
|
||||
func (c gRPCClient) Type() (string, error) {
|
||||
// If the plugin has already shutdown, this will hang forever so we give it
|
||||
// a one second timeout.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
default:
|
||||
return "", ErrPluginShutdown
|
||||
}
|
||||
resp, err := c.client.Type(ctx, &Empty{})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return resp.Type, err
|
||||
}
|
||||
|
||||
func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
t, err := ptypes.TimestampProto(expiration)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
default:
|
||||
return "", "", ErrPluginShutdown
|
||||
}
|
||||
|
||||
resp, err := c.client.CreateUser(ctx, &CreateUserRequest{
|
||||
Statements: &statements,
|
||||
UsernameConfig: &usernameConfig,
|
||||
Expiration: t,
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return resp.Username, resp.Password, err
|
||||
}
|
||||
|
||||
func (c *gRPCClient) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error {
|
||||
t, err := ptypes.TimestampProto(expiration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
default:
|
||||
return ErrPluginShutdown
|
||||
}
|
||||
|
||||
_, err = c.client.RenewUser(ctx, &RenewUserRequest{
|
||||
Statements: &statements,
|
||||
Username: username,
|
||||
Expiration: t,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, username string) error {
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
default:
|
||||
return ErrPluginShutdown
|
||||
}
|
||||
_, err := c.client.RevokeUser(ctx, &RevokeUserRequest{
|
||||
Statements: &statements,
|
||||
Username: username,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error {
|
||||
configRaw, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
default:
|
||||
return ErrPluginShutdown
|
||||
}
|
||||
|
||||
_, err = c.client.Initialize(ctx, &InitializeRequest{
|
||||
Config: configRaw,
|
||||
VerifyConnection: verifyConnection,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *gRPCClient) Close() error {
|
||||
// If the plugin has already shutdown, this will hang forever so we give it
|
||||
// a one second timeout.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
_, err := c.client.Close(ctx, &Empty{})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,139 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---- RPC server domain ----
|
||||
|
||||
// databasePluginRPCServer implements an RPC version of Database and is run
|
||||
// inside a plugin. It wraps an underlying implementation of Database.
|
||||
type databasePluginRPCServer struct {
|
||||
impl Database
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
|
||||
var err error
|
||||
*resp, err = ds.impl.Type()
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequestRPC, resp *CreateUserResponse) error {
|
||||
var err error
|
||||
resp.Username, resp.Password, err = ds.impl.CreateUser(context.Background(), args.Statements, args.UsernameConfig, args.Expiration)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequestRPC, _ *struct{}) error {
|
||||
err := ds.impl.RenewUser(context.Background(), args.Statements, args.Username, args.Expiration)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *struct{}) error {
|
||||
err := ds.impl.RevokeUser(context.Background(), args.Statements, args.Username)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error {
|
||||
err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
|
||||
ds.impl.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---- RPC client domain ----
|
||||
// databasePluginRPCClient implements Database and is used on the client to
|
||||
// make RPC calls to a plugin.
|
||||
type databasePluginRPCClient struct {
|
||||
client *rpc.Client
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Type() (string, error) {
|
||||
var dbType string
|
||||
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
|
||||
|
||||
return fmt.Sprintf("plugin-%s", dbType), err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) CreateUser(_ context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
req := CreateUserRequestRPC{
|
||||
Statements: statements,
|
||||
UsernameConfig: usernameConfig,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
var resp CreateUserResponse
|
||||
err = dr.client.Call("Plugin.CreateUser", req, &resp)
|
||||
|
||||
return resp.Username, resp.Password, err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements Statements, username string, expiration time.Time) error {
|
||||
req := RenewUserRequestRPC{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error {
|
||||
req := RevokeUserRequestRPC{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||
req := InitializeRequestRPC{
|
||||
Config: conf,
|
||||
VerifyConnection: verifyConnection,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Close() error {
|
||||
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ---- RPC Request Args Domain ----
|
||||
|
||||
type InitializeRequestRPC struct {
|
||||
Config map[string]interface{}
|
||||
VerifyConnection bool
|
||||
}
|
||||
|
||||
type CreateUserRequestRPC struct {
|
||||
Statements Statements
|
||||
UsernameConfig UsernameConfig
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RenewUserRequestRPC struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RevokeUserRequestRPC struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
}
|
|
@ -1,10 +1,13 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
|
@ -13,29 +16,14 @@ import (
|
|||
// Database is the interface that all database objects must implement.
|
||||
type Database interface {
|
||||
Type() (string, error)
|
||||
CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
|
||||
RenewUser(statements Statements, username string, expiration time.Time) error
|
||||
RevokeUser(statements Statements, username string) error
|
||||
CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
|
||||
RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error
|
||||
RevokeUser(ctx context.Context, statements Statements, username string) error
|
||||
|
||||
Initialize(config map[string]interface{}, verifyConnection bool) error
|
||||
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Statements set in role creation and passed into the database type's functions.
|
||||
type Statements struct {
|
||||
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
|
||||
RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"`
|
||||
RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"`
|
||||
RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"`
|
||||
}
|
||||
|
||||
// UsernameConfig is used to configure prefixes for the username to be
|
||||
// generated.
|
||||
type UsernameConfig struct {
|
||||
DisplayName string
|
||||
RoleName string
|
||||
}
|
||||
|
||||
// PluginFactory is used to build plugin database types. It wraps the database
|
||||
// object in a logging and metrics middleware.
|
||||
func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
|
||||
|
@ -45,6 +33,7 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var transport string
|
||||
var db Database
|
||||
if pluginRunner.Builtin {
|
||||
// Plugin is builtin so we can retrieve an instance of the interface
|
||||
|
@ -60,12 +49,24 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
|
|||
return nil, fmt.Errorf("unsuported database type: %s", pluginName)
|
||||
}
|
||||
|
||||
transport = "builtin"
|
||||
|
||||
} else {
|
||||
// create a DatabasePluginClient instance
|
||||
db, err = newPluginClient(sys, pluginRunner, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Switch on the underlying database client type to get the transport
|
||||
// method.
|
||||
switch db.(*DatabasePluginClient).Database.(type) {
|
||||
case *gRPCClient:
|
||||
transport = "gRPC"
|
||||
case *databasePluginRPCClient:
|
||||
transport = "netRPC"
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
typeStr, err := db.Type()
|
||||
|
@ -82,9 +83,10 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
|
|||
// Wrap with tracing middleware
|
||||
if logger.IsTrace() {
|
||||
db = &databaseTracingMiddleware{
|
||||
next: db,
|
||||
typeStr: typeStr,
|
||||
logger: logger,
|
||||
transport: transport,
|
||||
next: db,
|
||||
typeStr: typeStr,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -115,33 +117,14 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
|
|||
return &databasePluginRPCClient{client: c}, nil
|
||||
}
|
||||
|
||||
// ---- RPC Request Args Domain ----
|
||||
|
||||
type InitializeRequest struct {
|
||||
Config map[string]interface{}
|
||||
VerifyConnection bool
|
||||
func (d DatabasePlugin) GRPCServer(s *grpc.Server) error {
|
||||
RegisterDatabaseServer(s, &gRPCServer{impl: d.impl})
|
||||
return nil
|
||||
}
|
||||
|
||||
type CreateUserRequest struct {
|
||||
Statements Statements
|
||||
UsernameConfig UsernameConfig
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RenewUserRequest struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RevokeUserRequest struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
}
|
||||
|
||||
// ---- RPC Response Args Domain ----
|
||||
|
||||
type CreateUserResponse struct {
|
||||
Username string
|
||||
Password string
|
||||
func (DatabasePlugin) GRPCClient(c *grpc.ClientConn) (interface{}, error) {
|
||||
return &gRPCClient{
|
||||
client: NewDatabaseClient(c),
|
||||
clientConn: c,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
package dbplugin_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
|
@ -20,7 +22,7 @@ type mockPlugin struct {
|
|||
}
|
||||
|
||||
func (m *mockPlugin) Type() (string, error) { return "mock", nil }
|
||||
func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (m *mockPlugin) CreateUser(_ context.Context, statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
err = errors.New("err")
|
||||
if usernameConf.DisplayName == "" || expiration.IsZero() {
|
||||
return "", "", err
|
||||
|
@ -34,7 +36,7 @@ func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbp
|
|||
|
||||
return usernameConf.DisplayName, "test", nil
|
||||
}
|
||||
func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (m *mockPlugin) RenewUser(_ context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
err := errors.New("err")
|
||||
if username == "" || expiration.IsZero() {
|
||||
return err
|
||||
|
@ -46,7 +48,7 @@ func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string,
|
|||
|
||||
return nil
|
||||
}
|
||||
func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statements, username string) error {
|
||||
err := errors.New("err")
|
||||
if username == "" {
|
||||
return err
|
||||
|
@ -59,7 +61,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string)
|
|||
delete(m.users, username)
|
||||
return nil
|
||||
}
|
||||
func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error {
|
||||
func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error {
|
||||
err := errors.New("err")
|
||||
if len(conf) != 1 {
|
||||
return err
|
||||
|
@ -80,14 +82,15 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
|
|||
cores := cluster.Cores
|
||||
|
||||
sys := vault.TestDynamicSystemView(cores[0].Core)
|
||||
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_Main")
|
||||
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_GRPC_Main")
|
||||
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", "TestPlugin_NetRPC_Main")
|
||||
|
||||
return cluster, sys
|
||||
}
|
||||
|
||||
// This is not an actual test case, it's a helper function that will be executed
|
||||
// by the go-plugin client via an exec call.
|
||||
func TestPlugin_Main(t *testing.T) {
|
||||
func TestPlugin_GRPC_Main(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
@ -105,6 +108,30 @@ func TestPlugin_Main(t *testing.T) {
|
|||
plugins.Serve(plugin, apiClientMeta.GetTLSConfig())
|
||||
}
|
||||
|
||||
// This is not an actual test case, it's a helper function that will be executed
|
||||
// by the go-plugin client via an exec call.
|
||||
func TestPlugin_NetRPC_Main(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
p := &mockPlugin{
|
||||
users: make(map[string][]string),
|
||||
}
|
||||
|
||||
args := []string{"--tls-skip-verify=true"}
|
||||
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
|
||||
tlsProvider := pluginutil.VaultPluginTLSProvider(apiClientMeta.GetTLSConfig())
|
||||
serveConf := dbplugin.ServeConfig(p, tlsProvider)
|
||||
serveConf.GRPCServer = nil
|
||||
|
||||
plugin.Serve(serveConf)
|
||||
}
|
||||
|
||||
func TestPlugin_Initialize(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
@ -118,7 +145,7 @@ func TestPlugin_Initialize(t *testing.T) {
|
|||
"test": 1,
|
||||
}
|
||||
|
||||
err = dbRaw.Initialize(connectionDetails, true)
|
||||
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -143,7 +170,7 @@ func TestPlugin_CreateUser(t *testing.T) {
|
|||
"test": 1,
|
||||
}
|
||||
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -153,7 +180,7 @@ func TestPlugin_CreateUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, pw, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -163,7 +190,7 @@ func TestPlugin_CreateUser(t *testing.T) {
|
|||
|
||||
// try and save the same user again to verify it saved the first time, this
|
||||
// should return an error
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("expected an error, user wasn't created correctly")
|
||||
}
|
||||
|
@ -182,7 +209,7 @@ func TestPlugin_RenewUser(t *testing.T) {
|
|||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -192,12 +219,12 @@ func TestPlugin_RenewUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(dbplugin.Statements{}, us, time.Now().Add(time.Minute))
|
||||
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -216,7 +243,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
|
|||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -226,19 +253,159 @@ func TestPlugin_RevokeUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(dbplugin.Statements{}, us)
|
||||
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Try adding the same username back so we can verify it was removed
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test the code is still compatible with an old netRPC plugin
|
||||
func TestPlugin_NetRPC_Initialize(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
dbRaw, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
|
||||
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
err = dbRaw.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlugin_NetRPC_CreateUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
usernameConf := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if us != "test" || pw != "test" {
|
||||
t.Fatal("expected username and password to be 'test'")
|
||||
}
|
||||
|
||||
// try and save the same user again to verify it saved the first time, this
|
||||
// should return an error
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("expected an error, user wasn't created correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlugin_NetRPC_RenewUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
usernameConf := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
usernameConf := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Try adding the same username back so we can verify it was removed
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -10,6 +10,10 @@ import (
|
|||
// Database implementation in a databasePluginRPCServer object and starts a
|
||||
// RPC server.
|
||||
func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
|
||||
plugin.Serve(ServeConfig(db, tlsProvider))
|
||||
}
|
||||
|
||||
func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig {
|
||||
dbPlugin := &DatabasePlugin{
|
||||
impl: db,
|
||||
}
|
||||
|
@ -19,53 +23,10 @@ func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
|
|||
"database": dbPlugin,
|
||||
}
|
||||
|
||||
plugin.Serve(&plugin.ServeConfig{
|
||||
return &plugin.ServeConfig{
|
||||
HandshakeConfig: handshakeConfig,
|
||||
Plugins: pluginMap,
|
||||
TLSProvider: tlsProvider,
|
||||
})
|
||||
}
|
||||
|
||||
// ---- RPC server domain ----
|
||||
|
||||
// databasePluginRPCServer implements an RPC version of Database and is run
|
||||
// inside a plugin. It wraps an underlying implementation of Database.
|
||||
type databasePluginRPCServer struct {
|
||||
impl Database
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
|
||||
var err error
|
||||
*resp, err = ds.impl.Type()
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error {
|
||||
var err error
|
||||
resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernameConfig, args.Expiration)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error {
|
||||
err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error {
|
||||
err := ds.impl.RevokeUser(args.Statements, args.Username)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequest, _ *struct{}) error {
|
||||
err := ds.impl.Initialize(args.Config, args.VerifyConnection)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
|
||||
ds.impl.Close()
|
||||
return nil
|
||||
GRPCServer: plugin.DefaultGRPCServer,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
|
@ -62,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc {
|
|||
b.clearConnection(name)
|
||||
|
||||
// Execute plugin again, we don't need the object so throw away.
|
||||
_, err := b.createDBObj(req.Storage, name)
|
||||
_, err := b.createDBObj(context.TODO(), req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -230,7 +231,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
|||
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
||||
}
|
||||
|
||||
err = db.Initialize(config.ConnectionDetails, verifyConnection)
|
||||
err = db.Initialize(context.TODO(), config.ConnectionDetails, verifyConnection)
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
@ -66,7 +67,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
|||
unlockFunc = b.Unlock
|
||||
|
||||
// Create a new DB object
|
||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
||||
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
unlockFunc()
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
|
@ -81,7 +82,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
|||
}
|
||||
|
||||
// Create the user
|
||||
username, password, err := db.CreateUser(role.Statements, usernameConfig, expiration)
|
||||
username, password, err := db.CreateUser(context.TODO(), role.Statements, usernameConfig, expiration)
|
||||
// Unlock
|
||||
unlockFunc()
|
||||
if err != nil {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
|
@ -60,7 +61,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
|
|||
unlockFunc = b.Unlock
|
||||
|
||||
// Create a new DB object
|
||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
||||
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
unlockFunc()
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
|
@ -69,7 +70,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
|
|||
|
||||
// Make sure we increase the VALID UNTIL endpoint for this user.
|
||||
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
|
||||
err := db.RenewUser(role.Statements, username, expireTime)
|
||||
err := db.RenewUser(context.TODO(), role.Statements, username, expireTime)
|
||||
// Unlock
|
||||
unlockFunc()
|
||||
if err != nil {
|
||||
|
@ -119,14 +120,14 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
|
|||
unlockFunc = b.Unlock
|
||||
|
||||
// Create a new DB object
|
||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
||||
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
unlockFunc()
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = db.RevokeUser(role.Statements, username)
|
||||
err = db.RevokeUser(context.TODO(), role.Statements, username)
|
||||
// Unlock
|
||||
unlockFunc()
|
||||
if err != nil {
|
||||
|
|
|
@ -119,6 +119,10 @@ func (r *PluginRunner) runCommon(wrapper RunnerUtil, pluginMap map[string]plugin
|
|||
SecureConfig: secureConfig,
|
||||
TLSConfig: clientTLSConfig,
|
||||
Logger: namedLogger,
|
||||
AllowedProtocols: []plugin.Protocol{
|
||||
plugin.ProtocolNetRPC,
|
||||
plugin.ProtocolGRPC,
|
||||
},
|
||||
}
|
||||
|
||||
client := plugin.NewClient(clientConfig)
|
||||
|
|
|
@ -192,8 +192,7 @@ func TestBackendHandleRequest_helpRoot(t *testing.T) {
|
|||
func TestBackendHandleRequest_renewAuth(t *testing.T) {
|
||||
b := &Backend{}
|
||||
|
||||
resp, err := b.HandleRequest(logical.RenewAuthRequest(
|
||||
"/foo", &logical.Auth{}, nil))
|
||||
resp, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -213,8 +212,7 @@ func TestBackendHandleRequest_renewAuthCallback(t *testing.T) {
|
|||
AuthRenew: callback,
|
||||
}
|
||||
|
||||
_, err := b.HandleRequest(logical.RenewAuthRequest(
|
||||
"/foo", &logical.Auth{}, nil))
|
||||
_, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -237,8 +235,7 @@ func TestBackendHandleRequest_renew(t *testing.T) {
|
|||
Secrets: []*Secret{secret},
|
||||
}
|
||||
|
||||
_, err := b.HandleRequest(logical.RenewRequest(
|
||||
"/foo", secret.Response(nil, nil).Secret, nil))
|
||||
_, err := b.HandleRequest(logical.RenewRequest("/foo", secret.Response(nil, nil).Secret, nil))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -293,8 +290,7 @@ func TestBackendHandleRequest_revoke(t *testing.T) {
|
|||
Secrets: []*Secret{secret},
|
||||
}
|
||||
|
||||
_, err := b.HandleRequest(logical.RevokeRequest(
|
||||
"/foo", secret.Response(nil, nil).Secret, nil))
|
||||
_, err := b.HandleRequest(logical.RevokeRequest("/foo", secret.Response(nil, nil).Secret, nil))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -200,8 +200,7 @@ func (r *Request) SetLastRemoteWAL(last uint64) {
|
|||
}
|
||||
|
||||
// RenewRequest creates the structure of the renew request.
|
||||
func RenewRequest(
|
||||
path string, secret *Secret, data map[string]interface{}) *Request {
|
||||
func RenewRequest(path string, secret *Secret, data map[string]interface{}) *Request {
|
||||
return &Request{
|
||||
Operation: RenewOperation,
|
||||
Path: path,
|
||||
|
@ -211,8 +210,7 @@ func RenewRequest(
|
|||
}
|
||||
|
||||
// RenewAuthRequest creates the structure of the renew request for an auth.
|
||||
func RenewAuthRequest(
|
||||
path string, auth *Auth, data map[string]interface{}) *Request {
|
||||
func RenewAuthRequest(path string, auth *Auth, data map[string]interface{}) *Request {
|
||||
return &Request{
|
||||
Operation: RenewOperation,
|
||||
Path: path,
|
||||
|
@ -222,8 +220,7 @@ func RenewAuthRequest(
|
|||
}
|
||||
|
||||
// RevokeRequest creates the structure of the revoke request.
|
||||
func RevokeRequest(
|
||||
path string, secret *Secret, data map[string]interface{}) *Request {
|
||||
func RevokeRequest(path string, secret *Secret, data map[string]interface{}) *Request {
|
||||
return &Request{
|
||||
Operation: RevokeOperation,
|
||||
Path: path,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -21,6 +22,8 @@ const (
|
|||
cassandraTypeName = "cassandra"
|
||||
)
|
||||
|
||||
var _ dbplugin.Database = &Cassandra{}
|
||||
|
||||
// Cassandra is an implementation of Database interface
|
||||
type Cassandra struct {
|
||||
connutil.ConnectionProducer
|
||||
|
@ -64,8 +67,8 @@ func (c *Cassandra) Type() (string, error) {
|
|||
return cassandraTypeName, nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) getConnection() (*gocql.Session, error) {
|
||||
session, err := c.Connection()
|
||||
func (c *Cassandra) getConnection(ctx context.Context) (*gocql.Session, error) {
|
||||
session, err := c.Connection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -75,13 +78,13 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) {
|
|||
|
||||
// CreateUser generates the username/password on the underlying Cassandra secret backend as instructed by
|
||||
// the CreationStatement provided.
|
||||
func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
// Grab the lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
// Get the connection
|
||||
session, err := c.getConnection()
|
||||
session, err := c.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -138,18 +141,18 @@ func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig db
|
|||
}
|
||||
|
||||
// RenewUser is not supported on Cassandra, so this is a no-op.
|
||||
func (c *Cassandra) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (c *Cassandra) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeUser attempts to drop the specified user.
|
||||
func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
func (c *Cassandra) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
// Grab the lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
session, err := c.getConnection()
|
||||
session, err := c.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
@ -89,7 +90,7 @@ func TestCassandra_Initialize(t *testing.T) {
|
|||
db := dbRaw.(*Cassandra)
|
||||
connProducer := db.ConnectionProducer.(*cassandraConnectionProducer)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -112,7 +113,7 @@ func TestCassandra_Initialize(t *testing.T) {
|
|||
"protocol_version": "4",
|
||||
}
|
||||
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -135,7 +136,7 @@ func TestCassandra_CreateUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*Cassandra)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -149,7 +150,7 @@ func TestCassandra_CreateUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -176,7 +177,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*Cassandra)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -190,7 +191,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -199,7 +200,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
|
|||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
|
||||
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -222,7 +223,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*Cassandra)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -236,7 +237,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -246,7 +247,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -43,7 +44,7 @@ type cassandraConnectionProducer struct {
|
|||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
||||
func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
|
@ -106,7 +107,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve
|
|||
c.Initialized = true
|
||||
|
||||
if verifyConnection {
|
||||
if _, err := c.Connection(); err != nil {
|
||||
if _, err := c.Connection(ctx); err != nil {
|
||||
return fmt.Errorf("error verifying connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -114,7 +115,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *cassandraConnectionProducer) Connection() (interface{}, error) {
|
||||
func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) {
|
||||
if !c.Initialized {
|
||||
return nil, connutil.ErrNotInitialized
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package hana
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -26,6 +27,8 @@ type HANA struct {
|
|||
credsutil.CredentialsProducer
|
||||
}
|
||||
|
||||
var _ dbplugin.Database = &HANA{}
|
||||
|
||||
// New implements builtinplugins.BuiltinFactory
|
||||
func New() (interface{}, error) {
|
||||
connProducer := &connutil.SQLConnectionProducer{}
|
||||
|
@ -63,8 +66,8 @@ func (h *HANA) Type() (string, error) {
|
|||
return hanaTypeName, nil
|
||||
}
|
||||
|
||||
func (h *HANA) getConnection() (*sql.DB, error) {
|
||||
db, err := h.Connection()
|
||||
func (h *HANA) getConnection(ctx context.Context) (*sql.DB, error) {
|
||||
db, err := h.Connection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -74,13 +77,13 @@ func (h *HANA) getConnection() (*sql.DB, error) {
|
|||
|
||||
// CreateUser generates the username/password on the underlying HANA secret backend
|
||||
// as instructed by the CreationStatement provided.
|
||||
func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
// Grab the lock
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := h.getConnection()
|
||||
db, err := h.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -153,9 +156,9 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
|
|||
}
|
||||
|
||||
// Renewing hana user just means altering user's valid until property
|
||||
func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
// Get connection
|
||||
db, err := h.getConnection()
|
||||
db, err := h.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -193,14 +196,14 @@ func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expira
|
|||
}
|
||||
|
||||
// Revoking hana user will deactivate user and try to perform a soft drop
|
||||
func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
// default revoke will be a soft drop on user
|
||||
if statements.RevocationStatements == "" {
|
||||
return h.revokeUserDefault(username)
|
||||
return h.revokeUserDefault(ctx, username)
|
||||
}
|
||||
|
||||
// Get connection
|
||||
db, err := h.getConnection()
|
||||
db, err := h.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -239,9 +242,9 @@ func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (h *HANA) revokeUserDefault(username string) error {
|
||||
func (h *HANA) revokeUserDefault(ctx context.Context, username string) error {
|
||||
// Get connection
|
||||
db, err := h.getConnection()
|
||||
db, err := h.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package hana
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -25,7 +26,7 @@ func TestHANA_Initialize(t *testing.T) {
|
|||
dbRaw, _ := New()
|
||||
db := dbRaw.(*HANA)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -55,7 +56,7 @@ func TestHANA_CreateUser(t *testing.T) {
|
|||
dbRaw, _ := New()
|
||||
db := dbRaw.(*HANA)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -66,7 +67,7 @@ func TestHANA_CreateUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
@ -75,7 +76,7 @@ func TestHANA_CreateUser(t *testing.T) {
|
|||
CreationStatements: testHANARole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -98,7 +99,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
|||
dbRaw, _ := New()
|
||||
db := dbRaw.(*HANA)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -113,7 +114,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -121,7 +122,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
|||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -130,7 +131,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test custom revoke statememt
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -139,7 +140,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
statements.RevocationStatements = testHANADrop
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
|
@ -33,7 +34,7 @@ type mongoDBConnectionProducer struct {
|
|||
}
|
||||
|
||||
// Initialize parses connection configuration.
|
||||
func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
||||
func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
|
@ -75,7 +76,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri
|
|||
c.Initialized = true
|
||||
|
||||
if verifyConnection {
|
||||
if _, err := c.Connection(); err != nil {
|
||||
if _, err := c.Connection(ctx); err != nil {
|
||||
return fmt.Errorf("error verifying connection: %s", err)
|
||||
}
|
||||
|
||||
|
@ -88,7 +89,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri
|
|||
}
|
||||
|
||||
// Connection creates a database connection.
|
||||
func (c *mongoDBConnectionProducer) Connection() (interface{}, error) {
|
||||
func (c *mongoDBConnectionProducer) Connection(_ context.Context) (interface{}, error) {
|
||||
if !c.Initialized {
|
||||
return nil, connutil.ErrNotInitialized
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -27,6 +28,8 @@ type MongoDB struct {
|
|||
credsutil.CredentialsProducer
|
||||
}
|
||||
|
||||
var _ dbplugin.Database = &MongoDB{}
|
||||
|
||||
// New returns a new MongoDB instance
|
||||
func New() (interface{}, error) {
|
||||
connProducer := &mongoDBConnectionProducer{}
|
||||
|
@ -63,8 +66,8 @@ func (m *MongoDB) Type() (string, error) {
|
|||
return mongoDBTypeName, nil
|
||||
}
|
||||
|
||||
func (m *MongoDB) getConnection() (*mgo.Session, error) {
|
||||
session, err := m.Connection()
|
||||
func (m *MongoDB) getConnection(ctx context.Context) (*mgo.Session, error) {
|
||||
session, err := m.Connection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -80,7 +83,7 @@ func (m *MongoDB) getConnection() (*mgo.Session, error) {
|
|||
//
|
||||
// JSON Example:
|
||||
// { "db": "admin", "roles": [{ "role": "readWrite" }, {"role": "read", "db": "foo"}] }
|
||||
func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
// Grab the lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
@ -89,7 +92,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
|
|||
return "", "", dbutil.ErrEmptyCreationStatement
|
||||
}
|
||||
|
||||
session, err := m.getConnection()
|
||||
session, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -133,7 +136,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
|
|||
if err := m.ConnectionProducer.Close(); err != nil {
|
||||
return "", "", errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
|
||||
}
|
||||
session, err := m.getConnection()
|
||||
session, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -149,15 +152,15 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
|
|||
}
|
||||
|
||||
// RenewUser is not supported on MongoDB, so this is a no-op.
|
||||
func (m *MongoDB) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (m *MongoDB) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeUser drops the specified user from the authentication databse. If none is provided
|
||||
// in the revocation statement, the default "admin" authentication database will be assumed.
|
||||
func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
session, err := m.getConnection()
|
||||
func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
session, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -188,7 +191,7 @@ func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) er
|
|||
if err := m.ConnectionProducer.Close(); err != nil {
|
||||
return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
|
||||
}
|
||||
session, err := m.getConnection()
|
||||
session, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
@ -79,7 +80,7 @@ func TestMongoDB_Initialize(t *testing.T) {
|
|||
db := dbRaw.(*MongoDB)
|
||||
connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer)
|
||||
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -107,7 +108,7 @@ func TestMongoDB_CreateUser(t *testing.T) {
|
|||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
db := dbRaw.(*MongoDB)
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -121,7 +122,7 @@ func TestMongoDB_CreateUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -145,7 +146,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
|
|||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
db := dbRaw.(*MongoDB)
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -159,7 +160,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -182,7 +183,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
|
|||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
db := dbRaw.(*MongoDB)
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -196,7 +197,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -206,7 +207,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test default revocation statememt
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -18,6 +19,8 @@ import (
|
|||
|
||||
const msSQLTypeName = "mssql"
|
||||
|
||||
var _ dbplugin.Database = &MSSQL{}
|
||||
|
||||
// MSSQL is an implementation of Database interface
|
||||
type MSSQL struct {
|
||||
connutil.ConnectionProducer
|
||||
|
@ -60,8 +63,8 @@ func (m *MSSQL) Type() (string, error) {
|
|||
return msSQLTypeName, nil
|
||||
}
|
||||
|
||||
func (m *MSSQL) getConnection() (*sql.DB, error) {
|
||||
db, err := m.Connection()
|
||||
func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) {
|
||||
db, err := m.Connection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -71,13 +74,13 @@ func (m *MSSQL) getConnection() (*sql.DB, error) {
|
|||
|
||||
// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
|
||||
// the CreationStatement provided.
|
||||
func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
// Grab the lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
db, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -138,7 +141,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
|||
}
|
||||
|
||||
// RenewUser is not supported on MSSQL, so this is a no-op.
|
||||
func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
@ -146,13 +149,13 @@ func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expir
|
|||
// RevokeUser attempts to drop the specified user. It will first attempt to disable login,
|
||||
// then kill pending connections from that user, and finally drop the user and login from the
|
||||
// database instance.
|
||||
func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
if statements.RevocationStatements == "" {
|
||||
return m.revokeUserDefault(username)
|
||||
return m.revokeUserDefault(ctx, username)
|
||||
}
|
||||
|
||||
// Get connection
|
||||
db, err := m.getConnection()
|
||||
db, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -191,9 +194,9 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *MSSQL) revokeUserDefault(username string) error {
|
||||
func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
|
||||
// Get connection
|
||||
db, err := m.getConnection()
|
||||
db, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -30,7 +31,7 @@ func TestMSSQL_Initialize(t *testing.T) {
|
|||
dbRaw, _ := New()
|
||||
db := dbRaw.(*MSSQL)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -51,7 +52,7 @@ func TestMSSQL_Initialize(t *testing.T) {
|
|||
"max_open_connections": "5",
|
||||
}
|
||||
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -69,7 +70,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*MSSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -80,7 +81,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
@ -89,7 +90,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
|
|||
CreationStatements: testMSSQLRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -111,7 +112,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*MSSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -125,7 +126,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -135,7 +136,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -144,7 +145,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
|||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -155,7 +156,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
|||
|
||||
// Test custom revoke statememt
|
||||
statements.RevocationStatements = testMSSQLDrop
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -30,6 +31,8 @@ var (
|
|||
LegacyUsernameLen int = 16
|
||||
)
|
||||
|
||||
var _ dbplugin.Database = &MySQL{}
|
||||
|
||||
type MySQL struct {
|
||||
connutil.ConnectionProducer
|
||||
credsutil.CredentialsProducer
|
||||
|
@ -88,8 +91,8 @@ func (m *MySQL) Type() (string, error) {
|
|||
return mySQLTypeName, nil
|
||||
}
|
||||
|
||||
func (m *MySQL) getConnection() (*sql.DB, error) {
|
||||
db, err := m.Connection()
|
||||
func (m *MySQL) getConnection(ctx context.Context) (*sql.DB, error) {
|
||||
db, err := m.Connection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -97,13 +100,13 @@ func (m *MySQL) getConnection() (*sql.DB, error) {
|
|||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
// Grab the lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
db, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -128,7 +131,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
|||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -146,7 +149,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
|||
"expiration": expirationStr,
|
||||
})
|
||||
|
||||
stmt, err := tx.Prepare(query)
|
||||
stmt, err := tx.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
// If the error code we get back is Error 1295: This command is not
|
||||
// supported in the prepared statement protocol yet, we will execute
|
||||
|
@ -155,7 +158,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
|||
// prepare supported commands. If there is no error when running we
|
||||
// will continue to the next statement.
|
||||
if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 {
|
||||
_, err = tx.Exec(query)
|
||||
_, err = tx.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -165,7 +168,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
|||
return "", "", err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
@ -179,17 +182,17 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
|||
}
|
||||
|
||||
// NOOP
|
||||
func (m *MySQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (m *MySQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
// Grab the read lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
db, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -201,7 +204,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro
|
|||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -217,7 +220,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro
|
|||
// 1295: This command is not supported in the prepared statement protocol yet
|
||||
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
|
||||
query = strings.Replace(query, "{{name}}", username, -1)
|
||||
_, err = tx.Exec(query)
|
||||
_, err = tx.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -108,7 +109,7 @@ func TestMySQL_Initialize(t *testing.T) {
|
|||
db := dbRaw.(*MySQL)
|
||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -128,7 +129,7 @@ func TestMySQL_Initialize(t *testing.T) {
|
|||
"max_open_connections": "5",
|
||||
}
|
||||
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -146,7 +147,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
|||
dbRaw, _ := f()
|
||||
db := dbRaw.(*MySQL)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -157,7 +158,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
@ -166,7 +167,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
|||
CreationStatements: testMySQLRoleWildCard,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -176,7 +177,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test a second time to make sure usernames don't collide
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -188,7 +189,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
|||
// Test with a manualy prepare statement
|
||||
statements.CreationStatements = testMySQLRolePreparedStmt
|
||||
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -211,7 +212,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
|
|||
dbRaw, _ := f()
|
||||
db := dbRaw.(*MySQL)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -222,7 +223,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
@ -231,7 +232,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
|
|||
CreationStatements: testMySQLRoleWildCard,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -241,7 +242,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test a second time to make sure usernames don't collide
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -263,7 +264,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
|||
dbRaw, _ := f()
|
||||
db := dbRaw.(*MySQL)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -277,7 +278,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -287,7 +288,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -297,7 +298,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
statements.CreationStatements = testMySQLRoleWildCard
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -308,7 +309,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
|||
|
||||
// Test custom revoke statements
|
||||
statements.RevocationStatements = testMySQLRevocationSQL
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -24,6 +25,8 @@ ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
|
|||
`
|
||||
)
|
||||
|
||||
var _ dbplugin.Database = &PostgreSQL{}
|
||||
|
||||
// New implements builtinplugins.BuiltinFactory
|
||||
func New() (interface{}, error) {
|
||||
connProducer := &connutil.SQLConnectionProducer{}
|
||||
|
@ -65,8 +68,8 @@ func (p *PostgreSQL) Type() (string, error) {
|
|||
return postgreSQLTypeName, nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) getConnection() (*sql.DB, error) {
|
||||
db, err := p.Connection()
|
||||
func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) {
|
||||
db, err := p.Connection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -74,7 +77,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) {
|
|||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
if statements.CreationStatements == "" {
|
||||
return "", "", dbutil.ErrEmptyCreationStatement
|
||||
}
|
||||
|
@ -99,7 +102,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
|
|||
}
|
||||
|
||||
// Get the connection
|
||||
db, err := p.getConnection()
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
|
||||
|
@ -148,7 +151,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
|
|||
return username, password, nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
|
@ -157,7 +160,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
|
|||
renewStmts = defaultPostgresRenewSQL
|
||||
}
|
||||
|
||||
db, err := p.getConnection()
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -201,20 +204,20 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if statements.RevocationStatements == "" {
|
||||
return p.defaultRevokeUser(username)
|
||||
return p.defaultRevokeUser(ctx, username)
|
||||
}
|
||||
|
||||
return p.customRevokeUser(username, statements.RevocationStatements)
|
||||
return p.customRevokeUser(ctx, username, statements.RevocationStatements)
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
||||
db, err := p.getConnection()
|
||||
func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationStmts string) error {
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -253,8 +256,8 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
||||
db, err := p.getConnection()
|
||||
func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error {
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -72,7 +73,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
|
|||
|
||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -92,7 +93,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
|
|||
"max_open_connections": "5",
|
||||
}
|
||||
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -109,7 +110,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*PostgreSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -120,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
@ -129,7 +130,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
|||
CreationStatements: testPostgresRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -139,7 +140,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
|||
}
|
||||
|
||||
statements.CreationStatements = testPostgresReadOnlyRole
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -162,7 +163,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*PostgreSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -176,7 +177,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -185,7 +186,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
|||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
|
||||
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -197,7 +198,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
|||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
statements.RenewStatements = defaultPostgresRenewSQL
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -206,7 +207,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
|||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
|
||||
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -230,7 +231,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*PostgreSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -244,7 +245,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -254,7 +255,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -263,7 +264,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
|||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -274,7 +275,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
|||
|
||||
// Test custom revoke statements
|
||||
statements.RevocationStatements = defaultPostgresRevocationSQL
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package connutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
@ -14,8 +15,8 @@ var (
|
|||
// connections and is used in all the builtin database types.
|
||||
type ConnectionProducer interface {
|
||||
Close() error
|
||||
Initialize(map[string]interface{}, bool) error
|
||||
Connection() (interface{}, error)
|
||||
Initialize(context.Context, map[string]interface{}, bool) error
|
||||
Connection(context.Context) (interface{}, error)
|
||||
|
||||
sync.Locker
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package connutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -25,7 +26,7 @@ type SQLConnectionProducer struct {
|
|||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
||||
func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
|
@ -62,11 +63,11 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo
|
|||
c.Initialized = true
|
||||
|
||||
if verifyConnection {
|
||||
if _, err := c.Connection(); err != nil {
|
||||
if _, err := c.Connection(ctx); err != nil {
|
||||
return fmt.Errorf("error verifying connection: %s", err)
|
||||
}
|
||||
|
||||
if err := c.db.Ping(); err != nil {
|
||||
if err := c.db.PingContext(ctx); err != nil {
|
||||
return fmt.Errorf("error verifying connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -74,14 +75,14 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *SQLConnectionProducer) Connection() (interface{}, error) {
|
||||
func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
|
||||
if !c.Initialized {
|
||||
return nil, ErrNotInitialized
|
||||
}
|
||||
|
||||
// If we already have a DB, test it and return
|
||||
if c.db != nil {
|
||||
if err := c.db.Ping(); err == nil {
|
||||
if err := c.db.PingContext(ctx); err == nil {
|
||||
return c.db, nil
|
||||
}
|
||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||
|
|
|
@ -1017,8 +1017,7 @@ func (m *ExpirationManager) revokeEntry(le *leaseEntry) error {
|
|||
}
|
||||
|
||||
// Handle standard revocation via backends
|
||||
resp, err := m.router.Route(logical.RevokeRequest(
|
||||
le.Path, le.Secret, le.Data))
|
||||
resp, err := m.router.Route(logical.RevokeRequest(le.Path, le.Secret, le.Data))
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
return fmt.Errorf("failed to revoke entry: resp:%#v err:%s", resp, err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue