diff --git a/physical/mssql.go b/physical/mssql.go new file mode 100644 index 000000000..25709a22b --- /dev/null +++ b/physical/mssql.go @@ -0,0 +1,216 @@ +package physical + +import ( + "database/sql" + "fmt" + "sort" + "strings" + "time" + + "github.com/armon/go-metrics" + _ "github.com/denisenkom/go-mssqldb" + log "github.com/mgutz/logxi/v1" +) + +type MsSQLBackend struct { + dbTable string + client *sql.DB + statements map[string]*sql.Stmt + logger log.Logger +} + +func newMsSQLBackend(conf map[string]string, logger log.Logger) (Backend, error) { + username, ok := conf["username"] + if !ok { + username = "" + } + + password, ok := conf["password"] + if !ok { + password = "" + } + + server, ok := conf["server"] + if !ok || server == "" { + return nil, fmt.Errorf("missing server") + } + + database, ok := conf["database"] + if !ok { + database = "Vault" + } + + table, ok := conf["table"] + if !ok { + table = "Vault" + } + + appname, ok := conf["appname"] + if !ok { + appname = "Vault" + } + + connectionTimeout, ok := conf["connectiontimeout"] + if !ok { + connectionTimeout = "30" + } + + logLevel, ok := conf["loglevel"] + if !ok { + logLevel = "0" + } + + schema, ok := conf["schema"] + if !ok || schema == "" { + schema = "dbo" + } + + connectionString := fmt.Sprintf("server=%s;app name=%s;connection timeout=%s;log=%s", server, appname, connectionTimeout, logLevel) + if username != "" { + connectionString += ";user id=" + username + } + + if password != "" { + connectionString += ";password=" + password + } + + db, err := sql.Open("mssql", connectionString) + if err != nil { + return nil, fmt.Errorf("failed to connect to mssql: %v", err) + } + + if _, err := db.Exec("IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '" + database + "') CREATE DATABASE " + database); err != nil { + return nil, fmt.Errorf("failed to create mssql database: %v", err) + } + + dbTable := database + "." + schema + "." + table + createQuery := "IF NOT EXISTS(SELECT 1 FROM " + database + ".INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND TABLE_NAME='" + table + "' AND TABLE_SCHEMA='" + schema + + "') CREATE TABLE " + dbTable + " (Path VARCHAR(512) PRIMARY KEY, Value VARBINARY(MAX))" + + if schema != "dbo" { + if _, err := db.Exec("USE " + database); err != nil { + return nil, fmt.Errorf("failed to switch mssql database: %v", err) + } + + var num int + err = db.QueryRow("SELECT 1 FROM sys.schemas WHERE name = '" + schema + "'").Scan(&num) + + switch { + case err == sql.ErrNoRows: + if _, err := db.Exec("CREATE SCHEMA " + schema); err != nil { + return nil, fmt.Errorf("failed to create mssql schema: %v", err) + } + + case err != nil: + return nil, fmt.Errorf("failed to check if mssql schema exists: %v", err) + } + } + + if _, err := db.Exec(createQuery); err != nil { + return nil, fmt.Errorf("failed to create mssql table: %v", err) + } + + m := &MsSQLBackend{ + dbTable: dbTable, + client: db, + statements: make(map[string]*sql.Stmt), + logger: logger, + } + + statements := map[string]string{ + "put": "IF EXISTS(SELECT 1 FROM " + dbTable + " WHERE Path = ?) UPDATE " + dbTable + " SET Value = ? WHERE Path = ?" + + " ELSE INSERT INTO " + dbTable + " VALUES(?, ?)", + "get": "SELECT Value FROM " + dbTable + " WHERE Path = ?", + "delete": "DELETE FROM " + dbTable + " WHERE Path = ?", + "list": "SELECT Path FROM " + dbTable + " WHERE Path LIKE ?", + } + + for name, query := range statements { + if err := m.prepare(name, query); err != nil { + return nil, err + } + } + + return m, nil +} + +func (m *MsSQLBackend) prepare(name, query string) error { + stmt, err := m.client.Prepare(query) + if err != nil { + return fmt.Errorf("failed to prepare '%s': %v", name, err) + } + + m.statements[name] = stmt + + return nil +} + +func (m *MsSQLBackend) Put(entry *Entry) error { + defer metrics.MeasureSince([]string{"mssql", "put"}, time.Now()) + + _, err := m.statements["put"].Exec(entry.Key, entry.Value, entry.Key, entry.Key, entry.Value) + if err != nil { + return err + } + + return nil +} + +func (m *MsSQLBackend) Get(key string) (*Entry, error) { + defer metrics.MeasureSince([]string{"mssql", "get"}, time.Now()) + + var result []byte + err := m.statements["get"].QueryRow(key).Scan(&result) + if err == sql.ErrNoRows { + return nil, nil + } + + if err != nil { + return nil, err + } + + ent := &Entry{ + Key: key, + Value: result, + } + + return ent, nil +} + +func (m *MsSQLBackend) Delete(key string) error { + defer metrics.MeasureSince([]string{"mssql", "delete"}, time.Now()) + + _, err := m.statements["delete"].Exec(key) + if err != nil { + return err + } + + return nil +} + +func (m *MsSQLBackend) List(prefix string) ([]string, error) { + defer metrics.MeasureSince([]string{"mssql", "list"}, time.Now()) + + likePrefix := prefix + "%" + rows, err := m.statements["list"].Query(likePrefix) + + var keys []string + for rows.Next() { + var key string + err = rows.Scan(&key) + if err != nil { + return nil, fmt.Errorf("failed to scan rows: %v", err) + } + + key = strings.TrimPrefix(key, prefix) + if i := strings.Index(key, "/"); i == -1 { + keys = append(keys, key) + } else if i != -1 { + keys = appendIfMissing(keys, string(key[:i+1])) + } + } + + sort.Strings(keys) + + return keys, nil +} diff --git a/physical/mssql_test.go b/physical/mssql_test.go new file mode 100644 index 000000000..11f4684ea --- /dev/null +++ b/physical/mssql_test.go @@ -0,0 +1,58 @@ +package physical + +import ( + "os" + "testing" + + "github.com/hashicorp/vault/helper/logformat" + log "github.com/mgutz/logxi/v1" + + _ "github.com/denisenkom/go-mssqldb" +) + +func TestMsSQLBackend(t *testing.T) { + server := os.Getenv("MSSQL_SERVER") + if server == "" { + t.SkipNow() + } + + database := os.Getenv("MSSQL_DB") + if database == "" { + database = "test" + } + + table := os.Getenv("MSSQL_TABLE") + if table == "" { + table = "test" + } + + username := os.Getenv("MSSQL_USERNAME") + password := os.Getenv("MSSQL_PASSWORD") + + // Run vault tests + logger := logformat.NewVaultLogger(log.LevelTrace) + + b, err := NewBackend("mssql", logger, map[string]string{ + "server": server, + "database": database, + "table": table, + "username": username, + "password": password, + }) + + if err != nil { + t.Fatalf("Failed to create new backend: %v", err) + } + + defer func() { + mssql := b.(*MsSQLBackend) + _, err := mssql.client.Exec("DROP TABLE " + mssql.dbTable) + if err != nil { + t.Fatalf("Failed to drop table: %v", err) + } + }() + + testBackend(t, b) + testBackend_ListPrefix(t, b) + +} diff --git a/physical/physical.go b/physical/physical.go index 568ffe917..b35d281ce 100644 --- a/physical/physical.go +++ b/physical/physical.go @@ -148,6 +148,7 @@ var builtinBackends = map[string]Factory{ "azure": newAzureBackend, "dynamodb": newDynamoDBBackend, "etcd": newEtcdBackend, + "mssql": newMsSQLBackend, "mysql": newMySQLBackend, "postgresql": newPostgreSQLBackend, "swift": newSwiftBackend, diff --git a/website/source/docs/configuration/storage/mssql.html.md b/website/source/docs/configuration/storage/mssql.html.md new file mode 100644 index 000000000..7f1fca80d --- /dev/null +++ b/website/source/docs/configuration/storage/mssql.html.md @@ -0,0 +1,78 @@ +--- +layout: "docs" +page_title: "MSSQL - Storage Backends - Configuration" +sidebar_current: "docs-configuration-storage-mssql" +description: |- + The MSSQL storage backend is used to persist Vault's data in a Microsoft SQL Server. +--- + +# MSSQL Storage Backend + +The MSSQL storage backend is used to persist Vault's data in a Microsoft SQL Server. + +- **No High Availability** – the MSSQL storage backend does not support high + availability. + +- **Community Supported** – the MSSQL storage backend is supported by the + community. While it has undergone review by HashiCorp employees, they may not + be as knowledgeable about the technology. If you encounter problems with them, + you may be referred to the original author. + +```hcl +storage "mssql" { + server = "localhost" + username = "user1234" + password = "secret123!" + database = "vault" + table = "vault" + appname = "vault" + schema = "dbo" + connectionTimeout = 30 + logLevel = 0 +} +``` + +## `mssql` Parameters + +- `server` `(string: )` – host or host\instance. + +- `username` `(string: "")` - enter the SQL Server Authentication user id or + the Windows Authentication user id in the DOMAIN\User format. + On Windows, if user id is empty or missing Single-Sign-On is used. + +- `password` `(string: "")` – specifies the MSSQL password to connect to + the database. + +- `database` `(string: "Vault")` – Specifies the name of the database. If the + database does not exist, Vault will attempt to create it. + +- `table` `(string: "Vault")` – Specifies the name of the table. If the table + does not exist, Vault will attempt to create it. + +- `schema` `(string: "dbo")` – Specifies the name of the schema. If the schema + does not exist, Vault will attempt to create it. + +- `appname` `(string: "Vault")` – the application name. + +- `connectionTimeout` `(int: 30)` – in seconds (default is 30). + +- `logLevel` `(int: 0)` – logging flags (default 0/no logging, 63 for full logging) . + +## `mssql` Examples + +### Custom Database, Table and Schema + +This example shows configuring the MSSQL backend to use a custom database and +table name. + +```hcl +storage "mssql" { + database = "my-vault" + table = "vault-data" + schema = "vlt" + username = "user1234" + password = "pass5678" +} +``` + +