DBPW - Update MSSQL to adhere to v5 Database interface (#10128)

This commit is contained in:
Michael Golowka 2020-10-13 11:11:00 -06:00 committed by GitHub
parent 24d1f33c9c
commit a62ffcab2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 310 additions and 378 deletions

View File

@ -6,34 +6,32 @@ import (
"errors"
"fmt"
"strings"
"time"
_ "github.com/denisenkom/go-mssqldb"
"github.com/hashicorp/errwrap"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/credsutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/helper/dbtxn"
"github.com/hashicorp/vault/sdk/helper/strutil"
)
const msSQLTypeName = "mssql"
var _ dbplugin.Database = &MSSQL{}
var _ newdbplugin.Database = &MSSQL{}
// MSSQL is an implementation of Database interface
type MSSQL struct {
*connutil.SQLConnectionProducer
credsutil.CredentialsProducer
}
func New() (interface{}, error) {
db := new()
// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
dbType := newdbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
return dbType, nil
}
@ -42,16 +40,8 @@ func new() *MSSQL {
connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = msSQLTypeName
credsProducer := &credsutil.SQLCredentialsProducer{
DisplayNameLen: 20,
RoleNameLen: 20,
UsernameLen: 128,
Separator: "-",
}
return &MSSQL{
SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
}
@ -62,7 +52,7 @@ func Run(apiTLSConfig *api.TLSConfig) error {
return err
}
dbplugin.Serve(dbType.(dbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))
newdbplugin.Serve(dbType.(newdbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))
return nil
}
@ -72,6 +62,12 @@ func (m *MSSQL) Type() (string, error) {
return msSQLTypeName, nil
}
func (m *MSSQL) secretValues() map[string]string {
return map[string]string{
m.Password: "[password]",
}
}
func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := m.Connection(ctx)
if err != nil {
@ -81,49 +77,51 @@ func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) {
return db.(*sql.DB), nil
}
// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
// the CreationStatement provided.
func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
// Grab the lock
func (m *MSSQL) Initialize(ctx context.Context, req newdbplugin.InitializeRequest) (newdbplugin.InitializeResponse, error) {
newConf, err := m.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection)
if err != nil {
return newdbplugin.InitializeResponse{}, err
}
resp := newdbplugin.InitializeResponse{
Config: newConf,
}
return resp, nil
}
// NewUser generates the username/password on the underlying MSSQL secret backend as instructed by
// the statements provided.
func (m *MSSQL) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (newdbplugin.NewUserResponse, error) {
m.Lock()
defer m.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
// Get the connection
db, err := m.getConnection(ctx)
if err != nil {
return "", "", err
return newdbplugin.NewUserResponse{}, fmt.Errorf("unable to get connection: %w", err)
}
if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement
if len(req.Statements.Commands) == 0 {
return newdbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement
}
username, err = m.GenerateUsername(usernameConfig)
username, err := credsutil.GenerateUsername(
credsutil.DisplayName(req.UsernameConfig.DisplayName, 20),
credsutil.RoleName(req.UsernameConfig.RoleName, 20),
credsutil.MaxLength(128),
credsutil.Separator("-"),
)
if err != nil {
return "", "", err
return newdbplugin.NewUserResponse{}, err
}
password, err = m.GeneratePassword()
if err != nil {
return "", "", err
}
expirationStr := req.Expiration.Format("2006-01-02 15:04:05-0700")
expirationStr, err := m.GenerateExpiration(expiration)
if err != nil {
return "", "", err
}
// Start a transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
return newdbplugin.NewUserResponse{}, err
}
defer tx.Rollback()
// Execute each query
for _, stmt := range statements.Creation {
for _, stmt := range req.Statements.Commands {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
@ -132,50 +130,45 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
m := map[string]string{
"name": username,
"password": password,
"password": req.Password,
"expiration": expirationStr,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return "", "", err
return newdbplugin.NewUserResponse{}, err
}
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return "", "", err
return newdbplugin.NewUserResponse{}, err
}
return username, password, nil
resp := newdbplugin.NewUserResponse{
Username: username,
}
return resp, nil
}
// RenewUser is not supported on MSSQL, so this is a no-op.
func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// NOOP
return nil
}
// RevokeUser attempts to drop the specified user. It will first attempt to disable login,
// DeleteUser attempts to drop the specified user. It will first attempt to disable login,
// then kill pending connections from that user, and finally drop the user and login from the
// database instance.
func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Revocation) == 0 {
return m.revokeUserDefault(ctx, username)
func (m *MSSQL) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) {
if len(req.Statements.Commands) == 0 {
err := m.revokeUserDefault(ctx, req.Username)
return newdbplugin.DeleteUserResponse{}, err
}
// Get connection
db, err := m.getConnection(ctx)
if err != nil {
return err
return newdbplugin.DeleteUserResponse{}, fmt.Errorf("unable to get connection: %w", err)
}
var result error
merr := &multierror.Error{}
// Execute each query
for _, stmt := range statements.Revocation {
for _, stmt := range req.Statements.Commands {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
@ -183,15 +176,15 @@ func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
}
m := map[string]string{
"name": username,
"name": req.Username,
}
if err := dbtxn.ExecuteDBQuery(ctx, db, m, query); err != nil {
result = multierror.Append(result, err)
merr = multierror.Append(merr, err)
}
}
}
return result
return newdbplugin.DeleteUserResponse{}, merr.ErrorOrNil()
}
func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
@ -297,76 +290,28 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
return nil
}
func (m *MSSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
m.Lock()
defer m.Unlock()
if len(m.Username) == 0 || len(m.Password) == 0 {
return nil, errors.New("username and password are required to rotate")
func (m *MSSQL) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest) (newdbplugin.UpdateUserResponse, error) {
if req.Password == nil && req.Expiration == nil {
return newdbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested")
}
rotateStatents := statements
if len(rotateStatents) == 0 {
rotateStatents = []string{alterLoginSQL}
if req.Password != nil {
err := m.updateUserPass(ctx, req.Username, req.Password)
return newdbplugin.UpdateUserResponse{}, err
}
db, err := m.getConnection(ctx)
if err != nil {
return nil, err
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
password, err := m.GeneratePassword()
if err != nil {
return nil, err
}
for _, stmt := range rotateStatents {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"username": m.Username,
"password": password,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return nil, err
}
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
if err := db.Close(); err != nil {
return nil, err
}
m.RawConfig["password"] = password
return m.RawConfig, nil
// Expiration is a no-op
return newdbplugin.UpdateUserResponse{}, nil
}
func (m *MSSQL) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
if len(statements.Rotation) == 0 {
statements.Rotation = []string{alterLoginSQL}
func (m *MSSQL) updateUserPass(ctx context.Context, username string, changePass *newdbplugin.ChangePassword) error {
stmts := changePass.Statements.Commands
if len(stmts) == 0 {
stmts = []string{alterLoginSQL}
}
username = staticUser.Username
password = staticUser.Password
password := changePass.NewPassword
if username == "" || password == "" {
return "", "", errors.New("must provide both username and password")
return errors.New("must provide both username and password")
}
m.Lock()
@ -374,7 +319,7 @@ func (m *MSSQL) SetCredentials(ctx context.Context, statements dbplugin.Statemen
db, err := m.getConnection(ctx)
if err != nil {
return "", "", err
return err
}
var exists bool
@ -382,16 +327,14 @@ func (m *MSSQL) SetCredentials(ctx context.Context, statements dbplugin.Statemen
err = db.QueryRowContext(ctx, "SELECT 1 FROM master.sys.server_principals where name = N'$1'", username).Scan(&exists)
if err != nil && err != sql.ErrNoRows {
return "", "", err
return err
}
stmts := statements.Rotation
// Start a transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
return err
}
defer func() {
_ = tx.Rollback()
}()
@ -409,16 +352,16 @@ func (m *MSSQL) SetCredentials(ctx context.Context, statements dbplugin.Statemen
"password": password,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return "", "", err
return fmt.Errorf("failed to execute query: %w", err)
}
}
}
if err := tx.Commit(); err != nil {
return "", "", err
return fmt.Errorf("failed to commit transaction: %w", err)
}
return username, password, nil
return nil
}
const dropUserSQL = `

View File

@ -4,162 +4,43 @@ import (
"context"
"database/sql"
"fmt"
"reflect"
"regexp"
"strings"
"testing"
"time"
mssqlhelper "github.com/hashicorp/vault/helper/testhelpers/mssql"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
dbtesting "github.com/hashicorp/vault/sdk/database/newdbplugin/testing"
"github.com/hashicorp/vault/sdk/helper/dbtxn"
mssqlhelper "github.com/hashicorp/vault/helper/testhelpers/mssql"
)
func TestMSSQL_Initialize(t *testing.T) {
func TestInitialize(t *testing.T) {
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
connectionDetails := map[string]interface{}{
"connection_url": connURL,
}
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
if !db.Initialized {
t.Fatal("Database should be initialized")
}
err = db.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
// Test decoding a string value for max_open_connections
connectionDetails = map[string]interface{}{
"connection_url": connURL,
"max_open_connections": "5",
}
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestMSSQL_CreateUser(t *testing.T) {
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
connectionDetails := map[string]interface{}{
"connection_url": connURL,
}
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConfig := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
// Test with no configured Creation Statement
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
}
statements := dbplugin.Statements{
Creation: []string{testMSSQLRole},
}
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
if err = testCredsExist(t, connURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
}
func TestMSSQL_RotateRootCredentials(t *testing.T) {
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
connectionDetails := map[string]interface{}{
"connection_url": connURL,
"username": "sa",
"password": "yourStrong(!)Password",
}
db := new()
connProducer := db.SQLConnectionProducer
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
if !connProducer.Initialized {
t.Fatal("Database should be initialized")
}
newConf, err := db.RotateRootCredentials(context.Background(), nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if newConf["password"] == "yourStrong(!)Password" {
t.Fatal("password was not updated")
}
err = db.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestMSSQL_SetCredentials_missingArgs(t *testing.T) {
type testCase struct {
statements dbplugin.Statements
userConfig dbplugin.StaticUserConfig
req newdbplugin.InitializeRequest
}
tests := map[string]testCase{
"empty rotation statements": {
statements: dbplugin.Statements{
Rotation: nil,
},
userConfig: dbplugin.StaticUserConfig{
Username: "testuser",
Password: "password",
"happy path": {
req: newdbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
},
VerifyConnection: true,
},
},
"empty username": {
statements: dbplugin.Statements{
Rotation: []string{`
ALTER LOGIN [{{username}}] WITH PASSWORD = '{{password}}';`,
"max_open_connections set": {
newdbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"max_open_connections": "5",
},
},
userConfig: dbplugin.StaticUserConfig{
Username: "",
Password: "password",
},
},
"empty password": {
statements: dbplugin.Statements{
Rotation: []string{`
ALTER LOGIN [{{username}}] WITH PASSWORD = '{{password}}';`,
},
},
userConfig: dbplugin.StaticUserConfig{
Username: "testuser",
Password: "",
VerifyConnection: true,
},
},
}
@ -167,33 +48,165 @@ func TestMSSQL_SetCredentials_missingArgs(t *testing.T) {
for name, test := range tests {
t.Run(name, func(t *testing.T) {
db := new()
dbtesting.AssertInitialize(t, db, test.req)
defer dbtesting.AssertClose(t, db)
username, password, err := db.SetCredentials(context.Background(), test.statements, test.userConfig)
if err == nil {
t.Fatalf("expected err, got nil")
}
if username != "" {
t.Fatalf("expected empty username, got [%s]", username)
}
if password != "" {
t.Fatalf("expected empty password, got [%s]", password)
if !db.Initialized {
t.Fatal("Database should be initialized")
}
})
}
}
func TestMSSQL_SetCredentials(t *testing.T) {
func TestNewUser(t *testing.T) {
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
type testCase struct {
rotationStmts []string
req newdbplugin.NewUserRequest
usernameRegex string
expectErr bool
assertUser func(t testing.TB, connURL, username, password string)
}
tests := map[string]testCase{
"empty rotation statements": {
rotationStmts: []string{},
}, "username rotation": {
rotationStmts: []string{`
ALTER LOGIN [{{username}}] WITH PASSWORD = '{{password}}';`,
"no creation statements": {
req: newdbplugin.NewUserRequest{
UsernameConfig: newdbplugin.UsernameMetadata{
DisplayName: "test",
RoleName: "test",
},
Statements: newdbplugin.Statements{},
Password: "AG4qagho-dsvZ",
Expiration: time.Now().Add(1 * time.Second),
},
usernameRegex: "^$",
expectErr: true,
assertUser: assertCredsDoNotExist,
},
"with creation statements": {
req: newdbplugin.NewUserRequest{
UsernameConfig: newdbplugin.UsernameMetadata{
DisplayName: "test",
RoleName: "test",
},
Statements: newdbplugin.Statements{
Commands: []string{testMSSQLRole},
},
Password: "AG4qagho-dsvZ",
Expiration: time.Now().Add(1 * time.Second),
},
usernameRegex: "^v-test-test-[a-zA-Z0-9]{20}-[0-9]{10}$",
expectErr: false,
assertUser: assertCredsExist,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
usernameRe, err := regexp.Compile(test.usernameRegex)
if err != nil {
t.Fatalf("failed to compile username regex %q: %s", test.usernameRegex, err)
}
initReq := newdbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
},
VerifyConnection: true,
}
db := new()
dbtesting.AssertInitialize(t, db, initReq)
defer dbtesting.AssertClose(t, db)
createResp, err := db.NewUser(context.Background(), test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !usernameRe.MatchString(createResp.Username) {
t.Fatalf("Generated username %q did not match regex %q", createResp.Username, test.usernameRegex)
}
// Protect against future fields that aren't specified
expectedResp := newdbplugin.NewUserResponse{
Username: createResp.Username,
}
if !reflect.DeepEqual(createResp, expectedResp) {
t.Fatalf("Fields missing from expected response: Actual: %#v", createResp)
}
test.assertUser(t, connURL, createResp.Username, test.req.Password)
})
}
}
func TestUpdateUser_password(t *testing.T) {
type testCase struct {
req newdbplugin.UpdateUserRequest
expectErr bool
expectedPassword string
}
dbUser := "vaultuser"
initPassword := "p4$sw0rd"
tests := map[string]testCase{
"missing password": {
req: newdbplugin.UpdateUserRequest{
Username: dbUser,
Password: &newdbplugin.ChangePassword{
NewPassword: "",
Statements: newdbplugin.Statements{},
},
},
expectErr: true,
expectedPassword: initPassword,
},
"empty rotation statements": {
req: newdbplugin.UpdateUserRequest{
Username: dbUser,
Password: &newdbplugin.ChangePassword{
NewPassword: "N90gkKLy8$angf",
Statements: newdbplugin.Statements{},
},
},
expectErr: false,
expectedPassword: "N90gkKLy8$angf",
},
"username rotation": {
req: newdbplugin.UpdateUserRequest{
Username: dbUser,
Password: &newdbplugin.ChangePassword{
NewPassword: "N90gkKLy8$angf",
Statements: newdbplugin.Statements{
Commands: []string{
"ALTER LOGIN [{{username}}] WITH PASSWORD = '{{password}}'",
},
},
},
},
expectErr: false,
expectedPassword: "N90gkKLy8$angf",
},
"bad statements": {
req: newdbplugin.UpdateUserRequest{
Username: dbUser,
Password: &newdbplugin.ChangePassword{
NewPassword: "N90gkKLy8$angf",
Statements: newdbplugin.Statements{
Commands: []string{
"ahosh98asjdffs",
},
},
},
},
expectErr: true,
expectedPassword: initPassword,
},
}
@ -202,124 +215,101 @@ func TestMSSQL_SetCredentials(t *testing.T) {
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
connectionDetails := map[string]interface{}{
"connection_url": connURL,
initReq := newdbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
},
VerifyConnection: true,
}
db := new()
dbtesting.AssertInitialize(t, db, initReq)
defer dbtesting.AssertClose(t, db)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
_, err := db.Init(ctx, connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
dbUser := "vaultstatictest"
initPassword := "p4$sw0rd"
createTestMSSQLUser(t, connURL, dbUser, initPassword, testMSSQLLogin)
if err := testCredsExist(t, connURL, dbUser, initPassword); err != nil {
t.Fatalf("Could not connect with initial credentials: %s", err)
assertCredsExist(t, connURL, dbUser, initPassword)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
updateResp, err := db.UpdateUser(ctx, test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
statements := dbplugin.Statements{
Rotation: test.rotationStmts,
}
newPassword, err := db.GenerateCredentials(context.Background())
if err != nil {
t.Fatal(err)
}
usernameConfig := dbplugin.StaticUserConfig{
Username: dbUser,
Password: newPassword,
}
username, password, err := db.SetCredentials(ctx, statements, usernameConfig)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := testCredsExist(t, connURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
if err := testCredsExist(t, connURL, username, initPassword); err == nil {
t.Fatalf("Should not be able to connect with initial credentials")
// Protect against future fields that aren't specified
expectedResp := newdbplugin.UpdateUserResponse{}
if !reflect.DeepEqual(updateResp, expectedResp) {
t.Fatalf("Fields missing from expected response: Actual: %#v", updateResp)
}
assertCredsExist(t, connURL, dbUser, test.expectedPassword)
})
}
}
func TestMSSQL_RevokeUser(t *testing.T) {
func TestDeleteUser(t *testing.T) {
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
connectionDetails := map[string]interface{}{
"connection_url": connURL,
dbUser := "vaultuser"
initPassword := "p4$sw0rd"
initReq := newdbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
},
VerifyConnection: true,
}
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
dbtesting.AssertInitialize(t, db, initReq)
defer dbtesting.AssertClose(t, db)
createTestMSSQLUser(t, connURL, dbUser, initPassword, testMSSQLLogin)
assertCredsExist(t, connURL, dbUser, initPassword)
deleteReq := newdbplugin.DeleteUserRequest{
Username: dbUser,
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
deleteResp, err := db.DeleteUser(ctx, deleteReq)
if err != nil {
t.Fatalf("err: %s", err)
t.Fatalf("Failed to delete user: %s", err)
}
statements := dbplugin.Statements{
Creation: []string{testMSSQLRole},
// Protect against future fields that aren't specified
expectedResp := newdbplugin.DeleteUserResponse{}
if !reflect.DeepEqual(deleteResp, expectedResp) {
t.Fatalf("Fields missing from expected response: Actual: %#v", deleteResp)
}
usernameConfig := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
assertCredsDoNotExist(t, connURL, dbUser, initPassword)
}
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
func assertCredsExist(t testing.TB, connURL, username, password string) {
t.Helper()
err := testCredsExist(connURL, username, password)
if err != nil {
t.Fatalf("err: %s", err)
}
if err = testCredsExist(t, connURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
// Test default revoke statements
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := testCredsExist(t, connURL, username, password); err == nil {
t.Fatal("Credentials were not revoked")
}
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
if err = testCredsExist(t, connURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
// Test custom revoke statement
statements.Revocation = []string{testMSSQLDrop}
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := testCredsExist(t, connURL, username, password); err == nil {
t.Fatal("Credentials were not revoked")
t.Fatalf("Unable to log in as %q: %s", username, err)
}
}
func testCredsExist(t testing.TB, connURL, username, password string) error {
func assertCredsDoNotExist(t testing.TB, connURL, username, password string) {
t.Helper()
err := testCredsExist(connURL, username, password)
if err == nil {
t.Fatalf("Able to log in when it shouldn't")
}
}
func testCredsExist(connURL, username, password string) error {
// Log in with the new creds
parts := strings.Split(connURL, "@")
connURL = fmt.Sprintf("sqlserver://%s:%s@%s", username, password, parts[1])
@ -332,7 +322,6 @@ func testCredsExist(t testing.TB, connURL, username, password string) error {
}
func createTestMSSQLUser(t *testing.T, connURL string, username, password, query string) {
db, err := sql.Open("mssql", connURL)
defer db.Close()
if err != nil {