More work on refactor and cassandra database
This commit is contained in:
parent
acdcd79af3
commit
2ec5ab5616
|
@ -22,7 +22,6 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
|
||||||
|
|
||||||
Paths: []*framework.Path{
|
Paths: []*framework.Path{
|
||||||
pathConfigConnection(&b),
|
pathConfigConnection(&b),
|
||||||
pathConfigLease(&b),
|
|
||||||
pathListRoles(&b),
|
pathListRoles(&b),
|
||||||
pathRoles(&b),
|
pathRoles(&b),
|
||||||
pathRoleCreate(&b),
|
pathRoleCreate(&b),
|
||||||
|
@ -61,24 +60,6 @@ func (b *databaseBackend) resetAllDBs() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lease returns the lease information
|
|
||||||
func (b *databaseBackend) Lease(s logical.Storage) (*configLease, error) {
|
|
||||||
entry, err := s.Get("config/lease")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if entry == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var result configLease
|
|
||||||
if err := entry.DecodeJSON(&result); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) {
|
func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) {
|
||||||
entry, err := s.Get("role/" + n)
|
entry, err := s.Get("role/" + n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1,25 +1,20 @@
|
||||||
package dbs
|
package dbs
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
"github.com/gocql/gocql"
|
||||||
"github.com/hashicorp/vault/helper/certutil"
|
"github.com/hashicorp/vault/helper/strutil"
|
||||||
"github.com/hashicorp/vault/helper/tlsutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Cassandra struct {
|
type Cassandra struct {
|
||||||
// Session is goroutine safe, however, since we reinitialize
|
// Session is goroutine safe, however, since we reinitialize
|
||||||
// it when connection info changes, we want to make sure we
|
// it when connection info changes, we want to make sure we
|
||||||
// can close it and use a new connection; hence the lock
|
// can close it and use a new connection; hence the lock
|
||||||
session *gocql.Session
|
ConnectionProducer
|
||||||
config ConnectionConfig
|
CredentialsProducer
|
||||||
|
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,168 +22,85 @@ func (c *Cassandra) Type() string {
|
||||||
return cassandraTypeName
|
return cassandraTypeName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cassandra) Connection() (*gocql.Session, error) {
|
func (c *Cassandra) getConnection() (*gocql.Session, error) {
|
||||||
// Grab the write lock
|
session, err := c.Connection()
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
// If we already have a DB, we got it!
|
|
||||||
if c.session != nil {
|
|
||||||
return c.session, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
session, err := createSession(c.config)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the session in backend for reuse
|
return session.(*gocql.Session), nil
|
||||||
c.session = session
|
|
||||||
|
|
||||||
return session, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Cassandra) Close() {
|
func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error {
|
||||||
// Grab the write lock
|
|
||||||
p.Lock()
|
|
||||||
defer p.Unlock()
|
|
||||||
|
|
||||||
if p.session != nil {
|
|
||||||
p.session.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
p.session = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Cassandra) Reset(config ConnectionConfig) (*sql.DB, error) {
|
|
||||||
// Grab the write lock
|
|
||||||
p.Lock()
|
|
||||||
p.config = config
|
|
||||||
p.Unlock()
|
|
||||||
|
|
||||||
p.Close()
|
|
||||||
return p.Connection()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Cassandra) CreateUser(createStmt, username, password, expiration string) error {
|
|
||||||
// Get the connection
|
// Get the connection
|
||||||
db, err := p.Connection()
|
session, err := c.getConnection()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: This is racey
|
// TODO: This is racey
|
||||||
// Grab a read lock
|
// Grab a read lock
|
||||||
p.RLock()
|
c.RLock()
|
||||||
defer p.RUnlock()
|
defer c.RUnlock()
|
||||||
|
|
||||||
return nil
|
// Set consistency
|
||||||
}
|
/* if .Consistency != "" {
|
||||||
|
consistencyValue, err := gocql.ParseConsistencyWrapper(role.Consistency)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Cassandra) RenewUser(username, expiration string) error {
|
session.SetConsistency(consistencyValue)
|
||||||
db, err := p.Connection()
|
}*/
|
||||||
if err != nil {
|
|
||||||
return err
|
// Execute each query
|
||||||
|
for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") {
|
||||||
|
query = strings.TrimSpace(query)
|
||||||
|
if len(query) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = session.Query(queryHelper(query, map[string]string{
|
||||||
|
"username": username,
|
||||||
|
"password": password,
|
||||||
|
})).Exec()
|
||||||
|
if err != nil {
|
||||||
|
for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmt, ";") {
|
||||||
|
query = strings.TrimSpace(query)
|
||||||
|
if len(query) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
session.Query(queryHelper(query, map[string]string{
|
||||||
|
"username": username,
|
||||||
|
"password": password,
|
||||||
|
})).Exec()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// TODO: This is Racey
|
|
||||||
// Grab the read lock
|
|
||||||
p.RLock()
|
|
||||||
defer p.RUnlock()
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Cassandra) CustomRevokeUser(username, revocationSQL string) error {
|
func (c *Cassandra) RenewUser(username, expiration string) error {
|
||||||
db, err := p.Connection()
|
// NOOP
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cassandra) RevokeUser(username, revocationSQL string) error {
|
||||||
|
session, err := c.getConnection()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// TODO: this is Racey
|
// TODO: this is Racey
|
||||||
p.RLock()
|
c.RLock()
|
||||||
defer p.RUnlock()
|
defer c.RUnlock()
|
||||||
|
|
||||||
|
err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error removing user %s", username)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Cassandra) DefaultRevokeUser(username string) error {
|
|
||||||
// Grab the read lock
|
|
||||||
p.RLock()
|
|
||||||
defer p.RUnlock()
|
|
||||||
|
|
||||||
db, err := p.Connection()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func createSession(cfg *ConnectionConfig) (*gocql.Session, error) {
|
|
||||||
clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...)
|
|
||||||
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
|
|
||||||
Username: cfg.Username,
|
|
||||||
Password: cfg.Password,
|
|
||||||
}
|
|
||||||
|
|
||||||
clusterConfig.ProtoVersion = cfg.ProtocolVersion
|
|
||||||
if clusterConfig.ProtoVersion == 0 {
|
|
||||||
clusterConfig.ProtoVersion = 2
|
|
||||||
}
|
|
||||||
|
|
||||||
clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second
|
|
||||||
|
|
||||||
if cfg.TLS {
|
|
||||||
var tlsConfig *tls.Config
|
|
||||||
if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 {
|
|
||||||
if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 {
|
|
||||||
return nil, fmt.Errorf("Found certificate for TLS authentication but no private key")
|
|
||||||
}
|
|
||||||
|
|
||||||
certBundle := &certutil.CertBundle{}
|
|
||||||
if len(cfg.Certificate) > 0 {
|
|
||||||
certBundle.Certificate = cfg.Certificate
|
|
||||||
certBundle.PrivateKey = cfg.PrivateKey
|
|
||||||
}
|
|
||||||
if len(cfg.IssuingCA) > 0 {
|
|
||||||
certBundle.IssuingCA = cfg.IssuingCA
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse certificate bundle: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
|
||||||
if err != nil || tlsConfig == nil {
|
|
||||||
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err)
|
|
||||||
}
|
|
||||||
tlsConfig.InsecureSkipVerify = cfg.InsecureTLS
|
|
||||||
|
|
||||||
if cfg.TLSMinVersion != "" {
|
|
||||||
var ok bool
|
|
||||||
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// MinVersion was not being set earlier. Reset it to
|
|
||||||
// zero to gracefully handle upgrades.
|
|
||||||
tlsConfig.MinVersion = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
clusterConfig.SslOpts = &gocql.SslOptions{
|
|
||||||
Config: *tlsConfig,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
session, err := clusterConfig.CreateSession()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Error creating session: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the info
|
|
||||||
err = session.Query(`LIST USERS`).Exec()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Error validating connection info: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,254 @@
|
||||||
|
package dbs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gocql/gocql"
|
||||||
|
"github.com/hashicorp/vault/helper/certutil"
|
||||||
|
"github.com/hashicorp/vault/helper/tlsutil"
|
||||||
|
"github.com/mitchellh/mapstructure"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnectionProducer interface {
|
||||||
|
Connection() (interface{}, error)
|
||||||
|
Close()
|
||||||
|
// TODO: Should we make this immutable instead?
|
||||||
|
Reset(*DatabaseConfig) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqlConnectionProducer impliments ConnectionProducer and provides a generic producer for most sql databases
|
||||||
|
type sqlConnectionDetails struct {
|
||||||
|
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type sqlConnectionProducer struct {
|
||||||
|
config *DatabaseConfig
|
||||||
|
// TODO: Should we merge these two structures make it immutable?
|
||||||
|
connDetails *sqlConnectionDetails
|
||||||
|
|
||||||
|
db *sql.DB
|
||||||
|
sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cp *sqlConnectionProducer) Connection() (interface{}, error) {
|
||||||
|
// Grab the write lock
|
||||||
|
cp.Lock()
|
||||||
|
defer cp.Unlock()
|
||||||
|
|
||||||
|
// If we already have a DB, we got it!
|
||||||
|
if cp.db != nil {
|
||||||
|
if err := cp.db.Ping(); err == nil {
|
||||||
|
return cp.db, nil
|
||||||
|
}
|
||||||
|
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||||
|
// reestablishing anyways
|
||||||
|
cp.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, attempt to make connection
|
||||||
|
conn := cp.connDetails.ConnectionURL
|
||||||
|
|
||||||
|
// Ensure timezone is set to UTC for all the conenctions
|
||||||
|
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
|
||||||
|
if strings.Contains(conn, "?") {
|
||||||
|
conn += "&timezone=utc"
|
||||||
|
} else {
|
||||||
|
conn += "?timezone=utc"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
conn += " timezone=utc"
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
cp.db, err = sql.Open(cp.config.DatabaseType, conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set some connection pool settings. We don't need much of this,
|
||||||
|
// since the request rate shouldn't be high.
|
||||||
|
cp.db.SetMaxOpenConns(cp.config.MaxOpenConnections)
|
||||||
|
cp.db.SetMaxIdleConns(cp.config.MaxIdleConnections)
|
||||||
|
|
||||||
|
return cp.db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cp *sqlConnectionProducer) Close() {
|
||||||
|
// Grab the write lock
|
||||||
|
cp.Lock()
|
||||||
|
defer cp.Unlock()
|
||||||
|
|
||||||
|
if cp.db != nil {
|
||||||
|
cp.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
cp.db = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cp *sqlConnectionProducer) Reset(config *DatabaseConfig) error {
|
||||||
|
// Grab the write lock
|
||||||
|
cp.Lock()
|
||||||
|
|
||||||
|
var details *sqlConnectionDetails
|
||||||
|
err := mapstructure.Decode(config.ConnectionDetails, &details)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cp.connDetails = details
|
||||||
|
cp.config = config
|
||||||
|
|
||||||
|
cp.Unlock()
|
||||||
|
|
||||||
|
cp.Close()
|
||||||
|
_, err = cp.Connection()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// cassandraConnectionProducer impliments ConnectionProducer and provides connections for cassandra
|
||||||
|
type cassandraConnectionDetails struct {
|
||||||
|
Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"`
|
||||||
|
Username string `json:"username" structs:"username" mapstructure:"username"`
|
||||||
|
Password string `json:"password" structs:"password" mapstructure:"password"`
|
||||||
|
TLS bool `json:"tls" structs:"tls" mapstructure:"tls"`
|
||||||
|
InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
|
||||||
|
Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"`
|
||||||
|
PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"`
|
||||||
|
IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"`
|
||||||
|
ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
|
||||||
|
ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
|
||||||
|
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
|
||||||
|
Consistancy string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type cassandraConnectionProducer struct {
|
||||||
|
config *DatabaseConfig
|
||||||
|
// TODO: Should we merge these two structures make it immutable?
|
||||||
|
connDetails *cassandraConnectionDetails
|
||||||
|
|
||||||
|
session *gocql.Session
|
||||||
|
sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cp *cassandraConnectionProducer) Connection() (interface{}, error) {
|
||||||
|
// Grab the write lock
|
||||||
|
cp.Lock()
|
||||||
|
defer cp.Unlock()
|
||||||
|
|
||||||
|
// If we already have a DB, we got it!
|
||||||
|
if cp.session != nil {
|
||||||
|
return cp.session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := cp.createSession(cp.connDetails)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the session in backend for reuse
|
||||||
|
cp.session = session
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cp *cassandraConnectionProducer) Close() {
|
||||||
|
// Grab the write lock
|
||||||
|
cp.Lock()
|
||||||
|
defer cp.Unlock()
|
||||||
|
|
||||||
|
if cp.session != nil {
|
||||||
|
cp.session.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
cp.session = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cp *cassandraConnectionProducer) Reset(config *DatabaseConfig) error {
|
||||||
|
// Grab the write lock
|
||||||
|
cp.Lock()
|
||||||
|
cp.config = config
|
||||||
|
cp.Unlock()
|
||||||
|
|
||||||
|
cp.Close()
|
||||||
|
_, err := cp.Connection()
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDetails) (*gocql.Session, error) {
|
||||||
|
clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...)
|
||||||
|
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
|
||||||
|
Username: cfg.Username,
|
||||||
|
Password: cfg.Password,
|
||||||
|
}
|
||||||
|
|
||||||
|
clusterConfig.ProtoVersion = cfg.ProtocolVersion
|
||||||
|
if clusterConfig.ProtoVersion == 0 {
|
||||||
|
clusterConfig.ProtoVersion = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second
|
||||||
|
|
||||||
|
if cfg.TLS {
|
||||||
|
var tlsConfig *tls.Config
|
||||||
|
if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 {
|
||||||
|
if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 {
|
||||||
|
return nil, fmt.Errorf("Found certificate for TLS authentication but no private key")
|
||||||
|
}
|
||||||
|
|
||||||
|
certBundle := &certutil.CertBundle{}
|
||||||
|
if len(cfg.Certificate) > 0 {
|
||||||
|
certBundle.Certificate = cfg.Certificate
|
||||||
|
certBundle.PrivateKey = cfg.PrivateKey
|
||||||
|
}
|
||||||
|
if len(cfg.IssuingCA) > 0 {
|
||||||
|
certBundle.IssuingCA = cfg.IssuingCA
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse certificate bundle: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
||||||
|
if err != nil || tlsConfig == nil {
|
||||||
|
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err)
|
||||||
|
}
|
||||||
|
tlsConfig.InsecureSkipVerify = cfg.InsecureTLS
|
||||||
|
|
||||||
|
if cfg.TLSMinVersion != "" {
|
||||||
|
var ok bool
|
||||||
|
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// MinVersion was not being set earlier. Reset it to
|
||||||
|
// zero to gracefully handle upgrades.
|
||||||
|
tlsConfig.MinVersion = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
clusterConfig.SslOpts = &gocql.SslOptions{
|
||||||
|
Config: *tlsConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := clusterConfig.CreateSession()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Error creating session: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the info
|
||||||
|
err = session.Query(`LIST USERS`).Exec()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Error validating connection info: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
}
|
|
@ -0,0 +1,79 @@
|
||||||
|
package dbs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
uuid "github.com/hashicorp/go-uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CredentialsProducer interface {
|
||||||
|
GenerateUsername(displayName string) (string, error)
|
||||||
|
GeneratePassword() (string, error)
|
||||||
|
GenerateExpiration(ttl time.Duration) string
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqlCredentialsProducer impliments CredentialsProducer and provides a generic credentials producer for most sql database types.
|
||||||
|
type sqlCredentialsProducer struct {
|
||||||
|
displayNameLen int
|
||||||
|
usernameLen int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scg *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) {
|
||||||
|
// Generate the username, password and expiration. PG limits user to 63 characters
|
||||||
|
if scg.displayNameLen > 0 && len(displayName) > scg.displayNameLen {
|
||||||
|
displayName = displayName[:scg.displayNameLen]
|
||||||
|
}
|
||||||
|
userUUID, err := uuid.GenerateUUID()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
username := fmt.Sprintf("%s-%s", displayName, userUUID)
|
||||||
|
if scg.usernameLen > 0 && len(username) > scg.usernameLen {
|
||||||
|
username = username[:scg.usernameLen]
|
||||||
|
}
|
||||||
|
|
||||||
|
return username, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scg *sqlCredentialsProducer) GeneratePassword() (string, error) {
|
||||||
|
password, err := uuid.GenerateUUID()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return password, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scg *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string {
|
||||||
|
return time.Now().
|
||||||
|
Add(ttl).
|
||||||
|
Format("2006-01-02 15:04:05-0700")
|
||||||
|
}
|
||||||
|
|
||||||
|
type cassandraCredentialsProducer struct{}
|
||||||
|
|
||||||
|
func (ccp *cassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) {
|
||||||
|
userUUID, err := uuid.GenerateUUID()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix())
|
||||||
|
username = strings.Replace(username, "-", "_", -1)
|
||||||
|
|
||||||
|
return username, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) {
|
||||||
|
password, err := uuid.GenerateUUID()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return password, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) string {
|
||||||
|
return ""
|
||||||
|
}
|
|
@ -1,10 +1,11 @@
|
||||||
package dbs
|
package dbs
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/mitchellh/mapstructure"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -16,11 +17,47 @@ var (
|
||||||
ErrUnsupportedDatabaseType = errors.New("Unsupported database type")
|
ErrUnsupportedDatabaseType = errors.New("Unsupported database type")
|
||||||
)
|
)
|
||||||
|
|
||||||
func Factory(conf ConnectionConfig) (DatabaseType, error) {
|
func Factory(conf *DatabaseConfig) (DatabaseType, error) {
|
||||||
switch conf.ConnectionType {
|
switch conf.DatabaseType {
|
||||||
case postgreSQLTypeName:
|
case postgreSQLTypeName:
|
||||||
|
var details *sqlConnectionDetails
|
||||||
|
err := mapstructure.Decode(conf.ConnectionDetails, &details)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
connProducer := &sqlConnectionProducer{
|
||||||
|
config: conf,
|
||||||
|
connDetails: details,
|
||||||
|
}
|
||||||
|
|
||||||
|
credsProducer := &sqlCredentialsProducer{
|
||||||
|
displayNameLen: 23,
|
||||||
|
usernameLen: 63,
|
||||||
|
}
|
||||||
|
|
||||||
return &PostgreSQL{
|
return &PostgreSQL{
|
||||||
config: conf,
|
ConnectionProducer: connProducer,
|
||||||
|
CredentialsProducer: credsProducer,
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
case cassandraTypeName:
|
||||||
|
var details *cassandraConnectionDetails
|
||||||
|
err := mapstructure.Decode(conf.ConnectionDetails, &details)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
connProducer := &cassandraConnectionProducer{
|
||||||
|
config: conf,
|
||||||
|
connDetails: details,
|
||||||
|
}
|
||||||
|
|
||||||
|
credsProducer := &cassandraCredentialsProducer{}
|
||||||
|
|
||||||
|
return &Cassandra{
|
||||||
|
ConnectionProducer: connProducer,
|
||||||
|
CredentialsProducer: credsProducer,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,21 +66,19 @@ func Factory(conf ConnectionConfig) (DatabaseType, error) {
|
||||||
|
|
||||||
type DatabaseType interface {
|
type DatabaseType interface {
|
||||||
Type() string
|
Type() string
|
||||||
Connection() (*sql.DB, error)
|
CreateUser(createStmt, rollbackStmt, username, password, expiration string) error
|
||||||
Close()
|
|
||||||
Reset(ConnectionConfig) (*sql.DB, error)
|
|
||||||
CreateUser(createStmt, username, password, expiration string) error
|
|
||||||
RenewUser(username, expiration string) error
|
RenewUser(username, expiration string) error
|
||||||
CustomRevokeUser(username, revocationSQL string) error
|
RevokeUser(username, revocationStmt string) error
|
||||||
DefaultRevokeUser(username string) error
|
|
||||||
|
ConnectionProducer
|
||||||
|
CredentialsProducer
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConnectionConfig struct {
|
type DatabaseConfig struct {
|
||||||
ConnectionType string `json:"type" structs:"type" mapstructure:"type"`
|
DatabaseType string `json:"type" structs:"type" mapstructure:"type"`
|
||||||
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
|
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
|
||||||
ConnectionDetails map[string]string `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
|
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
|
||||||
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
|
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
|
||||||
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query templates a query for us.
|
// Query templates a query for us.
|
||||||
|
|
|
@ -11,9 +11,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type PostgreSQL struct {
|
type PostgreSQL struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
config ConnectionConfig
|
|
||||||
|
|
||||||
|
ConnectionProducer
|
||||||
|
CredentialsProducer
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,74 +22,18 @@ func (p *PostgreSQL) Type() string {
|
||||||
return postgreSQLTypeName
|
return postgreSQLTypeName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgreSQL) Connection() (*sql.DB, error) {
|
func (p *PostgreSQL) getConnection() (*sql.DB, error) {
|
||||||
// Grab the write lock
|
db, err := p.Connection()
|
||||||
p.Lock()
|
|
||||||
defer p.Unlock()
|
|
||||||
|
|
||||||
// If we already have a DB, we got it!
|
|
||||||
if p.db != nil {
|
|
||||||
if err := p.db.Ping(); err == nil {
|
|
||||||
return p.db, nil
|
|
||||||
}
|
|
||||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
|
||||||
// reestablishing anyways
|
|
||||||
p.db.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, attempt to make connection
|
|
||||||
conn := p.config.ConnectionURL
|
|
||||||
|
|
||||||
// Ensure timezone is set to UTC for all the conenctions
|
|
||||||
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
|
|
||||||
if strings.Contains(conn, "?") {
|
|
||||||
conn += "&timezone=utc"
|
|
||||||
} else {
|
|
||||||
conn += "?timezone=utc"
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
conn += " timezone=utc"
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
p.db, err = sql.Open("postgres", conn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set some connection pool settings. We don't need much of this,
|
return db.(*sql.DB), nil
|
||||||
// since the request rate shouldn't be high.
|
|
||||||
p.db.SetMaxOpenConns(p.config.MaxOpenConnections)
|
|
||||||
p.db.SetMaxIdleConns(p.config.MaxIdleConnections)
|
|
||||||
|
|
||||||
return p.db, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgreSQL) Close() {
|
func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error {
|
||||||
// Grab the write lock
|
|
||||||
p.Lock()
|
|
||||||
defer p.Unlock()
|
|
||||||
|
|
||||||
if p.db != nil {
|
|
||||||
p.db.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
p.db = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PostgreSQL) Reset(config ConnectionConfig) (*sql.DB, error) {
|
|
||||||
// Grab the write lock
|
|
||||||
p.Lock()
|
|
||||||
p.config = config
|
|
||||||
p.Unlock()
|
|
||||||
|
|
||||||
p.Close()
|
|
||||||
return p.Connection()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration string) error {
|
|
||||||
// Get the connection
|
// Get the connection
|
||||||
db, err := p.Connection()
|
db, err := p.getConnection()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -144,7 +89,7 @@ func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration strin
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgreSQL) RenewUser(username, expiration string) error {
|
func (p *PostgreSQL) RenewUser(username, expiration string) error {
|
||||||
db, err := p.Connection()
|
db, err := p.getConnection()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -170,14 +115,23 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error {
|
func (p *PostgreSQL) RevokeUser(username, revocationStmt string) error {
|
||||||
db, err := p.Connection()
|
// Grab the read lock
|
||||||
|
p.RLock()
|
||||||
|
defer p.RUnlock()
|
||||||
|
|
||||||
|
if revocationStmt == "" {
|
||||||
|
return p.defaultRevokeUser(username)
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.customRevokeUser(username, revocationStmt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error {
|
||||||
|
db, err := p.getConnection()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// TODO: this is Racey
|
|
||||||
p.RLock()
|
|
||||||
defer p.RUnlock()
|
|
||||||
|
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -187,7 +141,7 @@ func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") {
|
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") {
|
||||||
query = strings.TrimSpace(query)
|
query = strings.TrimSpace(query)
|
||||||
if len(query) == 0 {
|
if len(query) == 0 {
|
||||||
continue
|
continue
|
||||||
|
@ -213,12 +167,8 @@ func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgreSQL) DefaultRevokeUser(username string) error {
|
func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
||||||
// Grab the read lock
|
db, err := p.getConnection()
|
||||||
p.RLock()
|
|
||||||
defer p.RUnlock()
|
|
||||||
|
|
||||||
db, err := p.Connection()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,7 +79,7 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var config dbs.ConnectionConfig
|
var config dbs.DatabaseConfig
|
||||||
if err := entry.DecodeJSON(&config); err != nil {
|
if err := entry.DecodeJSON(&config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -89,8 +89,8 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||||
connURL := data.Get("connection_url").(string)
|
|
||||||
connType := data.Get("connection_type").(string)
|
connType := data.Get("connection_type").(string)
|
||||||
|
connDetails := data.Get("connection_details").(map[string]interface{})
|
||||||
|
|
||||||
maxOpenConns := data.Get("max_open_connections").(int)
|
maxOpenConns := data.Get("max_open_connections").(int)
|
||||||
if maxOpenConns == 0 {
|
if maxOpenConns == 0 {
|
||||||
|
@ -105,9 +105,9 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew
|
||||||
maxIdleConns = maxOpenConns
|
maxIdleConns = maxOpenConns
|
||||||
}
|
}
|
||||||
|
|
||||||
config := dbs.ConnectionConfig{
|
config := &dbs.DatabaseConfig{
|
||||||
ConnectionType: connType,
|
DatabaseType: connType,
|
||||||
ConnectionURL: connURL,
|
ConnectionDetails: connDetails,
|
||||||
MaxOpenConnections: maxOpenConns,
|
MaxOpenConnections: maxOpenConns,
|
||||||
MaxIdleConnections: maxIdleConns,
|
MaxIdleConnections: maxIdleConns,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,103 +0,0 @@
|
||||||
package database
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/hashicorp/vault/logical"
|
|
||||||
"github.com/hashicorp/vault/logical/framework"
|
|
||||||
)
|
|
||||||
|
|
||||||
func pathConfigLease(b *databaseBackend) *framework.Path {
|
|
||||||
return &framework.Path{
|
|
||||||
Pattern: "config/lease",
|
|
||||||
Fields: map[string]*framework.FieldSchema{
|
|
||||||
"lease": &framework.FieldSchema{
|
|
||||||
Type: framework.TypeString,
|
|
||||||
Description: "Default lease for roles.",
|
|
||||||
},
|
|
||||||
|
|
||||||
"lease_max": &framework.FieldSchema{
|
|
||||||
Type: framework.TypeString,
|
|
||||||
Description: "Maximum time a credential is valid for.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
|
||||||
logical.ReadOperation: b.pathLeaseRead,
|
|
||||||
logical.UpdateOperation: b.pathLeaseWrite,
|
|
||||||
},
|
|
||||||
|
|
||||||
HelpSynopsis: pathConfigLeaseHelpSyn,
|
|
||||||
HelpDescription: pathConfigLeaseHelpDesc,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *databaseBackend) pathLeaseWrite(
|
|
||||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
|
||||||
leaseRaw := d.Get("lease").(string)
|
|
||||||
leaseMaxRaw := d.Get("lease_max").(string)
|
|
||||||
|
|
||||||
lease, err := time.ParseDuration(leaseRaw)
|
|
||||||
if err != nil {
|
|
||||||
return logical.ErrorResponse(fmt.Sprintf(
|
|
||||||
"Invalid lease: %s", err)), nil
|
|
||||||
}
|
|
||||||
leaseMax, err := time.ParseDuration(leaseMaxRaw)
|
|
||||||
if err != nil {
|
|
||||||
return logical.ErrorResponse(fmt.Sprintf(
|
|
||||||
"Invalid lease: %s", err)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store it
|
|
||||||
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
|
|
||||||
Lease: lease,
|
|
||||||
LeaseMax: leaseMax,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := req.Storage.Put(entry); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *databaseBackend) pathLeaseRead(
|
|
||||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
|
||||||
lease, err := b.Lease(req.Storage)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if lease == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return &logical.Response{
|
|
||||||
Data: map[string]interface{}{
|
|
||||||
"lease": lease.Lease.String(),
|
|
||||||
"lease_max": lease.LeaseMax.String(),
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type configLease struct {
|
|
||||||
Lease time.Duration
|
|
||||||
LeaseMax time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
const pathConfigLeaseHelpSyn = `
|
|
||||||
Configure the default lease information for generated credentials.
|
|
||||||
`
|
|
||||||
|
|
||||||
const pathConfigLeaseHelpDesc = `
|
|
||||||
This configures the default lease information used for credentials
|
|
||||||
generated by this backend. The lease specifies the duration that a
|
|
||||||
credential will be valid for, as well as the maximum session for
|
|
||||||
a set of credentials.
|
|
||||||
|
|
||||||
The format for the lease is "1h" or integer and then unit. The longest
|
|
||||||
unit is hour.
|
|
||||||
`
|
|
|
@ -2,9 +2,7 @@ package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-uuid"
|
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/logical/framework"
|
"github.com/hashicorp/vault/logical/framework"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
|
@ -45,41 +43,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
|
||||||
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
|
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if we have a lease
|
|
||||||
b.logger.Trace("postgres/pathRoleCreateRead: getting lease")
|
|
||||||
lease, err := b.Lease(req.Storage)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// Unlike some other backends we need a lease here (can't leave as 0 and
|
|
||||||
// let core fill it in) because Postgres also expires users as a safety
|
|
||||||
// measure, so cannot be zero
|
|
||||||
if lease == nil {
|
|
||||||
lease = &configLease{
|
|
||||||
Lease: b.System().DefaultLeaseTTL(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate the username, password and expiration. PG limits user to 63 characters
|
// Generate the username, password and expiration. PG limits user to 63 characters
|
||||||
displayName := req.DisplayName
|
|
||||||
if len(displayName) > 26 {
|
|
||||||
displayName = displayName[:26]
|
|
||||||
}
|
|
||||||
userUUID, err := uuid.GenerateUUID()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
username := fmt.Sprintf("%s-%s", displayName, userUUID)
|
|
||||||
if len(username) > 63 {
|
|
||||||
username = username[:63]
|
|
||||||
}
|
|
||||||
password, err := uuid.GenerateUUID()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
expiration := time.Now().
|
|
||||||
Add(lease.Lease).
|
|
||||||
Format("2006-01-02 15:04:05-0700")
|
|
||||||
|
|
||||||
// Get our handle
|
// Get our handle
|
||||||
b.logger.Trace("postgres/pathRoleCreateRead: getting database handle")
|
b.logger.Trace("postgres/pathRoleCreateRead: getting database handle")
|
||||||
|
@ -92,7 +56,19 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
|
||||||
return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName)
|
return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.CreateUser(role.CreationStatement, username, password, expiration)
|
username, err := db.GenerateUsername(req.DisplayName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
password, err := db.GeneratePassword()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
expiration := db.GenerateExpiration(role.DefaultTTL)
|
||||||
|
|
||||||
|
err = db.CreateUser(role.CreationStatement, role.RollbackStatement, username, password, expiration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -105,7 +81,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
|
||||||
"username": username,
|
"username": username,
|
||||||
"role": name,
|
"role": name,
|
||||||
})
|
})
|
||||||
resp.Secret.TTL = lease.Lease
|
resp.Secret.TTL = role.DefaultTTL
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/logical/framework"
|
"github.com/hashicorp/vault/logical/framework"
|
||||||
)
|
)
|
||||||
|
@ -44,6 +47,24 @@ func pathRoles(b *databaseBackend) *framework.Path {
|
||||||
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
|
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
|
||||||
will be substituted.`,
|
will be substituted.`,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
"rollback_statement": {
|
||||||
|
Type: framework.TypeString,
|
||||||
|
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
|
||||||
|
string, a base64-encoded semicolon-separated string, a serialized JSON string
|
||||||
|
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
|
||||||
|
will be substituted.`,
|
||||||
|
},
|
||||||
|
|
||||||
|
"default_ttl": {
|
||||||
|
Type: framework.TypeString,
|
||||||
|
Description: "Default ttl for role.",
|
||||||
|
},
|
||||||
|
|
||||||
|
"max_ttl": {
|
||||||
|
Type: framework.TypeString,
|
||||||
|
Description: "Maximum time a credential is valid for",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||||
|
@ -79,6 +100,9 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie
|
||||||
Data: map[string]interface{}{
|
Data: map[string]interface{}{
|
||||||
"creation_statment": role.CreationStatement,
|
"creation_statment": role.CreationStatement,
|
||||||
"revocation_statement": role.RevocationStatement,
|
"revocation_statement": role.RevocationStatement,
|
||||||
|
"rollback_statement": role.RollbackStatement,
|
||||||
|
"default_ttl": role.DefaultTTL.String(),
|
||||||
|
"max_ttl": role.MaxTTL.String(),
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -97,6 +121,20 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
|
||||||
dbName := data.Get("db_name").(string)
|
dbName := data.Get("db_name").(string)
|
||||||
creationStmt := data.Get("creation_statement").(string)
|
creationStmt := data.Get("creation_statement").(string)
|
||||||
revocationStmt := data.Get("revocation_statement").(string)
|
revocationStmt := data.Get("revocation_statement").(string)
|
||||||
|
rollbackStmt := data.Get("rollback_statement").(string)
|
||||||
|
defaultTTLRaw := data.Get("default_ttl").(string)
|
||||||
|
maxTTLRaw := data.Get("max_ttl").(string)
|
||||||
|
|
||||||
|
defaultTTL, err := time.ParseDuration(defaultTTLRaw)
|
||||||
|
if err != nil {
|
||||||
|
return logical.ErrorResponse(fmt.Sprintf(
|
||||||
|
"Invalid default_ttl: %s", err)), nil
|
||||||
|
}
|
||||||
|
maxTTL, err := time.ParseDuration(maxTTLRaw)
|
||||||
|
if err != nil {
|
||||||
|
return logical.ErrorResponse(fmt.Sprintf(
|
||||||
|
"Invalid max_ttl: %s", err)), nil
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Think about preparing the statments to test.
|
// TODO: Think about preparing the statments to test.
|
||||||
|
|
||||||
|
@ -105,6 +143,9 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
|
||||||
DBName: dbName,
|
DBName: dbName,
|
||||||
CreationStatement: creationStmt,
|
CreationStatement: creationStmt,
|
||||||
RevocationStatement: revocationStmt,
|
RevocationStatement: revocationStmt,
|
||||||
|
RollbackStatement: rollbackStmt,
|
||||||
|
DefaultTTL: defaultTTL,
|
||||||
|
MaxTTL: maxTTL,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -117,9 +158,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
|
||||||
}
|
}
|
||||||
|
|
||||||
type roleEntry struct {
|
type roleEntry struct {
|
||||||
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
|
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
|
||||||
CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"`
|
CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"`
|
||||||
RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
|
RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
|
||||||
|
RollbackStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
|
||||||
|
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
|
||||||
|
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const pathRoleHelpSyn = `
|
const pathRoleHelpSyn = `
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
|
@ -31,8 +30,6 @@ func secretCreds(b *databaseBackend) *framework.Secret {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||||
dbName := d.Get("name").(string)
|
|
||||||
|
|
||||||
// Get the username from the internal data
|
// Get the username from the internal data
|
||||||
usernameRaw, ok := req.Secret.InternalData["username"]
|
usernameRaw, ok := req.Secret.InternalData["username"]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -40,27 +37,35 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
|
||||||
}
|
}
|
||||||
username, ok := usernameRaw.(string)
|
username, ok := usernameRaw.(string)
|
||||||
|
|
||||||
// Get our connection
|
roleNameRaw, ok := req.Secret.InternalData["role"]
|
||||||
db, ok := b.connections[dbName]
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New(fmt.Sprintf("Could not find connection with name %s", dbName))
|
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the lease information
|
role, err := b.Role(req.Storage, roleNameRaw.(string))
|
||||||
lease, err := b.Lease(req.Storage)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if lease == nil {
|
if role == nil {
|
||||||
lease = &configLease{}
|
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
|
||||||
}
|
}
|
||||||
|
|
||||||
f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, b.System())
|
f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System())
|
||||||
resp, err := f(req, d)
|
resp, err := f(req, d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Grab the read lock
|
||||||
|
b.RLock()
|
||||||
|
defer b.RUnlock()
|
||||||
|
|
||||||
|
// Get our connection
|
||||||
|
db, ok := b.connections[role.DBName]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("Could not find connection with name %s", role.DBName)
|
||||||
|
}
|
||||||
|
|
||||||
// Make sure we increase the VALID UNTIL endpoint for this user.
|
// Make sure we increase the VALID UNTIL endpoint for this user.
|
||||||
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
|
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
|
||||||
expiration := expireTime.Format("2006-01-02 15:04:05-0700")
|
expiration := expireTime.Format("2006-01-02 15:04:05-0700")
|
||||||
|
@ -124,23 +129,9 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
|
||||||
return nil, fmt.Errorf("Could not find database with name: %s", role.DBName)
|
return nil, fmt.Errorf("Could not find database with name: %s", role.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Maybe move this down into db package?
|
err = db.RevokeUser(username, revocationSQL)
|
||||||
switch revocationSQL {
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
// This is the default revocation logic. If revocation SQL is provided it
|
|
||||||
// is simply executed as-is.
|
|
||||||
case "":
|
|
||||||
err := db.DefaultRevokeUser(username)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We have revocation SQL, execute directly, within a transaction
|
|
||||||
default:
|
|
||||||
err := db.CustomRevokeUser(username, revocationSQL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
|
|
Loading…
Reference in New Issue