Merge remote-tracking branch 'oss/master' into f-nomad
* oss/master: Defer reader.Close that is used to determine sha256 changelog++ Avoid unseal failure if plugin backends fail to setup during postUnseal (#3686) Add logic for using Auth.Period when handling auth login/renew requests (#3677) plugins/database: use context with plugins that use database/sql package (#3691) changelog++ Fix plaintext backup in transit (#3692) Database gRPC plugins (#3666)
This commit is contained in:
commit
db0006ef65
10
CHANGELOG.md
10
CHANGELOG.md
|
@ -33,6 +33,8 @@ FEATURES:
|
|||
operation that can export a given key, including all key versions and
|
||||
configuration, as well as a restore operation allowing import into another
|
||||
Vault.
|
||||
* **gRPC Database Plugins**: Database plugins now use gRPC for transport,
|
||||
allowing them to be written in other languages.
|
||||
|
||||
IMPROVEMENTS:
|
||||
|
||||
|
@ -46,6 +48,10 @@ IMPROVEMENTS:
|
|||
during database configuration. This establishes a session-wide [write
|
||||
concern](https://docs.mongodb.com/manual/reference/write-concern/) for the
|
||||
lifecycle of the mount [GH-3646]
|
||||
* mfa/okta: Filter a given email address as a login filter, allowing operation
|
||||
when login email and account email are different
|
||||
* plugins: Make Vault more resilient when unsealing when plugins are
|
||||
unavailable [GH-3686]
|
||||
* secret/pki: `allowed_domains` and `key_usage` can now be specified
|
||||
as a comma-separated string or an array of strings [GH-3642]
|
||||
* secret/ssh: Allow 4096-bit keys to be used in dynamic key method [GH-3593]
|
||||
|
@ -58,8 +64,12 @@ BUG FIXES:
|
|||
* auth/cert: Return `allowed_names` on role read [GH-3654]
|
||||
* auth/ldap: Fix incorrect control information being sent [GH-3402] [GH-3496]
|
||||
[GH-3625] [GH-3656]
|
||||
* core: Fix seal status reporting when using an autoseal
|
||||
* core: Add creation path to wrap info for a control group token
|
||||
* core: Fix potential panic that could occur using plugins when a node
|
||||
transitioned from active to standby [GH-3638]
|
||||
* core: Fix memory ballooning when a connection would connect to the cluster
|
||||
port and then go away -- redux! [GH-3680]
|
||||
* core: Replace recursive token revocation logic with depth-first logic, which
|
||||
can avoid hitting stack depth limits in extreme cases [GH-2348]
|
||||
* core/pkcs11 (enterprise): Fix panic when PKCS#11 library is not readable
|
||||
|
|
1
Makefile
1
Makefile
|
@ -84,6 +84,7 @@ proto:
|
|||
protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding
|
||||
protoc -I physical physical/types.proto --go_out=plugins=grpc:physical
|
||||
protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity
|
||||
protoc builtin/logical/database/dbplugin/*.proto --go_out=plugins=grpc:.
|
||||
sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go
|
||||
sed -i -e 's/Iv/IV/' -e 's/Hmac/HMAC/' physical/types.pb.go
|
||||
|
||||
|
|
|
@ -141,7 +141,7 @@ func testAccStepMapUserIdCidr(t *testing.T, cidr string) logicaltest.TestStep {
|
|||
func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
|
||||
checkTTL := func(resp *logical.Response) error {
|
||||
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
|
||||
return fmt.Errorf("invalid TTL")
|
||||
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -165,7 +165,7 @@ func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
|
|||
func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep {
|
||||
checkTTL := func(resp *logical.Response) error {
|
||||
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
|
||||
return fmt.Errorf("invalid TTL")
|
||||
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package approle
|
|||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
|
@ -68,20 +67,13 @@ func (b *backend) pathLoginUpdate(req *logical.Request, data *framework.FieldDat
|
|||
Policies: role.Policies,
|
||||
LeaseOptions: logical.LeaseOptions{
|
||||
Renewable: true,
|
||||
TTL: role.TokenTTL,
|
||||
},
|
||||
Alias: &logical.Alias{
|
||||
Name: role.RoleID,
|
||||
},
|
||||
}
|
||||
|
||||
// If 'Period' is set, use the value of 'Period' as the TTL.
|
||||
// Otherwise, set the normal TokenTTL.
|
||||
if role.Period > time.Duration(0) {
|
||||
auth.TTL = role.Period
|
||||
} else {
|
||||
auth.TTL = role.TokenTTL
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Auth: auth,
|
||||
}, nil
|
||||
|
@ -107,16 +99,12 @@ func (b *backend) pathLoginRenew(req *logical.Request, data *framework.FieldData
|
|||
return nil, fmt.Errorf("role %s does not exist during renewal", roleName)
|
||||
}
|
||||
|
||||
// If 'Period' is set on the Role, the token should never expire.
|
||||
// Replenish the TTL with 'Period's value.
|
||||
if role.Period > time.Duration(0) {
|
||||
// If 'Period' was updated after the token was issued,
|
||||
// token will bear the updated 'Period' value as its TTL.
|
||||
req.Auth.TTL = role.Period
|
||||
return &logical.Response{Auth: req.Auth}, nil
|
||||
} else {
|
||||
return framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
|
||||
resp, err := framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Auth.Period = role.Period
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
const pathLoginHelpSys = "Issue a token based on the credentials supplied"
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"strings"
|
||||
|
@ -87,7 +88,7 @@ func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) {
|
|||
// This function creates a new db object from the stored configuration and
|
||||
// caches it in the connections map. The caller of this function needs to hold
|
||||
// the backend's write lock
|
||||
func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) {
|
||||
func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) {
|
||||
db, ok := b.connections[name]
|
||||
if ok {
|
||||
return db, nil
|
||||
|
@ -103,7 +104,7 @@ func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.
|
|||
return nil, err
|
||||
}
|
||||
|
||||
err = db.Initialize(config.ConnectionDetails, true)
|
||||
err = db.Initialize(ctx, config.ConnectionDetails, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -170,7 +171,8 @@ func (b *databaseBackend) clearConnection(name string) {
|
|||
|
||||
func (b *databaseBackend) closeIfShutdown(name string, err error) {
|
||||
// Plugin has shutdown, close it so next call can reconnect.
|
||||
if err == rpc.ErrShutdown {
|
||||
switch err {
|
||||
case rpc.ErrShutdown, dbplugin.ErrPluginShutdown:
|
||||
b.Lock()
|
||||
b.clearConnection(name)
|
||||
b.Unlock()
|
||||
|
|
|
@ -488,9 +488,11 @@ func TestBackend_roleCrud(t *testing.T) {
|
|||
RevocationStatements: defaultRevocationSQL,
|
||||
}
|
||||
|
||||
var actual dbplugin.Statements
|
||||
if err := mapstructure.Decode(resp.Data, &actual); err != nil {
|
||||
t.Fatal(err)
|
||||
actual := dbplugin.Statements{
|
||||
CreationStatements: resp.Data["creation_statements"].(string),
|
||||
RevocationStatements: resp.Data["revocation_statements"].(string),
|
||||
RollbackStatements: resp.Data["rollback_statements"].(string),
|
||||
RenewStatements: resp.Data["renew_statements"].(string),
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
|
@ -17,11 +15,11 @@ type DatabasePluginClient struct {
|
|||
client *plugin.Client
|
||||
sync.Mutex
|
||||
|
||||
*databasePluginRPCClient
|
||||
Database
|
||||
}
|
||||
|
||||
func (dc *DatabasePluginClient) Close() error {
|
||||
err := dc.databasePluginRPCClient.Close()
|
||||
err := dc.Database.Close()
|
||||
dc.client.Kill()
|
||||
|
||||
return err
|
||||
|
@ -55,79 +53,20 @@ func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginR
|
|||
|
||||
// We should have a database type now. This feels like a normal interface
|
||||
// implementation but is in fact over an RPC connection.
|
||||
databaseRPC := raw.(*databasePluginRPCClient)
|
||||
var db Database
|
||||
switch raw.(type) {
|
||||
case *gRPCClient:
|
||||
db = raw.(*gRPCClient)
|
||||
case *databasePluginRPCClient:
|
||||
logger.Warn("database: plugin is using deprecated net RPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name)
|
||||
db = raw.(*databasePluginRPCClient)
|
||||
default:
|
||||
return nil, errors.New("unsupported client type")
|
||||
}
|
||||
|
||||
// Wrap RPC implimentation in DatabasePluginClient
|
||||
return &DatabasePluginClient{
|
||||
client: client,
|
||||
databasePluginRPCClient: databaseRPC,
|
||||
client: client,
|
||||
Database: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ---- RPC client domain ----
|
||||
|
||||
// databasePluginRPCClient implements Database and is used on the client to
|
||||
// make RPC calls to a plugin.
|
||||
type databasePluginRPCClient struct {
|
||||
client *rpc.Client
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Type() (string, error) {
|
||||
var dbType string
|
||||
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
|
||||
|
||||
return fmt.Sprintf("plugin-%s", dbType), err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
req := CreateUserRequest{
|
||||
Statements: statements,
|
||||
UsernameConfig: usernameConfig,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
var resp CreateUserResponse
|
||||
err = dr.client.Call("Plugin.CreateUser", req, &resp)
|
||||
|
||||
return resp.Username, resp.Password, err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error {
|
||||
req := RenewUserRequest{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error {
|
||||
req := RevokeUserRequest{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
||||
req := InitializeRequest{
|
||||
Config: conf,
|
||||
VerifyConnection: verifyConnection,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Close() error {
|
||||
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,556 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// source: builtin/logical/database/dbplugin/database.proto
|
||||
|
||||
/*
|
||||
Package dbplugin is a generated protocol buffer package.
|
||||
|
||||
It is generated from these files:
|
||||
builtin/logical/database/dbplugin/database.proto
|
||||
|
||||
It has these top-level messages:
|
||||
InitializeRequest
|
||||
CreateUserRequest
|
||||
RenewUserRequest
|
||||
RevokeUserRequest
|
||||
Statements
|
||||
UsernameConfig
|
||||
CreateUserResponse
|
||||
TypeResponse
|
||||
Empty
|
||||
*/
|
||||
package dbplugin
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import fmt "fmt"
|
||||
import math "math"
|
||||
import google_protobuf "github.com/golang/protobuf/ptypes/timestamp"
|
||||
|
||||
import (
|
||||
context "golang.org/x/net/context"
|
||||
grpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
type InitializeRequest struct {
|
||||
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
|
||||
VerifyConnection bool `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"`
|
||||
}
|
||||
|
||||
func (m *InitializeRequest) Reset() { *m = InitializeRequest{} }
|
||||
func (m *InitializeRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*InitializeRequest) ProtoMessage() {}
|
||||
func (*InitializeRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
|
||||
|
||||
func (m *InitializeRequest) GetConfig() []byte {
|
||||
if m != nil {
|
||||
return m.Config
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *InitializeRequest) GetVerifyConnection() bool {
|
||||
if m != nil {
|
||||
return m.VerifyConnection
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type CreateUserRequest struct {
|
||||
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
|
||||
UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"`
|
||||
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
|
||||
}
|
||||
|
||||
func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} }
|
||||
func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*CreateUserRequest) ProtoMessage() {}
|
||||
func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
|
||||
|
||||
func (m *CreateUserRequest) GetStatements() *Statements {
|
||||
if m != nil {
|
||||
return m.Statements
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *CreateUserRequest) GetUsernameConfig() *UsernameConfig {
|
||||
if m != nil {
|
||||
return m.UsernameConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *CreateUserRequest) GetExpiration() *google_protobuf.Timestamp {
|
||||
if m != nil {
|
||||
return m.Expiration
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RenewUserRequest struct {
|
||||
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
|
||||
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
|
||||
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
|
||||
}
|
||||
|
||||
func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} }
|
||||
func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*RenewUserRequest) ProtoMessage() {}
|
||||
func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
|
||||
|
||||
func (m *RenewUserRequest) GetStatements() *Statements {
|
||||
if m != nil {
|
||||
return m.Statements
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *RenewUserRequest) GetUsername() string {
|
||||
if m != nil {
|
||||
return m.Username
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *RenewUserRequest) GetExpiration() *google_protobuf.Timestamp {
|
||||
if m != nil {
|
||||
return m.Expiration
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RevokeUserRequest struct {
|
||||
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
|
||||
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
|
||||
}
|
||||
|
||||
func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} }
|
||||
func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*RevokeUserRequest) ProtoMessage() {}
|
||||
func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
|
||||
|
||||
func (m *RevokeUserRequest) GetStatements() *Statements {
|
||||
if m != nil {
|
||||
return m.Statements
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *RevokeUserRequest) GetUsername() string {
|
||||
if m != nil {
|
||||
return m.Username
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type Statements struct {
|
||||
CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"`
|
||||
RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"`
|
||||
RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"`
|
||||
RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Statements) Reset() { *m = Statements{} }
|
||||
func (m *Statements) String() string { return proto.CompactTextString(m) }
|
||||
func (*Statements) ProtoMessage() {}
|
||||
func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
|
||||
|
||||
func (m *Statements) GetCreationStatements() string {
|
||||
if m != nil {
|
||||
return m.CreationStatements
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Statements) GetRevocationStatements() string {
|
||||
if m != nil {
|
||||
return m.RevocationStatements
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Statements) GetRollbackStatements() string {
|
||||
if m != nil {
|
||||
return m.RollbackStatements
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Statements) GetRenewStatements() string {
|
||||
if m != nil {
|
||||
return m.RenewStatements
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type UsernameConfig struct {
|
||||
DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"`
|
||||
RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"`
|
||||
}
|
||||
|
||||
func (m *UsernameConfig) Reset() { *m = UsernameConfig{} }
|
||||
func (m *UsernameConfig) String() string { return proto.CompactTextString(m) }
|
||||
func (*UsernameConfig) ProtoMessage() {}
|
||||
func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
|
||||
|
||||
func (m *UsernameConfig) GetDisplayName() string {
|
||||
if m != nil {
|
||||
return m.DisplayName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *UsernameConfig) GetRoleName() string {
|
||||
if m != nil {
|
||||
return m.RoleName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type CreateUserResponse struct {
|
||||
Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"`
|
||||
Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"`
|
||||
}
|
||||
|
||||
func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} }
|
||||
func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) }
|
||||
func (*CreateUserResponse) ProtoMessage() {}
|
||||
func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
|
||||
|
||||
func (m *CreateUserResponse) GetUsername() string {
|
||||
if m != nil {
|
||||
return m.Username
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *CreateUserResponse) GetPassword() string {
|
||||
if m != nil {
|
||||
return m.Password
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type TypeResponse struct {
|
||||
Type string `protobuf:"bytes,1,opt,name=type" json:"type,omitempty"`
|
||||
}
|
||||
|
||||
func (m *TypeResponse) Reset() { *m = TypeResponse{} }
|
||||
func (m *TypeResponse) String() string { return proto.CompactTextString(m) }
|
||||
func (*TypeResponse) ProtoMessage() {}
|
||||
func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
|
||||
|
||||
func (m *TypeResponse) GetType() string {
|
||||
if m != nil {
|
||||
return m.Type
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type Empty struct {
|
||||
}
|
||||
|
||||
func (m *Empty) Reset() { *m = Empty{} }
|
||||
func (m *Empty) String() string { return proto.CompactTextString(m) }
|
||||
func (*Empty) ProtoMessage() {}
|
||||
func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} }
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest")
|
||||
proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest")
|
||||
proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest")
|
||||
proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest")
|
||||
proto.RegisterType((*Statements)(nil), "dbplugin.Statements")
|
||||
proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig")
|
||||
proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse")
|
||||
proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse")
|
||||
proto.RegisterType((*Empty)(nil), "dbplugin.Empty")
|
||||
}
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ context.Context
|
||||
var _ grpc.ClientConn
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
const _ = grpc.SupportPackageIsVersion4
|
||||
|
||||
// Client API for Database service
|
||||
|
||||
type DatabaseClient interface {
|
||||
Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error)
|
||||
CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error)
|
||||
RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error)
|
||||
RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error)
|
||||
Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error)
|
||||
Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
|
||||
}
|
||||
|
||||
type databaseClient struct {
|
||||
cc *grpc.ClientConn
|
||||
}
|
||||
|
||||
func NewDatabaseClient(cc *grpc.ClientConn) DatabaseClient {
|
||||
return &databaseClient{cc}
|
||||
}
|
||||
|
||||
func (c *databaseClient) Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) {
|
||||
out := new(TypeResponse)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/Type", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *databaseClient) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) {
|
||||
out := new(CreateUserResponse)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/CreateUser", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *databaseClient) RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/RenewUser", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/RevokeUser", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := grpc.Invoke(ctx, "/dbplugin.Database/Close", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Server API for Database service
|
||||
|
||||
type DatabaseServer interface {
|
||||
Type(context.Context, *Empty) (*TypeResponse, error)
|
||||
CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error)
|
||||
RenewUser(context.Context, *RenewUserRequest) (*Empty, error)
|
||||
RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error)
|
||||
Initialize(context.Context, *InitializeRequest) (*Empty, error)
|
||||
Close(context.Context, *Empty) (*Empty, error)
|
||||
}
|
||||
|
||||
func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) {
|
||||
s.RegisterService(&_Database_serviceDesc, srv)
|
||||
}
|
||||
|
||||
func _Database_Type_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(Empty)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).Type(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/Type",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).Type(ctx, req.(*Empty))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Database_CreateUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(CreateUserRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).CreateUser(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/CreateUser",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).CreateUser(ctx, req.(*CreateUserRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Database_RenewUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RenewUserRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).RenewUser(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/RenewUser",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).RenewUser(ctx, req.(*RenewUserRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RevokeUserRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).RevokeUser(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/RevokeUser",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).RevokeUser(ctx, req.(*RevokeUserRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(InitializeRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).Initialize(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/Initialize",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(Empty)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DatabaseServer).Close(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/dbplugin.Database/Close",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DatabaseServer).Close(ctx, req.(*Empty))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
var _Database_serviceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "dbplugin.Database",
|
||||
HandlerType: (*DatabaseServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Type",
|
||||
Handler: _Database_Type_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "CreateUser",
|
||||
Handler: _Database_CreateUser_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RenewUser",
|
||||
Handler: _Database_RenewUser_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RevokeUser",
|
||||
Handler: _Database_RevokeUser_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "Initialize",
|
||||
Handler: _Database_Initialize_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "Close",
|
||||
Handler: _Database_Close_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "builtin/logical/database/dbplugin/database.proto",
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) }
|
||||
|
||||
var fileDescriptor0 = []byte{
|
||||
// 548 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e,
|
||||
0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f,
|
||||
0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e,
|
||||
0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2,
|
||||
0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6,
|
||||
0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c,
|
||||
0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e,
|
||||
0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3,
|
||||
0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1,
|
||||
0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f,
|
||||
0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49,
|
||||
0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35,
|
||||
0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20,
|
||||
0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6,
|
||||
0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34,
|
||||
0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0,
|
||||
0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce,
|
||||
0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68,
|
||||
0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0,
|
||||
0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8,
|
||||
0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1,
|
||||
0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0,
|
||||
0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a,
|
||||
0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab,
|
||||
0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c,
|
||||
0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4,
|
||||
0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36,
|
||||
0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a,
|
||||
0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf,
|
||||
0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67,
|
||||
0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b,
|
||||
0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d,
|
||||
0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6,
|
||||
0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56,
|
||||
0x94, 0x05, 0x00, 0x00,
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
syntax = "proto3";
|
||||
package dbplugin;
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
message InitializeRequest {
|
||||
bytes config = 1;
|
||||
bool verify_connection = 2;
|
||||
}
|
||||
|
||||
message CreateUserRequest {
|
||||
Statements statements = 1;
|
||||
UsernameConfig username_config = 2;
|
||||
google.protobuf.Timestamp expiration = 3;
|
||||
}
|
||||
|
||||
message RenewUserRequest {
|
||||
Statements statements = 1;
|
||||
string username = 2;
|
||||
google.protobuf.Timestamp expiration = 3;
|
||||
}
|
||||
|
||||
message RevokeUserRequest {
|
||||
Statements statements = 1;
|
||||
string username = 2;
|
||||
}
|
||||
|
||||
message Statements {
|
||||
string creation_statements = 1;
|
||||
string revocation_statements = 2;
|
||||
string rollback_statements = 3;
|
||||
string renew_statements = 4;
|
||||
}
|
||||
|
||||
message UsernameConfig {
|
||||
string DisplayName = 1;
|
||||
string RoleName = 2;
|
||||
}
|
||||
|
||||
message CreateUserResponse {
|
||||
string username = 1;
|
||||
string password = 2;
|
||||
}
|
||||
|
||||
message TypeResponse {
|
||||
string type = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
|
||||
service Database {
|
||||
rpc Type(Empty) returns (TypeResponse);
|
||||
rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
|
||||
rpc RenewUser(RenewUserRequest) returns (Empty);
|
||||
rpc RevokeUser(RevokeUserRequest) returns (Empty);
|
||||
rpc Initialize(InitializeRequest) returns (Empty);
|
||||
rpc Close(Empty) returns (Empty);
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
|
@ -15,55 +16,56 @@ type databaseTracingMiddleware struct {
|
|||
next Database
|
||||
logger log.Logger
|
||||
|
||||
typeStr string
|
||||
typeStr string
|
||||
transport string
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Type() (string, error) {
|
||||
return mw.next.Type()
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (mw *databaseTracingMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr)
|
||||
return mw.next.CreateUser(statements, usernameConfig, expiration)
|
||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
|
||||
func (mw *databaseTracingMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr)
|
||||
return mw.next.RenewUser(statements, username, expiration)
|
||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr, "transport", mw.transport)
|
||||
return mw.next.RenewUser(ctx, statements, username, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) {
|
||||
func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr)
|
||||
return mw.next.RevokeUser(statements, username)
|
||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||
return mw.next.RevokeUser(ctx, statements, username)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then))
|
||||
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr)
|
||||
return mw.next.Initialize(conf, verifyConnection)
|
||||
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||
return mw.next.Initialize(ctx, conf, verifyConnection)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Close() (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr)
|
||||
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||
return mw.next.Close()
|
||||
}
|
||||
|
||||
|
@ -81,7 +83,7 @@ func (mw *databaseMetricsMiddleware) Type() (string, error) {
|
|||
return mw.next.Type()
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (mw *databaseMetricsMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "CreateUser"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now)
|
||||
|
@ -94,10 +96,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameC
|
|||
|
||||
metrics.IncrCounter([]string{"database", "CreateUser"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1)
|
||||
return mw.next.CreateUser(statements, usernameConfig, expiration)
|
||||
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
|
||||
func (mw *databaseMetricsMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "RenewUser"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now)
|
||||
|
@ -110,10 +112,10 @@ func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username s
|
|||
|
||||
metrics.IncrCounter([]string{"database", "RenewUser"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1)
|
||||
return mw.next.RenewUser(statements, username, expiration)
|
||||
return mw.next.RenewUser(ctx, statements, username, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username string) (err error) {
|
||||
func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "RevokeUser"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now)
|
||||
|
@ -126,10 +128,10 @@ func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username
|
|||
|
||||
metrics.IncrCounter([]string{"database", "RevokeUser"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1)
|
||||
return mw.next.RevokeUser(statements, username)
|
||||
return mw.next.RevokeUser(ctx, statements, username)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "Initialize"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
|
||||
|
@ -142,7 +144,7 @@ func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, ver
|
|||
|
||||
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
|
||||
return mw.next.Initialize(conf, verifyConnection)
|
||||
return mw.next.Initialize(ctx, conf, verifyConnection)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) Close() (err error) {
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPluginShutdown = errors.New("plugin shutdown")
|
||||
)
|
||||
|
||||
// ---- gRPC Server domain ----
|
||||
|
||||
type gRPCServer struct {
|
||||
impl Database
|
||||
}
|
||||
|
||||
func (s *gRPCServer) Type(context.Context, *Empty) (*TypeResponse, error) {
|
||||
t, err := s.impl.Type()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TypeResponse{
|
||||
Type: t,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *gRPCServer) CreateUser(ctx context.Context, req *CreateUserRequest) (*CreateUserResponse, error) {
|
||||
e, err := ptypes.Timestamp(req.Expiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u, p, err := s.impl.CreateUser(ctx, *req.Statements, *req.UsernameConfig, e)
|
||||
|
||||
return &CreateUserResponse{
|
||||
Username: u,
|
||||
Password: p,
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) RenewUser(ctx context.Context, req *RenewUserRequest) (*Empty, error) {
|
||||
e, err := ptypes.Timestamp(req.Expiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = s.impl.RenewUser(ctx, *req.Statements, req.Username, e)
|
||||
return &Empty{}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*Empty, error) {
|
||||
err := s.impl.RevokeUser(ctx, *req.Statements, req.Username)
|
||||
return &Empty{}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
|
||||
config := map[string]interface{}{}
|
||||
|
||||
err := json.Unmarshal(req.Config, &config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.impl.Initialize(ctx, config, req.VerifyConnection)
|
||||
return &Empty{}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) {
|
||||
s.impl.Close()
|
||||
return &Empty{}, nil
|
||||
}
|
||||
|
||||
// ---- gRPC client domain ----
|
||||
|
||||
type gRPCClient struct {
|
||||
client DatabaseClient
|
||||
clientConn *grpc.ClientConn
|
||||
}
|
||||
|
||||
func (c gRPCClient) Type() (string, error) {
|
||||
// If the plugin has already shutdown, this will hang forever so we give it
|
||||
// a one second timeout.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
default:
|
||||
return "", ErrPluginShutdown
|
||||
}
|
||||
resp, err := c.client.Type(ctx, &Empty{})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return resp.Type, err
|
||||
}
|
||||
|
||||
func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
t, err := ptypes.TimestampProto(expiration)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
default:
|
||||
return "", "", ErrPluginShutdown
|
||||
}
|
||||
|
||||
resp, err := c.client.CreateUser(ctx, &CreateUserRequest{
|
||||
Statements: &statements,
|
||||
UsernameConfig: &usernameConfig,
|
||||
Expiration: t,
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return resp.Username, resp.Password, err
|
||||
}
|
||||
|
||||
func (c *gRPCClient) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error {
|
||||
t, err := ptypes.TimestampProto(expiration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
default:
|
||||
return ErrPluginShutdown
|
||||
}
|
||||
|
||||
_, err = c.client.RenewUser(ctx, &RenewUserRequest{
|
||||
Statements: &statements,
|
||||
Username: username,
|
||||
Expiration: t,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, username string) error {
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
default:
|
||||
return ErrPluginShutdown
|
||||
}
|
||||
_, err := c.client.RevokeUser(ctx, &RevokeUserRequest{
|
||||
Statements: &statements,
|
||||
Username: username,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error {
|
||||
configRaw, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
default:
|
||||
return ErrPluginShutdown
|
||||
}
|
||||
|
||||
_, err = c.client.Initialize(ctx, &InitializeRequest{
|
||||
Config: configRaw,
|
||||
VerifyConnection: verifyConnection,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *gRPCClient) Close() error {
|
||||
// If the plugin has already shutdown, this will hang forever so we give it
|
||||
// a one second timeout.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
switch c.clientConn.GetState() {
|
||||
case connectivity.Ready, connectivity.Idle:
|
||||
_, err := c.client.Close(ctx, &Empty{})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,139 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---- RPC server domain ----
|
||||
|
||||
// databasePluginRPCServer implements an RPC version of Database and is run
|
||||
// inside a plugin. It wraps an underlying implementation of Database.
|
||||
type databasePluginRPCServer struct {
|
||||
impl Database
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
|
||||
var err error
|
||||
*resp, err = ds.impl.Type()
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequestRPC, resp *CreateUserResponse) error {
|
||||
var err error
|
||||
resp.Username, resp.Password, err = ds.impl.CreateUser(context.Background(), args.Statements, args.UsernameConfig, args.Expiration)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequestRPC, _ *struct{}) error {
|
||||
err := ds.impl.RenewUser(context.Background(), args.Statements, args.Username, args.Expiration)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *struct{}) error {
|
||||
err := ds.impl.RevokeUser(context.Background(), args.Statements, args.Username)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error {
|
||||
err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
|
||||
ds.impl.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---- RPC client domain ----
|
||||
// databasePluginRPCClient implements Database and is used on the client to
|
||||
// make RPC calls to a plugin.
|
||||
type databasePluginRPCClient struct {
|
||||
client *rpc.Client
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Type() (string, error) {
|
||||
var dbType string
|
||||
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
|
||||
|
||||
return fmt.Sprintf("plugin-%s", dbType), err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) CreateUser(_ context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
req := CreateUserRequestRPC{
|
||||
Statements: statements,
|
||||
UsernameConfig: usernameConfig,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
var resp CreateUserResponse
|
||||
err = dr.client.Call("Plugin.CreateUser", req, &resp)
|
||||
|
||||
return resp.Username, resp.Password, err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements Statements, username string, expiration time.Time) error {
|
||||
req := RenewUserRequestRPC{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error {
|
||||
req := RevokeUserRequestRPC{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||
req := InitializeRequestRPC{
|
||||
Config: conf,
|
||||
VerifyConnection: verifyConnection,
|
||||
}
|
||||
|
||||
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Close() error {
|
||||
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ---- RPC Request Args Domain ----
|
||||
|
||||
type InitializeRequestRPC struct {
|
||||
Config map[string]interface{}
|
||||
VerifyConnection bool
|
||||
}
|
||||
|
||||
type CreateUserRequestRPC struct {
|
||||
Statements Statements
|
||||
UsernameConfig UsernameConfig
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RenewUserRequestRPC struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RevokeUserRequestRPC struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
}
|
|
@ -1,10 +1,13 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
|
@ -13,29 +16,14 @@ import (
|
|||
// Database is the interface that all database objects must implement.
|
||||
type Database interface {
|
||||
Type() (string, error)
|
||||
CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
|
||||
RenewUser(statements Statements, username string, expiration time.Time) error
|
||||
RevokeUser(statements Statements, username string) error
|
||||
CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
|
||||
RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error
|
||||
RevokeUser(ctx context.Context, statements Statements, username string) error
|
||||
|
||||
Initialize(config map[string]interface{}, verifyConnection bool) error
|
||||
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Statements set in role creation and passed into the database type's functions.
|
||||
type Statements struct {
|
||||
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
|
||||
RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"`
|
||||
RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"`
|
||||
RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"`
|
||||
}
|
||||
|
||||
// UsernameConfig is used to configure prefixes for the username to be
|
||||
// generated.
|
||||
type UsernameConfig struct {
|
||||
DisplayName string
|
||||
RoleName string
|
||||
}
|
||||
|
||||
// PluginFactory is used to build plugin database types. It wraps the database
|
||||
// object in a logging and metrics middleware.
|
||||
func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
|
||||
|
@ -45,6 +33,7 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var transport string
|
||||
var db Database
|
||||
if pluginRunner.Builtin {
|
||||
// Plugin is builtin so we can retrieve an instance of the interface
|
||||
|
@ -60,12 +49,24 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
|
|||
return nil, fmt.Errorf("unsuported database type: %s", pluginName)
|
||||
}
|
||||
|
||||
transport = "builtin"
|
||||
|
||||
} else {
|
||||
// create a DatabasePluginClient instance
|
||||
db, err = newPluginClient(sys, pluginRunner, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Switch on the underlying database client type to get the transport
|
||||
// method.
|
||||
switch db.(*DatabasePluginClient).Database.(type) {
|
||||
case *gRPCClient:
|
||||
transport = "gRPC"
|
||||
case *databasePluginRPCClient:
|
||||
transport = "netRPC"
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
typeStr, err := db.Type()
|
||||
|
@ -82,9 +83,10 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
|
|||
// Wrap with tracing middleware
|
||||
if logger.IsTrace() {
|
||||
db = &databaseTracingMiddleware{
|
||||
next: db,
|
||||
typeStr: typeStr,
|
||||
logger: logger,
|
||||
transport: transport,
|
||||
next: db,
|
||||
typeStr: typeStr,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -115,33 +117,14 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
|
|||
return &databasePluginRPCClient{client: c}, nil
|
||||
}
|
||||
|
||||
// ---- RPC Request Args Domain ----
|
||||
|
||||
type InitializeRequest struct {
|
||||
Config map[string]interface{}
|
||||
VerifyConnection bool
|
||||
func (d DatabasePlugin) GRPCServer(s *grpc.Server) error {
|
||||
RegisterDatabaseServer(s, &gRPCServer{impl: d.impl})
|
||||
return nil
|
||||
}
|
||||
|
||||
type CreateUserRequest struct {
|
||||
Statements Statements
|
||||
UsernameConfig UsernameConfig
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RenewUserRequest struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RevokeUserRequest struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
}
|
||||
|
||||
// ---- RPC Response Args Domain ----
|
||||
|
||||
type CreateUserResponse struct {
|
||||
Username string
|
||||
Password string
|
||||
func (DatabasePlugin) GRPCClient(c *grpc.ClientConn) (interface{}, error) {
|
||||
return &gRPCClient{
|
||||
client: NewDatabaseClient(c),
|
||||
clientConn: c,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
package dbplugin_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
|
@ -20,7 +22,7 @@ type mockPlugin struct {
|
|||
}
|
||||
|
||||
func (m *mockPlugin) Type() (string, error) { return "mock", nil }
|
||||
func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (m *mockPlugin) CreateUser(_ context.Context, statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
err = errors.New("err")
|
||||
if usernameConf.DisplayName == "" || expiration.IsZero() {
|
||||
return "", "", err
|
||||
|
@ -34,7 +36,7 @@ func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbp
|
|||
|
||||
return usernameConf.DisplayName, "test", nil
|
||||
}
|
||||
func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (m *mockPlugin) RenewUser(_ context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
err := errors.New("err")
|
||||
if username == "" || expiration.IsZero() {
|
||||
return err
|
||||
|
@ -46,7 +48,7 @@ func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string,
|
|||
|
||||
return nil
|
||||
}
|
||||
func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statements, username string) error {
|
||||
err := errors.New("err")
|
||||
if username == "" {
|
||||
return err
|
||||
|
@ -59,7 +61,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string)
|
|||
delete(m.users, username)
|
||||
return nil
|
||||
}
|
||||
func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error {
|
||||
func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error {
|
||||
err := errors.New("err")
|
||||
if len(conf) != 1 {
|
||||
return err
|
||||
|
@ -80,14 +82,15 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
|
|||
cores := cluster.Cores
|
||||
|
||||
sys := vault.TestDynamicSystemView(cores[0].Core)
|
||||
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_Main")
|
||||
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_GRPC_Main")
|
||||
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", "TestPlugin_NetRPC_Main")
|
||||
|
||||
return cluster, sys
|
||||
}
|
||||
|
||||
// This is not an actual test case, it's a helper function that will be executed
|
||||
// by the go-plugin client via an exec call.
|
||||
func TestPlugin_Main(t *testing.T) {
|
||||
func TestPlugin_GRPC_Main(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
@ -105,6 +108,30 @@ func TestPlugin_Main(t *testing.T) {
|
|||
plugins.Serve(plugin, apiClientMeta.GetTLSConfig())
|
||||
}
|
||||
|
||||
// This is not an actual test case, it's a helper function that will be executed
|
||||
// by the go-plugin client via an exec call.
|
||||
func TestPlugin_NetRPC_Main(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
p := &mockPlugin{
|
||||
users: make(map[string][]string),
|
||||
}
|
||||
|
||||
args := []string{"--tls-skip-verify=true"}
|
||||
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
|
||||
tlsProvider := pluginutil.VaultPluginTLSProvider(apiClientMeta.GetTLSConfig())
|
||||
serveConf := dbplugin.ServeConfig(p, tlsProvider)
|
||||
serveConf.GRPCServer = nil
|
||||
|
||||
plugin.Serve(serveConf)
|
||||
}
|
||||
|
||||
func TestPlugin_Initialize(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
@ -118,7 +145,7 @@ func TestPlugin_Initialize(t *testing.T) {
|
|||
"test": 1,
|
||||
}
|
||||
|
||||
err = dbRaw.Initialize(connectionDetails, true)
|
||||
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -143,7 +170,7 @@ func TestPlugin_CreateUser(t *testing.T) {
|
|||
"test": 1,
|
||||
}
|
||||
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -153,7 +180,7 @@ func TestPlugin_CreateUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, pw, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -163,7 +190,7 @@ func TestPlugin_CreateUser(t *testing.T) {
|
|||
|
||||
// try and save the same user again to verify it saved the first time, this
|
||||
// should return an error
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("expected an error, user wasn't created correctly")
|
||||
}
|
||||
|
@ -182,7 +209,7 @@ func TestPlugin_RenewUser(t *testing.T) {
|
|||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -192,12 +219,12 @@ func TestPlugin_RenewUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(dbplugin.Statements{}, us, time.Now().Add(time.Minute))
|
||||
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -216,7 +243,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
|
|||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -226,19 +253,159 @@ func TestPlugin_RevokeUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(dbplugin.Statements{}, us)
|
||||
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Try adding the same username back so we can verify it was removed
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test the code is still compatible with an old netRPC plugin
|
||||
func TestPlugin_NetRPC_Initialize(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
dbRaw, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
|
||||
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
err = dbRaw.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlugin_NetRPC_CreateUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
usernameConf := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if us != "test" || pw != "test" {
|
||||
t.Fatal("expected username and password to be 'test'")
|
||||
}
|
||||
|
||||
// try and save the same user again to verify it saved the first time, this
|
||||
// should return an error
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("expected an error, user wasn't created correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlugin_NetRPC_RenewUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
usernameConf := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
usernameConf := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Try adding the same username back so we can verify it was removed
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -10,6 +10,10 @@ import (
|
|||
// Database implementation in a databasePluginRPCServer object and starts a
|
||||
// RPC server.
|
||||
func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
|
||||
plugin.Serve(ServeConfig(db, tlsProvider))
|
||||
}
|
||||
|
||||
func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig {
|
||||
dbPlugin := &DatabasePlugin{
|
||||
impl: db,
|
||||
}
|
||||
|
@ -19,53 +23,10 @@ func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
|
|||
"database": dbPlugin,
|
||||
}
|
||||
|
||||
plugin.Serve(&plugin.ServeConfig{
|
||||
return &plugin.ServeConfig{
|
||||
HandshakeConfig: handshakeConfig,
|
||||
Plugins: pluginMap,
|
||||
TLSProvider: tlsProvider,
|
||||
})
|
||||
}
|
||||
|
||||
// ---- RPC server domain ----
|
||||
|
||||
// databasePluginRPCServer implements an RPC version of Database and is run
|
||||
// inside a plugin. It wraps an underlying implementation of Database.
|
||||
type databasePluginRPCServer struct {
|
||||
impl Database
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
|
||||
var err error
|
||||
*resp, err = ds.impl.Type()
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error {
|
||||
var err error
|
||||
resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernameConfig, args.Expiration)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error {
|
||||
err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error {
|
||||
err := ds.impl.RevokeUser(args.Statements, args.Username)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequest, _ *struct{}) error {
|
||||
err := ds.impl.Initialize(args.Config, args.VerifyConnection)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
|
||||
ds.impl.Close()
|
||||
return nil
|
||||
GRPCServer: plugin.DefaultGRPCServer,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
|
@ -62,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc {
|
|||
b.clearConnection(name)
|
||||
|
||||
// Execute plugin again, we don't need the object so throw away.
|
||||
_, err := b.createDBObj(req.Storage, name)
|
||||
_, err := b.createDBObj(context.TODO(), req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -230,7 +231,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
|||
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
||||
}
|
||||
|
||||
err = db.Initialize(config.ConnectionDetails, verifyConnection)
|
||||
err = db.Initialize(context.TODO(), config.ConnectionDetails, verifyConnection)
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
@ -66,7 +67,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
|||
unlockFunc = b.Unlock
|
||||
|
||||
// Create a new DB object
|
||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
||||
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
unlockFunc()
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
|
@ -81,7 +82,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
|||
}
|
||||
|
||||
// Create the user
|
||||
username, password, err := db.CreateUser(role.Statements, usernameConfig, expiration)
|
||||
username, password, err := db.CreateUser(context.TODO(), role.Statements, usernameConfig, expiration)
|
||||
// Unlock
|
||||
unlockFunc()
|
||||
if err != nil {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
|
@ -60,7 +61,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
|
|||
unlockFunc = b.Unlock
|
||||
|
||||
// Create a new DB object
|
||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
||||
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
unlockFunc()
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
|
@ -69,7 +70,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
|
|||
|
||||
// Make sure we increase the VALID UNTIL endpoint for this user.
|
||||
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
|
||||
err := db.RenewUser(role.Statements, username, expireTime)
|
||||
err := db.RenewUser(context.TODO(), role.Statements, username, expireTime)
|
||||
// Unlock
|
||||
unlockFunc()
|
||||
if err != nil {
|
||||
|
@ -119,14 +120,14 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
|
|||
unlockFunc = b.Unlock
|
||||
|
||||
// Create a new DB object
|
||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
||||
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
unlockFunc()
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = db.RevokeUser(role.Statements, username)
|
||||
err = db.RevokeUser(context.TODO(), role.Statements, username)
|
||||
// Unlock
|
||||
unlockFunc()
|
||||
if err != nil {
|
||||
|
|
|
@ -43,6 +43,9 @@ type PolicyRequest struct {
|
|||
|
||||
// Whether to upsert
|
||||
Upsert bool
|
||||
|
||||
// Whether to allow plaintext backup
|
||||
AllowPlaintextBackup bool
|
||||
}
|
||||
|
||||
type LockManager struct {
|
||||
|
@ -378,10 +381,11 @@ func (lm *LockManager) getPolicyCommon(req PolicyRequest, lockType bool) (*Polic
|
|||
}
|
||||
|
||||
p = &Policy{
|
||||
Name: req.Name,
|
||||
Type: req.KeyType,
|
||||
Derived: req.Derived,
|
||||
Exportable: req.Exportable,
|
||||
Name: req.Name,
|
||||
Type: req.KeyType,
|
||||
Derived: req.Derived,
|
||||
Exportable: req.Exportable,
|
||||
AllowPlaintextBackup: req.AllowPlaintextBackup,
|
||||
}
|
||||
if req.Derived {
|
||||
p.KDF = Kdf_hkdf_sha256
|
||||
|
|
|
@ -119,6 +119,10 @@ func (r *PluginRunner) runCommon(wrapper RunnerUtil, pluginMap map[string]plugin
|
|||
SecureConfig: secureConfig,
|
||||
TLSConfig: clientTLSConfig,
|
||||
Logger: namedLogger,
|
||||
AllowedProtocols: []plugin.Protocol{
|
||||
plugin.ProtocolNetRPC,
|
||||
plugin.ProtocolGRPC,
|
||||
},
|
||||
}
|
||||
|
||||
client := plugin.NewClient(clientConfig)
|
||||
|
|
|
@ -192,8 +192,7 @@ func TestBackendHandleRequest_helpRoot(t *testing.T) {
|
|||
func TestBackendHandleRequest_renewAuth(t *testing.T) {
|
||||
b := &Backend{}
|
||||
|
||||
resp, err := b.HandleRequest(logical.RenewAuthRequest(
|
||||
"/foo", &logical.Auth{}, nil))
|
||||
resp, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -213,8 +212,7 @@ func TestBackendHandleRequest_renewAuthCallback(t *testing.T) {
|
|||
AuthRenew: callback,
|
||||
}
|
||||
|
||||
_, err := b.HandleRequest(logical.RenewAuthRequest(
|
||||
"/foo", &logical.Auth{}, nil))
|
||||
_, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -237,8 +235,7 @@ func TestBackendHandleRequest_renew(t *testing.T) {
|
|||
Secrets: []*Secret{secret},
|
||||
}
|
||||
|
||||
_, err := b.HandleRequest(logical.RenewRequest(
|
||||
"/foo", secret.Response(nil, nil).Secret, nil))
|
||||
_, err := b.HandleRequest(logical.RenewRequest("/foo", secret.Response(nil, nil).Secret, nil))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -293,8 +290,7 @@ func TestBackendHandleRequest_revoke(t *testing.T) {
|
|||
Secrets: []*Secret{secret},
|
||||
}
|
||||
|
||||
_, err := b.HandleRequest(logical.RevokeRequest(
|
||||
"/foo", secret.Response(nil, nil).Secret, nil))
|
||||
_, err := b.HandleRequest(logical.RevokeRequest("/foo", secret.Response(nil, nil).Secret, nil))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -8,7 +8,8 @@ import (
|
|||
)
|
||||
|
||||
// LeaseExtend returns an OperationFunc that can be used to simply extend the
|
||||
// lease of the auth/secret for the duration that was requested.
|
||||
// lease of the auth/secret for the duration that was requested. The parameters
|
||||
// provided are used to determine the lease's new TTL value.
|
||||
//
|
||||
// backendIncrement is the backend's requested increment -- perhaps from a user
|
||||
// request, perhaps from a role/config value. If not set, uses the mount/system
|
||||
|
|
|
@ -200,8 +200,7 @@ func (r *Request) SetLastRemoteWAL(last uint64) {
|
|||
}
|
||||
|
||||
// RenewRequest creates the structure of the renew request.
|
||||
func RenewRequest(
|
||||
path string, secret *Secret, data map[string]interface{}) *Request {
|
||||
func RenewRequest(path string, secret *Secret, data map[string]interface{}) *Request {
|
||||
return &Request{
|
||||
Operation: RenewOperation,
|
||||
Path: path,
|
||||
|
@ -211,8 +210,7 @@ func RenewRequest(
|
|||
}
|
||||
|
||||
// RenewAuthRequest creates the structure of the renew request for an auth.
|
||||
func RenewAuthRequest(
|
||||
path string, auth *Auth, data map[string]interface{}) *Request {
|
||||
func RenewAuthRequest(path string, auth *Auth, data map[string]interface{}) *Request {
|
||||
return &Request{
|
||||
Operation: RenewOperation,
|
||||
Path: path,
|
||||
|
@ -222,8 +220,7 @@ func RenewAuthRequest(
|
|||
}
|
||||
|
||||
// RevokeRequest creates the structure of the revoke request.
|
||||
func RevokeRequest(
|
||||
path string, secret *Secret, data map[string]interface{}) *Request {
|
||||
func RevokeRequest(path string, secret *Secret, data map[string]interface{}) *Request {
|
||||
return &Request{
|
||||
Operation: RevokeOperation,
|
||||
Path: path,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -21,6 +22,8 @@ const (
|
|||
cassandraTypeName = "cassandra"
|
||||
)
|
||||
|
||||
var _ dbplugin.Database = &Cassandra{}
|
||||
|
||||
// Cassandra is an implementation of Database interface
|
||||
type Cassandra struct {
|
||||
connutil.ConnectionProducer
|
||||
|
@ -64,8 +67,8 @@ func (c *Cassandra) Type() (string, error) {
|
|||
return cassandraTypeName, nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) getConnection() (*gocql.Session, error) {
|
||||
session, err := c.Connection()
|
||||
func (c *Cassandra) getConnection(ctx context.Context) (*gocql.Session, error) {
|
||||
session, err := c.Connection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -75,13 +78,13 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) {
|
|||
|
||||
// CreateUser generates the username/password on the underlying Cassandra secret backend as instructed by
|
||||
// the CreationStatement provided.
|
||||
func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
// Grab the lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
// Get the connection
|
||||
session, err := c.getConnection()
|
||||
session, err := c.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -138,18 +141,18 @@ func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig db
|
|||
}
|
||||
|
||||
// RenewUser is not supported on Cassandra, so this is a no-op.
|
||||
func (c *Cassandra) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (c *Cassandra) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeUser attempts to drop the specified user.
|
||||
func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
func (c *Cassandra) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
// Grab the lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
session, err := c.getConnection()
|
||||
session, err := c.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
@ -89,7 +90,7 @@ func TestCassandra_Initialize(t *testing.T) {
|
|||
db := dbRaw.(*Cassandra)
|
||||
connProducer := db.ConnectionProducer.(*cassandraConnectionProducer)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -112,7 +113,7 @@ func TestCassandra_Initialize(t *testing.T) {
|
|||
"protocol_version": "4",
|
||||
}
|
||||
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -135,7 +136,7 @@ func TestCassandra_CreateUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*Cassandra)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -149,7 +150,7 @@ func TestCassandra_CreateUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -176,7 +177,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*Cassandra)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -190,7 +191,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -199,7 +200,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
|
|||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
|
||||
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -222,7 +223,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*Cassandra)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -236,7 +237,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -246,7 +247,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -43,7 +44,7 @@ type cassandraConnectionProducer struct {
|
|||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
||||
func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
|
@ -106,7 +107,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve
|
|||
c.Initialized = true
|
||||
|
||||
if verifyConnection {
|
||||
if _, err := c.Connection(); err != nil {
|
||||
if _, err := c.Connection(ctx); err != nil {
|
||||
return fmt.Errorf("error verifying connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -114,7 +115,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *cassandraConnectionProducer) Connection() (interface{}, error) {
|
||||
func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) {
|
||||
if !c.Initialized {
|
||||
return nil, connutil.ErrNotInitialized
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package hana
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -26,6 +27,8 @@ type HANA struct {
|
|||
credsutil.CredentialsProducer
|
||||
}
|
||||
|
||||
var _ dbplugin.Database = &HANA{}
|
||||
|
||||
// New implements builtinplugins.BuiltinFactory
|
||||
func New() (interface{}, error) {
|
||||
connProducer := &connutil.SQLConnectionProducer{}
|
||||
|
@ -63,8 +66,8 @@ func (h *HANA) Type() (string, error) {
|
|||
return hanaTypeName, nil
|
||||
}
|
||||
|
||||
func (h *HANA) getConnection() (*sql.DB, error) {
|
||||
db, err := h.Connection()
|
||||
func (h *HANA) getConnection(ctx context.Context) (*sql.DB, error) {
|
||||
db, err := h.Connection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -74,13 +77,13 @@ func (h *HANA) getConnection() (*sql.DB, error) {
|
|||
|
||||
// CreateUser generates the username/password on the underlying HANA secret backend
|
||||
// as instructed by the CreationStatement provided.
|
||||
func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
// Grab the lock
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := h.getConnection()
|
||||
db, err := h.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -117,7 +120,7 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
|
|||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -130,7 +133,7 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
|
|||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
"expiration": expirationStr,
|
||||
|
@ -139,7 +142,7 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
|
|||
return "", "", err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
@ -153,15 +156,15 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
|
|||
}
|
||||
|
||||
// Renewing hana user just means altering user's valid until property
|
||||
func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
// Get connection
|
||||
db, err := h.getConnection()
|
||||
db, err := h.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -175,12 +178,12 @@ func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expira
|
|||
}
|
||||
|
||||
// Renew user's valid until property field
|
||||
stmt, err := tx.Prepare("ALTER USER " + username + " VALID UNTIL " + "'" + expirationStr + "'")
|
||||
stmt, err := tx.PrepareContext(ctx, "ALTER USER "+username+" VALID UNTIL "+"'"+expirationStr+"'")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -193,20 +196,20 @@ func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expira
|
|||
}
|
||||
|
||||
// Revoking hana user will deactivate user and try to perform a soft drop
|
||||
func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
// default revoke will be a soft drop on user
|
||||
if statements.RevocationStatements == "" {
|
||||
return h.revokeUserDefault(username)
|
||||
return h.revokeUserDefault(ctx, username)
|
||||
}
|
||||
|
||||
// Get connection
|
||||
db, err := h.getConnection()
|
||||
db, err := h.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -219,14 +222,14 @@ func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error
|
|||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -239,38 +242,38 @@ func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (h *HANA) revokeUserDefault(username string) error {
|
||||
func (h *HANA) revokeUserDefault(ctx context.Context, username string) error {
|
||||
// Get connection
|
||||
db, err := h.getConnection()
|
||||
db, err := h.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Disable server login for user
|
||||
disableStmt, err := tx.Prepare(fmt.Sprintf("ALTER USER %s DEACTIVATE USER NOW", username))
|
||||
disableStmt, err := tx.PrepareContext(ctx, fmt.Sprintf("ALTER USER %s DEACTIVATE USER NOW", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer disableStmt.Close()
|
||||
if _, err := disableStmt.Exec(); err != nil {
|
||||
if _, err := disableStmt.ExecContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidates current sessions and performs soft drop (drop if no dependencies)
|
||||
// if hard drop is desired, custom revoke statements should be written for role
|
||||
dropStmt, err := tx.Prepare(fmt.Sprintf("DROP USER %s RESTRICT", username))
|
||||
dropStmt, err := tx.PrepareContext(ctx, fmt.Sprintf("DROP USER %s RESTRICT", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dropStmt.Close()
|
||||
if _, err := dropStmt.Exec(); err != nil {
|
||||
if _, err := dropStmt.ExecContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package hana
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -25,7 +26,7 @@ func TestHANA_Initialize(t *testing.T) {
|
|||
dbRaw, _ := New()
|
||||
db := dbRaw.(*HANA)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -55,7 +56,7 @@ func TestHANA_CreateUser(t *testing.T) {
|
|||
dbRaw, _ := New()
|
||||
db := dbRaw.(*HANA)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -66,7 +67,7 @@ func TestHANA_CreateUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
@ -75,7 +76,7 @@ func TestHANA_CreateUser(t *testing.T) {
|
|||
CreationStatements: testHANARole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -98,7 +99,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
|||
dbRaw, _ := New()
|
||||
db := dbRaw.(*HANA)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -113,7 +114,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -121,7 +122,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
|||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -130,7 +131,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test custom revoke statememt
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -139,7 +140,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
statements.RevocationStatements = testHANADrop
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
|
@ -33,7 +34,7 @@ type mongoDBConnectionProducer struct {
|
|||
}
|
||||
|
||||
// Initialize parses connection configuration.
|
||||
func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
||||
func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
|
@ -75,7 +76,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri
|
|||
c.Initialized = true
|
||||
|
||||
if verifyConnection {
|
||||
if _, err := c.Connection(); err != nil {
|
||||
if _, err := c.Connection(ctx); err != nil {
|
||||
return fmt.Errorf("error verifying connection: %s", err)
|
||||
}
|
||||
|
||||
|
@ -88,7 +89,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri
|
|||
}
|
||||
|
||||
// Connection creates a database connection.
|
||||
func (c *mongoDBConnectionProducer) Connection() (interface{}, error) {
|
||||
func (c *mongoDBConnectionProducer) Connection(_ context.Context) (interface{}, error) {
|
||||
if !c.Initialized {
|
||||
return nil, connutil.ErrNotInitialized
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -27,6 +28,8 @@ type MongoDB struct {
|
|||
credsutil.CredentialsProducer
|
||||
}
|
||||
|
||||
var _ dbplugin.Database = &MongoDB{}
|
||||
|
||||
// New returns a new MongoDB instance
|
||||
func New() (interface{}, error) {
|
||||
connProducer := &mongoDBConnectionProducer{}
|
||||
|
@ -63,8 +66,8 @@ func (m *MongoDB) Type() (string, error) {
|
|||
return mongoDBTypeName, nil
|
||||
}
|
||||
|
||||
func (m *MongoDB) getConnection() (*mgo.Session, error) {
|
||||
session, err := m.Connection()
|
||||
func (m *MongoDB) getConnection(ctx context.Context) (*mgo.Session, error) {
|
||||
session, err := m.Connection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -80,7 +83,7 @@ func (m *MongoDB) getConnection() (*mgo.Session, error) {
|
|||
//
|
||||
// JSON Example:
|
||||
// { "db": "admin", "roles": [{ "role": "readWrite" }, {"role": "read", "db": "foo"}] }
|
||||
func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
// Grab the lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
@ -89,7 +92,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
|
|||
return "", "", dbutil.ErrEmptyCreationStatement
|
||||
}
|
||||
|
||||
session, err := m.getConnection()
|
||||
session, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -133,7 +136,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
|
|||
if err := m.ConnectionProducer.Close(); err != nil {
|
||||
return "", "", errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
|
||||
}
|
||||
session, err := m.getConnection()
|
||||
session, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -149,15 +152,15 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
|
|||
}
|
||||
|
||||
// RenewUser is not supported on MongoDB, so this is a no-op.
|
||||
func (m *MongoDB) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (m *MongoDB) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeUser drops the specified user from the authentication databse. If none is provided
|
||||
// in the revocation statement, the default "admin" authentication database will be assumed.
|
||||
func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
session, err := m.getConnection()
|
||||
func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
session, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -188,7 +191,7 @@ func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) er
|
|||
if err := m.ConnectionProducer.Close(); err != nil {
|
||||
return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
|
||||
}
|
||||
session, err := m.getConnection()
|
||||
session, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
@ -79,7 +80,7 @@ func TestMongoDB_Initialize(t *testing.T) {
|
|||
db := dbRaw.(*MongoDB)
|
||||
connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer)
|
||||
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -107,7 +108,7 @@ func TestMongoDB_CreateUser(t *testing.T) {
|
|||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
db := dbRaw.(*MongoDB)
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -121,7 +122,7 @@ func TestMongoDB_CreateUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -145,7 +146,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
|
|||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
db := dbRaw.(*MongoDB)
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -159,7 +160,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -182,7 +183,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
|
|||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
db := dbRaw.(*MongoDB)
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -196,7 +197,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -206,7 +207,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test default revocation statememt
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -18,6 +19,8 @@ import (
|
|||
|
||||
const msSQLTypeName = "mssql"
|
||||
|
||||
var _ dbplugin.Database = &MSSQL{}
|
||||
|
||||
// MSSQL is an implementation of Database interface
|
||||
type MSSQL struct {
|
||||
connutil.ConnectionProducer
|
||||
|
@ -60,8 +63,8 @@ func (m *MSSQL) Type() (string, error) {
|
|||
return msSQLTypeName, nil
|
||||
}
|
||||
|
||||
func (m *MSSQL) getConnection() (*sql.DB, error) {
|
||||
db, err := m.Connection()
|
||||
func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) {
|
||||
db, err := m.Connection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -71,13 +74,13 @@ func (m *MSSQL) getConnection() (*sql.DB, error) {
|
|||
|
||||
// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
|
||||
// the CreationStatement provided.
|
||||
func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
// Grab the lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
db, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -102,7 +105,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
|||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -115,7 +118,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
|||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
"expiration": expirationStr,
|
||||
|
@ -124,7 +127,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
|||
return "", "", err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
@ -138,7 +141,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
|||
}
|
||||
|
||||
// RenewUser is not supported on MSSQL, so this is a no-op.
|
||||
func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
@ -146,19 +149,19 @@ func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expir
|
|||
// RevokeUser attempts to drop the specified user. It will first attempt to disable login,
|
||||
// then kill pending connections from that user, and finally drop the user and login from the
|
||||
// database instance.
|
||||
func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
if statements.RevocationStatements == "" {
|
||||
return m.revokeUserDefault(username)
|
||||
return m.revokeUserDefault(ctx, username)
|
||||
}
|
||||
|
||||
// Get connection
|
||||
db, err := m.getConnection()
|
||||
db, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -171,14 +174,14 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro
|
|||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -191,20 +194,20 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *MSSQL) revokeUserDefault(username string) error {
|
||||
func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
|
||||
// Get connection
|
||||
db, err := m.getConnection()
|
||||
db, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// First disable server login
|
||||
disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
|
||||
disableStmt, err := db.PrepareContext(ctx, fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer disableStmt.Close()
|
||||
if _, err := disableStmt.Exec(); err != nil {
|
||||
if _, err := disableStmt.ExecContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -212,14 +215,14 @@ func (m *MSSQL) revokeUserDefault(username string) error {
|
|||
// sessions. There cannot be any active sessions before we drop the logins
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
sessionStmt, err := db.Prepare(fmt.Sprintf(
|
||||
sessionStmt, err := db.PrepareContext(ctx, fmt.Sprintf(
|
||||
"SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sessionStmt.Close()
|
||||
|
||||
sessionRows, err := sessionStmt.Query()
|
||||
sessionRows, err := sessionStmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -240,13 +243,13 @@ func (m *MSSQL) revokeUserDefault(username string) error {
|
|||
// we need to drop the database users before we can drop the login and the role
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
stmt, err := db.Prepare(fmt.Sprintf("EXEC master.dbo.sp_msloginmappings '%s';", username))
|
||||
stmt, err := db.PrepareContext(ctx, fmt.Sprintf("EXEC master.dbo.sp_msloginmappings '%s';", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query()
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -266,13 +269,13 @@ func (m *MSSQL) revokeUserDefault(username string) error {
|
|||
// many permissions as possible right now
|
||||
var lastStmtError error
|
||||
for _, query := range revokeStmts {
|
||||
stmt, err := db.Prepare(query)
|
||||
stmt, err := db.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
continue
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.Exec()
|
||||
_, err = stmt.ExecContext(ctx)
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
}
|
||||
|
@ -287,12 +290,12 @@ func (m *MSSQL) revokeUserDefault(username string) error {
|
|||
}
|
||||
|
||||
// Drop this login
|
||||
stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username))
|
||||
stmt, err = db.PrepareContext(ctx, fmt.Sprintf(dropLoginSQL, username, username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -30,7 +31,7 @@ func TestMSSQL_Initialize(t *testing.T) {
|
|||
dbRaw, _ := New()
|
||||
db := dbRaw.(*MSSQL)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -51,7 +52,7 @@ func TestMSSQL_Initialize(t *testing.T) {
|
|||
"max_open_connections": "5",
|
||||
}
|
||||
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -69,7 +70,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*MSSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -80,7 +81,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
@ -89,7 +90,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
|
|||
CreationStatements: testMSSQLRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -111,7 +112,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*MSSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -125,7 +126,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -135,7 +136,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -144,7 +145,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
|||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -155,7 +156,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
|||
|
||||
// Test custom revoke statememt
|
||||
statements.RevocationStatements = testMSSQLDrop
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -30,6 +31,8 @@ var (
|
|||
LegacyUsernameLen int = 16
|
||||
)
|
||||
|
||||
var _ dbplugin.Database = &MySQL{}
|
||||
|
||||
type MySQL struct {
|
||||
connutil.ConnectionProducer
|
||||
credsutil.CredentialsProducer
|
||||
|
@ -88,8 +91,8 @@ func (m *MySQL) Type() (string, error) {
|
|||
return mySQLTypeName, nil
|
||||
}
|
||||
|
||||
func (m *MySQL) getConnection() (*sql.DB, error) {
|
||||
db, err := m.Connection()
|
||||
func (m *MySQL) getConnection(ctx context.Context) (*sql.DB, error) {
|
||||
db, err := m.Connection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -97,13 +100,13 @@ func (m *MySQL) getConnection() (*sql.DB, error) {
|
|||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
// Grab the lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
db, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -128,7 +131,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
|||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -146,7 +149,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
|||
"expiration": expirationStr,
|
||||
})
|
||||
|
||||
stmt, err := tx.Prepare(query)
|
||||
stmt, err := tx.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
// If the error code we get back is Error 1295: This command is not
|
||||
// supported in the prepared statement protocol yet, we will execute
|
||||
|
@ -155,7 +158,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
|||
// prepare supported commands. If there is no error when running we
|
||||
// will continue to the next statement.
|
||||
if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 {
|
||||
_, err = tx.Exec(query)
|
||||
_, err = tx.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
@ -165,7 +168,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
|||
return "", "", err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
@ -179,17 +182,17 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
|||
}
|
||||
|
||||
// NOOP
|
||||
func (m *MySQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (m *MySQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
// Grab the read lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
db, err := m.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -201,7 +204,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro
|
|||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -217,7 +220,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro
|
|||
// 1295: This command is not supported in the prepared statement protocol yet
|
||||
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
|
||||
query = strings.Replace(query, "{{name}}", username, -1)
|
||||
_, err = tx.Exec(query)
|
||||
_, err = tx.ExecContext(ctx, query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -108,7 +109,7 @@ func TestMySQL_Initialize(t *testing.T) {
|
|||
db := dbRaw.(*MySQL)
|
||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -128,7 +129,7 @@ func TestMySQL_Initialize(t *testing.T) {
|
|||
"max_open_connections": "5",
|
||||
}
|
||||
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -146,7 +147,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
|||
dbRaw, _ := f()
|
||||
db := dbRaw.(*MySQL)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -157,7 +158,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
@ -166,7 +167,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
|||
CreationStatements: testMySQLRoleWildCard,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -176,7 +177,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test a second time to make sure usernames don't collide
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -188,7 +189,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
|||
// Test with a manualy prepare statement
|
||||
statements.CreationStatements = testMySQLRolePreparedStmt
|
||||
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -211,7 +212,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
|
|||
dbRaw, _ := f()
|
||||
db := dbRaw.(*MySQL)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -222,7 +223,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
@ -231,7 +232,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
|
|||
CreationStatements: testMySQLRoleWildCard,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -241,7 +242,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test a second time to make sure usernames don't collide
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -263,7 +264,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
|||
dbRaw, _ := f()
|
||||
db := dbRaw.(*MySQL)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -277,7 +278,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -287,7 +288,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -297,7 +298,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
statements.CreationStatements = testMySQLRoleWildCard
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -308,7 +309,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
|||
|
||||
// Test custom revoke statements
|
||||
statements.RevocationStatements = testMySQLRevocationSQL
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -24,6 +25,8 @@ ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
|
|||
`
|
||||
)
|
||||
|
||||
var _ dbplugin.Database = &PostgreSQL{}
|
||||
|
||||
// New implements builtinplugins.BuiltinFactory
|
||||
func New() (interface{}, error) {
|
||||
connProducer := &connutil.SQLConnectionProducer{}
|
||||
|
@ -65,8 +68,8 @@ func (p *PostgreSQL) Type() (string, error) {
|
|||
return postgreSQLTypeName, nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) getConnection() (*sql.DB, error) {
|
||||
db, err := p.Connection()
|
||||
func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) {
|
||||
db, err := p.Connection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -74,7 +77,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) {
|
|||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
if statements.CreationStatements == "" {
|
||||
return "", "", dbutil.ErrEmptyCreationStatement
|
||||
}
|
||||
|
@ -99,14 +102,14 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
|
|||
}
|
||||
|
||||
// Get the connection
|
||||
db, err := p.getConnection()
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
|
||||
|
@ -123,7 +126,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
|
|||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
"expiration": expirationStr,
|
||||
|
@ -133,7 +136,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
|
|||
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
return "", "", err
|
||||
|
||||
}
|
||||
|
@ -148,7 +151,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
|
|||
return username, password, nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
|
@ -157,12 +160,12 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
|
|||
renewStmts = defaultPostgresRenewSQL
|
||||
}
|
||||
|
||||
db, err := p.getConnection()
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -180,7 +183,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
|
|||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"expiration": expirationStr,
|
||||
}))
|
||||
|
@ -189,7 +192,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
|
|||
}
|
||||
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -201,25 +204,25 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if statements.RevocationStatements == "" {
|
||||
return p.defaultRevokeUser(username)
|
||||
return p.defaultRevokeUser(ctx, username)
|
||||
}
|
||||
|
||||
return p.customRevokeUser(username, statements.RevocationStatements)
|
||||
return p.customRevokeUser(ctx, username, statements.RevocationStatements)
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
||||
db, err := p.getConnection()
|
||||
func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationStmts string) error {
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -233,7 +236,7 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
|||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
}))
|
||||
if err != nil {
|
||||
|
@ -241,7 +244,7 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
|||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -253,15 +256,15 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
||||
db, err := p.getConnection()
|
||||
func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error {
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if the role exists
|
||||
var exists bool
|
||||
err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
|
||||
err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
|
@ -274,13 +277,13 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
|||
// the role
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
|
||||
stmt, err := db.PrepareContext(ctx, "SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query(username)
|
||||
rows, err := stmt.QueryContext(ctx, username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -322,7 +325,7 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
|||
// get the current database name so we can issue a REVOKE CONNECT for
|
||||
// this username
|
||||
var dbname sql.NullString
|
||||
if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil {
|
||||
if err := db.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbname); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -337,13 +340,13 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
|||
// many permissions as possible right now
|
||||
var lastStmtError error
|
||||
for _, query := range revocationStmts {
|
||||
stmt, err := db.Prepare(query)
|
||||
stmt, err := db.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
continue
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.Exec()
|
||||
_, err = stmt.ExecContext(ctx)
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
}
|
||||
|
@ -358,13 +361,13 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
|||
}
|
||||
|
||||
// Drop this user
|
||||
stmt, err = db.Prepare(fmt.Sprintf(
|
||||
stmt, err = db.PrepareContext(ctx, fmt.Sprintf(
|
||||
`DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -72,7 +73,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
|
|||
|
||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -92,7 +93,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
|
|||
"max_open_connections": "5",
|
||||
}
|
||||
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -109,7 +110,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*PostgreSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -120,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
@ -129,7 +130,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
|||
CreationStatements: testPostgresRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -139,7 +140,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
|||
}
|
||||
|
||||
statements.CreationStatements = testPostgresReadOnlyRole
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -162,7 +163,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*PostgreSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -176,7 +177,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -185,7 +186,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
|||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
|
||||
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -197,7 +198,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
|||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
statements.RenewStatements = defaultPostgresRenewSQL
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -206,7 +207,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
|||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
|
||||
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -230,7 +231,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
|||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*PostgreSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -244,7 +245,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
|||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -254,7 +255,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -263,7 +264,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
|||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
@ -274,7 +275,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
|||
|
||||
// Test custom revoke statements
|
||||
statements.RevocationStatements = defaultPostgresRevocationSQL
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package connutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
@ -14,8 +15,8 @@ var (
|
|||
// connections and is used in all the builtin database types.
|
||||
type ConnectionProducer interface {
|
||||
Close() error
|
||||
Initialize(map[string]interface{}, bool) error
|
||||
Connection() (interface{}, error)
|
||||
Initialize(context.Context, map[string]interface{}, bool) error
|
||||
Connection(context.Context) (interface{}, error)
|
||||
|
||||
sync.Locker
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package connutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -25,7 +26,7 @@ type SQLConnectionProducer struct {
|
|||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
||||
func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
|
@ -62,11 +63,11 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo
|
|||
c.Initialized = true
|
||||
|
||||
if verifyConnection {
|
||||
if _, err := c.Connection(); err != nil {
|
||||
if _, err := c.Connection(ctx); err != nil {
|
||||
return fmt.Errorf("error verifying connection: %s", err)
|
||||
}
|
||||
|
||||
if err := c.db.Ping(); err != nil {
|
||||
if err := c.db.PingContext(ctx); err != nil {
|
||||
return fmt.Errorf("error verifying connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -74,14 +75,14 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *SQLConnectionProducer) Connection() (interface{}, error) {
|
||||
func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
|
||||
if !c.Initialized {
|
||||
return nil, ErrNotInitialized
|
||||
}
|
||||
|
||||
// If we already have a DB, test it and return
|
||||
if c.db != nil {
|
||||
if err := c.db.Ping(); err == nil {
|
||||
if err := c.db.PingContext(ctx); err == nil {
|
||||
return c.db, nil
|
||||
}
|
||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
|
@ -460,10 +459,11 @@ func (c *Core) setupCredentials() error {
|
|||
backend, err = c.newCredentialBackend(entry.Type, sysView, view, conf)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err)
|
||||
if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" {
|
||||
// If we encounter an error instantiating the backend due to it being missing from the catalog,
|
||||
// skip backend initialization but register the entry to the mount table to preserve storage
|
||||
// and path.
|
||||
if entry.Type == "plugin" {
|
||||
// If we encounter an error instantiating the backend due to an error,
|
||||
// skip backend initialization but register the entry to the mount table
|
||||
// to preserve storage and path.
|
||||
c.logger.Warn("core: skipping plugin-based credential entry", "path", entry.Path)
|
||||
goto ROUTER_MOUNT
|
||||
}
|
||||
return errLoadAuthFailed
|
||||
|
|
|
@ -750,6 +750,31 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke
|
|||
}, nil
|
||||
}
|
||||
|
||||
sysView := m.router.MatchingSystemView(le.Path)
|
||||
if sysView == nil {
|
||||
return nil, fmt.Errorf("expiration: unable to retrieve system view from router")
|
||||
}
|
||||
|
||||
retResp := &logical.Response{}
|
||||
switch {
|
||||
case resp.Auth.Period > time.Duration(0):
|
||||
// If it resp.Period is non-zero, use that as the TTL and override backend's
|
||||
// call on TTL modification, such as a TTL value determined by
|
||||
// framework.LeaseExtend call against the request. Also, cap period value to
|
||||
// the sys/mount max value.
|
||||
if resp.Auth.Period > sysView.MaxLeaseTTL() {
|
||||
retResp.AddWarning(fmt.Sprintf("Period of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL()))
|
||||
resp.Auth.Period = sysView.MaxLeaseTTL()
|
||||
}
|
||||
resp.Auth.TTL = resp.Auth.Period
|
||||
case resp.Auth.TTL > time.Duration(0):
|
||||
// Cap TTL value to the sys/mount max value
|
||||
if resp.Auth.TTL > sysView.MaxLeaseTTL() {
|
||||
retResp.AddWarning(fmt.Sprintf("TTL of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL()))
|
||||
resp.Auth.TTL = sysView.MaxLeaseTTL()
|
||||
}
|
||||
}
|
||||
|
||||
// Attach the ClientToken
|
||||
resp.Auth.ClientToken = token
|
||||
resp.Auth.Increment = 0
|
||||
|
@ -764,9 +789,9 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke
|
|||
|
||||
// Update the expiration time
|
||||
m.updatePending(le, resp.Auth.LeaseTotal())
|
||||
return &logical.Response{
|
||||
Auth: resp.Auth,
|
||||
}, nil
|
||||
|
||||
retResp.Auth = resp.Auth
|
||||
return retResp, nil
|
||||
}
|
||||
|
||||
// Register is used to take a request and response with an associated
|
||||
|
@ -866,6 +891,12 @@ func (m *ExpirationManager) RegisterAuth(source string, auth *logical.Auth) erro
|
|||
return err
|
||||
}
|
||||
|
||||
// If it resp.Period is non-zero, override the TTL value determined
|
||||
// by the backend.
|
||||
if auth.Period > time.Duration(0) {
|
||||
auth.TTL = auth.Period
|
||||
}
|
||||
|
||||
// Create a lease entry
|
||||
le := leaseEntry{
|
||||
LeaseID: path.Join(source, saltedID),
|
||||
|
@ -1017,8 +1048,7 @@ func (m *ExpirationManager) revokeEntry(le *leaseEntry) error {
|
|||
}
|
||||
|
||||
// Handle standard revocation via backends
|
||||
resp, err := m.router.Route(logical.RevokeRequest(
|
||||
le.Path, le.Secret, le.Data))
|
||||
resp, err := m.router.Route(logical.RevokeRequest(le.Path, le.Secret, le.Data))
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
return fmt.Errorf("failed to revoke entry: resp:%#v err:%s", resp, err)
|
||||
}
|
||||
|
|
|
@ -2,7 +2,9 @@ package vault_test
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -178,14 +180,13 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
|
|||
}
|
||||
|
||||
if testMount {
|
||||
// Add plugin back to the catalog
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical")
|
||||
|
||||
// Mount the plugin at the same path after plugin is re-added to the catalog
|
||||
// and expect an error due to existing path.
|
||||
var err error
|
||||
switch btype {
|
||||
case logical.TypeLogical:
|
||||
// Add plugin back to the catalog
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical")
|
||||
_, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{
|
||||
"type": "plugin",
|
||||
"config": map[string]interface{}{
|
||||
|
@ -193,6 +194,8 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
|
|||
},
|
||||
})
|
||||
case logical.TypeCredential:
|
||||
// Add plugin back to the catalog
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials")
|
||||
_, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{
|
||||
"type": "plugin",
|
||||
"plugin_name": "mock-plugin",
|
||||
|
@ -204,6 +207,129 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
|
|||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_continueOnError(t *testing.T) {
|
||||
t.Run("secret", func(t *testing.T) {
|
||||
t.Run("sha256_mismatch", func(t *testing.T) {
|
||||
testPlugin_continueOnError(t, logical.TypeLogical, true)
|
||||
})
|
||||
|
||||
t.Run("missing_plugin", func(t *testing.T) {
|
||||
testPlugin_continueOnError(t, logical.TypeLogical, false)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("auth", func(t *testing.T) {
|
||||
t.Run("sha256_mismatch", func(t *testing.T) {
|
||||
testPlugin_continueOnError(t, logical.TypeCredential, true)
|
||||
})
|
||||
|
||||
t.Run("missing_plugin", func(t *testing.T) {
|
||||
testPlugin_continueOnError(t, logical.TypeCredential, false)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatch bool) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, btype)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
||||
// Get the registered plugin
|
||||
req := logical.TestRequest(t, logical.ReadOperation, "sys/plugins/catalog/mock-plugin")
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err := core.HandleRequest(req)
|
||||
if err != nil || resp == nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
command, ok := resp.Data["command"].(string)
|
||||
if !ok || command == "" {
|
||||
t.Fatal("invalid command")
|
||||
}
|
||||
|
||||
// Trigger a sha256 mistmatch or missing plugin error
|
||||
if mismatch {
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/catalog/mock-plugin")
|
||||
req.Data = map[string]interface{}{
|
||||
"sha256": "d17bd7334758e53e6fbab15745d2520765c06e296f2ce8e25b7919effa0ac216",
|
||||
"command": filepath.Base(command),
|
||||
}
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
} else {
|
||||
err := os.Remove(filepath.Join(cluster.TempDir, filepath.Base(command)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Seal the cluster
|
||||
cluster.EnsureCoresSealed(t)
|
||||
|
||||
// Unseal the cluster
|
||||
barrierKeys := cluster.BarrierKeys
|
||||
for _, core := range cluster.Cores {
|
||||
for _, key := range barrierKeys {
|
||||
_, err := core.Unseal(vault.TestKeyCopy(key))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
sealed, err := core.Sealed()
|
||||
if err != nil {
|
||||
t.Fatalf("err checking seal status: %s", err)
|
||||
}
|
||||
if sealed {
|
||||
t.Fatal("should not be sealed")
|
||||
}
|
||||
// Wait for active so post-unseal takes place
|
||||
// If it fails, it means unseal process failed
|
||||
vault.TestWaitActive(t, core.Core)
|
||||
}
|
||||
|
||||
// Re-add the plugin to the catalog
|
||||
switch btype {
|
||||
case logical.TypeLogical:
|
||||
vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical", cluster.TempDir)
|
||||
case logical.TypeCredential:
|
||||
vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials", cluster.TempDir)
|
||||
}
|
||||
|
||||
// Reload the plugin
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/reload/backend")
|
||||
req.Data = map[string]interface{}{
|
||||
"plugin": "mock-plugin",
|
||||
}
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
|
||||
// Make a request to lazy load the plugin
|
||||
var reqPath string
|
||||
switch btype {
|
||||
case logical.TypeLogical:
|
||||
reqPath = "mock-0/internal"
|
||||
case logical.TypeCredential:
|
||||
reqPath = "auth/mock-0/internal"
|
||||
}
|
||||
|
||||
req = logical.TestRequest(t, logical.ReadOperation, reqPath)
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: response should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_Plugin_autoReload(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
|
||||
defer cluster.Cleanup()
|
||||
|
@ -332,7 +458,10 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
|
|||
}
|
||||
|
||||
// testSystemBackendMock returns a systemBackend with the desired number
|
||||
// of mounted mock plugin backends
|
||||
// of mounted mock plugin backends. numMounts alternates between different
|
||||
// ways of providing the plugin_name.
|
||||
//
|
||||
// The mounts are mounted at sys/mounts/mock-[numMounts] or sys/auth/mock-[numMounts]
|
||||
func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType) *vault.TestCluster {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
|
@ -343,10 +472,17 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo
|
|||
},
|
||||
}
|
||||
|
||||
// Create a tempdir, cluster.Cleanup will clean up this directory
|
||||
tempDir, err := ioutil.TempDir("", "vault-test-cluster")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
KeepStandbysSealed: true,
|
||||
NumCores: numCores,
|
||||
TempDir: tempDir,
|
||||
})
|
||||
cluster.Start()
|
||||
|
||||
|
@ -358,7 +494,7 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo
|
|||
|
||||
switch backendType {
|
||||
case logical.TypeLogical:
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical")
|
||||
vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical", tempDir)
|
||||
for i := 0; i < numMounts; i++ {
|
||||
// Alternate input styles for plugin_name on every other mount
|
||||
options := map[string]interface{}{
|
||||
|
@ -380,7 +516,7 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo
|
|||
}
|
||||
}
|
||||
case logical.TypeCredential:
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials")
|
||||
vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials", tempDir)
|
||||
for i := 0; i < numMounts; i++ {
|
||||
// Alternate input styles for plugin_name on every other mount
|
||||
options := map[string]interface{}{
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
|
@ -753,10 +752,11 @@ func (c *Core) setupMounts() error {
|
|||
backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err)
|
||||
if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" {
|
||||
// If we encounter an error instantiating the backend due to it being missing from the catalog,
|
||||
// skip backend initialization but register the entry to the mount table to preserve storage
|
||||
// and path.
|
||||
if entry.Type == "plugin" {
|
||||
// If we encounter an error instantiating the backend due to an error,
|
||||
// skip backend initialization but register the entry to the mount table
|
||||
// to preserve storage and path.
|
||||
c.logger.Warn("core: skipping plugin-based mount entry", "path", entry.Path)
|
||||
goto ROUTER_MOUNT
|
||||
}
|
||||
return errLoadMountsFailed
|
||||
|
|
|
@ -79,15 +79,23 @@ func (c *Core) reloadMatchingPlugin(pluginName string) error {
|
|||
func (c *Core) reloadPluginCommon(entry *MountEntry, isAuth bool) error {
|
||||
path := entry.Path
|
||||
|
||||
if isAuth {
|
||||
path = credentialRoutePrefix + path
|
||||
}
|
||||
|
||||
// Fast-path out if the backend doesn't exist
|
||||
raw, ok := c.router.root.Get(path)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Call backend's Cleanup routine
|
||||
re := raw.(*routeEntry)
|
||||
re.backend.Cleanup()
|
||||
|
||||
// Only call Cleanup if backend is initialized
|
||||
if re.backend != nil {
|
||||
// Call backend's Cleanup routine
|
||||
re.backend.Cleanup()
|
||||
}
|
||||
|
||||
view := re.storageView
|
||||
|
||||
|
|
|
@ -477,14 +477,27 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
|
|||
return nil, nil, ErrInternalError
|
||||
}
|
||||
|
||||
// Set the default lease if not provided
|
||||
if auth.TTL == 0 {
|
||||
auth.TTL = sysView.DefaultLeaseTTL()
|
||||
}
|
||||
// Start off with the sys default value, and update according to period/TTL
|
||||
// from resp.Auth
|
||||
tokenTTL := sysView.DefaultLeaseTTL()
|
||||
|
||||
// Limit the lease duration
|
||||
if auth.TTL > sysView.MaxLeaseTTL() {
|
||||
auth.TTL = sysView.MaxLeaseTTL()
|
||||
switch {
|
||||
case auth.Period > time.Duration(0):
|
||||
// Cap the period value to the sys max_ttl value. The auth backend should
|
||||
// have checked for it on its login path, but we check here again for
|
||||
// sanity.
|
||||
if auth.Period > sysView.MaxLeaseTTL() {
|
||||
auth.Period = sysView.MaxLeaseTTL()
|
||||
}
|
||||
tokenTTL = auth.Period
|
||||
case auth.TTL > time.Duration(0):
|
||||
// Cap the TTL value. The auth backend should have checked for it on its
|
||||
// login path (e.g. a call to b.SanitizeTTL), but we check here again for
|
||||
// sanity.
|
||||
if auth.TTL > sysView.MaxLeaseTTL() {
|
||||
auth.TTL = sysView.MaxLeaseTTL()
|
||||
}
|
||||
tokenTTL = auth.TTL
|
||||
}
|
||||
|
||||
// Generate a token
|
||||
|
@ -494,7 +507,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
|
|||
Meta: auth.Metadata,
|
||||
DisplayName: auth.DisplayName,
|
||||
CreationTime: time.Now().Unix(),
|
||||
TTL: auth.TTL,
|
||||
TTL: tokenTTL,
|
||||
NumUses: auth.NumUses,
|
||||
EntityID: auth.EntityID,
|
||||
}
|
||||
|
@ -513,10 +526,11 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
|
|||
return nil, auth, ErrInternalError
|
||||
}
|
||||
|
||||
// Populate the client token and accessor
|
||||
// Populate the client token, accessor, and TTL
|
||||
auth.ClientToken = te.ID
|
||||
auth.Accessor = te.Accessor
|
||||
auth.Policies = te.Policies
|
||||
auth.TTL = te.TTL
|
||||
|
||||
// Register with the expiration manager
|
||||
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
|
||||
|
|
|
@ -379,6 +379,8 @@ func TestDynamicSystemView(c *Core) *dynamicSystemView {
|
|||
return &dynamicSystemView{c, me}
|
||||
}
|
||||
|
||||
// TestAddTestPlugin registers the testFunc as part of the plugin command to the
|
||||
// plugin catalog.
|
||||
func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) {
|
||||
file, err := os.Open(os.Args[0])
|
||||
if err != nil {
|
||||
|
@ -413,11 +415,74 @@ func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestAddTestPluginTempDir registers the testFunc as part of the plugin command to the
|
||||
// plugin catalog. It uses tmpDir as the plugin directory.
|
||||
func TestAddTestPluginTempDir(t testing.T, c *Core, name, testFunc, tempDir string) {
|
||||
file, err := os.Open(os.Args[0])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
fi, err := file.Stat()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Copy over the file to the temp dir
|
||||
dst := filepath.Join(tempDir, filepath.Base(os.Args[0]))
|
||||
out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
if _, err = io.Copy(out, file); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = out.Sync()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Determine plugin directory full path
|
||||
fullPath, err := filepath.EvalSymlinks(tempDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
reader, err := os.Open(filepath.Join(fullPath, filepath.Base(os.Args[0])))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Find out the sha256
|
||||
hash := sha256.New()
|
||||
|
||||
_, err = io.Copy(hash, reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sum := hash.Sum(nil)
|
||||
|
||||
// Set core's plugin directory and plugin catalog directory
|
||||
c.pluginDirectory = fullPath
|
||||
c.pluginCatalog.directory = fullPath
|
||||
|
||||
command := fmt.Sprintf("%s --test.run=%s", filepath.Base(os.Args[0]), testFunc)
|
||||
err = c.pluginCatalog.Set(name, command, sum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
var testLogicalBackends = map[string]logical.Factory{}
|
||||
var testCredentialBackends = map[string]logical.Factory{}
|
||||
|
||||
// Starts the test server which responds to SSH authentication.
|
||||
// Used to test the SSH secret backend.
|
||||
// StartSSHHostTestServer starts the test server which responds to SSH
|
||||
// authentication. Used to test the SSH secret backend.
|
||||
func StartSSHHostTestServer() (string, error) {
|
||||
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testSharedPublicKey))
|
||||
if err != nil {
|
||||
|
@ -760,6 +825,7 @@ func (c *TestCluster) ensureCoresSealed() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// UnsealWithStoredKeys uses stored keys to unseal the test cluster cores
|
||||
func (c *TestCluster) UnsealWithStoredKeys(t testing.T) error {
|
||||
for _, core := range c.Cores {
|
||||
if err := core.UnsealWithStoredKeys(); err != nil {
|
||||
|
|
|
@ -2312,7 +2312,7 @@ func TestTokenStore_RolePeriod(t *testing.T) {
|
|||
req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test")
|
||||
req.ClientToken = root
|
||||
req.Data = map[string]interface{}{
|
||||
"period": 300,
|
||||
"period": 5,
|
||||
}
|
||||
|
||||
resp, err := core.HandleRequest(req)
|
||||
|
@ -2425,8 +2425,8 @@ func TestTokenStore_RolePeriod(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl := resp.Data["ttl"].(int64)
|
||||
if ttl < 299 {
|
||||
t.Fatalf("TTL too small (expected %d, got %d", 299, ttl)
|
||||
if ttl > 5 {
|
||||
t.Fatalf("TTL too large (expected %d, got %d", 5, ttl)
|
||||
}
|
||||
|
||||
// Let the TTL go down a bit to 3 seconds
|
||||
|
@ -2449,8 +2449,8 @@ func TestTokenStore_RolePeriod(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl = resp.Data["ttl"].(int64)
|
||||
if ttl < 299 {
|
||||
t.Fatalf("TTL too small (expected %d, got %d", 299, ttl)
|
||||
if ttl > 5 {
|
||||
t.Fatalf("TTL too large (expected %d, got %d", 5, ttl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2677,7 +2677,7 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test")
|
||||
req.ClientToken = root
|
||||
req.Data = map[string]interface{}{
|
||||
"period": 300,
|
||||
"period": 5,
|
||||
}
|
||||
|
||||
resp, err := core.HandleRequest(req)
|
||||
|
@ -2715,8 +2715,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl := resp.Data["ttl"].(int64)
|
||||
if ttl < 299 {
|
||||
t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl)
|
||||
if ttl > 5 {
|
||||
t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl)
|
||||
}
|
||||
|
||||
// Let the TTL go down a bit
|
||||
|
@ -2739,8 +2739,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl = resp.Data["ttl"].(int64)
|
||||
if ttl < 299 {
|
||||
t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl)
|
||||
if ttl > 5 {
|
||||
t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2750,8 +2750,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
req.Operation = logical.UpdateOperation
|
||||
req.Path = "auth/token/create"
|
||||
req.Data = map[string]interface{}{
|
||||
"period": 300,
|
||||
"explicit_max_ttl": 150,
|
||||
"period": 5,
|
||||
"explicit_max_ttl": 4,
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
|
@ -2775,8 +2775,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl := resp.Data["ttl"].(int64)
|
||||
if ttl < 149 || ttl > 150 {
|
||||
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl)
|
||||
if ttl < 3 || ttl > 4 {
|
||||
t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl)
|
||||
}
|
||||
|
||||
// Let the TTL go down a bit
|
||||
|
@ -2799,8 +2799,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl = resp.Data["ttl"].(int64)
|
||||
if ttl < 140 || ttl > 150 {
|
||||
t.Fatalf("TTL bad (expected around %d, got %d)", 145, ttl)
|
||||
if ttl > 2 {
|
||||
t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2812,7 +2812,7 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
req.Operation = logical.UpdateOperation
|
||||
req.Path = "auth/token/create/test"
|
||||
req.Data = map[string]interface{}{
|
||||
"period": 150,
|
||||
"period": 5,
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
|
@ -2836,8 +2836,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl := resp.Data["ttl"].(int64)
|
||||
if ttl < 149 || ttl > 150 {
|
||||
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl)
|
||||
if ttl < 4 || ttl > 5 {
|
||||
t.Fatalf("TTL bad (expected %d, got %d)", 4, ttl)
|
||||
}
|
||||
|
||||
// Let the TTL go down a bit
|
||||
|
@ -2860,8 +2860,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl = resp.Data["ttl"].(int64)
|
||||
if ttl < 149 {
|
||||
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl)
|
||||
if ttl > 5 {
|
||||
t.Fatalf("TTL bad (expected less than %d, got %d)", 5, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2869,18 +2869,23 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
{
|
||||
req.Path = "auth/token/roles/test"
|
||||
req.ClientToken = root
|
||||
req.Operation = logical.UpdateOperation
|
||||
req.Data = map[string]interface{}{
|
||||
"period": 300,
|
||||
"explicit_max_ttl": 150,
|
||||
"period": 5,
|
||||
"explicit_max_ttl": 4,
|
||||
}
|
||||
|
||||
resp, err := core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v %v", err, resp)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatalf("expected a nil response")
|
||||
}
|
||||
|
||||
req.ClientToken = root
|
||||
req.Operation = logical.UpdateOperation
|
||||
req.Path = "auth/token/create/test"
|
||||
req.Data = map[string]interface{}{
|
||||
"period": 150,
|
||||
"explicit_max_ttl": 130,
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v %v", err, resp)
|
||||
|
@ -2903,12 +2908,12 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl := resp.Data["ttl"].(int64)
|
||||
if ttl < 129 || ttl > 130 {
|
||||
t.Fatalf("TTL bad (expected %d, got %d)", 129, ttl)
|
||||
if ttl < 3 || ttl > 4 {
|
||||
t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl)
|
||||
}
|
||||
|
||||
// Let the TTL go down a bit
|
||||
time.Sleep(4 * time.Second)
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
req.Operation = logical.UpdateOperation
|
||||
req.Path = "auth/token/renew-self"
|
||||
|
@ -2927,8 +2932,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl = resp.Data["ttl"].(int64)
|
||||
if ttl > 127 {
|
||||
t.Fatalf("TTL bad (expected < %d, got %d)", 128, ttl)
|
||||
if ttl > 2 {
|
||||
t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue