commit
bb910f2bb9
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -48,7 +48,9 @@ Vagrantfile
|
|||
# Configs
|
||||
*.hcl
|
||||
!command/agent/config/test-fixtures/config.hcl
|
||||
!command/agent/config/test-fixtures/config-cache.hcl
|
||||
!command/agent/config/test-fixtures/config-embedded-type.hcl
|
||||
!command/agent/config/test-fixtures/config-cache-embedded-type.hcl
|
||||
|
||||
.DS_Store
|
||||
.idea
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
## Next
|
||||
## 1.0.3 (February 12th, 2019)
|
||||
|
||||
CHANGES:
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const EnvVaultAgentAddress = "VAULT_AGENT_ADDR"
|
||||
const EnvVaultAddress = "VAULT_ADDR"
|
||||
const EnvVaultCACert = "VAULT_CACERT"
|
||||
const EnvVaultCAPath = "VAULT_CAPATH"
|
||||
|
@ -237,6 +238,10 @@ func (c *Config) ReadEnvironment() error {
|
|||
if v := os.Getenv(EnvVaultAddress); v != "" {
|
||||
envAddress = v
|
||||
}
|
||||
// Agent's address will take precedence over Vault's address
|
||||
if v := os.Getenv(EnvVaultAgentAddress); v != "" {
|
||||
envAddress = v
|
||||
}
|
||||
if v := os.Getenv(EnvVaultMaxRetries); v != "" {
|
||||
maxRetries, err := strconv.ParseUint(v, 10, 32)
|
||||
if err != nil {
|
||||
|
@ -366,11 +371,6 @@ func NewClient(c *Config) (*Client, error) {
|
|||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
|
||||
u, err := url.Parse(c.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.HttpClient == nil {
|
||||
c.HttpClient = def.HttpClient
|
||||
}
|
||||
|
@ -378,6 +378,21 @@ func NewClient(c *Config) (*Client, error) {
|
|||
c.HttpClient.Transport = def.HttpClient.Transport
|
||||
}
|
||||
|
||||
if strings.HasPrefix(c.Address, "unix://") {
|
||||
socket := strings.TrimPrefix(c.Address, "unix://")
|
||||
transport := c.HttpClient.Transport.(*http.Transport)
|
||||
transport.DialContext = func(context.Context, string, string) (net.Conn, error) {
|
||||
return net.Dial("unix", socket)
|
||||
}
|
||||
// TODO: This shouldn't ideally be done. To be fixed post 1.1-beta.
|
||||
c.Address = "http://unix"
|
||||
}
|
||||
|
||||
u, err := url.Parse(c.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
addr: u,
|
||||
config: c,
|
||||
|
|
|
@ -292,6 +292,7 @@ type SecretAuth struct {
|
|||
TokenPolicies []string `json:"token_policies"`
|
||||
IdentityPolicies []string `json:"identity_policies"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
Orphan bool `json:"orphan"`
|
||||
|
||||
LeaseDuration int `json:"lease_duration"`
|
||||
Renewable bool `json:"renewable"`
|
||||
|
|
|
@ -25,14 +25,17 @@ func pathConfig(b *backend) *framework.Path {
|
|||
Description: `The API endpoint to use. Useful if you
|
||||
are running GitHub Enterprise or an
|
||||
API-compatible authentication server.`,
|
||||
DisplayName: "Base URL",
|
||||
},
|
||||
"ttl": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: `Duration after which authentication will be expired`,
|
||||
DisplayName: "TTL",
|
||||
},
|
||||
"max_ttl": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: `Maximum duration after which authentication will be expired`,
|
||||
DisplayName: "Max TTL",
|
||||
},
|
||||
},
|
||||
|
||||
|
|
|
@ -25,26 +25,32 @@ func pathConfig(b *backend) *framework.Path {
|
|||
"organization": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "(DEPRECATED) Okta organization to authenticate against. Use org_name instead.",
|
||||
Deprecated: true,
|
||||
},
|
||||
"org_name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the organization to be used in the Okta API.",
|
||||
DisplayName: "Organization Name",
|
||||
},
|
||||
"token": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "(DEPRECATED) Okta admin API token. Use api_token instead.",
|
||||
Deprecated: true,
|
||||
},
|
||||
"api_token": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Okta API key.",
|
||||
DisplayName: "API Token",
|
||||
},
|
||||
"base_url": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: `The base domain to use for the Okta API. When not specified in the configuration, "okta.com" is used.`,
|
||||
DisplayName: "Base URL",
|
||||
},
|
||||
"production": &framework.FieldSchema{
|
||||
Type: framework.TypeBool,
|
||||
Description: `(DEPRECATED) Use base_url.`,
|
||||
Deprecated: true,
|
||||
},
|
||||
"ttl": &framework.FieldSchema{
|
||||
Type: framework.TypeDurationSecond,
|
||||
|
@ -57,6 +63,7 @@ func pathConfig(b *backend) *framework.Path {
|
|||
"bypass_okta_mfa": &framework.FieldSchema{
|
||||
Type: framework.TypeBool,
|
||||
Description: `When set true, requests by Okta for a MFA check will be bypassed. This also disallows certain status checks on the account, such as whether the password is expired.`,
|
||||
DisplayName: "Bypass Okta MFA",
|
||||
},
|
||||
},
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ func pathConfig(b *backend) *framework.Path {
|
|||
"host": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "RADIUS server host",
|
||||
DisplayName: "Host",
|
||||
},
|
||||
|
||||
"port": &framework.FieldSchema{
|
||||
|
@ -30,6 +31,7 @@ func pathConfig(b *backend) *framework.Path {
|
|||
Type: framework.TypeString,
|
||||
Default: "",
|
||||
Description: "Comma-separated list of policies to grant upon successful RADIUS authentication of an unregisted user (default: emtpy)",
|
||||
DisplayName: "Policies for unregistered users",
|
||||
},
|
||||
"dial_timeout": &framework.FieldSchema{
|
||||
Type: framework.TypeDurationSecond,
|
||||
|
@ -45,11 +47,13 @@ func pathConfig(b *backend) *framework.Path {
|
|||
Type: framework.TypeInt,
|
||||
Default: 10,
|
||||
Description: "RADIUS NAS port field (default: 10)",
|
||||
DisplayName: "NAS Port",
|
||||
},
|
||||
"nas_identifier": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Default: "",
|
||||
Description: "RADIUS NAS Identifier field (optional)",
|
||||
DisplayName: "NAS Identifier",
|
||||
},
|
||||
},
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ func pathRoles(b *backend) *framework.Path {
|
|||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the policy",
|
||||
DisplayName: "Policy Name",
|
||||
},
|
||||
|
||||
"credential_type": &framework.FieldSchema{
|
||||
|
@ -46,11 +47,13 @@ func pathRoles(b *backend) *framework.Path {
|
|||
"role_arns": &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: "ARNs of AWS roles allowed to be assumed. Only valid when credential_type is " + assumedRoleCred,
|
||||
DisplayName: "Role ARNs",
|
||||
},
|
||||
|
||||
"policy_arns": &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: "ARNs of AWS policies to attach to IAM users. Only valid when credential_type is " + iamUserCred,
|
||||
DisplayName: "Policy ARNs",
|
||||
},
|
||||
|
||||
"policy_document": &framework.FieldSchema{
|
||||
|
@ -65,22 +68,26 @@ GetFederationToken API call, acting as a filter on permissions available.`,
|
|||
"default_sts_ttl": &framework.FieldSchema{
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: fmt.Sprintf("Default TTL for %s and %s credential types when no TTL is explicitly requested with the credentials", assumedRoleCred, federationTokenCred),
|
||||
DisplayName: "Default TTL",
|
||||
},
|
||||
|
||||
"max_sts_ttl": &framework.FieldSchema{
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: fmt.Sprintf("Max allowed TTL for %s and %s credential types", assumedRoleCred, federationTokenCred),
|
||||
DisplayName: "Max TTL",
|
||||
},
|
||||
|
||||
"arn": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: `Deprecated; use role_arns or policy_arns instead. ARN Reference to a managed policy
|
||||
or IAM role to assume`,
|
||||
Deprecated: true,
|
||||
},
|
||||
|
||||
"policy": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Deprecated; use policy_document instead. IAM policy document",
|
||||
Deprecated: true,
|
||||
},
|
||||
},
|
||||
|
||||
|
|
|
@ -35,11 +35,12 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
|
|||
|
||||
// pluginSets is the map of plugins we can dispense.
|
||||
pluginSets := map[int]plugin.PluginSet{
|
||||
// Version 3 supports both protocols
|
||||
// Version 3 used to supports both protocols. We want to keep it around
|
||||
// since it's possible old plugins built against this version will still
|
||||
// work with gRPC. There is currently no difference between version 3
|
||||
// and version 4.
|
||||
3: plugin.PluginSet{
|
||||
"database": &DatabasePlugin{
|
||||
GRPCDatabasePlugin: new(GRPCDatabasePlugin),
|
||||
},
|
||||
"database": new(GRPCDatabasePlugin),
|
||||
},
|
||||
// Version 4 only supports gRPC
|
||||
4: plugin.PluginSet{
|
||||
|
@ -76,9 +77,6 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
|
|||
switch raw.(type) {
|
||||
case *gRPCClient:
|
||||
db = raw.(*gRPCClient)
|
||||
case *databasePluginRPCClient:
|
||||
logger.Warn("plugin is using deprecated netRPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name)
|
||||
db = raw.(*databasePluginRPCClient)
|
||||
default:
|
||||
return nil, errors.New("unsupported client type")
|
||||
}
|
||||
|
|
|
@ -1,197 +0,0 @@
|
|||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---- RPC server domain ----
|
||||
|
||||
// databasePluginRPCServer implements an RPC version of Database and is run
|
||||
// inside a plugin. It wraps an underlying implementation of Database.
|
||||
type databasePluginRPCServer struct {
|
||||
impl Database
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
|
||||
var err error
|
||||
*resp, err = ds.impl.Type()
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequestRPC, resp *CreateUserResponse) error {
|
||||
var err error
|
||||
resp.Username, resp.Password, err = ds.impl.CreateUser(context.Background(), args.Statements, args.UsernameConfig, args.Expiration)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequestRPC, _ *struct{}) error {
|
||||
err := ds.impl.RenewUser(context.Background(), args.Statements, args.Username, args.Expiration)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *struct{}) error {
|
||||
err := ds.impl.RevokeUser(context.Background(), args.Statements, args.Username)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) RotateRootCredentials(args *RotateRootCredentialsRequestRPC, resp *RotateRootCredentialsResponse) error {
|
||||
config, err := ds.impl.RotateRootCredentials(context.Background(), args.Statements)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Config, err = json.Marshal(config)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error {
|
||||
return ds.Init(&InitRequestRPC{
|
||||
Config: args.Config,
|
||||
VerifyConnection: args.VerifyConnection,
|
||||
}, &InitResponse{})
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Init(args *InitRequestRPC, resp *InitResponse) error {
|
||||
config, err := ds.impl.Init(context.Background(), args.Config, args.VerifyConnection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Config, err = json.Marshal(config)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
|
||||
ds.impl.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---- RPC client domain ----
|
||||
// databasePluginRPCClient implements Database and is used on the client to
|
||||
// make RPC calls to a plugin.
|
||||
type databasePluginRPCClient struct {
|
||||
client *rpc.Client
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Type() (string, error) {
|
||||
var dbType string
|
||||
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
|
||||
|
||||
return fmt.Sprintf("plugin-%s", dbType), err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) CreateUser(_ context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
req := CreateUserRequestRPC{
|
||||
Statements: statements,
|
||||
UsernameConfig: usernameConfig,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
var resp CreateUserResponse
|
||||
err = dr.client.Call("Plugin.CreateUser", req, &resp)
|
||||
|
||||
return resp.Username, resp.Password, err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements Statements, username string, expiration time.Time) error {
|
||||
req := RenewUserRequestRPC{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
return dr.client.Call("Plugin.RenewUser", req, &struct{}{})
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error {
|
||||
req := RevokeUserRequestRPC{
|
||||
Statements: statements,
|
||||
Username: username,
|
||||
}
|
||||
|
||||
return dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) RotateRootCredentials(_ context.Context, statements []string) (saveConf map[string]interface{}, err error) {
|
||||
req := RotateRootCredentialsRequestRPC{
|
||||
Statements: statements,
|
||||
}
|
||||
|
||||
var resp RotateRootCredentialsResponse
|
||||
err = dr.client.Call("Plugin.RotateRootCredentials", req, &resp)
|
||||
|
||||
err = json.Unmarshal(resp.Config, &saveConf)
|
||||
return saveConf, err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||
_, err := dr.Init(nil, conf, verifyConnection)
|
||||
return err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Init(_ context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
|
||||
req := InitRequestRPC{
|
||||
Config: conf,
|
||||
VerifyConnection: verifyConnection,
|
||||
}
|
||||
|
||||
var resp InitResponse
|
||||
err = dr.client.Call("Plugin.Init", req, &resp)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "can't find method Plugin.Init") {
|
||||
req := InitializeRequestRPC{
|
||||
Config: conf,
|
||||
VerifyConnection: verifyConnection,
|
||||
}
|
||||
|
||||
err = dr.client.Call("Plugin.Initialize", req, &struct{}{})
|
||||
if err == nil {
|
||||
return conf, nil
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(resp.Config, &saveConf)
|
||||
return saveConf, err
|
||||
}
|
||||
|
||||
func (dr *databasePluginRPCClient) Close() error {
|
||||
return dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
|
||||
}
|
||||
|
||||
// ---- RPC Request Args Domain ----
|
||||
|
||||
type InitializeRequestRPC struct {
|
||||
Config map[string]interface{}
|
||||
VerifyConnection bool
|
||||
}
|
||||
|
||||
type InitRequestRPC struct {
|
||||
Config map[string]interface{}
|
||||
VerifyConnection bool
|
||||
}
|
||||
|
||||
type CreateUserRequestRPC struct {
|
||||
Statements Statements
|
||||
UsernameConfig UsernameConfig
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RenewUserRequestRPC struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
Expiration time.Time
|
||||
}
|
||||
|
||||
type RevokeUserRequestRPC struct {
|
||||
Statements Statements
|
||||
Username string
|
||||
}
|
||||
|
||||
type RotateRootCredentialsRequestRPC struct {
|
||||
Statements []string
|
||||
}
|
|
@ -3,7 +3,6 @@ package dbplugin
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
@ -72,8 +71,6 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu
|
|||
switch db.(*DatabasePluginClient).Database.(type) {
|
||||
case *gRPCClient:
|
||||
transport = "gRPC"
|
||||
case *databasePluginRPCClient:
|
||||
transport = "netRPC"
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -110,17 +107,9 @@ var handshakeConfig = plugin.HandshakeConfig{
|
|||
MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb",
|
||||
}
|
||||
|
||||
var _ plugin.Plugin = &DatabasePlugin{}
|
||||
var _ plugin.GRPCPlugin = &DatabasePlugin{}
|
||||
var _ plugin.Plugin = &GRPCDatabasePlugin{}
|
||||
var _ plugin.GRPCPlugin = &GRPCDatabasePlugin{}
|
||||
|
||||
// DatabasePlugin implements go-plugin's Plugin interface. It has methods for
|
||||
// retrieving a server and a client instance of the plugin.
|
||||
type DatabasePlugin struct {
|
||||
*GRPCDatabasePlugin
|
||||
}
|
||||
|
||||
// GRPCDatabasePlugin is the plugin.Plugin implementation that only supports GRPC
|
||||
// transport
|
||||
type GRPCDatabasePlugin struct {
|
||||
|
@ -130,17 +119,6 @@ type GRPCDatabasePlugin struct {
|
|||
plugin.NetRPCUnsupportedPlugin
|
||||
}
|
||||
|
||||
func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) {
|
||||
impl := &DatabaseErrorSanitizerMiddleware{
|
||||
next: d.Impl,
|
||||
}
|
||||
return &databasePluginRPCServer{impl: impl}, nil
|
||||
}
|
||||
|
||||
func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
|
||||
return &databasePluginRPCClient{client: c}, nil
|
||||
}
|
||||
|
||||
func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error {
|
||||
impl := &DatabaseErrorSanitizerMiddleware{
|
||||
next: d.Impl,
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"time"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
|
@ -96,7 +95,6 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
|
|||
|
||||
sys := vault.TestDynamicSystemView(cores[0].Core)
|
||||
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", consts.PluginTypeDatabase, "TestPlugin_GRPC_Main", []string{}, "")
|
||||
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", consts.PluginTypeDatabase, "TestPlugin_NetRPC_Main", []string{}, "")
|
||||
|
||||
return cluster, sys
|
||||
}
|
||||
|
@ -121,31 +119,6 @@ func TestPlugin_GRPC_Main(t *testing.T) {
|
|||
plugins.Serve(plugin, apiClientMeta.GetTLSConfig())
|
||||
}
|
||||
|
||||
// This is not an actual test case, it's a helper function that will be executed
|
||||
// by the go-plugin client via an exec call.
|
||||
func TestPlugin_NetRPC_Main(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
os.Unsetenv(pluginutil.PluginVaultVersionEnv)
|
||||
p := &mockPlugin{
|
||||
users: make(map[string][]string),
|
||||
}
|
||||
|
||||
args := []string{"--tls-skip-verify=true"}
|
||||
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
|
||||
tlsProvider := pluginutil.VaultPluginTLSProvider(apiClientMeta.GetTLSConfig())
|
||||
serveConf := dbplugin.ServeConfig(p, tlsProvider)
|
||||
serveConf.GRPCServer = nil
|
||||
|
||||
plugin.Serve(serveConf)
|
||||
}
|
||||
|
||||
func TestPlugin_Init(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
@ -284,143 +257,3 @@ func TestPlugin_RevokeUser(t *testing.T) {
|
|||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test the code is still compatible with an old netRPC plugin
|
||||
func TestPlugin_NetRPC_Init(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
dbRaw, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin-netRPC", sys, log.NewNullLogger())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
|
||||
_, err = dbRaw.Init(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
err = dbRaw.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlugin_NetRPC_CreateUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin-netRPC", sys, log.NewNullLogger())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
|
||||
_, err = db.Init(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
usernameConf := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if us != "test" || pw != "test" {
|
||||
t.Fatal("expected username and password to be 'test'")
|
||||
}
|
||||
|
||||
// try and save the same user again to verify it saved the first time, this
|
||||
// should return an error
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("expected an error, user wasn't created correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlugin_NetRPC_RenewUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin-netRPC", sys, log.NewNullLogger())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
_, err = db.Init(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
usernameConf := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin-netRPC", sys, log.NewNullLogger())
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"test": 1,
|
||||
}
|
||||
_, err = db.Init(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
usernameConf := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statements
|
||||
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Try adding the same username back so we can verify it was removed
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"crypto/tls"
|
||||
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
)
|
||||
|
||||
// Serve is called from within a plugin and wraps the provided
|
||||
|
@ -17,11 +16,13 @@ func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
|
|||
func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig {
|
||||
// pluginSets is the map of plugins we can dispense.
|
||||
pluginSets := map[int]plugin.PluginSet{
|
||||
// Version 3 used to supports both protocols. We want to keep it around
|
||||
// since it's possible old plugins built against this version will still
|
||||
// work with gRPC. There is currently no difference between version 3
|
||||
// and version 4.
|
||||
3: plugin.PluginSet{
|
||||
"database": &DatabasePlugin{
|
||||
GRPCDatabasePlugin: &GRPCDatabasePlugin{
|
||||
Impl: db,
|
||||
},
|
||||
"database": &GRPCDatabasePlugin{
|
||||
Impl: db,
|
||||
},
|
||||
},
|
||||
4: plugin.PluginSet{
|
||||
|
@ -38,12 +39,5 @@ func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.S
|
|||
GRPCServer: plugin.DefaultGRPCServer,
|
||||
}
|
||||
|
||||
// If we do not have gRPC support fallback to version 3
|
||||
// Remove this block in 0.13
|
||||
if !pluginutil.GRPCSupport() {
|
||||
conf.GRPCServer = nil
|
||||
delete(conf.VersionedPlugins, 4)
|
||||
}
|
||||
|
||||
return conf
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ func addIssueAndSignCommonFields(fields map[string]*framework.FieldSchema) map[s
|
|||
Description: `If true, the Common Name will not be
|
||||
included in DNS or Email Subject Alternate Names.
|
||||
Defaults to false (CN is included).`,
|
||||
DisplayName: "Exclude Common Name from Subject Alternative Names (SANs)",
|
||||
}
|
||||
|
||||
fields["format"] = &framework.FieldSchema{
|
||||
|
@ -20,6 +21,7 @@ Defaults to false (CN is included).`,
|
|||
or "pem_bundle". If "pem_bundle" any private
|
||||
key and issuing cert will be appended to the
|
||||
certificate pem. Defaults to "pem".`,
|
||||
AllowedValues: []interface{}{"pem", "der", "pem_bundle"},
|
||||
}
|
||||
|
||||
fields["private_key_format"] = &framework.FieldSchema{
|
||||
|
@ -31,24 +33,28 @@ parameter as either base64-encoded DER or PEM-encoded DER.
|
|||
However, this can be set to "pkcs8" to have the returned
|
||||
private key contain base64-encoded pkcs8 or PEM-encoded
|
||||
pkcs8 instead. Defaults to "der".`,
|
||||
AllowedValues: []interface{}{"", "der", "pem", "pkcs8"},
|
||||
}
|
||||
|
||||
fields["ip_sans"] = &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `The requested IP SANs, if any, in a
|
||||
comma-delimited list`,
|
||||
DisplayName: "IP Subject Alternative Names (SANs)",
|
||||
}
|
||||
|
||||
fields["uri_sans"] = &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `The requested URI SANs, if any, in a
|
||||
comma-delimited list.`,
|
||||
DisplayName: "URI Subject Alternative Names (SANs)",
|
||||
}
|
||||
|
||||
fields["other_sans"] = &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `Requested other SANs, in an array with the format
|
||||
<oid>;UTF8:<utf8 string value> for each entry.`,
|
||||
DisplayName: "Other SANs",
|
||||
}
|
||||
|
||||
return fields
|
||||
|
@ -79,6 +85,7 @@ in the role, this may be an email address.`,
|
|||
in a comma-delimited list. If email protection
|
||||
is enabled for the role, this may contain
|
||||
email addresses.`,
|
||||
DisplayName: "DNS/Email Subject Alternative Names (SANs)",
|
||||
}
|
||||
|
||||
fields["serial_number"] = &framework.FieldSchema{
|
||||
|
@ -95,6 +102,7 @@ sets the expiration date. If not specified
|
|||
the role default, backend default, or system
|
||||
default TTL is used, in that order. Cannot
|
||||
be larger than the role max TTL.`,
|
||||
DisplayName: "TTL",
|
||||
}
|
||||
|
||||
return fields
|
||||
|
@ -110,6 +118,7 @@ func addCACommonFields(fields map[string]*framework.FieldSchema) map[string]*fra
|
|||
Description: `The requested Subject Alternative Names, if any,
|
||||
in a comma-delimited list. May contain both
|
||||
DNS names and email addresses.`,
|
||||
DisplayName: "DNS/Email Subject Alternative Names (SANs)",
|
||||
}
|
||||
|
||||
fields["common_name"] = &framework.FieldSchema{
|
||||
|
@ -131,12 +140,14 @@ be larger than the mount max TTL. Note:
|
|||
this only has an effect when generating
|
||||
a CA cert or signing a CA cert, not when
|
||||
generating a CSR for an intermediate CA.`,
|
||||
DisplayName: "TTL",
|
||||
}
|
||||
|
||||
fields["ou"] = &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `If set, OU (OrganizationalUnit) will be set to
|
||||
this value.`,
|
||||
DisplayName: "OU (Organizational Unit)",
|
||||
}
|
||||
|
||||
fields["organization"] = &framework.FieldSchema{
|
||||
|
@ -155,24 +166,28 @@ this value.`,
|
|||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `If set, Locality will be set to
|
||||
this value.`,
|
||||
DisplayName: "Locality/City",
|
||||
}
|
||||
|
||||
fields["province"] = &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `If set, Province will be set to
|
||||
this value.`,
|
||||
DisplayName: "Province/State",
|
||||
}
|
||||
|
||||
fields["street_address"] = &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `If set, Street Address will be set to
|
||||
this value.`,
|
||||
DisplayName: "Street Address",
|
||||
}
|
||||
|
||||
fields["postal_code"] = &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `If set, Postal Code will be set to
|
||||
this value.`,
|
||||
DisplayName: "Postal Code",
|
||||
}
|
||||
|
||||
fields["serial_number"] = &framework.FieldSchema{
|
||||
|
@ -209,8 +224,8 @@ the key_type.`,
|
|||
Default: "rsa",
|
||||
Description: `The type of key to use; defaults to RSA. "rsa"
|
||||
and "ec" are the only valid values.`,
|
||||
AllowedValues: []interface{}{"rsa", "ec"},
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
|
@ -226,6 +241,7 @@ func addCAIssueFields(fields map[string]*framework.FieldSchema) map[string]*fram
|
|||
fields["permitted_dns_domains"] = &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `Domains for which this certificate is allowed to sign or issue child certificates. If set, all DNS names (subject and alt) on child certs must be exact matches or subsets of the given domains (see https://tools.ietf.org/html/rfc5280#section-4.2.1.10).`,
|
||||
DisplayName: "Permitted DNS Domains",
|
||||
}
|
||||
|
||||
return fields
|
||||
|
|
|
@ -31,6 +31,11 @@ func pathRoles(b *backend) *framework.Path {
|
|||
return &framework.Path{
|
||||
Pattern: "roles/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"backend": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Backend Type",
|
||||
},
|
||||
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role",
|
||||
|
@ -42,11 +47,13 @@ func pathRoles(b *backend) *framework.Path {
|
|||
requested. The lease duration controls the expiration
|
||||
of certificates issued by this backend. Defaults to
|
||||
the value of max_ttl.`,
|
||||
DisplayName: "TTL",
|
||||
},
|
||||
|
||||
"max_ttl": &framework.FieldSchema{
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: "The maximum allowed lease duration",
|
||||
DisplayName: "Max TTL",
|
||||
},
|
||||
|
||||
"allow_localhost": &framework.FieldSchema{
|
||||
|
@ -107,17 +114,20 @@ CN and SANs. Defaults to true.`,
|
|||
Default: true,
|
||||
Description: `If set, IP Subject Alternative Names are allowed.
|
||||
Any valid IP is accepted.`,
|
||||
DisplayName: "Allow IP Subject Alternative Names",
|
||||
},
|
||||
|
||||
"allowed_uri_sans": &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `If set, an array of allowed URIs to put in the URI Subject Alternative Names.
|
||||
Any valid URI is accepted, these values support globbing.`,
|
||||
DisplayName: "Allowed URI Subject Alternative Names",
|
||||
},
|
||||
|
||||
"allowed_other_sans": &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `If set, an array of allowed other names to put in SANs. These values support globbing and must be in the format <oid>;<type>:<value>. Currently only "utf8" is a valid type. All values, including globbing values, must use this syntax, with the exception being a single "*" which allows any OID and any value (but type must still be utf8).`,
|
||||
DisplayName: "Allowed Other Subject Alternative Names",
|
||||
},
|
||||
|
||||
"allowed_serial_numbers": &framework.FieldSchema{
|
||||
|
@ -156,6 +166,7 @@ protection use. Defaults to false.`,
|
|||
Default: "rsa",
|
||||
Description: `The type of key to use; defaults to RSA. "rsa"
|
||||
and "ec" are the only valid values.`,
|
||||
AllowedValues: []interface{}{"rsa", "ec"},
|
||||
},
|
||||
|
||||
"key_bits": &framework.FieldSchema{
|
||||
|
@ -175,6 +186,7 @@ https://golang.org/pkg/crypto/x509/#KeyUsage
|
|||
-- simply drop the "KeyUsage" part of the name.
|
||||
To remove all key usages from being set, set
|
||||
this value to an empty list.`,
|
||||
DisplayValue: "DigitalSignature,KeyAgreement,KeyEncipherment",
|
||||
},
|
||||
|
||||
"ext_key_usage": &framework.FieldSchema{
|
||||
|
@ -185,11 +197,13 @@ https://golang.org/pkg/crypto/x509/#ExtKeyUsage
|
|||
-- simply drop the "ExtKeyUsage" part of the name.
|
||||
To remove all key usages from being set, set
|
||||
this value to an empty list.`,
|
||||
DisplayName: "Extended Key Usage",
|
||||
},
|
||||
|
||||
"ext_key_usage_oids": &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `A comma-separated string or list of extended key usage oids.`,
|
||||
DisplayName: "Extended Key Usage OIDs",
|
||||
},
|
||||
|
||||
"use_csr_common_name": &framework.FieldSchema{
|
||||
|
@ -199,6 +213,7 @@ this value to an empty list.`,
|
|||
the common name in the CSR will be used. This
|
||||
does *not* include any requested Subject Alternative
|
||||
Names. Defaults to true.`,
|
||||
DisplayName: "Use CSR Common Name",
|
||||
},
|
||||
|
||||
"use_csr_sans": &framework.FieldSchema{
|
||||
|
@ -207,12 +222,14 @@ Names. Defaults to true.`,
|
|||
Description: `If set, when used with a signing profile,
|
||||
the SANs in the CSR will be used. This does *not*
|
||||
include the Common Name (cn). Defaults to true.`,
|
||||
DisplayName: "Use CSR Subject Alternative Names",
|
||||
},
|
||||
|
||||
"ou": &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `If set, OU (OrganizationalUnit) will be set to
|
||||
this value in certificates issued by this role.`,
|
||||
DisplayName: "Organizational Unit",
|
||||
},
|
||||
|
||||
"organization": &framework.FieldSchema{
|
||||
|
@ -231,12 +248,14 @@ this value in certificates issued by this role.`,
|
|||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `If set, Locality will be set to
|
||||
this value in certificates issued by this role.`,
|
||||
DisplayName: "Locality/City",
|
||||
},
|
||||
|
||||
"province": &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `If set, Province will be set to
|
||||
this value in certificates issued by this role.`,
|
||||
DisplayName: "Province/State",
|
||||
},
|
||||
|
||||
"street_address": &framework.FieldSchema{
|
||||
|
@ -263,6 +282,7 @@ to the CRL. When large number of certificates are generated with long
|
|||
lifetimes, it is recommended that lease generation be disabled, as large amount of
|
||||
leases adversely affect the startup time of Vault.`,
|
||||
},
|
||||
|
||||
"no_store": &framework.FieldSchema{
|
||||
Type: framework.TypeBool,
|
||||
Description: `
|
||||
|
@ -273,18 +293,23 @@ or revoked, so this option is recommended only for certificates that are
|
|||
non-sensitive, or extremely short-lived. This option implies a value of "false"
|
||||
for "generate_lease".`,
|
||||
},
|
||||
|
||||
"require_cn": &framework.FieldSchema{
|
||||
Type: framework.TypeBool,
|
||||
Default: true,
|
||||
Description: `If set to false, makes the 'common_name' field optional while generating a certificate.`,
|
||||
DisplayName: "Use CSR Common Name",
|
||||
},
|
||||
|
||||
"policy_identifiers": &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: `A comma-separated string or list of policy oids.`,
|
||||
},
|
||||
|
||||
"basic_constraints_valid_for_non_ca": &framework.FieldSchema{
|
||||
Type: framework.TypeBool,
|
||||
Description: `Mark Basic Constraints valid when issuing non-CA certificates.`,
|
||||
DisplayName: "Basic Constraints Valid for Non-CA",
|
||||
},
|
||||
"not_before_duration": &framework.FieldSchema{
|
||||
Type: framework.TypeDurationSecond,
|
||||
|
|
|
@ -93,6 +93,7 @@ func pathRoles(b *backend) *framework.Path {
|
|||
credential is being generated for other users, Vault uses this admin
|
||||
username to login to remote host and install the generated credential
|
||||
for the other user.`,
|
||||
DisplayName: "Admin Username",
|
||||
},
|
||||
"default_user": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
|
@ -101,6 +102,7 @@ func pathRoles(b *backend) *framework.Path {
|
|||
Default username for which a credential will be generated.
|
||||
When the endpoint 'creds/' is used without a username, this
|
||||
value will be used as default username.`,
|
||||
DisplayName: "Default Username",
|
||||
},
|
||||
"cidr_list": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
|
@ -108,6 +110,7 @@ func pathRoles(b *backend) *framework.Path {
|
|||
[Optional for Dynamic type] [Optional for OTP type] [Not applicable for CA type]
|
||||
Comma separated list of CIDR blocks for which the role is applicable for.
|
||||
CIDR blocks can belong to more than one role.`,
|
||||
DisplayName: "CIDR List",
|
||||
},
|
||||
"exclude_cidr_list": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
|
@ -116,6 +119,7 @@ func pathRoles(b *backend) *framework.Path {
|
|||
Comma separated list of CIDR blocks. IP addresses belonging to these blocks are not
|
||||
accepted by the role. This is particularly useful when big CIDR blocks are being used
|
||||
by the role and certain parts of it needs to be kept out.`,
|
||||
DisplayName: "Exclude CIDR List",
|
||||
},
|
||||
"port": &framework.FieldSchema{
|
||||
Type: framework.TypeInt,
|
||||
|
@ -125,6 +129,7 @@ func pathRoles(b *backend) *framework.Path {
|
|||
play any role in creation of OTP. For 'otp' type, this is just a way
|
||||
to inform client about the port number to use. Port number will be
|
||||
returned to client by Vault server along with OTP.`,
|
||||
DisplayValue: 22,
|
||||
},
|
||||
"key_type": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
|
@ -132,6 +137,8 @@ func pathRoles(b *backend) *framework.Path {
|
|||
[Required for all types]
|
||||
Type of key used to login to hosts. It can be either 'otp', 'dynamic' or 'ca'.
|
||||
'otp' type requires agent to be installed in remote hosts.`,
|
||||
AllowedValues: []interface{}{"otp", "dynamic","ca"},
|
||||
DisplayValue: "ca",
|
||||
},
|
||||
"key_bits": &framework.FieldSchema{
|
||||
Type: framework.TypeInt,
|
||||
|
@ -188,6 +195,7 @@ func pathRoles(b *backend) *framework.Path {
|
|||
requested. The lease duration controls the expiration
|
||||
of certificates issued by this backend. Defaults to
|
||||
the value of max_ttl.`,
|
||||
DisplayName: "TTL",
|
||||
},
|
||||
"max_ttl": &framework.FieldSchema{
|
||||
Type: framework.TypeDurationSecond,
|
||||
|
@ -195,6 +203,7 @@ func pathRoles(b *backend) *framework.Path {
|
|||
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
|
||||
The maximum allowed lease duration
|
||||
`,
|
||||
DisplayName: "Max TTL",
|
||||
},
|
||||
"allowed_critical_options": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
|
@ -202,7 +211,7 @@ func pathRoles(b *backend) *framework.Path {
|
|||
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
|
||||
A comma-separated list of critical options that certificates can have when signed.
|
||||
To allow any critical options, set this to an empty string.
|
||||
`,
|
||||
`,
|
||||
},
|
||||
"allowed_extensions": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
|
@ -238,7 +247,7 @@ func pathRoles(b *backend) *framework.Path {
|
|||
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
|
||||
If set, certificates are allowed to be signed for use as a 'user'.
|
||||
`,
|
||||
Default: false,
|
||||
Default: false,
|
||||
},
|
||||
"allow_host_certificates": &framework.FieldSchema{
|
||||
Type: framework.TypeBool,
|
||||
|
@ -246,7 +255,7 @@ func pathRoles(b *backend) *framework.Path {
|
|||
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
|
||||
If set, certificates are allowed to be signed for use as a 'host'.
|
||||
`,
|
||||
Default: false,
|
||||
Default: false,
|
||||
},
|
||||
"allow_bare_domains": &framework.FieldSchema{
|
||||
Type: framework.TypeBool,
|
||||
|
@ -272,6 +281,7 @@ func pathRoles(b *backend) *framework.Path {
|
|||
When false, the key ID will always be the token display name.
|
||||
The key ID is logged by the SSH server and can be useful for auditing.
|
||||
`,
|
||||
DisplayName: "Allow User Key IDs",
|
||||
},
|
||||
"key_id_format": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
|
@ -282,6 +292,7 @@ func pathRoles(b *backend) *framework.Path {
|
|||
the token used to make the request. '{{role_name}}' - The name of the role signing the request.
|
||||
'{{public_key_hash}}' - A SHA256 checksum of the public key that is being signed.
|
||||
`,
|
||||
DisplayName: "Key ID Format",
|
||||
},
|
||||
"allowed_user_key_lengths": &framework.FieldSchema{
|
||||
Type: framework.TypeMap,
|
||||
|
|
102
command/agent.go
102
command/agent.go
|
@ -4,6 +4,10 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
@ -23,6 +27,7 @@ import (
|
|||
"github.com/hashicorp/vault/command/agent/auth/gcp"
|
||||
"github.com/hashicorp/vault/command/agent/auth/jwt"
|
||||
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
|
||||
"github.com/hashicorp/vault/command/agent/cache"
|
||||
"github.com/hashicorp/vault/command/agent/config"
|
||||
"github.com/hashicorp/vault/command/agent/sink"
|
||||
"github.com/hashicorp/vault/command/agent/sink/file"
|
||||
|
@ -218,19 +223,6 @@ func (c *AgentCommand) Run(args []string) int {
|
|||
info["cgo"] = "enabled"
|
||||
}
|
||||
|
||||
// Server configuration output
|
||||
padding := 24
|
||||
sort.Strings(infoKeys)
|
||||
c.UI.Output("==> Vault agent configuration:\n")
|
||||
for _, k := range infoKeys {
|
||||
c.UI.Output(fmt.Sprintf(
|
||||
"%s%s: %s",
|
||||
strings.Repeat(" ", padding-len(k)),
|
||||
strings.Title(k),
|
||||
info[k]))
|
||||
}
|
||||
c.UI.Output("")
|
||||
|
||||
// Tests might not want to start a vault server and just want to verify
|
||||
// the configuration.
|
||||
if c.flagTestVerifyOnly {
|
||||
|
@ -332,10 +324,92 @@ func (c *AgentCommand) Run(args []string) int {
|
|||
EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials,
|
||||
})
|
||||
|
||||
// Start things running
|
||||
// Start auto-auth and sink servers
|
||||
go ah.Run(ctx, method)
|
||||
go ss.Run(ctx, ah.OutputCh, sinks)
|
||||
|
||||
// Parse agent listener configurations
|
||||
if config.Cache != nil && len(config.Cache.Listeners) != 0 {
|
||||
cacheLogger := c.logger.Named("cache")
|
||||
|
||||
// Create the API proxier
|
||||
apiProxy := cache.NewAPIProxy(&cache.APIProxyConfig{
|
||||
Logger: cacheLogger.Named("apiproxy"),
|
||||
})
|
||||
|
||||
// Create the lease cache proxier and set its underlying proxier to
|
||||
// the API proxier.
|
||||
leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{
|
||||
BaseContext: ctx,
|
||||
Proxier: apiProxy,
|
||||
Logger: cacheLogger.Named("leasecache"),
|
||||
})
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error creating lease cache: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Create a muxer and add paths relevant for the lease cache layer
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx))
|
||||
|
||||
mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, config.Cache.UseAutoAuthToken, c.client))
|
||||
|
||||
var listeners []net.Listener
|
||||
for i, lnConfig := range config.Cache.Listeners {
|
||||
listener, props, _, err := cache.ServerListener(lnConfig, c.logWriter, c.UI)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error parsing listener configuration: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
listeners = append(listeners, listener)
|
||||
|
||||
scheme := "https://"
|
||||
if props["tls"] == "disabled" {
|
||||
scheme = "http://"
|
||||
}
|
||||
if lnConfig.Type == "unix" {
|
||||
scheme = "unix://"
|
||||
}
|
||||
|
||||
infoKey := fmt.Sprintf("api address %d", i+1)
|
||||
info[infoKey] = scheme + listener.Addr().String()
|
||||
infoKeys = append(infoKeys, infoKey)
|
||||
|
||||
cacheLogger.Info("starting listener", "addr", listener.Addr().String())
|
||||
server := &http.Server{
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
ErrorLog: cacheLogger.StandardLogger(nil),
|
||||
}
|
||||
go server.Serve(listener)
|
||||
}
|
||||
|
||||
// Ensure that listeners are closed at all the exits
|
||||
listenerCloseFunc := func() {
|
||||
for _, ln := range listeners {
|
||||
ln.Close()
|
||||
}
|
||||
}
|
||||
defer c.cleanupGuard.Do(listenerCloseFunc)
|
||||
}
|
||||
|
||||
// Server configuration output
|
||||
padding := 24
|
||||
sort.Strings(infoKeys)
|
||||
c.UI.Output("==> Vault agent configuration:\n")
|
||||
for _, k := range infoKeys {
|
||||
c.UI.Output(fmt.Sprintf(
|
||||
"%s%s: %s",
|
||||
strings.Repeat(" ", padding-len(k)),
|
||||
strings.Title(k),
|
||||
info[k]))
|
||||
}
|
||||
c.UI.Output("")
|
||||
|
||||
// Release the log gate.
|
||||
c.logGate.Flush()
|
||||
|
||||
|
|
61
command/agent/cache/api_proxy.go
vendored
Normal file
61
command/agent/cache/api_proxy.go
vendored
Normal file
|
@ -0,0 +1,61 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io/ioutil"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
)
|
||||
|
||||
// APIProxy is an implementation of the proxier interface that is used to
|
||||
// forward the request to Vault and get the response.
|
||||
type APIProxy struct {
|
||||
logger hclog.Logger
|
||||
}
|
||||
|
||||
type APIProxyConfig struct {
|
||||
Logger hclog.Logger
|
||||
}
|
||||
|
||||
func NewAPIProxy(config *APIProxyConfig) Proxier {
|
||||
return &APIProxy{
|
||||
logger: config.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
|
||||
client, err := api.NewClient(api.DefaultConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.SetToken(req.Token)
|
||||
client.SetHeaders(req.Request.Header)
|
||||
|
||||
fwReq := client.NewRequest(req.Request.Method, req.Request.URL.Path)
|
||||
fwReq.BodyBytes = req.RequestBody
|
||||
|
||||
// Make the request to Vault and get the response
|
||||
ap.logger.Info("forwarding request", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
resp, err := client.RawRequestWithContext(ctx, fwReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse and reset response body
|
||||
respBody, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
ap.logger.Error("failed to read request body", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
resp.Body = ioutil.NopCloser(bytes.NewBuffer(respBody))
|
||||
|
||||
return &SendResponse{
|
||||
Response: resp,
|
||||
ResponseBody: respBody,
|
||||
}, nil
|
||||
}
|
43
command/agent/cache/api_proxy_test.go
vendored
Normal file
43
command/agent/cache/api_proxy_test.go
vendored
Normal file
|
@ -0,0 +1,43 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/logging"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
)
|
||||
|
||||
func TestCache_APIProxy(t *testing.T) {
|
||||
cleanup, client, _, _ := setupClusterAndAgent(namespace.RootContext(nil), t, nil)
|
||||
defer cleanup()
|
||||
|
||||
proxier := NewAPIProxy(&APIProxyConfig{
|
||||
Logger: logging.NewVaultLogger(hclog.Trace),
|
||||
})
|
||||
|
||||
r := client.NewRequest("GET", "/v1/sys/health")
|
||||
req, err := r.ToHTTP()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := proxier.Send(namespace.RootContext(nil), &SendRequest{
|
||||
Request: req,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var result api.HealthResponse
|
||||
err = jsonutil.DecodeJSONFromReader(resp.Response.Body, &result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !result.Initialized || result.Sealed || result.Standby {
|
||||
t.Fatalf("bad sys/health response")
|
||||
}
|
||||
}
|
926
command/agent/cache/cache_test.go
vendored
Normal file
926
command/agent/cache/cache_test.go
vendored
Normal file
|
@ -0,0 +1,926 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
kv "github.com/hashicorp/vault-plugin-secrets-kv"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/credential/userpass"
|
||||
"github.com/hashicorp/vault/helper/logging"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
const policyAdmin = `
|
||||
path "*" {
|
||||
capabilities = ["sudo", "create", "read", "update", "delete", "list"]
|
||||
}
|
||||
`
|
||||
|
||||
// setupClusterAndAgent is a helper func used to set up a test cluster and
|
||||
// caching agent. It returns a cleanup func that should be deferred immediately
|
||||
// along with two clients, one for direct cluster communication and another to
|
||||
// talk to the caching agent.
|
||||
func setupClusterAndAgent(ctx context.Context, t *testing.T, coreConfig *vault.CoreConfig) (func(), *api.Client, *api.Client, *LeaseCache) {
|
||||
t.Helper()
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// Handle sane defaults
|
||||
if coreConfig == nil {
|
||||
coreConfig = &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: logging.NewVaultLogger(hclog.Trace),
|
||||
CredentialBackends: map[string]logical.Factory{
|
||||
"userpass": userpass.Factory,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if coreConfig.CredentialBackends == nil {
|
||||
coreConfig.CredentialBackends = map[string]logical.Factory{
|
||||
"userpass": userpass.Factory,
|
||||
}
|
||||
}
|
||||
|
||||
// Init new test cluster
|
||||
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
})
|
||||
cluster.Start()
|
||||
|
||||
cores := cluster.Cores
|
||||
vault.TestWaitActive(t, cores[0].Core)
|
||||
|
||||
// clusterClient is the client that is used to talk directly to the cluster.
|
||||
clusterClient := cores[0].Client
|
||||
|
||||
// Add an admin policy
|
||||
if err := clusterClient.Sys().PutPolicy("admin", policyAdmin); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Set up the userpass auth backend and an admin user. Used for getting a token
|
||||
// for the agent later down in this func.
|
||||
clusterClient.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
|
||||
Type: "userpass",
|
||||
})
|
||||
|
||||
_, err := clusterClient.Logical().Write("auth/userpass/users/foo", map[string]interface{}{
|
||||
"password": "bar",
|
||||
"policies": []string{"admin"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Set up env vars for agent consumption
|
||||
origEnvVaultAddress := os.Getenv(api.EnvVaultAddress)
|
||||
os.Setenv(api.EnvVaultAddress, clusterClient.Address())
|
||||
|
||||
origEnvVaultCACert := os.Getenv(api.EnvVaultCACert)
|
||||
os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir))
|
||||
|
||||
cacheLogger := logging.NewVaultLogger(hclog.Trace).Named("cache")
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create the API proxier
|
||||
apiProxy := NewAPIProxy(&APIProxyConfig{
|
||||
Logger: cacheLogger.Named("apiproxy"),
|
||||
})
|
||||
|
||||
// Create the lease cache proxier and set its underlying proxier to
|
||||
// the API proxier.
|
||||
leaseCache, err := NewLeaseCache(&LeaseCacheConfig{
|
||||
BaseContext: ctx,
|
||||
Proxier: apiProxy,
|
||||
Logger: cacheLogger.Named("leasecache"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a muxer and add paths relevant for the lease cache layer
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx))
|
||||
|
||||
mux.Handle("/", Handler(ctx, cacheLogger, leaseCache, false, clusterClient))
|
||||
server := &http.Server{
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
ErrorLog: cacheLogger.StandardLogger(nil),
|
||||
}
|
||||
go server.Serve(listener)
|
||||
|
||||
// testClient is the client that is used to talk to the agent for proxying/caching behavior.
|
||||
testClient, err := clusterClient.Clone()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := testClient.SetAddress("http://" + listener.Addr().String()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Login via userpass method to derive a managed token. Set that token as the
|
||||
// testClient's token
|
||||
resp, err := testClient.Logical().Write("auth/userpass/login/foo", map[string]interface{}{
|
||||
"password": "bar",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testClient.SetToken(resp.Auth.ClientToken)
|
||||
|
||||
cleanup := func() {
|
||||
cluster.Cleanup()
|
||||
os.Setenv(api.EnvVaultAddress, origEnvVaultAddress)
|
||||
os.Setenv(api.EnvVaultCACert, origEnvVaultCACert)
|
||||
listener.Close()
|
||||
}
|
||||
|
||||
return cleanup, clusterClient, testClient, leaseCache
|
||||
}
|
||||
|
||||
func tokenRevocationValidation(t *testing.T, sampleSpace map[string]string, expected map[string]string, leaseCache *LeaseCache) {
|
||||
t.Helper()
|
||||
for val, valType := range sampleSpace {
|
||||
index, err := leaseCache.db.Get(valType, val)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if expected[val] == "" && index != nil {
|
||||
t.Fatalf("failed to evict index from the cache: type: %q, value: %q", valType, val)
|
||||
}
|
||||
if expected[val] != "" && index == nil {
|
||||
t.Fatalf("evicted an undesired index from cache: type: %q, value: %q", valType, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_TokenRevocations_RevokeOrphan(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": vault.LeasedPassthroughBackendFactory,
|
||||
},
|
||||
}
|
||||
|
||||
sampleSpace := make(map[string]string)
|
||||
|
||||
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
token1 := testClient.Token()
|
||||
sampleSpace[token1] = "token"
|
||||
|
||||
// Mount the kv backend
|
||||
err := testClient.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a secret in the backend
|
||||
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the secret and create a lease
|
||||
leaseResp, err := testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease1 := leaseResp.LeaseID
|
||||
sampleSpace[lease1] = "lease"
|
||||
|
||||
resp, err := testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token2 := resp.Auth.ClientToken
|
||||
sampleSpace[token2] = "token"
|
||||
|
||||
testClient.SetToken(token2)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease2 := leaseResp.LeaseID
|
||||
sampleSpace[lease2] = "lease"
|
||||
|
||||
resp, err = testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token3 := resp.Auth.ClientToken
|
||||
sampleSpace[token3] = "token"
|
||||
|
||||
testClient.SetToken(token3)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease3 := leaseResp.LeaseID
|
||||
sampleSpace[lease3] = "lease"
|
||||
|
||||
expected := make(map[string]string)
|
||||
for k, v := range sampleSpace {
|
||||
expected[k] = v
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
|
||||
// Revoke-orphan the intermediate token. This should result in its own
|
||||
// eviction and evictions of the revoked token's leases. All other things
|
||||
// including the child tokens and leases of the child tokens should be
|
||||
// untouched.
|
||||
testClient.SetToken(token2)
|
||||
err = testClient.Auth().Token().RevokeOrphan(token2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
expected = map[string]string{
|
||||
token1: "token",
|
||||
lease1: "lease",
|
||||
token3: "token",
|
||||
lease3: "lease",
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
}
|
||||
|
||||
func TestCache_TokenRevocations_LeafLevelToken(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": vault.LeasedPassthroughBackendFactory,
|
||||
},
|
||||
}
|
||||
|
||||
sampleSpace := make(map[string]string)
|
||||
|
||||
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
token1 := testClient.Token()
|
||||
sampleSpace[token1] = "token"
|
||||
|
||||
// Mount the kv backend
|
||||
err := testClient.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a secret in the backend
|
||||
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the secret and create a lease
|
||||
leaseResp, err := testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease1 := leaseResp.LeaseID
|
||||
sampleSpace[lease1] = "lease"
|
||||
|
||||
resp, err := testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token2 := resp.Auth.ClientToken
|
||||
sampleSpace[token2] = "token"
|
||||
|
||||
testClient.SetToken(token2)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease2 := leaseResp.LeaseID
|
||||
sampleSpace[lease2] = "lease"
|
||||
|
||||
resp, err = testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token3 := resp.Auth.ClientToken
|
||||
sampleSpace[token3] = "token"
|
||||
|
||||
testClient.SetToken(token3)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease3 := leaseResp.LeaseID
|
||||
sampleSpace[lease3] = "lease"
|
||||
|
||||
expected := make(map[string]string)
|
||||
for k, v := range sampleSpace {
|
||||
expected[k] = v
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
|
||||
// Revoke the lef token. This should evict all the leases belonging to this
|
||||
// token, evict entries for all the child tokens and their respective
|
||||
// leases.
|
||||
testClient.SetToken(token3)
|
||||
err = testClient.Auth().Token().RevokeSelf("")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
expected = map[string]string{
|
||||
token1: "token",
|
||||
lease1: "lease",
|
||||
token2: "token",
|
||||
lease2: "lease",
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
}
|
||||
|
||||
func TestCache_TokenRevocations_IntermediateLevelToken(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": vault.LeasedPassthroughBackendFactory,
|
||||
},
|
||||
}
|
||||
|
||||
sampleSpace := make(map[string]string)
|
||||
|
||||
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
token1 := testClient.Token()
|
||||
sampleSpace[token1] = "token"
|
||||
|
||||
// Mount the kv backend
|
||||
err := testClient.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a secret in the backend
|
||||
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the secret and create a lease
|
||||
leaseResp, err := testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease1 := leaseResp.LeaseID
|
||||
sampleSpace[lease1] = "lease"
|
||||
|
||||
resp, err := testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token2 := resp.Auth.ClientToken
|
||||
sampleSpace[token2] = "token"
|
||||
|
||||
testClient.SetToken(token2)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease2 := leaseResp.LeaseID
|
||||
sampleSpace[lease2] = "lease"
|
||||
|
||||
resp, err = testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token3 := resp.Auth.ClientToken
|
||||
sampleSpace[token3] = "token"
|
||||
|
||||
testClient.SetToken(token3)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease3 := leaseResp.LeaseID
|
||||
sampleSpace[lease3] = "lease"
|
||||
|
||||
expected := make(map[string]string)
|
||||
for k, v := range sampleSpace {
|
||||
expected[k] = v
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
|
||||
// Revoke the second level token. This should evict all the leases
|
||||
// belonging to this token, evict entries for all the child tokens and
|
||||
// their respective leases.
|
||||
testClient.SetToken(token2)
|
||||
err = testClient.Auth().Token().RevokeSelf("")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
expected = map[string]string{
|
||||
token1: "token",
|
||||
lease1: "lease",
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
}
|
||||
|
||||
func TestCache_TokenRevocations_TopLevelToken(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": vault.LeasedPassthroughBackendFactory,
|
||||
},
|
||||
}
|
||||
|
||||
sampleSpace := make(map[string]string)
|
||||
|
||||
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
token1 := testClient.Token()
|
||||
sampleSpace[token1] = "token"
|
||||
|
||||
// Mount the kv backend
|
||||
err := testClient.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a secret in the backend
|
||||
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the secret and create a lease
|
||||
leaseResp, err := testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease1 := leaseResp.LeaseID
|
||||
sampleSpace[lease1] = "lease"
|
||||
|
||||
resp, err := testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token2 := resp.Auth.ClientToken
|
||||
sampleSpace[token2] = "token"
|
||||
|
||||
testClient.SetToken(token2)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease2 := leaseResp.LeaseID
|
||||
sampleSpace[lease2] = "lease"
|
||||
|
||||
resp, err = testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token3 := resp.Auth.ClientToken
|
||||
sampleSpace[token3] = "token"
|
||||
|
||||
testClient.SetToken(token3)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease3 := leaseResp.LeaseID
|
||||
sampleSpace[lease3] = "lease"
|
||||
|
||||
expected := make(map[string]string)
|
||||
for k, v := range sampleSpace {
|
||||
expected[k] = v
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
|
||||
// Revoke the top level token. This should evict all the leases belonging
|
||||
// to this token, evict entries for all the child tokens and their
|
||||
// respective leases.
|
||||
testClient.SetToken(token1)
|
||||
err = testClient.Auth().Token().RevokeSelf("")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
expected = make(map[string]string)
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
}
|
||||
|
||||
func TestCache_TokenRevocations_Shutdown(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": vault.LeasedPassthroughBackendFactory,
|
||||
},
|
||||
}
|
||||
|
||||
sampleSpace := make(map[string]string)
|
||||
|
||||
ctx, rootCancelFunc := context.WithCancel(namespace.RootContext(nil))
|
||||
cleanup, _, testClient, leaseCache := setupClusterAndAgent(ctx, t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
token1 := testClient.Token()
|
||||
sampleSpace[token1] = "token"
|
||||
|
||||
// Mount the kv backend
|
||||
err := testClient.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a secret in the backend
|
||||
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the secret and create a lease
|
||||
leaseResp, err := testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease1 := leaseResp.LeaseID
|
||||
sampleSpace[lease1] = "lease"
|
||||
|
||||
resp, err := testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token2 := resp.Auth.ClientToken
|
||||
sampleSpace[token2] = "token"
|
||||
|
||||
testClient.SetToken(token2)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease2 := leaseResp.LeaseID
|
||||
sampleSpace[lease2] = "lease"
|
||||
|
||||
resp, err = testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token3 := resp.Auth.ClientToken
|
||||
sampleSpace[token3] = "token"
|
||||
|
||||
testClient.SetToken(token3)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease3 := leaseResp.LeaseID
|
||||
sampleSpace[lease3] = "lease"
|
||||
|
||||
expected := make(map[string]string)
|
||||
for k, v := range sampleSpace {
|
||||
expected[k] = v
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
|
||||
rootCancelFunc()
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Ensure that all the entries are now gone
|
||||
expected = make(map[string]string)
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
}
|
||||
|
||||
func TestCache_TokenRevocations_BaseContextCancellation(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": vault.LeasedPassthroughBackendFactory,
|
||||
},
|
||||
}
|
||||
|
||||
sampleSpace := make(map[string]string)
|
||||
|
||||
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
token1 := testClient.Token()
|
||||
sampleSpace[token1] = "token"
|
||||
|
||||
// Mount the kv backend
|
||||
err := testClient.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a secret in the backend
|
||||
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the secret and create a lease
|
||||
leaseResp, err := testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease1 := leaseResp.LeaseID
|
||||
sampleSpace[lease1] = "lease"
|
||||
|
||||
resp, err := testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token2 := resp.Auth.ClientToken
|
||||
sampleSpace[token2] = "token"
|
||||
|
||||
testClient.SetToken(token2)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease2 := leaseResp.LeaseID
|
||||
sampleSpace[lease2] = "lease"
|
||||
|
||||
resp, err = testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token3 := resp.Auth.ClientToken
|
||||
sampleSpace[token3] = "token"
|
||||
|
||||
testClient.SetToken(token3)
|
||||
|
||||
leaseResp, err = testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lease3 := leaseResp.LeaseID
|
||||
sampleSpace[lease3] = "lease"
|
||||
|
||||
expected := make(map[string]string)
|
||||
for k, v := range sampleSpace {
|
||||
expected[k] = v
|
||||
}
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
|
||||
// Cancel the base context of the lease cache. This should trigger
|
||||
// evictions of all the entries from the cache.
|
||||
leaseCache.baseCtxInfo.CancelFunc()
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Ensure that all the entries are now gone
|
||||
expected = make(map[string]string)
|
||||
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
|
||||
}
|
||||
|
||||
func TestCache_NonCacheable(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": kv.Factory,
|
||||
},
|
||||
}
|
||||
|
||||
cleanup, _, testClient, _ := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
// Query mounts first
|
||||
origMounts, err := testClient.Sys().ListMounts()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Mount a kv backend
|
||||
if err := testClient.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
Options: map[string]string{
|
||||
"version": "2",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Query mounts again
|
||||
newMounts, err := testClient.Sys().ListMounts()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := deep.Equal(origMounts, newMounts); diff == nil {
|
||||
t.Logf("response #1: %#v", origMounts)
|
||||
t.Logf("response #2: %#v", newMounts)
|
||||
t.Fatal("expected requests to be not cached")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_AuthResponse(t *testing.T) {
|
||||
cleanup, _, testClient, _ := setupClusterAndAgent(namespace.RootContext(nil), t, nil)
|
||||
defer cleanup()
|
||||
|
||||
resp, err := testClient.Logical().Write("auth/token/create", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token := resp.Auth.ClientToken
|
||||
testClient.SetToken(token)
|
||||
|
||||
authTokeCreateReq := func(t *testing.T, policies map[string]interface{}) *api.Secret {
|
||||
resp, err := testClient.Logical().Write("auth/token/create", policies)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Auth == nil || resp.Auth.ClientToken == "" {
|
||||
t.Fatalf("expected a valid client token in the response, got = %#v", resp)
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// Test on auth response by creating a child token
|
||||
{
|
||||
proxiedResp := authTokeCreateReq(t, map[string]interface{}{
|
||||
"policies": "default",
|
||||
})
|
||||
|
||||
cachedResp := authTokeCreateReq(t, map[string]interface{}{
|
||||
"policies": "default",
|
||||
})
|
||||
|
||||
if diff := deep.Equal(proxiedResp.Auth.ClientToken, cachedResp.Auth.ClientToken); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
||||
// Test on *non-renewable* auth response by creating a child root token
|
||||
{
|
||||
proxiedResp := authTokeCreateReq(t, nil)
|
||||
|
||||
cachedResp := authTokeCreateReq(t, nil)
|
||||
|
||||
if diff := deep.Equal(proxiedResp.Auth.ClientToken, cachedResp.Auth.ClientToken); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_LeaseResponse(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: hclog.NewNullLogger(),
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"kv": vault.LeasedPassthroughBackendFactory,
|
||||
},
|
||||
}
|
||||
|
||||
cleanup, client, testClient, _ := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
|
||||
defer cleanup()
|
||||
|
||||
err := client.Sys().Mount("kv", &api.MountInput{
|
||||
Type: "kv",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Test proxy by issuing two different requests
|
||||
{
|
||||
// Write data to the lease-kv backend
|
||||
_, err := testClient.Logical().Write("kv/foo", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = testClient.Logical().Write("kv/foobar", map[string]interface{}{
|
||||
"value": "bar",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
firstResp, err := testClient.Logical().Read("kv/foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secondResp, err := testClient.Logical().Read("kv/foobar")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := deep.Equal(firstResp, secondResp); diff == nil {
|
||||
t.Logf("response: %#v", firstResp)
|
||||
t.Fatal("expected proxied responses, got cached response on second request")
|
||||
}
|
||||
}
|
||||
|
||||
// Test caching behavior by issue the same request twice
|
||||
{
|
||||
_, err := testClient.Logical().Write("kv/baz", map[string]interface{}{
|
||||
"value": "foo",
|
||||
"ttl": "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
proxiedResp, err := testClient.Logical().Read("kv/baz")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cachedResp, err := testClient.Logical().Read("kv/baz")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := deep.Equal(proxiedResp, cachedResp); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
}
|
265
command/agent/cache/cachememdb/cache_memdb.go
vendored
Normal file
265
command/agent/cache/cachememdb/cache_memdb.go
vendored
Normal file
|
@ -0,0 +1,265 @@
|
|||
package cachememdb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
memdb "github.com/hashicorp/go-memdb"
|
||||
)
|
||||
|
||||
const (
|
||||
tableNameIndexer = "indexer"
|
||||
)
|
||||
|
||||
// CacheMemDB is the underlying cache database for storing indexes.
|
||||
type CacheMemDB struct {
|
||||
db *memdb.MemDB
|
||||
}
|
||||
|
||||
// New creates a new instance of CacheMemDB.
|
||||
func New() (*CacheMemDB, error) {
|
||||
db, err := newDB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &CacheMemDB{
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newDB() (*memdb.MemDB, error) {
|
||||
cacheSchema := &memdb.DBSchema{
|
||||
Tables: map[string]*memdb.TableSchema{
|
||||
tableNameIndexer: &memdb.TableSchema{
|
||||
Name: tableNameIndexer,
|
||||
Indexes: map[string]*memdb.IndexSchema{
|
||||
// This index enables fetching the cached item based on the
|
||||
// identifier of the index.
|
||||
IndexNameID: &memdb.IndexSchema{
|
||||
Name: IndexNameID,
|
||||
Unique: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "ID",
|
||||
},
|
||||
},
|
||||
// This index enables fetching all the entries in cache for
|
||||
// a given request path, in a given namespace.
|
||||
IndexNameRequestPath: &memdb.IndexSchema{
|
||||
Name: IndexNameRequestPath,
|
||||
Unique: false,
|
||||
Indexer: &memdb.CompoundIndex{
|
||||
Indexes: []memdb.Indexer{
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "Namespace",
|
||||
},
|
||||
&memdb.StringFieldIndex{
|
||||
Field: "RequestPath",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// This index enables fetching all the entries in cache
|
||||
// belonging to the leases of a given token.
|
||||
IndexNameLeaseToken: &memdb.IndexSchema{
|
||||
Name: IndexNameLeaseToken,
|
||||
Unique: false,
|
||||
AllowMissing: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "LeaseToken",
|
||||
},
|
||||
},
|
||||
// This index enables fetching all the entries in cache
|
||||
// that are tied to the given token, regardless of the
|
||||
// entries belonging to the token or belonging to the
|
||||
// lease.
|
||||
IndexNameToken: &memdb.IndexSchema{
|
||||
Name: IndexNameToken,
|
||||
Unique: true,
|
||||
AllowMissing: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "Token",
|
||||
},
|
||||
},
|
||||
// This index enables fetching all the entries in cache for
|
||||
// the given parent token.
|
||||
IndexNameTokenParent: &memdb.IndexSchema{
|
||||
Name: IndexNameTokenParent,
|
||||
Unique: false,
|
||||
AllowMissing: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "TokenParent",
|
||||
},
|
||||
},
|
||||
// This index enables fetching all the entries in cache for
|
||||
// the given accessor.
|
||||
IndexNameTokenAccessor: &memdb.IndexSchema{
|
||||
Name: IndexNameTokenAccessor,
|
||||
Unique: true,
|
||||
AllowMissing: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "TokenAccessor",
|
||||
},
|
||||
},
|
||||
// This index enables fetching all the entries in cache for
|
||||
// the given lease identifier.
|
||||
IndexNameLease: &memdb.IndexSchema{
|
||||
Name: IndexNameLease,
|
||||
Unique: true,
|
||||
AllowMissing: true,
|
||||
Indexer: &memdb.StringFieldIndex{
|
||||
Field: "Lease",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
db, err := memdb.NewMemDB(cacheSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// Get returns the index based on the indexer and the index values provided.
|
||||
func (c *CacheMemDB) Get(indexName string, indexValues ...interface{}) (*Index, error) {
|
||||
if !validIndexName(indexName) {
|
||||
return nil, fmt.Errorf("invalid index name %q", indexName)
|
||||
}
|
||||
|
||||
raw, err := c.db.Txn(false).First(tableNameIndexer, indexName, indexValues...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
index, ok := raw.(*Index)
|
||||
if !ok {
|
||||
return nil, errors.New("unable to parse index value from the cache")
|
||||
}
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// Set stores the index into the cache.
|
||||
func (c *CacheMemDB) Set(index *Index) error {
|
||||
if index == nil {
|
||||
return errors.New("nil index provided")
|
||||
}
|
||||
|
||||
txn := c.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
|
||||
if err := txn.Insert(tableNameIndexer, index); err != nil {
|
||||
return fmt.Errorf("unable to insert index into cache: %v", err)
|
||||
}
|
||||
|
||||
txn.Commit()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByPrefix returns all the cached indexes based on the index name and the
|
||||
// value prefix.
|
||||
func (c *CacheMemDB) GetByPrefix(indexName string, indexValues ...interface{}) ([]*Index, error) {
|
||||
if !validIndexName(indexName) {
|
||||
return nil, fmt.Errorf("invalid index name %q", indexName)
|
||||
}
|
||||
|
||||
indexName = indexName + "_prefix"
|
||||
|
||||
// Get all the objects
|
||||
iter, err := c.db.Txn(false).Get(tableNameIndexer, indexName, indexValues...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var indexes []*Index
|
||||
for {
|
||||
obj := iter.Next()
|
||||
if obj == nil {
|
||||
break
|
||||
}
|
||||
index, ok := obj.(*Index)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to cast cached index")
|
||||
}
|
||||
|
||||
indexes = append(indexes, index)
|
||||
}
|
||||
|
||||
return indexes, nil
|
||||
}
|
||||
|
||||
// Evict removes an index from the cache based on index name and value.
|
||||
func (c *CacheMemDB) Evict(indexName string, indexValues ...interface{}) error {
|
||||
index, err := c.Get(indexName, indexValues...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to fetch index on cache deletion: %v", err)
|
||||
}
|
||||
|
||||
if index == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
txn := c.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
|
||||
if err := txn.Delete(tableNameIndexer, index); err != nil {
|
||||
return fmt.Errorf("unable to delete index from cache: %v", err)
|
||||
}
|
||||
|
||||
txn.Commit()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EvictAll removes all matching indexes from the cache based on index name and value.
|
||||
func (c *CacheMemDB) EvictAll(indexName, indexValue string) error {
|
||||
return c.batchEvict(false, indexName, indexValue)
|
||||
}
|
||||
|
||||
// EvictByPrefix removes all matching prefix indexes from the cache based on index name and prefix.
|
||||
func (c *CacheMemDB) EvictByPrefix(indexName, indexPrefix string) error {
|
||||
return c.batchEvict(true, indexName, indexPrefix)
|
||||
}
|
||||
|
||||
// batchEvict is a helper that supports eviction based on absolute and prefixed index values.
|
||||
func (c *CacheMemDB) batchEvict(isPrefix bool, indexName string, indexValues ...interface{}) error {
|
||||
if !validIndexName(indexName) {
|
||||
return fmt.Errorf("invalid index name %q", indexName)
|
||||
}
|
||||
|
||||
if isPrefix {
|
||||
indexName = indexName + "_prefix"
|
||||
}
|
||||
|
||||
txn := c.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
|
||||
_, err := txn.DeleteAll(tableNameIndexer, indexName, indexValues...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
txn.Commit()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush resets the underlying cache object.
|
||||
func (c *CacheMemDB) Flush() error {
|
||||
newDB, err := newDB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.db = newDB
|
||||
|
||||
return nil
|
||||
}
|
392
command/agent/cache/cachememdb/cache_memdb_test.go
vendored
Normal file
392
command/agent/cache/cachememdb/cache_memdb_test.go
vendored
Normal file
|
@ -0,0 +1,392 @@
|
|||
package cachememdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
)
|
||||
|
||||
func testContextInfo() *ContextInfo {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
|
||||
return &ContextInfo{
|
||||
Ctx: ctx,
|
||||
CancelFunc: cancelFunc,
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
_, err := New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMemDB_Get(t *testing.T) {
|
||||
cache, err := New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Test invalid index name
|
||||
_, err = cache.Get("foo", "bar")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
// Test on empty cache
|
||||
index, err := cache.Get(IndexNameID, "foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if index != nil {
|
||||
t.Fatalf("expected nil index, got: %v", index)
|
||||
}
|
||||
|
||||
// Populate cache
|
||||
in := &Index{
|
||||
ID: "test_id",
|
||||
Namespace: "test_ns/",
|
||||
RequestPath: "/v1/request/path",
|
||||
Token: "test_token",
|
||||
TokenAccessor: "test_accessor",
|
||||
Lease: "test_lease",
|
||||
Response: []byte("hello world"),
|
||||
}
|
||||
|
||||
if err := cache.Set(in); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
indexName string
|
||||
indexValues []interface{}
|
||||
}{
|
||||
{
|
||||
"by_index_id",
|
||||
"id",
|
||||
[]interface{}{in.ID},
|
||||
},
|
||||
{
|
||||
"by_request_path",
|
||||
"request_path",
|
||||
[]interface{}{in.Namespace, in.RequestPath},
|
||||
},
|
||||
{
|
||||
"by_lease",
|
||||
"lease",
|
||||
[]interface{}{in.Lease},
|
||||
},
|
||||
{
|
||||
"by_token",
|
||||
"token",
|
||||
[]interface{}{in.Token},
|
||||
},
|
||||
{
|
||||
"by_token_accessor",
|
||||
"token_accessor",
|
||||
[]interface{}{in.TokenAccessor},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
out, err := cache.Get(tc.indexName, tc.indexValues...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(in, out); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMemDB_GetByPrefix(t *testing.T) {
|
||||
cache, err := New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Test invalid index name
|
||||
_, err = cache.GetByPrefix("foo", "bar", "baz")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
// Test on empty cache
|
||||
index, err := cache.GetByPrefix(IndexNameRequestPath, "foo", "bar")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if index != nil {
|
||||
t.Fatalf("expected nil index, got: %v", index)
|
||||
}
|
||||
|
||||
// Populate cache
|
||||
in := &Index{
|
||||
ID: "test_id",
|
||||
Namespace: "test_ns/",
|
||||
RequestPath: "/v1/request/path/1",
|
||||
Token: "test_token",
|
||||
TokenParent: "test_token_parent",
|
||||
TokenAccessor: "test_accessor",
|
||||
Lease: "path/to/test_lease/1",
|
||||
LeaseToken: "test_lease_token",
|
||||
Response: []byte("hello world"),
|
||||
}
|
||||
|
||||
if err := cache.Set(in); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Populate cache
|
||||
in2 := &Index{
|
||||
ID: "test_id_2",
|
||||
Namespace: "test_ns/",
|
||||
RequestPath: "/v1/request/path/2",
|
||||
Token: "test_token2",
|
||||
TokenParent: "test_token_parent",
|
||||
TokenAccessor: "test_accessor2",
|
||||
Lease: "path/to/test_lease/2",
|
||||
LeaseToken: "test_lease_token",
|
||||
Response: []byte("hello world"),
|
||||
}
|
||||
|
||||
if err := cache.Set(in2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
indexName string
|
||||
indexValues []interface{}
|
||||
}{
|
||||
{
|
||||
"by_request_path",
|
||||
"request_path",
|
||||
[]interface{}{"test_ns/", "/v1/request/path"},
|
||||
},
|
||||
{
|
||||
"by_lease",
|
||||
"lease",
|
||||
[]interface{}{"path/to/test_lease"},
|
||||
},
|
||||
{
|
||||
"by_token_parent",
|
||||
"token_parent",
|
||||
[]interface{}{"test_token_parent"},
|
||||
},
|
||||
{
|
||||
"by_lease_token",
|
||||
"lease_token",
|
||||
[]interface{}{"test_lease_token"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
out, err := cache.GetByPrefix(tc.indexName, tc.indexValues...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := deep.Equal([]*Index{in, in2}, out); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMemDB_Set(t *testing.T) {
|
||||
cache, err := New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
index *Index
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty_fields",
|
||||
&Index{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"missing_required_fields",
|
||||
&Index{
|
||||
Lease: "foo",
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"all_fields",
|
||||
&Index{
|
||||
ID: "test_id",
|
||||
Namespace: "test_ns/",
|
||||
RequestPath: "/v1/request/path",
|
||||
Token: "test_token",
|
||||
TokenAccessor: "test_accessor",
|
||||
Lease: "test_lease",
|
||||
RenewCtxInfo: testContextInfo(),
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if err := cache.Set(tc.index); (err != nil) != tc.wantErr {
|
||||
t.Fatalf("CacheMemDB.Set() error = %v, wantErr = %v", err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMemDB_Evict(t *testing.T) {
|
||||
cache, err := New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Test on empty cache
|
||||
if err := cache.Evict(IndexNameID, "foo"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testIndex := &Index{
|
||||
ID: "test_id",
|
||||
Namespace: "test_ns/",
|
||||
RequestPath: "/v1/request/path",
|
||||
Token: "test_token",
|
||||
TokenAccessor: "test_token_accessor",
|
||||
Lease: "test_lease",
|
||||
RenewCtxInfo: testContextInfo(),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
indexName string
|
||||
indexValues []interface{}
|
||||
insertIndex *Index
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"empty_params",
|
||||
"",
|
||||
[]interface{}{""},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid_params",
|
||||
"foo",
|
||||
[]interface{}{"bar"},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"by_id",
|
||||
"id",
|
||||
[]interface{}{"test_id"},
|
||||
testIndex,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"by_request_path",
|
||||
"request_path",
|
||||
[]interface{}{"test_ns/", "/v1/request/path"},
|
||||
testIndex,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"by_token",
|
||||
"token",
|
||||
[]interface{}{"test_token"},
|
||||
testIndex,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"by_token_accessor",
|
||||
"token_accessor",
|
||||
[]interface{}{"test_accessor"},
|
||||
testIndex,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"by_lease",
|
||||
"lease",
|
||||
[]interface{}{"test_lease"},
|
||||
testIndex,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.insertIndex != nil {
|
||||
if err := cache.Set(tc.insertIndex); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := cache.Evict(tc.indexName, tc.indexValues...); (err != nil) != tc.wantErr {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify that the cache doesn't contain the entry any more
|
||||
index, err := cache.Get(tc.indexName, tc.indexValues...)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if index != nil {
|
||||
t.Fatalf("expected nil entry, got = %#v", index)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMemDB_Flush(t *testing.T) {
|
||||
cache, err := New()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Populate cache
|
||||
in := &Index{
|
||||
ID: "test_id",
|
||||
Token: "test_token",
|
||||
Lease: "test_lease",
|
||||
Namespace: "test_ns/",
|
||||
RequestPath: "/v1/request/path",
|
||||
Response: []byte("hello world"),
|
||||
}
|
||||
|
||||
if err := cache.Set(in); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Reset the cache
|
||||
if err := cache.Flush(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Check the cache doesn't contain inserted index
|
||||
out, err := cache.Get(IndexNameID, "test_id")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatalf("expected cache to be empty, got = %v", out)
|
||||
}
|
||||
}
|
97
command/agent/cache/cachememdb/index.go
vendored
Normal file
97
command/agent/cache/cachememdb/index.go
vendored
Normal file
|
@ -0,0 +1,97 @@
|
|||
package cachememdb
|
||||
|
||||
import "context"
|
||||
|
||||
type ContextInfo struct {
|
||||
Ctx context.Context
|
||||
CancelFunc context.CancelFunc
|
||||
DoneCh chan struct{}
|
||||
}
|
||||
|
||||
// Index holds the response to be cached along with multiple other values that
|
||||
// serve as pointers to refer back to this index.
|
||||
type Index struct {
|
||||
// ID is a value that uniquely represents the request held by this
|
||||
// index. This is computed by serializing and hashing the response object.
|
||||
// Required: true, Unique: true
|
||||
ID string
|
||||
|
||||
// Token is the token that fetched the response held by this index
|
||||
// Required: true, Unique: true
|
||||
Token string
|
||||
|
||||
// TokenParent is the parent token of the token held by this index
|
||||
// Required: false, Unique: false
|
||||
TokenParent string
|
||||
|
||||
// TokenAccessor is the accessor of the token being cached in this index
|
||||
// Required: true, Unique: true
|
||||
TokenAccessor string
|
||||
|
||||
// Namespace is the namespace that was provided in the request path as the
|
||||
// Vault namespace to query
|
||||
Namespace string
|
||||
|
||||
// RequestPath is the path of the request that resulted in the response
|
||||
// held by this index.
|
||||
// Required: true, Unique: false
|
||||
RequestPath string
|
||||
|
||||
// Lease is the identifier of the lease in Vault, that belongs to the
|
||||
// response held by this index.
|
||||
// Required: false, Unique: true
|
||||
Lease string
|
||||
|
||||
// LeaseToken is the identifier of the token that created the lease held by
|
||||
// this index.
|
||||
// Required: false, Unique: false
|
||||
LeaseToken string
|
||||
|
||||
// Response is the serialized response object that the agent is caching.
|
||||
Response []byte
|
||||
|
||||
// RenewCtxInfo holds the context and the corresponding cancel func for the
|
||||
// goroutine that manages the renewal of the secret belonging to the
|
||||
// response in this index.
|
||||
RenewCtxInfo *ContextInfo
|
||||
}
|
||||
|
||||
type IndexName uint32
|
||||
|
||||
const (
|
||||
// IndexNameID is the ID of the index constructed from the serialized request.
|
||||
IndexNameID = "id"
|
||||
|
||||
// IndexNameLease is the lease of the index.
|
||||
IndexNameLease = "lease"
|
||||
|
||||
// IndexNameRequestPath is the request path of the index.
|
||||
IndexNameRequestPath = "request_path"
|
||||
|
||||
// IndexNameToken is the token of the index.
|
||||
IndexNameToken = "token"
|
||||
|
||||
// IndexNameTokenAccessor is the token accessor of the index.
|
||||
IndexNameTokenAccessor = "token_accessor"
|
||||
|
||||
// IndexNameTokenParent is the token parent of the index.
|
||||
IndexNameTokenParent = "token_parent"
|
||||
|
||||
// IndexNameLeaseToken is the token that created the lease.
|
||||
IndexNameLeaseToken = "lease_token"
|
||||
)
|
||||
|
||||
func validIndexName(indexName string) bool {
|
||||
switch indexName {
|
||||
case "id":
|
||||
case "lease":
|
||||
case "request_path":
|
||||
case "token":
|
||||
case "token_accessor":
|
||||
case "token_parent":
|
||||
case "lease_token":
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
155
command/agent/cache/handler.go
vendored
Normal file
155
command/agent/cache/handler.go
vendored
Normal file
|
@ -0,0 +1,155 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
func Handler(ctx context.Context, logger hclog.Logger, proxier Proxier, useAutoAuthToken bool, client *api.Client) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Info("received request", "path", r.URL.Path, "method", r.Method)
|
||||
|
||||
token := r.Header.Get(consts.AuthHeaderName)
|
||||
if token == "" && useAutoAuthToken {
|
||||
logger.Debug("using auto auth token")
|
||||
token = client.Token()
|
||||
}
|
||||
|
||||
// Parse and reset body.
|
||||
reqBody, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
logger.Error("failed to read request body")
|
||||
respondError(w, http.StatusInternalServerError, errors.New("failed to read request body"))
|
||||
}
|
||||
if r.Body != nil {
|
||||
r.Body.Close()
|
||||
}
|
||||
r.Body = ioutil.NopCloser(bytes.NewBuffer(reqBody))
|
||||
req := &SendRequest{
|
||||
Token: token,
|
||||
Request: r,
|
||||
RequestBody: reqBody,
|
||||
}
|
||||
|
||||
resp, err := proxier.Send(ctx, req)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, errwrap.Wrapf("failed to get the response: {{err}}", err))
|
||||
return
|
||||
}
|
||||
|
||||
err = processTokenLookupResponse(ctx, logger, useAutoAuthToken, client, req, resp)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, errwrap.Wrapf("failed to process token lookup response: {{err}}", err))
|
||||
return
|
||||
}
|
||||
|
||||
defer resp.Response.Body.Close()
|
||||
|
||||
copyHeader(w.Header(), resp.Response.Header)
|
||||
w.WriteHeader(resp.Response.StatusCode)
|
||||
io.Copy(w, resp.Response.Body)
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
// processTokenLookupResponse checks if the request was one of token
|
||||
// lookup-self. If the auto-auth token was used to perform lookup-self, the
|
||||
// identifier of the token and its accessor same will be stripped off of the
|
||||
// response.
|
||||
func processTokenLookupResponse(ctx context.Context, logger hclog.Logger, useAutoAuthToken bool, client *api.Client, req *SendRequest, resp *SendResponse) error {
|
||||
// If auto-auth token is not being used, there is nothing to do.
|
||||
if !useAutoAuthToken {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If lookup responded with non 200 status, there is nothing to do.
|
||||
if resp.Response.StatusCode != http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Strip-off namespace related information from the request and get the
|
||||
// relative path of the request.
|
||||
_, path := deriveNamespaceAndRevocationPath(req)
|
||||
if path == vaultPathTokenLookupSelf {
|
||||
logger.Info("stripping auto-auth token from the response", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
secret, err := api.ParseSecret(bytes.NewBuffer(resp.ResponseBody))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse token lookup response: %v", err)
|
||||
}
|
||||
if secret != nil && secret.Data != nil && secret.Data["id"] != nil {
|
||||
token, ok := secret.Data["id"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to type assert the token id in the response")
|
||||
}
|
||||
if token == client.Token() {
|
||||
delete(secret.Data, "id")
|
||||
delete(secret.Data, "accessor")
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Response.Body != nil {
|
||||
resp.Response.Body.Close()
|
||||
}
|
||||
resp.Response.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
resp.Response.ContentLength = int64(len(bodyBytes))
|
||||
|
||||
// Serialize and re-read the reponse
|
||||
var respBytes bytes.Buffer
|
||||
err = resp.Response.Write(&respBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to serialize the updated response: %v", err)
|
||||
}
|
||||
|
||||
updatedResponse, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBytes.Bytes())), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to deserialize the updated response: %v", err)
|
||||
}
|
||||
|
||||
resp.Response = &api.Response{
|
||||
Response: updatedResponse,
|
||||
}
|
||||
resp.ResponseBody = bodyBytes
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func respondError(w http.ResponseWriter, status int, err error) {
|
||||
logical.AdjustErrorStatusCode(&status, err)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
|
||||
resp := &vaulthttp.ErrorResponse{Errors: make([]string, 0, 1)}
|
||||
if err != nil {
|
||||
resp.Errors = append(resp.Errors, err.Error())
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
enc.Encode(resp)
|
||||
}
|
810
command/agent/cache/lease_cache.go
vendored
Normal file
810
command/agent/cache/lease_cache.go
vendored
Normal file
|
@ -0,0 +1,810 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
cachememdb "github.com/hashicorp/vault/command/agent/cache/cachememdb"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
nshelper "github.com/hashicorp/vault/helper/namespace"
|
||||
)
|
||||
|
||||
const (
|
||||
vaultPathTokenCreate = "/v1/auth/token/create"
|
||||
vaultPathTokenRevoke = "/v1/auth/token/revoke"
|
||||
vaultPathTokenRevokeSelf = "/v1/auth/token/revoke-self"
|
||||
vaultPathTokenRevokeAccessor = "/v1/auth/token/revoke-accessor"
|
||||
vaultPathTokenRevokeOrphan = "/v1/auth/token/revoke-orphan"
|
||||
vaultPathTokenLookupSelf = "/v1/auth/token/lookup-self"
|
||||
vaultPathLeaseRevoke = "/v1/sys/leases/revoke"
|
||||
vaultPathLeaseRevokeForce = "/v1/sys/leases/revoke-force"
|
||||
vaultPathLeaseRevokePrefix = "/v1/sys/leases/revoke-prefix"
|
||||
)
|
||||
|
||||
var (
|
||||
contextIndexID = contextIndex{}
|
||||
errInvalidType = errors.New("invalid type provided")
|
||||
revocationPaths = []string{
|
||||
strings.TrimPrefix(vaultPathTokenRevoke, "/v1"),
|
||||
strings.TrimPrefix(vaultPathTokenRevokeSelf, "/v1"),
|
||||
strings.TrimPrefix(vaultPathTokenRevokeAccessor, "/v1"),
|
||||
strings.TrimPrefix(vaultPathTokenRevokeOrphan, "/v1"),
|
||||
strings.TrimPrefix(vaultPathLeaseRevoke, "/v1"),
|
||||
strings.TrimPrefix(vaultPathLeaseRevokeForce, "/v1"),
|
||||
strings.TrimPrefix(vaultPathLeaseRevokePrefix, "/v1"),
|
||||
}
|
||||
)
|
||||
|
||||
type contextIndex struct{}
|
||||
|
||||
type cacheClearRequest struct {
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
Namespace string `json:"namespace"`
|
||||
}
|
||||
|
||||
// LeaseCache is an implementation of Proxier that handles
|
||||
// the caching of responses. It passes the incoming request
|
||||
// to an underlying Proxier implementation.
|
||||
type LeaseCache struct {
|
||||
proxier Proxier
|
||||
logger hclog.Logger
|
||||
db *cachememdb.CacheMemDB
|
||||
baseCtxInfo *ContextInfo
|
||||
}
|
||||
|
||||
// LeaseCacheConfig is the configuration for initializing a new
|
||||
// Lease.
|
||||
type LeaseCacheConfig struct {
|
||||
BaseContext context.Context
|
||||
Proxier Proxier
|
||||
Logger hclog.Logger
|
||||
}
|
||||
|
||||
// ContextInfo holds a derived context and cancelFunc pair.
|
||||
type ContextInfo struct {
|
||||
Ctx context.Context
|
||||
CancelFunc context.CancelFunc
|
||||
DoneCh chan struct{}
|
||||
}
|
||||
|
||||
// NewLeaseCache creates a new instance of a LeaseCache.
|
||||
func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
|
||||
if conf == nil {
|
||||
return nil, errors.New("nil configuration provided")
|
||||
}
|
||||
|
||||
if conf.Proxier == nil || conf.Logger == nil {
|
||||
return nil, fmt.Errorf("missing configuration required params: %v", conf)
|
||||
}
|
||||
|
||||
db, err := cachememdb.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a base context for the lease cache layer
|
||||
baseCtx, baseCancelFunc := context.WithCancel(conf.BaseContext)
|
||||
baseCtxInfo := &ContextInfo{
|
||||
Ctx: baseCtx,
|
||||
CancelFunc: baseCancelFunc,
|
||||
}
|
||||
|
||||
return &LeaseCache{
|
||||
proxier: conf.Proxier,
|
||||
logger: conf.Logger,
|
||||
db: db,
|
||||
baseCtxInfo: baseCtxInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Send performs a cache lookup on the incoming request. If it's a cache hit,
|
||||
// it will return the cached response, otherwise it will delegate to the
|
||||
// underlying Proxier and cache the received response.
|
||||
func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
|
||||
// Compute the index ID
|
||||
id, err := computeIndexID(req)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to compute cache key", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if the response for this request is already in the cache
|
||||
index, err := c.db.Get(cachememdb.IndexNameID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cached request is found, deserialize the response and return early
|
||||
if index != nil {
|
||||
c.logger.Debug("returning cached response", "path", req.Request.URL.Path)
|
||||
|
||||
reader := bufio.NewReader(bytes.NewReader(index.Response))
|
||||
resp, err := http.ReadResponse(reader, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to deserialize response", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: resp,
|
||||
},
|
||||
ResponseBody: index.Response,
|
||||
}, nil
|
||||
}
|
||||
|
||||
c.logger.Debug("forwarding request", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
|
||||
// Pass the request down and get a response
|
||||
resp, err := c.proxier.Send(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the namespace from the request header
|
||||
namespace := req.Request.Header.Get(consts.NamespaceHeaderName)
|
||||
// We need to populate an empty value since go-memdb will skip over indexes
|
||||
// that contain empty values.
|
||||
if namespace == "" {
|
||||
namespace = "root/"
|
||||
}
|
||||
|
||||
// Build the index to cache based on the response received
|
||||
index = &cachememdb.Index{
|
||||
ID: id,
|
||||
Namespace: namespace,
|
||||
RequestPath: req.Request.URL.Path,
|
||||
}
|
||||
|
||||
secret, err := api.ParseSecret(bytes.NewBuffer(resp.ResponseBody))
|
||||
if err != nil {
|
||||
c.logger.Error("failed to parse response as secret", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
isRevocation, err := c.handleRevocationRequest(ctx, req, resp)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to process the response", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If this is a revocation request, do not go through cache logic.
|
||||
if isRevocation {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Fast path for responses with no secrets
|
||||
if secret == nil {
|
||||
c.logger.Debug("pass-through response; no secret in response", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Short-circuit if the secret is not renewable
|
||||
tokenRenewable, err := secret.TokenIsRenewable()
|
||||
if err != nil {
|
||||
c.logger.Error("failed to parse renewable param", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
if !secret.Renewable && !tokenRenewable {
|
||||
c.logger.Debug("pass-through response; secret not renewable", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
var renewCtxInfo *ContextInfo
|
||||
switch {
|
||||
case secret.LeaseID != "":
|
||||
c.logger.Debug("processing lease response", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If the lease belongs to a token that is not managed by the agent,
|
||||
// return the response without caching it.
|
||||
if entry == nil {
|
||||
c.logger.Debug("pass-through lease response; token not managed by agent", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Derive a context for renewal using the token's context
|
||||
newCtxInfo := new(ContextInfo)
|
||||
newCtxInfo.Ctx, newCtxInfo.CancelFunc = context.WithCancel(entry.RenewCtxInfo.Ctx)
|
||||
newCtxInfo.DoneCh = make(chan struct{})
|
||||
renewCtxInfo = newCtxInfo
|
||||
|
||||
index.Lease = secret.LeaseID
|
||||
index.LeaseToken = req.Token
|
||||
|
||||
case secret.Auth != nil:
|
||||
c.logger.Debug("processing auth response", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
isNonOrphanNewToken := strings.HasPrefix(req.Request.URL.Path, vaultPathTokenCreate) && resp.Response.StatusCode == http.StatusOK && !secret.Auth.Orphan
|
||||
|
||||
// If the new token is a result of token creation endpoints (not from
|
||||
// login endpoints), and if its a non-orphan, then the new token's
|
||||
// context should be derived from the context of the parent token.
|
||||
var parentCtx context.Context
|
||||
if isNonOrphanNewToken {
|
||||
entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If parent token is not managed by the agent, child shouldn't be
|
||||
// either.
|
||||
if entry == nil {
|
||||
c.logger.Debug("pass-through auth response; parent token not managed by agent", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
c.logger.Debug("setting parent context", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
parentCtx = entry.RenewCtxInfo.Ctx
|
||||
|
||||
entry.TokenParent = req.Token
|
||||
}
|
||||
|
||||
renewCtxInfo = c.createCtxInfo(parentCtx, secret.Auth.ClientToken)
|
||||
index.Token = secret.Auth.ClientToken
|
||||
index.TokenAccessor = secret.Auth.Accessor
|
||||
|
||||
default:
|
||||
// We shouldn't be hitting this, but will err on the side of caution and
|
||||
// simply proxy.
|
||||
c.logger.Debug("pass-through response; secret without lease and token", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Serialize the response to store it in the cached index
|
||||
var respBytes bytes.Buffer
|
||||
err = resp.Response.Write(&respBytes)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to serialize response", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reset the response body for upper layers to read
|
||||
if resp.Response.Body != nil {
|
||||
resp.Response.Body.Close()
|
||||
}
|
||||
resp.Response.Body = ioutil.NopCloser(bytes.NewBuffer(resp.ResponseBody))
|
||||
|
||||
// Set the index's Response
|
||||
index.Response = respBytes.Bytes()
|
||||
|
||||
// Store the index ID in the renewer context
|
||||
renewCtx := context.WithValue(renewCtxInfo.Ctx, contextIndexID, index.ID)
|
||||
|
||||
// Store the renewer context in the index
|
||||
index.RenewCtxInfo = &cachememdb.ContextInfo{
|
||||
Ctx: renewCtx,
|
||||
CancelFunc: renewCtxInfo.CancelFunc,
|
||||
DoneCh: renewCtxInfo.DoneCh,
|
||||
}
|
||||
|
||||
// Store the index in the cache
|
||||
c.logger.Debug("storing response into the cache", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
err = c.db.Set(index)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to cache the proxied response", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start renewing the secret in the response
|
||||
go c.startRenewing(renewCtx, index, req, secret)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *LeaseCache) createCtxInfo(ctx context.Context, token string) *ContextInfo {
|
||||
if ctx == nil {
|
||||
ctx = c.baseCtxInfo.Ctx
|
||||
}
|
||||
ctxInfo := new(ContextInfo)
|
||||
ctxInfo.Ctx, ctxInfo.CancelFunc = context.WithCancel(ctx)
|
||||
ctxInfo.DoneCh = make(chan struct{})
|
||||
return ctxInfo
|
||||
}
|
||||
|
||||
func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index, req *SendRequest, secret *api.Secret) {
|
||||
defer func() {
|
||||
id := ctx.Value(contextIndexID).(string)
|
||||
c.logger.Debug("evicting index from cache", "id", id, "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
err := c.db.Evict(cachememdb.IndexNameID, id)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to evict index", "id", id, "error", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
client, err := api.NewClient(api.DefaultConfig())
|
||||
if err != nil {
|
||||
c.logger.Error("failed to create API client in the renewer", "error", err)
|
||||
return
|
||||
}
|
||||
client.SetToken(req.Token)
|
||||
client.SetHeaders(req.Request.Header)
|
||||
|
||||
renewer, err := client.NewRenewer(&api.RenewerInput{
|
||||
Secret: secret,
|
||||
})
|
||||
if err != nil {
|
||||
c.logger.Error("failed to create secret renewer", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Debug("initiating renewal", "path", req.Request.URL.Path, "method", req.Request.Method)
|
||||
go renewer.Renew()
|
||||
defer renewer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// This is the case which captures context cancellations from token
|
||||
// and leases. Since all the contexts are derived from the agent's
|
||||
// context, this will also cover the shutdown scenario.
|
||||
c.logger.Debug("context cancelled; stopping renewer", "path", req.Request.URL.Path)
|
||||
return
|
||||
case err := <-renewer.DoneCh():
|
||||
// This case covers renewal completion and renewal errors
|
||||
if err != nil {
|
||||
c.logger.Error("failed to renew secret", "error", err)
|
||||
return
|
||||
}
|
||||
c.logger.Debug("renewal halted; evicting from cache", "path", req.Request.URL.Path)
|
||||
return
|
||||
case renewal := <-renewer.RenewCh():
|
||||
// This case captures secret renewals. Renewed secret is updated in
|
||||
// the cached index.
|
||||
c.logger.Debug("renewal received; updating cache", "path", req.Request.URL.Path)
|
||||
err = c.updateResponse(ctx, renewal)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to handle renewal", "error", err)
|
||||
return
|
||||
}
|
||||
case <-index.RenewCtxInfo.DoneCh:
|
||||
// This case indicates the renewal process to shutdown and evict
|
||||
// the cache entry. This is triggered when a specific secret
|
||||
// renewal needs to be killed without affecting any of the derived
|
||||
// context renewals.
|
||||
c.logger.Debug("done channel closed")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LeaseCache) updateResponse(ctx context.Context, renewal *api.RenewOutput) error {
|
||||
id := ctx.Value(contextIndexID).(string)
|
||||
|
||||
// Get the cached index using the id in the context
|
||||
index, err := c.db.Get(cachememdb.IndexNameID, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if index == nil {
|
||||
return fmt.Errorf("missing cache entry for id: %q", id)
|
||||
}
|
||||
|
||||
// Read the response from the index
|
||||
resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(index.Response)), nil)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to deserialize response", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the body in the reponse by the renewed secret
|
||||
bodyBytes, err := json.Marshal(renewal.Secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
resp.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
resp.ContentLength = int64(len(bodyBytes))
|
||||
|
||||
// Serialize the response
|
||||
var respBytes bytes.Buffer
|
||||
err = resp.Write(&respBytes)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to serialize updated response", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the response in the index and set it in the cache
|
||||
index.Response = respBytes.Bytes()
|
||||
err = c.db.Set(index)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to cache the proxied response", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// computeIndexID results in a value that uniquely identifies a request
|
||||
// received by the agent. It does so by SHA256 hashing the serialized request
|
||||
// object containing the request path, query parameters and body parameters.
|
||||
func computeIndexID(req *SendRequest) (string, error) {
|
||||
var b bytes.Buffer
|
||||
|
||||
// Serialze the request
|
||||
if err := req.Request.Write(&b); err != nil {
|
||||
return "", fmt.Errorf("failed to serialize request: %v", err)
|
||||
}
|
||||
|
||||
// Reset the request body after it has been closed by Write
|
||||
req.Request.Body = ioutil.NopCloser(bytes.NewBuffer(req.RequestBody))
|
||||
|
||||
// Append req.Token into the byte slice. This is needed since auto-auth'ed
|
||||
// requests sets the token directly into SendRequest.Token
|
||||
b.Write([]byte(req.Token))
|
||||
|
||||
sum := sha256.Sum256(b.Bytes())
|
||||
return hex.EncodeToString(sum[:]), nil
|
||||
}
|
||||
|
||||
// HandleCacheClear returns a handlerFunc that can perform cache clearing operations.
|
||||
func (c *LeaseCache) HandleCacheClear(ctx context.Context) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
req := new(cacheClearRequest)
|
||||
if err := jsonutil.DecodeJSONFromReader(r.Body, req); err != nil {
|
||||
if err == io.EOF {
|
||||
err = errors.New("empty JSON provided")
|
||||
}
|
||||
respondError(w, http.StatusBadRequest, errwrap.Wrapf("failed to parse JSON input: {{err}}", err))
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Debug("received cache-clear request", "type", req.Type, "namespace", req.Namespace, "value", req.Value)
|
||||
|
||||
if err := c.handleCacheClear(ctx, req.Type, req.Namespace, req.Value); err != nil {
|
||||
// Default to 500 on error, unless the user provided an invalid type,
|
||||
// which would then be a 400.
|
||||
httpStatus := http.StatusInternalServerError
|
||||
if err == errInvalidType {
|
||||
httpStatus = http.StatusBadRequest
|
||||
}
|
||||
respondError(w, httpStatus, errwrap.Wrapf("failed to clear cache: {{err}}", err))
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
func (c *LeaseCache) handleCacheClear(ctx context.Context, clearType string, clearValues ...interface{}) error {
|
||||
if len(clearValues) == 0 {
|
||||
return errors.New("no value(s) provided to clear corresponding cache entries")
|
||||
}
|
||||
|
||||
// The value that we want to clear, for most cases, is the last one provided.
|
||||
clearValue, ok := clearValues[len(clearValues)-1].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unable to convert %v to type string", clearValue)
|
||||
}
|
||||
|
||||
switch clearType {
|
||||
case "request_path":
|
||||
// For this particular case, we need to ensure that there are 2 provided
|
||||
// indexers for the proper lookup.
|
||||
if len(clearValues) != 2 {
|
||||
return fmt.Errorf("clearing cache by request path requires 2 indexers, got %d", len(clearValues))
|
||||
}
|
||||
|
||||
// The first value provided for this case will be the namespace, but if it's
|
||||
// an empty value we need to overwrite it with "root/" to ensure proper
|
||||
// cache lookup.
|
||||
if clearValues[0].(string) == "" {
|
||||
clearValues[0] = "root/"
|
||||
}
|
||||
|
||||
// Find all the cached entries which has the given request path and
|
||||
// cancel the contexts of all the respective renewers
|
||||
indexes, err := c.db.GetByPrefix(clearType, clearValues...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, index := range indexes {
|
||||
index.RenewCtxInfo.CancelFunc()
|
||||
}
|
||||
|
||||
case "token":
|
||||
if clearValue == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the context for the given token and cancel its context
|
||||
index, err := c.db.Get(cachememdb.IndexNameToken, clearValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if index == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Debug("cancelling context of index attached to token")
|
||||
|
||||
index.RenewCtxInfo.CancelFunc()
|
||||
|
||||
case "token_accessor", "lease":
|
||||
// Get the cached index and cancel the corresponding renewer context
|
||||
index, err := c.db.Get(clearType, clearValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if index == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Debug("cancelling context of index attached to accessor")
|
||||
|
||||
index.RenewCtxInfo.CancelFunc()
|
||||
|
||||
case "all":
|
||||
// Cancel the base context which triggers all the goroutines to
|
||||
// stop and evict entries from cache.
|
||||
c.logger.Debug("cancelling base context")
|
||||
c.baseCtxInfo.CancelFunc()
|
||||
|
||||
// Reset the base context
|
||||
baseCtx, baseCancel := context.WithCancel(ctx)
|
||||
c.baseCtxInfo = &ContextInfo{
|
||||
Ctx: baseCtx,
|
||||
CancelFunc: baseCancel,
|
||||
}
|
||||
|
||||
// Reset the memdb instance
|
||||
if err := c.db.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
default:
|
||||
return errInvalidType
|
||||
}
|
||||
|
||||
c.logger.Debug("successfully cleared matching cache entries")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleRevocationRequest checks whether the originating request is a
|
||||
// revocation request, and if so perform applicable cache cleanups.
|
||||
// Returns true is this is a revocation request.
|
||||
func (c *LeaseCache) handleRevocationRequest(ctx context.Context, req *SendRequest, resp *SendResponse) (bool, error) {
|
||||
// Lease and token revocations return 204's on success. Fast-path if that's
|
||||
// not the case.
|
||||
if resp.Response.StatusCode != http.StatusNoContent {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
_, path := deriveNamespaceAndRevocationPath(req)
|
||||
|
||||
switch {
|
||||
case path == vaultPathTokenRevoke:
|
||||
// Get the token from the request body
|
||||
jsonBody := map[string]interface{}{}
|
||||
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
|
||||
return false, err
|
||||
}
|
||||
tokenRaw, ok := jsonBody["token"]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("failed to get token from request body")
|
||||
}
|
||||
token, ok := tokenRaw.(string)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("expected token in the request body to be string")
|
||||
}
|
||||
|
||||
// Clear the cache entry associated with the token and all the other
|
||||
// entries belonging to the leases derived from this token.
|
||||
if err := c.handleCacheClear(ctx, "token", token); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
case path == vaultPathTokenRevokeSelf:
|
||||
// Clear the cache entry associated with the token and all the other
|
||||
// entries belonging to the leases derived from this token.
|
||||
if err := c.handleCacheClear(ctx, "token", req.Token); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
case path == vaultPathTokenRevokeAccessor:
|
||||
jsonBody := map[string]interface{}{}
|
||||
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
|
||||
return false, err
|
||||
}
|
||||
accessorRaw, ok := jsonBody["accessor"]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("failed to get accessor from request body")
|
||||
}
|
||||
accessor, ok := accessorRaw.(string)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("expected accessor in the request body to be string")
|
||||
}
|
||||
|
||||
if err := c.handleCacheClear(ctx, "token_accessor", accessor); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
case path == vaultPathTokenRevokeOrphan:
|
||||
jsonBody := map[string]interface{}{}
|
||||
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
|
||||
return false, err
|
||||
}
|
||||
tokenRaw, ok := jsonBody["token"]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("failed to get token from request body")
|
||||
}
|
||||
token, ok := tokenRaw.(string)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("expected token in the request body to be string")
|
||||
}
|
||||
|
||||
// Kill the renewers of all the leases attached to the revoked token
|
||||
indexes, err := c.db.GetByPrefix(cachememdb.IndexNameLeaseToken, token)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, index := range indexes {
|
||||
index.RenewCtxInfo.CancelFunc()
|
||||
}
|
||||
|
||||
// Kill the renewer of the revoked token
|
||||
index, err := c.db.Get(cachememdb.IndexNameToken, token)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if index == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Indicate the renewer goroutine for this index to return. This will
|
||||
// not affect the child tokens because the context is not getting
|
||||
// cancelled.
|
||||
close(index.RenewCtxInfo.DoneCh)
|
||||
|
||||
// Clear the parent references of the revoked token in the entries
|
||||
// belonging to the child tokens of the revoked token.
|
||||
indexes, err = c.db.GetByPrefix(cachememdb.IndexNameTokenParent, token)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, index := range indexes {
|
||||
index.TokenParent = ""
|
||||
err = c.db.Set(index)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to persist index", "error", err)
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
case path == vaultPathLeaseRevoke:
|
||||
// TODO: Should lease present in the URL itself be considered here?
|
||||
// Get the lease from the request body
|
||||
jsonBody := map[string]interface{}{}
|
||||
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
|
||||
return false, err
|
||||
}
|
||||
leaseIDRaw, ok := jsonBody["lease_id"]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("failed to get lease_id from request body")
|
||||
}
|
||||
leaseID, ok := leaseIDRaw.(string)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("expected lease_id the request body to be string")
|
||||
}
|
||||
if err := c.handleCacheClear(ctx, "lease", leaseID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
case strings.HasPrefix(path, vaultPathLeaseRevokeForce):
|
||||
// Trim the URL path to get the request path prefix
|
||||
prefix := strings.TrimPrefix(path, vaultPathLeaseRevokeForce)
|
||||
// Get all the cache indexes that use the request path containing the
|
||||
// prefix and cancel the renewer context of each.
|
||||
indexes, err := c.db.GetByPrefix(cachememdb.IndexNameLease, prefix)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, tokenNSID := namespace.SplitIDFromString(req.Token)
|
||||
for _, index := range indexes {
|
||||
_, leaseNSID := namespace.SplitIDFromString(index.Lease)
|
||||
// Only evict leases that match the token's namespace
|
||||
if tokenNSID == leaseNSID {
|
||||
index.RenewCtxInfo.CancelFunc()
|
||||
}
|
||||
}
|
||||
|
||||
case strings.HasPrefix(path, vaultPathLeaseRevokePrefix):
|
||||
// Trim the URL path to get the request path prefix
|
||||
prefix := strings.TrimPrefix(path, vaultPathLeaseRevokePrefix)
|
||||
// Get all the cache indexes that use the request path containing the
|
||||
// prefix and cancel the renewer context of each.
|
||||
indexes, err := c.db.GetByPrefix(cachememdb.IndexNameLease, prefix)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, tokenNSID := namespace.SplitIDFromString(req.Token)
|
||||
for _, index := range indexes {
|
||||
_, leaseNSID := namespace.SplitIDFromString(index.Lease)
|
||||
// Only evict leases that match the token's namespace
|
||||
if tokenNSID == leaseNSID {
|
||||
index.RenewCtxInfo.CancelFunc()
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
|
||||
c.logger.Debug("triggered caching eviction from revocation request")
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// deriveNamespaceAndRevocationPath returns the namespace and relative path for
|
||||
// revocation paths.
|
||||
//
|
||||
// If the path contains a namespace, but it's not a revocation path, it will be
|
||||
// returned as-is, since there's no way to tell where the namespace ends and
|
||||
// where the request path begins purely based off a string.
|
||||
//
|
||||
// Case 1: /v1/ns1/leases/revoke -> ns1/, /v1/leases/revoke
|
||||
// Case 2: ns1/ /v1/leases/revoke -> ns1/, /v1/leases/revoke
|
||||
// Case 3: /v1/ns1/foo/bar -> root/, /v1/ns1/foo/bar
|
||||
// Case 4: ns1/ /v1/foo/bar -> ns1/, /v1/foo/bar
|
||||
func deriveNamespaceAndRevocationPath(req *SendRequest) (string, string) {
|
||||
namespace := "root/"
|
||||
nsHeader := req.Request.Header.Get(consts.NamespaceHeaderName)
|
||||
if nsHeader != "" {
|
||||
namespace = nsHeader
|
||||
}
|
||||
|
||||
fullPath := req.Request.URL.Path
|
||||
nonVersionedPath := strings.TrimPrefix(fullPath, "/v1")
|
||||
|
||||
for _, pathToCheck := range revocationPaths {
|
||||
// We use strings.Contains here for paths that can contain
|
||||
// vars in the path, e.g. /v1/lease/revoke-prefix/:prefix
|
||||
i := strings.Index(nonVersionedPath, pathToCheck)
|
||||
// If there's no match, move on to the next check
|
||||
if i == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
// If the index is 0, this is a relative path with no namespace preppended,
|
||||
// so we can break early
|
||||
if i == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// We need to turn /ns1 into ns1/, this makes it easy
|
||||
namespaceInPath := nshelper.Canonicalize(nonVersionedPath[:i])
|
||||
|
||||
// If it's root, we replace, otherwise we join
|
||||
if namespace == "root/" {
|
||||
namespace = namespaceInPath
|
||||
} else {
|
||||
namespace = namespace + namespaceInPath
|
||||
}
|
||||
|
||||
return namespace, fmt.Sprintf("/v1%s", nonVersionedPath[i:])
|
||||
}
|
||||
|
||||
return namespace, fmt.Sprintf("/v1%s", nonVersionedPath)
|
||||
}
|
507
command/agent/cache/lease_cache_test.go
vendored
Normal file
507
command/agent/cache/lease_cache_test.go
vendored
Normal file
|
@ -0,0 +1,507 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/logging"
|
||||
)
|
||||
|
||||
func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache {
|
||||
t.Helper()
|
||||
|
||||
lc, err := NewLeaseCache(&LeaseCacheConfig{
|
||||
BaseContext: context.Background(),
|
||||
Proxier: newMockProxier(responses),
|
||||
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return lc
|
||||
}
|
||||
|
||||
func TestCache_ComputeIndexID(t *testing.T) {
|
||||
type args struct {
|
||||
req *http.Request
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
req *SendRequest
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"basic",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "test",
|
||||
},
|
||||
},
|
||||
},
|
||||
"2edc7e965c3e1bdce3b1d5f79a52927842569c0734a86544d222753f11ae4847",
|
||||
false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := computeIndexID(tt.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("actual_error: %v, expected_error: %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, string(tt.want)) {
|
||||
t.Errorf("bad: index id; actual: %q, expected: %q", got, string(tt.want))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_LeaseCache_EmptyToken(t *testing.T) {
|
||||
responses := []*SendResponse{
|
||||
&SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: http.StatusCreated,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`)),
|
||||
},
|
||||
},
|
||||
ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`),
|
||||
},
|
||||
}
|
||||
lc := testNewLeaseCache(t, responses)
|
||||
|
||||
// Even if the send request doesn't have a token on it, a successful
|
||||
// cacheable response should result in the index properly getting populated
|
||||
// with a token and memdb shouldn't complain while inserting the index.
|
||||
urlPath := "http://example.com/v1/sample/api"
|
||||
sendReq := &SendRequest{
|
||||
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
|
||||
}
|
||||
resp, err := lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("expected a non empty response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_LeaseCache_SendCacheable(t *testing.T) {
|
||||
// Emulate 2 responses from the api proxy. One returns a new token and the
|
||||
// other returns a lease.
|
||||
responses := []*SendResponse{
|
||||
&SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: http.StatusCreated,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "testtoken", "renewable": true}}`)),
|
||||
},
|
||||
},
|
||||
ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "testtoken", "renewable": true}}`),
|
||||
},
|
||||
&SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value": "output", "lease_id": "foo", "renewable": true}`)),
|
||||
},
|
||||
},
|
||||
ResponseBody: []byte(`{"value": "output", "lease_id": "foo", "renewable": true}`),
|
||||
},
|
||||
}
|
||||
lc := testNewLeaseCache(t, responses)
|
||||
|
||||
// Make a request. A response with a new token is returned to the lease
|
||||
// cache and that will be cached.
|
||||
urlPath := "http://example.com/v1/sample/api"
|
||||
sendReq := &SendRequest{
|
||||
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
|
||||
}
|
||||
resp, err := lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
|
||||
// Send the same request again to get the cached response
|
||||
sendReq = &SendRequest{
|
||||
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
|
||||
}
|
||||
resp, err = lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
|
||||
// Modify the request a little bit to ensure the second response is
|
||||
// returned to the lease cache. But make sure that the token in the request
|
||||
// is valid.
|
||||
sendReq = &SendRequest{
|
||||
Token: "testtoken",
|
||||
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)),
|
||||
}
|
||||
resp, err = lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
|
||||
// Make the same request again and ensure that the same reponse is returned
|
||||
// again.
|
||||
sendReq = &SendRequest{
|
||||
Token: "testtoken",
|
||||
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)),
|
||||
}
|
||||
resp, err = lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_LeaseCache_SendNonCacheable(t *testing.T) {
|
||||
responses := []*SendResponse{
|
||||
&SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value": "output"}`)),
|
||||
},
|
||||
},
|
||||
},
|
||||
&SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: http.StatusNotFound,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid"}`)),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
lc := testNewLeaseCache(t, responses)
|
||||
|
||||
// Send a request through the lease cache which is not cacheable (there is
|
||||
// no lease information or auth information in the response)
|
||||
sendReq := &SendRequest{
|
||||
Request: httptest.NewRequest("GET", "http://example.com", strings.NewReader(`{"value": "input"}`)),
|
||||
}
|
||||
resp, err := lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
|
||||
// Since the response is non-cacheable, the second response will be
|
||||
// returned.
|
||||
sendReq = &SendRequest{
|
||||
Token: "foo",
|
||||
Request: httptest.NewRequest("GET", "http://example.com", strings.NewReader(`{"value": "input"}`)),
|
||||
}
|
||||
resp, err = lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_LeaseCache_SendNonCacheableNonTokenLease(t *testing.T) {
|
||||
// Create the cache
|
||||
responses := []*SendResponse{
|
||||
&SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value": "output", "lease_id": "foo"}`)),
|
||||
},
|
||||
},
|
||||
ResponseBody: []byte(`{"value": "output", "lease_id": "foo"}`),
|
||||
},
|
||||
&SendResponse{
|
||||
Response: &api.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: http.StatusCreated,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`)),
|
||||
},
|
||||
},
|
||||
ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`),
|
||||
},
|
||||
}
|
||||
lc := testNewLeaseCache(t, responses)
|
||||
|
||||
// Send a request through lease cache which returns a response containing
|
||||
// lease_id. Response will not be cached because it doesn't belong to a
|
||||
// token that is managed by the lease cache.
|
||||
urlPath := "http://example.com/v1/sample/api"
|
||||
sendReq := &SendRequest{
|
||||
Token: "foo",
|
||||
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
|
||||
}
|
||||
resp, err := lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
|
||||
// Verify that the response is not cached by sending the same request and
|
||||
// by expecting a different response.
|
||||
sendReq = &SendRequest{
|
||||
Token: "foo",
|
||||
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
|
||||
}
|
||||
resp, err = lc.Send(context.Background(), sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff == nil {
|
||||
t.Fatalf("expected getting proxied response: got %v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_LeaseCache_HandleCacheClear(t *testing.T) {
|
||||
lc := testNewLeaseCache(t, nil)
|
||||
|
||||
handler := lc.HandleCacheClear(context.Background())
|
||||
ts := httptest.NewServer(handler)
|
||||
defer ts.Close()
|
||||
|
||||
// Test missing body, should return 400
|
||||
resp, err := http.Post(ts.URL, "application/json", nil)
|
||||
if err != nil {
|
||||
t.Fatal()
|
||||
}
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("status code mismatch: expected = %v, got = %v", http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
reqType string
|
||||
reqValue string
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
"invalid_type",
|
||||
"foo",
|
||||
"",
|
||||
http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
"invalid_value",
|
||||
"",
|
||||
"bar",
|
||||
http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
"all",
|
||||
"all",
|
||||
"",
|
||||
http.StatusOK,
|
||||
},
|
||||
{
|
||||
"by_request_path",
|
||||
"request_path",
|
||||
"foo",
|
||||
http.StatusOK,
|
||||
},
|
||||
{
|
||||
"by_token",
|
||||
"token",
|
||||
"foo",
|
||||
http.StatusOK,
|
||||
},
|
||||
{
|
||||
"by_lease",
|
||||
"lease",
|
||||
"foo",
|
||||
http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
reqBody := fmt.Sprintf("{\"type\": \"%s\", \"value\": \"%s\"}", tc.reqType, tc.reqValue)
|
||||
resp, err := http.Post(ts.URL, "application/json", strings.NewReader(reqBody))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tc.expectedStatusCode != resp.StatusCode {
|
||||
t.Fatalf("status code mismatch: expected = %v, got = %v", tc.expectedStatusCode, resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_DeriveNamespaceAndRevocationPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *SendRequest
|
||||
wantNamespace string
|
||||
wantRelativePath string
|
||||
}{
|
||||
{
|
||||
"non_revocation_full_path",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/ns1/sys/mounts",
|
||||
},
|
||||
},
|
||||
},
|
||||
"root/",
|
||||
"/v1/ns1/sys/mounts",
|
||||
},
|
||||
{
|
||||
"non_revocation_relative_path",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/sys/mounts",
|
||||
},
|
||||
Header: http.Header{
|
||||
consts.NamespaceHeaderName: []string{"ns1/"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/",
|
||||
"/v1/sys/mounts",
|
||||
},
|
||||
{
|
||||
"non_revocation_relative_path",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/ns2/sys/mounts",
|
||||
},
|
||||
Header: http.Header{
|
||||
consts.NamespaceHeaderName: []string{"ns1/"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/",
|
||||
"/v1/ns2/sys/mounts",
|
||||
},
|
||||
{
|
||||
"revocation_full_path",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/ns1/sys/leases/revoke",
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/",
|
||||
"/v1/sys/leases/revoke",
|
||||
},
|
||||
{
|
||||
"revocation_relative_path",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/sys/leases/revoke",
|
||||
},
|
||||
Header: http.Header{
|
||||
consts.NamespaceHeaderName: []string{"ns1/"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/",
|
||||
"/v1/sys/leases/revoke",
|
||||
},
|
||||
{
|
||||
"revocation_relative_partial_ns",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/ns2/sys/leases/revoke",
|
||||
},
|
||||
Header: http.Header{
|
||||
consts.NamespaceHeaderName: []string{"ns1/"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/ns2/",
|
||||
"/v1/sys/leases/revoke",
|
||||
},
|
||||
{
|
||||
"revocation_prefix_full_path",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/ns1/sys/leases/revoke-prefix/foo",
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/",
|
||||
"/v1/sys/leases/revoke-prefix/foo",
|
||||
},
|
||||
{
|
||||
"revocation_prefix_relative_path",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/sys/leases/revoke-prefix/foo",
|
||||
},
|
||||
Header: http.Header{
|
||||
consts.NamespaceHeaderName: []string{"ns1/"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/",
|
||||
"/v1/sys/leases/revoke-prefix/foo",
|
||||
},
|
||||
{
|
||||
"revocation_prefix_partial_ns",
|
||||
&SendRequest{
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "/v1/ns2/sys/leases/revoke-prefix/foo",
|
||||
},
|
||||
Header: http.Header{
|
||||
consts.NamespaceHeaderName: []string{"ns1/"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"ns1/ns2/",
|
||||
"/v1/sys/leases/revoke-prefix/foo",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotNamespace, gotRelativePath := deriveNamespaceAndRevocationPath(tt.req)
|
||||
if gotNamespace != tt.wantNamespace {
|
||||
t.Errorf("deriveNamespaceAndRevocationPath() gotNamespace = %v, want %v", gotNamespace, tt.wantNamespace)
|
||||
}
|
||||
if gotRelativePath != tt.wantRelativePath {
|
||||
t.Errorf("deriveNamespaceAndRevocationPath() gotRelativePath = %v, want %v", gotRelativePath, tt.wantRelativePath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
105
command/agent/cache/listener.go
vendored
Normal file
105
command/agent/cache/listener.go
vendored
Normal file
|
@ -0,0 +1,105 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/command/agent/config"
|
||||
"github.com/hashicorp/vault/command/server"
|
||||
"github.com/hashicorp/vault/helper/reload"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func ServerListener(lnConfig *config.Listener, logger io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) {
|
||||
switch lnConfig.Type {
|
||||
case "unix":
|
||||
return unixSocketListener(lnConfig.Config, logger, ui)
|
||||
case "tcp":
|
||||
return tcpListener(lnConfig.Config, logger, ui)
|
||||
default:
|
||||
return nil, nil, nil, fmt.Errorf("unsupported listener type: %q", lnConfig.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func unixSocketListener(config map[string]interface{}, _ io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) {
|
||||
addr, ok := config["address"].(string)
|
||||
if !ok {
|
||||
return nil, nil, nil, fmt.Errorf("invalid address: %v", config["address"])
|
||||
}
|
||||
|
||||
if addr == "" {
|
||||
return nil, nil, nil, fmt.Errorf("address field should point to socket file path")
|
||||
}
|
||||
|
||||
// Remove the socket file as it shouldn't exist for the domain socket to
|
||||
// work
|
||||
err := os.Remove(addr)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return nil, nil, nil, fmt.Errorf("failed to remove the socket file: %v", err)
|
||||
}
|
||||
|
||||
listener, err := net.Listen("unix", addr)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
// Wrap the listener in rmListener so that the Unix domain socket file is
|
||||
// removed on close.
|
||||
listener = &rmListener{
|
||||
Listener: listener,
|
||||
Path: addr,
|
||||
}
|
||||
|
||||
props := map[string]string{"addr": addr, "tls": "disabled"}
|
||||
|
||||
return server.ListenerWrapTLS(listener, props, config, ui)
|
||||
}
|
||||
|
||||
func tcpListener(config map[string]interface{}, _ io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) {
|
||||
bindProto := "tcp"
|
||||
var addr string
|
||||
addrRaw, ok := config["address"]
|
||||
if !ok {
|
||||
addr = "127.0.0.1:8300"
|
||||
} else {
|
||||
addr = addrRaw.(string)
|
||||
}
|
||||
|
||||
// If they've passed 0.0.0.0, we only want to bind on IPv4
|
||||
// rather than golang's dual stack default
|
||||
if strings.HasPrefix(addr, "0.0.0.0:") {
|
||||
bindProto = "tcp4"
|
||||
}
|
||||
|
||||
ln, err := net.Listen(bindProto, addr)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
ln = server.TCPKeepAliveListener{ln.(*net.TCPListener)}
|
||||
|
||||
props := map[string]string{"addr": addr}
|
||||
|
||||
return server.ListenerWrapTLS(ln, props, config, ui)
|
||||
}
|
||||
|
||||
// rmListener is an implementation of net.Listener that forwards most
|
||||
// calls to the listener but also removes a file as part of the close. We
|
||||
// use this to cleanup the unix domain socket on close.
|
||||
type rmListener struct {
|
||||
net.Listener
|
||||
Path string
|
||||
}
|
||||
|
||||
func (l *rmListener) Close() error {
|
||||
// Close the listener itself
|
||||
if err := l.Listener.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove the file
|
||||
return os.Remove(l.Path)
|
||||
}
|
28
command/agent/cache/proxy.go
vendored
Normal file
28
command/agent/cache/proxy.go
vendored
Normal file
|
@ -0,0 +1,28 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
)
|
||||
|
||||
// SendRequest is the input for Proxier.Send.
|
||||
type SendRequest struct {
|
||||
Token string
|
||||
Request *http.Request
|
||||
RequestBody []byte
|
||||
}
|
||||
|
||||
// SendResponse is the output from Proxier.Send.
|
||||
type SendResponse struct {
|
||||
Response *api.Response
|
||||
ResponseBody []byte
|
||||
}
|
||||
|
||||
// Proxier is the interface implemented by different components that are
|
||||
// responsible for performing specific tasks, such as caching and proxying. All
|
||||
// these tasks combined together would serve the request received by the agent.
|
||||
type Proxier interface {
|
||||
Send(ctx context.Context, req *SendRequest) (*SendResponse, error)
|
||||
}
|
36
command/agent/cache/testing.go
vendored
Normal file
36
command/agent/cache/testing.go
vendored
Normal file
|
@ -0,0 +1,36 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// mockProxier is a mock implementation of the Proxier interface, used for testing purposes.
|
||||
// The mock will return the provided responses every time it reaches its Send method, up to
|
||||
// the last provided response. This lets tests control what the next/underlying Proxier layer
|
||||
// might expect to return.
|
||||
type mockProxier struct {
|
||||
proxiedResponses []*SendResponse
|
||||
responseIndex int
|
||||
}
|
||||
|
||||
func newMockProxier(responses []*SendResponse) *mockProxier {
|
||||
return &mockProxier{
|
||||
proxiedResponses: responses,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *mockProxier) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
|
||||
if p.responseIndex >= len(p.proxiedResponses) {
|
||||
return nil, fmt.Errorf("index out of bounds: responseIndex = %d, responses = %d", p.responseIndex, len(p.proxiedResponses))
|
||||
}
|
||||
resp := p.proxiedResponses[p.responseIndex]
|
||||
|
||||
p.responseIndex++
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *mockProxier) ResponseIndex() int {
|
||||
return p.responseIndex
|
||||
}
|
280
command/agent/cache_end_to_end_test.go
Normal file
280
command/agent/cache_end_to_end_test.go
Normal file
|
@ -0,0 +1,280 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
|
||||
"github.com/hashicorp/vault/command/agent/auth"
|
||||
agentapprole "github.com/hashicorp/vault/command/agent/auth/approle"
|
||||
"github.com/hashicorp/vault/command/agent/cache"
|
||||
"github.com/hashicorp/vault/command/agent/sink"
|
||||
"github.com/hashicorp/vault/command/agent/sink/file"
|
||||
"github.com/hashicorp/vault/helper/logging"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
func TestCache_UsingAutoAuthToken(t *testing.T) {
|
||||
var err error
|
||||
logger := logging.NewVaultLogger(log.Trace)
|
||||
coreConfig := &vault.CoreConfig{
|
||||
DisableMlock: true,
|
||||
DisableCache: true,
|
||||
Logger: log.NewNullLogger(),
|
||||
CredentialBackends: map[string]logical.Factory{
|
||||
"approle": credAppRole.Factory,
|
||||
},
|
||||
}
|
||||
|
||||
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
})
|
||||
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
|
||||
cores := cluster.Cores
|
||||
|
||||
vault.TestWaitActive(t, cores[0].Core)
|
||||
|
||||
client := cores[0].Client
|
||||
|
||||
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
|
||||
os.Setenv(api.EnvVaultAddress, client.Address())
|
||||
|
||||
defer os.Setenv(api.EnvVaultCACert, os.Getenv(api.EnvVaultCACert))
|
||||
os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir))
|
||||
|
||||
err = client.Sys().EnableAuthWithOptions("approle", &api.EnableAuthOptions{
|
||||
Type: "approle",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Write("auth/approle/role/test1", map[string]interface{}{
|
||||
"bind_secret_id": "true",
|
||||
"token_ttl": "3s",
|
||||
"token_max_ttl": "10s",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err := client.Logical().Write("auth/approle/role/test1/secret-id", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
secretID1 := resp.Data["secret_id"].(string)
|
||||
|
||||
resp, err = client.Logical().Read("auth/approle/role/test1/role-id")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
roleID1 := resp.Data["role_id"].(string)
|
||||
|
||||
rolef, err := ioutil.TempFile("", "auth.role-id.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
role := rolef.Name()
|
||||
rolef.Close() // WriteFile doesn't need it open
|
||||
defer os.Remove(role)
|
||||
t.Logf("input role_id_file_path: %s", role)
|
||||
|
||||
secretf, err := ioutil.TempFile("", "auth.secret-id.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
secret := secretf.Name()
|
||||
secretf.Close()
|
||||
defer os.Remove(secret)
|
||||
t.Logf("input secret_id_file_path: %s", secret)
|
||||
|
||||
// We close these right away because we're just basically testing
|
||||
// permissions and finding a usable file name
|
||||
ouf, err := ioutil.TempFile("", "auth.tokensink.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
out := ouf.Name()
|
||||
ouf.Close()
|
||||
os.Remove(out)
|
||||
t.Logf("output: %s", out)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
timer := time.AfterFunc(30*time.Second, func() {
|
||||
cancelFunc()
|
||||
})
|
||||
defer timer.Stop()
|
||||
|
||||
conf := map[string]interface{}{
|
||||
"role_id_file_path": role,
|
||||
"secret_id_file_path": secret,
|
||||
"remove_secret_id_file_after_reading": true,
|
||||
}
|
||||
|
||||
am, err := agentapprole.NewApproleAuthMethod(&auth.AuthConfig{
|
||||
Logger: logger.Named("auth.approle"),
|
||||
MountPath: "auth/approle",
|
||||
Config: conf,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ahConfig := &auth.AuthHandlerConfig{
|
||||
Logger: logger.Named("auth.handler"),
|
||||
Client: client,
|
||||
}
|
||||
ah := auth.NewAuthHandler(ahConfig)
|
||||
go ah.Run(ctx, am)
|
||||
defer func() {
|
||||
<-ah.DoneCh
|
||||
}()
|
||||
|
||||
config := &sink.SinkConfig{
|
||||
Logger: logger.Named("sink.file"),
|
||||
Config: map[string]interface{}{
|
||||
"path": out,
|
||||
},
|
||||
}
|
||||
fs, err := file.NewFileSink(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config.Sink = fs
|
||||
|
||||
ss := sink.NewSinkServer(&sink.SinkServerConfig{
|
||||
Logger: logger.Named("sink.server"),
|
||||
Client: client,
|
||||
})
|
||||
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
|
||||
defer func() {
|
||||
<-ss.DoneCh
|
||||
}()
|
||||
|
||||
// This has to be after the other defers so it happens first
|
||||
defer cancelFunc()
|
||||
|
||||
// Check that no sink file exists
|
||||
_, err = os.Lstat(out)
|
||||
if err == nil {
|
||||
t.Fatal("expected err")
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
t.Fatal("expected notexist err")
|
||||
}
|
||||
|
||||
if err := ioutil.WriteFile(role, []byte(roleID1), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
logger.Trace("wrote test role 1", "path", role)
|
||||
}
|
||||
|
||||
if err := ioutil.WriteFile(secret, []byte(secretID1), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
logger.Trace("wrote test secret 1", "path", secret)
|
||||
}
|
||||
|
||||
getToken := func() string {
|
||||
timeout := time.Now().Add(10 * time.Second)
|
||||
for {
|
||||
if time.Now().After(timeout) {
|
||||
t.Fatal("did not find a written token after timeout")
|
||||
}
|
||||
val, err := ioutil.ReadFile(out)
|
||||
if err == nil {
|
||||
os.Remove(out)
|
||||
if len(val) == 0 {
|
||||
t.Fatal("written token was empty")
|
||||
}
|
||||
|
||||
_, err = os.Stat(secret)
|
||||
if err == nil {
|
||||
t.Fatal("secret file exists but was supposed to be removed")
|
||||
}
|
||||
|
||||
client.SetToken(string(val))
|
||||
_, err := client.Auth().Token().LookupSelf()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return string(val)
|
||||
}
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("auto-auth token: %q", getToken())
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer listener.Close()
|
||||
|
||||
cacheLogger := logging.NewVaultLogger(hclog.Trace).Named("cache")
|
||||
|
||||
// Create the API proxier
|
||||
apiProxy := cache.NewAPIProxy(&cache.APIProxyConfig{
|
||||
Logger: cacheLogger.Named("apiproxy"),
|
||||
})
|
||||
|
||||
// Create the lease cache proxier and set its underlying proxier to
|
||||
// the API proxier.
|
||||
leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{
|
||||
BaseContext: ctx,
|
||||
Proxier: apiProxy,
|
||||
Logger: cacheLogger.Named("leasecache"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a muxer and add paths relevant for the lease cache layer
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx))
|
||||
|
||||
mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, true, client))
|
||||
server := &http.Server{
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
ErrorLog: cacheLogger.StandardLogger(nil),
|
||||
}
|
||||
go server.Serve(listener)
|
||||
|
||||
testClient, err := api.NewClient(api.DefaultConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := testClient.SetAddress("http://" + listener.Addr().String()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Wait for listeners to come up
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
resp, err = testClient.Logical().Read("auth/token/lookup-self")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("failed to use the auto-auth token to perform lookup-self")
|
||||
}
|
||||
}
|
|
@ -22,6 +22,17 @@ type Config struct {
|
|||
AutoAuth *AutoAuth `hcl:"auto_auth"`
|
||||
ExitAfterAuth bool `hcl:"exit_after_auth"`
|
||||
PidFile string `hcl:"pid_file"`
|
||||
Cache *Cache `hcl:"cache"`
|
||||
}
|
||||
|
||||
type Cache struct {
|
||||
UseAutoAuthToken bool `hcl:"use_auto_auth_token"`
|
||||
Listeners []*Listener `hcl:"listeners"`
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
Type string
|
||||
Config map[string]interface{}
|
||||
}
|
||||
|
||||
type AutoAuth struct {
|
||||
|
@ -91,9 +102,94 @@ func LoadConfig(path string, logger log.Logger) (*Config, error) {
|
|||
return nil, errwrap.Wrapf("error parsing 'auto_auth': {{err}}", err)
|
||||
}
|
||||
|
||||
err = parseCache(&result, list)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error parsing 'cache':{{err}}", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func parseCache(result *Config, list *ast.ObjectList) error {
|
||||
name := "cache"
|
||||
|
||||
cacheList := list.Filter(name)
|
||||
if len(cacheList.Items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(cacheList.Items) > 1 {
|
||||
return fmt.Errorf("one and only one %q block is required", name)
|
||||
}
|
||||
|
||||
item := cacheList.Items[0]
|
||||
|
||||
var c Cache
|
||||
err := hcl.DecodeObject(&c, item.Val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result.Cache = &c
|
||||
|
||||
subs, ok := item.Val.(*ast.ObjectType)
|
||||
if !ok {
|
||||
return fmt.Errorf("could not parse %q as an object", name)
|
||||
}
|
||||
subList := subs.List
|
||||
|
||||
err = parseListeners(result, subList)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("error parsing 'listener' stanzas: {{err}}", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseListeners(result *Config, list *ast.ObjectList) error {
|
||||
name := "listener"
|
||||
|
||||
listenerList := list.Filter(name)
|
||||
if len(listenerList.Items) < 1 {
|
||||
return fmt.Errorf("at least one %q block is required", name)
|
||||
}
|
||||
|
||||
var listeners []*Listener
|
||||
for _, item := range listenerList.Items {
|
||||
var lnConfig map[string]interface{}
|
||||
err := hcl.DecodeObject(&lnConfig, item.Val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var lnType string
|
||||
switch {
|
||||
case lnConfig["type"] != nil:
|
||||
lnType = lnConfig["type"].(string)
|
||||
delete(lnConfig, "type")
|
||||
case len(item.Keys) == 1:
|
||||
lnType = strings.ToLower(item.Keys[0].Token.Value().(string))
|
||||
default:
|
||||
return errors.New("listener type must be specified")
|
||||
}
|
||||
|
||||
switch lnType {
|
||||
case "unix", "tcp":
|
||||
default:
|
||||
return fmt.Errorf("invalid listener type %q", lnType)
|
||||
}
|
||||
|
||||
listeners = append(listeners, &Listener{
|
||||
Type: lnType,
|
||||
Config: lnConfig,
|
||||
})
|
||||
}
|
||||
|
||||
result.Cache.Listeners = listeners
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAutoAuth(result *Config, list *ast.ObjectList) error {
|
||||
name := "auto_auth"
|
||||
|
||||
|
|
|
@ -10,6 +10,80 @@ import (
|
|||
"github.com/hashicorp/vault/helper/logging"
|
||||
)
|
||||
|
||||
func TestLoadConfigFile_AgentCache(t *testing.T) {
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
config, err := LoadConfig("./test-fixtures/config-cache.hcl", logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := &Config{
|
||||
AutoAuth: &AutoAuth{
|
||||
Method: &Method{
|
||||
Type: "aws",
|
||||
WrapTTL: 300 * time.Second,
|
||||
MountPath: "auth/aws",
|
||||
Config: map[string]interface{}{
|
||||
"role": "foobar",
|
||||
},
|
||||
},
|
||||
Sinks: []*Sink{
|
||||
&Sink{
|
||||
Type: "file",
|
||||
DHType: "curve25519",
|
||||
DHPath: "/tmp/file-foo-dhpath",
|
||||
AAD: "foobar",
|
||||
Config: map[string]interface{}{
|
||||
"path": "/tmp/file-foo",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Cache: &Cache{
|
||||
UseAutoAuthToken: true,
|
||||
Listeners: []*Listener{
|
||||
&Listener{
|
||||
Type: "unix",
|
||||
Config: map[string]interface{}{
|
||||
"address": "/path/to/socket",
|
||||
"tls_disable": true,
|
||||
},
|
||||
},
|
||||
&Listener{
|
||||
Type: "tcp",
|
||||
Config: map[string]interface{}{
|
||||
"address": "127.0.0.1:8300",
|
||||
"tls_disable": true,
|
||||
},
|
||||
},
|
||||
&Listener{
|
||||
Type: "tcp",
|
||||
Config: map[string]interface{}{
|
||||
"address": "127.0.0.1:8400",
|
||||
"tls_key_file": "/path/to/cakey.pem",
|
||||
"tls_cert_file": "/path/to/cacert.pem",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
PidFile: "./pidfile",
|
||||
}
|
||||
|
||||
if diff := deep.Equal(config, expected); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
|
||||
config, err = LoadConfig("./test-fixtures/config-cache-embedded-type.hcl", logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := deep.Equal(config, expected); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigFile(t *testing.T) {
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
pid_file = "./pidfile"
|
||||
|
||||
auto_auth {
|
||||
method {
|
||||
type = "aws"
|
||||
wrap_ttl = 300
|
||||
config = {
|
||||
role = "foobar"
|
||||
}
|
||||
}
|
||||
|
||||
sink {
|
||||
type = "file"
|
||||
config = {
|
||||
path = "/tmp/file-foo"
|
||||
}
|
||||
aad = "foobar"
|
||||
dh_type = "curve25519"
|
||||
dh_path = "/tmp/file-foo-dhpath"
|
||||
}
|
||||
}
|
||||
|
||||
cache {
|
||||
use_auto_auth_token = true
|
||||
|
||||
listener {
|
||||
type = "unix"
|
||||
address = "/path/to/socket"
|
||||
tls_disable = true
|
||||
}
|
||||
|
||||
listener {
|
||||
type = "tcp"
|
||||
address = "127.0.0.1:8300"
|
||||
tls_disable = true
|
||||
}
|
||||
|
||||
listener {
|
||||
type = "tcp"
|
||||
address = "127.0.0.1:8400"
|
||||
tls_key_file = "/path/to/cakey.pem"
|
||||
tls_cert_file = "/path/to/cacert.pem"
|
||||
}
|
||||
}
|
41
command/agent/config/test-fixtures/config-cache.hcl
Normal file
41
command/agent/config/test-fixtures/config-cache.hcl
Normal file
|
@ -0,0 +1,41 @@
|
|||
pid_file = "./pidfile"
|
||||
|
||||
auto_auth {
|
||||
method {
|
||||
type = "aws"
|
||||
wrap_ttl = 300
|
||||
config = {
|
||||
role = "foobar"
|
||||
}
|
||||
}
|
||||
|
||||
sink {
|
||||
type = "file"
|
||||
config = {
|
||||
path = "/tmp/file-foo"
|
||||
}
|
||||
aad = "foobar"
|
||||
dh_type = "curve25519"
|
||||
dh_path = "/tmp/file-foo-dhpath"
|
||||
}
|
||||
}
|
||||
|
||||
cache {
|
||||
use_auto_auth_token = true
|
||||
|
||||
listener "unix" {
|
||||
address = "/path/to/socket"
|
||||
tls_disable = true
|
||||
}
|
||||
|
||||
listener "tcp" {
|
||||
address = "127.0.0.1:8300"
|
||||
tls_disable = true
|
||||
}
|
||||
|
||||
listener "tcp" {
|
||||
address = "127.0.0.1:8400"
|
||||
tls_key_file = "/path/to/cakey.pem"
|
||||
tls_cert_file = "/path/to/cacert.pem"
|
||||
}
|
||||
}
|
|
@ -62,6 +62,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
|||
}
|
||||
|
||||
_, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{
|
||||
"role_type": "jwt",
|
||||
"bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
"bound_audiences": "https://vault.plugin.auth.jwt.test",
|
||||
"user_claim": "https://vault/user",
|
||||
|
|
|
@ -30,6 +30,191 @@ func testAgentCommand(tb testing.TB, logger hclog.Logger) (*cli.MockUi, *AgentCo
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
func TestAgent_Cache_UnixListener(t *testing.T) {
|
||||
logger := logging.NewVaultLogger(hclog.Trace)
|
||||
coreConfig := &vault.CoreConfig{
|
||||
Logger: logger.Named("core"),
|
||||
CredentialBackends: map[string]logical.Factory{
|
||||
"jwt": vaultjwt.Factory,
|
||||
},
|
||||
}
|
||||
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
})
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
|
||||
vault.TestWaitActive(t, cluster.Cores[0].Core)
|
||||
client := cluster.Cores[0].Client
|
||||
|
||||
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
|
||||
os.Setenv(api.EnvVaultAddress, client.Address())
|
||||
|
||||
defer os.Setenv(api.EnvVaultCACert, os.Getenv(api.EnvVaultCACert))
|
||||
os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir))
|
||||
|
||||
// Setup Vault
|
||||
err := client.Sys().EnableAuthWithOptions("jwt", &api.EnableAuthOptions{
|
||||
Type: "jwt",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{
|
||||
"bound_issuer": "https://team-vault.auth0.com/",
|
||||
"jwt_validation_pubkeys": agent.TestECDSAPubKey,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{
|
||||
"role_type": "jwt",
|
||||
"bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
"bound_audiences": "https://vault.plugin.auth.jwt.test",
|
||||
"user_claim": "https://vault/user",
|
||||
"groups_claim": "https://vault/groups",
|
||||
"policies": "test",
|
||||
"period": "3s",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
inf, err := ioutil.TempFile("", "auth.jwt.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
in := inf.Name()
|
||||
inf.Close()
|
||||
os.Remove(in)
|
||||
t.Logf("input: %s", in)
|
||||
|
||||
sink1f, err := ioutil.TempFile("", "sink1.jwt.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sink1 := sink1f.Name()
|
||||
sink1f.Close()
|
||||
os.Remove(sink1)
|
||||
t.Logf("sink1: %s", sink1)
|
||||
|
||||
sink2f, err := ioutil.TempFile("", "sink2.jwt.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sink2 := sink2f.Name()
|
||||
sink2f.Close()
|
||||
os.Remove(sink2)
|
||||
t.Logf("sink2: %s", sink2)
|
||||
|
||||
conff, err := ioutil.TempFile("", "conf.jwt.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conf := conff.Name()
|
||||
conff.Close()
|
||||
os.Remove(conf)
|
||||
t.Logf("config: %s", conf)
|
||||
|
||||
jwtToken, _ := agent.GetTestJWT(t)
|
||||
if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
logger.Trace("wrote test jwt", "path", in)
|
||||
}
|
||||
|
||||
socketff, err := ioutil.TempFile("", "cache.socket.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
socketf := socketff.Name()
|
||||
socketff.Close()
|
||||
os.Remove(socketf)
|
||||
t.Logf("socketf: %s", socketf)
|
||||
|
||||
config := `
|
||||
auto_auth {
|
||||
method {
|
||||
type = "jwt"
|
||||
config = {
|
||||
role = "test"
|
||||
path = "%s"
|
||||
}
|
||||
}
|
||||
|
||||
sink {
|
||||
type = "file"
|
||||
config = {
|
||||
path = "%s"
|
||||
}
|
||||
}
|
||||
|
||||
sink "file" {
|
||||
config = {
|
||||
path = "%s"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cache {
|
||||
use_auto_auth_token = true
|
||||
|
||||
listener "unix" {
|
||||
address = "%s"
|
||||
tls_disable = true
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
config = fmt.Sprintf(config, in, sink1, sink2, socketf)
|
||||
if err := ioutil.WriteFile(conf, []byte(config), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
logger.Trace("wrote test config", "path", conf)
|
||||
}
|
||||
|
||||
_, cmd := testAgentCommand(t, logger)
|
||||
cmd.client = client
|
||||
|
||||
// Kill the command 5 seconds after it starts
|
||||
go func() {
|
||||
select {
|
||||
case <-cmd.ShutdownCh:
|
||||
case <-time.After(5 * time.Second):
|
||||
cmd.ShutdownCh <- struct{}{}
|
||||
}
|
||||
}()
|
||||
|
||||
originalVaultAgentAddress := os.Getenv(api.EnvVaultAgentAddress)
|
||||
|
||||
// Create a client that talks to the agent
|
||||
os.Setenv(api.EnvVaultAgentAddress, socketf)
|
||||
testClient, err := api.NewClient(api.DefaultConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
os.Setenv(api.EnvVaultAgentAddress, originalVaultAgentAddress)
|
||||
|
||||
// Start the agent
|
||||
go cmd.Run([]string{"-config", conf})
|
||||
|
||||
// Give some time for the auto-auth to complete
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Invoke lookup self through the agent
|
||||
secret, err := testClient.Auth().Token().LookupSelf()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if secret == nil || secret.Data == nil || secret.Data["id"].(string) == "" {
|
||||
t.Fatalf("failed to perform lookup self through agent")
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
func TestExitAfterAuth(t *testing.T) {
|
||||
logger := logging.NewVaultLogger(hclog.Trace)
|
||||
coreConfig := &vault.CoreConfig{
|
||||
|
@ -64,6 +249,7 @@ func TestExitAfterAuth(t *testing.T) {
|
|||
}
|
||||
|
||||
_, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{
|
||||
"role_type": "jwt",
|
||||
"bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
"bound_audiences": "https://vault.plugin.auth.jwt.test",
|
||||
"user_claim": "https://vault/user",
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
package command
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
|
||||
"github.com/mitchellh/cli"
|
||||
|
@ -13,10 +10,6 @@ var _ cli.Command = (*AuthCommand)(nil)
|
|||
|
||||
type AuthCommand struct {
|
||||
*BaseCommand
|
||||
|
||||
Handlers map[string]LoginHandler
|
||||
|
||||
testStdin io.Reader // for tests
|
||||
}
|
||||
|
||||
func (c *AuthCommand) Synopsis() string {
|
||||
|
@ -52,77 +45,5 @@ Usage: vault auth <subcommand> [options] [args]
|
|||
}
|
||||
|
||||
func (c *AuthCommand) Run(args []string) int {
|
||||
// If we entered the run method, none of the subcommands picked up. This
|
||||
// means the user is still trying to use auth as "vault auth TOKEN" or
|
||||
// similar, so direct them to vault login instead.
|
||||
//
|
||||
// This run command is a bit messy to maintain BC for a bit. In the future,
|
||||
// it will just be a tiny function, but for now we have to maintain bc.
|
||||
//
|
||||
// Deprecation
|
||||
// TODO: remove in 0.9.0
|
||||
|
||||
if len(args) == 0 {
|
||||
return cli.RunResultHelp
|
||||
}
|
||||
|
||||
// Parse the args for our deprecations and defer to the proper areas.
|
||||
for _, arg := range args {
|
||||
switch {
|
||||
case strings.HasPrefix(arg, "-methods"):
|
||||
if Format(c.UI) == "table" {
|
||||
c.UI.Warn(wrapAtLength(
|
||||
"WARNING! The -methods flag is deprecated. Please use "+
|
||||
"\"vault auth list\" instead. This flag will be removed in "+
|
||||
"Vault 1.1.") + "\n")
|
||||
}
|
||||
return (&AuthListCommand{
|
||||
BaseCommand: &BaseCommand{
|
||||
UI: c.UI,
|
||||
client: c.client,
|
||||
},
|
||||
}).Run(nil)
|
||||
case strings.HasPrefix(arg, "-method-help"):
|
||||
if Format(c.UI) == "table" {
|
||||
c.UI.Warn(wrapAtLength(
|
||||
"WARNING! The -method-help flag is deprecated. Please use "+
|
||||
"\"vault auth help\" instead. This flag will be removed in "+
|
||||
"Vault 1.1.") + "\n")
|
||||
}
|
||||
// Parse the args to pull out the method, suppressing any errors because
|
||||
// there could be other flags that we don't care about.
|
||||
f := flag.NewFlagSet("", flag.ContinueOnError)
|
||||
f.Usage = func() {}
|
||||
f.SetOutput(ioutil.Discard)
|
||||
flagMethod := f.String("method", "", "")
|
||||
f.Parse(args)
|
||||
|
||||
return (&AuthHelpCommand{
|
||||
BaseCommand: &BaseCommand{
|
||||
UI: c.UI,
|
||||
client: c.client,
|
||||
},
|
||||
Handlers: c.Handlers,
|
||||
}).Run([]string{*flagMethod})
|
||||
}
|
||||
}
|
||||
|
||||
// If we got this far, we have an arg or a series of args that should be
|
||||
// passed directly to the new "vault login" command.
|
||||
if Format(c.UI) == "table" {
|
||||
c.UI.Warn(wrapAtLength(
|
||||
"WARNING! The \"vault auth ARG\" command is deprecated and is now a "+
|
||||
"subcommand for interacting with auth methods. To authenticate "+
|
||||
"locally to Vault, use \"vault login\" instead. This backwards "+
|
||||
"compatibility will be removed in Vault 1.1.") + "\n")
|
||||
}
|
||||
return (&LoginCommand{
|
||||
BaseCommand: &BaseCommand{
|
||||
UI: c.UI,
|
||||
client: c.client,
|
||||
tokenHelper: c.tokenHelper,
|
||||
flagAddress: c.flagAddress,
|
||||
},
|
||||
Handlers: c.Handlers,
|
||||
}).Run(args)
|
||||
return cli.RunResultHelp
|
||||
}
|
||||
|
|
|
@ -175,7 +175,8 @@ func TestAuthEnableCommand_Run(t *testing.T) {
|
|||
|
||||
// Add 1 to account for the "token" backend, which is visible when you walk the filesystem but
|
||||
// is treated as special and excluded from the registry.
|
||||
expected := len(builtinplugins.Registry.Keys(consts.PluginTypeCredential)) + 1
|
||||
// Subtract 1 to account for "oidc" which is an alias of "jwt" and not a separate plugin.
|
||||
expected := len(builtinplugins.Registry.Keys(consts.PluginTypeCredential))
|
||||
if len(backends) != expected {
|
||||
t.Fatalf("expected %d credential backends, got %d", expected, len(backends))
|
||||
}
|
||||
|
|
|
@ -1,13 +1,10 @@
|
|||
package command
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mitchellh/cli"
|
||||
|
||||
credToken "github.com/hashicorp/vault/builtin/credential/token"
|
||||
credUserpass "github.com/hashicorp/vault/builtin/credential/userpass"
|
||||
"github.com/hashicorp/vault/command/token"
|
||||
)
|
||||
|
||||
|
@ -22,110 +19,12 @@ func testAuthCommand(tb testing.TB) (*cli.MockUi, *AuthCommand) {
|
|||
// Override to our own token helper
|
||||
tokenHelper: token.NewTestingTokenHelper(),
|
||||
},
|
||||
Handlers: map[string]LoginHandler{
|
||||
"token": &credToken.CLIHandler{},
|
||||
"userpass": &credUserpass.CLIHandler{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthCommand_Run(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// TODO: remove in 0.9.0
|
||||
t.Run("deprecated_methods", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
defer closer()
|
||||
|
||||
ui, cmd := testAuthCommand(t)
|
||||
cmd.client = client
|
||||
|
||||
// vault auth -methods -> vault auth list
|
||||
code := cmd.Run([]string{"-methods"})
|
||||
if exp := 0; code != exp {
|
||||
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
|
||||
}
|
||||
stdout, stderr := ui.OutputWriter.String(), ui.ErrorWriter.String()
|
||||
|
||||
if expected := "WARNING!"; !strings.Contains(stderr, expected) {
|
||||
t.Errorf("expected %q to contain %q", stderr, expected)
|
||||
}
|
||||
|
||||
if expected := "token/"; !strings.Contains(stdout, expected) {
|
||||
t.Errorf("expected %q to contain %q", stdout, expected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deprecated_method_help", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
defer closer()
|
||||
|
||||
ui, cmd := testAuthCommand(t)
|
||||
cmd.client = client
|
||||
|
||||
// vault auth -method=foo -method-help -> vault auth help foo
|
||||
code := cmd.Run([]string{
|
||||
"-method=userpass",
|
||||
"-method-help",
|
||||
})
|
||||
if exp := 0; code != exp {
|
||||
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
|
||||
}
|
||||
stdout, stderr := ui.OutputWriter.String(), ui.ErrorWriter.String()
|
||||
|
||||
if expected := "WARNING!"; !strings.Contains(stderr, expected) {
|
||||
t.Errorf("expected %q to contain %q", stderr, expected)
|
||||
}
|
||||
|
||||
if expected := "vault login"; !strings.Contains(stdout, expected) {
|
||||
t.Errorf("expected %q to contain %q", stdout, expected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deprecated_login", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
defer closer()
|
||||
|
||||
if err := client.Sys().EnableAuth("my-auth", "userpass", ""); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := client.Logical().Write("auth/my-auth/users/test", map[string]interface{}{
|
||||
"password": "test",
|
||||
"policies": "default",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ui, cmd := testAuthCommand(t)
|
||||
cmd.client = client
|
||||
|
||||
// vault auth ARGS -> vault login ARGS
|
||||
code := cmd.Run([]string{
|
||||
"-method", "userpass",
|
||||
"-path", "my-auth",
|
||||
"username=test",
|
||||
"password=test",
|
||||
})
|
||||
if exp := 0; code != exp {
|
||||
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
|
||||
}
|
||||
stdout, stderr := ui.OutputWriter.String(), ui.ErrorWriter.String()
|
||||
|
||||
if expected := "WARNING!"; !strings.Contains(stderr, expected) {
|
||||
t.Errorf("expected %q to contain %q", stderr, expected)
|
||||
}
|
||||
|
||||
if expected := "Success! You are now authenticated."; !strings.Contains(stdout, expected) {
|
||||
t.Errorf("expected %q to contain %q", stdout, expected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no_tabs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ type BaseCommand struct {
|
|||
flagsOnce sync.Once
|
||||
|
||||
flagAddress string
|
||||
flagAgentAddress string
|
||||
flagCACert string
|
||||
flagCAPath string
|
||||
flagClientCert string
|
||||
|
@ -78,6 +79,9 @@ func (c *BaseCommand) Client() (*api.Client, error) {
|
|||
if c.flagAddress != "" {
|
||||
config.Address = c.flagAddress
|
||||
}
|
||||
if c.flagAgentAddress != "" {
|
||||
config.Address = c.flagAgentAddress
|
||||
}
|
||||
|
||||
if c.flagOutputCurlString {
|
||||
config.OutputCurlString = c.flagOutputCurlString
|
||||
|
@ -220,6 +224,15 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
|
|||
}
|
||||
f.StringVar(addrStringVar)
|
||||
|
||||
agentAddrStringVar := &StringVar{
|
||||
Name: "agent-address",
|
||||
Target: &c.flagAgentAddress,
|
||||
EnvVar: "VAULT_AGENT_ADDR",
|
||||
Completion: complete.PredictAnything,
|
||||
Usage: "Address of the Agent.",
|
||||
}
|
||||
f.StringVar(agentAddrStringVar)
|
||||
|
||||
f.StringVar(&StringVar{
|
||||
Name: "ca-cert",
|
||||
Target: &c.flagCACert,
|
||||
|
|
|
@ -352,6 +352,7 @@ func TestPredict_Plugins(t *testing.T) {
|
|||
"mysql-legacy-database-plugin",
|
||||
"mysql-rds-database-plugin",
|
||||
"nomad",
|
||||
"oidc",
|
||||
"okta",
|
||||
"pki",
|
||||
"postgresql",
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package command
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
@ -27,6 +26,7 @@ import (
|
|||
credAliCloud "github.com/hashicorp/vault-plugin-auth-alicloud"
|
||||
credCentrify "github.com/hashicorp/vault-plugin-auth-centrify"
|
||||
credGcp "github.com/hashicorp/vault-plugin-auth-gcp/plugin"
|
||||
credOIDC "github.com/hashicorp/vault-plugin-auth-jwt"
|
||||
credAws "github.com/hashicorp/vault/builtin/credential/aws"
|
||||
credCert "github.com/hashicorp/vault/builtin/credential/cert"
|
||||
credGitHub "github.com/hashicorp/vault/builtin/credential/github"
|
||||
|
@ -130,43 +130,8 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
// DeprecatedCommand is a command that wraps an existing command and prints a
|
||||
// deprecation notice and points the user to the new command. Deprecated
|
||||
// commands are always hidden from help output.
|
||||
type DeprecatedCommand struct {
|
||||
cli.Command
|
||||
UI cli.Ui
|
||||
|
||||
// Old is the old command name, New is the new command name.
|
||||
Old, New string
|
||||
}
|
||||
|
||||
// Help wraps the embedded Help command and prints a warning about deprecations.
|
||||
func (c *DeprecatedCommand) Help() string {
|
||||
c.warn()
|
||||
return c.Command.Help()
|
||||
}
|
||||
|
||||
// Run wraps the embedded Run command and prints a warning about deprecation.
|
||||
func (c *DeprecatedCommand) Run(args []string) int {
|
||||
if Format(c.UI) == "table" {
|
||||
c.warn()
|
||||
}
|
||||
return c.Command.Run(args)
|
||||
}
|
||||
|
||||
func (c *DeprecatedCommand) warn() {
|
||||
c.UI.Warn(wrapAtLength(fmt.Sprintf(
|
||||
"WARNING! The \"vault %s\" command is deprecated. Please use \"vault %s\" "+
|
||||
"instead. This command will be removed in Vault 1.1.",
|
||||
c.Old,
|
||||
c.New)))
|
||||
c.UI.Warn("")
|
||||
}
|
||||
|
||||
// Commands is the mapping of all the available commands.
|
||||
var Commands map[string]cli.CommandFactory
|
||||
var DeprecatedCommands map[string]cli.CommandFactory
|
||||
|
||||
func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) {
|
||||
loginHandlers := map[string]LoginHandler{
|
||||
|
@ -177,6 +142,7 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) {
|
|||
"gcp": &credGcp.CLIHandler{},
|
||||
"github": &credGitHub.CLIHandler{},
|
||||
"ldap": &credLdap.CLIHandler{},
|
||||
"oidc": &credOIDC.CLIHandler{},
|
||||
"okta": &credOkta.CLIHandler{},
|
||||
"radius": &credUserpass.CLIHandler{
|
||||
DefaultMount: "radius",
|
||||
|
@ -233,7 +199,6 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) {
|
|||
"auth": func() (cli.Command, error) {
|
||||
return &AuthCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
Handlers: loginHandlers,
|
||||
}, nil
|
||||
},
|
||||
"auth disable": func() (cli.Command, error) {
|
||||
|
@ -612,328 +577,6 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) {
|
|||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
// Deprecated commands
|
||||
//
|
||||
// TODO: Remove not before 0.11.0
|
||||
DeprecatedCommands = map[string]cli.CommandFactory{
|
||||
"audit-disable": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "audit-disable",
|
||||
New: "audit disable",
|
||||
UI: ui,
|
||||
Command: &AuditDisableCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"audit-enable": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "audit-enable",
|
||||
New: "audit enable",
|
||||
UI: ui,
|
||||
Command: &AuditEnableCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"audit-list": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "audit-list",
|
||||
New: "audit list",
|
||||
UI: ui,
|
||||
Command: &AuditListCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"auth-disable": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "auth-disable",
|
||||
New: "auth disable",
|
||||
UI: ui,
|
||||
Command: &AuthDisableCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"auth-enable": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "auth-enable",
|
||||
New: "auth enable",
|
||||
UI: ui,
|
||||
Command: &AuthEnableCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"capabilities": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "capabilities",
|
||||
New: "token capabilities",
|
||||
UI: ui,
|
||||
Command: &TokenCapabilitiesCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"generate-root": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "generate-root",
|
||||
New: "operator generate-root",
|
||||
UI: ui,
|
||||
Command: &OperatorGenerateRootCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"init": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "init",
|
||||
New: "operator init",
|
||||
UI: ui,
|
||||
Command: &OperatorInitCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"key-status": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "key-status",
|
||||
New: "operator key-status",
|
||||
UI: ui,
|
||||
Command: &OperatorKeyStatusCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"renew": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "renew",
|
||||
New: "lease renew",
|
||||
UI: ui,
|
||||
Command: &LeaseRenewCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"revoke": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "revoke",
|
||||
New: "lease revoke",
|
||||
UI: ui,
|
||||
Command: &LeaseRevokeCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"mount": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "mount",
|
||||
New: "secrets enable",
|
||||
UI: ui,
|
||||
Command: &SecretsEnableCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"mount-tune": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "mount-tune",
|
||||
New: "secrets tune",
|
||||
UI: ui,
|
||||
Command: &SecretsTuneCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"mounts": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "mounts",
|
||||
New: "secrets list",
|
||||
UI: ui,
|
||||
Command: &SecretsListCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"policies": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "policies",
|
||||
New: "policy read\" or \"vault policy list", // lol
|
||||
UI: ui,
|
||||
Command: &PoliciesDeprecatedCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"policy-delete": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "policy-delete",
|
||||
New: "policy delete",
|
||||
UI: ui,
|
||||
Command: &PolicyDeleteCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"policy-write": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "policy-write",
|
||||
New: "policy write",
|
||||
UI: ui,
|
||||
Command: &PolicyWriteCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"rekey": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "rekey",
|
||||
New: "operator rekey",
|
||||
UI: ui,
|
||||
Command: &OperatorRekeyCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"remount": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "remount",
|
||||
New: "secrets move",
|
||||
UI: ui,
|
||||
Command: &SecretsMoveCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"rotate": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "rotate",
|
||||
New: "operator rotate",
|
||||
UI: ui,
|
||||
Command: &OperatorRotateCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"seal": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "seal",
|
||||
New: "operator seal",
|
||||
UI: ui,
|
||||
Command: &OperatorSealCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"step-down": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "step-down",
|
||||
New: "operator step-down",
|
||||
UI: ui,
|
||||
Command: &OperatorStepDownCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"token-create": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "token-create",
|
||||
New: "token create",
|
||||
UI: ui,
|
||||
Command: &TokenCreateCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"token-lookup": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "token-lookup",
|
||||
New: "token lookup",
|
||||
UI: ui,
|
||||
Command: &TokenLookupCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"token-renew": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "token-renew",
|
||||
New: "token renew",
|
||||
UI: ui,
|
||||
Command: &TokenRenewCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"token-revoke": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "token-revoke",
|
||||
New: "token revoke",
|
||||
UI: ui,
|
||||
Command: &TokenRevokeCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"unmount": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "unmount",
|
||||
New: "secrets disable",
|
||||
UI: ui,
|
||||
Command: &SecretsDisableCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
||||
"unseal": func() (cli.Command, error) {
|
||||
return &DeprecatedCommand{
|
||||
Old: "unseal",
|
||||
New: "operator unseal",
|
||||
UI: ui,
|
||||
Command: &OperatorUnsealCommand{
|
||||
BaseCommand: getBaseCommand(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
// Add deprecated commands back to the main commands so they parse.
|
||||
for k, v := range DeprecatedCommands {
|
||||
if _, ok := Commands[k]; ok {
|
||||
// Can't deprecate an existing command...
|
||||
panic(fmt.Sprintf("command %q defined as deprecated and not at the same time!", k))
|
||||
}
|
||||
Commands[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// MakeShutdownCh returns a channel that can be used for shutdown
|
||||
|
|
|
@ -28,10 +28,6 @@ type LoginCommand struct {
|
|||
flagNoPrint bool
|
||||
flagTokenOnly bool
|
||||
|
||||
// Deprecations
|
||||
// TODO: remove in 0.9.0
|
||||
flagNoVerify bool
|
||||
|
||||
testStdin io.Reader // for tests
|
||||
}
|
||||
|
||||
|
@ -132,16 +128,6 @@ func (c *LoginCommand) Flags() *FlagSets {
|
|||
"values will have no affect.",
|
||||
})
|
||||
|
||||
// Deprecations
|
||||
// TODO: remove in 0.9.0
|
||||
f.BoolVar(&BoolVar{
|
||||
Name: "no-verify",
|
||||
Target: &c.flagNoVerify,
|
||||
Hidden: true,
|
||||
Default: false,
|
||||
Usage: "",
|
||||
})
|
||||
|
||||
return set
|
||||
}
|
||||
|
||||
|
@ -163,39 +149,6 @@ func (c *LoginCommand) Run(args []string) int {
|
|||
|
||||
args = f.Args()
|
||||
|
||||
// Deprecations
|
||||
// TODO: remove in 0.10.0
|
||||
switch {
|
||||
case c.flagNoVerify:
|
||||
if Format(c.UI) == "table" {
|
||||
c.UI.Warn(wrapAtLength(
|
||||
"WARNING! The -no-verify flag is deprecated. In the past, Vault " +
|
||||
"performed a lookup on a token after authentication. This is no " +
|
||||
"longer the case for all auth methods except \"token\". Vault will " +
|
||||
"still attempt to perform a lookup when given a token directly " +
|
||||
"because that is how it gets the list of policies, ttl, and other " +
|
||||
"metadata. To disable this lookup, specify \"lookup=false\" as a " +
|
||||
"configuration option to the token auth method, like this:"))
|
||||
c.UI.Warn("")
|
||||
c.UI.Warn(" $ vault auth token=ABCD lookup=false")
|
||||
c.UI.Warn("")
|
||||
c.UI.Warn("Or omit the token and Vault will prompt for it:")
|
||||
c.UI.Warn("")
|
||||
c.UI.Warn(" $ vault auth lookup=false")
|
||||
c.UI.Warn(" Token (will be hidden): ...")
|
||||
c.UI.Warn("")
|
||||
c.UI.Warn(wrapAtLength(
|
||||
"If you are not using token authentication, you can safely omit this " +
|
||||
"flag. Vault will not perform a lookup after authentication."))
|
||||
c.UI.Warn("")
|
||||
}
|
||||
|
||||
// There's no point in passing this to other auth handlers...
|
||||
if c.flagMethod == "token" {
|
||||
args = append(args, "lookup=false")
|
||||
}
|
||||
}
|
||||
|
||||
// Set the right flags if the user requested token-only - this overrides
|
||||
// any previously configured values, as documented.
|
||||
if c.flagTokenOnly {
|
||||
|
|
|
@ -443,51 +443,6 @@ func TestLoginCommand_Run(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
// Deprecations
|
||||
// TODO: remove in 0.9.0
|
||||
t.Run("deprecated_no_verify", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
defer closer()
|
||||
|
||||
secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{
|
||||
Policies: []string{"default"},
|
||||
TTL: "30m",
|
||||
NumUses: 1,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token := secret.Auth.ClientToken
|
||||
|
||||
_, cmd := testLoginCommand(t)
|
||||
cmd.client = client
|
||||
|
||||
code := cmd.Run([]string{
|
||||
"-no-verify",
|
||||
token,
|
||||
})
|
||||
if exp := 0; code != exp {
|
||||
t.Errorf("expected %d to be %d", code, exp)
|
||||
}
|
||||
|
||||
lookup, err := client.Auth().Token().Lookup(token)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// There was 1 use to start, make sure we didn't use it (verifying would
|
||||
// use it).
|
||||
uses, err := lookup.TokenRemainingUses()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if uses != 1 {
|
||||
t.Errorf("expected %d to be %d", uses, 1)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no_tabs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -159,12 +159,7 @@ func RunCustom(args []string, runOpts *RunOptions) int {
|
|||
|
||||
initCommands(ui, serverCmdUi, runOpts)
|
||||
|
||||
// Calculate hidden commands from the deprecated ones
|
||||
hiddenCommands := make([]string, 0, len(DeprecatedCommands)+1)
|
||||
for k := range DeprecatedCommands {
|
||||
hiddenCommands = append(hiddenCommands, k)
|
||||
}
|
||||
hiddenCommands = append(hiddenCommands, "version")
|
||||
hiddenCommands := []string{"version"}
|
||||
|
||||
cli := &cli.CLI{
|
||||
Name: "vault",
|
||||
|
|
|
@ -36,10 +36,6 @@ type OperatorGenerateRootCommand struct {
|
|||
flagGenerateOTP bool
|
||||
flagDRToken bool
|
||||
|
||||
// Deprecation
|
||||
// TODO: remove in 0.9.0
|
||||
flagGenOTP bool
|
||||
|
||||
testStdin io.Reader // for tests
|
||||
}
|
||||
|
||||
|
@ -179,15 +175,6 @@ func (c *OperatorGenerateRootCommand) Flags() *FlagSets {
|
|||
"must be provided with each unseal key.",
|
||||
})
|
||||
|
||||
// Deprecations: prefer longer-form, descriptive flags
|
||||
// TODO: remove in 0.9.0
|
||||
f.BoolVar(&BoolVar{
|
||||
Name: "genotp", // -generate-otp
|
||||
Target: &c.flagGenOTP,
|
||||
Default: false,
|
||||
Hidden: true,
|
||||
})
|
||||
|
||||
return set
|
||||
}
|
||||
|
||||
|
@ -213,18 +200,6 @@ func (c *OperatorGenerateRootCommand) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
// Deprecations
|
||||
// TODO: remove in 0.9.0
|
||||
switch {
|
||||
case c.flagGenOTP:
|
||||
if Format(c.UI) == "table" {
|
||||
c.UI.Warn(wrapAtLength(
|
||||
"WARNING! The -gen-otp flag is deprecated. Please use the -generate-otp flag " +
|
||||
"instead."))
|
||||
}
|
||||
c.flagGenerateOTP = c.flagGenOTP
|
||||
}
|
||||
|
||||
client, err := c.Client()
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
|
|
|
@ -27,7 +27,6 @@ type OperatorInitCommand struct {
|
|||
flagRootTokenPGPKey string
|
||||
|
||||
// HSM
|
||||
flagStoredShares int
|
||||
flagRecoveryShares int
|
||||
flagRecoveryThreshold int
|
||||
flagRecoveryPGPKeys []string
|
||||
|
@ -35,11 +34,6 @@ type OperatorInitCommand struct {
|
|||
// Consul
|
||||
flagConsulAuto bool
|
||||
flagConsulService string
|
||||
|
||||
// Deprecations
|
||||
// TODO: remove in 0.9.0
|
||||
flagAuto bool
|
||||
flagCheck bool
|
||||
}
|
||||
|
||||
func (c *OperatorInitCommand) Synopsis() string {
|
||||
|
@ -196,32 +190,6 @@ func (c *OperatorInitCommand) Flags() *FlagSets {
|
|||
"is only used in HSM mode.",
|
||||
})
|
||||
|
||||
// Deprecations
|
||||
// TODO: remove in 0.9.0
|
||||
f.BoolVar(&BoolVar{
|
||||
Name: "check", // prefer -status
|
||||
Target: &c.flagCheck,
|
||||
Default: false,
|
||||
Hidden: true,
|
||||
Usage: "",
|
||||
})
|
||||
f.BoolVar(&BoolVar{
|
||||
Name: "auto", // prefer -consul-auto
|
||||
Target: &c.flagAuto,
|
||||
Default: false,
|
||||
Hidden: true,
|
||||
Usage: "",
|
||||
})
|
||||
|
||||
// Kept to keep scripts passing the flag working, but not used
|
||||
f.IntVar(&IntVar{
|
||||
Name: "stored-shares",
|
||||
Target: &c.flagStoredShares,
|
||||
Default: 0,
|
||||
Hidden: true,
|
||||
Usage: "",
|
||||
})
|
||||
|
||||
return set
|
||||
}
|
||||
|
||||
|
@ -241,23 +209,6 @@ func (c *OperatorInitCommand) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
// Deprecations
|
||||
// TODO: remove in 0.9.0
|
||||
if c.flagAuto {
|
||||
if Format(c.UI) == "table" {
|
||||
c.UI.Warn(wrapAtLength("WARNING! -auto is deprecated. Please use " +
|
||||
"-consul-auto instead. This will be removed in Vault 1.1."))
|
||||
}
|
||||
c.flagConsulAuto = true
|
||||
}
|
||||
if c.flagCheck {
|
||||
if Format(c.UI) == "table" {
|
||||
c.UI.Warn(wrapAtLength("WARNING! -check is deprecated. Please use " +
|
||||
"-status instead. This will be removed in Vault 1.1."))
|
||||
}
|
||||
c.flagStatus = true
|
||||
}
|
||||
|
||||
args = f.Args()
|
||||
if len(args) > 0 {
|
||||
c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args)))
|
||||
|
@ -271,7 +222,6 @@ func (c *OperatorInitCommand) Run(args []string) int {
|
|||
PGPKeys: c.flagPGPKeys,
|
||||
RootTokenPGPKey: c.flagRootTokenPGPKey,
|
||||
|
||||
StoredShares: c.flagStoredShares,
|
||||
RecoveryShares: c.flagRecoveryShares,
|
||||
RecoveryThreshold: c.flagRecoveryThreshold,
|
||||
RecoveryPGPKeys: c.flagRecoveryPGPKeys,
|
||||
|
|
|
@ -36,13 +36,6 @@ type OperatorRekeyCommand struct {
|
|||
flagBackupDelete bool
|
||||
flagBackupRetrieve bool
|
||||
|
||||
// Deprecations
|
||||
// TODO: remove in 0.9.0
|
||||
flagDelete bool
|
||||
flagRecoveryKey bool
|
||||
flagRetrieve bool
|
||||
flagStoredShares int
|
||||
|
||||
testStdin io.Reader // for tests
|
||||
}
|
||||
|
||||
|
@ -216,41 +209,6 @@ func (c *OperatorRekeyCommand) Flags() *FlagSets {
|
|||
"if the PGP keys were provided and the backup has not been deleted.",
|
||||
})
|
||||
|
||||
// Deprecations
|
||||
// TODO: remove in 0.9.0
|
||||
f.BoolVar(&BoolVar{
|
||||
Name: "delete", // prefer -backup-delete
|
||||
Target: &c.flagDelete,
|
||||
Default: false,
|
||||
Hidden: true,
|
||||
Usage: "",
|
||||
})
|
||||
|
||||
f.BoolVar(&BoolVar{
|
||||
Name: "retrieve", // prefer -backup-retrieve
|
||||
Target: &c.flagRetrieve,
|
||||
Default: false,
|
||||
Hidden: true,
|
||||
Usage: "",
|
||||
})
|
||||
|
||||
f.BoolVar(&BoolVar{
|
||||
Name: "recovery-key", // prefer -target=recovery
|
||||
Target: &c.flagRecoveryKey,
|
||||
Default: false,
|
||||
Hidden: true,
|
||||
Usage: "",
|
||||
})
|
||||
|
||||
// Kept to keep scripts passing the flag working, but not used
|
||||
f.IntVar(&IntVar{
|
||||
Name: "stored-shares",
|
||||
Target: &c.flagStoredShares,
|
||||
Default: 0,
|
||||
Hidden: true,
|
||||
Usage: "",
|
||||
})
|
||||
|
||||
return set
|
||||
}
|
||||
|
||||
|
@ -276,33 +234,6 @@ func (c *OperatorRekeyCommand) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
// Deprecations
|
||||
// TODO: remove in 0.9.0
|
||||
if c.flagDelete {
|
||||
if Format(c.UI) == "table" {
|
||||
c.UI.Warn(wrapAtLength(
|
||||
"WARNING! The -delete flag is deprecated. Please use -backup-delete " +
|
||||
"instead. This flag will be removed in Vault 1.1."))
|
||||
}
|
||||
c.flagBackupDelete = true
|
||||
}
|
||||
if c.flagRetrieve {
|
||||
if Format(c.UI) == "table" {
|
||||
c.UI.Warn(wrapAtLength(
|
||||
"WARNING! The -retrieve flag is deprecated. Please use -backup-retrieve " +
|
||||
"instead. This flag will be removed in Vault 1.1."))
|
||||
}
|
||||
c.flagBackupRetrieve = true
|
||||
}
|
||||
if c.flagRecoveryKey {
|
||||
if Format(c.UI) == "table" {
|
||||
c.UI.Warn(wrapAtLength(
|
||||
"WARNING! The -recovery-key flag is deprecated. Please use -target=recovery " +
|
||||
"instead. This flag will be removed in Vault 1.1."))
|
||||
}
|
||||
c.flagTarget = "recovery"
|
||||
}
|
||||
|
||||
// Create the client
|
||||
client, err := c.Client()
|
||||
if err != nil {
|
||||
|
@ -349,7 +280,6 @@ func (c *OperatorRekeyCommand) init(client *api.Client) int {
|
|||
status, err := fn(&api.RekeyInitRequest{
|
||||
SecretShares: c.flagKeyShares,
|
||||
SecretThreshold: c.flagKeyThreshold,
|
||||
StoredShares: c.flagStoredShares,
|
||||
PGPKeys: c.flagPGPKeys,
|
||||
Backup: c.flagBackup,
|
||||
RequireVerification: c.flagVerify,
|
||||
|
|
|
@ -167,7 +167,7 @@ func TestOperatorUnsealCommand_Format(t *testing.T) {
|
|||
Client: client,
|
||||
}
|
||||
|
||||
args, format, _ := setupEnv([]string{"unseal", "-format", "json"})
|
||||
args, format, _ := setupEnv([]string{"operator", "unseal", "-format", "json"})
|
||||
if format != "json" {
|
||||
t.Fatalf("expected %q, got %q", "json", format)
|
||||
}
|
||||
|
|
|
@ -1,57 +0,0 @@
|
|||
package command
|
||||
|
||||
import (
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
// Deprecation
|
||||
// TODO: remove in 0.9.0
|
||||
|
||||
var _ cli.Command = (*PoliciesDeprecatedCommand)(nil)
|
||||
|
||||
type PoliciesDeprecatedCommand struct {
|
||||
*BaseCommand
|
||||
}
|
||||
|
||||
func (c *PoliciesDeprecatedCommand) Synopsis() string { return "" }
|
||||
|
||||
func (c *PoliciesDeprecatedCommand) Help() string {
|
||||
return (&PolicyListCommand{
|
||||
BaseCommand: c.BaseCommand,
|
||||
}).Help()
|
||||
}
|
||||
|
||||
func (c *PoliciesDeprecatedCommand) Run(args []string) int {
|
||||
oargs := args
|
||||
|
||||
f := c.flagSet(FlagSetHTTP)
|
||||
if err := f.Parse(args); err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
|
||||
args = f.Args()
|
||||
|
||||
// Got an arg, this is trying to read a policy
|
||||
if len(args) > 0 {
|
||||
return (&PolicyReadCommand{
|
||||
BaseCommand: &BaseCommand{
|
||||
UI: c.UI,
|
||||
client: c.client,
|
||||
tokenHelper: c.tokenHelper,
|
||||
flagAddress: c.flagAddress,
|
||||
},
|
||||
}).Run(oargs)
|
||||
}
|
||||
|
||||
// No args, probably ran "vault policies" and we want to translate that to
|
||||
// "vault policy list"
|
||||
return (&PolicyListCommand{
|
||||
BaseCommand: &BaseCommand{
|
||||
UI: c.UI,
|
||||
client: c.client,
|
||||
tokenHelper: c.tokenHelper,
|
||||
flagAddress: c.flagAddress,
|
||||
},
|
||||
}).Run(oargs)
|
||||
}
|
|
@ -1,96 +0,0 @@
|
|||
package command
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func testPoliciesDeprecatedCommand(tb testing.TB) (*cli.MockUi, *PoliciesDeprecatedCommand) {
|
||||
tb.Helper()
|
||||
|
||||
ui := cli.NewMockUi()
|
||||
return ui, &PoliciesDeprecatedCommand{
|
||||
BaseCommand: &BaseCommand{
|
||||
UI: ui,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoliciesDeprecatedCommand_Run(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// TODO: remove in 0.9.0
|
||||
t.Run("deprecated_arg", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
defer closer()
|
||||
|
||||
ui, cmd := testPoliciesDeprecatedCommand(t)
|
||||
cmd.client = client
|
||||
|
||||
// vault policies ARG -> vault policy read ARG
|
||||
code := cmd.Run([]string{"default"})
|
||||
if exp := 0; code != exp {
|
||||
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
|
||||
}
|
||||
stdout := ui.OutputWriter.String()
|
||||
|
||||
if expected := "token/"; !strings.Contains(stdout, expected) {
|
||||
t.Errorf("expected %q to contain %q", stdout, expected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deprecated_no_args", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
defer closer()
|
||||
|
||||
ui, cmd := testPoliciesDeprecatedCommand(t)
|
||||
cmd.client = client
|
||||
|
||||
// vault policies -> vault policy list
|
||||
code := cmd.Run([]string{})
|
||||
if exp := 0; code != exp {
|
||||
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
|
||||
}
|
||||
stdout := ui.OutputWriter.String()
|
||||
|
||||
if expected := "root"; !strings.Contains(stdout, expected) {
|
||||
t.Errorf("expected %q to contain %q", stdout, expected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deprecated_with_flags", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
defer closer()
|
||||
|
||||
ui, cmd := testPoliciesDeprecatedCommand(t)
|
||||
cmd.client = client
|
||||
|
||||
// vault policies -flag -> vault policy list
|
||||
code := cmd.Run([]string{
|
||||
"-address", client.Address(),
|
||||
})
|
||||
if exp := 0; code != exp {
|
||||
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
|
||||
}
|
||||
stdout := ui.OutputWriter.String()
|
||||
|
||||
if expected := "root"; !strings.Contains(stdout, expected) {
|
||||
t.Errorf("expected %q to contain %q", stdout, expected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no_tabs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, cmd := testPoliciesDeprecatedCommand(t)
|
||||
assertNoTabs(t, cmd)
|
||||
})
|
||||
}
|
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"github.com/hashicorp/vault/helper/metricsutil"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
|
@ -23,6 +24,7 @@ import (
|
|||
metrics "github.com/armon/go-metrics"
|
||||
"github.com/armon/go-metrics/circonus"
|
||||
"github.com/armon/go-metrics/datadog"
|
||||
"github.com/armon/go-metrics/prometheus"
|
||||
"github.com/hashicorp/errwrap"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
multierror "github.com/hashicorp/go-multierror"
|
||||
|
@ -469,7 +471,8 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
"in a Docker container, provide the IPC_LOCK cap to the container."))
|
||||
}
|
||||
|
||||
if err := c.setupTelemetry(config); err != nil {
|
||||
metricsHelper, err := c.setupTelemetry(config)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error initializing telemetry: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
@ -563,6 +566,7 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
AllLoggers: allLoggers,
|
||||
BuiltinRegistry: builtinplugins.Registry,
|
||||
DisableKeyEncodingChecks: config.DisablePrintableCheck,
|
||||
MetricsHelper: metricsHelper,
|
||||
}
|
||||
if c.flagDev {
|
||||
coreConfig.DevToken = c.flagDevRootTokenID
|
||||
|
@ -1363,24 +1367,29 @@ func (c *ServerCommand) enableDev(core *vault.Core, coreConfig *vault.CoreConfig
|
|||
}
|
||||
|
||||
// Upgrade the default K/V store
|
||||
if !c.flagDevLeasedKV && !c.flagDevKVV1 {
|
||||
req := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
ClientToken: init.RootToken,
|
||||
Path: "sys/mounts/secret/tune",
|
||||
Data: map[string]interface{}{
|
||||
"options": map[string]string{
|
||||
"version": "2",
|
||||
},
|
||||
kvVer := "2"
|
||||
if c.flagDevKVV1 {
|
||||
kvVer = "1"
|
||||
}
|
||||
req := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
ClientToken: init.RootToken,
|
||||
Path: "sys/mounts/secret",
|
||||
Data: map[string]interface{}{
|
||||
"type": "kv",
|
||||
"path": "secret/",
|
||||
"description": "key/value secret storage",
|
||||
"options": map[string]string{
|
||||
"version": kvVer,
|
||||
},
|
||||
}
|
||||
resp, err := core.HandleRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error upgrading default K/V store: {{err}}", err)
|
||||
}
|
||||
if resp.IsError() {
|
||||
return nil, errwrap.Wrapf("failed to upgrade default K/V store: {{err}}", resp.Error())
|
||||
}
|
||||
},
|
||||
}
|
||||
resp, err := core.HandleRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error creating default K/V store: {{err}}", err)
|
||||
}
|
||||
if resp.IsError() {
|
||||
return nil, errwrap.Wrapf("failed to create default K/V store: {{err}}", resp.Error())
|
||||
}
|
||||
|
||||
return init, nil
|
||||
|
@ -1686,8 +1695,8 @@ func (c *ServerCommand) detectRedirect(detect physical.RedirectDetect,
|
|||
return url.String(), nil
|
||||
}
|
||||
|
||||
// setupTelemetry is used to setup the telemetry sub-systems
|
||||
func (c *ServerCommand) setupTelemetry(config *server.Config) error {
|
||||
// setupTelemetry is used to setup the telemetry sub-systems and returns the in-memory sink to be used in http configuration
|
||||
func (c *ServerCommand) setupTelemetry(config *server.Config) (*metricsutil.MetricsHelper, error) {
|
||||
/* Setup telemetry
|
||||
Aggregate on 10 second intervals for 1 minute. Expose the
|
||||
metrics over stderr when there is a SIGUSR1 received.
|
||||
|
@ -1696,10 +1705,10 @@ func (c *ServerCommand) setupTelemetry(config *server.Config) error {
|
|||
metrics.DefaultInmemSignal(inm)
|
||||
|
||||
var telConfig *server.Telemetry
|
||||
if config.Telemetry == nil {
|
||||
telConfig = &server.Telemetry{}
|
||||
} else {
|
||||
if config.Telemetry != nil {
|
||||
telConfig = config.Telemetry
|
||||
} else {
|
||||
telConfig = &server.Telemetry{}
|
||||
}
|
||||
|
||||
metricsConf := metrics.DefaultConfig("vault")
|
||||
|
@ -1707,10 +1716,29 @@ func (c *ServerCommand) setupTelemetry(config *server.Config) error {
|
|||
|
||||
// Configure the statsite sink
|
||||
var fanout metrics.FanoutSink
|
||||
var prometheusEnabled bool
|
||||
|
||||
// Configure the Prometheus sink
|
||||
if telConfig.PrometheusRetentionTime != 0 {
|
||||
prometheusEnabled = true
|
||||
prometheusOpts := prometheus.PrometheusOpts{
|
||||
Expiration: telConfig.PrometheusRetentionTime,
|
||||
}
|
||||
|
||||
sink, err := prometheus.NewPrometheusSinkFrom(prometheusOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fanout = append(fanout, sink)
|
||||
}
|
||||
|
||||
metricHelper := metricsutil.NewMetricsHelper(inm, prometheusEnabled)
|
||||
|
||||
|
||||
if telConfig.StatsiteAddr != "" {
|
||||
sink, err := metrics.NewStatsiteSink(telConfig.StatsiteAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
fanout = append(fanout, sink)
|
||||
}
|
||||
|
@ -1719,7 +1747,7 @@ func (c *ServerCommand) setupTelemetry(config *server.Config) error {
|
|||
if telConfig.StatsdAddr != "" {
|
||||
sink, err := metrics.NewStatsdSink(telConfig.StatsdAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
fanout = append(fanout, sink)
|
||||
}
|
||||
|
@ -1755,7 +1783,7 @@ func (c *ServerCommand) setupTelemetry(config *server.Config) error {
|
|||
|
||||
sink, err := circonus.NewCirconusSink(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
sink.Start()
|
||||
fanout = append(fanout, sink)
|
||||
|
@ -1770,21 +1798,29 @@ func (c *ServerCommand) setupTelemetry(config *server.Config) error {
|
|||
|
||||
sink, err := datadog.NewDogStatsdSink(telConfig.DogStatsDAddr, metricsConf.HostName)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to start DogStatsD sink: {{err}}", err)
|
||||
return nil, errwrap.Wrapf("failed to start DogStatsD sink: {{err}}", err)
|
||||
}
|
||||
sink.SetTags(tags)
|
||||
fanout = append(fanout, sink)
|
||||
}
|
||||
|
||||
// Initialize the global sink
|
||||
if len(fanout) > 0 {
|
||||
fanout = append(fanout, inm)
|
||||
metrics.NewGlobal(metricsConf, fanout)
|
||||
if len(fanout) > 1 {
|
||||
// Hostname enabled will create poor quality metrics name for prometheus
|
||||
if !telConfig.DisableHostname {
|
||||
c.UI.Warn("telemetry.disable_hostname has been set to false. Recommended setting is true for Prometheus to avoid poorly named metrics.")
|
||||
}
|
||||
} else {
|
||||
metricsConf.EnableHostname = false
|
||||
metrics.NewGlobal(metricsConf, inm)
|
||||
}
|
||||
return nil
|
||||
fanout = append(fanout, inm)
|
||||
_, err := metrics.NewGlobal(metricsConf, fanout)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return metricHelper, nil
|
||||
}
|
||||
|
||||
func (c *ServerCommand) Reload(lock *sync.RWMutex, reloadFuncs *map[string][]reload.ReloadFunc, configPath []string) error {
|
||||
|
|
|
@ -19,6 +19,10 @@ import (
|
|||
"github.com/hashicorp/vault/helper/parseutil"
|
||||
)
|
||||
|
||||
const (
|
||||
prometheusDefaultRetentionTime = 24 * time.Hour
|
||||
)
|
||||
|
||||
// Config is the configuration for the vault server.
|
||||
type Config struct {
|
||||
Listeners []*Listener `hcl:"-"`
|
||||
|
@ -98,7 +102,10 @@ func DevConfig(ha, transactional bool) *Config {
|
|||
|
||||
EnableUI: true,
|
||||
|
||||
Telemetry: &Telemetry{},
|
||||
Telemetry: &Telemetry{
|
||||
PrometheusRetentionTime: prometheusDefaultRetentionTime,
|
||||
DisableHostname: true,
|
||||
},
|
||||
}
|
||||
|
||||
switch {
|
||||
|
@ -233,6 +240,12 @@ type Telemetry struct {
|
|||
// DogStatsdTags are the global tags that should be sent with each packet to dogstatsd
|
||||
// It is a list of strings, where each string looks like "my_tag_name:my_tag_value"
|
||||
DogStatsDTags []string `hcl:"dogstatsd_tags"`
|
||||
|
||||
// Prometheus:
|
||||
// PrometheusRetentionTime is the retention time for prometheus metrics if greater than 0.
|
||||
// Default: 24h
|
||||
PrometheusRetentionTime time.Duration `hcl:-`
|
||||
PrometheusRetentionTimeRaw interface{} `hcl:"prometheus_retention_time"`
|
||||
}
|
||||
|
||||
func (s *Telemetry) GoString() string {
|
||||
|
@ -864,5 +877,15 @@ func parseTelemetry(result *Config, list *ast.ObjectList) error {
|
|||
if err := hcl.DecodeObject(&result.Telemetry, item.Val); err != nil {
|
||||
return multierror.Prefix(err, "telemetry:")
|
||||
}
|
||||
|
||||
if result.Telemetry.PrometheusRetentionTimeRaw != nil {
|
||||
var err error
|
||||
if result.Telemetry.PrometheusRetentionTime, err = parseutil.ParseDurationSecond(result.Telemetry.PrometheusRetentionTimeRaw); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
result.Telemetry.PrometheusRetentionTime = prometheusDefaultRetentionTime
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -48,11 +48,12 @@ func TestLoadConfigFile(t *testing.T) {
|
|||
},
|
||||
|
||||
Telemetry: &Telemetry{
|
||||
StatsdAddr: "bar",
|
||||
StatsiteAddr: "foo",
|
||||
DisableHostname: false,
|
||||
DogStatsDAddr: "127.0.0.1:7254",
|
||||
DogStatsDTags: []string{"tag_1:val_1", "tag_2:val_2"},
|
||||
StatsdAddr: "bar",
|
||||
StatsiteAddr: "foo",
|
||||
DisableHostname: false,
|
||||
DogStatsDAddr: "127.0.0.1:7254",
|
||||
DogStatsDTags: []string{"tag_1:val_1", "tag_2:val_2"},
|
||||
PrometheusRetentionTime: prometheusDefaultRetentionTime,
|
||||
},
|
||||
|
||||
DisableCache: true,
|
||||
|
@ -121,11 +122,13 @@ func TestLoadConfigFile_topLevel(t *testing.T) {
|
|||
},
|
||||
|
||||
Telemetry: &Telemetry{
|
||||
StatsdAddr: "bar",
|
||||
StatsiteAddr: "foo",
|
||||
DisableHostname: false,
|
||||
DogStatsDAddr: "127.0.0.1:7254",
|
||||
DogStatsDTags: []string{"tag_1:val_1", "tag_2:val_2"},
|
||||
StatsdAddr: "bar",
|
||||
StatsiteAddr: "foo",
|
||||
DisableHostname: false,
|
||||
DogStatsDAddr: "127.0.0.1:7254",
|
||||
DogStatsDTags: []string{"tag_1:val_1", "tag_2:val_2"},
|
||||
PrometheusRetentionTime: 30 * time.Second,
|
||||
PrometheusRetentionTimeRaw: "30s",
|
||||
},
|
||||
|
||||
DisableCache: true,
|
||||
|
@ -202,6 +205,7 @@ func TestLoadConfigFile_json(t *testing.T) {
|
|||
CirconusCheckTags: "",
|
||||
CirconusBrokerID: "",
|
||||
CirconusBrokerSelectTag: "",
|
||||
PrometheusRetentionTime: prometheusDefaultRetentionTime,
|
||||
},
|
||||
|
||||
MaxLeaseTTL: 10 * time.Hour,
|
||||
|
@ -288,6 +292,8 @@ func TestLoadConfigFile_json2(t *testing.T) {
|
|||
CirconusCheckTags: "cat1:tag1,cat2:tag2",
|
||||
CirconusBrokerID: "0",
|
||||
CirconusBrokerSelectTag: "dc:sfo",
|
||||
PrometheusRetentionTime: 30 * time.Second,
|
||||
PrometheusRetentionTimeRaw: "30s",
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(config, expected) {
|
||||
|
@ -336,9 +342,10 @@ func TestLoadConfigDir(t *testing.T) {
|
|||
EnableRawEndpoint: true,
|
||||
|
||||
Telemetry: &Telemetry{
|
||||
StatsiteAddr: "qux",
|
||||
StatsdAddr: "baz",
|
||||
DisableHostname: true,
|
||||
StatsiteAddr: "qux",
|
||||
StatsdAddr: "baz",
|
||||
DisableHostname: true,
|
||||
PrometheusRetentionTime: prometheusDefaultRetentionTime,
|
||||
},
|
||||
|
||||
MaxLeaseTTL: 10 * time.Hour,
|
||||
|
|
|
@ -72,7 +72,7 @@ func listenerWrapProxy(ln net.Listener, config map[string]interface{}) (net.List
|
|||
return newLn, nil
|
||||
}
|
||||
|
||||
func listenerWrapTLS(
|
||||
func ListenerWrapTLS(
|
||||
ln net.Listener,
|
||||
props map[string]string,
|
||||
config map[string]interface{},
|
||||
|
|
|
@ -35,7 +35,7 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer, ui cli.Ui) (
|
|||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
ln = tcpKeepAliveListener{ln.(*net.TCPListener)}
|
||||
ln = TCPKeepAliveListener{ln.(*net.TCPListener)}
|
||||
|
||||
ln, err = listenerWrapProxy(ln, config)
|
||||
if err != nil {
|
||||
|
@ -94,20 +94,20 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer, ui cli.Ui) (
|
|||
config["x_forwarded_for_reject_not_authorized"] = true
|
||||
}
|
||||
|
||||
return listenerWrapTLS(ln, props, config, ui)
|
||||
return ListenerWrapTLS(ln, props, config, ui)
|
||||
}
|
||||
|
||||
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||
// TCPKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||
// connections. It's used by ListenAndServe and ListenAndServeTLS so
|
||||
// dead TCP connections (e.g. closing laptop mid-download) eventually
|
||||
// go away.
|
||||
//
|
||||
// This is copied directly from the Go source code.
|
||||
type tcpKeepAliveListener struct {
|
||||
type TCPKeepAliveListener struct {
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
|
||||
func (ln TCPKeepAliveListener) Accept() (c net.Conn, err error) {
|
||||
tc, err := ln.AcceptTCP()
|
||||
if err != nil {
|
||||
return
|
||||
|
|
|
@ -26,6 +26,7 @@ telemetry {
|
|||
statsite_address = "foo"
|
||||
dogstatsd_addr = "127.0.0.1:7254"
|
||||
dogstatsd_tags = ["tag_1:val_1", "tag_2:val_2"]
|
||||
prometheus_retention_time = "30s"
|
||||
}
|
||||
|
||||
max_lease_ttl = "10h"
|
||||
|
|
|
@ -42,6 +42,7 @@
|
|||
"circonus_check_display_name": "node1:vault",
|
||||
"circonus_check_tags": "cat1:tag1,cat2:tag2",
|
||||
"circonus_broker_id": "0",
|
||||
"circonus_broker_select_tag": "dc:sfo"
|
||||
"circonus_broker_select_tag": "dc:sfo",
|
||||
"prometheus_retention_time": "30s"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,9 +29,6 @@ type TokenCreateCommand struct {
|
|||
flagType string
|
||||
flagMetadata map[string]string
|
||||
flagPolicies []string
|
||||
|
||||
// Deprecated flags
|
||||
flagLease time.Duration
|
||||
}
|
||||
|
||||
func (c *TokenCreateCommand) Synopsis() string {
|
||||
|
@ -179,15 +176,6 @@ func (c *TokenCreateCommand) Flags() *FlagSets {
|
|||
"specified multiple times to attach multiple policies.",
|
||||
})
|
||||
|
||||
// Deprecated flags
|
||||
// TODO: remove in 0.9.0
|
||||
f.DurationVar(&DurationVar{
|
||||
Name: "lease", // prefer -ttl
|
||||
Target: &c.flagLease,
|
||||
Default: 0,
|
||||
Hidden: true,
|
||||
})
|
||||
|
||||
return set
|
||||
}
|
||||
|
||||
|
@ -213,14 +201,6 @@ func (c *TokenCreateCommand) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
// TODO: remove in 0.9.0
|
||||
if c.flagLease != 0 {
|
||||
if Format(c.UI) == "table" {
|
||||
c.UI.Warn("The -lease flag is deprecated. Please use -ttl instead.")
|
||||
c.flagTTL = c.flagLease
|
||||
}
|
||||
}
|
||||
|
||||
if c.flagType == "batch" {
|
||||
c.flagRenewable = false
|
||||
}
|
||||
|
|
|
@ -97,19 +97,6 @@ func (c *TokenRenewCommand) Run(args []string) int {
|
|||
// Use the local token
|
||||
case 1:
|
||||
token = strings.TrimSpace(args[0])
|
||||
case 2:
|
||||
// TODO: remove in 0.9.0 - backwards compat
|
||||
if Format(c.UI) == "table" {
|
||||
c.UI.Warn("Specifying increment as a second argument is deprecated. " +
|
||||
"Please use -increment instead.")
|
||||
}
|
||||
token = strings.TrimSpace(args[0])
|
||||
parsed, err := time.ParseDuration(appendDurationSuffix(args[1]))
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Invalid increment: %s", err))
|
||||
return 1
|
||||
}
|
||||
increment = parsed
|
||||
default:
|
||||
c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args)))
|
||||
return 1
|
||||
|
|
|
@ -73,6 +73,7 @@ func newRegistry() *registry {
|
|||
"jwt": credJWT.Factory,
|
||||
"kubernetes": credKube.Factory,
|
||||
"ldap": credLdap.Factory,
|
||||
"oidc": credJWT.Factory,
|
||||
"okta": credOkta.Factory,
|
||||
"radius": credRadius.Factory,
|
||||
"userpass": credUserpass.Factory,
|
||||
|
|
|
@ -17,12 +17,21 @@ import (
|
|||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/helper/errutil"
|
||||
)
|
||||
|
||||
// This can be one of a few key types so the different params may or may not be filled
|
||||
type ClusterKeyParams struct {
|
||||
Type string `json:"type" structs:"type" mapstructure:"type"`
|
||||
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
|
||||
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
|
||||
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
|
||||
}
|
||||
|
||||
// Secret is used to attempt to unmarshal a Vault secret
|
||||
// JSON response, as a convenience
|
||||
type Secret struct {
|
||||
|
|
|
@ -22,26 +22,31 @@ func ConfigFields() map[string]*framework.FieldSchema {
|
|||
Type: framework.TypeString,
|
||||
Default: "ldap://127.0.0.1",
|
||||
Description: "LDAP URL to connect to (default: ldap://127.0.0.1). Multiple URLs can be specified by concatenating them with commas; they will be tried in-order.",
|
||||
DisplayName: "URL",
|
||||
},
|
||||
|
||||
"userdn": {
|
||||
Type: framework.TypeString,
|
||||
Description: "LDAP domain to use for users (eg: ou=People,dc=example,dc=org)",
|
||||
DisplayName: "User DN",
|
||||
},
|
||||
|
||||
"binddn": {
|
||||
Type: framework.TypeString,
|
||||
Description: "LDAP DN for searching for the user DN (optional)",
|
||||
DisplayName: "Name of Object to bind (binddn)",
|
||||
},
|
||||
|
||||
"bindpass": {
|
||||
Type: framework.TypeString,
|
||||
Description: "LDAP password for searching for the user DN (optional)",
|
||||
Type: framework.TypeString,
|
||||
Description: "LDAP password for searching for the user DN (optional)",
|
||||
DisplaySensitive: true,
|
||||
},
|
||||
|
||||
"groupdn": {
|
||||
Type: framework.TypeString,
|
||||
Description: "LDAP search base to use for group membership search (eg: ou=Groups,dc=example,dc=org)",
|
||||
DisplayName: "Group DN",
|
||||
},
|
||||
|
||||
"groupfilter": {
|
||||
|
@ -60,17 +65,20 @@ Default: (|(memberUid={{.Username}})(member={{.UserDN}})(uniqueMember={{.UserDN}
|
|||
in order to enumerate user group membership.
|
||||
Examples: "cn" or "memberOf", etc.
|
||||
Default: cn`,
|
||||
DisplayName: "Group Attribute",
|
||||
},
|
||||
|
||||
"upndomain": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Enables userPrincipalDomain login with [username]@UPNDomain (optional)",
|
||||
DisplayName: "User Principal (UPN) Domain",
|
||||
},
|
||||
|
||||
"userattr": {
|
||||
Type: framework.TypeString,
|
||||
Default: "cn",
|
||||
Description: "Attribute used for users (default: cn)",
|
||||
DisplayName: "User Attribute",
|
||||
},
|
||||
|
||||
"certificate": {
|
||||
|
@ -81,28 +89,35 @@ Default: cn`,
|
|||
"discoverdn": {
|
||||
Type: framework.TypeBool,
|
||||
Description: "Use anonymous bind to discover the bind DN of a user (optional)",
|
||||
DisplayName: "Discover DN",
|
||||
},
|
||||
|
||||
"insecure_tls": {
|
||||
Type: framework.TypeBool,
|
||||
Description: "Skip LDAP server SSL Certificate verification - VERY insecure (optional)",
|
||||
DisplayName: "Insecure TLS",
|
||||
},
|
||||
|
||||
"starttls": {
|
||||
Type: framework.TypeBool,
|
||||
Description: "Issue a StartTLS command after establishing unencrypted connection (optional)",
|
||||
DisplayName: "Issue StartTLS command after establishing an unencrypted connection",
|
||||
},
|
||||
|
||||
"tls_min_version": {
|
||||
Type: framework.TypeString,
|
||||
Default: "tls12",
|
||||
Description: "Minimum TLS version to use. Accepted values are 'tls10', 'tls11' or 'tls12'. Defaults to 'tls12'",
|
||||
Type: framework.TypeString,
|
||||
Default: "tls12",
|
||||
Description: "Minimum TLS version to use. Accepted values are 'tls10', 'tls11' or 'tls12'. Defaults to 'tls12'",
|
||||
DisplayName: "Minimum TLS Version",
|
||||
AllowedValues: []interface{}{"tls10", "tls11", "tls12"},
|
||||
},
|
||||
|
||||
"tls_max_version": {
|
||||
Type: framework.TypeString,
|
||||
Default: "tls12",
|
||||
Description: "Maximum TLS version to use. Accepted values are 'tls10', 'tls11' or 'tls12'. Defaults to 'tls12'",
|
||||
Type: framework.TypeString,
|
||||
Default: "tls12",
|
||||
Description: "Maximum TLS version to use. Accepted values are 'tls10', 'tls11' or 'tls12'. Defaults to 'tls12'",
|
||||
DisplayName: "Maxumum TLS Version",
|
||||
AllowedValues: []interface{}{"tls10", "tls11", "tls12"},
|
||||
},
|
||||
|
||||
"deny_null_bind": {
|
||||
|
|
104
helper/metricsutil/metricsutil.go
Normal file
104
helper/metricsutil/metricsutil.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package metricsutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/common/expfmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
OpenMetricsMIMEType = "application/openmetrics-text"
|
||||
)
|
||||
|
||||
const (
|
||||
PrometheusMetricFormat = "prometheus"
|
||||
)
|
||||
|
||||
type MetricsHelper struct {
|
||||
inMemSink *metrics.InmemSink
|
||||
PrometheusEnabled bool
|
||||
}
|
||||
|
||||
func NewMetricsHelper(inMem *metrics.InmemSink, enablePrometheus bool) *MetricsHelper{
|
||||
return &MetricsHelper{inMem, enablePrometheus}
|
||||
}
|
||||
|
||||
func FormatFromRequest(req *logical.Request) (string) {
|
||||
acceptHeaders := req.Headers["Accept"]
|
||||
if len(acceptHeaders) > 0 {
|
||||
acceptHeader := acceptHeaders[0]
|
||||
if strings.HasPrefix(acceptHeader, OpenMetricsMIMEType) {
|
||||
return "prometheus"
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *MetricsHelper) ResponseForFormat(format string) (*logical.Response, error) {
|
||||
switch format {
|
||||
case PrometheusMetricFormat:
|
||||
return m.PrometheusResponse()
|
||||
case "":
|
||||
return m.GenericResponse()
|
||||
default:
|
||||
return nil, fmt.Errorf("metric response format \"%s\" unknown", format)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MetricsHelper) PrometheusResponse() (*logical.Response, error) {
|
||||
if !m.PrometheusEnabled {
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
logical.HTTPContentType: "text/plain",
|
||||
logical.HTTPRawBody: "prometheus is not enabled",
|
||||
logical.HTTPStatusCode: 400,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
metricsFamilies, err := prometheus.DefaultGatherer.Gather()
|
||||
if err != nil && len(metricsFamilies) == 0 {
|
||||
return nil, fmt.Errorf("no prometheus metrics could be decoded: %s", err)
|
||||
}
|
||||
|
||||
// Initialize a byte buffer.
|
||||
buf := &bytes.Buffer{}
|
||||
defer buf.Reset()
|
||||
|
||||
e := expfmt.NewEncoder(buf, expfmt.FmtText)
|
||||
for _, mf := range metricsFamilies {
|
||||
err := e.Encode(mf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error during the encoding of metrics: %s", err)
|
||||
}
|
||||
}
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
logical.HTTPContentType: string(expfmt.FmtText),
|
||||
logical.HTTPRawBody: buf.Bytes(),
|
||||
logical.HTTPStatusCode: 200,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MetricsHelper) GenericResponse() (*logical.Response, error) {
|
||||
summary, err := m.inMemSink.DisplayMetrics(nil,nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while fetching the in-memory metrics: %s", err)
|
||||
}
|
||||
content, err := json.Marshal(summary)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while marshalling the in-memory metrics: %s", err)
|
||||
}
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
logical.HTTPContentType: "application/json",
|
||||
logical.HTTPRawBody: content,
|
||||
logical.HTTPStatusCode: 200,
|
||||
},
|
||||
}, nil
|
||||
}
|
|
@ -453,3 +453,18 @@ func WaitForNCoresSealed(t testing.T, cluster *vault.TestCluster, n int) {
|
|||
|
||||
t.Fatalf("%d cores were not sealed", n)
|
||||
}
|
||||
|
||||
func WaitForActiveNode(t testing.T, cluster *vault.TestCluster) *vault.TestClusterCore {
|
||||
for i := 0; i < 10; i++ {
|
||||
for _, core := range cluster.Cores {
|
||||
if standby, _ := core.Core.Standby(); !standby {
|
||||
return core
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
t.Fatalf("node did not become active")
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/go-test/deep"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/textproto"
|
||||
|
@ -285,6 +286,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
|||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
"force_no_cache": false,
|
||||
"passthrough_request_headers": []interface{}{"Accept"},
|
||||
},
|
||||
"local": false,
|
||||
"seal_wrap": false,
|
||||
|
@ -334,6 +336,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
|||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
"force_no_cache": false,
|
||||
"passthrough_request_headers": []interface{}{"Accept"},
|
||||
},
|
||||
"local": false,
|
||||
"seal_wrap": false,
|
||||
|
@ -376,8 +379,8 @@ func TestSysMounts_headerAuth(t *testing.T) {
|
|||
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Fatalf("bad:\nExpected: %#v\nActual: %#v\n", expected, actual)
|
||||
if diff := deep.Equal(actual, expected); len(diff) > 0 {
|
||||
t.Fatalf("bad, diff: %#v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package http
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/go-test/deep"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
|
@ -45,6 +46,7 @@ func TestSysMounts(t *testing.T) {
|
|||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
"force_no_cache": false,
|
||||
"passthrough_request_headers": []interface{}{"Accept"},
|
||||
},
|
||||
"local": false,
|
||||
"seal_wrap": false,
|
||||
|
@ -94,6 +96,7 @@ func TestSysMounts(t *testing.T) {
|
|||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
"force_no_cache": false,
|
||||
"passthrough_request_headers": []interface{}{"Accept"},
|
||||
},
|
||||
"local": false,
|
||||
"seal_wrap": false,
|
||||
|
@ -135,8 +138,8 @@ func TestSysMounts(t *testing.T) {
|
|||
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Fatalf("bad: expected: %#v\nactual: %#v\n", expected, actual)
|
||||
if diff := deep.Equal(actual, expected); len(diff) > 0 {
|
||||
t.Fatalf("bad, diff: %#v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -197,6 +200,7 @@ func TestSysMount(t *testing.T) {
|
|||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
"force_no_cache": false,
|
||||
"passthrough_request_headers": []interface{}{"Accept"},
|
||||
},
|
||||
"local": false,
|
||||
"seal_wrap": false,
|
||||
|
@ -258,6 +262,7 @@ func TestSysMount(t *testing.T) {
|
|||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
"force_no_cache": false,
|
||||
"passthrough_request_headers": []interface{}{"Accept"},
|
||||
},
|
||||
"local": false,
|
||||
"seal_wrap": false,
|
||||
|
@ -299,9 +304,10 @@ func TestSysMount(t *testing.T) {
|
|||
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Fatalf("bad: expected: %#v\nactual: %#v\n", expected, actual)
|
||||
if diff := deep.Equal(actual, expected); len(diff) > 0 {
|
||||
t.Fatalf("bad, diff: %#v", diff)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSysMount_put(t *testing.T) {
|
||||
|
@ -380,6 +386,7 @@ func TestSysRemount(t *testing.T) {
|
|||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
"force_no_cache": false,
|
||||
"passthrough_request_headers": []interface{}{"Accept"},
|
||||
},
|
||||
"local": false,
|
||||
"seal_wrap": false,
|
||||
|
@ -441,6 +448,7 @@ func TestSysRemount(t *testing.T) {
|
|||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
"force_no_cache": false,
|
||||
"passthrough_request_headers": []interface{}{"Accept"},
|
||||
},
|
||||
"local": false,
|
||||
"seal_wrap": false,
|
||||
|
@ -483,7 +491,7 @@ func TestSysRemount(t *testing.T) {
|
|||
}
|
||||
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Fatalf("bad:\ngot\n%#v\nexpected\n%#v\n", actual, expected)
|
||||
t.Fatalf("bad:\nExpected: %#v\nActual: %#v\n", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -532,6 +540,7 @@ func TestSysUnmount(t *testing.T) {
|
|||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
"force_no_cache": false,
|
||||
"passthrough_request_headers": []interface{}{"Accept"},
|
||||
},
|
||||
"local": false,
|
||||
"seal_wrap": false,
|
||||
|
@ -581,6 +590,7 @@ func TestSysUnmount(t *testing.T) {
|
|||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
"force_no_cache": false,
|
||||
"passthrough_request_headers": []interface{}{"Accept"},
|
||||
},
|
||||
"local": false,
|
||||
"seal_wrap": false,
|
||||
|
@ -622,8 +632,8 @@ func TestSysUnmount(t *testing.T) {
|
|||
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Fatalf("bad: %#v", actual)
|
||||
if diff := deep.Equal(actual, expected); len(diff) > 0 {
|
||||
t.Fatalf("bad, diff: %#v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -766,6 +776,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
"force_no_cache": false,
|
||||
"passthrough_request_headers": []interface{}{"Accept"},
|
||||
},
|
||||
"local": false,
|
||||
"seal_wrap": false,
|
||||
|
@ -827,6 +838,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
"force_no_cache": false,
|
||||
"passthrough_request_headers": []interface{}{"Accept"},
|
||||
},
|
||||
"local": false,
|
||||
"seal_wrap": false,
|
||||
|
@ -868,8 +880,8 @@ func TestSysTuneMount(t *testing.T) {
|
|||
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Fatalf("bad: %#v", actual)
|
||||
if diff := deep.Equal(actual, expected); len(diff) > 0 {
|
||||
t.Fatalf("bad, diff: %#v", diff)
|
||||
}
|
||||
|
||||
// Shorter than system default
|
||||
|
@ -956,6 +968,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
"force_no_cache": false,
|
||||
"passthrough_request_headers": []interface{}{"Accept"},
|
||||
},
|
||||
"local": false,
|
||||
"seal_wrap": false,
|
||||
|
@ -1017,6 +1030,7 @@ func TestSysTuneMount(t *testing.T) {
|
|||
"default_lease_ttl": json.Number("0"),
|
||||
"max_lease_ttl": json.Number("0"),
|
||||
"force_no_cache": false,
|
||||
"passthrough_request_headers": []interface{}{"Accept"},
|
||||
},
|
||||
"local": false,
|
||||
"seal_wrap": false,
|
||||
|
@ -1059,8 +1073,8 @@ func TestSysTuneMount(t *testing.T) {
|
|||
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Fatalf("bad:\nExpected: %#v\nActual:%#v", expected, actual)
|
||||
if diff := deep.Equal(actual, expected); len(diff) > 0 {
|
||||
t.Fatalf("bad, diff: %#v", diff)
|
||||
}
|
||||
|
||||
// Check simple configuration endpoint
|
||||
|
|
|
@ -155,6 +155,7 @@ type OASSchema struct {
|
|||
Format string `json:"format,omitempty"`
|
||||
Pattern string `json:"pattern,omitempty"`
|
||||
Enum []interface{} `json:"enum,omitempty"`
|
||||
Default interface{} `json:"default,omitempty"`
|
||||
Example interface{} `json:"example,omitempty"`
|
||||
Deprecated bool `json:"deprecated,omitempty"`
|
||||
Required bool `json:"required,omitempty"`
|
||||
|
@ -263,6 +264,7 @@ func documentPath(p *Path, specialPaths *logical.Paths, backendType logical.Back
|
|||
Type: t.baseType,
|
||||
Pattern: t.pattern,
|
||||
Enum: field.AllowedValues,
|
||||
Default: field.Default,
|
||||
DisplayName: field.DisplayName,
|
||||
DisplayValue: field.DisplayValue,
|
||||
DisplaySensitive: field.DisplaySensitive,
|
||||
|
@ -321,6 +323,7 @@ func documentPath(p *Path, specialPaths *logical.Paths, backendType logical.Back
|
|||
Format: openapiField.format,
|
||||
Pattern: openapiField.pattern,
|
||||
Enum: field.AllowedValues,
|
||||
Default: field.Default,
|
||||
Required: field.Required,
|
||||
Deprecated: field.Deprecated,
|
||||
DisplayName: field.DisplayName,
|
||||
|
|
|
@ -326,6 +326,7 @@ func TestOpenAPI_Paths(t *testing.T) {
|
|||
},
|
||||
"name": {
|
||||
Type: TypeNameString,
|
||||
Default: "Larry",
|
||||
Description: "the name",
|
||||
},
|
||||
"age": {
|
||||
|
|
1
logical/framework/testdata/operations.json
vendored
1
logical/framework/testdata/operations.json
vendored
|
@ -85,6 +85,7 @@
|
|||
"name": {
|
||||
"type": "string",
|
||||
"description": "the name",
|
||||
"default": "Larry",
|
||||
"pattern": "\\w([\\w-.]*\\w)?"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package plugin
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/rpc"
|
||||
"sync/atomic"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
@ -13,16 +12,9 @@ import (
|
|||
"github.com/hashicorp/vault/logical/plugin/pb"
|
||||
)
|
||||
|
||||
var _ plugin.Plugin = (*BackendPlugin)(nil)
|
||||
var _ plugin.GRPCPlugin = (*BackendPlugin)(nil)
|
||||
var _ plugin.Plugin = (*GRPCBackendPlugin)(nil)
|
||||
var _ plugin.GRPCPlugin = (*GRPCBackendPlugin)(nil)
|
||||
|
||||
// BackendPlugin is the plugin.Plugin implementation
|
||||
type BackendPlugin struct {
|
||||
*GRPCBackendPlugin
|
||||
}
|
||||
|
||||
// GRPCBackendPlugin is the plugin.Plugin implementation that only supports GRPC
|
||||
// transport
|
||||
type GRPCBackendPlugin struct {
|
||||
|
@ -34,26 +26,6 @@ type GRPCBackendPlugin struct {
|
|||
plugin.NetRPCUnsupportedPlugin
|
||||
}
|
||||
|
||||
// Server gets called when on plugin.Serve()
|
||||
func (b *BackendPlugin) Server(broker *plugin.MuxBroker) (interface{}, error) {
|
||||
return &backendPluginServer{
|
||||
factory: b.Factory,
|
||||
broker: broker,
|
||||
// We pass the logger down into the backend so go-plugin will forward
|
||||
// logs for us.
|
||||
logger: b.Logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Client gets called on plugin.NewClient()
|
||||
func (b BackendPlugin) Client(broker *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
|
||||
return &backendPluginClient{
|
||||
client: c,
|
||||
broker: broker,
|
||||
metadataMode: b.MetadataMode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b GRPCBackendPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error {
|
||||
pb.RegisterBackendServer(s, &backendGRPCPluginServer{
|
||||
broker: broker,
|
||||
|
|
|
@ -1,248 +0,0 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/rpc"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrClientInMetadataMode = errors.New("plugin client can not perform action while in metadata mode")
|
||||
)
|
||||
|
||||
// backendPluginClient implements logical.Backend and is the
|
||||
// go-plugin client.
|
||||
type backendPluginClient struct {
|
||||
broker *plugin.MuxBroker
|
||||
client *rpc.Client
|
||||
metadataMode bool
|
||||
|
||||
system logical.SystemView
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
// HandleRequestArgs is the args for HandleRequest method.
|
||||
type HandleRequestArgs struct {
|
||||
StorageID uint32
|
||||
Request *logical.Request
|
||||
}
|
||||
|
||||
// HandleRequestReply is the reply for HandleRequest method.
|
||||
type HandleRequestReply struct {
|
||||
Response *logical.Response
|
||||
Error error
|
||||
}
|
||||
|
||||
// SpecialPathsReply is the reply for SpecialPaths method.
|
||||
type SpecialPathsReply struct {
|
||||
Paths *logical.Paths
|
||||
}
|
||||
|
||||
// SystemReply is the reply for System method.
|
||||
type SystemReply struct {
|
||||
SystemView logical.SystemView
|
||||
Error error
|
||||
}
|
||||
|
||||
// HandleExistenceCheckArgs is the args for HandleExistenceCheck method.
|
||||
type HandleExistenceCheckArgs struct {
|
||||
StorageID uint32
|
||||
Request *logical.Request
|
||||
}
|
||||
|
||||
// HandleExistenceCheckReply is the reply for HandleExistenceCheck method.
|
||||
type HandleExistenceCheckReply struct {
|
||||
CheckFound bool
|
||||
Exists bool
|
||||
Error error
|
||||
}
|
||||
|
||||
// SetupArgs is the args for Setup method.
|
||||
type SetupArgs struct {
|
||||
StorageID uint32
|
||||
LoggerID uint32
|
||||
SysViewID uint32
|
||||
Config map[string]string
|
||||
BackendUUID string
|
||||
}
|
||||
|
||||
// SetupReply is the reply for Setup method.
|
||||
type SetupReply struct {
|
||||
Error error
|
||||
}
|
||||
|
||||
// TypeReply is the reply for the Type method.
|
||||
type TypeReply struct {
|
||||
Type logical.BackendType
|
||||
}
|
||||
|
||||
func (b *backendPluginClient) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) {
|
||||
if b.metadataMode {
|
||||
return nil, ErrClientInMetadataMode
|
||||
}
|
||||
|
||||
// Do not send the storage, since go-plugin cannot serialize
|
||||
// interfaces. The server will pick up the storage from the shim.
|
||||
req.Storage = nil
|
||||
args := &HandleRequestArgs{
|
||||
Request: req,
|
||||
}
|
||||
var reply HandleRequestReply
|
||||
|
||||
if req.Connection != nil {
|
||||
oldConnState := req.Connection.ConnState
|
||||
req.Connection.ConnState = nil
|
||||
defer func() {
|
||||
req.Connection.ConnState = oldConnState
|
||||
}()
|
||||
}
|
||||
|
||||
err := b.client.Call("Plugin.HandleRequest", args, &reply)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reply.Error != nil {
|
||||
if reply.Error.Error() == logical.ErrUnsupportedOperation.Error() {
|
||||
return nil, logical.ErrUnsupportedOperation
|
||||
}
|
||||
|
||||
return reply.Response, reply.Error
|
||||
}
|
||||
|
||||
return reply.Response, nil
|
||||
}
|
||||
|
||||
func (b *backendPluginClient) SpecialPaths() *logical.Paths {
|
||||
var reply SpecialPathsReply
|
||||
err := b.client.Call("Plugin.SpecialPaths", new(interface{}), &reply)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return reply.Paths
|
||||
}
|
||||
|
||||
// System returns vault's system view. The backend client stores the view during
|
||||
// Setup, so there is no need to shim the system just to get it back.
|
||||
func (b *backendPluginClient) System() logical.SystemView {
|
||||
return b.system
|
||||
}
|
||||
|
||||
// Logger returns vault's logger. The backend client stores the logger during
|
||||
// Setup, so there is no need to shim the logger just to get it back.
|
||||
func (b *backendPluginClient) Logger() log.Logger {
|
||||
return b.logger
|
||||
}
|
||||
|
||||
func (b *backendPluginClient) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) {
|
||||
if b.metadataMode {
|
||||
return false, false, ErrClientInMetadataMode
|
||||
}
|
||||
|
||||
// Do not send the storage, since go-plugin cannot serialize
|
||||
// interfaces. The server will pick up the storage from the shim.
|
||||
req.Storage = nil
|
||||
args := &HandleExistenceCheckArgs{
|
||||
Request: req,
|
||||
}
|
||||
var reply HandleExistenceCheckReply
|
||||
|
||||
if req.Connection != nil {
|
||||
oldConnState := req.Connection.ConnState
|
||||
req.Connection.ConnState = nil
|
||||
defer func() {
|
||||
req.Connection.ConnState = oldConnState
|
||||
}()
|
||||
}
|
||||
|
||||
err := b.client.Call("Plugin.HandleExistenceCheck", args, &reply)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
if reply.Error != nil {
|
||||
// THINKING: Should be be a switch on all error types?
|
||||
if reply.Error.Error() == logical.ErrUnsupportedPath.Error() {
|
||||
return false, false, logical.ErrUnsupportedPath
|
||||
}
|
||||
return false, false, reply.Error
|
||||
}
|
||||
|
||||
return reply.CheckFound, reply.Exists, nil
|
||||
}
|
||||
|
||||
func (b *backendPluginClient) Cleanup(ctx context.Context) {
|
||||
b.client.Call("Plugin.Cleanup", new(interface{}), &struct{}{})
|
||||
}
|
||||
|
||||
func (b *backendPluginClient) Initialize(ctx context.Context) error {
|
||||
if b.metadataMode {
|
||||
return ErrClientInMetadataMode
|
||||
}
|
||||
err := b.client.Call("Plugin.Initialize", new(interface{}), &struct{}{})
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *backendPluginClient) InvalidateKey(ctx context.Context, key string) {
|
||||
if b.metadataMode {
|
||||
return
|
||||
}
|
||||
b.client.Call("Plugin.InvalidateKey", key, &struct{}{})
|
||||
}
|
||||
|
||||
func (b *backendPluginClient) Setup(ctx context.Context, config *logical.BackendConfig) error {
|
||||
// Shim logical.Storage
|
||||
storageImpl := config.StorageView
|
||||
if b.metadataMode {
|
||||
storageImpl = &NOOPStorage{}
|
||||
}
|
||||
storageID := b.broker.NextId()
|
||||
go b.broker.AcceptAndServe(storageID, &StorageServer{
|
||||
impl: storageImpl,
|
||||
})
|
||||
|
||||
// Shim logical.SystemView
|
||||
sysViewImpl := config.System
|
||||
if b.metadataMode {
|
||||
sysViewImpl = &logical.StaticSystemView{}
|
||||
}
|
||||
sysViewID := b.broker.NextId()
|
||||
go b.broker.AcceptAndServe(sysViewID, &SystemViewServer{
|
||||
impl: sysViewImpl,
|
||||
})
|
||||
|
||||
args := &SetupArgs{
|
||||
StorageID: storageID,
|
||||
SysViewID: sysViewID,
|
||||
Config: config.Config,
|
||||
BackendUUID: config.BackendUUID,
|
||||
}
|
||||
var reply SetupReply
|
||||
|
||||
err := b.client.Call("Plugin.Setup", args, &reply)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if reply.Error != nil {
|
||||
return reply.Error
|
||||
}
|
||||
|
||||
// Set system and logger for getter methods
|
||||
b.system = config.System
|
||||
b.logger = config.Logger
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *backendPluginClient) Type() logical.BackendType {
|
||||
var reply TypeReply
|
||||
err := b.client.Call("Plugin.Type", new(interface{}), &reply)
|
||||
if err != nil {
|
||||
return logical.TypeUnknown
|
||||
}
|
||||
|
||||
return logical.BackendType(reply.Type)
|
||||
}
|
|
@ -1,147 +0,0 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/rpc"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrServerInMetadataMode = errors.New("plugin server can not perform action while in metadata mode")
|
||||
)
|
||||
|
||||
// backendPluginServer is the RPC server that backendPluginClient talks to,
|
||||
// it methods conforming to requirements by net/rpc
|
||||
type backendPluginServer struct {
|
||||
broker *plugin.MuxBroker
|
||||
backend logical.Backend
|
||||
factory logical.Factory
|
||||
|
||||
logger hclog.Logger
|
||||
sysViewClient *rpc.Client
|
||||
storageClient *rpc.Client
|
||||
}
|
||||
|
||||
func (b *backendPluginServer) HandleRequest(args *HandleRequestArgs, reply *HandleRequestReply) error {
|
||||
if pluginutil.InMetadataMode() {
|
||||
return ErrServerInMetadataMode
|
||||
}
|
||||
|
||||
storage := &StorageClient{client: b.storageClient}
|
||||
args.Request.Storage = storage
|
||||
|
||||
resp, err := b.backend.HandleRequest(context.Background(), args.Request)
|
||||
*reply = HandleRequestReply{
|
||||
Response: resp,
|
||||
Error: wrapError(err),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *backendPluginServer) SpecialPaths(_ interface{}, reply *SpecialPathsReply) error {
|
||||
*reply = SpecialPathsReply{
|
||||
Paths: b.backend.SpecialPaths(),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *backendPluginServer) HandleExistenceCheck(args *HandleExistenceCheckArgs, reply *HandleExistenceCheckReply) error {
|
||||
if pluginutil.InMetadataMode() {
|
||||
return ErrServerInMetadataMode
|
||||
}
|
||||
|
||||
storage := &StorageClient{client: b.storageClient}
|
||||
args.Request.Storage = storage
|
||||
|
||||
checkFound, exists, err := b.backend.HandleExistenceCheck(context.TODO(), args.Request)
|
||||
*reply = HandleExistenceCheckReply{
|
||||
CheckFound: checkFound,
|
||||
Exists: exists,
|
||||
Error: wrapError(err),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *backendPluginServer) Cleanup(_ interface{}, _ *struct{}) error {
|
||||
b.backend.Cleanup(context.Background())
|
||||
|
||||
// Close rpc clients
|
||||
b.sysViewClient.Close()
|
||||
b.storageClient.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *backendPluginServer) InvalidateKey(args string, _ *struct{}) error {
|
||||
if pluginutil.InMetadataMode() {
|
||||
return ErrServerInMetadataMode
|
||||
}
|
||||
|
||||
b.backend.InvalidateKey(context.Background(), args)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Setup dials into the plugin's broker to get a shimmed storage, logger, and
|
||||
// system view of the backend. This method also instantiates the underlying
|
||||
// backend through its factory func for the server side of the plugin.
|
||||
func (b *backendPluginServer) Setup(args *SetupArgs, reply *SetupReply) error {
|
||||
// Dial for storage
|
||||
storageConn, err := b.broker.Dial(args.StorageID)
|
||||
if err != nil {
|
||||
*reply = SetupReply{
|
||||
Error: wrapError(err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
rawStorageClient := rpc.NewClient(storageConn)
|
||||
b.storageClient = rawStorageClient
|
||||
|
||||
storage := &StorageClient{client: rawStorageClient}
|
||||
|
||||
// Dial for sys view
|
||||
sysViewConn, err := b.broker.Dial(args.SysViewID)
|
||||
if err != nil {
|
||||
*reply = SetupReply{
|
||||
Error: wrapError(err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
rawSysViewClient := rpc.NewClient(sysViewConn)
|
||||
b.sysViewClient = rawSysViewClient
|
||||
|
||||
sysView := &SystemViewClient{client: rawSysViewClient}
|
||||
|
||||
config := &logical.BackendConfig{
|
||||
StorageView: storage,
|
||||
Logger: b.logger,
|
||||
System: sysView,
|
||||
Config: args.Config,
|
||||
BackendUUID: args.BackendUUID,
|
||||
}
|
||||
|
||||
// Call the underlying backend factory after shims have been created
|
||||
// to set b.backend
|
||||
backend, err := b.factory(context.Background(), config)
|
||||
if err != nil {
|
||||
*reply = SetupReply{
|
||||
Error: wrapError(err),
|
||||
}
|
||||
}
|
||||
b.backend = backend
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *backendPluginServer) Type(_ interface{}, reply *TypeReply) error {
|
||||
*reply = TypeReply{
|
||||
Type: b.backend.Type(),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,173 +0,0 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
gplugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/helper/logging"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/plugin/mock"
|
||||
)
|
||||
|
||||
func TestBackendPlugin_impl(t *testing.T) {
|
||||
var _ gplugin.Plugin = new(BackendPlugin)
|
||||
var _ logical.Backend = new(backendPluginClient)
|
||||
}
|
||||
|
||||
func TestBackendPlugin_HandleRequest(t *testing.T) {
|
||||
b, cleanup := testBackend(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, err := b.HandleRequest(context.Background(), &logical.Request{
|
||||
Operation: logical.CreateOperation,
|
||||
Path: "kv/foo",
|
||||
Data: map[string]interface{}{
|
||||
"value": "bar",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Data["value"] != "bar" {
|
||||
t.Fatalf("bad: %#v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendPlugin_SpecialPaths(t *testing.T) {
|
||||
b, cleanup := testBackend(t)
|
||||
defer cleanup()
|
||||
|
||||
paths := b.SpecialPaths()
|
||||
if paths == nil {
|
||||
t.Fatal("SpecialPaths() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendPlugin_System(t *testing.T) {
|
||||
b, cleanup := testBackend(t)
|
||||
defer cleanup()
|
||||
|
||||
sys := b.System()
|
||||
if sys == nil {
|
||||
t.Fatal("System() returned nil")
|
||||
}
|
||||
|
||||
actual := sys.DefaultLeaseTTL()
|
||||
expected := 300 * time.Second
|
||||
|
||||
if actual != expected {
|
||||
t.Fatalf("bad: %v, expected %v", actual, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendPlugin_Logger(t *testing.T) {
|
||||
b, cleanup := testBackend(t)
|
||||
defer cleanup()
|
||||
|
||||
logger := b.Logger()
|
||||
if logger == nil {
|
||||
t.Fatal("Logger() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendPlugin_HandleExistenceCheck(t *testing.T) {
|
||||
b, cleanup := testBackend(t)
|
||||
defer cleanup()
|
||||
|
||||
checkFound, exists, err := b.HandleExistenceCheck(context.Background(), &logical.Request{
|
||||
Operation: logical.CreateOperation,
|
||||
Path: "kv/foo",
|
||||
Data: map[string]interface{}{"value": "bar"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !checkFound {
|
||||
t.Fatal("existence check not found for path 'kv/foo")
|
||||
}
|
||||
if exists {
|
||||
t.Fatal("existence check should have returned 'false' for 'kv/foo'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendPlugin_Cleanup(t *testing.T) {
|
||||
b, cleanup := testBackend(t)
|
||||
defer cleanup()
|
||||
|
||||
b.Cleanup(context.Background())
|
||||
}
|
||||
|
||||
func TestBackendPlugin_InvalidateKey(t *testing.T) {
|
||||
b, cleanup := testBackend(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := b.HandleRequest(ctx, &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "internal",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Data["value"] == "" {
|
||||
t.Fatalf("bad: %#v, expected non-empty value", resp)
|
||||
}
|
||||
|
||||
b.InvalidateKey(ctx, "internal")
|
||||
|
||||
resp, err = b.HandleRequest(ctx, &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "internal",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Data["value"] != "" {
|
||||
t.Fatalf("bad: expected empty response data, got %#v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendPlugin_Setup(t *testing.T) {
|
||||
_, cleanup := testBackend(t)
|
||||
defer cleanup()
|
||||
}
|
||||
|
||||
func testBackend(t *testing.T) (logical.Backend, func()) {
|
||||
// Create a mock provider
|
||||
pluginMap := map[string]gplugin.Plugin{
|
||||
"backend": &BackendPlugin{
|
||||
GRPCBackendPlugin: &GRPCBackendPlugin{
|
||||
Factory: mock.Factory,
|
||||
},
|
||||
},
|
||||
}
|
||||
client, _ := gplugin.TestPluginRPCConn(t, pluginMap, nil)
|
||||
cleanup := func() {
|
||||
client.Close()
|
||||
}
|
||||
|
||||
// Request the backend
|
||||
raw, err := client.Dispense(BackendPluginName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b := raw.(logical.Backend)
|
||||
|
||||
err = b.Setup(context.Background(), &logical.BackendConfig{
|
||||
Logger: logging.NewVaultLogger(log.Debug),
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: 300 * time.Second,
|
||||
MaxLeaseTTLVal: 1800 * time.Second,
|
||||
},
|
||||
StorageView: &logical.InmemStorage{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return b, cleanup
|
||||
}
|
|
@ -16,6 +16,7 @@ import (
|
|||
)
|
||||
|
||||
var ErrPluginShutdown = errors.New("plugin is shut down")
|
||||
var ErrClientInMetadataMode = errors.New("plugin client can not perform action while in metadata mode")
|
||||
|
||||
// Validate backendGRPCPluginClient satisfies the logical.Backend interface
|
||||
var _ logical.Backend = &backendGRPCPluginClient{}
|
||||
|
|
|
@ -2,6 +2,7 @@ package plugin
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
|
@ -11,6 +12,8 @@ import (
|
|||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
var ErrServerInMetadataMode = errors.New("plugin server can not perform action while in metadata mode")
|
||||
|
||||
type backendGRPCPluginServer struct {
|
||||
broker *plugin.GRPCBroker
|
||||
backend logical.Backend
|
||||
|
|
|
@ -14,8 +14,8 @@ import (
|
|||
)
|
||||
|
||||
func TestGRPCBackendPlugin_impl(t *testing.T) {
|
||||
var _ gplugin.Plugin = new(BackendPlugin)
|
||||
var _ logical.Backend = new(backendPluginClient)
|
||||
var _ gplugin.Plugin = new(GRPCBackendPlugin)
|
||||
var _ logical.Backend = new(backendGRPCPluginClient)
|
||||
}
|
||||
|
||||
func TestGRPCBackendPlugin_HandleRequest(t *testing.T) {
|
||||
|
@ -140,15 +140,13 @@ func TestGRPCBackendPlugin_Setup(t *testing.T) {
|
|||
func testGRPCBackend(t *testing.T) (logical.Backend, func()) {
|
||||
// Create a mock provider
|
||||
pluginMap := map[string]gplugin.Plugin{
|
||||
"backend": &BackendPlugin{
|
||||
GRPCBackendPlugin: &GRPCBackendPlugin{
|
||||
Factory: mock.Factory,
|
||||
Logger: log.New(&log.LoggerOptions{
|
||||
Level: log.Debug,
|
||||
Output: os.Stderr,
|
||||
JSONFormat: true,
|
||||
}),
|
||||
},
|
||||
"backend": &GRPCBackendPlugin{
|
||||
Factory: mock.Factory,
|
||||
Logger: log.New(&log.LoggerOptions{
|
||||
Level: log.Debug,
|
||||
Output: os.Stderr,
|
||||
JSONFormat: true,
|
||||
}),
|
||||
},
|
||||
}
|
||||
client, _ := gplugin.TestPluginGRPCConn(t, pluginMap)
|
||||
|
|
|
@ -108,3 +108,23 @@ func (s *GRPCStorageServer) Delete(ctx context.Context, args *pb.StorageDeleteAr
|
|||
Err: pb.ErrToString(err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NOOPStorage is used to deny access to the storage interface while running a
|
||||
// backend plugin in metadata mode.
|
||||
type NOOPStorage struct{}
|
||||
|
||||
func (s *NOOPStorage) List(_ context.Context, prefix string) ([]string, error) {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
func (s *NOOPStorage) Get(_ context.Context, key string) (*logical.StorageEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *NOOPStorage) Put(_ context.Context, entry *logical.StorageEntry) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NOOPStorage) Delete(_ context.Context, key string) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -2,13 +2,9 @@ package plugin
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
|
@ -18,28 +14,6 @@ import (
|
|||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
// init registers basic structs with gob which will be used to transport complex
|
||||
// types through the plugin server and client.
|
||||
func init() {
|
||||
// Common basic structs
|
||||
gob.Register([]interface{}{})
|
||||
gob.Register(map[string]interface{}{})
|
||||
gob.Register(map[string]string{})
|
||||
gob.Register(map[string]int{})
|
||||
|
||||
// Register these types since we have to serialize and de-serialize
|
||||
// tls.ConnectionState over the wire as part of logical.Request.Connection.
|
||||
gob.Register(rsa.PublicKey{})
|
||||
gob.Register(ecdsa.PublicKey{})
|
||||
gob.Register(time.Duration(0))
|
||||
|
||||
// Custom common error types for requests. If you add something here, you must
|
||||
// also add it to the switch statement in `wrapError`!
|
||||
gob.Register(&plugin.BasicError{})
|
||||
gob.Register(logical.CodedError(0, ""))
|
||||
gob.Register(&logical.StatusBadRequest{})
|
||||
}
|
||||
|
||||
// BackendPluginClient is a wrapper around backendPluginClient
|
||||
// that also contains its plugin.Client instance. It's primarily
|
||||
// used to cleanly kill the client on Cleanup()
|
||||
|
@ -98,11 +72,13 @@ func NewBackend(ctx context.Context, pluginName string, pluginType consts.Plugin
|
|||
func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
|
||||
// pluginMap is the map of plugins we can dispense.
|
||||
pluginSet := map[int]plugin.PluginSet{
|
||||
// Version 3 used to supports both protocols. We want to keep it around
|
||||
// since it's possible old plugins built against this version will still
|
||||
// work with gRPC. There is currently no difference between version 3
|
||||
// and version 4.
|
||||
3: plugin.PluginSet{
|
||||
"backend": &BackendPlugin{
|
||||
GRPCBackendPlugin: &GRPCBackendPlugin{
|
||||
MetadataMode: isMetadataMode,
|
||||
},
|
||||
"backend": &GRPCBackendPlugin{
|
||||
MetadataMode: isMetadataMode,
|
||||
},
|
||||
},
|
||||
4: plugin.PluginSet{
|
||||
|
@ -142,10 +118,6 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
|
|||
// We should have a logical backend type now. This feels like a normal interface
|
||||
// implementation but is in fact over an RPC connection.
|
||||
switch raw.(type) {
|
||||
case *backendPluginClient:
|
||||
logger.Warn("plugin is using deprecated netRPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name)
|
||||
backend = raw.(*backendPluginClient)
|
||||
transport = "netRPC"
|
||||
case *backendGRPCPluginClient:
|
||||
backend = raw.(*backendGRPCPluginClient)
|
||||
transport = "gRPC"
|
||||
|
|
|
@ -39,12 +39,14 @@ func Serve(opts *ServeOpts) error {
|
|||
|
||||
// pluginMap is the map of plugins we can dispense.
|
||||
pluginSets := map[int]plugin.PluginSet{
|
||||
// Version 3 used to supports both protocols. We want to keep it around
|
||||
// since it's possible old plugins built against this version will still
|
||||
// work with gRPC. There is currently no difference between version 3
|
||||
// and version 4.
|
||||
3: plugin.PluginSet{
|
||||
"backend": &BackendPlugin{
|
||||
GRPCBackendPlugin: &GRPCBackendPlugin{
|
||||
Factory: opts.BackendFactoryFunc,
|
||||
Logger: logger,
|
||||
},
|
||||
"backend": &GRPCBackendPlugin{
|
||||
Factory: opts.BackendFactoryFunc,
|
||||
Logger: logger,
|
||||
},
|
||||
},
|
||||
4: plugin.PluginSet{
|
||||
|
@ -74,13 +76,6 @@ func Serve(opts *ServeOpts) error {
|
|||
},
|
||||
}
|
||||
|
||||
// If we do not have gRPC support fallback to version 3
|
||||
// Remove this block in 0.13
|
||||
if !pluginutil.GRPCSupport() {
|
||||
serveOpts.GRPCServer = nil
|
||||
delete(pluginSets, 4)
|
||||
}
|
||||
|
||||
plugin.Serve(serveOpts)
|
||||
|
||||
return nil
|
||||
|
|
|
@ -1,139 +0,0 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/rpc"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
// StorageClient is an implementation of logical.Storage that communicates
|
||||
// over RPC.
|
||||
type StorageClient struct {
|
||||
client *rpc.Client
|
||||
}
|
||||
|
||||
func (s *StorageClient) List(_ context.Context, prefix string) ([]string, error) {
|
||||
var reply StorageListReply
|
||||
err := s.client.Call("Plugin.List", prefix, &reply)
|
||||
if err != nil {
|
||||
return reply.Keys, err
|
||||
}
|
||||
if reply.Error != nil {
|
||||
return reply.Keys, reply.Error
|
||||
}
|
||||
return reply.Keys, nil
|
||||
}
|
||||
|
||||
func (s *StorageClient) Get(_ context.Context, key string) (*logical.StorageEntry, error) {
|
||||
var reply StorageGetReply
|
||||
err := s.client.Call("Plugin.Get", key, &reply)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reply.Error != nil {
|
||||
return nil, reply.Error
|
||||
}
|
||||
return reply.StorageEntry, nil
|
||||
}
|
||||
|
||||
func (s *StorageClient) Put(_ context.Context, entry *logical.StorageEntry) error {
|
||||
var reply StoragePutReply
|
||||
err := s.client.Call("Plugin.Put", entry, &reply)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if reply.Error != nil {
|
||||
return reply.Error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StorageClient) Delete(_ context.Context, key string) error {
|
||||
var reply StorageDeleteReply
|
||||
err := s.client.Call("Plugin.Delete", key, &reply)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if reply.Error != nil {
|
||||
return reply.Error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// StorageServer is a net/rpc compatible structure for serving
|
||||
type StorageServer struct {
|
||||
impl logical.Storage
|
||||
}
|
||||
|
||||
func (s *StorageServer) List(prefix string, reply *StorageListReply) error {
|
||||
keys, err := s.impl.List(context.Background(), prefix)
|
||||
*reply = StorageListReply{
|
||||
Keys: keys,
|
||||
Error: wrapError(err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StorageServer) Get(key string, reply *StorageGetReply) error {
|
||||
storageEntry, err := s.impl.Get(context.Background(), key)
|
||||
*reply = StorageGetReply{
|
||||
StorageEntry: storageEntry,
|
||||
Error: wrapError(err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StorageServer) Put(entry *logical.StorageEntry, reply *StoragePutReply) error {
|
||||
err := s.impl.Put(context.Background(), entry)
|
||||
*reply = StoragePutReply{
|
||||
Error: wrapError(err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StorageServer) Delete(key string, reply *StorageDeleteReply) error {
|
||||
err := s.impl.Delete(context.Background(), key)
|
||||
*reply = StorageDeleteReply{
|
||||
Error: wrapError(err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type StorageListReply struct {
|
||||
Keys []string
|
||||
Error error
|
||||
}
|
||||
|
||||
type StorageGetReply struct {
|
||||
StorageEntry *logical.StorageEntry
|
||||
Error error
|
||||
}
|
||||
|
||||
type StoragePutReply struct {
|
||||
Error error
|
||||
}
|
||||
|
||||
type StorageDeleteReply struct {
|
||||
Error error
|
||||
}
|
||||
|
||||
// NOOPStorage is used to deny access to the storage interface while running a
|
||||
// backend plugin in metadata mode.
|
||||
type NOOPStorage struct{}
|
||||
|
||||
func (s *NOOPStorage) List(_ context.Context, prefix string) ([]string, error) {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
func (s *NOOPStorage) Get(_ context.Context, key string) (*logical.StorageEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *NOOPStorage) Put(_ context.Context, entry *logical.StorageEntry) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NOOPStorage) Delete(_ context.Context, key string) error {
|
||||
return nil
|
||||
}
|
|
@ -11,22 +11,7 @@ import (
|
|||
)
|
||||
|
||||
func TestStorage_impl(t *testing.T) {
|
||||
var _ logical.Storage = new(StorageClient)
|
||||
}
|
||||
|
||||
func TestStorage_RPC(t *testing.T) {
|
||||
client, server := plugin.TestRPCConn(t)
|
||||
defer client.Close()
|
||||
|
||||
storage := &logical.InmemStorage{}
|
||||
|
||||
server.RegisterName("Plugin", &StorageServer{
|
||||
impl: storage,
|
||||
})
|
||||
|
||||
testStorage := &StorageClient{client: client}
|
||||
|
||||
logical.TestStorage(t, testStorage)
|
||||
var _ logical.Storage = new(GRPCStorageClient)
|
||||
}
|
||||
|
||||
func TestStorage_GRPC(t *testing.T) {
|
||||
|
|
|
@ -1,351 +0,0 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/rpc"
|
||||
"time"
|
||||
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/license"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/helper/wrapping"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
type SystemViewClient struct {
|
||||
client *rpc.Client
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) DefaultLeaseTTL() time.Duration {
|
||||
var reply DefaultLeaseTTLReply
|
||||
err := s.client.Call("Plugin.DefaultLeaseTTL", new(interface{}), &reply)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return reply.DefaultLeaseTTL
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) MaxLeaseTTL() time.Duration {
|
||||
var reply MaxLeaseTTLReply
|
||||
err := s.client.Call("Plugin.MaxLeaseTTL", new(interface{}), &reply)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return reply.MaxLeaseTTL
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) SudoPrivilege(ctx context.Context, path string, token string) bool {
|
||||
var reply SudoPrivilegeReply
|
||||
args := &SudoPrivilegeArgs{
|
||||
Path: path,
|
||||
Token: token,
|
||||
}
|
||||
|
||||
err := s.client.Call("Plugin.SudoPrivilege", args, &reply)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return reply.Sudo
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) Tainted() bool {
|
||||
var reply TaintedReply
|
||||
|
||||
err := s.client.Call("Plugin.Tainted", new(interface{}), &reply)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return reply.Tainted
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) CachingDisabled() bool {
|
||||
var reply CachingDisabledReply
|
||||
|
||||
err := s.client.Call("Plugin.CachingDisabled", new(interface{}), &reply)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return reply.CachingDisabled
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) ReplicationState() consts.ReplicationState {
|
||||
var reply ReplicationStateReply
|
||||
|
||||
err := s.client.Call("Plugin.ReplicationState", new(interface{}), &reply)
|
||||
if err != nil {
|
||||
return consts.ReplicationUnknown
|
||||
}
|
||||
|
||||
return reply.ReplicationState
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
|
||||
var reply ResponseWrapDataReply
|
||||
// Do not allow JWTs to be returned
|
||||
args := &ResponseWrapDataArgs{
|
||||
Data: data,
|
||||
TTL: ttl,
|
||||
JWT: false,
|
||||
}
|
||||
|
||||
err := s.client.Call("Plugin.ResponseWrapData", args, &reply)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reply.Error != nil {
|
||||
return nil, reply.Error
|
||||
}
|
||||
|
||||
return reply.ResponseWrapInfo, nil
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) LookupPlugin(_ context.Context, _ string, _ consts.PluginType) (*pluginutil.PluginRunner, error) {
|
||||
return nil, fmt.Errorf("cannot call LookupPlugin from a plugin backend")
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) HasFeature(feature license.Features) bool {
|
||||
// Not implemented
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) MlockEnabled() bool {
|
||||
var reply MlockEnabledReply
|
||||
err := s.client.Call("Plugin.MlockEnabled", new(interface{}), &reply)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return reply.MlockEnabled
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) LocalMount() bool {
|
||||
var reply LocalMountReply
|
||||
err := s.client.Call("Plugin.LocalMount", new(interface{}), &reply)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return reply.Local
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) EntityInfo(entityID string) (*logical.Entity, error) {
|
||||
var reply EntityInfoReply
|
||||
args := &EntityInfoArgs{
|
||||
EntityID: entityID,
|
||||
}
|
||||
|
||||
err := s.client.Call("Plugin.EntityInfo", args, &reply)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reply.Error != nil {
|
||||
return nil, reply.Error
|
||||
}
|
||||
|
||||
return reply.Entity, nil
|
||||
}
|
||||
|
||||
func (s *SystemViewClient) PluginEnv(_ context.Context) (*logical.PluginEnvironment, error) {
|
||||
var reply PluginEnvReply
|
||||
|
||||
err := s.client.Call("Plugin.PluginEnv", new(interface{}), &reply)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reply.Error != nil {
|
||||
return nil, reply.Error
|
||||
}
|
||||
|
||||
return reply.PluginEnvironment, nil
|
||||
}
|
||||
|
||||
type SystemViewServer struct {
|
||||
impl logical.SystemView
|
||||
}
|
||||
|
||||
func (s *SystemViewServer) DefaultLeaseTTL(_ interface{}, reply *DefaultLeaseTTLReply) error {
|
||||
ttl := s.impl.DefaultLeaseTTL()
|
||||
*reply = DefaultLeaseTTLReply{
|
||||
DefaultLeaseTTL: ttl,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SystemViewServer) MaxLeaseTTL(_ interface{}, reply *MaxLeaseTTLReply) error {
|
||||
ttl := s.impl.MaxLeaseTTL()
|
||||
*reply = MaxLeaseTTLReply{
|
||||
MaxLeaseTTL: ttl,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SystemViewServer) SudoPrivilege(args *SudoPrivilegeArgs, reply *SudoPrivilegeReply) error {
|
||||
sudo := s.impl.SudoPrivilege(context.Background(), args.Path, args.Token)
|
||||
*reply = SudoPrivilegeReply{
|
||||
Sudo: sudo,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SystemViewServer) Tainted(_ interface{}, reply *TaintedReply) error {
|
||||
tainted := s.impl.Tainted()
|
||||
*reply = TaintedReply{
|
||||
Tainted: tainted,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SystemViewServer) CachingDisabled(_ interface{}, reply *CachingDisabledReply) error {
|
||||
cachingDisabled := s.impl.CachingDisabled()
|
||||
*reply = CachingDisabledReply{
|
||||
CachingDisabled: cachingDisabled,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SystemViewServer) ReplicationState(_ interface{}, reply *ReplicationStateReply) error {
|
||||
replicationState := s.impl.ReplicationState()
|
||||
*reply = ReplicationStateReply{
|
||||
ReplicationState: replicationState,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SystemViewServer) ResponseWrapData(args *ResponseWrapDataArgs, reply *ResponseWrapDataReply) error {
|
||||
// Do not allow JWTs to be returned
|
||||
info, err := s.impl.ResponseWrapData(context.Background(), args.Data, args.TTL, false)
|
||||
if err != nil {
|
||||
*reply = ResponseWrapDataReply{
|
||||
Error: wrapError(err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
*reply = ResponseWrapDataReply{
|
||||
ResponseWrapInfo: info,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SystemViewServer) MlockEnabled(_ interface{}, reply *MlockEnabledReply) error {
|
||||
enabled := s.impl.MlockEnabled()
|
||||
*reply = MlockEnabledReply{
|
||||
MlockEnabled: enabled,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SystemViewServer) LocalMount(_ interface{}, reply *LocalMountReply) error {
|
||||
local := s.impl.LocalMount()
|
||||
*reply = LocalMountReply{
|
||||
Local: local,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SystemViewServer) EntityInfo(args *EntityInfoArgs, reply *EntityInfoReply) error {
|
||||
entity, err := s.impl.EntityInfo(args.EntityID)
|
||||
if err != nil {
|
||||
*reply = EntityInfoReply{
|
||||
Error: wrapError(err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
*reply = EntityInfoReply{
|
||||
Entity: entity,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SystemViewServer) PluginEnv(_ interface{}, reply *PluginEnvReply) error {
|
||||
pluginEnv, err := s.impl.PluginEnv(context.Background())
|
||||
if err != nil {
|
||||
*reply = PluginEnvReply{
|
||||
Error: wrapError(err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
*reply = PluginEnvReply{
|
||||
PluginEnvironment: pluginEnv,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type DefaultLeaseTTLReply struct {
|
||||
DefaultLeaseTTL time.Duration
|
||||
}
|
||||
|
||||
type MaxLeaseTTLReply struct {
|
||||
MaxLeaseTTL time.Duration
|
||||
}
|
||||
|
||||
type SudoPrivilegeArgs struct {
|
||||
Path string
|
||||
Token string
|
||||
}
|
||||
|
||||
type SudoPrivilegeReply struct {
|
||||
Sudo bool
|
||||
}
|
||||
|
||||
type TaintedReply struct {
|
||||
Tainted bool
|
||||
}
|
||||
|
||||
type CachingDisabledReply struct {
|
||||
CachingDisabled bool
|
||||
}
|
||||
|
||||
type ReplicationStateReply struct {
|
||||
ReplicationState consts.ReplicationState
|
||||
}
|
||||
|
||||
type ResponseWrapDataArgs struct {
|
||||
Data map[string]interface{}
|
||||
TTL time.Duration
|
||||
JWT bool
|
||||
}
|
||||
|
||||
type ResponseWrapDataReply struct {
|
||||
ResponseWrapInfo *wrapping.ResponseWrapInfo
|
||||
Error error
|
||||
}
|
||||
|
||||
type MlockEnabledReply struct {
|
||||
MlockEnabled bool
|
||||
}
|
||||
|
||||
type LocalMountReply struct {
|
||||
Local bool
|
||||
}
|
||||
|
||||
type EntityInfoArgs struct {
|
||||
EntityID string
|
||||
}
|
||||
|
||||
type EntityInfoReply struct {
|
||||
Entity *logical.Entity
|
||||
Error error
|
||||
}
|
||||
|
||||
type PluginEnvReply struct {
|
||||
PluginEnvironment *logical.PluginEnvironment
|
||||
Error error
|
||||
}
|
|
@ -1,231 +0,0 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"reflect"
|
||||
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
func Test_impl(t *testing.T) {
|
||||
var _ logical.SystemView = new(SystemViewClient)
|
||||
}
|
||||
|
||||
func TestSystem_defaultLeaseTTL(t *testing.T) {
|
||||
client, server := plugin.TestRPCConn(t)
|
||||
defer client.Close()
|
||||
|
||||
sys := logical.TestSystemView()
|
||||
|
||||
server.RegisterName("Plugin", &SystemViewServer{
|
||||
impl: sys,
|
||||
})
|
||||
|
||||
testSystemView := &SystemViewClient{client: client}
|
||||
|
||||
expected := sys.DefaultLeaseTTL()
|
||||
actual := testSystemView.DefaultLeaseTTL()
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("expected: %v, got: %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystem_maxLeaseTTL(t *testing.T) {
|
||||
client, server := plugin.TestRPCConn(t)
|
||||
defer client.Close()
|
||||
|
||||
sys := logical.TestSystemView()
|
||||
|
||||
server.RegisterName("Plugin", &SystemViewServer{
|
||||
impl: sys,
|
||||
})
|
||||
|
||||
testSystemView := &SystemViewClient{client: client}
|
||||
|
||||
expected := sys.MaxLeaseTTL()
|
||||
actual := testSystemView.MaxLeaseTTL()
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("expected: %v, got: %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystem_sudoPrivilege(t *testing.T) {
|
||||
client, server := plugin.TestRPCConn(t)
|
||||
defer client.Close()
|
||||
|
||||
sys := logical.TestSystemView()
|
||||
sys.SudoPrivilegeVal = true
|
||||
|
||||
server.RegisterName("Plugin", &SystemViewServer{
|
||||
impl: sys,
|
||||
})
|
||||
|
||||
testSystemView := &SystemViewClient{client: client}
|
||||
ctx := context.Background()
|
||||
|
||||
expected := sys.SudoPrivilege(ctx, "foo", "bar")
|
||||
actual := testSystemView.SudoPrivilege(ctx, "foo", "bar")
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("expected: %v, got: %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystem_tainted(t *testing.T) {
|
||||
client, server := plugin.TestRPCConn(t)
|
||||
defer client.Close()
|
||||
|
||||
sys := logical.TestSystemView()
|
||||
sys.TaintedVal = true
|
||||
|
||||
server.RegisterName("Plugin", &SystemViewServer{
|
||||
impl: sys,
|
||||
})
|
||||
|
||||
testSystemView := &SystemViewClient{client: client}
|
||||
|
||||
expected := sys.Tainted()
|
||||
actual := testSystemView.Tainted()
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("expected: %v, got: %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystem_cachingDisabled(t *testing.T) {
|
||||
client, server := plugin.TestRPCConn(t)
|
||||
defer client.Close()
|
||||
|
||||
sys := logical.TestSystemView()
|
||||
sys.CachingDisabledVal = true
|
||||
|
||||
server.RegisterName("Plugin", &SystemViewServer{
|
||||
impl: sys,
|
||||
})
|
||||
|
||||
testSystemView := &SystemViewClient{client: client}
|
||||
|
||||
expected := sys.CachingDisabled()
|
||||
actual := testSystemView.CachingDisabled()
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("expected: %v, got: %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystem_replicationState(t *testing.T) {
|
||||
client, server := plugin.TestRPCConn(t)
|
||||
defer client.Close()
|
||||
|
||||
sys := logical.TestSystemView()
|
||||
sys.ReplicationStateVal = consts.ReplicationPerformancePrimary
|
||||
|
||||
server.RegisterName("Plugin", &SystemViewServer{
|
||||
impl: sys,
|
||||
})
|
||||
|
||||
testSystemView := &SystemViewClient{client: client}
|
||||
|
||||
expected := sys.ReplicationState()
|
||||
actual := testSystemView.ReplicationState()
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("expected: %v, got: %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystem_responseWrapData(t *testing.T) {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
func TestSystem_lookupPlugin(t *testing.T) {
|
||||
client, server := plugin.TestRPCConn(t)
|
||||
defer client.Close()
|
||||
|
||||
sys := logical.TestSystemView()
|
||||
|
||||
server.RegisterName("Plugin", &SystemViewServer{
|
||||
impl: sys,
|
||||
})
|
||||
|
||||
testSystemView := &SystemViewClient{client: client}
|
||||
|
||||
if _, err := testSystemView.LookupPlugin(context.Background(), "foo", consts.PluginTypeDatabase); err == nil {
|
||||
t.Fatal("LookPlugin(): expected error on due to unsupported call from plugin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystem_mlockEnabled(t *testing.T) {
|
||||
client, server := plugin.TestRPCConn(t)
|
||||
defer client.Close()
|
||||
|
||||
sys := logical.TestSystemView()
|
||||
sys.EnableMlock = true
|
||||
|
||||
server.RegisterName("Plugin", &SystemViewServer{
|
||||
impl: sys,
|
||||
})
|
||||
|
||||
testSystemView := &SystemViewClient{client: client}
|
||||
|
||||
expected := sys.MlockEnabled()
|
||||
actual := testSystemView.MlockEnabled()
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("expected: %v, got: %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystem_entityInfo(t *testing.T) {
|
||||
client, server := plugin.TestRPCConn(t)
|
||||
defer client.Close()
|
||||
|
||||
sys := logical.TestSystemView()
|
||||
sys.EntityVal = &logical.Entity{
|
||||
ID: "test",
|
||||
Name: "name",
|
||||
}
|
||||
|
||||
server.RegisterName("Plugin", &SystemViewServer{
|
||||
impl: sys,
|
||||
})
|
||||
|
||||
testSystemView := &SystemViewClient{client: client}
|
||||
|
||||
actual, err := testSystemView.EntityInfo("")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(sys.EntityVal, actual) {
|
||||
t.Fatalf("expected: %v, got: %v", sys.EntityVal, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystem_pluginEnv(t *testing.T) {
|
||||
client, server := plugin.TestRPCConn(t)
|
||||
defer client.Close()
|
||||
|
||||
sys := logical.TestSystemView()
|
||||
sys.PluginEnvironment = &logical.PluginEnvironment{
|
||||
VaultVersion: "0.10.42",
|
||||
}
|
||||
|
||||
server.RegisterName("Plugin", &SystemViewServer{
|
||||
impl: sys,
|
||||
})
|
||||
|
||||
testSystemView := &SystemViewClient{client: client}
|
||||
|
||||
expected, err := sys.PluginEnv(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
actual, err := testSystemView.PluginEnv(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Fatalf("expected: %v, got: %v", expected, actual)
|
||||
}
|
||||
}
|
|
@ -35,9 +35,9 @@ export default DS.RESTAdapter.extend({
|
|||
let headers = {};
|
||||
if (token && !options.unauthenticated) {
|
||||
headers['X-Vault-Token'] = token;
|
||||
if (options.wrapTTL) {
|
||||
headers['X-Vault-Wrap-TTL'] = options.wrapTTL;
|
||||
}
|
||||
}
|
||||
if (options.wrapTTL) {
|
||||
headers['X-Vault-Wrap-TTL'] = options.wrapTTL;
|
||||
}
|
||||
let namespace =
|
||||
typeof options.namespace === 'undefined' ? this.get('namespaceService.path') : options.namespace;
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
import AuthConfig from './_base';
|
||||
|
||||
export default AuthConfig.extend();
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
import AuthConfig from './_base';
|
||||
|
||||
export default AuthConfig.extend();
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
import AuthConfig from './_base';
|
||||
|
||||
export default AuthConfig.extend();
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
import AuthConfig from './_base';
|
||||
|
||||
export default AuthConfig.extend();
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
import AuthConfig from './_base';
|
||||
|
||||
export default AuthConfig.extend();
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
import AuthConfig from './_base';
|
||||
|
||||
export default AuthConfig.extend();
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
import AuthConfig from './_base';
|
||||
|
||||
export default AuthConfig.extend();
|
||||
|
|
|
@ -56,4 +56,8 @@ export default ApplicationAdapter.extend({
|
|||
urlForDeleteRecord(id, modelName, snapshot) {
|
||||
return this.url(snapshot.id);
|
||||
},
|
||||
|
||||
exchangeOIDC(path, state, code) {
|
||||
return this.ajax(`/v1/auth/${path}/oidc/callback`, 'GET', { data: { state, code } });
|
||||
},
|
||||
});
|
||||
|
|
|
@ -109,7 +109,7 @@ export default ApplicationAdapter.extend({
|
|||
},
|
||||
|
||||
authenticate({ backend, data }) {
|
||||
const { token, password, username, path } = data;
|
||||
const { role, jwt, token, password, username, path } = data;
|
||||
const url = this.urlForAuth(backend, username, path);
|
||||
const verb = backend === 'token' ? 'GET' : 'POST';
|
||||
let options = {
|
||||
|
@ -119,6 +119,8 @@ export default ApplicationAdapter.extend({
|
|||
options.headers = {
|
||||
'X-Vault-Token': token,
|
||||
};
|
||||
} else if (backend === 'jwt') {
|
||||
options.data = { role, jwt };
|
||||
} else {
|
||||
options.data = token ? { token, password } : { password };
|
||||
}
|
||||
|
@ -139,6 +141,7 @@ export default ApplicationAdapter.extend({
|
|||
const authBackend = type.toLowerCase();
|
||||
const authURLs = {
|
||||
github: 'login',
|
||||
jwt: 'login',
|
||||
userpass: `login/${encodeURIComponent(username)}`,
|
||||
ldap: `login/${encodeURIComponent(username)}`,
|
||||
okta: `login/${encodeURIComponent(username)}`,
|
||||
|
|
|
@ -7,4 +7,8 @@ export default Adapter.extend({
|
|||
}
|
||||
return `/v1/${role.backend}/sign/${role.name}`;
|
||||
},
|
||||
|
||||
pathForType() {
|
||||
return 'sign';
|
||||
},
|
||||
});
|
||||
|
|
|
@ -13,7 +13,6 @@ export default Adapter.extend({
|
|||
}
|
||||
return url;
|
||||
},
|
||||
|
||||
optionsForQuery(id) {
|
||||
let data = {};
|
||||
if (!id) {
|
||||
|
|
32
ui/app/adapters/role-jwt.js
Normal file
32
ui/app/adapters/role-jwt.js
Normal file
|
@ -0,0 +1,32 @@
|
|||
import ApplicationAdapter from './application';
|
||||
import { inject as service } from '@ember/service';
|
||||
import { get } from '@ember/object';
|
||||
|
||||
export default ApplicationAdapter.extend({
|
||||
router: service(),
|
||||
|
||||
findRecord(store, type, id, snapshot) {
|
||||
let [path, role] = JSON.parse(id);
|
||||
|
||||
let namespace = get(snapshot, 'adapterOptions.namespace');
|
||||
let url = `/v1/auth/${path}/oidc/auth_url`;
|
||||
let redirect_uri = `${window.location.origin}${this.router.urlFor('vault.cluster.oidc-callback', {
|
||||
auth_path: path,
|
||||
})}`;
|
||||
|
||||
if (namespace) {
|
||||
redirect_uri = `${window.location.origin}${this.router.urlFor(
|
||||
'vault.cluster.oidc-callback',
|
||||
{ auth_path: path },
|
||||
{ queryParams: { namespace } }
|
||||
)}`;
|
||||
}
|
||||
|
||||
return this.ajax(url, 'POST', {
|
||||
data: {
|
||||
role,
|
||||
redirect_uri,
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
|
@ -7,16 +7,20 @@ export default ApplicationAdapter.extend({
|
|||
return path ? url + '/' + path : url;
|
||||
},
|
||||
|
||||
internalURL(path) {
|
||||
let url = `/${this.urlPrefix()}/internal/ui/mounts`;
|
||||
if (path) {
|
||||
url = `${url}/${path}`;
|
||||
}
|
||||
return url;
|
||||
},
|
||||
|
||||
pathForType() {
|
||||
return 'mounts';
|
||||
},
|
||||
|
||||
query(store, type, query) {
|
||||
let url = `/${this.urlPrefix()}/internal/ui/mounts`;
|
||||
if (query.path) {
|
||||
url = `${url}/${query.path}`;
|
||||
}
|
||||
return this.ajax(url, 'GET');
|
||||
return this.ajax(this.internalURL(query.path), 'GET');
|
||||
},
|
||||
|
||||
createRecord(store, type, snapshot) {
|
||||
|
|
|
@ -34,6 +34,10 @@ export default ApplicationAdapter.extend({
|
|||
return url;
|
||||
},
|
||||
|
||||
pathForType() {
|
||||
return 'mounts';
|
||||
},
|
||||
|
||||
optionsForQuery(id, action, wrapTTL) {
|
||||
let data = {};
|
||||
if (action === 'query') {
|
||||
|
|
|
@ -5,6 +5,7 @@ import { messageTypes } from 'vault/helpers/message-types';
|
|||
|
||||
export default Component.extend({
|
||||
type: null,
|
||||
message: null,
|
||||
|
||||
classNames: ['message-inline'],
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue