Add max connection lifetime param and set consistancy on cassandra session

This commit is contained in:
Brian Kassouf 2017-01-04 11:28:30 -08:00 committed by Brian Kassouf
parent 1f009518cd
commit 8e8f260d96
3 changed files with 37 additions and 16 deletions

View File

@ -78,6 +78,7 @@ func (cp *sqlConnectionProducer) Connection() (interface{}, error) {
// since the request rate shouldn't be high.
cp.db.SetMaxOpenConns(cp.config.MaxOpenConnections)
cp.db.SetMaxIdleConns(cp.config.MaxIdleConnections)
cp.db.SetConnMaxLifetime(cp.config.MaxConnectionLifetime)
return cp.db, nil
}
@ -127,7 +128,7 @@ type cassandraConnectionDetails struct {
ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
Consistancy string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"`
}
type cassandraConnectionProducer struct {
@ -248,6 +249,16 @@ func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDet
return nil, fmt.Errorf("Error creating session: %s", err)
}
// Set consistency
if cfg.Consistency != "" {
consistencyValue, err := gocql.ParseConsistencyWrapper(cfg.Consistency)
if err != nil {
return nil, err
}
session.SetConsistency(consistencyValue)
}
// Verify the info
err = session.Query(`LIST USERS`).Exec()
if err != nil {

View File

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"strings"
"time"
"github.com/mitchellh/mapstructure"
)
@ -79,6 +80,7 @@ type DatabaseConfig struct {
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
}
// Query templates a query for us.

View File

@ -2,12 +2,12 @@ package database
import (
"fmt"
"time"
"github.com/fatih/structs"
"github.com/hashicorp/vault/builtin/logical/database/dbs"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
_ "github.com/lib/pq"
)
func pathConfigConnection(b *databaseBackend) *framework.Path {
@ -24,11 +24,6 @@ func pathConfigConnection(b *databaseBackend) *framework.Path {
Description: "DB type (e.g. postgres)",
},
"connection_url": &framework.FieldSchema{
Type: framework.TypeString,
Description: "DB connection string",
},
"connection_details": &framework.FieldSchema{
Type: framework.TypeMap,
Description: "Connection details for specified connection type.",
@ -55,6 +50,12 @@ and a negative value disables idle connections.
If larger than max_open_connections it will be
reduced to the same size.`,
},
"max_connection_lifetime": &framework.FieldSchema{
Type: framework.TypeInt,
Description: `Maximum amount of time a connection may be reused;
a zero or negative value reuses connections forever.`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -105,11 +106,19 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew
maxIdleConns = maxOpenConns
}
maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string)
maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Invalid max_connection_lifetime: %s", err)), nil
}
config := &dbs.DatabaseConfig{
DatabaseType: connType,
ConnectionDetails: connDetails,
MaxOpenConnections: maxOpenConns,
MaxIdleConnections: maxIdleConns,
MaxConnectionLifetime: maxConnLifetime,
}
name := data.Get("name").(string)
@ -118,7 +127,6 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew
b.Lock()
defer b.Unlock()
var err error
var db dbs.DatabaseType
if _, ok := b.connections[name]; ok {