open-vault/sdk/database/dbplugin/v5/middleware.go
Christopher Swenson 2c8e88ab67
Check if plugin version matches running version (#17182)
Check if plugin version matches running version

When registering a plugin, we check if the request version matches the
self-reported version from the plugin. If these do not match, we log a
warning.

This uncovered a few missing pieces for getting the database version
code fully working.

We added an environment variable that helps us unit test the running
version behavior as well, but only for approle, postgresql, and consul
plugins.

Return 400 on plugin not found or version mismatch

Populate the running SHA256 of plugins in the mount and auth tables (#17217)
2022-09-21 12:25:04 -07:00

322 lines
9.9 KiB
Go

package dbplugin
import (
"context"
"errors"
"net/url"
"strings"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/logical"
"google.golang.org/grpc/status"
)
// ///////////////////////////////////////////////////
// Tracing Middleware
// ///////////////////////////////////////////////////
var (
_ Database = databaseTracingMiddleware{}
_ logical.PluginVersioner = databaseTracingMiddleware{}
)
// databaseTracingMiddleware wraps a implementation of Database and executes
// trace logging on function call.
type databaseTracingMiddleware struct {
next Database
logger log.Logger
}
func (mw databaseTracingMiddleware) PluginVersion() (resp logical.PluginVersion) {
defer func(then time.Time) {
mw.logger.Trace("version",
"status", "finished",
"version", resp,
"took", time.Since(then))
}(time.Now())
mw.logger.Trace("version", "status", "started")
if versioner, ok := mw.next.(logical.PluginVersioner); ok {
return versioner.PluginVersion()
}
return logical.EmptyPluginVersion
}
func (mw databaseTracingMiddleware) Initialize(ctx context.Context, req InitializeRequest) (resp InitializeResponse, err error) {
defer func(then time.Time) {
mw.logger.Trace("initialize",
"status", "finished",
"verify", req.VerifyConnection,
"err", err,
"took", time.Since(then))
}(time.Now())
mw.logger.Trace("initialize", "status", "started")
return mw.next.Initialize(ctx, req)
}
func (mw databaseTracingMiddleware) NewUser(ctx context.Context, req NewUserRequest) (resp NewUserResponse, err error) {
defer func(then time.Time) {
mw.logger.Trace("create user",
"status", "finished",
"err", err,
"took", time.Since(then))
}(time.Now())
mw.logger.Trace("create user",
"status", "started")
return mw.next.NewUser(ctx, req)
}
func (mw databaseTracingMiddleware) UpdateUser(ctx context.Context, req UpdateUserRequest) (resp UpdateUserResponse, err error) {
defer func(then time.Time) {
mw.logger.Trace("update user",
"status", "finished",
"err", err,
"took", time.Since(then))
}(time.Now())
mw.logger.Trace("update user", "status", "started")
return mw.next.UpdateUser(ctx, req)
}
func (mw databaseTracingMiddleware) DeleteUser(ctx context.Context, req DeleteUserRequest) (resp DeleteUserResponse, err error) {
defer func(then time.Time) {
mw.logger.Trace("delete user",
"status", "finished",
"err", err,
"took", time.Since(then))
}(time.Now())
mw.logger.Trace("delete user",
"status", "started")
return mw.next.DeleteUser(ctx, req)
}
func (mw databaseTracingMiddleware) Type() (string, error) {
return mw.next.Type()
}
func (mw databaseTracingMiddleware) Close() (err error) {
defer func(then time.Time) {
mw.logger.Trace("close",
"status", "finished",
"err", err,
"took", time.Since(then))
}(time.Now())
mw.logger.Trace("close",
"status", "started")
return mw.next.Close()
}
// ///////////////////////////////////////////////////
// Metrics Middleware Domain
// ///////////////////////////////////////////////////
var (
_ Database = databaseMetricsMiddleware{}
_ logical.PluginVersioner = databaseMetricsMiddleware{}
)
// databaseMetricsMiddleware wraps an implementation of Databases and on
// function call logs metrics about this instance.
type databaseMetricsMiddleware struct {
next Database
typeStr string
}
func (mw databaseMetricsMiddleware) PluginVersion() logical.PluginVersion {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "PluginVersion"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "PluginVersion"}, now)
}(time.Now())
metrics.IncrCounter([]string{"database", "PluginVersion"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "PluginVersion"}, 1)
if versioner, ok := mw.next.(logical.PluginVersioner); ok {
return versioner.PluginVersion()
}
return logical.EmptyPluginVersion
}
func (mw databaseMetricsMiddleware) Initialize(ctx context.Context, req InitializeRequest) (resp InitializeResponse, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "Initialize"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "Initialize", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
return mw.next.Initialize(ctx, req)
}
func (mw databaseMetricsMiddleware) NewUser(ctx context.Context, req NewUserRequest) (resp NewUserResponse, err error) {
defer func(start time.Time) {
metrics.MeasureSince([]string{"database", "NewUser"}, start)
metrics.MeasureSince([]string{"database", mw.typeStr, "NewUser"}, start)
if err != nil {
metrics.IncrCounter([]string{"database", "NewUser", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "NewUser", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "NewUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "NewUser"}, 1)
return mw.next.NewUser(ctx, req)
}
func (mw databaseMetricsMiddleware) UpdateUser(ctx context.Context, req UpdateUserRequest) (resp UpdateUserResponse, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "UpdateUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "UpdateUser"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "UpdateUser", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "UpdateUser", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "UpdateUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "UpdateUser"}, 1)
return mw.next.UpdateUser(ctx, req)
}
func (mw databaseMetricsMiddleware) DeleteUser(ctx context.Context, req DeleteUserRequest) (resp DeleteUserResponse, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "DeleteUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "DeleteUser"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "DeleteUser", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "DeleteUser", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "DeleteUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "DeleteUser"}, 1)
return mw.next.DeleteUser(ctx, req)
}
func (mw databaseMetricsMiddleware) Type() (string, error) {
return mw.next.Type()
}
func (mw databaseMetricsMiddleware) Close() (err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "Close"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "Close"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "Close", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Close", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "Close"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1)
return mw.next.Close()
}
// ///////////////////////////////////////////////////
// Error Sanitizer Middleware Domain
// ///////////////////////////////////////////////////
var (
_ Database = (*DatabaseErrorSanitizerMiddleware)(nil)
_ logical.PluginVersioner = (*DatabaseErrorSanitizerMiddleware)(nil)
)
// DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and
// sanitizes returned error messages
type DatabaseErrorSanitizerMiddleware struct {
next Database
secretsFn secretsFn
}
type secretsFn func() map[string]string
func NewDatabaseErrorSanitizerMiddleware(next Database, secrets secretsFn) DatabaseErrorSanitizerMiddleware {
return DatabaseErrorSanitizerMiddleware{
next: next,
secretsFn: secrets,
}
}
func (mw DatabaseErrorSanitizerMiddleware) Initialize(ctx context.Context, req InitializeRequest) (resp InitializeResponse, err error) {
resp, err = mw.next.Initialize(ctx, req)
return resp, mw.sanitize(err)
}
func (mw DatabaseErrorSanitizerMiddleware) NewUser(ctx context.Context, req NewUserRequest) (resp NewUserResponse, err error) {
resp, err = mw.next.NewUser(ctx, req)
return resp, mw.sanitize(err)
}
func (mw DatabaseErrorSanitizerMiddleware) UpdateUser(ctx context.Context, req UpdateUserRequest) (UpdateUserResponse, error) {
resp, err := mw.next.UpdateUser(ctx, req)
return resp, mw.sanitize(err)
}
func (mw DatabaseErrorSanitizerMiddleware) DeleteUser(ctx context.Context, req DeleteUserRequest) (DeleteUserResponse, error) {
resp, err := mw.next.DeleteUser(ctx, req)
return resp, mw.sanitize(err)
}
func (mw DatabaseErrorSanitizerMiddleware) Type() (string, error) {
dbType, err := mw.next.Type()
return dbType, mw.sanitize(err)
}
func (mw DatabaseErrorSanitizerMiddleware) Close() (err error) {
return mw.sanitize(mw.next.Close())
}
func (mw DatabaseErrorSanitizerMiddleware) PluginVersion() logical.PluginVersion {
if versioner, ok := mw.next.(logical.PluginVersioner); ok {
return versioner.PluginVersion()
}
return logical.EmptyPluginVersion
}
// sanitize errors by removing any sensitive strings within their messages. This uses
// the secretsFn to determine what fields should be sanitized.
func (mw DatabaseErrorSanitizerMiddleware) sanitize(err error) error {
if err == nil {
return nil
}
if errwrap.ContainsType(err, new(url.Error)) {
return errors.New("unable to parse connection url")
}
if mw.secretsFn == nil {
return err
}
for find, replace := range mw.secretsFn() {
if find == "" {
continue
}
// Attempt to keep the status code attached to the
// error while changing the actual error message
s, ok := status.FromError(err)
if ok {
err = status.Error(s.Code(), strings.ReplaceAll(s.Message(), find, replace))
continue
}
err = errors.New(strings.ReplaceAll(err.Error(), find, replace))
}
return err
}