package physical import ( "crypto/tls" "crypto/x509" "database/sql" "fmt" "io/ioutil" "net/url" "sort" "strconv" "strings" "time" log "github.com/mgutz/logxi/v1" "github.com/armon/go-metrics" mysql "github.com/go-sql-driver/mysql" "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/helper/strutil" ) // Unreserved tls key // Reserved values are "true", "false", "skip-verify" const mysqlTLSKey = "default" // MySQLBackend is a physical backend that stores data // within MySQL database. type MySQLBackend struct { dbTable string client *sql.DB statements map[string]*sql.Stmt logger log.Logger permitPool *PermitPool } // newMySQLBackend constructs a MySQL backend using the given API client and // server address and credential for accessing mysql database. func newMySQLBackend(conf map[string]string, logger log.Logger) (Backend, error) { var err error // Get the MySQL credentials to perform read/write operations. username, ok := conf["username"] if !ok || username == "" { return nil, fmt.Errorf("missing username") } password, ok := conf["password"] if !ok || username == "" { return nil, fmt.Errorf("missing password") } // Get or set MySQL server address. Defaults to localhost and default port(3306) address, ok := conf["address"] if !ok { address = "127.0.0.1:3306" } // Get the MySQL database and table details. database, ok := conf["database"] if !ok { database = "vault" } table, ok := conf["table"] if !ok { table = "vault" } dbTable := database + "." + table maxParStr, ok := conf["max_parallel"] var maxParInt int if ok { maxParInt, err = strconv.Atoi(maxParStr) if err != nil { return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err) } if logger.IsDebug() { logger.Debug("mysql: max_parallel set", "max_parallel", maxParInt) } } else { maxParInt = DefaultParallelOperations } dsnParams := url.Values{} tlsCaFile, ok := conf["tls_ca_file"] if ok { if err := setupMySQLTLSConfig(tlsCaFile); err != nil { return nil, fmt.Errorf("failed register TLS config: %v", err) } dsnParams.Add("tls", mysqlTLSKey) } // Create MySQL handle for the database. dsn := username + ":" + password + "@tcp(" + address + ")/?" + dsnParams.Encode() db, err := sql.Open("mysql", dsn) if err != nil { return nil, fmt.Errorf("failed to connect to mysql: %v", err) } db.SetMaxOpenConns(maxParInt) // Create the required database if it doesn't exists. if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS " + database); err != nil { return nil, fmt.Errorf("failed to create mysql database: %v", err) } // Create the required table if it doesn't exists. create_query := "CREATE TABLE IF NOT EXISTS " + dbTable + " (vault_key varbinary(512), vault_value mediumblob, PRIMARY KEY (vault_key))" if _, err := db.Exec(create_query); err != nil { return nil, fmt.Errorf("failed to create mysql table: %v", err) } // Setup the backend. m := &MySQLBackend{ dbTable: dbTable, client: db, statements: make(map[string]*sql.Stmt), logger: logger, permitPool: NewPermitPool(maxParInt), } // Prepare all the statements required statements := map[string]string{ "put": "INSERT INTO " + dbTable + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)", "get": "SELECT vault_value FROM " + dbTable + " WHERE vault_key = ?", "delete": "DELETE FROM " + dbTable + " WHERE vault_key = ?", "list": "SELECT vault_key FROM " + dbTable + " WHERE vault_key LIKE ?", } for name, query := range statements { if err := m.prepare(name, query); err != nil { return nil, err } } return m, nil } // prepare is a helper to prepare a query for future execution func (m *MySQLBackend) prepare(name, query string) error { stmt, err := m.client.Prepare(query) if err != nil { return fmt.Errorf("failed to prepare '%s': %v", name, err) } m.statements[name] = stmt return nil } // Put is used to insert or update an entry. func (m *MySQLBackend) Put(entry *Entry) error { defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now()) m.permitPool.Acquire() defer m.permitPool.Release() _, err := m.statements["put"].Exec(entry.Key, entry.Value) if err != nil { return err } return nil } // Get is used to fetch and entry. func (m *MySQLBackend) Get(key string) (*Entry, error) { defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now()) m.permitPool.Acquire() defer m.permitPool.Release() var result []byte err := m.statements["get"].QueryRow(key).Scan(&result) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, err } ent := &Entry{ Key: key, Value: result, } return ent, nil } // Delete is used to permanently delete an entry func (m *MySQLBackend) Delete(key string) error { defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now()) m.permitPool.Acquire() defer m.permitPool.Release() _, err := m.statements["delete"].Exec(key) if err != nil { return err } return nil } // List is used to list all the keys under a given // prefix, up to the next prefix. func (m *MySQLBackend) List(prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now()) m.permitPool.Acquire() defer m.permitPool.Release() // Add the % wildcard to the prefix to do the prefix search likePrefix := prefix + "%" rows, err := m.statements["list"].Query(likePrefix) if err != nil { return nil, fmt.Errorf("failed to execute statement: %v", err) } var keys []string for rows.Next() { var key string err = rows.Scan(&key) if err != nil { return nil, fmt.Errorf("failed to scan rows: %v", err) } key = strings.TrimPrefix(key, prefix) if i := strings.Index(key, "/"); i == -1 { // Add objects only from the current 'folder' keys = append(keys, key) } else if i != -1 { // Add truncated 'folder' paths keys = strutil.AppendIfMissing(keys, string(key[:i+1])) } } sort.Strings(keys) return keys, nil } // Establish a TLS connection with a given CA certificate // Register a tsl.Config associted with the same key as the dns param from sql.Open // foo:bar@tcp(127.0.0.1:3306)/dbname?tls=default func setupMySQLTLSConfig(tlsCaFile string) error { rootCertPool := x509.NewCertPool() pem, err := ioutil.ReadFile(tlsCaFile) if err != nil { return err } if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { return err } err = mysql.RegisterTLSConfig(mysqlTLSKey, &tls.Config{ RootCAs: rootCertPool, }) if err != nil { return err } return nil }