Wrap the database calls with tracing information
This commit is contained in:
parent
2799586f45
commit
494f963581
|
@ -4,8 +4,112 @@ import (
|
|||
"time"
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
)
|
||||
|
||||
// ---- Tracing Middleware Domain ----
|
||||
|
||||
type databaseTracingMiddleware struct {
|
||||
next DatabaseType
|
||||
logger log.Logger
|
||||
|
||||
typeStr string
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Type() string {
|
||||
return mw.next.Type()
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) {
|
||||
if mw.logger.IsTrace() {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database/CreateUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database/CreateUser: starting", "type", mw.typeStr)
|
||||
}
|
||||
return mw.next.CreateUser(statements, username, password, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username, expiration string) (err error) {
|
||||
if mw.logger.IsTrace() {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database/RenewUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database/RenewUser: starting", "type", mw.typeStr)
|
||||
}
|
||||
return mw.next.RenewUser(statements, username, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) {
|
||||
if mw.logger.IsTrace() {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database/RevokeUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database/RevokeUser: starting", "type", mw.typeStr)
|
||||
}
|
||||
return mw.next.RevokeUser(statements, username)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}) (err error) {
|
||||
if mw.logger.IsTrace() {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database/Initialize: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database/Initialize: starting", "type", mw.typeStr)
|
||||
}
|
||||
return mw.next.Initialize(conf)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Close() (err error) {
|
||||
if mw.logger.IsTrace() {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database/Close: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database/Close: starting", "type", mw.typeStr)
|
||||
}
|
||||
return mw.next.Close()
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) GenerateUsername(displayName string) (_ string, err error) {
|
||||
if mw.logger.IsTrace() {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database/GenerateUsername: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database/GenerateUsername: starting", "type", mw.typeStr)
|
||||
}
|
||||
return mw.next.GenerateUsername(displayName)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) GeneratePassword() (_ string, err error) {
|
||||
if mw.logger.IsTrace() {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database/GeneratePassword: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database/GeneratePassword: starting", "type", mw.typeStr)
|
||||
}
|
||||
return mw.next.GeneratePassword()
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) {
|
||||
if mw.logger.IsTrace() {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("database/GenerateExpiration: finished", "type", mw.typeStr, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("database/GenerateExpiration: starting", "type", mw.typeStr)
|
||||
}
|
||||
return mw.next.GenerateExpiration(duration)
|
||||
}
|
||||
|
||||
// ---- Metrics Middleware Domain ----
|
||||
|
||||
type databaseMetricsMiddleware struct {
|
||||
next DatabaseType
|
||||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -21,9 +22,9 @@ var (
|
|||
ErrEmptyCreationStatement = errors.New("Empty creation statements")
|
||||
)
|
||||
|
||||
type Factory func(*DatabaseConfig, logical.SystemView) (DatabaseType, error)
|
||||
type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error)
|
||||
|
||||
func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, error) {
|
||||
func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {
|
||||
var dbType DatabaseType
|
||||
|
||||
switch conf.DatabaseType {
|
||||
|
@ -76,10 +77,17 @@ func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType,
|
|||
typeStr: dbType.Type(),
|
||||
}
|
||||
|
||||
// Wrap with tracing middleware
|
||||
dbType = &databaseTracingMiddleware{
|
||||
next: dbType,
|
||||
typeStr: dbType.Type(),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return dbType, nil
|
||||
}
|
||||
|
||||
func PluginFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, error) {
|
||||
func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {
|
||||
if conf.PluginCommand == "" {
|
||||
return nil, errors.New("ERROR")
|
||||
}
|
||||
|
@ -99,6 +107,13 @@ func PluginFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType,
|
|||
typeStr: db.Type(),
|
||||
}
|
||||
|
||||
// Wrap with tracing middleware
|
||||
db = &databaseTracingMiddleware{
|
||||
next: db,
|
||||
typeStr: db.Type(),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -82,7 +82,12 @@ func newPluginClient(sys pluginutil.Wrapper, command, checksum string) (Database
|
|||
|
||||
// Add the response wrap token to the ENV of the plugin
|
||||
commandArr := strings.Split(command, " ")
|
||||
cmd := exec.Command(commandArr[0], commandArr[1])
|
||||
var cmd *exec.Cmd
|
||||
if len(commandArr) > 1 {
|
||||
cmd = exec.Command(commandArr[0], commandArr[1])
|
||||
} else {
|
||||
cmd = exec.Command(commandArr[0])
|
||||
}
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", pluginutil.PluginUnwrapTokenEnv, wrapToken))
|
||||
|
||||
checksumDecoded, err := hex.DecodeString(checksum)
|
||||
|
|
|
@ -43,13 +43,11 @@ func (p *PostgreSQL) CreateUser(statements Statements, username, password, expir
|
|||
}
|
||||
|
||||
// Start a transaction
|
||||
// b.logger.Trace("postgres/pathRoleCreateRead: starting transaction")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
// b.logger.Trace("postgres/pathRoleCreateRead: rolling back transaction")
|
||||
tx.Rollback()
|
||||
}()
|
||||
// Return the secret
|
||||
|
@ -61,7 +59,6 @@ func (p *PostgreSQL) CreateUser(statements Statements, username, password, expir
|
|||
continue
|
||||
}
|
||||
|
||||
// b.logger.Trace("postgres/pathRoleCreateRead: preparing statement")
|
||||
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
|
@ -71,15 +68,12 @@ func (p *PostgreSQL) CreateUser(statements Statements, username, password, expir
|
|||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
// b.logger.Trace("postgres/pathRoleCreateRead: executing statement")
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
|
||||
// b.logger.Trace("postgres/pathRoleCreateRead: committing transaction")
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew
|
|||
|
||||
factory := config.GetFactory()
|
||||
|
||||
db, err = factory(&config, b.System())
|
||||
db, err = factory(&config, b.System(), b.logger)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
||||
}
|
||||
|
@ -262,7 +262,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
|
|||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
db, err := factory(config, b.System())
|
||||
db, err := factory(config, b.System(), b.logger)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue