agent: persistent caching support (#10938)
Adds the option of a write-through cache, backed by boltdb Co-authored-by: Theron Voran <tvoran@users.noreply.github.com> Co-authored-by: Jason O'Donnell <2160810+jasonodonnell@users.noreply.github.com> Co-authored-by: Calvin Leung Huang <cleung2010@gmail.com>
This commit is contained in:
parent
910b45413b
commit
1fdf08b149
|
@ -1,3 +1,3 @@
|
|||
```release-note:improvement
|
||||
agent: Route templating server through cache when enabled.
|
||||
agent: Route templating server through cache when persistent cache is enabled.
|
||||
```
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
```release-note:feature
|
||||
agent: Support for persisting the agent cache to disk
|
||||
```
|
189
command/agent.go
189
command/agent.go
|
@ -6,10 +6,12 @@ import (
|
|||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -30,6 +32,9 @@ import (
|
|||
"github.com/hashicorp/vault/command/agent/auth/kerberos"
|
||||
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
|
||||
"github.com/hashicorp/vault/command/agent/cache"
|
||||
"github.com/hashicorp/vault/command/agent/cache/cacheboltdb"
|
||||
"github.com/hashicorp/vault/command/agent/cache/cachememdb"
|
||||
"github.com/hashicorp/vault/command/agent/cache/keymanager"
|
||||
agentConfig "github.com/hashicorp/vault/command/agent/config"
|
||||
"github.com/hashicorp/vault/command/agent/sink"
|
||||
"github.com/hashicorp/vault/command/agent/sink/file"
|
||||
|
@ -461,6 +466,8 @@ func (c *AgentCommand) Run(args []string) int {
|
|||
c.UI.Output("==> Vault agent started! Log data will stream in below:\n")
|
||||
}
|
||||
|
||||
var leaseCache *cache.LeaseCache
|
||||
var previousToken string
|
||||
// Parse agent listener configurations
|
||||
if config.Cache != nil && len(config.Listeners) != 0 {
|
||||
cacheLogger := c.logger.Named("cache")
|
||||
|
@ -479,7 +486,7 @@ func (c *AgentCommand) Run(args []string) int {
|
|||
|
||||
// Create the lease cache proxier and set its underlying proxier to
|
||||
// the API proxier.
|
||||
leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{
|
||||
leaseCache, err = cache.NewLeaseCache(&cache.LeaseCacheConfig{
|
||||
Client: client,
|
||||
BaseContext: ctx,
|
||||
Proxier: apiProxy,
|
||||
|
@ -490,6 +497,152 @@ func (c *AgentCommand) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
// Configure persistent storage and add to LeaseCache
|
||||
if config.Cache.Persist != nil {
|
||||
if config.Cache.Persist.Path == "" {
|
||||
c.UI.Error("must specify persistent cache path")
|
||||
return 1
|
||||
}
|
||||
|
||||
// Set AAD based on key protection type
|
||||
var aad string
|
||||
switch config.Cache.Persist.Type {
|
||||
case "kubernetes":
|
||||
aad, err = getServiceAccountJWT(config.Cache.Persist.ServiceAccountTokenFile)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("failed to read service account token from %s: %s", config.Cache.Persist.ServiceAccountTokenFile, err))
|
||||
return 1
|
||||
}
|
||||
default:
|
||||
c.UI.Error(fmt.Sprintf("persistent key protection type %q not supported", config.Cache.Persist.Type))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Check if bolt file exists already
|
||||
dbFileExists, err := cacheboltdb.DBFileExists(config.Cache.Persist.Path)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("failed to check if bolt file exists at path %s: %s", config.Cache.Persist.Path, err))
|
||||
return 1
|
||||
}
|
||||
if dbFileExists {
|
||||
// Open the bolt file, but wait to setup Encryption
|
||||
ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
|
||||
Path: config.Cache.Persist.Path,
|
||||
Logger: cacheLogger.Named("cacheboltdb"),
|
||||
})
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error opening persistent cache: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Get the token from bolt for retrieving the encryption key,
|
||||
// then setup encryption so that restore is possible
|
||||
token, err := ps.GetRetrievalToken()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error getting retrieval token from persistent cache: %v", err))
|
||||
}
|
||||
|
||||
if err := ps.Close(); err != nil {
|
||||
c.UI.Warn(fmt.Sprintf("Failed to close persistent cache file after getting retrieval token: %s", err))
|
||||
}
|
||||
|
||||
km, err := keymanager.NewPassthroughKeyManager(token)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("failed to configure persistence encryption for cache: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Open the bolt file with the wrapper provided
|
||||
ps, err = cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
|
||||
Path: config.Cache.Persist.Path,
|
||||
Logger: cacheLogger.Named("cacheboltdb"),
|
||||
Wrapper: km.Wrapper(),
|
||||
AAD: aad,
|
||||
})
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error opening persistent cache: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Restore anything in the persistent cache to the memory cache
|
||||
if err := leaseCache.Restore(ctx, ps); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error restoring in-memory cache from persisted file: %v", err))
|
||||
if config.Cache.Persist.ExitOnErr {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
cacheLogger.Info("loaded memcache from persistent storage")
|
||||
|
||||
// Check for previous auto-auth token
|
||||
oldTokenBytes, err := ps.GetAutoAuthToken(ctx)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error in fetching previous auto-auth token: %s", err))
|
||||
if config.Cache.Persist.ExitOnErr {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
if len(oldTokenBytes) > 0 {
|
||||
oldToken, err := cachememdb.Deserialize(oldTokenBytes)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error in deserializing previous auto-auth token cache entry: %s", err))
|
||||
if config.Cache.Persist.ExitOnErr {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
previousToken = oldToken.Token
|
||||
}
|
||||
|
||||
// If keep_after_import true, set persistent storage layer in
|
||||
// leaseCache, else remove db file
|
||||
if config.Cache.Persist.KeepAfterImport {
|
||||
defer ps.Close()
|
||||
leaseCache.SetPersistentStorage(ps)
|
||||
} else {
|
||||
if err := ps.Close(); err != nil {
|
||||
c.UI.Warn(fmt.Sprintf("failed to close persistent cache file: %s", err))
|
||||
}
|
||||
dbFile := filepath.Join(config.Cache.Persist.Path, cacheboltdb.DatabaseFileName)
|
||||
if err := os.Remove(dbFile); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("failed to remove persistent storage file %s: %s", dbFile, err))
|
||||
if config.Cache.Persist.ExitOnErr {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
km, err := keymanager.NewPassthroughKeyManager(nil)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("failed to configure persistence encryption for cache: %s", err))
|
||||
return 1
|
||||
}
|
||||
ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
|
||||
Path: config.Cache.Persist.Path,
|
||||
Logger: cacheLogger.Named("cacheboltdb"),
|
||||
Wrapper: km.Wrapper(),
|
||||
AAD: aad,
|
||||
})
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error creating persistent cache: %v", err))
|
||||
return 1
|
||||
}
|
||||
cacheLogger.Info("configured persistent storage", "path", config.Cache.Persist.Path)
|
||||
|
||||
// Stash the key material in bolt
|
||||
token, err := km.RetrievalToken()
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error getting persistent key: %s", err))
|
||||
return 1
|
||||
}
|
||||
if err := ps.StoreRetrievalToken(token); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error setting key in persistent cache: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
defer ps.Close()
|
||||
leaseCache.SetPersistentStorage(ps)
|
||||
}
|
||||
}
|
||||
|
||||
var inmemSink sink.Sink
|
||||
if config.Cache.UseAutoAuthToken {
|
||||
cacheLogger.Debug("auto-auth token is allowed to be used; configuring inmem sink")
|
||||
|
@ -585,6 +738,11 @@ func (c *AgentCommand) Run(args []string) int {
|
|||
select {
|
||||
case <-c.ShutdownCh:
|
||||
c.UI.Output("==> Vault agent shutdown triggered")
|
||||
// Let the lease cache know this is a shutdown; no need to evict
|
||||
// everything
|
||||
if leaseCache != nil {
|
||||
leaseCache.SetShuttingDown(true)
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
|
@ -604,6 +762,7 @@ func (c *AgentCommand) Run(args []string) int {
|
|||
MaxBackoff: config.AutoAuth.Method.MaxBackoff,
|
||||
EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials,
|
||||
EnableTemplateTokenCh: enableTokenCh,
|
||||
Token: previousToken,
|
||||
})
|
||||
|
||||
ss := sink.NewSinkServer(&sink.SinkServerConfig{
|
||||
|
@ -624,6 +783,11 @@ func (c *AgentCommand) Run(args []string) int {
|
|||
g.Add(func() error {
|
||||
return ah.Run(ctx, method)
|
||||
}, func(error) {
|
||||
// Let the lease cache know this is a shutdown; no need to evict
|
||||
// everything
|
||||
if leaseCache != nil {
|
||||
leaseCache.SetShuttingDown(true)
|
||||
}
|
||||
cancelFunc()
|
||||
})
|
||||
|
||||
|
@ -650,12 +814,22 @@ func (c *AgentCommand) Run(args []string) int {
|
|||
|
||||
return err
|
||||
}, func(error) {
|
||||
// Let the lease cache know this is a shutdown; no need to evict
|
||||
// everything
|
||||
if leaseCache != nil {
|
||||
leaseCache.SetShuttingDown(true)
|
||||
}
|
||||
cancelFunc()
|
||||
})
|
||||
|
||||
g.Add(func() error {
|
||||
return ts.Run(ctx, ah.TemplateTokenCh, config.Templates)
|
||||
}, func(error) {
|
||||
// Let the lease cache know this is a shutdown; no need to evict
|
||||
// everything
|
||||
if leaseCache != nil {
|
||||
leaseCache.SetShuttingDown(true)
|
||||
}
|
||||
cancelFunc()
|
||||
ts.Stop()
|
||||
})
|
||||
|
@ -793,3 +967,16 @@ func (c *AgentCommand) removePidFile(pidPath string) error {
|
|||
}
|
||||
return os.Remove(pidPath)
|
||||
}
|
||||
|
||||
// GetServiceAccountJWT reads the service account jwt from `tokenFile`. Default is
|
||||
// the default service account file path in kubernetes.
|
||||
func getServiceAccountJWT(tokenFile string) (string, error) {
|
||||
if len(tokenFile) == 0 {
|
||||
tokenFile = "/var/run/secrets/kubernetes.io/serviceaccount/token"
|
||||
}
|
||||
token, err := ioutil.ReadFile(tokenFile)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(string(token)), nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,303 @@
|
|||
package cacheboltdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
wrapping "github.com/hashicorp/go-kms-wrapping"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
const (
|
||||
// Keep track of schema version for future migrations
|
||||
storageVersionKey = "version"
|
||||
storageVersion = "1"
|
||||
|
||||
// DatabaseFileName - filename for the persistent cache file
|
||||
DatabaseFileName = "vault-agent-cache.db"
|
||||
|
||||
// metaBucketName - naming the meta bucket that holds the version and
|
||||
// bootstrapping keys
|
||||
metaBucketName = "meta"
|
||||
|
||||
// SecretLeaseType - Bucket/type for leases with secret info
|
||||
SecretLeaseType = "secret-lease"
|
||||
|
||||
// AuthLeaseType - Bucket/type for leases with auth info
|
||||
AuthLeaseType = "auth-lease"
|
||||
|
||||
// TokenType - Bucket/type for auto-auth tokens
|
||||
TokenType = "token"
|
||||
|
||||
// AutoAuthToken - key for the latest auto-auth token
|
||||
AutoAuthToken = "auto-auth-token"
|
||||
|
||||
// RetrievalTokenMaterial is the actual key or token in the key bucket
|
||||
RetrievalTokenMaterial = "retrieval-token-material"
|
||||
)
|
||||
|
||||
// BoltStorage is a persistent cache using a bolt db. Items are organized with
|
||||
// the version and bootstrapping items in the "meta" bucket, and tokens, auth
|
||||
// leases, and secret leases in their own buckets.
|
||||
type BoltStorage struct {
|
||||
db *bolt.DB
|
||||
logger hclog.Logger
|
||||
wrapper wrapping.Wrapper
|
||||
aad string
|
||||
}
|
||||
|
||||
// BoltStorageConfig is the collection of input parameters for setting up bolt
|
||||
// storage
|
||||
type BoltStorageConfig struct {
|
||||
Path string
|
||||
Logger hclog.Logger
|
||||
Wrapper wrapping.Wrapper
|
||||
AAD string
|
||||
}
|
||||
|
||||
// NewBoltStorage opens a new bolt db at the specified file path and returns it.
|
||||
// If the db already exists the buckets will just be created if they don't
|
||||
// exist.
|
||||
func NewBoltStorage(config *BoltStorageConfig) (*BoltStorage, error) {
|
||||
dbPath := filepath.Join(config.Path, DatabaseFileName)
|
||||
db, err := bolt.Open(dbPath, 0600, &bolt.Options{Timeout: 1 * time.Second})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = db.Update(func(tx *bolt.Tx) error {
|
||||
return createBoltSchema(tx)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bs := &BoltStorage{
|
||||
db: db,
|
||||
logger: config.Logger,
|
||||
wrapper: config.Wrapper,
|
||||
aad: config.AAD,
|
||||
}
|
||||
return bs, nil
|
||||
}
|
||||
|
||||
func createBoltSchema(tx *bolt.Tx) error {
|
||||
// create the meta bucket at the top level
|
||||
meta, err := tx.CreateBucketIfNotExists([]byte(metaBucketName))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create bucket %s: %w", metaBucketName, err)
|
||||
}
|
||||
// check and set file version in the meta bucket
|
||||
version := meta.Get([]byte(storageVersionKey))
|
||||
switch {
|
||||
case version == nil:
|
||||
err = meta.Put([]byte(storageVersionKey), []byte(storageVersion))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set storage version: %w", err)
|
||||
}
|
||||
case string(version) != storageVersion:
|
||||
return fmt.Errorf("storage migration from %s to %s not implemented", string(version), storageVersion)
|
||||
}
|
||||
|
||||
// create the buckets for tokens and leases
|
||||
_, err = tx.CreateBucketIfNotExists([]byte(TokenType))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token bucket: %w", err)
|
||||
}
|
||||
_, err = tx.CreateBucketIfNotExists([]byte(AuthLeaseType))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth lease bucket: %w", err)
|
||||
}
|
||||
_, err = tx.CreateBucketIfNotExists([]byte(SecretLeaseType))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create secret lease bucket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set an index (token or lease) in bolt storage
|
||||
func (b *BoltStorage) Set(ctx context.Context, id string, plaintext []byte, indexType string) error {
|
||||
blob, err := b.wrapper.Encrypt(ctx, plaintext, []byte(b.aad))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error encrypting %s index: %w", indexType, err)
|
||||
}
|
||||
|
||||
protoBlob, err := proto.Marshal(blob)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return b.db.Update(func(tx *bolt.Tx) error {
|
||||
s := tx.Bucket([]byte(indexType))
|
||||
if s == nil {
|
||||
return fmt.Errorf("bucket %q not found", indexType)
|
||||
}
|
||||
// If this is an auto-auth token, also stash it in the meta bucket for
|
||||
// easy retrieval upon restore
|
||||
if indexType == TokenType {
|
||||
meta := tx.Bucket([]byte(metaBucketName))
|
||||
if err := meta.Put([]byte(AutoAuthToken), protoBlob); err != nil {
|
||||
return fmt.Errorf("failed to set latest auto-auth token: %w", err)
|
||||
}
|
||||
}
|
||||
return s.Put([]byte(id), protoBlob)
|
||||
})
|
||||
}
|
||||
|
||||
func getBucketIDs(b *bolt.Bucket) ([][]byte, error) {
|
||||
ids := [][]byte{}
|
||||
err := b.ForEach(func(k, v []byte) error {
|
||||
ids = append(ids, k)
|
||||
return nil
|
||||
})
|
||||
return ids, err
|
||||
}
|
||||
|
||||
// Delete an index (token or lease) by id from bolt storage
|
||||
func (b *BoltStorage) Delete(id string) error {
|
||||
return b.db.Update(func(tx *bolt.Tx) error {
|
||||
// Since Delete returns a nil error if the key doesn't exist, just call
|
||||
// delete in all three index buckets without checking existence first
|
||||
if err := tx.Bucket([]byte(TokenType)).Delete([]byte(id)); err != nil {
|
||||
return fmt.Errorf("failed to delete %q from token bucket: %w", id, err)
|
||||
}
|
||||
if err := tx.Bucket([]byte(AuthLeaseType)).Delete([]byte(id)); err != nil {
|
||||
return fmt.Errorf("failed to delete %q from auth lease bucket: %w", id, err)
|
||||
}
|
||||
if err := tx.Bucket([]byte(SecretLeaseType)).Delete([]byte(id)); err != nil {
|
||||
return fmt.Errorf("failed to delete %q from secret lease bucket: %w", id, err)
|
||||
}
|
||||
b.logger.Trace("deleted index from bolt db", "id", id)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (b *BoltStorage) decrypt(ctx context.Context, ciphertext []byte) ([]byte, error) {
|
||||
var blob wrapping.EncryptedBlobInfo
|
||||
if err := proto.Unmarshal(ciphertext, &blob); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.wrapper.Decrypt(ctx, &blob, []byte(b.aad))
|
||||
}
|
||||
|
||||
// GetByType returns a list of stored items of the specified type
|
||||
func (b *BoltStorage) GetByType(ctx context.Context, indexType string) ([][]byte, error) {
|
||||
var returnBytes [][]byte
|
||||
|
||||
err := b.db.View(func(tx *bolt.Tx) error {
|
||||
var errors *multierror.Error
|
||||
|
||||
tx.Bucket([]byte(indexType)).ForEach(func(id, ciphertext []byte) error {
|
||||
plaintext, err := b.decrypt(ctx, ciphertext)
|
||||
if err != nil {
|
||||
errors = multierror.Append(errors, fmt.Errorf("error decrypting index id %s: %w", id, err))
|
||||
return nil
|
||||
}
|
||||
|
||||
returnBytes = append(returnBytes, plaintext)
|
||||
return nil
|
||||
})
|
||||
return errors.ErrorOrNil()
|
||||
})
|
||||
|
||||
return returnBytes, err
|
||||
}
|
||||
|
||||
// GetAutoAuthToken retrieves the latest auto-auth token, and returns nil if non
|
||||
// exists yet
|
||||
func (b *BoltStorage) GetAutoAuthToken(ctx context.Context) ([]byte, error) {
|
||||
var encryptedToken []byte
|
||||
|
||||
err := b.db.View(func(tx *bolt.Tx) error {
|
||||
meta := tx.Bucket([]byte(metaBucketName))
|
||||
if meta == nil {
|
||||
return fmt.Errorf("bucket %q not found", metaBucketName)
|
||||
}
|
||||
encryptedToken = meta.Get([]byte(AutoAuthToken))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if encryptedToken == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
plaintext, err := b.decrypt(ctx, encryptedToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt auto-auth token: %w", err)
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// GetRetrievalToken retrieves a plaintext token from the KeyBucket, which will
|
||||
// be used by the key manager to retrieve the encryption key, nil if none set
|
||||
func (b *BoltStorage) GetRetrievalToken() ([]byte, error) {
|
||||
var token []byte
|
||||
|
||||
err := b.db.View(func(tx *bolt.Tx) error {
|
||||
keyBucket := tx.Bucket([]byte(metaBucketName))
|
||||
if keyBucket == nil {
|
||||
return fmt.Errorf("bucket %q not found", metaBucketName)
|
||||
}
|
||||
token = keyBucket.Get([]byte(RetrievalTokenMaterial))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return token, err
|
||||
}
|
||||
|
||||
// StoreRetrievalToken sets plaintext token material in the RetrievalTokenBucket
|
||||
func (b *BoltStorage) StoreRetrievalToken(token []byte) error {
|
||||
return b.db.Update(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(metaBucketName))
|
||||
if bucket == nil {
|
||||
return fmt.Errorf("bucket %q not found", metaBucketName)
|
||||
}
|
||||
return bucket.Put([]byte(RetrievalTokenMaterial), token)
|
||||
})
|
||||
}
|
||||
|
||||
// Close the boltdb
|
||||
func (b *BoltStorage) Close() error {
|
||||
b.logger.Trace("closing bolt db", "path", b.db.Path())
|
||||
return b.db.Close()
|
||||
}
|
||||
|
||||
// Clear the boltdb by deleting all the token and lease buckets and recreating
|
||||
// the schema/layout
|
||||
func (b *BoltStorage) Clear() error {
|
||||
return b.db.Update(func(tx *bolt.Tx) error {
|
||||
for _, name := range []string{AuthLeaseType, SecretLeaseType, TokenType} {
|
||||
b.logger.Trace("deleting bolt bucket", "name", name)
|
||||
if err := tx.DeleteBucket([]byte(name)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return createBoltSchema(tx)
|
||||
})
|
||||
}
|
||||
|
||||
// DBFileExists checks whether the vault agent cache file at `filePath` exists
|
||||
func DBFileExists(path string) (bool, error) {
|
||||
checkFile, err := os.OpenFile(filepath.Join(path, DatabaseFileName), os.O_RDWR, 0600)
|
||||
defer checkFile.Close()
|
||||
switch {
|
||||
case err == nil:
|
||||
return true, nil
|
||||
case os.IsNotExist(err):
|
||||
return false, nil
|
||||
default:
|
||||
return false, fmt.Errorf("failed to check if bolt file exists at path %s: %w", path, err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,263 @@
|
|||
package cacheboltdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/command/agent/cache/keymanager"
|
||||
"github.com/ory/dockertest/v3/docker/pkg/ioutils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func getTestKeyManager(t *testing.T) keymanager.KeyManager {
|
||||
t.Helper()
|
||||
|
||||
km, err := keymanager.NewPassthroughKeyManager(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
return km
|
||||
}
|
||||
|
||||
func TestBolt_SetGet(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
path, err := ioutils.TempDir("", "bolt-test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(path)
|
||||
|
||||
b, err := NewBoltStorage(&BoltStorageConfig{
|
||||
Path: path,
|
||||
Logger: hclog.Default(),
|
||||
Wrapper: getTestKeyManager(t).Wrapper(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secrets, err := b.GetByType(ctx, SecretLeaseType)
|
||||
assert.NoError(t, err)
|
||||
require.Len(t, secrets, 0)
|
||||
|
||||
err = b.Set(ctx, "test1", []byte("hello"), SecretLeaseType)
|
||||
assert.NoError(t, err)
|
||||
secrets, err = b.GetByType(ctx, SecretLeaseType)
|
||||
assert.NoError(t, err)
|
||||
require.Len(t, secrets, 1)
|
||||
assert.Equal(t, []byte("hello"), secrets[0])
|
||||
}
|
||||
|
||||
func TestBoltDelete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
path, err := ioutils.TempDir("", "bolt-test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(path)
|
||||
|
||||
b, err := NewBoltStorage(&BoltStorageConfig{
|
||||
Path: path,
|
||||
Logger: hclog.Default(),
|
||||
Wrapper: getTestKeyManager(t).Wrapper(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = b.Set(ctx, "secret-test1", []byte("hello1"), SecretLeaseType)
|
||||
require.NoError(t, err)
|
||||
err = b.Set(ctx, "secret-test2", []byte("hello2"), SecretLeaseType)
|
||||
require.NoError(t, err)
|
||||
|
||||
secrets, err := b.GetByType(ctx, SecretLeaseType)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, secrets, 2)
|
||||
assert.ElementsMatch(t, [][]byte{[]byte("hello1"), []byte("hello2")}, secrets)
|
||||
|
||||
err = b.Delete("secret-test1")
|
||||
require.NoError(t, err)
|
||||
secrets, err = b.GetByType(ctx, SecretLeaseType)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secrets, 1)
|
||||
assert.Equal(t, []byte("hello2"), secrets[0])
|
||||
}
|
||||
|
||||
func TestBoltClear(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
path, err := ioutils.TempDir("", "bolt-test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(path)
|
||||
|
||||
b, err := NewBoltStorage(&BoltStorageConfig{
|
||||
Path: path,
|
||||
Logger: hclog.Default(),
|
||||
Wrapper: getTestKeyManager(t).Wrapper(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Populate the bolt db
|
||||
err = b.Set(ctx, "secret-test1", []byte("hello"), SecretLeaseType)
|
||||
require.NoError(t, err)
|
||||
secrets, err := b.GetByType(ctx, SecretLeaseType)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secrets, 1)
|
||||
assert.Equal(t, []byte("hello"), secrets[0])
|
||||
|
||||
err = b.Set(ctx, "auth-test1", []byte("hello"), AuthLeaseType)
|
||||
require.NoError(t, err)
|
||||
auths, err := b.GetByType(ctx, AuthLeaseType)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, auths, 1)
|
||||
assert.Equal(t, []byte("hello"), auths[0])
|
||||
|
||||
err = b.Set(ctx, "token-test1", []byte("hello"), TokenType)
|
||||
require.NoError(t, err)
|
||||
tokens, err := b.GetByType(ctx, TokenType)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tokens, 1)
|
||||
assert.Equal(t, []byte("hello"), tokens[0])
|
||||
|
||||
// Clear the bolt db, and check that it's indeed clear
|
||||
err = b.Clear()
|
||||
require.NoError(t, err)
|
||||
secrets, err = b.GetByType(ctx, SecretLeaseType)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, secrets, 0)
|
||||
auths, err = b.GetByType(ctx, AuthLeaseType)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, auths, 0)
|
||||
tokens, err = b.GetByType(ctx, TokenType)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tokens, 0)
|
||||
}
|
||||
|
||||
func TestBoltSetAutoAuthToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
path, err := ioutils.TempDir("", "bolt-test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(path)
|
||||
|
||||
b, err := NewBoltStorage(&BoltStorageConfig{
|
||||
Path: path,
|
||||
Logger: hclog.Default(),
|
||||
Wrapper: getTestKeyManager(t).Wrapper(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := b.GetAutoAuthToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, token)
|
||||
|
||||
// set first token
|
||||
err = b.Set(ctx, "token-test1", []byte("hello 1"), TokenType)
|
||||
require.NoError(t, err)
|
||||
secrets, err := b.GetByType(ctx, TokenType)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secrets, 1)
|
||||
assert.Equal(t, []byte("hello 1"), secrets[0])
|
||||
token, err = b.GetAutoAuthToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("hello 1"), token)
|
||||
|
||||
// set second token
|
||||
err = b.Set(ctx, "token-test2", []byte("hello 2"), TokenType)
|
||||
require.NoError(t, err)
|
||||
secrets, err = b.GetByType(ctx, TokenType)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secrets, 2)
|
||||
assert.ElementsMatch(t, [][]byte{[]byte("hello 1"), []byte("hello 2")}, secrets)
|
||||
token, err = b.GetAutoAuthToken(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("hello 2"), token)
|
||||
}
|
||||
|
||||
func TestDBFileExists(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
mkDir bool
|
||||
createFile bool
|
||||
expectExist bool
|
||||
}{
|
||||
{
|
||||
name: "all exists",
|
||||
mkDir: true,
|
||||
createFile: true,
|
||||
expectExist: true,
|
||||
},
|
||||
{
|
||||
name: "dir exist, file missing",
|
||||
mkDir: true,
|
||||
createFile: false,
|
||||
expectExist: false,
|
||||
},
|
||||
{
|
||||
name: "all missing",
|
||||
mkDir: false,
|
||||
createFile: false,
|
||||
expectExist: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var tmpPath string
|
||||
var err error
|
||||
if tc.mkDir {
|
||||
tmpPath, err = ioutil.TempDir("", "test-db-path")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
if tc.createFile {
|
||||
err = ioutil.WriteFile(path.Join(tmpPath, DatabaseFileName), []byte("test-db-path"), 0600)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
exists, err := DBFileExists(tmpPath)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectExist, exists)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func Test_SetGetRetrievalToken(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
tokenToSet []byte
|
||||
expectedToken []byte
|
||||
}{
|
||||
{
|
||||
name: "normal set and get",
|
||||
tokenToSet: []byte("test token"),
|
||||
expectedToken: []byte("test token"),
|
||||
},
|
||||
{
|
||||
name: "no token set",
|
||||
tokenToSet: nil,
|
||||
expectedToken: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
path, err := ioutils.TempDir("", "bolt-test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(path)
|
||||
|
||||
b, err := NewBoltStorage(&BoltStorageConfig{
|
||||
Path: path,
|
||||
Logger: hclog.Default(),
|
||||
Wrapper: getTestKeyManager(t).Wrapper(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer b.Close()
|
||||
|
||||
if tc.tokenToSet != nil {
|
||||
err := b.StoreRetrievalToken(tc.tokenToSet)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
gotKey, err := b.GetRetrievalToken()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedToken, gotKey)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,6 +1,11 @@
|
|||
package cachememdb
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Index holds the response to be cached along with multiple other values that
|
||||
// serve as pointers to refer back to this index.
|
||||
|
@ -48,6 +53,21 @@ type Index struct {
|
|||
// goroutine that manages the renewal of the secret belonging to the
|
||||
// response in this index.
|
||||
RenewCtxInfo *ContextInfo
|
||||
|
||||
// RequestMethod is the HTTP method of the request
|
||||
RequestMethod string
|
||||
|
||||
// RequestToken is the token used in the request
|
||||
RequestToken string
|
||||
|
||||
// RequestHeader is the header used in the request
|
||||
RequestHeader http.Header
|
||||
|
||||
// LastRenewed is the timestamp of last renewal
|
||||
LastRenewed time.Time
|
||||
|
||||
// Type is the index type (token, auth-lease, secret-lease)
|
||||
Type string
|
||||
}
|
||||
|
||||
type IndexName uint32
|
||||
|
@ -106,3 +126,25 @@ func NewContextInfo(ctx context.Context) *ContextInfo {
|
|||
ctxInfo.DoneCh = make(chan struct{})
|
||||
return ctxInfo
|
||||
}
|
||||
|
||||
// Serialize returns a json marshal'ed Index object, without the RenewCtxInfo
|
||||
func (i Index) Serialize() ([]byte, error) {
|
||||
i.RenewCtxInfo = nil
|
||||
|
||||
indexBytes, err := json.Marshal(i)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return indexBytes, nil
|
||||
}
|
||||
|
||||
// Deserialize converts json bytes to an Index object
|
||||
// Note: RenewCtxInfo will need to be reconstructed elsewhere.
|
||||
func Deserialize(indexBytes []byte) (*Index, error) {
|
||||
index := new(Index)
|
||||
if err := json.Unmarshal(indexBytes, index); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return index, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
package cachememdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSerializeDeserialize(t *testing.T) {
|
||||
|
||||
testIndex := &Index{
|
||||
ID: "testid",
|
||||
Token: "testtoken",
|
||||
TokenParent: "parent token",
|
||||
TokenAccessor: "test accessor",
|
||||
Namespace: "test namespace",
|
||||
RequestPath: "/test/path",
|
||||
Lease: "lease id",
|
||||
LeaseToken: "lease token id",
|
||||
Response: []byte(`{"something": "here"}`),
|
||||
RenewCtxInfo: NewContextInfo(context.Background()),
|
||||
RequestMethod: "GET",
|
||||
RequestToken: "request token",
|
||||
RequestHeader: http.Header{
|
||||
"X-Test": []string{"vault", "agent"},
|
||||
},
|
||||
LastRenewed: time.Now().UTC(),
|
||||
}
|
||||
indexBytes, err := testIndex.Serialize()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, len(indexBytes) > 0)
|
||||
assert.NotNil(t, testIndex.RenewCtxInfo, "Serialize should not modify original Index object")
|
||||
|
||||
restoredIndex, err := Deserialize(indexBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
testIndex.RenewCtxInfo = nil
|
||||
assert.Equal(t, testIndex, restoredIndex, "They should be equal without RenewCtxInfo set on the original")
|
||||
}
|
|
@ -1,18 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const (
|
||||
KeyID = "root"
|
||||
)
|
||||
|
||||
type KeyManager interface {
|
||||
GetKey() []byte
|
||||
GetPersistentKey() ([]byte, error)
|
||||
Renewable() bool
|
||||
Renewer(context.Context) error
|
||||
Encrypt(context.Context, []byte, []byte) ([]byte, error)
|
||||
Decrypt(context.Context, []byte, []byte) ([]byte, error)
|
||||
}
|
|
@ -1,97 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
|
||||
wrapping "github.com/hashicorp/go-kms-wrapping"
|
||||
"github.com/hashicorp/go-kms-wrapping/wrappers/aead"
|
||||
)
|
||||
|
||||
var _ KeyManager = (*KubeEncryptionKey)(nil)
|
||||
|
||||
type KubeEncryptionKey struct {
|
||||
renewable bool
|
||||
wrapper *aead.Wrapper
|
||||
}
|
||||
|
||||
// NewK8s returns a new instance of the Kube encryption key. Kubernetes
|
||||
// encryption keys aren't renewable.
|
||||
func NewK8s(existingKey []byte) (*KubeEncryptionKey, error) {
|
||||
k := &KubeEncryptionKey{
|
||||
renewable: false,
|
||||
wrapper: aead.NewWrapper(nil),
|
||||
}
|
||||
|
||||
k.wrapper.SetConfig(map[string]string{"key_id": KeyID})
|
||||
|
||||
var rootKey []byte = nil
|
||||
if len(existingKey) != 0 {
|
||||
if len(existingKey) != 32 {
|
||||
return k, fmt.Errorf("invalid key size, should be 32, got %d", len(existingKey))
|
||||
}
|
||||
rootKey = existingKey
|
||||
}
|
||||
|
||||
if rootKey == nil {
|
||||
newKey := make([]byte, 32)
|
||||
_, err := rand.Read(newKey)
|
||||
if err != nil {
|
||||
return k, err
|
||||
}
|
||||
rootKey = newKey
|
||||
}
|
||||
|
||||
if err := k.wrapper.SetAESGCMKeyBytes(rootKey); err != nil {
|
||||
return k, err
|
||||
}
|
||||
|
||||
return k, nil
|
||||
}
|
||||
|
||||
// GetKey returns the encryption key in a format optimized for storage.
|
||||
// In k8s we store the key as is, so just return the key stored.
|
||||
func (k *KubeEncryptionKey) GetKey() []byte {
|
||||
return k.wrapper.GetKeyBytes()
|
||||
}
|
||||
|
||||
// GetPersistentKey returns the key which should be stored in the persisent
|
||||
// cache. In k8s we store the key as is, so just return the key stored.
|
||||
func (k *KubeEncryptionKey) GetPersistentKey() ([]byte, error) {
|
||||
return k.wrapper.GetKeyBytes(), nil
|
||||
}
|
||||
|
||||
// Renewable lets the caller know if this encryption key type is
|
||||
// renewable. In Kubernetes the key isn't renewable.
|
||||
func (k *KubeEncryptionKey) Renewable() bool {
|
||||
return k.renewable
|
||||
}
|
||||
|
||||
// Renewer is used when the encryption key type is renewable. Since Kubernetes
|
||||
// keys aren't renewable, returning nothing.
|
||||
func (k *KubeEncryptionKey) Renewer(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encrypt takes plaintext values and encrypts them using the store key and additional
|
||||
// data. For Kubernetes the AAD should be the service account JWT.
|
||||
func (k *KubeEncryptionKey) Encrypt(ctx context.Context, plaintext, aad []byte) ([]byte, error) {
|
||||
blob, err := k.wrapper.Encrypt(ctx, plaintext, aad)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return blob.Ciphertext, nil
|
||||
}
|
||||
|
||||
// Decrypt takes ciphertext and AAD values and returns the decrypted value. For Kubernetes the AAD
|
||||
// should be the service account JWT.
|
||||
func (k *KubeEncryptionKey) Decrypt(ctx context.Context, ciphertext, aad []byte) ([]byte, error) {
|
||||
blob := &wrapping.EncryptedBlobInfo{
|
||||
Ciphertext: ciphertext,
|
||||
KeyInfo: &wrapping.KeyInfo{
|
||||
KeyID: KeyID,
|
||||
},
|
||||
}
|
||||
return k.wrapper.Decrypt(ctx, blob, aad)
|
||||
}
|
|
@ -1,155 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCrypto_KubernetesNewKey(t *testing.T) {
|
||||
k8sKey, err := NewK8s([]byte{})
|
||||
if err != nil {
|
||||
t.Fatalf(fmt.Sprintf("unexpected error: %s", err))
|
||||
}
|
||||
|
||||
key := k8sKey.GetKey()
|
||||
if key == nil {
|
||||
t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", key))
|
||||
}
|
||||
|
||||
persistentKey, _ := k8sKey.GetPersistentKey()
|
||||
if persistentKey == nil {
|
||||
t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", persistentKey))
|
||||
}
|
||||
|
||||
if string(key) != string(persistentKey) {
|
||||
t.Fatalf("keys don't match, they should: key: %s, persistentKey: %s", key, persistentKey)
|
||||
}
|
||||
|
||||
plaintextInput := []byte("test")
|
||||
aad := []byte("kubernetes")
|
||||
|
||||
ciphertext, err := k8sKey.Encrypt(nil, plaintextInput, aad)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
if ciphertext == nil {
|
||||
t.Fatalf("ciphertext nil, it shouldn't be")
|
||||
}
|
||||
|
||||
plaintext, err := k8sKey.Decrypt(nil, ciphertext, aad)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
if string(plaintext) != string(plaintextInput) {
|
||||
t.Fatalf("expected %s, got %s", plaintextInput, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCrypto_KubernetesExistingKey(t *testing.T) {
|
||||
rootKey := make([]byte, 32)
|
||||
n, err := rand.Read(rootKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 32 {
|
||||
t.Fatal(n)
|
||||
}
|
||||
|
||||
k8sKey, err := NewK8s(rootKey)
|
||||
if err != nil {
|
||||
t.Fatalf(fmt.Sprintf("unexpected error: %s", err))
|
||||
}
|
||||
|
||||
key := k8sKey.GetKey()
|
||||
if key == nil {
|
||||
t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", key))
|
||||
}
|
||||
|
||||
if string(key) != string(rootKey) {
|
||||
t.Fatalf(fmt.Sprintf("expected keys to be the same, they weren't: expected: %s, got: %s", rootKey, key))
|
||||
}
|
||||
|
||||
persistentKey, _ := k8sKey.GetPersistentKey()
|
||||
if persistentKey == nil {
|
||||
t.Fatalf("key is nil, it shouldn't be")
|
||||
}
|
||||
|
||||
if string(persistentKey) != string(rootKey) {
|
||||
t.Fatalf(fmt.Sprintf("expected keys to be the same, they weren't: expected: %s, got: %s", rootKey, persistentKey))
|
||||
}
|
||||
|
||||
if string(key) != string(persistentKey) {
|
||||
t.Fatalf(fmt.Sprintf("expected keys to be the same, they weren't: %s %s", rootKey, persistentKey))
|
||||
}
|
||||
|
||||
plaintextInput := []byte("test")
|
||||
aad := []byte("kubernetes")
|
||||
|
||||
ciphertext, err := k8sKey.Encrypt(nil, plaintextInput, aad)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
if ciphertext == nil {
|
||||
t.Fatalf("ciphertext nil, it shouldn't be")
|
||||
}
|
||||
|
||||
plaintext, err := k8sKey.Decrypt(nil, ciphertext, aad)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
if string(plaintext) != string(plaintextInput) {
|
||||
t.Fatalf("expected %s, got %s", plaintextInput, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCrypto_KubernetesPassGeneratedKey(t *testing.T) {
|
||||
k8sFirstKey, err := NewK8s([]byte{})
|
||||
if err != nil {
|
||||
t.Fatalf(fmt.Sprintf("unexpected error: %s", err))
|
||||
}
|
||||
|
||||
firstPersistentKey := k8sFirstKey.GetKey()
|
||||
if firstPersistentKey == nil {
|
||||
t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", firstPersistentKey))
|
||||
}
|
||||
|
||||
plaintextInput := []byte("test")
|
||||
aad := []byte("kubernetes")
|
||||
|
||||
ciphertext, err := k8sFirstKey.Encrypt(nil, plaintextInput, aad)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
if ciphertext == nil {
|
||||
t.Fatalf("ciphertext nil, it shouldn't be")
|
||||
}
|
||||
|
||||
k8sLoadedKey, err := NewK8s(firstPersistentKey)
|
||||
if err != nil {
|
||||
t.Fatalf(fmt.Sprintf("unexpected error: %s", err))
|
||||
}
|
||||
|
||||
loadedKey, _ := k8sLoadedKey.GetPersistentKey()
|
||||
if loadedKey == nil {
|
||||
t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", loadedKey))
|
||||
}
|
||||
|
||||
if string(loadedKey) != string(firstPersistentKey) {
|
||||
t.Fatalf(fmt.Sprintf("keys do not match"))
|
||||
}
|
||||
|
||||
plaintext, err := k8sLoadedKey.Decrypt(nil, ciphertext, aad)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
if string(plaintext) != string(plaintextInput) {
|
||||
t.Fatalf("expected %s, got %s", plaintextInput, plaintext)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
package keymanager
|
||||
|
||||
import wrapping "github.com/hashicorp/go-kms-wrapping"
|
||||
|
||||
const (
|
||||
KeyID = "root"
|
||||
)
|
||||
|
||||
type KeyManager interface {
|
||||
// Returns a wrapping.Wrapper which can be used to perform key-related operations.
|
||||
Wrapper() wrapping.Wrapper
|
||||
// RetrievalToken is the material returned which can be used to source back the
|
||||
// encryption key. Depending on the implementation, the token can be the
|
||||
// encryption key itself or a token/identifier used to exchange the token.
|
||||
RetrievalToken() ([]byte, error)
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
package keymanager
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
|
||||
wrapping "github.com/hashicorp/go-kms-wrapping"
|
||||
"github.com/hashicorp/go-kms-wrapping/wrappers/aead"
|
||||
)
|
||||
|
||||
var _ KeyManager = (*PassthroughKeyManager)(nil)
|
||||
|
||||
type PassthroughKeyManager struct {
|
||||
wrapper *aead.Wrapper
|
||||
}
|
||||
|
||||
// NewPassthroughKeyManager returns a new instance of the Kube encryption key.
|
||||
// If a key is provided, it will be used as the encryption key for the wrapper,
|
||||
// otherwise one will be generated.
|
||||
func NewPassthroughKeyManager(key []byte) (*PassthroughKeyManager, error) {
|
||||
|
||||
var rootKey []byte = nil
|
||||
switch len(key) {
|
||||
case 0:
|
||||
newKey := make([]byte, 32)
|
||||
_, err := rand.Read(newKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rootKey = newKey
|
||||
case 32:
|
||||
rootKey = key
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid key size, should be 32, got %d", len(key))
|
||||
}
|
||||
|
||||
wrapper := aead.NewWrapper(nil)
|
||||
|
||||
if _, err := wrapper.SetConfig(map[string]string{"key_id": KeyID}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := wrapper.SetAESGCMKeyBytes(rootKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
k := &PassthroughKeyManager{
|
||||
wrapper: wrapper,
|
||||
}
|
||||
|
||||
return k, nil
|
||||
}
|
||||
|
||||
// Wrapper returns the manager's wrapper for key operations.
|
||||
func (w *PassthroughKeyManager) Wrapper() wrapping.Wrapper {
|
||||
return w.wrapper
|
||||
}
|
||||
|
||||
// RetrievalToken returns the key that was used on the wrapper since this key
|
||||
// manager is simply a passthrough and does not provide a mechanism to abstract
|
||||
// this key.
|
||||
func (w *PassthroughKeyManager) RetrievalToken() ([]byte, error) {
|
||||
if w.wrapper == nil {
|
||||
return nil, fmt.Errorf("unable to get wrapper for token retrieval")
|
||||
}
|
||||
|
||||
return w.wrapper.GetKeyBytes(), nil
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
package keymanager
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestKeyManager_PassthrougKeyManager(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"new key",
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"existing valid key",
|
||||
[]byte("e679e2f3d8d0e489d408bc617c6890d6"),
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid key length",
|
||||
[]byte("foobar"),
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
m, err := NewPassthroughKeyManager(tc.key)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
if w := m.Wrapper(); w == nil {
|
||||
t.Fatalf("expected non-nil wrapper from the key manager")
|
||||
}
|
||||
|
||||
token, err := m.RetrievalToken()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to retrieve token: %s", err)
|
||||
}
|
||||
|
||||
if len(tc.key) != 0 && !bytes.Equal(tc.key, token) {
|
||||
t.Fatalf("expected key bytes: %x, got: %x", tc.key, token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -11,6 +11,7 @@ import (
|
|||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -18,6 +19,7 @@ import (
|
|||
"github.com/hashicorp/errwrap"
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/command/agent/cache/cacheboltdb"
|
||||
cachememdb "github.com/hashicorp/vault/command/agent/cache/cachememdb"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
nshelper "github.com/hashicorp/vault/helper/namespace"
|
||||
|
@ -83,6 +85,13 @@ type LeaseCache struct {
|
|||
|
||||
// inflightCache keeps track of inflight requests
|
||||
inflightCache *gocache.Cache
|
||||
|
||||
// ps is the persistent storage for tokens and leases
|
||||
ps *cacheboltdb.BoltStorage
|
||||
|
||||
// shuttingDown is used to determine if cache needs to be evicted or not
|
||||
// when the context is cancelled
|
||||
shuttingDown atomic.Bool
|
||||
}
|
||||
|
||||
// LeaseCacheConfig is the configuration for initializing a new
|
||||
|
@ -92,6 +101,7 @@ type LeaseCacheConfig struct {
|
|||
BaseContext context.Context
|
||||
Proxier Proxier
|
||||
Logger hclog.Logger
|
||||
Storage *cacheboltdb.BoltStorage
|
||||
}
|
||||
|
||||
type inflightRequest struct {
|
||||
|
@ -141,9 +151,21 @@ func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
|
|||
l: &sync.RWMutex{},
|
||||
idLocks: locksutil.CreateLocks(),
|
||||
inflightCache: gocache.New(gocache.NoExpiration, gocache.NoExpiration),
|
||||
ps: conf.Storage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetShuttingDown is a setter for the shuttingDown field
|
||||
func (c *LeaseCache) SetShuttingDown(in bool) {
|
||||
c.shuttingDown.Store(in)
|
||||
}
|
||||
|
||||
// SetPersistentStorage is a setter for the persistent storage field in
|
||||
// LeaseCache
|
||||
func (c *LeaseCache) SetPersistentStorage(storageIn *cacheboltdb.BoltStorage) {
|
||||
c.ps = storageIn
|
||||
}
|
||||
|
||||
// checkCacheForRequest checks the cache for a particular request based on its
|
||||
// computed ID. It returns a non-nil *SendResponse if an entry is found.
|
||||
func (c *LeaseCache) checkCacheForRequest(id string) (*SendResponse, error) {
|
||||
|
@ -275,6 +297,7 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
|
|||
ID: id,
|
||||
Namespace: namespace,
|
||||
RequestPath: req.Request.URL.Path,
|
||||
LastRenewed: time.Now().UTC(),
|
||||
}
|
||||
|
||||
secret, err := api.ParseSecret(bytes.NewReader(resp.ResponseBody))
|
||||
|
@ -332,6 +355,8 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
|
|||
index.Lease = secret.LeaseID
|
||||
index.LeaseToken = req.Token
|
||||
|
||||
index.Type = cacheboltdb.SecretLeaseType
|
||||
|
||||
case secret.Auth != nil:
|
||||
c.logger.Debug("processing auth response", "method", req.Request.Method, "path", req.Request.URL.Path)
|
||||
|
||||
|
@ -360,6 +385,8 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
|
|||
index.Token = secret.Auth.ClientToken
|
||||
index.TokenAccessor = secret.Auth.Accessor
|
||||
|
||||
index.Type = cacheboltdb.AuthLeaseType
|
||||
|
||||
default:
|
||||
// We shouldn't be hitting this, but will err on the side of caution and
|
||||
// simply proxy.
|
||||
|
@ -394,9 +421,14 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
|
|||
DoneCh: renewCtxInfo.DoneCh,
|
||||
}
|
||||
|
||||
// Add extra information necessary for restoring from persisted cache
|
||||
index.RequestMethod = req.Request.Method
|
||||
index.RequestToken = req.Token
|
||||
index.RequestHeader = req.Request.Header
|
||||
|
||||
// Store the index in the cache
|
||||
c.logger.Debug("storing response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path)
|
||||
err = c.db.Set(index)
|
||||
err = c.Set(ctx, index)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to cache the proxied response", "error", err)
|
||||
return nil, err
|
||||
|
@ -420,8 +452,12 @@ func (c *LeaseCache) createCtxInfo(ctx context.Context) *cachememdb.ContextInfo
|
|||
func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index, req *SendRequest, secret *api.Secret) {
|
||||
defer func() {
|
||||
id := ctx.Value(contextIndexID).(string)
|
||||
if c.shuttingDown.Load() {
|
||||
c.logger.Trace("not evicting index from cache during shutdown", "id", id, "method", req.Request.Method, "path", req.Request.URL.Path)
|
||||
return
|
||||
}
|
||||
c.logger.Debug("evicting index from cache", "id", id, "method", req.Request.Method, "path", req.Request.URL.Path)
|
||||
err := c.db.Evict(cachememdb.IndexNameID, id)
|
||||
err := c.Evict(id)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to evict index", "id", id, "error", err)
|
||||
return
|
||||
|
@ -466,6 +502,11 @@ func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index,
|
|||
return
|
||||
case <-watcher.RenewCh():
|
||||
c.logger.Debug("secret renewed", "path", req.Request.URL.Path)
|
||||
if c.ps != nil {
|
||||
if err := c.updateLastRenewed(ctx, index, time.Now().UTC()); err != nil {
|
||||
c.logger.Warn("not able to update lastRenewed time for cached index", "id", index.ID)
|
||||
}
|
||||
}
|
||||
case <-index.RenewCtxInfo.DoneCh:
|
||||
// This case indicates the renewal process to shutdown and evict
|
||||
// the cache entry. This is triggered when a specific secret
|
||||
|
@ -477,6 +518,22 @@ func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index,
|
|||
}
|
||||
}
|
||||
|
||||
func (c *LeaseCache) updateLastRenewed(ctx context.Context, index *cachememdb.Index, t time.Time) error {
|
||||
idLock := locksutil.LockForKey(c.idLocks, index.ID)
|
||||
idLock.Lock()
|
||||
defer idLock.Unlock()
|
||||
|
||||
getIndex, err := c.db.Get(cachememdb.IndexNameID, index.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
index.LastRenewed = t
|
||||
if err := c.Set(ctx, getIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// computeIndexID results in a value that uniquely identifies a request
|
||||
// received by the agent. It does so by SHA256 hashing the serialized request
|
||||
// object containing the request path, query parameters and body parameters.
|
||||
|
@ -642,8 +699,8 @@ func (c *LeaseCache) handleCacheClear(ctx context.Context, in *cacheClearInput)
|
|||
}
|
||||
c.l.Unlock()
|
||||
|
||||
// Reset the memdb instance
|
||||
if err := c.db.Flush(); err != nil {
|
||||
// Reset the memdb instance (and persistent storage if enabled)
|
||||
if err := c.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -850,6 +907,213 @@ func (c *LeaseCache) handleRevocationRequest(ctx context.Context, req *SendReque
|
|||
return true, nil
|
||||
}
|
||||
|
||||
// Set stores the index in the cachememdb, and also stores it in the persistent
|
||||
// cache (if enabled)
|
||||
func (c *LeaseCache) Set(ctx context.Context, index *cachememdb.Index) error {
|
||||
if err := c.db.Set(index); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.ps != nil {
|
||||
b, err := index.Serialize()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.ps.Set(ctx, index.ID, b, index.Type); err != nil {
|
||||
return err
|
||||
}
|
||||
c.logger.Trace("set entry in persistent storage", "type", index.Type, "path", index.RequestPath, "id", index.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Evict removes an Index from the cachememdb, and also removes it from the
|
||||
// persistent cache (if enabled)
|
||||
func (c *LeaseCache) Evict(id string) error {
|
||||
if err := c.db.Evict(cachememdb.IndexNameID, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.ps != nil {
|
||||
if err := c.ps.Delete(id); err != nil {
|
||||
return err
|
||||
}
|
||||
c.logger.Trace("deleted item from persistent storage", "id", id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush the cachememdb and persistent cache (if enabled)
|
||||
func (c *LeaseCache) Flush() error {
|
||||
if err := c.db.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.ps != nil {
|
||||
c.logger.Trace("clearing persistent storage")
|
||||
return c.ps.Clear()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Restore loads the cachememdb from the persistent storage passed in. Loads
|
||||
// tokens first, since restoring a lease's renewal context and watcher requires
|
||||
// looking up the token in the cachememdb.
|
||||
func (c *LeaseCache) Restore(ctx context.Context, storage *cacheboltdb.BoltStorage) error {
|
||||
// Process tokens first
|
||||
tokens, err := storage.GetByType(ctx, cacheboltdb.TokenType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.restoreTokens(tokens); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then process auth leases
|
||||
authLeases, err := storage.GetByType(ctx, cacheboltdb.AuthLeaseType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.restoreLeases(authLeases); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then process secret leases
|
||||
secretLeases, err := storage.GetByType(ctx, cacheboltdb.SecretLeaseType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.restoreLeases(secretLeases); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *LeaseCache) restoreTokens(tokens [][]byte) error {
|
||||
for _, token := range tokens {
|
||||
newIndex, err := cachememdb.Deserialize(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newIndex.RenewCtxInfo = c.createCtxInfo(nil)
|
||||
if err := c.db.Set(newIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
c.logger.Trace("restored token", "id", newIndex.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *LeaseCache) restoreLeases(leases [][]byte) error {
|
||||
for _, lease := range leases {
|
||||
newIndex, err := cachememdb.Deserialize(lease)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if this lease has already expired
|
||||
expired, err := c.hasExpired(time.Now().UTC(), newIndex)
|
||||
if err != nil {
|
||||
c.logger.Warn("failed to check if lease is expired", "id", newIndex.ID, "error", err)
|
||||
}
|
||||
if expired {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.restoreLeaseRenewCtx(newIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.db.Set(newIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
c.logger.Trace("restored lease", "id", newIndex.ID, "path", newIndex.RequestPath)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreLeaseRenewCtx re-creates a RenewCtx for an index object and starts
|
||||
// the watcher go routine
|
||||
func (c *LeaseCache) restoreLeaseRenewCtx(index *cachememdb.Index) error {
|
||||
if index.Response == nil {
|
||||
return fmt.Errorf("cached response was nil for %s", index.ID)
|
||||
}
|
||||
|
||||
// Parse the secret to determine which type it is
|
||||
reader := bufio.NewReader(bytes.NewReader(index.Response))
|
||||
resp, err := http.ReadResponse(reader, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to deserialize response", "error", err)
|
||||
return err
|
||||
}
|
||||
secret, err := api.ParseSecret(resp.Body)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to parse response as secret", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var renewCtxInfo *cachememdb.ContextInfo
|
||||
switch {
|
||||
case secret.LeaseID != "":
|
||||
entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if entry == nil {
|
||||
return fmt.Errorf("could not find parent Token %s for req path %s", index.RequestToken, index.RequestPath)
|
||||
}
|
||||
|
||||
// Derive a context for renewal using the token's context
|
||||
renewCtxInfo = cachememdb.NewContextInfo(entry.RenewCtxInfo.Ctx)
|
||||
|
||||
case secret.Auth != nil:
|
||||
var parentCtx context.Context
|
||||
if !secret.Auth.Orphan {
|
||||
entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If parent token is not managed by the agent, child shouldn't be
|
||||
// either.
|
||||
if entry == nil {
|
||||
return fmt.Errorf("could not find parent Token %s for req path %s", index.RequestToken, index.RequestPath)
|
||||
}
|
||||
|
||||
c.logger.Debug("setting parent context", "method", index.RequestMethod, "path", index.RequestPath)
|
||||
parentCtx = entry.RenewCtxInfo.Ctx
|
||||
}
|
||||
renewCtxInfo = c.createCtxInfo(parentCtx)
|
||||
default:
|
||||
return fmt.Errorf("unknown cached index item: %s", index.ID)
|
||||
}
|
||||
|
||||
renewCtx := context.WithValue(renewCtxInfo.Ctx, contextIndexID, index.ID)
|
||||
index.RenewCtxInfo = &cachememdb.ContextInfo{
|
||||
Ctx: renewCtx,
|
||||
CancelFunc: renewCtxInfo.CancelFunc,
|
||||
DoneCh: renewCtxInfo.DoneCh,
|
||||
}
|
||||
|
||||
sendReq := &SendRequest{
|
||||
Token: index.RequestToken,
|
||||
Request: &http.Request{
|
||||
Header: index.RequestHeader,
|
||||
Method: index.RequestMethod,
|
||||
URL: &url.URL{
|
||||
Path: index.RequestPath,
|
||||
},
|
||||
},
|
||||
}
|
||||
go c.startRenewing(renewCtx, index, sendReq, secret)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deriveNamespaceAndRevocationPath returns the namespace and relative path for
|
||||
// revocation paths.
|
||||
//
|
||||
|
@ -912,9 +1176,11 @@ func (c *LeaseCache) RegisterAutoAuthToken(token string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// If the index is found, defer its cancelFunc
|
||||
// If the index is found, just keep it in the cache and ignore the incoming
|
||||
// token (since they're the same)
|
||||
if oldIndex != nil {
|
||||
defer oldIndex.RenewCtxInfo.CancelFunc()
|
||||
c.logger.Trace("auto-auth token already exists in cache; no need to store it again")
|
||||
return nil
|
||||
}
|
||||
|
||||
// The following randomly generated values are required for index stored by
|
||||
|
@ -938,6 +1204,7 @@ func (c *LeaseCache) RegisterAutoAuthToken(token string) error {
|
|||
Token: token,
|
||||
Namespace: namespace,
|
||||
RequestPath: requestPath,
|
||||
Type: cacheboltdb.TokenType,
|
||||
}
|
||||
|
||||
// Derive a context off of the lease cache's base context
|
||||
|
@ -951,7 +1218,7 @@ func (c *LeaseCache) RegisterAutoAuthToken(token string) error {
|
|||
|
||||
// Store the index in the cache
|
||||
c.logger.Debug("storing auto-auth token into the cache")
|
||||
err = c.db.Set(index)
|
||||
err = c.Set(c.baseCtxInfo.Ctx, index)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to cache the auto-auth token", "error", err)
|
||||
return err
|
||||
|
@ -997,3 +1264,32 @@ func parseCacheClearInput(req *cacheClearRequest) (*cacheClearInput, error) {
|
|||
|
||||
return in, nil
|
||||
}
|
||||
|
||||
func (c *LeaseCache) hasExpired(currentTime time.Time, index *cachememdb.Index) (bool, error) {
|
||||
reader := bufio.NewReader(bytes.NewReader(index.Response))
|
||||
resp, err := http.ReadResponse(reader, nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to deserialize response: %w", err)
|
||||
}
|
||||
secret, err := api.ParseSecret(resp.Body)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to parse response as secret: %w", err)
|
||||
}
|
||||
|
||||
elapsed := currentTime.Sub(index.LastRenewed)
|
||||
var leaseDuration int
|
||||
switch index.Type {
|
||||
case cacheboltdb.AuthLeaseType:
|
||||
leaseDuration = secret.Auth.LeaseDuration
|
||||
case cacheboltdb.SecretLeaseType:
|
||||
leaseDuration = secret.LeaseDuration
|
||||
default:
|
||||
return false, fmt.Errorf("index type %q unexpected in expiration check", index.Type)
|
||||
}
|
||||
|
||||
if int(elapsed.Seconds()) > leaseDuration {
|
||||
c.logger.Trace("secret has expired", "id", index.ID, "elapsed", elapsed, "lease duration", leaseDuration)
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
|
|
@ -3,9 +3,11 @@ package cache
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -15,9 +17,13 @@ import (
|
|||
"github.com/go-test/deep"
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/command/agent/cache/cacheboltdb"
|
||||
"github.com/hashicorp/vault/command/agent/cache/cachememdb"
|
||||
"github.com/hashicorp/vault/command/agent/cache/keymanager"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
|
@ -63,6 +69,24 @@ func testNewLeaseCacheWithDelay(t *testing.T, cacheable bool, delay int) *LeaseC
|
|||
return lc
|
||||
}
|
||||
|
||||
func testNewLeaseCacheWithPersistence(t *testing.T, responses []*SendResponse, storage *cacheboltdb.BoltStorage) *LeaseCache {
|
||||
t.Helper()
|
||||
|
||||
client, err := api.NewClient(api.DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
lc, err := NewLeaseCache(&LeaseCacheConfig{
|
||||
Client: client,
|
||||
BaseContext: context.Background(),
|
||||
Proxier: newMockProxier(responses),
|
||||
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"),
|
||||
Storage: storage,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return lc
|
||||
}
|
||||
|
||||
func TestCache_ComputeIndexID(t *testing.T) {
|
||||
type args struct {
|
||||
req *http.Request
|
||||
|
@ -649,3 +673,394 @@ func TestLeaseCache_Concurrent_Cacheable(t *testing.T) {
|
|||
t.Fatalf("Should have returned a cached response 99 times, got %d", cacheCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func setupBoltStorage(t *testing.T) (tempCacheDir string, boltStorage *cacheboltdb.BoltStorage) {
|
||||
t.Helper()
|
||||
|
||||
km, err := keymanager.NewPassthroughKeyManager(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
tempCacheDir, err = ioutil.TempDir("", "agent-cache-test")
|
||||
require.NoError(t, err)
|
||||
boltStorage, err = cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
|
||||
Path: tempCacheDir,
|
||||
Logger: hclog.Default(),
|
||||
Wrapper: km.Wrapper(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, boltStorage)
|
||||
// The calling function should `defer boltStorage.Close()` and `defer os.RemoveAll(tempCacheDir)`
|
||||
return tempCacheDir, boltStorage
|
||||
}
|
||||
|
||||
func TestLeaseCache_PersistAndRestore(t *testing.T) {
|
||||
// Emulate 4 responses from the api proxy. The first two use the auto-auth
|
||||
// token, and the last two use another token.
|
||||
responses := []*SendResponse{
|
||||
newTestSendResponse(200, `{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": 600}}`),
|
||||
newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}, "lease_duration": 600}`),
|
||||
newTestSendResponse(202, `{"auth": {"client_token": "testtoken2", "renewable": true, "orphan": true, "lease_duration": 600}}`),
|
||||
newTestSendResponse(203, `{"lease_id": "secret2-lease", "renewable": true, "data": {"number": "two"}, "lease_duration": 600}`),
|
||||
}
|
||||
|
||||
tempDir, boltStorage := setupBoltStorage(t)
|
||||
defer os.RemoveAll(tempDir)
|
||||
defer boltStorage.Close()
|
||||
lc := testNewLeaseCacheWithPersistence(t, responses, boltStorage)
|
||||
|
||||
// Register an auto-auth token so that the token and lease requests are cached
|
||||
lc.RegisterAutoAuthToken("autoauthtoken")
|
||||
|
||||
cacheTests := []struct {
|
||||
token string
|
||||
method string
|
||||
urlPath string
|
||||
body string
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
// Make a request. A response with a new token is returned to the
|
||||
// lease cache and that will be cached.
|
||||
token: "autoauthtoken",
|
||||
method: "GET",
|
||||
urlPath: "http://example.com/v1/sample/api",
|
||||
body: `{"value": "input"}`,
|
||||
wantStatusCode: responses[0].Response.StatusCode,
|
||||
},
|
||||
{
|
||||
// Modify the request a little bit to ensure the second response is
|
||||
// returned to the lease cache.
|
||||
token: "autoauthtoken",
|
||||
method: "GET",
|
||||
urlPath: "http://example.com/v1/sample/api",
|
||||
body: `{"value": "input_changed"}`,
|
||||
wantStatusCode: responses[1].Response.StatusCode,
|
||||
},
|
||||
{
|
||||
// Simulate an approle login to get another token
|
||||
method: "PUT",
|
||||
urlPath: "http://example.com/v1/auth/approle/login",
|
||||
body: `{"role_id": "my role", "secret_id": "my secret"}`,
|
||||
wantStatusCode: responses[2].Response.StatusCode,
|
||||
},
|
||||
{
|
||||
// Test caching with the token acquired from the approle login
|
||||
token: "testtoken2",
|
||||
method: "GET",
|
||||
urlPath: "http://example.com/v1/sample2/api",
|
||||
body: `{"second": "input"}`,
|
||||
wantStatusCode: responses[3].Response.StatusCode,
|
||||
},
|
||||
}
|
||||
|
||||
for _, ct := range cacheTests {
|
||||
// Send once to cache
|
||||
sendReq := &SendRequest{
|
||||
Token: ct.token,
|
||||
Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)),
|
||||
}
|
||||
resp, err := lc.Send(context.Background(), sendReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, resp.Response.StatusCode, ct.wantStatusCode, "expected proxied response")
|
||||
assert.Nil(t, resp.CacheMeta)
|
||||
|
||||
// Send again to test cache. If this isn't cached, the response returned
|
||||
// will be the next in the list and the status code will not match.
|
||||
sendCacheReq := &SendRequest{
|
||||
Token: ct.token,
|
||||
Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)),
|
||||
}
|
||||
respCached, err := lc.Send(context.Background(), sendCacheReq)
|
||||
require.NoError(t, err, "failed to send request %+v", ct)
|
||||
assert.Equal(t, respCached.Response.StatusCode, ct.wantStatusCode, "expected proxied response")
|
||||
require.NotNil(t, respCached.CacheMeta)
|
||||
assert.True(t, respCached.CacheMeta.Hit)
|
||||
}
|
||||
|
||||
// Now we know the cache is working, so try restoring from the persisted
|
||||
// cache's storage
|
||||
restoredCache := testNewLeaseCache(t, nil)
|
||||
|
||||
err := restoredCache.Restore(context.Background(), boltStorage)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Now compare before and after
|
||||
beforeDB, err := lc.db.GetByPrefix(cachememdb.IndexNameID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, beforeDB, 5)
|
||||
|
||||
for _, cachedItem := range beforeDB {
|
||||
restoredItem, err := restoredCache.db.Get(cachememdb.IndexNameID, cachedItem.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, cachedItem.ID, restoredItem.ID)
|
||||
assert.Equal(t, cachedItem.Lease, restoredItem.Lease)
|
||||
assert.Equal(t, cachedItem.LeaseToken, restoredItem.LeaseToken)
|
||||
assert.Equal(t, cachedItem.Namespace, restoredItem.Namespace)
|
||||
assert.Equal(t, cachedItem.RequestHeader, restoredItem.RequestHeader)
|
||||
assert.Equal(t, cachedItem.RequestMethod, restoredItem.RequestMethod)
|
||||
assert.Equal(t, cachedItem.RequestPath, restoredItem.RequestPath)
|
||||
assert.Equal(t, cachedItem.RequestToken, restoredItem.RequestToken)
|
||||
assert.Equal(t, cachedItem.Response, restoredItem.Response)
|
||||
assert.Equal(t, cachedItem.Token, restoredItem.Token)
|
||||
assert.Equal(t, cachedItem.TokenAccessor, restoredItem.TokenAccessor)
|
||||
assert.Equal(t, cachedItem.TokenParent, restoredItem.TokenParent)
|
||||
|
||||
// check what we can in the renewal context
|
||||
assert.NotEmpty(t, restoredItem.RenewCtxInfo.CancelFunc)
|
||||
assert.NotZero(t, restoredItem.RenewCtxInfo.DoneCh)
|
||||
require.NotEmpty(t, restoredItem.RenewCtxInfo.Ctx)
|
||||
assert.Equal(t,
|
||||
cachedItem.RenewCtxInfo.Ctx.Value(contextIndexID),
|
||||
restoredItem.RenewCtxInfo.Ctx.Value(contextIndexID),
|
||||
)
|
||||
}
|
||||
afterDB, err := restoredCache.db.GetByPrefix(cachememdb.IndexNameID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, afterDB, 5)
|
||||
|
||||
// And finally send the cache requests once to make sure they're all being
|
||||
// served from the restoredCache
|
||||
for _, ct := range cacheTests {
|
||||
sendCacheReq := &SendRequest{
|
||||
Token: ct.token,
|
||||
Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)),
|
||||
}
|
||||
respCached, err := restoredCache.Send(context.Background(), sendCacheReq)
|
||||
require.NoError(t, err, "failed to send request %+v", ct)
|
||||
assert.Equal(t, respCached.Response.StatusCode, ct.wantStatusCode, "expected proxied response")
|
||||
require.NotNil(t, respCached.CacheMeta)
|
||||
assert.True(t, respCached.CacheMeta.Hit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvictPersistent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
responses := []*SendResponse{
|
||||
newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}}`),
|
||||
}
|
||||
|
||||
tempDir, boltStorage := setupBoltStorage(t)
|
||||
defer os.RemoveAll(tempDir)
|
||||
defer boltStorage.Close()
|
||||
lc := testNewLeaseCacheWithPersistence(t, responses, boltStorage)
|
||||
|
||||
lc.RegisterAutoAuthToken("autoauthtoken")
|
||||
|
||||
// populate cache by sending request through
|
||||
sendReq := &SendRequest{
|
||||
Token: "autoauthtoken",
|
||||
Request: httptest.NewRequest("GET", "http://example.com/v1/sample/api", strings.NewReader(`{"value": "some_input"}`)),
|
||||
}
|
||||
resp, err := lc.Send(context.Background(), sendReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, resp.Response.StatusCode, 201, "expected proxied response")
|
||||
assert.Nil(t, resp.CacheMeta)
|
||||
|
||||
// Check bolt for the cached lease
|
||||
secrets, err := lc.ps.GetByType(ctx, cacheboltdb.SecretLeaseType)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, secrets, 1)
|
||||
|
||||
// Call clear for the request path
|
||||
err = lc.handleCacheClear(context.Background(), &cacheClearInput{
|
||||
Type: "request_path",
|
||||
RequestPath: "/v1/sample/api",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Check that cached item is gone
|
||||
secrets, err = lc.ps.GetByType(ctx, cacheboltdb.SecretLeaseType)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, secrets, 0)
|
||||
}
|
||||
|
||||
func TestRegisterAutoAuth_sameToken(t *testing.T) {
|
||||
// If the auto-auth token already exists in the cache, it should not be
|
||||
// stored again in a new index.
|
||||
lc := testNewLeaseCache(t, nil)
|
||||
err := lc.RegisterAutoAuthToken("autoauthtoken")
|
||||
assert.NoError(t, err)
|
||||
|
||||
oldTokenIndex, err := lc.db.Get(cachememdb.IndexNameToken, "autoauthtoken")
|
||||
assert.NoError(t, err)
|
||||
oldTokenID := oldTokenIndex.ID
|
||||
|
||||
// register the same token again
|
||||
err = lc.RegisterAutoAuthToken("autoauthtoken")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// check that there's only one index for autoauthtoken
|
||||
entries, err := lc.db.GetByPrefix(cachememdb.IndexNameToken, "autoauthtoken")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, entries, 1)
|
||||
|
||||
newTokenIndex, err := lc.db.Get(cachememdb.IndexNameToken, "autoauthtoken")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// compare the ID's since those are randomly generated when an index for a
|
||||
// token is added to the cache, so if a new token was added, the id's will
|
||||
// not match.
|
||||
assert.Equal(t, oldTokenID, newTokenIndex.ID)
|
||||
}
|
||||
|
||||
func Test_hasExpired(t *testing.T) {
|
||||
|
||||
responses := []*SendResponse{
|
||||
newTestSendResponse(200, `{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": 60}}`),
|
||||
newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}, "lease_duration": 60}`),
|
||||
}
|
||||
lc := testNewLeaseCache(t, responses)
|
||||
lc.RegisterAutoAuthToken("autoauthtoken")
|
||||
|
||||
cacheTests := []struct {
|
||||
token string
|
||||
urlPath string
|
||||
leaseType string
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
// auth lease
|
||||
token: "autoauthtoken",
|
||||
urlPath: "/v1/sample/auth",
|
||||
leaseType: cacheboltdb.AuthLeaseType,
|
||||
wantStatusCode: responses[0].Response.StatusCode,
|
||||
},
|
||||
{
|
||||
// secret lease
|
||||
token: "autoauthtoken",
|
||||
urlPath: "/v1/sample/secret",
|
||||
leaseType: cacheboltdb.SecretLeaseType,
|
||||
wantStatusCode: responses[1].Response.StatusCode,
|
||||
},
|
||||
}
|
||||
|
||||
for _, ct := range cacheTests {
|
||||
// Send once to cache
|
||||
urlPath := "http://example.com" + ct.urlPath
|
||||
sendReq := &SendRequest{
|
||||
Token: ct.token,
|
||||
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
|
||||
}
|
||||
resp, err := lc.Send(context.Background(), sendReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, resp.Response.StatusCode, ct.wantStatusCode, "expected proxied response")
|
||||
assert.Nil(t, resp.CacheMeta)
|
||||
|
||||
// get the Index out of the mem cache
|
||||
index, err := lc.db.Get(cachememdb.IndexNameRequestPath, "root/", ct.urlPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ct.leaseType, index.Type)
|
||||
|
||||
// The lease duration is 60 seconds, so time.Now() should be within that
|
||||
notExpired, err := lc.hasExpired(time.Now().UTC(), index)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, notExpired)
|
||||
|
||||
// In 90 seconds the index should be "expired"
|
||||
futureTime := time.Now().UTC().Add(time.Second * 90)
|
||||
expired, err := lc.hasExpired(futureTime, index)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, expired)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestLeaseCache_hasExpired_wrong_type(t *testing.T) {
|
||||
index := &cachememdb.Index{
|
||||
Type: cacheboltdb.TokenType,
|
||||
Response: []byte(`HTTP/0.0 200 OK
|
||||
Content-Type: application/json
|
||||
Date: Tue, 02 Mar 2021 17:54:16 GMT
|
||||
|
||||
{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": 60}}`),
|
||||
}
|
||||
|
||||
lc := testNewLeaseCache(t, nil)
|
||||
expired, err := lc.hasExpired(time.Now().UTC(), index)
|
||||
assert.False(t, expired)
|
||||
assert.EqualError(t, err, `index type "token" unexpected in expiration check`)
|
||||
}
|
||||
|
||||
func TestLeaseCacheRestore_expired(t *testing.T) {
|
||||
// Emulate 2 responses from the api proxy, both expired
|
||||
responses := []*SendResponse{
|
||||
newTestSendResponse(200, `{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": -600}}`),
|
||||
newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}, "lease_duration": -600}`),
|
||||
}
|
||||
|
||||
tempDir, boltStorage := setupBoltStorage(t)
|
||||
defer os.RemoveAll(tempDir)
|
||||
defer boltStorage.Close()
|
||||
lc := testNewLeaseCacheWithPersistence(t, responses, boltStorage)
|
||||
|
||||
// Register an auto-auth token so that the token and lease requests are cached in mem
|
||||
lc.RegisterAutoAuthToken("autoauthtoken")
|
||||
|
||||
cacheTests := []struct {
|
||||
token string
|
||||
method string
|
||||
urlPath string
|
||||
body string
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
// Make a request. A response with a new token is returned to the
|
||||
// lease cache and that will be cached.
|
||||
token: "autoauthtoken",
|
||||
method: "GET",
|
||||
urlPath: "http://example.com/v1/sample/api",
|
||||
body: `{"value": "input"}`,
|
||||
wantStatusCode: responses[0].Response.StatusCode,
|
||||
},
|
||||
{
|
||||
// Modify the request a little bit to ensure the second response is
|
||||
// returned to the lease cache.
|
||||
token: "autoauthtoken",
|
||||
method: "GET",
|
||||
urlPath: "http://example.com/v1/sample/api",
|
||||
body: `{"value": "input_changed"}`,
|
||||
wantStatusCode: responses[1].Response.StatusCode,
|
||||
},
|
||||
}
|
||||
|
||||
for _, ct := range cacheTests {
|
||||
// Send once to cache
|
||||
sendReq := &SendRequest{
|
||||
Token: ct.token,
|
||||
Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)),
|
||||
}
|
||||
resp, err := lc.Send(context.Background(), sendReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, resp.Response.StatusCode, ct.wantStatusCode, "expected proxied response")
|
||||
assert.Nil(t, resp.CacheMeta)
|
||||
}
|
||||
|
||||
// Restore from the persisted cache's storage
|
||||
restoredCache := testNewLeaseCache(t, nil)
|
||||
|
||||
err := restoredCache.Restore(context.Background(), boltStorage)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// The original mem cache should have all three items
|
||||
beforeDB, err := lc.db.GetByPrefix(cachememdb.IndexNameID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, beforeDB, 3)
|
||||
|
||||
// There should only be one item in the restored cache: the autoauth token
|
||||
afterDB, err := restoredCache.db.GetByPrefix(cachememdb.IndexNameID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, afterDB, 1)
|
||||
|
||||
// Just verify that the one item in the restored mem cache matches one in the original mem cache, and that it's the auto-auth token
|
||||
beforeItem, err := lc.db.Get(cachememdb.IndexNameID, afterDB[0].ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, beforeItem)
|
||||
|
||||
assert.Equal(t, "autoauthtoken", afterDB[0].Token)
|
||||
assert.Equal(t, cacheboltdb.TokenType, afterDB[0].Type)
|
||||
}
|
||||
|
|
|
@ -50,6 +50,16 @@ type Cache struct {
|
|||
ForceAutoAuthToken bool `hcl:"-"`
|
||||
EnforceConsistency string `hcl:"enforce_consistency"`
|
||||
WhenInconsistent string `hcl:"when_inconsistent"`
|
||||
Persist *Persist `hcl:"persist"`
|
||||
}
|
||||
|
||||
// Persist contains configuration needed for persistent caching
|
||||
type Persist struct {
|
||||
Type string
|
||||
Path string `hcl:"path"`
|
||||
KeepAfterImport bool `hcl:"keep_after_import"`
|
||||
ExitOnErr bool `hcl:"exit_on_err"`
|
||||
ServiceAccountTokenFile string `hcl:"service_account_token_file"`
|
||||
}
|
||||
|
||||
// AutoAuth is the configured authentication method and sinks
|
||||
|
@ -309,8 +319,51 @@ func parseCache(result *Config, list *ast.ObjectList) error {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
result.Cache = &c
|
||||
|
||||
subs, ok := item.Val.(*ast.ObjectType)
|
||||
if !ok {
|
||||
return fmt.Errorf("could not parse %q as an object", name)
|
||||
}
|
||||
subList := subs.List
|
||||
if err := parsePersist(result, subList); err != nil {
|
||||
return fmt.Errorf("error parsing persist: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parsePersist(result *Config, list *ast.ObjectList) error {
|
||||
name := "persist"
|
||||
|
||||
persistList := list.Filter(name)
|
||||
if len(persistList.Items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(persistList.Items) > 1 {
|
||||
return fmt.Errorf("only one %q block is required", name)
|
||||
}
|
||||
|
||||
item := persistList.Items[0]
|
||||
|
||||
var p Persist
|
||||
err := hcl.DecodeObject(&p, item.Val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if p.Type == "" {
|
||||
if len(item.Keys) == 1 {
|
||||
p.Type = strings.ToLower(item.Keys[0].Token.Value().(string))
|
||||
}
|
||||
if p.Type == "" {
|
||||
return errors.New("persist type must be specified")
|
||||
}
|
||||
}
|
||||
|
||||
result.Cache.Persist = &p
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -66,6 +66,13 @@ func TestLoadConfigFile_AgentCache(t *testing.T) {
|
|||
UseAutoAuthToken: true,
|
||||
UseAutoAuthTokenRaw: true,
|
||||
ForceAutoAuthToken: false,
|
||||
Persist: &Persist{
|
||||
Type: "kubernetes",
|
||||
Path: "/vault/agent-cache/",
|
||||
KeepAfterImport: true,
|
||||
ExitOnErr: true,
|
||||
ServiceAccountTokenFile: "/tmp/serviceaccount/token",
|
||||
},
|
||||
},
|
||||
Vault: &Vault{
|
||||
Address: "http://127.0.0.1:1111",
|
||||
|
@ -445,6 +452,52 @@ func TestLoadConfigFile_AgentCache_AutoAuth_False(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigFile_AgentCache_Persist(t *testing.T) {
|
||||
config, err := LoadConfig("./test-fixtures/config-cache-persist-false.hcl")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expected := &Config{
|
||||
Cache: &Cache{
|
||||
Persist: &Persist{
|
||||
Type: "kubernetes",
|
||||
Path: "/vault/agent-cache/",
|
||||
KeepAfterImport: false,
|
||||
ExitOnErr: false,
|
||||
ServiceAccountTokenFile: "",
|
||||
},
|
||||
},
|
||||
SharedConfig: &configutil.SharedConfig{
|
||||
PidFile: "./pidfile",
|
||||
Listeners: []*configutil.Listener{
|
||||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:8300",
|
||||
TLSDisable: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config.Listeners[0].RawConfig = nil
|
||||
if diff := deep.Equal(config, expected); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
|
||||
config.Listeners[0].RawConfig = nil
|
||||
if diff := deep.Equal(config, expected); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigFile_AgentCache_PersistMissingType(t *testing.T) {
|
||||
_, err := LoadConfig("./test-fixtures/config-cache-persist-empty-type.hcl")
|
||||
if err == nil || os.IsNotExist(err) {
|
||||
t.Fatal("expected error or file is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfigFile_Template tests template definitions in Vault Agent
|
||||
func TestLoadConfigFile_Template(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
|
|
|
@ -21,6 +21,12 @@ auto_auth {
|
|||
|
||||
cache {
|
||||
use_auto_auth_token = true
|
||||
persist "kubernetes" {
|
||||
path = "/vault/agent-cache/"
|
||||
keep_after_import = true
|
||||
exit_on_err = true
|
||||
service_account_token_file = "/tmp/serviceaccount/token"
|
||||
}
|
||||
}
|
||||
|
||||
listener {
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
pid_file = "./pidfile"
|
||||
|
||||
cache {
|
||||
persist = {
|
||||
path = "/vault/agent-cache/"
|
||||
}
|
||||
}
|
||||
|
||||
listener "tcp" {
|
||||
address = "127.0.0.1:8300"
|
||||
tls_disable = true
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
pid_file = "./pidfile"
|
||||
|
||||
cache {
|
||||
persist "kubernetes" {
|
||||
exit_on_err = false
|
||||
keep_after_import = false
|
||||
path = "/vault/agent-cache/"
|
||||
}
|
||||
}
|
||||
|
||||
listener "tcp" {
|
||||
address = "127.0.0.1:8300"
|
||||
tls_disable = true
|
||||
}
|
|
@ -21,6 +21,13 @@ auto_auth {
|
|||
|
||||
cache {
|
||||
use_auto_auth_token = true
|
||||
persist = {
|
||||
type = "kubernetes"
|
||||
path = "/vault/agent-cache/"
|
||||
keep_after_import = true
|
||||
exit_on_err = true
|
||||
service_account_token_file = "/tmp/serviceaccount/token"
|
||||
}
|
||||
}
|
||||
|
||||
listener "unix" {
|
||||
|
|
|
@ -256,7 +256,7 @@ func newRunnerConfig(sc *ServerConfig, templates ctconfig.TemplateConfigs) (*ctc
|
|||
}
|
||||
|
||||
// Use the cache if available or fallback to the Vault server values.
|
||||
if sc.AgentConfig.Cache != nil && len(sc.AgentConfig.Listeners) != 0 {
|
||||
if sc.AgentConfig.Cache != nil && sc.AgentConfig.Cache.Persist != nil && len(sc.AgentConfig.Listeners) != 0 {
|
||||
scheme := "unix://"
|
||||
if sc.AgentConfig.Listeners[0].Type == "tcp" {
|
||||
scheme = "https://"
|
||||
|
|
|
@ -28,7 +28,7 @@ func TestNewServer(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func newAgentConfig(listeners []*configutil.Listener, enableCache bool) *config.Config {
|
||||
func newAgentConfig(listeners []*configutil.Listener, enableCache, enablePersisentCache bool) *config.Config {
|
||||
agentConfig := &config.Config{
|
||||
SharedConfig: &configutil.SharedConfig{
|
||||
PidFile: "./pidfile",
|
||||
|
@ -65,7 +65,13 @@ func newAgentConfig(listeners []*configutil.Listener, enableCache bool) *config.
|
|||
},
|
||||
}
|
||||
if enableCache {
|
||||
agentConfig.Cache = &config.Cache{UseAutoAuthToken: true}
|
||||
agentConfig.Cache = &config.Cache{
|
||||
UseAutoAuthToken: true,
|
||||
}
|
||||
}
|
||||
|
||||
if enablePersisentCache {
|
||||
agentConfig.Cache.Persist = &config.Persist{Type: "kubernetes"}
|
||||
}
|
||||
|
||||
return agentConfig
|
||||
|
@ -94,7 +100,7 @@ func TestCacheConfigUnix(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
agentConfig := newAgentConfig(listeners, true)
|
||||
agentConfig := newAgentConfig(listeners, true, true)
|
||||
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
||||
|
||||
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
||||
|
@ -131,7 +137,7 @@ func TestCacheConfigHTTP(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
agentConfig := newAgentConfig(listeners, true)
|
||||
agentConfig := newAgentConfig(listeners, true, true)
|
||||
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
||||
|
||||
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
||||
|
@ -168,7 +174,7 @@ func TestCacheConfigHTTPS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
agentConfig := newAgentConfig(listeners, true)
|
||||
agentConfig := newAgentConfig(listeners, true, true)
|
||||
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
||||
|
||||
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
||||
|
@ -209,7 +215,44 @@ func TestCacheConfigNoCache(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
agentConfig := newAgentConfig(listeners, false)
|
||||
agentConfig := newAgentConfig(listeners, false, false)
|
||||
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
||||
|
||||
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
expected := "http://127.0.0.1:1111"
|
||||
if *ctConfig.Vault.Address != expected {
|
||||
t.Fatalf("expected %s, got %s", expected, *ctConfig.Vault.Address)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheConfigNoPersistentCache(t *testing.T) {
|
||||
listeners := []*configutil.Listener{
|
||||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:8300",
|
||||
TLSKeyFile: "/path/to/cakey.pem",
|
||||
TLSCertFile: "/path/to/cacert.pem",
|
||||
},
|
||||
{
|
||||
Type: "unix",
|
||||
Address: "foobar",
|
||||
TLSDisable: true,
|
||||
SocketMode: "configmode",
|
||||
SocketUser: "configuser",
|
||||
SocketGroup: "configgroup",
|
||||
},
|
||||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:8400",
|
||||
TLSDisable: true,
|
||||
},
|
||||
}
|
||||
|
||||
agentConfig := newAgentConfig(listeners, true, false)
|
||||
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
||||
|
||||
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
||||
|
@ -226,7 +269,7 @@ func TestCacheConfigNoCache(t *testing.T) {
|
|||
func TestCacheConfigNoListener(t *testing.T) {
|
||||
listeners := []*configutil.Listener{}
|
||||
|
||||
agentConfig := newAgentConfig(listeners, true)
|
||||
agentConfig := newAgentConfig(listeners, true, true)
|
||||
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
||||
|
||||
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
||||
|
@ -264,7 +307,7 @@ func TestCacheConfigRejectMTLS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
agentConfig := newAgentConfig(listeners, true)
|
||||
agentConfig := newAgentConfig(listeners, true, true)
|
||||
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
||||
|
||||
_, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
||||
|
|
Loading…
Reference in New Issue