Merge pull request #1553 from hashicorp/fix-status-code-regression
Fix up error detection regression to return correct status codes
This commit is contained in:
commit
f24d6a10f5
|
@ -97,11 +97,7 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l
|
|||
respondStandby(core, w, rawReq.URL)
|
||||
return resp, false
|
||||
}
|
||||
if respondCommon(w, resp, err) {
|
||||
return resp, false
|
||||
}
|
||||
if err != nil {
|
||||
respondErrorStatus(w, err)
|
||||
if respondErrorCommon(w, resp, err) {
|
||||
return resp, false
|
||||
}
|
||||
|
||||
|
@ -192,18 +188,6 @@ func requestWrapTTL(r *http.Request, req *logical.Request) (*logical.Request, er
|
|||
return req, nil
|
||||
}
|
||||
|
||||
// Determines the type of the error being returned and sets the HTTP
|
||||
// status code appropriately
|
||||
func respondErrorStatus(w http.ResponseWriter, err error) {
|
||||
status := http.StatusInternalServerError
|
||||
switch {
|
||||
// Keep adding more error types here to appropriate the status codes
|
||||
case err != nil && errwrap.ContainsType(err, new(vault.StatusBadRequest)):
|
||||
status = http.StatusBadRequest
|
||||
}
|
||||
respondError(w, status, err)
|
||||
}
|
||||
|
||||
func respondError(w http.ResponseWriter, status int, err error) {
|
||||
// Adjust status code when sealed
|
||||
if errwrap.Contains(err, vault.ErrSealed.Error()) {
|
||||
|
@ -227,33 +211,43 @@ func respondError(w http.ResponseWriter, status int, err error) {
|
|||
enc.Encode(resp)
|
||||
}
|
||||
|
||||
func respondCommon(w http.ResponseWriter, resp *logical.Response, err error) bool {
|
||||
if resp == nil {
|
||||
func respondErrorCommon(w http.ResponseWriter, resp *logical.Response, err error) bool {
|
||||
// If there are no errors return
|
||||
if err == nil && (resp == nil || !resp.IsError()) {
|
||||
return false
|
||||
}
|
||||
|
||||
if resp.IsError() {
|
||||
statusCode := http.StatusBadRequest
|
||||
|
||||
if err != nil {
|
||||
switch err {
|
||||
case logical.ErrPermissionDenied:
|
||||
statusCode = http.StatusForbidden
|
||||
case logical.ErrUnsupportedOperation:
|
||||
statusCode = http.StatusMethodNotAllowed
|
||||
case logical.ErrUnsupportedPath:
|
||||
statusCode = http.StatusNotFound
|
||||
case logical.ErrInvalidRequest:
|
||||
statusCode = http.StatusBadRequest
|
||||
}
|
||||
}
|
||||
|
||||
err := fmt.Errorf("%s", resp.Data["error"].(string))
|
||||
respondError(w, statusCode, err)
|
||||
return true
|
||||
// Start out with internal server error since in most of these cases there
|
||||
// won't be a response so this won't be overridden
|
||||
statusCode := http.StatusInternalServerError
|
||||
// If we actually have a response, start out with bad request
|
||||
if resp != nil {
|
||||
statusCode = http.StatusBadRequest
|
||||
}
|
||||
|
||||
return false
|
||||
// Now, check the error itself; if it has a specific logical error, set the
|
||||
// appropriate code
|
||||
if err != nil {
|
||||
switch {
|
||||
case errwrap.ContainsType(err, new(vault.StatusBadRequest)):
|
||||
statusCode = http.StatusBadRequest
|
||||
case errwrap.Contains(err, logical.ErrPermissionDenied.Error()):
|
||||
statusCode = http.StatusForbidden
|
||||
case errwrap.Contains(err, logical.ErrUnsupportedOperation.Error()):
|
||||
statusCode = http.StatusMethodNotAllowed
|
||||
case errwrap.Contains(err, logical.ErrUnsupportedPath.Error()):
|
||||
statusCode = http.StatusNotFound
|
||||
case errwrap.Contains(err, logical.ErrInvalidRequest.Error()):
|
||||
statusCode = http.StatusBadRequest
|
||||
}
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
err = fmt.Errorf("%s", resp.Data["error"].(string))
|
||||
}
|
||||
|
||||
respondError(w, statusCode, err)
|
||||
return true
|
||||
}
|
||||
|
||||
func respondOk(w http.ResponseWriter, body interface{}) {
|
||||
|
|
|
@ -30,8 +30,11 @@ func TestLogical(t *testing.T) {
|
|||
testResponseStatus(t, resp, 204)
|
||||
|
||||
// READ
|
||||
resp = testHttpGet(t, token, addr+"/v1/secret/foo")
|
||||
// Bad token should return a 403
|
||||
resp = testHttpGet(t, token+"bad", addr+"/v1/secret/foo")
|
||||
testResponseStatus(t, resp, 403)
|
||||
|
||||
resp = testHttpGet(t, token, addr+"/v1/secret/foo")
|
||||
var actual map[string]interface{}
|
||||
var nilWarnings interface{}
|
||||
expected := map[string]interface{}{
|
||||
|
|
Loading…
Reference in New Issue