Merge remote-tracking branch 'oss/master' into f-nomad

* oss/master:
  Defer reader.Close that is used to determine sha256
  changelog++
  Avoid unseal failure if plugin backends fail to setup during postUnseal (#3686)
  Add logic for using Auth.Period when handling auth login/renew requests (#3677)
  plugins/database: use context with plugins that use database/sql package (#3691)
  changelog++
  Fix plaintext backup in transit (#3692)
  Database gRPC plugins (#3666)
This commit is contained in:
Chris Hoffman 2017-12-15 17:05:42 -05:00
commit db0006ef65
47 changed files with 1824 additions and 526 deletions

View File

@ -33,6 +33,8 @@ FEATURES:
operation that can export a given key, including all key versions and
configuration, as well as a restore operation allowing import into another
Vault.
* **gRPC Database Plugins**: Database plugins now use gRPC for transport,
allowing them to be written in other languages.
IMPROVEMENTS:
@ -46,6 +48,10 @@ IMPROVEMENTS:
during database configuration. This establishes a session-wide [write
concern](https://docs.mongodb.com/manual/reference/write-concern/) for the
lifecycle of the mount [GH-3646]
* mfa/okta: Filter a given email address as a login filter, allowing operation
when login email and account email are different
* plugins: Make Vault more resilient when unsealing when plugins are
unavailable [GH-3686]
* secret/pki: `allowed_domains` and `key_usage` can now be specified
as a comma-separated string or an array of strings [GH-3642]
* secret/ssh: Allow 4096-bit keys to be used in dynamic key method [GH-3593]
@ -58,8 +64,12 @@ BUG FIXES:
* auth/cert: Return `allowed_names` on role read [GH-3654]
* auth/ldap: Fix incorrect control information being sent [GH-3402] [GH-3496]
[GH-3625] [GH-3656]
* core: Fix seal status reporting when using an autoseal
* core: Add creation path to wrap info for a control group token
* core: Fix potential panic that could occur using plugins when a node
transitioned from active to standby [GH-3638]
* core: Fix memory ballooning when a connection would connect to the cluster
port and then go away -- redux! [GH-3680]
* core: Replace recursive token revocation logic with depth-first logic, which
can avoid hitting stack depth limits in extreme cases [GH-2348]
* core/pkcs11 (enterprise): Fix panic when PKCS#11 library is not readable

View File

@ -84,6 +84,7 @@ proto:
protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding
protoc -I physical physical/types.proto --go_out=plugins=grpc:physical
protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity
protoc builtin/logical/database/dbplugin/*.proto --go_out=plugins=grpc:.
sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go
sed -i -e 's/Iv/IV/' -e 's/Hmac/HMAC/' physical/types.pb.go

View File

@ -141,7 +141,7 @@ func testAccStepMapUserIdCidr(t *testing.T, cidr string) logicaltest.TestStep {
func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
checkTTL := func(resp *logical.Response) error {
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
return fmt.Errorf("invalid TTL")
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
}
return nil
}
@ -165,7 +165,7 @@ func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep {
checkTTL := func(resp *logical.Response) error {
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
return fmt.Errorf("invalid TTL")
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
}
return nil
}

View File

@ -3,7 +3,6 @@ package approle
import (
"fmt"
"strings"
"time"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
@ -68,20 +67,13 @@ func (b *backend) pathLoginUpdate(req *logical.Request, data *framework.FieldDat
Policies: role.Policies,
LeaseOptions: logical.LeaseOptions{
Renewable: true,
TTL: role.TokenTTL,
},
Alias: &logical.Alias{
Name: role.RoleID,
},
}
// If 'Period' is set, use the value of 'Period' as the TTL.
// Otherwise, set the normal TokenTTL.
if role.Period > time.Duration(0) {
auth.TTL = role.Period
} else {
auth.TTL = role.TokenTTL
}
return &logical.Response{
Auth: auth,
}, nil
@ -107,16 +99,12 @@ func (b *backend) pathLoginRenew(req *logical.Request, data *framework.FieldData
return nil, fmt.Errorf("role %s does not exist during renewal", roleName)
}
// If 'Period' is set on the Role, the token should never expire.
// Replenish the TTL with 'Period's value.
if role.Period > time.Duration(0) {
// If 'Period' was updated after the token was issued,
// token will bear the updated 'Period' value as its TTL.
req.Auth.TTL = role.Period
return &logical.Response{Auth: req.Auth}, nil
} else {
return framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
resp, err := framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
if err != nil {
return nil, err
}
resp.Auth.Period = role.Period
return resp, nil
}
const pathLoginHelpSys = "Issue a token based on the credentials supplied"

View File

@ -1,6 +1,7 @@
package database
import (
"context"
"fmt"
"net/rpc"
"strings"
@ -87,7 +88,7 @@ func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) {
// This function creates a new db object from the stored configuration and
// caches it in the connections map. The caller of this function needs to hold
// the backend's write lock
func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) {
func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) {
db, ok := b.connections[name]
if ok {
return db, nil
@ -103,7 +104,7 @@ func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.
return nil, err
}
err = db.Initialize(config.ConnectionDetails, true)
err = db.Initialize(ctx, config.ConnectionDetails, true)
if err != nil {
return nil, err
}
@ -170,7 +171,8 @@ func (b *databaseBackend) clearConnection(name string) {
func (b *databaseBackend) closeIfShutdown(name string, err error) {
// Plugin has shutdown, close it so next call can reconnect.
if err == rpc.ErrShutdown {
switch err {
case rpc.ErrShutdown, dbplugin.ErrPluginShutdown:
b.Lock()
b.clearConnection(name)
b.Unlock()

View File

@ -488,9 +488,11 @@ func TestBackend_roleCrud(t *testing.T) {
RevocationStatements: defaultRevocationSQL,
}
var actual dbplugin.Statements
if err := mapstructure.Decode(resp.Data, &actual); err != nil {
t.Fatal(err)
actual := dbplugin.Statements{
CreationStatements: resp.Data["creation_statements"].(string),
RevocationStatements: resp.Data["revocation_statements"].(string),
RollbackStatements: resp.Data["rollback_statements"].(string),
RenewStatements: resp.Data["renew_statements"].(string),
}
if !reflect.DeepEqual(expected, actual) {

View File

@ -1,10 +1,8 @@
package dbplugin
import (
"fmt"
"net/rpc"
"errors"
"sync"
"time"
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil"
@ -17,11 +15,11 @@ type DatabasePluginClient struct {
client *plugin.Client
sync.Mutex
*databasePluginRPCClient
Database
}
func (dc *DatabasePluginClient) Close() error {
err := dc.databasePluginRPCClient.Close()
err := dc.Database.Close()
dc.client.Kill()
return err
@ -55,79 +53,20 @@ func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginR
// We should have a database type now. This feels like a normal interface
// implementation but is in fact over an RPC connection.
databaseRPC := raw.(*databasePluginRPCClient)
var db Database
switch raw.(type) {
case *gRPCClient:
db = raw.(*gRPCClient)
case *databasePluginRPCClient:
logger.Warn("database: plugin is using deprecated net RPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name)
db = raw.(*databasePluginRPCClient)
default:
return nil, errors.New("unsupported client type")
}
// Wrap RPC implimentation in DatabasePluginClient
return &DatabasePluginClient{
client: client,
databasePluginRPCClient: databaseRPC,
client: client,
Database: db,
}, nil
}
// ---- RPC client domain ----
// databasePluginRPCClient implements Database and is used on the client to
// make RPC calls to a plugin.
type databasePluginRPCClient struct {
client *rpc.Client
}
func (dr *databasePluginRPCClient) Type() (string, error) {
var dbType string
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
return fmt.Sprintf("plugin-%s", dbType), err
}
func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
req := CreateUserRequest{
Statements: statements,
UsernameConfig: usernameConfig,
Expiration: expiration,
}
var resp CreateUserResponse
err = dr.client.Call("Plugin.CreateUser", req, &resp)
return resp.Username, resp.Password, err
}
func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error {
req := RenewUserRequest{
Statements: statements,
Username: username,
Expiration: expiration,
}
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error {
req := RevokeUserRequest{
Statements: statements,
Username: username,
}
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}, verifyConnection bool) error {
req := InitializeRequest{
Config: conf,
VerifyConnection: verifyConnection,
}
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Close() error {
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
return err
}

View File

@ -0,0 +1,556 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: builtin/logical/database/dbplugin/database.proto
/*
Package dbplugin is a generated protocol buffer package.
It is generated from these files:
builtin/logical/database/dbplugin/database.proto
It has these top-level messages:
InitializeRequest
CreateUserRequest
RenewUserRequest
RevokeUserRequest
Statements
UsernameConfig
CreateUserResponse
TypeResponse
Empty
*/
package dbplugin
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import google_protobuf "github.com/golang/protobuf/ptypes/timestamp"
import (
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type InitializeRequest struct {
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
VerifyConnection bool `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"`
}
func (m *InitializeRequest) Reset() { *m = InitializeRequest{} }
func (m *InitializeRequest) String() string { return proto.CompactTextString(m) }
func (*InitializeRequest) ProtoMessage() {}
func (*InitializeRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *InitializeRequest) GetConfig() []byte {
if m != nil {
return m.Config
}
return nil
}
func (m *InitializeRequest) GetVerifyConnection() bool {
if m != nil {
return m.VerifyConnection
}
return false
}
type CreateUserRequest struct {
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"`
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
}
func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} }
func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) }
func (*CreateUserRequest) ProtoMessage() {}
func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *CreateUserRequest) GetStatements() *Statements {
if m != nil {
return m.Statements
}
return nil
}
func (m *CreateUserRequest) GetUsernameConfig() *UsernameConfig {
if m != nil {
return m.UsernameConfig
}
return nil
}
func (m *CreateUserRequest) GetExpiration() *google_protobuf.Timestamp {
if m != nil {
return m.Expiration
}
return nil
}
type RenewUserRequest struct {
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
}
func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} }
func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) }
func (*RenewUserRequest) ProtoMessage() {}
func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
func (m *RenewUserRequest) GetStatements() *Statements {
if m != nil {
return m.Statements
}
return nil
}
func (m *RenewUserRequest) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
func (m *RenewUserRequest) GetExpiration() *google_protobuf.Timestamp {
if m != nil {
return m.Expiration
}
return nil
}
type RevokeUserRequest struct {
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
}
func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} }
func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) }
func (*RevokeUserRequest) ProtoMessage() {}
func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
func (m *RevokeUserRequest) GetStatements() *Statements {
if m != nil {
return m.Statements
}
return nil
}
func (m *RevokeUserRequest) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
type Statements struct {
CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"`
RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"`
RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"`
RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"`
}
func (m *Statements) Reset() { *m = Statements{} }
func (m *Statements) String() string { return proto.CompactTextString(m) }
func (*Statements) ProtoMessage() {}
func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
func (m *Statements) GetCreationStatements() string {
if m != nil {
return m.CreationStatements
}
return ""
}
func (m *Statements) GetRevocationStatements() string {
if m != nil {
return m.RevocationStatements
}
return ""
}
func (m *Statements) GetRollbackStatements() string {
if m != nil {
return m.RollbackStatements
}
return ""
}
func (m *Statements) GetRenewStatements() string {
if m != nil {
return m.RenewStatements
}
return ""
}
type UsernameConfig struct {
DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"`
RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"`
}
func (m *UsernameConfig) Reset() { *m = UsernameConfig{} }
func (m *UsernameConfig) String() string { return proto.CompactTextString(m) }
func (*UsernameConfig) ProtoMessage() {}
func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
func (m *UsernameConfig) GetDisplayName() string {
if m != nil {
return m.DisplayName
}
return ""
}
func (m *UsernameConfig) GetRoleName() string {
if m != nil {
return m.RoleName
}
return ""
}
type CreateUserResponse struct {
Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"`
Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"`
}
func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} }
func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) }
func (*CreateUserResponse) ProtoMessage() {}
func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
func (m *CreateUserResponse) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
func (m *CreateUserResponse) GetPassword() string {
if m != nil {
return m.Password
}
return ""
}
type TypeResponse struct {
Type string `protobuf:"bytes,1,opt,name=type" json:"type,omitempty"`
}
func (m *TypeResponse) Reset() { *m = TypeResponse{} }
func (m *TypeResponse) String() string { return proto.CompactTextString(m) }
func (*TypeResponse) ProtoMessage() {}
func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
func (m *TypeResponse) GetType() string {
if m != nil {
return m.Type
}
return ""
}
type Empty struct {
}
func (m *Empty) Reset() { *m = Empty{} }
func (m *Empty) String() string { return proto.CompactTextString(m) }
func (*Empty) ProtoMessage() {}
func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} }
func init() {
proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest")
proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest")
proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest")
proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest")
proto.RegisterType((*Statements)(nil), "dbplugin.Statements")
proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig")
proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse")
proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse")
proto.RegisterType((*Empty)(nil), "dbplugin.Empty")
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// Client API for Database service
type DatabaseClient interface {
Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error)
CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error)
RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error)
RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error)
Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error)
Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
}
type databaseClient struct {
cc *grpc.ClientConn
}
func NewDatabaseClient(cc *grpc.ClientConn) DatabaseClient {
return &databaseClient{cc}
}
func (c *databaseClient) Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) {
out := new(TypeResponse)
err := grpc.Invoke(ctx, "/dbplugin.Database/Type", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) {
out := new(CreateUserResponse)
err := grpc.Invoke(ctx, "/dbplugin.Database/CreateUser", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/RenewUser", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/RevokeUser", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/Close", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// Server API for Database service
type DatabaseServer interface {
Type(context.Context, *Empty) (*TypeResponse, error)
CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error)
RenewUser(context.Context, *RenewUserRequest) (*Empty, error)
RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error)
Initialize(context.Context, *InitializeRequest) (*Empty, error)
Close(context.Context, *Empty) (*Empty, error)
}
func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) {
s.RegisterService(&_Database_serviceDesc, srv)
}
func _Database_Type_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Type(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Type",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Type(ctx, req.(*Empty))
}
return interceptor(ctx, in, info, handler)
}
func _Database_CreateUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CreateUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).CreateUser(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/CreateUser",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).CreateUser(ctx, req.(*CreateUserRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_RenewUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RenewUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).RenewUser(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/RenewUser",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).RenewUser(ctx, req.(*RenewUserRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RevokeUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).RevokeUser(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/RevokeUser",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).RevokeUser(ctx, req.(*RevokeUserRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(InitializeRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Initialize(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Initialize",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Close(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Close",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Close(ctx, req.(*Empty))
}
return interceptor(ctx, in, info, handler)
}
var _Database_serviceDesc = grpc.ServiceDesc{
ServiceName: "dbplugin.Database",
HandlerType: (*DatabaseServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Type",
Handler: _Database_Type_Handler,
},
{
MethodName: "CreateUser",
Handler: _Database_CreateUser_Handler,
},
{
MethodName: "RenewUser",
Handler: _Database_RenewUser_Handler,
},
{
MethodName: "RevokeUser",
Handler: _Database_RevokeUser_Handler,
},
{
MethodName: "Initialize",
Handler: _Database_Initialize_Handler,
},
{
MethodName: "Close",
Handler: _Database_Close_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "builtin/logical/database/dbplugin/database.proto",
}
func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 548 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e,
0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f,
0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e,
0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2,
0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6,
0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c,
0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e,
0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3,
0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1,
0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f,
0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49,
0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35,
0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20,
0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6,
0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34,
0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0,
0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce,
0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68,
0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0,
0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8,
0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1,
0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0,
0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a,
0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab,
0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c,
0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4,
0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36,
0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a,
0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf,
0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67,
0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b,
0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d,
0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6,
0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56,
0x94, 0x05, 0x00, 0x00,
}

View File

@ -0,0 +1,58 @@
syntax = "proto3";
package dbplugin;
import "google/protobuf/timestamp.proto";
message InitializeRequest {
bytes config = 1;
bool verify_connection = 2;
}
message CreateUserRequest {
Statements statements = 1;
UsernameConfig username_config = 2;
google.protobuf.Timestamp expiration = 3;
}
message RenewUserRequest {
Statements statements = 1;
string username = 2;
google.protobuf.Timestamp expiration = 3;
}
message RevokeUserRequest {
Statements statements = 1;
string username = 2;
}
message Statements {
string creation_statements = 1;
string revocation_statements = 2;
string rollback_statements = 3;
string renew_statements = 4;
}
message UsernameConfig {
string DisplayName = 1;
string RoleName = 2;
}
message CreateUserResponse {
string username = 1;
string password = 2;
}
message TypeResponse {
string type = 1;
}
message Empty {}
service Database {
rpc Type(Empty) returns (TypeResponse);
rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
rpc RenewUser(RenewUserRequest) returns (Empty);
rpc RevokeUser(RevokeUserRequest) returns (Empty);
rpc Initialize(InitializeRequest) returns (Empty);
rpc Close(Empty) returns (Empty);
}

View File

@ -1,6 +1,7 @@
package dbplugin
import (
"context"
"time"
metrics "github.com/armon/go-metrics"
@ -15,55 +16,56 @@ type databaseTracingMiddleware struct {
next Database
logger log.Logger
typeStr string
typeStr string
transport string
}
func (mw *databaseTracingMiddleware) Type() (string, error) {
return mw.next.Type()
}
func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (mw *databaseTracingMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr)
return mw.next.CreateUser(statements, usernameConfig, expiration)
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
}
func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
func (mw *databaseTracingMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr)
return mw.next.RenewUser(statements, username, expiration)
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr, "transport", mw.transport)
return mw.next.RenewUser(ctx, statements, username, expiration)
}
func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) {
func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr)
return mw.next.RevokeUser(statements, username)
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.RevokeUser(ctx, statements, username)
}
func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then))
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr)
return mw.next.Initialize(conf, verifyConnection)
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.Initialize(ctx, conf, verifyConnection)
}
func (mw *databaseTracingMiddleware) Close() (err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr)
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.Close()
}
@ -81,7 +83,7 @@ func (mw *databaseMetricsMiddleware) Type() (string, error) {
return mw.next.Type()
}
func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (mw *databaseMetricsMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "CreateUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now)
@ -94,10 +96,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameC
metrics.IncrCounter([]string{"database", "CreateUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1)
return mw.next.CreateUser(statements, usernameConfig, expiration)
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
}
func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
func (mw *databaseMetricsMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "RenewUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now)
@ -110,10 +112,10 @@ func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username s
metrics.IncrCounter([]string{"database", "RenewUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1)
return mw.next.RenewUser(statements, username, expiration)
return mw.next.RenewUser(ctx, statements, username, expiration)
}
func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username string) (err error) {
func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "RevokeUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now)
@ -126,10 +128,10 @@ func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username
metrics.IncrCounter([]string{"database", "RevokeUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1)
return mw.next.RevokeUser(statements, username)
return mw.next.RevokeUser(ctx, statements, username)
}
func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "Initialize"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
@ -142,7 +144,7 @@ func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, ver
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
return mw.next.Initialize(conf, verifyConnection)
return mw.next.Initialize(ctx, conf, verifyConnection)
}
func (mw *databaseMetricsMiddleware) Close() (err error) {

View File

@ -0,0 +1,198 @@
package dbplugin
import (
"context"
"encoding/json"
"errors"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"github.com/golang/protobuf/ptypes"
)
var (
ErrPluginShutdown = errors.New("plugin shutdown")
)
// ---- gRPC Server domain ----
type gRPCServer struct {
impl Database
}
func (s *gRPCServer) Type(context.Context, *Empty) (*TypeResponse, error) {
t, err := s.impl.Type()
if err != nil {
return nil, err
}
return &TypeResponse{
Type: t,
}, nil
}
func (s *gRPCServer) CreateUser(ctx context.Context, req *CreateUserRequest) (*CreateUserResponse, error) {
e, err := ptypes.Timestamp(req.Expiration)
if err != nil {
return nil, err
}
u, p, err := s.impl.CreateUser(ctx, *req.Statements, *req.UsernameConfig, e)
return &CreateUserResponse{
Username: u,
Password: p,
}, err
}
func (s *gRPCServer) RenewUser(ctx context.Context, req *RenewUserRequest) (*Empty, error) {
e, err := ptypes.Timestamp(req.Expiration)
if err != nil {
return nil, err
}
err = s.impl.RenewUser(ctx, *req.Statements, req.Username, e)
return &Empty{}, err
}
func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*Empty, error) {
err := s.impl.RevokeUser(ctx, *req.Statements, req.Username)
return &Empty{}, err
}
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
config := map[string]interface{}{}
err := json.Unmarshal(req.Config, &config)
if err != nil {
return nil, err
}
err = s.impl.Initialize(ctx, config, req.VerifyConnection)
return &Empty{}, err
}
func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) {
s.impl.Close()
return &Empty{}, nil
}
// ---- gRPC client domain ----
type gRPCClient struct {
client DatabaseClient
clientConn *grpc.ClientConn
}
func (c gRPCClient) Type() (string, error) {
// If the plugin has already shutdown, this will hang forever so we give it
// a one second timeout.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return "", ErrPluginShutdown
}
resp, err := c.client.Type(ctx, &Empty{})
if err != nil {
return "", err
}
return resp.Type, err
}
func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
t, err := ptypes.TimestampProto(expiration)
if err != nil {
return "", "", err
}
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return "", "", ErrPluginShutdown
}
resp, err := c.client.CreateUser(ctx, &CreateUserRequest{
Statements: &statements,
UsernameConfig: &usernameConfig,
Expiration: t,
})
if err != nil {
return "", "", err
}
return resp.Username, resp.Password, err
}
func (c *gRPCClient) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error {
t, err := ptypes.TimestampProto(expiration)
if err != nil {
return err
}
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return ErrPluginShutdown
}
_, err = c.client.RenewUser(ctx, &RenewUserRequest{
Statements: &statements,
Username: username,
Expiration: t,
})
return err
}
func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, username string) error {
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return ErrPluginShutdown
}
_, err := c.client.RevokeUser(ctx, &RevokeUserRequest{
Statements: &statements,
Username: username,
})
return err
}
func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error {
configRaw, err := json.Marshal(config)
if err != nil {
return err
}
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return ErrPluginShutdown
}
_, err = c.client.Initialize(ctx, &InitializeRequest{
Config: configRaw,
VerifyConnection: verifyConnection,
})
return err
}
func (c *gRPCClient) Close() error {
// If the plugin has already shutdown, this will hang forever so we give it
// a one second timeout.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
_, err := c.client.Close(ctx, &Empty{})
return err
}
return nil
}

View File

@ -0,0 +1,139 @@
package dbplugin
import (
"context"
"fmt"
"net/rpc"
"time"
)
// ---- RPC server domain ----
// databasePluginRPCServer implements an RPC version of Database and is run
// inside a plugin. It wraps an underlying implementation of Database.
type databasePluginRPCServer struct {
impl Database
}
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
var err error
*resp, err = ds.impl.Type()
return err
}
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequestRPC, resp *CreateUserResponse) error {
var err error
resp.Username, resp.Password, err = ds.impl.CreateUser(context.Background(), args.Statements, args.UsernameConfig, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequestRPC, _ *struct{}) error {
err := ds.impl.RenewUser(context.Background(), args.Statements, args.Username, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *struct{}) error {
err := ds.impl.RevokeUser(context.Background(), args.Statements, args.Username)
return err
}
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error {
err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection)
return err
}
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
ds.impl.Close()
return nil
}
// ---- RPC client domain ----
// databasePluginRPCClient implements Database and is used on the client to
// make RPC calls to a plugin.
type databasePluginRPCClient struct {
client *rpc.Client
}
func (dr *databasePluginRPCClient) Type() (string, error) {
var dbType string
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
return fmt.Sprintf("plugin-%s", dbType), err
}
func (dr *databasePluginRPCClient) CreateUser(_ context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
req := CreateUserRequestRPC{
Statements: statements,
UsernameConfig: usernameConfig,
Expiration: expiration,
}
var resp CreateUserResponse
err = dr.client.Call("Plugin.CreateUser", req, &resp)
return resp.Username, resp.Password, err
}
func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements Statements, username string, expiration time.Time) error {
req := RenewUserRequestRPC{
Statements: statements,
Username: username,
Expiration: expiration,
}
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error {
req := RevokeUserRequestRPC{
Statements: statements,
Username: username,
}
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error {
req := InitializeRequestRPC{
Config: conf,
VerifyConnection: verifyConnection,
}
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Close() error {
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
return err
}
// ---- RPC Request Args Domain ----
type InitializeRequestRPC struct {
Config map[string]interface{}
VerifyConnection bool
}
type CreateUserRequestRPC struct {
Statements Statements
UsernameConfig UsernameConfig
Expiration time.Time
}
type RenewUserRequestRPC struct {
Statements Statements
Username string
Expiration time.Time
}
type RevokeUserRequestRPC struct {
Statements Statements
Username string
}

View File

@ -1,10 +1,13 @@
package dbplugin
import (
"context"
"fmt"
"net/rpc"
"time"
"google.golang.org/grpc"
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil"
log "github.com/mgutz/logxi/v1"
@ -13,29 +16,14 @@ import (
// Database is the interface that all database objects must implement.
type Database interface {
Type() (string, error)
CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
RenewUser(statements Statements, username string, expiration time.Time) error
RevokeUser(statements Statements, username string) error
CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error
RevokeUser(ctx context.Context, statements Statements, username string) error
Initialize(config map[string]interface{}, verifyConnection bool) error
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error
Close() error
}
// Statements set in role creation and passed into the database type's functions.
type Statements struct {
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"`
RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"`
RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"`
}
// UsernameConfig is used to configure prefixes for the username to be
// generated.
type UsernameConfig struct {
DisplayName string
RoleName string
}
// PluginFactory is used to build plugin database types. It wraps the database
// object in a logging and metrics middleware.
func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
@ -45,6 +33,7 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
return nil, err
}
var transport string
var db Database
if pluginRunner.Builtin {
// Plugin is builtin so we can retrieve an instance of the interface
@ -60,12 +49,24 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
return nil, fmt.Errorf("unsuported database type: %s", pluginName)
}
transport = "builtin"
} else {
// create a DatabasePluginClient instance
db, err = newPluginClient(sys, pluginRunner, logger)
if err != nil {
return nil, err
}
// Switch on the underlying database client type to get the transport
// method.
switch db.(*DatabasePluginClient).Database.(type) {
case *gRPCClient:
transport = "gRPC"
case *databasePluginRPCClient:
transport = "netRPC"
}
}
typeStr, err := db.Type()
@ -82,9 +83,10 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
// Wrap with tracing middleware
if logger.IsTrace() {
db = &databaseTracingMiddleware{
next: db,
typeStr: typeStr,
logger: logger,
transport: transport,
next: db,
typeStr: typeStr,
logger: logger,
}
}
@ -115,33 +117,14 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
return &databasePluginRPCClient{client: c}, nil
}
// ---- RPC Request Args Domain ----
type InitializeRequest struct {
Config map[string]interface{}
VerifyConnection bool
func (d DatabasePlugin) GRPCServer(s *grpc.Server) error {
RegisterDatabaseServer(s, &gRPCServer{impl: d.impl})
return nil
}
type CreateUserRequest struct {
Statements Statements
UsernameConfig UsernameConfig
Expiration time.Time
}
type RenewUserRequest struct {
Statements Statements
Username string
Expiration time.Time
}
type RevokeUserRequest struct {
Statements Statements
Username string
}
// ---- RPC Response Args Domain ----
type CreateUserResponse struct {
Username string
Password string
func (DatabasePlugin) GRPCClient(c *grpc.ClientConn) (interface{}, error) {
return &gRPCClient{
client: NewDatabaseClient(c),
clientConn: c,
}, nil
}

View File

@ -1,11 +1,13 @@
package dbplugin_test
import (
"context"
"errors"
"os"
"testing"
"time"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/pluginutil"
vaulthttp "github.com/hashicorp/vault/http"
@ -20,7 +22,7 @@ type mockPlugin struct {
}
func (m *mockPlugin) Type() (string, error) { return "mock", nil }
func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (m *mockPlugin) CreateUser(_ context.Context, statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
err = errors.New("err")
if usernameConf.DisplayName == "" || expiration.IsZero() {
return "", "", err
@ -34,7 +36,7 @@ func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbp
return usernameConf.DisplayName, "test", nil
}
func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
func (m *mockPlugin) RenewUser(_ context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
err := errors.New("err")
if username == "" || expiration.IsZero() {
return err
@ -46,7 +48,7 @@ func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string,
return nil
}
func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) error {
func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statements, username string) error {
err := errors.New("err")
if username == "" {
return err
@ -59,7 +61,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string)
delete(m.users, username)
return nil
}
func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error {
func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error {
err := errors.New("err")
if len(conf) != 1 {
return err
@ -80,14 +82,15 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
cores := cluster.Cores
sys := vault.TestDynamicSystemView(cores[0].Core)
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_Main")
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_GRPC_Main")
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", "TestPlugin_NetRPC_Main")
return cluster, sys
}
// This is not an actual test case, it's a helper function that will be executed
// by the go-plugin client via an exec call.
func TestPlugin_Main(t *testing.T) {
func TestPlugin_GRPC_Main(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return
}
@ -105,6 +108,30 @@ func TestPlugin_Main(t *testing.T) {
plugins.Serve(plugin, apiClientMeta.GetTLSConfig())
}
// This is not an actual test case, it's a helper function that will be executed
// by the go-plugin client via an exec call.
func TestPlugin_NetRPC_Main(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return
}
p := &mockPlugin{
users: make(map[string][]string),
}
args := []string{"--tls-skip-verify=true"}
apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(args)
tlsProvider := pluginutil.VaultPluginTLSProvider(apiClientMeta.GetTLSConfig())
serveConf := dbplugin.ServeConfig(p, tlsProvider)
serveConf.GRPCServer = nil
plugin.Serve(serveConf)
}
func TestPlugin_Initialize(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
@ -118,7 +145,7 @@ func TestPlugin_Initialize(t *testing.T) {
"test": 1,
}
err = dbRaw.Initialize(connectionDetails, true)
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -143,7 +170,7 @@ func TestPlugin_CreateUser(t *testing.T) {
"test": 1,
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -153,7 +180,7 @@ func TestPlugin_CreateUser(t *testing.T) {
RoleName: "test",
}
us, pw, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -163,7 +190,7 @@ func TestPlugin_CreateUser(t *testing.T) {
// try and save the same user again to verify it saved the first time, this
// should return an error
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("expected an error, user wasn't created correctly")
}
@ -182,7 +209,7 @@ func TestPlugin_RenewUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -192,12 +219,12 @@ func TestPlugin_RenewUser(t *testing.T) {
RoleName: "test",
}
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.RenewUser(dbplugin.Statements{}, us, time.Now().Add(time.Minute))
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -216,7 +243,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -226,19 +253,159 @@ func TestPlugin_RevokeUser(t *testing.T) {
RoleName: "test",
}
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
// Test default revoke statememts
err = db.RevokeUser(dbplugin.Statements{}, us)
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
if err != nil {
t.Fatalf("err: %s", err)
}
// Try adding the same username back so we can verify it was removed
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
}
// Test the code is still compatible with an old netRPC plugin
func TestPlugin_NetRPC_Initialize(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
dbRaw, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
connectionDetails := map[string]interface{}{
"test": 1,
}
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
err = dbRaw.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestPlugin_NetRPC_CreateUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
defer db.Close()
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConf := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
if us != "test" || pw != "test" {
t.Fatal("expected username and password to be 'test'")
}
// try and save the same user again to verify it saved the first time, this
// should return an error
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("expected an error, user wasn't created correctly")
}
}
func TestPlugin_NetRPC_RenewUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
defer db.Close()
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConf := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
defer db.Close()
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConf := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
// Test default revoke statememts
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
if err != nil {
t.Fatalf("err: %s", err)
}
// Try adding the same username back so we can verify it was removed
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -10,6 +10,10 @@ import (
// Database implementation in a databasePluginRPCServer object and starts a
// RPC server.
func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
plugin.Serve(ServeConfig(db, tlsProvider))
}
func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig {
dbPlugin := &DatabasePlugin{
impl: db,
}
@ -19,53 +23,10 @@ func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
"database": dbPlugin,
}
plugin.Serve(&plugin.ServeConfig{
return &plugin.ServeConfig{
HandshakeConfig: handshakeConfig,
Plugins: pluginMap,
TLSProvider: tlsProvider,
})
}
// ---- RPC server domain ----
// databasePluginRPCServer implements an RPC version of Database and is run
// inside a plugin. It wraps an underlying implementation of Database.
type databasePluginRPCServer struct {
impl Database
}
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
var err error
*resp, err = ds.impl.Type()
return err
}
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error {
var err error
resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernameConfig, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error {
err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error {
err := ds.impl.RevokeUser(args.Statements, args.Username)
return err
}
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequest, _ *struct{}) error {
err := ds.impl.Initialize(args.Config, args.VerifyConnection)
return err
}
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
ds.impl.Close()
return nil
GRPCServer: plugin.DefaultGRPCServer,
}
}

View File

@ -1,6 +1,7 @@
package database
import (
"context"
"errors"
"fmt"
@ -62,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc {
b.clearConnection(name)
// Execute plugin again, we don't need the object so throw away.
_, err := b.createDBObj(req.Storage, name)
_, err := b.createDBObj(context.TODO(), req.Storage, name)
if err != nil {
return nil, err
}
@ -230,7 +231,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
}
err = db.Initialize(config.ConnectionDetails, verifyConnection)
err = db.Initialize(context.TODO(), config.ConnectionDetails, verifyConnection)
if err != nil {
db.Close()
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil

View File

@ -1,6 +1,7 @@
package database
import (
"context"
"fmt"
"time"
@ -66,7 +67,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
unlockFunc = b.Unlock
// Create a new DB object
db, err = b.createDBObj(req.Storage, role.DBName)
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
if err != nil {
unlockFunc()
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
@ -81,7 +82,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
}
// Create the user
username, password, err := db.CreateUser(role.Statements, usernameConfig, expiration)
username, password, err := db.CreateUser(context.TODO(), role.Statements, usernameConfig, expiration)
// Unlock
unlockFunc()
if err != nil {

View File

@ -1,6 +1,7 @@
package database
import (
"context"
"fmt"
"github.com/hashicorp/vault/logical"
@ -60,7 +61,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
unlockFunc = b.Unlock
// Create a new DB object
db, err = b.createDBObj(req.Storage, role.DBName)
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
if err != nil {
unlockFunc()
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
@ -69,7 +70,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
// Make sure we increase the VALID UNTIL endpoint for this user.
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
err := db.RenewUser(role.Statements, username, expireTime)
err := db.RenewUser(context.TODO(), role.Statements, username, expireTime)
// Unlock
unlockFunc()
if err != nil {
@ -119,14 +120,14 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
unlockFunc = b.Unlock
// Create a new DB object
db, err = b.createDBObj(req.Storage, role.DBName)
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
if err != nil {
unlockFunc()
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
}
}
err = db.RevokeUser(role.Statements, username)
err = db.RevokeUser(context.TODO(), role.Statements, username)
// Unlock
unlockFunc()
if err != nil {

View File

@ -43,6 +43,9 @@ type PolicyRequest struct {
// Whether to upsert
Upsert bool
// Whether to allow plaintext backup
AllowPlaintextBackup bool
}
type LockManager struct {
@ -378,10 +381,11 @@ func (lm *LockManager) getPolicyCommon(req PolicyRequest, lockType bool) (*Polic
}
p = &Policy{
Name: req.Name,
Type: req.KeyType,
Derived: req.Derived,
Exportable: req.Exportable,
Name: req.Name,
Type: req.KeyType,
Derived: req.Derived,
Exportable: req.Exportable,
AllowPlaintextBackup: req.AllowPlaintextBackup,
}
if req.Derived {
p.KDF = Kdf_hkdf_sha256

View File

@ -119,6 +119,10 @@ func (r *PluginRunner) runCommon(wrapper RunnerUtil, pluginMap map[string]plugin
SecureConfig: secureConfig,
TLSConfig: clientTLSConfig,
Logger: namedLogger,
AllowedProtocols: []plugin.Protocol{
plugin.ProtocolNetRPC,
plugin.ProtocolGRPC,
},
}
client := plugin.NewClient(clientConfig)

View File

@ -192,8 +192,7 @@ func TestBackendHandleRequest_helpRoot(t *testing.T) {
func TestBackendHandleRequest_renewAuth(t *testing.T) {
b := &Backend{}
resp, err := b.HandleRequest(logical.RenewAuthRequest(
"/foo", &logical.Auth{}, nil))
resp, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -213,8 +212,7 @@ func TestBackendHandleRequest_renewAuthCallback(t *testing.T) {
AuthRenew: callback,
}
_, err := b.HandleRequest(logical.RenewAuthRequest(
"/foo", &logical.Auth{}, nil))
_, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -237,8 +235,7 @@ func TestBackendHandleRequest_renew(t *testing.T) {
Secrets: []*Secret{secret},
}
_, err := b.HandleRequest(logical.RenewRequest(
"/foo", secret.Response(nil, nil).Secret, nil))
_, err := b.HandleRequest(logical.RenewRequest("/foo", secret.Response(nil, nil).Secret, nil))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -293,8 +290,7 @@ func TestBackendHandleRequest_revoke(t *testing.T) {
Secrets: []*Secret{secret},
}
_, err := b.HandleRequest(logical.RevokeRequest(
"/foo", secret.Response(nil, nil).Secret, nil))
_, err := b.HandleRequest(logical.RevokeRequest("/foo", secret.Response(nil, nil).Secret, nil))
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -8,7 +8,8 @@ import (
)
// LeaseExtend returns an OperationFunc that can be used to simply extend the
// lease of the auth/secret for the duration that was requested.
// lease of the auth/secret for the duration that was requested. The parameters
// provided are used to determine the lease's new TTL value.
//
// backendIncrement is the backend's requested increment -- perhaps from a user
// request, perhaps from a role/config value. If not set, uses the mount/system

View File

@ -200,8 +200,7 @@ func (r *Request) SetLastRemoteWAL(last uint64) {
}
// RenewRequest creates the structure of the renew request.
func RenewRequest(
path string, secret *Secret, data map[string]interface{}) *Request {
func RenewRequest(path string, secret *Secret, data map[string]interface{}) *Request {
return &Request{
Operation: RenewOperation,
Path: path,
@ -211,8 +210,7 @@ func RenewRequest(
}
// RenewAuthRequest creates the structure of the renew request for an auth.
func RenewAuthRequest(
path string, auth *Auth, data map[string]interface{}) *Request {
func RenewAuthRequest(path string, auth *Auth, data map[string]interface{}) *Request {
return &Request{
Operation: RenewOperation,
Path: path,
@ -222,8 +220,7 @@ func RenewAuthRequest(
}
// RevokeRequest creates the structure of the revoke request.
func RevokeRequest(
path string, secret *Secret, data map[string]interface{}) *Request {
func RevokeRequest(path string, secret *Secret, data map[string]interface{}) *Request {
return &Request{
Operation: RevokeOperation,
Path: path,

View File

@ -1,6 +1,7 @@
package cassandra
import (
"context"
"strings"
"time"
@ -21,6 +22,8 @@ const (
cassandraTypeName = "cassandra"
)
var _ dbplugin.Database = &Cassandra{}
// Cassandra is an implementation of Database interface
type Cassandra struct {
connutil.ConnectionProducer
@ -64,8 +67,8 @@ func (c *Cassandra) Type() (string, error) {
return cassandraTypeName, nil
}
func (c *Cassandra) getConnection() (*gocql.Session, error) {
session, err := c.Connection()
func (c *Cassandra) getConnection(ctx context.Context) (*gocql.Session, error) {
session, err := c.Connection(ctx)
if err != nil {
return nil, err
}
@ -75,13 +78,13 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) {
// CreateUser generates the username/password on the underlying Cassandra secret backend as instructed by
// the CreationStatement provided.
func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock
c.Lock()
defer c.Unlock()
// Get the connection
session, err := c.getConnection()
session, err := c.getConnection(ctx)
if err != nil {
return "", "", err
}
@ -138,18 +141,18 @@ func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig db
}
// RenewUser is not supported on Cassandra, so this is a no-op.
func (c *Cassandra) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
func (c *Cassandra) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// NOOP
return nil
}
// RevokeUser attempts to drop the specified user.
func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) error {
func (c *Cassandra) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
// Grab the lock
c.Lock()
defer c.Unlock()
session, err := c.getConnection()
session, err := c.getConnection(ctx)
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package cassandra
import (
"context"
"os"
"strconv"
"testing"
@ -89,7 +90,7 @@ func TestCassandra_Initialize(t *testing.T) {
db := dbRaw.(*Cassandra)
connProducer := db.ConnectionProducer.(*cassandraConnectionProducer)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -112,7 +113,7 @@ func TestCassandra_Initialize(t *testing.T) {
"protocol_version": "4",
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -135,7 +136,7 @@ func TestCassandra_CreateUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*Cassandra)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -149,7 +150,7 @@ func TestCassandra_CreateUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -176,7 +177,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*Cassandra)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -190,7 +191,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -199,7 +200,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err)
}
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -222,7 +223,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*Cassandra)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -236,7 +237,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -246,7 +247,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
}
// Test default revoke statememts
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -1,6 +1,7 @@
package cassandra
import (
"context"
"crypto/tls"
"fmt"
"strings"
@ -43,7 +44,7 @@ type cassandraConnectionProducer struct {
sync.Mutex
}
func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
c.Lock()
defer c.Unlock()
@ -106,7 +107,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve
c.Initialized = true
if verifyConnection {
if _, err := c.Connection(); err != nil {
if _, err := c.Connection(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
}
}
@ -114,7 +115,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve
return nil
}
func (c *cassandraConnectionProducer) Connection() (interface{}, error) {
func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) {
if !c.Initialized {
return nil, connutil.ErrNotInitialized
}

View File

@ -1,6 +1,7 @@
package hana
import (
"context"
"database/sql"
"fmt"
"strings"
@ -26,6 +27,8 @@ type HANA struct {
credsutil.CredentialsProducer
}
var _ dbplugin.Database = &HANA{}
// New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) {
connProducer := &connutil.SQLConnectionProducer{}
@ -63,8 +66,8 @@ func (h *HANA) Type() (string, error) {
return hanaTypeName, nil
}
func (h *HANA) getConnection() (*sql.DB, error) {
db, err := h.Connection()
func (h *HANA) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := h.Connection(ctx)
if err != nil {
return nil, err
}
@ -74,13 +77,13 @@ func (h *HANA) getConnection() (*sql.DB, error) {
// CreateUser generates the username/password on the underlying HANA secret backend
// as instructed by the CreationStatement provided.
func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock
h.Lock()
defer h.Unlock()
// Get the connection
db, err := h.getConnection()
db, err := h.getConnection(ctx)
if err != nil {
return "", "", err
}
@ -117,7 +120,7 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
}
// Start a transaction
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
}
@ -130,7 +133,7 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
continue
}
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username,
"password": password,
"expiration": expirationStr,
@ -139,7 +142,7 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
return "", "", err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err
}
}
@ -153,15 +156,15 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
}
// Renewing hana user just means altering user's valid until property
func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// Get connection
db, err := h.getConnection()
db, err := h.getConnection(ctx)
if err != nil {
return err
}
// Start a transaction
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
@ -175,12 +178,12 @@ func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expira
}
// Renew user's valid until property field
stmt, err := tx.Prepare("ALTER USER " + username + " VALID UNTIL " + "'" + expirationStr + "'")
stmt, err := tx.PrepareContext(ctx, "ALTER USER "+username+" VALID UNTIL "+"'"+expirationStr+"'")
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
if _, err := stmt.ExecContext(ctx); err != nil {
return err
}
@ -193,20 +196,20 @@ func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expira
}
// Revoking hana user will deactivate user and try to perform a soft drop
func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error {
func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
// default revoke will be a soft drop on user
if statements.RevocationStatements == "" {
return h.revokeUserDefault(username)
return h.revokeUserDefault(ctx, username)
}
// Get connection
db, err := h.getConnection()
db, err := h.getConnection(ctx)
if err != nil {
return err
}
// Start a transaction
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
@ -219,14 +222,14 @@ func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error
continue
}
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username,
}))
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
if _, err := stmt.ExecContext(ctx); err != nil {
return err
}
}
@ -239,38 +242,38 @@ func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error
return nil
}
func (h *HANA) revokeUserDefault(username string) error {
func (h *HANA) revokeUserDefault(ctx context.Context, username string) error {
// Get connection
db, err := h.getConnection()
db, err := h.getConnection(ctx)
if err != nil {
return err
}
// Start a transaction
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
// Disable server login for user
disableStmt, err := tx.Prepare(fmt.Sprintf("ALTER USER %s DEACTIVATE USER NOW", username))
disableStmt, err := tx.PrepareContext(ctx, fmt.Sprintf("ALTER USER %s DEACTIVATE USER NOW", username))
if err != nil {
return err
}
defer disableStmt.Close()
if _, err := disableStmt.Exec(); err != nil {
if _, err := disableStmt.ExecContext(ctx); err != nil {
return err
}
// Invalidates current sessions and performs soft drop (drop if no dependencies)
// if hard drop is desired, custom revoke statements should be written for role
dropStmt, err := tx.Prepare(fmt.Sprintf("DROP USER %s RESTRICT", username))
dropStmt, err := tx.PrepareContext(ctx, fmt.Sprintf("DROP USER %s RESTRICT", username))
if err != nil {
return err
}
defer dropStmt.Close()
if _, err := dropStmt.Exec(); err != nil {
if _, err := dropStmt.ExecContext(ctx); err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package hana
import (
"context"
"database/sql"
"fmt"
"os"
@ -25,7 +26,7 @@ func TestHANA_Initialize(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*HANA)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -55,7 +56,7 @@ func TestHANA_CreateUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*HANA)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -66,7 +67,7 @@ func TestHANA_CreateUser(t *testing.T) {
}
// Test with no configured Creation Statememt
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour))
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
}
@ -75,7 +76,7 @@ func TestHANA_CreateUser(t *testing.T) {
CreationStatements: testHANARole,
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -98,7 +99,7 @@ func TestHANA_RevokeUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*HANA)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -113,7 +114,7 @@ func TestHANA_RevokeUser(t *testing.T) {
}
// Test default revoke statememts
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -121,7 +122,7 @@ func TestHANA_RevokeUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err)
}
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -130,7 +131,7 @@ func TestHANA_RevokeUser(t *testing.T) {
}
// Test custom revoke statememt
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -139,7 +140,7 @@ func TestHANA_RevokeUser(t *testing.T) {
}
statements.RevocationStatements = testHANADrop
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -1,6 +1,7 @@
package mongodb
import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
@ -33,7 +34,7 @@ type mongoDBConnectionProducer struct {
}
// Initialize parses connection configuration.
func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
c.Lock()
defer c.Unlock()
@ -75,7 +76,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri
c.Initialized = true
if verifyConnection {
if _, err := c.Connection(); err != nil {
if _, err := c.Connection(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
}
@ -88,7 +89,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri
}
// Connection creates a database connection.
func (c *mongoDBConnectionProducer) Connection() (interface{}, error) {
func (c *mongoDBConnectionProducer) Connection(_ context.Context) (interface{}, error) {
if !c.Initialized {
return nil, connutil.ErrNotInitialized
}

View File

@ -1,6 +1,7 @@
package mongodb
import (
"context"
"io"
"strings"
"time"
@ -27,6 +28,8 @@ type MongoDB struct {
credsutil.CredentialsProducer
}
var _ dbplugin.Database = &MongoDB{}
// New returns a new MongoDB instance
func New() (interface{}, error) {
connProducer := &mongoDBConnectionProducer{}
@ -63,8 +66,8 @@ func (m *MongoDB) Type() (string, error) {
return mongoDBTypeName, nil
}
func (m *MongoDB) getConnection() (*mgo.Session, error) {
session, err := m.Connection()
func (m *MongoDB) getConnection(ctx context.Context) (*mgo.Session, error) {
session, err := m.Connection(ctx)
if err != nil {
return nil, err
}
@ -80,7 +83,7 @@ func (m *MongoDB) getConnection() (*mgo.Session, error) {
//
// JSON Example:
// { "db": "admin", "roles": [{ "role": "readWrite" }, {"role": "read", "db": "foo"}] }
func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock
m.Lock()
defer m.Unlock()
@ -89,7 +92,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
return "", "", dbutil.ErrEmptyCreationStatement
}
session, err := m.getConnection()
session, err := m.getConnection(ctx)
if err != nil {
return "", "", err
}
@ -133,7 +136,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
if err := m.ConnectionProducer.Close(); err != nil {
return "", "", errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
}
session, err := m.getConnection()
session, err := m.getConnection(ctx)
if err != nil {
return "", "", err
}
@ -149,15 +152,15 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
}
// RenewUser is not supported on MongoDB, so this is a no-op.
func (m *MongoDB) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
func (m *MongoDB) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// NOOP
return nil
}
// RevokeUser drops the specified user from the authentication databse. If none is provided
// in the revocation statement, the default "admin" authentication database will be assumed.
func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) error {
session, err := m.getConnection()
func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
session, err := m.getConnection(ctx)
if err != nil {
return err
}
@ -188,7 +191,7 @@ func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) er
if err := m.ConnectionProducer.Close(); err != nil {
return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
}
session, err := m.getConnection()
session, err := m.getConnection(ctx)
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package mongodb
import (
"context"
"fmt"
"os"
"testing"
@ -79,7 +80,7 @@ func TestMongoDB_Initialize(t *testing.T) {
db := dbRaw.(*MongoDB)
connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer)
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -107,7 +108,7 @@ func TestMongoDB_CreateUser(t *testing.T) {
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -121,7 +122,7 @@ func TestMongoDB_CreateUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -145,7 +146,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -159,7 +160,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -182,7 +183,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -196,7 +197,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -206,7 +207,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
}
// Test default revocation statememt
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -1,6 +1,7 @@
package mssql
import (
"context"
"database/sql"
"fmt"
"strings"
@ -18,6 +19,8 @@ import (
const msSQLTypeName = "mssql"
var _ dbplugin.Database = &MSSQL{}
// MSSQL is an implementation of Database interface
type MSSQL struct {
connutil.ConnectionProducer
@ -60,8 +63,8 @@ func (m *MSSQL) Type() (string, error) {
return msSQLTypeName, nil
}
func (m *MSSQL) getConnection() (*sql.DB, error) {
db, err := m.Connection()
func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := m.Connection(ctx)
if err != nil {
return nil, err
}
@ -71,13 +74,13 @@ func (m *MSSQL) getConnection() (*sql.DB, error) {
// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
// the CreationStatement provided.
func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock
m.Lock()
defer m.Unlock()
// Get the connection
db, err := m.getConnection()
db, err := m.getConnection(ctx)
if err != nil {
return "", "", err
}
@ -102,7 +105,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
}
// Start a transaction
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
}
@ -115,7 +118,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
continue
}
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username,
"password": password,
"expiration": expirationStr,
@ -124,7 +127,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
return "", "", err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err
}
}
@ -138,7 +141,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
}
// RenewUser is not supported on MSSQL, so this is a no-op.
func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// NOOP
return nil
}
@ -146,19 +149,19 @@ func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expir
// RevokeUser attempts to drop the specified user. It will first attempt to disable login,
// then kill pending connections from that user, and finally drop the user and login from the
// database instance.
func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error {
func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
if statements.RevocationStatements == "" {
return m.revokeUserDefault(username)
return m.revokeUserDefault(ctx, username)
}
// Get connection
db, err := m.getConnection()
db, err := m.getConnection(ctx)
if err != nil {
return err
}
// Start a transaction
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
@ -171,14 +174,14 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro
continue
}
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username,
}))
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
if _, err := stmt.ExecContext(ctx); err != nil {
return err
}
}
@ -191,20 +194,20 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro
return nil
}
func (m *MSSQL) revokeUserDefault(username string) error {
func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
// Get connection
db, err := m.getConnection()
db, err := m.getConnection(ctx)
if err != nil {
return err
}
// First disable server login
disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
disableStmt, err := db.PrepareContext(ctx, fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
if err != nil {
return err
}
defer disableStmt.Close()
if _, err := disableStmt.Exec(); err != nil {
if _, err := disableStmt.ExecContext(ctx); err != nil {
return err
}
@ -212,14 +215,14 @@ func (m *MSSQL) revokeUserDefault(username string) error {
// sessions. There cannot be any active sessions before we drop the logins
// This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible
sessionStmt, err := db.Prepare(fmt.Sprintf(
sessionStmt, err := db.PrepareContext(ctx, fmt.Sprintf(
"SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username))
if err != nil {
return err
}
defer sessionStmt.Close()
sessionRows, err := sessionStmt.Query()
sessionRows, err := sessionStmt.QueryContext(ctx)
if err != nil {
return err
}
@ -240,13 +243,13 @@ func (m *MSSQL) revokeUserDefault(username string) error {
// we need to drop the database users before we can drop the login and the role
// This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible
stmt, err := db.Prepare(fmt.Sprintf("EXEC master.dbo.sp_msloginmappings '%s';", username))
stmt, err := db.PrepareContext(ctx, fmt.Sprintf("EXEC master.dbo.sp_msloginmappings '%s';", username))
if err != nil {
return err
}
defer stmt.Close()
rows, err := stmt.Query()
rows, err := stmt.QueryContext(ctx)
if err != nil {
return err
}
@ -266,13 +269,13 @@ func (m *MSSQL) revokeUserDefault(username string) error {
// many permissions as possible right now
var lastStmtError error
for _, query := range revokeStmts {
stmt, err := db.Prepare(query)
stmt, err := db.PrepareContext(ctx, query)
if err != nil {
lastStmtError = err
continue
}
defer stmt.Close()
_, err = stmt.Exec()
_, err = stmt.ExecContext(ctx)
if err != nil {
lastStmtError = err
}
@ -287,12 +290,12 @@ func (m *MSSQL) revokeUserDefault(username string) error {
}
// Drop this login
stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username))
stmt, err = db.PrepareContext(ctx, fmt.Sprintf(dropLoginSQL, username, username))
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
if _, err := stmt.ExecContext(ctx); err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package mssql
import (
"context"
"database/sql"
"fmt"
"os"
@ -30,7 +31,7 @@ func TestMSSQL_Initialize(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*MSSQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -51,7 +52,7 @@ func TestMSSQL_Initialize(t *testing.T) {
"max_open_connections": "5",
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -69,7 +70,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*MSSQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -80,7 +81,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
}
// Test with no configured Creation Statememt
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
}
@ -89,7 +90,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
CreationStatements: testMSSQLRole,
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -111,7 +112,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*MSSQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -125,7 +126,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -135,7 +136,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
}
// Test default revoke statememts
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -144,7 +145,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
t.Fatal("Credentials were not revoked")
}
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -155,7 +156,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
// Test custom revoke statememt
statements.RevocationStatements = testMSSQLDrop
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -1,6 +1,7 @@
package mysql
import (
"context"
"database/sql"
"strings"
"time"
@ -30,6 +31,8 @@ var (
LegacyUsernameLen int = 16
)
var _ dbplugin.Database = &MySQL{}
type MySQL struct {
connutil.ConnectionProducer
credsutil.CredentialsProducer
@ -88,8 +91,8 @@ func (m *MySQL) Type() (string, error) {
return mySQLTypeName, nil
}
func (m *MySQL) getConnection() (*sql.DB, error) {
db, err := m.Connection()
func (m *MySQL) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := m.Connection(ctx)
if err != nil {
return nil, err
}
@ -97,13 +100,13 @@ func (m *MySQL) getConnection() (*sql.DB, error) {
return db.(*sql.DB), nil
}
func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock
m.Lock()
defer m.Unlock()
// Get the connection
db, err := m.getConnection()
db, err := m.getConnection(ctx)
if err != nil {
return "", "", err
}
@ -128,7 +131,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
}
// Start a transaction
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
}
@ -146,7 +149,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
"expiration": expirationStr,
})
stmt, err := tx.Prepare(query)
stmt, err := tx.PrepareContext(ctx, query)
if err != nil {
// If the error code we get back is Error 1295: This command is not
// supported in the prepared statement protocol yet, we will execute
@ -155,7 +158,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
// prepare supported commands. If there is no error when running we
// will continue to the next statement.
if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 {
_, err = tx.Exec(query)
_, err = tx.ExecContext(ctx, query)
if err != nil {
return "", "", err
}
@ -165,7 +168,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
return "", "", err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err
}
}
@ -179,17 +182,17 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
}
// NOOP
func (m *MySQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
func (m *MySQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
return nil
}
func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) error {
func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
// Grab the read lock
m.Lock()
defer m.Unlock()
// Get the connection
db, err := m.getConnection()
db, err := m.getConnection(ctx)
if err != nil {
return err
}
@ -201,7 +204,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro
}
// Start a transaction
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
@ -217,7 +220,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro
// 1295: This command is not supported in the prepared statement protocol yet
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
query = strings.Replace(query, "{{name}}", username, -1)
_, err = tx.Exec(query)
_, err = tx.ExecContext(ctx, query)
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package mysql
import (
"context"
"database/sql"
"fmt"
"os"
@ -108,7 +109,7 @@ func TestMySQL_Initialize(t *testing.T) {
db := dbRaw.(*MySQL)
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -128,7 +129,7 @@ func TestMySQL_Initialize(t *testing.T) {
"max_open_connections": "5",
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -146,7 +147,7 @@ func TestMySQL_CreateUser(t *testing.T) {
dbRaw, _ := f()
db := dbRaw.(*MySQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -157,7 +158,7 @@ func TestMySQL_CreateUser(t *testing.T) {
}
// Test with no configured Creation Statememt
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
}
@ -166,7 +167,7 @@ func TestMySQL_CreateUser(t *testing.T) {
CreationStatements: testMySQLRoleWildCard,
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -176,7 +177,7 @@ func TestMySQL_CreateUser(t *testing.T) {
}
// Test a second time to make sure usernames don't collide
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -188,7 +189,7 @@ func TestMySQL_CreateUser(t *testing.T) {
// Test with a manualy prepare statement
statements.CreationStatements = testMySQLRolePreparedStmt
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -211,7 +212,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
dbRaw, _ := f()
db := dbRaw.(*MySQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -222,7 +223,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
}
// Test with no configured Creation Statememt
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
}
@ -231,7 +232,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
CreationStatements: testMySQLRoleWildCard,
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -241,7 +242,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
}
// Test a second time to make sure usernames don't collide
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -263,7 +264,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
dbRaw, _ := f()
db := dbRaw.(*MySQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -277,7 +278,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -287,7 +288,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
}
// Test default revoke statememts
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -297,7 +298,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
}
statements.CreationStatements = testMySQLRoleWildCard
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -308,7 +309,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
// Test custom revoke statements
statements.RevocationStatements = testMySQLRevocationSQL
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -1,6 +1,7 @@
package postgresql
import (
"context"
"database/sql"
"fmt"
"strings"
@ -24,6 +25,8 @@ ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
`
)
var _ dbplugin.Database = &PostgreSQL{}
// New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) {
connProducer := &connutil.SQLConnectionProducer{}
@ -65,8 +68,8 @@ func (p *PostgreSQL) Type() (string, error) {
return postgreSQLTypeName, nil
}
func (p *PostgreSQL) getConnection() (*sql.DB, error) {
db, err := p.Connection()
func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := p.Connection(ctx)
if err != nil {
return nil, err
}
@ -74,7 +77,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) {
return db.(*sql.DB), nil
}
func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
if statements.CreationStatements == "" {
return "", "", dbutil.ErrEmptyCreationStatement
}
@ -99,14 +102,14 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
}
// Get the connection
db, err := p.getConnection()
db, err := p.getConnection(ctx)
if err != nil {
return "", "", err
}
// Start a transaction
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
@ -123,7 +126,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
continue
}
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username,
"password": password,
"expiration": expirationStr,
@ -133,7 +136,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err
}
@ -148,7 +151,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
return username, password, nil
}
func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
p.Lock()
defer p.Unlock()
@ -157,12 +160,12 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
renewStmts = defaultPostgresRenewSQL
}
db, err := p.getConnection()
db, err := p.getConnection(ctx)
if err != nil {
return err
}
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
@ -180,7 +183,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
if len(query) == 0 {
continue
}
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username,
"expiration": expirationStr,
}))
@ -189,7 +192,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
if _, err := stmt.ExecContext(ctx); err != nil {
return err
}
}
@ -201,25 +204,25 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
return nil
}
func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error {
func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
// Grab the lock
p.Lock()
defer p.Unlock()
if statements.RevocationStatements == "" {
return p.defaultRevokeUser(username)
return p.defaultRevokeUser(ctx, username)
}
return p.customRevokeUser(username, statements.RevocationStatements)
return p.customRevokeUser(ctx, username, statements.RevocationStatements)
}
func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
db, err := p.getConnection()
func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationStmts string) error {
db, err := p.getConnection(ctx)
if err != nil {
return err
}
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
@ -233,7 +236,7 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
continue
}
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username,
}))
if err != nil {
@ -241,7 +244,7 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
if _, err := stmt.ExecContext(ctx); err != nil {
return err
}
}
@ -253,15 +256,15 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
return nil
}
func (p *PostgreSQL) defaultRevokeUser(username string) error {
db, err := p.getConnection()
func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error {
db, err := p.getConnection(ctx)
if err != nil {
return err
}
// Check if the role exists
var exists bool
err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
if err != nil && err != sql.ErrNoRows {
return err
}
@ -274,13 +277,13 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
// the role
// This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible
stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
stmt, err := db.PrepareContext(ctx, "SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
if err != nil {
return err
}
defer stmt.Close()
rows, err := stmt.Query(username)
rows, err := stmt.QueryContext(ctx, username)
if err != nil {
return err
}
@ -322,7 +325,7 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
// get the current database name so we can issue a REVOKE CONNECT for
// this username
var dbname sql.NullString
if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil {
if err := db.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbname); err != nil {
return err
}
@ -337,13 +340,13 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
// many permissions as possible right now
var lastStmtError error
for _, query := range revocationStmts {
stmt, err := db.Prepare(query)
stmt, err := db.PrepareContext(ctx, query)
if err != nil {
lastStmtError = err
continue
}
defer stmt.Close()
_, err = stmt.Exec()
_, err = stmt.ExecContext(ctx)
if err != nil {
lastStmtError = err
}
@ -358,13 +361,13 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
}
// Drop this user
stmt, err = db.Prepare(fmt.Sprintf(
stmt, err = db.PrepareContext(ctx, fmt.Sprintf(
`DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username)))
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
if _, err := stmt.ExecContext(ctx); err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package postgresql
import (
"context"
"database/sql"
"fmt"
"os"
@ -72,7 +73,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -92,7 +93,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
"max_open_connections": "5",
}
err = db.Initialize(connectionDetails, true)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -109,7 +110,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*PostgreSQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -120,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
}
// Test with no configured Creation Statememt
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
}
@ -129,7 +130,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
CreationStatements: testPostgresRole,
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -139,7 +140,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
}
statements.CreationStatements = testPostgresReadOnlyRole
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -162,7 +163,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*PostgreSQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -176,7 +177,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -185,7 +186,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err)
}
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -197,7 +198,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err)
}
statements.RenewStatements = defaultPostgresRenewSQL
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -206,7 +207,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err)
}
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -230,7 +231,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
dbRaw, _ := New()
db := dbRaw.(*PostgreSQL)
err := db.Initialize(connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -244,7 +245,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
RoleName: "test",
}
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -254,7 +255,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
}
// Test default revoke statememts
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -263,7 +264,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
t.Fatal("Credentials were not revoked")
}
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
@ -274,7 +275,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
// Test custom revoke statements
statements.RevocationStatements = defaultPostgresRevocationSQL
err = db.RevokeUser(statements, username)
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -1,6 +1,7 @@
package connutil
import (
"context"
"errors"
"sync"
)
@ -14,8 +15,8 @@ var (
// connections and is used in all the builtin database types.
type ConnectionProducer interface {
Close() error
Initialize(map[string]interface{}, bool) error
Connection() (interface{}, error)
Initialize(context.Context, map[string]interface{}, bool) error
Connection(context.Context) (interface{}, error)
sync.Locker
}

View File

@ -1,6 +1,7 @@
package connutil
import (
"context"
"database/sql"
"fmt"
"strings"
@ -25,7 +26,7 @@ type SQLConnectionProducer struct {
sync.Mutex
}
func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
c.Lock()
defer c.Unlock()
@ -62,11 +63,11 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo
c.Initialized = true
if verifyConnection {
if _, err := c.Connection(); err != nil {
if _, err := c.Connection(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
}
if err := c.db.Ping(); err != nil {
if err := c.db.PingContext(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
}
}
@ -74,14 +75,14 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo
return nil
}
func (c *SQLConnectionProducer) Connection() (interface{}, error) {
func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
if !c.Initialized {
return nil, ErrNotInitialized
}
// If we already have a DB, test it and return
if c.db != nil {
if err := c.db.Ping(); err == nil {
if err := c.db.PingContext(ctx); err == nil {
return c.db, nil
}
// If the ping was unsuccessful, close it and ignore errors as we'll be

View File

@ -5,7 +5,6 @@ import (
"fmt"
"strings"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/jsonutil"
@ -460,10 +459,11 @@ func (c *Core) setupCredentials() error {
backend, err = c.newCredentialBackend(entry.Type, sysView, view, conf)
if err != nil {
c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err)
if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" {
// If we encounter an error instantiating the backend due to it being missing from the catalog,
// skip backend initialization but register the entry to the mount table to preserve storage
// and path.
if entry.Type == "plugin" {
// If we encounter an error instantiating the backend due to an error,
// skip backend initialization but register the entry to the mount table
// to preserve storage and path.
c.logger.Warn("core: skipping plugin-based credential entry", "path", entry.Path)
goto ROUTER_MOUNT
}
return errLoadAuthFailed

View File

@ -750,6 +750,31 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke
}, nil
}
sysView := m.router.MatchingSystemView(le.Path)
if sysView == nil {
return nil, fmt.Errorf("expiration: unable to retrieve system view from router")
}
retResp := &logical.Response{}
switch {
case resp.Auth.Period > time.Duration(0):
// If it resp.Period is non-zero, use that as the TTL and override backend's
// call on TTL modification, such as a TTL value determined by
// framework.LeaseExtend call against the request. Also, cap period value to
// the sys/mount max value.
if resp.Auth.Period > sysView.MaxLeaseTTL() {
retResp.AddWarning(fmt.Sprintf("Period of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL()))
resp.Auth.Period = sysView.MaxLeaseTTL()
}
resp.Auth.TTL = resp.Auth.Period
case resp.Auth.TTL > time.Duration(0):
// Cap TTL value to the sys/mount max value
if resp.Auth.TTL > sysView.MaxLeaseTTL() {
retResp.AddWarning(fmt.Sprintf("TTL of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL()))
resp.Auth.TTL = sysView.MaxLeaseTTL()
}
}
// Attach the ClientToken
resp.Auth.ClientToken = token
resp.Auth.Increment = 0
@ -764,9 +789,9 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke
// Update the expiration time
m.updatePending(le, resp.Auth.LeaseTotal())
return &logical.Response{
Auth: resp.Auth,
}, nil
retResp.Auth = resp.Auth
return retResp, nil
}
// Register is used to take a request and response with an associated
@ -866,6 +891,12 @@ func (m *ExpirationManager) RegisterAuth(source string, auth *logical.Auth) erro
return err
}
// If it resp.Period is non-zero, override the TTL value determined
// by the backend.
if auth.Period > time.Duration(0) {
auth.TTL = auth.Period
}
// Create a lease entry
le := leaseEntry{
LeaseID: path.Join(source, saltedID),
@ -1017,8 +1048,7 @@ func (m *ExpirationManager) revokeEntry(le *leaseEntry) error {
}
// Handle standard revocation via backends
resp, err := m.router.Route(logical.RevokeRequest(
le.Path, le.Secret, le.Data))
resp, err := m.router.Route(logical.RevokeRequest(le.Path, le.Secret, le.Data))
if err != nil || (resp != nil && resp.IsError()) {
return fmt.Errorf("failed to revoke entry: resp:%#v err:%s", resp, err)
}

View File

@ -2,7 +2,9 @@ package vault_test
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"testing"
"time"
@ -178,14 +180,13 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
}
if testMount {
// Add plugin back to the catalog
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical")
// Mount the plugin at the same path after plugin is re-added to the catalog
// and expect an error due to existing path.
var err error
switch btype {
case logical.TypeLogical:
// Add plugin back to the catalog
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical")
_, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{
"type": "plugin",
"config": map[string]interface{}{
@ -193,6 +194,8 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
},
})
case logical.TypeCredential:
// Add plugin back to the catalog
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials")
_, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{
"type": "plugin",
"plugin_name": "mock-plugin",
@ -204,6 +207,129 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
}
}
func TestSystemBackend_Plugin_continueOnError(t *testing.T) {
t.Run("secret", func(t *testing.T) {
t.Run("sha256_mismatch", func(t *testing.T) {
testPlugin_continueOnError(t, logical.TypeLogical, true)
})
t.Run("missing_plugin", func(t *testing.T) {
testPlugin_continueOnError(t, logical.TypeLogical, false)
})
})
t.Run("auth", func(t *testing.T) {
t.Run("sha256_mismatch", func(t *testing.T) {
testPlugin_continueOnError(t, logical.TypeCredential, true)
})
t.Run("missing_plugin", func(t *testing.T) {
testPlugin_continueOnError(t, logical.TypeCredential, false)
})
})
}
func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatch bool) {
cluster := testSystemBackendMock(t, 1, 1, btype)
defer cluster.Cleanup()
core := cluster.Cores[0]
// Get the registered plugin
req := logical.TestRequest(t, logical.ReadOperation, "sys/plugins/catalog/mock-plugin")
req.ClientToken = core.Client.Token()
resp, err := core.HandleRequest(req)
if err != nil || resp == nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
command, ok := resp.Data["command"].(string)
if !ok || command == "" {
t.Fatal("invalid command")
}
// Trigger a sha256 mistmatch or missing plugin error
if mismatch {
req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/catalog/mock-plugin")
req.Data = map[string]interface{}{
"sha256": "d17bd7334758e53e6fbab15745d2520765c06e296f2ce8e25b7919effa0ac216",
"command": filepath.Base(command),
}
req.ClientToken = core.Client.Token()
resp, err = core.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
} else {
err := os.Remove(filepath.Join(cluster.TempDir, filepath.Base(command)))
if err != nil {
t.Fatal(err)
}
}
// Seal the cluster
cluster.EnsureCoresSealed(t)
// Unseal the cluster
barrierKeys := cluster.BarrierKeys
for _, core := range cluster.Cores {
for _, key := range barrierKeys {
_, err := core.Unseal(vault.TestKeyCopy(key))
if err != nil {
t.Fatal(err)
}
}
sealed, err := core.Sealed()
if err != nil {
t.Fatalf("err checking seal status: %s", err)
}
if sealed {
t.Fatal("should not be sealed")
}
// Wait for active so post-unseal takes place
// If it fails, it means unseal process failed
vault.TestWaitActive(t, core.Core)
}
// Re-add the plugin to the catalog
switch btype {
case logical.TypeLogical:
vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical", cluster.TempDir)
case logical.TypeCredential:
vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials", cluster.TempDir)
}
// Reload the plugin
req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/reload/backend")
req.Data = map[string]interface{}{
"plugin": "mock-plugin",
}
req.ClientToken = core.Client.Token()
resp, err = core.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Make a request to lazy load the plugin
var reqPath string
switch btype {
case logical.TypeLogical:
reqPath = "mock-0/internal"
case logical.TypeCredential:
reqPath = "auth/mock-0/internal"
}
req = logical.TestRequest(t, logical.ReadOperation, reqPath)
req.ClientToken = core.Client.Token()
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: response should not be nil")
}
}
func TestSystemBackend_Plugin_autoReload(t *testing.T) {
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
defer cluster.Cleanup()
@ -332,7 +458,10 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
}
// testSystemBackendMock returns a systemBackend with the desired number
// of mounted mock plugin backends
// of mounted mock plugin backends. numMounts alternates between different
// ways of providing the plugin_name.
//
// The mounts are mounted at sys/mounts/mock-[numMounts] or sys/auth/mock-[numMounts]
func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType) *vault.TestCluster {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
@ -343,10 +472,17 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo
},
}
// Create a tempdir, cluster.Cleanup will clean up this directory
tempDir, err := ioutil.TempDir("", "vault-test-cluster")
if err != nil {
t.Fatal(err)
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
KeepStandbysSealed: true,
NumCores: numCores,
TempDir: tempDir,
})
cluster.Start()
@ -358,7 +494,7 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo
switch backendType {
case logical.TypeLogical:
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical")
vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical", tempDir)
for i := 0; i < numMounts; i++ {
// Alternate input styles for plugin_name on every other mount
options := map[string]interface{}{
@ -380,7 +516,7 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo
}
}
case logical.TypeCredential:
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials")
vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials", tempDir)
for i := 0; i < numMounts; i++ {
// Alternate input styles for plugin_name on every other mount
options := map[string]interface{}{

View File

@ -7,7 +7,6 @@ import (
"strings"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/jsonutil"
@ -753,10 +752,11 @@ func (c *Core) setupMounts() error {
backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf)
if err != nil {
c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err)
if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" {
// If we encounter an error instantiating the backend due to it being missing from the catalog,
// skip backend initialization but register the entry to the mount table to preserve storage
// and path.
if entry.Type == "plugin" {
// If we encounter an error instantiating the backend due to an error,
// skip backend initialization but register the entry to the mount table
// to preserve storage and path.
c.logger.Warn("core: skipping plugin-based mount entry", "path", entry.Path)
goto ROUTER_MOUNT
}
return errLoadMountsFailed

View File

@ -79,15 +79,23 @@ func (c *Core) reloadMatchingPlugin(pluginName string) error {
func (c *Core) reloadPluginCommon(entry *MountEntry, isAuth bool) error {
path := entry.Path
if isAuth {
path = credentialRoutePrefix + path
}
// Fast-path out if the backend doesn't exist
raw, ok := c.router.root.Get(path)
if !ok {
return nil
}
// Call backend's Cleanup routine
re := raw.(*routeEntry)
re.backend.Cleanup()
// Only call Cleanup if backend is initialized
if re.backend != nil {
// Call backend's Cleanup routine
re.backend.Cleanup()
}
view := re.storageView

View File

@ -477,14 +477,27 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
return nil, nil, ErrInternalError
}
// Set the default lease if not provided
if auth.TTL == 0 {
auth.TTL = sysView.DefaultLeaseTTL()
}
// Start off with the sys default value, and update according to period/TTL
// from resp.Auth
tokenTTL := sysView.DefaultLeaseTTL()
// Limit the lease duration
if auth.TTL > sysView.MaxLeaseTTL() {
auth.TTL = sysView.MaxLeaseTTL()
switch {
case auth.Period > time.Duration(0):
// Cap the period value to the sys max_ttl value. The auth backend should
// have checked for it on its login path, but we check here again for
// sanity.
if auth.Period > sysView.MaxLeaseTTL() {
auth.Period = sysView.MaxLeaseTTL()
}
tokenTTL = auth.Period
case auth.TTL > time.Duration(0):
// Cap the TTL value. The auth backend should have checked for it on its
// login path (e.g. a call to b.SanitizeTTL), but we check here again for
// sanity.
if auth.TTL > sysView.MaxLeaseTTL() {
auth.TTL = sysView.MaxLeaseTTL()
}
tokenTTL = auth.TTL
}
// Generate a token
@ -494,7 +507,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
Meta: auth.Metadata,
DisplayName: auth.DisplayName,
CreationTime: time.Now().Unix(),
TTL: auth.TTL,
TTL: tokenTTL,
NumUses: auth.NumUses,
EntityID: auth.EntityID,
}
@ -513,10 +526,11 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
return nil, auth, ErrInternalError
}
// Populate the client token and accessor
// Populate the client token, accessor, and TTL
auth.ClientToken = te.ID
auth.Accessor = te.Accessor
auth.Policies = te.Policies
auth.TTL = te.TTL
// Register with the expiration manager
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {

View File

@ -379,6 +379,8 @@ func TestDynamicSystemView(c *Core) *dynamicSystemView {
return &dynamicSystemView{c, me}
}
// TestAddTestPlugin registers the testFunc as part of the plugin command to the
// plugin catalog.
func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) {
file, err := os.Open(os.Args[0])
if err != nil {
@ -413,11 +415,74 @@ func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) {
}
}
// TestAddTestPluginTempDir registers the testFunc as part of the plugin command to the
// plugin catalog. It uses tmpDir as the plugin directory.
func TestAddTestPluginTempDir(t testing.T, c *Core, name, testFunc, tempDir string) {
file, err := os.Open(os.Args[0])
if err != nil {
t.Fatal(err)
}
defer file.Close()
fi, err := file.Stat()
if err != nil {
t.Fatal(err)
}
// Copy over the file to the temp dir
dst := filepath.Join(tempDir, filepath.Base(os.Args[0]))
out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
if err != nil {
t.Fatal(err)
}
defer out.Close()
if _, err = io.Copy(out, file); err != nil {
t.Fatal(err)
}
err = out.Sync()
if err != nil {
t.Fatal(err)
}
// Determine plugin directory full path
fullPath, err := filepath.EvalSymlinks(tempDir)
if err != nil {
t.Fatal(err)
}
reader, err := os.Open(filepath.Join(fullPath, filepath.Base(os.Args[0])))
if err != nil {
t.Fatal(err)
}
defer reader.Close()
// Find out the sha256
hash := sha256.New()
_, err = io.Copy(hash, reader)
if err != nil {
t.Fatal(err)
}
sum := hash.Sum(nil)
// Set core's plugin directory and plugin catalog directory
c.pluginDirectory = fullPath
c.pluginCatalog.directory = fullPath
command := fmt.Sprintf("%s --test.run=%s", filepath.Base(os.Args[0]), testFunc)
err = c.pluginCatalog.Set(name, command, sum)
if err != nil {
t.Fatal(err)
}
}
var testLogicalBackends = map[string]logical.Factory{}
var testCredentialBackends = map[string]logical.Factory{}
// Starts the test server which responds to SSH authentication.
// Used to test the SSH secret backend.
// StartSSHHostTestServer starts the test server which responds to SSH
// authentication. Used to test the SSH secret backend.
func StartSSHHostTestServer() (string, error) {
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testSharedPublicKey))
if err != nil {
@ -760,6 +825,7 @@ func (c *TestCluster) ensureCoresSealed() error {
return nil
}
// UnsealWithStoredKeys uses stored keys to unseal the test cluster cores
func (c *TestCluster) UnsealWithStoredKeys(t testing.T) error {
for _, core := range c.Cores {
if err := core.UnsealWithStoredKeys(); err != nil {

View File

@ -2312,7 +2312,7 @@ func TestTokenStore_RolePeriod(t *testing.T) {
req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test")
req.ClientToken = root
req.Data = map[string]interface{}{
"period": 300,
"period": 5,
}
resp, err := core.HandleRequest(req)
@ -2425,8 +2425,8 @@ func TestTokenStore_RolePeriod(t *testing.T) {
t.Fatalf("err: %v", err)
}
ttl := resp.Data["ttl"].(int64)
if ttl < 299 {
t.Fatalf("TTL too small (expected %d, got %d", 299, ttl)
if ttl > 5 {
t.Fatalf("TTL too large (expected %d, got %d", 5, ttl)
}
// Let the TTL go down a bit to 3 seconds
@ -2449,8 +2449,8 @@ func TestTokenStore_RolePeriod(t *testing.T) {
t.Fatalf("err: %v", err)
}
ttl = resp.Data["ttl"].(int64)
if ttl < 299 {
t.Fatalf("TTL too small (expected %d, got %d", 299, ttl)
if ttl > 5 {
t.Fatalf("TTL too large (expected %d, got %d", 5, ttl)
}
}
}
@ -2677,7 +2677,7 @@ func TestTokenStore_Periodic(t *testing.T) {
req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test")
req.ClientToken = root
req.Data = map[string]interface{}{
"period": 300,
"period": 5,
}
resp, err := core.HandleRequest(req)
@ -2715,8 +2715,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err)
}
ttl := resp.Data["ttl"].(int64)
if ttl < 299 {
t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl)
if ttl > 5 {
t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl)
}
// Let the TTL go down a bit
@ -2739,8 +2739,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err)
}
ttl = resp.Data["ttl"].(int64)
if ttl < 299 {
t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl)
if ttl > 5 {
t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl)
}
}
@ -2750,8 +2750,8 @@ func TestTokenStore_Periodic(t *testing.T) {
req.Operation = logical.UpdateOperation
req.Path = "auth/token/create"
req.Data = map[string]interface{}{
"period": 300,
"explicit_max_ttl": 150,
"period": 5,
"explicit_max_ttl": 4,
}
resp, err = core.HandleRequest(req)
if err != nil {
@ -2775,8 +2775,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err)
}
ttl := resp.Data["ttl"].(int64)
if ttl < 149 || ttl > 150 {
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl)
if ttl < 3 || ttl > 4 {
t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl)
}
// Let the TTL go down a bit
@ -2799,8 +2799,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err)
}
ttl = resp.Data["ttl"].(int64)
if ttl < 140 || ttl > 150 {
t.Fatalf("TTL bad (expected around %d, got %d)", 145, ttl)
if ttl > 2 {
t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl)
}
}
@ -2812,7 +2812,7 @@ func TestTokenStore_Periodic(t *testing.T) {
req.Operation = logical.UpdateOperation
req.Path = "auth/token/create/test"
req.Data = map[string]interface{}{
"period": 150,
"period": 5,
}
resp, err = core.HandleRequest(req)
if err != nil {
@ -2836,8 +2836,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err)
}
ttl := resp.Data["ttl"].(int64)
if ttl < 149 || ttl > 150 {
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl)
if ttl < 4 || ttl > 5 {
t.Fatalf("TTL bad (expected %d, got %d)", 4, ttl)
}
// Let the TTL go down a bit
@ -2860,8 +2860,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err)
}
ttl = resp.Data["ttl"].(int64)
if ttl < 149 {
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl)
if ttl > 5 {
t.Fatalf("TTL bad (expected less than %d, got %d)", 5, ttl)
}
}
@ -2869,18 +2869,23 @@ func TestTokenStore_Periodic(t *testing.T) {
{
req.Path = "auth/token/roles/test"
req.ClientToken = root
req.Operation = logical.UpdateOperation
req.Data = map[string]interface{}{
"period": 300,
"explicit_max_ttl": 150,
"period": 5,
"explicit_max_ttl": 4,
}
resp, err := core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v %v", err, resp)
}
if resp != nil {
t.Fatalf("expected a nil response")
}
req.ClientToken = root
req.Operation = logical.UpdateOperation
req.Path = "auth/token/create/test"
req.Data = map[string]interface{}{
"period": 150,
"explicit_max_ttl": 130,
}
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v %v", err, resp)
@ -2903,12 +2908,12 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err)
}
ttl := resp.Data["ttl"].(int64)
if ttl < 129 || ttl > 130 {
t.Fatalf("TTL bad (expected %d, got %d)", 129, ttl)
if ttl < 3 || ttl > 4 {
t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl)
}
// Let the TTL go down a bit
time.Sleep(4 * time.Second)
time.Sleep(2 * time.Second)
req.Operation = logical.UpdateOperation
req.Path = "auth/token/renew-self"
@ -2927,8 +2932,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err)
}
ttl = resp.Data["ttl"].(int64)
if ttl > 127 {
t.Fatalf("TTL bad (expected < %d, got %d)", 128, ttl)
if ttl > 2 {
t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl)
}
}
}