From 2ec5ab56160fbf689cbab72d201dc001123f40c2 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 20 Dec 2016 11:46:20 -0800 Subject: [PATCH] More work on refactor and cassandra database --- builtin/logical/database/backend.go | 19 -- builtin/logical/database/dbs/cassandra.go | 204 ++++---------- .../database/dbs/connectionproducer.go | 254 ++++++++++++++++++ .../database/dbs/credentialsproducer.go | 79 ++++++ builtin/logical/database/dbs/db.go | 67 +++-- builtin/logical/database/dbs/postgresql.go | 102 ++----- .../database/path_config_connection.go | 10 +- builtin/logical/database/path_config_lease.go | 103 ------- builtin/logical/database/path_role_create.go | 52 +--- builtin/logical/database/path_roles.go | 50 +++- builtin/logical/database/secret_creds.go | 47 ++-- 11 files changed, 553 insertions(+), 434 deletions(-) create mode 100644 builtin/logical/database/dbs/connectionproducer.go create mode 100644 builtin/logical/database/dbs/credentialsproducer.go delete mode 100644 builtin/logical/database/path_config_lease.go diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 8b7fa3670..3d757df1d 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -22,7 +22,6 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { Paths: []*framework.Path{ pathConfigConnection(&b), - pathConfigLease(&b), pathListRoles(&b), pathRoles(&b), pathRoleCreate(&b), @@ -61,24 +60,6 @@ func (b *databaseBackend) resetAllDBs() { } } -// Lease returns the lease information -func (b *databaseBackend) Lease(s logical.Storage) (*configLease, error) { - entry, err := s.Get("config/lease") - if err != nil { - return nil, err - } - if entry == nil { - return nil, nil - } - - var result configLease - if err := entry.DecodeJSON(&result); err != nil { - return nil, err - } - - return &result, nil -} - func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) { entry, err := s.Get("role/" + n) if err != nil { diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go index 8c7a068be..a8889032f 100644 --- a/builtin/logical/database/dbs/cassandra.go +++ b/builtin/logical/database/dbs/cassandra.go @@ -1,25 +1,20 @@ package dbs import ( - "crypto/tls" - "database/sql" "fmt" "strings" "sync" - "time" "github.com/gocql/gocql" - "github.com/hashicorp/vault/helper/certutil" - "github.com/hashicorp/vault/helper/tlsutil" + "github.com/hashicorp/vault/helper/strutil" ) type Cassandra struct { // Session is goroutine safe, however, since we reinitialize // it when connection info changes, we want to make sure we // can close it and use a new connection; hence the lock - session *gocql.Session - config ConnectionConfig - + ConnectionProducer + CredentialsProducer sync.RWMutex } @@ -27,168 +22,85 @@ func (c *Cassandra) Type() string { return cassandraTypeName } -func (c *Cassandra) Connection() (*gocql.Session, error) { - // Grab the write lock - c.Lock() - defer c.Unlock() - - // If we already have a DB, we got it! - if c.session != nil { - return c.session, nil - } - - session, err := createSession(c.config) +func (c *Cassandra) getConnection() (*gocql.Session, error) { + session, err := c.Connection() if err != nil { return nil, err } - // Store the session in backend for reuse - c.session = session - - return session, nil + return session.(*gocql.Session), nil } -func (p *Cassandra) Close() { - // Grab the write lock - p.Lock() - defer p.Unlock() - - if p.session != nil { - p.session.Close() - } - - p.session = nil -} - -func (p *Cassandra) Reset(config ConnectionConfig) (*sql.DB, error) { - // Grab the write lock - p.Lock() - p.config = config - p.Unlock() - - p.Close() - return p.Connection() -} - -func (p *Cassandra) CreateUser(createStmt, username, password, expiration string) error { +func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { // Get the connection - db, err := p.Connection() + session, err := c.getConnection() if err != nil { return err } // TODO: This is racey // Grab a read lock - p.RLock() - defer p.RUnlock() + c.RLock() + defer c.RUnlock() - return nil -} + // Set consistency + /* if .Consistency != "" { + consistencyValue, err := gocql.ParseConsistencyWrapper(role.Consistency) + if err != nil { + return err + } -func (p *Cassandra) RenewUser(username, expiration string) error { - db, err := p.Connection() - if err != nil { - return err + session.SetConsistency(consistencyValue) + }*/ + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + err = session.Query(queryHelper(query, map[string]string{ + "username": username, + "password": password, + })).Exec() + if err != nil { + for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + session.Query(queryHelper(query, map[string]string{ + "username": username, + "password": password, + })).Exec() + } + return err + } } - // TODO: This is Racey - // Grab the read lock - p.RLock() - defer p.RUnlock() return nil } -func (p *Cassandra) CustomRevokeUser(username, revocationSQL string) error { - db, err := p.Connection() +func (c *Cassandra) RenewUser(username, expiration string) error { + // NOOP + return nil +} + +func (c *Cassandra) RevokeUser(username, revocationSQL string) error { + session, err := c.getConnection() if err != nil { return err } // TODO: this is Racey - p.RLock() - defer p.RUnlock() + c.RLock() + defer c.RUnlock() + + err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec() + if err != nil { + return fmt.Errorf("error removing user %s", username) + } return nil } - -func (p *Cassandra) DefaultRevokeUser(username string) error { - // Grab the read lock - p.RLock() - defer p.RUnlock() - - db, err := p.Connection() - - return nil -} - -func createSession(cfg *ConnectionConfig) (*gocql.Session, error) { - clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...) - clusterConfig.Authenticator = gocql.PasswordAuthenticator{ - Username: cfg.Username, - Password: cfg.Password, - } - - clusterConfig.ProtoVersion = cfg.ProtocolVersion - if clusterConfig.ProtoVersion == 0 { - clusterConfig.ProtoVersion = 2 - } - - clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second - - if cfg.TLS { - var tlsConfig *tls.Config - if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 { - if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 { - return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") - } - - certBundle := &certutil.CertBundle{} - if len(cfg.Certificate) > 0 { - certBundle.Certificate = cfg.Certificate - certBundle.PrivateKey = cfg.PrivateKey - } - if len(cfg.IssuingCA) > 0 { - certBundle.IssuingCA = cfg.IssuingCA - } - - parsedCertBundle, err := certBundle.ToParsedCertBundle() - if err != nil { - return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) - } - - tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) - if err != nil || tlsConfig == nil { - return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) - } - tlsConfig.InsecureSkipVerify = cfg.InsecureTLS - - if cfg.TLSMinVersion != "" { - var ok bool - tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion] - if !ok { - return nil, fmt.Errorf("invalid 'tls_min_version' in config") - } - } else { - // MinVersion was not being set earlier. Reset it to - // zero to gracefully handle upgrades. - tlsConfig.MinVersion = 0 - } - } - - clusterConfig.SslOpts = &gocql.SslOptions{ - Config: *tlsConfig, - } - } - - session, err := clusterConfig.CreateSession() - if err != nil { - return nil, fmt.Errorf("Error creating session: %s", err) - } - - // Verify the info - err = session.Query(`LIST USERS`).Exec() - if err != nil { - return nil, fmt.Errorf("Error validating connection info: %s", err) - } - - return session, nil -} diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go new file mode 100644 index 000000000..adecfd55a --- /dev/null +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -0,0 +1,254 @@ +package dbs + +import ( + "crypto/tls" + "database/sql" + "fmt" + "strings" + "sync" + "time" + + "github.com/gocql/gocql" + "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/tlsutil" + "github.com/mitchellh/mapstructure" +) + +type ConnectionProducer interface { + Connection() (interface{}, error) + Close() + // TODO: Should we make this immutable instead? + Reset(*DatabaseConfig) error +} + +// sqlConnectionProducer impliments ConnectionProducer and provides a generic producer for most sql databases +type sqlConnectionDetails struct { + ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` +} + +type sqlConnectionProducer struct { + config *DatabaseConfig + // TODO: Should we merge these two structures make it immutable? + connDetails *sqlConnectionDetails + + db *sql.DB + sync.Mutex +} + +func (cp *sqlConnectionProducer) Connection() (interface{}, error) { + // Grab the write lock + cp.Lock() + defer cp.Unlock() + + // If we already have a DB, we got it! + if cp.db != nil { + if err := cp.db.Ping(); err == nil { + return cp.db, nil + } + // If the ping was unsuccessful, close it and ignore errors as we'll be + // reestablishing anyways + cp.db.Close() + } + + // Otherwise, attempt to make connection + conn := cp.connDetails.ConnectionURL + + // Ensure timezone is set to UTC for all the conenctions + if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { + if strings.Contains(conn, "?") { + conn += "&timezone=utc" + } else { + conn += "?timezone=utc" + } + } else { + conn += " timezone=utc" + } + + var err error + cp.db, err = sql.Open(cp.config.DatabaseType, conn) + if err != nil { + return nil, err + } + + // Set some connection pool settings. We don't need much of this, + // since the request rate shouldn't be high. + cp.db.SetMaxOpenConns(cp.config.MaxOpenConnections) + cp.db.SetMaxIdleConns(cp.config.MaxIdleConnections) + + return cp.db, nil +} + +func (cp *sqlConnectionProducer) Close() { + // Grab the write lock + cp.Lock() + defer cp.Unlock() + + if cp.db != nil { + cp.db.Close() + } + + cp.db = nil +} + +func (cp *sqlConnectionProducer) Reset(config *DatabaseConfig) error { + // Grab the write lock + cp.Lock() + + var details *sqlConnectionDetails + err := mapstructure.Decode(config.ConnectionDetails, &details) + if err != nil { + return err + } + + cp.connDetails = details + cp.config = config + + cp.Unlock() + + cp.Close() + _, err = cp.Connection() + return err +} + +// cassandraConnectionProducer impliments ConnectionProducer and provides connections for cassandra +type cassandraConnectionDetails struct { + Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` + Username string `json:"username" structs:"username" mapstructure:"username"` + Password string `json:"password" structs:"password" mapstructure:"password"` + TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` + InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` + Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` + PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` + IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` + ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` + ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` + TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` + Consistancy string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` +} + +type cassandraConnectionProducer struct { + config *DatabaseConfig + // TODO: Should we merge these two structures make it immutable? + connDetails *cassandraConnectionDetails + + session *gocql.Session + sync.Mutex +} + +func (cp *cassandraConnectionProducer) Connection() (interface{}, error) { + // Grab the write lock + cp.Lock() + defer cp.Unlock() + + // If we already have a DB, we got it! + if cp.session != nil { + return cp.session, nil + } + + session, err := cp.createSession(cp.connDetails) + if err != nil { + return nil, err + } + + // Store the session in backend for reuse + cp.session = session + + return session, nil +} + +func (cp *cassandraConnectionProducer) Close() { + // Grab the write lock + cp.Lock() + defer cp.Unlock() + + if cp.session != nil { + cp.session.Close() + } + + cp.session = nil +} + +func (cp *cassandraConnectionProducer) Reset(config *DatabaseConfig) error { + // Grab the write lock + cp.Lock() + cp.config = config + cp.Unlock() + + cp.Close() + _, err := cp.Connection() + + return err +} + +func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDetails) (*gocql.Session, error) { + clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...) + clusterConfig.Authenticator = gocql.PasswordAuthenticator{ + Username: cfg.Username, + Password: cfg.Password, + } + + clusterConfig.ProtoVersion = cfg.ProtocolVersion + if clusterConfig.ProtoVersion == 0 { + clusterConfig.ProtoVersion = 2 + } + + clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second + + if cfg.TLS { + var tlsConfig *tls.Config + if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 { + if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 { + return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") + } + + certBundle := &certutil.CertBundle{} + if len(cfg.Certificate) > 0 { + certBundle.Certificate = cfg.Certificate + certBundle.PrivateKey = cfg.PrivateKey + } + if len(cfg.IssuingCA) > 0 { + certBundle.IssuingCA = cfg.IssuingCA + } + + parsedCertBundle, err := certBundle.ToParsedCertBundle() + if err != nil { + return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) + } + + tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) + if err != nil || tlsConfig == nil { + return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) + } + tlsConfig.InsecureSkipVerify = cfg.InsecureTLS + + if cfg.TLSMinVersion != "" { + var ok bool + tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion] + if !ok { + return nil, fmt.Errorf("invalid 'tls_min_version' in config") + } + } else { + // MinVersion was not being set earlier. Reset it to + // zero to gracefully handle upgrades. + tlsConfig.MinVersion = 0 + } + } + + clusterConfig.SslOpts = &gocql.SslOptions{ + Config: *tlsConfig, + } + } + + session, err := clusterConfig.CreateSession() + if err != nil { + return nil, fmt.Errorf("Error creating session: %s", err) + } + + // Verify the info + err = session.Query(`LIST USERS`).Exec() + if err != nil { + return nil, fmt.Errorf("Error validating connection info: %s", err) + } + + return session, nil +} diff --git a/builtin/logical/database/dbs/credentialsproducer.go b/builtin/logical/database/dbs/credentialsproducer.go new file mode 100644 index 000000000..20210c2e8 --- /dev/null +++ b/builtin/logical/database/dbs/credentialsproducer.go @@ -0,0 +1,79 @@ +package dbs + +import ( + "fmt" + "strings" + "time" + + uuid "github.com/hashicorp/go-uuid" +) + +type CredentialsProducer interface { + GenerateUsername(displayName string) (string, error) + GeneratePassword() (string, error) + GenerateExpiration(ttl time.Duration) string +} + +// sqlCredentialsProducer impliments CredentialsProducer and provides a generic credentials producer for most sql database types. +type sqlCredentialsProducer struct { + displayNameLen int + usernameLen int +} + +func (scg *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) { + // Generate the username, password and expiration. PG limits user to 63 characters + if scg.displayNameLen > 0 && len(displayName) > scg.displayNameLen { + displayName = displayName[:scg.displayNameLen] + } + userUUID, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + username := fmt.Sprintf("%s-%s", displayName, userUUID) + if scg.usernameLen > 0 && len(username) > scg.usernameLen { + username = username[:scg.usernameLen] + } + + return username, nil +} + +func (scg *sqlCredentialsProducer) GeneratePassword() (string, error) { + password, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + + return password, nil +} + +func (scg *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string { + return time.Now(). + Add(ttl). + Format("2006-01-02 15:04:05-0700") +} + +type cassandraCredentialsProducer struct{} + +func (ccp *cassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) { + userUUID, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix()) + username = strings.Replace(username, "-", "_", -1) + + return username, nil +} + +func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) { + password, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + + return password, nil +} + +func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) string { + return "" +} diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index ee7b15b64..9d261ff42 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -1,10 +1,11 @@ package dbs import ( - "database/sql" "errors" "fmt" "strings" + + "github.com/mitchellh/mapstructure" ) const ( @@ -16,11 +17,47 @@ var ( ErrUnsupportedDatabaseType = errors.New("Unsupported database type") ) -func Factory(conf ConnectionConfig) (DatabaseType, error) { - switch conf.ConnectionType { +func Factory(conf *DatabaseConfig) (DatabaseType, error) { + switch conf.DatabaseType { case postgreSQLTypeName: + var details *sqlConnectionDetails + err := mapstructure.Decode(conf.ConnectionDetails, &details) + if err != nil { + return nil, err + } + + connProducer := &sqlConnectionProducer{ + config: conf, + connDetails: details, + } + + credsProducer := &sqlCredentialsProducer{ + displayNameLen: 23, + usernameLen: 63, + } + return &PostgreSQL{ - config: conf, + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + }, nil + + case cassandraTypeName: + var details *cassandraConnectionDetails + err := mapstructure.Decode(conf.ConnectionDetails, &details) + if err != nil { + return nil, err + } + + connProducer := &cassandraConnectionProducer{ + config: conf, + connDetails: details, + } + + credsProducer := &cassandraCredentialsProducer{} + + return &Cassandra{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, }, nil } @@ -29,21 +66,19 @@ func Factory(conf ConnectionConfig) (DatabaseType, error) { type DatabaseType interface { Type() string - Connection() (*sql.DB, error) - Close() - Reset(ConnectionConfig) (*sql.DB, error) - CreateUser(createStmt, username, password, expiration string) error + CreateUser(createStmt, rollbackStmt, username, password, expiration string) error RenewUser(username, expiration string) error - CustomRevokeUser(username, revocationSQL string) error - DefaultRevokeUser(username string) error + RevokeUser(username, revocationStmt string) error + + ConnectionProducer + CredentialsProducer } -type ConnectionConfig struct { - ConnectionType string `json:"type" structs:"type" mapstructure:"type"` - ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` - ConnectionDetails map[string]string `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` - MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` - MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` +type DatabaseConfig struct { + DatabaseType string `json:"type" structs:"type" mapstructure:"type"` + ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` } // Query templates a query for us. diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index ea7d08f8a..e050e30bf 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -11,9 +11,10 @@ import ( ) type PostgreSQL struct { - db *sql.DB - config ConnectionConfig + db *sql.DB + ConnectionProducer + CredentialsProducer sync.RWMutex } @@ -21,74 +22,18 @@ func (p *PostgreSQL) Type() string { return postgreSQLTypeName } -func (p *PostgreSQL) Connection() (*sql.DB, error) { - // Grab the write lock - p.Lock() - defer p.Unlock() - - // If we already have a DB, we got it! - if p.db != nil { - if err := p.db.Ping(); err == nil { - return p.db, nil - } - // If the ping was unsuccessful, close it and ignore errors as we'll be - // reestablishing anyways - p.db.Close() - } - - // Otherwise, attempt to make connection - conn := p.config.ConnectionURL - - // Ensure timezone is set to UTC for all the conenctions - if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { - if strings.Contains(conn, "?") { - conn += "&timezone=utc" - } else { - conn += "?timezone=utc" - } - } else { - conn += " timezone=utc" - } - - var err error - p.db, err = sql.Open("postgres", conn) +func (p *PostgreSQL) getConnection() (*sql.DB, error) { + db, err := p.Connection() if err != nil { return nil, err } - // Set some connection pool settings. We don't need much of this, - // since the request rate shouldn't be high. - p.db.SetMaxOpenConns(p.config.MaxOpenConnections) - p.db.SetMaxIdleConns(p.config.MaxIdleConnections) - - return p.db, nil + return db.(*sql.DB), nil } -func (p *PostgreSQL) Close() { - // Grab the write lock - p.Lock() - defer p.Unlock() - - if p.db != nil { - p.db.Close() - } - - p.db = nil -} - -func (p *PostgreSQL) Reset(config ConnectionConfig) (*sql.DB, error) { - // Grab the write lock - p.Lock() - p.config = config - p.Unlock() - - p.Close() - return p.Connection() -} - -func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration string) error { +func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { // Get the connection - db, err := p.Connection() + db, err := p.getConnection() if err != nil { return err } @@ -144,7 +89,7 @@ func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration strin } func (p *PostgreSQL) RenewUser(username, expiration string) error { - db, err := p.Connection() + db, err := p.getConnection() if err != nil { return err } @@ -170,14 +115,23 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error { return nil } -func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error { - db, err := p.Connection() +func (p *PostgreSQL) RevokeUser(username, revocationStmt string) error { + // Grab the read lock + p.RLock() + defer p.RUnlock() + + if revocationStmt == "" { + return p.defaultRevokeUser(username) + } + + return p.customRevokeUser(username, revocationStmt) +} + +func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error { + db, err := p.getConnection() if err != nil { return err } - // TODO: this is Racey - p.RLock() - defer p.RUnlock() tx, err := db.Begin() if err != nil { @@ -187,7 +141,7 @@ func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error { tx.Rollback() }() - for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -213,12 +167,8 @@ func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error { return nil } -func (p *PostgreSQL) DefaultRevokeUser(username string) error { - // Grab the read lock - p.RLock() - defer p.RUnlock() - - db, err := p.Connection() +func (p *PostgreSQL) defaultRevokeUser(username string) error { + db, err := p.getConnection() if err != nil { return err } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index be017ea35..d4a969a69 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -79,7 +79,7 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo return nil, nil } - var config dbs.ConnectionConfig + var config dbs.DatabaseConfig if err := entry.DecodeJSON(&config); err != nil { return nil, err } @@ -89,8 +89,8 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo } func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - connURL := data.Get("connection_url").(string) connType := data.Get("connection_type").(string) + connDetails := data.Get("connection_details").(map[string]interface{}) maxOpenConns := data.Get("max_open_connections").(int) if maxOpenConns == 0 { @@ -105,9 +105,9 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew maxIdleConns = maxOpenConns } - config := dbs.ConnectionConfig{ - ConnectionType: connType, - ConnectionURL: connURL, + config := &dbs.DatabaseConfig{ + DatabaseType: connType, + ConnectionDetails: connDetails, MaxOpenConnections: maxOpenConns, MaxIdleConnections: maxIdleConns, } diff --git a/builtin/logical/database/path_config_lease.go b/builtin/logical/database/path_config_lease.go deleted file mode 100644 index 5cc40a056..000000000 --- a/builtin/logical/database/path_config_lease.go +++ /dev/null @@ -1,103 +0,0 @@ -package database - -import ( - "fmt" - "time" - - "github.com/hashicorp/vault/logical" - "github.com/hashicorp/vault/logical/framework" -) - -func pathConfigLease(b *databaseBackend) *framework.Path { - return &framework.Path{ - Pattern: "config/lease", - Fields: map[string]*framework.FieldSchema{ - "lease": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "Default lease for roles.", - }, - - "lease_max": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "Maximum time a credential is valid for.", - }, - }, - - Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ReadOperation: b.pathLeaseRead, - logical.UpdateOperation: b.pathLeaseWrite, - }, - - HelpSynopsis: pathConfigLeaseHelpSyn, - HelpDescription: pathConfigLeaseHelpDesc, - } -} - -func (b *databaseBackend) pathLeaseWrite( - req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - leaseRaw := d.Get("lease").(string) - leaseMaxRaw := d.Get("lease_max").(string) - - lease, err := time.ParseDuration(leaseRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid lease: %s", err)), nil - } - leaseMax, err := time.ParseDuration(leaseMaxRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid lease: %s", err)), nil - } - - // Store it - entry, err := logical.StorageEntryJSON("config/lease", &configLease{ - Lease: lease, - LeaseMax: leaseMax, - }) - if err != nil { - return nil, err - } - if err := req.Storage.Put(entry); err != nil { - return nil, err - } - - return nil, nil -} - -func (b *databaseBackend) pathLeaseRead( - req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - lease, err := b.Lease(req.Storage) - - if err != nil { - return nil, err - } - if lease == nil { - return nil, nil - } - - return &logical.Response{ - Data: map[string]interface{}{ - "lease": lease.Lease.String(), - "lease_max": lease.LeaseMax.String(), - }, - }, nil -} - -type configLease struct { - Lease time.Duration - LeaseMax time.Duration -} - -const pathConfigLeaseHelpSyn = ` -Configure the default lease information for generated credentials. -` - -const pathConfigLeaseHelpDesc = ` -This configures the default lease information used for credentials -generated by this backend. The lease specifies the duration that a -credential will be valid for, as well as the maximum session for -a set of credentials. - -The format for the lease is "1h" or integer and then unit. The longest -unit is hour. -` diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 2a2386d01..15ca915ba 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -2,9 +2,7 @@ package database import ( "fmt" - "time" - "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" _ "github.com/lib/pq" @@ -45,41 +43,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil } - // Determine if we have a lease - b.logger.Trace("postgres/pathRoleCreateRead: getting lease") - lease, err := b.Lease(req.Storage) - if err != nil { - return nil, err - } - // Unlike some other backends we need a lease here (can't leave as 0 and - // let core fill it in) because Postgres also expires users as a safety - // measure, so cannot be zero - if lease == nil { - lease = &configLease{ - Lease: b.System().DefaultLeaseTTL(), - } - } - // Generate the username, password and expiration. PG limits user to 63 characters - displayName := req.DisplayName - if len(displayName) > 26 { - displayName = displayName[:26] - } - userUUID, err := uuid.GenerateUUID() - if err != nil { - return nil, err - } - username := fmt.Sprintf("%s-%s", displayName, userUUID) - if len(username) > 63 { - username = username[:63] - } - password, err := uuid.GenerateUUID() - if err != nil { - return nil, err - } - expiration := time.Now(). - Add(lease.Lease). - Format("2006-01-02 15:04:05-0700") // Get our handle b.logger.Trace("postgres/pathRoleCreateRead: getting database handle") @@ -92,7 +56,19 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName) } - err = db.CreateUser(role.CreationStatement, username, password, expiration) + username, err := db.GenerateUsername(req.DisplayName) + if err != nil { + return nil, err + } + + password, err := db.GeneratePassword() + if err != nil { + return nil, err + } + + expiration := db.GenerateExpiration(role.DefaultTTL) + + err = db.CreateUser(role.CreationStatement, role.RollbackStatement, username, password, expiration) if err != nil { return nil, err } @@ -105,7 +81,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo "username": username, "role": name, }) - resp.Secret.TTL = lease.Lease + resp.Secret.TTL = role.DefaultTTL return resp, nil } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index e06518b28..dc8c6805a 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -1,6 +1,9 @@ package database import ( + "fmt" + "time" + "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -44,6 +47,24 @@ func pathRoles(b *databaseBackend) *framework.Path { array, or a base64-encoded serialized JSON string array. The '{{name}}' value will be substituted.`, }, + + "rollback_statement": { + Type: framework.TypeString, + Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated + string, a base64-encoded semicolon-separated string, a serialized JSON string + array, or a base64-encoded serialized JSON string array. The '{{name}}' value + will be substituted.`, + }, + + "default_ttl": { + Type: framework.TypeString, + Description: "Default ttl for role.", + }, + + "max_ttl": { + Type: framework.TypeString, + Description: "Maximum time a credential is valid for", + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -79,6 +100,9 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie Data: map[string]interface{}{ "creation_statment": role.CreationStatement, "revocation_statement": role.RevocationStatement, + "rollback_statement": role.RollbackStatement, + "default_ttl": role.DefaultTTL.String(), + "max_ttl": role.MaxTTL.String(), }, }, nil } @@ -97,6 +121,20 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F dbName := data.Get("db_name").(string) creationStmt := data.Get("creation_statement").(string) revocationStmt := data.Get("revocation_statement").(string) + rollbackStmt := data.Get("rollback_statement").(string) + defaultTTLRaw := data.Get("default_ttl").(string) + maxTTLRaw := data.Get("max_ttl").(string) + + defaultTTL, err := time.ParseDuration(defaultTTLRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid default_ttl: %s", err)), nil + } + maxTTL, err := time.ParseDuration(maxTTLRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid max_ttl: %s", err)), nil + } // TODO: Think about preparing the statments to test. @@ -105,6 +143,9 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F DBName: dbName, CreationStatement: creationStmt, RevocationStatement: revocationStmt, + RollbackStatement: rollbackStmt, + DefaultTTL: defaultTTL, + MaxTTL: maxTTL, }) if err != nil { return nil, err @@ -117,9 +158,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F } type roleEntry struct { - DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` - CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` - RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` + RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` + RollbackStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` + DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` + MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` } const pathRoleHelpSyn = ` diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 30c4a6430..120804e91 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -1,7 +1,6 @@ package database import ( - "errors" "fmt" "github.com/hashicorp/vault/logical" @@ -31,8 +30,6 @@ func secretCreds(b *databaseBackend) *framework.Secret { } func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - dbName := d.Get("name").(string) - // Get the username from the internal data usernameRaw, ok := req.Secret.InternalData["username"] if !ok { @@ -40,27 +37,35 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi } username, ok := usernameRaw.(string) - // Get our connection - db, ok := b.connections[dbName] + roleNameRaw, ok := req.Secret.InternalData["role"] if !ok { - return nil, errors.New(fmt.Sprintf("Could not find connection with name %s", dbName)) + return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) } - // Get the lease information - lease, err := b.Lease(req.Storage) + role, err := b.Role(req.Storage, roleNameRaw.(string)) if err != nil { return nil, err } - if lease == nil { - lease = &configLease{} + if role == nil { + return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) } - f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, b.System()) + f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System()) resp, err := f(req, d) if err != nil { return nil, err } + // Grab the read lock + b.RLock() + defer b.RUnlock() + + // Get our connection + db, ok := b.connections[role.DBName] + if !ok { + return nil, fmt.Errorf("Could not find connection with name %s", role.DBName) + } + // Make sure we increase the VALID UNTIL endpoint for this user. if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { expiration := expireTime.Format("2006-01-02 15:04:05-0700") @@ -124,23 +129,9 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F return nil, fmt.Errorf("Could not find database with name: %s", role.DBName) } - // TODO: Maybe move this down into db package? - switch revocationSQL { - - // This is the default revocation logic. If revocation SQL is provided it - // is simply executed as-is. - case "": - err := db.DefaultRevokeUser(username) - if err != nil { - return nil, err - } - - // We have revocation SQL, execute directly, within a transaction - default: - err := db.CustomRevokeUser(username, revocationSQL) - if err != nil { - return nil, err - } + err = db.RevokeUser(username, revocationSQL) + if err != nil { + return nil, err } return resp, nil