From 4ced73875b8b1d7de6bf2ffd7acd746edd52bc03 Mon Sep 17 00:00:00 2001 From: Drew Bailey <2614075+drewbailey@users.noreply.github.com> Date: Fri, 13 Dec 2019 15:06:19 -0500 Subject: [PATCH] leave acl checking to rpc endpoints fix test expectation test wrapNonJSON --- command/agent/agent_endpoint.go | 22 ++---------- command/agent/agent_endpoint_test.go | 2 +- command/agent/http_test.go | 54 ++++++++++++++++++++++++++++ nomad/client_agent_endpoint.go | 2 -- nomad/structs/structs.go | 2 ++ 5 files changed, 59 insertions(+), 23 deletions(-) diff --git a/command/agent/agent_endpoint.go b/command/agent/agent_endpoint.go index dfa390196..1b96c1372 100644 --- a/command/agent/agent_endpoint.go +++ b/command/agent/agent_endpoint.go @@ -361,26 +361,8 @@ func (s *HTTPServer) agentPprof(reqType profile.ReqType, resp http.ResponseWrite var secret string s.parseToken(req, &secret) - // Check agent write permissions - aclObj, err := s.agent.Server().ResolveToken(secret) - if err != nil { - return nil, err - } else if aclObj != nil && !aclObj.AllowAgentWrite() { - return nil, structs.ErrPermissionDenied - } - - enableDebug := s.agent.GetConfig().EnableDebug - - // ACLs not enabled - if aclObj == nil { - // If debug is not explicitly enabled - // return unauthorized - if enableDebug == false { - return nil, structs.ErrPermissionDenied - } - } - // Parse profile duration, default to 1 second + var err error secondsParam := req.URL.Query().Get("seconds") var seconds int if secondsParam == "" { @@ -388,7 +370,7 @@ func (s *HTTPServer) agentPprof(reqType profile.ReqType, resp http.ResponseWrite } else { seconds, err = strconv.Atoi(secondsParam) if err != nil { - errStr := fmt.Sprintf("Error parsing seconds parameter %s", seconds) + errStr := fmt.Sprintf("Error parsing seconds parameter %s", secondsParam) return nil, CodedError(400, errStr) } } diff --git a/command/agent/agent_endpoint_test.go b/command/agent/agent_endpoint_test.go index c6671aa5d..db89e1d38 100644 --- a/command/agent/agent_endpoint_test.go +++ b/command/agent/agent_endpoint_test.go @@ -513,7 +513,7 @@ func TestAgent_PprofRequest(t *testing.T) { desc: "invalid server request", url: "/v1/agent/pprof/unknown", addServerID: true, - expectedErr: "RPC Error:: 404,Pprof profile not found profile: unknnown", + expectedErr: "RPC Error:: 404,Pprof profile not found profile: unknown", }, { desc: "cpu profile request", diff --git a/command/agent/http_test.go b/command/agent/http_test.go index fa1feaf31..ab26ed4f9 100644 --- a/command/agent/http_test.go +++ b/command/agent/http_test.go @@ -218,6 +218,60 @@ func TestContentTypeIsJSON(t *testing.T) { } } +func TestWrapNonJSON(t *testing.T) { + t.Parallel() + s := makeHTTPServer(t, nil) + defer s.Shutdown() + + resp := httptest.NewRecorder() + + handler := func(resp http.ResponseWriter, req *http.Request) ([]byte, error) { + return []byte("test response"), nil + } + + req, _ := http.NewRequest("GET", "/v1/kv/key", nil) + s.Server.wrapNonJSON(handler)(resp, req) + + respBody, _ := ioutil.ReadAll(resp.Body) + require.Equal(t, respBody, []byte("test response")) + +} + +func TestWrapNonJSON_Error(t *testing.T) { + t.Parallel() + s := makeHTTPServer(t, nil) + defer s.Shutdown() + + handlerRPCErr := func(resp http.ResponseWriter, req *http.Request) ([]byte, error) { + return nil, structs.NewErrRPCCoded(404, "not found") + } + + handlerCodedErr := func(resp http.ResponseWriter, req *http.Request) ([]byte, error) { + return nil, CodedError(422, "unprocessable") + } + + // RPC coded error + { + resp := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/v1/kv/key", nil) + s.Server.wrapNonJSON(handlerRPCErr)(resp, req) + respBody, _ := ioutil.ReadAll(resp.Body) + require.Equal(t, []byte("not found"), respBody) + require.Equal(t, 404, resp.Code) + } + + // CodedError + { + resp := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/v1/kv/key", nil) + s.Server.wrapNonJSON(handlerCodedErr)(resp, req) + respBody, _ := ioutil.ReadAll(resp.Body) + require.Equal(t, []byte("unprocessable"), respBody) + require.Equal(t, 422, resp.Code) + } + +} + func TestPrettyPrint(t *testing.T) { t.Parallel() testPrettyPrint("pretty=1", true, t) diff --git a/nomad/client_agent_endpoint.go b/nomad/client_agent_endpoint.go index 8887cd00c..8a3a666da 100644 --- a/nomad/client_agent_endpoint.go +++ b/nomad/client_agent_endpoint.go @@ -285,8 +285,6 @@ func (a *Agent) forwardFor(serverID, region string) (*serverParts, error) { } else { members := a.srv.Members() for _, mem := range members { - // TODO find a better way to get the agent ID we associate - // with a serf member if mem.Name == serverID || mem.Tags["id"] == serverID { if ok, srv := isNomadServer(mem); ok { if srv.Region != region { diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 48f4fd638..98e934d4a 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -294,6 +294,8 @@ type AgentPprofResponse struct { // Payload is the generated pprof profile Payload []byte + // HTTPHeaders are a set of key value pairs to be applied as + // HTTP headers for a specific runtime profile HTTPHeaders map[string]string }