Add mysql database type

This commit is contained in:
Brian Kassouf 2017-01-04 10:18:10 -08:00 committed by Brian Kassouf
parent 2ec5ab5616
commit 46aa7142c1
4 changed files with 145 additions and 8 deletions

View File

@ -8,6 +8,7 @@ import (
"sync" "sync"
"time" "time"
_ "github.com/go-sql-driver/mysql"
"github.com/gocql/gocql" "github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/tlsutil" "github.com/hashicorp/vault/helper/tlsutil"

View File

@ -20,24 +20,24 @@ type sqlCredentialsProducer struct {
usernameLen int usernameLen int
} }
func (scg *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) { func (scp *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) {
// Generate the username, password and expiration. PG limits user to 63 characters // Generate the username, password and expiration. PG limits user to 63 characters
if scg.displayNameLen > 0 && len(displayName) > scg.displayNameLen { if scp.displayNameLen > 0 && len(displayName) > scp.displayNameLen {
displayName = displayName[:scg.displayNameLen] displayName = displayName[:scp.displayNameLen]
} }
userUUID, err := uuid.GenerateUUID() userUUID, err := uuid.GenerateUUID()
if err != nil { if err != nil {
return "", err return "", err
} }
username := fmt.Sprintf("%s-%s", displayName, userUUID) username := fmt.Sprintf("%s-%s", displayName, userUUID)
if scg.usernameLen > 0 && len(username) > scg.usernameLen { if scp.usernameLen > 0 && len(username) > scp.usernameLen {
username = username[:scg.usernameLen] username = username[:scp.usernameLen]
} }
return username, nil return username, nil
} }
func (scg *sqlCredentialsProducer) GeneratePassword() (string, error) { func (scp *sqlCredentialsProducer) GeneratePassword() (string, error) {
password, err := uuid.GenerateUUID() password, err := uuid.GenerateUUID()
if err != nil { if err != nil {
return "", err return "", err
@ -46,7 +46,7 @@ func (scg *sqlCredentialsProducer) GeneratePassword() (string, error) {
return password, nil return password, nil
} }
func (scg *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string { func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string {
return time.Now(). return time.Now().
Add(ttl). Add(ttl).
Format("2006-01-02 15:04:05-0700") Format("2006-01-02 15:04:05-0700")

View File

@ -0,0 +1,136 @@
package dbs
import (
"database/sql"
"strings"
"sync"
"github.com/hashicorp/vault/helper/strutil"
)
const defaultRevocationSQL = `
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
DROP USER '{{name}}'@'%'
`
type MySQL struct {
db *sql.DB
ConnectionProducer
CredentialsProducer
sync.RWMutex
}
func (p *MySQL) Type() string {
return postgreSQLTypeName
}
func (p *MySQL) getConnection() (*sql.DB, error) {
db, err := p.Connection()
if err != nil {
return nil, err
}
return db.(*sql.DB), nil
}
func (p *MySQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error {
// Get the connection
db, err := p.getConnection()
if err != nil {
return err
}
// TODO: This is racey
// Grab a read lock
p.RLock()
defer p.RUnlock()
// Start a transaction
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
"name": username,
"password": password,
}))
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return err
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return err
}
return nil
}
// NOOP
func (p *MySQL) RenewUser(username, expiration string) error {
return nil
}
func (p *MySQL) RevokeUser(username, revocationStmt string) error {
// Get the connection
db, err := p.getConnection()
if err != nil {
return err
}
// Grab the read lock
p.RLock()
defer p.RUnlock()
// Use a default SQL statement for revocation if one cannot be fetched from the role
if revocationStmt == "" {
revocationStmt = defaultRevocationSQL
}
// Start a transaction
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
// This is not a prepared statement because not all commands are supported
// 1295: This command is not supported in the prepared statement protocol yet
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
query = strings.Replace(query, "{{name}}", username, -1)
_, err = tx.Exec(query)
if err != nil {
return err
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return err
}
return nil
}

View File

@ -124,7 +124,7 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew
// Don't allow the connection type to change // Don't allow the connection type to change
if b.connections[name].Type() != connType { if b.connections[name].Type() != connType {
return logical.ErrorResponse("can not change type of existing connection"), nil return logical.ErrorResponse("Can not change type of existing connection."), nil
} }
db = b.connections[name] db = b.connections[name]