More work on refactor and cassandra database
This commit is contained in:
parent
acdcd79af3
commit
2ec5ab5616
|
@ -22,7 +22,6 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
|
|||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathConfigLease(&b),
|
||||
pathListRoles(&b),
|
||||
pathRoles(&b),
|
||||
pathRoleCreate(&b),
|
||||
|
@ -61,24 +60,6 @@ func (b *databaseBackend) resetAllDBs() {
|
|||
}
|
||||
}
|
||||
|
||||
// Lease returns the lease information
|
||||
func (b *databaseBackend) Lease(s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get("config/lease")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result configLease
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) {
|
||||
entry, err := s.Get("role/" + n)
|
||||
if err != nil {
|
||||
|
|
|
@ -1,25 +1,20 @@
|
|||
package dbs
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/tlsutil"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
)
|
||||
|
||||
type Cassandra struct {
|
||||
// Session is goroutine safe, however, since we reinitialize
|
||||
// it when connection info changes, we want to make sure we
|
||||
// can close it and use a new connection; hence the lock
|
||||
session *gocql.Session
|
||||
config ConnectionConfig
|
||||
|
||||
ConnectionProducer
|
||||
CredentialsProducer
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
|
@ -27,168 +22,85 @@ func (c *Cassandra) Type() string {
|
|||
return cassandraTypeName
|
||||
}
|
||||
|
||||
func (c *Cassandra) Connection() (*gocql.Session, error) {
|
||||
// Grab the write lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
// If we already have a DB, we got it!
|
||||
if c.session != nil {
|
||||
return c.session, nil
|
||||
}
|
||||
|
||||
session, err := createSession(c.config)
|
||||
func (c *Cassandra) getConnection() (*gocql.Session, error) {
|
||||
session, err := c.Connection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store the session in backend for reuse
|
||||
c.session = session
|
||||
|
||||
return session, nil
|
||||
return session.(*gocql.Session), nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) Close() {
|
||||
// Grab the write lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if p.session != nil {
|
||||
p.session.Close()
|
||||
}
|
||||
|
||||
p.session = nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) Reset(config ConnectionConfig) (*sql.DB, error) {
|
||||
// Grab the write lock
|
||||
p.Lock()
|
||||
p.config = config
|
||||
p.Unlock()
|
||||
|
||||
p.Close()
|
||||
return p.Connection()
|
||||
}
|
||||
|
||||
func (p *Cassandra) CreateUser(createStmt, username, password, expiration string) error {
|
||||
func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error {
|
||||
// Get the connection
|
||||
db, err := p.Connection()
|
||||
session, err := c.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: This is racey
|
||||
// Grab a read lock
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
// Set consistency
|
||||
/* if .Consistency != "" {
|
||||
consistencyValue, err := gocql.ParseConsistencyWrapper(role.Consistency)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *Cassandra) RenewUser(username, expiration string) error {
|
||||
db, err := p.Connection()
|
||||
if err != nil {
|
||||
return err
|
||||
session.SetConsistency(consistencyValue)
|
||||
}*/
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
err = session.Query(queryHelper(query, map[string]string{
|
||||
"username": username,
|
||||
"password": password,
|
||||
})).Exec()
|
||||
if err != nil {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmt, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
session.Query(queryHelper(query, map[string]string{
|
||||
"username": username,
|
||||
"password": password,
|
||||
})).Exec()
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
// TODO: This is Racey
|
||||
// Grab the read lock
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) CustomRevokeUser(username, revocationSQL string) error {
|
||||
db, err := p.Connection()
|
||||
func (c *Cassandra) RenewUser(username, expiration string) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) RevokeUser(username, revocationSQL string) error {
|
||||
session, err := c.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: this is Racey
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error removing user %s", username)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) DefaultRevokeUser(username string) error {
|
||||
// Grab the read lock
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
db, err := p.Connection()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createSession(cfg *ConnectionConfig) (*gocql.Session, error) {
|
||||
clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...)
|
||||
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
}
|
||||
|
||||
clusterConfig.ProtoVersion = cfg.ProtocolVersion
|
||||
if clusterConfig.ProtoVersion == 0 {
|
||||
clusterConfig.ProtoVersion = 2
|
||||
}
|
||||
|
||||
clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second
|
||||
|
||||
if cfg.TLS {
|
||||
var tlsConfig *tls.Config
|
||||
if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 {
|
||||
if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 {
|
||||
return nil, fmt.Errorf("Found certificate for TLS authentication but no private key")
|
||||
}
|
||||
|
||||
certBundle := &certutil.CertBundle{}
|
||||
if len(cfg.Certificate) > 0 {
|
||||
certBundle.Certificate = cfg.Certificate
|
||||
certBundle.PrivateKey = cfg.PrivateKey
|
||||
}
|
||||
if len(cfg.IssuingCA) > 0 {
|
||||
certBundle.IssuingCA = cfg.IssuingCA
|
||||
}
|
||||
|
||||
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate bundle: %s", err)
|
||||
}
|
||||
|
||||
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
||||
if err != nil || tlsConfig == nil {
|
||||
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err)
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = cfg.InsecureTLS
|
||||
|
||||
if cfg.TLSMinVersion != "" {
|
||||
var ok bool
|
||||
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
|
||||
}
|
||||
} else {
|
||||
// MinVersion was not being set earlier. Reset it to
|
||||
// zero to gracefully handle upgrades.
|
||||
tlsConfig.MinVersion = 0
|
||||
}
|
||||
}
|
||||
|
||||
clusterConfig.SslOpts = &gocql.SslOptions{
|
||||
Config: *tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
session, err := clusterConfig.CreateSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error creating session: %s", err)
|
||||
}
|
||||
|
||||
// Verify the info
|
||||
err = session.Query(`LIST USERS`).Exec()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error validating connection info: %s", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,254 @@
|
|||
package dbs
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/tlsutil"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
type ConnectionProducer interface {
|
||||
Connection() (interface{}, error)
|
||||
Close()
|
||||
// TODO: Should we make this immutable instead?
|
||||
Reset(*DatabaseConfig) error
|
||||
}
|
||||
|
||||
// sqlConnectionProducer impliments ConnectionProducer and provides a generic producer for most sql databases
|
||||
type sqlConnectionDetails struct {
|
||||
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
|
||||
}
|
||||
|
||||
type sqlConnectionProducer struct {
|
||||
config *DatabaseConfig
|
||||
// TODO: Should we merge these two structures make it immutable?
|
||||
connDetails *sqlConnectionDetails
|
||||
|
||||
db *sql.DB
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (cp *sqlConnectionProducer) Connection() (interface{}, error) {
|
||||
// Grab the write lock
|
||||
cp.Lock()
|
||||
defer cp.Unlock()
|
||||
|
||||
// If we already have a DB, we got it!
|
||||
if cp.db != nil {
|
||||
if err := cp.db.Ping(); err == nil {
|
||||
return cp.db, nil
|
||||
}
|
||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||
// reestablishing anyways
|
||||
cp.db.Close()
|
||||
}
|
||||
|
||||
// Otherwise, attempt to make connection
|
||||
conn := cp.connDetails.ConnectionURL
|
||||
|
||||
// Ensure timezone is set to UTC for all the conenctions
|
||||
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
|
||||
if strings.Contains(conn, "?") {
|
||||
conn += "&timezone=utc"
|
||||
} else {
|
||||
conn += "?timezone=utc"
|
||||
}
|
||||
} else {
|
||||
conn += " timezone=utc"
|
||||
}
|
||||
|
||||
var err error
|
||||
cp.db, err = sql.Open(cp.config.DatabaseType, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set some connection pool settings. We don't need much of this,
|
||||
// since the request rate shouldn't be high.
|
||||
cp.db.SetMaxOpenConns(cp.config.MaxOpenConnections)
|
||||
cp.db.SetMaxIdleConns(cp.config.MaxIdleConnections)
|
||||
|
||||
return cp.db, nil
|
||||
}
|
||||
|
||||
func (cp *sqlConnectionProducer) Close() {
|
||||
// Grab the write lock
|
||||
cp.Lock()
|
||||
defer cp.Unlock()
|
||||
|
||||
if cp.db != nil {
|
||||
cp.db.Close()
|
||||
}
|
||||
|
||||
cp.db = nil
|
||||
}
|
||||
|
||||
func (cp *sqlConnectionProducer) Reset(config *DatabaseConfig) error {
|
||||
// Grab the write lock
|
||||
cp.Lock()
|
||||
|
||||
var details *sqlConnectionDetails
|
||||
err := mapstructure.Decode(config.ConnectionDetails, &details)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cp.connDetails = details
|
||||
cp.config = config
|
||||
|
||||
cp.Unlock()
|
||||
|
||||
cp.Close()
|
||||
_, err = cp.Connection()
|
||||
return err
|
||||
}
|
||||
|
||||
// cassandraConnectionProducer impliments ConnectionProducer and provides connections for cassandra
|
||||
type cassandraConnectionDetails struct {
|
||||
Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"`
|
||||
Username string `json:"username" structs:"username" mapstructure:"username"`
|
||||
Password string `json:"password" structs:"password" mapstructure:"password"`
|
||||
TLS bool `json:"tls" structs:"tls" mapstructure:"tls"`
|
||||
InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
|
||||
Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"`
|
||||
PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"`
|
||||
IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"`
|
||||
ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
|
||||
ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
|
||||
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
|
||||
Consistancy string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
|
||||
}
|
||||
|
||||
type cassandraConnectionProducer struct {
|
||||
config *DatabaseConfig
|
||||
// TODO: Should we merge these two structures make it immutable?
|
||||
connDetails *cassandraConnectionDetails
|
||||
|
||||
session *gocql.Session
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (cp *cassandraConnectionProducer) Connection() (interface{}, error) {
|
||||
// Grab the write lock
|
||||
cp.Lock()
|
||||
defer cp.Unlock()
|
||||
|
||||
// If we already have a DB, we got it!
|
||||
if cp.session != nil {
|
||||
return cp.session, nil
|
||||
}
|
||||
|
||||
session, err := cp.createSession(cp.connDetails)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store the session in backend for reuse
|
||||
cp.session = session
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (cp *cassandraConnectionProducer) Close() {
|
||||
// Grab the write lock
|
||||
cp.Lock()
|
||||
defer cp.Unlock()
|
||||
|
||||
if cp.session != nil {
|
||||
cp.session.Close()
|
||||
}
|
||||
|
||||
cp.session = nil
|
||||
}
|
||||
|
||||
func (cp *cassandraConnectionProducer) Reset(config *DatabaseConfig) error {
|
||||
// Grab the write lock
|
||||
cp.Lock()
|
||||
cp.config = config
|
||||
cp.Unlock()
|
||||
|
||||
cp.Close()
|
||||
_, err := cp.Connection()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDetails) (*gocql.Session, error) {
|
||||
clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...)
|
||||
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
}
|
||||
|
||||
clusterConfig.ProtoVersion = cfg.ProtocolVersion
|
||||
if clusterConfig.ProtoVersion == 0 {
|
||||
clusterConfig.ProtoVersion = 2
|
||||
}
|
||||
|
||||
clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second
|
||||
|
||||
if cfg.TLS {
|
||||
var tlsConfig *tls.Config
|
||||
if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 {
|
||||
if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 {
|
||||
return nil, fmt.Errorf("Found certificate for TLS authentication but no private key")
|
||||
}
|
||||
|
||||
certBundle := &certutil.CertBundle{}
|
||||
if len(cfg.Certificate) > 0 {
|
||||
certBundle.Certificate = cfg.Certificate
|
||||
certBundle.PrivateKey = cfg.PrivateKey
|
||||
}
|
||||
if len(cfg.IssuingCA) > 0 {
|
||||
certBundle.IssuingCA = cfg.IssuingCA
|
||||
}
|
||||
|
||||
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate bundle: %s", err)
|
||||
}
|
||||
|
||||
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
||||
if err != nil || tlsConfig == nil {
|
||||
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err)
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = cfg.InsecureTLS
|
||||
|
||||
if cfg.TLSMinVersion != "" {
|
||||
var ok bool
|
||||
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
|
||||
}
|
||||
} else {
|
||||
// MinVersion was not being set earlier. Reset it to
|
||||
// zero to gracefully handle upgrades.
|
||||
tlsConfig.MinVersion = 0
|
||||
}
|
||||
}
|
||||
|
||||
clusterConfig.SslOpts = &gocql.SslOptions{
|
||||
Config: *tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
session, err := clusterConfig.CreateSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error creating session: %s", err)
|
||||
}
|
||||
|
||||
// Verify the info
|
||||
err = session.Query(`LIST USERS`).Exec()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error validating connection info: %s", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
package dbs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
type CredentialsProducer interface {
|
||||
GenerateUsername(displayName string) (string, error)
|
||||
GeneratePassword() (string, error)
|
||||
GenerateExpiration(ttl time.Duration) string
|
||||
}
|
||||
|
||||
// sqlCredentialsProducer impliments CredentialsProducer and provides a generic credentials producer for most sql database types.
|
||||
type sqlCredentialsProducer struct {
|
||||
displayNameLen int
|
||||
usernameLen int
|
||||
}
|
||||
|
||||
func (scg *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) {
|
||||
// Generate the username, password and expiration. PG limits user to 63 characters
|
||||
if scg.displayNameLen > 0 && len(displayName) > scg.displayNameLen {
|
||||
displayName = displayName[:scg.displayNameLen]
|
||||
}
|
||||
userUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
username := fmt.Sprintf("%s-%s", displayName, userUUID)
|
||||
if scg.usernameLen > 0 && len(username) > scg.usernameLen {
|
||||
username = username[:scg.usernameLen]
|
||||
}
|
||||
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func (scg *sqlCredentialsProducer) GeneratePassword() (string, error) {
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return password, nil
|
||||
}
|
||||
|
||||
func (scg *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string {
|
||||
return time.Now().
|
||||
Add(ttl).
|
||||
Format("2006-01-02 15:04:05-0700")
|
||||
}
|
||||
|
||||
type cassandraCredentialsProducer struct{}
|
||||
|
||||
func (ccp *cassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) {
|
||||
userUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix())
|
||||
username = strings.Replace(username, "-", "_", -1)
|
||||
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) {
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return password, nil
|
||||
}
|
||||
|
||||
func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) string {
|
||||
return ""
|
||||
}
|
|
@ -1,10 +1,11 @@
|
|||
package dbs
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -16,11 +17,47 @@ var (
|
|||
ErrUnsupportedDatabaseType = errors.New("Unsupported database type")
|
||||
)
|
||||
|
||||
func Factory(conf ConnectionConfig) (DatabaseType, error) {
|
||||
switch conf.ConnectionType {
|
||||
func Factory(conf *DatabaseConfig) (DatabaseType, error) {
|
||||
switch conf.DatabaseType {
|
||||
case postgreSQLTypeName:
|
||||
var details *sqlConnectionDetails
|
||||
err := mapstructure.Decode(conf.ConnectionDetails, &details)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connProducer := &sqlConnectionProducer{
|
||||
config: conf,
|
||||
connDetails: details,
|
||||
}
|
||||
|
||||
credsProducer := &sqlCredentialsProducer{
|
||||
displayNameLen: 23,
|
||||
usernameLen: 63,
|
||||
}
|
||||
|
||||
return &PostgreSQL{
|
||||
config: conf,
|
||||
ConnectionProducer: connProducer,
|
||||
CredentialsProducer: credsProducer,
|
||||
}, nil
|
||||
|
||||
case cassandraTypeName:
|
||||
var details *cassandraConnectionDetails
|
||||
err := mapstructure.Decode(conf.ConnectionDetails, &details)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connProducer := &cassandraConnectionProducer{
|
||||
config: conf,
|
||||
connDetails: details,
|
||||
}
|
||||
|
||||
credsProducer := &cassandraCredentialsProducer{}
|
||||
|
||||
return &Cassandra{
|
||||
ConnectionProducer: connProducer,
|
||||
CredentialsProducer: credsProducer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -29,21 +66,19 @@ func Factory(conf ConnectionConfig) (DatabaseType, error) {
|
|||
|
||||
type DatabaseType interface {
|
||||
Type() string
|
||||
Connection() (*sql.DB, error)
|
||||
Close()
|
||||
Reset(ConnectionConfig) (*sql.DB, error)
|
||||
CreateUser(createStmt, username, password, expiration string) error
|
||||
CreateUser(createStmt, rollbackStmt, username, password, expiration string) error
|
||||
RenewUser(username, expiration string) error
|
||||
CustomRevokeUser(username, revocationSQL string) error
|
||||
DefaultRevokeUser(username string) error
|
||||
RevokeUser(username, revocationStmt string) error
|
||||
|
||||
ConnectionProducer
|
||||
CredentialsProducer
|
||||
}
|
||||
|
||||
type ConnectionConfig struct {
|
||||
ConnectionType string `json:"type" structs:"type" mapstructure:"type"`
|
||||
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
|
||||
ConnectionDetails map[string]string `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
|
||||
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
|
||||
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
|
||||
type DatabaseConfig struct {
|
||||
DatabaseType string `json:"type" structs:"type" mapstructure:"type"`
|
||||
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
|
||||
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
|
||||
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
|
||||
}
|
||||
|
||||
// Query templates a query for us.
|
||||
|
|
|
@ -11,9 +11,10 @@ import (
|
|||
)
|
||||
|
||||
type PostgreSQL struct {
|
||||
db *sql.DB
|
||||
config ConnectionConfig
|
||||
db *sql.DB
|
||||
|
||||
ConnectionProducer
|
||||
CredentialsProducer
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
|
@ -21,74 +22,18 @@ func (p *PostgreSQL) Type() string {
|
|||
return postgreSQLTypeName
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) Connection() (*sql.DB, error) {
|
||||
// Grab the write lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
// If we already have a DB, we got it!
|
||||
if p.db != nil {
|
||||
if err := p.db.Ping(); err == nil {
|
||||
return p.db, nil
|
||||
}
|
||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||
// reestablishing anyways
|
||||
p.db.Close()
|
||||
}
|
||||
|
||||
// Otherwise, attempt to make connection
|
||||
conn := p.config.ConnectionURL
|
||||
|
||||
// Ensure timezone is set to UTC for all the conenctions
|
||||
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
|
||||
if strings.Contains(conn, "?") {
|
||||
conn += "&timezone=utc"
|
||||
} else {
|
||||
conn += "?timezone=utc"
|
||||
}
|
||||
} else {
|
||||
conn += " timezone=utc"
|
||||
}
|
||||
|
||||
var err error
|
||||
p.db, err = sql.Open("postgres", conn)
|
||||
func (p *PostgreSQL) getConnection() (*sql.DB, error) {
|
||||
db, err := p.Connection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set some connection pool settings. We don't need much of this,
|
||||
// since the request rate shouldn't be high.
|
||||
p.db.SetMaxOpenConns(p.config.MaxOpenConnections)
|
||||
p.db.SetMaxIdleConns(p.config.MaxIdleConnections)
|
||||
|
||||
return p.db, nil
|
||||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) Close() {
|
||||
// Grab the write lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if p.db != nil {
|
||||
p.db.Close()
|
||||
}
|
||||
|
||||
p.db = nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) Reset(config ConnectionConfig) (*sql.DB, error) {
|
||||
// Grab the write lock
|
||||
p.Lock()
|
||||
p.config = config
|
||||
p.Unlock()
|
||||
|
||||
p.Close()
|
||||
return p.Connection()
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration string) error {
|
||||
func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error {
|
||||
// Get the connection
|
||||
db, err := p.Connection()
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -144,7 +89,7 @@ func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration strin
|
|||
}
|
||||
|
||||
func (p *PostgreSQL) RenewUser(username, expiration string) error {
|
||||
db, err := p.Connection()
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -170,14 +115,23 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error {
|
||||
db, err := p.Connection()
|
||||
func (p *PostgreSQL) RevokeUser(username, revocationStmt string) error {
|
||||
// Grab the read lock
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
if revocationStmt == "" {
|
||||
return p.defaultRevokeUser(username)
|
||||
}
|
||||
|
||||
return p.customRevokeUser(username, revocationStmt)
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error {
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: this is Racey
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
|
@ -187,7 +141,7 @@ func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error {
|
|||
tx.Rollback()
|
||||
}()
|
||||
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
|
@ -213,12 +167,8 @@ func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) DefaultRevokeUser(username string) error {
|
||||
// Grab the read lock
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
db, err := p.Connection()
|
||||
func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -79,7 +79,7 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
var config dbs.ConnectionConfig
|
||||
var config dbs.DatabaseConfig
|
||||
if err := entry.DecodeJSON(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -89,8 +89,8 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo
|
|||
}
|
||||
|
||||
func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
connURL := data.Get("connection_url").(string)
|
||||
connType := data.Get("connection_type").(string)
|
||||
connDetails := data.Get("connection_details").(map[string]interface{})
|
||||
|
||||
maxOpenConns := data.Get("max_open_connections").(int)
|
||||
if maxOpenConns == 0 {
|
||||
|
@ -105,9 +105,9 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew
|
|||
maxIdleConns = maxOpenConns
|
||||
}
|
||||
|
||||
config := dbs.ConnectionConfig{
|
||||
ConnectionType: connType,
|
||||
ConnectionURL: connURL,
|
||||
config := &dbs.DatabaseConfig{
|
||||
DatabaseType: connType,
|
||||
ConnectionDetails: connDetails,
|
||||
MaxOpenConnections: maxOpenConns,
|
||||
MaxIdleConnections: maxIdleConns,
|
||||
}
|
||||
|
|
|
@ -1,103 +0,0 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathConfigLease(b *databaseBackend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "config/lease",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"lease": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Default lease for roles.",
|
||||
},
|
||||
|
||||
"lease_max": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Maximum time a credential is valid for.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathLeaseRead,
|
||||
logical.UpdateOperation: b.pathLeaseWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigLeaseHelpSyn,
|
||||
HelpDescription: pathConfigLeaseHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathLeaseWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
leaseRaw := d.Get("lease").(string)
|
||||
leaseMaxRaw := d.Get("lease_max").(string)
|
||||
|
||||
lease, err := time.ParseDuration(leaseRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Invalid lease: %s", err)), nil
|
||||
}
|
||||
leaseMax, err := time.ParseDuration(leaseMaxRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Invalid lease: %s", err)), nil
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
|
||||
Lease: lease,
|
||||
LeaseMax: leaseMax,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathLeaseRead(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
lease, err := b.Lease(req.Storage)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lease == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"lease": lease.Lease.String(),
|
||||
"lease_max": lease.LeaseMax.String(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type configLease struct {
|
||||
Lease time.Duration
|
||||
LeaseMax time.Duration
|
||||
}
|
||||
|
||||
const pathConfigLeaseHelpSyn = `
|
||||
Configure the default lease information for generated credentials.
|
||||
`
|
||||
|
||||
const pathConfigLeaseHelpDesc = `
|
||||
This configures the default lease information used for credentials
|
||||
generated by this backend. The lease specifies the duration that a
|
||||
credential will be valid for, as well as the maximum session for
|
||||
a set of credentials.
|
||||
|
||||
The format for the lease is "1h" or integer and then unit. The longest
|
||||
unit is hour.
|
||||
`
|
|
@ -2,9 +2,7 @@ package database
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
_ "github.com/lib/pq"
|
||||
|
@ -45,41 +43,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
|
|||
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
|
||||
}
|
||||
|
||||
// Determine if we have a lease
|
||||
b.logger.Trace("postgres/pathRoleCreateRead: getting lease")
|
||||
lease, err := b.Lease(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Unlike some other backends we need a lease here (can't leave as 0 and
|
||||
// let core fill it in) because Postgres also expires users as a safety
|
||||
// measure, so cannot be zero
|
||||
if lease == nil {
|
||||
lease = &configLease{
|
||||
Lease: b.System().DefaultLeaseTTL(),
|
||||
}
|
||||
}
|
||||
|
||||
// Generate the username, password and expiration. PG limits user to 63 characters
|
||||
displayName := req.DisplayName
|
||||
if len(displayName) > 26 {
|
||||
displayName = displayName[:26]
|
||||
}
|
||||
userUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
username := fmt.Sprintf("%s-%s", displayName, userUUID)
|
||||
if len(username) > 63 {
|
||||
username = username[:63]
|
||||
}
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiration := time.Now().
|
||||
Add(lease.Lease).
|
||||
Format("2006-01-02 15:04:05-0700")
|
||||
|
||||
// Get our handle
|
||||
b.logger.Trace("postgres/pathRoleCreateRead: getting database handle")
|
||||
|
@ -92,7 +56,19 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
|
|||
return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName)
|
||||
}
|
||||
|
||||
err = db.CreateUser(role.CreationStatement, username, password, expiration)
|
||||
username, err := db.GenerateUsername(req.DisplayName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
password, err := db.GeneratePassword()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
expiration := db.GenerateExpiration(role.DefaultTTL)
|
||||
|
||||
err = db.CreateUser(role.CreationStatement, role.RollbackStatement, username, password, expiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -105,7 +81,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
|
|||
"username": username,
|
||||
"role": name,
|
||||
})
|
||||
resp.Secret.TTL = lease.Lease
|
||||
resp.Secret.TTL = role.DefaultTTL
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -44,6 +47,24 @@ func pathRoles(b *databaseBackend) *framework.Path {
|
|||
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
|
||||
will be substituted.`,
|
||||
},
|
||||
|
||||
"rollback_statement": {
|
||||
Type: framework.TypeString,
|
||||
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
|
||||
string, a base64-encoded semicolon-separated string, a serialized JSON string
|
||||
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
|
||||
will be substituted.`,
|
||||
},
|
||||
|
||||
"default_ttl": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Default ttl for role.",
|
||||
},
|
||||
|
||||
"max_ttl": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Maximum time a credential is valid for",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
|
@ -79,6 +100,9 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie
|
|||
Data: map[string]interface{}{
|
||||
"creation_statment": role.CreationStatement,
|
||||
"revocation_statement": role.RevocationStatement,
|
||||
"rollback_statement": role.RollbackStatement,
|
||||
"default_ttl": role.DefaultTTL.String(),
|
||||
"max_ttl": role.MaxTTL.String(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
@ -97,6 +121,20 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
|
|||
dbName := data.Get("db_name").(string)
|
||||
creationStmt := data.Get("creation_statement").(string)
|
||||
revocationStmt := data.Get("revocation_statement").(string)
|
||||
rollbackStmt := data.Get("rollback_statement").(string)
|
||||
defaultTTLRaw := data.Get("default_ttl").(string)
|
||||
maxTTLRaw := data.Get("max_ttl").(string)
|
||||
|
||||
defaultTTL, err := time.ParseDuration(defaultTTLRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Invalid default_ttl: %s", err)), nil
|
||||
}
|
||||
maxTTL, err := time.ParseDuration(maxTTLRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Invalid max_ttl: %s", err)), nil
|
||||
}
|
||||
|
||||
// TODO: Think about preparing the statments to test.
|
||||
|
||||
|
@ -105,6 +143,9 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
|
|||
DBName: dbName,
|
||||
CreationStatement: creationStmt,
|
||||
RevocationStatement: revocationStmt,
|
||||
RollbackStatement: rollbackStmt,
|
||||
DefaultTTL: defaultTTL,
|
||||
MaxTTL: maxTTL,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -117,9 +158,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
|
|||
}
|
||||
|
||||
type roleEntry struct {
|
||||
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
|
||||
CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"`
|
||||
RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
|
||||
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
|
||||
CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"`
|
||||
RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
|
||||
RollbackStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
|
||||
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
|
||||
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
|
||||
}
|
||||
|
||||
const pathRoleHelpSyn = `
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
|
@ -31,8 +30,6 @@ func secretCreds(b *databaseBackend) *framework.Secret {
|
|||
}
|
||||
|
||||
func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
dbName := d.Get("name").(string)
|
||||
|
||||
// Get the username from the internal data
|
||||
usernameRaw, ok := req.Secret.InternalData["username"]
|
||||
if !ok {
|
||||
|
@ -40,27 +37,35 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
|
|||
}
|
||||
username, ok := usernameRaw.(string)
|
||||
|
||||
// Get our connection
|
||||
db, ok := b.connections[dbName]
|
||||
roleNameRaw, ok := req.Secret.InternalData["role"]
|
||||
if !ok {
|
||||
return nil, errors.New(fmt.Sprintf("Could not find connection with name %s", dbName))
|
||||
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
|
||||
}
|
||||
|
||||
// Get the lease information
|
||||
lease, err := b.Lease(req.Storage)
|
||||
role, err := b.Role(req.Storage, roleNameRaw.(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lease == nil {
|
||||
lease = &configLease{}
|
||||
if role == nil {
|
||||
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
|
||||
}
|
||||
|
||||
f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, b.System())
|
||||
f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System())
|
||||
resp, err := f(req, d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Grab the read lock
|
||||
b.RLock()
|
||||
defer b.RUnlock()
|
||||
|
||||
// Get our connection
|
||||
db, ok := b.connections[role.DBName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Could not find connection with name %s", role.DBName)
|
||||
}
|
||||
|
||||
// Make sure we increase the VALID UNTIL endpoint for this user.
|
||||
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
|
||||
expiration := expireTime.Format("2006-01-02 15:04:05-0700")
|
||||
|
@ -124,23 +129,9 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
|
|||
return nil, fmt.Errorf("Could not find database with name: %s", role.DBName)
|
||||
}
|
||||
|
||||
// TODO: Maybe move this down into db package?
|
||||
switch revocationSQL {
|
||||
|
||||
// This is the default revocation logic. If revocation SQL is provided it
|
||||
// is simply executed as-is.
|
||||
case "":
|
||||
err := db.DefaultRevokeUser(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We have revocation SQL, execute directly, within a transaction
|
||||
default:
|
||||
err := db.CustomRevokeUser(username, revocationSQL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = db.RevokeUser(username, revocationSQL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
|
|
Loading…
Reference in New Issue