diff --git a/CHANGELOG.md b/CHANGELOG.md index 280b07b05..b2d7c0532 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,8 @@ FEATURES: operation that can export a given key, including all key versions and configuration, as well as a restore operation allowing import into another Vault. + * **gRPC Database Plugins**: Database plugins now use gRPC for transport, + allowing them to be written in other languages. IMPROVEMENTS: @@ -46,6 +48,10 @@ IMPROVEMENTS: during database configuration. This establishes a session-wide [write concern](https://docs.mongodb.com/manual/reference/write-concern/) for the lifecycle of the mount [GH-3646] + * mfa/okta: Filter a given email address as a login filter, allowing operation + when login email and account email are different + * plugins: Make Vault more resilient when unsealing when plugins are + unavailable [GH-3686] * secret/pki: `allowed_domains` and `key_usage` can now be specified as a comma-separated string or an array of strings [GH-3642] * secret/ssh: Allow 4096-bit keys to be used in dynamic key method [GH-3593] @@ -58,8 +64,12 @@ BUG FIXES: * auth/cert: Return `allowed_names` on role read [GH-3654] * auth/ldap: Fix incorrect control information being sent [GH-3402] [GH-3496] [GH-3625] [GH-3656] + * core: Fix seal status reporting when using an autoseal + * core: Add creation path to wrap info for a control group token * core: Fix potential panic that could occur using plugins when a node transitioned from active to standby [GH-3638] + * core: Fix memory ballooning when a connection would connect to the cluster + port and then go away -- redux! [GH-3680] * core: Replace recursive token revocation logic with depth-first logic, which can avoid hitting stack depth limits in extreme cases [GH-2348] * core/pkcs11 (enterprise): Fix panic when PKCS#11 library is not readable diff --git a/Makefile b/Makefile index 4e228c19c..60fcc39c4 100644 --- a/Makefile +++ b/Makefile @@ -84,6 +84,7 @@ proto: protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding protoc -I physical physical/types.proto --go_out=plugins=grpc:physical protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity + protoc builtin/logical/database/dbplugin/*.proto --go_out=plugins=grpc:. sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go sed -i -e 's/Iv/IV/' -e 's/Hmac/HMAC/' physical/types.pb.go diff --git a/builtin/credential/app-id/backend_test.go b/builtin/credential/app-id/backend_test.go index 4ae5d3e1c..e5d335b4f 100644 --- a/builtin/credential/app-id/backend_test.go +++ b/builtin/credential/app-id/backend_test.go @@ -141,7 +141,7 @@ func testAccStepMapUserIdCidr(t *testing.T, cidr string) logicaltest.TestStep { func testAccLogin(t *testing.T, display string) logicaltest.TestStep { checkTTL := func(resp *logical.Response) error { if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" { - return fmt.Errorf("invalid TTL") + return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL) } return nil } @@ -165,7 +165,7 @@ func testAccLogin(t *testing.T, display string) logicaltest.TestStep { func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep { checkTTL := func(resp *logical.Response) error { if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" { - return fmt.Errorf("invalid TTL") + return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL) } return nil } diff --git a/builtin/credential/approle/path_login.go b/builtin/credential/approle/path_login.go index 300ee9409..3dd829a84 100644 --- a/builtin/credential/approle/path_login.go +++ b/builtin/credential/approle/path_login.go @@ -3,7 +3,6 @@ package approle import ( "fmt" "strings" - "time" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -68,20 +67,13 @@ func (b *backend) pathLoginUpdate(req *logical.Request, data *framework.FieldDat Policies: role.Policies, LeaseOptions: logical.LeaseOptions{ Renewable: true, + TTL: role.TokenTTL, }, Alias: &logical.Alias{ Name: role.RoleID, }, } - // If 'Period' is set, use the value of 'Period' as the TTL. - // Otherwise, set the normal TokenTTL. - if role.Period > time.Duration(0) { - auth.TTL = role.Period - } else { - auth.TTL = role.TokenTTL - } - return &logical.Response{ Auth: auth, }, nil @@ -107,16 +99,12 @@ func (b *backend) pathLoginRenew(req *logical.Request, data *framework.FieldData return nil, fmt.Errorf("role %s does not exist during renewal", roleName) } - // If 'Period' is set on the Role, the token should never expire. - // Replenish the TTL with 'Period's value. - if role.Period > time.Duration(0) { - // If 'Period' was updated after the token was issued, - // token will bear the updated 'Period' value as its TTL. - req.Auth.TTL = role.Period - return &logical.Response{Auth: req.Auth}, nil - } else { - return framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data) + resp, err := framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data) + if err != nil { + return nil, err } + resp.Auth.Period = role.Period + return resp, nil } const pathLoginHelpSys = "Issue a token based on the credentials supplied" diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index a8a54a7cd..d605153a4 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "net/rpc" "strings" @@ -87,7 +88,7 @@ func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) { // This function creates a new db object from the stored configuration and // caches it in the connections map. The caller of this function needs to hold // the backend's write lock -func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) { +func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) { db, ok := b.connections[name] if ok { return db, nil @@ -103,7 +104,7 @@ func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin. return nil, err } - err = db.Initialize(config.ConnectionDetails, true) + err = db.Initialize(ctx, config.ConnectionDetails, true) if err != nil { return nil, err } @@ -170,7 +171,8 @@ func (b *databaseBackend) clearConnection(name string) { func (b *databaseBackend) closeIfShutdown(name string, err error) { // Plugin has shutdown, close it so next call can reconnect. - if err == rpc.ErrShutdown { + switch err { + case rpc.ErrShutdown, dbplugin.ErrPluginShutdown: b.Lock() b.clearConnection(name) b.Unlock() diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index f4dfd3b48..35d3639cd 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -488,9 +488,11 @@ func TestBackend_roleCrud(t *testing.T) { RevocationStatements: defaultRevocationSQL, } - var actual dbplugin.Statements - if err := mapstructure.Decode(resp.Data, &actual); err != nil { - t.Fatal(err) + actual := dbplugin.Statements{ + CreationStatements: resp.Data["creation_statements"].(string), + RevocationStatements: resp.Data["revocation_statements"].(string), + RollbackStatements: resp.Data["rollback_statements"].(string), + RenewStatements: resp.Data["renew_statements"].(string), } if !reflect.DeepEqual(expected, actual) { diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index 6df39489f..1d36386bc 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -1,10 +1,8 @@ package dbplugin import ( - "fmt" - "net/rpc" + "errors" "sync" - "time" "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/helper/pluginutil" @@ -17,11 +15,11 @@ type DatabasePluginClient struct { client *plugin.Client sync.Mutex - *databasePluginRPCClient + Database } func (dc *DatabasePluginClient) Close() error { - err := dc.databasePluginRPCClient.Close() + err := dc.Database.Close() dc.client.Kill() return err @@ -55,79 +53,20 @@ func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginR // We should have a database type now. This feels like a normal interface // implementation but is in fact over an RPC connection. - databaseRPC := raw.(*databasePluginRPCClient) + var db Database + switch raw.(type) { + case *gRPCClient: + db = raw.(*gRPCClient) + case *databasePluginRPCClient: + logger.Warn("database: plugin is using deprecated net RPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name) + db = raw.(*databasePluginRPCClient) + default: + return nil, errors.New("unsupported client type") + } // Wrap RPC implimentation in DatabasePluginClient return &DatabasePluginClient{ - client: client, - databasePluginRPCClient: databaseRPC, + client: client, + Database: db, }, nil } - -// ---- RPC client domain ---- - -// databasePluginRPCClient implements Database and is used on the client to -// make RPC calls to a plugin. -type databasePluginRPCClient struct { - client *rpc.Client -} - -func (dr *databasePluginRPCClient) Type() (string, error) { - var dbType string - err := dr.client.Call("Plugin.Type", struct{}{}, &dbType) - - return fmt.Sprintf("plugin-%s", dbType), err -} - -func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { - req := CreateUserRequest{ - Statements: statements, - UsernameConfig: usernameConfig, - Expiration: expiration, - } - - var resp CreateUserResponse - err = dr.client.Call("Plugin.CreateUser", req, &resp) - - return resp.Username, resp.Password, err -} - -func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error { - req := RenewUserRequest{ - Statements: statements, - Username: username, - Expiration: expiration, - } - - err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error { - req := RevokeUserRequest{ - Statements: statements, - Username: username, - } - - err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}, verifyConnection bool) error { - req := InitializeRequest{ - Config: conf, - VerifyConnection: verifyConnection, - } - - err := dr.client.Call("Plugin.Initialize", req, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) Close() error { - err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) - - return err -} diff --git a/builtin/logical/database/dbplugin/database.pb.go b/builtin/logical/database/dbplugin/database.pb.go new file mode 100644 index 000000000..c4c410196 --- /dev/null +++ b/builtin/logical/database/dbplugin/database.pb.go @@ -0,0 +1,556 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: builtin/logical/database/dbplugin/database.proto + +/* +Package dbplugin is a generated protocol buffer package. + +It is generated from these files: + builtin/logical/database/dbplugin/database.proto + +It has these top-level messages: + InitializeRequest + CreateUserRequest + RenewUserRequest + RevokeUserRequest + Statements + UsernameConfig + CreateUserResponse + TypeResponse + Empty +*/ +package dbplugin + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" +import google_protobuf "github.com/golang/protobuf/ptypes/timestamp" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type InitializeRequest struct { + Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` + VerifyConnection bool `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"` +} + +func (m *InitializeRequest) Reset() { *m = InitializeRequest{} } +func (m *InitializeRequest) String() string { return proto.CompactTextString(m) } +func (*InitializeRequest) ProtoMessage() {} +func (*InitializeRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *InitializeRequest) GetConfig() []byte { + if m != nil { + return m.Config + } + return nil +} + +func (m *InitializeRequest) GetVerifyConnection() bool { + if m != nil { + return m.VerifyConnection + } + return false +} + +type CreateUserRequest struct { + Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"` + UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"` + Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"` +} + +func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} } +func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) } +func (*CreateUserRequest) ProtoMessage() {} +func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *CreateUserRequest) GetStatements() *Statements { + if m != nil { + return m.Statements + } + return nil +} + +func (m *CreateUserRequest) GetUsernameConfig() *UsernameConfig { + if m != nil { + return m.UsernameConfig + } + return nil +} + +func (m *CreateUserRequest) GetExpiration() *google_protobuf.Timestamp { + if m != nil { + return m.Expiration + } + return nil +} + +type RenewUserRequest struct { + Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"` + Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"` + Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"` +} + +func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} } +func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) } +func (*RenewUserRequest) ProtoMessage() {} +func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } + +func (m *RenewUserRequest) GetStatements() *Statements { + if m != nil { + return m.Statements + } + return nil +} + +func (m *RenewUserRequest) GetUsername() string { + if m != nil { + return m.Username + } + return "" +} + +func (m *RenewUserRequest) GetExpiration() *google_protobuf.Timestamp { + if m != nil { + return m.Expiration + } + return nil +} + +type RevokeUserRequest struct { + Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"` + Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"` +} + +func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} } +func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) } +func (*RevokeUserRequest) ProtoMessage() {} +func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} } + +func (m *RevokeUserRequest) GetStatements() *Statements { + if m != nil { + return m.Statements + } + return nil +} + +func (m *RevokeUserRequest) GetUsername() string { + if m != nil { + return m.Username + } + return "" +} + +type Statements struct { + CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"` + RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"` + RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"` + RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"` +} + +func (m *Statements) Reset() { *m = Statements{} } +func (m *Statements) String() string { return proto.CompactTextString(m) } +func (*Statements) ProtoMessage() {} +func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} } + +func (m *Statements) GetCreationStatements() string { + if m != nil { + return m.CreationStatements + } + return "" +} + +func (m *Statements) GetRevocationStatements() string { + if m != nil { + return m.RevocationStatements + } + return "" +} + +func (m *Statements) GetRollbackStatements() string { + if m != nil { + return m.RollbackStatements + } + return "" +} + +func (m *Statements) GetRenewStatements() string { + if m != nil { + return m.RenewStatements + } + return "" +} + +type UsernameConfig struct { + DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"` + RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"` +} + +func (m *UsernameConfig) Reset() { *m = UsernameConfig{} } +func (m *UsernameConfig) String() string { return proto.CompactTextString(m) } +func (*UsernameConfig) ProtoMessage() {} +func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} } + +func (m *UsernameConfig) GetDisplayName() string { + if m != nil { + return m.DisplayName + } + return "" +} + +func (m *UsernameConfig) GetRoleName() string { + if m != nil { + return m.RoleName + } + return "" +} + +type CreateUserResponse struct { + Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"` + Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"` +} + +func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} } +func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) } +func (*CreateUserResponse) ProtoMessage() {} +func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} } + +func (m *CreateUserResponse) GetUsername() string { + if m != nil { + return m.Username + } + return "" +} + +func (m *CreateUserResponse) GetPassword() string { + if m != nil { + return m.Password + } + return "" +} + +type TypeResponse struct { + Type string `protobuf:"bytes,1,opt,name=type" json:"type,omitempty"` +} + +func (m *TypeResponse) Reset() { *m = TypeResponse{} } +func (m *TypeResponse) String() string { return proto.CompactTextString(m) } +func (*TypeResponse) ProtoMessage() {} +func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} } + +func (m *TypeResponse) GetType() string { + if m != nil { + return m.Type + } + return "" +} + +type Empty struct { +} + +func (m *Empty) Reset() { *m = Empty{} } +func (m *Empty) String() string { return proto.CompactTextString(m) } +func (*Empty) ProtoMessage() {} +func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} } + +func init() { + proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest") + proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest") + proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest") + proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest") + proto.RegisterType((*Statements)(nil), "dbplugin.Statements") + proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig") + proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse") + proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse") + proto.RegisterType((*Empty)(nil), "dbplugin.Empty") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for Database service + +type DatabaseClient interface { + Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) + CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) + RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) + RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) + Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) + Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) +} + +type databaseClient struct { + cc *grpc.ClientConn +} + +func NewDatabaseClient(cc *grpc.ClientConn) DatabaseClient { + return &databaseClient{cc} +} + +func (c *databaseClient) Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) { + out := new(TypeResponse) + err := grpc.Invoke(ctx, "/dbplugin.Database/Type", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *databaseClient) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) { + out := new(CreateUserResponse) + err := grpc.Invoke(ctx, "/dbplugin.Database/CreateUser", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *databaseClient) RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) { + out := new(Empty) + err := grpc.Invoke(ctx, "/dbplugin.Database/RenewUser", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) { + out := new(Empty) + err := grpc.Invoke(ctx, "/dbplugin.Database/RevokeUser", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) { + out := new(Empty) + err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) { + out := new(Empty) + err := grpc.Invoke(ctx, "/dbplugin.Database/Close", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for Database service + +type DatabaseServer interface { + Type(context.Context, *Empty) (*TypeResponse, error) + CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error) + RenewUser(context.Context, *RenewUserRequest) (*Empty, error) + RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error) + Initialize(context.Context, *InitializeRequest) (*Empty, error) + Close(context.Context, *Empty) (*Empty, error) +} + +func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) { + s.RegisterService(&_Database_serviceDesc, srv) +} + +func _Database_Type_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Empty) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DatabaseServer).Type(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dbplugin.Database/Type", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DatabaseServer).Type(ctx, req.(*Empty)) + } + return interceptor(ctx, in, info, handler) +} + +func _Database_CreateUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CreateUserRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DatabaseServer).CreateUser(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dbplugin.Database/CreateUser", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DatabaseServer).CreateUser(ctx, req.(*CreateUserRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Database_RenewUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RenewUserRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DatabaseServer).RenewUser(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dbplugin.Database/RenewUser", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DatabaseServer).RenewUser(ctx, req.(*RenewUserRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RevokeUserRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DatabaseServer).RevokeUser(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dbplugin.Database/RevokeUser", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DatabaseServer).RevokeUser(ctx, req.(*RevokeUserRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(InitializeRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DatabaseServer).Initialize(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dbplugin.Database/Initialize", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Empty) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DatabaseServer).Close(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dbplugin.Database/Close", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DatabaseServer).Close(ctx, req.(*Empty)) + } + return interceptor(ctx, in, info, handler) +} + +var _Database_serviceDesc = grpc.ServiceDesc{ + ServiceName: "dbplugin.Database", + HandlerType: (*DatabaseServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Type", + Handler: _Database_Type_Handler, + }, + { + MethodName: "CreateUser", + Handler: _Database_CreateUser_Handler, + }, + { + MethodName: "RenewUser", + Handler: _Database_RenewUser_Handler, + }, + { + MethodName: "RevokeUser", + Handler: _Database_RevokeUser_Handler, + }, + { + MethodName: "Initialize", + Handler: _Database_Initialize_Handler, + }, + { + MethodName: "Close", + Handler: _Database_Close_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "builtin/logical/database/dbplugin/database.proto", +} + +func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 548 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e, + 0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f, + 0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e, + 0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2, + 0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6, + 0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c, + 0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e, + 0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3, + 0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1, + 0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f, + 0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49, + 0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35, + 0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20, + 0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6, + 0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34, + 0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0, + 0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce, + 0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68, + 0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0, + 0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8, + 0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1, + 0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0, + 0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a, + 0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab, + 0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c, + 0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4, + 0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36, + 0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a, + 0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf, + 0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67, + 0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b, + 0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d, + 0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6, + 0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56, + 0x94, 0x05, 0x00, 0x00, +} diff --git a/builtin/logical/database/dbplugin/database.proto b/builtin/logical/database/dbplugin/database.proto new file mode 100644 index 000000000..d5e7d4068 --- /dev/null +++ b/builtin/logical/database/dbplugin/database.proto @@ -0,0 +1,58 @@ +syntax = "proto3"; +package dbplugin; + +import "google/protobuf/timestamp.proto"; + +message InitializeRequest { + bytes config = 1; + bool verify_connection = 2; +} + +message CreateUserRequest { + Statements statements = 1; + UsernameConfig username_config = 2; + google.protobuf.Timestamp expiration = 3; +} + +message RenewUserRequest { + Statements statements = 1; + string username = 2; + google.protobuf.Timestamp expiration = 3; +} + +message RevokeUserRequest { + Statements statements = 1; + string username = 2; +} + +message Statements { + string creation_statements = 1; + string revocation_statements = 2; + string rollback_statements = 3; + string renew_statements = 4; +} + +message UsernameConfig { + string DisplayName = 1; + string RoleName = 2; +} + +message CreateUserResponse { + string username = 1; + string password = 2; +} + +message TypeResponse { + string type = 1; +} + +message Empty {} + +service Database { + rpc Type(Empty) returns (TypeResponse); + rpc CreateUser(CreateUserRequest) returns (CreateUserResponse); + rpc RenewUser(RenewUserRequest) returns (Empty); + rpc RevokeUser(RevokeUserRequest) returns (Empty); + rpc Initialize(InitializeRequest) returns (Empty); + rpc Close(Empty) returns (Empty); +} diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index 87dfa6c31..c8bbdf61d 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -1,6 +1,7 @@ package dbplugin import ( + "context" "time" metrics "github.com/armon/go-metrics" @@ -15,55 +16,56 @@ type databaseTracingMiddleware struct { next Database logger log.Logger - typeStr string + typeStr string + transport string } func (mw *databaseTracingMiddleware) Type() (string, error) { return mw.next.Type() } -func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (mw *databaseTracingMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { defer func(then time.Time) { - mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr) - return mw.next.CreateUser(statements, usernameConfig, expiration) + mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr, "transport", mw.transport) + return mw.next.CreateUser(ctx, statements, usernameConfig, expiration) } -func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { +func (mw *databaseTracingMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) { defer func(then time.Time) { - mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr) - return mw.next.RenewUser(statements, username, expiration) + mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr, "transport", mw.transport) + return mw.next.RenewUser(ctx, statements, username, expiration) } -func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) { +func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) { defer func(then time.Time) { - mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr) - return mw.next.RevokeUser(statements, username) + mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr, "transport", mw.transport) + return mw.next.RevokeUser(ctx, statements, username) } -func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { +func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) { defer func(then time.Time) { - mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr) - return mw.next.Initialize(conf, verifyConnection) + mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport) + return mw.next.Initialize(ctx, conf, verifyConnection) } func (mw *databaseTracingMiddleware) Close() (err error) { defer func(then time.Time) { - mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr) + mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr, "transport", mw.transport) return mw.next.Close() } @@ -81,7 +83,7 @@ func (mw *databaseMetricsMiddleware) Type() (string, error) { return mw.next.Type() } -func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (mw *databaseMetricsMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "CreateUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now) @@ -94,10 +96,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameC metrics.IncrCounter([]string{"database", "CreateUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1) - return mw.next.CreateUser(statements, usernameConfig, expiration) + return mw.next.CreateUser(ctx, statements, usernameConfig, expiration) } -func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { +func (mw *databaseMetricsMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "RenewUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now) @@ -110,10 +112,10 @@ func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username s metrics.IncrCounter([]string{"database", "RenewUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1) - return mw.next.RenewUser(statements, username, expiration) + return mw.next.RenewUser(ctx, statements, username, expiration) } -func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username string) (err error) { +func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "RevokeUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now) @@ -126,10 +128,10 @@ func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username metrics.IncrCounter([]string{"database", "RevokeUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1) - return mw.next.RevokeUser(statements, username) + return mw.next.RevokeUser(ctx, statements, username) } -func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { +func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "Initialize"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now) @@ -142,7 +144,7 @@ func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, ver metrics.IncrCounter([]string{"database", "Initialize"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1) - return mw.next.Initialize(conf, verifyConnection) + return mw.next.Initialize(ctx, conf, verifyConnection) } func (mw *databaseMetricsMiddleware) Close() (err error) { diff --git a/builtin/logical/database/dbplugin/grpc_transport.go b/builtin/logical/database/dbplugin/grpc_transport.go new file mode 100644 index 000000000..0b277968c --- /dev/null +++ b/builtin/logical/database/dbplugin/grpc_transport.go @@ -0,0 +1,198 @@ +package dbplugin + +import ( + "context" + "encoding/json" + "errors" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + + "github.com/golang/protobuf/ptypes" +) + +var ( + ErrPluginShutdown = errors.New("plugin shutdown") +) + +// ---- gRPC Server domain ---- + +type gRPCServer struct { + impl Database +} + +func (s *gRPCServer) Type(context.Context, *Empty) (*TypeResponse, error) { + t, err := s.impl.Type() + if err != nil { + return nil, err + } + + return &TypeResponse{ + Type: t, + }, nil +} + +func (s *gRPCServer) CreateUser(ctx context.Context, req *CreateUserRequest) (*CreateUserResponse, error) { + e, err := ptypes.Timestamp(req.Expiration) + if err != nil { + return nil, err + } + + u, p, err := s.impl.CreateUser(ctx, *req.Statements, *req.UsernameConfig, e) + + return &CreateUserResponse{ + Username: u, + Password: p, + }, err +} + +func (s *gRPCServer) RenewUser(ctx context.Context, req *RenewUserRequest) (*Empty, error) { + e, err := ptypes.Timestamp(req.Expiration) + if err != nil { + return nil, err + } + err = s.impl.RenewUser(ctx, *req.Statements, req.Username, e) + return &Empty{}, err +} + +func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*Empty, error) { + err := s.impl.RevokeUser(ctx, *req.Statements, req.Username) + return &Empty{}, err +} + +func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) { + config := map[string]interface{}{} + + err := json.Unmarshal(req.Config, &config) + if err != nil { + return nil, err + } + + err = s.impl.Initialize(ctx, config, req.VerifyConnection) + return &Empty{}, err +} + +func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) { + s.impl.Close() + return &Empty{}, nil +} + +// ---- gRPC client domain ---- + +type gRPCClient struct { + client DatabaseClient + clientConn *grpc.ClientConn +} + +func (c gRPCClient) Type() (string, error) { + // If the plugin has already shutdown, this will hang forever so we give it + // a one second timeout. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + switch c.clientConn.GetState() { + case connectivity.Ready, connectivity.Idle: + default: + return "", ErrPluginShutdown + } + resp, err := c.client.Type(ctx, &Empty{}) + if err != nil { + return "", err + } + + return resp.Type, err +} + +func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { + t, err := ptypes.TimestampProto(expiration) + if err != nil { + return "", "", err + } + + switch c.clientConn.GetState() { + case connectivity.Ready, connectivity.Idle: + default: + return "", "", ErrPluginShutdown + } + + resp, err := c.client.CreateUser(ctx, &CreateUserRequest{ + Statements: &statements, + UsernameConfig: &usernameConfig, + Expiration: t, + }) + if err != nil { + return "", "", err + } + + return resp.Username, resp.Password, err +} + +func (c *gRPCClient) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error { + t, err := ptypes.TimestampProto(expiration) + if err != nil { + return err + } + + switch c.clientConn.GetState() { + case connectivity.Ready, connectivity.Idle: + default: + return ErrPluginShutdown + } + + _, err = c.client.RenewUser(ctx, &RenewUserRequest{ + Statements: &statements, + Username: username, + Expiration: t, + }) + + return err +} + +func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, username string) error { + switch c.clientConn.GetState() { + case connectivity.Ready, connectivity.Idle: + default: + return ErrPluginShutdown + } + _, err := c.client.RevokeUser(ctx, &RevokeUserRequest{ + Statements: &statements, + Username: username, + }) + + return err +} + +func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error { + configRaw, err := json.Marshal(config) + if err != nil { + return err + } + + switch c.clientConn.GetState() { + case connectivity.Ready, connectivity.Idle: + default: + return ErrPluginShutdown + } + + _, err = c.client.Initialize(ctx, &InitializeRequest{ + Config: configRaw, + VerifyConnection: verifyConnection, + }) + + return err +} + +func (c *gRPCClient) Close() error { + // If the plugin has already shutdown, this will hang forever so we give it + // a one second timeout. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + switch c.clientConn.GetState() { + case connectivity.Ready, connectivity.Idle: + _, err := c.client.Close(ctx, &Empty{}) + return err + } + + return nil +} diff --git a/builtin/logical/database/dbplugin/netrpc_transport.go b/builtin/logical/database/dbplugin/netrpc_transport.go new file mode 100644 index 000000000..6f6f3a5bf --- /dev/null +++ b/builtin/logical/database/dbplugin/netrpc_transport.go @@ -0,0 +1,139 @@ +package dbplugin + +import ( + "context" + "fmt" + "net/rpc" + "time" +) + +// ---- RPC server domain ---- + +// databasePluginRPCServer implements an RPC version of Database and is run +// inside a plugin. It wraps an underlying implementation of Database. +type databasePluginRPCServer struct { + impl Database +} + +func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { + var err error + *resp, err = ds.impl.Type() + return err +} + +func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequestRPC, resp *CreateUserResponse) error { + var err error + resp.Username, resp.Password, err = ds.impl.CreateUser(context.Background(), args.Statements, args.UsernameConfig, args.Expiration) + return err +} + +func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequestRPC, _ *struct{}) error { + err := ds.impl.RenewUser(context.Background(), args.Statements, args.Username, args.Expiration) + return err +} + +func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *struct{}) error { + err := ds.impl.RevokeUser(context.Background(), args.Statements, args.Username) + return err +} + +func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error { + err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection) + return err +} + +func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { + ds.impl.Close() + return nil +} + +// ---- RPC client domain ---- +// databasePluginRPCClient implements Database and is used on the client to +// make RPC calls to a plugin. +type databasePluginRPCClient struct { + client *rpc.Client +} + +func (dr *databasePluginRPCClient) Type() (string, error) { + var dbType string + err := dr.client.Call("Plugin.Type", struct{}{}, &dbType) + + return fmt.Sprintf("plugin-%s", dbType), err +} + +func (dr *databasePluginRPCClient) CreateUser(_ context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { + req := CreateUserRequestRPC{ + Statements: statements, + UsernameConfig: usernameConfig, + Expiration: expiration, + } + + var resp CreateUserResponse + err = dr.client.Call("Plugin.CreateUser", req, &resp) + + return resp.Username, resp.Password, err +} + +func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements Statements, username string, expiration time.Time) error { + req := RenewUserRequestRPC{ + Statements: statements, + Username: username, + Expiration: expiration, + } + + err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error { + req := RevokeUserRequestRPC{ + Statements: statements, + Username: username, + } + + err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error { + req := InitializeRequestRPC{ + Config: conf, + VerifyConnection: verifyConnection, + } + + err := dr.client.Call("Plugin.Initialize", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) Close() error { + err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) + + return err +} + +// ---- RPC Request Args Domain ---- + +type InitializeRequestRPC struct { + Config map[string]interface{} + VerifyConnection bool +} + +type CreateUserRequestRPC struct { + Statements Statements + UsernameConfig UsernameConfig + Expiration time.Time +} + +type RenewUserRequestRPC struct { + Statements Statements + Username string + Expiration time.Time +} + +type RevokeUserRequestRPC struct { + Statements Statements + Username string +} diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 0becc9f4a..0f4bfee80 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -1,10 +1,13 @@ package dbplugin import ( + "context" "fmt" "net/rpc" "time" + "google.golang.org/grpc" + "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/helper/pluginutil" log "github.com/mgutz/logxi/v1" @@ -13,29 +16,14 @@ import ( // Database is the interface that all database objects must implement. type Database interface { Type() (string, error) - CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) - RenewUser(statements Statements, username string, expiration time.Time) error - RevokeUser(statements Statements, username string) error + CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) + RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error + RevokeUser(ctx context.Context, statements Statements, username string) error - Initialize(config map[string]interface{}, verifyConnection bool) error + Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error Close() error } -// Statements set in role creation and passed into the database type's functions. -type Statements struct { - CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` - RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` - RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` - RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` -} - -// UsernameConfig is used to configure prefixes for the username to be -// generated. -type UsernameConfig struct { - DisplayName string - RoleName string -} - // PluginFactory is used to build plugin database types. It wraps the database // object in a logging and metrics middleware. func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) { @@ -45,6 +33,7 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log. return nil, err } + var transport string var db Database if pluginRunner.Builtin { // Plugin is builtin so we can retrieve an instance of the interface @@ -60,12 +49,24 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log. return nil, fmt.Errorf("unsuported database type: %s", pluginName) } + transport = "builtin" + } else { // create a DatabasePluginClient instance db, err = newPluginClient(sys, pluginRunner, logger) if err != nil { return nil, err } + + // Switch on the underlying database client type to get the transport + // method. + switch db.(*DatabasePluginClient).Database.(type) { + case *gRPCClient: + transport = "gRPC" + case *databasePluginRPCClient: + transport = "netRPC" + } + } typeStr, err := db.Type() @@ -82,9 +83,10 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log. // Wrap with tracing middleware if logger.IsTrace() { db = &databaseTracingMiddleware{ - next: db, - typeStr: typeStr, - logger: logger, + transport: transport, + next: db, + typeStr: typeStr, + logger: logger, } } @@ -115,33 +117,14 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e return &databasePluginRPCClient{client: c}, nil } -// ---- RPC Request Args Domain ---- - -type InitializeRequest struct { - Config map[string]interface{} - VerifyConnection bool +func (d DatabasePlugin) GRPCServer(s *grpc.Server) error { + RegisterDatabaseServer(s, &gRPCServer{impl: d.impl}) + return nil } -type CreateUserRequest struct { - Statements Statements - UsernameConfig UsernameConfig - Expiration time.Time -} - -type RenewUserRequest struct { - Statements Statements - Username string - Expiration time.Time -} - -type RevokeUserRequest struct { - Statements Statements - Username string -} - -// ---- RPC Response Args Domain ---- - -type CreateUserResponse struct { - Username string - Password string +func (DatabasePlugin) GRPCClient(c *grpc.ClientConn) (interface{}, error) { + return &gRPCClient{ + client: NewDatabaseClient(c), + clientConn: c, + }, nil } diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index 3a785953d..96ef886b2 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -1,11 +1,13 @@ package dbplugin_test import ( + "context" "errors" "os" "testing" "time" + plugin "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/pluginutil" vaulthttp "github.com/hashicorp/vault/http" @@ -20,7 +22,7 @@ type mockPlugin struct { } func (m *mockPlugin) Type() (string, error) { return "mock", nil } -func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (m *mockPlugin) CreateUser(_ context.Context, statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { err = errors.New("err") if usernameConf.DisplayName == "" || expiration.IsZero() { return "", "", err @@ -34,7 +36,7 @@ func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbp return usernameConf.DisplayName, "test", nil } -func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { +func (m *mockPlugin) RenewUser(_ context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { err := errors.New("err") if username == "" || expiration.IsZero() { return err @@ -46,7 +48,7 @@ func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, return nil } -func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) error { +func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statements, username string) error { err := errors.New("err") if username == "" { return err @@ -59,7 +61,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) delete(m.users, username) return nil } -func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error { +func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error { err := errors.New("err") if len(conf) != 1 { return err @@ -80,14 +82,15 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) { cores := cluster.Cores sys := vault.TestDynamicSystemView(cores[0].Core) - vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_Main") + vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_GRPC_Main") + vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", "TestPlugin_NetRPC_Main") return cluster, sys } // This is not an actual test case, it's a helper function that will be executed // by the go-plugin client via an exec call. -func TestPlugin_Main(t *testing.T) { +func TestPlugin_GRPC_Main(t *testing.T) { if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { return } @@ -105,6 +108,30 @@ func TestPlugin_Main(t *testing.T) { plugins.Serve(plugin, apiClientMeta.GetTLSConfig()) } +// This is not an actual test case, it's a helper function that will be executed +// by the go-plugin client via an exec call. +func TestPlugin_NetRPC_Main(t *testing.T) { + if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { + return + } + + p := &mockPlugin{ + users: make(map[string][]string), + } + + args := []string{"--tls-skip-verify=true"} + + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(args) + + tlsProvider := pluginutil.VaultPluginTLSProvider(apiClientMeta.GetTLSConfig()) + serveConf := dbplugin.ServeConfig(p, tlsProvider) + serveConf.GRPCServer = nil + + plugin.Serve(serveConf) +} + func TestPlugin_Initialize(t *testing.T) { cluster, sys := getCluster(t) defer cluster.Cleanup() @@ -118,7 +145,7 @@ func TestPlugin_Initialize(t *testing.T) { "test": 1, } - err = dbRaw.Initialize(connectionDetails, true) + err = dbRaw.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -143,7 +170,7 @@ func TestPlugin_CreateUser(t *testing.T) { "test": 1, } - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -153,7 +180,7 @@ func TestPlugin_CreateUser(t *testing.T) { RoleName: "test", } - us, pw, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -163,7 +190,7 @@ func TestPlugin_CreateUser(t *testing.T) { // try and save the same user again to verify it saved the first time, this // should return an error - _, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) if err == nil { t.Fatal("expected an error, user wasn't created correctly") } @@ -182,7 +209,7 @@ func TestPlugin_RenewUser(t *testing.T) { connectionDetails := map[string]interface{}{ "test": 1, } - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -192,12 +219,12 @@ func TestPlugin_RenewUser(t *testing.T) { RoleName: "test", } - us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } - err = db.RenewUser(dbplugin.Statements{}, us, time.Now().Add(time.Minute)) + err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -216,7 +243,7 @@ func TestPlugin_RevokeUser(t *testing.T) { connectionDetails := map[string]interface{}{ "test": 1, } - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -226,19 +253,159 @@ func TestPlugin_RevokeUser(t *testing.T) { RoleName: "test", } - us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } // Test default revoke statememts - err = db.RevokeUser(dbplugin.Statements{}, us) + err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us) if err != nil { t.Fatalf("err: %s", err) } // Try adding the same username back so we can verify it was removed - _, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +// Test the code is still compatible with an old netRPC plugin +func TestPlugin_NetRPC_Initialize(t *testing.T) { + cluster, sys := getCluster(t) + defer cluster.Cleanup() + + dbRaw, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + + connectionDetails := map[string]interface{}{ + "test": 1, + } + + err = dbRaw.Initialize(context.Background(), connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = dbRaw.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestPlugin_NetRPC_CreateUser(t *testing.T) { + cluster, sys := getCluster(t) + defer cluster.Cleanup() + + db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + defer db.Close() + + connectionDetails := map[string]interface{}{ + "test": 1, + } + + err = db.Initialize(context.Background(), connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + usernameConf := dbplugin.UsernameConfig{ + DisplayName: "test", + RoleName: "test", + } + + us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + if us != "test" || pw != "test" { + t.Fatal("expected username and password to be 'test'") + } + + // try and save the same user again to verify it saved the first time, this + // should return an error + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + if err == nil { + t.Fatal("expected an error, user wasn't created correctly") + } +} + +func TestPlugin_NetRPC_RenewUser(t *testing.T) { + cluster, sys := getCluster(t) + defer cluster.Cleanup() + + db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + defer db.Close() + + connectionDetails := map[string]interface{}{ + "test": 1, + } + err = db.Initialize(context.Background(), connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + usernameConf := dbplugin.UsernameConfig{ + DisplayName: "test", + RoleName: "test", + } + + us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestPlugin_NetRPC_RevokeUser(t *testing.T) { + cluster, sys := getCluster(t) + defer cluster.Cleanup() + + db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + defer db.Close() + + connectionDetails := map[string]interface{}{ + "test": 1, + } + err = db.Initialize(context.Background(), connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + usernameConf := dbplugin.UsernameConfig{ + DisplayName: "test", + RoleName: "test", + } + + us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Try adding the same username back so we can verify it was removed + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 381f0ae2a..0f44905af 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -10,6 +10,10 @@ import ( // Database implementation in a databasePluginRPCServer object and starts a // RPC server. func Serve(db Database, tlsProvider func() (*tls.Config, error)) { + plugin.Serve(ServeConfig(db, tlsProvider)) +} + +func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig { dbPlugin := &DatabasePlugin{ impl: db, } @@ -19,53 +23,10 @@ func Serve(db Database, tlsProvider func() (*tls.Config, error)) { "database": dbPlugin, } - plugin.Serve(&plugin.ServeConfig{ + return &plugin.ServeConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, TLSProvider: tlsProvider, - }) -} - -// ---- RPC server domain ---- - -// databasePluginRPCServer implements an RPC version of Database and is run -// inside a plugin. It wraps an underlying implementation of Database. -type databasePluginRPCServer struct { - impl Database -} - -func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { - var err error - *resp, err = ds.impl.Type() - return err -} - -func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error { - var err error - resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernameConfig, args.Expiration) - - return err -} - -func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error { - err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration) - - return err -} - -func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error { - err := ds.impl.RevokeUser(args.Statements, args.Username) - - return err -} - -func (ds *databasePluginRPCServer) Initialize(args *InitializeRequest, _ *struct{}) error { - err := ds.impl.Initialize(args.Config, args.VerifyConnection) - - return err -} - -func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { - ds.impl.Close() - return nil + GRPCServer: plugin.DefaultGRPCServer, + } } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index d1e6cb292..95ec216d2 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -1,6 +1,7 @@ package database import ( + "context" "errors" "fmt" @@ -62,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc { b.clearConnection(name) // Execute plugin again, we don't need the object so throw away. - _, err := b.createDBObj(req.Storage, name) + _, err := b.createDBObj(context.TODO(), req.Storage, name) if err != nil { return nil, err } @@ -230,7 +231,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil } - err = db.Initialize(config.ConnectionDetails, verifyConnection) + err = db.Initialize(context.TODO(), config.ConnectionDetails, verifyConnection) if err != nil { db.Close() return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil diff --git a/builtin/logical/database/path_creds_create.go b/builtin/logical/database/path_creds_create.go index 610da726c..8e1adce4a 100644 --- a/builtin/logical/database/path_creds_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "time" @@ -66,7 +67,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { unlockFunc = b.Unlock // Create a new DB object - db, err = b.createDBObj(req.Storage, role.DBName) + db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName) if err != nil { unlockFunc() return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) @@ -81,7 +82,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { } // Create the user - username, password, err := db.CreateUser(role.Statements, usernameConfig, expiration) + username, password, err := db.CreateUser(context.TODO(), role.Statements, usernameConfig, expiration) // Unlock unlockFunc() if err != nil { diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index c3dfcb973..fb3e05bdf 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "github.com/hashicorp/vault/logical" @@ -60,7 +61,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { unlockFunc = b.Unlock // Create a new DB object - db, err = b.createDBObj(req.Storage, role.DBName) + db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName) if err != nil { unlockFunc() return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) @@ -69,7 +70,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { // Make sure we increase the VALID UNTIL endpoint for this user. if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { - err := db.RenewUser(role.Statements, username, expireTime) + err := db.RenewUser(context.TODO(), role.Statements, username, expireTime) // Unlock unlockFunc() if err != nil { @@ -119,14 +120,14 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { unlockFunc = b.Unlock // Create a new DB object - db, err = b.createDBObj(req.Storage, role.DBName) + db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName) if err != nil { unlockFunc() return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } } - err = db.RevokeUser(role.Statements, username) + err = db.RevokeUser(context.TODO(), role.Statements, username) // Unlock unlockFunc() if err != nil { diff --git a/helper/keysutil/lock_manager.go b/helper/keysutil/lock_manager.go index 16724049f..9d5cf63ae 100644 --- a/helper/keysutil/lock_manager.go +++ b/helper/keysutil/lock_manager.go @@ -43,6 +43,9 @@ type PolicyRequest struct { // Whether to upsert Upsert bool + + // Whether to allow plaintext backup + AllowPlaintextBackup bool } type LockManager struct { @@ -378,10 +381,11 @@ func (lm *LockManager) getPolicyCommon(req PolicyRequest, lockType bool) (*Polic } p = &Policy{ - Name: req.Name, - Type: req.KeyType, - Derived: req.Derived, - Exportable: req.Exportable, + Name: req.Name, + Type: req.KeyType, + Derived: req.Derived, + Exportable: req.Exportable, + AllowPlaintextBackup: req.AllowPlaintextBackup, } if req.Derived { p.KDF = Kdf_hkdf_sha256 diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 2047651ed..bd9986e2d 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -119,6 +119,10 @@ func (r *PluginRunner) runCommon(wrapper RunnerUtil, pluginMap map[string]plugin SecureConfig: secureConfig, TLSConfig: clientTLSConfig, Logger: namedLogger, + AllowedProtocols: []plugin.Protocol{ + plugin.ProtocolNetRPC, + plugin.ProtocolGRPC, + }, } client := plugin.NewClient(clientConfig) diff --git a/logical/framework/backend_test.go b/logical/framework/backend_test.go index d94beedd3..040b52af3 100644 --- a/logical/framework/backend_test.go +++ b/logical/framework/backend_test.go @@ -192,8 +192,7 @@ func TestBackendHandleRequest_helpRoot(t *testing.T) { func TestBackendHandleRequest_renewAuth(t *testing.T) { b := &Backend{} - resp, err := b.HandleRequest(logical.RenewAuthRequest( - "/foo", &logical.Auth{}, nil)) + resp, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil)) if err != nil { t.Fatalf("err: %s", err) } @@ -213,8 +212,7 @@ func TestBackendHandleRequest_renewAuthCallback(t *testing.T) { AuthRenew: callback, } - _, err := b.HandleRequest(logical.RenewAuthRequest( - "/foo", &logical.Auth{}, nil)) + _, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil)) if err != nil { t.Fatalf("err: %s", err) } @@ -237,8 +235,7 @@ func TestBackendHandleRequest_renew(t *testing.T) { Secrets: []*Secret{secret}, } - _, err := b.HandleRequest(logical.RenewRequest( - "/foo", secret.Response(nil, nil).Secret, nil)) + _, err := b.HandleRequest(logical.RenewRequest("/foo", secret.Response(nil, nil).Secret, nil)) if err != nil { t.Fatalf("err: %s", err) } @@ -293,8 +290,7 @@ func TestBackendHandleRequest_revoke(t *testing.T) { Secrets: []*Secret{secret}, } - _, err := b.HandleRequest(logical.RevokeRequest( - "/foo", secret.Response(nil, nil).Secret, nil)) + _, err := b.HandleRequest(logical.RevokeRequest("/foo", secret.Response(nil, nil).Secret, nil)) if err != nil { t.Fatalf("err: %s", err) } diff --git a/logical/framework/lease.go b/logical/framework/lease.go index 4fd2ac902..d2678f712 100644 --- a/logical/framework/lease.go +++ b/logical/framework/lease.go @@ -8,7 +8,8 @@ import ( ) // LeaseExtend returns an OperationFunc that can be used to simply extend the -// lease of the auth/secret for the duration that was requested. +// lease of the auth/secret for the duration that was requested. The parameters +// provided are used to determine the lease's new TTL value. // // backendIncrement is the backend's requested increment -- perhaps from a user // request, perhaps from a role/config value. If not set, uses the mount/system diff --git a/logical/request.go b/logical/request.go index edde0417f..5e5102d1c 100644 --- a/logical/request.go +++ b/logical/request.go @@ -200,8 +200,7 @@ func (r *Request) SetLastRemoteWAL(last uint64) { } // RenewRequest creates the structure of the renew request. -func RenewRequest( - path string, secret *Secret, data map[string]interface{}) *Request { +func RenewRequest(path string, secret *Secret, data map[string]interface{}) *Request { return &Request{ Operation: RenewOperation, Path: path, @@ -211,8 +210,7 @@ func RenewRequest( } // RenewAuthRequest creates the structure of the renew request for an auth. -func RenewAuthRequest( - path string, auth *Auth, data map[string]interface{}) *Request { +func RenewAuthRequest(path string, auth *Auth, data map[string]interface{}) *Request { return &Request{ Operation: RenewOperation, Path: path, @@ -222,8 +220,7 @@ func RenewAuthRequest( } // RevokeRequest creates the structure of the revoke request. -func RevokeRequest( - path string, secret *Secret, data map[string]interface{}) *Request { +func RevokeRequest(path string, secret *Secret, data map[string]interface{}) *Request { return &Request{ Operation: RevokeOperation, Path: path, diff --git a/plugins/database/cassandra/cassandra.go b/plugins/database/cassandra/cassandra.go index c0b5fd5d4..221784e0f 100644 --- a/plugins/database/cassandra/cassandra.go +++ b/plugins/database/cassandra/cassandra.go @@ -1,6 +1,7 @@ package cassandra import ( + "context" "strings" "time" @@ -21,6 +22,8 @@ const ( cassandraTypeName = "cassandra" ) +var _ dbplugin.Database = &Cassandra{} + // Cassandra is an implementation of Database interface type Cassandra struct { connutil.ConnectionProducer @@ -64,8 +67,8 @@ func (c *Cassandra) Type() (string, error) { return cassandraTypeName, nil } -func (c *Cassandra) getConnection() (*gocql.Session, error) { - session, err := c.Connection() +func (c *Cassandra) getConnection(ctx context.Context) (*gocql.Session, error) { + session, err := c.Connection(ctx) if err != nil { return nil, err } @@ -75,13 +78,13 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) { // CreateUser generates the username/password on the underlying Cassandra secret backend as instructed by // the CreationStatement provided. -func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { // Grab the lock c.Lock() defer c.Unlock() // Get the connection - session, err := c.getConnection() + session, err := c.getConnection(ctx) if err != nil { return "", "", err } @@ -138,18 +141,18 @@ func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig db } // RenewUser is not supported on Cassandra, so this is a no-op. -func (c *Cassandra) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { +func (c *Cassandra) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { // NOOP return nil } // RevokeUser attempts to drop the specified user. -func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) error { +func (c *Cassandra) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { // Grab the lock c.Lock() defer c.Unlock() - session, err := c.getConnection() + session, err := c.getConnection(ctx) if err != nil { return err } diff --git a/plugins/database/cassandra/cassandra_test.go b/plugins/database/cassandra/cassandra_test.go index 0f4d3306e..c31139de7 100644 --- a/plugins/database/cassandra/cassandra_test.go +++ b/plugins/database/cassandra/cassandra_test.go @@ -1,6 +1,7 @@ package cassandra import ( + "context" "os" "strconv" "testing" @@ -89,7 +90,7 @@ func TestCassandra_Initialize(t *testing.T) { db := dbRaw.(*Cassandra) connProducer := db.ConnectionProducer.(*cassandraConnectionProducer) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -112,7 +113,7 @@ func TestCassandra_Initialize(t *testing.T) { "protocol_version": "4", } - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -135,7 +136,7 @@ func TestCassandra_CreateUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*Cassandra) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -149,7 +150,7 @@ func TestCassandra_CreateUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -176,7 +177,7 @@ func TestMyCassandra_RenewUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*Cassandra) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -190,7 +191,7 @@ func TestMyCassandra_RenewUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -199,7 +200,7 @@ func TestMyCassandra_RenewUser(t *testing.T) { t.Fatalf("Could not connect with new credentials: %s", err) } - err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) + err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -222,7 +223,7 @@ func TestCassandra_RevokeUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*Cassandra) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -236,7 +237,7 @@ func TestCassandra_RevokeUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -246,7 +247,7 @@ func TestCassandra_RevokeUser(t *testing.T) { } // Test default revoke statememts - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } diff --git a/plugins/database/cassandra/connection_producer.go b/plugins/database/cassandra/connection_producer.go index 45d46518b..2f7e06f86 100644 --- a/plugins/database/cassandra/connection_producer.go +++ b/plugins/database/cassandra/connection_producer.go @@ -1,6 +1,7 @@ package cassandra import ( + "context" "crypto/tls" "fmt" "strings" @@ -43,7 +44,7 @@ type cassandraConnectionProducer struct { sync.Mutex } -func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { +func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { c.Lock() defer c.Unlock() @@ -106,7 +107,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve c.Initialized = true if verifyConnection { - if _, err := c.Connection(); err != nil { + if _, err := c.Connection(ctx); err != nil { return fmt.Errorf("error verifying connection: %s", err) } } @@ -114,7 +115,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve return nil } -func (c *cassandraConnectionProducer) Connection() (interface{}, error) { +func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) { if !c.Initialized { return nil, connutil.ErrNotInitialized } diff --git a/plugins/database/hana/hana.go b/plugins/database/hana/hana.go index aa2b53d65..5411505c8 100644 --- a/plugins/database/hana/hana.go +++ b/plugins/database/hana/hana.go @@ -1,6 +1,7 @@ package hana import ( + "context" "database/sql" "fmt" "strings" @@ -26,6 +27,8 @@ type HANA struct { credsutil.CredentialsProducer } +var _ dbplugin.Database = &HANA{} + // New implements builtinplugins.BuiltinFactory func New() (interface{}, error) { connProducer := &connutil.SQLConnectionProducer{} @@ -63,8 +66,8 @@ func (h *HANA) Type() (string, error) { return hanaTypeName, nil } -func (h *HANA) getConnection() (*sql.DB, error) { - db, err := h.Connection() +func (h *HANA) getConnection(ctx context.Context) (*sql.DB, error) { + db, err := h.Connection(ctx) if err != nil { return nil, err } @@ -74,13 +77,13 @@ func (h *HANA) getConnection() (*sql.DB, error) { // CreateUser generates the username/password on the underlying HANA secret backend // as instructed by the CreationStatement provided. -func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { // Grab the lock h.Lock() defer h.Unlock() // Get the connection - db, err := h.getConnection() + db, err := h.getConnection(ctx) if err != nil { return "", "", err } @@ -117,7 +120,7 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi } // Start a transaction - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return "", "", err } @@ -130,7 +133,7 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi continue } - stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ "name": username, "password": password, "expiration": expirationStr, @@ -139,7 +142,7 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi return "", "", err } defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + if _, err := stmt.ExecContext(ctx); err != nil { return "", "", err } } @@ -153,15 +156,15 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi } // Renewing hana user just means altering user's valid until property -func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { +func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { // Get connection - db, err := h.getConnection() + db, err := h.getConnection(ctx) if err != nil { return err } // Start a transaction - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return err } @@ -175,12 +178,12 @@ func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expira } // Renew user's valid until property field - stmt, err := tx.Prepare("ALTER USER " + username + " VALID UNTIL " + "'" + expirationStr + "'") + stmt, err := tx.PrepareContext(ctx, "ALTER USER "+username+" VALID UNTIL "+"'"+expirationStr+"'") if err != nil { return err } defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + if _, err := stmt.ExecContext(ctx); err != nil { return err } @@ -193,20 +196,20 @@ func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expira } // Revoking hana user will deactivate user and try to perform a soft drop -func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error { +func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { // default revoke will be a soft drop on user if statements.RevocationStatements == "" { - return h.revokeUserDefault(username) + return h.revokeUserDefault(ctx, username) } // Get connection - db, err := h.getConnection() + db, err := h.getConnection(ctx) if err != nil { return err } // Start a transaction - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return err } @@ -219,14 +222,14 @@ func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error continue } - stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ "name": username, })) if err != nil { return err } defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + if _, err := stmt.ExecContext(ctx); err != nil { return err } } @@ -239,38 +242,38 @@ func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error return nil } -func (h *HANA) revokeUserDefault(username string) error { +func (h *HANA) revokeUserDefault(ctx context.Context, username string) error { // Get connection - db, err := h.getConnection() + db, err := h.getConnection(ctx) if err != nil { return err } // Start a transaction - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() // Disable server login for user - disableStmt, err := tx.Prepare(fmt.Sprintf("ALTER USER %s DEACTIVATE USER NOW", username)) + disableStmt, err := tx.PrepareContext(ctx, fmt.Sprintf("ALTER USER %s DEACTIVATE USER NOW", username)) if err != nil { return err } defer disableStmt.Close() - if _, err := disableStmt.Exec(); err != nil { + if _, err := disableStmt.ExecContext(ctx); err != nil { return err } // Invalidates current sessions and performs soft drop (drop if no dependencies) // if hard drop is desired, custom revoke statements should be written for role - dropStmt, err := tx.Prepare(fmt.Sprintf("DROP USER %s RESTRICT", username)) + dropStmt, err := tx.PrepareContext(ctx, fmt.Sprintf("DROP USER %s RESTRICT", username)) if err != nil { return err } defer dropStmt.Close() - if _, err := dropStmt.Exec(); err != nil { + if _, err := dropStmt.ExecContext(ctx); err != nil { return err } diff --git a/plugins/database/hana/hana_test.go b/plugins/database/hana/hana_test.go index 7cff7f1f3..8845fa3b8 100644 --- a/plugins/database/hana/hana_test.go +++ b/plugins/database/hana/hana_test.go @@ -1,6 +1,7 @@ package hana import ( + "context" "database/sql" "fmt" "os" @@ -25,7 +26,7 @@ func TestHANA_Initialize(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*HANA) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -55,7 +56,7 @@ func TestHANA_CreateUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*HANA) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -66,7 +67,7 @@ func TestHANA_CreateUser(t *testing.T) { } // Test with no configured Creation Statememt - _, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour)) + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour)) if err == nil { t.Fatal("Expected error when no creation statement is provided") } @@ -75,7 +76,7 @@ func TestHANA_CreateUser(t *testing.T) { CreationStatements: testHANARole, } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour)) if err != nil { t.Fatalf("err: %s", err) } @@ -98,7 +99,7 @@ func TestHANA_RevokeUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*HANA) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -113,7 +114,7 @@ func TestHANA_RevokeUser(t *testing.T) { } // Test default revoke statememts - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour)) if err != nil { t.Fatalf("err: %s", err) } @@ -121,7 +122,7 @@ func TestHANA_RevokeUser(t *testing.T) { t.Fatalf("Could not connect with new credentials: %s", err) } - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } @@ -130,7 +131,7 @@ func TestHANA_RevokeUser(t *testing.T) { } // Test custom revoke statememt - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour)) if err != nil { t.Fatalf("err: %s", err) } @@ -139,7 +140,7 @@ func TestHANA_RevokeUser(t *testing.T) { } statements.RevocationStatements = testHANADrop - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } diff --git a/plugins/database/mongodb/connection_producer.go b/plugins/database/mongodb/connection_producer.go index a9ff64a94..9674d10bf 100644 --- a/plugins/database/mongodb/connection_producer.go +++ b/plugins/database/mongodb/connection_producer.go @@ -1,6 +1,7 @@ package mongodb import ( + "context" "crypto/tls" "encoding/base64" "encoding/json" @@ -33,7 +34,7 @@ type mongoDBConnectionProducer struct { } // Initialize parses connection configuration. -func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { +func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { c.Lock() defer c.Unlock() @@ -75,7 +76,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri c.Initialized = true if verifyConnection { - if _, err := c.Connection(); err != nil { + if _, err := c.Connection(ctx); err != nil { return fmt.Errorf("error verifying connection: %s", err) } @@ -88,7 +89,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri } // Connection creates a database connection. -func (c *mongoDBConnectionProducer) Connection() (interface{}, error) { +func (c *mongoDBConnectionProducer) Connection(_ context.Context) (interface{}, error) { if !c.Initialized { return nil, connutil.ErrNotInitialized } diff --git a/plugins/database/mongodb/mongodb.go b/plugins/database/mongodb/mongodb.go index 52671dae2..8b2ee802b 100644 --- a/plugins/database/mongodb/mongodb.go +++ b/plugins/database/mongodb/mongodb.go @@ -1,6 +1,7 @@ package mongodb import ( + "context" "io" "strings" "time" @@ -27,6 +28,8 @@ type MongoDB struct { credsutil.CredentialsProducer } +var _ dbplugin.Database = &MongoDB{} + // New returns a new MongoDB instance func New() (interface{}, error) { connProducer := &mongoDBConnectionProducer{} @@ -63,8 +66,8 @@ func (m *MongoDB) Type() (string, error) { return mongoDBTypeName, nil } -func (m *MongoDB) getConnection() (*mgo.Session, error) { - session, err := m.Connection() +func (m *MongoDB) getConnection(ctx context.Context) (*mgo.Session, error) { + session, err := m.Connection(ctx) if err != nil { return nil, err } @@ -80,7 +83,7 @@ func (m *MongoDB) getConnection() (*mgo.Session, error) { // // JSON Example: // { "db": "admin", "roles": [{ "role": "readWrite" }, {"role": "read", "db": "foo"}] } -func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { // Grab the lock m.Lock() defer m.Unlock() @@ -89,7 +92,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl return "", "", dbutil.ErrEmptyCreationStatement } - session, err := m.getConnection() + session, err := m.getConnection(ctx) if err != nil { return "", "", err } @@ -133,7 +136,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl if err := m.ConnectionProducer.Close(); err != nil { return "", "", errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err) } - session, err := m.getConnection() + session, err := m.getConnection(ctx) if err != nil { return "", "", err } @@ -149,15 +152,15 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl } // RenewUser is not supported on MongoDB, so this is a no-op. -func (m *MongoDB) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { +func (m *MongoDB) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { // NOOP return nil } // RevokeUser drops the specified user from the authentication databse. If none is provided // in the revocation statement, the default "admin" authentication database will be assumed. -func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) error { - session, err := m.getConnection() +func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { + session, err := m.getConnection(ctx) if err != nil { return err } @@ -188,7 +191,7 @@ func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) er if err := m.ConnectionProducer.Close(); err != nil { return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err) } - session, err := m.getConnection() + session, err := m.getConnection(ctx) if err != nil { return err } diff --git a/plugins/database/mongodb/mongodb_test.go b/plugins/database/mongodb/mongodb_test.go index 4c1eacb66..cd948af81 100644 --- a/plugins/database/mongodb/mongodb_test.go +++ b/plugins/database/mongodb/mongodb_test.go @@ -1,6 +1,7 @@ package mongodb import ( + "context" "fmt" "os" "testing" @@ -79,7 +80,7 @@ func TestMongoDB_Initialize(t *testing.T) { db := dbRaw.(*MongoDB) connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer) - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -107,7 +108,7 @@ func TestMongoDB_CreateUser(t *testing.T) { t.Fatalf("err: %s", err) } db := dbRaw.(*MongoDB) - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -121,7 +122,7 @@ func TestMongoDB_CreateUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -145,7 +146,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) { t.Fatalf("err: %s", err) } db := dbRaw.(*MongoDB) - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -159,7 +160,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -182,7 +183,7 @@ func TestMongoDB_RevokeUser(t *testing.T) { t.Fatalf("err: %s", err) } db := dbRaw.(*MongoDB) - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -196,7 +197,7 @@ func TestMongoDB_RevokeUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -206,7 +207,7 @@ func TestMongoDB_RevokeUser(t *testing.T) { } // Test default revocation statememt - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index 7b920c8c9..27b36a0d0 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -1,6 +1,7 @@ package mssql import ( + "context" "database/sql" "fmt" "strings" @@ -18,6 +19,8 @@ import ( const msSQLTypeName = "mssql" +var _ dbplugin.Database = &MSSQL{} + // MSSQL is an implementation of Database interface type MSSQL struct { connutil.ConnectionProducer @@ -60,8 +63,8 @@ func (m *MSSQL) Type() (string, error) { return msSQLTypeName, nil } -func (m *MSSQL) getConnection() (*sql.DB, error) { - db, err := m.Connection() +func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) { + db, err := m.Connection(ctx) if err != nil { return nil, err } @@ -71,13 +74,13 @@ func (m *MSSQL) getConnection() (*sql.DB, error) { // CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by // the CreationStatement provided. -func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { // Grab the lock m.Lock() defer m.Unlock() // Get the connection - db, err := m.getConnection() + db, err := m.getConnection(ctx) if err != nil { return "", "", err } @@ -102,7 +105,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug } // Start a transaction - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return "", "", err } @@ -115,7 +118,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug continue } - stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ "name": username, "password": password, "expiration": expirationStr, @@ -124,7 +127,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug return "", "", err } defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + if _, err := stmt.ExecContext(ctx); err != nil { return "", "", err } } @@ -138,7 +141,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug } // RenewUser is not supported on MSSQL, so this is a no-op. -func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { +func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { // NOOP return nil } @@ -146,19 +149,19 @@ func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expir // RevokeUser attempts to drop the specified user. It will first attempt to disable login, // then kill pending connections from that user, and finally drop the user and login from the // database instance. -func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error { +func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { if statements.RevocationStatements == "" { - return m.revokeUserDefault(username) + return m.revokeUserDefault(ctx, username) } // Get connection - db, err := m.getConnection() + db, err := m.getConnection(ctx) if err != nil { return err } // Start a transaction - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return err } @@ -171,14 +174,14 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro continue } - stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ "name": username, })) if err != nil { return err } defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + if _, err := stmt.ExecContext(ctx); err != nil { return err } } @@ -191,20 +194,20 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro return nil } -func (m *MSSQL) revokeUserDefault(username string) error { +func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error { // Get connection - db, err := m.getConnection() + db, err := m.getConnection(ctx) if err != nil { return err } // First disable server login - disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username)) + disableStmt, err := db.PrepareContext(ctx, fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username)) if err != nil { return err } defer disableStmt.Close() - if _, err := disableStmt.Exec(); err != nil { + if _, err := disableStmt.ExecContext(ctx); err != nil { return err } @@ -212,14 +215,14 @@ func (m *MSSQL) revokeUserDefault(username string) error { // sessions. There cannot be any active sessions before we drop the logins // This isn't done in a transaction because even if we fail along the way, // we want to remove as much access as possible - sessionStmt, err := db.Prepare(fmt.Sprintf( + sessionStmt, err := db.PrepareContext(ctx, fmt.Sprintf( "SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username)) if err != nil { return err } defer sessionStmt.Close() - sessionRows, err := sessionStmt.Query() + sessionRows, err := sessionStmt.QueryContext(ctx) if err != nil { return err } @@ -240,13 +243,13 @@ func (m *MSSQL) revokeUserDefault(username string) error { // we need to drop the database users before we can drop the login and the role // This isn't done in a transaction because even if we fail along the way, // we want to remove as much access as possible - stmt, err := db.Prepare(fmt.Sprintf("EXEC master.dbo.sp_msloginmappings '%s';", username)) + stmt, err := db.PrepareContext(ctx, fmt.Sprintf("EXEC master.dbo.sp_msloginmappings '%s';", username)) if err != nil { return err } defer stmt.Close() - rows, err := stmt.Query() + rows, err := stmt.QueryContext(ctx) if err != nil { return err } @@ -266,13 +269,13 @@ func (m *MSSQL) revokeUserDefault(username string) error { // many permissions as possible right now var lastStmtError error for _, query := range revokeStmts { - stmt, err := db.Prepare(query) + stmt, err := db.PrepareContext(ctx, query) if err != nil { lastStmtError = err continue } defer stmt.Close() - _, err = stmt.Exec() + _, err = stmt.ExecContext(ctx) if err != nil { lastStmtError = err } @@ -287,12 +290,12 @@ func (m *MSSQL) revokeUserDefault(username string) error { } // Drop this login - stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username)) + stmt, err = db.PrepareContext(ctx, fmt.Sprintf(dropLoginSQL, username, username)) if err != nil { return err } defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + if _, err := stmt.ExecContext(ctx); err != nil { return err } diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go index 5a00890bf..7d2571c3d 100644 --- a/plugins/database/mssql/mssql_test.go +++ b/plugins/database/mssql/mssql_test.go @@ -1,6 +1,7 @@ package mssql import ( + "context" "database/sql" "fmt" "os" @@ -30,7 +31,7 @@ func TestMSSQL_Initialize(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*MSSQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -51,7 +52,7 @@ func TestMSSQL_Initialize(t *testing.T) { "max_open_connections": "5", } - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -69,7 +70,7 @@ func TestMSSQL_CreateUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*MSSQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -80,7 +81,7 @@ func TestMSSQL_CreateUser(t *testing.T) { } // Test with no configured Creation Statememt - _, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) if err == nil { t.Fatal("Expected error when no creation statement is provided") } @@ -89,7 +90,7 @@ func TestMSSQL_CreateUser(t *testing.T) { CreationStatements: testMSSQLRole, } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -111,7 +112,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*MSSQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -125,7 +126,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) if err != nil { t.Fatalf("err: %s", err) } @@ -135,7 +136,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { } // Test default revoke statememts - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } @@ -144,7 +145,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { t.Fatal("Credentials were not revoked") } - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) if err != nil { t.Fatalf("err: %s", err) } @@ -155,7 +156,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { // Test custom revoke statememt statements.RevocationStatements = testMSSQLDrop - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index 87289274e..38c928c35 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -1,6 +1,7 @@ package mysql import ( + "context" "database/sql" "strings" "time" @@ -30,6 +31,8 @@ var ( LegacyUsernameLen int = 16 ) +var _ dbplugin.Database = &MySQL{} + type MySQL struct { connutil.ConnectionProducer credsutil.CredentialsProducer @@ -88,8 +91,8 @@ func (m *MySQL) Type() (string, error) { return mySQLTypeName, nil } -func (m *MySQL) getConnection() (*sql.DB, error) { - db, err := m.Connection() +func (m *MySQL) getConnection(ctx context.Context) (*sql.DB, error) { + db, err := m.Connection(ctx) if err != nil { return nil, err } @@ -97,13 +100,13 @@ func (m *MySQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { // Grab the lock m.Lock() defer m.Unlock() // Get the connection - db, err := m.getConnection() + db, err := m.getConnection(ctx) if err != nil { return "", "", err } @@ -128,7 +131,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug } // Start a transaction - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return "", "", err } @@ -146,7 +149,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug "expiration": expirationStr, }) - stmt, err := tx.Prepare(query) + stmt, err := tx.PrepareContext(ctx, query) if err != nil { // If the error code we get back is Error 1295: This command is not // supported in the prepared statement protocol yet, we will execute @@ -155,7 +158,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug // prepare supported commands. If there is no error when running we // will continue to the next statement. if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 { - _, err = tx.Exec(query) + _, err = tx.ExecContext(ctx, query) if err != nil { return "", "", err } @@ -165,7 +168,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug return "", "", err } defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + if _, err := stmt.ExecContext(ctx); err != nil { return "", "", err } } @@ -179,17 +182,17 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug } // NOOP -func (m *MySQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { +func (m *MySQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { return nil } -func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) error { +func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { // Grab the read lock m.Lock() defer m.Unlock() // Get the connection - db, err := m.getConnection() + db, err := m.getConnection(ctx) if err != nil { return err } @@ -201,7 +204,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro } // Start a transaction - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return err } @@ -217,7 +220,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro // 1295: This command is not supported in the prepared statement protocol yet // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ query = strings.Replace(query, "{{name}}", username, -1) - _, err = tx.Exec(query) + _, err = tx.ExecContext(ctx, query) if err != nil { return err } diff --git a/plugins/database/mysql/mysql_test.go b/plugins/database/mysql/mysql_test.go index 203158b46..fbbc87008 100644 --- a/plugins/database/mysql/mysql_test.go +++ b/plugins/database/mysql/mysql_test.go @@ -1,6 +1,7 @@ package mysql import ( + "context" "database/sql" "fmt" "os" @@ -108,7 +109,7 @@ func TestMySQL_Initialize(t *testing.T) { db := dbRaw.(*MySQL) connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -128,7 +129,7 @@ func TestMySQL_Initialize(t *testing.T) { "max_open_connections": "5", } - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -146,7 +147,7 @@ func TestMySQL_CreateUser(t *testing.T) { dbRaw, _ := f() db := dbRaw.(*MySQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -157,7 +158,7 @@ func TestMySQL_CreateUser(t *testing.T) { } // Test with no configured Creation Statememt - _, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) if err == nil { t.Fatal("Expected error when no creation statement is provided") } @@ -166,7 +167,7 @@ func TestMySQL_CreateUser(t *testing.T) { CreationStatements: testMySQLRoleWildCard, } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -176,7 +177,7 @@ func TestMySQL_CreateUser(t *testing.T) { } // Test a second time to make sure usernames don't collide - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -188,7 +189,7 @@ func TestMySQL_CreateUser(t *testing.T) { // Test with a manualy prepare statement statements.CreationStatements = testMySQLRolePreparedStmt - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -211,7 +212,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { dbRaw, _ := f() db := dbRaw.(*MySQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -222,7 +223,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { } // Test with no configured Creation Statememt - _, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) if err == nil { t.Fatal("Expected error when no creation statement is provided") } @@ -231,7 +232,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { CreationStatements: testMySQLRoleWildCard, } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -241,7 +242,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { } // Test a second time to make sure usernames don't collide - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -263,7 +264,7 @@ func TestMySQL_RevokeUser(t *testing.T) { dbRaw, _ := f() db := dbRaw.(*MySQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -277,7 +278,7 @@ func TestMySQL_RevokeUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -287,7 +288,7 @@ func TestMySQL_RevokeUser(t *testing.T) { } // Test default revoke statememts - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } @@ -297,7 +298,7 @@ func TestMySQL_RevokeUser(t *testing.T) { } statements.CreationStatements = testMySQLRoleWildCard - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -308,7 +309,7 @@ func TestMySQL_RevokeUser(t *testing.T) { // Test custom revoke statements statements.RevocationStatements = testMySQLRevocationSQL - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index 93fa8a854..f2e20d3f4 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -1,6 +1,7 @@ package postgresql import ( + "context" "database/sql" "fmt" "strings" @@ -24,6 +25,8 @@ ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; ` ) +var _ dbplugin.Database = &PostgreSQL{} + // New implements builtinplugins.BuiltinFactory func New() (interface{}, error) { connProducer := &connutil.SQLConnectionProducer{} @@ -65,8 +68,8 @@ func (p *PostgreSQL) Type() (string, error) { return postgreSQLTypeName, nil } -func (p *PostgreSQL) getConnection() (*sql.DB, error) { - db, err := p.Connection() +func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) { + db, err := p.Connection(ctx) if err != nil { return nil, err } @@ -74,7 +77,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { if statements.CreationStatements == "" { return "", "", dbutil.ErrEmptyCreationStatement } @@ -99,14 +102,14 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d } // Get the connection - db, err := p.getConnection() + db, err := p.getConnection(ctx) if err != nil { return "", "", err } // Start a transaction - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return "", "", err @@ -123,7 +126,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d continue } - stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ "name": username, "password": password, "expiration": expirationStr, @@ -133,7 +136,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d } defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + if _, err := stmt.ExecContext(ctx); err != nil { return "", "", err } @@ -148,7 +151,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d return username, password, nil } -func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { +func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { p.Lock() defer p.Unlock() @@ -157,12 +160,12 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, renewStmts = defaultPostgresRenewSQL } - db, err := p.getConnection() + db, err := p.getConnection(ctx) if err != nil { return err } - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return err } @@ -180,7 +183,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, if len(query) == 0 { continue } - stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ "name": username, "expiration": expirationStr, })) @@ -189,7 +192,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, } defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + if _, err := stmt.ExecContext(ctx); err != nil { return err } } @@ -201,25 +204,25 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, return nil } -func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error { +func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { // Grab the lock p.Lock() defer p.Unlock() if statements.RevocationStatements == "" { - return p.defaultRevokeUser(username) + return p.defaultRevokeUser(ctx, username) } - return p.customRevokeUser(username, statements.RevocationStatements) + return p.customRevokeUser(ctx, username, statements.RevocationStatements) } -func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { - db, err := p.getConnection() +func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationStmts string) error { + db, err := p.getConnection(ctx) if err != nil { return err } - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return err } @@ -233,7 +236,7 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { continue } - stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ "name": username, })) if err != nil { @@ -241,7 +244,7 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { } defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + if _, err := stmt.ExecContext(ctx); err != nil { return err } } @@ -253,15 +256,15 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { return nil } -func (p *PostgreSQL) defaultRevokeUser(username string) error { - db, err := p.getConnection() +func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error { + db, err := p.getConnection(ctx) if err != nil { return err } // Check if the role exists var exists bool - err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) + err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) if err != nil && err != sql.ErrNoRows { return err } @@ -274,13 +277,13 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error { // the role // This isn't done in a transaction because even if we fail along the way, // we want to remove as much access as possible - stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") + stmt, err := db.PrepareContext(ctx, "SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") if err != nil { return err } defer stmt.Close() - rows, err := stmt.Query(username) + rows, err := stmt.QueryContext(ctx, username) if err != nil { return err } @@ -322,7 +325,7 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error { // get the current database name so we can issue a REVOKE CONNECT for // this username var dbname sql.NullString - if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil { + if err := db.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbname); err != nil { return err } @@ -337,13 +340,13 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error { // many permissions as possible right now var lastStmtError error for _, query := range revocationStmts { - stmt, err := db.Prepare(query) + stmt, err := db.PrepareContext(ctx, query) if err != nil { lastStmtError = err continue } defer stmt.Close() - _, err = stmt.Exec() + _, err = stmt.ExecContext(ctx) if err != nil { lastStmtError = err } @@ -358,13 +361,13 @@ func (p *PostgreSQL) defaultRevokeUser(username string) error { } // Drop this user - stmt, err = db.Prepare(fmt.Sprintf( + stmt, err = db.PrepareContext(ctx, fmt.Sprintf( `DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username))) if err != nil { return err } defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + if _, err := stmt.ExecContext(ctx); err != nil { return err } diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go index 1e53d9118..8f4ebb67a 100644 --- a/plugins/database/postgresql/postgresql_test.go +++ b/plugins/database/postgresql/postgresql_test.go @@ -1,6 +1,7 @@ package postgresql import ( + "context" "database/sql" "fmt" "os" @@ -72,7 +73,7 @@ func TestPostgreSQL_Initialize(t *testing.T) { connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -92,7 +93,7 @@ func TestPostgreSQL_Initialize(t *testing.T) { "max_open_connections": "5", } - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -109,7 +110,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*PostgreSQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -120,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { } // Test with no configured Creation Statememt - _, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) if err == nil { t.Fatal("Expected error when no creation statement is provided") } @@ -129,7 +130,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { CreationStatements: testPostgresRole, } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -139,7 +140,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { } statements.CreationStatements = testPostgresReadOnlyRole - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -162,7 +163,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*PostgreSQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -176,7 +177,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) if err != nil { t.Fatalf("err: %s", err) } @@ -185,7 +186,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { t.Fatalf("Could not connect with new credentials: %s", err) } - err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) + err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -197,7 +198,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { t.Fatalf("Could not connect with new credentials: %s", err) } statements.RenewStatements = defaultPostgresRenewSQL - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) if err != nil { t.Fatalf("err: %s", err) } @@ -206,7 +207,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { t.Fatalf("Could not connect with new credentials: %s", err) } - err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) + err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -230,7 +231,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*PostgreSQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -244,7 +245,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) if err != nil { t.Fatalf("err: %s", err) } @@ -254,7 +255,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { } // Test default revoke statememts - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } @@ -263,7 +264,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { t.Fatal("Credentials were not revoked") } - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) if err != nil { t.Fatalf("err: %s", err) } @@ -274,7 +275,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { // Test custom revoke statements statements.RevocationStatements = defaultPostgresRevocationSQL - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } diff --git a/plugins/helper/database/connutil/connutil.go b/plugins/helper/database/connutil/connutil.go index d36d5719d..7cf23c5c3 100644 --- a/plugins/helper/database/connutil/connutil.go +++ b/plugins/helper/database/connutil/connutil.go @@ -1,6 +1,7 @@ package connutil import ( + "context" "errors" "sync" ) @@ -14,8 +15,8 @@ var ( // connections and is used in all the builtin database types. type ConnectionProducer interface { Close() error - Initialize(map[string]interface{}, bool) error - Connection() (interface{}, error) + Initialize(context.Context, map[string]interface{}, bool) error + Connection(context.Context) (interface{}, error) sync.Locker } diff --git a/plugins/helper/database/connutil/sql.go b/plugins/helper/database/connutil/sql.go index c325cbc18..2e34065d0 100644 --- a/plugins/helper/database/connutil/sql.go +++ b/plugins/helper/database/connutil/sql.go @@ -1,6 +1,7 @@ package connutil import ( + "context" "database/sql" "fmt" "strings" @@ -25,7 +26,7 @@ type SQLConnectionProducer struct { sync.Mutex } -func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { +func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { c.Lock() defer c.Unlock() @@ -62,11 +63,11 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo c.Initialized = true if verifyConnection { - if _, err := c.Connection(); err != nil { + if _, err := c.Connection(ctx); err != nil { return fmt.Errorf("error verifying connection: %s", err) } - if err := c.db.Ping(); err != nil { + if err := c.db.PingContext(ctx); err != nil { return fmt.Errorf("error verifying connection: %s", err) } } @@ -74,14 +75,14 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo return nil } -func (c *SQLConnectionProducer) Connection() (interface{}, error) { +func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) { if !c.Initialized { return nil, ErrNotInitialized } // If we already have a DB, test it and return if c.db != nil { - if err := c.db.Ping(); err == nil { + if err := c.db.PingContext(ctx); err == nil { return c.db, nil } // If the ping was unsuccessful, close it and ignore errors as we'll be diff --git a/vault/auth.go b/vault/auth.go index db11ce8fc..8a1f5d6c9 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -5,7 +5,6 @@ import ( "fmt" "strings" - "github.com/hashicorp/errwrap" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/jsonutil" @@ -460,10 +459,11 @@ func (c *Core) setupCredentials() error { backend, err = c.newCredentialBackend(entry.Type, sysView, view, conf) if err != nil { c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err) - if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" { - // If we encounter an error instantiating the backend due to it being missing from the catalog, - // skip backend initialization but register the entry to the mount table to preserve storage - // and path. + if entry.Type == "plugin" { + // If we encounter an error instantiating the backend due to an error, + // skip backend initialization but register the entry to the mount table + // to preserve storage and path. + c.logger.Warn("core: skipping plugin-based credential entry", "path", entry.Path) goto ROUTER_MOUNT } return errLoadAuthFailed diff --git a/vault/expiration.go b/vault/expiration.go index 23176ae44..b42f334d5 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -750,6 +750,31 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke }, nil } + sysView := m.router.MatchingSystemView(le.Path) + if sysView == nil { + return nil, fmt.Errorf("expiration: unable to retrieve system view from router") + } + + retResp := &logical.Response{} + switch { + case resp.Auth.Period > time.Duration(0): + // If it resp.Period is non-zero, use that as the TTL and override backend's + // call on TTL modification, such as a TTL value determined by + // framework.LeaseExtend call against the request. Also, cap period value to + // the sys/mount max value. + if resp.Auth.Period > sysView.MaxLeaseTTL() { + retResp.AddWarning(fmt.Sprintf("Period of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL())) + resp.Auth.Period = sysView.MaxLeaseTTL() + } + resp.Auth.TTL = resp.Auth.Period + case resp.Auth.TTL > time.Duration(0): + // Cap TTL value to the sys/mount max value + if resp.Auth.TTL > sysView.MaxLeaseTTL() { + retResp.AddWarning(fmt.Sprintf("TTL of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL())) + resp.Auth.TTL = sysView.MaxLeaseTTL() + } + } + // Attach the ClientToken resp.Auth.ClientToken = token resp.Auth.Increment = 0 @@ -764,9 +789,9 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke // Update the expiration time m.updatePending(le, resp.Auth.LeaseTotal()) - return &logical.Response{ - Auth: resp.Auth, - }, nil + + retResp.Auth = resp.Auth + return retResp, nil } // Register is used to take a request and response with an associated @@ -866,6 +891,12 @@ func (m *ExpirationManager) RegisterAuth(source string, auth *logical.Auth) erro return err } + // If it resp.Period is non-zero, override the TTL value determined + // by the backend. + if auth.Period > time.Duration(0) { + auth.TTL = auth.Period + } + // Create a lease entry le := leaseEntry{ LeaseID: path.Join(source, saltedID), @@ -1017,8 +1048,7 @@ func (m *ExpirationManager) revokeEntry(le *leaseEntry) error { } // Handle standard revocation via backends - resp, err := m.router.Route(logical.RevokeRequest( - le.Path, le.Secret, le.Data)) + resp, err := m.router.Route(logical.RevokeRequest(le.Path, le.Secret, le.Data)) if err != nil || (resp != nil && resp.IsError()) { return fmt.Errorf("failed to revoke entry: resp:%#v err:%s", resp, err) } diff --git a/vault/logical_system_integ_test.go b/vault/logical_system_integ_test.go index 60eab6b69..403d4892b 100644 --- a/vault/logical_system_integ_test.go +++ b/vault/logical_system_integ_test.go @@ -2,7 +2,9 @@ package vault_test import ( "fmt" + "io/ioutil" "os" + "path/filepath" "testing" "time" @@ -178,14 +180,13 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun } if testMount { - // Add plugin back to the catalog - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical") - // Mount the plugin at the same path after plugin is re-added to the catalog // and expect an error due to existing path. var err error switch btype { case logical.TypeLogical: + // Add plugin back to the catalog + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical") _, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{ "type": "plugin", "config": map[string]interface{}{ @@ -193,6 +194,8 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun }, }) case logical.TypeCredential: + // Add plugin back to the catalog + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials") _, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{ "type": "plugin", "plugin_name": "mock-plugin", @@ -204,6 +207,129 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun } } +func TestSystemBackend_Plugin_continueOnError(t *testing.T) { + t.Run("secret", func(t *testing.T) { + t.Run("sha256_mismatch", func(t *testing.T) { + testPlugin_continueOnError(t, logical.TypeLogical, true) + }) + + t.Run("missing_plugin", func(t *testing.T) { + testPlugin_continueOnError(t, logical.TypeLogical, false) + }) + }) + + t.Run("auth", func(t *testing.T) { + t.Run("sha256_mismatch", func(t *testing.T) { + testPlugin_continueOnError(t, logical.TypeCredential, true) + }) + + t.Run("missing_plugin", func(t *testing.T) { + testPlugin_continueOnError(t, logical.TypeCredential, false) + }) + }) +} + +func testPlugin_continueOnError(t *testing.T, btype logical.BackendType, mismatch bool) { + cluster := testSystemBackendMock(t, 1, 1, btype) + defer cluster.Cleanup() + + core := cluster.Cores[0] + + // Get the registered plugin + req := logical.TestRequest(t, logical.ReadOperation, "sys/plugins/catalog/mock-plugin") + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(req) + if err != nil || resp == nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + command, ok := resp.Data["command"].(string) + if !ok || command == "" { + t.Fatal("invalid command") + } + + // Trigger a sha256 mistmatch or missing plugin error + if mismatch { + req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/catalog/mock-plugin") + req.Data = map[string]interface{}{ + "sha256": "d17bd7334758e53e6fbab15745d2520765c06e296f2ce8e25b7919effa0ac216", + "command": filepath.Base(command), + } + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + } else { + err := os.Remove(filepath.Join(cluster.TempDir, filepath.Base(command))) + if err != nil { + t.Fatal(err) + } + } + + // Seal the cluster + cluster.EnsureCoresSealed(t) + + // Unseal the cluster + barrierKeys := cluster.BarrierKeys + for _, core := range cluster.Cores { + for _, key := range barrierKeys { + _, err := core.Unseal(vault.TestKeyCopy(key)) + if err != nil { + t.Fatal(err) + } + } + sealed, err := core.Sealed() + if err != nil { + t.Fatalf("err checking seal status: %s", err) + } + if sealed { + t.Fatal("should not be sealed") + } + // Wait for active so post-unseal takes place + // If it fails, it means unseal process failed + vault.TestWaitActive(t, core.Core) + } + + // Re-add the plugin to the catalog + switch btype { + case logical.TypeLogical: + vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical", cluster.TempDir) + case logical.TypeCredential: + vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials", cluster.TempDir) + } + + // Reload the plugin + req = logical.TestRequest(t, logical.UpdateOperation, "sys/plugins/reload/backend") + req.Data = map[string]interface{}{ + "plugin": "mock-plugin", + } + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + // Make a request to lazy load the plugin + var reqPath string + switch btype { + case logical.TypeLogical: + reqPath = "mock-0/internal" + case logical.TypeCredential: + reqPath = "auth/mock-0/internal" + } + + req = logical.TestRequest(t, logical.ReadOperation, reqPath) + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: response should not be nil") + } +} + func TestSystemBackend_Plugin_autoReload(t *testing.T) { cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical) defer cluster.Cleanup() @@ -332,7 +458,10 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{} } // testSystemBackendMock returns a systemBackend with the desired number -// of mounted mock plugin backends +// of mounted mock plugin backends. numMounts alternates between different +// ways of providing the plugin_name. +// +// The mounts are mounted at sys/mounts/mock-[numMounts] or sys/auth/mock-[numMounts] func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType) *vault.TestCluster { coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ @@ -343,10 +472,17 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo }, } + // Create a tempdir, cluster.Cleanup will clean up this directory + tempDir, err := ioutil.TempDir("", "vault-test-cluster") + if err != nil { + t.Fatal(err) + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, KeepStandbysSealed: true, NumCores: numCores, + TempDir: tempDir, }) cluster.Start() @@ -358,7 +494,7 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo switch backendType { case logical.TypeLogical: - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical") + vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical", tempDir) for i := 0; i < numMounts; i++ { // Alternate input styles for plugin_name on every other mount options := map[string]interface{}{ @@ -380,7 +516,7 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo } } case logical.TypeCredential: - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials") + vault.TestAddTestPluginTempDir(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials", tempDir) for i := 0; i < numMounts; i++ { // Alternate input styles for plugin_name on every other mount options := map[string]interface{}{ diff --git a/vault/mount.go b/vault/mount.go index e4aea2100..7e3cab478 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -7,7 +7,6 @@ import ( "strings" "time" - "github.com/hashicorp/errwrap" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/jsonutil" @@ -753,10 +752,11 @@ func (c *Core) setupMounts() error { backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf) if err != nil { c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err) - if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" { - // If we encounter an error instantiating the backend due to it being missing from the catalog, - // skip backend initialization but register the entry to the mount table to preserve storage - // and path. + if entry.Type == "plugin" { + // If we encounter an error instantiating the backend due to an error, + // skip backend initialization but register the entry to the mount table + // to preserve storage and path. + c.logger.Warn("core: skipping plugin-based mount entry", "path", entry.Path) goto ROUTER_MOUNT } return errLoadMountsFailed diff --git a/vault/plugin_reload.go b/vault/plugin_reload.go index eaff18b48..8f699557c 100644 --- a/vault/plugin_reload.go +++ b/vault/plugin_reload.go @@ -79,15 +79,23 @@ func (c *Core) reloadMatchingPlugin(pluginName string) error { func (c *Core) reloadPluginCommon(entry *MountEntry, isAuth bool) error { path := entry.Path + if isAuth { + path = credentialRoutePrefix + path + } + // Fast-path out if the backend doesn't exist raw, ok := c.router.root.Get(path) if !ok { return nil } - // Call backend's Cleanup routine re := raw.(*routeEntry) - re.backend.Cleanup() + + // Only call Cleanup if backend is initialized + if re.backend != nil { + // Call backend's Cleanup routine + re.backend.Cleanup() + } view := re.storageView diff --git a/vault/request_handling.go b/vault/request_handling.go index 7fb7ea1d3..3453ff1b0 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -477,14 +477,27 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon return nil, nil, ErrInternalError } - // Set the default lease if not provided - if auth.TTL == 0 { - auth.TTL = sysView.DefaultLeaseTTL() - } + // Start off with the sys default value, and update according to period/TTL + // from resp.Auth + tokenTTL := sysView.DefaultLeaseTTL() - // Limit the lease duration - if auth.TTL > sysView.MaxLeaseTTL() { - auth.TTL = sysView.MaxLeaseTTL() + switch { + case auth.Period > time.Duration(0): + // Cap the period value to the sys max_ttl value. The auth backend should + // have checked for it on its login path, but we check here again for + // sanity. + if auth.Period > sysView.MaxLeaseTTL() { + auth.Period = sysView.MaxLeaseTTL() + } + tokenTTL = auth.Period + case auth.TTL > time.Duration(0): + // Cap the TTL value. The auth backend should have checked for it on its + // login path (e.g. a call to b.SanitizeTTL), but we check here again for + // sanity. + if auth.TTL > sysView.MaxLeaseTTL() { + auth.TTL = sysView.MaxLeaseTTL() + } + tokenTTL = auth.TTL } // Generate a token @@ -494,7 +507,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon Meta: auth.Metadata, DisplayName: auth.DisplayName, CreationTime: time.Now().Unix(), - TTL: auth.TTL, + TTL: tokenTTL, NumUses: auth.NumUses, EntityID: auth.EntityID, } @@ -513,10 +526,11 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon return nil, auth, ErrInternalError } - // Populate the client token and accessor + // Populate the client token, accessor, and TTL auth.ClientToken = te.ID auth.Accessor = te.Accessor auth.Policies = te.Policies + auth.TTL = te.TTL // Register with the expiration manager if err := c.expiration.RegisterAuth(te.Path, auth); err != nil { diff --git a/vault/testing.go b/vault/testing.go index 8fddbd1bc..a1e469ed2 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -379,6 +379,8 @@ func TestDynamicSystemView(c *Core) *dynamicSystemView { return &dynamicSystemView{c, me} } +// TestAddTestPlugin registers the testFunc as part of the plugin command to the +// plugin catalog. func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) { file, err := os.Open(os.Args[0]) if err != nil { @@ -413,11 +415,74 @@ func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) { } } +// TestAddTestPluginTempDir registers the testFunc as part of the plugin command to the +// plugin catalog. It uses tmpDir as the plugin directory. +func TestAddTestPluginTempDir(t testing.T, c *Core, name, testFunc, tempDir string) { + file, err := os.Open(os.Args[0]) + if err != nil { + t.Fatal(err) + } + defer file.Close() + + fi, err := file.Stat() + if err != nil { + t.Fatal(err) + } + + // Copy over the file to the temp dir + dst := filepath.Join(tempDir, filepath.Base(os.Args[0])) + out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) + if err != nil { + t.Fatal(err) + } + defer out.Close() + + if _, err = io.Copy(out, file); err != nil { + t.Fatal(err) + } + err = out.Sync() + if err != nil { + t.Fatal(err) + } + + // Determine plugin directory full path + fullPath, err := filepath.EvalSymlinks(tempDir) + if err != nil { + t.Fatal(err) + } + + reader, err := os.Open(filepath.Join(fullPath, filepath.Base(os.Args[0]))) + if err != nil { + t.Fatal(err) + } + defer reader.Close() + + // Find out the sha256 + hash := sha256.New() + + _, err = io.Copy(hash, reader) + if err != nil { + t.Fatal(err) + } + + sum := hash.Sum(nil) + + // Set core's plugin directory and plugin catalog directory + c.pluginDirectory = fullPath + c.pluginCatalog.directory = fullPath + + command := fmt.Sprintf("%s --test.run=%s", filepath.Base(os.Args[0]), testFunc) + err = c.pluginCatalog.Set(name, command, sum) + if err != nil { + t.Fatal(err) + } +} + var testLogicalBackends = map[string]logical.Factory{} var testCredentialBackends = map[string]logical.Factory{} -// Starts the test server which responds to SSH authentication. -// Used to test the SSH secret backend. +// StartSSHHostTestServer starts the test server which responds to SSH +// authentication. Used to test the SSH secret backend. func StartSSHHostTestServer() (string, error) { pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(testSharedPublicKey)) if err != nil { @@ -760,6 +825,7 @@ func (c *TestCluster) ensureCoresSealed() error { return nil } +// UnsealWithStoredKeys uses stored keys to unseal the test cluster cores func (c *TestCluster) UnsealWithStoredKeys(t testing.T) error { for _, core := range c.Cores { if err := core.UnsealWithStoredKeys(); err != nil { diff --git a/vault/token_store_test.go b/vault/token_store_test.go index 1b3d7c346..db8cf4816 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -2312,7 +2312,7 @@ func TestTokenStore_RolePeriod(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test") req.ClientToken = root req.Data = map[string]interface{}{ - "period": 300, + "period": 5, } resp, err := core.HandleRequest(req) @@ -2425,8 +2425,8 @@ func TestTokenStore_RolePeriod(t *testing.T) { t.Fatalf("err: %v", err) } ttl := resp.Data["ttl"].(int64) - if ttl < 299 { - t.Fatalf("TTL too small (expected %d, got %d", 299, ttl) + if ttl > 5 { + t.Fatalf("TTL too large (expected %d, got %d", 5, ttl) } // Let the TTL go down a bit to 3 seconds @@ -2449,8 +2449,8 @@ func TestTokenStore_RolePeriod(t *testing.T) { t.Fatalf("err: %v", err) } ttl = resp.Data["ttl"].(int64) - if ttl < 299 { - t.Fatalf("TTL too small (expected %d, got %d", 299, ttl) + if ttl > 5 { + t.Fatalf("TTL too large (expected %d, got %d", 5, ttl) } } } @@ -2677,7 +2677,7 @@ func TestTokenStore_Periodic(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test") req.ClientToken = root req.Data = map[string]interface{}{ - "period": 300, + "period": 5, } resp, err := core.HandleRequest(req) @@ -2715,8 +2715,8 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl := resp.Data["ttl"].(int64) - if ttl < 299 { - t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl) + if ttl > 5 { + t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl) } // Let the TTL go down a bit @@ -2739,8 +2739,8 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl = resp.Data["ttl"].(int64) - if ttl < 299 { - t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl) + if ttl > 5 { + t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl) } } @@ -2750,8 +2750,8 @@ func TestTokenStore_Periodic(t *testing.T) { req.Operation = logical.UpdateOperation req.Path = "auth/token/create" req.Data = map[string]interface{}{ - "period": 300, - "explicit_max_ttl": 150, + "period": 5, + "explicit_max_ttl": 4, } resp, err = core.HandleRequest(req) if err != nil { @@ -2775,8 +2775,8 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl := resp.Data["ttl"].(int64) - if ttl < 149 || ttl > 150 { - t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl) + if ttl < 3 || ttl > 4 { + t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl) } // Let the TTL go down a bit @@ -2799,8 +2799,8 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl = resp.Data["ttl"].(int64) - if ttl < 140 || ttl > 150 { - t.Fatalf("TTL bad (expected around %d, got %d)", 145, ttl) + if ttl > 2 { + t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl) } } @@ -2812,7 +2812,7 @@ func TestTokenStore_Periodic(t *testing.T) { req.Operation = logical.UpdateOperation req.Path = "auth/token/create/test" req.Data = map[string]interface{}{ - "period": 150, + "period": 5, } resp, err = core.HandleRequest(req) if err != nil { @@ -2836,8 +2836,8 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl := resp.Data["ttl"].(int64) - if ttl < 149 || ttl > 150 { - t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl) + if ttl < 4 || ttl > 5 { + t.Fatalf("TTL bad (expected %d, got %d)", 4, ttl) } // Let the TTL go down a bit @@ -2860,8 +2860,8 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl = resp.Data["ttl"].(int64) - if ttl < 149 { - t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl) + if ttl > 5 { + t.Fatalf("TTL bad (expected less than %d, got %d)", 5, ttl) } } @@ -2869,18 +2869,23 @@ func TestTokenStore_Periodic(t *testing.T) { { req.Path = "auth/token/roles/test" req.ClientToken = root + req.Operation = logical.UpdateOperation req.Data = map[string]interface{}{ - "period": 300, - "explicit_max_ttl": 150, + "period": 5, + "explicit_max_ttl": 4, + } + + resp, err := core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v %v", err, resp) + } + if resp != nil { + t.Fatalf("expected a nil response") } req.ClientToken = root req.Operation = logical.UpdateOperation req.Path = "auth/token/create/test" - req.Data = map[string]interface{}{ - "period": 150, - "explicit_max_ttl": 130, - } resp, err = core.HandleRequest(req) if err != nil { t.Fatalf("err: %v %v", err, resp) @@ -2903,12 +2908,12 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl := resp.Data["ttl"].(int64) - if ttl < 129 || ttl > 130 { - t.Fatalf("TTL bad (expected %d, got %d)", 129, ttl) + if ttl < 3 || ttl > 4 { + t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl) } // Let the TTL go down a bit - time.Sleep(4 * time.Second) + time.Sleep(2 * time.Second) req.Operation = logical.UpdateOperation req.Path = "auth/token/renew-self" @@ -2927,8 +2932,8 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl = resp.Data["ttl"].(int64) - if ttl > 127 { - t.Fatalf("TTL bad (expected < %d, got %d)", 128, ttl) + if ttl > 2 { + t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl) } } }