// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package postgresql import ( "context" "database/sql" "fmt" "os" "strconv" "strings" "sync" "time" "github.com/armon/go-metrics" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/database/helper/dbutil" "github.com/hashicorp/vault/sdk/physical" _ "github.com/jackc/pgx/v4/stdlib" ) const ( // The lock TTL matches the default that Consul API uses, 15 seconds. // Used as part of SQL commands to set/extend lock expiry time relative to // database clock. PostgreSQLLockTTLSeconds = 15 // The amount of time to wait between the lock renewals PostgreSQLLockRenewInterval = 5 * time.Second // PostgreSQLLockRetryInterval is the amount of time to wait // if a lock fails before trying again. PostgreSQLLockRetryInterval = time.Second ) // Verify PostgreSQLBackend satisfies the correct interfaces var _ physical.Backend = (*PostgreSQLBackend)(nil) // HA backend was implemented based on the DynamoDB backend pattern // With distinction using central postgres clock, hereby avoiding // possible issues with multiple clocks var ( _ physical.HABackend = (*PostgreSQLBackend)(nil) _ physical.Lock = (*PostgreSQLLock)(nil) ) // PostgreSQL Backend is a physical backend that stores data // within a PostgreSQL database. type PostgreSQLBackend struct { table string client *sql.DB put_query string get_query string delete_query string list_query string ha_table string haGetLockValueQuery string haUpsertLockIdentityExec string haDeleteLockExec string haEnabled bool logger log.Logger permitPool *physical.PermitPool } // PostgreSQLLock implements a lock using an PostgreSQL client. type PostgreSQLLock struct { backend *PostgreSQLBackend value, key string identity string lock sync.Mutex renewTicker *time.Ticker // ttlSeconds is how long a lock is valid for ttlSeconds int // renewInterval is how much time to wait between lock renewals. must be << ttl renewInterval time.Duration // retryInterval is how much time to wait between attempts to grab the lock retryInterval time.Duration } // NewPostgreSQLBackend constructs a PostgreSQL backend using the given // API client, server address, credentials, and database. func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { // Get the PostgreSQL credentials to perform read/write operations. connURL := connectionURL(conf) if connURL == "" { return nil, fmt.Errorf("missing connection_url") } unquoted_table, ok := conf["table"] if !ok { unquoted_table = "vault_kv_store" } quoted_table := dbutil.QuoteIdentifier(unquoted_table) maxParStr, ok := conf["max_parallel"] var maxParInt int var err error if ok { maxParInt, err = strconv.Atoi(maxParStr) if err != nil { return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err) } if logger.IsDebug() { logger.Debug("max_parallel set", "max_parallel", maxParInt) } } else { maxParInt = physical.DefaultParallelOperations } maxIdleConnsStr, maxIdleConnsIsSet := conf["max_idle_connections"] var maxIdleConns int if maxIdleConnsIsSet { maxIdleConns, err = strconv.Atoi(maxIdleConnsStr) if err != nil { return nil, fmt.Errorf("failed parsing max_idle_connections parameter: %w", err) } if logger.IsDebug() { logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnsStr) } } // Create PostgreSQL handle for the database. db, err := sql.Open("pgx", connURL) if err != nil { return nil, fmt.Errorf("failed to connect to postgres: %w", err) } db.SetMaxOpenConns(maxParInt) if maxIdleConnsIsSet { db.SetMaxIdleConns(maxIdleConns) } // Determine if we should use a function to work around lack of upsert (versions < 9.5) var upsertAvailable bool upsertAvailableQuery := "SELECT current_setting('server_version_num')::int >= 90500" if err := db.QueryRow(upsertAvailableQuery).Scan(&upsertAvailable); err != nil { return nil, fmt.Errorf("failed to check for native upsert: %w", err) } if !upsertAvailable && conf["ha_enabled"] == "true" { return nil, fmt.Errorf("ha_enabled=true in config but PG version doesn't support HA, must be at least 9.5") } // Setup our put strategy based on the presence or absence of a native // upsert. var put_query string if !upsertAvailable { put_query = "SELECT vault_kv_put($1, $2, $3, $4)" } else { put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" + " ON CONFLICT (path, key) DO " + " UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)" } unquoted_ha_table, ok := conf["ha_table"] if !ok { unquoted_ha_table = "vault_ha_locks" } quoted_ha_table := dbutil.QuoteIdentifier(unquoted_ha_table) // Setup the backend. m := &PostgreSQLBackend{ table: quoted_table, client: db, put_query: put_query, get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2", delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2", list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" + " UNION ALL SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " + quoted_table + " WHERE parent_path LIKE $1 || '%'", haGetLockValueQuery: // only read non expired data " SELECT ha_value FROM " + quoted_ha_table + " WHERE NOW() <= valid_until AND ha_key = $1 ", haUpsertLockIdentityExec: // $1=identity $2=ha_key $3=ha_value $4=TTL in seconds // update either steal expired lock OR update expiry for lock owned by me " INSERT INTO " + quoted_ha_table + " as t (ha_identity, ha_key, ha_value, valid_until) VALUES ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds' ) " + " ON CONFLICT (ha_key) DO " + " UPDATE SET (ha_identity, ha_key, ha_value, valid_until) = ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds') " + " WHERE (t.valid_until < NOW() AND t.ha_key = $2) OR " + " (t.ha_identity = $1 AND t.ha_key = $2) ", haDeleteLockExec: // $1=ha_identity $2=ha_key " DELETE FROM " + quoted_ha_table + " WHERE ha_identity=$1 AND ha_key=$2 ", logger: logger, permitPool: physical.NewPermitPool(maxParInt), haEnabled: conf["ha_enabled"] == "true", } return m, nil } // connectionURL first check the environment variables for a connection URL. If // no connection URL exists in the environment variable, the Vault config file is // checked. If neither the environment variables or the config file set the connection // URL for the Postgres backend, because it is a required field, an error is returned. func connectionURL(conf map[string]string) string { connURL := conf["connection_url"] if envURL := os.Getenv("VAULT_PG_CONNECTION_URL"); envURL != "" { connURL = envURL } return connURL } // splitKey is a helper to split a full path key into individual // parts: parentPath, path, key func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) { var parentPath string var path string pieces := strings.Split(fullPath, "/") depth := len(pieces) key := pieces[depth-1] if depth == 1 { parentPath = "" path = "/" } else if depth == 2 { parentPath = "/" path = "/" + pieces[0] + "/" } else { parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/" path = "/" + strings.Join(pieces[:depth-1], "/") + "/" } return parentPath, path, key } // Put is used to insert or update an entry. func (m *PostgreSQLBackend) Put(ctx context.Context, entry *physical.Entry) error { defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now()) m.permitPool.Acquire() defer m.permitPool.Release() parentPath, path, key := m.splitKey(entry.Key) _, err := m.client.ExecContext(ctx, m.put_query, parentPath, path, key, entry.Value) if err != nil { return err } return nil } // Get is used to fetch and entry. func (m *PostgreSQLBackend) Get(ctx context.Context, fullPath string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now()) m.permitPool.Acquire() defer m.permitPool.Release() _, path, key := m.splitKey(fullPath) var result []byte err := m.client.QueryRowContext(ctx, m.get_query, path, key).Scan(&result) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, err } ent := &physical.Entry{ Key: fullPath, Value: result, } return ent, nil } // Delete is used to permanently delete an entry func (m *PostgreSQLBackend) Delete(ctx context.Context, fullPath string) error { defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now()) m.permitPool.Acquire() defer m.permitPool.Release() _, path, key := m.splitKey(fullPath) _, err := m.client.ExecContext(ctx, m.delete_query, path, key) if err != nil { return err } return nil } // List is used to list all the keys under a given // prefix, up to the next prefix. func (m *PostgreSQLBackend) List(ctx context.Context, prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now()) m.permitPool.Acquire() defer m.permitPool.Release() rows, err := m.client.QueryContext(ctx, m.list_query, "/"+prefix) if err != nil { return nil, err } defer rows.Close() var keys []string for rows.Next() { var key string err = rows.Scan(&key) if err != nil { return nil, fmt.Errorf("failed to scan rows: %w", err) } keys = append(keys, key) } return keys, nil } // LockWith is used for mutual exclusion based on the given key. func (p *PostgreSQLBackend) LockWith(key, value string) (physical.Lock, error) { identity, err := uuid.GenerateUUID() if err != nil { return nil, err } return &PostgreSQLLock{ backend: p, key: key, value: value, identity: identity, ttlSeconds: PostgreSQLLockTTLSeconds, renewInterval: PostgreSQLLockRenewInterval, retryInterval: PostgreSQLLockRetryInterval, }, nil } func (p *PostgreSQLBackend) HAEnabled() bool { return p.haEnabled } // Lock tries to acquire the lock by repeatedly trying to create a record in the // PostgreSQL table. It will block until either the stop channel is closed or // the lock could be acquired successfully. The returned channel will be closed // once the lock in the PostgreSQL table cannot be renewed, either due to an // error speaking to PostgreSQL or because someone else has taken it. func (l *PostgreSQLLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { l.lock.Lock() defer l.lock.Unlock() var ( success = make(chan struct{}) errors = make(chan error) leader = make(chan struct{}) ) // try to acquire the lock asynchronously go l.tryToLock(stopCh, success, errors) select { case <-success: // after acquiring it successfully, we must renew the lock periodically l.renewTicker = time.NewTicker(l.renewInterval) go l.periodicallyRenewLock(leader) case err := <-errors: return nil, err case <-stopCh: return nil, nil } return leader, nil } // Unlock releases the lock by deleting the lock record from the // PostgreSQL table. func (l *PostgreSQLLock) Unlock() error { pg := l.backend pg.permitPool.Acquire() defer pg.permitPool.Release() if l.renewTicker != nil { l.renewTicker.Stop() } // Delete lock owned by me _, err := pg.client.Exec(pg.haDeleteLockExec, l.identity, l.key) return err } // Value checks whether or not the lock is held by any instance of PostgreSQLLock, // including this one, and returns the current value. func (l *PostgreSQLLock) Value() (bool, string, error) { pg := l.backend pg.permitPool.Acquire() defer pg.permitPool.Release() var result string err := pg.client.QueryRow(pg.haGetLockValueQuery, l.key).Scan(&result) switch err { case nil: return true, result, nil case sql.ErrNoRows: return false, "", nil default: return false, "", err } } // tryToLock tries to create a new item in PostgreSQL every `retryInterval`. // As long as the item cannot be created (because it already exists), it will // be retried. If the operation fails due to an error, it is sent to the errors // channel. When the lock could be acquired successfully, the success channel // is closed. func (l *PostgreSQLLock) tryToLock(stop <-chan struct{}, success chan struct{}, errors chan error) { ticker := time.NewTicker(l.retryInterval) defer ticker.Stop() for { select { case <-stop: return case <-ticker.C: gotlock, err := l.writeItem() switch { case err != nil: errors <- err return case gotlock: close(success) return } } } } func (l *PostgreSQLLock) periodicallyRenewLock(done chan struct{}) { for range l.renewTicker.C { gotlock, err := l.writeItem() if err != nil || !gotlock { close(done) l.renewTicker.Stop() return } } } // Attempts to put/update the PostgreSQL item using condition expressions to // evaluate the TTL. Returns true if the lock was obtained, false if not. // If false error may be nil or non-nil: nil indicates simply that someone // else has the lock, whereas non-nil means that something unexpected happened. func (l *PostgreSQLLock) writeItem() (bool, error) { pg := l.backend pg.permitPool.Acquire() defer pg.permitPool.Release() // Try steal lock or update expiry on my lock sqlResult, err := pg.client.Exec(pg.haUpsertLockIdentityExec, l.identity, l.key, l.value, l.ttlSeconds) if err != nil { return false, err } if sqlResult == nil { return false, fmt.Errorf("empty SQL response received") } ar, err := sqlResult.RowsAffected() if err != nil { return false, err } return ar == 1, nil }