Server side rate limiter: handle the race condition for limiters tree write in multilimiter (#15767)

* change to perform all tree writes in the same go routine to avoid race condition.

* rename runStoreOnce to reconcile

* Apply suggestions from code review

Co-authored-by: Dan Upton <daniel@floppy.co>

* reduce nesting

Co-authored-by: Dan Upton <daniel@floppy.co>
This commit is contained in:
Dhia Ayachi 2022-12-14 12:32:11 -05:00 committed by GitHub
parent 1f82e82e04
commit 11f245f24f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 79 additions and 84 deletions

View File

@ -110,32 +110,20 @@ func NewMultiLimiter(c Config) *MultiLimiter {
func (m *MultiLimiter) Run(ctx context.Context) {
m.once.Do(func() {
go func() {
writeTimeout := 10 * time.Millisecond
cfg := m.defaultConfig.Load()
writeTimeout := cfg.ReconcileCheckInterval
limiters := m.limiters.Load()
txn := limiters.Txn()
waiter := time.NewTicker(writeTimeout)
wt := tickerWrapper{ticker: waiter}
defer waiter.Stop()
for {
if txn = m.runStoreOnce(ctx, wt, txn); txn == nil {
if txn = m.reconcile(ctx, wt, txn, cfg.ReconcileCheckLimit); txn == nil {
return
}
}
}()
go func() {
waiter := time.NewTimer(0)
for {
c := m.defaultConfig.Load()
waiter.Reset(c.ReconcileCheckInterval)
select {
case <-ctx.Done():
waiter.Stop()
return
case now := <-waiter.C:
m.reconcileLimitedOnce(now, c.ReconcileCheckLimit)
}
}
}()
})
}
@ -192,13 +180,16 @@ func (t tickerWrapper) Ticker() <-chan time.Time {
return t.ticker.C
}
func (m *MultiLimiter) runStoreOnce(ctx context.Context, waiter ticker, txn *radix.Txn) *radix.Txn {
func (m *MultiLimiter) reconcile(ctx context.Context, waiter ticker, txn *radix.Txn, reconcileCheckLimit time.Duration) *radix.Txn {
select {
case <-waiter.Ticker():
tree := txn.Commit()
m.limiters.Store(tree)
txn = tree.Txn()
m.cleanLimiters(time.Now(), reconcileCheckLimit, txn)
m.reconcileConfig(txn)
tree = txn.Commit()
txn = tree.Txn()
case lk := <-m.limiterCh:
v, ok := txn.Get(lk.k)
if !ok {
@ -215,60 +206,55 @@ func (m *MultiLimiter) runStoreOnce(ctx context.Context, waiter ticker, txn *rad
return txn
}
// reconcileLimitedOnce is called by the MultiLimiter clean up routine to remove old Limited entries
// it will wait for ReconcileCheckInterval before traversing the radix tree and removing all entries
// with lastAccess > ReconcileCheckLimit
func (m *MultiLimiter) reconcileLimitedOnce(now time.Time, reconcileCheckLimit time.Duration) {
limiters := m.limiters.Load()
storedLimiters := limiters
iter := limiters.Root().Iterator()
k, v, ok := iter.Next()
var txn *radix.Txn
txn = limiters.Txn()
// remove all expired limiters
for ok {
if t, ok := v.(*Limiter); ok {
if t.limiter != nil {
lastAccess := t.lastAccess.Load()
lastAccessT := time.UnixMilli(lastAccess)
diff := now.Sub(lastAccessT)
if diff > reconcileCheckLimit {
txn.Delete(k)
}
}
}
k, v, ok = iter.Next()
}
iter = txn.Root().Iterator()
k, v, ok = iter.Next()
func (m *MultiLimiter) reconcileConfig(txn *radix.Txn) {
iter := txn.Root().Iterator()
// make sure all limiters have the latest defaultConfig of their prefix
for ok {
if pl, ok := v.(*Limiter); ok {
// check if it has a limiter, if so that's a lead
if pl.limiter != nil {
// find the prefix for the leaf and check if the defaultConfig is up-to-date
// it's possible that the prefix is equal to the key
prefix, _ := splitKey(k)
v, ok := m.limitersConfigs.Load().Get(prefix)
if ok {
if cl, ok := v.(*LimiterConfig); ok {
if cl != nil {
if !cl.isApplied(pl.limiter) {
limiter := Limiter{limiter: rate.NewLimiter(cl.Rate, cl.Burst)}
limiter.lastAccess.Store(pl.lastAccess.Load())
txn.Insert(k, &limiter)
}
}
}
}
}
for k, v, ok := iter.Next(); ok; k, v, ok = iter.Next() {
pl, ok := v.(*Limiter)
if pl == nil || !ok {
continue
}
k, v, ok = iter.Next()
if pl.limiter == nil {
continue
}
// find the prefix for the leaf and check if the defaultConfig is up-to-date
// it's possible that the prefix is equal to the key
prefix, _ := splitKey(k)
v, ok := m.limitersConfigs.Load().Get(prefix)
if v == nil || !ok {
continue
}
cl, ok := v.(*LimiterConfig)
if cl == nil || !ok {
continue
}
if cl.isApplied(pl.limiter) {
continue
}
limiter := Limiter{limiter: rate.NewLimiter(cl.Rate, cl.Burst)}
limiter.lastAccess.Store(pl.lastAccess.Load())
txn.Insert(k, &limiter)
}
limiters = txn.Commit()
m.limiters.CompareAndSwap(storedLimiters, limiters)
}
func (m *MultiLimiter) cleanLimiters(now time.Time, reconcileCheckLimit time.Duration, txn *radix.Txn) {
iter := txn.Root().Iterator()
// remove all expired limiters
for k, v, ok := iter.Next(); ok; k, v, ok = iter.Next() {
t, isLimiter := v.(*Limiter)
if !isLimiter || t.limiter == nil {
continue
}
lastAccess := time.UnixMilli(t.lastAccess.Load())
if now.Sub(lastAccess) > reconcileCheckLimit {
txn.Delete(k)
}
}
}
func (lc *LimiterConfig) isApplied(l *rate.Limiter) bool {

View File

@ -85,9 +85,11 @@ func TestRateLimiterCleanup(t *testing.T) {
limiters = l
})
l, ok := limiters.Get(key)
require.True(t, ok)
require.NotNil(t, l)
retry.RunWith(&retry.Timer{Wait: 100 * time.Millisecond, Timeout: 2 * time.Second}, t, func(r *retry.R) {
v, ok := limiters.Get(key)
require.True(r, ok)
require.NotNil(t, v)
})
time.Sleep(c.ReconcileCheckInterval)
// Wait > ReconcileCheckInterval and check that the key was cleaned up
@ -95,19 +97,30 @@ func TestRateLimiterCleanup(t *testing.T) {
l := m.limiters.Load()
require.NotEqual(r, limiters, l)
limiters = l
v, ok := limiters.Get(key)
require.False(r, ok)
require.Nil(r, v)
})
l, ok = limiters.Get(key)
require.False(t, ok)
require.Nil(t, l)
}
func storeLimiter(m *MultiLimiter) {
txn := m.limiters.Load().Txn()
mockTicker := mockTicker{tickerCh: make(chan time.Time, 1)}
ctx := context.Background()
m.runStoreOnce(ctx, &mockTicker, txn)
reconcileCheckLimit := m.defaultConfig.Load().ReconcileCheckLimit
m.reconcile(ctx, &mockTicker, txn, reconcileCheckLimit)
mockTicker.tickerCh <- time.Now()
m.runStoreOnce(ctx, &mockTicker, txn)
m.reconcile(ctx, &mockTicker, txn, reconcileCheckLimit)
}
func reconcile(m *MultiLimiter) {
txn := m.limiters.Load().Txn()
mockTicker := mockTicker{tickerCh: make(chan time.Time, 1)}
ctx := context.Background()
reconcileCheckLimit := m.defaultConfig.Load().ReconcileCheckLimit
mockTicker.tickerCh <- time.Now()
m.reconcile(ctx, &mockTicker, txn, reconcileCheckLimit)
}
func TestRateLimiterStore(t *testing.T) {
@ -241,7 +254,7 @@ func TestRateLimiterUpdateConfig(t *testing.T) {
c3 := LimiterConfig{Rate: 2}
m.UpdateConfig(c3, prefix)
// call reconcileLimitedOnce to make sure the update is applied
m.reconcileLimitedOnce(time.Now(), 100*time.Millisecond)
m.reconcileConfig(m.limiters.Load().Txn())
m.Allow(ipLimited{key: ip})
storeLimiter(m)
l3, ok3 := m.limiters.Load().Get(ip)
@ -276,7 +289,7 @@ func TestRateLimiterUpdateConfig(t *testing.T) {
c1 := LimiterConfig{Rate: 1}
m.UpdateConfig(c1, prefix)
// call reconcileLimitedOnce to make sure the update is applied
m.reconcileLimitedOnce(time.Now(), 100*time.Millisecond)
reconcile(m)
m.Allow(ipLimited{key: ip})
storeLimiter(m)
l, ok := m.limiters.Load().Get(ip)
@ -284,10 +297,6 @@ func TestRateLimiterUpdateConfig(t *testing.T) {
require.NotNil(t, l)
limiter := l.(*Limiter)
require.True(t, c1.isApplied(limiter.limiter))
m.reconcileLimitedOnce(time.Now().Add(100*time.Millisecond), 100*time.Millisecond)
l, ok = m.limiters.Load().Get(ip)
require.False(t, ok)
require.Nil(t, l)
})
}
@ -347,7 +356,7 @@ func FuzzUpdateConfig(f *testing.F) {
m.UpdateConfig(c, prefix)
go m.Allow(Limited{key: f})
}
m.reconcileLimitedOnce(time.Now(), 1*time.Millisecond)
m.reconcileConfig(m.limiters.Load().Txn())
checkTree(t, m.limiters.Load().Txn())
})