Database refactor mssql (#2562)
* WIP on mssql secret backend refactor * Add RevokeUser test, and use sqlserver driver internally * Remove debug statements * Fix code comment
This commit is contained in:
parent
210fa77e3c
commit
aa15a1d3a9
|
@ -10,6 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
// Import sql drivers
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
@ -73,6 +74,12 @@ func (c *sqlConnectionProducer) connection() (interface{}, error) {
|
|||
c.db.Close()
|
||||
}
|
||||
|
||||
// For mssql backend, switch to sqlserver instead
|
||||
dbType := c.config.DatabaseType
|
||||
if c.config.DatabaseType == "mssql" {
|
||||
dbType = "sqlserver"
|
||||
}
|
||||
|
||||
// Otherwise, attempt to make connection
|
||||
conn := c.ConnectionURL
|
||||
|
||||
|
@ -86,7 +93,7 @@ func (c *sqlConnectionProducer) connection() (interface{}, error) {
|
|||
}
|
||||
|
||||
var err error
|
||||
c.db, err = sql.Open(c.config.DatabaseType, conn)
|
||||
c.db, err = sql.Open(dbType, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
const (
|
||||
postgreSQLTypeName = "postgres"
|
||||
mySQLTypeName = "mysql"
|
||||
msSQLTypeName = "mssql"
|
||||
cassandraTypeName = "cassandra"
|
||||
pluginTypeName = "plugin"
|
||||
)
|
||||
|
@ -61,6 +62,20 @@ func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Log
|
|||
CredentialsProducer: credsProducer,
|
||||
}
|
||||
|
||||
case msSQLTypeName:
|
||||
connProducer := &sqlConnectionProducer{}
|
||||
connProducer.config = conf
|
||||
|
||||
credsProducer := &sqlCredentialsProducer{
|
||||
displayNameLen: 10,
|
||||
usernameLen: 63,
|
||||
}
|
||||
|
||||
dbType = &MSSQL{
|
||||
ConnectionProducer: connProducer,
|
||||
CredentialsProducer: credsProducer,
|
||||
}
|
||||
|
||||
case cassandraTypeName:
|
||||
connProducer := &cassandraConnectionProducer{}
|
||||
connProducer.config = conf
|
||||
|
@ -163,7 +178,7 @@ func (dc *DatabaseConfig) GetFactory() Factory {
|
|||
return BuiltinFactory
|
||||
}
|
||||
|
||||
// Statments set in role creation and passed into the database type's functions.
|
||||
// Statements set in role creation and passed into the database type's functions.
|
||||
// TODO: Add a way of setting defaults here.
|
||||
type Statements struct {
|
||||
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
|
||||
|
|
219
builtin/logical/database/dbs/mssql.go
Normal file
219
builtin/logical/database/dbs/mssql.go
Normal file
|
@ -0,0 +1,219 @@
|
|||
package dbs
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
)
|
||||
|
||||
// MSSQL is an implementation of DatabaseType interface
|
||||
type MSSQL struct {
|
||||
ConnectionProducer
|
||||
CredentialsProducer
|
||||
}
|
||||
|
||||
// Type returns the TypeName for this backend
|
||||
func (m *MSSQL) Type() string {
|
||||
return msSQLTypeName
|
||||
}
|
||||
|
||||
func (m *MSSQL) getConnection() (*sql.DB, error) {
|
||||
db, err := m.connection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
|
||||
// the CreationStatement provided.
|
||||
func (m *MSSQL) CreateUser(statements Statements, username, password, expiration string) error {
|
||||
// Grab the lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if statements.CreationStatements == "" {
|
||||
return ErrEmptyCreationStatement
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RenewUser is not supported on MSSQL, so this is a no-op.
|
||||
func (m *MSSQL) RenewUser(statements Statements, username, expiration string) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeUser attempts to drop the specified user. It will first attempt to disable login,
|
||||
// then kill pending connections from that user, and finally drop the user and login from the
|
||||
// database instance.
|
||||
func (m *MSSQL) RevokeUser(statements Statements, username string) error {
|
||||
// Get connection
|
||||
db, err := m.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// First disable server login
|
||||
disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer disableStmt.Close()
|
||||
if _, err := disableStmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Query for sessions for the login so that we can kill any outstanding
|
||||
// sessions. There cannot be any active sessions before we drop the logins
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
sessionStmt, err := db.Prepare(fmt.Sprintf(
|
||||
"SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sessionStmt.Close()
|
||||
|
||||
sessionRows, err := sessionStmt.Query()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sessionRows.Close()
|
||||
|
||||
var revokeStmts []string
|
||||
for sessionRows.Next() {
|
||||
var sessionID int
|
||||
err = sessionRows.Scan(&sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID))
|
||||
}
|
||||
|
||||
// Query for database users using undocumented stored procedure for now since
|
||||
// it is the easiest way to get this information;
|
||||
// we need to drop the database users before we can drop the login and the role
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var loginName, dbName, qUsername string
|
||||
var aliasName sql.NullString
|
||||
err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName, username, username))
|
||||
}
|
||||
|
||||
// we do not stop on error, as we want to remove as
|
||||
// many permissions as possible right now
|
||||
var lastStmtError error
|
||||
for _, query := range revokeStmts {
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
continue
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.Exec()
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
}
|
||||
}
|
||||
|
||||
// can't drop if not all database users are dropped
|
||||
if rows.Err() != nil {
|
||||
return fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err())
|
||||
}
|
||||
if lastStmtError != nil {
|
||||
return fmt.Errorf("could not perform all sql statements: %s", lastStmtError)
|
||||
}
|
||||
|
||||
// Drop this login
|
||||
stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const dropUserSQL = `
|
||||
USE [%s]
|
||||
IF EXISTS
|
||||
(SELECT name
|
||||
FROM sys.database_principals
|
||||
WHERE name = N'%s')
|
||||
BEGIN
|
||||
DROP USER [%s]
|
||||
END
|
||||
`
|
||||
|
||||
const dropLoginSQL = `
|
||||
IF EXISTS
|
||||
(SELECT name
|
||||
FROM master.sys.server_principals
|
||||
WHERE name = N'%s')
|
||||
BEGIN
|
||||
DROP LOGIN [%s]
|
||||
END
|
||||
`
|
221
builtin/logical/database/dbs/mssql_test.go
Normal file
221
builtin/logical/database/dbs/mssql_test.go
Normal file
|
@ -0,0 +1,221 @@
|
|||
package dbs
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
dockertest "gopkg.in/ory-am/dockertest.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
testMSQLImagePull sync.Once
|
||||
)
|
||||
|
||||
func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) {
|
||||
if os.Getenv("MSSQL_URL") != "" {
|
||||
return func() {}, os.Getenv("MSSQL_URL")
|
||||
}
|
||||
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to docker: %s", err)
|
||||
}
|
||||
|
||||
resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"})
|
||||
if err != nil {
|
||||
t.Fatalf("Could not start local MSSQL docker container: %s", err)
|
||||
}
|
||||
|
||||
cleanup = func() {
|
||||
err := pool.Purge(resource)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to cleanup local DynamoDB: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp"))
|
||||
|
||||
// exponential backoff-retry, because the mssql container may not be able to accept connections yet
|
||||
if err = pool.Retry(func() error {
|
||||
var err error
|
||||
var db *sql.DB
|
||||
db, err = sql.Open("mssql", retURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.Ping()
|
||||
}); err != nil {
|
||||
t.Fatalf("Could not connect to MSSQL docker container: %s", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestMSSQL_Initialize(t *testing.T) {
|
||||
cleanup, connURL := prepareMSSQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
conf := &DatabaseConfig{
|
||||
DatabaseType: msSQLTypeName,
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
}
|
||||
|
||||
dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Deconsturct the middleware chain to get the underlying mssql object
|
||||
dbTracer := dbRaw.(*databaseTracingMiddleware)
|
||||
dbMetrics := dbTracer.next.(*databaseMetricsMiddleware)
|
||||
db := dbMetrics.next.(*MSSQL)
|
||||
connProducer := db.ConnectionProducer.(*sqlConnectionProducer)
|
||||
|
||||
err = dbRaw.Initialize(conf.ConnectionDetails)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if !connProducer.initalized {
|
||||
t.Fatal("Database should be initalized")
|
||||
}
|
||||
|
||||
err = dbRaw.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if connProducer.db != nil {
|
||||
t.Fatal("db object should be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSSQL_CreateUser(t *testing.T) {
|
||||
cleanup, connURL := prepareMSSQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
conf := &DatabaseConfig{
|
||||
DatabaseType: msSQLTypeName,
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
err = db.Initialize(conf.ConnectionDetails)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err := db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err := db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err := db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
err = db.CreateUser(Statements{}, username, password, expiration)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
||||
statements := Statements{
|
||||
CreationStatements: testMSSQLRole,
|
||||
}
|
||||
|
||||
err = db.CreateUser(statements, username, password, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err = db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err = db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err = db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSSQL_RevokeUser(t *testing.T) {
|
||||
cleanup, connURL := prepareMSSQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
conf := &DatabaseConfig{
|
||||
DatabaseType: msSQLTypeName,
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
err = db.Initialize(conf.ConnectionDetails)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err := db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err := db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err := db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := Statements{
|
||||
CreationStatements: testMSSQLRole,
|
||||
}
|
||||
|
||||
err = db.CreateUser(statements, username, password, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
const testMSSQLRole = `
|
||||
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
|
||||
CREATE USER [{{name}}] FOR LOGIN [{{name}}];
|
||||
GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];`
|
Loading…
Reference in a new issue