Merge remote-tracking branch 'oss/master' into f-nomad

* oss/master:
  Defer reader.Close that is used to determine sha256
  changelog++
  Avoid unseal failure if plugin backends fail to setup during postUnseal (#3686)
  Add logic for using Auth.Period when handling auth login/renew requests (#3677)
  plugins/database: use context with plugins that use database/sql package (#3691)
  changelog++
  Fix plaintext backup in transit (#3692)
  Database gRPC plugins (#3666)
This commit is contained in:
Chris Hoffman 2017-12-15 17:05:42 -05:00
commit db0006ef65
47 changed files with 1824 additions and 526 deletions

View file

@ -33,6 +33,8 @@ FEATURES:
operation that can export a given key, including all key versions and operation that can export a given key, including all key versions and
configuration, as well as a restore operation allowing import into another configuration, as well as a restore operation allowing import into another
Vault. Vault.
* **gRPC Database Plugins**: Database plugins now use gRPC for transport,
allowing them to be written in other languages.
IMPROVEMENTS: IMPROVEMENTS:
@ -46,6 +48,10 @@ IMPROVEMENTS:
during database configuration. This establishes a session-wide [write during database configuration. This establishes a session-wide [write
concern](https://docs.mongodb.com/manual/reference/write-concern/) for the concern](https://docs.mongodb.com/manual/reference/write-concern/) for the
lifecycle of the mount [GH-3646] lifecycle of the mount [GH-3646]
* mfa/okta: Filter a given email address as a login filter, allowing operation
when login email and account email are different
* plugins: Make Vault more resilient when unsealing when plugins are
unavailable [GH-3686]
* secret/pki: `allowed_domains` and `key_usage` can now be specified * secret/pki: `allowed_domains` and `key_usage` can now be specified
as a comma-separated string or an array of strings [GH-3642] as a comma-separated string or an array of strings [GH-3642]
* secret/ssh: Allow 4096-bit keys to be used in dynamic key method [GH-3593] * secret/ssh: Allow 4096-bit keys to be used in dynamic key method [GH-3593]
@ -58,8 +64,12 @@ BUG FIXES:
* auth/cert: Return `allowed_names` on role read [GH-3654] * auth/cert: Return `allowed_names` on role read [GH-3654]
* auth/ldap: Fix incorrect control information being sent [GH-3402] [GH-3496] * auth/ldap: Fix incorrect control information being sent [GH-3402] [GH-3496]
[GH-3625] [GH-3656] [GH-3625] [GH-3656]
* core: Fix seal status reporting when using an autoseal
* core: Add creation path to wrap info for a control group token
* core: Fix potential panic that could occur using plugins when a node * core: Fix potential panic that could occur using plugins when a node
transitioned from active to standby [GH-3638] transitioned from active to standby [GH-3638]
* core: Fix memory ballooning when a connection would connect to the cluster
port and then go away -- redux! [GH-3680]
* core: Replace recursive token revocation logic with depth-first logic, which * core: Replace recursive token revocation logic with depth-first logic, which
can avoid hitting stack depth limits in extreme cases [GH-2348] can avoid hitting stack depth limits in extreme cases [GH-2348]
* core/pkcs11 (enterprise): Fix panic when PKCS#11 library is not readable * core/pkcs11 (enterprise): Fix panic when PKCS#11 library is not readable

View file

@ -84,6 +84,7 @@ proto:
protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding
protoc -I physical physical/types.proto --go_out=plugins=grpc:physical protoc -I physical physical/types.proto --go_out=plugins=grpc:physical
protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity
protoc builtin/logical/database/dbplugin/*.proto --go_out=plugins=grpc:.
sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go
sed -i -e 's/Iv/IV/' -e 's/Hmac/HMAC/' physical/types.pb.go sed -i -e 's/Iv/IV/' -e 's/Hmac/HMAC/' physical/types.pb.go

View file

@ -141,7 +141,7 @@ func testAccStepMapUserIdCidr(t *testing.T, cidr string) logicaltest.TestStep {
func testAccLogin(t *testing.T, display string) logicaltest.TestStep { func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
checkTTL := func(resp *logical.Response) error { checkTTL := func(resp *logical.Response) error {
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" { if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
return fmt.Errorf("invalid TTL") return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
} }
return nil return nil
} }
@ -165,7 +165,7 @@ func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep { func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep {
checkTTL := func(resp *logical.Response) error { checkTTL := func(resp *logical.Response) error {
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" { if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
return fmt.Errorf("invalid TTL") return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
} }
return nil return nil
} }

View file

@ -3,7 +3,6 @@ package approle
import ( import (
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
@ -68,20 +67,13 @@ func (b *backend) pathLoginUpdate(req *logical.Request, data *framework.FieldDat
Policies: role.Policies, Policies: role.Policies,
LeaseOptions: logical.LeaseOptions{ LeaseOptions: logical.LeaseOptions{
Renewable: true, Renewable: true,
TTL: role.TokenTTL,
}, },
Alias: &logical.Alias{ Alias: &logical.Alias{
Name: role.RoleID, Name: role.RoleID,
}, },
} }
// If 'Period' is set, use the value of 'Period' as the TTL.
// Otherwise, set the normal TokenTTL.
if role.Period > time.Duration(0) {
auth.TTL = role.Period
} else {
auth.TTL = role.TokenTTL
}
return &logical.Response{ return &logical.Response{
Auth: auth, Auth: auth,
}, nil }, nil
@ -107,16 +99,12 @@ func (b *backend) pathLoginRenew(req *logical.Request, data *framework.FieldData
return nil, fmt.Errorf("role %s does not exist during renewal", roleName) return nil, fmt.Errorf("role %s does not exist during renewal", roleName)
} }
// If 'Period' is set on the Role, the token should never expire. resp, err := framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
// Replenish the TTL with 'Period's value. if err != nil {
if role.Period > time.Duration(0) { return nil, err
// If 'Period' was updated after the token was issued,
// token will bear the updated 'Period' value as its TTL.
req.Auth.TTL = role.Period
return &logical.Response{Auth: req.Auth}, nil
} else {
return framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
} }
resp.Auth.Period = role.Period
return resp, nil
} }
const pathLoginHelpSys = "Issue a token based on the credentials supplied" const pathLoginHelpSys = "Issue a token based on the credentials supplied"

View file

@ -1,6 +1,7 @@
package database package database
import ( import (
"context"
"fmt" "fmt"
"net/rpc" "net/rpc"
"strings" "strings"
@ -87,7 +88,7 @@ func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) {
// This function creates a new db object from the stored configuration and // This function creates a new db object from the stored configuration and
// caches it in the connections map. The caller of this function needs to hold // caches it in the connections map. The caller of this function needs to hold
// the backend's write lock // the backend's write lock
func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) { func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) {
db, ok := b.connections[name] db, ok := b.connections[name]
if ok { if ok {
return db, nil return db, nil
@ -103,7 +104,7 @@ func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.
return nil, err return nil, err
} }
err = db.Initialize(config.ConnectionDetails, true) err = db.Initialize(ctx, config.ConnectionDetails, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -170,7 +171,8 @@ func (b *databaseBackend) clearConnection(name string) {
func (b *databaseBackend) closeIfShutdown(name string, err error) { func (b *databaseBackend) closeIfShutdown(name string, err error) {
// Plugin has shutdown, close it so next call can reconnect. // Plugin has shutdown, close it so next call can reconnect.
if err == rpc.ErrShutdown { switch err {
case rpc.ErrShutdown, dbplugin.ErrPluginShutdown:
b.Lock() b.Lock()
b.clearConnection(name) b.clearConnection(name)
b.Unlock() b.Unlock()

View file

@ -488,9 +488,11 @@ func TestBackend_roleCrud(t *testing.T) {
RevocationStatements: defaultRevocationSQL, RevocationStatements: defaultRevocationSQL,
} }
var actual dbplugin.Statements actual := dbplugin.Statements{
if err := mapstructure.Decode(resp.Data, &actual); err != nil { CreationStatements: resp.Data["creation_statements"].(string),
t.Fatal(err) RevocationStatements: resp.Data["revocation_statements"].(string),
RollbackStatements: resp.Data["rollback_statements"].(string),
RenewStatements: resp.Data["renew_statements"].(string),
} }
if !reflect.DeepEqual(expected, actual) { if !reflect.DeepEqual(expected, actual) {

View file

@ -1,10 +1,8 @@
package dbplugin package dbplugin
import ( import (
"fmt" "errors"
"net/rpc"
"sync" "sync"
"time"
"github.com/hashicorp/go-plugin" "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/helper/pluginutil"
@ -17,11 +15,11 @@ type DatabasePluginClient struct {
client *plugin.Client client *plugin.Client
sync.Mutex sync.Mutex
*databasePluginRPCClient Database
} }
func (dc *DatabasePluginClient) Close() error { func (dc *DatabasePluginClient) Close() error {
err := dc.databasePluginRPCClient.Close() err := dc.Database.Close()
dc.client.Kill() dc.client.Kill()
return err return err
@ -55,79 +53,20 @@ func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginR
// We should have a database type now. This feels like a normal interface // We should have a database type now. This feels like a normal interface
// implementation but is in fact over an RPC connection. // implementation but is in fact over an RPC connection.
databaseRPC := raw.(*databasePluginRPCClient) var db Database
switch raw.(type) {
case *gRPCClient:
db = raw.(*gRPCClient)
case *databasePluginRPCClient:
logger.Warn("database: plugin is using deprecated net RPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name)
db = raw.(*databasePluginRPCClient)
default:
return nil, errors.New("unsupported client type")
}
// Wrap RPC implimentation in DatabasePluginClient // Wrap RPC implimentation in DatabasePluginClient
return &DatabasePluginClient{ return &DatabasePluginClient{
client: client, client: client,
databasePluginRPCClient: databaseRPC, Database: db,
}, nil }, nil
} }
// ---- RPC client domain ----
// databasePluginRPCClient implements Database and is used on the client to
// make RPC calls to a plugin.
type databasePluginRPCClient struct {
client *rpc.Client
}
func (dr *databasePluginRPCClient) Type() (string, error) {
var dbType string
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
return fmt.Sprintf("plugin-%s", dbType), err
}
func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
req := CreateUserRequest{
Statements: statements,
UsernameConfig: usernameConfig,
Expiration: expiration,
}
var resp CreateUserResponse
err = dr.client.Call("Plugin.CreateUser", req, &resp)
return resp.Username, resp.Password, err
}
func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error {
req := RenewUserRequest{
Statements: statements,
Username: username,
Expiration: expiration,
}
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error {
req := RevokeUserRequest{
Statements: statements,
Username: username,
}
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}, verifyConnection bool) error {
req := InitializeRequest{
Config: conf,
VerifyConnection: verifyConnection,
}
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Close() error {
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
return err
}

View file

@ -0,0 +1,556 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: builtin/logical/database/dbplugin/database.proto
/*
Package dbplugin is a generated protocol buffer package.
It is generated from these files:
builtin/logical/database/dbplugin/database.proto
It has these top-level messages:
InitializeRequest
CreateUserRequest
RenewUserRequest
RevokeUserRequest
Statements
UsernameConfig
CreateUserResponse
TypeResponse
Empty
*/
package dbplugin
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import google_protobuf "github.com/golang/protobuf/ptypes/timestamp"
import (
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type InitializeRequest struct {
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
VerifyConnection bool `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"`
}
func (m *InitializeRequest) Reset() { *m = InitializeRequest{} }
func (m *InitializeRequest) String() string { return proto.CompactTextString(m) }
func (*InitializeRequest) ProtoMessage() {}
func (*InitializeRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *InitializeRequest) GetConfig() []byte {
if m != nil {
return m.Config
}
return nil
}
func (m *InitializeRequest) GetVerifyConnection() bool {
if m != nil {
return m.VerifyConnection
}
return false
}
type CreateUserRequest struct {
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"`
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
}
func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} }
func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) }
func (*CreateUserRequest) ProtoMessage() {}
func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *CreateUserRequest) GetStatements() *Statements {
if m != nil {
return m.Statements
}
return nil
}
func (m *CreateUserRequest) GetUsernameConfig() *UsernameConfig {
if m != nil {
return m.UsernameConfig
}
return nil
}
func (m *CreateUserRequest) GetExpiration() *google_protobuf.Timestamp {
if m != nil {
return m.Expiration
}
return nil
}
type RenewUserRequest struct {
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
}
func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} }
func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) }
func (*RenewUserRequest) ProtoMessage() {}
func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
func (m *RenewUserRequest) GetStatements() *Statements {
if m != nil {
return m.Statements
}
return nil
}
func (m *RenewUserRequest) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
func (m *RenewUserRequest) GetExpiration() *google_protobuf.Timestamp {
if m != nil {
return m.Expiration
}
return nil
}
type RevokeUserRequest struct {
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
}
func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} }
func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) }
func (*RevokeUserRequest) ProtoMessage() {}
func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
func (m *RevokeUserRequest) GetStatements() *Statements {
if m != nil {
return m.Statements
}
return nil
}
func (m *RevokeUserRequest) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
type Statements struct {
CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"`
RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"`
RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"`
RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"`
}
func (m *Statements) Reset() { *m = Statements{} }
func (m *Statements) String() string { return proto.CompactTextString(m) }
func (*Statements) ProtoMessage() {}
func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
func (m *Statements) GetCreationStatements() string {
if m != nil {
return m.CreationStatements
}
return ""
}
func (m *Statements) GetRevocationStatements() string {
if m != nil {
return m.RevocationStatements
}
return ""
}
func (m *Statements) GetRollbackStatements() string {
if m != nil {
return m.RollbackStatements
}
return ""
}
func (m *Statements) GetRenewStatements() string {
if m != nil {
return m.RenewStatements
}
return ""
}
type UsernameConfig struct {
DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"`
RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"`
}
func (m *UsernameConfig) Reset() { *m = UsernameConfig{} }
func (m *UsernameConfig) String() string { return proto.CompactTextString(m) }
func (*UsernameConfig) ProtoMessage() {}
func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
func (m *UsernameConfig) GetDisplayName() string {
if m != nil {
return m.DisplayName
}
return ""
}
func (m *UsernameConfig) GetRoleName() string {
if m != nil {
return m.RoleName
}
return ""
}
type CreateUserResponse struct {
Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"`
Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"`
}
func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} }
func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) }
func (*CreateUserResponse) ProtoMessage() {}
func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
func (m *CreateUserResponse) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
func (m *CreateUserResponse) GetPassword() string {
if m != nil {
return m.Password
}
return ""
}
type TypeResponse struct {
Type string `protobuf:"bytes,1,opt,name=type" json:"type,omitempty"`
}
func (m *TypeResponse) Reset() { *m = TypeResponse{} }
func (m *TypeResponse) String() string { return proto.CompactTextString(m) }
func (*TypeResponse) ProtoMessage() {}
func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
func (m *TypeResponse) GetType() string {
if m != nil {
return m.Type
}
return ""
}
type Empty struct {
}
func (m *Empty) Reset() { *m = Empty{} }
func (m *Empty) String() string { return proto.CompactTextString(m) }
func (*Empty) ProtoMessage() {}
func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} }
func init() {
proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest")
proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest")
proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest")
proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest")
proto.RegisterType((*Statements)(nil), "dbplugin.Statements")
proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig")
proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse")
proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse")
proto.RegisterType((*Empty)(nil), "dbplugin.Empty")
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// Client API for Database service
type DatabaseClient interface {
Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error)
CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error)
RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error)
RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error)
Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error)
Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
}
type databaseClient struct {
cc *grpc.ClientConn
}
func NewDatabaseClient(cc *grpc.ClientConn) DatabaseClient {
return &databaseClient{cc}
}
func (c *databaseClient) Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) {
out := new(TypeResponse)
err := grpc.Invoke(ctx, "/dbplugin.Database/Type", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) {
out := new(CreateUserResponse)
err := grpc.Invoke(ctx, "/dbplugin.Database/CreateUser", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/RenewUser", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/RevokeUser", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/Close", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// Server API for Database service
type DatabaseServer interface {
Type(context.Context, *Empty) (*TypeResponse, error)
CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error)
RenewUser(context.Context, *RenewUserRequest) (*Empty, error)
RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error)
Initialize(context.Context, *InitializeRequest) (*Empty, error)
Close(context.Context, *Empty) (*Empty, error)
}
func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) {
s.RegisterService(&_Database_serviceDesc, srv)
}
func _Database_Type_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Type(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Type",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Type(ctx, req.(*Empty))
}
return interceptor(ctx, in, info, handler)
}
func _Database_CreateUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CreateUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).CreateUser(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/CreateUser",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).CreateUser(ctx, req.(*CreateUserRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_RenewUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RenewUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).RenewUser(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/RenewUser",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).RenewUser(ctx, req.(*RenewUserRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RevokeUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).RevokeUser(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/RevokeUser",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).RevokeUser(ctx, req.(*RevokeUserRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(InitializeRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Initialize(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Initialize",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Close(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Close",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Close(ctx, req.(*Empty))
}
return interceptor(ctx, in, info, handler)
}
var _Database_serviceDesc = grpc.ServiceDesc{
ServiceName: "dbplugin.Database",
HandlerType: (*DatabaseServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Type",
Handler: _Database_Type_Handler,
},
{
MethodName: "CreateUser",
Handler: _Database_CreateUser_Handler,
},
{
MethodName: "RenewUser",
Handler: _Database_RenewUser_Handler,
},
{
MethodName: "RevokeUser",
Handler: _Database_RevokeUser_Handler,
},
{
MethodName: "Initialize",
Handler: _Database_Initialize_Handler,
},
{
MethodName: "Close",
Handler: _Database_Close_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "builtin/logical/database/dbplugin/database.proto",
}
func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 548 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e,
0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f,
0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e,
0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2,
0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6,
0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c,
0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e,
0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3,
0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1,
0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f,
0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49,
0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35,
0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20,
0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6,
0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34,
0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0,
0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce,
0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68,
0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0,
0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8,
0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1,
0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0,
0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a,
0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab,
0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c,
0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4,
0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36,
0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a,
0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf,
0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67,
0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b,
0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d,
0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6,
0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56,
0x94, 0x05, 0x00, 0x00,
}

View file

@ -0,0 +1,58 @@
syntax = "proto3";
package dbplugin;
import "google/protobuf/timestamp.proto";
message InitializeRequest {
bytes config = 1;
bool verify_connection = 2;
}
message CreateUserRequest {
Statements statements = 1;
UsernameConfig username_config = 2;
google.protobuf.Timestamp expiration = 3;
}
message RenewUserRequest {
Statements statements = 1;
string username = 2;
google.protobuf.Timestamp expiration = 3;
}
message RevokeUserRequest {
Statements statements = 1;
string username = 2;
}
message Statements {
string creation_statements = 1;
string revocation_statements = 2;
string rollback_statements = 3;
string renew_statements = 4;
}
message UsernameConfig {
string DisplayName = 1;
string RoleName = 2;
}
message CreateUserResponse {
string username = 1;
string password = 2;
}
message TypeResponse {
string type = 1;
}
message Empty {}
service Database {
rpc Type(Empty) returns (TypeResponse);
rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
rpc RenewUser(RenewUserRequest) returns (Empty);
rpc RevokeUser(RevokeUserRequest) returns (Empty);
rpc Initialize(InitializeRequest) returns (Empty);
rpc Close(Empty) returns (Empty);
}

View file

@ -1,6 +1,7 @@
package dbplugin package dbplugin
import ( import (
"context"
"time" "time"
metrics "github.com/armon/go-metrics" metrics "github.com/armon/go-metrics"
@ -15,55 +16,56 @@ type databaseTracingMiddleware struct {
next Database next Database
logger log.Logger logger log.Logger
typeStr string typeStr string
transport string
} }
func (mw *databaseTracingMiddleware) Type() (string, error) { func (mw *databaseTracingMiddleware) Type() (string, error) {
return mw.next.Type() return mw.next.Type()
} }
func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { func (mw *databaseTracingMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
defer func(then time.Time) { defer func(then time.Time) {
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now()) }(time.Now())
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr) mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.CreateUser(statements, usernameConfig, expiration) return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
} }
func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { func (mw *databaseTracingMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
defer func(then time.Time) { defer func(then time.Time) {
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now()) }(time.Now())
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr) mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr, "transport", mw.transport)
return mw.next.RenewUser(statements, username, expiration) return mw.next.RenewUser(ctx, statements, username, expiration)
} }
func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) { func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
defer func(then time.Time) { defer func(then time.Time) {
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now()) }(time.Now())
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr) mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.RevokeUser(statements, username) return mw.next.RevokeUser(ctx, statements, username)
} }
func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
defer func(then time.Time) { defer func(then time.Time) {
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then)) mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then))
}(time.Now()) }(time.Now())
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr) mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.Initialize(conf, verifyConnection) return mw.next.Initialize(ctx, conf, verifyConnection)
} }
func (mw *databaseTracingMiddleware) Close() (err error) { func (mw *databaseTracingMiddleware) Close() (err error) {
defer func(then time.Time) { defer func(then time.Time) {
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now()) }(time.Now())
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr) mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.Close() return mw.next.Close()
} }
@ -81,7 +83,7 @@ func (mw *databaseMetricsMiddleware) Type() (string, error) {
return mw.next.Type() return mw.next.Type()
} }
func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { func (mw *databaseMetricsMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
defer func(now time.Time) { defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "CreateUser"}, now) metrics.MeasureSince([]string{"database", "CreateUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now)
@ -94,10 +96,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameC
metrics.IncrCounter([]string{"database", "CreateUser"}, 1) metrics.IncrCounter([]string{"database", "CreateUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1)
return mw.next.CreateUser(statements, usernameConfig, expiration) return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
} }
func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { func (mw *databaseMetricsMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
defer func(now time.Time) { defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "RenewUser"}, now) metrics.MeasureSince([]string{"database", "RenewUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now)
@ -110,10 +112,10 @@ func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username s
metrics.IncrCounter([]string{"database", "RenewUser"}, 1) metrics.IncrCounter([]string{"database", "RenewUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1)
return mw.next.RenewUser(statements, username, expiration) return mw.next.RenewUser(ctx, statements, username, expiration)
} }
func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username string) (err error) { func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
defer func(now time.Time) { defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "RevokeUser"}, now) metrics.MeasureSince([]string{"database", "RevokeUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now)
@ -126,10 +128,10 @@ func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username
metrics.IncrCounter([]string{"database", "RevokeUser"}, 1) metrics.IncrCounter([]string{"database", "RevokeUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1)
return mw.next.RevokeUser(statements, username) return mw.next.RevokeUser(ctx, statements, username)
} }
func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
defer func(now time.Time) { defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "Initialize"}, now) metrics.MeasureSince([]string{"database", "Initialize"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
@ -142,7 +144,7 @@ func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, ver
metrics.IncrCounter([]string{"database", "Initialize"}, 1) metrics.IncrCounter([]string{"database", "Initialize"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
return mw.next.Initialize(conf, verifyConnection) return mw.next.Initialize(ctx, conf, verifyConnection)
} }
func (mw *databaseMetricsMiddleware) Close() (err error) { func (mw *databaseMetricsMiddleware) Close() (err error) {

View file

@ -0,0 +1,198 @@
package dbplugin
import (
"context"
"encoding/json"
"errors"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"github.com/golang/protobuf/ptypes"
)
var (
ErrPluginShutdown = errors.New("plugin shutdown")
)
// ---- gRPC Server domain ----
type gRPCServer struct {
impl Database
}
func (s *gRPCServer) Type(context.Context, *Empty) (*TypeResponse, error) {
t, err := s.impl.Type()
if err != nil {
return nil, err
}
return &TypeResponse{
Type: t,
}, nil
}
func (s *gRPCServer) CreateUser(ctx context.Context, req *CreateUserRequest) (*CreateUserResponse, error) {
e, err := ptypes.Timestamp(req.Expiration)
if err != nil {
return nil, err
}
u, p, err := s.impl.CreateUser(ctx, *req.Statements, *req.UsernameConfig, e)
return &CreateUserResponse{
Username: u,
Password: p,
}, err
}
func (s *gRPCServer) RenewUser(ctx context.Context, req *RenewUserRequest) (*Empty, error) {
e, err := ptypes.Timestamp(req.Expiration)
if err != nil {
return nil, err
}
err = s.impl.RenewUser(ctx, *req.Statements, req.Username, e)
return &Empty{}, err
}
func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*Empty, error) {
err := s.impl.RevokeUser(ctx, *req.Statements, req.Username)
return &Empty{}, err
}
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
config := map[string]interface{}{}
err := json.Unmarshal(req.Config, &config)
if err != nil {
return nil, err
}
err = s.impl.Initialize(ctx, config, req.VerifyConnection)
return &Empty{}, err
}
func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) {
s.impl.Close()
return &Empty{}, nil
}
// ---- gRPC client domain ----
type gRPCClient struct {
client DatabaseClient
clientConn *grpc.ClientConn
}
func (c gRPCClient) Type() (string, error) {
// If the plugin has already shutdown, this will hang forever so we give it
// a one second timeout.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return "", ErrPluginShutdown
}
resp, err := c.client.Type(ctx, &Empty{})
if err != nil {
return "", err
}
return resp.Type, err
}
func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
t, err := ptypes.TimestampProto(expiration)
if err != nil {
return "", "", err
}
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return "", "", ErrPluginShutdown
}
resp, err := c.client.CreateUser(ctx, &CreateUserRequest{
Statements: &statements,
UsernameConfig: &usernameConfig,
Expiration: t,
})
if err != nil {
return "", "", err
}
return resp.Username, resp.Password, err
}
func (c *gRPCClient) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error {
t, err := ptypes.TimestampProto(expiration)
if err != nil {
return err
}
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return ErrPluginShutdown
}
_, err = c.client.RenewUser(ctx, &RenewUserRequest{
Statements: &statements,
Username: username,
Expiration: t,
})
return err
}
func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, username string) error {
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return ErrPluginShutdown
}
_, err := c.client.RevokeUser(ctx, &RevokeUserRequest{
Statements: &statements,
Username: username,
})
return err
}
func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error {
configRaw, err := json.Marshal(config)
if err != nil {
return err
}
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
default:
return ErrPluginShutdown
}
_, err = c.client.Initialize(ctx, &InitializeRequest{
Config: configRaw,
VerifyConnection: verifyConnection,
})
return err
}
func (c *gRPCClient) Close() error {
// If the plugin has already shutdown, this will hang forever so we give it
// a one second timeout.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
switch c.clientConn.GetState() {
case connectivity.Ready, connectivity.Idle:
_, err := c.client.Close(ctx, &Empty{})
return err
}
return nil
}

View file

@ -0,0 +1,139 @@
package dbplugin
import (
"context"
"fmt"
"net/rpc"
"time"
)
// ---- RPC server domain ----
// databasePluginRPCServer implements an RPC version of Database and is run
// inside a plugin. It wraps an underlying implementation of Database.
type databasePluginRPCServer struct {
impl Database
}
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
var err error
*resp, err = ds.impl.Type()
return err
}
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequestRPC, resp *CreateUserResponse) error {
var err error
resp.Username, resp.Password, err = ds.impl.CreateUser(context.Background(), args.Statements, args.UsernameConfig, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequestRPC, _ *struct{}) error {
err := ds.impl.RenewUser(context.Background(), args.Statements, args.Username, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *struct{}) error {
err := ds.impl.RevokeUser(context.Background(), args.Statements, args.Username)
return err
}
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error {
err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection)
return err
}
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
ds.impl.Close()
return nil
}
// ---- RPC client domain ----
// databasePluginRPCClient implements Database and is used on the client to
// make RPC calls to a plugin.
type databasePluginRPCClient struct {
client *rpc.Client
}
func (dr *databasePluginRPCClient) Type() (string, error) {
var dbType string
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
return fmt.Sprintf("plugin-%s", dbType), err
}
func (dr *databasePluginRPCClient) CreateUser(_ context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
req := CreateUserRequestRPC{
Statements: statements,
UsernameConfig: usernameConfig,
Expiration: expiration,
}
var resp CreateUserResponse
err = dr.client.Call("Plugin.CreateUser", req, &resp)
return resp.Username, resp.Password, err
}
func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements Statements, username string, expiration time.Time) error {
req := RenewUserRequestRPC{
Statements: statements,
Username: username,
Expiration: expiration,
}
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error {
req := RevokeUserRequestRPC{
Statements: statements,
Username: username,
}
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error {
req := InitializeRequestRPC{
Config: conf,
VerifyConnection: verifyConnection,
}
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Close() error {
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
return err
}
// ---- RPC Request Args Domain ----
type InitializeRequestRPC struct {
Config map[string]interface{}
VerifyConnection bool
}
type CreateUserRequestRPC struct {
Statements Statements
UsernameConfig UsernameConfig
Expiration time.Time
}
type RenewUserRequestRPC struct {
Statements Statements
Username string
Expiration time.Time
}
type RevokeUserRequestRPC struct {
Statements Statements
Username string
}

View file

@ -1,10 +1,13 @@
package dbplugin package dbplugin
import ( import (
"context"
"fmt" "fmt"
"net/rpc" "net/rpc"
"time" "time"
"google.golang.org/grpc"
"github.com/hashicorp/go-plugin" "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/helper/pluginutil"
log "github.com/mgutz/logxi/v1" log "github.com/mgutz/logxi/v1"
@ -13,29 +16,14 @@ import (
// Database is the interface that all database objects must implement. // Database is the interface that all database objects must implement.
type Database interface { type Database interface {
Type() (string, error) Type() (string, error)
CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
RenewUser(statements Statements, username string, expiration time.Time) error RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error
RevokeUser(statements Statements, username string) error RevokeUser(ctx context.Context, statements Statements, username string) error
Initialize(config map[string]interface{}, verifyConnection bool) error Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error
Close() error Close() error
} }
// Statements set in role creation and passed into the database type's functions.
type Statements struct {
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"`
RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"`
RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"`
}
// UsernameConfig is used to configure prefixes for the username to be
// generated.
type UsernameConfig struct {
DisplayName string
RoleName string
}
// PluginFactory is used to build plugin database types. It wraps the database // PluginFactory is used to build plugin database types. It wraps the database
// object in a logging and metrics middleware. // object in a logging and metrics middleware.
func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) { func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
@ -45,6 +33,7 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
return nil, err return nil, err
} }
var transport string
var db Database var db Database
if pluginRunner.Builtin { if pluginRunner.Builtin {
// Plugin is builtin so we can retrieve an instance of the interface // Plugin is builtin so we can retrieve an instance of the interface
@ -60,12 +49,24 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
return nil, fmt.Errorf("unsuported database type: %s", pluginName) return nil, fmt.Errorf("unsuported database type: %s", pluginName)
} }
transport = "builtin"
} else { } else {
// create a DatabasePluginClient instance // create a DatabasePluginClient instance
db, err = newPluginClient(sys, pluginRunner, logger) db, err = newPluginClient(sys, pluginRunner, logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Switch on the underlying database client type to get the transport
// method.
switch db.(*DatabasePluginClient).Database.(type) {
case *gRPCClient:
transport = "gRPC"
case *databasePluginRPCClient:
transport = "netRPC"
}
} }
typeStr, err := db.Type() typeStr, err := db.Type()
@ -82,9 +83,10 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
// Wrap with tracing middleware // Wrap with tracing middleware
if logger.IsTrace() { if logger.IsTrace() {
db = &databaseTracingMiddleware{ db = &databaseTracingMiddleware{
next: db, transport: transport,
typeStr: typeStr, next: db,
logger: logger, typeStr: typeStr,
logger: logger,
} }
} }
@ -115,33 +117,14 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
return &databasePluginRPCClient{client: c}, nil return &databasePluginRPCClient{client: c}, nil
} }
// ---- RPC Request Args Domain ---- func (d DatabasePlugin) GRPCServer(s *grpc.Server) error {
RegisterDatabaseServer(s, &gRPCServer{impl: d.impl})
type InitializeRequest struct { return nil
Config map[string]interface{}
VerifyConnection bool
} }
type CreateUserRequest struct { func (DatabasePlugin) GRPCClient(c *grpc.ClientConn) (interface{}, error) {
Statements Statements return &gRPCClient{
UsernameConfig UsernameConfig client: NewDatabaseClient(c),
Expiration time.Time clientConn: c,
} }, nil
type RenewUserRequest struct {
Statements Statements
Username string
Expiration time.Time
}
type RevokeUserRequest struct {
Statements Statements
Username string
}
// ---- RPC Response Args Domain ----
type CreateUserResponse struct {
Username string
Password string
} }

View file

@ -1,11 +1,13 @@
package dbplugin_test package dbplugin_test
import ( import (
"context"
"errors" "errors"
"os" "os"
"testing" "testing"
"time" "time"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/helper/pluginutil"
vaulthttp "github.com/hashicorp/vault/http" vaulthttp "github.com/hashicorp/vault/http"
@ -20,7 +22,7 @@ type mockPlugin struct {
} }
func (m *mockPlugin) Type() (string, error) { return "mock", nil } func (m *mockPlugin) Type() (string, error) { return "mock", nil }
func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { func (m *mockPlugin) CreateUser(_ context.Context, statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
err = errors.New("err") err = errors.New("err")
if usernameConf.DisplayName == "" || expiration.IsZero() { if usernameConf.DisplayName == "" || expiration.IsZero() {
return "", "", err return "", "", err
@ -34,7 +36,7 @@ func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbp
return usernameConf.DisplayName, "test", nil return usernameConf.DisplayName, "test", nil
} }
func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { func (m *mockPlugin) RenewUser(_ context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
err := errors.New("err") err := errors.New("err")
if username == "" || expiration.IsZero() { if username == "" || expiration.IsZero() {
return err return err
@ -46,7 +48,7 @@ func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string,
return nil return nil
} }
func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) error { func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statements, username string) error {
err := errors.New("err") err := errors.New("err")
if username == "" { if username == "" {
return err return err
@ -59,7 +61,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string)
delete(m.users, username) delete(m.users, username)
return nil return nil
} }
func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error { func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error {
err := errors.New("err") err := errors.New("err")
if len(conf) != 1 { if len(conf) != 1 {
return err return err
@ -80,14 +82,15 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
cores := cluster.Cores cores := cluster.Cores
sys := vault.TestDynamicSystemView(cores[0].Core) sys := vault.TestDynamicSystemView(cores[0].Core)
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_Main") vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_GRPC_Main")
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", "TestPlugin_NetRPC_Main")
return cluster, sys return cluster, sys
} }
// This is not an actual test case, it's a helper function that will be executed // This is not an actual test case, it's a helper function that will be executed
// by the go-plugin client via an exec call. // by the go-plugin client via an exec call.
func TestPlugin_Main(t *testing.T) { func TestPlugin_GRPC_Main(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return return
} }
@ -105,6 +108,30 @@ func TestPlugin_Main(t *testing.T) {
plugins.Serve(plugin, apiClientMeta.GetTLSConfig()) plugins.Serve(plugin, apiClientMeta.GetTLSConfig())
} }
// This is not an actual test case, it's a helper function that will be executed
// by the go-plugin client via an exec call.
func TestPlugin_NetRPC_Main(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return
}
p := &mockPlugin{
users: make(map[string][]string),
}
args := []string{"--tls-skip-verify=true"}
apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(args)
tlsProvider := pluginutil.VaultPluginTLSProvider(apiClientMeta.GetTLSConfig())
serveConf := dbplugin.ServeConfig(p, tlsProvider)
serveConf.GRPCServer = nil
plugin.Serve(serveConf)
}
func TestPlugin_Initialize(t *testing.T) { func TestPlugin_Initialize(t *testing.T) {
cluster, sys := getCluster(t) cluster, sys := getCluster(t)
defer cluster.Cleanup() defer cluster.Cleanup()
@ -118,7 +145,7 @@ func TestPlugin_Initialize(t *testing.T) {
"test": 1, "test": 1,
} }
err = dbRaw.Initialize(connectionDetails, true) err = dbRaw.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -143,7 +170,7 @@ func TestPlugin_CreateUser(t *testing.T) {
"test": 1, "test": 1,
} }
err = db.Initialize(connectionDetails, true) err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -153,7 +180,7 @@ func TestPlugin_CreateUser(t *testing.T) {
RoleName: "test", RoleName: "test",
} }
us, pw, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -163,7 +190,7 @@ func TestPlugin_CreateUser(t *testing.T) {
// try and save the same user again to verify it saved the first time, this // try and save the same user again to verify it saved the first time, this
// should return an error // should return an error
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err == nil { if err == nil {
t.Fatal("expected an error, user wasn't created correctly") t.Fatal("expected an error, user wasn't created correctly")
} }
@ -182,7 +209,7 @@ func TestPlugin_RenewUser(t *testing.T) {
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"test": 1, "test": 1,
} }
err = db.Initialize(connectionDetails, true) err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -192,12 +219,12 @@ func TestPlugin_RenewUser(t *testing.T) {
RoleName: "test", RoleName: "test",
} }
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
err = db.RenewUser(dbplugin.Statements{}, us, time.Now().Add(time.Minute)) err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -216,7 +243,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"test": 1, "test": 1,
} }
err = db.Initialize(connectionDetails, true) err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -226,19 +253,159 @@ func TestPlugin_RevokeUser(t *testing.T) {
RoleName: "test", RoleName: "test",
} }
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
// Test default revoke statememts // Test default revoke statememts
err = db.RevokeUser(dbplugin.Statements{}, us) err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
// Try adding the same username back so we can verify it was removed // Try adding the same username back so we can verify it was removed
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
}
// Test the code is still compatible with an old netRPC plugin
func TestPlugin_NetRPC_Initialize(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
dbRaw, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
connectionDetails := map[string]interface{}{
"test": 1,
}
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
err = dbRaw.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestPlugin_NetRPC_CreateUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
defer db.Close()
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConf := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
if us != "test" || pw != "test" {
t.Fatal("expected username and password to be 'test'")
}
// try and save the same user again to verify it saved the first time, this
// should return an error
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("expected an error, user wasn't created correctly")
}
}
func TestPlugin_NetRPC_RenewUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
defer db.Close()
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConf := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
defer db.Close()
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConf := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
// Test default revoke statememts
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
if err != nil {
t.Fatalf("err: %s", err)
}
// Try adding the same username back so we can verify it was removed
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }

View file

@ -10,6 +10,10 @@ import (
// Database implementation in a databasePluginRPCServer object and starts a // Database implementation in a databasePluginRPCServer object and starts a
// RPC server. // RPC server.
func Serve(db Database, tlsProvider func() (*tls.Config, error)) { func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
plugin.Serve(ServeConfig(db, tlsProvider))
}
func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig {
dbPlugin := &DatabasePlugin{ dbPlugin := &DatabasePlugin{
impl: db, impl: db,
} }
@ -19,53 +23,10 @@ func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
"database": dbPlugin, "database": dbPlugin,
} }
plugin.Serve(&plugin.ServeConfig{ return &plugin.ServeConfig{
HandshakeConfig: handshakeConfig, HandshakeConfig: handshakeConfig,
Plugins: pluginMap, Plugins: pluginMap,
TLSProvider: tlsProvider, TLSProvider: tlsProvider,
}) GRPCServer: plugin.DefaultGRPCServer,
} }
// ---- RPC server domain ----
// databasePluginRPCServer implements an RPC version of Database and is run
// inside a plugin. It wraps an underlying implementation of Database.
type databasePluginRPCServer struct {
impl Database
}
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
var err error
*resp, err = ds.impl.Type()
return err
}
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error {
var err error
resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernameConfig, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error {
err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error {
err := ds.impl.RevokeUser(args.Statements, args.Username)
return err
}
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequest, _ *struct{}) error {
err := ds.impl.Initialize(args.Config, args.VerifyConnection)
return err
}
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
ds.impl.Close()
return nil
} }

View file

@ -1,6 +1,7 @@
package database package database
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
@ -62,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc {
b.clearConnection(name) b.clearConnection(name)
// Execute plugin again, we don't need the object so throw away. // Execute plugin again, we don't need the object so throw away.
_, err := b.createDBObj(req.Storage, name) _, err := b.createDBObj(context.TODO(), req.Storage, name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -230,7 +231,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
} }
err = db.Initialize(config.ConnectionDetails, verifyConnection) err = db.Initialize(context.TODO(), config.ConnectionDetails, verifyConnection)
if err != nil { if err != nil {
db.Close() db.Close()
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil

View file

@ -1,6 +1,7 @@
package database package database
import ( import (
"context"
"fmt" "fmt"
"time" "time"
@ -66,7 +67,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
unlockFunc = b.Unlock unlockFunc = b.Unlock
// Create a new DB object // Create a new DB object
db, err = b.createDBObj(req.Storage, role.DBName) db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
if err != nil { if err != nil {
unlockFunc() unlockFunc()
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
@ -81,7 +82,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
} }
// Create the user // Create the user
username, password, err := db.CreateUser(role.Statements, usernameConfig, expiration) username, password, err := db.CreateUser(context.TODO(), role.Statements, usernameConfig, expiration)
// Unlock // Unlock
unlockFunc() unlockFunc()
if err != nil { if err != nil {

View file

@ -1,6 +1,7 @@
package database package database
import ( import (
"context"
"fmt" "fmt"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
@ -60,7 +61,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
unlockFunc = b.Unlock unlockFunc = b.Unlock
// Create a new DB object // Create a new DB object
db, err = b.createDBObj(req.Storage, role.DBName) db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
if err != nil { if err != nil {
unlockFunc() unlockFunc()
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
@ -69,7 +70,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
// Make sure we increase the VALID UNTIL endpoint for this user. // Make sure we increase the VALID UNTIL endpoint for this user.
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
err := db.RenewUser(role.Statements, username, expireTime) err := db.RenewUser(context.TODO(), role.Statements, username, expireTime)
// Unlock // Unlock
unlockFunc() unlockFunc()
if err != nil { if err != nil {
@ -119,14 +120,14 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
unlockFunc = b.Unlock unlockFunc = b.Unlock
// Create a new DB object // Create a new DB object
db, err = b.createDBObj(req.Storage, role.DBName) db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
if err != nil { if err != nil {
unlockFunc() unlockFunc()
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
} }
} }
err = db.RevokeUser(role.Statements, username) err = db.RevokeUser(context.TODO(), role.Statements, username)
// Unlock // Unlock
unlockFunc() unlockFunc()
if err != nil { if err != nil {

View file

@ -43,6 +43,9 @@ type PolicyRequest struct {
// Whether to upsert // Whether to upsert
Upsert bool Upsert bool
// Whether to allow plaintext backup
AllowPlaintextBackup bool
} }
type LockManager struct { type LockManager struct {
@ -378,10 +381,11 @@ func (lm *LockManager) getPolicyCommon(req PolicyRequest, lockType bool) (*Polic
} }
p = &Policy{ p = &Policy{
Name: req.Name, Name: req.Name,
Type: req.KeyType, Type: req.KeyType,
Derived: req.Derived, Derived: req.Derived,
Exportable: req.Exportable, Exportable: req.Exportable,
AllowPlaintextBackup: req.AllowPlaintextBackup,
} }
if req.Derived { if req.Derived {
p.KDF = Kdf_hkdf_sha256 p.KDF = Kdf_hkdf_sha256

View file

@ -119,6 +119,10 @@ func (r *PluginRunner) runCommon(wrapper RunnerUtil, pluginMap map[string]plugin
SecureConfig: secureConfig, SecureConfig: secureConfig,
TLSConfig: clientTLSConfig, TLSConfig: clientTLSConfig,
Logger: namedLogger, Logger: namedLogger,
AllowedProtocols: []plugin.Protocol{
plugin.ProtocolNetRPC,
plugin.ProtocolGRPC,
},
} }
client := plugin.NewClient(clientConfig) client := plugin.NewClient(clientConfig)

View file

@ -192,8 +192,7 @@ func TestBackendHandleRequest_helpRoot(t *testing.T) {
func TestBackendHandleRequest_renewAuth(t *testing.T) { func TestBackendHandleRequest_renewAuth(t *testing.T) {
b := &Backend{} b := &Backend{}
resp, err := b.HandleRequest(logical.RenewAuthRequest( resp, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil))
"/foo", &logical.Auth{}, nil))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -213,8 +212,7 @@ func TestBackendHandleRequest_renewAuthCallback(t *testing.T) {
AuthRenew: callback, AuthRenew: callback,
} }
_, err := b.HandleRequest(logical.RenewAuthRequest( _, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil))
"/foo", &logical.Auth{}, nil))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -237,8 +235,7 @@ func TestBackendHandleRequest_renew(t *testing.T) {
Secrets: []*Secret{secret}, Secrets: []*Secret{secret},
} }
_, err := b.HandleRequest(logical.RenewRequest( _, err := b.HandleRequest(logical.RenewRequest("/foo", secret.Response(nil, nil).Secret, nil))
"/foo", secret.Response(nil, nil).Secret, nil))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -293,8 +290,7 @@ func TestBackendHandleRequest_revoke(t *testing.T) {
Secrets: []*Secret{secret}, Secrets: []*Secret{secret},
} }
_, err := b.HandleRequest(logical.RevokeRequest( _, err := b.HandleRequest(logical.RevokeRequest("/foo", secret.Response(nil, nil).Secret, nil))
"/foo", secret.Response(nil, nil).Secret, nil))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }

View file

@ -8,7 +8,8 @@ import (
) )
// LeaseExtend returns an OperationFunc that can be used to simply extend the // LeaseExtend returns an OperationFunc that can be used to simply extend the
// lease of the auth/secret for the duration that was requested. // lease of the auth/secret for the duration that was requested. The parameters
// provided are used to determine the lease's new TTL value.
// //
// backendIncrement is the backend's requested increment -- perhaps from a user // backendIncrement is the backend's requested increment -- perhaps from a user
// request, perhaps from a role/config value. If not set, uses the mount/system // request, perhaps from a role/config value. If not set, uses the mount/system

View file

@ -200,8 +200,7 @@ func (r *Request) SetLastRemoteWAL(last uint64) {
} }
// RenewRequest creates the structure of the renew request. // RenewRequest creates the structure of the renew request.
func RenewRequest( func RenewRequest(path string, secret *Secret, data map[string]interface{}) *Request {
path string, secret *Secret, data map[string]interface{}) *Request {
return &Request{ return &Request{
Operation: RenewOperation, Operation: RenewOperation,
Path: path, Path: path,
@ -211,8 +210,7 @@ func RenewRequest(
} }
// RenewAuthRequest creates the structure of the renew request for an auth. // RenewAuthRequest creates the structure of the renew request for an auth.
func RenewAuthRequest( func RenewAuthRequest(path string, auth *Auth, data map[string]interface{}) *Request {
path string, auth *Auth, data map[string]interface{}) *Request {
return &Request{ return &Request{
Operation: RenewOperation, Operation: RenewOperation,
Path: path, Path: path,
@ -222,8 +220,7 @@ func RenewAuthRequest(
} }
// RevokeRequest creates the structure of the revoke request. // RevokeRequest creates the structure of the revoke request.
func RevokeRequest( func RevokeRequest(path string, secret *Secret, data map[string]interface{}) *Request {
path string, secret *Secret, data map[string]interface{}) *Request {
return &Request{ return &Request{
Operation: RevokeOperation, Operation: RevokeOperation,
Path: path, Path: path,

View file

@ -1,6 +1,7 @@
package cassandra package cassandra
import ( import (
"context"
"strings" "strings"
"time" "time"
@ -21,6 +22,8 @@ const (
cassandraTypeName = "cassandra" cassandraTypeName = "cassandra"
) )
var _ dbplugin.Database = &Cassandra{}
// Cassandra is an implementation of Database interface // Cassandra is an implementation of Database interface
type Cassandra struct { type Cassandra struct {
connutil.ConnectionProducer connutil.ConnectionProducer
@ -64,8 +67,8 @@ func (c *Cassandra) Type() (string, error) {
return cassandraTypeName, nil return cassandraTypeName, nil
} }
func (c *Cassandra) getConnection() (*gocql.Session, error) { func (c *Cassandra) getConnection(ctx context.Context) (*gocql.Session, error) {
session, err := c.Connection() session, err := c.Connection(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -75,13 +78,13 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) {
// CreateUser generates the username/password on the underlying Cassandra secret backend as instructed by // CreateUser generates the username/password on the underlying Cassandra secret backend as instructed by
// the CreationStatement provided. // the CreationStatement provided.
func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock // Grab the lock
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
// Get the connection // Get the connection
session, err := c.getConnection() session, err := c.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@ -138,18 +141,18 @@ func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig db
} }
// RenewUser is not supported on Cassandra, so this is a no-op. // RenewUser is not supported on Cassandra, so this is a no-op.
func (c *Cassandra) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { func (c *Cassandra) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// NOOP // NOOP
return nil return nil
} }
// RevokeUser attempts to drop the specified user. // RevokeUser attempts to drop the specified user.
func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) error { func (c *Cassandra) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
// Grab the lock // Grab the lock
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
session, err := c.getConnection() session, err := c.getConnection(ctx)
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package cassandra package cassandra
import ( import (
"context"
"os" "os"
"strconv" "strconv"
"testing" "testing"
@ -89,7 +90,7 @@ func TestCassandra_Initialize(t *testing.T) {
db := dbRaw.(*Cassandra) db := dbRaw.(*Cassandra)
connProducer := db.ConnectionProducer.(*cassandraConnectionProducer) connProducer := db.ConnectionProducer.(*cassandraConnectionProducer)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -112,7 +113,7 @@ func TestCassandra_Initialize(t *testing.T) {
"protocol_version": "4", "protocol_version": "4",
} }
err = db.Initialize(connectionDetails, true) err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -135,7 +136,7 @@ func TestCassandra_CreateUser(t *testing.T) {
dbRaw, _ := New() dbRaw, _ := New()
db := dbRaw.(*Cassandra) db := dbRaw.(*Cassandra)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -149,7 +150,7 @@ func TestCassandra_CreateUser(t *testing.T) {
RoleName: "test", RoleName: "test",
} }
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -176,7 +177,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
dbRaw, _ := New() dbRaw, _ := New()
db := dbRaw.(*Cassandra) db := dbRaw.(*Cassandra)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -190,7 +191,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
RoleName: "test", RoleName: "test",
} }
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -199,7 +200,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err) t.Fatalf("Could not connect with new credentials: %s", err)
} }
err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -222,7 +223,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
dbRaw, _ := New() dbRaw, _ := New()
db := dbRaw.(*Cassandra) db := dbRaw.(*Cassandra)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -236,7 +237,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
RoleName: "test", RoleName: "test",
} }
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -246,7 +247,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
} }
// Test default revoke statememts // Test default revoke statememts
err = db.RevokeUser(statements, username) err = db.RevokeUser(context.Background(), statements, username)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }

View file

@ -1,6 +1,7 @@
package cassandra package cassandra
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"strings" "strings"
@ -43,7 +44,7 @@ type cassandraConnectionProducer struct {
sync.Mutex sync.Mutex
} }
func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
@ -106,7 +107,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve
c.Initialized = true c.Initialized = true
if verifyConnection { if verifyConnection {
if _, err := c.Connection(); err != nil { if _, err := c.Connection(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err) return fmt.Errorf("error verifying connection: %s", err)
} }
} }
@ -114,7 +115,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve
return nil return nil
} }
func (c *cassandraConnectionProducer) Connection() (interface{}, error) { func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) {
if !c.Initialized { if !c.Initialized {
return nil, connutil.ErrNotInitialized return nil, connutil.ErrNotInitialized
} }

View file

@ -1,6 +1,7 @@
package hana package hana
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "strings"
@ -26,6 +27,8 @@ type HANA struct {
credsutil.CredentialsProducer credsutil.CredentialsProducer
} }
var _ dbplugin.Database = &HANA{}
// New implements builtinplugins.BuiltinFactory // New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) { func New() (interface{}, error) {
connProducer := &connutil.SQLConnectionProducer{} connProducer := &connutil.SQLConnectionProducer{}
@ -63,8 +66,8 @@ func (h *HANA) Type() (string, error) {
return hanaTypeName, nil return hanaTypeName, nil
} }
func (h *HANA) getConnection() (*sql.DB, error) { func (h *HANA) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := h.Connection() db, err := h.Connection(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -74,13 +77,13 @@ func (h *HANA) getConnection() (*sql.DB, error) {
// CreateUser generates the username/password on the underlying HANA secret backend // CreateUser generates the username/password on the underlying HANA secret backend
// as instructed by the CreationStatement provided. // as instructed by the CreationStatement provided.
func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock // Grab the lock
h.Lock() h.Lock()
defer h.Unlock() defer h.Unlock()
// Get the connection // Get the connection
db, err := h.getConnection() db, err := h.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@ -117,7 +120,7 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@ -130,7 +133,7 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
continue continue
} }
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
"password": password, "password": password,
"expiration": expirationStr, "expiration": expirationStr,
@ -139,7 +142,7 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
return "", "", err return "", "", err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err return "", "", err
} }
} }
@ -153,15 +156,15 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
} }
// Renewing hana user just means altering user's valid until property // Renewing hana user just means altering user's valid until property
func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// Get connection // Get connection
db, err := h.getConnection() db, err := h.getConnection(ctx)
if err != nil { if err != nil {
return err return err
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@ -175,12 +178,12 @@ func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expira
} }
// Renew user's valid until property field // Renew user's valid until property field
stmt, err := tx.Prepare("ALTER USER " + username + " VALID UNTIL " + "'" + expirationStr + "'") stmt, err := tx.PrepareContext(ctx, "ALTER USER "+username+" VALID UNTIL "+"'"+expirationStr+"'")
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
} }
@ -193,20 +196,20 @@ func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expira
} }
// Revoking hana user will deactivate user and try to perform a soft drop // Revoking hana user will deactivate user and try to perform a soft drop
func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error { func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
// default revoke will be a soft drop on user // default revoke will be a soft drop on user
if statements.RevocationStatements == "" { if statements.RevocationStatements == "" {
return h.revokeUserDefault(username) return h.revokeUserDefault(ctx, username)
} }
// Get connection // Get connection
db, err := h.getConnection() db, err := h.getConnection(ctx)
if err != nil { if err != nil {
return err return err
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@ -219,14 +222,14 @@ func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error
continue continue
} }
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
})) }))
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
} }
} }
@ -239,38 +242,38 @@ func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error
return nil return nil
} }
func (h *HANA) revokeUserDefault(username string) error { func (h *HANA) revokeUserDefault(ctx context.Context, username string) error {
// Get connection // Get connection
db, err := h.getConnection() db, err := h.getConnection(ctx)
if err != nil { if err != nil {
return err return err
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
// Disable server login for user // Disable server login for user
disableStmt, err := tx.Prepare(fmt.Sprintf("ALTER USER %s DEACTIVATE USER NOW", username)) disableStmt, err := tx.PrepareContext(ctx, fmt.Sprintf("ALTER USER %s DEACTIVATE USER NOW", username))
if err != nil { if err != nil {
return err return err
} }
defer disableStmt.Close() defer disableStmt.Close()
if _, err := disableStmt.Exec(); err != nil { if _, err := disableStmt.ExecContext(ctx); err != nil {
return err return err
} }
// Invalidates current sessions and performs soft drop (drop if no dependencies) // Invalidates current sessions and performs soft drop (drop if no dependencies)
// if hard drop is desired, custom revoke statements should be written for role // if hard drop is desired, custom revoke statements should be written for role
dropStmt, err := tx.Prepare(fmt.Sprintf("DROP USER %s RESTRICT", username)) dropStmt, err := tx.PrepareContext(ctx, fmt.Sprintf("DROP USER %s RESTRICT", username))
if err != nil { if err != nil {
return err return err
} }
defer dropStmt.Close() defer dropStmt.Close()
if _, err := dropStmt.Exec(); err != nil { if _, err := dropStmt.ExecContext(ctx); err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package hana package hana
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"os" "os"
@ -25,7 +26,7 @@ func TestHANA_Initialize(t *testing.T) {
dbRaw, _ := New() dbRaw, _ := New()
db := dbRaw.(*HANA) db := dbRaw.(*HANA)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -55,7 +56,7 @@ func TestHANA_CreateUser(t *testing.T) {
dbRaw, _ := New() dbRaw, _ := New()
db := dbRaw.(*HANA) db := dbRaw.(*HANA)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -66,7 +67,7 @@ func TestHANA_CreateUser(t *testing.T) {
} }
// Test with no configured Creation Statememt // Test with no configured Creation Statememt
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour)) _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour))
if err == nil { if err == nil {
t.Fatal("Expected error when no creation statement is provided") t.Fatal("Expected error when no creation statement is provided")
} }
@ -75,7 +76,7 @@ func TestHANA_CreateUser(t *testing.T) {
CreationStatements: testHANARole, CreationStatements: testHANARole,
} }
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -98,7 +99,7 @@ func TestHANA_RevokeUser(t *testing.T) {
dbRaw, _ := New() dbRaw, _ := New()
db := dbRaw.(*HANA) db := dbRaw.(*HANA)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -113,7 +114,7 @@ func TestHANA_RevokeUser(t *testing.T) {
} }
// Test default revoke statememts // Test default revoke statememts
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -121,7 +122,7 @@ func TestHANA_RevokeUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err) t.Fatalf("Could not connect with new credentials: %s", err)
} }
err = db.RevokeUser(statements, username) err = db.RevokeUser(context.Background(), statements, username)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -130,7 +131,7 @@ func TestHANA_RevokeUser(t *testing.T) {
} }
// Test custom revoke statememt // Test custom revoke statememt
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour)) username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -139,7 +140,7 @@ func TestHANA_RevokeUser(t *testing.T) {
} }
statements.RevocationStatements = testHANADrop statements.RevocationStatements = testHANADrop
err = db.RevokeUser(statements, username) err = db.RevokeUser(context.Background(), statements, username)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }

View file

@ -1,6 +1,7 @@
package mongodb package mongodb
import ( import (
"context"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
@ -33,7 +34,7 @@ type mongoDBConnectionProducer struct {
} }
// Initialize parses connection configuration. // Initialize parses connection configuration.
func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
@ -75,7 +76,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri
c.Initialized = true c.Initialized = true
if verifyConnection { if verifyConnection {
if _, err := c.Connection(); err != nil { if _, err := c.Connection(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err) return fmt.Errorf("error verifying connection: %s", err)
} }
@ -88,7 +89,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri
} }
// Connection creates a database connection. // Connection creates a database connection.
func (c *mongoDBConnectionProducer) Connection() (interface{}, error) { func (c *mongoDBConnectionProducer) Connection(_ context.Context) (interface{}, error) {
if !c.Initialized { if !c.Initialized {
return nil, connutil.ErrNotInitialized return nil, connutil.ErrNotInitialized
} }

View file

@ -1,6 +1,7 @@
package mongodb package mongodb
import ( import (
"context"
"io" "io"
"strings" "strings"
"time" "time"
@ -27,6 +28,8 @@ type MongoDB struct {
credsutil.CredentialsProducer credsutil.CredentialsProducer
} }
var _ dbplugin.Database = &MongoDB{}
// New returns a new MongoDB instance // New returns a new MongoDB instance
func New() (interface{}, error) { func New() (interface{}, error) {
connProducer := &mongoDBConnectionProducer{} connProducer := &mongoDBConnectionProducer{}
@ -63,8 +66,8 @@ func (m *MongoDB) Type() (string, error) {
return mongoDBTypeName, nil return mongoDBTypeName, nil
} }
func (m *MongoDB) getConnection() (*mgo.Session, error) { func (m *MongoDB) getConnection(ctx context.Context) (*mgo.Session, error) {
session, err := m.Connection() session, err := m.Connection(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -80,7 +83,7 @@ func (m *MongoDB) getConnection() (*mgo.Session, error) {
// //
// JSON Example: // JSON Example:
// { "db": "admin", "roles": [{ "role": "readWrite" }, {"role": "read", "db": "foo"}] } // { "db": "admin", "roles": [{ "role": "readWrite" }, {"role": "read", "db": "foo"}] }
func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock // Grab the lock
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
@ -89,7 +92,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
return "", "", dbutil.ErrEmptyCreationStatement return "", "", dbutil.ErrEmptyCreationStatement
} }
session, err := m.getConnection() session, err := m.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@ -133,7 +136,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
if err := m.ConnectionProducer.Close(); err != nil { if err := m.ConnectionProducer.Close(); err != nil {
return "", "", errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err) return "", "", errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
} }
session, err := m.getConnection() session, err := m.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@ -149,15 +152,15 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
} }
// RenewUser is not supported on MongoDB, so this is a no-op. // RenewUser is not supported on MongoDB, so this is a no-op.
func (m *MongoDB) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { func (m *MongoDB) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// NOOP // NOOP
return nil return nil
} }
// RevokeUser drops the specified user from the authentication databse. If none is provided // RevokeUser drops the specified user from the authentication databse. If none is provided
// in the revocation statement, the default "admin" authentication database will be assumed. // in the revocation statement, the default "admin" authentication database will be assumed.
func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) error { func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
session, err := m.getConnection() session, err := m.getConnection(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -188,7 +191,7 @@ func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) er
if err := m.ConnectionProducer.Close(); err != nil { if err := m.ConnectionProducer.Close(); err != nil {
return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err) return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
} }
session, err := m.getConnection() session, err := m.getConnection(ctx)
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package mongodb package mongodb
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"testing" "testing"
@ -79,7 +80,7 @@ func TestMongoDB_Initialize(t *testing.T) {
db := dbRaw.(*MongoDB) db := dbRaw.(*MongoDB)
connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer) connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer)
err = db.Initialize(connectionDetails, true) err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -107,7 +108,7 @@ func TestMongoDB_CreateUser(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
db := dbRaw.(*MongoDB) db := dbRaw.(*MongoDB)
err = db.Initialize(connectionDetails, true) err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -121,7 +122,7 @@ func TestMongoDB_CreateUser(t *testing.T) {
RoleName: "test", RoleName: "test",
} }
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -145,7 +146,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
db := dbRaw.(*MongoDB) db := dbRaw.(*MongoDB)
err = db.Initialize(connectionDetails, true) err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -159,7 +160,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
RoleName: "test", RoleName: "test",
} }
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -182,7 +183,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
db := dbRaw.(*MongoDB) db := dbRaw.(*MongoDB)
err = db.Initialize(connectionDetails, true) err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -196,7 +197,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
RoleName: "test", RoleName: "test",
} }
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -206,7 +207,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
} }
// Test default revocation statememt // Test default revocation statememt
err = db.RevokeUser(statements, username) err = db.RevokeUser(context.Background(), statements, username)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }

View file

@ -1,6 +1,7 @@
package mssql package mssql
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "strings"
@ -18,6 +19,8 @@ import (
const msSQLTypeName = "mssql" const msSQLTypeName = "mssql"
var _ dbplugin.Database = &MSSQL{}
// MSSQL is an implementation of Database interface // MSSQL is an implementation of Database interface
type MSSQL struct { type MSSQL struct {
connutil.ConnectionProducer connutil.ConnectionProducer
@ -60,8 +63,8 @@ func (m *MSSQL) Type() (string, error) {
return msSQLTypeName, nil return msSQLTypeName, nil
} }
func (m *MSSQL) getConnection() (*sql.DB, error) { func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := m.Connection() db, err := m.Connection(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -71,13 +74,13 @@ func (m *MSSQL) getConnection() (*sql.DB, error) {
// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by // CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
// the CreationStatement provided. // the CreationStatement provided.
func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock // Grab the lock
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
// Get the connection // Get the connection
db, err := m.getConnection() db, err := m.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@ -102,7 +105,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@ -115,7 +118,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
continue continue
} }
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
"password": password, "password": password,
"expiration": expirationStr, "expiration": expirationStr,
@ -124,7 +127,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
return "", "", err return "", "", err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err return "", "", err
} }
} }
@ -138,7 +141,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
} }
// RenewUser is not supported on MSSQL, so this is a no-op. // RenewUser is not supported on MSSQL, so this is a no-op.
func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// NOOP // NOOP
return nil return nil
} }
@ -146,19 +149,19 @@ func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expir
// RevokeUser attempts to drop the specified user. It will first attempt to disable login, // RevokeUser attempts to drop the specified user. It will first attempt to disable login,
// then kill pending connections from that user, and finally drop the user and login from the // then kill pending connections from that user, and finally drop the user and login from the
// database instance. // database instance.
func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error { func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
if statements.RevocationStatements == "" { if statements.RevocationStatements == "" {
return m.revokeUserDefault(username) return m.revokeUserDefault(ctx, username)
} }
// Get connection // Get connection
db, err := m.getConnection() db, err := m.getConnection(ctx)
if err != nil { if err != nil {
return err return err
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@ -171,14 +174,14 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro
continue continue
} }
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
})) }))
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
} }
} }
@ -191,20 +194,20 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro
return nil return nil
} }
func (m *MSSQL) revokeUserDefault(username string) error { func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
// Get connection // Get connection
db, err := m.getConnection() db, err := m.getConnection(ctx)
if err != nil { if err != nil {
return err return err
} }
// First disable server login // First disable server login
disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username)) disableStmt, err := db.PrepareContext(ctx, fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
if err != nil { if err != nil {
return err return err
} }
defer disableStmt.Close() defer disableStmt.Close()
if _, err := disableStmt.Exec(); err != nil { if _, err := disableStmt.ExecContext(ctx); err != nil {
return err return err
} }
@ -212,14 +215,14 @@ func (m *MSSQL) revokeUserDefault(username string) error {
// sessions. There cannot be any active sessions before we drop the logins // sessions. There cannot be any active sessions before we drop the logins
// This isn't done in a transaction because even if we fail along the way, // This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible // we want to remove as much access as possible
sessionStmt, err := db.Prepare(fmt.Sprintf( sessionStmt, err := db.PrepareContext(ctx, fmt.Sprintf(
"SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username)) "SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username))
if err != nil { if err != nil {
return err return err
} }
defer sessionStmt.Close() defer sessionStmt.Close()
sessionRows, err := sessionStmt.Query() sessionRows, err := sessionStmt.QueryContext(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -240,13 +243,13 @@ func (m *MSSQL) revokeUserDefault(username string) error {
// we need to drop the database users before we can drop the login and the role // we need to drop the database users before we can drop the login and the role
// This isn't done in a transaction because even if we fail along the way, // This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible // we want to remove as much access as possible
stmt, err := db.Prepare(fmt.Sprintf("EXEC master.dbo.sp_msloginmappings '%s';", username)) stmt, err := db.PrepareContext(ctx, fmt.Sprintf("EXEC master.dbo.sp_msloginmappings '%s';", username))
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
rows, err := stmt.Query() rows, err := stmt.QueryContext(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -266,13 +269,13 @@ func (m *MSSQL) revokeUserDefault(username string) error {
// many permissions as possible right now // many permissions as possible right now
var lastStmtError error var lastStmtError error
for _, query := range revokeStmts { for _, query := range revokeStmts {
stmt, err := db.Prepare(query) stmt, err := db.PrepareContext(ctx, query)
if err != nil { if err != nil {
lastStmtError = err lastStmtError = err
continue continue
} }
defer stmt.Close() defer stmt.Close()
_, err = stmt.Exec() _, err = stmt.ExecContext(ctx)
if err != nil { if err != nil {
lastStmtError = err lastStmtError = err
} }
@ -287,12 +290,12 @@ func (m *MSSQL) revokeUserDefault(username string) error {
} }
// Drop this login // Drop this login
stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username)) stmt, err = db.PrepareContext(ctx, fmt.Sprintf(dropLoginSQL, username, username))
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package mssql package mssql
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"os" "os"
@ -30,7 +31,7 @@ func TestMSSQL_Initialize(t *testing.T) {
dbRaw, _ := New() dbRaw, _ := New()
db := dbRaw.(*MSSQL) db := dbRaw.(*MSSQL)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -51,7 +52,7 @@ func TestMSSQL_Initialize(t *testing.T) {
"max_open_connections": "5", "max_open_connections": "5",
} }
err = db.Initialize(connectionDetails, true) err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -69,7 +70,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
dbRaw, _ := New() dbRaw, _ := New()
db := dbRaw.(*MSSQL) db := dbRaw.(*MSSQL)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -80,7 +81,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
} }
// Test with no configured Creation Statememt // Test with no configured Creation Statememt
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
if err == nil { if err == nil {
t.Fatal("Expected error when no creation statement is provided") t.Fatal("Expected error when no creation statement is provided")
} }
@ -89,7 +90,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
CreationStatements: testMSSQLRole, CreationStatements: testMSSQLRole,
} }
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -111,7 +112,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
dbRaw, _ := New() dbRaw, _ := New()
db := dbRaw.(*MSSQL) db := dbRaw.(*MSSQL)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -125,7 +126,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
RoleName: "test", RoleName: "test",
} }
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -135,7 +136,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
} }
// Test default revoke statememts // Test default revoke statememts
err = db.RevokeUser(statements, username) err = db.RevokeUser(context.Background(), statements, username)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -144,7 +145,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
t.Fatal("Credentials were not revoked") t.Fatal("Credentials were not revoked")
} }
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -155,7 +156,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
// Test custom revoke statememt // Test custom revoke statememt
statements.RevocationStatements = testMSSQLDrop statements.RevocationStatements = testMSSQLDrop
err = db.RevokeUser(statements, username) err = db.RevokeUser(context.Background(), statements, username)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }

View file

@ -1,6 +1,7 @@
package mysql package mysql
import ( import (
"context"
"database/sql" "database/sql"
"strings" "strings"
"time" "time"
@ -30,6 +31,8 @@ var (
LegacyUsernameLen int = 16 LegacyUsernameLen int = 16
) )
var _ dbplugin.Database = &MySQL{}
type MySQL struct { type MySQL struct {
connutil.ConnectionProducer connutil.ConnectionProducer
credsutil.CredentialsProducer credsutil.CredentialsProducer
@ -88,8 +91,8 @@ func (m *MySQL) Type() (string, error) {
return mySQLTypeName, nil return mySQLTypeName, nil
} }
func (m *MySQL) getConnection() (*sql.DB, error) { func (m *MySQL) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := m.Connection() db, err := m.Connection(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -97,13 +100,13 @@ func (m *MySQL) getConnection() (*sql.DB, error) {
return db.(*sql.DB), nil return db.(*sql.DB), nil
} }
func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock // Grab the lock
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
// Get the connection // Get the connection
db, err := m.getConnection() db, err := m.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@ -128,7 +131,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@ -146,7 +149,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
"expiration": expirationStr, "expiration": expirationStr,
}) })
stmt, err := tx.Prepare(query) stmt, err := tx.PrepareContext(ctx, query)
if err != nil { if err != nil {
// If the error code we get back is Error 1295: This command is not // If the error code we get back is Error 1295: This command is not
// supported in the prepared statement protocol yet, we will execute // supported in the prepared statement protocol yet, we will execute
@ -155,7 +158,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
// prepare supported commands. If there is no error when running we // prepare supported commands. If there is no error when running we
// will continue to the next statement. // will continue to the next statement.
if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 { if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 {
_, err = tx.Exec(query) _, err = tx.ExecContext(ctx, query)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@ -165,7 +168,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
return "", "", err return "", "", err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err return "", "", err
} }
} }
@ -179,17 +182,17 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
} }
// NOOP // NOOP
func (m *MySQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { func (m *MySQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
return nil return nil
} }
func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) error { func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
// Grab the read lock // Grab the read lock
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
// Get the connection // Get the connection
db, err := m.getConnection() db, err := m.getConnection(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -201,7 +204,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@ -217,7 +220,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro
// 1295: This command is not supported in the prepared statement protocol yet // 1295: This command is not supported in the prepared statement protocol yet
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
query = strings.Replace(query, "{{name}}", username, -1) query = strings.Replace(query, "{{name}}", username, -1)
_, err = tx.Exec(query) _, err = tx.ExecContext(ctx, query)
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package mysql package mysql
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"os" "os"
@ -108,7 +109,7 @@ func TestMySQL_Initialize(t *testing.T) {
db := dbRaw.(*MySQL) db := dbRaw.(*MySQL)
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -128,7 +129,7 @@ func TestMySQL_Initialize(t *testing.T) {
"max_open_connections": "5", "max_open_connections": "5",
} }
err = db.Initialize(connectionDetails, true) err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -146,7 +147,7 @@ func TestMySQL_CreateUser(t *testing.T) {
dbRaw, _ := f() dbRaw, _ := f()
db := dbRaw.(*MySQL) db := dbRaw.(*MySQL)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -157,7 +158,7 @@ func TestMySQL_CreateUser(t *testing.T) {
} }
// Test with no configured Creation Statememt // Test with no configured Creation Statememt
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
if err == nil { if err == nil {
t.Fatal("Expected error when no creation statement is provided") t.Fatal("Expected error when no creation statement is provided")
} }
@ -166,7 +167,7 @@ func TestMySQL_CreateUser(t *testing.T) {
CreationStatements: testMySQLRoleWildCard, CreationStatements: testMySQLRoleWildCard,
} }
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -176,7 +177,7 @@ func TestMySQL_CreateUser(t *testing.T) {
} }
// Test a second time to make sure usernames don't collide // Test a second time to make sure usernames don't collide
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -188,7 +189,7 @@ func TestMySQL_CreateUser(t *testing.T) {
// Test with a manualy prepare statement // Test with a manualy prepare statement
statements.CreationStatements = testMySQLRolePreparedStmt statements.CreationStatements = testMySQLRolePreparedStmt
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -211,7 +212,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
dbRaw, _ := f() dbRaw, _ := f()
db := dbRaw.(*MySQL) db := dbRaw.(*MySQL)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -222,7 +223,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
} }
// Test with no configured Creation Statememt // Test with no configured Creation Statememt
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
if err == nil { if err == nil {
t.Fatal("Expected error when no creation statement is provided") t.Fatal("Expected error when no creation statement is provided")
} }
@ -231,7 +232,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
CreationStatements: testMySQLRoleWildCard, CreationStatements: testMySQLRoleWildCard,
} }
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -241,7 +242,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
} }
// Test a second time to make sure usernames don't collide // Test a second time to make sure usernames don't collide
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -263,7 +264,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
dbRaw, _ := f() dbRaw, _ := f()
db := dbRaw.(*MySQL) db := dbRaw.(*MySQL)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -277,7 +278,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
RoleName: "test", RoleName: "test",
} }
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -287,7 +288,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
} }
// Test default revoke statememts // Test default revoke statememts
err = db.RevokeUser(statements, username) err = db.RevokeUser(context.Background(), statements, username)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -297,7 +298,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
} }
statements.CreationStatements = testMySQLRoleWildCard statements.CreationStatements = testMySQLRoleWildCard
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -308,7 +309,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
// Test custom revoke statements // Test custom revoke statements
statements.RevocationStatements = testMySQLRevocationSQL statements.RevocationStatements = testMySQLRevocationSQL
err = db.RevokeUser(statements, username) err = db.RevokeUser(context.Background(), statements, username)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }

View file

@ -1,6 +1,7 @@
package postgresql package postgresql
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "strings"
@ -24,6 +25,8 @@ ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
` `
) )
var _ dbplugin.Database = &PostgreSQL{}
// New implements builtinplugins.BuiltinFactory // New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) { func New() (interface{}, error) {
connProducer := &connutil.SQLConnectionProducer{} connProducer := &connutil.SQLConnectionProducer{}
@ -65,8 +68,8 @@ func (p *PostgreSQL) Type() (string, error) {
return postgreSQLTypeName, nil return postgreSQLTypeName, nil
} }
func (p *PostgreSQL) getConnection() (*sql.DB, error) { func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := p.Connection() db, err := p.Connection(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -74,7 +77,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) {
return db.(*sql.DB), nil return db.(*sql.DB), nil
} }
func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
if statements.CreationStatements == "" { if statements.CreationStatements == "" {
return "", "", dbutil.ErrEmptyCreationStatement return "", "", dbutil.ErrEmptyCreationStatement
} }
@ -99,14 +102,14 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
} }
// Get the connection // Get the connection
db, err := p.getConnection() db, err := p.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
// Start a transaction // Start a transaction
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return "", "", err return "", "", err
@ -123,7 +126,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
continue continue
} }
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
"password": password, "password": password,
"expiration": expirationStr, "expiration": expirationStr,
@ -133,7 +136,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err return "", "", err
} }
@ -148,7 +151,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
return username, password, nil return username, password, nil
} }
func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
@ -157,12 +160,12 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
renewStmts = defaultPostgresRenewSQL renewStmts = defaultPostgresRenewSQL
} }
db, err := p.getConnection() db, err := p.getConnection(ctx)
if err != nil { if err != nil {
return err return err
} }
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@ -180,7 +183,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
if len(query) == 0 { if len(query) == 0 {
continue continue
} }
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
"expiration": expirationStr, "expiration": expirationStr,
})) }))
@ -189,7 +192,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
} }
} }
@ -201,25 +204,25 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
return nil return nil
} }
func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error { func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
// Grab the lock // Grab the lock
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
if statements.RevocationStatements == "" { if statements.RevocationStatements == "" {
return p.defaultRevokeUser(username) return p.defaultRevokeUser(ctx, username)
} }
return p.customRevokeUser(username, statements.RevocationStatements) return p.customRevokeUser(ctx, username, statements.RevocationStatements)
} }
func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationStmts string) error {
db, err := p.getConnection() db, err := p.getConnection(ctx)
if err != nil { if err != nil {
return err return err
} }
tx, err := db.Begin() tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@ -233,7 +236,7 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
continue continue
} }
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
})) }))
if err != nil { if err != nil {
@ -241,7 +244,7 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
} }
} }
@ -253,15 +256,15 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
return nil return nil
} }
func (p *PostgreSQL) defaultRevokeUser(username string) error { func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error {
db, err := p.getConnection() db, err := p.getConnection(ctx)
if err != nil { if err != nil {
return err return err
} }
// Check if the role exists // Check if the role exists
var exists bool var exists bool
err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return err
} }
@ -274,13 +277,13 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
// the role // the role
// This isn't done in a transaction because even if we fail along the way, // This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible // we want to remove as much access as possible
stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") stmt, err := db.PrepareContext(ctx, "SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
rows, err := stmt.Query(username) rows, err := stmt.QueryContext(ctx, username)
if err != nil { if err != nil {
return err return err
} }
@ -322,7 +325,7 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
// get the current database name so we can issue a REVOKE CONNECT for // get the current database name so we can issue a REVOKE CONNECT for
// this username // this username
var dbname sql.NullString var dbname sql.NullString
if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil { if err := db.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbname); err != nil {
return err return err
} }
@ -337,13 +340,13 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
// many permissions as possible right now // many permissions as possible right now
var lastStmtError error var lastStmtError error
for _, query := range revocationStmts { for _, query := range revocationStmts {
stmt, err := db.Prepare(query) stmt, err := db.PrepareContext(ctx, query)
if err != nil { if err != nil {
lastStmtError = err lastStmtError = err
continue continue
} }
defer stmt.Close() defer stmt.Close()
_, err = stmt.Exec() _, err = stmt.ExecContext(ctx)
if err != nil { if err != nil {
lastStmtError = err lastStmtError = err
} }
@ -358,13 +361,13 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
} }
// Drop this user // Drop this user
stmt, err = db.Prepare(fmt.Sprintf( stmt, err = db.PrepareContext(ctx, fmt.Sprintf(
`DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username))) `DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username)))
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package postgresql package postgresql
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"os" "os"
@ -72,7 +73,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -92,7 +93,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
"max_open_connections": "5", "max_open_connections": "5",
} }
err = db.Initialize(connectionDetails, true) err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -109,7 +110,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
dbRaw, _ := New() dbRaw, _ := New()
db := dbRaw.(*PostgreSQL) db := dbRaw.(*PostgreSQL)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -120,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
} }
// Test with no configured Creation Statememt // Test with no configured Creation Statememt
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
if err == nil { if err == nil {
t.Fatal("Expected error when no creation statement is provided") t.Fatal("Expected error when no creation statement is provided")
} }
@ -129,7 +130,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
CreationStatements: testPostgresRole, CreationStatements: testPostgresRole,
} }
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -139,7 +140,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
} }
statements.CreationStatements = testPostgresReadOnlyRole statements.CreationStatements = testPostgresReadOnlyRole
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -162,7 +163,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
dbRaw, _ := New() dbRaw, _ := New()
db := dbRaw.(*PostgreSQL) db := dbRaw.(*PostgreSQL)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -176,7 +177,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
RoleName: "test", RoleName: "test",
} }
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -185,7 +186,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err) t.Fatalf("Could not connect with new credentials: %s", err)
} }
err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -197,7 +198,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err) t.Fatalf("Could not connect with new credentials: %s", err)
} }
statements.RenewStatements = defaultPostgresRenewSQL statements.RenewStatements = defaultPostgresRenewSQL
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -206,7 +207,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err) t.Fatalf("Could not connect with new credentials: %s", err)
} }
err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -230,7 +231,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
dbRaw, _ := New() dbRaw, _ := New()
db := dbRaw.(*PostgreSQL) db := dbRaw.(*PostgreSQL)
err := db.Initialize(connectionDetails, true) err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -244,7 +245,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
RoleName: "test", RoleName: "test",
} }
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -254,7 +255,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
} }
// Test default revoke statememts // Test default revoke statememts
err = db.RevokeUser(statements, username) err = db.RevokeUser(context.Background(), statements, username)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -263,7 +264,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
t.Fatal("Credentials were not revoked") t.Fatal("Credentials were not revoked")
} }
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@ -274,7 +275,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
// Test custom revoke statements // Test custom revoke statements
statements.RevocationStatements = defaultPostgresRevocationSQL statements.RevocationStatements = defaultPostgresRevocationSQL
err = db.RevokeUser(statements, username) err = db.RevokeUser(context.Background(), statements, username)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }

View file

@ -1,6 +1,7 @@
package connutil package connutil
import ( import (
"context"
"errors" "errors"
"sync" "sync"
) )
@ -14,8 +15,8 @@ var (
// connections and is used in all the builtin database types. // connections and is used in all the builtin database types.
type ConnectionProducer interface { type ConnectionProducer interface {
Close() error Close() error
Initialize(map[string]interface{}, bool) error Initialize(context.Context, map[string]interface{}, bool) error
Connection() (interface{}, error) Connection(context.Context) (interface{}, error)
sync.Locker sync.Locker
} }

View file

@ -1,6 +1,7 @@
package connutil package connutil
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "strings"
@ -25,7 +26,7 @@ type SQLConnectionProducer struct {
sync.Mutex sync.Mutex
} }
func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
@ -62,11 +63,11 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo
c.Initialized = true c.Initialized = true
if verifyConnection { if verifyConnection {
if _, err := c.Connection(); err != nil { if _, err := c.Connection(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err) return fmt.Errorf("error verifying connection: %s", err)
} }
if err := c.db.Ping(); err != nil { if err := c.db.PingContext(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err) return fmt.Errorf("error verifying connection: %s", err)
} }
} }
@ -74,14 +75,14 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo
return nil return nil
} }
func (c *SQLConnectionProducer) Connection() (interface{}, error) { func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
if !c.Initialized { if !c.Initialized {
return nil, ErrNotInitialized return nil, ErrNotInitialized
} }
// If we already have a DB, test it and return // If we already have a DB, test it and return
if c.db != nil { if c.db != nil {
if err := c.db.Ping(); err == nil { if err := c.db.PingContext(ctx); err == nil {
return c.db, nil return c.db, nil
} }
// If the ping was unsuccessful, close it and ignore errors as we'll be // If the ping was unsuccessful, close it and ignore errors as we'll be

View file

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/jsonutil"
@ -460,10 +459,11 @@ func (c *Core) setupCredentials() error {
backend, err = c.newCredentialBackend(entry.Type, sysView, view, conf) backend, err = c.newCredentialBackend(entry.Type, sysView, view, conf)
if err != nil { if err != nil {
c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err) c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err)
if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" { if entry.Type == "plugin" {
// If we encounter an error instantiating the backend due to it being missing from the catalog, // If we encounter an error instantiating the backend due to an error,
// skip backend initialization but register the entry to the mount table to preserve storage // skip backend initialization but register the entry to the mount table
// and path. // to preserve storage and path.
c.logger.Warn("core: skipping plugin-based credential entry", "path", entry.Path)
goto ROUTER_MOUNT goto ROUTER_MOUNT
} }
return errLoadAuthFailed return errLoadAuthFailed

View file

@ -750,6 +750,31 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke
}, nil }, nil
} }
sysView := m.router.MatchingSystemView(le.Path)
if sysView == nil {
return nil, fmt.Errorf("expiration: unable to retrieve system view from router")
}
retResp := &logical.Response{}
switch {
case resp.Auth.Period > time.Duration(0):
// If it resp.Period is non-zero, use that as the TTL and override backend's
// call on TTL modification, such as a TTL value determined by
// framework.LeaseExtend call against the request. Also, cap period value to
// the sys/mount max value.
if resp.Auth.Period > sysView.MaxLeaseTTL() {
retResp.AddWarning(fmt.Sprintf("Period of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL()))
resp.Auth.Period = sysView.MaxLeaseTTL()
}
resp.Auth.TTL = resp.Auth.Period
case resp.Auth.TTL > time.Duration(0):
// Cap TTL value to the sys/mount max value
if resp.Auth.TTL > sysView.MaxLeaseTTL() {
retResp.AddWarning(fmt.Sprintf("TTL of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL()))
resp.Auth.TTL = sysView.MaxLeaseTTL()
}
}
// Attach the ClientToken // Attach the ClientToken
resp.Auth.ClientToken = token resp.Auth.ClientToken = token
resp.Auth.Increment = 0 resp.Auth.Increment = 0
@ -764,9 +789,9 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke
// Update the expiration time // Update the expiration time
m.updatePending(le, resp.Auth.LeaseTotal()) m.updatePending(le, resp.Auth.LeaseTotal())
return &logical.Response{
Auth: resp.Auth, retResp.Auth = resp.Auth
}, nil return retResp, nil
} }
// Register is used to take a request and response with an associated // Register is used to take a request and response with an associated
@ -866,6 +891,12 @@ func (m *ExpirationManager) RegisterAuth(source string, auth *logical.Auth) erro
return err return err
} }
// If it resp.Period is non-zero, override the TTL value determined
// by the backend.
if auth.Period > time.Duration(0) {
auth.TTL = auth.Period
}
// Create a lease entry // Create a lease entry
le := leaseEntry{ le := leaseEntry{
LeaseID: path.Join(source, saltedID), LeaseID: path.Join(source, saltedID),
@ -1017,8 +1048,7 @@ func (m *ExpirationManager) revokeEntry(le *leaseEntry) error {
} }
// Handle standard revocation via backends // Handle standard revocation via backends
resp, err := m.router.Route(logical.RevokeRequest( resp, err := m.router.Route(logical.RevokeRequest(le.Path, le.Secret, le.Data))
le.Path, le.Secret, le.Data))
if err != nil || (resp != nil && resp.IsError()) { if err != nil || (resp != nil && resp.IsError()) {
return fmt.Errorf("failed to revoke entry: resp:%#v err:%s", resp, err) return fmt.Errorf("failed to revoke entry: resp:%#v err:%s", resp, err)
} }

View file

@ -2,7 +2,9 @@ package vault_test
import ( import (
"fmt" "fmt"
"io/ioutil"
"os" "os"
"path/filepath"
"testing" "testing"
"time" "time"
@ -178,14 +180,13 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
} }
if testMount { if testMount {
// Add plugin back to the catalog
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical")
// Mount the plugin at the same path after plugin is re-added to the catalog // Mount the plugin at the same path after plugin is re-added to the catalog
// and expect an error due to existing path. // and expect an error due to existing path.
var err error var err error
switch btype { switch btype {
case logical.TypeLogical: case logical.TypeLogical:
// Add plugin back to the catalog
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical")
_, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{ _, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{
"type": "plugin", "type": "plugin",
"config": map[string]interface{}{ "config": map[string]interface{}{
@ -193,6 +194,8 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
}, },
}) })
case logical.TypeCredential: case logical.TypeCredential:
// Add plugin back to the catalog
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials")
_, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{ _, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{
"type": "plugin", "type": "plugin",
"plugin_name": "mock-plugin", "plugin_name": "mock-plugin",
@ -204,6 +207,129 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
} }
} }
func TestSystemBackend_Plugin_continueOnError(t *testing.T) {
t.Run("secret", func(t *testing.T) {
t.Run("sha256_mismatch", func(t *testing.T) {
testPlugin_continueOnError(t, logical.TypeLogical, true)
})
t.Run("missing_plugin", func(t *testing.T) {
testPlugin_continueOnError(t, logical.TypeLogical, false)
})
})
t.Run("auth", func(t *testing.T) {
t.Run("sha256_mismatch", func(t *testing.T) {
testPlugin_continueOnError(t, logical.TypeCredential, true)
})
t.Run("missing_plugin", func(t *testing.T) {
testPlugin_continueOnError(t, logical.TypeCredential, false)
})
})
}
func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatch bool) {
cluster := testSystemBackendMock(t, 1, 1, btype)
defer cluster.Cleanup()
core := cluster.Cores[0]
// Get the registered plugin
req := logical.TestRequest(t, logical.ReadOperation, "sys/plugins/catalog/mock-plugin")
req.ClientToken = core.Client.Token()
resp, err := core.HandleRequest(req)
if err != nil || resp == nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
command, ok := resp.Data["command"].(string)
if !ok || command == "" {
t.Fatal("invalid command")
}
// Trigger a sha256 mistmatch or missing plugin error
if mismatch {
req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/catalog/mock-plugin")
req.Data = map[string]interface{}{
"sha256": "d17bd7334758e53e6fbab15745d2520765c06e296f2ce8e25b7919effa0ac216",
"command": filepath.Base(command),
}
req.ClientToken = core.Client.Token()
resp, err = core.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
} else {
err := os.Remove(filepath.Join(cluster.TempDir, filepath.Base(command)))
if err != nil {
t.Fatal(err)
}
}
// Seal the cluster
cluster.EnsureCoresSealed(t)
// Unseal the cluster
barrierKeys := cluster.BarrierKeys
for _, core := range cluster.Cores {
for _, key := range barrierKeys {
_, err := core.Unseal(vault.TestKeyCopy(key))
if err != nil {
t.Fatal(err)
}
}
sealed, err := core.Sealed()
if err != nil {
t.Fatalf("err checking seal status: %s", err)
}
if sealed {
t.Fatal("should not be sealed")
}
// Wait for active so post-unseal takes place
// If it fails, it means unseal process failed
vault.TestWaitActive(t, core.Core)
}
// Re-add the plugin to the catalog
switch btype {
case logical.TypeLogical:
vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical", cluster.TempDir)
case logical.TypeCredential:
vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials", cluster.TempDir)
}
// Reload the plugin
req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/reload/backend")
req.Data = map[string]interface{}{
"plugin": "mock-plugin",
}
req.ClientToken = core.Client.Token()
resp, err = core.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Make a request to lazy load the plugin
var reqPath string
switch btype {
case logical.TypeLogical:
reqPath = "mock-0/internal"
case logical.TypeCredential:
reqPath = "auth/mock-0/internal"
}
req = logical.TestRequest(t, logical.ReadOperation, reqPath)
req.ClientToken = core.Client.Token()
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: response should not be nil")
}
}
func TestSystemBackend_Plugin_autoReload(t *testing.T) { func TestSystemBackend_Plugin_autoReload(t *testing.T) {
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
defer cluster.Cleanup() defer cluster.Cleanup()
@ -332,7 +458,10 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
} }
// testSystemBackendMock returns a systemBackend with the desired number // testSystemBackendMock returns a systemBackend with the desired number
// of mounted mock plugin backends // of mounted mock plugin backends. numMounts alternates between different
// ways of providing the plugin_name.
//
// The mounts are mounted at sys/mounts/mock-[numMounts] or sys/auth/mock-[numMounts]
func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType) *vault.TestCluster { func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType) *vault.TestCluster {
coreConfig := &vault.CoreConfig{ coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{ LogicalBackends: map[string]logical.Factory{
@ -343,10 +472,17 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo
}, },
} }
// Create a tempdir, cluster.Cleanup will clean up this directory
tempDir, err := ioutil.TempDir("", "vault-test-cluster")
if err != nil {
t.Fatal(err)
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler, HandlerFunc: vaulthttp.Handler,
KeepStandbysSealed: true, KeepStandbysSealed: true,
NumCores: numCores, NumCores: numCores,
TempDir: tempDir,
}) })
cluster.Start() cluster.Start()
@ -358,7 +494,7 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo
switch backendType { switch backendType {
case logical.TypeLogical: case logical.TypeLogical:
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical") vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical", tempDir)
for i := 0; i < numMounts; i++ { for i := 0; i < numMounts; i++ {
// Alternate input styles for plugin_name on every other mount // Alternate input styles for plugin_name on every other mount
options := map[string]interface{}{ options := map[string]interface{}{
@ -380,7 +516,7 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo
} }
} }
case logical.TypeCredential: case logical.TypeCredential:
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials") vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials", tempDir)
for i := 0; i < numMounts; i++ { for i := 0; i < numMounts; i++ {
// Alternate input styles for plugin_name on every other mount // Alternate input styles for plugin_name on every other mount
options := map[string]interface{}{ options := map[string]interface{}{

View file

@ -7,7 +7,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/jsonutil"
@ -753,10 +752,11 @@ func (c *Core) setupMounts() error {
backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf) backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf)
if err != nil { if err != nil {
c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err) c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err)
if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" { if entry.Type == "plugin" {
// If we encounter an error instantiating the backend due to it being missing from the catalog, // If we encounter an error instantiating the backend due to an error,
// skip backend initialization but register the entry to the mount table to preserve storage // skip backend initialization but register the entry to the mount table
// and path. // to preserve storage and path.
c.logger.Warn("core: skipping plugin-based mount entry", "path", entry.Path)
goto ROUTER_MOUNT goto ROUTER_MOUNT
} }
return errLoadMountsFailed return errLoadMountsFailed

View file

@ -79,15 +79,23 @@ func (c *Core) reloadMatchingPlugin(pluginName string) error {
func (c *Core) reloadPluginCommon(entry *MountEntry, isAuth bool) error { func (c *Core) reloadPluginCommon(entry *MountEntry, isAuth bool) error {
path := entry.Path path := entry.Path
if isAuth {
path = credentialRoutePrefix + path
}
// Fast-path out if the backend doesn't exist // Fast-path out if the backend doesn't exist
raw, ok := c.router.root.Get(path) raw, ok := c.router.root.Get(path)
if !ok { if !ok {
return nil return nil
} }
// Call backend's Cleanup routine
re := raw.(*routeEntry) re := raw.(*routeEntry)
re.backend.Cleanup()
// Only call Cleanup if backend is initialized
if re.backend != nil {
// Call backend's Cleanup routine
re.backend.Cleanup()
}
view := re.storageView view := re.storageView

View file

@ -477,14 +477,27 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
return nil, nil, ErrInternalError return nil, nil, ErrInternalError
} }
// Set the default lease if not provided // Start off with the sys default value, and update according to period/TTL
if auth.TTL == 0 { // from resp.Auth
auth.TTL = sysView.DefaultLeaseTTL() tokenTTL := sysView.DefaultLeaseTTL()
}
// Limit the lease duration switch {
if auth.TTL > sysView.MaxLeaseTTL() { case auth.Period > time.Duration(0):
auth.TTL = sysView.MaxLeaseTTL() // Cap the period value to the sys max_ttl value. The auth backend should
// have checked for it on its login path, but we check here again for
// sanity.
if auth.Period > sysView.MaxLeaseTTL() {
auth.Period = sysView.MaxLeaseTTL()
}
tokenTTL = auth.Period
case auth.TTL > time.Duration(0):
// Cap the TTL value. The auth backend should have checked for it on its
// login path (e.g. a call to b.SanitizeTTL), but we check here again for
// sanity.
if auth.TTL > sysView.MaxLeaseTTL() {
auth.TTL = sysView.MaxLeaseTTL()
}
tokenTTL = auth.TTL
} }
// Generate a token // Generate a token
@ -494,7 +507,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
Meta: auth.Metadata, Meta: auth.Metadata,
DisplayName: auth.DisplayName, DisplayName: auth.DisplayName,
CreationTime: time.Now().Unix(), CreationTime: time.Now().Unix(),
TTL: auth.TTL, TTL: tokenTTL,
NumUses: auth.NumUses, NumUses: auth.NumUses,
EntityID: auth.EntityID, EntityID: auth.EntityID,
} }
@ -513,10 +526,11 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
return nil, auth, ErrInternalError return nil, auth, ErrInternalError
} }
// Populate the client token and accessor // Populate the client token, accessor, and TTL
auth.ClientToken = te.ID auth.ClientToken = te.ID
auth.Accessor = te.Accessor auth.Accessor = te.Accessor
auth.Policies = te.Policies auth.Policies = te.Policies
auth.TTL = te.TTL
// Register with the expiration manager // Register with the expiration manager
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil { if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {

View file

@ -379,6 +379,8 @@ func TestDynamicSystemView(c *Core) *dynamicSystemView {
return &dynamicSystemView{c, me} return &dynamicSystemView{c, me}
} }
// TestAddTestPlugin registers the testFunc as part of the plugin command to the
// plugin catalog.
func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) { func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) {
file, err := os.Open(os.Args[0]) file, err := os.Open(os.Args[0])
if err != nil { if err != nil {
@ -413,11 +415,74 @@ func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) {
} }
} }
// TestAddTestPluginTempDir registers the testFunc as part of the plugin command to the
// plugin catalog. It uses tmpDir as the plugin directory.
func TestAddTestPluginTempDir(t testing.T, c *Core, name, testFunc, tempDir string) {
file, err := os.Open(os.Args[0])
if err != nil {
t.Fatal(err)
}
defer file.Close()
fi, err := file.Stat()
if err != nil {
t.Fatal(err)
}
// Copy over the file to the temp dir
dst := filepath.Join(tempDir, filepath.Base(os.Args[0]))
out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
if err != nil {
t.Fatal(err)
}
defer out.Close()
if _, err = io.Copy(out, file); err != nil {
t.Fatal(err)
}
err = out.Sync()
if err != nil {
t.Fatal(err)
}
// Determine plugin directory full path
fullPath, err := filepath.EvalSymlinks(tempDir)
if err != nil {
t.Fatal(err)
}
reader, err := os.Open(filepath.Join(fullPath, filepath.Base(os.Args[0])))
if err != nil {
t.Fatal(err)
}
defer reader.Close()
// Find out the sha256
hash := sha256.New()
_, err = io.Copy(hash, reader)
if err != nil {
t.Fatal(err)
}
sum := hash.Sum(nil)
// Set core's plugin directory and plugin catalog directory
c.pluginDirectory = fullPath
c.pluginCatalog.directory = fullPath
command := fmt.Sprintf("%s --test.run=%s", filepath.Base(os.Args[0]), testFunc)
err = c.pluginCatalog.Set(name, command, sum)
if err != nil {
t.Fatal(err)
}
}
var testLogicalBackends = map[string]logical.Factory{} var testLogicalBackends = map[string]logical.Factory{}
var testCredentialBackends = map[string]logical.Factory{} var testCredentialBackends = map[string]logical.Factory{}
// Starts the test server which responds to SSH authentication. // StartSSHHostTestServer starts the test server which responds to SSH
// Used to test the SSH secret backend. // authentication. Used to test the SSH secret backend.
func StartSSHHostTestServer() (string, error) { func StartSSHHostTestServer() (string, error) {
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testSharedPublicKey)) pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testSharedPublicKey))
if err != nil { if err != nil {
@ -760,6 +825,7 @@ func (c *TestCluster) ensureCoresSealed() error {
return nil return nil
} }
// UnsealWithStoredKeys uses stored keys to unseal the test cluster cores
func (c *TestCluster) UnsealWithStoredKeys(t testing.T) error { func (c *TestCluster) UnsealWithStoredKeys(t testing.T) error {
for _, core := range c.Cores { for _, core := range c.Cores {
if err := core.UnsealWithStoredKeys(); err != nil { if err := core.UnsealWithStoredKeys(); err != nil {

View file

@ -2312,7 +2312,7 @@ func TestTokenStore_RolePeriod(t *testing.T) {
req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test") req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test")
req.ClientToken = root req.ClientToken = root
req.Data = map[string]interface{}{ req.Data = map[string]interface{}{
"period": 300, "period": 5,
} }
resp, err := core.HandleRequest(req) resp, err := core.HandleRequest(req)
@ -2425,8 +2425,8 @@ func TestTokenStore_RolePeriod(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl := resp.Data["ttl"].(int64) ttl := resp.Data["ttl"].(int64)
if ttl < 299 { if ttl > 5 {
t.Fatalf("TTL too small (expected %d, got %d", 299, ttl) t.Fatalf("TTL too large (expected %d, got %d", 5, ttl)
} }
// Let the TTL go down a bit to 3 seconds // Let the TTL go down a bit to 3 seconds
@ -2449,8 +2449,8 @@ func TestTokenStore_RolePeriod(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl = resp.Data["ttl"].(int64) ttl = resp.Data["ttl"].(int64)
if ttl < 299 { if ttl > 5 {
t.Fatalf("TTL too small (expected %d, got %d", 299, ttl) t.Fatalf("TTL too large (expected %d, got %d", 5, ttl)
} }
} }
} }
@ -2677,7 +2677,7 @@ func TestTokenStore_Periodic(t *testing.T) {
req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test") req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test")
req.ClientToken = root req.ClientToken = root
req.Data = map[string]interface{}{ req.Data = map[string]interface{}{
"period": 300, "period": 5,
} }
resp, err := core.HandleRequest(req) resp, err := core.HandleRequest(req)
@ -2715,8 +2715,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl := resp.Data["ttl"].(int64) ttl := resp.Data["ttl"].(int64)
if ttl < 299 { if ttl > 5 {
t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl) t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl)
} }
// Let the TTL go down a bit // Let the TTL go down a bit
@ -2739,8 +2739,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl = resp.Data["ttl"].(int64) ttl = resp.Data["ttl"].(int64)
if ttl < 299 { if ttl > 5 {
t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl) t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl)
} }
} }
@ -2750,8 +2750,8 @@ func TestTokenStore_Periodic(t *testing.T) {
req.Operation = logical.UpdateOperation req.Operation = logical.UpdateOperation
req.Path = "auth/token/create" req.Path = "auth/token/create"
req.Data = map[string]interface{}{ req.Data = map[string]interface{}{
"period": 300, "period": 5,
"explicit_max_ttl": 150, "explicit_max_ttl": 4,
} }
resp, err = core.HandleRequest(req) resp, err = core.HandleRequest(req)
if err != nil { if err != nil {
@ -2775,8 +2775,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl := resp.Data["ttl"].(int64) ttl := resp.Data["ttl"].(int64)
if ttl < 149 || ttl > 150 { if ttl < 3 || ttl > 4 {
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl) t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl)
} }
// Let the TTL go down a bit // Let the TTL go down a bit
@ -2799,8 +2799,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl = resp.Data["ttl"].(int64) ttl = resp.Data["ttl"].(int64)
if ttl < 140 || ttl > 150 { if ttl > 2 {
t.Fatalf("TTL bad (expected around %d, got %d)", 145, ttl) t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl)
} }
} }
@ -2812,7 +2812,7 @@ func TestTokenStore_Periodic(t *testing.T) {
req.Operation = logical.UpdateOperation req.Operation = logical.UpdateOperation
req.Path = "auth/token/create/test" req.Path = "auth/token/create/test"
req.Data = map[string]interface{}{ req.Data = map[string]interface{}{
"period": 150, "period": 5,
} }
resp, err = core.HandleRequest(req) resp, err = core.HandleRequest(req)
if err != nil { if err != nil {
@ -2836,8 +2836,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl := resp.Data["ttl"].(int64) ttl := resp.Data["ttl"].(int64)
if ttl < 149 || ttl > 150 { if ttl < 4 || ttl > 5 {
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl) t.Fatalf("TTL bad (expected %d, got %d)", 4, ttl)
} }
// Let the TTL go down a bit // Let the TTL go down a bit
@ -2860,8 +2860,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl = resp.Data["ttl"].(int64) ttl = resp.Data["ttl"].(int64)
if ttl < 149 { if ttl > 5 {
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl) t.Fatalf("TTL bad (expected less than %d, got %d)", 5, ttl)
} }
} }
@ -2869,18 +2869,23 @@ func TestTokenStore_Periodic(t *testing.T) {
{ {
req.Path = "auth/token/roles/test" req.Path = "auth/token/roles/test"
req.ClientToken = root req.ClientToken = root
req.Operation = logical.UpdateOperation
req.Data = map[string]interface{}{ req.Data = map[string]interface{}{
"period": 300, "period": 5,
"explicit_max_ttl": 150, "explicit_max_ttl": 4,
}
resp, err := core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v %v", err, resp)
}
if resp != nil {
t.Fatalf("expected a nil response")
} }
req.ClientToken = root req.ClientToken = root
req.Operation = logical.UpdateOperation req.Operation = logical.UpdateOperation
req.Path = "auth/token/create/test" req.Path = "auth/token/create/test"
req.Data = map[string]interface{}{
"period": 150,
"explicit_max_ttl": 130,
}
resp, err = core.HandleRequest(req) resp, err = core.HandleRequest(req)
if err != nil { if err != nil {
t.Fatalf("err: %v %v", err, resp) t.Fatalf("err: %v %v", err, resp)
@ -2903,12 +2908,12 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl := resp.Data["ttl"].(int64) ttl := resp.Data["ttl"].(int64)
if ttl < 129 || ttl > 130 { if ttl < 3 || ttl > 4 {
t.Fatalf("TTL bad (expected %d, got %d)", 129, ttl) t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl)
} }
// Let the TTL go down a bit // Let the TTL go down a bit
time.Sleep(4 * time.Second) time.Sleep(2 * time.Second)
req.Operation = logical.UpdateOperation req.Operation = logical.UpdateOperation
req.Path = "auth/token/renew-self" req.Path = "auth/token/renew-self"
@ -2927,8 +2932,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl = resp.Data["ttl"].(int64) ttl = resp.Data["ttl"].(int64)
if ttl > 127 { if ttl > 2 {
t.Fatalf("TTL bad (expected < %d, got %d)", 128, ttl) t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl)
} }
} }
} }