grpc: add rate-limiting middleware (#15550)
Implements the gRPC middleware for rate-limiting as a tap.ServerInHandle function (executed before the request is unmarshaled). Mappings between gRPC methods and their operation type are generated by a protoc plugin introduced by #15564.
This commit is contained in:
parent
4894848993
commit
c73707ca3c
|
@ -563,12 +563,19 @@ func (a *Agent) Start(ctx context.Context) error {
|
||||||
return fmt.Errorf("Failed to load TLS configurations after applying auto-config settings: %w", err)
|
return fmt.Errorf("Failed to load TLS configurations after applying auto-config settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// gRPC calls are only rate-limited on server, not client agents.
|
||||||
|
grpcRateLimiter := middleware.NullRateLimiter()
|
||||||
|
if s, ok := a.delegate.(*consul.Server); ok {
|
||||||
|
grpcRateLimiter = s.IncomingRPCLimiter()
|
||||||
|
}
|
||||||
|
|
||||||
// This needs to happen after the initial auto-config is loaded, because TLS
|
// This needs to happen after the initial auto-config is loaded, because TLS
|
||||||
// can only be configured on the gRPC server at the point of creation.
|
// can only be configured on the gRPC server at the point of creation.
|
||||||
a.externalGRPCServer = external.NewServer(
|
a.externalGRPCServer = external.NewServer(
|
||||||
a.logger.Named("grpc.external"),
|
a.logger.Named("grpc.external"),
|
||||||
metrics.Default(),
|
metrics.Default(),
|
||||||
a.tlsConfigurator,
|
a.tlsConfigurator,
|
||||||
|
grpcRateLimiter,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := a.startLicenseManager(ctx); err != nil {
|
if err := a.startLicenseManager(ctx); err != nil {
|
||||||
|
|
|
@ -18,7 +18,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/armon/go-metrics"
|
"github.com/armon/go-metrics"
|
||||||
"github.com/hashicorp/consul-net-rpc/net/rpc"
|
|
||||||
"github.com/hashicorp/go-connlimit"
|
"github.com/hashicorp/go-connlimit"
|
||||||
"github.com/hashicorp/go-hclog"
|
"github.com/hashicorp/go-hclog"
|
||||||
"github.com/hashicorp/go-memdb"
|
"github.com/hashicorp/go-memdb"
|
||||||
|
@ -30,6 +29,8 @@ import (
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul-net-rpc/net/rpc"
|
||||||
|
|
||||||
"github.com/hashicorp/consul/acl"
|
"github.com/hashicorp/consul/acl"
|
||||||
"github.com/hashicorp/consul/agent/consul/authmethod"
|
"github.com/hashicorp/consul/agent/consul/authmethod"
|
||||||
"github.com/hashicorp/consul/agent/consul/authmethod/ssoauth"
|
"github.com/hashicorp/consul/agent/consul/authmethod/ssoauth"
|
||||||
|
@ -876,7 +877,7 @@ func newGRPCHandlerFromConfig(deps Deps, config *Config, s *Server) connHandler
|
||||||
s.externalConnectCAServer.Register(srv)
|
s.externalConnectCAServer.Register(srv)
|
||||||
}
|
}
|
||||||
|
|
||||||
return agentgrpc.NewHandler(deps.Logger, config.RPCAddr, register, nil)
|
return agentgrpc.NewHandler(deps.Logger, config.RPCAddr, register, nil, s.incomingRPCLimiter)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) connectCARootsMonitor(ctx context.Context) {
|
func (s *Server) connectCARootsMonitor(ctx context.Context) {
|
||||||
|
@ -1829,6 +1830,11 @@ func (s *Server) hcpServerStatus(deps Deps) hcp.StatusCallback {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IncomingRPCLimiter returns the server's configured rate limit handler for
|
||||||
|
// incoming RPCs. This is necessary because the external gRPC server is created
|
||||||
|
// by the agent (as it is also used for xDS).
|
||||||
|
func (s *Server) IncomingRPCLimiter() *rpcRate.Handler { return s.incomingRPCLimiter }
|
||||||
|
|
||||||
// peersInfoContent is used to help operators understand what happened to the
|
// peersInfoContent is used to help operators understand what happened to the
|
||||||
// peers.json file. This is written to a file called peers.info in the same
|
// peers.json file. This is written to a file called peers.info in the same
|
||||||
// location.
|
// location.
|
||||||
|
|
|
@ -331,7 +331,7 @@ func newServerWithDeps(t *testing.T, c *Config, deps Deps) (*Server, error) {
|
||||||
oldNotify()
|
oldNotify()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator)
|
grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator, grpcmiddleware.NullRateLimiter())
|
||||||
srv, err := NewServer(c, deps, grpcServer)
|
srv, err := NewServer(c, deps, grpcServer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -23,7 +23,7 @@ var (
|
||||||
|
|
||||||
// NewServer constructs a gRPC server for the external gRPC port, to which
|
// NewServer constructs a gRPC server for the external gRPC port, to which
|
||||||
// handlers can be registered.
|
// handlers can be registered.
|
||||||
func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *tlsutil.Configurator) *grpc.Server {
|
func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *tlsutil.Configurator, limiter agentmiddleware.RateLimiter) *grpc.Server {
|
||||||
if metricsObj == nil {
|
if metricsObj == nil {
|
||||||
metricsObj = metrics.Default()
|
metricsObj = metrics.Default()
|
||||||
}
|
}
|
||||||
|
@ -48,6 +48,7 @@ func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *
|
||||||
opts := []grpc.ServerOption{
|
opts := []grpc.ServerOption{
|
||||||
grpc.MaxConcurrentStreams(2048),
|
grpc.MaxConcurrentStreams(2048),
|
||||||
grpc.MaxRecvMsgSize(50 * 1024 * 1024),
|
grpc.MaxRecvMsgSize(50 * 1024 * 1024),
|
||||||
|
grpc.InTapHandle(agentmiddleware.ServerRateLimiterMiddleware(limiter, agentmiddleware.NewPanicHandler(logger))),
|
||||||
grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)),
|
grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)),
|
||||||
middleware.WithUnaryServerChain(unaryInterceptors...),
|
middleware.WithUnaryServerChain(unaryInterceptors...),
|
||||||
middleware.WithStreamServerChain(streamInterceptors...),
|
middleware.WithStreamServerChain(streamInterceptors...),
|
||||||
|
|
|
@ -23,7 +23,7 @@ import (
|
||||||
func TestServer_EmitsStats(t *testing.T) {
|
func TestServer_EmitsStats(t *testing.T) {
|
||||||
sink, metricsObj := testutil.NewFakeSink(t)
|
sink, metricsObj := testutil.NewFakeSink(t)
|
||||||
|
|
||||||
srv := NewServer(hclog.Default(), metricsObj, nil)
|
srv := NewServer(hclog.Default(), metricsObj, nil, grpcmiddleware.NullRateLimiter())
|
||||||
|
|
||||||
testservice.RegisterSimpleServer(srv, &testservice.Simple{})
|
testservice.RegisterSimpleServer(srv, &testservice.Simple{})
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/armon/go-metrics"
|
"github.com/armon/go-metrics"
|
||||||
|
|
||||||
agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
||||||
|
|
||||||
middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||||
|
@ -24,7 +25,7 @@ var (
|
||||||
// NewHandler returns a gRPC server that accepts connections from Handle(conn).
|
// NewHandler returns a gRPC server that accepts connections from Handle(conn).
|
||||||
// The register function will be called with the grpc.Server to register
|
// The register function will be called with the grpc.Server to register
|
||||||
// gRPC services with the server.
|
// gRPC services with the server.
|
||||||
func NewHandler(logger Logger, addr net.Addr, register func(server *grpc.Server), metricsObj *metrics.Metrics) *Handler {
|
func NewHandler(logger Logger, addr net.Addr, register func(server *grpc.Server), metricsObj *metrics.Metrics, rateLimiter agentmiddleware.RateLimiter) *Handler {
|
||||||
if metricsObj == nil {
|
if metricsObj == nil {
|
||||||
metricsObj = metrics.Default()
|
metricsObj = metrics.Default()
|
||||||
}
|
}
|
||||||
|
@ -34,6 +35,7 @@ func NewHandler(logger Logger, addr net.Addr, register func(server *grpc.Server)
|
||||||
recoveryOpts := agentmiddleware.PanicHandlerMiddlewareOpts(logger)
|
recoveryOpts := agentmiddleware.PanicHandlerMiddlewareOpts(logger)
|
||||||
|
|
||||||
opts := []grpc.ServerOption{
|
opts := []grpc.ServerOption{
|
||||||
|
grpc.InTapHandle(agentmiddleware.ServerRateLimiterMiddleware(rateLimiter, agentmiddleware.NewPanicHandler(logger))),
|
||||||
grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)),
|
grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)),
|
||||||
middleware.WithUnaryServerChain(
|
middleware.WithUnaryServerChain(
|
||||||
// Add middlware interceptors to recover in case of panics.
|
// Add middlware interceptors to recover in case of panics.
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
middleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
||||||
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
|
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
|
||||||
"github.com/hashicorp/consul/agent/metadata"
|
"github.com/hashicorp/consul/agent/metadata"
|
||||||
"github.com/hashicorp/consul/agent/pool"
|
"github.com/hashicorp/consul/agent/pool"
|
||||||
|
@ -54,7 +55,7 @@ func newPanicTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsC
|
||||||
|
|
||||||
func newTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsConf *tlsutil.Configurator, register func(server *grpc.Server)) testServer {
|
func newTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsConf *tlsutil.Configurator, register func(server *grpc.Server)) testServer {
|
||||||
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
|
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
|
||||||
handler := NewHandler(logger, addr, register, nil)
|
handler := NewHandler(logger, addr, register, nil, middleware.NullRateLimiter())
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"github.com/hashicorp/consul/agent/consul/state"
|
"github.com/hashicorp/consul/agent/consul/state"
|
||||||
"github.com/hashicorp/consul/agent/consul/stream"
|
"github.com/hashicorp/consul/agent/consul/stream"
|
||||||
grpc "github.com/hashicorp/consul/agent/grpc-internal"
|
grpc "github.com/hashicorp/consul/agent/grpc-internal"
|
||||||
|
middleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
||||||
"github.com/hashicorp/consul/agent/structs"
|
"github.com/hashicorp/consul/agent/structs"
|
||||||
"github.com/hashicorp/consul/api"
|
"github.com/hashicorp/consul/api"
|
||||||
"github.com/hashicorp/consul/proto/pbcommon"
|
"github.com/hashicorp/consul/proto/pbcommon"
|
||||||
|
@ -380,6 +381,7 @@ func runTestServer(t *testing.T, server *Server) net.Addr {
|
||||||
pbsubscribe.RegisterStateChangeSubscriptionServer(srv, server)
|
pbsubscribe.RegisterStateChangeSubscriptionServer(srv, server)
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
|
middleware.NullRateLimiter(),
|
||||||
)
|
)
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
|
|
||||||
"github.com/hashicorp/go-hclog"
|
"github.com/hashicorp/go-hclog"
|
||||||
|
|
||||||
|
middleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
||||||
"github.com/hashicorp/consul/agent/grpc-middleware/testutil"
|
"github.com/hashicorp/consul/agent/grpc-middleware/testutil"
|
||||||
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
|
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
|
||||||
"github.com/hashicorp/consul/proto/prototest"
|
"github.com/hashicorp/consul/proto/prototest"
|
||||||
|
@ -25,7 +26,7 @@ func TestHandler_EmitsStats(t *testing.T) {
|
||||||
sink, metricsObj := testutil.NewFakeSink(t)
|
sink, metricsObj := testutil.NewFakeSink(t)
|
||||||
|
|
||||||
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
|
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
|
||||||
handler := NewHandler(hclog.Default(), addr, noopRegister, metricsObj)
|
handler := NewHandler(hclog.Default(), addr, noopRegister, metricsObj, middleware.NullRateLimiter())
|
||||||
|
|
||||||
testservice.RegisterSimpleServer(handler.srv, &testservice.Simple{})
|
testservice.RegisterSimpleServer(handler.srv, &testservice.Simple{})
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
// Code generated by mockery v2.12.0. DO NOT EDIT.
|
||||||
|
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
testing "testing"
|
||||||
|
|
||||||
|
rate "github.com/hashicorp/consul/agent/consul/rate"
|
||||||
|
mock "github.com/stretchr/testify/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockRateLimiter is an autogenerated mock type for the RateLimiter type
|
||||||
|
type MockRateLimiter struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow provides a mock function with given fields: _a0
|
||||||
|
func (_m *MockRateLimiter) Allow(_a0 rate.Operation) error {
|
||||||
|
ret := _m.Called(_a0)
|
||||||
|
|
||||||
|
var r0 error
|
||||||
|
if rf, ok := ret.Get(0).(func(rate.Operation) error); ok {
|
||||||
|
r0 = rf(_a0)
|
||||||
|
} else {
|
||||||
|
r0 = ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r0
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockRateLimiter creates a new instance of MockRateLimiter. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations.
|
||||||
|
func NewMockRateLimiter(t testing.TB) *MockRateLimiter {
|
||||||
|
mock := &MockRateLimiter{}
|
||||||
|
mock.Mock.Test(t)
|
||||||
|
|
||||||
|
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||||
|
|
||||||
|
return mock
|
||||||
|
}
|
|
@ -0,0 +1,72 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/peer"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"google.golang.org/grpc/tap"
|
||||||
|
|
||||||
|
recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul/agent/consul/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServerRateLimiterMiddleware implements a ServerInHandle function to perform
|
||||||
|
// RPC rate limiting at the cheapest possible point (before the full request has
|
||||||
|
// been decoded).
|
||||||
|
func ServerRateLimiterMiddleware(limiter RateLimiter, panicHandler recovery.RecoveryHandlerFunc) tap.ServerInHandle {
|
||||||
|
return func(ctx context.Context, info *tap.Info) (_ context.Context, retErr error) {
|
||||||
|
// This function is called before unary and stream RPC interceptors, so we
|
||||||
|
// must handle our own panics here.
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
retErr = panicHandler(r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Do not rate-limit the xDS service, it handles its own limiting.
|
||||||
|
if info.FullMethodName == "/envoy.service.discovery.v3.AggregatedDiscoveryService/DeltaAggregatedResources" {
|
||||||
|
return ctx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, ok := peer.FromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
// This should never happen!
|
||||||
|
return ctx, status.Error(codes.Internal, "gRPC rate limit middleware unable to read peer")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := limiter.Allow(rate.Operation{
|
||||||
|
Name: info.FullMethodName,
|
||||||
|
SourceAddr: peer.Addr,
|
||||||
|
// TODO: operation type.
|
||||||
|
})
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
return ctx, nil
|
||||||
|
case errors.Is(err, rate.ErrRetryElsewhere):
|
||||||
|
return ctx, status.Error(codes.ResourceExhausted, err.Error())
|
||||||
|
case errors.Is(err, rate.ErrRetryLater):
|
||||||
|
return ctx, status.Error(codes.Unavailable, err.Error())
|
||||||
|
default:
|
||||||
|
return ctx, status.Error(codes.Internal, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:generate mockery --name RateLimiter --inpackage
|
||||||
|
type RateLimiter interface {
|
||||||
|
Allow(rate.Operation) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NullRateLimiter returns a RateLimiter that allows every operation.
|
||||||
|
func NullRateLimiter() RateLimiter {
|
||||||
|
return nullRateLimiter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type nullRateLimiter struct{}
|
||||||
|
|
||||||
|
func (nullRateLimiter) Allow(rate.Operation) error { return nil }
|
|
@ -0,0 +1,111 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
"google.golang.org/grpc/health"
|
||||||
|
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-hclog"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul/agent/consul/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServerRateLimiterMiddleware_Integration(t *testing.T) {
|
||||||
|
limiter := NewMockRateLimiter(t)
|
||||||
|
|
||||||
|
server := grpc.NewServer(
|
||||||
|
grpc.InTapHandle(ServerRateLimiterMiddleware(limiter, NewPanicHandler(hclog.NewNullLogger()))),
|
||||||
|
)
|
||||||
|
server.RegisterService(&healthpb.Health_ServiceDesc, health.NewServer())
|
||||||
|
|
||||||
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if err := lis.Close(); err != nil {
|
||||||
|
t.Logf("failed to close listener: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
go server.Serve(lis)
|
||||||
|
t.Cleanup(server.Stop)
|
||||||
|
|
||||||
|
conn, err := grpc.Dial(
|
||||||
|
lis.Addr().String(),
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
t.Logf("failed to close client connection: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
client := healthpb.NewHealthClient(conn)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
t.Run("ErrRetryElsewhere = ResourceExhausted", func(t *testing.T) {
|
||||||
|
limiter.On("Allow", mock.Anything).
|
||||||
|
Run(func(args mock.Arguments) {
|
||||||
|
op := args.Get(0).(rate.Operation)
|
||||||
|
require.Equal(t, "/grpc.health.v1.Health/Check", op.Name)
|
||||||
|
|
||||||
|
addr := op.SourceAddr.(*net.TCPAddr)
|
||||||
|
require.True(t, addr.IP.IsLoopback())
|
||||||
|
}).
|
||||||
|
Return(rate.ErrRetryElsewhere).
|
||||||
|
Once()
|
||||||
|
|
||||||
|
_, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, codes.ResourceExhausted.String(), status.Code(err).String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ErrRetryLater = Unavailable", func(t *testing.T) {
|
||||||
|
limiter.On("Allow", mock.Anything).
|
||||||
|
Return(rate.ErrRetryLater).
|
||||||
|
Once()
|
||||||
|
|
||||||
|
_, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, codes.Unavailable.String(), status.Code(err).String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unexpected error", func(t *testing.T) {
|
||||||
|
limiter.On("Allow", mock.Anything).
|
||||||
|
Return(errors.New("uh oh")).
|
||||||
|
Once()
|
||||||
|
|
||||||
|
_, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, codes.Internal.String(), status.Code(err).String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("operation allowed", func(t *testing.T) {
|
||||||
|
limiter.On("Allow", mock.Anything).
|
||||||
|
Return(nil).
|
||||||
|
Once()
|
||||||
|
|
||||||
|
_, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Allow panics", func(t *testing.T) {
|
||||||
|
limiter.On("Allow", mock.Anything).
|
||||||
|
Panic("uh oh").
|
||||||
|
Once()
|
||||||
|
|
||||||
|
_, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, codes.Internal.String(), status.Code(err).String())
|
||||||
|
})
|
||||||
|
}
|
|
@ -1591,7 +1591,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer {
|
||||||
conf.ACLResolverSettings.EnterpriseMeta = *conf.AgentEnterpriseMeta()
|
conf.ACLResolverSettings.EnterpriseMeta = *conf.AgentEnterpriseMeta()
|
||||||
|
|
||||||
deps := newDefaultDeps(t, conf)
|
deps := newDefaultDeps(t, conf)
|
||||||
externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator)
|
externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator, agentmiddleware.NullRateLimiter())
|
||||||
|
|
||||||
server, err := consul.NewServer(conf, deps, externalGRPCServer)
|
server, err := consul.NewServer(conf, deps, externalGRPCServer)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
Loading…
Reference in New Issue