Mark deprecated builtins Removed (#18039)

* Remove logical database builtins

* Drop removed builtins from registry keys

* Update plugin prediction test

* Remove app-id builtin

* Add changelog
This commit is contained in:
Mike Palmiotto 2023-01-09 09:16:35 -05:00 committed by GitHub
parent 25d0afae23
commit 43a78c85f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
57 changed files with 67 additions and 8246 deletions

View File

@ -1,184 +0,0 @@
package appId
import (
"context"
"sync"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
)
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b, err := Backend(conf)
if err != nil {
return nil, err
}
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
}
func Backend(conf *logical.BackendConfig) (*backend, error) {
var b backend
b.MapAppId = &framework.PolicyMap{
PathMap: framework.PathMap{
Name: "app-id",
Schema: map[string]*framework.FieldSchema{
"display_name": {
Type: framework.TypeString,
Description: "A name to map to this app ID for logs.",
},
"value": {
Type: framework.TypeString,
Description: "Policies for the app ID.",
},
},
},
DefaultKey: "default",
}
b.MapUserId = &framework.PathMap{
Name: "user-id",
Schema: map[string]*framework.FieldSchema{
"cidr_block": {
Type: framework.TypeString,
Description: "If not blank, restricts auth by this CIDR block",
},
"value": {
Type: framework.TypeString,
Description: "App IDs that this user associates with.",
},
},
}
b.Backend = &framework.Backend{
Help: backendHelp,
PathsSpecial: &logical.Paths{
Unauthenticated: []string{
"login",
"login/*",
},
},
Paths: framework.PathAppend([]*framework.Path{
pathLogin(&b),
pathLoginWithAppIDPath(&b),
},
b.MapAppId.Paths(),
b.MapUserId.Paths(),
),
AuthRenew: b.pathLoginRenew,
Invalidate: b.invalidate,
BackendType: logical.TypeCredential,
}
b.view = conf.StorageView
b.MapAppId.SaltFunc = b.Salt
b.MapUserId.SaltFunc = b.Salt
return &b, nil
}
type backend struct {
*framework.Backend
salt *salt.Salt
SaltMutex sync.RWMutex
view logical.Storage
MapAppId *framework.PolicyMap
MapUserId *framework.PathMap
}
func (b *backend) Salt(ctx context.Context) (*salt.Salt, error) {
b.SaltMutex.RLock()
if b.salt != nil {
defer b.SaltMutex.RUnlock()
return b.salt, nil
}
b.SaltMutex.RUnlock()
b.SaltMutex.Lock()
defer b.SaltMutex.Unlock()
if b.salt != nil {
return b.salt, nil
}
salt, err := salt.NewSalt(ctx, b.view, &salt.Config{
HashFunc: salt.SHA1Hash,
Location: salt.DefaultLocation,
})
if err != nil {
return nil, err
}
b.salt = salt
return salt, nil
}
func (b *backend) invalidate(_ context.Context, key string) {
switch key {
case salt.DefaultLocation:
b.SaltMutex.Lock()
defer b.SaltMutex.Unlock()
b.salt = nil
}
}
const backendHelp = `
The App ID credential provider is used to perform authentication from
within applications or machine by pairing together two hard-to-guess
unique pieces of information: a unique app ID, and a unique user ID.
The goal of this credential provider is to allow elastic users
(dynamic machines, containers, etc.) to authenticate with Vault without
having to store passwords outside of Vault. It is a single method of
solving the chicken-and-egg problem of setting up Vault access on a machine.
With this provider, nobody except the machine itself has access to both
pieces of information necessary to authenticate. For example:
configuration management will have the app IDs, but the machine itself
will detect its user ID based on some unique machine property such as a
MAC address (or a hash of it with some salt).
An example, real world process for using this provider:
1. Create unique app IDs (UUIDs work well) and map them to policies.
(Path: map/app-id/<app-id>)
2. Store the app IDs within configuration management systems.
3. An out-of-band process run by security operators map unique user IDs
to these app IDs. Example: when an instance is launched, a cloud-init
system tells security operators a unique ID for this machine. This
process can be scripted, but the key is that it is out-of-band and
out of reach of configuration management.
(Path: map/user-id/<user-id>)
4. A new server is provisioned. Configuration management configures the
app ID, the server itself detects its user ID. With both of these
pieces of information, Vault can be accessed according to the policy
set by the app ID.
More details on this process follow:
The app ID is a unique ID that maps to a set of policies. This ID is
generated by an operator and configured into the backend. The ID itself
is usually a UUID, but any hard-to-guess unique value can be used.
After creating app IDs, an operator authorizes a fixed set of user IDs
with each app ID. When a valid {app ID, user ID} tuple is given to the
"login" path, then the user is authenticated with the configured app
ID policies.
The user ID can be any value (just like the app ID), however it is
generally a value unique to a machine, such as a MAC address or instance ID,
or a value hashed from these unique values.
It is possible to authorize multiple app IDs with each
user ID by writing them as comma-separated values to the map/user-id/<user-id>
path.
It is also possible to renew the auth tokens with 'vault token-renew <token>' command.
Before the token is renewed, the validity of app ID, user ID and the associated
policies are checked again.
`

View File

@ -1,239 +0,0 @@
package appId
import (
"context"
"fmt"
"testing"
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
)
func TestBackend_basic(t *testing.T) {
var b *backend
var err error
var storage logical.Storage
factory := func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b, err = Backend(conf)
if err != nil {
t.Fatal(err)
}
storage = conf.StorageView
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
}
logicaltest.Test(t, logicaltest.TestCase{
CredentialFactory: factory,
Steps: []logicaltest.TestStep{
testAccStepMapAppId(t),
testAccStepMapUserId(t),
testAccLogin(t, ""),
testAccLoginAppIDInPath(t, ""),
testAccLoginInvalid(t),
testAccStepDeleteUserId(t),
testAccLoginDeleted(t),
},
})
req := &logical.Request{
Path: "map/app-id",
Operation: logical.ListOperation,
Storage: storage,
}
resp, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("nil response")
}
keys := resp.Data["keys"].([]string)
if len(keys) != 1 {
t.Fatalf("expected 1 key, got %d", len(keys))
}
bSalt, err := b.Salt(context.Background())
if err != nil {
t.Fatal(err)
}
if keys[0] != "s"+bSalt.SaltIDHashFunc("foo", salt.SHA256Hash) {
t.Fatal("value was improperly salted")
}
}
func TestBackend_cidr(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
CredentialFactory: Factory,
Steps: []logicaltest.TestStep{
testAccStepMapAppIdDisplayName(t),
testAccStepMapUserIdCidr(t, "192.168.1.0/16"),
testAccLoginCidr(t, "192.168.1.5", false),
testAccLoginCidr(t, "10.0.1.5", true),
testAccLoginCidr(t, "", true),
},
})
}
func TestBackend_displayName(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
CredentialFactory: Factory,
Steps: []logicaltest.TestStep{
testAccStepMapAppIdDisplayName(t),
testAccStepMapUserId(t),
testAccLogin(t, "tubbin"),
testAccLoginAppIDInPath(t, "tubbin"),
testAccLoginInvalid(t),
testAccStepDeleteUserId(t),
testAccLoginDeleted(t),
},
})
}
func testAccStepMapAppId(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "map/app-id/foo",
Data: map[string]interface{}{
"value": "foo,bar",
},
}
}
func testAccStepMapAppIdDisplayName(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "map/app-id/foo",
Data: map[string]interface{}{
"display_name": "tubbin",
"value": "foo,bar",
},
}
}
func testAccStepMapUserId(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "map/user-id/42",
Data: map[string]interface{}{
"value": "foo",
},
}
}
func testAccStepDeleteUserId(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.DeleteOperation,
Path: "map/user-id/42",
}
}
func testAccStepMapUserIdCidr(t *testing.T, cidr string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "map/user-id/42",
Data: map[string]interface{}{
"value": "foo",
"cidr_block": cidr,
},
}
}
func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
checkTTL := func(resp *logical.Response) error {
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
}
return nil
}
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "login",
Data: map[string]interface{}{
"app_id": "foo",
"user_id": "42",
},
Unauthenticated: true,
Check: logicaltest.TestCheckMulti(
logicaltest.TestCheckAuth([]string{"bar", "default", "foo"}),
logicaltest.TestCheckAuthDisplayName(display),
checkTTL,
),
}
}
func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep {
checkTTL := func(resp *logical.Response) error {
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
}
return nil
}
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "login/foo",
Data: map[string]interface{}{
"user_id": "42",
},
Unauthenticated: true,
Check: logicaltest.TestCheckMulti(
logicaltest.TestCheckAuth([]string{"bar", "default", "foo"}),
logicaltest.TestCheckAuthDisplayName(display),
checkTTL,
),
}
}
func testAccLoginCidr(t *testing.T, ip string, err bool) logicaltest.TestStep {
check := logicaltest.TestCheckError()
if !err {
check = logicaltest.TestCheckAuth([]string{"bar", "default", "foo"})
}
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "login",
Data: map[string]interface{}{
"app_id": "foo",
"user_id": "42",
},
ErrorOk: err,
Unauthenticated: true,
RemoteAddr: ip,
Check: check,
}
}
func testAccLoginInvalid(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "login",
Data: map[string]interface{}{
"app_id": "foo",
"user_id": "48",
},
ErrorOk: true,
Unauthenticated: true,
Check: logicaltest.TestCheckError(),
}
}
func testAccLoginDeleted(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "login",
Data: map[string]interface{}{
"app_id": "foo",
"user_id": "42",
},
ErrorOk: true,
Unauthenticated: true,
Check: logicaltest.TestCheckError(),
}
}

View File

@ -1,29 +0,0 @@
package main
import (
"os"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
appId "github.com/hashicorp/vault/builtin/credential/app-id"
"github.com/hashicorp/vault/sdk/plugin"
)
func main() {
apiClientMeta := &api.PluginAPIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(os.Args[1:])
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
if err := plugin.Serve(&plugin.ServeOpts{
BackendFactoryFunc: appId.Factory,
TLSProviderFunc: tlsProviderFunc,
}); err != nil {
logger := hclog.New(&hclog.LoggerOptions{})
logger.Error("plugin shutting down", "error", err)
os.Exit(1)
}
}

View File

@ -1,229 +0,0 @@
package appId
import (
"context"
"crypto/sha1"
"crypto/subtle"
"encoding/hex"
"fmt"
"net"
"strings"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/policyutil"
"github.com/hashicorp/vault/sdk/logical"
)
func pathLoginWithAppIDPath(b *backend) *framework.Path {
return &framework.Path{
Pattern: "login/(?P<app_id>.+)",
Fields: map[string]*framework.FieldSchema{
"app_id": {
Type: framework.TypeString,
Description: "The unique app ID",
},
"user_id": {
Type: framework.TypeString,
Description: "The unique user ID",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathLogin,
},
HelpSynopsis: pathLoginSyn,
HelpDescription: pathLoginDesc,
}
}
func pathLogin(b *backend) *framework.Path {
return &framework.Path{
Pattern: "login$",
Fields: map[string]*framework.FieldSchema{
"app_id": {
Type: framework.TypeString,
Description: "The unique app ID",
},
"user_id": {
Type: framework.TypeString,
Description: "The unique user ID",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathLogin,
logical.AliasLookaheadOperation: b.pathLoginAliasLookahead,
},
HelpSynopsis: pathLoginSyn,
HelpDescription: pathLoginDesc,
}
}
func (b *backend) pathLoginAliasLookahead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
appId := data.Get("app_id").(string)
if appId == "" {
return nil, fmt.Errorf("missing app_id")
}
return &logical.Response{
Auth: &logical.Auth{
Alias: &logical.Alias{
Name: appId,
},
},
}, nil
}
func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
appId := data.Get("app_id").(string)
userId := data.Get("user_id").(string)
var displayName string
if dispName, resp, err := b.verifyCredentials(ctx, req, appId, userId); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
} else {
displayName = dispName
}
// Get the policies associated with the app
policies, err := b.MapAppId.Policies(ctx, req.Storage, appId)
if err != nil {
return nil, err
}
// Store hashes of the app ID and user ID for the metadata
appIdHash := sha1.Sum([]byte(appId))
userIdHash := sha1.Sum([]byte(userId))
metadata := map[string]string{
"app-id": "sha1:" + hex.EncodeToString(appIdHash[:]),
"user-id": "sha1:" + hex.EncodeToString(userIdHash[:]),
}
return &logical.Response{
Auth: &logical.Auth{
InternalData: map[string]interface{}{
"app-id": appId,
"user-id": userId,
},
DisplayName: displayName,
Policies: policies,
Metadata: metadata,
LeaseOptions: logical.LeaseOptions{
Renewable: true,
},
Alias: &logical.Alias{
Name: appId,
},
},
}, nil
}
func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
appId := req.Auth.InternalData["app-id"].(string)
userId := req.Auth.InternalData["user-id"].(string)
// Skipping CIDR verification to enable renewal from machines other than
// the ones encompassed by CIDR block.
if _, resp, err := b.verifyCredentials(ctx, req, appId, userId); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
}
// Get the policies associated with the app
mapPolicies, err := b.MapAppId.Policies(ctx, req.Storage, appId)
if err != nil {
return nil, err
}
if !policyutil.EquivalentPolicies(mapPolicies, req.Auth.TokenPolicies) {
return nil, fmt.Errorf("policies do not match")
}
return &logical.Response{Auth: req.Auth}, nil
}
func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, appId, userId string) (string, *logical.Response, error) {
// Ensure both appId and userId are provided
if appId == "" || userId == "" {
return "", logical.ErrorResponse("missing 'app_id' or 'user_id'"), nil
}
// Look up the apps that this user is allowed to access
appsMap, err := b.MapUserId.Get(ctx, req.Storage, userId)
if err != nil {
return "", nil, err
}
if appsMap == nil {
return "", logical.ErrorResponse("invalid user ID or app ID"), nil
}
// If there is a CIDR block restriction, check that
if raw, ok := appsMap["cidr_block"]; ok {
_, cidr, err := net.ParseCIDR(raw.(string))
if err != nil {
return "", nil, fmt.Errorf("invalid restriction cidr: %w", err)
}
var addr string
if req.Connection != nil {
addr = req.Connection.RemoteAddr
}
if addr == "" || !cidr.Contains(net.ParseIP(addr)) {
return "", logical.ErrorResponse("unauthorized source address"), nil
}
}
appsRaw, ok := appsMap["value"]
if !ok {
appsRaw = ""
}
apps, ok := appsRaw.(string)
if !ok {
return "", nil, fmt.Errorf("mapping is not a string")
}
// Verify that the app is in the list
found := false
appIdBytes := []byte(appId)
for _, app := range strings.Split(apps, ",") {
match := []byte(strings.TrimSpace(app))
// Protect against a timing attack with the app_id comparison
if subtle.ConstantTimeCompare(match, appIdBytes) == 1 {
found = true
}
}
if !found {
return "", logical.ErrorResponse("invalid user ID or app ID"), nil
}
// Get the raw data associated with the app
appRaw, err := b.MapAppId.Get(ctx, req.Storage, appId)
if err != nil {
return "", nil, err
}
if appRaw == nil {
return "", logical.ErrorResponse("invalid user ID or app ID"), nil
}
var displayName string
if raw, ok := appRaw["display_name"]; ok {
displayName = raw.(string)
}
return displayName, nil, nil
}
const pathLoginSyn = `
Log in with an App ID and User ID.
`
const pathLoginDesc = `
This endpoint authenticates using an application ID, user ID and potential the IP address of the connecting client.
`

View File

@ -1,134 +0,0 @@
package cassandra
import (
"context"
"fmt"
"strings"
"sync"
"github.com/gocql/gocql"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
// Factory creates a new backend
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
}
// Backend contains the base information for the backend's functionality
func Backend() *backend {
var b backend
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
SealWrapStorage: []string{
"config/connection",
},
},
Paths: []*framework.Path{
pathConfigConnection(&b),
pathRoles(&b),
pathCredsCreate(&b),
},
Secrets: []*framework.Secret{
secretCreds(&b),
},
Invalidate: b.invalidate,
Clean: func(_ context.Context) {
b.ResetDB(nil)
},
BackendType: logical.TypeLogical,
}
return &b
}
type backend struct {
*framework.Backend
// Session is goroutine safe, however, since we reinitialize
// it when connection info changes, we want to make sure we
// can close it and use a new connection; hence the lock
session *gocql.Session
lock sync.Mutex
}
type sessionConfig struct {
Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"`
Username string `json:"username" structs:"username" mapstructure:"username"`
Password string `json:"password" structs:"password" mapstructure:"password"`
TLS bool `json:"tls" structs:"tls" mapstructure:"tls"`
InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"`
PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"`
IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"`
ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
}
// DB returns the database connection.
func (b *backend) DB(ctx context.Context, s logical.Storage) (*gocql.Session, error) {
b.lock.Lock()
defer b.lock.Unlock()
// If we already have a DB, we got it!
if b.session != nil {
return b.session, nil
}
entry, err := s.Get(ctx, "config/connection")
if err != nil {
return nil, err
}
if entry == nil {
return nil, fmt.Errorf("configure the DB connection with config/connection first")
}
config := &sessionConfig{}
if err := entry.DecodeJSON(config); err != nil {
return nil, err
}
session, err := createSession(config, s)
// Store the session in backend for reuse
b.session = session
return session, err
}
// ResetDB forces a connection next time DB() is called.
func (b *backend) ResetDB(newSession *gocql.Session) {
b.lock.Lock()
defer b.lock.Unlock()
if b.session != nil {
b.session.Close()
}
b.session = newSession
}
func (b *backend) invalidate(_ context.Context, key string) {
switch key {
case "config/connection":
b.ResetDB(nil)
}
}
const backendHelp = `
The Cassandra backend dynamically generates database users.
After mounting this backend, configure it using the endpoints within
the "config/" path.
`

View File

@ -1,163 +0,0 @@
package cassandra
import (
"context"
"fmt"
"log"
"testing"
"github.com/hashicorp/vault/helper/testhelpers/cassandra"
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
"github.com/hashicorp/vault/sdk/logical"
"github.com/mitchellh/mapstructure"
)
func TestBackend_basic(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
copyFromTo := map[string]string{
"test-fixtures/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
}
host, cleanup := cassandra.PrepareTestContainer(t,
cassandra.CopyFromTo(copyFromTo),
)
defer cleanup()
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, host.ConnectionURL()),
testAccStepRole(t),
testAccStepReadCreds(t, "test"),
},
})
}
func TestBackend_roleCrud(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
copyFromTo := map[string]string{
"test-fixtures/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
}
host, cleanup := cassandra.PrepareTestContainer(t,
cassandra.CopyFromTo(copyFromTo))
defer cleanup()
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, host.ConnectionURL()),
testAccStepRole(t),
testAccStepRoleWithOptions(t),
testAccStepReadRole(t, "test", testRole),
testAccStepReadRole(t, "test2", testRole),
testAccStepDeleteRole(t, "test"),
testAccStepDeleteRole(t, "test2"),
testAccStepReadRole(t, "test", ""),
testAccStepReadRole(t, "test2", ""),
},
})
}
func testAccStepConfig(t *testing.T, hostname string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/connection",
Data: map[string]interface{}{
"hosts": hostname,
"username": "cassandra",
"password": "cassandra",
"protocol_version": 3,
},
}
}
func testAccStepRole(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "roles/test",
Data: map[string]interface{}{
"creation_cql": testRole,
},
}
}
func testAccStepRoleWithOptions(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "roles/test2",
Data: map[string]interface{}{
"creation_cql": testRole,
"lease": "30s",
"consistency": "All",
},
}
}
func testAccStepDeleteRole(t *testing.T, n string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.DeleteOperation,
Path: "roles/" + n,
}
}
func testAccStepReadCreds(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "creds/" + name,
Check: func(resp *logical.Response) error {
var d struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[WARN] Generated credentials: %v", d)
return nil
},
}
}
func testAccStepReadRole(t *testing.T, name string, cql string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "roles/" + name,
Check: func(resp *logical.Response) error {
if resp == nil {
if cql == "" {
return nil
}
return fmt.Errorf("response is nil")
}
var d struct {
CreationCQL string `mapstructure:"creation_cql"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if d.CreationCQL != cql {
return fmt.Errorf("bad: %#v\n%#v\n%#v\n", resp, cql, d.CreationCQL)
}
return nil
},
}
}
const testRole = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;
GRANT ALL PERMISSIONS ON ALL KEYSPACES TO {{username}};`

View File

@ -1,29 +0,0 @@
package main
import (
"os"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/cassandra"
"github.com/hashicorp/vault/sdk/plugin"
)
func main() {
apiClientMeta := &api.PluginAPIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(os.Args[1:])
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
if err := plugin.Serve(&plugin.ServeOpts{
BackendFactoryFunc: cassandra.Factory,
TLSProviderFunc: tlsProviderFunc,
}); err != nil {
logger := hclog.New(&hclog.LoggerOptions{})
logger.Error("plugin shutting down", "error", err)
os.Exit(1)
}
}

View File

@ -1,245 +0,0 @@
package cassandra
import (
"context"
"fmt"
"github.com/hashicorp/go-secure-stdlib/tlsutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
)
func pathConfigConnection(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config/connection",
Fields: map[string]*framework.FieldSchema{
"hosts": {
Type: framework.TypeString,
Description: "Comma-separated list of hosts",
},
"username": {
Type: framework.TypeString,
Description: "The username to use for connecting to the cluster",
},
"password": {
Type: framework.TypeString,
Description: "The password to use for connecting to the cluster",
},
"tls": {
Type: framework.TypeBool,
Description: `Whether to use TLS. If pem_bundle or pem_json are
set, this is automatically set to true`,
},
"insecure_tls": {
Type: framework.TypeBool,
Description: `Whether to use TLS but skip verification; has no
effect if a CA certificate is provided`,
},
// TLS 1.3 is not supported as this engine is deprecated. Please switch to the Cassandra database secrets engine
"tls_min_version": {
Type: framework.TypeString,
Default: "tls12",
Description: "Minimum TLS version to use. Accepted values are 'tls10', 'tls11' or 'tls12'. Defaults to 'tls12'",
},
"pem_bundle": {
Type: framework.TypeString,
Description: `PEM-format, concatenated unencrypted secret key
and certificate, with optional CA certificate`,
},
"pem_json": {
Type: framework.TypeString,
Description: `JSON containing a PEM-format, unencrypted secret
key and certificate, with optional CA certificate.
The JSON output of a certificate issued with the PKI
backend can be directly passed into this parameter.
If both this and "pem_bundle" are specified, this will
take precedence.`,
},
"protocol_version": {
Type: framework.TypeInt,
Description: `The protocol version to use. Defaults to 2.`,
},
"connect_timeout": {
Type: framework.TypeDurationSecond,
Default: 5,
Description: `The connection timeout to use. Defaults to 5.`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathConnectionRead,
logical.UpdateOperation: b.pathConnectionWrite,
},
HelpSynopsis: pathConfigConnectionHelpSyn,
HelpDescription: pathConfigConnectionHelpDesc,
}
}
func (b *backend) pathConnectionRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entry, err := req.Storage.Get(ctx, "config/connection")
if err != nil {
return nil, err
}
if entry == nil {
return logical.ErrorResponse(fmt.Sprintf("Configure the DB connection with config/connection first")), nil
}
config := &sessionConfig{}
if err := entry.DecodeJSON(config); err != nil {
return nil, err
}
resp := &logical.Response{
Data: map[string]interface{}{
"hosts": config.Hosts,
"username": config.Username,
"tls": config.TLS,
"insecure_tls": config.InsecureTLS,
"certificate": config.Certificate,
"issuing_ca": config.IssuingCA,
"protocol_version": config.ProtocolVersion,
"connect_timeout": config.ConnectTimeout,
"tls_min_version": config.TLSMinVersion,
},
}
return resp, nil
}
func (b *backend) pathConnectionWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
hosts := data.Get("hosts").(string)
username := data.Get("username").(string)
password := data.Get("password").(string)
switch {
case len(hosts) == 0:
return logical.ErrorResponse("Hosts cannot be empty"), nil
case len(username) == 0:
return logical.ErrorResponse("Username cannot be empty"), nil
case len(password) == 0:
return logical.ErrorResponse("Password cannot be empty"), nil
}
config := &sessionConfig{
Hosts: hosts,
Username: username,
Password: password,
TLS: data.Get("tls").(bool),
InsecureTLS: data.Get("insecure_tls").(bool),
ProtocolVersion: data.Get("protocol_version").(int),
ConnectTimeout: data.Get("connect_timeout").(int),
}
config.TLSMinVersion = data.Get("tls_min_version").(string)
if config.TLSMinVersion == "" {
return logical.ErrorResponse("failed to get 'tls_min_version' value"), nil
}
var ok bool
_, ok = tlsutil.TLSLookup[config.TLSMinVersion]
if !ok {
return logical.ErrorResponse("invalid 'tls_min_version'"), nil
}
if config.InsecureTLS {
config.TLS = true
}
pemBundle := data.Get("pem_bundle").(string)
pemJSON := data.Get("pem_json").(string)
var certBundle *certutil.CertBundle
var parsedCertBundle *certutil.ParsedCertBundle
var err error
switch {
case len(pemJSON) != 0:
parsedCertBundle, err = certutil.ParsePKIJSON([]byte(pemJSON))
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: %s", err)), nil
}
certBundle, err = parsedCertBundle.ToCertBundle()
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error marshaling PEM information: %s", err)), nil
}
config.Certificate = certBundle.Certificate
config.PrivateKey = certBundle.PrivateKey
config.IssuingCA = certBundle.IssuingCA
config.TLS = true
case len(pemBundle) != 0:
parsedCertBundle, err = certutil.ParsePEMBundle(pemBundle)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error parsing the given PEM information: %s", err)), nil
}
certBundle, err = parsedCertBundle.ToCertBundle()
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error marshaling PEM information: %s", err)), nil
}
config.Certificate = certBundle.Certificate
config.PrivateKey = certBundle.PrivateKey
config.IssuingCA = certBundle.IssuingCA
config.TLS = true
}
session, err := createSession(config, req.Storage)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
// Store it
entry, err := logical.StorageEntryJSON("config/connection", config)
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
// Reset the DB connection
b.ResetDB(session)
return nil, nil
}
const pathConfigConnectionHelpSyn = `
Configure the connection information to talk to Cassandra.
`
const pathConfigConnectionHelpDesc = `
This path configures the connection information used to connect to Cassandra.
"hosts" is a comma-delimited list of hostnames in the Cassandra cluster.
"username" and "password" are self-explanatory, although the given user
must have superuser access within Cassandra. Note that since this backend
issues username/password credentials, Cassandra must be configured to use
PasswordAuthenticator or a similar backend for its authentication. If you wish
to have no authorization in Cassandra and want to use TLS client certificates,
see the PKI backend.
TLS works as follows:
* If "tls" is set to true, the connection will use TLS; this happens automatically if "pem_bundle", "pem_json", or "insecure_tls" is set
* If "insecure_tls" is set to true, the connection will not perform verification of the server certificate; this also sets "tls" to true
* If only "issuing_ca" is set in "pem_json", or the only certificate in "pem_bundle" is a CA certificate, the given CA certificate will be used for server certificate verification; otherwise the system CA certificates will be used
* If "certificate" and "private_key" are set in "pem_bundle" or "pem_json", client auth will be turned on for the connection
"pem_bundle" should be a PEM-concatenated bundle of a private key + client certificate, an issuing CA certificate, or both. "pem_json" should contain the same information; for convenience, the JSON format is the same as that output by the issue command from the PKI backend.
When configuring the connection information, the backend will verify its
validity.
`

View File

@ -1,123 +0,0 @@
package cassandra
import (
"context"
"fmt"
"strings"
"time"
"github.com/gocql/gocql"
"github.com/hashicorp/go-secure-stdlib/strutil"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func pathCredsCreate(b *backend) *framework.Path {
return &framework.Path{
Pattern: "creds/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the role",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathCredsCreateRead,
},
HelpSynopsis: pathCredsCreateReadHelpSyn,
HelpDescription: pathCredsCreateReadHelpDesc,
}
}
func (b *backend) pathCredsCreateRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
// Get the role
role, err := getRole(ctx, req.Storage, name)
if err != nil {
return nil, err
}
if role == nil {
return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil
}
displayName := req.DisplayName
userUUID, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
username := fmt.Sprintf("vault_%s_%s_%s_%d", name, displayName, userUUID, time.Now().Unix())
username = strings.ReplaceAll(username, "-", "_")
password, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
// Get our connection
session, err := b.DB(ctx, req.Storage)
if err != nil {
return nil, err
}
// Set consistency
if role.Consistency != "" {
consistencyValue, err := gocql.ParseConsistencyWrapper(role.Consistency)
if err != nil {
return nil, err
}
session.SetConsistency(consistencyValue)
}
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(role.CreationCQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
err = session.Query(substQuery(query, map[string]string{
"username": username,
"password": password,
})).Exec()
if err != nil {
for _, query := range strutil.ParseArbitraryStringSlice(role.RollbackCQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
session.Query(substQuery(query, map[string]string{
"username": username,
"password": password,
})).Exec()
}
return nil, err
}
}
// Return the secret
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
"username": username,
"password": password,
}, map[string]interface{}{
"username": username,
"role": name,
})
resp.Secret.TTL = role.Lease
return resp, nil
}
const pathCredsCreateReadHelpSyn = `
Request database credentials for a certain role.
`
const pathCredsCreateReadHelpDesc = `
This path creates database credentials for a certain role. The
database credentials will be generated on demand and will be automatically
revoked when the lease is up.
`

View File

@ -1,196 +0,0 @@
package cassandra
import (
"context"
"fmt"
"time"
"github.com/fatih/structs"
"github.com/gocql/gocql"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
const (
defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;`
defaultRollbackCQL = `DROP USER '{{username}}';`
)
func pathRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "roles/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the role",
},
"creation_cql": {
Type: framework.TypeString,
Default: defaultCreationCQL,
Description: `CQL to create a user and optionally grant
authorization. If not supplied, a default that
creates non-superuser accounts with the built-in
password authenticator will be used; no
authorization grants will be configured. Separate
statements by semicolons; use @file to load from a
file. Valid template values are '{{username}}' and
'{{password}}' -- the single quotes are important!`,
},
"rollback_cql": {
Type: framework.TypeString,
Default: defaultRollbackCQL,
Description: `CQL to roll back an account operation. This will
be used if there is an error during execution of a
statement passed in via the "creation_cql" parameter
parameter. The default simply drops the user, which
should generally be sufficient. Separate statements
by semicolons; use @file to load from a file. Valid
template values are '{{username}}' and
'{{password}}' -- the single quotes are important!`,
},
"lease": {
Type: framework.TypeString,
Default: "4h",
Description: "The lease length; defaults to 4 hours",
},
"consistency": {
Type: framework.TypeString,
Default: "Quorum",
Description: "The consistency level for the operations; defaults to Quorum.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRoleRead,
logical.UpdateOperation: b.pathRoleCreate,
logical.DeleteOperation: b.pathRoleDelete,
},
HelpSynopsis: pathRoleHelpSyn,
HelpDescription: pathRoleHelpDesc,
}
}
func getRole(ctx context.Context, s logical.Storage, n string) (*roleEntry, error) {
entry, err := s.Get(ctx, "role/"+n)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var result roleEntry
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete(ctx, "role/"+data.Get("name").(string))
if err != nil {
return nil, err
}
return nil, nil
}
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
role, err := getRole(ctx, req.Storage, data.Get("name").(string))
if err != nil {
return nil, err
}
if role == nil {
return nil, nil
}
return &logical.Response{
Data: structs.New(role).Map(),
}, nil
}
func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
creationCQL := data.Get("creation_cql").(string)
rollbackCQL := data.Get("rollback_cql").(string)
leaseRaw := data.Get("lease").(string)
lease, err := time.ParseDuration(leaseRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error parsing lease value of %s: %s", leaseRaw, err)), nil
}
consistencyStr := data.Get("consistency").(string)
_, err = gocql.ParseConsistencyWrapper(consistencyStr)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error parsing consistency value of %q: %v", consistencyStr, err)), nil
}
entry := &roleEntry{
Lease: lease,
CreationCQL: creationCQL,
RollbackCQL: rollbackCQL,
Consistency: consistencyStr,
}
// Store it
entryJSON, err := logical.StorageEntryJSON("role/"+name, entry)
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entryJSON); err != nil {
return nil, err
}
return nil, nil
}
type roleEntry struct {
CreationCQL string `json:"creation_cql" structs:"creation_cql"`
Lease time.Duration `json:"lease" structs:"lease"`
RollbackCQL string `json:"rollback_cql" structs:"rollback_cql"`
Consistency string `json:"consistency" structs:"consistency"`
}
const pathRoleHelpSyn = `
Manage the roles that can be created with this backend.
`
const pathRoleHelpDesc = `
This path lets you manage the roles that can be created with this backend.
The "creation_cql" parameter customizes the CQL string used to create users
and assign them grants. This can be a sequence of CQL queries separated by
semicolons. Some substitution will be done to the CQL string for certain keys.
The names of the variables must be surrounded by '{{' and '}}' to be replaced.
Note that it is important that single quotes are used, not double quotes.
* "username" - The random username generated for the DB user.
* "password" - The random password generated for the DB user.
If no "creation_cql" parameter is given, a default will be used:
` + defaultCreationCQL + `
This default should be suitable for Cassandra installations using the password
authenticator but not configured to use authorization.
Similarly, the "rollback_cql" is used if user creation fails, in the absence of
Cassandra transactions. The default should be suitable for almost any
instance of Cassandra:
` + defaultRollbackCQL + `
"lease" the lease time; if not set the mount/system defaults are used.
`

View File

@ -1,77 +0,0 @@
package cassandra
import (
"context"
"fmt"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
// SecretCredsType is the type of creds issued from this backend
const SecretCredsType = "cassandra"
func secretCreds(b *backend) *framework.Secret {
return &framework.Secret{
Type: SecretCredsType,
Fields: map[string]*framework.FieldSchema{
"username": {
Type: framework.TypeString,
Description: "Username",
},
"password": {
Type: framework.TypeString,
Description: "Password",
},
},
Renew: b.secretCredsRenew,
Revoke: b.secretCredsRevoke,
}
}
func (b *backend) secretCredsRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the lease information
roleRaw, ok := req.Secret.InternalData["role"]
if !ok {
return nil, fmt.Errorf("secret is missing role internal data")
}
roleName, ok := roleRaw.(string)
if !ok {
return nil, fmt.Errorf("error converting role internal data to string")
}
role, err := getRole(ctx, req.Storage, roleName)
if err != nil {
return nil, fmt.Errorf("unable to load role: %w", err)
}
resp := &logical.Response{Secret: req.Secret}
resp.Secret.TTL = role.Lease
return resp, nil
}
func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the username from the internal data
usernameRaw, ok := req.Secret.InternalData["username"]
if !ok {
return nil, fmt.Errorf("secret is missing username internal data")
}
username, ok := usernameRaw.(string)
if !ok {
return nil, fmt.Errorf("error converting username internal data to string")
}
session, err := b.DB(ctx, req.Storage)
if err != nil {
return nil, fmt.Errorf("error getting session")
}
err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec()
if err != nil {
return nil, fmt.Errorf("error removing user %q", username)
}
return nil, nil
}

File diff suppressed because it is too large Load Diff

View File

@ -1,95 +0,0 @@
package cassandra
import (
"crypto/tls"
"fmt"
"strings"
"time"
"github.com/gocql/gocql"
"github.com/hashicorp/go-secure-stdlib/tlsutil"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
)
// Query templates a query for us.
func substQuery(tpl string, data map[string]string) string {
for k, v := range data {
tpl = strings.ReplaceAll(tpl, fmt.Sprintf("{{%s}}", k), v)
}
return tpl
}
func createSession(cfg *sessionConfig, s logical.Storage) (*gocql.Session, error) {
clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...)
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
Username: cfg.Username,
Password: cfg.Password,
}
clusterConfig.ProtoVersion = cfg.ProtocolVersion
if clusterConfig.ProtoVersion == 0 {
clusterConfig.ProtoVersion = 2
}
clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second
if cfg.TLS {
var tlsConfig *tls.Config
if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 {
if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 {
return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
}
certBundle := &certutil.CertBundle{}
if len(cfg.Certificate) > 0 {
certBundle.Certificate = cfg.Certificate
certBundle.PrivateKey = cfg.PrivateKey
}
if len(cfg.IssuingCA) > 0 {
certBundle.IssuingCA = cfg.IssuingCA
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return nil, fmt.Errorf("failed to parse certificate bundle: %w", err)
}
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
if err != nil || tlsConfig == nil {
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig: %#v; %w", tlsConfig, err)
}
tlsConfig.InsecureSkipVerify = cfg.InsecureTLS
if cfg.TLSMinVersion != "" {
var ok bool
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
}
} else {
// MinVersion was not being set earlier. Reset it to
// zero to gracefully handle upgrades.
tlsConfig.MinVersion = 0
}
}
clusterConfig.SslOpts = &gocql.SslOptions{
Config: tlsConfig,
}
}
session, err := clusterConfig.CreateSession()
if err != nil {
return nil, fmt.Errorf("error creating session: %w", err)
}
// Verify the info
err = session.Query(`LIST USERS`).Exec()
if err != nil {
return nil, fmt.Errorf("error validating connection info: %w", err)
}
return session, nil
}

View File

@ -1,144 +0,0 @@
package mongodb
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
mgo "gopkg.in/mgo.v2"
)
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
}
func Backend() *framework.Backend {
var b backend
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
SealWrapStorage: []string{
"config/connection",
},
},
Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigLease(&b),
pathListRoles(&b),
pathRoles(&b),
pathCredsCreate(&b),
},
Secrets: []*framework.Secret{
secretCreds(&b),
},
Clean: b.ResetSession,
Invalidate: b.invalidate,
BackendType: logical.TypeLogical,
}
return b.Backend
}
type backend struct {
*framework.Backend
session *mgo.Session
lock sync.Mutex
}
// Session returns the database connection.
func (b *backend) Session(ctx context.Context, s logical.Storage) (*mgo.Session, error) {
b.lock.Lock()
defer b.lock.Unlock()
if b.session != nil {
if err := b.session.Ping(); err == nil {
return b.session, nil
}
b.session.Close()
}
connConfigJSON, err := s.Get(ctx, "config/connection")
if err != nil {
return nil, err
}
if connConfigJSON == nil {
return nil, fmt.Errorf("configure the MongoDB connection with config/connection first")
}
var connConfig connectionConfig
if err := connConfigJSON.DecodeJSON(&connConfig); err != nil {
return nil, err
}
dialInfo, err := parseMongoURI(connConfig.URI)
if err != nil {
return nil, err
}
b.session, err = mgo.DialWithInfo(dialInfo)
if err != nil {
return nil, err
}
b.session.SetSyncTimeout(1 * time.Minute)
b.session.SetSocketTimeout(1 * time.Minute)
return b.session, nil
}
// ResetSession forces creation of a new connection next time Session() is called.
func (b *backend) ResetSession(_ context.Context) {
b.lock.Lock()
defer b.lock.Unlock()
if b.session != nil {
b.session.Close()
}
b.session = nil
}
func (b *backend) invalidate(ctx context.Context, key string) {
switch key {
case "config/connection":
b.ResetSession(ctx)
}
}
// LeaseConfig returns the lease configuration
func (b *backend) LeaseConfig(ctx context.Context, s logical.Storage) (*configLease, error) {
entry, err := s.Get(ctx, "config/lease")
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var result configLease
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
const backendHelp = `
The mongodb backend dynamically generates MongoDB credentials.
After mounting this backend, configure it using the endpoints within
the "config/" path.
`

View File

@ -1,268 +0,0 @@
package mongodb
import (
"context"
"fmt"
"log"
"strings"
"sync"
"testing"
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
"github.com/hashicorp/vault/helper/testhelpers/mongodb"
"github.com/hashicorp/vault/sdk/logical"
"github.com/mitchellh/mapstructure"
)
var testImagePull sync.Once
func TestBackend_config_connection(t *testing.T) {
var resp *logical.Response
var err error
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
configData := map[string]interface{}{
"uri": "sample_connection_uri",
"verify_connection": false,
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/connection",
Storage: config.StorageView,
Data: configData,
}
resp, err = b.HandleRequest(context.Background(), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(context.Background(), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
}
func TestBackend_basic(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
cleanup, connURI := mongodb.PrepareTestContainer(t, "5.0.10")
defer cleanup()
connData := map[string]interface{}{
"uri": connURI,
}
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(connData, false),
testAccStepRole(),
testAccStepReadCreds("web"),
},
})
}
func TestBackend_roleCrud(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
cleanup, connURI := mongodb.PrepareTestContainer(t, "5.0.10")
defer cleanup()
connData := map[string]interface{}{
"uri": connURI,
}
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(connData, false),
testAccStepRole(),
testAccStepReadRole("web", testDb, testMongoDBRoles),
testAccStepDeleteRole("web"),
testAccStepReadRole("web", "", ""),
},
})
}
func TestBackend_leaseWriteRead(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
cleanup, connURI := mongodb.PrepareTestContainer(t, "5.0.10")
defer cleanup()
connData := map[string]interface{}{
"uri": connURI,
}
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(connData, false),
testAccStepWriteLease(),
testAccStepReadLease(),
},
})
}
func testAccStepConfig(d map[string]interface{}, expectError bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/connection",
Data: d,
ErrorOk: true,
Check: func(resp *logical.Response) error {
if expectError {
if resp.Data == nil {
return fmt.Errorf("data is nil")
}
var e struct {
Error string `mapstructure:"error"`
}
if err := mapstructure.Decode(resp.Data, &e); err != nil {
return err
}
if len(e.Error) == 0 {
return fmt.Errorf("expected error, but write succeeded")
}
return nil
} else if resp != nil && resp.IsError() {
return fmt.Errorf("got an error response: %v", resp.Error())
}
return nil
},
}
}
func testAccStepRole() logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "roles/web",
Data: map[string]interface{}{
"db": testDb,
"roles": testMongoDBRoles,
},
}
}
func testAccStepDeleteRole(n string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.DeleteOperation,
Path: "roles/" + n,
}
}
func testAccStepReadCreds(name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "creds/" + name,
Check: func(resp *logical.Response) error {
var d struct {
DB string `mapstructure:"db"`
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if d.DB == "" {
return fmt.Errorf("bad: %#v", resp)
}
if d.Username == "" {
return fmt.Errorf("bad: %#v", resp)
}
if !strings.HasPrefix(d.Username, "vault-root-") {
return fmt.Errorf("bad: %#v", resp)
}
if d.Password == "" {
return fmt.Errorf("bad: %#v", resp)
}
log.Printf("[WARN] Generated credentials: %v", d)
return nil
},
}
}
func testAccStepReadRole(name, db, mongoDBRoles string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "roles/" + name,
Check: func(resp *logical.Response) error {
if resp == nil {
if db == "" && mongoDBRoles == "" {
return nil
}
return fmt.Errorf("bad: %#v", resp)
}
var d struct {
DB string `mapstructure:"db"`
MongoDBRoles string `mapstructure:"roles"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if d.DB != db {
return fmt.Errorf("bad: %#v", resp)
}
if d.MongoDBRoles != mongoDBRoles {
return fmt.Errorf("bad: %#v", resp)
}
return nil
},
}
}
func testAccStepWriteLease() logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/lease",
Data: map[string]interface{}{
"ttl": "1h5m",
"max_ttl": "24h",
},
}
}
func testAccStepReadLease() logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "config/lease",
Check: func(resp *logical.Response) error {
if resp.Data["ttl"].(float64) != 3900 || resp.Data["max_ttl"].(float64) != 86400 {
return fmt.Errorf("bad: %#v", resp)
}
return nil
},
}
}
const (
testDb = "foo"
testMongoDBRoles = `["readWrite",{"role":"read","db":"bar"}]`
)

View File

@ -1,29 +0,0 @@
package main
import (
"os"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/mongodb"
"github.com/hashicorp/vault/sdk/plugin"
)
func main() {
apiClientMeta := &api.PluginAPIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(os.Args[1:])
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
if err := plugin.Serve(&plugin.ServeOpts{
BackendFactoryFunc: mongodb.Factory,
TLSProviderFunc: tlsProviderFunc,
}); err != nil {
logger := hclog.New(&hclog.LoggerOptions{})
logger.Error("plugin shutting down", "error", err)
os.Exit(1)
}
}

View File

@ -1,112 +0,0 @@
package mongodb
import (
"context"
"fmt"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
mgo "gopkg.in/mgo.v2"
)
func pathConfigConnection(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config/connection",
Fields: map[string]*framework.FieldSchema{
"uri": {
Type: framework.TypeString,
Description: "MongoDB standard connection string (URI)",
},
"verify_connection": {
Type: framework.TypeBool,
Default: true,
Description: `If set, uri is verified by actually connecting to the database`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathConnectionRead,
logical.UpdateOperation: b.pathConnectionWrite,
},
HelpSynopsis: pathConfigConnectionHelpSyn,
HelpDescription: pathConfigConnectionHelpDesc,
}
}
// pathConnectionRead reads out the connection configuration
func (b *backend) pathConnectionRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entry, err := req.Storage.Get(ctx, "config/connection")
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration")
}
if entry == nil {
return nil, nil
}
return nil, nil
}
func (b *backend) pathConnectionWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
uri := data.Get("uri").(string)
if uri == "" {
return logical.ErrorResponse("uri parameter is required"), nil
}
dialInfo, err := parseMongoURI(uri)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("invalid uri: %s", err)), nil
}
// Don't check the config if verification is disabled
verifyConnection := data.Get("verify_connection").(bool)
if verifyConnection {
// Verify the config
session, err := mgo.DialWithInfo(dialInfo)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info: %s", err)), nil
}
defer session.Close()
if err := session.Ping(); err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info: %s", err)), nil
}
}
// Store it
entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{
URI: uri,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
// Reset the Session
b.ResetSession(ctx)
resp := &logical.Response{}
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection URI as it is, including passwords, if any.")
return resp, nil
}
type connectionConfig struct {
URI string `json:"uri" structs:"uri" mapstructure:"uri"`
}
const pathConfigConnectionHelpSyn = `
Configure the connection string to talk to MongoDB.
`
const pathConfigConnectionHelpDesc = `
This path configures the standard connection string (URI) used to connect to MongoDB.
A MongoDB URI looks like:
"mongodb://[username:password@]host1[:port1][,host2[:port2],...[,hostN[:portN]]][/[database][?options]]"
See https://docs.mongodb.org/manual/reference/connection-string/ for detailed documentation of the URI format.
When configuring the connection string, the backend will verify its validity.
`

View File

@ -1,89 +0,0 @@
package mongodb
import (
"context"
"time"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func pathConfigLease(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config/lease",
Fields: map[string]*framework.FieldSchema{
"ttl": {
Type: framework.TypeDurationSecond,
Description: "Default ttl for credentials.",
},
"max_ttl": {
Type: framework.TypeDurationSecond,
Description: "Maximum time a set of credentials can be valid for.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathConfigLeaseRead,
logical.UpdateOperation: b.pathConfigLeaseWrite,
},
HelpSynopsis: pathConfigLeaseHelpSyn,
HelpDescription: pathConfigLeaseHelpDesc,
}
}
func (b *backend) pathConfigLeaseWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
TTL: time.Second * time.Duration(d.Get("ttl").(int)),
MaxTTL: time.Second * time.Duration(d.Get("max_ttl").(int)),
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
return nil, nil
}
func (b *backend) pathConfigLeaseRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
leaseConfig, err := b.LeaseConfig(ctx, req.Storage)
if err != nil {
return nil, err
}
if leaseConfig == nil {
return nil, nil
}
return &logical.Response{
Data: map[string]interface{}{
"ttl": leaseConfig.TTL.Seconds(),
"max_ttl": leaseConfig.MaxTTL.Seconds(),
},
}, nil
}
type configLease struct {
TTL time.Duration
MaxTTL time.Duration
}
const pathConfigLeaseHelpSyn = `
Configure the default lease TTL settings for credentials
generated by the mongodb backend.
`
const pathConfigLeaseHelpDesc = `
This configures the default lease TTL settings used for
credentials generated by this backend. The ttl specifies the
duration that a set of credentials will be valid for before
the lease must be renewed (if it is renewable), while the
max_ttl specifies the overall maximum duration that the
credentials will be valid regardless of lease renewals.
The format for the TTL values is an integer and then unit. For
example, the value "1h" specifies a 1-hour TTL. The longest
supported unit is hours.
`

View File

@ -1,119 +0,0 @@
package mongodb
import (
"context"
"fmt"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func pathCredsCreate(b *backend) *framework.Path {
return &framework.Path{
Pattern: "creds/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the role to generate credentials for.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathCredsCreateRead,
},
HelpSynopsis: pathCredsCreateReadHelpSyn,
HelpDescription: pathCredsCreateReadHelpDesc,
}
}
func (b *backend) pathCredsCreateRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
// Get the role
role, err := b.Role(ctx, req.Storage, name)
if err != nil {
return nil, err
}
if role == nil {
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
}
// Determine if we have a lease configuration
leaseConfig, err := b.LeaseConfig(ctx, req.Storage)
if err != nil {
return nil, err
}
if leaseConfig == nil {
leaseConfig = &configLease{}
}
// Generate the username and password
displayName := req.DisplayName
if displayName != "" {
displayName += "-"
}
userUUID, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
username := fmt.Sprintf("vault-%s%s", displayName, userUUID)
password, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
// Build the user creation command
createUserCmd := createUserCommand{
Username: username,
Password: password,
Roles: role.MongoDBRoles.toStandardRolesArray(),
}
// Get our connection
session, err := b.Session(ctx, req.Storage)
if err != nil {
return nil, err
}
// Create the user
err = session.DB(role.DB).Run(createUserCmd, nil)
if err != nil {
return nil, err
}
// Return the secret
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
"db": role.DB,
"username": username,
"password": password,
}, map[string]interface{}{
"username": username,
"db": role.DB,
})
resp.Secret.TTL = leaseConfig.TTL
resp.Secret.MaxTTL = leaseConfig.MaxTTL
return resp, nil
}
type createUserCommand struct {
Username string `bson:"createUser"`
Password string `bson:"pwd"`
Roles []interface{} `bson:"roles"`
}
const pathCredsCreateReadHelpSyn = `
Request MongoDB database credentials for a particular role.
`
const pathCredsCreateReadHelpDesc = `
This path reads generates MongoDB database credentials for
a particular role. The database credentials will be
generated on demand and will be automatically revoked when
the lease is up.
`

View File

@ -1,224 +0,0 @@
package mongodb
import (
"context"
"encoding/json"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func pathListRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "roles/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathRoleList,
},
HelpSynopsis: pathRoleHelpSyn,
HelpDescription: pathRoleHelpDesc,
}
}
func pathRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "roles/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the role.",
},
"db": {
Type: framework.TypeString,
Description: "Name of the authentication database for users generated for this role.",
},
"roles": {
Type: framework.TypeString,
Description: "MongoDB roles to assign to the users generated for this role.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRoleRead,
logical.UpdateOperation: b.pathRoleCreate,
logical.DeleteOperation: b.pathRoleDelete,
},
HelpSynopsis: pathRoleHelpSyn,
HelpDescription: pathRoleHelpDesc,
}
}
func (b *backend) Role(ctx context.Context, s logical.Storage, n string) (*roleStorageEntry, error) {
entry, err := s.Get(ctx, "role/"+n)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var result roleStorageEntry
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete(ctx, "role/"+data.Get("name").(string))
if err != nil {
return nil, err
}
return nil, nil
}
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
role, err := b.Role(ctx, req.Storage, data.Get("name").(string))
if err != nil {
return nil, err
}
if role == nil {
return nil, nil
}
rolesJsonBytes, err := json.Marshal(role.MongoDBRoles.toStandardRolesArray())
if err != nil {
return nil, err
}
return &logical.Response{
Data: map[string]interface{}{
"db": role.DB,
"roles": string(rolesJsonBytes),
},
}, nil
}
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List(ctx, "role/")
if err != nil {
return nil, err
}
return logical.ListResponse(entries), nil
}
func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if name == "" {
return logical.ErrorResponse("Missing name"), nil
}
roleDB := data.Get("db").(string)
if roleDB == "" {
return logical.ErrorResponse("db parameter is required"), nil
}
// Example roles JSON:
//
// [ "readWrite", { "role": "readWrite", "db": "test" } ]
//
// For storage, we convert such an array into a homogeneous array of role documents like:
//
// [ { "role": "readWrite" }, { "role": "readWrite", "db": "test" } ]
//
var roles []mongodbRole
rolesJson := []byte(data.Get("roles").(string))
if len(rolesJson) > 0 {
var rolesArray []interface{}
err := json.Unmarshal(rolesJson, &rolesArray)
if err != nil {
return nil, err
}
for _, rawRole := range rolesArray {
switch role := rawRole.(type) {
case string:
roles = append(roles, mongodbRole{Role: role})
case map[string]interface{}:
if db, ok := role["db"].(string); ok {
if roleName, ok := role["role"].(string); ok {
roles = append(roles, mongodbRole{Role: roleName, DB: db})
}
}
}
}
}
// Store it
entry, err := logical.StorageEntryJSON("role/"+name, &roleStorageEntry{
DB: roleDB,
MongoDBRoles: roles,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
return nil, nil
}
func (roles mongodbRoles) toStandardRolesArray() []interface{} {
// Convert array of role documents like:
//
// [ { "role": "readWrite" }, { "role": "readWrite", "db": "test" } ]
//
// into a "standard" MongoDB roles array containing both strings and role documents:
//
// [ "readWrite", { "role": "readWrite", "db": "test" } ]
//
// MongoDB's createUser command accepts the latter.
//
var standardRolesArray []interface{}
for _, role := range roles {
if role.DB == "" {
standardRolesArray = append(standardRolesArray, role.Role)
} else {
standardRolesArray = append(standardRolesArray, role)
}
}
return standardRolesArray
}
type roleStorageEntry struct {
DB string `json:"db"`
MongoDBRoles mongodbRoles `json:"roles"`
}
type mongodbRole struct {
Role string `json:"role" bson:"role"`
DB string `json:"db" bson:"db"`
}
type mongodbRoles []mongodbRole
const pathRoleHelpSyn = `
Manage the roles used to generate MongoDB credentials.
`
const pathRoleHelpDesc = `
This path lets you manage the roles used to generate MongoDB credentials.
The "db" parameter specifies the authentication database for users
generated for a given role.
The "roles" parameter specifies the MongoDB roles that should be assigned
to users created for a given role. Just like when creating a user directly
using db.createUser, the roles JSON array can specify both built-in roles
and user-defined roles for both the database the user is created in and
for other databases.
For example, the following roles JSON array grants the "readWrite"
permission on both the user's authentication database and the "test"
database:
[ "readWrite", { "role": "readWrite", "db": "test" } ]
Please consult the MongoDB documentation for more
details on Role-Based Access Control in MongoDB.
`

View File

@ -1,84 +0,0 @@
package mongodb
import (
"context"
"fmt"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
mgo "gopkg.in/mgo.v2"
)
const SecretCredsType = "creds"
func secretCreds(b *backend) *framework.Secret {
return &framework.Secret{
Type: SecretCredsType,
Fields: map[string]*framework.FieldSchema{
"username": {
Type: framework.TypeString,
Description: "Username",
},
"password": {
Type: framework.TypeString,
Description: "Password",
},
},
Renew: b.secretCredsRenew,
Revoke: b.secretCredsRevoke,
}
}
func (b *backend) secretCredsRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the lease information
leaseConfig, err := b.LeaseConfig(ctx, req.Storage)
if err != nil {
return nil, err
}
if leaseConfig == nil {
leaseConfig = &configLease{}
}
resp := &logical.Response{Secret: req.Secret}
resp.Secret.TTL = leaseConfig.TTL
resp.Secret.MaxTTL = leaseConfig.MaxTTL
return resp, nil
}
func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the username from the internal data
usernameRaw, ok := req.Secret.InternalData["username"]
if !ok {
return nil, fmt.Errorf("secret is missing username internal data")
}
username, ok := usernameRaw.(string)
if !ok {
return nil, fmt.Errorf("username internal data is not a string")
}
// Get the db from the internal data
dbRaw, ok := req.Secret.InternalData["db"]
if !ok {
return nil, fmt.Errorf("secret is missing db internal data")
}
db, ok := dbRaw.(string)
if !ok {
return nil, fmt.Errorf("db internal data is not a string")
}
// Get our connection
session, err := b.Session(ctx, req.Storage)
if err != nil {
return nil, err
}
// Drop the user
err = session.DB(db).RemoveUser(username)
if err != nil && err != mgo.ErrNotFound {
return nil, err
}
return nil, nil
}

View File

@ -1,81 +0,0 @@
package mongodb
import (
"crypto/tls"
"errors"
"net"
"net/url"
"strconv"
"strings"
"time"
mgo "gopkg.in/mgo.v2"
)
// Unfortunately, mgo doesn't support the ssl parameter in its MongoDB URI parsing logic, so we have to handle that
// ourselves. See https://github.com/go-mgo/mgo/issues/84
func parseMongoURI(rawUri string) (*mgo.DialInfo, error) {
uri, err := url.Parse(rawUri)
if err != nil {
return nil, err
}
info := mgo.DialInfo{
Addrs: strings.Split(uri.Host, ","),
Database: strings.TrimPrefix(uri.Path, "/"),
Timeout: 10 * time.Second,
}
if uri.User != nil {
info.Username = uri.User.Username()
info.Password, _ = uri.User.Password()
}
query := uri.Query()
for key, values := range query {
var value string
if len(values) > 0 {
value = values[0]
}
switch key {
case "authSource":
info.Source = value
case "authMechanism":
info.Mechanism = value
case "gssapiServiceName":
info.Service = value
case "replicaSet":
info.ReplicaSetName = value
case "maxPoolSize":
poolLimit, err := strconv.Atoi(value)
if err != nil {
return nil, errors.New("bad value for maxPoolSize: " + value)
}
info.PoolLimit = poolLimit
case "ssl":
ssl, err := strconv.ParseBool(value)
if err != nil {
return nil, errors.New("bad value for ssl: " + value)
}
if ssl {
info.DialServer = func(addr *mgo.ServerAddr) (net.Conn, error) {
return tls.Dial("tcp", addr.String(), &tls.Config{})
}
}
case "connect":
if value == "direct" {
info.Direct = true
break
}
if value == "replicaSet" {
break
}
fallthrough
default:
return nil, errors.New("unsupported connection URL option: " + key + "=" + value)
}
}
return &info, nil
}

View File

@ -1,160 +0,0 @@
package mssql
import (
"context"
"database/sql"
"fmt"
"strings"
"sync"
_ "github.com/denisenkom/go-mssqldb"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
}
func Backend() *backend {
var b backend
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
SealWrapStorage: []string{
"config/connection",
},
},
Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigLease(&b),
pathListRoles(&b),
pathRoles(&b),
pathCredsCreate(&b),
},
Secrets: []*framework.Secret{
secretCreds(&b),
},
Invalidate: b.invalidate,
Clean: b.ResetDB,
BackendType: logical.TypeLogical,
}
return &b
}
type backend struct {
*framework.Backend
db *sql.DB
defaultDb string
lock sync.Mutex
}
// DB returns the default database connection.
func (b *backend) DB(ctx context.Context, s logical.Storage) (*sql.DB, error) {
b.lock.Lock()
defer b.lock.Unlock()
// If we already have a DB, we got it!
if b.db != nil {
if err := b.db.Ping(); err == nil {
return b.db, nil
}
// If the ping was unsuccessful, close it and ignore errors as we'll be
// reestablishing anyways
b.db.Close()
}
// Otherwise, attempt to make connection
entry, err := s.Get(ctx, "config/connection")
if err != nil {
return nil, err
}
if entry == nil {
return nil, fmt.Errorf("configure the DB connection with config/connection first")
}
var connConfig connectionConfig
if err := entry.DecodeJSON(&connConfig); err != nil {
return nil, err
}
connString := connConfig.ConnectionString
db, err := sql.Open("sqlserver", connString)
if err != nil {
return nil, err
}
// Set some connection pool settings. We don't need much of this,
// since the request rate shouldn't be high.
db.SetMaxOpenConns(connConfig.MaxOpenConnections)
stmt, err := db.Prepare("SELECT db_name();")
if err != nil {
return nil, err
}
defer stmt.Close()
err = stmt.QueryRow().Scan(&b.defaultDb)
if err != nil {
return nil, err
}
b.db = db
return b.db, nil
}
// ResetDB forces a connection next time DB() is called.
func (b *backend) ResetDB(_ context.Context) {
b.lock.Lock()
defer b.lock.Unlock()
if b.db != nil {
b.db.Close()
}
b.db = nil
}
func (b *backend) invalidate(ctx context.Context, key string) {
switch key {
case "config/connection":
b.ResetDB(ctx)
}
}
// LeaseConfig returns the lease configuration
func (b *backend) LeaseConfig(ctx context.Context, s logical.Storage) (*configLease, error) {
entry, err := s.Get(ctx, "config/lease")
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var result configLease
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
const backendHelp = `
The MSSQL backend dynamically generates database users.
After mounting this backend, configure it using the endpoints within
the "config/" path.
This backend does not support Azure SQL Databases.
`

View File

@ -1,222 +0,0 @@
package mssql
import (
"context"
"fmt"
"log"
"reflect"
"testing"
_ "github.com/denisenkom/go-mssqldb"
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
mssqlhelper "github.com/hashicorp/vault/helper/testhelpers/mssql"
"github.com/hashicorp/vault/sdk/logical"
"github.com/mitchellh/mapstructure"
)
func Backend_config_connection(t *testing.T) {
var resp *logical.Response
var err error
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
configData := map[string]interface{}{
"connection_string": "sample_connection_string",
"max_open_connections": 7,
"verify_connection": false,
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/connection",
Storage: config.StorageView,
Data: configData,
}
resp, err = b.HandleRequest(context.Background(), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(context.Background(), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
delete(configData, "verify_connection")
delete(configData, "connection_string")
if !reflect.DeepEqual(configData, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data)
}
}
func TestBackend_basic(t *testing.T) {
b, _ := Factory(context.Background(), logical.TestBackendConfig())
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
logicaltest.Test(t, logicaltest.TestCase{
PreCheck: testAccPreCheckFunc(t, connURL),
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connURL),
testAccStepRole(t),
testAccStepReadCreds(t, "web"),
},
})
}
func TestBackend_roleCrud(t *testing.T) {
b := Backend()
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
logicaltest.Test(t, logicaltest.TestCase{
PreCheck: testAccPreCheckFunc(t, connURL),
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connURL),
testAccStepRole(t),
testAccStepReadRole(t, "web", testRoleSQL),
testAccStepDeleteRole(t, "web"),
testAccStepReadRole(t, "web", ""),
},
})
}
func TestBackend_leaseWriteRead(t *testing.T) {
b := Backend()
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
defer cleanup()
logicaltest.Test(t, logicaltest.TestCase{
PreCheck: testAccPreCheckFunc(t, connURL),
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connURL),
testAccStepWriteLease(t),
testAccStepReadLease(t),
},
})
}
func testAccPreCheckFunc(t *testing.T, connectionURL string) func() {
return func() {
if connectionURL == "" {
t.Fatal("connection URL must be set for acceptance tests")
}
}
}
func testAccStepConfig(t *testing.T, connURL string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/connection",
Data: map[string]interface{}{
"connection_string": connURL,
},
}
}
func testAccStepRole(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "roles/web",
Data: map[string]interface{}{
"sql": testRoleSQL,
},
}
}
func testAccStepDeleteRole(t *testing.T, n string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.DeleteOperation,
Path: "roles/" + n,
}
}
func testAccStepReadCreds(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "creds/" + name,
Check: func(resp *logical.Response) error {
var d struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[WARN] Generated credentials: %v", d)
return nil
},
}
}
func testAccStepReadRole(t *testing.T, name, sql string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "roles/" + name,
Check: func(resp *logical.Response) error {
if resp == nil {
if sql == "" {
return nil
}
return fmt.Errorf("bad: %#v", resp)
}
var d struct {
SQL string `mapstructure:"sql"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if d.SQL != sql {
return fmt.Errorf("bad: %#v", resp)
}
return nil
},
}
}
func testAccStepWriteLease(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/lease",
Data: map[string]interface{}{
"ttl": "1h5m",
"max_ttl": "24h",
},
}
}
func testAccStepReadLease(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "config/lease",
Check: func(resp *logical.Response) error {
if resp.Data["ttl"] != "1h5m0s" || resp.Data["max_ttl"] != "24h0m0s" {
return fmt.Errorf("bad: %#v", resp)
}
return nil
},
}
}
const testRoleSQL = `
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
CREATE USER [{{name}}] FOR LOGIN [{{name}}];
GRANT SELECT ON SCHEMA::dbo TO [{{name}}]
`

View File

@ -1,29 +0,0 @@
package main
import (
"os"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/mssql"
"github.com/hashicorp/vault/sdk/plugin"
)
func main() {
apiClientMeta := &api.PluginAPIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(os.Args[1:])
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
if err := plugin.Serve(&plugin.ServeOpts{
BackendFactoryFunc: mssql.Factory,
TLSProviderFunc: tlsProviderFunc,
}); err != nil {
logger := hclog.New(&hclog.LoggerOptions{})
logger.Error("plugin shutting down", "error", err)
os.Exit(1)
}
}

View File

@ -1,126 +0,0 @@
package mssql
import (
"context"
"database/sql"
"fmt"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func pathConfigConnection(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config/connection",
Fields: map[string]*framework.FieldSchema{
"connection_string": {
Type: framework.TypeString,
Description: "DB connection parameters",
},
"max_open_connections": {
Type: framework.TypeInt,
Description: "Maximum number of open connections to database",
},
"verify_connection": {
Type: framework.TypeBool,
Default: true,
Description: "If set, connection_string is verified by actually connecting to the database",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathConnectionWrite,
logical.ReadOperation: b.pathConnectionRead,
},
HelpSynopsis: pathConfigConnectionHelpSyn,
HelpDescription: pathConfigConnectionHelpDesc,
}
}
// pathConnectionRead reads out the connection configuration
func (b *backend) pathConnectionRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entry, err := req.Storage.Get(ctx, "config/connection")
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration")
}
if entry == nil {
return nil, nil
}
var config connectionConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
return &logical.Response{
Data: map[string]interface{}{
"max_open_connections": config.MaxOpenConnections,
},
}, nil
}
// pathConnectionWrite stores the connection configuration
func (b *backend) pathConnectionWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
connString := data.Get("connection_string").(string)
maxOpenConns := data.Get("max_open_connections").(int)
if maxOpenConns == 0 {
maxOpenConns = 2
}
// Don't check the connection_string if verification is disabled
verifyConnection := data.Get("verify_connection").(bool)
if verifyConnection {
// Verify the string
db, err := sql.Open("mssql", connString)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info: %s", err)), nil
}
defer db.Close()
if err := db.Ping(); err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info: %s", err)), nil
}
}
// Store it
entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{
ConnectionString: connString,
MaxOpenConnections: maxOpenConns,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
// Reset the DB connection
b.ResetDB(ctx)
resp := &logical.Response{}
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string as it is, including passwords, if any.")
return resp, nil
}
type connectionConfig struct {
ConnectionString string `json:"connection_string" structs:"connection_string" mapstructure:"connection_string"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
}
const pathConfigConnectionHelpSyn = `
Configure the connection string to talk to Microsoft Sql Server.
`
const pathConfigConnectionHelpDesc = `
This path configures the connection string used to connect to Sql Server.
The value of the string is a Data Source Name (DSN). An example is
using "server=<hostname>;port=<port>;user id=<username>;password=<password>;database=<database>;app name=vault;"
When configuring the connection string, the backend will verify its validity.
If the database is not available when setting the connection string, set the
"verify_connection" option to false.
`

View File

@ -1,114 +0,0 @@
package mssql
import (
"context"
"fmt"
"time"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func pathConfigLease(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config/lease",
Fields: map[string]*framework.FieldSchema{
"ttl": {
Type: framework.TypeString,
Description: "Default ttl for roles.",
},
"ttl_max": {
Type: framework.TypeString,
Description: `Deprecated: use "max_ttl" instead. Maximum
time a credential is valid for.`,
},
"max_ttl": {
Type: framework.TypeString,
Description: "Maximum time a credential is valid for.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathConfigLeaseRead,
logical.UpdateOperation: b.pathConfigLeaseWrite,
},
HelpSynopsis: pathConfigLeaseHelpSyn,
HelpDescription: pathConfigLeaseHelpDesc,
}
}
func (b *backend) pathConfigLeaseWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
ttlRaw := d.Get("ttl").(string)
ttlMaxRaw := d.Get("max_ttl").(string)
if len(ttlMaxRaw) == 0 {
ttlMaxRaw = d.Get("ttl_max").(string)
}
ttl, err := time.ParseDuration(ttlRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Invalid ttl: %s", err)), nil
}
ttlMax, err := time.ParseDuration(ttlMaxRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Invalid max_ttl: %s", err)), nil
}
// Store it
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
TTL: ttl,
TTLMax: ttlMax,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
return nil, nil
}
func (b *backend) pathConfigLeaseRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
leaseConfig, err := b.LeaseConfig(ctx, req.Storage)
if err != nil {
return nil, err
}
if leaseConfig == nil {
return nil, nil
}
resp := &logical.Response{
Data: map[string]interface{}{
"ttl": leaseConfig.TTL.String(),
"ttl_max": leaseConfig.TTLMax.String(),
"max_ttl": leaseConfig.TTLMax.String(),
},
}
resp.AddWarning("The field ttl_max is deprecated and will be removed in a future release. Use max_ttl instead.")
return resp, nil
}
type configLease struct {
TTL time.Duration
TTLMax time.Duration
}
const pathConfigLeaseHelpSyn = `
Configure the default lease ttl for generated credentials.
`
const pathConfigLeaseHelpDesc = `
This configures the default lease ttl used for credentials
generated by this backend. The ttl specifies the duration that a
credential will be valid for, as well as the maximum session for
a set of credentials.
The format for the ttl is "1h" or integer and then unit. The longest
unit is hour.
`

View File

@ -1,129 +0,0 @@
package mssql
import (
"context"
"fmt"
"strings"
"github.com/hashicorp/go-secure-stdlib/strutil"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/dbtxn"
"github.com/hashicorp/vault/sdk/logical"
)
func pathCredsCreate(b *backend) *framework.Path {
return &framework.Path{
Pattern: "creds/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the role.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathCredsCreateRead,
},
HelpSynopsis: pathCredsCreateHelpSyn,
HelpDescription: pathCredsCreateHelpDesc,
}
}
func (b *backend) pathCredsCreateRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
// Get the role
role, err := b.Role(ctx, req.Storage, name)
if err != nil {
return nil, err
}
if role == nil {
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
}
// Determine if we have a lease configuration
leaseConfig, err := b.LeaseConfig(ctx, req.Storage)
if err != nil {
return nil, err
}
if leaseConfig == nil {
leaseConfig = &configLease{}
}
// Generate our username and password
displayName := req.DisplayName
if len(displayName) > 10 {
displayName = displayName[:10]
}
userUUID, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
username := fmt.Sprintf("%s-%s", displayName, userUUID)
password, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
// Get our handle
db, err := b.DB(ctx, req.Storage)
if err != nil {
return nil, err
}
// Start a transaction
tx, err := db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
// Always reset database to default db of connection. Since it is in a
// transaction, all statements will be on the same connection in the pool.
roleSQL := fmt.Sprintf("USE [%s]; %s", b.defaultDb, role.SQL)
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(roleSQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"name": username,
"password": password,
}
if err := dbtxn.ExecuteTxQueryDirect(ctx, tx, m, query); err != nil {
return nil, err
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return nil, err
}
// Return the secret
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
"username": username,
"password": password,
}, map[string]interface{}{
"username": username,
})
resp.Secret.TTL = leaseConfig.TTL
resp.Secret.MaxTTL = leaseConfig.TTLMax
return resp, nil
}
const pathCredsCreateHelpSyn = `
Request database credentials for a certain role.
`
const pathCredsCreateHelpDesc = `
This path reads database credentials for a certain role. The
database credentials will be generated on demand and will be automatically
revoked when the lease is up.
`

View File

@ -1,172 +0,0 @@
package mssql
import (
"context"
"fmt"
"strings"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func pathListRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "roles/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathRoleList,
},
HelpSynopsis: pathRoleHelpSyn,
HelpDescription: pathRoleHelpDesc,
}
}
func pathRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "roles/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the role.",
},
"sql": {
Type: framework.TypeString,
Description: "SQL string to create a role. See help for more info.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRoleRead,
logical.UpdateOperation: b.pathRoleCreate,
logical.DeleteOperation: b.pathRoleDelete,
},
HelpSynopsis: pathRoleHelpSyn,
HelpDescription: pathRoleHelpDesc,
}
}
func (b *backend) Role(ctx context.Context, s logical.Storage, n string) (*roleEntry, error) {
entry, err := s.Get(ctx, "role/"+n)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var result roleEntry
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete(ctx, "role/"+data.Get("name").(string))
if err != nil {
return nil, err
}
return nil, nil
}
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
role, err := b.Role(ctx, req.Storage, data.Get("name").(string))
if err != nil {
return nil, err
}
if role == nil {
return nil, nil
}
return &logical.Response{
Data: map[string]interface{}{
"sql": role.SQL,
},
}, nil
}
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List(ctx, "role/")
if err != nil {
return nil, err
}
return logical.ListResponse(entries), nil
}
func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
sql := data.Get("sql").(string)
// Get our connection
db, err := b.DB(ctx, req.Storage)
if err != nil {
return nil, err
}
// Test the query by trying to prepare it
for _, query := range strutil.ParseArbitraryStringSlice(sql, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := db.Prepare(Query(query, map[string]string{
"name": "foo",
"password": "bar",
}))
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error testing query: %s", err)), nil
}
stmt.Close()
}
// Store it
entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
SQL: sql,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
return nil, nil
}
type roleEntry struct {
SQL string `json:"sql"`
}
const pathRoleHelpSyn = `
Manage the roles that can be created with this backend.
`
const pathRoleHelpDesc = `
This path lets you manage the roles that can be created with this backend.
The "sql" parameter customizes the SQL string used to create the login to
the server. The parameter can be a sequence of SQL queries, each semi-colon
separated. Some substitution will be done to the SQL string for certain keys.
The names of the variables must be surrounded by "{{" and "}}" to be replaced.
* "name" - The random username generated for the DB user.
* "password" - The random password generated for the DB user.
Example SQL query to use:
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
CREATE USER [{{name}}] FROM LOGIN [{{name}}];
GRANT SELECT, UPDATE, DELETE, INSERT on SCHEMA::dbo TO [{{name}}];
Please see the Microsoft SQL Server manual on the GRANT command to learn how to
do more fine grained access.
`

View File

@ -1,180 +0,0 @@
package mssql
import (
"context"
"database/sql"
"fmt"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/dbtxn"
"github.com/hashicorp/vault/sdk/logical"
)
const SecretCredsType = "creds"
func secretCreds(b *backend) *framework.Secret {
return &framework.Secret{
Type: SecretCredsType,
Fields: map[string]*framework.FieldSchema{
"username": {
Type: framework.TypeString,
Description: "Username",
},
"password": {
Type: framework.TypeString,
Description: "Password",
},
},
Renew: b.secretCredsRenew,
Revoke: b.secretCredsRevoke,
}
}
func (b *backend) secretCredsRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the lease information
leaseConfig, err := b.LeaseConfig(ctx, req.Storage)
if err != nil {
return nil, err
}
if leaseConfig == nil {
leaseConfig = &configLease{}
}
resp := &logical.Response{Secret: req.Secret}
resp.Secret.TTL = leaseConfig.TTL
resp.Secret.MaxTTL = leaseConfig.TTLMax
return resp, nil
}
func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the username from the internal data
usernameRaw, ok := req.Secret.InternalData["username"]
if !ok {
return nil, fmt.Errorf("secret is missing username internal data")
}
username, ok := usernameRaw.(string)
// Get our connection
db, err := b.DB(ctx, req.Storage)
if err != nil {
return nil, err
}
// First disable server login
disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
if err != nil {
return nil, err
}
defer disableStmt.Close()
if _, err := disableStmt.Exec(); err != nil {
return nil, err
}
// Query for sessions for the login so that we can kill any outstanding
// sessions. There cannot be any active sessions before we drop the logins
// This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible
sessionStmt, err := db.Prepare("SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = @p1;")
if err != nil {
return nil, err
}
defer sessionStmt.Close()
sessionRows, err := sessionStmt.Query(username)
if err != nil {
return nil, err
}
defer sessionRows.Close()
var revokeStmts []string
for sessionRows.Next() {
var sessionID int
err = sessionRows.Scan(&sessionID)
if err != nil {
return nil, err
}
revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID))
}
// Query for database users using undocumented stored procedure for now since
// it is the easiest way to get this information;
// we need to drop the database users before we can drop the login and the role
// This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible
stmt, err := db.Prepare("EXEC master.dbo.sp_msloginmappings @p1;")
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query(username)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var loginName, dbName, qUsername, aliasName sql.NullString
err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName)
if err != nil {
return nil, err
}
if !dbName.Valid {
continue
}
revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName.String, username, username))
}
// we do not stop on error, as we want to remove as
// many permissions as possible right now
var lastStmtError error
for _, query := range revokeStmts {
if err := dbtxn.ExecuteDBQueryDirect(ctx, db, nil, query); err != nil {
lastStmtError = err
continue
}
}
// can't drop if not all database users are dropped
if rows.Err() != nil {
return nil, fmt.Errorf("could not generate sql statements for all rows: %w", rows.Err())
}
if lastStmtError != nil {
return nil, fmt.Errorf("could not perform all sql statements: %w", lastStmtError)
}
// Drop this login
stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username))
if err != nil {
return nil, err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return nil, err
}
return nil, nil
}
const dropUserSQL = `
USE [%s]
IF EXISTS
(SELECT name
FROM sys.database_principals
WHERE name = N'%s')
BEGIN
DROP USER [%s]
END
`
const dropLoginSQL = `
IF EXISTS
(SELECT name
FROM master.sys.server_principals
WHERE name = N'%s')
BEGIN
DROP LOGIN [%s]
END
`

View File

@ -1,28 +0,0 @@
package mssql
import (
"fmt"
"strings"
)
// SplitSQL is used to split a series of SQL statements
func SplitSQL(sql string) []string {
parts := strings.Split(sql, ";")
out := make([]string, 0, len(parts))
for _, p := range parts {
clean := strings.TrimSpace(p)
if len(clean) > 0 {
out = append(out, clean)
}
}
return out
}
// Query templates a query for us.
func Query(tpl string, data map[string]string) string {
for k, v := range data {
tpl = strings.ReplaceAll(tpl, fmt.Sprintf("{{%s}}", k), v)
}
return tpl
}

View File

@ -1,151 +0,0 @@
package mysql
import (
"context"
"database/sql"
"fmt"
"strings"
"sync"
_ "github.com/go-sql-driver/mysql"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
}
func Backend() *backend {
var b backend
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
SealWrapStorage: []string{
"config/connection",
},
},
Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigLease(&b),
pathListRoles(&b),
pathRoles(&b),
pathRoleCreate(&b),
},
Secrets: []*framework.Secret{
secretCreds(&b),
},
Invalidate: b.invalidate,
Clean: b.ResetDB,
BackendType: logical.TypeLogical,
}
return &b
}
type backend struct {
*framework.Backend
db *sql.DB
lock sync.Mutex
}
// DB returns the database connection.
func (b *backend) DB(ctx context.Context, s logical.Storage) (*sql.DB, error) {
b.lock.Lock()
defer b.lock.Unlock()
// If we already have a DB, we got it!
if b.db != nil {
if err := b.db.Ping(); err == nil {
return b.db, nil
}
// If the ping was unsuccessful, close it and ignore errors as we'll be
// reestablishing anyways
b.db.Close()
}
// Otherwise, attempt to make connection
entry, err := s.Get(ctx, "config/connection")
if err != nil {
return nil, err
}
if entry == nil {
return nil,
fmt.Errorf("configure the DB connection with config/connection first")
}
var connConfig connectionConfig
if err := entry.DecodeJSON(&connConfig); err != nil {
return nil, err
}
conn := connConfig.ConnectionURL
if len(conn) == 0 {
conn = connConfig.ConnectionString
}
b.db, err = sql.Open("mysql", conn)
if err != nil {
return nil, err
}
// Set some connection pool settings. We don't need much of this,
// since the request rate shouldn't be high.
b.db.SetMaxOpenConns(connConfig.MaxOpenConnections)
b.db.SetMaxIdleConns(connConfig.MaxIdleConnections)
return b.db, nil
}
// ResetDB forces a connection next time DB() is called.
func (b *backend) ResetDB(_ context.Context) {
b.lock.Lock()
defer b.lock.Unlock()
if b.db != nil {
b.db.Close()
}
b.db = nil
}
func (b *backend) invalidate(ctx context.Context, key string) {
switch key {
case "config/connection":
b.ResetDB(ctx)
}
}
// Lease returns the lease information
func (b *backend) Lease(ctx context.Context, s logical.Storage) (*configLease, error) {
entry, err := s.Get(ctx, "config/lease")
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var result configLease
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
const backendHelp = `
The MySQL backend dynamically generates database users.
After mounting this backend, configure it using the endpoints within
the "config/" path.
`

View File

@ -1,307 +0,0 @@
package mysql
import (
"context"
"fmt"
"log"
"reflect"
"testing"
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
mysqlhelper "github.com/hashicorp/vault/helper/testhelpers/mysql"
"github.com/hashicorp/vault/sdk/logical"
"github.com/mitchellh/mapstructure"
)
func TestBackend_config_connection(t *testing.T) {
var resp *logical.Response
var err error
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
configData := map[string]interface{}{
"connection_url": "sample_connection_url",
"max_open_connections": 9,
"max_idle_connections": 7,
"verify_connection": false,
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/connection",
Storage: config.StorageView,
Data: configData,
}
resp, err = b.HandleRequest(context.Background(), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(context.Background(), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
delete(configData, "verify_connection")
delete(configData, "connection_url")
if !reflect.DeepEqual(configData, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data)
}
}
func TestBackend_basic(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, "secret")
defer cleanup()
connData := map[string]interface{}{
"connection_url": connURL,
}
// for wildcard based mysql user
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
testAccStepRole(t, true),
testAccStepReadCreds(t, "web"),
},
})
}
func TestBackend_basicHostRevoke(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, "secret")
defer cleanup()
connData := map[string]interface{}{
"connection_url": connURL,
}
// for host based mysql user
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
testAccStepRole(t, false),
testAccStepReadCreds(t, "web"),
},
})
}
func TestBackend_roleCrud(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, "secret")
defer cleanup()
connData := map[string]interface{}{
"connection_url": connURL,
}
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
// test SQL with wildcard based user
testAccStepRole(t, true),
testAccStepReadRole(t, "web", testRoleWildCard),
testAccStepDeleteRole(t, "web"),
// test SQL with host based user
testAccStepRole(t, false),
testAccStepReadRole(t, "web", testRoleHost),
testAccStepDeleteRole(t, "web"),
},
})
}
func TestBackend_leaseWriteRead(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, "secret")
defer cleanup()
connData := map[string]interface{}{
"connection_url": connURL,
}
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
testAccStepWriteLease(t),
testAccStepReadLease(t),
},
})
}
func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/connection",
Data: d,
ErrorOk: true,
Check: func(resp *logical.Response) error {
if expectError {
if resp.Data == nil {
return fmt.Errorf("data is nil")
}
var e struct {
Error string `mapstructure:"error"`
}
if err := mapstructure.Decode(resp.Data, &e); err != nil {
return err
}
if len(e.Error) == 0 {
return fmt.Errorf("expected error, but write succeeded")
}
return nil
} else if resp != nil && resp.IsError() {
return fmt.Errorf("got an error response: %v", resp.Error())
}
return nil
},
}
}
func testAccStepRole(t *testing.T, wildCard bool) logicaltest.TestStep {
pathData := make(map[string]interface{})
if wildCard {
pathData = map[string]interface{}{
"sql": testRoleWildCard,
}
} else {
pathData = map[string]interface{}{
"sql": testRoleHost,
"revocation_sql": testRevocationSQL,
}
}
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "roles/web",
Data: pathData,
}
}
func testAccStepDeleteRole(t *testing.T, n string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.DeleteOperation,
Path: "roles/" + n,
}
}
func testAccStepReadCreds(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "creds/" + name,
Check: func(resp *logical.Response) error {
var d struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[WARN] Generated credentials: %v", d)
return nil
},
}
}
func testAccStepReadRole(t *testing.T, name string, sql string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "roles/" + name,
Check: func(resp *logical.Response) error {
if resp == nil {
if sql == "" {
return nil
}
return fmt.Errorf("bad: %#v", resp)
}
var d struct {
SQL string `mapstructure:"sql"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if d.SQL != sql {
return fmt.Errorf("bad: %#v", resp)
}
return nil
},
}
}
func testAccStepWriteLease(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/lease",
Data: map[string]interface{}{
"lease": "1h5m",
"lease_max": "24h",
},
}
}
func testAccStepReadLease(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "config/lease",
Check: func(resp *logical.Response) error {
if resp.Data["lease"] != "1h5m0s" || resp.Data["lease_max"] != "24h0m0s" {
return fmt.Errorf("bad: %#v", resp)
}
return nil
},
}
}
const testRoleWildCard = `
CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}';
GRANT SELECT ON *.* TO '{{name}}'@'%';
`
const testRoleHost = `
CREATE USER '{{name}}'@'10.1.1.2' IDENTIFIED BY '{{password}}';
GRANT SELECT ON *.* TO '{{name}}'@'10.1.1.2';
`
const testRevocationSQL = `
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'10.1.1.2';
DROP USER '{{name}}'@'10.1.1.2';
`

View File

@ -1,29 +0,0 @@
package main
import (
"os"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/mysql"
"github.com/hashicorp/vault/sdk/plugin"
)
func main() {
apiClientMeta := &api.PluginAPIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(os.Args[1:])
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
if err := plugin.Serve(&plugin.ServeOpts{
BackendFactoryFunc: mysql.Factory,
TLSProviderFunc: tlsProviderFunc,
}); err != nil {
logger := hclog.New(&hclog.LoggerOptions{})
logger.Error("plugin shutting down", "error", err)
os.Exit(1)
}
}

View File

@ -1,159 +0,0 @@
package mysql
import (
"context"
"database/sql"
"fmt"
_ "github.com/go-sql-driver/mysql"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func pathConfigConnection(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config/connection",
Fields: map[string]*framework.FieldSchema{
"connection_url": {
Type: framework.TypeString,
Description: "DB connection string",
},
"value": {
Type: framework.TypeString,
Description: `DB connection string. Use 'connection_url' instead.
This name is deprecated.`,
},
"max_open_connections": {
Type: framework.TypeInt,
Description: "Maximum number of open connections to database",
},
"max_idle_connections": {
Type: framework.TypeInt,
Description: "Maximum number of idle connections to the database; a zero uses the value of max_open_connections and a negative value disables idle connections. If larger than max_open_connections it will be reduced to the same size.",
},
"verify_connection": {
Type: framework.TypeBool,
Default: true,
Description: "If set, connection_url is verified by actually connecting to the database",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathConnectionWrite,
logical.ReadOperation: b.pathConnectionRead,
},
HelpSynopsis: pathConfigConnectionHelpSyn,
HelpDescription: pathConfigConnectionHelpDesc,
}
}
// pathConnectionRead reads out the connection configuration
func (b *backend) pathConnectionRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entry, err := req.Storage.Get(ctx, "config/connection")
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration")
}
if entry == nil {
return nil, nil
}
var config connectionConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
return &logical.Response{
Data: map[string]interface{}{
"max_open_connections": config.MaxOpenConnections,
"max_idle_connections": config.MaxIdleConnections,
},
}, nil
}
func (b *backend) pathConnectionWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
connValue := data.Get("value").(string)
connURL := data.Get("connection_url").(string)
if connURL == "" {
if connValue == "" {
return logical.ErrorResponse("the connection_url parameter must be supplied"), nil
} else {
connURL = connValue
}
}
maxOpenConns := data.Get("max_open_connections").(int)
if maxOpenConns == 0 {
maxOpenConns = 2
}
maxIdleConns := data.Get("max_idle_connections").(int)
if maxIdleConns == 0 {
maxIdleConns = maxOpenConns
}
if maxIdleConns > maxOpenConns {
maxIdleConns = maxOpenConns
}
// Don't check the connection_url if verification is disabled
verifyConnection := data.Get("verify_connection").(bool)
if verifyConnection {
// Verify the string
db, err := sql.Open("mysql", connURL)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"error validating connection info: %s", err)), nil
}
defer db.Close()
if err := db.Ping(); err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"error validating connection info: %s", err)), nil
}
}
// Store it
entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{
ConnectionURL: connURL,
MaxOpenConnections: maxOpenConns,
MaxIdleConnections: maxIdleConns,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
// Reset the DB connection
b.ResetDB(ctx)
resp := &logical.Response{}
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection URL as it is, including passwords, if any.")
return resp, nil
}
type connectionConfig struct {
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
// Deprecate "value" in coming releases
ConnectionString string `json:"value" structs:"value" mapstructure:"value"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
}
const pathConfigConnectionHelpSyn = `
Configure the connection string to talk to MySQL.
`
const pathConfigConnectionHelpDesc = `
This path configures the connection string used to connect to MySQL. The value
of the string is a Data Source Name (DSN). An example is using
"username:password@protocol(address)/dbname?param=value"
For example, RDS may look like:
"id:password@tcp(your-amazonaws-uri.com:3306)/dbname"
When configuring the connection string, the backend will verify its validity.
If the database is not available when setting the connection URL, set the
"verify_connection" option to false.
`

View File

@ -1,101 +0,0 @@
package mysql
import (
"context"
"fmt"
"time"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func pathConfigLease(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config/lease",
Fields: map[string]*framework.FieldSchema{
"lease": {
Type: framework.TypeString,
Description: "Default lease for roles.",
},
"lease_max": {
Type: framework.TypeString,
Description: "Maximum time a credential is valid for.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathLeaseRead,
logical.UpdateOperation: b.pathLeaseWrite,
},
HelpSynopsis: pathConfigLeaseHelpSyn,
HelpDescription: pathConfigLeaseHelpDesc,
}
}
func (b *backend) pathLeaseWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
leaseRaw := d.Get("lease").(string)
leaseMaxRaw := d.Get("lease_max").(string)
lease, err := time.ParseDuration(leaseRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Invalid lease: %s", err)), nil
}
leaseMax, err := time.ParseDuration(leaseMaxRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Invalid lease: %s", err)), nil
}
// Store it
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
Lease: lease,
LeaseMax: leaseMax,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
return nil, nil
}
func (b *backend) pathLeaseRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
lease, err := b.Lease(ctx, req.Storage)
if err != nil {
return nil, err
}
if lease == nil {
return nil, nil
}
return &logical.Response{
Data: map[string]interface{}{
"lease": lease.Lease.String(),
"lease_max": lease.LeaseMax.String(),
},
}, nil
}
type configLease struct {
Lease time.Duration
LeaseMax time.Duration
}
const pathConfigLeaseHelpSyn = `
Configure the default lease information for generated credentials.
`
const pathConfigLeaseHelpDesc = `
This configures the default lease information used for credentials
generated by this backend. The lease specifies the duration that a
credential will be valid for, as well as the maximum session for
a set of credentials.
The format for the lease is "1h" or integer and then unit. The longest
unit is hour.
`

View File

@ -1,143 +0,0 @@
package mysql
import (
"context"
"fmt"
"strings"
"github.com/hashicorp/go-secure-stdlib/strutil"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/dbtxn"
"github.com/hashicorp/vault/sdk/logical"
)
func pathRoleCreate(b *backend) *framework.Path {
return &framework.Path{
Pattern: "creds/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the role.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRoleCreateRead,
},
HelpSynopsis: pathRoleCreateReadHelpSyn,
HelpDescription: pathRoleCreateReadHelpDesc,
}
}
func (b *backend) pathRoleCreateRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
// Get the role
role, err := b.Role(ctx, req.Storage, name)
if err != nil {
return nil, err
}
if role == nil {
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
}
// Determine if we have a lease
lease, err := b.Lease(ctx, req.Storage)
if err != nil {
return nil, err
}
if lease == nil {
lease = &configLease{}
}
// Generate our username and password. The username will be a
// concatenation of:
//
// - the role name, truncated to role.rolenameLength (default 4)
// - the token display name, truncated to role.displaynameLength (default 4)
// - a UUID
//
// the entire concatenated string is then truncated to role.usernameLength,
// which by default is 16 due to limitations in older but still-prevalent
// versions of MySQL.
roleName := name
if len(roleName) > role.RolenameLength {
roleName = roleName[:role.RolenameLength]
}
displayName := req.DisplayName
if len(displayName) > role.DisplaynameLength {
displayName = displayName[:role.DisplaynameLength]
}
userUUID, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
username := fmt.Sprintf("%s-%s-%s", roleName, displayName, userUUID)
if len(username) > role.UsernameLength {
username = username[:role.UsernameLength]
}
password, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
// Get our handle
db, err := b.DB(ctx, req.Storage)
if err != nil {
return nil, err
}
// Start a transaction
tx, err := db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(role.SQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"name": username,
"password": password,
}
if err := dbtxn.ExecuteTxQueryDirect(ctx, tx, m, query); err != nil {
return nil, err
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return nil, err
}
// Return the secret
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
"username": username,
"password": password,
}, map[string]interface{}{
"username": username,
"role": name,
})
resp.Secret.TTL = lease.Lease
resp.Secret.MaxTTL = lease.LeaseMax
return resp, nil
}
const pathRoleCreateReadHelpSyn = `
Request database credentials for a certain role.
`
const pathRoleCreateReadHelpDesc = `
This path reads database credentials for a certain role. The
database credentials will be generated on demand and will be automatically
revoked when the lease is up.
`

View File

@ -1,230 +0,0 @@
package mysql
import (
"context"
"fmt"
"strings"
_ "github.com/go-sql-driver/mysql"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func pathListRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "roles/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathRoleList,
},
HelpSynopsis: pathRoleHelpSyn,
HelpDescription: pathRoleHelpDesc,
}
}
func pathRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "roles/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the role.",
},
"sql": {
Type: framework.TypeString,
Description: "SQL string to create a user. See help for more info.",
},
"revocation_sql": {
Type: framework.TypeString,
Description: "SQL string to revoke a user. See help for more info.",
},
"username_length": {
Type: framework.TypeInt,
Description: "number of characters to truncate generated mysql usernames to (default 16)",
Default: 16,
},
"rolename_length": {
Type: framework.TypeInt,
Description: "number of characters to truncate the rolename portion of generated mysql usernames to (default 4)",
Default: 4,
},
"displayname_length": {
Type: framework.TypeInt,
Description: "number of characters to truncate the displayname portion of generated mysql usernames to (default 4)",
Default: 4,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRoleRead,
logical.UpdateOperation: b.pathRoleCreate,
logical.DeleteOperation: b.pathRoleDelete,
},
HelpSynopsis: pathRoleHelpSyn,
HelpDescription: pathRoleHelpDesc,
}
}
func (b *backend) Role(ctx context.Context, s logical.Storage, n string) (*roleEntry, error) {
entry, err := s.Get(ctx, "role/"+n)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
// Set defaults to handle upgrade cases
result := roleEntry{
UsernameLength: 16,
RolenameLength: 4,
DisplaynameLength: 4,
}
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete(ctx, "role/"+data.Get("name").(string))
if err != nil {
return nil, err
}
return nil, nil
}
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
role, err := b.Role(ctx, req.Storage, data.Get("name").(string))
if err != nil {
return nil, err
}
if role == nil {
return nil, nil
}
return &logical.Response{
Data: map[string]interface{}{
"sql": role.SQL,
"revocation_sql": role.RevocationSQL,
},
}, nil
}
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List(ctx, "role/")
if err != nil {
return nil, err
}
return logical.ListResponse(entries), nil
}
func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
// Get our connection
db, err := b.DB(ctx, req.Storage)
if err != nil {
return nil, err
}
// Test the query by trying to prepare it
sql := data.Get("sql").(string)
for _, query := range strutil.ParseArbitraryStringSlice(sql, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := db.Prepare(Query(query, map[string]string{
"name": "foo",
"password": "bar",
}))
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error testing query: %s", err)), nil
}
stmt.Close()
}
// Store it
entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
SQL: sql,
RevocationSQL: data.Get("revocation_sql").(string),
UsernameLength: data.Get("username_length").(int),
DisplaynameLength: data.Get("displayname_length").(int),
RolenameLength: data.Get("rolename_length").(int),
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
return nil, nil
}
type roleEntry struct {
SQL string `json:"sql" mapstructure:"sql" structs:"sql"`
RevocationSQL string `json:"revocation_sql" mapstructure:"revocation_sql" structs:"revocation_sql"`
UsernameLength int `json:"username_length" mapstructure:"username_length" structs:"username_length"`
DisplaynameLength int `json:"displayname_length" mapstructure:"displayname_length" structs:"displayname_length"`
RolenameLength int `json:"rolename_length" mapstructure:"rolename_length" structs:"rolename_length"`
}
const pathRoleHelpSyn = `
Manage the roles that can be created with this backend.
`
const pathRoleHelpDesc = `
This path lets you manage the roles that can be created with this backend.
The "sql" parameter customizes the SQL string used to create the role.
This can be a sequence of SQL queries, each semi-colon separated. Some
substitution will be done to the SQL string for certain keys.
The names of the variables must be surrounded by "{{" and "}}" to be replaced.
* "name" - The random username generated for the DB user.
* "password" - The random password generated for the DB user.
Example of a decent SQL query to use:
CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}';
GRANT ALL ON db1.* TO '{{name}}'@'%';
Note the above user would be able to access anything in db1. Please see the MySQL
manual on the GRANT command to learn how to do more fine grained access.
The "rolename_length" parameter determines how many characters of the role name
will be used in creating the generated mysql username; the default is 4.
The "displayname_length" parameter determines how many characters of the token
display name will be used in creating the generated mysql username; the default
is 4.
The "username_length" parameter determines how many total characters the
generated username (including the role name, token display name and the uuid
portion) will be truncated to. Versions of MySQL prior to 5.7.8 are limited to
16 characters total (see
http://dev.mysql.com/doc/refman/5.7/en/user-names.html) so that is the default;
for versions >=5.7.8 it is safe to increase this to 32.
For best readability in MySQL process lists, we recommend using MySQL 5.7.8 or
later, setting "username_length" to 32 and setting both "rolename_length" and
"displayname_length" to 8. However due the the prevalence of older versions of
MySQL in general deployment, the defaults are currently tuned for a
username_length of 16.
`

View File

@ -1,136 +0,0 @@
package mysql
import (
"context"
"fmt"
"strings"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
const SecretCredsType = "creds"
// defaultRevocationSQL is a default SQL statement for revoking a user. Revoking
// permissions for the user is done before the drop, because MySQL explicitly
// documents that open user connections will not be closed. By revoking all
// grants, at least we ensure that the open connection is useless. Dropping the
// user will only affect the next connection.
const defaultRevocationSQL = `
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
DROP USER '{{name}}'@'%'
`
func secretCreds(b *backend) *framework.Secret {
return &framework.Secret{
Type: SecretCredsType,
Fields: map[string]*framework.FieldSchema{
"username": {
Type: framework.TypeString,
Description: "Username",
},
"password": {
Type: framework.TypeString,
Description: "Password",
},
},
Renew: b.secretCredsRenew,
Revoke: b.secretCredsRevoke,
}
}
func (b *backend) secretCredsRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the lease information
lease, err := b.Lease(ctx, req.Storage)
if err != nil {
return nil, err
}
if lease == nil {
lease = &configLease{}
}
resp := &logical.Response{Secret: req.Secret}
resp.Secret.TTL = lease.Lease
resp.Secret.MaxTTL = lease.LeaseMax
return resp, nil
}
func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
var resp *logical.Response
// Get the username from the internal data
usernameRaw, ok := req.Secret.InternalData["username"]
if !ok {
return nil, fmt.Errorf("secret is missing username internal data")
}
username, ok := usernameRaw.(string)
if !ok {
return nil, fmt.Errorf("usernameRaw is not a string")
}
// Get our connection
db, err := b.DB(ctx, req.Storage)
if err != nil {
return nil, err
}
roleName := ""
roleNameRaw, ok := req.Secret.InternalData["role"]
if ok {
roleName = roleNameRaw.(string)
}
var role *roleEntry
if roleName != "" {
role, err = b.Role(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
}
// Use a default SQL statement for revocation if one cannot be fetched from the role
revocationSQL := defaultRevocationSQL
if role != nil && role.RevocationSQL != "" {
revocationSQL = role.RevocationSQL
} else {
if resp == nil {
resp = &logical.Response{}
}
resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default SQL for revoking user.", roleName))
}
// Start a transaction
tx, err := db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
// This is not a prepared statement because not all commands are supported
// 1295: This command is not supported in the prepared statement protocol yet
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
query = strings.ReplaceAll(query, "{{name}}", username)
_, err = tx.Exec(query)
if err != nil {
return nil, err
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return nil, err
}
return resp, nil
}

View File

@ -1,15 +0,0 @@
package mysql
import (
"fmt"
"strings"
)
// Query templates a query for us.
func Query(tpl string, data map[string]string) string {
for k, v := range data {
tpl = strings.ReplaceAll(tpl, fmt.Sprintf("{{%s}}", k), v)
}
return tpl
}

View File

@ -1,171 +0,0 @@
package postgresql
import (
"context"
"database/sql"
"fmt"
"strings"
"sync"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend(conf)
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
}
func Backend(conf *logical.BackendConfig) *backend {
var b backend
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
PathsSpecial: &logical.Paths{
SealWrapStorage: []string{
"config/connection",
},
},
Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigLease(&b),
pathListRoles(&b),
pathRoles(&b),
pathRoleCreate(&b),
},
Secrets: []*framework.Secret{
secretCreds(&b),
},
Clean: b.ResetDB,
Invalidate: b.invalidate,
BackendType: logical.TypeLogical,
}
b.logger = conf.Logger
return &b
}
type backend struct {
*framework.Backend
db *sql.DB
lock sync.Mutex
logger log.Logger
}
// DB returns the database connection.
func (b *backend) DB(ctx context.Context, s logical.Storage) (*sql.DB, error) {
b.logger.Debug("postgres/db: enter")
defer b.logger.Debug("postgres/db: exit")
b.lock.Lock()
defer b.lock.Unlock()
// If we already have a DB, we got it!
if b.db != nil {
if err := b.db.Ping(); err == nil {
return b.db, nil
}
// If the ping was unsuccessful, close it and ignore errors as we'll be
// reestablishing anyways
b.db.Close()
}
// Otherwise, attempt to make connection
entry, err := s.Get(ctx, "config/connection")
if err != nil {
return nil, err
}
if entry == nil {
return nil,
fmt.Errorf("configure the DB connection with config/connection first")
}
var connConfig connectionConfig
if err := entry.DecodeJSON(&connConfig); err != nil {
return nil, err
}
conn := connConfig.ConnectionURL
if len(conn) == 0 {
conn = connConfig.ConnectionString
}
// Ensure timezone is set to UTC for all the connections
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
if strings.Contains(conn, "?") {
conn += "&timezone=utc"
} else {
conn += "?timezone=utc"
}
} else {
conn += "&timezone=utc"
}
b.db, err = sql.Open("pgx", conn)
if err != nil {
return nil, err
}
// Set some connection pool settings. We don't need much of this,
// since the request rate shouldn't be high.
b.db.SetMaxOpenConns(connConfig.MaxOpenConnections)
b.db.SetMaxIdleConns(connConfig.MaxIdleConnections)
return b.db, nil
}
// ResetDB forces a connection next time DB() is called.
func (b *backend) ResetDB(_ context.Context) {
b.logger.Debug("postgres/db: enter")
defer b.logger.Debug("postgres/db: exit")
b.lock.Lock()
defer b.lock.Unlock()
if b.db != nil {
b.db.Close()
}
b.db = nil
}
func (b *backend) invalidate(ctx context.Context, key string) {
switch key {
case "config/connection":
b.ResetDB(ctx)
}
}
// Lease returns the lease information
func (b *backend) Lease(ctx context.Context, s logical.Storage) (*configLease, error) {
entry, err := s.Get(ctx, "config/lease")
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var result configLease
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
const backendHelp = `
The PostgreSQL backend dynamically generates database users.
After mounting this backend, configure it using the endpoints within
the "config/" path.
`

View File

@ -1,532 +0,0 @@
package postgresql
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"log"
"path"
"reflect"
"testing"
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
postgreshelper "github.com/hashicorp/vault/helper/testhelpers/postgresql"
"github.com/hashicorp/vault/sdk/logical"
"github.com/mitchellh/mapstructure"
)
func TestBackend_config_connection(t *testing.T) {
var resp *logical.Response
var err error
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
configData := map[string]interface{}{
"connection_url": "sample_connection_url",
"max_open_connections": 9,
"max_idle_connections": 7,
"verify_connection": false,
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/connection",
Storage: config.StorageView,
Data: configData,
}
resp, err = b.HandleRequest(context.Background(), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(context.Background(), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
delete(configData, "verify_connection")
delete(configData, "connection_url")
if !reflect.DeepEqual(configData, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data)
}
}
func TestBackend_basic(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
cleanup, connURL := postgreshelper.PrepareTestContainer(t, "")
defer cleanup()
connData := map[string]interface{}{
"connection_url": connURL,
}
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
testAccStepCreateRole(t, "web", testRole, false),
testAccStepReadCreds(t, b, config.StorageView, "web", connURL),
},
})
}
func TestBackend_roleCrud(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
cleanup, connURL := postgreshelper.PrepareTestContainer(t, "")
defer cleanup()
connData := map[string]interface{}{
"connection_url": connURL,
}
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
testAccStepCreateRole(t, "web", testRole, false),
testAccStepReadRole(t, "web", testRole),
testAccStepDeleteRole(t, "web"),
testAccStepReadRole(t, "web", ""),
},
})
}
func TestBackend_BlockStatements(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
cleanup, connURL := postgreshelper.PrepareTestContainer(t, "")
defer cleanup()
connData := map[string]interface{}{
"connection_url": connURL,
}
jsonBlockStatement, err := json.Marshal(testBlockStatementRoleSlice)
if err != nil {
t.Fatal(err)
}
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
// This will also validate the query
testAccStepCreateRole(t, "web-block", testBlockStatementRole, true),
testAccStepCreateRole(t, "web-block", string(jsonBlockStatement), false),
},
})
}
func TestBackend_roleReadOnly(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
cleanup, connURL := postgreshelper.PrepareTestContainer(t, "")
defer cleanup()
connData := map[string]interface{}{
"connection_url": connURL,
}
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
testAccStepCreateRole(t, "web", testRole, false),
testAccStepCreateRole(t, "web-readonly", testReadOnlyRole, false),
testAccStepReadRole(t, "web-readonly", testReadOnlyRole),
testAccStepCreateTable(t, b, config.StorageView, "web", connURL),
testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL),
testAccStepDropTable(t, b, config.StorageView, "web", connURL),
testAccStepDeleteRole(t, "web-readonly"),
testAccStepDeleteRole(t, "web"),
testAccStepReadRole(t, "web-readonly", ""),
},
})
}
func TestBackend_roleReadOnly_revocationSQL(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
cleanup, connURL := postgreshelper.PrepareTestContainer(t, "")
defer cleanup()
connData := map[string]interface{}{
"connection_url": connURL,
}
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
testAccStepCreateRoleWithRevocationSQL(t, "web", testRole, defaultRevocationSQL, false),
testAccStepCreateRoleWithRevocationSQL(t, "web-readonly", testReadOnlyRole, defaultRevocationSQL, false),
testAccStepReadRole(t, "web-readonly", testReadOnlyRole),
testAccStepCreateTable(t, b, config.StorageView, "web", connURL),
testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL),
testAccStepDropTable(t, b, config.StorageView, "web", connURL),
testAccStepDeleteRole(t, "web-readonly"),
testAccStepDeleteRole(t, "web"),
testAccStepReadRole(t, "web-readonly", ""),
},
})
}
func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/connection",
Data: d,
ErrorOk: true,
Check: func(resp *logical.Response) error {
if expectError {
if resp.Data == nil {
return fmt.Errorf("data is nil")
}
var e struct {
Error string `mapstructure:"error"`
}
if err := mapstructure.Decode(resp.Data, &e); err != nil {
return err
}
if len(e.Error) == 0 {
return fmt.Errorf("expected error, but write succeeded")
}
return nil
} else if resp != nil && resp.IsError() {
return fmt.Errorf("got an error response: %v", resp.Error())
}
return nil
},
}
}
func testAccStepCreateRole(t *testing.T, name string, sql string, expectFail bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: path.Join("roles", name),
Data: map[string]interface{}{
"sql": sql,
},
ErrorOk: expectFail,
}
}
func testAccStepCreateRoleWithRevocationSQL(t *testing.T, name, sql, revocationSQL string, expectFail bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: path.Join("roles", name),
Data: map[string]interface{}{
"sql": sql,
"revocation_sql": revocationSQL,
},
ErrorOk: expectFail,
}
}
func testAccStepDeleteRole(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.DeleteOperation,
Path: path.Join("roles", name),
}
}
func testAccStepReadCreds(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: path.Join("creds", name),
Check: func(resp *logical.Response) error {
var d struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[TRACE] Generated credentials: %v", d)
db, err := sql.Open("pgx", connURL+"&timezone=utc")
if err != nil {
t.Fatal(err)
}
returnedRows := func() int {
stmt, err := db.Prepare("SELECT DISTINCT schemaname FROM pg_tables WHERE has_table_privilege($1, 'information_schema.role_column_grants', 'select');")
if err != nil {
return -1
}
defer stmt.Close()
rows, err := stmt.Query(d.Username)
if err != nil {
return -1
}
defer rows.Close()
i := 0
for rows.Next() {
i++
}
return i
}
// minNumPermissions is the minimum number of permissions that will always be present.
const minNumPermissions = 2
userRows := returnedRows()
if userRows < minNumPermissions {
t.Fatalf("did not get expected number of rows, got %d", userRows)
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.RevokeOperation,
Storage: s,
Secret: &logical.Secret{
InternalData: map[string]interface{}{
"secret_type": "creds",
"username": d.Username,
"role": name,
},
},
})
if err != nil {
return err
}
if resp != nil {
if resp.IsError() {
return fmt.Errorf("error on resp: %#v", *resp)
}
}
userRows = returnedRows()
// User shouldn't exist so returnedRows() should encounter an error and exit with -1
if userRows != -1 {
t.Fatalf("did not get expected number of rows, got %d", userRows)
}
return nil
},
}
}
func testAccStepCreateTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: path.Join("creds", name),
Check: func(resp *logical.Response) error {
var d struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[TRACE] Generated credentials: %v", d)
db, err := sql.Open("pgx", connURL+"&timezone=utc")
if err != nil {
t.Fatal(err)
}
_, err = db.Exec("CREATE TABLE test (id SERIAL PRIMARY KEY);")
if err != nil {
t.Fatal(err)
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.RevokeOperation,
Storage: s,
Secret: &logical.Secret{
InternalData: map[string]interface{}{
"secret_type": "creds",
"username": d.Username,
},
},
})
if err != nil {
return err
}
if resp != nil {
if resp.IsError() {
return fmt.Errorf("error on resp: %#v", *resp)
}
}
return nil
},
}
}
func testAccStepDropTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: path.Join("creds", name),
Check: func(resp *logical.Response) error {
var d struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[TRACE] Generated credentials: %v", d)
db, err := sql.Open("pgx", connURL+"&timezone=utc")
if err != nil {
t.Fatal(err)
}
_, err = db.Exec("DROP TABLE test;")
if err != nil {
t.Fatal(err)
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.RevokeOperation,
Storage: s,
Secret: &logical.Secret{
InternalData: map[string]interface{}{
"secret_type": "creds",
"username": d.Username,
},
},
})
if err != nil {
return err
}
if resp != nil {
if resp.IsError() {
return fmt.Errorf("error on resp: %#v", *resp)
}
}
return nil
},
}
}
func testAccStepReadRole(t *testing.T, name string, sql string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "roles/" + name,
Check: func(resp *logical.Response) error {
if resp == nil {
if sql == "" {
return nil
}
return fmt.Errorf("bad: %#v", resp)
}
var d struct {
SQL string `mapstructure:"sql"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if d.SQL != sql {
return fmt.Errorf("bad: %#v", resp)
}
return nil
},
}
}
const testRole = `
CREATE ROLE "{{name}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
`
const testReadOnlyRole = `
CREATE ROLE "{{name}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";
`
const testBlockStatementRole = `
DO $$
BEGIN
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
CREATE ROLE "foo-role";
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
ALTER ROLE "foo-role" SET search_path = foo;
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
END IF;
END
$$
CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';
GRANT "foo-role" TO "{{name}}";
ALTER ROLE "{{name}}" SET search_path = foo;
GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";
`
var testBlockStatementRoleSlice = []string{
`
DO $$
BEGIN
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
CREATE ROLE "foo-role";
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
ALTER ROLE "foo-role" SET search_path = foo;
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
END IF;
END
$$
`,
`CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`,
`GRANT "foo-role" TO "{{name}}";`,
`ALTER ROLE "{{name}}" SET search_path = foo;`,
`GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`,
}
const defaultRevocationSQL = `
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}};
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}};
REVOKE USAGE ON SCHEMA public FROM {{name}};
DROP ROLE IF EXISTS {{name}};
`

View File

@ -1,29 +0,0 @@
package main
import (
"os"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/postgresql"
"github.com/hashicorp/vault/sdk/plugin"
)
func main() {
apiClientMeta := &api.PluginAPIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(os.Args[1:])
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
if err := plugin.Serve(&plugin.ServeOpts{
BackendFactoryFunc: postgresql.Factory,
TLSProviderFunc: tlsProviderFunc,
}); err != nil {
logger := hclog.New(&hclog.LoggerOptions{})
logger.Error("plugin shutting down", "error", err)
os.Exit(1)
}
}

View File

@ -1,168 +0,0 @@
package postgresql
import (
"context"
"database/sql"
"fmt"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
_ "github.com/jackc/pgx/v4/stdlib"
)
func pathConfigConnection(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config/connection",
Fields: map[string]*framework.FieldSchema{
"connection_url": {
Type: framework.TypeString,
Description: "DB connection string",
},
"value": {
Type: framework.TypeString,
Description: `DB connection string. Use 'connection_url' instead.
This will be deprecated.`,
},
"verify_connection": {
Type: framework.TypeBool,
Default: true,
Description: `If set, connection_url is verified by actually connecting to the database`,
},
"max_open_connections": {
Type: framework.TypeInt,
Description: `Maximum number of open connections to the database;
a zero uses the default value of two and a
negative value means unlimited`,
},
"max_idle_connections": {
Type: framework.TypeInt,
Description: `Maximum number of idle connections to the database;
a zero uses the value of max_open_connections
and a negative value disables idle connections.
If larger than max_open_connections it will be
reduced to the same size.`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathConnectionWrite,
logical.ReadOperation: b.pathConnectionRead,
},
HelpSynopsis: pathConfigConnectionHelpSyn,
HelpDescription: pathConfigConnectionHelpDesc,
}
}
// pathConnectionRead reads out the connection configuration
func (b *backend) pathConnectionRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entry, err := req.Storage.Get(ctx, "config/connection")
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration")
}
if entry == nil {
return nil, nil
}
var config connectionConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
return &logical.Response{
Data: map[string]interface{}{
"max_open_connections": config.MaxOpenConnections,
"max_idle_connections": config.MaxIdleConnections,
},
}, nil
}
func (b *backend) pathConnectionWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
connValue := data.Get("value").(string)
connURL := data.Get("connection_url").(string)
if connURL == "" {
if connValue == "" {
return logical.ErrorResponse("connection_url parameter must be supplied"), nil
} else {
connURL = connValue
}
}
maxOpenConns := data.Get("max_open_connections").(int)
if maxOpenConns == 0 {
maxOpenConns = 2
}
maxIdleConns := data.Get("max_idle_connections").(int)
if maxIdleConns == 0 {
maxIdleConns = maxOpenConns
}
if maxIdleConns > maxOpenConns {
maxIdleConns = maxOpenConns
}
// Don't check the connection_url if verification is disabled
verifyConnection := data.Get("verify_connection").(bool)
if verifyConnection {
// Verify the string
db, err := sql.Open("pgx", connURL)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info: %s", err)), nil
}
defer db.Close()
if err := db.Ping(); err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info: %s", err)), nil
}
}
// Store it
entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{
ConnectionString: connValue,
ConnectionURL: connURL,
MaxOpenConnections: maxOpenConns,
MaxIdleConnections: maxIdleConns,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
// Reset the DB connection
b.ResetDB(ctx)
resp := &logical.Response{}
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.")
return resp, nil
}
type connectionConfig struct {
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
// Deprecate "value" in coming releases
ConnectionString string `json:"value" structs:"value" mapstructure:"value"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
}
const pathConfigConnectionHelpSyn = `
Configure the connection string to talk to PostgreSQL.
`
const pathConfigConnectionHelpDesc = `
This path configures the connection string used to connect to PostgreSQL.
The value of the string can be a URL, or a PG style string in the
format of "user=foo host=bar" etc.
The URL looks like:
"postgresql://user:pass@host:port/dbname"
When configuring the connection string, the backend will verify its validity.
`

View File

@ -1,101 +0,0 @@
package postgresql
import (
"context"
"fmt"
"time"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func pathConfigLease(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config/lease",
Fields: map[string]*framework.FieldSchema{
"lease": {
Type: framework.TypeString,
Description: "Default lease for roles.",
},
"lease_max": {
Type: framework.TypeString,
Description: "Maximum time a credential is valid for.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathLeaseRead,
logical.UpdateOperation: b.pathLeaseWrite,
},
HelpSynopsis: pathConfigLeaseHelpSyn,
HelpDescription: pathConfigLeaseHelpDesc,
}
}
func (b *backend) pathLeaseWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
leaseRaw := d.Get("lease").(string)
leaseMaxRaw := d.Get("lease_max").(string)
lease, err := time.ParseDuration(leaseRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Invalid lease: %s", err)), nil
}
leaseMax, err := time.ParseDuration(leaseMaxRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Invalid lease: %s", err)), nil
}
// Store it
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
Lease: lease,
LeaseMax: leaseMax,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
return nil, nil
}
func (b *backend) pathLeaseRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
lease, err := b.Lease(ctx, req.Storage)
if err != nil {
return nil, err
}
if lease == nil {
return nil, nil
}
return &logical.Response{
Data: map[string]interface{}{
"lease": lease.Lease.String(),
"lease_max": lease.LeaseMax.String(),
},
}, nil
}
type configLease struct {
Lease time.Duration
LeaseMax time.Duration
}
const pathConfigLeaseHelpSyn = `
Configure the default lease information for generated credentials.
`
const pathConfigLeaseHelpDesc = `
This configures the default lease information used for credentials
generated by this backend. The lease specifies the duration that a
credential will be valid for, as well as the maximum session for
a set of credentials.
The format for the lease is "1h" or integer and then unit. The longest
unit is hour.
`

View File

@ -1,149 +0,0 @@
package postgresql
import (
"context"
"fmt"
"strings"
"time"
"github.com/hashicorp/go-secure-stdlib/strutil"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/dbtxn"
"github.com/hashicorp/vault/sdk/logical"
_ "github.com/jackc/pgx/v4/stdlib"
)
func pathRoleCreate(b *backend) *framework.Path {
return &framework.Path{
Pattern: "creds/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the role.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRoleCreateRead,
},
HelpSynopsis: pathRoleCreateReadHelpSyn,
HelpDescription: pathRoleCreateReadHelpDesc,
}
}
func (b *backend) pathRoleCreateRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
// Get the role
role, err := b.Role(ctx, req.Storage, name)
if err != nil {
return nil, err
}
if role == nil {
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
}
// Determine if we have a lease
lease, err := b.Lease(ctx, req.Storage)
if err != nil {
return nil, err
}
// Unlike some other backends we need a lease here (can't leave as 0 and
// let core fill it in) because Postgres also expires users as a safety
// measure, so cannot be zero
if lease == nil {
lease = &configLease{
Lease: b.System().DefaultLeaseTTL(),
}
}
// Generate the username, password and expiration. PG limits user to 63 characters
displayName := req.DisplayName
if len(displayName) > 26 {
displayName = displayName[:26]
}
userUUID, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
username := fmt.Sprintf("%s-%s", displayName, userUUID)
if len(username) > 63 {
username = username[:63]
}
password, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
ttl, _, err := framework.CalculateTTL(b.System(), 0, lease.Lease, 0, lease.LeaseMax, 0, time.Time{})
if err != nil {
return nil, err
}
expiration := time.Now().
Add(ttl).
Format("2006-01-02 15:04:05-0700")
// Get our handle
db, err := b.DB(ctx, req.Storage)
if err != nil {
return nil, err
}
// Start a transaction
tx, err := db.Begin()
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(role.SQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"name": username,
"password": password,
"expiration": expiration,
}
if err := dbtxn.ExecuteTxQueryDirect(ctx, tx, m, query); err != nil {
return nil, err
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return nil, err
}
// Return the secret
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
"username": username,
"password": password,
}, map[string]interface{}{
"username": username,
"role": name,
})
resp.Secret.TTL = lease.Lease
resp.Secret.MaxTTL = lease.LeaseMax
return resp, nil
}
const pathRoleCreateReadHelpSyn = `
Request database credentials for a certain role.
`
const pathRoleCreateReadHelpDesc = `
This path reads database credentials for a certain role. The
database credentials will be generated on demand and will be automatically
revoked when the lease is up.
`

View File

@ -1,197 +0,0 @@
package postgresql
import (
"context"
"fmt"
"strings"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func pathListRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "roles/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathRoleList,
},
HelpSynopsis: pathRoleHelpSyn,
HelpDescription: pathRoleHelpDesc,
}
}
func pathRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "roles/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the role.",
},
"sql": {
Type: framework.TypeString,
Description: "SQL string to create a user. See help for more info.",
},
"revocation_sql": {
Type: framework.TypeString,
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
string, a base64-encoded semicolon-separated string, a serialized JSON string
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
will be substituted.`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRoleRead,
logical.UpdateOperation: b.pathRoleCreate,
logical.DeleteOperation: b.pathRoleDelete,
},
HelpSynopsis: pathRoleHelpSyn,
HelpDescription: pathRoleHelpDesc,
}
}
func (b *backend) Role(ctx context.Context, s logical.Storage, n string) (*roleEntry, error) {
entry, err := s.Get(ctx, "role/"+n)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var result roleEntry
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete(ctx, "role/"+data.Get("name").(string))
if err != nil {
return nil, err
}
return nil, nil
}
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
role, err := b.Role(ctx, req.Storage, data.Get("name").(string))
if err != nil {
return nil, err
}
if role == nil {
return nil, nil
}
return &logical.Response{
Data: map[string]interface{}{
"sql": role.SQL,
"revocation_sql": role.RevocationSQL,
},
}, nil
}
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List(ctx, "role/")
if err != nil {
return nil, err
}
return logical.ListResponse(entries), nil
}
func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
sql := data.Get("sql").(string)
// Get our connection
db, err := b.DB(ctx, req.Storage)
if err != nil {
return nil, err
}
// Test the query by trying to prepare it
for _, query := range strutil.ParseArbitraryStringSlice(sql, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := db.Prepare(Query(query, map[string]string{
"name": "foo",
"password": "bar",
"expiration": "",
}))
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error testing query: %s", err)), nil
}
stmt.Close()
}
// Store it
entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
SQL: sql,
RevocationSQL: data.Get("revocation_sql").(string),
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
return nil, nil
}
type roleEntry struct {
SQL string `json:"sql" mapstructure:"sql" structs:"sql"`
RevocationSQL string `json:"revocation_sql" mapstructure:"revocation_sql" structs:"revocation_sql"`
}
const pathRoleHelpSyn = `
Manage the roles that can be created with this backend.
`
const pathRoleHelpDesc = `
This path lets you manage the roles that can be created with this backend.
The "sql" parameter customizes the SQL string used to create the role.
This can be a sequence of SQL queries. Some substitution will be done to the
SQL string for certain keys. The names of the variables must be surrounded
by "{{" and "}}" to be replaced.
* "name" - The random username generated for the DB user.
* "password" - The random password generated for the DB user.
* "expiration" - The timestamp when this user will expire.
Example of a decent SQL query to use:
CREATE ROLE "{{name}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
Note the above user would be able to access everything in schema public.
For more complex GRANT clauses, see the PostgreSQL manual.
The "revocation_sql" parameter customizes the SQL string used to revoke a user.
Example of a decent revocation SQL query to use:
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}};
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}};
REVOKE USAGE ON SCHEMA public FROM {{name}};
DROP ROLE IF EXISTS {{name}};
`

View File

@ -1,15 +0,0 @@
package postgresql
import (
"fmt"
"strings"
)
// Query templates a query for us.
func Query(tpl string, data map[string]string) string {
for k, v := range data {
tpl = strings.ReplaceAll(tpl, fmt.Sprintf("{{%s}}", k), v)
}
return tpl
}

View File

@ -1,269 +0,0 @@
package postgresql
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/dbtxn"
"github.com/hashicorp/vault/sdk/logical"
)
const SecretCredsType = "creds"
func secretCreds(b *backend) *framework.Secret {
return &framework.Secret{
Type: SecretCredsType,
Fields: map[string]*framework.FieldSchema{
"username": {
Type: framework.TypeString,
Description: "Username",
},
"password": {
Type: framework.TypeString,
Description: "Password",
},
},
Renew: b.secretCredsRenew,
Revoke: b.secretCredsRevoke,
}
}
func (b *backend) secretCredsRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the username from the internal data
usernameRaw, ok := req.Secret.InternalData["username"]
if !ok {
return nil, fmt.Errorf("secret is missing username internal data")
}
username, ok := usernameRaw.(string)
if !ok {
return nil, fmt.Errorf("usernameRaw is not a string")
}
// Get our connection
db, err := b.DB(ctx, req.Storage)
if err != nil {
return nil, err
}
// Get the lease information
lease, err := b.Lease(ctx, req.Storage)
if err != nil {
return nil, err
}
if lease == nil {
lease = &configLease{}
}
// Make sure we increase the VALID UNTIL endpoint for this user.
ttl, _, err := framework.CalculateTTL(b.System(), req.Secret.Increment, lease.Lease, 0, lease.LeaseMax, 0, req.Secret.IssueTime)
if err != nil {
return nil, err
}
if ttl > 0 {
expireTime := time.Now().Add(ttl)
// Adding a small buffer since the TTL will be calculated again afeter this call
// to ensure the database credential does not expire before the lease
expireTime = expireTime.Add(5 * time.Second)
expiration := expireTime.Format("2006-01-02 15:04:05-0700")
query := fmt.Sprintf(
"ALTER ROLE %s VALID UNTIL '%s';",
dbutil.QuoteIdentifier(username),
expiration)
stmt, err := db.Prepare(query)
if err != nil {
return nil, err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return nil, err
}
}
resp := &logical.Response{Secret: req.Secret}
resp.Secret.TTL = lease.Lease
resp.Secret.MaxTTL = lease.LeaseMax
return resp, nil
}
func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the username from the internal data
usernameRaw, ok := req.Secret.InternalData["username"]
if !ok {
return nil, fmt.Errorf("secret is missing username internal data")
}
username, ok := usernameRaw.(string)
if !ok {
return nil, fmt.Errorf("usernameRaw is not a string")
}
var revocationSQL string
var resp *logical.Response
roleNameRaw, ok := req.Secret.InternalData["role"]
if ok {
role, err := b.Role(ctx, req.Storage, roleNameRaw.(string))
if err != nil {
return nil, err
}
if role == nil {
if resp == nil {
resp = &logical.Response{}
}
resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default revocation SQL.", roleNameRaw.(string)))
} else {
revocationSQL = role.RevocationSQL
}
}
// Get our connection
db, err := b.DB(ctx, req.Storage)
if err != nil {
return nil, err
}
switch revocationSQL {
// This is the default revocation logic. If revocation SQL is provided it
// is simply executed as-is.
case "":
// Check if the role exists
var exists bool
err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
if err != nil && err != sql.ErrNoRows {
return nil, err
}
if !exists {
return resp, nil
}
// Query for permissions; we need to revoke permissions before we can drop
// the role
// This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible
stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query(username)
if err != nil {
return nil, err
}
defer rows.Close()
const initialNumRevocations = 16
revocationStmts := make([]string, 0, initialNumRevocations)
for rows.Next() {
var schema string
err = rows.Scan(&schema)
if err != nil {
// keep going; remove as many permissions as possible right now
continue
}
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`,
dbutil.QuoteIdentifier(schema),
dbutil.QuoteIdentifier(username)))
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE USAGE ON SCHEMA %s FROM %s;`,
dbutil.QuoteIdentifier(schema),
dbutil.QuoteIdentifier(username)))
}
// for good measure, revoke all privileges and usage on schema public
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`,
dbutil.QuoteIdentifier(username)))
revocationStmts = append(revocationStmts, fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;",
dbutil.QuoteIdentifier(username)))
revocationStmts = append(revocationStmts, fmt.Sprintf(
"REVOKE USAGE ON SCHEMA public FROM %s;",
dbutil.QuoteIdentifier(username)))
// get the current database name so we can issue a REVOKE CONNECT for
// this username
var dbname sql.NullString
if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil {
return nil, err
}
if dbname.Valid {
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE CONNECT ON DATABASE %s FROM %s;`,
dbutil.QuoteIdentifier(dbname.String),
dbutil.QuoteIdentifier(username)))
}
// again, here, we do not stop on error, as we want to remove as
// many permissions as possible right now
var lastStmtError error
for _, query := range revocationStmts {
if err := dbtxn.ExecuteDBQueryDirect(ctx, db, nil, query); err != nil {
lastStmtError = err
}
}
// can't drop if not all privileges are revoked
if rows.Err() != nil {
return nil, fmt.Errorf("could not generate revocation statements for all rows: %w", rows.Err())
}
if lastStmtError != nil {
return nil, fmt.Errorf("could not perform all revocation statements: %w", lastStmtError)
}
// Drop this user
stmt, err = db.Prepare(fmt.Sprintf(
`DROP ROLE IF EXISTS %s;`, dbutil.QuoteIdentifier(username)))
if err != nil {
return nil, err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return nil, err
}
// We have revocation SQL, execute directly, within a transaction
default:
tx, err := db.Begin()
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"name": username,
}
if err := dbtxn.ExecuteTxQueryDirect(ctx, tx, m, query); err != nil {
return nil, err
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
}
return resp, nil
}

6
changelog/18039.txt Normal file
View File

@ -0,0 +1,6 @@
```release-note:improvement
plugins: Mark logical database plugins Removed and remove the plugin code.
```
```release-note:improvement
plugins: Mark app-id auth method Removed and remove the plugin code.
```

View File

@ -49,18 +49,6 @@ func TestAuthEnableCommand_Run(t *testing.T) {
"",
2,
},
{
"deprecated builtin with standard mount",
[]string{"app-id"},
"mount entry associated with pending removal builtin",
2,
},
{
"deprecated builtin with different mount",
[]string{"-path=/tmp", "app-id"},
"mount entry associated with pending removal builtin",
2,
},
}
for _, tc := range cases {

View File

@ -343,11 +343,9 @@ func TestPredict_Plugins(t *testing.T) {
[]string{
"ad",
"alicloud",
"app-id",
"approle",
"aws",
"azure",
"cassandra",
"cassandra-database-plugin",
"centrify",
"cert",
@ -367,13 +365,10 @@ func TestPredict_Plugins(t *testing.T) {
"kubernetes",
"kv",
"ldap",
"mongodb",
"mongodb-database-plugin",
"mongodbatlas",
"mongodbatlas-database-plugin",
"mssql",
"mssql-database-plugin",
"mysql",
"mysql-aurora-database-plugin",
"mysql-database-plugin",
"mysql-legacy-database-plugin",
@ -385,7 +380,6 @@ func TestPredict_Plugins(t *testing.T) {
"openldap",
"pcf", // Deprecated.
"pki",
"postgresql",
"postgresql-database-plugin",
"rabbitmq",
"radius",
@ -439,7 +433,7 @@ func TestPredict_Plugins(t *testing.T) {
}
}
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected:%q, got: %q", tc.exp, act)
t.Errorf("expected: %q, got: %q, diff: %v", tc.exp, act, strutil.Difference(act, tc.exp, true))
}
})
}

View File

@ -1,92 +0,0 @@
package command
import (
"testing"
"github.com/hashicorp/vault/api"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
credAppId "github.com/hashicorp/vault/builtin/credential/app-id"
)
func TestPathMap_Upgrade_API(t *testing.T) {
var err error
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
CredentialBackends: map[string]logical.Factory{
"app-id": credAppId.Factory,
},
PendingRemovalMountsAllowed: true,
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
client := cores[0].Client
// Enable the app-id method
err = client.Sys().EnableAuthWithOptions("app-id", &api.EnableAuthOptions{
Type: "app-id",
})
if err != nil {
t.Fatal(err)
}
// Create an app-id
_, err = client.Logical().Write("auth/app-id/map/app-id/test-app-id", map[string]interface{}{
"policy": "test-policy",
})
if err != nil {
t.Fatal(err)
}
// Create a user-id
_, err = client.Logical().Write("auth/app-id/map/user-id/test-user-id", map[string]interface{}{
"value": "test-app-id",
})
if err != nil {
t.Fatal(err)
}
// Perform a login. It should succeed.
_, err = client.Logical().Write("auth/app-id/login", map[string]interface{}{
"app_id": "test-app-id",
"user_id": "test-user-id",
})
if err != nil {
t.Fatal(err)
}
// List the hashed app-ids in the storage
secret, err := client.Logical().List("auth/app-id/map/app-id")
if err != nil {
t.Fatal(err)
}
hashedAppID := secret.Data["keys"].([]interface{})[0].(string)
// Try reading it. This used to cause an issue which is fixed in [GH-3806].
_, err = client.Logical().Read("auth/app-id/map/app-id/" + hashedAppID)
if err != nil {
t.Fatal(err)
}
// Ensure that there was no issue by performing another login
_, err = client.Logical().Write("auth/app-id/login", map[string]interface{}{
"app_id": "test-app-id",
"user_id": "test-user-id",
})
if err != nil {
t.Fatal(err)
}
}

1
go.mod
View File

@ -202,7 +202,6 @@ require (
google.golang.org/grpc v1.50.1
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0
google.golang.org/protobuf v1.28.1
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce
gopkg.in/ory-am/dockertest.v3 v3.3.4
gopkg.in/square/go-jose.v2 v2.6.0
k8s.io/utils v0.0.0-20220728103510-ee6ede2d64ed

2
go.sum
View File

@ -2599,8 +2599,6 @@ gopkg.in/ini.v1 v1.66.2 h1:XfR1dOYubytKy4Shzc2LHrrGhU0lDCfDGG1yLPmpgsI=
gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/jcmturner/goidentity.v3 v3.0.0 h1:1duIyWiTaYvVx3YX2CYtpJbUFd7/UuPYCfgXtQ3VTbI=
gopkg.in/jcmturner/goidentity.v3 v3.0.0/go.mod h1:oG2kH0IvSYNIu80dVAyu/yoefjq1mNfM5bm88whjWx4=
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce h1:xcEWjVhvbDy+nHP67nPDDpbYrY+ILlfndk4bRioVHaU=
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA=
gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k=
gopkg.in/ory-am/dockertest.v3 v3.3.4 h1:oen8RiwxVNxtQ1pRoV4e4jqh6UjNsOuIZ1NXns6jdcw=
gopkg.in/ory-am/dockertest.v3 v3.3.4/go.mod h1:s9mmoLkaGeAh97qygnNj4xWkiN7e1SKekYC6CovU+ek=

View File

@ -1,6 +1,8 @@
package builtinplugins
import (
"context"
credAliCloud "github.com/hashicorp/vault-plugin-auth-alicloud"
credAzure "github.com/hashicorp/vault-plugin-auth-azure"
credCentrify "github.com/hashicorp/vault-plugin-auth-centrify"
@ -26,7 +28,6 @@ import (
logicalMongoAtlas "github.com/hashicorp/vault-plugin-secrets-mongodbatlas"
logicalLDAP "github.com/hashicorp/vault-plugin-secrets-openldap"
logicalTerraform "github.com/hashicorp/vault-plugin-secrets-terraform"
credAppId "github.com/hashicorp/vault/builtin/credential/app-id"
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
credAws "github.com/hashicorp/vault/builtin/credential/aws"
credCert "github.com/hashicorp/vault/builtin/credential/cert"
@ -36,14 +37,9 @@ import (
credRadius "github.com/hashicorp/vault/builtin/credential/radius"
credUserpass "github.com/hashicorp/vault/builtin/credential/userpass"
logicalAws "github.com/hashicorp/vault/builtin/logical/aws"
logicalCass "github.com/hashicorp/vault/builtin/logical/cassandra"
logicalConsul "github.com/hashicorp/vault/builtin/logical/consul"
logicalMongo "github.com/hashicorp/vault/builtin/logical/mongodb"
logicalMssql "github.com/hashicorp/vault/builtin/logical/mssql"
logicalMysql "github.com/hashicorp/vault/builtin/logical/mysql"
logicalNomad "github.com/hashicorp/vault/builtin/logical/nomad"
logicalPki "github.com/hashicorp/vault/builtin/logical/pki"
logicalPostgres "github.com/hashicorp/vault/builtin/logical/postgresql"
logicalRabbit "github.com/hashicorp/vault/builtin/logical/rabbitmq"
logicalSsh "github.com/hashicorp/vault/builtin/logical/ssh"
logicalTotp "github.com/hashicorp/vault/builtin/logical/totp"
@ -56,6 +52,7 @@ import (
dbMysql "github.com/hashicorp/vault/plugins/database/mysql"
dbPostgres "github.com/hashicorp/vault/plugins/database/postgresql"
dbRedshift "github.com/hashicorp/vault/plugins/database/redshift"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
)
@ -86,13 +83,23 @@ type logicalBackend struct {
consts.DeprecationStatus
}
type removedBackend struct {
*framework.Backend
}
func removedFactory(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) {
removedBackend := &removedBackend{}
removedBackend.Backend = &framework.Backend{}
return removedBackend, nil
}
func newRegistry() *registry {
reg := &registry{
credentialBackends: map[string]credentialBackend{
"alicloud": {Factory: credAliCloud.Factory},
"app-id": {
Factory: credAppId.Factory,
DeprecationStatus: consts.PendingRemoval,
Factory: removedFactory,
DeprecationStatus: consts.Removed,
},
"approle": {Factory: credAppRole.Factory},
"aws": {Factory: credAws.Factory},
@ -144,8 +151,8 @@ func newRegistry() *registry {
"aws": {Factory: logicalAws.Factory},
"azure": {Factory: logicalAzure.Factory},
"cassandra": {
Factory: logicalCass.Factory,
DeprecationStatus: consts.PendingRemoval,
Factory: removedFactory,
DeprecationStatus: consts.Removed,
},
"consul": {Factory: logicalConsul.Factory},
"gcp": {Factory: logicalGcp.Factory},
@ -153,25 +160,27 @@ func newRegistry() *registry {
"kubernetes": {Factory: logicalKube.Factory},
"kv": {Factory: logicalKv.Factory},
"mongodb": {
Factory: logicalMongo.Factory,
DeprecationStatus: consts.PendingRemoval,
Factory: removedFactory,
DeprecationStatus: consts.Removed,
},
// The mongodbatlas secrets engine is not the same as the database plugin equivalent
// (`mongodbatlas-database-plugin`), and thus will not be deprecated at this time.
"mongodbatlas": {Factory: logicalMongoAtlas.Factory},
"mssql": {
Factory: logicalMssql.Factory,
DeprecationStatus: consts.PendingRemoval,
Factory: removedFactory,
DeprecationStatus: consts.Removed,
},
"mysql": {
Factory: logicalMysql.Factory,
DeprecationStatus: consts.PendingRemoval,
Factory: removedFactory,
DeprecationStatus: consts.Removed,
},
"nomad": {Factory: logicalNomad.Factory},
"openldap": {Factory: logicalLDAP.Factory},
"ldap": {Factory: logicalLDAP.Factory},
"pki": {Factory: logicalPki.Factory},
"postgresql": {
Factory: logicalPostgres.Factory,
DeprecationStatus: consts.PendingRemoval,
Factory: removedFactory,
DeprecationStatus: consts.Removed,
},
"rabbitmq": {Factory: logicalRabbit.Factory},
"ssh": {Factory: logicalSsh.Factory},
@ -222,16 +231,16 @@ func (r *registry) Keys(pluginType consts.PluginType) []string {
var keys []string
switch pluginType {
case consts.PluginTypeDatabase:
for key := range r.databasePlugins {
keys = append(keys, key)
for key, backend := range r.databasePlugins {
keys = appendIfNotRemoved(keys, key, backend.DeprecationStatus)
}
case consts.PluginTypeCredential:
for key := range r.credentialBackends {
keys = append(keys, key)
for key, backend := range r.credentialBackends {
keys = appendIfNotRemoved(keys, key, backend.DeprecationStatus)
}
case consts.PluginTypeSecrets:
for key := range r.logicalBackends {
keys = append(keys, key)
for key, backend := range r.logicalBackends {
keys = appendIfNotRemoved(keys, key, backend.DeprecationStatus)
}
}
return keys
@ -273,3 +282,10 @@ func toFunc(ifc interface{}) func() (interface{}, error) {
return ifc, nil
}
}
func appendIfNotRemoved(keys []string, name string, status consts.DeprecationStatus) []string {
if status != consts.Removed {
return append(keys, name)
}
return keys
}

View File

@ -4,7 +4,7 @@ import (
"reflect"
"testing"
credAppId "github.com/hashicorp/vault/builtin/credential/app-id"
credUserpass "github.com/hashicorp/vault/builtin/credential/userpass"
dbMysql "github.com/hashicorp/vault/plugins/database/mysql"
"github.com/hashicorp/vault/sdk/helper/consts"
)
@ -35,9 +35,16 @@ func Test_RegistryGet(t *testing.T) {
},
{
name: "known builtin lookup",
builtin: "userpass",
pluginType: consts.PluginTypeCredential,
want: toFunc(credUserpass.Factory),
wantOk: true,
},
{
name: "removed builtin lookup",
builtin: "app-id",
pluginType: consts.PluginTypeCredential,
want: toFunc(credAppId.Factory),
want: nil,
wantOk: true,
},
{
@ -81,7 +88,7 @@ func Test_RegistryKeyCounts(t *testing.T) {
{
name: "number of auth plugins",
pluginType: consts.PluginTypeCredential,
want: 20,
want: 19,
},
{
name: "number of database plugins",
@ -91,7 +98,7 @@ func Test_RegistryKeyCounts(t *testing.T) {
{
name: "number of secrets plugins",
pluginType: consts.PluginTypeSecrets,
want: 24,
want: 19,
},
}
for _, tt := range tests {
@ -126,10 +133,16 @@ func Test_RegistryContains(t *testing.T) {
},
{
name: "known builtin lookup",
builtin: "app-id",
builtin: "approle",
pluginType: consts.PluginTypeCredential,
want: true,
},
{
name: "removed builtin lookup",
builtin: "app-id",
pluginType: consts.PluginTypeCredential,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -186,10 +199,10 @@ func Test_RegistryStatus(t *testing.T) {
wantOk: true,
},
{
name: "pending removal builtin lookup",
name: "removed builtin lookup",
builtin: "app-id",
pluginType: consts.PluginTypeCredential,
want: consts.PendingRemoval,
want: consts.Removed,
wantOk: true,
},
}