Validate physical MySQL database and table config values before using them (#9189)

* Validate database & table names prior to using it in SQL
This commit is contained in:
Michael Golowka 2020-06-12 11:08:56 -06:00 committed by GitHub
parent 889c9d6f06
commit 8d022cbe9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 173 additions and 27 deletions

View File

@ -15,8 +15,10 @@ import (
"strings"
"sync"
"time"
"unicode"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
metrics "github.com/armon/go-metrics"
mysql "github.com/go-sql-driver/mysql"
@ -59,15 +61,21 @@ func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backen
return nil, err
}
database, ok := conf["database"]
if !ok {
database := conf["database"]
if database == "" {
database = "vault"
}
table, ok := conf["table"]
if !ok {
table := conf["table"]
if table == "" {
table = "vault"
}
dbTable := "`" + database + "`.`" + table + "`"
err = validateDBTable(database, table)
if err != nil {
return nil, err
}
dbTable := fmt.Sprintf("`%s`.`%s`", database, table)
maxParStr, ok := conf["max_parallel"]
var maxParInt int
@ -193,6 +201,67 @@ func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backen
return m, nil
}
// validateDBTable to prevent SQL injection attacks. This ensures that the database and table names only have valid
// characters in them. MySQL allows for more characters that this will allow, but there isn't an easy way of
// representing the full Unicode Basic Multilingual Plane to check against.
// https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
func validateDBTable(db, table string) (err error) {
merr := &multierror.Error{}
merr = multierror.Append(merr, wrapErr("invalid database: %w", validate(db)))
merr = multierror.Append(merr, wrapErr("invalid table: %w", validate(table)))
return merr.ErrorOrNil()
}
func validate(name string) (err error) {
if name == "" {
return fmt.Errorf("missing name")
}
// From: https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
// - Permitted characters in quoted identifiers include the full Unicode Basic Multilingual Plane (BMP), except U+0000:
// ASCII: U+0001 .. U+007F
// Extended: U+0080 .. U+FFFF
// - ASCII NUL (U+0000) and supplementary characters (U+10000 and higher) are not permitted in quoted or unquoted identifiers.
// - Identifiers may begin with a digit but unless quoted may not consist solely of digits.
// - Database, table, and column names cannot end with space characters.
//
// We are explicitly excluding all space characters (it's easier to deal with)
// The name will be quoted, so the all-digit requirement doesn't apply
runes := []rune(name)
validationErr := fmt.Errorf("invalid character found: can only include printable, non-space characters between [0x0001-0xFFFF]")
for _, r := range runes {
// U+0000 Explicitly disallowed
if r == 0x0000 {
return fmt.Errorf("invalid character: cannot include 0x0000")
}
// Cannot be above 0xFFFF
if r > 0xFFFF {
return fmt.Errorf("invalid character: cannot include any characters above 0xFFFF")
}
if r == '`' {
return fmt.Errorf("invalid character: cannot include '`' character")
}
if r == '\'' || r == '"' {
return fmt.Errorf("invalid character: cannot include quotes")
}
// We are excluding non-printable characters (not mentioned in the docs)
if !unicode.IsPrint(r) {
return validationErr
}
// We are excluding space characters (not mentioned in the docs)
if unicode.IsSpace(r) {
return validationErr
}
}
return nil
}
func wrapErr(message string, err error) error {
if err == nil {
return nil
}
return fmt.Errorf(message, err)
}
func NewMySQLClient(conf map[string]string, logger log.Logger) (*sql.DB, error) {
var err error

View File

@ -43,11 +43,11 @@ func TestMySQLPlaintextCatch(t *testing.T) {
logger := logging.NewVaultLogger(log.Debug)
NewMySQLBackend(map[string]string{
"address": address,
"database": database,
"table": table,
"username": username,
"password": password,
"address": address,
"database": database,
"table": table,
"username": username,
"password": password,
"plaintext_connection_allowed": "false",
}, logger)
@ -82,11 +82,11 @@ func TestMySQLBackend(t *testing.T) {
logger := logging.NewVaultLogger(log.Debug)
b, err := NewMySQLBackend(map[string]string{
"address": address,
"database": database,
"table": table,
"username": username,
"password": password,
"address": address,
"database": database,
"table": table,
"username": username,
"password": password,
"plaintext_connection_allowed": "true",
}, logger)
@ -128,12 +128,12 @@ func TestMySQLHABackend(t *testing.T) {
// Run vault tests
logger := logging.NewVaultLogger(log.Debug)
config := map[string]string{
"address": address,
"database": database,
"table": table,
"username": username,
"password": password,
"ha_enabled": "true",
"address": address,
"database": database,
"table": table,
"username": username,
"password": password,
"ha_enabled": "true",
"plaintext_connection_allowed": "true",
}
@ -176,12 +176,12 @@ func TestMySQLHABackend_LockFailPanic(t *testing.T) {
table := "test"
logger := logging.NewVaultLogger(log.Debug)
config := map[string]string{
"address": cfg.Addr,
"database": cfg.DBName,
"table": table,
"username": cfg.User,
"password": cfg.Passwd,
"ha_enabled": "true",
"address": cfg.Addr,
"database": cfg.DBName,
"table": table,
"username": cfg.User,
"password": cfg.Passwd,
"ha_enabled": "true",
"plaintext_connection_allowed": "true",
}
@ -265,3 +265,80 @@ func TestMySQLHABackend_LockFailPanic(t *testing.T) {
t.Fatalf("expected error, got none")
}
}
func TestValidateDBTable(t *testing.T) {
type testCase struct {
database string
table string
expectErr bool
}
tests := map[string]testCase{
"empty database & table": {"", "", true},
"empty database": {"", "a", true},
"empty table": {"a", "", true},
"ascii database": {"abcde", "a", false},
"ascii table": {"a", "abcde", false},
"ascii database & table": {"abcde", "abcde", false},
"only whitespace db": {" ", "a", true},
"only whitespace table": {"a", " ", true},
"whitespace prefix db": {" bcde", "a", true},
"whitespace middle db": {"ab de", "a", true},
"whitespace suffix db": {"abcd ", "a", true},
"whitespace prefix table": {"a", " bcde", true},
"whitespace middle table": {"a", "ab de", true},
"whitespace suffix table": {"a", "abcd ", true},
"backtick prefix db": {"`bcde", "a", true},
"backtick middle db": {"ab`de", "a", true},
"backtick suffix db": {"abcd`", "a", true},
"backtick prefix table": {"a", "`bcde", true},
"backtick middle table": {"a", "ab`de", true},
"backtick suffix table": {"a", "abcd`", true},
"single quote prefix db": {"'bcde", "a", true},
"single quote middle db": {"ab'de", "a", true},
"single quote suffix db": {"abcd'", "a", true},
"single quote prefix table": {"a", "'bcde", true},
"single quote middle table": {"a", "ab'de", true},
"single quote suffix table": {"a", "abcd'", true},
"double quote prefix db": {`"bcde`, "a", true},
"double quote middle db": {`ab"de`, "a", true},
"double quote suffix db": {`abcd"`, "a", true},
"double quote prefix table": {"a", `"bcde`, true},
"double quote middle table": {"a", `ab"de`, true},
"double quote suffix table": {"a", `abcd"`, true},
"0x0000 prefix db": {str(0x0000, 'b', 'c'), "a", true},
"0x0000 middle db": {str('a', 0x0000, 'c'), "a", true},
"0x0000 suffix db": {str('a', 'b', 0x0000), "a", true},
"0x0000 prefix table": {"a", str(0x0000, 'b', 'c'), true},
"0x0000 middle table": {"a", str('a', 0x0000, 'c'), true},
"0x0000 suffix table": {"a", str('a', 'b', 0x0000), true},
"unicode > 0xFFFF prefix db": {str(0x10000, 'b', 'c'), "a", true},
"unicode > 0xFFFF middle db": {str('a', 0x10000, 'c'), "a", true},
"unicode > 0xFFFF suffix db": {str('a', 'b', 0x10000), "a", true},
"unicode > 0xFFFF prefix table": {"a", str(0x10000, 'b', 'c'), true},
"unicode > 0xFFFF middle table": {"a", str('a', 0x10000, 'c'), true},
"unicode > 0xFFFF suffix table": {"a", str('a', 'b', 0x10000), true},
"non-printable prefix db": {str(0x0001, 'b', 'c'), "a", true},
"non-printable middle db": {str('a', 0x0001, 'c'), "a", true},
"non-printable suffix db": {str('a', 'b', 0x0001), "a", true},
"non-printable prefix table": {"a", str(0x0001, 'b', 'c'), true},
"non-printable middle table": {"a", str('a', 0x0001, 'c'), true},
"non-printable suffix table": {"a", str('a', 'b', 0x0001), true},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
err := validateDBTable(test.database, test.table)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}
func str(r ...rune) string {
return string(r)
}