1888323243
This is part 1 of 4 for renaming the `newdbplugin` package. This copies the existing package to the new location but keeps the current one in place so we can migrate the existing references over more easily.
748 lines
21 KiB
Go
748 lines
21 KiB
Go
package postgresql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/vault/helper/testhelpers/postgresql"
|
|
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
|
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
|
|
)
|
|
|
|
func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, func()) {
|
|
cleanup, connURL := postgresql.PrepareTestContainer(t, "latest")
|
|
|
|
connectionDetails := map[string]interface{}{
|
|
"connection_url": connURL,
|
|
}
|
|
for k, v := range options {
|
|
connectionDetails[k] = v
|
|
}
|
|
|
|
req := dbplugin.InitializeRequest{
|
|
Config: connectionDetails,
|
|
VerifyConnection: true,
|
|
}
|
|
|
|
db := new()
|
|
dbtesting.AssertInitialize(t, db, req)
|
|
|
|
if !db.Initialized {
|
|
t.Fatal("Database should be initialized")
|
|
}
|
|
return db, cleanup
|
|
}
|
|
|
|
func TestPostgreSQL_Initialize(t *testing.T) {
|
|
db, cleanup := getPostgreSQL(t, map[string]interface{}{
|
|
"max_open_connections": 5,
|
|
})
|
|
defer cleanup()
|
|
|
|
if err := db.Close(); err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
}
|
|
|
|
func TestPostgreSQL_InitializeWithStringVals(t *testing.T) {
|
|
db, cleanup := getPostgreSQL(t, map[string]interface{}{
|
|
"max_open_connections": "5",
|
|
})
|
|
defer cleanup()
|
|
|
|
if err := db.Close(); err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
}
|
|
|
|
func TestPostgreSQL_NewUser(t *testing.T) {
|
|
type testCase struct {
|
|
req dbplugin.NewUserRequest
|
|
expectErr bool
|
|
credsAssertion credsAssertion
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"no creation statements": {
|
|
req: dbplugin.NewUserRequest{
|
|
UsernameConfig: dbplugin.UsernameMetadata{
|
|
DisplayName: "test",
|
|
RoleName: "test",
|
|
},
|
|
// No statements
|
|
Password: "somesecurepassword",
|
|
Expiration: time.Now().Add(1 * time.Minute),
|
|
},
|
|
expectErr: true,
|
|
credsAssertion: assertCredsDoNotExist,
|
|
},
|
|
"admin name": {
|
|
req: dbplugin.NewUserRequest{
|
|
UsernameConfig: dbplugin.UsernameMetadata{
|
|
DisplayName: "test",
|
|
RoleName: "test",
|
|
},
|
|
Statements: dbplugin.Statements{
|
|
Commands: []string{`
|
|
CREATE ROLE "{{name}}" WITH
|
|
LOGIN
|
|
PASSWORD '{{password}}'
|
|
VALID UNTIL '{{expiration}}';
|
|
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`,
|
|
},
|
|
},
|
|
Password: "somesecurepassword",
|
|
Expiration: time.Now().Add(1 * time.Minute),
|
|
},
|
|
expectErr: false,
|
|
credsAssertion: assertCredsExist,
|
|
},
|
|
"admin username": {
|
|
req: dbplugin.NewUserRequest{
|
|
UsernameConfig: dbplugin.UsernameMetadata{
|
|
DisplayName: "test",
|
|
RoleName: "test",
|
|
},
|
|
Statements: dbplugin.Statements{
|
|
Commands: []string{`
|
|
CREATE ROLE "{{username}}" WITH
|
|
LOGIN
|
|
PASSWORD '{{password}}'
|
|
VALID UNTIL '{{expiration}}';
|
|
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`,
|
|
},
|
|
},
|
|
Password: "somesecurepassword",
|
|
Expiration: time.Now().Add(1 * time.Minute),
|
|
},
|
|
expectErr: false,
|
|
credsAssertion: assertCredsExist,
|
|
},
|
|
"read only name": {
|
|
req: dbplugin.NewUserRequest{
|
|
UsernameConfig: dbplugin.UsernameMetadata{
|
|
DisplayName: "test",
|
|
RoleName: "test",
|
|
},
|
|
Statements: dbplugin.Statements{
|
|
Commands: []string{`
|
|
CREATE ROLE "{{name}}" WITH
|
|
LOGIN
|
|
PASSWORD '{{password}}'
|
|
VALID UNTIL '{{expiration}}';
|
|
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
|
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`,
|
|
},
|
|
},
|
|
Password: "somesecurepassword",
|
|
Expiration: time.Now().Add(1 * time.Minute),
|
|
},
|
|
expectErr: false,
|
|
credsAssertion: assertCredsExist,
|
|
},
|
|
"read only username": {
|
|
req: dbplugin.NewUserRequest{
|
|
UsernameConfig: dbplugin.UsernameMetadata{
|
|
DisplayName: "test",
|
|
RoleName: "test",
|
|
},
|
|
Statements: dbplugin.Statements{
|
|
Commands: []string{`
|
|
CREATE ROLE "{{username}}" WITH
|
|
LOGIN
|
|
PASSWORD '{{password}}'
|
|
VALID UNTIL '{{expiration}}';
|
|
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}";
|
|
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`,
|
|
},
|
|
},
|
|
Password: "somesecurepassword",
|
|
Expiration: time.Now().Add(1 * time.Minute),
|
|
},
|
|
expectErr: false,
|
|
credsAssertion: assertCredsExist,
|
|
},
|
|
// https://github.com/hashicorp/vault/issues/6098
|
|
"reproduce GH-6098": {
|
|
req: dbplugin.NewUserRequest{
|
|
UsernameConfig: dbplugin.UsernameMetadata{
|
|
DisplayName: "test",
|
|
RoleName: "test",
|
|
},
|
|
Statements: dbplugin.Statements{
|
|
Commands: []string{
|
|
// NOTE: "rolname" in the following line is not a typo.
|
|
"DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$",
|
|
},
|
|
},
|
|
Password: "somesecurepassword",
|
|
Expiration: time.Now().Add(1 * time.Minute),
|
|
},
|
|
expectErr: false,
|
|
credsAssertion: assertCredsDoNotExist,
|
|
},
|
|
"reproduce issue with template": {
|
|
req: dbplugin.NewUserRequest{
|
|
UsernameConfig: dbplugin.UsernameMetadata{
|
|
DisplayName: "test",
|
|
RoleName: "test",
|
|
},
|
|
Statements: dbplugin.Statements{
|
|
Commands: []string{
|
|
`DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`,
|
|
},
|
|
},
|
|
Password: "somesecurepassword",
|
|
Expiration: time.Now().Add(1 * time.Minute),
|
|
},
|
|
expectErr: false,
|
|
credsAssertion: assertCredsDoNotExist,
|
|
},
|
|
"large block statements": {
|
|
req: dbplugin.NewUserRequest{
|
|
UsernameConfig: dbplugin.UsernameMetadata{
|
|
DisplayName: "test",
|
|
RoleName: "test",
|
|
},
|
|
Statements: dbplugin.Statements{
|
|
Commands: newUserLargeBlockStatements,
|
|
},
|
|
Password: "somesecurepassword",
|
|
Expiration: time.Now().Add(1 * time.Minute),
|
|
},
|
|
expectErr: false,
|
|
credsAssertion: assertCredsExist,
|
|
},
|
|
}
|
|
|
|
// Shared test container for speed - there should not be any overlap between the tests
|
|
db, cleanup := getPostgreSQL(t, nil)
|
|
defer cleanup()
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
// Give a timeout just in case the test decides to be problematic
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
resp, err := db.NewUser(ctx, test.req)
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password)
|
|
|
|
// Ensure that the role doesn't expire immediately
|
|
time.Sleep(2 * time.Second)
|
|
|
|
test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUpdateUser_Password(t *testing.T) {
|
|
type testCase struct {
|
|
statements []string
|
|
expectErr bool
|
|
credsAssertion credsAssertion
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"default statements": {
|
|
statements: nil,
|
|
expectErr: false,
|
|
credsAssertion: assertCredsExist,
|
|
},
|
|
"explicit default statements": {
|
|
statements: []string{defaultChangePasswordStatement},
|
|
expectErr: false,
|
|
credsAssertion: assertCredsExist,
|
|
},
|
|
"name instead of username": {
|
|
statements: []string{`ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`},
|
|
expectErr: false,
|
|
credsAssertion: assertCredsExist,
|
|
},
|
|
"bad statements": {
|
|
statements: []string{`asdofyas8uf77asoiajv`},
|
|
expectErr: true,
|
|
credsAssertion: assertCredsDoNotExist,
|
|
},
|
|
}
|
|
|
|
// Shared test container for speed - there should not be any overlap between the tests
|
|
db, cleanup := getPostgreSQL(t, nil)
|
|
defer cleanup()
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
initialPass := "myreallysecurepassword"
|
|
createReq := dbplugin.NewUserRequest{
|
|
UsernameConfig: dbplugin.UsernameMetadata{
|
|
DisplayName: "test",
|
|
RoleName: "test",
|
|
},
|
|
Statements: dbplugin.Statements{
|
|
Commands: []string{createAdminUser},
|
|
},
|
|
Password: initialPass,
|
|
Expiration: time.Now().Add(2 * time.Second),
|
|
}
|
|
createResp := dbtesting.AssertNewUser(t, db, createReq)
|
|
|
|
assertCredsExist(t, db.ConnectionURL, createResp.Username, initialPass)
|
|
|
|
newPass := "somenewpassword"
|
|
updateReq := dbplugin.UpdateUserRequest{
|
|
Username: createResp.Username,
|
|
Password: &dbplugin.ChangePassword{
|
|
NewPassword: newPass,
|
|
Statements: dbplugin.Statements{
|
|
Commands: test.statements,
|
|
},
|
|
},
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
_, err := db.UpdateUser(ctx, updateReq)
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
test.credsAssertion(t, db.ConnectionURL, createResp.Username, newPass)
|
|
})
|
|
}
|
|
|
|
t.Run("user does not exist", func(t *testing.T) {
|
|
newPass := "somenewpassword"
|
|
updateReq := dbplugin.UpdateUserRequest{
|
|
Username: "missing-user",
|
|
Password: &dbplugin.ChangePassword{
|
|
NewPassword: newPass,
|
|
Statements: dbplugin.Statements{},
|
|
},
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
_, err := db.UpdateUser(ctx, updateReq)
|
|
if err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
|
|
assertCredsDoNotExist(t, db.ConnectionURL, updateReq.Username, newPass)
|
|
})
|
|
}
|
|
|
|
func TestUpdateUser_Expiration(t *testing.T) {
|
|
type testCase struct {
|
|
initialExpiration time.Time
|
|
newExpiration time.Time
|
|
expectedExpiration time.Time
|
|
statements []string
|
|
expectErr bool
|
|
}
|
|
|
|
now := time.Now()
|
|
tests := map[string]testCase{
|
|
"no statements": {
|
|
initialExpiration: now.Add(1 * time.Minute),
|
|
newExpiration: now.Add(5 * time.Minute),
|
|
expectedExpiration: now.Add(5 * time.Minute),
|
|
statements: nil,
|
|
expectErr: false,
|
|
},
|
|
"default statements with name": {
|
|
initialExpiration: now.Add(1 * time.Minute),
|
|
newExpiration: now.Add(5 * time.Minute),
|
|
expectedExpiration: now.Add(5 * time.Minute),
|
|
statements: []string{defaultExpirationStatement},
|
|
expectErr: false,
|
|
},
|
|
"default statements with username": {
|
|
initialExpiration: now.Add(1 * time.Minute),
|
|
newExpiration: now.Add(5 * time.Minute),
|
|
expectedExpiration: now.Add(5 * time.Minute),
|
|
statements: []string{`ALTER ROLE "{{username}}" VALID UNTIL '{{expiration}}';`},
|
|
expectErr: false,
|
|
},
|
|
"bad statements": {
|
|
initialExpiration: now.Add(1 * time.Minute),
|
|
newExpiration: now.Add(5 * time.Minute),
|
|
expectedExpiration: now.Add(1 * time.Minute),
|
|
statements: []string{"ladshfouay09sgj"},
|
|
expectErr: true,
|
|
},
|
|
}
|
|
|
|
// Shared test container for speed - there should not be any overlap between the tests
|
|
db, cleanup := getPostgreSQL(t, nil)
|
|
defer cleanup()
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
password := "myreallysecurepassword"
|
|
initialExpiration := test.initialExpiration.Truncate(time.Second)
|
|
createReq := dbplugin.NewUserRequest{
|
|
UsernameConfig: dbplugin.UsernameMetadata{
|
|
DisplayName: "test",
|
|
RoleName: "test",
|
|
},
|
|
Statements: dbplugin.Statements{
|
|
Commands: []string{createAdminUser},
|
|
},
|
|
Password: password,
|
|
Expiration: initialExpiration,
|
|
}
|
|
createResp := dbtesting.AssertNewUser(t, db, createReq)
|
|
|
|
assertCredsExist(t, db.ConnectionURL, createResp.Username, password)
|
|
|
|
actualExpiration := getExpiration(t, db, createResp.Username)
|
|
if actualExpiration.IsZero() {
|
|
t.Fatalf("Initial expiration is zero but should be set")
|
|
}
|
|
if !actualExpiration.Equal(initialExpiration) {
|
|
t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, initialExpiration)
|
|
}
|
|
|
|
newExpiration := test.newExpiration.Truncate(time.Second)
|
|
updateReq := dbplugin.UpdateUserRequest{
|
|
Username: createResp.Username,
|
|
Expiration: &dbplugin.ChangeExpiration{
|
|
NewExpiration: newExpiration,
|
|
Statements: dbplugin.Statements{
|
|
Commands: test.statements,
|
|
},
|
|
},
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
_, err := db.UpdateUser(ctx, updateReq)
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
expectedExpiration := test.expectedExpiration.Truncate(time.Second)
|
|
actualExpiration = getExpiration(t, db, createResp.Username)
|
|
if !actualExpiration.Equal(expectedExpiration) {
|
|
t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, expectedExpiration)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func getExpiration(t testing.TB, db *PostgreSQL, username string) time.Time {
|
|
t.Helper()
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
query := fmt.Sprintf("select valuntil from pg_catalog.pg_user where usename = '%s'", username)
|
|
conn, err := db.getConnection(ctx)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get connection to database: %s", err)
|
|
}
|
|
|
|
stmt, err := conn.PrepareContext(ctx, query)
|
|
if err != nil {
|
|
t.Fatalf("Failed to prepare statement: %s", err)
|
|
}
|
|
defer stmt.Close()
|
|
|
|
rows, err := stmt.QueryContext(ctx)
|
|
if err != nil {
|
|
t.Fatalf("Failed to execute query to get expiration: %s", err)
|
|
}
|
|
|
|
if !rows.Next() {
|
|
return time.Time{} // No expiration
|
|
}
|
|
rawExp := ""
|
|
err = rows.Scan(&rawExp)
|
|
if err != nil {
|
|
t.Fatalf("Unable to get raw expiration: %s", err)
|
|
}
|
|
if rawExp == "" {
|
|
return time.Time{} // No expiration
|
|
}
|
|
exp, err := time.Parse(time.RFC3339, rawExp)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse expiration %q: %s", rawExp, err)
|
|
}
|
|
return exp
|
|
}
|
|
|
|
func TestDeleteUser(t *testing.T) {
|
|
type testCase struct {
|
|
revokeStmts []string
|
|
expectErr bool
|
|
credsAssertion credsAssertion
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"no statements": {
|
|
revokeStmts: nil,
|
|
expectErr: false,
|
|
// Wait for a short time before failing because postgres takes a moment to finish deleting the user
|
|
credsAssertion: waitUntilCredsDoNotExist(2 * time.Second),
|
|
},
|
|
"statements with name": {
|
|
revokeStmts: []string{`
|
|
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}";
|
|
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}";
|
|
REVOKE USAGE ON SCHEMA public FROM "{{name}}";
|
|
|
|
DROP ROLE IF EXISTS "{{name}}";`},
|
|
expectErr: false,
|
|
// Wait for a short time before failing because postgres takes a moment to finish deleting the user
|
|
credsAssertion: waitUntilCredsDoNotExist(2 * time.Second),
|
|
},
|
|
"statements with username": {
|
|
revokeStmts: []string{`
|
|
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{username}}";
|
|
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{username}}";
|
|
REVOKE USAGE ON SCHEMA public FROM "{{username}}";
|
|
|
|
DROP ROLE IF EXISTS "{{username}}";`},
|
|
expectErr: false,
|
|
// Wait for a short time before failing because postgres takes a moment to finish deleting the user
|
|
credsAssertion: waitUntilCredsDoNotExist(2 * time.Second),
|
|
},
|
|
"bad statements": {
|
|
revokeStmts: []string{`8a9yhfoiasjff`},
|
|
expectErr: true,
|
|
// Wait for a short time before checking because postgres takes a moment to finish deleting the user
|
|
credsAssertion: assertCredsExistAfter(100 * time.Millisecond),
|
|
},
|
|
}
|
|
|
|
// Shared test container for speed - there should not be any overlap between the tests
|
|
db, cleanup := getPostgreSQL(t, nil)
|
|
defer cleanup()
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
password := "myreallysecurepassword"
|
|
createReq := dbplugin.NewUserRequest{
|
|
UsernameConfig: dbplugin.UsernameMetadata{
|
|
DisplayName: "test",
|
|
RoleName: "test",
|
|
},
|
|
Statements: dbplugin.Statements{
|
|
Commands: []string{createAdminUser},
|
|
},
|
|
Password: password,
|
|
Expiration: time.Now().Add(2 * time.Second),
|
|
}
|
|
createResp := dbtesting.AssertNewUser(t, db, createReq)
|
|
|
|
assertCredsExist(t, db.ConnectionURL, createResp.Username, password)
|
|
|
|
deleteReq := dbplugin.DeleteUserRequest{
|
|
Username: createResp.Username,
|
|
Statements: dbplugin.Statements{
|
|
Commands: test.revokeStmts,
|
|
},
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
_, err := db.DeleteUser(ctx, deleteReq)
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
|
|
test.credsAssertion(t, db.ConnectionURL, createResp.Username, password)
|
|
})
|
|
}
|
|
}
|
|
|
|
type credsAssertion func(t testing.TB, connURL, username, password string)
|
|
|
|
func assertCredsExist(t testing.TB, connURL, username, password string) {
|
|
t.Helper()
|
|
err := testCredsExist(t, connURL, username, password)
|
|
if err != nil {
|
|
t.Fatalf("user does not exist: %s", err)
|
|
}
|
|
}
|
|
|
|
func assertCredsDoNotExist(t testing.TB, connURL, username, password string) {
|
|
t.Helper()
|
|
err := testCredsExist(t, connURL, username, password)
|
|
if err == nil {
|
|
t.Fatalf("user should not exist but does")
|
|
}
|
|
}
|
|
|
|
func waitUntilCredsDoNotExist(timeout time.Duration) credsAssertion {
|
|
return func(t testing.TB, connURL, username, password string) {
|
|
t.Helper()
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
|
|
ticker := time.NewTicker(10 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
t.Fatalf("Timed out waiting for user %s to be deleted", username)
|
|
case <-ticker.C:
|
|
err := testCredsExist(t, connURL, username, password)
|
|
if err != nil {
|
|
// Happy path
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func assertCredsExistAfter(timeout time.Duration) credsAssertion {
|
|
return func(t testing.TB, connURL, username, password string) {
|
|
t.Helper()
|
|
time.Sleep(timeout)
|
|
assertCredsExist(t, connURL, username, password)
|
|
}
|
|
}
|
|
|
|
func testCredsExist(t testing.TB, connURL, username, password string) error {
|
|
t.Helper()
|
|
// Log in with the new creds
|
|
connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", username, password), 1)
|
|
db, err := sql.Open("postgres", connURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer db.Close()
|
|
return db.Ping()
|
|
}
|
|
|
|
const createAdminUser = `
|
|
CREATE ROLE "{{name}}" WITH
|
|
LOGIN
|
|
PASSWORD '{{password}}'
|
|
VALID UNTIL '{{expiration}}';
|
|
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
|
`
|
|
|
|
var newUserLargeBlockStatements = []string{
|
|
`
|
|
DO $$
|
|
BEGIN
|
|
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
|
|
CREATE ROLE "foo-role";
|
|
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
|
|
ALTER ROLE "foo-role" SET search_path = foo;
|
|
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
|
|
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
|
|
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
|
|
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
|
|
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
|
|
END IF;
|
|
END
|
|
$$
|
|
`,
|
|
`CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`,
|
|
`GRANT "foo-role" TO "{{name}}";`,
|
|
`ALTER ROLE "{{name}}" SET search_path = foo;`,
|
|
`GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`,
|
|
}
|
|
|
|
func TestContainsMultilineStatement(t *testing.T) {
|
|
type testCase struct {
|
|
Input string
|
|
Expected bool
|
|
}
|
|
|
|
testCases := map[string]*testCase{
|
|
"issue 6098 repro": {
|
|
Input: `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$`,
|
|
Expected: true,
|
|
},
|
|
"multiline with template fields": {
|
|
Input: `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname="{{name}}") THEN CREATE ROLE {{name}}; END IF; END $$`,
|
|
Expected: true,
|
|
},
|
|
"docs example": {
|
|
Input: `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; \
|
|
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";`,
|
|
Expected: false,
|
|
},
|
|
}
|
|
|
|
for tName, tCase := range testCases {
|
|
t.Run(tName, func(t *testing.T) {
|
|
if containsMultilineStatement(tCase.Input) != tCase.Expected {
|
|
t.Fatalf("%q should be %t for multiline input", tCase.Input, tCase.Expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtractQuotedStrings(t *testing.T) {
|
|
type testCase struct {
|
|
Input string
|
|
Expected []string
|
|
}
|
|
|
|
testCases := map[string]*testCase{
|
|
"no quotes": {
|
|
Input: `Five little monkeys jumping on the bed`,
|
|
Expected: []string{},
|
|
},
|
|
"two of both quote types": {
|
|
Input: `"Five" little 'monkeys' "jumping on" the' 'bed`,
|
|
Expected: []string{`"Five"`, `"jumping on"`, `'monkeys'`, `' '`},
|
|
},
|
|
"one single quote": {
|
|
Input: `Five little monkeys 'jumping on the bed`,
|
|
Expected: []string{},
|
|
},
|
|
"empty string": {
|
|
Input: ``,
|
|
Expected: []string{},
|
|
},
|
|
"templated field": {
|
|
Input: `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname="{{name}}") THEN CREATE ROLE {{name}}; END IF; END $$`,
|
|
Expected: []string{`"{{name}}"`},
|
|
},
|
|
}
|
|
|
|
for tName, tCase := range testCases {
|
|
t.Run(tName, func(t *testing.T) {
|
|
results, err := extractQuotedStrings(tCase.Input)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(results) != len(tCase.Expected) {
|
|
t.Fatalf("%s isn't equal to %s", results, tCase.Expected)
|
|
}
|
|
for i := range results {
|
|
if results[i] != tCase.Expected[i] {
|
|
t.Fatalf(`expected %q but received %q`, tCase.Expected, results[i])
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|