diff --git a/changelog/10927.txt b/changelog/10927.txt index 80828de57..67ac0389b 100644 --- a/changelog/10927.txt +++ b/changelog/10927.txt @@ -1,3 +1,3 @@ ```release-note:improvement -agent: Route templating server through cache when enabled. +agent: Route templating server through cache when persistent cache is enabled. ``` diff --git a/changelog/10938.txt b/changelog/10938.txt new file mode 100644 index 000000000..841c37a9c --- /dev/null +++ b/changelog/10938.txt @@ -0,0 +1,3 @@ +```release-note:feature +agent: Support for persisting the agent cache to disk +``` diff --git a/command/agent.go b/command/agent.go index 8b3b53127..e1a714a38 100644 --- a/command/agent.go +++ b/command/agent.go @@ -6,10 +6,12 @@ import ( "flag" "fmt" "io" + "io/ioutil" "net" "net/http" "os" "path" + "path/filepath" "sort" "strings" "sync" @@ -30,6 +32,9 @@ import ( "github.com/hashicorp/vault/command/agent/auth/kerberos" "github.com/hashicorp/vault/command/agent/auth/kubernetes" "github.com/hashicorp/vault/command/agent/cache" + "github.com/hashicorp/vault/command/agent/cache/cacheboltdb" + "github.com/hashicorp/vault/command/agent/cache/cachememdb" + "github.com/hashicorp/vault/command/agent/cache/keymanager" agentConfig "github.com/hashicorp/vault/command/agent/config" "github.com/hashicorp/vault/command/agent/sink" "github.com/hashicorp/vault/command/agent/sink/file" @@ -461,6 +466,8 @@ func (c *AgentCommand) Run(args []string) int { c.UI.Output("==> Vault agent started! Log data will stream in below:\n") } + var leaseCache *cache.LeaseCache + var previousToken string // Parse agent listener configurations if config.Cache != nil && len(config.Listeners) != 0 { cacheLogger := c.logger.Named("cache") @@ -479,7 +486,7 @@ func (c *AgentCommand) Run(args []string) int { // Create the lease cache proxier and set its underlying proxier to // the API proxier. - leaseCache, err := cache.NewLeaseCache(&cache.LeaseCacheConfig{ + leaseCache, err = cache.NewLeaseCache(&cache.LeaseCacheConfig{ Client: client, BaseContext: ctx, Proxier: apiProxy, @@ -490,6 +497,152 @@ func (c *AgentCommand) Run(args []string) int { return 1 } + // Configure persistent storage and add to LeaseCache + if config.Cache.Persist != nil { + if config.Cache.Persist.Path == "" { + c.UI.Error("must specify persistent cache path") + return 1 + } + + // Set AAD based on key protection type + var aad string + switch config.Cache.Persist.Type { + case "kubernetes": + aad, err = getServiceAccountJWT(config.Cache.Persist.ServiceAccountTokenFile) + if err != nil { + c.UI.Error(fmt.Sprintf("failed to read service account token from %s: %s", config.Cache.Persist.ServiceAccountTokenFile, err)) + return 1 + } + default: + c.UI.Error(fmt.Sprintf("persistent key protection type %q not supported", config.Cache.Persist.Type)) + return 1 + } + + // Check if bolt file exists already + dbFileExists, err := cacheboltdb.DBFileExists(config.Cache.Persist.Path) + if err != nil { + c.UI.Error(fmt.Sprintf("failed to check if bolt file exists at path %s: %s", config.Cache.Persist.Path, err)) + return 1 + } + if dbFileExists { + // Open the bolt file, but wait to setup Encryption + ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{ + Path: config.Cache.Persist.Path, + Logger: cacheLogger.Named("cacheboltdb"), + }) + if err != nil { + c.UI.Error(fmt.Sprintf("Error opening persistent cache: %v", err)) + return 1 + } + + // Get the token from bolt for retrieving the encryption key, + // then setup encryption so that restore is possible + token, err := ps.GetRetrievalToken() + if err != nil { + c.UI.Error(fmt.Sprintf("Error getting retrieval token from persistent cache: %v", err)) + } + + if err := ps.Close(); err != nil { + c.UI.Warn(fmt.Sprintf("Failed to close persistent cache file after getting retrieval token: %s", err)) + } + + km, err := keymanager.NewPassthroughKeyManager(token) + if err != nil { + c.UI.Error(fmt.Sprintf("failed to configure persistence encryption for cache: %s", err)) + return 1 + } + + // Open the bolt file with the wrapper provided + ps, err = cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{ + Path: config.Cache.Persist.Path, + Logger: cacheLogger.Named("cacheboltdb"), + Wrapper: km.Wrapper(), + AAD: aad, + }) + if err != nil { + c.UI.Error(fmt.Sprintf("Error opening persistent cache: %v", err)) + return 1 + } + + // Restore anything in the persistent cache to the memory cache + if err := leaseCache.Restore(ctx, ps); err != nil { + c.UI.Error(fmt.Sprintf("Error restoring in-memory cache from persisted file: %v", err)) + if config.Cache.Persist.ExitOnErr { + return 1 + } + } + cacheLogger.Info("loaded memcache from persistent storage") + + // Check for previous auto-auth token + oldTokenBytes, err := ps.GetAutoAuthToken(ctx) + if err != nil { + c.UI.Error(fmt.Sprintf("Error in fetching previous auto-auth token: %s", err)) + if config.Cache.Persist.ExitOnErr { + return 1 + } + } + if len(oldTokenBytes) > 0 { + oldToken, err := cachememdb.Deserialize(oldTokenBytes) + if err != nil { + c.UI.Error(fmt.Sprintf("Error in deserializing previous auto-auth token cache entry: %s", err)) + if config.Cache.Persist.ExitOnErr { + return 1 + } + } + previousToken = oldToken.Token + } + + // If keep_after_import true, set persistent storage layer in + // leaseCache, else remove db file + if config.Cache.Persist.KeepAfterImport { + defer ps.Close() + leaseCache.SetPersistentStorage(ps) + } else { + if err := ps.Close(); err != nil { + c.UI.Warn(fmt.Sprintf("failed to close persistent cache file: %s", err)) + } + dbFile := filepath.Join(config.Cache.Persist.Path, cacheboltdb.DatabaseFileName) + if err := os.Remove(dbFile); err != nil { + c.UI.Error(fmt.Sprintf("failed to remove persistent storage file %s: %s", dbFile, err)) + if config.Cache.Persist.ExitOnErr { + return 1 + } + } + } + } else { + km, err := keymanager.NewPassthroughKeyManager(nil) + if err != nil { + c.UI.Error(fmt.Sprintf("failed to configure persistence encryption for cache: %s", err)) + return 1 + } + ps, err := cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{ + Path: config.Cache.Persist.Path, + Logger: cacheLogger.Named("cacheboltdb"), + Wrapper: km.Wrapper(), + AAD: aad, + }) + if err != nil { + c.UI.Error(fmt.Sprintf("Error creating persistent cache: %v", err)) + return 1 + } + cacheLogger.Info("configured persistent storage", "path", config.Cache.Persist.Path) + + // Stash the key material in bolt + token, err := km.RetrievalToken() + if err != nil { + c.UI.Error(fmt.Sprintf("Error getting persistent key: %s", err)) + return 1 + } + if err := ps.StoreRetrievalToken(token); err != nil { + c.UI.Error(fmt.Sprintf("Error setting key in persistent cache: %v", err)) + return 1 + } + + defer ps.Close() + leaseCache.SetPersistentStorage(ps) + } + } + var inmemSink sink.Sink if config.Cache.UseAutoAuthToken { cacheLogger.Debug("auto-auth token is allowed to be used; configuring inmem sink") @@ -585,6 +738,11 @@ func (c *AgentCommand) Run(args []string) int { select { case <-c.ShutdownCh: c.UI.Output("==> Vault agent shutdown triggered") + // Let the lease cache know this is a shutdown; no need to evict + // everything + if leaseCache != nil { + leaseCache.SetShuttingDown(true) + } return nil case <-ctx.Done(): return nil @@ -604,6 +762,7 @@ func (c *AgentCommand) Run(args []string) int { MaxBackoff: config.AutoAuth.Method.MaxBackoff, EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials, EnableTemplateTokenCh: enableTokenCh, + Token: previousToken, }) ss := sink.NewSinkServer(&sink.SinkServerConfig{ @@ -624,6 +783,11 @@ func (c *AgentCommand) Run(args []string) int { g.Add(func() error { return ah.Run(ctx, method) }, func(error) { + // Let the lease cache know this is a shutdown; no need to evict + // everything + if leaseCache != nil { + leaseCache.SetShuttingDown(true) + } cancelFunc() }) @@ -650,12 +814,22 @@ func (c *AgentCommand) Run(args []string) int { return err }, func(error) { + // Let the lease cache know this is a shutdown; no need to evict + // everything + if leaseCache != nil { + leaseCache.SetShuttingDown(true) + } cancelFunc() }) g.Add(func() error { return ts.Run(ctx, ah.TemplateTokenCh, config.Templates) }, func(error) { + // Let the lease cache know this is a shutdown; no need to evict + // everything + if leaseCache != nil { + leaseCache.SetShuttingDown(true) + } cancelFunc() ts.Stop() }) @@ -793,3 +967,16 @@ func (c *AgentCommand) removePidFile(pidPath string) error { } return os.Remove(pidPath) } + +// GetServiceAccountJWT reads the service account jwt from `tokenFile`. Default is +// the default service account file path in kubernetes. +func getServiceAccountJWT(tokenFile string) (string, error) { + if len(tokenFile) == 0 { + tokenFile = "/var/run/secrets/kubernetes.io/serviceaccount/token" + } + token, err := ioutil.ReadFile(tokenFile) + if err != nil { + return "", err + } + return strings.TrimSpace(string(token)), nil +} diff --git a/command/agent/cache/cacheboltdb/bolt.go b/command/agent/cache/cacheboltdb/bolt.go new file mode 100644 index 000000000..afff5c647 --- /dev/null +++ b/command/agent/cache/cacheboltdb/bolt.go @@ -0,0 +1,303 @@ +package cacheboltdb + +import ( + "context" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/golang/protobuf/proto" + "github.com/hashicorp/go-hclog" + wrapping "github.com/hashicorp/go-kms-wrapping" + "github.com/hashicorp/go-multierror" + bolt "go.etcd.io/bbolt" +) + +const ( + // Keep track of schema version for future migrations + storageVersionKey = "version" + storageVersion = "1" + + // DatabaseFileName - filename for the persistent cache file + DatabaseFileName = "vault-agent-cache.db" + + // metaBucketName - naming the meta bucket that holds the version and + // bootstrapping keys + metaBucketName = "meta" + + // SecretLeaseType - Bucket/type for leases with secret info + SecretLeaseType = "secret-lease" + + // AuthLeaseType - Bucket/type for leases with auth info + AuthLeaseType = "auth-lease" + + // TokenType - Bucket/type for auto-auth tokens + TokenType = "token" + + // AutoAuthToken - key for the latest auto-auth token + AutoAuthToken = "auto-auth-token" + + // RetrievalTokenMaterial is the actual key or token in the key bucket + RetrievalTokenMaterial = "retrieval-token-material" +) + +// BoltStorage is a persistent cache using a bolt db. Items are organized with +// the version and bootstrapping items in the "meta" bucket, and tokens, auth +// leases, and secret leases in their own buckets. +type BoltStorage struct { + db *bolt.DB + logger hclog.Logger + wrapper wrapping.Wrapper + aad string +} + +// BoltStorageConfig is the collection of input parameters for setting up bolt +// storage +type BoltStorageConfig struct { + Path string + Logger hclog.Logger + Wrapper wrapping.Wrapper + AAD string +} + +// NewBoltStorage opens a new bolt db at the specified file path and returns it. +// If the db already exists the buckets will just be created if they don't +// exist. +func NewBoltStorage(config *BoltStorageConfig) (*BoltStorage, error) { + dbPath := filepath.Join(config.Path, DatabaseFileName) + db, err := bolt.Open(dbPath, 0600, &bolt.Options{Timeout: 1 * time.Second}) + if err != nil { + return nil, err + } + err = db.Update(func(tx *bolt.Tx) error { + return createBoltSchema(tx) + }) + if err != nil { + return nil, err + } + bs := &BoltStorage{ + db: db, + logger: config.Logger, + wrapper: config.Wrapper, + aad: config.AAD, + } + return bs, nil +} + +func createBoltSchema(tx *bolt.Tx) error { + // create the meta bucket at the top level + meta, err := tx.CreateBucketIfNotExists([]byte(metaBucketName)) + if err != nil { + return fmt.Errorf("failed to create bucket %s: %w", metaBucketName, err) + } + // check and set file version in the meta bucket + version := meta.Get([]byte(storageVersionKey)) + switch { + case version == nil: + err = meta.Put([]byte(storageVersionKey), []byte(storageVersion)) + if err != nil { + return fmt.Errorf("failed to set storage version: %w", err) + } + case string(version) != storageVersion: + return fmt.Errorf("storage migration from %s to %s not implemented", string(version), storageVersion) + } + + // create the buckets for tokens and leases + _, err = tx.CreateBucketIfNotExists([]byte(TokenType)) + if err != nil { + return fmt.Errorf("failed to create token bucket: %w", err) + } + _, err = tx.CreateBucketIfNotExists([]byte(AuthLeaseType)) + if err != nil { + return fmt.Errorf("failed to create auth lease bucket: %w", err) + } + _, err = tx.CreateBucketIfNotExists([]byte(SecretLeaseType)) + if err != nil { + return fmt.Errorf("failed to create secret lease bucket: %w", err) + } + + return nil +} + +// Set an index (token or lease) in bolt storage +func (b *BoltStorage) Set(ctx context.Context, id string, plaintext []byte, indexType string) error { + blob, err := b.wrapper.Encrypt(ctx, plaintext, []byte(b.aad)) + if err != nil { + return fmt.Errorf("error encrypting %s index: %w", indexType, err) + } + + protoBlob, err := proto.Marshal(blob) + if err != nil { + return err + } + + return b.db.Update(func(tx *bolt.Tx) error { + s := tx.Bucket([]byte(indexType)) + if s == nil { + return fmt.Errorf("bucket %q not found", indexType) + } + // If this is an auto-auth token, also stash it in the meta bucket for + // easy retrieval upon restore + if indexType == TokenType { + meta := tx.Bucket([]byte(metaBucketName)) + if err := meta.Put([]byte(AutoAuthToken), protoBlob); err != nil { + return fmt.Errorf("failed to set latest auto-auth token: %w", err) + } + } + return s.Put([]byte(id), protoBlob) + }) +} + +func getBucketIDs(b *bolt.Bucket) ([][]byte, error) { + ids := [][]byte{} + err := b.ForEach(func(k, v []byte) error { + ids = append(ids, k) + return nil + }) + return ids, err +} + +// Delete an index (token or lease) by id from bolt storage +func (b *BoltStorage) Delete(id string) error { + return b.db.Update(func(tx *bolt.Tx) error { + // Since Delete returns a nil error if the key doesn't exist, just call + // delete in all three index buckets without checking existence first + if err := tx.Bucket([]byte(TokenType)).Delete([]byte(id)); err != nil { + return fmt.Errorf("failed to delete %q from token bucket: %w", id, err) + } + if err := tx.Bucket([]byte(AuthLeaseType)).Delete([]byte(id)); err != nil { + return fmt.Errorf("failed to delete %q from auth lease bucket: %w", id, err) + } + if err := tx.Bucket([]byte(SecretLeaseType)).Delete([]byte(id)); err != nil { + return fmt.Errorf("failed to delete %q from secret lease bucket: %w", id, err) + } + b.logger.Trace("deleted index from bolt db", "id", id) + return nil + }) +} + +func (b *BoltStorage) decrypt(ctx context.Context, ciphertext []byte) ([]byte, error) { + var blob wrapping.EncryptedBlobInfo + if err := proto.Unmarshal(ciphertext, &blob); err != nil { + return nil, err + } + + return b.wrapper.Decrypt(ctx, &blob, []byte(b.aad)) +} + +// GetByType returns a list of stored items of the specified type +func (b *BoltStorage) GetByType(ctx context.Context, indexType string) ([][]byte, error) { + var returnBytes [][]byte + + err := b.db.View(func(tx *bolt.Tx) error { + var errors *multierror.Error + + tx.Bucket([]byte(indexType)).ForEach(func(id, ciphertext []byte) error { + plaintext, err := b.decrypt(ctx, ciphertext) + if err != nil { + errors = multierror.Append(errors, fmt.Errorf("error decrypting index id %s: %w", id, err)) + return nil + } + + returnBytes = append(returnBytes, plaintext) + return nil + }) + return errors.ErrorOrNil() + }) + + return returnBytes, err +} + +// GetAutoAuthToken retrieves the latest auto-auth token, and returns nil if non +// exists yet +func (b *BoltStorage) GetAutoAuthToken(ctx context.Context) ([]byte, error) { + var encryptedToken []byte + + err := b.db.View(func(tx *bolt.Tx) error { + meta := tx.Bucket([]byte(metaBucketName)) + if meta == nil { + return fmt.Errorf("bucket %q not found", metaBucketName) + } + encryptedToken = meta.Get([]byte(AutoAuthToken)) + return nil + }) + if err != nil { + return nil, err + } + + if encryptedToken == nil { + return nil, nil + } + + plaintext, err := b.decrypt(ctx, encryptedToken) + if err != nil { + return nil, fmt.Errorf("failed to decrypt auto-auth token: %w", err) + } + return plaintext, nil +} + +// GetRetrievalToken retrieves a plaintext token from the KeyBucket, which will +// be used by the key manager to retrieve the encryption key, nil if none set +func (b *BoltStorage) GetRetrievalToken() ([]byte, error) { + var token []byte + + err := b.db.View(func(tx *bolt.Tx) error { + keyBucket := tx.Bucket([]byte(metaBucketName)) + if keyBucket == nil { + return fmt.Errorf("bucket %q not found", metaBucketName) + } + token = keyBucket.Get([]byte(RetrievalTokenMaterial)) + return nil + }) + if err != nil { + return nil, err + } + + return token, err +} + +// StoreRetrievalToken sets plaintext token material in the RetrievalTokenBucket +func (b *BoltStorage) StoreRetrievalToken(token []byte) error { + return b.db.Update(func(tx *bolt.Tx) error { + bucket := tx.Bucket([]byte(metaBucketName)) + if bucket == nil { + return fmt.Errorf("bucket %q not found", metaBucketName) + } + return bucket.Put([]byte(RetrievalTokenMaterial), token) + }) +} + +// Close the boltdb +func (b *BoltStorage) Close() error { + b.logger.Trace("closing bolt db", "path", b.db.Path()) + return b.db.Close() +} + +// Clear the boltdb by deleting all the token and lease buckets and recreating +// the schema/layout +func (b *BoltStorage) Clear() error { + return b.db.Update(func(tx *bolt.Tx) error { + for _, name := range []string{AuthLeaseType, SecretLeaseType, TokenType} { + b.logger.Trace("deleting bolt bucket", "name", name) + if err := tx.DeleteBucket([]byte(name)); err != nil { + return err + } + } + return createBoltSchema(tx) + }) +} + +// DBFileExists checks whether the vault agent cache file at `filePath` exists +func DBFileExists(path string) (bool, error) { + checkFile, err := os.OpenFile(filepath.Join(path, DatabaseFileName), os.O_RDWR, 0600) + defer checkFile.Close() + switch { + case err == nil: + return true, nil + case os.IsNotExist(err): + return false, nil + default: + return false, fmt.Errorf("failed to check if bolt file exists at path %s: %w", path, err) + } +} diff --git a/command/agent/cache/cacheboltdb/bolt_test.go b/command/agent/cache/cacheboltdb/bolt_test.go new file mode 100644 index 000000000..de457620b --- /dev/null +++ b/command/agent/cache/cacheboltdb/bolt_test.go @@ -0,0 +1,263 @@ +package cacheboltdb + +import ( + "context" + "io/ioutil" + "os" + "path" + "testing" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/command/agent/cache/keymanager" + "github.com/ory/dockertest/v3/docker/pkg/ioutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func getTestKeyManager(t *testing.T) keymanager.KeyManager { + t.Helper() + + km, err := keymanager.NewPassthroughKeyManager(nil) + require.NoError(t, err) + + return km +} + +func TestBolt_SetGet(t *testing.T) { + ctx := context.Background() + + path, err := ioutils.TempDir("", "bolt-test") + require.NoError(t, err) + defer os.RemoveAll(path) + + b, err := NewBoltStorage(&BoltStorageConfig{ + Path: path, + Logger: hclog.Default(), + Wrapper: getTestKeyManager(t).Wrapper(), + }) + require.NoError(t, err) + + secrets, err := b.GetByType(ctx, SecretLeaseType) + assert.NoError(t, err) + require.Len(t, secrets, 0) + + err = b.Set(ctx, "test1", []byte("hello"), SecretLeaseType) + assert.NoError(t, err) + secrets, err = b.GetByType(ctx, SecretLeaseType) + assert.NoError(t, err) + require.Len(t, secrets, 1) + assert.Equal(t, []byte("hello"), secrets[0]) +} + +func TestBoltDelete(t *testing.T) { + ctx := context.Background() + + path, err := ioutils.TempDir("", "bolt-test") + require.NoError(t, err) + defer os.RemoveAll(path) + + b, err := NewBoltStorage(&BoltStorageConfig{ + Path: path, + Logger: hclog.Default(), + Wrapper: getTestKeyManager(t).Wrapper(), + }) + require.NoError(t, err) + + err = b.Set(ctx, "secret-test1", []byte("hello1"), SecretLeaseType) + require.NoError(t, err) + err = b.Set(ctx, "secret-test2", []byte("hello2"), SecretLeaseType) + require.NoError(t, err) + + secrets, err := b.GetByType(ctx, SecretLeaseType) + require.NoError(t, err) + assert.Len(t, secrets, 2) + assert.ElementsMatch(t, [][]byte{[]byte("hello1"), []byte("hello2")}, secrets) + + err = b.Delete("secret-test1") + require.NoError(t, err) + secrets, err = b.GetByType(ctx, SecretLeaseType) + require.NoError(t, err) + require.Len(t, secrets, 1) + assert.Equal(t, []byte("hello2"), secrets[0]) +} + +func TestBoltClear(t *testing.T) { + ctx := context.Background() + + path, err := ioutils.TempDir("", "bolt-test") + require.NoError(t, err) + defer os.RemoveAll(path) + + b, err := NewBoltStorage(&BoltStorageConfig{ + Path: path, + Logger: hclog.Default(), + Wrapper: getTestKeyManager(t).Wrapper(), + }) + require.NoError(t, err) + + // Populate the bolt db + err = b.Set(ctx, "secret-test1", []byte("hello"), SecretLeaseType) + require.NoError(t, err) + secrets, err := b.GetByType(ctx, SecretLeaseType) + require.NoError(t, err) + require.Len(t, secrets, 1) + assert.Equal(t, []byte("hello"), secrets[0]) + + err = b.Set(ctx, "auth-test1", []byte("hello"), AuthLeaseType) + require.NoError(t, err) + auths, err := b.GetByType(ctx, AuthLeaseType) + require.NoError(t, err) + require.Len(t, auths, 1) + assert.Equal(t, []byte("hello"), auths[0]) + + err = b.Set(ctx, "token-test1", []byte("hello"), TokenType) + require.NoError(t, err) + tokens, err := b.GetByType(ctx, TokenType) + require.NoError(t, err) + require.Len(t, tokens, 1) + assert.Equal(t, []byte("hello"), tokens[0]) + + // Clear the bolt db, and check that it's indeed clear + err = b.Clear() + require.NoError(t, err) + secrets, err = b.GetByType(ctx, SecretLeaseType) + require.NoError(t, err) + assert.Len(t, secrets, 0) + auths, err = b.GetByType(ctx, AuthLeaseType) + require.NoError(t, err) + assert.Len(t, auths, 0) + tokens, err = b.GetByType(ctx, TokenType) + require.NoError(t, err) + assert.Len(t, tokens, 0) +} + +func TestBoltSetAutoAuthToken(t *testing.T) { + ctx := context.Background() + + path, err := ioutils.TempDir("", "bolt-test") + require.NoError(t, err) + defer os.RemoveAll(path) + + b, err := NewBoltStorage(&BoltStorageConfig{ + Path: path, + Logger: hclog.Default(), + Wrapper: getTestKeyManager(t).Wrapper(), + }) + require.NoError(t, err) + + token, err := b.GetAutoAuthToken(ctx) + assert.NoError(t, err) + assert.Nil(t, token) + + // set first token + err = b.Set(ctx, "token-test1", []byte("hello 1"), TokenType) + require.NoError(t, err) + secrets, err := b.GetByType(ctx, TokenType) + require.NoError(t, err) + require.Len(t, secrets, 1) + assert.Equal(t, []byte("hello 1"), secrets[0]) + token, err = b.GetAutoAuthToken(ctx) + assert.NoError(t, err) + assert.Equal(t, []byte("hello 1"), token) + + // set second token + err = b.Set(ctx, "token-test2", []byte("hello 2"), TokenType) + require.NoError(t, err) + secrets, err = b.GetByType(ctx, TokenType) + require.NoError(t, err) + require.Len(t, secrets, 2) + assert.ElementsMatch(t, [][]byte{[]byte("hello 1"), []byte("hello 2")}, secrets) + token, err = b.GetAutoAuthToken(ctx) + assert.NoError(t, err) + assert.Equal(t, []byte("hello 2"), token) +} + +func TestDBFileExists(t *testing.T) { + testCases := []struct { + name string + mkDir bool + createFile bool + expectExist bool + }{ + { + name: "all exists", + mkDir: true, + createFile: true, + expectExist: true, + }, + { + name: "dir exist, file missing", + mkDir: true, + createFile: false, + expectExist: false, + }, + { + name: "all missing", + mkDir: false, + createFile: false, + expectExist: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var tmpPath string + var err error + if tc.mkDir { + tmpPath, err = ioutil.TempDir("", "test-db-path") + require.NoError(t, err) + } + if tc.createFile { + err = ioutil.WriteFile(path.Join(tmpPath, DatabaseFileName), []byte("test-db-path"), 0600) + require.NoError(t, err) + } + exists, err := DBFileExists(tmpPath) + assert.NoError(t, err) + assert.Equal(t, tc.expectExist, exists) + }) + } + +} + +func Test_SetGetRetrievalToken(t *testing.T) { + testCases := []struct { + name string + tokenToSet []byte + expectedToken []byte + }{ + { + name: "normal set and get", + tokenToSet: []byte("test token"), + expectedToken: []byte("test token"), + }, + { + name: "no token set", + tokenToSet: nil, + expectedToken: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + path, err := ioutils.TempDir("", "bolt-test") + require.NoError(t, err) + defer os.RemoveAll(path) + + b, err := NewBoltStorage(&BoltStorageConfig{ + Path: path, + Logger: hclog.Default(), + Wrapper: getTestKeyManager(t).Wrapper(), + }) + require.NoError(t, err) + defer b.Close() + + if tc.tokenToSet != nil { + err := b.StoreRetrievalToken(tc.tokenToSet) + require.NoError(t, err) + } + gotKey, err := b.GetRetrievalToken() + assert.NoError(t, err) + assert.Equal(t, tc.expectedToken, gotKey) + }) + } +} diff --git a/command/agent/cache/cachememdb/index.go b/command/agent/cache/cachememdb/index.go index ae822cf1d..546a528cb 100644 --- a/command/agent/cache/cachememdb/index.go +++ b/command/agent/cache/cachememdb/index.go @@ -1,6 +1,11 @@ package cachememdb -import "context" +import ( + "context" + "encoding/json" + "net/http" + "time" +) // Index holds the response to be cached along with multiple other values that // serve as pointers to refer back to this index. @@ -48,6 +53,21 @@ type Index struct { // goroutine that manages the renewal of the secret belonging to the // response in this index. RenewCtxInfo *ContextInfo + + // RequestMethod is the HTTP method of the request + RequestMethod string + + // RequestToken is the token used in the request + RequestToken string + + // RequestHeader is the header used in the request + RequestHeader http.Header + + // LastRenewed is the timestamp of last renewal + LastRenewed time.Time + + // Type is the index type (token, auth-lease, secret-lease) + Type string } type IndexName uint32 @@ -106,3 +126,25 @@ func NewContextInfo(ctx context.Context) *ContextInfo { ctxInfo.DoneCh = make(chan struct{}) return ctxInfo } + +// Serialize returns a json marshal'ed Index object, without the RenewCtxInfo +func (i Index) Serialize() ([]byte, error) { + i.RenewCtxInfo = nil + + indexBytes, err := json.Marshal(i) + if err != nil { + return nil, err + } + + return indexBytes, nil +} + +// Deserialize converts json bytes to an Index object +// Note: RenewCtxInfo will need to be reconstructed elsewhere. +func Deserialize(indexBytes []byte) (*Index, error) { + index := new(Index) + if err := json.Unmarshal(indexBytes, index); err != nil { + return nil, err + } + return index, nil +} diff --git a/command/agent/cache/cachememdb/index_test.go b/command/agent/cache/cachememdb/index_test.go new file mode 100644 index 000000000..f603399e0 --- /dev/null +++ b/command/agent/cache/cachememdb/index_test.go @@ -0,0 +1,43 @@ +package cachememdb + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSerializeDeserialize(t *testing.T) { + + testIndex := &Index{ + ID: "testid", + Token: "testtoken", + TokenParent: "parent token", + TokenAccessor: "test accessor", + Namespace: "test namespace", + RequestPath: "/test/path", + Lease: "lease id", + LeaseToken: "lease token id", + Response: []byte(`{"something": "here"}`), + RenewCtxInfo: NewContextInfo(context.Background()), + RequestMethod: "GET", + RequestToken: "request token", + RequestHeader: http.Header{ + "X-Test": []string{"vault", "agent"}, + }, + LastRenewed: time.Now().UTC(), + } + indexBytes, err := testIndex.Serialize() + require.NoError(t, err) + assert.True(t, len(indexBytes) > 0) + assert.NotNil(t, testIndex.RenewCtxInfo, "Serialize should not modify original Index object") + + restoredIndex, err := Deserialize(indexBytes) + require.NoError(t, err) + + testIndex.RenewCtxInfo = nil + assert.Equal(t, testIndex, restoredIndex, "They should be equal without RenewCtxInfo set on the original") +} diff --git a/command/agent/cache/crypto/crypto.go b/command/agent/cache/crypto/crypto.go deleted file mode 100644 index 85e20c037..000000000 --- a/command/agent/cache/crypto/crypto.go +++ /dev/null @@ -1,18 +0,0 @@ -package crypto - -import ( - "context" -) - -const ( - KeyID = "root" -) - -type KeyManager interface { - GetKey() []byte - GetPersistentKey() ([]byte, error) - Renewable() bool - Renewer(context.Context) error - Encrypt(context.Context, []byte, []byte) ([]byte, error) - Decrypt(context.Context, []byte, []byte) ([]byte, error) -} diff --git a/command/agent/cache/crypto/k8s.go b/command/agent/cache/crypto/k8s.go deleted file mode 100644 index 36a6d86b8..000000000 --- a/command/agent/cache/crypto/k8s.go +++ /dev/null @@ -1,97 +0,0 @@ -package crypto - -import ( - "context" - "crypto/rand" - "fmt" - - wrapping "github.com/hashicorp/go-kms-wrapping" - "github.com/hashicorp/go-kms-wrapping/wrappers/aead" -) - -var _ KeyManager = (*KubeEncryptionKey)(nil) - -type KubeEncryptionKey struct { - renewable bool - wrapper *aead.Wrapper -} - -// NewK8s returns a new instance of the Kube encryption key. Kubernetes -// encryption keys aren't renewable. -func NewK8s(existingKey []byte) (*KubeEncryptionKey, error) { - k := &KubeEncryptionKey{ - renewable: false, - wrapper: aead.NewWrapper(nil), - } - - k.wrapper.SetConfig(map[string]string{"key_id": KeyID}) - - var rootKey []byte = nil - if len(existingKey) != 0 { - if len(existingKey) != 32 { - return k, fmt.Errorf("invalid key size, should be 32, got %d", len(existingKey)) - } - rootKey = existingKey - } - - if rootKey == nil { - newKey := make([]byte, 32) - _, err := rand.Read(newKey) - if err != nil { - return k, err - } - rootKey = newKey - } - - if err := k.wrapper.SetAESGCMKeyBytes(rootKey); err != nil { - return k, err - } - - return k, nil -} - -// GetKey returns the encryption key in a format optimized for storage. -// In k8s we store the key as is, so just return the key stored. -func (k *KubeEncryptionKey) GetKey() []byte { - return k.wrapper.GetKeyBytes() -} - -// GetPersistentKey returns the key which should be stored in the persisent -// cache. In k8s we store the key as is, so just return the key stored. -func (k *KubeEncryptionKey) GetPersistentKey() ([]byte, error) { - return k.wrapper.GetKeyBytes(), nil -} - -// Renewable lets the caller know if this encryption key type is -// renewable. In Kubernetes the key isn't renewable. -func (k *KubeEncryptionKey) Renewable() bool { - return k.renewable -} - -// Renewer is used when the encryption key type is renewable. Since Kubernetes -// keys aren't renewable, returning nothing. -func (k *KubeEncryptionKey) Renewer(ctx context.Context) error { - return nil -} - -// Encrypt takes plaintext values and encrypts them using the store key and additional -// data. For Kubernetes the AAD should be the service account JWT. -func (k *KubeEncryptionKey) Encrypt(ctx context.Context, plaintext, aad []byte) ([]byte, error) { - blob, err := k.wrapper.Encrypt(ctx, plaintext, aad) - if err != nil { - return nil, err - } - return blob.Ciphertext, nil -} - -// Decrypt takes ciphertext and AAD values and returns the decrypted value. For Kubernetes the AAD -// should be the service account JWT. -func (k *KubeEncryptionKey) Decrypt(ctx context.Context, ciphertext, aad []byte) ([]byte, error) { - blob := &wrapping.EncryptedBlobInfo{ - Ciphertext: ciphertext, - KeyInfo: &wrapping.KeyInfo{ - KeyID: KeyID, - }, - } - return k.wrapper.Decrypt(ctx, blob, aad) -} diff --git a/command/agent/cache/crypto/k8s_test.go b/command/agent/cache/crypto/k8s_test.go deleted file mode 100644 index 01b2f883f..000000000 --- a/command/agent/cache/crypto/k8s_test.go +++ /dev/null @@ -1,155 +0,0 @@ -package crypto - -import ( - "fmt" - "math/rand" - "testing" -) - -func TestCrypto_KubernetesNewKey(t *testing.T) { - k8sKey, err := NewK8s([]byte{}) - if err != nil { - t.Fatalf(fmt.Sprintf("unexpected error: %s", err)) - } - - key := k8sKey.GetKey() - if key == nil { - t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", key)) - } - - persistentKey, _ := k8sKey.GetPersistentKey() - if persistentKey == nil { - t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", persistentKey)) - } - - if string(key) != string(persistentKey) { - t.Fatalf("keys don't match, they should: key: %s, persistentKey: %s", key, persistentKey) - } - - plaintextInput := []byte("test") - aad := []byte("kubernetes") - - ciphertext, err := k8sKey.Encrypt(nil, plaintextInput, aad) - if err != nil { - t.Fatalf(err.Error()) - } - - if ciphertext == nil { - t.Fatalf("ciphertext nil, it shouldn't be") - } - - plaintext, err := k8sKey.Decrypt(nil, ciphertext, aad) - if err != nil { - t.Fatalf(err.Error()) - } - - if string(plaintext) != string(plaintextInput) { - t.Fatalf("expected %s, got %s", plaintextInput, plaintext) - } -} - -func TestCrypto_KubernetesExistingKey(t *testing.T) { - rootKey := make([]byte, 32) - n, err := rand.Read(rootKey) - if err != nil { - t.Fatal(err) - } - if n != 32 { - t.Fatal(n) - } - - k8sKey, err := NewK8s(rootKey) - if err != nil { - t.Fatalf(fmt.Sprintf("unexpected error: %s", err)) - } - - key := k8sKey.GetKey() - if key == nil { - t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", key)) - } - - if string(key) != string(rootKey) { - t.Fatalf(fmt.Sprintf("expected keys to be the same, they weren't: expected: %s, got: %s", rootKey, key)) - } - - persistentKey, _ := k8sKey.GetPersistentKey() - if persistentKey == nil { - t.Fatalf("key is nil, it shouldn't be") - } - - if string(persistentKey) != string(rootKey) { - t.Fatalf(fmt.Sprintf("expected keys to be the same, they weren't: expected: %s, got: %s", rootKey, persistentKey)) - } - - if string(key) != string(persistentKey) { - t.Fatalf(fmt.Sprintf("expected keys to be the same, they weren't: %s %s", rootKey, persistentKey)) - } - - plaintextInput := []byte("test") - aad := []byte("kubernetes") - - ciphertext, err := k8sKey.Encrypt(nil, plaintextInput, aad) - if err != nil { - t.Fatalf(err.Error()) - } - - if ciphertext == nil { - t.Fatalf("ciphertext nil, it shouldn't be") - } - - plaintext, err := k8sKey.Decrypt(nil, ciphertext, aad) - if err != nil { - t.Fatalf(err.Error()) - } - - if string(plaintext) != string(plaintextInput) { - t.Fatalf("expected %s, got %s", plaintextInput, plaintext) - } -} - -func TestCrypto_KubernetesPassGeneratedKey(t *testing.T) { - k8sFirstKey, err := NewK8s([]byte{}) - if err != nil { - t.Fatalf(fmt.Sprintf("unexpected error: %s", err)) - } - - firstPersistentKey := k8sFirstKey.GetKey() - if firstPersistentKey == nil { - t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", firstPersistentKey)) - } - - plaintextInput := []byte("test") - aad := []byte("kubernetes") - - ciphertext, err := k8sFirstKey.Encrypt(nil, plaintextInput, aad) - if err != nil { - t.Fatalf(err.Error()) - } - - if ciphertext == nil { - t.Fatalf("ciphertext nil, it shouldn't be") - } - - k8sLoadedKey, err := NewK8s(firstPersistentKey) - if err != nil { - t.Fatalf(fmt.Sprintf("unexpected error: %s", err)) - } - - loadedKey, _ := k8sLoadedKey.GetPersistentKey() - if loadedKey == nil { - t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", loadedKey)) - } - - if string(loadedKey) != string(firstPersistentKey) { - t.Fatalf(fmt.Sprintf("keys do not match")) - } - - plaintext, err := k8sLoadedKey.Decrypt(nil, ciphertext, aad) - if err != nil { - t.Fatalf(err.Error()) - } - - if string(plaintext) != string(plaintextInput) { - t.Fatalf("expected %s, got %s", plaintextInput, plaintext) - } -} diff --git a/command/agent/cache/keymanager/manager.go b/command/agent/cache/keymanager/manager.go new file mode 100644 index 000000000..c69598623 --- /dev/null +++ b/command/agent/cache/keymanager/manager.go @@ -0,0 +1,16 @@ +package keymanager + +import wrapping "github.com/hashicorp/go-kms-wrapping" + +const ( + KeyID = "root" +) + +type KeyManager interface { + // Returns a wrapping.Wrapper which can be used to perform key-related operations. + Wrapper() wrapping.Wrapper + // RetrievalToken is the material returned which can be used to source back the + // encryption key. Depending on the implementation, the token can be the + // encryption key itself or a token/identifier used to exchange the token. + RetrievalToken() ([]byte, error) +} diff --git a/command/agent/cache/keymanager/passthrough.go b/command/agent/cache/keymanager/passthrough.go new file mode 100644 index 000000000..a4aff2eba --- /dev/null +++ b/command/agent/cache/keymanager/passthrough.go @@ -0,0 +1,68 @@ +package keymanager + +import ( + "crypto/rand" + "fmt" + + wrapping "github.com/hashicorp/go-kms-wrapping" + "github.com/hashicorp/go-kms-wrapping/wrappers/aead" +) + +var _ KeyManager = (*PassthroughKeyManager)(nil) + +type PassthroughKeyManager struct { + wrapper *aead.Wrapper +} + +// NewPassthroughKeyManager returns a new instance of the Kube encryption key. +// If a key is provided, it will be used as the encryption key for the wrapper, +// otherwise one will be generated. +func NewPassthroughKeyManager(key []byte) (*PassthroughKeyManager, error) { + + var rootKey []byte = nil + switch len(key) { + case 0: + newKey := make([]byte, 32) + _, err := rand.Read(newKey) + if err != nil { + return nil, err + } + rootKey = newKey + case 32: + rootKey = key + default: + return nil, fmt.Errorf("invalid key size, should be 32, got %d", len(key)) + } + + wrapper := aead.NewWrapper(nil) + + if _, err := wrapper.SetConfig(map[string]string{"key_id": KeyID}); err != nil { + return nil, err + } + + if err := wrapper.SetAESGCMKeyBytes(rootKey); err != nil { + return nil, err + } + + k := &PassthroughKeyManager{ + wrapper: wrapper, + } + + return k, nil +} + +// Wrapper returns the manager's wrapper for key operations. +func (w *PassthroughKeyManager) Wrapper() wrapping.Wrapper { + return w.wrapper +} + +// RetrievalToken returns the key that was used on the wrapper since this key +// manager is simply a passthrough and does not provide a mechanism to abstract +// this key. +func (w *PassthroughKeyManager) RetrievalToken() ([]byte, error) { + if w.wrapper == nil { + return nil, fmt.Errorf("unable to get wrapper for token retrieval") + } + + return w.wrapper.GetKeyBytes(), nil +} diff --git a/command/agent/cache/keymanager/passthrough_test.go b/command/agent/cache/keymanager/passthrough_test.go new file mode 100644 index 000000000..794f15bc2 --- /dev/null +++ b/command/agent/cache/keymanager/passthrough_test.go @@ -0,0 +1,56 @@ +package keymanager + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestKeyManager_PassthrougKeyManager(t *testing.T) { + tests := []struct { + name string + key []byte + wantErr bool + }{ + { + "new key", + nil, + false, + }, + { + "existing valid key", + []byte("e679e2f3d8d0e489d408bc617c6890d6"), + false, + }, + { + "invalid key length", + []byte("foobar"), + true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + m, err := NewPassthroughKeyManager(tc.key) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + if w := m.Wrapper(); w == nil { + t.Fatalf("expected non-nil wrapper from the key manager") + } + + token, err := m.RetrievalToken() + if err != nil { + t.Fatalf("unable to retrieve token: %s", err) + } + + if len(tc.key) != 0 && !bytes.Equal(tc.key, token) { + t.Fatalf("expected key bytes: %x, got: %x", tc.key, token) + } + }) + } +} diff --git a/command/agent/cache/lease_cache.go b/command/agent/cache/lease_cache.go index 7d3867520..da7aae6f3 100644 --- a/command/agent/cache/lease_cache.go +++ b/command/agent/cache/lease_cache.go @@ -11,6 +11,7 @@ import ( "io" "io/ioutil" "net/http" + "net/url" "strings" "sync" "time" @@ -18,6 +19,7 @@ import ( "github.com/hashicorp/errwrap" hclog "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/command/agent/cache/cacheboltdb" cachememdb "github.com/hashicorp/vault/command/agent/cache/cachememdb" "github.com/hashicorp/vault/helper/namespace" nshelper "github.com/hashicorp/vault/helper/namespace" @@ -83,6 +85,13 @@ type LeaseCache struct { // inflightCache keeps track of inflight requests inflightCache *gocache.Cache + + // ps is the persistent storage for tokens and leases + ps *cacheboltdb.BoltStorage + + // shuttingDown is used to determine if cache needs to be evicted or not + // when the context is cancelled + shuttingDown atomic.Bool } // LeaseCacheConfig is the configuration for initializing a new @@ -92,6 +101,7 @@ type LeaseCacheConfig struct { BaseContext context.Context Proxier Proxier Logger hclog.Logger + Storage *cacheboltdb.BoltStorage } type inflightRequest struct { @@ -141,9 +151,21 @@ func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) { l: &sync.RWMutex{}, idLocks: locksutil.CreateLocks(), inflightCache: gocache.New(gocache.NoExpiration, gocache.NoExpiration), + ps: conf.Storage, }, nil } +// SetShuttingDown is a setter for the shuttingDown field +func (c *LeaseCache) SetShuttingDown(in bool) { + c.shuttingDown.Store(in) +} + +// SetPersistentStorage is a setter for the persistent storage field in +// LeaseCache +func (c *LeaseCache) SetPersistentStorage(storageIn *cacheboltdb.BoltStorage) { + c.ps = storageIn +} + // checkCacheForRequest checks the cache for a particular request based on its // computed ID. It returns a non-nil *SendResponse if an entry is found. func (c *LeaseCache) checkCacheForRequest(id string) (*SendResponse, error) { @@ -275,6 +297,7 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, ID: id, Namespace: namespace, RequestPath: req.Request.URL.Path, + LastRenewed: time.Now().UTC(), } secret, err := api.ParseSecret(bytes.NewReader(resp.ResponseBody)) @@ -332,6 +355,8 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, index.Lease = secret.LeaseID index.LeaseToken = req.Token + index.Type = cacheboltdb.SecretLeaseType + case secret.Auth != nil: c.logger.Debug("processing auth response", "method", req.Request.Method, "path", req.Request.URL.Path) @@ -360,6 +385,8 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, index.Token = secret.Auth.ClientToken index.TokenAccessor = secret.Auth.Accessor + index.Type = cacheboltdb.AuthLeaseType + default: // We shouldn't be hitting this, but will err on the side of caution and // simply proxy. @@ -394,9 +421,14 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, DoneCh: renewCtxInfo.DoneCh, } + // Add extra information necessary for restoring from persisted cache + index.RequestMethod = req.Request.Method + index.RequestToken = req.Token + index.RequestHeader = req.Request.Header + // Store the index in the cache c.logger.Debug("storing response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path) - err = c.db.Set(index) + err = c.Set(ctx, index) if err != nil { c.logger.Error("failed to cache the proxied response", "error", err) return nil, err @@ -420,8 +452,12 @@ func (c *LeaseCache) createCtxInfo(ctx context.Context) *cachememdb.ContextInfo func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index, req *SendRequest, secret *api.Secret) { defer func() { id := ctx.Value(contextIndexID).(string) + if c.shuttingDown.Load() { + c.logger.Trace("not evicting index from cache during shutdown", "id", id, "method", req.Request.Method, "path", req.Request.URL.Path) + return + } c.logger.Debug("evicting index from cache", "id", id, "method", req.Request.Method, "path", req.Request.URL.Path) - err := c.db.Evict(cachememdb.IndexNameID, id) + err := c.Evict(id) if err != nil { c.logger.Error("failed to evict index", "id", id, "error", err) return @@ -466,6 +502,11 @@ func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index, return case <-watcher.RenewCh(): c.logger.Debug("secret renewed", "path", req.Request.URL.Path) + if c.ps != nil { + if err := c.updateLastRenewed(ctx, index, time.Now().UTC()); err != nil { + c.logger.Warn("not able to update lastRenewed time for cached index", "id", index.ID) + } + } case <-index.RenewCtxInfo.DoneCh: // This case indicates the renewal process to shutdown and evict // the cache entry. This is triggered when a specific secret @@ -477,6 +518,22 @@ func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index, } } +func (c *LeaseCache) updateLastRenewed(ctx context.Context, index *cachememdb.Index, t time.Time) error { + idLock := locksutil.LockForKey(c.idLocks, index.ID) + idLock.Lock() + defer idLock.Unlock() + + getIndex, err := c.db.Get(cachememdb.IndexNameID, index.ID) + if err != nil { + return err + } + index.LastRenewed = t + if err := c.Set(ctx, getIndex); err != nil { + return err + } + return nil +} + // computeIndexID results in a value that uniquely identifies a request // received by the agent. It does so by SHA256 hashing the serialized request // object containing the request path, query parameters and body parameters. @@ -642,8 +699,8 @@ func (c *LeaseCache) handleCacheClear(ctx context.Context, in *cacheClearInput) } c.l.Unlock() - // Reset the memdb instance - if err := c.db.Flush(); err != nil { + // Reset the memdb instance (and persistent storage if enabled) + if err := c.Flush(); err != nil { return err } @@ -850,6 +907,213 @@ func (c *LeaseCache) handleRevocationRequest(ctx context.Context, req *SendReque return true, nil } +// Set stores the index in the cachememdb, and also stores it in the persistent +// cache (if enabled) +func (c *LeaseCache) Set(ctx context.Context, index *cachememdb.Index) error { + if err := c.db.Set(index); err != nil { + return err + } + + if c.ps != nil { + b, err := index.Serialize() + if err != nil { + return err + } + + if err := c.ps.Set(ctx, index.ID, b, index.Type); err != nil { + return err + } + c.logger.Trace("set entry in persistent storage", "type", index.Type, "path", index.RequestPath, "id", index.ID) + } + + return nil +} + +// Evict removes an Index from the cachememdb, and also removes it from the +// persistent cache (if enabled) +func (c *LeaseCache) Evict(id string) error { + if err := c.db.Evict(cachememdb.IndexNameID, id); err != nil { + return err + } + + if c.ps != nil { + if err := c.ps.Delete(id); err != nil { + return err + } + c.logger.Trace("deleted item from persistent storage", "id", id) + } + + return nil +} + +// Flush the cachememdb and persistent cache (if enabled) +func (c *LeaseCache) Flush() error { + if err := c.db.Flush(); err != nil { + return err + } + + if c.ps != nil { + c.logger.Trace("clearing persistent storage") + return c.ps.Clear() + } + + return nil +} + +// Restore loads the cachememdb from the persistent storage passed in. Loads +// tokens first, since restoring a lease's renewal context and watcher requires +// looking up the token in the cachememdb. +func (c *LeaseCache) Restore(ctx context.Context, storage *cacheboltdb.BoltStorage) error { + // Process tokens first + tokens, err := storage.GetByType(ctx, cacheboltdb.TokenType) + if err != nil { + return err + } + if err := c.restoreTokens(tokens); err != nil { + return err + } + + // Then process auth leases + authLeases, err := storage.GetByType(ctx, cacheboltdb.AuthLeaseType) + if err != nil { + return err + } + if err := c.restoreLeases(authLeases); err != nil { + return err + } + + // Then process secret leases + secretLeases, err := storage.GetByType(ctx, cacheboltdb.SecretLeaseType) + if err != nil { + return err + } + if err := c.restoreLeases(secretLeases); err != nil { + return err + } + + return nil +} + +func (c *LeaseCache) restoreTokens(tokens [][]byte) error { + for _, token := range tokens { + newIndex, err := cachememdb.Deserialize(token) + if err != nil { + return err + } + newIndex.RenewCtxInfo = c.createCtxInfo(nil) + if err := c.db.Set(newIndex); err != nil { + return err + } + c.logger.Trace("restored token", "id", newIndex.ID) + } + return nil +} + +func (c *LeaseCache) restoreLeases(leases [][]byte) error { + for _, lease := range leases { + newIndex, err := cachememdb.Deserialize(lease) + if err != nil { + return err + } + + // Check if this lease has already expired + expired, err := c.hasExpired(time.Now().UTC(), newIndex) + if err != nil { + c.logger.Warn("failed to check if lease is expired", "id", newIndex.ID, "error", err) + } + if expired { + continue + } + + if err := c.restoreLeaseRenewCtx(newIndex); err != nil { + return err + } + if err := c.db.Set(newIndex); err != nil { + return err + } + c.logger.Trace("restored lease", "id", newIndex.ID, "path", newIndex.RequestPath) + } + return nil +} + +// restoreLeaseRenewCtx re-creates a RenewCtx for an index object and starts +// the watcher go routine +func (c *LeaseCache) restoreLeaseRenewCtx(index *cachememdb.Index) error { + if index.Response == nil { + return fmt.Errorf("cached response was nil for %s", index.ID) + } + + // Parse the secret to determine which type it is + reader := bufio.NewReader(bytes.NewReader(index.Response)) + resp, err := http.ReadResponse(reader, nil) + if err != nil { + c.logger.Error("failed to deserialize response", "error", err) + return err + } + secret, err := api.ParseSecret(resp.Body) + if err != nil { + c.logger.Error("failed to parse response as secret", "error", err) + return err + } + + var renewCtxInfo *cachememdb.ContextInfo + switch { + case secret.LeaseID != "": + entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken) + if err != nil { + return err + } + + if entry == nil { + return fmt.Errorf("could not find parent Token %s for req path %s", index.RequestToken, index.RequestPath) + } + + // Derive a context for renewal using the token's context + renewCtxInfo = cachememdb.NewContextInfo(entry.RenewCtxInfo.Ctx) + + case secret.Auth != nil: + var parentCtx context.Context + if !secret.Auth.Orphan { + entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken) + if err != nil { + return err + } + // If parent token is not managed by the agent, child shouldn't be + // either. + if entry == nil { + return fmt.Errorf("could not find parent Token %s for req path %s", index.RequestToken, index.RequestPath) + } + + c.logger.Debug("setting parent context", "method", index.RequestMethod, "path", index.RequestPath) + parentCtx = entry.RenewCtxInfo.Ctx + } + renewCtxInfo = c.createCtxInfo(parentCtx) + default: + return fmt.Errorf("unknown cached index item: %s", index.ID) + } + + renewCtx := context.WithValue(renewCtxInfo.Ctx, contextIndexID, index.ID) + index.RenewCtxInfo = &cachememdb.ContextInfo{ + Ctx: renewCtx, + CancelFunc: renewCtxInfo.CancelFunc, + DoneCh: renewCtxInfo.DoneCh, + } + + sendReq := &SendRequest{ + Token: index.RequestToken, + Request: &http.Request{ + Header: index.RequestHeader, + Method: index.RequestMethod, + URL: &url.URL{ + Path: index.RequestPath, + }, + }, + } + go c.startRenewing(renewCtx, index, sendReq, secret) + + return nil +} + // deriveNamespaceAndRevocationPath returns the namespace and relative path for // revocation paths. // @@ -912,9 +1176,11 @@ func (c *LeaseCache) RegisterAutoAuthToken(token string) error { return err } - // If the index is found, defer its cancelFunc + // If the index is found, just keep it in the cache and ignore the incoming + // token (since they're the same) if oldIndex != nil { - defer oldIndex.RenewCtxInfo.CancelFunc() + c.logger.Trace("auto-auth token already exists in cache; no need to store it again") + return nil } // The following randomly generated values are required for index stored by @@ -938,6 +1204,7 @@ func (c *LeaseCache) RegisterAutoAuthToken(token string) error { Token: token, Namespace: namespace, RequestPath: requestPath, + Type: cacheboltdb.TokenType, } // Derive a context off of the lease cache's base context @@ -951,7 +1218,7 @@ func (c *LeaseCache) RegisterAutoAuthToken(token string) error { // Store the index in the cache c.logger.Debug("storing auto-auth token into the cache") - err = c.db.Set(index) + err = c.Set(c.baseCtxInfo.Ctx, index) if err != nil { c.logger.Error("failed to cache the auto-auth token", "error", err) return err @@ -997,3 +1264,32 @@ func parseCacheClearInput(req *cacheClearRequest) (*cacheClearInput, error) { return in, nil } + +func (c *LeaseCache) hasExpired(currentTime time.Time, index *cachememdb.Index) (bool, error) { + reader := bufio.NewReader(bytes.NewReader(index.Response)) + resp, err := http.ReadResponse(reader, nil) + if err != nil { + return false, fmt.Errorf("failed to deserialize response: %w", err) + } + secret, err := api.ParseSecret(resp.Body) + if err != nil { + return false, fmt.Errorf("failed to parse response as secret: %w", err) + } + + elapsed := currentTime.Sub(index.LastRenewed) + var leaseDuration int + switch index.Type { + case cacheboltdb.AuthLeaseType: + leaseDuration = secret.Auth.LeaseDuration + case cacheboltdb.SecretLeaseType: + leaseDuration = secret.LeaseDuration + default: + return false, fmt.Errorf("index type %q unexpected in expiration check", index.Type) + } + + if int(elapsed.Seconds()) > leaseDuration { + c.logger.Trace("secret has expired", "id", index.ID, "elapsed", elapsed, "lease duration", leaseDuration) + return true, nil + } + return false, nil +} diff --git a/command/agent/cache/lease_cache_test.go b/command/agent/cache/lease_cache_test.go index ba6dafa8b..5cf999d9b 100644 --- a/command/agent/cache/lease_cache_test.go +++ b/command/agent/cache/lease_cache_test.go @@ -3,9 +3,11 @@ package cache import ( "context" "fmt" + "io/ioutil" "net/http" "net/http/httptest" "net/url" + "os" "reflect" "strings" "sync" @@ -15,9 +17,13 @@ import ( "github.com/go-test/deep" hclog "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/command/agent/cache/cacheboltdb" "github.com/hashicorp/vault/command/agent/cache/cachememdb" + "github.com/hashicorp/vault/command/agent/cache/keymanager" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/atomic" ) @@ -63,6 +69,24 @@ func testNewLeaseCacheWithDelay(t *testing.T, cacheable bool, delay int) *LeaseC return lc } +func testNewLeaseCacheWithPersistence(t *testing.T, responses []*SendResponse, storage *cacheboltdb.BoltStorage) *LeaseCache { + t.Helper() + + client, err := api.NewClient(api.DefaultConfig()) + require.NoError(t, err) + + lc, err := NewLeaseCache(&LeaseCacheConfig{ + Client: client, + BaseContext: context.Background(), + Proxier: newMockProxier(responses), + Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"), + Storage: storage, + }) + require.NoError(t, err) + + return lc +} + func TestCache_ComputeIndexID(t *testing.T) { type args struct { req *http.Request @@ -649,3 +673,394 @@ func TestLeaseCache_Concurrent_Cacheable(t *testing.T) { t.Fatalf("Should have returned a cached response 99 times, got %d", cacheCount.Load()) } } + +func setupBoltStorage(t *testing.T) (tempCacheDir string, boltStorage *cacheboltdb.BoltStorage) { + t.Helper() + + km, err := keymanager.NewPassthroughKeyManager(nil) + require.NoError(t, err) + + tempCacheDir, err = ioutil.TempDir("", "agent-cache-test") + require.NoError(t, err) + boltStorage, err = cacheboltdb.NewBoltStorage(&cacheboltdb.BoltStorageConfig{ + Path: tempCacheDir, + Logger: hclog.Default(), + Wrapper: km.Wrapper(), + }) + require.NoError(t, err) + require.NotNil(t, boltStorage) + // The calling function should `defer boltStorage.Close()` and `defer os.RemoveAll(tempCacheDir)` + return tempCacheDir, boltStorage +} + +func TestLeaseCache_PersistAndRestore(t *testing.T) { + // Emulate 4 responses from the api proxy. The first two use the auto-auth + // token, and the last two use another token. + responses := []*SendResponse{ + newTestSendResponse(200, `{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": 600}}`), + newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}, "lease_duration": 600}`), + newTestSendResponse(202, `{"auth": {"client_token": "testtoken2", "renewable": true, "orphan": true, "lease_duration": 600}}`), + newTestSendResponse(203, `{"lease_id": "secret2-lease", "renewable": true, "data": {"number": "two"}, "lease_duration": 600}`), + } + + tempDir, boltStorage := setupBoltStorage(t) + defer os.RemoveAll(tempDir) + defer boltStorage.Close() + lc := testNewLeaseCacheWithPersistence(t, responses, boltStorage) + + // Register an auto-auth token so that the token and lease requests are cached + lc.RegisterAutoAuthToken("autoauthtoken") + + cacheTests := []struct { + token string + method string + urlPath string + body string + wantStatusCode int + }{ + { + // Make a request. A response with a new token is returned to the + // lease cache and that will be cached. + token: "autoauthtoken", + method: "GET", + urlPath: "http://example.com/v1/sample/api", + body: `{"value": "input"}`, + wantStatusCode: responses[0].Response.StatusCode, + }, + { + // Modify the request a little bit to ensure the second response is + // returned to the lease cache. + token: "autoauthtoken", + method: "GET", + urlPath: "http://example.com/v1/sample/api", + body: `{"value": "input_changed"}`, + wantStatusCode: responses[1].Response.StatusCode, + }, + { + // Simulate an approle login to get another token + method: "PUT", + urlPath: "http://example.com/v1/auth/approle/login", + body: `{"role_id": "my role", "secret_id": "my secret"}`, + wantStatusCode: responses[2].Response.StatusCode, + }, + { + // Test caching with the token acquired from the approle login + token: "testtoken2", + method: "GET", + urlPath: "http://example.com/v1/sample2/api", + body: `{"second": "input"}`, + wantStatusCode: responses[3].Response.StatusCode, + }, + } + + for _, ct := range cacheTests { + // Send once to cache + sendReq := &SendRequest{ + Token: ct.token, + Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)), + } + resp, err := lc.Send(context.Background(), sendReq) + require.NoError(t, err) + assert.Equal(t, resp.Response.StatusCode, ct.wantStatusCode, "expected proxied response") + assert.Nil(t, resp.CacheMeta) + + // Send again to test cache. If this isn't cached, the response returned + // will be the next in the list and the status code will not match. + sendCacheReq := &SendRequest{ + Token: ct.token, + Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)), + } + respCached, err := lc.Send(context.Background(), sendCacheReq) + require.NoError(t, err, "failed to send request %+v", ct) + assert.Equal(t, respCached.Response.StatusCode, ct.wantStatusCode, "expected proxied response") + require.NotNil(t, respCached.CacheMeta) + assert.True(t, respCached.CacheMeta.Hit) + } + + // Now we know the cache is working, so try restoring from the persisted + // cache's storage + restoredCache := testNewLeaseCache(t, nil) + + err := restoredCache.Restore(context.Background(), boltStorage) + assert.NoError(t, err) + + // Now compare before and after + beforeDB, err := lc.db.GetByPrefix(cachememdb.IndexNameID) + require.NoError(t, err) + assert.Len(t, beforeDB, 5) + + for _, cachedItem := range beforeDB { + restoredItem, err := restoredCache.db.Get(cachememdb.IndexNameID, cachedItem.ID) + require.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, cachedItem.ID, restoredItem.ID) + assert.Equal(t, cachedItem.Lease, restoredItem.Lease) + assert.Equal(t, cachedItem.LeaseToken, restoredItem.LeaseToken) + assert.Equal(t, cachedItem.Namespace, restoredItem.Namespace) + assert.Equal(t, cachedItem.RequestHeader, restoredItem.RequestHeader) + assert.Equal(t, cachedItem.RequestMethod, restoredItem.RequestMethod) + assert.Equal(t, cachedItem.RequestPath, restoredItem.RequestPath) + assert.Equal(t, cachedItem.RequestToken, restoredItem.RequestToken) + assert.Equal(t, cachedItem.Response, restoredItem.Response) + assert.Equal(t, cachedItem.Token, restoredItem.Token) + assert.Equal(t, cachedItem.TokenAccessor, restoredItem.TokenAccessor) + assert.Equal(t, cachedItem.TokenParent, restoredItem.TokenParent) + + // check what we can in the renewal context + assert.NotEmpty(t, restoredItem.RenewCtxInfo.CancelFunc) + assert.NotZero(t, restoredItem.RenewCtxInfo.DoneCh) + require.NotEmpty(t, restoredItem.RenewCtxInfo.Ctx) + assert.Equal(t, + cachedItem.RenewCtxInfo.Ctx.Value(contextIndexID), + restoredItem.RenewCtxInfo.Ctx.Value(contextIndexID), + ) + } + afterDB, err := restoredCache.db.GetByPrefix(cachememdb.IndexNameID) + require.NoError(t, err) + assert.Len(t, afterDB, 5) + + // And finally send the cache requests once to make sure they're all being + // served from the restoredCache + for _, ct := range cacheTests { + sendCacheReq := &SendRequest{ + Token: ct.token, + Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)), + } + respCached, err := restoredCache.Send(context.Background(), sendCacheReq) + require.NoError(t, err, "failed to send request %+v", ct) + assert.Equal(t, respCached.Response.StatusCode, ct.wantStatusCode, "expected proxied response") + require.NotNil(t, respCached.CacheMeta) + assert.True(t, respCached.CacheMeta.Hit) + } +} + +func TestEvictPersistent(t *testing.T) { + ctx := context.Background() + + responses := []*SendResponse{ + newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}}`), + } + + tempDir, boltStorage := setupBoltStorage(t) + defer os.RemoveAll(tempDir) + defer boltStorage.Close() + lc := testNewLeaseCacheWithPersistence(t, responses, boltStorage) + + lc.RegisterAutoAuthToken("autoauthtoken") + + // populate cache by sending request through + sendReq := &SendRequest{ + Token: "autoauthtoken", + Request: httptest.NewRequest("GET", "http://example.com/v1/sample/api", strings.NewReader(`{"value": "some_input"}`)), + } + resp, err := lc.Send(context.Background(), sendReq) + require.NoError(t, err) + assert.Equal(t, resp.Response.StatusCode, 201, "expected proxied response") + assert.Nil(t, resp.CacheMeta) + + // Check bolt for the cached lease + secrets, err := lc.ps.GetByType(ctx, cacheboltdb.SecretLeaseType) + require.NoError(t, err) + assert.Len(t, secrets, 1) + + // Call clear for the request path + err = lc.handleCacheClear(context.Background(), &cacheClearInput{ + Type: "request_path", + RequestPath: "/v1/sample/api", + }) + require.NoError(t, err) + + time.Sleep(2 * time.Second) + + // Check that cached item is gone + secrets, err = lc.ps.GetByType(ctx, cacheboltdb.SecretLeaseType) + require.NoError(t, err) + assert.Len(t, secrets, 0) +} + +func TestRegisterAutoAuth_sameToken(t *testing.T) { + // If the auto-auth token already exists in the cache, it should not be + // stored again in a new index. + lc := testNewLeaseCache(t, nil) + err := lc.RegisterAutoAuthToken("autoauthtoken") + assert.NoError(t, err) + + oldTokenIndex, err := lc.db.Get(cachememdb.IndexNameToken, "autoauthtoken") + assert.NoError(t, err) + oldTokenID := oldTokenIndex.ID + + // register the same token again + err = lc.RegisterAutoAuthToken("autoauthtoken") + assert.NoError(t, err) + + // check that there's only one index for autoauthtoken + entries, err := lc.db.GetByPrefix(cachememdb.IndexNameToken, "autoauthtoken") + assert.NoError(t, err) + assert.Len(t, entries, 1) + + newTokenIndex, err := lc.db.Get(cachememdb.IndexNameToken, "autoauthtoken") + assert.NoError(t, err) + + // compare the ID's since those are randomly generated when an index for a + // token is added to the cache, so if a new token was added, the id's will + // not match. + assert.Equal(t, oldTokenID, newTokenIndex.ID) +} + +func Test_hasExpired(t *testing.T) { + + responses := []*SendResponse{ + newTestSendResponse(200, `{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": 60}}`), + newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}, "lease_duration": 60}`), + } + lc := testNewLeaseCache(t, responses) + lc.RegisterAutoAuthToken("autoauthtoken") + + cacheTests := []struct { + token string + urlPath string + leaseType string + wantStatusCode int + }{ + { + // auth lease + token: "autoauthtoken", + urlPath: "/v1/sample/auth", + leaseType: cacheboltdb.AuthLeaseType, + wantStatusCode: responses[0].Response.StatusCode, + }, + { + // secret lease + token: "autoauthtoken", + urlPath: "/v1/sample/secret", + leaseType: cacheboltdb.SecretLeaseType, + wantStatusCode: responses[1].Response.StatusCode, + }, + } + + for _, ct := range cacheTests { + // Send once to cache + urlPath := "http://example.com" + ct.urlPath + sendReq := &SendRequest{ + Token: ct.token, + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)), + } + resp, err := lc.Send(context.Background(), sendReq) + require.NoError(t, err) + assert.Equal(t, resp.Response.StatusCode, ct.wantStatusCode, "expected proxied response") + assert.Nil(t, resp.CacheMeta) + + // get the Index out of the mem cache + index, err := lc.db.Get(cachememdb.IndexNameRequestPath, "root/", ct.urlPath) + require.NoError(t, err) + assert.Equal(t, ct.leaseType, index.Type) + + // The lease duration is 60 seconds, so time.Now() should be within that + notExpired, err := lc.hasExpired(time.Now().UTC(), index) + require.NoError(t, err) + assert.False(t, notExpired) + + // In 90 seconds the index should be "expired" + futureTime := time.Now().UTC().Add(time.Second * 90) + expired, err := lc.hasExpired(futureTime, index) + require.NoError(t, err) + assert.True(t, expired) + } + +} + +func TestLeaseCache_hasExpired_wrong_type(t *testing.T) { + index := &cachememdb.Index{ + Type: cacheboltdb.TokenType, + Response: []byte(`HTTP/0.0 200 OK +Content-Type: application/json +Date: Tue, 02 Mar 2021 17:54:16 GMT + +{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": 60}}`), + } + + lc := testNewLeaseCache(t, nil) + expired, err := lc.hasExpired(time.Now().UTC(), index) + assert.False(t, expired) + assert.EqualError(t, err, `index type "token" unexpected in expiration check`) +} + +func TestLeaseCacheRestore_expired(t *testing.T) { + // Emulate 2 responses from the api proxy, both expired + responses := []*SendResponse{ + newTestSendResponse(200, `{"auth": {"client_token": "testtoken", "renewable": true, "lease_duration": -600}}`), + newTestSendResponse(201, `{"lease_id": "foo", "renewable": true, "data": {"value": "foo"}, "lease_duration": -600}`), + } + + tempDir, boltStorage := setupBoltStorage(t) + defer os.RemoveAll(tempDir) + defer boltStorage.Close() + lc := testNewLeaseCacheWithPersistence(t, responses, boltStorage) + + // Register an auto-auth token so that the token and lease requests are cached in mem + lc.RegisterAutoAuthToken("autoauthtoken") + + cacheTests := []struct { + token string + method string + urlPath string + body string + wantStatusCode int + }{ + { + // Make a request. A response with a new token is returned to the + // lease cache and that will be cached. + token: "autoauthtoken", + method: "GET", + urlPath: "http://example.com/v1/sample/api", + body: `{"value": "input"}`, + wantStatusCode: responses[0].Response.StatusCode, + }, + { + // Modify the request a little bit to ensure the second response is + // returned to the lease cache. + token: "autoauthtoken", + method: "GET", + urlPath: "http://example.com/v1/sample/api", + body: `{"value": "input_changed"}`, + wantStatusCode: responses[1].Response.StatusCode, + }, + } + + for _, ct := range cacheTests { + // Send once to cache + sendReq := &SendRequest{ + Token: ct.token, + Request: httptest.NewRequest(ct.method, ct.urlPath, strings.NewReader(ct.body)), + } + resp, err := lc.Send(context.Background(), sendReq) + require.NoError(t, err) + assert.Equal(t, resp.Response.StatusCode, ct.wantStatusCode, "expected proxied response") + assert.Nil(t, resp.CacheMeta) + } + + // Restore from the persisted cache's storage + restoredCache := testNewLeaseCache(t, nil) + + err := restoredCache.Restore(context.Background(), boltStorage) + assert.NoError(t, err) + + // The original mem cache should have all three items + beforeDB, err := lc.db.GetByPrefix(cachememdb.IndexNameID) + require.NoError(t, err) + assert.Len(t, beforeDB, 3) + + // There should only be one item in the restored cache: the autoauth token + afterDB, err := restoredCache.db.GetByPrefix(cachememdb.IndexNameID) + require.NoError(t, err) + assert.Len(t, afterDB, 1) + + // Just verify that the one item in the restored mem cache matches one in the original mem cache, and that it's the auto-auth token + beforeItem, err := lc.db.Get(cachememdb.IndexNameID, afterDB[0].ID) + require.NoError(t, err) + assert.NotNil(t, beforeItem) + + assert.Equal(t, "autoauthtoken", afterDB[0].Token) + assert.Equal(t, cacheboltdb.TokenType, afterDB[0].Type) +} diff --git a/command/agent/config/config.go b/command/agent/config/config.go index 32f99b55c..eb91385d8 100644 --- a/command/agent/config/config.go +++ b/command/agent/config/config.go @@ -50,6 +50,16 @@ type Cache struct { ForceAutoAuthToken bool `hcl:"-"` EnforceConsistency string `hcl:"enforce_consistency"` WhenInconsistent string `hcl:"when_inconsistent"` + Persist *Persist `hcl:"persist"` +} + +// Persist contains configuration needed for persistent caching +type Persist struct { + Type string + Path string `hcl:"path"` + KeepAfterImport bool `hcl:"keep_after_import"` + ExitOnErr bool `hcl:"exit_on_err"` + ServiceAccountTokenFile string `hcl:"service_account_token_file"` } // AutoAuth is the configured authentication method and sinks @@ -309,8 +319,51 @@ func parseCache(result *Config, list *ast.ObjectList) error { } } } - result.Cache = &c + + subs, ok := item.Val.(*ast.ObjectType) + if !ok { + return fmt.Errorf("could not parse %q as an object", name) + } + subList := subs.List + if err := parsePersist(result, subList); err != nil { + return fmt.Errorf("error parsing persist: %w", err) + } + + return nil +} + +func parsePersist(result *Config, list *ast.ObjectList) error { + name := "persist" + + persistList := list.Filter(name) + if len(persistList.Items) == 0 { + return nil + } + + if len(persistList.Items) > 1 { + return fmt.Errorf("only one %q block is required", name) + } + + item := persistList.Items[0] + + var p Persist + err := hcl.DecodeObject(&p, item.Val) + if err != nil { + return err + } + + if p.Type == "" { + if len(item.Keys) == 1 { + p.Type = strings.ToLower(item.Keys[0].Token.Value().(string)) + } + if p.Type == "" { + return errors.New("persist type must be specified") + } + } + + result.Cache.Persist = &p + return nil } diff --git a/command/agent/config/config_test.go b/command/agent/config/config_test.go index 8a502d7f9..167b7bd8a 100644 --- a/command/agent/config/config_test.go +++ b/command/agent/config/config_test.go @@ -66,6 +66,13 @@ func TestLoadConfigFile_AgentCache(t *testing.T) { UseAutoAuthToken: true, UseAutoAuthTokenRaw: true, ForceAutoAuthToken: false, + Persist: &Persist{ + Type: "kubernetes", + Path: "/vault/agent-cache/", + KeepAfterImport: true, + ExitOnErr: true, + ServiceAccountTokenFile: "/tmp/serviceaccount/token", + }, }, Vault: &Vault{ Address: "http://127.0.0.1:1111", @@ -445,6 +452,52 @@ func TestLoadConfigFile_AgentCache_AutoAuth_False(t *testing.T) { } } +func TestLoadConfigFile_AgentCache_Persist(t *testing.T) { + config, err := LoadConfig("./test-fixtures/config-cache-persist-false.hcl") + if err != nil { + t.Fatalf("err: %s", err) + } + + expected := &Config{ + Cache: &Cache{ + Persist: &Persist{ + Type: "kubernetes", + Path: "/vault/agent-cache/", + KeepAfterImport: false, + ExitOnErr: false, + ServiceAccountTokenFile: "", + }, + }, + SharedConfig: &configutil.SharedConfig{ + PidFile: "./pidfile", + Listeners: []*configutil.Listener{ + { + Type: "tcp", + Address: "127.0.0.1:8300", + TLSDisable: true, + }, + }, + }, + } + + config.Listeners[0].RawConfig = nil + if diff := deep.Equal(config, expected); diff != nil { + t.Fatal(diff) + } + + config.Listeners[0].RawConfig = nil + if diff := deep.Equal(config, expected); diff != nil { + t.Fatal(diff) + } +} + +func TestLoadConfigFile_AgentCache_PersistMissingType(t *testing.T) { + _, err := LoadConfig("./test-fixtures/config-cache-persist-empty-type.hcl") + if err == nil || os.IsNotExist(err) { + t.Fatal("expected error or file is missing") + } +} + // TestLoadConfigFile_Template tests template definitions in Vault Agent func TestLoadConfigFile_Template(t *testing.T) { testCases := map[string]struct { diff --git a/command/agent/config/test-fixtures/config-cache-embedded-type.hcl b/command/agent/config/test-fixtures/config-cache-embedded-type.hcl index 83242e941..3c5615315 100644 --- a/command/agent/config/test-fixtures/config-cache-embedded-type.hcl +++ b/command/agent/config/test-fixtures/config-cache-embedded-type.hcl @@ -21,6 +21,12 @@ auto_auth { cache { use_auto_auth_token = true + persist "kubernetes" { + path = "/vault/agent-cache/" + keep_after_import = true + exit_on_err = true + service_account_token_file = "/tmp/serviceaccount/token" + } } listener { diff --git a/command/agent/config/test-fixtures/config-cache-persist-empty-type.hcl b/command/agent/config/test-fixtures/config-cache-persist-empty-type.hcl new file mode 100644 index 000000000..55f1d6480 --- /dev/null +++ b/command/agent/config/test-fixtures/config-cache-persist-empty-type.hcl @@ -0,0 +1,12 @@ +pid_file = "./pidfile" + +cache { + persist = { + path = "/vault/agent-cache/" + } +} + +listener "tcp" { + address = "127.0.0.1:8300" + tls_disable = true +} diff --git a/command/agent/config/test-fixtures/config-cache-persist-false.hcl b/command/agent/config/test-fixtures/config-cache-persist-false.hcl new file mode 100644 index 000000000..5ab7f0449 --- /dev/null +++ b/command/agent/config/test-fixtures/config-cache-persist-false.hcl @@ -0,0 +1,14 @@ +pid_file = "./pidfile" + +cache { + persist "kubernetes" { + exit_on_err = false + keep_after_import = false + path = "/vault/agent-cache/" + } +} + +listener "tcp" { + address = "127.0.0.1:8300" + tls_disable = true +} diff --git a/command/agent/config/test-fixtures/config-cache.hcl b/command/agent/config/test-fixtures/config-cache.hcl index 64a9cdf08..b468e9a07 100644 --- a/command/agent/config/test-fixtures/config-cache.hcl +++ b/command/agent/config/test-fixtures/config-cache.hcl @@ -21,6 +21,13 @@ auto_auth { cache { use_auto_auth_token = true + persist = { + type = "kubernetes" + path = "/vault/agent-cache/" + keep_after_import = true + exit_on_err = true + service_account_token_file = "/tmp/serviceaccount/token" + } } listener "unix" { diff --git a/command/agent/template/template.go b/command/agent/template/template.go index b1b69b30f..bff82c868 100644 --- a/command/agent/template/template.go +++ b/command/agent/template/template.go @@ -256,7 +256,7 @@ func newRunnerConfig(sc *ServerConfig, templates ctconfig.TemplateConfigs) (*ctc } // Use the cache if available or fallback to the Vault server values. - if sc.AgentConfig.Cache != nil && len(sc.AgentConfig.Listeners) != 0 { + if sc.AgentConfig.Cache != nil && sc.AgentConfig.Cache.Persist != nil && len(sc.AgentConfig.Listeners) != 0 { scheme := "unix://" if sc.AgentConfig.Listeners[0].Type == "tcp" { scheme = "https://" diff --git a/command/agent/template/template_test.go b/command/agent/template/template_test.go index 347364c79..f856650de 100644 --- a/command/agent/template/template_test.go +++ b/command/agent/template/template_test.go @@ -28,7 +28,7 @@ func TestNewServer(t *testing.T) { } } -func newAgentConfig(listeners []*configutil.Listener, enableCache bool) *config.Config { +func newAgentConfig(listeners []*configutil.Listener, enableCache, enablePersisentCache bool) *config.Config { agentConfig := &config.Config{ SharedConfig: &configutil.SharedConfig{ PidFile: "./pidfile", @@ -65,7 +65,13 @@ func newAgentConfig(listeners []*configutil.Listener, enableCache bool) *config. }, } if enableCache { - agentConfig.Cache = &config.Cache{UseAutoAuthToken: true} + agentConfig.Cache = &config.Cache{ + UseAutoAuthToken: true, + } + } + + if enablePersisentCache { + agentConfig.Cache.Persist = &config.Persist{Type: "kubernetes"} } return agentConfig @@ -94,7 +100,7 @@ func TestCacheConfigUnix(t *testing.T) { }, } - agentConfig := newAgentConfig(listeners, true) + agentConfig := newAgentConfig(listeners, true, true) serverConfig := ServerConfig{AgentConfig: agentConfig} ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{}) @@ -131,7 +137,7 @@ func TestCacheConfigHTTP(t *testing.T) { }, } - agentConfig := newAgentConfig(listeners, true) + agentConfig := newAgentConfig(listeners, true, true) serverConfig := ServerConfig{AgentConfig: agentConfig} ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{}) @@ -168,7 +174,7 @@ func TestCacheConfigHTTPS(t *testing.T) { }, } - agentConfig := newAgentConfig(listeners, true) + agentConfig := newAgentConfig(listeners, true, true) serverConfig := ServerConfig{AgentConfig: agentConfig} ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{}) @@ -209,7 +215,44 @@ func TestCacheConfigNoCache(t *testing.T) { }, } - agentConfig := newAgentConfig(listeners, false) + agentConfig := newAgentConfig(listeners, false, false) + serverConfig := ServerConfig{AgentConfig: agentConfig} + + ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{}) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + expected := "http://127.0.0.1:1111" + if *ctConfig.Vault.Address != expected { + t.Fatalf("expected %s, got %s", expected, *ctConfig.Vault.Address) + } +} + +func TestCacheConfigNoPersistentCache(t *testing.T) { + listeners := []*configutil.Listener{ + { + Type: "tcp", + Address: "127.0.0.1:8300", + TLSKeyFile: "/path/to/cakey.pem", + TLSCertFile: "/path/to/cacert.pem", + }, + { + Type: "unix", + Address: "foobar", + TLSDisable: true, + SocketMode: "configmode", + SocketUser: "configuser", + SocketGroup: "configgroup", + }, + { + Type: "tcp", + Address: "127.0.0.1:8400", + TLSDisable: true, + }, + } + + agentConfig := newAgentConfig(listeners, true, false) serverConfig := ServerConfig{AgentConfig: agentConfig} ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{}) @@ -226,7 +269,7 @@ func TestCacheConfigNoCache(t *testing.T) { func TestCacheConfigNoListener(t *testing.T) { listeners := []*configutil.Listener{} - agentConfig := newAgentConfig(listeners, true) + agentConfig := newAgentConfig(listeners, true, true) serverConfig := ServerConfig{AgentConfig: agentConfig} ctConfig, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{}) @@ -264,7 +307,7 @@ func TestCacheConfigRejectMTLS(t *testing.T) { }, } - agentConfig := newAgentConfig(listeners, true) + agentConfig := newAgentConfig(listeners, true, true) serverConfig := ServerConfig{AgentConfig: agentConfig} _, err := newRunnerConfig(&serverConfig, ctconfig.TemplateConfigs{})