open-vault/physical/postgresql.go
Devin Christensen b0f50ecb6c Remove prepared stmnts from pgsql physical backend
Prepared statements prevent the use of connection multiplexing software
such as PGBouncer. Even when PGBouncer is configured for [session mode][1]
there's a possibility that a connection to PostgreSQL can be re-used by
different clients.  This leads to errors when clients use session based
features (like prepared statements).

This change removes prepared statements from the PostgreSQL physical
backend. This will allow vault to successfully work in infrastructures
that employ the use of PGBouncer or other connection multiplexing
software.

[1]: https://pgbouncer.github.io/config.html#poolmode
2016-05-26 17:07:21 -06:00

176 lines
4.5 KiB
Go

package physical
import (
"database/sql"
"fmt"
"log"
"strings"
"time"
"github.com/armon/go-metrics"
"github.com/lib/pq"
)
// PostgreSQL Backend is a physical backend that stores data
// within a PostgreSQL database.
type PostgreSQLBackend struct {
table string
client *sql.DB
put_query string
get_query string
delete_query string
list_query string
logger *log.Logger
}
// newPostgreSQLBackend constructs a PostgreSQL backend using the given
// API client, server address, credentials, and database.
func newPostgreSQLBackend(conf map[string]string, logger *log.Logger) (Backend, error) {
// Get the PostgreSQL credentials to perform read/write operations.
connURL, ok := conf["connection_url"]
if !ok || connURL == "" {
return nil, fmt.Errorf("missing connection_url")
}
unquoted_table, ok := conf["table"]
if !ok {
unquoted_table = "vault_kv_store"
}
quoted_table := pq.QuoteIdentifier(unquoted_table)
// Create PostgreSQL handle for the database.
db, err := sql.Open("postgres", connURL)
if err != nil {
return nil, fmt.Errorf("failed to connect to postgres: %v", err)
}
// Determine if we should use an upsert function (versions < 9.5)
var upsert_required bool
upsert_required_query := "SELECT string_to_array(setting, '.')::int[] < '{9,5}' FROM pg_settings WHERE name = 'server_version'"
if err := db.QueryRow(upsert_required_query).Scan(&upsert_required); err != nil {
return nil, fmt.Errorf("failed to check for native upsert: %v", err)
}
// Setup our put strategy based on the presence or absence of a native
// upsert.
var put_query string
if upsert_required {
put_query = "SELECT vault_kv_put($1, $2, $3, $4)"
} else {
put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" +
" ON CONFLICT (path, key) DO " +
" UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)"
}
// Setup the backend.
m := &PostgreSQLBackend{
table: quoted_table,
client: db,
put_query: put_query,
get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
"UNION SELECT substr(path, length($1)+1) FROM " + quoted_table + "WHERE parent_path = $1",
logger: logger,
}
return m, nil
}
// splitKey is a helper to split a full path key into individual
// parts: parentPath, path, key
func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) {
var parentPath string
var path string
pieces := strings.Split(fullPath, "/")
depth := len(pieces)
key := pieces[depth-1]
if depth == 1 {
parentPath = ""
path = "/"
} else if depth == 2 {
parentPath = "/"
path = "/" + pieces[0] + "/"
} else {
parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/"
path = "/" + strings.Join(pieces[:depth-1], "/") + "/"
}
return parentPath, path, key
}
// Put is used to insert or update an entry.
func (m *PostgreSQLBackend) Put(entry *Entry) error {
defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now())
parentPath, path, key := m.splitKey(entry.Key)
_, err := m.client.Exec(m.put_query, parentPath, path, key, entry.Value)
if err != nil {
return err
}
return nil
}
// Get is used to fetch and entry.
func (m *PostgreSQLBackend) Get(fullPath string) (*Entry, error) {
defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now())
_, path, key := m.splitKey(fullPath)
var result []byte
err := m.client.QueryRow(m.get_query, path, key).Scan(&result)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
ent := &Entry{
Key: key,
Value: result,
}
return ent, nil
}
// Delete is used to permanently delete an entry
func (m *PostgreSQLBackend) Delete(fullPath string) error {
defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now())
_, path, key := m.splitKey(fullPath)
_, err := m.client.Exec(m.delete_query, path, key)
if err != nil {
return err
}
return nil
}
// List is used to list all the keys under a given
// prefix, up to the next prefix.
func (m *PostgreSQLBackend) List(prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
rows, err := m.client.Query(m.list_query, "/" + prefix)
if err != nil {
return nil, err
}
defer rows.Close()
var keys []string
for rows.Next() {
var key string
err = rows.Scan(&key)
if err != nil {
return nil, fmt.Errorf("failed to scan rows: %v", err)
}
keys = append(keys, key)
}
return keys, nil
}