Add a way to initalize plugins and builtin databases the same way.

This commit is contained in:
Brian Kassouf 2017-03-13 14:39:55 -07:00
parent 71b81aad23
commit 2054fff890
4 changed files with 90 additions and 21 deletions

View File

@ -3,6 +3,7 @@ package dbs
import (
"crypto/tls"
"database/sql"
"errors"
"fmt"
"strings"
"sync"
@ -11,14 +12,20 @@ import (
// Import sql drivers
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
"github.com/mitchellh/mapstructure"
"github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/tlsutil"
)
var (
errNotInitalized = errors.New("Connection has not been initalized")
)
type ConnectionProducer interface {
Close()
Initialize(map[string]interface{}) error
sync.Locker
connection() (interface{}, error)
@ -30,10 +37,28 @@ type sqlConnectionProducer struct {
config *DatabaseConfig
db *sql.DB
initalized bool
db *sql.DB
sync.Mutex
}
func (c *sqlConnectionProducer) Initialize(conf map[string]interface{}) error {
c.Lock()
defer c.Unlock()
err := mapstructure.Decode(conf, c)
if err != nil {
return err
}
c.initalized = true
if _, err := c.connection(); err != nil {
return fmt.Errorf("Error Initalizing Connection: %s", err)
}
return nil
}
func (c *sqlConnectionProducer) connection() (interface{}, error) {
// If we already have a DB, test it and return
if c.db != nil {
@ -98,13 +123,34 @@ type cassandraConnectionProducer struct {
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"`
config *DatabaseConfig
session *gocql.Session
config *DatabaseConfig
initalized bool
session *gocql.Session
sync.Mutex
}
func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) error {
c.Lock()
defer c.Unlock()
err := mapstructure.Decode(conf, c)
if err != nil {
return err
}
c.initalized = true
if _, err := c.connection(); err != nil {
return fmt.Errorf("Error Initalizing Connection: %s", err)
}
return nil
}
func (c *cassandraConnectionProducer) connection() (interface{}, error) {
if !c.initalized {
return nil, errNotInitalized
}
// If we already have a DB, return it
if c.session != nil {
return c.session, nil

View File

@ -5,8 +5,6 @@ import (
"fmt"
"strings"
"time"
"github.com/mitchellh/mapstructure"
)
const (
@ -25,11 +23,7 @@ type Factory func(*DatabaseConfig) (DatabaseType, error)
func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) {
switch conf.DatabaseType {
case postgreSQLTypeName:
var connProducer *sqlConnectionProducer
err := mapstructure.Decode(conf.ConnectionDetails, &connProducer)
if err != nil {
return nil, err
}
connProducer := &sqlConnectionProducer{}
connProducer.config = conf
credsProducer := &sqlCredentialsProducer{
@ -43,11 +37,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) {
}, nil
case mySQLTypeName:
var connProducer *sqlConnectionProducer
err := mapstructure.Decode(conf.ConnectionDetails, &connProducer)
if err != nil {
return nil, err
}
connProducer := &sqlConnectionProducer{}
connProducer.config = conf
credsProducer := &sqlCredentialsProducer{
@ -61,11 +51,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) {
}, nil
case cassandraTypeName:
var connProducer *cassandraConnectionProducer
err := mapstructure.Decode(conf.ConnectionDetails, &connProducer)
if err != nil {
return nil, err
}
connProducer := &cassandraConnectionProducer{}
connProducer.config = conf
credsProducer := &cassandraCredentialsProducer{}
@ -102,6 +88,7 @@ type DatabaseType interface {
RenewUser(statements Statements, username, expiration string) error
RevokeUser(statements Statements, username string) error
Initialize(map[string]interface{}) error
Close()
CredentialsProducer
}

View File

@ -140,6 +140,12 @@ func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username st
return err
}
func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}) error {
err := dr.client.Call("Plugin.Initialize", conf, &struct{}{})
return err
}
func (dr *databasePluginRPCClient) Close() error {
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
@ -195,6 +201,12 @@ func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct
return err
}
func (ds *databasePluginRPCServer) Initialize(args map[string]interface{}, _ *struct{}) error {
err := ds.impl.Initialize(args)
return err
}
func (ds *databasePluginRPCServer) Close(_ interface{}, _ *struct{}) error {
ds.impl.Close()
return nil

View File

@ -3,6 +3,7 @@ package database
import (
"errors"
"fmt"
"strings"
"time"
"github.com/fatih/structs"
@ -67,6 +68,11 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
err = db.Initialize(config.ConnectionDetails)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
b.connections[name] = db
return nil, nil
@ -207,6 +213,11 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
}
name := data.Get("name").(string)
if name == "" {
return logical.ErrorResponse("Empty name attribute given"), nil
}
verifyConnection := data.Get("verify_connection").(bool)
// Grab the mutex lock
b.Lock()
@ -225,6 +236,19 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
err := db.Initialize(config.ConnectionDetails)
if err != nil {
if !strings.Contains(err.Error(), "Error Initializing Connection") {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
if verifyConnection {
return logical.ErrorResponse(err.Error()), nil
}
}
b.connections[name] = db
}