diff --git a/http/sys_health.go b/http/sys_health.go index 6e7c0b92c..2883744c3 100644 --- a/http/sys_health.go +++ b/http/sys_health.go @@ -14,6 +14,8 @@ func handleSysHealth(core *vault.Core) http.Handler { switch r.Method { case "GET": handleSysHealthGet(core, w, r) + case "HEAD": + handleSysHealthHead(core, w, r) default: respondError(w, http.StatusMethodNotAllowed, nil) } @@ -34,7 +36,38 @@ func fetchStatusCode(r *http.Request, field string) (int, bool, bool) { } func handleSysHealthGet(core *vault.Core, w http.ResponseWriter, r *http.Request) { + code, body, err := getSysHealth(core, r) + if err != nil { + respondError(w, http.StatusInternalServerError, nil) + return + } + if body == nil { + respondError(w, code, nil) + return + } + + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(code) + + // Generate the response + enc := json.NewEncoder(w) + enc.Encode(body) +} + +func handleSysHealthHead(core *vault.Core, w http.ResponseWriter, r *http.Request) { + code, body, err := getSysHealth(core, r) + if err != nil { + code = http.StatusInternalServerError + } + + if body != nil { + w.Header().Add("Content-Type", "application/json") + } + w.WriteHeader(code) +} + +func getSysHealth(core *vault.Core, r *http.Request) (int, *HealthResponse, error) { // Check if being a standby is allowed for the purpose of a 200 OK _, standbyOK := r.URL.Query()["standbyok"] @@ -42,24 +75,21 @@ func handleSysHealthGet(core *vault.Core, w http.ResponseWriter, r *http.Request // point sealedCode := http.StatusInternalServerError if code, found, ok := fetchStatusCode(r, "sealedcode"); !ok { - respondError(w, http.StatusBadRequest, nil) - return + return http.StatusBadRequest, nil, nil } else if found { sealedCode = code } standbyCode := http.StatusTooManyRequests // Consul warning code if code, found, ok := fetchStatusCode(r, "standbycode"); !ok { - respondError(w, http.StatusBadRequest, nil) - return + return http.StatusBadRequest, nil, nil } else if found { standbyCode = code } activeCode := http.StatusOK if code, found, ok := fetchStatusCode(r, "activecode"); !ok { - respondError(w, http.StatusBadRequest, nil) - return + return http.StatusBadRequest, nil, nil } else if found { activeCode = code } @@ -69,8 +99,7 @@ func handleSysHealthGet(core *vault.Core, w http.ResponseWriter, r *http.Request standby, _ := core.Standby() init, err := core.Initialized() if err != nil { - respondError(w, http.StatusInternalServerError, err) - return + return http.StatusInternalServerError, nil, err } // Determine the status code @@ -91,12 +120,7 @@ func handleSysHealthGet(core *vault.Core, w http.ResponseWriter, r *http.Request Standby: standby, ServerTimeUTC: time.Now().UTC().Unix(), } - - // Generate the response - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(code) - enc := json.NewEncoder(w) - enc.Encode(body) + return code, body, nil } type HealthResponse struct { diff --git a/http/sys_health_test.go b/http/sys_health_test.go index 03b5f207a..452bd0cf1 100644 --- a/http/sys_health_test.go +++ b/http/sys_health_test.go @@ -1,6 +1,8 @@ package http import ( + "io/ioutil" + "net/http" "net/url" "reflect" @@ -105,3 +107,41 @@ func TestSysHealth_customcodes(t *testing.T) { t.Fatalf("bad: %#v", actual) } } + +func TestSysHealth_head(t *testing.T) { + core, _, _ := vault.TestCoreUnsealed(t) + ln, addr := TestServer(t, core) + defer ln.Close() + + testData := []struct{ + uri string + code int + }{ + {"", 200}, + {"?activecode=503", 503}, + {"?activecode=notacode", 400}, + } + + for _, tt := range testData { + queryurl, err := url.Parse(addr + "/v1/sys/health" + tt.uri) + if err != nil { + t.Fatalf("err on %v: %s", queryurl, err) + } + resp, err := http.Head(queryurl.String()) + if err != nil { + t.Fatalf("err on %v: %s", queryurl, err) + } + + if resp.StatusCode != tt.code { + t.Fatalf("HEAD %v expected code %d, got %d.", queryurl, tt.code, resp.StatusCode) + } + + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("err on %v: %s", queryurl, err) + } + if len(data) > 0 { + t.Fatalf("HEAD %v expected no body, received \"%v\".", queryurl, data) + } + } +}