s/Statement/Statements/
This commit is contained in:
parent
46aa7142c1
commit
1f009518cd
|
@ -9,6 +9,11 @@ import (
|
|||
"github.com/hashicorp/vault/helper/strutil"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;`
|
||||
defaultRollbackCQL = `DROP USER '{{username}}';`
|
||||
)
|
||||
|
||||
type Cassandra struct {
|
||||
// Session is goroutine safe, however, since we reinitialize
|
||||
// it when connection info changes, we want to make sure we
|
||||
|
@ -31,7 +36,7 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) {
|
|||
return session.(*gocql.Session), nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error {
|
||||
func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error {
|
||||
// Get the connection
|
||||
session, err := c.getConnection()
|
||||
if err != nil {
|
||||
|
@ -54,7 +59,7 @@ func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, exp
|
|||
}*/
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
|
@ -65,7 +70,7 @@ func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, exp
|
|||
"password": password,
|
||||
})).Exec()
|
||||
if err != nil {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmt, ";") {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmts, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
|
@ -88,7 +93,7 @@ func (c *Cassandra) RenewUser(username, expiration string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) RevokeUser(username, revocationSQL string) error {
|
||||
func (c *Cassandra) RevokeUser(username, revocationStmts string) error {
|
||||
session, err := c.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -8,7 +8,10 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
// Import sql drivers
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/tlsutil"
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"github.com/hashicorp/vault/helper/strutil"
|
||||
)
|
||||
|
||||
const defaultRevocationSQL = `
|
||||
const defaultRevocationStmts = `
|
||||
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
|
||||
DROP USER '{{name}}'@'%'
|
||||
`
|
||||
|
@ -34,7 +34,7 @@ func (p *MySQL) getConnection() (*sql.DB, error) {
|
|||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (p *MySQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error {
|
||||
func (p *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error {
|
||||
// Get the connection
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
|
@ -54,7 +54,7 @@ func (p *MySQL) CreateUser(createStmt, rollbackStmt, username, password, expirat
|
|||
defer tx.Rollback()
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
|
@ -86,7 +86,7 @@ func (p *MySQL) RenewUser(username, expiration string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *MySQL) RevokeUser(username, revocationStmt string) error {
|
||||
func (p *MySQL) RevokeUser(username, revocationStmts string) error {
|
||||
// Get the connection
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
|
@ -99,8 +99,8 @@ func (p *MySQL) RevokeUser(username, revocationStmt string) error {
|
|||
|
||||
// Use a default SQL statement for revocation if one cannot be fetched from the role
|
||||
|
||||
if revocationStmt == "" {
|
||||
revocationStmt = defaultRevocationSQL
|
||||
if revocationStmts == "" {
|
||||
revocationStmts = defaultRevocationStmts
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
|
@ -110,7 +110,7 @@ func (p *MySQL) RevokeUser(username, revocationStmt string) error {
|
|||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
|
|
|
@ -31,7 +31,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) {
|
|||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error {
|
||||
func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error {
|
||||
// Get the connection
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
|
@ -56,7 +56,7 @@ func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, ex
|
|||
// Return the secret
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
|
@ -115,19 +115,19 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RevokeUser(username, revocationStmt string) error {
|
||||
func (p *PostgreSQL) RevokeUser(username, revocationStmts string) error {
|
||||
// Grab the read lock
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
if revocationStmt == "" {
|
||||
if revocationStmts == "" {
|
||||
return p.defaultRevokeUser(username)
|
||||
}
|
||||
|
||||
return p.customRevokeUser(username, revocationStmt)
|
||||
return p.customRevokeUser(username, revocationStmts)
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error {
|
||||
func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -141,7 +141,7 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error {
|
|||
tx.Rollback()
|
||||
}()
|
||||
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func pathRoleCreate(b *databaseBackend) *framework.Path {
|
||||
|
@ -68,7 +67,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
|
|||
|
||||
expiration := db.GenerateExpiration(role.DefaultTTL)
|
||||
|
||||
err = db.CreateUser(role.CreationStatement, role.RollbackStatement, username, password, expiration)
|
||||
err = db.CreateUser(role.CreationStatements, role.RollbackStatements, username, password, expiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -35,12 +35,12 @@ func pathRoles(b *databaseBackend) *framework.Path {
|
|||
Description: "Name of the database this role acts on.",
|
||||
},
|
||||
|
||||
"creation_statement": {
|
||||
"creation_statements": {
|
||||
Type: framework.TypeString,
|
||||
Description: "SQL string to create a user. See help for more info.",
|
||||
},
|
||||
|
||||
"revocation_statement": {
|
||||
"revocation_statements": {
|
||||
Type: framework.TypeString,
|
||||
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
|
||||
string, a base64-encoded semicolon-separated string, a serialized JSON string
|
||||
|
@ -48,7 +48,7 @@ func pathRoles(b *databaseBackend) *framework.Path {
|
|||
will be substituted.`,
|
||||
},
|
||||
|
||||
"rollback_statement": {
|
||||
"rollback_statements": {
|
||||
Type: framework.TypeString,
|
||||
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
|
||||
string, a base64-encoded semicolon-separated string, a serialized JSON string
|
||||
|
@ -98,11 +98,11 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie
|
|||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"creation_statment": role.CreationStatement,
|
||||
"revocation_statement": role.RevocationStatement,
|
||||
"rollback_statement": role.RollbackStatement,
|
||||
"default_ttl": role.DefaultTTL.String(),
|
||||
"max_ttl": role.MaxTTL.String(),
|
||||
"creation_statments": role.CreationStatements,
|
||||
"revocation_statements": role.RevocationStatements,
|
||||
"rollback_statements": role.RollbackStatements,
|
||||
"default_ttl": role.DefaultTTL.String(),
|
||||
"max_ttl": role.MaxTTL.String(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
@ -119,9 +119,9 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD
|
|||
func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
dbName := data.Get("db_name").(string)
|
||||
creationStmt := data.Get("creation_statement").(string)
|
||||
revocationStmt := data.Get("revocation_statement").(string)
|
||||
rollbackStmt := data.Get("rollback_statement").(string)
|
||||
creationStmts := data.Get("creation_statements").(string)
|
||||
revocationStmts := data.Get("revocation_statements").(string)
|
||||
rollbackStmts := data.Get("rollback_statements").(string)
|
||||
defaultTTLRaw := data.Get("default_ttl").(string)
|
||||
maxTTLRaw := data.Get("max_ttl").(string)
|
||||
|
||||
|
@ -140,12 +140,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
|
|||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
|
||||
DBName: dbName,
|
||||
CreationStatement: creationStmt,
|
||||
RevocationStatement: revocationStmt,
|
||||
RollbackStatement: rollbackStmt,
|
||||
DefaultTTL: defaultTTL,
|
||||
MaxTTL: maxTTL,
|
||||
DBName: dbName,
|
||||
CreationStatements: creationStmts,
|
||||
RevocationStatements: revocationStmts,
|
||||
RollbackStatements: rollbackStmts,
|
||||
DefaultTTL: defaultTTL,
|
||||
MaxTTL: maxTTL,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -158,12 +158,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
|
|||
}
|
||||
|
||||
type roleEntry struct {
|
||||
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
|
||||
CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"`
|
||||
RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
|
||||
RollbackStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
|
||||
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
|
||||
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
|
||||
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
|
||||
CreationStatements string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"`
|
||||
RevocationStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
|
||||
RollbackStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
|
||||
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
|
||||
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
|
||||
}
|
||||
|
||||
const pathRoleHelpSyn = `
|
||||
|
|
Loading…
Reference in a new issue