2017-08-03 17:24:27 +00:00
package mssql
2017-04-06 13:33:49 +00:00
import (
2018-01-19 06:44:44 +00:00
"context"
2017-04-06 13:33:49 +00:00
"database/sql"
"fmt"
"sort"
2017-07-17 17:04:49 +00:00
"strconv"
2017-04-06 13:33:49 +00:00
"strings"
"time"
"github.com/armon/go-metrics"
_ "github.com/denisenkom/go-mssqldb"
2017-07-17 17:04:49 +00:00
"github.com/hashicorp/errwrap"
2017-06-16 15:09:15 +00:00
"github.com/hashicorp/vault/helper/strutil"
2017-08-03 17:24:27 +00:00
"github.com/hashicorp/vault/physical"
2017-04-06 13:33:49 +00:00
log "github.com/mgutz/logxi/v1"
)
2018-01-20 01:44:24 +00:00
// Verify MSSQLBackend satisfies the correct interfaces
var _ physical . Backend = ( * MSSQLBackend ) ( nil )
2017-08-03 17:24:27 +00:00
type MSSQLBackend struct {
2017-04-06 13:33:49 +00:00
dbTable string
client * sql . DB
statements map [ string ] * sql . Stmt
logger log . Logger
2017-08-03 17:24:27 +00:00
permitPool * physical . PermitPool
2017-04-06 13:33:49 +00:00
}
2017-08-03 17:24:27 +00:00
func NewMSSQLBackend ( conf map [ string ] string , logger log . Logger ) ( physical . Backend , error ) {
2017-04-06 13:33:49 +00:00
username , ok := conf [ "username" ]
if ! ok {
username = ""
}
password , ok := conf [ "password" ]
if ! ok {
password = ""
}
server , ok := conf [ "server" ]
if ! ok || server == "" {
return nil , fmt . Errorf ( "missing server" )
}
2017-07-17 17:04:49 +00:00
maxParStr , ok := conf [ "max_parallel" ]
var maxParInt int
var err error
if ok {
maxParInt , err = strconv . Atoi ( maxParStr )
if err != nil {
return nil , errwrap . Wrapf ( "failed parsing max_parallel parameter: {{err}}" , err )
}
if logger . IsDebug ( ) {
logger . Debug ( "mysql: max_parallel set" , "max_parallel" , maxParInt )
}
} else {
2017-08-03 17:24:27 +00:00
maxParInt = physical . DefaultParallelOperations
2017-07-17 17:04:49 +00:00
}
2017-04-06 13:33:49 +00:00
database , ok := conf [ "database" ]
if ! ok {
database = "Vault"
}
table , ok := conf [ "table" ]
if ! ok {
table = "Vault"
}
appname , ok := conf [ "appname" ]
if ! ok {
appname = "Vault"
}
connectionTimeout , ok := conf [ "connectiontimeout" ]
if ! ok {
connectionTimeout = "30"
}
logLevel , ok := conf [ "loglevel" ]
if ! ok {
logLevel = "0"
}
schema , ok := conf [ "schema" ]
if ! ok || schema == "" {
schema = "dbo"
}
connectionString := fmt . Sprintf ( "server=%s;app name=%s;connection timeout=%s;log=%s" , server , appname , connectionTimeout , logLevel )
if username != "" {
connectionString += ";user id=" + username
}
if password != "" {
connectionString += ";password=" + password
}
db , err := sql . Open ( "mssql" , connectionString )
if err != nil {
return nil , fmt . Errorf ( "failed to connect to mssql: %v" , err )
}
2017-07-17 17:04:49 +00:00
db . SetMaxOpenConns ( maxParInt )
2017-04-06 13:33:49 +00:00
if _ , err := db . Exec ( "IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '" + database + "') CREATE DATABASE " + database ) ; err != nil {
return nil , fmt . Errorf ( "failed to create mssql database: %v" , err )
}
dbTable := database + "." + schema + "." + table
createQuery := "IF NOT EXISTS(SELECT 1 FROM " + database + ".INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND TABLE_NAME='" + table + "' AND TABLE_SCHEMA='" + schema +
"') CREATE TABLE " + dbTable + " (Path VARCHAR(512) PRIMARY KEY, Value VARBINARY(MAX))"
if schema != "dbo" {
if _ , err := db . Exec ( "USE " + database ) ; err != nil {
return nil , fmt . Errorf ( "failed to switch mssql database: %v" , err )
}
var num int
err = db . QueryRow ( "SELECT 1 FROM sys.schemas WHERE name = '" + schema + "'" ) . Scan ( & num )
switch {
case err == sql . ErrNoRows :
if _ , err := db . Exec ( "CREATE SCHEMA " + schema ) ; err != nil {
return nil , fmt . Errorf ( "failed to create mssql schema: %v" , err )
}
case err != nil :
return nil , fmt . Errorf ( "failed to check if mssql schema exists: %v" , err )
}
}
if _ , err := db . Exec ( createQuery ) ; err != nil {
return nil , fmt . Errorf ( "failed to create mssql table: %v" , err )
}
2017-08-03 17:24:27 +00:00
m := & MSSQLBackend {
2017-04-06 13:33:49 +00:00
dbTable : dbTable ,
client : db ,
statements : make ( map [ string ] * sql . Stmt ) ,
logger : logger ,
2017-08-03 17:24:27 +00:00
permitPool : physical . NewPermitPool ( maxParInt ) ,
2017-04-06 13:33:49 +00:00
}
statements := map [ string ] string {
"put" : "IF EXISTS(SELECT 1 FROM " + dbTable + " WHERE Path = ?) UPDATE " + dbTable + " SET Value = ? WHERE Path = ?" +
" ELSE INSERT INTO " + dbTable + " VALUES(?, ?)" ,
"get" : "SELECT Value FROM " + dbTable + " WHERE Path = ?" ,
"delete" : "DELETE FROM " + dbTable + " WHERE Path = ?" ,
"list" : "SELECT Path FROM " + dbTable + " WHERE Path LIKE ?" ,
}
for name , query := range statements {
if err := m . prepare ( name , query ) ; err != nil {
return nil , err
}
}
return m , nil
}
2017-08-03 17:24:27 +00:00
func ( m * MSSQLBackend ) prepare ( name , query string ) error {
2017-04-06 13:33:49 +00:00
stmt , err := m . client . Prepare ( query )
if err != nil {
return fmt . Errorf ( "failed to prepare '%s': %v" , name , err )
}
m . statements [ name ] = stmt
return nil
}
2018-01-19 06:44:44 +00:00
func ( m * MSSQLBackend ) Put ( ctx context . Context , entry * physical . Entry ) error {
2017-04-06 13:33:49 +00:00
defer metrics . MeasureSince ( [ ] string { "mssql" , "put" } , time . Now ( ) )
2017-07-17 17:04:49 +00:00
m . permitPool . Acquire ( )
defer m . permitPool . Release ( )
2017-04-06 13:33:49 +00:00
_ , err := m . statements [ "put" ] . Exec ( entry . Key , entry . Value , entry . Key , entry . Key , entry . Value )
if err != nil {
return err
}
return nil
}
2018-01-19 06:44:44 +00:00
func ( m * MSSQLBackend ) Get ( ctx context . Context , key string ) ( * physical . Entry , error ) {
2017-04-06 13:33:49 +00:00
defer metrics . MeasureSince ( [ ] string { "mssql" , "get" } , time . Now ( ) )
2017-07-17 17:04:49 +00:00
m . permitPool . Acquire ( )
defer m . permitPool . Release ( )
2017-04-06 13:33:49 +00:00
var result [ ] byte
err := m . statements [ "get" ] . QueryRow ( key ) . Scan ( & result )
if err == sql . ErrNoRows {
return nil , nil
}
if err != nil {
return nil , err
}
2017-08-03 17:24:27 +00:00
ent := & physical . Entry {
2017-04-06 13:33:49 +00:00
Key : key ,
Value : result ,
}
return ent , nil
}
2018-01-19 06:44:44 +00:00
func ( m * MSSQLBackend ) Delete ( ctx context . Context , key string ) error {
2017-04-06 13:33:49 +00:00
defer metrics . MeasureSince ( [ ] string { "mssql" , "delete" } , time . Now ( ) )
2017-07-17 17:04:49 +00:00
m . permitPool . Acquire ( )
defer m . permitPool . Release ( )
2017-04-06 13:33:49 +00:00
_ , err := m . statements [ "delete" ] . Exec ( key )
if err != nil {
return err
}
return nil
}
2018-01-19 06:44:44 +00:00
func ( m * MSSQLBackend ) List ( ctx context . Context , prefix string ) ( [ ] string , error ) {
2017-04-06 13:33:49 +00:00
defer metrics . MeasureSince ( [ ] string { "mssql" , "list" } , time . Now ( ) )
2017-07-17 17:04:49 +00:00
m . permitPool . Acquire ( )
defer m . permitPool . Release ( )
2017-04-06 13:33:49 +00:00
likePrefix := prefix + "%"
rows , err := m . statements [ "list" ] . Query ( likePrefix )
2017-07-07 12:15:59 +00:00
if err != nil {
return nil , err
}
2017-04-06 13:33:49 +00:00
var keys [ ] string
for rows . Next ( ) {
var key string
err = rows . Scan ( & key )
if err != nil {
return nil , fmt . Errorf ( "failed to scan rows: %v" , err )
}
key = strings . TrimPrefix ( key , prefix )
if i := strings . Index ( key , "/" ) ; i == - 1 {
keys = append ( keys , key )
} else if i != - 1 {
2017-06-16 15:09:15 +00:00
keys = strutil . AppendIfMissing ( keys , string ( key [ : i + 1 ] ) )
2017-04-06 13:33:49 +00:00
}
}
sort . Strings ( keys )
return keys , nil
}