diff --git a/agent/acl_endpoint.go b/agent/acl_endpoint.go index 9a40a6596..94a845d07 100644 --- a/agent/acl_endpoint.go +++ b/agent/acl_endpoint.go @@ -32,9 +32,6 @@ func (s *HTTPServer) ACLBootstrap(resp http.ResponseWriter, req *http.Request) ( if s.checkACLDisabled(resp, req) { return nil, nil } - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } args := structs.DCSpecificRequest{ Datacenter: s.agent.config.ACLDatacenter, @@ -59,9 +56,6 @@ func (s *HTTPServer) ACLDestroy(resp http.ResponseWriter, req *http.Request) (in if s.checkACLDisabled(resp, req) { return nil, nil } - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } args := structs.ACLRequest{ Datacenter: s.agent.config.ACLDatacenter, @@ -88,9 +82,6 @@ func (s *HTTPServer) ACLCreate(resp http.ResponseWriter, req *http.Request) (int if s.checkACLDisabled(resp, req) { return nil, nil } - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } return s.aclSet(resp, req, false) } @@ -98,17 +89,10 @@ func (s *HTTPServer) ACLUpdate(resp http.ResponseWriter, req *http.Request) (int if s.checkACLDisabled(resp, req) { return nil, nil } - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } return s.aclSet(resp, req, true) } func (s *HTTPServer) aclSet(resp http.ResponseWriter, req *http.Request, update bool) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - args := structs.ACLRequest{ Datacenter: s.agent.config.ACLDatacenter, Op: structs.ACLSet, @@ -149,9 +133,6 @@ func (s *HTTPServer) ACLClone(resp http.ResponseWriter, req *http.Request) (inte if s.checkACLDisabled(resp, req) { return nil, nil } - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } args := structs.ACLSpecificRequest{ Datacenter: s.agent.config.ACLDatacenter, @@ -204,9 +185,6 @@ func (s *HTTPServer) ACLGet(resp http.ResponseWriter, req *http.Request) (interf if s.checkACLDisabled(resp, req) { return nil, nil } - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } args := structs.ACLSpecificRequest{ Datacenter: s.agent.config.ACLDatacenter, @@ -241,9 +219,6 @@ func (s *HTTPServer) ACLList(resp http.ResponseWriter, req *http.Request) (inter if s.checkACLDisabled(resp, req) { return nil, nil } - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } args := structs.DCSpecificRequest{ Datacenter: s.agent.config.ACLDatacenter, @@ -270,9 +245,6 @@ func (s *HTTPServer) ACLReplicationStatus(resp http.ResponseWriter, req *http.Re if s.checkACLDisabled(resp, req) { return nil, nil } - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } // Note that we do not forward to the ACL DC here. This is a query for // any DC that's doing replication. diff --git a/agent/agent_endpoint.go b/agent/agent_endpoint.go index 5b4040f0c..20388a38e 100644 --- a/agent/agent_endpoint.go +++ b/agent/agent_endpoint.go @@ -31,10 +31,6 @@ type Self struct { } func (s *HTTPServer) AgentSelf(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - // Fetch the ACL token, if any, and enforce agent policy. var token string s.parseToken(req, &token) @@ -80,10 +76,6 @@ func (s *HTTPServer) AgentSelf(resp http.ResponseWriter, req *http.Request) (int } func (s *HTTPServer) AgentMetrics(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - // Fetch the ACL token, if any, and enforce agent policy. var token string s.parseToken(req, &token) @@ -99,10 +91,6 @@ func (s *HTTPServer) AgentMetrics(resp http.ResponseWriter, req *http.Request) ( } func (s *HTTPServer) AgentReload(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - // Fetch the ACL token, if any, and enforce agent policy. var token string s.parseToken(req, &token) @@ -132,10 +120,6 @@ func (s *HTTPServer) AgentReload(resp http.ResponseWriter, req *http.Request) (i } func (s *HTTPServer) AgentServices(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - // Fetch the ACL token, if any. var token string s.parseToken(req, &token) @@ -158,10 +142,6 @@ func (s *HTTPServer) AgentServices(resp http.ResponseWriter, req *http.Request) } func (s *HTTPServer) AgentChecks(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - // Fetch the ACL token, if any. var token string s.parseToken(req, &token) @@ -184,10 +164,6 @@ func (s *HTTPServer) AgentChecks(resp http.ResponseWriter, req *http.Request) (i } func (s *HTTPServer) AgentMembers(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - // Fetch the ACL token, if any. var token string s.parseToken(req, &token) @@ -233,10 +209,6 @@ func (s *HTTPServer) AgentMembers(resp http.ResponseWriter, req *http.Request) ( } func (s *HTTPServer) AgentJoin(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - // Fetch the ACL token, if any, and enforce agent policy. var token string s.parseToken(req, &token) @@ -265,10 +237,6 @@ func (s *HTTPServer) AgentJoin(resp http.ResponseWriter, req *http.Request) (int } func (s *HTTPServer) AgentLeave(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - // Fetch the ACL token, if any, and enforce agent policy. var token string s.parseToken(req, &token) @@ -287,10 +255,6 @@ func (s *HTTPServer) AgentLeave(resp http.ResponseWriter, req *http.Request) (in } func (s *HTTPServer) AgentForceLeave(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - // Fetch the ACL token, if any, and enforce agent policy. var token string s.parseToken(req, &token) @@ -316,10 +280,6 @@ func (s *HTTPServer) syncChanges() { } func (s *HTTPServer) AgentRegisterCheck(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - var args structs.CheckDefinition // Fixup the type decode of TTL or Interval. decodeCB := func(raw interface{}) error { @@ -372,10 +332,6 @@ func (s *HTTPServer) AgentRegisterCheck(resp http.ResponseWriter, req *http.Requ } func (s *HTTPServer) AgentDeregisterCheck(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - checkID := types.CheckID(strings.TrimPrefix(req.URL.Path, "/v1/agent/check/deregister/")) // Get the provided token, if any, and vet against any ACL policies. @@ -393,10 +349,6 @@ func (s *HTTPServer) AgentDeregisterCheck(resp http.ResponseWriter, req *http.Re } func (s *HTTPServer) AgentCheckPass(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - checkID := types.CheckID(strings.TrimPrefix(req.URL.Path, "/v1/agent/check/pass/")) note := req.URL.Query().Get("note") @@ -415,10 +367,6 @@ func (s *HTTPServer) AgentCheckPass(resp http.ResponseWriter, req *http.Request) } func (s *HTTPServer) AgentCheckWarn(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - checkID := types.CheckID(strings.TrimPrefix(req.URL.Path, "/v1/agent/check/warn/")) note := req.URL.Query().Get("note") @@ -437,10 +385,6 @@ func (s *HTTPServer) AgentCheckWarn(resp http.ResponseWriter, req *http.Request) } func (s *HTTPServer) AgentCheckFail(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - checkID := types.CheckID(strings.TrimPrefix(req.URL.Path, "/v1/agent/check/fail/")) note := req.URL.Query().Get("note") @@ -474,10 +418,6 @@ type checkUpdate struct { // AgentCheckUpdate is a PUT-based alternative to the GET-based Pass/Warn/Fail // APIs. func (s *HTTPServer) AgentCheckUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - var update checkUpdate if err := decodeBody(req, &update, nil); err != nil { resp.WriteHeader(http.StatusBadRequest) @@ -518,10 +458,6 @@ func (s *HTTPServer) AgentCheckUpdate(resp http.ResponseWriter, req *http.Reques } func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - var args structs.ServiceDefinition // Fixup the type decode of TTL or Interval if a check if provided. decodeCB := func(raw interface{}) error { @@ -611,10 +547,6 @@ func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Re } func (s *HTTPServer) AgentDeregisterService(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - serviceID := strings.TrimPrefix(req.URL.Path, "/v1/agent/service/deregister/") // Get the provided token, if any, and vet against any ACL policies. @@ -632,10 +564,6 @@ func (s *HTTPServer) AgentDeregisterService(resp http.ResponseWriter, req *http. } func (s *HTTPServer) AgentServiceMaintenance(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - // Ensure we have a service ID serviceID := strings.TrimPrefix(req.URL.Path, "/v1/agent/service/maintenance/") if serviceID == "" { @@ -686,10 +614,6 @@ func (s *HTTPServer) AgentServiceMaintenance(resp http.ResponseWriter, req *http } func (s *HTTPServer) AgentNodeMaintenance(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - // Ensure we have some action params := req.URL.Query() if _, ok := params["enable"]; !ok { @@ -727,10 +651,6 @@ func (s *HTTPServer) AgentNodeMaintenance(resp http.ResponseWriter, req *http.Re } func (s *HTTPServer) AgentMonitor(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - // Fetch the ACL token, if any, and enforce agent policy. var token string s.parseToken(req, &token) @@ -821,9 +741,6 @@ func (s *HTTPServer) AgentToken(resp http.ResponseWriter, req *http.Request) (in if s.checkACLDisabled(resp, req) { return nil, nil } - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } // Fetch the ACL token, if any, and enforce agent policy. var token string diff --git a/agent/catalog_endpoint.go b/agent/catalog_endpoint.go index ca781b36f..ac330f3b1 100644 --- a/agent/catalog_endpoint.go +++ b/agent/catalog_endpoint.go @@ -14,9 +14,6 @@ var durations = NewDurationFixer("interval", "timeout", "deregistercriticalservi func (s *HTTPServer) CatalogRegister(resp http.ResponseWriter, req *http.Request) (interface{}, error) { metrics.IncrCounterWithLabels([]string{"client", "api", "catalog_register"}, 1, []metrics.Label{{Name: "node", Value: s.nodeName()}}) - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } var args structs.RegisterRequest if err := decodeBody(req, &args, durations.FixupDurations); err != nil { @@ -46,9 +43,6 @@ func (s *HTTPServer) CatalogRegister(resp http.ResponseWriter, req *http.Request func (s *HTTPServer) CatalogDeregister(resp http.ResponseWriter, req *http.Request) (interface{}, error) { metrics.IncrCounterWithLabels([]string{"client", "api", "catalog_deregister"}, 1, []metrics.Label{{Name: "node", Value: s.nodeName()}}) - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } var args structs.DeregisterRequest if err := decodeBody(req, &args, nil); err != nil { @@ -78,9 +72,6 @@ func (s *HTTPServer) CatalogDeregister(resp http.ResponseWriter, req *http.Reque func (s *HTTPServer) CatalogDatacenters(resp http.ResponseWriter, req *http.Request) (interface{}, error) { metrics.IncrCounterWithLabels([]string{"client", "api", "catalog_datacenters"}, 1, []metrics.Label{{Name: "node", Value: s.nodeName()}}) - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } var out []string if err := s.agent.RPC("Catalog.ListDatacenters", struct{}{}, &out); err != nil { @@ -96,9 +87,6 @@ func (s *HTTPServer) CatalogDatacenters(resp http.ResponseWriter, req *http.Requ func (s *HTTPServer) CatalogNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) { metrics.IncrCounterWithLabels([]string{"client", "api", "catalog_nodes"}, 1, []metrics.Label{{Name: "node", Value: s.nodeName()}}) - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } // Setup the request args := structs.DCSpecificRequest{} @@ -129,9 +117,6 @@ func (s *HTTPServer) CatalogNodes(resp http.ResponseWriter, req *http.Request) ( func (s *HTTPServer) CatalogServices(resp http.ResponseWriter, req *http.Request) (interface{}, error) { metrics.IncrCounterWithLabels([]string{"client", "api", "catalog_services"}, 1, []metrics.Label{{Name: "node", Value: s.nodeName()}}) - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } // Set default DC args := structs.DCSpecificRequest{} @@ -160,9 +145,6 @@ func (s *HTTPServer) CatalogServices(resp http.ResponseWriter, req *http.Request func (s *HTTPServer) CatalogServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) { metrics.IncrCounterWithLabels([]string{"client", "api", "catalog_service_nodes"}, 1, []metrics.Label{{Name: "node", Value: s.nodeName()}}) - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } // Set default DC args := structs.ServiceSpecificRequest{} @@ -216,9 +198,6 @@ func (s *HTTPServer) CatalogServiceNodes(resp http.ResponseWriter, req *http.Req func (s *HTTPServer) CatalogNodeServices(resp http.ResponseWriter, req *http.Request) (interface{}, error) { metrics.IncrCounterWithLabels([]string{"client", "api", "catalog_node_services"}, 1, []metrics.Label{{Name: "node", Value: s.nodeName()}}) - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } // Set default Datacenter args := structs.NodeSpecificRequest{} diff --git a/agent/coordinate_endpoint.go b/agent/coordinate_endpoint.go index 6bbfdb015..f81f4753f 100644 --- a/agent/coordinate_endpoint.go +++ b/agent/coordinate_endpoint.go @@ -48,9 +48,6 @@ func (s *HTTPServer) CoordinateDatacenters(resp http.ResponseWriter, req *http.R if s.checkCoordinateDisabled(resp, req) { return nil, nil } - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } var out []structs.DatacenterMap if err := s.agent.RPC("Coordinate.ListDatacenters", struct{}{}, &out); err != nil { @@ -80,9 +77,6 @@ func (s *HTTPServer) CoordinateNodes(resp http.ResponseWriter, req *http.Request if s.checkCoordinateDisabled(resp, req) { return nil, nil } - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } args := structs.DCSpecificRequest{} if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { @@ -105,9 +99,6 @@ func (s *HTTPServer) CoordinateNode(resp http.ResponseWriter, req *http.Request) if s.checkCoordinateDisabled(resp, req) { return nil, nil } - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } node := strings.TrimPrefix(req.URL.Path, "/v1/coordinate/node/") args := structs.NodeSpecificRequest{Node: node} @@ -157,9 +148,6 @@ func (s *HTTPServer) CoordinateUpdate(resp http.ResponseWriter, req *http.Reques if s.checkCoordinateDisabled(resp, req) { return nil, nil } - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } args := structs.CoordinateUpdateRequest{} if err := decodeBody(req, &args, nil); err != nil { diff --git a/agent/event_endpoint.go b/agent/event_endpoint.go index 132106cc1..b9fd0d1f4 100644 --- a/agent/event_endpoint.go +++ b/agent/event_endpoint.go @@ -20,9 +20,6 @@ const ( // EventFire is used to fire a new event func (s *HTTPServer) EventFire(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } // Get the datacenter var dc string @@ -77,10 +74,6 @@ func (s *HTTPServer) EventFire(resp http.ResponseWriter, req *http.Request) (int // EventList is used to retrieve the recent list of events func (s *HTTPServer) EventList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - // Parse the query options, since we simulate a blocking query var b structs.QueryOptions if parseWait(resp, req, &b) { diff --git a/agent/health_endpoint.go b/agent/health_endpoint.go index e3cf4539a..c04cc423b 100644 --- a/agent/health_endpoint.go +++ b/agent/health_endpoint.go @@ -11,10 +11,6 @@ import ( ) func (s *HTTPServer) HealthChecksInState(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - // Set default DC args := structs.ChecksInStateRequest{} s.parseSource(req, &args.Source) @@ -53,10 +49,6 @@ func (s *HTTPServer) HealthChecksInState(resp http.ResponseWriter, req *http.Req } func (s *HTTPServer) HealthNodeChecks(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - // Set default DC args := structs.NodeSpecificRequest{} if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { @@ -93,10 +85,6 @@ func (s *HTTPServer) HealthNodeChecks(resp http.ResponseWriter, req *http.Reques } func (s *HTTPServer) HealthServiceChecks(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - // Set default DC args := structs.ServiceSpecificRequest{} s.parseSource(req, &args.Source) @@ -135,10 +123,6 @@ func (s *HTTPServer) HealthServiceChecks(resp http.ResponseWriter, req *http.Req } func (s *HTTPServer) HealthServiceNodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - // Set default DC args := structs.ServiceSpecificRequest{} s.parseSource(req, &args.Source) diff --git a/agent/http.go b/agent/http.go index 51bee5766..823606f2f 100644 --- a/agent/http.go +++ b/agent/http.go @@ -51,16 +51,21 @@ type unboundEndpoint func(s *HTTPServer, resp http.ResponseWriter, req *http.Req // endpoints is a map from URL pattern to unbound endpoint. var endpoints map[string]unboundEndpoint +// allowedMethods is a map from endpoint prefix to supported HTTP methods. +// An empty slice means an endpoint handles OPTIONS requests and MethodNotFound errors itself. +var allowedMethods map[string][]string + // registerEndpoint registers a new endpoint, which should be done at package // init() time. -func registerEndpoint(pattern string, fn unboundEndpoint) { +func registerEndpoint(pattern string, methods []string, fn unboundEndpoint) { if endpoints == nil { endpoints = make(map[string]unboundEndpoint) } - if endpoints[pattern] != nil { + if endpoints[pattern] != nil || allowedMethods[pattern] != nil { panic(fmt.Errorf("Pattern %q is already registered", pattern)) } endpoints[pattern] = fn + allowedMethods[pattern] = methods } // wrappedMux hangs on to the underlying mux for unit tests. @@ -112,10 +117,11 @@ func (s *HTTPServer) handler(enableDebug bool) http.Handler { mux.HandleFunc("/", s.Index) for pattern, fn := range endpoints { thisFn := fn + methods, _ := allowedMethods[pattern] bound := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) { return thisFn(s, resp, req) } - handleFuncMetrics(pattern, s.wrap(bound)) + handleFuncMetrics(pattern, s.wrap(bound, methods)) } if enableDebug { handleFuncMetrics("/debug/pprof/", pprof.Index) @@ -168,7 +174,7 @@ var ( ) // wrap is used to wrap functions to make them more convenient -func (s *HTTPServer) wrap(handler endpoint) http.HandlerFunc { +func (s *HTTPServer) wrap(handler endpoint, methods []string) http.HandlerFunc { return func(resp http.ResponseWriter, req *http.Request) { setHeaders(resp, s.agent.config.HTTPResponseHeaders) setTranslateAddr(resp, s.agent.config.TranslateWANAddrs) @@ -205,6 +211,10 @@ func (s *HTTPServer) wrap(handler endpoint) http.HandlerFunc { return ok } + addAllowHeader := func(methods []string) { + resp.Header().Add("Allow", strings.Join(methods, ",")) + } + handleErr := func(err error) { s.agent.logger.Printf("[ERR] http: Request %s %v, error: %v from=%s", req.Method, logURL, err, req.RemoteAddr) switch { @@ -218,7 +228,7 @@ func (s *HTTPServer) wrap(handler endpoint) http.HandlerFunc { // MUST include an Allow header containing the list of valid // methods for the requested resource. // https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html - resp.Header()["Allow"] = err.(MethodNotAllowedError).Allow + addAllowHeader(err.(MethodNotAllowedError).Allow) resp.WriteHeader(http.StatusMethodNotAllowed) // 405 fmt.Fprint(resp, err.Error()) default: @@ -227,12 +237,35 @@ func (s *HTTPServer) wrap(handler endpoint) http.HandlerFunc { } } - // Invoke the handler start := time.Now() defer func() { s.agent.logger.Printf("[DEBUG] http: Request %s %v (%v) from=%s", req.Method, logURL, time.Since(start), req.RemoteAddr) }() - obj, err := handler(resp, req) + + var obj interface{} + + // if this endpoint has declared methods, respond appropriately to OPTIONS requests. Otherwise let the endpoint handle that. + if req.Method == "OPTIONS" && len(methods) > 0 { + addAllowHeader(append([]string{"OPTIONS"}, methods...)) + return + } + + // if this endpoint has declared methods, check the request method. Otherwise let the endpoint handle that. + methodFound := len(methods) == 0 + for _, method := range methods { + if method == req.Method { + methodFound = true + break + } + } + + if !methodFound { + err = MethodNotAllowedError{req.Method, append([]string{"OPTIONS"}, methods...)} + } else { + // Invoke the handler + obj, err = handler(resp, req) + } + if err != nil { handleErr(err) return diff --git a/agent/http_oss.go b/agent/http_oss.go index 10e9abda0..4a2017d28 100644 --- a/agent/http_oss.go +++ b/agent/http_oss.go @@ -1,71 +1,75 @@ package agent func init() { - registerEndpoint("/v1/acl/bootstrap", (*HTTPServer).ACLBootstrap) - registerEndpoint("/v1/acl/create", (*HTTPServer).ACLCreate) - registerEndpoint("/v1/acl/update", (*HTTPServer).ACLUpdate) - registerEndpoint("/v1/acl/destroy/", (*HTTPServer).ACLDestroy) - registerEndpoint("/v1/acl/info/", (*HTTPServer).ACLGet) - registerEndpoint("/v1/acl/clone/", (*HTTPServer).ACLClone) - registerEndpoint("/v1/acl/list", (*HTTPServer).ACLList) - registerEndpoint("/v1/acl/replication", (*HTTPServer).ACLReplicationStatus) - registerEndpoint("/v1/agent/token/", (*HTTPServer).AgentToken) - registerEndpoint("/v1/agent/self", (*HTTPServer).AgentSelf) - registerEndpoint("/v1/agent/maintenance", (*HTTPServer).AgentNodeMaintenance) - registerEndpoint("/v1/agent/reload", (*HTTPServer).AgentReload) - registerEndpoint("/v1/agent/monitor", (*HTTPServer).AgentMonitor) - registerEndpoint("/v1/agent/metrics", (*HTTPServer).AgentMetrics) - registerEndpoint("/v1/agent/services", (*HTTPServer).AgentServices) - registerEndpoint("/v1/agent/checks", (*HTTPServer).AgentChecks) - registerEndpoint("/v1/agent/members", (*HTTPServer).AgentMembers) - registerEndpoint("/v1/agent/join/", (*HTTPServer).AgentJoin) - registerEndpoint("/v1/agent/leave", (*HTTPServer).AgentLeave) - registerEndpoint("/v1/agent/force-leave/", (*HTTPServer).AgentForceLeave) - registerEndpoint("/v1/agent/check/register", (*HTTPServer).AgentRegisterCheck) - registerEndpoint("/v1/agent/check/deregister/", (*HTTPServer).AgentDeregisterCheck) - registerEndpoint("/v1/agent/check/pass/", (*HTTPServer).AgentCheckPass) - registerEndpoint("/v1/agent/check/warn/", (*HTTPServer).AgentCheckWarn) - registerEndpoint("/v1/agent/check/fail/", (*HTTPServer).AgentCheckFail) - registerEndpoint("/v1/agent/check/update/", (*HTTPServer).AgentCheckUpdate) - registerEndpoint("/v1/agent/service/register", (*HTTPServer).AgentRegisterService) - registerEndpoint("/v1/agent/service/deregister/", (*HTTPServer).AgentDeregisterService) - registerEndpoint("/v1/agent/service/maintenance/", (*HTTPServer).AgentServiceMaintenance) - registerEndpoint("/v1/catalog/register", (*HTTPServer).CatalogRegister) - registerEndpoint("/v1/catalog/deregister", (*HTTPServer).CatalogDeregister) - registerEndpoint("/v1/catalog/datacenters", (*HTTPServer).CatalogDatacenters) - registerEndpoint("/v1/catalog/nodes", (*HTTPServer).CatalogNodes) - registerEndpoint("/v1/catalog/services", (*HTTPServer).CatalogServices) - registerEndpoint("/v1/catalog/service/", (*HTTPServer).CatalogServiceNodes) - registerEndpoint("/v1/catalog/node/", (*HTTPServer).CatalogNodeServices) - registerEndpoint("/v1/coordinate/datacenters", (*HTTPServer).CoordinateDatacenters) - registerEndpoint("/v1/coordinate/nodes", (*HTTPServer).CoordinateNodes) - registerEndpoint("/v1/coordinate/node/", (*HTTPServer).CoordinateNode) - registerEndpoint("/v1/coordinate/update", (*HTTPServer).CoordinateUpdate) - registerEndpoint("/v1/event/fire/", (*HTTPServer).EventFire) - registerEndpoint("/v1/event/list", (*HTTPServer).EventList) - registerEndpoint("/v1/health/node/", (*HTTPServer).HealthNodeChecks) - registerEndpoint("/v1/health/checks/", (*HTTPServer).HealthServiceChecks) - registerEndpoint("/v1/health/state/", (*HTTPServer).HealthChecksInState) - registerEndpoint("/v1/health/service/", (*HTTPServer).HealthServiceNodes) - registerEndpoint("/v1/internal/ui/nodes", (*HTTPServer).UINodes) - registerEndpoint("/v1/internal/ui/node/", (*HTTPServer).UINodeInfo) - registerEndpoint("/v1/internal/ui/services", (*HTTPServer).UIServices) - registerEndpoint("/v1/kv/", (*HTTPServer).KVSEndpoint) - registerEndpoint("/v1/operator/raft/configuration", (*HTTPServer).OperatorRaftConfiguration) - registerEndpoint("/v1/operator/raft/peer", (*HTTPServer).OperatorRaftPeer) - registerEndpoint("/v1/operator/keyring", (*HTTPServer).OperatorKeyringEndpoint) - registerEndpoint("/v1/operator/autopilot/configuration", (*HTTPServer).OperatorAutopilotConfiguration) - registerEndpoint("/v1/operator/autopilot/health", (*HTTPServer).OperatorServerHealth) - registerEndpoint("/v1/query", (*HTTPServer).PreparedQueryGeneral) - registerEndpoint("/v1/query/", (*HTTPServer).PreparedQuerySpecific) - registerEndpoint("/v1/session/create", (*HTTPServer).SessionCreate) - registerEndpoint("/v1/session/destroy/", (*HTTPServer).SessionDestroy) - registerEndpoint("/v1/session/renew/", (*HTTPServer).SessionRenew) - registerEndpoint("/v1/session/info/", (*HTTPServer).SessionGet) - registerEndpoint("/v1/session/node/", (*HTTPServer).SessionsForNode) - registerEndpoint("/v1/session/list", (*HTTPServer).SessionList) - registerEndpoint("/v1/status/leader", (*HTTPServer).StatusLeader) - registerEndpoint("/v1/status/peers", (*HTTPServer).StatusPeers) - registerEndpoint("/v1/snapshot", (*HTTPServer).Snapshot) - registerEndpoint("/v1/txn", (*HTTPServer).Txn) + allowedMethods = make(map[string][]string) + + registerEndpoint("/v1/acl/bootstrap", []string{"PUT"}, (*HTTPServer).ACLBootstrap) + registerEndpoint("/v1/acl/create", []string{"PUT"}, (*HTTPServer).ACLCreate) + registerEndpoint("/v1/acl/update", []string{"PUT"}, (*HTTPServer).ACLUpdate) + registerEndpoint("/v1/acl/destroy/", []string{"PUT"}, (*HTTPServer).ACLDestroy) + registerEndpoint("/v1/acl/info/", []string{"GET"}, (*HTTPServer).ACLGet) + registerEndpoint("/v1/acl/clone/", []string{"PUT"}, (*HTTPServer).ACLClone) + registerEndpoint("/v1/acl/list", []string{"GET"}, (*HTTPServer).ACLList) + registerEndpoint("/v1/acl/replication", []string{"GET"}, (*HTTPServer).ACLReplicationStatus) + registerEndpoint("/v1/agent/token/", []string{"PUT"}, (*HTTPServer).AgentToken) + registerEndpoint("/v1/agent/self", []string{"GET"}, (*HTTPServer).AgentSelf) + registerEndpoint("/v1/agent/maintenance", []string{"PUT"}, (*HTTPServer).AgentNodeMaintenance) + registerEndpoint("/v1/agent/reload", []string{"PUT"}, (*HTTPServer).AgentReload) + registerEndpoint("/v1/agent/monitor", []string{"GET"}, (*HTTPServer).AgentMonitor) + registerEndpoint("/v1/agent/metrics", []string{"GET"}, (*HTTPServer).AgentMetrics) + registerEndpoint("/v1/agent/services", []string{"GET"}, (*HTTPServer).AgentServices) + registerEndpoint("/v1/agent/checks", []string{"GET"}, (*HTTPServer).AgentChecks) + registerEndpoint("/v1/agent/members", []string{"GET"}, (*HTTPServer).AgentMembers) + registerEndpoint("/v1/agent/join/", []string{"PUT"}, (*HTTPServer).AgentJoin) + registerEndpoint("/v1/agent/leave", []string{"PUT"}, (*HTTPServer).AgentLeave) + registerEndpoint("/v1/agent/force-leave/", []string{"PUT"}, (*HTTPServer).AgentForceLeave) + registerEndpoint("/v1/agent/check/register", []string{"PUT"}, (*HTTPServer).AgentRegisterCheck) + registerEndpoint("/v1/agent/check/deregister/", []string{"PUT"}, (*HTTPServer).AgentDeregisterCheck) + registerEndpoint("/v1/agent/check/pass/", []string{"PUT"}, (*HTTPServer).AgentCheckPass) + registerEndpoint("/v1/agent/check/warn/", []string{"PUT"}, (*HTTPServer).AgentCheckWarn) + registerEndpoint("/v1/agent/check/fail/", []string{"PUT"}, (*HTTPServer).AgentCheckFail) + registerEndpoint("/v1/agent/check/update/", []string{"PUT"}, (*HTTPServer).AgentCheckUpdate) + registerEndpoint("/v1/agent/service/register", []string{"PUT"}, (*HTTPServer).AgentRegisterService) + registerEndpoint("/v1/agent/service/deregister/", []string{"PUT"}, (*HTTPServer).AgentDeregisterService) + registerEndpoint("/v1/agent/service/maintenance/", []string{"PUT"}, (*HTTPServer).AgentServiceMaintenance) + registerEndpoint("/v1/catalog/register", []string{"PUT"}, (*HTTPServer).CatalogRegister) + registerEndpoint("/v1/catalog/deregister", []string{"PUT"}, (*HTTPServer).CatalogDeregister) + registerEndpoint("/v1/catalog/datacenters", []string{"GET"}, (*HTTPServer).CatalogDatacenters) + registerEndpoint("/v1/catalog/nodes", []string{"GET"}, (*HTTPServer).CatalogNodes) + registerEndpoint("/v1/catalog/services", []string{"GET"}, (*HTTPServer).CatalogServices) + registerEndpoint("/v1/catalog/service/", []string{"GET"}, (*HTTPServer).CatalogServiceNodes) + registerEndpoint("/v1/catalog/node/", []string{"GET"}, (*HTTPServer).CatalogNodeServices) + registerEndpoint("/v1/coordinate/datacenters", []string{"GET"}, (*HTTPServer).CoordinateDatacenters) + registerEndpoint("/v1/coordinate/nodes", []string{"GET"}, (*HTTPServer).CoordinateNodes) + registerEndpoint("/v1/coordinate/node/", []string{"GET"}, (*HTTPServer).CoordinateNode) + registerEndpoint("/v1/coordinate/update", []string{"PUT"}, (*HTTPServer).CoordinateUpdate) + registerEndpoint("/v1/event/fire/", []string{"PUT"}, (*HTTPServer).EventFire) + registerEndpoint("/v1/event/list", []string{"GET"}, (*HTTPServer).EventList) + registerEndpoint("/v1/health/node/", []string{"GET"}, (*HTTPServer).HealthNodeChecks) + registerEndpoint("/v1/health/checks/", []string{"GET"}, (*HTTPServer).HealthServiceChecks) + registerEndpoint("/v1/health/state/", []string{"GET"}, (*HTTPServer).HealthChecksInState) + registerEndpoint("/v1/health/service/", []string{"GET"}, (*HTTPServer).HealthServiceNodes) + registerEndpoint("/v1/internal/ui/nodes", []string{"GET"}, (*HTTPServer).UINodes) + registerEndpoint("/v1/internal/ui/node/", []string{"GET"}, (*HTTPServer).UINodeInfo) + registerEndpoint("/v1/internal/ui/services", []string{"GET"}, (*HTTPServer).UIServices) + registerEndpoint("/v1/kv/", []string{"GET", "PUT", "DELETE"}, (*HTTPServer).KVSEndpoint) + registerEndpoint("/v1/operator/raft/configuration", []string{"GET"}, (*HTTPServer).OperatorRaftConfiguration) + registerEndpoint("/v1/operator/raft/peer", []string{"DELETE"}, (*HTTPServer).OperatorRaftPeer) + registerEndpoint("/v1/operator/keyring", []string{"GET", "POST", "PUT", "DELETE"}, (*HTTPServer).OperatorKeyringEndpoint) + registerEndpoint("/v1/operator/autopilot/configuration", []string{"GET", "PUT"}, (*HTTPServer).OperatorAutopilotConfiguration) + registerEndpoint("/v1/operator/autopilot/health", []string{"GET"}, (*HTTPServer).OperatorServerHealth) + registerEndpoint("/v1/query", []string{"GET", "POST"}, (*HTTPServer).PreparedQueryGeneral) + // specific prepared query endpoints have more complex rules for allowed methods, so + // the prefix is registered with no methods. + registerEndpoint("/v1/query/", []string{}, (*HTTPServer).PreparedQuerySpecific) + registerEndpoint("/v1/session/create", []string{"PUT"}, (*HTTPServer).SessionCreate) + registerEndpoint("/v1/session/destroy/", []string{"PUT"}, (*HTTPServer).SessionDestroy) + registerEndpoint("/v1/session/renew/", []string{"PUT"}, (*HTTPServer).SessionRenew) + registerEndpoint("/v1/session/info/", []string{"GET"}, (*HTTPServer).SessionGet) + registerEndpoint("/v1/session/node/", []string{"GET"}, (*HTTPServer).SessionsForNode) + registerEndpoint("/v1/session/list", []string{"GET"}, (*HTTPServer).SessionList) + registerEndpoint("/v1/status/leader", []string{"GET"}, (*HTTPServer).StatusLeader) + registerEndpoint("/v1/status/peers", []string{"GET"}, (*HTTPServer).StatusPeers) + registerEndpoint("/v1/snapshot", []string{"GET", "PUT"}, (*HTTPServer).Snapshot) + registerEndpoint("/v1/txn", []string{"PUT"}, (*HTTPServer).Txn) } diff --git a/agent/http_oss_test.go b/agent/http_oss_test.go index d2e04419b..af5c328ac 100644 --- a/agent/http_oss_test.go +++ b/agent/http_oss_test.go @@ -3,111 +3,110 @@ package agent import ( "fmt" "net/http" + "net/http/httptest" "strings" "testing" "github.com/hashicorp/consul/logger" ) +// extra endpoints that should be tested, and their allowed methods +var extraTestEndpoints = map[string][]string{ + "/v1/query": []string{"GET", "POST"}, + "/v1/query/": []string{"GET", "PUT", "DELETE"}, + "/v1/query/xxx/execute": []string{"GET"}, + "/v1/query/xxx/explain": []string{"GET"}, +} + +// certain endpoints can't be unit tested. +func includePathInTest(path string) bool { + var hanging = path == "/v1/status/peers" || path == "/v1/agent/monitor" || path == "/v1/agent/reload" // these hang + var custom = path == "/v1/query" || path == "/v1/query/" // these have custom logic + return !(hanging || custom) +} + func TestHTTPAPI_MethodNotAllowed_OSS(t *testing.T) { - tests := []struct { - methods, uri string - }{ - {"PUT", "/v1/acl/bootstrap"}, - {"PUT", "/v1/acl/create"}, - {"PUT", "/v1/acl/update"}, - {"PUT", "/v1/acl/destroy/"}, - {"GET", "/v1/acl/info/"}, - {"PUT", "/v1/acl/clone/"}, - {"GET", "/v1/acl/list"}, - {"GET", "/v1/acl/replication"}, - {"PUT", "/v1/agent/token/"}, - {"GET", "/v1/agent/self"}, - {"GET", "/v1/agent/members"}, - {"PUT", "/v1/agent/check/deregister/"}, - {"PUT", "/v1/agent/check/fail/"}, - {"PUT", "/v1/agent/check/pass/"}, - {"PUT", "/v1/agent/check/register"}, - {"PUT", "/v1/agent/check/update/"}, - {"PUT", "/v1/agent/check/warn/"}, - {"GET", "/v1/agent/checks"}, - {"PUT", "/v1/agent/force-leave/"}, - {"PUT", "/v1/agent/join/"}, - {"PUT", "/v1/agent/leave"}, - {"PUT", "/v1/agent/maintenance"}, - {"GET", "/v1/agent/metrics"}, - // {"GET", "/v1/agent/monitor"}, // requires LogWriter. Hangs if LogWriter is provided - {"PUT", "/v1/agent/reload"}, - {"PUT", "/v1/agent/service/deregister/"}, - {"PUT", "/v1/agent/service/maintenance/"}, - {"PUT", "/v1/agent/service/register"}, - {"GET", "/v1/agent/services"}, - {"GET", "/v1/catalog/datacenters"}, - {"PUT", "/v1/catalog/deregister"}, - {"GET", "/v1/catalog/node/"}, - {"GET", "/v1/catalog/nodes"}, - {"PUT", "/v1/catalog/register"}, - {"GET", "/v1/catalog/service/"}, - {"GET", "/v1/catalog/services"}, - {"GET", "/v1/coordinate/datacenters"}, - {"GET", "/v1/coordinate/nodes"}, - {"GET", "/v1/coordinate/node/"}, - {"PUT", "/v1/event/fire/"}, - {"GET", "/v1/event/list"}, - {"GET", "/v1/health/checks/"}, - {"GET", "/v1/health/node/"}, - {"GET", "/v1/health/service/"}, - {"GET", "/v1/health/state/"}, - {"GET", "/v1/internal/ui/node/"}, - {"GET", "/v1/internal/ui/nodes"}, - {"GET", "/v1/internal/ui/services"}, - {"GET PUT DELETE", "/v1/kv/"}, - {"GET PUT", "/v1/operator/autopilot/configuration"}, - {"GET", "/v1/operator/autopilot/health"}, - {"GET POST PUT DELETE", "/v1/operator/keyring"}, - {"GET", "/v1/operator/raft/configuration"}, - {"DELETE", "/v1/operator/raft/peer"}, - {"GET POST", "/v1/query"}, - {"GET PUT DELETE", "/v1/query/"}, - {"GET", "/v1/query/xxx/execute"}, - {"GET", "/v1/query/xxx/explain"}, - {"PUT", "/v1/session/create"}, - {"PUT", "/v1/session/destroy/"}, - {"GET", "/v1/session/info/"}, - {"GET", "/v1/session/list"}, - {"GET", "/v1/session/node/"}, - {"PUT", "/v1/session/renew/"}, - {"GET PUT", "/v1/snapshot"}, - {"GET", "/v1/status/leader"}, - // {"GET", "/v1/status/peers"},// hangs - {"PUT", "/v1/txn"}, - } a := NewTestAgent(t.Name(), `acl_datacenter = "dc1"`) a.Agent.LogWriter = logger.NewLogWriter(512) defer a.Shutdown() - all := []string{"GET", "PUT", "POST", "DELETE", "HEAD"} + all := []string{"GET", "PUT", "POST", "DELETE", "HEAD", "OPTIONS"} client := http.Client{} - for _, tt := range tests { - for _, m := range all { - t.Run(m+" "+tt.uri, func(t *testing.T) { - uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), tt.uri) - req, _ := http.NewRequest(m, uri, nil) - resp, err := client.Do(req) - if err != nil { - t.Fatal("client.Do failed: ", err) - } + testMethodNotAllowed := func(method string, path string, allowedMethods []string) { + t.Run(method+" "+path, func(t *testing.T) { + uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path) + req, _ := http.NewRequest(method, uri, nil) + resp, err := client.Do(req) + if err != nil { + t.Fatal("client.Do failed: ", err) + } - allowed := strings.Contains(tt.methods, m) - if allowed && resp.StatusCode == http.StatusMethodNotAllowed { - t.Fatalf("method allowed: got status code %d want any other code", resp.StatusCode) + allowed := method == "OPTIONS" + for _, allowedMethod := range allowedMethods { + if allowedMethod == method { + allowed = true + break } - if !allowed && resp.StatusCode != http.StatusMethodNotAllowed { - t.Fatalf("method not allowed: got status code %d want %d", resp.StatusCode, http.StatusMethodNotAllowed) - } - }) + } + + if allowed && resp.StatusCode == http.StatusMethodNotAllowed { + t.Fatalf("method allowed: got status code %d want any other code", resp.StatusCode) + } + if !allowed && resp.StatusCode != http.StatusMethodNotAllowed { + t.Fatalf("method not allowed: got status code %d want %d", resp.StatusCode, http.StatusMethodNotAllowed) + } + }) + } + + for path, methods := range extraTestEndpoints { + for _, method := range all { + testMethodNotAllowed(method, path, methods) + } + } + + for path, methods := range allowedMethods { + if includePathInTest(path) { + for _, method := range all { + testMethodNotAllowed(method, path, methods) + } + } + } +} + +func TestHTTPAPI_OptionMethod_OSS(t *testing.T) { + a := NewTestAgent(t.Name(), `acl_datacenter = "dc1"`) + a.Agent.LogWriter = logger.NewLogWriter(512) + defer a.Shutdown() + + testOptionMethod := func(path string, methods []string) { + t.Run("OPTIONS "+path, func(t *testing.T) { + uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path) + req, _ := http.NewRequest("OPTIONS", uri, nil) + resp := httptest.NewRecorder() + a.srv.Handler.ServeHTTP(resp, req) + allMethods := append([]string{"OPTIONS"}, methods...) + + if resp.Code != http.StatusOK { + t.Fatalf("options request: got status code %d want %d", resp.Code, http.StatusOK) + } + + optionsStr := resp.Header().Get("Allow") + if optionsStr == "" { + t.Fatalf("options request: got empty 'Allow' header") + } else if optionsStr != strings.Join(allMethods, ",") { + t.Fatalf("options request: got 'Allow' header value of %s want %s", optionsStr, allMethods) + } + }) + } + + for path, methods := range extraTestEndpoints { + testOptionMethod(path, methods) + } + for path, methods := range allowedMethods { + if includePathInTest(path) { + testOptionMethod(path, methods) } } } diff --git a/agent/http_test.go b/agent/http_test.go index 12ce2a813..23c90014e 100644 --- a/agent/http_test.go +++ b/agent/http_test.go @@ -298,7 +298,7 @@ func TestHTTPAPI_BlockEndpoints(t *testing.T) { { req, _ := http.NewRequest("GET", "/v1/agent/self", nil) resp := httptest.NewRecorder() - a.srv.wrap(handler)(resp, req) + a.srv.wrap(handler, []string{"GET"})(resp, req) if got, want := resp.Code, http.StatusForbidden; got != want { t.Fatalf("bad response code got %d want %d", got, want) } @@ -308,7 +308,7 @@ func TestHTTPAPI_BlockEndpoints(t *testing.T) { { req, _ := http.NewRequest("GET", "/v1/agent/checks", nil) resp := httptest.NewRecorder() - a.srv.wrap(handler)(resp, req) + a.srv.wrap(handler, []string{"GET"})(resp, req) if got, want := resp.Code, http.StatusOK; got != want { t.Fatalf("bad response code got %d want %d", got, want) } @@ -340,7 +340,7 @@ func TestHTTPAPI_TranslateAddrHeader(t *testing.T) { } req, _ := http.NewRequest("GET", "/v1/agent/self", nil) - a.srv.wrap(handler)(resp, req) + a.srv.wrap(handler, []string{"GET"})(resp, req) translate := resp.Header().Get("X-Consul-Translate-Addresses") if translate != "" { @@ -361,7 +361,7 @@ func TestHTTPAPI_TranslateAddrHeader(t *testing.T) { } req, _ := http.NewRequest("GET", "/v1/agent/self", nil) - a.srv.wrap(handler)(resp, req) + a.srv.wrap(handler, []string{"GET"})(resp, req) translate := resp.Header().Get("X-Consul-Translate-Addresses") if translate != "true" { @@ -388,7 +388,7 @@ func TestHTTPAPIResponseHeaders(t *testing.T) { } req, _ := http.NewRequest("GET", "/v1/agent/self", nil) - a.srv.wrap(handler)(resp, req) + a.srv.wrap(handler, []string{"GET"})(resp, req) origin := resp.Header().Get("Access-Control-Allow-Origin") if origin != "*" { @@ -413,7 +413,7 @@ func TestContentTypeIsJSON(t *testing.T) { } req, _ := http.NewRequest("GET", "/v1/kv/key", nil) - a.srv.wrap(handler)(resp, req) + a.srv.wrap(handler, []string{"GET"})(resp, req) contentType := resp.Header().Get("Content-Type") @@ -467,7 +467,7 @@ func TestHTTP_wrap_obfuscateLog(t *testing.T) { t.Run(url, func(t *testing.T) { resp := httptest.NewRecorder() req, _ := http.NewRequest("GET", url, nil) - a.srv.wrap(handler)(resp, req) + a.srv.wrap(handler, []string{"GET"})(resp, req) if got := buf.String(); !strings.Contains(got, want) { t.Fatalf("got %s want %s", got, want) @@ -499,7 +499,7 @@ func testPrettyPrint(pretty string, t *testing.T) { urlStr := "/v1/kv/key?" + pretty req, _ := http.NewRequest("GET", urlStr, nil) - a.srv.wrap(handler)(resp, req) + a.srv.wrap(handler, []string{"GET"})(resp, req) expected, _ := json.MarshalIndent(r, "", " ") expected = append(expected, "\n"...) diff --git a/agent/kvs_endpoint.go b/agent/kvs_endpoint.go index 0e1a1cd87..e95570f50 100644 --- a/agent/kvs_endpoint.go +++ b/agent/kvs_endpoint.go @@ -48,8 +48,7 @@ func (s *HTTPServer) KVSEndpoint(resp http.ResponseWriter, req *http.Request) (i case "DELETE": return s.KVSDelete(resp, req, &args) default: - resp.WriteHeader(http.StatusMethodNotAllowed) - return nil, nil + return nil, MethodNotAllowedError{req.Method, []string{"GET", "PUT", "DELETE"}} } } diff --git a/agent/operator_endpoint.go b/agent/operator_endpoint.go index 4cf580e20..34aa90d93 100644 --- a/agent/operator_endpoint.go +++ b/agent/operator_endpoint.go @@ -16,10 +16,6 @@ import ( // OperatorRaftConfiguration is used to inspect the current Raft configuration. // This supports the stale query mode in case the cluster doesn't have a leader. func (s *HTTPServer) OperatorRaftConfiguration(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - var args structs.DCSpecificRequest if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { return nil, nil @@ -36,10 +32,6 @@ func (s *HTTPServer) OperatorRaftConfiguration(resp http.ResponseWriter, req *ht // OperatorRaftPeer supports actions on Raft peers. Currently we only support // removing peers by address. func (s *HTTPServer) OperatorRaftPeer(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "DELETE" { - return nil, MethodNotAllowedError{req.Method, []string{"DELETE"}} - } - var args structs.RaftRemovePeerRequest s.parseDC(req, &args.Datacenter) s.parseToken(req, &args.Token) @@ -268,10 +260,6 @@ func (s *HTTPServer) OperatorAutopilotConfiguration(resp http.ResponseWriter, re // OperatorServerHealth is used to get the health of the servers in the local DC func (s *HTTPServer) OperatorServerHealth(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - var args structs.DCSpecificRequest if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { return nil, nil diff --git a/agent/prepared_query_endpoint.go b/agent/prepared_query_endpoint.go index 0be0ea3ed..532cf70f4 100644 --- a/agent/prepared_query_endpoint.go +++ b/agent/prepared_query_endpoint.go @@ -228,9 +228,31 @@ func (s *HTTPServer) preparedQueryDelete(id string, resp http.ResponseWriter, re return nil, nil } +// PreparedQuerySpecificOptions handles OPTIONS requests to prepared query endpoints. +func (s *HTTPServer) preparedQuerySpecificOptions(resp http.ResponseWriter, req *http.Request) interface{} { + path := req.URL.Path + switch { + case strings.HasSuffix(path, "/execute"): + resp.Header().Add("Allow", strings.Join([]string{"OPTIONS", "GET"}, ",")) + return resp + + case strings.HasSuffix(path, "/explain"): + resp.Header().Add("Allow", strings.Join([]string{"OPTIONS", "GET"}, ",")) + return resp + + default: + resp.Header().Add("Allow", strings.Join([]string{"OPTIONS", "GET", "PUT", "DELETE"}, ",")) + return resp + } +} + // PreparedQuerySpecific handles all the prepared query requests specific to a // particular query. func (s *HTTPServer) PreparedQuerySpecific(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + if req.Method == "OPTIONS" { + return s.preparedQuerySpecificOptions(resp, req), nil + } + path := req.URL.Path id := strings.TrimPrefix(path, "/v1/query/") diff --git a/agent/session_endpoint.go b/agent/session_endpoint.go index 85d392abb..1a6cc930d 100644 --- a/agent/session_endpoint.go +++ b/agent/session_endpoint.go @@ -27,10 +27,6 @@ type sessionCreateResponse struct { // SessionCreate is used to create a new session func (s *HTTPServer) SessionCreate(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - // Default the session to our node + serf check + release session // invalidate behavior. args := structs.SessionRequest{ @@ -136,10 +132,6 @@ func FixupChecks(raw interface{}, s *structs.Session) error { // SessionDestroy is used to destroy an existing session func (s *HTTPServer) SessionDestroy(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - args := structs.SessionRequest{ Op: structs.SessionDestroy, } @@ -163,10 +155,6 @@ func (s *HTTPServer) SessionDestroy(resp http.ResponseWriter, req *http.Request) // SessionRenew is used to renew the TTL on an existing TTL session func (s *HTTPServer) SessionRenew(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - args := structs.SessionSpecificRequest{} if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { return nil, nil @@ -194,10 +182,6 @@ func (s *HTTPServer) SessionRenew(resp http.ResponseWriter, req *http.Request) ( // SessionGet is used to get info for a particular session func (s *HTTPServer) SessionGet(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - args := structs.SessionSpecificRequest{} if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { return nil, nil @@ -226,10 +210,6 @@ func (s *HTTPServer) SessionGet(resp http.ResponseWriter, req *http.Request) (in // SessionList is used to list all the sessions func (s *HTTPServer) SessionList(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - args := structs.DCSpecificRequest{} if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { return nil, nil @@ -250,10 +230,6 @@ func (s *HTTPServer) SessionList(resp http.ResponseWriter, req *http.Request) (i // SessionsForNode returns all the nodes belonging to a node func (s *HTTPServer) SessionsForNode(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - args := structs.NodeSpecificRequest{} if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { return nil, nil diff --git a/agent/status_endpoint.go b/agent/status_endpoint.go index 2e3748267..75275800f 100644 --- a/agent/status_endpoint.go +++ b/agent/status_endpoint.go @@ -5,10 +5,6 @@ import ( ) func (s *HTTPServer) StatusLeader(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - var out string if err := s.agent.RPC("Status.Leader", struct{}{}, &out); err != nil { return nil, err @@ -17,10 +13,6 @@ func (s *HTTPServer) StatusLeader(resp http.ResponseWriter, req *http.Request) ( } func (s *HTTPServer) StatusPeers(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - var out []string if err := s.agent.RPC("Status.Peers", struct{}{}, &out); err != nil { return nil, err diff --git a/agent/txn_endpoint.go b/agent/txn_endpoint.go index 2d0b30574..4870b0327 100644 --- a/agent/txn_endpoint.go +++ b/agent/txn_endpoint.go @@ -172,10 +172,6 @@ func (s *HTTPServer) convertOps(resp http.ResponseWriter, req *http.Request) (st // pathed to an endpoint that supports consistency modes (but not blocking), // and everything else will be routed through Raft like a normal write. func (s *HTTPServer) Txn(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "PUT" { - return nil, MethodNotAllowedError{req.Method, []string{"PUT"}} - } - // Convert the ops from the API format to the internal format. ops, writes, ok := s.convertOps(resp, req) if !ok { diff --git a/agent/ui_endpoint.go b/agent/ui_endpoint.go index e43935131..e13f3ebb6 100644 --- a/agent/ui_endpoint.go +++ b/agent/ui_endpoint.go @@ -22,10 +22,6 @@ type ServiceSummary struct { // UINodes is used to list the nodes in a given datacenter. We return a // NodeDump which provides overview information for all the nodes func (s *HTTPServer) UINodes(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - // Parse arguments args := structs.DCSpecificRequest{} if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { @@ -63,10 +59,6 @@ RPC: // UINodeInfo is used to get info on a single node in a given datacenter. We return a // NodeInfo which provides overview information for the node func (s *HTTPServer) UINodeInfo(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - // Parse arguments args := structs.NodeSpecificRequest{} if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done { @@ -113,10 +105,6 @@ RPC: // UIServices is used to list the services in a given datacenter. We return a // ServiceSummary which provides overview information for the service func (s *HTTPServer) UIServices(resp http.ResponseWriter, req *http.Request) (interface{}, error) { - if req.Method != "GET" { - return nil, MethodNotAllowedError{req.Method, []string{"GET"}} - } - // Parse arguments args := structs.DCSpecificRequest{} if done := s.parse(resp, req, &args.Datacenter, &args.QueryOptions); done {