Wire in rate limiter to handle internal and external gRPC calls (#15857)
This commit is contained in:
parent
1586c7ba10
commit
006138beb4
|
@ -16,6 +16,7 @@ Thumbs.db
|
||||||
.idea
|
.idea
|
||||||
.vscode
|
.vscode
|
||||||
__debug_bin
|
__debug_bin
|
||||||
|
coverage.out
|
||||||
|
|
||||||
# MacOS
|
# MacOS
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
|
@ -565,22 +565,6 @@ func (a *Agent) Start(ctx context.Context) error {
|
||||||
return fmt.Errorf("Failed to load TLS configurations after applying auto-config settings: %w", err)
|
return fmt.Errorf("Failed to load TLS configurations after applying auto-config settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// gRPC calls are only rate-limited on server, not client agents.
|
|
||||||
var grpcRateLimiter rpcRate.RequestLimitsHandler
|
|
||||||
grpcRateLimiter = rpcRate.NullRequestLimitsHandler()
|
|
||||||
if s, ok := a.delegate.(*consul.Server); ok {
|
|
||||||
grpcRateLimiter = s.IncomingRPCLimiter()
|
|
||||||
}
|
|
||||||
|
|
||||||
// This needs to happen after the initial auto-config is loaded, because TLS
|
|
||||||
// can only be configured on the gRPC server at the point of creation.
|
|
||||||
a.externalGRPCServer = external.NewServer(
|
|
||||||
a.logger.Named("grpc.external"),
|
|
||||||
metrics.Default(),
|
|
||||||
a.tlsConfigurator,
|
|
||||||
grpcRateLimiter,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err := a.startLicenseManager(ctx); err != nil {
|
if err := a.startLicenseManager(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -618,10 +602,21 @@ func (a *Agent) Start(ctx context.Context) error {
|
||||||
|
|
||||||
// Setup either the client or the server.
|
// Setup either the client or the server.
|
||||||
if c.ServerMode {
|
if c.ServerMode {
|
||||||
server, err := consul.NewServer(consulCfg, a.baseDeps.Deps, a.externalGRPCServer)
|
serverLogger := a.baseDeps.Logger.NamedIntercept(logging.ConsulServer)
|
||||||
|
incomingRPCLimiter := consul.ConfiguredIncomingRPCLimiter(serverLogger, consulCfg)
|
||||||
|
|
||||||
|
a.externalGRPCServer = external.NewServer(
|
||||||
|
a.logger.Named("grpc.external"),
|
||||||
|
metrics.Default(),
|
||||||
|
a.tlsConfigurator,
|
||||||
|
incomingRPCLimiter,
|
||||||
|
)
|
||||||
|
|
||||||
|
server, err := consul.NewServer(consulCfg, a.baseDeps.Deps, a.externalGRPCServer, incomingRPCLimiter, serverLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Failed to start Consul server: %v", err)
|
return fmt.Errorf("Failed to start Consul server: %v", err)
|
||||||
}
|
}
|
||||||
|
incomingRPCLimiter.Register(server)
|
||||||
a.delegate = server
|
a.delegate = server
|
||||||
|
|
||||||
if a.config.PeeringEnabled && a.config.ConnectEnabled {
|
if a.config.PeeringEnabled && a.config.ConnectEnabled {
|
||||||
|
@ -642,6 +637,13 @@ func (a *Agent) Start(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
a.externalGRPCServer = external.NewServer(
|
||||||
|
a.logger.Named("grpc.external"),
|
||||||
|
metrics.Default(),
|
||||||
|
a.tlsConfigurator,
|
||||||
|
rpcRate.NullRequestLimitsHandler(),
|
||||||
|
)
|
||||||
|
|
||||||
client, err := consul.NewClient(consulCfg, a.baseDeps.Deps)
|
client, err := consul.NewClient(consulCfg, a.baseDeps.Deps)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Failed to start Consul client: %v", err)
|
return fmt.Errorf("Failed to start Consul client: %v", err)
|
||||||
|
|
|
@ -563,7 +563,7 @@ func TestCAManager_Initialize_Logging(t *testing.T) {
|
||||||
deps := newDefaultDeps(t, conf1)
|
deps := newDefaultDeps(t, conf1)
|
||||||
deps.Logger = logger
|
deps.Logger = logger
|
||||||
|
|
||||||
s1, err := NewServer(conf1, deps, grpc.NewServer())
|
s1, err := NewServer(conf1, deps, grpc.NewServer(), nil, logger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer s1.Shutdown()
|
defer s1.Shutdown()
|
||||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||||
|
|
|
@ -1556,7 +1556,7 @@ func TestLeader_ConfigEntryBootstrap_Fail(t *testing.T) {
|
||||||
deps := newDefaultDeps(t, config)
|
deps := newDefaultDeps(t, config)
|
||||||
deps.Logger = logger
|
deps.Logger = logger
|
||||||
|
|
||||||
srv, err := NewServer(config, deps, grpc.NewServer())
|
srv, err := NewServer(config, deps, grpc.NewServer(), nil, logger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer srv.Shutdown()
|
defer srv.Shutdown()
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ package rate
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -112,11 +113,14 @@ type RequestLimitsHandler interface {
|
||||||
|
|
||||||
// Handler enforces rate limits for incoming RPCs.
|
// Handler enforces rate limits for incoming RPCs.
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
cfg *atomic.Pointer[HandlerConfig]
|
cfg *atomic.Pointer[HandlerConfig]
|
||||||
delegate HandlerDelegate
|
leaderStatusProvider LeaderStatusProvider
|
||||||
|
|
||||||
limiter multilimiter.RateLimiter
|
limiter multilimiter.RateLimiter
|
||||||
logger hclog.Logger
|
|
||||||
|
// TODO: replace this with the real logger.
|
||||||
|
// https://github.com/hashicorp/consul/pull/15822
|
||||||
|
logger hclog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type HandlerConfig struct {
|
type HandlerConfig struct {
|
||||||
|
@ -135,7 +139,8 @@ type HandlerConfig struct {
|
||||||
GlobalReadConfig multilimiter.LimiterConfig
|
GlobalReadConfig multilimiter.LimiterConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type HandlerDelegate interface {
|
//go:generate mockery --name LeaderStatusProvider --inpackage --filename mock_LeaderStatusProvider_test.go
|
||||||
|
type LeaderStatusProvider interface {
|
||||||
// IsLeader is used to determine whether the operation is being performed
|
// IsLeader is used to determine whether the operation is being performed
|
||||||
// against the cluster leader, such that if it can _only_ be performed by
|
// against the cluster leader, such that if it can _only_ be performed by
|
||||||
// the leader (e.g. write operations) we don't tell clients to retry against
|
// the leader (e.g. write operations) we don't tell clients to retry against
|
||||||
|
@ -143,16 +148,18 @@ type HandlerDelegate interface {
|
||||||
IsLeader() bool
|
IsLeader() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandlerWithLimiter(cfg HandlerConfig, delegate HandlerDelegate,
|
func NewHandlerWithLimiter(
|
||||||
limiter multilimiter.RateLimiter, logger hclog.Logger) *Handler {
|
cfg HandlerConfig,
|
||||||
|
limiter multilimiter.RateLimiter,
|
||||||
|
logger hclog.Logger) *Handler {
|
||||||
|
|
||||||
limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite)
|
limiter.UpdateConfig(cfg.GlobalWriteConfig, globalWrite)
|
||||||
limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead)
|
limiter.UpdateConfig(cfg.GlobalReadConfig, globalRead)
|
||||||
|
|
||||||
h := &Handler{
|
h := &Handler{
|
||||||
cfg: new(atomic.Pointer[HandlerConfig]),
|
cfg: new(atomic.Pointer[HandlerConfig]),
|
||||||
delegate: delegate,
|
limiter: limiter,
|
||||||
limiter: limiter,
|
logger: logger,
|
||||||
logger: logger,
|
|
||||||
}
|
}
|
||||||
h.cfg.Store(&cfg)
|
h.cfg.Store(&cfg)
|
||||||
|
|
||||||
|
@ -160,9 +167,9 @@ func NewHandlerWithLimiter(cfg HandlerConfig, delegate HandlerDelegate,
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler creates a new RPC rate limit handler.
|
// NewHandler creates a new RPC rate limit handler.
|
||||||
func NewHandler(cfg HandlerConfig, delegate HandlerDelegate, logger hclog.Logger) *Handler {
|
func NewHandler(cfg HandlerConfig, logger hclog.Logger) *Handler {
|
||||||
limiter := multilimiter.NewMultiLimiter(cfg.Config)
|
limiter := multilimiter.NewMultiLimiter(cfg.Config)
|
||||||
return NewHandlerWithLimiter(cfg, delegate, limiter, logger)
|
return NewHandlerWithLimiter(cfg, limiter, logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run the limiter cleanup routine until the given context is canceled.
|
// Run the limiter cleanup routine until the given context is canceled.
|
||||||
|
@ -175,14 +182,45 @@ func (h *Handler) Run(ctx context.Context) {
|
||||||
// Allow returns an error if the given operation is not allowed to proceed
|
// Allow returns an error if the given operation is not allowed to proceed
|
||||||
// because of an exhausted rate-limit.
|
// because of an exhausted rate-limit.
|
||||||
func (h *Handler) Allow(op Operation) error {
|
func (h *Handler) Allow(op Operation) error {
|
||||||
|
|
||||||
|
if h.leaderStatusProvider == nil {
|
||||||
|
h.logger.Error("leaderStatusProvider required to be set via Register(). bailing on rate limiter")
|
||||||
|
return nil
|
||||||
|
// TODO: panic and make sure to use the server's recovery handler
|
||||||
|
// panic("leaderStatusProvider required to be set via Register(..)")
|
||||||
|
}
|
||||||
|
|
||||||
cfg := h.cfg.Load()
|
cfg := h.cfg.Load()
|
||||||
if cfg.GlobalMode == ModeDisabled {
|
if cfg.GlobalMode == ModeDisabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !h.limiter.Allow(globalWrite) {
|
for _, l := range h.limits(op) {
|
||||||
// TODO(NET-1383): actually implement the rate limiting logic and replace this returned nil.
|
if l.mode == ModeDisabled {
|
||||||
return nil
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.limiter.Allow(l.ent) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: metrics.
|
||||||
|
// TODO: is this the correct log-level?
|
||||||
|
|
||||||
|
enforced := l.mode == ModeEnforcing
|
||||||
|
h.logger.Trace("RPC exceeded allowed rate limit",
|
||||||
|
"rpc", op.Name,
|
||||||
|
"source_addr", op.SourceAddr.String(),
|
||||||
|
"limit_type", l.desc,
|
||||||
|
"limit_enforced", enforced,
|
||||||
|
)
|
||||||
|
|
||||||
|
if enforced {
|
||||||
|
if h.leaderStatusProvider.IsLeader() && op.Type == OperationTypeWrite {
|
||||||
|
return ErrRetryLater
|
||||||
|
}
|
||||||
|
return ErrRetryElsewhere
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -202,6 +240,48 @@ func (h *Handler) UpdateConfig(cfg HandlerConfig) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) Register(leaderStatusProvider LeaderStatusProvider) {
|
||||||
|
h.leaderStatusProvider = leaderStatusProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
type limit struct {
|
||||||
|
mode Mode
|
||||||
|
ent multilimiter.LimitedEntity
|
||||||
|
desc string
|
||||||
|
}
|
||||||
|
|
||||||
|
// limits returns the limits to check for the given operation (e.g. global +
|
||||||
|
// ip-based + tenant-based).
|
||||||
|
func (h *Handler) limits(op Operation) []limit {
|
||||||
|
limits := make([]limit, 0)
|
||||||
|
|
||||||
|
if global := h.globalLimit(op); global != nil {
|
||||||
|
limits = append(limits, *global)
|
||||||
|
}
|
||||||
|
|
||||||
|
return limits
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) globalLimit(op Operation) *limit {
|
||||||
|
if op.Type == OperationTypeExempt {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cfg := h.cfg.Load()
|
||||||
|
|
||||||
|
lim := &limit{mode: cfg.GlobalMode}
|
||||||
|
switch op.Type {
|
||||||
|
case OperationTypeRead:
|
||||||
|
lim.desc = "global/read"
|
||||||
|
lim.ent = globalRead
|
||||||
|
case OperationTypeWrite:
|
||||||
|
lim.desc = "global/write"
|
||||||
|
lim.ent = globalWrite
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unknown operation type %d", op.Type))
|
||||||
|
}
|
||||||
|
return lim
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// globalWrite identifies the global rate limit applied to write operations.
|
// globalWrite identifies the global rate limit applied to write operations.
|
||||||
globalWrite = globalLimit("global.write")
|
globalWrite = globalLimit("global.write")
|
||||||
|
|
|
@ -1,15 +1,233 @@
|
||||||
package rate
|
package rate
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent/consul/multilimiter"
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/hashicorp/go-hclog"
|
"github.com/hashicorp/go-hclog"
|
||||||
mock "github.com/stretchr/testify/mock"
|
|
||||||
|
"github.com/hashicorp/consul/agent/consul/multilimiter"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//
|
||||||
|
// Revisit test when handler.go:189 TODO implemented
|
||||||
|
//
|
||||||
|
// func TestHandler_Allow_PanicsWhenLeaderStatusProviderNotRegistered(t *testing.T) {
|
||||||
|
// defer func() {
|
||||||
|
// err := recover()
|
||||||
|
// if err == nil {
|
||||||
|
// t.Fatal("Run should panic")
|
||||||
|
// }
|
||||||
|
// }()
|
||||||
|
|
||||||
|
// handler := NewHandler(HandlerConfig{}, hclog.NewNullLogger())
|
||||||
|
// handler.Allow(Operation{})
|
||||||
|
// // intentionally skip handler.Register(...)
|
||||||
|
// }
|
||||||
|
|
||||||
|
func TestHandler(t *testing.T) {
|
||||||
|
var (
|
||||||
|
rpcName = "Foo.Bar"
|
||||||
|
sourceAddr = net.TCPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:5678"))
|
||||||
|
)
|
||||||
|
|
||||||
|
type limitCheck struct {
|
||||||
|
limit multilimiter.LimitedEntity
|
||||||
|
allow bool
|
||||||
|
}
|
||||||
|
testCases := map[string]struct {
|
||||||
|
op Operation
|
||||||
|
globalMode Mode
|
||||||
|
checks []limitCheck
|
||||||
|
isLeader bool
|
||||||
|
expectErr error
|
||||||
|
expectLog bool
|
||||||
|
}{
|
||||||
|
"operation exempt from limiting": {
|
||||||
|
op: Operation{
|
||||||
|
Type: OperationTypeExempt,
|
||||||
|
Name: rpcName,
|
||||||
|
SourceAddr: sourceAddr,
|
||||||
|
},
|
||||||
|
globalMode: ModeEnforcing,
|
||||||
|
checks: []limitCheck{},
|
||||||
|
expectErr: nil,
|
||||||
|
expectLog: false,
|
||||||
|
},
|
||||||
|
"global write limit disabled": {
|
||||||
|
op: Operation{
|
||||||
|
Type: OperationTypeWrite,
|
||||||
|
Name: rpcName,
|
||||||
|
SourceAddr: sourceAddr,
|
||||||
|
},
|
||||||
|
globalMode: ModeDisabled,
|
||||||
|
checks: []limitCheck{},
|
||||||
|
expectErr: nil,
|
||||||
|
expectLog: false,
|
||||||
|
},
|
||||||
|
"global write limit within allowance": {
|
||||||
|
op: Operation{
|
||||||
|
Type: OperationTypeWrite,
|
||||||
|
Name: rpcName,
|
||||||
|
SourceAddr: sourceAddr,
|
||||||
|
},
|
||||||
|
globalMode: ModeEnforcing,
|
||||||
|
checks: []limitCheck{
|
||||||
|
{limit: globalWrite, allow: true},
|
||||||
|
},
|
||||||
|
expectErr: nil,
|
||||||
|
expectLog: false,
|
||||||
|
},
|
||||||
|
"global write limit exceeded (permissive)": {
|
||||||
|
op: Operation{
|
||||||
|
Type: OperationTypeWrite,
|
||||||
|
Name: rpcName,
|
||||||
|
SourceAddr: sourceAddr,
|
||||||
|
},
|
||||||
|
globalMode: ModePermissive,
|
||||||
|
checks: []limitCheck{
|
||||||
|
{limit: globalWrite, allow: false},
|
||||||
|
},
|
||||||
|
expectErr: nil,
|
||||||
|
expectLog: true,
|
||||||
|
},
|
||||||
|
"global write limit exceeded (enforcing, leader)": {
|
||||||
|
op: Operation{
|
||||||
|
Type: OperationTypeWrite,
|
||||||
|
Name: rpcName,
|
||||||
|
SourceAddr: sourceAddr,
|
||||||
|
},
|
||||||
|
globalMode: ModeEnforcing,
|
||||||
|
checks: []limitCheck{
|
||||||
|
{limit: globalWrite, allow: false},
|
||||||
|
},
|
||||||
|
isLeader: true,
|
||||||
|
expectErr: ErrRetryLater,
|
||||||
|
expectLog: true,
|
||||||
|
},
|
||||||
|
"global write limit exceeded (enforcing, follower)": {
|
||||||
|
op: Operation{
|
||||||
|
Type: OperationTypeWrite,
|
||||||
|
Name: rpcName,
|
||||||
|
SourceAddr: sourceAddr,
|
||||||
|
},
|
||||||
|
globalMode: ModeEnforcing,
|
||||||
|
checks: []limitCheck{
|
||||||
|
{limit: globalWrite, allow: false},
|
||||||
|
},
|
||||||
|
isLeader: false,
|
||||||
|
expectErr: ErrRetryElsewhere,
|
||||||
|
expectLog: true,
|
||||||
|
},
|
||||||
|
"global read limit disabled": {
|
||||||
|
op: Operation{
|
||||||
|
Type: OperationTypeRead,
|
||||||
|
Name: rpcName,
|
||||||
|
SourceAddr: sourceAddr,
|
||||||
|
},
|
||||||
|
globalMode: ModeDisabled,
|
||||||
|
checks: []limitCheck{},
|
||||||
|
expectErr: nil,
|
||||||
|
expectLog: false,
|
||||||
|
},
|
||||||
|
"global read limit within allowance": {
|
||||||
|
op: Operation{
|
||||||
|
Type: OperationTypeRead,
|
||||||
|
Name: rpcName,
|
||||||
|
SourceAddr: sourceAddr,
|
||||||
|
},
|
||||||
|
globalMode: ModeEnforcing,
|
||||||
|
checks: []limitCheck{
|
||||||
|
{limit: globalRead, allow: true},
|
||||||
|
},
|
||||||
|
expectErr: nil,
|
||||||
|
expectLog: false,
|
||||||
|
},
|
||||||
|
"global read limit exceeded (permissive)": {
|
||||||
|
op: Operation{
|
||||||
|
Type: OperationTypeRead,
|
||||||
|
Name: rpcName,
|
||||||
|
SourceAddr: sourceAddr,
|
||||||
|
},
|
||||||
|
globalMode: ModePermissive,
|
||||||
|
checks: []limitCheck{
|
||||||
|
{limit: globalRead, allow: false},
|
||||||
|
},
|
||||||
|
expectErr: nil,
|
||||||
|
expectLog: true,
|
||||||
|
},
|
||||||
|
"global read limit exceeded (enforcing, leader)": {
|
||||||
|
op: Operation{
|
||||||
|
Type: OperationTypeRead,
|
||||||
|
Name: rpcName,
|
||||||
|
SourceAddr: sourceAddr,
|
||||||
|
},
|
||||||
|
globalMode: ModeEnforcing,
|
||||||
|
checks: []limitCheck{
|
||||||
|
{limit: globalRead, allow: false},
|
||||||
|
},
|
||||||
|
isLeader: true,
|
||||||
|
expectErr: ErrRetryElsewhere,
|
||||||
|
expectLog: true,
|
||||||
|
},
|
||||||
|
"global read limit exceeded (enforcing, follower)": {
|
||||||
|
op: Operation{
|
||||||
|
Type: OperationTypeRead,
|
||||||
|
Name: rpcName,
|
||||||
|
SourceAddr: sourceAddr,
|
||||||
|
},
|
||||||
|
globalMode: ModeEnforcing,
|
||||||
|
checks: []limitCheck{
|
||||||
|
{limit: globalRead, allow: false},
|
||||||
|
},
|
||||||
|
isLeader: false,
|
||||||
|
expectErr: ErrRetryElsewhere,
|
||||||
|
expectLog: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for desc, tc := range testCases {
|
||||||
|
t.Run(desc, func(t *testing.T) {
|
||||||
|
limiter := newMockLimiter(t)
|
||||||
|
limiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
||||||
|
for _, c := range tc.checks {
|
||||||
|
limiter.On("Allow", c.limit).Return(c.allow)
|
||||||
|
}
|
||||||
|
|
||||||
|
leaderStatusProvider := NewMockLeaderStatusProvider(t)
|
||||||
|
leaderStatusProvider.On("IsLeader").Return(tc.isLeader).Maybe()
|
||||||
|
|
||||||
|
var output bytes.Buffer
|
||||||
|
logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{
|
||||||
|
Level: hclog.Trace,
|
||||||
|
Output: &output,
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := NewHandlerWithLimiter(
|
||||||
|
HandlerConfig{
|
||||||
|
GlobalMode: tc.globalMode,
|
||||||
|
},
|
||||||
|
limiter,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
handler.Register(leaderStatusProvider)
|
||||||
|
|
||||||
|
require.Equal(t, tc.expectErr, handler.Allow(tc.op))
|
||||||
|
|
||||||
|
if tc.expectLog {
|
||||||
|
require.Contains(t, output.String(), "RPC exceeded allowed rate limit")
|
||||||
|
} else {
|
||||||
|
require.Zero(t, output.Len(), "expected no logs to be emitted")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) {
|
func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) {
|
||||||
mockRateLimiter := multilimiter.NewMockRateLimiter(t)
|
mockRateLimiter := multilimiter.NewMockRateLimiter(t)
|
||||||
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
||||||
|
@ -22,7 +240,7 @@ func TestNewHandlerWithLimiter_CallsUpdateConfig(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := hclog.NewNullLogger()
|
logger := hclog.NewNullLogger()
|
||||||
NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger)
|
NewHandlerWithLimiter(*cfg, mockRateLimiter, logger)
|
||||||
mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 2)
|
mockRateLimiter.AssertNumberOfCalls(t, "UpdateConfig", 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,7 +301,7 @@ func TestUpdateConfig(t *testing.T) {
|
||||||
mockRateLimiter := multilimiter.NewMockRateLimiter(t)
|
mockRateLimiter := multilimiter.NewMockRateLimiter(t)
|
||||||
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
||||||
logger := hclog.NewNullLogger()
|
logger := hclog.NewNullLogger()
|
||||||
handler := NewHandlerWithLimiter(*cfg, nil, mockRateLimiter, logger)
|
handler := NewHandlerWithLimiter(*cfg, mockRateLimiter, logger)
|
||||||
mockRateLimiter.Calls = nil
|
mockRateLimiter.Calls = nil
|
||||||
tc.configModFunc(cfg)
|
tc.configModFunc(cfg)
|
||||||
handler.UpdateConfig(*cfg)
|
handler.UpdateConfig(*cfg)
|
||||||
|
@ -139,7 +357,10 @@ func TestAllow(t *testing.T) {
|
||||||
}
|
}
|
||||||
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
|
||||||
logger := hclog.NewNullLogger()
|
logger := hclog.NewNullLogger()
|
||||||
handler := NewHandlerWithLimiter(*tc.cfg, nil, mockRateLimiter, logger)
|
delegate := NewMockLeaderStatusProvider(t)
|
||||||
|
delegate.On("IsLeader").Return(true).Maybe()
|
||||||
|
handler := NewHandlerWithLimiter(*tc.cfg, mockRateLimiter, logger)
|
||||||
|
handler.Register(delegate)
|
||||||
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234"))
|
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234"))
|
||||||
mockRateLimiter.Calls = nil
|
mockRateLimiter.Calls = nil
|
||||||
handler.Allow(Operation{Name: "test", SourceAddr: addr})
|
handler.Allow(Operation{Name: "test", SourceAddr: addr})
|
||||||
|
@ -147,3 +368,24 @@ func TestAllow(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ multilimiter.RateLimiter = (*mockLimiter)(nil)
|
||||||
|
|
||||||
|
func newMockLimiter(t *testing.T) *mockLimiter {
|
||||||
|
l := &mockLimiter{}
|
||||||
|
l.Mock.Test(t)
|
||||||
|
|
||||||
|
t.Cleanup(func() { l.AssertExpectations(t) })
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockLimiter struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockLimiter) Allow(v multilimiter.LimitedEntity) bool { return m.Called(v).Bool(0) }
|
||||||
|
func (m *mockLimiter) Run(ctx context.Context) { m.Called(ctx) }
|
||||||
|
func (m *mockLimiter) UpdateConfig(cfg multilimiter.LimiterConfig, prefix []byte) {
|
||||||
|
m.Called(cfg, prefix)
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
// Code generated by mockery v2.12.2. DO NOT EDIT.
|
||||||
|
|
||||||
|
package rate
|
||||||
|
|
||||||
|
import (
|
||||||
|
testing "testing"
|
||||||
|
|
||||||
|
mock "github.com/stretchr/testify/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockLeaderStatusProvider is an autogenerated mock type for the LeaderStatusProvider type
|
||||||
|
type MockLeaderStatusProvider struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLeader provides a mock function with given fields:
|
||||||
|
func (_m *MockLeaderStatusProvider) IsLeader() bool {
|
||||||
|
ret := _m.Called()
|
||||||
|
|
||||||
|
var r0 bool
|
||||||
|
if rf, ok := ret.Get(0).(func() bool); ok {
|
||||||
|
r0 = rf()
|
||||||
|
} else {
|
||||||
|
r0 = ret.Get(0).(bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockLeaderStatusProvider creates a new instance of MockLeaderStatusProvider. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations.
|
||||||
|
func NewMockLeaderStatusProvider(t testing.TB) *MockLeaderStatusProvider {
|
||||||
|
mock := &MockLeaderStatusProvider{}
|
||||||
|
mock.Mock.Test(t)
|
||||||
|
|
||||||
|
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||||
|
|
||||||
|
return mock
|
||||||
|
}
|
|
@ -407,7 +407,7 @@ type connHandler interface {
|
||||||
|
|
||||||
// NewServer is used to construct a new Consul server from the configuration
|
// NewServer is used to construct a new Consul server from the configuration
|
||||||
// and extra options, potentially returning an error.
|
// and extra options, potentially returning an error.
|
||||||
func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Server, error) {
|
func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incomingRPCLimiter rpcRate.RequestLimitsHandler, serverLogger hclog.InterceptLogger) (*Server, error) {
|
||||||
logger := flat.Logger
|
logger := flat.Logger
|
||||||
if err := config.CheckProtocolVersion(); err != nil {
|
if err := config.CheckProtocolVersion(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -428,7 +428,6 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser
|
||||||
// Create the shutdown channel - this is closed but never written to.
|
// Create the shutdown channel - this is closed but never written to.
|
||||||
shutdownCh := make(chan struct{})
|
shutdownCh := make(chan struct{})
|
||||||
|
|
||||||
serverLogger := flat.Logger.NamedIntercept(logging.ConsulServer)
|
|
||||||
loggers := newLoggerStore(serverLogger)
|
loggers := newLoggerStore(serverLogger)
|
||||||
|
|
||||||
fsmDeps := fsm.Deps{
|
fsmDeps := fsm.Deps{
|
||||||
|
@ -439,6 +438,10 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser
|
||||||
Publisher: flat.EventPublisher,
|
Publisher: flat.EventPublisher,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if incomingRPCLimiter == nil {
|
||||||
|
incomingRPCLimiter = rpcRate.NullRequestLimitsHandler()
|
||||||
|
}
|
||||||
|
|
||||||
// Create server.
|
// Create server.
|
||||||
s := &Server{
|
s := &Server{
|
||||||
config: config,
|
config: config,
|
||||||
|
@ -463,6 +466,7 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser
|
||||||
aclAuthMethodValidators: authmethod.NewCache(),
|
aclAuthMethodValidators: authmethod.NewCache(),
|
||||||
fsm: fsm.NewFromDeps(fsmDeps),
|
fsm: fsm.NewFromDeps(fsmDeps),
|
||||||
publisher: flat.EventPublisher,
|
publisher: flat.EventPublisher,
|
||||||
|
incomingRPCLimiter: incomingRPCLimiter,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.hcpManager = hcp.NewManager(hcp.ManagerConfig{
|
s.hcpManager = hcp.NewManager(hcp.ManagerConfig{
|
||||||
|
@ -471,17 +475,6 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server) (*Ser
|
||||||
Logger: logger.Named("hcp_manager"),
|
Logger: logger.Named("hcp_manager"),
|
||||||
})
|
})
|
||||||
|
|
||||||
// TODO(NET-1380, NET-1381): thread this into the net/rpc and gRPC interceptors.
|
|
||||||
if s.incomingRPCLimiter == nil {
|
|
||||||
mlCfg := &multilimiter.Config{ReconcileCheckLimit: 30 * time.Second, ReconcileCheckInterval: time.Second}
|
|
||||||
limitsConfig := &RequestLimits{
|
|
||||||
Mode: rpcRate.RequestLimitsModeFromNameWithDefault(config.RequestLimitsMode),
|
|
||||||
ReadRate: config.RequestLimitsReadRate,
|
|
||||||
WriteRate: config.RequestLimitsWriteRate,
|
|
||||||
}
|
|
||||||
|
|
||||||
s.incomingRPCLimiter = rpcRate.NewHandler(*s.convertConsulConfigToRateLimitHandlerConfig(*limitsConfig, mlCfg), s, s.logger)
|
|
||||||
}
|
|
||||||
s.incomingRPCLimiter.Run(&lib.StopChannelContext{StopCh: s.shutdownCh})
|
s.incomingRPCLimiter.Run(&lib.StopChannelContext{StopCh: s.shutdownCh})
|
||||||
|
|
||||||
var recorder *middleware.RequestRecorder
|
var recorder *middleware.RequestRecorder
|
||||||
|
@ -1696,7 +1689,7 @@ func (s *Server) ReloadConfig(config ReloadableConfig) error {
|
||||||
s.rpcLimiter.Store(rate.NewLimiter(config.RPCRateLimit, config.RPCMaxBurst))
|
s.rpcLimiter.Store(rate.NewLimiter(config.RPCRateLimit, config.RPCMaxBurst))
|
||||||
|
|
||||||
if config.RequestLimits != nil {
|
if config.RequestLimits != nil {
|
||||||
s.incomingRPCLimiter.UpdateConfig(*s.convertConsulConfigToRateLimitHandlerConfig(*config.RequestLimits, nil))
|
s.incomingRPCLimiter.UpdateConfig(*convertConsulConfigToRateLimitHandlerConfig(*config.RequestLimits, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
s.rpcConnLimiter.SetConfig(connlimit.Config{
|
s.rpcConnLimiter.SetConfig(connlimit.Config{
|
||||||
|
@ -1849,9 +1842,25 @@ func (s *Server) hcpServerStatus(deps Deps) hcp.StatusCallback {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// convertConsulConfigToRateLimitHandlerConfig creates a rate limite handler config
|
func ConfiguredIncomingRPCLimiter(serverLogger hclog.InterceptLogger, consulCfg *Config) *rpcRate.Handler {
|
||||||
// from the relevant fields in the consul runtime config.
|
mlCfg := &multilimiter.Config{ReconcileCheckLimit: 30 * time.Second, ReconcileCheckInterval: time.Second}
|
||||||
func (s *Server) convertConsulConfigToRateLimitHandlerConfig(limitsConfig RequestLimits, multilimiterConfig *multilimiter.Config) *rpcRate.HandlerConfig {
|
limitsConfig := &RequestLimits{
|
||||||
|
Mode: rpcRate.RequestLimitsModeFromNameWithDefault(consulCfg.RequestLimitsMode),
|
||||||
|
ReadRate: consulCfg.RequestLimitsReadRate,
|
||||||
|
WriteRate: consulCfg.RequestLimitsWriteRate,
|
||||||
|
}
|
||||||
|
|
||||||
|
rateLimiterConfig := convertConsulConfigToRateLimitHandlerConfig(*limitsConfig, mlCfg)
|
||||||
|
|
||||||
|
incomingRPCLimiter := rpcRate.NewHandler(
|
||||||
|
*rateLimiterConfig,
|
||||||
|
serverLogger.Named("rpc-rate-limit"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return incomingRPCLimiter
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertConsulConfigToRateLimitHandlerConfig(limitsConfig RequestLimits, multilimiterConfig *multilimiter.Config) *rpcRate.HandlerConfig {
|
||||||
hc := &rpcRate.HandlerConfig{
|
hc := &rpcRate.HandlerConfig{
|
||||||
GlobalMode: limitsConfig.Mode,
|
GlobalMode: limitsConfig.Mode,
|
||||||
GlobalReadConfig: multilimiter.LimiterConfig{
|
GlobalReadConfig: multilimiter.LimiterConfig{
|
||||||
|
@ -1870,11 +1879,6 @@ func (s *Server) convertConsulConfigToRateLimitHandlerConfig(limitsConfig Reques
|
||||||
return hc
|
return hc
|
||||||
}
|
}
|
||||||
|
|
||||||
// IncomingRPCLimiter returns the server's configured rate limit handler for
|
|
||||||
// incoming RPCs. This is necessary because the external gRPC server is created
|
|
||||||
// by the agent (as it is also used for xDS).
|
|
||||||
func (s *Server) IncomingRPCLimiter() rpcRate.RequestLimitsHandler { return s.incomingRPCLimiter }
|
|
||||||
|
|
||||||
// peersInfoContent is used to help operators understand what happened to the
|
// peersInfoContent is used to help operators understand what happened to the
|
||||||
// peers.json file. This is written to a file called peers.info in the same
|
// peers.json file. This is written to a file called peers.info in the same
|
||||||
// location.
|
// location.
|
||||||
|
|
|
@ -334,7 +334,7 @@ func newServerWithDeps(t *testing.T, c *Config, deps Deps) (*Server, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator, rpcRate.NullRequestLimitsHandler())
|
grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator, rpcRate.NullRequestLimitsHandler())
|
||||||
srv, err := NewServer(c, deps, grpcServer)
|
srv, err := NewServer(c, deps, grpcServer, nil, deps.Logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -1241,7 +1241,7 @@ func TestServer_RPC_MetricsIntercept_Off(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s1, err := NewServer(conf, deps, grpc.NewServer())
|
s1, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -1279,7 +1279,7 @@ func TestServer_RPC_MetricsIntercept_Off(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
s2, err := NewServer(conf, deps, grpc.NewServer())
|
s2, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %v", err)
|
t.Fatalf("err: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -1313,7 +1313,7 @@ func TestServer_RPC_RequestRecorder(t *testing.T) {
|
||||||
deps := newDefaultDeps(t, conf)
|
deps := newDefaultDeps(t, conf)
|
||||||
deps.NewRequestRecorderFunc = nil
|
deps.NewRequestRecorderFunc = nil
|
||||||
|
|
||||||
s1, err := NewServer(conf, deps, grpc.NewServer())
|
s1, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger)
|
||||||
|
|
||||||
require.Error(t, err, "need err when provider func is nil")
|
require.Error(t, err, "need err when provider func is nil")
|
||||||
require.Equal(t, err.Error(), "cannot initialize server without an RPC request recorder provider")
|
require.Equal(t, err.Error(), "cannot initialize server without an RPC request recorder provider")
|
||||||
|
@ -1332,7 +1332,7 @@ func TestServer_RPC_RequestRecorder(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
s2, err := NewServer(conf, deps, grpc.NewServer())
|
s2, err := NewServer(conf, deps, grpc.NewServer(), nil, deps.Logger)
|
||||||
|
|
||||||
require.Error(t, err, "need err when RequestRecorder is nil")
|
require.Error(t, err, "need err when RequestRecorder is nil")
|
||||||
require.Equal(t, err.Error(), "cannot initialize server with a nil RPC request recorder")
|
require.Equal(t, err.Error(), "cannot initialize server with a nil RPC request recorder")
|
||||||
|
|
|
@ -41,7 +41,7 @@ func ServerRateLimiterMiddleware(limiter rate.RequestLimitsHandler, panicHandler
|
||||||
err := limiter.Allow(rate.Operation{
|
err := limiter.Allow(rate.Operation{
|
||||||
Name: info.FullMethodName,
|
Name: info.FullMethodName,
|
||||||
SourceAddr: peer.Addr,
|
SourceAddr: peer.Addr,
|
||||||
// TODO: operation type.
|
// TODO: add operation type from https://github.com/hashicorp/consul/pull/15564
|
||||||
})
|
})
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
|
|
|
@ -1594,7 +1594,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer {
|
||||||
deps := newDefaultDeps(t, conf)
|
deps := newDefaultDeps(t, conf)
|
||||||
externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator, rate.NullRequestLimitsHandler())
|
externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator, rate.NullRequestLimitsHandler())
|
||||||
|
|
||||||
server, err := consul.NewServer(conf, deps, externalGRPCServer)
|
server, err := consul.NewServer(conf, deps, externalGRPCServer, nil, deps.Logger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, server.Shutdown())
|
require.NoError(t, server.Shutdown())
|
||||||
|
|
Loading…
Reference in New Issue