Wire in rate limiter to handle internal and external gRPC calls (#15857)

This commit is contained in:
Dan Upton 2022-12-23 19:42:16 +00:00 committed by GitHub
parent 1586c7ba10
commit 006138beb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 435 additions and 68 deletions

1
.gitignore vendored
View File

@ -16,6 +16,7 @@ Thumbs.db
.idea .idea
.vscode .vscode
__debug_bin __debug_bin
coverage.out
# MacOS # MacOS
.DS_Store .DS_Store

View File

@ -565,22 +565,6 @@ func (a *Agent) Start(ctx context.Context) error {
return fmt.Errorf("Failed to load TLS configurations after applying auto-config settings: %w", err) return fmt.Errorf("Failed to load TLS configurations after applying auto-config settings: %w", err)
} }
// gRPC calls are only rate-limited on server, not client agents.
var grpcRateLimiter rpcRate.RequestLimitsHandler
grpcRateLimiter = rpcRate.NullRequestLimitsHandler()
if s, ok := a.delegate.(*consul.Server); ok {
grpcRateLimiter = s.IncomingRPCLimiter()
}
// This needs to happen after the initial auto-config is loaded, because TLS
// can only be configured on the gRPC server at the point of creation.
a.externalGRPCServer = external.NewServer(
a.logger.Named("grpc.external"),
metrics.Default(),
a.tlsConfigurator,
grpcRateLimiter,
)
if err := a.startLicenseManager(ctx); err != nil { if err := a.startLicenseManager(ctx); err != nil {
return err return err
} }
@ -618,10 +602,21 @@ func (a *Agent) Start(ctx context.Context) error {
// Setup either the client or the server. // Setup either the client or the server.
if c.ServerMode { if c.ServerMode {
server, err := consul.NewServer(consulCfg, a.baseDeps.Deps, a.externalGRPCServer) serverLogger := a.baseDeps.Logger.NamedIntercept(logging.ConsulServer)
incomingRPCLimiter := consul.ConfiguredIncomingRPCLimiter(serverLogger, consulCfg)
a.externalGRPCServer = external.NewServer(
a.logger.Named("grpc.external"),
metrics.Default(),
a.tlsConfigurator,
incomingRPCLimiter,
)
server, err := consul.NewServer(consulCfg, a.baseDeps.Deps, a.externalGRPCServer, incomingRPCLimiter, serverLogger)
if err != nil { if err != nil {
return fmt.Errorf("Failed to start Consul server: %v", err) return fmt.Errorf("Failed to start Consul server: %v", err)
} }
incomingRPCLimiter.Register(server)
a.delegate = server a.delegate = server
if a.config.PeeringEnabled && a.config.ConnectEnabled { if a.config.PeeringEnabled && a.config.ConnectEnabled {
@ -642,6 +637,13 @@ func (a *Agent) Start(ctx context.Context) error {
} }
} else { } else {
a.externalGRPCServer = external.NewServer(
a.logger.Named("grpc.external"),
metrics.Default(),
a.tlsConfigurator,
rpcRate.NullRequestLimitsHandler(),
)
client, err := consul.NewClient(consulCfg, a.baseDeps.Deps) client, err := consul.NewClient(consulCfg, a.baseDeps.Deps)
if err != nil { if err != nil {
return fmt.Errorf("Failed to start Consul client: %v", err) return fmt.Errorf("Failed to start Consul client: %v", err)

View File

@ -563,7 +563,7 @@ func TestCAManager_Initialize_Logging(t *testing.T) {
deps := newDefaultDeps(t, conf1) deps := newDefaultDeps(t, conf1)
deps.Logger = logger deps.Logger = logger
s1, err := NewServer(conf1, deps, grpc.NewServer()) s1, err := NewServer(conf1, deps, grpc.NewServer(), nil, logger)
require.NoError(t, err) require.NoError(t, err)
defer s1.Shutdown() defer s1.Shutdown()
testrpc.WaitForLeader(t, s1.RPC, "dc1") testrpc.WaitForLeader(t, s1.RPC, "dc1")

View File

@ -1556,7 +1556,7 @@ func TestLeader_ConfigEntryBootstrap_Fail(t *testing.T) {
deps := newDefaultDeps(t, config) deps := newDefaultDeps(t, config)
deps.Logger = logger deps.Logger = logger
srv, err := NewServer(config, deps, grpc.NewServer()) srv, err := NewServer(config, deps, grpc.NewServer(), nil, logger)
require.NoError(t, err) require.NoError(t, err)
defer srv.Shutdown() defer srv.Shutdown()

View File

@ -4,6 +4,7 @@ package rate
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net" "net"
"reflect" "reflect"
"sync/atomic" "sync/atomic"
@ -112,11 +113,14 @@ type RequestLimitsHandler interface {
// Handler enforces rate limits for incoming RPCs. // Handler enforces rate limits for incoming RPCs.
type Handler struct { type Handler struct {
cfg *atomic.Pointer[HandlerConfig] cfg *atomic.Pointer[HandlerConfig]
delegate HandlerDelegate leaderStatusProvider LeaderStatusProvider
limiter multilimiter.RateLimiter limiter multilimiter.RateLimiter
logger hclog.Logger
// TODO: replace this with the real logger.
// https://github.com/hashicorp/consul/pull/15822
logger hclog.Logger
} }
type HandlerConfig struct { type HandlerConfig struct {
@ -135,7 +139,8 @@ type HandlerConfig struct {
GlobalReadConfig multilimiter.LimiterConfig GlobalReadConfig multilimiter.LimiterConfig
} }
type HandlerDelegate interface { //go:generate mockery --name LeaderStatusProvider --inpackage --filename mock_LeaderStatusProvider_test.go
type LeaderStatusProvider interface {
// IsLeader is used to determine whether the operation is being performed // IsLeader is used to determine whether the operation is being performed
// against the cluster leader, such that if it can _only_ be performed by // against the cluster leader, such that if it can _only_ be performed by
// the leader (e.g. write operations) we don't tell clients to retry against // the leader (e.g. write operations) we don't tell clients to retry against
@ -143,16 +148,18 @@ type HandlerDelegate interface {
IsLeader() bool IsLeader() bool
} }
func NewHandlerWithLimiter(cfg HandlerConfig, delegate HandlerDelegate, func NewHandlerWithLimiter(
limiter multilimiter.RateLimiter, logger hclog.Logger) *Handler { cfg HandlerConfig,
limiter multilimiter.RateLimiter,
logger hclog.Logger) *Handler {
limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite) limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite)
limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead) limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead)
h := &Handler{ h := &Handler{
cfg: new(atomic.Pointer[HandlerConfig]), cfg: new(atomic.Pointer[HandlerConfig]),
delegate: delegate, limiter: limiter,
limiter: limiter, logger: logger,
logger: logger,
} }
h.cfg.Store(&cfg) h.cfg.Store(&cfg)
@ -160,9 +167,9 @@ func NewHandlerWithLimiter(cfg HandlerConfig, delegate HandlerDelegate,
} }
// NewHandler creates a new RPC rate limit handler. // NewHandler creates a new RPC rate limit handler.
func NewHandler(cfg HandlerConfig, delegate HandlerDelegate, logger hclog.Logger) *Handler { func NewHandler(cfg HandlerConfig, logger hclog.Logger) *Handler {
limiter := multilimiter.NewMultiLimiter(cfg.Config) limiter := multilimiter.NewMultiLimiter(cfg.Config)
return NewHandlerWithLimiter(cfg, delegate, limiter, logger) return NewHandlerWithLimiter(cfg, limiter, logger)
} }
// Run the limiter cleanup routine until the given context is canceled. // Run the limiter cleanup routine until the given context is canceled.
@ -175,14 +182,45 @@ func (h *Handler) Run(ctx context.Context) {
// Allow returns an error if the given operation is not allowed to proceed // Allow returns an error if the given operation is not allowed to proceed
// because of an exhausted rate-limit. // because of an exhausted rate-limit.
func (h *Handler) Allow(op Operation) error { func (h *Handler) Allow(op Operation) error {
if h.leaderStatusProvider == nil {
h.logger.Error("leaderStatusProvider required to be set via Register(). bailing on rate limiter")
return nil
// TODO: panic and make sure to use the server's recovery handler
// panic("leaderStatusProvider required to be set via Register(..)")
}
cfg := h.cfg.Load() cfg := h.cfg.Load()
if cfg.GlobalMode == ModeDisabled { if cfg.GlobalMode == ModeDisabled {
return nil return nil
} }
if !h.limiter.Allow(globalWrite) { for _, l := range h.limits(op) {
// TODO(NET-1383): actually implement the rate limiting logic and replace this returned nil. if l.mode == ModeDisabled {
return nil continue
}
if h.limiter.Allow(l.ent) {
continue
}
// TODO: metrics.
// TODO: is this the correct log-level?
enforced := l.mode == ModeEnforcing
h.logger.Trace("RPC exceeded allowed rate limit",
"rpc", op.Name,
"source_addr", op.SourceAddr.String(),
"limit_type", l.desc,
"limit_enforced", enforced,
)
if enforced {
if h.leaderStatusProvider.IsLeader() && op.Type == OperationTypeWrite {
return ErrRetryLater
}
return ErrRetryElsewhere
}
} }
return nil return nil
} }
@ -202,6 +240,48 @@ func (h *Handler) UpdateConfig(cfg HandlerConfig) {
} }
} }
func (h *Handler) Register(leaderStatusProvider LeaderStatusProvider) {
h.leaderStatusProvider = leaderStatusProvider
}
type limit struct {
mode Mode
ent multilimiter.LimitedEntity
desc string
}
// limits returns the limits to check for the given operation (e.g. global +
// ip-based + tenant-based).
func (h *Handler) limits(op Operation) []limit {
limits := make([]limit, 0)
if global := h.globalLimit(op); global != nil {
limits = append(limits, *global)
}
return limits
}
func (h *Handler) globalLimit(op Operation) *limit {
if op.Type == OperationTypeExempt {
return nil
}
cfg := h.cfg.Load()
lim := &limit{mode: cfg.GlobalMode}
switch op.Type {
case OperationTypeRead:
lim.desc = "global/read"
lim.ent = globalRead
case OperationTypeWrite:
lim.desc = "global/write"
lim.ent = globalWrite
default:
panic(fmt.Sprintf("unknown operation type %d", op.Type))
}
return lim
}
var ( var (
// globalWrite identifies the global rate limit applied to write operations. // globalWrite identifies the global rate limit applied to write operations.
globalWrite = globalLimit("global.write") globalWrite = globalLimit("global.write")

View File

@ -1,15 +1,233 @@
package rate package rate
import ( import (
"bytes"
"context"
"net" "net"
"net/netip" "net/netip"
"testing" "testing"
"github.com/hashicorp/consul/agent/consul/multilimiter" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
mock "github.com/stretchr/testify/mock"
"github.com/hashicorp/consul/agent/consul/multilimiter"
) )
//
// Revisit test when handler.go:189 TODO implemented
//
// func TestHandler_Allow_PanicsWhenLeaderStatusProviderNotRegistered(t *testing.T) {
// defer func() {
// err := recover()
// if err == nil {
// t.Fatal("Run should panic")
// }
// }()
// handler := NewHandler(HandlerConfig{}, hclog.NewNullLogger())
// handler.Allow(Operation{})
// // intentionally skip handler.Register(...)
// }
func TestHandler(t *testing.T) {
var (
rpcName = "Foo.Bar"
sourceAddr = net.TCPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:5678"))
)
type limitCheck struct {
limit multilimiter.LimitedEntity
allow bool
}
testCases := map[string]struct {
op Operation
globalMode Mode
checks []limitCheck
isLeader bool
expectErr error
expectLog bool
}{
"operation exempt from limiting": {
op: Operation{
Type: OperationTypeExempt,
Name: rpcName,
SourceAddr: sourceAddr,
},
globalMode: ModeEnforcing,
checks: []limitCheck{},
expectErr: nil,
expectLog: false,
},
"global write limit disabled": {
op: Operation{
Type: OperationTypeWrite,
Name: rpcName,
SourceAddr: sourceAddr,
},
globalMode: ModeDisabled,
checks: []limitCheck{},
expectErr: nil,
expectLog: false,
},
"global write limit within allowance": {
op: Operation{
Type: OperationTypeWrite,
Name: rpcName,
SourceAddr: sourceAddr,
},
globalMode: ModeEnforcing,
checks: []limitCheck{
{limit: globalWrite, allow: true},
},
expectErr: nil,
expectLog: false,
},
"global write limit exceeded (permissive)": {
op: Operation{
Type: OperationTypeWrite,
Name: rpcName,
SourceAddr: sourceAddr,
},
globalMode: ModePermissive,
checks: []limitCheck{
{limit: globalWrite, allow: false},
},
expectErr: nil,
expectLog: true,
},
"global write limit exceeded (enforcing, leader)": {
op: Operation{
Type: OperationTypeWrite,
Name: rpcName,
SourceAddr: sourceAddr,
},
globalMode: ModeEnforcing,
checks: []limitCheck{
{limit: globalWrite, allow: false},
},
isLeader: true,
expectErr: ErrRetryLater,
expectLog: true,
},
"global write limit exceeded (enforcing, follower)": {
op: Operation{
Type: OperationTypeWrite,
Name: rpcName,
SourceAddr: sourceAddr,
},
globalMode: ModeEnforcing,
checks: []limitCheck{
{limit: globalWrite, allow: false},
},
isLeader: false,
expectErr: ErrRetryElsewhere,
expectLog: true,
},
"global read limit disabled": {
op: Operation{
Type: OperationTypeRead,
Name: rpcName,
SourceAddr: sourceAddr,
},
globalMode: ModeDisabled,
checks: []limitCheck{},
expectErr: nil,
expectLog: false,
},
"global read limit within allowance": {
op: Operation{
Type: OperationTypeRead,
Name: rpcName,
SourceAddr: sourceAddr,
},
globalMode: ModeEnforcing,
checks: []limitCheck{
{limit: globalRead, allow: true},
},
expectErr: nil,
expectLog: false,
},
"global read limit exceeded (permissive)": {
op: Operation{
Type: OperationTypeRead,
Name: rpcName,
SourceAddr: sourceAddr,
},
globalMode: ModePermissive,
checks: []limitCheck{
{limit: globalRead, allow: false},
},
expectErr: nil,
expectLog: true,
},
"global read limit exceeded (enforcing, leader)": {
op: Operation{
Type: OperationTypeRead,
Name: rpcName,
SourceAddr: sourceAddr,
},
globalMode: ModeEnforcing,
checks: []limitCheck{
{limit: globalRead, allow: false},
},
isLeader: true,
expectErr: ErrRetryElsewhere,
expectLog: true,
},
"global read limit exceeded (enforcing, follower)": {
op: Operation{
Type: OperationTypeRead,
Name: rpcName,
SourceAddr: sourceAddr,
},
globalMode: ModeEnforcing,
checks: []limitCheck{
{limit: globalRead, allow: false},
},
isLeader: false,
expectErr: ErrRetryElsewhere,
expectLog: true,
},
}
for desc, tc := range testCases {
t.Run(desc, func(t *testing.T) {
limiter := newMockLimiter(t)
limiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
for _, c := range tc.checks {
limiter.On("Allow", c.limit).Return(c.allow)
}
leaderStatusProvider := NewMockLeaderStatusProvider(t)
leaderStatusProvider.On("IsLeader").Return(tc.isLeader).Maybe()
var output bytes.Buffer
logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{
Level: hclog.Trace,
Output: &output,
})
handler := NewHandlerWithLimiter(
HandlerConfig{
GlobalMode: tc.globalMode,
},
limiter,
logger,
)
handler.Register(leaderStatusProvider)
require.Equal(t, tc.expectErr, handler.Allow(tc.op))
if tc.expectLog {
require.Contains(t, output.String(), "RPC exceeded allowed rate limit")
} else {
require.Zero(t, output.Len(), "expected no logs to be emitted")
}
})
}
}
func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) { func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) {
mockRateLimiter := multilimiter.NewMockRateLimiter(t) mockRateLimiter := multilimiter.NewMockRateLimiter(t)
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
@ -22,7 +240,7 @@ func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) {
} }
logger := hclog.NewNullLogger() logger := hclog.NewNullLogger()
NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger) NewHandlerWithLimiter(*cfg, mockRateLimiter, logger)
mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 2) mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 2)
} }
@ -83,7 +301,7 @@ func TestUpdateConfig(t *testing.T) {
mockRateLimiter := multilimiter.NewMockRateLimiter(t) mockRateLimiter := multilimiter.NewMockRateLimiter(t)
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
logger := hclog.NewNullLogger() logger := hclog.NewNullLogger()
handler := NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger) handler := NewHandlerWithLimiter(*cfg, mockRateLimiter, logger)
mockRateLimiter.Calls = nil mockRateLimiter.Calls = nil
tc.configModFunc(cfg) tc.configModFunc(cfg)
handler.UpdateConfig(*cfg) handler.UpdateConfig(*cfg)
@ -139,7 +357,10 @@ func TestAllow(t *testing.T) {
} }
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
logger := hclog.NewNullLogger() logger := hclog.NewNullLogger()
handler := NewHandlerWithLimiter(*tc.cfg, nil, mockRateLimiter, logger) delegate := NewMockLeaderStatusProvider(t)
delegate.On("IsLeader").Return(true).Maybe()
handler := NewHandlerWithLimiter(*tc.cfg, mockRateLimiter, logger)
handler.Register(delegate)
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234")) addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234"))
mockRateLimiter.Calls = nil mockRateLimiter.Calls = nil
handler.Allow(Operation{Name: "test", SourceAddr: addr}) handler.Allow(Operation{Name: "test", SourceAddr: addr})
@ -147,3 +368,24 @@ func TestAllow(t *testing.T) {
}) })
} }
} }
var _ multilimiter.RateLimiter = (*mockLimiter)(nil)
func newMockLimiter(t *testing.T) *mockLimiter {
l := &mockLimiter{}
l.Mock.Test(t)
t.Cleanup(func() { l.AssertExpectations(t) })
return l
}
type mockLimiter struct {
mock.Mock
}
func (m *mockLimiter) Allow(v multilimiter.LimitedEntity) bool { return m.Called(v).Bool(0) }
func (m *mockLimiter) Run(ctx context.Context) { m.Called(ctx) }
func (m *mockLimiter) UpdateConfig(cfg multilimiter.LimiterConfig, prefix []byte) {
m.Called(cfg, prefix)
}

View File

@ -0,0 +1,38 @@
// Code generated by mockery v2.12.2. DO NOT EDIT.
package rate
import (
testing "testing"
mock "github.com/stretchr/testify/mock"
)
// MockLeaderStatusProvider is an autogenerated mock type for the LeaderStatusProvider type
type MockLeaderStatusProvider struct {
mock.Mock
}
// IsLeader provides a mock function with given fields:
func (_m *MockLeaderStatusProvider) IsLeader() bool {
ret := _m.Called()
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// NewMockLeaderStatusProvider creates a new instance of MockLeaderStatusProvider. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockLeaderStatusProvider(t testing.TB) *MockLeaderStatusProvider {
mock := &MockLeaderStatusProvider{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -407,7 +407,7 @@ type connHandler interface {
// NewServer is used to construct a new Consul server from the configuration // NewServer is used to construct a new Consul server from the configuration
// and extra options, potentially returning an error. // and extra options, potentially returning an error.
func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Server, error) { func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incomingRPCLimiter rpcRate.RequestLimitsHandler, serverLogger hclog.InterceptLogger) (*Server, error) {
logger := flat.Logger logger := flat.Logger
if err := config.CheckProtocolVersion(); err != nil { if err := config.CheckProtocolVersion(); err != nil {
return nil, err return nil, err
@ -428,7 +428,6 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser
// Create the shutdown channel - this is closed but never written to. // Create the shutdown channel - this is closed but never written to.
shutdownCh := make(chan struct{}) shutdownCh := make(chan struct{})
serverLogger := flat.Logger.NamedIntercept(logging.ConsulServer)
loggers := newLoggerStore(serverLogger) loggers := newLoggerStore(serverLogger)
fsmDeps := fsm.Deps{ fsmDeps := fsm.Deps{
@ -439,6 +438,10 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser
Publisher: flat.EventPublisher, Publisher: flat.EventPublisher,
} }
if incomingRPCLimiter == nil {
incomingRPCLimiter = rpcRate.NullRequestLimitsHandler()
}
// Create server. // Create server.
s := &Server{ s := &Server{
config: config, config: config,
@ -463,6 +466,7 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser
aclAuthMethodValidators: authmethod.NewCache(), aclAuthMethodValidators: authmethod.NewCache(),
fsm: fsm.NewFromDeps(fsmDeps), fsm: fsm.NewFromDeps(fsmDeps),
publisher: flat.EventPublisher, publisher: flat.EventPublisher,
incomingRPCLimiter: incomingRPCLimiter,
} }
s.hcpManager = hcp.NewManager(hcp.ManagerConfig{ s.hcpManager = hcp.NewManager(hcp.ManagerConfig{
@ -471,17 +475,6 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser
Logger: logger.Named("hcp_manager"), Logger: logger.Named("hcp_manager"),
}) })
// TODO(NET-1380, NET-1381): thread this into the net/rpc and gRPC interceptors.
if s.incomingRPCLimiter == nil {
mlCfg := &multilimiter.Config{ReconcileCheckLimit: 30 * time.Second, ReconcileCheckInterval: time.Second}
limitsConfig := &RequestLimits{
Mode: rpcRate.RequestLimitsModeFromNameWithDefault(config.RequestLimitsMode),
ReadRate: config.RequestLimitsReadRate,
WriteRate: config.RequestLimitsWriteRate,
}
s.incomingRPCLimiter = rpcRate.NewHandler(*s.convertConsulConfigToRateLimitHandlerConfig(*limitsConfig, mlCfg), s, s.logger)
}
s.incomingRPCLimiter.Run(&lib.StopChannelContext{StopCh: s.shutdownCh}) s.incomingRPCLimiter.Run(&lib.StopChannelContext{StopCh: s.shutdownCh})
var recorder *middleware.RequestRecorder var recorder *middleware.RequestRecorder
@ -1696,7 +1689,7 @@ func (s *Server) ReloadConfig(config ReloadableConfig) error {
s.rpcLimiter.Store(rate.NewLimiter(config.RPCRateLimit, config.RPCMaxBurst)) s.rpcLimiter.Store(rate.NewLimiter(config.RPCRateLimit, config.RPCMaxBurst))
if config.RequestLimits != nil { if config.RequestLimits != nil {
s.incomingRPCLimiter.UpdateConfig(*s.convertConsulConfigToRateLimitHandlerConfig(*config.RequestLimits, nil)) s.incomingRPCLimiter.UpdateConfig(*convertConsulConfigToRateLimitHandlerConfig(*config.RequestLimits, nil))
} }
s.rpcConnLimiter.SetConfig(connlimit.Config{ s.rpcConnLimiter.SetConfig(connlimit.Config{
@ -1849,9 +1842,25 @@ func (s *Server) hcpServerStatus(deps Deps) hcp.StatusCallback {
} }
} }
// convertConsulConfigToRateLimitHandlerConfig creates a rate limite handler config func ConfiguredIncomingRPCLimiter(serverLogger hclog.InterceptLogger, consulCfg *Config) *rpcRate.Handler {
// from the relevant fields in the consul runtime config. mlCfg := &multilimiter.Config{ReconcileCheckLimit: 30 * time.Second, ReconcileCheckInterval: time.Second}
func (s *Server) convertConsulConfigToRateLimitHandlerConfig(limitsConfig RequestLimits, multilimiterConfig *multilimiter.Config) *rpcRate.HandlerConfig { limitsConfig := &RequestLimits{
Mode: rpcRate.RequestLimitsModeFromNameWithDefault(consulCfg.RequestLimitsMode),
ReadRate: consulCfg.RequestLimitsReadRate,
WriteRate: consulCfg.RequestLimitsWriteRate,
}
rateLimiterConfig := convertConsulConfigToRateLimitHandlerConfig(*limitsConfig, mlCfg)
incomingRPCLimiter := rpcRate.NewHandler(
*rateLimiterConfig,
serverLogger.Named("rpc-rate-limit"),
)
return incomingRPCLimiter
}
func convertConsulConfigToRateLimitHandlerConfig(limitsConfig RequestLimits, multilimiterConfig *multilimiter.Config) *rpcRate.HandlerConfig {
hc := &rpcRate.HandlerConfig{ hc := &rpcRate.HandlerConfig{
GlobalMode: limitsConfig.Mode, GlobalMode: limitsConfig.Mode,
GlobalReadConfig: multilimiter.LimiterConfig{ GlobalReadConfig: multilimiter.LimiterConfig{
@ -1870,11 +1879,6 @@ func (s *Server) convertConsulConfigToRateLimitHandlerConfig(limitsConfig Reques
return hc return hc
} }
// IncomingRPCLimiter returns the server's configured rate limit handler for
// incoming RPCs. This is necessary because the external gRPC server is created
// by the agent (as it is also used for xDS).
func (s *Server) IncomingRPCLimiter() rpcRate.RequestLimitsHandler { return s.incomingRPCLimiter }
// peersInfoContent is used to help operators understand what happened to the // peersInfoContent is used to help operators understand what happened to the
// peers.json file. This is written to a file called peers.info in the same // peers.json file. This is written to a file called peers.info in the same
// location. // location.

View File

@ -334,7 +334,7 @@ func newServerWithDeps(t *testing.T, c *Config, deps Deps) (*Server, error) {
} }
} }
grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator, rpcRate.NullRequestLimitsHandler()) grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator, rpcRate.NullRequestLimitsHandler())
srv, err := NewServer(c, deps, grpcServer) srv, err := NewServer(c, deps, grpcServer, nil, deps.Logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1241,7 +1241,7 @@ func TestServer_RPC_MetricsIntercept_Off(t *testing.T) {
} }
} }
s1, err := NewServer(conf, deps, grpc.NewServer()) s1, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -1279,7 +1279,7 @@ func TestServer_RPC_MetricsIntercept_Off(t *testing.T) {
return nil return nil
} }
s2, err := NewServer(conf, deps, grpc.NewServer()) s2, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -1313,7 +1313,7 @@ func TestServer_RPC_RequestRecorder(t *testing.T) {
deps := newDefaultDeps(t, conf) deps := newDefaultDeps(t, conf)
deps.NewRequestRecorderFunc = nil deps.NewRequestRecorderFunc = nil
s1, err := NewServer(conf, deps, grpc.NewServer()) s1, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger)
require.Error(t, err, "need err when provider func is nil") require.Error(t, err, "need err when provider func is nil")
require.Equal(t, err.Error(), "cannot initialize server without an RPC request recorder provider") require.Equal(t, err.Error(), "cannot initialize server without an RPC request recorder provider")
@ -1332,7 +1332,7 @@ func TestServer_RPC_RequestRecorder(t *testing.T) {
return nil return nil
} }
s2, err := NewServer(conf, deps, grpc.NewServer()) s2, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger)
require.Error(t, err, "need err when RequestRecorder is nil") require.Error(t, err, "need err when RequestRecorder is nil")
require.Equal(t, err.Error(), "cannot initialize server with a nil RPC request recorder") require.Equal(t, err.Error(), "cannot initialize server with a nil RPC request recorder")

View File

@ -41,7 +41,7 @@ func ServerRateLimiterMiddleware(limiter rate.RequestLimitsHandler, panicHandler
err := limiter.Allow(rate.Operation{ err := limiter.Allow(rate.Operation{
Name: info.FullMethodName, Name: info.FullMethodName,
SourceAddr: peer.Addr, SourceAddr: peer.Addr,
// TODO: operation type. // TODO: add operation type from https://github.com/hashicorp/consul/pull/15564
}) })
switch { switch {

View File

@ -1594,7 +1594,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer {
deps := newDefaultDeps(t, conf) deps := newDefaultDeps(t, conf)
externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator, rate.NullRequestLimitsHandler()) externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator, rate.NullRequestLimitsHandler())
server, err := consul.NewServer(conf, deps, externalGRPCServer) server, err := consul.NewServer(conf, deps, externalGRPCServer, nil, deps.Logger)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, server.Shutdown()) require.NoError(t, server.Shutdown())