Merge PR #9667: Rate Limit Backoff
This commit is contained in:
parent
ca65131543
commit
f873863263
|
@ -73,6 +73,11 @@ The 'rate' must be positive.`,
|
|||
Type: framework.TypeDurationSecond,
|
||||
Description: "The duration to enforce rate limiting for (default '1s').",
|
||||
},
|
||||
"block_interval": {
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: `If set, when a client reaches a rate limit threshold, the client will be prohibited
|
||||
from any further requests until after the 'block_interval' has elapsed.`,
|
||||
},
|
||||
},
|
||||
Operations: map[logical.Operation]framework.OperationHandler{
|
||||
logical.UpdateOperation: &framework.PathOperation{
|
||||
|
@ -154,6 +159,11 @@ func (b *SystemBackend) handleRateLimitQuotasUpdate() framework.OperationFunc {
|
|||
interval = time.Second
|
||||
}
|
||||
|
||||
blockInterval := time.Second * time.Duration(d.Get("block_interval").(int))
|
||||
if blockInterval < 0 {
|
||||
return logical.ErrorResponse("'block' is invalid"), nil
|
||||
}
|
||||
|
||||
mountPath := sanitizePath(d.Get("path").(string))
|
||||
ns := b.Core.namespaceByPath(mountPath)
|
||||
if ns.ID != namespace.RootNamespaceID {
|
||||
|
@ -185,13 +195,14 @@ func (b *SystemBackend) handleRateLimitQuotasUpdate() framework.OperationFunc {
|
|||
return logical.ErrorResponse("quota rule with similar properties exists under the name %q", quotaByFactors.QuotaName()), nil
|
||||
}
|
||||
|
||||
quota = quotas.NewRateLimitQuota(name, ns.Path, mountPath, rate, interval)
|
||||
quota = quotas.NewRateLimitQuota(name, ns.Path, mountPath, rate, interval, blockInterval)
|
||||
default:
|
||||
rlq := quota.(*quotas.RateLimitQuota)
|
||||
rlq.NamespacePath = ns.Path
|
||||
rlq.MountPath = mountPath
|
||||
rlq.Rate = rate
|
||||
rlq.Interval = interval
|
||||
rlq.BlockInterval = blockInterval
|
||||
}
|
||||
|
||||
entry, err := logical.StorageEntryJSON(quotas.QuotaStoragePath(qType, name), quota)
|
||||
|
@ -232,11 +243,12 @@ func (b *SystemBackend) handleRateLimitQuotasRead() framework.OperationFunc {
|
|||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"type": qType,
|
||||
"name": rlq.Name,
|
||||
"path": nsPath + rlq.MountPath,
|
||||
"rate": rlq.Rate,
|
||||
"interval": int(rlq.Interval.Seconds()),
|
||||
"type": qType,
|
||||
"name": rlq.Name,
|
||||
"path": nsPath + rlq.MountPath,
|
||||
"rate": rlq.Rate,
|
||||
"interval": int(rlq.Interval.Seconds()),
|
||||
"block_interval": int(rlq.BlockInterval.Seconds()),
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
|
|
|
@ -73,18 +73,28 @@ type RateLimitQuota struct {
|
|||
// Interval defines the duration to which rate limiting is applied.
|
||||
Interval time.Duration `json:"interval"`
|
||||
|
||||
lock *sync.RWMutex
|
||||
store limiter.Store
|
||||
logger log.Logger
|
||||
metricSink *metricsutil.ClusterMetricSink
|
||||
purgeInterval time.Duration
|
||||
staleAge time.Duration
|
||||
// BlockInterval defines the duration during which all requests are blocked for
|
||||
// a given client. This interval is enforced only if non-zero and a client
|
||||
// reaches the rate limit.
|
||||
BlockInterval time.Duration `json:"block_interval"`
|
||||
|
||||
lock *sync.RWMutex
|
||||
store limiter.Store
|
||||
logger log.Logger
|
||||
metricSink *metricsutil.ClusterMetricSink
|
||||
purgeInterval time.Duration
|
||||
staleAge time.Duration
|
||||
blockedClients sync.Map
|
||||
purgeBlocked bool
|
||||
closePurgeBlockedCh chan struct{}
|
||||
}
|
||||
|
||||
// NewRateLimitQuota creates a quota checker for imposing limits on the number
|
||||
// of requests in a given interval. An interval time duration of zero may be
|
||||
// provided, which will default to 1s when initialized.
|
||||
func NewRateLimitQuota(name, nsPath, mountPath string, rate float64, interval time.Duration) *RateLimitQuota {
|
||||
// provided, which will default to 1s when initialized. An optional block
|
||||
// duration may be provided, where if set, when a client reaches the rate limit,
|
||||
// subsequent requests will fail until the block duration has passed.
|
||||
func NewRateLimitQuota(name, nsPath, mountPath string, rate float64, interval, block time.Duration) *RateLimitQuota {
|
||||
return &RateLimitQuota{
|
||||
Name: name,
|
||||
Type: TypeRateLimit,
|
||||
|
@ -92,6 +102,7 @@ func NewRateLimitQuota(name, nsPath, mountPath string, rate float64, interval ti
|
|||
MountPath: mountPath,
|
||||
Rate: rate,
|
||||
Interval: interval,
|
||||
BlockInterval: block,
|
||||
purgeInterval: DefaultRateLimitPurgeInterval,
|
||||
staleAge: DefaultRateLimitStaleAge,
|
||||
}
|
||||
|
@ -119,7 +130,11 @@ func (rlq *RateLimitQuota) initialize(logger log.Logger, ms *metricsutil.Cluster
|
|||
}
|
||||
|
||||
if rlq.Rate <= 0 {
|
||||
return fmt.Errorf("invalid avg rps: %v", rlq.Rate)
|
||||
return fmt.Errorf("invalid rate: %v", rlq.Rate)
|
||||
}
|
||||
|
||||
if rlq.BlockInterval < 0 {
|
||||
return fmt.Errorf("invalid block interval: %v", rlq.BlockInterval)
|
||||
}
|
||||
|
||||
if logger != nil {
|
||||
|
@ -150,10 +165,73 @@ func (rlq *RateLimitQuota) initialize(logger log.Logger, ms *metricsutil.Cluster
|
|||
}
|
||||
|
||||
rlq.store = rlStore
|
||||
rlq.blockedClients = sync.Map{}
|
||||
|
||||
if rlq.BlockInterval > 0 && !rlq.purgeBlocked {
|
||||
rlq.purgeBlocked = true
|
||||
rlq.closePurgeBlockedCh = make(chan struct{})
|
||||
go rlq.purgeBlockedClients()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// purgeBlockedClients performs a blocking process where every purgeInterval
|
||||
// duration, we look at all blocked clients to potentially remove from the blocked
|
||||
// clients map.
|
||||
//
|
||||
// A blocked client will only be removed if the current time minus the time the
|
||||
// client was blocked at is greater than or equal to the block duration. The loop
|
||||
// will continue to run indefinitely until a value is sent on the closePurgeBlockedCh
|
||||
// in which we stop the ticker and return.
|
||||
func (rlq *RateLimitQuota) purgeBlockedClients() {
|
||||
rlq.lock.RLock()
|
||||
ticker := time.NewTicker(rlq.purgeInterval)
|
||||
rlq.lock.RUnlock()
|
||||
|
||||
for {
|
||||
select {
|
||||
case t := <-ticker.C:
|
||||
rlq.blockedClients.Range(func(key, value interface{}) bool {
|
||||
blockedAt := value.(time.Time)
|
||||
if t.Sub(blockedAt) >= rlq.BlockInterval {
|
||||
rlq.blockedClients.Delete(key)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
case <-rlq.closePurgeBlockedCh:
|
||||
ticker.Stop()
|
||||
|
||||
rlq.lock.Lock()
|
||||
rlq.purgeBlocked = false
|
||||
rlq.lock.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rlq *RateLimitQuota) getPurgeBlocked() bool {
|
||||
rlq.lock.RLock()
|
||||
defer rlq.lock.RUnlock()
|
||||
return rlq.purgeBlocked
|
||||
}
|
||||
|
||||
func (rlq *RateLimitQuota) numBlockedClients() int {
|
||||
rlq.lock.RLock()
|
||||
defer rlq.lock.RUnlock()
|
||||
|
||||
size := 0
|
||||
rlq.blockedClients.Range(func(_, _ interface{}) bool {
|
||||
size++
|
||||
return true
|
||||
})
|
||||
|
||||
return size
|
||||
}
|
||||
|
||||
// quotaID returns the identifier of the quota rule
|
||||
func (rlq *RateLimitQuota) quotaID() string {
|
||||
return rlq.ID
|
||||
|
@ -169,7 +247,9 @@ func (rlq *RateLimitQuota) QuotaName() string {
|
|||
// quota will not be evaluated. Otherwise, the client rate limiter is retrieved
|
||||
// by address and the rate limit quota is checked against that limiter.
|
||||
func (rlq *RateLimitQuota) allow(req *Request) (Response, error) {
|
||||
var resp Response
|
||||
resp := Response{
|
||||
Headers: make(map[string]string),
|
||||
}
|
||||
|
||||
// Skip rate limit checks for paths that are exempt from rate limiting.
|
||||
if rateLimitExemptPaths.HasPath(req.Path) {
|
||||
|
@ -181,17 +261,46 @@ func (rlq *RateLimitQuota) allow(req *Request) (Response, error) {
|
|||
return resp, fmt.Errorf("missing request client address in quota request")
|
||||
}
|
||||
|
||||
limit, remaining, reset, allow := rlq.store.Take(req.ClientAddress)
|
||||
resp.Allowed = allow
|
||||
resp.Headers = map[string]string{
|
||||
httplimit.HeaderRateLimitLimit: strconv.FormatUint(limit, 10),
|
||||
httplimit.HeaderRateLimitRemaining: strconv.FormatUint(remaining, 10),
|
||||
httplimit.HeaderRateLimitReset: time.Unix(0, int64(reset)).UTC().Format(time.RFC1123),
|
||||
var retryAfter string
|
||||
|
||||
defer func() {
|
||||
if !resp.Allowed {
|
||||
resp.Headers[httplimit.HeaderRetryAfter] = retryAfter
|
||||
rlq.metricSink.IncrCounterWithLabels([]string{"quota", "rate_limit", "violation"}, 1, []metrics.Label{{"name", rlq.Name}})
|
||||
}
|
||||
}()
|
||||
|
||||
// Check if the client is currently blocked and if so, deny the request. Note,
|
||||
// we cannot simply rely on the presence of the client in the map as the timing
|
||||
// of purging blocked clients may not yield a false negative. In other words,
|
||||
// a client may no longer be considered blocked whereas the purging interval
|
||||
// has yet to run.
|
||||
if v, ok := rlq.blockedClients.Load(req.ClientAddress); ok {
|
||||
blockedAt := v.(time.Time)
|
||||
if time.Since(blockedAt) >= rlq.BlockInterval {
|
||||
// allow the request and remove the blocked client
|
||||
rlq.blockedClients.Delete(req.ClientAddress)
|
||||
} else {
|
||||
// deny the request and return early
|
||||
resp.Allowed = false
|
||||
retryAfter = strconv.Itoa(int(time.Until(blockedAt.Add(rlq.BlockInterval)).Seconds()))
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
if !resp.Allowed {
|
||||
resp.Headers[httplimit.HeaderRetryAfter] = resp.Headers[httplimit.HeaderRateLimitReset]
|
||||
rlq.metricSink.IncrCounterWithLabels([]string{"quota", "rate_limit", "violation"}, 1, []metrics.Label{{"name", rlq.Name}})
|
||||
limit, remaining, reset, allow := rlq.store.Take(req.ClientAddress)
|
||||
resp.Allowed = allow
|
||||
resp.Headers[httplimit.HeaderRateLimitLimit] = strconv.FormatUint(limit, 10)
|
||||
resp.Headers[httplimit.HeaderRateLimitRemaining] = strconv.FormatUint(remaining, 10)
|
||||
resp.Headers[httplimit.HeaderRateLimitReset] = strconv.Itoa(int(time.Until(time.Unix(0, int64(reset))).Seconds()))
|
||||
retryAfter = resp.Headers[httplimit.HeaderRateLimitReset]
|
||||
|
||||
// If the request is not allowed (i.e. rate limit threshold reached) and blocking
|
||||
// is enabled, we add the client to the set of blocked clients.
|
||||
if !resp.Allowed && rlq.purgeBlocked {
|
||||
blockedAt := time.Now()
|
||||
retryAfter = strconv.Itoa(int(time.Until(blockedAt.Add(rlq.BlockInterval)).Seconds()))
|
||||
rlq.blockedClients.Store(req.ClientAddress, blockedAt)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
|
@ -199,6 +308,10 @@ func (rlq *RateLimitQuota) allow(req *Request) (Response, error) {
|
|||
|
||||
// close stops the current running client purge loop.
|
||||
func (rlq *RateLimitQuota) close() error {
|
||||
if rlq.purgeBlocked {
|
||||
close(rlq.closePurgeBlockedCh)
|
||||
}
|
||||
|
||||
if rlq.store != nil {
|
||||
return rlq.store.Close()
|
||||
}
|
||||
|
|
|
@ -14,13 +14,18 @@ import (
|
|||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type clientResult struct {
|
||||
atomicNumAllow *atomic.Int32
|
||||
atomicNumFail *atomic.Int32
|
||||
}
|
||||
|
||||
func TestNewRateLimitQuota(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
rlq *RateLimitQuota
|
||||
expectErr bool
|
||||
}{
|
||||
{"valid rate", NewRateLimitQuota("test-rate-limiter", "qa", "/foo/bar", 16.7, time.Second), false},
|
||||
{"valid rate", NewRateLimitQuota("test-rate-limiter", "qa", "/foo/bar", 16.7, time.Second, 0), false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
@ -34,9 +39,12 @@ func TestNewRateLimitQuota(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRateLimitQuota_Close(t *testing.T) {
|
||||
rlq := NewRateLimitQuota("test-rate-limiter", "qa", "/foo/bar", 16.7, time.Second)
|
||||
rlq := NewRateLimitQuota("test-rate-limiter", "qa", "/foo/bar", 16.7, time.Second, time.Minute)
|
||||
require.NoError(t, rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()))
|
||||
require.NoError(t, rlq.close())
|
||||
|
||||
time.Sleep(time.Second) // allow enough time for purgeClientsLoop to receive on closeCh
|
||||
require.False(t, rlq.getPurgeBlocked(), "expected blocked client purging to be disabled after explicit close")
|
||||
}
|
||||
|
||||
func TestRateLimitQuota_Allow(t *testing.T) {
|
||||
|
@ -56,11 +64,6 @@ func TestRateLimitQuota_Allow(t *testing.T) {
|
|||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
type clientResult struct {
|
||||
atomicNumAllow *atomic.Int32
|
||||
atomicNumFail *atomic.Int32
|
||||
}
|
||||
|
||||
reqFunc := func(addr string, atomicNumAllow, atomicNumFail *atomic.Int32) {
|
||||
defer wg.Done()
|
||||
|
||||
|
@ -80,8 +83,8 @@ func TestRateLimitQuota_Allow(t *testing.T) {
|
|||
|
||||
start := time.Now()
|
||||
end := start.Add(5 * time.Second)
|
||||
for time.Now().Before(end) {
|
||||
|
||||
for time.Now().Before(end) {
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
|
@ -116,3 +119,88 @@ func TestRateLimitQuota_Allow(t *testing.T) {
|
|||
require.Falsef(t, numAllow > want, "too many successful requests; addr: %s, want: %d, numSuccess: %d, numFail: %d, elapsed: %d", addr, want, numAllow, numFail, elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitQuota_Allow_WithBlock(t *testing.T) {
|
||||
rlq := &RateLimitQuota{
|
||||
Name: "test-rate-limiter",
|
||||
Type: TypeRateLimit,
|
||||
NamespacePath: "qa",
|
||||
MountPath: "/foo/bar",
|
||||
Rate: 16.7,
|
||||
BlockInterval: 10 * time.Second,
|
||||
|
||||
// override values to lower durations for testing purposes
|
||||
purgeInterval: 10 * time.Second,
|
||||
staleAge: 10 * time.Second,
|
||||
}
|
||||
|
||||
require.NoError(t, rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()))
|
||||
require.True(t, rlq.getPurgeBlocked())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
reqFunc := func(addr string, atomicNumAllow, atomicNumFail *atomic.Int32) {
|
||||
defer wg.Done()
|
||||
|
||||
resp, err := rlq.allow(&Request{ClientAddress: addr})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Allowed {
|
||||
atomicNumAllow.Add(1)
|
||||
} else {
|
||||
atomicNumFail.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
results := make(map[string]*clientResult)
|
||||
|
||||
start := time.Now()
|
||||
end := start.Add(5 * time.Second)
|
||||
|
||||
for time.Now().Before(end) {
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
addr := fmt.Sprintf("127.0.0.%d", i)
|
||||
cr, ok := results[addr]
|
||||
if !ok {
|
||||
results[addr] = &clientResult{atomicNumAllow: atomic.NewInt32(0), atomicNumFail: atomic.NewInt32(0)}
|
||||
cr = results[addr]
|
||||
}
|
||||
|
||||
go reqFunc(addr, cr.atomicNumAllow, cr.atomicNumFail)
|
||||
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
for _, cr := range results {
|
||||
numAllow := cr.atomicNumAllow.Load()
|
||||
numFail := cr.atomicNumFail.Load()
|
||||
|
||||
// Since blocking is enabled, each client should only have 'rate' successful
|
||||
// requests, whereas all subsequent requests fail.
|
||||
require.Equal(t, int32(17), numAllow)
|
||||
require.NotZero(t, numFail)
|
||||
}
|
||||
|
||||
func() {
|
||||
timeout := time.After(rlq.purgeInterval * 2)
|
||||
ticker := time.Tick(time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
require.Failf(t, "timeout exceeded waiting for blocked clients to be purged", "num blocked: %d", rlq.numBlockedClients())
|
||||
|
||||
case <-ticker:
|
||||
if rlq.numBlockedClients() == 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ func TestQuotas_Precedence(t *testing.T) {
|
|||
|
||||
setQuotaFunc := func(t *testing.T, name, nsPath, mountPath string) Quota {
|
||||
t.Helper()
|
||||
quota := NewRateLimitQuota(name, nsPath, mountPath, 10, time.Second)
|
||||
quota := NewRateLimitQuota(name, nsPath, mountPath, 10, time.Second, 0)
|
||||
require.NoError(t, qm.SetQuota(context.Background(), TypeRateLimit.String(), quota, true))
|
||||
return quota
|
||||
}
|
||||
|
|
|
@ -32,6 +32,9 @@ either be a namespace or mount.
|
|||
- `rate` `(float: 0.0)` - The maximum number of requests in a given interval to
|
||||
be allowed by the quota rule. The `rate` must be positive.
|
||||
- `interval` `(string: "")` - The duration to enforce rate limiting for (default `"1s"`).
|
||||
- `block_interval` `(string: "")` - If set, when a client reaches a rate limit
|
||||
threshold, the client will be prohibited from any further requests until after
|
||||
the 'block_interval' has elapsed.
|
||||
|
||||
### Sample Payload
|
||||
|
||||
|
@ -39,7 +42,8 @@ either be a namespace or mount.
|
|||
{
|
||||
"path": "",
|
||||
"rate": 897.3,
|
||||
"interval": "2m"
|
||||
"interval": "2m",
|
||||
"block_interval": "5m"
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -96,7 +100,8 @@ $ curl \
|
|||
"lease_duration": 0,
|
||||
"renewable": false,
|
||||
"data": {
|
||||
"interval": "2m0s",
|
||||
"block_interval": 300,
|
||||
"interval": 2,
|
||||
"name": "global-rate-limiter",
|
||||
"path": "",
|
||||
"rate": 897.3,
|
||||
|
|
|
@ -40,6 +40,10 @@ rate limit quota, and a rate limit quota defined for a mount takes precedence ov
|
|||
the global and namespace rate limit quotas. In other words, the most specific
|
||||
quota rule will be applied.
|
||||
|
||||
A rate limit can be created with an optional `block_interval`, such that when set
|
||||
to a non-zero value, any client that hits a rate limit threshold will be blocked
|
||||
from all subsequent requests for a duration of `block_interval` seconds.
|
||||
|
||||
Vault also allows the inspection of the state of rate limiting in a Vault node
|
||||
through various [metrics](/docs/internals/telemetry#Resource-Quota-Metrics) exposed
|
||||
and through enabling optional audit logging.
|
||||
|
|
Loading…
Reference in New Issue