diff --git a/changelog/19591.txt b/changelog/19591.txt new file mode 100644 index 000000000..f15d3979a --- /dev/null +++ b/changelog/19591.txt @@ -0,0 +1,3 @@ +```release-note:improvement +core: validate name identifiers in mssql physical storage backend prior use +``` diff --git a/physical/mssql/mssql.go b/physical/mssql/mssql.go index bbcb68332..2859a65ef 100644 --- a/physical/mssql/mssql.go +++ b/physical/mssql/mssql.go @@ -7,6 +7,7 @@ import ( "context" "database/sql" "fmt" + "regexp" "sort" "strconv" "strings" @@ -21,6 +22,7 @@ import ( // Verify MSSQLBackend satisfies the correct interfaces var _ physical.Backend = (*MSSQLBackend)(nil) +var identifierRegex = regexp.MustCompile(`^[\p{L}_][\p{L}\p{Nd}@#$_]*$`) type MSSQLBackend struct { dbTable string @@ -30,6 +32,13 @@ type MSSQLBackend struct { permitPool *physical.PermitPool } +func isInvalidIdentifier(name string) bool { + if !identifierRegex.MatchString(name) { + return true + } + return false +} + func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { username, ok := conf["username"] if !ok { @@ -71,11 +80,19 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen database = "Vault" } + if isInvalidIdentifier(database) { + return nil, fmt.Errorf("invalid database name") + } + table, ok := conf["table"] if !ok { table = "Vault" } + if isInvalidIdentifier(table) { + return nil, fmt.Errorf("invalid table name") + } + appname, ok := conf["appname"] if !ok { appname = "Vault" @@ -96,6 +113,10 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen schema = "dbo" } + if isInvalidIdentifier(schema) { + return nil, fmt.Errorf("invalid schema name") + } + connectionString := fmt.Sprintf("server=%s;app name=%s;connection timeout=%s;log=%s", server, appname, connectionTimeout, logLevel) if username != "" { connectionString += ";user id=" + username @@ -116,18 +137,17 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen db.SetMaxOpenConns(maxParInt) - if _, err := db.Exec("IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '" + database + "') CREATE DATABASE " + database); err != nil { + if _, err := db.Exec("IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = ?) CREATE DATABASE "+database, database); err != nil { return nil, fmt.Errorf("failed to create mssql database: %w", err) } dbTable := database + "." + schema + "." + table - createQuery := "IF NOT EXISTS(SELECT 1 FROM " + database + ".INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND TABLE_NAME='" + table + "' AND TABLE_SCHEMA='" + schema + - "') CREATE TABLE " + dbTable + " (Path VARCHAR(512) PRIMARY KEY, Value VARBINARY(MAX))" + createQuery := "IF NOT EXISTS(SELECT 1 FROM " + database + ".INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND TABLE_NAME=? AND TABLE_SCHEMA=?) CREATE TABLE " + dbTable + " (Path VARCHAR(512) PRIMARY KEY, Value VARBINARY(MAX))" if schema != "dbo" { var num int - err = db.QueryRow("SELECT 1 FROM " + database + ".sys.schemas WHERE name = '" + schema + "'").Scan(&num) + err = db.QueryRow("SELECT 1 FROM "+database+".sys.schemas WHERE name = ?", schema).Scan(&num) switch { case err == sql.ErrNoRows: @@ -140,7 +160,7 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen } } - if _, err := db.Exec(createQuery); err != nil { + if _, err := db.Exec(createQuery, table, schema); err != nil { return nil, fmt.Errorf("failed to create mssql table: %w", err) } diff --git a/physical/mssql/mssql_test.go b/physical/mssql/mssql_test.go index fc40a7722..2324ff5c0 100644 --- a/physical/mssql/mssql_test.go +++ b/physical/mssql/mssql_test.go @@ -7,13 +7,53 @@ import ( "os" "testing" + _ "github.com/denisenkom/go-mssqldb" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/physical" - - _ "github.com/denisenkom/go-mssqldb" ) +// TestInvalidIdentifier checks validity of an identifier +func TestInvalidIdentifier(t *testing.T) { + testcases := map[string]bool{ + "name": true, + "_name": true, + "Name": true, + "#name": false, + "?Name": false, + "9name": false, + "@name": false, + "$name": false, + " name": false, + "n ame": false, + "n4444444": true, + "_4321098765": true, + "_##$$@@__": true, + "_123name#@": true, + "name!": false, + "name%": false, + "name^": false, + "name&": false, + "name*": false, + "name(": false, + "name)": false, + "nåame": true, + "åname": true, + "name'": false, + "nam`e": false, + "пример": true, + "_#Āā@#$_ĂĄąćĈĉĊċ": true, + "ÛÜÝÞßàáâ": true, + "豈更滑a23$#@": true, + } + + for i, expected := range testcases { + if !isInvalidIdentifier(i) != expected { + t.Fatalf("unexpected identifier %s: expected validity %v", i, expected) + } + } +} + func TestMSSQLBackend(t *testing.T) { server := os.Getenv("MSSQL_SERVER") if server == "" {