Merge pull request #510 from ctennis/more_descriptive_errors
More descriptive errors with specific HTTP return codes
This commit is contained in:
commit
4abc488cec
|
@ -156,6 +156,11 @@ func respondError(w http.ResponseWriter, status int, err error) {
|
||||||
status = http.StatusServiceUnavailable
|
status = http.StatusServiceUnavailable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Allow HTTPCoded error passthrough to specify a code
|
||||||
|
if t, ok := err.(logical.HTTPCodedError); ok {
|
||||||
|
status = t.Code()
|
||||||
|
}
|
||||||
|
|
||||||
w.Header().Add("Content-Type", "application/json")
|
w.Header().Add("Content-Type", "application/json")
|
||||||
w.WriteHeader(status)
|
w.WriteHeader(status)
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
package http
|
package http
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/vault"
|
"github.com/hashicorp/vault/vault"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -57,3 +60,34 @@ func TestHandler_sealed(t *testing.T) {
|
||||||
}
|
}
|
||||||
testResponseStatus(t, resp, 503)
|
testResponseStatus(t, resp, 503)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandler_error(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
respondError(w, 500, errors.New("Test Error"))
|
||||||
|
|
||||||
|
if w.Code != 500 {
|
||||||
|
t.Fatalf("expected 500, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The code inside of the error should override
|
||||||
|
// the argument to respondError
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
e := logical.CodedError(403, "error text")
|
||||||
|
|
||||||
|
respondError(w2, 500, e)
|
||||||
|
|
||||||
|
if w2.Code != 403 {
|
||||||
|
t.Fatalf("expected 403, got %d", w2.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// vault.ErrSealed is a special case
|
||||||
|
w3 := httptest.NewRecorder()
|
||||||
|
|
||||||
|
respondError(w3, 400, vault.ErrSealed)
|
||||||
|
|
||||||
|
if w3.Code != 503 {
|
||||||
|
t.Fatalf("expected 503, got %d", w3.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
|
@ -131,6 +131,7 @@ func handleSysMount(
|
||||||
"description": req.Description,
|
"description": req.Description,
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondError(w, http.StatusInternalServerError, err)
|
respondError(w, http.StatusInternalServerError, err)
|
||||||
return
|
return
|
||||||
|
@ -149,6 +150,7 @@ func handleSysUnmount(
|
||||||
Path: "sys/mounts/" + path,
|
Path: "sys/mounts/" + path,
|
||||||
Connection: getConnection(r),
|
Connection: getConnection(r),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondError(w, http.StatusInternalServerError, err)
|
respondError(w, http.StatusInternalServerError, err)
|
||||||
return
|
return
|
||||||
|
|
24
logical/error.go
Normal file
24
logical/error.go
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
package logical
|
||||||
|
|
||||||
|
type HTTPCodedError interface {
|
||||||
|
Error() string
|
||||||
|
Code() int
|
||||||
|
}
|
||||||
|
|
||||||
|
func CodedError(c int, s string) HTTPCodedError {
|
||||||
|
return &codedError{s,c}
|
||||||
|
}
|
||||||
|
|
||||||
|
type codedError struct {
|
||||||
|
s string
|
||||||
|
code int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *codedError) Error() string {
|
||||||
|
return e.s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *codedError) Code() int {
|
||||||
|
return e.code
|
||||||
|
}
|
||||||
|
|
|
@ -145,6 +145,6 @@ var (
|
||||||
// ErrInvalidRequest is returned if the request is invalid
|
// ErrInvalidRequest is returned if the request is invalid
|
||||||
ErrInvalidRequest = errors.New("invalid request")
|
ErrInvalidRequest = errors.New("invalid request")
|
||||||
|
|
||||||
// ErrPermissionDeneid is returned if the client is not authorized
|
// ErrPermissionDenied is returned if the client is not authorized
|
||||||
ErrPermissionDenied = errors.New("permission denied")
|
ErrPermissionDenied = errors.New("permission denied")
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,7 +13,7 @@ const (
|
||||||
// avoided like the HTTPContentType. The value must be a byte slice.
|
// avoided like the HTTPContentType. The value must be a byte slice.
|
||||||
HTTPRawBody = "http_raw_body"
|
HTTPRawBody = "http_raw_body"
|
||||||
|
|
||||||
// HTTPStatusCode is the response code the HTTP body that goes with the HTTPContentType.
|
// HTTPStatusCode is the response code of the HTTP body that goes with the HTTPContentType.
|
||||||
// This can only be specified for non-secrets, and should should be similarly
|
// This can only be specified for non-secrets, and should should be similarly
|
||||||
// avoided like the HTTPContentType. The value must be an integer.
|
// avoided like the HTTPContentType. The value must be an integer.
|
||||||
HTTPStatusCode = "http_status_code"
|
HTTPStatusCode = "http_status_code"
|
||||||
|
|
|
@ -371,9 +371,21 @@ func (b *SystemBackend) handleMount(
|
||||||
// Attempt mount
|
// Attempt mount
|
||||||
if err := b.Core.mount(me); err != nil {
|
if err := b.Core.mount(me); err != nil {
|
||||||
b.Backend.Logger().Printf("[ERR] sys: mount %#v failed: %v", me, err)
|
b.Backend.Logger().Printf("[ERR] sys: mount %#v failed: %v", me, err)
|
||||||
|
return handleError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// used to intercept an HTTPCodedError so it goes back to callee
|
||||||
|
func handleError(
|
||||||
|
err error) (*logical.Response, error) {
|
||||||
|
switch err.(type) {
|
||||||
|
case logical.HTTPCodedError:
|
||||||
|
return logical.ErrorResponse(err.Error()), err
|
||||||
|
default:
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||||
}
|
}
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleUnmount is used to unmount a path
|
// handleUnmount is used to unmount a path
|
||||||
|
@ -387,7 +399,7 @@ func (b *SystemBackend) handleUnmount(
|
||||||
// Attempt unmount
|
// Attempt unmount
|
||||||
if err := b.Core.unmount(suffix); err != nil {
|
if err := b.Core.unmount(suffix); err != nil {
|
||||||
b.Backend.Logger().Printf("[ERR] sys: unmount '%s' failed: %v", suffix, err)
|
b.Backend.Logger().Printf("[ERR] sys: unmount '%s' failed: %v", suffix, err)
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -408,7 +420,7 @@ func (b *SystemBackend) handleRemount(
|
||||||
// Attempt remount
|
// Attempt remount
|
||||||
if err := b.Core.remount(fromPath, toPath); err != nil {
|
if err := b.Core.remount(fromPath, toPath); err != nil {
|
||||||
b.Backend.Logger().Printf("[ERR] sys: remount '%s' to '%s' failed: %v", fromPath, toPath, err)
|
b.Backend.Logger().Printf("[ERR] sys: remount '%s' to '%s' failed: %v", fromPath, toPath, err)
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -428,7 +440,7 @@ func (b *SystemBackend) handleRenew(
|
||||||
resp, err := b.Core.expiration.Renew(leaseID, increment)
|
resp, err := b.Core.expiration.Renew(leaseID, increment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Backend.Logger().Printf("[ERR] sys: renew '%s' failed: %v", leaseID, err)
|
b.Backend.Logger().Printf("[ERR] sys: renew '%s' failed: %v", leaseID, err)
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
@ -442,7 +454,7 @@ func (b *SystemBackend) handleRevoke(
|
||||||
// Invoke the expiration manager directly
|
// Invoke the expiration manager directly
|
||||||
if err := b.Core.expiration.Revoke(leaseID); err != nil {
|
if err := b.Core.expiration.Revoke(leaseID); err != nil {
|
||||||
b.Backend.Logger().Printf("[ERR] sys: revoke '%s' failed: %v", leaseID, err)
|
b.Backend.Logger().Printf("[ERR] sys: revoke '%s' failed: %v", leaseID, err)
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -456,7 +468,7 @@ func (b *SystemBackend) handleRevokePrefix(
|
||||||
// Invoke the expiration manager directly
|
// Invoke the expiration manager directly
|
||||||
if err := b.Core.expiration.RevokePrefix(prefix); err != nil {
|
if err := b.Core.expiration.RevokePrefix(prefix); err != nil {
|
||||||
b.Backend.Logger().Printf("[ERR] sys: revoke prefix '%s' failed: %v", prefix, err)
|
b.Backend.Logger().Printf("[ERR] sys: revoke prefix '%s' failed: %v", prefix, err)
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -504,7 +516,7 @@ func (b *SystemBackend) handleEnableAuth(
|
||||||
// Attempt enabling
|
// Attempt enabling
|
||||||
if err := b.Core.enableCredential(me); err != nil {
|
if err := b.Core.enableCredential(me); err != nil {
|
||||||
b.Backend.Logger().Printf("[ERR] sys: enable auth %#v failed: %v", me, err)
|
b.Backend.Logger().Printf("[ERR] sys: enable auth %#v failed: %v", me, err)
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -520,7 +532,7 @@ func (b *SystemBackend) handleDisableAuth(
|
||||||
// Attempt disable
|
// Attempt disable
|
||||||
if err := b.Core.disableCredential(suffix); err != nil {
|
if err := b.Core.disableCredential(suffix); err != nil {
|
||||||
b.Backend.Logger().Printf("[ERR] sys: disable auth '%s' failed: %v", suffix, err)
|
b.Backend.Logger().Printf("[ERR] sys: disable auth '%s' failed: %v", suffix, err)
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -543,7 +555,7 @@ func (b *SystemBackend) handlePolicyRead(
|
||||||
|
|
||||||
policy, err := b.Core.policy.GetPolicy(name)
|
policy, err := b.Core.policy.GetPolicy(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if policy == nil {
|
if policy == nil {
|
||||||
|
@ -567,7 +579,7 @@ func (b *SystemBackend) handlePolicySet(
|
||||||
// Validate the rules parse
|
// Validate the rules parse
|
||||||
parse, err := Parse(rules)
|
parse, err := Parse(rules)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Override the name
|
// Override the name
|
||||||
|
@ -575,7 +587,7 @@ func (b *SystemBackend) handlePolicySet(
|
||||||
|
|
||||||
// Update the policy
|
// Update the policy
|
||||||
if err := b.Core.policy.SetPolicy(parse); err != nil {
|
if err := b.Core.policy.SetPolicy(parse); err != nil {
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -585,7 +597,7 @@ func (b *SystemBackend) handlePolicyDelete(
|
||||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||||
name := data.Get("name").(string)
|
name := data.Get("name").(string)
|
||||||
if err := b.Core.policy.DeletePolicy(name); err != nil {
|
if err := b.Core.policy.DeletePolicy(name); err != nil {
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -640,7 +652,7 @@ func (b *SystemBackend) handleEnableAudit(
|
||||||
// Attempt enabling
|
// Attempt enabling
|
||||||
if err := b.Core.enableAudit(me); err != nil {
|
if err := b.Core.enableAudit(me); err != nil {
|
||||||
b.Backend.Logger().Printf("[ERR] sys: enable audit %#v failed: %v", me, err)
|
b.Backend.Logger().Printf("[ERR] sys: enable audit %#v failed: %v", me, err)
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -653,7 +665,7 @@ func (b *SystemBackend) handleDisableAudit(
|
||||||
// Attempt disable
|
// Attempt disable
|
||||||
if err := b.Core.disableAudit(path); err != nil {
|
if err := b.Core.disableAudit(path); err != nil {
|
||||||
b.Backend.Logger().Printf("[ERR] sys: disable audit '%s' failed: %v", path, err)
|
b.Backend.Logger().Printf("[ERR] sys: disable audit '%s' failed: %v", path, err)
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -673,7 +685,7 @@ func (b *SystemBackend) handleRawRead(
|
||||||
|
|
||||||
entry, err := b.Core.barrier.Get(path)
|
entry, err := b.Core.barrier.Get(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
if entry == nil {
|
if entry == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -724,7 +736,7 @@ func (b *SystemBackend) handleRawDelete(
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := b.Core.barrier.Delete(path); err != nil {
|
if err := b.Core.barrier.Delete(path); err != nil {
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -754,7 +766,7 @@ func (b *SystemBackend) handleRotate(
|
||||||
newTerm, err := b.Core.barrier.Rotate()
|
newTerm, err := b.Core.barrier.Rotate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Backend.Logger().Printf("[ERR] sys: failed to create new encryption key: %v", err)
|
b.Backend.Logger().Printf("[ERR] sys: failed to create new encryption key: %v", err)
|
||||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
return handleError(err)
|
||||||
}
|
}
|
||||||
b.Backend.Logger().Printf("[INFO] sys: installed new encryption key")
|
b.Backend.Logger().Printf("[INFO] sys: installed new encryption key")
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ const (
|
||||||
// barrier view for the backends.
|
// barrier view for the backends.
|
||||||
backendBarrierPrefix = "logical/"
|
backendBarrierPrefix = "logical/"
|
||||||
|
|
||||||
// systemBarrierPrefix is sthe prefix used for the
|
// systemBarrierPrefix is the prefix used for the
|
||||||
// system logical backend.
|
// system logical backend.
|
||||||
systemBarrierPrefix = "sys/"
|
systemBarrierPrefix = "sys/"
|
||||||
)
|
)
|
||||||
|
@ -139,16 +139,16 @@ func (c *Core) mount(me *MountEntry) error {
|
||||||
me.Path += "/"
|
me.Path += "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prevent protected paths from being unmounted
|
// Prevent protected paths from being mounted
|
||||||
for _, p := range protectedMounts {
|
for _, p := range protectedMounts {
|
||||||
if strings.HasPrefix(me.Path, p) {
|
if strings.HasPrefix(me.Path, p) {
|
||||||
return fmt.Errorf("cannot mount '%s'", me.Path)
|
return logical.CodedError(403, fmt.Sprintf("cannot mount '%s'", me.Path))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify there is no conflicting mount
|
// Verify there is no conflicting mount
|
||||||
if match := c.router.MatchingMount(me.Path); match != "" {
|
if match := c.router.MatchingMount(me.Path); match != "" {
|
||||||
return fmt.Errorf("existing mount at '%s'", match)
|
return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate a new UUID and view
|
// Generate a new UUID and view
|
||||||
|
|
Loading…
Reference in a new issue