resource: add missing validation to the `List` and `WatchList` endpoints (#17213)
This commit is contained in:
parent
0d54d9a678
commit
f72d75d6b2
|
@ -15,6 +15,10 @@ import (
|
|||
)
|
||||
|
||||
func (s *Server) List(ctx context.Context, req *pbresource.ListRequest) (*pbresource.ListResponse, error) {
|
||||
if err := validateListRequest(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// check type
|
||||
reg, err := s.resolveType(req.Type)
|
||||
if err != nil {
|
||||
|
@ -65,3 +69,16 @@ func (s *Server) List(ctx context.Context, req *pbresource.ListRequest) (*pbreso
|
|||
}
|
||||
return &pbresource.ListResponse{Resources: result}, nil
|
||||
}
|
||||
|
||||
func validateListRequest(req *pbresource.ListRequest) error {
|
||||
var field string
|
||||
switch {
|
||||
case req.Type == nil:
|
||||
field = "type"
|
||||
case req.Tenancy == nil:
|
||||
field = "tenancy"
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return status.Errorf(codes.InvalidArgument, "%s is required", field)
|
||||
}
|
||||
|
|
|
@ -22,6 +22,31 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestList_InputValidation(t *testing.T) {
|
||||
server := testServer(t)
|
||||
client := testClient(t, server)
|
||||
|
||||
demo.RegisterTypes(server.Registry)
|
||||
|
||||
testCases := map[string]func(*pbresource.ListRequest){
|
||||
"no type": func(req *pbresource.ListRequest) { req.Type = nil },
|
||||
"no tenancy": func(req *pbresource.ListRequest) { req.Tenancy = nil },
|
||||
}
|
||||
for desc, modFn := range testCases {
|
||||
t.Run(desc, func(t *testing.T) {
|
||||
req := &pbresource.ListRequest{
|
||||
Type: demo.TypeV2Album,
|
||||
Tenancy: demo.TenancyDefault,
|
||||
}
|
||||
modFn(req)
|
||||
|
||||
_, err := client.List(testContext(t), req)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.InvalidArgument.String(), status.Code(err).String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestList_TypeNotFound(t *testing.T) {
|
||||
server := testServer(t)
|
||||
client := testClient(t, server)
|
||||
|
|
|
@ -13,6 +13,10 @@ import (
|
|||
)
|
||||
|
||||
func (s *Server) WatchList(req *pbresource.WatchListRequest, stream pbresource.ResourceService_WatchListServer) error {
|
||||
if err := validateWatchListRequest(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check type exists
|
||||
reg, err := s.resolveType(req.Type)
|
||||
if err != nil {
|
||||
|
@ -70,3 +74,16 @@ func (s *Server) WatchList(req *pbresource.WatchListRequest, stream pbresource.R
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func validateWatchListRequest(req *pbresource.WatchListRequest) error {
|
||||
var field string
|
||||
switch {
|
||||
case req.Type == nil:
|
||||
field = "type"
|
||||
case req.Tenancy == nil:
|
||||
field = "tenancy"
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return status.Errorf(codes.InvalidArgument, "%s is required", field)
|
||||
}
|
||||
|
|
|
@ -22,6 +22,34 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestWatchList_InputValidation(t *testing.T) {
|
||||
server := testServer(t)
|
||||
client := testClient(t, server)
|
||||
|
||||
demo.RegisterTypes(server.Registry)
|
||||
|
||||
testCases := map[string]func(*pbresource.WatchListRequest){
|
||||
"no type": func(req *pbresource.WatchListRequest) { req.Type = nil },
|
||||
"no tenancy": func(req *pbresource.WatchListRequest) { req.Tenancy = nil },
|
||||
}
|
||||
for desc, modFn := range testCases {
|
||||
t.Run(desc, func(t *testing.T) {
|
||||
req := &pbresource.WatchListRequest{
|
||||
Type: demo.TypeV2Album,
|
||||
Tenancy: demo.TenancyDefault,
|
||||
}
|
||||
modFn(req)
|
||||
|
||||
stream, err := client.WatchList(testContext(t), req)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = stream.Recv()
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.InvalidArgument.String(), status.Code(err).String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatchList_TypeNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue