resource: add missing validation to the `List` and `WatchList` endpoints (#17213)

This commit is contained in:
Dan Upton 2023-05-10 10:38:48 +01:00 committed by GitHub
parent 0d54d9a678
commit f72d75d6b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 87 additions and 0 deletions

View File

@ -15,6 +15,10 @@ import (
) )
func (s *Server) List(ctx context.Context, req *pbresource.ListRequest) (*pbresource.ListResponse, error) { func (s *Server) List(ctx context.Context, req *pbresource.ListRequest) (*pbresource.ListResponse, error) {
if err := validateListRequest(req); err != nil {
return nil, err
}
// check type // check type
reg, err := s.resolveType(req.Type) reg, err := s.resolveType(req.Type)
if err != nil { if err != nil {
@ -65,3 +69,16 @@ func (s *Server) List(ctx context.Context, req *pbresource.ListRequest) (*pbreso
} }
return &pbresource.ListResponse{Resources: result}, nil return &pbresource.ListResponse{Resources: result}, nil
} }
func validateListRequest(req *pbresource.ListRequest) error {
var field string
switch {
case req.Type == nil:
field = "type"
case req.Tenancy == nil:
field = "tenancy"
default:
return nil
}
return status.Errorf(codes.InvalidArgument, "%s is required", field)
}

View File

@ -22,6 +22,31 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
func TestList_InputValidation(t *testing.T) {
server := testServer(t)
client := testClient(t, server)
demo.RegisterTypes(server.Registry)
testCases := map[string]func(*pbresource.ListRequest){
"no type": func(req *pbresource.ListRequest) { req.Type = nil },
"no tenancy": func(req *pbresource.ListRequest) { req.Tenancy = nil },
}
for desc, modFn := range testCases {
t.Run(desc, func(t *testing.T) {
req := &pbresource.ListRequest{
Type: demo.TypeV2Album,
Tenancy: demo.TenancyDefault,
}
modFn(req)
_, err := client.List(testContext(t), req)
require.Error(t, err)
require.Equal(t, codes.InvalidArgument.String(), status.Code(err).String())
})
}
}
func TestList_TypeNotFound(t *testing.T) { func TestList_TypeNotFound(t *testing.T) {
server := testServer(t) server := testServer(t)
client := testClient(t, server) client := testClient(t, server)

View File

@ -13,6 +13,10 @@ import (
) )
func (s *Server) WatchList(req *pbresource.WatchListRequest, stream pbresource.ResourceService_WatchListServer) error { func (s *Server) WatchList(req *pbresource.WatchListRequest, stream pbresource.ResourceService_WatchListServer) error {
if err := validateWatchListRequest(req); err != nil {
return err
}
// check type exists // check type exists
reg, err := s.resolveType(req.Type) reg, err := s.resolveType(req.Type)
if err != nil { if err != nil {
@ -70,3 +74,16 @@ func (s *Server) WatchList(req *pbresource.WatchListRequest, stream pbresource.R
} }
} }
} }
func validateWatchListRequest(req *pbresource.WatchListRequest) error {
var field string
switch {
case req.Type == nil:
field = "type"
case req.Tenancy == nil:
field = "tenancy"
default:
return nil
}
return status.Errorf(codes.InvalidArgument, "%s is required", field)
}

View File

@ -22,6 +22,34 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
func TestWatchList_InputValidation(t *testing.T) {
server := testServer(t)
client := testClient(t, server)
demo.RegisterTypes(server.Registry)
testCases := map[string]func(*pbresource.WatchListRequest){
"no type": func(req *pbresource.WatchListRequest) { req.Type = nil },
"no tenancy": func(req *pbresource.WatchListRequest) { req.Tenancy = nil },
}
for desc, modFn := range testCases {
t.Run(desc, func(t *testing.T) {
req := &pbresource.WatchListRequest{
Type: demo.TypeV2Album,
Tenancy: demo.TenancyDefault,
}
modFn(req)
stream, err := client.WatchList(testContext(t), req)
require.NoError(t, err)
_, err = stream.Recv()
require.Error(t, err)
require.Equal(t, codes.InvalidArgument.String(), status.Code(err).String())
})
}
}
func TestWatchList_TypeNotFound(t *testing.T) { func TestWatchList_TypeNotFound(t *testing.T) {
t.Parallel() t.Parallel()