88118dce0f
* Add max_parallel parameter to MySQL backend. This limits the number of concurrent connections, so that vault does not die suddenly from "Too many connections". This can happen when e.g. vault starts up, and tries to load all the existing leases in parallel. At the time of writing this, the value ExpirationRestoreWorkerCount in vault/helper/consts/const.go is set to 64, meaning that if there are enough leases in the vault's DB, it will generate AT LEAST 64 concurrent connections to MySQL when loading the data during start-up. On certain configurations, e.g. smaller AWS RDS/Aurora instances, this will cause Vault to fail startup. * Fix a typo in mysql storage readme
257 lines
6.3 KiB
Go
257 lines
6.3 KiB
Go
package physical
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"database/sql"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/url"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
log "github.com/mgutz/logxi/v1"
|
|
|
|
"github.com/armon/go-metrics"
|
|
mysql "github.com/go-sql-driver/mysql"
|
|
"github.com/hashicorp/errwrap"
|
|
)
|
|
|
|
// Unreserved tls key
|
|
// Reserved values are "true", "false", "skip-verify"
|
|
const mysqlTLSKey = "default"
|
|
|
|
// MySQLBackend is a physical backend that stores data
|
|
// within MySQL database.
|
|
type MySQLBackend struct {
|
|
dbTable string
|
|
client *sql.DB
|
|
statements map[string]*sql.Stmt
|
|
logger log.Logger
|
|
permitPool *PermitPool
|
|
}
|
|
|
|
// newMySQLBackend constructs a MySQL backend using the given API client and
|
|
// server address and credential for accessing mysql database.
|
|
func newMySQLBackend(conf map[string]string, logger log.Logger) (Backend, error) {
|
|
var err error
|
|
|
|
// Get the MySQL credentials to perform read/write operations.
|
|
username, ok := conf["username"]
|
|
if !ok || username == "" {
|
|
return nil, fmt.Errorf("missing username")
|
|
}
|
|
password, ok := conf["password"]
|
|
if !ok || username == "" {
|
|
return nil, fmt.Errorf("missing password")
|
|
}
|
|
|
|
// Get or set MySQL server address. Defaults to localhost and default port(3306)
|
|
address, ok := conf["address"]
|
|
if !ok {
|
|
address = "127.0.0.1:3306"
|
|
}
|
|
|
|
// Get the MySQL database and table details.
|
|
database, ok := conf["database"]
|
|
if !ok {
|
|
database = "vault"
|
|
}
|
|
table, ok := conf["table"]
|
|
if !ok {
|
|
table = "vault"
|
|
}
|
|
dbTable := database + "." + table
|
|
|
|
maxParStr, ok := conf["max_parallel"]
|
|
var maxParInt int
|
|
if ok {
|
|
maxParInt, err = strconv.Atoi(maxParStr)
|
|
if err != nil {
|
|
return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
|
|
}
|
|
if logger.IsDebug() {
|
|
logger.Debug("mysql: max_parallel set", "max_parallel", maxParInt)
|
|
}
|
|
}
|
|
|
|
dsnParams := url.Values{}
|
|
tlsCaFile, ok := conf["tls_ca_file"]
|
|
if ok {
|
|
if err := setupMySQLTLSConfig(tlsCaFile); err != nil {
|
|
return nil, fmt.Errorf("failed register TLS config: %v", err)
|
|
}
|
|
|
|
dsnParams.Add("tls", mysqlTLSKey)
|
|
}
|
|
|
|
// Create MySQL handle for the database.
|
|
dsn := username + ":" + password + "@tcp(" + address + ")/?" + dsnParams.Encode()
|
|
db, err := sql.Open("mysql", dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to mysql: %v", err)
|
|
}
|
|
|
|
// Create the required database if it doesn't exists.
|
|
if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS " + database); err != nil {
|
|
return nil, fmt.Errorf("failed to create mysql database: %v", err)
|
|
}
|
|
|
|
// Create the required table if it doesn't exists.
|
|
create_query := "CREATE TABLE IF NOT EXISTS " + dbTable +
|
|
" (vault_key varbinary(512), vault_value mediumblob, PRIMARY KEY (vault_key))"
|
|
if _, err := db.Exec(create_query); err != nil {
|
|
return nil, fmt.Errorf("failed to create mysql table: %v", err)
|
|
}
|
|
|
|
// Setup the backend.
|
|
m := &MySQLBackend{
|
|
dbTable: dbTable,
|
|
client: db,
|
|
statements: make(map[string]*sql.Stmt),
|
|
logger: logger,
|
|
permitPool: NewPermitPool(maxParInt),
|
|
}
|
|
|
|
// Prepare all the statements required
|
|
statements := map[string]string{
|
|
"put": "INSERT INTO " + dbTable +
|
|
" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)",
|
|
"get": "SELECT vault_value FROM " + dbTable + " WHERE vault_key = ?",
|
|
"delete": "DELETE FROM " + dbTable + " WHERE vault_key = ?",
|
|
"list": "SELECT vault_key FROM " + dbTable + " WHERE vault_key LIKE ?",
|
|
}
|
|
for name, query := range statements {
|
|
if err := m.prepare(name, query); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return m, nil
|
|
}
|
|
|
|
// prepare is a helper to prepare a query for future execution
|
|
func (m *MySQLBackend) prepare(name, query string) error {
|
|
stmt, err := m.client.Prepare(query)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare '%s': %v", name, err)
|
|
}
|
|
m.statements[name] = stmt
|
|
return nil
|
|
}
|
|
|
|
// Put is used to insert or update an entry.
|
|
func (m *MySQLBackend) Put(entry *Entry) error {
|
|
defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
_, err := m.statements["put"].Exec(entry.Key, entry.Value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get is used to fetch and entry.
|
|
func (m *MySQLBackend) Get(key string) (*Entry, error) {
|
|
defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
var result []byte
|
|
err := m.statements["get"].QueryRow(key).Scan(&result)
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ent := &Entry{
|
|
Key: key,
|
|
Value: result,
|
|
}
|
|
return ent, nil
|
|
}
|
|
|
|
// Delete is used to permanently delete an entry
|
|
func (m *MySQLBackend) Delete(key string) error {
|
|
defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
_, err := m.statements["delete"].Exec(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// List is used to list all the keys under a given
|
|
// prefix, up to the next prefix.
|
|
func (m *MySQLBackend) List(prefix string) ([]string, error) {
|
|
defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
// Add the % wildcard to the prefix to do the prefix search
|
|
likePrefix := prefix + "%"
|
|
rows, err := m.statements["list"].Query(likePrefix)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to execute statement: %v", err)
|
|
}
|
|
|
|
var keys []string
|
|
for rows.Next() {
|
|
var key string
|
|
err = rows.Scan(&key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan rows: %v", err)
|
|
}
|
|
|
|
key = strings.TrimPrefix(key, prefix)
|
|
if i := strings.Index(key, "/"); i == -1 {
|
|
// Add objects only from the current 'folder'
|
|
keys = append(keys, key)
|
|
} else if i != -1 {
|
|
// Add truncated 'folder' paths
|
|
keys = appendIfMissing(keys, string(key[:i+1]))
|
|
}
|
|
}
|
|
|
|
sort.Strings(keys)
|
|
return keys, nil
|
|
}
|
|
|
|
// Establish a TLS connection with a given CA certificate
|
|
// Register a tsl.Config associted with the same key as the dns param from sql.Open
|
|
// foo:bar@tcp(127.0.0.1:3306)/dbname?tls=default
|
|
func setupMySQLTLSConfig(tlsCaFile string) error {
|
|
rootCertPool := x509.NewCertPool()
|
|
|
|
pem, err := ioutil.ReadFile(tlsCaFile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
|
return err
|
|
}
|
|
|
|
err = mysql.RegisterTLSConfig(mysqlTLSKey, &tls.Config{
|
|
RootCAs: rootCertPool,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|