add necessary plumbing to implement per server ip based rate limiting (#17436)

This commit is contained in:
Dhia Ayachi 2023-05-23 15:37:01 -04:00 committed by GitHub
parent 3ed4f7a33a
commit cdc47ea200
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 113 additions and 81 deletions

View File

@ -8,6 +8,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/hashicorp/consul/agent/metadata"
"net" "net"
"reflect" "reflect"
"sync/atomic" "sync/atomic"
@ -153,14 +154,14 @@ type RequestLimitsHandler interface {
Allow(op Operation) error Allow(op Operation) error
UpdateConfig(cfg HandlerConfig) UpdateConfig(cfg HandlerConfig)
UpdateIPConfig(cfg IPLimitConfig) UpdateIPConfig(cfg IPLimitConfig)
Register(leaderStatusProvider LeaderStatusProvider) Register(serversStatusProvider ServersStatusProvider)
} }
// Handler enforces rate limits for incoming RPCs. // Handler enforces rate limits for incoming RPCs.
type Handler struct { type Handler struct {
globalCfg *atomic.Pointer[HandlerConfig] globalCfg *atomic.Pointer[HandlerConfig]
ipCfg *atomic.Pointer[IPLimitConfig] ipCfg *atomic.Pointer[IPLimitConfig]
leaderStatusProvider LeaderStatusProvider serversStatusProvider ServersStatusProvider
limiter multilimiter.RateLimiter limiter multilimiter.RateLimiter
@ -186,13 +187,14 @@ type HandlerConfig struct {
GlobalLimitConfig GlobalLimitConfig GlobalLimitConfig GlobalLimitConfig
} }
//go:generate mockery --name LeaderStatusProvider --inpackage --filename mock_LeaderStatusProvider_test.go //go:generate mockery --name ServersStatusProvider --inpackage --filename mock_ServersStatusProvider_test.go
type LeaderStatusProvider interface { type ServersStatusProvider interface {
// IsLeader is used to determine whether the operation is being performed // IsLeader is used to determine whether the operation is being performed
// against the cluster leader, such that if it can _only_ be performed by // against the cluster leader, such that if it can _only_ be performed by
// the leader (e.g. write operations) we don't tell clients to retry against // the leader (e.g. write operations) we don't tell clients to retry against
// a different server. // a different server.
IsLeader() bool IsLeader() bool
IsServer(addr string) bool
} }
func isInfRate(cfg multilimiter.LimiterConfig) bool { func isInfRate(cfg multilimiter.LimiterConfig) bool {
@ -237,11 +239,11 @@ func (h *Handler) Run(ctx context.Context) {
// because of an exhausted rate-limit. // because of an exhausted rate-limit.
func (h *Handler) Allow(op Operation) error { func (h *Handler) Allow(op Operation) error {
if h.leaderStatusProvider == nil { if h.serversStatusProvider == nil {
h.logger.Error("leaderStatusProvider required to be set via Register(). bailing on rate limiter") h.logger.Error("serversStatusProvider required to be set via Register(). bailing on rate limiter")
return nil return nil
// TODO: panic and make sure to use the server's recovery handler // TODO: panic and make sure to use the server's recovery handler
// panic("leaderStatusProvider required to be set via Register(..)") // panic("serversStatusProvider required to be set via Register(..)")
} }
cfg := h.globalCfg.Load() cfg := h.globalCfg.Load()
@ -249,7 +251,7 @@ func (h *Handler) Allow(op Operation) error {
return nil return nil
} }
allow, throttledLimits := h.allowAllLimits(h.limits(op)) allow, throttledLimits := h.allowAllLimits(h.limits(op), h.serversStatusProvider.IsServer(string(metadata.GetIP(op.SourceAddr))))
if !allow { if !allow {
for _, l := range throttledLimits { for _, l := range throttledLimits {
@ -277,7 +279,7 @@ func (h *Handler) Allow(op Operation) error {
}) })
if enforced { if enforced {
if h.leaderStatusProvider.IsLeader() && op.Type == OperationTypeWrite { if h.serversStatusProvider.IsLeader() && op.Type == OperationTypeWrite {
return ErrRetryLater return ErrRetryLater
} }
return ErrRetryElsewhere return ErrRetryElsewhere
@ -305,17 +307,18 @@ func (h *Handler) UpdateConfig(cfg HandlerConfig) {
} }
func (h *Handler) Register(leaderStatusProvider LeaderStatusProvider) { func (h *Handler) Register(serversStatusProvider ServersStatusProvider) {
h.leaderStatusProvider = leaderStatusProvider h.serversStatusProvider = serversStatusProvider
} }
type limit struct { type limit struct {
mode Mode mode Mode
ent multilimiter.LimitedEntity ent multilimiter.LimitedEntity
desc string desc string
applyOnServer bool
} }
func (h *Handler) allowAllLimits(limits []limit) (bool, []limit) { func (h *Handler) allowAllLimits(limits []limit, isServer bool) (bool, []limit) {
allow := true allow := true
throttledLimits := make([]limit, 0) throttledLimits := make([]limit, 0)
@ -324,6 +327,10 @@ func (h *Handler) allowAllLimits(limits []limit) (bool, []limit) {
continue continue
} }
if isServer && !l.applyOnServer {
continue
}
if !h.limiter.Allow(l.ent) { if !h.limiter.Allow(l.ent) {
throttledLimits = append(throttledLimits, l) throttledLimits = append(throttledLimits, l)
allow = false allow = false
@ -358,7 +365,7 @@ func (h *Handler) globalLimit(op Operation) *limit {
} }
cfg := h.globalCfg.Load() cfg := h.globalCfg.Load()
lim := &limit{mode: cfg.GlobalLimitConfig.Mode} lim := &limit{mode: cfg.GlobalLimitConfig.Mode, applyOnServer: true}
switch op.Type { switch op.Type {
case OperationTypeRead: case OperationTypeRead:
lim.desc = "global/read" lim.desc = "global/read"
@ -409,4 +416,4 @@ func (nullRequestLimitsHandler) Run(_ context.Context) {}
func (nullRequestLimitsHandler) UpdateConfig(_ HandlerConfig) {} func (nullRequestLimitsHandler) UpdateConfig(_ HandlerConfig) {}
func (nullRequestLimitsHandler) Register(_ LeaderStatusProvider) {} func (nullRequestLimitsHandler) Register(_ ServersStatusProvider) {}

View File

@ -19,22 +19,6 @@ import (
"github.com/hashicorp/consul/agent/consul/multilimiter" "github.com/hashicorp/consul/agent/consul/multilimiter"
) )
//
// Revisit test when handler.go:189 TODO implemented
//
// func TestHandler_Allow_PanicsWhenLeaderStatusProviderNotRegistered(t *testing.T) {
// defer func() {
// err := recover()
// if err == nil {
// t.Fatal("Run should panic")
// }
// }()
// handler := NewHandler(HandlerConfig{}, hclog.NewNullLogger())
// handler.Allow(Operation{})
// // intentionally skip handler.Register(...)
// }
func TestHandler(t *testing.T) { func TestHandler(t *testing.T) {
var ( var (
rpcName = "Foo.Bar" rpcName = "Foo.Bar"
@ -50,6 +34,7 @@ func TestHandler(t *testing.T) {
globalMode Mode globalMode Mode
checks []limitCheck checks []limitCheck
isLeader bool isLeader bool
isServer bool
expectErr error expectErr error
expectLog bool expectLog bool
expectMetric bool expectMetric bool
@ -230,8 +215,9 @@ func TestHandler(t *testing.T) {
limiter.On("Allow", mock.Anything).Return(c.allow) limiter.On("Allow", mock.Anything).Return(c.allow)
} }
leaderStatusProvider := NewMockLeaderStatusProvider(t) serversStatusProvider := NewMockServersStatusProvider(t)
leaderStatusProvider.On("IsLeader").Return(tc.isLeader).Maybe() serversStatusProvider.On("IsLeader").Return(tc.isLeader).Maybe()
serversStatusProvider.On("IsServer", mock.Anything).Return(tc.isServer).Maybe()
var output bytes.Buffer var output bytes.Buffer
logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{ logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{
@ -252,7 +238,7 @@ func TestHandler(t *testing.T) {
limiter, limiter,
logger, logger,
) )
handler.Register(leaderStatusProvider) handler.Register(serversStatusProvider)
require.Equal(t, tc.expectErr, handler.Allow(tc.op)) require.Equal(t, tc.expectErr, handler.Allow(tc.op))
@ -426,8 +412,9 @@ func TestAllow(t *testing.T) {
} }
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
logger := hclog.NewNullLogger() logger := hclog.NewNullLogger()
delegate := NewMockLeaderStatusProvider(t) delegate := NewMockServersStatusProvider(t)
delegate.On("IsLeader").Return(true).Maybe() delegate.On("IsLeader").Return(true).Maybe()
delegate.On("IsServer", mock.Anything).Return(false).Maybe()
handler := NewHandlerWithLimiter(*tc.cfg, mockRateLimiter, logger) handler := NewHandlerWithLimiter(*tc.cfg, mockRateLimiter, logger)
handler.Register(delegate) handler.Register(delegate)
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234")) addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234"))

View File

@ -1,39 +0,0 @@
// Code generated by mockery v2.20.0. DO NOT EDIT.
package rate
import mock "github.com/stretchr/testify/mock"
// MockLeaderStatusProvider is an autogenerated mock type for the LeaderStatusProvider type
type MockLeaderStatusProvider struct {
mock.Mock
}
// IsLeader provides a mock function with given fields:
func (_m *MockLeaderStatusProvider) IsLeader() bool {
ret := _m.Called()
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
type mockConstructorTestingTNewMockLeaderStatusProvider interface {
mock.TestingT
Cleanup(func())
}
// NewMockLeaderStatusProvider creates a new instance of MockLeaderStatusProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockLeaderStatusProvider(t mockConstructorTestingTNewMockLeaderStatusProvider) *MockLeaderStatusProvider {
mock := &MockLeaderStatusProvider{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -27,9 +27,9 @@ func (_m *MockRequestLimitsHandler) Allow(op Operation) error {
return r0 return r0
} }
// Register provides a mock function with given fields: leaderStatusProvider // Register provides a mock function with given fields: serversStatusProvider
func (_m *MockRequestLimitsHandler) Register(leaderStatusProvider LeaderStatusProvider) { func (_m *MockRequestLimitsHandler) Register(serversStatusProvider ServersStatusProvider) {
_m.Called(leaderStatusProvider) _m.Called(serversStatusProvider)
} }
// Run provides a mock function with given fields: ctx // Run provides a mock function with given fields: ctx

View File

@ -0,0 +1,53 @@
// Code generated by mockery v2.20.0. DO NOT EDIT.
package rate
import mock "github.com/stretchr/testify/mock"
// MockServersStatusProvider is an autogenerated mock type for the ServersStatusProvider type
type MockServersStatusProvider struct {
mock.Mock
}
// IsLeader provides a mock function with given fields:
func (_m *MockServersStatusProvider) IsLeader() bool {
ret := _m.Called()
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// IsServer provides a mock function with given fields: addr
func (_m *MockServersStatusProvider) IsServer(addr string) bool {
ret := _m.Called(addr)
var r0 bool
if rf, ok := ret.Get(0).(func(string) bool); ok {
r0 = rf(addr)
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
type mockConstructorTestingTNewMockServersStatusProvider interface {
mock.TestingT
Cleanup(func())
}
// NewMockServersStatusProvider creates a new instance of MockServersStatusProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockServersStatusProvider(t mockConstructorTestingTNewMockServersStatusProvider) *MockServersStatusProvider {
mock := &MockServersStatusProvider{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -1660,6 +1660,20 @@ func (s *Server) IsLeader() bool {
return s.raft.State() == raft.Leader return s.raft.State() == raft.Leader
} }
// IsServer checks if this addr is of a server
func (s *Server) IsServer(addr string) bool {
for _, s := range s.raft.GetConfiguration().Configuration().Servers {
a, err := net.ResolveTCPAddr("tcp", string(s.Address))
if err != nil {
continue
}
if string(metadata.GetIP(a)) == addr {
return true
}
}
return false
}
// LeaderLastContact returns the time of last contact by a leader. // LeaderLastContact returns the time of last contact by a leader.
// This only makes sense if we are currently a follower. // This only makes sense if we are currently a follower.
func (s *Server) LeaderLastContact() time.Time { func (s *Server) LeaderLastContact() time.Time {

View File

@ -221,3 +221,13 @@ func AddFeatureFlags(tags map[string]string, flags ...string) {
tags[featureFlagPrefix+flag] = "1" tags[featureFlagPrefix+flag] = "1"
} }
} }
func GetIP(addr net.Addr) []byte {
switch a := addr.(type) {
case *net.UDPAddr:
return []byte(a.IP.String())
case *net.TCPAddr:
return []byte(a.IP.String())
}
return []byte{}
}