Make db instances immutable and add a reset path to tear down and create a new database instance with an updated config

This commit is contained in:
Brian Kassouf 2017-02-15 16:51:59 -08:00 committed by Brian Kassouf
parent 29e07ac9e8
commit bba832e6bf
4 changed files with 125 additions and 124 deletions

View File

@ -25,6 +25,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
pathListRoles(&b), pathListRoles(&b),
pathRoles(&b), pathRoles(&b),
pathRoleCreate(&b), pathRoleCreate(&b),
pathResetConnection(&b),
}, },
Secrets: []*framework.Secret{ Secrets: []*framework.Secret{

View File

@ -15,47 +15,40 @@ import (
"github.com/gocql/gocql" "github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/tlsutil" "github.com/hashicorp/vault/helper/tlsutil"
"github.com/mitchellh/mapstructure"
) )
type ConnectionProducer interface { type ConnectionProducer interface {
Connection() (interface{}, error) Connection() (interface{}, error)
Close() Close()
// TODO: Should we make this immutable instead?
Reset(*DatabaseConfig) error
} }
// sqlConnectionProducer impliments ConnectionProducer and provides a generic producer for most sql databases // sqlConnectionProducer impliments ConnectionProducer and provides a generic producer for most sql databases
type sqlConnectionDetails struct {
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
}
type sqlConnectionProducer struct { type sqlConnectionProducer struct {
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
config *DatabaseConfig config *DatabaseConfig
// TODO: Should we merge these two structures make it immutable?
connDetails *sqlConnectionDetails
db *sql.DB db *sql.DB
sync.Mutex sync.Mutex
} }
func (cp *sqlConnectionProducer) Connection() (interface{}, error) { func (c *sqlConnectionProducer) Connection() (interface{}, error) {
// Grab the write lock // Grab the write lock
cp.Lock() c.Lock()
defer cp.Unlock() defer c.Unlock()
// If we already have a DB, we got it! // If we already have a DB, we got it!
if cp.db != nil { if c.db != nil {
if err := cp.db.Ping(); err == nil { if err := c.db.Ping(); err == nil {
return cp.db, nil return c.db, nil
} }
// If the ping was unsuccessful, close it and ignore errors as we'll be // If the ping was unsuccessful, close it and ignore errors as we'll be
// reestablishing anyways // reestablishing anyways
cp.db.Close() c.db.Close()
} }
// Otherwise, attempt to make connection // Otherwise, attempt to make connection
conn := cp.connDetails.ConnectionURL conn := c.ConnectionURL
// Ensure timezone is set to UTC for all the conenctions // Ensure timezone is set to UTC for all the conenctions
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
@ -67,54 +60,33 @@ func (cp *sqlConnectionProducer) Connection() (interface{}, error) {
} }
var err error var err error
cp.db, err = sql.Open(cp.config.DatabaseType, conn) c.db, err = sql.Open(c.config.DatabaseType, conn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Set some connection pool settings. We don't need much of this, // Set some connection pool settings. We don't need much of this,
// since the request rate shouldn't be high. // since the request rate shouldn't be high.
cp.db.SetMaxOpenConns(cp.config.MaxOpenConnections) c.db.SetMaxOpenConns(c.config.MaxOpenConnections)
cp.db.SetMaxIdleConns(cp.config.MaxIdleConnections) c.db.SetMaxIdleConns(c.config.MaxIdleConnections)
cp.db.SetConnMaxLifetime(cp.config.MaxConnectionLifetime) c.db.SetConnMaxLifetime(c.config.MaxConnectionLifetime)
return cp.db, nil return c.db, nil
} }
func (cp *sqlConnectionProducer) Close() { func (c *sqlConnectionProducer) Close() {
// Grab the write lock // Grab the write lock
cp.Lock() c.Lock()
defer cp.Unlock() defer c.Unlock()
if cp.db != nil { if c.db != nil {
cp.db.Close() c.db.Close()
} }
cp.db = nil c.db = nil
} }
func (cp *sqlConnectionProducer) Reset(config *DatabaseConfig) error { type cassandraConnectionProducer struct {
// Grab the write lock
cp.Lock()
var details *sqlConnectionDetails
err := mapstructure.Decode(config.ConnectionDetails, &details)
if err != nil {
return err
}
cp.connDetails = details
cp.config = config
cp.Unlock()
cp.Close()
_, err = cp.Connection()
return err
}
// cassandraConnectionProducer impliments ConnectionProducer and provides connections for cassandra
type cassandraConnectionDetails struct {
Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"`
Username string `json:"username" structs:"username" mapstructure:"username"` Username string `json:"username" structs:"username" mapstructure:"username"`
Password string `json:"password" structs:"password" mapstructure:"password"` Password string `json:"password" structs:"password" mapstructure:"password"`
@ -127,90 +99,74 @@ type cassandraConnectionDetails struct {
ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"`
}
type cassandraConnectionProducer struct {
config *DatabaseConfig config *DatabaseConfig
// TODO: Should we merge these two structures make it immutable?
connDetails *cassandraConnectionDetails
session *gocql.Session session *gocql.Session
sync.Mutex sync.Mutex
} }
func (cp *cassandraConnectionProducer) Connection() (interface{}, error) { func (c *cassandraConnectionProducer) Connection() (interface{}, error) {
// Grab the write lock // Grab the write lock
cp.Lock() c.Lock()
defer cp.Unlock() defer c.Unlock()
// If we already have a DB, we got it! // If we already have a DB, we got it!
if cp.session != nil { if c.session != nil {
return cp.session, nil return c.session, nil
} }
session, err := cp.createSession(cp.connDetails) session, err := c.createSession()
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Store the session in backend for reuse // Store the session in backend for reuse
cp.session = session c.session = session
return session, nil return session, nil
} }
func (cp *cassandraConnectionProducer) Close() { func (c *cassandraConnectionProducer) Close() {
// Grab the write lock // Grab the write lock
cp.Lock() c.Lock()
defer cp.Unlock() defer c.Unlock()
if cp.session != nil { if c.session != nil {
cp.session.Close() c.session.Close()
} }
cp.session = nil c.session = nil
} }
func (cp *cassandraConnectionProducer) Reset(config *DatabaseConfig) error { func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
// Grab the write lock clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...)
cp.Lock()
cp.config = config
cp.Unlock()
cp.Close()
_, err := cp.Connection()
return err
}
func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDetails) (*gocql.Session, error) {
clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...)
clusterConfig.Authenticator = gocql.PasswordAuthenticator{ clusterConfig.Authenticator = gocql.PasswordAuthenticator{
Username: cfg.Username, Username: c.Username,
Password: cfg.Password, Password: c.Password,
} }
clusterConfig.ProtoVersion = cfg.ProtocolVersion clusterConfig.ProtoVersion = c.ProtocolVersion
if clusterConfig.ProtoVersion == 0 { if clusterConfig.ProtoVersion == 0 {
clusterConfig.ProtoVersion = 2 clusterConfig.ProtoVersion = 2
} }
clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second
if cfg.TLS { if c.TLS {
var tlsConfig *tls.Config var tlsConfig *tls.Config
if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 { if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 {
if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 { if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 {
return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") return nil, fmt.Errorf("Found certificate for TLS authentication but no private key")
} }
certBundle := &certutil.CertBundle{} certBundle := &certutil.CertBundle{}
if len(cfg.Certificate) > 0 { if len(c.Certificate) > 0 {
certBundle.Certificate = cfg.Certificate certBundle.Certificate = c.Certificate
certBundle.PrivateKey = cfg.PrivateKey certBundle.PrivateKey = c.PrivateKey
} }
if len(cfg.IssuingCA) > 0 { if len(c.IssuingCA) > 0 {
certBundle.IssuingCA = cfg.IssuingCA certBundle.IssuingCA = c.IssuingCA
} }
parsedCertBundle, err := certBundle.ToParsedCertBundle() parsedCertBundle, err := certBundle.ToParsedCertBundle()
@ -222,11 +178,11 @@ func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDet
if err != nil || tlsConfig == nil { if err != nil || tlsConfig == nil {
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err)
} }
tlsConfig.InsecureSkipVerify = cfg.InsecureTLS tlsConfig.InsecureSkipVerify = c.InsecureTLS
if cfg.TLSMinVersion != "" { if c.TLSMinVersion != "" {
var ok bool var ok bool
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion] tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion]
if !ok { if !ok {
return nil, fmt.Errorf("invalid 'tls_min_version' in config") return nil, fmt.Errorf("invalid 'tls_min_version' in config")
} }
@ -248,8 +204,8 @@ func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDet
} }
// Set consistency // Set consistency
if cfg.Consistency != "" { if c.Consistency != "" {
consistencyValue, err := gocql.ParseConsistencyWrapper(cfg.Consistency) consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -22,16 +22,12 @@ var (
func Factory(conf *DatabaseConfig) (DatabaseType, error) { func Factory(conf *DatabaseConfig) (DatabaseType, error) {
switch conf.DatabaseType { switch conf.DatabaseType {
case postgreSQLTypeName: case postgreSQLTypeName:
var details *sqlConnectionDetails var connProducer *sqlConnectionProducer
err := mapstructure.Decode(conf.ConnectionDetails, &details) err := mapstructure.Decode(conf.ConnectionDetails, &connProducer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
connProducer.config = conf
connProducer := &sqlConnectionProducer{
config: conf,
connDetails: details,
}
credsProducer := &sqlCredentialsProducer{ credsProducer := &sqlCredentialsProducer{
displayNameLen: 23, displayNameLen: 23,
@ -44,16 +40,12 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) {
}, nil }, nil
case mySQLTypeName: case mySQLTypeName:
var details *sqlConnectionDetails var connProducer *sqlConnectionProducer
err := mapstructure.Decode(conf.ConnectionDetails, &details) err := mapstructure.Decode(conf.ConnectionDetails, &connProducer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
connProducer.config = conf
connProducer := &sqlConnectionProducer{
config: conf,
connDetails: details,
}
credsProducer := &sqlCredentialsProducer{ credsProducer := &sqlCredentialsProducer{
displayNameLen: 4, displayNameLen: 4,
@ -66,16 +58,12 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) {
}, nil }, nil
case cassandraTypeName: case cassandraTypeName:
var details *cassandraConnectionDetails var connProducer *cassandraConnectionProducer
err := mapstructure.Decode(conf.ConnectionDetails, &details) err := mapstructure.Decode(conf.ConnectionDetails, &connProducer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
connProducer.config = conf
connProducer := &cassandraConnectionProducer{
config: conf,
connDetails: details,
}
credsProducer := &cassandraCredentialsProducer{} credsProducer := &cassandraCredentialsProducer{}

View File

@ -1,6 +1,7 @@
package database package database
import ( import (
"errors"
"fmt" "fmt"
"time" "time"
@ -10,6 +11,64 @@ import (
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
) )
func pathResetConnection(b *databaseBackend) *framework.Path {
return &framework.Path{
Pattern: fmt.Sprintf("reset/%s", framework.GenericNameRegex("name")),
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of this DB type",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathConnectionReset,
},
HelpSynopsis: pathConfigConnectionHelpSyn,
HelpDescription: pathConfigConnectionHelpDesc,
}
}
func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if name == "" {
return nil, errors.New("No database name set")
}
// Grab the mutex lock
b.Lock()
defer b.Unlock()
entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name))
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration")
}
if entry == nil {
return nil, nil
}
var config dbs.DatabaseConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
db, ok := b.connections[name]
if !ok {
return logical.ErrorResponse("Can not change type of existing connection."), nil
}
db.Close()
db, err = dbs.Factory(&config)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
b.connections[name] = db
return nil, nil
}
func pathConfigConnection(b *databaseBackend) *framework.Path { func pathConfigConnection(b *databaseBackend) *framework.Path {
return &framework.Path{ return &framework.Path{
Pattern: fmt.Sprintf("dbs/%s", framework.GenericNameRegex("name")), Pattern: fmt.Sprintf("dbs/%s", framework.GenericNameRegex("name")),
@ -129,13 +188,13 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew
if b.connections[name].Type() != connType { if b.connections[name].Type() != connType {
return logical.ErrorResponse("Can not change type of existing connection."), nil return logical.ErrorResponse("Can not change type of existing connection."), nil
} }
db = b.connections[name]
} else { } else {
db, err = dbs.Factory(config) db, err = dbs.Factory(config)
if err != nil { if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
} }
b.connections[name] = db
} }
/* /*
@ -166,9 +225,6 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew
} }
// Reset the DB connection // Reset the DB connection
db.Reset(config)
b.connections[name] = db
resp := &logical.Response{} resp := &logical.Response{}
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.")