diff --git a/agent/consul/multilimiter/multilimiter.go b/agent/consul/multilimiter/multilimiter.go index 539cced1c..ff30b48b9 100644 --- a/agent/consul/multilimiter/multilimiter.go +++ b/agent/consul/multilimiter/multilimiter.go @@ -110,32 +110,20 @@ func NewMultiLimiter(c Config) *MultiLimiter { func (m *MultiLimiter) Run(ctx context.Context) { m.once.Do(func() { go func() { - writeTimeout := 10 * time.Millisecond + cfg := m.defaultConfig.Load() + writeTimeout := cfg.ReconcileCheckInterval limiters := m.limiters.Load() txn := limiters.Txn() waiter := time.NewTicker(writeTimeout) wt := tickerWrapper{ticker: waiter} + defer waiter.Stop() for { - if txn = m.runStoreOnce(ctx, wt, txn); txn == nil { + if txn = m.reconcile(ctx, wt, txn, cfg.ReconcileCheckLimit); txn == nil { return } } }() - go func() { - waiter := time.NewTimer(0) - for { - c := m.defaultConfig.Load() - waiter.Reset(c.ReconcileCheckInterval) - select { - case <-ctx.Done(): - waiter.Stop() - return - case now := <-waiter.C: - m.reconcileLimitedOnce(now, c.ReconcileCheckLimit) - } - } - }() }) } @@ -192,13 +180,16 @@ func (t tickerWrapper) Ticker() <-chan time.Time { return t.ticker.C } -func (m *MultiLimiter) runStoreOnce(ctx context.Context, waiter ticker, txn *radix.Txn) *radix.Txn { +func (m *MultiLimiter) reconcile(ctx context.Context, waiter ticker, txn *radix.Txn, reconcileCheckLimit time.Duration) *radix.Txn { select { case <-waiter.Ticker(): tree := txn.Commit() m.limiters.Store(tree) txn = tree.Txn() - + m.cleanLimiters(time.Now(), reconcileCheckLimit, txn) + m.reconcileConfig(txn) + tree = txn.Commit() + txn = tree.Txn() case lk := <-m.limiterCh: v, ok := txn.Get(lk.k) if !ok { @@ -215,60 +206,55 @@ func (m *MultiLimiter) runStoreOnce(ctx context.Context, waiter ticker, txn *rad return txn } -// reconcileLimitedOnce is called by the MultiLimiter clean up routine to remove old Limited entries -// it will wait for ReconcileCheckInterval before traversing the radix tree and removing all entries -// with lastAccess > ReconcileCheckLimit -func (m *MultiLimiter) reconcileLimitedOnce(now time.Time, reconcileCheckLimit time.Duration) { - limiters := m.limiters.Load() - storedLimiters := limiters - iter := limiters.Root().Iterator() - k, v, ok := iter.Next() - var txn *radix.Txn - txn = limiters.Txn() - // remove all expired limiters - for ok { - if t, ok := v.(*Limiter); ok { - if t.limiter != nil { - lastAccess := t.lastAccess.Load() - lastAccessT := time.UnixMilli(lastAccess) - diff := now.Sub(lastAccessT) - - if diff > reconcileCheckLimit { - txn.Delete(k) - } - } - } - k, v, ok = iter.Next() - } - iter = txn.Root().Iterator() - k, v, ok = iter.Next() - +func (m *MultiLimiter) reconcileConfig(txn *radix.Txn) { + iter := txn.Root().Iterator() // make sure all limiters have the latest defaultConfig of their prefix - for ok { - if pl, ok := v.(*Limiter); ok { - // check if it has a limiter, if so that's a lead - if pl.limiter != nil { - // find the prefix for the leaf and check if the defaultConfig is up-to-date - // it's possible that the prefix is equal to the key - prefix, _ := splitKey(k) - v, ok := m.limitersConfigs.Load().Get(prefix) - if ok { - if cl, ok := v.(*LimiterConfig); ok { - if cl != nil { - if !cl.isApplied(pl.limiter) { - limiter := Limiter{limiter: rate.NewLimiter(cl.Rate, cl.Burst)} - limiter.lastAccess.Store(pl.lastAccess.Load()) - txn.Insert(k, &limiter) - } - } - } - } - } + for k, v, ok := iter.Next(); ok; k, v, ok = iter.Next() { + pl, ok := v.(*Limiter) + if pl == nil || !ok { + continue } - k, v, ok = iter.Next() + if pl.limiter == nil { + continue + } + + // find the prefix for the leaf and check if the defaultConfig is up-to-date + // it's possible that the prefix is equal to the key + prefix, _ := splitKey(k) + v, ok := m.limitersConfigs.Load().Get(prefix) + if v == nil || !ok { + continue + } + cl, ok := v.(*LimiterConfig) + if cl == nil || !ok { + continue + } + if cl.isApplied(pl.limiter) { + continue + } + + limiter := Limiter{limiter: rate.NewLimiter(cl.Rate, cl.Burst)} + limiter.lastAccess.Store(pl.lastAccess.Load()) + txn.Insert(k, &limiter) + } - limiters = txn.Commit() - m.limiters.CompareAndSwap(storedLimiters, limiters) +} + +func (m *MultiLimiter) cleanLimiters(now time.Time, reconcileCheckLimit time.Duration, txn *radix.Txn) { + iter := txn.Root().Iterator() + // remove all expired limiters + for k, v, ok := iter.Next(); ok; k, v, ok = iter.Next() { + t, isLimiter := v.(*Limiter) + if !isLimiter || t.limiter == nil { + continue + } + + lastAccess := time.UnixMilli(t.lastAccess.Load()) + if now.Sub(lastAccess) > reconcileCheckLimit { + txn.Delete(k) + } + } + } func (lc *LimiterConfig) isApplied(l *rate.Limiter) bool { diff --git a/agent/consul/multilimiter/multilimiter_test.go b/agent/consul/multilimiter/multilimiter_test.go index 939e8006f..b64f95feb 100644 --- a/agent/consul/multilimiter/multilimiter_test.go +++ b/agent/consul/multilimiter/multilimiter_test.go @@ -85,9 +85,11 @@ func TestRateLimiterCleanup(t *testing.T) { limiters = l }) - l, ok := limiters.Get(key) - require.True(t, ok) - require.NotNil(t, l) + retry.RunWith(&retry.Timer{Wait: 100 * time.Millisecond, Timeout: 2 * time.Second}, t, func(r *retry.R) { + v, ok := limiters.Get(key) + require.True(r, ok) + require.NotNil(t, v) + }) time.Sleep(c.ReconcileCheckInterval) // Wait > ReconcileCheckInterval and check that the key was cleaned up @@ -95,19 +97,30 @@ func TestRateLimiterCleanup(t *testing.T) { l := m.limiters.Load() require.NotEqual(r, limiters, l) limiters = l + v, ok := limiters.Get(key) + require.False(r, ok) + require.Nil(r, v) }) - l, ok = limiters.Get(key) - require.False(t, ok) - require.Nil(t, l) + } func storeLimiter(m *MultiLimiter) { txn := m.limiters.Load().Txn() mockTicker := mockTicker{tickerCh: make(chan time.Time, 1)} ctx := context.Background() - m.runStoreOnce(ctx, &mockTicker, txn) + reconcileCheckLimit := m.defaultConfig.Load().ReconcileCheckLimit + m.reconcile(ctx, &mockTicker, txn, reconcileCheckLimit) mockTicker.tickerCh <- time.Now() - m.runStoreOnce(ctx, &mockTicker, txn) + m.reconcile(ctx, &mockTicker, txn, reconcileCheckLimit) +} + +func reconcile(m *MultiLimiter) { + txn := m.limiters.Load().Txn() + mockTicker := mockTicker{tickerCh: make(chan time.Time, 1)} + ctx := context.Background() + reconcileCheckLimit := m.defaultConfig.Load().ReconcileCheckLimit + mockTicker.tickerCh <- time.Now() + m.reconcile(ctx, &mockTicker, txn, reconcileCheckLimit) } func TestRateLimiterStore(t *testing.T) { @@ -241,7 +254,7 @@ func TestRateLimiterUpdateConfig(t *testing.T) { c3 := LimiterConfig{Rate: 2} m.UpdateConfig(c3, prefix) // call reconcileLimitedOnce to make sure the update is applied - m.reconcileLimitedOnce(time.Now(), 100*time.Millisecond) + m.reconcileConfig(m.limiters.Load().Txn()) m.Allow(ipLimited{key: ip}) storeLimiter(m) l3, ok3 := m.limiters.Load().Get(ip) @@ -276,7 +289,7 @@ func TestRateLimiterUpdateConfig(t *testing.T) { c1 := LimiterConfig{Rate: 1} m.UpdateConfig(c1, prefix) // call reconcileLimitedOnce to make sure the update is applied - m.reconcileLimitedOnce(time.Now(), 100*time.Millisecond) + reconcile(m) m.Allow(ipLimited{key: ip}) storeLimiter(m) l, ok := m.limiters.Load().Get(ip) @@ -284,10 +297,6 @@ func TestRateLimiterUpdateConfig(t *testing.T) { require.NotNil(t, l) limiter := l.(*Limiter) require.True(t, c1.isApplied(limiter.limiter)) - m.reconcileLimitedOnce(time.Now().Add(100*time.Millisecond), 100*time.Millisecond) - l, ok = m.limiters.Load().Get(ip) - require.False(t, ok) - require.Nil(t, l) }) } @@ -347,7 +356,7 @@ func FuzzUpdateConfig(f *testing.F) { m.UpdateConfig(c, prefix) go m.Allow(Limited{key: f}) } - m.reconcileLimitedOnce(time.Now(), 1*time.Millisecond) + m.reconcileConfig(m.limiters.Load().Txn()) checkTree(t, m.limiters.Load().Txn()) })