From 2054fff89063dd3aa51c404e5bc59db73dff2efc Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 13 Mar 2017 14:39:55 -0700 Subject: [PATCH] Add a way to initalize plugins and builtin databases the same way. --- .../database/dbs/connectionproducer.go | 54 +++++++++++++++++-- builtin/logical/database/dbs/db.go | 21 ++------ builtin/logical/database/dbs/plugin.go | 12 +++++ .../database/path_config_connection.go | 24 +++++++++ 4 files changed, 90 insertions(+), 21 deletions(-) diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index 82da37cc7..8d05e5d9e 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -3,6 +3,7 @@ package dbs import ( "crypto/tls" "database/sql" + "errors" "fmt" "strings" "sync" @@ -11,14 +12,20 @@ import ( // Import sql drivers _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" + "github.com/mitchellh/mapstructure" "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/tlsutil" ) +var ( + errNotInitalized = errors.New("Connection has not been initalized") +) + type ConnectionProducer interface { Close() + Initialize(map[string]interface{}) error sync.Locker connection() (interface{}, error) @@ -30,10 +37,28 @@ type sqlConnectionProducer struct { config *DatabaseConfig - db *sql.DB + initalized bool + db *sql.DB sync.Mutex } +func (c *sqlConnectionProducer) Initialize(conf map[string]interface{}) error { + c.Lock() + defer c.Unlock() + + err := mapstructure.Decode(conf, c) + if err != nil { + return err + } + c.initalized = true + + if _, err := c.connection(); err != nil { + return fmt.Errorf("Error Initalizing Connection: %s", err) + } + + return nil +} + func (c *sqlConnectionProducer) connection() (interface{}, error) { // If we already have a DB, test it and return if c.db != nil { @@ -98,13 +123,34 @@ type cassandraConnectionProducer struct { TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` - config *DatabaseConfig - - session *gocql.Session + config *DatabaseConfig + initalized bool + session *gocql.Session sync.Mutex } +func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) error { + c.Lock() + defer c.Unlock() + + err := mapstructure.Decode(conf, c) + if err != nil { + return err + } + c.initalized = true + + if _, err := c.connection(); err != nil { + return fmt.Errorf("Error Initalizing Connection: %s", err) + } + + return nil +} + func (c *cassandraConnectionProducer) connection() (interface{}, error) { + if !c.initalized { + return nil, errNotInitalized + } + // If we already have a DB, return it if c.session != nil { return c.session, nil diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 33cf7361a..98443f8f2 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -5,8 +5,6 @@ import ( "fmt" "strings" "time" - - "github.com/mitchellh/mapstructure" ) const ( @@ -25,11 +23,7 @@ type Factory func(*DatabaseConfig) (DatabaseType, error) func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { switch conf.DatabaseType { case postgreSQLTypeName: - var connProducer *sqlConnectionProducer - err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) - if err != nil { - return nil, err - } + connProducer := &sqlConnectionProducer{} connProducer.config = conf credsProducer := &sqlCredentialsProducer{ @@ -43,11 +37,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { }, nil case mySQLTypeName: - var connProducer *sqlConnectionProducer - err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) - if err != nil { - return nil, err - } + connProducer := &sqlConnectionProducer{} connProducer.config = conf credsProducer := &sqlCredentialsProducer{ @@ -61,11 +51,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { }, nil case cassandraTypeName: - var connProducer *cassandraConnectionProducer - err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) - if err != nil { - return nil, err - } + connProducer := &cassandraConnectionProducer{} connProducer.config = conf credsProducer := &cassandraCredentialsProducer{} @@ -102,6 +88,7 @@ type DatabaseType interface { RenewUser(statements Statements, username, expiration string) error RevokeUser(statements Statements, username string) error + Initialize(map[string]interface{}) error Close() CredentialsProducer } diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index bbd8d4ce4..b244a33fc 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -140,6 +140,12 @@ func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username st return err } +func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}) error { + err := dr.client.Call("Plugin.Initialize", conf, &struct{}{}) + + return err +} + func (dr *databasePluginRPCClient) Close() error { err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) @@ -195,6 +201,12 @@ func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct return err } +func (ds *databasePluginRPCServer) Initialize(args map[string]interface{}, _ *struct{}) error { + err := ds.impl.Initialize(args) + + return err +} + func (ds *databasePluginRPCServer) Close(_ interface{}, _ *struct{}) error { ds.impl.Close() return nil diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 31f618281..6c0a63a11 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -3,6 +3,7 @@ package database import ( "errors" "fmt" + "strings" "time" "github.com/fatih/structs" @@ -67,6 +68,11 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } + err = db.Initialize(config.ConnectionDetails) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + } + b.connections[name] = db return nil, nil @@ -207,6 +213,11 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. } name := data.Get("name").(string) + if name == "" { + return logical.ErrorResponse("Empty name attribute given"), nil + } + + verifyConnection := data.Get("verify_connection").(bool) // Grab the mutex lock b.Lock() @@ -225,6 +236,19 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } + err := db.Initialize(config.ConnectionDetails) + if err != nil { + if !strings.Contains(err.Error(), "Error Initializing Connection") { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + + } + + if verifyConnection { + return logical.ErrorResponse(err.Error()), nil + + } + } + b.connections[name] = db }