206 lines
4.7 KiB
Go
206 lines
4.7 KiB
Go
package mongodb
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
|
"github.com/mitchellh/mapstructure"
|
|
|
|
"gopkg.in/mgo.v2"
|
|
)
|
|
|
|
// mongoDBConnectionProducer implements ConnectionProducer and provides an
|
|
// interface for databases to make connections.
|
|
type mongoDBConnectionProducer struct {
|
|
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
|
|
WriteConcern string `json:"write_concern" structs:"write_concern" mapstructure:"write_concern"`
|
|
|
|
Initialized bool
|
|
Type string
|
|
session *mgo.Session
|
|
safe *mgo.Safe
|
|
sync.Mutex
|
|
}
|
|
|
|
// Initialize parses connection configuration.
|
|
func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
|
c.Lock()
|
|
defer c.Unlock()
|
|
|
|
err := mapstructure.WeakDecode(conf, c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(c.ConnectionURL) == 0 {
|
|
return fmt.Errorf("connection_url cannot be empty")
|
|
}
|
|
|
|
if c.WriteConcern != "" {
|
|
input := c.WriteConcern
|
|
|
|
// Try to base64 decode the input. If successful, consider the decoded
|
|
// value as input.
|
|
inputBytes, err := base64.StdEncoding.DecodeString(input)
|
|
if err == nil {
|
|
input = string(inputBytes)
|
|
}
|
|
|
|
concern := &mgo.Safe{}
|
|
err = json.Unmarshal([]byte(input), concern)
|
|
if err != nil {
|
|
return fmt.Errorf("error mashalling write_concern: %s", err)
|
|
}
|
|
|
|
// Guard against empty, non-nil mgo.Safe object; we don't want to pass that
|
|
// into mgo.SetSafe in Connection().
|
|
if (mgo.Safe{} == *concern) {
|
|
return fmt.Errorf("provided write_concern values did not map to any mgo.Safe fields")
|
|
}
|
|
c.safe = concern
|
|
}
|
|
|
|
// Set initialized to true at this point since all fields are set,
|
|
// and the connection can be established at a later time.
|
|
c.Initialized = true
|
|
|
|
if verifyConnection {
|
|
if _, err := c.Connection(ctx); err != nil {
|
|
return fmt.Errorf("error verifying connection: %s", err)
|
|
}
|
|
|
|
if err := c.session.Ping(); err != nil {
|
|
return fmt.Errorf("error verifying connection: %s", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Connection creates or returns an existing a database connection. If the session fails
|
|
// on a ping check, the session will be closed and then re-created.
|
|
func (c *mongoDBConnectionProducer) Connection(_ context.Context) (interface{}, error) {
|
|
if !c.Initialized {
|
|
return nil, connutil.ErrNotInitialized
|
|
}
|
|
|
|
if c.session != nil {
|
|
if err := c.session.Ping(); err == nil {
|
|
return c.session, nil
|
|
}
|
|
c.session.Close()
|
|
}
|
|
|
|
dialInfo, err := parseMongoURL(c.ConnectionURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c.session, err = mgo.DialWithInfo(dialInfo)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if c.safe != nil {
|
|
c.session.SetSafe(c.safe)
|
|
}
|
|
|
|
c.session.SetSyncTimeout(1 * time.Minute)
|
|
c.session.SetSocketTimeout(1 * time.Minute)
|
|
|
|
return c.session, nil
|
|
}
|
|
|
|
// Close terminates the database connection.
|
|
func (c *mongoDBConnectionProducer) Close() error {
|
|
c.Lock()
|
|
defer c.Unlock()
|
|
|
|
if c.session != nil {
|
|
c.session.Close()
|
|
}
|
|
|
|
c.session = nil
|
|
|
|
return nil
|
|
}
|
|
|
|
func parseMongoURL(rawURL string) (*mgo.DialInfo, error) {
|
|
url, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
info := mgo.DialInfo{
|
|
Addrs: strings.Split(url.Host, ","),
|
|
Database: strings.TrimPrefix(url.Path, "/"),
|
|
Timeout: 10 * time.Second,
|
|
}
|
|
|
|
if url.User != nil {
|
|
info.Username = url.User.Username()
|
|
info.Password, _ = url.User.Password()
|
|
}
|
|
|
|
query := url.Query()
|
|
for key, values := range query {
|
|
var value string
|
|
if len(values) > 0 {
|
|
value = values[0]
|
|
}
|
|
|
|
switch key {
|
|
case "authSource":
|
|
info.Source = value
|
|
case "authMechanism":
|
|
info.Mechanism = value
|
|
case "gssapiServiceName":
|
|
info.Service = value
|
|
case "replicaSet":
|
|
info.ReplicaSetName = value
|
|
case "maxPoolSize":
|
|
poolLimit, err := strconv.Atoi(value)
|
|
if err != nil {
|
|
return nil, errors.New("bad value for maxPoolSize: " + value)
|
|
}
|
|
info.PoolLimit = poolLimit
|
|
case "ssl":
|
|
// Unfortunately, mgo doesn't support the ssl parameter in its MongoDB URI parsing logic, so we have to handle that
|
|
// ourselves. See https://github.com/go-mgo/mgo/issues/84
|
|
ssl, err := strconv.ParseBool(value)
|
|
if err != nil {
|
|
return nil, errors.New("bad value for ssl: " + value)
|
|
}
|
|
if ssl {
|
|
info.DialServer = func(addr *mgo.ServerAddr) (net.Conn, error) {
|
|
return tls.Dial("tcp", addr.String(), &tls.Config{})
|
|
}
|
|
}
|
|
case "connect":
|
|
if value == "direct" {
|
|
info.Direct = true
|
|
break
|
|
}
|
|
if value == "replicaSet" {
|
|
break
|
|
}
|
|
fallthrough
|
|
default:
|
|
return nil, errors.New("unsupported connection URL option: " + key + "=" + value)
|
|
}
|
|
}
|
|
|
|
return &info, nil
|
|
}
|