diff --git a/builtin/logical/database/path_rotate_credentials.go b/builtin/logical/database/path_rotate_credentials.go index 84ed3db8d..5774ea863 100644 --- a/builtin/logical/database/path_rotate_credentials.go +++ b/builtin/logical/database/path_rotate_credentials.go @@ -78,6 +78,14 @@ func (b *databaseBackend) pathRotateRootCredentialsUpdate() framework.OperationF return nil, err } + // Take out the backend lock since we are swapping out the connection + b.Lock() + defer b.Unlock() + + // Take the write lock on the instance + dbi.Lock() + defer dbi.Unlock() + defer func() { // Close the plugin dbi.closed = true @@ -88,14 +96,6 @@ func (b *databaseBackend) pathRotateRootCredentialsUpdate() framework.OperationF delete(b.connections, name) }() - // Take out the backend lock since we are swapping out the connection - b.Lock() - defer b.Unlock() - - // Take the write lock on the instance - dbi.Lock() - defer dbi.Unlock() - // Generate new credentials oldPassword := config.ConnectionDetails["password"].(string) newPassword, err := dbi.database.GeneratePassword(ctx, b.System(), config.PasswordPolicy) diff --git a/changelog/11600.txt b/changelog/11600.txt new file mode 100644 index 000000000..f40d4bc45 --- /dev/null +++ b/changelog/11600.txt @@ -0,0 +1,9 @@ +```release-note:improvement +secrets/database/mongodb: Add ability to customize `SocketTimeout`, `ConnectTimeout`, and `ServerSelectionTimeout` +``` +```release-note:improvement +secrets/database/mongodb: Increased throughput by allowing for multiple request threads to simultaneously update users in MongoDB +``` +```release-note:bug +secrets/database: Fixed minor race condition when rotate-root is called +``` diff --git a/plugins/database/mongodb/connection_producer.go b/plugins/database/mongodb/connection_producer.go index f160c0a04..1f0c312fa 100644 --- a/plugins/database/mongodb/connection_producer.go +++ b/plugins/database/mongodb/connection_producer.go @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/sdk/database/helper/connutil" "github.com/hashicorp/vault/sdk/database/helper/dbutil" + "github.com/mitchellh/mapstructure" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" @@ -31,6 +32,10 @@ type mongoDBConnectionProducer struct { TLSCertificateKeyData []byte `json:"tls_certificate_key" structs:"-" mapstructure:"tls_certificate_key"` TLSCAData []byte `json:"tls_ca" structs:"-" mapstructure:"tls_ca"` + SocketTimeout time.Duration `json:"socket_timeout" structs:"-" mapstructure:"socket_timeout"` + ConnectTimeout time.Duration `json:"connect_timeout" structs:"-" mapstructure:"connect_timeout"` + ServerSelectionTimeout time.Duration `json:"server_selection_timeout" structs:"-" mapstructure:"server_selection_timeout"` + Initialized bool RawConfig map[string]interface{} Type string @@ -48,15 +53,47 @@ type writeConcern struct { J bool // Sync via the journal if present } +func (c *mongoDBConnectionProducer) loadConfig(cfg map[string]interface{}) error { + err := mapstructure.WeakDecode(cfg, c) + if err != nil { + return err + } + + if len(c.ConnectionURL) == 0 { + return fmt.Errorf("connection_url cannot be empty") + } + + if c.SocketTimeout < 0 { + return fmt.Errorf("socket_timeout must be >= 0") + } + if c.ConnectTimeout < 0 { + return fmt.Errorf("connect_timeout must be >= 0") + } + if c.ServerSelectionTimeout < 0 { + return fmt.Errorf("server_selection_timeout must be >= 0") + } + + opts, err := c.makeClientOpts() + if err != nil { + return err + } + + c.clientOptions = opts + + return nil +} + // Connection creates or returns an existing a database connection. If the session fails // on a ping check, the session will be closed and then re-created. -// This method does not lock the mutex and it is intended that this is the callers -// responsibility. -func (c *mongoDBConnectionProducer) Connection(ctx context.Context) (interface{}, error) { +// This method does locks the mutex on its own. +func (c *mongoDBConnectionProducer) Connection(ctx context.Context) (*mongo.Client, error) { if !c.Initialized { return nil, connutil.ErrNotInitialized } + c.Mutex.Lock() + defer c.Mutex.Unlock() + if c.client != nil { if err := c.client.Ping(ctx, readpref.Primary()); err == nil { return c.client, nil @@ -65,8 +102,7 @@ func (c *mongoDBConnectionProducer) Connection(ctx context.Context) (interface{} _ = c.client.Disconnect(ctx) } - connURL := c.getConnectionURL() - client, err := createClient(ctx, connURL, c.clientOptions) + client, err := c.createClient(ctx) if err != nil { return nil, err } @@ -74,14 +110,14 @@ func (c *mongoDBConnectionProducer) Connection(ctx context.Context) (interface{} return c.client, nil } -func createClient(ctx context.Context, connURL string, clientOptions *options.ClientOptions) (client *mongo.Client, err error) { - if clientOptions == nil { - clientOptions = options.Client() +func (c *mongoDBConnectionProducer) createClient(ctx context.Context) (client *mongo.Client, err error) { + if !c.Initialized { + return nil, fmt.Errorf("failed to create client: connection producer is not initialized") } - clientOptions.SetSocketTimeout(1 * time.Minute) - clientOptions.SetConnectTimeout(1 * time.Minute) - - client, err = mongo.Connect(ctx, options.MergeClientOptions(options.Client().ApplyURI(connURL), clientOptions)) + if c.clientOptions == nil { + return nil, fmt.Errorf("missing client options") + } + client, err = mongo.Connect(ctx, options.MergeClientOptions(options.Client().ApplyURI(c.getConnectionURL()), c.clientOptions)) if err != nil { return nil, err } @@ -120,6 +156,26 @@ func (c *mongoDBConnectionProducer) getConnectionURL() (connURL string) { return connURL } +func (c *mongoDBConnectionProducer) makeClientOpts() (*options.ClientOptions, error) { + writeOpts, err := c.getWriteConcern() + if err != nil { + return nil, err + } + + authOpts, err := c.getTLSAuth() + if err != nil { + return nil, err + } + + timeoutOpts, err := c.timeoutOpts() + if err != nil { + return nil, err + } + + opts := options.MergeClientOptions(writeOpts, authOpts, timeoutOpts) + return opts, nil +} + func (c *mongoDBConnectionProducer) getWriteConcern() (opts *options.ClientOptions, err error) { if c.WriteConcern == "" { return nil, nil @@ -206,3 +262,29 @@ func (c *mongoDBConnectionProducer) getTLSAuth() (opts *options.ClientOptions, e opts.SetTLSConfig(tlsConfig) return opts, nil } + +func (c *mongoDBConnectionProducer) timeoutOpts() (opts *options.ClientOptions, err error) { + opts = options.Client() + + if c.SocketTimeout < 0 { + return nil, fmt.Errorf("socket_timeout must be >= 0") + } + + if c.SocketTimeout == 0 { + opts.SetSocketTimeout(1 * time.Minute) + } else { + opts.SetSocketTimeout(c.SocketTimeout) + } + + if c.ConnectTimeout == 0 { + opts.SetConnectTimeout(1 * time.Minute) + } else { + opts.SetConnectTimeout(c.ConnectTimeout) + } + + if c.ServerSelectionTimeout != 0 { + opts.SetServerSelectionTimeout(c.ServerSelectionTimeout) + } + + return opts, nil +} diff --git a/plugins/database/mongodb/connection_producer_test.go b/plugins/database/mongodb/connection_producer_test.go index c39914cc5..4b0ccaf25 100644 --- a/plugins/database/mongodb/connection_producer_test.go +++ b/plugins/database/mongodb/connection_producer_test.go @@ -103,7 +103,7 @@ net: "connectionStatus": 1, } - client, err := mongo.getConnection(ctx) + client, err := mongo.Connection(ctx) if err != nil { t.Fatalf("Unable to make connection to Mongo: %s", err) } diff --git a/plugins/database/mongodb/mongodb.go b/plugins/database/mongodb/mongodb.go index bfd8d4a3c..884f17dbe 100644 --- a/plugins/database/mongodb/mongodb.go +++ b/plugins/database/mongodb/mongodb.go @@ -7,14 +7,12 @@ import ( "io" "strings" + log "github.com/hashicorp/go-hclog" + dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5" "github.com/hashicorp/vault/sdk/database/helper/dbutil" "github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/helper/template" - - dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5" - "github.com/mitchellh/mapstructure" "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" @@ -57,15 +55,6 @@ func (m *MongoDB) Type() (string, error) { return mongoDBTypeName, nil } -func (m *MongoDB) getConnection(ctx context.Context) (*mongo.Client, error) { - client, err := m.Connection(ctx) - if err != nil { - return nil, err - } - - return client.(*mongo.Client), nil -} - func (m *MongoDB) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) { m.Lock() defer m.Unlock() @@ -91,41 +80,27 @@ func (m *MongoDB) Initialize(ctx context.Context, req dbplugin.InitializeRequest return dbplugin.InitializeResponse{}, fmt.Errorf("invalid username template: %w", err) } - err = mapstructure.WeakDecode(req.Config, m.mongoDBConnectionProducer) + err = m.mongoDBConnectionProducer.loadConfig(req.Config) if err != nil { return dbplugin.InitializeResponse{}, err } - if len(m.ConnectionURL) == 0 { - return dbplugin.InitializeResponse{}, fmt.Errorf("connection_url cannot be empty-mongo fail") - } - - writeOpts, err := m.getWriteConcern() - if err != nil { - return dbplugin.InitializeResponse{}, err - } - - authOpts, err := m.getTLSAuth() - if err != nil { - return dbplugin.InitializeResponse{}, err - } - - m.clientOptions = options.MergeClientOptions(writeOpts, authOpts) - // Set initialized to true at this point since all fields are set, // and the connection can be established at a later time. m.Initialized = true if req.VerifyConnection { - _, err := m.Connection(ctx) + client, err := m.mongoDBConnectionProducer.createClient(ctx) if err != nil { return dbplugin.InitializeResponse{}, fmt.Errorf("failed to verify connection: %w", err) } - err = m.client.Ping(ctx, readpref.Primary()) + err = client.Ping(ctx, readpref.Primary()) if err != nil { + _ = client.Disconnect(ctx) // Try to prevent any sort of resource leak return dbplugin.InitializeResponse{}, fmt.Errorf("failed to verify connection: %w", err) } + m.mongoDBConnectionProducer.client = client } resp := dbplugin.InitializeResponse{ @@ -135,10 +110,6 @@ func (m *MongoDB) Initialize(ctx context.Context, req dbplugin.InitializeRequest } func (m *MongoDB) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (dbplugin.NewUserResponse, error) { - // Grab the lock - m.Lock() - defer m.Unlock() - if len(req.Statements.Commands) == 0 { return dbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement } @@ -189,9 +160,6 @@ func (m *MongoDB) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest } func (m *MongoDB) changeUserPassword(ctx context.Context, username, password string) error { - m.Lock() - defer m.Unlock() - connURL := m.getConnectionURL() cs, err := connstring.Parse(connURL) if err != nil { @@ -218,9 +186,6 @@ func (m *MongoDB) changeUserPassword(ctx context.Context, username, password str } func (m *MongoDB) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) { - m.Lock() - defer m.Unlock() - // If no revocation statements provided, pass in empty JSON var revocationStatement string switch len(req.Statements.Commands) { @@ -251,6 +216,12 @@ func (m *MongoDB) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest } err = m.runCommandWithRetry(ctx, db, dropUserCmd) + cErr, ok := err.(mongo.CommandError) + if ok && cErr.Name == "UserNotFound" { // User already removed, don't retry needlessly + log.Default().Warn("MongoDB user was deleted prior to lease revocation", "user", req.Username) + return dbplugin.DeleteUserResponse{}, nil + } + return dbplugin.DeleteUserResponse{}, err } @@ -258,7 +229,7 @@ func (m *MongoDB) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest // on the first attempt. This should be called with the lock held func (m *MongoDB) runCommandWithRetry(ctx context.Context, db string, cmd interface{}) error { // Get the client - client, err := m.getConnection(ctx) + client, err := m.Connection(ctx) if err != nil { return err } @@ -273,7 +244,7 @@ func (m *MongoDB) runCommandWithRetry(ctx context.Context, db string, cmd interf return nil case err == io.EOF, strings.Contains(err.Error(), "EOF"): // Call getConnection to reset and retry query if we get an EOF error on first attempt. - client, err = m.getConnection(ctx) + client, err = m.Connection(ctx) if err != nil { return err }