diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index 52023fcf2..4d0e46607 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -6,34 +6,32 @@ import ( "errors" "fmt" "strings" - "time" _ "github.com/denisenkom/go-mssqldb" "github.com/hashicorp/errwrap" multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/sdk/database/dbplugin" "github.com/hashicorp/vault/sdk/database/helper/connutil" "github.com/hashicorp/vault/sdk/database/helper/credsutil" "github.com/hashicorp/vault/sdk/database/helper/dbutil" + "github.com/hashicorp/vault/sdk/database/newdbplugin" "github.com/hashicorp/vault/sdk/helper/dbtxn" "github.com/hashicorp/vault/sdk/helper/strutil" ) const msSQLTypeName = "mssql" -var _ dbplugin.Database = &MSSQL{} +var _ newdbplugin.Database = &MSSQL{} // MSSQL is an implementation of Database interface type MSSQL struct { *connutil.SQLConnectionProducer - credsutil.CredentialsProducer } func New() (interface{}, error) { db := new() // Wrap the plugin with middleware to sanitize errors - dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) + dbType := newdbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues) return dbType, nil } @@ -42,16 +40,8 @@ func new() *MSSQL { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = msSQLTypeName - credsProducer := &credsutil.SQLCredentialsProducer{ - DisplayNameLen: 20, - RoleNameLen: 20, - UsernameLen: 128, - Separator: "-", - } - return &MSSQL{ SQLConnectionProducer: connProducer, - CredentialsProducer: credsProducer, } } @@ -62,7 +52,7 @@ func Run(apiTLSConfig *api.TLSConfig) error { return err } - dbplugin.Serve(dbType.(dbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig)) + newdbplugin.Serve(dbType.(newdbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig)) return nil } @@ -72,6 +62,12 @@ func (m *MSSQL) Type() (string, error) { return msSQLTypeName, nil } +func (m *MSSQL) secretValues() map[string]string { + return map[string]string{ + m.Password: "[password]", + } +} + func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) { db, err := m.Connection(ctx) if err != nil { @@ -81,49 +77,51 @@ func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) { return db.(*sql.DB), nil } -// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by -// the CreationStatement provided. -func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { - // Grab the lock +func (m *MSSQL) Initialize(ctx context.Context, req newdbplugin.InitializeRequest) (newdbplugin.InitializeResponse, error) { + newConf, err := m.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection) + if err != nil { + return newdbplugin.InitializeResponse{}, err + } + resp := newdbplugin.InitializeResponse{ + Config: newConf, + } + return resp, nil +} + +// NewUser generates the username/password on the underlying MSSQL secret backend as instructed by +// the statements provided. +func (m *MSSQL) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (newdbplugin.NewUserResponse, error) { m.Lock() defer m.Unlock() - statements = dbutil.StatementCompatibilityHelper(statements) - - // Get the connection db, err := m.getConnection(ctx) if err != nil { - return "", "", err + return newdbplugin.NewUserResponse{}, fmt.Errorf("unable to get connection: %w", err) } - if len(statements.Creation) == 0 { - return "", "", dbutil.ErrEmptyCreationStatement + if len(req.Statements.Commands) == 0 { + return newdbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement } - username, err = m.GenerateUsername(usernameConfig) + username, err := credsutil.GenerateUsername( + credsutil.DisplayName(req.UsernameConfig.DisplayName, 20), + credsutil.RoleName(req.UsernameConfig.RoleName, 20), + credsutil.MaxLength(128), + credsutil.Separator("-"), + ) if err != nil { - return "", "", err + return newdbplugin.NewUserResponse{}, err } - password, err = m.GeneratePassword() - if err != nil { - return "", "", err - } + expirationStr := req.Expiration.Format("2006-01-02 15:04:05-0700") - expirationStr, err := m.GenerateExpiration(expiration) - if err != nil { - return "", "", err - } - - // Start a transaction tx, err := db.BeginTx(ctx, nil) if err != nil { - return "", "", err + return newdbplugin.NewUserResponse{}, err } defer tx.Rollback() - // Execute each query - for _, stmt := range statements.Creation { + for _, stmt := range req.Statements.Commands { for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { query = strings.TrimSpace(query) if len(query) == 0 { @@ -132,50 +130,45 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, m := map[string]string{ "name": username, - "password": password, + "password": req.Password, "expiration": expirationStr, } if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { - return "", "", err + return newdbplugin.NewUserResponse{}, err } } } - // Commit the transaction if err := tx.Commit(); err != nil { - return "", "", err + return newdbplugin.NewUserResponse{}, err } - return username, password, nil + resp := newdbplugin.NewUserResponse{ + Username: username, + } + + return resp, nil } -// RenewUser is not supported on MSSQL, so this is a no-op. -func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { - // NOOP - return nil -} - -// RevokeUser attempts to drop the specified user. It will first attempt to disable login, +// DeleteUser attempts to drop the specified user. It will first attempt to disable login, // then kill pending connections from that user, and finally drop the user and login from the // database instance. -func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { - statements = dbutil.StatementCompatibilityHelper(statements) - - if len(statements.Revocation) == 0 { - return m.revokeUserDefault(ctx, username) +func (m *MSSQL) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) { + if len(req.Statements.Commands) == 0 { + err := m.revokeUserDefault(ctx, req.Username) + return newdbplugin.DeleteUserResponse{}, err } - // Get connection db, err := m.getConnection(ctx) if err != nil { - return err + return newdbplugin.DeleteUserResponse{}, fmt.Errorf("unable to get connection: %w", err) } - var result error + merr := &multierror.Error{} // Execute each query - for _, stmt := range statements.Revocation { + for _, stmt := range req.Statements.Commands { for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { query = strings.TrimSpace(query) if len(query) == 0 { @@ -183,15 +176,15 @@ func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, } m := map[string]string{ - "name": username, + "name": req.Username, } if err := dbtxn.ExecuteDBQuery(ctx, db, m, query); err != nil { - result = multierror.Append(result, err) + merr = multierror.Append(merr, err) } } } - return result + return newdbplugin.DeleteUserResponse{}, merr.ErrorOrNil() } func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error { @@ -297,76 +290,28 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error { return nil } -func (m *MSSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { - m.Lock() - defer m.Unlock() - - if len(m.Username) == 0 || len(m.Password) == 0 { - return nil, errors.New("username and password are required to rotate") +func (m *MSSQL) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest) (newdbplugin.UpdateUserResponse, error) { + if req.Password == nil && req.Expiration == nil { + return newdbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested") } - - rotateStatents := statements - if len(rotateStatents) == 0 { - rotateStatents = []string{alterLoginSQL} + if req.Password != nil { + err := m.updateUserPass(ctx, req.Username, req.Password) + return newdbplugin.UpdateUserResponse{}, err } - - db, err := m.getConnection(ctx) - if err != nil { - return nil, err - } - - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer func() { - tx.Rollback() - }() - - password, err := m.GeneratePassword() - if err != nil { - return nil, err - } - - for _, stmt := range rotateStatents { - for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - m := map[string]string{ - "username": m.Username, - "password": password, - } - if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { - return nil, err - } - } - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - if err := db.Close(); err != nil { - return nil, err - } - - m.RawConfig["password"] = password - return m.RawConfig, nil + // Expiration is a no-op + return newdbplugin.UpdateUserResponse{}, nil } -func (m *MSSQL) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) { - if len(statements.Rotation) == 0 { - statements.Rotation = []string{alterLoginSQL} +func (m *MSSQL) updateUserPass(ctx context.Context, username string, changePass *newdbplugin.ChangePassword) error { + stmts := changePass.Statements.Commands + if len(stmts) == 0 { + stmts = []string{alterLoginSQL} } - username = staticUser.Username - password = staticUser.Password + password := changePass.NewPassword if username == "" || password == "" { - return "", "", errors.New("must provide both username and password") + return errors.New("must provide both username and password") } m.Lock() @@ -374,7 +319,7 @@ func (m *MSSQL) SetCredentials(ctx context.Context, statements dbplugin.Statemen db, err := m.getConnection(ctx) if err != nil { - return "", "", err + return err } var exists bool @@ -382,16 +327,14 @@ func (m *MSSQL) SetCredentials(ctx context.Context, statements dbplugin.Statemen err = db.QueryRowContext(ctx, "SELECT 1 FROM master.sys.server_principals where name = N'$1'", username).Scan(&exists) if err != nil && err != sql.ErrNoRows { - return "", "", err + return err } - stmts := statements.Rotation - - // Start a transaction tx, err := db.BeginTx(ctx, nil) if err != nil { - return "", "", err + return err } + defer func() { _ = tx.Rollback() }() @@ -409,16 +352,16 @@ func (m *MSSQL) SetCredentials(ctx context.Context, statements dbplugin.Statemen "password": password, } if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { - return "", "", err + return fmt.Errorf("failed to execute query: %w", err) } } } if err := tx.Commit(); err != nil { - return "", "", err + return fmt.Errorf("failed to commit transaction: %w", err) } - return username, password, nil + return nil } const dropUserSQL = ` diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go index e8522ac8b..1bf108dc0 100644 --- a/plugins/database/mssql/mssql_test.go +++ b/plugins/database/mssql/mssql_test.go @@ -4,162 +4,43 @@ import ( "context" "database/sql" "fmt" + "reflect" + "regexp" "strings" "testing" "time" - mssqlhelper "github.com/hashicorp/vault/helper/testhelpers/mssql" - "github.com/hashicorp/vault/sdk/database/dbplugin" + "github.com/hashicorp/vault/sdk/database/newdbplugin" + dbtesting "github.com/hashicorp/vault/sdk/database/newdbplugin/testing" "github.com/hashicorp/vault/sdk/helper/dbtxn" + + mssqlhelper "github.com/hashicorp/vault/helper/testhelpers/mssql" ) -func TestMSSQL_Initialize(t *testing.T) { +func TestInitialize(t *testing.T) { cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t) defer cleanup() - connectionDetails := map[string]interface{}{ - "connection_url": connURL, - } - - db := new() - _, err := db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } - - if !db.Initialized { - t.Fatal("Database should be initialized") - } - - err = db.Close() - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test decoding a string value for max_open_connections - connectionDetails = map[string]interface{}{ - "connection_url": connURL, - "max_open_connections": "5", - } - - _, err = db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestMSSQL_CreateUser(t *testing.T) { - cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t) - defer cleanup() - - connectionDetails := map[string]interface{}{ - "connection_url": connURL, - } - - db := new() - _, err := db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } - - usernameConfig := dbplugin.UsernameConfig{ - DisplayName: "test", - RoleName: "test", - } - - // Test with no configured Creation Statement - _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) - if err == nil { - t.Fatal("Expected error when no creation statement is provided") - } - - statements := dbplugin.Statements{ - Creation: []string{testMSSQLRole}, - } - - username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err = testCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } -} - -func TestMSSQL_RotateRootCredentials(t *testing.T) { - cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t) - defer cleanup() - - connectionDetails := map[string]interface{}{ - "connection_url": connURL, - "username": "sa", - "password": "yourStrong(!)Password", - } - - db := new() - - connProducer := db.SQLConnectionProducer - - _, err := db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } - - if !connProducer.Initialized { - t.Fatal("Database should be initialized") - } - - newConf, err := db.RotateRootCredentials(context.Background(), nil) - if err != nil { - t.Fatalf("err: %v", err) - } - if newConf["password"] == "yourStrong(!)Password" { - t.Fatal("password was not updated") - } - - err = db.Close() - if err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestMSSQL_SetCredentials_missingArgs(t *testing.T) { type testCase struct { - statements dbplugin.Statements - userConfig dbplugin.StaticUserConfig + req newdbplugin.InitializeRequest } tests := map[string]testCase{ - "empty rotation statements": { - statements: dbplugin.Statements{ - Rotation: nil, - }, - userConfig: dbplugin.StaticUserConfig{ - Username: "testuser", - Password: "password", + "happy path": { + req: newdbplugin.InitializeRequest{ + Config: map[string]interface{}{ + "connection_url": connURL, + }, + VerifyConnection: true, }, }, - "empty username": { - statements: dbplugin.Statements{ - Rotation: []string{` - ALTER LOGIN [{{username}}] WITH PASSWORD = '{{password}}';`, + "max_open_connections set": { + newdbplugin.InitializeRequest{ + Config: map[string]interface{}{ + "connection_url": connURL, + "max_open_connections": "5", }, - }, - userConfig: dbplugin.StaticUserConfig{ - Username: "", - Password: "password", - }, - }, - "empty password": { - statements: dbplugin.Statements{ - Rotation: []string{` - ALTER LOGIN [{{username}}] WITH PASSWORD = '{{password}}';`, - }, - }, - userConfig: dbplugin.StaticUserConfig{ - Username: "testuser", - Password: "", + VerifyConnection: true, }, }, } @@ -167,33 +48,165 @@ func TestMSSQL_SetCredentials_missingArgs(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { db := new() + dbtesting.AssertInitialize(t, db, test.req) + defer dbtesting.AssertClose(t, db) - username, password, err := db.SetCredentials(context.Background(), test.statements, test.userConfig) - if err == nil { - t.Fatalf("expected err, got nil") - } - if username != "" { - t.Fatalf("expected empty username, got [%s]", username) - } - if password != "" { - t.Fatalf("expected empty password, got [%s]", password) + if !db.Initialized { + t.Fatal("Database should be initialized") } }) } } -func TestMSSQL_SetCredentials(t *testing.T) { +func TestNewUser(t *testing.T) { + cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t) + defer cleanup() + type testCase struct { - rotationStmts []string + req newdbplugin.NewUserRequest + usernameRegex string + expectErr bool + assertUser func(t testing.TB, connURL, username, password string) } tests := map[string]testCase{ - "empty rotation statements": { - rotationStmts: []string{}, - }, "username rotation": { - rotationStmts: []string{` - ALTER LOGIN [{{username}}] WITH PASSWORD = '{{password}}';`, + "no creation statements": { + req: newdbplugin.NewUserRequest{ + UsernameConfig: newdbplugin.UsernameMetadata{ + DisplayName: "test", + RoleName: "test", + }, + Statements: newdbplugin.Statements{}, + Password: "AG4qagho-dsvZ", + Expiration: time.Now().Add(1 * time.Second), }, + usernameRegex: "^$", + expectErr: true, + assertUser: assertCredsDoNotExist, + }, + "with creation statements": { + req: newdbplugin.NewUserRequest{ + UsernameConfig: newdbplugin.UsernameMetadata{ + DisplayName: "test", + RoleName: "test", + }, + Statements: newdbplugin.Statements{ + Commands: []string{testMSSQLRole}, + }, + Password: "AG4qagho-dsvZ", + Expiration: time.Now().Add(1 * time.Second), + }, + usernameRegex: "^v-test-test-[a-zA-Z0-9]{20}-[0-9]{10}$", + expectErr: false, + assertUser: assertCredsExist, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + usernameRe, err := regexp.Compile(test.usernameRegex) + if err != nil { + t.Fatalf("failed to compile username regex %q: %s", test.usernameRegex, err) + } + + initReq := newdbplugin.InitializeRequest{ + Config: map[string]interface{}{ + "connection_url": connURL, + }, + VerifyConnection: true, + } + + db := new() + dbtesting.AssertInitialize(t, db, initReq) + defer dbtesting.AssertClose(t, db) + + createResp, err := db.NewUser(context.Background(), test.req) + if test.expectErr && err == nil { + t.Fatalf("err expected, got nil") + } + if !test.expectErr && err != nil { + t.Fatalf("no error expected, got: %s", err) + } + + if !usernameRe.MatchString(createResp.Username) { + t.Fatalf("Generated username %q did not match regex %q", createResp.Username, test.usernameRegex) + } + + // Protect against future fields that aren't specified + expectedResp := newdbplugin.NewUserResponse{ + Username: createResp.Username, + } + if !reflect.DeepEqual(createResp, expectedResp) { + t.Fatalf("Fields missing from expected response: Actual: %#v", createResp) + } + + test.assertUser(t, connURL, createResp.Username, test.req.Password) + }) + } +} + +func TestUpdateUser_password(t *testing.T) { + type testCase struct { + req newdbplugin.UpdateUserRequest + expectErr bool + expectedPassword string + } + + dbUser := "vaultuser" + initPassword := "p4$sw0rd" + + tests := map[string]testCase{ + "missing password": { + req: newdbplugin.UpdateUserRequest{ + Username: dbUser, + Password: &newdbplugin.ChangePassword{ + NewPassword: "", + Statements: newdbplugin.Statements{}, + }, + }, + expectErr: true, + expectedPassword: initPassword, + }, + "empty rotation statements": { + req: newdbplugin.UpdateUserRequest{ + Username: dbUser, + Password: &newdbplugin.ChangePassword{ + NewPassword: "N90gkKLy8$angf", + Statements: newdbplugin.Statements{}, + }, + }, + expectErr: false, + expectedPassword: "N90gkKLy8$angf", + }, + "username rotation": { + req: newdbplugin.UpdateUserRequest{ + Username: dbUser, + Password: &newdbplugin.ChangePassword{ + NewPassword: "N90gkKLy8$angf", + Statements: newdbplugin.Statements{ + Commands: []string{ + "ALTER LOGIN [{{username}}] WITH PASSWORD = '{{password}}'", + }, + }, + }, + }, + expectErr: false, + expectedPassword: "N90gkKLy8$angf", + }, + "bad statements": { + req: newdbplugin.UpdateUserRequest{ + Username: dbUser, + Password: &newdbplugin.ChangePassword{ + NewPassword: "N90gkKLy8$angf", + Statements: newdbplugin.Statements{ + Commands: []string{ + "ahosh98asjdffs", + }, + }, + }, + }, + expectErr: true, + expectedPassword: initPassword, }, } @@ -202,124 +215,101 @@ func TestMSSQL_SetCredentials(t *testing.T) { cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t) defer cleanup() - connectionDetails := map[string]interface{}{ - "connection_url": connURL, + initReq := newdbplugin.InitializeRequest{ + Config: map[string]interface{}{ + "connection_url": connURL, + }, + VerifyConnection: true, } db := new() + dbtesting.AssertInitialize(t, db, initReq) + defer dbtesting.AssertClose(t, db) - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - _, err := db.Init(ctx, connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } - - dbUser := "vaultstatictest" - initPassword := "p4$sw0rd" createTestMSSQLUser(t, connURL, dbUser, initPassword, testMSSQLLogin) - if err := testCredsExist(t, connURL, dbUser, initPassword); err != nil { - t.Fatalf("Could not connect with initial credentials: %s", err) + assertCredsExist(t, connURL, dbUser, initPassword) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + updateResp, err := db.UpdateUser(ctx, test.req) + if test.expectErr && err == nil { + t.Fatalf("err expected, got nil") + } + if !test.expectErr && err != nil { + t.Fatalf("no error expected, got: %s", err) } - statements := dbplugin.Statements{ - Rotation: test.rotationStmts, - } - - newPassword, err := db.GenerateCredentials(context.Background()) - if err != nil { - t.Fatal(err) - } - - usernameConfig := dbplugin.StaticUserConfig{ - Username: dbUser, - Password: newPassword, - } - - username, password, err := db.SetCredentials(ctx, statements, usernameConfig) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err := testCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } - - if err := testCredsExist(t, connURL, username, initPassword); err == nil { - t.Fatalf("Should not be able to connect with initial credentials") + // Protect against future fields that aren't specified + expectedResp := newdbplugin.UpdateUserResponse{} + if !reflect.DeepEqual(updateResp, expectedResp) { + t.Fatalf("Fields missing from expected response: Actual: %#v", updateResp) } + assertCredsExist(t, connURL, dbUser, test.expectedPassword) }) } - } -func TestMSSQL_RevokeUser(t *testing.T) { +func TestDeleteUser(t *testing.T) { cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t) defer cleanup() - connectionDetails := map[string]interface{}{ - "connection_url": connURL, + dbUser := "vaultuser" + initPassword := "p4$sw0rd" + + initReq := newdbplugin.InitializeRequest{ + Config: map[string]interface{}{ + "connection_url": connURL, + }, + VerifyConnection: true, } db := new() - _, err := db.Init(context.Background(), connectionDetails, true) + dbtesting.AssertInitialize(t, db, initReq) + defer dbtesting.AssertClose(t, db) + + createTestMSSQLUser(t, connURL, dbUser, initPassword, testMSSQLLogin) + + assertCredsExist(t, connURL, dbUser, initPassword) + + deleteReq := newdbplugin.DeleteUserRequest{ + Username: dbUser, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + deleteResp, err := db.DeleteUser(ctx, deleteReq) if err != nil { - t.Fatalf("err: %s", err) + t.Fatalf("Failed to delete user: %s", err) } - statements := dbplugin.Statements{ - Creation: []string{testMSSQLRole}, + // Protect against future fields that aren't specified + expectedResp := newdbplugin.DeleteUserResponse{} + if !reflect.DeepEqual(deleteResp, expectedResp) { + t.Fatalf("Fields missing from expected response: Actual: %#v", deleteResp) } - usernameConfig := dbplugin.UsernameConfig{ - DisplayName: "test", - RoleName: "test", - } + assertCredsDoNotExist(t, connURL, dbUser, initPassword) +} - username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) +func assertCredsExist(t testing.TB, connURL, username, password string) { + t.Helper() + err := testCredsExist(connURL, username, password) if err != nil { - t.Fatalf("err: %s", err) - } - - if err = testCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } - - // Test default revoke statements - err = db.RevokeUser(context.Background(), statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err := testCredsExist(t, connURL, username, password); err == nil { - t.Fatal("Credentials were not revoked") - } - - username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err = testCredsExist(t, connURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } - - // Test custom revoke statement - statements.Revocation = []string{testMSSQLDrop} - err = db.RevokeUser(context.Background(), statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err := testCredsExist(t, connURL, username, password); err == nil { - t.Fatal("Credentials were not revoked") + t.Fatalf("Unable to log in as %q: %s", username, err) } } -func testCredsExist(t testing.TB, connURL, username, password string) error { +func assertCredsDoNotExist(t testing.TB, connURL, username, password string) { + t.Helper() + err := testCredsExist(connURL, username, password) + if err == nil { + t.Fatalf("Able to log in when it shouldn't") + } +} + +func testCredsExist(connURL, username, password string) error { // Log in with the new creds parts := strings.Split(connURL, "@") connURL = fmt.Sprintf("sqlserver://%s:%s@%s", username, password, parts[1]) @@ -332,7 +322,6 @@ func testCredsExist(t testing.TB, connURL, username, password string) error { } func createTestMSSQLUser(t *testing.T, connURL string, username, password, query string) { - db, err := sql.Open("mssql", connURL) defer db.Close() if err != nil {