Change error-handling across handlers. (#12225)

This commit is contained in:
Mathew Estafanous 2022-01-31 11:17:35 -05:00 committed by GitHub
parent 8f33317819
commit 1113a7533c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 149 additions and 254 deletions

View File

@ -136,9 +136,7 @@ func (s *HTTPHandlers) CatalogRegister(resp http.ResponseWriter, req *http.Reque
}
if err := s.rewordUnknownEnterpriseFieldError(decodeBody(req.Body, &args)); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
// Setup the default DC if not provided
@ -168,9 +166,7 @@ func (s *HTTPHandlers) CatalogDeregister(resp http.ResponseWriter, req *http.Req
return nil, err
}
if err := s.rewordUnknownEnterpriseFieldError(decodeBody(req.Body, &args)); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
// Setup the default DC if not provided
@ -367,9 +363,7 @@ func (s *HTTPHandlers) catalogServiceNodes(resp http.ResponseWriter, req *http.R
return nil, err
}
if args.ServiceName == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing service name")
return nil, nil
return nil, BadRequestError{Reason: "Missing service name"}
}
// Make the RPC request
@ -444,9 +438,7 @@ func (s *HTTPHandlers) CatalogNodeServices(resp http.ResponseWriter, req *http.R
return nil, err
}
if args.Node == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing node name")
return nil, nil
return nil, BadRequestError{Reason: "Missing node name"}
}
// Make the RPC request
@ -511,9 +503,7 @@ func (s *HTTPHandlers) CatalogNodeServiceList(resp http.ResponseWriter, req *htt
return nil, err
}
if args.Node == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing node name")
return nil, nil
return nil, BadRequestError{Reason: "Missing node name"}
}
// Make the RPC request
@ -564,9 +554,7 @@ func (s *HTTPHandlers) CatalogGatewayServices(resp http.ResponseWriter, req *htt
return nil, err
}
if args.ServiceName == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing gateway name")
return nil, nil
return nil, BadRequestError{Reason: "Missing gateway name"}
}
// Make the RPC request

View File

@ -90,16 +90,12 @@ func (s *HTTPHandlers) configDelete(resp http.ResponseWriter, req *http.Request)
pathArgs := strings.SplitN(kindAndName, "/", 2)
if len(pathArgs) != 2 {
resp.WriteHeader(http.StatusNotFound)
fmt.Fprintf(resp, "Must provide both a kind and name to delete")
return nil, nil
return nil, NotFoundError{Reason: "Must provide both a kind and name to delete"}
}
entry, err := structs.MakeConfigEntry(pathArgs[0], pathArgs[1])
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "%v", err)
return nil, nil
return nil, BadRequestError{Reason: err.Error()}
}
args.Entry = entry
// Parse enterprise meta.

View File

@ -8,16 +8,13 @@ import (
"github.com/hashicorp/consul/agent/structs"
)
// checkCoordinateDisabled will return a standard response if coordinates are
// disabled. This returns true if they are disabled and we should not continue.
func (s *HTTPHandlers) checkCoordinateDisabled(resp http.ResponseWriter, req *http.Request) bool {
// checkCoordinateDisabled will return an unauthorized error if coordinates are
// disabled. Otherwise, a nil error will be returned.
func (s *HTTPHandlers) checkCoordinateDisabled() error {
if !s.agent.config.DisableCoordinates {
return false
return nil
}
resp.WriteHeader(http.StatusUnauthorized)
fmt.Fprint(resp, "Coordinate support disabled")
return true
return UnauthorizedError{Reason: "Coordinate support disabled"}
}
// sorter wraps a coordinate list and implements the sort.Interface to sort by
@ -44,8 +41,8 @@ func (s *sorter) Less(i, j int) bool {
// CoordinateDatacenters returns the WAN nodes in each datacenter, along with
// raw network coordinates.
func (s *HTTPHandlers) CoordinateDatacenters(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkCoordinateDisabled(resp, req) {
return nil, nil
if err := s.checkCoordinateDisabled(); err != nil {
return nil, err
}
var out []structs.DatacenterMap
@ -73,8 +70,8 @@ func (s *HTTPHandlers) CoordinateDatacenters(resp http.ResponseWriter, req *http
// CoordinateNodes returns the LAN nodes in the given datacenter, along with
// raw network coordinates.
func (s *HTTPHandlers) CoordinateNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkCoordinateDisabled(resp, req) {
return nil, nil
if err := s.checkCoordinateDisabled(); err != nil {
return nil, err
}
args := structs.DCSpecificRequest{}
@ -98,8 +95,8 @@ func (s *HTTPHandlers) CoordinateNodes(resp http.ResponseWriter, req *http.Reque
// CoordinateNode returns the LAN node in the given datacenter, along with
// raw network coordinates.
func (s *HTTPHandlers) CoordinateNode(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkCoordinateDisabled(resp, req) {
return nil, nil
if err := s.checkCoordinateDisabled(); err != nil {
return nil, err
}
node, err := getPathSuffixUnescaped(req.URL.Path, "/v1/coordinate/node/")
@ -153,15 +150,13 @@ func filterCoordinates(req *http.Request, in structs.Coordinates) structs.Coordi
// CoordinateUpdate inserts or updates the LAN coordinate of a node.
func (s *HTTPHandlers) CoordinateUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkCoordinateDisabled(resp, req) {
return nil, nil
if err := s.checkCoordinateDisabled(); err != nil {
return nil, err
}
args := structs.CoordinateUpdateRequest{}
if err := decodeBody(req.Body, &args); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)

View File

@ -39,16 +39,14 @@ func TestCoordinate_Disabled_Response(t *testing.T) {
req, _ := http.NewRequest("PUT", "/should/not/care", nil)
resp := httptest.NewRecorder()
obj, err := tt(resp, req)
if err != nil {
t.Fatalf("err: %v", err)
err, ok := err.(UnauthorizedError)
if !ok {
t.Fatalf("expected unauthorized error but got %v", err)
}
if obj != nil {
t.Fatalf("bad: %#v", obj)
}
if got, want := resp.Code, http.StatusUnauthorized; got != want {
t.Fatalf("got %d want %d", got, want)
}
if !strings.Contains(resp.Body.String(), "Coordinate support disabled") {
if !strings.Contains(err.Error(), "Coordinate support disabled") {
t.Fatalf("bad: %#v", resp)
}
})

View File

@ -51,9 +51,7 @@ func (s *HTTPHandlers) DiscoveryChainRead(resp http.ResponseWriter, req *http.Re
if apiReq.OverrideMeshGateway.Mode != "" {
_, err := structs.ValidateMeshGatewayMode(string(apiReq.OverrideMeshGateway.Mode))
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Invalid OverrideMeshGateway.Mode parameter")
return nil, nil
return nil, BadRequestError{Reason: "Invalid OverrideMeshGateway.Mode parameter"}
}
args.OverrideMeshGateway = apiReq.OverrideMeshGateway
}

View File

@ -2,7 +2,6 @@ package agent
import (
"bytes"
"fmt"
"io"
"net/http"
"strconv"
@ -26,9 +25,7 @@ func (s *HTTPHandlers) EventFire(resp http.ResponseWriter, req *http.Request) (i
return nil, err
}
if event.Name == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing name")
return nil, nil
return nil, BadRequestError{Reason: "Missing name"}
}
// Get the ACL token
@ -58,9 +55,7 @@ func (s *HTTPHandlers) EventFire(resp http.ResponseWriter, req *http.Request) (i
// Try to fire the event
if err := s.agent.UserEvent(dc, token, event); err != nil {
if acl.IsErrPermissionDenied(err) {
resp.WriteHeader(http.StatusForbidden)
fmt.Fprint(resp, acl.ErrPermissionDenied.Error())
return nil, nil
return nil, ForbiddenError{Reason: acl.ErrPermissionDenied.Error()}
}
resp.WriteHeader(http.StatusInternalServerError)
return nil, err

View File

@ -88,13 +88,11 @@ func TestEventFire_token(t *testing.T) {
url := fmt.Sprintf("/v1/event/fire/%s?token=%s", c.event, token)
req, _ := http.NewRequest("PUT", url, nil)
resp := httptest.NewRecorder()
if _, err := a.srv.EventFire(resp, req); err != nil {
t.Fatalf("err: %s", err)
}
_, err := a.srv.EventFire(resp, req)
// Check the result
body := resp.Body.String()
if c.allowed {
body := resp.Body.String()
if acl.IsErrPermissionDenied(errors.New(body)) {
t.Fatalf("bad: %s", body)
}
@ -102,11 +100,11 @@ func TestEventFire_token(t *testing.T) {
t.Fatalf("bad: %d", resp.Code)
}
} else {
if !acl.IsErrPermissionDenied(errors.New(body)) {
t.Fatalf("bad: %s", body)
if !acl.IsErrPermissionDenied(err) {
t.Fatalf("bad: %s", err.Error())
}
if resp.Code != 403 {
t.Fatalf("bad: %d", resp.Code)
if err, ok := err.(ForbiddenError); !ok {
t.Fatalf("Expected forbidden but got %v", err)
}
}
}

View File

@ -1,7 +1,6 @@
package agent
import (
"fmt"
"net/http"
"net/url"
"strconv"
@ -36,9 +35,7 @@ func (s *HTTPHandlers) HealthChecksInState(resp http.ResponseWriter, req *http.R
return nil, err
}
if args.State == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing check state")
return nil, nil
return nil, BadRequestError{Reason: "Missing check state"}
}
// Make the RPC request
@ -82,9 +79,7 @@ func (s *HTTPHandlers) HealthNodeChecks(resp http.ResponseWriter, req *http.Requ
// Pull out the service name
args.Node = strings.TrimPrefix(req.URL.Path, "/v1/health/node/")
if args.Node == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing node name")
return nil, nil
return nil, BadRequestError{Reason: "Missing node name"}
}
// Make the RPC request
@ -130,9 +125,7 @@ func (s *HTTPHandlers) HealthServiceChecks(resp http.ResponseWriter, req *http.R
// Pull out the service name
args.ServiceName = strings.TrimPrefix(req.URL.Path, "/v1/health/checks/")
if args.ServiceName == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing service name")
return nil, nil
return nil, BadRequestError{Reason: "Missing service name"}
}
// Make the RPC request
@ -218,9 +211,7 @@ func (s *HTTPHandlers) healthServiceNodes(resp http.ResponseWriter, req *http.Re
// Pull out the service name
args.ServiceName = strings.TrimPrefix(req.URL.Path, prefix)
if args.ServiceName == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing service name")
return nil, nil
return nil, BadRequestError{Reason: "Missing service name"}
}
out, md, err := s.agent.rpcClientHealth.ServiceNodes(req.Context(), args)
@ -238,9 +229,7 @@ func (s *HTTPHandlers) healthServiceNodes(resp http.ResponseWriter, req *http.Re
// Filter to only passing if specified
filter, err := getBoolQueryParam(params, api.HealthPassing)
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Invalid value for ?passing")
return nil, nil
return nil, BadRequestError{Reason: "Invalid value for ?passing"}
}
// FIXME: remove filterNonPassing, replace with nodes.Filter, which is used by DNSServer

View File

@ -1,14 +1,13 @@
package agent
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strconv"
"strings"
"testing"
"time"
@ -1241,18 +1240,12 @@ func TestHealthServiceNodes_PassingFilter(t *testing.T) {
t.Run("passing_bad", func(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/health/service/consul?passing=nope-nope-nope", nil)
resp := httptest.NewRecorder()
a.srv.HealthServiceNodes(resp, req)
if code := resp.Code; code != 400 {
t.Errorf("bad response code %d, expected %d", code, 400)
_, err := a.srv.HealthServiceNodes(resp, req)
if _, ok := err.(BadRequestError); !ok {
t.Fatalf("Expected bad request error but got %v", err)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if !bytes.Contains(body, []byte("Invalid value for ?passing")) {
t.Errorf("bad %s", body)
if !strings.Contains(err.Error(), "Invalid value for ?passing") {
t.Errorf("bad %s", err.Error())
}
})
}
@ -1654,12 +1647,12 @@ func TestHealthConnectServiceNodes_PassingFilter(t *testing.T) {
req, _ := http.NewRequest("GET", fmt.Sprintf(
"/v1/health/connect/%s?passing=nope-nope", args.Service.Proxy.DestinationServiceName), nil)
resp := httptest.NewRecorder()
a.srv.HealthConnectServiceNodes(resp, req)
assert.Equal(t, 400, resp.Code)
_, err := a.srv.HealthConnectServiceNodes(resp, req)
assert.NotNil(t, err)
_, ok := err.(BadRequestError)
assert.True(t, ok)
body, err := ioutil.ReadAll(resp.Body)
assert.Nil(t, err)
assert.True(t, bytes.Contains(body, []byte("Invalid value for ?passing")))
assert.True(t, strings.Contains(err.Error(), "Invalid value for ?passing"))
})
}

View File

@ -78,6 +78,14 @@ func (e UnauthorizedError) Error() string {
return e.Reason
}
type EntityTooLargeError struct {
Reason string
}
func (e EntityTooLargeError) Error() string {
return e.Reason
}
// CodeWithPayloadError allow returning non HTTP 200
// Error codes while not returning PlainText payload
type CodeWithPayloadError struct {
@ -91,10 +99,11 @@ func (e CodeWithPayloadError) Error() string {
}
type ForbiddenError struct {
Reason string
}
func (e ForbiddenError) Error() string {
return "Access is restricted"
return e.Reason
}
// HTTPHandlers provides an HTTP api for an agent.
@ -443,6 +452,11 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc
return err.Error() == consul.ErrRateLimited.Error()
}
isEntityToLarge := func(err error) bool {
_, ok := err.(EntityTooLargeError)
return ok
}
addAllowHeader := func(methods []string) {
resp.Header().Add("Allow", strings.Join(methods, ","))
}
@ -488,6 +502,9 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc
case isTooManyRequests(err):
resp.WriteHeader(http.StatusTooManyRequests)
fmt.Fprint(resp, err.Error())
case isEntityToLarge(err):
resp.WriteHeader(http.StatusRequestEntityTooLarge)
fmt.Fprint(resp, err.Error())
default:
resp.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(resp, err.Error())
@ -1136,7 +1153,7 @@ func (s *HTTPHandlers) checkWriteAccess(req *http.Request) error {
}
}
return ForbiddenError{}
return ForbiddenError{Reason: "Access is restricted"}
}
func (s *HTTPHandlers) parseFilter(req *http.Request, filter *string) {

View File

@ -323,9 +323,7 @@ func (s *HTTPHandlers) IntentionGetExact(resp http.ResponseWriter, req *http.Req
if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil {
// We have to check the string since the RPC sheds the error type
if err.Error() == consul.ErrIntentionNotFound.Error() {
resp.WriteHeader(http.StatusNotFound)
fmt.Fprint(resp, err.Error())
return nil, nil
return nil, NotFoundError{Reason: err.Error()}
}
// Not ideal, but there are a number of error scenarios that are not
@ -521,9 +519,7 @@ func (s *HTTPHandlers) IntentionSpecificGet(id string, resp http.ResponseWriter,
if err := s.agent.RPC("Intention.Get", &args, &reply); err != nil {
// We have to check the string since the RPC sheds the error type
if err.Error() == consul.ErrIntentionNotFound.Error() {
resp.WriteHeader(http.StatusNotFound)
fmt.Fprint(resp, err.Error())
return nil, nil
return nil, NotFoundError{Reason: err.Error()}
}
// Not ideal, but there are a number of error scenarios that are not

View File

@ -55,8 +55,8 @@ func (s *HTTPHandlers) KVSGet(resp http.ResponseWriter, req *http.Request, args
params := req.URL.Query()
if _, ok := params["recurse"]; ok {
method = "KVS.List"
} else if missingKey(resp, args) {
return nil, nil
} else if args.Key == "" {
return nil, BadRequestError{Reason: "Missing key name"}
}
// Do not allow wildcard NS on GET reqs
@ -156,8 +156,8 @@ func (s *HTTPHandlers) KVSPut(resp http.ResponseWriter, req *http.Request, args
if err := s.parseEntMetaNoWildcard(req, &args.EnterpriseMeta); err != nil {
return nil, err
}
if missingKey(resp, args) {
return nil, nil
if args.Key == "" {
return nil, BadRequestError{Reason: "Missing key name"}
}
if conflictingFlags(resp, req, "cas", "acquire", "release") {
return nil, nil
@ -208,13 +208,10 @@ func (s *HTTPHandlers) KVSPut(resp http.ResponseWriter, req *http.Request, args
// Check the content-length
if req.ContentLength > int64(s.agent.config.KVMaxValueSize) {
resp.WriteHeader(http.StatusRequestEntityTooLarge)
fmt.Fprintf(resp,
"Request body(%d bytes) too large, max size: %d bytes. See %s.",
req.ContentLength, s.agent.config.KVMaxValueSize,
"https://www.consul.io/docs/agent/options.html#kv_max_value_size",
)
return nil, nil
return nil, EntityTooLargeError{
Reason: fmt.Sprintf("Request body(%d bytes) too large, max size: %d bytes. See %s.",
req.ContentLength, s.agent.config.KVMaxValueSize, "https://www.consul.io/docs/agent/options.html#kv_max_value_size"),
}
}
// Copy the value
@ -259,8 +256,8 @@ func (s *HTTPHandlers) KVSDelete(resp http.ResponseWriter, req *http.Request, ar
params := req.URL.Query()
if _, ok := params["recurse"]; ok {
applyReq.Op = api.KVDeleteTree
} else if missingKey(resp, args) {
return nil, nil
} else if args.Key == "" {
return nil, BadRequestError{Reason: "Missing key name"}
}
// Check for cas value
@ -286,16 +283,6 @@ func (s *HTTPHandlers) KVSDelete(resp http.ResponseWriter, req *http.Request, ar
return true, nil
}
// missingKey checks if the key is missing
func missingKey(resp http.ResponseWriter, args *structs.KeyRequest) bool {
if args.Key == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing key name")
return true
}
return false
}
// conflictingFlags determines if non-composable flags were passed in a request.
func conflictingFlags(resp http.ResponseWriter, req *http.Request, flags ...string) bool {
params := req.URL.Query()

View File

@ -23,9 +23,7 @@ func (s *HTTPHandlers) preparedQueryCreate(resp http.ResponseWriter, req *http.R
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)
if err := decodeBody(req.Body, &args.Query); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
var reply string
@ -145,9 +143,7 @@ func (s *HTTPHandlers) preparedQueryExecute(id string, resp http.ResponseWriter,
// We have to check the string since the RPC sheds
// the specific error type.
if structs.IsErrQueryNotFound(err) {
resp.WriteHeader(http.StatusNotFound)
fmt.Fprint(resp, err.Error())
return nil, nil
return nil, NotFoundError{Reason: err.Error()}
}
return nil, err
}
@ -200,9 +196,7 @@ RETRY_ONCE:
// We have to check the string since the RPC sheds
// the specific error type.
if structs.IsErrQueryNotFound(err) {
resp.WriteHeader(http.StatusNotFound)
fmt.Fprint(resp, err.Error())
return nil, nil
return nil, NotFoundError{Reason: err.Error()}
}
return nil, err
}
@ -231,9 +225,7 @@ RETRY_ONCE:
// We have to check the string since the RPC sheds
// the specific error type.
if structs.IsErrQueryNotFound(err) {
resp.WriteHeader(http.StatusNotFound)
fmt.Fprint(resp, err.Error())
return nil, nil
return nil, NotFoundError{Reason: err.Error()}
}
return nil, err
}
@ -255,9 +247,7 @@ func (s *HTTPHandlers) preparedQueryUpdate(id string, resp http.ResponseWriter,
s.parseToken(req, &args.Token)
if req.ContentLength > 0 {
if err := decodeBody(req.Body, &args.Query); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
}

View File

@ -620,11 +620,9 @@ func TestPreparedQuery_Execute(t *testing.T) {
body := bytes.NewBuffer(nil)
req, _ := http.NewRequest("GET", "/v1/query/not-there/execute", body)
resp := httptest.NewRecorder()
if _, err := a.srv.PreparedQuerySpecific(resp, req); err != nil {
t.Fatalf("err: %v", err)
}
if resp.Code != 404 {
t.Fatalf("bad code: %d", resp.Code)
_, err := a.srv.PreparedQuerySpecific(resp, req)
if err, ok := err.(NotFoundError); !ok {
t.Fatalf("Expected not found error but got %v", err)
}
})
}
@ -757,11 +755,9 @@ func TestPreparedQuery_Explain(t *testing.T) {
body := bytes.NewBuffer(nil)
req, _ := http.NewRequest("GET", "/v1/query/not-there/explain", body)
resp := httptest.NewRecorder()
if _, err := a.srv.PreparedQuerySpecific(resp, req); err != nil {
t.Fatalf("err: %v", err)
}
if resp.Code != 404 {
t.Fatalf("bad code: %d", resp.Code)
_, err := a.srv.PreparedQuerySpecific(resp, req)
if err, ok := err.(NotFoundError); !ok {
t.Fatalf("Expected not found error but got %v", err)
}
})
@ -848,11 +844,9 @@ func TestPreparedQuery_Get(t *testing.T) {
body := bytes.NewBuffer(nil)
req, _ := http.NewRequest("GET", "/v1/query/f004177f-2c28-83b7-4229-eacc25fe55d1", body)
resp := httptest.NewRecorder()
if _, err := a.srv.PreparedQuerySpecific(resp, req); err != nil {
t.Fatalf("err: %v", err)
}
if resp.Code != 404 {
t.Fatalf("bad code: %d", resp.Code)
_, err := a.srv.PreparedQuerySpecific(resp, req)
if err, ok := err.(NotFoundError); !ok {
t.Fatalf("Expected not found error but got %v", err)
}
})
}

View File

@ -40,9 +40,7 @@ func (s *HTTPHandlers) SessionCreate(resp http.ResponseWriter, req *http.Request
// Handle optional request body
if req.ContentLength > 0 {
if err := s.rewordUnknownEnterpriseFieldError(lib.DecodeJSON(req.Body, &args.Session)); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
}
@ -77,9 +75,7 @@ func (s *HTTPHandlers) SessionDestroy(resp http.ResponseWriter, req *http.Reques
return nil, err
}
if args.Session.ID == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing session")
return nil, nil
return nil, BadRequestError{Reason: "Missing session"}
}
var out string
@ -107,18 +103,14 @@ func (s *HTTPHandlers) SessionRenew(resp http.ResponseWriter, req *http.Request)
}
args.Session = args.SessionID
if args.SessionID == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing session")
return nil, nil
return nil, BadRequestError{Reason: "Missing session"}
}
var out structs.IndexedSessions
if err := s.agent.RPC("Session.Renew", &args, &out); err != nil {
return nil, err
} else if out.Sessions == nil {
resp.WriteHeader(http.StatusNotFound)
fmt.Fprintf(resp, "Session id '%s' not found", args.SessionID)
return nil, nil
return nil, NotFoundError{Reason: fmt.Sprintf("Session id '%s' not found", args.SessionID)}
}
return out.Sessions, nil
@ -142,9 +134,7 @@ func (s *HTTPHandlers) SessionGet(resp http.ResponseWriter, req *http.Request) (
}
args.Session = args.SessionID
if args.SessionID == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing session")
return nil, nil
return nil, BadRequestError{Reason: "Missing session"}
}
var out structs.IndexedSessions
@ -200,9 +190,7 @@ func (s *HTTPHandlers) SessionsForNode(resp http.ResponseWriter, req *http.Reque
return nil, err
}
if args.Node == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing node name")
return nil, nil
return nil, BadRequestError{Reason: "Missing node name"}
}
var out structs.IndexedSessions

View File

@ -63,7 +63,7 @@ func isWrite(op api.KVOp) bool {
// internal RPC format. This returns a count of the number of write ops, and
// a boolean, that if false means an error response has been generated and
// processing should stop.
func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (structs.TxnOps, int, bool) {
func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (structs.TxnOps, int, error) {
// The TxnMaxReqLen limit and KVMaxValueSize limit both default to the
// suggested raft data size and can be configured independently. The
// TxnMaxReqLen is enforced on the cumulative size of the transaction,
@ -87,13 +87,10 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (
// Check Content-Length first before decoding to return early
if req.ContentLength > maxTxnLen {
resp.WriteHeader(http.StatusRequestEntityTooLarge)
fmt.Fprintf(resp,
"Request body(%d bytes) too large, max size: %d bytes. See %s.",
req.ContentLength, maxTxnLen,
"https://www.consul.io/docs/agent/options.html#txn_max_req_len",
)
return nil, 0, false
return nil, 0, EntityTooLargeError{
Reason: fmt.Sprintf("Request body(%d bytes) too large, max size: %d bytes. See %s.",
req.ContentLength, maxTxnLen, "https://www.consul.io/docs/agent/options.html#txn_max_req_len"),
}
}
var ops api.TxnOps
@ -102,30 +99,24 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (
if err.Error() == "http: request body too large" {
// The request size is also verified during decoding to double check
// if the Content-Length header was not set by the client.
resp.WriteHeader(http.StatusRequestEntityTooLarge)
fmt.Fprintf(resp,
"Request body too large, max size: %d bytes. See %s.",
maxTxnLen,
"https://www.consul.io/docs/agent/options.html#txn_max_req_len",
)
return nil, 0, EntityTooLargeError{
Reason: fmt.Sprintf("Request body too large, max size: %d bytes. See %s.",
maxTxnLen, "https://www.consul.io/docs/agent/options.html#txn_max_req_len"),
}
} else {
// Note the body is in API format, and not the RPC format. If we can't
// decode it, we will return a 400 since we don't have enough context to
// associate the error with a given operation.
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Failed to parse body: %v", err)
return nil, 0, BadRequestError{Reason: fmt.Sprintf("Failed to parse body: %v", err)}
}
return nil, 0, false
}
// Enforce a reasonable upper limit on the number of operations in a
// transaction in order to curb abuse.
if size := len(ops); size > maxTxnOps {
resp.WriteHeader(http.StatusRequestEntityTooLarge)
fmt.Fprintf(resp, "Transaction contains too many operations (%d > %d)",
size, maxTxnOps)
return nil, 0, false
return nil, 0, EntityTooLargeError{
Reason: fmt.Sprintf("Transaction contains too many operations (%d > %d)", size, maxTxnOps),
}
}
// Convert the KV API format into the RPC format. Note that fixupKVOps
@ -138,9 +129,9 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (
case in.KV != nil:
size := len(in.KV.Value)
if int64(size) > kvMaxValueSize {
resp.WriteHeader(http.StatusRequestEntityTooLarge)
fmt.Fprintf(resp, "Value for key %q is too large (%d > %d bytes)", in.KV.Key, size, s.agent.config.KVMaxValueSize)
return nil, 0, false
return nil, 0, EntityTooLargeError{
Reason: fmt.Sprintf("Value for key %q is too large (%d > %d bytes)", in.KV.Key, size, s.agent.config.KVMaxValueSize),
}
}
verb := in.KV.Verb
@ -297,7 +288,7 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (
}
}
return opsRPC, writes, true
return opsRPC, writes, nil
}
// Txn handles requests to apply multiple operations in a single, atomic
@ -306,9 +297,9 @@ func (s *HTTPHandlers) convertOps(resp http.ResponseWriter, req *http.Request) (
// and everything else will be routed through Raft like a normal write.
func (s *HTTPHandlers) Txn(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// Convert the ops from the API format to the internal format.
ops, writes, ok := s.convertOps(resp, req)
if !ok {
return nil, nil
ops, writes, err := s.convertOps(resp, req)
if err != nil {
return nil, err
}
// Fast-path a transaction with only writes to the read-only endpoint,

View File

@ -30,13 +30,12 @@ func TestTxnEndpoint_Bad_JSON(t *testing.T) {
buf := bytes.NewBuffer([]byte("{"))
req, _ := http.NewRequest("PUT", "/v1/txn", buf)
resp := httptest.NewRecorder()
if _, err := a.srv.Txn(resp, req); err != nil {
t.Fatalf("err: %v", err)
_, err := a.srv.Txn(resp, req)
err, ok := err.(BadRequestError)
if !ok {
t.Fatalf("expected bad request error but got %v", err)
}
if resp.Code != 400 {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !bytes.Contains(resp.Body.Bytes(), []byte("Failed to parse")) {
if !strings.Contains(err.Error(), "Failed to parse") {
t.Fatalf("expected conflicting args error")
}
}
@ -63,15 +62,13 @@ func TestTxnEndpoint_Bad_Size_Item(t *testing.T) {
`, value)))
req, _ := http.NewRequest("PUT", "/v1/txn", buf)
resp := httptest.NewRecorder()
if _, err := agent.srv.Txn(resp, req); err != nil {
_, err := agent.srv.Txn(resp, req)
if err, ok := err.(EntityTooLargeError); !ok && !wantPass {
t.Fatalf("expected too large error but got %v", err)
}
if err != nil && wantPass {
t.Fatalf("err: %v", err)
}
if resp.Code != 413 && !wantPass {
t.Fatalf("expected 413, got %d", resp.Code)
}
if resp.Code != 200 && wantPass {
t.Fatalf("expected 200, got %d", resp.Code)
}
}
t.Run("exceeds default limits", func(t *testing.T) {
@ -140,15 +137,13 @@ func TestTxnEndpoint_Bad_Size_Net(t *testing.T) {
`, value, value, value)))
req, _ := http.NewRequest("PUT", "/v1/txn", buf)
resp := httptest.NewRecorder()
if _, err := agent.srv.Txn(resp, req); err != nil {
_, err := agent.srv.Txn(resp, req)
if err, ok := err.(EntityTooLargeError); !ok && !wantPass {
t.Fatalf("expected too large error but got %v", err)
}
if err != nil && wantPass {
t.Fatalf("err: %v", err)
}
if resp.Code != 413 && !wantPass {
t.Fatalf("expected 413, got %d", resp.Code)
}
if resp.Code != 200 && wantPass {
t.Fatalf("expected 200, got %d", resp.Code)
}
}
t.Run("exceeds default limits", func(t *testing.T) {
@ -209,11 +204,9 @@ func TestTxnEndpoint_Bad_Size_Ops(t *testing.T) {
`, strings.Repeat(`{ "KV": { "Verb": "get", "Key": "key" } },`, 2*maxTxnOps))))
req, _ := http.NewRequest("PUT", "/v1/txn", buf)
resp := httptest.NewRecorder()
if _, err := a.srv.Txn(resp, req); err != nil {
t.Fatalf("err: %v", err)
}
if resp.Code != 413 {
t.Fatalf("expected 413, got %d", resp.Code)
_, err := a.srv.Txn(resp, req)
if err, ok := err.(EntityTooLargeError); !ok {
t.Fatalf("expected too large error but got %v", err)
}
}

View File

@ -139,9 +139,7 @@ func (s *HTTPHandlers) UINodeInfo(resp http.ResponseWriter, req *http.Request) (
return nil, err
}
if args.Node == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing node name")
return nil, nil
return nil, BadRequestError{Reason: "Missing node name"}
}
// Make the RPC request
@ -255,9 +253,7 @@ func (s *HTTPHandlers) UIGatewayServicesNodes(resp http.ResponseWriter, req *htt
return nil, err
}
if args.ServiceName == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing gateway name")
return nil, nil
return nil, BadRequestError{Reason: "Missing gateway name"}
}
// Make the RPC request
@ -301,16 +297,12 @@ func (s *HTTPHandlers) UIServiceTopology(resp http.ResponseWriter, req *http.Req
return nil, err
}
if args.ServiceName == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing service name")
return nil, nil
return nil, BadRequestError{Reason: "Missing service name"}
}
kind, ok := req.URL.Query()["kind"]
if !ok {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing service kind")
return nil, nil
return nil, BadRequestError{Reason: "Missing service kind"}
}
args.ServiceKind = structs.ServiceKind(kind[0])
@ -318,9 +310,7 @@ func (s *HTTPHandlers) UIServiceTopology(resp http.ResponseWriter, req *http.Req
case structs.ServiceKindTypical, structs.ServiceKindIngressGateway:
// allowed
default:
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Unsupported service kind %q", args.ServiceKind)
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Unsupported service kind %q", args.ServiceKind)}
}
// Make the RPC request
@ -584,9 +574,7 @@ func (s *HTTPHandlers) UIGatewayIntentions(resp http.ResponseWriter, req *http.R
return nil, err
}
if name == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing gateway name")
return nil, nil
return nil, BadRequestError{Reason: "Missing gateway name"}
}
args.Match = &structs.IntentionQueryMatch{
Type: structs.IntentionMatchDestination,

View File

@ -1409,14 +1409,15 @@ func TestUIServiceTopology(t *testing.T) {
retry.Run(t, func(r *retry.R) {
resp := httptest.NewRecorder()
obj, err := a.srv.UIServiceTopology(resp, tc.httpReq)
assert.Nil(r, err)
if tc.wantErr != "" {
assert.NotNil(r, err)
assert.Nil(r, tc.want) // should not define a non-nil want
require.Equal(r, tc.wantErr, resp.Body.String())
require.Contains(r, err.Error(), tc.wantErr)
require.Nil(r, obj)
return
}
assert.Nil(r, err)
require.NoError(r, checkIndex(resp))
require.NotNil(r, obj)