b0f50ecb6c
Prepared statements prevent the use of connection multiplexing software such as PGBouncer. Even when PGBouncer is configured for [session mode][1] there's a possibility that a connection to PostgreSQL can be re-used by different clients. This leads to errors when clients use session based features (like prepared statements). This change removes prepared statements from the PostgreSQL physical backend. This will allow vault to successfully work in infrastructures that employ the use of PGBouncer or other connection multiplexing software. [1]: https://pgbouncer.github.io/config.html#poolmode
176 lines
4.5 KiB
Go
176 lines
4.5 KiB
Go
package physical
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"log"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/armon/go-metrics"
|
|
"github.com/lib/pq"
|
|
)
|
|
|
|
// PostgreSQL Backend is a physical backend that stores data
|
|
// within a PostgreSQL database.
|
|
type PostgreSQLBackend struct {
|
|
table string
|
|
client *sql.DB
|
|
put_query string
|
|
get_query string
|
|
delete_query string
|
|
list_query string
|
|
logger *log.Logger
|
|
}
|
|
|
|
// newPostgreSQLBackend constructs a PostgreSQL backend using the given
|
|
// API client, server address, credentials, and database.
|
|
func newPostgreSQLBackend(conf map[string]string, logger *log.Logger) (Backend, error) {
|
|
// Get the PostgreSQL credentials to perform read/write operations.
|
|
connURL, ok := conf["connection_url"]
|
|
if !ok || connURL == "" {
|
|
return nil, fmt.Errorf("missing connection_url")
|
|
}
|
|
|
|
unquoted_table, ok := conf["table"]
|
|
if !ok {
|
|
unquoted_table = "vault_kv_store"
|
|
}
|
|
quoted_table := pq.QuoteIdentifier(unquoted_table)
|
|
|
|
// Create PostgreSQL handle for the database.
|
|
db, err := sql.Open("postgres", connURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to postgres: %v", err)
|
|
}
|
|
|
|
// Determine if we should use an upsert function (versions < 9.5)
|
|
var upsert_required bool
|
|
upsert_required_query := "SELECT string_to_array(setting, '.')::int[] < '{9,5}' FROM pg_settings WHERE name = 'server_version'"
|
|
if err := db.QueryRow(upsert_required_query).Scan(&upsert_required); err != nil {
|
|
return nil, fmt.Errorf("failed to check for native upsert: %v", err)
|
|
}
|
|
|
|
// Setup our put strategy based on the presence or absence of a native
|
|
// upsert.
|
|
var put_query string
|
|
if upsert_required {
|
|
put_query = "SELECT vault_kv_put($1, $2, $3, $4)"
|
|
} else {
|
|
put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" +
|
|
" ON CONFLICT (path, key) DO " +
|
|
" UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)"
|
|
}
|
|
|
|
// Setup the backend.
|
|
m := &PostgreSQLBackend{
|
|
table: quoted_table,
|
|
client: db,
|
|
put_query: put_query,
|
|
get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
|
delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
|
list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
|
|
"UNION SELECT substr(path, length($1)+1) FROM " + quoted_table + "WHERE parent_path = $1",
|
|
logger: logger,
|
|
}
|
|
|
|
return m, nil
|
|
}
|
|
|
|
// splitKey is a helper to split a full path key into individual
|
|
// parts: parentPath, path, key
|
|
func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) {
|
|
var parentPath string
|
|
var path string
|
|
|
|
pieces := strings.Split(fullPath, "/")
|
|
depth := len(pieces)
|
|
key := pieces[depth-1]
|
|
|
|
if depth == 1 {
|
|
parentPath = ""
|
|
path = "/"
|
|
} else if depth == 2 {
|
|
parentPath = "/"
|
|
path = "/" + pieces[0] + "/"
|
|
} else {
|
|
parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/"
|
|
path = "/" + strings.Join(pieces[:depth-1], "/") + "/"
|
|
}
|
|
|
|
return parentPath, path, key
|
|
}
|
|
|
|
// Put is used to insert or update an entry.
|
|
func (m *PostgreSQLBackend) Put(entry *Entry) error {
|
|
defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now())
|
|
|
|
parentPath, path, key := m.splitKey(entry.Key)
|
|
|
|
_, err := m.client.Exec(m.put_query, parentPath, path, key, entry.Value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get is used to fetch and entry.
|
|
func (m *PostgreSQLBackend) Get(fullPath string) (*Entry, error) {
|
|
defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now())
|
|
|
|
_, path, key := m.splitKey(fullPath)
|
|
|
|
var result []byte
|
|
err := m.client.QueryRow(m.get_query, path, key).Scan(&result)
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ent := &Entry{
|
|
Key: key,
|
|
Value: result,
|
|
}
|
|
return ent, nil
|
|
}
|
|
|
|
// Delete is used to permanently delete an entry
|
|
func (m *PostgreSQLBackend) Delete(fullPath string) error {
|
|
defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now())
|
|
|
|
_, path, key := m.splitKey(fullPath)
|
|
|
|
_, err := m.client.Exec(m.delete_query, path, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// List is used to list all the keys under a given
|
|
// prefix, up to the next prefix.
|
|
func (m *PostgreSQLBackend) List(prefix string) ([]string, error) {
|
|
defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
|
|
|
|
rows, err := m.client.Query(m.list_query, "/" + prefix)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var keys []string
|
|
for rows.Next() {
|
|
var key string
|
|
err = rows.Scan(&key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan rows: %v", err)
|
|
}
|
|
|
|
keys = append(keys, key)
|
|
}
|
|
|
|
return keys, nil
|
|
}
|