diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index adecfd55a..dc8f6c82c 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -8,6 +8,7 @@ import ( "sync" "time" + _ "github.com/go-sql-driver/mysql" "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/tlsutil" diff --git a/builtin/logical/database/dbs/credentialsproducer.go b/builtin/logical/database/dbs/credentialsproducer.go index 20210c2e8..94fce6275 100644 --- a/builtin/logical/database/dbs/credentialsproducer.go +++ b/builtin/logical/database/dbs/credentialsproducer.go @@ -20,24 +20,24 @@ type sqlCredentialsProducer struct { usernameLen int } -func (scg *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) { +func (scp *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) { // Generate the username, password and expiration. PG limits user to 63 characters - if scg.displayNameLen > 0 && len(displayName) > scg.displayNameLen { - displayName = displayName[:scg.displayNameLen] + if scp.displayNameLen > 0 && len(displayName) > scp.displayNameLen { + displayName = displayName[:scp.displayNameLen] } userUUID, err := uuid.GenerateUUID() if err != nil { return "", err } username := fmt.Sprintf("%s-%s", displayName, userUUID) - if scg.usernameLen > 0 && len(username) > scg.usernameLen { - username = username[:scg.usernameLen] + if scp.usernameLen > 0 && len(username) > scp.usernameLen { + username = username[:scp.usernameLen] } return username, nil } -func (scg *sqlCredentialsProducer) GeneratePassword() (string, error) { +func (scp *sqlCredentialsProducer) GeneratePassword() (string, error) { password, err := uuid.GenerateUUID() if err != nil { return "", err @@ -46,7 +46,7 @@ func (scg *sqlCredentialsProducer) GeneratePassword() (string, error) { return password, nil } -func (scg *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string { +func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string { return time.Now(). Add(ttl). Format("2006-01-02 15:04:05-0700") diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go new file mode 100644 index 000000000..314d4c929 --- /dev/null +++ b/builtin/logical/database/dbs/mysql.go @@ -0,0 +1,136 @@ +package dbs + +import ( + "database/sql" + "strings" + "sync" + + "github.com/hashicorp/vault/helper/strutil" +) + +const defaultRevocationSQL = ` + REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; + DROP USER '{{name}}'@'%' +` + +type MySQL struct { + db *sql.DB + + ConnectionProducer + CredentialsProducer + sync.RWMutex +} + +func (p *MySQL) Type() string { + return postgreSQLTypeName +} + +func (p *MySQL) getConnection() (*sql.DB, error) { + db, err := p.Connection() + if err != nil { + return nil, err + } + + return db.(*sql.DB), nil +} + +func (p *MySQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { + // Get the connection + db, err := p.getConnection() + if err != nil { + return err + } + + // TODO: This is racey + // Grab a read lock + p.RLock() + defer p.RUnlock() + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(queryHelper(query, map[string]string{ + "name": username, + "password": password, + })) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +// NOOP +func (p *MySQL) RenewUser(username, expiration string) error { + return nil +} + +func (p *MySQL) RevokeUser(username, revocationStmt string) error { + // Get the connection + db, err := p.getConnection() + if err != nil { + return err + } + + // Grab the read lock + p.RLock() + defer p.RUnlock() + + // Use a default SQL statement for revocation if one cannot be fetched from the role + + if revocationStmt == "" { + revocationStmt = defaultRevocationSQL + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + // This is not a prepared statement because not all commands are supported + // 1295: This command is not supported in the prepared statement protocol yet + // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ + query = strings.Replace(query, "{{name}}", username, -1) + _, err = tx.Exec(query) + if err != nil { + return err + } + + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return err + } + + return nil +} diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index d4a969a69..90dfea4cd 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -124,7 +124,7 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew // Don't allow the connection type to change if b.connections[name].Type() != connType { - return logical.ErrorResponse("can not change type of existing connection"), nil + return logical.ErrorResponse("Can not change type of existing connection."), nil } db = b.connections[name]