agent: persistent caching support (#10938)

Adds the option of a write-through cache, backed by boltdb

Co-authored-by: Theron Voran <tvoran@users.noreply.github.com>
Co-authored-by: Jason O'Donnell <2160810+jasonodonnell@users.noreply.github.com>
Co-authored-by: Calvin Leung Huang <cleung2010@gmail.com>
This commit is contained in:
Theron Voran 2021-03-03 14:01:33 -08:00 committed by GitHub
parent 910b45413b
commit 1fdf08b149
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1900 additions and 290 deletions

View File

@ -1,3 +1,3 @@
```release-note:improvement ```release-note:improvement
agent: Route templating server through cache when enabled. agent: Route templating server through cache when persistent cache is enabled.
``` ```

3
changelog/10938.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:feature
agent: Support for persisting the agent cache to disk
```

View File

@ -6,10 +6,12 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"os" "os"
"path" "path"
"path/filepath"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@ -30,6 +32,9 @@ import (
"github.com/hashicorp/vault/command/agent/auth/kerberos" "github.com/hashicorp/vault/command/agent/auth/kerberos"
"github.com/hashicorp/vault/command/agent/auth/kubernetes" "github.com/hashicorp/vault/command/agent/auth/kubernetes"
"github.com/hashicorp/vault/command/agent/cache" "github.com/hashicorp/vault/command/agent/cache"
"github.com/hashicorp/vault/command/agent/cache/cacheboltdb"
"github.com/hashicorp/vault/command/agent/cache/cachememdb"
"github.com/hashicorp/vault/command/agent/cache/keymanager"
agentConfig "github.com/hashicorp/vault/command/agent/config" agentConfig "github.com/hashicorp/vault/command/agent/config"
"github.com/hashicorp/vault/command/agent/sink" "github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/command/agent/sink/file" "github.com/hashicorp/vault/command/agent/sink/file"
@ -461,6 +466,8 @@ func (c *AgentCommand) Run(args []string) int {
c.UI.Output("==> Vault agent started! Log data will stream in below:\n") c.UI.Output("==> Vault agent started! Log data will stream in below:\n")
} }
var leaseCache *cache.LeaseCache
var previousToken string
// Parse agent listener configurations // Parse agent listener configurations
if config.Cache != nil && len(config.Listeners) != 0 { if config.Cache != nil && len(config.Listeners) != 0 {
cacheLogger := c.logger.Named("cache") cacheLogger := c.logger.Named("cache")
@ -479,7 +486,7 @@ func (c *AgentCommand) Run(args []string) int {
// Create the lease cache proxier and set its underlying proxier to // Create the lease cache proxier and set its underlying proxier to
// the API proxier. // the API proxier.
leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{ leaseCache, err = cache.NewLeaseCache(&cache.LeaseCacheConfig{
Client: client, Client: client,
BaseContext: ctx, BaseContext: ctx,
Proxier: apiProxy, Proxier: apiProxy,
@ -490,6 +497,152 @@ func (c *AgentCommand) Run(args []string) int {
return 1 return 1
} }
// Configure persistent storage and add to LeaseCache
if config.Cache.Persist != nil {
if config.Cache.Persist.Path == "" {
c.UI.Error("must specify persistent cache path")
return 1
}
// Set AAD based on key protection type
var aad string
switch config.Cache.Persist.Type {
case "kubernetes":
aad, err = getServiceAccountJWT(config.Cache.Persist.ServiceAccountTokenFile)
if err != nil {
c.UI.Error(fmt.Sprintf("failed to read service account token from %s: %s", config.Cache.Persist.ServiceAccountTokenFile, err))
return 1
}
default:
c.UI.Error(fmt.Sprintf("persistent key protection type %q not supported", config.Cache.Persist.Type))
return 1
}
// Check if bolt file exists already
dbFileExists, err := cacheboltdb.DBFileExists(config.Cache.Persist.Path)
if err != nil {
c.UI.Error(fmt.Sprintf("failed to check if bolt file exists at path %s: %s", config.Cache.Persist.Path, err))
return 1
}
if dbFileExists {
// Open the bolt file, but wait to setup Encryption
ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
Path: config.Cache.Persist.Path,
Logger: cacheLogger.Named("cacheboltdb"),
})
if err != nil {
c.UI.Error(fmt.Sprintf("Error opening persistent cache: %v", err))
return 1
}
// Get the token from bolt for retrieving the encryption key,
// then setup encryption so that restore is possible
token, err := ps.GetRetrievalToken()
if err != nil {
c.UI.Error(fmt.Sprintf("Error getting retrieval token from persistent cache: %v", err))
}
if err := ps.Close(); err != nil {
c.UI.Warn(fmt.Sprintf("Failed to close persistent cache file after getting retrieval token: %s", err))
}
km, err := keymanager.NewPassthroughKeyManager(token)
if err != nil {
c.UI.Error(fmt.Sprintf("failed to configure persistence encryption for cache: %s", err))
return 1
}
// Open the bolt file with the wrapper provided
ps, err = cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
Path: config.Cache.Persist.Path,
Logger: cacheLogger.Named("cacheboltdb"),
Wrapper: km.Wrapper(),
AAD: aad,
})
if err != nil {
c.UI.Error(fmt.Sprintf("Error opening persistent cache: %v", err))
return 1
}
// Restore anything in the persistent cache to the memory cache
if err := leaseCache.Restore(ctx, ps); err != nil {
c.UI.Error(fmt.Sprintf("Error restoring in-memory cache from persisted file: %v", err))
if config.Cache.Persist.ExitOnErr {
return 1
}
}
cacheLogger.Info("loaded memcache from persistent storage")
// Check for previous auto-auth token
oldTokenBytes, err := ps.GetAutoAuthToken(ctx)
if err != nil {
c.UI.Error(fmt.Sprintf("Error in fetching previous auto-auth token: %s", err))
if config.Cache.Persist.ExitOnErr {
return 1
}
}
if len(oldTokenBytes) > 0 {
oldToken, err := cachememdb.Deserialize(oldTokenBytes)
if err != nil {
c.UI.Error(fmt.Sprintf("Error in deserializing previous auto-auth token cache entry: %s", err))
if config.Cache.Persist.ExitOnErr {
return 1
}
}
previousToken = oldToken.Token
}
// If keep_after_import true, set persistent storage layer in
// leaseCache, else remove db file
if config.Cache.Persist.KeepAfterImport {
defer ps.Close()
leaseCache.SetPersistentStorage(ps)
} else {
if err := ps.Close(); err != nil {
c.UI.Warn(fmt.Sprintf("failed to close persistent cache file: %s", err))
}
dbFile := filepath.Join(config.Cache.Persist.Path, cacheboltdb.DatabaseFileName)
if err := os.Remove(dbFile); err != nil {
c.UI.Error(fmt.Sprintf("failed to remove persistent storage file %s: %s", dbFile, err))
if config.Cache.Persist.ExitOnErr {
return 1
}
}
}
} else {
km, err := keymanager.NewPassthroughKeyManager(nil)
if err != nil {
c.UI.Error(fmt.Sprintf("failed to configure persistence encryption for cache: %s", err))
return 1
}
ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
Path: config.Cache.Persist.Path,
Logger: cacheLogger.Named("cacheboltdb"),
Wrapper: km.Wrapper(),
AAD: aad,
})
if err != nil {
c.UI.Error(fmt.Sprintf("Error creating persistent cache: %v", err))
return 1
}
cacheLogger.Info("configured persistent storage", "path", config.Cache.Persist.Path)
// Stash the key material in bolt
token, err := km.RetrievalToken()
if err != nil {
c.UI.Error(fmt.Sprintf("Error getting persistent key: %s", err))
return 1
}
if err := ps.StoreRetrievalToken(token); err != nil {
c.UI.Error(fmt.Sprintf("Error setting key in persistent cache: %v", err))
return 1
}
defer ps.Close()
leaseCache.SetPersistentStorage(ps)
}
}
var inmemSink sink.Sink var inmemSink sink.Sink
if config.Cache.UseAutoAuthToken { if config.Cache.UseAutoAuthToken {
cacheLogger.Debug("auto-auth token is allowed to be used; configuring inmem sink") cacheLogger.Debug("auto-auth token is allowed to be used; configuring inmem sink")
@ -585,6 +738,11 @@ func (c *AgentCommand) Run(args []string) int {
select { select {
case <-c.ShutdownCh: case <-c.ShutdownCh:
c.UI.Output("==> Vault agent shutdown triggered") c.UI.Output("==> Vault agent shutdown triggered")
// Let the lease cache know this is a shutdown; no need to evict
// everything
if leaseCache != nil {
leaseCache.SetShuttingDown(true)
}
return nil return nil
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
@ -604,6 +762,7 @@ func (c *AgentCommand) Run(args []string) int {
MaxBackoff: config.AutoAuth.Method.MaxBackoff, MaxBackoff: config.AutoAuth.Method.MaxBackoff,
EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials, EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials,
EnableTemplateTokenCh: enableTokenCh, EnableTemplateTokenCh: enableTokenCh,
Token: previousToken,
}) })
ss := sink.NewSinkServer(&sink.SinkServerConfig{ ss := sink.NewSinkServer(&sink.SinkServerConfig{
@ -624,6 +783,11 @@ func (c *AgentCommand) Run(args []string) int {
g.Add(func() error { g.Add(func() error {
return ah.Run(ctx, method) return ah.Run(ctx, method)
}, func(error) { }, func(error) {
// Let the lease cache know this is a shutdown; no need to evict
// everything
if leaseCache != nil {
leaseCache.SetShuttingDown(true)
}
cancelFunc() cancelFunc()
}) })
@ -650,12 +814,22 @@ func (c *AgentCommand) Run(args []string) int {
return err return err
}, func(error) { }, func(error) {
// Let the lease cache know this is a shutdown; no need to evict
// everything
if leaseCache != nil {
leaseCache.SetShuttingDown(true)
}
cancelFunc() cancelFunc()
}) })
g.Add(func() error { g.Add(func() error {
return ts.Run(ctx, ah.TemplateTokenCh, config.Templates) return ts.Run(ctx, ah.TemplateTokenCh, config.Templates)
}, func(error) { }, func(error) {
// Let the lease cache know this is a shutdown; no need to evict
// everything
if leaseCache != nil {
leaseCache.SetShuttingDown(true)
}
cancelFunc() cancelFunc()
ts.Stop() ts.Stop()
}) })
@ -793,3 +967,16 @@ func (c *AgentCommand) removePidFile(pidPath string) error {
} }
return os.Remove(pidPath) return os.Remove(pidPath)
} }
// GetServiceAccountJWT reads the service account jwt from `tokenFile`. Default is
// the default service account file path in kubernetes.
func getServiceAccountJWT(tokenFile string) (string, error) {
if len(tokenFile) == 0 {
tokenFile = "/var/run/secrets/kubernetes.io/serviceaccount/token"
}
token, err := ioutil.ReadFile(tokenFile)
if err != nil {
return "", err
}
return strings.TrimSpace(string(token)), nil
}

303
command/agent/cache/cacheboltdb/bolt.go vendored Normal file
View File

@ -0,0 +1,303 @@
package cacheboltdb
import (
"context"
"fmt"
"os"
"path/filepath"
"time"
"github.com/golang/protobuf/proto"
"github.com/hashicorp/go-hclog"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/go-multierror"
bolt "go.etcd.io/bbolt"
)
const (
// Keep track of schema version for future migrations
storageVersionKey = "version"
storageVersion = "1"
// DatabaseFileName - filename for the persistent cache file
DatabaseFileName = "vault-agent-cache.db"
// metaBucketName - naming the meta bucket that holds the version and
// bootstrapping keys
metaBucketName = "meta"
// SecretLeaseType - Bucket/type for leases with secret info
SecretLeaseType = "secret-lease"
// AuthLeaseType - Bucket/type for leases with auth info
AuthLeaseType = "auth-lease"
// TokenType - Bucket/type for auto-auth tokens
TokenType = "token"
// AutoAuthToken - key for the latest auto-auth token
AutoAuthToken = "auto-auth-token"
// RetrievalTokenMaterial is the actual key or token in the key bucket
RetrievalTokenMaterial = "retrieval-token-material"
)
// BoltStorage is a persistent cache using a bolt db. Items are organized with
// the version and bootstrapping items in the "meta" bucket, and tokens, auth
// leases, and secret leases in their own buckets.
type BoltStorage struct {
db *bolt.DB
logger hclog.Logger
wrapper wrapping.Wrapper
aad string
}
// BoltStorageConfig is the collection of input parameters for setting up bolt
// storage
type BoltStorageConfig struct {
Path string
Logger hclog.Logger
Wrapper wrapping.Wrapper
AAD string
}
// NewBoltStorage opens a new bolt db at the specified file path and returns it.
// If the db already exists the buckets will just be created if they don't
// exist.
func NewBoltStorage(config *BoltStorageConfig) (*BoltStorage, error) {
dbPath := filepath.Join(config.Path, DatabaseFileName)
db, err := bolt.Open(dbPath, 0600, &bolt.Options{Timeout: 1 * time.Second})
if err != nil {
return nil, err
}
err = db.Update(func(tx *bolt.Tx) error {
return createBoltSchema(tx)
})
if err != nil {
return nil, err
}
bs := &BoltStorage{
db: db,
logger: config.Logger,
wrapper: config.Wrapper,
aad: config.AAD,
}
return bs, nil
}
func createBoltSchema(tx *bolt.Tx) error {
// create the meta bucket at the top level
meta, err := tx.CreateBucketIfNotExists([]byte(metaBucketName))
if err != nil {
return fmt.Errorf("failed to create bucket %s: %w", metaBucketName, err)
}
// check and set file version in the meta bucket
version := meta.Get([]byte(storageVersionKey))
switch {
case version == nil:
err = meta.Put([]byte(storageVersionKey), []byte(storageVersion))
if err != nil {
return fmt.Errorf("failed to set storage version: %w", err)
}
case string(version) != storageVersion:
return fmt.Errorf("storage migration from %s to %s not implemented", string(version), storageVersion)
}
// create the buckets for tokens and leases
_, err = tx.CreateBucketIfNotExists([]byte(TokenType))
if err != nil {
return fmt.Errorf("failed to create token bucket: %w", err)
}
_, err = tx.CreateBucketIfNotExists([]byte(AuthLeaseType))
if err != nil {
return fmt.Errorf("failed to create auth lease bucket: %w", err)
}
_, err = tx.CreateBucketIfNotExists([]byte(SecretLeaseType))
if err != nil {
return fmt.Errorf("failed to create secret lease bucket: %w", err)
}
return nil
}
// Set an index (token or lease) in bolt storage
func (b *BoltStorage) Set(ctx context.Context, id string, plaintext []byte, indexType string) error {
blob, err := b.wrapper.Encrypt(ctx, plaintext, []byte(b.aad))
if err != nil {
return fmt.Errorf("error encrypting %s index: %w", indexType, err)
}
protoBlob, err := proto.Marshal(blob)
if err != nil {
return err
}
return b.db.Update(func(tx *bolt.Tx) error {
s := tx.Bucket([]byte(indexType))
if s == nil {
return fmt.Errorf("bucket %q not found", indexType)
}
// If this is an auto-auth token, also stash it in the meta bucket for
// easy retrieval upon restore
if indexType == TokenType {
meta := tx.Bucket([]byte(metaBucketName))
if err := meta.Put([]byte(AutoAuthToken), protoBlob); err != nil {
return fmt.Errorf("failed to set latest auto-auth token: %w", err)
}
}
return s.Put([]byte(id), protoBlob)
})
}
func getBucketIDs(b *bolt.Bucket) ([][]byte, error) {
ids := [][]byte{}
err := b.ForEach(func(k, v []byte) error {
ids = append(ids, k)
return nil
})
return ids, err
}
// Delete an index (token or lease) by id from bolt storage
func (b *BoltStorage) Delete(id string) error {
return b.db.Update(func(tx *bolt.Tx) error {
// Since Delete returns a nil error if the key doesn't exist, just call
// delete in all three index buckets without checking existence first
if err := tx.Bucket([]byte(TokenType)).Delete([]byte(id)); err != nil {
return fmt.Errorf("failed to delete %q from token bucket: %w", id, err)
}
if err := tx.Bucket([]byte(AuthLeaseType)).Delete([]byte(id)); err != nil {
return fmt.Errorf("failed to delete %q from auth lease bucket: %w", id, err)
}
if err := tx.Bucket([]byte(SecretLeaseType)).Delete([]byte(id)); err != nil {
return fmt.Errorf("failed to delete %q from secret lease bucket: %w", id, err)
}
b.logger.Trace("deleted index from bolt db", "id", id)
return nil
})
}
func (b *BoltStorage) decrypt(ctx context.Context, ciphertext []byte) ([]byte, error) {
var blob wrapping.EncryptedBlobInfo
if err := proto.Unmarshal(ciphertext, &blob); err != nil {
return nil, err
}
return b.wrapper.Decrypt(ctx, &blob, []byte(b.aad))
}
// GetByType returns a list of stored items of the specified type
func (b *BoltStorage) GetByType(ctx context.Context, indexType string) ([][]byte, error) {
var returnBytes [][]byte
err := b.db.View(func(tx *bolt.Tx) error {
var errors *multierror.Error
tx.Bucket([]byte(indexType)).ForEach(func(id, ciphertext []byte) error {
plaintext, err := b.decrypt(ctx, ciphertext)
if err != nil {
errors = multierror.Append(errors, fmt.Errorf("error decrypting index id %s: %w", id, err))
return nil
}
returnBytes = append(returnBytes, plaintext)
return nil
})
return errors.ErrorOrNil()
})
return returnBytes, err
}
// GetAutoAuthToken retrieves the latest auto-auth token, and returns nil if non
// exists yet
func (b *BoltStorage) GetAutoAuthToken(ctx context.Context) ([]byte, error) {
var encryptedToken []byte
err := b.db.View(func(tx *bolt.Tx) error {
meta := tx.Bucket([]byte(metaBucketName))
if meta == nil {
return fmt.Errorf("bucket %q not found", metaBucketName)
}
encryptedToken = meta.Get([]byte(AutoAuthToken))
return nil
})
if err != nil {
return nil, err
}
if encryptedToken == nil {
return nil, nil
}
plaintext, err := b.decrypt(ctx, encryptedToken)
if err != nil {
return nil, fmt.Errorf("failed to decrypt auto-auth token: %w", err)
}
return plaintext, nil
}
// GetRetrievalToken retrieves a plaintext token from the KeyBucket, which will
// be used by the key manager to retrieve the encryption key, nil if none set
func (b *BoltStorage) GetRetrievalToken() ([]byte, error) {
var token []byte
err := b.db.View(func(tx *bolt.Tx) error {
keyBucket := tx.Bucket([]byte(metaBucketName))
if keyBucket == nil {
return fmt.Errorf("bucket %q not found", metaBucketName)
}
token = keyBucket.Get([]byte(RetrievalTokenMaterial))
return nil
})
if err != nil {
return nil, err
}
return token, err
}
// StoreRetrievalToken sets plaintext token material in the RetrievalTokenBucket
func (b *BoltStorage) StoreRetrievalToken(token []byte) error {
return b.db.Update(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(metaBucketName))
if bucket == nil {
return fmt.Errorf("bucket %q not found", metaBucketName)
}
return bucket.Put([]byte(RetrievalTokenMaterial), token)
})
}
// Close the boltdb
func (b *BoltStorage) Close() error {
b.logger.Trace("closing bolt db", "path", b.db.Path())
return b.db.Close()
}
// Clear the boltdb by deleting all the token and lease buckets and recreating
// the schema/layout
func (b *BoltStorage) Clear() error {
return b.db.Update(func(tx *bolt.Tx) error {
for _, name := range []string{AuthLeaseType, SecretLeaseType, TokenType} {
b.logger.Trace("deleting bolt bucket", "name", name)
if err := tx.DeleteBucket([]byte(name)); err != nil {
return err
}
}
return createBoltSchema(tx)
})
}
// DBFileExists checks whether the vault agent cache file at `filePath` exists
func DBFileExists(path string) (bool, error) {
checkFile, err := os.OpenFile(filepath.Join(path, DatabaseFileName), os.O_RDWR, 0600)
defer checkFile.Close()
switch {
case err == nil:
return true, nil
case os.IsNotExist(err):
return false, nil
default:
return false, fmt.Errorf("failed to check if bolt file exists at path %s: %w", path, err)
}
}

View File

@ -0,0 +1,263 @@
package cacheboltdb
import (
"context"
"io/ioutil"
"os"
"path"
"testing"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/command/agent/cache/keymanager"
"github.com/ory/dockertest/v3/docker/pkg/ioutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func getTestKeyManager(t *testing.T) keymanager.KeyManager {
t.Helper()
km, err := keymanager.NewPassthroughKeyManager(nil)
require.NoError(t, err)
return km
}
func TestBolt_SetGet(t *testing.T) {
ctx := context.Background()
path, err := ioutils.TempDir("", "bolt-test")
require.NoError(t, err)
defer os.RemoveAll(path)
b, err := NewBoltStorage(&BoltStorageConfig{
Path: path,
Logger: hclog.Default(),
Wrapper: getTestKeyManager(t).Wrapper(),
})
require.NoError(t, err)
secrets, err := b.GetByType(ctx, SecretLeaseType)
assert.NoError(t, err)
require.Len(t, secrets, 0)
err = b.Set(ctx, "test1", []byte("hello"), SecretLeaseType)
assert.NoError(t, err)
secrets, err = b.GetByType(ctx, SecretLeaseType)
assert.NoError(t, err)
require.Len(t, secrets, 1)
assert.Equal(t, []byte("hello"), secrets[0])
}
func TestBoltDelete(t *testing.T) {
ctx := context.Background()
path, err := ioutils.TempDir("", "bolt-test")
require.NoError(t, err)
defer os.RemoveAll(path)
b, err := NewBoltStorage(&BoltStorageConfig{
Path: path,
Logger: hclog.Default(),
Wrapper: getTestKeyManager(t).Wrapper(),
})
require.NoError(t, err)
err = b.Set(ctx, "secret-test1", []byte("hello1"), SecretLeaseType)
require.NoError(t, err)
err = b.Set(ctx, "secret-test2", []byte("hello2"), SecretLeaseType)
require.NoError(t, err)
secrets, err := b.GetByType(ctx, SecretLeaseType)
require.NoError(t, err)
assert.Len(t, secrets, 2)
assert.ElementsMatch(t, [][]byte{[]byte("hello1"), []byte("hello2")}, secrets)
err = b.Delete("secret-test1")
require.NoError(t, err)
secrets, err = b.GetByType(ctx, SecretLeaseType)
require.NoError(t, err)
require.Len(t, secrets, 1)
assert.Equal(t, []byte("hello2"), secrets[0])
}
func TestBoltClear(t *testing.T) {
ctx := context.Background()
path, err := ioutils.TempDir("", "bolt-test")
require.NoError(t, err)
defer os.RemoveAll(path)
b, err := NewBoltStorage(&BoltStorageConfig{
Path: path,
Logger: hclog.Default(),
Wrapper: getTestKeyManager(t).Wrapper(),
})
require.NoError(t, err)
// Populate the bolt db
err = b.Set(ctx, "secret-test1", []byte("hello"), SecretLeaseType)
require.NoError(t, err)
secrets, err := b.GetByType(ctx, SecretLeaseType)
require.NoError(t, err)
require.Len(t, secrets, 1)
assert.Equal(t, []byte("hello"), secrets[0])
err = b.Set(ctx, "auth-test1", []byte("hello"), AuthLeaseType)
require.NoError(t, err)
auths, err := b.GetByType(ctx, AuthLeaseType)
require.NoError(t, err)
require.Len(t, auths, 1)
assert.Equal(t, []byte("hello"), auths[0])
err = b.Set(ctx, "token-test1", []byte("hello"), TokenType)
require.NoError(t, err)
tokens, err := b.GetByType(ctx, TokenType)
require.NoError(t, err)
require.Len(t, tokens, 1)
assert.Equal(t, []byte("hello"), tokens[0])
// Clear the bolt db, and check that it's indeed clear
err = b.Clear()
require.NoError(t, err)
secrets, err = b.GetByType(ctx, SecretLeaseType)
require.NoError(t, err)
assert.Len(t, secrets, 0)
auths, err = b.GetByType(ctx, AuthLeaseType)
require.NoError(t, err)
assert.Len(t, auths, 0)
tokens, err = b.GetByType(ctx, TokenType)
require.NoError(t, err)
assert.Len(t, tokens, 0)
}
func TestBoltSetAutoAuthToken(t *testing.T) {
ctx := context.Background()
path, err := ioutils.TempDir("", "bolt-test")
require.NoError(t, err)
defer os.RemoveAll(path)
b, err := NewBoltStorage(&BoltStorageConfig{
Path: path,
Logger: hclog.Default(),
Wrapper: getTestKeyManager(t).Wrapper(),
})
require.NoError(t, err)
token, err := b.GetAutoAuthToken(ctx)
assert.NoError(t, err)
assert.Nil(t, token)
// set first token
err = b.Set(ctx, "token-test1", []byte("hello 1"), TokenType)
require.NoError(t, err)
secrets, err := b.GetByType(ctx, TokenType)
require.NoError(t, err)
require.Len(t, secrets, 1)
assert.Equal(t, []byte("hello 1"), secrets[0])
token, err = b.GetAutoAuthToken(ctx)
assert.NoError(t, err)
assert.Equal(t, []byte("hello 1"), token)
// set second token
err = b.Set(ctx, "token-test2", []byte("hello 2"), TokenType)
require.NoError(t, err)
secrets, err = b.GetByType(ctx, TokenType)
require.NoError(t, err)
require.Len(t, secrets, 2)
assert.ElementsMatch(t, [][]byte{[]byte("hello 1"), []byte("hello 2")}, secrets)
token, err = b.GetAutoAuthToken(ctx)
assert.NoError(t, err)
assert.Equal(t, []byte("hello 2"), token)
}
func TestDBFileExists(t *testing.T) {
testCases := []struct {
name string
mkDir bool
createFile bool
expectExist bool
}{
{
name: "all exists",
mkDir: true,
createFile: true,
expectExist: true,
},
{
name: "dir exist, file missing",
mkDir: true,
createFile: false,
expectExist: false,
},
{
name: "all missing",
mkDir: false,
createFile: false,
expectExist: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var tmpPath string
var err error
if tc.mkDir {
tmpPath, err = ioutil.TempDir("", "test-db-path")
require.NoError(t, err)
}
if tc.createFile {
err = ioutil.WriteFile(path.Join(tmpPath, DatabaseFileName), []byte("test-db-path"), 0600)
require.NoError(t, err)
}
exists, err := DBFileExists(tmpPath)
assert.NoError(t, err)
assert.Equal(t, tc.expectExist, exists)
})
}
}
func Test_SetGetRetrievalToken(t *testing.T) {
testCases := []struct {
name string
tokenToSet []byte
expectedToken []byte
}{
{
name: "normal set and get",
tokenToSet: []byte("test token"),
expectedToken: []byte("test token"),
},
{
name: "no token set",
tokenToSet: nil,
expectedToken: nil,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
path, err := ioutils.TempDir("", "bolt-test")
require.NoError(t, err)
defer os.RemoveAll(path)
b, err := NewBoltStorage(&BoltStorageConfig{
Path: path,
Logger: hclog.Default(),
Wrapper: getTestKeyManager(t).Wrapper(),
})
require.NoError(t, err)
defer b.Close()
if tc.tokenToSet != nil {
err := b.StoreRetrievalToken(tc.tokenToSet)
require.NoError(t, err)
}
gotKey, err := b.GetRetrievalToken()
assert.NoError(t, err)
assert.Equal(t, tc.expectedToken, gotKey)
})
}
}

View File

@ -1,6 +1,11 @@
package cachememdb package cachememdb
import "context" import (
"context"
"encoding/json"
"net/http"
"time"
)
// Index holds the response to be cached along with multiple other values that // Index holds the response to be cached along with multiple other values that
// serve as pointers to refer back to this index. // serve as pointers to refer back to this index.
@ -48,6 +53,21 @@ type Index struct {
// goroutine that manages the renewal of the secret belonging to the // goroutine that manages the renewal of the secret belonging to the
// response in this index. // response in this index.
RenewCtxInfo *ContextInfo RenewCtxInfo *ContextInfo
// RequestMethod is the HTTP method of the request
RequestMethod string
// RequestToken is the token used in the request
RequestToken string
// RequestHeader is the header used in the request
RequestHeader http.Header
// LastRenewed is the timestamp of last renewal
LastRenewed time.Time
// Type is the index type (token, auth-lease, secret-lease)
Type string
} }
type IndexName uint32 type IndexName uint32
@ -106,3 +126,25 @@ func NewContextInfo(ctx context.Context) *ContextInfo {
ctxInfo.DoneCh = make(chan struct{}) ctxInfo.DoneCh = make(chan struct{})
return ctxInfo return ctxInfo
} }
// Serialize returns a json marshal'ed Index object, without the RenewCtxInfo
func (i Index) Serialize() ([]byte, error) {
i.RenewCtxInfo = nil
indexBytes, err := json.Marshal(i)
if err != nil {
return nil, err
}
return indexBytes, nil
}
// Deserialize converts json bytes to an Index object
// Note: RenewCtxInfo will need to be reconstructed elsewhere.
func Deserialize(indexBytes []byte) (*Index, error) {
index := new(Index)
if err := json.Unmarshal(indexBytes, index); err != nil {
return nil, err
}
return index, nil
}

View File

@ -0,0 +1,43 @@
package cachememdb
import (
"context"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSerializeDeserialize(t *testing.T) {
testIndex := &Index{
ID: "testid",
Token: "testtoken",
TokenParent: "parent token",
TokenAccessor: "test accessor",
Namespace: "test namespace",
RequestPath: "/test/path",
Lease: "lease id",
LeaseToken: "lease token id",
Response: []byte(`{"something": "here"}`),
RenewCtxInfo: NewContextInfo(context.Background()),
RequestMethod: "GET",
RequestToken: "request token",
RequestHeader: http.Header{
"X-Test": []string{"vault", "agent"},
},
LastRenewed: time.Now().UTC(),
}
indexBytes, err := testIndex.Serialize()
require.NoError(t, err)
assert.True(t, len(indexBytes) > 0)
assert.NotNil(t, testIndex.RenewCtxInfo, "Serialize should not modify original Index object")
restoredIndex, err := Deserialize(indexBytes)
require.NoError(t, err)
testIndex.RenewCtxInfo = nil
assert.Equal(t, testIndex, restoredIndex, "They should be equal without RenewCtxInfo set on the original")
}

View File

@ -1,18 +0,0 @@
package crypto
import (
"context"
)
const (
KeyID = "root"
)
type KeyManager interface {
GetKey() []byte
GetPersistentKey() ([]byte, error)
Renewable() bool
Renewer(context.Context) error
Encrypt(context.Context, []byte, []byte) ([]byte, error)
Decrypt(context.Context, []byte, []byte) ([]byte, error)
}

View File

@ -1,97 +0,0 @@
package crypto
import (
"context"
"crypto/rand"
"fmt"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/go-kms-wrapping/wrappers/aead"
)
var _ KeyManager = (*KubeEncryptionKey)(nil)
type KubeEncryptionKey struct {
renewable bool
wrapper *aead.Wrapper
}
// NewK8s returns a new instance of the Kube encryption key. Kubernetes
// encryption keys aren't renewable.
func NewK8s(existingKey []byte) (*KubeEncryptionKey, error) {
k := &KubeEncryptionKey{
renewable: false,
wrapper: aead.NewWrapper(nil),
}
k.wrapper.SetConfig(map[string]string{"key_id": KeyID})
var rootKey []byte = nil
if len(existingKey) != 0 {
if len(existingKey) != 32 {
return k, fmt.Errorf("invalid key size, should be 32, got %d", len(existingKey))
}
rootKey = existingKey
}
if rootKey == nil {
newKey := make([]byte, 32)
_, err := rand.Read(newKey)
if err != nil {
return k, err
}
rootKey = newKey
}
if err := k.wrapper.SetAESGCMKeyBytes(rootKey); err != nil {
return k, err
}
return k, nil
}
// GetKey returns the encryption key in a format optimized for storage.
// In k8s we store the key as is, so just return the key stored.
func (k *KubeEncryptionKey) GetKey() []byte {
return k.wrapper.GetKeyBytes()
}
// GetPersistentKey returns the key which should be stored in the persisent
// cache. In k8s we store the key as is, so just return the key stored.
func (k *KubeEncryptionKey) GetPersistentKey() ([]byte, error) {
return k.wrapper.GetKeyBytes(), nil
}
// Renewable lets the caller know if this encryption key type is
// renewable. In Kubernetes the key isn't renewable.
func (k *KubeEncryptionKey) Renewable() bool {
return k.renewable
}
// Renewer is used when the encryption key type is renewable. Since Kubernetes
// keys aren't renewable, returning nothing.
func (k *KubeEncryptionKey) Renewer(ctx context.Context) error {
return nil
}
// Encrypt takes plaintext values and encrypts them using the store key and additional
// data. For Kubernetes the AAD should be the service account JWT.
func (k *KubeEncryptionKey) Encrypt(ctx context.Context, plaintext, aad []byte) ([]byte, error) {
blob, err := k.wrapper.Encrypt(ctx, plaintext, aad)
if err != nil {
return nil, err
}
return blob.Ciphertext, nil
}
// Decrypt takes ciphertext and AAD values and returns the decrypted value. For Kubernetes the AAD
// should be the service account JWT.
func (k *KubeEncryptionKey) Decrypt(ctx context.Context, ciphertext, aad []byte) ([]byte, error) {
blob := &wrapping.EncryptedBlobInfo{
Ciphertext: ciphertext,
KeyInfo: &wrapping.KeyInfo{
KeyID: KeyID,
},
}
return k.wrapper.Decrypt(ctx, blob, aad)
}

View File

@ -1,155 +0,0 @@
package crypto
import (
"fmt"
"math/rand"
"testing"
)
func TestCrypto_KubernetesNewKey(t *testing.T) {
k8sKey, err := NewK8s([]byte{})
if err != nil {
t.Fatalf(fmt.Sprintf("unexpected error: %s", err))
}
key := k8sKey.GetKey()
if key == nil {
t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", key))
}
persistentKey, _ := k8sKey.GetPersistentKey()
if persistentKey == nil {
t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", persistentKey))
}
if string(key) != string(persistentKey) {
t.Fatalf("keys don't match, they should: key: %s, persistentKey: %s", key, persistentKey)
}
plaintextInput := []byte("test")
aad := []byte("kubernetes")
ciphertext, err := k8sKey.Encrypt(nil, plaintextInput, aad)
if err != nil {
t.Fatalf(err.Error())
}
if ciphertext == nil {
t.Fatalf("ciphertext nil, it shouldn't be")
}
plaintext, err := k8sKey.Decrypt(nil, ciphertext, aad)
if err != nil {
t.Fatalf(err.Error())
}
if string(plaintext) != string(plaintextInput) {
t.Fatalf("expected %s, got %s", plaintextInput, plaintext)
}
}
func TestCrypto_KubernetesExistingKey(t *testing.T) {
rootKey := make([]byte, 32)
n, err := rand.Read(rootKey)
if err != nil {
t.Fatal(err)
}
if n != 32 {
t.Fatal(n)
}
k8sKey, err := NewK8s(rootKey)
if err != nil {
t.Fatalf(fmt.Sprintf("unexpected error: %s", err))
}
key := k8sKey.GetKey()
if key == nil {
t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", key))
}
if string(key) != string(rootKey) {
t.Fatalf(fmt.Sprintf("expected keys to be the same, they weren't: expected: %s, got: %s", rootKey, key))
}
persistentKey, _ := k8sKey.GetPersistentKey()
if persistentKey == nil {
t.Fatalf("key is nil, it shouldn't be")
}
if string(persistentKey) != string(rootKey) {
t.Fatalf(fmt.Sprintf("expected keys to be the same, they weren't: expected: %s, got: %s", rootKey, persistentKey))
}
if string(key) != string(persistentKey) {
t.Fatalf(fmt.Sprintf("expected keys to be the same, they weren't: %s %s", rootKey, persistentKey))
}
plaintextInput := []byte("test")
aad := []byte("kubernetes")
ciphertext, err := k8sKey.Encrypt(nil, plaintextInput, aad)
if err != nil {
t.Fatalf(err.Error())
}
if ciphertext == nil {
t.Fatalf("ciphertext nil, it shouldn't be")
}
plaintext, err := k8sKey.Decrypt(nil, ciphertext, aad)
if err != nil {
t.Fatalf(err.Error())
}
if string(plaintext) != string(plaintextInput) {
t.Fatalf("expected %s, got %s", plaintextInput, plaintext)
}
}
func TestCrypto_KubernetesPassGeneratedKey(t *testing.T) {
k8sFirstKey, err := NewK8s([]byte{})
if err != nil {
t.Fatalf(fmt.Sprintf("unexpected error: %s", err))
}
firstPersistentKey := k8sFirstKey.GetKey()
if firstPersistentKey == nil {
t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", firstPersistentKey))
}
plaintextInput := []byte("test")
aad := []byte("kubernetes")
ciphertext, err := k8sFirstKey.Encrypt(nil, plaintextInput, aad)
if err != nil {
t.Fatalf(err.Error())
}
if ciphertext == nil {
t.Fatalf("ciphertext nil, it shouldn't be")
}
k8sLoadedKey, err := NewK8s(firstPersistentKey)
if err != nil {
t.Fatalf(fmt.Sprintf("unexpected error: %s", err))
}
loadedKey, _ := k8sLoadedKey.GetPersistentKey()
if loadedKey == nil {
t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", loadedKey))
}
if string(loadedKey) != string(firstPersistentKey) {
t.Fatalf(fmt.Sprintf("keys do not match"))
}
plaintext, err := k8sLoadedKey.Decrypt(nil, ciphertext, aad)
if err != nil {
t.Fatalf(err.Error())
}
if string(plaintext) != string(plaintextInput) {
t.Fatalf("expected %s, got %s", plaintextInput, plaintext)
}
}

View File

@ -0,0 +1,16 @@
package keymanager
import wrapping "github.com/hashicorp/go-kms-wrapping"
const (
KeyID = "root"
)
type KeyManager interface {
// Returns a wrapping.Wrapper which can be used to perform key-related operations.
Wrapper() wrapping.Wrapper
// RetrievalToken is the material returned which can be used to source back the
// encryption key. Depending on the implementation, the token can be the
// encryption key itself or a token/identifier used to exchange the token.
RetrievalToken() ([]byte, error)
}

View File

@ -0,0 +1,68 @@
package keymanager
import (
"crypto/rand"
"fmt"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/go-kms-wrapping/wrappers/aead"
)
var _ KeyManager = (*PassthroughKeyManager)(nil)
type PassthroughKeyManager struct {
wrapper *aead.Wrapper
}
// NewPassthroughKeyManager returns a new instance of the Kube encryption key.
// If a key is provided, it will be used as the encryption key for the wrapper,
// otherwise one will be generated.
func NewPassthroughKeyManager(key []byte) (*PassthroughKeyManager, error) {
var rootKey []byte = nil
switch len(key) {
case 0:
newKey := make([]byte, 32)
_, err := rand.Read(newKey)
if err != nil {
return nil, err
}
rootKey = newKey
case 32:
rootKey = key
default:
return nil, fmt.Errorf("invalid key size, should be 32, got %d", len(key))
}
wrapper := aead.NewWrapper(nil)
if _, err := wrapper.SetConfig(map[string]string{"key_id": KeyID}); err != nil {
return nil, err
}
if err := wrapper.SetAESGCMKeyBytes(rootKey); err != nil {
return nil, err
}
k := &PassthroughKeyManager{
wrapper: wrapper,
}
return k, nil
}
// Wrapper returns the manager's wrapper for key operations.
func (w *PassthroughKeyManager) Wrapper() wrapping.Wrapper {
return w.wrapper
}
// RetrievalToken returns the key that was used on the wrapper since this key
// manager is simply a passthrough and does not provide a mechanism to abstract
// this key.
func (w *PassthroughKeyManager) RetrievalToken() ([]byte, error) {
if w.wrapper == nil {
return nil, fmt.Errorf("unable to get wrapper for token retrieval")
}
return w.wrapper.GetKeyBytes(), nil
}

View File

@ -0,0 +1,56 @@
package keymanager
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
func TestKeyManager_PassthrougKeyManager(t *testing.T) {
tests := []struct {
name string
key []byte
wantErr bool
}{
{
"new key",
nil,
false,
},
{
"existing valid key",
[]byte("e679e2f3d8d0e489d408bc617c6890d6"),
false,
},
{
"invalid key length",
[]byte("foobar"),
true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
m, err := NewPassthroughKeyManager(tc.key)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
if w := m.Wrapper(); w == nil {
t.Fatalf("expected non-nil wrapper from the key manager")
}
token, err := m.RetrievalToken()
if err != nil {
t.Fatalf("unable to retrieve token: %s", err)
}
if len(tc.key) != 0 && !bytes.Equal(tc.key, token) {
t.Fatalf("expected key bytes: %x, got: %x", tc.key, token)
}
})
}
}

View File

@ -11,6 +11,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -18,6 +19,7 @@ import (
"github.com/hashicorp/errwrap" "github.com/hashicorp/errwrap"
hclog "github.com/hashicorp/go-hclog" hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agent/cache/cacheboltdb"
cachememdb "github.com/hashicorp/vault/command/agent/cache/cachememdb" cachememdb "github.com/hashicorp/vault/command/agent/cache/cachememdb"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
nshelper "github.com/hashicorp/vault/helper/namespace" nshelper "github.com/hashicorp/vault/helper/namespace"
@ -83,6 +85,13 @@ type LeaseCache struct {
// inflightCache keeps track of inflight requests // inflightCache keeps track of inflight requests
inflightCache *gocache.Cache inflightCache *gocache.Cache
// ps is the persistent storage for tokens and leases
ps *cacheboltdb.BoltStorage
// shuttingDown is used to determine if cache needs to be evicted or not
// when the context is cancelled
shuttingDown atomic.Bool
} }
// LeaseCacheConfig is the configuration for initializing a new // LeaseCacheConfig is the configuration for initializing a new
@ -92,6 +101,7 @@ type LeaseCacheConfig struct {
BaseContext context.Context BaseContext context.Context
Proxier Proxier Proxier Proxier
Logger hclog.Logger Logger hclog.Logger
Storage *cacheboltdb.BoltStorage
} }
type inflightRequest struct { type inflightRequest struct {
@ -141,9 +151,21 @@ func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
l: &sync.RWMutex{}, l: &sync.RWMutex{},
idLocks: locksutil.CreateLocks(), idLocks: locksutil.CreateLocks(),
inflightCache: gocache.New(gocache.NoExpiration, gocache.NoExpiration), inflightCache: gocache.New(gocache.NoExpiration, gocache.NoExpiration),
ps: conf.Storage,
}, nil }, nil
} }
// SetShuttingDown is a setter for the shuttingDown field
func (c *LeaseCache) SetShuttingDown(in bool) {
c.shuttingDown.Store(in)
}
// SetPersistentStorage is a setter for the persistent storage field in
// LeaseCache
func (c *LeaseCache) SetPersistentStorage(storageIn *cacheboltdb.BoltStorage) {
c.ps = storageIn
}
// checkCacheForRequest checks the cache for a particular request based on its // checkCacheForRequest checks the cache for a particular request based on its
// computed ID. It returns a non-nil *SendResponse if an entry is found. // computed ID. It returns a non-nil *SendResponse if an entry is found.
func (c *LeaseCache) checkCacheForRequest(id string) (*SendResponse, error) { func (c *LeaseCache) checkCacheForRequest(id string) (*SendResponse, error) {
@ -275,6 +297,7 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
ID: id, ID: id,
Namespace: namespace, Namespace: namespace,
RequestPath: req.Request.URL.Path, RequestPath: req.Request.URL.Path,
LastRenewed: time.Now().UTC(),
} }
secret, err := api.ParseSecret(bytes.NewReader(resp.ResponseBody)) secret, err := api.ParseSecret(bytes.NewReader(resp.ResponseBody))
@ -332,6 +355,8 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
index.Lease = secret.LeaseID index.Lease = secret.LeaseID
index.LeaseToken = req.Token index.LeaseToken = req.Token
index.Type = cacheboltdb.SecretLeaseType
case secret.Auth != nil: case secret.Auth != nil:
c.logger.Debug("processing auth response", "method", req.Request.Method, "path", req.Request.URL.Path) c.logger.Debug("processing auth response", "method", req.Request.Method, "path", req.Request.URL.Path)
@ -360,6 +385,8 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
index.Token = secret.Auth.ClientToken index.Token = secret.Auth.ClientToken
index.TokenAccessor = secret.Auth.Accessor index.TokenAccessor = secret.Auth.Accessor
index.Type = cacheboltdb.AuthLeaseType
default: default:
// We shouldn't be hitting this, but will err on the side of caution and // We shouldn't be hitting this, but will err on the side of caution and
// simply proxy. // simply proxy.
@ -394,9 +421,14 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
DoneCh: renewCtxInfo.DoneCh, DoneCh: renewCtxInfo.DoneCh,
} }
// Add extra information necessary for restoring from persisted cache
index.RequestMethod = req.Request.Method
index.RequestToken = req.Token
index.RequestHeader = req.Request.Header
// Store the index in the cache // Store the index in the cache
c.logger.Debug("storing response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path) c.logger.Debug("storing response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path)
err = c.db.Set(index) err = c.Set(ctx, index)
if err != nil { if err != nil {
c.logger.Error("failed to cache the proxied response", "error", err) c.logger.Error("failed to cache the proxied response", "error", err)
return nil, err return nil, err
@ -420,8 +452,12 @@ func (c *LeaseCache) createCtxInfo(ctx context.Context) *cachememdb.ContextInfo
func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index, req *SendRequest, secret *api.Secret) { func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index, req *SendRequest, secret *api.Secret) {
defer func() { defer func() {
id := ctx.Value(contextIndexID).(string) id := ctx.Value(contextIndexID).(string)
if c.shuttingDown.Load() {
c.logger.Trace("not evicting index from cache during shutdown", "id", id, "method", req.Request.Method, "path", req.Request.URL.Path)
return
}
c.logger.Debug("evicting index from cache", "id", id, "method", req.Request.Method, "path", req.Request.URL.Path) c.logger.Debug("evicting index from cache", "id", id, "method", req.Request.Method, "path", req.Request.URL.Path)
err := c.db.Evict(cachememdb.IndexNameID, id) err := c.Evict(id)
if err != nil { if err != nil {
c.logger.Error("failed to evict index", "id", id, "error", err) c.logger.Error("failed to evict index", "id", id, "error", err)
return return
@ -466,6 +502,11 @@ func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index,
return return
case <-watcher.RenewCh(): case <-watcher.RenewCh():
c.logger.Debug("secret renewed", "path", req.Request.URL.Path) c.logger.Debug("secret renewed", "path", req.Request.URL.Path)
if c.ps != nil {
if err := c.updateLastRenewed(ctx, index, time.Now().UTC()); err != nil {
c.logger.Warn("not able to update lastRenewed time for cached index", "id", index.ID)
}
}
case <-index.RenewCtxInfo.DoneCh: case <-index.RenewCtxInfo.DoneCh:
// This case indicates the renewal process to shutdown and evict // This case indicates the renewal process to shutdown and evict
// the cache entry. This is triggered when a specific secret // the cache entry. This is triggered when a specific secret
@ -477,6 +518,22 @@ func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index,
} }
} }
func (c *LeaseCache) updateLastRenewed(ctx context.Context, index *cachememdb.Index, t time.Time) error {
idLock := locksutil.LockForKey(c.idLocks, index.ID)
idLock.Lock()
defer idLock.Unlock()
getIndex, err := c.db.Get(cachememdb.IndexNameID, index.ID)
if err != nil {
return err
}
index.LastRenewed = t
if err := c.Set(ctx, getIndex); err != nil {
return err
}
return nil
}
// computeIndexID results in a value that uniquely identifies a request // computeIndexID results in a value that uniquely identifies a request
// received by the agent. It does so by SHA256 hashing the serialized request // received by the agent. It does so by SHA256 hashing the serialized request
// object containing the request path, query parameters and body parameters. // object containing the request path, query parameters and body parameters.
@ -642,8 +699,8 @@ func (c *LeaseCache) handleCacheClear(ctx context.Context, in *cacheClearInput)
} }
c.l.Unlock() c.l.Unlock()
// Reset the memdb instance // Reset the memdb instance (and persistent storage if enabled)
if err := c.db.Flush(); err != nil { if err := c.Flush(); err != nil {
return err return err
} }
@ -850,6 +907,213 @@ func (c *LeaseCache) handleRevocationRequest(ctx context.Context, req *SendReque
return true, nil return true, nil
} }
// Set stores the index in the cachememdb, and also stores it in the persistent
// cache (if enabled)
func (c *LeaseCache) Set(ctx context.Context, index *cachememdb.Index) error {
if err := c.db.Set(index); err != nil {
return err
}
if c.ps != nil {
b, err := index.Serialize()
if err != nil {
return err
}
if err := c.ps.Set(ctx, index.ID, b, index.Type); err != nil {
return err
}
c.logger.Trace("set entry in persistent storage", "type", index.Type, "path", index.RequestPath, "id", index.ID)
}
return nil
}
// Evict removes an Index from the cachememdb, and also removes it from the
// persistent cache (if enabled)
func (c *LeaseCache) Evict(id string) error {
if err := c.db.Evict(cachememdb.IndexNameID, id); err != nil {
return err
}
if c.ps != nil {
if err := c.ps.Delete(id); err != nil {
return err
}
c.logger.Trace("deleted item from persistent storage", "id", id)
}
return nil
}
// Flush the cachememdb and persistent cache (if enabled)
func (c *LeaseCache) Flush() error {
if err := c.db.Flush(); err != nil {
return err
}
if c.ps != nil {
c.logger.Trace("clearing persistent storage")
return c.ps.Clear()
}
return nil
}
// Restore loads the cachememdb from the persistent storage passed in. Loads
// tokens first, since restoring a lease's renewal context and watcher requires
// looking up the token in the cachememdb.
func (c *LeaseCache) Restore(ctx context.Context, storage *cacheboltdb.BoltStorage) error {
// Process tokens first
tokens, err := storage.GetByType(ctx, cacheboltdb.TokenType)
if err != nil {
return err
}
if err := c.restoreTokens(tokens); err != nil {
return err
}
// Then process auth leases
authLeases, err := storage.GetByType(ctx, cacheboltdb.AuthLeaseType)
if err != nil {
return err
}
if err := c.restoreLeases(authLeases); err != nil {
return err
}
// Then process secret leases
secretLeases, err := storage.GetByType(ctx, cacheboltdb.SecretLeaseType)
if err != nil {
return err
}
if err := c.restoreLeases(secretLeases); err != nil {
return err
}
return nil
}
func (c *LeaseCache) restoreTokens(tokens [][]byte) error {
for _, token := range tokens {
newIndex, err := cachememdb.Deserialize(token)
if err != nil {
return err
}
newIndex.RenewCtxInfo = c.createCtxInfo(nil)
if err := c.db.Set(newIndex); err != nil {
return err
}
c.logger.Trace("restored token", "id", newIndex.ID)
}
return nil
}
func (c *LeaseCache) restoreLeases(leases [][]byte) error {
for _, lease := range leases {
newIndex, err := cachememdb.Deserialize(lease)
if err != nil {
return err
}
// Check if this lease has already expired
expired, err := c.hasExpired(time.Now().UTC(), newIndex)
if err != nil {
c.logger.Warn("failed to check if lease is expired", "id", newIndex.ID, "error", err)
}
if expired {
continue
}
if err := c.restoreLeaseRenewCtx(newIndex); err != nil {
return err
}
if err := c.db.Set(newIndex); err != nil {
return err
}
c.logger.Trace("restored lease", "id", newIndex.ID, "path", newIndex.RequestPath)
}
return nil
}
// restoreLeaseRenewCtx re-creates a RenewCtx for an index object and starts
// the watcher go routine
func (c *LeaseCache) restoreLeaseRenewCtx(index *cachememdb.Index) error {
if index.Response == nil {
return fmt.Errorf("cached response was nil for %s", index.ID)
}
// Parse the secret to determine which type it is
reader := bufio.NewReader(bytes.NewReader(index.Response))
resp, err := http.ReadResponse(reader, nil)
if err != nil {
c.logger.Error("failed to deserialize response", "error", err)
return err
}
secret, err := api.ParseSecret(resp.Body)
if err != nil {
c.logger.Error("failed to parse response as secret", "error", err)
return err
}
var renewCtxInfo *cachememdb.ContextInfo
switch {
case secret.LeaseID != "":
entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken)
if err != nil {
return err
}
if entry == nil {
return fmt.Errorf("could not find parent Token %s for req path %s", index.RequestToken, index.RequestPath)
}
// Derive a context for renewal using the token's context
renewCtxInfo = cachememdb.NewContextInfo(entry.RenewCtxInfo.Ctx)
case secret.Auth != nil:
var parentCtx context.Context
if !secret.Auth.Orphan {
entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken)
if err != nil {
return err
}
// If parent token is not managed by the agent, child shouldn't be
// either.
if entry == nil {
return fmt.Errorf("could not find parent Token %s for req path %s", index.RequestToken, index.RequestPath)
}
c.logger.Debug("setting parent context", "method", index.RequestMethod, "path", index.RequestPath)
parentCtx = entry.RenewCtxInfo.Ctx
}
renewCtxInfo = c.createCtxInfo(parentCtx)
default:
return fmt.Errorf("unknown cached index item: %s", index.ID)
}
renewCtx := context.WithValue(renewCtxInfo.Ctx, contextIndexID, index.ID)
index.RenewCtxInfo = &cachememdb.ContextInfo{
Ctx: renewCtx,
CancelFunc: renewCtxInfo.CancelFunc,
DoneCh: renewCtxInfo.DoneCh,
}
sendReq := &SendRequest{
Token: index.RequestToken,
Request: &http.Request{
Header: index.RequestHeader,
Method: index.RequestMethod,
URL: &url.URL{
Path: index.RequestPath,
},
},
}
go c.startRenewing(renewCtx, index, sendReq, secret)
return nil
}
// deriveNamespaceAndRevocationPath returns the namespace and relative path for // deriveNamespaceAndRevocationPath returns the namespace and relative path for
// revocation paths. // revocation paths.
// //
@ -912,9 +1176,11 @@ func (c *LeaseCache) RegisterAutoAuthToken(token string) error {
return err return err
} }
// If the index is found, defer its cancelFunc // If the index is found, just keep it in the cache and ignore the incoming
// token (since they're the same)
if oldIndex != nil { if oldIndex != nil {
defer oldIndex.RenewCtxInfo.CancelFunc() c.logger.Trace("auto-auth token already exists in cache; no need to store it again")
return nil
} }
// The following randomly generated values are required for index stored by // The following randomly generated values are required for index stored by
@ -938,6 +1204,7 @@ func (c *LeaseCache) RegisterAutoAuthToken(token string) error {
Token: token, Token: token,
Namespace: namespace, Namespace: namespace,
RequestPath: requestPath, RequestPath: requestPath,
Type: cacheboltdb.TokenType,
} }
// Derive a context off of the lease cache's base context // Derive a context off of the lease cache's base context
@ -951,7 +1218,7 @@ func (c *LeaseCache) RegisterAutoAuthToken(token string) error {
// Store the index in the cache // Store the index in the cache
c.logger.Debug("storing auto-auth token into the cache") c.logger.Debug("storing auto-auth token into the cache")
err = c.db.Set(index) err = c.Set(c.baseCtxInfo.Ctx, index)
if err != nil { if err != nil {
c.logger.Error("failed to cache the auto-auth token", "error", err) c.logger.Error("failed to cache the auto-auth token", "error", err)
return err return err
@ -997,3 +1264,32 @@ func parseCacheClearInput(req *cacheClearRequest) (*cacheClearInput, error) {
return in, nil return in, nil
} }
func (c *LeaseCache) hasExpired(currentTime time.Time, index *cachememdb.Index) (bool, error) {
reader := bufio.NewReader(bytes.NewReader(index.Response))
resp, err := http.ReadResponse(reader, nil)
if err != nil {
return false, fmt.Errorf("failed to deserialize response: %w", err)
}
secret, err := api.ParseSecret(resp.Body)
if err != nil {
return false, fmt.Errorf("failed to parse response as secret: %w", err)
}
elapsed := currentTime.Sub(index.LastRenewed)
var leaseDuration int
switch index.Type {
case cacheboltdb.AuthLeaseType:
leaseDuration = secret.Auth.LeaseDuration
case cacheboltdb.SecretLeaseType:
leaseDuration = secret.LeaseDuration
default:
return false, fmt.Errorf("index type %q unexpected in expiration check", index.Type)
}
if int(elapsed.Seconds()) > leaseDuration {
c.logger.Trace("secret has expired", "id", index.ID, "elapsed", elapsed, "lease duration", leaseDuration)
return true, nil
}
return false, nil
}

View File

@ -3,9 +3,11 @@ package cache
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"os"
"reflect" "reflect"
"strings" "strings"
"sync" "sync"
@ -15,9 +17,13 @@ import (
"github.com/go-test/deep" "github.com/go-test/deep"
hclog "github.com/hashicorp/go-hclog" hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agent/cache/cacheboltdb"
"github.com/hashicorp/vault/command/agent/cache/cachememdb" "github.com/hashicorp/vault/command/agent/cache/cachememdb"
"github.com/hashicorp/vault/command/agent/cache/keymanager"
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/helper/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic" "go.uber.org/atomic"
) )
@ -63,6 +69,24 @@ func testNewLeaseCacheWithDelay(t *testing.T, cacheable bool, delay int) *LeaseC
return lc return lc
} }
func testNewLeaseCacheWithPersistence(t *testing.T, responses []*SendResponse, storage *cacheboltdb.BoltStorage) *LeaseCache {
t.Helper()
client, err := api.NewClient(api.DefaultConfig())
require.NoError(t, err)
lc, err := NewLeaseCache(&LeaseCacheConfig{
Client: client,
BaseContext: context.Background(),
Proxier: newMockProxier(responses),
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"),
Storage: storage,
})
require.NoError(t, err)
return lc
}
func TestCache_ComputeIndexID(t *testing.T) { func TestCache_ComputeIndexID(t *testing.T) {
type args struct { type args struct {
req *http.Request req *http.Request
@ -649,3 +673,394 @@ func TestLeaseCache_Concurrent_Cacheable(t *testing.T) {
t.Fatalf("Should have returned a cached response 99 times, got %d", cacheCount.Load()) t.Fatalf("Should have returned a cached response 99 times, got %d", cacheCount.Load())
} }
} }
func setupBoltStorage(t *testing.T) (tempCacheDir string, boltStorage *cacheboltdb.BoltStorage) {
t.Helper()
km, err := keymanager.NewPassthroughKeyManager(nil)
require.NoError(t, err)
tempCacheDir, err = ioutil.TempDir("", "agent-cache-test")
require.NoError(t, err)
boltStorage, err = cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
Path: tempCacheDir,
Logger: hclog.Default(),
Wrapper: km.Wrapper(),
})
require.NoError(t, err)
require.NotNil(t, boltStorage)
// The calling function should `defer boltStorage.Close()` and `defer os.RemoveAll(tempCacheDir)`
return tempCacheDir, boltStorage
}
func TestLeaseCache_PersistAndRestore(t *testing.T) {
// Emulate 4 responses from the api proxy. The first two use the auto-auth
// token, and the last two use another token.
responses := []*SendResponse{
newTestSendResponse(200, `{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": 600}}`),
newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}, "lease_duration": 600}`),
newTestSendResponse(202, `{"auth": {"client_token": "testtoken2", "renewable": true, "orphan": true, "lease_duration": 600}}`),
newTestSendResponse(203, `{"lease_id": "secret2-lease", "renewable": true, "data": {"number": "two"}, "lease_duration": 600}`),
}
tempDir, boltStorage := setupBoltStorage(t)
defer os.RemoveAll(tempDir)
defer boltStorage.Close()
lc := testNewLeaseCacheWithPersistence(t, responses, boltStorage)
// Register an auto-auth token so that the token and lease requests are cached
lc.RegisterAutoAuthToken("autoauthtoken")
cacheTests := []struct {
token string
method string
urlPath string
body string
wantStatusCode int
}{
{
// Make a request. A response with a new token is returned to the
// lease cache and that will be cached.
token: "autoauthtoken",
method: "GET",
urlPath: "http://example.com/v1/sample/api",
body: `{"value": "input"}`,
wantStatusCode: responses[0].Response.StatusCode,
},
{
// Modify the request a little bit to ensure the second response is
// returned to the lease cache.
token: "autoauthtoken",
method: "GET",
urlPath: "http://example.com/v1/sample/api",
body: `{"value": "input_changed"}`,
wantStatusCode: responses[1].Response.StatusCode,
},
{
// Simulate an approle login to get another token
method: "PUT",
urlPath: "http://example.com/v1/auth/approle/login",
body: `{"role_id": "my role", "secret_id": "my secret"}`,
wantStatusCode: responses[2].Response.StatusCode,
},
{
// Test caching with the token acquired from the approle login
token: "testtoken2",
method: "GET",
urlPath: "http://example.com/v1/sample2/api",
body: `{"second": "input"}`,
wantStatusCode: responses[3].Response.StatusCode,
},
}
for _, ct := range cacheTests {
// Send once to cache
sendReq := &SendRequest{
Token: ct.token,
Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)),
}
resp, err := lc.Send(context.Background(), sendReq)
require.NoError(t, err)
assert.Equal(t, resp.Response.StatusCode, ct.wantStatusCode, "expected proxied response")
assert.Nil(t, resp.CacheMeta)
// Send again to test cache. If this isn't cached, the response returned
// will be the next in the list and the status code will not match.
sendCacheReq := &SendRequest{
Token: ct.token,
Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)),
}
respCached, err := lc.Send(context.Background(), sendCacheReq)
require.NoError(t, err, "failed to send request %+v", ct)
assert.Equal(t, respCached.Response.StatusCode, ct.wantStatusCode, "expected proxied response")
require.NotNil(t, respCached.CacheMeta)
assert.True(t, respCached.CacheMeta.Hit)
}
// Now we know the cache is working, so try restoring from the persisted
// cache's storage
restoredCache := testNewLeaseCache(t, nil)
err := restoredCache.Restore(context.Background(), boltStorage)
assert.NoError(t, err)
// Now compare before and after
beforeDB, err := lc.db.GetByPrefix(cachememdb.IndexNameID)
require.NoError(t, err)
assert.Len(t, beforeDB, 5)
for _, cachedItem := range beforeDB {
restoredItem, err := restoredCache.db.Get(cachememdb.IndexNameID, cachedItem.ID)
require.NoError(t, err)
assert.NoError(t, err)
assert.Equal(t, cachedItem.ID, restoredItem.ID)
assert.Equal(t, cachedItem.Lease, restoredItem.Lease)
assert.Equal(t, cachedItem.LeaseToken, restoredItem.LeaseToken)
assert.Equal(t, cachedItem.Namespace, restoredItem.Namespace)
assert.Equal(t, cachedItem.RequestHeader, restoredItem.RequestHeader)
assert.Equal(t, cachedItem.RequestMethod, restoredItem.RequestMethod)
assert.Equal(t, cachedItem.RequestPath, restoredItem.RequestPath)
assert.Equal(t, cachedItem.RequestToken, restoredItem.RequestToken)
assert.Equal(t, cachedItem.Response, restoredItem.Response)
assert.Equal(t, cachedItem.Token, restoredItem.Token)
assert.Equal(t, cachedItem.TokenAccessor, restoredItem.TokenAccessor)
assert.Equal(t, cachedItem.TokenParent, restoredItem.TokenParent)
// check what we can in the renewal context
assert.NotEmpty(t, restoredItem.RenewCtxInfo.CancelFunc)
assert.NotZero(t, restoredItem.RenewCtxInfo.DoneCh)
require.NotEmpty(t, restoredItem.RenewCtxInfo.Ctx)
assert.Equal(t,
cachedItem.RenewCtxInfo.Ctx.Value(contextIndexID),
restoredItem.RenewCtxInfo.Ctx.Value(contextIndexID),
)
}
afterDB, err := restoredCache.db.GetByPrefix(cachememdb.IndexNameID)
require.NoError(t, err)
assert.Len(t, afterDB, 5)
// And finally send the cache requests once to make sure they're all being
// served from the restoredCache
for _, ct := range cacheTests {
sendCacheReq := &SendRequest{
Token: ct.token,
Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)),
}
respCached, err := restoredCache.Send(context.Background(), sendCacheReq)
require.NoError(t, err, "failed to send request %+v", ct)
assert.Equal(t, respCached.Response.StatusCode, ct.wantStatusCode, "expected proxied response")
require.NotNil(t, respCached.CacheMeta)
assert.True(t, respCached.CacheMeta.Hit)
}
}
func TestEvictPersistent(t *testing.T) {
ctx := context.Background()
responses := []*SendResponse{
newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}}`),
}
tempDir, boltStorage := setupBoltStorage(t)
defer os.RemoveAll(tempDir)
defer boltStorage.Close()
lc := testNewLeaseCacheWithPersistence(t, responses, boltStorage)
lc.RegisterAutoAuthToken("autoauthtoken")
// populate cache by sending request through
sendReq := &SendRequest{
Token: "autoauthtoken",
Request: httptest.NewRequest("GET", "http://example.com/v1/sample/api", strings.NewReader(`{"value": "some_input"}`)),
}
resp, err := lc.Send(context.Background(), sendReq)
require.NoError(t, err)
assert.Equal(t, resp.Response.StatusCode, 201, "expected proxied response")
assert.Nil(t, resp.CacheMeta)
// Check bolt for the cached lease
secrets, err := lc.ps.GetByType(ctx, cacheboltdb.SecretLeaseType)
require.NoError(t, err)
assert.Len(t, secrets, 1)
// Call clear for the request path
err = lc.handleCacheClear(context.Background(), &cacheClearInput{
Type: "request_path",
RequestPath: "/v1/sample/api",
})
require.NoError(t, err)
time.Sleep(2 * time.Second)
// Check that cached item is gone
secrets, err = lc.ps.GetByType(ctx, cacheboltdb.SecretLeaseType)
require.NoError(t, err)
assert.Len(t, secrets, 0)
}
func TestRegisterAutoAuth_sameToken(t *testing.T) {
// If the auto-auth token already exists in the cache, it should not be
// stored again in a new index.
lc := testNewLeaseCache(t, nil)
err := lc.RegisterAutoAuthToken("autoauthtoken")
assert.NoError(t, err)
oldTokenIndex, err := lc.db.Get(cachememdb.IndexNameToken, "autoauthtoken")
assert.NoError(t, err)
oldTokenID := oldTokenIndex.ID
// register the same token again
err = lc.RegisterAutoAuthToken("autoauthtoken")
assert.NoError(t, err)
// check that there's only one index for autoauthtoken
entries, err := lc.db.GetByPrefix(cachememdb.IndexNameToken, "autoauthtoken")
assert.NoError(t, err)
assert.Len(t, entries, 1)
newTokenIndex, err := lc.db.Get(cachememdb.IndexNameToken, "autoauthtoken")
assert.NoError(t, err)
// compare the ID's since those are randomly generated when an index for a
// token is added to the cache, so if a new token was added, the id's will
// not match.
assert.Equal(t, oldTokenID, newTokenIndex.ID)
}
func Test_hasExpired(t *testing.T) {
responses := []*SendResponse{
newTestSendResponse(200, `{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": 60}}`),
newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}, "lease_duration": 60}`),
}
lc := testNewLeaseCache(t, responses)
lc.RegisterAutoAuthToken("autoauthtoken")
cacheTests := []struct {
token string
urlPath string
leaseType string
wantStatusCode int
}{
{
// auth lease
token: "autoauthtoken",
urlPath: "/v1/sample/auth",
leaseType: cacheboltdb.AuthLeaseType,
wantStatusCode: responses[0].Response.StatusCode,
},
{
// secret lease
token: "autoauthtoken",
urlPath: "/v1/sample/secret",
leaseType: cacheboltdb.SecretLeaseType,
wantStatusCode: responses[1].Response.StatusCode,
},
}
for _, ct := range cacheTests {
// Send once to cache
urlPath := "http://example.com" + ct.urlPath
sendReq := &SendRequest{
Token: ct.token,
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
}
resp, err := lc.Send(context.Background(), sendReq)
require.NoError(t, err)
assert.Equal(t, resp.Response.StatusCode, ct.wantStatusCode, "expected proxied response")
assert.Nil(t, resp.CacheMeta)
// get the Index out of the mem cache
index, err := lc.db.Get(cachememdb.IndexNameRequestPath, "root/", ct.urlPath)
require.NoError(t, err)
assert.Equal(t, ct.leaseType, index.Type)
// The lease duration is 60 seconds, so time.Now() should be within that
notExpired, err := lc.hasExpired(time.Now().UTC(), index)
require.NoError(t, err)
assert.False(t, notExpired)
// In 90 seconds the index should be "expired"
futureTime := time.Now().UTC().Add(time.Second * 90)
expired, err := lc.hasExpired(futureTime, index)
require.NoError(t, err)
assert.True(t, expired)
}
}
func TestLeaseCache_hasExpired_wrong_type(t *testing.T) {
index := &cachememdb.Index{
Type: cacheboltdb.TokenType,
Response: []byte(`HTTP/0.0 200 OK
Content-Type: application/json
Date: Tue, 02 Mar 2021 17:54:16 GMT
{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": 60}}`),
}
lc := testNewLeaseCache(t, nil)
expired, err := lc.hasExpired(time.Now().UTC(), index)
assert.False(t, expired)
assert.EqualError(t, err, `index type "token" unexpected in expiration check`)
}
func TestLeaseCacheRestore_expired(t *testing.T) {
// Emulate 2 responses from the api proxy, both expired
responses := []*SendResponse{
newTestSendResponse(200, `{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": -600}}`),
newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}, "lease_duration": -600}`),
}
tempDir, boltStorage := setupBoltStorage(t)
defer os.RemoveAll(tempDir)
defer boltStorage.Close()
lc := testNewLeaseCacheWithPersistence(t, responses, boltStorage)
// Register an auto-auth token so that the token and lease requests are cached in mem
lc.RegisterAutoAuthToken("autoauthtoken")
cacheTests := []struct {
token string
method string
urlPath string
body string
wantStatusCode int
}{
{
// Make a request. A response with a new token is returned to the
// lease cache and that will be cached.
token: "autoauthtoken",
method: "GET",
urlPath: "http://example.com/v1/sample/api",
body: `{"value": "input"}`,
wantStatusCode: responses[0].Response.StatusCode,
},
{
// Modify the request a little bit to ensure the second response is
// returned to the lease cache.
token: "autoauthtoken",
method: "GET",
urlPath: "http://example.com/v1/sample/api",
body: `{"value": "input_changed"}`,
wantStatusCode: responses[1].Response.StatusCode,
},
}
for _, ct := range cacheTests {
// Send once to cache
sendReq := &SendRequest{
Token: ct.token,
Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)),
}
resp, err := lc.Send(context.Background(), sendReq)
require.NoError(t, err)
assert.Equal(t, resp.Response.StatusCode, ct.wantStatusCode, "expected proxied response")
assert.Nil(t, resp.CacheMeta)
}
// Restore from the persisted cache's storage
restoredCache := testNewLeaseCache(t, nil)
err := restoredCache.Restore(context.Background(), boltStorage)
assert.NoError(t, err)
// The original mem cache should have all three items
beforeDB, err := lc.db.GetByPrefix(cachememdb.IndexNameID)
require.NoError(t, err)
assert.Len(t, beforeDB, 3)
// There should only be one item in the restored cache: the autoauth token
afterDB, err := restoredCache.db.GetByPrefix(cachememdb.IndexNameID)
require.NoError(t, err)
assert.Len(t, afterDB, 1)
// Just verify that the one item in the restored mem cache matches one in the original mem cache, and that it's the auto-auth token
beforeItem, err := lc.db.Get(cachememdb.IndexNameID, afterDB[0].ID)
require.NoError(t, err)
assert.NotNil(t, beforeItem)
assert.Equal(t, "autoauthtoken", afterDB[0].Token)
assert.Equal(t, cacheboltdb.TokenType, afterDB[0].Type)
}

View File

@ -50,6 +50,16 @@ type Cache struct {
ForceAutoAuthToken bool `hcl:"-"` ForceAutoAuthToken bool `hcl:"-"`
EnforceConsistency string `hcl:"enforce_consistency"` EnforceConsistency string `hcl:"enforce_consistency"`
WhenInconsistent string `hcl:"when_inconsistent"` WhenInconsistent string `hcl:"when_inconsistent"`
Persist *Persist `hcl:"persist"`
}
// Persist contains configuration needed for persistent caching
type Persist struct {
Type string
Path string `hcl:"path"`
KeepAfterImport bool `hcl:"keep_after_import"`
ExitOnErr bool `hcl:"exit_on_err"`
ServiceAccountTokenFile string `hcl:"service_account_token_file"`
} }
// AutoAuth is the configured authentication method and sinks // AutoAuth is the configured authentication method and sinks
@ -309,8 +319,51 @@ func parseCache(result *Config, list *ast.ObjectList) error {
} }
} }
} }
result.Cache = &c result.Cache = &c
subs, ok := item.Val.(*ast.ObjectType)
if !ok {
return fmt.Errorf("could not parse %q as an object", name)
}
subList := subs.List
if err := parsePersist(result, subList); err != nil {
return fmt.Errorf("error parsing persist: %w", err)
}
return nil
}
func parsePersist(result *Config, list *ast.ObjectList) error {
name := "persist"
persistList := list.Filter(name)
if len(persistList.Items) == 0 {
return nil
}
if len(persistList.Items) > 1 {
return fmt.Errorf("only one %q block is required", name)
}
item := persistList.Items[0]
var p Persist
err := hcl.DecodeObject(&p, item.Val)
if err != nil {
return err
}
if p.Type == "" {
if len(item.Keys) == 1 {
p.Type = strings.ToLower(item.Keys[0].Token.Value().(string))
}
if p.Type == "" {
return errors.New("persist type must be specified")
}
}
result.Cache.Persist = &p
return nil return nil
} }

View File

@ -66,6 +66,13 @@ func TestLoadConfigFile_AgentCache(t *testing.T) {
UseAutoAuthToken: true, UseAutoAuthToken: true,
UseAutoAuthTokenRaw: true, UseAutoAuthTokenRaw: true,
ForceAutoAuthToken: false, ForceAutoAuthToken: false,
Persist: &Persist{
Type: "kubernetes",
Path: "/vault/agent-cache/",
KeepAfterImport: true,
ExitOnErr: true,
ServiceAccountTokenFile: "/tmp/serviceaccount/token",
},
}, },
Vault: &Vault{ Vault: &Vault{
Address: "http://127.0.0.1:1111", Address: "http://127.0.0.1:1111",
@ -445,6 +452,52 @@ func TestLoadConfigFile_AgentCache_AutoAuth_False(t *testing.T) {
} }
} }
func TestLoadConfigFile_AgentCache_Persist(t *testing.T) {
config, err := LoadConfig("./test-fixtures/config-cache-persist-false.hcl")
if err != nil {
t.Fatalf("err: %s", err)
}
expected := &Config{
Cache: &Cache{
Persist: &Persist{
Type: "kubernetes",
Path: "/vault/agent-cache/",
KeepAfterImport: false,
ExitOnErr: false,
ServiceAccountTokenFile: "",
},
},
SharedConfig: &configutil.SharedConfig{
PidFile: "./pidfile",
Listeners: []*configutil.Listener{
{
Type: "tcp",
Address: "127.0.0.1:8300",
TLSDisable: true,
},
},
},
}
config.Listeners[0].RawConfig = nil
if diff := deep.Equal(config, expected); diff != nil {
t.Fatal(diff)
}
config.Listeners[0].RawConfig = nil
if diff := deep.Equal(config, expected); diff != nil {
t.Fatal(diff)
}
}
func TestLoadConfigFile_AgentCache_PersistMissingType(t *testing.T) {
_, err := LoadConfig("./test-fixtures/config-cache-persist-empty-type.hcl")
if err == nil || os.IsNotExist(err) {
t.Fatal("expected error or file is missing")
}
}
// TestLoadConfigFile_Template tests template definitions in Vault Agent // TestLoadConfigFile_Template tests template definitions in Vault Agent
func TestLoadConfigFile_Template(t *testing.T) { func TestLoadConfigFile_Template(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {

View File

@ -21,6 +21,12 @@ auto_auth {
cache { cache {
use_auto_auth_token = true use_auto_auth_token = true
persist "kubernetes" {
path = "/vault/agent-cache/"
keep_after_import = true
exit_on_err = true
service_account_token_file = "/tmp/serviceaccount/token"
}
} }
listener { listener {

View File

@ -0,0 +1,12 @@
pid_file = "./pidfile"
cache {
persist = {
path = "/vault/agent-cache/"
}
}
listener "tcp" {
address = "127.0.0.1:8300"
tls_disable = true
}

View File

@ -0,0 +1,14 @@
pid_file = "./pidfile"
cache {
persist "kubernetes" {
exit_on_err = false
keep_after_import = false
path = "/vault/agent-cache/"
}
}
listener "tcp" {
address = "127.0.0.1:8300"
tls_disable = true
}

View File

@ -21,6 +21,13 @@ auto_auth {
cache { cache {
use_auto_auth_token = true use_auto_auth_token = true
persist = {
type = "kubernetes"
path = "/vault/agent-cache/"
keep_after_import = true
exit_on_err = true
service_account_token_file = "/tmp/serviceaccount/token"
}
} }
listener "unix" { listener "unix" {

View File

@ -256,7 +256,7 @@ func newRunnerConfig(sc *ServerConfig, templates ctconfig.TemplateConfigs) (*ctc
} }
// Use the cache if available or fallback to the Vault server values. // Use the cache if available or fallback to the Vault server values.
if sc.AgentConfig.Cache != nil && len(sc.AgentConfig.Listeners) != 0 { if sc.AgentConfig.Cache != nil && sc.AgentConfig.Cache.Persist != nil && len(sc.AgentConfig.Listeners) != 0 {
scheme := "unix://" scheme := "unix://"
if sc.AgentConfig.Listeners[0].Type == "tcp" { if sc.AgentConfig.Listeners[0].Type == "tcp" {
scheme = "https://" scheme = "https://"

View File

@ -28,7 +28,7 @@ func TestNewServer(t *testing.T) {
} }
} }
func newAgentConfig(listeners []*configutil.Listener, enableCache bool) *config.Config { func newAgentConfig(listeners []*configutil.Listener, enableCache, enablePersisentCache bool) *config.Config {
agentConfig := &config.Config{ agentConfig := &config.Config{
SharedConfig: &configutil.SharedConfig{ SharedConfig: &configutil.SharedConfig{
PidFile: "./pidfile", PidFile: "./pidfile",
@ -65,7 +65,13 @@ func newAgentConfig(listeners []*configutil.Listener, enableCache bool) *config.
}, },
} }
if enableCache { if enableCache {
agentConfig.Cache = &config.Cache{UseAutoAuthToken: true} agentConfig.Cache = &config.Cache{
UseAutoAuthToken: true,
}
}
if enablePersisentCache {
agentConfig.Cache.Persist = &config.Persist{Type: "kubernetes"}
} }
return agentConfig return agentConfig
@ -94,7 +100,7 @@ func TestCacheConfigUnix(t *testing.T) {
}, },
} }
agentConfig := newAgentConfig(listeners, true) agentConfig := newAgentConfig(listeners, true, true)
serverConfig := ServerConfig{AgentConfig: agentConfig} serverConfig := ServerConfig{AgentConfig: agentConfig}
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{}) ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
@ -131,7 +137,7 @@ func TestCacheConfigHTTP(t *testing.T) {
}, },
} }
agentConfig := newAgentConfig(listeners, true) agentConfig := newAgentConfig(listeners, true, true)
serverConfig := ServerConfig{AgentConfig: agentConfig} serverConfig := ServerConfig{AgentConfig: agentConfig}
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{}) ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
@ -168,7 +174,7 @@ func TestCacheConfigHTTPS(t *testing.T) {
}, },
} }
agentConfig := newAgentConfig(listeners, true) agentConfig := newAgentConfig(listeners, true, true)
serverConfig := ServerConfig{AgentConfig: agentConfig} serverConfig := ServerConfig{AgentConfig: agentConfig}
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{}) ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
@ -209,7 +215,44 @@ func TestCacheConfigNoCache(t *testing.T) {
}, },
} }
agentConfig := newAgentConfig(listeners, false) agentConfig := newAgentConfig(listeners, false, false)
serverConfig := ServerConfig{AgentConfig: agentConfig}
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
expected := "http://127.0.0.1:1111"
if *ctConfig.Vault.Address != expected {
t.Fatalf("expected %s, got %s", expected, *ctConfig.Vault.Address)
}
}
func TestCacheConfigNoPersistentCache(t *testing.T) {
listeners := []*configutil.Listener{
{
Type: "tcp",
Address: "127.0.0.1:8300",
TLSKeyFile: "/path/to/cakey.pem",
TLSCertFile: "/path/to/cacert.pem",
},
{
Type: "unix",
Address: "foobar",
TLSDisable: true,
SocketMode: "configmode",
SocketUser: "configuser",
SocketGroup: "configgroup",
},
{
Type: "tcp",
Address: "127.0.0.1:8400",
TLSDisable: true,
},
}
agentConfig := newAgentConfig(listeners, true, false)
serverConfig := ServerConfig{AgentConfig: agentConfig} serverConfig := ServerConfig{AgentConfig: agentConfig}
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{}) ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
@ -226,7 +269,7 @@ func TestCacheConfigNoCache(t *testing.T) {
func TestCacheConfigNoListener(t *testing.T) { func TestCacheConfigNoListener(t *testing.T) {
listeners := []*configutil.Listener{} listeners := []*configutil.Listener{}
agentConfig := newAgentConfig(listeners, true) agentConfig := newAgentConfig(listeners, true, true)
serverConfig := ServerConfig{AgentConfig: agentConfig} serverConfig := ServerConfig{AgentConfig: agentConfig}
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{}) ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
@ -264,7 +307,7 @@ func TestCacheConfigRejectMTLS(t *testing.T) {
}, },
} }
agentConfig := newAgentConfig(listeners, true) agentConfig := newAgentConfig(listeners, true, true)
serverConfig := ServerConfig{AgentConfig: agentConfig} serverConfig := ServerConfig{AgentConfig: agentConfig}
_, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{}) _, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})