From 1f009518cdda280cb41d4e230a592d98fbbbac32 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 4 Jan 2017 10:53:39 -0800 Subject: [PATCH] s/Statement/Statements/ --- builtin/logical/database/dbs/cassandra.go | 13 ++++-- .../database/dbs/connectionproducer.go | 3 ++ builtin/logical/database/dbs/mysql.go | 14 +++--- builtin/logical/database/dbs/postgresql.go | 14 +++--- builtin/logical/database/path_role_create.go | 3 +- builtin/logical/database/path_roles.go | 46 +++++++++---------- 6 files changed, 50 insertions(+), 43 deletions(-) diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go index a8889032f..7a06e1314 100644 --- a/builtin/logical/database/dbs/cassandra.go +++ b/builtin/logical/database/dbs/cassandra.go @@ -9,6 +9,11 @@ import ( "github.com/hashicorp/vault/helper/strutil" ) +const ( + defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` + defaultRollbackCQL = `DROP USER '{{username}}';` +) + type Cassandra struct { // Session is goroutine safe, however, since we reinitialize // it when connection info changes, we want to make sure we @@ -31,7 +36,7 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) { return session.(*gocql.Session), nil } -func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { +func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { // Get the connection session, err := c.getConnection() if err != nil { @@ -54,7 +59,7 @@ func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, exp }*/ // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -65,7 +70,7 @@ func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, exp "password": password, })).Exec() if err != nil { - for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -88,7 +93,7 @@ func (c *Cassandra) RenewUser(username, expiration string) error { return nil } -func (c *Cassandra) RevokeUser(username, revocationSQL string) error { +func (c *Cassandra) RevokeUser(username, revocationStmts string) error { session, err := c.getConnection() if err != nil { return err diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index dc8f6c82c..5c606996d 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -8,7 +8,10 @@ import ( "sync" "time" + // Import sql drivers _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/tlsutil" diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index 314d4c929..0a18683ea 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -8,7 +8,7 @@ import ( "github.com/hashicorp/vault/helper/strutil" ) -const defaultRevocationSQL = ` +const defaultRevocationStmts = ` REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; DROP USER '{{name}}'@'%' ` @@ -34,7 +34,7 @@ func (p *MySQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (p *MySQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { +func (p *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { // Get the connection db, err := p.getConnection() if err != nil { @@ -54,7 +54,7 @@ func (p *MySQL) CreateUser(createStmt, rollbackStmt, username, password, expirat defer tx.Rollback() // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -86,7 +86,7 @@ func (p *MySQL) RenewUser(username, expiration string) error { return nil } -func (p *MySQL) RevokeUser(username, revocationStmt string) error { +func (p *MySQL) RevokeUser(username, revocationStmts string) error { // Get the connection db, err := p.getConnection() if err != nil { @@ -99,8 +99,8 @@ func (p *MySQL) RevokeUser(username, revocationStmt string) error { // Use a default SQL statement for revocation if one cannot be fetched from the role - if revocationStmt == "" { - revocationStmt = defaultRevocationSQL + if revocationStmts == "" { + revocationStmts = defaultRevocationStmts } // Start a transaction @@ -110,7 +110,7 @@ func (p *MySQL) RevokeUser(username, revocationStmt string) error { } defer tx.Rollback() - for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index e050e30bf..01fb3cd70 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -31,7 +31,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { +func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { // Get the connection db, err := p.getConnection() if err != nil { @@ -56,7 +56,7 @@ func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, ex // Return the secret // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -115,19 +115,19 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error { return nil } -func (p *PostgreSQL) RevokeUser(username, revocationStmt string) error { +func (p *PostgreSQL) RevokeUser(username, revocationStmts string) error { // Grab the read lock p.RLock() defer p.RUnlock() - if revocationStmt == "" { + if revocationStmts == "" { return p.defaultRevokeUser(username) } - return p.customRevokeUser(username, revocationStmt) + return p.customRevokeUser(username, revocationStmts) } -func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error { +func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { db, err := p.getConnection() if err != nil { return err @@ -141,7 +141,7 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error { tx.Rollback() }() - for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 15ca915ba..b1cce97f3 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -5,7 +5,6 @@ import ( "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" - _ "github.com/lib/pq" ) func pathRoleCreate(b *databaseBackend) *framework.Path { @@ -68,7 +67,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo expiration := db.GenerateExpiration(role.DefaultTTL) - err = db.CreateUser(role.CreationStatement, role.RollbackStatement, username, password, expiration) + err = db.CreateUser(role.CreationStatements, role.RollbackStatements, username, password, expiration) if err != nil { return nil, err } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index dc8c6805a..994d084f0 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -35,12 +35,12 @@ func pathRoles(b *databaseBackend) *framework.Path { Description: "Name of the database this role acts on.", }, - "creation_statement": { + "creation_statements": { Type: framework.TypeString, Description: "SQL string to create a user. See help for more info.", }, - "revocation_statement": { + "revocation_statements": { Type: framework.TypeString, Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated string, a base64-encoded semicolon-separated string, a serialized JSON string @@ -48,7 +48,7 @@ func pathRoles(b *databaseBackend) *framework.Path { will be substituted.`, }, - "rollback_statement": { + "rollback_statements": { Type: framework.TypeString, Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated string, a base64-encoded semicolon-separated string, a serialized JSON string @@ -98,11 +98,11 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie return &logical.Response{ Data: map[string]interface{}{ - "creation_statment": role.CreationStatement, - "revocation_statement": role.RevocationStatement, - "rollback_statement": role.RollbackStatement, - "default_ttl": role.DefaultTTL.String(), - "max_ttl": role.MaxTTL.String(), + "creation_statments": role.CreationStatements, + "revocation_statements": role.RevocationStatements, + "rollback_statements": role.RollbackStatements, + "default_ttl": role.DefaultTTL.String(), + "max_ttl": role.MaxTTL.String(), }, }, nil } @@ -119,9 +119,9 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) dbName := data.Get("db_name").(string) - creationStmt := data.Get("creation_statement").(string) - revocationStmt := data.Get("revocation_statement").(string) - rollbackStmt := data.Get("rollback_statement").(string) + creationStmts := data.Get("creation_statements").(string) + revocationStmts := data.Get("revocation_statements").(string) + rollbackStmts := data.Get("rollback_statements").(string) defaultTTLRaw := data.Get("default_ttl").(string) maxTTLRaw := data.Get("max_ttl").(string) @@ -140,12 +140,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F // Store it entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{ - DBName: dbName, - CreationStatement: creationStmt, - RevocationStatement: revocationStmt, - RollbackStatement: rollbackStmt, - DefaultTTL: defaultTTL, - MaxTTL: maxTTL, + DBName: dbName, + CreationStatements: creationStmts, + RevocationStatements: revocationStmts, + RollbackStatements: rollbackStmts, + DefaultTTL: defaultTTL, + MaxTTL: maxTTL, }) if err != nil { return nil, err @@ -158,12 +158,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F } type roleEntry struct { - DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` - CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` - RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` - RollbackStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` - DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` - MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + CreationStatements string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` + RevocationStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` + RollbackStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` + DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` + MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` } const pathRoleHelpSyn = `