Use ParseDurationSecond to parse the timeouts in connutil

This commit is contained in:
Brian Kassouf 2017-05-03 13:11:30 -07:00
parent 37bd3ed76e
commit cf15c023df
2 changed files with 31 additions and 24 deletions

View file

@ -11,28 +11,30 @@ import (
"github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/helper/tlsutil"
)
// CassandraConnectionProducer implements ConnectionProducer and provides an
// interface for cassandra databases to make connections.
type CassandraConnectionProducer struct {
Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"`
Username string `json:"username" structs:"username" mapstructure:"username"`
Password string `json:"password" structs:"password" mapstructure:"password"`
TLS bool `json:"tls" structs:"tls" mapstructure:"tls"`
InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"`
PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"`
IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"`
ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"`
Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"`
Username string `json:"username" structs:"username" mapstructure:"username"`
Password string `json:"password" structs:"password" mapstructure:"password"`
TLS bool `json:"tls" structs:"tls" mapstructure:"tls"`
InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"`
PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"`
IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"`
ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
ConnectTimeoutRaw interface{} `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"`
Initialized bool
Type string
session *gocql.Session
connectTimeout time.Duration
Initialized bool
Type string
session *gocql.Session
sync.Mutex
}
@ -46,6 +48,11 @@ func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, ve
}
c.Initialized = true
c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw)
if err != nil {
return fmt.Errorf("invalid connect_timeout: %s", err)
}
if verifyConnection {
if _, err := c.Connection(); err != nil {
return fmt.Errorf("error Initalizing Connection: %s", err)
@ -101,8 +108,7 @@ func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) {
clusterConfig.ProtoVersion = 2
}
clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second
clusterConfig.Timeout = c.connectTimeout
if c.TLS {
var tlsConfig *tls.Config
if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 {

View file

@ -10,19 +10,20 @@ import (
// Import sql drivers
_ "github.com/denisenkom/go-mssqldb"
_ "github.com/go-sql-driver/mysql"
"github.com/hashicorp/vault/helper/parseutil"
_ "github.com/lib/pq"
"github.com/mitchellh/mapstructure"
)
// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
type SQLConnectionProducer struct {
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
MaxConnectionLifetimeRaw string `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
Type string
MaxConnectionLifetime time.Duration
maxConnectionLifetime time.Duration
Initialized bool
db *sql.DB
sync.Mutex
@ -51,7 +52,7 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo
c.MaxConnectionLifetimeRaw = "0s"
}
c.MaxConnectionLifetime, err = time.ParseDuration(c.MaxConnectionLifetimeRaw)
c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw)
if err != nil {
return fmt.Errorf("invalid max_connection_lifetime: %s", err)
}
@ -110,7 +111,7 @@ func (c *SQLConnectionProducer) Connection() (interface{}, error) {
// since the request rate shouldn't be high.
c.db.SetMaxOpenConns(c.MaxOpenConnections)
c.db.SetMaxIdleConns(c.MaxIdleConnections)
c.db.SetConnMaxLifetime(c.MaxConnectionLifetime)
c.db.SetConnMaxLifetime(c.maxConnectionLifetime)
return c.db, nil
}