Combined Database backend: Add Static Account support to MongoDB (#7003)
* Implement SetCredentials for MongoDB, adding support for static accounts * rework SetCredentials to split from CreateUser, and to parse the url for database * Add integration test for mongodb static account rotation * check the length of the password results to avoid out-of-bounds * remove unused method * use the pre-existing test helper for this. Add parse method to helper * remove unused command
This commit is contained in:
parent
28447e00a3
commit
f27dc7d5f8
|
@ -17,6 +17,7 @@ import (
|
||||||
"github.com/hashicorp/vault/helper/namespace"
|
"github.com/hashicorp/vault/helper/namespace"
|
||||||
"github.com/hashicorp/vault/helper/testhelpers/docker"
|
"github.com/hashicorp/vault/helper/testhelpers/docker"
|
||||||
vaulthttp "github.com/hashicorp/vault/http"
|
vaulthttp "github.com/hashicorp/vault/http"
|
||||||
|
"github.com/hashicorp/vault/plugins/database/mongodb"
|
||||||
"github.com/hashicorp/vault/plugins/database/postgresql"
|
"github.com/hashicorp/vault/plugins/database/postgresql"
|
||||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||||
|
@ -102,6 +103,7 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
|
||||||
|
|
||||||
sys := vault.TestDynamicSystemView(cores[0].Core)
|
sys := vault.TestDynamicSystemView(cores[0].Core)
|
||||||
vault.TestAddTestPlugin(t, cores[0].Core, "postgresql-database-plugin", consts.PluginTypeDatabase, "TestBackend_PluginMain_Postgres", []string{}, "")
|
vault.TestAddTestPlugin(t, cores[0].Core, "postgresql-database-plugin", consts.PluginTypeDatabase, "TestBackend_PluginMain_Postgres", []string{}, "")
|
||||||
|
vault.TestAddTestPlugin(t, cores[0].Core, "mongodb-database-plugin", consts.PluginTypeDatabase, "TestBackend_PluginMain_Mongo", []string{}, "")
|
||||||
|
|
||||||
return cluster, sys
|
return cluster, sys
|
||||||
}
|
}
|
||||||
|
@ -125,6 +127,28 @@ func TestBackend_PluginMain_Postgres(t *testing.T) {
|
||||||
postgresql.Run(apiClientMeta.GetTLSConfig())
|
postgresql.Run(apiClientMeta.GetTLSConfig())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBackend_PluginMain_Mongo(t *testing.T) {
|
||||||
|
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv)
|
||||||
|
if caPEM == "" {
|
||||||
|
t.Fatal("CA cert not passed in")
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []string{"--ca-cert=" + caPEM}
|
||||||
|
|
||||||
|
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||||
|
flags := apiClientMeta.FlagSet()
|
||||||
|
flags.Parse(args)
|
||||||
|
|
||||||
|
err := mongodb.Run(apiClientMeta.GetTLSConfig())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBackend_RoleUpgrade(t *testing.T) {
|
func TestBackend_RoleUpgrade(t *testing.T) {
|
||||||
|
|
||||||
storage := &logical.InmemStorage{}
|
storage := &logical.InmemStorage{}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/helper/namespace"
|
"github.com/hashicorp/vault/helper/namespace"
|
||||||
|
"github.com/hashicorp/vault/helper/testhelpers/mongodb"
|
||||||
"github.com/hashicorp/vault/sdk/framework"
|
"github.com/hashicorp/vault/sdk/framework"
|
||||||
"github.com/hashicorp/vault/sdk/helper/dbtxn"
|
"github.com/hashicorp/vault/sdk/helper/dbtxn"
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
|
@ -19,6 +20,8 @@ import (
|
||||||
|
|
||||||
const dbUser = "vaultstatictest"
|
const dbUser = "vaultstatictest"
|
||||||
|
|
||||||
|
const testMongoDBRole = `{ "db": "admin", "roles": [ { "role": "readWrite" } ] }`
|
||||||
|
|
||||||
func TestBackend_StaticRole_Rotate_basic(t *testing.T) {
|
func TestBackend_StaticRole_Rotate_basic(t *testing.T) {
|
||||||
cluster, sys := getCluster(t)
|
cluster, sys := getCluster(t)
|
||||||
defer cluster.Cleanup()
|
defer cluster.Cleanup()
|
||||||
|
@ -814,6 +817,142 @@ func TestBackend_StaticRole_Rotations_PostgreSQL(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBackend_StaticRole_Rotations_MongoDB(t *testing.T) {
|
||||||
|
cluster, sys := getCluster(t)
|
||||||
|
defer cluster.Cleanup()
|
||||||
|
|
||||||
|
config := logical.TestBackendConfig()
|
||||||
|
config.StorageView = &logical.InmemStorage{}
|
||||||
|
config.System = sys
|
||||||
|
|
||||||
|
b, err := Factory(context.Background(), config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer b.Cleanup(context.Background())
|
||||||
|
|
||||||
|
// allow initQueue to finish
|
||||||
|
bd := b.(*databaseBackend)
|
||||||
|
if bd.credRotationQueue == nil {
|
||||||
|
t.Fatal("database backend had no credential rotation queue")
|
||||||
|
}
|
||||||
|
|
||||||
|
// configure backend, add item and confirm length
|
||||||
|
cleanup, connURL := mongodb.PrepareTestContainerWithDatabase(t, "latest", "vaulttestdb")
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Configure a connection
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"connection_url": connURL,
|
||||||
|
"plugin_name": "mongodb-database-plugin",
|
||||||
|
"verify_connection": false,
|
||||||
|
"allowed_roles": []string{"*"},
|
||||||
|
"name": "plugin-mongo-test",
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &logical.Request{
|
||||||
|
Operation: logical.UpdateOperation,
|
||||||
|
Path: "config/plugin-mongo-test",
|
||||||
|
Storage: config.StorageView,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
resp, err := b.HandleRequest(namespace.RootContext(nil), req)
|
||||||
|
if err != nil || (resp != nil && resp.IsError()) {
|
||||||
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create three static roles with different rotation periods
|
||||||
|
testCases := []string{"65", "130", "5400"}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
roleName := "plugin-static-role-" + tc
|
||||||
|
data = map[string]interface{}{
|
||||||
|
"name": roleName,
|
||||||
|
"db_name": "plugin-mongo-test",
|
||||||
|
"username": "statictestMongo" + tc,
|
||||||
|
"rotation_period": tc,
|
||||||
|
}
|
||||||
|
|
||||||
|
req = &logical.Request{
|
||||||
|
Operation: logical.CreateOperation,
|
||||||
|
Path: "static-roles/" + roleName,
|
||||||
|
Storage: config.StorageView,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
||||||
|
if err != nil || (resp != nil && resp.IsError()) {
|
||||||
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify the queue has 3 items in it
|
||||||
|
if bd.credRotationQueue.Len() != 3 {
|
||||||
|
t.Fatalf("expected 3 items in the rotation queue, got: (%d)", bd.credRotationQueue.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
// List the roles
|
||||||
|
data = map[string]interface{}{}
|
||||||
|
req = &logical.Request{
|
||||||
|
Operation: logical.ListOperation,
|
||||||
|
Path: "static-roles/",
|
||||||
|
Storage: config.StorageView,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
||||||
|
if err != nil || (resp != nil && resp.IsError()) {
|
||||||
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := resp.Data["keys"].([]string)
|
||||||
|
if len(keys) != 3 {
|
||||||
|
t.Fatalf("expected 3 roles, got: (%d)", len(keys))
|
||||||
|
}
|
||||||
|
|
||||||
|
// capture initial passwords, before the periodic function is triggered
|
||||||
|
pws := make(map[string][]string, 0)
|
||||||
|
pws = capturePasswords(t, b, config, testCases, pws)
|
||||||
|
|
||||||
|
// sleep to make sure the 65s role will be up for rotation by the time the
|
||||||
|
// periodic function ticks
|
||||||
|
time.Sleep(7 * time.Second)
|
||||||
|
|
||||||
|
// sleep 75 to make sure the periodic func has time to actually run
|
||||||
|
time.Sleep(75 * time.Second)
|
||||||
|
pws = capturePasswords(t, b, config, testCases, pws)
|
||||||
|
|
||||||
|
// sleep more, this should allow both sr65 and sr130 to rotate
|
||||||
|
time.Sleep(140 * time.Second)
|
||||||
|
pws = capturePasswords(t, b, config, testCases, pws)
|
||||||
|
|
||||||
|
// verify all pws are as they should
|
||||||
|
pass := true
|
||||||
|
for k, v := range pws {
|
||||||
|
if len(v) < 3 {
|
||||||
|
t.Fatalf("expected to find 3 passwords for (%s), only found (%d)", k, len(v))
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case k == "plugin-static-role-65":
|
||||||
|
// expect all passwords to be different
|
||||||
|
if v[0] == v[1] || v[1] == v[2] || v[0] == v[2] {
|
||||||
|
pass = false
|
||||||
|
}
|
||||||
|
case k == "plugin-static-role-130":
|
||||||
|
// expect the first two to be equal, but different from the third
|
||||||
|
if v[0] != v[1] || v[0] == v[2] {
|
||||||
|
pass = false
|
||||||
|
}
|
||||||
|
case k == "plugin-static-role-5400":
|
||||||
|
// expect all passwords to be equal
|
||||||
|
if v[0] != v[1] || v[1] != v[2] {
|
||||||
|
pass = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !pass {
|
||||||
|
t.Fatalf("password rotations did not match expected: %#v", pws)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// capturePasswords captures the current passwords at the time of calling, and
|
// capturePasswords captures the current passwords at the time of calling, and
|
||||||
// returns a map of username / passwords building off of the input map
|
// returns a map of username / passwords building off of the input map
|
||||||
func capturePasswords(t *testing.T, b logical.Backend, config *logical.BackendConfig, testCases []string, pws map[string][]string) map[string][]string {
|
func capturePasswords(t *testing.T, b logical.Backend, config *logical.BackendConfig, testCases []string, pws map[string][]string) map[string][]string {
|
||||||
|
|
|
@ -1,17 +1,30 @@
|
||||||
package mongodb
|
package mongodb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/helper/testhelpers/docker"
|
|
||||||
"github.com/ory/dockertest"
|
"github.com/ory/dockertest"
|
||||||
"gopkg.in/mgo.v2"
|
"gopkg.in/mgo.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PrepareTestContainer calls PrepareTestContainerWithDatabase without a
|
||||||
|
// database name value, which results in configuring a database named "test"
|
||||||
func PrepareTestContainer(t *testing.T, version string) (cleanup func(), retURL string) {
|
func PrepareTestContainer(t *testing.T, version string) (cleanup func(), retURL string) {
|
||||||
|
return PrepareTestContainerWithDatabase(t, version, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrepareTestContainerWithDatabase configures a test container with a given
|
||||||
|
// database name, to test non-test/admin database configurations
|
||||||
|
func PrepareTestContainerWithDatabase(t *testing.T, version, dbName string) (cleanup func(), retURL string) {
|
||||||
if os.Getenv("MONGODB_URL") != "" {
|
if os.Getenv("MONGODB_URL") != "" {
|
||||||
return func() {}, os.Getenv("MONGODB_URL")
|
return func() {}, os.Getenv("MONGODB_URL")
|
||||||
}
|
}
|
||||||
|
@ -21,29 +34,36 @@ func PrepareTestContainer(t *testing.T, version string) (cleanup func(), retURL
|
||||||
t.Fatalf("Failed to connect to docker: %s", err)
|
t.Fatalf("Failed to connect to docker: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resource, err := pool.Run("mongo", "latest", []string{})
|
resource, err := pool.Run("mongo", version, []string{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Could not start local mongo docker container: %s", err)
|
t.Fatalf("Could not start local mongo docker container: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanup = func() {
|
cleanup = func() {
|
||||||
docker.CleanupResource(t, pool, resource)
|
err := pool.Purge(resource)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to cleanup local container: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := fmt.Sprintf("localhost:%s", resource.GetPort("27017/tcp"))
|
retURL = fmt.Sprintf("mongodb://localhost:%s", resource.GetPort("27017/tcp"))
|
||||||
retURL = "mongodb://" + addr
|
if dbName != "" {
|
||||||
|
retURL = fmt.Sprintf("%s/%s", retURL, dbName)
|
||||||
|
}
|
||||||
|
|
||||||
// exponential backoff-retry
|
// exponential backoff-retry
|
||||||
if err = pool.Retry(func() error {
|
if err = pool.Retry(func() error {
|
||||||
session, err := mgo.DialWithInfo(&mgo.DialInfo{
|
var err error
|
||||||
Addrs: []string{addr},
|
dialInfo, err := parseMongoURL(retURL)
|
||||||
Timeout: 10 * time.Second,
|
if err != nil {
|
||||||
})
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := mgo.DialWithInfo(dialInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
|
|
||||||
session.SetSyncTimeout(1 * time.Minute)
|
session.SetSyncTimeout(1 * time.Minute)
|
||||||
session.SetSocketTimeout(1 * time.Minute)
|
session.SetSocketTimeout(1 * time.Minute)
|
||||||
return session.Ping()
|
return session.Ping()
|
||||||
|
@ -54,3 +74,72 @@ func PrepareTestContainer(t *testing.T, version string) (cleanup func(), retURL
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseMongoURL will parse a connection string and return a configured dialer
|
||||||
|
func parseMongoURL(rawURL string) (*mgo.DialInfo, error) {
|
||||||
|
url, err := url.Parse(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
info := mgo.DialInfo{
|
||||||
|
Addrs: strings.Split(url.Host, ","),
|
||||||
|
Database: strings.TrimPrefix(url.Path, "/"),
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
if url.User != nil {
|
||||||
|
info.Username = url.User.Username()
|
||||||
|
info.Password, _ = url.User.Password()
|
||||||
|
}
|
||||||
|
|
||||||
|
query := url.Query()
|
||||||
|
for key, values := range query {
|
||||||
|
var value string
|
||||||
|
if len(values) > 0 {
|
||||||
|
value = values[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
switch key {
|
||||||
|
case "authSource":
|
||||||
|
info.Source = value
|
||||||
|
case "authMechanism":
|
||||||
|
info.Mechanism = value
|
||||||
|
case "gssapiServiceName":
|
||||||
|
info.Service = value
|
||||||
|
case "replicaSet":
|
||||||
|
info.ReplicaSetName = value
|
||||||
|
case "maxPoolSize":
|
||||||
|
poolLimit, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("bad value for maxPoolSize: " + value)
|
||||||
|
}
|
||||||
|
info.PoolLimit = poolLimit
|
||||||
|
case "ssl":
|
||||||
|
// Unfortunately, mgo doesn't support the ssl parameter in its MongoDB URI parsing logic, so we have to handle that
|
||||||
|
// ourselves. See https://github.com/go-mgo/mgo/issues/84
|
||||||
|
ssl, err := strconv.ParseBool(value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("bad value for ssl: " + value)
|
||||||
|
}
|
||||||
|
if ssl {
|
||||||
|
info.DialServer = func(addr *mgo.ServerAddr) (net.Conn, error) {
|
||||||
|
return tls.Dial("tcp", addr.String(), &tls.Config{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "connect":
|
||||||
|
if value == "direct" {
|
||||||
|
info.Direct = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if value == "replicaSet" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
return nil, errors.New("unsupported connection URL option: " + key + "=" + value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &info, nil
|
||||||
|
}
|
||||||
|
|
|
@ -15,11 +15,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/errwrap"
|
"github.com/hashicorp/errwrap"
|
||||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
|
||||||
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
||||||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
|
|
||||||
mgo "gopkg.in/mgo.v2"
|
mgo "gopkg.in/mgo.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -154,16 +152,6 @@ func (c *mongoDBConnectionProducer) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetCredentials uses provided information to set/create a user in the
|
|
||||||
// database. Unlike CreateUser, this method requires a username be provided and
|
|
||||||
// uses the name given, instead of generating a name. This is used for creating
|
|
||||||
// and setting the password of static accounts, as well as rolling back
|
|
||||||
// passwords in the database in the event an updated database fails to save in
|
|
||||||
// Vault's storage.
|
|
||||||
func (c *mongoDBConnectionProducer) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
|
|
||||||
return "", "", dbutil.Unimplemented()
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseMongoURL(rawURL string) (*mgo.DialInfo, error) {
|
func parseMongoURL(rawURL string) (*mgo.DialInfo, error) {
|
||||||
url, err := url.Parse(rawURL)
|
url, err := url.Parse(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -155,6 +155,56 @@ func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements
|
||||||
return username, password, nil
|
return username, password, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetCredentials uses provided information to set/create a user in the
|
||||||
|
// database. Unlike CreateUser, this method requires a username be provided and
|
||||||
|
// uses the name given, instead of generating a name. This is used for creating
|
||||||
|
// and setting the password of static accounts, as well as rolling back
|
||||||
|
// passwords in the database in the event an updated database fails to save in
|
||||||
|
// Vault's storage.
|
||||||
|
func (m *MongoDB) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
|
||||||
|
// Grab the lock
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
|
||||||
|
session, err := m.getConnection(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
username = staticUser.Username
|
||||||
|
password = staticUser.Password
|
||||||
|
|
||||||
|
dialInfo, err := parseMongoURL(m.ConnectionURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
mongoUser := mgo.User{
|
||||||
|
Username: username,
|
||||||
|
Password: password,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = session.DB(dialInfo.Database).UpsertUser(&mongoUser)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
case err == io.EOF, strings.Contains(err.Error(), "EOF"):
|
||||||
|
// Call getConnection to reset and retry query if we get an EOF error on first attempt.
|
||||||
|
session, err := m.getConnection(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
err = session.DB(dialInfo.Database).UpsertUser(&mongoUser)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return username, password, nil
|
||||||
|
}
|
||||||
|
|
||||||
// RenewUser is not supported on MongoDB, so this is a no-op.
|
// RenewUser is not supported on MongoDB, so this is a no-op.
|
||||||
func (m *MongoDB) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
func (m *MongoDB) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||||
// NOOP
|
// NOOP
|
||||||
|
|
|
@ -165,3 +165,78 @@ func testCredsExist(t testing.TB, connURL, username, password string) error {
|
||||||
session.SetSocketTimeout(1 * time.Minute)
|
session.SetSocketTimeout(1 * time.Minute)
|
||||||
return session.Ping()
|
return session.Ping()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMongoDB_SetCredentials(t *testing.T) {
|
||||||
|
cleanup, connURL := mongodb.PrepareTestContainer(t, "latest")
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// The docker test method PrepareTestContainer defaults to a database "test"
|
||||||
|
// if none is provided
|
||||||
|
connURL = connURL + "/test"
|
||||||
|
connectionDetails := map[string]interface{}{
|
||||||
|
"connection_url": connURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
db := new()
|
||||||
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create the database user in advance, and test the connection
|
||||||
|
dbUser := "testmongouser"
|
||||||
|
startingPassword := "password"
|
||||||
|
testCreateDBUser(t, connURL, dbUser, startingPassword)
|
||||||
|
if err := testCredsExist(t, connURL, dbUser, startingPassword); err != nil {
|
||||||
|
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newPassword, err := db.GenerateCredentials(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
usernameConfig := dbplugin.StaticUserConfig{
|
||||||
|
Username: dbUser,
|
||||||
|
Password: newPassword,
|
||||||
|
}
|
||||||
|
|
||||||
|
username, password, err := db.SetCredentials(context.Background(), dbplugin.Statements{}, usernameConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := testCredsExist(t, connURL, username, password); err != nil {
|
||||||
|
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||||
|
}
|
||||||
|
// confirm the original creds used to set still work (should be the same)
|
||||||
|
if err := testCredsExist(t, connURL, dbUser, newPassword); err != nil {
|
||||||
|
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dbUser != username) || (newPassword != password) {
|
||||||
|
t.Fatalf("username/password mismatch: (%s)/(%s) vs (%s)/(%s)", dbUser, username, newPassword, password)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCreateDBUser(t testing.TB, connURL, username, password string) {
|
||||||
|
dialInfo, err := parseMongoURL(connURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := mgo.DialWithInfo(dialInfo)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
session.SetSyncTimeout(1 * time.Minute)
|
||||||
|
session.SetSocketTimeout(1 * time.Minute)
|
||||||
|
mUser := mgo.User{
|
||||||
|
Username: username,
|
||||||
|
Password: password,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := session.DB(dialInfo.Database).UpsertUser(&mUser); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ type createUserCommand struct {
|
||||||
Password string `bson:"pwd"`
|
Password string `bson:"pwd"`
|
||||||
Roles []interface{} `bson:"roles"`
|
Roles []interface{} `bson:"roles"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type mongodbRole struct {
|
type mongodbRole struct {
|
||||||
Role string `json:"role" bson:"role"`
|
Role string `json:"role" bson:"role"`
|
||||||
DB string `json:"db" bson:"db"`
|
DB string `json:"db" bson:"db"`
|
||||||
|
|
Loading…
Reference in New Issue