grpc: add rate-limiting middleware (#15550)
Implements the gRPC middleware for rate-limiting as a tap.ServerInHandle function (executed before the request is unmarshaled). Mappings between gRPC methods and their operation type are generated by a protoc plugin introduced by #15564.
This commit is contained in:
parent
4894848993
commit
c73707ca3c
|
@ -563,12 +563,19 @@ func (a *Agent) Start(ctx context.Context) error {
|
|||
return fmt.Errorf("Failed to load TLS configurations after applying auto-config settings: %w", err)
|
||||
}
|
||||
|
||||
// gRPC calls are only rate-limited on server, not client agents.
|
||||
grpcRateLimiter := middleware.NullRateLimiter()
|
||||
if s, ok := a.delegate.(*consul.Server); ok {
|
||||
grpcRateLimiter = s.IncomingRPCLimiter()
|
||||
}
|
||||
|
||||
// This needs to happen after the initial auto-config is loaded, because TLS
|
||||
// can only be configured on the gRPC server at the point of creation.
|
||||
a.externalGRPCServer = external.NewServer(
|
||||
a.logger.Named("grpc.external"),
|
||||
metrics.Default(),
|
||||
a.tlsConfigurator,
|
||||
grpcRateLimiter,
|
||||
)
|
||||
|
||||
if err := a.startLicenseManager(ctx); err != nil {
|
||||
|
|
|
@ -18,7 +18,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul-net-rpc/net/rpc"
|
||||
"github.com/hashicorp/go-connlimit"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
|
@ -30,6 +29,8 @@ import (
|
|||
"golang.org/x/time/rate"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/hashicorp/consul-net-rpc/net/rpc"
|
||||
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod"
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod/ssoauth"
|
||||
|
@ -876,7 +877,7 @@ func newGRPCHandlerFromConfig(deps Deps, config *Config, s *Server) connHandler
|
|||
s.externalConnectCAServer.Register(srv)
|
||||
}
|
||||
|
||||
return agentgrpc.NewHandler(deps.Logger, config.RPCAddr, register, nil)
|
||||
return agentgrpc.NewHandler(deps.Logger, config.RPCAddr, register, nil, s.incomingRPCLimiter)
|
||||
}
|
||||
|
||||
func (s *Server) connectCARootsMonitor(ctx context.Context) {
|
||||
|
@ -1829,6 +1830,11 @@ func (s *Server) hcpServerStatus(deps Deps) hcp.StatusCallback {
|
|||
}
|
||||
}
|
||||
|
||||
// IncomingRPCLimiter returns the server's configured rate limit handler for
|
||||
// incoming RPCs. This is necessary because the external gRPC server is created
|
||||
// by the agent (as it is also used for xDS).
|
||||
func (s *Server) IncomingRPCLimiter() *rpcRate.Handler { return s.incomingRPCLimiter }
|
||||
|
||||
// peersInfoContent is used to help operators understand what happened to the
|
||||
// peers.json file. This is written to a file called peers.info in the same
|
||||
// location.
|
||||
|
|
|
@ -331,7 +331,7 @@ func newServerWithDeps(t *testing.T, c *Config, deps Deps) (*Server, error) {
|
|||
oldNotify()
|
||||
}
|
||||
}
|
||||
grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator)
|
||||
grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator, grpcmiddleware.NullRateLimiter())
|
||||
srv, err := NewServer(c, deps, grpcServer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -23,7 +23,7 @@ var (
|
|||
|
||||
// NewServer constructs a gRPC server for the external gRPC port, to which
|
||||
// handlers can be registered.
|
||||
func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *tlsutil.Configurator) *grpc.Server {
|
||||
func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *tlsutil.Configurator, limiter agentmiddleware.RateLimiter) *grpc.Server {
|
||||
if metricsObj == nil {
|
||||
metricsObj = metrics.Default()
|
||||
}
|
||||
|
@ -48,6 +48,7 @@ func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *
|
|||
opts := []grpc.ServerOption{
|
||||
grpc.MaxConcurrentStreams(2048),
|
||||
grpc.MaxRecvMsgSize(50 * 1024 * 1024),
|
||||
grpc.InTapHandle(agentmiddleware.ServerRateLimiterMiddleware(limiter, agentmiddleware.NewPanicHandler(logger))),
|
||||
grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)),
|
||||
middleware.WithUnaryServerChain(unaryInterceptors...),
|
||||
middleware.WithStreamServerChain(streamInterceptors...),
|
||||
|
|
|
@ -23,7 +23,7 @@ import (
|
|||
func TestServer_EmitsStats(t *testing.T) {
|
||||
sink, metricsObj := testutil.NewFakeSink(t)
|
||||
|
||||
srv := NewServer(hclog.Default(), metricsObj, nil)
|
||||
srv := NewServer(hclog.Default(), metricsObj, nil, grpcmiddleware.NullRateLimiter())
|
||||
|
||||
testservice.RegisterSimpleServer(srv, &testservice.Simple{})
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
|
||||
agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
||||
|
||||
middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
|
@ -24,7 +25,7 @@ var (
|
|||
// NewHandler returns a gRPC server that accepts connections from Handle(conn).
|
||||
// The register function will be called with the grpc.Server to register
|
||||
// gRPC services with the server.
|
||||
func NewHandler(logger Logger, addr net.Addr, register func(server *grpc.Server), metricsObj *metrics.Metrics) *Handler {
|
||||
func NewHandler(logger Logger, addr net.Addr, register func(server *grpc.Server), metricsObj *metrics.Metrics, rateLimiter agentmiddleware.RateLimiter) *Handler {
|
||||
if metricsObj == nil {
|
||||
metricsObj = metrics.Default()
|
||||
}
|
||||
|
@ -34,6 +35,7 @@ func NewHandler(logger Logger, addr net.Addr, register func(server *grpc.Server)
|
|||
recoveryOpts := agentmiddleware.PanicHandlerMiddlewareOpts(logger)
|
||||
|
||||
opts := []grpc.ServerOption{
|
||||
grpc.InTapHandle(agentmiddleware.ServerRateLimiterMiddleware(rateLimiter, agentmiddleware.NewPanicHandler(logger))),
|
||||
grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)),
|
||||
middleware.WithUnaryServerChain(
|
||||
// Add middlware interceptors to recover in case of panics.
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
middleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
||||
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
|
||||
"github.com/hashicorp/consul/agent/metadata"
|
||||
"github.com/hashicorp/consul/agent/pool"
|
||||
|
@ -54,7 +55,7 @@ func newPanicTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsC
|
|||
|
||||
func newTestServer(t *testing.T, logger hclog.Logger, name, dc string, tlsConf *tlsutil.Configurator, register func(server *grpc.Server)) testServer {
|
||||
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
|
||||
handler := NewHandler(logger, addr, register, nil)
|
||||
handler := NewHandler(logger, addr, register, nil, middleware.NullRateLimiter())
|
||||
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
"github.com/hashicorp/consul/agent/consul/stream"
|
||||
grpc "github.com/hashicorp/consul/agent/grpc-internal"
|
||||
middleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/proto/pbcommon"
|
||||
|
@ -380,6 +381,7 @@ func runTestServer(t *testing.T, server *Server) net.Addr {
|
|||
pbsubscribe.RegisterStateChangeSubscriptionServer(srv, server)
|
||||
},
|
||||
nil,
|
||||
middleware.NullRateLimiter(),
|
||||
)
|
||||
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
|
||||
middleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
||||
"github.com/hashicorp/consul/agent/grpc-middleware/testutil"
|
||||
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
|
||||
"github.com/hashicorp/consul/proto/prototest"
|
||||
|
@ -25,7 +26,7 @@ func TestHandler_EmitsStats(t *testing.T) {
|
|||
sink, metricsObj := testutil.NewFakeSink(t)
|
||||
|
||||
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
|
||||
handler := NewHandler(hclog.Default(), addr, noopRegister, metricsObj)
|
||||
handler := NewHandler(hclog.Default(), addr, noopRegister, metricsObj, middleware.NullRateLimiter())
|
||||
|
||||
testservice.RegisterSimpleServer(handler.srv, &testservice.Simple{})
|
||||
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
// Code generated by mockery v2.12.0. DO NOT EDIT.
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
testing "testing"
|
||||
|
||||
rate "github.com/hashicorp/consul/agent/consul/rate"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockRateLimiter is an autogenerated mock type for the RateLimiter type
|
||||
type MockRateLimiter struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Allow provides a mock function with given fields: _a0
|
||||
func (_m *MockRateLimiter) Allow(_a0 rate.Operation) error {
|
||||
ret := _m.Called(_a0)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(rate.Operation) error); ok {
|
||||
r0 = rf(_a0)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewMockRateLimiter creates a new instance of MockRateLimiter. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
func NewMockRateLimiter(t testing.TB) *MockRateLimiter {
|
||||
mock := &MockRateLimiter{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/grpc/tap"
|
||||
|
||||
recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/rate"
|
||||
)
|
||||
|
||||
// ServerRateLimiterMiddleware implements a ServerInHandle function to perform
|
||||
// RPC rate limiting at the cheapest possible point (before the full request has
|
||||
// been decoded).
|
||||
func ServerRateLimiterMiddleware(limiter RateLimiter, panicHandler recovery.RecoveryHandlerFunc) tap.ServerInHandle {
|
||||
return func(ctx context.Context, info *tap.Info) (_ context.Context, retErr error) {
|
||||
// This function is called before unary and stream RPC interceptors, so we
|
||||
// must handle our own panics here.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
retErr = panicHandler(r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Do not rate-limit the xDS service, it handles its own limiting.
|
||||
if info.FullMethodName == "/envoy.service.discovery.v3.AggregatedDiscoveryService/DeltaAggregatedResources" {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
peer, ok := peer.FromContext(ctx)
|
||||
if !ok {
|
||||
// This should never happen!
|
||||
return ctx, status.Error(codes.Internal, "gRPC rate limit middleware unable to read peer")
|
||||
}
|
||||
|
||||
err := limiter.Allow(rate.Operation{
|
||||
Name: info.FullMethodName,
|
||||
SourceAddr: peer.Addr,
|
||||
// TODO: operation type.
|
||||
})
|
||||
|
||||
switch {
|
||||
case err == nil:
|
||||
return ctx, nil
|
||||
case errors.Is(err, rate.ErrRetryElsewhere):
|
||||
return ctx, status.Error(codes.ResourceExhausted, err.Error())
|
||||
case errors.Is(err, rate.ErrRetryLater):
|
||||
return ctx, status.Error(codes.Unavailable, err.Error())
|
||||
default:
|
||||
return ctx, status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//go:generate mockery --name RateLimiter --inpackage
|
||||
type RateLimiter interface {
|
||||
Allow(rate.Operation) error
|
||||
}
|
||||
|
||||
// NullRateLimiter returns a RateLimiter that allows every operation.
|
||||
func NullRateLimiter() RateLimiter {
|
||||
return nullRateLimiter{}
|
||||
}
|
||||
|
||||
type nullRateLimiter struct{}
|
||||
|
||||
func (nullRateLimiter) Allow(rate.Operation) error { return nil }
|
|
@ -0,0 +1,111 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/health"
|
||||
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/rate"
|
||||
)
|
||||
|
||||
func TestServerRateLimiterMiddleware_Integration(t *testing.T) {
|
||||
limiter := NewMockRateLimiter(t)
|
||||
|
||||
server := grpc.NewServer(
|
||||
grpc.InTapHandle(ServerRateLimiterMiddleware(limiter, NewPanicHandler(hclog.NewNullLogger()))),
|
||||
)
|
||||
server.RegisterService(&healthpb.Health_ServiceDesc, health.NewServer())
|
||||
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
if err := lis.Close(); err != nil {
|
||||
t.Logf("failed to close listener: %v", err)
|
||||
}
|
||||
})
|
||||
go server.Serve(lis)
|
||||
t.Cleanup(server.Stop)
|
||||
|
||||
conn, err := grpc.Dial(
|
||||
lis.Addr().String(),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("failed to close client connection: %v", err)
|
||||
}
|
||||
})
|
||||
client := healthpb.NewHealthClient(conn)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
t.Run("ErrRetryElsewhere = ResourceExhausted", func(t *testing.T) {
|
||||
limiter.On("Allow", mock.Anything).
|
||||
Run(func(args mock.Arguments) {
|
||||
op := args.Get(0).(rate.Operation)
|
||||
require.Equal(t, "/grpc.health.v1.Health/Check", op.Name)
|
||||
|
||||
addr := op.SourceAddr.(*net.TCPAddr)
|
||||
require.True(t, addr.IP.IsLoopback())
|
||||
}).
|
||||
Return(rate.ErrRetryElsewhere).
|
||||
Once()
|
||||
|
||||
_, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.ResourceExhausted.String(), status.Code(err).String())
|
||||
})
|
||||
|
||||
t.Run("ErrRetryLater = Unavailable", func(t *testing.T) {
|
||||
limiter.On("Allow", mock.Anything).
|
||||
Return(rate.ErrRetryLater).
|
||||
Once()
|
||||
|
||||
_, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.Unavailable.String(), status.Code(err).String())
|
||||
})
|
||||
|
||||
t.Run("unexpected error", func(t *testing.T) {
|
||||
limiter.On("Allow", mock.Anything).
|
||||
Return(errors.New("uh oh")).
|
||||
Once()
|
||||
|
||||
_, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.Internal.String(), status.Code(err).String())
|
||||
})
|
||||
|
||||
t.Run("operation allowed", func(t *testing.T) {
|
||||
limiter.On("Allow", mock.Anything).
|
||||
Return(nil).
|
||||
Once()
|
||||
|
||||
_, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Allow panics", func(t *testing.T) {
|
||||
limiter.On("Allow", mock.Anything).
|
||||
Panic("uh oh").
|
||||
Once()
|
||||
|
||||
_, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.Internal.String(), status.Code(err).String())
|
||||
})
|
||||
}
|
|
@ -1591,7 +1591,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer {
|
|||
conf.ACLResolverSettings.EnterpriseMeta = *conf.AgentEnterpriseMeta()
|
||||
|
||||
deps := newDefaultDeps(t, conf)
|
||||
externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator)
|
||||
externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator, agentmiddleware.NullRateLimiter())
|
||||
|
||||
server, err := consul.NewServer(conf, deps, externalGRPCServer)
|
||||
require.NoError(t, err)
|
||||
|
|
Loading…
Reference in New Issue