Wrap the database calls with tracing information

This commit is contained in:
Brian Kassouf 2017-03-27 15:17:28 -07:00
parent 2799586f45
commit 494f963581
5 changed files with 130 additions and 12 deletions

View file

@ -4,8 +4,112 @@ import (
"time"
metrics "github.com/armon/go-metrics"
log "github.com/mgutz/logxi/v1"
)
// ---- Tracing Middleware Domain ----
type databaseTracingMiddleware struct {
next DatabaseType
logger log.Logger
typeStr string
}
func (mw *databaseTracingMiddleware) Type() string {
return mw.next.Type()
}
func (mw *databaseTracingMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/CreateUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database/CreateUser: starting", "type", mw.typeStr)
}
return mw.next.CreateUser(statements, username, password, expiration)
}
func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username, expiration string) (err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/RenewUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database/RenewUser: starting", "type", mw.typeStr)
}
return mw.next.RenewUser(statements, username, expiration)
}
func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/RevokeUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database/RevokeUser: starting", "type", mw.typeStr)
}
return mw.next.RevokeUser(statements, username)
}
func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}) (err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/Initialize: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database/Initialize: starting", "type", mw.typeStr)
}
return mw.next.Initialize(conf)
}
func (mw *databaseTracingMiddleware) Close() (err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/Close: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database/Close: starting", "type", mw.typeStr)
}
return mw.next.Close()
}
func (mw *databaseTracingMiddleware) GenerateUsername(displayName string) (_ string, err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/GenerateUsername: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database/GenerateUsername: starting", "type", mw.typeStr)
}
return mw.next.GenerateUsername(displayName)
}
func (mw *databaseTracingMiddleware) GeneratePassword() (_ string, err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/GeneratePassword: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database/GeneratePassword: starting", "type", mw.typeStr)
}
return mw.next.GeneratePassword()
}
func (mw *databaseTracingMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) {
if mw.logger.IsTrace() {
defer func(then time.Time) {
mw.logger.Trace("database/GenerateExpiration: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database/GenerateExpiration: starting", "type", mw.typeStr)
}
return mw.next.GenerateExpiration(duration)
}
// ---- Metrics Middleware Domain ----
type databaseMetricsMiddleware struct {
next DatabaseType

View file

@ -7,6 +7,7 @@ import (
"time"
"github.com/hashicorp/vault/logical"
log "github.com/mgutz/logxi/v1"
)
const (
@ -21,9 +22,9 @@ var (
ErrEmptyCreationStatement = errors.New("Empty creation statements")
)
type Factory func(*DatabaseConfig, logical.SystemView) (DatabaseType, error)
type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error)
func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, error) {
func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {
var dbType DatabaseType
switch conf.DatabaseType {
@ -76,10 +77,17 @@ func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType,
typeStr: dbType.Type(),
}
// Wrap with tracing middleware
dbType = &databaseTracingMiddleware{
next: dbType,
typeStr: dbType.Type(),
logger: logger,
}
return dbType, nil
}
func PluginFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, error) {
func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {
if conf.PluginCommand == "" {
return nil, errors.New("ERROR")
}
@ -99,6 +107,13 @@ func PluginFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType,
typeStr: db.Type(),
}
// Wrap with tracing middleware
db = &databaseTracingMiddleware{
next: db,
typeStr: db.Type(),
logger: logger,
}
return db, nil
}

View file

@ -82,7 +82,12 @@ func newPluginClient(sys pluginutil.Wrapper, command, checksum string) (Database
// Add the response wrap token to the ENV of the plugin
commandArr := strings.Split(command, " ")
cmd := exec.Command(commandArr[0], commandArr[1])
var cmd *exec.Cmd
if len(commandArr) > 1 {
cmd = exec.Command(commandArr[0], commandArr[1])
} else {
cmd = exec.Command(commandArr[0])
}
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", pluginutil.PluginUnwrapTokenEnv, wrapToken))
checksumDecoded, err := hex.DecodeString(checksum)

View file

@ -43,13 +43,11 @@ func (p *PostgreSQL) CreateUser(statements Statements, username, password, expir
}
// Start a transaction
// b.logger.Trace("postgres/pathRoleCreateRead: starting transaction")
tx, err := db.Begin()
if err != nil {
return err
}
defer func() {
// b.logger.Trace("postgres/pathRoleCreateRead: rolling back transaction")
tx.Rollback()
}()
// Return the secret
@ -61,7 +59,6 @@ func (p *PostgreSQL) CreateUser(statements Statements, username, password, expir
continue
}
// b.logger.Trace("postgres/pathRoleCreateRead: preparing statement")
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
"name": username,
"password": password,
@ -71,15 +68,12 @@ func (p *PostgreSQL) CreateUser(statements Statements, username, password, expir
return err
}
defer stmt.Close()
// b.logger.Trace("postgres/pathRoleCreateRead: executing statement")
if _, err := stmt.Exec(); err != nil {
return err
}
}
// Commit the transaction
// b.logger.Trace("postgres/pathRoleCreateRead: committing transaction")
if err := tx.Commit(); err != nil {
return err
}

View file

@ -63,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew
factory := config.GetFactory()
db, err = factory(&config, b.System())
db, err = factory(&config, b.System(), b.logger)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
@ -262,7 +262,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
b.Lock()
defer b.Unlock()
db, err := factory(config, b.System())
db, err := factory(config, b.System(), b.logger)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}