Merge pull request #6262 from hashicorp/1.1-beta

Merge 1.1 Beta
This commit is contained in:
Brian Kassouf 2019-02-19 12:20:58 -08:00 committed by GitHub
commit bb910f2bb9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
245 changed files with 11051 additions and 4554 deletions

2
.gitignore vendored
View file

@ -48,7 +48,9 @@ Vagrantfile
# Configs
*.hcl
!command/agent/config/test-fixtures/config.hcl
!command/agent/config/test-fixtures/config-cache.hcl
!command/agent/config/test-fixtures/config-embedded-type.hcl
!command/agent/config/test-fixtures/config-cache-embedded-type.hcl
.DS_Store
.idea

View file

@ -1,4 +1,4 @@
## Next
## 1.0.3 (February 12th, 2019)
CHANGES:

View file

@ -25,6 +25,7 @@ import (
"golang.org/x/time/rate"
)
const EnvVaultAgentAddress = "VAULT_AGENT_ADDR"
const EnvVaultAddress = "VAULT_ADDR"
const EnvVaultCACert = "VAULT_CACERT"
const EnvVaultCAPath = "VAULT_CAPATH"
@ -237,6 +238,10 @@ func (c *Config) ReadEnvironment() error {
if v := os.Getenv(EnvVaultAddress); v != "" {
envAddress = v
}
// Agent's address will take precedence over Vault's address
if v := os.Getenv(EnvVaultAgentAddress); v != "" {
envAddress = v
}
if v := os.Getenv(EnvVaultMaxRetries); v != "" {
maxRetries, err := strconv.ParseUint(v, 10, 32)
if err != nil {
@ -366,11 +371,6 @@ func NewClient(c *Config) (*Client, error) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
u, err := url.Parse(c.Address)
if err != nil {
return nil, err
}
if c.HttpClient == nil {
c.HttpClient = def.HttpClient
}
@ -378,6 +378,21 @@ func NewClient(c *Config) (*Client, error) {
c.HttpClient.Transport = def.HttpClient.Transport
}
if strings.HasPrefix(c.Address, "unix://") {
socket := strings.TrimPrefix(c.Address, "unix://")
transport := c.HttpClient.Transport.(*http.Transport)
transport.DialContext = func(context.Context, string, string) (net.Conn, error) {
return net.Dial("unix", socket)
}
// TODO: This shouldn't ideally be done. To be fixed post 1.1-beta.
c.Address = "http://unix"
}
u, err := url.Parse(c.Address)
if err != nil {
return nil, err
}
client := &Client{
addr: u,
config: c,

View file

@ -292,6 +292,7 @@ type SecretAuth struct {
TokenPolicies []string `json:"token_policies"`
IdentityPolicies []string `json:"identity_policies"`
Metadata map[string]string `json:"metadata"`
Orphan bool `json:"orphan"`
LeaseDuration int `json:"lease_duration"`
Renewable bool `json:"renewable"`

View file

@ -25,14 +25,17 @@ func pathConfig(b *backend) *framework.Path {
Description: `The API endpoint to use. Useful if you
are running GitHub Enterprise or an
API-compatible authentication server.`,
DisplayName: "Base URL",
},
"ttl": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Duration after which authentication will be expired`,
DisplayName: "TTL",
},
"max_ttl": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Maximum duration after which authentication will be expired`,
DisplayName: "Max TTL",
},
},

View file

@ -25,26 +25,32 @@ func pathConfig(b *backend) *framework.Path {
"organization": &framework.FieldSchema{
Type: framework.TypeString,
Description: "(DEPRECATED) Okta organization to authenticate against. Use org_name instead.",
Deprecated: true,
},
"org_name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the organization to be used in the Okta API.",
DisplayName: "Organization Name",
},
"token": &framework.FieldSchema{
Type: framework.TypeString,
Description: "(DEPRECATED) Okta admin API token. Use api_token instead.",
Deprecated: true,
},
"api_token": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Okta API key.",
DisplayName: "API Token",
},
"base_url": &framework.FieldSchema{
Type: framework.TypeString,
Description: `The base domain to use for the Okta API. When not specified in the configuration, "okta.com" is used.`,
DisplayName: "Base URL",
},
"production": &framework.FieldSchema{
Type: framework.TypeBool,
Description: `(DEPRECATED) Use base_url.`,
Deprecated: true,
},
"ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
@ -57,6 +63,7 @@ func pathConfig(b *backend) *framework.Path {
"bypass_okta_mfa": &framework.FieldSchema{
Type: framework.TypeBool,
Description: `When set true, requests by Okta for a MFA check will be bypassed. This also disallows certain status checks on the account, such as whether the password is expired.`,
DisplayName: "Bypass Okta MFA",
},
},

View file

@ -15,6 +15,7 @@ func pathConfig(b *backend) *framework.Path {
"host": &framework.FieldSchema{
Type: framework.TypeString,
Description: "RADIUS server host",
DisplayName: "Host",
},
"port": &framework.FieldSchema{
@ -30,6 +31,7 @@ func pathConfig(b *backend) *framework.Path {
Type: framework.TypeString,
Default: "",
Description: "Comma-separated list of policies to grant upon successful RADIUS authentication of an unregisted user (default: emtpy)",
DisplayName: "Policies for unregistered users",
},
"dial_timeout": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
@ -45,11 +47,13 @@ func pathConfig(b *backend) *framework.Path {
Type: framework.TypeInt,
Default: 10,
Description: "RADIUS NAS port field (default: 10)",
DisplayName: "NAS Port",
},
"nas_identifier": &framework.FieldSchema{
Type: framework.TypeString,
Default: "",
Description: "RADIUS NAS Identifier field (optional)",
DisplayName: "NAS Identifier",
},
},

View file

@ -36,6 +36,7 @@ func pathRoles(b *backend) *framework.Path {
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the policy",
DisplayName: "Policy Name",
},
"credential_type": &framework.FieldSchema{
@ -46,11 +47,13 @@ func pathRoles(b *backend) *framework.Path {
"role_arns": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: "ARNs of AWS roles allowed to be assumed. Only valid when credential_type is " + assumedRoleCred,
DisplayName: "Role ARNs",
},
"policy_arns": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: "ARNs of AWS policies to attach to IAM users. Only valid when credential_type is " + iamUserCred,
DisplayName: "Policy ARNs",
},
"policy_document": &framework.FieldSchema{
@ -65,22 +68,26 @@ GetFederationToken API call, acting as a filter on permissions available.`,
"default_sts_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: fmt.Sprintf("Default TTL for %s and %s credential types when no TTL is explicitly requested with the credentials", assumedRoleCred, federationTokenCred),
DisplayName: "Default TTL",
},
"max_sts_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: fmt.Sprintf("Max allowed TTL for %s and %s credential types", assumedRoleCred, federationTokenCred),
DisplayName: "Max TTL",
},
"arn": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Deprecated; use role_arns or policy_arns instead. ARN Reference to a managed policy
or IAM role to assume`,
Deprecated: true,
},
"policy": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Deprecated; use policy_document instead. IAM policy document",
Deprecated: true,
},
},

View file

@ -35,11 +35,12 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
// pluginSets is the map of plugins we can dispense.
pluginSets := map[int]plugin.PluginSet{
// Version 3 supports both protocols
// Version 3 used to supports both protocols. We want to keep it around
// since it's possible old plugins built against this version will still
// work with gRPC. There is currently no difference between version 3
// and version 4.
3: plugin.PluginSet{
"database": &DatabasePlugin{
GRPCDatabasePlugin: new(GRPCDatabasePlugin),
},
"database": new(GRPCDatabasePlugin),
},
// Version 4 only supports gRPC
4: plugin.PluginSet{
@ -76,9 +77,6 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
switch raw.(type) {
case *gRPCClient:
db = raw.(*gRPCClient)
case *databasePluginRPCClient:
logger.Warn("plugin is using deprecated netRPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name)
db = raw.(*databasePluginRPCClient)
default:
return nil, errors.New("unsupported client type")
}

View file

@ -1,197 +0,0 @@
package dbplugin
import (
"context"
"encoding/json"
"fmt"
"net/rpc"
"strings"
"time"
)
// ---- RPC server domain ----
// databasePluginRPCServer implements an RPC version of Database and is run
// inside a plugin. It wraps an underlying implementation of Database.
type databasePluginRPCServer struct {
impl Database
}
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
var err error
*resp, err = ds.impl.Type()
return err
}
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequestRPC, resp *CreateUserResponse) error {
var err error
resp.Username, resp.Password, err = ds.impl.CreateUser(context.Background(), args.Statements, args.UsernameConfig, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequestRPC, _ *struct{}) error {
err := ds.impl.RenewUser(context.Background(), args.Statements, args.Username, args.Expiration)
return err
}
func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *struct{}) error {
err := ds.impl.RevokeUser(context.Background(), args.Statements, args.Username)
return err
}
func (ds *databasePluginRPCServer) RotateRootCredentials(args *RotateRootCredentialsRequestRPC, resp *RotateRootCredentialsResponse) error {
config, err := ds.impl.RotateRootCredentials(context.Background(), args.Statements)
if err != nil {
return err
}
resp.Config, err = json.Marshal(config)
return err
}
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error {
return ds.Init(&InitRequestRPC{
Config: args.Config,
VerifyConnection: args.VerifyConnection,
}, &InitResponse{})
}
func (ds *databasePluginRPCServer) Init(args *InitRequestRPC, resp *InitResponse) error {
config, err := ds.impl.Init(context.Background(), args.Config, args.VerifyConnection)
if err != nil {
return err
}
resp.Config, err = json.Marshal(config)
return err
}
func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error {
ds.impl.Close()
return nil
}
// ---- RPC client domain ----
// databasePluginRPCClient implements Database and is used on the client to
// make RPC calls to a plugin.
type databasePluginRPCClient struct {
client *rpc.Client
}
func (dr *databasePluginRPCClient) Type() (string, error) {
var dbType string
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
return fmt.Sprintf("plugin-%s", dbType), err
}
func (dr *databasePluginRPCClient) CreateUser(_ context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
req := CreateUserRequestRPC{
Statements: statements,
UsernameConfig: usernameConfig,
Expiration: expiration,
}
var resp CreateUserResponse
err = dr.client.Call("Plugin.CreateUser", req, &resp)
return resp.Username, resp.Password, err
}
func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements Statements, username string, expiration time.Time) error {
req := RenewUserRequestRPC{
Statements: statements,
Username: username,
Expiration: expiration,
}
return dr.client.Call("Plugin.RenewUser", req, &struct{}{})
}
func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error {
req := RevokeUserRequestRPC{
Statements: statements,
Username: username,
}
return dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
}
func (dr *databasePluginRPCClient) RotateRootCredentials(_ context.Context, statements []string) (saveConf map[string]interface{}, err error) {
req := RotateRootCredentialsRequestRPC{
Statements: statements,
}
var resp RotateRootCredentialsResponse
err = dr.client.Call("Plugin.RotateRootCredentials", req, &resp)
err = json.Unmarshal(resp.Config, &saveConf)
return saveConf, err
}
func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := dr.Init(nil, conf, verifyConnection)
return err
}
func (dr *databasePluginRPCClient) Init(_ context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
req := InitRequestRPC{
Config: conf,
VerifyConnection: verifyConnection,
}
var resp InitResponse
err = dr.client.Call("Plugin.Init", req, &resp)
if err != nil {
if strings.Contains(err.Error(), "can't find method Plugin.Init") {
req := InitializeRequestRPC{
Config: conf,
VerifyConnection: verifyConnection,
}
err = dr.client.Call("Plugin.Initialize", req, &struct{}{})
if err == nil {
return conf, nil
}
}
return nil, err
}
err = json.Unmarshal(resp.Config, &saveConf)
return saveConf, err
}
func (dr *databasePluginRPCClient) Close() error {
return dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
}
// ---- RPC Request Args Domain ----
type InitializeRequestRPC struct {
Config map[string]interface{}
VerifyConnection bool
}
type InitRequestRPC struct {
Config map[string]interface{}
VerifyConnection bool
}
type CreateUserRequestRPC struct {
Statements Statements
UsernameConfig UsernameConfig
Expiration time.Time
}
type RenewUserRequestRPC struct {
Statements Statements
Username string
Expiration time.Time
}
type RevokeUserRequestRPC struct {
Statements Statements
Username string
}
type RotateRootCredentialsRequestRPC struct {
Statements []string
}

View file

@ -3,7 +3,6 @@ package dbplugin
import (
"context"
"fmt"
"net/rpc"
"time"
"google.golang.org/grpc"
@ -72,8 +71,6 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu
switch db.(*DatabasePluginClient).Database.(type) {
case *gRPCClient:
transport = "gRPC"
case *databasePluginRPCClient:
transport = "netRPC"
}
}
@ -110,17 +107,9 @@ var handshakeConfig = plugin.HandshakeConfig{
MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb",
}
var _ plugin.Plugin = &DatabasePlugin{}
var _ plugin.GRPCPlugin = &DatabasePlugin{}
var _ plugin.Plugin = &GRPCDatabasePlugin{}
var _ plugin.GRPCPlugin = &GRPCDatabasePlugin{}
// DatabasePlugin implements go-plugin's Plugin interface. It has methods for
// retrieving a server and a client instance of the plugin.
type DatabasePlugin struct {
*GRPCDatabasePlugin
}
// GRPCDatabasePlugin is the plugin.Plugin implementation that only supports GRPC
// transport
type GRPCDatabasePlugin struct {
@ -130,17 +119,6 @@ type GRPCDatabasePlugin struct {
plugin.NetRPCUnsupportedPlugin
}
func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) {
impl := &DatabaseErrorSanitizerMiddleware{
next: d.Impl,
}
return &databasePluginRPCServer{impl: impl}, nil
}
func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
return &databasePluginRPCClient{client: c}, nil
}
func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error {
impl := &DatabaseErrorSanitizerMiddleware{
next: d.Impl,

View file

@ -8,7 +8,6 @@ import (
"time"
log "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/namespace"
@ -96,7 +95,6 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
sys := vault.TestDynamicSystemView(cores[0].Core)
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", consts.PluginTypeDatabase, "TestPlugin_GRPC_Main", []string{}, "")
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", consts.PluginTypeDatabase, "TestPlugin_NetRPC_Main", []string{}, "")
return cluster, sys
}
@ -121,31 +119,6 @@ func TestPlugin_GRPC_Main(t *testing.T) {
plugins.Serve(plugin, apiClientMeta.GetTLSConfig())
}
// This is not an actual test case, it's a helper function that will be executed
// by the go-plugin client via an exec call.
func TestPlugin_NetRPC_Main(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return
}
os.Unsetenv(pluginutil.PluginVaultVersionEnv)
p := &mockPlugin{
users: make(map[string][]string),
}
args := []string{"--tls-skip-verify=true"}
apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(args)
tlsProvider := pluginutil.VaultPluginTLSProvider(apiClientMeta.GetTLSConfig())
serveConf := dbplugin.ServeConfig(p, tlsProvider)
serveConf.GRPCServer = nil
plugin.Serve(serveConf)
}
func TestPlugin_Init(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
@ -284,143 +257,3 @@ func TestPlugin_RevokeUser(t *testing.T) {
t.Fatalf("err: %s", err)
}
}
// Test the code is still compatible with an old netRPC plugin
func TestPlugin_NetRPC_Init(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
dbRaw, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin-netRPC", sys, log.NewNullLogger())
if err != nil {
t.Fatalf("err: %s", err)
}
connectionDetails := map[string]interface{}{
"test": 1,
}
_, err = dbRaw.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
err = dbRaw.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestPlugin_NetRPC_CreateUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin-netRPC", sys, log.NewNullLogger())
if err != nil {
t.Fatalf("err: %s", err)
}
defer db.Close()
connectionDetails := map[string]interface{}{
"test": 1,
}
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConf := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
if us != "test" || pw != "test" {
t.Fatal("expected username and password to be 'test'")
}
// try and save the same user again to verify it saved the first time, this
// should return an error
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err == nil {
t.Fatal("expected an error, user wasn't created correctly")
}
}
func TestPlugin_NetRPC_RenewUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin-netRPC", sys, log.NewNullLogger())
if err != nil {
t.Fatalf("err: %s", err)
}
defer db.Close()
connectionDetails := map[string]interface{}{
"test": 1,
}
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConf := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin-netRPC", sys, log.NewNullLogger())
if err != nil {
t.Fatalf("err: %s", err)
}
defer db.Close()
connectionDetails := map[string]interface{}{
"test": 1,
}
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
usernameConf := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
// Test default revoke statements
err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us)
if err != nil {
t.Fatalf("err: %s", err)
}
// Try adding the same username back so we can verify it was removed
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
}

View file

@ -4,7 +4,6 @@ import (
"crypto/tls"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil"
)
// Serve is called from within a plugin and wraps the provided
@ -17,11 +16,13 @@ func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig {
// pluginSets is the map of plugins we can dispense.
pluginSets := map[int]plugin.PluginSet{
// Version 3 used to supports both protocols. We want to keep it around
// since it's possible old plugins built against this version will still
// work with gRPC. There is currently no difference between version 3
// and version 4.
3: plugin.PluginSet{
"database": &DatabasePlugin{
GRPCDatabasePlugin: &GRPCDatabasePlugin{
Impl: db,
},
"database": &GRPCDatabasePlugin{
Impl: db,
},
},
4: plugin.PluginSet{
@ -38,12 +39,5 @@ func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.S
GRPCServer: plugin.DefaultGRPCServer,
}
// If we do not have gRPC support fallback to version 3
// Remove this block in 0.13
if !pluginutil.GRPCSupport() {
conf.GRPCServer = nil
delete(conf.VersionedPlugins, 4)
}
return conf
}

View file

@ -11,6 +11,7 @@ func addIssueAndSignCommonFields(fields map[string]*framework.FieldSchema) map[s
Description: `If true, the Common Name will not be
included in DNS or Email Subject Alternate Names.
Defaults to false (CN is included).`,
DisplayName: "Exclude Common Name from Subject Alternative Names (SANs)",
}
fields["format"] = &framework.FieldSchema{
@ -20,6 +21,7 @@ Defaults to false (CN is included).`,
or "pem_bundle". If "pem_bundle" any private
key and issuing cert will be appended to the
certificate pem. Defaults to "pem".`,
AllowedValues: []interface{}{"pem", "der", "pem_bundle"},
}
fields["private_key_format"] = &framework.FieldSchema{
@ -31,24 +33,28 @@ parameter as either base64-encoded DER or PEM-encoded DER.
However, this can be set to "pkcs8" to have the returned
private key contain base64-encoded pkcs8 or PEM-encoded
pkcs8 instead. Defaults to "der".`,
AllowedValues: []interface{}{"", "der", "pem", "pkcs8"},
}
fields["ip_sans"] = &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `The requested IP SANs, if any, in a
comma-delimited list`,
DisplayName: "IP Subject Alternative Names (SANs)",
}
fields["uri_sans"] = &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `The requested URI SANs, if any, in a
comma-delimited list.`,
DisplayName: "URI Subject Alternative Names (SANs)",
}
fields["other_sans"] = &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `Requested other SANs, in an array with the format
<oid>;UTF8:<utf8 string value> for each entry.`,
DisplayName: "Other SANs",
}
return fields
@ -79,6 +85,7 @@ in the role, this may be an email address.`,
in a comma-delimited list. If email protection
is enabled for the role, this may contain
email addresses.`,
DisplayName: "DNS/Email Subject Alternative Names (SANs)",
}
fields["serial_number"] = &framework.FieldSchema{
@ -95,6 +102,7 @@ sets the expiration date. If not specified
the role default, backend default, or system
default TTL is used, in that order. Cannot
be larger than the role max TTL.`,
DisplayName: "TTL",
}
return fields
@ -110,6 +118,7 @@ func addCACommonFields(fields map[string]*framework.FieldSchema) map[string]*fra
Description: `The requested Subject Alternative Names, if any,
in a comma-delimited list. May contain both
DNS names and email addresses.`,
DisplayName: "DNS/Email Subject Alternative Names (SANs)",
}
fields["common_name"] = &framework.FieldSchema{
@ -131,12 +140,14 @@ be larger than the mount max TTL. Note:
this only has an effect when generating
a CA cert or signing a CA cert, not when
generating a CSR for an intermediate CA.`,
DisplayName: "TTL",
}
fields["ou"] = &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `If set, OU (OrganizationalUnit) will be set to
this value.`,
DisplayName: "OU (Organizational Unit)",
}
fields["organization"] = &framework.FieldSchema{
@ -155,24 +166,28 @@ this value.`,
Type: framework.TypeCommaStringSlice,
Description: `If set, Locality will be set to
this value.`,
DisplayName: "Locality/City",
}
fields["province"] = &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `If set, Province will be set to
this value.`,
DisplayName: "Province/State",
}
fields["street_address"] = &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `If set, Street Address will be set to
this value.`,
DisplayName: "Street Address",
}
fields["postal_code"] = &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `If set, Postal Code will be set to
this value.`,
DisplayName: "Postal Code",
}
fields["serial_number"] = &framework.FieldSchema{
@ -209,8 +224,8 @@ the key_type.`,
Default: "rsa",
Description: `The type of key to use; defaults to RSA. "rsa"
and "ec" are the only valid values.`,
AllowedValues: []interface{}{"rsa", "ec"},
}
return fields
}
@ -226,6 +241,7 @@ func addCAIssueFields(fields map[string]*framework.FieldSchema) map[string]*fram
fields["permitted_dns_domains"] = &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `Domains for which this certificate is allowed to sign or issue child certificates. If set, all DNS names (subject and alt) on child certs must be exact matches or subsets of the given domains (see https://tools.ietf.org/html/rfc5280#section-4.2.1.10).`,
DisplayName: "Permitted DNS Domains",
}
return fields

View file

@ -31,6 +31,11 @@ func pathRoles(b *backend) *framework.Path {
return &framework.Path{
Pattern: "roles/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"backend": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Backend Type",
},
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the role",
@ -42,11 +47,13 @@ func pathRoles(b *backend) *framework.Path {
requested. The lease duration controls the expiration
of certificates issued by this backend. Defaults to
the value of max_ttl.`,
DisplayName: "TTL",
},
"max_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
Description: "The maximum allowed lease duration",
DisplayName: "Max TTL",
},
"allow_localhost": &framework.FieldSchema{
@ -107,17 +114,20 @@ CN and SANs. Defaults to true.`,
Default: true,
Description: `If set, IP Subject Alternative Names are allowed.
Any valid IP is accepted.`,
DisplayName: "Allow IP Subject Alternative Names",
},
"allowed_uri_sans": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `If set, an array of allowed URIs to put in the URI Subject Alternative Names.
Any valid URI is accepted, these values support globbing.`,
DisplayName: "Allowed URI Subject Alternative Names",
},
"allowed_other_sans": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `If set, an array of allowed other names to put in SANs. These values support globbing and must be in the format <oid>;<type>:<value>. Currently only "utf8" is a valid type. All values, including globbing values, must use this syntax, with the exception being a single "*" which allows any OID and any value (but type must still be utf8).`,
DisplayName: "Allowed Other Subject Alternative Names",
},
"allowed_serial_numbers": &framework.FieldSchema{
@ -156,6 +166,7 @@ protection use. Defaults to false.`,
Default: "rsa",
Description: `The type of key to use; defaults to RSA. "rsa"
and "ec" are the only valid values.`,
AllowedValues: []interface{}{"rsa", "ec"},
},
"key_bits": &framework.FieldSchema{
@ -175,6 +186,7 @@ https://golang.org/pkg/crypto/x509/#KeyUsage
-- simply drop the "KeyUsage" part of the name.
To remove all key usages from being set, set
this value to an empty list.`,
DisplayValue: "DigitalSignature,KeyAgreement,KeyEncipherment",
},
"ext_key_usage": &framework.FieldSchema{
@ -185,11 +197,13 @@ https://golang.org/pkg/crypto/x509/#ExtKeyUsage
-- simply drop the "ExtKeyUsage" part of the name.
To remove all key usages from being set, set
this value to an empty list.`,
DisplayName: "Extended Key Usage",
},
"ext_key_usage_oids": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `A comma-separated string or list of extended key usage oids.`,
DisplayName: "Extended Key Usage OIDs",
},
"use_csr_common_name": &framework.FieldSchema{
@ -199,6 +213,7 @@ this value to an empty list.`,
the common name in the CSR will be used. This
does *not* include any requested Subject Alternative
Names. Defaults to true.`,
DisplayName: "Use CSR Common Name",
},
"use_csr_sans": &framework.FieldSchema{
@ -207,12 +222,14 @@ Names. Defaults to true.`,
Description: `If set, when used with a signing profile,
the SANs in the CSR will be used. This does *not*
include the Common Name (cn). Defaults to true.`,
DisplayName: "Use CSR Subject Alternative Names",
},
"ou": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `If set, OU (OrganizationalUnit) will be set to
this value in certificates issued by this role.`,
DisplayName: "Organizational Unit",
},
"organization": &framework.FieldSchema{
@ -231,12 +248,14 @@ this value in certificates issued by this role.`,
Type: framework.TypeCommaStringSlice,
Description: `If set, Locality will be set to
this value in certificates issued by this role.`,
DisplayName: "Locality/City",
},
"province": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `If set, Province will be set to
this value in certificates issued by this role.`,
DisplayName: "Province/State",
},
"street_address": &framework.FieldSchema{
@ -263,6 +282,7 @@ to the CRL. When large number of certificates are generated with long
lifetimes, it is recommended that lease generation be disabled, as large amount of
leases adversely affect the startup time of Vault.`,
},
"no_store": &framework.FieldSchema{
Type: framework.TypeBool,
Description: `
@ -273,18 +293,23 @@ or revoked, so this option is recommended only for certificates that are
non-sensitive, or extremely short-lived. This option implies a value of "false"
for "generate_lease".`,
},
"require_cn": &framework.FieldSchema{
Type: framework.TypeBool,
Default: true,
Description: `If set to false, makes the 'common_name' field optional while generating a certificate.`,
DisplayName: "Use CSR Common Name",
},
"policy_identifiers": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: `A comma-separated string or list of policy oids.`,
},
"basic_constraints_valid_for_non_ca": &framework.FieldSchema{
Type: framework.TypeBool,
Description: `Mark Basic Constraints valid when issuing non-CA certificates.`,
DisplayName: "Basic Constraints Valid for Non-CA",
},
"not_before_duration": &framework.FieldSchema{
Type: framework.TypeDurationSecond,

View file

@ -93,6 +93,7 @@ func pathRoles(b *backend) *framework.Path {
credential is being generated for other users, Vault uses this admin
username to login to remote host and install the generated credential
for the other user.`,
DisplayName: "Admin Username",
},
"default_user": &framework.FieldSchema{
Type: framework.TypeString,
@ -101,6 +102,7 @@ func pathRoles(b *backend) *framework.Path {
Default username for which a credential will be generated.
When the endpoint 'creds/' is used without a username, this
value will be used as default username.`,
DisplayName: "Default Username",
},
"cidr_list": &framework.FieldSchema{
Type: framework.TypeString,
@ -108,6 +110,7 @@ func pathRoles(b *backend) *framework.Path {
[Optional for Dynamic type] [Optional for OTP type] [Not applicable for CA type]
Comma separated list of CIDR blocks for which the role is applicable for.
CIDR blocks can belong to more than one role.`,
DisplayName: "CIDR List",
},
"exclude_cidr_list": &framework.FieldSchema{
Type: framework.TypeString,
@ -116,6 +119,7 @@ func pathRoles(b *backend) *framework.Path {
Comma separated list of CIDR blocks. IP addresses belonging to these blocks are not
accepted by the role. This is particularly useful when big CIDR blocks are being used
by the role and certain parts of it needs to be kept out.`,
DisplayName: "Exclude CIDR List",
},
"port": &framework.FieldSchema{
Type: framework.TypeInt,
@ -125,6 +129,7 @@ func pathRoles(b *backend) *framework.Path {
play any role in creation of OTP. For 'otp' type, this is just a way
to inform client about the port number to use. Port number will be
returned to client by Vault server along with OTP.`,
DisplayValue: 22,
},
"key_type": &framework.FieldSchema{
Type: framework.TypeString,
@ -132,6 +137,8 @@ func pathRoles(b *backend) *framework.Path {
[Required for all types]
Type of key used to login to hosts. It can be either 'otp', 'dynamic' or 'ca'.
'otp' type requires agent to be installed in remote hosts.`,
AllowedValues: []interface{}{"otp", "dynamic","ca"},
DisplayValue: "ca",
},
"key_bits": &framework.FieldSchema{
Type: framework.TypeInt,
@ -188,6 +195,7 @@ func pathRoles(b *backend) *framework.Path {
requested. The lease duration controls the expiration
of certificates issued by this backend. Defaults to
the value of max_ttl.`,
DisplayName: "TTL",
},
"max_ttl": &framework.FieldSchema{
Type: framework.TypeDurationSecond,
@ -195,6 +203,7 @@ func pathRoles(b *backend) *framework.Path {
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
The maximum allowed lease duration
`,
DisplayName: "Max TTL",
},
"allowed_critical_options": &framework.FieldSchema{
Type: framework.TypeString,
@ -202,7 +211,7 @@ func pathRoles(b *backend) *framework.Path {
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
A comma-separated list of critical options that certificates can have when signed.
To allow any critical options, set this to an empty string.
`,
`,
},
"allowed_extensions": &framework.FieldSchema{
Type: framework.TypeString,
@ -238,7 +247,7 @@ func pathRoles(b *backend) *framework.Path {
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
If set, certificates are allowed to be signed for use as a 'user'.
`,
Default: false,
Default: false,
},
"allow_host_certificates": &framework.FieldSchema{
Type: framework.TypeBool,
@ -246,7 +255,7 @@ func pathRoles(b *backend) *framework.Path {
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
If set, certificates are allowed to be signed for use as a 'host'.
`,
Default: false,
Default: false,
},
"allow_bare_domains": &framework.FieldSchema{
Type: framework.TypeBool,
@ -272,6 +281,7 @@ func pathRoles(b *backend) *framework.Path {
When false, the key ID will always be the token display name.
The key ID is logged by the SSH server and can be useful for auditing.
`,
DisplayName: "Allow User Key IDs",
},
"key_id_format": &framework.FieldSchema{
Type: framework.TypeString,
@ -282,6 +292,7 @@ func pathRoles(b *backend) *framework.Path {
the token used to make the request. '{{role_name}}' - The name of the role signing the request.
'{{public_key_hash}}' - A SHA256 checksum of the public key that is being signed.
`,
DisplayName: "Key ID Format",
},
"allowed_user_key_lengths": &framework.FieldSchema{
Type: framework.TypeMap,

View file

@ -4,6 +4,10 @@ import (
"context"
"fmt"
"io"
"net"
"net/http"
"time"
"os"
"sort"
"strings"
@ -23,6 +27,7 @@ import (
"github.com/hashicorp/vault/command/agent/auth/gcp"
"github.com/hashicorp/vault/command/agent/auth/jwt"
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
"github.com/hashicorp/vault/command/agent/cache"
"github.com/hashicorp/vault/command/agent/config"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/command/agent/sink/file"
@ -218,19 +223,6 @@ func (c *AgentCommand) Run(args []string) int {
info["cgo"] = "enabled"
}
// Server configuration output
padding := 24
sort.Strings(infoKeys)
c.UI.Output("==> Vault agent configuration:\n")
for _, k := range infoKeys {
c.UI.Output(fmt.Sprintf(
"%s%s: %s",
strings.Repeat(" ", padding-len(k)),
strings.Title(k),
info[k]))
}
c.UI.Output("")
// Tests might not want to start a vault server and just want to verify
// the configuration.
if c.flagTestVerifyOnly {
@ -332,10 +324,92 @@ func (c *AgentCommand) Run(args []string) int {
EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials,
})
// Start things running
// Start auto-auth and sink servers
go ah.Run(ctx, method)
go ss.Run(ctx, ah.OutputCh, sinks)
// Parse agent listener configurations
if config.Cache != nil && len(config.Cache.Listeners) != 0 {
cacheLogger := c.logger.Named("cache")
// Create the API proxier
apiProxy := cache.NewAPIProxy(&cache.APIProxyConfig{
Logger: cacheLogger.Named("apiproxy"),
})
// Create the lease cache proxier and set its underlying proxier to
// the API proxier.
leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{
BaseContext: ctx,
Proxier: apiProxy,
Logger: cacheLogger.Named("leasecache"),
})
if err != nil {
c.UI.Error(fmt.Sprintf("Error creating lease cache: %v", err))
return 1
}
// Create a muxer and add paths relevant for the lease cache layer
mux := http.NewServeMux()
mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx))
mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, config.Cache.UseAutoAuthToken, c.client))
var listeners []net.Listener
for i, lnConfig := range config.Cache.Listeners {
listener, props, _, err := cache.ServerListener(lnConfig, c.logWriter, c.UI)
if err != nil {
c.UI.Error(fmt.Sprintf("Error parsing listener configuration: %v", err))
return 1
}
listeners = append(listeners, listener)
scheme := "https://"
if props["tls"] == "disabled" {
scheme = "http://"
}
if lnConfig.Type == "unix" {
scheme = "unix://"
}
infoKey := fmt.Sprintf("api address %d", i+1)
info[infoKey] = scheme + listener.Addr().String()
infoKeys = append(infoKeys, infoKey)
cacheLogger.Info("starting listener", "addr", listener.Addr().String())
server := &http.Server{
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
IdleTimeout: 5 * time.Minute,
ErrorLog: cacheLogger.StandardLogger(nil),
}
go server.Serve(listener)
}
// Ensure that listeners are closed at all the exits
listenerCloseFunc := func() {
for _, ln := range listeners {
ln.Close()
}
}
defer c.cleanupGuard.Do(listenerCloseFunc)
}
// Server configuration output
padding := 24
sort.Strings(infoKeys)
c.UI.Output("==> Vault agent configuration:\n")
for _, k := range infoKeys {
c.UI.Output(fmt.Sprintf(
"%s%s: %s",
strings.Repeat(" ", padding-len(k)),
strings.Title(k),
info[k]))
}
c.UI.Output("")
// Release the log gate.
c.logGate.Flush()

61
command/agent/cache/api_proxy.go vendored Normal file
View file

@ -0,0 +1,61 @@
package cache
import (
"bytes"
"context"
"io/ioutil"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
)
// APIProxy is an implementation of the proxier interface that is used to
// forward the request to Vault and get the response.
type APIProxy struct {
logger hclog.Logger
}
type APIProxyConfig struct {
Logger hclog.Logger
}
func NewAPIProxy(config *APIProxyConfig) Proxier {
return &APIProxy{
logger: config.Logger,
}
}
func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
client, err := api.NewClient(api.DefaultConfig())
if err != nil {
return nil, err
}
client.SetToken(req.Token)
client.SetHeaders(req.Request.Header)
fwReq := client.NewRequest(req.Request.Method, req.Request.URL.Path)
fwReq.BodyBytes = req.RequestBody
// Make the request to Vault and get the response
ap.logger.Info("forwarding request", "path", req.Request.URL.Path, "method", req.Request.Method)
resp, err := client.RawRequestWithContext(ctx, fwReq)
if err != nil {
return nil, err
}
// Parse and reset response body
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
ap.logger.Error("failed to read request body", "error", err)
return nil, err
}
if resp.Body != nil {
resp.Body.Close()
}
resp.Body = ioutil.NopCloser(bytes.NewBuffer(respBody))
return &SendResponse{
Response: resp,
ResponseBody: respBody,
}, nil
}

43
command/agent/cache/api_proxy_test.go vendored Normal file
View file

@ -0,0 +1,43 @@
package cache
import (
"testing"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/logging"
"github.com/hashicorp/vault/helper/namespace"
)
func TestCache_APIProxy(t *testing.T) {
cleanup, client, _, _ := setupClusterAndAgent(namespace.RootContext(nil), t, nil)
defer cleanup()
proxier := NewAPIProxy(&APIProxyConfig{
Logger: logging.NewVaultLogger(hclog.Trace),
})
r := client.NewRequest("GET", "/v1/sys/health")
req, err := r.ToHTTP()
if err != nil {
t.Fatal(err)
}
resp, err := proxier.Send(namespace.RootContext(nil), &SendRequest{
Request: req,
})
if err != nil {
t.Fatal(err)
}
var result api.HealthResponse
err = jsonutil.DecodeJSONFromReader(resp.Response.Body, &result)
if err != nil {
t.Fatal(err)
}
if !result.Initialized || result.Sealed || result.Standby {
t.Fatalf("bad sys/health response")
}
}

926
command/agent/cache/cache_test.go vendored Normal file
View file

@ -0,0 +1,926 @@
package cache
import (
"context"
"fmt"
"net"
"net/http"
"os"
"testing"
"time"
"github.com/hashicorp/vault/logical"
"github.com/go-test/deep"
hclog "github.com/hashicorp/go-hclog"
kv "github.com/hashicorp/vault-plugin-secrets-kv"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/credential/userpass"
"github.com/hashicorp/vault/helper/logging"
"github.com/hashicorp/vault/helper/namespace"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/vault"
)
const policyAdmin = `
path "*" {
capabilities = ["sudo", "create", "read", "update", "delete", "list"]
}
`
// setupClusterAndAgent is a helper func used to set up a test cluster and
// caching agent. It returns a cleanup func that should be deferred immediately
// along with two clients, one for direct cluster communication and another to
// talk to the caching agent.
func setupClusterAndAgent(ctx context.Context, t *testing.T, coreConfig *vault.CoreConfig) (func(), *api.Client, *api.Client, *LeaseCache) {
t.Helper()
if ctx == nil {
ctx = context.Background()
}
// Handle sane defaults
if coreConfig == nil {
coreConfig = &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: logging.NewVaultLogger(hclog.Trace),
CredentialBackends: map[string]logical.Factory{
"userpass": userpass.Factory,
},
}
}
if coreConfig.CredentialBackends == nil {
coreConfig.CredentialBackends = map[string]logical.Factory{
"userpass": userpass.Factory,
}
}
// Init new test cluster
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
// clusterClient is the client that is used to talk directly to the cluster.
clusterClient := cores[0].Client
// Add an admin policy
if err := clusterClient.Sys().PutPolicy("admin", policyAdmin); err != nil {
t.Fatal(err)
}
// Set up the userpass auth backend and an admin user. Used for getting a token
// for the agent later down in this func.
clusterClient.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
Type: "userpass",
})
_, err := clusterClient.Logical().Write("auth/userpass/users/foo", map[string]interface{}{
"password": "bar",
"policies": []string{"admin"},
})
if err != nil {
t.Fatal(err)
}
// Set up env vars for agent consumption
origEnvVaultAddress := os.Getenv(api.EnvVaultAddress)
os.Setenv(api.EnvVaultAddress, clusterClient.Address())
origEnvVaultCACert := os.Getenv(api.EnvVaultCACert)
os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir))
cacheLogger := logging.NewVaultLogger(hclog.Trace).Named("cache")
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
// Create the API proxier
apiProxy := NewAPIProxy(&APIProxyConfig{
Logger: cacheLogger.Named("apiproxy"),
})
// Create the lease cache proxier and set its underlying proxier to
// the API proxier.
leaseCache, err := NewLeaseCache(&LeaseCacheConfig{
BaseContext: ctx,
Proxier: apiProxy,
Logger: cacheLogger.Named("leasecache"),
})
if err != nil {
t.Fatal(err)
}
// Create a muxer and add paths relevant for the lease cache layer
mux := http.NewServeMux()
mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx))
mux.Handle("/", Handler(ctx, cacheLogger, leaseCache, false, clusterClient))
server := &http.Server{
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
IdleTimeout: 5 * time.Minute,
ErrorLog: cacheLogger.StandardLogger(nil),
}
go server.Serve(listener)
// testClient is the client that is used to talk to the agent for proxying/caching behavior.
testClient, err := clusterClient.Clone()
if err != nil {
t.Fatal(err)
}
if err := testClient.SetAddress("http://" + listener.Addr().String()); err != nil {
t.Fatal(err)
}
// Login via userpass method to derive a managed token. Set that token as the
// testClient's token
resp, err := testClient.Logical().Write("auth/userpass/login/foo", map[string]interface{}{
"password": "bar",
})
if err != nil {
t.Fatal(err)
}
testClient.SetToken(resp.Auth.ClientToken)
cleanup := func() {
cluster.Cleanup()
os.Setenv(api.EnvVaultAddress, origEnvVaultAddress)
os.Setenv(api.EnvVaultCACert, origEnvVaultCACert)
listener.Close()
}
return cleanup, clusterClient, testClient, leaseCache
}
func tokenRevocationValidation(t *testing.T, sampleSpace map[string]string, expected map[string]string, leaseCache *LeaseCache) {
t.Helper()
for val, valType := range sampleSpace {
index, err := leaseCache.db.Get(valType, val)
if err != nil {
t.Fatal(err)
}
if expected[val] == "" && index != nil {
t.Fatalf("failed to evict index from the cache: type: %q, value: %q", valType, val)
}
if expected[val] != "" && index == nil {
t.Fatalf("evicted an undesired index from cache: type: %q, value: %q", valType, val)
}
}
}
func TestCache_TokenRevocations_RevokeOrphan(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": vault.LeasedPassthroughBackendFactory,
},
}
sampleSpace := make(map[string]string)
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
defer cleanup()
token1 := testClient.Token()
sampleSpace[token1] = "token"
// Mount the kv backend
err := testClient.Sys().Mount("kv", &api.MountInput{
Type: "kv",
})
if err != nil {
t.Fatal(err)
}
// Create a secret in the backend
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
// Read the secret and create a lease
leaseResp, err := testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease1 := leaseResp.LeaseID
sampleSpace[lease1] = "lease"
resp, err := testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token2 := resp.Auth.ClientToken
sampleSpace[token2] = "token"
testClient.SetToken(token2)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease2 := leaseResp.LeaseID
sampleSpace[lease2] = "lease"
resp, err = testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token3 := resp.Auth.ClientToken
sampleSpace[token3] = "token"
testClient.SetToken(token3)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease3 := leaseResp.LeaseID
sampleSpace[lease3] = "lease"
expected := make(map[string]string)
for k, v := range sampleSpace {
expected[k] = v
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
// Revoke-orphan the intermediate token. This should result in its own
// eviction and evictions of the revoked token's leases. All other things
// including the child tokens and leases of the child tokens should be
// untouched.
testClient.SetToken(token2)
err = testClient.Auth().Token().RevokeOrphan(token2)
if err != nil {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
expected = map[string]string{
token1: "token",
lease1: "lease",
token3: "token",
lease3: "lease",
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
}
func TestCache_TokenRevocations_LeafLevelToken(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": vault.LeasedPassthroughBackendFactory,
},
}
sampleSpace := make(map[string]string)
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
defer cleanup()
token1 := testClient.Token()
sampleSpace[token1] = "token"
// Mount the kv backend
err := testClient.Sys().Mount("kv", &api.MountInput{
Type: "kv",
})
if err != nil {
t.Fatal(err)
}
// Create a secret in the backend
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
// Read the secret and create a lease
leaseResp, err := testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease1 := leaseResp.LeaseID
sampleSpace[lease1] = "lease"
resp, err := testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token2 := resp.Auth.ClientToken
sampleSpace[token2] = "token"
testClient.SetToken(token2)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease2 := leaseResp.LeaseID
sampleSpace[lease2] = "lease"
resp, err = testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token3 := resp.Auth.ClientToken
sampleSpace[token3] = "token"
testClient.SetToken(token3)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease3 := leaseResp.LeaseID
sampleSpace[lease3] = "lease"
expected := make(map[string]string)
for k, v := range sampleSpace {
expected[k] = v
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
// Revoke the lef token. This should evict all the leases belonging to this
// token, evict entries for all the child tokens and their respective
// leases.
testClient.SetToken(token3)
err = testClient.Auth().Token().RevokeSelf("")
if err != nil {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
expected = map[string]string{
token1: "token",
lease1: "lease",
token2: "token",
lease2: "lease",
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
}
func TestCache_TokenRevocations_IntermediateLevelToken(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": vault.LeasedPassthroughBackendFactory,
},
}
sampleSpace := make(map[string]string)
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
defer cleanup()
token1 := testClient.Token()
sampleSpace[token1] = "token"
// Mount the kv backend
err := testClient.Sys().Mount("kv", &api.MountInput{
Type: "kv",
})
if err != nil {
t.Fatal(err)
}
// Create a secret in the backend
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
// Read the secret and create a lease
leaseResp, err := testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease1 := leaseResp.LeaseID
sampleSpace[lease1] = "lease"
resp, err := testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token2 := resp.Auth.ClientToken
sampleSpace[token2] = "token"
testClient.SetToken(token2)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease2 := leaseResp.LeaseID
sampleSpace[lease2] = "lease"
resp, err = testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token3 := resp.Auth.ClientToken
sampleSpace[token3] = "token"
testClient.SetToken(token3)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease3 := leaseResp.LeaseID
sampleSpace[lease3] = "lease"
expected := make(map[string]string)
for k, v := range sampleSpace {
expected[k] = v
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
// Revoke the second level token. This should evict all the leases
// belonging to this token, evict entries for all the child tokens and
// their respective leases.
testClient.SetToken(token2)
err = testClient.Auth().Token().RevokeSelf("")
if err != nil {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
expected = map[string]string{
token1: "token",
lease1: "lease",
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
}
func TestCache_TokenRevocations_TopLevelToken(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": vault.LeasedPassthroughBackendFactory,
},
}
sampleSpace := make(map[string]string)
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
defer cleanup()
token1 := testClient.Token()
sampleSpace[token1] = "token"
// Mount the kv backend
err := testClient.Sys().Mount("kv", &api.MountInput{
Type: "kv",
})
if err != nil {
t.Fatal(err)
}
// Create a secret in the backend
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
// Read the secret and create a lease
leaseResp, err := testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease1 := leaseResp.LeaseID
sampleSpace[lease1] = "lease"
resp, err := testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token2 := resp.Auth.ClientToken
sampleSpace[token2] = "token"
testClient.SetToken(token2)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease2 := leaseResp.LeaseID
sampleSpace[lease2] = "lease"
resp, err = testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token3 := resp.Auth.ClientToken
sampleSpace[token3] = "token"
testClient.SetToken(token3)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease3 := leaseResp.LeaseID
sampleSpace[lease3] = "lease"
expected := make(map[string]string)
for k, v := range sampleSpace {
expected[k] = v
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
// Revoke the top level token. This should evict all the leases belonging
// to this token, evict entries for all the child tokens and their
// respective leases.
testClient.SetToken(token1)
err = testClient.Auth().Token().RevokeSelf("")
if err != nil {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
expected = make(map[string]string)
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
}
func TestCache_TokenRevocations_Shutdown(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": vault.LeasedPassthroughBackendFactory,
},
}
sampleSpace := make(map[string]string)
ctx, rootCancelFunc := context.WithCancel(namespace.RootContext(nil))
cleanup, _, testClient, leaseCache := setupClusterAndAgent(ctx, t, coreConfig)
defer cleanup()
token1 := testClient.Token()
sampleSpace[token1] = "token"
// Mount the kv backend
err := testClient.Sys().Mount("kv", &api.MountInput{
Type: "kv",
})
if err != nil {
t.Fatal(err)
}
// Create a secret in the backend
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
// Read the secret and create a lease
leaseResp, err := testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease1 := leaseResp.LeaseID
sampleSpace[lease1] = "lease"
resp, err := testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token2 := resp.Auth.ClientToken
sampleSpace[token2] = "token"
testClient.SetToken(token2)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease2 := leaseResp.LeaseID
sampleSpace[lease2] = "lease"
resp, err = testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token3 := resp.Auth.ClientToken
sampleSpace[token3] = "token"
testClient.SetToken(token3)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease3 := leaseResp.LeaseID
sampleSpace[lease3] = "lease"
expected := make(map[string]string)
for k, v := range sampleSpace {
expected[k] = v
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
rootCancelFunc()
time.Sleep(1 * time.Second)
// Ensure that all the entries are now gone
expected = make(map[string]string)
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
}
func TestCache_TokenRevocations_BaseContextCancellation(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": vault.LeasedPassthroughBackendFactory,
},
}
sampleSpace := make(map[string]string)
cleanup, _, testClient, leaseCache := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
defer cleanup()
token1 := testClient.Token()
sampleSpace[token1] = "token"
// Mount the kv backend
err := testClient.Sys().Mount("kv", &api.MountInput{
Type: "kv",
})
if err != nil {
t.Fatal(err)
}
// Create a secret in the backend
_, err = testClient.Logical().Write("kv/foo", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
// Read the secret and create a lease
leaseResp, err := testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease1 := leaseResp.LeaseID
sampleSpace[lease1] = "lease"
resp, err := testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token2 := resp.Auth.ClientToken
sampleSpace[token2] = "token"
testClient.SetToken(token2)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease2 := leaseResp.LeaseID
sampleSpace[lease2] = "lease"
resp, err = testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token3 := resp.Auth.ClientToken
sampleSpace[token3] = "token"
testClient.SetToken(token3)
leaseResp, err = testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
lease3 := leaseResp.LeaseID
sampleSpace[lease3] = "lease"
expected := make(map[string]string)
for k, v := range sampleSpace {
expected[k] = v
}
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
// Cancel the base context of the lease cache. This should trigger
// evictions of all the entries from the cache.
leaseCache.baseCtxInfo.CancelFunc()
time.Sleep(1 * time.Second)
// Ensure that all the entries are now gone
expected = make(map[string]string)
tokenRevocationValidation(t, sampleSpace, expected, leaseCache)
}
func TestCache_NonCacheable(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": kv.Factory,
},
}
cleanup, _, testClient, _ := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
defer cleanup()
// Query mounts first
origMounts, err := testClient.Sys().ListMounts()
if err != nil {
t.Fatal(err)
}
// Mount a kv backend
if err := testClient.Sys().Mount("kv", &api.MountInput{
Type: "kv",
Options: map[string]string{
"version": "2",
},
}); err != nil {
t.Fatal(err)
}
// Query mounts again
newMounts, err := testClient.Sys().ListMounts()
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(origMounts, newMounts); diff == nil {
t.Logf("response #1: %#v", origMounts)
t.Logf("response #2: %#v", newMounts)
t.Fatal("expected requests to be not cached")
}
}
func TestCache_AuthResponse(t *testing.T) {
cleanup, _, testClient, _ := setupClusterAndAgent(namespace.RootContext(nil), t, nil)
defer cleanup()
resp, err := testClient.Logical().Write("auth/token/create", nil)
if err != nil {
t.Fatal(err)
}
token := resp.Auth.ClientToken
testClient.SetToken(token)
authTokeCreateReq := func(t *testing.T, policies map[string]interface{}) *api.Secret {
resp, err := testClient.Logical().Write("auth/token/create", policies)
if err != nil {
t.Fatal(err)
}
if resp.Auth == nil || resp.Auth.ClientToken == "" {
t.Fatalf("expected a valid client token in the response, got = %#v", resp)
}
return resp
}
// Test on auth response by creating a child token
{
proxiedResp := authTokeCreateReq(t, map[string]interface{}{
"policies": "default",
})
cachedResp := authTokeCreateReq(t, map[string]interface{}{
"policies": "default",
})
if diff := deep.Equal(proxiedResp.Auth.ClientToken, cachedResp.Auth.ClientToken); diff != nil {
t.Fatal(diff)
}
}
// Test on *non-renewable* auth response by creating a child root token
{
proxiedResp := authTokeCreateReq(t, nil)
cachedResp := authTokeCreateReq(t, nil)
if diff := deep.Equal(proxiedResp.Auth.ClientToken, cachedResp.Auth.ClientToken); diff != nil {
t.Fatal(diff)
}
}
}
func TestCache_LeaseResponse(t *testing.T) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: hclog.NewNullLogger(),
LogicalBackends: map[string]logical.Factory{
"kv": vault.LeasedPassthroughBackendFactory,
},
}
cleanup, client, testClient, _ := setupClusterAndAgent(namespace.RootContext(nil), t, coreConfig)
defer cleanup()
err := client.Sys().Mount("kv", &api.MountInput{
Type: "kv",
})
if err != nil {
t.Fatal(err)
}
// Test proxy by issuing two different requests
{
// Write data to the lease-kv backend
_, err := testClient.Logical().Write("kv/foo", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
_, err = testClient.Logical().Write("kv/foobar", map[string]interface{}{
"value": "bar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
firstResp, err := testClient.Logical().Read("kv/foo")
if err != nil {
t.Fatal(err)
}
secondResp, err := testClient.Logical().Read("kv/foobar")
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(firstResp, secondResp); diff == nil {
t.Logf("response: %#v", firstResp)
t.Fatal("expected proxied responses, got cached response on second request")
}
}
// Test caching behavior by issue the same request twice
{
_, err := testClient.Logical().Write("kv/baz", map[string]interface{}{
"value": "foo",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
proxiedResp, err := testClient.Logical().Read("kv/baz")
if err != nil {
t.Fatal(err)
}
cachedResp, err := testClient.Logical().Read("kv/baz")
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(proxiedResp, cachedResp); diff != nil {
t.Fatal(diff)
}
}
}

View file

@ -0,0 +1,265 @@
package cachememdb
import (
"errors"
"fmt"
memdb "github.com/hashicorp/go-memdb"
)
const (
tableNameIndexer = "indexer"
)
// CacheMemDB is the underlying cache database for storing indexes.
type CacheMemDB struct {
db *memdb.MemDB
}
// New creates a new instance of CacheMemDB.
func New() (*CacheMemDB, error) {
db, err := newDB()
if err != nil {
return nil, err
}
return &CacheMemDB{
db: db,
}, nil
}
func newDB() (*memdb.MemDB, error) {
cacheSchema := &memdb.DBSchema{
Tables: map[string]*memdb.TableSchema{
tableNameIndexer: &memdb.TableSchema{
Name: tableNameIndexer,
Indexes: map[string]*memdb.IndexSchema{
// This index enables fetching the cached item based on the
// identifier of the index.
IndexNameID: &memdb.IndexSchema{
Name: IndexNameID,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "ID",
},
},
// This index enables fetching all the entries in cache for
// a given request path, in a given namespace.
IndexNameRequestPath: &memdb.IndexSchema{
Name: IndexNameRequestPath,
Unique: false,
Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "Namespace",
},
&memdb.StringFieldIndex{
Field: "RequestPath",
},
},
},
},
// This index enables fetching all the entries in cache
// belonging to the leases of a given token.
IndexNameLeaseToken: &memdb.IndexSchema{
Name: IndexNameLeaseToken,
Unique: false,
AllowMissing: true,
Indexer: &memdb.StringFieldIndex{
Field: "LeaseToken",
},
},
// This index enables fetching all the entries in cache
// that are tied to the given token, regardless of the
// entries belonging to the token or belonging to the
// lease.
IndexNameToken: &memdb.IndexSchema{
Name: IndexNameToken,
Unique: true,
AllowMissing: true,
Indexer: &memdb.StringFieldIndex{
Field: "Token",
},
},
// This index enables fetching all the entries in cache for
// the given parent token.
IndexNameTokenParent: &memdb.IndexSchema{
Name: IndexNameTokenParent,
Unique: false,
AllowMissing: true,
Indexer: &memdb.StringFieldIndex{
Field: "TokenParent",
},
},
// This index enables fetching all the entries in cache for
// the given accessor.
IndexNameTokenAccessor: &memdb.IndexSchema{
Name: IndexNameTokenAccessor,
Unique: true,
AllowMissing: true,
Indexer: &memdb.StringFieldIndex{
Field: "TokenAccessor",
},
},
// This index enables fetching all the entries in cache for
// the given lease identifier.
IndexNameLease: &memdb.IndexSchema{
Name: IndexNameLease,
Unique: true,
AllowMissing: true,
Indexer: &memdb.StringFieldIndex{
Field: "Lease",
},
},
},
},
},
}
db, err := memdb.NewMemDB(cacheSchema)
if err != nil {
return nil, err
}
return db, nil
}
// Get returns the index based on the indexer and the index values provided.
func (c *CacheMemDB) Get(indexName string, indexValues ...interface{}) (*Index, error) {
if !validIndexName(indexName) {
return nil, fmt.Errorf("invalid index name %q", indexName)
}
raw, err := c.db.Txn(false).First(tableNameIndexer, indexName, indexValues...)
if err != nil {
return nil, err
}
if raw == nil {
return nil, nil
}
index, ok := raw.(*Index)
if !ok {
return nil, errors.New("unable to parse index value from the cache")
}
return index, nil
}
// Set stores the index into the cache.
func (c *CacheMemDB) Set(index *Index) error {
if index == nil {
return errors.New("nil index provided")
}
txn := c.db.Txn(true)
defer txn.Abort()
if err := txn.Insert(tableNameIndexer, index); err != nil {
return fmt.Errorf("unable to insert index into cache: %v", err)
}
txn.Commit()
return nil
}
// GetByPrefix returns all the cached indexes based on the index name and the
// value prefix.
func (c *CacheMemDB) GetByPrefix(indexName string, indexValues ...interface{}) ([]*Index, error) {
if !validIndexName(indexName) {
return nil, fmt.Errorf("invalid index name %q", indexName)
}
indexName = indexName + "_prefix"
// Get all the objects
iter, err := c.db.Txn(false).Get(tableNameIndexer, indexName, indexValues...)
if err != nil {
return nil, err
}
var indexes []*Index
for {
obj := iter.Next()
if obj == nil {
break
}
index, ok := obj.(*Index)
if !ok {
return nil, fmt.Errorf("failed to cast cached index")
}
indexes = append(indexes, index)
}
return indexes, nil
}
// Evict removes an index from the cache based on index name and value.
func (c *CacheMemDB) Evict(indexName string, indexValues ...interface{}) error {
index, err := c.Get(indexName, indexValues...)
if err != nil {
return fmt.Errorf("unable to fetch index on cache deletion: %v", err)
}
if index == nil {
return nil
}
txn := c.db.Txn(true)
defer txn.Abort()
if err := txn.Delete(tableNameIndexer, index); err != nil {
return fmt.Errorf("unable to delete index from cache: %v", err)
}
txn.Commit()
return nil
}
// EvictAll removes all matching indexes from the cache based on index name and value.
func (c *CacheMemDB) EvictAll(indexName, indexValue string) error {
return c.batchEvict(false, indexName, indexValue)
}
// EvictByPrefix removes all matching prefix indexes from the cache based on index name and prefix.
func (c *CacheMemDB) EvictByPrefix(indexName, indexPrefix string) error {
return c.batchEvict(true, indexName, indexPrefix)
}
// batchEvict is a helper that supports eviction based on absolute and prefixed index values.
func (c *CacheMemDB) batchEvict(isPrefix bool, indexName string, indexValues ...interface{}) error {
if !validIndexName(indexName) {
return fmt.Errorf("invalid index name %q", indexName)
}
if isPrefix {
indexName = indexName + "_prefix"
}
txn := c.db.Txn(true)
defer txn.Abort()
_, err := txn.DeleteAll(tableNameIndexer, indexName, indexValues...)
if err != nil {
return err
}
txn.Commit()
return nil
}
// Flush resets the underlying cache object.
func (c *CacheMemDB) Flush() error {
newDB, err := newDB()
if err != nil {
return err
}
c.db = newDB
return nil
}

View file

@ -0,0 +1,392 @@
package cachememdb
import (
"context"
"testing"
"github.com/go-test/deep"
)
func testContextInfo() *ContextInfo {
ctx, cancelFunc := context.WithCancel(context.Background())
return &ContextInfo{
Ctx: ctx,
CancelFunc: cancelFunc,
}
}
func TestNew(t *testing.T) {
_, err := New()
if err != nil {
t.Fatal(err)
}
}
func TestCacheMemDB_Get(t *testing.T) {
cache, err := New()
if err != nil {
t.Fatal(err)
}
// Test invalid index name
_, err = cache.Get("foo", "bar")
if err == nil {
t.Fatal("expected error")
}
// Test on empty cache
index, err := cache.Get(IndexNameID, "foo")
if err != nil {
t.Fatal(err)
}
if index != nil {
t.Fatalf("expected nil index, got: %v", index)
}
// Populate cache
in := &Index{
ID: "test_id",
Namespace: "test_ns/",
RequestPath: "/v1/request/path",
Token: "test_token",
TokenAccessor: "test_accessor",
Lease: "test_lease",
Response: []byte("hello world"),
}
if err := cache.Set(in); err != nil {
t.Fatal(err)
}
testCases := []struct {
name string
indexName string
indexValues []interface{}
}{
{
"by_index_id",
"id",
[]interface{}{in.ID},
},
{
"by_request_path",
"request_path",
[]interface{}{in.Namespace, in.RequestPath},
},
{
"by_lease",
"lease",
[]interface{}{in.Lease},
},
{
"by_token",
"token",
[]interface{}{in.Token},
},
{
"by_token_accessor",
"token_accessor",
[]interface{}{in.TokenAccessor},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
out, err := cache.Get(tc.indexName, tc.indexValues...)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(in, out); diff != nil {
t.Fatal(diff)
}
})
}
}
func TestCacheMemDB_GetByPrefix(t *testing.T) {
cache, err := New()
if err != nil {
t.Fatal(err)
}
// Test invalid index name
_, err = cache.GetByPrefix("foo", "bar", "baz")
if err == nil {
t.Fatal("expected error")
}
// Test on empty cache
index, err := cache.GetByPrefix(IndexNameRequestPath, "foo", "bar")
if err != nil {
t.Fatal(err)
}
if index != nil {
t.Fatalf("expected nil index, got: %v", index)
}
// Populate cache
in := &Index{
ID: "test_id",
Namespace: "test_ns/",
RequestPath: "/v1/request/path/1",
Token: "test_token",
TokenParent: "test_token_parent",
TokenAccessor: "test_accessor",
Lease: "path/to/test_lease/1",
LeaseToken: "test_lease_token",
Response: []byte("hello world"),
}
if err := cache.Set(in); err != nil {
t.Fatal(err)
}
// Populate cache
in2 := &Index{
ID: "test_id_2",
Namespace: "test_ns/",
RequestPath: "/v1/request/path/2",
Token: "test_token2",
TokenParent: "test_token_parent",
TokenAccessor: "test_accessor2",
Lease: "path/to/test_lease/2",
LeaseToken: "test_lease_token",
Response: []byte("hello world"),
}
if err := cache.Set(in2); err != nil {
t.Fatal(err)
}
testCases := []struct {
name string
indexName string
indexValues []interface{}
}{
{
"by_request_path",
"request_path",
[]interface{}{"test_ns/", "/v1/request/path"},
},
{
"by_lease",
"lease",
[]interface{}{"path/to/test_lease"},
},
{
"by_token_parent",
"token_parent",
[]interface{}{"test_token_parent"},
},
{
"by_lease_token",
"lease_token",
[]interface{}{"test_lease_token"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
out, err := cache.GetByPrefix(tc.indexName, tc.indexValues...)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal([]*Index{in, in2}, out); diff != nil {
t.Fatal(diff)
}
})
}
}
func TestCacheMemDB_Set(t *testing.T) {
cache, err := New()
if err != nil {
t.Fatal(err)
}
testCases := []struct {
name string
index *Index
wantErr bool
}{
{
"nil",
nil,
true,
},
{
"empty_fields",
&Index{},
true,
},
{
"missing_required_fields",
&Index{
Lease: "foo",
},
true,
},
{
"all_fields",
&Index{
ID: "test_id",
Namespace: "test_ns/",
RequestPath: "/v1/request/path",
Token: "test_token",
TokenAccessor: "test_accessor",
Lease: "test_lease",
RenewCtxInfo: testContextInfo(),
},
false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if err := cache.Set(tc.index); (err != nil) != tc.wantErr {
t.Fatalf("CacheMemDB.Set() error = %v, wantErr = %v", err, tc.wantErr)
}
})
}
}
func TestCacheMemDB_Evict(t *testing.T) {
cache, err := New()
if err != nil {
t.Fatal(err)
}
// Test on empty cache
if err := cache.Evict(IndexNameID, "foo"); err != nil {
t.Fatal(err)
}
testIndex := &Index{
ID: "test_id",
Namespace: "test_ns/",
RequestPath: "/v1/request/path",
Token: "test_token",
TokenAccessor: "test_token_accessor",
Lease: "test_lease",
RenewCtxInfo: testContextInfo(),
}
testCases := []struct {
name string
indexName string
indexValues []interface{}
insertIndex *Index
wantErr bool
}{
{
"empty_params",
"",
[]interface{}{""},
nil,
true,
},
{
"invalid_params",
"foo",
[]interface{}{"bar"},
nil,
true,
},
{
"by_id",
"id",
[]interface{}{"test_id"},
testIndex,
false,
},
{
"by_request_path",
"request_path",
[]interface{}{"test_ns/", "/v1/request/path"},
testIndex,
false,
},
{
"by_token",
"token",
[]interface{}{"test_token"},
testIndex,
false,
},
{
"by_token_accessor",
"token_accessor",
[]interface{}{"test_accessor"},
testIndex,
false,
},
{
"by_lease",
"lease",
[]interface{}{"test_lease"},
testIndex,
false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.insertIndex != nil {
if err := cache.Set(tc.insertIndex); err != nil {
t.Fatal(err)
}
}
if err := cache.Evict(tc.indexName, tc.indexValues...); (err != nil) != tc.wantErr {
t.Fatal(err)
}
// Verify that the cache doesn't contain the entry any more
index, err := cache.Get(tc.indexName, tc.indexValues...)
if (err != nil) != tc.wantErr {
t.Fatal(err)
}
if index != nil {
t.Fatalf("expected nil entry, got = %#v", index)
}
})
}
}
func TestCacheMemDB_Flush(t *testing.T) {
cache, err := New()
if err != nil {
t.Fatal(err)
}
// Populate cache
in := &Index{
ID: "test_id",
Token: "test_token",
Lease: "test_lease",
Namespace: "test_ns/",
RequestPath: "/v1/request/path",
Response: []byte("hello world"),
}
if err := cache.Set(in); err != nil {
t.Fatal(err)
}
// Reset the cache
if err := cache.Flush(); err != nil {
t.Fatal(err)
}
// Check the cache doesn't contain inserted index
out, err := cache.Get(IndexNameID, "test_id")
if err != nil {
t.Fatal(err)
}
if out != nil {
t.Fatalf("expected cache to be empty, got = %v", out)
}
}

97
command/agent/cache/cachememdb/index.go vendored Normal file
View file

@ -0,0 +1,97 @@
package cachememdb
import "context"
type ContextInfo struct {
Ctx context.Context
CancelFunc context.CancelFunc
DoneCh chan struct{}
}
// Index holds the response to be cached along with multiple other values that
// serve as pointers to refer back to this index.
type Index struct {
// ID is a value that uniquely represents the request held by this
// index. This is computed by serializing and hashing the response object.
// Required: true, Unique: true
ID string
// Token is the token that fetched the response held by this index
// Required: true, Unique: true
Token string
// TokenParent is the parent token of the token held by this index
// Required: false, Unique: false
TokenParent string
// TokenAccessor is the accessor of the token being cached in this index
// Required: true, Unique: true
TokenAccessor string
// Namespace is the namespace that was provided in the request path as the
// Vault namespace to query
Namespace string
// RequestPath is the path of the request that resulted in the response
// held by this index.
// Required: true, Unique: false
RequestPath string
// Lease is the identifier of the lease in Vault, that belongs to the
// response held by this index.
// Required: false, Unique: true
Lease string
// LeaseToken is the identifier of the token that created the lease held by
// this index.
// Required: false, Unique: false
LeaseToken string
// Response is the serialized response object that the agent is caching.
Response []byte
// RenewCtxInfo holds the context and the corresponding cancel func for the
// goroutine that manages the renewal of the secret belonging to the
// response in this index.
RenewCtxInfo *ContextInfo
}
type IndexName uint32
const (
// IndexNameID is the ID of the index constructed from the serialized request.
IndexNameID = "id"
// IndexNameLease is the lease of the index.
IndexNameLease = "lease"
// IndexNameRequestPath is the request path of the index.
IndexNameRequestPath = "request_path"
// IndexNameToken is the token of the index.
IndexNameToken = "token"
// IndexNameTokenAccessor is the token accessor of the index.
IndexNameTokenAccessor = "token_accessor"
// IndexNameTokenParent is the token parent of the index.
IndexNameTokenParent = "token_parent"
// IndexNameLeaseToken is the token that created the lease.
IndexNameLeaseToken = "lease_token"
)
func validIndexName(indexName string) bool {
switch indexName {
case "id":
case "lease":
case "request_path":
case "token":
case "token_accessor":
case "token_parent":
case "lease_token":
default:
return false
}
return true
}

155
command/agent/cache/handler.go vendored Normal file
View file

@ -0,0 +1,155 @@
package cache
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"github.com/hashicorp/errwrap"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/consts"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
)
func Handler(ctx context.Context, logger hclog.Logger, proxier Proxier, useAutoAuthToken bool, client *api.Client) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Info("received request", "path", r.URL.Path, "method", r.Method)
token := r.Header.Get(consts.AuthHeaderName)
if token == "" && useAutoAuthToken {
logger.Debug("using auto auth token")
token = client.Token()
}
// Parse and reset body.
reqBody, err := ioutil.ReadAll(r.Body)
if err != nil {
logger.Error("failed to read request body")
respondError(w, http.StatusInternalServerError, errors.New("failed to read request body"))
}
if r.Body != nil {
r.Body.Close()
}
r.Body = ioutil.NopCloser(bytes.NewBuffer(reqBody))
req := &SendRequest{
Token: token,
Request: r,
RequestBody: reqBody,
}
resp, err := proxier.Send(ctx, req)
if err != nil {
respondError(w, http.StatusInternalServerError, errwrap.Wrapf("failed to get the response: {{err}}", err))
return
}
err = processTokenLookupResponse(ctx, logger, useAutoAuthToken, client, req, resp)
if err != nil {
respondError(w, http.StatusInternalServerError, errwrap.Wrapf("failed to process token lookup response: {{err}}", err))
return
}
defer resp.Response.Body.Close()
copyHeader(w.Header(), resp.Response.Header)
w.WriteHeader(resp.Response.StatusCode)
io.Copy(w, resp.Response.Body)
return
})
}
// processTokenLookupResponse checks if the request was one of token
// lookup-self. If the auto-auth token was used to perform lookup-self, the
// identifier of the token and its accessor same will be stripped off of the
// response.
func processTokenLookupResponse(ctx context.Context, logger hclog.Logger, useAutoAuthToken bool, client *api.Client, req *SendRequest, resp *SendResponse) error {
// If auto-auth token is not being used, there is nothing to do.
if !useAutoAuthToken {
return nil
}
// If lookup responded with non 200 status, there is nothing to do.
if resp.Response.StatusCode != http.StatusOK {
return nil
}
// Strip-off namespace related information from the request and get the
// relative path of the request.
_, path := deriveNamespaceAndRevocationPath(req)
if path == vaultPathTokenLookupSelf {
logger.Info("stripping auto-auth token from the response", "path", req.Request.URL.Path, "method", req.Request.Method)
secret, err := api.ParseSecret(bytes.NewBuffer(resp.ResponseBody))
if err != nil {
return fmt.Errorf("failed to parse token lookup response: %v", err)
}
if secret != nil && secret.Data != nil && secret.Data["id"] != nil {
token, ok := secret.Data["id"].(string)
if !ok {
return fmt.Errorf("failed to type assert the token id in the response")
}
if token == client.Token() {
delete(secret.Data, "id")
delete(secret.Data, "accessor")
}
bodyBytes, err := json.Marshal(secret)
if err != nil {
return err
}
if resp.Response.Body != nil {
resp.Response.Body.Close()
}
resp.Response.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
resp.Response.ContentLength = int64(len(bodyBytes))
// Serialize and re-read the reponse
var respBytes bytes.Buffer
err = resp.Response.Write(&respBytes)
if err != nil {
return fmt.Errorf("failed to serialize the updated response: %v", err)
}
updatedResponse, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBytes.Bytes())), nil)
if err != nil {
return fmt.Errorf("failed to deserialize the updated response: %v", err)
}
resp.Response = &api.Response{
Response: updatedResponse,
}
resp.ResponseBody = bodyBytes
}
}
return nil
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func respondError(w http.ResponseWriter, status int, err error) {
logical.AdjustErrorStatusCode(&status, err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
resp := &vaulthttp.ErrorResponse{Errors: make([]string, 0, 1)}
if err != nil {
resp.Errors = append(resp.Errors, err.Error())
}
enc := json.NewEncoder(w)
enc.Encode(resp)
}

810
command/agent/cache/lease_cache.go vendored Normal file
View file

@ -0,0 +1,810 @@
package cache
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"github.com/hashicorp/errwrap"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
cachememdb "github.com/hashicorp/vault/command/agent/cache/cachememdb"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/namespace"
nshelper "github.com/hashicorp/vault/helper/namespace"
)
const (
vaultPathTokenCreate = "/v1/auth/token/create"
vaultPathTokenRevoke = "/v1/auth/token/revoke"
vaultPathTokenRevokeSelf = "/v1/auth/token/revoke-self"
vaultPathTokenRevokeAccessor = "/v1/auth/token/revoke-accessor"
vaultPathTokenRevokeOrphan = "/v1/auth/token/revoke-orphan"
vaultPathTokenLookupSelf = "/v1/auth/token/lookup-self"
vaultPathLeaseRevoke = "/v1/sys/leases/revoke"
vaultPathLeaseRevokeForce = "/v1/sys/leases/revoke-force"
vaultPathLeaseRevokePrefix = "/v1/sys/leases/revoke-prefix"
)
var (
contextIndexID = contextIndex{}
errInvalidType = errors.New("invalid type provided")
revocationPaths = []string{
strings.TrimPrefix(vaultPathTokenRevoke, "/v1"),
strings.TrimPrefix(vaultPathTokenRevokeSelf, "/v1"),
strings.TrimPrefix(vaultPathTokenRevokeAccessor, "/v1"),
strings.TrimPrefix(vaultPathTokenRevokeOrphan, "/v1"),
strings.TrimPrefix(vaultPathLeaseRevoke, "/v1"),
strings.TrimPrefix(vaultPathLeaseRevokeForce, "/v1"),
strings.TrimPrefix(vaultPathLeaseRevokePrefix, "/v1"),
}
)
type contextIndex struct{}
type cacheClearRequest struct {
Type string `json:"type"`
Value string `json:"value"`
Namespace string `json:"namespace"`
}
// LeaseCache is an implementation of Proxier that handles
// the caching of responses. It passes the incoming request
// to an underlying Proxier implementation.
type LeaseCache struct {
proxier Proxier
logger hclog.Logger
db *cachememdb.CacheMemDB
baseCtxInfo *ContextInfo
}
// LeaseCacheConfig is the configuration for initializing a new
// Lease.
type LeaseCacheConfig struct {
BaseContext context.Context
Proxier Proxier
Logger hclog.Logger
}
// ContextInfo holds a derived context and cancelFunc pair.
type ContextInfo struct {
Ctx context.Context
CancelFunc context.CancelFunc
DoneCh chan struct{}
}
// NewLeaseCache creates a new instance of a LeaseCache.
func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
if conf == nil {
return nil, errors.New("nil configuration provided")
}
if conf.Proxier == nil || conf.Logger == nil {
return nil, fmt.Errorf("missing configuration required params: %v", conf)
}
db, err := cachememdb.New()
if err != nil {
return nil, err
}
// Create a base context for the lease cache layer
baseCtx, baseCancelFunc := context.WithCancel(conf.BaseContext)
baseCtxInfo := &ContextInfo{
Ctx: baseCtx,
CancelFunc: baseCancelFunc,
}
return &LeaseCache{
proxier: conf.Proxier,
logger: conf.Logger,
db: db,
baseCtxInfo: baseCtxInfo,
}, nil
}
// Send performs a cache lookup on the incoming request. If it's a cache hit,
// it will return the cached response, otherwise it will delegate to the
// underlying Proxier and cache the received response.
func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
// Compute the index ID
id, err := computeIndexID(req)
if err != nil {
c.logger.Error("failed to compute cache key", "error", err)
return nil, err
}
// Check if the response for this request is already in the cache
index, err := c.db.Get(cachememdb.IndexNameID, id)
if err != nil {
return nil, err
}
// Cached request is found, deserialize the response and return early
if index != nil {
c.logger.Debug("returning cached response", "path", req.Request.URL.Path)
reader := bufio.NewReader(bytes.NewReader(index.Response))
resp, err := http.ReadResponse(reader, nil)
if err != nil {
c.logger.Error("failed to deserialize response", "error", err)
return nil, err
}
return &SendResponse{
Response: &api.Response{
Response: resp,
},
ResponseBody: index.Response,
}, nil
}
c.logger.Debug("forwarding request", "path", req.Request.URL.Path, "method", req.Request.Method)
// Pass the request down and get a response
resp, err := c.proxier.Send(ctx, req)
if err != nil {
return nil, err
}
// Get the namespace from the request header
namespace := req.Request.Header.Get(consts.NamespaceHeaderName)
// We need to populate an empty value since go-memdb will skip over indexes
// that contain empty values.
if namespace == "" {
namespace = "root/"
}
// Build the index to cache based on the response received
index = &cachememdb.Index{
ID: id,
Namespace: namespace,
RequestPath: req.Request.URL.Path,
}
secret, err := api.ParseSecret(bytes.NewBuffer(resp.ResponseBody))
if err != nil {
c.logger.Error("failed to parse response as secret", "error", err)
return nil, err
}
isRevocation, err := c.handleRevocationRequest(ctx, req, resp)
if err != nil {
c.logger.Error("failed to process the response", "error", err)
return nil, err
}
// If this is a revocation request, do not go through cache logic.
if isRevocation {
return resp, nil
}
// Fast path for responses with no secrets
if secret == nil {
c.logger.Debug("pass-through response; no secret in response", "path", req.Request.URL.Path, "method", req.Request.Method)
return resp, nil
}
// Short-circuit if the secret is not renewable
tokenRenewable, err := secret.TokenIsRenewable()
if err != nil {
c.logger.Error("failed to parse renewable param", "error", err)
return nil, err
}
if !secret.Renewable && !tokenRenewable {
c.logger.Debug("pass-through response; secret not renewable", "path", req.Request.URL.Path, "method", req.Request.Method)
return resp, nil
}
var renewCtxInfo *ContextInfo
switch {
case secret.LeaseID != "":
c.logger.Debug("processing lease response", "path", req.Request.URL.Path, "method", req.Request.Method)
entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token)
if err != nil {
return nil, err
}
// If the lease belongs to a token that is not managed by the agent,
// return the response without caching it.
if entry == nil {
c.logger.Debug("pass-through lease response; token not managed by agent", "path", req.Request.URL.Path, "method", req.Request.Method)
return resp, nil
}
// Derive a context for renewal using the token's context
newCtxInfo := new(ContextInfo)
newCtxInfo.Ctx, newCtxInfo.CancelFunc = context.WithCancel(entry.RenewCtxInfo.Ctx)
newCtxInfo.DoneCh = make(chan struct{})
renewCtxInfo = newCtxInfo
index.Lease = secret.LeaseID
index.LeaseToken = req.Token
case secret.Auth != nil:
c.logger.Debug("processing auth response", "path", req.Request.URL.Path, "method", req.Request.Method)
isNonOrphanNewToken := strings.HasPrefix(req.Request.URL.Path, vaultPathTokenCreate) && resp.Response.StatusCode == http.StatusOK && !secret.Auth.Orphan
// If the new token is a result of token creation endpoints (not from
// login endpoints), and if its a non-orphan, then the new token's
// context should be derived from the context of the parent token.
var parentCtx context.Context
if isNonOrphanNewToken {
entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token)
if err != nil {
return nil, err
}
// If parent token is not managed by the agent, child shouldn't be
// either.
if entry == nil {
c.logger.Debug("pass-through auth response; parent token not managed by agent", "path", req.Request.URL.Path, "method", req.Request.Method)
return resp, nil
}
c.logger.Debug("setting parent context", "path", req.Request.URL.Path, "method", req.Request.Method)
parentCtx = entry.RenewCtxInfo.Ctx
entry.TokenParent = req.Token
}
renewCtxInfo = c.createCtxInfo(parentCtx, secret.Auth.ClientToken)
index.Token = secret.Auth.ClientToken
index.TokenAccessor = secret.Auth.Accessor
default:
// We shouldn't be hitting this, but will err on the side of caution and
// simply proxy.
c.logger.Debug("pass-through response; secret without lease and token", "path", req.Request.URL.Path, "method", req.Request.Method)
return resp, nil
}
// Serialize the response to store it in the cached index
var respBytes bytes.Buffer
err = resp.Response.Write(&respBytes)
if err != nil {
c.logger.Error("failed to serialize response", "error", err)
return nil, err
}
// Reset the response body for upper layers to read
if resp.Response.Body != nil {
resp.Response.Body.Close()
}
resp.Response.Body = ioutil.NopCloser(bytes.NewBuffer(resp.ResponseBody))
// Set the index's Response
index.Response = respBytes.Bytes()
// Store the index ID in the renewer context
renewCtx := context.WithValue(renewCtxInfo.Ctx, contextIndexID, index.ID)
// Store the renewer context in the index
index.RenewCtxInfo = &cachememdb.ContextInfo{
Ctx: renewCtx,
CancelFunc: renewCtxInfo.CancelFunc,
DoneCh: renewCtxInfo.DoneCh,
}
// Store the index in the cache
c.logger.Debug("storing response into the cache", "path", req.Request.URL.Path, "method", req.Request.Method)
err = c.db.Set(index)
if err != nil {
c.logger.Error("failed to cache the proxied response", "error", err)
return nil, err
}
// Start renewing the secret in the response
go c.startRenewing(renewCtx, index, req, secret)
return resp, nil
}
func (c *LeaseCache) createCtxInfo(ctx context.Context, token string) *ContextInfo {
if ctx == nil {
ctx = c.baseCtxInfo.Ctx
}
ctxInfo := new(ContextInfo)
ctxInfo.Ctx, ctxInfo.CancelFunc = context.WithCancel(ctx)
ctxInfo.DoneCh = make(chan struct{})
return ctxInfo
}
func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index, req *SendRequest, secret *api.Secret) {
defer func() {
id := ctx.Value(contextIndexID).(string)
c.logger.Debug("evicting index from cache", "id", id, "path", req.Request.URL.Path, "method", req.Request.Method)
err := c.db.Evict(cachememdb.IndexNameID, id)
if err != nil {
c.logger.Error("failed to evict index", "id", id, "error", err)
return
}
}()
client, err := api.NewClient(api.DefaultConfig())
if err != nil {
c.logger.Error("failed to create API client in the renewer", "error", err)
return
}
client.SetToken(req.Token)
client.SetHeaders(req.Request.Header)
renewer, err := client.NewRenewer(&api.RenewerInput{
Secret: secret,
})
if err != nil {
c.logger.Error("failed to create secret renewer", "error", err)
return
}
c.logger.Debug("initiating renewal", "path", req.Request.URL.Path, "method", req.Request.Method)
go renewer.Renew()
defer renewer.Stop()
for {
select {
case <-ctx.Done():
// This is the case which captures context cancellations from token
// and leases. Since all the contexts are derived from the agent's
// context, this will also cover the shutdown scenario.
c.logger.Debug("context cancelled; stopping renewer", "path", req.Request.URL.Path)
return
case err := <-renewer.DoneCh():
// This case covers renewal completion and renewal errors
if err != nil {
c.logger.Error("failed to renew secret", "error", err)
return
}
c.logger.Debug("renewal halted; evicting from cache", "path", req.Request.URL.Path)
return
case renewal := <-renewer.RenewCh():
// This case captures secret renewals. Renewed secret is updated in
// the cached index.
c.logger.Debug("renewal received; updating cache", "path", req.Request.URL.Path)
err = c.updateResponse(ctx, renewal)
if err != nil {
c.logger.Error("failed to handle renewal", "error", err)
return
}
case <-index.RenewCtxInfo.DoneCh:
// This case indicates the renewal process to shutdown and evict
// the cache entry. This is triggered when a specific secret
// renewal needs to be killed without affecting any of the derived
// context renewals.
c.logger.Debug("done channel closed")
return
}
}
}
func (c *LeaseCache) updateResponse(ctx context.Context, renewal *api.RenewOutput) error {
id := ctx.Value(contextIndexID).(string)
// Get the cached index using the id in the context
index, err := c.db.Get(cachememdb.IndexNameID, id)
if err != nil {
return err
}
if index == nil {
return fmt.Errorf("missing cache entry for id: %q", id)
}
// Read the response from the index
resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(index.Response)), nil)
if err != nil {
c.logger.Error("failed to deserialize response", "error", err)
return err
}
// Update the body in the reponse by the renewed secret
bodyBytes, err := json.Marshal(renewal.Secret)
if err != nil {
return err
}
if resp.Body != nil {
resp.Body.Close()
}
resp.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
resp.ContentLength = int64(len(bodyBytes))
// Serialize the response
var respBytes bytes.Buffer
err = resp.Write(&respBytes)
if err != nil {
c.logger.Error("failed to serialize updated response", "error", err)
return err
}
// Update the response in the index and set it in the cache
index.Response = respBytes.Bytes()
err = c.db.Set(index)
if err != nil {
c.logger.Error("failed to cache the proxied response", "error", err)
return err
}
return nil
}
// computeIndexID results in a value that uniquely identifies a request
// received by the agent. It does so by SHA256 hashing the serialized request
// object containing the request path, query parameters and body parameters.
func computeIndexID(req *SendRequest) (string, error) {
var b bytes.Buffer
// Serialze the request
if err := req.Request.Write(&b); err != nil {
return "", fmt.Errorf("failed to serialize request: %v", err)
}
// Reset the request body after it has been closed by Write
req.Request.Body = ioutil.NopCloser(bytes.NewBuffer(req.RequestBody))
// Append req.Token into the byte slice. This is needed since auto-auth'ed
// requests sets the token directly into SendRequest.Token
b.Write([]byte(req.Token))
sum := sha256.Sum256(b.Bytes())
return hex.EncodeToString(sum[:]), nil
}
// HandleCacheClear returns a handlerFunc that can perform cache clearing operations.
func (c *LeaseCache) HandleCacheClear(ctx context.Context) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req := new(cacheClearRequest)
if err := jsonutil.DecodeJSONFromReader(r.Body, req); err != nil {
if err == io.EOF {
err = errors.New("empty JSON provided")
}
respondError(w, http.StatusBadRequest, errwrap.Wrapf("failed to parse JSON input: {{err}}", err))
return
}
c.logger.Debug("received cache-clear request", "type", req.Type, "namespace", req.Namespace, "value", req.Value)
if err := c.handleCacheClear(ctx, req.Type, req.Namespace, req.Value); err != nil {
// Default to 500 on error, unless the user provided an invalid type,
// which would then be a 400.
httpStatus := http.StatusInternalServerError
if err == errInvalidType {
httpStatus = http.StatusBadRequest
}
respondError(w, httpStatus, errwrap.Wrapf("failed to clear cache: {{err}}", err))
return
}
return
})
}
func (c *LeaseCache) handleCacheClear(ctx context.Context, clearType string, clearValues ...interface{}) error {
if len(clearValues) == 0 {
return errors.New("no value(s) provided to clear corresponding cache entries")
}
// The value that we want to clear, for most cases, is the last one provided.
clearValue, ok := clearValues[len(clearValues)-1].(string)
if !ok {
return fmt.Errorf("unable to convert %v to type string", clearValue)
}
switch clearType {
case "request_path":
// For this particular case, we need to ensure that there are 2 provided
// indexers for the proper lookup.
if len(clearValues) != 2 {
return fmt.Errorf("clearing cache by request path requires 2 indexers, got %d", len(clearValues))
}
// The first value provided for this case will be the namespace, but if it's
// an empty value we need to overwrite it with "root/" to ensure proper
// cache lookup.
if clearValues[0].(string) == "" {
clearValues[0] = "root/"
}
// Find all the cached entries which has the given request path and
// cancel the contexts of all the respective renewers
indexes, err := c.db.GetByPrefix(clearType, clearValues...)
if err != nil {
return err
}
for _, index := range indexes {
index.RenewCtxInfo.CancelFunc()
}
case "token":
if clearValue == "" {
return nil
}
// Get the context for the given token and cancel its context
index, err := c.db.Get(cachememdb.IndexNameToken, clearValue)
if err != nil {
return err
}
if index == nil {
return nil
}
c.logger.Debug("cancelling context of index attached to token")
index.RenewCtxInfo.CancelFunc()
case "token_accessor", "lease":
// Get the cached index and cancel the corresponding renewer context
index, err := c.db.Get(clearType, clearValue)
if err != nil {
return err
}
if index == nil {
return nil
}
c.logger.Debug("cancelling context of index attached to accessor")
index.RenewCtxInfo.CancelFunc()
case "all":
// Cancel the base context which triggers all the goroutines to
// stop and evict entries from cache.
c.logger.Debug("cancelling base context")
c.baseCtxInfo.CancelFunc()
// Reset the base context
baseCtx, baseCancel := context.WithCancel(ctx)
c.baseCtxInfo = &ContextInfo{
Ctx: baseCtx,
CancelFunc: baseCancel,
}
// Reset the memdb instance
if err := c.db.Flush(); err != nil {
return err
}
default:
return errInvalidType
}
c.logger.Debug("successfully cleared matching cache entries")
return nil
}
// handleRevocationRequest checks whether the originating request is a
// revocation request, and if so perform applicable cache cleanups.
// Returns true is this is a revocation request.
func (c *LeaseCache) handleRevocationRequest(ctx context.Context, req *SendRequest, resp *SendResponse) (bool, error) {
// Lease and token revocations return 204's on success. Fast-path if that's
// not the case.
if resp.Response.StatusCode != http.StatusNoContent {
return false, nil
}
_, path := deriveNamespaceAndRevocationPath(req)
switch {
case path == vaultPathTokenRevoke:
// Get the token from the request body
jsonBody := map[string]interface{}{}
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
return false, err
}
tokenRaw, ok := jsonBody["token"]
if !ok {
return false, fmt.Errorf("failed to get token from request body")
}
token, ok := tokenRaw.(string)
if !ok {
return false, fmt.Errorf("expected token in the request body to be string")
}
// Clear the cache entry associated with the token and all the other
// entries belonging to the leases derived from this token.
if err := c.handleCacheClear(ctx, "token", token); err != nil {
return false, err
}
case path == vaultPathTokenRevokeSelf:
// Clear the cache entry associated with the token and all the other
// entries belonging to the leases derived from this token.
if err := c.handleCacheClear(ctx, "token", req.Token); err != nil {
return false, err
}
case path == vaultPathTokenRevokeAccessor:
jsonBody := map[string]interface{}{}
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
return false, err
}
accessorRaw, ok := jsonBody["accessor"]
if !ok {
return false, fmt.Errorf("failed to get accessor from request body")
}
accessor, ok := accessorRaw.(string)
if !ok {
return false, fmt.Errorf("expected accessor in the request body to be string")
}
if err := c.handleCacheClear(ctx, "token_accessor", accessor); err != nil {
return false, err
}
case path == vaultPathTokenRevokeOrphan:
jsonBody := map[string]interface{}{}
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
return false, err
}
tokenRaw, ok := jsonBody["token"]
if !ok {
return false, fmt.Errorf("failed to get token from request body")
}
token, ok := tokenRaw.(string)
if !ok {
return false, fmt.Errorf("expected token in the request body to be string")
}
// Kill the renewers of all the leases attached to the revoked token
indexes, err := c.db.GetByPrefix(cachememdb.IndexNameLeaseToken, token)
if err != nil {
return false, err
}
for _, index := range indexes {
index.RenewCtxInfo.CancelFunc()
}
// Kill the renewer of the revoked token
index, err := c.db.Get(cachememdb.IndexNameToken, token)
if err != nil {
return false, err
}
if index == nil {
return true, nil
}
// Indicate the renewer goroutine for this index to return. This will
// not affect the child tokens because the context is not getting
// cancelled.
close(index.RenewCtxInfo.DoneCh)
// Clear the parent references of the revoked token in the entries
// belonging to the child tokens of the revoked token.
indexes, err = c.db.GetByPrefix(cachememdb.IndexNameTokenParent, token)
if err != nil {
return false, err
}
for _, index := range indexes {
index.TokenParent = ""
err = c.db.Set(index)
if err != nil {
c.logger.Error("failed to persist index", "error", err)
return false, err
}
}
case path == vaultPathLeaseRevoke:
// TODO: Should lease present in the URL itself be considered here?
// Get the lease from the request body
jsonBody := map[string]interface{}{}
if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil {
return false, err
}
leaseIDRaw, ok := jsonBody["lease_id"]
if !ok {
return false, fmt.Errorf("failed to get lease_id from request body")
}
leaseID, ok := leaseIDRaw.(string)
if !ok {
return false, fmt.Errorf("expected lease_id the request body to be string")
}
if err := c.handleCacheClear(ctx, "lease", leaseID); err != nil {
return false, err
}
case strings.HasPrefix(path, vaultPathLeaseRevokeForce):
// Trim the URL path to get the request path prefix
prefix := strings.TrimPrefix(path, vaultPathLeaseRevokeForce)
// Get all the cache indexes that use the request path containing the
// prefix and cancel the renewer context of each.
indexes, err := c.db.GetByPrefix(cachememdb.IndexNameLease, prefix)
if err != nil {
return false, err
}
_, tokenNSID := namespace.SplitIDFromString(req.Token)
for _, index := range indexes {
_, leaseNSID := namespace.SplitIDFromString(index.Lease)
// Only evict leases that match the token's namespace
if tokenNSID == leaseNSID {
index.RenewCtxInfo.CancelFunc()
}
}
case strings.HasPrefix(path, vaultPathLeaseRevokePrefix):
// Trim the URL path to get the request path prefix
prefix := strings.TrimPrefix(path, vaultPathLeaseRevokePrefix)
// Get all the cache indexes that use the request path containing the
// prefix and cancel the renewer context of each.
indexes, err := c.db.GetByPrefix(cachememdb.IndexNameLease, prefix)
if err != nil {
return false, err
}
_, tokenNSID := namespace.SplitIDFromString(req.Token)
for _, index := range indexes {
_, leaseNSID := namespace.SplitIDFromString(index.Lease)
// Only evict leases that match the token's namespace
if tokenNSID == leaseNSID {
index.RenewCtxInfo.CancelFunc()
}
}
default:
return false, nil
}
c.logger.Debug("triggered caching eviction from revocation request")
return true, nil
}
// deriveNamespaceAndRevocationPath returns the namespace and relative path for
// revocation paths.
//
// If the path contains a namespace, but it's not a revocation path, it will be
// returned as-is, since there's no way to tell where the namespace ends and
// where the request path begins purely based off a string.
//
// Case 1: /v1/ns1/leases/revoke -> ns1/, /v1/leases/revoke
// Case 2: ns1/ /v1/leases/revoke -> ns1/, /v1/leases/revoke
// Case 3: /v1/ns1/foo/bar -> root/, /v1/ns1/foo/bar
// Case 4: ns1/ /v1/foo/bar -> ns1/, /v1/foo/bar
func deriveNamespaceAndRevocationPath(req *SendRequest) (string, string) {
namespace := "root/"
nsHeader := req.Request.Header.Get(consts.NamespaceHeaderName)
if nsHeader != "" {
namespace = nsHeader
}
fullPath := req.Request.URL.Path
nonVersionedPath := strings.TrimPrefix(fullPath, "/v1")
for _, pathToCheck := range revocationPaths {
// We use strings.Contains here for paths that can contain
// vars in the path, e.g. /v1/lease/revoke-prefix/:prefix
i := strings.Index(nonVersionedPath, pathToCheck)
// If there's no match, move on to the next check
if i == -1 {
continue
}
// If the index is 0, this is a relative path with no namespace preppended,
// so we can break early
if i == 0 {
break
}
// We need to turn /ns1 into ns1/, this makes it easy
namespaceInPath := nshelper.Canonicalize(nonVersionedPath[:i])
// If it's root, we replace, otherwise we join
if namespace == "root/" {
namespace = namespaceInPath
} else {
namespace = namespace + namespaceInPath
}
return namespace, fmt.Sprintf("/v1%s", nonVersionedPath[i:])
}
return namespace, fmt.Sprintf("/v1%s", nonVersionedPath)
}

507
command/agent/cache/lease_cache_test.go vendored Normal file
View file

@ -0,0 +1,507 @@
package cache
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"github.com/go-test/deep"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/logging"
)
func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache {
t.Helper()
lc, err := NewLeaseCache(&LeaseCacheConfig{
BaseContext: context.Background(),
Proxier: newMockProxier(responses),
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"),
})
if err != nil {
t.Fatal(err)
}
return lc
}
func TestCache_ComputeIndexID(t *testing.T) {
type args struct {
req *http.Request
}
tests := []struct {
name string
req *SendRequest
want string
wantErr bool
}{
{
"basic",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "test",
},
},
},
"2edc7e965c3e1bdce3b1d5f79a52927842569c0734a86544d222753f11ae4847",
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := computeIndexID(tt.req)
if (err != nil) != tt.wantErr {
t.Errorf("actual_error: %v, expected_error: %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, string(tt.want)) {
t.Errorf("bad: index id; actual: %q, expected: %q", got, string(tt.want))
}
})
}
}
func TestCache_LeaseCache_EmptyToken(t *testing.T) {
responses := []*SendResponse{
&SendResponse{
Response: &api.Response{
Response: &http.Response{
StatusCode: http.StatusCreated,
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`)),
},
},
ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`),
},
}
lc := testNewLeaseCache(t, responses)
// Even if the send request doesn't have a token on it, a successful
// cacheable response should result in the index properly getting populated
// with a token and memdb shouldn't complain while inserting the index.
urlPath := "http://example.com/v1/sample/api"
sendReq := &SendRequest{
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err := lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatalf("expected a non empty response")
}
}
func TestCache_LeaseCache_SendCacheable(t *testing.T) {
// Emulate 2 responses from the api proxy. One returns a new token and the
// other returns a lease.
responses := []*SendResponse{
&SendResponse{
Response: &api.Response{
Response: &http.Response{
StatusCode: http.StatusCreated,
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "testtoken", "renewable": true}}`)),
},
},
ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "testtoken", "renewable": true}}`),
},
&SendResponse{
Response: &api.Response{
Response: &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader(`{"value": "output", "lease_id": "foo", "renewable": true}`)),
},
},
ResponseBody: []byte(`{"value": "output", "lease_id": "foo", "renewable": true}`),
},
}
lc := testNewLeaseCache(t, responses)
// Make a request. A response with a new token is returned to the lease
// cache and that will be cached.
urlPath := "http://example.com/v1/sample/api"
sendReq := &SendRequest{
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err := lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Send the same request again to get the cached response
sendReq = &SendRequest{
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Modify the request a little bit to ensure the second response is
// returned to the lease cache. But make sure that the token in the request
// is valid.
sendReq = &SendRequest{
Token: "testtoken",
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Make the same request again and ensure that the same reponse is returned
// again.
sendReq = &SendRequest{
Token: "testtoken",
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
}
func TestCache_LeaseCache_SendNonCacheable(t *testing.T) {
responses := []*SendResponse{
&SendResponse{
Response: &api.Response{
Response: &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader(`{"value": "output"}`)),
},
},
},
&SendResponse{
Response: &api.Response{
Response: &http.Response{
StatusCode: http.StatusNotFound,
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid"}`)),
},
},
},
}
lc := testNewLeaseCache(t, responses)
// Send a request through the lease cache which is not cacheable (there is
// no lease information or auth information in the response)
sendReq := &SendRequest{
Request: httptest.NewRequest("GET", "http://example.com", strings.NewReader(`{"value": "input"}`)),
}
resp, err := lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Since the response is non-cacheable, the second response will be
// returned.
sendReq = &SendRequest{
Token: "foo",
Request: httptest.NewRequest("GET", "http://example.com", strings.NewReader(`{"value": "input"}`)),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
}
func TestCache_LeaseCache_SendNonCacheableNonTokenLease(t *testing.T) {
// Create the cache
responses := []*SendResponse{
&SendResponse{
Response: &api.Response{
Response: &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader(`{"value": "output", "lease_id": "foo"}`)),
},
},
ResponseBody: []byte(`{"value": "output", "lease_id": "foo"}`),
},
&SendResponse{
Response: &api.Response{
Response: &http.Response{
StatusCode: http.StatusCreated,
Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`)),
},
},
ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "testtoken"}}`),
},
}
lc := testNewLeaseCache(t, responses)
// Send a request through lease cache which returns a response containing
// lease_id. Response will not be cached because it doesn't belong to a
// token that is managed by the lease cache.
urlPath := "http://example.com/v1/sample/api"
sendReq := &SendRequest{
Token: "foo",
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err := lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
// Verify that the response is not cached by sending the same request and
// by expecting a different response.
sendReq = &SendRequest{
Token: "foo",
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err = lc.Send(context.Background(), sendReq)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff == nil {
t.Fatalf("expected getting proxied response: got %v", diff)
}
}
func TestCache_LeaseCache_HandleCacheClear(t *testing.T) {
lc := testNewLeaseCache(t, nil)
handler := lc.HandleCacheClear(context.Background())
ts := httptest.NewServer(handler)
defer ts.Close()
// Test missing body, should return 400
resp, err := http.Post(ts.URL, "application/json", nil)
if err != nil {
t.Fatal()
}
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("status code mismatch: expected = %v, got = %v", http.StatusBadRequest, resp.StatusCode)
}
testCases := []struct {
name string
reqType string
reqValue string
expectedStatusCode int
}{
{
"invalid_type",
"foo",
"",
http.StatusBadRequest,
},
{
"invalid_value",
"",
"bar",
http.StatusBadRequest,
},
{
"all",
"all",
"",
http.StatusOK,
},
{
"by_request_path",
"request_path",
"foo",
http.StatusOK,
},
{
"by_token",
"token",
"foo",
http.StatusOK,
},
{
"by_lease",
"lease",
"foo",
http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
reqBody := fmt.Sprintf("{\"type\": \"%s\", \"value\": \"%s\"}", tc.reqType, tc.reqValue)
resp, err := http.Post(ts.URL, "application/json", strings.NewReader(reqBody))
if err != nil {
t.Fatal(err)
}
if tc.expectedStatusCode != resp.StatusCode {
t.Fatalf("status code mismatch: expected = %v, got = %v", tc.expectedStatusCode, resp.StatusCode)
}
})
}
}
func TestCache_DeriveNamespaceAndRevocationPath(t *testing.T) {
tests := []struct {
name string
req *SendRequest
wantNamespace string
wantRelativePath string
}{
{
"non_revocation_full_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns1/sys/mounts",
},
},
},
"root/",
"/v1/ns1/sys/mounts",
},
{
"non_revocation_relative_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/sys/mounts",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/",
"/v1/sys/mounts",
},
{
"non_revocation_relative_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns2/sys/mounts",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/",
"/v1/ns2/sys/mounts",
},
{
"revocation_full_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns1/sys/leases/revoke",
},
},
},
"ns1/",
"/v1/sys/leases/revoke",
},
{
"revocation_relative_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/sys/leases/revoke",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/",
"/v1/sys/leases/revoke",
},
{
"revocation_relative_partial_ns",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns2/sys/leases/revoke",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/ns2/",
"/v1/sys/leases/revoke",
},
{
"revocation_prefix_full_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns1/sys/leases/revoke-prefix/foo",
},
},
},
"ns1/",
"/v1/sys/leases/revoke-prefix/foo",
},
{
"revocation_prefix_relative_path",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/sys/leases/revoke-prefix/foo",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/",
"/v1/sys/leases/revoke-prefix/foo",
},
{
"revocation_prefix_partial_ns",
&SendRequest{
Request: &http.Request{
URL: &url.URL{
Path: "/v1/ns2/sys/leases/revoke-prefix/foo",
},
Header: http.Header{
consts.NamespaceHeaderName: []string{"ns1/"},
},
},
},
"ns1/ns2/",
"/v1/sys/leases/revoke-prefix/foo",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotNamespace, gotRelativePath := deriveNamespaceAndRevocationPath(tt.req)
if gotNamespace != tt.wantNamespace {
t.Errorf("deriveNamespaceAndRevocationPath() gotNamespace = %v, want %v", gotNamespace, tt.wantNamespace)
}
if gotRelativePath != tt.wantRelativePath {
t.Errorf("deriveNamespaceAndRevocationPath() gotRelativePath = %v, want %v", gotRelativePath, tt.wantRelativePath)
}
})
}
}

105
command/agent/cache/listener.go vendored Normal file
View file

@ -0,0 +1,105 @@
package cache
import (
"fmt"
"io"
"net"
"os"
"strings"
"github.com/hashicorp/vault/command/agent/config"
"github.com/hashicorp/vault/command/server"
"github.com/hashicorp/vault/helper/reload"
"github.com/mitchellh/cli"
)
func ServerListener(lnConfig *config.Listener, logger io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) {
switch lnConfig.Type {
case "unix":
return unixSocketListener(lnConfig.Config, logger, ui)
case "tcp":
return tcpListener(lnConfig.Config, logger, ui)
default:
return nil, nil, nil, fmt.Errorf("unsupported listener type: %q", lnConfig.Type)
}
}
func unixSocketListener(config map[string]interface{}, _ io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) {
addr, ok := config["address"].(string)
if !ok {
return nil, nil, nil, fmt.Errorf("invalid address: %v", config["address"])
}
if addr == "" {
return nil, nil, nil, fmt.Errorf("address field should point to socket file path")
}
// Remove the socket file as it shouldn't exist for the domain socket to
// work
err := os.Remove(addr)
if err != nil && !os.IsNotExist(err) {
return nil, nil, nil, fmt.Errorf("failed to remove the socket file: %v", err)
}
listener, err := net.Listen("unix", addr)
if err != nil {
return nil, nil, nil, err
}
// Wrap the listener in rmListener so that the Unix domain socket file is
// removed on close.
listener = &rmListener{
Listener: listener,
Path: addr,
}
props := map[string]string{"addr": addr, "tls": "disabled"}
return server.ListenerWrapTLS(listener, props, config, ui)
}
func tcpListener(config map[string]interface{}, _ io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) {
bindProto := "tcp"
var addr string
addrRaw, ok := config["address"]
if !ok {
addr = "127.0.0.1:8300"
} else {
addr = addrRaw.(string)
}
// If they've passed 0.0.0.0, we only want to bind on IPv4
// rather than golang's dual stack default
if strings.HasPrefix(addr, "0.0.0.0:") {
bindProto = "tcp4"
}
ln, err := net.Listen(bindProto, addr)
if err != nil {
return nil, nil, nil, err
}
ln = server.TCPKeepAliveListener{ln.(*net.TCPListener)}
props := map[string]string{"addr": addr}
return server.ListenerWrapTLS(ln, props, config, ui)
}
// rmListener is an implementation of net.Listener that forwards most
// calls to the listener but also removes a file as part of the close. We
// use this to cleanup the unix domain socket on close.
type rmListener struct {
net.Listener
Path string
}
func (l *rmListener) Close() error {
// Close the listener itself
if err := l.Listener.Close(); err != nil {
return err
}
// Remove the file
return os.Remove(l.Path)
}

28
command/agent/cache/proxy.go vendored Normal file
View file

@ -0,0 +1,28 @@
package cache
import (
"context"
"net/http"
"github.com/hashicorp/vault/api"
)
// SendRequest is the input for Proxier.Send.
type SendRequest struct {
Token string
Request *http.Request
RequestBody []byte
}
// SendResponse is the output from Proxier.Send.
type SendResponse struct {
Response *api.Response
ResponseBody []byte
}
// Proxier is the interface implemented by different components that are
// responsible for performing specific tasks, such as caching and proxying. All
// these tasks combined together would serve the request received by the agent.
type Proxier interface {
Send(ctx context.Context, req *SendRequest) (*SendResponse, error)
}

36
command/agent/cache/testing.go vendored Normal file
View file

@ -0,0 +1,36 @@
package cache
import (
"context"
"fmt"
)
// mockProxier is a mock implementation of the Proxier interface, used for testing purposes.
// The mock will return the provided responses every time it reaches its Send method, up to
// the last provided response. This lets tests control what the next/underlying Proxier layer
// might expect to return.
type mockProxier struct {
proxiedResponses []*SendResponse
responseIndex int
}
func newMockProxier(responses []*SendResponse) *mockProxier {
return &mockProxier{
proxiedResponses: responses,
}
}
func (p *mockProxier) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
if p.responseIndex >= len(p.proxiedResponses) {
return nil, fmt.Errorf("index out of bounds: responseIndex = %d, responses = %d", p.responseIndex, len(p.proxiedResponses))
}
resp := p.proxiedResponses[p.responseIndex]
p.responseIndex++
return resp, nil
}
func (p *mockProxier) ResponseIndex() int {
return p.responseIndex
}

View file

@ -0,0 +1,280 @@
package agent
import (
"context"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"testing"
"time"
hclog "github.com/hashicorp/go-hclog"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
"github.com/hashicorp/vault/command/agent/auth"
agentapprole "github.com/hashicorp/vault/command/agent/auth/approle"
"github.com/hashicorp/vault/command/agent/cache"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/command/agent/sink/file"
"github.com/hashicorp/vault/helper/logging"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)
func TestCache_UsingAutoAuthToken(t *testing.T) {
var err error
logger := logging.NewVaultLogger(log.Trace)
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: log.NewNullLogger(),
CredentialBackends: map[string]logical.Factory{
"approle": credAppRole.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
client := cores[0].Client
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
os.Setenv(api.EnvVaultAddress, client.Address())
defer os.Setenv(api.EnvVaultCACert, os.Getenv(api.EnvVaultCACert))
os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir))
err = client.Sys().EnableAuthWithOptions("approle", &api.EnableAuthOptions{
Type: "approle",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/approle/role/test1", map[string]interface{}{
"bind_secret_id": "true",
"token_ttl": "3s",
"token_max_ttl": "10s",
})
if err != nil {
t.Fatal(err)
}
resp, err := client.Logical().Write("auth/approle/role/test1/secret-id", nil)
if err != nil {
t.Fatal(err)
}
secretID1 := resp.Data["secret_id"].(string)
resp, err = client.Logical().Read("auth/approle/role/test1/role-id")
if err != nil {
t.Fatal(err)
}
roleID1 := resp.Data["role_id"].(string)
rolef, err := ioutil.TempFile("", "auth.role-id.test.")
if err != nil {
t.Fatal(err)
}
role := rolef.Name()
rolef.Close() // WriteFile doesn't need it open
defer os.Remove(role)
t.Logf("input role_id_file_path: %s", role)
secretf, err := ioutil.TempFile("", "auth.secret-id.test.")
if err != nil {
t.Fatal(err)
}
secret := secretf.Name()
secretf.Close()
defer os.Remove(secret)
t.Logf("input secret_id_file_path: %s", secret)
// We close these right away because we're just basically testing
// permissions and finding a usable file name
ouf, err := ioutil.TempFile("", "auth.tokensink.test.")
if err != nil {
t.Fatal(err)
}
out := ouf.Name()
ouf.Close()
os.Remove(out)
t.Logf("output: %s", out)
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
conf := map[string]interface{}{
"role_id_file_path": role,
"secret_id_file_path": secret,
"remove_secret_id_file_after_reading": true,
}
am, err := agentapprole.NewApproleAuthMethod(&auth.AuthConfig{
Logger: logger.Named("auth.approle"),
MountPath: "auth/approle",
Config: conf,
})
if err != nil {
t.Fatal(err)
}
ahConfig := &auth.AuthHandlerConfig{
Logger: logger.Named("auth.handler"),
Client: client,
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
defer func() {
<-ah.DoneCh
}()
config := &sink.SinkConfig{
Logger: logger.Named("sink.file"),
Config: map[string]interface{}{
"path": out,
},
}
fs, err := file.NewFileSink(config)
if err != nil {
t.Fatal(err)
}
config.Sink = fs
ss := sink.NewSinkServer(&sink.SinkServerConfig{
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
defer func() {
<-ss.DoneCh
}()
// This has to be after the other defers so it happens first
defer cancelFunc()
// Check that no sink file exists
_, err = os.Lstat(out)
if err == nil {
t.Fatal("expected err")
}
if !os.IsNotExist(err) {
t.Fatal("expected notexist err")
}
if err := ioutil.WriteFile(role, []byte(roleID1), 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote test role 1", "path", role)
}
if err := ioutil.WriteFile(secret, []byte(secretID1), 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote test secret 1", "path", secret)
}
getToken := func() string {
timeout := time.Now().Add(10 * time.Second)
for {
if time.Now().After(timeout) {
t.Fatal("did not find a written token after timeout")
}
val, err := ioutil.ReadFile(out)
if err == nil {
os.Remove(out)
if len(val) == 0 {
t.Fatal("written token was empty")
}
_, err = os.Stat(secret)
if err == nil {
t.Fatal("secret file exists but was supposed to be removed")
}
client.SetToken(string(val))
_, err := client.Auth().Token().LookupSelf()
if err != nil {
t.Fatal(err)
}
return string(val)
}
time.Sleep(250 * time.Millisecond)
}
}
t.Logf("auto-auth token: %q", getToken())
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
cacheLogger := logging.NewVaultLogger(hclog.Trace).Named("cache")
// Create the API proxier
apiProxy := cache.NewAPIProxy(&cache.APIProxyConfig{
Logger: cacheLogger.Named("apiproxy"),
})
// Create the lease cache proxier and set its underlying proxier to
// the API proxier.
leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{
BaseContext: ctx,
Proxier: apiProxy,
Logger: cacheLogger.Named("leasecache"),
})
if err != nil {
t.Fatal(err)
}
// Create a muxer and add paths relevant for the lease cache layer
mux := http.NewServeMux()
mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx))
mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, true, client))
server := &http.Server{
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
IdleTimeout: 5 * time.Minute,
ErrorLog: cacheLogger.StandardLogger(nil),
}
go server.Serve(listener)
testClient, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}
if err := testClient.SetAddress("http://" + listener.Addr().String()); err != nil {
t.Fatal(err)
}
// Wait for listeners to come up
time.Sleep(2 * time.Second)
resp, err = testClient.Logical().Read("auth/token/lookup-self")
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatalf("failed to use the auto-auth token to perform lookup-self")
}
}

View file

@ -22,6 +22,17 @@ type Config struct {
AutoAuth *AutoAuth `hcl:"auto_auth"`
ExitAfterAuth bool `hcl:"exit_after_auth"`
PidFile string `hcl:"pid_file"`
Cache *Cache `hcl:"cache"`
}
type Cache struct {
UseAutoAuthToken bool `hcl:"use_auto_auth_token"`
Listeners []*Listener `hcl:"listeners"`
}
type Listener struct {
Type string
Config map[string]interface{}
}
type AutoAuth struct {
@ -91,9 +102,94 @@ func LoadConfig(path string, logger log.Logger) (*Config, error) {
return nil, errwrap.Wrapf("error parsing 'auto_auth': {{err}}", err)
}
err = parseCache(&result, list)
if err != nil {
return nil, errwrap.Wrapf("error parsing 'cache':{{err}}", err)
}
return &result, nil
}
func parseCache(result *Config, list *ast.ObjectList) error {
name := "cache"
cacheList := list.Filter(name)
if len(cacheList.Items) == 0 {
return nil
}
if len(cacheList.Items) > 1 {
return fmt.Errorf("one and only one %q block is required", name)
}
item := cacheList.Items[0]
var c Cache
err := hcl.DecodeObject(&c, item.Val)
if err != nil {
return err
}
result.Cache = &c
subs, ok := item.Val.(*ast.ObjectType)
if !ok {
return fmt.Errorf("could not parse %q as an object", name)
}
subList := subs.List
err = parseListeners(result, subList)
if err != nil {
return errwrap.Wrapf("error parsing 'listener' stanzas: {{err}}", err)
}
return nil
}
func parseListeners(result *Config, list *ast.ObjectList) error {
name := "listener"
listenerList := list.Filter(name)
if len(listenerList.Items) < 1 {
return fmt.Errorf("at least one %q block is required", name)
}
var listeners []*Listener
for _, item := range listenerList.Items {
var lnConfig map[string]interface{}
err := hcl.DecodeObject(&lnConfig, item.Val)
if err != nil {
return err
}
var lnType string
switch {
case lnConfig["type"] != nil:
lnType = lnConfig["type"].(string)
delete(lnConfig, "type")
case len(item.Keys) == 1:
lnType = strings.ToLower(item.Keys[0].Token.Value().(string))
default:
return errors.New("listener type must be specified")
}
switch lnType {
case "unix", "tcp":
default:
return fmt.Errorf("invalid listener type %q", lnType)
}
listeners = append(listeners, &Listener{
Type: lnType,
Config: lnConfig,
})
}
result.Cache.Listeners = listeners
return nil
}
func parseAutoAuth(result *Config, list *ast.ObjectList) error {
name := "auto_auth"

View file

@ -10,6 +10,80 @@ import (
"github.com/hashicorp/vault/helper/logging"
)
func TestLoadConfigFile_AgentCache(t *testing.T) {
logger := logging.NewVaultLogger(log.Debug)
config, err := LoadConfig("./test-fixtures/config-cache.hcl", logger)
if err != nil {
t.Fatal(err)
}
expected := &Config{
AutoAuth: &AutoAuth{
Method: &Method{
Type: "aws",
WrapTTL: 300 * time.Second,
MountPath: "auth/aws",
Config: map[string]interface{}{
"role": "foobar",
},
},
Sinks: []*Sink{
&Sink{
Type: "file",
DHType: "curve25519",
DHPath: "/tmp/file-foo-dhpath",
AAD: "foobar",
Config: map[string]interface{}{
"path": "/tmp/file-foo",
},
},
},
},
Cache: &Cache{
UseAutoAuthToken: true,
Listeners: []*Listener{
&Listener{
Type: "unix",
Config: map[string]interface{}{
"address": "/path/to/socket",
"tls_disable": true,
},
},
&Listener{
Type: "tcp",
Config: map[string]interface{}{
"address": "127.0.0.1:8300",
"tls_disable": true,
},
},
&Listener{
Type: "tcp",
Config: map[string]interface{}{
"address": "127.0.0.1:8400",
"tls_key_file": "/path/to/cakey.pem",
"tls_cert_file": "/path/to/cacert.pem",
},
},
},
},
PidFile: "./pidfile",
}
if diff := deep.Equal(config, expected); diff != nil {
t.Fatal(diff)
}
config, err = LoadConfig("./test-fixtures/config-cache-embedded-type.hcl", logger)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(config, expected); diff != nil {
t.Fatal(diff)
}
}
func TestLoadConfigFile(t *testing.T) {
logger := logging.NewVaultLogger(log.Debug)

View file

@ -0,0 +1,44 @@
pid_file = "./pidfile"
auto_auth {
method {
type = "aws"
wrap_ttl = 300
config = {
role = "foobar"
}
}
sink {
type = "file"
config = {
path = "/tmp/file-foo"
}
aad = "foobar"
dh_type = "curve25519"
dh_path = "/tmp/file-foo-dhpath"
}
}
cache {
use_auto_auth_token = true
listener {
type = "unix"
address = "/path/to/socket"
tls_disable = true
}
listener {
type = "tcp"
address = "127.0.0.1:8300"
tls_disable = true
}
listener {
type = "tcp"
address = "127.0.0.1:8400"
tls_key_file = "/path/to/cakey.pem"
tls_cert_file = "/path/to/cacert.pem"
}
}

View file

@ -0,0 +1,41 @@
pid_file = "./pidfile"
auto_auth {
method {
type = "aws"
wrap_ttl = 300
config = {
role = "foobar"
}
}
sink {
type = "file"
config = {
path = "/tmp/file-foo"
}
aad = "foobar"
dh_type = "curve25519"
dh_path = "/tmp/file-foo-dhpath"
}
}
cache {
use_auto_auth_token = true
listener "unix" {
address = "/path/to/socket"
tls_disable = true
}
listener "tcp" {
address = "127.0.0.1:8300"
tls_disable = true
}
listener "tcp" {
address = "127.0.0.1:8400"
tls_key_file = "/path/to/cakey.pem"
tls_cert_file = "/path/to/cacert.pem"
}
}

View file

@ -62,6 +62,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
}
_, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{
"role_type": "jwt",
"bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
"bound_audiences": "https://vault.plugin.auth.jwt.test",
"user_claim": "https://vault/user",

View file

@ -30,6 +30,191 @@ func testAgentCommand(tb testing.TB, logger hclog.Logger) (*cli.MockUi, *AgentCo
}
}
/*
func TestAgent_Cache_UnixListener(t *testing.T) {
logger := logging.NewVaultLogger(hclog.Trace)
coreConfig := &vault.CoreConfig{
Logger: logger.Named("core"),
CredentialBackends: map[string]logical.Factory{
"jwt": vaultjwt.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
client := cluster.Cores[0].Client
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
os.Setenv(api.EnvVaultAddress, client.Address())
defer os.Setenv(api.EnvVaultCACert, os.Getenv(api.EnvVaultCACert))
os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir))
// Setup Vault
err := client.Sys().EnableAuthWithOptions("jwt", &api.EnableAuthOptions{
Type: "jwt",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{
"bound_issuer": "https://team-vault.auth0.com/",
"jwt_validation_pubkeys": agent.TestECDSAPubKey,
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{
"role_type": "jwt",
"bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
"bound_audiences": "https://vault.plugin.auth.jwt.test",
"user_claim": "https://vault/user",
"groups_claim": "https://vault/groups",
"policies": "test",
"period": "3s",
})
if err != nil {
t.Fatal(err)
}
inf, err := ioutil.TempFile("", "auth.jwt.test.")
if err != nil {
t.Fatal(err)
}
in := inf.Name()
inf.Close()
os.Remove(in)
t.Logf("input: %s", in)
sink1f, err := ioutil.TempFile("", "sink1.jwt.test.")
if err != nil {
t.Fatal(err)
}
sink1 := sink1f.Name()
sink1f.Close()
os.Remove(sink1)
t.Logf("sink1: %s", sink1)
sink2f, err := ioutil.TempFile("", "sink2.jwt.test.")
if err != nil {
t.Fatal(err)
}
sink2 := sink2f.Name()
sink2f.Close()
os.Remove(sink2)
t.Logf("sink2: %s", sink2)
conff, err := ioutil.TempFile("", "conf.jwt.test.")
if err != nil {
t.Fatal(err)
}
conf := conff.Name()
conff.Close()
os.Remove(conf)
t.Logf("config: %s", conf)
jwtToken, _ := agent.GetTestJWT(t)
if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote test jwt", "path", in)
}
socketff, err := ioutil.TempFile("", "cache.socket.")
if err != nil {
t.Fatal(err)
}
socketf := socketff.Name()
socketff.Close()
os.Remove(socketf)
t.Logf("socketf: %s", socketf)
config := `
auto_auth {
method {
type = "jwt"
config = {
role = "test"
path = "%s"
}
}
sink {
type = "file"
config = {
path = "%s"
}
}
sink "file" {
config = {
path = "%s"
}
}
}
cache {
use_auto_auth_token = true
listener "unix" {
address = "%s"
tls_disable = true
}
}
`
config = fmt.Sprintf(config, in, sink1, sink2, socketf)
if err := ioutil.WriteFile(conf, []byte(config), 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote test config", "path", conf)
}
_, cmd := testAgentCommand(t, logger)
cmd.client = client
// Kill the command 5 seconds after it starts
go func() {
select {
case <-cmd.ShutdownCh:
case <-time.After(5 * time.Second):
cmd.ShutdownCh <- struct{}{}
}
}()
originalVaultAgentAddress := os.Getenv(api.EnvVaultAgentAddress)
// Create a client that talks to the agent
os.Setenv(api.EnvVaultAgentAddress, socketf)
testClient, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}
os.Setenv(api.EnvVaultAgentAddress, originalVaultAgentAddress)
// Start the agent
go cmd.Run([]string{"-config", conf})
// Give some time for the auto-auth to complete
time.Sleep(1 * time.Second)
// Invoke lookup self through the agent
secret, err := testClient.Auth().Token().LookupSelf()
if err != nil {
t.Fatal(err)
}
if secret == nil || secret.Data == nil || secret.Data["id"].(string) == "" {
t.Fatalf("failed to perform lookup self through agent")
}
}
*/
func TestExitAfterAuth(t *testing.T) {
logger := logging.NewVaultLogger(hclog.Trace)
coreConfig := &vault.CoreConfig{
@ -64,6 +249,7 @@ func TestExitAfterAuth(t *testing.T) {
}
_, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{
"role_type": "jwt",
"bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
"bound_audiences": "https://vault.plugin.auth.jwt.test",
"user_claim": "https://vault/user",

View file

@ -1,9 +1,6 @@
package command
import (
"flag"
"io"
"io/ioutil"
"strings"
"github.com/mitchellh/cli"
@ -13,10 +10,6 @@ var _ cli.Command = (*AuthCommand)(nil)
type AuthCommand struct {
*BaseCommand
Handlers map[string]LoginHandler
testStdin io.Reader // for tests
}
func (c *AuthCommand) Synopsis() string {
@ -52,77 +45,5 @@ Usage: vault auth <subcommand> [options] [args]
}
func (c *AuthCommand) Run(args []string) int {
// If we entered the run method, none of the subcommands picked up. This
// means the user is still trying to use auth as "vault auth TOKEN" or
// similar, so direct them to vault login instead.
//
// This run command is a bit messy to maintain BC for a bit. In the future,
// it will just be a tiny function, but for now we have to maintain bc.
//
// Deprecation
// TODO: remove in 0.9.0
if len(args) == 0 {
return cli.RunResultHelp
}
// Parse the args for our deprecations and defer to the proper areas.
for _, arg := range args {
switch {
case strings.HasPrefix(arg, "-methods"):
if Format(c.UI) == "table" {
c.UI.Warn(wrapAtLength(
"WARNING! The -methods flag is deprecated. Please use "+
"\"vault auth list\" instead. This flag will be removed in "+
"Vault 1.1.") + "\n")
}
return (&AuthListCommand{
BaseCommand: &BaseCommand{
UI: c.UI,
client: c.client,
},
}).Run(nil)
case strings.HasPrefix(arg, "-method-help"):
if Format(c.UI) == "table" {
c.UI.Warn(wrapAtLength(
"WARNING! The -method-help flag is deprecated. Please use "+
"\"vault auth help\" instead. This flag will be removed in "+
"Vault 1.1.") + "\n")
}
// Parse the args to pull out the method, suppressing any errors because
// there could be other flags that we don't care about.
f := flag.NewFlagSet("", flag.ContinueOnError)
f.Usage = func() {}
f.SetOutput(ioutil.Discard)
flagMethod := f.String("method", "", "")
f.Parse(args)
return (&AuthHelpCommand{
BaseCommand: &BaseCommand{
UI: c.UI,
client: c.client,
},
Handlers: c.Handlers,
}).Run([]string{*flagMethod})
}
}
// If we got this far, we have an arg or a series of args that should be
// passed directly to the new "vault login" command.
if Format(c.UI) == "table" {
c.UI.Warn(wrapAtLength(
"WARNING! The \"vault auth ARG\" command is deprecated and is now a "+
"subcommand for interacting with auth methods. To authenticate "+
"locally to Vault, use \"vault login\" instead. This backwards "+
"compatibility will be removed in Vault 1.1.") + "\n")
}
return (&LoginCommand{
BaseCommand: &BaseCommand{
UI: c.UI,
client: c.client,
tokenHelper: c.tokenHelper,
flagAddress: c.flagAddress,
},
Handlers: c.Handlers,
}).Run(args)
return cli.RunResultHelp
}

View file

@ -175,7 +175,8 @@ func TestAuthEnableCommand_Run(t *testing.T) {
// Add 1 to account for the "token" backend, which is visible when you walk the filesystem but
// is treated as special and excluded from the registry.
expected := len(builtinplugins.Registry.Keys(consts.PluginTypeCredential)) + 1
// Subtract 1 to account for "oidc" which is an alias of "jwt" and not a separate plugin.
expected := len(builtinplugins.Registry.Keys(consts.PluginTypeCredential))
if len(backends) != expected {
t.Fatalf("expected %d credential backends, got %d", expected, len(backends))
}

View file

@ -1,13 +1,10 @@
package command
import (
"strings"
"testing"
"github.com/mitchellh/cli"
credToken "github.com/hashicorp/vault/builtin/credential/token"
credUserpass "github.com/hashicorp/vault/builtin/credential/userpass"
"github.com/hashicorp/vault/command/token"
)
@ -22,110 +19,12 @@ func testAuthCommand(tb testing.TB) (*cli.MockUi, *AuthCommand) {
// Override to our own token helper
tokenHelper: token.NewTestingTokenHelper(),
},
Handlers: map[string]LoginHandler{
"token": &credToken.CLIHandler{},
"userpass": &credUserpass.CLIHandler{},
},
}
}
func TestAuthCommand_Run(t *testing.T) {
t.Parallel()
// TODO: remove in 0.9.0
t.Run("deprecated_methods", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
ui, cmd := testAuthCommand(t)
cmd.client = client
// vault auth -methods -> vault auth list
code := cmd.Run([]string{"-methods"})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
}
stdout, stderr := ui.OutputWriter.String(), ui.ErrorWriter.String()
if expected := "WARNING!"; !strings.Contains(stderr, expected) {
t.Errorf("expected %q to contain %q", stderr, expected)
}
if expected := "token/"; !strings.Contains(stdout, expected) {
t.Errorf("expected %q to contain %q", stdout, expected)
}
})
t.Run("deprecated_method_help", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
ui, cmd := testAuthCommand(t)
cmd.client = client
// vault auth -method=foo -method-help -> vault auth help foo
code := cmd.Run([]string{
"-method=userpass",
"-method-help",
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
}
stdout, stderr := ui.OutputWriter.String(), ui.ErrorWriter.String()
if expected := "WARNING!"; !strings.Contains(stderr, expected) {
t.Errorf("expected %q to contain %q", stderr, expected)
}
if expected := "vault login"; !strings.Contains(stdout, expected) {
t.Errorf("expected %q to contain %q", stdout, expected)
}
})
t.Run("deprecated_login", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
if err := client.Sys().EnableAuth("my-auth", "userpass", ""); err != nil {
t.Fatal(err)
}
if _, err := client.Logical().Write("auth/my-auth/users/test", map[string]interface{}{
"password": "test",
"policies": "default",
}); err != nil {
t.Fatal(err)
}
ui, cmd := testAuthCommand(t)
cmd.client = client
// vault auth ARGS -> vault login ARGS
code := cmd.Run([]string{
"-method", "userpass",
"-path", "my-auth",
"username=test",
"password=test",
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
}
stdout, stderr := ui.OutputWriter.String(), ui.ErrorWriter.String()
if expected := "WARNING!"; !strings.Contains(stderr, expected) {
t.Errorf("expected %q to contain %q", stderr, expected)
}
if expected := "Success! You are now authenticated."; !strings.Contains(stdout, expected) {
t.Errorf("expected %q to contain %q", stdout, expected)
}
})
t.Run("no_tabs", func(t *testing.T) {
t.Parallel()

View file

@ -39,6 +39,7 @@ type BaseCommand struct {
flagsOnce sync.Once
flagAddress string
flagAgentAddress string
flagCACert string
flagCAPath string
flagClientCert string
@ -78,6 +79,9 @@ func (c *BaseCommand) Client() (*api.Client, error) {
if c.flagAddress != "" {
config.Address = c.flagAddress
}
if c.flagAgentAddress != "" {
config.Address = c.flagAgentAddress
}
if c.flagOutputCurlString {
config.OutputCurlString = c.flagOutputCurlString
@ -220,6 +224,15 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
}
f.StringVar(addrStringVar)
agentAddrStringVar := &StringVar{
Name: "agent-address",
Target: &c.flagAgentAddress,
EnvVar: "VAULT_AGENT_ADDR",
Completion: complete.PredictAnything,
Usage: "Address of the Agent.",
}
f.StringVar(agentAddrStringVar)
f.StringVar(&StringVar{
Name: "ca-cert",
Target: &c.flagCACert,

View file

@ -352,6 +352,7 @@ func TestPredict_Plugins(t *testing.T) {
"mysql-legacy-database-plugin",
"mysql-rds-database-plugin",
"nomad",
"oidc",
"okta",
"pki",
"postgresql",

View file

@ -1,7 +1,6 @@
package command
import (
"fmt"
"os"
"os/signal"
"syscall"
@ -27,6 +26,7 @@ import (
credAliCloud "github.com/hashicorp/vault-plugin-auth-alicloud"
credCentrify "github.com/hashicorp/vault-plugin-auth-centrify"
credGcp "github.com/hashicorp/vault-plugin-auth-gcp/plugin"
credOIDC "github.com/hashicorp/vault-plugin-auth-jwt"
credAws "github.com/hashicorp/vault/builtin/credential/aws"
credCert "github.com/hashicorp/vault/builtin/credential/cert"
credGitHub "github.com/hashicorp/vault/builtin/credential/github"
@ -130,43 +130,8 @@ var (
}
)
// DeprecatedCommand is a command that wraps an existing command and prints a
// deprecation notice and points the user to the new command. Deprecated
// commands are always hidden from help output.
type DeprecatedCommand struct {
cli.Command
UI cli.Ui
// Old is the old command name, New is the new command name.
Old, New string
}
// Help wraps the embedded Help command and prints a warning about deprecations.
func (c *DeprecatedCommand) Help() string {
c.warn()
return c.Command.Help()
}
// Run wraps the embedded Run command and prints a warning about deprecation.
func (c *DeprecatedCommand) Run(args []string) int {
if Format(c.UI) == "table" {
c.warn()
}
return c.Command.Run(args)
}
func (c *DeprecatedCommand) warn() {
c.UI.Warn(wrapAtLength(fmt.Sprintf(
"WARNING! The \"vault %s\" command is deprecated. Please use \"vault %s\" "+
"instead. This command will be removed in Vault 1.1.",
c.Old,
c.New)))
c.UI.Warn("")
}
// Commands is the mapping of all the available commands.
var Commands map[string]cli.CommandFactory
var DeprecatedCommands map[string]cli.CommandFactory
func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) {
loginHandlers := map[string]LoginHandler{
@ -177,6 +142,7 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) {
"gcp": &credGcp.CLIHandler{},
"github": &credGitHub.CLIHandler{},
"ldap": &credLdap.CLIHandler{},
"oidc": &credOIDC.CLIHandler{},
"okta": &credOkta.CLIHandler{},
"radius": &credUserpass.CLIHandler{
DefaultMount: "radius",
@ -233,7 +199,6 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) {
"auth": func() (cli.Command, error) {
return &AuthCommand{
BaseCommand: getBaseCommand(),
Handlers: loginHandlers,
}, nil
},
"auth disable": func() (cli.Command, error) {
@ -612,328 +577,6 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) {
}, nil
},
}
// Deprecated commands
//
// TODO: Remove not before 0.11.0
DeprecatedCommands = map[string]cli.CommandFactory{
"audit-disable": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "audit-disable",
New: "audit disable",
UI: ui,
Command: &AuditDisableCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"audit-enable": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "audit-enable",
New: "audit enable",
UI: ui,
Command: &AuditEnableCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"audit-list": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "audit-list",
New: "audit list",
UI: ui,
Command: &AuditListCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"auth-disable": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "auth-disable",
New: "auth disable",
UI: ui,
Command: &AuthDisableCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"auth-enable": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "auth-enable",
New: "auth enable",
UI: ui,
Command: &AuthEnableCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"capabilities": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "capabilities",
New: "token capabilities",
UI: ui,
Command: &TokenCapabilitiesCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"generate-root": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "generate-root",
New: "operator generate-root",
UI: ui,
Command: &OperatorGenerateRootCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"init": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "init",
New: "operator init",
UI: ui,
Command: &OperatorInitCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"key-status": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "key-status",
New: "operator key-status",
UI: ui,
Command: &OperatorKeyStatusCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"renew": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "renew",
New: "lease renew",
UI: ui,
Command: &LeaseRenewCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"revoke": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "revoke",
New: "lease revoke",
UI: ui,
Command: &LeaseRevokeCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"mount": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "mount",
New: "secrets enable",
UI: ui,
Command: &SecretsEnableCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"mount-tune": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "mount-tune",
New: "secrets tune",
UI: ui,
Command: &SecretsTuneCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"mounts": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "mounts",
New: "secrets list",
UI: ui,
Command: &SecretsListCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"policies": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "policies",
New: "policy read\" or \"vault policy list", // lol
UI: ui,
Command: &PoliciesDeprecatedCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"policy-delete": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "policy-delete",
New: "policy delete",
UI: ui,
Command: &PolicyDeleteCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"policy-write": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "policy-write",
New: "policy write",
UI: ui,
Command: &PolicyWriteCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"rekey": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "rekey",
New: "operator rekey",
UI: ui,
Command: &OperatorRekeyCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"remount": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "remount",
New: "secrets move",
UI: ui,
Command: &SecretsMoveCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"rotate": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "rotate",
New: "operator rotate",
UI: ui,
Command: &OperatorRotateCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"seal": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "seal",
New: "operator seal",
UI: ui,
Command: &OperatorSealCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"step-down": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "step-down",
New: "operator step-down",
UI: ui,
Command: &OperatorStepDownCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"token-create": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "token-create",
New: "token create",
UI: ui,
Command: &TokenCreateCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"token-lookup": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "token-lookup",
New: "token lookup",
UI: ui,
Command: &TokenLookupCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"token-renew": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "token-renew",
New: "token renew",
UI: ui,
Command: &TokenRenewCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"token-revoke": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "token-revoke",
New: "token revoke",
UI: ui,
Command: &TokenRevokeCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"unmount": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "unmount",
New: "secrets disable",
UI: ui,
Command: &SecretsDisableCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
"unseal": func() (cli.Command, error) {
return &DeprecatedCommand{
Old: "unseal",
New: "operator unseal",
UI: ui,
Command: &OperatorUnsealCommand{
BaseCommand: getBaseCommand(),
},
}, nil
},
}
// Add deprecated commands back to the main commands so they parse.
for k, v := range DeprecatedCommands {
if _, ok := Commands[k]; ok {
// Can't deprecate an existing command...
panic(fmt.Sprintf("command %q defined as deprecated and not at the same time!", k))
}
Commands[k] = v
}
}
// MakeShutdownCh returns a channel that can be used for shutdown

View file

@ -28,10 +28,6 @@ type LoginCommand struct {
flagNoPrint bool
flagTokenOnly bool
// Deprecations
// TODO: remove in 0.9.0
flagNoVerify bool
testStdin io.Reader // for tests
}
@ -132,16 +128,6 @@ func (c *LoginCommand) Flags() *FlagSets {
"values will have no affect.",
})
// Deprecations
// TODO: remove in 0.9.0
f.BoolVar(&BoolVar{
Name: "no-verify",
Target: &c.flagNoVerify,
Hidden: true,
Default: false,
Usage: "",
})
return set
}
@ -163,39 +149,6 @@ func (c *LoginCommand) Run(args []string) int {
args = f.Args()
// Deprecations
// TODO: remove in 0.10.0
switch {
case c.flagNoVerify:
if Format(c.UI) == "table" {
c.UI.Warn(wrapAtLength(
"WARNING! The -no-verify flag is deprecated. In the past, Vault " +
"performed a lookup on a token after authentication. This is no " +
"longer the case for all auth methods except \"token\". Vault will " +
"still attempt to perform a lookup when given a token directly " +
"because that is how it gets the list of policies, ttl, and other " +
"metadata. To disable this lookup, specify \"lookup=false\" as a " +
"configuration option to the token auth method, like this:"))
c.UI.Warn("")
c.UI.Warn(" $ vault auth token=ABCD lookup=false")
c.UI.Warn("")
c.UI.Warn("Or omit the token and Vault will prompt for it:")
c.UI.Warn("")
c.UI.Warn(" $ vault auth lookup=false")
c.UI.Warn(" Token (will be hidden): ...")
c.UI.Warn("")
c.UI.Warn(wrapAtLength(
"If you are not using token authentication, you can safely omit this " +
"flag. Vault will not perform a lookup after authentication."))
c.UI.Warn("")
}
// There's no point in passing this to other auth handlers...
if c.flagMethod == "token" {
args = append(args, "lookup=false")
}
}
// Set the right flags if the user requested token-only - this overrides
// any previously configured values, as documented.
if c.flagTokenOnly {

View file

@ -443,51 +443,6 @@ func TestLoginCommand_Run(t *testing.T) {
}
})
// Deprecations
// TODO: remove in 0.9.0
t.Run("deprecated_no_verify", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{
Policies: []string{"default"},
TTL: "30m",
NumUses: 1,
})
if err != nil {
t.Fatal(err)
}
token := secret.Auth.ClientToken
_, cmd := testLoginCommand(t)
cmd.client = client
code := cmd.Run([]string{
"-no-verify",
token,
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d", code, exp)
}
lookup, err := client.Auth().Token().Lookup(token)
if err != nil {
t.Fatal(err)
}
// There was 1 use to start, make sure we didn't use it (verifying would
// use it).
uses, err := lookup.TokenRemainingUses()
if err != nil {
t.Fatal(err)
}
if uses != 1 {
t.Errorf("expected %d to be %d", uses, 1)
}
})
t.Run("no_tabs", func(t *testing.T) {
t.Parallel()

View file

@ -159,12 +159,7 @@ func RunCustom(args []string, runOpts *RunOptions) int {
initCommands(ui, serverCmdUi, runOpts)
// Calculate hidden commands from the deprecated ones
hiddenCommands := make([]string, 0, len(DeprecatedCommands)+1)
for k := range DeprecatedCommands {
hiddenCommands = append(hiddenCommands, k)
}
hiddenCommands = append(hiddenCommands, "version")
hiddenCommands := []string{"version"}
cli := &cli.CLI{
Name: "vault",

View file

@ -36,10 +36,6 @@ type OperatorGenerateRootCommand struct {
flagGenerateOTP bool
flagDRToken bool
// Deprecation
// TODO: remove in 0.9.0
flagGenOTP bool
testStdin io.Reader // for tests
}
@ -179,15 +175,6 @@ func (c *OperatorGenerateRootCommand) Flags() *FlagSets {
"must be provided with each unseal key.",
})
// Deprecations: prefer longer-form, descriptive flags
// TODO: remove in 0.9.0
f.BoolVar(&BoolVar{
Name: "genotp", // -generate-otp
Target: &c.flagGenOTP,
Default: false,
Hidden: true,
})
return set
}
@ -213,18 +200,6 @@ func (c *OperatorGenerateRootCommand) Run(args []string) int {
return 1
}
// Deprecations
// TODO: remove in 0.9.0
switch {
case c.flagGenOTP:
if Format(c.UI) == "table" {
c.UI.Warn(wrapAtLength(
"WARNING! The -gen-otp flag is deprecated. Please use the -generate-otp flag " +
"instead."))
}
c.flagGenerateOTP = c.flagGenOTP
}
client, err := c.Client()
if err != nil {
c.UI.Error(err.Error())

View file

@ -27,7 +27,6 @@ type OperatorInitCommand struct {
flagRootTokenPGPKey string
// HSM
flagStoredShares int
flagRecoveryShares int
flagRecoveryThreshold int
flagRecoveryPGPKeys []string
@ -35,11 +34,6 @@ type OperatorInitCommand struct {
// Consul
flagConsulAuto bool
flagConsulService string
// Deprecations
// TODO: remove in 0.9.0
flagAuto bool
flagCheck bool
}
func (c *OperatorInitCommand) Synopsis() string {
@ -196,32 +190,6 @@ func (c *OperatorInitCommand) Flags() *FlagSets {
"is only used in HSM mode.",
})
// Deprecations
// TODO: remove in 0.9.0
f.BoolVar(&BoolVar{
Name: "check", // prefer -status
Target: &c.flagCheck,
Default: false,
Hidden: true,
Usage: "",
})
f.BoolVar(&BoolVar{
Name: "auto", // prefer -consul-auto
Target: &c.flagAuto,
Default: false,
Hidden: true,
Usage: "",
})
// Kept to keep scripts passing the flag working, but not used
f.IntVar(&IntVar{
Name: "stored-shares",
Target: &c.flagStoredShares,
Default: 0,
Hidden: true,
Usage: "",
})
return set
}
@ -241,23 +209,6 @@ func (c *OperatorInitCommand) Run(args []string) int {
return 1
}
// Deprecations
// TODO: remove in 0.9.0
if c.flagAuto {
if Format(c.UI) == "table" {
c.UI.Warn(wrapAtLength("WARNING! -auto is deprecated. Please use " +
"-consul-auto instead. This will be removed in Vault 1.1."))
}
c.flagConsulAuto = true
}
if c.flagCheck {
if Format(c.UI) == "table" {
c.UI.Warn(wrapAtLength("WARNING! -check is deprecated. Please use " +
"-status instead. This will be removed in Vault 1.1."))
}
c.flagStatus = true
}
args = f.Args()
if len(args) > 0 {
c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args)))
@ -271,7 +222,6 @@ func (c *OperatorInitCommand) Run(args []string) int {
PGPKeys: c.flagPGPKeys,
RootTokenPGPKey: c.flagRootTokenPGPKey,
StoredShares: c.flagStoredShares,
RecoveryShares: c.flagRecoveryShares,
RecoveryThreshold: c.flagRecoveryThreshold,
RecoveryPGPKeys: c.flagRecoveryPGPKeys,

View file

@ -36,13 +36,6 @@ type OperatorRekeyCommand struct {
flagBackupDelete bool
flagBackupRetrieve bool
// Deprecations
// TODO: remove in 0.9.0
flagDelete bool
flagRecoveryKey bool
flagRetrieve bool
flagStoredShares int
testStdin io.Reader // for tests
}
@ -216,41 +209,6 @@ func (c *OperatorRekeyCommand) Flags() *FlagSets {
"if the PGP keys were provided and the backup has not been deleted.",
})
// Deprecations
// TODO: remove in 0.9.0
f.BoolVar(&BoolVar{
Name: "delete", // prefer -backup-delete
Target: &c.flagDelete,
Default: false,
Hidden: true,
Usage: "",
})
f.BoolVar(&BoolVar{
Name: "retrieve", // prefer -backup-retrieve
Target: &c.flagRetrieve,
Default: false,
Hidden: true,
Usage: "",
})
f.BoolVar(&BoolVar{
Name: "recovery-key", // prefer -target=recovery
Target: &c.flagRecoveryKey,
Default: false,
Hidden: true,
Usage: "",
})
// Kept to keep scripts passing the flag working, but not used
f.IntVar(&IntVar{
Name: "stored-shares",
Target: &c.flagStoredShares,
Default: 0,
Hidden: true,
Usage: "",
})
return set
}
@ -276,33 +234,6 @@ func (c *OperatorRekeyCommand) Run(args []string) int {
return 1
}
// Deprecations
// TODO: remove in 0.9.0
if c.flagDelete {
if Format(c.UI) == "table" {
c.UI.Warn(wrapAtLength(
"WARNING! The -delete flag is deprecated. Please use -backup-delete " +
"instead. This flag will be removed in Vault 1.1."))
}
c.flagBackupDelete = true
}
if c.flagRetrieve {
if Format(c.UI) == "table" {
c.UI.Warn(wrapAtLength(
"WARNING! The -retrieve flag is deprecated. Please use -backup-retrieve " +
"instead. This flag will be removed in Vault 1.1."))
}
c.flagBackupRetrieve = true
}
if c.flagRecoveryKey {
if Format(c.UI) == "table" {
c.UI.Warn(wrapAtLength(
"WARNING! The -recovery-key flag is deprecated. Please use -target=recovery " +
"instead. This flag will be removed in Vault 1.1."))
}
c.flagTarget = "recovery"
}
// Create the client
client, err := c.Client()
if err != nil {
@ -349,7 +280,6 @@ func (c *OperatorRekeyCommand) init(client *api.Client) int {
status, err := fn(&api.RekeyInitRequest{
SecretShares: c.flagKeyShares,
SecretThreshold: c.flagKeyThreshold,
StoredShares: c.flagStoredShares,
PGPKeys: c.flagPGPKeys,
Backup: c.flagBackup,
RequireVerification: c.flagVerify,

View file

@ -167,7 +167,7 @@ func TestOperatorUnsealCommand_Format(t *testing.T) {
Client: client,
}
args, format, _ := setupEnv([]string{"unseal", "-format", "json"})
args, format, _ := setupEnv([]string{"operator", "unseal", "-format", "json"})
if format != "json" {
t.Fatalf("expected %q, got %q", "json", format)
}

View file

@ -1,57 +0,0 @@
package command
import (
"github.com/mitchellh/cli"
)
// Deprecation
// TODO: remove in 0.9.0
var _ cli.Command = (*PoliciesDeprecatedCommand)(nil)
type PoliciesDeprecatedCommand struct {
*BaseCommand
}
func (c *PoliciesDeprecatedCommand) Synopsis() string { return "" }
func (c *PoliciesDeprecatedCommand) Help() string {
return (&PolicyListCommand{
BaseCommand: c.BaseCommand,
}).Help()
}
func (c *PoliciesDeprecatedCommand) Run(args []string) int {
oargs := args
f := c.flagSet(FlagSetHTTP)
if err := f.Parse(args); err != nil {
c.UI.Error(err.Error())
return 1
}
args = f.Args()
// Got an arg, this is trying to read a policy
if len(args) > 0 {
return (&PolicyReadCommand{
BaseCommand: &BaseCommand{
UI: c.UI,
client: c.client,
tokenHelper: c.tokenHelper,
flagAddress: c.flagAddress,
},
}).Run(oargs)
}
// No args, probably ran "vault policies" and we want to translate that to
// "vault policy list"
return (&PolicyListCommand{
BaseCommand: &BaseCommand{
UI: c.UI,
client: c.client,
tokenHelper: c.tokenHelper,
flagAddress: c.flagAddress,
},
}).Run(oargs)
}

View file

@ -1,96 +0,0 @@
package command
import (
"strings"
"testing"
"github.com/mitchellh/cli"
)
func testPoliciesDeprecatedCommand(tb testing.TB) (*cli.MockUi, *PoliciesDeprecatedCommand) {
tb.Helper()
ui := cli.NewMockUi()
return ui, &PoliciesDeprecatedCommand{
BaseCommand: &BaseCommand{
UI: ui,
},
}
}
func TestPoliciesDeprecatedCommand_Run(t *testing.T) {
t.Parallel()
// TODO: remove in 0.9.0
t.Run("deprecated_arg", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
ui, cmd := testPoliciesDeprecatedCommand(t)
cmd.client = client
// vault policies ARG -> vault policy read ARG
code := cmd.Run([]string{"default"})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
}
stdout := ui.OutputWriter.String()
if expected := "token/"; !strings.Contains(stdout, expected) {
t.Errorf("expected %q to contain %q", stdout, expected)
}
})
t.Run("deprecated_no_args", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
ui, cmd := testPoliciesDeprecatedCommand(t)
cmd.client = client
// vault policies -> vault policy list
code := cmd.Run([]string{})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
}
stdout := ui.OutputWriter.String()
if expected := "root"; !strings.Contains(stdout, expected) {
t.Errorf("expected %q to contain %q", stdout, expected)
}
})
t.Run("deprecated_with_flags", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
ui, cmd := testPoliciesDeprecatedCommand(t)
cmd.client = client
// vault policies -flag -> vault policy list
code := cmd.Run([]string{
"-address", client.Address(),
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d: %s", code, exp, ui.ErrorWriter.String())
}
stdout := ui.OutputWriter.String()
if expected := "root"; !strings.Contains(stdout, expected) {
t.Errorf("expected %q to contain %q", stdout, expected)
}
})
t.Run("no_tabs", func(t *testing.T) {
t.Parallel()
_, cmd := testPoliciesDeprecatedCommand(t)
assertNoTabs(t, cmd)
})
}

View file

@ -6,6 +6,7 @@ import (
"encoding/base64"
"encoding/hex"
"fmt"
"github.com/hashicorp/vault/helper/metricsutil"
"io"
"io/ioutil"
"net"
@ -23,6 +24,7 @@ import (
metrics "github.com/armon/go-metrics"
"github.com/armon/go-metrics/circonus"
"github.com/armon/go-metrics/datadog"
"github.com/armon/go-metrics/prometheus"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
multierror "github.com/hashicorp/go-multierror"
@ -469,7 +471,8 @@ func (c *ServerCommand) Run(args []string) int {
"in a Docker container, provide the IPC_LOCK cap to the container."))
}
if err := c.setupTelemetry(config); err != nil {
metricsHelper, err := c.setupTelemetry(config)
if err != nil {
c.UI.Error(fmt.Sprintf("Error initializing telemetry: %s", err))
return 1
}
@ -563,6 +566,7 @@ func (c *ServerCommand) Run(args []string) int {
AllLoggers: allLoggers,
BuiltinRegistry: builtinplugins.Registry,
DisableKeyEncodingChecks: config.DisablePrintableCheck,
MetricsHelper: metricsHelper,
}
if c.flagDev {
coreConfig.DevToken = c.flagDevRootTokenID
@ -1363,24 +1367,29 @@ func (c *ServerCommand) enableDev(core *vault.Core, coreConfig *vault.CoreConfig
}
// Upgrade the default K/V store
if !c.flagDevLeasedKV && !c.flagDevKVV1 {
req := &logical.Request{
Operation: logical.UpdateOperation,
ClientToken: init.RootToken,
Path: "sys/mounts/secret/tune",
Data: map[string]interface{}{
"options": map[string]string{
"version": "2",
},
kvVer := "2"
if c.flagDevKVV1 {
kvVer = "1"
}
req := &logical.Request{
Operation: logical.UpdateOperation,
ClientToken: init.RootToken,
Path: "sys/mounts/secret",
Data: map[string]interface{}{
"type": "kv",
"path": "secret/",
"description": "key/value secret storage",
"options": map[string]string{
"version": kvVer,
},
}
resp, err := core.HandleRequest(ctx, req)
if err != nil {
return nil, errwrap.Wrapf("error upgrading default K/V store: {{err}}", err)
}
if resp.IsError() {
return nil, errwrap.Wrapf("failed to upgrade default K/V store: {{err}}", resp.Error())
}
},
}
resp, err := core.HandleRequest(ctx, req)
if err != nil {
return nil, errwrap.Wrapf("error creating default K/V store: {{err}}", err)
}
if resp.IsError() {
return nil, errwrap.Wrapf("failed to create default K/V store: {{err}}", resp.Error())
}
return init, nil
@ -1686,8 +1695,8 @@ func (c *ServerCommand) detectRedirect(detect physical.RedirectDetect,
return url.String(), nil
}
// setupTelemetry is used to setup the telemetry sub-systems
func (c *ServerCommand) setupTelemetry(config *server.Config) error {
// setupTelemetry is used to setup the telemetry sub-systems and returns the in-memory sink to be used in http configuration
func (c *ServerCommand) setupTelemetry(config *server.Config) (*metricsutil.MetricsHelper, error) {
/* Setup telemetry
Aggregate on 10 second intervals for 1 minute. Expose the
metrics over stderr when there is a SIGUSR1 received.
@ -1696,10 +1705,10 @@ func (c *ServerCommand) setupTelemetry(config *server.Config) error {
metrics.DefaultInmemSignal(inm)
var telConfig *server.Telemetry
if config.Telemetry == nil {
telConfig = &server.Telemetry{}
} else {
if config.Telemetry != nil {
telConfig = config.Telemetry
} else {
telConfig = &server.Telemetry{}
}
metricsConf := metrics.DefaultConfig("vault")
@ -1707,10 +1716,29 @@ func (c *ServerCommand) setupTelemetry(config *server.Config) error {
// Configure the statsite sink
var fanout metrics.FanoutSink
var prometheusEnabled bool
// Configure the Prometheus sink
if telConfig.PrometheusRetentionTime != 0 {
prometheusEnabled = true
prometheusOpts := prometheus.PrometheusOpts{
Expiration: telConfig.PrometheusRetentionTime,
}
sink, err := prometheus.NewPrometheusSinkFrom(prometheusOpts)
if err != nil {
return nil, err
}
fanout = append(fanout, sink)
}
metricHelper := metricsutil.NewMetricsHelper(inm, prometheusEnabled)
if telConfig.StatsiteAddr != "" {
sink, err := metrics.NewStatsiteSink(telConfig.StatsiteAddr)
if err != nil {
return err
return nil, err
}
fanout = append(fanout, sink)
}
@ -1719,7 +1747,7 @@ func (c *ServerCommand) setupTelemetry(config *server.Config) error {
if telConfig.StatsdAddr != "" {
sink, err := metrics.NewStatsdSink(telConfig.StatsdAddr)
if err != nil {
return err
return nil, err
}
fanout = append(fanout, sink)
}
@ -1755,7 +1783,7 @@ func (c *ServerCommand) setupTelemetry(config *server.Config) error {
sink, err := circonus.NewCirconusSink(cfg)
if err != nil {
return err
return nil, err
}
sink.Start()
fanout = append(fanout, sink)
@ -1770,21 +1798,29 @@ func (c *ServerCommand) setupTelemetry(config *server.Config) error {
sink, err := datadog.NewDogStatsdSink(telConfig.DogStatsDAddr, metricsConf.HostName)
if err != nil {
return errwrap.Wrapf("failed to start DogStatsD sink: {{err}}", err)
return nil, errwrap.Wrapf("failed to start DogStatsD sink: {{err}}", err)
}
sink.SetTags(tags)
fanout = append(fanout, sink)
}
// Initialize the global sink
if len(fanout) > 0 {
fanout = append(fanout, inm)
metrics.NewGlobal(metricsConf, fanout)
if len(fanout) > 1 {
// Hostname enabled will create poor quality metrics name for prometheus
if !telConfig.DisableHostname {
c.UI.Warn("telemetry.disable_hostname has been set to false. Recommended setting is true for Prometheus to avoid poorly named metrics.")
}
} else {
metricsConf.EnableHostname = false
metrics.NewGlobal(metricsConf, inm)
}
return nil
fanout = append(fanout, inm)
_, err := metrics.NewGlobal(metricsConf, fanout)
if err != nil {
return nil, err
}
return metricHelper, nil
}
func (c *ServerCommand) Reload(lock *sync.RWMutex, reloadFuncs *map[string][]reload.ReloadFunc, configPath []string) error {

View file

@ -19,6 +19,10 @@ import (
"github.com/hashicorp/vault/helper/parseutil"
)
const (
prometheusDefaultRetentionTime = 24 * time.Hour
)
// Config is the configuration for the vault server.
type Config struct {
Listeners []*Listener `hcl:"-"`
@ -98,7 +102,10 @@ func DevConfig(ha, transactional bool) *Config {
EnableUI: true,
Telemetry: &Telemetry{},
Telemetry: &Telemetry{
PrometheusRetentionTime: prometheusDefaultRetentionTime,
DisableHostname: true,
},
}
switch {
@ -233,6 +240,12 @@ type Telemetry struct {
// DogStatsdTags are the global tags that should be sent with each packet to dogstatsd
// It is a list of strings, where each string looks like "my_tag_name:my_tag_value"
DogStatsDTags []string `hcl:"dogstatsd_tags"`
// Prometheus:
// PrometheusRetentionTime is the retention time for prometheus metrics if greater than 0.
// Default: 24h
PrometheusRetentionTime time.Duration `hcl:-`
PrometheusRetentionTimeRaw interface{} `hcl:"prometheus_retention_time"`
}
func (s *Telemetry) GoString() string {
@ -864,5 +877,15 @@ func parseTelemetry(result *Config, list *ast.ObjectList) error {
if err := hcl.DecodeObject(&result.Telemetry, item.Val); err != nil {
return multierror.Prefix(err, "telemetry:")
}
if result.Telemetry.PrometheusRetentionTimeRaw != nil {
var err error
if result.Telemetry.PrometheusRetentionTime, err = parseutil.ParseDurationSecond(result.Telemetry.PrometheusRetentionTimeRaw); err != nil {
return err
}
} else {
result.Telemetry.PrometheusRetentionTime = prometheusDefaultRetentionTime
}
return nil
}

View file

@ -48,11 +48,12 @@ func TestLoadConfigFile(t *testing.T) {
},
Telemetry: &Telemetry{
StatsdAddr: "bar",
StatsiteAddr: "foo",
DisableHostname: false,
DogStatsDAddr: "127.0.0.1:7254",
DogStatsDTags: []string{"tag_1:val_1", "tag_2:val_2"},
StatsdAddr: "bar",
StatsiteAddr: "foo",
DisableHostname: false,
DogStatsDAddr: "127.0.0.1:7254",
DogStatsDTags: []string{"tag_1:val_1", "tag_2:val_2"},
PrometheusRetentionTime: prometheusDefaultRetentionTime,
},
DisableCache: true,
@ -121,11 +122,13 @@ func TestLoadConfigFile_topLevel(t *testing.T) {
},
Telemetry: &Telemetry{
StatsdAddr: "bar",
StatsiteAddr: "foo",
DisableHostname: false,
DogStatsDAddr: "127.0.0.1:7254",
DogStatsDTags: []string{"tag_1:val_1", "tag_2:val_2"},
StatsdAddr: "bar",
StatsiteAddr: "foo",
DisableHostname: false,
DogStatsDAddr: "127.0.0.1:7254",
DogStatsDTags: []string{"tag_1:val_1", "tag_2:val_2"},
PrometheusRetentionTime: 30 * time.Second,
PrometheusRetentionTimeRaw: "30s",
},
DisableCache: true,
@ -202,6 +205,7 @@ func TestLoadConfigFile_json(t *testing.T) {
CirconusCheckTags: "",
CirconusBrokerID: "",
CirconusBrokerSelectTag: "",
PrometheusRetentionTime: prometheusDefaultRetentionTime,
},
MaxLeaseTTL: 10 * time.Hour,
@ -288,6 +292,8 @@ func TestLoadConfigFile_json2(t *testing.T) {
CirconusCheckTags: "cat1:tag1,cat2:tag2",
CirconusBrokerID: "0",
CirconusBrokerSelectTag: "dc:sfo",
PrometheusRetentionTime: 30 * time.Second,
PrometheusRetentionTimeRaw: "30s",
},
}
if !reflect.DeepEqual(config, expected) {
@ -336,9 +342,10 @@ func TestLoadConfigDir(t *testing.T) {
EnableRawEndpoint: true,
Telemetry: &Telemetry{
StatsiteAddr: "qux",
StatsdAddr: "baz",
DisableHostname: true,
StatsiteAddr: "qux",
StatsdAddr: "baz",
DisableHostname: true,
PrometheusRetentionTime: prometheusDefaultRetentionTime,
},
MaxLeaseTTL: 10 * time.Hour,

View file

@ -72,7 +72,7 @@ func listenerWrapProxy(ln net.Listener, config map[string]interface{}) (net.List
return newLn, nil
}
func listenerWrapTLS(
func ListenerWrapTLS(
ln net.Listener,
props map[string]string,
config map[string]interface{},

View file

@ -35,7 +35,7 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer, ui cli.Ui) (
return nil, nil, nil, err
}
ln = tcpKeepAliveListener{ln.(*net.TCPListener)}
ln = TCPKeepAliveListener{ln.(*net.TCPListener)}
ln, err = listenerWrapProxy(ln, config)
if err != nil {
@ -94,20 +94,20 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer, ui cli.Ui) (
config["x_forwarded_for_reject_not_authorized"] = true
}
return listenerWrapTLS(ln, props, config, ui)
return ListenerWrapTLS(ln, props, config, ui)
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// TCPKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually
// go away.
//
// This is copied directly from the Go source code.
type tcpKeepAliveListener struct {
type TCPKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
func (ln TCPKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return

View file

@ -26,6 +26,7 @@ telemetry {
statsite_address = "foo"
dogstatsd_addr = "127.0.0.1:7254"
dogstatsd_tags = ["tag_1:val_1", "tag_2:val_2"]
prometheus_retention_time = "30s"
}
max_lease_ttl = "10h"

View file

@ -42,6 +42,7 @@
"circonus_check_display_name": "node1:vault",
"circonus_check_tags": "cat1:tag1,cat2:tag2",
"circonus_broker_id": "0",
"circonus_broker_select_tag": "dc:sfo"
"circonus_broker_select_tag": "dc:sfo",
"prometheus_retention_time": "30s"
}
}

View file

@ -29,9 +29,6 @@ type TokenCreateCommand struct {
flagType string
flagMetadata map[string]string
flagPolicies []string
// Deprecated flags
flagLease time.Duration
}
func (c *TokenCreateCommand) Synopsis() string {
@ -179,15 +176,6 @@ func (c *TokenCreateCommand) Flags() *FlagSets {
"specified multiple times to attach multiple policies.",
})
// Deprecated flags
// TODO: remove in 0.9.0
f.DurationVar(&DurationVar{
Name: "lease", // prefer -ttl
Target: &c.flagLease,
Default: 0,
Hidden: true,
})
return set
}
@ -213,14 +201,6 @@ func (c *TokenCreateCommand) Run(args []string) int {
return 1
}
// TODO: remove in 0.9.0
if c.flagLease != 0 {
if Format(c.UI) == "table" {
c.UI.Warn("The -lease flag is deprecated. Please use -ttl instead.")
c.flagTTL = c.flagLease
}
}
if c.flagType == "batch" {
c.flagRenewable = false
}

View file

@ -97,19 +97,6 @@ func (c *TokenRenewCommand) Run(args []string) int {
// Use the local token
case 1:
token = strings.TrimSpace(args[0])
case 2:
// TODO: remove in 0.9.0 - backwards compat
if Format(c.UI) == "table" {
c.UI.Warn("Specifying increment as a second argument is deprecated. " +
"Please use -increment instead.")
}
token = strings.TrimSpace(args[0])
parsed, err := time.ParseDuration(appendDurationSuffix(args[1]))
if err != nil {
c.UI.Error(fmt.Sprintf("Invalid increment: %s", err))
return 1
}
increment = parsed
default:
c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args)))
return 1

View file

@ -73,6 +73,7 @@ func newRegistry() *registry {
"jwt": credJWT.Factory,
"kubernetes": credKube.Factory,
"ldap": credLdap.Factory,
"oidc": credJWT.Factory,
"okta": credOkta.Factory,
"radius": credRadius.Factory,
"userpass": credUserpass.Factory,

View file

@ -17,12 +17,21 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"math/big"
"strings"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/errutil"
)
// This can be one of a few key types so the different params may or may not be filled
type ClusterKeyParams struct {
Type string `json:"type" structs:"type" mapstructure:"type"`
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
}
// Secret is used to attempt to unmarshal a Vault secret
// JSON response, as a convenience
type Secret struct {

View file

@ -22,26 +22,31 @@ func ConfigFields() map[string]*framework.FieldSchema {
Type: framework.TypeString,
Default: "ldap://127.0.0.1",
Description: "LDAP URL to connect to (default: ldap://127.0.0.1). Multiple URLs can be specified by concatenating them with commas; they will be tried in-order.",
DisplayName: "URL",
},
"userdn": {
Type: framework.TypeString,
Description: "LDAP domain to use for users (eg: ou=People,dc=example,dc=org)",
DisplayName: "User DN",
},
"binddn": {
Type: framework.TypeString,
Description: "LDAP DN for searching for the user DN (optional)",
DisplayName: "Name of Object to bind (binddn)",
},
"bindpass": {
Type: framework.TypeString,
Description: "LDAP password for searching for the user DN (optional)",
Type: framework.TypeString,
Description: "LDAP password for searching for the user DN (optional)",
DisplaySensitive: true,
},
"groupdn": {
Type: framework.TypeString,
Description: "LDAP search base to use for group membership search (eg: ou=Groups,dc=example,dc=org)",
DisplayName: "Group DN",
},
"groupfilter": {
@ -60,17 +65,20 @@ Default: (|(memberUid={{.Username}})(member={{.UserDN}})(uniqueMember={{.UserDN}
in order to enumerate user group membership.
Examples: "cn" or "memberOf", etc.
Default: cn`,
DisplayName: "Group Attribute",
},
"upndomain": {
Type: framework.TypeString,
Description: "Enables userPrincipalDomain login with [username]@UPNDomain (optional)",
DisplayName: "User Principal (UPN) Domain",
},
"userattr": {
Type: framework.TypeString,
Default: "cn",
Description: "Attribute used for users (default: cn)",
DisplayName: "User Attribute",
},
"certificate": {
@ -81,28 +89,35 @@ Default: cn`,
"discoverdn": {
Type: framework.TypeBool,
Description: "Use anonymous bind to discover the bind DN of a user (optional)",
DisplayName: "Discover DN",
},
"insecure_tls": {
Type: framework.TypeBool,
Description: "Skip LDAP server SSL Certificate verification - VERY insecure (optional)",
DisplayName: "Insecure TLS",
},
"starttls": {
Type: framework.TypeBool,
Description: "Issue a StartTLS command after establishing unencrypted connection (optional)",
DisplayName: "Issue StartTLS command after establishing an unencrypted connection",
},
"tls_min_version": {
Type: framework.TypeString,
Default: "tls12",
Description: "Minimum TLS version to use. Accepted values are 'tls10', 'tls11' or 'tls12'. Defaults to 'tls12'",
Type: framework.TypeString,
Default: "tls12",
Description: "Minimum TLS version to use. Accepted values are 'tls10', 'tls11' or 'tls12'. Defaults to 'tls12'",
DisplayName: "Minimum TLS Version",
AllowedValues: []interface{}{"tls10", "tls11", "tls12"},
},
"tls_max_version": {
Type: framework.TypeString,
Default: "tls12",
Description: "Maximum TLS version to use. Accepted values are 'tls10', 'tls11' or 'tls12'. Defaults to 'tls12'",
Type: framework.TypeString,
Default: "tls12",
Description: "Maximum TLS version to use. Accepted values are 'tls10', 'tls11' or 'tls12'. Defaults to 'tls12'",
DisplayName: "Maxumum TLS Version",
AllowedValues: []interface{}{"tls10", "tls11", "tls12"},
},
"deny_null_bind": {

View file

@ -0,0 +1,104 @@
package metricsutil
import (
"bytes"
"encoding/json"
"fmt"
"github.com/armon/go-metrics"
"github.com/hashicorp/vault/logical"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/common/expfmt"
"strings"
)
const (
OpenMetricsMIMEType = "application/openmetrics-text"
)
const (
PrometheusMetricFormat = "prometheus"
)
type MetricsHelper struct {
inMemSink *metrics.InmemSink
PrometheusEnabled bool
}
func NewMetricsHelper(inMem *metrics.InmemSink, enablePrometheus bool) *MetricsHelper{
return &MetricsHelper{inMem, enablePrometheus}
}
func FormatFromRequest(req *logical.Request) (string) {
acceptHeaders := req.Headers["Accept"]
if len(acceptHeaders) > 0 {
acceptHeader := acceptHeaders[0]
if strings.HasPrefix(acceptHeader, OpenMetricsMIMEType) {
return "prometheus"
}
}
return ""
}
func (m *MetricsHelper) ResponseForFormat(format string) (*logical.Response, error) {
switch format {
case PrometheusMetricFormat:
return m.PrometheusResponse()
case "":
return m.GenericResponse()
default:
return nil, fmt.Errorf("metric response format \"%s\" unknown", format)
}
}
func (m *MetricsHelper) PrometheusResponse() (*logical.Response, error) {
if !m.PrometheusEnabled {
return &logical.Response{
Data: map[string]interface{}{
logical.HTTPContentType: "text/plain",
logical.HTTPRawBody: "prometheus is not enabled",
logical.HTTPStatusCode: 400,
},
}, nil
}
metricsFamilies, err := prometheus.DefaultGatherer.Gather()
if err != nil && len(metricsFamilies) == 0 {
return nil, fmt.Errorf("no prometheus metrics could be decoded: %s", err)
}
// Initialize a byte buffer.
buf := &bytes.Buffer{}
defer buf.Reset()
e := expfmt.NewEncoder(buf, expfmt.FmtText)
for _, mf := range metricsFamilies {
err := e.Encode(mf)
if err != nil {
return nil, fmt.Errorf("error during the encoding of metrics: %s", err)
}
}
return &logical.Response{
Data: map[string]interface{}{
logical.HTTPContentType: string(expfmt.FmtText),
logical.HTTPRawBody: buf.Bytes(),
logical.HTTPStatusCode: 200,
},
}, nil
}
func (m *MetricsHelper) GenericResponse() (*logical.Response, error) {
summary, err := m.inMemSink.DisplayMetrics(nil,nil)
if err != nil {
return nil, fmt.Errorf("error while fetching the in-memory metrics: %s", err)
}
content, err := json.Marshal(summary)
if err != nil {
return nil, fmt.Errorf("error while marshalling the in-memory metrics: %s", err)
}
return &logical.Response{
Data: map[string]interface{}{
logical.HTTPContentType: "application/json",
logical.HTTPRawBody: content,
logical.HTTPStatusCode: 200,
},
}, nil
}

View file

@ -453,3 +453,18 @@ func WaitForNCoresSealed(t testing.T, cluster *vault.TestCluster, n int) {
t.Fatalf("%d cores were not sealed", n)
}
func WaitForActiveNode(t testing.T, cluster *vault.TestCluster) *vault.TestClusterCore {
for i := 0; i < 10; i++ {
for _, core := range cluster.Cores {
if standby, _ := core.Core.Standby(); !standby {
return core
}
}
time.Sleep(time.Second)
}
t.Fatalf("node did not become active")
return nil
}

View file

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"github.com/go-test/deep"
"net/http"
"net/http/httptest"
"net/textproto"
@ -285,6 +286,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"Accept"},
},
"local": false,
"seal_wrap": false,
@ -334,6 +336,7 @@ func TestSysMounts_headerAuth(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"Accept"},
},
"local": false,
"seal_wrap": false,
@ -376,8 +379,8 @@ func TestSysMounts_headerAuth(t *testing.T) {
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}
if !reflect.DeepEqual(actual, expected) {
t.Fatalf("bad:\nExpected: %#v\nActual: %#v\n", expected, actual)
if diff := deep.Equal(actual, expected); len(diff) > 0 {
t.Fatalf("bad, diff: %#v", diff)
}
}

View file

@ -2,6 +2,7 @@ package http
import (
"encoding/json"
"github.com/go-test/deep"
"reflect"
"testing"
@ -45,6 +46,7 @@ func TestSysMounts(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"Accept"},
},
"local": false,
"seal_wrap": false,
@ -94,6 +96,7 @@ func TestSysMounts(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"Accept"},
},
"local": false,
"seal_wrap": false,
@ -135,8 +138,8 @@ func TestSysMounts(t *testing.T) {
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}
if !reflect.DeepEqual(actual, expected) {
t.Fatalf("bad: expected: %#v\nactual: %#v\n", expected, actual)
if diff := deep.Equal(actual, expected); len(diff) > 0 {
t.Fatalf("bad, diff: %#v", diff)
}
}
@ -197,6 +200,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"Accept"},
},
"local": false,
"seal_wrap": false,
@ -258,6 +262,7 @@ func TestSysMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"Accept"},
},
"local": false,
"seal_wrap": false,
@ -299,9 +304,10 @@ func TestSysMount(t *testing.T) {
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}
if !reflect.DeepEqual(actual, expected) {
t.Fatalf("bad: expected: %#v\nactual: %#v\n", expected, actual)
if diff := deep.Equal(actual, expected); len(diff) > 0 {
t.Fatalf("bad, diff: %#v", diff)
}
}
func TestSysMount_put(t *testing.T) {
@ -380,6 +386,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"Accept"},
},
"local": false,
"seal_wrap": false,
@ -441,6 +448,7 @@ func TestSysRemount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"Accept"},
},
"local": false,
"seal_wrap": false,
@ -483,7 +491,7 @@ func TestSysRemount(t *testing.T) {
}
if !reflect.DeepEqual(actual, expected) {
t.Fatalf("bad:\ngot\n%#v\nexpected\n%#v\n", actual, expected)
t.Fatalf("bad:\nExpected: %#v\nActual: %#v\n", expected, actual)
}
}
@ -532,6 +540,7 @@ func TestSysUnmount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"Accept"},
},
"local": false,
"seal_wrap": false,
@ -581,6 +590,7 @@ func TestSysUnmount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"Accept"},
},
"local": false,
"seal_wrap": false,
@ -622,8 +632,8 @@ func TestSysUnmount(t *testing.T) {
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}
if !reflect.DeepEqual(actual, expected) {
t.Fatalf("bad: %#v", actual)
if diff := deep.Equal(actual, expected); len(diff) > 0 {
t.Fatalf("bad, diff: %#v", diff)
}
}
@ -766,6 +776,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"Accept"},
},
"local": false,
"seal_wrap": false,
@ -827,6 +838,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"Accept"},
},
"local": false,
"seal_wrap": false,
@ -868,8 +880,8 @@ func TestSysTuneMount(t *testing.T) {
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}
if !reflect.DeepEqual(actual, expected) {
t.Fatalf("bad: %#v", actual)
if diff := deep.Equal(actual, expected); len(diff) > 0 {
t.Fatalf("bad, diff: %#v", diff)
}
// Shorter than system default
@ -956,6 +968,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"Accept"},
},
"local": false,
"seal_wrap": false,
@ -1017,6 +1030,7 @@ func TestSysTuneMount(t *testing.T) {
"default_lease_ttl": json.Number("0"),
"max_lease_ttl": json.Number("0"),
"force_no_cache": false,
"passthrough_request_headers": []interface{}{"Accept"},
},
"local": false,
"seal_wrap": false,
@ -1059,8 +1073,8 @@ func TestSysTuneMount(t *testing.T) {
expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"]
}
if !reflect.DeepEqual(actual, expected) {
t.Fatalf("bad:\nExpected: %#v\nActual:%#v", expected, actual)
if diff := deep.Equal(actual, expected); len(diff) > 0 {
t.Fatalf("bad, diff: %#v", diff)
}
// Check simple configuration endpoint

View file

@ -155,6 +155,7 @@ type OASSchema struct {
Format string `json:"format,omitempty"`
Pattern string `json:"pattern,omitempty"`
Enum []interface{} `json:"enum,omitempty"`
Default interface{} `json:"default,omitempty"`
Example interface{} `json:"example,omitempty"`
Deprecated bool `json:"deprecated,omitempty"`
Required bool `json:"required,omitempty"`
@ -263,6 +264,7 @@ func documentPath(p *Path, specialPaths *logical.Paths, backendType logical.Back
Type: t.baseType,
Pattern: t.pattern,
Enum: field.AllowedValues,
Default: field.Default,
DisplayName: field.DisplayName,
DisplayValue: field.DisplayValue,
DisplaySensitive: field.DisplaySensitive,
@ -321,6 +323,7 @@ func documentPath(p *Path, specialPaths *logical.Paths, backendType logical.Back
Format: openapiField.format,
Pattern: openapiField.pattern,
Enum: field.AllowedValues,
Default: field.Default,
Required: field.Required,
Deprecated: field.Deprecated,
DisplayName: field.DisplayName,

View file

@ -326,6 +326,7 @@ func TestOpenAPI_Paths(t *testing.T) {
},
"name": {
Type: TypeNameString,
Default: "Larry",
Description: "the name",
},
"age": {

View file

@ -85,6 +85,7 @@
"name": {
"type": "string",
"description": "the name",
"default": "Larry",
"pattern": "\\w([\\w-.]*\\w)?"
}
}

View file

@ -2,7 +2,6 @@ package plugin
import (
"context"
"net/rpc"
"sync/atomic"
"google.golang.org/grpc"
@ -13,16 +12,9 @@ import (
"github.com/hashicorp/vault/logical/plugin/pb"
)
var _ plugin.Plugin = (*BackendPlugin)(nil)
var _ plugin.GRPCPlugin = (*BackendPlugin)(nil)
var _ plugin.Plugin = (*GRPCBackendPlugin)(nil)
var _ plugin.GRPCPlugin = (*GRPCBackendPlugin)(nil)
// BackendPlugin is the plugin.Plugin implementation
type BackendPlugin struct {
*GRPCBackendPlugin
}
// GRPCBackendPlugin is the plugin.Plugin implementation that only supports GRPC
// transport
type GRPCBackendPlugin struct {
@ -34,26 +26,6 @@ type GRPCBackendPlugin struct {
plugin.NetRPCUnsupportedPlugin
}
// Server gets called when on plugin.Serve()
func (b *BackendPlugin) Server(broker *plugin.MuxBroker) (interface{}, error) {
return &backendPluginServer{
factory: b.Factory,
broker: broker,
// We pass the logger down into the backend so go-plugin will forward
// logs for us.
logger: b.Logger,
}, nil
}
// Client gets called on plugin.NewClient()
func (b BackendPlugin) Client(broker *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
return &backendPluginClient{
client: c,
broker: broker,
metadataMode: b.MetadataMode,
}, nil
}
func (b GRPCBackendPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error {
pb.RegisterBackendServer(s, &backendGRPCPluginServer{
broker: broker,

View file

@ -1,248 +0,0 @@
package plugin
import (
"context"
"errors"
"net/rpc"
log "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/logical"
)
var (
ErrClientInMetadataMode = errors.New("plugin client can not perform action while in metadata mode")
)
// backendPluginClient implements logical.Backend and is the
// go-plugin client.
type backendPluginClient struct {
broker *plugin.MuxBroker
client *rpc.Client
metadataMode bool
system logical.SystemView
logger log.Logger
}
// HandleRequestArgs is the args for HandleRequest method.
type HandleRequestArgs struct {
StorageID uint32
Request *logical.Request
}
// HandleRequestReply is the reply for HandleRequest method.
type HandleRequestReply struct {
Response *logical.Response
Error error
}
// SpecialPathsReply is the reply for SpecialPaths method.
type SpecialPathsReply struct {
Paths *logical.Paths
}
// SystemReply is the reply for System method.
type SystemReply struct {
SystemView logical.SystemView
Error error
}
// HandleExistenceCheckArgs is the args for HandleExistenceCheck method.
type HandleExistenceCheckArgs struct {
StorageID uint32
Request *logical.Request
}
// HandleExistenceCheckReply is the reply for HandleExistenceCheck method.
type HandleExistenceCheckReply struct {
CheckFound bool
Exists bool
Error error
}
// SetupArgs is the args for Setup method.
type SetupArgs struct {
StorageID uint32
LoggerID uint32
SysViewID uint32
Config map[string]string
BackendUUID string
}
// SetupReply is the reply for Setup method.
type SetupReply struct {
Error error
}
// TypeReply is the reply for the Type method.
type TypeReply struct {
Type logical.BackendType
}
func (b *backendPluginClient) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) {
if b.metadataMode {
return nil, ErrClientInMetadataMode
}
// Do not send the storage, since go-plugin cannot serialize
// interfaces. The server will pick up the storage from the shim.
req.Storage = nil
args := &HandleRequestArgs{
Request: req,
}
var reply HandleRequestReply
if req.Connection != nil {
oldConnState := req.Connection.ConnState
req.Connection.ConnState = nil
defer func() {
req.Connection.ConnState = oldConnState
}()
}
err := b.client.Call("Plugin.HandleRequest", args, &reply)
if err != nil {
return nil, err
}
if reply.Error != nil {
if reply.Error.Error() == logical.ErrUnsupportedOperation.Error() {
return nil, logical.ErrUnsupportedOperation
}
return reply.Response, reply.Error
}
return reply.Response, nil
}
func (b *backendPluginClient) SpecialPaths() *logical.Paths {
var reply SpecialPathsReply
err := b.client.Call("Plugin.SpecialPaths", new(interface{}), &reply)
if err != nil {
return nil
}
return reply.Paths
}
// System returns vault's system view. The backend client stores the view during
// Setup, so there is no need to shim the system just to get it back.
func (b *backendPluginClient) System() logical.SystemView {
return b.system
}
// Logger returns vault's logger. The backend client stores the logger during
// Setup, so there is no need to shim the logger just to get it back.
func (b *backendPluginClient) Logger() log.Logger {
return b.logger
}
func (b *backendPluginClient) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) {
if b.metadataMode {
return false, false, ErrClientInMetadataMode
}
// Do not send the storage, since go-plugin cannot serialize
// interfaces. The server will pick up the storage from the shim.
req.Storage = nil
args := &HandleExistenceCheckArgs{
Request: req,
}
var reply HandleExistenceCheckReply
if req.Connection != nil {
oldConnState := req.Connection.ConnState
req.Connection.ConnState = nil
defer func() {
req.Connection.ConnState = oldConnState
}()
}
err := b.client.Call("Plugin.HandleExistenceCheck", args, &reply)
if err != nil {
return false, false, err
}
if reply.Error != nil {
// THINKING: Should be be a switch on all error types?
if reply.Error.Error() == logical.ErrUnsupportedPath.Error() {
return false, false, logical.ErrUnsupportedPath
}
return false, false, reply.Error
}
return reply.CheckFound, reply.Exists, nil
}
func (b *backendPluginClient) Cleanup(ctx context.Context) {
b.client.Call("Plugin.Cleanup", new(interface{}), &struct{}{})
}
func (b *backendPluginClient) Initialize(ctx context.Context) error {
if b.metadataMode {
return ErrClientInMetadataMode
}
err := b.client.Call("Plugin.Initialize", new(interface{}), &struct{}{})
return err
}
func (b *backendPluginClient) InvalidateKey(ctx context.Context, key string) {
if b.metadataMode {
return
}
b.client.Call("Plugin.InvalidateKey", key, &struct{}{})
}
func (b *backendPluginClient) Setup(ctx context.Context, config *logical.BackendConfig) error {
// Shim logical.Storage
storageImpl := config.StorageView
if b.metadataMode {
storageImpl = &NOOPStorage{}
}
storageID := b.broker.NextId()
go b.broker.AcceptAndServe(storageID, &StorageServer{
impl: storageImpl,
})
// Shim logical.SystemView
sysViewImpl := config.System
if b.metadataMode {
sysViewImpl = &logical.StaticSystemView{}
}
sysViewID := b.broker.NextId()
go b.broker.AcceptAndServe(sysViewID, &SystemViewServer{
impl: sysViewImpl,
})
args := &SetupArgs{
StorageID: storageID,
SysViewID: sysViewID,
Config: config.Config,
BackendUUID: config.BackendUUID,
}
var reply SetupReply
err := b.client.Call("Plugin.Setup", args, &reply)
if err != nil {
return err
}
if reply.Error != nil {
return reply.Error
}
// Set system and logger for getter methods
b.system = config.System
b.logger = config.Logger
return nil
}
func (b *backendPluginClient) Type() logical.BackendType {
var reply TypeReply
err := b.client.Call("Plugin.Type", new(interface{}), &reply)
if err != nil {
return logical.TypeUnknown
}
return logical.BackendType(reply.Type)
}

View file

@ -1,147 +0,0 @@
package plugin
import (
"context"
"errors"
"net/rpc"
hclog "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil"
"github.com/hashicorp/vault/logical"
)
var (
ErrServerInMetadataMode = errors.New("plugin server can not perform action while in metadata mode")
)
// backendPluginServer is the RPC server that backendPluginClient talks to,
// it methods conforming to requirements by net/rpc
type backendPluginServer struct {
broker *plugin.MuxBroker
backend logical.Backend
factory logical.Factory
logger hclog.Logger
sysViewClient *rpc.Client
storageClient *rpc.Client
}
func (b *backendPluginServer) HandleRequest(args *HandleRequestArgs, reply *HandleRequestReply) error {
if pluginutil.InMetadataMode() {
return ErrServerInMetadataMode
}
storage := &StorageClient{client: b.storageClient}
args.Request.Storage = storage
resp, err := b.backend.HandleRequest(context.Background(), args.Request)
*reply = HandleRequestReply{
Response: resp,
Error: wrapError(err),
}
return nil
}
func (b *backendPluginServer) SpecialPaths(_ interface{}, reply *SpecialPathsReply) error {
*reply = SpecialPathsReply{
Paths: b.backend.SpecialPaths(),
}
return nil
}
func (b *backendPluginServer) HandleExistenceCheck(args *HandleExistenceCheckArgs, reply *HandleExistenceCheckReply) error {
if pluginutil.InMetadataMode() {
return ErrServerInMetadataMode
}
storage := &StorageClient{client: b.storageClient}
args.Request.Storage = storage
checkFound, exists, err := b.backend.HandleExistenceCheck(context.TODO(), args.Request)
*reply = HandleExistenceCheckReply{
CheckFound: checkFound,
Exists: exists,
Error: wrapError(err),
}
return nil
}
func (b *backendPluginServer) Cleanup(_ interface{}, _ *struct{}) error {
b.backend.Cleanup(context.Background())
// Close rpc clients
b.sysViewClient.Close()
b.storageClient.Close()
return nil
}
func (b *backendPluginServer) InvalidateKey(args string, _ *struct{}) error {
if pluginutil.InMetadataMode() {
return ErrServerInMetadataMode
}
b.backend.InvalidateKey(context.Background(), args)
return nil
}
// Setup dials into the plugin's broker to get a shimmed storage, logger, and
// system view of the backend. This method also instantiates the underlying
// backend through its factory func for the server side of the plugin.
func (b *backendPluginServer) Setup(args *SetupArgs, reply *SetupReply) error {
// Dial for storage
storageConn, err := b.broker.Dial(args.StorageID)
if err != nil {
*reply = SetupReply{
Error: wrapError(err),
}
return nil
}
rawStorageClient := rpc.NewClient(storageConn)
b.storageClient = rawStorageClient
storage := &StorageClient{client: rawStorageClient}
// Dial for sys view
sysViewConn, err := b.broker.Dial(args.SysViewID)
if err != nil {
*reply = SetupReply{
Error: wrapError(err),
}
return nil
}
rawSysViewClient := rpc.NewClient(sysViewConn)
b.sysViewClient = rawSysViewClient
sysView := &SystemViewClient{client: rawSysViewClient}
config := &logical.BackendConfig{
StorageView: storage,
Logger: b.logger,
System: sysView,
Config: args.Config,
BackendUUID: args.BackendUUID,
}
// Call the underlying backend factory after shims have been created
// to set b.backend
backend, err := b.factory(context.Background(), config)
if err != nil {
*reply = SetupReply{
Error: wrapError(err),
}
}
b.backend = backend
return nil
}
func (b *backendPluginServer) Type(_ interface{}, reply *TypeReply) error {
*reply = TypeReply{
Type: b.backend.Type(),
}
return nil
}

View file

@ -1,173 +0,0 @@
package plugin
import (
"context"
"testing"
"time"
log "github.com/hashicorp/go-hclog"
gplugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/logging"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/plugin/mock"
)
func TestBackendPlugin_impl(t *testing.T) {
var _ gplugin.Plugin = new(BackendPlugin)
var _ logical.Backend = new(backendPluginClient)
}
func TestBackendPlugin_HandleRequest(t *testing.T) {
b, cleanup := testBackend(t)
defer cleanup()
resp, err := b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.CreateOperation,
Path: "kv/foo",
Data: map[string]interface{}{
"value": "bar",
},
})
if err != nil {
t.Fatal(err)
}
if resp.Data["value"] != "bar" {
t.Fatalf("bad: %#v", resp)
}
}
func TestBackendPlugin_SpecialPaths(t *testing.T) {
b, cleanup := testBackend(t)
defer cleanup()
paths := b.SpecialPaths()
if paths == nil {
t.Fatal("SpecialPaths() returned nil")
}
}
func TestBackendPlugin_System(t *testing.T) {
b, cleanup := testBackend(t)
defer cleanup()
sys := b.System()
if sys == nil {
t.Fatal("System() returned nil")
}
actual := sys.DefaultLeaseTTL()
expected := 300 * time.Second
if actual != expected {
t.Fatalf("bad: %v, expected %v", actual, expected)
}
}
func TestBackendPlugin_Logger(t *testing.T) {
b, cleanup := testBackend(t)
defer cleanup()
logger := b.Logger()
if logger == nil {
t.Fatal("Logger() returned nil")
}
}
func TestBackendPlugin_HandleExistenceCheck(t *testing.T) {
b, cleanup := testBackend(t)
defer cleanup()
checkFound, exists, err := b.HandleExistenceCheck(context.Background(), &logical.Request{
Operation: logical.CreateOperation,
Path: "kv/foo",
Data: map[string]interface{}{"value": "bar"},
})
if err != nil {
t.Fatal(err)
}
if !checkFound {
t.Fatal("existence check not found for path 'kv/foo")
}
if exists {
t.Fatal("existence check should have returned 'false' for 'kv/foo'")
}
}
func TestBackendPlugin_Cleanup(t *testing.T) {
b, cleanup := testBackend(t)
defer cleanup()
b.Cleanup(context.Background())
}
func TestBackendPlugin_InvalidateKey(t *testing.T) {
b, cleanup := testBackend(t)
defer cleanup()
ctx := context.Background()
resp, err := b.HandleRequest(ctx, &logical.Request{
Operation: logical.ReadOperation,
Path: "internal",
})
if err != nil {
t.Fatal(err)
}
if resp.Data["value"] == "" {
t.Fatalf("bad: %#v, expected non-empty value", resp)
}
b.InvalidateKey(ctx, "internal")
resp, err = b.HandleRequest(ctx, &logical.Request{
Operation: logical.ReadOperation,
Path: "internal",
})
if err != nil {
t.Fatal(err)
}
if resp.Data["value"] != "" {
t.Fatalf("bad: expected empty response data, got %#v", resp)
}
}
func TestBackendPlugin_Setup(t *testing.T) {
_, cleanup := testBackend(t)
defer cleanup()
}
func testBackend(t *testing.T) (logical.Backend, func()) {
// Create a mock provider
pluginMap := map[string]gplugin.Plugin{
"backend": &BackendPlugin{
GRPCBackendPlugin: &GRPCBackendPlugin{
Factory: mock.Factory,
},
},
}
client, _ := gplugin.TestPluginRPCConn(t, pluginMap, nil)
cleanup := func() {
client.Close()
}
// Request the backend
raw, err := client.Dispense(BackendPluginName)
if err != nil {
t.Fatal(err)
}
b := raw.(logical.Backend)
err = b.Setup(context.Background(), &logical.BackendConfig{
Logger: logging.NewVaultLogger(log.Debug),
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: 300 * time.Second,
MaxLeaseTTLVal: 1800 * time.Second,
},
StorageView: &logical.InmemStorage{},
})
if err != nil {
t.Fatal(err)
}
return b, cleanup
}

View file

@ -16,6 +16,7 @@ import (
)
var ErrPluginShutdown = errors.New("plugin is shut down")
var ErrClientInMetadataMode = errors.New("plugin client can not perform action while in metadata mode")
// Validate backendGRPCPluginClient satisfies the logical.Backend interface
var _ logical.Backend = &backendGRPCPluginClient{}

View file

@ -2,6 +2,7 @@ package plugin
import (
"context"
"errors"
log "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
@ -11,6 +12,8 @@ import (
"google.golang.org/grpc"
)
var ErrServerInMetadataMode = errors.New("plugin server can not perform action while in metadata mode")
type backendGRPCPluginServer struct {
broker *plugin.GRPCBroker
backend logical.Backend

View file

@ -14,8 +14,8 @@ import (
)
func TestGRPCBackendPlugin_impl(t *testing.T) {
var _ gplugin.Plugin = new(BackendPlugin)
var _ logical.Backend = new(backendPluginClient)
var _ gplugin.Plugin = new(GRPCBackendPlugin)
var _ logical.Backend = new(backendGRPCPluginClient)
}
func TestGRPCBackendPlugin_HandleRequest(t *testing.T) {
@ -140,15 +140,13 @@ func TestGRPCBackendPlugin_Setup(t *testing.T) {
func testGRPCBackend(t *testing.T) (logical.Backend, func()) {
// Create a mock provider
pluginMap := map[string]gplugin.Plugin{
"backend": &BackendPlugin{
GRPCBackendPlugin: &GRPCBackendPlugin{
Factory: mock.Factory,
Logger: log.New(&log.LoggerOptions{
Level: log.Debug,
Output: os.Stderr,
JSONFormat: true,
}),
},
"backend": &GRPCBackendPlugin{
Factory: mock.Factory,
Logger: log.New(&log.LoggerOptions{
Level: log.Debug,
Output: os.Stderr,
JSONFormat: true,
}),
},
}
client, _ := gplugin.TestPluginGRPCConn(t, pluginMap)

View file

@ -108,3 +108,23 @@ func (s *GRPCStorageServer) Delete(ctx context.Context, args *pb.StorageDeleteAr
Err: pb.ErrToString(err),
}, nil
}
// NOOPStorage is used to deny access to the storage interface while running a
// backend plugin in metadata mode.
type NOOPStorage struct{}
func (s *NOOPStorage) List(_ context.Context, prefix string) ([]string, error) {
return []string{}, nil
}
func (s *NOOPStorage) Get(_ context.Context, key string) (*logical.StorageEntry, error) {
return nil, nil
}
func (s *NOOPStorage) Put(_ context.Context, entry *logical.StorageEntry) error {
return nil
}
func (s *NOOPStorage) Delete(_ context.Context, key string) error {
return nil
}

View file

@ -2,13 +2,9 @@ package plugin
import (
"context"
"crypto/ecdsa"
"crypto/rsa"
"encoding/gob"
"errors"
"fmt"
"sync"
"time"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
@ -18,28 +14,6 @@ import (
"github.com/hashicorp/vault/logical"
)
// init registers basic structs with gob which will be used to transport complex
// types through the plugin server and client.
func init() {
// Common basic structs
gob.Register([]interface{}{})
gob.Register(map[string]interface{}{})
gob.Register(map[string]string{})
gob.Register(map[string]int{})
// Register these types since we have to serialize and de-serialize
// tls.ConnectionState over the wire as part of logical.Request.Connection.
gob.Register(rsa.PublicKey{})
gob.Register(ecdsa.PublicKey{})
gob.Register(time.Duration(0))
// Custom common error types for requests. If you add something here, you must
// also add it to the switch statement in `wrapError`!
gob.Register(&plugin.BasicError{})
gob.Register(logical.CodedError(0, ""))
gob.Register(&logical.StatusBadRequest{})
}
// BackendPluginClient is a wrapper around backendPluginClient
// that also contains its plugin.Client instance. It's primarily
// used to cleanly kill the client on Cleanup()
@ -98,11 +72,13 @@ func NewBackend(ctx context.Context, pluginName string, pluginType consts.Plugin
func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
// pluginMap is the map of plugins we can dispense.
pluginSet := map[int]plugin.PluginSet{
// Version 3 used to supports both protocols. We want to keep it around
// since it's possible old plugins built against this version will still
// work with gRPC. There is currently no difference between version 3
// and version 4.
3: plugin.PluginSet{
"backend": &BackendPlugin{
GRPCBackendPlugin: &GRPCBackendPlugin{
MetadataMode: isMetadataMode,
},
"backend": &GRPCBackendPlugin{
MetadataMode: isMetadataMode,
},
},
4: plugin.PluginSet{
@ -142,10 +118,6 @@ func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunne
// We should have a logical backend type now. This feels like a normal interface
// implementation but is in fact over an RPC connection.
switch raw.(type) {
case *backendPluginClient:
logger.Warn("plugin is using deprecated netRPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name)
backend = raw.(*backendPluginClient)
transport = "netRPC"
case *backendGRPCPluginClient:
backend = raw.(*backendGRPCPluginClient)
transport = "gRPC"

View file

@ -39,12 +39,14 @@ func Serve(opts *ServeOpts) error {
// pluginMap is the map of plugins we can dispense.
pluginSets := map[int]plugin.PluginSet{
// Version 3 used to supports both protocols. We want to keep it around
// since it's possible old plugins built against this version will still
// work with gRPC. There is currently no difference between version 3
// and version 4.
3: plugin.PluginSet{
"backend": &BackendPlugin{
GRPCBackendPlugin: &GRPCBackendPlugin{
Factory: opts.BackendFactoryFunc,
Logger: logger,
},
"backend": &GRPCBackendPlugin{
Factory: opts.BackendFactoryFunc,
Logger: logger,
},
},
4: plugin.PluginSet{
@ -74,13 +76,6 @@ func Serve(opts *ServeOpts) error {
},
}
// If we do not have gRPC support fallback to version 3
// Remove this block in 0.13
if !pluginutil.GRPCSupport() {
serveOpts.GRPCServer = nil
delete(pluginSets, 4)
}
plugin.Serve(serveOpts)
return nil

View file

@ -1,139 +0,0 @@
package plugin
import (
"context"
"net/rpc"
"github.com/hashicorp/vault/logical"
)
// StorageClient is an implementation of logical.Storage that communicates
// over RPC.
type StorageClient struct {
client *rpc.Client
}
func (s *StorageClient) List(_ context.Context, prefix string) ([]string, error) {
var reply StorageListReply
err := s.client.Call("Plugin.List", prefix, &reply)
if err != nil {
return reply.Keys, err
}
if reply.Error != nil {
return reply.Keys, reply.Error
}
return reply.Keys, nil
}
func (s *StorageClient) Get(_ context.Context, key string) (*logical.StorageEntry, error) {
var reply StorageGetReply
err := s.client.Call("Plugin.Get", key, &reply)
if err != nil {
return nil, err
}
if reply.Error != nil {
return nil, reply.Error
}
return reply.StorageEntry, nil
}
func (s *StorageClient) Put(_ context.Context, entry *logical.StorageEntry) error {
var reply StoragePutReply
err := s.client.Call("Plugin.Put", entry, &reply)
if err != nil {
return err
}
if reply.Error != nil {
return reply.Error
}
return nil
}
func (s *StorageClient) Delete(_ context.Context, key string) error {
var reply StorageDeleteReply
err := s.client.Call("Plugin.Delete", key, &reply)
if err != nil {
return err
}
if reply.Error != nil {
return reply.Error
}
return nil
}
// StorageServer is a net/rpc compatible structure for serving
type StorageServer struct {
impl logical.Storage
}
func (s *StorageServer) List(prefix string, reply *StorageListReply) error {
keys, err := s.impl.List(context.Background(), prefix)
*reply = StorageListReply{
Keys: keys,
Error: wrapError(err),
}
return nil
}
func (s *StorageServer) Get(key string, reply *StorageGetReply) error {
storageEntry, err := s.impl.Get(context.Background(), key)
*reply = StorageGetReply{
StorageEntry: storageEntry,
Error: wrapError(err),
}
return nil
}
func (s *StorageServer) Put(entry *logical.StorageEntry, reply *StoragePutReply) error {
err := s.impl.Put(context.Background(), entry)
*reply = StoragePutReply{
Error: wrapError(err),
}
return nil
}
func (s *StorageServer) Delete(key string, reply *StorageDeleteReply) error {
err := s.impl.Delete(context.Background(), key)
*reply = StorageDeleteReply{
Error: wrapError(err),
}
return nil
}
type StorageListReply struct {
Keys []string
Error error
}
type StorageGetReply struct {
StorageEntry *logical.StorageEntry
Error error
}
type StoragePutReply struct {
Error error
}
type StorageDeleteReply struct {
Error error
}
// NOOPStorage is used to deny access to the storage interface while running a
// backend plugin in metadata mode.
type NOOPStorage struct{}
func (s *NOOPStorage) List(_ context.Context, prefix string) ([]string, error) {
return []string{}, nil
}
func (s *NOOPStorage) Get(_ context.Context, key string) (*logical.StorageEntry, error) {
return nil, nil
}
func (s *NOOPStorage) Put(_ context.Context, entry *logical.StorageEntry) error {
return nil
}
func (s *NOOPStorage) Delete(_ context.Context, key string) error {
return nil
}

View file

@ -11,22 +11,7 @@ import (
)
func TestStorage_impl(t *testing.T) {
var _ logical.Storage = new(StorageClient)
}
func TestStorage_RPC(t *testing.T) {
client, server := plugin.TestRPCConn(t)
defer client.Close()
storage := &logical.InmemStorage{}
server.RegisterName("Plugin", &StorageServer{
impl: storage,
})
testStorage := &StorageClient{client: client}
logical.TestStorage(t, testStorage)
var _ logical.Storage = new(GRPCStorageClient)
}
func TestStorage_GRPC(t *testing.T) {

View file

@ -1,351 +0,0 @@
package plugin
import (
"context"
"net/rpc"
"time"
"fmt"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/license"
"github.com/hashicorp/vault/helper/pluginutil"
"github.com/hashicorp/vault/helper/wrapping"
"github.com/hashicorp/vault/logical"
)
type SystemViewClient struct {
client *rpc.Client
}
func (s *SystemViewClient) DefaultLeaseTTL() time.Duration {
var reply DefaultLeaseTTLReply
err := s.client.Call("Plugin.DefaultLeaseTTL", new(interface{}), &reply)
if err != nil {
return 0
}
return reply.DefaultLeaseTTL
}
func (s *SystemViewClient) MaxLeaseTTL() time.Duration {
var reply MaxLeaseTTLReply
err := s.client.Call("Plugin.MaxLeaseTTL", new(interface{}), &reply)
if err != nil {
return 0
}
return reply.MaxLeaseTTL
}
func (s *SystemViewClient) SudoPrivilege(ctx context.Context, path string, token string) bool {
var reply SudoPrivilegeReply
args := &SudoPrivilegeArgs{
Path: path,
Token: token,
}
err := s.client.Call("Plugin.SudoPrivilege", args, &reply)
if err != nil {
return false
}
return reply.Sudo
}
func (s *SystemViewClient) Tainted() bool {
var reply TaintedReply
err := s.client.Call("Plugin.Tainted", new(interface{}), &reply)
if err != nil {
return false
}
return reply.Tainted
}
func (s *SystemViewClient) CachingDisabled() bool {
var reply CachingDisabledReply
err := s.client.Call("Plugin.CachingDisabled", new(interface{}), &reply)
if err != nil {
return false
}
return reply.CachingDisabled
}
func (s *SystemViewClient) ReplicationState() consts.ReplicationState {
var reply ReplicationStateReply
err := s.client.Call("Plugin.ReplicationState", new(interface{}), &reply)
if err != nil {
return consts.ReplicationUnknown
}
return reply.ReplicationState
}
func (s *SystemViewClient) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
var reply ResponseWrapDataReply
// Do not allow JWTs to be returned
args := &ResponseWrapDataArgs{
Data: data,
TTL: ttl,
JWT: false,
}
err := s.client.Call("Plugin.ResponseWrapData", args, &reply)
if err != nil {
return nil, err
}
if reply.Error != nil {
return nil, reply.Error
}
return reply.ResponseWrapInfo, nil
}
func (s *SystemViewClient) LookupPlugin(_ context.Context, _ string, _ consts.PluginType) (*pluginutil.PluginRunner, error) {
return nil, fmt.Errorf("cannot call LookupPlugin from a plugin backend")
}
func (s *SystemViewClient) HasFeature(feature license.Features) bool {
// Not implemented
return false
}
func (s *SystemViewClient) MlockEnabled() bool {
var reply MlockEnabledReply
err := s.client.Call("Plugin.MlockEnabled", new(interface{}), &reply)
if err != nil {
return false
}
return reply.MlockEnabled
}
func (s *SystemViewClient) LocalMount() bool {
var reply LocalMountReply
err := s.client.Call("Plugin.LocalMount", new(interface{}), &reply)
if err != nil {
return false
}
return reply.Local
}
func (s *SystemViewClient) EntityInfo(entityID string) (*logical.Entity, error) {
var reply EntityInfoReply
args := &EntityInfoArgs{
EntityID: entityID,
}
err := s.client.Call("Plugin.EntityInfo", args, &reply)
if err != nil {
return nil, err
}
if reply.Error != nil {
return nil, reply.Error
}
return reply.Entity, nil
}
func (s *SystemViewClient) PluginEnv(_ context.Context) (*logical.PluginEnvironment, error) {
var reply PluginEnvReply
err := s.client.Call("Plugin.PluginEnv", new(interface{}), &reply)
if err != nil {
return nil, err
}
if reply.Error != nil {
return nil, reply.Error
}
return reply.PluginEnvironment, nil
}
type SystemViewServer struct {
impl logical.SystemView
}
func (s *SystemViewServer) DefaultLeaseTTL(_ interface{}, reply *DefaultLeaseTTLReply) error {
ttl := s.impl.DefaultLeaseTTL()
*reply = DefaultLeaseTTLReply{
DefaultLeaseTTL: ttl,
}
return nil
}
func (s *SystemViewServer) MaxLeaseTTL(_ interface{}, reply *MaxLeaseTTLReply) error {
ttl := s.impl.MaxLeaseTTL()
*reply = MaxLeaseTTLReply{
MaxLeaseTTL: ttl,
}
return nil
}
func (s *SystemViewServer) SudoPrivilege(args *SudoPrivilegeArgs, reply *SudoPrivilegeReply) error {
sudo := s.impl.SudoPrivilege(context.Background(), args.Path, args.Token)
*reply = SudoPrivilegeReply{
Sudo: sudo,
}
return nil
}
func (s *SystemViewServer) Tainted(_ interface{}, reply *TaintedReply) error {
tainted := s.impl.Tainted()
*reply = TaintedReply{
Tainted: tainted,
}
return nil
}
func (s *SystemViewServer) CachingDisabled(_ interface{}, reply *CachingDisabledReply) error {
cachingDisabled := s.impl.CachingDisabled()
*reply = CachingDisabledReply{
CachingDisabled: cachingDisabled,
}
return nil
}
func (s *SystemViewServer) ReplicationState(_ interface{}, reply *ReplicationStateReply) error {
replicationState := s.impl.ReplicationState()
*reply = ReplicationStateReply{
ReplicationState: replicationState,
}
return nil
}
func (s *SystemViewServer) ResponseWrapData(args *ResponseWrapDataArgs, reply *ResponseWrapDataReply) error {
// Do not allow JWTs to be returned
info, err := s.impl.ResponseWrapData(context.Background(), args.Data, args.TTL, false)
if err != nil {
*reply = ResponseWrapDataReply{
Error: wrapError(err),
}
return nil
}
*reply = ResponseWrapDataReply{
ResponseWrapInfo: info,
}
return nil
}
func (s *SystemViewServer) MlockEnabled(_ interface{}, reply *MlockEnabledReply) error {
enabled := s.impl.MlockEnabled()
*reply = MlockEnabledReply{
MlockEnabled: enabled,
}
return nil
}
func (s *SystemViewServer) LocalMount(_ interface{}, reply *LocalMountReply) error {
local := s.impl.LocalMount()
*reply = LocalMountReply{
Local: local,
}
return nil
}
func (s *SystemViewServer) EntityInfo(args *EntityInfoArgs, reply *EntityInfoReply) error {
entity, err := s.impl.EntityInfo(args.EntityID)
if err != nil {
*reply = EntityInfoReply{
Error: wrapError(err),
}
return nil
}
*reply = EntityInfoReply{
Entity: entity,
}
return nil
}
func (s *SystemViewServer) PluginEnv(_ interface{}, reply *PluginEnvReply) error {
pluginEnv, err := s.impl.PluginEnv(context.Background())
if err != nil {
*reply = PluginEnvReply{
Error: wrapError(err),
}
return nil
}
*reply = PluginEnvReply{
PluginEnvironment: pluginEnv,
}
return nil
}
type DefaultLeaseTTLReply struct {
DefaultLeaseTTL time.Duration
}
type MaxLeaseTTLReply struct {
MaxLeaseTTL time.Duration
}
type SudoPrivilegeArgs struct {
Path string
Token string
}
type SudoPrivilegeReply struct {
Sudo bool
}
type TaintedReply struct {
Tainted bool
}
type CachingDisabledReply struct {
CachingDisabled bool
}
type ReplicationStateReply struct {
ReplicationState consts.ReplicationState
}
type ResponseWrapDataArgs struct {
Data map[string]interface{}
TTL time.Duration
JWT bool
}
type ResponseWrapDataReply struct {
ResponseWrapInfo *wrapping.ResponseWrapInfo
Error error
}
type MlockEnabledReply struct {
MlockEnabled bool
}
type LocalMountReply struct {
Local bool
}
type EntityInfoArgs struct {
EntityID string
}
type EntityInfoReply struct {
Entity *logical.Entity
Error error
}
type PluginEnvReply struct {
PluginEnvironment *logical.PluginEnvironment
Error error
}

View file

@ -1,231 +0,0 @@
package plugin
import (
"context"
"testing"
"reflect"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/logical"
)
func Test_impl(t *testing.T) {
var _ logical.SystemView = new(SystemViewClient)
}
func TestSystem_defaultLeaseTTL(t *testing.T) {
client, server := plugin.TestRPCConn(t)
defer client.Close()
sys := logical.TestSystemView()
server.RegisterName("Plugin", &SystemViewServer{
impl: sys,
})
testSystemView := &SystemViewClient{client: client}
expected := sys.DefaultLeaseTTL()
actual := testSystemView.DefaultLeaseTTL()
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("expected: %v, got: %v", expected, actual)
}
}
func TestSystem_maxLeaseTTL(t *testing.T) {
client, server := plugin.TestRPCConn(t)
defer client.Close()
sys := logical.TestSystemView()
server.RegisterName("Plugin", &SystemViewServer{
impl: sys,
})
testSystemView := &SystemViewClient{client: client}
expected := sys.MaxLeaseTTL()
actual := testSystemView.MaxLeaseTTL()
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("expected: %v, got: %v", expected, actual)
}
}
func TestSystem_sudoPrivilege(t *testing.T) {
client, server := plugin.TestRPCConn(t)
defer client.Close()
sys := logical.TestSystemView()
sys.SudoPrivilegeVal = true
server.RegisterName("Plugin", &SystemViewServer{
impl: sys,
})
testSystemView := &SystemViewClient{client: client}
ctx := context.Background()
expected := sys.SudoPrivilege(ctx, "foo", "bar")
actual := testSystemView.SudoPrivilege(ctx, "foo", "bar")
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("expected: %v, got: %v", expected, actual)
}
}
func TestSystem_tainted(t *testing.T) {
client, server := plugin.TestRPCConn(t)
defer client.Close()
sys := logical.TestSystemView()
sys.TaintedVal = true
server.RegisterName("Plugin", &SystemViewServer{
impl: sys,
})
testSystemView := &SystemViewClient{client: client}
expected := sys.Tainted()
actual := testSystemView.Tainted()
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("expected: %v, got: %v", expected, actual)
}
}
func TestSystem_cachingDisabled(t *testing.T) {
client, server := plugin.TestRPCConn(t)
defer client.Close()
sys := logical.TestSystemView()
sys.CachingDisabledVal = true
server.RegisterName("Plugin", &SystemViewServer{
impl: sys,
})
testSystemView := &SystemViewClient{client: client}
expected := sys.CachingDisabled()
actual := testSystemView.CachingDisabled()
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("expected: %v, got: %v", expected, actual)
}
}
func TestSystem_replicationState(t *testing.T) {
client, server := plugin.TestRPCConn(t)
defer client.Close()
sys := logical.TestSystemView()
sys.ReplicationStateVal = consts.ReplicationPerformancePrimary
server.RegisterName("Plugin", &SystemViewServer{
impl: sys,
})
testSystemView := &SystemViewClient{client: client}
expected := sys.ReplicationState()
actual := testSystemView.ReplicationState()
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("expected: %v, got: %v", expected, actual)
}
}
func TestSystem_responseWrapData(t *testing.T) {
t.SkipNow()
}
func TestSystem_lookupPlugin(t *testing.T) {
client, server := plugin.TestRPCConn(t)
defer client.Close()
sys := logical.TestSystemView()
server.RegisterName("Plugin", &SystemViewServer{
impl: sys,
})
testSystemView := &SystemViewClient{client: client}
if _, err := testSystemView.LookupPlugin(context.Background(), "foo", consts.PluginTypeDatabase); err == nil {
t.Fatal("LookPlugin(): expected error on due to unsupported call from plugin")
}
}
func TestSystem_mlockEnabled(t *testing.T) {
client, server := plugin.TestRPCConn(t)
defer client.Close()
sys := logical.TestSystemView()
sys.EnableMlock = true
server.RegisterName("Plugin", &SystemViewServer{
impl: sys,
})
testSystemView := &SystemViewClient{client: client}
expected := sys.MlockEnabled()
actual := testSystemView.MlockEnabled()
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("expected: %v, got: %v", expected, actual)
}
}
func TestSystem_entityInfo(t *testing.T) {
client, server := plugin.TestRPCConn(t)
defer client.Close()
sys := logical.TestSystemView()
sys.EntityVal = &logical.Entity{
ID: "test",
Name: "name",
}
server.RegisterName("Plugin", &SystemViewServer{
impl: sys,
})
testSystemView := &SystemViewClient{client: client}
actual, err := testSystemView.EntityInfo("")
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(sys.EntityVal, actual) {
t.Fatalf("expected: %v, got: %v", sys.EntityVal, actual)
}
}
func TestSystem_pluginEnv(t *testing.T) {
client, server := plugin.TestRPCConn(t)
defer client.Close()
sys := logical.TestSystemView()
sys.PluginEnvironment = &logical.PluginEnvironment{
VaultVersion: "0.10.42",
}
server.RegisterName("Plugin", &SystemViewServer{
impl: sys,
})
testSystemView := &SystemViewClient{client: client}
expected, err := sys.PluginEnv(context.Background())
if err != nil {
t.Fatal(err)
}
actual, err := testSystemView.PluginEnv(context.Background())
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("expected: %v, got: %v", expected, actual)
}
}

View file

@ -35,9 +35,9 @@ export default DS.RESTAdapter.extend({
let headers = {};
if (token && !options.unauthenticated) {
headers['X-Vault-Token'] = token;
if (options.wrapTTL) {
headers['X-Vault-Wrap-TTL'] = options.wrapTTL;
}
}
if (options.wrapTTL) {
headers['X-Vault-Wrap-TTL'] = options.wrapTTL;
}
let namespace =
typeof options.namespace === 'undefined' ? this.get('namespaceService.path') : options.namespace;

View file

@ -1,2 +1,3 @@
import AuthConfig from './_base';
export default AuthConfig.extend();

View file

@ -1,2 +1,3 @@
import AuthConfig from './_base';
export default AuthConfig.extend();

View file

@ -1,2 +1,3 @@
import AuthConfig from './_base';
export default AuthConfig.extend();

View file

@ -1,2 +1,3 @@
import AuthConfig from './_base';
export default AuthConfig.extend();

View file

@ -1,2 +1,3 @@
import AuthConfig from './_base';
export default AuthConfig.extend();

View file

@ -1,2 +1,3 @@
import AuthConfig from './_base';
export default AuthConfig.extend();

View file

@ -1,2 +1,3 @@
import AuthConfig from './_base';
export default AuthConfig.extend();

View file

@ -56,4 +56,8 @@ export default ApplicationAdapter.extend({
urlForDeleteRecord(id, modelName, snapshot) {
return this.url(snapshot.id);
},
exchangeOIDC(path, state, code) {
return this.ajax(`/v1/auth/${path}/oidc/callback`, 'GET', { data: { state, code } });
},
});

View file

@ -109,7 +109,7 @@ export default ApplicationAdapter.extend({
},
authenticate({ backend, data }) {
const { token, password, username, path } = data;
const { role, jwt, token, password, username, path } = data;
const url = this.urlForAuth(backend, username, path);
const verb = backend === 'token' ? 'GET' : 'POST';
let options = {
@ -119,6 +119,8 @@ export default ApplicationAdapter.extend({
options.headers = {
'X-Vault-Token': token,
};
} else if (backend === 'jwt') {
options.data = { role, jwt };
} else {
options.data = token ? { token, password } : { password };
}
@ -139,6 +141,7 @@ export default ApplicationAdapter.extend({
const authBackend = type.toLowerCase();
const authURLs = {
github: 'login',
jwt: 'login',
userpass: `login/${encodeURIComponent(username)}`,
ldap: `login/${encodeURIComponent(username)}`,
okta: `login/${encodeURIComponent(username)}`,

View file

@ -7,4 +7,8 @@ export default Adapter.extend({
}
return `/v1/${role.backend}/sign/${role.name}`;
},
pathForType() {
return 'sign';
},
});

View file

@ -13,7 +13,6 @@ export default Adapter.extend({
}
return url;
},
optionsForQuery(id) {
let data = {};
if (!id) {

View file

@ -0,0 +1,32 @@
import ApplicationAdapter from './application';
import { inject as service } from '@ember/service';
import { get } from '@ember/object';
export default ApplicationAdapter.extend({
router: service(),
findRecord(store, type, id, snapshot) {
let [path, role] = JSON.parse(id);
let namespace = get(snapshot, 'adapterOptions.namespace');
let url = `/v1/auth/${path}/oidc/auth_url`;
let redirect_uri = `${window.location.origin}${this.router.urlFor('vault.cluster.oidc-callback', {
auth_path: path,
})}`;
if (namespace) {
redirect_uri = `${window.location.origin}${this.router.urlFor(
'vault.cluster.oidc-callback',
{ auth_path: path },
{ queryParams: { namespace } }
)}`;
}
return this.ajax(url, 'POST', {
data: {
role,
redirect_uri,
},
});
},
});

View file

@ -7,16 +7,20 @@ export default ApplicationAdapter.extend({
return path ? url + '/' + path : url;
},
internalURL(path) {
let url = `/${this.urlPrefix()}/internal/ui/mounts`;
if (path) {
url = `${url}/${path}`;
}
return url;
},
pathForType() {
return 'mounts';
},
query(store, type, query) {
let url = `/${this.urlPrefix()}/internal/ui/mounts`;
if (query.path) {
url = `${url}/${query.path}`;
}
return this.ajax(url, 'GET');
return this.ajax(this.internalURL(query.path), 'GET');
},
createRecord(store, type, snapshot) {

View file

@ -34,6 +34,10 @@ export default ApplicationAdapter.extend({
return url;
},
pathForType() {
return 'mounts';
},
optionsForQuery(id, action, wrapTTL) {
let data = {};
if (action === 'query') {

View file

@ -5,6 +5,7 @@ import { messageTypes } from 'vault/helpers/message-types';
export default Component.extend({
type: null,
message: null,
classNames: ['message-inline'],

Some files were not shown because too many files have changed in this diff Show more