Update the interface for plugins removing functions for creating creds

This commit is contained in:
Brian Kassouf 2017-04-10 12:24:16 -07:00
parent 459e3eda4e
commit bbbd81220c
6 changed files with 28 additions and 177 deletions

View File

@ -78,20 +78,20 @@ func (dr *databasePluginRPCClient) Type() string {
return fmt.Sprintf("plugin-%s", dbType)
}
func (dr *databasePluginRPCClient) CreateUser(statements Statements, username, password, expiration string) error {
func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
req := CreateUserRequest{
Statements: statements,
Username: username,
Password: password,
UsernamePrefix: usernamePrefix,
Expiration: expiration,
}
err := dr.client.Call("Plugin.CreateUser", req, &struct{}{})
var resp CreateUserResponse
err = dr.client.Call("Plugin.CreateUser", req, &resp)
return err
return resp.Username, resp.Password, err
}
func (dr *databasePluginRPCClient) RenewUser(statements Statements, username, expiration string) error {
func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error {
req := RenewUserRequest{
Statements: statements,
Username: username,
@ -125,24 +125,3 @@ func (dr *databasePluginRPCClient) Close() error {
return err
}
func (dr *databasePluginRPCClient) GenerateUsername(displayName string) (string, error) {
resp := &GenerateUsernameResponse{}
err := dr.client.Call("Plugin.GenerateUsername", displayName, resp)
return resp.Username, err
}
func (dr *databasePluginRPCClient) GeneratePassword() (string, error) {
resp := &GeneratePasswordResponse{}
err := dr.client.Call("Plugin.GeneratePassword", struct{}{}, resp)
return resp.Password, err
}
func (dr *databasePluginRPCClient) GenerateExpiration(duration time.Duration) (string, error) {
resp := &GenerateExpirationResponse{}
err := dr.client.Call("Plugin.GenerateExpiration", duration, resp)
return resp.Expiration, err
}

View File

@ -20,7 +20,7 @@ func (mw *databaseTracingMiddleware) Type() string {
return mw.next.Type()
}
func (mw *databaseTracingMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) {
func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/CreateUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
@ -28,10 +28,10 @@ func (mw *databaseTracingMiddleware) CreateUser(statements Statements, username,
mw.logger.Trace("database/CreateUser: starting", "type", mw.typeStr)
}
return mw.next.CreateUser(statements, username, password, expiration)
return mw.next.CreateUser(statements, usernamePrefix, expiration)
}
func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username, expiration string) (err error) {
func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/RenewUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
@ -75,39 +75,6 @@ func (mw *databaseTracingMiddleware) Close() (err error) {
return mw.next.Close()
}
func (mw *databaseTracingMiddleware) GenerateUsername(displayName string) (_ string, err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/GenerateUsername: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database/GenerateUsername: starting", "type", mw.typeStr)
}
return mw.next.GenerateUsername(displayName)
}
func (mw *databaseTracingMiddleware) GeneratePassword() (_ string, err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/GeneratePassword: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database/GeneratePassword: starting", "type", mw.typeStr)
}
return mw.next.GeneratePassword()
}
func (mw *databaseTracingMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/GenerateExpiration: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database/GenerateExpiration: starting", "type", mw.typeStr)
}
return mw.next.GenerateExpiration(duration)
}
// ---- Metrics Middleware Domain ----
type databaseMetricsMiddleware struct {
@ -120,7 +87,7 @@ func (mw *databaseMetricsMiddleware) Type() string {
return mw.next.Type()
}
func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) {
func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "CreateUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now)
@ -133,10 +100,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, username,
metrics.IncrCounter([]string{"database", "CreateUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1)
return mw.next.CreateUser(statements, username, password, expiration)
return mw.next.CreateUser(statements, usernamePrefix, expiration)
}
func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username, expiration string) (err error) {
func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "RenewUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now)
@ -199,51 +166,3 @@ func (mw *databaseMetricsMiddleware) Close() (err error) {
metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1)
return mw.next.Close()
}
func (mw *databaseMetricsMiddleware) GenerateUsername(displayName string) (_ string, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "GenerateUsername"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "GenerateUsername"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "GenerateUsername", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateUsername", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "GenerateUsername"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateUsername"}, 1)
return mw.next.GenerateUsername(displayName)
}
func (mw *databaseMetricsMiddleware) GeneratePassword() (_ string, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "GeneratePassword"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "GeneratePassword"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "GeneratePassword", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "GeneratePassword", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "GeneratePassword"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "GeneratePassword"}, 1)
return mw.next.GeneratePassword()
}
func (mw *databaseMetricsMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "GenerateExpiration"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "GenerateExpiration"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "GenerateExpiration", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateExpiration", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "GenerateExpiration"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateExpiration"}, 1)
return mw.next.GenerateExpiration(duration)
}

View File

@ -17,16 +17,12 @@ var (
// DatabaseType is the interface that all database objects must implement.
type DatabaseType interface {
Type() string
CreateUser(statements Statements, username, password, expiration string) error
RenewUser(statements Statements, username, expiration string) error
CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error)
RenewUser(statements Statements, username string, expiration time.Time) error
RevokeUser(statements Statements, username string) error
Initialize(map[string]interface{}) error
Close() error
GenerateUsername(displayName string) (string, error)
GeneratePassword() (string, error)
GenerateExpiration(ttl time.Duration) (string, error)
}
// Statements set in role creation and passed into the database type's functions.
@ -97,15 +93,14 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
type CreateUserRequest struct {
Statements Statements
Username string
Password string
Expiration string
UsernamePrefix string
Expiration time.Time
}
type RenewUserRequest struct {
Statements Statements
Username string
Expiration string
Expiration time.Time
}
type RevokeUserRequest struct {
@ -115,12 +110,7 @@ type RevokeUserRequest struct {
// ---- RPC Response Args Domain ----
type GenerateUsernameResponse struct {
type CreateUserResponse struct {
Username string
}
type GenerateExpirationResponse struct {
Expiration string
}
type GeneratePasswordResponse struct {
Password string
}

View File

@ -1,8 +1,6 @@
package dbplugin
import (
"time"
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil"
)
@ -39,8 +37,9 @@ func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
return nil
}
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, _ *struct{}) error {
err := ds.impl.CreateUser(args.Statements, args.Username, args.Password, args.Expiration)
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error {
var err error
resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernamePrefix, args.Expiration)
return err
}
@ -67,24 +66,3 @@ func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
ds.impl.Close()
return nil
}
func (ds *databasePluginRPCServer) GenerateUsername(args string, resp *GenerateUsernameResponse) error {
var err error
resp.Username, err = ds.impl.GenerateUsername(args)
return err
}
func (ds *databasePluginRPCServer) GeneratePassword(_ struct{}, resp *GeneratePasswordResponse) error {
var err error
resp.Password, err = ds.impl.GeneratePassword()
return err
}
func (ds *databasePluginRPCServer) GenerateExpiration(args time.Duration, resp *GenerateExpirationResponse) error {
var err error
resp.Expiration, err = ds.impl.GenerateExpiration(args)
return err
}

View File

@ -2,6 +2,7 @@ package database
import (
"fmt"
"time"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
@ -48,24 +49,10 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
}
// Generate the username, password and expiration
username, err := db.GenerateUsername(req.DisplayName)
if err != nil {
return nil, err
}
password, err := db.GeneratePassword()
if err != nil {
return nil, err
}
expiration, err := db.GenerateExpiration(role.DefaultTTL)
if err != nil {
return nil, err
}
expiration := time.Now().Add(role.DefaultTTL)
// Create the user
err = db.CreateUser(role.Statements, username, password, expiration)
username, password, err := db.CreateUser(role.Statements, req.DisplayName, expiration)
if err != nil {
return nil, err
}

View File

@ -58,9 +58,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
// Make sure we increase the VALID UNTIL endpoint for this user.
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
expiration := expireTime.Format("2006-01-02 15:04:05-0700")
err := db.RenewUser(role.Statements, username, expiration)
err := db.RenewUser(role.Statements, username, expireTime)
if err != nil {
return nil, err
}