Merge pull request #510 from ctennis/more_descriptive_errors

More descriptive errors with specific HTTP return codes
This commit is contained in:
Armon Dadgar 2015-08-11 10:11:26 -07:00
commit 4abc488cec
8 changed files with 100 additions and 23 deletions

View file

@ -156,6 +156,11 @@ func respondError(w http.ResponseWriter, status int, err error) {
status = http.StatusServiceUnavailable status = http.StatusServiceUnavailable
} }
// Allow HTTPCoded error passthrough to specify a code
if t, ok := err.(logical.HTTPCodedError); ok {
status = t.Code()
}
w.Header().Add("Content-Type", "application/json") w.Header().Add("Content-Type", "application/json")
w.WriteHeader(status) w.WriteHeader(status)

View file

@ -1,10 +1,13 @@
package http package http
import ( import (
"errors"
"net/http" "net/http"
"net/http/httptest"
"reflect" "reflect"
"testing" "testing"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
) )
@ -57,3 +60,34 @@ func TestHandler_sealed(t *testing.T) {
} }
testResponseStatus(t, resp, 503) testResponseStatus(t, resp, 503)
} }
func TestHandler_error(t *testing.T) {
w := httptest.NewRecorder()
respondError(w, 500, errors.New("Test Error"))
if w.Code != 500 {
t.Fatalf("expected 500, got %d", w.Code)
}
// The code inside of the error should override
// the argument to respondError
w2 := httptest.NewRecorder()
e := logical.CodedError(403, "error text")
respondError(w2, 500, e)
if w2.Code != 403 {
t.Fatalf("expected 403, got %d", w2.Code)
}
// vault.ErrSealed is a special case
w3 := httptest.NewRecorder()
respondError(w3, 400, vault.ErrSealed)
if w3.Code != 503 {
t.Fatalf("expected 503, got %d", w3.Code)
}
}

View file

@ -131,6 +131,7 @@ func handleSysMount(
"description": req.Description, "description": req.Description,
}, },
})) }))
if err != nil { if err != nil {
respondError(w, http.StatusInternalServerError, err) respondError(w, http.StatusInternalServerError, err)
return return
@ -149,6 +150,7 @@ func handleSysUnmount(
Path: "sys/mounts/" + path, Path: "sys/mounts/" + path,
Connection: getConnection(r), Connection: getConnection(r),
})) }))
if err != nil { if err != nil {
respondError(w, http.StatusInternalServerError, err) respondError(w, http.StatusInternalServerError, err)
return return

24
logical/error.go Normal file
View file

@ -0,0 +1,24 @@
package logical
type HTTPCodedError interface {
Error() string
Code() int
}
func CodedError(c int, s string) HTTPCodedError {
return &codedError{s,c}
}
type codedError struct {
s string
code int
}
func (e *codedError) Error() string {
return e.s
}
func (e *codedError) Code() int {
return e.code
}

View file

@ -145,6 +145,6 @@ var (
// ErrInvalidRequest is returned if the request is invalid // ErrInvalidRequest is returned if the request is invalid
ErrInvalidRequest = errors.New("invalid request") ErrInvalidRequest = errors.New("invalid request")
// ErrPermissionDeneid is returned if the client is not authorized // ErrPermissionDenied is returned if the client is not authorized
ErrPermissionDenied = errors.New("permission denied") ErrPermissionDenied = errors.New("permission denied")
) )

View file

@ -13,7 +13,7 @@ const (
// avoided like the HTTPContentType. The value must be a byte slice. // avoided like the HTTPContentType. The value must be a byte slice.
HTTPRawBody = "http_raw_body" HTTPRawBody = "http_raw_body"
// HTTPStatusCode is the response code the HTTP body that goes with the HTTPContentType. // HTTPStatusCode is the response code of the HTTP body that goes with the HTTPContentType.
// This can only be specified for non-secrets, and should should be similarly // This can only be specified for non-secrets, and should should be similarly
// avoided like the HTTPContentType. The value must be an integer. // avoided like the HTTPContentType. The value must be an integer.
HTTPStatusCode = "http_status_code" HTTPStatusCode = "http_status_code"

View file

@ -371,9 +371,21 @@ func (b *SystemBackend) handleMount(
// Attempt mount // Attempt mount
if err := b.Core.mount(me); err != nil { if err := b.Core.mount(me); err != nil {
b.Backend.Logger().Printf("[ERR] sys: mount %#v failed: %v", me, err) b.Backend.Logger().Printf("[ERR] sys: mount %#v failed: %v", me, err)
return handleError(err)
}
return nil, nil
}
// used to intercept an HTTPCodedError so it goes back to callee
func handleError(
err error) (*logical.Response, error) {
switch err.(type) {
case logical.HTTPCodedError:
return logical.ErrorResponse(err.Error()), err
default:
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
} }
return nil, nil
} }
// handleUnmount is used to unmount a path // handleUnmount is used to unmount a path
@ -387,7 +399,7 @@ func (b *SystemBackend) handleUnmount(
// Attempt unmount // Attempt unmount
if err := b.Core.unmount(suffix); err != nil { if err := b.Core.unmount(suffix); err != nil {
b.Backend.Logger().Printf("[ERR] sys: unmount '%s' failed: %v", suffix, err) b.Backend.Logger().Printf("[ERR] sys: unmount '%s' failed: %v", suffix, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
return nil, nil return nil, nil
@ -408,7 +420,7 @@ func (b *SystemBackend) handleRemount(
// Attempt remount // Attempt remount
if err := b.Core.remount(fromPath, toPath); err != nil { if err := b.Core.remount(fromPath, toPath); err != nil {
b.Backend.Logger().Printf("[ERR] sys: remount '%s' to '%s' failed: %v", fromPath, toPath, err) b.Backend.Logger().Printf("[ERR] sys: remount '%s' to '%s' failed: %v", fromPath, toPath, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
return nil, nil return nil, nil
@ -428,7 +440,7 @@ func (b *SystemBackend) handleRenew(
resp, err := b.Core.expiration.Renew(leaseID, increment) resp, err := b.Core.expiration.Renew(leaseID, increment)
if err != nil { if err != nil {
b.Backend.Logger().Printf("[ERR] sys: renew '%s' failed: %v", leaseID, err) b.Backend.Logger().Printf("[ERR] sys: renew '%s' failed: %v", leaseID, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
return resp, err return resp, err
} }
@ -442,7 +454,7 @@ func (b *SystemBackend) handleRevoke(
// Invoke the expiration manager directly // Invoke the expiration manager directly
if err := b.Core.expiration.Revoke(leaseID); err != nil { if err := b.Core.expiration.Revoke(leaseID); err != nil {
b.Backend.Logger().Printf("[ERR] sys: revoke '%s' failed: %v", leaseID, err) b.Backend.Logger().Printf("[ERR] sys: revoke '%s' failed: %v", leaseID, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
return nil, nil return nil, nil
} }
@ -456,7 +468,7 @@ func (b *SystemBackend) handleRevokePrefix(
// Invoke the expiration manager directly // Invoke the expiration manager directly
if err := b.Core.expiration.RevokePrefix(prefix); err != nil { if err := b.Core.expiration.RevokePrefix(prefix); err != nil {
b.Backend.Logger().Printf("[ERR] sys: revoke prefix '%s' failed: %v", prefix, err) b.Backend.Logger().Printf("[ERR] sys: revoke prefix '%s' failed: %v", prefix, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
return nil, nil return nil, nil
} }
@ -504,7 +516,7 @@ func (b *SystemBackend) handleEnableAuth(
// Attempt enabling // Attempt enabling
if err := b.Core.enableCredential(me); err != nil { if err := b.Core.enableCredential(me); err != nil {
b.Backend.Logger().Printf("[ERR] sys: enable auth %#v failed: %v", me, err) b.Backend.Logger().Printf("[ERR] sys: enable auth %#v failed: %v", me, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
return nil, nil return nil, nil
} }
@ -520,7 +532,7 @@ func (b *SystemBackend) handleDisableAuth(
// Attempt disable // Attempt disable
if err := b.Core.disableCredential(suffix); err != nil { if err := b.Core.disableCredential(suffix); err != nil {
b.Backend.Logger().Printf("[ERR] sys: disable auth '%s' failed: %v", suffix, err) b.Backend.Logger().Printf("[ERR] sys: disable auth '%s' failed: %v", suffix, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
return nil, nil return nil, nil
} }
@ -543,7 +555,7 @@ func (b *SystemBackend) handlePolicyRead(
policy, err := b.Core.policy.GetPolicy(name) policy, err := b.Core.policy.GetPolicy(name)
if err != nil { if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
if policy == nil { if policy == nil {
@ -567,7 +579,7 @@ func (b *SystemBackend) handlePolicySet(
// Validate the rules parse // Validate the rules parse
parse, err := Parse(rules) parse, err := Parse(rules)
if err != nil { if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
// Override the name // Override the name
@ -575,7 +587,7 @@ func (b *SystemBackend) handlePolicySet(
// Update the policy // Update the policy
if err := b.Core.policy.SetPolicy(parse); err != nil { if err := b.Core.policy.SetPolicy(parse); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
return nil, nil return nil, nil
} }
@ -585,7 +597,7 @@ func (b *SystemBackend) handlePolicyDelete(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) { req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string) name := data.Get("name").(string)
if err := b.Core.policy.DeletePolicy(name); err != nil { if err := b.Core.policy.DeletePolicy(name); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
return nil, nil return nil, nil
} }
@ -640,7 +652,7 @@ func (b *SystemBackend) handleEnableAudit(
// Attempt enabling // Attempt enabling
if err := b.Core.enableAudit(me); err != nil { if err := b.Core.enableAudit(me); err != nil {
b.Backend.Logger().Printf("[ERR] sys: enable audit %#v failed: %v", me, err) b.Backend.Logger().Printf("[ERR] sys: enable audit %#v failed: %v", me, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
return nil, nil return nil, nil
} }
@ -653,7 +665,7 @@ func (b *SystemBackend) handleDisableAudit(
// Attempt disable // Attempt disable
if err := b.Core.disableAudit(path); err != nil { if err := b.Core.disableAudit(path); err != nil {
b.Backend.Logger().Printf("[ERR] sys: disable audit '%s' failed: %v", path, err) b.Backend.Logger().Printf("[ERR] sys: disable audit '%s' failed: %v", path, err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
return nil, nil return nil, nil
} }
@ -673,7 +685,7 @@ func (b *SystemBackend) handleRawRead(
entry, err := b.Core.barrier.Get(path) entry, err := b.Core.barrier.Get(path)
if err != nil { if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
if entry == nil { if entry == nil {
return nil, nil return nil, nil
@ -724,7 +736,7 @@ func (b *SystemBackend) handleRawDelete(
} }
if err := b.Core.barrier.Delete(path); err != nil { if err := b.Core.barrier.Delete(path); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
return nil, nil return nil, nil
} }
@ -754,7 +766,7 @@ func (b *SystemBackend) handleRotate(
newTerm, err := b.Core.barrier.Rotate() newTerm, err := b.Core.barrier.Rotate()
if err != nil { if err != nil {
b.Backend.Logger().Printf("[ERR] sys: failed to create new encryption key: %v", err) b.Backend.Logger().Printf("[ERR] sys: failed to create new encryption key: %v", err)
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return handleError(err)
} }
b.Backend.Logger().Printf("[INFO] sys: installed new encryption key") b.Backend.Logger().Printf("[INFO] sys: installed new encryption key")

View file

@ -22,7 +22,7 @@ const (
// barrier view for the backends. // barrier view for the backends.
backendBarrierPrefix = "logical/" backendBarrierPrefix = "logical/"
// systemBarrierPrefix is sthe prefix used for the // systemBarrierPrefix is the prefix used for the
// system logical backend. // system logical backend.
systemBarrierPrefix = "sys/" systemBarrierPrefix = "sys/"
) )
@ -139,16 +139,16 @@ func (c *Core) mount(me *MountEntry) error {
me.Path += "/" me.Path += "/"
} }
// Prevent protected paths from being unmounted // Prevent protected paths from being mounted
for _, p := range protectedMounts { for _, p := range protectedMounts {
if strings.HasPrefix(me.Path, p) { if strings.HasPrefix(me.Path, p) {
return fmt.Errorf("cannot mount '%s'", me.Path) return logical.CodedError(403, fmt.Sprintf("cannot mount '%s'", me.Path))
} }
} }
// Verify there is no conflicting mount // Verify there is no conflicting mount
if match := c.router.MatchingMount(me.Path); match != "" { if match := c.router.MatchingMount(me.Path); match != "" {
return fmt.Errorf("existing mount at '%s'", match) return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match))
} }
// Generate a new UUID and view // Generate a new UUID and view