open-vault/physical/postgresql.go
Sean Chittenden aeea7628d6 Add a *log.Logger argument to physical.Factory
Logging in the backend is a good thing.  This is a noisy interface change but should be a functional noop.
2016-04-25 20:10:32 -07:00

192 lines
5 KiB
Go

package physical
import (
"database/sql"
"fmt"
"log"
"strings"
"time"
"github.com/armon/go-metrics"
"github.com/lib/pq"
)
// PostgreSQL Backend is a physical backend that stores data
// within a PostgreSQL database.
type PostgreSQLBackend struct {
table string
client *sql.DB
statements map[string]*sql.Stmt
logger *log.Logger
}
// newPostgreSQLBackend constructs a PostgreSQL backend using the given
// API client, server address, credentials, and database.
func newPostgreSQLBackend(conf map[string]string, logger *log.Logger) (Backend, error) {
// Get the PostgreSQL credentials to perform read/write operations.
connURL, ok := conf["connection_url"]
if !ok || connURL == "" {
return nil, fmt.Errorf("missing connection_url")
}
unquoted_table, ok := conf["table"]
if !ok {
unquoted_table = "vault_kv_store"
}
quoted_table := pq.QuoteIdentifier(unquoted_table)
// Create PostgreSQL handle for the database.
db, err := sql.Open("postgres", connURL)
if err != nil {
return nil, fmt.Errorf("failed to connect to postgres: %v", err)
}
// Determine if we should use an upsert function (versions < 9.5)
var upsert_required bool
upsert_required_query := "SELECT string_to_array(setting, '.')::int[] < '{9,5}' FROM pg_settings WHERE name = 'server_version'"
if err := db.QueryRow(upsert_required_query).Scan(&upsert_required); err != nil {
return nil, fmt.Errorf("failed to check for native upsert: %v", err)
}
// Setup our put strategy based on the presence or absence of a native
// upsert.
var put_statement string
if upsert_required {
put_statement = "SELECT vault_kv_put($1, $2, $3, $4)"
} else {
put_statement = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" +
" ON CONFLICT (path, key) DO " +
" UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)"
}
// Setup the backend.
m := &PostgreSQLBackend{
table: quoted_table,
client: db,
statements: make(map[string]*sql.Stmt),
logger: logger,
}
// Prepare all the statements required
statements := map[string]string{
"put": put_statement,
"get": "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
"delete": "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
"list": "SELECT key FROM " + quoted_table + " WHERE path = $1" +
"UNION SELECT substr(path, length($1)+1) FROM " + quoted_table + "WHERE parent_path = $1",
}
for name, query := range statements {
if err := m.prepare(name, query); err != nil {
return nil, err
}
}
return m, nil
}
// prepare is a helper to prepare a query for future execution
func (m *PostgreSQLBackend) prepare(name, query string) error {
stmt, err := m.client.Prepare(query)
if err != nil {
return fmt.Errorf("failed to prepare '%s': %v", name, err)
}
m.statements[name] = stmt
return nil
}
// splitKey is a helper to split a full path key into individual
// parts: parentPath, path, key
func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) {
var parentPath string
var path string
pieces := strings.Split(fullPath, "/")
depth := len(pieces)
key := pieces[depth-1]
if depth == 1 {
parentPath = ""
path = "/"
} else if depth == 2 {
parentPath = "/"
path = "/" + pieces[0] + "/"
} else {
parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/"
path = "/" + strings.Join(pieces[:depth-1], "/") + "/"
}
return parentPath, path, key
}
// Put is used to insert or update an entry.
func (m *PostgreSQLBackend) Put(entry *Entry) error {
defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now())
parentPath, path, key := m.splitKey(entry.Key)
_, err := m.statements["put"].Exec(parentPath, path, key, entry.Value)
if err != nil {
return err
}
return nil
}
// Get is used to fetch and entry.
func (m *PostgreSQLBackend) Get(fullPath string) (*Entry, error) {
defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now())
_, path, key := m.splitKey(fullPath)
var result []byte
err := m.statements["get"].QueryRow(path, key).Scan(&result)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
ent := &Entry{
Key: key,
Value: result,
}
return ent, nil
}
// Delete is used to permanently delete an entry
func (m *PostgreSQLBackend) Delete(fullPath string) error {
defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now())
_, path, key := m.splitKey(fullPath)
_, err := m.statements["delete"].Exec(path, key)
if err != nil {
return err
}
return nil
}
// List is used to list all the keys under a given
// prefix, up to the next prefix.
func (m *PostgreSQLBackend) List(prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
rows, err := m.statements["list"].Query("/" + prefix)
if err != nil {
return nil, err
}
defer rows.Close()
var keys []string
for rows.Next() {
var key string
err = rows.Scan(&key)
if err != nil {
return nil, fmt.Errorf("failed to scan rows: %v", err)
}
keys = append(keys, key)
}
return keys, nil
}