Pass statements object

This commit is contained in:
Brian Kassouf 2017-03-07 16:48:17 -08:00
parent 843d584254
commit 3976a2a0a6
6 changed files with 62 additions and 38 deletions

View file

@ -34,7 +34,7 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) {
return session.(*gocql.Session), nil
}
func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error {
func (c *Cassandra) CreateUser(statements Statements, username, password, expiration string) error {
// Grab the lock
c.Lock()
defer c.Unlock()
@ -46,7 +46,7 @@ func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, e
}
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") {
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@ -57,7 +57,7 @@ func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, e
"password": password,
})).Exec()
if err != nil {
for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmts, ";") {
for _, query := range strutil.ParseArbitraryStringSlice(statements.RollbackStatements, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@ -75,12 +75,12 @@ func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, e
return nil
}
func (c *Cassandra) RenewUser(username, expiration string) error {
func (c *Cassandra) RenewUser(statements Statements, username, expiration string) error {
// NOOP
return nil
}
func (c *Cassandra) RevokeUser(username, revocationStmts string) error {
func (c *Cassandra) RevokeUser(statements Statements, username string) error {
// Grab the lock
c.Lock()
defer c.Unlock()

View file

@ -78,9 +78,9 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) {
type DatabaseType interface {
Type() string
CreateUser(createStmt, rollbackStmt, username, password, expiration string) error
RenewUser(username, expiration string) error
RevokeUser(username, revocationStmt string) error
CreateUser(statements Statements, username, password, expiration string) error
RenewUser(statements Statements, username, expiration string) error
RevokeUser(statements Statements, username string) error
ConnectionProducer
CredentialsProducer
@ -94,6 +94,13 @@ type DatabaseConfig struct {
MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
}
type Statements struct {
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"`
RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"`
RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"`
}
// Query templates a query for us.
func queryHelper(tpl string, data map[string]string) string {
for k, v := range data {

View file

@ -7,7 +7,7 @@ import (
"github.com/hashicorp/vault/helper/strutil"
)
const defaultRevocationStmts = `
const defaultMysqlRevocationStmts = `
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
DROP USER '{{name}}'@'%'
`
@ -30,7 +30,7 @@ func (m *MySQL) getConnection() (*sql.DB, error) {
return db.(*sql.DB), nil
}
func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error {
func (m *MySQL) CreateUser(statements Statements, username, password, expiration string) error {
// Grab the lock
m.Lock()
defer m.Unlock()
@ -49,7 +49,7 @@ func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expir
defer tx.Rollback()
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") {
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@ -77,11 +77,11 @@ func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expir
}
// NOOP
func (m *MySQL) RenewUser(username, expiration string) error {
func (m *MySQL) RenewUser(statements Statements, username, expiration string) error {
return nil
}
func (m *MySQL) RevokeUser(username, revocationStmts string) error {
func (m *MySQL) RevokeUser(statements Statements, username string) error {
// Grab the read lock
m.Lock()
defer m.Unlock()
@ -92,9 +92,10 @@ func (m *MySQL) RevokeUser(username, revocationStmts string) error {
return err
}
revocationStmts := statements.RevocationStatements
// Use a default SQL statement for revocation if one cannot be fetched from the role
if revocationStmts == "" {
revocationStmts = defaultRevocationStmts
revocationStmts = defaultMysqlRevocationStmts
}
// Start a transaction

View file

@ -27,7 +27,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) {
return db.(*sql.DB), nil
}
func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error {
func (p *PostgreSQL) CreateUser(statements Statements, username, password, expiration string) error {
// Grab the lock
p.Lock()
defer p.Unlock()
@ -51,7 +51,7 @@ func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password,
// Return the secret
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") {
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@ -83,7 +83,7 @@ func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password,
return nil
}
func (p *PostgreSQL) RenewUser(username, expiration string) error {
func (p *PostgreSQL) RenewUser(statements Statements, username, expiration string) error {
// Grab the lock
p.Lock()
defer p.Unlock()
@ -110,16 +110,16 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error {
return nil
}
func (p *PostgreSQL) RevokeUser(username, revocationStmts string) error {
func (p *PostgreSQL) RevokeUser(statements Statements, username string) error {
// Grab the lock
p.Lock()
defer p.Unlock()
if revocationStmts == "" {
if statements.RevocationStatements == "" {
return p.defaultRevokeUser(username)
}
return p.customRevokeUser(username, revocationStmts)
return p.customRevokeUser(username, statements.RevocationStatements)
}
func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {

View file

@ -67,7 +67,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
expiration := db.GenerateExpiration(role.DefaultTTL)
err = db.CreateUser(role.CreationStatements, role.RollbackStatements, username, password, expiration)
err = db.CreateUser(role.Statements, username, password, expiration)
if err != nil {
return nil, err
}

View file

@ -4,6 +4,7 @@ import (
"fmt"
"time"
"github.com/hashicorp/vault/builtin/logical/database/dbs"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -42,12 +43,18 @@ func pathRoles(b *databaseBackend) *framework.Path {
"revocation_statements": {
Type: framework.TypeString,
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
Description: `Statements to be executed to revoke a user. Must be a semicolon-separated
string, a base64-encoded semicolon-separated string, a serialized JSON string
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
will be substituted.`,
},
"renew_statements": {
Type: framework.TypeString,
Description: `Statements to be executed to renew a user. Must be a semicolon-separated
string, a base64-encoded semicolon-separated string, a serialized JSON string
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
will be substituted.`,
},
"rollback_statements": {
Type: framework.TypeString,
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
@ -98,9 +105,10 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie
return &logical.Response{
Data: map[string]interface{}{
"creation_statments": role.CreationStatements,
"revocation_statements": role.RevocationStatements,
"rollback_statements": role.RollbackStatements,
"creation_statments": role.Statements.CreationStatements,
"revocation_statements": role.Statements.RevocationStatements,
"rollback_statements": role.Statements.RollbackStatements,
"renew_statements": role.Statements.RenewStatements,
"default_ttl": role.DefaultTTL.String(),
"max_ttl": role.MaxTTL.String(),
},
@ -119,9 +127,14 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD
func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
dbName := data.Get("db_name").(string)
// Get statements
creationStmts := data.Get("creation_statements").(string)
revocationStmts := data.Get("revocation_statements").(string)
rollbackStmts := data.Get("rollback_statements").(string)
renewStmts := data.Get("renew_statements").(string)
// Get TTLs
defaultTTLRaw := data.Get("default_ttl").(string)
maxTTLRaw := data.Get("max_ttl").(string)
@ -136,14 +149,19 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
"Invalid max_ttl: %s", err)), nil
}
statements := dbs.Statements{
CreationStatements: creationStmts,
RevocationStatements: revocationStmts,
RollbackStatements: rollbackStmts,
RenewStatements: rollbackStmts,
}
// TODO: Think about preparing the statments to test.
// Store it
entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
DBName: dbName,
CreationStatements: creationStmts,
RevocationStatements: revocationStmts,
RollbackStatements: rollbackStmts,
Statements: statements,
DefaultTTL: defaultTTL,
MaxTTL: maxTTL,
})
@ -159,9 +177,7 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
type roleEntry struct {
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
CreationStatements string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"`
RevocationStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
RollbackStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
Statements dbs.Statements `json:"statments" mapstructure:"statements" structs:"statments"`
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
}