Add special path to enforce root on plugin configuration

This commit is contained in:
Brian Kassouf 2017-03-09 21:31:29 -08:00
parent 748c70cfb4
commit fda45f531d
3 changed files with 146 additions and 113 deletions

View File

@ -20,8 +20,15 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
Root: []string{
"dbs/plugin/*",
},
},
Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigureConnection(&b),
pathConfigurePluginConnection(&b),
pathListRoles(&b),
pathRoles(&b),
pathRoleCreate(&b),

View File

@ -20,7 +20,9 @@ var (
ErrUnsupportedDatabaseType = errors.New("Unsupported database type")
)
func Factory(conf *DatabaseConfig) (DatabaseType, error) {
type Factory func(*DatabaseConfig) (DatabaseType, error)
func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) {
switch conf.DatabaseType {
case postgreSQLTypeName:
var connProducer *sqlConnectionProducer
@ -72,23 +74,24 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) {
ConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}, nil
case pluginTypeName:
if conf.PluginCommand == "" {
return nil, errors.New("ERROR")
}
db, err := newPluginClient(conf.PluginCommand)
if err != nil {
return nil, err
}
return db, nil
}
return nil, ErrUnsupportedDatabaseType
}
func PluginFactory(conf *DatabaseConfig) (DatabaseType, error) {
if conf.PluginCommand == "" {
return nil, errors.New("ERROR")
}
db, err := newPluginClient(conf.PluginCommand)
if err != nil {
return nil, err
}
return db, nil
}
type DatabaseType interface {
Type() string
CreateUser(statements Statements, username, password, expiration string) error
@ -108,6 +111,14 @@ type DatabaseConfig struct {
PluginCommand string `json:"plugin_command" structs:"plugin_command" mapstructure:"plugin_command"`
}
func (dc *DatabaseConfig) GetFactory() Factory {
if dc.DatabaseType == pluginTypeName {
return PluginFactory
}
return BuiltinFactory
}
// Statments set in role creation and passed into the database type's functions.
// TODO: Add a way of setting defaults here.
type Statements struct {

View File

@ -59,7 +59,10 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew
}
db.Close()
db, err = dbs.Factory(&config)
factory := config.GetFactory()
db, err = factory(&config)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
@ -69,9 +72,17 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew
return nil, nil
}
func pathConfigConnection(b *databaseBackend) *framework.Path {
func pathConfigureConnection(b *databaseBackend) *framework.Path {
return buildConfigConnectionPath("dbs/%s", b.connectionWriteHandler(dbs.BuiltinFactory), b.connectionReadHandler())
}
func pathConfigurePluginConnection(b *databaseBackend) *framework.Path {
return buildConfigConnectionPath("dbs/plugin/%s", b.connectionWriteHandler(dbs.PluginFactory), b.connectionReadHandler())
}
func buildConfigConnectionPath(path string, updateOp, readOp framework.OperationFunc) *framework.Path {
return &framework.Path{
Pattern: fmt.Sprintf("dbs/%s", framework.GenericNameRegex("name")),
Pattern: fmt.Sprintf(path, framework.GenericNameRegex("name")),
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
@ -120,8 +131,8 @@ reduced to the same size.`,
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathConnectionWrite,
logical.ReadOperation: b.pathConnectionRead,
logical.UpdateOperation: updateOp,
logical.ReadOperation: readOp,
},
HelpSynopsis: pathConfigConnectionHelpSyn,
@ -130,115 +141,119 @@ reduced to the same size.`,
}
// pathConnectionRead reads out the connection configuration
func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
func (b *databaseBackend) connectionReadHandler() framework.OperationFunc {
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name))
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration")
}
if entry == nil {
return nil, nil
}
entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name))
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration")
}
if entry == nil {
return nil, nil
}
var config dbs.DatabaseConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
var config dbs.DatabaseConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
return &logical.Response{
Data: structs.New(config).Map(),
}, nil
}
return &logical.Response{
Data: structs.New(config).Map(),
}, nil
}
func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
connType := data.Get("connection_type").(string)
if connType == "" {
return logical.ErrorResponse("connection_type not set"), nil
}
maxOpenConns := data.Get("max_open_connections").(int)
if maxOpenConns == 0 {
maxOpenConns = 2
}
maxIdleConns := data.Get("max_idle_connections").(int)
if maxIdleConns == 0 {
maxIdleConns = maxOpenConns
}
if maxIdleConns > maxOpenConns {
maxIdleConns = maxOpenConns
}
maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string)
maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Invalid max_connection_lifetime: %s", err)), nil
}
config := &dbs.DatabaseConfig{
DatabaseType: connType,
ConnectionDetails: data.Raw,
MaxOpenConnections: maxOpenConns,
MaxIdleConnections: maxIdleConns,
MaxConnectionLifetime: maxConnLifetime,
PluginCommand: data.Get("plugin_command").(string),
}
name := data.Get("name").(string)
// Grab the mutex lock
b.Lock()
defer b.Unlock()
var db dbs.DatabaseType
if _, ok := b.connections[name]; ok {
// Don't allow the connection type to change
if b.connections[name].Type() != connType {
return logical.ErrorResponse("Can not change type of existing connection."), nil
}
} else {
db, err = dbs.Factory(config)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.OperationFunc {
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
connType := data.Get("connection_type").(string)
if connType == "" {
return logical.ErrorResponse("connection_type not set"), nil
}
b.connections[name] = db
}
maxOpenConns := data.Get("max_open_connections").(int)
if maxOpenConns == 0 {
maxOpenConns = 2
}
/* TODO:
// Don't check the connection_url if verification is disabled
verifyConnection := data.Get("verify_connection").(bool)
if verifyConnection {
// Verify the string
db, err := sql.Open("postgres", connURL)
maxIdleConns := data.Get("max_idle_connections").(int)
if maxIdleConns == 0 {
maxIdleConns = maxOpenConns
}
if maxIdleConns > maxOpenConns {
maxIdleConns = maxOpenConns
}
maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string)
maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info: %s", err)), nil
"Invalid max_connection_lifetime: %s", err)), nil
}
defer db.Close()
if err := db.Ping(); err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info: %s", err)), nil
config := &dbs.DatabaseConfig{
DatabaseType: connType,
ConnectionDetails: data.Raw,
MaxOpenConnections: maxOpenConns,
MaxIdleConnections: maxIdleConns,
MaxConnectionLifetime: maxConnLifetime,
PluginCommand: data.Get("plugin_command").(string),
}
}
*/
// Store it
entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config)
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
return nil, err
}
name := data.Get("name").(string)
// Reset the DB connection
resp := &logical.Response{}
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.")
// Grab the mutex lock
b.Lock()
defer b.Unlock()
return resp, nil
var db dbs.DatabaseType
if _, ok := b.connections[name]; ok {
// Don't allow the connection type to change
if b.connections[name].Type() != connType {
return logical.ErrorResponse("Can not change type of existing connection."), nil
}
} else {
db, err = factory(config)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
b.connections[name] = db
}
/* TODO:
// Don't check the connection_url if verification is disabled
verifyConnection := data.Get("verify_connection").(bool)
if verifyConnection {
// Verify the string
db, err := sql.Open("postgres", connURL)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info: %s", err)), nil
}
defer db.Close()
if err := db.Ping(); err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info: %s", err)), nil
}
}
*/
// Store it
entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config)
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
return nil, err
}
// Reset the DB connection
resp := &logical.Response{}
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.")
return resp, nil
}
}
const pathConfigConnectionHelpSyn = `