27bb03bbc0
* adding copyright header * fix fmt and a test
357 lines
10 KiB
Go
357 lines
10 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package cockroachdb
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"unicode"
|
|
|
|
metrics "github.com/armon/go-metrics"
|
|
"github.com/cockroachdb/cockroach-go/crdb"
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/go-multierror"
|
|
"github.com/hashicorp/go-secure-stdlib/strutil"
|
|
"github.com/hashicorp/vault/sdk/physical"
|
|
|
|
// CockroachDB uses the Postgres SQL driver
|
|
_ "github.com/jackc/pgx/v4/stdlib"
|
|
)
|
|
|
|
// Verify CockroachDBBackend satisfies the correct interfaces
|
|
var (
|
|
_ physical.Backend = (*CockroachDBBackend)(nil)
|
|
_ physical.Transactional = (*CockroachDBBackend)(nil)
|
|
)
|
|
|
|
const (
|
|
defaultTableName = "vault_kv_store"
|
|
defaultHATableName = "vault_ha_locks"
|
|
)
|
|
|
|
// CockroachDBBackend Backend is a physical backend that stores data
|
|
// within a CockroachDB database.
|
|
type CockroachDBBackend struct {
|
|
table string
|
|
haTable string
|
|
client *sql.DB
|
|
rawStatements map[string]string
|
|
statements map[string]*sql.Stmt
|
|
rawHAStatements map[string]string
|
|
haStatements map[string]*sql.Stmt
|
|
logger log.Logger
|
|
permitPool *physical.PermitPool
|
|
haEnabled bool
|
|
}
|
|
|
|
// NewCockroachDBBackend constructs a CockroachDB backend using the given
|
|
// API client, server address, credentials, and database.
|
|
func NewCockroachDBBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
|
// Get the CockroachDB credentials to perform read/write operations.
|
|
connURL, ok := conf["connection_url"]
|
|
if !ok || connURL == "" {
|
|
return nil, fmt.Errorf("missing connection_url")
|
|
}
|
|
|
|
haEnabled := conf["ha_enabled"] == "true"
|
|
|
|
dbTable := conf["table"]
|
|
if dbTable == "" {
|
|
dbTable = defaultTableName
|
|
}
|
|
|
|
err := validateDBTable(dbTable)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid table: %w", err)
|
|
}
|
|
|
|
dbHATable, ok := conf["ha_table"]
|
|
if !ok {
|
|
dbHATable = defaultHATableName
|
|
}
|
|
|
|
err = validateDBTable(dbHATable)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid HA table: %w", err)
|
|
}
|
|
|
|
maxParStr, ok := conf["max_parallel"]
|
|
var maxParInt int
|
|
if ok {
|
|
maxParInt, err = strconv.Atoi(maxParStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
|
|
}
|
|
if logger.IsDebug() {
|
|
logger.Debug("max_parallel set", "max_parallel", maxParInt)
|
|
}
|
|
}
|
|
|
|
// Create CockroachDB handle for the database.
|
|
db, err := sql.Open("pgx", connURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to cockroachdb: %w", err)
|
|
}
|
|
|
|
// Create the required tables if they don't exist.
|
|
createQuery := "CREATE TABLE IF NOT EXISTS " + dbTable +
|
|
" (path STRING, value BYTES, PRIMARY KEY (path))"
|
|
if _, err := db.Exec(createQuery); err != nil {
|
|
return nil, fmt.Errorf("failed to create CockroachDB table: %w", err)
|
|
}
|
|
if haEnabled {
|
|
createHATableQuery := "CREATE TABLE IF NOT EXISTS " + dbHATable +
|
|
"(ha_key TEXT NOT NULL, " +
|
|
" ha_identity TEXT NOT NULL, " +
|
|
" ha_value TEXT, " +
|
|
" valid_until TIMESTAMP WITH TIME ZONE NOT NULL, " +
|
|
" CONSTRAINT ha_key PRIMARY KEY (ha_key) " +
|
|
");"
|
|
if _, err := db.Exec(createHATableQuery); err != nil {
|
|
return nil, fmt.Errorf("failed to create CockroachDB HA table: %w", err)
|
|
}
|
|
}
|
|
|
|
// Setup the backend
|
|
c := &CockroachDBBackend{
|
|
table: dbTable,
|
|
haTable: dbHATable,
|
|
client: db,
|
|
rawStatements: map[string]string{
|
|
"put": "INSERT INTO " + dbTable + " VALUES($1, $2)" +
|
|
" ON CONFLICT (path) DO " +
|
|
" UPDATE SET (path, value) = ($1, $2)",
|
|
"get": "SELECT value FROM " + dbTable + " WHERE path = $1",
|
|
"delete": "DELETE FROM " + dbTable + " WHERE path = $1",
|
|
"list": "SELECT path FROM " + dbTable + " WHERE path LIKE $1",
|
|
},
|
|
statements: make(map[string]*sql.Stmt),
|
|
rawHAStatements: map[string]string{
|
|
"get": "SELECT ha_value FROM " + dbHATable + " WHERE NOW() <= valid_until AND ha_key = $1",
|
|
"upsert": "INSERT INTO " + dbHATable + " as t (ha_identity, ha_key, ha_value, valid_until)" +
|
|
" VALUES ($1, $2, $3, NOW() + $4) " +
|
|
" ON CONFLICT (ha_key) DO " +
|
|
" UPDATE SET (ha_identity, ha_key, ha_value, valid_until) = ($1, $2, $3, NOW() + $4) " +
|
|
" WHERE (t.valid_until < NOW() AND t.ha_key = $2) OR " +
|
|
" (t.ha_identity = $1 AND t.ha_key = $2) ",
|
|
"delete": "DELETE FROM " + dbHATable + " WHERE ha_key = $1",
|
|
},
|
|
haStatements: make(map[string]*sql.Stmt),
|
|
logger: logger,
|
|
permitPool: physical.NewPermitPool(maxParInt),
|
|
haEnabled: haEnabled,
|
|
}
|
|
|
|
// Prepare all the statements required
|
|
for name, query := range c.rawStatements {
|
|
if err := c.prepare(c.statements, name, query); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
if haEnabled {
|
|
for name, query := range c.rawHAStatements {
|
|
if err := c.prepare(c.haStatements, name, query); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
// prepare is a helper to prepare a query for future execution.
|
|
func (c *CockroachDBBackend) prepare(statementMap map[string]*sql.Stmt, name, query string) error {
|
|
stmt, err := c.client.Prepare(query)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare %q: %w", name, err)
|
|
}
|
|
statementMap[name] = stmt
|
|
return nil
|
|
}
|
|
|
|
// Put is used to insert or update an entry.
|
|
func (c *CockroachDBBackend) Put(ctx context.Context, entry *physical.Entry) error {
|
|
defer metrics.MeasureSince([]string{"cockroachdb", "put"}, time.Now())
|
|
|
|
c.permitPool.Acquire()
|
|
defer c.permitPool.Release()
|
|
|
|
_, err := c.statements["put"].Exec(entry.Key, entry.Value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get is used to fetch and entry.
|
|
func (c *CockroachDBBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
|
|
defer metrics.MeasureSince([]string{"cockroachdb", "get"}, time.Now())
|
|
|
|
c.permitPool.Acquire()
|
|
defer c.permitPool.Release()
|
|
|
|
var result []byte
|
|
err := c.statements["get"].QueryRow(key).Scan(&result)
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ent := &physical.Entry{
|
|
Key: key,
|
|
Value: result,
|
|
}
|
|
return ent, nil
|
|
}
|
|
|
|
// Delete is used to permanently delete an entry
|
|
func (c *CockroachDBBackend) Delete(ctx context.Context, key string) error {
|
|
defer metrics.MeasureSince([]string{"cockroachdb", "delete"}, time.Now())
|
|
|
|
c.permitPool.Acquire()
|
|
defer c.permitPool.Release()
|
|
|
|
_, err := c.statements["delete"].Exec(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// List is used to list all the keys under a given
|
|
// prefix, up to the next prefix.
|
|
func (c *CockroachDBBackend) List(ctx context.Context, prefix string) ([]string, error) {
|
|
defer metrics.MeasureSince([]string{"cockroachdb", "list"}, time.Now())
|
|
|
|
c.permitPool.Acquire()
|
|
defer c.permitPool.Release()
|
|
|
|
likePrefix := prefix + "%"
|
|
rows, err := c.statements["list"].Query(likePrefix)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var keys []string
|
|
for rows.Next() {
|
|
var key string
|
|
err = rows.Scan(&key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan rows: %w", err)
|
|
}
|
|
|
|
key = strings.TrimPrefix(key, prefix)
|
|
if i := strings.Index(key, "/"); i == -1 {
|
|
// Add objects only from the current 'folder'
|
|
keys = append(keys, key)
|
|
} else if i != -1 {
|
|
// Add truncated 'folder' paths
|
|
keys = strutil.AppendIfMissing(keys, string(key[:i+1]))
|
|
}
|
|
}
|
|
|
|
sort.Strings(keys)
|
|
return keys, nil
|
|
}
|
|
|
|
// Transaction is used to run multiple entries via a transaction
|
|
func (c *CockroachDBBackend) Transaction(ctx context.Context, txns []*physical.TxnEntry) error {
|
|
defer metrics.MeasureSince([]string{"cockroachdb", "transaction"}, time.Now())
|
|
if len(txns) == 0 {
|
|
return nil
|
|
}
|
|
|
|
c.permitPool.Acquire()
|
|
defer c.permitPool.Release()
|
|
|
|
return crdb.ExecuteTx(context.Background(), c.client, nil, func(tx *sql.Tx) error {
|
|
return c.transaction(tx, txns)
|
|
})
|
|
}
|
|
|
|
func (c *CockroachDBBackend) transaction(tx *sql.Tx, txns []*physical.TxnEntry) error {
|
|
deleteStmt, err := tx.Prepare(c.rawStatements["delete"])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
putStmt, err := tx.Prepare(c.rawStatements["put"])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, op := range txns {
|
|
switch op.Operation {
|
|
case physical.DeleteOperation:
|
|
_, err = deleteStmt.Exec(op.Entry.Key)
|
|
case physical.PutOperation:
|
|
_, err = putStmt.Exec(op.Entry.Key, op.Entry.Value)
|
|
default:
|
|
return fmt.Errorf("%q is not a supported transaction operation", op.Operation)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// validateDBTable against the CockroachDB rules for table names:
|
|
// https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#identifiers
|
|
//
|
|
// - All values that accept an identifier must:
|
|
// - Begin with a Unicode letter or an underscore (_). Subsequent characters can be letters,
|
|
// - underscores, digits (0-9), or dollar signs ($).
|
|
// - Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example,
|
|
// name accepts Unreserved or Column Name keywords.
|
|
//
|
|
// The docs do state that we can bypass these rules with double quotes, however I think it
|
|
// is safer to just require these rules across the board.
|
|
func validateDBTable(dbTable string) (err error) {
|
|
// Check if this is 'database.table' formatted. If so, split them apart and check the two
|
|
// parts from each other
|
|
split := strings.SplitN(dbTable, ".", 2)
|
|
if len(split) == 2 {
|
|
merr := &multierror.Error{}
|
|
merr = multierror.Append(merr, wrapErr("invalid database: %w", validateDBTable(split[0])))
|
|
merr = multierror.Append(merr, wrapErr("invalid table name: %w", validateDBTable(split[1])))
|
|
return merr.ErrorOrNil()
|
|
}
|
|
|
|
// Disallow SQL keywords as the table name
|
|
if sqlKeywords[strings.ToUpper(dbTable)] {
|
|
return fmt.Errorf("name must not be a SQL keyword")
|
|
}
|
|
|
|
runes := []rune(dbTable)
|
|
for i, r := range runes {
|
|
if i == 0 && !unicode.IsLetter(r) && r != '_' {
|
|
return fmt.Errorf("must use a letter or an underscore as the first character")
|
|
}
|
|
|
|
if !unicode.IsLetter(r) && r != '_' && !unicode.IsDigit(r) && r != '$' {
|
|
return fmt.Errorf("must only contain letters, underscores, digits, and dollar signs")
|
|
}
|
|
|
|
if r == '`' || r == '\'' || r == '"' {
|
|
return fmt.Errorf("cannot contain backticks, single quotes, or double quotes")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func wrapErr(message string, err error) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
return fmt.Errorf(message, err)
|
|
}
|