66f4589195
use server_version_num instead of parsing the text version. See: https://www.postgresql.org/docs/10/static/functions-info.html
215 lines
5.6 KiB
Go
215 lines
5.6 KiB
Go
package postgresql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/hashicorp/errwrap"
|
|
"github.com/hashicorp/vault/physical"
|
|
log "github.com/mgutz/logxi/v1"
|
|
|
|
"github.com/armon/go-metrics"
|
|
"github.com/lib/pq"
|
|
)
|
|
|
|
// Verify PostgreSQLBackend satisfies the correct interfaces
|
|
var _ physical.Backend = (*PostgreSQLBackend)(nil)
|
|
|
|
// PostgreSQL Backend is a physical backend that stores data
|
|
// within a PostgreSQL database.
|
|
type PostgreSQLBackend struct {
|
|
table string
|
|
client *sql.DB
|
|
put_query string
|
|
get_query string
|
|
delete_query string
|
|
list_query string
|
|
logger log.Logger
|
|
permitPool *physical.PermitPool
|
|
}
|
|
|
|
// NewPostgreSQLBackend constructs a PostgreSQL backend using the given
|
|
// API client, server address, credentials, and database.
|
|
func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
|
// Get the PostgreSQL credentials to perform read/write operations.
|
|
connURL, ok := conf["connection_url"]
|
|
if !ok || connURL == "" {
|
|
return nil, fmt.Errorf("missing connection_url")
|
|
}
|
|
|
|
unquoted_table, ok := conf["table"]
|
|
if !ok {
|
|
unquoted_table = "vault_kv_store"
|
|
}
|
|
quoted_table := pq.QuoteIdentifier(unquoted_table)
|
|
|
|
maxParStr, ok := conf["max_parallel"]
|
|
var maxParInt int
|
|
var err error
|
|
if ok {
|
|
maxParInt, err = strconv.Atoi(maxParStr)
|
|
if err != nil {
|
|
return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
|
|
}
|
|
if logger.IsDebug() {
|
|
logger.Debug("postgres: max_parallel set", "max_parallel", maxParInt)
|
|
}
|
|
} else {
|
|
maxParInt = physical.DefaultParallelOperations
|
|
}
|
|
|
|
// Create PostgreSQL handle for the database.
|
|
db, err := sql.Open("postgres", connURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to postgres: %v", err)
|
|
}
|
|
db.SetMaxOpenConns(maxParInt)
|
|
|
|
// Determine if we should use an upsert function (versions < 9.5)
|
|
var upsert_required bool
|
|
upsert_required_query := "SELECT current_setting('server_version_num')::int < 90500"
|
|
if err := db.QueryRow(upsert_required_query).Scan(&upsert_required); err != nil {
|
|
return nil, fmt.Errorf("failed to check for native upsert: %v", err)
|
|
}
|
|
|
|
// Setup our put strategy based on the presence or absence of a native
|
|
// upsert.
|
|
var put_query string
|
|
if upsert_required {
|
|
put_query = "SELECT vault_kv_put($1, $2, $3, $4)"
|
|
} else {
|
|
put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" +
|
|
" ON CONFLICT (path, key) DO " +
|
|
" UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)"
|
|
}
|
|
|
|
// Setup the backend.
|
|
m := &PostgreSQLBackend{
|
|
table: quoted_table,
|
|
client: db,
|
|
put_query: put_query,
|
|
get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
|
delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
|
list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
|
|
"UNION SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " +
|
|
quoted_table + " WHERE parent_path LIKE $1 || '%'",
|
|
logger: logger,
|
|
permitPool: physical.NewPermitPool(maxParInt),
|
|
}
|
|
|
|
return m, nil
|
|
}
|
|
|
|
// splitKey is a helper to split a full path key into individual
|
|
// parts: parentPath, path, key
|
|
func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) {
|
|
var parentPath string
|
|
var path string
|
|
|
|
pieces := strings.Split(fullPath, "/")
|
|
depth := len(pieces)
|
|
key := pieces[depth-1]
|
|
|
|
if depth == 1 {
|
|
parentPath = ""
|
|
path = "/"
|
|
} else if depth == 2 {
|
|
parentPath = "/"
|
|
path = "/" + pieces[0] + "/"
|
|
} else {
|
|
parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/"
|
|
path = "/" + strings.Join(pieces[:depth-1], "/") + "/"
|
|
}
|
|
|
|
return parentPath, path, key
|
|
}
|
|
|
|
// Put is used to insert or update an entry.
|
|
func (m *PostgreSQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
|
|
defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
parentPath, path, key := m.splitKey(entry.Key)
|
|
|
|
_, err := m.client.Exec(m.put_query, parentPath, path, key, entry.Value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get is used to fetch and entry.
|
|
func (m *PostgreSQLBackend) Get(ctx context.Context, fullPath string) (*physical.Entry, error) {
|
|
defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
_, path, key := m.splitKey(fullPath)
|
|
|
|
var result []byte
|
|
err := m.client.QueryRow(m.get_query, path, key).Scan(&result)
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ent := &physical.Entry{
|
|
Key: key,
|
|
Value: result,
|
|
}
|
|
return ent, nil
|
|
}
|
|
|
|
// Delete is used to permanently delete an entry
|
|
func (m *PostgreSQLBackend) Delete(ctx context.Context, fullPath string) error {
|
|
defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
_, path, key := m.splitKey(fullPath)
|
|
|
|
_, err := m.client.Exec(m.delete_query, path, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// List is used to list all the keys under a given
|
|
// prefix, up to the next prefix.
|
|
func (m *PostgreSQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
|
|
defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
rows, err := m.client.Query(m.list_query, "/"+prefix)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var keys []string
|
|
for rows.Next() {
|
|
var key string
|
|
err = rows.Scan(&key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan rows: %v", err)
|
|
}
|
|
|
|
keys = append(keys, key)
|
|
}
|
|
|
|
return keys, nil
|
|
}
|