open-consul/agent/grpc-middleware/rate.go

73 lines
2.0 KiB
Go
Raw Normal View History

package middleware
import (
"context"
"errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"google.golang.org/grpc/tap"
recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
"github.com/hashicorp/consul/agent/consul/rate"
)
// ServerRateLimiterMiddleware implements a ServerInHandle function to perform
// RPC rate limiting at the cheapest possible point (before the full request has
// been decoded).
func ServerRateLimiterMiddleware(limiter RateLimiter, panicHandler recovery.RecoveryHandlerFunc) tap.ServerInHandle {
return func(ctx context.Context, info *tap.Info) (_ context.Context, retErr error) {
// This function is called before unary and stream RPC interceptors, so we
// must handle our own panics here.
defer func() {
if r := recover(); r != nil {
retErr = panicHandler(r)
}
}()
// Do not rate-limit the xDS service, it handles its own limiting.
if info.FullMethodName == "/envoy.service.discovery.v3.AggregatedDiscoveryService/DeltaAggregatedResources" {
return ctx, nil
}
peer, ok := peer.FromContext(ctx)
if !ok {
// This should never happen!
return ctx, status.Error(codes.Internal, "gRPC rate limit middleware unable to read peer")
}
err := limiter.Allow(rate.Operation{
Name: info.FullMethodName,
SourceAddr: peer.Addr,
// TODO: operation type.
})
switch {
case err == nil:
return ctx, nil
case errors.Is(err, rate.ErrRetryElsewhere):
return ctx, status.Error(codes.ResourceExhausted, err.Error())
case errors.Is(err, rate.ErrRetryLater):
return ctx, status.Error(codes.Unavailable, err.Error())
default:
return ctx, status.Error(codes.Internal, err.Error())
}
}
}
//go:generate mockery --name RateLimiter --inpackage
type RateLimiter interface {
Allow(rate.Operation) error
}
// NullRateLimiter returns a RateLimiter that allows every operation.
func NullRateLimiter() RateLimiter {
return nullRateLimiter{}
}
type nullRateLimiter struct{}
func (nullRateLimiter) Allow(rate.Operation) error { return nil }