256 lines
6.3 KiB
Go
256 lines
6.3 KiB
Go
package multilimiter
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
radix "github.com/hashicorp/go-immutable-radix"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
var _ RateLimiter = &MultiLimiter{}
|
|
|
|
const separator = "♣"
|
|
|
|
func Key(keys ...[]byte) KeyType {
|
|
return bytes.Join(keys, []byte(separator))
|
|
}
|
|
|
|
// RateLimiter is the interface implemented by MultiLimiter
|
|
//
|
|
//go:generate mockery --name RateLimiter --inpackage --filename mock_RateLimiter.go
|
|
type RateLimiter interface {
|
|
Run(ctx context.Context)
|
|
Allow(entity LimitedEntity) bool
|
|
UpdateConfig(c LimiterConfig, prefix []byte)
|
|
}
|
|
|
|
type limiterWithKey struct {
|
|
l *Limiter
|
|
k []byte
|
|
t time.Time
|
|
}
|
|
|
|
// MultiLimiter implement RateLimiter interface and represent a set of rate limiters
|
|
// specific to different LimitedEntities and queried by a LimitedEntities.Key()
|
|
type MultiLimiter struct {
|
|
limiters *atomic.Pointer[radix.Tree]
|
|
limitersConfigs *atomic.Pointer[radix.Tree]
|
|
defaultConfig *atomic.Pointer[Config]
|
|
limiterCh chan *limiterWithKey
|
|
configsLock sync.Mutex
|
|
once sync.Once
|
|
}
|
|
|
|
type KeyType = []byte
|
|
|
|
// LimitedEntity is an interface used by MultiLimiter.Allow to determine
|
|
// which rate limiter to use to allow the request
|
|
type LimitedEntity interface {
|
|
Key() KeyType
|
|
}
|
|
|
|
// Limiter define a limiter to be part of the MultiLimiter structure
|
|
type Limiter struct {
|
|
limiter *rate.Limiter
|
|
lastAccess atomic.Int64
|
|
}
|
|
|
|
// LimiterConfig is a Limiter configuration
|
|
type LimiterConfig struct {
|
|
Rate rate.Limit
|
|
Burst int
|
|
}
|
|
|
|
// Config is a MultiLimiter configuration
|
|
type Config struct {
|
|
LimiterConfig
|
|
ReconcileCheckLimit time.Duration
|
|
ReconcileCheckInterval time.Duration
|
|
}
|
|
|
|
// UpdateConfig will update the MultiLimiter Config
|
|
// which will cascade to all the Limiter(s) LimiterConfig
|
|
func (m *MultiLimiter) UpdateConfig(c LimiterConfig, prefix []byte) {
|
|
m.configsLock.Lock()
|
|
defer m.configsLock.Unlock()
|
|
if prefix == nil {
|
|
prefix = []byte("")
|
|
}
|
|
configs := m.limitersConfigs.Load()
|
|
newConfigs, _, _ := configs.Insert(prefix, &c)
|
|
m.limitersConfigs.Store(newConfigs)
|
|
}
|
|
|
|
// NewMultiLimiter create a new MultiLimiter
|
|
func NewMultiLimiter(c Config) *MultiLimiter {
|
|
limiters := atomic.Pointer[radix.Tree]{}
|
|
configs := atomic.Pointer[radix.Tree]{}
|
|
config := atomic.Pointer[Config]{}
|
|
config.Store(&c)
|
|
limiters.Store(radix.New())
|
|
configs.Store(radix.New())
|
|
if c.ReconcileCheckLimit == 0 {
|
|
c.ReconcileCheckLimit = 30 * time.Millisecond
|
|
}
|
|
if c.ReconcileCheckInterval == 0 {
|
|
c.ReconcileCheckLimit = 1 * time.Second
|
|
}
|
|
chLimiter := make(chan *limiterWithKey, 100)
|
|
m := &MultiLimiter{limiters: &limiters, defaultConfig: &config, limitersConfigs: &configs, limiterCh: chLimiter}
|
|
|
|
return m
|
|
}
|
|
|
|
// Run the cleanup routine to remove old entries of Limiters based on ReconcileCheckLimit and ReconcileCheckInterval.
|
|
func (m *MultiLimiter) Run(ctx context.Context) {
|
|
m.once.Do(func() {
|
|
go func() {
|
|
cfg := m.defaultConfig.Load()
|
|
writeTimeout := cfg.ReconcileCheckInterval
|
|
limiters := m.limiters.Load()
|
|
txn := limiters.Txn()
|
|
waiter := time.NewTicker(writeTimeout)
|
|
wt := tickerWrapper{ticker: waiter}
|
|
|
|
defer waiter.Stop()
|
|
for {
|
|
if txn = m.reconcile(ctx, wt, txn, cfg.ReconcileCheckLimit); txn == nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
})
|
|
|
|
}
|
|
|
|
// Allow should be called by a request processor to check if the current request is Limited
|
|
// The request processor should provide a LimitedEntity that implement the right Key()
|
|
func (m *MultiLimiter) Allow(e LimitedEntity) bool {
|
|
limiters := m.limiters.Load()
|
|
l, ok := limiters.Get(e.Key())
|
|
now := time.Now()
|
|
unixNow := time.Now().UnixMilli()
|
|
if ok {
|
|
limiter := l.(*Limiter)
|
|
if limiter.limiter != nil {
|
|
limiter.lastAccess.Store(unixNow)
|
|
return limiter.limiter.Allow()
|
|
}
|
|
}
|
|
|
|
configs := m.limitersConfigs.Load()
|
|
config := &m.defaultConfig.Load().LimiterConfig
|
|
p, _, ok := configs.Root().LongestPrefix(e.Key())
|
|
|
|
if ok {
|
|
c, ok := configs.Get(p)
|
|
if ok && c != nil {
|
|
config = c.(*LimiterConfig)
|
|
}
|
|
}
|
|
|
|
limiter := &Limiter{limiter: rate.NewLimiter(config.Rate, config.Burst)}
|
|
limiter.lastAccess.Store(unixNow)
|
|
m.limiterCh <- &limiterWithKey{l: limiter, k: e.Key(), t: now}
|
|
return limiter.limiter.Allow()
|
|
}
|
|
|
|
type ticker interface {
|
|
Ticker() <-chan time.Time
|
|
}
|
|
|
|
type tickerWrapper struct {
|
|
ticker *time.Ticker
|
|
}
|
|
|
|
func (t tickerWrapper) Ticker() <-chan time.Time {
|
|
return t.ticker.C
|
|
}
|
|
func (m *MultiLimiter) reconcile(ctx context.Context, waiter ticker, txn *radix.Txn, reconcileCheckLimit time.Duration) *radix.Txn {
|
|
select {
|
|
case <-waiter.Ticker():
|
|
tree := txn.Commit()
|
|
m.limiters.Store(tree)
|
|
txn = tree.Txn()
|
|
m.cleanLimiters(time.Now(), reconcileCheckLimit, txn)
|
|
m.reconcileConfig(txn)
|
|
tree = txn.Commit()
|
|
txn = tree.Txn()
|
|
case lk := <-m.limiterCh:
|
|
v, ok := txn.Get(lk.k)
|
|
if !ok {
|
|
txn.Insert(lk.k, lk.l)
|
|
} else {
|
|
if l, ok := v.(*Limiter); ok {
|
|
l.lastAccess.Store(lk.t.Unix())
|
|
l.limiter.AllowN(lk.t, 1)
|
|
}
|
|
}
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
return txn
|
|
}
|
|
|
|
func (m *MultiLimiter) reconcileConfig(txn *radix.Txn) {
|
|
iter := txn.Root().Iterator()
|
|
// make sure all limiters have the latest defaultConfig of their prefix
|
|
for k, v, ok := iter.Next(); ok; k, v, ok = iter.Next() {
|
|
pl, ok := v.(*Limiter)
|
|
if pl == nil || !ok {
|
|
continue
|
|
}
|
|
if pl.limiter == nil {
|
|
continue
|
|
}
|
|
|
|
// find the prefix for the leaf and check if the defaultConfig is up-to-date
|
|
// it's possible that the prefix is equal to the key
|
|
p, _, ok := m.limitersConfigs.Load().Root().LongestPrefix(k)
|
|
if !ok {
|
|
continue
|
|
}
|
|
v, ok := m.limitersConfigs.Load().Get(p)
|
|
if v == nil || !ok {
|
|
continue
|
|
}
|
|
cl, ok := v.(*LimiterConfig)
|
|
if cl == nil || !ok {
|
|
continue
|
|
}
|
|
if cl.isApplied(pl.limiter) {
|
|
continue
|
|
}
|
|
|
|
limiter := Limiter{limiter: rate.NewLimiter(cl.Rate, cl.Burst)}
|
|
limiter.lastAccess.Store(pl.lastAccess.Load())
|
|
txn.Insert(k, &limiter)
|
|
|
|
}
|
|
}
|
|
|
|
func (m *MultiLimiter) cleanLimiters(now time.Time, reconcileCheckLimit time.Duration, txn *radix.Txn) {
|
|
iter := txn.Root().Iterator()
|
|
// remove all expired limiters
|
|
for k, v, ok := iter.Next(); ok; k, v, ok = iter.Next() {
|
|
t, isLimiter := v.(*Limiter)
|
|
if !isLimiter || t.limiter == nil {
|
|
continue
|
|
}
|
|
|
|
lastAccess := time.UnixMilli(t.lastAccess.Load())
|
|
if now.Sub(lastAccess) > reconcileCheckLimit {
|
|
txn.Delete(k)
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
func (lc *LimiterConfig) isApplied(l *rate.Limiter) bool {
|
|
return l.Limit() == lc.Rate && l.Burst() == lc.Burst
|
|
}
|