Mark deprecated builtins Removed (#18039)
* Remove logical database builtins * Drop removed builtins from registry keys * Update plugin prediction test * Remove app-id builtin * Add changelog
This commit is contained in:
parent
25d0afae23
commit
43a78c85f4
|
@ -1,184 +0,0 @@
|
|||
package appId
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/salt"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b, err := Backend(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func Backend(conf *logical.BackendConfig) (*backend, error) {
|
||||
var b backend
|
||||
b.MapAppId = &framework.PolicyMap{
|
||||
PathMap: framework.PathMap{
|
||||
Name: "app-id",
|
||||
Schema: map[string]*framework.FieldSchema{
|
||||
"display_name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "A name to map to this app ID for logs.",
|
||||
},
|
||||
|
||||
"value": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Policies for the app ID.",
|
||||
},
|
||||
},
|
||||
},
|
||||
DefaultKey: "default",
|
||||
}
|
||||
|
||||
b.MapUserId = &framework.PathMap{
|
||||
Name: "user-id",
|
||||
Schema: map[string]*framework.FieldSchema{
|
||||
"cidr_block": {
|
||||
Type: framework.TypeString,
|
||||
Description: "If not blank, restricts auth by this CIDR block",
|
||||
},
|
||||
|
||||
"value": {
|
||||
Type: framework.TypeString,
|
||||
Description: "App IDs that this user associates with.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
b.Backend = &framework.Backend{
|
||||
Help: backendHelp,
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
Unauthenticated: []string{
|
||||
"login",
|
||||
"login/*",
|
||||
},
|
||||
},
|
||||
Paths: framework.PathAppend([]*framework.Path{
|
||||
pathLogin(&b),
|
||||
pathLoginWithAppIDPath(&b),
|
||||
},
|
||||
b.MapAppId.Paths(),
|
||||
b.MapUserId.Paths(),
|
||||
),
|
||||
AuthRenew: b.pathLoginRenew,
|
||||
Invalidate: b.invalidate,
|
||||
BackendType: logical.TypeCredential,
|
||||
}
|
||||
|
||||
b.view = conf.StorageView
|
||||
b.MapAppId.SaltFunc = b.Salt
|
||||
b.MapUserId.SaltFunc = b.Salt
|
||||
|
||||
return &b, nil
|
||||
}
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
|
||||
salt *salt.Salt
|
||||
SaltMutex sync.RWMutex
|
||||
view logical.Storage
|
||||
MapAppId *framework.PolicyMap
|
||||
MapUserId *framework.PathMap
|
||||
}
|
||||
|
||||
func (b *backend) Salt(ctx context.Context) (*salt.Salt, error) {
|
||||
b.SaltMutex.RLock()
|
||||
if b.salt != nil {
|
||||
defer b.SaltMutex.RUnlock()
|
||||
return b.salt, nil
|
||||
}
|
||||
b.SaltMutex.RUnlock()
|
||||
b.SaltMutex.Lock()
|
||||
defer b.SaltMutex.Unlock()
|
||||
if b.salt != nil {
|
||||
return b.salt, nil
|
||||
}
|
||||
salt, err := salt.NewSalt(ctx, b.view, &salt.Config{
|
||||
HashFunc: salt.SHA1Hash,
|
||||
Location: salt.DefaultLocation,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.salt = salt
|
||||
return salt, nil
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(_ context.Context, key string) {
|
||||
switch key {
|
||||
case salt.DefaultLocation:
|
||||
b.SaltMutex.Lock()
|
||||
defer b.SaltMutex.Unlock()
|
||||
b.salt = nil
|
||||
}
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
The App ID credential provider is used to perform authentication from
|
||||
within applications or machine by pairing together two hard-to-guess
|
||||
unique pieces of information: a unique app ID, and a unique user ID.
|
||||
|
||||
The goal of this credential provider is to allow elastic users
|
||||
(dynamic machines, containers, etc.) to authenticate with Vault without
|
||||
having to store passwords outside of Vault. It is a single method of
|
||||
solving the chicken-and-egg problem of setting up Vault access on a machine.
|
||||
With this provider, nobody except the machine itself has access to both
|
||||
pieces of information necessary to authenticate. For example:
|
||||
configuration management will have the app IDs, but the machine itself
|
||||
will detect its user ID based on some unique machine property such as a
|
||||
MAC address (or a hash of it with some salt).
|
||||
|
||||
An example, real world process for using this provider:
|
||||
|
||||
1. Create unique app IDs (UUIDs work well) and map them to policies.
|
||||
(Path: map/app-id/<app-id>)
|
||||
|
||||
2. Store the app IDs within configuration management systems.
|
||||
|
||||
3. An out-of-band process run by security operators map unique user IDs
|
||||
to these app IDs. Example: when an instance is launched, a cloud-init
|
||||
system tells security operators a unique ID for this machine. This
|
||||
process can be scripted, but the key is that it is out-of-band and
|
||||
out of reach of configuration management.
|
||||
(Path: map/user-id/<user-id>)
|
||||
|
||||
4. A new server is provisioned. Configuration management configures the
|
||||
app ID, the server itself detects its user ID. With both of these
|
||||
pieces of information, Vault can be accessed according to the policy
|
||||
set by the app ID.
|
||||
|
||||
More details on this process follow:
|
||||
|
||||
The app ID is a unique ID that maps to a set of policies. This ID is
|
||||
generated by an operator and configured into the backend. The ID itself
|
||||
is usually a UUID, but any hard-to-guess unique value can be used.
|
||||
|
||||
After creating app IDs, an operator authorizes a fixed set of user IDs
|
||||
with each app ID. When a valid {app ID, user ID} tuple is given to the
|
||||
"login" path, then the user is authenticated with the configured app
|
||||
ID policies.
|
||||
|
||||
The user ID can be any value (just like the app ID), however it is
|
||||
generally a value unique to a machine, such as a MAC address or instance ID,
|
||||
or a value hashed from these unique values.
|
||||
|
||||
It is possible to authorize multiple app IDs with each
|
||||
user ID by writing them as comma-separated values to the map/user-id/<user-id>
|
||||
path.
|
||||
|
||||
It is also possible to renew the auth tokens with 'vault token-renew <token>' command.
|
||||
Before the token is renewed, the validity of app ID, user ID and the associated
|
||||
policies are checked again.
|
||||
`
|
|
@ -1,239 +0,0 @@
|
|||
package appId
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
|
||||
"github.com/hashicorp/vault/sdk/helper/salt"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
var b *backend
|
||||
var err error
|
||||
var storage logical.Storage
|
||||
factory := func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b, err = Backend(conf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
storage = conf.StorageView
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
CredentialFactory: factory,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepMapAppId(t),
|
||||
testAccStepMapUserId(t),
|
||||
testAccLogin(t, ""),
|
||||
testAccLoginAppIDInPath(t, ""),
|
||||
testAccLoginInvalid(t),
|
||||
testAccStepDeleteUserId(t),
|
||||
testAccLoginDeleted(t),
|
||||
},
|
||||
})
|
||||
|
||||
req := &logical.Request{
|
||||
Path: "map/app-id",
|
||||
Operation: logical.ListOperation,
|
||||
Storage: storage,
|
||||
}
|
||||
resp, err := b.HandleRequest(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
keys := resp.Data["keys"].([]string)
|
||||
if len(keys) != 1 {
|
||||
t.Fatalf("expected 1 key, got %d", len(keys))
|
||||
}
|
||||
bSalt, err := b.Salt(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if keys[0] != "s"+bSalt.SaltIDHashFunc("foo", salt.SHA256Hash) {
|
||||
t.Fatal("value was improperly salted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_cidr(t *testing.T) {
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
CredentialFactory: Factory,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepMapAppIdDisplayName(t),
|
||||
testAccStepMapUserIdCidr(t, "192.168.1.0/16"),
|
||||
testAccLoginCidr(t, "192.168.1.5", false),
|
||||
testAccLoginCidr(t, "10.0.1.5", true),
|
||||
testAccLoginCidr(t, "", true),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_displayName(t *testing.T) {
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
CredentialFactory: Factory,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepMapAppIdDisplayName(t),
|
||||
testAccStepMapUserId(t),
|
||||
testAccLogin(t, "tubbin"),
|
||||
testAccLoginAppIDInPath(t, "tubbin"),
|
||||
testAccLoginInvalid(t),
|
||||
testAccStepDeleteUserId(t),
|
||||
testAccLoginDeleted(t),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func testAccStepMapAppId(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "map/app-id/foo",
|
||||
Data: map[string]interface{}{
|
||||
"value": "foo,bar",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepMapAppIdDisplayName(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "map/app-id/foo",
|
||||
Data: map[string]interface{}{
|
||||
"display_name": "tubbin",
|
||||
"value": "foo,bar",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepMapUserId(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "map/user-id/42",
|
||||
Data: map[string]interface{}{
|
||||
"value": "foo",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDeleteUserId(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: "map/user-id/42",
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepMapUserIdCidr(t *testing.T, cidr string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "map/user-id/42",
|
||||
Data: map[string]interface{}{
|
||||
"value": "foo",
|
||||
"cidr_block": cidr,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
|
||||
checkTTL := func(resp *logical.Response) error {
|
||||
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
|
||||
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "login",
|
||||
Data: map[string]interface{}{
|
||||
"app_id": "foo",
|
||||
"user_id": "42",
|
||||
},
|
||||
Unauthenticated: true,
|
||||
|
||||
Check: logicaltest.TestCheckMulti(
|
||||
logicaltest.TestCheckAuth([]string{"bar", "default", "foo"}),
|
||||
logicaltest.TestCheckAuthDisplayName(display),
|
||||
checkTTL,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep {
|
||||
checkTTL := func(resp *logical.Response) error {
|
||||
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
|
||||
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "login/foo",
|
||||
Data: map[string]interface{}{
|
||||
"user_id": "42",
|
||||
},
|
||||
Unauthenticated: true,
|
||||
|
||||
Check: logicaltest.TestCheckMulti(
|
||||
logicaltest.TestCheckAuth([]string{"bar", "default", "foo"}),
|
||||
logicaltest.TestCheckAuthDisplayName(display),
|
||||
checkTTL,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func testAccLoginCidr(t *testing.T, ip string, err bool) logicaltest.TestStep {
|
||||
check := logicaltest.TestCheckError()
|
||||
if !err {
|
||||
check = logicaltest.TestCheckAuth([]string{"bar", "default", "foo"})
|
||||
}
|
||||
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "login",
|
||||
Data: map[string]interface{}{
|
||||
"app_id": "foo",
|
||||
"user_id": "42",
|
||||
},
|
||||
ErrorOk: err,
|
||||
Unauthenticated: true,
|
||||
RemoteAddr: ip,
|
||||
|
||||
Check: check,
|
||||
}
|
||||
}
|
||||
|
||||
func testAccLoginInvalid(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "login",
|
||||
Data: map[string]interface{}{
|
||||
"app_id": "foo",
|
||||
"user_id": "48",
|
||||
},
|
||||
ErrorOk: true,
|
||||
Unauthenticated: true,
|
||||
|
||||
Check: logicaltest.TestCheckError(),
|
||||
}
|
||||
}
|
||||
|
||||
func testAccLoginDeleted(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "login",
|
||||
Data: map[string]interface{}{
|
||||
"app_id": "foo",
|
||||
"user_id": "42",
|
||||
},
|
||||
ErrorOk: true,
|
||||
Unauthenticated: true,
|
||||
|
||||
Check: logicaltest.TestCheckError(),
|
||||
}
|
||||
}
|
|
@ -1,29 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
appId "github.com/hashicorp/vault/builtin/credential/app-id"
|
||||
"github.com/hashicorp/vault/sdk/plugin"
|
||||
)
|
||||
|
||||
func main() {
|
||||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args[1:])
|
||||
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
if err := plugin.Serve(&plugin.ServeOpts{
|
||||
BackendFactoryFunc: appId.Factory,
|
||||
TLSProviderFunc: tlsProviderFunc,
|
||||
}); err != nil {
|
||||
logger := hclog.New(&hclog.LoggerOptions{})
|
||||
|
||||
logger.Error("plugin shutting down", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
|
@ -1,229 +0,0 @@
|
|||
package appId
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/policyutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathLoginWithAppIDPath(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "login/(?P<app_id>.+)",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"app_id": {
|
||||
Type: framework.TypeString,
|
||||
Description: "The unique app ID",
|
||||
},
|
||||
|
||||
"user_id": {
|
||||
Type: framework.TypeString,
|
||||
Description: "The unique user ID",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: b.pathLogin,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathLoginSyn,
|
||||
HelpDescription: pathLoginDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathLogin(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "login$",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"app_id": {
|
||||
Type: framework.TypeString,
|
||||
Description: "The unique app ID",
|
||||
},
|
||||
|
||||
"user_id": {
|
||||
Type: framework.TypeString,
|
||||
Description: "The unique user ID",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: b.pathLogin,
|
||||
logical.AliasLookaheadOperation: b.pathLoginAliasLookahead,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathLoginSyn,
|
||||
HelpDescription: pathLoginDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathLoginAliasLookahead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
appId := data.Get("app_id").(string)
|
||||
|
||||
if appId == "" {
|
||||
return nil, fmt.Errorf("missing app_id")
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Auth: &logical.Auth{
|
||||
Alias: &logical.Alias{
|
||||
Name: appId,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
appId := data.Get("app_id").(string)
|
||||
userId := data.Get("user_id").(string)
|
||||
|
||||
var displayName string
|
||||
if dispName, resp, err := b.verifyCredentials(ctx, req, appId, userId); err != nil {
|
||||
return nil, err
|
||||
} else if resp != nil {
|
||||
return resp, nil
|
||||
} else {
|
||||
displayName = dispName
|
||||
}
|
||||
|
||||
// Get the policies associated with the app
|
||||
policies, err := b.MapAppId.Policies(ctx, req.Storage, appId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store hashes of the app ID and user ID for the metadata
|
||||
appIdHash := sha1.Sum([]byte(appId))
|
||||
userIdHash := sha1.Sum([]byte(userId))
|
||||
metadata := map[string]string{
|
||||
"app-id": "sha1:" + hex.EncodeToString(appIdHash[:]),
|
||||
"user-id": "sha1:" + hex.EncodeToString(userIdHash[:]),
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Auth: &logical.Auth{
|
||||
InternalData: map[string]interface{}{
|
||||
"app-id": appId,
|
||||
"user-id": userId,
|
||||
},
|
||||
DisplayName: displayName,
|
||||
Policies: policies,
|
||||
Metadata: metadata,
|
||||
LeaseOptions: logical.LeaseOptions{
|
||||
Renewable: true,
|
||||
},
|
||||
Alias: &logical.Alias{
|
||||
Name: appId,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
appId := req.Auth.InternalData["app-id"].(string)
|
||||
userId := req.Auth.InternalData["user-id"].(string)
|
||||
|
||||
// Skipping CIDR verification to enable renewal from machines other than
|
||||
// the ones encompassed by CIDR block.
|
||||
if _, resp, err := b.verifyCredentials(ctx, req, appId, userId); err != nil {
|
||||
return nil, err
|
||||
} else if resp != nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Get the policies associated with the app
|
||||
mapPolicies, err := b.MapAppId.Policies(ctx, req.Storage, appId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !policyutil.EquivalentPolicies(mapPolicies, req.Auth.TokenPolicies) {
|
||||
return nil, fmt.Errorf("policies do not match")
|
||||
}
|
||||
|
||||
return &logical.Response{Auth: req.Auth}, nil
|
||||
}
|
||||
|
||||
func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, appId, userId string) (string, *logical.Response, error) {
|
||||
// Ensure both appId and userId are provided
|
||||
if appId == "" || userId == "" {
|
||||
return "", logical.ErrorResponse("missing 'app_id' or 'user_id'"), nil
|
||||
}
|
||||
|
||||
// Look up the apps that this user is allowed to access
|
||||
appsMap, err := b.MapUserId.Get(ctx, req.Storage, userId)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if appsMap == nil {
|
||||
return "", logical.ErrorResponse("invalid user ID or app ID"), nil
|
||||
}
|
||||
|
||||
// If there is a CIDR block restriction, check that
|
||||
if raw, ok := appsMap["cidr_block"]; ok {
|
||||
_, cidr, err := net.ParseCIDR(raw.(string))
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid restriction cidr: %w", err)
|
||||
}
|
||||
|
||||
var addr string
|
||||
if req.Connection != nil {
|
||||
addr = req.Connection.RemoteAddr
|
||||
}
|
||||
if addr == "" || !cidr.Contains(net.ParseIP(addr)) {
|
||||
return "", logical.ErrorResponse("unauthorized source address"), nil
|
||||
}
|
||||
}
|
||||
|
||||
appsRaw, ok := appsMap["value"]
|
||||
if !ok {
|
||||
appsRaw = ""
|
||||
}
|
||||
|
||||
apps, ok := appsRaw.(string)
|
||||
if !ok {
|
||||
return "", nil, fmt.Errorf("mapping is not a string")
|
||||
}
|
||||
|
||||
// Verify that the app is in the list
|
||||
found := false
|
||||
appIdBytes := []byte(appId)
|
||||
for _, app := range strings.Split(apps, ",") {
|
||||
match := []byte(strings.TrimSpace(app))
|
||||
// Protect against a timing attack with the app_id comparison
|
||||
if subtle.ConstantTimeCompare(match, appIdBytes) == 1 {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return "", logical.ErrorResponse("invalid user ID or app ID"), nil
|
||||
}
|
||||
|
||||
// Get the raw data associated with the app
|
||||
appRaw, err := b.MapAppId.Get(ctx, req.Storage, appId)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if appRaw == nil {
|
||||
return "", logical.ErrorResponse("invalid user ID or app ID"), nil
|
||||
}
|
||||
var displayName string
|
||||
if raw, ok := appRaw["display_name"]; ok {
|
||||
displayName = raw.(string)
|
||||
}
|
||||
|
||||
return displayName, nil, nil
|
||||
}
|
||||
|
||||
const pathLoginSyn = `
|
||||
Log in with an App ID and User ID.
|
||||
`
|
||||
|
||||
const pathLoginDesc = `
|
||||
This endpoint authenticates using an application ID, user ID and potential the IP address of the connecting client.
|
||||
`
|
|
@ -1,134 +0,0 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
// Factory creates a new backend
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// Backend contains the base information for the backend's functionality
|
||||
func Backend() *backend {
|
||||
var b backend
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
SealWrapStorage: []string{
|
||||
"config/connection",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathRoles(&b),
|
||||
pathCredsCreate(&b),
|
||||
},
|
||||
|
||||
Secrets: []*framework.Secret{
|
||||
secretCreds(&b),
|
||||
},
|
||||
|
||||
Invalidate: b.invalidate,
|
||||
|
||||
Clean: func(_ context.Context) {
|
||||
b.ResetDB(nil)
|
||||
},
|
||||
BackendType: logical.TypeLogical,
|
||||
}
|
||||
|
||||
return &b
|
||||
}
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
|
||||
// Session is goroutine safe, however, since we reinitialize
|
||||
// it when connection info changes, we want to make sure we
|
||||
// can close it and use a new connection; hence the lock
|
||||
session *gocql.Session
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
type sessionConfig struct {
|
||||
Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"`
|
||||
Username string `json:"username" structs:"username" mapstructure:"username"`
|
||||
Password string `json:"password" structs:"password" mapstructure:"password"`
|
||||
TLS bool `json:"tls" structs:"tls" mapstructure:"tls"`
|
||||
InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
|
||||
Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"`
|
||||
PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"`
|
||||
IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"`
|
||||
ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
|
||||
ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
|
||||
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
|
||||
}
|
||||
|
||||
// DB returns the database connection.
|
||||
func (b *backend) DB(ctx context.Context, s logical.Storage) (*gocql.Session, error) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
// If we already have a DB, we got it!
|
||||
if b.session != nil {
|
||||
return b.session, nil
|
||||
}
|
||||
|
||||
entry, err := s.Get(ctx, "config/connection")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, fmt.Errorf("configure the DB connection with config/connection first")
|
||||
}
|
||||
|
||||
config := &sessionConfig{}
|
||||
if err := entry.DecodeJSON(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, err := createSession(config, s)
|
||||
// Store the session in backend for reuse
|
||||
b.session = session
|
||||
|
||||
return session, err
|
||||
}
|
||||
|
||||
// ResetDB forces a connection next time DB() is called.
|
||||
func (b *backend) ResetDB(newSession *gocql.Session) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
if b.session != nil {
|
||||
b.session.Close()
|
||||
}
|
||||
|
||||
b.session = newSession
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(_ context.Context, key string) {
|
||||
switch key {
|
||||
case "config/connection":
|
||||
b.ResetDB(nil)
|
||||
}
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
The Cassandra backend dynamically generates database users.
|
||||
|
||||
After mounting this backend, configure it using the endpoints within
|
||||
the "config/" path.
|
||||
`
|
|
@ -1,163 +0,0 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/helper/testhelpers/cassandra"
|
||||
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
copyFromTo := map[string]string{
|
||||
"test-fixtures/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
|
||||
}
|
||||
host, cleanup := cassandra.PrepareTestContainer(t,
|
||||
cassandra.CopyFromTo(copyFromTo),
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, host.ConnectionURL()),
|
||||
testAccStepRole(t),
|
||||
testAccStepReadCreds(t, "test"),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_roleCrud(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
copyFromTo := map[string]string{
|
||||
"test-fixtures/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
|
||||
}
|
||||
host, cleanup := cassandra.PrepareTestContainer(t,
|
||||
cassandra.CopyFromTo(copyFromTo))
|
||||
defer cleanup()
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, host.ConnectionURL()),
|
||||
testAccStepRole(t),
|
||||
testAccStepRoleWithOptions(t),
|
||||
testAccStepReadRole(t, "test", testRole),
|
||||
testAccStepReadRole(t, "test2", testRole),
|
||||
testAccStepDeleteRole(t, "test"),
|
||||
testAccStepDeleteRole(t, "test2"),
|
||||
testAccStepReadRole(t, "test", ""),
|
||||
testAccStepReadRole(t, "test2", ""),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func testAccStepConfig(t *testing.T, hostname string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/connection",
|
||||
Data: map[string]interface{}{
|
||||
"hosts": hostname,
|
||||
"username": "cassandra",
|
||||
"password": "cassandra",
|
||||
"protocol_version": 3,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepRole(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/test",
|
||||
Data: map[string]interface{}{
|
||||
"creation_cql": testRole,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepRoleWithOptions(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/test2",
|
||||
Data: map[string]interface{}{
|
||||
"creation_cql": testRole,
|
||||
"lease": "30s",
|
||||
"consistency": "All",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDeleteRole(t *testing.T, n string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: "roles/" + n,
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadCreds(t *testing.T, name string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "creds/" + name,
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[WARN] Generated credentials: %v", d)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadRole(t *testing.T, name string, cql string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "roles/" + name,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp == nil {
|
||||
if cql == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("response is nil")
|
||||
}
|
||||
|
||||
var d struct {
|
||||
CreationCQL string `mapstructure:"creation_cql"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.CreationCQL != cql {
|
||||
return fmt.Errorf("bad: %#v\n%#v\n%#v\n", resp, cql, d.CreationCQL)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const testRole = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;
|
||||
GRANT ALL PERMISSIONS ON ALL KEYSPACES TO {{username}};`
|
|
@ -1,29 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/cassandra"
|
||||
"github.com/hashicorp/vault/sdk/plugin"
|
||||
)
|
||||
|
||||
func main() {
|
||||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args[1:])
|
||||
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
if err := plugin.Serve(&plugin.ServeOpts{
|
||||
BackendFactoryFunc: cassandra.Factory,
|
||||
TLSProviderFunc: tlsProviderFunc,
|
||||
}); err != nil {
|
||||
logger := hclog.New(&hclog.LoggerOptions{})
|
||||
|
||||
logger.Error("plugin shutting down", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
|
@ -1,245 +0,0 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/go-secure-stdlib/tlsutil"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/certutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathConfigConnection(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "config/connection",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"hosts": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Comma-separated list of hosts",
|
||||
},
|
||||
|
||||
"username": {
|
||||
Type: framework.TypeString,
|
||||
Description: "The username to use for connecting to the cluster",
|
||||
},
|
||||
|
||||
"password": {
|
||||
Type: framework.TypeString,
|
||||
Description: "The password to use for connecting to the cluster",
|
||||
},
|
||||
|
||||
"tls": {
|
||||
Type: framework.TypeBool,
|
||||
Description: `Whether to use TLS. If pem_bundle or pem_json are
|
||||
set, this is automatically set to true`,
|
||||
},
|
||||
|
||||
"insecure_tls": {
|
||||
Type: framework.TypeBool,
|
||||
Description: `Whether to use TLS but skip verification; has no
|
||||
effect if a CA certificate is provided`,
|
||||
},
|
||||
|
||||
// TLS 1.3 is not supported as this engine is deprecated. Please switch to the Cassandra database secrets engine
|
||||
"tls_min_version": {
|
||||
Type: framework.TypeString,
|
||||
Default: "tls12",
|
||||
Description: "Minimum TLS version to use. Accepted values are 'tls10', 'tls11' or 'tls12'. Defaults to 'tls12'",
|
||||
},
|
||||
|
||||
"pem_bundle": {
|
||||
Type: framework.TypeString,
|
||||
Description: `PEM-format, concatenated unencrypted secret key
|
||||
and certificate, with optional CA certificate`,
|
||||
},
|
||||
|
||||
"pem_json": {
|
||||
Type: framework.TypeString,
|
||||
Description: `JSON containing a PEM-format, unencrypted secret
|
||||
key and certificate, with optional CA certificate.
|
||||
The JSON output of a certificate issued with the PKI
|
||||
backend can be directly passed into this parameter.
|
||||
If both this and "pem_bundle" are specified, this will
|
||||
take precedence.`,
|
||||
},
|
||||
|
||||
"protocol_version": {
|
||||
Type: framework.TypeInt,
|
||||
Description: `The protocol version to use. Defaults to 2.`,
|
||||
},
|
||||
|
||||
"connect_timeout": {
|
||||
Type: framework.TypeDurationSecond,
|
||||
Default: 5,
|
||||
Description: `The connection timeout to use. Defaults to 5.`,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathConnectionRead,
|
||||
logical.UpdateOperation: b.pathConnectionWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigConnectionHelpSyn,
|
||||
HelpDescription: pathConfigConnectionHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathConnectionRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
entry, err := req.Storage.Get(ctx, "config/connection")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Configure the DB connection with config/connection first")), nil
|
||||
}
|
||||
|
||||
config := &sessionConfig{}
|
||||
if err := entry.DecodeJSON(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"hosts": config.Hosts,
|
||||
"username": config.Username,
|
||||
"tls": config.TLS,
|
||||
"insecure_tls": config.InsecureTLS,
|
||||
"certificate": config.Certificate,
|
||||
"issuing_ca": config.IssuingCA,
|
||||
"protocol_version": config.ProtocolVersion,
|
||||
"connect_timeout": config.ConnectTimeout,
|
||||
"tls_min_version": config.TLSMinVersion,
|
||||
},
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathConnectionWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
hosts := data.Get("hosts").(string)
|
||||
username := data.Get("username").(string)
|
||||
password := data.Get("password").(string)
|
||||
|
||||
switch {
|
||||
case len(hosts) == 0:
|
||||
return logical.ErrorResponse("Hosts cannot be empty"), nil
|
||||
case len(username) == 0:
|
||||
return logical.ErrorResponse("Username cannot be empty"), nil
|
||||
case len(password) == 0:
|
||||
return logical.ErrorResponse("Password cannot be empty"), nil
|
||||
}
|
||||
|
||||
config := &sessionConfig{
|
||||
Hosts: hosts,
|
||||
Username: username,
|
||||
Password: password,
|
||||
TLS: data.Get("tls").(bool),
|
||||
InsecureTLS: data.Get("insecure_tls").(bool),
|
||||
ProtocolVersion: data.Get("protocol_version").(int),
|
||||
ConnectTimeout: data.Get("connect_timeout").(int),
|
||||
}
|
||||
|
||||
config.TLSMinVersion = data.Get("tls_min_version").(string)
|
||||
if config.TLSMinVersion == "" {
|
||||
return logical.ErrorResponse("failed to get 'tls_min_version' value"), nil
|
||||
}
|
||||
|
||||
var ok bool
|
||||
_, ok = tlsutil.TLSLookup[config.TLSMinVersion]
|
||||
if !ok {
|
||||
return logical.ErrorResponse("invalid 'tls_min_version'"), nil
|
||||
}
|
||||
|
||||
if config.InsecureTLS {
|
||||
config.TLS = true
|
||||
}
|
||||
|
||||
pemBundle := data.Get("pem_bundle").(string)
|
||||
pemJSON := data.Get("pem_json").(string)
|
||||
|
||||
var certBundle *certutil.CertBundle
|
||||
var parsedCertBundle *certutil.ParsedCertBundle
|
||||
var err error
|
||||
|
||||
switch {
|
||||
case len(pemJSON) != 0:
|
||||
parsedCertBundle, err = certutil.ParsePKIJSON([]byte(pemJSON))
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: %s", err)), nil
|
||||
}
|
||||
certBundle, err = parsedCertBundle.ToCertBundle()
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Error marshaling PEM information: %s", err)), nil
|
||||
}
|
||||
config.Certificate = certBundle.Certificate
|
||||
config.PrivateKey = certBundle.PrivateKey
|
||||
config.IssuingCA = certBundle.IssuingCA
|
||||
config.TLS = true
|
||||
|
||||
case len(pemBundle) != 0:
|
||||
parsedCertBundle, err = certutil.ParsePEMBundle(pemBundle)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Error parsing the given PEM information: %s", err)), nil
|
||||
}
|
||||
certBundle, err = parsedCertBundle.ToCertBundle()
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Error marshaling PEM information: %s", err)), nil
|
||||
}
|
||||
config.Certificate = certBundle.Certificate
|
||||
config.PrivateKey = certBundle.PrivateKey
|
||||
config.IssuingCA = certBundle.IssuingCA
|
||||
config.TLS = true
|
||||
}
|
||||
|
||||
session, err := createSession(config, req.Storage)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("config/connection", config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reset the DB connection
|
||||
b.ResetDB(session)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
const pathConfigConnectionHelpSyn = `
|
||||
Configure the connection information to talk to Cassandra.
|
||||
`
|
||||
|
||||
const pathConfigConnectionHelpDesc = `
|
||||
This path configures the connection information used to connect to Cassandra.
|
||||
|
||||
"hosts" is a comma-delimited list of hostnames in the Cassandra cluster.
|
||||
|
||||
"username" and "password" are self-explanatory, although the given user
|
||||
must have superuser access within Cassandra. Note that since this backend
|
||||
issues username/password credentials, Cassandra must be configured to use
|
||||
PasswordAuthenticator or a similar backend for its authentication. If you wish
|
||||
to have no authorization in Cassandra and want to use TLS client certificates,
|
||||
see the PKI backend.
|
||||
|
||||
TLS works as follows:
|
||||
|
||||
* If "tls" is set to true, the connection will use TLS; this happens automatically if "pem_bundle", "pem_json", or "insecure_tls" is set
|
||||
|
||||
* If "insecure_tls" is set to true, the connection will not perform verification of the server certificate; this also sets "tls" to true
|
||||
|
||||
* If only "issuing_ca" is set in "pem_json", or the only certificate in "pem_bundle" is a CA certificate, the given CA certificate will be used for server certificate verification; otherwise the system CA certificates will be used
|
||||
|
||||
* If "certificate" and "private_key" are set in "pem_bundle" or "pem_json", client auth will be turned on for the connection
|
||||
|
||||
"pem_bundle" should be a PEM-concatenated bundle of a private key + client certificate, an issuing CA certificate, or both. "pem_json" should contain the same information; for convenience, the JSON format is the same as that output by the issue command from the PKI backend.
|
||||
|
||||
When configuring the connection information, the backend will verify its
|
||||
validity.
|
||||
`
|
|
@ -1,123 +0,0 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathCredsCreate(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "creds/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathCredsCreateRead,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathCredsCreateReadHelpSyn,
|
||||
HelpDescription: pathCredsCreateReadHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathCredsCreateRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
|
||||
// Get the role
|
||||
role, err := getRole(ctx, req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil
|
||||
}
|
||||
|
||||
displayName := req.DisplayName
|
||||
userUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
username := fmt.Sprintf("vault_%s_%s_%s_%d", name, displayName, userUUID, time.Now().Unix())
|
||||
username = strings.ReplaceAll(username, "-", "_")
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get our connection
|
||||
session, err := b.DB(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set consistency
|
||||
if role.Consistency != "" {
|
||||
consistencyValue, err := gocql.ParseConsistencyWrapper(role.Consistency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session.SetConsistency(consistencyValue)
|
||||
}
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(role.CreationCQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
err = session.Query(substQuery(query, map[string]string{
|
||||
"username": username,
|
||||
"password": password,
|
||||
})).Exec()
|
||||
if err != nil {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(role.RollbackCQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
session.Query(substQuery(query, map[string]string{
|
||||
"username": username,
|
||||
"password": password,
|
||||
})).Exec()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Return the secret
|
||||
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}, map[string]interface{}{
|
||||
"username": username,
|
||||
"role": name,
|
||||
})
|
||||
resp.Secret.TTL = role.Lease
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
const pathCredsCreateReadHelpSyn = `
|
||||
Request database credentials for a certain role.
|
||||
`
|
||||
|
||||
const pathCredsCreateReadHelpDesc = `
|
||||
This path creates database credentials for a certain role. The
|
||||
database credentials will be generated on demand and will be automatically
|
||||
revoked when the lease is up.
|
||||
`
|
|
@ -1,196 +0,0 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/structs"
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;`
|
||||
defaultRollbackCQL = `DROP USER '{{username}}';`
|
||||
)
|
||||
|
||||
func pathRoles(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "roles/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role",
|
||||
},
|
||||
|
||||
"creation_cql": {
|
||||
Type: framework.TypeString,
|
||||
Default: defaultCreationCQL,
|
||||
Description: `CQL to create a user and optionally grant
|
||||
authorization. If not supplied, a default that
|
||||
creates non-superuser accounts with the built-in
|
||||
password authenticator will be used; no
|
||||
authorization grants will be configured. Separate
|
||||
statements by semicolons; use @file to load from a
|
||||
file. Valid template values are '{{username}}' and
|
||||
'{{password}}' -- the single quotes are important!`,
|
||||
},
|
||||
|
||||
"rollback_cql": {
|
||||
Type: framework.TypeString,
|
||||
Default: defaultRollbackCQL,
|
||||
Description: `CQL to roll back an account operation. This will
|
||||
be used if there is an error during execution of a
|
||||
statement passed in via the "creation_cql" parameter
|
||||
parameter. The default simply drops the user, which
|
||||
should generally be sufficient. Separate statements
|
||||
by semicolons; use @file to load from a file. Valid
|
||||
template values are '{{username}}' and
|
||||
'{{password}}' -- the single quotes are important!`,
|
||||
},
|
||||
|
||||
"lease": {
|
||||
Type: framework.TypeString,
|
||||
Default: "4h",
|
||||
Description: "The lease length; defaults to 4 hours",
|
||||
},
|
||||
|
||||
"consistency": {
|
||||
Type: framework.TypeString,
|
||||
Default: "Quorum",
|
||||
Description: "The consistency level for the operations; defaults to Quorum.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathRoleRead,
|
||||
logical.UpdateOperation: b.pathRoleCreate,
|
||||
logical.DeleteOperation: b.pathRoleDelete,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleHelpSyn,
|
||||
HelpDescription: pathRoleHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func getRole(ctx context.Context, s logical.Storage, n string) (*roleEntry, error) {
|
||||
entry, err := s.Get(ctx, "role/"+n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result roleEntry
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete(ctx, "role/"+data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
role, err := getRole(ctx, req.Storage, data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: structs.New(role).Map(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
|
||||
creationCQL := data.Get("creation_cql").(string)
|
||||
|
||||
rollbackCQL := data.Get("rollback_cql").(string)
|
||||
|
||||
leaseRaw := data.Get("lease").(string)
|
||||
lease, err := time.ParseDuration(leaseRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Error parsing lease value of %s: %s", leaseRaw, err)), nil
|
||||
}
|
||||
|
||||
consistencyStr := data.Get("consistency").(string)
|
||||
_, err = gocql.ParseConsistencyWrapper(consistencyStr)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Error parsing consistency value of %q: %v", consistencyStr, err)), nil
|
||||
}
|
||||
|
||||
entry := &roleEntry{
|
||||
Lease: lease,
|
||||
CreationCQL: creationCQL,
|
||||
RollbackCQL: rollbackCQL,
|
||||
Consistency: consistencyStr,
|
||||
}
|
||||
|
||||
// Store it
|
||||
entryJSON, err := logical.StorageEntryJSON("role/"+name, entry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entryJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type roleEntry struct {
|
||||
CreationCQL string `json:"creation_cql" structs:"creation_cql"`
|
||||
Lease time.Duration `json:"lease" structs:"lease"`
|
||||
RollbackCQL string `json:"rollback_cql" structs:"rollback_cql"`
|
||||
Consistency string `json:"consistency" structs:"consistency"`
|
||||
}
|
||||
|
||||
const pathRoleHelpSyn = `
|
||||
Manage the roles that can be created with this backend.
|
||||
`
|
||||
|
||||
const pathRoleHelpDesc = `
|
||||
This path lets you manage the roles that can be created with this backend.
|
||||
|
||||
The "creation_cql" parameter customizes the CQL string used to create users
|
||||
and assign them grants. This can be a sequence of CQL queries separated by
|
||||
semicolons. Some substitution will be done to the CQL string for certain keys.
|
||||
The names of the variables must be surrounded by '{{' and '}}' to be replaced.
|
||||
Note that it is important that single quotes are used, not double quotes.
|
||||
|
||||
* "username" - The random username generated for the DB user.
|
||||
|
||||
* "password" - The random password generated for the DB user.
|
||||
|
||||
If no "creation_cql" parameter is given, a default will be used:
|
||||
|
||||
` + defaultCreationCQL + `
|
||||
|
||||
This default should be suitable for Cassandra installations using the password
|
||||
authenticator but not configured to use authorization.
|
||||
|
||||
Similarly, the "rollback_cql" is used if user creation fails, in the absence of
|
||||
Cassandra transactions. The default should be suitable for almost any
|
||||
instance of Cassandra:
|
||||
|
||||
` + defaultRollbackCQL + `
|
||||
|
||||
"lease" the lease time; if not set the mount/system defaults are used.
|
||||
`
|
|
@ -1,77 +0,0 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
// SecretCredsType is the type of creds issued from this backend
|
||||
const SecretCredsType = "cassandra"
|
||||
|
||||
func secretCreds(b *backend) *framework.Secret {
|
||||
return &framework.Secret{
|
||||
Type: SecretCredsType,
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"username": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Username",
|
||||
},
|
||||
|
||||
"password": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Password",
|
||||
},
|
||||
},
|
||||
|
||||
Renew: b.secretCredsRenew,
|
||||
Revoke: b.secretCredsRevoke,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) secretCredsRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the lease information
|
||||
roleRaw, ok := req.Secret.InternalData["role"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret is missing role internal data")
|
||||
}
|
||||
roleName, ok := roleRaw.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error converting role internal data to string")
|
||||
}
|
||||
|
||||
role, err := getRole(ctx, req.Storage, roleName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load role: %w", err)
|
||||
}
|
||||
|
||||
resp := &logical.Response{Secret: req.Secret}
|
||||
resp.Secret.TTL = role.Lease
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the username from the internal data
|
||||
usernameRaw, ok := req.Secret.InternalData["username"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret is missing username internal data")
|
||||
}
|
||||
username, ok := usernameRaw.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error converting username internal data to string")
|
||||
}
|
||||
|
||||
session, err := b.DB(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting session")
|
||||
}
|
||||
|
||||
err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error removing user %q", username)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -1,95 +0,0 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/go-secure-stdlib/tlsutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/certutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
// Query templates a query for us.
|
||||
func substQuery(tpl string, data map[string]string) string {
|
||||
for k, v := range data {
|
||||
tpl = strings.ReplaceAll(tpl, fmt.Sprintf("{{%s}}", k), v)
|
||||
}
|
||||
|
||||
return tpl
|
||||
}
|
||||
|
||||
func createSession(cfg *sessionConfig, s logical.Storage) (*gocql.Session, error) {
|
||||
clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...)
|
||||
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
}
|
||||
|
||||
clusterConfig.ProtoVersion = cfg.ProtocolVersion
|
||||
if clusterConfig.ProtoVersion == 0 {
|
||||
clusterConfig.ProtoVersion = 2
|
||||
}
|
||||
|
||||
clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second
|
||||
|
||||
if cfg.TLS {
|
||||
var tlsConfig *tls.Config
|
||||
if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 {
|
||||
if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 {
|
||||
return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
|
||||
}
|
||||
|
||||
certBundle := &certutil.CertBundle{}
|
||||
if len(cfg.Certificate) > 0 {
|
||||
certBundle.Certificate = cfg.Certificate
|
||||
certBundle.PrivateKey = cfg.PrivateKey
|
||||
}
|
||||
if len(cfg.IssuingCA) > 0 {
|
||||
certBundle.IssuingCA = cfg.IssuingCA
|
||||
}
|
||||
|
||||
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate bundle: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
||||
if err != nil || tlsConfig == nil {
|
||||
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig: %#v; %w", tlsConfig, err)
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = cfg.InsecureTLS
|
||||
|
||||
if cfg.TLSMinVersion != "" {
|
||||
var ok bool
|
||||
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
|
||||
}
|
||||
} else {
|
||||
// MinVersion was not being set earlier. Reset it to
|
||||
// zero to gracefully handle upgrades.
|
||||
tlsConfig.MinVersion = 0
|
||||
}
|
||||
}
|
||||
|
||||
clusterConfig.SslOpts = &gocql.SslOptions{
|
||||
Config: tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
session, err := clusterConfig.CreateSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating session: %w", err)
|
||||
}
|
||||
|
||||
// Verify the info
|
||||
err = session.Query(`LIST USERS`).Exec()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error validating connection info: %w", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
|
@ -1,144 +0,0 @@
|
|||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
mgo "gopkg.in/mgo.v2"
|
||||
)
|
||||
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func Backend() *framework.Backend {
|
||||
var b backend
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
SealWrapStorage: []string{
|
||||
"config/connection",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathConfigLease(&b),
|
||||
pathListRoles(&b),
|
||||
pathRoles(&b),
|
||||
pathCredsCreate(&b),
|
||||
},
|
||||
|
||||
Secrets: []*framework.Secret{
|
||||
secretCreds(&b),
|
||||
},
|
||||
|
||||
Clean: b.ResetSession,
|
||||
|
||||
Invalidate: b.invalidate,
|
||||
BackendType: logical.TypeLogical,
|
||||
}
|
||||
|
||||
return b.Backend
|
||||
}
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
|
||||
session *mgo.Session
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
// Session returns the database connection.
|
||||
func (b *backend) Session(ctx context.Context, s logical.Storage) (*mgo.Session, error) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
if b.session != nil {
|
||||
if err := b.session.Ping(); err == nil {
|
||||
return b.session, nil
|
||||
}
|
||||
b.session.Close()
|
||||
}
|
||||
|
||||
connConfigJSON, err := s.Get(ctx, "config/connection")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if connConfigJSON == nil {
|
||||
return nil, fmt.Errorf("configure the MongoDB connection with config/connection first")
|
||||
}
|
||||
|
||||
var connConfig connectionConfig
|
||||
if err := connConfigJSON.DecodeJSON(&connConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialInfo, err := parseMongoURI(connConfig.URI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.session, err = mgo.DialWithInfo(dialInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.session.SetSyncTimeout(1 * time.Minute)
|
||||
b.session.SetSocketTimeout(1 * time.Minute)
|
||||
|
||||
return b.session, nil
|
||||
}
|
||||
|
||||
// ResetSession forces creation of a new connection next time Session() is called.
|
||||
func (b *backend) ResetSession(_ context.Context) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
if b.session != nil {
|
||||
b.session.Close()
|
||||
}
|
||||
|
||||
b.session = nil
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(ctx context.Context, key string) {
|
||||
switch key {
|
||||
case "config/connection":
|
||||
b.ResetSession(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// LeaseConfig returns the lease configuration
|
||||
func (b *backend) LeaseConfig(ctx context.Context, s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get(ctx, "config/lease")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result configLease
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
The mongodb backend dynamically generates MongoDB credentials.
|
||||
|
||||
After mounting this backend, configure it using the endpoints within
|
||||
the "config/" path.
|
||||
`
|
|
@ -1,268 +0,0 @@
|
|||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
|
||||
"github.com/hashicorp/vault/helper/testhelpers/mongodb"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
var testImagePull sync.Once
|
||||
|
||||
func TestBackend_config_connection(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
configData := map[string]interface{}{
|
||||
"uri": "sample_connection_uri",
|
||||
"verify_connection": false,
|
||||
}
|
||||
|
||||
configReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/connection",
|
||||
Storage: config.StorageView,
|
||||
Data: configData,
|
||||
}
|
||||
resp, err = b.HandleRequest(context.Background(), configReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
configReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(context.Background(), configReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, connURI := mongodb.PrepareTestContainer(t, "5.0.10")
|
||||
defer cleanup()
|
||||
connData := map[string]interface{}{
|
||||
"uri": connURI,
|
||||
}
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(connData, false),
|
||||
testAccStepRole(),
|
||||
testAccStepReadCreds("web"),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_roleCrud(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, connURI := mongodb.PrepareTestContainer(t, "5.0.10")
|
||||
defer cleanup()
|
||||
connData := map[string]interface{}{
|
||||
"uri": connURI,
|
||||
}
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(connData, false),
|
||||
testAccStepRole(),
|
||||
testAccStepReadRole("web", testDb, testMongoDBRoles),
|
||||
testAccStepDeleteRole("web"),
|
||||
testAccStepReadRole("web", "", ""),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_leaseWriteRead(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, connURI := mongodb.PrepareTestContainer(t, "5.0.10")
|
||||
defer cleanup()
|
||||
connData := map[string]interface{}{
|
||||
"uri": connURI,
|
||||
}
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(connData, false),
|
||||
testAccStepWriteLease(),
|
||||
testAccStepReadLease(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func testAccStepConfig(d map[string]interface{}, expectError bool) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/connection",
|
||||
Data: d,
|
||||
ErrorOk: true,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if expectError {
|
||||
if resp.Data == nil {
|
||||
return fmt.Errorf("data is nil")
|
||||
}
|
||||
var e struct {
|
||||
Error string `mapstructure:"error"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &e); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(e.Error) == 0 {
|
||||
return fmt.Errorf("expected error, but write succeeded")
|
||||
}
|
||||
return nil
|
||||
} else if resp != nil && resp.IsError() {
|
||||
return fmt.Errorf("got an error response: %v", resp.Error())
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepRole() logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/web",
|
||||
Data: map[string]interface{}{
|
||||
"db": testDb,
|
||||
"roles": testMongoDBRoles,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDeleteRole(n string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: "roles/" + n,
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadCreds(name string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "creds/" + name,
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
DB string `mapstructure:"db"`
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.DB == "" {
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
if d.Username == "" {
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
if !strings.HasPrefix(d.Username, "vault-root-") {
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
if d.Password == "" {
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
log.Printf("[WARN] Generated credentials: %v", d)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadRole(name, db, mongoDBRoles string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "roles/" + name,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp == nil {
|
||||
if db == "" && mongoDBRoles == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
var d struct {
|
||||
DB string `mapstructure:"db"`
|
||||
MongoDBRoles string `mapstructure:"roles"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.DB != db {
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
if d.MongoDBRoles != mongoDBRoles {
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepWriteLease() logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/lease",
|
||||
Data: map[string]interface{}{
|
||||
"ttl": "1h5m",
|
||||
"max_ttl": "24h",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadLease() logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "config/lease",
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp.Data["ttl"].(float64) != 3900 || resp.Data["max_ttl"].(float64) != 86400 {
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
testDb = "foo"
|
||||
testMongoDBRoles = `["readWrite",{"role":"read","db":"bar"}]`
|
||||
)
|
|
@ -1,29 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/mongodb"
|
||||
"github.com/hashicorp/vault/sdk/plugin"
|
||||
)
|
||||
|
||||
func main() {
|
||||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args[1:])
|
||||
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
if err := plugin.Serve(&plugin.ServeOpts{
|
||||
BackendFactoryFunc: mongodb.Factory,
|
||||
TLSProviderFunc: tlsProviderFunc,
|
||||
}); err != nil {
|
||||
logger := hclog.New(&hclog.LoggerOptions{})
|
||||
|
||||
logger.Error("plugin shutting down", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
|
@ -1,112 +0,0 @@
|
|||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
mgo "gopkg.in/mgo.v2"
|
||||
)
|
||||
|
||||
func pathConfigConnection(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "config/connection",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"uri": {
|
||||
Type: framework.TypeString,
|
||||
Description: "MongoDB standard connection string (URI)",
|
||||
},
|
||||
"verify_connection": {
|
||||
Type: framework.TypeBool,
|
||||
Default: true,
|
||||
Description: `If set, uri is verified by actually connecting to the database`,
|
||||
},
|
||||
},
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathConnectionRead,
|
||||
logical.UpdateOperation: b.pathConnectionWrite,
|
||||
},
|
||||
HelpSynopsis: pathConfigConnectionHelpSyn,
|
||||
HelpDescription: pathConfigConnectionHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
// pathConnectionRead reads out the connection configuration
|
||||
func (b *backend) pathConnectionRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
entry, err := req.Storage.Get(ctx, "config/connection")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read connection configuration")
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathConnectionWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
uri := data.Get("uri").(string)
|
||||
if uri == "" {
|
||||
return logical.ErrorResponse("uri parameter is required"), nil
|
||||
}
|
||||
|
||||
dialInfo, err := parseMongoURI(uri)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("invalid uri: %s", err)), nil
|
||||
}
|
||||
|
||||
// Don't check the config if verification is disabled
|
||||
verifyConnection := data.Get("verify_connection").(bool)
|
||||
if verifyConnection {
|
||||
// Verify the config
|
||||
session, err := mgo.DialWithInfo(dialInfo)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Error validating connection info: %s", err)), nil
|
||||
}
|
||||
defer session.Close()
|
||||
if err := session.Ping(); err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Error validating connection info: %s", err)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{
|
||||
URI: uri,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reset the Session
|
||||
b.ResetSession(ctx)
|
||||
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection URI as it is, including passwords, if any.")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type connectionConfig struct {
|
||||
URI string `json:"uri" structs:"uri" mapstructure:"uri"`
|
||||
}
|
||||
|
||||
const pathConfigConnectionHelpSyn = `
|
||||
Configure the connection string to talk to MongoDB.
|
||||
`
|
||||
|
||||
const pathConfigConnectionHelpDesc = `
|
||||
This path configures the standard connection string (URI) used to connect to MongoDB.
|
||||
|
||||
A MongoDB URI looks like:
|
||||
"mongodb://[username:password@]host1[:port1][,host2[:port2],...[,hostN[:portN]]][/[database][?options]]"
|
||||
|
||||
See https://docs.mongodb.org/manual/reference/connection-string/ for detailed documentation of the URI format.
|
||||
|
||||
When configuring the connection string, the backend will verify its validity.
|
||||
`
|
|
@ -1,89 +0,0 @@
|
|||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathConfigLease(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "config/lease",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"ttl": {
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: "Default ttl for credentials.",
|
||||
},
|
||||
|
||||
"max_ttl": {
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: "Maximum time a set of credentials can be valid for.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathConfigLeaseRead,
|
||||
logical.UpdateOperation: b.pathConfigLeaseWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigLeaseHelpSyn,
|
||||
HelpDescription: pathConfigLeaseHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigLeaseWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
|
||||
TTL: time.Second * time.Duration(d.Get("ttl").(int)),
|
||||
MaxTTL: time.Second * time.Duration(d.Get("max_ttl").(int)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigLeaseRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
leaseConfig, err := b.LeaseConfig(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if leaseConfig == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"ttl": leaseConfig.TTL.Seconds(),
|
||||
"max_ttl": leaseConfig.MaxTTL.Seconds(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type configLease struct {
|
||||
TTL time.Duration
|
||||
MaxTTL time.Duration
|
||||
}
|
||||
|
||||
const pathConfigLeaseHelpSyn = `
|
||||
Configure the default lease TTL settings for credentials
|
||||
generated by the mongodb backend.
|
||||
`
|
||||
|
||||
const pathConfigLeaseHelpDesc = `
|
||||
This configures the default lease TTL settings used for
|
||||
credentials generated by this backend. The ttl specifies the
|
||||
duration that a set of credentials will be valid for before
|
||||
the lease must be renewed (if it is renewable), while the
|
||||
max_ttl specifies the overall maximum duration that the
|
||||
credentials will be valid regardless of lease renewals.
|
||||
|
||||
The format for the TTL values is an integer and then unit. For
|
||||
example, the value "1h" specifies a 1-hour TTL. The longest
|
||||
supported unit is hours.
|
||||
`
|
|
@ -1,119 +0,0 @@
|
|||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathCredsCreate(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "creds/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role to generate credentials for.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathCredsCreateRead,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathCredsCreateReadHelpSyn,
|
||||
HelpDescription: pathCredsCreateReadHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathCredsCreateRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
|
||||
// Get the role
|
||||
role, err := b.Role(ctx, req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
|
||||
}
|
||||
|
||||
// Determine if we have a lease configuration
|
||||
leaseConfig, err := b.LeaseConfig(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if leaseConfig == nil {
|
||||
leaseConfig = &configLease{}
|
||||
}
|
||||
|
||||
// Generate the username and password
|
||||
displayName := req.DisplayName
|
||||
if displayName != "" {
|
||||
displayName += "-"
|
||||
}
|
||||
|
||||
userUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
username := fmt.Sprintf("vault-%s%s", displayName, userUUID)
|
||||
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build the user creation command
|
||||
createUserCmd := createUserCommand{
|
||||
Username: username,
|
||||
Password: password,
|
||||
Roles: role.MongoDBRoles.toStandardRolesArray(),
|
||||
}
|
||||
|
||||
// Get our connection
|
||||
session, err := b.Session(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create the user
|
||||
err = session.DB(role.DB).Run(createUserCmd, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the secret
|
||||
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
|
||||
"db": role.DB,
|
||||
"username": username,
|
||||
"password": password,
|
||||
}, map[string]interface{}{
|
||||
"username": username,
|
||||
"db": role.DB,
|
||||
})
|
||||
resp.Secret.TTL = leaseConfig.TTL
|
||||
resp.Secret.MaxTTL = leaseConfig.MaxTTL
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type createUserCommand struct {
|
||||
Username string `bson:"createUser"`
|
||||
Password string `bson:"pwd"`
|
||||
Roles []interface{} `bson:"roles"`
|
||||
}
|
||||
|
||||
const pathCredsCreateReadHelpSyn = `
|
||||
Request MongoDB database credentials for a particular role.
|
||||
`
|
||||
|
||||
const pathCredsCreateReadHelpDesc = `
|
||||
This path reads generates MongoDB database credentials for
|
||||
a particular role. The database credentials will be
|
||||
generated on demand and will be automatically revoked when
|
||||
the lease is up.
|
||||
`
|
|
@ -1,224 +0,0 @@
|
|||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathListRoles(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "roles/?$",
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ListOperation: b.pathRoleList,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleHelpSyn,
|
||||
HelpDescription: pathRoleHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathRoles(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "roles/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role.",
|
||||
},
|
||||
"db": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the authentication database for users generated for this role.",
|
||||
},
|
||||
"roles": {
|
||||
Type: framework.TypeString,
|
||||
Description: "MongoDB roles to assign to the users generated for this role.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathRoleRead,
|
||||
logical.UpdateOperation: b.pathRoleCreate,
|
||||
logical.DeleteOperation: b.pathRoleDelete,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleHelpSyn,
|
||||
HelpDescription: pathRoleHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) Role(ctx context.Context, s logical.Storage, n string) (*roleStorageEntry, error) {
|
||||
entry, err := s.Get(ctx, "role/"+n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result roleStorageEntry
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete(ctx, "role/"+data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
role, err := b.Role(ctx, req.Storage, data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rolesJsonBytes, err := json.Marshal(role.MongoDBRoles.toStandardRolesArray())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"db": role.DB,
|
||||
"roles": string(rolesJsonBytes),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
entries, err := req.Storage.List(ctx, "role/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logical.ListResponse(entries), nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
if name == "" {
|
||||
return logical.ErrorResponse("Missing name"), nil
|
||||
}
|
||||
|
||||
roleDB := data.Get("db").(string)
|
||||
if roleDB == "" {
|
||||
return logical.ErrorResponse("db parameter is required"), nil
|
||||
}
|
||||
|
||||
// Example roles JSON:
|
||||
//
|
||||
// [ "readWrite", { "role": "readWrite", "db": "test" } ]
|
||||
//
|
||||
// For storage, we convert such an array into a homogeneous array of role documents like:
|
||||
//
|
||||
// [ { "role": "readWrite" }, { "role": "readWrite", "db": "test" } ]
|
||||
//
|
||||
var roles []mongodbRole
|
||||
rolesJson := []byte(data.Get("roles").(string))
|
||||
if len(rolesJson) > 0 {
|
||||
var rolesArray []interface{}
|
||||
err := json.Unmarshal(rolesJson, &rolesArray)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, rawRole := range rolesArray {
|
||||
switch role := rawRole.(type) {
|
||||
case string:
|
||||
roles = append(roles, mongodbRole{Role: role})
|
||||
case map[string]interface{}:
|
||||
if db, ok := role["db"].(string); ok {
|
||||
if roleName, ok := role["role"].(string); ok {
|
||||
roles = append(roles, mongodbRole{Role: roleName, DB: db})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("role/"+name, &roleStorageEntry{
|
||||
DB: roleDB,
|
||||
MongoDBRoles: roles,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (roles mongodbRoles) toStandardRolesArray() []interface{} {
|
||||
// Convert array of role documents like:
|
||||
//
|
||||
// [ { "role": "readWrite" }, { "role": "readWrite", "db": "test" } ]
|
||||
//
|
||||
// into a "standard" MongoDB roles array containing both strings and role documents:
|
||||
//
|
||||
// [ "readWrite", { "role": "readWrite", "db": "test" } ]
|
||||
//
|
||||
// MongoDB's createUser command accepts the latter.
|
||||
//
|
||||
var standardRolesArray []interface{}
|
||||
for _, role := range roles {
|
||||
if role.DB == "" {
|
||||
standardRolesArray = append(standardRolesArray, role.Role)
|
||||
} else {
|
||||
standardRolesArray = append(standardRolesArray, role)
|
||||
}
|
||||
}
|
||||
return standardRolesArray
|
||||
}
|
||||
|
||||
type roleStorageEntry struct {
|
||||
DB string `json:"db"`
|
||||
MongoDBRoles mongodbRoles `json:"roles"`
|
||||
}
|
||||
|
||||
type mongodbRole struct {
|
||||
Role string `json:"role" bson:"role"`
|
||||
DB string `json:"db" bson:"db"`
|
||||
}
|
||||
|
||||
type mongodbRoles []mongodbRole
|
||||
|
||||
const pathRoleHelpSyn = `
|
||||
Manage the roles used to generate MongoDB credentials.
|
||||
`
|
||||
|
||||
const pathRoleHelpDesc = `
|
||||
This path lets you manage the roles used to generate MongoDB credentials.
|
||||
|
||||
The "db" parameter specifies the authentication database for users
|
||||
generated for a given role.
|
||||
|
||||
The "roles" parameter specifies the MongoDB roles that should be assigned
|
||||
to users created for a given role. Just like when creating a user directly
|
||||
using db.createUser, the roles JSON array can specify both built-in roles
|
||||
and user-defined roles for both the database the user is created in and
|
||||
for other databases.
|
||||
|
||||
For example, the following roles JSON array grants the "readWrite"
|
||||
permission on both the user's authentication database and the "test"
|
||||
database:
|
||||
|
||||
[ "readWrite", { "role": "readWrite", "db": "test" } ]
|
||||
|
||||
Please consult the MongoDB documentation for more
|
||||
details on Role-Based Access Control in MongoDB.
|
||||
`
|
|
@ -1,84 +0,0 @@
|
|||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
mgo "gopkg.in/mgo.v2"
|
||||
)
|
||||
|
||||
const SecretCredsType = "creds"
|
||||
|
||||
func secretCreds(b *backend) *framework.Secret {
|
||||
return &framework.Secret{
|
||||
Type: SecretCredsType,
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"username": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Username",
|
||||
},
|
||||
|
||||
"password": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Password",
|
||||
},
|
||||
},
|
||||
|
||||
Renew: b.secretCredsRenew,
|
||||
Revoke: b.secretCredsRevoke,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) secretCredsRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the lease information
|
||||
leaseConfig, err := b.LeaseConfig(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if leaseConfig == nil {
|
||||
leaseConfig = &configLease{}
|
||||
}
|
||||
|
||||
resp := &logical.Response{Secret: req.Secret}
|
||||
resp.Secret.TTL = leaseConfig.TTL
|
||||
resp.Secret.MaxTTL = leaseConfig.MaxTTL
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the username from the internal data
|
||||
usernameRaw, ok := req.Secret.InternalData["username"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret is missing username internal data")
|
||||
}
|
||||
username, ok := usernameRaw.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("username internal data is not a string")
|
||||
}
|
||||
|
||||
// Get the db from the internal data
|
||||
dbRaw, ok := req.Secret.InternalData["db"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret is missing db internal data")
|
||||
}
|
||||
db, ok := dbRaw.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("db internal data is not a string")
|
||||
}
|
||||
|
||||
// Get our connection
|
||||
session, err := b.Session(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Drop the user
|
||||
err = session.DB(db).RemoveUser(username)
|
||||
if err != nil && err != mgo.ErrNotFound {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
|
@ -1,81 +0,0 @@
|
|||
package mongodb
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
mgo "gopkg.in/mgo.v2"
|
||||
)
|
||||
|
||||
// Unfortunately, mgo doesn't support the ssl parameter in its MongoDB URI parsing logic, so we have to handle that
|
||||
// ourselves. See https://github.com/go-mgo/mgo/issues/84
|
||||
func parseMongoURI(rawUri string) (*mgo.DialInfo, error) {
|
||||
uri, err := url.Parse(rawUri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := mgo.DialInfo{
|
||||
Addrs: strings.Split(uri.Host, ","),
|
||||
Database: strings.TrimPrefix(uri.Path, "/"),
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
if uri.User != nil {
|
||||
info.Username = uri.User.Username()
|
||||
info.Password, _ = uri.User.Password()
|
||||
}
|
||||
|
||||
query := uri.Query()
|
||||
for key, values := range query {
|
||||
var value string
|
||||
if len(values) > 0 {
|
||||
value = values[0]
|
||||
}
|
||||
|
||||
switch key {
|
||||
case "authSource":
|
||||
info.Source = value
|
||||
case "authMechanism":
|
||||
info.Mechanism = value
|
||||
case "gssapiServiceName":
|
||||
info.Service = value
|
||||
case "replicaSet":
|
||||
info.ReplicaSetName = value
|
||||
case "maxPoolSize":
|
||||
poolLimit, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return nil, errors.New("bad value for maxPoolSize: " + value)
|
||||
}
|
||||
info.PoolLimit = poolLimit
|
||||
case "ssl":
|
||||
ssl, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return nil, errors.New("bad value for ssl: " + value)
|
||||
}
|
||||
if ssl {
|
||||
info.DialServer = func(addr *mgo.ServerAddr) (net.Conn, error) {
|
||||
return tls.Dial("tcp", addr.String(), &tls.Config{})
|
||||
}
|
||||
}
|
||||
case "connect":
|
||||
if value == "direct" {
|
||||
info.Direct = true
|
||||
break
|
||||
}
|
||||
if value == "replicaSet" {
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
default:
|
||||
return nil, errors.New("unsupported connection URL option: " + key + "=" + value)
|
||||
}
|
||||
}
|
||||
|
||||
return &info, nil
|
||||
}
|
|
@ -1,160 +0,0 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func Backend() *backend {
|
||||
var b backend
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
SealWrapStorage: []string{
|
||||
"config/connection",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathConfigLease(&b),
|
||||
pathListRoles(&b),
|
||||
pathRoles(&b),
|
||||
pathCredsCreate(&b),
|
||||
},
|
||||
|
||||
Secrets: []*framework.Secret{
|
||||
secretCreds(&b),
|
||||
},
|
||||
|
||||
Invalidate: b.invalidate,
|
||||
Clean: b.ResetDB,
|
||||
BackendType: logical.TypeLogical,
|
||||
}
|
||||
|
||||
return &b
|
||||
}
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
|
||||
db *sql.DB
|
||||
defaultDb string
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
// DB returns the default database connection.
|
||||
func (b *backend) DB(ctx context.Context, s logical.Storage) (*sql.DB, error) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
// If we already have a DB, we got it!
|
||||
if b.db != nil {
|
||||
if err := b.db.Ping(); err == nil {
|
||||
return b.db, nil
|
||||
}
|
||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||
// reestablishing anyways
|
||||
b.db.Close()
|
||||
}
|
||||
|
||||
// Otherwise, attempt to make connection
|
||||
entry, err := s.Get(ctx, "config/connection")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, fmt.Errorf("configure the DB connection with config/connection first")
|
||||
}
|
||||
|
||||
var connConfig connectionConfig
|
||||
if err := entry.DecodeJSON(&connConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connString := connConfig.ConnectionString
|
||||
|
||||
db, err := sql.Open("sqlserver", connString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set some connection pool settings. We don't need much of this,
|
||||
// since the request rate shouldn't be high.
|
||||
db.SetMaxOpenConns(connConfig.MaxOpenConnections)
|
||||
|
||||
stmt, err := db.Prepare("SELECT db_name();")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
err = stmt.QueryRow().Scan(&b.defaultDb)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.db = db
|
||||
return b.db, nil
|
||||
}
|
||||
|
||||
// ResetDB forces a connection next time DB() is called.
|
||||
func (b *backend) ResetDB(_ context.Context) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
if b.db != nil {
|
||||
b.db.Close()
|
||||
}
|
||||
|
||||
b.db = nil
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(ctx context.Context, key string) {
|
||||
switch key {
|
||||
case "config/connection":
|
||||
b.ResetDB(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// LeaseConfig returns the lease configuration
|
||||
func (b *backend) LeaseConfig(ctx context.Context, s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get(ctx, "config/lease")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result configLease
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
The MSSQL backend dynamically generates database users.
|
||||
|
||||
After mounting this backend, configure it using the endpoints within
|
||||
the "config/" path.
|
||||
|
||||
This backend does not support Azure SQL Databases.
|
||||
`
|
|
@ -1,222 +0,0 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
|
||||
mssqlhelper "github.com/hashicorp/vault/helper/testhelpers/mssql"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
func Backend_config_connection(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
configData := map[string]interface{}{
|
||||
"connection_string": "sample_connection_string",
|
||||
"max_open_connections": 7,
|
||||
"verify_connection": false,
|
||||
}
|
||||
|
||||
configReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/connection",
|
||||
Storage: config.StorageView,
|
||||
Data: configData,
|
||||
}
|
||||
resp, err = b.HandleRequest(context.Background(), configReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
configReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(context.Background(), configReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
delete(configData, "verify_connection")
|
||||
delete(configData, "connection_string")
|
||||
if !reflect.DeepEqual(configData, resp.Data) {
|
||||
t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
b, _ := Factory(context.Background(), logical.TestBackendConfig())
|
||||
|
||||
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
PreCheck: testAccPreCheckFunc(t, connURL),
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connURL),
|
||||
testAccStepRole(t),
|
||||
testAccStepReadCreds(t, "web"),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_roleCrud(t *testing.T) {
|
||||
b := Backend()
|
||||
|
||||
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
PreCheck: testAccPreCheckFunc(t, connURL),
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connURL),
|
||||
testAccStepRole(t),
|
||||
testAccStepReadRole(t, "web", testRoleSQL),
|
||||
testAccStepDeleteRole(t, "web"),
|
||||
testAccStepReadRole(t, "web", ""),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_leaseWriteRead(t *testing.T) {
|
||||
b := Backend()
|
||||
|
||||
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
PreCheck: testAccPreCheckFunc(t, connURL),
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connURL),
|
||||
testAccStepWriteLease(t),
|
||||
testAccStepReadLease(t),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func testAccPreCheckFunc(t *testing.T, connectionURL string) func() {
|
||||
return func() {
|
||||
if connectionURL == "" {
|
||||
t.Fatal("connection URL must be set for acceptance tests")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepConfig(t *testing.T, connURL string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/connection",
|
||||
Data: map[string]interface{}{
|
||||
"connection_string": connURL,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepRole(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/web",
|
||||
Data: map[string]interface{}{
|
||||
"sql": testRoleSQL,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDeleteRole(t *testing.T, n string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: "roles/" + n,
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadCreds(t *testing.T, name string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "creds/" + name,
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[WARN] Generated credentials: %v", d)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadRole(t *testing.T, name, sql string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "roles/" + name,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp == nil {
|
||||
if sql == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
var d struct {
|
||||
SQL string `mapstructure:"sql"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.SQL != sql {
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepWriteLease(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/lease",
|
||||
Data: map[string]interface{}{
|
||||
"ttl": "1h5m",
|
||||
"max_ttl": "24h",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadLease(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "config/lease",
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp.Data["ttl"] != "1h5m0s" || resp.Data["max_ttl"] != "24h0m0s" {
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const testRoleSQL = `
|
||||
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
|
||||
CREATE USER [{{name}}] FOR LOGIN [{{name}}];
|
||||
GRANT SELECT ON SCHEMA::dbo TO [{{name}}]
|
||||
`
|
|
@ -1,29 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/mssql"
|
||||
"github.com/hashicorp/vault/sdk/plugin"
|
||||
)
|
||||
|
||||
func main() {
|
||||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args[1:])
|
||||
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
if err := plugin.Serve(&plugin.ServeOpts{
|
||||
BackendFactoryFunc: mssql.Factory,
|
||||
TLSProviderFunc: tlsProviderFunc,
|
||||
}); err != nil {
|
||||
logger := hclog.New(&hclog.LoggerOptions{})
|
||||
|
||||
logger.Error("plugin shutting down", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
|
@ -1,126 +0,0 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathConfigConnection(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "config/connection",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"connection_string": {
|
||||
Type: framework.TypeString,
|
||||
Description: "DB connection parameters",
|
||||
},
|
||||
"max_open_connections": {
|
||||
Type: framework.TypeInt,
|
||||
Description: "Maximum number of open connections to database",
|
||||
},
|
||||
"verify_connection": {
|
||||
Type: framework.TypeBool,
|
||||
Default: true,
|
||||
Description: "If set, connection_string is verified by actually connecting to the database",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: b.pathConnectionWrite,
|
||||
logical.ReadOperation: b.pathConnectionRead,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigConnectionHelpSyn,
|
||||
HelpDescription: pathConfigConnectionHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
// pathConnectionRead reads out the connection configuration
|
||||
func (b *backend) pathConnectionRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
entry, err := req.Storage.Get(ctx, "config/connection")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read connection configuration")
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var config connectionConfig
|
||||
if err := entry.DecodeJSON(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"max_open_connections": config.MaxOpenConnections,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// pathConnectionWrite stores the connection configuration
|
||||
func (b *backend) pathConnectionWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
connString := data.Get("connection_string").(string)
|
||||
|
||||
maxOpenConns := data.Get("max_open_connections").(int)
|
||||
if maxOpenConns == 0 {
|
||||
maxOpenConns = 2
|
||||
}
|
||||
|
||||
// Don't check the connection_string if verification is disabled
|
||||
verifyConnection := data.Get("verify_connection").(bool)
|
||||
if verifyConnection {
|
||||
// Verify the string
|
||||
db, err := sql.Open("mssql", connString)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Error validating connection info: %s", err)), nil
|
||||
}
|
||||
defer db.Close()
|
||||
if err := db.Ping(); err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Error validating connection info: %s", err)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{
|
||||
ConnectionString: connString,
|
||||
MaxOpenConnections: maxOpenConns,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reset the DB connection
|
||||
b.ResetDB(ctx)
|
||||
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string as it is, including passwords, if any.")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type connectionConfig struct {
|
||||
ConnectionString string `json:"connection_string" structs:"connection_string" mapstructure:"connection_string"`
|
||||
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
|
||||
}
|
||||
|
||||
const pathConfigConnectionHelpSyn = `
|
||||
Configure the connection string to talk to Microsoft Sql Server.
|
||||
`
|
||||
|
||||
const pathConfigConnectionHelpDesc = `
|
||||
This path configures the connection string used to connect to Sql Server.
|
||||
The value of the string is a Data Source Name (DSN). An example is
|
||||
using "server=<hostname>;port=<port>;user id=<username>;password=<password>;database=<database>;app name=vault;"
|
||||
|
||||
When configuring the connection string, the backend will verify its validity.
|
||||
If the database is not available when setting the connection string, set the
|
||||
"verify_connection" option to false.
|
||||
`
|
|
@ -1,114 +0,0 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathConfigLease(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "config/lease",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"ttl": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Default ttl for roles.",
|
||||
},
|
||||
|
||||
"ttl_max": {
|
||||
Type: framework.TypeString,
|
||||
Description: `Deprecated: use "max_ttl" instead. Maximum
|
||||
time a credential is valid for.`,
|
||||
},
|
||||
|
||||
"max_ttl": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Maximum time a credential is valid for.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathConfigLeaseRead,
|
||||
logical.UpdateOperation: b.pathConfigLeaseWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigLeaseHelpSyn,
|
||||
HelpDescription: pathConfigLeaseHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigLeaseWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
ttlRaw := d.Get("ttl").(string)
|
||||
ttlMaxRaw := d.Get("max_ttl").(string)
|
||||
if len(ttlMaxRaw) == 0 {
|
||||
ttlMaxRaw = d.Get("ttl_max").(string)
|
||||
}
|
||||
|
||||
ttl, err := time.ParseDuration(ttlRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Invalid ttl: %s", err)), nil
|
||||
}
|
||||
ttlMax, err := time.ParseDuration(ttlMaxRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Invalid max_ttl: %s", err)), nil
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
|
||||
TTL: ttl,
|
||||
TTLMax: ttlMax,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathConfigLeaseRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
leaseConfig, err := b.LeaseConfig(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if leaseConfig == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"ttl": leaseConfig.TTL.String(),
|
||||
"ttl_max": leaseConfig.TTLMax.String(),
|
||||
"max_ttl": leaseConfig.TTLMax.String(),
|
||||
},
|
||||
}
|
||||
resp.AddWarning("The field ttl_max is deprecated and will be removed in a future release. Use max_ttl instead.")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type configLease struct {
|
||||
TTL time.Duration
|
||||
TTLMax time.Duration
|
||||
}
|
||||
|
||||
const pathConfigLeaseHelpSyn = `
|
||||
Configure the default lease ttl for generated credentials.
|
||||
`
|
||||
|
||||
const pathConfigLeaseHelpDesc = `
|
||||
This configures the default lease ttl used for credentials
|
||||
generated by this backend. The ttl specifies the duration that a
|
||||
credential will be valid for, as well as the maximum session for
|
||||
a set of credentials.
|
||||
|
||||
The format for the ttl is "1h" or integer and then unit. The longest
|
||||
unit is hour.
|
||||
`
|
|
@ -1,129 +0,0 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/dbtxn"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathCredsCreate(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "creds/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathCredsCreateRead,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathCredsCreateHelpSyn,
|
||||
HelpDescription: pathCredsCreateHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathCredsCreateRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
|
||||
// Get the role
|
||||
role, err := b.Role(ctx, req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
|
||||
}
|
||||
|
||||
// Determine if we have a lease configuration
|
||||
leaseConfig, err := b.LeaseConfig(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if leaseConfig == nil {
|
||||
leaseConfig = &configLease{}
|
||||
}
|
||||
|
||||
// Generate our username and password
|
||||
displayName := req.DisplayName
|
||||
if len(displayName) > 10 {
|
||||
displayName = displayName[:10]
|
||||
}
|
||||
userUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
username := fmt.Sprintf("%s-%s", displayName, userUUID)
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get our handle
|
||||
db, err := b.DB(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Always reset database to default db of connection. Since it is in a
|
||||
// transaction, all statements will be on the same connection in the pool.
|
||||
roleSQL := fmt.Sprintf("USE [%s]; %s", b.defaultDb, role.SQL)
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(roleSQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
m := map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
}
|
||||
if err := dbtxn.ExecuteTxQueryDirect(ctx, tx, m, query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the secret
|
||||
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}, map[string]interface{}{
|
||||
"username": username,
|
||||
})
|
||||
resp.Secret.TTL = leaseConfig.TTL
|
||||
resp.Secret.MaxTTL = leaseConfig.TTLMax
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
const pathCredsCreateHelpSyn = `
|
||||
Request database credentials for a certain role.
|
||||
`
|
||||
|
||||
const pathCredsCreateHelpDesc = `
|
||||
This path reads database credentials for a certain role. The
|
||||
database credentials will be generated on demand and will be automatically
|
||||
revoked when the lease is up.
|
||||
`
|
|
@ -1,172 +0,0 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathListRoles(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "roles/?$",
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ListOperation: b.pathRoleList,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleHelpSyn,
|
||||
HelpDescription: pathRoleHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathRoles(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "roles/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role.",
|
||||
},
|
||||
|
||||
"sql": {
|
||||
Type: framework.TypeString,
|
||||
Description: "SQL string to create a role. See help for more info.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathRoleRead,
|
||||
logical.UpdateOperation: b.pathRoleCreate,
|
||||
logical.DeleteOperation: b.pathRoleDelete,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleHelpSyn,
|
||||
HelpDescription: pathRoleHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) Role(ctx context.Context, s logical.Storage, n string) (*roleEntry, error) {
|
||||
entry, err := s.Get(ctx, "role/"+n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result roleEntry
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete(ctx, "role/"+data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
role, err := b.Role(ctx, req.Storage, data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"sql": role.SQL,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
entries, err := req.Storage.List(ctx, "role/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logical.ListResponse(entries), nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
sql := data.Get("sql").(string)
|
||||
|
||||
// Get our connection
|
||||
db, err := b.DB(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Test the query by trying to prepare it
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(sql, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := db.Prepare(Query(query, map[string]string{
|
||||
"name": "foo",
|
||||
"password": "bar",
|
||||
}))
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Error testing query: %s", err)), nil
|
||||
}
|
||||
stmt.Close()
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
|
||||
SQL: sql,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type roleEntry struct {
|
||||
SQL string `json:"sql"`
|
||||
}
|
||||
|
||||
const pathRoleHelpSyn = `
|
||||
Manage the roles that can be created with this backend.
|
||||
`
|
||||
|
||||
const pathRoleHelpDesc = `
|
||||
This path lets you manage the roles that can be created with this backend.
|
||||
|
||||
The "sql" parameter customizes the SQL string used to create the login to
|
||||
the server. The parameter can be a sequence of SQL queries, each semi-colon
|
||||
separated. Some substitution will be done to the SQL string for certain keys.
|
||||
The names of the variables must be surrounded by "{{" and "}}" to be replaced.
|
||||
|
||||
* "name" - The random username generated for the DB user.
|
||||
|
||||
* "password" - The random password generated for the DB user.
|
||||
|
||||
Example SQL query to use:
|
||||
|
||||
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
|
||||
CREATE USER [{{name}}] FROM LOGIN [{{name}}];
|
||||
GRANT SELECT, UPDATE, DELETE, INSERT on SCHEMA::dbo TO [{{name}}];
|
||||
|
||||
Please see the Microsoft SQL Server manual on the GRANT command to learn how to
|
||||
do more fine grained access.
|
||||
`
|
|
@ -1,180 +0,0 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/dbtxn"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
const SecretCredsType = "creds"
|
||||
|
||||
func secretCreds(b *backend) *framework.Secret {
|
||||
return &framework.Secret{
|
||||
Type: SecretCredsType,
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"username": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Username",
|
||||
},
|
||||
|
||||
"password": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Password",
|
||||
},
|
||||
},
|
||||
|
||||
Renew: b.secretCredsRenew,
|
||||
Revoke: b.secretCredsRevoke,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) secretCredsRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the lease information
|
||||
leaseConfig, err := b.LeaseConfig(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if leaseConfig == nil {
|
||||
leaseConfig = &configLease{}
|
||||
}
|
||||
|
||||
resp := &logical.Response{Secret: req.Secret}
|
||||
resp.Secret.TTL = leaseConfig.TTL
|
||||
resp.Secret.MaxTTL = leaseConfig.TTLMax
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the username from the internal data
|
||||
usernameRaw, ok := req.Secret.InternalData["username"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret is missing username internal data")
|
||||
}
|
||||
username, ok := usernameRaw.(string)
|
||||
|
||||
// Get our connection
|
||||
db, err := b.DB(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// First disable server login
|
||||
disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer disableStmt.Close()
|
||||
if _, err := disableStmt.Exec(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Query for sessions for the login so that we can kill any outstanding
|
||||
// sessions. There cannot be any active sessions before we drop the logins
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
sessionStmt, err := db.Prepare("SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = @p1;")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer sessionStmt.Close()
|
||||
|
||||
sessionRows, err := sessionStmt.Query(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer sessionRows.Close()
|
||||
|
||||
var revokeStmts []string
|
||||
for sessionRows.Next() {
|
||||
var sessionID int
|
||||
err = sessionRows.Scan(&sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID))
|
||||
}
|
||||
|
||||
// Query for database users using undocumented stored procedure for now since
|
||||
// it is the easiest way to get this information;
|
||||
// we need to drop the database users before we can drop the login and the role
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
stmt, err := db.Prepare("EXEC master.dbo.sp_msloginmappings @p1;")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var loginName, dbName, qUsername, aliasName sql.NullString
|
||||
err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !dbName.Valid {
|
||||
continue
|
||||
}
|
||||
revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName.String, username, username))
|
||||
}
|
||||
|
||||
// we do not stop on error, as we want to remove as
|
||||
// many permissions as possible right now
|
||||
var lastStmtError error
|
||||
for _, query := range revokeStmts {
|
||||
if err := dbtxn.ExecuteDBQueryDirect(ctx, db, nil, query); err != nil {
|
||||
lastStmtError = err
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// can't drop if not all database users are dropped
|
||||
if rows.Err() != nil {
|
||||
return nil, fmt.Errorf("could not generate sql statements for all rows: %w", rows.Err())
|
||||
}
|
||||
if lastStmtError != nil {
|
||||
return nil, fmt.Errorf("could not perform all sql statements: %w", lastStmtError)
|
||||
}
|
||||
|
||||
// Drop this login
|
||||
stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
const dropUserSQL = `
|
||||
USE [%s]
|
||||
IF EXISTS
|
||||
(SELECT name
|
||||
FROM sys.database_principals
|
||||
WHERE name = N'%s')
|
||||
BEGIN
|
||||
DROP USER [%s]
|
||||
END
|
||||
`
|
||||
|
||||
const dropLoginSQL = `
|
||||
IF EXISTS
|
||||
(SELECT name
|
||||
FROM master.sys.server_principals
|
||||
WHERE name = N'%s')
|
||||
BEGIN
|
||||
DROP LOGIN [%s]
|
||||
END
|
||||
`
|
|
@ -1,28 +0,0 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SplitSQL is used to split a series of SQL statements
|
||||
func SplitSQL(sql string) []string {
|
||||
parts := strings.Split(sql, ";")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
clean := strings.TrimSpace(p)
|
||||
if len(clean) > 0 {
|
||||
out = append(out, clean)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Query templates a query for us.
|
||||
func Query(tpl string, data map[string]string) string {
|
||||
for k, v := range data {
|
||||
tpl = strings.ReplaceAll(tpl, fmt.Sprintf("{{%s}}", k), v)
|
||||
}
|
||||
|
||||
return tpl
|
||||
}
|
|
@ -1,151 +0,0 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func Backend() *backend {
|
||||
var b backend
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
SealWrapStorage: []string{
|
||||
"config/connection",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathConfigLease(&b),
|
||||
pathListRoles(&b),
|
||||
pathRoles(&b),
|
||||
pathRoleCreate(&b),
|
||||
},
|
||||
|
||||
Secrets: []*framework.Secret{
|
||||
secretCreds(&b),
|
||||
},
|
||||
|
||||
Invalidate: b.invalidate,
|
||||
Clean: b.ResetDB,
|
||||
BackendType: logical.TypeLogical,
|
||||
}
|
||||
|
||||
return &b
|
||||
}
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
|
||||
db *sql.DB
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
// DB returns the database connection.
|
||||
func (b *backend) DB(ctx context.Context, s logical.Storage) (*sql.DB, error) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
// If we already have a DB, we got it!
|
||||
if b.db != nil {
|
||||
if err := b.db.Ping(); err == nil {
|
||||
return b.db, nil
|
||||
}
|
||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||
// reestablishing anyways
|
||||
b.db.Close()
|
||||
}
|
||||
|
||||
// Otherwise, attempt to make connection
|
||||
entry, err := s.Get(ctx, "config/connection")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil,
|
||||
fmt.Errorf("configure the DB connection with config/connection first")
|
||||
}
|
||||
|
||||
var connConfig connectionConfig
|
||||
if err := entry.DecodeJSON(&connConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := connConfig.ConnectionURL
|
||||
if len(conn) == 0 {
|
||||
conn = connConfig.ConnectionString
|
||||
}
|
||||
|
||||
b.db, err = sql.Open("mysql", conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set some connection pool settings. We don't need much of this,
|
||||
// since the request rate shouldn't be high.
|
||||
b.db.SetMaxOpenConns(connConfig.MaxOpenConnections)
|
||||
b.db.SetMaxIdleConns(connConfig.MaxIdleConnections)
|
||||
|
||||
return b.db, nil
|
||||
}
|
||||
|
||||
// ResetDB forces a connection next time DB() is called.
|
||||
func (b *backend) ResetDB(_ context.Context) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
if b.db != nil {
|
||||
b.db.Close()
|
||||
}
|
||||
|
||||
b.db = nil
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(ctx context.Context, key string) {
|
||||
switch key {
|
||||
case "config/connection":
|
||||
b.ResetDB(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Lease returns the lease information
|
||||
func (b *backend) Lease(ctx context.Context, s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get(ctx, "config/lease")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result configLease
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
The MySQL backend dynamically generates database users.
|
||||
|
||||
After mounting this backend, configure it using the endpoints within
|
||||
the "config/" path.
|
||||
`
|
|
@ -1,307 +0,0 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
|
||||
mysqlhelper "github.com/hashicorp/vault/helper/testhelpers/mysql"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
func TestBackend_config_connection(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
configData := map[string]interface{}{
|
||||
"connection_url": "sample_connection_url",
|
||||
"max_open_connections": 9,
|
||||
"max_idle_connections": 7,
|
||||
"verify_connection": false,
|
||||
}
|
||||
|
||||
configReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/connection",
|
||||
Storage: config.StorageView,
|
||||
Data: configData,
|
||||
}
|
||||
resp, err = b.HandleRequest(context.Background(), configReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
configReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(context.Background(), configReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
delete(configData, "verify_connection")
|
||||
delete(configData, "connection_url")
|
||||
if !reflect.DeepEqual(configData, resp.Data) {
|
||||
t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, "secret")
|
||||
defer cleanup()
|
||||
|
||||
connData := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
// for wildcard based mysql user
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connData, false),
|
||||
testAccStepRole(t, true),
|
||||
testAccStepReadCreds(t, "web"),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_basicHostRevoke(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, "secret")
|
||||
defer cleanup()
|
||||
|
||||
connData := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
// for host based mysql user
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connData, false),
|
||||
testAccStepRole(t, false),
|
||||
testAccStepReadCreds(t, "web"),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_roleCrud(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, "secret")
|
||||
defer cleanup()
|
||||
|
||||
connData := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connData, false),
|
||||
// test SQL with wildcard based user
|
||||
testAccStepRole(t, true),
|
||||
testAccStepReadRole(t, "web", testRoleWildCard),
|
||||
testAccStepDeleteRole(t, "web"),
|
||||
// test SQL with host based user
|
||||
testAccStepRole(t, false),
|
||||
testAccStepReadRole(t, "web", testRoleHost),
|
||||
testAccStepDeleteRole(t, "web"),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_leaseWriteRead(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, "secret")
|
||||
defer cleanup()
|
||||
|
||||
connData := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connData, false),
|
||||
testAccStepWriteLease(t),
|
||||
testAccStepReadLease(t),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/connection",
|
||||
Data: d,
|
||||
ErrorOk: true,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if expectError {
|
||||
if resp.Data == nil {
|
||||
return fmt.Errorf("data is nil")
|
||||
}
|
||||
var e struct {
|
||||
Error string `mapstructure:"error"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &e); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(e.Error) == 0 {
|
||||
return fmt.Errorf("expected error, but write succeeded")
|
||||
}
|
||||
return nil
|
||||
} else if resp != nil && resp.IsError() {
|
||||
return fmt.Errorf("got an error response: %v", resp.Error())
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepRole(t *testing.T, wildCard bool) logicaltest.TestStep {
|
||||
pathData := make(map[string]interface{})
|
||||
if wildCard {
|
||||
pathData = map[string]interface{}{
|
||||
"sql": testRoleWildCard,
|
||||
}
|
||||
} else {
|
||||
pathData = map[string]interface{}{
|
||||
"sql": testRoleHost,
|
||||
"revocation_sql": testRevocationSQL,
|
||||
}
|
||||
}
|
||||
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/web",
|
||||
Data: pathData,
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDeleteRole(t *testing.T, n string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: "roles/" + n,
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadCreds(t *testing.T, name string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "creds/" + name,
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[WARN] Generated credentials: %v", d)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadRole(t *testing.T, name string, sql string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "roles/" + name,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp == nil {
|
||||
if sql == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
var d struct {
|
||||
SQL string `mapstructure:"sql"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.SQL != sql {
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepWriteLease(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/lease",
|
||||
Data: map[string]interface{}{
|
||||
"lease": "1h5m",
|
||||
"lease_max": "24h",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadLease(t *testing.T) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "config/lease",
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp.Data["lease"] != "1h5m0s" || resp.Data["lease_max"] != "24h0m0s" {
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const testRoleWildCard = `
|
||||
CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}';
|
||||
GRANT SELECT ON *.* TO '{{name}}'@'%';
|
||||
`
|
||||
|
||||
const testRoleHost = `
|
||||
CREATE USER '{{name}}'@'10.1.1.2' IDENTIFIED BY '{{password}}';
|
||||
GRANT SELECT ON *.* TO '{{name}}'@'10.1.1.2';
|
||||
`
|
||||
|
||||
const testRevocationSQL = `
|
||||
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'10.1.1.2';
|
||||
DROP USER '{{name}}'@'10.1.1.2';
|
||||
`
|
|
@ -1,29 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/mysql"
|
||||
"github.com/hashicorp/vault/sdk/plugin"
|
||||
)
|
||||
|
||||
func main() {
|
||||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args[1:])
|
||||
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
if err := plugin.Serve(&plugin.ServeOpts{
|
||||
BackendFactoryFunc: mysql.Factory,
|
||||
TLSProviderFunc: tlsProviderFunc,
|
||||
}); err != nil {
|
||||
logger := hclog.New(&hclog.LoggerOptions{})
|
||||
|
||||
logger.Error("plugin shutting down", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
|
@ -1,159 +0,0 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathConfigConnection(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "config/connection",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"connection_url": {
|
||||
Type: framework.TypeString,
|
||||
Description: "DB connection string",
|
||||
},
|
||||
"value": {
|
||||
Type: framework.TypeString,
|
||||
Description: `DB connection string. Use 'connection_url' instead.
|
||||
This name is deprecated.`,
|
||||
},
|
||||
"max_open_connections": {
|
||||
Type: framework.TypeInt,
|
||||
Description: "Maximum number of open connections to database",
|
||||
},
|
||||
"max_idle_connections": {
|
||||
Type: framework.TypeInt,
|
||||
Description: "Maximum number of idle connections to the database; a zero uses the value of max_open_connections and a negative value disables idle connections. If larger than max_open_connections it will be reduced to the same size.",
|
||||
},
|
||||
"verify_connection": {
|
||||
Type: framework.TypeBool,
|
||||
Default: true,
|
||||
Description: "If set, connection_url is verified by actually connecting to the database",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: b.pathConnectionWrite,
|
||||
logical.ReadOperation: b.pathConnectionRead,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigConnectionHelpSyn,
|
||||
HelpDescription: pathConfigConnectionHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
// pathConnectionRead reads out the connection configuration
|
||||
func (b *backend) pathConnectionRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
entry, err := req.Storage.Get(ctx, "config/connection")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read connection configuration")
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var config connectionConfig
|
||||
if err := entry.DecodeJSON(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"max_open_connections": config.MaxOpenConnections,
|
||||
"max_idle_connections": config.MaxIdleConnections,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathConnectionWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
connValue := data.Get("value").(string)
|
||||
connURL := data.Get("connection_url").(string)
|
||||
if connURL == "" {
|
||||
if connValue == "" {
|
||||
return logical.ErrorResponse("the connection_url parameter must be supplied"), nil
|
||||
} else {
|
||||
connURL = connValue
|
||||
}
|
||||
}
|
||||
|
||||
maxOpenConns := data.Get("max_open_connections").(int)
|
||||
if maxOpenConns == 0 {
|
||||
maxOpenConns = 2
|
||||
}
|
||||
|
||||
maxIdleConns := data.Get("max_idle_connections").(int)
|
||||
if maxIdleConns == 0 {
|
||||
maxIdleConns = maxOpenConns
|
||||
}
|
||||
if maxIdleConns > maxOpenConns {
|
||||
maxIdleConns = maxOpenConns
|
||||
}
|
||||
|
||||
// Don't check the connection_url if verification is disabled
|
||||
verifyConnection := data.Get("verify_connection").(bool)
|
||||
if verifyConnection {
|
||||
// Verify the string
|
||||
db, err := sql.Open("mysql", connURL)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"error validating connection info: %s", err)), nil
|
||||
}
|
||||
defer db.Close()
|
||||
if err := db.Ping(); err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"error validating connection info: %s", err)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{
|
||||
ConnectionURL: connURL,
|
||||
MaxOpenConnections: maxOpenConns,
|
||||
MaxIdleConnections: maxIdleConns,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reset the DB connection
|
||||
b.ResetDB(ctx)
|
||||
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection URL as it is, including passwords, if any.")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type connectionConfig struct {
|
||||
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
|
||||
// Deprecate "value" in coming releases
|
||||
ConnectionString string `json:"value" structs:"value" mapstructure:"value"`
|
||||
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
|
||||
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
|
||||
}
|
||||
|
||||
const pathConfigConnectionHelpSyn = `
|
||||
Configure the connection string to talk to MySQL.
|
||||
`
|
||||
|
||||
const pathConfigConnectionHelpDesc = `
|
||||
This path configures the connection string used to connect to MySQL. The value
|
||||
of the string is a Data Source Name (DSN). An example is using
|
||||
"username:password@protocol(address)/dbname?param=value"
|
||||
|
||||
For example, RDS may look like:
|
||||
"id:password@tcp(your-amazonaws-uri.com:3306)/dbname"
|
||||
|
||||
When configuring the connection string, the backend will verify its validity.
|
||||
If the database is not available when setting the connection URL, set the
|
||||
"verify_connection" option to false.
|
||||
`
|
|
@ -1,101 +0,0 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathConfigLease(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "config/lease",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"lease": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Default lease for roles.",
|
||||
},
|
||||
|
||||
"lease_max": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Maximum time a credential is valid for.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathLeaseRead,
|
||||
logical.UpdateOperation: b.pathLeaseWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigLeaseHelpSyn,
|
||||
HelpDescription: pathConfigLeaseHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathLeaseWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
leaseRaw := d.Get("lease").(string)
|
||||
leaseMaxRaw := d.Get("lease_max").(string)
|
||||
|
||||
lease, err := time.ParseDuration(leaseRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Invalid lease: %s", err)), nil
|
||||
}
|
||||
leaseMax, err := time.ParseDuration(leaseMaxRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Invalid lease: %s", err)), nil
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
|
||||
Lease: lease,
|
||||
LeaseMax: leaseMax,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathLeaseRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
lease, err := b.Lease(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lease == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"lease": lease.Lease.String(),
|
||||
"lease_max": lease.LeaseMax.String(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type configLease struct {
|
||||
Lease time.Duration
|
||||
LeaseMax time.Duration
|
||||
}
|
||||
|
||||
const pathConfigLeaseHelpSyn = `
|
||||
Configure the default lease information for generated credentials.
|
||||
`
|
||||
|
||||
const pathConfigLeaseHelpDesc = `
|
||||
This configures the default lease information used for credentials
|
||||
generated by this backend. The lease specifies the duration that a
|
||||
credential will be valid for, as well as the maximum session for
|
||||
a set of credentials.
|
||||
|
||||
The format for the lease is "1h" or integer and then unit. The longest
|
||||
unit is hour.
|
||||
`
|
|
@ -1,143 +0,0 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/dbtxn"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathRoleCreate(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "creds/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathRoleCreateRead,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleCreateReadHelpSyn,
|
||||
HelpDescription: pathRoleCreateReadHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleCreateRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
|
||||
// Get the role
|
||||
role, err := b.Role(ctx, req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
|
||||
}
|
||||
|
||||
// Determine if we have a lease
|
||||
lease, err := b.Lease(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lease == nil {
|
||||
lease = &configLease{}
|
||||
}
|
||||
|
||||
// Generate our username and password. The username will be a
|
||||
// concatenation of:
|
||||
//
|
||||
// - the role name, truncated to role.rolenameLength (default 4)
|
||||
// - the token display name, truncated to role.displaynameLength (default 4)
|
||||
// - a UUID
|
||||
//
|
||||
// the entire concatenated string is then truncated to role.usernameLength,
|
||||
// which by default is 16 due to limitations in older but still-prevalent
|
||||
// versions of MySQL.
|
||||
roleName := name
|
||||
if len(roleName) > role.RolenameLength {
|
||||
roleName = roleName[:role.RolenameLength]
|
||||
}
|
||||
displayName := req.DisplayName
|
||||
if len(displayName) > role.DisplaynameLength {
|
||||
displayName = displayName[:role.DisplaynameLength]
|
||||
}
|
||||
userUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
username := fmt.Sprintf("%s-%s-%s", roleName, displayName, userUUID)
|
||||
if len(username) > role.UsernameLength {
|
||||
username = username[:role.UsernameLength]
|
||||
}
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get our handle
|
||||
db, err := b.DB(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(role.SQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
m := map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
}
|
||||
if err := dbtxn.ExecuteTxQueryDirect(ctx, tx, m, query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the secret
|
||||
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}, map[string]interface{}{
|
||||
"username": username,
|
||||
"role": name,
|
||||
})
|
||||
|
||||
resp.Secret.TTL = lease.Lease
|
||||
resp.Secret.MaxTTL = lease.LeaseMax
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
const pathRoleCreateReadHelpSyn = `
|
||||
Request database credentials for a certain role.
|
||||
`
|
||||
|
||||
const pathRoleCreateReadHelpDesc = `
|
||||
This path reads database credentials for a certain role. The
|
||||
database credentials will be generated on demand and will be automatically
|
||||
revoked when the lease is up.
|
||||
`
|
|
@ -1,230 +0,0 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathListRoles(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "roles/?$",
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ListOperation: b.pathRoleList,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleHelpSyn,
|
||||
HelpDescription: pathRoleHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathRoles(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "roles/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role.",
|
||||
},
|
||||
|
||||
"sql": {
|
||||
Type: framework.TypeString,
|
||||
Description: "SQL string to create a user. See help for more info.",
|
||||
},
|
||||
|
||||
"revocation_sql": {
|
||||
Type: framework.TypeString,
|
||||
Description: "SQL string to revoke a user. See help for more info.",
|
||||
},
|
||||
|
||||
"username_length": {
|
||||
Type: framework.TypeInt,
|
||||
Description: "number of characters to truncate generated mysql usernames to (default 16)",
|
||||
Default: 16,
|
||||
},
|
||||
|
||||
"rolename_length": {
|
||||
Type: framework.TypeInt,
|
||||
Description: "number of characters to truncate the rolename portion of generated mysql usernames to (default 4)",
|
||||
Default: 4,
|
||||
},
|
||||
|
||||
"displayname_length": {
|
||||
Type: framework.TypeInt,
|
||||
Description: "number of characters to truncate the displayname portion of generated mysql usernames to (default 4)",
|
||||
Default: 4,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathRoleRead,
|
||||
logical.UpdateOperation: b.pathRoleCreate,
|
||||
logical.DeleteOperation: b.pathRoleDelete,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleHelpSyn,
|
||||
HelpDescription: pathRoleHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) Role(ctx context.Context, s logical.Storage, n string) (*roleEntry, error) {
|
||||
entry, err := s.Get(ctx, "role/"+n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Set defaults to handle upgrade cases
|
||||
result := roleEntry{
|
||||
UsernameLength: 16,
|
||||
RolenameLength: 4,
|
||||
DisplaynameLength: 4,
|
||||
}
|
||||
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete(ctx, "role/"+data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
role, err := b.Role(ctx, req.Storage, data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"sql": role.SQL,
|
||||
"revocation_sql": role.RevocationSQL,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
entries, err := req.Storage.List(ctx, "role/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logical.ListResponse(entries), nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
|
||||
// Get our connection
|
||||
db, err := b.DB(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Test the query by trying to prepare it
|
||||
sql := data.Get("sql").(string)
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(sql, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := db.Prepare(Query(query, map[string]string{
|
||||
"name": "foo",
|
||||
"password": "bar",
|
||||
}))
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Error testing query: %s", err)), nil
|
||||
}
|
||||
stmt.Close()
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
|
||||
SQL: sql,
|
||||
RevocationSQL: data.Get("revocation_sql").(string),
|
||||
UsernameLength: data.Get("username_length").(int),
|
||||
DisplaynameLength: data.Get("displayname_length").(int),
|
||||
RolenameLength: data.Get("rolename_length").(int),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type roleEntry struct {
|
||||
SQL string `json:"sql" mapstructure:"sql" structs:"sql"`
|
||||
RevocationSQL string `json:"revocation_sql" mapstructure:"revocation_sql" structs:"revocation_sql"`
|
||||
UsernameLength int `json:"username_length" mapstructure:"username_length" structs:"username_length"`
|
||||
DisplaynameLength int `json:"displayname_length" mapstructure:"displayname_length" structs:"displayname_length"`
|
||||
RolenameLength int `json:"rolename_length" mapstructure:"rolename_length" structs:"rolename_length"`
|
||||
}
|
||||
|
||||
const pathRoleHelpSyn = `
|
||||
Manage the roles that can be created with this backend.
|
||||
`
|
||||
|
||||
const pathRoleHelpDesc = `
|
||||
This path lets you manage the roles that can be created with this backend.
|
||||
|
||||
The "sql" parameter customizes the SQL string used to create the role.
|
||||
This can be a sequence of SQL queries, each semi-colon separated. Some
|
||||
substitution will be done to the SQL string for certain keys.
|
||||
The names of the variables must be surrounded by "{{" and "}}" to be replaced.
|
||||
|
||||
* "name" - The random username generated for the DB user.
|
||||
|
||||
* "password" - The random password generated for the DB user.
|
||||
|
||||
Example of a decent SQL query to use:
|
||||
|
||||
CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}';
|
||||
GRANT ALL ON db1.* TO '{{name}}'@'%';
|
||||
|
||||
Note the above user would be able to access anything in db1. Please see the MySQL
|
||||
manual on the GRANT command to learn how to do more fine grained access.
|
||||
|
||||
The "rolename_length" parameter determines how many characters of the role name
|
||||
will be used in creating the generated mysql username; the default is 4.
|
||||
|
||||
The "displayname_length" parameter determines how many characters of the token
|
||||
display name will be used in creating the generated mysql username; the default
|
||||
is 4.
|
||||
|
||||
The "username_length" parameter determines how many total characters the
|
||||
generated username (including the role name, token display name and the uuid
|
||||
portion) will be truncated to. Versions of MySQL prior to 5.7.8 are limited to
|
||||
16 characters total (see
|
||||
http://dev.mysql.com/doc/refman/5.7/en/user-names.html) so that is the default;
|
||||
for versions >=5.7.8 it is safe to increase this to 32.
|
||||
|
||||
For best readability in MySQL process lists, we recommend using MySQL 5.7.8 or
|
||||
later, setting "username_length" to 32 and setting both "rolename_length" and
|
||||
"displayname_length" to 8. However due the the prevalence of older versions of
|
||||
MySQL in general deployment, the defaults are currently tuned for a
|
||||
username_length of 16.
|
||||
`
|
|
@ -1,136 +0,0 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
const SecretCredsType = "creds"
|
||||
|
||||
// defaultRevocationSQL is a default SQL statement for revoking a user. Revoking
|
||||
// permissions for the user is done before the drop, because MySQL explicitly
|
||||
// documents that open user connections will not be closed. By revoking all
|
||||
// grants, at least we ensure that the open connection is useless. Dropping the
|
||||
// user will only affect the next connection.
|
||||
const defaultRevocationSQL = `
|
||||
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
|
||||
DROP USER '{{name}}'@'%'
|
||||
`
|
||||
|
||||
func secretCreds(b *backend) *framework.Secret {
|
||||
return &framework.Secret{
|
||||
Type: SecretCredsType,
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"username": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Username",
|
||||
},
|
||||
|
||||
"password": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Password",
|
||||
},
|
||||
},
|
||||
|
||||
Renew: b.secretCredsRenew,
|
||||
Revoke: b.secretCredsRevoke,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) secretCredsRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the lease information
|
||||
lease, err := b.Lease(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lease == nil {
|
||||
lease = &configLease{}
|
||||
}
|
||||
|
||||
resp := &logical.Response{Secret: req.Secret}
|
||||
resp.Secret.TTL = lease.Lease
|
||||
resp.Secret.MaxTTL = lease.LeaseMax
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
var resp *logical.Response
|
||||
|
||||
// Get the username from the internal data
|
||||
usernameRaw, ok := req.Secret.InternalData["username"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret is missing username internal data")
|
||||
}
|
||||
username, ok := usernameRaw.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("usernameRaw is not a string")
|
||||
}
|
||||
|
||||
// Get our connection
|
||||
db, err := b.DB(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roleName := ""
|
||||
roleNameRaw, ok := req.Secret.InternalData["role"]
|
||||
if ok {
|
||||
roleName = roleNameRaw.(string)
|
||||
}
|
||||
|
||||
var role *roleEntry
|
||||
if roleName != "" {
|
||||
role, err = b.Role(ctx, req.Storage, roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Use a default SQL statement for revocation if one cannot be fetched from the role
|
||||
revocationSQL := defaultRevocationSQL
|
||||
|
||||
if role != nil && role.RevocationSQL != "" {
|
||||
revocationSQL = role.RevocationSQL
|
||||
} else {
|
||||
if resp == nil {
|
||||
resp = &logical.Response{}
|
||||
}
|
||||
resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default SQL for revoking user.", roleName))
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// This is not a prepared statement because not all commands are supported
|
||||
// 1295: This command is not supported in the prepared statement protocol yet
|
||||
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
|
||||
query = strings.ReplaceAll(query, "{{name}}", username)
|
||||
_, err = tx.Exec(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
|
@ -1,15 +0,0 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Query templates a query for us.
|
||||
func Query(tpl string, data map[string]string) string {
|
||||
for k, v := range data {
|
||||
tpl = strings.ReplaceAll(tpl, fmt.Sprintf("{{%s}}", k), v)
|
||||
}
|
||||
|
||||
return tpl
|
||||
}
|
|
@ -1,171 +0,0 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend(conf)
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func Backend(conf *logical.BackendConfig) *backend {
|
||||
var b backend
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
PathsSpecial: &logical.Paths{
|
||||
SealWrapStorage: []string{
|
||||
"config/connection",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathConfigLease(&b),
|
||||
pathListRoles(&b),
|
||||
pathRoles(&b),
|
||||
pathRoleCreate(&b),
|
||||
},
|
||||
|
||||
Secrets: []*framework.Secret{
|
||||
secretCreds(&b),
|
||||
},
|
||||
|
||||
Clean: b.ResetDB,
|
||||
Invalidate: b.invalidate,
|
||||
BackendType: logical.TypeLogical,
|
||||
}
|
||||
|
||||
b.logger = conf.Logger
|
||||
return &b
|
||||
}
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
|
||||
db *sql.DB
|
||||
lock sync.Mutex
|
||||
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
// DB returns the database connection.
|
||||
func (b *backend) DB(ctx context.Context, s logical.Storage) (*sql.DB, error) {
|
||||
b.logger.Debug("postgres/db: enter")
|
||||
defer b.logger.Debug("postgres/db: exit")
|
||||
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
// If we already have a DB, we got it!
|
||||
if b.db != nil {
|
||||
if err := b.db.Ping(); err == nil {
|
||||
return b.db, nil
|
||||
}
|
||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||
// reestablishing anyways
|
||||
b.db.Close()
|
||||
}
|
||||
|
||||
// Otherwise, attempt to make connection
|
||||
entry, err := s.Get(ctx, "config/connection")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil,
|
||||
fmt.Errorf("configure the DB connection with config/connection first")
|
||||
}
|
||||
|
||||
var connConfig connectionConfig
|
||||
if err := entry.DecodeJSON(&connConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := connConfig.ConnectionURL
|
||||
if len(conn) == 0 {
|
||||
conn = connConfig.ConnectionString
|
||||
}
|
||||
|
||||
// Ensure timezone is set to UTC for all the connections
|
||||
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
|
||||
if strings.Contains(conn, "?") {
|
||||
conn += "&timezone=utc"
|
||||
} else {
|
||||
conn += "?timezone=utc"
|
||||
}
|
||||
} else {
|
||||
conn += "&timezone=utc"
|
||||
}
|
||||
|
||||
b.db, err = sql.Open("pgx", conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set some connection pool settings. We don't need much of this,
|
||||
// since the request rate shouldn't be high.
|
||||
b.db.SetMaxOpenConns(connConfig.MaxOpenConnections)
|
||||
b.db.SetMaxIdleConns(connConfig.MaxIdleConnections)
|
||||
|
||||
return b.db, nil
|
||||
}
|
||||
|
||||
// ResetDB forces a connection next time DB() is called.
|
||||
func (b *backend) ResetDB(_ context.Context) {
|
||||
b.logger.Debug("postgres/db: enter")
|
||||
defer b.logger.Debug("postgres/db: exit")
|
||||
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
if b.db != nil {
|
||||
b.db.Close()
|
||||
}
|
||||
|
||||
b.db = nil
|
||||
}
|
||||
|
||||
func (b *backend) invalidate(ctx context.Context, key string) {
|
||||
switch key {
|
||||
case "config/connection":
|
||||
b.ResetDB(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Lease returns the lease information
|
||||
func (b *backend) Lease(ctx context.Context, s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get(ctx, "config/lease")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result configLease
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
The PostgreSQL backend dynamically generates database users.
|
||||
|
||||
After mounting this backend, configure it using the endpoints within
|
||||
the "config/" path.
|
||||
`
|
|
@ -1,532 +0,0 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"path"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
|
||||
postgreshelper "github.com/hashicorp/vault/helper/testhelpers/postgresql"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
func TestBackend_config_connection(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
configData := map[string]interface{}{
|
||||
"connection_url": "sample_connection_url",
|
||||
"max_open_connections": 9,
|
||||
"max_idle_connections": 7,
|
||||
"verify_connection": false,
|
||||
}
|
||||
|
||||
configReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/connection",
|
||||
Storage: config.StorageView,
|
||||
Data: configData,
|
||||
}
|
||||
resp, err = b.HandleRequest(context.Background(), configReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
configReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(context.Background(), configReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
delete(configData, "verify_connection")
|
||||
delete(configData, "connection_url")
|
||||
if !reflect.DeepEqual(configData, resp.Data) {
|
||||
t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, connURL := postgreshelper.PrepareTestContainer(t, "")
|
||||
defer cleanup()
|
||||
|
||||
connData := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connData, false),
|
||||
testAccStepCreateRole(t, "web", testRole, false),
|
||||
testAccStepReadCreds(t, b, config.StorageView, "web", connURL),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_roleCrud(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, connURL := postgreshelper.PrepareTestContainer(t, "")
|
||||
defer cleanup()
|
||||
|
||||
connData := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connData, false),
|
||||
testAccStepCreateRole(t, "web", testRole, false),
|
||||
testAccStepReadRole(t, "web", testRole),
|
||||
testAccStepDeleteRole(t, "web"),
|
||||
testAccStepReadRole(t, "web", ""),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_BlockStatements(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, connURL := postgreshelper.PrepareTestContainer(t, "")
|
||||
defer cleanup()
|
||||
|
||||
connData := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
jsonBlockStatement, err := json.Marshal(testBlockStatementRoleSlice)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connData, false),
|
||||
// This will also validate the query
|
||||
testAccStepCreateRole(t, "web-block", testBlockStatementRole, true),
|
||||
testAccStepCreateRole(t, "web-block", string(jsonBlockStatement), false),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_roleReadOnly(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, connURL := postgreshelper.PrepareTestContainer(t, "")
|
||||
defer cleanup()
|
||||
|
||||
connData := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connData, false),
|
||||
testAccStepCreateRole(t, "web", testRole, false),
|
||||
testAccStepCreateRole(t, "web-readonly", testReadOnlyRole, false),
|
||||
testAccStepReadRole(t, "web-readonly", testReadOnlyRole),
|
||||
testAccStepCreateTable(t, b, config.StorageView, "web", connURL),
|
||||
testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL),
|
||||
testAccStepDropTable(t, b, config.StorageView, "web", connURL),
|
||||
testAccStepDeleteRole(t, "web-readonly"),
|
||||
testAccStepDeleteRole(t, "web"),
|
||||
testAccStepReadRole(t, "web-readonly", ""),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_roleReadOnly_revocationSQL(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, connURL := postgreshelper.PrepareTestContainer(t, "")
|
||||
defer cleanup()
|
||||
|
||||
connData := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connData, false),
|
||||
testAccStepCreateRoleWithRevocationSQL(t, "web", testRole, defaultRevocationSQL, false),
|
||||
testAccStepCreateRoleWithRevocationSQL(t, "web-readonly", testReadOnlyRole, defaultRevocationSQL, false),
|
||||
testAccStepReadRole(t, "web-readonly", testReadOnlyRole),
|
||||
testAccStepCreateTable(t, b, config.StorageView, "web", connURL),
|
||||
testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL),
|
||||
testAccStepDropTable(t, b, config.StorageView, "web", connURL),
|
||||
testAccStepDeleteRole(t, "web-readonly"),
|
||||
testAccStepDeleteRole(t, "web"),
|
||||
testAccStepReadRole(t, "web-readonly", ""),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/connection",
|
||||
Data: d,
|
||||
ErrorOk: true,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if expectError {
|
||||
if resp.Data == nil {
|
||||
return fmt.Errorf("data is nil")
|
||||
}
|
||||
var e struct {
|
||||
Error string `mapstructure:"error"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &e); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(e.Error) == 0 {
|
||||
return fmt.Errorf("expected error, but write succeeded")
|
||||
}
|
||||
return nil
|
||||
} else if resp != nil && resp.IsError() {
|
||||
return fmt.Errorf("got an error response: %v", resp.Error())
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepCreateRole(t *testing.T, name string, sql string, expectFail bool) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: path.Join("roles", name),
|
||||
Data: map[string]interface{}{
|
||||
"sql": sql,
|
||||
},
|
||||
ErrorOk: expectFail,
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepCreateRoleWithRevocationSQL(t *testing.T, name, sql, revocationSQL string, expectFail bool) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: path.Join("roles", name),
|
||||
Data: map[string]interface{}{
|
||||
"sql": sql,
|
||||
"revocation_sql": revocationSQL,
|
||||
},
|
||||
ErrorOk: expectFail,
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDeleteRole(t *testing.T, name string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: path.Join("roles", name),
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadCreds(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: path.Join("creds", name),
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[TRACE] Generated credentials: %v", d)
|
||||
|
||||
db, err := sql.Open("pgx", connURL+"&timezone=utc")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
returnedRows := func() int {
|
||||
stmt, err := db.Prepare("SELECT DISTINCT schemaname FROM pg_tables WHERE has_table_privilege($1, 'information_schema.role_column_grants', 'select');")
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query(d.Username)
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
i := 0
|
||||
for rows.Next() {
|
||||
i++
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
// minNumPermissions is the minimum number of permissions that will always be present.
|
||||
const minNumPermissions = 2
|
||||
|
||||
userRows := returnedRows()
|
||||
if userRows < minNumPermissions {
|
||||
t.Fatalf("did not get expected number of rows, got %d", userRows)
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(context.Background(), &logical.Request{
|
||||
Operation: logical.RevokeOperation,
|
||||
Storage: s,
|
||||
Secret: &logical.Secret{
|
||||
InternalData: map[string]interface{}{
|
||||
"secret_type": "creds",
|
||||
"username": d.Username,
|
||||
"role": name,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp != nil {
|
||||
if resp.IsError() {
|
||||
return fmt.Errorf("error on resp: %#v", *resp)
|
||||
}
|
||||
}
|
||||
|
||||
userRows = returnedRows()
|
||||
// User shouldn't exist so returnedRows() should encounter an error and exit with -1
|
||||
if userRows != -1 {
|
||||
t.Fatalf("did not get expected number of rows, got %d", userRows)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepCreateTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: path.Join("creds", name),
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[TRACE] Generated credentials: %v", d)
|
||||
|
||||
db, err := sql.Open("pgx", connURL+"&timezone=utc")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.Exec("CREATE TABLE test (id SERIAL PRIMARY KEY);")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(context.Background(), &logical.Request{
|
||||
Operation: logical.RevokeOperation,
|
||||
Storage: s,
|
||||
Secret: &logical.Secret{
|
||||
InternalData: map[string]interface{}{
|
||||
"secret_type": "creds",
|
||||
"username": d.Username,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp != nil {
|
||||
if resp.IsError() {
|
||||
return fmt.Errorf("error on resp: %#v", *resp)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDropTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: path.Join("creds", name),
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[TRACE] Generated credentials: %v", d)
|
||||
|
||||
db, err := sql.Open("pgx", connURL+"&timezone=utc")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.Exec("DROP TABLE test;")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(context.Background(), &logical.Request{
|
||||
Operation: logical.RevokeOperation,
|
||||
Storage: s,
|
||||
Secret: &logical.Secret{
|
||||
InternalData: map[string]interface{}{
|
||||
"secret_type": "creds",
|
||||
"username": d.Username,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp != nil {
|
||||
if resp.IsError() {
|
||||
return fmt.Errorf("error on resp: %#v", *resp)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadRole(t *testing.T, name string, sql string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "roles/" + name,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp == nil {
|
||||
if sql == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
var d struct {
|
||||
SQL string `mapstructure:"sql"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.SQL != sql {
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const testRole = `
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
`
|
||||
|
||||
const testReadOnlyRole = `
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";
|
||||
`
|
||||
|
||||
const testBlockStatementRole = `
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
|
||||
CREATE ROLE "foo-role";
|
||||
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
|
||||
ALTER ROLE "foo-role" SET search_path = foo;
|
||||
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
|
||||
END IF;
|
||||
END
|
||||
$$
|
||||
|
||||
CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';
|
||||
GRANT "foo-role" TO "{{name}}";
|
||||
ALTER ROLE "{{name}}" SET search_path = foo;
|
||||
GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";
|
||||
`
|
||||
|
||||
var testBlockStatementRoleSlice = []string{
|
||||
`
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
|
||||
CREATE ROLE "foo-role";
|
||||
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
|
||||
ALTER ROLE "foo-role" SET search_path = foo;
|
||||
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
|
||||
END IF;
|
||||
END
|
||||
$$
|
||||
`,
|
||||
`CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`,
|
||||
`GRANT "foo-role" TO "{{name}}";`,
|
||||
`ALTER ROLE "{{name}}" SET search_path = foo;`,
|
||||
`GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`,
|
||||
}
|
||||
|
||||
const defaultRevocationSQL = `
|
||||
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}};
|
||||
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}};
|
||||
REVOKE USAGE ON SCHEMA public FROM {{name}};
|
||||
|
||||
DROP ROLE IF EXISTS {{name}};
|
||||
`
|
|
@ -1,29 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/logical/postgresql"
|
||||
"github.com/hashicorp/vault/sdk/plugin"
|
||||
)
|
||||
|
||||
func main() {
|
||||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(os.Args[1:])
|
||||
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := api.VaultPluginTLSProvider(tlsConfig)
|
||||
|
||||
if err := plugin.Serve(&plugin.ServeOpts{
|
||||
BackendFactoryFunc: postgresql.Factory,
|
||||
TLSProviderFunc: tlsProviderFunc,
|
||||
}); err != nil {
|
||||
logger := hclog.New(&hclog.LoggerOptions{})
|
||||
|
||||
logger.Error("plugin shutting down", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
|
@ -1,168 +0,0 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
)
|
||||
|
||||
func pathConfigConnection(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "config/connection",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"connection_url": {
|
||||
Type: framework.TypeString,
|
||||
Description: "DB connection string",
|
||||
},
|
||||
|
||||
"value": {
|
||||
Type: framework.TypeString,
|
||||
Description: `DB connection string. Use 'connection_url' instead.
|
||||
This will be deprecated.`,
|
||||
},
|
||||
|
||||
"verify_connection": {
|
||||
Type: framework.TypeBool,
|
||||
Default: true,
|
||||
Description: `If set, connection_url is verified by actually connecting to the database`,
|
||||
},
|
||||
|
||||
"max_open_connections": {
|
||||
Type: framework.TypeInt,
|
||||
Description: `Maximum number of open connections to the database;
|
||||
a zero uses the default value of two and a
|
||||
negative value means unlimited`,
|
||||
},
|
||||
|
||||
"max_idle_connections": {
|
||||
Type: framework.TypeInt,
|
||||
Description: `Maximum number of idle connections to the database;
|
||||
a zero uses the value of max_open_connections
|
||||
and a negative value disables idle connections.
|
||||
If larger than max_open_connections it will be
|
||||
reduced to the same size.`,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: b.pathConnectionWrite,
|
||||
logical.ReadOperation: b.pathConnectionRead,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigConnectionHelpSyn,
|
||||
HelpDescription: pathConfigConnectionHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
// pathConnectionRead reads out the connection configuration
|
||||
func (b *backend) pathConnectionRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
entry, err := req.Storage.Get(ctx, "config/connection")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read connection configuration")
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var config connectionConfig
|
||||
if err := entry.DecodeJSON(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"max_open_connections": config.MaxOpenConnections,
|
||||
"max_idle_connections": config.MaxIdleConnections,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathConnectionWrite(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
connValue := data.Get("value").(string)
|
||||
connURL := data.Get("connection_url").(string)
|
||||
if connURL == "" {
|
||||
if connValue == "" {
|
||||
return logical.ErrorResponse("connection_url parameter must be supplied"), nil
|
||||
} else {
|
||||
connURL = connValue
|
||||
}
|
||||
}
|
||||
|
||||
maxOpenConns := data.Get("max_open_connections").(int)
|
||||
if maxOpenConns == 0 {
|
||||
maxOpenConns = 2
|
||||
}
|
||||
|
||||
maxIdleConns := data.Get("max_idle_connections").(int)
|
||||
if maxIdleConns == 0 {
|
||||
maxIdleConns = maxOpenConns
|
||||
}
|
||||
if maxIdleConns > maxOpenConns {
|
||||
maxIdleConns = maxOpenConns
|
||||
}
|
||||
|
||||
// Don't check the connection_url if verification is disabled
|
||||
verifyConnection := data.Get("verify_connection").(bool)
|
||||
if verifyConnection {
|
||||
// Verify the string
|
||||
db, err := sql.Open("pgx", connURL)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Error validating connection info: %s", err)), nil
|
||||
}
|
||||
defer db.Close()
|
||||
if err := db.Ping(); err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Error validating connection info: %s", err)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{
|
||||
ConnectionString: connValue,
|
||||
ConnectionURL: connURL,
|
||||
MaxOpenConnections: maxOpenConns,
|
||||
MaxIdleConnections: maxIdleConns,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reset the DB connection
|
||||
b.ResetDB(ctx)
|
||||
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type connectionConfig struct {
|
||||
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
|
||||
// Deprecate "value" in coming releases
|
||||
ConnectionString string `json:"value" structs:"value" mapstructure:"value"`
|
||||
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
|
||||
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
|
||||
}
|
||||
|
||||
const pathConfigConnectionHelpSyn = `
|
||||
Configure the connection string to talk to PostgreSQL.
|
||||
`
|
||||
|
||||
const pathConfigConnectionHelpDesc = `
|
||||
This path configures the connection string used to connect to PostgreSQL.
|
||||
The value of the string can be a URL, or a PG style string in the
|
||||
format of "user=foo host=bar" etc.
|
||||
|
||||
The URL looks like:
|
||||
"postgresql://user:pass@host:port/dbname"
|
||||
|
||||
When configuring the connection string, the backend will verify its validity.
|
||||
`
|
|
@ -1,101 +0,0 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathConfigLease(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "config/lease",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"lease": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Default lease for roles.",
|
||||
},
|
||||
|
||||
"lease_max": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Maximum time a credential is valid for.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathLeaseRead,
|
||||
logical.UpdateOperation: b.pathLeaseWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigLeaseHelpSyn,
|
||||
HelpDescription: pathConfigLeaseHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathLeaseWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
leaseRaw := d.Get("lease").(string)
|
||||
leaseMaxRaw := d.Get("lease_max").(string)
|
||||
|
||||
lease, err := time.ParseDuration(leaseRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Invalid lease: %s", err)), nil
|
||||
}
|
||||
leaseMax, err := time.ParseDuration(leaseMaxRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Invalid lease: %s", err)), nil
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
|
||||
Lease: lease,
|
||||
LeaseMax: leaseMax,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathLeaseRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
lease, err := b.Lease(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lease == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"lease": lease.Lease.String(),
|
||||
"lease_max": lease.LeaseMax.String(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type configLease struct {
|
||||
Lease time.Duration
|
||||
LeaseMax time.Duration
|
||||
}
|
||||
|
||||
const pathConfigLeaseHelpSyn = `
|
||||
Configure the default lease information for generated credentials.
|
||||
`
|
||||
|
||||
const pathConfigLeaseHelpDesc = `
|
||||
This configures the default lease information used for credentials
|
||||
generated by this backend. The lease specifies the duration that a
|
||||
credential will be valid for, as well as the maximum session for
|
||||
a set of credentials.
|
||||
|
||||
The format for the lease is "1h" or integer and then unit. The longest
|
||||
unit is hour.
|
||||
`
|
|
@ -1,149 +0,0 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/dbtxn"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
)
|
||||
|
||||
func pathRoleCreate(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "creds/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathRoleCreateRead,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleCreateReadHelpSyn,
|
||||
HelpDescription: pathRoleCreateReadHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleCreateRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
|
||||
// Get the role
|
||||
role, err := b.Role(ctx, req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
|
||||
}
|
||||
|
||||
// Determine if we have a lease
|
||||
lease, err := b.Lease(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Unlike some other backends we need a lease here (can't leave as 0 and
|
||||
// let core fill it in) because Postgres also expires users as a safety
|
||||
// measure, so cannot be zero
|
||||
if lease == nil {
|
||||
lease = &configLease{
|
||||
Lease: b.System().DefaultLeaseTTL(),
|
||||
}
|
||||
}
|
||||
|
||||
// Generate the username, password and expiration. PG limits user to 63 characters
|
||||
displayName := req.DisplayName
|
||||
if len(displayName) > 26 {
|
||||
displayName = displayName[:26]
|
||||
}
|
||||
userUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
username := fmt.Sprintf("%s-%s", displayName, userUUID)
|
||||
if len(username) > 63 {
|
||||
username = username[:63]
|
||||
}
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ttl, _, err := framework.CalculateTTL(b.System(), 0, lease.Lease, 0, lease.LeaseMax, 0, time.Time{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiration := time.Now().
|
||||
Add(ttl).
|
||||
Format("2006-01-02 15:04:05-0700")
|
||||
|
||||
// Get our handle
|
||||
db, err := b.DB(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(role.SQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
m := map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
"expiration": expiration,
|
||||
}
|
||||
|
||||
if err := dbtxn.ExecuteTxQueryDirect(ctx, tx, m, query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the secret
|
||||
|
||||
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}, map[string]interface{}{
|
||||
"username": username,
|
||||
"role": name,
|
||||
})
|
||||
resp.Secret.TTL = lease.Lease
|
||||
resp.Secret.MaxTTL = lease.LeaseMax
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
const pathRoleCreateReadHelpSyn = `
|
||||
Request database credentials for a certain role.
|
||||
`
|
||||
|
||||
const pathRoleCreateReadHelpDesc = `
|
||||
This path reads database credentials for a certain role. The
|
||||
database credentials will be generated on demand and will be automatically
|
||||
revoked when the lease is up.
|
||||
`
|
|
@ -1,197 +0,0 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func pathListRoles(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "roles/?$",
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ListOperation: b.pathRoleList,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleHelpSyn,
|
||||
HelpDescription: pathRoleHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathRoles(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "roles/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role.",
|
||||
},
|
||||
|
||||
"sql": {
|
||||
Type: framework.TypeString,
|
||||
Description: "SQL string to create a user. See help for more info.",
|
||||
},
|
||||
|
||||
"revocation_sql": {
|
||||
Type: framework.TypeString,
|
||||
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
|
||||
string, a base64-encoded semicolon-separated string, a serialized JSON string
|
||||
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
|
||||
will be substituted.`,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathRoleRead,
|
||||
logical.UpdateOperation: b.pathRoleCreate,
|
||||
logical.DeleteOperation: b.pathRoleDelete,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleHelpSyn,
|
||||
HelpDescription: pathRoleHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) Role(ctx context.Context, s logical.Storage, n string) (*roleEntry, error) {
|
||||
entry, err := s.Get(ctx, "role/"+n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result roleEntry
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete(ctx, "role/"+data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
role, err := b.Role(ctx, req.Storage, data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"sql": role.SQL,
|
||||
"revocation_sql": role.RevocationSQL,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
entries, err := req.Storage.List(ctx, "role/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logical.ListResponse(entries), nil
|
||||
}
|
||||
|
||||
func (b *backend) pathRoleCreate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
sql := data.Get("sql").(string)
|
||||
|
||||
// Get our connection
|
||||
db, err := b.DB(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Test the query by trying to prepare it
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(sql, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := db.Prepare(Query(query, map[string]string{
|
||||
"name": "foo",
|
||||
"password": "bar",
|
||||
"expiration": "",
|
||||
}))
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Error testing query: %s", err)), nil
|
||||
}
|
||||
stmt.Close()
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
|
||||
SQL: sql,
|
||||
RevocationSQL: data.Get("revocation_sql").(string),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type roleEntry struct {
|
||||
SQL string `json:"sql" mapstructure:"sql" structs:"sql"`
|
||||
RevocationSQL string `json:"revocation_sql" mapstructure:"revocation_sql" structs:"revocation_sql"`
|
||||
}
|
||||
|
||||
const pathRoleHelpSyn = `
|
||||
Manage the roles that can be created with this backend.
|
||||
`
|
||||
|
||||
const pathRoleHelpDesc = `
|
||||
This path lets you manage the roles that can be created with this backend.
|
||||
|
||||
The "sql" parameter customizes the SQL string used to create the role.
|
||||
This can be a sequence of SQL queries. Some substitution will be done to the
|
||||
SQL string for certain keys. The names of the variables must be surrounded
|
||||
by "{{" and "}}" to be replaced.
|
||||
|
||||
* "name" - The random username generated for the DB user.
|
||||
|
||||
* "password" - The random password generated for the DB user.
|
||||
|
||||
* "expiration" - The timestamp when this user will expire.
|
||||
|
||||
Example of a decent SQL query to use:
|
||||
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
|
||||
Note the above user would be able to access everything in schema public.
|
||||
For more complex GRANT clauses, see the PostgreSQL manual.
|
||||
|
||||
The "revocation_sql" parameter customizes the SQL string used to revoke a user.
|
||||
Example of a decent revocation SQL query to use:
|
||||
|
||||
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}};
|
||||
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}};
|
||||
REVOKE USAGE ON SCHEMA public FROM {{name}};
|
||||
DROP ROLE IF EXISTS {{name}};
|
||||
`
|
|
@ -1,15 +0,0 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Query templates a query for us.
|
||||
func Query(tpl string, data map[string]string) string {
|
||||
for k, v := range data {
|
||||
tpl = strings.ReplaceAll(tpl, fmt.Sprintf("{{%s}}", k), v)
|
||||
}
|
||||
|
||||
return tpl
|
||||
}
|
|
@ -1,269 +0,0 @@
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/dbtxn"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
const SecretCredsType = "creds"
|
||||
|
||||
func secretCreds(b *backend) *framework.Secret {
|
||||
return &framework.Secret{
|
||||
Type: SecretCredsType,
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"username": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Username",
|
||||
},
|
||||
|
||||
"password": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Password",
|
||||
},
|
||||
},
|
||||
|
||||
Renew: b.secretCredsRenew,
|
||||
Revoke: b.secretCredsRevoke,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) secretCredsRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the username from the internal data
|
||||
usernameRaw, ok := req.Secret.InternalData["username"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret is missing username internal data")
|
||||
}
|
||||
username, ok := usernameRaw.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("usernameRaw is not a string")
|
||||
}
|
||||
// Get our connection
|
||||
db, err := b.DB(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the lease information
|
||||
lease, err := b.Lease(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lease == nil {
|
||||
lease = &configLease{}
|
||||
}
|
||||
|
||||
// Make sure we increase the VALID UNTIL endpoint for this user.
|
||||
ttl, _, err := framework.CalculateTTL(b.System(), req.Secret.Increment, lease.Lease, 0, lease.LeaseMax, 0, req.Secret.IssueTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ttl > 0 {
|
||||
expireTime := time.Now().Add(ttl)
|
||||
// Adding a small buffer since the TTL will be calculated again afeter this call
|
||||
// to ensure the database credential does not expire before the lease
|
||||
expireTime = expireTime.Add(5 * time.Second)
|
||||
expiration := expireTime.Format("2006-01-02 15:04:05-0700")
|
||||
|
||||
query := fmt.Sprintf(
|
||||
"ALTER ROLE %s VALID UNTIL '%s';",
|
||||
dbutil.QuoteIdentifier(username),
|
||||
expiration)
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
resp := &logical.Response{Secret: req.Secret}
|
||||
resp.Secret.TTL = lease.Lease
|
||||
resp.Secret.MaxTTL = lease.LeaseMax
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *backend) secretCredsRevoke(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the username from the internal data
|
||||
usernameRaw, ok := req.Secret.InternalData["username"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret is missing username internal data")
|
||||
}
|
||||
username, ok := usernameRaw.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("usernameRaw is not a string")
|
||||
}
|
||||
var revocationSQL string
|
||||
var resp *logical.Response
|
||||
|
||||
roleNameRaw, ok := req.Secret.InternalData["role"]
|
||||
if ok {
|
||||
role, err := b.Role(ctx, req.Storage, roleNameRaw.(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
if resp == nil {
|
||||
resp = &logical.Response{}
|
||||
}
|
||||
resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default revocation SQL.", roleNameRaw.(string)))
|
||||
} else {
|
||||
revocationSQL = role.RevocationSQL
|
||||
}
|
||||
}
|
||||
|
||||
// Get our connection
|
||||
db, err := b.DB(ctx, req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch revocationSQL {
|
||||
|
||||
// This is the default revocation logic. If revocation SQL is provided it
|
||||
// is simply executed as-is.
|
||||
case "":
|
||||
// Check if the role exists
|
||||
var exists bool
|
||||
err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Query for permissions; we need to revoke permissions before we can drop
|
||||
// the role
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
const initialNumRevocations = 16
|
||||
revocationStmts := make([]string, 0, initialNumRevocations)
|
||||
for rows.Next() {
|
||||
var schema string
|
||||
err = rows.Scan(&schema)
|
||||
if err != nil {
|
||||
// keep going; remove as many permissions as possible right now
|
||||
continue
|
||||
}
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`,
|
||||
dbutil.QuoteIdentifier(schema),
|
||||
dbutil.QuoteIdentifier(username)))
|
||||
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE USAGE ON SCHEMA %s FROM %s;`,
|
||||
dbutil.QuoteIdentifier(schema),
|
||||
dbutil.QuoteIdentifier(username)))
|
||||
}
|
||||
|
||||
// for good measure, revoke all privileges and usage on schema public
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`,
|
||||
dbutil.QuoteIdentifier(username)))
|
||||
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;",
|
||||
dbutil.QuoteIdentifier(username)))
|
||||
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
"REVOKE USAGE ON SCHEMA public FROM %s;",
|
||||
dbutil.QuoteIdentifier(username)))
|
||||
|
||||
// get the current database name so we can issue a REVOKE CONNECT for
|
||||
// this username
|
||||
var dbname sql.NullString
|
||||
if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dbname.Valid {
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE CONNECT ON DATABASE %s FROM %s;`,
|
||||
dbutil.QuoteIdentifier(dbname.String),
|
||||
dbutil.QuoteIdentifier(username)))
|
||||
}
|
||||
|
||||
// again, here, we do not stop on error, as we want to remove as
|
||||
// many permissions as possible right now
|
||||
var lastStmtError error
|
||||
for _, query := range revocationStmts {
|
||||
if err := dbtxn.ExecuteDBQueryDirect(ctx, db, nil, query); err != nil {
|
||||
lastStmtError = err
|
||||
}
|
||||
}
|
||||
|
||||
// can't drop if not all privileges are revoked
|
||||
if rows.Err() != nil {
|
||||
return nil, fmt.Errorf("could not generate revocation statements for all rows: %w", rows.Err())
|
||||
}
|
||||
if lastStmtError != nil {
|
||||
return nil, fmt.Errorf("could not perform all revocation statements: %w", lastStmtError)
|
||||
}
|
||||
|
||||
// Drop this user
|
||||
stmt, err = db.Prepare(fmt.Sprintf(
|
||||
`DROP ROLE IF EXISTS %s;`, dbutil.QuoteIdentifier(username)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We have revocation SQL, execute directly, within a transaction
|
||||
default:
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
m := map[string]string{
|
||||
"name": username,
|
||||
}
|
||||
if err := dbtxn.ExecuteTxQueryDirect(ctx, tx, m, query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
```release-note:improvement
|
||||
plugins: Mark logical database plugins Removed and remove the plugin code.
|
||||
```
|
||||
```release-note:improvement
|
||||
plugins: Mark app-id auth method Removed and remove the plugin code.
|
||||
```
|
|
@ -49,18 +49,6 @@ func TestAuthEnableCommand_Run(t *testing.T) {
|
|||
"",
|
||||
2,
|
||||
},
|
||||
{
|
||||
"deprecated builtin with standard mount",
|
||||
[]string{"app-id"},
|
||||
"mount entry associated with pending removal builtin",
|
||||
2,
|
||||
},
|
||||
{
|
||||
"deprecated builtin with different mount",
|
||||
[]string{"-path=/tmp", "app-id"},
|
||||
"mount entry associated with pending removal builtin",
|
||||
2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
|
|
|
@ -343,11 +343,9 @@ func TestPredict_Plugins(t *testing.T) {
|
|||
[]string{
|
||||
"ad",
|
||||
"alicloud",
|
||||
"app-id",
|
||||
"approle",
|
||||
"aws",
|
||||
"azure",
|
||||
"cassandra",
|
||||
"cassandra-database-plugin",
|
||||
"centrify",
|
||||
"cert",
|
||||
|
@ -367,13 +365,10 @@ func TestPredict_Plugins(t *testing.T) {
|
|||
"kubernetes",
|
||||
"kv",
|
||||
"ldap",
|
||||
"mongodb",
|
||||
"mongodb-database-plugin",
|
||||
"mongodbatlas",
|
||||
"mongodbatlas-database-plugin",
|
||||
"mssql",
|
||||
"mssql-database-plugin",
|
||||
"mysql",
|
||||
"mysql-aurora-database-plugin",
|
||||
"mysql-database-plugin",
|
||||
"mysql-legacy-database-plugin",
|
||||
|
@ -385,7 +380,6 @@ func TestPredict_Plugins(t *testing.T) {
|
|||
"openldap",
|
||||
"pcf", // Deprecated.
|
||||
"pki",
|
||||
"postgresql",
|
||||
"postgresql-database-plugin",
|
||||
"rabbitmq",
|
||||
"radius",
|
||||
|
@ -439,7 +433,7 @@ func TestPredict_Plugins(t *testing.T) {
|
|||
}
|
||||
}
|
||||
if !reflect.DeepEqual(act, tc.exp) {
|
||||
t.Errorf("expected:%q, got: %q", tc.exp, act)
|
||||
t.Errorf("expected: %q, got: %q, diff: %v", tc.exp, act, strutil.Difference(act, tc.exp, true))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,92 +0,0 @@
|
|||
package command
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
|
||||
credAppId "github.com/hashicorp/vault/builtin/credential/app-id"
|
||||
)
|
||||
|
||||
func TestPathMap_Upgrade_API(t *testing.T) {
|
||||
var err error
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
CredentialBackends: map[string]logical.Factory{
|
||||
"app-id": credAppId.Factory,
|
||||
},
|
||||
PendingRemovalMountsAllowed: true,
|
||||
}
|
||||
|
||||
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
})
|
||||
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
|
||||
cores := cluster.Cores
|
||||
|
||||
vault.TestWaitActive(t, cores[0].Core)
|
||||
|
||||
client := cores[0].Client
|
||||
|
||||
// Enable the app-id method
|
||||
err = client.Sys().EnableAuthWithOptions("app-id", &api.EnableAuthOptions{
|
||||
Type: "app-id",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create an app-id
|
||||
_, err = client.Logical().Write("auth/app-id/map/app-id/test-app-id", map[string]interface{}{
|
||||
"policy": "test-policy",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a user-id
|
||||
_, err = client.Logical().Write("auth/app-id/map/user-id/test-user-id", map[string]interface{}{
|
||||
"value": "test-app-id",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Perform a login. It should succeed.
|
||||
_, err = client.Logical().Write("auth/app-id/login", map[string]interface{}{
|
||||
"app_id": "test-app-id",
|
||||
"user_id": "test-user-id",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// List the hashed app-ids in the storage
|
||||
secret, err := client.Logical().List("auth/app-id/map/app-id")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hashedAppID := secret.Data["keys"].([]interface{})[0].(string)
|
||||
|
||||
// Try reading it. This used to cause an issue which is fixed in [GH-3806].
|
||||
_, err = client.Logical().Read("auth/app-id/map/app-id/" + hashedAppID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Ensure that there was no issue by performing another login
|
||||
_, err = client.Logical().Write("auth/app-id/login", map[string]interface{}{
|
||||
"app_id": "test-app-id",
|
||||
"user_id": "test-user-id",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
1
go.mod
1
go.mod
|
@ -202,7 +202,6 @@ require (
|
|||
google.golang.org/grpc v1.50.1
|
||||
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0
|
||||
google.golang.org/protobuf v1.28.1
|
||||
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce
|
||||
gopkg.in/ory-am/dockertest.v3 v3.3.4
|
||||
gopkg.in/square/go-jose.v2 v2.6.0
|
||||
k8s.io/utils v0.0.0-20220728103510-ee6ede2d64ed
|
||||
|
|
2
go.sum
2
go.sum
|
@ -2599,8 +2599,6 @@ gopkg.in/ini.v1 v1.66.2 h1:XfR1dOYubytKy4Shzc2LHrrGhU0lDCfDGG1yLPmpgsI=
|
|||
gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/jcmturner/goidentity.v3 v3.0.0 h1:1duIyWiTaYvVx3YX2CYtpJbUFd7/UuPYCfgXtQ3VTbI=
|
||||
gopkg.in/jcmturner/goidentity.v3 v3.0.0/go.mod h1:oG2kH0IvSYNIu80dVAyu/yoefjq1mNfM5bm88whjWx4=
|
||||
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce h1:xcEWjVhvbDy+nHP67nPDDpbYrY+ILlfndk4bRioVHaU=
|
||||
gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k=
|
||||
gopkg.in/ory-am/dockertest.v3 v3.3.4 h1:oen8RiwxVNxtQ1pRoV4e4jqh6UjNsOuIZ1NXns6jdcw=
|
||||
gopkg.in/ory-am/dockertest.v3 v3.3.4/go.mod h1:s9mmoLkaGeAh97qygnNj4xWkiN7e1SKekYC6CovU+ek=
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package builtinplugins
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
credAliCloud "github.com/hashicorp/vault-plugin-auth-alicloud"
|
||||
credAzure "github.com/hashicorp/vault-plugin-auth-azure"
|
||||
credCentrify "github.com/hashicorp/vault-plugin-auth-centrify"
|
||||
|
@ -26,7 +28,6 @@ import (
|
|||
logicalMongoAtlas "github.com/hashicorp/vault-plugin-secrets-mongodbatlas"
|
||||
logicalLDAP "github.com/hashicorp/vault-plugin-secrets-openldap"
|
||||
logicalTerraform "github.com/hashicorp/vault-plugin-secrets-terraform"
|
||||
credAppId "github.com/hashicorp/vault/builtin/credential/app-id"
|
||||
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
|
||||
credAws "github.com/hashicorp/vault/builtin/credential/aws"
|
||||
credCert "github.com/hashicorp/vault/builtin/credential/cert"
|
||||
|
@ -36,14 +37,9 @@ import (
|
|||
credRadius "github.com/hashicorp/vault/builtin/credential/radius"
|
||||
credUserpass "github.com/hashicorp/vault/builtin/credential/userpass"
|
||||
logicalAws "github.com/hashicorp/vault/builtin/logical/aws"
|
||||
logicalCass "github.com/hashicorp/vault/builtin/logical/cassandra"
|
||||
logicalConsul "github.com/hashicorp/vault/builtin/logical/consul"
|
||||
logicalMongo "github.com/hashicorp/vault/builtin/logical/mongodb"
|
||||
logicalMssql "github.com/hashicorp/vault/builtin/logical/mssql"
|
||||
logicalMysql "github.com/hashicorp/vault/builtin/logical/mysql"
|
||||
logicalNomad "github.com/hashicorp/vault/builtin/logical/nomad"
|
||||
logicalPki "github.com/hashicorp/vault/builtin/logical/pki"
|
||||
logicalPostgres "github.com/hashicorp/vault/builtin/logical/postgresql"
|
||||
logicalRabbit "github.com/hashicorp/vault/builtin/logical/rabbitmq"
|
||||
logicalSsh "github.com/hashicorp/vault/builtin/logical/ssh"
|
||||
logicalTotp "github.com/hashicorp/vault/builtin/logical/totp"
|
||||
|
@ -56,6 +52,7 @@ import (
|
|||
dbMysql "github.com/hashicorp/vault/plugins/database/mysql"
|
||||
dbPostgres "github.com/hashicorp/vault/plugins/database/postgresql"
|
||||
dbRedshift "github.com/hashicorp/vault/plugins/database/redshift"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
@ -86,13 +83,23 @@ type logicalBackend struct {
|
|||
consts.DeprecationStatus
|
||||
}
|
||||
|
||||
type removedBackend struct {
|
||||
*framework.Backend
|
||||
}
|
||||
|
||||
func removedFactory(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) {
|
||||
removedBackend := &removedBackend{}
|
||||
removedBackend.Backend = &framework.Backend{}
|
||||
return removedBackend, nil
|
||||
}
|
||||
|
||||
func newRegistry() *registry {
|
||||
reg := ®istry{
|
||||
credentialBackends: map[string]credentialBackend{
|
||||
"alicloud": {Factory: credAliCloud.Factory},
|
||||
"app-id": {
|
||||
Factory: credAppId.Factory,
|
||||
DeprecationStatus: consts.PendingRemoval,
|
||||
Factory: removedFactory,
|
||||
DeprecationStatus: consts.Removed,
|
||||
},
|
||||
"approle": {Factory: credAppRole.Factory},
|
||||
"aws": {Factory: credAws.Factory},
|
||||
|
@ -144,8 +151,8 @@ func newRegistry() *registry {
|
|||
"aws": {Factory: logicalAws.Factory},
|
||||
"azure": {Factory: logicalAzure.Factory},
|
||||
"cassandra": {
|
||||
Factory: logicalCass.Factory,
|
||||
DeprecationStatus: consts.PendingRemoval,
|
||||
Factory: removedFactory,
|
||||
DeprecationStatus: consts.Removed,
|
||||
},
|
||||
"consul": {Factory: logicalConsul.Factory},
|
||||
"gcp": {Factory: logicalGcp.Factory},
|
||||
|
@ -153,25 +160,27 @@ func newRegistry() *registry {
|
|||
"kubernetes": {Factory: logicalKube.Factory},
|
||||
"kv": {Factory: logicalKv.Factory},
|
||||
"mongodb": {
|
||||
Factory: logicalMongo.Factory,
|
||||
DeprecationStatus: consts.PendingRemoval,
|
||||
Factory: removedFactory,
|
||||
DeprecationStatus: consts.Removed,
|
||||
},
|
||||
// The mongodbatlas secrets engine is not the same as the database plugin equivalent
|
||||
// (`mongodbatlas-database-plugin`), and thus will not be deprecated at this time.
|
||||
"mongodbatlas": {Factory: logicalMongoAtlas.Factory},
|
||||
"mssql": {
|
||||
Factory: logicalMssql.Factory,
|
||||
DeprecationStatus: consts.PendingRemoval,
|
||||
Factory: removedFactory,
|
||||
DeprecationStatus: consts.Removed,
|
||||
},
|
||||
"mysql": {
|
||||
Factory: logicalMysql.Factory,
|
||||
DeprecationStatus: consts.PendingRemoval,
|
||||
Factory: removedFactory,
|
||||
DeprecationStatus: consts.Removed,
|
||||
},
|
||||
"nomad": {Factory: logicalNomad.Factory},
|
||||
"openldap": {Factory: logicalLDAP.Factory},
|
||||
"ldap": {Factory: logicalLDAP.Factory},
|
||||
"pki": {Factory: logicalPki.Factory},
|
||||
"postgresql": {
|
||||
Factory: logicalPostgres.Factory,
|
||||
DeprecationStatus: consts.PendingRemoval,
|
||||
Factory: removedFactory,
|
||||
DeprecationStatus: consts.Removed,
|
||||
},
|
||||
"rabbitmq": {Factory: logicalRabbit.Factory},
|
||||
"ssh": {Factory: logicalSsh.Factory},
|
||||
|
@ -222,16 +231,16 @@ func (r *registry) Keys(pluginType consts.PluginType) []string {
|
|||
var keys []string
|
||||
switch pluginType {
|
||||
case consts.PluginTypeDatabase:
|
||||
for key := range r.databasePlugins {
|
||||
keys = append(keys, key)
|
||||
for key, backend := range r.databasePlugins {
|
||||
keys = appendIfNotRemoved(keys, key, backend.DeprecationStatus)
|
||||
}
|
||||
case consts.PluginTypeCredential:
|
||||
for key := range r.credentialBackends {
|
||||
keys = append(keys, key)
|
||||
for key, backend := range r.credentialBackends {
|
||||
keys = appendIfNotRemoved(keys, key, backend.DeprecationStatus)
|
||||
}
|
||||
case consts.PluginTypeSecrets:
|
||||
for key := range r.logicalBackends {
|
||||
keys = append(keys, key)
|
||||
for key, backend := range r.logicalBackends {
|
||||
keys = appendIfNotRemoved(keys, key, backend.DeprecationStatus)
|
||||
}
|
||||
}
|
||||
return keys
|
||||
|
@ -273,3 +282,10 @@ func toFunc(ifc interface{}) func() (interface{}, error) {
|
|||
return ifc, nil
|
||||
}
|
||||
}
|
||||
|
||||
func appendIfNotRemoved(keys []string, name string, status consts.DeprecationStatus) []string {
|
||||
if status != consts.Removed {
|
||||
return append(keys, name)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"reflect"
|
||||
"testing"
|
||||
|
||||
credAppId "github.com/hashicorp/vault/builtin/credential/app-id"
|
||||
credUserpass "github.com/hashicorp/vault/builtin/credential/userpass"
|
||||
dbMysql "github.com/hashicorp/vault/plugins/database/mysql"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
)
|
||||
|
@ -35,9 +35,16 @@ func Test_RegistryGet(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "known builtin lookup",
|
||||
builtin: "userpass",
|
||||
pluginType: consts.PluginTypeCredential,
|
||||
want: toFunc(credUserpass.Factory),
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "removed builtin lookup",
|
||||
builtin: "app-id",
|
||||
pluginType: consts.PluginTypeCredential,
|
||||
want: toFunc(credAppId.Factory),
|
||||
want: nil,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
|
@ -81,7 +88,7 @@ func Test_RegistryKeyCounts(t *testing.T) {
|
|||
{
|
||||
name: "number of auth plugins",
|
||||
pluginType: consts.PluginTypeCredential,
|
||||
want: 20,
|
||||
want: 19,
|
||||
},
|
||||
{
|
||||
name: "number of database plugins",
|
||||
|
@ -91,7 +98,7 @@ func Test_RegistryKeyCounts(t *testing.T) {
|
|||
{
|
||||
name: "number of secrets plugins",
|
||||
pluginType: consts.PluginTypeSecrets,
|
||||
want: 24,
|
||||
want: 19,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -126,10 +133,16 @@ func Test_RegistryContains(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "known builtin lookup",
|
||||
builtin: "app-id",
|
||||
builtin: "approle",
|
||||
pluginType: consts.PluginTypeCredential,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "removed builtin lookup",
|
||||
builtin: "app-id",
|
||||
pluginType: consts.PluginTypeCredential,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -186,10 +199,10 @@ func Test_RegistryStatus(t *testing.T) {
|
|||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "pending removal builtin lookup",
|
||||
name: "removed builtin lookup",
|
||||
builtin: "app-id",
|
||||
pluginType: consts.PluginTypeCredential,
|
||||
want: consts.PendingRemoval,
|
||||
want: consts.Removed,
|
||||
wantOk: true,
|
||||
},
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue