open-vault/plugins/database/mssql/mssql_test.go
Christopher Swenson aa6d61477e
VAULT-5827 Don't prepare SQL queries before executing them (#15166)
VAULT-5827 Don't prepare SQL queries before executing them

We don't support proper prepared statements, i.e., preparing once and
executing many times since we do our own templating. So preparing our
queries does not really accomplish anything, and can have severe
performance impacts (see
https://github.com/hashicorp/vault-plugin-database-snowflake/issues/13
for example).

This behavior seems to have been copy-pasted for many years but not for
any particular reason that we have been able to find. First use was in
https://github.com/hashicorp/vault/pull/15

So here we switch to new methods suffixed with `Direct` to indicate
that they don't `Prepare` before running `Exec`, and switch everything
here to use those. We maintain the older methods with the existing
behavior (with `Prepare`) for backwards compatibility.
2022-04-26 12:47:06 -07:00

578 lines
15 KiB
Go

package mssql
import (
"context"
"database/sql"
"fmt"
"reflect"
"regexp"
"strings"
"testing"
"time"
mssqlhelper "github.com/hashicorp/vault/helper/testhelpers/mssql"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
"github.com/hashicorp/vault/sdk/helper/dbtxn"
"github.com/stretchr/testify/assert"
)
func TestInitialize(t *testing.T) {
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
type testCase struct {
req dbplugin.InitializeRequest
}
tests := map[string]testCase{
"happy path": {
req: dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
},
VerifyConnection: true,
},
},
"max_open_connections set": {
dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"max_open_connections": "5",
},
VerifyConnection: true,
},
},
"contained_db set": {
dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"contained_db": true,
},
VerifyConnection: true,
},
},
"contained_db set string": {
dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"contained_db": "true",
},
VerifyConnection: true,
},
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
db := new()
dbtesting.AssertInitializeCircleCiTest(t, db, test.req)
defer dbtesting.AssertClose(t, db)
if !db.Initialized {
t.Fatal("Database should be initialized")
}
})
}
}
func TestNewUser(t *testing.T) {
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
type testCase struct {
usernameTemplate string
req dbplugin.NewUserRequest
usernameRegex string
expectErr bool
assertUser func(t testing.TB, connURL, username, password string)
}
tests := map[string]testCase{
"no creation statements": {
req: dbplugin.NewUserRequest{
UsernameConfig: dbplugin.UsernameMetadata{
DisplayName: "test",
RoleName: "test",
},
Statements: dbplugin.Statements{},
Password: "AG4qagho-dsvZ",
Expiration: time.Now().Add(1 * time.Second),
},
usernameRegex: "^$",
expectErr: true,
assertUser: assertCredsDoNotExist,
},
"with creation statements": {
req: dbplugin.NewUserRequest{
UsernameConfig: dbplugin.UsernameMetadata{
DisplayName: "test",
RoleName: "test",
},
Statements: dbplugin.Statements{
Commands: []string{testMSSQLRole},
},
Password: "AG4qagho-dsvZ",
Expiration: time.Now().Add(1 * time.Second),
},
usernameRegex: "^v-test-test-[a-zA-Z0-9]{20}-[0-9]{10}$",
expectErr: false,
assertUser: assertCredsExist,
},
"custom username template": {
usernameTemplate: "{{random 10}}_{{.RoleName}}.{{.DisplayName | sha256}}",
req: dbplugin.NewUserRequest{
UsernameConfig: dbplugin.UsernameMetadata{
DisplayName: "tokenwithlotsofextracharactershere",
RoleName: "myrolenamewithlotsofextracharacters",
},
Statements: dbplugin.Statements{
Commands: []string{testMSSQLRole},
},
Password: "AG4qagho-dsvZ",
Expiration: time.Now().Add(1 * time.Second),
},
usernameRegex: "^[a-zA-Z0-9]{10}_myrolenamewithlotsofextracharacters.80d15d22dba29ddbd4994f8009b5ff7b17922c267eb49fb805a9488bd55d11f9$",
expectErr: false,
assertUser: assertCredsExist,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
usernameRe, err := regexp.Compile(test.usernameRegex)
if err != nil {
t.Fatalf("failed to compile username regex %q: %s", test.usernameRegex, err)
}
initReq := dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"username_template": test.usernameTemplate,
},
VerifyConnection: true,
}
db := new()
dbtesting.AssertInitializeCircleCiTest(t, db, initReq)
defer dbtesting.AssertClose(t, db)
createResp, err := db.NewUser(context.Background(), test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !usernameRe.MatchString(createResp.Username) {
t.Fatalf("Generated username %q did not match regex %q", createResp.Username, test.usernameRegex)
}
// Protect against future fields that aren't specified
expectedResp := dbplugin.NewUserResponse{
Username: createResp.Username,
}
if !reflect.DeepEqual(createResp, expectedResp) {
t.Fatalf("Fields missing from expected response: Actual: %#v", createResp)
}
test.assertUser(t, connURL, createResp.Username, test.req.Password)
})
}
}
func TestUpdateUser_password(t *testing.T) {
type testCase struct {
req dbplugin.UpdateUserRequest
expectErr bool
expectedPassword string
}
dbUser := "vaultuser"
initPassword := "p4$sw0rd"
tests := map[string]testCase{
"missing password": {
req: dbplugin.UpdateUserRequest{
Username: dbUser,
Password: &dbplugin.ChangePassword{
NewPassword: "",
Statements: dbplugin.Statements{},
},
},
expectErr: true,
expectedPassword: initPassword,
},
"empty rotation statements": {
req: dbplugin.UpdateUserRequest{
Username: dbUser,
Password: &dbplugin.ChangePassword{
NewPassword: "N90gkKLy8$angf",
Statements: dbplugin.Statements{},
},
},
expectErr: false,
expectedPassword: "N90gkKLy8$angf",
},
"username rotation": {
req: dbplugin.UpdateUserRequest{
Username: dbUser,
Password: &dbplugin.ChangePassword{
NewPassword: "N90gkKLy8$angf",
Statements: dbplugin.Statements{
Commands: []string{
"ALTER LOGIN [{{username}}] WITH PASSWORD = '{{password}}'",
},
},
},
},
expectErr: false,
expectedPassword: "N90gkKLy8$angf",
},
"bad statements": {
req: dbplugin.UpdateUserRequest{
Username: dbUser,
Password: &dbplugin.ChangePassword{
NewPassword: "N90gkKLy8$angf",
Statements: dbplugin.Statements{
Commands: []string{
"ahosh98asjdffs",
},
},
},
},
expectErr: true,
expectedPassword: initPassword,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
initReq := dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
},
VerifyConnection: true,
}
db := new()
dbtesting.AssertInitializeCircleCiTest(t, db, initReq)
defer dbtesting.AssertClose(t, db)
err := createTestMSSQLUser(connURL, dbUser, initPassword, testMSSQLLogin)
if err != nil {
t.Fatalf("Failed to create user: %s", err)
}
assertCredsExist(t, connURL, dbUser, initPassword)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
updateResp, err := db.UpdateUser(ctx, test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
// Protect against future fields that aren't specified
expectedResp := dbplugin.UpdateUserResponse{}
if !reflect.DeepEqual(updateResp, expectedResp) {
t.Fatalf("Fields missing from expected response: Actual: %#v", updateResp)
}
assertCredsExist(t, connURL, dbUser, test.expectedPassword)
// Delete user at the end of each test
deleteReq := dbplugin.DeleteUserRequest{
Username: dbUser,
}
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
deleteResp, err := db.DeleteUser(ctx, deleteReq)
if err != nil {
t.Fatalf("Failed to delete user: %s", err)
}
// Protect against future fields that aren't specified
expectedDeleteResp := dbplugin.DeleteUserResponse{}
if !reflect.DeepEqual(deleteResp, expectedDeleteResp) {
t.Fatalf("Fields missing from expected response: Actual: %#v", deleteResp)
}
assertCredsDoNotExist(t, connURL, dbUser, initPassword)
})
}
}
func TestDeleteUser(t *testing.T) {
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
dbUser := "vaultuser"
initPassword := "p4$sw0rd"
initReq := dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
},
VerifyConnection: true,
}
db := new()
dbtesting.AssertInitializeCircleCiTest(t, db, initReq)
defer dbtesting.AssertClose(t, db)
err := createTestMSSQLUser(connURL, dbUser, initPassword, testMSSQLLogin)
if err != nil {
t.Fatalf("Failed to create user: %s", err)
}
assertCredsExist(t, connURL, dbUser, initPassword)
deleteReq := dbplugin.DeleteUserRequest{
Username: dbUser,
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
deleteResp, err := db.DeleteUser(ctx, deleteReq)
if err != nil {
t.Fatalf("Failed to delete user: %s", err)
}
// Protect against future fields that aren't specified
expectedResp := dbplugin.DeleteUserResponse{}
if !reflect.DeepEqual(deleteResp, expectedResp) {
t.Fatalf("Fields missing from expected response: Actual: %#v", deleteResp)
}
assertCredsDoNotExist(t, connURL, dbUser, initPassword)
}
func TestDeleteUserContainedDB(t *testing.T) {
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
dbUser := "vaultuser"
initPassword := "p4$sw0rd"
initReq := dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"contained_db": true,
},
VerifyConnection: true,
}
db := new()
dbtesting.AssertInitializeCircleCiTest(t, db, initReq)
defer dbtesting.AssertClose(t, db)
err := createTestMSSQLUser(connURL, dbUser, initPassword, testMSSQLContainedLogin)
if err != nil {
t.Fatalf("Failed to create user: %s", err)
}
assertCredsExist(t, connURL, dbUser, initPassword)
deleteReq := dbplugin.DeleteUserRequest{
Username: dbUser,
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
deleteResp, err := db.DeleteUser(ctx, deleteReq)
if err != nil {
t.Fatalf("Failed to delete user: %s", err)
}
// Protect against future fields that aren't specified
expectedResp := dbplugin.DeleteUserResponse{}
if !reflect.DeepEqual(deleteResp, expectedResp) {
t.Fatalf("Fields missing from expected response: Actual: %#v", deleteResp)
}
assertContainedDBCredsDoNotExist(t, connURL, dbUser)
}
func TestContainedDBSQLSanitization(t *testing.T) {
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
injectionString := "vaultuser]"
dbUser := "vaultuser"
initPassword := "p4$sw0rd"
initReq := dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
},
VerifyConnection: true,
}
db := new()
dbtesting.AssertInitializeCircleCiTest(t, db, initReq)
defer dbtesting.AssertClose(t, db)
err := createTestMSSQLUser(connURL, dbUser, initPassword, testMSSQLContainedLogin)
if err != nil {
t.Fatalf("Failed to create user: %s", err)
}
assertCredsExist(t, connURL, dbUser, initPassword)
deleteReq := dbplugin.DeleteUserRequest{
Username: injectionString,
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err = db.DeleteUser(ctx, deleteReq)
assert.EqualError(t, err, "mssql: Cannot alter the login 'vaultuser]', because it does not exist or you do not have permission.")
}
func TestSQLSanitization(t *testing.T) {
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
injectionString := "vaultuser]"
dbUser := "vaultuser"
initPassword := "p4$sw0rd"
initReq := dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
},
VerifyConnection: true,
}
db := new()
dbtesting.AssertInitializeCircleCiTest(t, db, initReq)
defer dbtesting.AssertClose(t, db)
err := createTestMSSQLUser(connURL, dbUser, initPassword, testMSSQLLogin)
if err != nil {
t.Fatalf("Failed to create user: %s", err)
}
assertCredsExist(t, connURL, dbUser, initPassword)
deleteReq := dbplugin.DeleteUserRequest{
Username: injectionString,
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err = db.DeleteUser(ctx, deleteReq)
assert.EqualError(t, err, "mssql: Cannot alter the login 'vaultuser]', because it does not exist or you do not have permission.")
}
func assertCredsExist(t testing.TB, connURL, username, password string) {
t.Helper()
err := testCredsExist(connURL, username, password)
if err != nil {
t.Fatalf("Unable to log in as %q: %s", username, err)
}
}
func assertCredsDoNotExist(t testing.TB, connURL, username, password string) {
t.Helper()
err := testCredsExist(connURL, username, password)
if err == nil {
t.Fatalf("Able to log in when it shouldn't")
}
}
func assertContainedDBCredsDoNotExist(t testing.TB, connURL, username string) {
t.Helper()
err := testContainedDBCredsExist(connURL, username)
assert.EqualError(t, err, "mssql: Cannot drop the user 'vaultuser', because it does not exist or you do not have permission.")
}
func testContainedDBCredsExist(connURL, username string) error {
ctx := context.Background()
// Log in
db, err := sql.Open("mssql", connURL)
if err != nil {
return err
}
defer db.Close()
userStmt, err := db.PrepareContext(ctx, fmt.Sprintf("DROP USER [%s]", username))
if err != nil {
return err
}
_, err = userStmt.ExecContext(ctx)
defer userStmt.Close()
return err
}
func testCredsExist(connURL, username, password string) error {
// Log in with the new creds
parts := strings.Split(connURL, "@")
connURL = fmt.Sprintf("sqlserver://%s:%s@%s", username, password, parts[1])
db, err := sql.Open("mssql", connURL)
if err != nil {
return err
}
defer db.Close()
return db.Ping()
}
func createTestMSSQLUser(connURL string, username, password, query string) error {
db, err := sql.Open("mssql", connURL)
defer db.Close()
if err != nil {
return err
}
// Start a transaction
ctx := context.Background()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() {
_ = tx.Rollback()
}()
m := map[string]string{
"name": username,
"password": password,
}
if err := dbtxn.ExecuteTxQueryDirect(ctx, tx, m, query); err != nil {
return err
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return err
}
return nil
}
const testMSSQLRole = `
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
CREATE USER [{{name}}] FOR LOGIN [{{name}}];
GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];`
const testMSSQLLogin = `
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
`
const testMSSQLContainedLogin = `
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
CREATE USER [{{name}}] FOR LOGIN [{{name}}];
`