package postgresql import ( "context" "database/sql" "fmt" "strings" "testing" "time" "github.com/hashicorp/vault/helper/testhelpers/postgresql" dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5" dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing" "github.com/hashicorp/vault/sdk/helper/template" "github.com/stretchr/testify/require" ) func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, func()) { cleanup, connURL := postgresql.PrepareTestContainer(t, "latest") connectionDetails := map[string]interface{}{ "connection_url": connURL, } for k, v := range options { connectionDetails[k] = v } req := dbplugin.InitializeRequest{ Config: connectionDetails, VerifyConnection: true, } db := new() dbtesting.AssertInitialize(t, db, req) if !db.Initialized { t.Fatal("Database should be initialized") } return db, cleanup } func TestPostgreSQL_Initialize(t *testing.T) { db, cleanup := getPostgreSQL(t, map[string]interface{}{ "max_open_connections": 5, }) defer cleanup() if err := db.Close(); err != nil { t.Fatalf("err: %s", err) } } func TestPostgreSQL_InitializeWithStringVals(t *testing.T) { db, cleanup := getPostgreSQL(t, map[string]interface{}{ "max_open_connections": "5", }) defer cleanup() if err := db.Close(); err != nil { t.Fatalf("err: %s", err) } } func TestPostgreSQL_NewUser(t *testing.T) { type testCase struct { req dbplugin.NewUserRequest expectErr bool credsAssertion credsAssertion } tests := map[string]testCase{ "no creation statements": { req: dbplugin.NewUserRequest{ UsernameConfig: dbplugin.UsernameMetadata{ DisplayName: "test", RoleName: "test", }, // No statements Password: "somesecurepassword", Expiration: time.Now().Add(1 * time.Minute), }, expectErr: true, credsAssertion: assertCreds( assertUsernameRegex("^$"), assertCredsDoNotExist, ), }, "admin name": { req: dbplugin.NewUserRequest{ UsernameConfig: dbplugin.UsernameMetadata{ DisplayName: "test", RoleName: "test", }, Statements: dbplugin.Statements{ Commands: []string{ ` CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`, }, }, Password: "somesecurepassword", Expiration: time.Now().Add(1 * time.Minute), }, expectErr: false, credsAssertion: assertCreds( assertUsernameRegex("^v-test-test-[a-zA-Z0-9]{20}-[0-9]{10}$"), assertCredsExist, ), }, "admin username": { req: dbplugin.NewUserRequest{ UsernameConfig: dbplugin.UsernameMetadata{ DisplayName: "test", RoleName: "test", }, Statements: dbplugin.Statements{ Commands: []string{ ` CREATE ROLE "{{username}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`, }, }, Password: "somesecurepassword", Expiration: time.Now().Add(1 * time.Minute), }, expectErr: false, credsAssertion: assertCreds( assertUsernameRegex("^v-test-test-[a-zA-Z0-9]{20}-[0-9]{10}$"), assertCredsExist, ), }, "read only name": { req: dbplugin.NewUserRequest{ UsernameConfig: dbplugin.UsernameMetadata{ DisplayName: "test", RoleName: "test", }, Statements: dbplugin.Statements{ Commands: []string{ ` CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`, }, }, Password: "somesecurepassword", Expiration: time.Now().Add(1 * time.Minute), }, expectErr: false, credsAssertion: assertCreds( assertUsernameRegex("^v-test-test-[a-zA-Z0-9]{20}-[0-9]{10}$"), assertCredsExist, ), }, "read only username": { req: dbplugin.NewUserRequest{ UsernameConfig: dbplugin.UsernameMetadata{ DisplayName: "test", RoleName: "test", }, Statements: dbplugin.Statements{ Commands: []string{ ` CREATE ROLE "{{username}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}"; GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`, }, }, Password: "somesecurepassword", Expiration: time.Now().Add(1 * time.Minute), }, expectErr: false, credsAssertion: assertCreds( assertUsernameRegex("^v-test-test-[a-zA-Z0-9]{20}-[0-9]{10}$"), assertCredsExist, ), }, // https://github.com/hashicorp/vault/issues/6098 "reproduce GH-6098": { req: dbplugin.NewUserRequest{ UsernameConfig: dbplugin.UsernameMetadata{ DisplayName: "test", RoleName: "test", }, Statements: dbplugin.Statements{ Commands: []string{ // NOTE: "rolname" in the following line is not a typo. "DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$", }, }, Password: "somesecurepassword", Expiration: time.Now().Add(1 * time.Minute), }, expectErr: false, credsAssertion: assertCreds( assertUsernameRegex("^v-test-test-[a-zA-Z0-9]{20}-[0-9]{10}$"), assertCredsDoNotExist, ), }, "reproduce issue with template": { req: dbplugin.NewUserRequest{ UsernameConfig: dbplugin.UsernameMetadata{ DisplayName: "test", RoleName: "test", }, Statements: dbplugin.Statements{ Commands: []string{ `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`, }, }, Password: "somesecurepassword", Expiration: time.Now().Add(1 * time.Minute), }, expectErr: false, credsAssertion: assertCreds( assertUsernameRegex("^v-test-test-[a-zA-Z0-9]{20}-[0-9]{10}$"), assertCredsDoNotExist, ), }, "large block statements": { req: dbplugin.NewUserRequest{ UsernameConfig: dbplugin.UsernameMetadata{ DisplayName: "test", RoleName: "test", }, Statements: dbplugin.Statements{ Commands: newUserLargeBlockStatements, }, Password: "somesecurepassword", Expiration: time.Now().Add(1 * time.Minute), }, expectErr: false, credsAssertion: assertCreds( assertUsernameRegex("^v-test-test-[a-zA-Z0-9]{20}-[0-9]{10}$"), assertCredsExist, ), }, } // Shared test container for speed - there should not be any overlap between the tests db, cleanup := getPostgreSQL(t, nil) defer cleanup() for name, test := range tests { t.Run(name, func(t *testing.T) { // Give a timeout just in case the test decides to be problematic ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() resp, err := db.NewUser(ctx, test.req) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } if !test.expectErr && err != nil { t.Fatalf("no error expected, got: %s", err) } test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password) // Ensure that the role doesn't expire immediately time.Sleep(2 * time.Second) test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password) }) } } func TestUpdateUser_Password(t *testing.T) { type testCase struct { statements []string expectErr bool credsAssertion credsAssertion } tests := map[string]testCase{ "default statements": { statements: nil, expectErr: false, credsAssertion: assertCredsExist, }, "explicit default statements": { statements: []string{defaultChangePasswordStatement}, expectErr: false, credsAssertion: assertCredsExist, }, "name instead of username": { statements: []string{`ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`}, expectErr: false, credsAssertion: assertCredsExist, }, "bad statements": { statements: []string{`asdofyas8uf77asoiajv`}, expectErr: true, credsAssertion: assertCredsDoNotExist, }, } // Shared test container for speed - there should not be any overlap between the tests db, cleanup := getPostgreSQL(t, nil) defer cleanup() for name, test := range tests { t.Run(name, func(t *testing.T) { initialPass := "myreallysecurepassword" createReq := dbplugin.NewUserRequest{ UsernameConfig: dbplugin.UsernameMetadata{ DisplayName: "test", RoleName: "test", }, Statements: dbplugin.Statements{ Commands: []string{createAdminUser}, }, Password: initialPass, Expiration: time.Now().Add(2 * time.Second), } createResp := dbtesting.AssertNewUser(t, db, createReq) assertCredsExist(t, db.ConnectionURL, createResp.Username, initialPass) newPass := "somenewpassword" updateReq := dbplugin.UpdateUserRequest{ Username: createResp.Username, Password: &dbplugin.ChangePassword{ NewPassword: newPass, Statements: dbplugin.Statements{ Commands: test.statements, }, }, } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, err := db.UpdateUser(ctx, updateReq) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } if !test.expectErr && err != nil { t.Fatalf("no error expected, got: %s", err) } test.credsAssertion(t, db.ConnectionURL, createResp.Username, newPass) }) } t.Run("user does not exist", func(t *testing.T) { newPass := "somenewpassword" updateReq := dbplugin.UpdateUserRequest{ Username: "missing-user", Password: &dbplugin.ChangePassword{ NewPassword: newPass, Statements: dbplugin.Statements{}, }, } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, err := db.UpdateUser(ctx, updateReq) if err == nil { t.Fatalf("err expected, got nil") } assertCredsDoNotExist(t, db.ConnectionURL, updateReq.Username, newPass) }) } func TestUpdateUser_Expiration(t *testing.T) { type testCase struct { initialExpiration time.Time newExpiration time.Time expectedExpiration time.Time statements []string expectErr bool } now := time.Now() tests := map[string]testCase{ "no statements": { initialExpiration: now.Add(1 * time.Minute), newExpiration: now.Add(5 * time.Minute), expectedExpiration: now.Add(5 * time.Minute), statements: nil, expectErr: false, }, "default statements with name": { initialExpiration: now.Add(1 * time.Minute), newExpiration: now.Add(5 * time.Minute), expectedExpiration: now.Add(5 * time.Minute), statements: []string{defaultExpirationStatement}, expectErr: false, }, "default statements with username": { initialExpiration: now.Add(1 * time.Minute), newExpiration: now.Add(5 * time.Minute), expectedExpiration: now.Add(5 * time.Minute), statements: []string{`ALTER ROLE "{{username}}" VALID UNTIL '{{expiration}}';`}, expectErr: false, }, "bad statements": { initialExpiration: now.Add(1 * time.Minute), newExpiration: now.Add(5 * time.Minute), expectedExpiration: now.Add(1 * time.Minute), statements: []string{"ladshfouay09sgj"}, expectErr: true, }, } // Shared test container for speed - there should not be any overlap between the tests db, cleanup := getPostgreSQL(t, nil) defer cleanup() for name, test := range tests { t.Run(name, func(t *testing.T) { password := "myreallysecurepassword" initialExpiration := test.initialExpiration.Truncate(time.Second) createReq := dbplugin.NewUserRequest{ UsernameConfig: dbplugin.UsernameMetadata{ DisplayName: "test", RoleName: "test", }, Statements: dbplugin.Statements{ Commands: []string{createAdminUser}, }, Password: password, Expiration: initialExpiration, } createResp := dbtesting.AssertNewUser(t, db, createReq) assertCredsExist(t, db.ConnectionURL, createResp.Username, password) actualExpiration := getExpiration(t, db, createResp.Username) if actualExpiration.IsZero() { t.Fatalf("Initial expiration is zero but should be set") } if !actualExpiration.Equal(initialExpiration) { t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, initialExpiration) } newExpiration := test.newExpiration.Truncate(time.Second) updateReq := dbplugin.UpdateUserRequest{ Username: createResp.Username, Expiration: &dbplugin.ChangeExpiration{ NewExpiration: newExpiration, Statements: dbplugin.Statements{ Commands: test.statements, }, }, } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, err := db.UpdateUser(ctx, updateReq) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } if !test.expectErr && err != nil { t.Fatalf("no error expected, got: %s", err) } expectedExpiration := test.expectedExpiration.Truncate(time.Second) actualExpiration = getExpiration(t, db, createResp.Username) if !actualExpiration.Equal(expectedExpiration) { t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, expectedExpiration) } }) } } func getExpiration(t testing.TB, db *PostgreSQL, username string) time.Time { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() query := fmt.Sprintf("select valuntil from pg_catalog.pg_user where usename = '%s'", username) conn, err := db.getConnection(ctx) if err != nil { t.Fatalf("Failed to get connection to database: %s", err) } stmt, err := conn.PrepareContext(ctx, query) if err != nil { t.Fatalf("Failed to prepare statement: %s", err) } defer stmt.Close() rows, err := stmt.QueryContext(ctx) if err != nil { t.Fatalf("Failed to execute query to get expiration: %s", err) } if !rows.Next() { return time.Time{} // No expiration } rawExp := "" err = rows.Scan(&rawExp) if err != nil { t.Fatalf("Unable to get raw expiration: %s", err) } if rawExp == "" { return time.Time{} // No expiration } exp, err := time.Parse(time.RFC3339, rawExp) if err != nil { t.Fatalf("Failed to parse expiration %q: %s", rawExp, err) } return exp } func TestDeleteUser(t *testing.T) { type testCase struct { revokeStmts []string expectErr bool credsAssertion credsAssertion } tests := map[string]testCase{ "no statements": { revokeStmts: nil, expectErr: false, // Wait for a short time before failing because postgres takes a moment to finish deleting the user credsAssertion: waitUntilCredsDoNotExist(2 * time.Second), }, "statements with name": { revokeStmts: []string{` REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; REVOKE USAGE ON SCHEMA public FROM "{{name}}"; DROP ROLE IF EXISTS "{{name}}";`}, expectErr: false, // Wait for a short time before failing because postgres takes a moment to finish deleting the user credsAssertion: waitUntilCredsDoNotExist(2 * time.Second), }, "statements with username": { revokeStmts: []string{` REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{username}}"; REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{username}}"; REVOKE USAGE ON SCHEMA public FROM "{{username}}"; DROP ROLE IF EXISTS "{{username}}";`}, expectErr: false, // Wait for a short time before failing because postgres takes a moment to finish deleting the user credsAssertion: waitUntilCredsDoNotExist(2 * time.Second), }, "bad statements": { revokeStmts: []string{`8a9yhfoiasjff`}, expectErr: true, // Wait for a short time before checking because postgres takes a moment to finish deleting the user credsAssertion: assertCredsExistAfter(100 * time.Millisecond), }, } // Shared test container for speed - there should not be any overlap between the tests db, cleanup := getPostgreSQL(t, nil) defer cleanup() for name, test := range tests { t.Run(name, func(t *testing.T) { password := "myreallysecurepassword" createReq := dbplugin.NewUserRequest{ UsernameConfig: dbplugin.UsernameMetadata{ DisplayName: "test", RoleName: "test", }, Statements: dbplugin.Statements{ Commands: []string{createAdminUser}, }, Password: password, Expiration: time.Now().Add(2 * time.Second), } createResp := dbtesting.AssertNewUser(t, db, createReq) assertCredsExist(t, db.ConnectionURL, createResp.Username, password) deleteReq := dbplugin.DeleteUserRequest{ Username: createResp.Username, Statements: dbplugin.Statements{ Commands: test.revokeStmts, }, } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, err := db.DeleteUser(ctx, deleteReq) if test.expectErr && err == nil { t.Fatalf("err expected, got nil") } if !test.expectErr && err != nil { t.Fatalf("no error expected, got: %s", err) } test.credsAssertion(t, db.ConnectionURL, createResp.Username, password) }) } } type credsAssertion func(t testing.TB, connURL, username, password string) func assertCreds(assertions ...credsAssertion) credsAssertion { return func(t testing.TB, connURL, username, password string) { t.Helper() for _, assertion := range assertions { assertion(t, connURL, username, password) } } } func assertUsernameRegex(rawRegex string) credsAssertion { return func(t testing.TB, _, username, _ string) { t.Helper() require.Regexp(t, rawRegex, username) } } func assertCredsExist(t testing.TB, connURL, username, password string) { t.Helper() err := testCredsExist(t, connURL, username, password) if err != nil { t.Fatalf("user does not exist: %s", err) } } func assertCredsDoNotExist(t testing.TB, connURL, username, password string) { t.Helper() err := testCredsExist(t, connURL, username, password) if err == nil { t.Fatalf("user should not exist but does") } } func waitUntilCredsDoNotExist(timeout time.Duration) credsAssertion { return func(t testing.TB, connURL, username, password string) { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for { select { case <-ctx.Done(): t.Fatalf("Timed out waiting for user %s to be deleted", username) case <-ticker.C: err := testCredsExist(t, connURL, username, password) if err != nil { // Happy path return } } } } } func assertCredsExistAfter(timeout time.Duration) credsAssertion { return func(t testing.TB, connURL, username, password string) { t.Helper() time.Sleep(timeout) assertCredsExist(t, connURL, username, password) } } func testCredsExist(t testing.TB, connURL, username, password string) error { t.Helper() // Log in with the new creds connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", username, password), 1) db, err := sql.Open("postgres", connURL) if err != nil { return err } defer db.Close() return db.Ping() } const createAdminUser = ` CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; ` var newUserLargeBlockStatements = []string{ ` DO $$ BEGIN IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN CREATE ROLE "foo-role"; CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; ALTER ROLE "foo-role" SET search_path = foo; GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; END IF; END $$ `, `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`, `GRANT "foo-role" TO "{{name}}";`, `ALTER ROLE "{{name}}" SET search_path = foo;`, `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, } func TestContainsMultilineStatement(t *testing.T) { type testCase struct { Input string Expected bool } testCases := map[string]*testCase{ "issue 6098 repro": { Input: `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$`, Expected: true, }, "multiline with template fields": { Input: `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname="{{name}}") THEN CREATE ROLE {{name}}; END IF; END $$`, Expected: true, }, "docs example": { Input: `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; \ GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";`, Expected: false, }, } for tName, tCase := range testCases { t.Run(tName, func(t *testing.T) { if containsMultilineStatement(tCase.Input) != tCase.Expected { t.Fatalf("%q should be %t for multiline input", tCase.Input, tCase.Expected) } }) } } func TestExtractQuotedStrings(t *testing.T) { type testCase struct { Input string Expected []string } testCases := map[string]*testCase{ "no quotes": { Input: `Five little monkeys jumping on the bed`, Expected: []string{}, }, "two of both quote types": { Input: `"Five" little 'monkeys' "jumping on" the' 'bed`, Expected: []string{`"Five"`, `"jumping on"`, `'monkeys'`, `' '`}, }, "one single quote": { Input: `Five little monkeys 'jumping on the bed`, Expected: []string{}, }, "empty string": { Input: ``, Expected: []string{}, }, "templated field": { Input: `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname="{{name}}") THEN CREATE ROLE {{name}}; END IF; END $$`, Expected: []string{`"{{name}}"`}, }, } for tName, tCase := range testCases { t.Run(tName, func(t *testing.T) { results, err := extractQuotedStrings(tCase.Input) if err != nil { t.Fatal(err) } if len(results) != len(tCase.Expected) { t.Fatalf("%s isn't equal to %s", results, tCase.Expected) } for i := range results { if results[i] != tCase.Expected[i] { t.Fatalf(`expected %q but received %q`, tCase.Expected, results[i]) } } }) } } func TestUsernameGeneration(t *testing.T) { type testCase struct { data dbplugin.UsernameMetadata expectedRegex string } tests := map[string]testCase{ "simple display and role names": { data: dbplugin.UsernameMetadata{ DisplayName: "token", RoleName: "myrole", }, expectedRegex: `v-token-myrole-[a-zA-Z0-9]{20}-[0-9]{10}`, }, "display name has dash": { data: dbplugin.UsernameMetadata{ DisplayName: "token-foo", RoleName: "myrole", }, expectedRegex: `v-token-fo-myrole-[a-zA-Z0-9]{20}-[0-9]{10}`, }, "display name has underscore": { data: dbplugin.UsernameMetadata{ DisplayName: "token_foo", RoleName: "myrole", }, expectedRegex: `v-token_fo-myrole-[a-zA-Z0-9]{20}-[0-9]{10}`, }, "display name has period": { data: dbplugin.UsernameMetadata{ DisplayName: "token.foo", RoleName: "myrole", }, expectedRegex: `v-token.fo-myrole-[a-zA-Z0-9]{20}-[0-9]{10}`, }, "role name has dash": { data: dbplugin.UsernameMetadata{ DisplayName: "token", RoleName: "myrole-foo", }, expectedRegex: `v-token-myrole-f-[a-zA-Z0-9]{20}-[0-9]{10}`, }, "role name has underscore": { data: dbplugin.UsernameMetadata{ DisplayName: "token", RoleName: "myrole_foo", }, expectedRegex: `v-token-myrole_f-[a-zA-Z0-9]{20}-[0-9]{10}`, }, "role name has period": { data: dbplugin.UsernameMetadata{ DisplayName: "token", RoleName: "myrole.foo", }, expectedRegex: `v-token-myrole.f-[a-zA-Z0-9]{20}-[0-9]{10}`, }, } for name, test := range tests { t.Run(fmt.Sprintf("new-%s", name), func(t *testing.T) { up, err := template.NewTemplate( template.Template(defaultUserNameTemplate), ) require.NoError(t, err) for i := 0; i < 1000; i++ { username, err := up.Generate(test.data) require.NoError(t, err) require.Regexp(t, test.expectedRegex, username) } }) } } func TestNewUser_CustomUsername(t *testing.T) { cleanup, connURL := postgresql.PrepareTestContainer(t, "latest") defer cleanup() type testCase struct { usernameTemplate string newUserData dbplugin.UsernameMetadata expectedRegex string } tests := map[string]testCase{ "default template": { usernameTemplate: "", newUserData: dbplugin.UsernameMetadata{ DisplayName: "displayname", RoleName: "longrolename", }, expectedRegex: "^v-displayn-longrole-[a-zA-Z0-9]{20}-[0-9]{10}$", }, "explicit default template": { usernameTemplate: defaultUserNameTemplate, newUserData: dbplugin.UsernameMetadata{ DisplayName: "displayname", RoleName: "longrolename", }, expectedRegex: "^v-displayn-longrole-[a-zA-Z0-9]{20}-[0-9]{10}$", }, "unique template": { usernameTemplate: "foo-bar", newUserData: dbplugin.UsernameMetadata{ DisplayName: "displayname", RoleName: "longrolename", }, expectedRegex: "^foo-bar$", }, "custom prefix": { usernameTemplate: "foobar-{{.DisplayName | truncate 8}}-{{.RoleName | truncate 8}}-{{random 20}}-{{unix_time}}", newUserData: dbplugin.UsernameMetadata{ DisplayName: "displayname", RoleName: "longrolename", }, expectedRegex: "^foobar-displayn-longrole-[a-zA-Z0-9]{20}-[0-9]{10}$", }, "totally custom template": { usernameTemplate: "foobar_{{random 10}}-{{.RoleName | uppercase}}.{{unix_time}}x{{.DisplayName | truncate 5}}", newUserData: dbplugin.UsernameMetadata{ DisplayName: "displayname", RoleName: "longrolename", }, expectedRegex: `^foobar_[a-zA-Z0-9]{10}-LONGROLENAME\.[0-9]{10}xdispl$`, }, } for name, test := range tests { t.Run(name, func(t *testing.T) { initReq := dbplugin.InitializeRequest{ Config: map[string]interface{}{ "connection_url": connURL, "username_template": test.usernameTemplate, }, VerifyConnection: true, } db := new() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() _, err := db.Initialize(ctx, initReq) require.NoError(t, err) newUserReq := dbplugin.NewUserRequest{ UsernameConfig: test.newUserData, Statements: dbplugin.Statements{ Commands: []string{ ` CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`, }, }, Password: "myReally-S3curePassword", Expiration: time.Now().Add(1 * time.Hour), } ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) defer cancel() newUserResp, err := db.NewUser(ctx, newUserReq) require.NoError(t, err) require.Regexp(t, test.expectedRegex, newUserResp.Username) }) } }