Pass statements object
This commit is contained in:
parent
843d584254
commit
3976a2a0a6
|
@ -34,7 +34,7 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) {
|
|||
return session.(*gocql.Session), nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error {
|
||||
func (c *Cassandra) CreateUser(statements Statements, username, password, expiration string) error {
|
||||
// Grab the lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
@ -46,7 +46,7 @@ func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, e
|
|||
}
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
|
@ -57,7 +57,7 @@ func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, e
|
|||
"password": password,
|
||||
})).Exec()
|
||||
if err != nil {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmts, ";") {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.RollbackStatements, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
|
@ -75,12 +75,12 @@ func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, e
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) RenewUser(username, expiration string) error {
|
||||
func (c *Cassandra) RenewUser(statements Statements, username, expiration string) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) RevokeUser(username, revocationStmts string) error {
|
||||
func (c *Cassandra) RevokeUser(statements Statements, username string) error {
|
||||
// Grab the lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
|
|
@ -78,9 +78,9 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) {
|
|||
|
||||
type DatabaseType interface {
|
||||
Type() string
|
||||
CreateUser(createStmt, rollbackStmt, username, password, expiration string) error
|
||||
RenewUser(username, expiration string) error
|
||||
RevokeUser(username, revocationStmt string) error
|
||||
CreateUser(statements Statements, username, password, expiration string) error
|
||||
RenewUser(statements Statements, username, expiration string) error
|
||||
RevokeUser(statements Statements, username string) error
|
||||
|
||||
ConnectionProducer
|
||||
CredentialsProducer
|
||||
|
@ -94,6 +94,13 @@ type DatabaseConfig struct {
|
|||
MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
|
||||
}
|
||||
|
||||
type Statements struct {
|
||||
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
|
||||
RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"`
|
||||
RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"`
|
||||
RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"`
|
||||
}
|
||||
|
||||
// Query templates a query for us.
|
||||
func queryHelper(tpl string, data map[string]string) string {
|
||||
for k, v := range data {
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"github.com/hashicorp/vault/helper/strutil"
|
||||
)
|
||||
|
||||
const defaultRevocationStmts = `
|
||||
const defaultMysqlRevocationStmts = `
|
||||
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
|
||||
DROP USER '{{name}}'@'%'
|
||||
`
|
||||
|
@ -30,7 +30,7 @@ func (m *MySQL) getConnection() (*sql.DB, error) {
|
|||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error {
|
||||
func (m *MySQL) CreateUser(statements Statements, username, password, expiration string) error {
|
||||
// Grab the lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
@ -49,7 +49,7 @@ func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expir
|
|||
defer tx.Rollback()
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
|
@ -77,11 +77,11 @@ func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expir
|
|||
}
|
||||
|
||||
// NOOP
|
||||
func (m *MySQL) RenewUser(username, expiration string) error {
|
||||
func (m *MySQL) RenewUser(statements Statements, username, expiration string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQL) RevokeUser(username, revocationStmts string) error {
|
||||
func (m *MySQL) RevokeUser(statements Statements, username string) error {
|
||||
// Grab the read lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
@ -92,9 +92,10 @@ func (m *MySQL) RevokeUser(username, revocationStmts string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
revocationStmts := statements.RevocationStatements
|
||||
// Use a default SQL statement for revocation if one cannot be fetched from the role
|
||||
if revocationStmts == "" {
|
||||
revocationStmts = defaultRevocationStmts
|
||||
revocationStmts = defaultMysqlRevocationStmts
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
|
|
|
@ -27,7 +27,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) {
|
|||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error {
|
||||
func (p *PostgreSQL) CreateUser(statements Statements, username, password, expiration string) error {
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
@ -51,7 +51,7 @@ func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password,
|
|||
// Return the secret
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
|
@ -83,7 +83,7 @@ func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password,
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RenewUser(username, expiration string) error {
|
||||
func (p *PostgreSQL) RenewUser(statements Statements, username, expiration string) error {
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
@ -110,16 +110,16 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RevokeUser(username, revocationStmts string) error {
|
||||
func (p *PostgreSQL) RevokeUser(statements Statements, username string) error {
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if revocationStmts == "" {
|
||||
if statements.RevocationStatements == "" {
|
||||
return p.defaultRevokeUser(username)
|
||||
}
|
||||
|
||||
return p.customRevokeUser(username, revocationStmts)
|
||||
return p.customRevokeUser(username, statements.RevocationStatements)
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
||||
|
|
|
@ -67,7 +67,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
|
|||
|
||||
expiration := db.GenerateExpiration(role.DefaultTTL)
|
||||
|
||||
err = db.CreateUser(role.CreationStatements, role.RollbackStatements, username, password, expiration)
|
||||
err = db.CreateUser(role.Statements, username, password, expiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbs"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -42,12 +43,18 @@ func pathRoles(b *databaseBackend) *framework.Path {
|
|||
|
||||
"revocation_statements": {
|
||||
Type: framework.TypeString,
|
||||
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
|
||||
Description: `Statements to be executed to revoke a user. Must be a semicolon-separated
|
||||
string, a base64-encoded semicolon-separated string, a serialized JSON string
|
||||
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
|
||||
will be substituted.`,
|
||||
},
|
||||
"renew_statements": {
|
||||
Type: framework.TypeString,
|
||||
Description: `Statements to be executed to renew a user. Must be a semicolon-separated
|
||||
string, a base64-encoded semicolon-separated string, a serialized JSON string
|
||||
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
|
||||
will be substituted.`,
|
||||
},
|
||||
|
||||
"rollback_statements": {
|
||||
Type: framework.TypeString,
|
||||
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
|
||||
|
@ -98,9 +105,10 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie
|
|||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"creation_statments": role.CreationStatements,
|
||||
"revocation_statements": role.RevocationStatements,
|
||||
"rollback_statements": role.RollbackStatements,
|
||||
"creation_statments": role.Statements.CreationStatements,
|
||||
"revocation_statements": role.Statements.RevocationStatements,
|
||||
"rollback_statements": role.Statements.RollbackStatements,
|
||||
"renew_statements": role.Statements.RenewStatements,
|
||||
"default_ttl": role.DefaultTTL.String(),
|
||||
"max_ttl": role.MaxTTL.String(),
|
||||
},
|
||||
|
@ -119,9 +127,14 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD
|
|||
func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
dbName := data.Get("db_name").(string)
|
||||
|
||||
// Get statements
|
||||
creationStmts := data.Get("creation_statements").(string)
|
||||
revocationStmts := data.Get("revocation_statements").(string)
|
||||
rollbackStmts := data.Get("rollback_statements").(string)
|
||||
renewStmts := data.Get("renew_statements").(string)
|
||||
|
||||
// Get TTLs
|
||||
defaultTTLRaw := data.Get("default_ttl").(string)
|
||||
maxTTLRaw := data.Get("max_ttl").(string)
|
||||
|
||||
|
@ -136,14 +149,19 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
|
|||
"Invalid max_ttl: %s", err)), nil
|
||||
}
|
||||
|
||||
statements := dbs.Statements{
|
||||
CreationStatements: creationStmts,
|
||||
RevocationStatements: revocationStmts,
|
||||
RollbackStatements: rollbackStmts,
|
||||
RenewStatements: rollbackStmts,
|
||||
}
|
||||
|
||||
// TODO: Think about preparing the statments to test.
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
|
||||
DBName: dbName,
|
||||
CreationStatements: creationStmts,
|
||||
RevocationStatements: revocationStmts,
|
||||
RollbackStatements: rollbackStmts,
|
||||
Statements: statements,
|
||||
DefaultTTL: defaultTTL,
|
||||
MaxTTL: maxTTL,
|
||||
})
|
||||
|
@ -159,9 +177,7 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
|
|||
|
||||
type roleEntry struct {
|
||||
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
|
||||
CreationStatements string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"`
|
||||
RevocationStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
|
||||
RollbackStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
|
||||
Statements dbs.Statements `json:"statments" mapstructure:"statements" structs:"statments"`
|
||||
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
|
||||
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue