diff --git a/command/agent/kvs_endpoint.go b/command/agent/kvs_endpoint.go index 4f5bccb48..bca7bf321 100644 --- a/command/agent/kvs_endpoint.go +++ b/command/agent/kvs_endpoint.go @@ -348,11 +348,13 @@ func (s *HTTPServer) KVSTxn(resp http.ResponseWriter, req *http.Request) (interf s.parseToken(req, &args.Token) // Note the body is in API format, and not the RPC format. If we can't - // decode it, we will return a 500 since we don't have enough context to + // decode it, we will return a 400 since we don't have enough context to // associate the error with a given operation. var txn api.KVTxn if err := decodeBody(req, &txn, fixupValues); err != nil { - return nil, fmt.Errorf("failed to parse body: %v", err) + resp.WriteHeader(http.StatusBadRequest) + resp.Write([]byte(fmt.Sprintf("Failed to parse body: %v", err))) + return nil, nil } // Convert the API format into the RPC format. Note that fixupValues diff --git a/command/agent/kvs_endpoint_test.go b/command/agent/kvs_endpoint_test.go index ab6357ed1..3d0343c2f 100644 --- a/command/agent/kvs_endpoint_test.go +++ b/command/agent/kvs_endpoint_test.go @@ -573,3 +573,219 @@ func TestKVSEndpoint_DELETE_ConflictingFlags(t *testing.T) { } }) } + +func TestKVSEndpoint_Txn(t *testing.T) { + // Bad JSON. + httpTest(t, func(srv *HTTPServer) { + buf := bytes.NewBuffer([]byte("{")) + req, err := http.NewRequest("PUT", "/v1/kv-txn", buf) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := httptest.NewRecorder() + if _, err := srv.KVSTxn(resp, req); err != nil { + t.Fatalf("err: %v", err) + } + if resp.Code != 400 { + t.Fatalf("expected 400, got %d", resp.Code) + } + if !bytes.Contains(resp.Body.Bytes(), []byte("Failed to parse")) { + t.Fatalf("expected conflicting args error") + } + }) + + // Bad request. + httpTest(t, func(srv *HTTPServer) { + buf := bytes.NewBuffer([]byte("{")) + req, err := http.NewRequest("GET", "/v1/kv-txn", buf) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := httptest.NewRecorder() + if _, err := srv.KVSTxn(resp, req); err != nil { + t.Fatalf("err: %v", err) + } + if resp.Code != 405 { + t.Fatalf("expected 405, got %d", resp.Code) + } + }) + + // Make sure all incoming fields get converted properly to the internal + // RPC format. + httpTest(t, func(srv *HTTPServer) { + var index uint64 + id := makeTestSession(t, srv) + { + buf := bytes.NewBuffer([]byte(fmt.Sprintf(` +[ + { + "Op": "lock", + "Key": "key", + "Value": "aGVsbG8gd29ybGQ=", + "Flags": 23, + "Session": %q + }, + { + "Op": "get", + "Key": "key" + } +] +`, id))) + req, err := http.NewRequest("PUT", "/v1/kv-txn", buf) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := httptest.NewRecorder() + obj, err := srv.KVSTxn(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp.Code != 200 { + t.Fatalf("expected 200, got %d", resp.Code) + } + + atomic, ok := obj.(structs.KVSAtomicResponse) + if !ok { + t.Fatalf("bad type: %T", obj) + } + if len(atomic.Results) != 2 { + t.Fatalf("bad: %v", atomic) + } + index = atomic.Results[0].ModifyIndex + expected := structs.KVSAtomicResponse{ + Results: structs.DirEntries{ + &structs.DirEntry{ + Key: "key", + Value: nil, + Flags: 23, + Session: id, + LockIndex: 1, + RaftIndex: structs.RaftIndex{ + CreateIndex: index, + ModifyIndex: index, + }, + }, + &structs.DirEntry{ + Key: "key", + Value: []byte("hello world"), + Flags: 23, + Session: id, + LockIndex: 1, + RaftIndex: structs.RaftIndex{ + CreateIndex: index, + ModifyIndex: index, + }, + }, + }, + } + if !reflect.DeepEqual(atomic, expected) { + t.Fatalf("bad: %v", atomic) + } + } + + // Now that we have an index we can do a CAS to make sure the + // index field gets translated to the RPC format. + { + buf := bytes.NewBuffer([]byte(fmt.Sprintf(` +[ + { + "Op": "cas", + "Key": "key", + "Value": "Z29vZGJ5ZSB3b3JsZA==", + "Index": %d + }, + { + "Op": "get", + "Key": "key" + } +] +`, index))) + req, err := http.NewRequest("PUT", "/v1/kv-txn", buf) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := httptest.NewRecorder() + obj, err := srv.KVSTxn(resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp.Code != 200 { + t.Fatalf("expected 200, got %d", resp.Code) + } + + atomic, ok := obj.(structs.KVSAtomicResponse) + if !ok { + t.Fatalf("bad type: %T", obj) + } + if len(atomic.Results) != 2 { + t.Fatalf("bad: %v", atomic) + } + modIndex := atomic.Results[0].ModifyIndex + expected := structs.KVSAtomicResponse{ + Results: structs.DirEntries{ + &structs.DirEntry{ + Key: "key", + Value: nil, + Session: id, + RaftIndex: structs.RaftIndex{ + CreateIndex: index, + ModifyIndex: modIndex, + }, + }, + &structs.DirEntry{ + Key: "key", + Value: []byte("goodbye world"), + Session: id, + RaftIndex: structs.RaftIndex{ + CreateIndex: index, + ModifyIndex: modIndex, + }, + }, + }, + } + for _, r := range atomic.Results { + fmt.Printf("%v\n", *r) + } + if !reflect.DeepEqual(atomic, expected) { + t.Fatalf("bad: %v", atomic) + } + } + }) + + // Verify an error inside a transaction. + httpTest(t, func(srv *HTTPServer) { + buf := bytes.NewBuffer([]byte(` +[ + { + "Op": "lock", + "Key": "key", + "Value": "aGVsbG8gd29ybGQ=", + "Session": "nope" + }, + { + "Op": "get", + "Key": "key" + } +] +`)) + req, err := http.NewRequest("PUT", "/v1/kv-txn", buf) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp := httptest.NewRecorder() + if _, err = srv.KVSTxn(resp, req); err != nil { + t.Fatalf("err: %v", err) + } + if resp.Code != 409 { + t.Fatalf("expected 409, got %d", resp.Code) + } + if !bytes.Contains(resp.Body.Bytes(), []byte("failed session lookup")) { + t.Fatalf("bad: %s", resp.Body.String()) + } + }) +}