Database refactor mssql (#2562)

* WIP on mssql secret backend refactor

* Add RevokeUser test, and use sqlserver driver internally

* Remove debug statements

* Fix code comment
This commit is contained in:
Calvin Leung Huang 2017-04-03 12:59:30 -04:00 committed by Brian Kassouf
parent 210fa77e3c
commit aa15a1d3a9
4 changed files with 464 additions and 2 deletions

View file

@ -10,6 +10,7 @@ import (
"time"
// Import sql drivers
_ "github.com/denisenkom/go-mssqldb"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
"github.com/mitchellh/mapstructure"
@ -73,6 +74,12 @@ func (c *sqlConnectionProducer) connection() (interface{}, error) {
c.db.Close()
}
// For mssql backend, switch to sqlserver instead
dbType := c.config.DatabaseType
if c.config.DatabaseType == "mssql" {
dbType = "sqlserver"
}
// Otherwise, attempt to make connection
conn := c.ConnectionURL
@ -86,7 +93,7 @@ func (c *sqlConnectionProducer) connection() (interface{}, error) {
}
var err error
c.db, err = sql.Open(c.config.DatabaseType, conn)
c.db, err = sql.Open(dbType, conn)
if err != nil {
return nil, err
}

View file

@ -13,6 +13,7 @@ import (
const (
postgreSQLTypeName = "postgres"
mySQLTypeName = "mysql"
msSQLTypeName = "mssql"
cassandraTypeName = "cassandra"
pluginTypeName = "plugin"
)
@ -61,6 +62,20 @@ func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Log
CredentialsProducer: credsProducer,
}
case msSQLTypeName:
connProducer := &sqlConnectionProducer{}
connProducer.config = conf
credsProducer := &sqlCredentialsProducer{
displayNameLen: 10,
usernameLen: 63,
}
dbType = &MSSQL{
ConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
case cassandraTypeName:
connProducer := &cassandraConnectionProducer{}
connProducer.config = conf
@ -163,7 +178,7 @@ func (dc *DatabaseConfig) GetFactory() Factory {
return BuiltinFactory
}
// Statments set in role creation and passed into the database type's functions.
// Statements set in role creation and passed into the database type's functions.
// TODO: Add a way of setting defaults here.
type Statements struct {
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`

View file

@ -0,0 +1,219 @@
package dbs
import (
"database/sql"
"fmt"
"strings"
"github.com/hashicorp/vault/helper/strutil"
)
// MSSQL is an implementation of DatabaseType interface
type MSSQL struct {
ConnectionProducer
CredentialsProducer
}
// Type returns the TypeName for this backend
func (m *MSSQL) Type() string {
return msSQLTypeName
}
func (m *MSSQL) getConnection() (*sql.DB, error) {
db, err := m.connection()
if err != nil {
return nil, err
}
return db.(*sql.DB), nil
}
// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
// the CreationStatement provided.
func (m *MSSQL) CreateUser(statements Statements, username, password, expiration string) error {
// Grab the lock
m.Lock()
defer m.Unlock()
// Get the connection
db, err := m.getConnection()
if err != nil {
return err
}
if statements.CreationStatements == "" {
return ErrEmptyCreationStatement
}
// Start a transaction
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
"name": username,
"password": password,
}))
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return err
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return err
}
return nil
}
// RenewUser is not supported on MSSQL, so this is a no-op.
func (m *MSSQL) RenewUser(statements Statements, username, expiration string) error {
// NOOP
return nil
}
// RevokeUser attempts to drop the specified user. It will first attempt to disable login,
// then kill pending connections from that user, and finally drop the user and login from the
// database instance.
func (m *MSSQL) RevokeUser(statements Statements, username string) error {
// Get connection
db, err := m.getConnection()
if err != nil {
return err
}
// First disable server login
disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
if err != nil {
return err
}
defer disableStmt.Close()
if _, err := disableStmt.Exec(); err != nil {
return err
}
// Query for sessions for the login so that we can kill any outstanding
// sessions. There cannot be any active sessions before we drop the logins
// This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible
sessionStmt, err := db.Prepare(fmt.Sprintf(
"SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username))
if err != nil {
return err
}
defer sessionStmt.Close()
sessionRows, err := sessionStmt.Query()
if err != nil {
return err
}
defer sessionRows.Close()
var revokeStmts []string
for sessionRows.Next() {
var sessionID int
err = sessionRows.Scan(&sessionID)
if err != nil {
return err
}
revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID))
}
// Query for database users using undocumented stored procedure for now since
// it is the easiest way to get this information;
// we need to drop the database users before we can drop the login and the role
// This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible
stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username))
if err != nil {
return err
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var loginName, dbName, qUsername string
var aliasName sql.NullString
err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName)
if err != nil {
return err
}
revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName, username, username))
}
// we do not stop on error, as we want to remove as
// many permissions as possible right now
var lastStmtError error
for _, query := range revokeStmts {
stmt, err := db.Prepare(query)
if err != nil {
lastStmtError = err
continue
}
defer stmt.Close()
_, err = stmt.Exec()
if err != nil {
lastStmtError = err
}
}
// can't drop if not all database users are dropped
if rows.Err() != nil {
return fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err())
}
if lastStmtError != nil {
return fmt.Errorf("could not perform all sql statements: %s", lastStmtError)
}
// Drop this login
stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username))
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return err
}
return nil
}
const dropUserSQL = `
USE [%s]
IF EXISTS
(SELECT name
FROM sys.database_principals
WHERE name = N'%s')
BEGIN
DROP USER [%s]
END
`
const dropLoginSQL = `
IF EXISTS
(SELECT name
FROM master.sys.server_principals
WHERE name = N'%s')
BEGIN
DROP LOGIN [%s]
END
`

View file

@ -0,0 +1,221 @@
package dbs
import (
"database/sql"
"fmt"
"os"
"sync"
"testing"
"time"
_ "github.com/denisenkom/go-mssqldb"
log "github.com/mgutz/logxi/v1"
dockertest "gopkg.in/ory-am/dockertest.v3"
)
var (
testMSQLImagePull sync.Once
)
func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) {
if os.Getenv("MSSQL_URL") != "" {
return func() {}, os.Getenv("MSSQL_URL")
}
pool, err := dockertest.NewPool("")
if err != nil {
t.Fatalf("Failed to connect to docker: %s", err)
}
resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"})
if err != nil {
t.Fatalf("Could not start local MSSQL docker container: %s", err)
}
cleanup = func() {
err := pool.Purge(resource)
if err != nil {
t.Fatalf("Failed to cleanup local DynamoDB: %s", err)
}
}
retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp"))
// exponential backoff-retry, because the mssql container may not be able to accept connections yet
if err = pool.Retry(func() error {
var err error
var db *sql.DB
db, err = sql.Open("mssql", retURL)
if err != nil {
return err
}
return db.Ping()
}); err != nil {
t.Fatalf("Could not connect to MSSQL docker container: %s", err)
}
return
}
func TestMSSQL_Initialize(t *testing.T) {
cleanup, connURL := prepareMSSQLTestContainer(t)
defer cleanup()
conf := &DatabaseConfig{
DatabaseType: msSQLTypeName,
ConnectionDetails: map[string]interface{}{
"connection_url": connURL,
},
}
dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
// Deconsturct the middleware chain to get the underlying mssql object
dbTracer := dbRaw.(*databaseTracingMiddleware)
dbMetrics := dbTracer.next.(*databaseMetricsMiddleware)
db := dbMetrics.next.(*MSSQL)
connProducer := db.ConnectionProducer.(*sqlConnectionProducer)
err = dbRaw.Initialize(conf.ConnectionDetails)
if err != nil {
t.Fatalf("err: %s", err)
}
if !connProducer.initalized {
t.Fatal("Database should be initalized")
}
err = dbRaw.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
if connProducer.db != nil {
t.Fatal("db object should be nil")
}
}
func TestMSSQL_CreateUser(t *testing.T) {
cleanup, connURL := prepareMSSQLTestContainer(t)
defer cleanup()
conf := &DatabaseConfig{
DatabaseType: msSQLTypeName,
ConnectionDetails: map[string]interface{}{
"connection_url": connURL,
},
}
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.Initialize(conf.ConnectionDetails)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err := db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err := db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err := db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
// Test with no configured Creation Statememt
err = db.CreateUser(Statements{}, username, password, expiration)
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
}
statements := Statements{
CreationStatements: testMSSQLRole,
}
err = db.CreateUser(statements, username, password, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err = db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err = db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err = db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestMSSQL_RevokeUser(t *testing.T) {
cleanup, connURL := prepareMSSQLTestContainer(t)
defer cleanup()
conf := &DatabaseConfig{
DatabaseType: msSQLTypeName,
ConnectionDetails: map[string]interface{}{
"connection_url": connURL,
},
}
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.Initialize(conf.ConnectionDetails)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err := db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err := db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err := db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := Statements{
CreationStatements: testMSSQLRole,
}
err = db.CreateUser(statements, username, password, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
// Test default revoke statememts
err = db.RevokeUser(statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
}
const testMSSQLRole = `
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
CREATE USER [{{name}}] FOR LOGIN [{{name}}];
GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];`