27bb03bbc0
* adding copyright header * fix fmt and a test
274 lines
7.4 KiB
Go
274 lines
7.4 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package mongodb
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/go-secure-stdlib/strutil"
|
|
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
|
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
|
"github.com/hashicorp/vault/sdk/helper/template"
|
|
"go.mongodb.org/mongo-driver/mongo"
|
|
"go.mongodb.org/mongo-driver/mongo/readpref"
|
|
"go.mongodb.org/mongo-driver/mongo/writeconcern"
|
|
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
|
|
)
|
|
|
|
const (
|
|
mongoDBTypeName = "mongodb"
|
|
|
|
defaultUserNameTemplate = `{{ printf "v-%s-%s-%s-%s" (.DisplayName | truncate 15) (.RoleName | truncate 15) (random 20) (unix_time) | replace "." "-" | truncate 100 }}`
|
|
)
|
|
|
|
// MongoDB is an implementation of Database interface
|
|
type MongoDB struct {
|
|
*mongoDBConnectionProducer
|
|
|
|
usernameProducer template.StringTemplate
|
|
}
|
|
|
|
var _ dbplugin.Database = &MongoDB{}
|
|
|
|
// New returns a new MongoDB instance
|
|
func New() (interface{}, error) {
|
|
db := new()
|
|
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
|
|
return dbType, nil
|
|
}
|
|
|
|
func new() *MongoDB {
|
|
connProducer := &mongoDBConnectionProducer{
|
|
Type: mongoDBTypeName,
|
|
}
|
|
|
|
return &MongoDB{
|
|
mongoDBConnectionProducer: connProducer,
|
|
}
|
|
}
|
|
|
|
// Type returns the TypeName for this backend
|
|
func (m *MongoDB) Type() (string, error) {
|
|
return mongoDBTypeName, nil
|
|
}
|
|
|
|
func (m *MongoDB) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
|
|
m.Lock()
|
|
defer m.Unlock()
|
|
|
|
m.RawConfig = req.Config
|
|
|
|
usernameTemplate, err := strutil.GetString(req.Config, "username_template")
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve username_template: %w", err)
|
|
}
|
|
if usernameTemplate == "" {
|
|
usernameTemplate = defaultUserNameTemplate
|
|
}
|
|
|
|
up, err := template.NewTemplate(template.Template(usernameTemplate))
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, fmt.Errorf("unable to initialize username template: %w", err)
|
|
}
|
|
m.usernameProducer = up
|
|
|
|
_, err = m.usernameProducer.Generate(dbplugin.UsernameMetadata{})
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, fmt.Errorf("invalid username template: %w", err)
|
|
}
|
|
|
|
err = m.mongoDBConnectionProducer.loadConfig(req.Config)
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, err
|
|
}
|
|
|
|
// Set initialized to true at this point since all fields are set,
|
|
// and the connection can be established at a later time.
|
|
m.Initialized = true
|
|
|
|
if req.VerifyConnection {
|
|
client, err := m.mongoDBConnectionProducer.createClient(ctx)
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to verify connection: %w", err)
|
|
}
|
|
|
|
err = client.Ping(ctx, readpref.Primary())
|
|
if err != nil {
|
|
_ = client.Disconnect(ctx) // Try to prevent any sort of resource leak
|
|
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to verify connection: %w", err)
|
|
}
|
|
m.mongoDBConnectionProducer.client = client
|
|
}
|
|
|
|
resp := dbplugin.InitializeResponse{
|
|
Config: req.Config,
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
func (m *MongoDB) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (dbplugin.NewUserResponse, error) {
|
|
if len(req.Statements.Commands) == 0 {
|
|
return dbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement
|
|
}
|
|
|
|
username, err := m.usernameProducer.Generate(req.UsernameConfig)
|
|
if err != nil {
|
|
return dbplugin.NewUserResponse{}, err
|
|
}
|
|
|
|
// Unmarshal statements.CreationStatements into mongodbRoles
|
|
var mongoCS mongoDBStatement
|
|
err = json.Unmarshal([]byte(req.Statements.Commands[0]), &mongoCS)
|
|
if err != nil {
|
|
return dbplugin.NewUserResponse{}, err
|
|
}
|
|
|
|
// Default to "admin" if no db provided
|
|
if mongoCS.DB == "" {
|
|
mongoCS.DB = "admin"
|
|
}
|
|
|
|
if len(mongoCS.Roles) == 0 {
|
|
return dbplugin.NewUserResponse{}, fmt.Errorf("roles array is required in creation statement")
|
|
}
|
|
|
|
createUserCmd := createUserCommand{
|
|
Username: username,
|
|
Password: req.Password,
|
|
Roles: mongoCS.Roles.toStandardRolesArray(),
|
|
}
|
|
|
|
if err := m.runCommandWithRetry(ctx, mongoCS.DB, createUserCmd); err != nil {
|
|
return dbplugin.NewUserResponse{}, err
|
|
}
|
|
|
|
resp := dbplugin.NewUserResponse{
|
|
Username: username,
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
func (m *MongoDB) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest) (dbplugin.UpdateUserResponse, error) {
|
|
if req.Password != nil {
|
|
err := m.changeUserPassword(ctx, req.Username, req.Password.NewPassword)
|
|
return dbplugin.UpdateUserResponse{}, err
|
|
}
|
|
return dbplugin.UpdateUserResponse{}, nil
|
|
}
|
|
|
|
func (m *MongoDB) changeUserPassword(ctx context.Context, username, password string) error {
|
|
connURL := m.getConnectionURL()
|
|
cs, err := connstring.Parse(connURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Currently doesn't support custom statements for changing the user's password
|
|
changeUserCmd := &updateUserCommand{
|
|
Username: username,
|
|
Password: password,
|
|
}
|
|
|
|
database := cs.Database
|
|
if username == m.Username || database == "" {
|
|
database = "admin"
|
|
}
|
|
|
|
err = m.runCommandWithRetry(ctx, database, changeUserCmd)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *MongoDB) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
|
|
// If no revocation statements provided, pass in empty JSON
|
|
var revocationStatement string
|
|
switch len(req.Statements.Commands) {
|
|
case 0:
|
|
revocationStatement = `{}`
|
|
case 1:
|
|
revocationStatement = req.Statements.Commands[0]
|
|
default:
|
|
return dbplugin.DeleteUserResponse{}, fmt.Errorf("expected 0 or 1 revocation statements, got %d", len(req.Statements.Commands))
|
|
}
|
|
|
|
// Unmarshal revocation statements into mongodbRoles
|
|
var mongoCS mongoDBStatement
|
|
err := json.Unmarshal([]byte(revocationStatement), &mongoCS)
|
|
if err != nil {
|
|
return dbplugin.DeleteUserResponse{}, err
|
|
}
|
|
|
|
db := mongoCS.DB
|
|
// If db is not specified, use the default authenticationDatabase "admin"
|
|
if db == "" {
|
|
db = "admin"
|
|
}
|
|
|
|
// Set the write concern. The default is majority.
|
|
writeConcern := writeconcern.New(writeconcern.WMajority())
|
|
opts, err := m.getWriteConcern()
|
|
if err != nil {
|
|
return dbplugin.DeleteUserResponse{}, err
|
|
}
|
|
if opts != nil {
|
|
writeConcern = opts.WriteConcern
|
|
}
|
|
|
|
dropUserCmd := &dropUserCommand{
|
|
Username: req.Username,
|
|
WriteConcern: writeConcern,
|
|
}
|
|
|
|
err = m.runCommandWithRetry(ctx, db, dropUserCmd)
|
|
cErr, ok := err.(mongo.CommandError)
|
|
if ok && cErr.Name == "UserNotFound" { // User already removed, don't retry needlessly
|
|
log.Default().Warn("MongoDB user was deleted prior to lease revocation", "user", req.Username)
|
|
return dbplugin.DeleteUserResponse{}, nil
|
|
}
|
|
|
|
return dbplugin.DeleteUserResponse{}, err
|
|
}
|
|
|
|
// runCommandWithRetry runs a command and retries once more if there's a failure
|
|
// on the first attempt. This should be called with the lock held
|
|
func (m *MongoDB) runCommandWithRetry(ctx context.Context, db string, cmd interface{}) error {
|
|
// Get the client
|
|
client, err := m.Connection(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Run command
|
|
result := client.Database(db).RunCommand(ctx, cmd, nil)
|
|
|
|
// Error check on the first attempt
|
|
err = result.Err()
|
|
switch {
|
|
case err == nil:
|
|
return nil
|
|
case err == io.EOF, strings.Contains(err.Error(), "EOF"):
|
|
// Call getConnection to reset and retry query if we get an EOF error on first attempt.
|
|
client, err = m.Connection(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
result = client.Database(db).RunCommand(ctx, cmd, nil)
|
|
if err := result.Err(); err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|