agent: add an inflight cache better concurrent request handling (#10705)
* agent: do not grap idLock writelock until caching entry * agent: inflight cache using sync.Map * agent: implement an inflight caching mechanism * agent/lease: add lock for inflight cache to prevent simultaneous Set calls * agent/lease: lock on a per-ID basis so unique requests can be processed independently * agent/lease: add some concurrency tests * test: use lease_id for uniqueness * agent: remove env flags, add comments around locks * agent: clean up test comment * agent: clean up test comment * agent: remove commented debug code * agent/lease: word-smithing * Update command/agent/cache/lease_cache.go Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com> * agent/lease: return the context error if the Done ch got closed * agent/lease: fix data race in concurrency tests * agent/lease: mockDelayProxier: return ctx.Err() if context got canceled * agent/lease: remove unused inflightCacheLock * agent/lease: test: bump context timeout to 3s Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>
This commit is contained in:
parent
fb049caa7f
commit
0df09e356d
|
@ -27,6 +27,8 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/locksutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -78,6 +80,9 @@ type LeaseCache struct {
|
|||
// idLocks is used during cache lookup to ensure that identical requests made
|
||||
// in parallel won't trigger multiple renewal goroutines.
|
||||
idLocks []*locksutil.LockEntry
|
||||
|
||||
// inflightCache keeps track of inflight requests
|
||||
inflightCache *gocache.Cache
|
||||
}
|
||||
|
||||
// LeaseCacheConfig is the configuration for initializing a new
|
||||
|
@ -89,6 +94,22 @@ type LeaseCacheConfig struct {
|
|||
Logger hclog.Logger
|
||||
}
|
||||
|
||||
type inflightRequest struct {
|
||||
// ch is closed by the request that ends up processing the set of
|
||||
// parallel request
|
||||
ch chan struct{}
|
||||
|
||||
// remaining is the number of remaining inflight request that needs to
|
||||
// be processed before this object can be cleaned up
|
||||
remaining atomic.Uint64
|
||||
}
|
||||
|
||||
func newInflightRequest() *inflightRequest {
|
||||
return &inflightRequest{
|
||||
ch: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// NewLeaseCache creates a new instance of a LeaseCache.
|
||||
func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
|
||||
if conf == nil {
|
||||
|
@ -112,13 +133,14 @@ func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
|
|||
baseCtxInfo := cachememdb.NewContextInfo(conf.BaseContext)
|
||||
|
||||
return &LeaseCache{
|
||||
client: conf.Client,
|
||||
proxier: conf.Proxier,
|
||||
logger: conf.Logger,
|
||||
db: db,
|
||||
baseCtxInfo: baseCtxInfo,
|
||||
l: &sync.RWMutex{},
|
||||
idLocks: locksutil.CreateLocks(),
|
||||
client: conf.Client,
|
||||
proxier: conf.Proxier,
|
||||
logger: conf.Logger,
|
||||
db: db,
|
||||
baseCtxInfo: baseCtxInfo,
|
||||
l: &sync.RWMutex{},
|
||||
idLocks: locksutil.CreateLocks(),
|
||||
inflightCache: gocache.New(gocache.NoExpiration, gocache.NoExpiration),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -170,40 +192,60 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Grab a read lock for this particular request
|
||||
// Check the inflight cache to see if there are other inflight requests
|
||||
// of the same kind, based on the computed ID. If so, we increment a counter
|
||||
|
||||
var inflight *inflightRequest
|
||||
|
||||
defer func() {
|
||||
// Cleanup on the cache if there are no remaining inflight requests.
|
||||
// This is the last step, so we defer the call first
|
||||
if inflight != nil && inflight.remaining.Load() == 0 {
|
||||
c.inflightCache.Delete(id)
|
||||
}
|
||||
}()
|
||||
|
||||
idLock := locksutil.LockForKey(c.idLocks, id)
|
||||
|
||||
idLock.RLock()
|
||||
unlockFunc := idLock.RUnlock
|
||||
defer func() { unlockFunc() }()
|
||||
// Briefly grab an ID-based lock in here to emulate a load-or-store behavior
|
||||
// and prevent concurrent cacheable requests from being proxied twice if
|
||||
// they both miss the cache due to it being clean when peeking the cache
|
||||
// entry.
|
||||
idLock.Lock()
|
||||
inflightRaw, found := c.inflightCache.Get(id)
|
||||
if found {
|
||||
idLock.Unlock()
|
||||
inflight = inflightRaw.(*inflightRequest)
|
||||
inflight.remaining.Inc()
|
||||
defer inflight.remaining.Dec()
|
||||
|
||||
// If found it means that there's an inflight request being processed.
|
||||
// We wait until that's finished before proceeding further.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-inflight.ch:
|
||||
}
|
||||
} else {
|
||||
inflight = newInflightRequest()
|
||||
inflight.remaining.Inc()
|
||||
defer inflight.remaining.Dec()
|
||||
|
||||
c.inflightCache.Set(id, inflight, gocache.NoExpiration)
|
||||
idLock.Unlock()
|
||||
|
||||
// Signal that the processing request is done
|
||||
defer close(inflight.ch)
|
||||
}
|
||||
|
||||
// Check if the response for this request is already in the cache
|
||||
sendResp, err := c.checkCacheForRequest(id)
|
||||
cachedResp, err := c.checkCacheForRequest(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sendResp != nil {
|
||||
if cachedResp != nil {
|
||||
c.logger.Debug("returning cached response", "path", req.Request.URL.Path)
|
||||
return sendResp, nil
|
||||
}
|
||||
|
||||
// Perform a lock upgrade
|
||||
idLock.RUnlock()
|
||||
idLock.Lock()
|
||||
unlockFunc = idLock.Unlock
|
||||
|
||||
// Check cache once more after upgrade
|
||||
sendResp, err = c.checkCacheForRequest(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If found, it means that some other parallel request already cached this response
|
||||
// in between this upgrade so we can simply return that. Otherwise, this request
|
||||
// will be the one performing the cache write.
|
||||
if sendResp != nil {
|
||||
c.logger.Debug("returning cached response", "method", req.Request.Method, "path", req.Request.URL.Path)
|
||||
return sendResp, nil
|
||||
return cachedResp, nil
|
||||
}
|
||||
|
||||
c.logger.Debug("forwarding request", "method", req.Request.Method, "path", req.Request.URL.Path)
|
||||
|
@ -441,7 +483,7 @@ func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index,
|
|||
func computeIndexID(req *SendRequest) (string, error) {
|
||||
var b bytes.Buffer
|
||||
|
||||
// Serialze the request
|
||||
// Serialize the request
|
||||
if err := req.Request.Write(&b); err != nil {
|
||||
return "", fmt.Errorf("failed to serialize request: %v", err)
|
||||
}
|
||||
|
|
|
@ -8,15 +8,17 @@ import (
|
|||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/command/agent/cache/cachememdb"
|
||||
"time"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/command/agent/cache/cachememdb"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache {
|
||||
|
@ -40,6 +42,27 @@ func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache {
|
|||
return lc
|
||||
}
|
||||
|
||||
func testNewLeaseCacheWithDelay(t *testing.T, cacheable bool, delay int) *LeaseCache {
|
||||
t.Helper()
|
||||
|
||||
client, err := api.NewClient(api.DefaultConfig())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
lc, err := NewLeaseCache(&LeaseCacheConfig{
|
||||
Client: client,
|
||||
BaseContext: context.Background(),
|
||||
Proxier: &mockDelayProxier{cacheable, delay},
|
||||
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return lc
|
||||
}
|
||||
|
||||
func TestCache_ComputeIndexID(t *testing.T) {
|
||||
type args struct {
|
||||
req *http.Request
|
||||
|
@ -509,3 +532,108 @@ func TestCache_DeriveNamespaceAndRevocationPath(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeaseCache_Concurrent_NonCacheable(t *testing.T) {
|
||||
lc := testNewLeaseCacheWithDelay(t, false, 50)
|
||||
|
||||
// We are going to send 100 requests, each taking 50ms to process. If these
|
||||
// requests are processed serially, it will take ~5seconds to finish. we
|
||||
// use a ContextWithTimeout to tell us if this is the case by giving ample
|
||||
// time for it process them concurrently but time out if they get processed
|
||||
// serially.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
wgDoneCh := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
var wg sync.WaitGroup
|
||||
// 100 concurrent requests
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// Send a request through the lease cache which is not cacheable (there is
|
||||
// no lease information or auth information in the response)
|
||||
sendReq := &SendRequest{
|
||||
Request: httptest.NewRequest("GET", "http://example.com", nil),
|
||||
}
|
||||
|
||||
_, err := lc.Send(ctx, sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(wgDoneCh)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("request timed out: %s", ctx.Err())
|
||||
case <-wgDoneCh:
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestLeaseCache_Concurrent_Cacheable(t *testing.T) {
|
||||
lc := testNewLeaseCacheWithDelay(t, true, 50)
|
||||
|
||||
if err := lc.RegisterAutoAuthToken("autoauthtoken"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// We are going to send 100 requests, each taking 50ms to process. If these
|
||||
// requests are processed serially, it will take ~5seconds to finish, so we
|
||||
// use a ContextWithTimeout to tell us if this is the case.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var cacheCount atomic.Uint32
|
||||
wgDoneCh := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
var wg sync.WaitGroup
|
||||
// Start 100 concurrent requests
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
sendReq := &SendRequest{
|
||||
Token: "autoauthtoken",
|
||||
Request: httptest.NewRequest("GET", "http://example.com/v1/sample/api", nil),
|
||||
}
|
||||
|
||||
resp, err := lc.Send(ctx, sendReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.CacheMeta != nil && resp.CacheMeta.Hit {
|
||||
cacheCount.Inc()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(wgDoneCh)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("request timed out: %s", ctx.Err())
|
||||
case <-wgDoneCh:
|
||||
}
|
||||
|
||||
// Ensure that all but one request got proxied. The other 99 should be
|
||||
// returned from the cache.
|
||||
if cacheCount.Load() != 99 {
|
||||
t.Fatalf("Should have returned a cached response 99 times, got %d", cacheCount.Load())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -80,3 +81,27 @@ func (p *mockTokenVerifierProxier) Send(ctx context.Context, req *SendRequest) (
|
|||
func (p *mockTokenVerifierProxier) GetCurrentRequestToken() string {
|
||||
return p.currentToken
|
||||
}
|
||||
|
||||
type mockDelayProxier struct {
|
||||
cacheableResp bool
|
||||
delay int
|
||||
}
|
||||
|
||||
func (p *mockDelayProxier) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
|
||||
if p.delay > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(time.Duration(p.delay) * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
// If this is a cacheable response, we return a unique response every time
|
||||
if p.cacheableResp {
|
||||
rand.Seed(time.Now().Unix())
|
||||
s := fmt.Sprintf(`{"lease_id": "%d", "renewable": true, "data": {"foo": "bar"}}`, rand.Int())
|
||||
return newTestSendResponse(http.StatusOK, s), nil
|
||||
}
|
||||
|
||||
return newTestSendResponse(http.StatusOK, `{"value": "output"}`), nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue