Update the interface for plugins removing functions for creating creds

This commit is contained in:
Brian Kassouf 2017-04-10 12:24:16 -07:00
parent 459e3eda4e
commit bbbd81220c
6 changed files with 28 additions and 177 deletions

View File

@ -78,20 +78,20 @@ func (dr *databasePluginRPCClient) Type() string {
return fmt.Sprintf("plugin-%s", dbType) return fmt.Sprintf("plugin-%s", dbType)
} }
func (dr *databasePluginRPCClient) CreateUser(statements Statements, username, password, expiration string) error { func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
req := CreateUserRequest{ req := CreateUserRequest{
Statements: statements, Statements: statements,
Username: username, UsernamePrefix: usernamePrefix,
Password: password, Expiration: expiration,
Expiration: expiration,
} }
err := dr.client.Call("Plugin.CreateUser", req, &struct{}{}) var resp CreateUserResponse
err = dr.client.Call("Plugin.CreateUser", req, &resp)
return err return resp.Username, resp.Password, err
} }
func (dr *databasePluginRPCClient) RenewUser(statements Statements, username, expiration string) error { func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error {
req := RenewUserRequest{ req := RenewUserRequest{
Statements: statements, Statements: statements,
Username: username, Username: username,
@ -125,24 +125,3 @@ func (dr *databasePluginRPCClient) Close() error {
return err return err
} }
func (dr *databasePluginRPCClient) GenerateUsername(displayName string) (string, error) {
resp := &GenerateUsernameResponse{}
err := dr.client.Call("Plugin.GenerateUsername", displayName, resp)
return resp.Username, err
}
func (dr *databasePluginRPCClient) GeneratePassword() (string, error) {
resp := &GeneratePasswordResponse{}
err := dr.client.Call("Plugin.GeneratePassword", struct{}{}, resp)
return resp.Password, err
}
func (dr *databasePluginRPCClient) GenerateExpiration(duration time.Duration) (string, error) {
resp := &GenerateExpirationResponse{}
err := dr.client.Call("Plugin.GenerateExpiration", duration, resp)
return resp.Expiration, err
}

View File

@ -20,7 +20,7 @@ func (mw *databaseTracingMiddleware) Type() string {
return mw.next.Type() return mw.next.Type()
} }
func (mw *databaseTracingMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) { func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
if mw.logger.IsTrace() { if mw.logger.IsTrace() {
defer func(then time.Time) { defer func(then time.Time) {
mw.logger.Trace("database/CreateUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) mw.logger.Trace("database/CreateUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
@ -28,10 +28,10 @@ func (mw *databaseTracingMiddleware) CreateUser(statements Statements, username,
mw.logger.Trace("database/CreateUser: starting", "type", mw.typeStr) mw.logger.Trace("database/CreateUser: starting", "type", mw.typeStr)
} }
return mw.next.CreateUser(statements, username, password, expiration) return mw.next.CreateUser(statements, usernamePrefix, expiration)
} }
func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username, expiration string) (err error) { func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
if mw.logger.IsTrace() { if mw.logger.IsTrace() {
defer func(then time.Time) { defer func(then time.Time) {
mw.logger.Trace("database/RenewUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) mw.logger.Trace("database/RenewUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
@ -75,39 +75,6 @@ func (mw *databaseTracingMiddleware) Close() (err error) {
return mw.next.Close() return mw.next.Close()
} }
func (mw *databaseTracingMiddleware) GenerateUsername(displayName string) (_ string, err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/GenerateUsername: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database/GenerateUsername: starting", "type", mw.typeStr)
}
return mw.next.GenerateUsername(displayName)
}
func (mw *databaseTracingMiddleware) GeneratePassword() (_ string, err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/GeneratePassword: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database/GeneratePassword: starting", "type", mw.typeStr)
}
return mw.next.GeneratePassword()
}
func (mw *databaseTracingMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/GenerateExpiration: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database/GenerateExpiration: starting", "type", mw.typeStr)
}
return mw.next.GenerateExpiration(duration)
}
// ---- Metrics Middleware Domain ---- // ---- Metrics Middleware Domain ----
type databaseMetricsMiddleware struct { type databaseMetricsMiddleware struct {
@ -120,7 +87,7 @@ func (mw *databaseMetricsMiddleware) Type() string {
return mw.next.Type() return mw.next.Type()
} }
func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) { func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
defer func(now time.Time) { defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "CreateUser"}, now) metrics.MeasureSince([]string{"database", "CreateUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now)
@ -133,10 +100,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, username,
metrics.IncrCounter([]string{"database", "CreateUser"}, 1) metrics.IncrCounter([]string{"database", "CreateUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1)
return mw.next.CreateUser(statements, username, password, expiration) return mw.next.CreateUser(statements, usernamePrefix, expiration)
} }
func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username, expiration string) (err error) { func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) {
defer func(now time.Time) { defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "RenewUser"}, now) metrics.MeasureSince([]string{"database", "RenewUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now)
@ -199,51 +166,3 @@ func (mw *databaseMetricsMiddleware) Close() (err error) {
metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1)
return mw.next.Close() return mw.next.Close()
} }
func (mw *databaseMetricsMiddleware) GenerateUsername(displayName string) (_ string, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "GenerateUsername"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "GenerateUsername"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "GenerateUsername", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateUsername", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "GenerateUsername"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateUsername"}, 1)
return mw.next.GenerateUsername(displayName)
}
func (mw *databaseMetricsMiddleware) GeneratePassword() (_ string, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "GeneratePassword"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "GeneratePassword"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "GeneratePassword", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "GeneratePassword", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "GeneratePassword"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "GeneratePassword"}, 1)
return mw.next.GeneratePassword()
}
func (mw *databaseMetricsMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "GenerateExpiration"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "GenerateExpiration"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "GenerateExpiration", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateExpiration", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "GenerateExpiration"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateExpiration"}, 1)
return mw.next.GenerateExpiration(duration)
}

View File

@ -17,16 +17,12 @@ var (
// DatabaseType is the interface that all database objects must implement. // DatabaseType is the interface that all database objects must implement.
type DatabaseType interface { type DatabaseType interface {
Type() string Type() string
CreateUser(statements Statements, username, password, expiration string) error CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error)
RenewUser(statements Statements, username, expiration string) error RenewUser(statements Statements, username string, expiration time.Time) error
RevokeUser(statements Statements, username string) error RevokeUser(statements Statements, username string) error
Initialize(map[string]interface{}) error Initialize(map[string]interface{}) error
Close() error Close() error
GenerateUsername(displayName string) (string, error)
GeneratePassword() (string, error)
GenerateExpiration(ttl time.Duration) (string, error)
} }
// Statements set in role creation and passed into the database type's functions. // Statements set in role creation and passed into the database type's functions.
@ -96,16 +92,15 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
// ---- RPC Request Args Domain ---- // ---- RPC Request Args Domain ----
type CreateUserRequest struct { type CreateUserRequest struct {
Statements Statements Statements Statements
Username string UsernamePrefix string
Password string Expiration time.Time
Expiration string
} }
type RenewUserRequest struct { type RenewUserRequest struct {
Statements Statements Statements Statements
Username string Username string
Expiration string Expiration time.Time
} }
type RevokeUserRequest struct { type RevokeUserRequest struct {
@ -115,12 +110,7 @@ type RevokeUserRequest struct {
// ---- RPC Response Args Domain ---- // ---- RPC Response Args Domain ----
type GenerateUsernameResponse struct { type CreateUserResponse struct {
Username string Username string
}
type GenerateExpirationResponse struct {
Expiration string
}
type GeneratePasswordResponse struct {
Password string Password string
} }

View File

@ -1,8 +1,6 @@
package dbplugin package dbplugin
import ( import (
"time"
"github.com/hashicorp/go-plugin" "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/helper/pluginutil"
) )
@ -39,8 +37,9 @@ func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
return nil return nil
} }
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, _ *struct{}) error { func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error {
err := ds.impl.CreateUser(args.Statements, args.Username, args.Password, args.Expiration) var err error
resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernamePrefix, args.Expiration)
return err return err
} }
@ -67,24 +66,3 @@ func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
ds.impl.Close() ds.impl.Close()
return nil return nil
} }
func (ds *databasePluginRPCServer) GenerateUsername(args string, resp *GenerateUsernameResponse) error {
var err error
resp.Username, err = ds.impl.GenerateUsername(args)
return err
}
func (ds *databasePluginRPCServer) GeneratePassword(_ struct{}, resp *GeneratePasswordResponse) error {
var err error
resp.Password, err = ds.impl.GeneratePassword()
return err
}
func (ds *databasePluginRPCServer) GenerateExpiration(args time.Duration, resp *GenerateExpirationResponse) error {
var err error
resp.Expiration, err = ds.impl.GenerateExpiration(args)
return err
}

View File

@ -2,6 +2,7 @@ package database
import ( import (
"fmt" "fmt"
"time"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
@ -48,24 +49,10 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
} }
// Generate the username, password and expiration expiration := time.Now().Add(role.DefaultTTL)
username, err := db.GenerateUsername(req.DisplayName)
if err != nil {
return nil, err
}
password, err := db.GeneratePassword()
if err != nil {
return nil, err
}
expiration, err := db.GenerateExpiration(role.DefaultTTL)
if err != nil {
return nil, err
}
// Create the user // Create the user
err = db.CreateUser(role.Statements, username, password, expiration) username, password, err := db.CreateUser(role.Statements, req.DisplayName, expiration)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -58,9 +58,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
// Make sure we increase the VALID UNTIL endpoint for this user. // Make sure we increase the VALID UNTIL endpoint for this user.
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
expiration := expireTime.Format("2006-01-02 15:04:05-0700") err := db.RenewUser(role.Statements, username, expireTime)
err := db.RenewUser(role.Statements, username, expiration)
if err != nil { if err != nil {
return nil, err return nil, err
} }