agent: add an inflight cache better concurrent request handling (#10705)

* agent: do not grap idLock writelock until caching entry

* agent: inflight cache using sync.Map

* agent: implement an inflight caching mechanism

* agent/lease: add lock for inflight cache to prevent simultaneous Set calls

* agent/lease: lock on a per-ID basis so unique requests can be processed independently

* agent/lease: add some concurrency tests

* test: use lease_id for uniqueness

* agent: remove env flags, add comments around locks

* agent: clean up test comment

* agent: clean up test comment

* agent: remove commented debug code

* agent/lease: word-smithing

* Update command/agent/cache/lease_cache.go

Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>

* agent/lease: return the context error if the Done ch got closed

* agent/lease: fix data race in concurrency tests

* agent/lease: mockDelayProxier: return ctx.Err() if context got canceled

* agent/lease: remove unused inflightCacheLock

* agent/lease: test: bump context timeout to 3s

Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>
This commit is contained in:
Calvin Leung Huang 2021-01-26 12:09:37 -08:00 committed by GitHub
parent fb049caa7f
commit 0df09e356d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 231 additions and 36 deletions

View File

@ -27,6 +27,8 @@ import (
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/helper/locksutil"
"github.com/hashicorp/vault/sdk/logical"
gocache "github.com/patrickmn/go-cache"
"go.uber.org/atomic"
)
const (
@ -78,6 +80,9 @@ type LeaseCache struct {
// idLocks is used during cache lookup to ensure that identical requests made
// in parallel won't trigger multiple renewal goroutines.
idLocks []*locksutil.LockEntry
// inflightCache keeps track of inflight requests
inflightCache *gocache.Cache
}
// LeaseCacheConfig is the configuration for initializing a new
@ -89,6 +94,22 @@ type LeaseCacheConfig struct {
Logger hclog.Logger
}
type inflightRequest struct {
// ch is closed by the request that ends up processing the set of
// parallel request
ch chan struct{}
// remaining is the number of remaining inflight request that needs to
// be processed before this object can be cleaned up
remaining atomic.Uint64
}
func newInflightRequest() *inflightRequest {
return &inflightRequest{
ch: make(chan struct{}),
}
}
// NewLeaseCache creates a new instance of a LeaseCache.
func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
if conf == nil {
@ -119,6 +140,7 @@ func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
baseCtxInfo: baseCtxInfo,
l: &sync.RWMutex{},
idLocks: locksutil.CreateLocks(),
inflightCache: gocache.New(gocache.NoExpiration, gocache.NoExpiration),
}, nil
}
@ -170,40 +192,60 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
return nil, err
}
// Grab a read lock for this particular request
// Check the inflight cache to see if there are other inflight requests
// of the same kind, based on the computed ID. If so, we increment a counter
var inflight *inflightRequest
defer func() {
// Cleanup on the cache if there are no remaining inflight requests.
// This is the last step, so we defer the call first
if inflight != nil && inflight.remaining.Load() == 0 {
c.inflightCache.Delete(id)
}
}()
idLock := locksutil.LockForKey(c.idLocks, id)
idLock.RLock()
unlockFunc := idLock.RUnlock
defer func() { unlockFunc() }()
// Briefly grab an ID-based lock in here to emulate a load-or-store behavior
// and prevent concurrent cacheable requests from being proxied twice if
// they both miss the cache due to it being clean when peeking the cache
// entry.
idLock.Lock()
inflightRaw, found := c.inflightCache.Get(id)
if found {
idLock.Unlock()
inflight = inflightRaw.(*inflightRequest)
inflight.remaining.Inc()
defer inflight.remaining.Dec()
// If found it means that there's an inflight request being processed.
// We wait until that's finished before proceeding further.
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-inflight.ch:
}
} else {
inflight = newInflightRequest()
inflight.remaining.Inc()
defer inflight.remaining.Dec()
c.inflightCache.Set(id, inflight, gocache.NoExpiration)
idLock.Unlock()
// Signal that the processing request is done
defer close(inflight.ch)
}
// Check if the response for this request is already in the cache
sendResp, err := c.checkCacheForRequest(id)
cachedResp, err := c.checkCacheForRequest(id)
if err != nil {
return nil, err
}
if sendResp != nil {
if cachedResp != nil {
c.logger.Debug("returning cached response", "path", req.Request.URL.Path)
return sendResp, nil
}
// Perform a lock upgrade
idLock.RUnlock()
idLock.Lock()
unlockFunc = idLock.Unlock
// Check cache once more after upgrade
sendResp, err = c.checkCacheForRequest(id)
if err != nil {
return nil, err
}
// If found, it means that some other parallel request already cached this response
// in between this upgrade so we can simply return that. Otherwise, this request
// will be the one performing the cache write.
if sendResp != nil {
c.logger.Debug("returning cached response", "method", req.Request.Method, "path", req.Request.URL.Path)
return sendResp, nil
return cachedResp, nil
}
c.logger.Debug("forwarding request", "method", req.Request.Method, "path", req.Request.URL.Path)
@ -441,7 +483,7 @@ func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index,
func computeIndexID(req *SendRequest) (string, error) {
var b bytes.Buffer
// Serialze the request
// Serialize the request
if err := req.Request.Write(&b); err != nil {
return "", fmt.Errorf("failed to serialize request: %v", err)
}

View File

@ -8,15 +8,17 @@ import (
"net/url"
"reflect"
"strings"
"sync"
"testing"
"github.com/hashicorp/vault/command/agent/cache/cachememdb"
"time"
"github.com/go-test/deep"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agent/cache/cachememdb"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/logging"
"go.uber.org/atomic"
)
func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache {
@ -40,6 +42,27 @@ func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache {
return lc
}
func testNewLeaseCacheWithDelay(t *testing.T, cacheable bool, delay int) *LeaseCache {
t.Helper()
client, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}
lc, err := NewLeaseCache(&LeaseCacheConfig{
Client: client,
BaseContext: context.Background(),
Proxier: &mockDelayProxier{cacheable, delay},
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"),
})
if err != nil {
t.Fatal(err)
}
return lc
}
func TestCache_ComputeIndexID(t *testing.T) {
type args struct {
req *http.Request
@ -509,3 +532,108 @@ func TestCache_DeriveNamespaceAndRevocationPath(t *testing.T) {
})
}
}
func TestLeaseCache_Concurrent_NonCacheable(t *testing.T) {
lc := testNewLeaseCacheWithDelay(t, false, 50)
// We are going to send 100 requests, each taking 50ms to process. If these
// requests are processed serially, it will take ~5seconds to finish. we
// use a ContextWithTimeout to tell us if this is the case by giving ample
// time for it process them concurrently but time out if they get processed
// serially.
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
wgDoneCh := make(chan struct{})
go func() {
var wg sync.WaitGroup
// 100 concurrent requests
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
// Send a request through the lease cache which is not cacheable (there is
// no lease information or auth information in the response)
sendReq := &SendRequest{
Request: httptest.NewRequest("GET", "http://example.com", nil),
}
_, err := lc.Send(ctx, sendReq)
if err != nil {
t.Fatal(err)
}
}()
}
wg.Wait()
close(wgDoneCh)
}()
select {
case <-ctx.Done():
t.Fatalf("request timed out: %s", ctx.Err())
case <-wgDoneCh:
}
}
func TestLeaseCache_Concurrent_Cacheable(t *testing.T) {
lc := testNewLeaseCacheWithDelay(t, true, 50)
if err := lc.RegisterAutoAuthToken("autoauthtoken"); err != nil {
t.Fatal(err)
}
// We are going to send 100 requests, each taking 50ms to process. If these
// requests are processed serially, it will take ~5seconds to finish, so we
// use a ContextWithTimeout to tell us if this is the case.
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
var cacheCount atomic.Uint32
wgDoneCh := make(chan struct{})
go func() {
var wg sync.WaitGroup
// Start 100 concurrent requests
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
sendReq := &SendRequest{
Token: "autoauthtoken",
Request: httptest.NewRequest("GET", "http://example.com/v1/sample/api", nil),
}
resp, err := lc.Send(ctx, sendReq)
if err != nil {
t.Fatal(err)
}
if resp.CacheMeta != nil && resp.CacheMeta.Hit {
cacheCount.Inc()
}
}()
}
wg.Wait()
close(wgDoneCh)
}()
select {
case <-ctx.Done():
t.Fatalf("request timed out: %s", ctx.Err())
case <-wgDoneCh:
}
// Ensure that all but one request got proxied. The other 99 should be
// returned from the cache.
if cacheCount.Load() != 99 {
t.Fatalf("Should have returned a cached response 99 times, got %d", cacheCount.Load())
}
}

View File

@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"strings"
"time"
@ -80,3 +81,27 @@ func (p *mockTokenVerifierProxier) Send(ctx context.Context, req *SendRequest) (
func (p *mockTokenVerifierProxier) GetCurrentRequestToken() string {
return p.currentToken
}
type mockDelayProxier struct {
cacheableResp bool
delay int
}
func (p *mockDelayProxier) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
if p.delay > 0 {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(time.Duration(p.delay) * time.Millisecond):
}
}
// If this is a cacheable response, we return a unique response every time
if p.cacheableResp {
rand.Seed(time.Now().Unix())
s := fmt.Sprintf(`{"lease_id": "%d", "renewable": true, "data": {"foo": "bar"}}`, rand.Int())
return newTestSendResponse(http.StatusOK, s), nil
}
return newTestSendResponse(http.StatusOK, `{"value": "output"}`), nil
}