Add plugin features

This commit is contained in:
Brian Kassouf 2017-03-09 17:43:37 -08:00
parent b7128f8370
commit 9099231229
4 changed files with 34 additions and 7 deletions

View File

@ -11,7 +11,7 @@ import (
type CredentialsProducer interface {
GenerateUsername(displayName string) (string, error)
GeneratePassword() (string, error)
GenerateExpiration(ttl time.Duration) string
GenerateExpiration(ttl time.Duration) (string, error)
}
// sqlCredentialsProducer impliments CredentialsProducer and provides a generic credentials producer for most sql database types.
@ -46,10 +46,10 @@ func (scp *sqlCredentialsProducer) GeneratePassword() (string, error) {
return password, nil
}
func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string {
func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) {
return time.Now().
Add(ttl).
Format("2006-01-02 15:04:05-0700")
Format("2006-01-02 15:04:05-0700"), nil
}
type cassandraCredentialsProducer struct{}
@ -74,6 +74,6 @@ func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) {
return password, nil
}
func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) string {
return ""
func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) {
return "", nil
}

View File

@ -13,6 +13,7 @@ const (
postgreSQLTypeName = "postgres"
mySQLTypeName = "mysql"
cassandraTypeName = "cassandra"
pluginTypeName = "plugin"
)
var (
@ -71,6 +72,18 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) {
ConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}, nil
case pluginTypeName:
if conf.PluginCommand == "" {
return nil, errors.New("ERROR")
}
db, err := newPluginClient(conf.PluginCommand)
if err != nil {
return nil, err
}
return db, nil
}
return nil, ErrUnsupportedDatabaseType
@ -82,7 +95,7 @@ type DatabaseType interface {
RenewUser(statements Statements, username, expiration string) error
RevokeUser(statements Statements, username string) error
ConnectionProducer
Close()
CredentialsProducer
}
@ -92,6 +105,7 @@ type DatabaseConfig struct {
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
PluginCommand string `json:"plugin_command" structs:"plugin_command" mapstructure:"plugin_command"`
}
// Statments set in role creation and passed into the database type's functions.

View File

@ -111,6 +111,12 @@ reduced to the same size.`,
Description: `Maximum amount of time a connection may be reused;
a zero or negative value reuses connections forever.`,
},
"plugin_command": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Maximum amount of time a connection may be reused;
a zero or negative value reuses connections forever.`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -146,6 +152,9 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo
func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
connType := data.Get("connection_type").(string)
if connType == "" {
return logical.ErrorResponse("connection_type not set"), nil
}
maxOpenConns := data.Get("max_open_connections").(int)
if maxOpenConns == 0 {
@ -173,6 +182,7 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew
MaxOpenConnections: maxOpenConns,
MaxIdleConnections: maxIdleConns,
MaxConnectionLifetime: maxConnLifetime,
PluginCommand: data.Get("plugin_command").(string),
}
name := data.Get("name").(string)

View File

@ -65,7 +65,10 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
return nil, err
}
expiration := db.GenerateExpiration(role.DefaultTTL)
expiration, err := db.GenerateExpiration(role.DefaultTTL)
if err != nil {
return nil, err
}
err = db.CreateUser(role.Statements, username, password, expiration)
if err != nil {