From cf15c023dfc5723fe2c0e71d29e0513196e5eb24 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 13:11:30 -0700 Subject: [PATCH] Use ParseDurationSecond to parse the timeouts in connutil --- plugins/helper/database/connutil/cassandra.go | 40 +++++++++++-------- plugins/helper/database/connutil/sql.go | 15 +++---- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/plugins/helper/database/connutil/cassandra.go b/plugins/helper/database/connutil/cassandra.go index 1babc3cbd..27fb25195 100644 --- a/plugins/helper/database/connutil/cassandra.go +++ b/plugins/helper/database/connutil/cassandra.go @@ -11,28 +11,30 @@ import ( "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/parseutil" "github.com/hashicorp/vault/helper/tlsutil" ) // CassandraConnectionProducer implements ConnectionProducer and provides an // interface for cassandra databases to make connections. type CassandraConnectionProducer struct { - Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` - Username string `json:"username" structs:"username" mapstructure:"username"` - Password string `json:"password" structs:"password" mapstructure:"password"` - TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` - InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` - Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` - PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` - IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` - ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` - ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` - TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` - Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` + Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` + Username string `json:"username" structs:"username" mapstructure:"username"` + Password string `json:"password" structs:"password" mapstructure:"password"` + TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` + InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` + Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` + PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` + IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` + ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` + ConnectTimeoutRaw interface{} `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` + TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` + Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` - Initialized bool - Type string - session *gocql.Session + connectTimeout time.Duration + Initialized bool + Type string + session *gocql.Session sync.Mutex } @@ -46,6 +48,11 @@ func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, ve } c.Initialized = true + c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw) + if err != nil { + return fmt.Errorf("invalid connect_timeout: %s", err) + } + if verifyConnection { if _, err := c.Connection(); err != nil { return fmt.Errorf("error Initalizing Connection: %s", err) @@ -101,8 +108,7 @@ func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) { clusterConfig.ProtoVersion = 2 } - clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second - + clusterConfig.Timeout = c.connectTimeout if c.TLS { var tlsConfig *tls.Config if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { diff --git a/plugins/helper/database/connutil/sql.go b/plugins/helper/database/connutil/sql.go index 0bfc5f9f6..4a6368560 100644 --- a/plugins/helper/database/connutil/sql.go +++ b/plugins/helper/database/connutil/sql.go @@ -10,19 +10,20 @@ import ( // Import sql drivers _ "github.com/denisenkom/go-mssqldb" _ "github.com/go-sql-driver/mysql" + "github.com/hashicorp/vault/helper/parseutil" _ "github.com/lib/pq" "github.com/mitchellh/mapstructure" ) // SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases type SQLConnectionProducer struct { - ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` - MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` - MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` - MaxConnectionLifetimeRaw string `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` + ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` + MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` Type string - MaxConnectionLifetime time.Duration + maxConnectionLifetime time.Duration Initialized bool db *sql.DB sync.Mutex @@ -51,7 +52,7 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo c.MaxConnectionLifetimeRaw = "0s" } - c.MaxConnectionLifetime, err = time.ParseDuration(c.MaxConnectionLifetimeRaw) + c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw) if err != nil { return fmt.Errorf("invalid max_connection_lifetime: %s", err) } @@ -110,7 +111,7 @@ func (c *SQLConnectionProducer) Connection() (interface{}, error) { // since the request rate shouldn't be high. c.db.SetMaxOpenConns(c.MaxOpenConnections) c.db.SetMaxIdleConns(c.MaxIdleConnections) - c.db.SetConnMaxLifetime(c.MaxConnectionLifetime) + c.db.SetConnMaxLifetime(c.maxConnectionLifetime) return c.db, nil }