2017-08-03 17:24:27 +00:00
package mysql
2015-06-08 10:32:44 +00:00
import (
2018-01-19 06:44:44 +00:00
"context"
2015-07-17 15:54:38 +00:00
"crypto/tls"
"crypto/x509"
2015-06-08 10:32:44 +00:00
"database/sql"
2018-08-13 21:02:31 +00:00
"errors"
2015-06-12 05:56:25 +00:00
"fmt"
2015-07-17 15:54:38 +00:00
"io/ioutil"
2018-09-19 19:05:05 +00:00
"math"
2015-07-17 15:54:38 +00:00
"net/url"
2015-06-08 10:32:44 +00:00
"sort"
2017-06-01 22:20:32 +00:00
"strconv"
2015-06-12 17:31:46 +00:00
"strings"
2018-08-13 21:02:31 +00:00
"sync"
2015-06-08 10:32:44 +00:00
"time"
2020-06-12 17:08:56 +00:00
"unicode"
2015-06-08 10:32:44 +00:00
2018-04-03 00:46:59 +00:00
log "github.com/hashicorp/go-hclog"
2020-06-12 17:08:56 +00:00
"github.com/hashicorp/go-multierror"
2016-08-19 20:45:17 +00:00
2019-01-09 00:48:57 +00:00
metrics "github.com/armon/go-metrics"
2015-07-17 15:54:38 +00:00
mysql "github.com/go-sql-driver/mysql"
2017-06-01 22:20:32 +00:00
"github.com/hashicorp/errwrap"
2019-04-12 21:54:35 +00:00
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/physical"
2015-07-17 15:54:38 +00:00
)
2018-01-20 01:44:24 +00:00
// Verify MySQLBackend satisfies the correct interfaces
var _ physical . Backend = ( * MySQLBackend ) ( nil )
2018-08-13 21:02:31 +00:00
var _ physical . HABackend = ( * MySQLBackend ) ( nil )
var _ physical . Lock = ( * MySQLHALock ) ( nil )
2018-01-20 01:44:24 +00:00
2015-07-29 20:03:46 +00:00
// Unreserved tls key
// Reserved values are "true", "false", "skip-verify"
const mysqlTLSKey = "default"
2015-06-08 10:32:44 +00:00
// MySQLBackend is a physical backend that stores data
// within MySQL database.
type MySQLBackend struct {
2018-08-13 21:02:31 +00:00
dbTable string
2018-08-16 18:03:16 +00:00
dbLockTable string
2018-08-13 21:02:31 +00:00
client * sql . DB
statements map [ string ] * sql . Stmt
logger log . Logger
permitPool * physical . PermitPool
conf map [ string ] string
redirectHost string
redirectPort int64
2018-08-16 18:03:16 +00:00
haEnabled bool
2015-06-08 10:32:44 +00:00
}
2017-08-03 17:24:27 +00:00
// NewMySQLBackend constructs a MySQL backend using the given API client and
2015-06-08 10:32:44 +00:00
// server address and credential for accessing mysql database.
2017-08-03 17:24:27 +00:00
func NewMySQLBackend ( conf map [ string ] string , logger log . Logger ) ( physical . Backend , error ) {
2017-06-01 22:20:32 +00:00
var err error
2018-08-13 21:02:31 +00:00
db , err := NewMySQLClient ( conf , logger )
if err != nil {
return nil , err
2015-06-08 10:32:44 +00:00
}
2020-06-12 17:08:56 +00:00
database := conf [ "database" ]
if database == "" {
2015-06-18 21:31:00 +00:00
database = "vault"
2015-06-08 10:32:44 +00:00
}
2020-06-12 17:08:56 +00:00
table := conf [ "table" ]
if table == "" {
2015-06-18 21:31:00 +00:00
table = "vault"
2015-06-08 10:32:44 +00:00
}
2020-06-12 17:08:56 +00:00
err = validateDBTable ( database , table )
if err != nil {
return nil , err
}
dbTable := fmt . Sprintf ( "`%s`.`%s`" , database , table )
2015-06-08 10:32:44 +00:00
2017-06-01 22:20:32 +00:00
maxParStr , ok := conf [ "max_parallel" ]
var maxParInt int
if ok {
maxParInt , err = strconv . Atoi ( maxParStr )
if err != nil {
return nil , errwrap . Wrapf ( "failed parsing max_parallel parameter: {{err}}" , err )
}
if logger . IsDebug ( ) {
2018-04-03 00:46:59 +00:00
logger . Debug ( "max_parallel set" , "max_parallel" , maxParInt )
2017-06-01 22:20:32 +00:00
}
2017-07-17 17:04:49 +00:00
} else {
2017-08-03 17:24:27 +00:00
maxParInt = physical . DefaultParallelOperations
2017-06-01 22:20:32 +00:00
}
2017-12-19 19:23:58 +00:00
// Check schema exists
var schemaExist bool
schemaRows , err := db . Query ( "SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?" , database )
if err != nil {
2018-04-05 15:49:21 +00:00
return nil , errwrap . Wrapf ( "failed to check mysql schema exist: {{err}}" , err )
2017-12-19 19:23:58 +00:00
}
defer schemaRows . Close ( )
schemaExist = schemaRows . Next ( )
// Check table exists
var tableExist bool
tableRows , err := db . Query ( "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?" , table , database )
if err != nil {
2018-04-05 15:49:21 +00:00
return nil , errwrap . Wrapf ( "failed to check mysql table exist: {{err}}" , err )
2017-12-19 19:23:58 +00:00
}
defer tableRows . Close ( )
tableExist = tableRows . Next ( )
2015-06-18 21:31:00 +00:00
// Create the required database if it doesn't exists.
2017-12-19 19:23:58 +00:00
if ! schemaExist {
2018-08-10 23:38:20 +00:00
if _ , err := db . Exec ( "CREATE DATABASE IF NOT EXISTS `" + database + "`" ) ; err != nil {
2018-04-05 15:49:21 +00:00
return nil , errwrap . Wrapf ( "failed to create mysql database: {{err}}" , err )
2017-12-19 19:23:58 +00:00
}
2015-06-08 10:32:44 +00:00
}
2015-06-18 21:31:00 +00:00
// Create the required table if it doesn't exists.
2017-12-19 19:23:58 +00:00
if ! tableExist {
create_query := "CREATE TABLE IF NOT EXISTS " + dbTable +
" (vault_key varbinary(512), vault_value mediumblob, PRIMARY KEY (vault_key))"
if _ , err := db . Exec ( create_query ) ; err != nil {
2018-04-05 15:49:21 +00:00
return nil , errwrap . Wrapf ( "failed to create mysql table: {{err}}" , err )
2017-12-19 19:23:58 +00:00
}
2015-06-12 05:56:25 +00:00
}
2018-08-16 18:03:16 +00:00
// Default value for ha_enabled
haEnabledStr , ok := conf [ "ha_enabled" ]
2018-08-13 21:02:31 +00:00
if ! ok {
2018-08-16 18:03:16 +00:00
haEnabledStr = "false"
}
haEnabled , err := strconv . ParseBool ( haEnabledStr )
if err != nil {
return nil , fmt . Errorf ( "value [%v] of 'ha_enabled' could not be understood" , haEnabledStr )
2018-08-13 21:02:31 +00:00
}
2018-08-16 18:03:16 +00:00
locktable , ok := conf [ "lock_table" ]
if ! ok {
locktable = table + "_lock"
}
2018-08-13 21:02:31 +00:00
2018-08-16 18:03:16 +00:00
dbLockTable := "`" + database + "`.`" + locktable + "`"
2018-08-13 21:02:31 +00:00
2018-08-16 18:03:16 +00:00
// Only create lock table if ha_enabled is true
if haEnabled {
// Check table exists
var lockTableExist bool
lockTableRows , err := db . Query ( "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?" , locktable , database )
2018-08-13 21:02:31 +00:00
2018-08-16 18:03:16 +00:00
if err != nil {
return nil , errwrap . Wrapf ( "failed to check mysql table exist: {{err}}" , err )
}
defer lockTableRows . Close ( )
lockTableExist = lockTableRows . Next ( )
// Create the required table if it doesn't exists.
if ! lockTableExist {
create_query := "CREATE TABLE IF NOT EXISTS " + dbLockTable +
2018-10-18 17:35:04 +00:00
" (node_job varbinary(512), current_leader varbinary(512), PRIMARY KEY (node_job))"
2018-08-16 18:03:16 +00:00
if _ , err := db . Exec ( create_query ) ; err != nil {
return nil , errwrap . Wrapf ( "failed to create mysql table: {{err}}" , err )
}
2018-08-13 21:02:31 +00:00
}
}
2015-06-08 10:32:44 +00:00
// Setup the backend.
m := & MySQLBackend {
2018-08-16 18:03:16 +00:00
dbTable : dbTable ,
dbLockTable : dbLockTable ,
client : db ,
statements : make ( map [ string ] * sql . Stmt ) ,
logger : logger ,
permitPool : physical . NewPermitPool ( maxParInt ) ,
conf : conf ,
haEnabled : haEnabled ,
2015-06-08 10:32:44 +00:00
}
2015-06-18 21:31:00 +00:00
// Prepare all the statements required
statements := map [ string ] string {
"put" : "INSERT INTO " + dbTable +
" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)" ,
2018-08-16 18:03:16 +00:00
"get" : "SELECT vault_value FROM " + dbTable + " WHERE vault_key = ?" ,
"delete" : "DELETE FROM " + dbTable + " WHERE vault_key = ?" ,
"list" : "SELECT vault_key FROM " + dbTable + " WHERE vault_key LIKE ?" ,
}
// Only prepare ha-related statements if we need them
if haEnabled {
statements [ "get_lock" ] = "SELECT current_leader FROM " + dbLockTable + " WHERE node_job = ?"
statements [ "used_lock" ] = "SELECT IS_USED_LOCK(?)"
2015-06-18 21:31:00 +00:00
}
2018-08-16 18:03:16 +00:00
2015-06-18 21:31:00 +00:00
for name , query := range statements {
if err := m . prepare ( name , query ) ; err != nil {
return nil , err
}
}
2017-06-01 22:20:32 +00:00
2015-06-08 10:32:44 +00:00
return m , nil
}
2020-06-12 17:08:56 +00:00
// validateDBTable to prevent SQL injection attacks. This ensures that the database and table names only have valid
// characters in them. MySQL allows for more characters that this will allow, but there isn't an easy way of
// representing the full Unicode Basic Multilingual Plane to check against.
// https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
func validateDBTable ( db , table string ) ( err error ) {
merr := & multierror . Error { }
merr = multierror . Append ( merr , wrapErr ( "invalid database: %w" , validate ( db ) ) )
merr = multierror . Append ( merr , wrapErr ( "invalid table: %w" , validate ( table ) ) )
return merr . ErrorOrNil ( )
}
func validate ( name string ) ( err error ) {
if name == "" {
return fmt . Errorf ( "missing name" )
}
// From: https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
// - Permitted characters in quoted identifiers include the full Unicode Basic Multilingual Plane (BMP), except U+0000:
// ASCII: U+0001 .. U+007F
// Extended: U+0080 .. U+FFFF
// - ASCII NUL (U+0000) and supplementary characters (U+10000 and higher) are not permitted in quoted or unquoted identifiers.
// - Identifiers may begin with a digit but unless quoted may not consist solely of digits.
// - Database, table, and column names cannot end with space characters.
//
// We are explicitly excluding all space characters (it's easier to deal with)
// The name will be quoted, so the all-digit requirement doesn't apply
runes := [ ] rune ( name )
validationErr := fmt . Errorf ( "invalid character found: can only include printable, non-space characters between [0x0001-0xFFFF]" )
for _ , r := range runes {
// U+0000 Explicitly disallowed
if r == 0x0000 {
return fmt . Errorf ( "invalid character: cannot include 0x0000" )
}
// Cannot be above 0xFFFF
if r > 0xFFFF {
return fmt . Errorf ( "invalid character: cannot include any characters above 0xFFFF" )
}
if r == '`' {
return fmt . Errorf ( "invalid character: cannot include '`' character" )
}
if r == '\'' || r == '"' {
return fmt . Errorf ( "invalid character: cannot include quotes" )
}
// We are excluding non-printable characters (not mentioned in the docs)
if ! unicode . IsPrint ( r ) {
return validationErr
}
// We are excluding space characters (not mentioned in the docs)
if unicode . IsSpace ( r ) {
return validationErr
}
}
return nil
}
func wrapErr ( message string , err error ) error {
if err == nil {
return nil
}
return fmt . Errorf ( message , err )
}
2018-08-13 21:02:31 +00:00
func NewMySQLClient ( conf map [ string ] string , logger log . Logger ) ( * sql . DB , error ) {
var err error
// Get the MySQL credentials to perform read/write operations.
username , ok := conf [ "username" ]
if ! ok || username == "" {
return nil , fmt . Errorf ( "missing username" )
}
password , ok := conf [ "password" ]
if ! ok || password == "" {
return nil , fmt . Errorf ( "missing password" )
}
// Get or set MySQL server address. Defaults to localhost and default port(3306)
address , ok := conf [ "address" ]
if ! ok {
address = "127.0.0.1:3306"
}
maxIdleConnStr , ok := conf [ "max_idle_connections" ]
var maxIdleConnInt int
if ok {
maxIdleConnInt , err = strconv . Atoi ( maxIdleConnStr )
if err != nil {
return nil , errwrap . Wrapf ( "failed parsing max_idle_connections parameter: {{err}}" , err )
}
if logger . IsDebug ( ) {
logger . Debug ( "max_idle_connections set" , "max_idle_connections" , maxIdleConnInt )
}
}
maxConnLifeStr , ok := conf [ "max_connection_lifetime" ]
var maxConnLifeInt int
if ok {
maxConnLifeInt , err = strconv . Atoi ( maxConnLifeStr )
if err != nil {
return nil , errwrap . Wrapf ( "failed parsing max_connection_lifetime parameter: {{err}}" , err )
}
if logger . IsDebug ( ) {
logger . Debug ( "max_connection_lifetime set" , "max_connection_lifetime" , maxConnLifeInt )
}
}
maxParStr , ok := conf [ "max_parallel" ]
var maxParInt int
if ok {
maxParInt , err = strconv . Atoi ( maxParStr )
if err != nil {
return nil , errwrap . Wrapf ( "failed parsing max_parallel parameter: {{err}}" , err )
}
if logger . IsDebug ( ) {
logger . Debug ( "max_parallel set" , "max_parallel" , maxParInt )
}
} else {
maxParInt = physical . DefaultParallelOperations
}
dsnParams := url . Values { }
2020-05-21 16:09:37 +00:00
tlsCaFile , tlsOk := conf [ "tls_ca_file" ]
if tlsOk {
2018-08-13 21:02:31 +00:00
if err := setupMySQLTLSConfig ( tlsCaFile ) ; err != nil {
return nil , errwrap . Wrapf ( "failed register TLS config: {{err}}" , err )
}
dsnParams . Add ( "tls" , mysqlTLSKey )
}
2020-05-21 16:09:37 +00:00
ptAllowed , ptOk := conf [ "plaintext_connection_allowed" ]
if ! ( ptOk && strings . ToLower ( ptAllowed ) == "true" ) && ! tlsOk {
logger . Warn ( "No TLS specified, credentials will be sent in plaintext. To mute this warning add 'plaintext_connection_allowed' with a true value to your MySQL configuration in your config file." )
}
2018-08-13 21:02:31 +00:00
// Create MySQL handle for the database.
dsn := username + ":" + password + "@tcp(" + address + ")/?" + dsnParams . Encode ( )
db , err := sql . Open ( "mysql" , dsn )
if err != nil {
return nil , errwrap . Wrapf ( "failed to connect to mysql: {{err}}" , err )
}
db . SetMaxOpenConns ( maxParInt )
if maxIdleConnInt != 0 {
db . SetMaxIdleConns ( maxIdleConnInt )
}
if maxConnLifeInt != 0 {
db . SetConnMaxLifetime ( time . Duration ( maxConnLifeInt ) * time . Second )
}
return db , err
}
2015-06-18 21:31:00 +00:00
// prepare is a helper to prepare a query for future execution
func ( m * MySQLBackend ) prepare ( name , query string ) error {
stmt , err := m . client . Prepare ( query )
if err != nil {
2018-04-05 15:49:21 +00:00
return errwrap . Wrapf ( fmt . Sprintf ( "failed to prepare %q: {{err}}" , name ) , err )
2015-06-18 21:31:00 +00:00
}
m . statements [ name ] = stmt
return nil
}
2015-06-08 10:32:44 +00:00
// Put is used to insert or update an entry.
2018-01-19 06:44:44 +00:00
func ( m * MySQLBackend ) Put ( ctx context . Context , entry * physical . Entry ) error {
2015-06-08 10:32:44 +00:00
defer metrics . MeasureSince ( [ ] string { "mysql" , "put" } , time . Now ( ) )
2017-06-01 22:20:32 +00:00
m . permitPool . Acquire ( )
defer m . permitPool . Release ( )
2015-06-12 05:56:25 +00:00
_ , err := m . statements [ "put" ] . Exec ( entry . Key , entry . Value )
2015-06-08 10:32:44 +00:00
if err != nil {
return err
}
return nil
}
2018-08-13 21:02:31 +00:00
// Get is used to fetch an entry.
2018-01-19 06:44:44 +00:00
func ( m * MySQLBackend ) Get ( ctx context . Context , key string ) ( * physical . Entry , error ) {
2015-06-08 10:32:44 +00:00
defer metrics . MeasureSince ( [ ] string { "mysql" , "get" } , time . Now ( ) )
2017-06-01 22:20:32 +00:00
m . permitPool . Acquire ( )
defer m . permitPool . Release ( )
2015-06-08 10:32:44 +00:00
var result [ ] byte
2015-06-12 05:56:25 +00:00
err := m . statements [ "get" ] . QueryRow ( key ) . Scan ( & result )
2015-06-18 21:31:00 +00:00
if err == sql . ErrNoRows {
2015-06-13 02:19:40 +00:00
return nil , nil
2015-06-08 10:32:44 +00:00
}
2015-06-18 21:31:00 +00:00
if err != nil {
return nil , err
2015-06-12 17:31:46 +00:00
}
2017-08-03 17:24:27 +00:00
ent := & physical . Entry {
2015-06-08 10:32:44 +00:00
Key : key ,
Value : result ,
}
return ent , nil
}
// Delete is used to permanently delete an entry
2018-01-19 06:44:44 +00:00
func ( m * MySQLBackend ) Delete ( ctx context . Context , key string ) error {
2015-06-08 10:32:44 +00:00
defer metrics . MeasureSince ( [ ] string { "mysql" , "delete" } , time . Now ( ) )
2017-06-01 22:20:32 +00:00
m . permitPool . Acquire ( )
defer m . permitPool . Release ( )
2015-06-12 05:56:25 +00:00
_ , err := m . statements [ "delete" ] . Exec ( key )
2015-06-08 10:32:44 +00:00
if err != nil {
return err
}
return nil
}
// List is used to list all the keys under a given
// prefix, up to the next prefix.
2018-01-19 06:44:44 +00:00
func ( m * MySQLBackend ) List ( ctx context . Context , prefix string ) ( [ ] string , error ) {
2015-06-08 10:32:44 +00:00
defer metrics . MeasureSince ( [ ] string { "mysql" , "list" } , time . Now ( ) )
2017-06-01 22:20:32 +00:00
m . permitPool . Acquire ( )
defer m . permitPool . Release ( )
2015-06-18 21:31:00 +00:00
// Add the % wildcard to the prefix to do the prefix search
likePrefix := prefix + "%"
rows , err := m . statements [ "list" ] . Query ( likePrefix )
2016-11-17 14:59:27 +00:00
if err != nil {
2018-04-05 15:49:21 +00:00
return nil , errwrap . Wrapf ( "failed to execute statement: {{err}}" , err )
2016-11-17 14:59:27 +00:00
}
2015-06-08 10:32:44 +00:00
2015-06-18 21:31:00 +00:00
var keys [ ] string
2015-06-08 10:32:44 +00:00
for rows . Next ( ) {
2015-06-18 21:31:00 +00:00
var key string
err = rows . Scan ( & key )
2015-06-08 10:32:44 +00:00
if err != nil {
2018-04-05 15:49:21 +00:00
return nil , errwrap . Wrapf ( "failed to scan rows: {{err}}" , err )
2015-06-08 10:32:44 +00:00
}
2015-06-18 21:31:00 +00:00
key = strings . TrimPrefix ( key , prefix )
if i := strings . Index ( key , "/" ) ; i == - 1 {
// Add objects only from the current 'folder'
keys = append ( keys , key )
} else if i != - 1 {
// Add truncated 'folder' paths
2017-06-16 15:09:15 +00:00
keys = strutil . AppendIfMissing ( keys , string ( key [ : i + 1 ] ) )
2015-06-08 10:32:44 +00:00
}
}
sort . Strings ( keys )
return keys , nil
}
2015-07-17 15:54:38 +00:00
2018-08-13 21:02:31 +00:00
// LockWith is used for mutual exclusion based on the given key.
func ( m * MySQLBackend ) LockWith ( key , value string ) ( physical . Lock , error ) {
l := & MySQLHALock {
in : m ,
key : key ,
value : value ,
logger : m . logger ,
}
return l , nil
}
func ( m * MySQLBackend ) HAEnabled ( ) bool {
2018-08-16 18:03:16 +00:00
return m . haEnabled
2018-08-13 21:02:31 +00:00
}
// MySQLHALock is a MySQL Lock implementation for the HABackend
type MySQLHALock struct {
in * MySQLBackend
key string
value string
logger log . Logger
held bool
localLock sync . Mutex
leaderCh chan struct { }
stopCh <- chan struct { }
lock * MySQLLock
}
func ( i * MySQLHALock ) Lock ( stopCh <- chan struct { } ) ( <- chan struct { } , error ) {
i . localLock . Lock ( )
defer i . localLock . Unlock ( )
if i . held {
return nil , fmt . Errorf ( "lock already held" )
}
// Attempt an async acquisition
didLock := make ( chan struct { } )
failLock := make ( chan error , 1 )
releaseCh := make ( chan bool , 1 )
go i . attemptLock ( i . key , i . value , didLock , failLock , releaseCh )
// Wait for lock acquisition, failure, or shutdown
select {
case <- didLock :
releaseCh <- false
case err := <- failLock :
return nil , err
case <- stopCh :
releaseCh <- true
return nil , nil
}
// Create the leader channel
i . held = true
i . leaderCh = make ( chan struct { } )
go i . monitorLock ( i . leaderCh )
i . stopCh = stopCh
return i . leaderCh , nil
}
func ( i * MySQLHALock ) attemptLock ( key , value string , didLock chan struct { } , failLock chan error , releaseCh chan bool ) {
lock , err := NewMySQLLock ( i . in , i . logger , key , value )
if err != nil {
failLock <- err
2020-02-05 20:08:48 +00:00
return
2018-08-13 21:02:31 +00:00
}
2020-02-05 20:08:48 +00:00
// Set node value
i . lock = lock
2018-08-13 21:02:31 +00:00
err = lock . Lock ( )
if err != nil {
failLock <- err
return
}
// Signal that lock is held
close ( didLock )
// Handle an early abort
release := <- releaseCh
if release {
lock . Unlock ( )
}
}
func ( i * MySQLHALock ) monitorLock ( leaderCh chan struct { } ) {
for {
// The only way to lose this lock is if someone is
// logging into the DB and altering system tables or you lose a connection in
// which case you will lose the lock anyway.
err := i . hasLock ( i . key )
if err != nil {
// Somehow we lost the lock.... likely because the connection holding
// the lock was closed or someone was playing around with the locks in the DB.
close ( leaderCh )
return
}
time . Sleep ( 5 * time . Second )
}
}
func ( i * MySQLHALock ) Unlock ( ) error {
i . localLock . Lock ( )
defer i . localLock . Unlock ( )
if ! i . held {
return nil
}
err := i . lock . Unlock ( )
if err == nil {
i . held = false
return nil
}
return err
}
// hasLock will check if a lock is held by checking the current lock id against our known ID.
func ( i * MySQLHALock ) hasLock ( key string ) error {
var result sql . NullInt64
err := i . in . statements [ "used_lock" ] . QueryRow ( key ) . Scan ( & result )
if err == sql . ErrNoRows || ! result . Valid {
// This is not an error to us since it just means the lock isn't held
return nil
}
if err != nil {
return err
}
// IS_USED_LOCK will return the ID of the connection that created the lock.
if result . Int64 != GlobalLockID {
return ErrLockHeld
}
return nil
}
func ( i * MySQLHALock ) GetLeader ( ) ( string , error ) {
defer metrics . MeasureSince ( [ ] string { "mysql" , "lock_get" } , time . Now ( ) )
var result string
err := i . in . statements [ "get_lock" ] . QueryRow ( "leader" ) . Scan ( & result )
if err == sql . ErrNoRows {
return "" , err
}
return result , nil
}
func ( i * MySQLHALock ) Value ( ) ( bool , string , error ) {
leaderkey , err := i . GetLeader ( )
if err != nil {
return false , "" , err
}
return true , leaderkey , err
}
// MySQLLock provides an easy way to grab and release mysql
// locks using the built in GET_LOCK function. Note that these
// locks are released when you lose connection to the server.
type MySQLLock struct {
parentConn * MySQLBackend
in * sql . DB
logger log . Logger
statements map [ string ] * sql . Stmt
key string
value string
}
// Errors specific to trying to grab a lock in MySQL
var (
// This is the GlobalLockID for checking if the lock we got is still the current lock
GlobalLockID int64
// ErrLockHeld is returned when another vault instance already has a lock held for the given key.
ErrLockHeld = errors . New ( "mysql: lock already held" )
// ErrUnlockFailed
ErrUnlockFailed = errors . New ( "mysql: unable to release lock, already released or not held by this session" )
// You were unable to update that you are the new leader in the DB
2019-03-19 13:32:45 +00:00
ErrClaimFailed = errors . New ( "mysql: unable to update DB with new leader information" )
// Error to throw if between getting the lock and checking the ID of it we lost it.
2018-08-13 21:02:31 +00:00
ErrSettingGlobalID = errors . New ( "mysql: getting global lock id failed" )
)
// NewMySQLLock helper function
func NewMySQLLock ( in * MySQLBackend , l log . Logger , key , value string ) ( * MySQLLock , error ) {
// Create a new MySQL connection so we can close this and have no effect on
// the rest of the MySQL backend and any cleanup that might need to be done.
conn , _ := NewMySQLClient ( in . conf , in . logger )
m := & MySQLLock {
parentConn : in ,
in : conn ,
logger : l ,
statements : make ( map [ string ] * sql . Stmt ) ,
key : key ,
value : value ,
}
statements := map [ string ] string {
2018-08-16 18:03:16 +00:00
"put" : "INSERT INTO " + in . dbLockTable +
2018-08-13 21:02:31 +00:00
" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE current_leader=VALUES(current_leader)" ,
}
for name , query := range statements {
if err := m . prepare ( name , query ) ; err != nil {
return nil , err
}
}
return m , nil
}
// prepare is a helper to prepare a query for future execution
func ( m * MySQLLock ) prepare ( name , query string ) error {
stmt , err := m . in . Prepare ( query )
if err != nil {
return errwrap . Wrapf ( fmt . Sprintf ( "failed to prepare %q: {{err}}" , name ) , err )
}
m . statements [ name ] = stmt
return nil
}
// update the current cluster leader in the DB. This is used so
// we can tell the servers in standby who the active leader is.
func ( i * MySQLLock ) becomeLeader ( ) error {
_ , err := i . statements [ "put" ] . Exec ( "leader" , i . value )
if err != nil {
return err
}
return nil
}
// Lock will try to get a lock for an indefinite amount of time
// based on the given key that has been requested.
func ( i * MySQLLock ) Lock ( ) error {
defer metrics . MeasureSince ( [ ] string { "mysql" , "get_lock" } , time . Now ( ) )
2018-10-02 18:42:50 +00:00
// Lock timeout math.MaxInt32 instead of -1 solves compatibility issues with
2018-09-19 19:05:05 +00:00
// different MySQL flavours i.e. MariaDB
2018-10-02 18:42:50 +00:00
rows , err := i . in . Query ( "SELECT GET_LOCK(?, ?), IS_USED_LOCK(?)" , i . key , math . MaxInt32 , i . key )
2018-08-13 21:02:31 +00:00
if err != nil {
return err
}
defer rows . Close ( )
rows . Next ( )
var lock sql . NullInt64
var connectionID sql . NullInt64
rows . Scan ( & lock , & connectionID )
if rows . Err ( ) != nil {
return rows . Err ( )
}
// 1 is returned from GET_LOCK if it was able to get the lock
2019-03-19 13:32:45 +00:00
// 0 if it failed and NULL if some strange error happened.
2018-08-13 21:02:31 +00:00
// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_get-lock
if ! lock . Valid || lock . Int64 != 1 {
return ErrLockHeld
}
// Since we have the lock alert the rest of the cluster
// that we are now the active leader.
err = i . becomeLeader ( )
if err != nil {
return ErrLockHeld
}
// This will return the connection ID of NULL if an error happens
// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_is-used-lock
if ! connectionID . Valid {
return ErrSettingGlobalID
}
GlobalLockID = connectionID . Int64
return nil
}
// Unlock just closes the connection. This is because closing the MySQL connection
// is a 100% reliable way to close the lock. If you just release the lock you must
// do it from the same mysql connection_id that you originally created it from. This
// is a huge hastle and I actually couldn't find a clean way to do this although one
// likely does exist. Closing the connection however ensures we don't ever get into a
// state where we try to release the lock and it hangs it is also much less code.
func ( i * MySQLLock ) Unlock ( ) error {
err := i . in . Close ( )
if err != nil {
return ErrUnlockFailed
}
return nil
}
2015-07-29 20:03:46 +00:00
// Establish a TLS connection with a given CA certificate
2018-03-20 18:54:10 +00:00
// Register a tsl.Config associated with the same key as the dns param from sql.Open
2015-07-29 20:03:46 +00:00
// foo:bar@tcp(127.0.0.1:3306)/dbname?tls=default
func setupMySQLTLSConfig ( tlsCaFile string ) error {
2015-07-17 15:54:38 +00:00
rootCertPool := x509 . NewCertPool ( )
2015-07-29 20:03:46 +00:00
pem , err := ioutil . ReadFile ( tlsCaFile )
2015-07-17 15:54:38 +00:00
if err != nil {
return err
}
if ok := rootCertPool . AppendCertsFromPEM ( pem ) ; ! ok {
return err
}
2015-07-29 20:03:46 +00:00
err = mysql . RegisterTLSConfig ( mysqlTLSKey , & tls . Config {
2015-07-17 15:54:38 +00:00
RootCAs : rootCertPool ,
} )
if err != nil {
return err
}
return nil
}