Wire in rate limiter to handle internal and external gRPC calls (#15857)
This commit is contained in:
parent
1586c7ba10
commit
006138beb4
|
@ -16,6 +16,7 @@ Thumbs.db
|
|||
.idea
|
||||
.vscode
|
||||
__debug_bin
|
||||
coverage.out
|
||||
|
||||
# MacOS
|
||||
.DS_Store
|
||||
|
|
|
@ -565,22 +565,6 @@ func (a *Agent) Start(ctx context.Context) error {
|
|||
return fmt.Errorf("Failed to load TLS configurations after applying auto-config settings: %w", err)
|
||||
}
|
||||
|
||||
// gRPC calls are only rate-limited on server, not client agents.
|
||||
var grpcRateLimiter rpcRate.RequestLimitsHandler
|
||||
grpcRateLimiter = rpcRate.NullRequestLimitsHandler()
|
||||
if s, ok := a.delegate.(*consul.Server); ok {
|
||||
grpcRateLimiter = s.IncomingRPCLimiter()
|
||||
}
|
||||
|
||||
// This needs to happen after the initial auto-config is loaded, because TLS
|
||||
// can only be configured on the gRPC server at the point of creation.
|
||||
a.externalGRPCServer = external.NewServer(
|
||||
a.logger.Named("grpc.external"),
|
||||
metrics.Default(),
|
||||
a.tlsConfigurator,
|
||||
grpcRateLimiter,
|
||||
)
|
||||
|
||||
if err := a.startLicenseManager(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -618,10 +602,21 @@ func (a *Agent) Start(ctx context.Context) error {
|
|||
|
||||
// Setup either the client or the server.
|
||||
if c.ServerMode {
|
||||
server, err := consul.NewServer(consulCfg, a.baseDeps.Deps, a.externalGRPCServer)
|
||||
serverLogger := a.baseDeps.Logger.NamedIntercept(logging.ConsulServer)
|
||||
incomingRPCLimiter := consul.ConfiguredIncomingRPCLimiter(serverLogger, consulCfg)
|
||||
|
||||
a.externalGRPCServer = external.NewServer(
|
||||
a.logger.Named("grpc.external"),
|
||||
metrics.Default(),
|
||||
a.tlsConfigurator,
|
||||
incomingRPCLimiter,
|
||||
)
|
||||
|
||||
server, err := consul.NewServer(consulCfg, a.baseDeps.Deps, a.externalGRPCServer, incomingRPCLimiter, serverLogger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to start Consul server: %v", err)
|
||||
}
|
||||
incomingRPCLimiter.Register(server)
|
||||
a.delegate = server
|
||||
|
||||
if a.config.PeeringEnabled && a.config.ConnectEnabled {
|
||||
|
@ -642,6 +637,13 @@ func (a *Agent) Start(ctx context.Context) error {
|
|||
}
|
||||
|
||||
} else {
|
||||
a.externalGRPCServer = external.NewServer(
|
||||
a.logger.Named("grpc.external"),
|
||||
metrics.Default(),
|
||||
a.tlsConfigurator,
|
||||
rpcRate.NullRequestLimitsHandler(),
|
||||
)
|
||||
|
||||
client, err := consul.NewClient(consulCfg, a.baseDeps.Deps)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to start Consul client: %v", err)
|
||||
|
|
|
@ -563,7 +563,7 @@ func TestCAManager_Initialize_Logging(t *testing.T) {
|
|||
deps := newDefaultDeps(t, conf1)
|
||||
deps.Logger = logger
|
||||
|
||||
s1, err := NewServer(conf1, deps, grpc.NewServer())
|
||||
s1, err := NewServer(conf1, deps, grpc.NewServer(), nil, logger)
|
||||
require.NoError(t, err)
|
||||
defer s1.Shutdown()
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
|
|
@ -1556,7 +1556,7 @@ func TestLeader_ConfigEntryBootstrap_Fail(t *testing.T) {
|
|||
deps := newDefaultDeps(t, config)
|
||||
deps.Logger = logger
|
||||
|
||||
srv, err := NewServer(config, deps, grpc.NewServer())
|
||||
srv, err := NewServer(config, deps, grpc.NewServer(), nil, logger)
|
||||
require.NoError(t, err)
|
||||
defer srv.Shutdown()
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ package rate
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
|
@ -112,11 +113,14 @@ type RequestLimitsHandler interface {
|
|||
|
||||
// Handler enforces rate limits for incoming RPCs.
|
||||
type Handler struct {
|
||||
cfg *atomic.Pointer[HandlerConfig]
|
||||
delegate HandlerDelegate
|
||||
cfg *atomic.Pointer[HandlerConfig]
|
||||
leaderStatusProvider LeaderStatusProvider
|
||||
|
||||
limiter multilimiter.RateLimiter
|
||||
logger hclog.Logger
|
||||
|
||||
// TODO: replace this with the real logger.
|
||||
// https://github.com/hashicorp/consul/pull/15822
|
||||
logger hclog.Logger
|
||||
}
|
||||
|
||||
type HandlerConfig struct {
|
||||
|
@ -135,7 +139,8 @@ type HandlerConfig struct {
|
|||
GlobalReadConfig multilimiter.LimiterConfig
|
||||
}
|
||||
|
||||
type HandlerDelegate interface {
|
||||
//go:generate mockery --name LeaderStatusProvider --inpackage --filename mock_LeaderStatusProvider_test.go
|
||||
type LeaderStatusProvider interface {
|
||||
// IsLeader is used to determine whether the operation is being performed
|
||||
// against the cluster leader, such that if it can _only_ be performed by
|
||||
// the leader (e.g. write operations) we don't tell clients to retry against
|
||||
|
@ -143,16 +148,18 @@ type HandlerDelegate interface {
|
|||
IsLeader() bool
|
||||
}
|
||||
|
||||
func NewHandlerWithLimiter(cfg HandlerConfig, delegate HandlerDelegate,
|
||||
limiter multilimiter.RateLimiter, logger hclog.Logger) *Handler {
|
||||
func NewHandlerWithLimiter(
|
||||
cfg HandlerConfig,
|
||||
limiter multilimiter.RateLimiter,
|
||||
logger hclog.Logger) *Handler {
|
||||
|
||||
limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite)
|
||||
limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead)
|
||||
|
||||
h := &Handler{
|
||||
cfg: new(atomic.Pointer[HandlerConfig]),
|
||||
delegate: delegate,
|
||||
limiter: limiter,
|
||||
logger: logger,
|
||||
cfg: new(atomic.Pointer[HandlerConfig]),
|
||||
limiter: limiter,
|
||||
logger: logger,
|
||||
}
|
||||
h.cfg.Store(&cfg)
|
||||
|
||||
|
@ -160,9 +167,9 @@ func NewHandlerWithLimiter(cfg HandlerConfig, delegate HandlerDelegate,
|
|||
}
|
||||
|
||||
// NewHandler creates a new RPC rate limit handler.
|
||||
func NewHandler(cfg HandlerConfig, delegate HandlerDelegate, logger hclog.Logger) *Handler {
|
||||
func NewHandler(cfg HandlerConfig, logger hclog.Logger) *Handler {
|
||||
limiter := multilimiter.NewMultiLimiter(cfg.Config)
|
||||
return NewHandlerWithLimiter(cfg, delegate, limiter, logger)
|
||||
return NewHandlerWithLimiter(cfg, limiter, logger)
|
||||
}
|
||||
|
||||
// Run the limiter cleanup routine until the given context is canceled.
|
||||
|
@ -175,14 +182,45 @@ func (h *Handler) Run(ctx context.Context) {
|
|||
// Allow returns an error if the given operation is not allowed to proceed
|
||||
// because of an exhausted rate-limit.
|
||||
func (h *Handler) Allow(op Operation) error {
|
||||
|
||||
if h.leaderStatusProvider == nil {
|
||||
h.logger.Error("leaderStatusProvider required to be set via Register(). bailing on rate limiter")
|
||||
return nil
|
||||
// TODO: panic and make sure to use the server's recovery handler
|
||||
// panic("leaderStatusProvider required to be set via Register(..)")
|
||||
}
|
||||
|
||||
cfg := h.cfg.Load()
|
||||
if cfg.GlobalMode == ModeDisabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !h.limiter.Allow(globalWrite) {
|
||||
// TODO(NET-1383): actually implement the rate limiting logic and replace this returned nil.
|
||||
return nil
|
||||
for _, l := range h.limits(op) {
|
||||
if l.mode == ModeDisabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if h.limiter.Allow(l.ent) {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO: metrics.
|
||||
// TODO: is this the correct log-level?
|
||||
|
||||
enforced := l.mode == ModeEnforcing
|
||||
h.logger.Trace("RPC exceeded allowed rate limit",
|
||||
"rpc", op.Name,
|
||||
"source_addr", op.SourceAddr.String(),
|
||||
"limit_type", l.desc,
|
||||
"limit_enforced", enforced,
|
||||
)
|
||||
|
||||
if enforced {
|
||||
if h.leaderStatusProvider.IsLeader() && op.Type == OperationTypeWrite {
|
||||
return ErrRetryLater
|
||||
}
|
||||
return ErrRetryElsewhere
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -202,6 +240,48 @@ func (h *Handler) UpdateConfig(cfg HandlerConfig) {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *Handler) Register(leaderStatusProvider LeaderStatusProvider) {
|
||||
h.leaderStatusProvider = leaderStatusProvider
|
||||
}
|
||||
|
||||
type limit struct {
|
||||
mode Mode
|
||||
ent multilimiter.LimitedEntity
|
||||
desc string
|
||||
}
|
||||
|
||||
// limits returns the limits to check for the given operation (e.g. global +
|
||||
// ip-based + tenant-based).
|
||||
func (h *Handler) limits(op Operation) []limit {
|
||||
limits := make([]limit, 0)
|
||||
|
||||
if global := h.globalLimit(op); global != nil {
|
||||
limits = append(limits, *global)
|
||||
}
|
||||
|
||||
return limits
|
||||
}
|
||||
|
||||
func (h *Handler) globalLimit(op Operation) *limit {
|
||||
if op.Type == OperationTypeExempt {
|
||||
return nil
|
||||
}
|
||||
cfg := h.cfg.Load()
|
||||
|
||||
lim := &limit{mode: cfg.GlobalMode}
|
||||
switch op.Type {
|
||||
case OperationTypeRead:
|
||||
lim.desc = "global/read"
|
||||
lim.ent = globalRead
|
||||
case OperationTypeWrite:
|
||||
lim.desc = "global/write"
|
||||
lim.ent = globalWrite
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown operation type %d", op.Type))
|
||||
}
|
||||
return lim
|
||||
}
|
||||
|
||||
var (
|
||||
// globalWrite identifies the global rate limit applied to write operations.
|
||||
globalWrite = globalLimit("global.write")
|
||||
|
|
|
@ -1,15 +1,233 @@
|
|||
package rate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/multilimiter"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/multilimiter"
|
||||
)
|
||||
|
||||
//
|
||||
// Revisit test when handler.go:189 TODO implemented
|
||||
//
|
||||
// func TestHandler_Allow_PanicsWhenLeaderStatusProviderNotRegistered(t *testing.T) {
|
||||
// defer func() {
|
||||
// err := recover()
|
||||
// if err == nil {
|
||||
// t.Fatal("Run should panic")
|
||||
// }
|
||||
// }()
|
||||
|
||||
// handler := NewHandler(HandlerConfig{}, hclog.NewNullLogger())
|
||||
// handler.Allow(Operation{})
|
||||
// // intentionally skip handler.Register(...)
|
||||
// }
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
var (
|
||||
rpcName = "Foo.Bar"
|
||||
sourceAddr = net.TCPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:5678"))
|
||||
)
|
||||
|
||||
type limitCheck struct {
|
||||
limit multilimiter.LimitedEntity
|
||||
allow bool
|
||||
}
|
||||
testCases := map[string]struct {
|
||||
op Operation
|
||||
globalMode Mode
|
||||
checks []limitCheck
|
||||
isLeader bool
|
||||
expectErr error
|
||||
expectLog bool
|
||||
}{
|
||||
"operation exempt from limiting": {
|
||||
op: Operation{
|
||||
Type: OperationTypeExempt,
|
||||
Name: rpcName,
|
||||
SourceAddr: sourceAddr,
|
||||
},
|
||||
globalMode: ModeEnforcing,
|
||||
checks: []limitCheck{},
|
||||
expectErr: nil,
|
||||
expectLog: false,
|
||||
},
|
||||
"global write limit disabled": {
|
||||
op: Operation{
|
||||
Type: OperationTypeWrite,
|
||||
Name: rpcName,
|
||||
SourceAddr: sourceAddr,
|
||||
},
|
||||
globalMode: ModeDisabled,
|
||||
checks: []limitCheck{},
|
||||
expectErr: nil,
|
||||
expectLog: false,
|
||||
},
|
||||
"global write limit within allowance": {
|
||||
op: Operation{
|
||||
Type: OperationTypeWrite,
|
||||
Name: rpcName,
|
||||
SourceAddr: sourceAddr,
|
||||
},
|
||||
globalMode: ModeEnforcing,
|
||||
checks: []limitCheck{
|
||||
{limit: globalWrite, allow: true},
|
||||
},
|
||||
expectErr: nil,
|
||||
expectLog: false,
|
||||
},
|
||||
"global write limit exceeded (permissive)": {
|
||||
op: Operation{
|
||||
Type: OperationTypeWrite,
|
||||
Name: rpcName,
|
||||
SourceAddr: sourceAddr,
|
||||
},
|
||||
globalMode: ModePermissive,
|
||||
checks: []limitCheck{
|
||||
{limit: globalWrite, allow: false},
|
||||
},
|
||||
expectErr: nil,
|
||||
expectLog: true,
|
||||
},
|
||||
"global write limit exceeded (enforcing, leader)": {
|
||||
op: Operation{
|
||||
Type: OperationTypeWrite,
|
||||
Name: rpcName,
|
||||
SourceAddr: sourceAddr,
|
||||
},
|
||||
globalMode: ModeEnforcing,
|
||||
checks: []limitCheck{
|
||||
{limit: globalWrite, allow: false},
|
||||
},
|
||||
isLeader: true,
|
||||
expectErr: ErrRetryLater,
|
||||
expectLog: true,
|
||||
},
|
||||
"global write limit exceeded (enforcing, follower)": {
|
||||
op: Operation{
|
||||
Type: OperationTypeWrite,
|
||||
Name: rpcName,
|
||||
SourceAddr: sourceAddr,
|
||||
},
|
||||
globalMode: ModeEnforcing,
|
||||
checks: []limitCheck{
|
||||
{limit: globalWrite, allow: false},
|
||||
},
|
||||
isLeader: false,
|
||||
expectErr: ErrRetryElsewhere,
|
||||
expectLog: true,
|
||||
},
|
||||
"global read limit disabled": {
|
||||
op: Operation{
|
||||
Type: OperationTypeRead,
|
||||
Name: rpcName,
|
||||
SourceAddr: sourceAddr,
|
||||
},
|
||||
globalMode: ModeDisabled,
|
||||
checks: []limitCheck{},
|
||||
expectErr: nil,
|
||||
expectLog: false,
|
||||
},
|
||||
"global read limit within allowance": {
|
||||
op: Operation{
|
||||
Type: OperationTypeRead,
|
||||
Name: rpcName,
|
||||
SourceAddr: sourceAddr,
|
||||
},
|
||||
globalMode: ModeEnforcing,
|
||||
checks: []limitCheck{
|
||||
{limit: globalRead, allow: true},
|
||||
},
|
||||
expectErr: nil,
|
||||
expectLog: false,
|
||||
},
|
||||
"global read limit exceeded (permissive)": {
|
||||
op: Operation{
|
||||
Type: OperationTypeRead,
|
||||
Name: rpcName,
|
||||
SourceAddr: sourceAddr,
|
||||
},
|
||||
globalMode: ModePermissive,
|
||||
checks: []limitCheck{
|
||||
{limit: globalRead, allow: false},
|
||||
},
|
||||
expectErr: nil,
|
||||
expectLog: true,
|
||||
},
|
||||
"global read limit exceeded (enforcing, leader)": {
|
||||
op: Operation{
|
||||
Type: OperationTypeRead,
|
||||
Name: rpcName,
|
||||
SourceAddr: sourceAddr,
|
||||
},
|
||||
globalMode: ModeEnforcing,
|
||||
checks: []limitCheck{
|
||||
{limit: globalRead, allow: false},
|
||||
},
|
||||
isLeader: true,
|
||||
expectErr: ErrRetryElsewhere,
|
||||
expectLog: true,
|
||||
},
|
||||
"global read limit exceeded (enforcing, follower)": {
|
||||
op: Operation{
|
||||
Type: OperationTypeRead,
|
||||
Name: rpcName,
|
||||
SourceAddr: sourceAddr,
|
||||
},
|
||||
globalMode: ModeEnforcing,
|
||||
checks: []limitCheck{
|
||||
{limit: globalRead, allow: false},
|
||||
},
|
||||
isLeader: false,
|
||||
expectErr: ErrRetryElsewhere,
|
||||
expectLog: true,
|
||||
},
|
||||
}
|
||||
for desc, tc := range testCases {
|
||||
t.Run(desc, func(t *testing.T) {
|
||||
limiter := newMockLimiter(t)
|
||||
limiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
||||
for _, c := range tc.checks {
|
||||
limiter.On("Allow", c.limit).Return(c.allow)
|
||||
}
|
||||
|
||||
leaderStatusProvider := NewMockLeaderStatusProvider(t)
|
||||
leaderStatusProvider.On("IsLeader").Return(tc.isLeader).Maybe()
|
||||
|
||||
var output bytes.Buffer
|
||||
logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{
|
||||
Level: hclog.Trace,
|
||||
Output: &output,
|
||||
})
|
||||
|
||||
handler := NewHandlerWithLimiter(
|
||||
HandlerConfig{
|
||||
GlobalMode: tc.globalMode,
|
||||
},
|
||||
limiter,
|
||||
logger,
|
||||
)
|
||||
handler.Register(leaderStatusProvider)
|
||||
|
||||
require.Equal(t, tc.expectErr, handler.Allow(tc.op))
|
||||
|
||||
if tc.expectLog {
|
||||
require.Contains(t, output.String(), "RPC exceeded allowed rate limit")
|
||||
} else {
|
||||
require.Zero(t, output.Len(), "expected no logs to be emitted")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) {
|
||||
mockRateLimiter := multilimiter.NewMockRateLimiter(t)
|
||||
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
||||
|
@ -22,7 +240,7 @@ func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) {
|
|||
}
|
||||
|
||||
logger := hclog.NewNullLogger()
|
||||
NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger)
|
||||
NewHandlerWithLimiter(*cfg, mockRateLimiter, logger)
|
||||
mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 2)
|
||||
}
|
||||
|
||||
|
@ -83,7 +301,7 @@ func TestUpdateConfig(t *testing.T) {
|
|||
mockRateLimiter := multilimiter.NewMockRateLimiter(t)
|
||||
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
||||
logger := hclog.NewNullLogger()
|
||||
handler := NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger)
|
||||
handler := NewHandlerWithLimiter(*cfg, mockRateLimiter, logger)
|
||||
mockRateLimiter.Calls = nil
|
||||
tc.configModFunc(cfg)
|
||||
handler.UpdateConfig(*cfg)
|
||||
|
@ -139,7 +357,10 @@ func TestAllow(t *testing.T) {
|
|||
}
|
||||
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
||||
logger := hclog.NewNullLogger()
|
||||
handler := NewHandlerWithLimiter(*tc.cfg, nil, mockRateLimiter, logger)
|
||||
delegate := NewMockLeaderStatusProvider(t)
|
||||
delegate.On("IsLeader").Return(true).Maybe()
|
||||
handler := NewHandlerWithLimiter(*tc.cfg, mockRateLimiter, logger)
|
||||
handler.Register(delegate)
|
||||
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234"))
|
||||
mockRateLimiter.Calls = nil
|
||||
handler.Allow(Operation{Name: "test", SourceAddr: addr})
|
||||
|
@ -147,3 +368,24 @@ func TestAllow(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
var _ multilimiter.RateLimiter = (*mockLimiter)(nil)
|
||||
|
||||
func newMockLimiter(t *testing.T) *mockLimiter {
|
||||
l := &mockLimiter{}
|
||||
l.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { l.AssertExpectations(t) })
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
type mockLimiter struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockLimiter) Allow(v multilimiter.LimitedEntity) bool { return m.Called(v).Bool(0) }
|
||||
func (m *mockLimiter) Run(ctx context.Context) { m.Called(ctx) }
|
||||
func (m *mockLimiter) UpdateConfig(cfg multilimiter.LimiterConfig, prefix []byte) {
|
||||
m.Called(cfg, prefix)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
// Code generated by mockery v2.12.2. DO NOT EDIT.
|
||||
|
||||
package rate
|
||||
|
||||
import (
|
||||
testing "testing"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockLeaderStatusProvider is an autogenerated mock type for the LeaderStatusProvider type
|
||||
type MockLeaderStatusProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// IsLeader provides a mock function with given fields:
|
||||
func (_m *MockLeaderStatusProvider) IsLeader() bool {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func() bool); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewMockLeaderStatusProvider creates a new instance of MockLeaderStatusProvider. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
func NewMockLeaderStatusProvider(t testing.TB) *MockLeaderStatusProvider {
|
||||
mock := &MockLeaderStatusProvider{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
|
@ -407,7 +407,7 @@ type connHandler interface {
|
|||
|
||||
// NewServer is used to construct a new Consul server from the configuration
|
||||
// and extra options, potentially returning an error.
|
||||
func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Server, error) {
|
||||
func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incomingRPCLimiter rpcRate.RequestLimitsHandler, serverLogger hclog.InterceptLogger) (*Server, error) {
|
||||
logger := flat.Logger
|
||||
if err := config.CheckProtocolVersion(); err != nil {
|
||||
return nil, err
|
||||
|
@ -428,7 +428,6 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser
|
|||
// Create the shutdown channel - this is closed but never written to.
|
||||
shutdownCh := make(chan struct{})
|
||||
|
||||
serverLogger := flat.Logger.NamedIntercept(logging.ConsulServer)
|
||||
loggers := newLoggerStore(serverLogger)
|
||||
|
||||
fsmDeps := fsm.Deps{
|
||||
|
@ -439,6 +438,10 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser
|
|||
Publisher: flat.EventPublisher,
|
||||
}
|
||||
|
||||
if incomingRPCLimiter == nil {
|
||||
incomingRPCLimiter = rpcRate.NullRequestLimitsHandler()
|
||||
}
|
||||
|
||||
// Create server.
|
||||
s := &Server{
|
||||
config: config,
|
||||
|
@ -463,6 +466,7 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser
|
|||
aclAuthMethodValidators: authmethod.NewCache(),
|
||||
fsm: fsm.NewFromDeps(fsmDeps),
|
||||
publisher: flat.EventPublisher,
|
||||
incomingRPCLimiter: incomingRPCLimiter,
|
||||
}
|
||||
|
||||
s.hcpManager = hcp.NewManager(hcp.ManagerConfig{
|
||||
|
@ -471,17 +475,6 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser
|
|||
Logger: logger.Named("hcp_manager"),
|
||||
})
|
||||
|
||||
// TODO(NET-1380, NET-1381): thread this into the net/rpc and gRPC interceptors.
|
||||
if s.incomingRPCLimiter == nil {
|
||||
mlCfg := &multilimiter.Config{ReconcileCheckLimit: 30 * time.Second, ReconcileCheckInterval: time.Second}
|
||||
limitsConfig := &RequestLimits{
|
||||
Mode: rpcRate.RequestLimitsModeFromNameWithDefault(config.RequestLimitsMode),
|
||||
ReadRate: config.RequestLimitsReadRate,
|
||||
WriteRate: config.RequestLimitsWriteRate,
|
||||
}
|
||||
|
||||
s.incomingRPCLimiter = rpcRate.NewHandler(*s.convertConsulConfigToRateLimitHandlerConfig(*limitsConfig, mlCfg), s, s.logger)
|
||||
}
|
||||
s.incomingRPCLimiter.Run(&lib.StopChannelContext{StopCh: s.shutdownCh})
|
||||
|
||||
var recorder *middleware.RequestRecorder
|
||||
|
@ -1696,7 +1689,7 @@ func (s *Server) ReloadConfig(config ReloadableConfig) error {
|
|||
s.rpcLimiter.Store(rate.NewLimiter(config.RPCRateLimit, config.RPCMaxBurst))
|
||||
|
||||
if config.RequestLimits != nil {
|
||||
s.incomingRPCLimiter.UpdateConfig(*s.convertConsulConfigToRateLimitHandlerConfig(*config.RequestLimits, nil))
|
||||
s.incomingRPCLimiter.UpdateConfig(*convertConsulConfigToRateLimitHandlerConfig(*config.RequestLimits, nil))
|
||||
}
|
||||
|
||||
s.rpcConnLimiter.SetConfig(connlimit.Config{
|
||||
|
@ -1849,9 +1842,25 @@ func (s *Server) hcpServerStatus(deps Deps) hcp.StatusCallback {
|
|||
}
|
||||
}
|
||||
|
||||
// convertConsulConfigToRateLimitHandlerConfig creates a rate limite handler config
|
||||
// from the relevant fields in the consul runtime config.
|
||||
func (s *Server) convertConsulConfigToRateLimitHandlerConfig(limitsConfig RequestLimits, multilimiterConfig *multilimiter.Config) *rpcRate.HandlerConfig {
|
||||
func ConfiguredIncomingRPCLimiter(serverLogger hclog.InterceptLogger, consulCfg *Config) *rpcRate.Handler {
|
||||
mlCfg := &multilimiter.Config{ReconcileCheckLimit: 30 * time.Second, ReconcileCheckInterval: time.Second}
|
||||
limitsConfig := &RequestLimits{
|
||||
Mode: rpcRate.RequestLimitsModeFromNameWithDefault(consulCfg.RequestLimitsMode),
|
||||
ReadRate: consulCfg.RequestLimitsReadRate,
|
||||
WriteRate: consulCfg.RequestLimitsWriteRate,
|
||||
}
|
||||
|
||||
rateLimiterConfig := convertConsulConfigToRateLimitHandlerConfig(*limitsConfig, mlCfg)
|
||||
|
||||
incomingRPCLimiter := rpcRate.NewHandler(
|
||||
*rateLimiterConfig,
|
||||
serverLogger.Named("rpc-rate-limit"),
|
||||
)
|
||||
|
||||
return incomingRPCLimiter
|
||||
}
|
||||
|
||||
func convertConsulConfigToRateLimitHandlerConfig(limitsConfig RequestLimits, multilimiterConfig *multilimiter.Config) *rpcRate.HandlerConfig {
|
||||
hc := &rpcRate.HandlerConfig{
|
||||
GlobalMode: limitsConfig.Mode,
|
||||
GlobalReadConfig: multilimiter.LimiterConfig{
|
||||
|
@ -1870,11 +1879,6 @@ func (s *Server) convertConsulConfigToRateLimitHandlerConfig(limitsConfig Reques
|
|||
return hc
|
||||
}
|
||||
|
||||
// IncomingRPCLimiter returns the server's configured rate limit handler for
|
||||
// incoming RPCs. This is necessary because the external gRPC server is created
|
||||
// by the agent (as it is also used for xDS).
|
||||
func (s *Server) IncomingRPCLimiter() rpcRate.RequestLimitsHandler { return s.incomingRPCLimiter }
|
||||
|
||||
// peersInfoContent is used to help operators understand what happened to the
|
||||
// peers.json file. This is written to a file called peers.info in the same
|
||||
// location.
|
||||
|
|
|
@ -334,7 +334,7 @@ func newServerWithDeps(t *testing.T, c *Config, deps Deps) (*Server, error) {
|
|||
}
|
||||
}
|
||||
grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator, rpcRate.NullRequestLimitsHandler())
|
||||
srv, err := NewServer(c, deps, grpcServer)
|
||||
srv, err := NewServer(c, deps, grpcServer, nil, deps.Logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1241,7 +1241,7 @@ func TestServer_RPC_MetricsIntercept_Off(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
s1, err := NewServer(conf, deps, grpc.NewServer())
|
||||
s1, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -1279,7 +1279,7 @@ func TestServer_RPC_MetricsIntercept_Off(t *testing.T) {
|
|||
return nil
|
||||
}
|
||||
|
||||
s2, err := NewServer(conf, deps, grpc.NewServer())
|
||||
s2, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -1313,7 +1313,7 @@ func TestServer_RPC_RequestRecorder(t *testing.T) {
|
|||
deps := newDefaultDeps(t, conf)
|
||||
deps.NewRequestRecorderFunc = nil
|
||||
|
||||
s1, err := NewServer(conf, deps, grpc.NewServer())
|
||||
s1, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger)
|
||||
|
||||
require.Error(t, err, "need err when provider func is nil")
|
||||
require.Equal(t, err.Error(), "cannot initialize server without an RPC request recorder provider")
|
||||
|
@ -1332,7 +1332,7 @@ func TestServer_RPC_RequestRecorder(t *testing.T) {
|
|||
return nil
|
||||
}
|
||||
|
||||
s2, err := NewServer(conf, deps, grpc.NewServer())
|
||||
s2, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger)
|
||||
|
||||
require.Error(t, err, "need err when RequestRecorder is nil")
|
||||
require.Equal(t, err.Error(), "cannot initialize server with a nil RPC request recorder")
|
||||
|
|
|
@ -41,7 +41,7 @@ func ServerRateLimiterMiddleware(limiter rate.RequestLimitsHandler, panicHandler
|
|||
err := limiter.Allow(rate.Operation{
|
||||
Name: info.FullMethodName,
|
||||
SourceAddr: peer.Addr,
|
||||
// TODO: operation type.
|
||||
// TODO: add operation type from https://github.com/hashicorp/consul/pull/15564
|
||||
})
|
||||
|
||||
switch {
|
||||
|
|
|
@ -1594,7 +1594,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer {
|
|||
deps := newDefaultDeps(t, conf)
|
||||
externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator, rate.NullRequestLimitsHandler())
|
||||
|
||||
server, err := consul.NewServer(conf, deps, externalGRPCServer)
|
||||
server, err := consul.NewServer(conf, deps, externalGRPCServer, nil, deps.Logger)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Shutdown())
|
||||
|
|
Loading…
Reference in New Issue