Rate limiting handler - ensure configuration has changed before modifying limiters (#15805)

* Rate limiting handler - ensure configuration has changed before modifying limiters

* Updating test to validate arguments to UpdateConfig

* Removing duplicate test.  Updating mock.

* adding logging for when UpdateConfig is called but the config has not changed.

* Update agent/consul/rate/handler.go

Co-authored-by: Dhia Ayachi <dhia@hashicorp.com>

Co-authored-by: Dhia Ayachi <dhia@hashicorp.com>
This commit is contained in:
John Murret 2022-12-20 14:12:03 -07:00 committed by GitHub
parent a0a8a205c5
commit 8c33d7cc0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 172 additions and 8 deletions

View File

@ -0,0 +1,53 @@
// Code generated by mockery v2.15.0. DO NOT EDIT.
package multilimiter
import (
context "context"
mock "github.com/stretchr/testify/mock"
)
// MockRateLimiter is an autogenerated mock type for the RateLimiter type
type MockRateLimiter struct {
mock.Mock
}
// Allow provides a mock function with given fields: entity
func (_m *MockRateLimiter) Allow(entity LimitedEntity) bool {
ret := _m.Called(entity)
var r0 bool
if rf, ok := ret.Get(0).(func(LimitedEntity) bool); ok {
r0 = rf(entity)
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// Run provides a mock function with given fields: ctx
func (_m *MockRateLimiter) Run(ctx context.Context) {
_m.Called(ctx)
}
// UpdateConfig provides a mock function with given fields: c, prefix
func (_m *MockRateLimiter) UpdateConfig(c LimiterConfig, prefix []byte) {
_m.Called(c, prefix)
}
type mockConstructorTestingTNewMockRateLimiter interface {
mock.TestingT
Cleanup(func())
}
// NewMockRateLimiter creates a new instance of MockRateLimiter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockRateLimiter(t mockConstructorTestingTNewMockRateLimiter) *MockRateLimiter {
mock := &MockRateLimiter{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -3,11 +3,12 @@ package multilimiter
import (
"bytes"
"context"
radix "github.com/hashicorp/go-immutable-radix"
"golang.org/x/time/rate"
"sync"
"sync/atomic"
"time"
radix "github.com/hashicorp/go-immutable-radix"
"golang.org/x/time/rate"
)
var _ RateLimiter = &MultiLimiter{}
@ -23,6 +24,8 @@ func Key(prefix, key []byte) KeyType {
}
// RateLimiter is the interface implemented by MultiLimiter
//
//go:generate mockery --name RateLimiter --inpackage --filename mock_RateLimiter.go
type RateLimiter interface {
Run(ctx context.Context)
Allow(entity LimitedEntity) bool

View File

@ -5,9 +5,11 @@ import (
"context"
"errors"
"net"
"reflect"
"sync/atomic"
"github.com/hashicorp/consul/agent/consul/multilimiter"
"github.com/hashicorp/go-hclog"
)
var (
@ -114,6 +116,7 @@ type Handler struct {
delegate HandlerDelegate
limiter multilimiter.RateLimiter
logger hclog.Logger
}
type HandlerConfig struct {
@ -140,9 +143,8 @@ type HandlerDelegate interface {
IsLeader() bool
}
// NewHandler creates a new RPC rate limit handler.
func NewHandler(cfg HandlerConfig, delegate HandlerDelegate) *Handler {
limiter := multilimiter.NewMultiLimiter(cfg.Config)
func NewHandlerWithLimiter(cfg HandlerConfig, delegate HandlerDelegate,
limiter multilimiter.RateLimiter, logger hclog.Logger) *Handler {
limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite)
limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead)
@ -150,12 +152,19 @@ func NewHandler(cfg HandlerConfig, delegate HandlerDelegate) *Handler {
cfg: new(atomic.Pointer[HandlerConfig]),
delegate: delegate,
limiter: limiter,
logger: logger,
}
h.cfg.Store(&cfg)
return h
}
// NewHandler creates a new RPC rate limit handler.
func NewHandler(cfg HandlerConfig, delegate HandlerDelegate, logger hclog.Logger) *Handler {
limiter := multilimiter.NewMultiLimiter(cfg.Config)
return NewHandlerWithLimiter(cfg, delegate, limiter, logger)
}
// Run the limiter cleanup routine until the given context is canceled.
//
// Note: this starts a goroutine.
@ -175,9 +184,18 @@ func (h *Handler) Allow(op Operation) error {
}
func (h *Handler) UpdateConfig(cfg HandlerConfig) {
existingCfg := h.cfg.Load()
h.cfg.Store(&cfg)
h.limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite)
h.limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead)
if reflect.DeepEqual(existingCfg, cfg) {
h.logger.Warn("UpdateConfig called but configuration has not changed. Skipping updating the server rate limiter configuration.")
return
}
if !reflect.DeepEqual(existingCfg.GlobalWriteConfig, cfg.GlobalWriteConfig) {
h.limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite)
}
if !reflect.DeepEqual(existingCfg.GlobalReadConfig, cfg.GlobalReadConfig) {
h.limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead)
}
}
var (

View File

@ -0,0 +1,90 @@
package rate
import (
"testing"
"github.com/hashicorp/consul/agent/consul/multilimiter"
"github.com/hashicorp/go-hclog"
mock "github.com/stretchr/testify/mock"
)
func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) {
mockRateLimiter := multilimiter.NewMockRateLimiter(t)
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
readCfg := multilimiter.LimiterConfig{Rate: 100, Burst: 100}
writeCfg := multilimiter.LimiterConfig{Rate: 99, Burst: 99}
cfg := &HandlerConfig{
GlobalReadConfig: readCfg,
GlobalWriteConfig: writeCfg,
GlobalMode: ModeEnforcing,
}
logger := hclog.NewNullLogger()
NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger)
mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 2)
}
func TestUpdateConfig(t *testing.T) {
type testCase struct {
description string
configModFunc func(cfg *HandlerConfig)
assertFunc func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig)
}
testCases := []testCase{
{
description: "RateLimiter does not get updated when config does not change.",
configModFunc: func(cfg *HandlerConfig) {},
assertFunc: func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) {
mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 0)
},
},
{
description: "RateLimiter gets updated when GlobalReadConfig changes.",
configModFunc: func(cfg *HandlerConfig) {
cfg.GlobalReadConfig.Burst++
},
assertFunc: func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) {
mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 1)
mockRateLimiter.AssertCalled(t, "UpdateConfig", cfg.GlobalReadConfig, []byte("global.read"))
},
},
{
description: "RateLimiter gets updated when GlobalWriteConfig changes.",
configModFunc: func(cfg *HandlerConfig) {
cfg.GlobalWriteConfig.Burst++
},
assertFunc: func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) {
mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 1)
mockRateLimiter.AssertCalled(t, "UpdateConfig", cfg.GlobalWriteConfig, []byte("global.write"))
},
},
{
description: "RateLimiter does not get updated when GlobalMode changes.",
configModFunc: func(cfg *HandlerConfig) {
cfg.GlobalMode = ModePermissive
},
assertFunc: func(mockRateLimiter *multilimiter.MockRateLimiter, cfg *HandlerConfig) {
mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 0)
},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
readCfg := multilimiter.LimiterConfig{Rate: 100, Burst: 100}
writeCfg := multilimiter.LimiterConfig{Rate: 99, Burst: 99}
cfg := &HandlerConfig{
GlobalReadConfig: readCfg,
GlobalWriteConfig: writeCfg,
GlobalMode: ModeEnforcing,
}
mockRateLimiter := multilimiter.NewMockRateLimiter(t)
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
logger := hclog.NewNullLogger()
handler := NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger)
mockRateLimiter.Calls = nil
tc.configModFunc(cfg)
handler.UpdateConfig(*cfg)
tc.assertFunc(mockRateLimiter, cfg)
})
}
}

View File

@ -480,7 +480,7 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser
WriteRate: config.RequestLimitsWriteRate,
}
s.incomingRPCLimiter = rpcRate.NewHandler(*s.convertConsulConfigToRateLimitHandlerConfig(*limitsConfig, mlCfg), s)
s.incomingRPCLimiter = rpcRate.NewHandler(*s.convertConsulConfigToRateLimitHandlerConfig(*limitsConfig, mlCfg), s, s.logger)
}
s.incomingRPCLimiter.Run(&lib.StopChannelContext{StopCh: s.shutdownCh})