Add a way to initalize plugins and builtin databases the same way.
This commit is contained in:
parent
71b81aad23
commit
2054fff890
|
@ -3,6 +3,7 @@ package dbs
|
|||
import (
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -11,14 +12,20 @@ import (
|
|||
// Import sql drivers
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/tlsutil"
|
||||
)
|
||||
|
||||
var (
|
||||
errNotInitalized = errors.New("Connection has not been initalized")
|
||||
)
|
||||
|
||||
type ConnectionProducer interface {
|
||||
Close()
|
||||
Initialize(map[string]interface{}) error
|
||||
|
||||
sync.Locker
|
||||
connection() (interface{}, error)
|
||||
|
@ -30,10 +37,28 @@ type sqlConnectionProducer struct {
|
|||
|
||||
config *DatabaseConfig
|
||||
|
||||
initalized bool
|
||||
db *sql.DB
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *sqlConnectionProducer) Initialize(conf map[string]interface{}) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
err := mapstructure.Decode(conf, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.initalized = true
|
||||
|
||||
if _, err := c.connection(); err != nil {
|
||||
return fmt.Errorf("Error Initalizing Connection: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *sqlConnectionProducer) connection() (interface{}, error) {
|
||||
// If we already have a DB, test it and return
|
||||
if c.db != nil {
|
||||
|
@ -99,12 +124,33 @@ type cassandraConnectionProducer struct {
|
|||
Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"`
|
||||
|
||||
config *DatabaseConfig
|
||||
|
||||
initalized bool
|
||||
session *gocql.Session
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
err := mapstructure.Decode(conf, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.initalized = true
|
||||
|
||||
if _, err := c.connection(); err != nil {
|
||||
return fmt.Errorf("Error Initalizing Connection: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *cassandraConnectionProducer) connection() (interface{}, error) {
|
||||
if !c.initalized {
|
||||
return nil, errNotInitalized
|
||||
}
|
||||
|
||||
// If we already have a DB, return it
|
||||
if c.session != nil {
|
||||
return c.session, nil
|
||||
|
|
|
@ -5,8 +5,6 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -25,11 +23,7 @@ type Factory func(*DatabaseConfig) (DatabaseType, error)
|
|||
func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) {
|
||||
switch conf.DatabaseType {
|
||||
case postgreSQLTypeName:
|
||||
var connProducer *sqlConnectionProducer
|
||||
err := mapstructure.Decode(conf.ConnectionDetails, &connProducer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connProducer := &sqlConnectionProducer{}
|
||||
connProducer.config = conf
|
||||
|
||||
credsProducer := &sqlCredentialsProducer{
|
||||
|
@ -43,11 +37,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) {
|
|||
}, nil
|
||||
|
||||
case mySQLTypeName:
|
||||
var connProducer *sqlConnectionProducer
|
||||
err := mapstructure.Decode(conf.ConnectionDetails, &connProducer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connProducer := &sqlConnectionProducer{}
|
||||
connProducer.config = conf
|
||||
|
||||
credsProducer := &sqlCredentialsProducer{
|
||||
|
@ -61,11 +51,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) {
|
|||
}, nil
|
||||
|
||||
case cassandraTypeName:
|
||||
var connProducer *cassandraConnectionProducer
|
||||
err := mapstructure.Decode(conf.ConnectionDetails, &connProducer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connProducer := &cassandraConnectionProducer{}
|
||||
connProducer.config = conf
|
||||
|
||||
credsProducer := &cassandraCredentialsProducer{}
|
||||
|
@ -102,6 +88,7 @@ type DatabaseType interface {
|
|||
RenewUser(statements Statements, username, expiration string) error
|
||||
RevokeUser(statements Statements, username string) error
|
||||
|
||||
Initialize(map[string]interface{}) error
|
||||
Close()
|
||||
CredentialsProducer
|
||||
}
|
||||
|
|
|
@ -140,6 +140,12 @@ func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username st
|
|||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}) error {
|
||||
err := dr.client.Call("Plugin.Initialize", conf, &struct{}{})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Close() error {
|
||||
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
|
||||
|
||||
|
@ -195,6 +201,12 @@ func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct
|
|||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Initialize(args map[string]interface{}, _ *struct{}) error {
|
||||
err := ds.impl.Initialize(args)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Close(_ interface{}, _ *struct{}) error {
|
||||
ds.impl.Close()
|
||||
return nil
|
||||
|
|
|
@ -3,6 +3,7 @@ package database
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/structs"
|
||||
|
@ -67,6 +68,11 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew
|
|||
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
||||
}
|
||||
|
||||
err = db.Initialize(config.ConnectionDetails)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
||||
}
|
||||
|
||||
b.connections[name] = db
|
||||
|
||||
return nil, nil
|
||||
|
@ -207,6 +213,11 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
|
|||
}
|
||||
|
||||
name := data.Get("name").(string)
|
||||
if name == "" {
|
||||
return logical.ErrorResponse("Empty name attribute given"), nil
|
||||
}
|
||||
|
||||
verifyConnection := data.Get("verify_connection").(bool)
|
||||
|
||||
// Grab the mutex lock
|
||||
b.Lock()
|
||||
|
@ -225,6 +236,19 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
|
|||
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
||||
}
|
||||
|
||||
err := db.Initialize(config.ConnectionDetails)
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "Error Initializing Connection") {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
||||
|
||||
}
|
||||
|
||||
if verifyConnection {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
b.connections[name] = db
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue