// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package middleware import ( "context" "errors" "google.golang.org/grpc/codes" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" "google.golang.org/grpc/tap" recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" "github.com/hashicorp/consul/agent/consul/rate" ) // ServerRateLimiterMiddleware implements a ServerInHandle function to perform // RPC rate limiting at the cheapest possible point (before the full request has // been decoded). func ServerRateLimiterMiddleware(limiter rate.RequestLimitsHandler, panicHandler recovery.RecoveryHandlerFunc, logger Logger) tap.ServerInHandle { return func(ctx context.Context, info *tap.Info) (_ context.Context, retErr error) { // This function is called before unary and stream RPC interceptors, so we // must handle our own panics here. defer func() { if r := recover(); r != nil { retErr = panicHandler(r) } }() // Do not rate-limit the xDS service, it handles its own limiting. if info.FullMethodName == "/envoy.service.discovery.v3.AggregatedDiscoveryService/DeltaAggregatedResources" { return ctx, nil } peer, ok := peer.FromContext(ctx) if !ok { // This should never happen! return ctx, status.Error(codes.Internal, "gRPC rate limit middleware unable to read peer") } operationSpec, ok := rpcRateLimitSpecs[info.FullMethodName] if !ok { logger.Warn("failed to determine which rate limit to apply to RPC", "rpc", info.FullMethodName) return ctx, nil } err := limiter.Allow(rate.Operation{ Name: info.FullMethodName, SourceAddr: peer.Addr, Type: operationSpec.Type, Category: operationSpec.Category, }) switch { case err == nil: return ctx, nil case errors.Is(err, rate.ErrRetryElsewhere): return ctx, status.Error(codes.ResourceExhausted, err.Error()) case errors.Is(err, rate.ErrRetryLater): return ctx, status.Error(codes.Unavailable, err.Error()) default: return ctx, status.Error(codes.Internal, err.Error()) } } }