agent: persistent caching support (#10938)
Adds the option of a write-through cache, backed by boltdb Co-authored-by: Theron Voran <tvoran@users.noreply.github.com> Co-authored-by: Jason O'Donnell <2160810+jasonodonnell@users.noreply.github.com> Co-authored-by: Calvin Leung Huang <cleung2010@gmail.com>
This commit is contained in:
parent
910b45413b
commit
1fdf08b149
|
@ -1,3 +1,3 @@
|
||||||
```release-note:improvement
|
```release-note:improvement
|
||||||
agent: Route templating server through cache when enabled.
|
agent: Route templating server through cache when persistent cache is enabled.
|
||||||
```
|
```
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
```release-note:feature
|
||||||
|
agent: Support for persisting the agent cache to disk
|
||||||
|
```
|
189
command/agent.go
189
command/agent.go
|
@ -6,10 +6,12 @@ import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
"path/filepath"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -30,6 +32,9 @@ import (
|
||||||
"github.com/hashicorp/vault/command/agent/auth/kerberos"
|
"github.com/hashicorp/vault/command/agent/auth/kerberos"
|
||||||
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
|
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
|
||||||
"github.com/hashicorp/vault/command/agent/cache"
|
"github.com/hashicorp/vault/command/agent/cache"
|
||||||
|
"github.com/hashicorp/vault/command/agent/cache/cacheboltdb"
|
||||||
|
"github.com/hashicorp/vault/command/agent/cache/cachememdb"
|
||||||
|
"github.com/hashicorp/vault/command/agent/cache/keymanager"
|
||||||
agentConfig "github.com/hashicorp/vault/command/agent/config"
|
agentConfig "github.com/hashicorp/vault/command/agent/config"
|
||||||
"github.com/hashicorp/vault/command/agent/sink"
|
"github.com/hashicorp/vault/command/agent/sink"
|
||||||
"github.com/hashicorp/vault/command/agent/sink/file"
|
"github.com/hashicorp/vault/command/agent/sink/file"
|
||||||
|
@ -461,6 +466,8 @@ func (c *AgentCommand) Run(args []string) int {
|
||||||
c.UI.Output("==> Vault agent started! Log data will stream in below:\n")
|
c.UI.Output("==> Vault agent started! Log data will stream in below:\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var leaseCache *cache.LeaseCache
|
||||||
|
var previousToken string
|
||||||
// Parse agent listener configurations
|
// Parse agent listener configurations
|
||||||
if config.Cache != nil && len(config.Listeners) != 0 {
|
if config.Cache != nil && len(config.Listeners) != 0 {
|
||||||
cacheLogger := c.logger.Named("cache")
|
cacheLogger := c.logger.Named("cache")
|
||||||
|
@ -479,7 +486,7 @@ func (c *AgentCommand) Run(args []string) int {
|
||||||
|
|
||||||
// Create the lease cache proxier and set its underlying proxier to
|
// Create the lease cache proxier and set its underlying proxier to
|
||||||
// the API proxier.
|
// the API proxier.
|
||||||
leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{
|
leaseCache, err = cache.NewLeaseCache(&cache.LeaseCacheConfig{
|
||||||
Client: client,
|
Client: client,
|
||||||
BaseContext: ctx,
|
BaseContext: ctx,
|
||||||
Proxier: apiProxy,
|
Proxier: apiProxy,
|
||||||
|
@ -490,6 +497,152 @@ func (c *AgentCommand) Run(args []string) int {
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Configure persistent storage and add to LeaseCache
|
||||||
|
if config.Cache.Persist != nil {
|
||||||
|
if config.Cache.Persist.Path == "" {
|
||||||
|
c.UI.Error("must specify persistent cache path")
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set AAD based on key protection type
|
||||||
|
var aad string
|
||||||
|
switch config.Cache.Persist.Type {
|
||||||
|
case "kubernetes":
|
||||||
|
aad, err = getServiceAccountJWT(config.Cache.Persist.ServiceAccountTokenFile)
|
||||||
|
if err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("failed to read service account token from %s: %s", config.Cache.Persist.ServiceAccountTokenFile, err))
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
c.UI.Error(fmt.Sprintf("persistent key protection type %q not supported", config.Cache.Persist.Type))
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if bolt file exists already
|
||||||
|
dbFileExists, err := cacheboltdb.DBFileExists(config.Cache.Persist.Path)
|
||||||
|
if err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("failed to check if bolt file exists at path %s: %s", config.Cache.Persist.Path, err))
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if dbFileExists {
|
||||||
|
// Open the bolt file, but wait to setup Encryption
|
||||||
|
ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
|
||||||
|
Path: config.Cache.Persist.Path,
|
||||||
|
Logger: cacheLogger.Named("cacheboltdb"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("Error opening persistent cache: %v", err))
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the token from bolt for retrieving the encryption key,
|
||||||
|
// then setup encryption so that restore is possible
|
||||||
|
token, err := ps.GetRetrievalToken()
|
||||||
|
if err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("Error getting retrieval token from persistent cache: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ps.Close(); err != nil {
|
||||||
|
c.UI.Warn(fmt.Sprintf("Failed to close persistent cache file after getting retrieval token: %s", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
km, err := keymanager.NewPassthroughKeyManager(token)
|
||||||
|
if err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("failed to configure persistence encryption for cache: %s", err))
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open the bolt file with the wrapper provided
|
||||||
|
ps, err = cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
|
||||||
|
Path: config.Cache.Persist.Path,
|
||||||
|
Logger: cacheLogger.Named("cacheboltdb"),
|
||||||
|
Wrapper: km.Wrapper(),
|
||||||
|
AAD: aad,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("Error opening persistent cache: %v", err))
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore anything in the persistent cache to the memory cache
|
||||||
|
if err := leaseCache.Restore(ctx, ps); err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("Error restoring in-memory cache from persisted file: %v", err))
|
||||||
|
if config.Cache.Persist.ExitOnErr {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cacheLogger.Info("loaded memcache from persistent storage")
|
||||||
|
|
||||||
|
// Check for previous auto-auth token
|
||||||
|
oldTokenBytes, err := ps.GetAutoAuthToken(ctx)
|
||||||
|
if err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("Error in fetching previous auto-auth token: %s", err))
|
||||||
|
if config.Cache.Persist.ExitOnErr {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(oldTokenBytes) > 0 {
|
||||||
|
oldToken, err := cachememdb.Deserialize(oldTokenBytes)
|
||||||
|
if err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("Error in deserializing previous auto-auth token cache entry: %s", err))
|
||||||
|
if config.Cache.Persist.ExitOnErr {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
previousToken = oldToken.Token
|
||||||
|
}
|
||||||
|
|
||||||
|
// If keep_after_import true, set persistent storage layer in
|
||||||
|
// leaseCache, else remove db file
|
||||||
|
if config.Cache.Persist.KeepAfterImport {
|
||||||
|
defer ps.Close()
|
||||||
|
leaseCache.SetPersistentStorage(ps)
|
||||||
|
} else {
|
||||||
|
if err := ps.Close(); err != nil {
|
||||||
|
c.UI.Warn(fmt.Sprintf("failed to close persistent cache file: %s", err))
|
||||||
|
}
|
||||||
|
dbFile := filepath.Join(config.Cache.Persist.Path, cacheboltdb.DatabaseFileName)
|
||||||
|
if err := os.Remove(dbFile); err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("failed to remove persistent storage file %s: %s", dbFile, err))
|
||||||
|
if config.Cache.Persist.ExitOnErr {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
km, err := keymanager.NewPassthroughKeyManager(nil)
|
||||||
|
if err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("failed to configure persistence encryption for cache: %s", err))
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
|
||||||
|
Path: config.Cache.Persist.Path,
|
||||||
|
Logger: cacheLogger.Named("cacheboltdb"),
|
||||||
|
Wrapper: km.Wrapper(),
|
||||||
|
AAD: aad,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("Error creating persistent cache: %v", err))
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
cacheLogger.Info("configured persistent storage", "path", config.Cache.Persist.Path)
|
||||||
|
|
||||||
|
// Stash the key material in bolt
|
||||||
|
token, err := km.RetrievalToken()
|
||||||
|
if err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("Error getting persistent key: %s", err))
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if err := ps.StoreRetrievalToken(token); err != nil {
|
||||||
|
c.UI.Error(fmt.Sprintf("Error setting key in persistent cache: %v", err))
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
defer ps.Close()
|
||||||
|
leaseCache.SetPersistentStorage(ps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var inmemSink sink.Sink
|
var inmemSink sink.Sink
|
||||||
if config.Cache.UseAutoAuthToken {
|
if config.Cache.UseAutoAuthToken {
|
||||||
cacheLogger.Debug("auto-auth token is allowed to be used; configuring inmem sink")
|
cacheLogger.Debug("auto-auth token is allowed to be used; configuring inmem sink")
|
||||||
|
@ -585,6 +738,11 @@ func (c *AgentCommand) Run(args []string) int {
|
||||||
select {
|
select {
|
||||||
case <-c.ShutdownCh:
|
case <-c.ShutdownCh:
|
||||||
c.UI.Output("==> Vault agent shutdown triggered")
|
c.UI.Output("==> Vault agent shutdown triggered")
|
||||||
|
// Let the lease cache know this is a shutdown; no need to evict
|
||||||
|
// everything
|
||||||
|
if leaseCache != nil {
|
||||||
|
leaseCache.SetShuttingDown(true)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil
|
return nil
|
||||||
|
@ -604,6 +762,7 @@ func (c *AgentCommand) Run(args []string) int {
|
||||||
MaxBackoff: config.AutoAuth.Method.MaxBackoff,
|
MaxBackoff: config.AutoAuth.Method.MaxBackoff,
|
||||||
EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials,
|
EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials,
|
||||||
EnableTemplateTokenCh: enableTokenCh,
|
EnableTemplateTokenCh: enableTokenCh,
|
||||||
|
Token: previousToken,
|
||||||
})
|
})
|
||||||
|
|
||||||
ss := sink.NewSinkServer(&sink.SinkServerConfig{
|
ss := sink.NewSinkServer(&sink.SinkServerConfig{
|
||||||
|
@ -624,6 +783,11 @@ func (c *AgentCommand) Run(args []string) int {
|
||||||
g.Add(func() error {
|
g.Add(func() error {
|
||||||
return ah.Run(ctx, method)
|
return ah.Run(ctx, method)
|
||||||
}, func(error) {
|
}, func(error) {
|
||||||
|
// Let the lease cache know this is a shutdown; no need to evict
|
||||||
|
// everything
|
||||||
|
if leaseCache != nil {
|
||||||
|
leaseCache.SetShuttingDown(true)
|
||||||
|
}
|
||||||
cancelFunc()
|
cancelFunc()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -650,12 +814,22 @@ func (c *AgentCommand) Run(args []string) int {
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}, func(error) {
|
}, func(error) {
|
||||||
|
// Let the lease cache know this is a shutdown; no need to evict
|
||||||
|
// everything
|
||||||
|
if leaseCache != nil {
|
||||||
|
leaseCache.SetShuttingDown(true)
|
||||||
|
}
|
||||||
cancelFunc()
|
cancelFunc()
|
||||||
})
|
})
|
||||||
|
|
||||||
g.Add(func() error {
|
g.Add(func() error {
|
||||||
return ts.Run(ctx, ah.TemplateTokenCh, config.Templates)
|
return ts.Run(ctx, ah.TemplateTokenCh, config.Templates)
|
||||||
}, func(error) {
|
}, func(error) {
|
||||||
|
// Let the lease cache know this is a shutdown; no need to evict
|
||||||
|
// everything
|
||||||
|
if leaseCache != nil {
|
||||||
|
leaseCache.SetShuttingDown(true)
|
||||||
|
}
|
||||||
cancelFunc()
|
cancelFunc()
|
||||||
ts.Stop()
|
ts.Stop()
|
||||||
})
|
})
|
||||||
|
@ -793,3 +967,16 @@ func (c *AgentCommand) removePidFile(pidPath string) error {
|
||||||
}
|
}
|
||||||
return os.Remove(pidPath)
|
return os.Remove(pidPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetServiceAccountJWT reads the service account jwt from `tokenFile`. Default is
|
||||||
|
// the default service account file path in kubernetes.
|
||||||
|
func getServiceAccountJWT(tokenFile string) (string, error) {
|
||||||
|
if len(tokenFile) == 0 {
|
||||||
|
tokenFile = "/var/run/secrets/kubernetes.io/serviceaccount/token"
|
||||||
|
}
|
||||||
|
token, err := ioutil.ReadFile(tokenFile)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(string(token)), nil
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,303 @@
|
||||||
|
package cacheboltdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/protobuf/proto"
|
||||||
|
"github.com/hashicorp/go-hclog"
|
||||||
|
wrapping "github.com/hashicorp/go-kms-wrapping"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
bolt "go.etcd.io/bbolt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Keep track of schema version for future migrations
|
||||||
|
storageVersionKey = "version"
|
||||||
|
storageVersion = "1"
|
||||||
|
|
||||||
|
// DatabaseFileName - filename for the persistent cache file
|
||||||
|
DatabaseFileName = "vault-agent-cache.db"
|
||||||
|
|
||||||
|
// metaBucketName - naming the meta bucket that holds the version and
|
||||||
|
// bootstrapping keys
|
||||||
|
metaBucketName = "meta"
|
||||||
|
|
||||||
|
// SecretLeaseType - Bucket/type for leases with secret info
|
||||||
|
SecretLeaseType = "secret-lease"
|
||||||
|
|
||||||
|
// AuthLeaseType - Bucket/type for leases with auth info
|
||||||
|
AuthLeaseType = "auth-lease"
|
||||||
|
|
||||||
|
// TokenType - Bucket/type for auto-auth tokens
|
||||||
|
TokenType = "token"
|
||||||
|
|
||||||
|
// AutoAuthToken - key for the latest auto-auth token
|
||||||
|
AutoAuthToken = "auto-auth-token"
|
||||||
|
|
||||||
|
// RetrievalTokenMaterial is the actual key or token in the key bucket
|
||||||
|
RetrievalTokenMaterial = "retrieval-token-material"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BoltStorage is a persistent cache using a bolt db. Items are organized with
|
||||||
|
// the version and bootstrapping items in the "meta" bucket, and tokens, auth
|
||||||
|
// leases, and secret leases in their own buckets.
|
||||||
|
type BoltStorage struct {
|
||||||
|
db *bolt.DB
|
||||||
|
logger hclog.Logger
|
||||||
|
wrapper wrapping.Wrapper
|
||||||
|
aad string
|
||||||
|
}
|
||||||
|
|
||||||
|
// BoltStorageConfig is the collection of input parameters for setting up bolt
|
||||||
|
// storage
|
||||||
|
type BoltStorageConfig struct {
|
||||||
|
Path string
|
||||||
|
Logger hclog.Logger
|
||||||
|
Wrapper wrapping.Wrapper
|
||||||
|
AAD string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBoltStorage opens a new bolt db at the specified file path and returns it.
|
||||||
|
// If the db already exists the buckets will just be created if they don't
|
||||||
|
// exist.
|
||||||
|
func NewBoltStorage(config *BoltStorageConfig) (*BoltStorage, error) {
|
||||||
|
dbPath := filepath.Join(config.Path, DatabaseFileName)
|
||||||
|
db, err := bolt.Open(dbPath, 0600, &bolt.Options{Timeout: 1 * time.Second})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = db.Update(func(tx *bolt.Tx) error {
|
||||||
|
return createBoltSchema(tx)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
bs := &BoltStorage{
|
||||||
|
db: db,
|
||||||
|
logger: config.Logger,
|
||||||
|
wrapper: config.Wrapper,
|
||||||
|
aad: config.AAD,
|
||||||
|
}
|
||||||
|
return bs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createBoltSchema(tx *bolt.Tx) error {
|
||||||
|
// create the meta bucket at the top level
|
||||||
|
meta, err := tx.CreateBucketIfNotExists([]byte(metaBucketName))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create bucket %s: %w", metaBucketName, err)
|
||||||
|
}
|
||||||
|
// check and set file version in the meta bucket
|
||||||
|
version := meta.Get([]byte(storageVersionKey))
|
||||||
|
switch {
|
||||||
|
case version == nil:
|
||||||
|
err = meta.Put([]byte(storageVersionKey), []byte(storageVersion))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set storage version: %w", err)
|
||||||
|
}
|
||||||
|
case string(version) != storageVersion:
|
||||||
|
return fmt.Errorf("storage migration from %s to %s not implemented", string(version), storageVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create the buckets for tokens and leases
|
||||||
|
_, err = tx.CreateBucketIfNotExists([]byte(TokenType))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create token bucket: %w", err)
|
||||||
|
}
|
||||||
|
_, err = tx.CreateBucketIfNotExists([]byte(AuthLeaseType))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth lease bucket: %w", err)
|
||||||
|
}
|
||||||
|
_, err = tx.CreateBucketIfNotExists([]byte(SecretLeaseType))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create secret lease bucket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set an index (token or lease) in bolt storage
|
||||||
|
func (b *BoltStorage) Set(ctx context.Context, id string, plaintext []byte, indexType string) error {
|
||||||
|
blob, err := b.wrapper.Encrypt(ctx, plaintext, []byte(b.aad))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error encrypting %s index: %w", indexType, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
protoBlob, err := proto.Marshal(blob)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.db.Update(func(tx *bolt.Tx) error {
|
||||||
|
s := tx.Bucket([]byte(indexType))
|
||||||
|
if s == nil {
|
||||||
|
return fmt.Errorf("bucket %q not found", indexType)
|
||||||
|
}
|
||||||
|
// If this is an auto-auth token, also stash it in the meta bucket for
|
||||||
|
// easy retrieval upon restore
|
||||||
|
if indexType == TokenType {
|
||||||
|
meta := tx.Bucket([]byte(metaBucketName))
|
||||||
|
if err := meta.Put([]byte(AutoAuthToken), protoBlob); err != nil {
|
||||||
|
return fmt.Errorf("failed to set latest auto-auth token: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.Put([]byte(id), protoBlob)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBucketIDs(b *bolt.Bucket) ([][]byte, error) {
|
||||||
|
ids := [][]byte{}
|
||||||
|
err := b.ForEach(func(k, v []byte) error {
|
||||||
|
ids = append(ids, k)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return ids, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete an index (token or lease) by id from bolt storage
|
||||||
|
func (b *BoltStorage) Delete(id string) error {
|
||||||
|
return b.db.Update(func(tx *bolt.Tx) error {
|
||||||
|
// Since Delete returns a nil error if the key doesn't exist, just call
|
||||||
|
// delete in all three index buckets without checking existence first
|
||||||
|
if err := tx.Bucket([]byte(TokenType)).Delete([]byte(id)); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete %q from token bucket: %w", id, err)
|
||||||
|
}
|
||||||
|
if err := tx.Bucket([]byte(AuthLeaseType)).Delete([]byte(id)); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete %q from auth lease bucket: %w", id, err)
|
||||||
|
}
|
||||||
|
if err := tx.Bucket([]byte(SecretLeaseType)).Delete([]byte(id)); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete %q from secret lease bucket: %w", id, err)
|
||||||
|
}
|
||||||
|
b.logger.Trace("deleted index from bolt db", "id", id)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BoltStorage) decrypt(ctx context.Context, ciphertext []byte) ([]byte, error) {
|
||||||
|
var blob wrapping.EncryptedBlobInfo
|
||||||
|
if err := proto.Unmarshal(ciphertext, &blob); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.wrapper.Decrypt(ctx, &blob, []byte(b.aad))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByType returns a list of stored items of the specified type
|
||||||
|
func (b *BoltStorage) GetByType(ctx context.Context, indexType string) ([][]byte, error) {
|
||||||
|
var returnBytes [][]byte
|
||||||
|
|
||||||
|
err := b.db.View(func(tx *bolt.Tx) error {
|
||||||
|
var errors *multierror.Error
|
||||||
|
|
||||||
|
tx.Bucket([]byte(indexType)).ForEach(func(id, ciphertext []byte) error {
|
||||||
|
plaintext, err := b.decrypt(ctx, ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
errors = multierror.Append(errors, fmt.Errorf("error decrypting index id %s: %w", id, err))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
returnBytes = append(returnBytes, plaintext)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return errors.ErrorOrNil()
|
||||||
|
})
|
||||||
|
|
||||||
|
return returnBytes, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAutoAuthToken retrieves the latest auto-auth token, and returns nil if non
|
||||||
|
// exists yet
|
||||||
|
func (b *BoltStorage) GetAutoAuthToken(ctx context.Context) ([]byte, error) {
|
||||||
|
var encryptedToken []byte
|
||||||
|
|
||||||
|
err := b.db.View(func(tx *bolt.Tx) error {
|
||||||
|
meta := tx.Bucket([]byte(metaBucketName))
|
||||||
|
if meta == nil {
|
||||||
|
return fmt.Errorf("bucket %q not found", metaBucketName)
|
||||||
|
}
|
||||||
|
encryptedToken = meta.Get([]byte(AutoAuthToken))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if encryptedToken == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext, err := b.decrypt(ctx, encryptedToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decrypt auto-auth token: %w", err)
|
||||||
|
}
|
||||||
|
return plaintext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRetrievalToken retrieves a plaintext token from the KeyBucket, which will
|
||||||
|
// be used by the key manager to retrieve the encryption key, nil if none set
|
||||||
|
func (b *BoltStorage) GetRetrievalToken() ([]byte, error) {
|
||||||
|
var token []byte
|
||||||
|
|
||||||
|
err := b.db.View(func(tx *bolt.Tx) error {
|
||||||
|
keyBucket := tx.Bucket([]byte(metaBucketName))
|
||||||
|
if keyBucket == nil {
|
||||||
|
return fmt.Errorf("bucket %q not found", metaBucketName)
|
||||||
|
}
|
||||||
|
token = keyBucket.Get([]byte(RetrievalTokenMaterial))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreRetrievalToken sets plaintext token material in the RetrievalTokenBucket
|
||||||
|
func (b *BoltStorage) StoreRetrievalToken(token []byte) error {
|
||||||
|
return b.db.Update(func(tx *bolt.Tx) error {
|
||||||
|
bucket := tx.Bucket([]byte(metaBucketName))
|
||||||
|
if bucket == nil {
|
||||||
|
return fmt.Errorf("bucket %q not found", metaBucketName)
|
||||||
|
}
|
||||||
|
return bucket.Put([]byte(RetrievalTokenMaterial), token)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the boltdb
|
||||||
|
func (b *BoltStorage) Close() error {
|
||||||
|
b.logger.Trace("closing bolt db", "path", b.db.Path())
|
||||||
|
return b.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear the boltdb by deleting all the token and lease buckets and recreating
|
||||||
|
// the schema/layout
|
||||||
|
func (b *BoltStorage) Clear() error {
|
||||||
|
return b.db.Update(func(tx *bolt.Tx) error {
|
||||||
|
for _, name := range []string{AuthLeaseType, SecretLeaseType, TokenType} {
|
||||||
|
b.logger.Trace("deleting bolt bucket", "name", name)
|
||||||
|
if err := tx.DeleteBucket([]byte(name)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return createBoltSchema(tx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// DBFileExists checks whether the vault agent cache file at `filePath` exists
|
||||||
|
func DBFileExists(path string) (bool, error) {
|
||||||
|
checkFile, err := os.OpenFile(filepath.Join(path, DatabaseFileName), os.O_RDWR, 0600)
|
||||||
|
defer checkFile.Close()
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
return true, nil
|
||||||
|
case os.IsNotExist(err):
|
||||||
|
return false, nil
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("failed to check if bolt file exists at path %s: %w", path, err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,263 @@
|
||||||
|
package cacheboltdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-hclog"
|
||||||
|
"github.com/hashicorp/vault/command/agent/cache/keymanager"
|
||||||
|
"github.com/ory/dockertest/v3/docker/pkg/ioutils"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getTestKeyManager(t *testing.T) keymanager.KeyManager {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
km, err := keymanager.NewPassthroughKeyManager(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return km
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBolt_SetGet(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
path, err := ioutils.TempDir("", "bolt-test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(path)
|
||||||
|
|
||||||
|
b, err := NewBoltStorage(&BoltStorageConfig{
|
||||||
|
Path: path,
|
||||||
|
Logger: hclog.Default(),
|
||||||
|
Wrapper: getTestKeyManager(t).Wrapper(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
secrets, err := b.GetByType(ctx, SecretLeaseType)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
require.Len(t, secrets, 0)
|
||||||
|
|
||||||
|
err = b.Set(ctx, "test1", []byte("hello"), SecretLeaseType)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
secrets, err = b.GetByType(ctx, SecretLeaseType)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
require.Len(t, secrets, 1)
|
||||||
|
assert.Equal(t, []byte("hello"), secrets[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBoltDelete(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
path, err := ioutils.TempDir("", "bolt-test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(path)
|
||||||
|
|
||||||
|
b, err := NewBoltStorage(&BoltStorageConfig{
|
||||||
|
Path: path,
|
||||||
|
Logger: hclog.Default(),
|
||||||
|
Wrapper: getTestKeyManager(t).Wrapper(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = b.Set(ctx, "secret-test1", []byte("hello1"), SecretLeaseType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = b.Set(ctx, "secret-test2", []byte("hello2"), SecretLeaseType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
secrets, err := b.GetByType(ctx, SecretLeaseType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, secrets, 2)
|
||||||
|
assert.ElementsMatch(t, [][]byte{[]byte("hello1"), []byte("hello2")}, secrets)
|
||||||
|
|
||||||
|
err = b.Delete("secret-test1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
secrets, err = b.GetByType(ctx, SecretLeaseType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, secrets, 1)
|
||||||
|
assert.Equal(t, []byte("hello2"), secrets[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBoltClear(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
path, err := ioutils.TempDir("", "bolt-test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(path)
|
||||||
|
|
||||||
|
b, err := NewBoltStorage(&BoltStorageConfig{
|
||||||
|
Path: path,
|
||||||
|
Logger: hclog.Default(),
|
||||||
|
Wrapper: getTestKeyManager(t).Wrapper(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Populate the bolt db
|
||||||
|
err = b.Set(ctx, "secret-test1", []byte("hello"), SecretLeaseType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
secrets, err := b.GetByType(ctx, SecretLeaseType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, secrets, 1)
|
||||||
|
assert.Equal(t, []byte("hello"), secrets[0])
|
||||||
|
|
||||||
|
err = b.Set(ctx, "auth-test1", []byte("hello"), AuthLeaseType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
auths, err := b.GetByType(ctx, AuthLeaseType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, auths, 1)
|
||||||
|
assert.Equal(t, []byte("hello"), auths[0])
|
||||||
|
|
||||||
|
err = b.Set(ctx, "token-test1", []byte("hello"), TokenType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
tokens, err := b.GetByType(ctx, TokenType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, tokens, 1)
|
||||||
|
assert.Equal(t, []byte("hello"), tokens[0])
|
||||||
|
|
||||||
|
// Clear the bolt db, and check that it's indeed clear
|
||||||
|
err = b.Clear()
|
||||||
|
require.NoError(t, err)
|
||||||
|
secrets, err = b.GetByType(ctx, SecretLeaseType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, secrets, 0)
|
||||||
|
auths, err = b.GetByType(ctx, AuthLeaseType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, auths, 0)
|
||||||
|
tokens, err = b.GetByType(ctx, TokenType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, tokens, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBoltSetAutoAuthToken(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
path, err := ioutils.TempDir("", "bolt-test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(path)
|
||||||
|
|
||||||
|
b, err := NewBoltStorage(&BoltStorageConfig{
|
||||||
|
Path: path,
|
||||||
|
Logger: hclog.Default(),
|
||||||
|
Wrapper: getTestKeyManager(t).Wrapper(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token, err := b.GetAutoAuthToken(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, token)
|
||||||
|
|
||||||
|
// set first token
|
||||||
|
err = b.Set(ctx, "token-test1", []byte("hello 1"), TokenType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
secrets, err := b.GetByType(ctx, TokenType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, secrets, 1)
|
||||||
|
assert.Equal(t, []byte("hello 1"), secrets[0])
|
||||||
|
token, err = b.GetAutoAuthToken(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("hello 1"), token)
|
||||||
|
|
||||||
|
// set second token
|
||||||
|
err = b.Set(ctx, "token-test2", []byte("hello 2"), TokenType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
secrets, err = b.GetByType(ctx, TokenType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, secrets, 2)
|
||||||
|
assert.ElementsMatch(t, [][]byte{[]byte("hello 1"), []byte("hello 2")}, secrets)
|
||||||
|
token, err = b.GetAutoAuthToken(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("hello 2"), token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDBFileExists(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
mkDir bool
|
||||||
|
createFile bool
|
||||||
|
expectExist bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "all exists",
|
||||||
|
mkDir: true,
|
||||||
|
createFile: true,
|
||||||
|
expectExist: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dir exist, file missing",
|
||||||
|
mkDir: true,
|
||||||
|
createFile: false,
|
||||||
|
expectExist: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all missing",
|
||||||
|
mkDir: false,
|
||||||
|
createFile: false,
|
||||||
|
expectExist: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var tmpPath string
|
||||||
|
var err error
|
||||||
|
if tc.mkDir {
|
||||||
|
tmpPath, err = ioutil.TempDir("", "test-db-path")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
if tc.createFile {
|
||||||
|
err = ioutil.WriteFile(path.Join(tmpPath, DatabaseFileName), []byte("test-db-path"), 0600)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
exists, err := DBFileExists(tmpPath)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.expectExist, exists)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_SetGetRetrievalToken(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
tokenToSet []byte
|
||||||
|
expectedToken []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal set and get",
|
||||||
|
tokenToSet: []byte("test token"),
|
||||||
|
expectedToken: []byte("test token"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no token set",
|
||||||
|
tokenToSet: nil,
|
||||||
|
expectedToken: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
path, err := ioutils.TempDir("", "bolt-test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(path)
|
||||||
|
|
||||||
|
b, err := NewBoltStorage(&BoltStorageConfig{
|
||||||
|
Path: path,
|
||||||
|
Logger: hclog.Default(),
|
||||||
|
Wrapper: getTestKeyManager(t).Wrapper(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer b.Close()
|
||||||
|
|
||||||
|
if tc.tokenToSet != nil {
|
||||||
|
err := b.StoreRetrievalToken(tc.tokenToSet)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
gotKey, err := b.GetRetrievalToken()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.expectedToken, gotKey)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,6 +1,11 @@
|
||||||
package cachememdb
|
package cachememdb
|
||||||
|
|
||||||
import "context"
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
// Index holds the response to be cached along with multiple other values that
|
// Index holds the response to be cached along with multiple other values that
|
||||||
// serve as pointers to refer back to this index.
|
// serve as pointers to refer back to this index.
|
||||||
|
@ -48,6 +53,21 @@ type Index struct {
|
||||||
// goroutine that manages the renewal of the secret belonging to the
|
// goroutine that manages the renewal of the secret belonging to the
|
||||||
// response in this index.
|
// response in this index.
|
||||||
RenewCtxInfo *ContextInfo
|
RenewCtxInfo *ContextInfo
|
||||||
|
|
||||||
|
// RequestMethod is the HTTP method of the request
|
||||||
|
RequestMethod string
|
||||||
|
|
||||||
|
// RequestToken is the token used in the request
|
||||||
|
RequestToken string
|
||||||
|
|
||||||
|
// RequestHeader is the header used in the request
|
||||||
|
RequestHeader http.Header
|
||||||
|
|
||||||
|
// LastRenewed is the timestamp of last renewal
|
||||||
|
LastRenewed time.Time
|
||||||
|
|
||||||
|
// Type is the index type (token, auth-lease, secret-lease)
|
||||||
|
Type string
|
||||||
}
|
}
|
||||||
|
|
||||||
type IndexName uint32
|
type IndexName uint32
|
||||||
|
@ -106,3 +126,25 @@ func NewContextInfo(ctx context.Context) *ContextInfo {
|
||||||
ctxInfo.DoneCh = make(chan struct{})
|
ctxInfo.DoneCh = make(chan struct{})
|
||||||
return ctxInfo
|
return ctxInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Serialize returns a json marshal'ed Index object, without the RenewCtxInfo
|
||||||
|
func (i Index) Serialize() ([]byte, error) {
|
||||||
|
i.RenewCtxInfo = nil
|
||||||
|
|
||||||
|
indexBytes, err := json.Marshal(i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return indexBytes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deserialize converts json bytes to an Index object
|
||||||
|
// Note: RenewCtxInfo will need to be reconstructed elsewhere.
|
||||||
|
func Deserialize(indexBytes []byte) (*Index, error) {
|
||||||
|
index := new(Index)
|
||||||
|
if err := json.Unmarshal(indexBytes, index); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
package cachememdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSerializeDeserialize(t *testing.T) {
|
||||||
|
|
||||||
|
testIndex := &Index{
|
||||||
|
ID: "testid",
|
||||||
|
Token: "testtoken",
|
||||||
|
TokenParent: "parent token",
|
||||||
|
TokenAccessor: "test accessor",
|
||||||
|
Namespace: "test namespace",
|
||||||
|
RequestPath: "/test/path",
|
||||||
|
Lease: "lease id",
|
||||||
|
LeaseToken: "lease token id",
|
||||||
|
Response: []byte(`{"something": "here"}`),
|
||||||
|
RenewCtxInfo: NewContextInfo(context.Background()),
|
||||||
|
RequestMethod: "GET",
|
||||||
|
RequestToken: "request token",
|
||||||
|
RequestHeader: http.Header{
|
||||||
|
"X-Test": []string{"vault", "agent"},
|
||||||
|
},
|
||||||
|
LastRenewed: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
indexBytes, err := testIndex.Serialize()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, len(indexBytes) > 0)
|
||||||
|
assert.NotNil(t, testIndex.RenewCtxInfo, "Serialize should not modify original Index object")
|
||||||
|
|
||||||
|
restoredIndex, err := Deserialize(indexBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testIndex.RenewCtxInfo = nil
|
||||||
|
assert.Equal(t, testIndex, restoredIndex, "They should be equal without RenewCtxInfo set on the original")
|
||||||
|
}
|
|
@ -1,18 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
KeyID = "root"
|
|
||||||
)
|
|
||||||
|
|
||||||
type KeyManager interface {
|
|
||||||
GetKey() []byte
|
|
||||||
GetPersistentKey() ([]byte, error)
|
|
||||||
Renewable() bool
|
|
||||||
Renewer(context.Context) error
|
|
||||||
Encrypt(context.Context, []byte, []byte) ([]byte, error)
|
|
||||||
Decrypt(context.Context, []byte, []byte) ([]byte, error)
|
|
||||||
}
|
|
|
@ -1,97 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
wrapping "github.com/hashicorp/go-kms-wrapping"
|
|
||||||
"github.com/hashicorp/go-kms-wrapping/wrappers/aead"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ KeyManager = (*KubeEncryptionKey)(nil)
|
|
||||||
|
|
||||||
type KubeEncryptionKey struct {
|
|
||||||
renewable bool
|
|
||||||
wrapper *aead.Wrapper
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewK8s returns a new instance of the Kube encryption key. Kubernetes
|
|
||||||
// encryption keys aren't renewable.
|
|
||||||
func NewK8s(existingKey []byte) (*KubeEncryptionKey, error) {
|
|
||||||
k := &KubeEncryptionKey{
|
|
||||||
renewable: false,
|
|
||||||
wrapper: aead.NewWrapper(nil),
|
|
||||||
}
|
|
||||||
|
|
||||||
k.wrapper.SetConfig(map[string]string{"key_id": KeyID})
|
|
||||||
|
|
||||||
var rootKey []byte = nil
|
|
||||||
if len(existingKey) != 0 {
|
|
||||||
if len(existingKey) != 32 {
|
|
||||||
return k, fmt.Errorf("invalid key size, should be 32, got %d", len(existingKey))
|
|
||||||
}
|
|
||||||
rootKey = existingKey
|
|
||||||
}
|
|
||||||
|
|
||||||
if rootKey == nil {
|
|
||||||
newKey := make([]byte, 32)
|
|
||||||
_, err := rand.Read(newKey)
|
|
||||||
if err != nil {
|
|
||||||
return k, err
|
|
||||||
}
|
|
||||||
rootKey = newKey
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := k.wrapper.SetAESGCMKeyBytes(rootKey); err != nil {
|
|
||||||
return k, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return k, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetKey returns the encryption key in a format optimized for storage.
|
|
||||||
// In k8s we store the key as is, so just return the key stored.
|
|
||||||
func (k *KubeEncryptionKey) GetKey() []byte {
|
|
||||||
return k.wrapper.GetKeyBytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPersistentKey returns the key which should be stored in the persisent
|
|
||||||
// cache. In k8s we store the key as is, so just return the key stored.
|
|
||||||
func (k *KubeEncryptionKey) GetPersistentKey() ([]byte, error) {
|
|
||||||
return k.wrapper.GetKeyBytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Renewable lets the caller know if this encryption key type is
|
|
||||||
// renewable. In Kubernetes the key isn't renewable.
|
|
||||||
func (k *KubeEncryptionKey) Renewable() bool {
|
|
||||||
return k.renewable
|
|
||||||
}
|
|
||||||
|
|
||||||
// Renewer is used when the encryption key type is renewable. Since Kubernetes
|
|
||||||
// keys aren't renewable, returning nothing.
|
|
||||||
func (k *KubeEncryptionKey) Renewer(ctx context.Context) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encrypt takes plaintext values and encrypts them using the store key and additional
|
|
||||||
// data. For Kubernetes the AAD should be the service account JWT.
|
|
||||||
func (k *KubeEncryptionKey) Encrypt(ctx context.Context, plaintext, aad []byte) ([]byte, error) {
|
|
||||||
blob, err := k.wrapper.Encrypt(ctx, plaintext, aad)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return blob.Ciphertext, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decrypt takes ciphertext and AAD values and returns the decrypted value. For Kubernetes the AAD
|
|
||||||
// should be the service account JWT.
|
|
||||||
func (k *KubeEncryptionKey) Decrypt(ctx context.Context, ciphertext, aad []byte) ([]byte, error) {
|
|
||||||
blob := &wrapping.EncryptedBlobInfo{
|
|
||||||
Ciphertext: ciphertext,
|
|
||||||
KeyInfo: &wrapping.KeyInfo{
|
|
||||||
KeyID: KeyID,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return k.wrapper.Decrypt(ctx, blob, aad)
|
|
||||||
}
|
|
|
@ -1,155 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCrypto_KubernetesNewKey(t *testing.T) {
|
|
||||||
k8sKey, err := NewK8s([]byte{})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf(fmt.Sprintf("unexpected error: %s", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
key := k8sKey.GetKey()
|
|
||||||
if key == nil {
|
|
||||||
t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", key))
|
|
||||||
}
|
|
||||||
|
|
||||||
persistentKey, _ := k8sKey.GetPersistentKey()
|
|
||||||
if persistentKey == nil {
|
|
||||||
t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", persistentKey))
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(key) != string(persistentKey) {
|
|
||||||
t.Fatalf("keys don't match, they should: key: %s, persistentKey: %s", key, persistentKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
plaintextInput := []byte("test")
|
|
||||||
aad := []byte("kubernetes")
|
|
||||||
|
|
||||||
ciphertext, err := k8sKey.Encrypt(nil, plaintextInput, aad)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if ciphertext == nil {
|
|
||||||
t.Fatalf("ciphertext nil, it shouldn't be")
|
|
||||||
}
|
|
||||||
|
|
||||||
plaintext, err := k8sKey.Decrypt(nil, ciphertext, aad)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(plaintext) != string(plaintextInput) {
|
|
||||||
t.Fatalf("expected %s, got %s", plaintextInput, plaintext)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCrypto_KubernetesExistingKey(t *testing.T) {
|
|
||||||
rootKey := make([]byte, 32)
|
|
||||||
n, err := rand.Read(rootKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if n != 32 {
|
|
||||||
t.Fatal(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
k8sKey, err := NewK8s(rootKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf(fmt.Sprintf("unexpected error: %s", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
key := k8sKey.GetKey()
|
|
||||||
if key == nil {
|
|
||||||
t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", key))
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(key) != string(rootKey) {
|
|
||||||
t.Fatalf(fmt.Sprintf("expected keys to be the same, they weren't: expected: %s, got: %s", rootKey, key))
|
|
||||||
}
|
|
||||||
|
|
||||||
persistentKey, _ := k8sKey.GetPersistentKey()
|
|
||||||
if persistentKey == nil {
|
|
||||||
t.Fatalf("key is nil, it shouldn't be")
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(persistentKey) != string(rootKey) {
|
|
||||||
t.Fatalf(fmt.Sprintf("expected keys to be the same, they weren't: expected: %s, got: %s", rootKey, persistentKey))
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(key) != string(persistentKey) {
|
|
||||||
t.Fatalf(fmt.Sprintf("expected keys to be the same, they weren't: %s %s", rootKey, persistentKey))
|
|
||||||
}
|
|
||||||
|
|
||||||
plaintextInput := []byte("test")
|
|
||||||
aad := []byte("kubernetes")
|
|
||||||
|
|
||||||
ciphertext, err := k8sKey.Encrypt(nil, plaintextInput, aad)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if ciphertext == nil {
|
|
||||||
t.Fatalf("ciphertext nil, it shouldn't be")
|
|
||||||
}
|
|
||||||
|
|
||||||
plaintext, err := k8sKey.Decrypt(nil, ciphertext, aad)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(plaintext) != string(plaintextInput) {
|
|
||||||
t.Fatalf("expected %s, got %s", plaintextInput, plaintext)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCrypto_KubernetesPassGeneratedKey(t *testing.T) {
|
|
||||||
k8sFirstKey, err := NewK8s([]byte{})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf(fmt.Sprintf("unexpected error: %s", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
firstPersistentKey := k8sFirstKey.GetKey()
|
|
||||||
if firstPersistentKey == nil {
|
|
||||||
t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", firstPersistentKey))
|
|
||||||
}
|
|
||||||
|
|
||||||
plaintextInput := []byte("test")
|
|
||||||
aad := []byte("kubernetes")
|
|
||||||
|
|
||||||
ciphertext, err := k8sFirstKey.Encrypt(nil, plaintextInput, aad)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if ciphertext == nil {
|
|
||||||
t.Fatalf("ciphertext nil, it shouldn't be")
|
|
||||||
}
|
|
||||||
|
|
||||||
k8sLoadedKey, err := NewK8s(firstPersistentKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf(fmt.Sprintf("unexpected error: %s", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
loadedKey, _ := k8sLoadedKey.GetPersistentKey()
|
|
||||||
if loadedKey == nil {
|
|
||||||
t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", loadedKey))
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(loadedKey) != string(firstPersistentKey) {
|
|
||||||
t.Fatalf(fmt.Sprintf("keys do not match"))
|
|
||||||
}
|
|
||||||
|
|
||||||
plaintext, err := k8sLoadedKey.Decrypt(nil, ciphertext, aad)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(plaintext) != string(plaintextInput) {
|
|
||||||
t.Fatalf("expected %s, got %s", plaintextInput, plaintext)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
package keymanager
|
||||||
|
|
||||||
|
import wrapping "github.com/hashicorp/go-kms-wrapping"
|
||||||
|
|
||||||
|
const (
|
||||||
|
KeyID = "root"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyManager interface {
|
||||||
|
// Returns a wrapping.Wrapper which can be used to perform key-related operations.
|
||||||
|
Wrapper() wrapping.Wrapper
|
||||||
|
// RetrievalToken is the material returned which can be used to source back the
|
||||||
|
// encryption key. Depending on the implementation, the token can be the
|
||||||
|
// encryption key itself or a token/identifier used to exchange the token.
|
||||||
|
RetrievalToken() ([]byte, error)
|
||||||
|
}
|
|
@ -0,0 +1,68 @@
|
||||||
|
package keymanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
wrapping "github.com/hashicorp/go-kms-wrapping"
|
||||||
|
"github.com/hashicorp/go-kms-wrapping/wrappers/aead"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ KeyManager = (*PassthroughKeyManager)(nil)
|
||||||
|
|
||||||
|
type PassthroughKeyManager struct {
|
||||||
|
wrapper *aead.Wrapper
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPassthroughKeyManager returns a new instance of the Kube encryption key.
|
||||||
|
// If a key is provided, it will be used as the encryption key for the wrapper,
|
||||||
|
// otherwise one will be generated.
|
||||||
|
func NewPassthroughKeyManager(key []byte) (*PassthroughKeyManager, error) {
|
||||||
|
|
||||||
|
var rootKey []byte = nil
|
||||||
|
switch len(key) {
|
||||||
|
case 0:
|
||||||
|
newKey := make([]byte, 32)
|
||||||
|
_, err := rand.Read(newKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rootKey = newKey
|
||||||
|
case 32:
|
||||||
|
rootKey = key
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid key size, should be 32, got %d", len(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
wrapper := aead.NewWrapper(nil)
|
||||||
|
|
||||||
|
if _, err := wrapper.SetConfig(map[string]string{"key_id": KeyID}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := wrapper.SetAESGCMKeyBytes(rootKey); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
k := &PassthroughKeyManager{
|
||||||
|
wrapper: wrapper,
|
||||||
|
}
|
||||||
|
|
||||||
|
return k, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrapper returns the manager's wrapper for key operations.
|
||||||
|
func (w *PassthroughKeyManager) Wrapper() wrapping.Wrapper {
|
||||||
|
return w.wrapper
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetrievalToken returns the key that was used on the wrapper since this key
|
||||||
|
// manager is simply a passthrough and does not provide a mechanism to abstract
|
||||||
|
// this key.
|
||||||
|
func (w *PassthroughKeyManager) RetrievalToken() ([]byte, error) {
|
||||||
|
if w.wrapper == nil {
|
||||||
|
return nil, fmt.Errorf("unable to get wrapper for token retrieval")
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.wrapper.GetKeyBytes(), nil
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
package keymanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKeyManager_PassthrougKeyManager(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"new key",
|
||||||
|
nil,
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"existing valid key",
|
||||||
|
[]byte("e679e2f3d8d0e489d408bc617c6890d6"),
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid key length",
|
||||||
|
[]byte("foobar"),
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
m, err := NewPassthroughKeyManager(tc.key)
|
||||||
|
if tc.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if w := m.Wrapper(); w == nil {
|
||||||
|
t.Fatalf("expected non-nil wrapper from the key manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := m.RetrievalToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to retrieve token: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tc.key) != 0 && !bytes.Equal(tc.key, token) {
|
||||||
|
t.Fatalf("expected key bytes: %x, got: %x", tc.key, token)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -18,6 +19,7 @@ import (
|
||||||
"github.com/hashicorp/errwrap"
|
"github.com/hashicorp/errwrap"
|
||||||
hclog "github.com/hashicorp/go-hclog"
|
hclog "github.com/hashicorp/go-hclog"
|
||||||
"github.com/hashicorp/vault/api"
|
"github.com/hashicorp/vault/api"
|
||||||
|
"github.com/hashicorp/vault/command/agent/cache/cacheboltdb"
|
||||||
cachememdb "github.com/hashicorp/vault/command/agent/cache/cachememdb"
|
cachememdb "github.com/hashicorp/vault/command/agent/cache/cachememdb"
|
||||||
"github.com/hashicorp/vault/helper/namespace"
|
"github.com/hashicorp/vault/helper/namespace"
|
||||||
nshelper "github.com/hashicorp/vault/helper/namespace"
|
nshelper "github.com/hashicorp/vault/helper/namespace"
|
||||||
|
@ -83,6 +85,13 @@ type LeaseCache struct {
|
||||||
|
|
||||||
// inflightCache keeps track of inflight requests
|
// inflightCache keeps track of inflight requests
|
||||||
inflightCache *gocache.Cache
|
inflightCache *gocache.Cache
|
||||||
|
|
||||||
|
// ps is the persistent storage for tokens and leases
|
||||||
|
ps *cacheboltdb.BoltStorage
|
||||||
|
|
||||||
|
// shuttingDown is used to determine if cache needs to be evicted or not
|
||||||
|
// when the context is cancelled
|
||||||
|
shuttingDown atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// LeaseCacheConfig is the configuration for initializing a new
|
// LeaseCacheConfig is the configuration for initializing a new
|
||||||
|
@ -92,6 +101,7 @@ type LeaseCacheConfig struct {
|
||||||
BaseContext context.Context
|
BaseContext context.Context
|
||||||
Proxier Proxier
|
Proxier Proxier
|
||||||
Logger hclog.Logger
|
Logger hclog.Logger
|
||||||
|
Storage *cacheboltdb.BoltStorage
|
||||||
}
|
}
|
||||||
|
|
||||||
type inflightRequest struct {
|
type inflightRequest struct {
|
||||||
|
@ -141,9 +151,21 @@ func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
|
||||||
l: &sync.RWMutex{},
|
l: &sync.RWMutex{},
|
||||||
idLocks: locksutil.CreateLocks(),
|
idLocks: locksutil.CreateLocks(),
|
||||||
inflightCache: gocache.New(gocache.NoExpiration, gocache.NoExpiration),
|
inflightCache: gocache.New(gocache.NoExpiration, gocache.NoExpiration),
|
||||||
|
ps: conf.Storage,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetShuttingDown is a setter for the shuttingDown field
|
||||||
|
func (c *LeaseCache) SetShuttingDown(in bool) {
|
||||||
|
c.shuttingDown.Store(in)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPersistentStorage is a setter for the persistent storage field in
|
||||||
|
// LeaseCache
|
||||||
|
func (c *LeaseCache) SetPersistentStorage(storageIn *cacheboltdb.BoltStorage) {
|
||||||
|
c.ps = storageIn
|
||||||
|
}
|
||||||
|
|
||||||
// checkCacheForRequest checks the cache for a particular request based on its
|
// checkCacheForRequest checks the cache for a particular request based on its
|
||||||
// computed ID. It returns a non-nil *SendResponse if an entry is found.
|
// computed ID. It returns a non-nil *SendResponse if an entry is found.
|
||||||
func (c *LeaseCache) checkCacheForRequest(id string) (*SendResponse, error) {
|
func (c *LeaseCache) checkCacheForRequest(id string) (*SendResponse, error) {
|
||||||
|
@ -275,6 +297,7 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
|
||||||
ID: id,
|
ID: id,
|
||||||
Namespace: namespace,
|
Namespace: namespace,
|
||||||
RequestPath: req.Request.URL.Path,
|
RequestPath: req.Request.URL.Path,
|
||||||
|
LastRenewed: time.Now().UTC(),
|
||||||
}
|
}
|
||||||
|
|
||||||
secret, err := api.ParseSecret(bytes.NewReader(resp.ResponseBody))
|
secret, err := api.ParseSecret(bytes.NewReader(resp.ResponseBody))
|
||||||
|
@ -332,6 +355,8 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
|
||||||
index.Lease = secret.LeaseID
|
index.Lease = secret.LeaseID
|
||||||
index.LeaseToken = req.Token
|
index.LeaseToken = req.Token
|
||||||
|
|
||||||
|
index.Type = cacheboltdb.SecretLeaseType
|
||||||
|
|
||||||
case secret.Auth != nil:
|
case secret.Auth != nil:
|
||||||
c.logger.Debug("processing auth response", "method", req.Request.Method, "path", req.Request.URL.Path)
|
c.logger.Debug("processing auth response", "method", req.Request.Method, "path", req.Request.URL.Path)
|
||||||
|
|
||||||
|
@ -360,6 +385,8 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
|
||||||
index.Token = secret.Auth.ClientToken
|
index.Token = secret.Auth.ClientToken
|
||||||
index.TokenAccessor = secret.Auth.Accessor
|
index.TokenAccessor = secret.Auth.Accessor
|
||||||
|
|
||||||
|
index.Type = cacheboltdb.AuthLeaseType
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// We shouldn't be hitting this, but will err on the side of caution and
|
// We shouldn't be hitting this, but will err on the side of caution and
|
||||||
// simply proxy.
|
// simply proxy.
|
||||||
|
@ -394,9 +421,14 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
|
||||||
DoneCh: renewCtxInfo.DoneCh,
|
DoneCh: renewCtxInfo.DoneCh,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add extra information necessary for restoring from persisted cache
|
||||||
|
index.RequestMethod = req.Request.Method
|
||||||
|
index.RequestToken = req.Token
|
||||||
|
index.RequestHeader = req.Request.Header
|
||||||
|
|
||||||
// Store the index in the cache
|
// Store the index in the cache
|
||||||
c.logger.Debug("storing response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path)
|
c.logger.Debug("storing response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path)
|
||||||
err = c.db.Set(index)
|
err = c.Set(ctx, index)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Error("failed to cache the proxied response", "error", err)
|
c.logger.Error("failed to cache the proxied response", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -420,8 +452,12 @@ func (c *LeaseCache) createCtxInfo(ctx context.Context) *cachememdb.ContextInfo
|
||||||
func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index, req *SendRequest, secret *api.Secret) {
|
func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index, req *SendRequest, secret *api.Secret) {
|
||||||
defer func() {
|
defer func() {
|
||||||
id := ctx.Value(contextIndexID).(string)
|
id := ctx.Value(contextIndexID).(string)
|
||||||
|
if c.shuttingDown.Load() {
|
||||||
|
c.logger.Trace("not evicting index from cache during shutdown", "id", id, "method", req.Request.Method, "path", req.Request.URL.Path)
|
||||||
|
return
|
||||||
|
}
|
||||||
c.logger.Debug("evicting index from cache", "id", id, "method", req.Request.Method, "path", req.Request.URL.Path)
|
c.logger.Debug("evicting index from cache", "id", id, "method", req.Request.Method, "path", req.Request.URL.Path)
|
||||||
err := c.db.Evict(cachememdb.IndexNameID, id)
|
err := c.Evict(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Error("failed to evict index", "id", id, "error", err)
|
c.logger.Error("failed to evict index", "id", id, "error", err)
|
||||||
return
|
return
|
||||||
|
@ -466,6 +502,11 @@ func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index,
|
||||||
return
|
return
|
||||||
case <-watcher.RenewCh():
|
case <-watcher.RenewCh():
|
||||||
c.logger.Debug("secret renewed", "path", req.Request.URL.Path)
|
c.logger.Debug("secret renewed", "path", req.Request.URL.Path)
|
||||||
|
if c.ps != nil {
|
||||||
|
if err := c.updateLastRenewed(ctx, index, time.Now().UTC()); err != nil {
|
||||||
|
c.logger.Warn("not able to update lastRenewed time for cached index", "id", index.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
case <-index.RenewCtxInfo.DoneCh:
|
case <-index.RenewCtxInfo.DoneCh:
|
||||||
// This case indicates the renewal process to shutdown and evict
|
// This case indicates the renewal process to shutdown and evict
|
||||||
// the cache entry. This is triggered when a specific secret
|
// the cache entry. This is triggered when a specific secret
|
||||||
|
@ -477,6 +518,22 @@ func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *LeaseCache) updateLastRenewed(ctx context.Context, index *cachememdb.Index, t time.Time) error {
|
||||||
|
idLock := locksutil.LockForKey(c.idLocks, index.ID)
|
||||||
|
idLock.Lock()
|
||||||
|
defer idLock.Unlock()
|
||||||
|
|
||||||
|
getIndex, err := c.db.Get(cachememdb.IndexNameID, index.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
index.LastRenewed = t
|
||||||
|
if err := c.Set(ctx, getIndex); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// computeIndexID results in a value that uniquely identifies a request
|
// computeIndexID results in a value that uniquely identifies a request
|
||||||
// received by the agent. It does so by SHA256 hashing the serialized request
|
// received by the agent. It does so by SHA256 hashing the serialized request
|
||||||
// object containing the request path, query parameters and body parameters.
|
// object containing the request path, query parameters and body parameters.
|
||||||
|
@ -642,8 +699,8 @@ func (c *LeaseCache) handleCacheClear(ctx context.Context, in *cacheClearInput)
|
||||||
}
|
}
|
||||||
c.l.Unlock()
|
c.l.Unlock()
|
||||||
|
|
||||||
// Reset the memdb instance
|
// Reset the memdb instance (and persistent storage if enabled)
|
||||||
if err := c.db.Flush(); err != nil {
|
if err := c.Flush(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -850,6 +907,213 @@ func (c *LeaseCache) handleRevocationRequest(ctx context.Context, req *SendReque
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set stores the index in the cachememdb, and also stores it in the persistent
|
||||||
|
// cache (if enabled)
|
||||||
|
func (c *LeaseCache) Set(ctx context.Context, index *cachememdb.Index) error {
|
||||||
|
if err := c.db.Set(index); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.ps != nil {
|
||||||
|
b, err := index.Serialize()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.ps.Set(ctx, index.ID, b, index.Type); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.logger.Trace("set entry in persistent storage", "type", index.Type, "path", index.RequestPath, "id", index.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evict removes an Index from the cachememdb, and also removes it from the
|
||||||
|
// persistent cache (if enabled)
|
||||||
|
func (c *LeaseCache) Evict(id string) error {
|
||||||
|
if err := c.db.Evict(cachememdb.IndexNameID, id); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.ps != nil {
|
||||||
|
if err := c.ps.Delete(id); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.logger.Trace("deleted item from persistent storage", "id", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush the cachememdb and persistent cache (if enabled)
|
||||||
|
func (c *LeaseCache) Flush() error {
|
||||||
|
if err := c.db.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.ps != nil {
|
||||||
|
c.logger.Trace("clearing persistent storage")
|
||||||
|
return c.ps.Clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore loads the cachememdb from the persistent storage passed in. Loads
|
||||||
|
// tokens first, since restoring a lease's renewal context and watcher requires
|
||||||
|
// looking up the token in the cachememdb.
|
||||||
|
func (c *LeaseCache) Restore(ctx context.Context, storage *cacheboltdb.BoltStorage) error {
|
||||||
|
// Process tokens first
|
||||||
|
tokens, err := storage.GetByType(ctx, cacheboltdb.TokenType)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := c.restoreTokens(tokens); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then process auth leases
|
||||||
|
authLeases, err := storage.GetByType(ctx, cacheboltdb.AuthLeaseType)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := c.restoreLeases(authLeases); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then process secret leases
|
||||||
|
secretLeases, err := storage.GetByType(ctx, cacheboltdb.SecretLeaseType)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := c.restoreLeases(secretLeases); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LeaseCache) restoreTokens(tokens [][]byte) error {
|
||||||
|
for _, token := range tokens {
|
||||||
|
newIndex, err := cachememdb.Deserialize(token)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
newIndex.RenewCtxInfo = c.createCtxInfo(nil)
|
||||||
|
if err := c.db.Set(newIndex); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.logger.Trace("restored token", "id", newIndex.ID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LeaseCache) restoreLeases(leases [][]byte) error {
|
||||||
|
for _, lease := range leases {
|
||||||
|
newIndex, err := cachememdb.Deserialize(lease)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this lease has already expired
|
||||||
|
expired, err := c.hasExpired(time.Now().UTC(), newIndex)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Warn("failed to check if lease is expired", "id", newIndex.ID, "error", err)
|
||||||
|
}
|
||||||
|
if expired {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.restoreLeaseRenewCtx(newIndex); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := c.db.Set(newIndex); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.logger.Trace("restored lease", "id", newIndex.ID, "path", newIndex.RequestPath)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// restoreLeaseRenewCtx re-creates a RenewCtx for an index object and starts
|
||||||
|
// the watcher go routine
|
||||||
|
func (c *LeaseCache) restoreLeaseRenewCtx(index *cachememdb.Index) error {
|
||||||
|
if index.Response == nil {
|
||||||
|
return fmt.Errorf("cached response was nil for %s", index.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the secret to determine which type it is
|
||||||
|
reader := bufio.NewReader(bytes.NewReader(index.Response))
|
||||||
|
resp, err := http.ReadResponse(reader, nil)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("failed to deserialize response", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
secret, err := api.ParseSecret(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("failed to parse response as secret", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var renewCtxInfo *cachememdb.ContextInfo
|
||||||
|
switch {
|
||||||
|
case secret.LeaseID != "":
|
||||||
|
entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if entry == nil {
|
||||||
|
return fmt.Errorf("could not find parent Token %s for req path %s", index.RequestToken, index.RequestPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive a context for renewal using the token's context
|
||||||
|
renewCtxInfo = cachememdb.NewContextInfo(entry.RenewCtxInfo.Ctx)
|
||||||
|
|
||||||
|
case secret.Auth != nil:
|
||||||
|
var parentCtx context.Context
|
||||||
|
if !secret.Auth.Orphan {
|
||||||
|
entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// If parent token is not managed by the agent, child shouldn't be
|
||||||
|
// either.
|
||||||
|
if entry == nil {
|
||||||
|
return fmt.Errorf("could not find parent Token %s for req path %s", index.RequestToken, index.RequestPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Debug("setting parent context", "method", index.RequestMethod, "path", index.RequestPath)
|
||||||
|
parentCtx = entry.RenewCtxInfo.Ctx
|
||||||
|
}
|
||||||
|
renewCtxInfo = c.createCtxInfo(parentCtx)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown cached index item: %s", index.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
renewCtx := context.WithValue(renewCtxInfo.Ctx, contextIndexID, index.ID)
|
||||||
|
index.RenewCtxInfo = &cachememdb.ContextInfo{
|
||||||
|
Ctx: renewCtx,
|
||||||
|
CancelFunc: renewCtxInfo.CancelFunc,
|
||||||
|
DoneCh: renewCtxInfo.DoneCh,
|
||||||
|
}
|
||||||
|
|
||||||
|
sendReq := &SendRequest{
|
||||||
|
Token: index.RequestToken,
|
||||||
|
Request: &http.Request{
|
||||||
|
Header: index.RequestHeader,
|
||||||
|
Method: index.RequestMethod,
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: index.RequestPath,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
go c.startRenewing(renewCtx, index, sendReq, secret)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// deriveNamespaceAndRevocationPath returns the namespace and relative path for
|
// deriveNamespaceAndRevocationPath returns the namespace and relative path for
|
||||||
// revocation paths.
|
// revocation paths.
|
||||||
//
|
//
|
||||||
|
@ -912,9 +1176,11 @@ func (c *LeaseCache) RegisterAutoAuthToken(token string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the index is found, defer its cancelFunc
|
// If the index is found, just keep it in the cache and ignore the incoming
|
||||||
|
// token (since they're the same)
|
||||||
if oldIndex != nil {
|
if oldIndex != nil {
|
||||||
defer oldIndex.RenewCtxInfo.CancelFunc()
|
c.logger.Trace("auto-auth token already exists in cache; no need to store it again")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// The following randomly generated values are required for index stored by
|
// The following randomly generated values are required for index stored by
|
||||||
|
@ -938,6 +1204,7 @@ func (c *LeaseCache) RegisterAutoAuthToken(token string) error {
|
||||||
Token: token,
|
Token: token,
|
||||||
Namespace: namespace,
|
Namespace: namespace,
|
||||||
RequestPath: requestPath,
|
RequestPath: requestPath,
|
||||||
|
Type: cacheboltdb.TokenType,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Derive a context off of the lease cache's base context
|
// Derive a context off of the lease cache's base context
|
||||||
|
@ -951,7 +1218,7 @@ func (c *LeaseCache) RegisterAutoAuthToken(token string) error {
|
||||||
|
|
||||||
// Store the index in the cache
|
// Store the index in the cache
|
||||||
c.logger.Debug("storing auto-auth token into the cache")
|
c.logger.Debug("storing auto-auth token into the cache")
|
||||||
err = c.db.Set(index)
|
err = c.Set(c.baseCtxInfo.Ctx, index)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Error("failed to cache the auto-auth token", "error", err)
|
c.logger.Error("failed to cache the auto-auth token", "error", err)
|
||||||
return err
|
return err
|
||||||
|
@ -997,3 +1264,32 @@ func parseCacheClearInput(req *cacheClearRequest) (*cacheClearInput, error) {
|
||||||
|
|
||||||
return in, nil
|
return in, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *LeaseCache) hasExpired(currentTime time.Time, index *cachememdb.Index) (bool, error) {
|
||||||
|
reader := bufio.NewReader(bytes.NewReader(index.Response))
|
||||||
|
resp, err := http.ReadResponse(reader, nil)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to deserialize response: %w", err)
|
||||||
|
}
|
||||||
|
secret, err := api.ParseSecret(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to parse response as secret: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
elapsed := currentTime.Sub(index.LastRenewed)
|
||||||
|
var leaseDuration int
|
||||||
|
switch index.Type {
|
||||||
|
case cacheboltdb.AuthLeaseType:
|
||||||
|
leaseDuration = secret.Auth.LeaseDuration
|
||||||
|
case cacheboltdb.SecretLeaseType:
|
||||||
|
leaseDuration = secret.LeaseDuration
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("index type %q unexpected in expiration check", index.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
if int(elapsed.Seconds()) > leaseDuration {
|
||||||
|
c.logger.Trace("secret has expired", "id", index.ID, "elapsed", elapsed, "lease duration", leaseDuration)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
|
@ -3,9 +3,11 @@ package cache
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -15,9 +17,13 @@ import (
|
||||||
"github.com/go-test/deep"
|
"github.com/go-test/deep"
|
||||||
hclog "github.com/hashicorp/go-hclog"
|
hclog "github.com/hashicorp/go-hclog"
|
||||||
"github.com/hashicorp/vault/api"
|
"github.com/hashicorp/vault/api"
|
||||||
|
"github.com/hashicorp/vault/command/agent/cache/cacheboltdb"
|
||||||
"github.com/hashicorp/vault/command/agent/cache/cachememdb"
|
"github.com/hashicorp/vault/command/agent/cache/cachememdb"
|
||||||
|
"github.com/hashicorp/vault/command/agent/cache/keymanager"
|
||||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -63,6 +69,24 @@ func testNewLeaseCacheWithDelay(t *testing.T, cacheable bool, delay int) *LeaseC
|
||||||
return lc
|
return lc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testNewLeaseCacheWithPersistence(t *testing.T, responses []*SendResponse, storage *cacheboltdb.BoltStorage) *LeaseCache {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
client, err := api.NewClient(api.DefaultConfig())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
lc, err := NewLeaseCache(&LeaseCacheConfig{
|
||||||
|
Client: client,
|
||||||
|
BaseContext: context.Background(),
|
||||||
|
Proxier: newMockProxier(responses),
|
||||||
|
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"),
|
||||||
|
Storage: storage,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return lc
|
||||||
|
}
|
||||||
|
|
||||||
func TestCache_ComputeIndexID(t *testing.T) {
|
func TestCache_ComputeIndexID(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
req *http.Request
|
req *http.Request
|
||||||
|
@ -649,3 +673,394 @@ func TestLeaseCache_Concurrent_Cacheable(t *testing.T) {
|
||||||
t.Fatalf("Should have returned a cached response 99 times, got %d", cacheCount.Load())
|
t.Fatalf("Should have returned a cached response 99 times, got %d", cacheCount.Load())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setupBoltStorage(t *testing.T) (tempCacheDir string, boltStorage *cacheboltdb.BoltStorage) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
km, err := keymanager.NewPassthroughKeyManager(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tempCacheDir, err = ioutil.TempDir("", "agent-cache-test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
boltStorage, err = cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{
|
||||||
|
Path: tempCacheDir,
|
||||||
|
Logger: hclog.Default(),
|
||||||
|
Wrapper: km.Wrapper(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, boltStorage)
|
||||||
|
// The calling function should `defer boltStorage.Close()` and `defer os.RemoveAll(tempCacheDir)`
|
||||||
|
return tempCacheDir, boltStorage
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLeaseCache_PersistAndRestore(t *testing.T) {
|
||||||
|
// Emulate 4 responses from the api proxy. The first two use the auto-auth
|
||||||
|
// token, and the last two use another token.
|
||||||
|
responses := []*SendResponse{
|
||||||
|
newTestSendResponse(200, `{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": 600}}`),
|
||||||
|
newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}, "lease_duration": 600}`),
|
||||||
|
newTestSendResponse(202, `{"auth": {"client_token": "testtoken2", "renewable": true, "orphan": true, "lease_duration": 600}}`),
|
||||||
|
newTestSendResponse(203, `{"lease_id": "secret2-lease", "renewable": true, "data": {"number": "two"}, "lease_duration": 600}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
tempDir, boltStorage := setupBoltStorage(t)
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
defer boltStorage.Close()
|
||||||
|
lc := testNewLeaseCacheWithPersistence(t, responses, boltStorage)
|
||||||
|
|
||||||
|
// Register an auto-auth token so that the token and lease requests are cached
|
||||||
|
lc.RegisterAutoAuthToken("autoauthtoken")
|
||||||
|
|
||||||
|
cacheTests := []struct {
|
||||||
|
token string
|
||||||
|
method string
|
||||||
|
urlPath string
|
||||||
|
body string
|
||||||
|
wantStatusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
// Make a request. A response with a new token is returned to the
|
||||||
|
// lease cache and that will be cached.
|
||||||
|
token: "autoauthtoken",
|
||||||
|
method: "GET",
|
||||||
|
urlPath: "http://example.com/v1/sample/api",
|
||||||
|
body: `{"value": "input"}`,
|
||||||
|
wantStatusCode: responses[0].Response.StatusCode,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Modify the request a little bit to ensure the second response is
|
||||||
|
// returned to the lease cache.
|
||||||
|
token: "autoauthtoken",
|
||||||
|
method: "GET",
|
||||||
|
urlPath: "http://example.com/v1/sample/api",
|
||||||
|
body: `{"value": "input_changed"}`,
|
||||||
|
wantStatusCode: responses[1].Response.StatusCode,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Simulate an approle login to get another token
|
||||||
|
method: "PUT",
|
||||||
|
urlPath: "http://example.com/v1/auth/approle/login",
|
||||||
|
body: `{"role_id": "my role", "secret_id": "my secret"}`,
|
||||||
|
wantStatusCode: responses[2].Response.StatusCode,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Test caching with the token acquired from the approle login
|
||||||
|
token: "testtoken2",
|
||||||
|
method: "GET",
|
||||||
|
urlPath: "http://example.com/v1/sample2/api",
|
||||||
|
body: `{"second": "input"}`,
|
||||||
|
wantStatusCode: responses[3].Response.StatusCode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ct := range cacheTests {
|
||||||
|
// Send once to cache
|
||||||
|
sendReq := &SendRequest{
|
||||||
|
Token: ct.token,
|
||||||
|
Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)),
|
||||||
|
}
|
||||||
|
resp, err := lc.Send(context.Background(), sendReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, resp.Response.StatusCode, ct.wantStatusCode, "expected proxied response")
|
||||||
|
assert.Nil(t, resp.CacheMeta)
|
||||||
|
|
||||||
|
// Send again to test cache. If this isn't cached, the response returned
|
||||||
|
// will be the next in the list and the status code will not match.
|
||||||
|
sendCacheReq := &SendRequest{
|
||||||
|
Token: ct.token,
|
||||||
|
Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)),
|
||||||
|
}
|
||||||
|
respCached, err := lc.Send(context.Background(), sendCacheReq)
|
||||||
|
require.NoError(t, err, "failed to send request %+v", ct)
|
||||||
|
assert.Equal(t, respCached.Response.StatusCode, ct.wantStatusCode, "expected proxied response")
|
||||||
|
require.NotNil(t, respCached.CacheMeta)
|
||||||
|
assert.True(t, respCached.CacheMeta.Hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we know the cache is working, so try restoring from the persisted
|
||||||
|
// cache's storage
|
||||||
|
restoredCache := testNewLeaseCache(t, nil)
|
||||||
|
|
||||||
|
err := restoredCache.Restore(context.Background(), boltStorage)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Now compare before and after
|
||||||
|
beforeDB, err := lc.db.GetByPrefix(cachememdb.IndexNameID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, beforeDB, 5)
|
||||||
|
|
||||||
|
for _, cachedItem := range beforeDB {
|
||||||
|
restoredItem, err := restoredCache.db.Get(cachememdb.IndexNameID, cachedItem.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, cachedItem.ID, restoredItem.ID)
|
||||||
|
assert.Equal(t, cachedItem.Lease, restoredItem.Lease)
|
||||||
|
assert.Equal(t, cachedItem.LeaseToken, restoredItem.LeaseToken)
|
||||||
|
assert.Equal(t, cachedItem.Namespace, restoredItem.Namespace)
|
||||||
|
assert.Equal(t, cachedItem.RequestHeader, restoredItem.RequestHeader)
|
||||||
|
assert.Equal(t, cachedItem.RequestMethod, restoredItem.RequestMethod)
|
||||||
|
assert.Equal(t, cachedItem.RequestPath, restoredItem.RequestPath)
|
||||||
|
assert.Equal(t, cachedItem.RequestToken, restoredItem.RequestToken)
|
||||||
|
assert.Equal(t, cachedItem.Response, restoredItem.Response)
|
||||||
|
assert.Equal(t, cachedItem.Token, restoredItem.Token)
|
||||||
|
assert.Equal(t, cachedItem.TokenAccessor, restoredItem.TokenAccessor)
|
||||||
|
assert.Equal(t, cachedItem.TokenParent, restoredItem.TokenParent)
|
||||||
|
|
||||||
|
// check what we can in the renewal context
|
||||||
|
assert.NotEmpty(t, restoredItem.RenewCtxInfo.CancelFunc)
|
||||||
|
assert.NotZero(t, restoredItem.RenewCtxInfo.DoneCh)
|
||||||
|
require.NotEmpty(t, restoredItem.RenewCtxInfo.Ctx)
|
||||||
|
assert.Equal(t,
|
||||||
|
cachedItem.RenewCtxInfo.Ctx.Value(contextIndexID),
|
||||||
|
restoredItem.RenewCtxInfo.Ctx.Value(contextIndexID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
afterDB, err := restoredCache.db.GetByPrefix(cachememdb.IndexNameID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, afterDB, 5)
|
||||||
|
|
||||||
|
// And finally send the cache requests once to make sure they're all being
|
||||||
|
// served from the restoredCache
|
||||||
|
for _, ct := range cacheTests {
|
||||||
|
sendCacheReq := &SendRequest{
|
||||||
|
Token: ct.token,
|
||||||
|
Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)),
|
||||||
|
}
|
||||||
|
respCached, err := restoredCache.Send(context.Background(), sendCacheReq)
|
||||||
|
require.NoError(t, err, "failed to send request %+v", ct)
|
||||||
|
assert.Equal(t, respCached.Response.StatusCode, ct.wantStatusCode, "expected proxied response")
|
||||||
|
require.NotNil(t, respCached.CacheMeta)
|
||||||
|
assert.True(t, respCached.CacheMeta.Hit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEvictPersistent(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
responses := []*SendResponse{
|
||||||
|
newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
tempDir, boltStorage := setupBoltStorage(t)
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
defer boltStorage.Close()
|
||||||
|
lc := testNewLeaseCacheWithPersistence(t, responses, boltStorage)
|
||||||
|
|
||||||
|
lc.RegisterAutoAuthToken("autoauthtoken")
|
||||||
|
|
||||||
|
// populate cache by sending request through
|
||||||
|
sendReq := &SendRequest{
|
||||||
|
Token: "autoauthtoken",
|
||||||
|
Request: httptest.NewRequest("GET", "http://example.com/v1/sample/api", strings.NewReader(`{"value": "some_input"}`)),
|
||||||
|
}
|
||||||
|
resp, err := lc.Send(context.Background(), sendReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, resp.Response.StatusCode, 201, "expected proxied response")
|
||||||
|
assert.Nil(t, resp.CacheMeta)
|
||||||
|
|
||||||
|
// Check bolt for the cached lease
|
||||||
|
secrets, err := lc.ps.GetByType(ctx, cacheboltdb.SecretLeaseType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, secrets, 1)
|
||||||
|
|
||||||
|
// Call clear for the request path
|
||||||
|
err = lc.handleCacheClear(context.Background(), &cacheClearInput{
|
||||||
|
Type: "request_path",
|
||||||
|
RequestPath: "/v1/sample/api",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
// Check that cached item is gone
|
||||||
|
secrets, err = lc.ps.GetByType(ctx, cacheboltdb.SecretLeaseType)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, secrets, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterAutoAuth_sameToken(t *testing.T) {
|
||||||
|
// If the auto-auth token already exists in the cache, it should not be
|
||||||
|
// stored again in a new index.
|
||||||
|
lc := testNewLeaseCache(t, nil)
|
||||||
|
err := lc.RegisterAutoAuthToken("autoauthtoken")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
oldTokenIndex, err := lc.db.Get(cachememdb.IndexNameToken, "autoauthtoken")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
oldTokenID := oldTokenIndex.ID
|
||||||
|
|
||||||
|
// register the same token again
|
||||||
|
err = lc.RegisterAutoAuthToken("autoauthtoken")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// check that there's only one index for autoauthtoken
|
||||||
|
entries, err := lc.db.GetByPrefix(cachememdb.IndexNameToken, "autoauthtoken")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, entries, 1)
|
||||||
|
|
||||||
|
newTokenIndex, err := lc.db.Get(cachememdb.IndexNameToken, "autoauthtoken")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// compare the ID's since those are randomly generated when an index for a
|
||||||
|
// token is added to the cache, so if a new token was added, the id's will
|
||||||
|
// not match.
|
||||||
|
assert.Equal(t, oldTokenID, newTokenIndex.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_hasExpired(t *testing.T) {
|
||||||
|
|
||||||
|
responses := []*SendResponse{
|
||||||
|
newTestSendResponse(200, `{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": 60}}`),
|
||||||
|
newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}, "lease_duration": 60}`),
|
||||||
|
}
|
||||||
|
lc := testNewLeaseCache(t, responses)
|
||||||
|
lc.RegisterAutoAuthToken("autoauthtoken")
|
||||||
|
|
||||||
|
cacheTests := []struct {
|
||||||
|
token string
|
||||||
|
urlPath string
|
||||||
|
leaseType string
|
||||||
|
wantStatusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
// auth lease
|
||||||
|
token: "autoauthtoken",
|
||||||
|
urlPath: "/v1/sample/auth",
|
||||||
|
leaseType: cacheboltdb.AuthLeaseType,
|
||||||
|
wantStatusCode: responses[0].Response.StatusCode,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// secret lease
|
||||||
|
token: "autoauthtoken",
|
||||||
|
urlPath: "/v1/sample/secret",
|
||||||
|
leaseType: cacheboltdb.SecretLeaseType,
|
||||||
|
wantStatusCode: responses[1].Response.StatusCode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ct := range cacheTests {
|
||||||
|
// Send once to cache
|
||||||
|
urlPath := "http://example.com" + ct.urlPath
|
||||||
|
sendReq := &SendRequest{
|
||||||
|
Token: ct.token,
|
||||||
|
Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)),
|
||||||
|
}
|
||||||
|
resp, err := lc.Send(context.Background(), sendReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, resp.Response.StatusCode, ct.wantStatusCode, "expected proxied response")
|
||||||
|
assert.Nil(t, resp.CacheMeta)
|
||||||
|
|
||||||
|
// get the Index out of the mem cache
|
||||||
|
index, err := lc.db.Get(cachememdb.IndexNameRequestPath, "root/", ct.urlPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, ct.leaseType, index.Type)
|
||||||
|
|
||||||
|
// The lease duration is 60 seconds, so time.Now() should be within that
|
||||||
|
notExpired, err := lc.hasExpired(time.Now().UTC(), index)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, notExpired)
|
||||||
|
|
||||||
|
// In 90 seconds the index should be "expired"
|
||||||
|
futureTime := time.Now().UTC().Add(time.Second * 90)
|
||||||
|
expired, err := lc.hasExpired(futureTime, index)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, expired)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLeaseCache_hasExpired_wrong_type(t *testing.T) {
|
||||||
|
index := &cachememdb.Index{
|
||||||
|
Type: cacheboltdb.TokenType,
|
||||||
|
Response: []byte(`HTTP/0.0 200 OK
|
||||||
|
Content-Type: application/json
|
||||||
|
Date: Tue, 02 Mar 2021 17:54:16 GMT
|
||||||
|
|
||||||
|
{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": 60}}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
lc := testNewLeaseCache(t, nil)
|
||||||
|
expired, err := lc.hasExpired(time.Now().UTC(), index)
|
||||||
|
assert.False(t, expired)
|
||||||
|
assert.EqualError(t, err, `index type "token" unexpected in expiration check`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLeaseCacheRestore_expired(t *testing.T) {
|
||||||
|
// Emulate 2 responses from the api proxy, both expired
|
||||||
|
responses := []*SendResponse{
|
||||||
|
newTestSendResponse(200, `{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": -600}}`),
|
||||||
|
newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}, "lease_duration": -600}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
tempDir, boltStorage := setupBoltStorage(t)
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
defer boltStorage.Close()
|
||||||
|
lc := testNewLeaseCacheWithPersistence(t, responses, boltStorage)
|
||||||
|
|
||||||
|
// Register an auto-auth token so that the token and lease requests are cached in mem
|
||||||
|
lc.RegisterAutoAuthToken("autoauthtoken")
|
||||||
|
|
||||||
|
cacheTests := []struct {
|
||||||
|
token string
|
||||||
|
method string
|
||||||
|
urlPath string
|
||||||
|
body string
|
||||||
|
wantStatusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
// Make a request. A response with a new token is returned to the
|
||||||
|
// lease cache and that will be cached.
|
||||||
|
token: "autoauthtoken",
|
||||||
|
method: "GET",
|
||||||
|
urlPath: "http://example.com/v1/sample/api",
|
||||||
|
body: `{"value": "input"}`,
|
||||||
|
wantStatusCode: responses[0].Response.StatusCode,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Modify the request a little bit to ensure the second response is
|
||||||
|
// returned to the lease cache.
|
||||||
|
token: "autoauthtoken",
|
||||||
|
method: "GET",
|
||||||
|
urlPath: "http://example.com/v1/sample/api",
|
||||||
|
body: `{"value": "input_changed"}`,
|
||||||
|
wantStatusCode: responses[1].Response.StatusCode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ct := range cacheTests {
|
||||||
|
// Send once to cache
|
||||||
|
sendReq := &SendRequest{
|
||||||
|
Token: ct.token,
|
||||||
|
Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)),
|
||||||
|
}
|
||||||
|
resp, err := lc.Send(context.Background(), sendReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, resp.Response.StatusCode, ct.wantStatusCode, "expected proxied response")
|
||||||
|
assert.Nil(t, resp.CacheMeta)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore from the persisted cache's storage
|
||||||
|
restoredCache := testNewLeaseCache(t, nil)
|
||||||
|
|
||||||
|
err := restoredCache.Restore(context.Background(), boltStorage)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// The original mem cache should have all three items
|
||||||
|
beforeDB, err := lc.db.GetByPrefix(cachememdb.IndexNameID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, beforeDB, 3)
|
||||||
|
|
||||||
|
// There should only be one item in the restored cache: the autoauth token
|
||||||
|
afterDB, err := restoredCache.db.GetByPrefix(cachememdb.IndexNameID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, afterDB, 1)
|
||||||
|
|
||||||
|
// Just verify that the one item in the restored mem cache matches one in the original mem cache, and that it's the auto-auth token
|
||||||
|
beforeItem, err := lc.db.Get(cachememdb.IndexNameID, afterDB[0].ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, beforeItem)
|
||||||
|
|
||||||
|
assert.Equal(t, "autoauthtoken", afterDB[0].Token)
|
||||||
|
assert.Equal(t, cacheboltdb.TokenType, afterDB[0].Type)
|
||||||
|
}
|
||||||
|
|
|
@ -50,6 +50,16 @@ type Cache struct {
|
||||||
ForceAutoAuthToken bool `hcl:"-"`
|
ForceAutoAuthToken bool `hcl:"-"`
|
||||||
EnforceConsistency string `hcl:"enforce_consistency"`
|
EnforceConsistency string `hcl:"enforce_consistency"`
|
||||||
WhenInconsistent string `hcl:"when_inconsistent"`
|
WhenInconsistent string `hcl:"when_inconsistent"`
|
||||||
|
Persist *Persist `hcl:"persist"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Persist contains configuration needed for persistent caching
|
||||||
|
type Persist struct {
|
||||||
|
Type string
|
||||||
|
Path string `hcl:"path"`
|
||||||
|
KeepAfterImport bool `hcl:"keep_after_import"`
|
||||||
|
ExitOnErr bool `hcl:"exit_on_err"`
|
||||||
|
ServiceAccountTokenFile string `hcl:"service_account_token_file"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutoAuth is the configured authentication method and sinks
|
// AutoAuth is the configured authentication method and sinks
|
||||||
|
@ -309,8 +319,51 @@ func parseCache(result *Config, list *ast.ObjectList) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result.Cache = &c
|
result.Cache = &c
|
||||||
|
|
||||||
|
subs, ok := item.Val.(*ast.ObjectType)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("could not parse %q as an object", name)
|
||||||
|
}
|
||||||
|
subList := subs.List
|
||||||
|
if err := parsePersist(result, subList); err != nil {
|
||||||
|
return fmt.Errorf("error parsing persist: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePersist(result *Config, list *ast.ObjectList) error {
|
||||||
|
name := "persist"
|
||||||
|
|
||||||
|
persistList := list.Filter(name)
|
||||||
|
if len(persistList.Items) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(persistList.Items) > 1 {
|
||||||
|
return fmt.Errorf("only one %q block is required", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
item := persistList.Items[0]
|
||||||
|
|
||||||
|
var p Persist
|
||||||
|
err := hcl.DecodeObject(&p, item.Val)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Type == "" {
|
||||||
|
if len(item.Keys) == 1 {
|
||||||
|
p.Type = strings.ToLower(item.Keys[0].Token.Value().(string))
|
||||||
|
}
|
||||||
|
if p.Type == "" {
|
||||||
|
return errors.New("persist type must be specified")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Cache.Persist = &p
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -66,6 +66,13 @@ func TestLoadConfigFile_AgentCache(t *testing.T) {
|
||||||
UseAutoAuthToken: true,
|
UseAutoAuthToken: true,
|
||||||
UseAutoAuthTokenRaw: true,
|
UseAutoAuthTokenRaw: true,
|
||||||
ForceAutoAuthToken: false,
|
ForceAutoAuthToken: false,
|
||||||
|
Persist: &Persist{
|
||||||
|
Type: "kubernetes",
|
||||||
|
Path: "/vault/agent-cache/",
|
||||||
|
KeepAfterImport: true,
|
||||||
|
ExitOnErr: true,
|
||||||
|
ServiceAccountTokenFile: "/tmp/serviceaccount/token",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Vault: &Vault{
|
Vault: &Vault{
|
||||||
Address: "http://127.0.0.1:1111",
|
Address: "http://127.0.0.1:1111",
|
||||||
|
@ -445,6 +452,52 @@ func TestLoadConfigFile_AgentCache_AutoAuth_False(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigFile_AgentCache_Persist(t *testing.T) {
|
||||||
|
config, err := LoadConfig("./test-fixtures/config-cache-persist-false.hcl")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := &Config{
|
||||||
|
Cache: &Cache{
|
||||||
|
Persist: &Persist{
|
||||||
|
Type: "kubernetes",
|
||||||
|
Path: "/vault/agent-cache/",
|
||||||
|
KeepAfterImport: false,
|
||||||
|
ExitOnErr: false,
|
||||||
|
ServiceAccountTokenFile: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SharedConfig: &configutil.SharedConfig{
|
||||||
|
PidFile: "./pidfile",
|
||||||
|
Listeners: []*configutil.Listener{
|
||||||
|
{
|
||||||
|
Type: "tcp",
|
||||||
|
Address: "127.0.0.1:8300",
|
||||||
|
TLSDisable: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Listeners[0].RawConfig = nil
|
||||||
|
if diff := deep.Equal(config, expected); diff != nil {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Listeners[0].RawConfig = nil
|
||||||
|
if diff := deep.Equal(config, expected); diff != nil {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadConfigFile_AgentCache_PersistMissingType(t *testing.T) {
|
||||||
|
_, err := LoadConfig("./test-fixtures/config-cache-persist-empty-type.hcl")
|
||||||
|
if err == nil || os.IsNotExist(err) {
|
||||||
|
t.Fatal("expected error or file is missing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestLoadConfigFile_Template tests template definitions in Vault Agent
|
// TestLoadConfigFile_Template tests template definitions in Vault Agent
|
||||||
func TestLoadConfigFile_Template(t *testing.T) {
|
func TestLoadConfigFile_Template(t *testing.T) {
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
|
|
|
@ -21,6 +21,12 @@ auto_auth {
|
||||||
|
|
||||||
cache {
|
cache {
|
||||||
use_auto_auth_token = true
|
use_auto_auth_token = true
|
||||||
|
persist "kubernetes" {
|
||||||
|
path = "/vault/agent-cache/"
|
||||||
|
keep_after_import = true
|
||||||
|
exit_on_err = true
|
||||||
|
service_account_token_file = "/tmp/serviceaccount/token"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
listener {
|
listener {
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
pid_file = "./pidfile"
|
||||||
|
|
||||||
|
cache {
|
||||||
|
persist = {
|
||||||
|
path = "/vault/agent-cache/"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
listener "tcp" {
|
||||||
|
address = "127.0.0.1:8300"
|
||||||
|
tls_disable = true
|
||||||
|
}
|
|
@ -0,0 +1,14 @@
|
||||||
|
pid_file = "./pidfile"
|
||||||
|
|
||||||
|
cache {
|
||||||
|
persist "kubernetes" {
|
||||||
|
exit_on_err = false
|
||||||
|
keep_after_import = false
|
||||||
|
path = "/vault/agent-cache/"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
listener "tcp" {
|
||||||
|
address = "127.0.0.1:8300"
|
||||||
|
tls_disable = true
|
||||||
|
}
|
|
@ -21,6 +21,13 @@ auto_auth {
|
||||||
|
|
||||||
cache {
|
cache {
|
||||||
use_auto_auth_token = true
|
use_auto_auth_token = true
|
||||||
|
persist = {
|
||||||
|
type = "kubernetes"
|
||||||
|
path = "/vault/agent-cache/"
|
||||||
|
keep_after_import = true
|
||||||
|
exit_on_err = true
|
||||||
|
service_account_token_file = "/tmp/serviceaccount/token"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
listener "unix" {
|
listener "unix" {
|
||||||
|
|
|
@ -256,7 +256,7 @@ func newRunnerConfig(sc *ServerConfig, templates ctconfig.TemplateConfigs) (*ctc
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use the cache if available or fallback to the Vault server values.
|
// Use the cache if available or fallback to the Vault server values.
|
||||||
if sc.AgentConfig.Cache != nil && len(sc.AgentConfig.Listeners) != 0 {
|
if sc.AgentConfig.Cache != nil && sc.AgentConfig.Cache.Persist != nil && len(sc.AgentConfig.Listeners) != 0 {
|
||||||
scheme := "unix://"
|
scheme := "unix://"
|
||||||
if sc.AgentConfig.Listeners[0].Type == "tcp" {
|
if sc.AgentConfig.Listeners[0].Type == "tcp" {
|
||||||
scheme = "https://"
|
scheme = "https://"
|
||||||
|
|
|
@ -28,7 +28,7 @@ func TestNewServer(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAgentConfig(listeners []*configutil.Listener, enableCache bool) *config.Config {
|
func newAgentConfig(listeners []*configutil.Listener, enableCache, enablePersisentCache bool) *config.Config {
|
||||||
agentConfig := &config.Config{
|
agentConfig := &config.Config{
|
||||||
SharedConfig: &configutil.SharedConfig{
|
SharedConfig: &configutil.SharedConfig{
|
||||||
PidFile: "./pidfile",
|
PidFile: "./pidfile",
|
||||||
|
@ -65,7 +65,13 @@ func newAgentConfig(listeners []*configutil.Listener, enableCache bool) *config.
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if enableCache {
|
if enableCache {
|
||||||
agentConfig.Cache = &config.Cache{UseAutoAuthToken: true}
|
agentConfig.Cache = &config.Cache{
|
||||||
|
UseAutoAuthToken: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if enablePersisentCache {
|
||||||
|
agentConfig.Cache.Persist = &config.Persist{Type: "kubernetes"}
|
||||||
}
|
}
|
||||||
|
|
||||||
return agentConfig
|
return agentConfig
|
||||||
|
@ -94,7 +100,7 @@ func TestCacheConfigUnix(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
agentConfig := newAgentConfig(listeners, true)
|
agentConfig := newAgentConfig(listeners, true, true)
|
||||||
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
||||||
|
|
||||||
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
||||||
|
@ -131,7 +137,7 @@ func TestCacheConfigHTTP(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
agentConfig := newAgentConfig(listeners, true)
|
agentConfig := newAgentConfig(listeners, true, true)
|
||||||
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
||||||
|
|
||||||
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
||||||
|
@ -168,7 +174,7 @@ func TestCacheConfigHTTPS(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
agentConfig := newAgentConfig(listeners, true)
|
agentConfig := newAgentConfig(listeners, true, true)
|
||||||
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
||||||
|
|
||||||
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
||||||
|
@ -209,7 +215,44 @@ func TestCacheConfigNoCache(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
agentConfig := newAgentConfig(listeners, false)
|
agentConfig := newAgentConfig(listeners, false, false)
|
||||||
|
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
||||||
|
|
||||||
|
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := "http://127.0.0.1:1111"
|
||||||
|
if *ctConfig.Vault.Address != expected {
|
||||||
|
t.Fatalf("expected %s, got %s", expected, *ctConfig.Vault.Address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheConfigNoPersistentCache(t *testing.T) {
|
||||||
|
listeners := []*configutil.Listener{
|
||||||
|
{
|
||||||
|
Type: "tcp",
|
||||||
|
Address: "127.0.0.1:8300",
|
||||||
|
TLSKeyFile: "/path/to/cakey.pem",
|
||||||
|
TLSCertFile: "/path/to/cacert.pem",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "unix",
|
||||||
|
Address: "foobar",
|
||||||
|
TLSDisable: true,
|
||||||
|
SocketMode: "configmode",
|
||||||
|
SocketUser: "configuser",
|
||||||
|
SocketGroup: "configgroup",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "tcp",
|
||||||
|
Address: "127.0.0.1:8400",
|
||||||
|
TLSDisable: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
agentConfig := newAgentConfig(listeners, true, false)
|
||||||
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
||||||
|
|
||||||
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
||||||
|
@ -226,7 +269,7 @@ func TestCacheConfigNoCache(t *testing.T) {
|
||||||
func TestCacheConfigNoListener(t *testing.T) {
|
func TestCacheConfigNoListener(t *testing.T) {
|
||||||
listeners := []*configutil.Listener{}
|
listeners := []*configutil.Listener{}
|
||||||
|
|
||||||
agentConfig := newAgentConfig(listeners, true)
|
agentConfig := newAgentConfig(listeners, true, true)
|
||||||
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
||||||
|
|
||||||
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
||||||
|
@ -264,7 +307,7 @@ func TestCacheConfigRejectMTLS(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
agentConfig := newAgentConfig(listeners, true)
|
agentConfig := newAgentConfig(listeners, true, true)
|
||||||
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
serverConfig := ServerConfig{AgentConfig: agentConfig}
|
||||||
|
|
||||||
_, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
_, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})
|
||||||
|
|
Loading…
Reference in New Issue