Add a way to initalize plugins and builtin databases the same way.

This commit is contained in:
Brian Kassouf 2017-03-13 14:39:55 -07:00
parent 71b81aad23
commit 2054fff890
4 changed files with 90 additions and 21 deletions

View File

@ -3,6 +3,7 @@ package dbs
import ( import (
"crypto/tls" "crypto/tls"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
@ -11,14 +12,20 @@ import (
// Import sql drivers // Import sql drivers
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/mitchellh/mapstructure"
"github.com/gocql/gocql" "github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/tlsutil" "github.com/hashicorp/vault/helper/tlsutil"
) )
var (
errNotInitalized = errors.New("Connection has not been initalized")
)
type ConnectionProducer interface { type ConnectionProducer interface {
Close() Close()
Initialize(map[string]interface{}) error
sync.Locker sync.Locker
connection() (interface{}, error) connection() (interface{}, error)
@ -30,10 +37,28 @@ type sqlConnectionProducer struct {
config *DatabaseConfig config *DatabaseConfig
db *sql.DB initalized bool
db *sql.DB
sync.Mutex sync.Mutex
} }
func (c *sqlConnectionProducer) Initialize(conf map[string]interface{}) error {
c.Lock()
defer c.Unlock()
err := mapstructure.Decode(conf, c)
if err != nil {
return err
}
c.initalized = true
if _, err := c.connection(); err != nil {
return fmt.Errorf("Error Initalizing Connection: %s", err)
}
return nil
}
func (c *sqlConnectionProducer) connection() (interface{}, error) { func (c *sqlConnectionProducer) connection() (interface{}, error) {
// If we already have a DB, test it and return // If we already have a DB, test it and return
if c.db != nil { if c.db != nil {
@ -98,13 +123,34 @@ type cassandraConnectionProducer struct {
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"`
config *DatabaseConfig config *DatabaseConfig
initalized bool
session *gocql.Session session *gocql.Session
sync.Mutex sync.Mutex
} }
func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) error {
c.Lock()
defer c.Unlock()
err := mapstructure.Decode(conf, c)
if err != nil {
return err
}
c.initalized = true
if _, err := c.connection(); err != nil {
return fmt.Errorf("Error Initalizing Connection: %s", err)
}
return nil
}
func (c *cassandraConnectionProducer) connection() (interface{}, error) { func (c *cassandraConnectionProducer) connection() (interface{}, error) {
if !c.initalized {
return nil, errNotInitalized
}
// If we already have a DB, return it // If we already have a DB, return it
if c.session != nil { if c.session != nil {
return c.session, nil return c.session, nil

View File

@ -5,8 +5,6 @@ import (
"fmt" "fmt"
"strings" "strings"
"time" "time"
"github.com/mitchellh/mapstructure"
) )
const ( const (
@ -25,11 +23,7 @@ type Factory func(*DatabaseConfig) (DatabaseType, error)
func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) {
switch conf.DatabaseType { switch conf.DatabaseType {
case postgreSQLTypeName: case postgreSQLTypeName:
var connProducer *sqlConnectionProducer connProducer := &sqlConnectionProducer{}
err := mapstructure.Decode(conf.ConnectionDetails, &connProducer)
if err != nil {
return nil, err
}
connProducer.config = conf connProducer.config = conf
credsProducer := &sqlCredentialsProducer{ credsProducer := &sqlCredentialsProducer{
@ -43,11 +37,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) {
}, nil }, nil
case mySQLTypeName: case mySQLTypeName:
var connProducer *sqlConnectionProducer connProducer := &sqlConnectionProducer{}
err := mapstructure.Decode(conf.ConnectionDetails, &connProducer)
if err != nil {
return nil, err
}
connProducer.config = conf connProducer.config = conf
credsProducer := &sqlCredentialsProducer{ credsProducer := &sqlCredentialsProducer{
@ -61,11 +51,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) {
}, nil }, nil
case cassandraTypeName: case cassandraTypeName:
var connProducer *cassandraConnectionProducer connProducer := &cassandraConnectionProducer{}
err := mapstructure.Decode(conf.ConnectionDetails, &connProducer)
if err != nil {
return nil, err
}
connProducer.config = conf connProducer.config = conf
credsProducer := &cassandraCredentialsProducer{} credsProducer := &cassandraCredentialsProducer{}
@ -102,6 +88,7 @@ type DatabaseType interface {
RenewUser(statements Statements, username, expiration string) error RenewUser(statements Statements, username, expiration string) error
RevokeUser(statements Statements, username string) error RevokeUser(statements Statements, username string) error
Initialize(map[string]interface{}) error
Close() Close()
CredentialsProducer CredentialsProducer
} }

View File

@ -140,6 +140,12 @@ func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username st
return err return err
} }
func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}) error {
err := dr.client.Call("Plugin.Initialize", conf, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Close() error { func (dr *databasePluginRPCClient) Close() error {
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
@ -195,6 +201,12 @@ func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct
return err return err
} }
func (ds *databasePluginRPCServer) Initialize(args map[string]interface{}, _ *struct{}) error {
err := ds.impl.Initialize(args)
return err
}
func (ds *databasePluginRPCServer) Close(_ interface{}, _ *struct{}) error { func (ds *databasePluginRPCServer) Close(_ interface{}, _ *struct{}) error {
ds.impl.Close() ds.impl.Close()
return nil return nil

View File

@ -3,6 +3,7 @@ package database
import ( import (
"errors" "errors"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/fatih/structs" "github.com/fatih/structs"
@ -67,6 +68,11 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
} }
err = db.Initialize(config.ConnectionDetails)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
b.connections[name] = db b.connections[name] = db
return nil, nil return nil, nil
@ -207,6 +213,11 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
} }
name := data.Get("name").(string) name := data.Get("name").(string)
if name == "" {
return logical.ErrorResponse("Empty name attribute given"), nil
}
verifyConnection := data.Get("verify_connection").(bool)
// Grab the mutex lock // Grab the mutex lock
b.Lock() b.Lock()
@ -225,6 +236,19 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
} }
err := db.Initialize(config.ConnectionDetails)
if err != nil {
if !strings.Contains(err.Error(), "Error Initializing Connection") {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
if verifyConnection {
return logical.ErrorResponse(err.Error()), nil
}
}
b.connections[name] = db b.connections[name] = db
} }