package cockroachdb import ( "context" "database/sql" "fmt" "sort" "strconv" "strings" "time" "unicode" metrics "github.com/armon/go-metrics" "github.com/cockroachdb/cockroach-go/crdb" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-multierror" "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/vault/sdk/physical" // CockroachDB uses the Postgres SQL driver _ "github.com/jackc/pgx/v4/stdlib" ) // Verify CockroachDBBackend satisfies the correct interfaces var ( _ physical.Backend = (*CockroachDBBackend)(nil) _ physical.Transactional = (*CockroachDBBackend)(nil) ) const ( defaultTableName = "vault_kv_store" defaultHATableName = "vault_ha_locks" ) // CockroachDBBackend Backend is a physical backend that stores data // within a CockroachDB database. type CockroachDBBackend struct { table string haTable string client *sql.DB rawStatements map[string]string statements map[string]*sql.Stmt rawHAStatements map[string]string haStatements map[string]*sql.Stmt logger log.Logger permitPool *physical.PermitPool haEnabled bool } // NewCockroachDBBackend constructs a CockroachDB backend using the given // API client, server address, credentials, and database. func NewCockroachDBBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { // Get the CockroachDB credentials to perform read/write operations. connURL, ok := conf["connection_url"] if !ok || connURL == "" { return nil, fmt.Errorf("missing connection_url") } haEnabled := conf["ha_enabled"] == "true" dbTable := conf["table"] if dbTable == "" { dbTable = defaultTableName } err := validateDBTable(dbTable) if err != nil { return nil, fmt.Errorf("invalid table: %w", err) } dbHATable, ok := conf["ha_table"] if !ok { dbHATable = defaultHATableName } err = validateDBTable(dbHATable) if err != nil { return nil, fmt.Errorf("invalid HA table: %w", err) } maxParStr, ok := conf["max_parallel"] var maxParInt int if ok { maxParInt, err = strconv.Atoi(maxParStr) if err != nil { return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err) } if logger.IsDebug() { logger.Debug("max_parallel set", "max_parallel", maxParInt) } } // Create CockroachDB handle for the database. db, err := sql.Open("pgx", connURL) if err != nil { return nil, fmt.Errorf("failed to connect to cockroachdb: %w", err) } // Create the required tables if they don't exist. createQuery := "CREATE TABLE IF NOT EXISTS " + dbTable + " (path STRING, value BYTES, PRIMARY KEY (path))" if _, err := db.Exec(createQuery); err != nil { return nil, fmt.Errorf("failed to create CockroachDB table: %w", err) } if haEnabled { createHATableQuery := "CREATE TABLE IF NOT EXISTS " + dbHATable + "(ha_key TEXT NOT NULL, " + " ha_identity TEXT NOT NULL, " + " ha_value TEXT, " + " valid_until TIMESTAMP WITH TIME ZONE NOT NULL, " + " CONSTRAINT ha_key PRIMARY KEY (ha_key) " + ");" if _, err := db.Exec(createHATableQuery); err != nil { return nil, fmt.Errorf("failed to create CockroachDB HA table: %w", err) } } // Setup the backend c := &CockroachDBBackend{ table: dbTable, haTable: dbHATable, client: db, rawStatements: map[string]string{ "put": "INSERT INTO " + dbTable + " VALUES($1, $2)" + " ON CONFLICT (path) DO " + " UPDATE SET (path, value) = ($1, $2)", "get": "SELECT value FROM " + dbTable + " WHERE path = $1", "delete": "DELETE FROM " + dbTable + " WHERE path = $1", "list": "SELECT path FROM " + dbTable + " WHERE path LIKE $1", }, statements: make(map[string]*sql.Stmt), rawHAStatements: map[string]string{ "get": "SELECT ha_value FROM " + dbHATable + " WHERE NOW() <= valid_until AND ha_key = $1", "upsert": "INSERT INTO " + dbHATable + " as t (ha_identity, ha_key, ha_value, valid_until)" + " VALUES ($1, $2, $3, NOW() + $4) " + " ON CONFLICT (ha_key) DO " + " UPDATE SET (ha_identity, ha_key, ha_value, valid_until) = ($1, $2, $3, NOW() + $4) " + " WHERE (t.valid_until < NOW() AND t.ha_key = $2) OR " + " (t.ha_identity = $1 AND t.ha_key = $2) ", "delete": "DELETE FROM " + dbHATable + " WHERE ha_key = $1", }, haStatements: make(map[string]*sql.Stmt), logger: logger, permitPool: physical.NewPermitPool(maxParInt), haEnabled: haEnabled, } // Prepare all the statements required for name, query := range c.rawStatements { if err := c.prepare(c.statements, name, query); err != nil { return nil, err } } if haEnabled { for name, query := range c.rawHAStatements { if err := c.prepare(c.haStatements, name, query); err != nil { return nil, err } } } return c, nil } // prepare is a helper to prepare a query for future execution. func (c *CockroachDBBackend) prepare(statementMap map[string]*sql.Stmt, name, query string) error { stmt, err := c.client.Prepare(query) if err != nil { return fmt.Errorf("failed to prepare %q: %w", name, err) } statementMap[name] = stmt return nil } // Put is used to insert or update an entry. func (c *CockroachDBBackend) Put(ctx context.Context, entry *physical.Entry) error { defer metrics.MeasureSince([]string{"cockroachdb", "put"}, time.Now()) c.permitPool.Acquire() defer c.permitPool.Release() _, err := c.statements["put"].Exec(entry.Key, entry.Value) if err != nil { return err } return nil } // Get is used to fetch and entry. func (c *CockroachDBBackend) Get(ctx context.Context, key string) (*physical.Entry, error) { defer metrics.MeasureSince([]string{"cockroachdb", "get"}, time.Now()) c.permitPool.Acquire() defer c.permitPool.Release() var result []byte err := c.statements["get"].QueryRow(key).Scan(&result) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, err } ent := &physical.Entry{ Key: key, Value: result, } return ent, nil } // Delete is used to permanently delete an entry func (c *CockroachDBBackend) Delete(ctx context.Context, key string) error { defer metrics.MeasureSince([]string{"cockroachdb", "delete"}, time.Now()) c.permitPool.Acquire() defer c.permitPool.Release() _, err := c.statements["delete"].Exec(key) if err != nil { return err } return nil } // List is used to list all the keys under a given // prefix, up to the next prefix. func (c *CockroachDBBackend) List(ctx context.Context, prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"cockroachdb", "list"}, time.Now()) c.permitPool.Acquire() defer c.permitPool.Release() likePrefix := prefix + "%" rows, err := c.statements["list"].Query(likePrefix) if err != nil { return nil, err } defer rows.Close() var keys []string for rows.Next() { var key string err = rows.Scan(&key) if err != nil { return nil, fmt.Errorf("failed to scan rows: %w", err) } key = strings.TrimPrefix(key, prefix) if i := strings.Index(key, "/"); i == -1 { // Add objects only from the current 'folder' keys = append(keys, key) } else if i != -1 { // Add truncated 'folder' paths keys = strutil.AppendIfMissing(keys, string(key[:i+1])) } } sort.Strings(keys) return keys, nil } // Transaction is used to run multiple entries via a transaction func (c *CockroachDBBackend) Transaction(ctx context.Context, txns []*physical.TxnEntry) error { defer metrics.MeasureSince([]string{"cockroachdb", "transaction"}, time.Now()) if len(txns) == 0 { return nil } c.permitPool.Acquire() defer c.permitPool.Release() return crdb.ExecuteTx(context.Background(), c.client, nil, func(tx *sql.Tx) error { return c.transaction(tx, txns) }) } func (c *CockroachDBBackend) transaction(tx *sql.Tx, txns []*physical.TxnEntry) error { deleteStmt, err := tx.Prepare(c.rawStatements["delete"]) if err != nil { return err } putStmt, err := tx.Prepare(c.rawStatements["put"]) if err != nil { return err } for _, op := range txns { switch op.Operation { case physical.DeleteOperation: _, err = deleteStmt.Exec(op.Entry.Key) case physical.PutOperation: _, err = putStmt.Exec(op.Entry.Key, op.Entry.Value) default: return fmt.Errorf("%q is not a supported transaction operation", op.Operation) } if err != nil { return err } } return nil } // validateDBTable against the CockroachDB rules for table names: // https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#identifiers // // - All values that accept an identifier must: // - Begin with a Unicode letter or an underscore (_). Subsequent characters can be letters, // - underscores, digits (0-9), or dollar signs ($). // - Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, // name accepts Unreserved or Column Name keywords. // // The docs do state that we can bypass these rules with double quotes, however I think it // is safer to just require these rules across the board. func validateDBTable(dbTable string) (err error) { // Check if this is 'database.table' formatted. If so, split them apart and check the two // parts from each other split := strings.SplitN(dbTable, ".", 2) if len(split) == 2 { merr := &multierror.Error{} merr = multierror.Append(merr, wrapErr("invalid database: %w", validateDBTable(split[0]))) merr = multierror.Append(merr, wrapErr("invalid table name: %w", validateDBTable(split[1]))) return merr.ErrorOrNil() } // Disallow SQL keywords as the table name if sqlKeywords[strings.ToUpper(dbTable)] { return fmt.Errorf("name must not be a SQL keyword") } runes := []rune(dbTable) for i, r := range runes { if i == 0 && !unicode.IsLetter(r) && r != '_' { return fmt.Errorf("must use a letter or an underscore as the first character") } if !unicode.IsLetter(r) && r != '_' && !unicode.IsDigit(r) && r != '$' { return fmt.Errorf("must only contain letters, underscores, digits, and dollar signs") } if r == '`' || r == '\'' || r == '"' { return fmt.Errorf("cannot contain backticks, single quotes, or double quotes") } } return nil } func wrapErr(message string, err error) error { if err == nil { return nil } return fmt.Errorf(message, err) }