Merge pull request #510 from ctennis/more_descriptive_errors

More descriptive errors with specific HTTP return codes
This commit is contained in:
Armon Dadgar 2015-08-11 10:11:26 -07:00
commit 4abc488cec
8 changed files with 100 additions and 23 deletions

View File

@ -156,6 +156,11 @@ func respondError(w http.ResponseWriter, status int, err error) {
status = http.StatusServiceUnavailable
}
// Allow HTTPCoded error passthrough to specify a code
if t, ok := err.(logical.HTTPCodedError); ok {
status = t.Code()
}
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(status)

View File

@ -1,10 +1,13 @@
package http
import (
"errors"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)
@ -57,3 +60,34 @@ func TestHandler_sealed(t *testing.T) {
}
testResponseStatus(t, resp, 503)
}
func TestHandler_error(t *testing.T) {
w := httptest.NewRecorder()
respondError(w, 500, errors.New("Test Error"))
if w.Code != 500 {
t.Fatalf("expected 500, got %d", w.Code)
}
// The code inside of the error should override
// the argument to respondError
w2 := httptest.NewRecorder()
e := logical.CodedError(403, "error text")
respondError(w2, 500, e)
if w2.Code != 403 {
t.Fatalf("expected 403, got %d", w2.Code)
}
// vault.ErrSealed is a special case
w3 := httptest.NewRecorder()
respondError(w3, 400, vault.ErrSealed)
if w3.Code != 503 {
t.Fatalf("expected 503, got %d", w3.Code)
}
}

View File

@ -131,6 +131,7 @@ func handleSysMount(
"description": req.Description,
},
}))
if err != nil {
respondError(w, http.StatusInternalServerError, err)
return
@ -149,6 +150,7 @@ func handleSysUnmount(
Path: "sys/mounts/" + path,
Connection: getConnection(r),
}))
if err != nil {
respondError(w, http.StatusInternalServerError, err)
return

24
logical/error.go Normal file
View File

@ -0,0 +1,24 @@
package logical
type HTTPCodedError interface {
Error() string
Code() int
}
func CodedError(c int, s string) HTTPCodedError {
return &codedError{s,c}
}
type codedError struct {
s string
code int
}
func (e *codedError) Error() string {
return e.s
}
func (e *codedError) Code() int {
return e.code
}

View File

@ -145,6 +145,6 @@ var (
// ErrInvalidRequest is returned if the request is invalid
ErrInvalidRequest = errors.New("invalid request")
// ErrPermissionDeneid is returned if the client is not authorized
// ErrPermissionDenied is returned if the client is not authorized
ErrPermissionDenied = errors.New("permission denied")
)

View File

@ -13,7 +13,7 @@ const (
// avoided like the HTTPContentType. The value must be a byte slice.
HTTPRawBody = "http_raw_body"
// HTTPStatusCode is the response code the HTTP body that goes with the HTTPContentType.
// HTTPStatusCode is the response code of the HTTP body that goes with the HTTPContentType.
// This can only be specified for non-secrets, and should should be similarly
// avoided like the HTTPContentType. The value must be an integer.
HTTPStatusCode = "http_status_code"

View File

@ -371,9 +371,21 @@ func (b *SystemBackend) handleMount(
// Attempt mount
if err := b.Core.mount(me); err != nil {
b.Backend.Logger().Printf("[ERR] sys: mount %#v failed: %v", me, err)
return handleError(err)
}
return nil, nil
}
// used to intercept an HTTPCodedError so it goes back to callee
func handleError(
err error) (*logical.Response, error) {
switch err.(type) {
case logical.HTTPCodedError:
return logical.ErrorResponse(err.Error()), err
default:
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
return nil, nil
}
// handleUnmount is used to unmount a path
@ -387,7 +399,7 @@ func (b *SystemBackend) handleUnmount(
// Attempt unmount
if err := b.Core.unmount(suffix); err != nil {
b.Backend.Logger().Printf("[ERR] sys: unmount '%s' failed: %v", suffix, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
@ -408,7 +420,7 @@ func (b *SystemBackend) handleRemount(
// Attempt remount
if err := b.Core.remount(fromPath, toPath); err != nil {
b.Backend.Logger().Printf("[ERR] sys: remount '%s' to '%s' failed: %v", fromPath, toPath, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
@ -428,7 +440,7 @@ func (b *SystemBackend) handleRenew(
resp, err := b.Core.expiration.Renew(leaseID, increment)
if err != nil {
b.Backend.Logger().Printf("[ERR] sys: renew '%s' failed: %v", leaseID, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return resp, err
}
@ -442,7 +454,7 @@ func (b *SystemBackend) handleRevoke(
// Invoke the expiration manager directly
if err := b.Core.expiration.Revoke(leaseID); err != nil {
b.Backend.Logger().Printf("[ERR] sys: revoke '%s' failed: %v", leaseID, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
@ -456,7 +468,7 @@ func (b *SystemBackend) handleRevokePrefix(
// Invoke the expiration manager directly
if err := b.Core.expiration.RevokePrefix(prefix); err != nil {
b.Backend.Logger().Printf("[ERR] sys: revoke prefix '%s' failed: %v", prefix, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
@ -504,7 +516,7 @@ func (b *SystemBackend) handleEnableAuth(
// Attempt enabling
if err := b.Core.enableCredential(me); err != nil {
b.Backend.Logger().Printf("[ERR] sys: enable auth %#v failed: %v", me, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
@ -520,7 +532,7 @@ func (b *SystemBackend) handleDisableAuth(
// Attempt disable
if err := b.Core.disableCredential(suffix); err != nil {
b.Backend.Logger().Printf("[ERR] sys: disable auth '%s' failed: %v", suffix, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
@ -543,7 +555,7 @@ func (b *SystemBackend) handlePolicyRead(
policy, err := b.Core.policy.GetPolicy(name)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
if policy == nil {
@ -567,7 +579,7 @@ func (b *SystemBackend) handlePolicySet(
// Validate the rules parse
parse, err := Parse(rules)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
// Override the name
@ -575,7 +587,7 @@ func (b *SystemBackend) handlePolicySet(
// Update the policy
if err := b.Core.policy.SetPolicy(parse); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
@ -585,7 +597,7 @@ func (b *SystemBackend) handlePolicyDelete(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if err := b.Core.policy.DeletePolicy(name); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
@ -640,7 +652,7 @@ func (b *SystemBackend) handleEnableAudit(
// Attempt enabling
if err := b.Core.enableAudit(me); err != nil {
b.Backend.Logger().Printf("[ERR] sys: enable audit %#v failed: %v", me, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
@ -653,7 +665,7 @@ func (b *SystemBackend) handleDisableAudit(
// Attempt disable
if err := b.Core.disableAudit(path); err != nil {
b.Backend.Logger().Printf("[ERR] sys: disable audit '%s' failed: %v", path, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
@ -673,7 +685,7 @@ func (b *SystemBackend) handleRawRead(
entry, err := b.Core.barrier.Get(path)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
if entry == nil {
return nil, nil
@ -724,7 +736,7 @@ func (b *SystemBackend) handleRawDelete(
}
if err := b.Core.barrier.Delete(path); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
return nil, nil
}
@ -754,7 +766,7 @@ func (b *SystemBackend) handleRotate(
newTerm, err := b.Core.barrier.Rotate()
if err != nil {
b.Backend.Logger().Printf("[ERR] sys: failed to create new encryption key: %v", err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
return handleError(err)
}
b.Backend.Logger().Printf("[INFO] sys: installed new encryption key")

View File

@ -22,7 +22,7 @@ const (
// barrier view for the backends.
backendBarrierPrefix = "logical/"
// systemBarrierPrefix is sthe prefix used for the
// systemBarrierPrefix is the prefix used for the
// system logical backend.
systemBarrierPrefix = "sys/"
)
@ -139,16 +139,16 @@ func (c *Core) mount(me *MountEntry) error {
me.Path += "/"
}
// Prevent protected paths from being unmounted
// Prevent protected paths from being mounted
for _, p := range protectedMounts {
if strings.HasPrefix(me.Path, p) {
return fmt.Errorf("cannot mount '%s'", me.Path)
return logical.CodedError(403, fmt.Sprintf("cannot mount '%s'", me.Path))
}
}
// Verify there is no conflicting mount
if match := c.router.MatchingMount(me.Path); match != "" {
return fmt.Errorf("existing mount at '%s'", match)
return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match))
}
// Generate a new UUID and view