// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package database import ( "context" "fmt" "net/rpc" "strings" "sync" "time" "github.com/armon/go-metrics" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/internalshared/configutil" v4 "github.com/hashicorp/vault/sdk/database/dbplugin" v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" "github.com/hashicorp/vault/sdk/database/helper/dbutil" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/locksutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/queue" ) const ( operationPrefixDatabase = "database" databaseConfigPath = "config/" databaseRolePath = "role/" databaseStaticRolePath = "static-role/" minRootCredRollbackAge = 1 * time.Minute ) type dbPluginInstance struct { sync.RWMutex database databaseVersionWrapper id string name string closed bool } func (dbi *dbPluginInstance) Close() error { dbi.Lock() defer dbi.Unlock() if dbi.closed { return nil } dbi.closed = true return dbi.database.Close() } func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { b := Backend(conf) if err := b.Setup(ctx, conf); err != nil { return nil, err } b.credRotationQueue = queue.New() // Load queue and kickoff new periodic ticker go b.initQueue(b.queueCtx, conf, conf.System.ReplicationState()) // collect metrics on number of plugin instances var err error b.gaugeCollectionProcess, err = metricsutil.NewGaugeCollectionProcess( []string{"secrets", "database", "backend", "pluginInstances", "count"}, []metricsutil.Label{}, b.collectPluginInstanceGaugeValues, metrics.Default(), configutil.UsageGaugeDefaultPeriod, // TODO: add config settings for these, or add plumbing to the main config settings configutil.MaximumGaugeCardinalityDefault, b.logger) if err != nil { return nil, err } go b.gaugeCollectionProcess.Run() return b, nil } func Backend(conf *logical.BackendConfig) *databaseBackend { var b databaseBackend b.Backend = &framework.Backend{ Help: strings.TrimSpace(backendHelp), PathsSpecial: &logical.Paths{ LocalStorage: []string{ framework.WALPrefix, }, SealWrapStorage: []string{ "config/*", "static-role/*", }, }, Paths: framework.PathAppend( []*framework.Path{ pathListPluginConnection(&b), pathConfigurePluginConnection(&b), pathResetConnection(&b), }, pathListRoles(&b), pathRoles(&b), pathCredsCreate(&b), pathRotateRootCredentials(&b), ), Secrets: []*framework.Secret{ secretCreds(&b), }, Clean: b.clean, Invalidate: b.invalidate, WALRollback: b.walRollback, WALRollbackMinAge: minRootCredRollbackAge, BackendType: logical.TypeLogical, } b.logger = conf.Logger b.connections = make(map[string]*dbPluginInstance) b.queueCtx, b.cancelQueueCtx = context.WithCancel(context.Background()) b.roleLocks = locksutil.CreateLocks() return &b } func (b *databaseBackend) collectPluginInstanceGaugeValues(context.Context) ([]metricsutil.GaugeLabelValues, error) { // copy the map so we can release the lock connMapCopy := func() map[string]*dbPluginInstance { b.connLock.RLock() defer b.connLock.RUnlock() mapCopy := map[string]*dbPluginInstance{} for k, v := range b.connections { mapCopy[k] = v } return mapCopy }() counts := map[string]int{} for _, v := range connMapCopy { dbType, err := v.database.Type() if err != nil { // there's a chance this will already be closed since we don't hold the lock continue } if _, ok := counts[dbType]; !ok { counts[dbType] = 0 } counts[dbType] += 1 } var gauges []metricsutil.GaugeLabelValues for k, v := range counts { gauges = append(gauges, metricsutil.GaugeLabelValues{Labels: []metricsutil.Label{{Name: "dbType", Value: k}}, Value: float32(v)}) } return gauges, nil } type databaseBackend struct { // connLock is used to synchronize access to the connections map connLock sync.RWMutex // connections holds configured database connections by config name connections map[string]*dbPluginInstance logger log.Logger *framework.Backend // credRotationQueue is an in-memory priority queue used to track Static Roles // that require periodic rotation. Backends will have a PriorityQueue // initialized on setup, but only backends that are mounted by a primary // server or mounted as a local mount will perform the rotations. credRotationQueue *queue.PriorityQueue // queueCtx is the context for the priority queue queueCtx context.Context // cancelQueueCtx is used to terminate the background ticker cancelQueueCtx context.CancelFunc // roleLocks is used to lock modifications to roles in the queue, to ensure // concurrent requests are not modifying the same role and possibly causing // issues with the priority queue. roleLocks []*locksutil.LockEntry // the running gauge collection process gaugeCollectionProcess *metricsutil.GaugeCollectionProcess gaugeCollectionProcessStop sync.Once } func (b *databaseBackend) connGet(name string) *dbPluginInstance { b.connLock.RLock() defer b.connLock.RUnlock() return b.connections[name] } func (b *databaseBackend) connPop(name string) *dbPluginInstance { b.connLock.Lock() defer b.connLock.Unlock() dbi, ok := b.connections[name] if ok { delete(b.connections, name) } return dbi } func (b *databaseBackend) connPopIfEqual(name, id string) *dbPluginInstance { b.connLock.Lock() defer b.connLock.Unlock() dbi, ok := b.connections[name] if ok && dbi.id == id { delete(b.connections, name) return dbi } return nil } func (b *databaseBackend) connPut(name string, newDbi *dbPluginInstance) *dbPluginInstance { b.connLock.Lock() defer b.connLock.Unlock() dbi := b.connections[name] b.connections[name] = newDbi return dbi } func (b *databaseBackend) connClear() map[string]*dbPluginInstance { b.connLock.Lock() defer b.connLock.Unlock() old := b.connections b.connections = make(map[string]*dbPluginInstance) return old } func (b *databaseBackend) DatabaseConfig(ctx context.Context, s logical.Storage, name string) (*DatabaseConfig, error) { entry, err := s.Get(ctx, fmt.Sprintf("config/%s", name)) if err != nil { return nil, fmt.Errorf("failed to read connection configuration: %w", err) } if entry == nil { return nil, fmt.Errorf("failed to find entry for connection with name: %q", name) } var config DatabaseConfig if err := entry.DecodeJSON(&config); err != nil { return nil, err } return &config, nil } type upgradeStatements struct { // This json tag has a typo in it, the new version does not. This // necessitates this upgrade logic. CreationStatements string `json:"creation_statments"` RevocationStatements string `json:"revocation_statements"` RollbackStatements string `json:"rollback_statements"` RenewStatements string `json:"renew_statements"` } type upgradeCheck struct { // This json tag has a typo in it, the new version does not. This // necessitates this upgrade logic. Statements *upgradeStatements `json:"statments,omitempty"` } func (b *databaseBackend) Role(ctx context.Context, s logical.Storage, roleName string) (*roleEntry, error) { return b.roleAtPath(ctx, s, roleName, databaseRolePath) } func (b *databaseBackend) StaticRole(ctx context.Context, s logical.Storage, roleName string) (*roleEntry, error) { return b.roleAtPath(ctx, s, roleName, databaseStaticRolePath) } func (b *databaseBackend) roleAtPath(ctx context.Context, s logical.Storage, roleName string, pathPrefix string) (*roleEntry, error) { entry, err := s.Get(ctx, pathPrefix+roleName) if err != nil { return nil, err } if entry == nil { return nil, nil } var upgradeCh upgradeCheck if err := entry.DecodeJSON(&upgradeCh); err != nil { return nil, err } var result roleEntry if err := entry.DecodeJSON(&result); err != nil { return nil, err } switch { case upgradeCh.Statements != nil: var stmts v4.Statements if upgradeCh.Statements.CreationStatements != "" { stmts.Creation = []string{upgradeCh.Statements.CreationStatements} } if upgradeCh.Statements.RevocationStatements != "" { stmts.Revocation = []string{upgradeCh.Statements.RevocationStatements} } if upgradeCh.Statements.RollbackStatements != "" { stmts.Rollback = []string{upgradeCh.Statements.RollbackStatements} } if upgradeCh.Statements.RenewStatements != "" { stmts.Renewal = []string{upgradeCh.Statements.RenewStatements} } result.Statements = stmts } result.Statements.Revocation = strutil.RemoveEmpty(result.Statements.Revocation) // For backwards compatibility, copy the values back into the string form // of the fields result.Statements = dbutil.StatementCompatibilityHelper(result.Statements) return &result, nil } func (b *databaseBackend) invalidate(ctx context.Context, key string) { switch { case strings.HasPrefix(key, databaseConfigPath): name := strings.TrimPrefix(key, databaseConfigPath) b.ClearConnection(name) } } func (b *databaseBackend) GetConnection(ctx context.Context, s logical.Storage, name string) (*dbPluginInstance, error) { config, err := b.DatabaseConfig(ctx, s, name) if err != nil { return nil, err } return b.GetConnectionWithConfig(ctx, name, config) } func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name string, config *DatabaseConfig) (*dbPluginInstance, error) { dbi := b.connGet(name) if dbi != nil { return dbi, nil } id, err := uuid.GenerateUUID() if err != nil { return nil, err } dbw, err := newDatabaseWrapper(ctx, config.PluginName, config.PluginVersion, b.System(), b.logger) if err != nil { return nil, fmt.Errorf("unable to create database instance: %w", err) } initReq := v5.InitializeRequest{ Config: config.ConnectionDetails, VerifyConnection: true, } _, err = dbw.Initialize(ctx, initReq) if err != nil { dbw.Close() return nil, err } dbi = &dbPluginInstance{ database: dbw, id: id, name: name, } oldConn := b.connPut(name, dbi) if oldConn != nil { err := oldConn.Close() if err != nil { b.Logger().Warn("Error closing database connection", "error", err) } } return dbi, nil } // ClearConnection closes the database connection and // removes it from the b.connections map. func (b *databaseBackend) ClearConnection(name string) error { db := b.connPop(name) if db != nil { // Ignore error here since the database client is always killed db.Close() } return nil } // ClearConnectionId closes the database connection with a specific id and // removes it from the b.connections map. func (b *databaseBackend) ClearConnectionId(name, id string) error { db := b.connPopIfEqual(name, id) if db != nil { // Ignore error here since the database client is always killed db.Close() } return nil } func (b *databaseBackend) CloseIfShutdown(db *dbPluginInstance, err error) { // Plugin has shutdown, close it so next call can reconnect. switch err { case rpc.ErrShutdown, v4.ErrPluginShutdown, v5.ErrPluginShutdown: // Put this in a goroutine so that requests can run with the read or write lock // and simply defer the unlock. Since we are attaching the instance and matching // the id in the connection map, we can safely do this. go func() { db.Close() // Delete the connection if it is still active. b.connPopIfEqual(db.name, db.id) }() } } // clean closes all connections from all database types // and cancels any rotation queue loading operation. func (b *databaseBackend) clean(_ context.Context) { // kill the queue and terminate the background ticker if b.cancelQueueCtx != nil { b.cancelQueueCtx() } connections := b.connClear() for _, db := range connections { go db.Close() } b.gaugeCollectionProcessStop.Do(func() { if b.gaugeCollectionProcess != nil { b.gaugeCollectionProcess.Stop() } b.gaugeCollectionProcess = nil }) } const backendHelp = ` The database backend supports using many different databases as secret backends, including but not limited to: cassandra, mssql, mysql, postgres After mounting this backend, configure it using the endpoints within the "database/config/" path. `