Database gRPC plugins (#3666)

* Start work on context aware backends

* Start work on moving the database plugins to gRPC in order to pass context

* Add context to builtin database plugins

* use byte slice instead of string

* Context all the things

* Move proto messages to the dbplugin package

* Add a grpc mechanism for running backend plugins

* Serve the GRPC plugin

* Add backwards compatibility to the database plugins

* Remove backend plugin changes

* Remove backend plugin changes

* Cleanup the transport implementations

* If grpc connection is in an unexpected state restart the plugin

* Fix tests

* Fix tests

* Remove context from the request object, replace it with context.TODO

* Add a test to verify netRPC plugins still work

* Remove unused mapstructure call

* Code review fixes

* Code review fixes

* Code review fixes
This commit is contained in:
Brian Kassouf 2017-12-14 14:03:11 -08:00 committed by GitHub
parent 829f2c38dc
commit afe53eb862
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 1425 additions and 390 deletions

View File

@ -84,6 +84,7 @@ proto:
protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding
protoc -I physical physical/types.proto --go_out=plugins=grpc:physical
protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity
protoc builtin/logical/database/dbplugin/*.proto --go_out=plugins=grpc:.
sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go
sed -i -e 's/Iv/IV/' -e 's/Hmac/HMAC/' physical/types.pb.go

View File

@ -1,6 +1,7 @@
package database
import (
"context"
"fmt"
"net/rpc"
"strings"
@ -87,7 +88,7 @@ func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) {
// This function creates a new db object from the stored configuration and
// caches it in the connections map. The caller of this function needs to hold
// the backend's write lock
func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) {
func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) {
db, ok := b.connections[name]
if ok {
return db, nil
@ -103,7 +104,7 @@ func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.
return nil, err
}
err = db.Initialize(config.ConnectionDetails, true)
err = db.Initialize(ctx, config.ConnectionDetails, true)
if err != nil {
return nil, err
}
@ -170,7 +171,8 @@ func (b *databaseBackend) clearConnection(name string) {
func (b *databaseBackend) closeIfShutdown(name string, err error) {
// Plugin has shutdown, close it so next call can reconnect.
if err == rpc.ErrShutdown {
switch err {
case rpc.ErrShutdown, dbplugin.ErrPluginShutdown:
b.Lock()
b.clearConnection(name)
b.Unlock()

View File

@ -488,9 +488,11 @@ func TestBackend_roleCrud(t *testing.T) {
RevocationStatements: defaultRevocationSQL,
}
var actual dbplugin.Statements
if err := mapstructure.Decode(resp.Data, &actual); err != nil {
t.Fatal(err)
actual := dbplugin.Statements{
CreationStatements: resp.Data["creation_statements"].(string),
RevocationStatements: resp.Data["revocation_statements"].(string),
RollbackStatements: resp.Data["rollback_statements"].(string),
RenewStatements: resp.Data["renew_statements"].(string),
}
if !reflect.DeepEqual(expected, actual) {

View File

@ -1,10 +1,8 @@
package dbplugin
import (
"fmt"
"net/rpc"
"errors"
"sync"
"time"
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil"
@ -17,11 +15,11 @@ type DatabasePluginClient struct {
client *plugin.Client
sync.Mutex
*databasePluginRPCClient
Database
}
func (dc *DatabasePluginClient) Close() error {
err := dc.databasePluginRPCClient.Close()
err := dc.Database.Close()
dc.client.Kill()
return err
@ -55,79 +53,20 @@ func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginR
// We should have a database type now. This feels like a normal interface
// implementation but is in fact over an RPC connection.
databaseRPC := raw.(*databasePluginRPCClient)
var db Database
switch raw.(type) {
case *gRPCClient:
db = raw.(*gRPCClient)
case *databasePluginRPCClient:
logger.Warn("database: plugin is using deprecated net RPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name)
db = raw.(*databasePluginRPCClient)
default:
return nil, errors.New("unsupported client type")
}
// Wrap RPC implimentation in DatabasePluginClient
return &DatabasePluginClient{
client: client,
databasePluginRPCClient: databaseRPC,
client: client,
Database: db,
}, nil
}
// ---- RPC client domain ----
// databasePluginRPCClient implements Database and is used on the client to
// make RPC calls to a plugin.
type databasePluginRPCClient struct {
client *rpc.Client
}
func (dr *databasePluginRPCClient) Type() (string, error) {
var dbType string
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
return fmt.Sprintf("plugin-%s", dbType), err
}
func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
req := CreateUserRequest{
Statements: statements,
UsernameConfig: usernameConfig,
Expiration: expiration,
}
var resp CreateUserResponse
err = dr.client.Call("Plugin.CreateUser", req, &resp)
return resp.Username, resp.Password, err
}
func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error {
req := RenewUserRequest{
Statements: statements,
Username: username,
Expiration: expiration,
}
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error {
req := RevokeUserRequest{
Statements: statements,
Username: username,
}
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}, verifyConnection bool) error {
req := InitializeRequest{
Config: conf,
VerifyConnection: verifyConnection,
}
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Close() error {
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
return err
}

View File

@ -0,0 +1,556 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: builtin/logical/database/dbplugin/database.proto
/*
Package dbplugin is a generated protocol buffer package.
It is generated from these files:
builtin/logical/database/dbplugin/database.proto
It has these top-level messages:
InitializeRequest
CreateUserRequest
RenewUserRequest
RevokeUserRequest
Statements
UsernameConfig
CreateUserResponse
TypeResponse
Empty
*/
package dbplugin
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import google_protobuf "github.com/golang/protobuf/ptypes/timestamp"
import (
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type InitializeRequest struct {
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
VerifyConnection bool `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"`
}
func (m *InitializeRequest) Reset() { *m = InitializeRequest{} }
func (m *InitializeRequest) String() string { return proto.CompactTextString(m) }
func (*InitializeRequest) ProtoMessage() {}
func (*InitializeRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *InitializeRequest) GetConfig() []byte {
if m != nil {
return m.Config
}
return nil
}
func (m *InitializeRequest) GetVerifyConnection() bool {
if m != nil {
return m.VerifyConnection
}
return false
}
type CreateUserRequest struct {
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"`
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
}
func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} }
func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) }
func (*CreateUserRequest) ProtoMessage() {}
func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *CreateUserRequest) GetStatements() *Statements {
if m != nil {
return m.Statements
}
return nil
}
func (m *CreateUserRequest) GetUsernameConfig() *UsernameConfig {
if m != nil {
return m.UsernameConfig
}
return nil
}
func (m *CreateUserRequest) GetExpiration() *google_protobuf.Timestamp {
if m != nil {
return m.Expiration
}
return nil
}
type RenewUserRequest struct {
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
}
func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} }
func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) }
func (*RenewUserRequest) ProtoMessage() {}
func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
func (m *RenewUserRequest) GetStatements() *Statements {
if m != nil {
return m.Statements
}
return nil
}
func (m *RenewUserRequest) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
func (m *RenewUserRequest) GetExpiration() *google_protobuf.Timestamp {
if m != nil {
return m.Expiration
}
return nil
}
type RevokeUserRequest struct {
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
}
func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} }
func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) }
func (*RevokeUserRequest) ProtoMessage() {}
func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
func (m *RevokeUserRequest) GetStatements() *Statements {
if m != nil {
return m.Statements
}
return nil
}
func (m *RevokeUserRequest) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
type Statements struct {
CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"`
RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"`
RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"`
RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"`
}
func (m *Statements) Reset() { *m = Statements{} }
func (m *Statements) String() string { return proto.CompactTextString(m) }
func (*Statements) ProtoMessage() {}
func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
func (m *Statements) GetCreationStatements() string {
if m != nil {
return m.CreationStatements
}
return ""
}
func (m *Statements) GetRevocationStatements() string {
if m != nil {
return m.RevocationStatements
}
return ""
}
func (m *Statements) GetRollbackStatements() string {
if m != nil {
return m.RollbackStatements
}
return ""
}
func (m *Statements) GetRenewStatements() string {
if m != nil {
return m.RenewStatements
}
return ""
}
type UsernameConfig struct {
DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"`
RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"`
}
func (m *UsernameConfig) Reset() { *m = UsernameConfig{} }
func (m *UsernameConfig) String() string { return proto.CompactTextString(m) }
func (*UsernameConfig) ProtoMessage() {}
func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
func (m *UsernameConfig) GetDisplayName() string {
if m != nil {
return m.DisplayName
}
return ""
}
func (m *UsernameConfig) GetRoleName() string {
if m != nil {
return m.RoleName
}
return ""
}
type CreateUserResponse struct {
Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"`
Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"`
}
func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} }
func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) }
func (*CreateUserResponse) ProtoMessage() {}
func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
func (m *CreateUserResponse) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
func (m *CreateUserResponse) GetPassword() string {
if m != nil {
return m.Password
}
return ""
}
type TypeResponse struct {
Type string `protobuf:"bytes,1,opt,name=type" json:"type,omitempty"`
}
func (m *TypeResponse) Reset() { *m = TypeResponse{} }
func (m *TypeResponse) String() string { return proto.CompactTextString(m) }
func (*TypeResponse) ProtoMessage() {}
func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
func (m *TypeResponse) GetType() string {
if m != nil {
return m.Type
}
return ""
}
type Empty struct {
}
func (m *Empty) Reset() { *m = Empty{} }
func (m *Empty) String() string { return proto.CompactTextString(m) }
func (*Empty) ProtoMessage() {}
func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} }
func init() {
proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest")
proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest")
proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest")
proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest")
proto.RegisterType((*Statements)(nil), "dbplugin.Statements")
proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig")
proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse")
proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse")
proto.RegisterType((*Empty)(nil), "dbplugin.Empty")
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// Client API for Database service
type DatabaseClient interface {
Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error)
CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error)
RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error)
RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error)
Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error)
Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
}
type databaseClient struct {
cc *grpc.ClientConn
}
func NewDatabaseClient(cc *grpc.ClientConn) DatabaseClient {
return &databaseClient{cc}
}
func (c *databaseClient) Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) {
out := new(TypeResponse)
err := grpc.Invoke(ctx, "/dbplugin.Database/Type", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) {
out := new(CreateUserResponse)
err := grpc.Invoke(ctx, "/dbplugin.Database/CreateUser", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/RenewUser", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/RevokeUser", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/Close", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// Server API for Database service
type DatabaseServer interface {
Type(context.Context, *Empty) (*TypeResponse, error)
CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error)
RenewUser(context.Context, *RenewUserRequest) (*Empty, error)
RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error)
Initialize(context.Context, *InitializeRequest) (*Empty, error)
Close(context.Context, *Empty) (*Empty, error)
}
func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) {
s.RegisterService(&_Database_serviceDesc, srv)
}
func _Database_Type_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Type(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Type",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Type(ctx, req.(*Empty))
}
return interceptor(ctx, in, info, handler)
}
func _Database_CreateUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CreateUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).CreateUser(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/CreateUser",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).CreateUser(ctx, req.(*CreateUserRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_RenewUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RenewUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).RenewUser(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/RenewUser",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).RenewUser(ctx, req.(*RenewUserRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RevokeUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).RevokeUser(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/RevokeUser",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).RevokeUser(ctx, req.(*RevokeUserRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(InitializeRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Initialize(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Initialize",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Close(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Close",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Close(ctx, req.(*Empty))
}
return interceptor(ctx, in, info, handler)
}
var _Database_serviceDesc = grpc.ServiceDesc{
ServiceName: "dbplugin.Database",
HandlerType: (*DatabaseServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Type",
Handler: _Database_Type_Handler,
},
{
MethodName: "CreateUser",
Handler: _Database_CreateUser_Handler,
},
{
MethodName: "RenewUser",
Handler: _Database_RenewUser_Handler,
},
{
MethodName: "RevokeUser",
Handler: _Database_RevokeUser_Handler,
},
{
MethodName: "Initialize",
Handler: _Database_Initialize_Handler,
},
{
MethodName: "Close",
Handler: _Database_Close_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "builtin/logical/database/dbplugin/database.proto",
}
func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 548 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e,
0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f,
0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e,
0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2,
0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6,
0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c,
0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e,
0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3,
0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1,
0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f,
0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49,
0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35,
0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20,
0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6,
0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34,
0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0,
0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce,
0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68,
0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0,
0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8,
0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1,
0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0,
0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a,
0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab,
0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c,
0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4,
0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36,
0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a,
0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf,
0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67,
0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b,
0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d,
0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6,
0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56,
0x94, 0x05, 0x00, 0x00,
}

View File

@ -0,0 +1,58 @@
syntax = "proto3";
package dbplugin;
import "google/protobuf/timestamp.proto";
message InitializeRequest {
bytes config = 1;
bool verify_connection = 2;
}
message CreateUserRequest {
Statements statements = 1;
UsernameConfig username_config = 2;
google.protobuf.Timestamp expiration = 3;
}
message RenewUserRequest {
Statements statements = 1;
string username = 2;
google.protobuf.Timestamp expiration = 3;
}
message RevokeUserRequest {
Statements statements = 1;
string username = 2;
}
message Statements {
string creation_statements = 1;
string revocation_statements = 2;
string rollback_statements = 3;
string renew_statements = 4;
}
message UsernameConfig {
string DisplayName = 1;
string RoleName = 2;
}
message CreateUserResponse {
string username = 1;
string password = 2;
}
message TypeResponse {
string type = 1;
}
message Empty {}
service Database {
rpc Type(Empty) returns (TypeResponse);
rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
rpc RenewUser(RenewUserRequest) returns (Empty);
rpc RevokeUser(RevokeUserRequest) returns (Empty);
rpc Initialize(InitializeRequest) returns (Empty);
rpc Close(Empty) returns (Empty);
}

View File

@ -1,6 +1,7 @@
package dbplugin
import (
"context"
"time"
metrics "github.com/armon/go-metrics"
@ -15,55 +16,56 @@ type databaseTracingMiddleware struct {
next Database
logger log.Logger
typeStr string
typeStr string
transport string
}
func (mw *databaseTracingMiddleware) Type() (string, error) {
return mw.next.Type()
}
func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (mw *databaseTracingMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr)
return mw.next.CreateUser(statements, usernameConfig, expiration)
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
}
func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
func (mw *databaseTracingMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr)
return mw.next.RenewUser(statements, username, expiration)
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr, "transport", mw.transport)
return mw.next.RenewUser(ctx, statements, username, expiration)
}
func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) {
func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr)
return mw.next.RevokeUser(statements, username)
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.RevokeUser(ctx, statements, username)
}
func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then))
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr)
return mw.next.Initialize(conf, verifyConnection)
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.Initialize(ctx, conf, verifyConnection)
}
func (mw *databaseTracingMiddleware) Close() (err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr)
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.Close()
}
@ -81,7 +83,7 @@ func (mw *databaseMetricsMiddleware) Type() (string, error) {
return mw.next.Type()
}
func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (mw *databaseMetricsMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "CreateUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now)
@ -94,10 +96,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameC
metrics.IncrCounter([]string{"database", "CreateUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1)
return mw.next.CreateUser(statements, usernameConfig, expiration)
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
}
func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
func (mw *databaseMetricsMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "RenewUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now)
@ -110,10 +112,10 @@ func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username s
metrics.IncrCounter([]string{"database", "RenewUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1)
return mw.next.RenewUser(statements, username, expiration)
return mw.next.RenewUser(ctx, statements, username, expiration)
}
func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username string) (err error) {
func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "RevokeUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now)
@ -126,10 +128,10 @@ func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username
metrics.IncrCounter([]string{"database", "RevokeUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1)
return mw.next.RevokeUser(statements, username)
return mw.next.RevokeUser(ctx, statements, username)
}
func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "Initialize"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
@ -142,7 +144,7 @@ func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, ver
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
return mw.next.Initialize(conf, verifyConnection)
return mw.next.Initialize(ctx, conf, verifyConnection)
}
func (mw *databaseMetricsMiddleware) Close() (err error) {

View File

@ -0,0 +1,198 @@
package dbplugin
import (
"context"
"encoding/json"
"errors"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"github.com/golang/protobuf/ptypes"
)
var (
ErrPluginShutdown = errors.New("plugin shutdown")
)
// ---- gRPC Server domain ----
type gRPCServer struct {
impl Database
}
func (s *gRPCServer) Type(context.Context, *Empty) (*TypeResponse, error) {
t, err := s.impl.Type()
if err != nil {
return nil, err
}
return &TypeResponse{
Type: t,
}, nil
}
func (s *gRPCServer) CreateUser(ctx context.Context, req *CreateUserRequest) (*CreateUserResponse, error) {
e, err := ptypes.Timestamp(req.Expiration)
if err != nil {
return nil, err
}
u, p, err := s.impl.CreateUser(ctx, *req.Statements, *req.UsernameConfig, e)
return &CreateUserResponse{
Username: u,
Password: p,
}, err
}
func (s *gRPCServer) RenewUser(ctx context.Context, req *RenewUserRequest) (*Empty, error) {
e, err := ptypes.Timestamp(req.Expiration)
if err != nil {
return nil, err
}
err = s.impl.RenewUser(ctx, *req.Statements, req.Username, e)
return &Empty{}, err
}
func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*Empty, error) {
err := s.impl.RevokeUser(ctx, *req.Statements, req.Username)
return &Empty{}, err
}
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
config := map[string]interface{}{}
err := json.Unmarshal(req.Config, &config)
if err != nil {
return nil, err
}
err = s.impl.Initialize(ctx, config, req.VerifyConnection)
return &Empty{}, err
}
func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) {
s.impl.Close()
return &Empty{}, nil
}
// ---- gRPC client domain ----
type gRPCClient struct {
client DatabaseClient
clientConn *grpc.ClientConn
}
func (c gRPCClient) Type() (string, error) {
// If the plugin has already shutdown, this will hang forever so we give it
// a one second timeout.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return "", ErrPluginShutdown
}
resp, err := c.client.Type(ctx, &Empty{})
if err != nil {
return "", err
}
return resp.Type, err
}
func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
t, err := ptypes.TimestampProto(expiration)
if err != nil {
return "", "", err
}
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return "", "", ErrPluginShutdown
}
resp, err := c.client.CreateUser(ctx, &CreateUserRequest{
Statements: &statements,
UsernameConfig: &usernameConfig,
Expiration: t,
})
if err != nil {
return "", "", err
}
return resp.Username, resp.Password, err
}
func (c *gRPCClient) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error {
t, err := ptypes.TimestampProto(expiration)
if err != nil {
return err
}
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return ErrPluginShutdown
}
_, err = c.client.RenewUser(ctx, &RenewUserRequest{
Statements: &statements,
Username: username,
Expiration: t,
})
return err
}
func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, username string) error {
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return ErrPluginShutdown
}
_, err := c.client.RevokeUser(ctx, &RevokeUserRequest{
Statements: &statements,
Username: username,
})
return err
}
func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error {
configRaw, err := json.Marshal(config)
if err != nil {
return err
}
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return ErrPluginShutdown
}
_, err = c.client.Initialize(ctx, &InitializeRequest{
Config: configRaw,
VerifyConnection: verifyConnection,
})
return err
}
func (c *gRPCClient) Close() error {
// If the plugin has already shutdown, this will hang forever so we give it
// a one second timeout.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
_, err := c.client.Close(ctx, &Empty{})
return err
}
return nil
}

View File

@ -0,0 +1,139 @@
package dbplugin
import (
"context"
"fmt"
"net/rpc"
"time"
)
// ---- RPC server domain ----
// databasePluginRPCServer implements an RPC version of Database and is run
// inside a plugin. It wraps an underlying implementation of Database.
type databasePluginRPCServer struct {
impl Database
}
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
var err error
*resp, err = ds.impl.Type()
return err
}
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequestRPC, resp *CreateUserResponse) error {
var err error
resp.Username, resp.Password, err = ds.impl.CreateUser(context.Background(), args.Statements, args.UsernameConfig, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequestRPC, _ *struct{}) error {
err := ds.impl.RenewUser(context.Background(), args.Statements, args.Username, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *struct{}) error {
err := ds.impl.RevokeUser(context.Background(), args.Statements, args.Username)
return err
}
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error {
err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection)
return err
}
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
ds.impl.Close()
return nil
}
// ---- RPC client domain ----
// databasePluginRPCClient implements Database and is used on the client to
// make RPC calls to a plugin.
type databasePluginRPCClient struct {
client *rpc.Client
}
func (dr *databasePluginRPCClient) Type() (string, error) {
var dbType string
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
return fmt.Sprintf("plugin-%s", dbType), err
}
func (dr *databasePluginRPCClient) CreateUser(_ context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
req := CreateUserRequestRPC{
Statements: statements,
UsernameConfig: usernameConfig,
Expiration: expiration,
}
var resp CreateUserResponse
err = dr.client.Call("Plugin.CreateUser", req, &resp)
return resp.Username, resp.Password, err
}
func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements Statements, username string, expiration time.Time) error {
req := RenewUserRequestRPC{
Statements: statements,
Username: username,
Expiration: expiration,
}
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error {
req := RevokeUserRequestRPC{
Statements: statements,
Username: username,
}
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error {
req := InitializeRequestRPC{
Config: conf,
VerifyConnection: verifyConnection,
}
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Close() error {
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
return err
}
// ---- RPC Request Args Domain ----
type InitializeRequestRPC struct {
Config map[string]interface{}
VerifyConnection bool
}
type CreateUserRequestRPC struct {
Statements Statements
UsernameConfig UsernameConfig
Expiration time.Time
}
type RenewUserRequestRPC struct {
Statements Statements
Username string
Expiration time.Time
}
type RevokeUserRequestRPC struct {
Statements Statements
Username string
}

View File

@ -1,10 +1,13 @@
package dbplugin
import (
"context"
"fmt"
"net/rpc"
"time"
"google.golang.org/grpc"
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil"
log "github.com/mgutz/logxi/v1"
@ -13,29 +16,14 @@ import (
// Database is the interface that all database objects must implement.
type Database interface {
Type() (string, error)
CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
RenewUser(statements Statements, username string, expiration time.Time) error
RevokeUser(statements Statements, username string) error
CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error
RevokeUser(ctx context.Context, statements Statements, username string) error
Initialize(config map[string]interface{}, verifyConnection bool) error
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error
Close() error
}
// Statements set in role creation and passed into the database type's functions.
type Statements struct {
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"`
RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"`
RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"`
}
// UsernameConfig is used to configure prefixes for the username to be
// generated.
type UsernameConfig struct {
DisplayName string
RoleName string
}
// PluginFactory is used to build plugin database types. It wraps the database
// object in a logging and metrics middleware.
func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
@ -45,6 +33,7 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
return nil, err
}
var transport string
var db Database
if pluginRunner.Builtin {
// Plugin is builtin so we can retrieve an instance of the interface
@ -60,12 +49,24 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
return nil, fmt.Errorf("unsuported database type: %s", pluginName)
}
transport = "builtin"
} else {
// create a DatabasePluginClient instance
db, err = newPluginClient(sys, pluginRunner, logger)
if err != nil {
return nil, err
}
// Switch on the underlying database client type to get the transport
// method.
switch db.(*DatabasePluginClient).Database.(type) {
case *gRPCClient:
transport = "gRPC"
case *databasePluginRPCClient:
transport = "netRPC"
}
}
typeStr, err := db.Type()
@ -82,9 +83,10 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
// Wrap with tracing middleware
if logger.IsTrace() {
db = &databaseTracingMiddleware{
next: db,
typeStr: typeStr,
logger: logger,
transport: transport,
next: db,
typeStr: typeStr,
logger: logger,
}
}
@ -115,33 +117,14 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
return &databasePluginRPCClient{client: c}, nil
}
// ---- RPC Request Args Domain ----
type InitializeRequest struct {
Config map[string]interface{}
VerifyConnection bool
func (d DatabasePlugin) GRPCServer(s *grpc.Server) error {
RegisterDatabaseServer(s, &gRPCServer{impl: d.impl})
return nil
}
type CreateUserRequest struct {
Statements Statements
UsernameConfig UsernameConfig
Expiration time.Time
}
type RenewUserRequest struct {
Statements Statements
Username string
Expiration time.Time
}
type RevokeUserRequest struct {
Statements Statements
Username string
}
// ---- RPC Response Args Domain ----
type CreateUserResponse struct {
Username string
Password string
func (DatabasePlugin) GRPCClient(c *grpc.ClientConn) (interface{}, error) {
return &gRPCClient{
client: NewDatabaseClient(c),
clientConn: c,
}, nil
}

View File

@ -1,11 +1,13 @@
package dbplugin_test
import (
"context"
"errors"
"os"
"testing"
"time"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/pluginutil"
vaulthttp "github.com/hashicorp/vault/http"
@ -20,7 +22,7 @@ type mockPlugin struct {
}
func (m *mockPlugin) Type() (string, error) { return "mock", nil }
func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (m *mockPlugin) CreateUser(_ context.Context, statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
err = errors.New("err")
if usernameConf.DisplayName == "" || expiration.IsZero() {
return "", "", err
@ -34,7 +36,7 @@ func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbp
return usernameConf.DisplayName, "test", nil
}
func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
func (m *mockPlugin) RenewUser(_ context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
err := errors.New("err")
if username == "" || expiration.IsZero() {
return err
@ -46,7 +48,7 @@ func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string,
return nil
}
func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) error {
func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statements, username string) error {
err := errors.New("err")
if username == "" {
return err
@ -59,7 +61,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string)
delete(m.users, username)
return nil
}
func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error {
func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error {
err := errors.New("err")
if len(conf) != 1 {
return err
@ -80,14 +82,15 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
cores := cluster.Cores
sys := vault.TestDynamicSystemView(cores[0].Core)
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_Main")
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_GRPC_Main")
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", "TestPlugin_NetRPC_Main")
return cluster, sys
}
// This is not an actual test case, it's a helper function that will be executed
// by the go-plugin client via an exec call.
func TestPlugin_Main(t *testing.T) {
func TestPlugin_GRPC_Main(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return
}
@ -105,6 +108,30 @@ func TestPlugin_Main(t *testing.T) {
plugins.Serve(plugin, apiClientMeta.GetTLSConfig())
}
// This is not an actual test case, it's a helper function that will be executed
// by the go-plugin client via an exec call.
func TestPlugin_NetRPC_Main(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return
}
p := &mockPlugin{
users: make(map[string][]string),
}
args := []string{"--tls-skip-verify=true"}
apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(args)
tlsProvider := pluginutil.VaultPluginTLSProvider(apiClientMeta.GetTLSConfig())
serveConf := dbplugin.ServeConfig(p, tlsProvider)
serveConf.GRPCServer = nil
plugin.Serve(serveConf)
}
func TestPlugin_Initialize(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
@ -118,7 +145,7 @@ func TestPlugin_Initialize(t *testing.T) {
"test": 1,
}
err = dbRaw.Initialize(connectionDetails, true)
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -143,7 +170,7 @@ func TestPlugin_CreateUser(t *testing.T) {
"test": 1,
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -153,7 +180,7 @@ func TestPlugin_CreateUser(t *testing.T) {
RoleName: "test",
}
us, pw, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -163,7 +190,7 @@ func TestPlugin_CreateUser(t *testing.T) {
// try and save the same user again to verify it saved the first time, this
// should return an error
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("expected an error, user wasn't created correctly")
}
@ -182,7 +209,7 @@ func TestPlugin_RenewUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -192,12 +219,12 @@ func TestPlugin_RenewUser(t *testing.T) {
RoleName: "test",
}
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.RenewUser(dbplugin.Statements{}, us, time.Now().Add(time.Minute))
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -216,7 +243,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -226,19 +253,159 @@ func TestPlugin_RevokeUser(t *testing.T) {
RoleName: "test",
}
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
// Test default revoke statememts
err = db.RevokeUser(dbplugin.Statements{}, us)
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
if err != nil {
t.Fatalf("err: %s", err)
}
// Try adding the same username back so we can verify it was removed
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
}
// Test the code is still compatible with an old netRPC plugin
func TestPlugin_NetRPC_Initialize(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
dbRaw, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
connectionDetails := map[string]interface{}{
"test": 1,
}
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
err = dbRaw.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestPlugin_NetRPC_CreateUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
defer db.Close()
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConf := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
if us != "test" || pw != "test" {
t.Fatal("expected username and password to be 'test'")
}
// try and save the same user again to verify it saved the first time, this
// should return an error
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("expected an error, user wasn't created correctly")
}
}
func TestPlugin_NetRPC_RenewUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
defer db.Close()
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConf := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
defer db.Close()
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConf := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
// Test default revoke statememts
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
if err != nil {
t.Fatalf("err: %s", err)
}
// Try adding the same username back so we can verify it was removed
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -10,6 +10,10 @@ import (
// Database implementation in a databasePluginRPCServer object and starts a
// RPC server.
func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
plugin.Serve(ServeConfig(db, tlsProvider))
}
func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig {
dbPlugin := &DatabasePlugin{
impl: db,
}
@ -19,53 +23,10 @@ func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
"database": dbPlugin,
}
plugin.Serve(&plugin.ServeConfig{
return &plugin.ServeConfig{
HandshakeConfig: handshakeConfig,
Plugins: pluginMap,
TLSProvider: tlsProvider,
})
}
// ---- RPC server domain ----
// databasePluginRPCServer implements an RPC version of Database and is run
// inside a plugin. It wraps an underlying implementation of Database.
type databasePluginRPCServer struct {
impl Database
}
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
var err error
*resp, err = ds.impl.Type()
return err
}
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error {
var err error
resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernameConfig, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error {
err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error {
err := ds.impl.RevokeUser(args.Statements, args.Username)
return err
}
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequest, _ *struct{}) error {
err := ds.impl.Initialize(args.Config, args.VerifyConnection)
return err
}
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
ds.impl.Close()
return nil
GRPCServer: plugin.DefaultGRPCServer,
}
}

View File

@ -1,6 +1,7 @@
package database
import (
"context"
"errors"
"fmt"
@ -62,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc {
b.clearConnection(name)
// Execute plugin again, we don't need the object so throw away.
_, err := b.createDBObj(req.Storage, name)
_, err := b.createDBObj(context.TODO(), req.Storage, name)
if err != nil {
return nil, err
}
@ -230,7 +231,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
}
err = db.Initialize(config.ConnectionDetails, verifyConnection)
err = db.Initialize(context.TODO(), config.ConnectionDetails, verifyConnection)
if err != nil {
db.Close()
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil

View File

@ -1,6 +1,7 @@
package database
import (
"context"
"fmt"
"time"
@ -66,7 +67,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
unlockFunc = b.Unlock
// Create a new DB object
db, err = b.createDBObj(req.Storage, role.DBName)
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
if err != nil {
unlockFunc()
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
@ -81,7 +82,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
}
// Create the user
username, password, err := db.CreateUser(role.Statements, usernameConfig, expiration)
username, password, err := db.CreateUser(context.TODO(), role.Statements, usernameConfig, expiration)
// Unlock
unlockFunc()
if err != nil {

View File

@ -1,6 +1,7 @@
package database
import (
"context"
"fmt"
"github.com/hashicorp/vault/logical"
@ -60,7 +61,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
unlockFunc = b.Unlock
// Create a new DB object
db, err = b.createDBObj(req.Storage, role.DBName)
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
if err != nil {
unlockFunc()
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
@ -69,7 +70,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
// Make sure we increase the VALID UNTIL endpoint for this user.
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
err := db.RenewUser(role.Statements, username, expireTime)
err := db.RenewUser(context.TODO(), role.Statements, username, expireTime)
// Unlock
unlockFunc()
if err != nil {
@ -119,14 +120,14 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
unlockFunc = b.Unlock
// Create a new DB object
db, err = b.createDBObj(req.Storage, role.DBName)
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
if err != nil {
unlockFunc()
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
}
}
err = db.RevokeUser(role.Statements, username)
err = db.RevokeUser(context.TODO(), role.Statements, username)
// Unlock
unlockFunc()
if err != nil {

View File

@ -119,6 +119,10 @@ func (r *PluginRunner) runCommon(wrapper RunnerUtil, pluginMap map[string]plugin
SecureConfig: secureConfig,
TLSConfig: clientTLSConfig,
Logger: namedLogger,
AllowedProtocols: []plugin.Protocol{
plugin.ProtocolNetRPC,
plugin.ProtocolGRPC,
},
}
client := plugin.NewClient(clientConfig)

View File

@ -192,8 +192,7 @@ func TestBackendHandleRequest_helpRoot(t *testing.T) {
func TestBackendHandleRequest_renewAuth(t *testing.T) {
b := &Backend{}
resp, err := b.HandleRequest(logical.RenewAuthRequest(
"/foo", &logical.Auth{}, nil))
resp, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -213,8 +212,7 @@ func TestBackendHandleRequest_renewAuthCallback(t *testing.T) {
AuthRenew: callback,
}
_, err := b.HandleRequest(logical.RenewAuthRequest(
"/foo", &logical.Auth{}, nil))
_, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -237,8 +235,7 @@ func TestBackendHandleRequest_renew(t *testing.T) {
Secrets: []*Secret{secret},
}
_, err := b.HandleRequest(logical.RenewRequest(
"/foo", secret.Response(nil, nil).Secret, nil))
_, err := b.HandleRequest(logical.RenewRequest("/foo", secret.Response(nil, nil).Secret, nil))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -293,8 +290,7 @@ func TestBackendHandleRequest_revoke(t *testing.T) {
Secrets: []*Secret{secret},
}
_, err := b.HandleRequest(logical.RevokeRequest(
"/foo", secret.Response(nil, nil).Secret, nil))
_, err := b.HandleRequest(logical.RevokeRequest("/foo", secret.Response(nil, nil).Secret, nil))
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -200,8 +200,7 @@ func (r *Request) SetLastRemoteWAL(last uint64) {
}
// RenewRequest creates the structure of the renew request.
func RenewRequest(
path string, secret *Secret, data map[string]interface{}) *Request {
func RenewRequest(path string, secret *Secret, data map[string]interface{}) *Request {
return &Request{
Operation: RenewOperation,
Path: path,
@ -211,8 +210,7 @@ func RenewRequest(
}
// RenewAuthRequest creates the structure of the renew request for an auth.
func RenewAuthRequest(
path string, auth *Auth, data map[string]interface{}) *Request {
func RenewAuthRequest(path string, auth *Auth, data map[string]interface{}) *Request {
return &Request{
Operation: RenewOperation,
Path: path,
@ -222,8 +220,7 @@ func RenewAuthRequest(
}
// RevokeRequest creates the structure of the revoke request.
func RevokeRequest(
path string, secret *Secret, data map[string]interface{}) *Request {
func RevokeRequest(path string, secret *Secret, data map[string]interface{}) *Request {
return &Request{
Operation: RevokeOperation,
Path: path,

View File

@ -1,6 +1,7 @@
package cassandra
import (
"context"
"strings"
"time"
@ -21,6 +22,8 @@ const (
cassandraTypeName = "cassandra"
)
var _ dbplugin.Database = &Cassandra{}
// Cassandra is an implementation of Database interface
type Cassandra struct {
connutil.ConnectionProducer
@ -64,8 +67,8 @@ func (c *Cassandra) Type() (string, error) {
return cassandraTypeName, nil
}
func (c *Cassandra) getConnection() (*gocql.Session, error) {
session, err := c.Connection()
func (c *Cassandra) getConnection(ctx context.Context) (*gocql.Session, error) {
session, err := c.Connection(ctx)
if err != nil {
return nil, err
}
@ -75,13 +78,13 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) {
// CreateUser generates the username/password on the underlying Cassandra secret backend as instructed by
// the CreationStatement provided.
func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock
c.Lock()
defer c.Unlock()
// Get the connection
session, err := c.getConnection()
session, err := c.getConnection(ctx)
if err != nil {
return "", "", err
}
@ -138,18 +141,18 @@ func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig db
}
// RenewUser is not supported on Cassandra, so this is a no-op.
func (c *Cassandra) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
func (c *Cassandra) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// NOOP
return nil
}
// RevokeUser attempts to drop the specified user.
func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) error {
func (c *Cassandra) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
// Grab the lock
c.Lock()
defer c.Unlock()
session, err := c.getConnection()
session, err := c.getConnection(ctx)
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package cassandra
import (
"context"
"os"
"strconv"
"testing"
@ -89,7 +90,7 @@ func TestCassandra_Initialize(t *testing.T) {
db := dbRaw.(*Cassandra)
connProducer := db.ConnectionProducer.(*cassandraConnectionProducer)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -112,7 +113,7 @@ func TestCassandra_Initialize(t *testing.T) {
"protocol_version": "4",
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -135,7 +136,7 @@ func TestCassandra_CreateUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*Cassandra)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -149,7 +150,7 @@ func TestCassandra_CreateUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -176,7 +177,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*Cassandra)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -190,7 +191,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -199,7 +200,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err)
}
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -222,7 +223,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*Cassandra)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -236,7 +237,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -246,7 +247,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
}
// Test default revoke statememts
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -1,6 +1,7 @@
package cassandra
import (
"context"
"crypto/tls"
"fmt"
"strings"
@ -43,7 +44,7 @@ type cassandraConnectionProducer struct {
sync.Mutex
}
func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
c.Lock()
defer c.Unlock()
@ -106,7 +107,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve
c.Initialized = true
if verifyConnection {
if _, err := c.Connection(); err != nil {
if _, err := c.Connection(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
}
}
@ -114,7 +115,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve
return nil
}
func (c *cassandraConnectionProducer) Connection() (interface{}, error) {
func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) {
if !c.Initialized {
return nil, connutil.ErrNotInitialized
}

View File

@ -1,6 +1,7 @@
package hana
import (
"context"
"database/sql"
"fmt"
"strings"
@ -26,6 +27,8 @@ type HANA struct {
credsutil.CredentialsProducer
}
var _ dbplugin.Database = &HANA{}
// New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) {
connProducer := &connutil.SQLConnectionProducer{}
@ -63,8 +66,8 @@ func (h *HANA) Type() (string, error) {
return hanaTypeName, nil
}
func (h *HANA) getConnection() (*sql.DB, error) {
db, err := h.Connection()
func (h *HANA) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := h.Connection(ctx)
if err != nil {
return nil, err
}
@ -74,13 +77,13 @@ func (h *HANA) getConnection() (*sql.DB, error) {
// CreateUser generates the username/password on the underlying HANA secret backend
// as instructed by the CreationStatement provided.
func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock
h.Lock()
defer h.Unlock()
// Get the connection
db, err := h.getConnection()
db, err := h.getConnection(ctx)
if err != nil {
return "", "", err
}
@ -153,9 +156,9 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
}
// Renewing hana user just means altering user's valid until property
func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// Get connection
db, err := h.getConnection()
db, err := h.getConnection(ctx)
if err != nil {
return err
}
@ -193,14 +196,14 @@ func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expira
}
// Revoking hana user will deactivate user and try to perform a soft drop
func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error {
func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
// default revoke will be a soft drop on user
if statements.RevocationStatements == "" {
return h.revokeUserDefault(username)
return h.revokeUserDefault(ctx, username)
}
// Get connection
db, err := h.getConnection()
db, err := h.getConnection(ctx)
if err != nil {
return err
}
@ -239,9 +242,9 @@ func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error
return nil
}
func (h *HANA) revokeUserDefault(username string) error {
func (h *HANA) revokeUserDefault(ctx context.Context, username string) error {
// Get connection
db, err := h.getConnection()
db, err := h.getConnection(ctx)
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package hana
import (
"context"
"database/sql"
"fmt"
"os"
@ -25,7 +26,7 @@ func TestHANA_Initialize(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*HANA)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -55,7 +56,7 @@ func TestHANA_CreateUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*HANA)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -66,7 +67,7 @@ func TestHANA_CreateUser(t *testing.T) {
}
// Test with no configured Creation Statememt
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour))
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
}
@ -75,7 +76,7 @@ func TestHANA_CreateUser(t *testing.T) {
CreationStatements: testHANARole,
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -98,7 +99,7 @@ func TestHANA_RevokeUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*HANA)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -113,7 +114,7 @@ func TestHANA_RevokeUser(t *testing.T) {
}
// Test default revoke statememts
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -121,7 +122,7 @@ func TestHANA_RevokeUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err)
}
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -130,7 +131,7 @@ func TestHANA_RevokeUser(t *testing.T) {
}
// Test custom revoke statememt
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -139,7 +140,7 @@ func TestHANA_RevokeUser(t *testing.T) {
}
statements.RevocationStatements = testHANADrop
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -1,6 +1,7 @@
package mongodb
import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
@ -33,7 +34,7 @@ type mongoDBConnectionProducer struct {
}
// Initialize parses connection configuration.
func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
c.Lock()
defer c.Unlock()
@ -75,7 +76,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri
c.Initialized = true
if verifyConnection {
if _, err := c.Connection(); err != nil {
if _, err := c.Connection(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
}
@ -88,7 +89,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri
}
// Connection creates a database connection.
func (c *mongoDBConnectionProducer) Connection() (interface{}, error) {
func (c *mongoDBConnectionProducer) Connection(_ context.Context) (interface{}, error) {
if !c.Initialized {
return nil, connutil.ErrNotInitialized
}

View File

@ -1,6 +1,7 @@
package mongodb
import (
"context"
"io"
"strings"
"time"
@ -27,6 +28,8 @@ type MongoDB struct {
credsutil.CredentialsProducer
}
var _ dbplugin.Database = &MongoDB{}
// New returns a new MongoDB instance
func New() (interface{}, error) {
connProducer := &mongoDBConnectionProducer{}
@ -63,8 +66,8 @@ func (m *MongoDB) Type() (string, error) {
return mongoDBTypeName, nil
}
func (m *MongoDB) getConnection() (*mgo.Session, error) {
session, err := m.Connection()
func (m *MongoDB) getConnection(ctx context.Context) (*mgo.Session, error) {
session, err := m.Connection(ctx)
if err != nil {
return nil, err
}
@ -80,7 +83,7 @@ func (m *MongoDB) getConnection() (*mgo.Session, error) {
//
// JSON Example:
// { "db": "admin", "roles": [{ "role": "readWrite" }, {"role": "read", "db": "foo"}] }
func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock
m.Lock()
defer m.Unlock()
@ -89,7 +92,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
return "", "", dbutil.ErrEmptyCreationStatement
}
session, err := m.getConnection()
session, err := m.getConnection(ctx)
if err != nil {
return "", "", err
}
@ -133,7 +136,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
if err := m.ConnectionProducer.Close(); err != nil {
return "", "", errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
}
session, err := m.getConnection()
session, err := m.getConnection(ctx)
if err != nil {
return "", "", err
}
@ -149,15 +152,15 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
}
// RenewUser is not supported on MongoDB, so this is a no-op.
func (m *MongoDB) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
func (m *MongoDB) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// NOOP
return nil
}
// RevokeUser drops the specified user from the authentication databse. If none is provided
// in the revocation statement, the default "admin" authentication database will be assumed.
func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) error {
session, err := m.getConnection()
func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
session, err := m.getConnection(ctx)
if err != nil {
return err
}
@ -188,7 +191,7 @@ func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) er
if err := m.ConnectionProducer.Close(); err != nil {
return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
}
session, err := m.getConnection()
session, err := m.getConnection(ctx)
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package mongodb
import (
"context"
"fmt"
"os"
"testing"
@ -79,7 +80,7 @@ func TestMongoDB_Initialize(t *testing.T) {
db := dbRaw.(*MongoDB)
connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer)
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -107,7 +108,7 @@ func TestMongoDB_CreateUser(t *testing.T) {
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -121,7 +122,7 @@ func TestMongoDB_CreateUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -145,7 +146,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -159,7 +160,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -182,7 +183,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -196,7 +197,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -206,7 +207,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
}
// Test default revocation statememt
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -1,6 +1,7 @@
package mssql
import (
"context"
"database/sql"
"fmt"
"strings"
@ -18,6 +19,8 @@ import (
const msSQLTypeName = "mssql"
var _ dbplugin.Database = &MSSQL{}
// MSSQL is an implementation of Database interface
type MSSQL struct {
connutil.ConnectionProducer
@ -60,8 +63,8 @@ func (m *MSSQL) Type() (string, error) {
return msSQLTypeName, nil
}
func (m *MSSQL) getConnection() (*sql.DB, error) {
db, err := m.Connection()
func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := m.Connection(ctx)
if err != nil {
return nil, err
}
@ -71,13 +74,13 @@ func (m *MSSQL) getConnection() (*sql.DB, error) {
// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
// the CreationStatement provided.
func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock
m.Lock()
defer m.Unlock()
// Get the connection
db, err := m.getConnection()
db, err := m.getConnection(ctx)
if err != nil {
return "", "", err
}
@ -138,7 +141,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
}
// RenewUser is not supported on MSSQL, so this is a no-op.
func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// NOOP
return nil
}
@ -146,13 +149,13 @@ func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expir
// RevokeUser attempts to drop the specified user. It will first attempt to disable login,
// then kill pending connections from that user, and finally drop the user and login from the
// database instance.
func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error {
func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
if statements.RevocationStatements == "" {
return m.revokeUserDefault(username)
return m.revokeUserDefault(ctx, username)
}
// Get connection
db, err := m.getConnection()
db, err := m.getConnection(ctx)
if err != nil {
return err
}
@ -191,9 +194,9 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro
return nil
}
func (m *MSSQL) revokeUserDefault(username string) error {
func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
// Get connection
db, err := m.getConnection()
db, err := m.getConnection(ctx)
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package mssql
import (
"context"
"database/sql"
"fmt"
"os"
@ -30,7 +31,7 @@ func TestMSSQL_Initialize(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*MSSQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -51,7 +52,7 @@ func TestMSSQL_Initialize(t *testing.T) {
"max_open_connections": "5",
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -69,7 +70,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*MSSQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -80,7 +81,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
}
// Test with no configured Creation Statememt
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
}
@ -89,7 +90,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
CreationStatements: testMSSQLRole,
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -111,7 +112,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*MSSQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -125,7 +126,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -135,7 +136,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
}
// Test default revoke statememts
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -144,7 +145,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
t.Fatal("Credentials were not revoked")
}
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -155,7 +156,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
// Test custom revoke statememt
statements.RevocationStatements = testMSSQLDrop
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -1,6 +1,7 @@
package mysql
import (
"context"
"database/sql"
"strings"
"time"
@ -30,6 +31,8 @@ var (
LegacyUsernameLen int = 16
)
var _ dbplugin.Database = &MySQL{}
type MySQL struct {
connutil.ConnectionProducer
credsutil.CredentialsProducer
@ -88,8 +91,8 @@ func (m *MySQL) Type() (string, error) {
return mySQLTypeName, nil
}
func (m *MySQL) getConnection() (*sql.DB, error) {
db, err := m.Connection()
func (m *MySQL) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := m.Connection(ctx)
if err != nil {
return nil, err
}
@ -97,13 +100,13 @@ func (m *MySQL) getConnection() (*sql.DB, error) {
return db.(*sql.DB), nil
}
func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock
m.Lock()
defer m.Unlock()
// Get the connection
db, err := m.getConnection()
db, err := m.getConnection(ctx)
if err != nil {
return "", "", err
}
@ -128,7 +131,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
}
// Start a transaction
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
}
@ -146,7 +149,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
"expiration": expirationStr,
})
stmt, err := tx.Prepare(query)
stmt, err := tx.PrepareContext(ctx, query)
if err != nil {
// If the error code we get back is Error 1295: This command is not
// supported in the prepared statement protocol yet, we will execute
@ -155,7 +158,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
// prepare supported commands. If there is no error when running we
// will continue to the next statement.
if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 {
_, err = tx.Exec(query)
_, err = tx.ExecContext(ctx, query)
if err != nil {
return "", "", err
}
@ -165,7 +168,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
return "", "", err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err
}
}
@ -179,17 +182,17 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
}
// NOOP
func (m *MySQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
func (m *MySQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
return nil
}
func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) error {
func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
// Grab the read lock
m.Lock()
defer m.Unlock()
// Get the connection
db, err := m.getConnection()
db, err := m.getConnection(ctx)
if err != nil {
return err
}
@ -201,7 +204,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro
}
// Start a transaction
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
@ -217,7 +220,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro
// 1295: This command is not supported in the prepared statement protocol yet
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
query = strings.Replace(query, "{{name}}", username, -1)
_, err = tx.Exec(query)
_, err = tx.ExecContext(ctx, query)
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package mysql
import (
"context"
"database/sql"
"fmt"
"os"
@ -108,7 +109,7 @@ func TestMySQL_Initialize(t *testing.T) {
db := dbRaw.(*MySQL)
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -128,7 +129,7 @@ func TestMySQL_Initialize(t *testing.T) {
"max_open_connections": "5",
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -146,7 +147,7 @@ func TestMySQL_CreateUser(t *testing.T) {
dbRaw, _ := f()
db := dbRaw.(*MySQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -157,7 +158,7 @@ func TestMySQL_CreateUser(t *testing.T) {
}
// Test with no configured Creation Statememt
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
}
@ -166,7 +167,7 @@ func TestMySQL_CreateUser(t *testing.T) {
CreationStatements: testMySQLRoleWildCard,
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -176,7 +177,7 @@ func TestMySQL_CreateUser(t *testing.T) {
}
// Test a second time to make sure usernames don't collide
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -188,7 +189,7 @@ func TestMySQL_CreateUser(t *testing.T) {
// Test with a manualy prepare statement
statements.CreationStatements = testMySQLRolePreparedStmt
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -211,7 +212,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
dbRaw, _ := f()
db := dbRaw.(*MySQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -222,7 +223,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
}
// Test with no configured Creation Statememt
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
}
@ -231,7 +232,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
CreationStatements: testMySQLRoleWildCard,
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -241,7 +242,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
}
// Test a second time to make sure usernames don't collide
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -263,7 +264,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
dbRaw, _ := f()
db := dbRaw.(*MySQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -277,7 +278,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -287,7 +288,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
}
// Test default revoke statememts
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -297,7 +298,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
}
statements.CreationStatements = testMySQLRoleWildCard
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -308,7 +309,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
// Test custom revoke statements
statements.RevocationStatements = testMySQLRevocationSQL
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -1,6 +1,7 @@
package postgresql
import (
"context"
"database/sql"
"fmt"
"strings"
@ -24,6 +25,8 @@ ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
`
)
var _ dbplugin.Database = &PostgreSQL{}
// New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) {
connProducer := &connutil.SQLConnectionProducer{}
@ -65,8 +68,8 @@ func (p *PostgreSQL) Type() (string, error) {
return postgreSQLTypeName, nil
}
func (p *PostgreSQL) getConnection() (*sql.DB, error) {
db, err := p.Connection()
func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := p.Connection(ctx)
if err != nil {
return nil, err
}
@ -74,7 +77,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) {
return db.(*sql.DB), nil
}
func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
if statements.CreationStatements == "" {
return "", "", dbutil.ErrEmptyCreationStatement
}
@ -99,7 +102,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
}
// Get the connection
db, err := p.getConnection()
db, err := p.getConnection(ctx)
if err != nil {
return "", "", err
@ -148,7 +151,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
return username, password, nil
}
func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
p.Lock()
defer p.Unlock()
@ -157,7 +160,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
renewStmts = defaultPostgresRenewSQL
}
db, err := p.getConnection()
db, err := p.getConnection(ctx)
if err != nil {
return err
}
@ -201,20 +204,20 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
return nil
}
func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error {
func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
// Grab the lock
p.Lock()
defer p.Unlock()
if statements.RevocationStatements == "" {
return p.defaultRevokeUser(username)
return p.defaultRevokeUser(ctx, username)
}
return p.customRevokeUser(username, statements.RevocationStatements)
return p.customRevokeUser(ctx, username, statements.RevocationStatements)
}
func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
db, err := p.getConnection()
func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationStmts string) error {
db, err := p.getConnection(ctx)
if err != nil {
return err
}
@ -253,8 +256,8 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
return nil
}
func (p *PostgreSQL) defaultRevokeUser(username string) error {
db, err := p.getConnection()
func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error {
db, err := p.getConnection(ctx)
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package postgresql
import (
"context"
"database/sql"
"fmt"
"os"
@ -72,7 +73,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -92,7 +93,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
"max_open_connections": "5",
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -109,7 +110,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*PostgreSQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -120,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
}
// Test with no configured Creation Statememt
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
}
@ -129,7 +130,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
CreationStatements: testPostgresRole,
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -139,7 +140,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
}
statements.CreationStatements = testPostgresReadOnlyRole
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -162,7 +163,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*PostgreSQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -176,7 +177,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -185,7 +186,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err)
}
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -197,7 +198,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err)
}
statements.RenewStatements = defaultPostgresRenewSQL
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -206,7 +207,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err)
}
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -230,7 +231,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*PostgreSQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -244,7 +245,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -254,7 +255,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
}
// Test default revoke statememts
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -263,7 +264,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
t.Fatal("Credentials were not revoked")
}
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -274,7 +275,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
// Test custom revoke statements
statements.RevocationStatements = defaultPostgresRevocationSQL
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -1,6 +1,7 @@
package connutil
import (
"context"
"errors"
"sync"
)
@ -14,8 +15,8 @@ var (
// connections and is used in all the builtin database types.
type ConnectionProducer interface {
Close() error
Initialize(map[string]interface{}, bool) error
Connection() (interface{}, error)
Initialize(context.Context, map[string]interface{}, bool) error
Connection(context.Context) (interface{}, error)
sync.Locker
}

View File

@ -1,6 +1,7 @@
package connutil
import (
"context"
"database/sql"
"fmt"
"strings"
@ -25,7 +26,7 @@ type SQLConnectionProducer struct {
sync.Mutex
}
func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
c.Lock()
defer c.Unlock()
@ -62,11 +63,11 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo
c.Initialized = true
if verifyConnection {
if _, err := c.Connection(); err != nil {
if _, err := c.Connection(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
}
if err := c.db.Ping(); err != nil {
if err := c.db.PingContext(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
}
}
@ -74,14 +75,14 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo
return nil
}
func (c *SQLConnectionProducer) Connection() (interface{}, error) {
func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
if !c.Initialized {
return nil, ErrNotInitialized
}
// If we already have a DB, test it and return
if c.db != nil {
if err := c.db.Ping(); err == nil {
if err := c.db.PingContext(ctx); err == nil {
return c.db, nil
}
// If the ping was unsuccessful, close it and ignore errors as we'll be

View File

@ -1017,8 +1017,7 @@ func (m *ExpirationManager) revokeEntry(le *leaseEntry) error {
}
// Handle standard revocation via backends
resp, err := m.router.Route(logical.RevokeRequest(
le.Path, le.Secret, le.Data))
resp, err := m.router.Route(logical.RevokeRequest(le.Path, le.Secret, le.Data))
if err != nil || (resp != nil && resp.IsError()) {
return fmt.Errorf("failed to revoke entry: resp:%#v err:%s", resp, err)
}