Add a way to initalize plugins and builtin databases the same way.
This commit is contained in:
parent
71b81aad23
commit
2054fff890
|
@ -3,6 +3,7 @@ package dbs
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -11,14 +12,20 @@ import (
|
||||||
// Import sql drivers
|
// Import sql drivers
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
|
"github.com/mitchellh/mapstructure"
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
"github.com/gocql/gocql"
|
||||||
"github.com/hashicorp/vault/helper/certutil"
|
"github.com/hashicorp/vault/helper/certutil"
|
||||||
"github.com/hashicorp/vault/helper/tlsutil"
|
"github.com/hashicorp/vault/helper/tlsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errNotInitalized = errors.New("Connection has not been initalized")
|
||||||
|
)
|
||||||
|
|
||||||
type ConnectionProducer interface {
|
type ConnectionProducer interface {
|
||||||
Close()
|
Close()
|
||||||
|
Initialize(map[string]interface{}) error
|
||||||
|
|
||||||
sync.Locker
|
sync.Locker
|
||||||
connection() (interface{}, error)
|
connection() (interface{}, error)
|
||||||
|
@ -30,10 +37,28 @@ type sqlConnectionProducer struct {
|
||||||
|
|
||||||
config *DatabaseConfig
|
config *DatabaseConfig
|
||||||
|
|
||||||
db *sql.DB
|
initalized bool
|
||||||
|
db *sql.DB
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *sqlConnectionProducer) Initialize(conf map[string]interface{}) error {
|
||||||
|
c.Lock()
|
||||||
|
defer c.Unlock()
|
||||||
|
|
||||||
|
err := mapstructure.Decode(conf, c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.initalized = true
|
||||||
|
|
||||||
|
if _, err := c.connection(); err != nil {
|
||||||
|
return fmt.Errorf("Error Initalizing Connection: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *sqlConnectionProducer) connection() (interface{}, error) {
|
func (c *sqlConnectionProducer) connection() (interface{}, error) {
|
||||||
// If we already have a DB, test it and return
|
// If we already have a DB, test it and return
|
||||||
if c.db != nil {
|
if c.db != nil {
|
||||||
|
@ -98,13 +123,34 @@ type cassandraConnectionProducer struct {
|
||||||
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
|
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
|
||||||
Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"`
|
Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"`
|
||||||
|
|
||||||
config *DatabaseConfig
|
config *DatabaseConfig
|
||||||
|
initalized bool
|
||||||
session *gocql.Session
|
session *gocql.Session
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) error {
|
||||||
|
c.Lock()
|
||||||
|
defer c.Unlock()
|
||||||
|
|
||||||
|
err := mapstructure.Decode(conf, c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.initalized = true
|
||||||
|
|
||||||
|
if _, err := c.connection(); err != nil {
|
||||||
|
return fmt.Errorf("Error Initalizing Connection: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *cassandraConnectionProducer) connection() (interface{}, error) {
|
func (c *cassandraConnectionProducer) connection() (interface{}, error) {
|
||||||
|
if !c.initalized {
|
||||||
|
return nil, errNotInitalized
|
||||||
|
}
|
||||||
|
|
||||||
// If we already have a DB, return it
|
// If we already have a DB, return it
|
||||||
if c.session != nil {
|
if c.session != nil {
|
||||||
return c.session, nil
|
return c.session, nil
|
||||||
|
|
|
@ -5,8 +5,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mitchellh/mapstructure"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -25,11 +23,7 @@ type Factory func(*DatabaseConfig) (DatabaseType, error)
|
||||||
func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) {
|
func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) {
|
||||||
switch conf.DatabaseType {
|
switch conf.DatabaseType {
|
||||||
case postgreSQLTypeName:
|
case postgreSQLTypeName:
|
||||||
var connProducer *sqlConnectionProducer
|
connProducer := &sqlConnectionProducer{}
|
||||||
err := mapstructure.Decode(conf.ConnectionDetails, &connProducer)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
connProducer.config = conf
|
connProducer.config = conf
|
||||||
|
|
||||||
credsProducer := &sqlCredentialsProducer{
|
credsProducer := &sqlCredentialsProducer{
|
||||||
|
@ -43,11 +37,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) {
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
case mySQLTypeName:
|
case mySQLTypeName:
|
||||||
var connProducer *sqlConnectionProducer
|
connProducer := &sqlConnectionProducer{}
|
||||||
err := mapstructure.Decode(conf.ConnectionDetails, &connProducer)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
connProducer.config = conf
|
connProducer.config = conf
|
||||||
|
|
||||||
credsProducer := &sqlCredentialsProducer{
|
credsProducer := &sqlCredentialsProducer{
|
||||||
|
@ -61,11 +51,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) {
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
case cassandraTypeName:
|
case cassandraTypeName:
|
||||||
var connProducer *cassandraConnectionProducer
|
connProducer := &cassandraConnectionProducer{}
|
||||||
err := mapstructure.Decode(conf.ConnectionDetails, &connProducer)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
connProducer.config = conf
|
connProducer.config = conf
|
||||||
|
|
||||||
credsProducer := &cassandraCredentialsProducer{}
|
credsProducer := &cassandraCredentialsProducer{}
|
||||||
|
@ -102,6 +88,7 @@ type DatabaseType interface {
|
||||||
RenewUser(statements Statements, username, expiration string) error
|
RenewUser(statements Statements, username, expiration string) error
|
||||||
RevokeUser(statements Statements, username string) error
|
RevokeUser(statements Statements, username string) error
|
||||||
|
|
||||||
|
Initialize(map[string]interface{}) error
|
||||||
Close()
|
Close()
|
||||||
CredentialsProducer
|
CredentialsProducer
|
||||||
}
|
}
|
||||||
|
|
|
@ -140,6 +140,12 @@ func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username st
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}) error {
|
||||||
|
err := dr.client.Call("Plugin.Initialize", conf, &struct{}{})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (dr *databasePluginRPCClient) Close() error {
|
func (dr *databasePluginRPCClient) Close() error {
|
||||||
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
|
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
|
||||||
|
|
||||||
|
@ -195,6 +201,12 @@ func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ds *databasePluginRPCServer) Initialize(args map[string]interface{}, _ *struct{}) error {
|
||||||
|
err := ds.impl.Initialize(args)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (ds *databasePluginRPCServer) Close(_ interface{}, _ *struct{}) error {
|
func (ds *databasePluginRPCServer) Close(_ interface{}, _ *struct{}) error {
|
||||||
ds.impl.Close()
|
ds.impl.Close()
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -3,6 +3,7 @@ package database
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fatih/structs"
|
"github.com/fatih/structs"
|
||||||
|
@ -67,6 +68,11 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew
|
||||||
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = db.Initialize(config.ConnectionDetails)
|
||||||
|
if err != nil {
|
||||||
|
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
||||||
|
}
|
||||||
|
|
||||||
b.connections[name] = db
|
b.connections[name] = db
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -207,6 +213,11 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
|
||||||
}
|
}
|
||||||
|
|
||||||
name := data.Get("name").(string)
|
name := data.Get("name").(string)
|
||||||
|
if name == "" {
|
||||||
|
return logical.ErrorResponse("Empty name attribute given"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
verifyConnection := data.Get("verify_connection").(bool)
|
||||||
|
|
||||||
// Grab the mutex lock
|
// Grab the mutex lock
|
||||||
b.Lock()
|
b.Lock()
|
||||||
|
@ -225,6 +236,19 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
|
||||||
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err := db.Initialize(config.ConnectionDetails)
|
||||||
|
if err != nil {
|
||||||
|
if !strings.Contains(err.Error(), "Error Initializing Connection") {
|
||||||
|
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if verifyConnection {
|
||||||
|
return logical.ErrorResponse(err.Error()), nil
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
b.connections[name] = db
|
b.connections[name] = db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue