diff --git a/agent/consul/multilimiter/mock_RateLimiter.go b/agent/consul/multilimiter/mock_RateLimiter.go new file mode 100644 index 000000000..d36c2e161 --- /dev/null +++ b/agent/consul/multilimiter/mock_RateLimiter.go @@ -0,0 +1,53 @@ +// Code generated by mockery v2.15.0. DO NOT EDIT. + +package multilimiter + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// MockRateLimiter is an autogenerated mock type for the RateLimiter type +type MockRateLimiter struct { + mock.Mock +} + +// Allow provides a mock function with given fields: entity +func (_m *MockRateLimiter) Allow(entity LimitedEntity) bool { + ret := _m.Called(entity) + + var r0 bool + if rf, ok := ret.Get(0).(func(LimitedEntity) bool); ok { + r0 = rf(entity) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Run provides a mock function with given fields: ctx +func (_m *MockRateLimiter) Run(ctx context.Context) { + _m.Called(ctx) +} + +// UpdateConfig provides a mock function with given fields: c, prefix +func (_m *MockRateLimiter) UpdateConfig(c LimiterConfig, prefix []byte) { + _m.Called(c, prefix) +} + +type mockConstructorTestingTNewMockRateLimiter interface { + mock.TestingT + Cleanup(func()) +} + +// NewMockRateLimiter creates a new instance of MockRateLimiter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockRateLimiter(t mockConstructorTestingTNewMockRateLimiter) *MockRateLimiter { + mock := &MockRateLimiter{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/agent/consul/multilimiter/multilimiter.go b/agent/consul/multilimiter/multilimiter.go index ff30b48b9..c66948b26 100644 --- a/agent/consul/multilimiter/multilimiter.go +++ b/agent/consul/multilimiter/multilimiter.go @@ -3,11 +3,12 @@ package multilimiter import ( "bytes" "context" - radix "github.com/hashicorp/go-immutable-radix" - "golang.org/x/time/rate" "sync" "sync/atomic" "time" + + radix "github.com/hashicorp/go-immutable-radix" + "golang.org/x/time/rate" ) var _ RateLimiter = &MultiLimiter{} @@ -23,6 +24,8 @@ func Key(prefix, key []byte) KeyType { } // RateLimiter is the interface implemented by MultiLimiter +// +//go:generate mockery --name RateLimiter --inpackage --filename mock_RateLimiter.go type RateLimiter interface { Run(ctx context.Context) Allow(entity LimitedEntity) bool diff --git a/agent/consul/rate/handler.go b/agent/consul/rate/handler.go index 8bfa7b247..0c9b0ccd2 100644 --- a/agent/consul/rate/handler.go +++ b/agent/consul/rate/handler.go @@ -5,9 +5,11 @@ import ( "context" "errors" "net" + "reflect" "sync/atomic" "github.com/hashicorp/consul/agent/consul/multilimiter" + "github.com/hashicorp/go-hclog" ) var ( @@ -114,6 +116,7 @@ type Handler struct { delegate HandlerDelegate limiter multilimiter.RateLimiter + logger hclog.Logger } type HandlerConfig struct { @@ -140,9 +143,8 @@ type HandlerDelegate interface { IsLeader() bool } -// NewHandler creates a new RPC rate limit handler. -func NewHandler(cfg HandlerConfig, delegate HandlerDelegate) *Handler { - limiter := multilimiter.NewMultiLimiter(cfg.Config) +func NewHandlerWithLimiter(cfg HandlerConfig, delegate HandlerDelegate, + limiter multilimiter.RateLimiter, logger hclog.Logger) *Handler { limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite) limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead) @@ -150,12 +152,19 @@ func NewHandler(cfg HandlerConfig, delegate HandlerDelegate) *Handler { cfg: new(atomic.Pointer[HandlerConfig]), delegate: delegate, limiter: limiter, + logger: logger, } h.cfg.Store(&cfg) return h } +// NewHandler creates a new RPC rate limit handler. +func NewHandler(cfg HandlerConfig, delegate HandlerDelegate, logger hclog.Logger) *Handler { + limiter := multilimiter.NewMultiLimiter(cfg.Config) + return NewHandlerWithLimiter(cfg, delegate, limiter, logger) +} + // Run the limiter cleanup routine until the given context is canceled. // // Note: this starts a goroutine. @@ -175,9 +184,18 @@ func (h *Handler) Allow(op Operation) error { } func (h *Handler) UpdateConfig(cfg HandlerConfig) { + existingCfg := h.cfg.Load() h.cfg.Store(&cfg) - h.limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite) - h.limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead) + if reflect.DeepEqual(existingCfg, cfg) { + h.logger.Warn("UpdateConfig called but configuration has not changed. Skipping updating the server rate limiter configuration.") + return + } + if !reflect.DeepEqual(existingCfg.GlobalWriteConfig, cfg.GlobalWriteConfig) { + h.limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite) + } + if !reflect.DeepEqual(existingCfg.GlobalReadConfig, cfg.GlobalReadConfig) { + h.limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead) + } } var ( diff --git a/agent/consul/rate/handler_test.go b/agent/consul/rate/handler_test.go new file mode 100644 index 000000000..76fefe818 --- /dev/null +++ b/agent/consul/rate/handler_test.go @@ -0,0 +1,90 @@ +package rate + +import ( + "testing" + + "github.com/hashicorp/consul/agent/consul/multilimiter" + "github.com/hashicorp/go-hclog" + mock "github.com/stretchr/testify/mock" +) + +func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) { + mockRateLimiter := multilimiter.NewMockRateLimiter(t) + mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() + readCfg := multilimiter.LimiterConfig{Rate: 100, Burst: 100} + writeCfg := multilimiter.LimiterConfig{Rate: 99, Burst: 99} + cfg := &HandlerConfig{ + GlobalReadConfig: readCfg, + GlobalWriteConfig: writeCfg, + GlobalMode: ModeEnforcing, + } + logger := hclog.NewNullLogger() + NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger) + mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 2) +} + +func TestUpdateConfig(t *testing.T) { + type testCase struct { + description string + configModFunc func(cfg *HandlerConfig) + assertFunc func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) + } + testCases := []testCase{ + { + description: "RateLimiter does not get updated when config does not change.", + configModFunc: func(cfg *HandlerConfig) {}, + assertFunc: func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) { + mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 0) + }, + }, + { + description: "RateLimiter gets updated when GlobalReadConfig changes.", + configModFunc: func(cfg *HandlerConfig) { + cfg.GlobalReadConfig.Burst++ + }, + assertFunc: func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) { + mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 1) + mockRateLimiter.AssertCalled(t, "UpdateConfig", cfg.GlobalReadConfig, []byte("global.read")) + }, + }, + { + description: "RateLimiter gets updated when GlobalWriteConfig changes.", + configModFunc: func(cfg *HandlerConfig) { + cfg.GlobalWriteConfig.Burst++ + }, + assertFunc: func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) { + mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 1) + mockRateLimiter.AssertCalled(t, "UpdateConfig", cfg.GlobalWriteConfig, []byte("global.write")) + }, + }, + { + description: "RateLimiter does not get updated when GlobalMode changes.", + configModFunc: func(cfg *HandlerConfig) { + cfg.GlobalMode = ModePermissive + }, + assertFunc: func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) { + mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 0) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + readCfg := multilimiter.LimiterConfig{Rate: 100, Burst: 100} + writeCfg := multilimiter.LimiterConfig{Rate: 99, Burst: 99} + cfg := &HandlerConfig{ + GlobalReadConfig: readCfg, + GlobalWriteConfig: writeCfg, + GlobalMode: ModeEnforcing, + } + mockRateLimiter := multilimiter.NewMockRateLimiter(t) + mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() + logger := hclog.NewNullLogger() + handler := NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger) + mockRateLimiter.Calls = nil + tc.configModFunc(cfg) + handler.UpdateConfig(*cfg) + tc.assertFunc(mockRateLimiter, cfg) + }) + } +} diff --git a/agent/consul/server.go b/agent/consul/server.go index cf16f5d01..0aea2782c 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -480,7 +480,7 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser WriteRate: config.RequestLimitsWriteRate, } - s.incomingRPCLimiter = rpcRate.NewHandler(*s.convertConsulConfigToRateLimitHandlerConfig(*limitsConfig, mlCfg), s) + s.incomingRPCLimiter = rpcRate.NewHandler(*s.convertConsulConfigToRateLimitHandlerConfig(*limitsConfig, mlCfg), s, s.logger) } s.incomingRPCLimiter.Run(&lib.StopChannelContext{StopCh: s.shutdownCh})