From aa15a1d3a9a62e926be7c8fd74b2e13bccdd2a06 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Mon, 3 Apr 2017 12:59:30 -0400 Subject: [PATCH] Database refactor mssql (#2562) * WIP on mssql secret backend refactor * Add RevokeUser test, and use sqlserver driver internally * Remove debug statements * Fix code comment --- .../database/dbs/connectionproducer.go | 9 +- builtin/logical/database/dbs/db.go | 17 +- builtin/logical/database/dbs/mssql.go | 219 +++++++++++++++++ builtin/logical/database/dbs/mssql_test.go | 221 ++++++++++++++++++ 4 files changed, 464 insertions(+), 2 deletions(-) create mode 100644 builtin/logical/database/dbs/mssql.go create mode 100644 builtin/logical/database/dbs/mssql_test.go diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index ca9e7250e..b5dc93951 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -10,6 +10,7 @@ import ( "time" // Import sql drivers + _ "github.com/denisenkom/go-mssqldb" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" "github.com/mitchellh/mapstructure" @@ -73,6 +74,12 @@ func (c *sqlConnectionProducer) connection() (interface{}, error) { c.db.Close() } + // For mssql backend, switch to sqlserver instead + dbType := c.config.DatabaseType + if c.config.DatabaseType == "mssql" { + dbType = "sqlserver" + } + // Otherwise, attempt to make connection conn := c.ConnectionURL @@ -86,7 +93,7 @@ func (c *sqlConnectionProducer) connection() (interface{}, error) { } var err error - c.db, err = sql.Open(c.config.DatabaseType, conn) + c.db, err = sql.Open(dbType, conn) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 2637a73d1..cf8f8ee7f 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -13,6 +13,7 @@ import ( const ( postgreSQLTypeName = "postgres" mySQLTypeName = "mysql" + msSQLTypeName = "mssql" cassandraTypeName = "cassandra" pluginTypeName = "plugin" ) @@ -61,6 +62,20 @@ func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Log CredentialsProducer: credsProducer, } + case msSQLTypeName: + connProducer := &sqlConnectionProducer{} + connProducer.config = conf + + credsProducer := &sqlCredentialsProducer{ + displayNameLen: 10, + usernameLen: 63, + } + + dbType = &MSSQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } + case cassandraTypeName: connProducer := &cassandraConnectionProducer{} connProducer.config = conf @@ -163,7 +178,7 @@ func (dc *DatabaseConfig) GetFactory() Factory { return BuiltinFactory } -// Statments set in role creation and passed into the database type's functions. +// Statements set in role creation and passed into the database type's functions. // TODO: Add a way of setting defaults here. type Statements struct { CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` diff --git a/builtin/logical/database/dbs/mssql.go b/builtin/logical/database/dbs/mssql.go new file mode 100644 index 000000000..b7439b0a8 --- /dev/null +++ b/builtin/logical/database/dbs/mssql.go @@ -0,0 +1,219 @@ +package dbs + +import ( + "database/sql" + "fmt" + "strings" + + "github.com/hashicorp/vault/helper/strutil" +) + +// MSSQL is an implementation of DatabaseType interface +type MSSQL struct { + ConnectionProducer + CredentialsProducer +} + +// Type returns the TypeName for this backend +func (m *MSSQL) Type() string { + return msSQLTypeName +} + +func (m *MSSQL) getConnection() (*sql.DB, error) { + db, err := m.connection() + if err != nil { + return nil, err + } + + return db.(*sql.DB), nil +} + +// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by +// the CreationStatement provided. +func (m *MSSQL) CreateUser(statements Statements, username, password, expiration string) error { + // Grab the lock + m.Lock() + defer m.Unlock() + + // Get the connection + db, err := m.getConnection() + if err != nil { + return err + } + + if statements.CreationStatements == "" { + return ErrEmptyCreationStatement + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(queryHelper(query, map[string]string{ + "name": username, + "password": password, + })) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +// RenewUser is not supported on MSSQL, so this is a no-op. +func (m *MSSQL) RenewUser(statements Statements, username, expiration string) error { + // NOOP + return nil +} + +// RevokeUser attempts to drop the specified user. It will first attempt to disable login, +// then kill pending connections from that user, and finally drop the user and login from the +// database instance. +func (m *MSSQL) RevokeUser(statements Statements, username string) error { + // Get connection + db, err := m.getConnection() + if err != nil { + return err + } + + // First disable server login + disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username)) + if err != nil { + return err + } + defer disableStmt.Close() + if _, err := disableStmt.Exec(); err != nil { + return err + } + + // Query for sessions for the login so that we can kill any outstanding + // sessions. There cannot be any active sessions before we drop the logins + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + sessionStmt, err := db.Prepare(fmt.Sprintf( + "SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username)) + if err != nil { + return err + } + defer sessionStmt.Close() + + sessionRows, err := sessionStmt.Query() + if err != nil { + return err + } + defer sessionRows.Close() + + var revokeStmts []string + for sessionRows.Next() { + var sessionID int + err = sessionRows.Scan(&sessionID) + if err != nil { + return err + } + revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID)) + } + + // Query for database users using undocumented stored procedure for now since + // it is the easiest way to get this information; + // we need to drop the database users before we can drop the login and the role + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username)) + if err != nil { + return err + } + defer stmt.Close() + + rows, err := stmt.Query() + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var loginName, dbName, qUsername string + var aliasName sql.NullString + err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName) + if err != nil { + return err + } + revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName, username, username)) + } + + // we do not stop on error, as we want to remove as + // many permissions as possible right now + var lastStmtError error + for _, query := range revokeStmts { + stmt, err := db.Prepare(query) + if err != nil { + lastStmtError = err + continue + } + defer stmt.Close() + _, err = stmt.Exec() + if err != nil { + lastStmtError = err + } + } + + // can't drop if not all database users are dropped + if rows.Err() != nil { + return fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err()) + } + if lastStmtError != nil { + return fmt.Errorf("could not perform all sql statements: %s", lastStmtError) + } + + // Drop this login + stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username)) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} + +const dropUserSQL = ` +USE [%s] +IF EXISTS + (SELECT name + FROM sys.database_principals + WHERE name = N'%s') +BEGIN + DROP USER [%s] +END +` + +const dropLoginSQL = ` +IF EXISTS + (SELECT name + FROM master.sys.server_principals + WHERE name = N'%s') +BEGIN + DROP LOGIN [%s] +END +` diff --git a/builtin/logical/database/dbs/mssql_test.go b/builtin/logical/database/dbs/mssql_test.go new file mode 100644 index 000000000..f2169299f --- /dev/null +++ b/builtin/logical/database/dbs/mssql_test.go @@ -0,0 +1,221 @@ +package dbs + +import ( + "database/sql" + "fmt" + "os" + "sync" + "testing" + "time" + + _ "github.com/denisenkom/go-mssqldb" + log "github.com/mgutz/logxi/v1" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var ( + testMSQLImagePull sync.Once +) + +func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) { + if os.Getenv("MSSQL_URL") != "" { + return func() {}, os.Getenv("MSSQL_URL") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"}) + if err != nil { + t.Fatalf("Could not start local MSSQL docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local DynamoDB: %s", err) + } + } + + retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp")) + + // exponential backoff-retry, because the mssql container may not be able to accept connections yet + if err = pool.Retry(func() error { + var err error + var db *sql.DB + db, err = sql.Open("mssql", retURL) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + t.Fatalf("Could not connect to MSSQL docker container: %s", err) + } + + return +} + +func TestMSSQL_Initialize(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + conf := &DatabaseConfig{ + DatabaseType: msSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Deconsturct the middleware chain to get the underlying mssql object + dbTracer := dbRaw.(*databaseTracingMiddleware) + dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) + db := dbMetrics.next.(*MSSQL) + connProducer := db.ConnectionProducer.(*sqlConnectionProducer) + + err = dbRaw.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !connProducer.initalized { + t.Fatal("Database should be initalized") + } + + err = dbRaw.Close() + if err != nil { + t.Fatalf("err: %s", err) + } + + if connProducer.db != nil { + t.Fatal("db object should be nil") + } +} + +func TestMSSQL_CreateUser(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + conf := &DatabaseConfig{ + DatabaseType: msSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + err = db.CreateUser(Statements{}, username, password, expiration) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := Statements{ + CreationStatements: testMSSQLRole, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err = db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestMSSQL_RevokeUser(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + conf := &DatabaseConfig{ + DatabaseType: msSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := Statements{ + CreationStatements: testMSSQLRole, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +const testMSSQLRole = ` +CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}'; +CREATE USER [{{name}}] FOR LOGIN [{{name}}]; +GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];`