open-vault/physical/mysql/mysql.go

777 lines
21 KiB
Go

package mysql
import (
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
"errors"
"fmt"
"io/ioutil"
"math"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"time"
"unicode"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
metrics "github.com/armon/go-metrics"
mysql "github.com/go-sql-driver/mysql"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/physical"
)
// Verify MySQLBackend satisfies the correct interfaces
var (
_ physical.Backend = (*MySQLBackend)(nil)
_ physical.HABackend = (*MySQLBackend)(nil)
_ physical.Lock = (*MySQLHALock)(nil)
)
// Unreserved tls key
// Reserved values are "true", "false", "skip-verify"
const mysqlTLSKey = "default"
// MySQLBackend is a physical backend that stores data
// within MySQL database.
type MySQLBackend struct {
dbTable string
dbLockTable string
client *sql.DB
statements map[string]*sql.Stmt
logger log.Logger
permitPool *physical.PermitPool
conf map[string]string
redirectHost string
redirectPort int64
haEnabled bool
}
// NewMySQLBackend constructs a MySQL backend using the given API client and
// server address and credential for accessing mysql database.
func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
var err error
db, err := NewMySQLClient(conf, logger)
if err != nil {
return nil, err
}
database := conf["database"]
if database == "" {
database = "vault"
}
table := conf["table"]
if table == "" {
table = "vault"
}
err = validateDBTable(database, table)
if err != nil {
return nil, err
}
dbTable := fmt.Sprintf("`%s`.`%s`", database, table)
maxParStr, ok := conf["max_parallel"]
var maxParInt int
if ok {
maxParInt, err = strconv.Atoi(maxParStr)
if err != nil {
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
}
if logger.IsDebug() {
logger.Debug("max_parallel set", "max_parallel", maxParInt)
}
} else {
maxParInt = physical.DefaultParallelOperations
}
// Check schema exists
var schemaExist bool
schemaRows, err := db.Query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?", database)
if err != nil {
return nil, fmt.Errorf("failed to check mysql schema exist: %w", err)
}
defer schemaRows.Close()
schemaExist = schemaRows.Next()
// Check table exists
var tableExist bool
tableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", table, database)
if err != nil {
return nil, fmt.Errorf("failed to check mysql table exist: %w", err)
}
defer tableRows.Close()
tableExist = tableRows.Next()
// Create the required database if it doesn't exists.
if !schemaExist {
if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS `" + database + "`"); err != nil {
return nil, fmt.Errorf("failed to create mysql database: %w", err)
}
}
// Create the required table if it doesn't exists.
if !tableExist {
create_query := "CREATE TABLE IF NOT EXISTS " + dbTable +
" (vault_key varbinary(3072), vault_value mediumblob, PRIMARY KEY (vault_key))"
if _, err := db.Exec(create_query); err != nil {
return nil, fmt.Errorf("failed to create mysql table: %w", err)
}
}
// Default value for ha_enabled
haEnabledStr, ok := conf["ha_enabled"]
if !ok {
haEnabledStr = "false"
}
haEnabled, err := strconv.ParseBool(haEnabledStr)
if err != nil {
return nil, fmt.Errorf("value [%v] of 'ha_enabled' could not be understood", haEnabledStr)
}
locktable, ok := conf["lock_table"]
if !ok {
locktable = table + "_lock"
}
dbLockTable := "`" + database + "`.`" + locktable + "`"
// Only create lock table if ha_enabled is true
if haEnabled {
// Check table exists
var lockTableExist bool
lockTableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", locktable, database)
if err != nil {
return nil, fmt.Errorf("failed to check mysql table exist: %w", err)
}
defer lockTableRows.Close()
lockTableExist = lockTableRows.Next()
// Create the required table if it doesn't exists.
if !lockTableExist {
create_query := "CREATE TABLE IF NOT EXISTS " + dbLockTable +
" (node_job varbinary(512), current_leader varbinary(512), PRIMARY KEY (node_job))"
if _, err := db.Exec(create_query); err != nil {
return nil, fmt.Errorf("failed to create mysql table: %w", err)
}
}
}
// Setup the backend.
m := &MySQLBackend{
dbTable: dbTable,
dbLockTable: dbLockTable,
client: db,
statements: make(map[string]*sql.Stmt),
logger: logger,
permitPool: physical.NewPermitPool(maxParInt),
conf: conf,
haEnabled: haEnabled,
}
// Prepare all the statements required
statements := map[string]string{
"put": "INSERT INTO " + dbTable +
" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)",
"get": "SELECT vault_value FROM " + dbTable + " WHERE vault_key = ?",
"delete": "DELETE FROM " + dbTable + " WHERE vault_key = ?",
"list": "SELECT vault_key FROM " + dbTable + " WHERE vault_key LIKE ?",
}
// Only prepare ha-related statements if we need them
if haEnabled {
statements["get_lock"] = "SELECT current_leader FROM " + dbLockTable + " WHERE node_job = ?"
statements["used_lock"] = "SELECT IS_USED_LOCK(?)"
}
for name, query := range statements {
if err := m.prepare(name, query); err != nil {
return nil, err
}
}
return m, nil
}
// validateDBTable to prevent SQL injection attacks. This ensures that the database and table names only have valid
// characters in them. MySQL allows for more characters that this will allow, but there isn't an easy way of
// representing the full Unicode Basic Multilingual Plane to check against.
// https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
func validateDBTable(db, table string) (err error) {
merr := &multierror.Error{}
merr = multierror.Append(merr, wrapErr("invalid database: %w", validate(db)))
merr = multierror.Append(merr, wrapErr("invalid table: %w", validate(table)))
return merr.ErrorOrNil()
}
func validate(name string) (err error) {
if name == "" {
return fmt.Errorf("missing name")
}
// From: https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
// - Permitted characters in quoted identifiers include the full Unicode Basic Multilingual Plane (BMP), except U+0000:
// ASCII: U+0001 .. U+007F
// Extended: U+0080 .. U+FFFF
// - ASCII NUL (U+0000) and supplementary characters (U+10000 and higher) are not permitted in quoted or unquoted identifiers.
// - Identifiers may begin with a digit but unless quoted may not consist solely of digits.
// - Database, table, and column names cannot end with space characters.
//
// We are explicitly excluding all space characters (it's easier to deal with)
// The name will be quoted, so the all-digit requirement doesn't apply
runes := []rune(name)
validationErr := fmt.Errorf("invalid character found: can only include printable, non-space characters between [0x0001-0xFFFF]")
for _, r := range runes {
// U+0000 Explicitly disallowed
if r == 0x0000 {
return fmt.Errorf("invalid character: cannot include 0x0000")
}
// Cannot be above 0xFFFF
if r > 0xFFFF {
return fmt.Errorf("invalid character: cannot include any characters above 0xFFFF")
}
if r == '`' {
return fmt.Errorf("invalid character: cannot include '`' character")
}
if r == '\'' || r == '"' {
return fmt.Errorf("invalid character: cannot include quotes")
}
// We are excluding non-printable characters (not mentioned in the docs)
if !unicode.IsPrint(r) {
return validationErr
}
// We are excluding space characters (not mentioned in the docs)
if unicode.IsSpace(r) {
return validationErr
}
}
return nil
}
func wrapErr(message string, err error) error {
if err == nil {
return nil
}
return fmt.Errorf(message, err)
}
func NewMySQLClient(conf map[string]string, logger log.Logger) (*sql.DB, error) {
var err error
// Get the MySQL credentials to perform read/write operations.
username, ok := conf["username"]
if !ok || username == "" {
return nil, fmt.Errorf("missing username")
}
password, ok := conf["password"]
if !ok || password == "" {
return nil, fmt.Errorf("missing password")
}
// Get or set MySQL server address. Defaults to localhost and default port(3306)
address, ok := conf["address"]
if !ok {
address = "127.0.0.1:3306"
}
maxIdleConnStr, ok := conf["max_idle_connections"]
var maxIdleConnInt int
if ok {
maxIdleConnInt, err = strconv.Atoi(maxIdleConnStr)
if err != nil {
return nil, fmt.Errorf("failed parsing max_idle_connections parameter: %w", err)
}
if logger.IsDebug() {
logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnInt)
}
}
maxConnLifeStr, ok := conf["max_connection_lifetime"]
var maxConnLifeInt int
if ok {
maxConnLifeInt, err = strconv.Atoi(maxConnLifeStr)
if err != nil {
return nil, fmt.Errorf("failed parsing max_connection_lifetime parameter: %w", err)
}
if logger.IsDebug() {
logger.Debug("max_connection_lifetime set", "max_connection_lifetime", maxConnLifeInt)
}
}
maxParStr, ok := conf["max_parallel"]
var maxParInt int
if ok {
maxParInt, err = strconv.Atoi(maxParStr)
if err != nil {
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
}
if logger.IsDebug() {
logger.Debug("max_parallel set", "max_parallel", maxParInt)
}
} else {
maxParInt = physical.DefaultParallelOperations
}
dsnParams := url.Values{}
tlsCaFile, tlsOk := conf["tls_ca_file"]
if tlsOk {
if err := setupMySQLTLSConfig(tlsCaFile); err != nil {
return nil, fmt.Errorf("failed register TLS config: %w", err)
}
dsnParams.Add("tls", mysqlTLSKey)
}
ptAllowed, ptOk := conf["plaintext_connection_allowed"]
if !(ptOk && strings.ToLower(ptAllowed) == "true") && !tlsOk {
logger.Warn("No TLS specified, credentials will be sent in plaintext. To mute this warning add 'plaintext_connection_allowed' with a true value to your MySQL configuration in your config file.")
}
// Create MySQL handle for the database.
dsn := username + ":" + password + "@tcp(" + address + ")/?" + dsnParams.Encode()
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, fmt.Errorf("failed to connect to mysql: %w", err)
}
db.SetMaxOpenConns(maxParInt)
if maxIdleConnInt != 0 {
db.SetMaxIdleConns(maxIdleConnInt)
}
if maxConnLifeInt != 0 {
db.SetConnMaxLifetime(time.Duration(maxConnLifeInt) * time.Second)
}
return db, err
}
// prepare is a helper to prepare a query for future execution
func (m *MySQLBackend) prepare(name, query string) error {
stmt, err := m.client.Prepare(query)
if err != nil {
return fmt.Errorf("failed to prepare %q: %w", name, err)
}
m.statements[name] = stmt
return nil
}
// Put is used to insert or update an entry.
func (m *MySQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
_, err := m.statements["put"].Exec(entry.Key, entry.Value)
if err != nil {
return err
}
return nil
}
// Get is used to fetch an entry.
func (m *MySQLBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
var result []byte
err := m.statements["get"].QueryRow(key).Scan(&result)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
ent := &physical.Entry{
Key: key,
Value: result,
}
return ent, nil
}
// Delete is used to permanently delete an entry
func (m *MySQLBackend) Delete(ctx context.Context, key string) error {
defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
_, err := m.statements["delete"].Exec(key)
if err != nil {
return err
}
return nil
}
// List is used to list all the keys under a given
// prefix, up to the next prefix.
func (m *MySQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now())
m.permitPool.Acquire()
defer m.permitPool.Release()
// Add the % wildcard to the prefix to do the prefix search
likePrefix := prefix + "%"
rows, err := m.statements["list"].Query(likePrefix)
if err != nil {
return nil, fmt.Errorf("failed to execute statement: %w", err)
}
var keys []string
for rows.Next() {
var key string
err = rows.Scan(&key)
if err != nil {
return nil, fmt.Errorf("failed to scan rows: %w", err)
}
key = strings.TrimPrefix(key, prefix)
if i := strings.Index(key, "/"); i == -1 {
// Add objects only from the current 'folder'
keys = append(keys, key)
} else if i != -1 {
// Add truncated 'folder' paths
keys = strutil.AppendIfMissing(keys, string(key[:i+1]))
}
}
sort.Strings(keys)
return keys, nil
}
// LockWith is used for mutual exclusion based on the given key.
func (m *MySQLBackend) LockWith(key, value string) (physical.Lock, error) {
l := &MySQLHALock{
in: m,
key: key,
value: value,
logger: m.logger,
}
return l, nil
}
func (m *MySQLBackend) HAEnabled() bool {
return m.haEnabled
}
// MySQLHALock is a MySQL Lock implementation for the HABackend
type MySQLHALock struct {
in *MySQLBackend
key string
value string
logger log.Logger
held bool
localLock sync.Mutex
leaderCh chan struct{}
stopCh <-chan struct{}
lock *MySQLLock
}
func (i *MySQLHALock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
i.localLock.Lock()
defer i.localLock.Unlock()
if i.held {
return nil, fmt.Errorf("lock already held")
}
// Attempt an async acquisition
didLock := make(chan struct{})
failLock := make(chan error, 1)
releaseCh := make(chan bool, 1)
go i.attemptLock(i.key, i.value, didLock, failLock, releaseCh)
// Wait for lock acquisition, failure, or shutdown
select {
case <-didLock:
releaseCh <- false
case err := <-failLock:
return nil, err
case <-stopCh:
releaseCh <- true
return nil, nil
}
// Create the leader channel
i.held = true
i.leaderCh = make(chan struct{})
go i.monitorLock(i.leaderCh)
i.stopCh = stopCh
return i.leaderCh, nil
}
func (i *MySQLHALock) attemptLock(key, value string, didLock chan struct{}, failLock chan error, releaseCh chan bool) {
lock, err := NewMySQLLock(i.in, i.logger, key, value)
if err != nil {
failLock <- err
return
}
// Set node value
i.lock = lock
err = lock.Lock()
if err != nil {
failLock <- err
return
}
// Signal that lock is held
close(didLock)
// Handle an early abort
release := <-releaseCh
if release {
lock.Unlock()
}
}
func (i *MySQLHALock) monitorLock(leaderCh chan struct{}) {
for {
// The only way to lose this lock is if someone is
// logging into the DB and altering system tables or you lose a connection in
// which case you will lose the lock anyway.
err := i.hasLock(i.key)
if err != nil {
// Somehow we lost the lock.... likely because the connection holding
// the lock was closed or someone was playing around with the locks in the DB.
close(leaderCh)
return
}
time.Sleep(5 * time.Second)
}
}
func (i *MySQLHALock) Unlock() error {
i.localLock.Lock()
defer i.localLock.Unlock()
if !i.held {
return nil
}
err := i.lock.Unlock()
if err == nil {
i.held = false
return nil
}
return err
}
// hasLock will check if a lock is held by checking the current lock id against our known ID.
func (i *MySQLHALock) hasLock(key string) error {
var result sql.NullInt64
err := i.in.statements["used_lock"].QueryRow(key).Scan(&result)
if err == sql.ErrNoRows || !result.Valid {
// This is not an error to us since it just means the lock isn't held
return nil
}
if err != nil {
return err
}
// IS_USED_LOCK will return the ID of the connection that created the lock.
if result.Int64 != GlobalLockID {
return ErrLockHeld
}
return nil
}
func (i *MySQLHALock) GetLeader() (string, error) {
defer metrics.MeasureSince([]string{"mysql", "lock_get"}, time.Now())
var result string
err := i.in.statements["get_lock"].QueryRow("leader").Scan(&result)
if err == sql.ErrNoRows {
return "", err
}
return result, nil
}
func (i *MySQLHALock) Value() (bool, string, error) {
leaderkey, err := i.GetLeader()
if err != nil {
return false, "", err
}
return true, leaderkey, err
}
// MySQLLock provides an easy way to grab and release mysql
// locks using the built in GET_LOCK function. Note that these
// locks are released when you lose connection to the server.
type MySQLLock struct {
parentConn *MySQLBackend
in *sql.DB
logger log.Logger
statements map[string]*sql.Stmt
key string
value string
}
// Errors specific to trying to grab a lock in MySQL
var (
// This is the GlobalLockID for checking if the lock we got is still the current lock
GlobalLockID int64
// ErrLockHeld is returned when another vault instance already has a lock held for the given key.
ErrLockHeld = errors.New("mysql: lock already held")
// ErrUnlockFailed
ErrUnlockFailed = errors.New("mysql: unable to release lock, already released or not held by this session")
// You were unable to update that you are the new leader in the DB
ErrClaimFailed = errors.New("mysql: unable to update DB with new leader information")
// Error to throw if between getting the lock and checking the ID of it we lost it.
ErrSettingGlobalID = errors.New("mysql: getting global lock id failed")
)
// NewMySQLLock helper function
func NewMySQLLock(in *MySQLBackend, l log.Logger, key, value string) (*MySQLLock, error) {
// Create a new MySQL connection so we can close this and have no effect on
// the rest of the MySQL backend and any cleanup that might need to be done.
conn, _ := NewMySQLClient(in.conf, in.logger)
m := &MySQLLock{
parentConn: in,
in: conn,
logger: l,
statements: make(map[string]*sql.Stmt),
key: key,
value: value,
}
statements := map[string]string{
"put": "INSERT INTO " + in.dbLockTable +
" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE current_leader=VALUES(current_leader)",
}
for name, query := range statements {
if err := m.prepare(name, query); err != nil {
return nil, err
}
}
return m, nil
}
// prepare is a helper to prepare a query for future execution
func (m *MySQLLock) prepare(name, query string) error {
stmt, err := m.in.Prepare(query)
if err != nil {
return fmt.Errorf("failed to prepare %q: %w", name, err)
}
m.statements[name] = stmt
return nil
}
// update the current cluster leader in the DB. This is used so
// we can tell the servers in standby who the active leader is.
func (i *MySQLLock) becomeLeader() error {
_, err := i.statements["put"].Exec("leader", i.value)
if err != nil {
return err
}
return nil
}
// Lock will try to get a lock for an indefinite amount of time
// based on the given key that has been requested.
func (i *MySQLLock) Lock() error {
defer metrics.MeasureSince([]string{"mysql", "get_lock"}, time.Now())
// Lock timeout math.MaxInt32 instead of -1 solves compatibility issues with
// different MySQL flavours i.e. MariaDB
rows, err := i.in.Query("SELECT GET_LOCK(?, ?), IS_USED_LOCK(?)", i.key, math.MaxInt32, i.key)
if err != nil {
return err
}
defer rows.Close()
rows.Next()
var lock sql.NullInt64
var connectionID sql.NullInt64
rows.Scan(&lock, &connectionID)
if rows.Err() != nil {
return rows.Err()
}
// 1 is returned from GET_LOCK if it was able to get the lock
// 0 if it failed and NULL if some strange error happened.
// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_get-lock
if !lock.Valid || lock.Int64 != 1 {
return ErrLockHeld
}
// Since we have the lock alert the rest of the cluster
// that we are now the active leader.
err = i.becomeLeader()
if err != nil {
return ErrLockHeld
}
// This will return the connection ID of NULL if an error happens
// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_is-used-lock
if !connectionID.Valid {
return ErrSettingGlobalID
}
GlobalLockID = connectionID.Int64
return nil
}
// Unlock just closes the connection. This is because closing the MySQL connection
// is a 100% reliable way to close the lock. If you just release the lock you must
// do it from the same mysql connection_id that you originally created it from. This
// is a huge hastle and I actually couldn't find a clean way to do this although one
// likely does exist. Closing the connection however ensures we don't ever get into a
// state where we try to release the lock and it hangs it is also much less code.
func (i *MySQLLock) Unlock() error {
err := i.in.Close()
if err != nil {
return ErrUnlockFailed
}
return nil
}
// Establish a TLS connection with a given CA certificate
// Register a tsl.Config associated with the same key as the dns param from sql.Open
// foo:bar@tcp(127.0.0.1:3306)/dbname?tls=default
func setupMySQLTLSConfig(tlsCaFile string) error {
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(tlsCaFile)
if err != nil {
return err
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return err
}
err = mysql.RegisterTLSConfig(mysqlTLSKey, &tls.Config{
RootCAs: rootCertPool,
})
if err != nil {
return err
}
return nil
}