diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 2ce759526..e57fa19c1 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -14,15 +14,6 @@ import ( const databaseConfigPath = "database/config/" -// DatabaseConfig is used by the Factory function to configure a DatabaseType -// object. -type DatabaseConfig struct { - PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` - // ConnectionDetails stores the database specific connection settings needed - // by each database type. - ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` -} - func Factory(conf *logical.BackendConfig) (logical.Backend, error) { return Backend(conf).Setup(conf) } @@ -84,16 +75,8 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbpl return db, nil } - entry, err := s.Get(fmt.Sprintf("config/%s", name)) + config, err := b.DatabaseConfig(s, name) if err != nil { - return nil, fmt.Errorf("failed to read connection configuration with name: %s", name) - } - if entry == nil { - return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) - } - - var config DatabaseConfig - if err := entry.DecodeJSON(&config); err != nil { return nil, err } @@ -112,6 +95,23 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbpl return db, nil } +func (b *databaseBackend) DatabaseConfig(s logical.Storage, name string) (*DatabaseConfig, error) { + entry, err := s.Get(fmt.Sprintf("config/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration with name: %s", name) + } + if entry == nil { + return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) + } + + var config DatabaseConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + + return &config, nil +} + func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) { entry, err := s.Get("role/" + n) if err != nil { diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 5b3a0db42..2615577fd 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -130,6 +130,7 @@ func TestBackend_config_connection(t *testing.T) { expected := map[string]interface{}{ "plugin_name": "postgresql-database-plugin", "connection_details": configData, + "allowed_roles": []string{}, } configReq.Operation = logical.ReadOperation resp, err = b.HandleRequest(configReq) @@ -306,6 +307,7 @@ func TestBackend_connectionCrud(t *testing.T) { expected := map[string]interface{}{ "plugin_name": "postgresql-database-plugin", "connection_details": data, + "allowed_roles": []string{}, } req.Operation = logical.ReadOperation resp, err = b.HandleRequest(req) @@ -484,6 +486,105 @@ func TestBackend_roleCrud(t *testing.T) { t.Fatal("Expected response to be nil") } } +func TestBackend_allowedRoles(t *testing.T) { + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b) + defer cleanup() + + // Configure a connection + data := map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + "allowed_roles": "allow, allowed", + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err := b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a denied and an allowed role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/denied", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/allowed", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Get creds from denied role, should fail + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/denied", + Storage: config.StorageView, + Data: data, + } + credsResp, err := b.HandleRequest(req) + if err != logical.ErrPermissionDenied { + t.Fatalf("expected error to be:%s got:%#v\n", logical.ErrPermissionDenied, err) + } + + // Get creds from allowed role, should work. + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/allowed", + Storage: config.StorageView, + Data: data, + } + credsResp, err = b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + + if !testCredsExist(t, credsResp, connURL) { + t.Fatalf("Creds should exist") + } +} func testCredsExist(t *testing.T, resp *logical.Response, connURL string) bool { var d struct { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index f69c7761b..2a0022b4d 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -6,15 +6,26 @@ import ( "github.com/fatih/structs" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) var ( - respErrEmptyPluginName = logical.ErrorResponse("empty plugin name") + respErrEmptyPluginName = logical.ErrorResponse("Empty plugin name") respErrEmptyName = logical.ErrorResponse("Empty name attribute given") ) +// DatabaseConfig is used by the Factory function to configure a DatabaseType +// object. +type DatabaseConfig struct { + PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` + // ConnectionDetails stores the database specific connection settings needed + // by each database type. + ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` + AllowedRoles []string `json:"allowed_roles" structs:"allowed_roles" mapstructure:"allowed_roles"` +} + // pathResetConnection configures a path to reset a plugin. func pathResetConnection(b *databaseBackend) *framework.Path { return &framework.Path{ @@ -75,15 +86,22 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { "plugin_name": &framework.FieldSchema{ Type: framework.TypeString, Description: `The name of a builtin or previously registered - plugin known to vault. This endpoint will create an instance of - that plugin type.`, + plugin known to vault. This endpoint will create an instance of + that plugin type.`, }, "verify_connection": &framework.FieldSchema{ Type: framework.TypeBool, Default: true, Description: `If true, the connection details are verified by - actually connecting to the database. Defaults to true.`, + actually connecting to the database. Defaults to true.`, + }, + + "allowed_roles": &framework.FieldSchema{ + Type: framework.TypeString, + Description: `Comma separated list of the role names allowed to + get creds from this database connection. If not set all roles + are allowed.`, }, }, @@ -169,9 +187,14 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { verifyConnection := data.Get("verify_connection").(bool) + // Pasrse and dedupe allowed roles from a comma separated string. + allowedRolesRaw := data.Get("allowed_roles").(string) + allowedRoles := strutil.ParseDedupAndSortStrings(allowedRolesRaw, ",") + config := &DatabaseConfig{ ConnectionDetails: data.Raw, PluginName: pluginName, + AllowedRoles: allowedRoles, } db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 59584e943..631802dff 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -40,6 +41,17 @@ func (b *databaseBackend) pathRoleCreateRead() framework.OperationFunc { return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil } + dbConfig, err := b.DatabaseConfig(req.Storage, role.DBName) + if err != nil { + return nil, err + } + + // If role name isn't in the database's allowed roles, send back a + // permission denied. + if len(dbConfig.AllowedRoles) > 0 && !strutil.StrListContains(dbConfig.AllowedRoles, name) { + return nil, logical.ErrPermissionDenied + } + b.Lock() defer b.Unlock()