VAULT-15547 Agent/proxy decoupling, take two (#20634)

* VAULT-15547 Additional tests, refactoring, for proxy split

* VAULT-15547 Additional tests, refactoring, for proxy split

* VAULT-15547 Import reorganization

* VAULT-15547 Some missed updates for PersistConfig

* VAULT-15547 address comments

* VAULT-15547 address comments
This commit is contained in:
Violet Hynes 2023-05-19 13:17:48 -04:00 committed by GitHub
parent a47c0c7073
commit 92dc054bb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1179 additions and 439 deletions

View File

@ -92,6 +92,7 @@ test_packages[6]+=" $base/command/agentproxyshared/auth/jwt"
test_packages[6]+=" $base/command/agentproxyshared/auth/kerberos" test_packages[6]+=" $base/command/agentproxyshared/auth/kerberos"
test_packages[6]+=" $base/command/agentproxyshared/auth/kubernetes" test_packages[6]+=" $base/command/agentproxyshared/auth/kubernetes"
test_packages[6]+=" $base/command/agentproxyshared/auth/token-file" test_packages[6]+=" $base/command/agentproxyshared/auth/token-file"
test_packages[6]+=" $base/command/agentproxyshared"
test_packages[6]+=" $base/command/agentproxyshared/cache" test_packages[6]+=" $base/command/agentproxyshared/cache"
test_packages[6]+=" $base/command/agentproxyshared/cache/cacheboltdb" test_packages[6]+=" $base/command/agentproxyshared/cache/cacheboltdb"
test_packages[6]+=" $base/command/agentproxyshared/cache/cachememdb" test_packages[6]+=" $base/command/agentproxyshared/cache/cachememdb"

View File

@ -9,11 +9,9 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"os" "os"
"path/filepath"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@ -29,23 +27,9 @@ import (
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
agentConfig "github.com/hashicorp/vault/command/agent/config" agentConfig "github.com/hashicorp/vault/command/agent/config"
"github.com/hashicorp/vault/command/agent/template" "github.com/hashicorp/vault/command/agent/template"
"github.com/hashicorp/vault/command/agentproxyshared"
"github.com/hashicorp/vault/command/agentproxyshared/auth" "github.com/hashicorp/vault/command/agentproxyshared/auth"
"github.com/hashicorp/vault/command/agentproxyshared/auth/alicloud"
"github.com/hashicorp/vault/command/agentproxyshared/auth/approle"
"github.com/hashicorp/vault/command/agentproxyshared/auth/aws"
"github.com/hashicorp/vault/command/agentproxyshared/auth/azure"
"github.com/hashicorp/vault/command/agentproxyshared/auth/cert"
"github.com/hashicorp/vault/command/agentproxyshared/auth/cf"
"github.com/hashicorp/vault/command/agentproxyshared/auth/gcp"
"github.com/hashicorp/vault/command/agentproxyshared/auth/jwt"
"github.com/hashicorp/vault/command/agentproxyshared/auth/kerberos"
"github.com/hashicorp/vault/command/agentproxyshared/auth/kubernetes"
"github.com/hashicorp/vault/command/agentproxyshared/auth/oci"
token_file "github.com/hashicorp/vault/command/agentproxyshared/auth/token-file"
cache "github.com/hashicorp/vault/command/agentproxyshared/cache" cache "github.com/hashicorp/vault/command/agentproxyshared/cache"
"github.com/hashicorp/vault/command/agentproxyshared/cache/cacheboltdb"
"github.com/hashicorp/vault/command/agentproxyshared/cache/cachememdb"
"github.com/hashicorp/vault/command/agentproxyshared/cache/keymanager"
"github.com/hashicorp/vault/command/agentproxyshared/sink" "github.com/hashicorp/vault/command/agentproxyshared/sink"
"github.com/hashicorp/vault/command/agentproxyshared/sink/file" "github.com/hashicorp/vault/command/agentproxyshared/sink/file"
"github.com/hashicorp/vault/command/agentproxyshared/sink/inmem" "github.com/hashicorp/vault/command/agentproxyshared/sink/inmem"
@ -277,6 +261,16 @@ func (c *AgentCommand) Run(args []string) int {
} }
} }
if config.IsDefaultListerDefined() {
// Notably, we cannot know for sure if they are using the API proxy functionality unless
// we log on each API proxy call, which would be too noisy.
// A customer could have a listener defined but only be using e.g. the cache-clear API,
// even though the API proxy is something they have available.
c.UI.Warn("==> Note: Vault Agent will be deprecating API proxy functionality in a future " +
"release, and this functionality has moved to a new subcommand, vault proxy. If you rely on this " +
"functionality, plan to move to Vault Proxy instead.")
}
// ctx and cancelFunc are passed to the AuthHandler, SinkServer, and // ctx and cancelFunc are passed to the AuthHandler, SinkServer, and
// TemplateServer that periodically listen for ctx.Done() to fire and shut // TemplateServer that periodically listen for ctx.Done() to fire and shut
// down accordingly. // down accordingly.
@ -352,39 +346,9 @@ func (c *AgentCommand) Run(args []string) int {
MountPath: config.AutoAuth.Method.MountPath, MountPath: config.AutoAuth.Method.MountPath,
Config: config.AutoAuth.Method.Config, Config: config.AutoAuth.Method.Config,
} }
switch config.AutoAuth.Method.Type { method, err = agentproxyshared.GetAutoAuthMethodFromConfig(config.AutoAuth.Method.Type, authConfig, config.Vault.Address)
case "alicloud":
method, err = alicloud.NewAliCloudAuthMethod(authConfig)
case "aws":
method, err = aws.NewAWSAuthMethod(authConfig)
case "azure":
method, err = azure.NewAzureAuthMethod(authConfig)
case "cert":
method, err = cert.NewCertAuthMethod(authConfig)
case "cf":
method, err = cf.NewCFAuthMethod(authConfig)
case "gcp":
method, err = gcp.NewGCPAuthMethod(authConfig)
case "jwt":
method, err = jwt.NewJWTAuthMethod(authConfig)
case "kerberos":
method, err = kerberos.NewKerberosAuthMethod(authConfig)
case "kubernetes":
method, err = kubernetes.NewKubernetesAuthMethod(authConfig)
case "approle":
method, err = approle.NewApproleAuthMethod(authConfig)
case "oci":
method, err = oci.NewOCIAuthMethod(authConfig, config.Vault.Address)
case "token_file":
method, err = token_file.NewTokenFileAuthMethod(authConfig)
case "pcf": // Deprecated.
method, err = cf.NewCFAuthMethod(authConfig)
default:
c.UI.Error(fmt.Sprintf("Unknown auth method %q", config.AutoAuth.Method.Type))
return 1
}
if err != nil { if err != nil {
c.UI.Error(fmt.Errorf("Error creating %s auth method: %w", config.AutoAuth.Method.Type, err).Error()) c.UI.Error(fmt.Sprintf("Error creating %s auth method: %v", config.AutoAuth.Method.Type, err))
return 1 return 1
} }
} }
@ -535,147 +499,14 @@ func (c *AgentCommand) Run(args []string) int {
// Configure persistent storage and add to LeaseCache // Configure persistent storage and add to LeaseCache
if config.Cache.Persist != nil { if config.Cache.Persist != nil {
if config.Cache.Persist.Path == "" { deferFunc, oldToken, err := agentproxyshared.AddPersistentStorageToLeaseCache(ctx, leaseCache, config.Cache.Persist, cacheLogger)
c.UI.Error("must specify persistent cache path")
return 1
}
// Set AAD based on key protection type
var aad string
switch config.Cache.Persist.Type {
case "kubernetes":
aad, err = getServiceAccountJWT(config.Cache.Persist.ServiceAccountTokenFile)
if err != nil {
c.UI.Error(fmt.Sprintf("failed to read service account token from %s: %s", config.Cache.Persist.ServiceAccountTokenFile, err))
return 1
}
default:
c.UI.Error(fmt.Sprintf("persistent key protection type %q not supported", config.Cache.Persist.Type))
return 1
}
// Check if bolt file exists already
dbFileExists, err := cacheboltdb.DBFileExists(config.Cache.Persist.Path)
if err != nil { if err != nil {
c.UI.Error(fmt.Sprintf("failed to check if bolt file exists at path %s: %s", config.Cache.Persist.Path, err)) c.UI.Error(fmt.Sprintf("Error creating persistent cache: %v", err))
return 1 return 1
} }
if dbFileExists { previousToken = oldToken
// Open the bolt file, but wait to setup Encryption if deferFunc != nil {
ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{ defer deferFunc()
Path: config.Cache.Persist.Path,
Logger: cacheLogger.Named("cacheboltdb"),
})
if err != nil {
c.UI.Error(fmt.Sprintf("Error opening persistent cache: %v", err))
return 1
}
// Get the token from bolt for retrieving the encryption key,
// then setup encryption so that restore is possible
token, err := ps.GetRetrievalToken()
if err != nil {
c.UI.Error(fmt.Sprintf("Error getting retrieval token from persistent cache: %v", err))
}
if err := ps.Close(); err != nil {
c.UI.Warn(fmt.Sprintf("Failed to close persistent cache file after getting retrieval token: %s", err))
}
km, err := keymanager.NewPassthroughKeyManager(ctx, token)
if err != nil {
c.UI.Error(fmt.Sprintf("failed to configure persistence encryption for cache: %s", err))
return 1
}
// Open the bolt file with the wrapper provided
ps, err = cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
Path: config.Cache.Persist.Path,
Logger: cacheLogger.Named("cacheboltdb"),
Wrapper: km.Wrapper(),
AAD: aad,
})
if err != nil {
c.UI.Error(fmt.Sprintf("Error opening persistent cache with wrapper: %v", err))
return 1
}
// Restore anything in the persistent cache to the memory cache
if err := leaseCache.Restore(ctx, ps); err != nil {
c.UI.Error(fmt.Sprintf("Error restoring in-memory cache from persisted file: %v", err))
if config.Cache.Persist.ExitOnErr {
return 1
}
}
cacheLogger.Info("loaded memcache from persistent storage")
// Check for previous auto-auth token
oldTokenBytes, err := ps.GetAutoAuthToken(ctx)
if err != nil {
c.UI.Error(fmt.Sprintf("Error in fetching previous auto-auth token: %s", err))
if config.Cache.Persist.ExitOnErr {
return 1
}
}
if len(oldTokenBytes) > 0 {
oldToken, err := cachememdb.Deserialize(oldTokenBytes)
if err != nil {
c.UI.Error(fmt.Sprintf("Error in deserializing previous auto-auth token cache entry: %s", err))
if config.Cache.Persist.ExitOnErr {
return 1
}
}
previousToken = oldToken.Token
}
// If keep_after_import true, set persistent storage layer in
// leaseCache, else remove db file
if config.Cache.Persist.KeepAfterImport {
defer ps.Close()
leaseCache.SetPersistentStorage(ps)
} else {
if err := ps.Close(); err != nil {
c.UI.Warn(fmt.Sprintf("failed to close persistent cache file: %s", err))
}
dbFile := filepath.Join(config.Cache.Persist.Path, cacheboltdb.DatabaseFileName)
if err := os.Remove(dbFile); err != nil {
c.UI.Error(fmt.Sprintf("failed to remove persistent storage file %s: %s", dbFile, err))
if config.Cache.Persist.ExitOnErr {
return 1
}
}
}
} else {
km, err := keymanager.NewPassthroughKeyManager(ctx, nil)
if err != nil {
c.UI.Error(fmt.Sprintf("failed to configure persistence encryption for cache: %s", err))
return 1
}
ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
Path: config.Cache.Persist.Path,
Logger: cacheLogger.Named("cacheboltdb"),
Wrapper: km.Wrapper(),
AAD: aad,
})
if err != nil {
c.UI.Error(fmt.Sprintf("Error creating persistent cache: %v", err))
return 1
}
cacheLogger.Info("configured persistent storage", "path", config.Cache.Persist.Path)
// Stash the key material in bolt
token, err := km.RetrievalToken(ctx)
if err != nil {
c.UI.Error(fmt.Sprintf("Error getting persistent key: %s", err))
return 1
}
if err := ps.StoreRetrievalToken(token); err != nil {
c.UI.Error(fmt.Sprintf("Error setting key in persistent cache: %v", err))
return 1
}
defer ps.Close()
leaseCache.SetPersistentStorage(ps)
} }
} }
} }
@ -1166,19 +997,6 @@ func (c *AgentCommand) removePidFile(pidPath string) error {
return os.Remove(pidPath) return os.Remove(pidPath)
} }
// GetServiceAccountJWT reads the service account jwt from `tokenFile`. Default is
// the default service account file path in kubernetes.
func getServiceAccountJWT(tokenFile string) (string, error) {
if len(tokenFile) == 0 {
tokenFile = "/var/run/secrets/kubernetes.io/serviceaccount/token"
}
token, err := ioutil.ReadFile(tokenFile)
if err != nil {
return "", err
}
return strings.TrimSpace(string(token)), nil
}
func (c *AgentCommand) handleMetrics() http.Handler { func (c *AgentCommand) handleMetrics() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {

View File

@ -170,7 +170,7 @@ func TestCache_UsingAutoAuthToken(t *testing.T) {
Client: client, Client: client,
Logger: cacheLogger.Named("apiproxy"), Logger: cacheLogger.Named("apiproxy"),
UserAgentStringFunction: useragent.ProxyStringWithProxiedUserAgent, UserAgentStringFunction: useragent.ProxyStringWithProxiedUserAgent,
UserAgentString: useragent.ProxyString(), UserAgentString: useragent.ProxyAPIProxyString(),
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -19,6 +19,7 @@ import (
"github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/hcl" "github.com/hashicorp/hcl"
"github.com/hashicorp/hcl/hcl/ast" "github.com/hashicorp/hcl/hcl/ast"
"github.com/hashicorp/vault/command/agentproxyshared"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/internalshared/configutil"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
@ -105,22 +106,13 @@ type APIProxy struct {
// Cache contains any configuration needed for Cache mode // Cache contains any configuration needed for Cache mode
type Cache struct { type Cache struct {
UseAutoAuthTokenRaw interface{} `hcl:"use_auto_auth_token"` UseAutoAuthTokenRaw interface{} `hcl:"use_auto_auth_token"`
UseAutoAuthToken bool `hcl:"-"` UseAutoAuthToken bool `hcl:"-"`
ForceAutoAuthToken bool `hcl:"-"` ForceAutoAuthToken bool `hcl:"-"`
EnforceConsistency string `hcl:"enforce_consistency"` EnforceConsistency string `hcl:"enforce_consistency"`
WhenInconsistent string `hcl:"when_inconsistent"` WhenInconsistent string `hcl:"when_inconsistent"`
Persist *Persist `hcl:"persist"` Persist *agentproxyshared.PersistConfig `hcl:"persist"`
InProcDialer transportDialer `hcl:"-"` InProcDialer transportDialer `hcl:"-"`
}
// Persist contains configuration needed for persistent caching
type Persist struct {
Type string
Path string `hcl:"path"`
KeepAfterImport bool `hcl:"keep_after_import"`
ExitOnErr bool `hcl:"exit_on_err"`
ServiceAccountTokenFile string `hcl:"service_account_token_file"`
} }
// AutoAuth is the configured authentication method and sinks // AutoAuth is the configured authentication method and sinks
@ -268,6 +260,17 @@ func (c *Config) Merge(c2 *Config) *Config {
return result return result
} }
// IsDefaultListerDefined returns true if a default listener has been defined
// in this config
func (c *Config) IsDefaultListerDefined() bool {
for _, l := range c.Listeners {
if l.Role != "metrics_only" {
return true
}
}
return false
}
// ValidateConfig validates an Agent configuration after it has been fully merged together, to // ValidateConfig validates an Agent configuration after it has been fully merged together, to
// ensure that required combinations of configs are there // ensure that required combinations of configs are there
func (c *Config) ValidateConfig() error { func (c *Config) ValidateConfig() error {
@ -737,7 +740,7 @@ func parsePersist(result *Config, list *ast.ObjectList) error {
item := persistList.Items[0] item := persistList.Items[0]
var p Persist var p agentproxyshared.PersistConfig
err := hcl.DecodeObject(&p, item.Val) err := hcl.DecodeObject(&p, item.Val)
if err != nil { if err != nil {
return err return err

View File

@ -10,6 +10,7 @@ import (
"github.com/go-test/deep" "github.com/go-test/deep"
ctconfig "github.com/hashicorp/consul-template/config" ctconfig "github.com/hashicorp/consul-template/config"
"github.com/hashicorp/vault/command/agentproxyshared"
"github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/sdk/helper/pointerutil" "github.com/hashicorp/vault/sdk/helper/pointerutil"
) )
@ -80,7 +81,7 @@ func TestLoadConfigFile_AgentCache(t *testing.T) {
UseAutoAuthToken: true, UseAutoAuthToken: true,
UseAutoAuthTokenRaw: true, UseAutoAuthTokenRaw: true,
ForceAutoAuthToken: false, ForceAutoAuthToken: false,
Persist: &Persist{ Persist: &agentproxyshared.PersistConfig{
Type: "kubernetes", Type: "kubernetes",
Path: "/vault/agent-cache/", Path: "/vault/agent-cache/",
KeepAfterImport: true, KeepAfterImport: true,
@ -185,7 +186,7 @@ func TestLoadConfigDir_AgentCache(t *testing.T) {
UseAutoAuthToken: true, UseAutoAuthToken: true,
UseAutoAuthTokenRaw: true, UseAutoAuthTokenRaw: true,
ForceAutoAuthToken: false, ForceAutoAuthToken: false,
Persist: &Persist{ Persist: &agentproxyshared.PersistConfig{
Type: "kubernetes", Type: "kubernetes",
Path: "/vault/agent-cache/", Path: "/vault/agent-cache/",
KeepAfterImport: true, KeepAfterImport: true,
@ -385,7 +386,7 @@ func TestLoadConfigFile_AgentCache_NoListeners(t *testing.T) {
UseAutoAuthToken: true, UseAutoAuthToken: true,
UseAutoAuthTokenRaw: true, UseAutoAuthTokenRaw: true,
ForceAutoAuthToken: false, ForceAutoAuthToken: false,
Persist: &Persist{ Persist: &agentproxyshared.PersistConfig{
Type: "kubernetes", Type: "kubernetes",
Path: "/vault/agent-cache/", Path: "/vault/agent-cache/",
KeepAfterImport: true, KeepAfterImport: true,
@ -957,7 +958,7 @@ func TestLoadConfigFile_AgentCache_Persist(t *testing.T) {
expected := &Config{ expected := &Config{
APIProxy: &APIProxy{}, APIProxy: &APIProxy{},
Cache: &Cache{ Cache: &Cache{
Persist: &Persist{ Persist: &agentproxyshared.PersistConfig{
Type: "kubernetes", Type: "kubernetes",
Path: "/vault/agent-cache/", Path: "/vault/agent-cache/",
KeepAfterImport: false, KeepAfterImport: false,

View File

@ -16,16 +16,16 @@ import (
ctconfig "github.com/hashicorp/consul-template/config" ctconfig "github.com/hashicorp/consul-template/config"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/test/bufconn"
"github.com/hashicorp/vault/command/agent/config" "github.com/hashicorp/vault/command/agent/config"
"github.com/hashicorp/vault/command/agent/internal/ctmanager" "github.com/hashicorp/vault/command/agent/internal/ctmanager"
"github.com/hashicorp/vault/command/agentproxyshared"
"github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/internalshared/listenerutil" "github.com/hashicorp/vault/internalshared/listenerutil"
"github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/helper/pointerutil" "github.com/hashicorp/vault/sdk/helper/pointerutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/test/bufconn"
) )
func newRunnerConfig(s *ServerConfig, configs ctconfig.TemplateConfigs) (*ctconfig.Config, error) { func newRunnerConfig(s *ServerConfig, configs ctconfig.TemplateConfigs) (*ctconfig.Config, error) {
@ -88,7 +88,7 @@ func newAgentConfig(listeners []*configutil.Listener, enableCache, enablePersise
} }
if enablePersisentCache { if enablePersisentCache {
agentConfig.Cache.Persist = &config.Persist{Type: "kubernetes"} agentConfig.Cache.Persist = &agentproxyshared.PersistConfig{Type: "kubernetes"}
} }
return agentConfig return agentConfig

View File

@ -40,7 +40,7 @@ func TestAPIProxy(t *testing.T) {
Client: client, Client: client,
Logger: logging.NewVaultLogger(hclog.Trace), Logger: logging.NewVaultLogger(hclog.Trace),
UserAgentStringFunction: useragent.ProxyStringWithProxiedUserAgent, UserAgentStringFunction: useragent.ProxyStringWithProxiedUserAgent,
UserAgentString: useragent.ProxyString(), UserAgentString: useragent.ProxyAPIProxyString(),
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -78,7 +78,7 @@ func TestAPIProxyNoCache(t *testing.T) {
Client: client, Client: client,
Logger: logging.NewVaultLogger(hclog.Trace), Logger: logging.NewVaultLogger(hclog.Trace),
UserAgentStringFunction: useragent.ProxyStringWithProxiedUserAgent, UserAgentStringFunction: useragent.ProxyStringWithProxiedUserAgent,
UserAgentString: useragent.ProxyString(), UserAgentString: useragent.ProxyAPIProxyString(),
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -118,7 +118,7 @@ func TestAPIProxy_queryParams(t *testing.T) {
Client: client, Client: client,
Logger: logging.NewVaultLogger(hclog.Trace), Logger: logging.NewVaultLogger(hclog.Trace),
UserAgentStringFunction: useragent.ProxyStringWithProxiedUserAgent, UserAgentStringFunction: useragent.ProxyStringWithProxiedUserAgent,
UserAgentString: useragent.ProxyString(), UserAgentString: useragent.ProxyAPIProxyString(),
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -264,7 +264,7 @@ func setupClusterAndAgentCommon(ctx context.Context, t *testing.T, coreConfig *v
Client: clienToUse, Client: clienToUse,
Logger: apiProxyLogger, Logger: apiProxyLogger,
UserAgentStringFunction: useragent.ProxyStringWithProxiedUserAgent, UserAgentStringFunction: useragent.ProxyStringWithProxiedUserAgent,
UserAgentString: useragent.ProxyString(), UserAgentString: useragent.ProxyAPIProxyString(),
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -174,6 +174,12 @@ func (c *LeaseCache) SetPersistentStorage(storageIn *cacheboltdb.BoltStorage) {
c.ps = storageIn c.ps = storageIn
} }
// PersistentStorage is a getter for the persistent storage field in
// LeaseCache
func (c *LeaseCache) PersistentStorage() *cacheboltdb.BoltStorage {
return c.ps
}
// checkCacheForRequest checks the cache for a particular request based on its // checkCacheForRequest checks the cache for a particular request based on its
// computed ID. It returns a non-nil *SendResponse if an entry is found. // computed ID. It returns a non-nil *SendResponse if an entry is found.
func (c *LeaseCache) checkCacheForRequest(id string) (*SendResponse, error) { func (c *LeaseCache) checkCacheForRequest(id string) (*SendResponse, error) {

View File

@ -43,7 +43,7 @@ func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache {
lc, err := NewLeaseCache(&LeaseCacheConfig{ lc, err := NewLeaseCache(&LeaseCacheConfig{
Client: client, Client: client,
BaseContext: context.Background(), BaseContext: context.Background(),
Proxier: newMockProxier(responses), Proxier: NewMockProxier(responses),
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"), Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"),
}) })
if err != nil { if err != nil {
@ -82,7 +82,7 @@ func testNewLeaseCacheWithPersistence(t *testing.T, responses []*SendResponse, s
lc, err := NewLeaseCache(&LeaseCacheConfig{ lc, err := NewLeaseCache(&LeaseCacheConfig{
Client: client, Client: client,
BaseContext: context.Background(), BaseContext: context.Background(),
Proxier: newMockProxier(responses), Proxier: NewMockProxier(responses),
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"), Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"),
Storage: storage, Storage: storage,
}) })

View File

@ -27,7 +27,7 @@ type mockProxier struct {
responseIndex int responseIndex int
} }
func newMockProxier(responses []*SendResponse) *mockProxier { func NewMockProxier(responses []*SendResponse) *mockProxier {
return &mockProxier{ return &mockProxier{
proxiedResponses: responses, proxiedResponses: responses,
} }

View File

@ -0,0 +1,237 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package agentproxyshared
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/command/agentproxyshared/auth"
"github.com/hashicorp/vault/command/agentproxyshared/auth/alicloud"
"github.com/hashicorp/vault/command/agentproxyshared/auth/approle"
"github.com/hashicorp/vault/command/agentproxyshared/auth/aws"
"github.com/hashicorp/vault/command/agentproxyshared/auth/azure"
"github.com/hashicorp/vault/command/agentproxyshared/auth/cert"
"github.com/hashicorp/vault/command/agentproxyshared/auth/cf"
"github.com/hashicorp/vault/command/agentproxyshared/auth/gcp"
"github.com/hashicorp/vault/command/agentproxyshared/auth/jwt"
"github.com/hashicorp/vault/command/agentproxyshared/auth/kerberos"
"github.com/hashicorp/vault/command/agentproxyshared/auth/kubernetes"
"github.com/hashicorp/vault/command/agentproxyshared/auth/oci"
token_file "github.com/hashicorp/vault/command/agentproxyshared/auth/token-file"
"github.com/hashicorp/vault/command/agentproxyshared/cache"
"github.com/hashicorp/vault/command/agentproxyshared/cache/cacheboltdb"
"github.com/hashicorp/vault/command/agentproxyshared/cache/cachememdb"
"github.com/hashicorp/vault/command/agentproxyshared/cache/keymanager"
)
// GetAutoAuthMethodFromConfig Calls the appropriate NewAutoAuthMethod function, initializing
// the auto-auth method, based on the auto-auth method type. Returns an error if one happens or
// the method type is invalid.
func GetAutoAuthMethodFromConfig(autoAuthMethodType string, authConfig *auth.AuthConfig, vaultAddress string) (auth.AuthMethod, error) {
switch autoAuthMethodType {
case "alicloud":
return alicloud.NewAliCloudAuthMethod(authConfig)
case "aws":
return aws.NewAWSAuthMethod(authConfig)
case "azure":
return azure.NewAzureAuthMethod(authConfig)
case "cert":
return cert.NewCertAuthMethod(authConfig)
case "cf":
return cf.NewCFAuthMethod(authConfig)
case "gcp":
return gcp.NewGCPAuthMethod(authConfig)
case "jwt":
return jwt.NewJWTAuthMethod(authConfig)
case "kerberos":
return kerberos.NewKerberosAuthMethod(authConfig)
case "kubernetes":
return kubernetes.NewKubernetesAuthMethod(authConfig)
case "approle":
return approle.NewApproleAuthMethod(authConfig)
case "oci":
return oci.NewOCIAuthMethod(authConfig, vaultAddress)
case "token_file":
return token_file.NewTokenFileAuthMethod(authConfig)
case "pcf": // Deprecated.
return cf.NewCFAuthMethod(authConfig)
default:
return nil, errors.New(fmt.Sprintf("unknown auth method %q", autoAuthMethodType))
}
}
// PersistConfig contains configuration needed for persistent caching
type PersistConfig struct {
Type string
Path string `hcl:"path"`
KeepAfterImport bool `hcl:"keep_after_import"`
ExitOnErr bool `hcl:"exit_on_err"`
ServiceAccountTokenFile string `hcl:"service_account_token_file"`
}
// AddPersistentStorageToLeaseCache adds persistence to a lease cache, based on a given PersistConfig
// Returns a close function to be deferred and the old token, if found, or an error
func AddPersistentStorageToLeaseCache(ctx context.Context, leaseCache *cache.LeaseCache, persistConfig *PersistConfig, logger log.Logger) (func() error, string, error) {
if persistConfig == nil {
return nil, "", errors.New("persist config was nil")
}
if persistConfig.Path == "" {
return nil, "", errors.New("must specify persistent cache path")
}
// Set AAD based on key protection type
var aad string
var err error
switch persistConfig.Type {
case "kubernetes":
aad, err = getServiceAccountJWT(persistConfig.ServiceAccountTokenFile)
if err != nil {
tokenFileName := persistConfig.ServiceAccountTokenFile
if len(tokenFileName) == 0 {
tokenFileName = "/var/run/secrets/kubernetes.io/serviceaccount/token"
}
return nil, "", fmt.Errorf("failed to read service account token from %s: %w", tokenFileName, err)
}
default:
return nil, "", fmt.Errorf("persistent key protection type %q not supported", persistConfig.Type)
}
// Check if bolt file exists already
dbFileExists, err := cacheboltdb.DBFileExists(persistConfig.Path)
if err != nil {
return nil, "", fmt.Errorf("failed to check if bolt file exists at path %s: %w", persistConfig.Path, err)
}
if dbFileExists {
// Open the bolt file, but wait to setup Encryption
ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
Path: persistConfig.Path,
Logger: logger.Named("cacheboltdb"),
})
if err != nil {
return nil, "", fmt.Errorf("error opening persistent cache %v", err)
}
// Get the token from bolt for retrieving the encryption key,
// then setup encryption so that restore is possible
token, err := ps.GetRetrievalToken()
if err != nil {
return nil, "", fmt.Errorf("error getting retrieval token from persistent cache: %w", err)
}
if err := ps.Close(); err != nil {
return nil, "", fmt.Errorf("failed to close persistent cache file after getting retrieval token: %w", err)
}
km, err := keymanager.NewPassthroughKeyManager(ctx, token)
if err != nil {
return nil, "", fmt.Errorf("failed to configure persistence encryption for cache: %w", err)
}
// Open the bolt file with the wrapper provided
ps, err = cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
Path: persistConfig.Path,
Logger: logger.Named("cacheboltdb"),
Wrapper: km.Wrapper(),
AAD: aad,
})
if err != nil {
return nil, "", fmt.Errorf("error opening persistent cache with wrapper: %w", err)
}
// Restore anything in the persistent cache to the memory cache
if err := leaseCache.Restore(ctx, ps); err != nil {
logger.Error(fmt.Sprintf("error restoring in-memory cache from persisted file: %v", err))
if persistConfig.ExitOnErr {
return nil, "", fmt.Errorf("exiting with error as exit_on_err is set to true")
}
}
logger.Info("loaded memcache from persistent storage")
// Check for previous auto-auth token
oldTokenBytes, err := ps.GetAutoAuthToken(ctx)
if err != nil {
logger.Error(fmt.Sprintf("error in fetching previous auto-auth token: %v", err))
if persistConfig.ExitOnErr {
return nil, "", fmt.Errorf("exiting with error as exit_on_err is set to true")
}
}
var previousToken string
if len(oldTokenBytes) > 0 {
oldToken, err := cachememdb.Deserialize(oldTokenBytes)
if err != nil {
logger.Error(fmt.Sprintf("error in deserializing previous auto-auth token cache entryn: %v", err))
if persistConfig.ExitOnErr {
return nil, "", fmt.Errorf("exiting with error as exit_on_err is set to true")
}
}
previousToken = oldToken.Token
}
// If keep_after_import true, set persistent storage layer in
// leaseCache, else remove db file
if persistConfig.KeepAfterImport {
leaseCache.SetPersistentStorage(ps)
return ps.Close, previousToken, nil
} else {
if err := ps.Close(); err != nil {
logger.Warn(fmt.Sprintf("failed to close persistent cache file: %s", err))
}
dbFile := filepath.Join(persistConfig.Path, cacheboltdb.DatabaseFileName)
if err := os.Remove(dbFile); err != nil {
logger.Error(fmt.Sprintf("failed to remove persistent storage file %s: %v", dbFile, err))
if persistConfig.ExitOnErr {
return nil, "", fmt.Errorf("exiting with error as exit_on_err is set to true")
}
}
return nil, previousToken, nil
}
} else {
km, err := keymanager.NewPassthroughKeyManager(ctx, nil)
if err != nil {
return nil, "", fmt.Errorf("failed to configure persistence encryption for cache: %w", err)
}
ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
Path: persistConfig.Path,
Logger: logger.Named("cacheboltdb"),
Wrapper: km.Wrapper(),
AAD: aad,
})
if err != nil {
return nil, "", fmt.Errorf("error creating persistent cache: %w", err)
}
logger.Info("configured persistent storage", "path", persistConfig.Path)
// Stash the key material in bolt
token, err := km.RetrievalToken(ctx)
if err != nil {
return nil, "", fmt.Errorf("error getting persistence key: %w", err)
}
if err := ps.StoreRetrievalToken(token); err != nil {
return nil, "", fmt.Errorf("error setting key in persistent cache: %w", err)
}
leaseCache.SetPersistentStorage(ps)
return ps.Close, "", nil
}
}
// getServiceAccountJWT attempts to read the service account JWT from the specified token file path.
// Defaults to using the Kubernetes default service account file path if token file path is empty.
func getServiceAccountJWT(tokenFile string) (string, error) {
if len(tokenFile) == 0 {
tokenFile = "/var/run/secrets/kubernetes.io/serviceaccount/token"
}
token, err := os.ReadFile(tokenFile)
if err != nil {
return "", err
}
return strings.TrimSpace(string(token)), nil
}

View File

@ -0,0 +1,92 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package agentproxyshared
import (
"context"
"os"
"testing"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agentproxyshared/cache"
"github.com/hashicorp/vault/sdk/helper/logging"
)
func testNewLeaseCache(t *testing.T, responses []*cache.SendResponse) *cache.LeaseCache {
t.Helper()
client, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}
lc, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{
Client: client,
BaseContext: context.Background(),
Proxier: cache.NewMockProxier(responses),
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"),
})
if err != nil {
t.Fatal(err)
}
return lc
}
func populateTempFile(t *testing.T, name, contents string) *os.File {
t.Helper()
file, err := os.CreateTemp(t.TempDir(), name)
if err != nil {
t.Fatal(err)
}
_, err = file.WriteString(contents)
if err != nil {
t.Fatal(err)
}
err = file.Close()
if err != nil {
t.Fatal(err)
}
return file
}
// Test_AddPersistentStorageToLeaseCache Tests that AddPersistentStorageToLeaseCache() correctly
// adds persistent storage to a lease cache
func Test_AddPersistentStorageToLeaseCache(t *testing.T) {
tempDir := t.TempDir()
serviceAccountTokenFile := populateTempFile(t, "proxy-config.hcl", "token")
persistConfig := &PersistConfig{
Type: "kubernetes",
Path: tempDir,
KeepAfterImport: false,
ExitOnErr: false,
ServiceAccountTokenFile: serviceAccountTokenFile.Name(),
}
leaseCache := testNewLeaseCache(t, nil)
if leaseCache.PersistentStorage() != nil {
t.Fatal("persistent storage was available before ours was added")
}
deferFunc, token, err := AddPersistentStorageToLeaseCache(context.Background(), leaseCache, persistConfig, logging.NewVaultLogger(hclog.Info))
if err != nil {
t.Fatal(err)
}
if leaseCache.PersistentStorage() == nil {
t.Fatal("persistent storage was not added")
}
if token != "" {
t.Fatal("expected token to be empty")
}
if deferFunc == nil {
t.Fatal("expected deferFunc to not be nil")
}
}

View File

@ -12,7 +12,6 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"path/filepath"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@ -26,23 +25,9 @@ import (
"github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/go-secure-stdlib/reloadutil" "github.com/hashicorp/go-secure-stdlib/reloadutil"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agentproxyshared"
"github.com/hashicorp/vault/command/agentproxyshared/auth" "github.com/hashicorp/vault/command/agentproxyshared/auth"
"github.com/hashicorp/vault/command/agentproxyshared/auth/alicloud"
"github.com/hashicorp/vault/command/agentproxyshared/auth/approle"
"github.com/hashicorp/vault/command/agentproxyshared/auth/aws"
"github.com/hashicorp/vault/command/agentproxyshared/auth/azure"
"github.com/hashicorp/vault/command/agentproxyshared/auth/cert"
"github.com/hashicorp/vault/command/agentproxyshared/auth/cf"
"github.com/hashicorp/vault/command/agentproxyshared/auth/gcp"
"github.com/hashicorp/vault/command/agentproxyshared/auth/jwt"
"github.com/hashicorp/vault/command/agentproxyshared/auth/kerberos"
"github.com/hashicorp/vault/command/agentproxyshared/auth/kubernetes"
"github.com/hashicorp/vault/command/agentproxyshared/auth/oci"
token_file "github.com/hashicorp/vault/command/agentproxyshared/auth/token-file"
cache "github.com/hashicorp/vault/command/agentproxyshared/cache" cache "github.com/hashicorp/vault/command/agentproxyshared/cache"
"github.com/hashicorp/vault/command/agentproxyshared/cache/cacheboltdb"
"github.com/hashicorp/vault/command/agentproxyshared/cache/cachememdb"
"github.com/hashicorp/vault/command/agentproxyshared/cache/keymanager"
"github.com/hashicorp/vault/command/agentproxyshared/sink" "github.com/hashicorp/vault/command/agentproxyshared/sink"
"github.com/hashicorp/vault/command/agentproxyshared/sink/file" "github.com/hashicorp/vault/command/agentproxyshared/sink/file"
"github.com/hashicorp/vault/command/agentproxyshared/sink/inmem" "github.com/hashicorp/vault/command/agentproxyshared/sink/inmem"
@ -345,39 +330,9 @@ func (c *ProxyCommand) Run(args []string) int {
MountPath: config.AutoAuth.Method.MountPath, MountPath: config.AutoAuth.Method.MountPath,
Config: config.AutoAuth.Method.Config, Config: config.AutoAuth.Method.Config,
} }
switch config.AutoAuth.Method.Type { method, err = agentproxyshared.GetAutoAuthMethodFromConfig(config.AutoAuth.Method.Type, authConfig, config.Vault.Address)
case "alicloud":
method, err = alicloud.NewAliCloudAuthMethod(authConfig)
case "aws":
method, err = aws.NewAWSAuthMethod(authConfig)
case "azure":
method, err = azure.NewAzureAuthMethod(authConfig)
case "cert":
method, err = cert.NewCertAuthMethod(authConfig)
case "cf":
method, err = cf.NewCFAuthMethod(authConfig)
case "gcp":
method, err = gcp.NewGCPAuthMethod(authConfig)
case "jwt":
method, err = jwt.NewJWTAuthMethod(authConfig)
case "kerberos":
method, err = kerberos.NewKerberosAuthMethod(authConfig)
case "kubernetes":
method, err = kubernetes.NewKubernetesAuthMethod(authConfig)
case "approle":
method, err = approle.NewApproleAuthMethod(authConfig)
case "oci":
method, err = oci.NewOCIAuthMethod(authConfig, config.Vault.Address)
case "token_file":
method, err = token_file.NewTokenFileAuthMethod(authConfig)
case "pcf": // Deprecated.
method, err = cf.NewCFAuthMethod(authConfig)
default:
c.UI.Error(fmt.Sprintf("Unknown auth method %q", config.AutoAuth.Method.Type))
return 1
}
if err != nil { if err != nil {
c.UI.Error(fmt.Errorf("Error creating %s auth method: %w", config.AutoAuth.Method.Type, err).Error()) c.UI.Error(fmt.Sprintf("Error creating %s auth method: %v", config.AutoAuth.Method.Type, err))
return 1 return 1
} }
} }
@ -465,7 +420,7 @@ func (c *ProxyCommand) Run(args []string) int {
EnforceConsistency: enforceConsistency, EnforceConsistency: enforceConsistency,
WhenInconsistentAction: whenInconsistent, WhenInconsistentAction: whenInconsistent,
UserAgentStringFunction: useragent.ProxyStringWithProxiedUserAgent, UserAgentStringFunction: useragent.ProxyStringWithProxiedUserAgent,
UserAgentString: useragent.ProxyString(), UserAgentString: useragent.ProxyAPIProxyString(),
}) })
if err != nil { if err != nil {
c.UI.Error(fmt.Sprintf("Error creating API proxy: %v", err)) c.UI.Error(fmt.Sprintf("Error creating API proxy: %v", err))
@ -497,147 +452,14 @@ func (c *ProxyCommand) Run(args []string) int {
// Configure persistent storage and add to LeaseCache // Configure persistent storage and add to LeaseCache
if config.Cache.Persist != nil { if config.Cache.Persist != nil {
if config.Cache.Persist.Path == "" { deferFunc, oldToken, err := agentproxyshared.AddPersistentStorageToLeaseCache(ctx, leaseCache, config.Cache.Persist, cacheLogger)
c.UI.Error("must specify persistent cache path")
return 1
}
// Set AAD based on key protection type
var aad string
switch config.Cache.Persist.Type {
case "kubernetes":
aad, err = getServiceAccountJWT(config.Cache.Persist.ServiceAccountTokenFile)
if err != nil {
c.UI.Error(fmt.Sprintf("failed to read service account token from %s: %s", config.Cache.Persist.ServiceAccountTokenFile, err))
return 1
}
default:
c.UI.Error(fmt.Sprintf("persistent key protection type %q not supported", config.Cache.Persist.Type))
return 1
}
// Check if bolt file exists already
dbFileExists, err := cacheboltdb.DBFileExists(config.Cache.Persist.Path)
if err != nil { if err != nil {
c.UI.Error(fmt.Sprintf("failed to check if bolt file exists at path %s: %s", config.Cache.Persist.Path, err)) c.UI.Error(fmt.Sprintf("Error creating persistent cache: %v", err))
return 1 return 1
} }
if dbFileExists { previousToken = oldToken
// Open the bolt file, but wait to setup Encryption if deferFunc != nil {
ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{ defer deferFunc()
Path: config.Cache.Persist.Path,
Logger: cacheLogger.Named("cacheboltdb"),
})
if err != nil {
c.UI.Error(fmt.Sprintf("Error opening persistent cache: %v", err))
return 1
}
// Get the token from bolt for retrieving the encryption key,
// then setup encryption so that restore is possible
token, err := ps.GetRetrievalToken()
if err != nil {
c.UI.Error(fmt.Sprintf("Error getting retrieval token from persistent cache: %v", err))
}
if err := ps.Close(); err != nil {
c.UI.Warn(fmt.Sprintf("Failed to close persistent cache file after getting retrieval token: %s", err))
}
km, err := keymanager.NewPassthroughKeyManager(ctx, token)
if err != nil {
c.UI.Error(fmt.Sprintf("failed to configure persistence encryption for cache: %s", err))
return 1
}
// Open the bolt file with the wrapper provided
ps, err = cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
Path: config.Cache.Persist.Path,
Logger: cacheLogger.Named("cacheboltdb"),
Wrapper: km.Wrapper(),
AAD: aad,
})
if err != nil {
c.UI.Error(fmt.Sprintf("Error opening persistent cache with wrapper: %v", err))
return 1
}
// Restore anything in the persistent cache to the memory cache
if err := leaseCache.Restore(ctx, ps); err != nil {
c.UI.Error(fmt.Sprintf("Error restoring in-memory cache from persisted file: %v", err))
if config.Cache.Persist.ExitOnErr {
return 1
}
}
cacheLogger.Info("loaded memcache from persistent storage")
// Check for previous auto-auth token
oldTokenBytes, err := ps.GetAutoAuthToken(ctx)
if err != nil {
c.UI.Error(fmt.Sprintf("Error in fetching previous auto-auth token: %s", err))
if config.Cache.Persist.ExitOnErr {
return 1
}
}
if len(oldTokenBytes) > 0 {
oldToken, err := cachememdb.Deserialize(oldTokenBytes)
if err != nil {
c.UI.Error(fmt.Sprintf("Error in deserializing previous auto-auth token cache entry: %s", err))
if config.Cache.Persist.ExitOnErr {
return 1
}
}
previousToken = oldToken.Token
}
// If keep_after_import true, set persistent storage layer in
// leaseCache, else remove db file
if config.Cache.Persist.KeepAfterImport {
defer ps.Close()
leaseCache.SetPersistentStorage(ps)
} else {
if err := ps.Close(); err != nil {
c.UI.Warn(fmt.Sprintf("failed to close persistent cache file: %s", err))
}
dbFile := filepath.Join(config.Cache.Persist.Path, cacheboltdb.DatabaseFileName)
if err := os.Remove(dbFile); err != nil {
c.UI.Error(fmt.Sprintf("failed to remove persistent storage file %s: %s", dbFile, err))
if config.Cache.Persist.ExitOnErr {
return 1
}
}
}
} else {
km, err := keymanager.NewPassthroughKeyManager(ctx, nil)
if err != nil {
c.UI.Error(fmt.Sprintf("failed to configure persistence encryption for cache: %s", err))
return 1
}
ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
Path: config.Cache.Persist.Path,
Logger: cacheLogger.Named("cacheboltdb"),
Wrapper: km.Wrapper(),
AAD: aad,
})
if err != nil {
c.UI.Error(fmt.Sprintf("Error creating persistent cache: %v", err))
return 1
}
cacheLogger.Info("configured persistent storage", "path", config.Cache.Persist.Path)
// Stash the key material in bolt
token, err := km.RetrievalToken(ctx)
if err != nil {
c.UI.Error(fmt.Sprintf("Error getting persistent key: %s", err))
return 1
}
if err := ps.StoreRetrievalToken(token); err != nil {
c.UI.Error(fmt.Sprintf("Error setting key in persistent cache: %v", err))
return 1
}
defer ps.Close()
leaseCache.SetPersistentStorage(ps)
} }
} }
} }

View File

@ -19,6 +19,7 @@ import (
"github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/hcl" "github.com/hashicorp/hcl"
"github.com/hashicorp/hcl/hcl/ast" "github.com/hashicorp/hcl/hcl/ast"
"github.com/hashicorp/vault/command/agentproxyshared"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/internalshared/configutil"
) )
@ -100,17 +101,8 @@ type APIProxy struct {
// Cache contains any configuration needed for Cache mode // Cache contains any configuration needed for Cache mode
type Cache struct { type Cache struct {
Persist *Persist `hcl:"persist"` Persist *agentproxyshared.PersistConfig `hcl:"persist"`
InProcDialer transportDialer `hcl:"-"` InProcDialer transportDialer `hcl:"-"`
}
// Persist contains configuration needed for persistent caching
type Persist struct {
Type string
Path string `hcl:"path"`
KeepAfterImport bool `hcl:"keep_after_import"`
ExitOnErr bool `hcl:"exit_on_err"`
ServiceAccountTokenFile string `hcl:"service_account_token_file"`
} }
// AutoAuth is the configured authentication method and sinks // AutoAuth is the configured authentication method and sinks
@ -640,7 +632,7 @@ func parsePersist(result *Config, list *ast.ObjectList) error {
item := persistList.Items[0] item := persistList.Items[0]
var p Persist var p agentproxyshared.PersistConfig
err := hcl.DecodeObject(&p, item.Val) err := hcl.DecodeObject(&p, item.Val)
if err != nil { if err != nil {
return err return err

View File

@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/go-test/deep" "github.com/go-test/deep"
"github.com/hashicorp/vault/command/agentproxyshared"
"github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/internalshared/configutil"
) )
@ -78,7 +79,7 @@ func TestLoadConfigFile_ProxyCache(t *testing.T) {
ForceAutoAuthToken: false, ForceAutoAuthToken: false,
}, },
Cache: &Cache{ Cache: &Cache{
Persist: &Persist{ Persist: &agentproxyshared.PersistConfig{
Type: "kubernetes", Type: "kubernetes",
Path: "/vault/agent-cache/", Path: "/vault/agent-cache/",
KeepAfterImport: true, KeepAfterImport: true,

View File

@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAwF7sRAyUiLcd6es6VeaTRUBOusFFGkmKJ5lU351waCJqXFju
Z6i/SQYNAAnnRgotXSTE1fIPjE2kZNH1hvqE5IpTGgAwy50xpjJrrBBI6e9lyKqj
7T8gLVNBvtC0cpQi+pGrszEI0ckDQCSZHqi/PAzcpmLUgh2KMrgagT+YlN35KHtl
/bQ/Fsn+kqykVqNw69n/CDKNKdDHn1qPwiX9q/fTMj3EG6g+3ntKrUOh8V/gHKPz
q8QGP/wIud2K+tTSorVXr/4zx7xgzlbJkCakzcQQiP6K+paPnDRlE8fK+1gRRyR7
XCzyp0irUl8G1NjYAR/tVWxiUhlk/jZutb8PpwIDAQABAoIBAEOzJELuindyujxQ
ZD9G3h1I/GwNCFyv9Mbq10u7BIwhUH0fbwdcA7WXQ4v38ERd4IkfH4aLoZ0m1ewF
V/sgvxQO+h/0YTfHImny5KGxOXfaoF92bipYROKuojydBmQsbgLwsRRm9UufCl3Q
g3KewG5JuH112oPQEYq379v8nZ4FxC3Ano1OFBTm9UhHIAX1Dn22kcHOIIw8jCsQ
zp7TZOW+nwtkS41cBwhvV4VIeL6yse2UgbOfRVRwI7B0OtswS5VgW3wysO2mTDKt
V/WCmeht1il/6ZogEHgi/mvDCKpj20wQ1EzGnPdFLdiFJFylf0oufQD/7N/uezbC
is0qJEECgYEA3AE7SeLpe3SZApj2RmE2lcD9/Saj1Y30PznxB7M7hK0sZ1yXEbtS
Qf894iDDD/Cn3ufA4xk/K52CXgAcqvH/h2geG4pWLYsT1mdWhGftprtOMCIvJvzU
8uWJzKdOGVMG7R59wNgEpPDZDpBISjexwQsFo3aw1L/H1/Sa8cdY3a0CgYEA39hB
1oLmGRyE32Q4GF/srG4FqKL1EsbISGDUEYTnaYg2XiM43gu3tC/ikfclk27Jwc2L
m7cA5FxxaEyfoOgfAizfU/uWTAbx9GoXgWsO0hWSN9+YNq61gc5WKoHyrJ/rfrti
y5d7k0OCeBxckLqGDuJqICQ0myiz0El6FU8h5SMCgYEAuhigmiNC9JbwRu40g9v/
XDVfox9oPmBRVpogdC78DYKeqN/9OZaGQiUxp3GnDni2xyqqUm8srCwT9oeJuF/z
kgpUTV96/hNCuH25BU8UC5Es1jJUSFpdlwjqwx5SRcGhfjnojZMseojwUg1h2MW7
qls0bc0cTxnaZaYW2qWRWhECgYBrT0cwyQv6GdvxJCBoPwQ9HXmFAKowWC+H0zOX
Onmd8/jsZEJM4J0uuo4Jn8vZxBDg4eL9wVuiHlcXwzP7dYv4BP8DSechh2rS21Ft
b59pQ4IXWw+jl1nYYsyYEDgAXaIN3VNder95N7ICVsZhc6n01MI/qlu1zmt1fOQT
9x2utQKBgHI9SbsfWfbGiu6oLS3+9V1t4dORhj8D8b7z3trvECrD6tPhxoZqtfrH
4apKr3OKRSXk3K+1K6pkMHJHunspucnA1ChXLhzfNF08BSRJkQDGYuaRLS6VGgab
JZTl54bGvO1GkszEBE/9QFcqNVtWGMWXnUPwNNv8t//yJT5rvQil
-----END RSA PRIVATE KEY-----

View File

@ -0,0 +1,20 @@
-----BEGIN CERTIFICATE-----
MIIDQzCCAiugAwIBAgIULLCz3mZKmg2xy3rWCud0f1zcmBwwDQYJKoZIhvcNAQEL
BQAwFjEUMBIGA1UEAxMLZXhhbXBsZS5jb20wHhcNMTYwMzEwMDIzNjQ0WhcNMzYw
MzA1MDEzNzE0WjAaMRgwFgYDVQQDEw9iYXIuZXhhbXBsZS5jb20wggEiMA0GCSqG
SIb3DQEBAQUAA4IBDwAwggEKAoIBAQDAXuxEDJSItx3p6zpV5pNFQE66wUUaSYon
mVTfnXBoImpcWO5nqL9JBg0ACedGCi1dJMTV8g+MTaRk0fWG+oTkilMaADDLnTGm
MmusEEjp72XIqqPtPyAtU0G+0LRylCL6kauzMQjRyQNAJJkeqL88DNymYtSCHYoy
uBqBP5iU3fkoe2X9tD8Wyf6SrKRWo3Dr2f8IMo0p0MefWo/CJf2r99MyPcQbqD7e
e0qtQ6HxX+Aco/OrxAY//Ai53Yr61NKitVev/jPHvGDOVsmQJqTNxBCI/or6lo+c
NGUTx8r7WBFHJHtcLPKnSKtSXwbU2NgBH+1VbGJSGWT+Nm61vw+nAgMBAAGjgYQw
gYEwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBSVoF8F
7qbzSryIFrldurAG78LvSjAfBgNVHSMEGDAWgBRzDNvqF/Tq21OgWs13B5YydZjl
vzAgBgNVHREEGTAXgg9iYXIuZXhhbXBsZS5jb22HBH8AAAEwDQYJKoZIhvcNAQEL
BQADggEBAGmz2N282iT2IaEZvOmzIE4znHGkvoxZmrr/2byq5PskBg9ysyCHfUvw
SFA8U7jWjezKTnGRUu5blB+yZdjrMtB4AePWyEqtkJwVsZ2SPeP+9V2gNYK4iktP
UF3aIgBbAbw8rNuGIIB0T4D+6Zyo9Y3MCygs6/N4bRPZgLhewWn1ilklfnl3eqaC
a+JY1NBuTgCMa28NuC+Hy3mCveqhI8tFNiOthlLdgAEbuQaOuNutAG73utZ2aq6Q
W4pajFm3lEf5zt7Lo6ZCFtY/Q8jjURJ9e4O7VjXcqIhBM5bSMI6+fgQyOH0SLboj
RNanJ2bcyF1iPVyPBGzV3dF0ngYzxEY=
-----END CERTIFICATE-----

View File

@ -0,0 +1,20 @@
-----BEGIN CERTIFICATE-----
MIIDNTCCAh2gAwIBAgIUBeVo+Ce2BrdRT1cogKvJLtdOky8wDQYJKoZIhvcNAQEL
BQAwFjEUMBIGA1UEAxMLZXhhbXBsZS5jb20wHhcNMTYwMzEwMDIzNTM4WhcNMzYw
MzA1MDIzNjA4WjAWMRQwEgYDVQQDEwtleGFtcGxlLmNvbTCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBAPTQGWPRIOECGeJB6tR/ftvvtioC9f84fY2QdJ5k
JBupXjPAGYKgS4MGzyT5bz9yY400tCtmh6h7p9tZwHl/TElTugtLQ/8ilMbJTiOM
SiyaMDPHiMJJYKTjm9bu6bKeU1qPZ0Cryes4rygbqs7w2XPgA2RxNmDh7JdX7/h+
VB5onBmv8g4WFSayowGyDcJWWCbu5yv6ZdH1bqQjgRzQ5xp17WXNmvlzdp2vate/
9UqPdA8sdJzW/91Gvmros0o/FnG7c2pULhk22wFqO8t2HRjKb3nuxALEJvqoPvad
KjpDTaq1L1ZzxcB7wvWyhy/lNLZL7jiNWy0mN1YB0UpSWdECAwEAAaN7MHkwDgYD
VR0PAQH/BAQDAgEGMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYEFHMM2+oX9Orb
U6BazXcHljJ1mOW/MB8GA1UdIwQYMBaAFHMM2+oX9OrbU6BazXcHljJ1mOW/MBYG
A1UdEQQPMA2CC2V4YW1wbGUuY29tMA0GCSqGSIb3DQEBCwUAA4IBAQAp17XsOaT9
hculRqrFptn3+zkH3HrIckHm+28R5xYT8ASFXFcLFugGizJAXVL5lvsRVRIwCoOX
Nhi8XSNEFP640VbHcEl81I84bbRIIDS+Yheu6JDZGemTaDYLv1J3D5SHwgoM+nyf
oTRgotUCIXcwJHmTpWEUkZFKuqBxsoTGzk0jO8wOP6xoJkzxVVG5PvNxs924rxY8
Y8iaLdDfMeT7Pi0XIliBa/aSp/iqSW8XKyJl5R5vXg9+DOgZUrVzIxObaF5RBl/a
mJOeklJBdNVzQm5+iMpO42lu0TA9eWtpP+YiUEXU17XDvFeQWOocFbQ1Peo0W895
XRz2GCwCNyvW
-----END CERTIFICATE-----

View File

@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEpgIBAAKCAQEAzNyVieSti9XBb5/celB5u8YKRJv3mQS9A4/X0mqY1ePznt1i
ilG7OmG0yM2VAk0ceIAQac3Bsn74jxn2cDlrrVniPXcNgYtMtW0kRqNEo4doo4EX
xZguS9vNBu29useHhif1TGX/pA3dgvaVycUCjzTEVk6qI8UEehMK6gEGZb7nOr0A
A9nipSqoeHpDLe3a4KVqj1vtlJKUvD2i1MuBuQ130cB1K9rufLCShGu7mEgzEosc
gr+K3Bf03IejbeVRyIfLtgj1zuvV1katec75UqRA/bsvt5G9JfJqiZ9mwFN0vp3g
Cr7pdQBSBQ2q4yf9s8CuY5c5w9fl3F8f5QFQoQIDAQABAoIBAQCbCb1qNFRa5ZSV
I8i6ELlwMDqJHfhOJ9XcIjpVljLAfNlcu3Ld92jYkCU/asaAjVckotbJG9yhd5Io
yp9E40/oS4P6vGTOS1vsWgMAKoPBtrKsOwCAm+E9q8UIn1fdSS/5ibgM74x+3bds
a62Em8KKGocUQkhk9a+jq1GxMsFisbHRxEHvClLmDMgGnW3FyGmWwT6yZLPSC0ey
szmmjt3ouP8cLAOmSjzcQBMmEZpQMCgR6Qckg6nrLQAGzZyTdCd875wbGA57DpWX
Lssn95+A5EFvr/6b7DkXeIFCrYBFFa+UQN3PWGEQ6Zjmiw4VgV2vO8yX2kCLlUhU
02bL393ZAoGBAPXPD/0yWINbKUPcRlx/WfWQxfz0bu50ytwIXzVK+pRoAMuNqehK
BJ6kNzTTBq40u+IZ4f5jbLDulymR+4zSkirLE7CyWFJOLNI/8K4Pf5DJUgNdrZjJ
LCtP9XRdxiPatQF0NGfdgHlSJh+/CiRJP4AgB17AnB/4z9/M0ZlJGVrzAoGBANVa
69P3Rp/WPBQv0wx6f0tWppJolWekAHKcDIdQ5HdOZE5CPAYSlTrTUW3uJuqMwU2L
M0Er2gIPKWIR5X+9r7Fvu9hQW6l2v3xLlcrGPiapp3STJvuMxzhRAmXmu3bZfVn1
Vn7Vf1jPULHtTFSlNFEvYG5UJmygK9BeyyVO5KMbAoGBAMCyAibLQPg4jrDUDZSV
gUAwrgUO2ae1hxHWvkxY6vdMUNNByuB+pgB3W4/dnm8Sh/dHsxJpftt1Lqs39ar/
p/ZEHLt4FCTxg9GOrm7FV4t5RwG8fko36phJpnIC0UFqQltRbYO+8OgqrhhU+u5X
PaCDe0OcWsf1lYAsYGN6GpZhAoGBAMJ5Ksa9+YEODRs1cIFKUyd/5ztC2xRqOAI/
3WemQ2nAacuvsfizDZVeMzYpww0+maAuBt0btI719PmwaGmkpDXvK+EDdlmkpOwO
FY6MXvBs6fdnfjwCWUErDi2GQFAX9Jt/9oSL5JU1+08DhvUM1QA/V/2Y9KFE6kr3
bOIn5F4LAoGBAKQzH/AThDGhT3hwr4ktmReF3qKxBgxzjVa8veXtkY5VWwyN09iT
jnTTt6N1CchZoK5WCETjdzNYP7cuBTcV4d3bPNRiJmxXaNVvx3Tlrk98OiffT8Qa
5DO/Wfb43rNHYXBjU6l0n2zWcQ4PUSSbu0P0bM2JTQPRCqSthXvSHw2P
-----END RSA PRIVATE KEY-----

View File

@ -0,0 +1,20 @@
-----BEGIN CERTIFICATE-----
MIIDQzCCAiugAwIBAgIUFVW6i/M+yJUsDrXWgRKO/Dnb+L4wDQYJKoZIhvcNAQEL
BQAwFjEUMBIGA1UEAxMLZXhhbXBsZS5jb20wHhcNMTYwMzEwMDIzNjA1WhcNMzYw
MzA1MDEzNjM1WjAaMRgwFgYDVQQDEw9mb28uZXhhbXBsZS5jb20wggEiMA0GCSqG
SIb3DQEBAQUAA4IBDwAwggEKAoIBAQDM3JWJ5K2L1cFvn9x6UHm7xgpEm/eZBL0D
j9fSapjV4/Oe3WKKUbs6YbTIzZUCTRx4gBBpzcGyfviPGfZwOWutWeI9dw2Bi0y1
bSRGo0Sjh2ijgRfFmC5L280G7b26x4eGJ/VMZf+kDd2C9pXJxQKPNMRWTqojxQR6
EwrqAQZlvuc6vQAD2eKlKqh4ekMt7drgpWqPW+2UkpS8PaLUy4G5DXfRwHUr2u58
sJKEa7uYSDMSixyCv4rcF/Tch6Nt5VHIh8u2CPXO69XWRq15zvlSpED9uy+3kb0l
8mqJn2bAU3S+neAKvul1AFIFDarjJ/2zwK5jlznD1+XcXx/lAVChAgMBAAGjgYQw
gYEwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBRNJoOJ
dnazDiuqLhV6truQ4cRe9jAfBgNVHSMEGDAWgBRzDNvqF/Tq21OgWs13B5YydZjl
vzAgBgNVHREEGTAXgg9mb28uZXhhbXBsZS5jb22HBH8AAAEwDQYJKoZIhvcNAQEL
BQADggEBAHzv67mtbxMWcuMsxCFBN1PJNAyUDZVCB+1gWhk59EySbVg81hWJDCBy
fl3TKjz3i7wBGAv+C2iTxmwsSJbda22v8JQbuscXIfLFbNALsPzF+J0vxAgJs5Gc
sDbfJ7EQOIIOVKQhHLYnQoLnigSSPc1kd0JjYyHEBjgIaSuXgRRTBAeqLiBMx0yh
RKL1lQ+WoBU/9SXUZZkwokqWt5G7khi5qZkNxVXZCm8VGPg0iywf6gGyhI1SU5S2
oR219S6kA4JY/stw1qne85/EmHmoImHGt08xex3GoU72jKAjsIpqRWopcD/+uene
Tc9nn3fTQW/Z9fsoJ5iF5OdJnDEswqE=
-----END CERTIFICATE-----

View File

@ -4,24 +4,33 @@
package command package command
import ( import (
"crypto/tls"
"crypto/x509"
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
"path/filepath"
"reflect"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt" vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt"
logicalKv "github.com/hashicorp/vault-plugin-secrets-kv"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
credAppRole "github.com/hashicorp/vault/builtin/credential/approle" credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
"github.com/hashicorp/vault/command/agent" "github.com/hashicorp/vault/command/agent"
proxyConfig "github.com/hashicorp/vault/command/proxy/config"
"github.com/hashicorp/vault/helper/useragent" "github.com/hashicorp/vault/helper/useragent"
vaulthttp "github.com/hashicorp/vault/http" vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
"github.com/mitchellh/cli" "github.com/mitchellh/cli"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func testProxyCommand(tb testing.TB, logger hclog.Logger) (*cli.MockUi, *ProxyCommand) { func testProxyCommand(tb testing.TB, logger hclog.Logger) (*cli.MockUi, *ProxyCommand) {
@ -440,7 +449,7 @@ vault {
configPath := makeTempFile(t, "config.hcl", config) configPath := makeTempFile(t, "config.hcl", config)
defer os.Remove(configPath) defer os.Remove(configPath)
// Start the agent // Start the proxy
_, cmd := testProxyCommand(t, logger) _, cmd := testProxyCommand(t, logger)
cmd.startedCh = make(chan struct{}) cmd.startedCh = make(chan struct{})
@ -532,7 +541,7 @@ vault {
configPath := makeTempFile(t, "config.hcl", config) configPath := makeTempFile(t, "config.hcl", config)
defer os.Remove(configPath) defer os.Remove(configPath)
// Start the agent // Start the proxy
_, cmd := testProxyCommand(t, logger) _, cmd := testProxyCommand(t, logger)
cmd.startedCh = make(chan struct{}) cmd.startedCh = make(chan struct{})
@ -582,7 +591,7 @@ func TestProxy_Cache_DynamicSecret(t *testing.T) {
serverClient := cluster.Cores[0].Client serverClient := cluster.Cores[0].Client
// Unset the environment variable so that agent picks up the right test // Unset the environment variable so that proxy picks up the right test
// cluster address // cluster address
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
os.Unsetenv(api.EnvVaultAddress) os.Unsetenv(api.EnvVaultAddress)
@ -676,3 +685,586 @@ vault {
close(cmd.ShutdownCh) close(cmd.ShutdownCh)
wg.Wait() wg.Wait()
} }
// TestProxy_ApiProxy_Retry Tests the retry functionalities of Vault Proxy's API Proxy
func TestProxy_ApiProxy_Retry(t *testing.T) {
//----------------------------------------------------
// Start the server and proxy
//----------------------------------------------------
logger := logging.NewVaultLogger(hclog.Trace)
var h handler
cluster := vault.NewTestCluster(t,
&vault.CoreConfig{
Logger: logger,
CredentialBackends: map[string]logical.Factory{
"approle": credAppRole.Factory,
},
LogicalBackends: map[string]logical.Factory{
"kv": logicalKv.Factory,
},
},
&vault.TestClusterOptions{
NumCores: 1,
HandlerFunc: vaulthttp.HandlerFunc(func(properties *vault.HandlerProperties) http.Handler {
h.props = properties
h.t = t
return &h
}),
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
serverClient := cluster.Cores[0].Client
// Unset the environment variable so that proxy picks up the right test
// cluster address
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
os.Unsetenv(api.EnvVaultAddress)
_, err := serverClient.Logical().Write("secret/foo", map[string]interface{}{
"bar": "baz",
})
if err != nil {
t.Fatal(err)
}
intRef := func(i int) *int {
return &i
}
// start test cases here
testCases := map[string]struct {
retries *int
expectError bool
}{
"none": {
retries: intRef(-1),
expectError: true,
},
"one": {
retries: intRef(1),
expectError: true,
},
"two": {
retries: intRef(2),
expectError: false,
},
"missing": {
retries: nil,
expectError: false,
},
"default": {
retries: intRef(0),
expectError: false,
},
}
for tcname, tc := range testCases {
t.Run(tcname, func(t *testing.T) {
h.failCount = 2
cacheConfig := `
cache {
}
`
listenAddr := generateListenerAddress(t)
listenConfig := fmt.Sprintf(`
listener "tcp" {
address = "%s"
tls_disable = true
}
`, listenAddr)
var retryConf string
if tc.retries != nil {
retryConf = fmt.Sprintf("retry { num_retries = %d }", *tc.retries)
}
config := fmt.Sprintf(`
vault {
address = "%s"
%s
tls_skip_verify = true
}
%s
%s
`, serverClient.Address(), retryConf, cacheConfig, listenConfig)
configPath := makeTempFile(t, "config.hcl", config)
defer os.Remove(configPath)
_, cmd := testProxyCommand(t, logger)
cmd.startedCh = make(chan struct{})
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
cmd.Run([]string{"-config", configPath})
wg.Done()
}()
select {
case <-cmd.startedCh:
case <-time.After(5 * time.Second):
t.Errorf("timeout")
}
client, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}
client.SetToken(serverClient.Token())
client.SetMaxRetries(0)
err = client.SetAddress("http://" + listenAddr)
if err != nil {
t.Fatal(err)
}
secret, err := client.Logical().Read("secret/foo")
switch {
case (err != nil || secret == nil) && tc.expectError:
case (err == nil || secret != nil) && !tc.expectError:
default:
t.Fatalf("%s expectError=%v error=%v secret=%v", tcname, tc.expectError, err, secret)
}
if secret != nil && secret.Data["foo"] != nil {
val := secret.Data["foo"].(map[string]interface{})
if !reflect.DeepEqual(val, map[string]interface{}{"bar": "baz"}) {
t.Fatalf("expected key 'foo' to yield bar=baz, got: %v", val)
}
}
time.Sleep(time.Second)
close(cmd.ShutdownCh)
wg.Wait()
})
}
}
// TestProxy_Metrics tests that metrics are being properly reported.
func TestProxy_Metrics(t *testing.T) {
// Start a vault server
logger := logging.NewVaultLogger(hclog.Trace)
cluster := vault.NewTestCluster(t,
&vault.CoreConfig{
Logger: logger,
},
&vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
serverClient := cluster.Cores[0].Client
// Create a config file
listenAddr := generateListenerAddress(t)
config := fmt.Sprintf(`
cache {}
listener "tcp" {
address = "%s"
tls_disable = true
}
`, listenAddr)
configPath := makeTempFile(t, "config.hcl", config)
defer os.Remove(configPath)
ui, cmd := testProxyCommand(t, logger)
cmd.client = serverClient
cmd.startedCh = make(chan struct{})
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
code := cmd.Run([]string{"-config", configPath})
if code != 0 {
t.Errorf("non-zero return code when running proxy: %d", code)
t.Logf("STDOUT from proxy:\n%s", ui.OutputWriter.String())
t.Logf("STDERR from proxy:\n%s", ui.ErrorWriter.String())
}
wg.Done()
}()
select {
case <-cmd.startedCh:
case <-time.After(5 * time.Second):
t.Errorf("timeout")
}
// defer proxy shutdown
defer func() {
cmd.ShutdownCh <- struct{}{}
wg.Wait()
}()
conf := api.DefaultConfig()
conf.Address = "http://" + listenAddr
proxyClient, err := api.NewClient(conf)
if err != nil {
t.Fatalf("err: %s", err)
}
req := proxyClient.NewRequest("GET", "/proxy/v1/metrics")
body := request(t, proxyClient, req, 200)
keys := []string{}
for k := range body {
keys = append(keys, k)
}
require.ElementsMatch(t, keys, []string{
"Counters",
"Samples",
"Timestamp",
"Gauges",
"Points",
})
}
// TestProxy_QuitAPI Tests the /proxy/v1/quit API that can be enabled for the proxy.
func TestProxy_QuitAPI(t *testing.T) {
logger := logging.NewVaultLogger(hclog.Error)
cluster := vault.NewTestCluster(t,
&vault.CoreConfig{
Logger: logger,
CredentialBackends: map[string]logical.Factory{
"approle": credAppRole.Factory,
},
LogicalBackends: map[string]logical.Factory{
"kv": logicalKv.Factory,
},
},
&vault.TestClusterOptions{
NumCores: 1,
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
serverClient := cluster.Cores[0].Client
// Unset the environment variable so that proxy picks up the right test
// cluster address
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
err := os.Unsetenv(api.EnvVaultAddress)
if err != nil {
t.Fatal(err)
}
listenAddr := generateListenerAddress(t)
listenAddr2 := generateListenerAddress(t)
config := fmt.Sprintf(`
vault {
address = "%s"
tls_skip_verify = true
}
listener "tcp" {
address = "%s"
tls_disable = true
}
listener "tcp" {
address = "%s"
tls_disable = true
proxy_api {
enable_quit = true
}
}
cache {}
`, serverClient.Address(), listenAddr, listenAddr2)
configPath := makeTempFile(t, "config.hcl", config)
defer os.Remove(configPath)
_, cmd := testProxyCommand(t, logger)
cmd.startedCh = make(chan struct{})
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
cmd.Run([]string{"-config", configPath})
wg.Done()
}()
select {
case <-cmd.startedCh:
case <-time.After(5 * time.Second):
t.Errorf("timeout")
}
client, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}
client.SetToken(serverClient.Token())
client.SetMaxRetries(0)
err = client.SetAddress("http://" + listenAddr)
if err != nil {
t.Fatal(err)
}
// First try on listener 1 where the API should be disabled.
resp, err := client.RawRequest(client.NewRequest(http.MethodPost, "/proxy/v1/quit"))
if err == nil {
t.Fatalf("expected error")
}
if resp != nil && resp.StatusCode != http.StatusNotFound {
t.Fatalf("expected %d but got: %d", http.StatusNotFound, resp.StatusCode)
}
// Now try on listener 2 where the quit API should be enabled.
err = client.SetAddress("http://" + listenAddr2)
if err != nil {
t.Fatal(err)
}
_, err = client.RawRequest(client.NewRequest(http.MethodPost, "/proxy/v1/quit"))
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
select {
case <-cmd.ShutdownCh:
case <-time.After(5 * time.Second):
t.Errorf("timeout")
}
wg.Wait()
}
// TestProxy_LogFile_CliOverridesConfig tests that the CLI values
// override the config for log files
func TestProxy_LogFile_CliOverridesConfig(t *testing.T) {
// Create basic config
configFile := populateTempFile(t, "proxy-config.hcl", BasicHclConfig)
cfg, err := proxyConfig.LoadConfigFile(configFile.Name())
if err != nil {
t.Fatal("Cannot load config to test update/merge", err)
}
// Sanity check that the config value is the current value
assert.Equal(t, "TMPDIR/juan.log", cfg.LogFile)
// Initialize the command and parse any flags
cmd := &ProxyCommand{BaseCommand: &BaseCommand{}}
f := cmd.Flags()
// Simulate the flag being specified
err = f.Parse([]string{"-log-file=/foo/bar/test.log"})
if err != nil {
t.Fatal(err)
}
// Update the config based on the inputs.
cmd.applyConfigOverrides(f, cfg)
assert.NotEqual(t, "TMPDIR/juan.log", cfg.LogFile)
assert.NotEqual(t, "/squiggle/logs.txt", cfg.LogFile)
assert.Equal(t, "/foo/bar/test.log", cfg.LogFile)
}
// TestProxy_LogFile_Config tests log file config when loaded from config
func TestProxy_LogFile_Config(t *testing.T) {
configFile := populateTempFile(t, "proxy-config.hcl", BasicHclConfig)
cfg, err := proxyConfig.LoadConfigFile(configFile.Name())
if err != nil {
t.Fatal("Cannot load config to test update/merge", err)
}
// Sanity check that the config value is the current value
assert.Equal(t, "TMPDIR/juan.log", cfg.LogFile, "sanity check on log config failed")
assert.Equal(t, 2, cfg.LogRotateMaxFiles)
assert.Equal(t, 1048576, cfg.LogRotateBytes)
// Parse the cli flags (but we pass in an empty slice)
cmd := &ProxyCommand{BaseCommand: &BaseCommand{}}
f := cmd.Flags()
err = f.Parse([]string{})
if err != nil {
t.Fatal(err)
}
// Should change nothing...
cmd.applyConfigOverrides(f, cfg)
assert.Equal(t, "TMPDIR/juan.log", cfg.LogFile, "actual config check")
assert.Equal(t, 2, cfg.LogRotateMaxFiles)
assert.Equal(t, 1048576, cfg.LogRotateBytes)
}
// TestProxy_Config_NewLogger_Default Tests defaults for log level and
// specifically cmd.newLogger()
func TestProxy_Config_NewLogger_Default(t *testing.T) {
cmd := &ProxyCommand{BaseCommand: &BaseCommand{}}
cmd.config = proxyConfig.NewConfig()
logger, err := cmd.newLogger()
assert.NoError(t, err)
assert.NotNil(t, logger)
assert.Equal(t, hclog.Info.String(), logger.GetLevel().String())
}
// TestProxy_Config_ReloadLogLevel Tests reloading updates the log
// level as expected.
func TestProxy_Config_ReloadLogLevel(t *testing.T) {
cmd := &ProxyCommand{BaseCommand: &BaseCommand{}}
var err error
tempDir := t.TempDir()
// Load an initial config
hcl := strings.ReplaceAll(BasicHclConfig, "TMPDIR", tempDir)
configFile := populateTempFile(t, "proxy-config.hcl", hcl)
cmd.config, err = proxyConfig.LoadConfigFile(configFile.Name())
if err != nil {
t.Fatal("Cannot load config to test update/merge", err)
}
// Tweak the loaded config to make sure we can put log files into a temp dir
// and systemd log attempts work fine, this would usually happen during Run.
cmd.logWriter = os.Stdout
cmd.logger, err = cmd.newLogger()
if err != nil {
t.Fatal("logger required for systemd log messages", err)
}
// Sanity check
assert.Equal(t, "warn", cmd.config.LogLevel)
// Load a new config
hcl = strings.ReplaceAll(BasicHclConfig2, "TMPDIR", tempDir)
configFile = populateTempFile(t, "proxy-config.hcl", hcl)
err = cmd.reloadConfig([]string{configFile.Name()})
assert.NoError(t, err)
assert.Equal(t, "debug", cmd.config.LogLevel)
}
// TestProxy_Config_ReloadTls Tests that the TLS certs for the listener are
// correctly reloaded.
func TestProxy_Config_ReloadTls(t *testing.T) {
var wg sync.WaitGroup
wd, err := os.Getwd()
if err != nil {
t.Fatal("unable to get current working directory")
}
workingDir := filepath.Join(wd, "/proxy/test-fixtures/reload")
fooCert := "reload_foo.pem"
fooKey := "reload_foo.key"
barCert := "reload_bar.pem"
barKey := "reload_bar.key"
reloadCert := "reload_cert.pem"
reloadKey := "reload_key.pem"
caPem := "reload_ca.pem"
tempDir := t.TempDir()
// Set up initial 'foo' certs
inBytes, err := os.ReadFile(filepath.Join(workingDir, fooCert))
if err != nil {
t.Fatal("unable to read cert required for test", fooCert, err)
}
err = os.WriteFile(filepath.Join(tempDir, reloadCert), inBytes, 0o777)
if err != nil {
t.Fatal("unable to write temp cert required for test", reloadCert, err)
}
inBytes, err = os.ReadFile(filepath.Join(workingDir, fooKey))
if err != nil {
t.Fatal("unable to read cert key required for test", fooKey, err)
}
err = os.WriteFile(filepath.Join(tempDir, reloadKey), inBytes, 0o777)
if err != nil {
t.Fatal("unable to write temp cert key required for test", reloadKey, err)
}
inBytes, err = os.ReadFile(filepath.Join(workingDir, caPem))
if err != nil {
t.Fatal("unable to read CA pem required for test", caPem, err)
}
certPool := x509.NewCertPool()
ok := certPool.AppendCertsFromPEM(inBytes)
if !ok {
t.Fatal("not ok when appending CA cert")
}
replacedHcl := strings.ReplaceAll(BasicHclConfig, "TMPDIR", tempDir)
configFile := populateTempFile(t, "proxy-config.hcl", replacedHcl)
// Set up Proxy
logger := logging.NewVaultLogger(hclog.Trace)
ui, cmd := testProxyCommand(t, logger)
wg.Add(1)
args := []string{"-config", configFile.Name()}
go func() {
if code := cmd.Run(args); code != 0 {
output := ui.ErrorWriter.String() + ui.OutputWriter.String()
t.Errorf("got a non-zero exit status: %s", output)
}
wg.Done()
}()
testCertificateName := func(cn string) error {
conn, err := tls.Dial("tcp", "127.0.0.1:8100", &tls.Config{
RootCAs: certPool,
})
if err != nil {
return err
}
defer conn.Close()
if err = conn.Handshake(); err != nil {
return err
}
servName := conn.ConnectionState().PeerCertificates[0].Subject.CommonName
if servName != cn {
return fmt.Errorf("expected %s, got %s", cn, servName)
}
return nil
}
// Start
select {
case <-cmd.startedCh:
case <-time.After(5 * time.Second):
t.Fatalf("timeout")
}
if err := testCertificateName("foo.example.com"); err != nil {
t.Fatalf("certificate name didn't check out: %s", err)
}
// Swap out certs
inBytes, err = os.ReadFile(filepath.Join(workingDir, barCert))
if err != nil {
t.Fatal("unable to read cert required for test", barCert, err)
}
err = os.WriteFile(filepath.Join(tempDir, reloadCert), inBytes, 0o777)
if err != nil {
t.Fatal("unable to write temp cert required for test", reloadCert, err)
}
inBytes, err = os.ReadFile(filepath.Join(workingDir, barKey))
if err != nil {
t.Fatal("unable to read cert key required for test", barKey, err)
}
err = os.WriteFile(filepath.Join(tempDir, reloadKey), inBytes, 0o777)
if err != nil {
t.Fatal("unable to write temp cert key required for test", reloadKey, err)
}
// Reload
cmd.SighupCh <- struct{}{}
select {
case <-cmd.reloadedCh:
case <-time.After(5 * time.Second):
t.Fatalf("timeout")
}
if err := testCertificateName("bar.example.com"); err != nil {
t.Fatalf("certificate name didn't check out: %s", err)
}
// Shut down
cmd.ShutdownCh <- struct{}{}
wg.Wait()
}

View File

@ -75,10 +75,18 @@ func AgentAutoAuthString() string {
versionFunc(), projectURL, rt) versionFunc(), projectURL, rt)
} }
// ProxyString returns the consistent user-agent string for Vault Proxy API Proxying. // ProxyString returns the consistent user-agent string for Vault Proxy.
//
// e.g. Vault Proxy/0.10.4 (+https://www.vaultproject.io/; go1.10.1)
func ProxyString() string {
return fmt.Sprintf("Vault Proxy/%s (+%s; %s)",
versionFunc(), projectURL, rt)
}
// ProxyAPIProxyString returns the consistent user-agent string for Vault Proxy API Proxying.
// //
// e.g. Vault Proxy API Proxy/0.10.4 (+https://www.vaultproject.io/; go1.10.1) // e.g. Vault Proxy API Proxy/0.10.4 (+https://www.vaultproject.io/; go1.10.1)
func ProxyString() string { func ProxyAPIProxyString() string {
return fmt.Sprintf("Vault Proxy API Proxy/%s (+%s; %s)", return fmt.Sprintf("Vault Proxy API Proxy/%s (+%s; %s)",
versionFunc(), projectURL, rt) versionFunc(), projectURL, rt)
} }

View File

@ -85,3 +85,56 @@ func TestUserAgent_VaultAgentAutoAuth(t *testing.T) {
exp := "Vault Agent Auto-Auth/1.2.3 (+https://vault-test.com; go5.0)" exp := "Vault Agent Auto-Auth/1.2.3 (+https://vault-test.com; go5.0)"
require.Equal(t, exp, act) require.Equal(t, exp, act)
} }
// TestUserAgent_VaultProxy tests the ProxyString() function works
// as expected
func TestUserAgent_VaultProxy(t *testing.T) {
projectURL = "https://vault-test.com"
rt = "go5.0"
versionFunc = func() string { return "1.2.3" }
act := ProxyString()
exp := "Vault Proxy/1.2.3 (+https://vault-test.com; go5.0)"
require.Equal(t, exp, act)
}
// TestUserAgent_VaultProxyAPIProxy tests the ProxyAPIProxyString() function works
// as expected
func TestUserAgent_VaultProxyAPIProxy(t *testing.T) {
projectURL = "https://vault-test.com"
rt = "go5.0"
versionFunc = func() string { return "1.2.3" }
act := ProxyAPIProxyString()
exp := "Vault Proxy API Proxy/1.2.3 (+https://vault-test.com; go5.0)"
require.Equal(t, exp, act)
}
// TestUserAgent_VaultProxyWithProxiedUserAgent tests the ProxyStringWithProxiedUserAgent()
// function works as expected
func TestUserAgent_VaultProxyWithProxiedUserAgent(t *testing.T) {
projectURL = "https://vault-test.com"
rt = "go5.0"
versionFunc = func() string { return "1.2.3" }
userAgent := "my-user-agent"
act := ProxyStringWithProxiedUserAgent(userAgent)
exp := "Vault Proxy API Proxy/1.2.3 (+https://vault-test.com; go5.0); my-user-agent"
require.Equal(t, exp, act)
}
// TestUserAgent_VaultProxyAutoAuth tests the ProxyAPIProxyString() function works
// as expected
func TestUserAgent_VaultProxyAutoAuth(t *testing.T) {
projectURL = "https://vault-test.com"
rt = "go5.0"
versionFunc = func() string { return "1.2.3" }
act := ProxyAutoAuthString()
exp := "Vault Proxy Auto-Auth/1.2.3 (+https://vault-test.com; go5.0)"
require.Equal(t, exp, act)
}