27bb03bbc0
* adding copyright header * fix fmt and a test
347 lines
10 KiB
Go
347 lines
10 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package mysql
|
|
|
|
import (
|
|
"bytes"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/vault/sdk/helper/logging"
|
|
"github.com/hashicorp/vault/sdk/physical"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
mysql "github.com/go-sql-driver/mysql"
|
|
|
|
mysqlhelper "github.com/hashicorp/vault/helper/testhelpers/mysql"
|
|
)
|
|
|
|
func TestMySQLPlaintextCatch(t *testing.T) {
|
|
address := os.Getenv("MYSQL_ADDR")
|
|
if address == "" {
|
|
t.SkipNow()
|
|
}
|
|
|
|
database := os.Getenv("MYSQL_DB")
|
|
if database == "" {
|
|
database = "test"
|
|
}
|
|
|
|
table := os.Getenv("MYSQL_TABLE")
|
|
if table == "" {
|
|
table = "test"
|
|
}
|
|
|
|
username := os.Getenv("MYSQL_USERNAME")
|
|
password := os.Getenv("MYSQL_PASSWORD")
|
|
|
|
// Run vault tests
|
|
var buf bytes.Buffer
|
|
log.DefaultOutput = &buf
|
|
|
|
logger := logging.NewVaultLogger(log.Debug)
|
|
|
|
NewMySQLBackend(map[string]string{
|
|
"address": address,
|
|
"database": database,
|
|
"table": table,
|
|
"username": username,
|
|
"password": password,
|
|
"plaintext_connection_allowed": "false",
|
|
}, logger)
|
|
|
|
str := buf.String()
|
|
dataIdx := strings.IndexByte(str, ' ')
|
|
rest := str[dataIdx+1:]
|
|
|
|
if !strings.Contains(rest, "credentials will be sent in plaintext") {
|
|
t.Fatalf("No warning of plaintext credentials occurred")
|
|
}
|
|
}
|
|
|
|
func TestMySQLBackend(t *testing.T) {
|
|
address := os.Getenv("MYSQL_ADDR")
|
|
if address == "" {
|
|
t.SkipNow()
|
|
}
|
|
|
|
database := os.Getenv("MYSQL_DB")
|
|
if database == "" {
|
|
database = "test"
|
|
}
|
|
|
|
table := os.Getenv("MYSQL_TABLE")
|
|
if table == "" {
|
|
table = "test"
|
|
}
|
|
|
|
username := os.Getenv("MYSQL_USERNAME")
|
|
password := os.Getenv("MYSQL_PASSWORD")
|
|
|
|
// Run vault tests
|
|
logger := logging.NewVaultLogger(log.Debug)
|
|
|
|
b, err := NewMySQLBackend(map[string]string{
|
|
"address": address,
|
|
"database": database,
|
|
"table": table,
|
|
"username": username,
|
|
"password": password,
|
|
"plaintext_connection_allowed": "true",
|
|
"max_connection_lifetime": "1",
|
|
}, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create new backend: %v", err)
|
|
}
|
|
|
|
defer func() {
|
|
mysql := b.(*MySQLBackend)
|
|
_, err := mysql.client.Exec("DROP TABLE IF EXISTS " + mysql.dbTable + " ," + mysql.dbLockTable)
|
|
if err != nil {
|
|
t.Fatalf("Failed to drop table: %v", err)
|
|
}
|
|
}()
|
|
|
|
physical.ExerciseBackend(t, b)
|
|
physical.ExerciseBackend_ListPrefix(t, b)
|
|
}
|
|
|
|
func TestMySQLHABackend(t *testing.T) {
|
|
address := os.Getenv("MYSQL_ADDR")
|
|
if address == "" {
|
|
t.SkipNow()
|
|
}
|
|
|
|
database := os.Getenv("MYSQL_DB")
|
|
if database == "" {
|
|
database = "test"
|
|
}
|
|
|
|
table := os.Getenv("MYSQL_TABLE")
|
|
if table == "" {
|
|
table = "test"
|
|
}
|
|
|
|
username := os.Getenv("MYSQL_USERNAME")
|
|
password := os.Getenv("MYSQL_PASSWORD")
|
|
|
|
// Run vault tests
|
|
logger := logging.NewVaultLogger(log.Debug)
|
|
config := map[string]string{
|
|
"address": address,
|
|
"database": database,
|
|
"table": table,
|
|
"username": username,
|
|
"password": password,
|
|
"ha_enabled": "true",
|
|
"plaintext_connection_allowed": "true",
|
|
}
|
|
|
|
b, err := NewMySQLBackend(config, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create new backend: %v", err)
|
|
}
|
|
|
|
defer func() {
|
|
mysql := b.(*MySQLBackend)
|
|
_, err := mysql.client.Exec("DROP TABLE IF EXISTS " + mysql.dbTable + " ," + mysql.dbLockTable)
|
|
if err != nil {
|
|
t.Fatalf("Failed to drop table: %v", err)
|
|
}
|
|
}()
|
|
|
|
b2, err := NewMySQLBackend(config, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create new backend: %v", err)
|
|
}
|
|
|
|
physical.ExerciseHABackend(t, b.(physical.HABackend), b2.(physical.HABackend))
|
|
}
|
|
|
|
// TestMySQLHABackend_LockFailPanic is a regression test for the panic shown in
|
|
// https://github.com/hashicorp/vault/issues/8203 and patched in
|
|
// https://github.com/hashicorp/vault/pull/8229
|
|
func TestMySQLHABackend_LockFailPanic(t *testing.T) {
|
|
cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, "secret")
|
|
|
|
cfg, err := mysql.ParseDSN(connURL)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if err := mysqlhelper.TestCredsExist(t, connURL, cfg.User, cfg.Passwd); err != nil {
|
|
t.Fatalf("Could not connect with new credentials: %s", err)
|
|
}
|
|
|
|
table := "test"
|
|
logger := logging.NewVaultLogger(log.Debug)
|
|
config := map[string]string{
|
|
"address": cfg.Addr,
|
|
"database": cfg.DBName,
|
|
"table": table,
|
|
"username": cfg.User,
|
|
"password": cfg.Passwd,
|
|
"ha_enabled": "true",
|
|
"plaintext_connection_allowed": "true",
|
|
}
|
|
|
|
b, err := NewMySQLBackend(config, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create new backend: %v", err)
|
|
}
|
|
|
|
b2, err := NewMySQLBackend(config, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create new backend: %v", err)
|
|
}
|
|
|
|
b1ha := b.(physical.HABackend)
|
|
b2ha := b2.(physical.HABackend)
|
|
|
|
// Copied from ExerciseHABackend - ensuring things are normal at this point
|
|
// Get the lock
|
|
lock, err := b1ha.LockWith("foo", "bar")
|
|
if err != nil {
|
|
t.Fatalf("initial lock: %v", err)
|
|
}
|
|
|
|
// Attempt to lock
|
|
leaderCh, err := lock.Lock(nil)
|
|
if err != nil {
|
|
t.Fatalf("lock attempt 1: %v", err)
|
|
}
|
|
if leaderCh == nil {
|
|
t.Fatalf("missing leaderCh")
|
|
}
|
|
|
|
// Check the value
|
|
held, val, err := lock.Value()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if !held {
|
|
t.Errorf("should be held")
|
|
}
|
|
if val != "bar" {
|
|
t.Errorf("expected value bar: %v", err)
|
|
}
|
|
|
|
// Second acquisition should fail
|
|
lock2, err := b2ha.LockWith("foo", "baz")
|
|
if err != nil {
|
|
t.Fatalf("lock 2: %v", err)
|
|
}
|
|
|
|
stopCh := make(chan struct{})
|
|
time.AfterFunc(10*time.Second, func() {
|
|
close(stopCh)
|
|
})
|
|
|
|
// Attempt to lock - can't lock because lock1 is held - this is normal
|
|
leaderCh2, err := lock2.Lock(stopCh)
|
|
if err != nil {
|
|
t.Fatalf("stop lock 2: %v", err)
|
|
}
|
|
if leaderCh2 != nil {
|
|
t.Errorf("should not have gotten leaderCh: %v", leaderCh2)
|
|
}
|
|
// end normal
|
|
|
|
// Clean up the database. When Lock() is called, a new connection is created
|
|
// using the configuration. If that connection cannot be created, there was a
|
|
// panic due to not returning with the connection error. Here we intentionally
|
|
// break the config for b2, so a new connection can't be made, which would
|
|
// trigger the panic shown in https://github.com/hashicorp/vault/issues/8203
|
|
cleanup()
|
|
|
|
stopCh2 := make(chan struct{})
|
|
time.AfterFunc(10*time.Second, func() {
|
|
close(stopCh2)
|
|
})
|
|
leaderCh2, err = lock2.Lock(stopCh2)
|
|
if err == nil {
|
|
t.Fatalf("expected error, got none, leaderCh2=%v", leaderCh2)
|
|
}
|
|
}
|
|
|
|
func TestValidateDBTable(t *testing.T) {
|
|
type testCase struct {
|
|
database string
|
|
table string
|
|
expectErr bool
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"empty database & table": {"", "", true},
|
|
"empty database": {"", "a", true},
|
|
"empty table": {"a", "", true},
|
|
"ascii database": {"abcde", "a", false},
|
|
"ascii table": {"a", "abcde", false},
|
|
"ascii database & table": {"abcde", "abcde", false},
|
|
"only whitespace db": {" ", "a", true},
|
|
"only whitespace table": {"a", " ", true},
|
|
"whitespace prefix db": {" bcde", "a", true},
|
|
"whitespace middle db": {"ab de", "a", true},
|
|
"whitespace suffix db": {"abcd ", "a", true},
|
|
"whitespace prefix table": {"a", " bcde", true},
|
|
"whitespace middle table": {"a", "ab de", true},
|
|
"whitespace suffix table": {"a", "abcd ", true},
|
|
"backtick prefix db": {"`bcde", "a", true},
|
|
"backtick middle db": {"ab`de", "a", true},
|
|
"backtick suffix db": {"abcd`", "a", true},
|
|
"backtick prefix table": {"a", "`bcde", true},
|
|
"backtick middle table": {"a", "ab`de", true},
|
|
"backtick suffix table": {"a", "abcd`", true},
|
|
"single quote prefix db": {"'bcde", "a", true},
|
|
"single quote middle db": {"ab'de", "a", true},
|
|
"single quote suffix db": {"abcd'", "a", true},
|
|
"single quote prefix table": {"a", "'bcde", true},
|
|
"single quote middle table": {"a", "ab'de", true},
|
|
"single quote suffix table": {"a", "abcd'", true},
|
|
"double quote prefix db": {`"bcde`, "a", true},
|
|
"double quote middle db": {`ab"de`, "a", true},
|
|
"double quote suffix db": {`abcd"`, "a", true},
|
|
"double quote prefix table": {"a", `"bcde`, true},
|
|
"double quote middle table": {"a", `ab"de`, true},
|
|
"double quote suffix table": {"a", `abcd"`, true},
|
|
"0x0000 prefix db": {str(0x0000, 'b', 'c'), "a", true},
|
|
"0x0000 middle db": {str('a', 0x0000, 'c'), "a", true},
|
|
"0x0000 suffix db": {str('a', 'b', 0x0000), "a", true},
|
|
"0x0000 prefix table": {"a", str(0x0000, 'b', 'c'), true},
|
|
"0x0000 middle table": {"a", str('a', 0x0000, 'c'), true},
|
|
"0x0000 suffix table": {"a", str('a', 'b', 0x0000), true},
|
|
"unicode > 0xFFFF prefix db": {str(0x10000, 'b', 'c'), "a", true},
|
|
"unicode > 0xFFFF middle db": {str('a', 0x10000, 'c'), "a", true},
|
|
"unicode > 0xFFFF suffix db": {str('a', 'b', 0x10000), "a", true},
|
|
"unicode > 0xFFFF prefix table": {"a", str(0x10000, 'b', 'c'), true},
|
|
"unicode > 0xFFFF middle table": {"a", str('a', 0x10000, 'c'), true},
|
|
"unicode > 0xFFFF suffix table": {"a", str('a', 'b', 0x10000), true},
|
|
"non-printable prefix db": {str(0x0001, 'b', 'c'), "a", true},
|
|
"non-printable middle db": {str('a', 0x0001, 'c'), "a", true},
|
|
"non-printable suffix db": {str('a', 'b', 0x0001), "a", true},
|
|
"non-printable prefix table": {"a", str(0x0001, 'b', 'c'), true},
|
|
"non-printable middle table": {"a", str('a', 0x0001, 'c'), true},
|
|
"non-printable suffix table": {"a", str('a', 'b', 0x0001), true},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
err := validateDBTable(test.database, test.table)
|
|
if test.expectErr && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if !test.expectErr && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func str(r ...rune) string {
|
|
return string(r)
|
|
}
|