diff --git a/agent/consul/rate/handler.go b/agent/consul/rate/handler.go index 2573894b1..12ed974c6 100644 --- a/agent/consul/rate/handler.go +++ b/agent/consul/rate/handler.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 -// package rate implements server-side RPC rate limiting. +// Package rate implements server-side RPC rate limiting. package rate import ( @@ -14,6 +14,7 @@ import ( "github.com/armon/go-metrics" "github.com/hashicorp/go-hclog" + "golang.org/x/time/rate" "github.com/hashicorp/consul/agent/consul/multilimiter" ) @@ -54,7 +55,7 @@ var modeToName = map[Mode]string{ ModeEnforcing: "enforcing", ModePermissive: "permissive", } -var modeFromName = func() map[string]Mode { +var ModeFromName = func() map[string]Mode { vals := map[string]Mode{ "": ModeDisabled, } @@ -70,13 +71,13 @@ func (m Mode) String() string { // RequestLimitsModeFromName will unmarshal the string form of a configMode. func RequestLimitsModeFromName(name string) (Mode, bool) { - s, ok := modeFromName[name] + s, ok := ModeFromName[name] return s, ok } // RequestLimitsModeFromNameWithDefault will unmarshal the string form of a configMode. func RequestLimitsModeFromNameWithDefault(name string) Mode { - s, ok := modeFromName[name] + s, ok := ModeFromName[name] if !ok { return ModePermissive } @@ -151,12 +152,14 @@ type RequestLimitsHandler interface { Run(ctx context.Context) Allow(op Operation) error UpdateConfig(cfg HandlerConfig) + UpdateIPConfig(cfg IPLimitConfig) Register(leaderStatusProvider LeaderStatusProvider) } // Handler enforces rate limits for incoming RPCs. type Handler struct { - cfg *atomic.Pointer[HandlerConfig] + globalCfg *atomic.Pointer[HandlerConfig] + ipCfg *atomic.Pointer[IPLimitConfig] leaderStatusProvider LeaderStatusProvider limiter multilimiter.RateLimiter @@ -164,20 +167,23 @@ type Handler struct { logger hclog.Logger } +type ReadWriteConfig struct { + // WriteConfig configures the global rate limiter for write operations. + WriteConfig multilimiter.LimiterConfig + + // ReadConfig configures the global rate limiter for read operations. + ReadConfig multilimiter.LimiterConfig +} + +type GlobalLimitConfig struct { + Mode Mode + ReadWriteConfig +} + type HandlerConfig struct { multilimiter.Config - // GlobalMode configures the action that will be taken when a global rate-limit - // has been exhausted. - // - // Note: in the future there'll be a separate Mode for IP-based limits. - GlobalMode Mode - - // GlobalWriteConfig configures the global rate limiter for write operations. - GlobalWriteConfig multilimiter.LimiterConfig - - // GlobalReadConfig configures the global rate limiter for read operations. - GlobalReadConfig multilimiter.LimiterConfig + GlobalLimitConfig GlobalLimitConfig } //go:generate mockery --name LeaderStatusProvider --inpackage --filename mock_LeaderStatusProvider_test.go @@ -189,20 +195,26 @@ type LeaderStatusProvider interface { IsLeader() bool } +func isInfRate(cfg multilimiter.LimiterConfig) bool { + return cfg.Rate == rate.Inf +} + func NewHandlerWithLimiter( cfg HandlerConfig, limiter multilimiter.RateLimiter, logger hclog.Logger) *Handler { - limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite) - limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead) + limiter.UpdateConfig(cfg.GlobalLimitConfig.WriteConfig, globalWrite) + + limiter.UpdateConfig(cfg.GlobalLimitConfig.ReadConfig, globalRead) h := &Handler{ - cfg: new(atomic.Pointer[HandlerConfig]), - limiter: limiter, - logger: logger, + ipCfg: new(atomic.Pointer[IPLimitConfig]), + globalCfg: new(atomic.Pointer[HandlerConfig]), + limiter: limiter, + logger: logger, } - h.cfg.Store(&cfg) + h.globalCfg.Store(&cfg) return h } @@ -231,8 +243,8 @@ func (h *Handler) Allow(op Operation) error { // panic("leaderStatusProvider required to be set via Register(..)") } - cfg := h.cfg.Load() - if cfg.GlobalMode == ModeDisabled { + cfg := h.globalCfg.Load() + if cfg.GlobalLimitConfig.Mode == ModeDisabled { return nil } @@ -281,18 +293,21 @@ func (h *Handler) Allow(op Operation) error { } func (h *Handler) UpdateConfig(cfg HandlerConfig) { - existingCfg := h.cfg.Load() - h.cfg.Store(&cfg) - if reflect.DeepEqual(existingCfg, cfg) { + existingCfg := h.globalCfg.Load() + h.globalCfg.Store(&cfg) + if reflect.DeepEqual(existingCfg, &cfg) { h.logger.Warn("UpdateConfig called but configuration has not changed. Skipping updating the server rate limiter configuration.") return } - if !reflect.DeepEqual(existingCfg.GlobalWriteConfig, cfg.GlobalWriteConfig) { - h.limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite) + + if !reflect.DeepEqual(existingCfg.GlobalLimitConfig.WriteConfig, cfg.GlobalLimitConfig.WriteConfig) { + h.limiter.UpdateConfig(cfg.GlobalLimitConfig.WriteConfig, globalWrite) } - if !reflect.DeepEqual(existingCfg.GlobalReadConfig, cfg.GlobalReadConfig) { - h.limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead) + + if !reflect.DeepEqual(existingCfg.GlobalLimitConfig.ReadConfig, cfg.GlobalLimitConfig.ReadConfig) { + h.limiter.UpdateConfig(cfg.GlobalLimitConfig.ReadConfig, globalRead) } + } func (h *Handler) Register(leaderStatusProvider LeaderStatusProvider) { @@ -321,9 +336,9 @@ func (h *Handler) globalLimit(op Operation) *limit { if op.Type == OperationTypeExempt { return nil } - cfg := h.cfg.Load() + cfg := h.globalCfg.Load() - lim := &limit{mode: cfg.GlobalMode} + lim := &limit{mode: cfg.GlobalLimitConfig.Mode} switch op.Type { case OperationTypeRead: lim.desc = "global/read" @@ -343,6 +358,12 @@ var ( // globalRead identifies the global rate limit applied to read operations. globalRead = globalLimit("global.read") + + // globalIPRead identifies the global rate limit applied to read operations. + globalIPRead = globalLimit("global.ip.read") + + // globalIPWrite identifies the global rate limit applied to read operations. + globalIPWrite = globalLimit("global.ip.write") ) // globalLimit represents a limit that applies to all writes or reads. @@ -360,10 +381,12 @@ func NullRequestLimitsHandler() RequestLimitsHandler { type nullRequestLimitsHandler struct{} +func (h nullRequestLimitsHandler) UpdateIPConfig(cfg IPLimitConfig) {} + func (nullRequestLimitsHandler) Allow(Operation) error { return nil } -func (nullRequestLimitsHandler) Run(ctx context.Context) {} +func (nullRequestLimitsHandler) Run(_ context.Context) {} -func (nullRequestLimitsHandler) UpdateConfig(cfg HandlerConfig) {} +func (nullRequestLimitsHandler) UpdateConfig(_ HandlerConfig) {} -func (nullRequestLimitsHandler) Register(leaderStatusProvider LeaderStatusProvider) {} +func (nullRequestLimitsHandler) Register(_ LeaderStatusProvider) {} diff --git a/agent/consul/rate/handler_oss.go b/agent/consul/rate/handler_oss.go new file mode 100644 index 000000000..127064381 --- /dev/null +++ b/agent/consul/rate/handler_oss.go @@ -0,0 +1,11 @@ +//go:build !consulent +// +build !consulent + +package rate + +type IPLimitConfig struct { +} + +func (h *Handler) UpdateIPConfig(cfg IPLimitConfig) { + // noop +} diff --git a/agent/consul/rate/handler_test.go b/agent/consul/rate/handler_test.go index 3b610a147..8f1b465f4 100644 --- a/agent/consul/rate/handler_test.go +++ b/agent/consul/rate/handler_test.go @@ -10,6 +10,8 @@ import ( "net/netip" "testing" + "golang.org/x/time/rate" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -241,7 +243,13 @@ func TestHandler(t *testing.T) { handler := NewHandlerWithLimiter( HandlerConfig{ - GlobalMode: tc.globalMode, + GlobalLimitConfig: GlobalLimitConfig{ + Mode: tc.globalMode, + ReadWriteConfig: ReadWriteConfig{ + ReadConfig: multilimiter.LimiterConfig{}, + WriteConfig: multilimiter.LimiterConfig{}, + }, + }, }, limiter, logger, @@ -269,15 +277,26 @@ func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) { readCfg := multilimiter.LimiterConfig{Rate: 100, Burst: 100} writeCfg := multilimiter.LimiterConfig{Rate: 99, Burst: 99} cfg := &HandlerConfig{ - GlobalReadConfig: readCfg, - GlobalWriteConfig: writeCfg, - GlobalMode: ModeEnforcing, + GlobalLimitConfig: GlobalLimitConfig{ + Mode: ModeEnforcing, + ReadWriteConfig: ReadWriteConfig{ + ReadConfig: readCfg, + WriteConfig: writeCfg, + }, + }, } logger := hclog.NewNullLogger() NewHandlerWithLimiter(*cfg, mockRateLimiter, logger) mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 2) } +func infReadRateConfig() ReadWriteConfig { + return ReadWriteConfig{ + ReadConfig: multilimiter.LimiterConfig{Rate: rate.Inf}, + WriteConfig: multilimiter.LimiterConfig{Rate: rate.Inf}, + } +} + func TestUpdateConfig(t *testing.T) { type testCase struct { description string @@ -295,27 +314,29 @@ func TestUpdateConfig(t *testing.T) { { description: "RateLimiter gets updated when GlobalReadConfig changes.", configModFunc: func(cfg *HandlerConfig) { - cfg.GlobalReadConfig.Burst++ + rc := multilimiter.LimiterConfig{Rate: cfg.GlobalLimitConfig.ReadWriteConfig.ReadConfig.Rate, Burst: cfg.GlobalLimitConfig.ReadWriteConfig.ReadConfig.Burst + 1} + cfg.GlobalLimitConfig.ReadWriteConfig.ReadConfig = rc }, assertFunc: func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) { mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 1) - mockRateLimiter.AssertCalled(t, "UpdateConfig", cfg.GlobalReadConfig, []byte("global.read")) + mockRateLimiter.AssertCalled(t, "UpdateConfig", cfg.GlobalLimitConfig.ReadWriteConfig.ReadConfig, []byte("global.read")) }, }, { description: "RateLimiter gets updated when GlobalWriteConfig changes.", configModFunc: func(cfg *HandlerConfig) { - cfg.GlobalWriteConfig.Burst++ + wc := multilimiter.LimiterConfig{Rate: cfg.GlobalLimitConfig.ReadWriteConfig.WriteConfig.Rate, Burst: cfg.GlobalLimitConfig.ReadWriteConfig.WriteConfig.Burst + 1} + cfg.GlobalLimitConfig.ReadWriteConfig.WriteConfig = wc }, assertFunc: func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) { mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 1) - mockRateLimiter.AssertCalled(t, "UpdateConfig", cfg.GlobalWriteConfig, []byte("global.write")) + mockRateLimiter.AssertCalled(t, "UpdateConfig", cfg.GlobalLimitConfig.ReadWriteConfig.WriteConfig, []byte("global.write")) }, }, { description: "RateLimiter does not get updated when GlobalMode changes.", configModFunc: func(cfg *HandlerConfig) { - cfg.GlobalMode = ModePermissive + cfg.GlobalLimitConfig.Mode = ModePermissive }, assertFunc: func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) { mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 0) @@ -328,9 +349,13 @@ func TestUpdateConfig(t *testing.T) { readCfg := multilimiter.LimiterConfig{Rate: 100, Burst: 100} writeCfg := multilimiter.LimiterConfig{Rate: 99, Burst: 99} cfg := &HandlerConfig{ - GlobalReadConfig: readCfg, - GlobalWriteConfig: writeCfg, - GlobalMode: ModeEnforcing, + GlobalLimitConfig: GlobalLimitConfig{ + Mode: ModeEnforcing, + ReadWriteConfig: ReadWriteConfig{ + ReadConfig: readCfg, + WriteConfig: writeCfg, + }, + }, } mockRateLimiter := multilimiter.NewMockRateLimiter(t) mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() @@ -357,27 +382,39 @@ func TestAllow(t *testing.T) { { description: "RateLimiter does not get called when mode is disabled.", cfg: &HandlerConfig{ - GlobalReadConfig: readCfg, - GlobalWriteConfig: writeCfg, - GlobalMode: ModeDisabled, + GlobalLimitConfig: GlobalLimitConfig{ + Mode: ModeDisabled, + ReadWriteConfig: ReadWriteConfig{ + ReadConfig: readCfg, + WriteConfig: writeCfg, + }, + }, }, expectedAllowCalls: 0, }, { description: "RateLimiter gets called when mode is permissive.", cfg: &HandlerConfig{ - GlobalReadConfig: readCfg, - GlobalWriteConfig: writeCfg, - GlobalMode: ModePermissive, + GlobalLimitConfig: GlobalLimitConfig{ + Mode: ModePermissive, + ReadWriteConfig: ReadWriteConfig{ + ReadConfig: readCfg, + WriteConfig: writeCfg, + }, + }, }, expectedAllowCalls: 1, }, { description: "RateLimiter gets called when mode is enforcing.", cfg: &HandlerConfig{ - GlobalReadConfig: readCfg, - GlobalWriteConfig: writeCfg, - GlobalMode: ModeEnforcing, + GlobalLimitConfig: GlobalLimitConfig{ + Mode: ModeEnforcing, + ReadWriteConfig: ReadWriteConfig{ + ReadConfig: readCfg, + WriteConfig: writeCfg, + }, + }, }, expectedAllowCalls: 1, }, diff --git a/agent/consul/rate/mock_LeaderStatusProvider_test.go b/agent/consul/rate/mock_LeaderStatusProvider_test.go index 2c7f1b6cb..92af311b1 100644 --- a/agent/consul/rate/mock_LeaderStatusProvider_test.go +++ b/agent/consul/rate/mock_LeaderStatusProvider_test.go @@ -1,12 +1,8 @@ -// Code generated by mockery v2.12.2. DO NOT EDIT. +// Code generated by mockery v2.20.0. DO NOT EDIT. package rate -import ( - testing "testing" - - mock "github.com/stretchr/testify/mock" -) +import mock "github.com/stretchr/testify/mock" // MockLeaderStatusProvider is an autogenerated mock type for the LeaderStatusProvider type type MockLeaderStatusProvider struct { @@ -27,8 +23,13 @@ func (_m *MockLeaderStatusProvider) IsLeader() bool { return r0 } -// NewMockLeaderStatusProvider creates a new instance of MockLeaderStatusProvider. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations. -func NewMockLeaderStatusProvider(t testing.TB) *MockLeaderStatusProvider { +type mockConstructorTestingTNewMockLeaderStatusProvider interface { + mock.TestingT + Cleanup(func()) +} + +// NewMockLeaderStatusProvider creates a new instance of MockLeaderStatusProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockLeaderStatusProvider(t mockConstructorTestingTNewMockLeaderStatusProvider) *MockLeaderStatusProvider { mock := &MockLeaderStatusProvider{} mock.Mock.Test(t) diff --git a/agent/consul/rate/mock_RequestLimitsHandler.go b/agent/consul/rate/mock_RequestLimitsHandler.go index 9ff0e3baa..dcd131af4 100644 --- a/agent/consul/rate/mock_RequestLimitsHandler.go +++ b/agent/consul/rate/mock_RequestLimitsHandler.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.15.0. DO NOT EDIT. +// Code generated by mockery v2.20.0. DO NOT EDIT. package rate @@ -27,6 +27,11 @@ func (_m *MockRequestLimitsHandler) Allow(op Operation) error { return r0 } +// Register provides a mock function with given fields: leaderStatusProvider +func (_m *MockRequestLimitsHandler) Register(leaderStatusProvider LeaderStatusProvider) { + _m.Called(leaderStatusProvider) +} + // Run provides a mock function with given fields: ctx func (_m *MockRequestLimitsHandler) Run(ctx context.Context) { _m.Called(ctx) @@ -37,9 +42,9 @@ func (_m *MockRequestLimitsHandler) UpdateConfig(cfg HandlerConfig) { _m.Called(cfg) } -// Register provides a mock function with given fields: leaderStatusProvider -func (_m *MockRequestLimitsHandler) Register(leaderStatusProvider LeaderStatusProvider) { - _m.Called(leaderStatusProvider) +// UpdateIPConfig provides a mock function with given fields: cfg +func (_m *MockRequestLimitsHandler) UpdateIPConfig(cfg IPLimitConfig) { + _m.Called(cfg) } type mockConstructorTestingTNewMockRequestLimitsHandler interface { diff --git a/agent/consul/server.go b/agent/consul/server.go index 5347b0a78..ec365ef39 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -1997,14 +1997,18 @@ func ConfiguredIncomingRPCLimiter(ctx context.Context, serverLogger hclog.Interc func convertConsulConfigToRateLimitHandlerConfig(limitsConfig RequestLimits, multilimiterConfig *multilimiter.Config) *rpcRate.HandlerConfig { hc := &rpcRate.HandlerConfig{ - GlobalMode: limitsConfig.Mode, - GlobalReadConfig: multilimiter.LimiterConfig{ - Rate: limitsConfig.ReadRate, - Burst: int(limitsConfig.ReadRate) * requestLimitsBurstMultiplier, - }, - GlobalWriteConfig: multilimiter.LimiterConfig{ - Rate: limitsConfig.WriteRate, - Burst: int(limitsConfig.WriteRate) * requestLimitsBurstMultiplier, + GlobalLimitConfig: rpcRate.GlobalLimitConfig{ + Mode: limitsConfig.Mode, + ReadWriteConfig: rpcRate.ReadWriteConfig{ + ReadConfig: multilimiter.LimiterConfig{ + Rate: limitsConfig.ReadRate, + Burst: int(limitsConfig.ReadRate) * requestLimitsBurstMultiplier, + }, + WriteConfig: multilimiter.LimiterConfig{ + Rate: limitsConfig.WriteRate, + Burst: int(limitsConfig.WriteRate) * requestLimitsBurstMultiplier, + }, + }, }, } if multilimiterConfig != nil { diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index 5fcbb4533..c246428b2 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -1884,14 +1884,18 @@ func TestServer_ReloadConfig(t *testing.T) { // Check the incoming RPC rate limiter got updated mockHandler.AssertCalled(t, "UpdateConfig", rpcRate.HandlerConfig{ - GlobalMode: rc.RequestLimits.Mode, - GlobalReadConfig: multilimiter.LimiterConfig{ - Rate: rc.RequestLimits.ReadRate, - Burst: int(rc.RequestLimits.ReadRate) * requestLimitsBurstMultiplier, - }, - GlobalWriteConfig: multilimiter.LimiterConfig{ - Rate: rc.RequestLimits.WriteRate, - Burst: int(rc.RequestLimits.WriteRate) * requestLimitsBurstMultiplier, + GlobalLimitConfig: rpcRate.GlobalLimitConfig{ + Mode: rc.RequestLimits.Mode, + ReadWriteConfig: rpcRate.ReadWriteConfig{ + ReadConfig: multilimiter.LimiterConfig{ + Rate: rc.RequestLimits.ReadRate, + Burst: int(rc.RequestLimits.ReadRate) * requestLimitsBurstMultiplier, + }, + WriteConfig: multilimiter.LimiterConfig{ + Rate: rc.RequestLimits.WriteRate, + Burst: int(rc.RequestLimits.WriteRate) * requestLimitsBurstMultiplier, + }, + }, }, })