First part of perf standby entity race fix (#6106)

This commit is contained in:
Jeff Mitchell 2019-01-25 14:08:42 -05:00 committed by GitHub
parent 1f57e3674a
commit e781ea3ac4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 29 additions and 6 deletions

View file

@ -561,7 +561,7 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) {
// request is a helper to perform a request and properly exit in the // request is a helper to perform a request and properly exit in the
// case of an error. // case of an error.
func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool) { func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool, bool) {
resp, err := core.HandleRequest(rawReq.Context(), r) resp, err := core.HandleRequest(rawReq.Context(), r)
if r.LastRemoteWAL() > 0 && !vault.WaitUntilWALShipped(rawReq.Context(), core, r.LastRemoteWAL()) { if r.LastRemoteWAL() > 0 && !vault.WaitUntilWALShipped(rawReq.Context(), core, r.LastRemoteWAL()) {
if resp == nil { if resp == nil {
@ -571,14 +571,17 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l
} }
if errwrap.Contains(err, consts.ErrStandby.Error()) { if errwrap.Contains(err, consts.ErrStandby.Error()) {
respondStandby(core, w, rawReq.URL) respondStandby(core, w, rawReq.URL)
return resp, false return resp, false, false
}
if err != nil && errwrap.Contains(err, logical.ErrPerfStandbyPleaseForward.Error()) {
return nil, false, true
} }
if respondErrorCommon(w, r, resp, err) { if respondErrorCommon(w, r, resp, err) {
return resp, false return resp, false, false
} }
return resp, true return resp, true, false
} }
// respondStandby is used to trigger a redirect in the case that this Vault is currently a hot standby // respondStandby is used to trigger a redirect in the case that this Vault is currently a hot standby

View file

@ -208,7 +208,11 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool) http.H
// it. Vault core handles stripping this if we need to. This also // it. Vault core handles stripping this if we need to. This also
// handles all error cases; if we hit respondLogical, the request is a // handles all error cases; if we hit respondLogical, the request is a
// success. // success.
resp, ok := request(core, w, r, req) resp, ok, needsForward := request(core, w, r, req)
if needsForward {
forwardRequest(core, w, r)
return
}
if !ok { if !ok {
return return
} }

View file

@ -24,6 +24,10 @@ var (
// ErrUpstreamRateLimited is returned when Vault receives a rate limited // ErrUpstreamRateLimited is returned when Vault receives a rate limited
// response from an upstream // response from an upstream
ErrUpstreamRateLimited = errors.New("upstream rate limited") ErrUpstreamRateLimited = errors.New("upstream rate limited")
// ErrPerfStandbyForward is returned when Vault is in a state such that a
// perf standby cannot satisfy a request
ErrPerfStandbyPleaseForward = errors.New("please forward to the active node")
) )
type HTTPCodedError interface { type HTTPCodedError interface {

View file

@ -251,6 +251,9 @@ func (c *Core) checkToken(ctx context.Context, req *logical.Request, unauth bool
return nil, te, logical.ErrPermissionDenied return nil, te, logical.ErrPermissionDenied
} }
if te != nil && te.EntityID != "" && entity == nil { if te != nil && te.EntityID != "" && entity == nil {
if c.perfStandby {
return nil, nil, logical.ErrPerfStandbyPleaseForward
}
c.logger.Warn("permission denied as the entity on the token is invalid") c.logger.Warn("permission denied as the entity on the token is invalid")
return nil, te, logical.ErrPermissionDenied return nil, te, logical.ErrPermissionDenied
} }
@ -529,7 +532,13 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
// Validate the token // Validate the token
auth, te, ctErr := c.checkToken(ctx, req, false) auth, te, ctErr := c.checkToken(ctx, req, false)
// We run this logic first because we want to decrement the use count even in the case of an error if ctErr == logical.ErrPerfStandbyPleaseForward {
return nil, nil, ctErr
}
// We run this logic first because we want to decrement the use count even
// in the case of an error (assuming we can successfully look up; if we
// need to forward, we exit before now)
if te != nil && !isControlGroupRun(req) { if te != nil && !isControlGroupRun(req) {
// Attempt to use the token (decrement NumUses) // Attempt to use the token (decrement NumUses)
var err error var err error
@ -854,6 +863,9 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
// Do an unauth check. This will cause EGP policies to be checked // Do an unauth check. This will cause EGP policies to be checked
var ctErr error var ctErr error
auth, _, ctErr = c.checkToken(ctx, req, true) auth, _, ctErr = c.checkToken(ctx, req, true)
if ctErr == logical.ErrPerfStandbyPleaseForward {
return nil, nil, ctErr
}
if ctErr != nil { if ctErr != nil {
// If it is an internal error we return that, otherwise we // If it is an internal error we return that, otherwise we
// return invalid request so that the status codes can be correct // return invalid request so that the status codes can be correct