Ensure consistency with error-handling across all handlers. (#11599)

This commit is contained in:
Mathew Estafanous 2022-01-05 12:11:03 -05:00 committed by GitHub
parent 45b5742973
commit dc18933cc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 107 additions and 140 deletions

View File

@ -16,23 +16,22 @@ type aclBootstrapResponse struct {
structs.ACLToken
}
var aclDisabled = UnauthorizedError{Reason: "ACL support disabled"}
// checkACLDisabled will return a standard response if ACLs are disabled. This
// returns true if they are disabled and we should not continue.
func (s *HTTPHandlers) checkACLDisabled(resp http.ResponseWriter, _req *http.Request) bool {
func (s *HTTPHandlers) checkACLDisabled() bool {
if s.agent.config.ACLsEnabled {
return false
}
resp.WriteHeader(http.StatusUnauthorized)
fmt.Fprint(resp, "ACL support disabled")
return true
}
// ACLBootstrap is used to perform a one-time ACL bootstrap operation on
// a cluster to get the first management token.
func (s *HTTPHandlers) ACLBootstrap(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
args := structs.DCSpecificRequest{
@ -42,9 +41,7 @@ func (s *HTTPHandlers) ACLBootstrap(resp http.ResponseWriter, req *http.Request)
err := s.agent.RPC("ACL.BootstrapTokens", &args, &out)
if err != nil {
if strings.Contains(err.Error(), structs.ACLBootstrapNotAllowedErr.Error()) {
resp.WriteHeader(http.StatusForbidden)
fmt.Fprint(resp, acl.PermissionDeniedError{Cause: err.Error()}.Error())
return nil, nil
return nil, acl.PermissionDeniedError{Cause: err.Error()}
} else {
return nil, err
}
@ -53,8 +50,8 @@ func (s *HTTPHandlers) ACLBootstrap(resp http.ResponseWriter, req *http.Request)
}
func (s *HTTPHandlers) ACLReplicationStatus(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
// Note that we do not forward to the ACL DC here. This is a query for
@ -74,8 +71,8 @@ func (s *HTTPHandlers) ACLReplicationStatus(resp http.ResponseWriter, req *http.
}
func (s *HTTPHandlers) ACLPolicyList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
var args structs.ACLPolicyListRequest
@ -105,8 +102,8 @@ func (s *HTTPHandlers) ACLPolicyList(resp http.ResponseWriter, req *http.Request
}
func (s *HTTPHandlers) ACLPolicyCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
var fn func(resp http.ResponseWriter, req *http.Request, policyID string) (interface{}, error)
@ -166,8 +163,8 @@ func (s *HTTPHandlers) ACLPolicyRead(resp http.ResponseWriter, req *http.Request
}
func (s *HTTPHandlers) ACLPolicyReadByName(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
policyName := strings.TrimPrefix(req.URL.Path, "/v1/acl/policy/name/")
@ -183,8 +180,8 @@ func (s *HTTPHandlers) ACLPolicyReadByID(resp http.ResponseWriter, req *http.Req
}
func (s *HTTPHandlers) ACLPolicyCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
return s.aclPolicyWriteInternal(resp, req, "", true)
@ -248,8 +245,8 @@ func (s *HTTPHandlers) ACLPolicyDelete(resp http.ResponseWriter, req *http.Reque
}
func (s *HTTPHandlers) ACLTokenList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
args := &structs.ACLTokenListRequest{
@ -285,8 +282,8 @@ func (s *HTTPHandlers) ACLTokenList(resp http.ResponseWriter, req *http.Request)
}
func (s *HTTPHandlers) ACLTokenCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
var fn func(resp http.ResponseWriter, req *http.Request, tokenID string) (interface{}, error)
@ -318,8 +315,8 @@ func (s *HTTPHandlers) ACLTokenCRUD(resp http.ResponseWriter, req *http.Request)
}
func (s *HTTPHandlers) ACLTokenSelf(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
args := structs.ACLTokenGetRequest{
@ -351,8 +348,8 @@ func (s *HTTPHandlers) ACLTokenSelf(resp http.ResponseWriter, req *http.Request)
}
func (s *HTTPHandlers) ACLTokenCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
return s.aclTokenSetInternal(req, "", true)
@ -442,8 +439,8 @@ func (s *HTTPHandlers) ACLTokenDelete(resp http.ResponseWriter, req *http.Reques
}
func (s *HTTPHandlers) ACLTokenClone(resp http.ResponseWriter, req *http.Request, tokenID string) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
args := structs.ACLTokenSetRequest{
@ -471,8 +468,8 @@ func (s *HTTPHandlers) ACLTokenClone(resp http.ResponseWriter, req *http.Request
}
func (s *HTTPHandlers) ACLRoleList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
var args structs.ACLRoleListRequest
@ -504,8 +501,8 @@ func (s *HTTPHandlers) ACLRoleList(resp http.ResponseWriter, req *http.Request)
}
func (s *HTTPHandlers) ACLRoleCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
var fn func(resp http.ResponseWriter, req *http.Request, roleID string) (interface{}, error)
@ -533,8 +530,8 @@ func (s *HTTPHandlers) ACLRoleCRUD(resp http.ResponseWriter, req *http.Request)
}
func (s *HTTPHandlers) ACLRoleReadByName(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
roleName := strings.TrimPrefix(req.URL.Path, "/v1/acl/role/name/")
@ -581,8 +578,8 @@ func (s *HTTPHandlers) ACLRoleRead(resp http.ResponseWriter, req *http.Request,
}
func (s *HTTPHandlers) ACLRoleCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
return s.ACLRoleWrite(resp, req, "")
@ -634,8 +631,8 @@ func (s *HTTPHandlers) ACLRoleDelete(resp http.ResponseWriter, req *http.Request
}
func (s *HTTPHandlers) ACLBindingRuleList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
var args structs.ACLBindingRuleListRequest
@ -668,8 +665,8 @@ func (s *HTTPHandlers) ACLBindingRuleList(resp http.ResponseWriter, req *http.Re
}
func (s *HTTPHandlers) ACLBindingRuleCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
var fn func(resp http.ResponseWriter, req *http.Request, bindingRuleID string) (interface{}, error)
@ -728,8 +725,8 @@ func (s *HTTPHandlers) ACLBindingRuleRead(resp http.ResponseWriter, req *http.Re
}
func (s *HTTPHandlers) ACLBindingRuleCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
return s.ACLBindingRuleWrite(resp, req, "")
@ -781,8 +778,8 @@ func (s *HTTPHandlers) ACLBindingRuleDelete(resp http.ResponseWriter, req *http.
}
func (s *HTTPHandlers) ACLAuthMethodList(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
var args structs.ACLAuthMethodListRequest
@ -812,8 +809,8 @@ func (s *HTTPHandlers) ACLAuthMethodList(resp http.ResponseWriter, req *http.Req
}
func (s *HTTPHandlers) ACLAuthMethodCRUD(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
var fn func(resp http.ResponseWriter, req *http.Request, methodName string) (interface{}, error)
@ -872,8 +869,8 @@ func (s *HTTPHandlers) ACLAuthMethodRead(resp http.ResponseWriter, req *http.Req
}
func (s *HTTPHandlers) ACLAuthMethodCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
return s.ACLAuthMethodWrite(resp, req, "")
@ -928,8 +925,8 @@ func (s *HTTPHandlers) ACLAuthMethodDelete(resp http.ResponseWriter, req *http.R
}
func (s *HTTPHandlers) ACLLogin(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
args := &structs.ACLLoginRequest{
@ -954,8 +951,8 @@ func (s *HTTPHandlers) ACLLogin(resp http.ResponseWriter, req *http.Request) (in
}
func (s *HTTPHandlers) ACLLogout(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
args := structs.ACLLogoutRequest{
@ -1014,8 +1011,8 @@ func (s *HTTPHandlers) ACLAuthorize(resp http.ResponseWriter, req *http.Request)
// policy.
const maxRequests = 64
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, aclDisabled
}
request := structs.RemoteACLAuthorizationRequest{

View File

@ -70,10 +70,8 @@ func TestACL_Disabled_Response(t *testing.T) {
req, _ := http.NewRequest("PUT", "/should/not/care", nil)
resp := httptest.NewRecorder()
obj, err := tt.fn(resp, req)
require.NoError(t, err)
require.Nil(t, obj)
require.Equal(t, http.StatusUnauthorized, resp.Code)
require.Contains(t, resp.Body.String(), "ACL support disabled")
require.ErrorIs(t, err, UnauthorizedError{Reason: "ACL support disabled"})
})
}
}
@ -119,9 +117,6 @@ func TestACL_Bootstrap(t *testing.T) {
if tt.token && err != nil {
t.Fatalf("err: %v", err)
}
if got, want := resp.Code, tt.code; got != want {
t.Fatalf("got %d want %d", got, want)
}
if tt.token {
wrap, ok := out.(*aclBootstrapResponse)
if !ok {

View File

@ -155,9 +155,11 @@ func (s *HTTPHandlers) AgentMetrics(resp http.ResponseWriter, req *http.Request)
}
if enablePrometheusOutput(req) {
if s.agent.config.Telemetry.PrometheusOpts.Expiration < 1 {
resp.WriteHeader(http.StatusUnsupportedMediaType)
fmt.Fprint(resp, "Prometheus is not enabled since its retention time is not positive")
return nil, nil
return nil, CodeWithPayloadError{
StatusCode: http.StatusUnsupportedMediaType,
Reason: "Prometheus is not enabled since its retention time is not positive",
ContentType: "text/plain",
}
}
handlerOptions := promhttp.HandlerOpts{
ErrorLog: s.agent.logger.StandardLogger(&hclog.StandardLoggerOptions{
@ -423,11 +425,7 @@ func (s *HTTPHandlers) AgentService(resp http.ResponseWriter, req *http.Request)
svcState := s.agent.State.ServiceState(sid)
if svcState == nil {
resp.WriteHeader(http.StatusNotFound)
fmt.Fprintf(resp,
"Unknown service ID %q. Ensure that the service ID is passed, not the service name.",
sid.String())
return "", nil, nil
return "", nil, NotFoundError{Reason: fmt.Sprintf("unknown service ID: %s", sid.String())}
}
svc := svcState.Service
@ -557,9 +555,7 @@ func (s *HTTPHandlers) AgentMembers(resp http.ResponseWriter, req *http.Request)
// key are ok, otherwise the argument doesn't apply to
// the WAN.
default:
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Cannot provide a segment with wan=true")
return nil, nil
return nil, BadRequestError{Reason: "Cannot provide a segment with wan=true"}
}
}
@ -735,16 +731,16 @@ func (s *HTTPHandlers) AgentRegisterCheck(resp http.ResponseWriter, req *http.Re
}
if err := decodeBody(req.Body, &args); err != nil {
return nil, BadRequestError{fmt.Sprintf("Request decode failed: %v", err)}
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
// Verify the check has a name.
if args.Name == "" {
return nil, BadRequestError{"Missing check name"}
return nil, BadRequestError{Reason: "Missing check name"}
}
if args.Status != "" && !structs.ValidStatus(args.Status) {
return nil, BadRequestError{"Bad check status"}
return nil, BadRequestError{Reason: "Bad check status"}
}
authz, err := s.agent.delegate.ResolveTokenAndDefaultMeta(token, &args.EnterpriseMeta, nil)
@ -763,15 +759,15 @@ func (s *HTTPHandlers) AgentRegisterCheck(resp http.ResponseWriter, req *http.Re
chkType := args.CheckType()
err = chkType.Validate()
if err != nil {
return nil, BadRequestError{fmt.Sprintf("Invalid check: %v", err)}
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check: %v", err)}
}
// Store the type of check based on the definition
health.Type = chkType.Type()
if health.ServiceID != "" {
cid := health.CompoundServiceID()
// fixup the service name so that vetCheckRegister requires the right ACLs
cid := health.CompoundServiceID()
service := s.agent.State.Service(cid)
if service != nil {
health.ServiceName = service.Service
@ -881,9 +877,7 @@ type checkUpdate struct {
func (s *HTTPHandlers) AgentCheckUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
var update checkUpdate
if err := decodeBody(req.Body, &update); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
switch update.Status {
@ -891,9 +885,7 @@ func (s *HTTPHandlers) AgentCheckUpdate(resp http.ResponseWriter, req *http.Requ
case api.HealthWarning:
case api.HealthCritical:
default:
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Invalid check status: '%s'", update.Status)
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check status: '%s'", update.Status)}
}
ID, err := getPathSuffixUnescaped(req.URL.Path, "/v1/agent/check/update/")
@ -1121,24 +1113,18 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
}
if err := decodeBody(req.Body, &args); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Request decode failed: %v", err)}
}
// Verify the service has a name.
if args.Name == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing service name")
return nil, nil
return nil, BadRequestError{Reason: "Missing service name"}
}
// Check the service address here and in the catalog RPC endpoint
// since service registration isn't synchronous.
if ipaddr.IsAny(args.Address) {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Invalid service address")
return nil, nil
return nil, BadRequestError{Reason: "Invalid service address"}
}
var token string
@ -1157,37 +1143,27 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
ns := args.NodeService()
if ns.Weights != nil {
if err := structs.ValidateWeights(ns.Weights); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, fmt.Errorf("Invalid Weights: %v", err))
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid Weights: %v", err)}
}
}
if err := structs.ValidateServiceMetadata(ns.Kind, ns.Meta, false); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, fmt.Errorf("Invalid Service Meta: %v", err))
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid Service Meta: %v", err)}
}
// Run validation. This is the same validation that would happen on
// the catalog endpoint so it helps ensure the sync will work properly.
if err := ns.Validate(); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, err.Error())
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Validation failed: %v", err.Error())}
}
// Verify the check type.
chkTypes, err := args.CheckTypes()
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, fmt.Errorf("Invalid check: %v", err))
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid check: %v", err)}
}
for _, check := range chkTypes {
if check.Status != "" && !structs.ValidStatus(check.Status) {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Status for checks must 'passing', 'warning', 'critical'")
return nil, nil
return nil, BadRequestError{Reason: "Status for checks must 'passing', 'warning', 'critical'"}
}
}
@ -1221,9 +1197,7 @@ func (s *HTTPHandlers) AgentRegisterService(resp http.ResponseWriter, req *http.
}
if sidecar != nil {
if err := sidecar.Validate(); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, err.Error())
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Failed Validation: %v", err.Error())}
}
// Make sure we are allowed to register the sidecar using the token
// specified (might be specific to sidecar or the same one as the overall
@ -1324,25 +1298,19 @@ func (s *HTTPHandlers) AgentServiceMaintenance(resp http.ResponseWriter, req *ht
sid := structs.NewServiceID(serviceID, nil)
if sid.ID == "" {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing service ID")
return nil, nil
return nil, BadRequestError{Reason: "Missing service ID"}
}
// Ensure we have some action
params := req.URL.Query()
if _, ok := params["enable"]; !ok {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing value for enable")
return nil, nil
return nil, BadRequestError{Reason: "Missing value for enable"}
}
raw := params.Get("enable")
enable, err := strconv.ParseBool(raw)
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Invalid value for enable: %q", raw)
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid value for enable: %q", raw)}
}
// Get the provided token, if any, and vet against any ACL policies.
@ -1371,15 +1339,11 @@ func (s *HTTPHandlers) AgentServiceMaintenance(resp http.ResponseWriter, req *ht
if enable {
reason := params.Get("reason")
if err = s.agent.EnableServiceMaintenance(sid, reason, token); err != nil {
resp.WriteHeader(http.StatusNotFound)
fmt.Fprint(resp, err.Error())
return nil, nil
return nil, NotFoundError{Reason: err.Error()}
}
} else {
if err = s.agent.DisableServiceMaintenance(sid); err != nil {
resp.WriteHeader(http.StatusNotFound)
fmt.Fprint(resp, err.Error())
return nil, nil
return nil, NotFoundError{Reason: err.Error()}
}
}
s.syncChanges()
@ -1390,17 +1354,13 @@ func (s *HTTPHandlers) AgentNodeMaintenance(resp http.ResponseWriter, req *http.
// Ensure we have some action
params := req.URL.Query()
if _, ok := params["enable"]; !ok {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, "Missing value for enable")
return nil, nil
return nil, BadRequestError{Reason: "Missing value for enable"}
}
raw := params.Get("enable")
enable, err := strconv.ParseBool(raw)
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Invalid value for enable: %q", raw)
return nil, nil
return nil, BadRequestError{Reason: fmt.Sprintf("Invalid value for enable: %q", raw)}
}
// Get the provided token, if any, and vet against any ACL policies.
@ -1507,8 +1467,8 @@ func (s *HTTPHandlers) AgentMonitor(resp http.ResponseWriter, req *http.Request)
}
func (s *HTTPHandlers) AgentToken(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
if s.checkACLDisabled(resp, req) {
return nil, nil
if s.checkACLDisabled() {
return nil, UnauthorizedError{Reason: "ACL support disabled"}
}
// Fetch the ACL token, if any, and enforce agent policy.

View File

@ -660,9 +660,9 @@ func TestAgent_Service(t *testing.T) {
wantResp: &updatedResponse,
},
{
name: "err: non-existent proxy",
url: "/v1/agent/service/nope",
wantCode: 404,
name: "err: non-existent proxy",
url: "/v1/agent/service/nope",
wantErr: "unknown service ID: nope",
},
{
name: "err: bad ACL for service",
@ -3784,9 +3784,6 @@ func testAgent_RegisterService_InvalidAddress(t *testing.T, extraHCL string) {
if got, want := resp.Code, 400; got != want {
t.Fatalf("got code %d want %d", got, want)
}
if got, want := resp.Body.String(), "Invalid service address"; got != want {
t.Fatalf("got body %q want %q", got, want)
}
})
}
}

View File

@ -69,6 +69,15 @@ func (e NotFoundError) Error() string {
return e.Reason
}
// UnauthorizedError should be returned by a handler when the request lacks valid authorization.
type UnauthorizedError struct {
Reason string
}
func (e UnauthorizedError) Error() string {
return e.Reason
}
// CodeWithPayloadError allow returning non HTTP 200
// Error codes while not returning PlainText payload
type CodeWithPayloadError struct {
@ -241,7 +250,8 @@ func (s *HTTPHandlers) handler(enableDebug bool) http.Handler {
// If enableDebug is not set, and ACLs are disabled, write
// an unauthorized response
if !enableDebug && s.checkACLDisabled(resp, req) {
if !enableDebug && s.checkACLDisabled() {
resp.WriteHeader(http.StatusUnauthorized)
return
}
@ -423,6 +433,11 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc
return ok
}
isUnauthorized := func(err error) bool {
_, ok := err.(UnauthorizedError)
return ok
}
isTooManyRequests := func(err error) bool {
// Sadness net/rpc can't do nice typed errors so this is all we got
return err.Error() == consul.ErrRateLimited.Error()
@ -467,6 +482,9 @@ func (s *HTTPHandlers) wrap(handler endpoint, methods []string) http.HandlerFunc
case isNotFound(err):
resp.WriteHeader(http.StatusNotFound)
fmt.Fprint(resp, err.Error())
case isUnauthorized(err):
resp.WriteHeader(http.StatusUnauthorized)
fmt.Fprint(resp, err.Error())
case isTooManyRequests(err):
resp.WriteHeader(http.StatusTooManyRequests)
fmt.Fprint(resp, err.Error())