diff --git a/command/agent/cache/lease_cache.go b/command/agent/cache/lease_cache.go index c20a03ed1..cbaa6cc15 100644 --- a/command/agent/cache/lease_cache.go +++ b/command/agent/cache/lease_cache.go @@ -27,6 +27,8 @@ import ( "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/helper/locksutil" "github.com/hashicorp/vault/sdk/logical" + gocache "github.com/patrickmn/go-cache" + "go.uber.org/atomic" ) const ( @@ -78,6 +80,9 @@ type LeaseCache struct { // idLocks is used during cache lookup to ensure that identical requests made // in parallel won't trigger multiple renewal goroutines. idLocks []*locksutil.LockEntry + + // inflightCache keeps track of inflight requests + inflightCache *gocache.Cache } // LeaseCacheConfig is the configuration for initializing a new @@ -89,6 +94,22 @@ type LeaseCacheConfig struct { Logger hclog.Logger } +type inflightRequest struct { + // ch is closed by the request that ends up processing the set of + // parallel request + ch chan struct{} + + // remaining is the number of remaining inflight request that needs to + // be processed before this object can be cleaned up + remaining atomic.Uint64 +} + +func newInflightRequest() *inflightRequest { + return &inflightRequest{ + ch: make(chan struct{}), + } +} + // NewLeaseCache creates a new instance of a LeaseCache. func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) { if conf == nil { @@ -112,13 +133,14 @@ func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) { baseCtxInfo := cachememdb.NewContextInfo(conf.BaseContext) return &LeaseCache{ - client: conf.Client, - proxier: conf.Proxier, - logger: conf.Logger, - db: db, - baseCtxInfo: baseCtxInfo, - l: &sync.RWMutex{}, - idLocks: locksutil.CreateLocks(), + client: conf.Client, + proxier: conf.Proxier, + logger: conf.Logger, + db: db, + baseCtxInfo: baseCtxInfo, + l: &sync.RWMutex{}, + idLocks: locksutil.CreateLocks(), + inflightCache: gocache.New(gocache.NoExpiration, gocache.NoExpiration), }, nil } @@ -170,40 +192,60 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, return nil, err } - // Grab a read lock for this particular request + // Check the inflight cache to see if there are other inflight requests + // of the same kind, based on the computed ID. If so, we increment a counter + + var inflight *inflightRequest + + defer func() { + // Cleanup on the cache if there are no remaining inflight requests. + // This is the last step, so we defer the call first + if inflight != nil && inflight.remaining.Load() == 0 { + c.inflightCache.Delete(id) + } + }() + idLock := locksutil.LockForKey(c.idLocks, id) - idLock.RLock() - unlockFunc := idLock.RUnlock - defer func() { unlockFunc() }() + // Briefly grab an ID-based lock in here to emulate a load-or-store behavior + // and prevent concurrent cacheable requests from being proxied twice if + // they both miss the cache due to it being clean when peeking the cache + // entry. + idLock.Lock() + inflightRaw, found := c.inflightCache.Get(id) + if found { + idLock.Unlock() + inflight = inflightRaw.(*inflightRequest) + inflight.remaining.Inc() + defer inflight.remaining.Dec() + + // If found it means that there's an inflight request being processed. + // We wait until that's finished before proceeding further. + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-inflight.ch: + } + } else { + inflight = newInflightRequest() + inflight.remaining.Inc() + defer inflight.remaining.Dec() + + c.inflightCache.Set(id, inflight, gocache.NoExpiration) + idLock.Unlock() + + // Signal that the processing request is done + defer close(inflight.ch) + } // Check if the response for this request is already in the cache - sendResp, err := c.checkCacheForRequest(id) + cachedResp, err := c.checkCacheForRequest(id) if err != nil { return nil, err } - if sendResp != nil { + if cachedResp != nil { c.logger.Debug("returning cached response", "path", req.Request.URL.Path) - return sendResp, nil - } - - // Perform a lock upgrade - idLock.RUnlock() - idLock.Lock() - unlockFunc = idLock.Unlock - - // Check cache once more after upgrade - sendResp, err = c.checkCacheForRequest(id) - if err != nil { - return nil, err - } - - // If found, it means that some other parallel request already cached this response - // in between this upgrade so we can simply return that. Otherwise, this request - // will be the one performing the cache write. - if sendResp != nil { - c.logger.Debug("returning cached response", "method", req.Request.Method, "path", req.Request.URL.Path) - return sendResp, nil + return cachedResp, nil } c.logger.Debug("forwarding request", "method", req.Request.Method, "path", req.Request.URL.Path) @@ -441,7 +483,7 @@ func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index, func computeIndexID(req *SendRequest) (string, error) { var b bytes.Buffer - // Serialze the request + // Serialize the request if err := req.Request.Write(&b); err != nil { return "", fmt.Errorf("failed to serialize request: %v", err) } diff --git a/command/agent/cache/lease_cache_test.go b/command/agent/cache/lease_cache_test.go index 7726b5d2f..3b31385d0 100644 --- a/command/agent/cache/lease_cache_test.go +++ b/command/agent/cache/lease_cache_test.go @@ -8,15 +8,17 @@ import ( "net/url" "reflect" "strings" + "sync" "testing" - - "github.com/hashicorp/vault/command/agent/cache/cachememdb" + "time" "github.com/go-test/deep" hclog "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/command/agent/cache/cachememdb" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/logging" + "go.uber.org/atomic" ) func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache { @@ -40,6 +42,27 @@ func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache { return lc } +func testNewLeaseCacheWithDelay(t *testing.T, cacheable bool, delay int) *LeaseCache { + t.Helper() + + client, err := api.NewClient(api.DefaultConfig()) + if err != nil { + t.Fatal(err) + } + + lc, err := NewLeaseCache(&LeaseCacheConfig{ + Client: client, + BaseContext: context.Background(), + Proxier: &mockDelayProxier{cacheable, delay}, + Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"), + }) + if err != nil { + t.Fatal(err) + } + + return lc +} + func TestCache_ComputeIndexID(t *testing.T) { type args struct { req *http.Request @@ -509,3 +532,108 @@ func TestCache_DeriveNamespaceAndRevocationPath(t *testing.T) { }) } } + +func TestLeaseCache_Concurrent_NonCacheable(t *testing.T) { + lc := testNewLeaseCacheWithDelay(t, false, 50) + + // We are going to send 100 requests, each taking 50ms to process. If these + // requests are processed serially, it will take ~5seconds to finish. we + // use a ContextWithTimeout to tell us if this is the case by giving ample + // time for it process them concurrently but time out if they get processed + // serially. + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + wgDoneCh := make(chan struct{}) + + go func() { + var wg sync.WaitGroup + // 100 concurrent requests + for i := 0; i < 100; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + // Send a request through the lease cache which is not cacheable (there is + // no lease information or auth information in the response) + sendReq := &SendRequest{ + Request: httptest.NewRequest("GET", "http://example.com", nil), + } + + _, err := lc.Send(ctx, sendReq) + if err != nil { + t.Fatal(err) + } + }() + } + + wg.Wait() + close(wgDoneCh) + }() + + select { + case <-ctx.Done(): + t.Fatalf("request timed out: %s", ctx.Err()) + case <-wgDoneCh: + } + +} + +func TestLeaseCache_Concurrent_Cacheable(t *testing.T) { + lc := testNewLeaseCacheWithDelay(t, true, 50) + + if err := lc.RegisterAutoAuthToken("autoauthtoken"); err != nil { + t.Fatal(err) + } + + // We are going to send 100 requests, each taking 50ms to process. If these + // requests are processed serially, it will take ~5seconds to finish, so we + // use a ContextWithTimeout to tell us if this is the case. + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + var cacheCount atomic.Uint32 + wgDoneCh := make(chan struct{}) + + go func() { + var wg sync.WaitGroup + // Start 100 concurrent requests + for i := 0; i < 100; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + sendReq := &SendRequest{ + Token: "autoauthtoken", + Request: httptest.NewRequest("GET", "http://example.com/v1/sample/api", nil), + } + + resp, err := lc.Send(ctx, sendReq) + if err != nil { + t.Fatal(err) + } + + if resp.CacheMeta != nil && resp.CacheMeta.Hit { + cacheCount.Inc() + } + }() + } + + wg.Wait() + close(wgDoneCh) + }() + + select { + case <-ctx.Done(): + t.Fatalf("request timed out: %s", ctx.Err()) + case <-wgDoneCh: + } + + // Ensure that all but one request got proxied. The other 99 should be + // returned from the cache. + if cacheCount.Load() != 99 { + t.Fatalf("Should have returned a cached response 99 times, got %d", cacheCount.Load()) + } +} diff --git a/command/agent/cache/testing.go b/command/agent/cache/testing.go index ca5e526e4..9ec637be4 100644 --- a/command/agent/cache/testing.go +++ b/command/agent/cache/testing.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "math/rand" "net/http" "strings" "time" @@ -80,3 +81,27 @@ func (p *mockTokenVerifierProxier) Send(ctx context.Context, req *SendRequest) ( func (p *mockTokenVerifierProxier) GetCurrentRequestToken() string { return p.currentToken } + +type mockDelayProxier struct { + cacheableResp bool + delay int +} + +func (p *mockDelayProxier) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) { + if p.delay > 0 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(p.delay) * time.Millisecond): + } + } + + // If this is a cacheable response, we return a unique response every time + if p.cacheableResp { + rand.Seed(time.Now().Unix()) + s := fmt.Sprintf(`{"lease_id": "%d", "renewable": true, "data": {"foo": "bar"}}`, rand.Int()) + return newTestSendResponse(http.StatusOK, s), nil + } + + return newTestSendResponse(http.StatusOK, `{"value": "output"}`), nil +}