parent
d6e1dec686
commit
85ead99d64
|
@ -0,0 +1,3 @@
|
||||||
|
```release-note:improvement
|
||||||
|
core: validate name identifiers in mssql physical storage backend prior use
|
||||||
|
```
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -21,6 +22,7 @@ import (
|
||||||
|
|
||||||
// Verify MSSQLBackend satisfies the correct interfaces
|
// Verify MSSQLBackend satisfies the correct interfaces
|
||||||
var _ physical.Backend = (*MSSQLBackend)(nil)
|
var _ physical.Backend = (*MSSQLBackend)(nil)
|
||||||
|
var identifierRegex = regexp.MustCompile(`^[\p{L}_][\p{L}\p{Nd}@#$_]*$`)
|
||||||
|
|
||||||
type MSSQLBackend struct {
|
type MSSQLBackend struct {
|
||||||
dbTable string
|
dbTable string
|
||||||
|
@ -30,6 +32,13 @@ type MSSQLBackend struct {
|
||||||
permitPool *physical.PermitPool
|
permitPool *physical.PermitPool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isInvalidIdentifier(name string) bool {
|
||||||
|
if !identifierRegex.MatchString(name) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
||||||
username, ok := conf["username"]
|
username, ok := conf["username"]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -71,11 +80,19 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen
|
||||||
database = "Vault"
|
database = "Vault"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isInvalidIdentifier(database) {
|
||||||
|
return nil, fmt.Errorf("invalid database name")
|
||||||
|
}
|
||||||
|
|
||||||
table, ok := conf["table"]
|
table, ok := conf["table"]
|
||||||
if !ok {
|
if !ok {
|
||||||
table = "Vault"
|
table = "Vault"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isInvalidIdentifier(table) {
|
||||||
|
return nil, fmt.Errorf("invalid table name")
|
||||||
|
}
|
||||||
|
|
||||||
appname, ok := conf["appname"]
|
appname, ok := conf["appname"]
|
||||||
if !ok {
|
if !ok {
|
||||||
appname = "Vault"
|
appname = "Vault"
|
||||||
|
@ -96,6 +113,10 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen
|
||||||
schema = "dbo"
|
schema = "dbo"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isInvalidIdentifier(schema) {
|
||||||
|
return nil, fmt.Errorf("invalid schema name")
|
||||||
|
}
|
||||||
|
|
||||||
connectionString := fmt.Sprintf("server=%s;app name=%s;connection timeout=%s;log=%s", server, appname, connectionTimeout, logLevel)
|
connectionString := fmt.Sprintf("server=%s;app name=%s;connection timeout=%s;log=%s", server, appname, connectionTimeout, logLevel)
|
||||||
if username != "" {
|
if username != "" {
|
||||||
connectionString += ";user id=" + username
|
connectionString += ";user id=" + username
|
||||||
|
@ -116,18 +137,17 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen
|
||||||
|
|
||||||
db.SetMaxOpenConns(maxParInt)
|
db.SetMaxOpenConns(maxParInt)
|
||||||
|
|
||||||
if _, err := db.Exec("IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '" + database + "') CREATE DATABASE " + database); err != nil {
|
if _, err := db.Exec("IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = ?) CREATE DATABASE "+database, database); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create mssql database: %w", err)
|
return nil, fmt.Errorf("failed to create mssql database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dbTable := database + "." + schema + "." + table
|
dbTable := database + "." + schema + "." + table
|
||||||
createQuery := "IF NOT EXISTS(SELECT 1 FROM " + database + ".INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND TABLE_NAME='" + table + "' AND TABLE_SCHEMA='" + schema +
|
createQuery := "IF NOT EXISTS(SELECT 1 FROM " + database + ".INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND TABLE_NAME=? AND TABLE_SCHEMA=?) CREATE TABLE " + dbTable + " (Path VARCHAR(512) PRIMARY KEY, Value VARBINARY(MAX))"
|
||||||
"') CREATE TABLE " + dbTable + " (Path VARCHAR(512) PRIMARY KEY, Value VARBINARY(MAX))"
|
|
||||||
|
|
||||||
if schema != "dbo" {
|
if schema != "dbo" {
|
||||||
|
|
||||||
var num int
|
var num int
|
||||||
err = db.QueryRow("SELECT 1 FROM " + database + ".sys.schemas WHERE name = '" + schema + "'").Scan(&num)
|
err = db.QueryRow("SELECT 1 FROM "+database+".sys.schemas WHERE name = ?", schema).Scan(&num)
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case err == sql.ErrNoRows:
|
case err == sql.ErrNoRows:
|
||||||
|
@ -140,7 +160,7 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := db.Exec(createQuery); err != nil {
|
if _, err := db.Exec(createQuery, table, schema); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create mssql table: %w", err)
|
return nil, fmt.Errorf("failed to create mssql table: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,13 +7,53 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
_ "github.com/denisenkom/go-mssqldb"
|
||||||
log "github.com/hashicorp/go-hclog"
|
log "github.com/hashicorp/go-hclog"
|
||||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||||
"github.com/hashicorp/vault/sdk/physical"
|
"github.com/hashicorp/vault/sdk/physical"
|
||||||
|
|
||||||
_ "github.com/denisenkom/go-mssqldb"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TestInvalidIdentifier checks validity of an identifier
|
||||||
|
func TestInvalidIdentifier(t *testing.T) {
|
||||||
|
testcases := map[string]bool{
|
||||||
|
"name": true,
|
||||||
|
"_name": true,
|
||||||
|
"Name": true,
|
||||||
|
"#name": false,
|
||||||
|
"?Name": false,
|
||||||
|
"9name": false,
|
||||||
|
"@name": false,
|
||||||
|
"$name": false,
|
||||||
|
" name": false,
|
||||||
|
"n ame": false,
|
||||||
|
"n4444444": true,
|
||||||
|
"_4321098765": true,
|
||||||
|
"_##$$@@__": true,
|
||||||
|
"_123name#@": true,
|
||||||
|
"name!": false,
|
||||||
|
"name%": false,
|
||||||
|
"name^": false,
|
||||||
|
"name&": false,
|
||||||
|
"name*": false,
|
||||||
|
"name(": false,
|
||||||
|
"name)": false,
|
||||||
|
"nåame": true,
|
||||||
|
"åname": true,
|
||||||
|
"name'": false,
|
||||||
|
"nam`e": false,
|
||||||
|
"пример": true,
|
||||||
|
"_#Āā@#$_ĂĄąćĈĉĊċ": true,
|
||||||
|
"ÛÜÝÞßàáâ": true,
|
||||||
|
"豈更滑a23$#@": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, expected := range testcases {
|
||||||
|
if !isInvalidIdentifier(i) != expected {
|
||||||
|
t.Fatalf("unexpected identifier %s: expected validity %v", i, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMSSQLBackend(t *testing.T) {
|
func TestMSSQLBackend(t *testing.T) {
|
||||||
server := os.Getenv("MSSQL_SERVER")
|
server := os.Getenv("MSSQL_SERVER")
|
||||||
if server == "" {
|
if server == "" {
|
||||||
|
|
Loading…
Reference in New Issue