More work on refactor and cassandra database

This commit is contained in:
Brian Kassouf 2016-12-20 11:46:20 -08:00 committed by Brian Kassouf
parent acdcd79af3
commit 2ec5ab5616
11 changed files with 553 additions and 434 deletions

View File

@ -22,7 +22,6 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigLease(&b),
pathListRoles(&b),
pathRoles(&b),
pathRoleCreate(&b),
@ -61,24 +60,6 @@ func (b *databaseBackend) resetAllDBs() {
}
}
// Lease returns the lease information
func (b *databaseBackend) Lease(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease")
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var result configLease
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) {
entry, err := s.Get("role/" + n)
if err != nil {

View File

@ -1,25 +1,20 @@
package dbs
import (
"crypto/tls"
"database/sql"
"fmt"
"strings"
"sync"
"time"
"github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/tlsutil"
"github.com/hashicorp/vault/helper/strutil"
)
type Cassandra struct {
// Session is goroutine safe, however, since we reinitialize
// it when connection info changes, we want to make sure we
// can close it and use a new connection; hence the lock
session *gocql.Session
config ConnectionConfig
ConnectionProducer
CredentialsProducer
sync.RWMutex
}
@ -27,168 +22,85 @@ func (c *Cassandra) Type() string {
return cassandraTypeName
}
func (c *Cassandra) Connection() (*gocql.Session, error) {
// Grab the write lock
c.Lock()
defer c.Unlock()
// If we already have a DB, we got it!
if c.session != nil {
return c.session, nil
}
session, err := createSession(c.config)
func (c *Cassandra) getConnection() (*gocql.Session, error) {
session, err := c.Connection()
if err != nil {
return nil, err
}
// Store the session in backend for reuse
c.session = session
return session, nil
return session.(*gocql.Session), nil
}
func (p *Cassandra) Close() {
// Grab the write lock
p.Lock()
defer p.Unlock()
if p.session != nil {
p.session.Close()
}
p.session = nil
}
func (p *Cassandra) Reset(config ConnectionConfig) (*sql.DB, error) {
// Grab the write lock
p.Lock()
p.config = config
p.Unlock()
p.Close()
return p.Connection()
}
func (p *Cassandra) CreateUser(createStmt, username, password, expiration string) error {
func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error {
// Get the connection
db, err := p.Connection()
session, err := c.getConnection()
if err != nil {
return err
}
// TODO: This is racey
// Grab a read lock
p.RLock()
defer p.RUnlock()
c.RLock()
defer c.RUnlock()
return nil
}
// Set consistency
/* if .Consistency != "" {
consistencyValue, err := gocql.ParseConsistencyWrapper(role.Consistency)
if err != nil {
return err
}
func (p *Cassandra) RenewUser(username, expiration string) error {
db, err := p.Connection()
if err != nil {
return err
session.SetConsistency(consistencyValue)
}*/
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
err = session.Query(queryHelper(query, map[string]string{
"username": username,
"password": password,
})).Exec()
if err != nil {
for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
session.Query(queryHelper(query, map[string]string{
"username": username,
"password": password,
})).Exec()
}
return err
}
}
// TODO: This is Racey
// Grab the read lock
p.RLock()
defer p.RUnlock()
return nil
}
func (p *Cassandra) CustomRevokeUser(username, revocationSQL string) error {
db, err := p.Connection()
func (c *Cassandra) RenewUser(username, expiration string) error {
// NOOP
return nil
}
func (c *Cassandra) RevokeUser(username, revocationSQL string) error {
session, err := c.getConnection()
if err != nil {
return err
}
// TODO: this is Racey
p.RLock()
defer p.RUnlock()
c.RLock()
defer c.RUnlock()
err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec()
if err != nil {
return fmt.Errorf("error removing user %s", username)
}
return nil
}
func (p *Cassandra) DefaultRevokeUser(username string) error {
// Grab the read lock
p.RLock()
defer p.RUnlock()
db, err := p.Connection()
return nil
}
func createSession(cfg *ConnectionConfig) (*gocql.Session, error) {
clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...)
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
Username: cfg.Username,
Password: cfg.Password,
}
clusterConfig.ProtoVersion = cfg.ProtocolVersion
if clusterConfig.ProtoVersion == 0 {
clusterConfig.ProtoVersion = 2
}
clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second
if cfg.TLS {
var tlsConfig *tls.Config
if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 {
if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 {
return nil, fmt.Errorf("Found certificate for TLS authentication but no private key")
}
certBundle := &certutil.CertBundle{}
if len(cfg.Certificate) > 0 {
certBundle.Certificate = cfg.Certificate
certBundle.PrivateKey = cfg.PrivateKey
}
if len(cfg.IssuingCA) > 0 {
certBundle.IssuingCA = cfg.IssuingCA
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return nil, fmt.Errorf("failed to parse certificate bundle: %s", err)
}
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
if err != nil || tlsConfig == nil {
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err)
}
tlsConfig.InsecureSkipVerify = cfg.InsecureTLS
if cfg.TLSMinVersion != "" {
var ok bool
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
}
} else {
// MinVersion was not being set earlier. Reset it to
// zero to gracefully handle upgrades.
tlsConfig.MinVersion = 0
}
}
clusterConfig.SslOpts = &gocql.SslOptions{
Config: *tlsConfig,
}
}
session, err := clusterConfig.CreateSession()
if err != nil {
return nil, fmt.Errorf("Error creating session: %s", err)
}
// Verify the info
err = session.Query(`LIST USERS`).Exec()
if err != nil {
return nil, fmt.Errorf("Error validating connection info: %s", err)
}
return session, nil
}

View File

@ -0,0 +1,254 @@
package dbs
import (
"crypto/tls"
"database/sql"
"fmt"
"strings"
"sync"
"time"
"github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/tlsutil"
"github.com/mitchellh/mapstructure"
)
type ConnectionProducer interface {
Connection() (interface{}, error)
Close()
// TODO: Should we make this immutable instead?
Reset(*DatabaseConfig) error
}
// sqlConnectionProducer impliments ConnectionProducer and provides a generic producer for most sql databases
type sqlConnectionDetails struct {
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
}
type sqlConnectionProducer struct {
config *DatabaseConfig
// TODO: Should we merge these two structures make it immutable?
connDetails *sqlConnectionDetails
db *sql.DB
sync.Mutex
}
func (cp *sqlConnectionProducer) Connection() (interface{}, error) {
// Grab the write lock
cp.Lock()
defer cp.Unlock()
// If we already have a DB, we got it!
if cp.db != nil {
if err := cp.db.Ping(); err == nil {
return cp.db, nil
}
// If the ping was unsuccessful, close it and ignore errors as we'll be
// reestablishing anyways
cp.db.Close()
}
// Otherwise, attempt to make connection
conn := cp.connDetails.ConnectionURL
// Ensure timezone is set to UTC for all the conenctions
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
if strings.Contains(conn, "?") {
conn += "&timezone=utc"
} else {
conn += "?timezone=utc"
}
} else {
conn += " timezone=utc"
}
var err error
cp.db, err = sql.Open(cp.config.DatabaseType, conn)
if err != nil {
return nil, err
}
// Set some connection pool settings. We don't need much of this,
// since the request rate shouldn't be high.
cp.db.SetMaxOpenConns(cp.config.MaxOpenConnections)
cp.db.SetMaxIdleConns(cp.config.MaxIdleConnections)
return cp.db, nil
}
func (cp *sqlConnectionProducer) Close() {
// Grab the write lock
cp.Lock()
defer cp.Unlock()
if cp.db != nil {
cp.db.Close()
}
cp.db = nil
}
func (cp *sqlConnectionProducer) Reset(config *DatabaseConfig) error {
// Grab the write lock
cp.Lock()
var details *sqlConnectionDetails
err := mapstructure.Decode(config.ConnectionDetails, &details)
if err != nil {
return err
}
cp.connDetails = details
cp.config = config
cp.Unlock()
cp.Close()
_, err = cp.Connection()
return err
}
// cassandraConnectionProducer impliments ConnectionProducer and provides connections for cassandra
type cassandraConnectionDetails struct {
Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"`
Username string `json:"username" structs:"username" mapstructure:"username"`
Password string `json:"password" structs:"password" mapstructure:"password"`
TLS bool `json:"tls" structs:"tls" mapstructure:"tls"`
InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"`
PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"`
IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"`
ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
Consistancy string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
}
type cassandraConnectionProducer struct {
config *DatabaseConfig
// TODO: Should we merge these two structures make it immutable?
connDetails *cassandraConnectionDetails
session *gocql.Session
sync.Mutex
}
func (cp *cassandraConnectionProducer) Connection() (interface{}, error) {
// Grab the write lock
cp.Lock()
defer cp.Unlock()
// If we already have a DB, we got it!
if cp.session != nil {
return cp.session, nil
}
session, err := cp.createSession(cp.connDetails)
if err != nil {
return nil, err
}
// Store the session in backend for reuse
cp.session = session
return session, nil
}
func (cp *cassandraConnectionProducer) Close() {
// Grab the write lock
cp.Lock()
defer cp.Unlock()
if cp.session != nil {
cp.session.Close()
}
cp.session = nil
}
func (cp *cassandraConnectionProducer) Reset(config *DatabaseConfig) error {
// Grab the write lock
cp.Lock()
cp.config = config
cp.Unlock()
cp.Close()
_, err := cp.Connection()
return err
}
func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDetails) (*gocql.Session, error) {
clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...)
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
Username: cfg.Username,
Password: cfg.Password,
}
clusterConfig.ProtoVersion = cfg.ProtocolVersion
if clusterConfig.ProtoVersion == 0 {
clusterConfig.ProtoVersion = 2
}
clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second
if cfg.TLS {
var tlsConfig *tls.Config
if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 {
if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 {
return nil, fmt.Errorf("Found certificate for TLS authentication but no private key")
}
certBundle := &certutil.CertBundle{}
if len(cfg.Certificate) > 0 {
certBundle.Certificate = cfg.Certificate
certBundle.PrivateKey = cfg.PrivateKey
}
if len(cfg.IssuingCA) > 0 {
certBundle.IssuingCA = cfg.IssuingCA
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return nil, fmt.Errorf("failed to parse certificate bundle: %s", err)
}
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
if err != nil || tlsConfig == nil {
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err)
}
tlsConfig.InsecureSkipVerify = cfg.InsecureTLS
if cfg.TLSMinVersion != "" {
var ok bool
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
}
} else {
// MinVersion was not being set earlier. Reset it to
// zero to gracefully handle upgrades.
tlsConfig.MinVersion = 0
}
}
clusterConfig.SslOpts = &gocql.SslOptions{
Config: *tlsConfig,
}
}
session, err := clusterConfig.CreateSession()
if err != nil {
return nil, fmt.Errorf("Error creating session: %s", err)
}
// Verify the info
err = session.Query(`LIST USERS`).Exec()
if err != nil {
return nil, fmt.Errorf("Error validating connection info: %s", err)
}
return session, nil
}

View File

@ -0,0 +1,79 @@
package dbs
import (
"fmt"
"strings"
"time"
uuid "github.com/hashicorp/go-uuid"
)
type CredentialsProducer interface {
GenerateUsername(displayName string) (string, error)
GeneratePassword() (string, error)
GenerateExpiration(ttl time.Duration) string
}
// sqlCredentialsProducer impliments CredentialsProducer and provides a generic credentials producer for most sql database types.
type sqlCredentialsProducer struct {
displayNameLen int
usernameLen int
}
func (scg *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) {
// Generate the username, password and expiration. PG limits user to 63 characters
if scg.displayNameLen > 0 && len(displayName) > scg.displayNameLen {
displayName = displayName[:scg.displayNameLen]
}
userUUID, err := uuid.GenerateUUID()
if err != nil {
return "", err
}
username := fmt.Sprintf("%s-%s", displayName, userUUID)
if scg.usernameLen > 0 && len(username) > scg.usernameLen {
username = username[:scg.usernameLen]
}
return username, nil
}
func (scg *sqlCredentialsProducer) GeneratePassword() (string, error) {
password, err := uuid.GenerateUUID()
if err != nil {
return "", err
}
return password, nil
}
func (scg *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string {
return time.Now().
Add(ttl).
Format("2006-01-02 15:04:05-0700")
}
type cassandraCredentialsProducer struct{}
func (ccp *cassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) {
userUUID, err := uuid.GenerateUUID()
if err != nil {
return "", err
}
username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix())
username = strings.Replace(username, "-", "_", -1)
return username, nil
}
func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) {
password, err := uuid.GenerateUUID()
if err != nil {
return "", err
}
return password, nil
}
func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) string {
return ""
}

View File

@ -1,10 +1,11 @@
package dbs
import (
"database/sql"
"errors"
"fmt"
"strings"
"github.com/mitchellh/mapstructure"
)
const (
@ -16,11 +17,47 @@ var (
ErrUnsupportedDatabaseType = errors.New("Unsupported database type")
)
func Factory(conf ConnectionConfig) (DatabaseType, error) {
switch conf.ConnectionType {
func Factory(conf *DatabaseConfig) (DatabaseType, error) {
switch conf.DatabaseType {
case postgreSQLTypeName:
var details *sqlConnectionDetails
err := mapstructure.Decode(conf.ConnectionDetails, &details)
if err != nil {
return nil, err
}
connProducer := &sqlConnectionProducer{
config: conf,
connDetails: details,
}
credsProducer := &sqlCredentialsProducer{
displayNameLen: 23,
usernameLen: 63,
}
return &PostgreSQL{
config: conf,
ConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}, nil
case cassandraTypeName:
var details *cassandraConnectionDetails
err := mapstructure.Decode(conf.ConnectionDetails, &details)
if err != nil {
return nil, err
}
connProducer := &cassandraConnectionProducer{
config: conf,
connDetails: details,
}
credsProducer := &cassandraCredentialsProducer{}
return &Cassandra{
ConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}, nil
}
@ -29,21 +66,19 @@ func Factory(conf ConnectionConfig) (DatabaseType, error) {
type DatabaseType interface {
Type() string
Connection() (*sql.DB, error)
Close()
Reset(ConnectionConfig) (*sql.DB, error)
CreateUser(createStmt, username, password, expiration string) error
CreateUser(createStmt, rollbackStmt, username, password, expiration string) error
RenewUser(username, expiration string) error
CustomRevokeUser(username, revocationSQL string) error
DefaultRevokeUser(username string) error
RevokeUser(username, revocationStmt string) error
ConnectionProducer
CredentialsProducer
}
type ConnectionConfig struct {
ConnectionType string `json:"type" structs:"type" mapstructure:"type"`
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
ConnectionDetails map[string]string `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
type DatabaseConfig struct {
DatabaseType string `json:"type" structs:"type" mapstructure:"type"`
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
}
// Query templates a query for us.

View File

@ -11,9 +11,10 @@ import (
)
type PostgreSQL struct {
db *sql.DB
config ConnectionConfig
db *sql.DB
ConnectionProducer
CredentialsProducer
sync.RWMutex
}
@ -21,74 +22,18 @@ func (p *PostgreSQL) Type() string {
return postgreSQLTypeName
}
func (p *PostgreSQL) Connection() (*sql.DB, error) {
// Grab the write lock
p.Lock()
defer p.Unlock()
// If we already have a DB, we got it!
if p.db != nil {
if err := p.db.Ping(); err == nil {
return p.db, nil
}
// If the ping was unsuccessful, close it and ignore errors as we'll be
// reestablishing anyways
p.db.Close()
}
// Otherwise, attempt to make connection
conn := p.config.ConnectionURL
// Ensure timezone is set to UTC for all the conenctions
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
if strings.Contains(conn, "?") {
conn += "&timezone=utc"
} else {
conn += "?timezone=utc"
}
} else {
conn += " timezone=utc"
}
var err error
p.db, err = sql.Open("postgres", conn)
func (p *PostgreSQL) getConnection() (*sql.DB, error) {
db, err := p.Connection()
if err != nil {
return nil, err
}
// Set some connection pool settings. We don't need much of this,
// since the request rate shouldn't be high.
p.db.SetMaxOpenConns(p.config.MaxOpenConnections)
p.db.SetMaxIdleConns(p.config.MaxIdleConnections)
return p.db, nil
return db.(*sql.DB), nil
}
func (p *PostgreSQL) Close() {
// Grab the write lock
p.Lock()
defer p.Unlock()
if p.db != nil {
p.db.Close()
}
p.db = nil
}
func (p *PostgreSQL) Reset(config ConnectionConfig) (*sql.DB, error) {
// Grab the write lock
p.Lock()
p.config = config
p.Unlock()
p.Close()
return p.Connection()
}
func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration string) error {
func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error {
// Get the connection
db, err := p.Connection()
db, err := p.getConnection()
if err != nil {
return err
}
@ -144,7 +89,7 @@ func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration strin
}
func (p *PostgreSQL) RenewUser(username, expiration string) error {
db, err := p.Connection()
db, err := p.getConnection()
if err != nil {
return err
}
@ -170,14 +115,23 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error {
return nil
}
func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error {
db, err := p.Connection()
func (p *PostgreSQL) RevokeUser(username, revocationStmt string) error {
// Grab the read lock
p.RLock()
defer p.RUnlock()
if revocationStmt == "" {
return p.defaultRevokeUser(username)
}
return p.customRevokeUser(username, revocationStmt)
}
func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error {
db, err := p.getConnection()
if err != nil {
return err
}
// TODO: this is Racey
p.RLock()
defer p.RUnlock()
tx, err := db.Begin()
if err != nil {
@ -187,7 +141,7 @@ func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error {
tx.Rollback()
}()
for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") {
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@ -213,12 +167,8 @@ func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error {
return nil
}
func (p *PostgreSQL) DefaultRevokeUser(username string) error {
// Grab the read lock
p.RLock()
defer p.RUnlock()
db, err := p.Connection()
func (p *PostgreSQL) defaultRevokeUser(username string) error {
db, err := p.getConnection()
if err != nil {
return err
}

View File

@ -79,7 +79,7 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo
return nil, nil
}
var config dbs.ConnectionConfig
var config dbs.DatabaseConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
@ -89,8 +89,8 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo
}
func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
connURL := data.Get("connection_url").(string)
connType := data.Get("connection_type").(string)
connDetails := data.Get("connection_details").(map[string]interface{})
maxOpenConns := data.Get("max_open_connections").(int)
if maxOpenConns == 0 {
@ -105,9 +105,9 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew
maxIdleConns = maxOpenConns
}
config := dbs.ConnectionConfig{
ConnectionType: connType,
ConnectionURL: connURL,
config := &dbs.DatabaseConfig{
DatabaseType: connType,
ConnectionDetails: connDetails,
MaxOpenConnections: maxOpenConns,
MaxIdleConnections: maxIdleConns,
}

View File

@ -1,103 +0,0 @@
package database
import (
"fmt"
"time"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func pathConfigLease(b *databaseBackend) *framework.Path {
return &framework.Path{
Pattern: "config/lease",
Fields: map[string]*framework.FieldSchema{
"lease": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Default lease for roles.",
},
"lease_max": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Maximum time a credential is valid for.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathLeaseRead,
logical.UpdateOperation: b.pathLeaseWrite,
},
HelpSynopsis: pathConfigLeaseHelpSyn,
HelpDescription: pathConfigLeaseHelpDesc,
}
}
func (b *databaseBackend) pathLeaseWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
leaseRaw := d.Get("lease").(string)
leaseMaxRaw := d.Get("lease_max").(string)
lease, err := time.ParseDuration(leaseRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Invalid lease: %s", err)), nil
}
leaseMax, err := time.ParseDuration(leaseMaxRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Invalid lease: %s", err)), nil
}
// Store it
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
Lease: lease,
LeaseMax: leaseMax,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
return nil, err
}
return nil, nil
}
func (b *databaseBackend) pathLeaseRead(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
lease, err := b.Lease(req.Storage)
if err != nil {
return nil, err
}
if lease == nil {
return nil, nil
}
return &logical.Response{
Data: map[string]interface{}{
"lease": lease.Lease.String(),
"lease_max": lease.LeaseMax.String(),
},
}, nil
}
type configLease struct {
Lease time.Duration
LeaseMax time.Duration
}
const pathConfigLeaseHelpSyn = `
Configure the default lease information for generated credentials.
`
const pathConfigLeaseHelpDesc = `
This configures the default lease information used for credentials
generated by this backend. The lease specifies the duration that a
credential will be valid for, as well as the maximum session for
a set of credentials.
The format for the lease is "1h" or integer and then unit. The longest
unit is hour.
`

View File

@ -2,9 +2,7 @@ package database
import (
"fmt"
"time"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
_ "github.com/lib/pq"
@ -45,41 +43,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
}
// Determine if we have a lease
b.logger.Trace("postgres/pathRoleCreateRead: getting lease")
lease, err := b.Lease(req.Storage)
if err != nil {
return nil, err
}
// Unlike some other backends we need a lease here (can't leave as 0 and
// let core fill it in) because Postgres also expires users as a safety
// measure, so cannot be zero
if lease == nil {
lease = &configLease{
Lease: b.System().DefaultLeaseTTL(),
}
}
// Generate the username, password and expiration. PG limits user to 63 characters
displayName := req.DisplayName
if len(displayName) > 26 {
displayName = displayName[:26]
}
userUUID, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
username := fmt.Sprintf("%s-%s", displayName, userUUID)
if len(username) > 63 {
username = username[:63]
}
password, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
expiration := time.Now().
Add(lease.Lease).
Format("2006-01-02 15:04:05-0700")
// Get our handle
b.logger.Trace("postgres/pathRoleCreateRead: getting database handle")
@ -92,7 +56,19 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName)
}
err = db.CreateUser(role.CreationStatement, username, password, expiration)
username, err := db.GenerateUsername(req.DisplayName)
if err != nil {
return nil, err
}
password, err := db.GeneratePassword()
if err != nil {
return nil, err
}
expiration := db.GenerateExpiration(role.DefaultTTL)
err = db.CreateUser(role.CreationStatement, role.RollbackStatement, username, password, expiration)
if err != nil {
return nil, err
}
@ -105,7 +81,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
"username": username,
"role": name,
})
resp.Secret.TTL = lease.Lease
resp.Secret.TTL = role.DefaultTTL
return resp, nil
}

View File

@ -1,6 +1,9 @@
package database
import (
"fmt"
"time"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -44,6 +47,24 @@ func pathRoles(b *databaseBackend) *framework.Path {
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
will be substituted.`,
},
"rollback_statement": {
Type: framework.TypeString,
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
string, a base64-encoded semicolon-separated string, a serialized JSON string
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
will be substituted.`,
},
"default_ttl": {
Type: framework.TypeString,
Description: "Default ttl for role.",
},
"max_ttl": {
Type: framework.TypeString,
Description: "Maximum time a credential is valid for",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -79,6 +100,9 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie
Data: map[string]interface{}{
"creation_statment": role.CreationStatement,
"revocation_statement": role.RevocationStatement,
"rollback_statement": role.RollbackStatement,
"default_ttl": role.DefaultTTL.String(),
"max_ttl": role.MaxTTL.String(),
},
}, nil
}
@ -97,6 +121,20 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
dbName := data.Get("db_name").(string)
creationStmt := data.Get("creation_statement").(string)
revocationStmt := data.Get("revocation_statement").(string)
rollbackStmt := data.Get("rollback_statement").(string)
defaultTTLRaw := data.Get("default_ttl").(string)
maxTTLRaw := data.Get("max_ttl").(string)
defaultTTL, err := time.ParseDuration(defaultTTLRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Invalid default_ttl: %s", err)), nil
}
maxTTL, err := time.ParseDuration(maxTTLRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Invalid max_ttl: %s", err)), nil
}
// TODO: Think about preparing the statments to test.
@ -105,6 +143,9 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
DBName: dbName,
CreationStatement: creationStmt,
RevocationStatement: revocationStmt,
RollbackStatement: rollbackStmt,
DefaultTTL: defaultTTL,
MaxTTL: maxTTL,
})
if err != nil {
return nil, err
@ -117,9 +158,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
}
type roleEntry struct {
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"`
RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"`
RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
RollbackStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
}
const pathRoleHelpSyn = `

View File

@ -1,7 +1,6 @@
package database
import (
"errors"
"fmt"
"github.com/hashicorp/vault/logical"
@ -31,8 +30,6 @@ func secretCreds(b *databaseBackend) *framework.Secret {
}
func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
dbName := d.Get("name").(string)
// Get the username from the internal data
usernameRaw, ok := req.Secret.InternalData["username"]
if !ok {
@ -40,27 +37,35 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
}
username, ok := usernameRaw.(string)
// Get our connection
db, ok := b.connections[dbName]
roleNameRaw, ok := req.Secret.InternalData["role"]
if !ok {
return nil, errors.New(fmt.Sprintf("Could not find connection with name %s", dbName))
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
}
// Get the lease information
lease, err := b.Lease(req.Storage)
role, err := b.Role(req.Storage, roleNameRaw.(string))
if err != nil {
return nil, err
}
if lease == nil {
lease = &configLease{}
if role == nil {
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
}
f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, b.System())
f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System())
resp, err := f(req, d)
if err != nil {
return nil, err
}
// Grab the read lock
b.RLock()
defer b.RUnlock()
// Get our connection
db, ok := b.connections[role.DBName]
if !ok {
return nil, fmt.Errorf("Could not find connection with name %s", role.DBName)
}
// Make sure we increase the VALID UNTIL endpoint for this user.
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
expiration := expireTime.Format("2006-01-02 15:04:05-0700")
@ -124,23 +129,9 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
return nil, fmt.Errorf("Could not find database with name: %s", role.DBName)
}
// TODO: Maybe move this down into db package?
switch revocationSQL {
// This is the default revocation logic. If revocation SQL is provided it
// is simply executed as-is.
case "":
err := db.DefaultRevokeUser(username)
if err != nil {
return nil, err
}
// We have revocation SQL, execute directly, within a transaction
default:
err := db.CustomRevokeUser(username, revocationSQL)
if err != nil {
return nil, err
}
err = db.RevokeUser(username, revocationSQL)
if err != nil {
return nil, err
}
return resp, nil