From f72d75d6b2d046366261ba288be7744674246221 Mon Sep 17 00:00:00 2001 From: Dan Upton Date: Wed, 10 May 2023 10:38:48 +0100 Subject: [PATCH] resource: add missing validation to the `List` and `WatchList` endpoints (#17213) --- agent/grpc-external/services/resource/list.go | 17 +++++++++++ .../services/resource/list_test.go | 25 +++++++++++++++++ .../grpc-external/services/resource/watch.go | 17 +++++++++++ .../services/resource/watch_test.go | 28 +++++++++++++++++++ 4 files changed, 87 insertions(+) diff --git a/agent/grpc-external/services/resource/list.go b/agent/grpc-external/services/resource/list.go index 65cc37a26..77269e746 100644 --- a/agent/grpc-external/services/resource/list.go +++ b/agent/grpc-external/services/resource/list.go @@ -15,6 +15,10 @@ import ( ) func (s *Server) List(ctx context.Context, req *pbresource.ListRequest) (*pbresource.ListResponse, error) { + if err := validateListRequest(req); err != nil { + return nil, err + } + // check type reg, err := s.resolveType(req.Type) if err != nil { @@ -65,3 +69,16 @@ func (s *Server) List(ctx context.Context, req *pbresource.ListRequest) (*pbreso } return &pbresource.ListResponse{Resources: result}, nil } + +func validateListRequest(req *pbresource.ListRequest) error { + var field string + switch { + case req.Type == nil: + field = "type" + case req.Tenancy == nil: + field = "tenancy" + default: + return nil + } + return status.Errorf(codes.InvalidArgument, "%s is required", field) +} diff --git a/agent/grpc-external/services/resource/list_test.go b/agent/grpc-external/services/resource/list_test.go index b476c82ae..7128d5e31 100644 --- a/agent/grpc-external/services/resource/list_test.go +++ b/agent/grpc-external/services/resource/list_test.go @@ -22,6 +22,31 @@ import ( "google.golang.org/grpc/status" ) +func TestList_InputValidation(t *testing.T) { + server := testServer(t) + client := testClient(t, server) + + demo.RegisterTypes(server.Registry) + + testCases := map[string]func(*pbresource.ListRequest){ + "no type": func(req *pbresource.ListRequest) { req.Type = nil }, + "no tenancy": func(req *pbresource.ListRequest) { req.Tenancy = nil }, + } + for desc, modFn := range testCases { + t.Run(desc, func(t *testing.T) { + req := &pbresource.ListRequest{ + Type: demo.TypeV2Album, + Tenancy: demo.TenancyDefault, + } + modFn(req) + + _, err := client.List(testContext(t), req) + require.Error(t, err) + require.Equal(t, codes.InvalidArgument.String(), status.Code(err).String()) + }) + } +} + func TestList_TypeNotFound(t *testing.T) { server := testServer(t) client := testClient(t, server) diff --git a/agent/grpc-external/services/resource/watch.go b/agent/grpc-external/services/resource/watch.go index 77ffe19b0..2fd943a6c 100644 --- a/agent/grpc-external/services/resource/watch.go +++ b/agent/grpc-external/services/resource/watch.go @@ -13,6 +13,10 @@ import ( ) func (s *Server) WatchList(req *pbresource.WatchListRequest, stream pbresource.ResourceService_WatchListServer) error { + if err := validateWatchListRequest(req); err != nil { + return err + } + // check type exists reg, err := s.resolveType(req.Type) if err != nil { @@ -70,3 +74,16 @@ func (s *Server) WatchList(req *pbresource.WatchListRequest, stream pbresource.R } } } + +func validateWatchListRequest(req *pbresource.WatchListRequest) error { + var field string + switch { + case req.Type == nil: + field = "type" + case req.Tenancy == nil: + field = "tenancy" + default: + return nil + } + return status.Errorf(codes.InvalidArgument, "%s is required", field) +} diff --git a/agent/grpc-external/services/resource/watch_test.go b/agent/grpc-external/services/resource/watch_test.go index b62dc8a40..687fe0d06 100644 --- a/agent/grpc-external/services/resource/watch_test.go +++ b/agent/grpc-external/services/resource/watch_test.go @@ -22,6 +22,34 @@ import ( "google.golang.org/grpc/status" ) +func TestWatchList_InputValidation(t *testing.T) { + server := testServer(t) + client := testClient(t, server) + + demo.RegisterTypes(server.Registry) + + testCases := map[string]func(*pbresource.WatchListRequest){ + "no type": func(req *pbresource.WatchListRequest) { req.Type = nil }, + "no tenancy": func(req *pbresource.WatchListRequest) { req.Tenancy = nil }, + } + for desc, modFn := range testCases { + t.Run(desc, func(t *testing.T) { + req := &pbresource.WatchListRequest{ + Type: demo.TypeV2Album, + Tenancy: demo.TenancyDefault, + } + modFn(req) + + stream, err := client.WatchList(testContext(t), req) + require.NoError(t, err) + + _, err = stream.Recv() + require.Error(t, err) + require.Equal(t, codes.InvalidArgument.String(), status.Code(err).String()) + }) + } +} + func TestWatchList_TypeNotFound(t *testing.T) { t.Parallel()