From 26c9d96ccb1ea3e97b3e7b937ddb4cfaee695217 Mon Sep 17 00:00:00 2001 From: Ryan Uber Date: Fri, 23 Jan 2015 12:48:39 -0800 Subject: [PATCH] agent: error from KVS endpoint if incompatible flags are passed. Fixes #432 --- command/agent/kvs_endpoint.go | 25 ++++++++++++++++++ command/agent/kvs_endpoint_test.go | 42 ++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/command/agent/kvs_endpoint.go b/command/agent/kvs_endpoint.go index 54f843c91..f632cfb2c 100644 --- a/command/agent/kvs_endpoint.go +++ b/command/agent/kvs_endpoint.go @@ -136,6 +136,9 @@ func (s *HTTPServer) KVSPut(resp http.ResponseWriter, req *http.Request, args *s if missingKey(resp, args) { return nil, nil } + if conflictingFlags(resp, req, "cas", "acquire", "release") { + return nil, nil + } applyReq := structs.KVSRequest{ Datacenter: args.Datacenter, Op: structs.KVSSet, @@ -209,6 +212,9 @@ func (s *HTTPServer) KVSPut(resp http.ResponseWriter, req *http.Request, args *s // KVSPut handles a DELETE request func (s *HTTPServer) KVSDelete(resp http.ResponseWriter, req *http.Request, args *structs.KeyRequest) (interface{}, error) { + if conflictingFlags(resp, req, "recurse", "cas") { + return nil, nil + } applyReq := structs.KVSRequest{ Datacenter: args.Datacenter, Op: structs.KVSDelete, @@ -259,3 +265,22 @@ func missingKey(resp http.ResponseWriter, args *structs.KeyRequest) bool { } return false } + +// conflictingFlags determines if non-composable flags were passed in a request. +func conflictingFlags(resp http.ResponseWriter, req *http.Request, flags ...string) bool { + params := req.URL.Query() + + found := false + for _, conflict := range flags { + if _, ok := params[conflict]; ok { + if found { + resp.WriteHeader(400) + resp.Write([]byte("Conflicting flags: " + params.Encode())) + return true + } + found = true + } + } + + return false +} diff --git a/command/agent/kvs_endpoint_test.go b/command/agent/kvs_endpoint_test.go index 9e2b17416..ab6357ed1 100644 --- a/command/agent/kvs_endpoint_test.go +++ b/command/agent/kvs_endpoint_test.go @@ -531,3 +531,45 @@ func TestKVSEndpoint_GET_Raw(t *testing.T) { } }) } + +func TestKVSEndpoint_PUT_ConflictingFlags(t *testing.T) { + httpTest(t, func(srv *HTTPServer) { + req, err := http.NewRequest("PUT", "/v1/kv/test?cas=0&acquire=xxx", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := httptest.NewRecorder() + if _, err := srv.KVSEndpoint(resp, req); err != nil { + t.Fatalf("err: %v", err) + } + + if resp.Code != 400 { + t.Fatalf("expected 400, got %d", resp.Code) + } + if !bytes.Contains(resp.Body.Bytes(), []byte("Conflicting")) { + t.Fatalf("expected conflicting args error") + } + }) +} + +func TestKVSEndpoint_DELETE_ConflictingFlags(t *testing.T) { + httpTest(t, func(srv *HTTPServer) { + req, err := http.NewRequest("DELETE", "/v1/kv/test?recurse&cas=0", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := httptest.NewRecorder() + if _, err := srv.KVSEndpoint(resp, req); err != nil { + t.Fatalf("err: %v", err) + } + + if resp.Code != 400 { + t.Fatalf("expected 400, got %d", resp.Code) + } + if !bytes.Contains(resp.Body.Bytes(), []byte("Conflicting")) { + t.Fatalf("expected conflicting args error") + } + }) +}