2016-01-20 00:00:09 +00:00
|
|
|
package physical
|
|
|
|
|
|
|
|
import (
|
|
|
|
"database/sql"
|
|
|
|
"fmt"
|
|
|
|
"sort"
|
|
|
|
"strings"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/armon/go-metrics"
|
2016-01-22 16:47:02 +00:00
|
|
|
"github.com/lib/pq"
|
2016-01-20 00:00:09 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
// PostgreSQL Backend is a physical backend that stores data
|
|
|
|
// within a PostgreSQL database.
|
|
|
|
type PostgreSQLBackend struct {
|
|
|
|
table string
|
|
|
|
client *sql.DB
|
|
|
|
statements map[string]*sql.Stmt
|
|
|
|
}
|
|
|
|
|
|
|
|
// newPostgreSQLBackend constructs a PostgreSQL backend using the given
|
|
|
|
// API client, server address, credentials, and database.
|
|
|
|
func newPostgreSQLBackend(conf map[string]string) (Backend, error) {
|
|
|
|
// Get the PostgreSQL credentials to perform read/write operations.
|
|
|
|
connURL, ok := conf["connection_url"]
|
|
|
|
if !ok || connURL == "" {
|
|
|
|
return nil, fmt.Errorf("missing connection_url")
|
|
|
|
}
|
|
|
|
|
2016-01-22 16:47:02 +00:00
|
|
|
unquoted_table, ok := conf["table"]
|
2016-01-20 00:00:09 +00:00
|
|
|
if !ok {
|
2016-01-22 16:47:02 +00:00
|
|
|
unquoted_table = "vault"
|
2016-01-20 17:47:54 +00:00
|
|
|
}
|
2016-01-22 16:47:02 +00:00
|
|
|
quoted_table := pq.QuoteIdentifier(unquoted_table)
|
2016-01-20 17:47:54 +00:00
|
|
|
|
2016-01-21 01:52:49 +00:00
|
|
|
// Create PostgreSQL handle for the database.
|
|
|
|
db, err := sql.Open("postgres", connURL)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to connect to postgres: %v", err)
|
|
|
|
}
|
2016-01-20 17:47:54 +00:00
|
|
|
|
2016-01-21 01:52:49 +00:00
|
|
|
// Determine if we should use an upsert function (versions < 9.5)
|
|
|
|
var upsert_required bool
|
|
|
|
upsert_required_query := "SELECT string_to_array(setting, '.')::int[] < '{9,5}' FROM pg_settings WHERE name = 'server_version'"
|
|
|
|
if err := db.QueryRow(upsert_required_query).Scan(&upsert_required); err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to check for native upsert: %v", err)
|
|
|
|
}
|
2016-01-20 17:47:54 +00:00
|
|
|
|
2016-01-22 16:47:02 +00:00
|
|
|
// Setup our put strategy based on the presence or absence of a native
|
|
|
|
// upsert. The upsert function used is taken [from the PostgreSQL
|
|
|
|
// docs](http://www.postgresql.org/docs/current/static/plpgsql-control-structures.html#PLPGSQL-UPSERT-EXAMPLE)
|
|
|
|
// and chosen primarily for reasons [listed
|
|
|
|
// here](http://www.depesz.com/2012/06/10/why-is-upsert-so-complicated/)
|
2016-01-21 01:52:49 +00:00
|
|
|
var put_statement string
|
2016-01-22 16:47:02 +00:00
|
|
|
create_upsert_sql := `
|
|
|
|
CREATE OR REPLACE FUNCTION vault_upsert(_key TEXT, _value BYTEA) RETURNS VOID AS
|
|
|
|
$$
|
|
|
|
BEGIN
|
|
|
|
LOOP
|
|
|
|
UPDATE ` + quoted_table + ` SET vault_value = _value WHERE vault_key = _key;
|
|
|
|
IF found THEN
|
|
|
|
RETURN;
|
|
|
|
END IF;
|
|
|
|
BEGIN
|
|
|
|
INSERT INTO ` + quoted_table + ` (vault_key, vault_value) VALUES (_key, _value);
|
|
|
|
RETURN;
|
|
|
|
EXCEPTION WHEN unique_violation THEN
|
|
|
|
-- Do nothing, and loop to try the UPDATE again.
|
|
|
|
END;
|
|
|
|
END LOOP;
|
|
|
|
END;
|
|
|
|
$$
|
|
|
|
LANGUAGE plpgsql;`
|
2016-01-21 01:52:49 +00:00
|
|
|
if upsert_required {
|
2016-01-22 16:47:02 +00:00
|
|
|
put_statement = "SELECT vault_upsert($1, $2)"
|
|
|
|
if _, err := db.Exec(create_upsert_sql); err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to create upsert function: %v", err)
|
|
|
|
}
|
2016-01-21 01:52:49 +00:00
|
|
|
} else {
|
2016-01-22 16:47:02 +00:00
|
|
|
put_statement = "INSERT INTO " + quoted_table + " VALUES($1, $2)" +
|
2016-01-22 15:41:31 +00:00
|
|
|
" ON CONFLICT (vault_key) DO " +
|
|
|
|
" UPDATE SET vault_value = $2"
|
2016-01-20 00:00:09 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// Setup the backend.
|
|
|
|
m := &PostgreSQLBackend{
|
2016-01-22 16:47:02 +00:00
|
|
|
table: unquoted_table,
|
2016-01-20 00:00:09 +00:00
|
|
|
client: db,
|
|
|
|
statements: make(map[string]*sql.Stmt),
|
|
|
|
}
|
|
|
|
|
|
|
|
// Prepare all the statements required
|
|
|
|
statements := map[string]string{
|
2016-01-20 17:47:54 +00:00
|
|
|
"put": put_statement,
|
2016-01-22 16:47:02 +00:00
|
|
|
"get": "SELECT vault_value FROM " + quoted_table + " WHERE vault_key = $1",
|
|
|
|
"delete": "DELETE FROM " + quoted_table + " WHERE vault_key = $1",
|
|
|
|
"list": "SELECT vault_key FROM " + quoted_table + " WHERE vault_key LIKE $1",
|
2016-01-20 00:00:09 +00:00
|
|
|
}
|
|
|
|
for name, query := range statements {
|
|
|
|
if err := m.prepare(name, query); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return m, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// prepare is a helper to prepare a query for future execution
|
|
|
|
func (m *PostgreSQLBackend) prepare(name, query string) error {
|
|
|
|
stmt, err := m.client.Prepare(query)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to prepare '%s': %v", name, err)
|
|
|
|
}
|
|
|
|
m.statements[name] = stmt
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Put is used to insert or update an entry.
|
|
|
|
func (m *PostgreSQLBackend) Put(entry *Entry) error {
|
|
|
|
defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now())
|
|
|
|
|
|
|
|
_, err := m.statements["put"].Exec(entry.Key, entry.Value)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get is used to fetch and entry.
|
|
|
|
func (m *PostgreSQLBackend) Get(key string) (*Entry, error) {
|
|
|
|
defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now())
|
|
|
|
|
|
|
|
var result []byte
|
|
|
|
err := m.statements["get"].QueryRow(key).Scan(&result)
|
|
|
|
if err == sql.ErrNoRows {
|
|
|
|
return nil, nil
|
|
|
|
}
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
ent := &Entry{
|
|
|
|
Key: key,
|
|
|
|
Value: result,
|
|
|
|
}
|
|
|
|
return ent, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Delete is used to permanently delete an entry
|
|
|
|
func (m *PostgreSQLBackend) Delete(key string) error {
|
|
|
|
defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now())
|
|
|
|
|
|
|
|
_, err := m.statements["delete"].Exec(key)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// List is used to list all the keys under a given
|
|
|
|
// prefix, up to the next prefix.
|
|
|
|
func (m *PostgreSQLBackend) List(prefix string) ([]string, error) {
|
|
|
|
defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
|
|
|
|
|
|
|
|
// Add the % wildcard to the prefix to do the prefix search
|
|
|
|
likePrefix := prefix + "%"
|
|
|
|
rows, err := m.statements["list"].Query(likePrefix)
|
2016-01-21 00:02:23 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
defer rows.Close()
|
2016-01-20 00:00:09 +00:00
|
|
|
|
|
|
|
var keys []string
|
|
|
|
for rows.Next() {
|
|
|
|
var key string
|
|
|
|
err = rows.Scan(&key)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to scan rows: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
key = strings.TrimPrefix(key, prefix)
|
|
|
|
if i := strings.Index(key, "/"); i == -1 {
|
|
|
|
// Add objects only from the current 'folder'
|
|
|
|
keys = append(keys, key)
|
2016-01-21 00:05:21 +00:00
|
|
|
} else {
|
2016-01-20 00:00:09 +00:00
|
|
|
// Add truncated 'folder' paths
|
|
|
|
keys = appendIfMissing(keys, string(key[:i+1]))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
sort.Strings(keys)
|
|
|
|
return keys, nil
|
|
|
|
}
|