From fda45f531d5330acdcba68abd1f4bde9781d2678 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 9 Mar 2017 21:31:29 -0800 Subject: [PATCH] Add special path to enforce root on plugin configuration --- builtin/logical/database/backend.go | 9 +- builtin/logical/database/dbs/db.go | 37 +-- .../database/path_config_connection.go | 213 ++++++++++-------- 3 files changed, 146 insertions(+), 113 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index fe853d3fb..e06e7b381 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -20,8 +20,15 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { b.Backend = &framework.Backend{ Help: strings.TrimSpace(backendHelp), + PathsSpecial: &logical.Paths{ + Root: []string{ + "dbs/plugin/*", + }, + }, + Paths: []*framework.Path{ - pathConfigConnection(&b), + pathConfigureConnection(&b), + pathConfigurePluginConnection(&b), pathListRoles(&b), pathRoles(&b), pathRoleCreate(&b), diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 063cc89cf..bf78d29e6 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -20,7 +20,9 @@ var ( ErrUnsupportedDatabaseType = errors.New("Unsupported database type") ) -func Factory(conf *DatabaseConfig) (DatabaseType, error) { +type Factory func(*DatabaseConfig) (DatabaseType, error) + +func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { switch conf.DatabaseType { case postgreSQLTypeName: var connProducer *sqlConnectionProducer @@ -72,23 +74,24 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) { ConnectionProducer: connProducer, CredentialsProducer: credsProducer, }, nil - - case pluginTypeName: - if conf.PluginCommand == "" { - return nil, errors.New("ERROR") - } - - db, err := newPluginClient(conf.PluginCommand) - if err != nil { - return nil, err - } - - return db, nil } return nil, ErrUnsupportedDatabaseType } +func PluginFactory(conf *DatabaseConfig) (DatabaseType, error) { + if conf.PluginCommand == "" { + return nil, errors.New("ERROR") + } + + db, err := newPluginClient(conf.PluginCommand) + if err != nil { + return nil, err + } + + return db, nil +} + type DatabaseType interface { Type() string CreateUser(statements Statements, username, password, expiration string) error @@ -108,6 +111,14 @@ type DatabaseConfig struct { PluginCommand string `json:"plugin_command" structs:"plugin_command" mapstructure:"plugin_command"` } +func (dc *DatabaseConfig) GetFactory() Factory { + if dc.DatabaseType == pluginTypeName { + return PluginFactory + } + + return BuiltinFactory +} + // Statments set in role creation and passed into the database type's functions. // TODO: Add a way of setting defaults here. type Statements struct { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 4e1da240c..4780dc492 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -59,7 +59,10 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew } db.Close() - db, err = dbs.Factory(&config) + + factory := config.GetFactory() + + db, err = factory(&config) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } @@ -69,9 +72,17 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew return nil, nil } -func pathConfigConnection(b *databaseBackend) *framework.Path { +func pathConfigureConnection(b *databaseBackend) *framework.Path { + return buildConfigConnectionPath("dbs/%s", b.connectionWriteHandler(dbs.BuiltinFactory), b.connectionReadHandler()) +} + +func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { + return buildConfigConnectionPath("dbs/plugin/%s", b.connectionWriteHandler(dbs.PluginFactory), b.connectionReadHandler()) +} + +func buildConfigConnectionPath(path string, updateOp, readOp framework.OperationFunc) *framework.Path { return &framework.Path{ - Pattern: fmt.Sprintf("dbs/%s", framework.GenericNameRegex("name")), + Pattern: fmt.Sprintf(path, framework.GenericNameRegex("name")), Fields: map[string]*framework.FieldSchema{ "name": &framework.FieldSchema{ Type: framework.TypeString, @@ -120,8 +131,8 @@ reduced to the same size.`, }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: b.pathConnectionWrite, - logical.ReadOperation: b.pathConnectionRead, + logical.UpdateOperation: updateOp, + logical.ReadOperation: readOp, }, HelpSynopsis: pathConfigConnectionHelpSyn, @@ -130,115 +141,119 @@ reduced to the same size.`, } // pathConnectionRead reads out the connection configuration -func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - name := data.Get("name").(string) +func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) - entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) - if err != nil { - return nil, fmt.Errorf("failed to read connection configuration") - } - if entry == nil { - return nil, nil - } + entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration") + } + if entry == nil { + return nil, nil + } - var config dbs.DatabaseConfig - if err := entry.DecodeJSON(&config); err != nil { - return nil, err + var config dbs.DatabaseConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + return &logical.Response{ + Data: structs.New(config).Map(), + }, nil } - return &logical.Response{ - Data: structs.New(config).Map(), - }, nil } -func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - connType := data.Get("connection_type").(string) - if connType == "" { - return logical.ErrorResponse("connection_type not set"), nil - } - - maxOpenConns := data.Get("max_open_connections").(int) - if maxOpenConns == 0 { - maxOpenConns = 2 - } - - maxIdleConns := data.Get("max_idle_connections").(int) - if maxIdleConns == 0 { - maxIdleConns = maxOpenConns - } - if maxIdleConns > maxOpenConns { - maxIdleConns = maxOpenConns - } - - maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string) - maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid max_connection_lifetime: %s", err)), nil - } - - config := &dbs.DatabaseConfig{ - DatabaseType: connType, - ConnectionDetails: data.Raw, - MaxOpenConnections: maxOpenConns, - MaxIdleConnections: maxIdleConns, - MaxConnectionLifetime: maxConnLifetime, - PluginCommand: data.Get("plugin_command").(string), - } - - name := data.Get("name").(string) - - // Grab the mutex lock - b.Lock() - defer b.Unlock() - - var db dbs.DatabaseType - if _, ok := b.connections[name]; ok { - - // Don't allow the connection type to change - if b.connections[name].Type() != connType { - return logical.ErrorResponse("Can not change type of existing connection."), nil - } - } else { - db, err = dbs.Factory(config) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil +func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + connType := data.Get("connection_type").(string) + if connType == "" { + return logical.ErrorResponse("connection_type not set"), nil } - b.connections[name] = db - } + maxOpenConns := data.Get("max_open_connections").(int) + if maxOpenConns == 0 { + maxOpenConns = 2 + } - /* TODO: - // Don't check the connection_url if verification is disabled - verifyConnection := data.Get("verify_connection").(bool) - if verifyConnection { - // Verify the string - db, err := sql.Open("postgres", connURL) + maxIdleConns := data.Get("max_idle_connections").(int) + if maxIdleConns == 0 { + maxIdleConns = maxOpenConns + } + if maxIdleConns > maxOpenConns { + maxIdleConns = maxOpenConns + } + + maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string) + maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw) if err != nil { return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil + "Invalid max_connection_lifetime: %s", err)), nil } - defer db.Close() - if err := db.Ping(); err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil + + config := &dbs.DatabaseConfig{ + DatabaseType: connType, + ConnectionDetails: data.Raw, + MaxOpenConnections: maxOpenConns, + MaxIdleConnections: maxIdleConns, + MaxConnectionLifetime: maxConnLifetime, + PluginCommand: data.Get("plugin_command").(string), } - } - */ - // Store it - entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config) - if err != nil { - return nil, err - } - if err := req.Storage.Put(entry); err != nil { - return nil, err - } + name := data.Get("name").(string) - // Reset the DB connection - resp := &logical.Response{} - resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") + // Grab the mutex lock + b.Lock() + defer b.Unlock() - return resp, nil + var db dbs.DatabaseType + if _, ok := b.connections[name]; ok { + + // Don't allow the connection type to change + if b.connections[name].Type() != connType { + return logical.ErrorResponse("Can not change type of existing connection."), nil + } + } else { + db, err = factory(config) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + } + + b.connections[name] = db + } + + /* TODO: + // Don't check the connection_url if verification is disabled + verifyConnection := data.Get("verify_connection").(bool) + if verifyConnection { + // Verify the string + db, err := sql.Open("postgres", connURL) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + defer db.Close() + if err := db.Ping(); err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + } + */ + + // Store it + entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + // Reset the DB connection + resp := &logical.Response{} + resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") + + return resp, nil + } } const pathConfigConnectionHelpSyn = `