resource: add missing validation to the `List` and `WatchList` endpoints (#17213)
This commit is contained in:
parent
0d54d9a678
commit
f72d75d6b2
|
@ -15,6 +15,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) List(ctx context.Context, req *pbresource.ListRequest) (*pbresource.ListResponse, error) {
|
func (s *Server) List(ctx context.Context, req *pbresource.ListRequest) (*pbresource.ListResponse, error) {
|
||||||
|
if err := validateListRequest(req); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// check type
|
// check type
|
||||||
reg, err := s.resolveType(req.Type)
|
reg, err := s.resolveType(req.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -65,3 +69,16 @@ func (s *Server) List(ctx context.Context, req *pbresource.ListRequest) (*pbreso
|
||||||
}
|
}
|
||||||
return &pbresource.ListResponse{Resources: result}, nil
|
return &pbresource.ListResponse{Resources: result}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateListRequest(req *pbresource.ListRequest) error {
|
||||||
|
var field string
|
||||||
|
switch {
|
||||||
|
case req.Type == nil:
|
||||||
|
field = "type"
|
||||||
|
case req.Tenancy == nil:
|
||||||
|
field = "tenancy"
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return status.Errorf(codes.InvalidArgument, "%s is required", field)
|
||||||
|
}
|
||||||
|
|
|
@ -22,6 +22,31 @@ import (
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestList_InputValidation(t *testing.T) {
|
||||||
|
server := testServer(t)
|
||||||
|
client := testClient(t, server)
|
||||||
|
|
||||||
|
demo.RegisterTypes(server.Registry)
|
||||||
|
|
||||||
|
testCases := map[string]func(*pbresource.ListRequest){
|
||||||
|
"no type": func(req *pbresource.ListRequest) { req.Type = nil },
|
||||||
|
"no tenancy": func(req *pbresource.ListRequest) { req.Tenancy = nil },
|
||||||
|
}
|
||||||
|
for desc, modFn := range testCases {
|
||||||
|
t.Run(desc, func(t *testing.T) {
|
||||||
|
req := &pbresource.ListRequest{
|
||||||
|
Type: demo.TypeV2Album,
|
||||||
|
Tenancy: demo.TenancyDefault,
|
||||||
|
}
|
||||||
|
modFn(req)
|
||||||
|
|
||||||
|
_, err := client.List(testContext(t), req)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, codes.InvalidArgument.String(), status.Code(err).String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestList_TypeNotFound(t *testing.T) {
|
func TestList_TypeNotFound(t *testing.T) {
|
||||||
server := testServer(t)
|
server := testServer(t)
|
||||||
client := testClient(t, server)
|
client := testClient(t, server)
|
||||||
|
|
|
@ -13,6 +13,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) WatchList(req *pbresource.WatchListRequest, stream pbresource.ResourceService_WatchListServer) error {
|
func (s *Server) WatchList(req *pbresource.WatchListRequest, stream pbresource.ResourceService_WatchListServer) error {
|
||||||
|
if err := validateWatchListRequest(req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// check type exists
|
// check type exists
|
||||||
reg, err := s.resolveType(req.Type)
|
reg, err := s.resolveType(req.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -70,3 +74,16 @@ func (s *Server) WatchList(req *pbresource.WatchListRequest, stream pbresource.R
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateWatchListRequest(req *pbresource.WatchListRequest) error {
|
||||||
|
var field string
|
||||||
|
switch {
|
||||||
|
case req.Type == nil:
|
||||||
|
field = "type"
|
||||||
|
case req.Tenancy == nil:
|
||||||
|
field = "tenancy"
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return status.Errorf(codes.InvalidArgument, "%s is required", field)
|
||||||
|
}
|
||||||
|
|
|
@ -22,6 +22,34 @@ import (
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestWatchList_InputValidation(t *testing.T) {
|
||||||
|
server := testServer(t)
|
||||||
|
client := testClient(t, server)
|
||||||
|
|
||||||
|
demo.RegisterTypes(server.Registry)
|
||||||
|
|
||||||
|
testCases := map[string]func(*pbresource.WatchListRequest){
|
||||||
|
"no type": func(req *pbresource.WatchListRequest) { req.Type = nil },
|
||||||
|
"no tenancy": func(req *pbresource.WatchListRequest) { req.Tenancy = nil },
|
||||||
|
}
|
||||||
|
for desc, modFn := range testCases {
|
||||||
|
t.Run(desc, func(t *testing.T) {
|
||||||
|
req := &pbresource.WatchListRequest{
|
||||||
|
Type: demo.TypeV2Album,
|
||||||
|
Tenancy: demo.TenancyDefault,
|
||||||
|
}
|
||||||
|
modFn(req)
|
||||||
|
|
||||||
|
stream, err := client.WatchList(testContext(t), req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = stream.Recv()
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, codes.InvalidArgument.String(), status.Code(err).String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWatchList_TypeNotFound(t *testing.T) {
|
func TestWatchList_TypeNotFound(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue