diff --git a/agent/consul/server.go b/agent/consul/server.go index e5e4ecb37..a7a651767 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -640,7 +640,7 @@ func newGRPCHandlerFromConfig(deps Deps, config *Config, s *Server) connHandler &subscribeBackend{srv: s, connPool: deps.GRPCConnPool}, deps.Logger.Named("grpc-api.subscription"))) } - return agentgrpc.NewHandler(config.RPCAddr, register) + return agentgrpc.NewHandler(deps.Logger, config.RPCAddr, register) } func (s *Server) connectCARootsMonitor(ctx context.Context) { diff --git a/agent/grpc/client_test.go b/agent/grpc/client_test.go index 49922a309..a831bc8ba 100644 --- a/agent/grpc/client_test.go +++ b/agent/grpc/client_test.go @@ -1,6 +1,7 @@ package grpc import ( + "bytes" "context" "fmt" "net" @@ -11,6 +12,8 @@ import ( "github.com/hashicorp/go-hclog" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "github.com/hashicorp/consul/agent/grpc/internal/testservice" "github.com/hashicorp/consul/agent/grpc/resolver" @@ -54,7 +57,7 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) { res := resolver.NewServerResolverBuilder(newConfig(t)) registerWithGRPC(t, res) - srv := newTestServer(t, "server-1", "dc1") + srv := newSimpleTestServer(t, "server-1", "dc1") tlsConf, err := tlsutil.NewConfigurator(tlsutil.Config{ VerifyIncoming: true, VerifyOutgoing: true, @@ -91,7 +94,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { for i := 0; i < count; i++ { name := fmt.Sprintf("server-%d", i) - srv := newTestServer(t, name, "dc1") + srv := newSimpleTestServer(t, name, "dc1") res.AddServer(srv.Metadata()) t.Cleanup(srv.shutdown) } @@ -128,7 +131,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) { for i := 0; i < count; i++ { name := fmt.Sprintf("server-%d", i) - srv := newTestServer(t, name, "dc1") + srv := newSimpleTestServer(t, name, "dc1") res.AddServer(srv.Metadata()) t.Cleanup(srv.shutdown) } @@ -177,7 +180,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { for _, dc := range dcs { name := "server-0-" + dc - srv := newTestServer(t, name, dc) + srv := newSimpleTestServer(t, name, dc) res.AddServer(srv.Metadata()) t.Cleanup(srv.shutdown) } @@ -202,3 +205,41 @@ func registerWithGRPC(t *testing.T, b *resolver.ServerResolverBuilder) { resolver.Deregister(b.Authority()) }) } + +func TestRecoverMiddleware(t *testing.T) { + // Prepare a logger with output to a buffer + // so we can check what it writes. + var buf bytes.Buffer + + logger := hclog.New(&hclog.LoggerOptions{ + Output: &buf, + }) + + res := resolver.NewServerResolverBuilder(newConfig(t)) + registerWithGRPC(t, res) + + srv := newPanicTestServer(t, logger, "server-1", "dc1") + res.AddServer(srv.Metadata()) + t.Cleanup(srv.shutdown) + + pool := NewClientConnPool(res, nil, useTLSForDcAlwaysTrue) + + conn, err := pool.ClientConn("dc1") + require.NoError(t, err) + client := testservice.NewSimpleClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(cancel) + + resp, err := client.Something(ctx, &testservice.Req{}) + expectedErr := status.Errorf(codes.Internal, "grpc: panic serving request: panic from Something") + require.Equal(t, expectedErr, err) + require.Nil(t, resp) + + // Read the log + strLog := buf.String() + // Checking the entire stack trace is not possible, let's + // make sure that it contains a couple of expected strings. + require.Contains(t, strLog, `[ERROR] panic serving grpc request: panic="panic from Something`) + require.Contains(t, strLog, `github.com/hashicorp/consul/agent/grpc.(*simplePanic).Something`) +} diff --git a/agent/grpc/handler.go b/agent/grpc/handler.go index e21a4b768..3a48679b0 100644 --- a/agent/grpc/handler.go +++ b/agent/grpc/handler.go @@ -4,30 +4,39 @@ Package grpc provides a Handler and client for agent gRPC connections. package grpc import ( + "context" "fmt" "net" "time" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/status" middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/recovery" + "github.com/hashicorp/go-hclog" ) // NewHandler returns a gRPC server that accepts connections from Handle(conn). // The register function will be called with the grpc.Server to register // gRPC services with the server. -func NewHandler(addr net.Addr, register func(server *grpc.Server)) *Handler { +func NewHandler(logger Logger, addr net.Addr, register func(server *grpc.Server)) *Handler { + recoveryOpts := []recovery.Option{ + recovery.WithRecoveryHandlerContext(newPanicHandler(logger)), + } metrics := defaultMetrics() // We don't need to pass tls.Config to the server since it's multiplexed // behind the RPC listener, which already has TLS configured. srv := grpc.NewServer( middleware.WithUnaryServerChain( - recovery.UnaryServerInterceptor(), + // Add middlware interceptors to recover in case of panics. + recovery.UnaryServerInterceptor(recoveryOpts...), ), middleware.WithStreamServerChain( - recovery.StreamServerInterceptor(), + // Add middlware interceptors to recover in case of panics. + recovery.StreamServerInterceptor(recoveryOpts...), (&activeStreamCounter{metrics: metrics}).Intercept, ), grpc.StatsHandler(newStatsHandler(metrics)), @@ -41,6 +50,21 @@ func NewHandler(addr net.Addr, register func(server *grpc.Server)) *Handler { return &Handler{srv: srv, listener: lis} } +// newPanicHandler returns a recovery.RecoveryHandlerFuncContext closure function +// to handle panic in GRPC server's handlers. +func newPanicHandler(logger Logger) recovery.RecoveryHandlerFuncContext { + return func(ctx context.Context, p interface{}) (err error) { + // Log the panic and the stack trace of the Goroutine that caused the panic. + stacktrace := hclog.Stacktrace() + logger.Error("panic serving grpc request", + "panic", p, + "stack", stacktrace, + ) + + return status.Errorf(codes.Internal, "grpc: panic serving request: %v", p) + } +} + // Handler implements a handler for the rpc server listener, and the // agent.Component interface for managing the lifecycle of the grpc.Server. type Handler struct { diff --git a/agent/grpc/server_test.go b/agent/grpc/server_test.go index 442b617d5..d6efa826d 100644 --- a/agent/grpc/server_test.go +++ b/agent/grpc/server_test.go @@ -18,6 +18,7 @@ import ( "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/tlsutil" + "github.com/hashicorp/go-hclog" ) type testServer struct { @@ -37,11 +38,22 @@ func (s testServer) Metadata() *metadata.Server { } } -func newTestServer(t *testing.T, name string, dc string) testServer { - addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} - handler := NewHandler(addr, func(server *grpc.Server) { +func newSimpleTestServer(t *testing.T, name, dc string) testServer { + return newTestServer(t, hclog.Default(), name, dc, func(server *grpc.Server) { testservice.RegisterSimpleServer(server, &simple{name: name, dc: dc}) }) +} + +// newPanicTestServer sets up a simple server with handlers that panic. +func newPanicTestServer(t *testing.T, logger hclog.Logger, name, dc string) testServer { + return newTestServer(t, logger, name, dc, func(server *grpc.Server) { + testservice.RegisterSimpleServer(server, &simplePanic{name: name, dc: dc}) + }) +} + +func newTestServer(t *testing.T, logger hclog.Logger, name, dc string, register func(server *grpc.Server)) testServer { + addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} + handler := NewHandler(logger, addr, register) lis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -101,6 +113,23 @@ func (s *simple) Something(_ context.Context, _ *testservice.Req) (*testservice. return &testservice.Resp{ServerName: s.name, Datacenter: s.dc}, nil } +type simplePanic struct { + name, dc string +} + +func (s *simplePanic) Flow(_ *testservice.Req, flow testservice.Simple_FlowServer) error { + for flow.Context().Err() == nil { + time.Sleep(time.Millisecond) + panic("panic from Flow") + } + return nil +} + +func (s *simplePanic) Something(_ context.Context, _ *testservice.Req) (*testservice.Resp, error) { + time.Sleep(time.Millisecond) + panic("panic from Something") +} + // fakeRPCListener mimics agent/consul.Server.listen to handle the RPCType byte. // In the future we should be able to refactor Server and extract this RPC // handling logic so that we don't need to use a fake. diff --git a/agent/grpc/stats_test.go b/agent/grpc/stats_test.go index 475bbf6df..079de3408 100644 --- a/agent/grpc/stats_test.go +++ b/agent/grpc/stats_test.go @@ -15,6 +15,7 @@ import ( "google.golang.org/grpc" "github.com/hashicorp/consul/agent/grpc/internal/testservice" + "github.com/hashicorp/go-hclog" ) func noopRegister(*grpc.Server) {} @@ -23,7 +24,7 @@ func TestHandler_EmitsStats(t *testing.T) { sink, reset := patchGlobalMetrics(t) addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} - handler := NewHandler(addr, noopRegister) + handler := NewHandler(hclog.Default(), addr, noopRegister) reset() testservice.RegisterSimpleServer(handler.srv, &simple{}) diff --git a/agent/rpc/subscribe/subscribe_test.go b/agent/rpc/subscribe/subscribe_test.go index d2c13716d..7ec636ec8 100644 --- a/agent/rpc/subscribe/subscribe_test.go +++ b/agent/rpc/subscribe/subscribe_test.go @@ -317,7 +317,7 @@ var _ Backend = (*testBackend)(nil) func runTestServer(t *testing.T, server *Server) net.Addr { addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} var grpcServer *gogrpc.Server - handler := grpc.NewHandler(addr, func(srv *gogrpc.Server) { + handler := grpc.NewHandler(hclog.New(nil), addr, func(srv *gogrpc.Server) { grpcServer = srv pbsubscribe.RegisterStateChangeSubscriptionServer(srv, server) }) diff --git a/agent/xds/server.go b/agent/xds/server.go index 19ff44aba..011cdb653 100644 --- a/agent/xds/server.go +++ b/agent/xds/server.go @@ -545,15 +545,36 @@ func tokenFromContext(ctx context.Context) string { return "" } +// newPanicHandler returns a recovery.RecoveryHandlerFuncContext closure function +// to handle panic in GRPC server's handlers. +func newPanicHandler(logger hclog.Logger) recovery.RecoveryHandlerFuncContext { + return func(ctx context.Context, p interface{}) (err error) { + // Log the panic and the stack trace of the Goroutine that caused the panic. + stacktrace := hclog.Stacktrace() + logger.Error("panic serving grpc request", + "panic", p, + "stack", stacktrace, + ) + + return status.Errorf(codes.Internal, "grpc: panic serving request: %v", p) + } +} + // GRPCServer returns a server instance that can handle xDS requests. func (s *Server) GRPCServer(tlsConfigurator *tlsutil.Configurator) (*grpc.Server, error) { + recoveryOpts := []recovery.Option{ + recovery.WithRecoveryHandlerContext(newPanicHandler(s.Logger)), + } + opts := []grpc.ServerOption{ grpc.MaxConcurrentStreams(2048), middleware.WithUnaryServerChain( - recovery.UnaryServerInterceptor(), + // Add middlware interceptors to recover in case of panics. + recovery.UnaryServerInterceptor(recoveryOpts...), ), middleware.WithStreamServerChain( - recovery.StreamServerInterceptor(), + // Add middlware interceptors to recover in case of panics. + recovery.StreamServerInterceptor(recoveryOpts...), ), } if tlsConfigurator != nil {