counterpart of the ent in oss (#17367)
This commit is contained in:
parent
4a245b2bff
commit
c2a479bffa
|
@ -215,6 +215,7 @@ func NewHandlerWithLimiter(
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
h.globalCfg.Store(&cfg)
|
h.globalCfg.Store(&cfg)
|
||||||
|
h.ipCfg.Store(&IPLimitConfig{})
|
||||||
|
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
@ -248,45 +249,39 @@ func (h *Handler) Allow(op Operation) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, l := range h.limits(op) {
|
allow, throttledLimits := h.allowAllLimits(h.limits(op))
|
||||||
if l.mode == ModeDisabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.limiter.Allow(l.ent) {
|
if !allow {
|
||||||
continue
|
for _, l := range throttledLimits {
|
||||||
}
|
enforced := l.mode == ModeEnforcing
|
||||||
|
h.logger.Debug("RPC exceeded allowed rate limit",
|
||||||
|
"rpc", op.Name,
|
||||||
|
"source_addr", op.SourceAddr,
|
||||||
|
"limit_type", l.desc,
|
||||||
|
"limit_enforced", enforced,
|
||||||
|
)
|
||||||
|
|
||||||
// TODO(NET-1382): is this the correct log-level?
|
metrics.IncrCounterWithLabels([]string{"rpc", "rate_limit", "exceeded"}, 1, []metrics.Label{
|
||||||
|
{
|
||||||
|
Name: "limit_type",
|
||||||
|
Value: l.desc,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "op",
|
||||||
|
Value: op.Name,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "mode",
|
||||||
|
Value: l.mode.String(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
enforced := l.mode == ModeEnforcing
|
if enforced {
|
||||||
h.logger.Debug("RPC exceeded allowed rate limit",
|
if h.leaderStatusProvider.IsLeader() && op.Type == OperationTypeWrite {
|
||||||
"rpc", op.Name,
|
return ErrRetryLater
|
||||||
"source_addr", op.SourceAddr,
|
}
|
||||||
"limit_type", l.desc,
|
return ErrRetryElsewhere
|
||||||
"limit_enforced", enforced,
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics.IncrCounterWithLabels([]string{"rpc", "rate_limit", "exceeded"}, 1, []metrics.Label{
|
|
||||||
{
|
|
||||||
Name: "limit_type",
|
|
||||||
Value: l.desc,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "op",
|
|
||||||
Value: op.Name,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "mode",
|
|
||||||
Value: l.mode.String(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
if enforced {
|
|
||||||
if h.leaderStatusProvider.IsLeader() && op.Type == OperationTypeWrite {
|
|
||||||
return ErrRetryLater
|
|
||||||
}
|
}
|
||||||
return ErrRetryElsewhere
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -320,6 +315,23 @@ type limit struct {
|
||||||
desc string
|
desc string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) allowAllLimits(limits []limit) (bool, []limit) {
|
||||||
|
allow := true
|
||||||
|
throttledLimits := make([]limit, 0)
|
||||||
|
|
||||||
|
for _, l := range limits {
|
||||||
|
if l.mode == ModeDisabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.limiter.Allow(l.ent) {
|
||||||
|
throttledLimits = append(throttledLimits, l)
|
||||||
|
allow = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return allow, throttledLimits
|
||||||
|
}
|
||||||
|
|
||||||
// limits returns the limits to check for the given operation (e.g. global +
|
// limits returns the limits to check for the given operation (e.g. global +
|
||||||
// ip-based + tenant-based).
|
// ip-based + tenant-based).
|
||||||
func (h *Handler) limits(op Operation) []limit {
|
func (h *Handler) limits(op Operation) []limit {
|
||||||
|
@ -329,6 +341,14 @@ func (h *Handler) limits(op Operation) []limit {
|
||||||
limits = append(limits, *global)
|
limits = append(limits, *global)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ipGlobal := h.ipGlobalLimit(op); ipGlobal != nil {
|
||||||
|
limits = append(limits, *ipGlobal)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ipCategory := h.ipCategoryLimit(op); ipCategory != nil {
|
||||||
|
limits = append(limits, *ipCategory)
|
||||||
|
}
|
||||||
|
|
||||||
return limits
|
return limits
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -354,23 +374,23 @@ func (h *Handler) globalLimit(op Operation) *limit {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// globalWrite identifies the global rate limit applied to write operations.
|
// globalWrite identifies the global rate limit applied to write operations.
|
||||||
globalWrite = globalLimit("global.write")
|
globalWrite = limitedEntity("global.write")
|
||||||
|
|
||||||
// globalRead identifies the global rate limit applied to read operations.
|
// globalRead identifies the global rate limit applied to read operations.
|
||||||
globalRead = globalLimit("global.read")
|
globalRead = limitedEntity("global.read")
|
||||||
|
|
||||||
// globalIPRead identifies the global rate limit applied to read operations.
|
// globalIPRead identifies the global rate limit applied to read operations.
|
||||||
globalIPRead = globalLimit("global.ip.read")
|
globalIPRead = limitedEntity("global.ip.read")
|
||||||
|
|
||||||
// globalIPWrite identifies the global rate limit applied to read operations.
|
// globalIPWrite identifies the global rate limit applied to read operations.
|
||||||
globalIPWrite = globalLimit("global.ip.write")
|
globalIPWrite = limitedEntity("global.ip.write")
|
||||||
)
|
)
|
||||||
|
|
||||||
// globalLimit represents a limit that applies to all writes or reads.
|
// limitedEntity convert the string type to Multilimiter.LimitedEntity
|
||||||
type globalLimit []byte
|
type limitedEntity []byte
|
||||||
|
|
||||||
// Key satisfies the multilimiter.LimitedEntity interface.
|
// Key satisfies the multilimiter.LimitedEntity interface.
|
||||||
func (prefix globalLimit) Key() multilimiter.KeyType {
|
func (prefix limitedEntity) Key() multilimiter.KeyType {
|
||||||
return multilimiter.Key(prefix, nil)
|
return multilimiter.Key(prefix, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,9 +6,16 @@
|
||||||
|
|
||||||
package rate
|
package rate
|
||||||
|
|
||||||
type IPLimitConfig struct {
|
type IPLimitConfig struct{}
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) UpdateIPConfig(cfg IPLimitConfig) {
|
func (h *Handler) UpdateIPConfig(cfg IPLimitConfig) {
|
||||||
// noop
|
// noop
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) ipGlobalLimit(op Operation) *limit {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) ipCategoryLimit(op Operation) *limit {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -5,20 +5,18 @@ package rate
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"github.com/hashicorp/consul/agent/metrics"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-hclog"
|
"github.com/hashicorp/go-hclog"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent/consul/multilimiter"
|
"github.com/hashicorp/consul/agent/consul/multilimiter"
|
||||||
"github.com/hashicorp/consul/agent/metrics"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -226,10 +224,10 @@ func TestHandler(t *testing.T) {
|
||||||
for desc, tc := range testCases {
|
for desc, tc := range testCases {
|
||||||
t.Run(desc, func(t *testing.T) {
|
t.Run(desc, func(t *testing.T) {
|
||||||
sink := metrics.TestSetupMetrics(t, "")
|
sink := metrics.TestSetupMetrics(t, "")
|
||||||
limiter := newMockLimiter(t)
|
limiter := multilimiter.NewMockRateLimiter(t)
|
||||||
limiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
limiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
||||||
for _, c := range tc.checks {
|
for _, c := range tc.checks {
|
||||||
limiter.On("Allow", c.limit).Return(c.allow)
|
limiter.On("Allow", mock.Anything).Return(c.allow)
|
||||||
}
|
}
|
||||||
|
|
||||||
leaderStatusProvider := NewMockLeaderStatusProvider(t)
|
leaderStatusProvider := NewMockLeaderStatusProvider(t)
|
||||||
|
@ -376,7 +374,7 @@ func TestAllow(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
cfg *HandlerConfig
|
cfg *HandlerConfig
|
||||||
expectedAllowCalls int
|
expectedAllowCalls bool
|
||||||
}
|
}
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
|
@ -390,7 +388,7 @@ func TestAllow(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedAllowCalls: 0,
|
expectedAllowCalls: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "RateLimiter gets called when mode is permissive.",
|
description: "RateLimiter gets called when mode is permissive.",
|
||||||
|
@ -403,7 +401,7 @@ func TestAllow(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedAllowCalls: 1,
|
expectedAllowCalls: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "RateLimiter gets called when mode is enforcing.",
|
description: "RateLimiter gets called when mode is enforcing.",
|
||||||
|
@ -416,14 +414,14 @@ func TestAllow(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedAllowCalls: 1,
|
expectedAllowCalls: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.description, func(t *testing.T) {
|
t.Run(tc.description, func(t *testing.T) {
|
||||||
mockRateLimiter := multilimiter.NewMockRateLimiter(t)
|
mockRateLimiter := multilimiter.NewMockRateLimiter(t)
|
||||||
if tc.expectedAllowCalls > 0 {
|
if tc.expectedAllowCalls {
|
||||||
mockRateLimiter.On("Allow", mock.Anything).Return(func(entity multilimiter.LimitedEntity) bool { return true })
|
mockRateLimiter.On("Allow", mock.Anything).Return(func(entity multilimiter.LimitedEntity) bool { return true })
|
||||||
}
|
}
|
||||||
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
||||||
|
@ -435,31 +433,7 @@ func TestAllow(t *testing.T) {
|
||||||
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234"))
|
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234"))
|
||||||
mockRateLimiter.Calls = nil
|
mockRateLimiter.Calls = nil
|
||||||
handler.Allow(Operation{Name: "test", SourceAddr: addr})
|
handler.Allow(Operation{Name: "test", SourceAddr: addr})
|
||||||
mockRateLimiter.AssertNumberOfCalls(t, "Allow", tc.expectedAllowCalls)
|
mockRateLimiter.AssertExpectations(t)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ multilimiter.RateLimiter = (*mockLimiter)(nil)
|
|
||||||
|
|
||||||
func newMockLimiter(t *testing.T) *mockLimiter {
|
|
||||||
l := &mockLimiter{}
|
|
||||||
l.Mock.Test(t)
|
|
||||||
|
|
||||||
t.Cleanup(func() { l.AssertExpectations(t) })
|
|
||||||
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockLimiter struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockLimiter) Allow(v multilimiter.LimitedEntity) bool { return m.Called(v).Bool(0) }
|
|
||||||
func (m *mockLimiter) Run(ctx context.Context) { m.Called(ctx) }
|
|
||||||
func (m *mockLimiter) UpdateConfig(cfg multilimiter.LimiterConfig, prefix []byte) {
|
|
||||||
m.Called(cfg, prefix)
|
|
||||||
}
|
|
||||||
func (m *mockLimiter) DeleteConfig(prefix []byte) {
|
|
||||||
m.Called(prefix)
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue