diff --git a/agent/consul/server.go b/agent/consul/server.go index a12038be5..94f4560a0 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -490,7 +490,7 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incom } rpcServerOpts := []func(*rpc.Server){ - rpc.WithPreBodyInterceptor(middleware.GetNetRPCRateLimitingInterceptor(s.incomingRPCLimiter)), + rpc.WithPreBodyInterceptor(middleware.GetNetRPCRateLimitingInterceptor(s.incomingRPCLimiter, middleware.NewPanicHandler(s.logger))), } if flat.GetNetRPCInterceptorFunc != nil { diff --git a/agent/rpc/middleware/interceptors.go b/agent/rpc/middleware/interceptors.go index a4aa432d6..8c5d6c15e 100644 --- a/agent/rpc/middleware/interceptors.go +++ b/agent/rpc/middleware/interceptors.go @@ -160,9 +160,16 @@ func GetNetRPCInterceptor(recorder *RequestRecorder) rpc.ServerServiceCallInterc } } -func GetNetRPCRateLimitingInterceptor(requestLimitsHandler rpcRate.RequestLimitsHandler) rpc.PreBodyInterceptor { +func GetNetRPCRateLimitingInterceptor(requestLimitsHandler rpcRate.RequestLimitsHandler, panicHandler RecoveryHandlerFunc) rpc.PreBodyInterceptor { + + return func(reqServiceMethod string, sourceAddr net.Addr) (retErr error) { + + defer func() { + if r := recover(); r != nil { + retErr = panicHandler(r) + } + }() - return func(reqServiceMethod string, sourceAddr net.Addr) error { op := rpcRate.Operation{ Name: reqServiceMethod, SourceAddr: sourceAddr, diff --git a/agent/rpc/middleware/interceptors_test.go b/agent/rpc/middleware/interceptors_test.go index c47cf17f4..fda01199b 100644 --- a/agent/rpc/middleware/interceptors_test.go +++ b/agent/rpc/middleware/interceptors_test.go @@ -1,13 +1,18 @@ package middleware import ( + "errors" + "net" + "net/netip" "strings" "sync" "testing" "time" "github.com/armon/go-metrics" + "github.com/hashicorp/consul/agent/consul/rate" "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -266,3 +271,42 @@ func TestRequestRecorder(t *testing.T) { }) } } + +func TestGetNetRPCRateLimitingInterceptor(t *testing.T) { + limiter := rate.NewMockRequestLimitsHandler(t) + + logger := hclog.NewNullLogger() + rateLimitInterceptor := GetNetRPCRateLimitingInterceptor(limiter, NewPanicHandler(logger)) + + addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:5678")) + + t.Run("allow operation", func(t *testing.T) { + limiter.On("Allow", mock.Anything). + Return(nil). + Once() + + err := rateLimitInterceptor("Status.Leader", addr) + require.NoError(t, err) + }) + + t.Run("allow returns error", func(t *testing.T) { + limiter.On("Allow", mock.Anything). + Return(errors.New("uh oh")). + Once() + + err := rateLimitInterceptor("Status.Leader", addr) + require.Error(t, err) + require.Equal(t, "uh oh", err.Error()) + }) + + t.Run("allow panics", func(t *testing.T) { + limiter.On("Allow", mock.Anything). + Panic("uh oh"). + Once() + + err := rateLimitInterceptor("Status.Leader", addr) + + require.Error(t, err) + require.Equal(t, "rpc: panic serving request", err.Error()) + }) +} diff --git a/agent/rpc/middleware/recovery.go b/agent/rpc/middleware/recovery.go new file mode 100644 index 000000000..f381f0ee2 --- /dev/null +++ b/agent/rpc/middleware/recovery.go @@ -0,0 +1,24 @@ +package middleware + +import ( + "fmt" + + "github.com/hashicorp/go-hclog" +) + +// NewPanicHandler returns a RecoveryHandlerFunc type function +// to handle panic in RPC server's handlers. +func NewPanicHandler(logger hclog.Logger) RecoveryHandlerFunc { + return func(p interface{}) (err error) { + // Log the panic and the stack trace of the Goroutine that caused the panic. + stacktrace := hclog.Stacktrace() + logger.Error("panic serving rpc request", + "panic", p, + "stack", stacktrace, + ) + + return fmt.Errorf("rpc: panic serving request") + } +} + +type RecoveryHandlerFunc func(p interface{}) (err error)