Merge remote-tracking branch 'oss/master' into f-nomad
* oss/master: Defer reader.Close that is used to determine sha256 changelog++ Avoid unseal failure if plugin backends fail to setup during postUnseal (#3686) Add logic for using Auth.Period when handling auth login/renew requests (#3677) plugins/database: use context with plugins that use database/sql package (#3691) changelog++ Fix plaintext backup in transit (#3692) Database gRPC plugins (#3666)
This commit is contained in:
commit
db0006ef65
10
CHANGELOG.md
10
CHANGELOG.md
|
@ -33,6 +33,8 @@ FEATURES:
|
||||||
operation that can export a given key, including all key versions and
|
operation that can export a given key, including all key versions and
|
||||||
configuration, as well as a restore operation allowing import into another
|
configuration, as well as a restore operation allowing import into another
|
||||||
Vault.
|
Vault.
|
||||||
|
* **gRPC Database Plugins**: Database plugins now use gRPC for transport,
|
||||||
|
allowing them to be written in other languages.
|
||||||
|
|
||||||
IMPROVEMENTS:
|
IMPROVEMENTS:
|
||||||
|
|
||||||
|
@ -46,6 +48,10 @@ IMPROVEMENTS:
|
||||||
during database configuration. This establishes a session-wide [write
|
during database configuration. This establishes a session-wide [write
|
||||||
concern](https://docs.mongodb.com/manual/reference/write-concern/) for the
|
concern](https://docs.mongodb.com/manual/reference/write-concern/) for the
|
||||||
lifecycle of the mount [GH-3646]
|
lifecycle of the mount [GH-3646]
|
||||||
|
* mfa/okta: Filter a given email address as a login filter, allowing operation
|
||||||
|
when login email and account email are different
|
||||||
|
* plugins: Make Vault more resilient when unsealing when plugins are
|
||||||
|
unavailable [GH-3686]
|
||||||
* secret/pki: `allowed_domains` and `key_usage` can now be specified
|
* secret/pki: `allowed_domains` and `key_usage` can now be specified
|
||||||
as a comma-separated string or an array of strings [GH-3642]
|
as a comma-separated string or an array of strings [GH-3642]
|
||||||
* secret/ssh: Allow 4096-bit keys to be used in dynamic key method [GH-3593]
|
* secret/ssh: Allow 4096-bit keys to be used in dynamic key method [GH-3593]
|
||||||
|
@ -58,8 +64,12 @@ BUG FIXES:
|
||||||
* auth/cert: Return `allowed_names` on role read [GH-3654]
|
* auth/cert: Return `allowed_names` on role read [GH-3654]
|
||||||
* auth/ldap: Fix incorrect control information being sent [GH-3402] [GH-3496]
|
* auth/ldap: Fix incorrect control information being sent [GH-3402] [GH-3496]
|
||||||
[GH-3625] [GH-3656]
|
[GH-3625] [GH-3656]
|
||||||
|
* core: Fix seal status reporting when using an autoseal
|
||||||
|
* core: Add creation path to wrap info for a control group token
|
||||||
* core: Fix potential panic that could occur using plugins when a node
|
* core: Fix potential panic that could occur using plugins when a node
|
||||||
transitioned from active to standby [GH-3638]
|
transitioned from active to standby [GH-3638]
|
||||||
|
* core: Fix memory ballooning when a connection would connect to the cluster
|
||||||
|
port and then go away -- redux! [GH-3680]
|
||||||
* core: Replace recursive token revocation logic with depth-first logic, which
|
* core: Replace recursive token revocation logic with depth-first logic, which
|
||||||
can avoid hitting stack depth limits in extreme cases [GH-2348]
|
can avoid hitting stack depth limits in extreme cases [GH-2348]
|
||||||
* core/pkcs11 (enterprise): Fix panic when PKCS#11 library is not readable
|
* core/pkcs11 (enterprise): Fix panic when PKCS#11 library is not readable
|
||||||
|
|
1
Makefile
1
Makefile
|
@ -84,6 +84,7 @@ proto:
|
||||||
protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding
|
protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding
|
||||||
protoc -I physical physical/types.proto --go_out=plugins=grpc:physical
|
protoc -I physical physical/types.proto --go_out=plugins=grpc:physical
|
||||||
protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity
|
protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity
|
||||||
|
protoc builtin/logical/database/dbplugin/*.proto --go_out=plugins=grpc:.
|
||||||
sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go
|
sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go
|
||||||
sed -i -e 's/Iv/IV/' -e 's/Hmac/HMAC/' physical/types.pb.go
|
sed -i -e 's/Iv/IV/' -e 's/Hmac/HMAC/' physical/types.pb.go
|
||||||
|
|
||||||
|
|
|
@ -141,7 +141,7 @@ func testAccStepMapUserIdCidr(t *testing.T, cidr string) logicaltest.TestStep {
|
||||||
func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
|
func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
|
||||||
checkTTL := func(resp *logical.Response) error {
|
checkTTL := func(resp *logical.Response) error {
|
||||||
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
|
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
|
||||||
return fmt.Errorf("invalid TTL")
|
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -165,7 +165,7 @@ func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
|
||||||
func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep {
|
func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep {
|
||||||
checkTTL := func(resp *logical.Response) error {
|
checkTTL := func(resp *logical.Response) error {
|
||||||
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
|
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
|
||||||
return fmt.Errorf("invalid TTL")
|
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package approle
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/logical/framework"
|
"github.com/hashicorp/vault/logical/framework"
|
||||||
|
@ -68,20 +67,13 @@ func (b *backend) pathLoginUpdate(req *logical.Request, data *framework.FieldDat
|
||||||
Policies: role.Policies,
|
Policies: role.Policies,
|
||||||
LeaseOptions: logical.LeaseOptions{
|
LeaseOptions: logical.LeaseOptions{
|
||||||
Renewable: true,
|
Renewable: true,
|
||||||
|
TTL: role.TokenTTL,
|
||||||
},
|
},
|
||||||
Alias: &logical.Alias{
|
Alias: &logical.Alias{
|
||||||
Name: role.RoleID,
|
Name: role.RoleID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// If 'Period' is set, use the value of 'Period' as the TTL.
|
|
||||||
// Otherwise, set the normal TokenTTL.
|
|
||||||
if role.Period > time.Duration(0) {
|
|
||||||
auth.TTL = role.Period
|
|
||||||
} else {
|
|
||||||
auth.TTL = role.TokenTTL
|
|
||||||
}
|
|
||||||
|
|
||||||
return &logical.Response{
|
return &logical.Response{
|
||||||
Auth: auth,
|
Auth: auth,
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -107,16 +99,12 @@ func (b *backend) pathLoginRenew(req *logical.Request, data *framework.FieldData
|
||||||
return nil, fmt.Errorf("role %s does not exist during renewal", roleName)
|
return nil, fmt.Errorf("role %s does not exist during renewal", roleName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If 'Period' is set on the Role, the token should never expire.
|
resp, err := framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
|
||||||
// Replenish the TTL with 'Period's value.
|
if err != nil {
|
||||||
if role.Period > time.Duration(0) {
|
return nil, err
|
||||||
// If 'Period' was updated after the token was issued,
|
|
||||||
// token will bear the updated 'Period' value as its TTL.
|
|
||||||
req.Auth.TTL = role.Period
|
|
||||||
return &logical.Response{Auth: req.Auth}, nil
|
|
||||||
} else {
|
|
||||||
return framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
|
|
||||||
}
|
}
|
||||||
|
resp.Auth.Period = role.Period
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const pathLoginHelpSys = "Issue a token based on the credentials supplied"
|
const pathLoginHelpSys = "Issue a token based on the credentials supplied"
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/rpc"
|
"net/rpc"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -87,7 +88,7 @@ func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) {
|
||||||
// This function creates a new db object from the stored configuration and
|
// This function creates a new db object from the stored configuration and
|
||||||
// caches it in the connections map. The caller of this function needs to hold
|
// caches it in the connections map. The caller of this function needs to hold
|
||||||
// the backend's write lock
|
// the backend's write lock
|
||||||
func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) {
|
func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) {
|
||||||
db, ok := b.connections[name]
|
db, ok := b.connections[name]
|
||||||
if ok {
|
if ok {
|
||||||
return db, nil
|
return db, nil
|
||||||
|
@ -103,7 +104,7 @@ func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Initialize(config.ConnectionDetails, true)
|
err = db.Initialize(ctx, config.ConnectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -170,7 +171,8 @@ func (b *databaseBackend) clearConnection(name string) {
|
||||||
|
|
||||||
func (b *databaseBackend) closeIfShutdown(name string, err error) {
|
func (b *databaseBackend) closeIfShutdown(name string, err error) {
|
||||||
// Plugin has shutdown, close it so next call can reconnect.
|
// Plugin has shutdown, close it so next call can reconnect.
|
||||||
if err == rpc.ErrShutdown {
|
switch err {
|
||||||
|
case rpc.ErrShutdown, dbplugin.ErrPluginShutdown:
|
||||||
b.Lock()
|
b.Lock()
|
||||||
b.clearConnection(name)
|
b.clearConnection(name)
|
||||||
b.Unlock()
|
b.Unlock()
|
||||||
|
|
|
@ -488,9 +488,11 @@ func TestBackend_roleCrud(t *testing.T) {
|
||||||
RevocationStatements: defaultRevocationSQL,
|
RevocationStatements: defaultRevocationSQL,
|
||||||
}
|
}
|
||||||
|
|
||||||
var actual dbplugin.Statements
|
actual := dbplugin.Statements{
|
||||||
if err := mapstructure.Decode(resp.Data, &actual); err != nil {
|
CreationStatements: resp.Data["creation_statements"].(string),
|
||||||
t.Fatal(err)
|
RevocationStatements: resp.Data["revocation_statements"].(string),
|
||||||
|
RollbackStatements: resp.Data["rollback_statements"].(string),
|
||||||
|
RenewStatements: resp.Data["renew_statements"].(string),
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(expected, actual) {
|
if !reflect.DeepEqual(expected, actual) {
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
package dbplugin
|
package dbplugin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"errors"
|
||||||
"net/rpc"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-plugin"
|
"github.com/hashicorp/go-plugin"
|
||||||
"github.com/hashicorp/vault/helper/pluginutil"
|
"github.com/hashicorp/vault/helper/pluginutil"
|
||||||
|
@ -17,11 +15,11 @@ type DatabasePluginClient struct {
|
||||||
client *plugin.Client
|
client *plugin.Client
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
*databasePluginRPCClient
|
Database
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dc *DatabasePluginClient) Close() error {
|
func (dc *DatabasePluginClient) Close() error {
|
||||||
err := dc.databasePluginRPCClient.Close()
|
err := dc.Database.Close()
|
||||||
dc.client.Kill()
|
dc.client.Kill()
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
@ -55,79 +53,20 @@ func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginR
|
||||||
|
|
||||||
// We should have a database type now. This feels like a normal interface
|
// We should have a database type now. This feels like a normal interface
|
||||||
// implementation but is in fact over an RPC connection.
|
// implementation but is in fact over an RPC connection.
|
||||||
databaseRPC := raw.(*databasePluginRPCClient)
|
var db Database
|
||||||
|
switch raw.(type) {
|
||||||
|
case *gRPCClient:
|
||||||
|
db = raw.(*gRPCClient)
|
||||||
|
case *databasePluginRPCClient:
|
||||||
|
logger.Warn("database: plugin is using deprecated net RPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name)
|
||||||
|
db = raw.(*databasePluginRPCClient)
|
||||||
|
default:
|
||||||
|
return nil, errors.New("unsupported client type")
|
||||||
|
}
|
||||||
|
|
||||||
// Wrap RPC implimentation in DatabasePluginClient
|
// Wrap RPC implimentation in DatabasePluginClient
|
||||||
return &DatabasePluginClient{
|
return &DatabasePluginClient{
|
||||||
client: client,
|
client: client,
|
||||||
databasePluginRPCClient: databaseRPC,
|
Database: db,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---- RPC client domain ----
|
|
||||||
|
|
||||||
// databasePluginRPCClient implements Database and is used on the client to
|
|
||||||
// make RPC calls to a plugin.
|
|
||||||
type databasePluginRPCClient struct {
|
|
||||||
client *rpc.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dr *databasePluginRPCClient) Type() (string, error) {
|
|
||||||
var dbType string
|
|
||||||
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
|
|
||||||
|
|
||||||
return fmt.Sprintf("plugin-%s", dbType), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
|
||||||
req := CreateUserRequest{
|
|
||||||
Statements: statements,
|
|
||||||
UsernameConfig: usernameConfig,
|
|
||||||
Expiration: expiration,
|
|
||||||
}
|
|
||||||
|
|
||||||
var resp CreateUserResponse
|
|
||||||
err = dr.client.Call("Plugin.CreateUser", req, &resp)
|
|
||||||
|
|
||||||
return resp.Username, resp.Password, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error {
|
|
||||||
req := RenewUserRequest{
|
|
||||||
Statements: statements,
|
|
||||||
Username: username,
|
|
||||||
Expiration: expiration,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error {
|
|
||||||
req := RevokeUserRequest{
|
|
||||||
Statements: statements,
|
|
||||||
Username: username,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
|
||||||
req := InitializeRequest{
|
|
||||||
Config: conf,
|
|
||||||
VerifyConnection: verifyConnection,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dr *databasePluginRPCClient) Close() error {
|
|
||||||
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
556
builtin/logical/database/dbplugin/database.pb.go
Normal file
556
builtin/logical/database/dbplugin/database.pb.go
Normal file
|
@ -0,0 +1,556 @@
|
||||||
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
|
// source: builtin/logical/database/dbplugin/database.proto
|
||||||
|
|
||||||
|
/*
|
||||||
|
Package dbplugin is a generated protocol buffer package.
|
||||||
|
|
||||||
|
It is generated from these files:
|
||||||
|
builtin/logical/database/dbplugin/database.proto
|
||||||
|
|
||||||
|
It has these top-level messages:
|
||||||
|
InitializeRequest
|
||||||
|
CreateUserRequest
|
||||||
|
RenewUserRequest
|
||||||
|
RevokeUserRequest
|
||||||
|
Statements
|
||||||
|
UsernameConfig
|
||||||
|
CreateUserResponse
|
||||||
|
TypeResponse
|
||||||
|
Empty
|
||||||
|
*/
|
||||||
|
package dbplugin
|
||||||
|
|
||||||
|
import proto "github.com/golang/protobuf/proto"
|
||||||
|
import fmt "fmt"
|
||||||
|
import math "math"
|
||||||
|
import google_protobuf "github.com/golang/protobuf/ptypes/timestamp"
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "golang.org/x/net/context"
|
||||||
|
grpc "google.golang.org/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Reference imports to suppress errors if they are not otherwise used.
|
||||||
|
var _ = proto.Marshal
|
||||||
|
var _ = fmt.Errorf
|
||||||
|
var _ = math.Inf
|
||||||
|
|
||||||
|
// This is a compile-time assertion to ensure that this generated file
|
||||||
|
// is compatible with the proto package it is being compiled against.
|
||||||
|
// A compilation error at this line likely means your copy of the
|
||||||
|
// proto package needs to be updated.
|
||||||
|
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||||
|
|
||||||
|
type InitializeRequest struct {
|
||||||
|
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
|
||||||
|
VerifyConnection bool `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *InitializeRequest) Reset() { *m = InitializeRequest{} }
|
||||||
|
func (m *InitializeRequest) String() string { return proto.CompactTextString(m) }
|
||||||
|
func (*InitializeRequest) ProtoMessage() {}
|
||||||
|
func (*InitializeRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
|
||||||
|
|
||||||
|
func (m *InitializeRequest) GetConfig() []byte {
|
||||||
|
if m != nil {
|
||||||
|
return m.Config
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *InitializeRequest) GetVerifyConnection() bool {
|
||||||
|
if m != nil {
|
||||||
|
return m.VerifyConnection
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateUserRequest struct {
|
||||||
|
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
|
||||||
|
UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"`
|
||||||
|
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} }
|
||||||
|
func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) }
|
||||||
|
func (*CreateUserRequest) ProtoMessage() {}
|
||||||
|
func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
|
||||||
|
|
||||||
|
func (m *CreateUserRequest) GetStatements() *Statements {
|
||||||
|
if m != nil {
|
||||||
|
return m.Statements
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CreateUserRequest) GetUsernameConfig() *UsernameConfig {
|
||||||
|
if m != nil {
|
||||||
|
return m.UsernameConfig
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CreateUserRequest) GetExpiration() *google_protobuf.Timestamp {
|
||||||
|
if m != nil {
|
||||||
|
return m.Expiration
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type RenewUserRequest struct {
|
||||||
|
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
|
||||||
|
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
|
||||||
|
Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} }
|
||||||
|
func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) }
|
||||||
|
func (*RenewUserRequest) ProtoMessage() {}
|
||||||
|
func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
|
||||||
|
|
||||||
|
func (m *RenewUserRequest) GetStatements() *Statements {
|
||||||
|
if m != nil {
|
||||||
|
return m.Statements
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RenewUserRequest) GetUsername() string {
|
||||||
|
if m != nil {
|
||||||
|
return m.Username
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RenewUserRequest) GetExpiration() *google_protobuf.Timestamp {
|
||||||
|
if m != nil {
|
||||||
|
return m.Expiration
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type RevokeUserRequest struct {
|
||||||
|
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
|
||||||
|
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} }
|
||||||
|
func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) }
|
||||||
|
func (*RevokeUserRequest) ProtoMessage() {}
|
||||||
|
func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
|
||||||
|
|
||||||
|
func (m *RevokeUserRequest) GetStatements() *Statements {
|
||||||
|
if m != nil {
|
||||||
|
return m.Statements
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RevokeUserRequest) GetUsername() string {
|
||||||
|
if m != nil {
|
||||||
|
return m.Username
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type Statements struct {
|
||||||
|
CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"`
|
||||||
|
RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"`
|
||||||
|
RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"`
|
||||||
|
RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Statements) Reset() { *m = Statements{} }
|
||||||
|
func (m *Statements) String() string { return proto.CompactTextString(m) }
|
||||||
|
func (*Statements) ProtoMessage() {}
|
||||||
|
func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
|
||||||
|
|
||||||
|
func (m *Statements) GetCreationStatements() string {
|
||||||
|
if m != nil {
|
||||||
|
return m.CreationStatements
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Statements) GetRevocationStatements() string {
|
||||||
|
if m != nil {
|
||||||
|
return m.RevocationStatements
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Statements) GetRollbackStatements() string {
|
||||||
|
if m != nil {
|
||||||
|
return m.RollbackStatements
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Statements) GetRenewStatements() string {
|
||||||
|
if m != nil {
|
||||||
|
return m.RenewStatements
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type UsernameConfig struct {
|
||||||
|
DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"`
|
||||||
|
RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UsernameConfig) Reset() { *m = UsernameConfig{} }
|
||||||
|
func (m *UsernameConfig) String() string { return proto.CompactTextString(m) }
|
||||||
|
func (*UsernameConfig) ProtoMessage() {}
|
||||||
|
func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
|
||||||
|
|
||||||
|
func (m *UsernameConfig) GetDisplayName() string {
|
||||||
|
if m != nil {
|
||||||
|
return m.DisplayName
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UsernameConfig) GetRoleName() string {
|
||||||
|
if m != nil {
|
||||||
|
return m.RoleName
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateUserResponse struct {
|
||||||
|
Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"`
|
||||||
|
Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} }
|
||||||
|
func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) }
|
||||||
|
func (*CreateUserResponse) ProtoMessage() {}
|
||||||
|
func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
|
||||||
|
|
||||||
|
func (m *CreateUserResponse) GetUsername() string {
|
||||||
|
if m != nil {
|
||||||
|
return m.Username
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CreateUserResponse) GetPassword() string {
|
||||||
|
if m != nil {
|
||||||
|
return m.Password
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type TypeResponse struct {
|
||||||
|
Type string `protobuf:"bytes,1,opt,name=type" json:"type,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TypeResponse) Reset() { *m = TypeResponse{} }
|
||||||
|
func (m *TypeResponse) String() string { return proto.CompactTextString(m) }
|
||||||
|
func (*TypeResponse) ProtoMessage() {}
|
||||||
|
func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
|
||||||
|
|
||||||
|
func (m *TypeResponse) GetType() string {
|
||||||
|
if m != nil {
|
||||||
|
return m.Type
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type Empty struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Empty) Reset() { *m = Empty{} }
|
||||||
|
func (m *Empty) String() string { return proto.CompactTextString(m) }
|
||||||
|
func (*Empty) ProtoMessage() {}
|
||||||
|
func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} }
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest")
|
||||||
|
proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest")
|
||||||
|
proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest")
|
||||||
|
proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest")
|
||||||
|
proto.RegisterType((*Statements)(nil), "dbplugin.Statements")
|
||||||
|
proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig")
|
||||||
|
proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse")
|
||||||
|
proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse")
|
||||||
|
proto.RegisterType((*Empty)(nil), "dbplugin.Empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reference imports to suppress errors if they are not otherwise used.
|
||||||
|
var _ context.Context
|
||||||
|
var _ grpc.ClientConn
|
||||||
|
|
||||||
|
// This is a compile-time assertion to ensure that this generated file
|
||||||
|
// is compatible with the grpc package it is being compiled against.
|
||||||
|
const _ = grpc.SupportPackageIsVersion4
|
||||||
|
|
||||||
|
// Client API for Database service
|
||||||
|
|
||||||
|
type DatabaseClient interface {
|
||||||
|
Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error)
|
||||||
|
CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error)
|
||||||
|
RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error)
|
||||||
|
RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error)
|
||||||
|
Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error)
|
||||||
|
Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type databaseClient struct {
|
||||||
|
cc *grpc.ClientConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDatabaseClient(cc *grpc.ClientConn) DatabaseClient {
|
||||||
|
return &databaseClient{cc}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *databaseClient) Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) {
|
||||||
|
out := new(TypeResponse)
|
||||||
|
err := grpc.Invoke(ctx, "/dbplugin.Database/Type", in, out, c.cc, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *databaseClient) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) {
|
||||||
|
out := new(CreateUserResponse)
|
||||||
|
err := grpc.Invoke(ctx, "/dbplugin.Database/CreateUser", in, out, c.cc, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *databaseClient) RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) {
|
||||||
|
out := new(Empty)
|
||||||
|
err := grpc.Invoke(ctx, "/dbplugin.Database/RenewUser", in, out, c.cc, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) {
|
||||||
|
out := new(Empty)
|
||||||
|
err := grpc.Invoke(ctx, "/dbplugin.Database/RevokeUser", in, out, c.cc, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) {
|
||||||
|
out := new(Empty)
|
||||||
|
err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) {
|
||||||
|
out := new(Empty)
|
||||||
|
err := grpc.Invoke(ctx, "/dbplugin.Database/Close", in, out, c.cc, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server API for Database service
|
||||||
|
|
||||||
|
type DatabaseServer interface {
|
||||||
|
Type(context.Context, *Empty) (*TypeResponse, error)
|
||||||
|
CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error)
|
||||||
|
RenewUser(context.Context, *RenewUserRequest) (*Empty, error)
|
||||||
|
RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error)
|
||||||
|
Initialize(context.Context, *InitializeRequest) (*Empty, error)
|
||||||
|
Close(context.Context, *Empty) (*Empty, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) {
|
||||||
|
s.RegisterService(&_Database_serviceDesc, srv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _Database_Type_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(Empty)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DatabaseServer).Type(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/dbplugin.Database/Type",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DatabaseServer).Type(ctx, req.(*Empty))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _Database_CreateUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(CreateUserRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DatabaseServer).CreateUser(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/dbplugin.Database/CreateUser",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DatabaseServer).CreateUser(ctx, req.(*CreateUserRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _Database_RenewUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(RenewUserRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DatabaseServer).RenewUser(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/dbplugin.Database/RenewUser",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DatabaseServer).RenewUser(ctx, req.(*RenewUserRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(RevokeUserRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DatabaseServer).RevokeUser(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/dbplugin.Database/RevokeUser",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DatabaseServer).RevokeUser(ctx, req.(*RevokeUserRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(InitializeRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DatabaseServer).Initialize(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/dbplugin.Database/Initialize",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(Empty)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DatabaseServer).Close(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/dbplugin.Database/Close",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DatabaseServer).Close(ctx, req.(*Empty))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _Database_serviceDesc = grpc.ServiceDesc{
|
||||||
|
ServiceName: "dbplugin.Database",
|
||||||
|
HandlerType: (*DatabaseServer)(nil),
|
||||||
|
Methods: []grpc.MethodDesc{
|
||||||
|
{
|
||||||
|
MethodName: "Type",
|
||||||
|
Handler: _Database_Type_Handler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MethodName: "CreateUser",
|
||||||
|
Handler: _Database_CreateUser_Handler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MethodName: "RenewUser",
|
||||||
|
Handler: _Database_RenewUser_Handler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MethodName: "RevokeUser",
|
||||||
|
Handler: _Database_RevokeUser_Handler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MethodName: "Initialize",
|
||||||
|
Handler: _Database_Initialize_Handler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MethodName: "Close",
|
||||||
|
Handler: _Database_Close_Handler,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Streams: []grpc.StreamDesc{},
|
||||||
|
Metadata: "builtin/logical/database/dbplugin/database.proto",
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) }
|
||||||
|
|
||||||
|
var fileDescriptor0 = []byte{
|
||||||
|
// 548 bytes of a gzipped FileDescriptorProto
|
||||||
|
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e,
|
||||||
|
0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f,
|
||||||
|
0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e,
|
||||||
|
0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2,
|
||||||
|
0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6,
|
||||||
|
0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c,
|
||||||
|
0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e,
|
||||||
|
0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3,
|
||||||
|
0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1,
|
||||||
|
0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f,
|
||||||
|
0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49,
|
||||||
|
0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35,
|
||||||
|
0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20,
|
||||||
|
0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6,
|
||||||
|
0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34,
|
||||||
|
0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0,
|
||||||
|
0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce,
|
||||||
|
0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68,
|
||||||
|
0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0,
|
||||||
|
0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8,
|
||||||
|
0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1,
|
||||||
|
0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0,
|
||||||
|
0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a,
|
||||||
|
0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab,
|
||||||
|
0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c,
|
||||||
|
0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4,
|
||||||
|
0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36,
|
||||||
|
0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a,
|
||||||
|
0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf,
|
||||||
|
0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67,
|
||||||
|
0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b,
|
||||||
|
0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d,
|
||||||
|
0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6,
|
||||||
|
0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56,
|
||||||
|
0x94, 0x05, 0x00, 0x00,
|
||||||
|
}
|
58
builtin/logical/database/dbplugin/database.proto
Normal file
58
builtin/logical/database/dbplugin/database.proto
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
syntax = "proto3";
|
||||||
|
package dbplugin;
|
||||||
|
|
||||||
|
import "google/protobuf/timestamp.proto";
|
||||||
|
|
||||||
|
message InitializeRequest {
|
||||||
|
bytes config = 1;
|
||||||
|
bool verify_connection = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CreateUserRequest {
|
||||||
|
Statements statements = 1;
|
||||||
|
UsernameConfig username_config = 2;
|
||||||
|
google.protobuf.Timestamp expiration = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message RenewUserRequest {
|
||||||
|
Statements statements = 1;
|
||||||
|
string username = 2;
|
||||||
|
google.protobuf.Timestamp expiration = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message RevokeUserRequest {
|
||||||
|
Statements statements = 1;
|
||||||
|
string username = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Statements {
|
||||||
|
string creation_statements = 1;
|
||||||
|
string revocation_statements = 2;
|
||||||
|
string rollback_statements = 3;
|
||||||
|
string renew_statements = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message UsernameConfig {
|
||||||
|
string DisplayName = 1;
|
||||||
|
string RoleName = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CreateUserResponse {
|
||||||
|
string username = 1;
|
||||||
|
string password = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TypeResponse {
|
||||||
|
string type = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Empty {}
|
||||||
|
|
||||||
|
service Database {
|
||||||
|
rpc Type(Empty) returns (TypeResponse);
|
||||||
|
rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
|
||||||
|
rpc RenewUser(RenewUserRequest) returns (Empty);
|
||||||
|
rpc RevokeUser(RevokeUserRequest) returns (Empty);
|
||||||
|
rpc Initialize(InitializeRequest) returns (Empty);
|
||||||
|
rpc Close(Empty) returns (Empty);
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
package dbplugin
|
package dbplugin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
metrics "github.com/armon/go-metrics"
|
metrics "github.com/armon/go-metrics"
|
||||||
|
@ -15,55 +16,56 @@ type databaseTracingMiddleware struct {
|
||||||
next Database
|
next Database
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
|
|
||||||
typeStr string
|
typeStr string
|
||||||
|
transport string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *databaseTracingMiddleware) Type() (string, error) {
|
func (mw *databaseTracingMiddleware) Type() (string, error) {
|
||||||
return mw.next.Type()
|
return mw.next.Type()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
func (mw *databaseTracingMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||||
defer func(then time.Time) {
|
defer func(then time.Time) {
|
||||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||||
}(time.Now())
|
}(time.Now())
|
||||||
|
|
||||||
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr)
|
mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||||
return mw.next.CreateUser(statements, usernameConfig, expiration)
|
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
|
func (mw *databaseTracingMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
|
||||||
defer func(then time.Time) {
|
defer func(then time.Time) {
|
||||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||||
}(time.Now())
|
}(time.Now())
|
||||||
|
|
||||||
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr)
|
mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr, "transport", mw.transport)
|
||||||
return mw.next.RenewUser(statements, username, expiration)
|
return mw.next.RenewUser(ctx, statements, username, expiration)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) {
|
func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
|
||||||
defer func(then time.Time) {
|
defer func(then time.Time) {
|
||||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||||
}(time.Now())
|
}(time.Now())
|
||||||
|
|
||||||
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr)
|
mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||||
return mw.next.RevokeUser(statements, username)
|
return mw.next.RevokeUser(ctx, statements, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
|
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||||
defer func(then time.Time) {
|
defer func(then time.Time) {
|
||||||
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then))
|
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then))
|
||||||
}(time.Now())
|
}(time.Now())
|
||||||
|
|
||||||
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr)
|
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||||
return mw.next.Initialize(conf, verifyConnection)
|
return mw.next.Initialize(ctx, conf, verifyConnection)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *databaseTracingMiddleware) Close() (err error) {
|
func (mw *databaseTracingMiddleware) Close() (err error) {
|
||||||
defer func(then time.Time) {
|
defer func(then time.Time) {
|
||||||
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||||
}(time.Now())
|
}(time.Now())
|
||||||
|
|
||||||
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr)
|
mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||||
return mw.next.Close()
|
return mw.next.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,7 +83,7 @@ func (mw *databaseMetricsMiddleware) Type() (string, error) {
|
||||||
return mw.next.Type()
|
return mw.next.Type()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
func (mw *databaseMetricsMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||||
defer func(now time.Time) {
|
defer func(now time.Time) {
|
||||||
metrics.MeasureSince([]string{"database", "CreateUser"}, now)
|
metrics.MeasureSince([]string{"database", "CreateUser"}, now)
|
||||||
metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now)
|
metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now)
|
||||||
|
@ -94,10 +96,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameC
|
||||||
|
|
||||||
metrics.IncrCounter([]string{"database", "CreateUser"}, 1)
|
metrics.IncrCounter([]string{"database", "CreateUser"}, 1)
|
||||||
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1)
|
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1)
|
||||||
return mw.next.CreateUser(statements, usernameConfig, expiration)
|
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
|
func (mw *databaseMetricsMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
|
||||||
defer func(now time.Time) {
|
defer func(now time.Time) {
|
||||||
metrics.MeasureSince([]string{"database", "RenewUser"}, now)
|
metrics.MeasureSince([]string{"database", "RenewUser"}, now)
|
||||||
metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now)
|
metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now)
|
||||||
|
@ -110,10 +112,10 @@ func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username s
|
||||||
|
|
||||||
metrics.IncrCounter([]string{"database", "RenewUser"}, 1)
|
metrics.IncrCounter([]string{"database", "RenewUser"}, 1)
|
||||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1)
|
metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1)
|
||||||
return mw.next.RenewUser(statements, username, expiration)
|
return mw.next.RenewUser(ctx, statements, username, expiration)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username string) (err error) {
|
func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
|
||||||
defer func(now time.Time) {
|
defer func(now time.Time) {
|
||||||
metrics.MeasureSince([]string{"database", "RevokeUser"}, now)
|
metrics.MeasureSince([]string{"database", "RevokeUser"}, now)
|
||||||
metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now)
|
metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now)
|
||||||
|
@ -126,10 +128,10 @@ func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username
|
||||||
|
|
||||||
metrics.IncrCounter([]string{"database", "RevokeUser"}, 1)
|
metrics.IncrCounter([]string{"database", "RevokeUser"}, 1)
|
||||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1)
|
metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1)
|
||||||
return mw.next.RevokeUser(statements, username)
|
return mw.next.RevokeUser(ctx, statements, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) {
|
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
|
||||||
defer func(now time.Time) {
|
defer func(now time.Time) {
|
||||||
metrics.MeasureSince([]string{"database", "Initialize"}, now)
|
metrics.MeasureSince([]string{"database", "Initialize"}, now)
|
||||||
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
|
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
|
||||||
|
@ -142,7 +144,7 @@ func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, ver
|
||||||
|
|
||||||
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
|
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
|
||||||
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
|
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
|
||||||
return mw.next.Initialize(conf, verifyConnection)
|
return mw.next.Initialize(ctx, conf, verifyConnection)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *databaseMetricsMiddleware) Close() (err error) {
|
func (mw *databaseMetricsMiddleware) Close() (err error) {
|
||||||
|
|
198
builtin/logical/database/dbplugin/grpc_transport.go
Normal file
198
builtin/logical/database/dbplugin/grpc_transport.go
Normal file
|
@ -0,0 +1,198 @@
|
||||||
|
package dbplugin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/connectivity"
|
||||||
|
|
||||||
|
"github.com/golang/protobuf/ptypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrPluginShutdown = errors.New("plugin shutdown")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---- gRPC Server domain ----
|
||||||
|
|
||||||
|
type gRPCServer struct {
|
||||||
|
impl Database
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *gRPCServer) Type(context.Context, *Empty) (*TypeResponse, error) {
|
||||||
|
t, err := s.impl.Type()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &TypeResponse{
|
||||||
|
Type: t,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *gRPCServer) CreateUser(ctx context.Context, req *CreateUserRequest) (*CreateUserResponse, error) {
|
||||||
|
e, err := ptypes.Timestamp(req.Expiration)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
u, p, err := s.impl.CreateUser(ctx, *req.Statements, *req.UsernameConfig, e)
|
||||||
|
|
||||||
|
return &CreateUserResponse{
|
||||||
|
Username: u,
|
||||||
|
Password: p,
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *gRPCServer) RenewUser(ctx context.Context, req *RenewUserRequest) (*Empty, error) {
|
||||||
|
e, err := ptypes.Timestamp(req.Expiration)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = s.impl.RenewUser(ctx, *req.Statements, req.Username, e)
|
||||||
|
return &Empty{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*Empty, error) {
|
||||||
|
err := s.impl.RevokeUser(ctx, *req.Statements, req.Username)
|
||||||
|
return &Empty{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
|
||||||
|
config := map[string]interface{}{}
|
||||||
|
|
||||||
|
err := json.Unmarshal(req.Config, &config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.impl.Initialize(ctx, config, req.VerifyConnection)
|
||||||
|
return &Empty{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) {
|
||||||
|
s.impl.Close()
|
||||||
|
return &Empty{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- gRPC client domain ----
|
||||||
|
|
||||||
|
type gRPCClient struct {
|
||||||
|
client DatabaseClient
|
||||||
|
clientConn *grpc.ClientConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c gRPCClient) Type() (string, error) {
|
||||||
|
// If the plugin has already shutdown, this will hang forever so we give it
|
||||||
|
// a one second timeout.
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
switch c.clientConn.GetState() {
|
||||||
|
case connectivity.Ready, connectivity.Idle:
|
||||||
|
default:
|
||||||
|
return "", ErrPluginShutdown
|
||||||
|
}
|
||||||
|
resp, err := c.client.Type(ctx, &Empty{})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp.Type, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||||
|
t, err := ptypes.TimestampProto(expiration)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch c.clientConn.GetState() {
|
||||||
|
case connectivity.Ready, connectivity.Idle:
|
||||||
|
default:
|
||||||
|
return "", "", ErrPluginShutdown
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.client.CreateUser(ctx, &CreateUserRequest{
|
||||||
|
Statements: &statements,
|
||||||
|
UsernameConfig: &usernameConfig,
|
||||||
|
Expiration: t,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp.Username, resp.Password, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gRPCClient) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error {
|
||||||
|
t, err := ptypes.TimestampProto(expiration)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch c.clientConn.GetState() {
|
||||||
|
case connectivity.Ready, connectivity.Idle:
|
||||||
|
default:
|
||||||
|
return ErrPluginShutdown
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.client.RenewUser(ctx, &RenewUserRequest{
|
||||||
|
Statements: &statements,
|
||||||
|
Username: username,
|
||||||
|
Expiration: t,
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, username string) error {
|
||||||
|
switch c.clientConn.GetState() {
|
||||||
|
case connectivity.Ready, connectivity.Idle:
|
||||||
|
default:
|
||||||
|
return ErrPluginShutdown
|
||||||
|
}
|
||||||
|
_, err := c.client.RevokeUser(ctx, &RevokeUserRequest{
|
||||||
|
Statements: &statements,
|
||||||
|
Username: username,
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error {
|
||||||
|
configRaw, err := json.Marshal(config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch c.clientConn.GetState() {
|
||||||
|
case connectivity.Ready, connectivity.Idle:
|
||||||
|
default:
|
||||||
|
return ErrPluginShutdown
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.client.Initialize(ctx, &InitializeRequest{
|
||||||
|
Config: configRaw,
|
||||||
|
VerifyConnection: verifyConnection,
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gRPCClient) Close() error {
|
||||||
|
// If the plugin has already shutdown, this will hang forever so we give it
|
||||||
|
// a one second timeout.
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
defer cancel()
|
||||||
|
switch c.clientConn.GetState() {
|
||||||
|
case connectivity.Ready, connectivity.Idle:
|
||||||
|
_, err := c.client.Close(ctx, &Empty{})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
139
builtin/logical/database/dbplugin/netrpc_transport.go
Normal file
139
builtin/logical/database/dbplugin/netrpc_transport.go
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
package dbplugin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/rpc"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---- RPC server domain ----
|
||||||
|
|
||||||
|
// databasePluginRPCServer implements an RPC version of Database and is run
|
||||||
|
// inside a plugin. It wraps an underlying implementation of Database.
|
||||||
|
type databasePluginRPCServer struct {
|
||||||
|
impl Database
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
|
||||||
|
var err error
|
||||||
|
*resp, err = ds.impl.Type()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequestRPC, resp *CreateUserResponse) error {
|
||||||
|
var err error
|
||||||
|
resp.Username, resp.Password, err = ds.impl.CreateUser(context.Background(), args.Statements, args.UsernameConfig, args.Expiration)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequestRPC, _ *struct{}) error {
|
||||||
|
err := ds.impl.RenewUser(context.Background(), args.Statements, args.Username, args.Expiration)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *struct{}) error {
|
||||||
|
err := ds.impl.RevokeUser(context.Background(), args.Statements, args.Username)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error {
|
||||||
|
err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
|
||||||
|
ds.impl.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- RPC client domain ----
|
||||||
|
// databasePluginRPCClient implements Database and is used on the client to
|
||||||
|
// make RPC calls to a plugin.
|
||||||
|
type databasePluginRPCClient struct {
|
||||||
|
client *rpc.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dr *databasePluginRPCClient) Type() (string, error) {
|
||||||
|
var dbType string
|
||||||
|
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
|
||||||
|
|
||||||
|
return fmt.Sprintf("plugin-%s", dbType), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dr *databasePluginRPCClient) CreateUser(_ context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||||
|
req := CreateUserRequestRPC{
|
||||||
|
Statements: statements,
|
||||||
|
UsernameConfig: usernameConfig,
|
||||||
|
Expiration: expiration,
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp CreateUserResponse
|
||||||
|
err = dr.client.Call("Plugin.CreateUser", req, &resp)
|
||||||
|
|
||||||
|
return resp.Username, resp.Password, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements Statements, username string, expiration time.Time) error {
|
||||||
|
req := RenewUserRequestRPC{
|
||||||
|
Statements: statements,
|
||||||
|
Username: username,
|
||||||
|
Expiration: expiration,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error {
|
||||||
|
req := RevokeUserRequestRPC{
|
||||||
|
Statements: statements,
|
||||||
|
Username: username,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||||
|
req := InitializeRequestRPC{
|
||||||
|
Config: conf,
|
||||||
|
VerifyConnection: verifyConnection,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dr *databasePluginRPCClient) Close() error {
|
||||||
|
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- RPC Request Args Domain ----
|
||||||
|
|
||||||
|
type InitializeRequestRPC struct {
|
||||||
|
Config map[string]interface{}
|
||||||
|
VerifyConnection bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateUserRequestRPC struct {
|
||||||
|
Statements Statements
|
||||||
|
UsernameConfig UsernameConfig
|
||||||
|
Expiration time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type RenewUserRequestRPC struct {
|
||||||
|
Statements Statements
|
||||||
|
Username string
|
||||||
|
Expiration time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type RevokeUserRequestRPC struct {
|
||||||
|
Statements Statements
|
||||||
|
Username string
|
||||||
|
}
|
|
@ -1,10 +1,13 @@
|
||||||
package dbplugin
|
package dbplugin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/rpc"
|
"net/rpc"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/hashicorp/go-plugin"
|
"github.com/hashicorp/go-plugin"
|
||||||
"github.com/hashicorp/vault/helper/pluginutil"
|
"github.com/hashicorp/vault/helper/pluginutil"
|
||||||
log "github.com/mgutz/logxi/v1"
|
log "github.com/mgutz/logxi/v1"
|
||||||
|
@ -13,29 +16,14 @@ import (
|
||||||
// Database is the interface that all database objects must implement.
|
// Database is the interface that all database objects must implement.
|
||||||
type Database interface {
|
type Database interface {
|
||||||
Type() (string, error)
|
Type() (string, error)
|
||||||
CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
|
CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
|
||||||
RenewUser(statements Statements, username string, expiration time.Time) error
|
RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error
|
||||||
RevokeUser(statements Statements, username string) error
|
RevokeUser(ctx context.Context, statements Statements, username string) error
|
||||||
|
|
||||||
Initialize(config map[string]interface{}, verifyConnection bool) error
|
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Statements set in role creation and passed into the database type's functions.
|
|
||||||
type Statements struct {
|
|
||||||
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
|
|
||||||
RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"`
|
|
||||||
RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"`
|
|
||||||
RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// UsernameConfig is used to configure prefixes for the username to be
|
|
||||||
// generated.
|
|
||||||
type UsernameConfig struct {
|
|
||||||
DisplayName string
|
|
||||||
RoleName string
|
|
||||||
}
|
|
||||||
|
|
||||||
// PluginFactory is used to build plugin database types. It wraps the database
|
// PluginFactory is used to build plugin database types. It wraps the database
|
||||||
// object in a logging and metrics middleware.
|
// object in a logging and metrics middleware.
|
||||||
func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
|
func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
|
||||||
|
@ -45,6 +33,7 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var transport string
|
||||||
var db Database
|
var db Database
|
||||||
if pluginRunner.Builtin {
|
if pluginRunner.Builtin {
|
||||||
// Plugin is builtin so we can retrieve an instance of the interface
|
// Plugin is builtin so we can retrieve an instance of the interface
|
||||||
|
@ -60,12 +49,24 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
|
||||||
return nil, fmt.Errorf("unsuported database type: %s", pluginName)
|
return nil, fmt.Errorf("unsuported database type: %s", pluginName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
transport = "builtin"
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// create a DatabasePluginClient instance
|
// create a DatabasePluginClient instance
|
||||||
db, err = newPluginClient(sys, pluginRunner, logger)
|
db, err = newPluginClient(sys, pluginRunner, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Switch on the underlying database client type to get the transport
|
||||||
|
// method.
|
||||||
|
switch db.(*DatabasePluginClient).Database.(type) {
|
||||||
|
case *gRPCClient:
|
||||||
|
transport = "gRPC"
|
||||||
|
case *databasePluginRPCClient:
|
||||||
|
transport = "netRPC"
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
typeStr, err := db.Type()
|
typeStr, err := db.Type()
|
||||||
|
@ -82,9 +83,10 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.
|
||||||
// Wrap with tracing middleware
|
// Wrap with tracing middleware
|
||||||
if logger.IsTrace() {
|
if logger.IsTrace() {
|
||||||
db = &databaseTracingMiddleware{
|
db = &databaseTracingMiddleware{
|
||||||
next: db,
|
transport: transport,
|
||||||
typeStr: typeStr,
|
next: db,
|
||||||
logger: logger,
|
typeStr: typeStr,
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,33 +117,14 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
|
||||||
return &databasePluginRPCClient{client: c}, nil
|
return &databasePluginRPCClient{client: c}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---- RPC Request Args Domain ----
|
func (d DatabasePlugin) GRPCServer(s *grpc.Server) error {
|
||||||
|
RegisterDatabaseServer(s, &gRPCServer{impl: d.impl})
|
||||||
type InitializeRequest struct {
|
return nil
|
||||||
Config map[string]interface{}
|
|
||||||
VerifyConnection bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateUserRequest struct {
|
func (DatabasePlugin) GRPCClient(c *grpc.ClientConn) (interface{}, error) {
|
||||||
Statements Statements
|
return &gRPCClient{
|
||||||
UsernameConfig UsernameConfig
|
client: NewDatabaseClient(c),
|
||||||
Expiration time.Time
|
clientConn: c,
|
||||||
}
|
}, nil
|
||||||
|
|
||||||
type RenewUserRequest struct {
|
|
||||||
Statements Statements
|
|
||||||
Username string
|
|
||||||
Expiration time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type RevokeUserRequest struct {
|
|
||||||
Statements Statements
|
|
||||||
Username string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---- RPC Response Args Domain ----
|
|
||||||
|
|
||||||
type CreateUserResponse struct {
|
|
||||||
Username string
|
|
||||||
Password string
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
package dbplugin_test
|
package dbplugin_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
plugin "github.com/hashicorp/go-plugin"
|
||||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||||
"github.com/hashicorp/vault/helper/pluginutil"
|
"github.com/hashicorp/vault/helper/pluginutil"
|
||||||
vaulthttp "github.com/hashicorp/vault/http"
|
vaulthttp "github.com/hashicorp/vault/http"
|
||||||
|
@ -20,7 +22,7 @@ type mockPlugin struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockPlugin) Type() (string, error) { return "mock", nil }
|
func (m *mockPlugin) Type() (string, error) { return "mock", nil }
|
||||||
func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
func (m *mockPlugin) CreateUser(_ context.Context, statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||||
err = errors.New("err")
|
err = errors.New("err")
|
||||||
if usernameConf.DisplayName == "" || expiration.IsZero() {
|
if usernameConf.DisplayName == "" || expiration.IsZero() {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
|
@ -34,7 +36,7 @@ func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbp
|
||||||
|
|
||||||
return usernameConf.DisplayName, "test", nil
|
return usernameConf.DisplayName, "test", nil
|
||||||
}
|
}
|
||||||
func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
func (m *mockPlugin) RenewUser(_ context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||||
err := errors.New("err")
|
err := errors.New("err")
|
||||||
if username == "" || expiration.IsZero() {
|
if username == "" || expiration.IsZero() {
|
||||||
return err
|
return err
|
||||||
|
@ -46,7 +48,7 @@ func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string,
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) error {
|
func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statements, username string) error {
|
||||||
err := errors.New("err")
|
err := errors.New("err")
|
||||||
if username == "" {
|
if username == "" {
|
||||||
return err
|
return err
|
||||||
|
@ -59,7 +61,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string)
|
||||||
delete(m.users, username)
|
delete(m.users, username)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error {
|
func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error {
|
||||||
err := errors.New("err")
|
err := errors.New("err")
|
||||||
if len(conf) != 1 {
|
if len(conf) != 1 {
|
||||||
return err
|
return err
|
||||||
|
@ -80,14 +82,15 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
|
||||||
cores := cluster.Cores
|
cores := cluster.Cores
|
||||||
|
|
||||||
sys := vault.TestDynamicSystemView(cores[0].Core)
|
sys := vault.TestDynamicSystemView(cores[0].Core)
|
||||||
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_Main")
|
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_GRPC_Main")
|
||||||
|
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", "TestPlugin_NetRPC_Main")
|
||||||
|
|
||||||
return cluster, sys
|
return cluster, sys
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is not an actual test case, it's a helper function that will be executed
|
// This is not an actual test case, it's a helper function that will be executed
|
||||||
// by the go-plugin client via an exec call.
|
// by the go-plugin client via an exec call.
|
||||||
func TestPlugin_Main(t *testing.T) {
|
func TestPlugin_GRPC_Main(t *testing.T) {
|
||||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -105,6 +108,30 @@ func TestPlugin_Main(t *testing.T) {
|
||||||
plugins.Serve(plugin, apiClientMeta.GetTLSConfig())
|
plugins.Serve(plugin, apiClientMeta.GetTLSConfig())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This is not an actual test case, it's a helper function that will be executed
|
||||||
|
// by the go-plugin client via an exec call.
|
||||||
|
func TestPlugin_NetRPC_Main(t *testing.T) {
|
||||||
|
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p := &mockPlugin{
|
||||||
|
users: make(map[string][]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []string{"--tls-skip-verify=true"}
|
||||||
|
|
||||||
|
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||||
|
flags := apiClientMeta.FlagSet()
|
||||||
|
flags.Parse(args)
|
||||||
|
|
||||||
|
tlsProvider := pluginutil.VaultPluginTLSProvider(apiClientMeta.GetTLSConfig())
|
||||||
|
serveConf := dbplugin.ServeConfig(p, tlsProvider)
|
||||||
|
serveConf.GRPCServer = nil
|
||||||
|
|
||||||
|
plugin.Serve(serveConf)
|
||||||
|
}
|
||||||
|
|
||||||
func TestPlugin_Initialize(t *testing.T) {
|
func TestPlugin_Initialize(t *testing.T) {
|
||||||
cluster, sys := getCluster(t)
|
cluster, sys := getCluster(t)
|
||||||
defer cluster.Cleanup()
|
defer cluster.Cleanup()
|
||||||
|
@ -118,7 +145,7 @@ func TestPlugin_Initialize(t *testing.T) {
|
||||||
"test": 1,
|
"test": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dbRaw.Initialize(connectionDetails, true)
|
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -143,7 +170,7 @@ func TestPlugin_CreateUser(t *testing.T) {
|
||||||
"test": 1,
|
"test": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Initialize(connectionDetails, true)
|
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -153,7 +180,7 @@ func TestPlugin_CreateUser(t *testing.T) {
|
||||||
RoleName: "test",
|
RoleName: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
us, pw, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -163,7 +190,7 @@ func TestPlugin_CreateUser(t *testing.T) {
|
||||||
|
|
||||||
// try and save the same user again to verify it saved the first time, this
|
// try and save the same user again to verify it saved the first time, this
|
||||||
// should return an error
|
// should return an error
|
||||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected an error, user wasn't created correctly")
|
t.Fatal("expected an error, user wasn't created correctly")
|
||||||
}
|
}
|
||||||
|
@ -182,7 +209,7 @@ func TestPlugin_RenewUser(t *testing.T) {
|
||||||
connectionDetails := map[string]interface{}{
|
connectionDetails := map[string]interface{}{
|
||||||
"test": 1,
|
"test": 1,
|
||||||
}
|
}
|
||||||
err = db.Initialize(connectionDetails, true)
|
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -192,12 +219,12 @@ func TestPlugin_RenewUser(t *testing.T) {
|
||||||
RoleName: "test",
|
RoleName: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.RenewUser(dbplugin.Statements{}, us, time.Now().Add(time.Minute))
|
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -216,7 +243,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
|
||||||
connectionDetails := map[string]interface{}{
|
connectionDetails := map[string]interface{}{
|
||||||
"test": 1,
|
"test": 1,
|
||||||
}
|
}
|
||||||
err = db.Initialize(connectionDetails, true)
|
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -226,19 +253,159 @@ func TestPlugin_RevokeUser(t *testing.T) {
|
||||||
RoleName: "test",
|
RoleName: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test default revoke statememts
|
// Test default revoke statememts
|
||||||
err = db.RevokeUser(dbplugin.Statements{}, us)
|
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try adding the same username back so we can verify it was removed
|
// Try adding the same username back so we can verify it was removed
|
||||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the code is still compatible with an old netRPC plugin
|
||||||
|
func TestPlugin_NetRPC_Initialize(t *testing.T) {
|
||||||
|
cluster, sys := getCluster(t)
|
||||||
|
defer cluster.Cleanup()
|
||||||
|
|
||||||
|
dbRaw, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
connectionDetails := map[string]interface{}{
|
||||||
|
"test": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = dbRaw.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPlugin_NetRPC_CreateUser(t *testing.T) {
|
||||||
|
cluster, sys := getCluster(t)
|
||||||
|
defer cluster.Cleanup()
|
||||||
|
|
||||||
|
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
connectionDetails := map[string]interface{}{
|
||||||
|
"test": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
usernameConf := dbplugin.UsernameConfig{
|
||||||
|
DisplayName: "test",
|
||||||
|
RoleName: "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
if us != "test" || pw != "test" {
|
||||||
|
t.Fatal("expected username and password to be 'test'")
|
||||||
|
}
|
||||||
|
|
||||||
|
// try and save the same user again to verify it saved the first time, this
|
||||||
|
// should return an error
|
||||||
|
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected an error, user wasn't created correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPlugin_NetRPC_RenewUser(t *testing.T) {
|
||||||
|
cluster, sys := getCluster(t)
|
||||||
|
defer cluster.Cleanup()
|
||||||
|
|
||||||
|
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
connectionDetails := map[string]interface{}{
|
||||||
|
"test": 1,
|
||||||
|
}
|
||||||
|
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
usernameConf := dbplugin.UsernameConfig{
|
||||||
|
DisplayName: "test",
|
||||||
|
RoleName: "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
|
||||||
|
cluster, sys := getCluster(t)
|
||||||
|
defer cluster.Cleanup()
|
||||||
|
|
||||||
|
db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
connectionDetails := map[string]interface{}{
|
||||||
|
"test": 1,
|
||||||
|
}
|
||||||
|
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
usernameConf := dbplugin.UsernameConfig{
|
||||||
|
DisplayName: "test",
|
||||||
|
RoleName: "test",
|
||||||
|
}
|
||||||
|
|
||||||
|
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test default revoke statememts
|
||||||
|
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try adding the same username back so we can verify it was removed
|
||||||
|
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,10 @@ import (
|
||||||
// Database implementation in a databasePluginRPCServer object and starts a
|
// Database implementation in a databasePluginRPCServer object and starts a
|
||||||
// RPC server.
|
// RPC server.
|
||||||
func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
|
func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
|
||||||
|
plugin.Serve(ServeConfig(db, tlsProvider))
|
||||||
|
}
|
||||||
|
|
||||||
|
func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig {
|
||||||
dbPlugin := &DatabasePlugin{
|
dbPlugin := &DatabasePlugin{
|
||||||
impl: db,
|
impl: db,
|
||||||
}
|
}
|
||||||
|
@ -19,53 +23,10 @@ func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
|
||||||
"database": dbPlugin,
|
"database": dbPlugin,
|
||||||
}
|
}
|
||||||
|
|
||||||
plugin.Serve(&plugin.ServeConfig{
|
return &plugin.ServeConfig{
|
||||||
HandshakeConfig: handshakeConfig,
|
HandshakeConfig: handshakeConfig,
|
||||||
Plugins: pluginMap,
|
Plugins: pluginMap,
|
||||||
TLSProvider: tlsProvider,
|
TLSProvider: tlsProvider,
|
||||||
})
|
GRPCServer: plugin.DefaultGRPCServer,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---- RPC server domain ----
|
|
||||||
|
|
||||||
// databasePluginRPCServer implements an RPC version of Database and is run
|
|
||||||
// inside a plugin. It wraps an underlying implementation of Database.
|
|
||||||
type databasePluginRPCServer struct {
|
|
||||||
impl Database
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
|
|
||||||
var err error
|
|
||||||
*resp, err = ds.impl.Type()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error {
|
|
||||||
var err error
|
|
||||||
resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernameConfig, args.Expiration)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error {
|
|
||||||
err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error {
|
|
||||||
err := ds.impl.RevokeUser(args.Statements, args.Username)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequest, _ *struct{}) error {
|
|
||||||
err := ds.impl.Initialize(args.Config, args.VerifyConnection)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
|
|
||||||
ds.impl.Close()
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
@ -62,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc {
|
||||||
b.clearConnection(name)
|
b.clearConnection(name)
|
||||||
|
|
||||||
// Execute plugin again, we don't need the object so throw away.
|
// Execute plugin again, we don't need the object so throw away.
|
||||||
_, err := b.createDBObj(req.Storage, name)
|
_, err := b.createDBObj(context.TODO(), req.Storage, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -230,7 +231,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
||||||
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Initialize(config.ConnectionDetails, verifyConnection)
|
err = db.Initialize(context.TODO(), config.ConnectionDetails, verifyConnection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.Close()
|
db.Close()
|
||||||
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -66,7 +67,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
||||||
unlockFunc = b.Unlock
|
unlockFunc = b.Unlock
|
||||||
|
|
||||||
// Create a new DB object
|
// Create a new DB object
|
||||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
unlockFunc()
|
unlockFunc()
|
||||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||||
|
@ -81,7 +82,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the user
|
// Create the user
|
||||||
username, password, err := db.CreateUser(role.Statements, usernameConfig, expiration)
|
username, password, err := db.CreateUser(context.TODO(), role.Statements, usernameConfig, expiration)
|
||||||
// Unlock
|
// Unlock
|
||||||
unlockFunc()
|
unlockFunc()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
|
@ -60,7 +61,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
|
||||||
unlockFunc = b.Unlock
|
unlockFunc = b.Unlock
|
||||||
|
|
||||||
// Create a new DB object
|
// Create a new DB object
|
||||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
unlockFunc()
|
unlockFunc()
|
||||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||||
|
@ -69,7 +70,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
|
||||||
|
|
||||||
// Make sure we increase the VALID UNTIL endpoint for this user.
|
// Make sure we increase the VALID UNTIL endpoint for this user.
|
||||||
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
|
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
|
||||||
err := db.RenewUser(role.Statements, username, expireTime)
|
err := db.RenewUser(context.TODO(), role.Statements, username, expireTime)
|
||||||
// Unlock
|
// Unlock
|
||||||
unlockFunc()
|
unlockFunc()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -119,14 +120,14 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
|
||||||
unlockFunc = b.Unlock
|
unlockFunc = b.Unlock
|
||||||
|
|
||||||
// Create a new DB object
|
// Create a new DB object
|
||||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
unlockFunc()
|
unlockFunc()
|
||||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.RevokeUser(role.Statements, username)
|
err = db.RevokeUser(context.TODO(), role.Statements, username)
|
||||||
// Unlock
|
// Unlock
|
||||||
unlockFunc()
|
unlockFunc()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -43,6 +43,9 @@ type PolicyRequest struct {
|
||||||
|
|
||||||
// Whether to upsert
|
// Whether to upsert
|
||||||
Upsert bool
|
Upsert bool
|
||||||
|
|
||||||
|
// Whether to allow plaintext backup
|
||||||
|
AllowPlaintextBackup bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type LockManager struct {
|
type LockManager struct {
|
||||||
|
@ -378,10 +381,11 @@ func (lm *LockManager) getPolicyCommon(req PolicyRequest, lockType bool) (*Polic
|
||||||
}
|
}
|
||||||
|
|
||||||
p = &Policy{
|
p = &Policy{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Type: req.KeyType,
|
Type: req.KeyType,
|
||||||
Derived: req.Derived,
|
Derived: req.Derived,
|
||||||
Exportable: req.Exportable,
|
Exportable: req.Exportable,
|
||||||
|
AllowPlaintextBackup: req.AllowPlaintextBackup,
|
||||||
}
|
}
|
||||||
if req.Derived {
|
if req.Derived {
|
||||||
p.KDF = Kdf_hkdf_sha256
|
p.KDF = Kdf_hkdf_sha256
|
||||||
|
|
|
@ -119,6 +119,10 @@ func (r *PluginRunner) runCommon(wrapper RunnerUtil, pluginMap map[string]plugin
|
||||||
SecureConfig: secureConfig,
|
SecureConfig: secureConfig,
|
||||||
TLSConfig: clientTLSConfig,
|
TLSConfig: clientTLSConfig,
|
||||||
Logger: namedLogger,
|
Logger: namedLogger,
|
||||||
|
AllowedProtocols: []plugin.Protocol{
|
||||||
|
plugin.ProtocolNetRPC,
|
||||||
|
plugin.ProtocolGRPC,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
client := plugin.NewClient(clientConfig)
|
client := plugin.NewClient(clientConfig)
|
||||||
|
|
|
@ -192,8 +192,7 @@ func TestBackendHandleRequest_helpRoot(t *testing.T) {
|
||||||
func TestBackendHandleRequest_renewAuth(t *testing.T) {
|
func TestBackendHandleRequest_renewAuth(t *testing.T) {
|
||||||
b := &Backend{}
|
b := &Backend{}
|
||||||
|
|
||||||
resp, err := b.HandleRequest(logical.RenewAuthRequest(
|
resp, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil))
|
||||||
"/foo", &logical.Auth{}, nil))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -213,8 +212,7 @@ func TestBackendHandleRequest_renewAuthCallback(t *testing.T) {
|
||||||
AuthRenew: callback,
|
AuthRenew: callback,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := b.HandleRequest(logical.RenewAuthRequest(
|
_, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil))
|
||||||
"/foo", &logical.Auth{}, nil))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -237,8 +235,7 @@ func TestBackendHandleRequest_renew(t *testing.T) {
|
||||||
Secrets: []*Secret{secret},
|
Secrets: []*Secret{secret},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := b.HandleRequest(logical.RenewRequest(
|
_, err := b.HandleRequest(logical.RenewRequest("/foo", secret.Response(nil, nil).Secret, nil))
|
||||||
"/foo", secret.Response(nil, nil).Secret, nil))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -293,8 +290,7 @@ func TestBackendHandleRequest_revoke(t *testing.T) {
|
||||||
Secrets: []*Secret{secret},
|
Secrets: []*Secret{secret},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := b.HandleRequest(logical.RevokeRequest(
|
_, err := b.HandleRequest(logical.RevokeRequest("/foo", secret.Response(nil, nil).Secret, nil))
|
||||||
"/foo", secret.Response(nil, nil).Secret, nil))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// LeaseExtend returns an OperationFunc that can be used to simply extend the
|
// LeaseExtend returns an OperationFunc that can be used to simply extend the
|
||||||
// lease of the auth/secret for the duration that was requested.
|
// lease of the auth/secret for the duration that was requested. The parameters
|
||||||
|
// provided are used to determine the lease's new TTL value.
|
||||||
//
|
//
|
||||||
// backendIncrement is the backend's requested increment -- perhaps from a user
|
// backendIncrement is the backend's requested increment -- perhaps from a user
|
||||||
// request, perhaps from a role/config value. If not set, uses the mount/system
|
// request, perhaps from a role/config value. If not set, uses the mount/system
|
||||||
|
|
|
@ -200,8 +200,7 @@ func (r *Request) SetLastRemoteWAL(last uint64) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenewRequest creates the structure of the renew request.
|
// RenewRequest creates the structure of the renew request.
|
||||||
func RenewRequest(
|
func RenewRequest(path string, secret *Secret, data map[string]interface{}) *Request {
|
||||||
path string, secret *Secret, data map[string]interface{}) *Request {
|
|
||||||
return &Request{
|
return &Request{
|
||||||
Operation: RenewOperation,
|
Operation: RenewOperation,
|
||||||
Path: path,
|
Path: path,
|
||||||
|
@ -211,8 +210,7 @@ func RenewRequest(
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenewAuthRequest creates the structure of the renew request for an auth.
|
// RenewAuthRequest creates the structure of the renew request for an auth.
|
||||||
func RenewAuthRequest(
|
func RenewAuthRequest(path string, auth *Auth, data map[string]interface{}) *Request {
|
||||||
path string, auth *Auth, data map[string]interface{}) *Request {
|
|
||||||
return &Request{
|
return &Request{
|
||||||
Operation: RenewOperation,
|
Operation: RenewOperation,
|
||||||
Path: path,
|
Path: path,
|
||||||
|
@ -222,8 +220,7 @@ func RenewAuthRequest(
|
||||||
}
|
}
|
||||||
|
|
||||||
// RevokeRequest creates the structure of the revoke request.
|
// RevokeRequest creates the structure of the revoke request.
|
||||||
func RevokeRequest(
|
func RevokeRequest(path string, secret *Secret, data map[string]interface{}) *Request {
|
||||||
path string, secret *Secret, data map[string]interface{}) *Request {
|
|
||||||
return &Request{
|
return &Request{
|
||||||
Operation: RevokeOperation,
|
Operation: RevokeOperation,
|
||||||
Path: path,
|
Path: path,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package cassandra
|
package cassandra
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -21,6 +22,8 @@ const (
|
||||||
cassandraTypeName = "cassandra"
|
cassandraTypeName = "cassandra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var _ dbplugin.Database = &Cassandra{}
|
||||||
|
|
||||||
// Cassandra is an implementation of Database interface
|
// Cassandra is an implementation of Database interface
|
||||||
type Cassandra struct {
|
type Cassandra struct {
|
||||||
connutil.ConnectionProducer
|
connutil.ConnectionProducer
|
||||||
|
@ -64,8 +67,8 @@ func (c *Cassandra) Type() (string, error) {
|
||||||
return cassandraTypeName, nil
|
return cassandraTypeName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cassandra) getConnection() (*gocql.Session, error) {
|
func (c *Cassandra) getConnection(ctx context.Context) (*gocql.Session, error) {
|
||||||
session, err := c.Connection()
|
session, err := c.Connection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -75,13 +78,13 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) {
|
||||||
|
|
||||||
// CreateUser generates the username/password on the underlying Cassandra secret backend as instructed by
|
// CreateUser generates the username/password on the underlying Cassandra secret backend as instructed by
|
||||||
// the CreationStatement provided.
|
// the CreationStatement provided.
|
||||||
func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||||
// Grab the lock
|
// Grab the lock
|
||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
|
|
||||||
// Get the connection
|
// Get the connection
|
||||||
session, err := c.getConnection()
|
session, err := c.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
@ -138,18 +141,18 @@ func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig db
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenewUser is not supported on Cassandra, so this is a no-op.
|
// RenewUser is not supported on Cassandra, so this is a no-op.
|
||||||
func (c *Cassandra) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
func (c *Cassandra) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||||
// NOOP
|
// NOOP
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RevokeUser attempts to drop the specified user.
|
// RevokeUser attempts to drop the specified user.
|
||||||
func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) error {
|
func (c *Cassandra) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||||
// Grab the lock
|
// Grab the lock
|
||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
|
|
||||||
session, err := c.getConnection()
|
session, err := c.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package cassandra
|
package cassandra
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -89,7 +90,7 @@ func TestCassandra_Initialize(t *testing.T) {
|
||||||
db := dbRaw.(*Cassandra)
|
db := dbRaw.(*Cassandra)
|
||||||
connProducer := db.ConnectionProducer.(*cassandraConnectionProducer)
|
connProducer := db.ConnectionProducer.(*cassandraConnectionProducer)
|
||||||
|
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -112,7 +113,7 @@ func TestCassandra_Initialize(t *testing.T) {
|
||||||
"protocol_version": "4",
|
"protocol_version": "4",
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Initialize(connectionDetails, true)
|
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -135,7 +136,7 @@ func TestCassandra_CreateUser(t *testing.T) {
|
||||||
|
|
||||||
dbRaw, _ := New()
|
dbRaw, _ := New()
|
||||||
db := dbRaw.(*Cassandra)
|
db := dbRaw.(*Cassandra)
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -149,7 +150,7 @@ func TestCassandra_CreateUser(t *testing.T) {
|
||||||
RoleName: "test",
|
RoleName: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -176,7 +177,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
|
||||||
|
|
||||||
dbRaw, _ := New()
|
dbRaw, _ := New()
|
||||||
db := dbRaw.(*Cassandra)
|
db := dbRaw.(*Cassandra)
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -190,7 +191,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
|
||||||
RoleName: "test",
|
RoleName: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -199,7 +200,7 @@ func TestMyCassandra_RenewUser(t *testing.T) {
|
||||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
|
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -222,7 +223,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
|
||||||
|
|
||||||
dbRaw, _ := New()
|
dbRaw, _ := New()
|
||||||
db := dbRaw.(*Cassandra)
|
db := dbRaw.(*Cassandra)
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -236,7 +237,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
|
||||||
RoleName: "test",
|
RoleName: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -246,7 +247,7 @@ func TestCassandra_RevokeUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test default revoke statememts
|
// Test default revoke statememts
|
||||||
err = db.RevokeUser(statements, username)
|
err = db.RevokeUser(context.Background(), statements, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package cassandra
|
package cassandra
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -43,7 +44,7 @@ type cassandraConnectionProducer struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
|
|
||||||
|
@ -106,7 +107,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve
|
||||||
c.Initialized = true
|
c.Initialized = true
|
||||||
|
|
||||||
if verifyConnection {
|
if verifyConnection {
|
||||||
if _, err := c.Connection(); err != nil {
|
if _, err := c.Connection(ctx); err != nil {
|
||||||
return fmt.Errorf("error verifying connection: %s", err)
|
return fmt.Errorf("error verifying connection: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -114,7 +115,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cassandraConnectionProducer) Connection() (interface{}, error) {
|
func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) {
|
||||||
if !c.Initialized {
|
if !c.Initialized {
|
||||||
return nil, connutil.ErrNotInitialized
|
return nil, connutil.ErrNotInitialized
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package hana
|
package hana
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -26,6 +27,8 @@ type HANA struct {
|
||||||
credsutil.CredentialsProducer
|
credsutil.CredentialsProducer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ dbplugin.Database = &HANA{}
|
||||||
|
|
||||||
// New implements builtinplugins.BuiltinFactory
|
// New implements builtinplugins.BuiltinFactory
|
||||||
func New() (interface{}, error) {
|
func New() (interface{}, error) {
|
||||||
connProducer := &connutil.SQLConnectionProducer{}
|
connProducer := &connutil.SQLConnectionProducer{}
|
||||||
|
@ -63,8 +66,8 @@ func (h *HANA) Type() (string, error) {
|
||||||
return hanaTypeName, nil
|
return hanaTypeName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HANA) getConnection() (*sql.DB, error) {
|
func (h *HANA) getConnection(ctx context.Context) (*sql.DB, error) {
|
||||||
db, err := h.Connection()
|
db, err := h.Connection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -74,13 +77,13 @@ func (h *HANA) getConnection() (*sql.DB, error) {
|
||||||
|
|
||||||
// CreateUser generates the username/password on the underlying HANA secret backend
|
// CreateUser generates the username/password on the underlying HANA secret backend
|
||||||
// as instructed by the CreationStatement provided.
|
// as instructed by the CreationStatement provided.
|
||||||
func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||||
// Grab the lock
|
// Grab the lock
|
||||||
h.Lock()
|
h.Lock()
|
||||||
defer h.Unlock()
|
defer h.Unlock()
|
||||||
|
|
||||||
// Get the connection
|
// Get the connection
|
||||||
db, err := h.getConnection()
|
db, err := h.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
@ -117,7 +120,7 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a transaction
|
// Start a transaction
|
||||||
tx, err := db.Begin()
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
@ -130,7 +133,7 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
"name": username,
|
"name": username,
|
||||||
"password": password,
|
"password": password,
|
||||||
"expiration": expirationStr,
|
"expiration": expirationStr,
|
||||||
|
@ -139,7 +142,7 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if _, err := stmt.Exec(); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -153,15 +156,15 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi
|
||||||
}
|
}
|
||||||
|
|
||||||
// Renewing hana user just means altering user's valid until property
|
// Renewing hana user just means altering user's valid until property
|
||||||
func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||||
// Get connection
|
// Get connection
|
||||||
db, err := h.getConnection()
|
db, err := h.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a transaction
|
// Start a transaction
|
||||||
tx, err := db.Begin()
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -175,12 +178,12 @@ func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expira
|
||||||
}
|
}
|
||||||
|
|
||||||
// Renew user's valid until property field
|
// Renew user's valid until property field
|
||||||
stmt, err := tx.Prepare("ALTER USER " + username + " VALID UNTIL " + "'" + expirationStr + "'")
|
stmt, err := tx.PrepareContext(ctx, "ALTER USER "+username+" VALID UNTIL "+"'"+expirationStr+"'")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if _, err := stmt.Exec(); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -193,20 +196,20 @@ func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expira
|
||||||
}
|
}
|
||||||
|
|
||||||
// Revoking hana user will deactivate user and try to perform a soft drop
|
// Revoking hana user will deactivate user and try to perform a soft drop
|
||||||
func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error {
|
func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||||
// default revoke will be a soft drop on user
|
// default revoke will be a soft drop on user
|
||||||
if statements.RevocationStatements == "" {
|
if statements.RevocationStatements == "" {
|
||||||
return h.revokeUserDefault(username)
|
return h.revokeUserDefault(ctx, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get connection
|
// Get connection
|
||||||
db, err := h.getConnection()
|
db, err := h.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a transaction
|
// Start a transaction
|
||||||
tx, err := db.Begin()
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -219,14 +222,14 @@ func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
"name": username,
|
"name": username,
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if _, err := stmt.Exec(); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -239,38 +242,38 @@ func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HANA) revokeUserDefault(username string) error {
|
func (h *HANA) revokeUserDefault(ctx context.Context, username string) error {
|
||||||
// Get connection
|
// Get connection
|
||||||
db, err := h.getConnection()
|
db, err := h.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a transaction
|
// Start a transaction
|
||||||
tx, err := db.Begin()
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
// Disable server login for user
|
// Disable server login for user
|
||||||
disableStmt, err := tx.Prepare(fmt.Sprintf("ALTER USER %s DEACTIVATE USER NOW", username))
|
disableStmt, err := tx.PrepareContext(ctx, fmt.Sprintf("ALTER USER %s DEACTIVATE USER NOW", username))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer disableStmt.Close()
|
defer disableStmt.Close()
|
||||||
if _, err := disableStmt.Exec(); err != nil {
|
if _, err := disableStmt.ExecContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invalidates current sessions and performs soft drop (drop if no dependencies)
|
// Invalidates current sessions and performs soft drop (drop if no dependencies)
|
||||||
// if hard drop is desired, custom revoke statements should be written for role
|
// if hard drop is desired, custom revoke statements should be written for role
|
||||||
dropStmt, err := tx.Prepare(fmt.Sprintf("DROP USER %s RESTRICT", username))
|
dropStmt, err := tx.PrepareContext(ctx, fmt.Sprintf("DROP USER %s RESTRICT", username))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer dropStmt.Close()
|
defer dropStmt.Close()
|
||||||
if _, err := dropStmt.Exec(); err != nil {
|
if _, err := dropStmt.ExecContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package hana
|
package hana
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
@ -25,7 +26,7 @@ func TestHANA_Initialize(t *testing.T) {
|
||||||
dbRaw, _ := New()
|
dbRaw, _ := New()
|
||||||
db := dbRaw.(*HANA)
|
db := dbRaw.(*HANA)
|
||||||
|
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -55,7 +56,7 @@ func TestHANA_CreateUser(t *testing.T) {
|
||||||
dbRaw, _ := New()
|
dbRaw, _ := New()
|
||||||
db := dbRaw.(*HANA)
|
db := dbRaw.(*HANA)
|
||||||
|
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -66,7 +67,7 @@ func TestHANA_CreateUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test with no configured Creation Statememt
|
// Test with no configured Creation Statememt
|
||||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour))
|
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("Expected error when no creation statement is provided")
|
t.Fatal("Expected error when no creation statement is provided")
|
||||||
}
|
}
|
||||||
|
@ -75,7 +76,7 @@ func TestHANA_CreateUser(t *testing.T) {
|
||||||
CreationStatements: testHANARole,
|
CreationStatements: testHANARole,
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -98,7 +99,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
||||||
dbRaw, _ := New()
|
dbRaw, _ := New()
|
||||||
db := dbRaw.(*HANA)
|
db := dbRaw.(*HANA)
|
||||||
|
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -113,7 +114,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test default revoke statememts
|
// Test default revoke statememts
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -121,7 +122,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
||||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.RevokeUser(statements, username)
|
err = db.RevokeUser(context.Background(), statements, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -130,7 +131,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test custom revoke statememt
|
// Test custom revoke statememt
|
||||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour))
|
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -139,7 +140,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
statements.RevocationStatements = testHANADrop
|
statements.RevocationStatements = testHANADrop
|
||||||
err = db.RevokeUser(statements, username)
|
err = db.RevokeUser(context.Background(), statements, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package mongodb
|
package mongodb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -33,7 +34,7 @@ type mongoDBConnectionProducer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize parses connection configuration.
|
// Initialize parses connection configuration.
|
||||||
func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
|
|
||||||
|
@ -75,7 +76,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri
|
||||||
c.Initialized = true
|
c.Initialized = true
|
||||||
|
|
||||||
if verifyConnection {
|
if verifyConnection {
|
||||||
if _, err := c.Connection(); err != nil {
|
if _, err := c.Connection(ctx); err != nil {
|
||||||
return fmt.Errorf("error verifying connection: %s", err)
|
return fmt.Errorf("error verifying connection: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,7 +89,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connection creates a database connection.
|
// Connection creates a database connection.
|
||||||
func (c *mongoDBConnectionProducer) Connection() (interface{}, error) {
|
func (c *mongoDBConnectionProducer) Connection(_ context.Context) (interface{}, error) {
|
||||||
if !c.Initialized {
|
if !c.Initialized {
|
||||||
return nil, connutil.ErrNotInitialized
|
return nil, connutil.ErrNotInitialized
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package mongodb
|
package mongodb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -27,6 +28,8 @@ type MongoDB struct {
|
||||||
credsutil.CredentialsProducer
|
credsutil.CredentialsProducer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ dbplugin.Database = &MongoDB{}
|
||||||
|
|
||||||
// New returns a new MongoDB instance
|
// New returns a new MongoDB instance
|
||||||
func New() (interface{}, error) {
|
func New() (interface{}, error) {
|
||||||
connProducer := &mongoDBConnectionProducer{}
|
connProducer := &mongoDBConnectionProducer{}
|
||||||
|
@ -63,8 +66,8 @@ func (m *MongoDB) Type() (string, error) {
|
||||||
return mongoDBTypeName, nil
|
return mongoDBTypeName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MongoDB) getConnection() (*mgo.Session, error) {
|
func (m *MongoDB) getConnection(ctx context.Context) (*mgo.Session, error) {
|
||||||
session, err := m.Connection()
|
session, err := m.Connection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -80,7 +83,7 @@ func (m *MongoDB) getConnection() (*mgo.Session, error) {
|
||||||
//
|
//
|
||||||
// JSON Example:
|
// JSON Example:
|
||||||
// { "db": "admin", "roles": [{ "role": "readWrite" }, {"role": "read", "db": "foo"}] }
|
// { "db": "admin", "roles": [{ "role": "readWrite" }, {"role": "read", "db": "foo"}] }
|
||||||
func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||||
// Grab the lock
|
// Grab the lock
|
||||||
m.Lock()
|
m.Lock()
|
||||||
defer m.Unlock()
|
defer m.Unlock()
|
||||||
|
@ -89,7 +92,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
|
||||||
return "", "", dbutil.ErrEmptyCreationStatement
|
return "", "", dbutil.ErrEmptyCreationStatement
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := m.getConnection()
|
session, err := m.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
@ -133,7 +136,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
|
||||||
if err := m.ConnectionProducer.Close(); err != nil {
|
if err := m.ConnectionProducer.Close(); err != nil {
|
||||||
return "", "", errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
|
return "", "", errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
|
||||||
}
|
}
|
||||||
session, err := m.getConnection()
|
session, err := m.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
@ -149,15 +152,15 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenewUser is not supported on MongoDB, so this is a no-op.
|
// RenewUser is not supported on MongoDB, so this is a no-op.
|
||||||
func (m *MongoDB) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
func (m *MongoDB) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||||
// NOOP
|
// NOOP
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RevokeUser drops the specified user from the authentication databse. If none is provided
|
// RevokeUser drops the specified user from the authentication databse. If none is provided
|
||||||
// in the revocation statement, the default "admin" authentication database will be assumed.
|
// in the revocation statement, the default "admin" authentication database will be assumed.
|
||||||
func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) error {
|
func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||||
session, err := m.getConnection()
|
session, err := m.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -188,7 +191,7 @@ func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) er
|
||||||
if err := m.ConnectionProducer.Close(); err != nil {
|
if err := m.ConnectionProducer.Close(); err != nil {
|
||||||
return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
|
return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
|
||||||
}
|
}
|
||||||
session, err := m.getConnection()
|
session, err := m.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package mongodb
|
package mongodb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -79,7 +80,7 @@ func TestMongoDB_Initialize(t *testing.T) {
|
||||||
db := dbRaw.(*MongoDB)
|
db := dbRaw.(*MongoDB)
|
||||||
connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer)
|
connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer)
|
||||||
|
|
||||||
err = db.Initialize(connectionDetails, true)
|
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -107,7 +108,7 @@ func TestMongoDB_CreateUser(t *testing.T) {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
db := dbRaw.(*MongoDB)
|
db := dbRaw.(*MongoDB)
|
||||||
err = db.Initialize(connectionDetails, true)
|
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -121,7 +122,7 @@ func TestMongoDB_CreateUser(t *testing.T) {
|
||||||
RoleName: "test",
|
RoleName: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -145,7 +146,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
db := dbRaw.(*MongoDB)
|
db := dbRaw.(*MongoDB)
|
||||||
err = db.Initialize(connectionDetails, true)
|
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -159,7 +160,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
|
||||||
RoleName: "test",
|
RoleName: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -182,7 +183,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
db := dbRaw.(*MongoDB)
|
db := dbRaw.(*MongoDB)
|
||||||
err = db.Initialize(connectionDetails, true)
|
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -196,7 +197,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
|
||||||
RoleName: "test",
|
RoleName: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -206,7 +207,7 @@ func TestMongoDB_RevokeUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test default revocation statememt
|
// Test default revocation statememt
|
||||||
err = db.RevokeUser(statements, username)
|
err = db.RevokeUser(context.Background(), statements, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package mssql
|
package mssql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -18,6 +19,8 @@ import (
|
||||||
|
|
||||||
const msSQLTypeName = "mssql"
|
const msSQLTypeName = "mssql"
|
||||||
|
|
||||||
|
var _ dbplugin.Database = &MSSQL{}
|
||||||
|
|
||||||
// MSSQL is an implementation of Database interface
|
// MSSQL is an implementation of Database interface
|
||||||
type MSSQL struct {
|
type MSSQL struct {
|
||||||
connutil.ConnectionProducer
|
connutil.ConnectionProducer
|
||||||
|
@ -60,8 +63,8 @@ func (m *MSSQL) Type() (string, error) {
|
||||||
return msSQLTypeName, nil
|
return msSQLTypeName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MSSQL) getConnection() (*sql.DB, error) {
|
func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) {
|
||||||
db, err := m.Connection()
|
db, err := m.Connection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -71,13 +74,13 @@ func (m *MSSQL) getConnection() (*sql.DB, error) {
|
||||||
|
|
||||||
// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
|
// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
|
||||||
// the CreationStatement provided.
|
// the CreationStatement provided.
|
||||||
func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||||
// Grab the lock
|
// Grab the lock
|
||||||
m.Lock()
|
m.Lock()
|
||||||
defer m.Unlock()
|
defer m.Unlock()
|
||||||
|
|
||||||
// Get the connection
|
// Get the connection
|
||||||
db, err := m.getConnection()
|
db, err := m.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
@ -102,7 +105,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a transaction
|
// Start a transaction
|
||||||
tx, err := db.Begin()
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
@ -115,7 +118,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
"name": username,
|
"name": username,
|
||||||
"password": password,
|
"password": password,
|
||||||
"expiration": expirationStr,
|
"expiration": expirationStr,
|
||||||
|
@ -124,7 +127,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if _, err := stmt.Exec(); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -138,7 +141,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenewUser is not supported on MSSQL, so this is a no-op.
|
// RenewUser is not supported on MSSQL, so this is a no-op.
|
||||||
func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||||
// NOOP
|
// NOOP
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -146,19 +149,19 @@ func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expir
|
||||||
// RevokeUser attempts to drop the specified user. It will first attempt to disable login,
|
// RevokeUser attempts to drop the specified user. It will first attempt to disable login,
|
||||||
// then kill pending connections from that user, and finally drop the user and login from the
|
// then kill pending connections from that user, and finally drop the user and login from the
|
||||||
// database instance.
|
// database instance.
|
||||||
func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||||
if statements.RevocationStatements == "" {
|
if statements.RevocationStatements == "" {
|
||||||
return m.revokeUserDefault(username)
|
return m.revokeUserDefault(ctx, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get connection
|
// Get connection
|
||||||
db, err := m.getConnection()
|
db, err := m.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a transaction
|
// Start a transaction
|
||||||
tx, err := db.Begin()
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -171,14 +174,14 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
"name": username,
|
"name": username,
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if _, err := stmt.Exec(); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -191,20 +194,20 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MSSQL) revokeUserDefault(username string) error {
|
func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
|
||||||
// Get connection
|
// Get connection
|
||||||
db, err := m.getConnection()
|
db, err := m.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// First disable server login
|
// First disable server login
|
||||||
disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
|
disableStmt, err := db.PrepareContext(ctx, fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer disableStmt.Close()
|
defer disableStmt.Close()
|
||||||
if _, err := disableStmt.Exec(); err != nil {
|
if _, err := disableStmt.ExecContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -212,14 +215,14 @@ func (m *MSSQL) revokeUserDefault(username string) error {
|
||||||
// sessions. There cannot be any active sessions before we drop the logins
|
// sessions. There cannot be any active sessions before we drop the logins
|
||||||
// This isn't done in a transaction because even if we fail along the way,
|
// This isn't done in a transaction because even if we fail along the way,
|
||||||
// we want to remove as much access as possible
|
// we want to remove as much access as possible
|
||||||
sessionStmt, err := db.Prepare(fmt.Sprintf(
|
sessionStmt, err := db.PrepareContext(ctx, fmt.Sprintf(
|
||||||
"SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username))
|
"SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer sessionStmt.Close()
|
defer sessionStmt.Close()
|
||||||
|
|
||||||
sessionRows, err := sessionStmt.Query()
|
sessionRows, err := sessionStmt.QueryContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -240,13 +243,13 @@ func (m *MSSQL) revokeUserDefault(username string) error {
|
||||||
// we need to drop the database users before we can drop the login and the role
|
// we need to drop the database users before we can drop the login and the role
|
||||||
// This isn't done in a transaction because even if we fail along the way,
|
// This isn't done in a transaction because even if we fail along the way,
|
||||||
// we want to remove as much access as possible
|
// we want to remove as much access as possible
|
||||||
stmt, err := db.Prepare(fmt.Sprintf("EXEC master.dbo.sp_msloginmappings '%s';", username))
|
stmt, err := db.PrepareContext(ctx, fmt.Sprintf("EXEC master.dbo.sp_msloginmappings '%s';", username))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
|
|
||||||
rows, err := stmt.Query()
|
rows, err := stmt.QueryContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -266,13 +269,13 @@ func (m *MSSQL) revokeUserDefault(username string) error {
|
||||||
// many permissions as possible right now
|
// many permissions as possible right now
|
||||||
var lastStmtError error
|
var lastStmtError error
|
||||||
for _, query := range revokeStmts {
|
for _, query := range revokeStmts {
|
||||||
stmt, err := db.Prepare(query)
|
stmt, err := db.PrepareContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastStmtError = err
|
lastStmtError = err
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
_, err = stmt.Exec()
|
_, err = stmt.ExecContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastStmtError = err
|
lastStmtError = err
|
||||||
}
|
}
|
||||||
|
@ -287,12 +290,12 @@ func (m *MSSQL) revokeUserDefault(username string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Drop this login
|
// Drop this login
|
||||||
stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username))
|
stmt, err = db.PrepareContext(ctx, fmt.Sprintf(dropLoginSQL, username, username))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if _, err := stmt.Exec(); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package mssql
|
package mssql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
@ -30,7 +31,7 @@ func TestMSSQL_Initialize(t *testing.T) {
|
||||||
dbRaw, _ := New()
|
dbRaw, _ := New()
|
||||||
db := dbRaw.(*MSSQL)
|
db := dbRaw.(*MSSQL)
|
||||||
|
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -51,7 +52,7 @@ func TestMSSQL_Initialize(t *testing.T) {
|
||||||
"max_open_connections": "5",
|
"max_open_connections": "5",
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Initialize(connectionDetails, true)
|
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -69,7 +70,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
|
||||||
|
|
||||||
dbRaw, _ := New()
|
dbRaw, _ := New()
|
||||||
db := dbRaw.(*MSSQL)
|
db := dbRaw.(*MSSQL)
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -80,7 +81,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test with no configured Creation Statememt
|
// Test with no configured Creation Statememt
|
||||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("Expected error when no creation statement is provided")
|
t.Fatal("Expected error when no creation statement is provided")
|
||||||
}
|
}
|
||||||
|
@ -89,7 +90,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
|
||||||
CreationStatements: testMSSQLRole,
|
CreationStatements: testMSSQLRole,
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -111,7 +112,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
||||||
|
|
||||||
dbRaw, _ := New()
|
dbRaw, _ := New()
|
||||||
db := dbRaw.(*MSSQL)
|
db := dbRaw.(*MSSQL)
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -125,7 +126,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
||||||
RoleName: "test",
|
RoleName: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -135,7 +136,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test default revoke statememts
|
// Test default revoke statememts
|
||||||
err = db.RevokeUser(statements, username)
|
err = db.RevokeUser(context.Background(), statements, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -144,7 +145,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
||||||
t.Fatal("Credentials were not revoked")
|
t.Fatal("Credentials were not revoked")
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -155,7 +156,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
||||||
|
|
||||||
// Test custom revoke statememt
|
// Test custom revoke statememt
|
||||||
statements.RevocationStatements = testMSSQLDrop
|
statements.RevocationStatements = testMSSQLDrop
|
||||||
err = db.RevokeUser(statements, username)
|
err = db.RevokeUser(context.Background(), statements, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package mysql
|
package mysql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -30,6 +31,8 @@ var (
|
||||||
LegacyUsernameLen int = 16
|
LegacyUsernameLen int = 16
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var _ dbplugin.Database = &MySQL{}
|
||||||
|
|
||||||
type MySQL struct {
|
type MySQL struct {
|
||||||
connutil.ConnectionProducer
|
connutil.ConnectionProducer
|
||||||
credsutil.CredentialsProducer
|
credsutil.CredentialsProducer
|
||||||
|
@ -88,8 +91,8 @@ func (m *MySQL) Type() (string, error) {
|
||||||
return mySQLTypeName, nil
|
return mySQLTypeName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MySQL) getConnection() (*sql.DB, error) {
|
func (m *MySQL) getConnection(ctx context.Context) (*sql.DB, error) {
|
||||||
db, err := m.Connection()
|
db, err := m.Connection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -97,13 +100,13 @@ func (m *MySQL) getConnection() (*sql.DB, error) {
|
||||||
return db.(*sql.DB), nil
|
return db.(*sql.DB), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||||
// Grab the lock
|
// Grab the lock
|
||||||
m.Lock()
|
m.Lock()
|
||||||
defer m.Unlock()
|
defer m.Unlock()
|
||||||
|
|
||||||
// Get the connection
|
// Get the connection
|
||||||
db, err := m.getConnection()
|
db, err := m.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
@ -128,7 +131,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a transaction
|
// Start a transaction
|
||||||
tx, err := db.Begin()
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
@ -146,7 +149,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
||||||
"expiration": expirationStr,
|
"expiration": expirationStr,
|
||||||
})
|
})
|
||||||
|
|
||||||
stmt, err := tx.Prepare(query)
|
stmt, err := tx.PrepareContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If the error code we get back is Error 1295: This command is not
|
// If the error code we get back is Error 1295: This command is not
|
||||||
// supported in the prepared statement protocol yet, we will execute
|
// supported in the prepared statement protocol yet, we will execute
|
||||||
|
@ -155,7 +158,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
||||||
// prepare supported commands. If there is no error when running we
|
// prepare supported commands. If there is no error when running we
|
||||||
// will continue to the next statement.
|
// will continue to the next statement.
|
||||||
if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 {
|
if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 {
|
||||||
_, err = tx.Exec(query)
|
_, err = tx.ExecContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
@ -165,7 +168,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if _, err := stmt.Exec(); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -179,17 +182,17 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOOP
|
// NOOP
|
||||||
func (m *MySQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
func (m *MySQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||||
// Grab the read lock
|
// Grab the read lock
|
||||||
m.Lock()
|
m.Lock()
|
||||||
defer m.Unlock()
|
defer m.Unlock()
|
||||||
|
|
||||||
// Get the connection
|
// Get the connection
|
||||||
db, err := m.getConnection()
|
db, err := m.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -201,7 +204,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a transaction
|
// Start a transaction
|
||||||
tx, err := db.Begin()
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -217,7 +220,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro
|
||||||
// 1295: This command is not supported in the prepared statement protocol yet
|
// 1295: This command is not supported in the prepared statement protocol yet
|
||||||
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
|
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
|
||||||
query = strings.Replace(query, "{{name}}", username, -1)
|
query = strings.Replace(query, "{{name}}", username, -1)
|
||||||
_, err = tx.Exec(query)
|
_, err = tx.ExecContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package mysql
|
package mysql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
@ -108,7 +109,7 @@ func TestMySQL_Initialize(t *testing.T) {
|
||||||
db := dbRaw.(*MySQL)
|
db := dbRaw.(*MySQL)
|
||||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
||||||
|
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -128,7 +129,7 @@ func TestMySQL_Initialize(t *testing.T) {
|
||||||
"max_open_connections": "5",
|
"max_open_connections": "5",
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Initialize(connectionDetails, true)
|
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -146,7 +147,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
||||||
dbRaw, _ := f()
|
dbRaw, _ := f()
|
||||||
db := dbRaw.(*MySQL)
|
db := dbRaw.(*MySQL)
|
||||||
|
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -157,7 +158,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test with no configured Creation Statememt
|
// Test with no configured Creation Statememt
|
||||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("Expected error when no creation statement is provided")
|
t.Fatal("Expected error when no creation statement is provided")
|
||||||
}
|
}
|
||||||
|
@ -166,7 +167,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
||||||
CreationStatements: testMySQLRoleWildCard,
|
CreationStatements: testMySQLRoleWildCard,
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -176,7 +177,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test a second time to make sure usernames don't collide
|
// Test a second time to make sure usernames don't collide
|
||||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -188,7 +189,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
||||||
// Test with a manualy prepare statement
|
// Test with a manualy prepare statement
|
||||||
statements.CreationStatements = testMySQLRolePreparedStmt
|
statements.CreationStatements = testMySQLRolePreparedStmt
|
||||||
|
|
||||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -211,7 +212,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
|
||||||
dbRaw, _ := f()
|
dbRaw, _ := f()
|
||||||
db := dbRaw.(*MySQL)
|
db := dbRaw.(*MySQL)
|
||||||
|
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -222,7 +223,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test with no configured Creation Statememt
|
// Test with no configured Creation Statememt
|
||||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("Expected error when no creation statement is provided")
|
t.Fatal("Expected error when no creation statement is provided")
|
||||||
}
|
}
|
||||||
|
@ -231,7 +232,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
|
||||||
CreationStatements: testMySQLRoleWildCard,
|
CreationStatements: testMySQLRoleWildCard,
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -241,7 +242,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test a second time to make sure usernames don't collide
|
// Test a second time to make sure usernames don't collide
|
||||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -263,7 +264,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
||||||
dbRaw, _ := f()
|
dbRaw, _ := f()
|
||||||
db := dbRaw.(*MySQL)
|
db := dbRaw.(*MySQL)
|
||||||
|
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -277,7 +278,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
||||||
RoleName: "test",
|
RoleName: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -287,7 +288,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test default revoke statememts
|
// Test default revoke statememts
|
||||||
err = db.RevokeUser(statements, username)
|
err = db.RevokeUser(context.Background(), statements, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -297,7 +298,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
statements.CreationStatements = testMySQLRoleWildCard
|
statements.CreationStatements = testMySQLRoleWildCard
|
||||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -308,7 +309,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
||||||
|
|
||||||
// Test custom revoke statements
|
// Test custom revoke statements
|
||||||
statements.RevocationStatements = testMySQLRevocationSQL
|
statements.RevocationStatements = testMySQLRevocationSQL
|
||||||
err = db.RevokeUser(statements, username)
|
err = db.RevokeUser(context.Background(), statements, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package postgresql
|
package postgresql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -24,6 +25,8 @@ ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
|
||||||
`
|
`
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var _ dbplugin.Database = &PostgreSQL{}
|
||||||
|
|
||||||
// New implements builtinplugins.BuiltinFactory
|
// New implements builtinplugins.BuiltinFactory
|
||||||
func New() (interface{}, error) {
|
func New() (interface{}, error) {
|
||||||
connProducer := &connutil.SQLConnectionProducer{}
|
connProducer := &connutil.SQLConnectionProducer{}
|
||||||
|
@ -65,8 +68,8 @@ func (p *PostgreSQL) Type() (string, error) {
|
||||||
return postgreSQLTypeName, nil
|
return postgreSQLTypeName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgreSQL) getConnection() (*sql.DB, error) {
|
func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) {
|
||||||
db, err := p.Connection()
|
db, err := p.Connection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -74,7 +77,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) {
|
||||||
return db.(*sql.DB), nil
|
return db.(*sql.DB), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||||
if statements.CreationStatements == "" {
|
if statements.CreationStatements == "" {
|
||||||
return "", "", dbutil.ErrEmptyCreationStatement
|
return "", "", dbutil.ErrEmptyCreationStatement
|
||||||
}
|
}
|
||||||
|
@ -99,14 +102,14 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the connection
|
// Get the connection
|
||||||
db, err := p.getConnection()
|
db, err := p.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a transaction
|
// Start a transaction
|
||||||
tx, err := db.Begin()
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
|
|
||||||
|
@ -123,7 +126,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
"name": username,
|
"name": username,
|
||||||
"password": password,
|
"password": password,
|
||||||
"expiration": expirationStr,
|
"expiration": expirationStr,
|
||||||
|
@ -133,7 +136,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
|
||||||
|
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if _, err := stmt.Exec(); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -148,7 +151,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
|
||||||
return username, password, nil
|
return username, password, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||||
p.Lock()
|
p.Lock()
|
||||||
defer p.Unlock()
|
defer p.Unlock()
|
||||||
|
|
||||||
|
@ -157,12 +160,12 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
|
||||||
renewStmts = defaultPostgresRenewSQL
|
renewStmts = defaultPostgresRenewSQL
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := p.getConnection()
|
db, err := p.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := db.Begin()
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -180,7 +183,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
|
||||||
if len(query) == 0 {
|
if len(query) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
"name": username,
|
"name": username,
|
||||||
"expiration": expirationStr,
|
"expiration": expirationStr,
|
||||||
}))
|
}))
|
||||||
|
@ -189,7 +192,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
|
||||||
}
|
}
|
||||||
|
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if _, err := stmt.Exec(); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -201,25 +204,25 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||||
// Grab the lock
|
// Grab the lock
|
||||||
p.Lock()
|
p.Lock()
|
||||||
defer p.Unlock()
|
defer p.Unlock()
|
||||||
|
|
||||||
if statements.RevocationStatements == "" {
|
if statements.RevocationStatements == "" {
|
||||||
return p.defaultRevokeUser(username)
|
return p.defaultRevokeUser(ctx, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.customRevokeUser(username, statements.RevocationStatements)
|
return p.customRevokeUser(ctx, username, statements.RevocationStatements)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationStmts string) error {
|
||||||
db, err := p.getConnection()
|
db, err := p.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := db.Begin()
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -233,7 +236,7 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
"name": username,
|
"name": username,
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -241,7 +244,7 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
|
|
||||||
if _, err := stmt.Exec(); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -253,15 +256,15 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error {
|
||||||
db, err := p.getConnection()
|
db, err := p.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the role exists
|
// Check if the role exists
|
||||||
var exists bool
|
var exists bool
|
||||||
err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
|
err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -274,13 +277,13 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
||||||
// the role
|
// the role
|
||||||
// This isn't done in a transaction because even if we fail along the way,
|
// This isn't done in a transaction because even if we fail along the way,
|
||||||
// we want to remove as much access as possible
|
// we want to remove as much access as possible
|
||||||
stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
|
stmt, err := db.PrepareContext(ctx, "SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
|
|
||||||
rows, err := stmt.Query(username)
|
rows, err := stmt.QueryContext(ctx, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -322,7 +325,7 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
||||||
// get the current database name so we can issue a REVOKE CONNECT for
|
// get the current database name so we can issue a REVOKE CONNECT for
|
||||||
// this username
|
// this username
|
||||||
var dbname sql.NullString
|
var dbname sql.NullString
|
||||||
if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil {
|
if err := db.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbname); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -337,13 +340,13 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
||||||
// many permissions as possible right now
|
// many permissions as possible right now
|
||||||
var lastStmtError error
|
var lastStmtError error
|
||||||
for _, query := range revocationStmts {
|
for _, query := range revocationStmts {
|
||||||
stmt, err := db.Prepare(query)
|
stmt, err := db.PrepareContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastStmtError = err
|
lastStmtError = err
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
_, err = stmt.Exec()
|
_, err = stmt.ExecContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastStmtError = err
|
lastStmtError = err
|
||||||
}
|
}
|
||||||
|
@ -358,13 +361,13 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Drop this user
|
// Drop this user
|
||||||
stmt, err = db.Prepare(fmt.Sprintf(
|
stmt, err = db.PrepareContext(ctx, fmt.Sprintf(
|
||||||
`DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username)))
|
`DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if _, err := stmt.Exec(); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package postgresql
|
package postgresql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
@ -72,7 +73,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
|
||||||
|
|
||||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
||||||
|
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -92,7 +93,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
|
||||||
"max_open_connections": "5",
|
"max_open_connections": "5",
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Initialize(connectionDetails, true)
|
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -109,7 +110,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
||||||
|
|
||||||
dbRaw, _ := New()
|
dbRaw, _ := New()
|
||||||
db := dbRaw.(*PostgreSQL)
|
db := dbRaw.(*PostgreSQL)
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -120,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test with no configured Creation Statememt
|
// Test with no configured Creation Statememt
|
||||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("Expected error when no creation statement is provided")
|
t.Fatal("Expected error when no creation statement is provided")
|
||||||
}
|
}
|
||||||
|
@ -129,7 +130,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
||||||
CreationStatements: testPostgresRole,
|
CreationStatements: testPostgresRole,
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -139,7 +140,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
statements.CreationStatements = testPostgresReadOnlyRole
|
statements.CreationStatements = testPostgresReadOnlyRole
|
||||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -162,7 +163,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||||
|
|
||||||
dbRaw, _ := New()
|
dbRaw, _ := New()
|
||||||
db := dbRaw.(*PostgreSQL)
|
db := dbRaw.(*PostgreSQL)
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -176,7 +177,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||||
RoleName: "test",
|
RoleName: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -185,7 +186,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
|
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -197,7 +198,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||||
}
|
}
|
||||||
statements.RenewStatements = defaultPostgresRenewSQL
|
statements.RenewStatements = defaultPostgresRenewSQL
|
||||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -206,7 +207,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
|
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -230,7 +231,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||||
|
|
||||||
dbRaw, _ := New()
|
dbRaw, _ := New()
|
||||||
db := dbRaw.(*PostgreSQL)
|
db := dbRaw.(*PostgreSQL)
|
||||||
err := db.Initialize(connectionDetails, true)
|
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -244,7 +245,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||||
RoleName: "test",
|
RoleName: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -254,7 +255,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test default revoke statememts
|
// Test default revoke statememts
|
||||||
err = db.RevokeUser(statements, username)
|
err = db.RevokeUser(context.Background(), statements, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -263,7 +264,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||||
t.Fatal("Credentials were not revoked")
|
t.Fatal("Credentials were not revoked")
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -274,7 +275,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||||
|
|
||||||
// Test custom revoke statements
|
// Test custom revoke statements
|
||||||
statements.RevocationStatements = defaultPostgresRevocationSQL
|
statements.RevocationStatements = defaultPostgresRevocationSQL
|
||||||
err = db.RevokeUser(statements, username)
|
err = db.RevokeUser(context.Background(), statements, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package connutil
|
package connutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
@ -14,8 +15,8 @@ var (
|
||||||
// connections and is used in all the builtin database types.
|
// connections and is used in all the builtin database types.
|
||||||
type ConnectionProducer interface {
|
type ConnectionProducer interface {
|
||||||
Close() error
|
Close() error
|
||||||
Initialize(map[string]interface{}, bool) error
|
Initialize(context.Context, map[string]interface{}, bool) error
|
||||||
Connection() (interface{}, error)
|
Connection(context.Context) (interface{}, error)
|
||||||
|
|
||||||
sync.Locker
|
sync.Locker
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package connutil
|
package connutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -25,7 +26,7 @@ type SQLConnectionProducer struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
|
|
||||||
|
@ -62,11 +63,11 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo
|
||||||
c.Initialized = true
|
c.Initialized = true
|
||||||
|
|
||||||
if verifyConnection {
|
if verifyConnection {
|
||||||
if _, err := c.Connection(); err != nil {
|
if _, err := c.Connection(ctx); err != nil {
|
||||||
return fmt.Errorf("error verifying connection: %s", err)
|
return fmt.Errorf("error verifying connection: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.db.Ping(); err != nil {
|
if err := c.db.PingContext(ctx); err != nil {
|
||||||
return fmt.Errorf("error verifying connection: %s", err)
|
return fmt.Errorf("error verifying connection: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -74,14 +75,14 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *SQLConnectionProducer) Connection() (interface{}, error) {
|
func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
|
||||||
if !c.Initialized {
|
if !c.Initialized {
|
||||||
return nil, ErrNotInitialized
|
return nil, ErrNotInitialized
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we already have a DB, test it and return
|
// If we already have a DB, test it and return
|
||||||
if c.db != nil {
|
if c.db != nil {
|
||||||
if err := c.db.Ping(); err == nil {
|
if err := c.db.PingContext(ctx); err == nil {
|
||||||
return c.db, nil
|
return c.db, nil
|
||||||
}
|
}
|
||||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/hashicorp/errwrap"
|
|
||||||
"github.com/hashicorp/go-uuid"
|
"github.com/hashicorp/go-uuid"
|
||||||
"github.com/hashicorp/vault/helper/consts"
|
"github.com/hashicorp/vault/helper/consts"
|
||||||
"github.com/hashicorp/vault/helper/jsonutil"
|
"github.com/hashicorp/vault/helper/jsonutil"
|
||||||
|
@ -460,10 +459,11 @@ func (c *Core) setupCredentials() error {
|
||||||
backend, err = c.newCredentialBackend(entry.Type, sysView, view, conf)
|
backend, err = c.newCredentialBackend(entry.Type, sysView, view, conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err)
|
c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err)
|
||||||
if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" {
|
if entry.Type == "plugin" {
|
||||||
// If we encounter an error instantiating the backend due to it being missing from the catalog,
|
// If we encounter an error instantiating the backend due to an error,
|
||||||
// skip backend initialization but register the entry to the mount table to preserve storage
|
// skip backend initialization but register the entry to the mount table
|
||||||
// and path.
|
// to preserve storage and path.
|
||||||
|
c.logger.Warn("core: skipping plugin-based credential entry", "path", entry.Path)
|
||||||
goto ROUTER_MOUNT
|
goto ROUTER_MOUNT
|
||||||
}
|
}
|
||||||
return errLoadAuthFailed
|
return errLoadAuthFailed
|
||||||
|
|
|
@ -750,6 +750,31 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sysView := m.router.MatchingSystemView(le.Path)
|
||||||
|
if sysView == nil {
|
||||||
|
return nil, fmt.Errorf("expiration: unable to retrieve system view from router")
|
||||||
|
}
|
||||||
|
|
||||||
|
retResp := &logical.Response{}
|
||||||
|
switch {
|
||||||
|
case resp.Auth.Period > time.Duration(0):
|
||||||
|
// If it resp.Period is non-zero, use that as the TTL and override backend's
|
||||||
|
// call on TTL modification, such as a TTL value determined by
|
||||||
|
// framework.LeaseExtend call against the request. Also, cap period value to
|
||||||
|
// the sys/mount max value.
|
||||||
|
if resp.Auth.Period > sysView.MaxLeaseTTL() {
|
||||||
|
retResp.AddWarning(fmt.Sprintf("Period of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL()))
|
||||||
|
resp.Auth.Period = sysView.MaxLeaseTTL()
|
||||||
|
}
|
||||||
|
resp.Auth.TTL = resp.Auth.Period
|
||||||
|
case resp.Auth.TTL > time.Duration(0):
|
||||||
|
// Cap TTL value to the sys/mount max value
|
||||||
|
if resp.Auth.TTL > sysView.MaxLeaseTTL() {
|
||||||
|
retResp.AddWarning(fmt.Sprintf("TTL of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL()))
|
||||||
|
resp.Auth.TTL = sysView.MaxLeaseTTL()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Attach the ClientToken
|
// Attach the ClientToken
|
||||||
resp.Auth.ClientToken = token
|
resp.Auth.ClientToken = token
|
||||||
resp.Auth.Increment = 0
|
resp.Auth.Increment = 0
|
||||||
|
@ -764,9 +789,9 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke
|
||||||
|
|
||||||
// Update the expiration time
|
// Update the expiration time
|
||||||
m.updatePending(le, resp.Auth.LeaseTotal())
|
m.updatePending(le, resp.Auth.LeaseTotal())
|
||||||
return &logical.Response{
|
|
||||||
Auth: resp.Auth,
|
retResp.Auth = resp.Auth
|
||||||
}, nil
|
return retResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register is used to take a request and response with an associated
|
// Register is used to take a request and response with an associated
|
||||||
|
@ -866,6 +891,12 @@ func (m *ExpirationManager) RegisterAuth(source string, auth *logical.Auth) erro
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If it resp.Period is non-zero, override the TTL value determined
|
||||||
|
// by the backend.
|
||||||
|
if auth.Period > time.Duration(0) {
|
||||||
|
auth.TTL = auth.Period
|
||||||
|
}
|
||||||
|
|
||||||
// Create a lease entry
|
// Create a lease entry
|
||||||
le := leaseEntry{
|
le := leaseEntry{
|
||||||
LeaseID: path.Join(source, saltedID),
|
LeaseID: path.Join(source, saltedID),
|
||||||
|
@ -1017,8 +1048,7 @@ func (m *ExpirationManager) revokeEntry(le *leaseEntry) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle standard revocation via backends
|
// Handle standard revocation via backends
|
||||||
resp, err := m.router.Route(logical.RevokeRequest(
|
resp, err := m.router.Route(logical.RevokeRequest(le.Path, le.Secret, le.Data))
|
||||||
le.Path, le.Secret, le.Data))
|
|
||||||
if err != nil || (resp != nil && resp.IsError()) {
|
if err != nil || (resp != nil && resp.IsError()) {
|
||||||
return fmt.Errorf("failed to revoke entry: resp:%#v err:%s", resp, err)
|
return fmt.Errorf("failed to revoke entry: resp:%#v err:%s", resp, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,9 @@ package vault_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -178,14 +180,13 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
|
||||||
}
|
}
|
||||||
|
|
||||||
if testMount {
|
if testMount {
|
||||||
// Add plugin back to the catalog
|
|
||||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical")
|
|
||||||
|
|
||||||
// Mount the plugin at the same path after plugin is re-added to the catalog
|
// Mount the plugin at the same path after plugin is re-added to the catalog
|
||||||
// and expect an error due to existing path.
|
// and expect an error due to existing path.
|
||||||
var err error
|
var err error
|
||||||
switch btype {
|
switch btype {
|
||||||
case logical.TypeLogical:
|
case logical.TypeLogical:
|
||||||
|
// Add plugin back to the catalog
|
||||||
|
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical")
|
||||||
_, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{
|
_, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{
|
||||||
"type": "plugin",
|
"type": "plugin",
|
||||||
"config": map[string]interface{}{
|
"config": map[string]interface{}{
|
||||||
|
@ -193,6 +194,8 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
case logical.TypeCredential:
|
case logical.TypeCredential:
|
||||||
|
// Add plugin back to the catalog
|
||||||
|
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials")
|
||||||
_, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{
|
_, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{
|
||||||
"type": "plugin",
|
"type": "plugin",
|
||||||
"plugin_name": "mock-plugin",
|
"plugin_name": "mock-plugin",
|
||||||
|
@ -204,6 +207,129 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSystemBackend_Plugin_continueOnError(t *testing.T) {
|
||||||
|
t.Run("secret", func(t *testing.T) {
|
||||||
|
t.Run("sha256_mismatch", func(t *testing.T) {
|
||||||
|
testPlugin_continueOnError(t, logical.TypeLogical, true)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing_plugin", func(t *testing.T) {
|
||||||
|
testPlugin_continueOnError(t, logical.TypeLogical, false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("auth", func(t *testing.T) {
|
||||||
|
t.Run("sha256_mismatch", func(t *testing.T) {
|
||||||
|
testPlugin_continueOnError(t, logical.TypeCredential, true)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing_plugin", func(t *testing.T) {
|
||||||
|
testPlugin_continueOnError(t, logical.TypeCredential, false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatch bool) {
|
||||||
|
cluster := testSystemBackendMock(t, 1, 1, btype)
|
||||||
|
defer cluster.Cleanup()
|
||||||
|
|
||||||
|
core := cluster.Cores[0]
|
||||||
|
|
||||||
|
// Get the registered plugin
|
||||||
|
req := logical.TestRequest(t, logical.ReadOperation, "sys/plugins/catalog/mock-plugin")
|
||||||
|
req.ClientToken = core.Client.Token()
|
||||||
|
resp, err := core.HandleRequest(req)
|
||||||
|
if err != nil || resp == nil || (resp != nil && resp.IsError()) {
|
||||||
|
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
command, ok := resp.Data["command"].(string)
|
||||||
|
if !ok || command == "" {
|
||||||
|
t.Fatal("invalid command")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trigger a sha256 mistmatch or missing plugin error
|
||||||
|
if mismatch {
|
||||||
|
req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/catalog/mock-plugin")
|
||||||
|
req.Data = map[string]interface{}{
|
||||||
|
"sha256": "d17bd7334758e53e6fbab15745d2520765c06e296f2ce8e25b7919effa0ac216",
|
||||||
|
"command": filepath.Base(command),
|
||||||
|
}
|
||||||
|
req.ClientToken = core.Client.Token()
|
||||||
|
resp, err = core.HandleRequest(req)
|
||||||
|
if err != nil || (resp != nil && resp.IsError()) {
|
||||||
|
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err := os.Remove(filepath.Join(cluster.TempDir, filepath.Base(command)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Seal the cluster
|
||||||
|
cluster.EnsureCoresSealed(t)
|
||||||
|
|
||||||
|
// Unseal the cluster
|
||||||
|
barrierKeys := cluster.BarrierKeys
|
||||||
|
for _, core := range cluster.Cores {
|
||||||
|
for _, key := range barrierKeys {
|
||||||
|
_, err := core.Unseal(vault.TestKeyCopy(key))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sealed, err := core.Sealed()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err checking seal status: %s", err)
|
||||||
|
}
|
||||||
|
if sealed {
|
||||||
|
t.Fatal("should not be sealed")
|
||||||
|
}
|
||||||
|
// Wait for active so post-unseal takes place
|
||||||
|
// If it fails, it means unseal process failed
|
||||||
|
vault.TestWaitActive(t, core.Core)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-add the plugin to the catalog
|
||||||
|
switch btype {
|
||||||
|
case logical.TypeLogical:
|
||||||
|
vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical", cluster.TempDir)
|
||||||
|
case logical.TypeCredential:
|
||||||
|
vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials", cluster.TempDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload the plugin
|
||||||
|
req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/reload/backend")
|
||||||
|
req.Data = map[string]interface{}{
|
||||||
|
"plugin": "mock-plugin",
|
||||||
|
}
|
||||||
|
req.ClientToken = core.Client.Token()
|
||||||
|
resp, err = core.HandleRequest(req)
|
||||||
|
if err != nil || (resp != nil && resp.IsError()) {
|
||||||
|
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a request to lazy load the plugin
|
||||||
|
var reqPath string
|
||||||
|
switch btype {
|
||||||
|
case logical.TypeLogical:
|
||||||
|
reqPath = "mock-0/internal"
|
||||||
|
case logical.TypeCredential:
|
||||||
|
reqPath = "auth/mock-0/internal"
|
||||||
|
}
|
||||||
|
|
||||||
|
req = logical.TestRequest(t, logical.ReadOperation, reqPath)
|
||||||
|
req.ClientToken = core.Client.Token()
|
||||||
|
resp, err = core.HandleRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
t.Fatalf("bad: response should not be nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSystemBackend_Plugin_autoReload(t *testing.T) {
|
func TestSystemBackend_Plugin_autoReload(t *testing.T) {
|
||||||
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
|
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
|
||||||
defer cluster.Cleanup()
|
defer cluster.Cleanup()
|
||||||
|
@ -332,7 +458,10 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// testSystemBackendMock returns a systemBackend with the desired number
|
// testSystemBackendMock returns a systemBackend with the desired number
|
||||||
// of mounted mock plugin backends
|
// of mounted mock plugin backends. numMounts alternates between different
|
||||||
|
// ways of providing the plugin_name.
|
||||||
|
//
|
||||||
|
// The mounts are mounted at sys/mounts/mock-[numMounts] or sys/auth/mock-[numMounts]
|
||||||
func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType) *vault.TestCluster {
|
func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType) *vault.TestCluster {
|
||||||
coreConfig := &vault.CoreConfig{
|
coreConfig := &vault.CoreConfig{
|
||||||
LogicalBackends: map[string]logical.Factory{
|
LogicalBackends: map[string]logical.Factory{
|
||||||
|
@ -343,10 +472,17 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a tempdir, cluster.Cleanup will clean up this directory
|
||||||
|
tempDir, err := ioutil.TempDir("", "vault-test-cluster")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
||||||
HandlerFunc: vaulthttp.Handler,
|
HandlerFunc: vaulthttp.Handler,
|
||||||
KeepStandbysSealed: true,
|
KeepStandbysSealed: true,
|
||||||
NumCores: numCores,
|
NumCores: numCores,
|
||||||
|
TempDir: tempDir,
|
||||||
})
|
})
|
||||||
cluster.Start()
|
cluster.Start()
|
||||||
|
|
||||||
|
@ -358,7 +494,7 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo
|
||||||
|
|
||||||
switch backendType {
|
switch backendType {
|
||||||
case logical.TypeLogical:
|
case logical.TypeLogical:
|
||||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical")
|
vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical", tempDir)
|
||||||
for i := 0; i < numMounts; i++ {
|
for i := 0; i < numMounts; i++ {
|
||||||
// Alternate input styles for plugin_name on every other mount
|
// Alternate input styles for plugin_name on every other mount
|
||||||
options := map[string]interface{}{
|
options := map[string]interface{}{
|
||||||
|
@ -380,7 +516,7 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case logical.TypeCredential:
|
case logical.TypeCredential:
|
||||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials")
|
vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials", tempDir)
|
||||||
for i := 0; i < numMounts; i++ {
|
for i := 0; i < numMounts; i++ {
|
||||||
// Alternate input styles for plugin_name on every other mount
|
// Alternate input styles for plugin_name on every other mount
|
||||||
options := map[string]interface{}{
|
options := map[string]interface{}{
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/errwrap"
|
|
||||||
"github.com/hashicorp/go-uuid"
|
"github.com/hashicorp/go-uuid"
|
||||||
"github.com/hashicorp/vault/helper/consts"
|
"github.com/hashicorp/vault/helper/consts"
|
||||||
"github.com/hashicorp/vault/helper/jsonutil"
|
"github.com/hashicorp/vault/helper/jsonutil"
|
||||||
|
@ -753,10 +752,11 @@ func (c *Core) setupMounts() error {
|
||||||
backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf)
|
backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err)
|
c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err)
|
||||||
if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" {
|
if entry.Type == "plugin" {
|
||||||
// If we encounter an error instantiating the backend due to it being missing from the catalog,
|
// If we encounter an error instantiating the backend due to an error,
|
||||||
// skip backend initialization but register the entry to the mount table to preserve storage
|
// skip backend initialization but register the entry to the mount table
|
||||||
// and path.
|
// to preserve storage and path.
|
||||||
|
c.logger.Warn("core: skipping plugin-based mount entry", "path", entry.Path)
|
||||||
goto ROUTER_MOUNT
|
goto ROUTER_MOUNT
|
||||||
}
|
}
|
||||||
return errLoadMountsFailed
|
return errLoadMountsFailed
|
||||||
|
|
|
@ -79,15 +79,23 @@ func (c *Core) reloadMatchingPlugin(pluginName string) error {
|
||||||
func (c *Core) reloadPluginCommon(entry *MountEntry, isAuth bool) error {
|
func (c *Core) reloadPluginCommon(entry *MountEntry, isAuth bool) error {
|
||||||
path := entry.Path
|
path := entry.Path
|
||||||
|
|
||||||
|
if isAuth {
|
||||||
|
path = credentialRoutePrefix + path
|
||||||
|
}
|
||||||
|
|
||||||
// Fast-path out if the backend doesn't exist
|
// Fast-path out if the backend doesn't exist
|
||||||
raw, ok := c.router.root.Get(path)
|
raw, ok := c.router.root.Get(path)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call backend's Cleanup routine
|
|
||||||
re := raw.(*routeEntry)
|
re := raw.(*routeEntry)
|
||||||
re.backend.Cleanup()
|
|
||||||
|
// Only call Cleanup if backend is initialized
|
||||||
|
if re.backend != nil {
|
||||||
|
// Call backend's Cleanup routine
|
||||||
|
re.backend.Cleanup()
|
||||||
|
}
|
||||||
|
|
||||||
view := re.storageView
|
view := re.storageView
|
||||||
|
|
||||||
|
|
|
@ -477,14 +477,27 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
|
||||||
return nil, nil, ErrInternalError
|
return nil, nil, ErrInternalError
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the default lease if not provided
|
// Start off with the sys default value, and update according to period/TTL
|
||||||
if auth.TTL == 0 {
|
// from resp.Auth
|
||||||
auth.TTL = sysView.DefaultLeaseTTL()
|
tokenTTL := sysView.DefaultLeaseTTL()
|
||||||
}
|
|
||||||
|
|
||||||
// Limit the lease duration
|
switch {
|
||||||
if auth.TTL > sysView.MaxLeaseTTL() {
|
case auth.Period > time.Duration(0):
|
||||||
auth.TTL = sysView.MaxLeaseTTL()
|
// Cap the period value to the sys max_ttl value. The auth backend should
|
||||||
|
// have checked for it on its login path, but we check here again for
|
||||||
|
// sanity.
|
||||||
|
if auth.Period > sysView.MaxLeaseTTL() {
|
||||||
|
auth.Period = sysView.MaxLeaseTTL()
|
||||||
|
}
|
||||||
|
tokenTTL = auth.Period
|
||||||
|
case auth.TTL > time.Duration(0):
|
||||||
|
// Cap the TTL value. The auth backend should have checked for it on its
|
||||||
|
// login path (e.g. a call to b.SanitizeTTL), but we check here again for
|
||||||
|
// sanity.
|
||||||
|
if auth.TTL > sysView.MaxLeaseTTL() {
|
||||||
|
auth.TTL = sysView.MaxLeaseTTL()
|
||||||
|
}
|
||||||
|
tokenTTL = auth.TTL
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate a token
|
// Generate a token
|
||||||
|
@ -494,7 +507,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
|
||||||
Meta: auth.Metadata,
|
Meta: auth.Metadata,
|
||||||
DisplayName: auth.DisplayName,
|
DisplayName: auth.DisplayName,
|
||||||
CreationTime: time.Now().Unix(),
|
CreationTime: time.Now().Unix(),
|
||||||
TTL: auth.TTL,
|
TTL: tokenTTL,
|
||||||
NumUses: auth.NumUses,
|
NumUses: auth.NumUses,
|
||||||
EntityID: auth.EntityID,
|
EntityID: auth.EntityID,
|
||||||
}
|
}
|
||||||
|
@ -513,10 +526,11 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
|
||||||
return nil, auth, ErrInternalError
|
return nil, auth, ErrInternalError
|
||||||
}
|
}
|
||||||
|
|
||||||
// Populate the client token and accessor
|
// Populate the client token, accessor, and TTL
|
||||||
auth.ClientToken = te.ID
|
auth.ClientToken = te.ID
|
||||||
auth.Accessor = te.Accessor
|
auth.Accessor = te.Accessor
|
||||||
auth.Policies = te.Policies
|
auth.Policies = te.Policies
|
||||||
|
auth.TTL = te.TTL
|
||||||
|
|
||||||
// Register with the expiration manager
|
// Register with the expiration manager
|
||||||
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
|
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
|
||||||
|
|
|
@ -379,6 +379,8 @@ func TestDynamicSystemView(c *Core) *dynamicSystemView {
|
||||||
return &dynamicSystemView{c, me}
|
return &dynamicSystemView{c, me}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAddTestPlugin registers the testFunc as part of the plugin command to the
|
||||||
|
// plugin catalog.
|
||||||
func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) {
|
func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) {
|
||||||
file, err := os.Open(os.Args[0])
|
file, err := os.Open(os.Args[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -413,11 +415,74 @@ func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAddTestPluginTempDir registers the testFunc as part of the plugin command to the
|
||||||
|
// plugin catalog. It uses tmpDir as the plugin directory.
|
||||||
|
func TestAddTestPluginTempDir(t testing.T, c *Core, name, testFunc, tempDir string) {
|
||||||
|
file, err := os.Open(os.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
fi, err := file.Stat()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy over the file to the temp dir
|
||||||
|
dst := filepath.Join(tempDir, filepath.Base(os.Args[0]))
|
||||||
|
out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer out.Close()
|
||||||
|
|
||||||
|
if _, err = io.Copy(out, file); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
err = out.Sync()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine plugin directory full path
|
||||||
|
fullPath, err := filepath.EvalSymlinks(tempDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, err := os.Open(filepath.Join(fullPath, filepath.Base(os.Args[0])))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
// Find out the sha256
|
||||||
|
hash := sha256.New()
|
||||||
|
|
||||||
|
_, err = io.Copy(hash, reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sum := hash.Sum(nil)
|
||||||
|
|
||||||
|
// Set core's plugin directory and plugin catalog directory
|
||||||
|
c.pluginDirectory = fullPath
|
||||||
|
c.pluginCatalog.directory = fullPath
|
||||||
|
|
||||||
|
command := fmt.Sprintf("%s --test.run=%s", filepath.Base(os.Args[0]), testFunc)
|
||||||
|
err = c.pluginCatalog.Set(name, command, sum)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var testLogicalBackends = map[string]logical.Factory{}
|
var testLogicalBackends = map[string]logical.Factory{}
|
||||||
var testCredentialBackends = map[string]logical.Factory{}
|
var testCredentialBackends = map[string]logical.Factory{}
|
||||||
|
|
||||||
// Starts the test server which responds to SSH authentication.
|
// StartSSHHostTestServer starts the test server which responds to SSH
|
||||||
// Used to test the SSH secret backend.
|
// authentication. Used to test the SSH secret backend.
|
||||||
func StartSSHHostTestServer() (string, error) {
|
func StartSSHHostTestServer() (string, error) {
|
||||||
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testSharedPublicKey))
|
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testSharedPublicKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -760,6 +825,7 @@ func (c *TestCluster) ensureCoresSealed() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnsealWithStoredKeys uses stored keys to unseal the test cluster cores
|
||||||
func (c *TestCluster) UnsealWithStoredKeys(t testing.T) error {
|
func (c *TestCluster) UnsealWithStoredKeys(t testing.T) error {
|
||||||
for _, core := range c.Cores {
|
for _, core := range c.Cores {
|
||||||
if err := core.UnsealWithStoredKeys(); err != nil {
|
if err := core.UnsealWithStoredKeys(); err != nil {
|
||||||
|
|
|
@ -2312,7 +2312,7 @@ func TestTokenStore_RolePeriod(t *testing.T) {
|
||||||
req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test")
|
req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test")
|
||||||
req.ClientToken = root
|
req.ClientToken = root
|
||||||
req.Data = map[string]interface{}{
|
req.Data = map[string]interface{}{
|
||||||
"period": 300,
|
"period": 5,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := core.HandleRequest(req)
|
resp, err := core.HandleRequest(req)
|
||||||
|
@ -2425,8 +2425,8 @@ func TestTokenStore_RolePeriod(t *testing.T) {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
ttl := resp.Data["ttl"].(int64)
|
ttl := resp.Data["ttl"].(int64)
|
||||||
if ttl < 299 {
|
if ttl > 5 {
|
||||||
t.Fatalf("TTL too small (expected %d, got %d", 299, ttl)
|
t.Fatalf("TTL too large (expected %d, got %d", 5, ttl)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let the TTL go down a bit to 3 seconds
|
// Let the TTL go down a bit to 3 seconds
|
||||||
|
@ -2449,8 +2449,8 @@ func TestTokenStore_RolePeriod(t *testing.T) {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
ttl = resp.Data["ttl"].(int64)
|
ttl = resp.Data["ttl"].(int64)
|
||||||
if ttl < 299 {
|
if ttl > 5 {
|
||||||
t.Fatalf("TTL too small (expected %d, got %d", 299, ttl)
|
t.Fatalf("TTL too large (expected %d, got %d", 5, ttl)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2677,7 +2677,7 @@ func TestTokenStore_Periodic(t *testing.T) {
|
||||||
req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test")
|
req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test")
|
||||||
req.ClientToken = root
|
req.ClientToken = root
|
||||||
req.Data = map[string]interface{}{
|
req.Data = map[string]interface{}{
|
||||||
"period": 300,
|
"period": 5,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := core.HandleRequest(req)
|
resp, err := core.HandleRequest(req)
|
||||||
|
@ -2715,8 +2715,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
ttl := resp.Data["ttl"].(int64)
|
ttl := resp.Data["ttl"].(int64)
|
||||||
if ttl < 299 {
|
if ttl > 5 {
|
||||||
t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl)
|
t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let the TTL go down a bit
|
// Let the TTL go down a bit
|
||||||
|
@ -2739,8 +2739,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
ttl = resp.Data["ttl"].(int64)
|
ttl = resp.Data["ttl"].(int64)
|
||||||
if ttl < 299 {
|
if ttl > 5 {
|
||||||
t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl)
|
t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2750,8 +2750,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
||||||
req.Operation = logical.UpdateOperation
|
req.Operation = logical.UpdateOperation
|
||||||
req.Path = "auth/token/create"
|
req.Path = "auth/token/create"
|
||||||
req.Data = map[string]interface{}{
|
req.Data = map[string]interface{}{
|
||||||
"period": 300,
|
"period": 5,
|
||||||
"explicit_max_ttl": 150,
|
"explicit_max_ttl": 4,
|
||||||
}
|
}
|
||||||
resp, err = core.HandleRequest(req)
|
resp, err = core.HandleRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2775,8 +2775,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
ttl := resp.Data["ttl"].(int64)
|
ttl := resp.Data["ttl"].(int64)
|
||||||
if ttl < 149 || ttl > 150 {
|
if ttl < 3 || ttl > 4 {
|
||||||
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl)
|
t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let the TTL go down a bit
|
// Let the TTL go down a bit
|
||||||
|
@ -2799,8 +2799,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
ttl = resp.Data["ttl"].(int64)
|
ttl = resp.Data["ttl"].(int64)
|
||||||
if ttl < 140 || ttl > 150 {
|
if ttl > 2 {
|
||||||
t.Fatalf("TTL bad (expected around %d, got %d)", 145, ttl)
|
t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2812,7 +2812,7 @@ func TestTokenStore_Periodic(t *testing.T) {
|
||||||
req.Operation = logical.UpdateOperation
|
req.Operation = logical.UpdateOperation
|
||||||
req.Path = "auth/token/create/test"
|
req.Path = "auth/token/create/test"
|
||||||
req.Data = map[string]interface{}{
|
req.Data = map[string]interface{}{
|
||||||
"period": 150,
|
"period": 5,
|
||||||
}
|
}
|
||||||
resp, err = core.HandleRequest(req)
|
resp, err = core.HandleRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2836,8 +2836,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
ttl := resp.Data["ttl"].(int64)
|
ttl := resp.Data["ttl"].(int64)
|
||||||
if ttl < 149 || ttl > 150 {
|
if ttl < 4 || ttl > 5 {
|
||||||
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl)
|
t.Fatalf("TTL bad (expected %d, got %d)", 4, ttl)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let the TTL go down a bit
|
// Let the TTL go down a bit
|
||||||
|
@ -2860,8 +2860,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
ttl = resp.Data["ttl"].(int64)
|
ttl = resp.Data["ttl"].(int64)
|
||||||
if ttl < 149 {
|
if ttl > 5 {
|
||||||
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl)
|
t.Fatalf("TTL bad (expected less than %d, got %d)", 5, ttl)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2869,18 +2869,23 @@ func TestTokenStore_Periodic(t *testing.T) {
|
||||||
{
|
{
|
||||||
req.Path = "auth/token/roles/test"
|
req.Path = "auth/token/roles/test"
|
||||||
req.ClientToken = root
|
req.ClientToken = root
|
||||||
|
req.Operation = logical.UpdateOperation
|
||||||
req.Data = map[string]interface{}{
|
req.Data = map[string]interface{}{
|
||||||
"period": 300,
|
"period": 5,
|
||||||
"explicit_max_ttl": 150,
|
"explicit_max_ttl": 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := core.HandleRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v %v", err, resp)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
t.Fatalf("expected a nil response")
|
||||||
}
|
}
|
||||||
|
|
||||||
req.ClientToken = root
|
req.ClientToken = root
|
||||||
req.Operation = logical.UpdateOperation
|
req.Operation = logical.UpdateOperation
|
||||||
req.Path = "auth/token/create/test"
|
req.Path = "auth/token/create/test"
|
||||||
req.Data = map[string]interface{}{
|
|
||||||
"period": 150,
|
|
||||||
"explicit_max_ttl": 130,
|
|
||||||
}
|
|
||||||
resp, err = core.HandleRequest(req)
|
resp, err = core.HandleRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %v %v", err, resp)
|
t.Fatalf("err: %v %v", err, resp)
|
||||||
|
@ -2903,12 +2908,12 @@ func TestTokenStore_Periodic(t *testing.T) {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
ttl := resp.Data["ttl"].(int64)
|
ttl := resp.Data["ttl"].(int64)
|
||||||
if ttl < 129 || ttl > 130 {
|
if ttl < 3 || ttl > 4 {
|
||||||
t.Fatalf("TTL bad (expected %d, got %d)", 129, ttl)
|
t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let the TTL go down a bit
|
// Let the TTL go down a bit
|
||||||
time.Sleep(4 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
req.Operation = logical.UpdateOperation
|
req.Operation = logical.UpdateOperation
|
||||||
req.Path = "auth/token/renew-self"
|
req.Path = "auth/token/renew-self"
|
||||||
|
@ -2927,8 +2932,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
ttl = resp.Data["ttl"].(int64)
|
ttl = resp.Data["ttl"].(int64)
|
||||||
if ttl > 127 {
|
if ttl > 2 {
|
||||||
t.Fatalf("TTL bad (expected < %d, got %d)", 128, ttl)
|
t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue