feat: panic handler in rpc rate limit interceptor (#16022)

* feat: handle panic in rpc rate limit interceptor

* test: additional test cases to rpc rate limiting interceptor

* refactor: remove unused listener
This commit is contained in:
Poonam Jadhav 2023-01-25 14:13:38 -05:00 committed by GitHub
parent 3e5e03aa95
commit c50bf92b84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 78 additions and 3 deletions

View File

@ -490,7 +490,7 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incom
}
rpcServerOpts := []func(*rpc.Server){
rpc.WithPreBodyInterceptor(middleware.GetNetRPCRateLimitingInterceptor(s.incomingRPCLimiter)),
rpc.WithPreBodyInterceptor(middleware.GetNetRPCRateLimitingInterceptor(s.incomingRPCLimiter, middleware.NewPanicHandler(s.logger))),
}
if flat.GetNetRPCInterceptorFunc != nil {

View File

@ -160,9 +160,16 @@ func GetNetRPCInterceptor(recorder *RequestRecorder) rpc.ServerServiceCallInterc
}
}
func GetNetRPCRateLimitingInterceptor(requestLimitsHandler rpcRate.RequestLimitsHandler) rpc.PreBodyInterceptor {
func GetNetRPCRateLimitingInterceptor(requestLimitsHandler rpcRate.RequestLimitsHandler, panicHandler RecoveryHandlerFunc) rpc.PreBodyInterceptor {
return func(reqServiceMethod string, sourceAddr net.Addr) (retErr error) {
defer func() {
if r := recover(); r != nil {
retErr = panicHandler(r)
}
}()
return func(reqServiceMethod string, sourceAddr net.Addr) error {
op := rpcRate.Operation{
Name: reqServiceMethod,
SourceAddr: sourceAddr,

View File

@ -1,13 +1,18 @@
package middleware
import (
"errors"
"net"
"net/netip"
"strings"
"sync"
"testing"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/agent/consul/rate"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
@ -266,3 +271,42 @@ func TestRequestRecorder(t *testing.T) {
})
}
}
func TestGetNetRPCRateLimitingInterceptor(t *testing.T) {
limiter := rate.NewMockRequestLimitsHandler(t)
logger := hclog.NewNullLogger()
rateLimitInterceptor := GetNetRPCRateLimitingInterceptor(limiter, NewPanicHandler(logger))
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:5678"))
t.Run("allow operation", func(t *testing.T) {
limiter.On("Allow", mock.Anything).
Return(nil).
Once()
err := rateLimitInterceptor("Status.Leader", addr)
require.NoError(t, err)
})
t.Run("allow returns error", func(t *testing.T) {
limiter.On("Allow", mock.Anything).
Return(errors.New("uh oh")).
Once()
err := rateLimitInterceptor("Status.Leader", addr)
require.Error(t, err)
require.Equal(t, "uh oh", err.Error())
})
t.Run("allow panics", func(t *testing.T) {
limiter.On("Allow", mock.Anything).
Panic("uh oh").
Once()
err := rateLimitInterceptor("Status.Leader", addr)
require.Error(t, err)
require.Equal(t, "rpc: panic serving request", err.Error())
})
}

View File

@ -0,0 +1,24 @@
package middleware
import (
"fmt"
"github.com/hashicorp/go-hclog"
)
// NewPanicHandler returns a RecoveryHandlerFunc type function
// to handle panic in RPC server's handlers.
func NewPanicHandler(logger hclog.Logger) RecoveryHandlerFunc {
return func(p interface{}) (err error) {
// Log the panic and the stack trace of the Goroutine that caused the panic.
stacktrace := hclog.Stacktrace()
logger.Error("panic serving rpc request",
"panic", p,
"stack", stacktrace,
)
return fmt.Errorf("rpc: panic serving request")
}
}
type RecoveryHandlerFunc func(p interface{}) (err error)