Merge pull request #510 from ctennis/more_descriptive_errors
More descriptive errors with specific HTTP return codes
This commit is contained in:
commit
4abc488cec
|
@ -156,6 +156,11 @@ func respondError(w http.ResponseWriter, status int, err error) {
|
|||
status = http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
// Allow HTTPCoded error passthrough to specify a code
|
||||
if t, ok := err.(logical.HTTPCodedError); ok {
|
||||
status = t.Code()
|
||||
}
|
||||
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
|
@ -57,3 +60,34 @@ func TestHandler_sealed(t *testing.T) {
|
|||
}
|
||||
testResponseStatus(t, resp, 503)
|
||||
}
|
||||
|
||||
func TestHandler_error(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
respondError(w, 500, errors.New("Test Error"))
|
||||
|
||||
if w.Code != 500 {
|
||||
t.Fatalf("expected 500, got %d", w.Code)
|
||||
}
|
||||
|
||||
// The code inside of the error should override
|
||||
// the argument to respondError
|
||||
w2 := httptest.NewRecorder()
|
||||
e := logical.CodedError(403, "error text")
|
||||
|
||||
respondError(w2, 500, e)
|
||||
|
||||
if w2.Code != 403 {
|
||||
t.Fatalf("expected 403, got %d", w2.Code)
|
||||
}
|
||||
|
||||
// vault.ErrSealed is a special case
|
||||
w3 := httptest.NewRecorder()
|
||||
|
||||
respondError(w3, 400, vault.ErrSealed)
|
||||
|
||||
if w3.Code != 503 {
|
||||
t.Fatalf("expected 503, got %d", w3.Code)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -131,6 +131,7 @@ func handleSysMount(
|
|||
"description": req.Description,
|
||||
},
|
||||
}))
|
||||
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
|
@ -149,6 +150,7 @@ func handleSysUnmount(
|
|||
Path: "sys/mounts/" + path,
|
||||
Connection: getConnection(r),
|
||||
}))
|
||||
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
package logical
|
||||
|
||||
type HTTPCodedError interface {
|
||||
Error() string
|
||||
Code() int
|
||||
}
|
||||
|
||||
func CodedError(c int, s string) HTTPCodedError {
|
||||
return &codedError{s,c}
|
||||
}
|
||||
|
||||
type codedError struct {
|
||||
s string
|
||||
code int
|
||||
}
|
||||
|
||||
func (e *codedError) Error() string {
|
||||
return e.s
|
||||
}
|
||||
|
||||
func (e *codedError) Code() int {
|
||||
return e.code
|
||||
}
|
||||
|
|
@ -145,6 +145,6 @@ var (
|
|||
// ErrInvalidRequest is returned if the request is invalid
|
||||
ErrInvalidRequest = errors.New("invalid request")
|
||||
|
||||
// ErrPermissionDeneid is returned if the client is not authorized
|
||||
// ErrPermissionDenied is returned if the client is not authorized
|
||||
ErrPermissionDenied = errors.New("permission denied")
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@ const (
|
|||
// avoided like the HTTPContentType. The value must be a byte slice.
|
||||
HTTPRawBody = "http_raw_body"
|
||||
|
||||
// HTTPStatusCode is the response code the HTTP body that goes with the HTTPContentType.
|
||||
// HTTPStatusCode is the response code of the HTTP body that goes with the HTTPContentType.
|
||||
// This can only be specified for non-secrets, and should should be similarly
|
||||
// avoided like the HTTPContentType. The value must be an integer.
|
||||
HTTPStatusCode = "http_status_code"
|
||||
|
|
|
@ -371,9 +371,21 @@ func (b *SystemBackend) handleMount(
|
|||
// Attempt mount
|
||||
if err := b.Core.mount(me); err != nil {
|
||||
b.Backend.Logger().Printf("[ERR] sys: mount %#v failed: %v", me, err)
|
||||
return handleError(err)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// used to intercept an HTTPCodedError so it goes back to callee
|
||||
func handleError(
|
||||
err error) (*logical.Response, error) {
|
||||
switch err.(type) {
|
||||
case logical.HTTPCodedError:
|
||||
return logical.ErrorResponse(err.Error()), err
|
||||
default:
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// handleUnmount is used to unmount a path
|
||||
|
@ -387,7 +399,7 @@ func (b *SystemBackend) handleUnmount(
|
|||
// Attempt unmount
|
||||
if err := b.Core.unmount(suffix); err != nil {
|
||||
b.Backend.Logger().Printf("[ERR] sys: unmount '%s' failed: %v", suffix, err)
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
|
@ -408,7 +420,7 @@ func (b *SystemBackend) handleRemount(
|
|||
// Attempt remount
|
||||
if err := b.Core.remount(fromPath, toPath); err != nil {
|
||||
b.Backend.Logger().Printf("[ERR] sys: remount '%s' to '%s' failed: %v", fromPath, toPath, err)
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
|
@ -428,7 +440,7 @@ func (b *SystemBackend) handleRenew(
|
|||
resp, err := b.Core.expiration.Renew(leaseID, increment)
|
||||
if err != nil {
|
||||
b.Backend.Logger().Printf("[ERR] sys: renew '%s' failed: %v", leaseID, err)
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
@ -442,7 +454,7 @@ func (b *SystemBackend) handleRevoke(
|
|||
// Invoke the expiration manager directly
|
||||
if err := b.Core.expiration.Revoke(leaseID); err != nil {
|
||||
b.Backend.Logger().Printf("[ERR] sys: revoke '%s' failed: %v", leaseID, err)
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -456,7 +468,7 @@ func (b *SystemBackend) handleRevokePrefix(
|
|||
// Invoke the expiration manager directly
|
||||
if err := b.Core.expiration.RevokePrefix(prefix); err != nil {
|
||||
b.Backend.Logger().Printf("[ERR] sys: revoke prefix '%s' failed: %v", prefix, err)
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -504,7 +516,7 @@ func (b *SystemBackend) handleEnableAuth(
|
|||
// Attempt enabling
|
||||
if err := b.Core.enableCredential(me); err != nil {
|
||||
b.Backend.Logger().Printf("[ERR] sys: enable auth %#v failed: %v", me, err)
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -520,7 +532,7 @@ func (b *SystemBackend) handleDisableAuth(
|
|||
// Attempt disable
|
||||
if err := b.Core.disableCredential(suffix); err != nil {
|
||||
b.Backend.Logger().Printf("[ERR] sys: disable auth '%s' failed: %v", suffix, err)
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -543,7 +555,7 @@ func (b *SystemBackend) handlePolicyRead(
|
|||
|
||||
policy, err := b.Core.policy.GetPolicy(name)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
|
||||
if policy == nil {
|
||||
|
@ -567,7 +579,7 @@ func (b *SystemBackend) handlePolicySet(
|
|||
// Validate the rules parse
|
||||
parse, err := Parse(rules)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
|
||||
// Override the name
|
||||
|
@ -575,7 +587,7 @@ func (b *SystemBackend) handlePolicySet(
|
|||
|
||||
// Update the policy
|
||||
if err := b.Core.policy.SetPolicy(parse); err != nil {
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -585,7 +597,7 @@ func (b *SystemBackend) handlePolicyDelete(
|
|||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
if err := b.Core.policy.DeletePolicy(name); err != nil {
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -640,7 +652,7 @@ func (b *SystemBackend) handleEnableAudit(
|
|||
// Attempt enabling
|
||||
if err := b.Core.enableAudit(me); err != nil {
|
||||
b.Backend.Logger().Printf("[ERR] sys: enable audit %#v failed: %v", me, err)
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -653,7 +665,7 @@ func (b *SystemBackend) handleDisableAudit(
|
|||
// Attempt disable
|
||||
if err := b.Core.disableAudit(path); err != nil {
|
||||
b.Backend.Logger().Printf("[ERR] sys: disable audit '%s' failed: %v", path, err)
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -673,7 +685,7 @@ func (b *SystemBackend) handleRawRead(
|
|||
|
||||
entry, err := b.Core.barrier.Get(path)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
|
@ -724,7 +736,7 @@ func (b *SystemBackend) handleRawDelete(
|
|||
}
|
||||
|
||||
if err := b.Core.barrier.Delete(path); err != nil {
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -754,7 +766,7 @@ func (b *SystemBackend) handleRotate(
|
|||
newTerm, err := b.Core.barrier.Rotate()
|
||||
if err != nil {
|
||||
b.Backend.Logger().Printf("[ERR] sys: failed to create new encryption key: %v", err)
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
return handleError(err)
|
||||
}
|
||||
b.Backend.Logger().Printf("[INFO] sys: installed new encryption key")
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ const (
|
|||
// barrier view for the backends.
|
||||
backendBarrierPrefix = "logical/"
|
||||
|
||||
// systemBarrierPrefix is sthe prefix used for the
|
||||
// systemBarrierPrefix is the prefix used for the
|
||||
// system logical backend.
|
||||
systemBarrierPrefix = "sys/"
|
||||
)
|
||||
|
@ -139,16 +139,16 @@ func (c *Core) mount(me *MountEntry) error {
|
|||
me.Path += "/"
|
||||
}
|
||||
|
||||
// Prevent protected paths from being unmounted
|
||||
// Prevent protected paths from being mounted
|
||||
for _, p := range protectedMounts {
|
||||
if strings.HasPrefix(me.Path, p) {
|
||||
return fmt.Errorf("cannot mount '%s'", me.Path)
|
||||
return logical.CodedError(403, fmt.Sprintf("cannot mount '%s'", me.Path))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify there is no conflicting mount
|
||||
if match := c.router.MatchingMount(me.Path); match != "" {
|
||||
return fmt.Errorf("existing mount at '%s'", match)
|
||||
return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match))
|
||||
}
|
||||
|
||||
// Generate a new UUID and view
|
||||
|
|
Loading…
Reference in New Issue