From bbbd81220c92b4f3f7818b74d586994f9d13ea4c Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 10 Apr 2017 12:24:16 -0700 Subject: [PATCH] Update the interface for plugins removing functions for creating creds --- builtin/logical/database/dbplugin/client.go | 37 ++------ .../database/dbplugin/databasemiddleware.go | 93 ++----------------- builtin/logical/database/dbplugin/plugin.go | 24 ++--- builtin/logical/database/dbplugin/server.go | 28 +----- builtin/logical/database/path_role_create.go | 19 +--- builtin/logical/database/secret_creds.go | 4 +- 6 files changed, 28 insertions(+), 177 deletions(-) diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index db6b3d1fd..0dae61d27 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -78,20 +78,20 @@ func (dr *databasePluginRPCClient) Type() string { return fmt.Sprintf("plugin-%s", dbType) } -func (dr *databasePluginRPCClient) CreateUser(statements Statements, username, password, expiration string) error { +func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { req := CreateUserRequest{ - Statements: statements, - Username: username, - Password: password, - Expiration: expiration, + Statements: statements, + UsernamePrefix: usernamePrefix, + Expiration: expiration, } - err := dr.client.Call("Plugin.CreateUser", req, &struct{}{}) + var resp CreateUserResponse + err = dr.client.Call("Plugin.CreateUser", req, &resp) - return err + return resp.Username, resp.Password, err } -func (dr *databasePluginRPCClient) RenewUser(statements Statements, username, expiration string) error { +func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error { req := RenewUserRequest{ Statements: statements, Username: username, @@ -125,24 +125,3 @@ func (dr *databasePluginRPCClient) Close() error { return err } - -func (dr *databasePluginRPCClient) GenerateUsername(displayName string) (string, error) { - resp := &GenerateUsernameResponse{} - err := dr.client.Call("Plugin.GenerateUsername", displayName, resp) - - return resp.Username, err -} - -func (dr *databasePluginRPCClient) GeneratePassword() (string, error) { - resp := &GeneratePasswordResponse{} - err := dr.client.Call("Plugin.GeneratePassword", struct{}{}, resp) - - return resp.Password, err -} - -func (dr *databasePluginRPCClient) GenerateExpiration(duration time.Duration) (string, error) { - resp := &GenerateExpirationResponse{} - err := dr.client.Call("Plugin.GenerateExpiration", duration, resp) - - return resp.Expiration, err -} diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index b4a980950..2748f2f11 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -20,7 +20,7 @@ func (mw *databaseTracingMiddleware) Type() string { return mw.next.Type() } -func (mw *databaseTracingMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) { +func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { if mw.logger.IsTrace() { defer func(then time.Time) { mw.logger.Trace("database/CreateUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) @@ -28,10 +28,10 @@ func (mw *databaseTracingMiddleware) CreateUser(statements Statements, username, mw.logger.Trace("database/CreateUser: starting", "type", mw.typeStr) } - return mw.next.CreateUser(statements, username, password, expiration) + return mw.next.CreateUser(statements, usernamePrefix, expiration) } -func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username, expiration string) (err error) { +func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { if mw.logger.IsTrace() { defer func(then time.Time) { mw.logger.Trace("database/RenewUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) @@ -75,39 +75,6 @@ func (mw *databaseTracingMiddleware) Close() (err error) { return mw.next.Close() } -func (mw *databaseTracingMiddleware) GenerateUsername(displayName string) (_ string, err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database/GenerateUsername: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) - - mw.logger.Trace("database/GenerateUsername: starting", "type", mw.typeStr) - } - return mw.next.GenerateUsername(displayName) -} - -func (mw *databaseTracingMiddleware) GeneratePassword() (_ string, err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database/GeneratePassword: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) - - mw.logger.Trace("database/GeneratePassword: starting", "type", mw.typeStr) - } - return mw.next.GeneratePassword() -} - -func (mw *databaseTracingMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database/GenerateExpiration: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) - - mw.logger.Trace("database/GenerateExpiration: starting", "type", mw.typeStr) - } - return mw.next.GenerateExpiration(duration) -} - // ---- Metrics Middleware Domain ---- type databaseMetricsMiddleware struct { @@ -120,7 +87,7 @@ func (mw *databaseMetricsMiddleware) Type() string { return mw.next.Type() } -func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) { +func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "CreateUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now) @@ -133,10 +100,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, username, metrics.IncrCounter([]string{"database", "CreateUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1) - return mw.next.CreateUser(statements, username, password, expiration) + return mw.next.CreateUser(statements, usernamePrefix, expiration) } -func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username, expiration string) (err error) { +func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "RenewUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now) @@ -199,51 +166,3 @@ func (mw *databaseMetricsMiddleware) Close() (err error) { metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1) return mw.next.Close() } - -func (mw *databaseMetricsMiddleware) GenerateUsername(displayName string) (_ string, err error) { - defer func(now time.Time) { - metrics.MeasureSince([]string{"database", "GenerateUsername"}, now) - metrics.MeasureSince([]string{"database", mw.typeStr, "GenerateUsername"}, now) - - if err != nil { - metrics.IncrCounter([]string{"database", "GenerateUsername", "error"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateUsername", "error"}, 1) - } - }(time.Now()) - - metrics.IncrCounter([]string{"database", "GenerateUsername"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateUsername"}, 1) - return mw.next.GenerateUsername(displayName) -} - -func (mw *databaseMetricsMiddleware) GeneratePassword() (_ string, err error) { - defer func(now time.Time) { - metrics.MeasureSince([]string{"database", "GeneratePassword"}, now) - metrics.MeasureSince([]string{"database", mw.typeStr, "GeneratePassword"}, now) - - if err != nil { - metrics.IncrCounter([]string{"database", "GeneratePassword", "error"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GeneratePassword", "error"}, 1) - } - }(time.Now()) - - metrics.IncrCounter([]string{"database", "GeneratePassword"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GeneratePassword"}, 1) - return mw.next.GeneratePassword() -} - -func (mw *databaseMetricsMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) { - defer func(now time.Time) { - metrics.MeasureSince([]string{"database", "GenerateExpiration"}, now) - metrics.MeasureSince([]string{"database", mw.typeStr, "GenerateExpiration"}, now) - - if err != nil { - metrics.IncrCounter([]string{"database", "GenerateExpiration", "error"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateExpiration", "error"}, 1) - } - }(time.Now()) - - metrics.IncrCounter([]string{"database", "GenerateExpiration"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateExpiration"}, 1) - return mw.next.GenerateExpiration(duration) -} diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 994f3b0ce..5cd24e879 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -17,16 +17,12 @@ var ( // DatabaseType is the interface that all database objects must implement. type DatabaseType interface { Type() string - CreateUser(statements Statements, username, password, expiration string) error - RenewUser(statements Statements, username, expiration string) error + CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) + RenewUser(statements Statements, username string, expiration time.Time) error RevokeUser(statements Statements, username string) error Initialize(map[string]interface{}) error Close() error - - GenerateUsername(displayName string) (string, error) - GeneratePassword() (string, error) - GenerateExpiration(ttl time.Duration) (string, error) } // Statements set in role creation and passed into the database type's functions. @@ -96,16 +92,15 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e // ---- RPC Request Args Domain ---- type CreateUserRequest struct { - Statements Statements - Username string - Password string - Expiration string + Statements Statements + UsernamePrefix string + Expiration time.Time } type RenewUserRequest struct { Statements Statements Username string - Expiration string + Expiration time.Time } type RevokeUserRequest struct { @@ -115,12 +110,7 @@ type RevokeUserRequest struct { // ---- RPC Response Args Domain ---- -type GenerateUsernameResponse struct { +type CreateUserResponse struct { Username string -} -type GenerateExpirationResponse struct { - Expiration string -} -type GeneratePasswordResponse struct { Password string } diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 018d9b8db..2dddbaffd 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -1,8 +1,6 @@ package dbplugin import ( - "time" - "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/helper/pluginutil" ) @@ -39,8 +37,9 @@ func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { return nil } -func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, _ *struct{}) error { - err := ds.impl.CreateUser(args.Statements, args.Username, args.Password, args.Expiration) +func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error { + var err error + resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernamePrefix, args.Expiration) return err } @@ -67,24 +66,3 @@ func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { ds.impl.Close() return nil } - -func (ds *databasePluginRPCServer) GenerateUsername(args string, resp *GenerateUsernameResponse) error { - var err error - resp.Username, err = ds.impl.GenerateUsername(args) - - return err -} - -func (ds *databasePluginRPCServer) GeneratePassword(_ struct{}, resp *GeneratePasswordResponse) error { - var err error - resp.Password, err = ds.impl.GeneratePassword() - - return err -} - -func (ds *databasePluginRPCServer) GenerateExpiration(args time.Duration, resp *GenerateExpirationResponse) error { - var err error - resp.Expiration, err = ds.impl.GenerateExpiration(args) - - return err -} diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index d379ef267..5a16c8926 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -2,6 +2,7 @@ package database import ( "fmt" + "time" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -48,24 +49,10 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } - // Generate the username, password and expiration - username, err := db.GenerateUsername(req.DisplayName) - if err != nil { - return nil, err - } - - password, err := db.GeneratePassword() - if err != nil { - return nil, err - } - - expiration, err := db.GenerateExpiration(role.DefaultTTL) - if err != nil { - return nil, err - } + expiration := time.Now().Add(role.DefaultTTL) // Create the user - err = db.CreateUser(role.Statements, username, password, expiration) + username, password, err := db.CreateUser(role.Statements, req.DisplayName, expiration) if err != nil { return nil, err } diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 353541c0c..5701e373a 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -58,9 +58,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi // Make sure we increase the VALID UNTIL endpoint for this user. if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { - expiration := expireTime.Format("2006-01-02 15:04:05-0700") - - err := db.RenewUser(role.Statements, username, expiration) + err := db.RenewUser(role.Statements, username, expireTime) if err != nil { return nil, err }