counterpart of the ent in oss (#17367)

This commit is contained in:
wangxinyi7 2023-05-15 10:49:43 -07:00 committed by GitHub
parent 4a245b2bff
commit c2a479bffa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 81 deletions

View File

@ -215,6 +215,7 @@ func NewHandlerWithLimiter(
logger: logger, logger: logger,
} }
h.globalCfg.Store(&cfg) h.globalCfg.Store(&cfg)
h.ipCfg.Store(&IPLimitConfig{})
return h return h
} }
@ -248,17 +249,10 @@ func (h *Handler) Allow(op Operation) error {
return nil return nil
} }
for _, l := range h.limits(op) { allow, throttledLimits := h.allowAllLimits(h.limits(op))
if l.mode == ModeDisabled {
continue
}
if h.limiter.Allow(l.ent) {
continue
}
// TODO(NET-1382): is this the correct log-level?
if !allow {
for _, l := range throttledLimits {
enforced := l.mode == ModeEnforcing enforced := l.mode == ModeEnforcing
h.logger.Debug("RPC exceeded allowed rate limit", h.logger.Debug("RPC exceeded allowed rate limit",
"rpc", op.Name, "rpc", op.Name,
@ -289,6 +283,7 @@ func (h *Handler) Allow(op Operation) error {
return ErrRetryElsewhere return ErrRetryElsewhere
} }
} }
}
return nil return nil
} }
@ -320,6 +315,23 @@ type limit struct {
desc string desc string
} }
func (h *Handler) allowAllLimits(limits []limit) (bool, []limit) {
allow := true
throttledLimits := make([]limit, 0)
for _, l := range limits {
if l.mode == ModeDisabled {
continue
}
if !h.limiter.Allow(l.ent) {
throttledLimits = append(throttledLimits, l)
allow = false
}
}
return allow, throttledLimits
}
// limits returns the limits to check for the given operation (e.g. global + // limits returns the limits to check for the given operation (e.g. global +
// ip-based + tenant-based). // ip-based + tenant-based).
func (h *Handler) limits(op Operation) []limit { func (h *Handler) limits(op Operation) []limit {
@ -329,6 +341,14 @@ func (h *Handler) limits(op Operation) []limit {
limits = append(limits, *global) limits = append(limits, *global)
} }
if ipGlobal := h.ipGlobalLimit(op); ipGlobal != nil {
limits = append(limits, *ipGlobal)
}
if ipCategory := h.ipCategoryLimit(op); ipCategory != nil {
limits = append(limits, *ipCategory)
}
return limits return limits
} }
@ -354,23 +374,23 @@ func (h *Handler) globalLimit(op Operation) *limit {
var ( var (
// globalWrite identifies the global rate limit applied to write operations. // globalWrite identifies the global rate limit applied to write operations.
globalWrite = globalLimit("global.write") globalWrite = limitedEntity("global.write")
// globalRead identifies the global rate limit applied to read operations. // globalRead identifies the global rate limit applied to read operations.
globalRead = globalLimit("global.read") globalRead = limitedEntity("global.read")
// globalIPRead identifies the global rate limit applied to read operations. // globalIPRead identifies the global rate limit applied to read operations.
globalIPRead = globalLimit("global.ip.read") globalIPRead = limitedEntity("global.ip.read")
// globalIPWrite identifies the global rate limit applied to read operations. // globalIPWrite identifies the global rate limit applied to read operations.
globalIPWrite = globalLimit("global.ip.write") globalIPWrite = limitedEntity("global.ip.write")
) )
// globalLimit represents a limit that applies to all writes or reads. // limitedEntity convert the string type to Multilimiter.LimitedEntity
type globalLimit []byte type limitedEntity []byte
// Key satisfies the multilimiter.LimitedEntity interface. // Key satisfies the multilimiter.LimitedEntity interface.
func (prefix globalLimit) Key() multilimiter.KeyType { func (prefix limitedEntity) Key() multilimiter.KeyType {
return multilimiter.Key(prefix, nil) return multilimiter.Key(prefix, nil)
} }

View File

@ -6,9 +6,16 @@
package rate package rate
type IPLimitConfig struct { type IPLimitConfig struct{}
}
func (h *Handler) UpdateIPConfig(cfg IPLimitConfig) { func (h *Handler) UpdateIPConfig(cfg IPLimitConfig) {
// noop // noop
} }
func (h *Handler) ipGlobalLimit(op Operation) *limit {
return nil
}
func (h *Handler) ipCategoryLimit(op Operation) *limit {
return nil
}

View File

@ -5,20 +5,18 @@ package rate
import ( import (
"bytes" "bytes"
"context" "github.com/hashicorp/consul/agent/metrics"
"github.com/stretchr/testify/require"
"net" "net"
"net/netip" "net/netip"
"testing" "testing"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/mock"
"github.com/hashicorp/consul/agent/consul/multilimiter" "github.com/hashicorp/consul/agent/consul/multilimiter"
"github.com/hashicorp/consul/agent/metrics"
) )
// //
@ -226,10 +224,10 @@ func TestHandler(t *testing.T) {
for desc, tc := range testCases { for desc, tc := range testCases {
t.Run(desc, func(t *testing.T) { t.Run(desc, func(t *testing.T) {
sink := metrics.TestSetupMetrics(t, "") sink := metrics.TestSetupMetrics(t, "")
limiter := newMockLimiter(t) limiter := multilimiter.NewMockRateLimiter(t)
limiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() limiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
for _, c := range tc.checks { for _, c := range tc.checks {
limiter.On("Allow", c.limit).Return(c.allow) limiter.On("Allow", mock.Anything).Return(c.allow)
} }
leaderStatusProvider := NewMockLeaderStatusProvider(t) leaderStatusProvider := NewMockLeaderStatusProvider(t)
@ -376,7 +374,7 @@ func TestAllow(t *testing.T) {
type testCase struct { type testCase struct {
description string description string
cfg *HandlerConfig cfg *HandlerConfig
expectedAllowCalls int expectedAllowCalls bool
} }
testCases := []testCase{ testCases := []testCase{
{ {
@ -390,7 +388,7 @@ func TestAllow(t *testing.T) {
}, },
}, },
}, },
expectedAllowCalls: 0, expectedAllowCalls: false,
}, },
{ {
description: "RateLimiter gets called when mode is permissive.", description: "RateLimiter gets called when mode is permissive.",
@ -403,7 +401,7 @@ func TestAllow(t *testing.T) {
}, },
}, },
}, },
expectedAllowCalls: 1, expectedAllowCalls: true,
}, },
{ {
description: "RateLimiter gets called when mode is enforcing.", description: "RateLimiter gets called when mode is enforcing.",
@ -416,14 +414,14 @@ func TestAllow(t *testing.T) {
}, },
}, },
}, },
expectedAllowCalls: 1, expectedAllowCalls: true,
}, },
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
mockRateLimiter := multilimiter.NewMockRateLimiter(t) mockRateLimiter := multilimiter.NewMockRateLimiter(t)
if tc.expectedAllowCalls > 0 { if tc.expectedAllowCalls {
mockRateLimiter.On("Allow", mock.Anything).Return(func(entity multilimiter.LimitedEntity) bool { return true }) mockRateLimiter.On("Allow", mock.Anything).Return(func(entity multilimiter.LimitedEntity) bool { return true })
} }
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
@ -435,31 +433,7 @@ func TestAllow(t *testing.T) {
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234")) addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234"))
mockRateLimiter.Calls = nil mockRateLimiter.Calls = nil
handler.Allow(Operation{Name: "test", SourceAddr: addr}) handler.Allow(Operation{Name: "test", SourceAddr: addr})
mockRateLimiter.AssertNumberOfCalls(t, "Allow", tc.expectedAllowCalls) mockRateLimiter.AssertExpectations(t)
}) })
} }
} }
var _ multilimiter.RateLimiter = (*mockLimiter)(nil)
func newMockLimiter(t *testing.T) *mockLimiter {
l := &mockLimiter{}
l.Mock.Test(t)
t.Cleanup(func() { l.AssertExpectations(t) })
return l
}
type mockLimiter struct {
mock.Mock
}
func (m *mockLimiter) Allow(v multilimiter.LimitedEntity) bool { return m.Called(v).Bool(0) }
func (m *mockLimiter) Run(ctx context.Context) { m.Called(ctx) }
func (m *mockLimiter) UpdateConfig(cfg multilimiter.LimiterConfig, prefix []byte) {
m.Called(cfg, prefix)
}
func (m *mockLimiter) DeleteConfig(prefix []byte) {
m.Called(prefix)
}