open-vault/physical/postgresql.go

194 lines
5.3 KiB
Go
Raw Normal View History

package physical
import (
"database/sql"
"fmt"
"sort"
"strings"
"time"
"github.com/armon/go-metrics"
"github.com/lib/pq"
)
// PostgreSQL Backend is a physical backend that stores data
// within a PostgreSQL database.
type PostgreSQLBackend struct {
table string
client *sql.DB
statements map[string]*sql.Stmt
}
// newPostgreSQLBackend constructs a PostgreSQL backend using the given
// API client, server address, credentials, and database.
func newPostgreSQLBackend(conf map[string]string) (Backend, error) {
// Get the PostgreSQL credentials to perform read/write operations.
connURL, ok := conf["connection_url"]
if !ok || connURL == "" {
return nil, fmt.Errorf("missing connection_url")
}
unquoted_table, ok := conf["table"]
if !ok {
unquoted_table = "vault"
2016-01-20 17:47:54 +00:00
}
quoted_table := pq.QuoteIdentifier(unquoted_table)
2016-01-20 17:47:54 +00:00
2016-01-21 01:52:49 +00:00
// Create PostgreSQL handle for the database.
db, err := sql.Open("postgres", connURL)
if err != nil {
return nil, fmt.Errorf("failed to connect to postgres: %v", err)
}
2016-01-20 17:47:54 +00:00
2016-01-21 01:52:49 +00:00
// Determine if we should use an upsert function (versions < 9.5)
var upsert_required bool
upsert_required_query := "SELECT string_to_array(setting, '.')::int[] < '{9,5}' FROM pg_settings WHERE name = 'server_version'"
if err := db.QueryRow(upsert_required_query).Scan(&upsert_required); err != nil {
return nil, fmt.Errorf("failed to check for native upsert: %v", err)
}
2016-01-20 17:47:54 +00:00
// Setup our put strategy based on the presence or absence of a native
// upsert. The upsert function used is taken [from the PostgreSQL
// docs](http://www.postgresql.org/docs/current/static/plpgsql-control-structures.html#PLPGSQL-UPSERT-EXAMPLE)
// and chosen primarily for reasons [listed
// here](http://www.depesz.com/2012/06/10/why-is-upsert-so-complicated/)
2016-01-21 01:52:49 +00:00
var put_statement string
create_upsert_sql := `
CREATE OR REPLACE FUNCTION vault_upsert(_key TEXT, _value BYTEA) RETURNS VOID AS
$$
BEGIN
LOOP
UPDATE ` + quoted_table + ` SET vault_value = _value WHERE vault_key = _key;
IF found THEN
RETURN;
END IF;
BEGIN
INSERT INTO ` + quoted_table + ` (vault_key, vault_value) VALUES (_key, _value);
RETURN;
EXCEPTION WHEN unique_violation THEN
-- Do nothing, and loop to try the UPDATE again.
END;
END LOOP;
END;
$$
LANGUAGE plpgsql;`
2016-01-21 01:52:49 +00:00
if upsert_required {
put_statement = "SELECT vault_upsert($1, $2)"
if _, err := db.Exec(create_upsert_sql); err != nil {
return nil, fmt.Errorf("failed to create upsert function: %v", err)
}
2016-01-21 01:52:49 +00:00
} else {
put_statement = "INSERT INTO " + quoted_table + " VALUES($1, $2)" +
" ON CONFLICT (vault_key) DO " +
" UPDATE SET vault_value = $2"
}
// Setup the backend.
m := &PostgreSQLBackend{
table: unquoted_table,
client: db,
statements: make(map[string]*sql.Stmt),
}
// Prepare all the statements required
statements := map[string]string{
2016-01-20 17:47:54 +00:00
"put": put_statement,
"get": "SELECT vault_value FROM " + quoted_table + " WHERE vault_key = $1",
"delete": "DELETE FROM " + quoted_table + " WHERE vault_key = $1",
"list": "SELECT vault_key FROM " + quoted_table + " WHERE vault_key LIKE $1",
}
for name, query := range statements {
if err := m.prepare(name, query); err != nil {
return nil, err
}
}
return m, nil
}
// prepare is a helper to prepare a query for future execution
func (m *PostgreSQLBackend) prepare(name, query string) error {
stmt, err := m.client.Prepare(query)
if err != nil {
return fmt.Errorf("failed to prepare '%s': %v", name, err)
}
m.statements[name] = stmt
return nil
}
// Put is used to insert or update an entry.
func (m *PostgreSQLBackend) Put(entry *Entry) error {
defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now())
_, err := m.statements["put"].Exec(entry.Key, entry.Value)
if err != nil {
return err
}
return nil
}
// Get is used to fetch and entry.
func (m *PostgreSQLBackend) Get(key string) (*Entry, error) {
defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now())
var result []byte
err := m.statements["get"].QueryRow(key).Scan(&result)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
ent := &Entry{
Key: key,
Value: result,
}
return ent, nil
}
// Delete is used to permanently delete an entry
func (m *PostgreSQLBackend) Delete(key string) error {
defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now())
_, err := m.statements["delete"].Exec(key)
if err != nil {
return err
}
return nil
}
// List is used to list all the keys under a given
// prefix, up to the next prefix.
func (m *PostgreSQLBackend) List(prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
// Add the % wildcard to the prefix to do the prefix search
likePrefix := prefix + "%"
rows, err := m.statements["list"].Query(likePrefix)
2016-01-21 00:02:23 +00:00
if err != nil {
return nil, err
}
defer rows.Close()
var keys []string
for rows.Next() {
var key string
err = rows.Scan(&key)
if err != nil {
return nil, fmt.Errorf("failed to scan rows: %v", err)
}
key = strings.TrimPrefix(key, prefix)
if i := strings.Index(key, "/"); i == -1 {
// Add objects only from the current 'folder'
keys = append(keys, key)
2016-01-21 00:05:21 +00:00
} else {
// Add truncated 'folder' paths
keys = appendIfMissing(keys, string(key[:i+1]))
}
}
sort.Strings(keys)
return keys, nil
}