diff --git a/http/handler.go b/http/handler.go index 987580935..b711574c0 100644 --- a/http/handler.go +++ b/http/handler.go @@ -665,18 +665,19 @@ func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) { // Returns true if the token was sourced from a Bearer header. func getTokenFromReq(r *http.Request) (string, bool) { if token := r.Header.Get(consts.AuthHeaderName); token != "" { - r.Header.Del(consts.AuthHeaderName) return token, false } - if v := r.Header.Get("Authorization"); v != "" { + if headers, ok := r.Header["Authorization"]; ok { // Reference for Authorization header format: https://tools.ietf.org/html/rfc7236#section-3 // If string does not start by 'Bearer ', it is not one we would use, // but might be used by plugins - if !strings.HasPrefix(v, "Bearer ") { - return "", false + for _, v := range headers { + if !strings.HasPrefix(v, "Bearer ") { + continue + } + return strings.TrimSpace(v[7:]), true } - return strings.TrimSpace(v[7:]), true } return "", false } @@ -687,6 +688,10 @@ func requestAuth(core *vault.Core, r *http.Request, req *logical.Request) (*logi token, fromAuthzHeader := getTokenFromReq(r) if token != "" { req.ClientToken = token + req.ClientTokenSource = logical.ClientTokenFromVaultHeader + if fromAuthzHeader { + req.ClientTokenSource = logical.ClientTokenFromAuthzHeader + } // Also attach the accessor if we have it. This doesn't fail if it // doesn't exist because the request may be to an unauthenticated @@ -700,10 +705,6 @@ func requestAuth(core *vault.Core, r *http.Request, req *logical.Request) (*logi req.ClientTokenAccessor = te.Accessor req.ClientTokenRemainingUses = te.NumUses req.SetTokenEntry(te) - if fromAuthzHeader { - // This was a valid token in an authz header - r.Header.Del("Authorization") - } } } diff --git a/http/handler_test.go b/http/handler_test.go index a0a9e046b..244eb183e 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -610,8 +610,6 @@ func TestHandler_getTokenFromReq(t *testing.T) { t.Fatalf("%s header should be prioritized", consts.AuthHeaderName) } else if tok != "NEWTOKEN" { t.Fatalf("expected 'NEWTOKEN' as result, got '%s'", tok) - } else if r.Header.Get(consts.AuthHeaderName) != "" { - t.Fatal("expected auth header to be removed") } r.Header = http.Header{} diff --git a/logical/request.go b/logical/request.go index 8380270df..f5156adc9 100644 --- a/logical/request.go +++ b/logical/request.go @@ -43,6 +43,14 @@ func (r *RequestWrapInfo) SentinelKeys() []string { } } +type ClientTokenSource uint32 + +const ( + NoClientToken ClientTokenSource = iota + ClientTokenFromVaultHeader + ClientTokenFromAuthzHeader +) + // Request is a struct that stores the parameters and context of a request // being made to Vault. It is used to abstract the details of the higher level // request protocol from the handlers. @@ -157,6 +165,10 @@ type Request struct { // For replication, contains the last WAL on the remote side after handling // the request, used for best-effort avoidance of stale read-after-write lastRemoteWAL uint64 + + // ClientTokenSource tells us where the client token was sourced from, so + // we can delete it before sending off to plugins + ClientTokenSource ClientTokenSource } // Get returns a data field and guards for nil Data diff --git a/vault/request_handling.go b/vault/request_handling.go index a46921a08..146b72aaf 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -265,6 +265,24 @@ func (c *Core) checkToken(ctx context.Context, req *logical.Request, unauth bool return nil, nil, errors.New("cannot access root path in unauthenticated request") } + // At this point we won't be forwarding a raw request; we should delete + // authorization headers as appropriate + switch req.ClientTokenSource { + case logical.ClientTokenFromVaultHeader: + delete(req.Headers, consts.AuthHeaderName) + case logical.ClientTokenFromAuthzHeader: + if headers, ok := req.Headers["Authorization"]; ok { + retHeaders := make([]string, 0, len(headers)) + for _, v := range headers { + if strings.HasPrefix(v, "Bearer ") { + continue + } + retHeaders = append(retHeaders, v) + } + req.Headers["Authorization"] = retHeaders + } + } + // When we receive a write of either type, rather than require clients to // PUT/POST and trust the operation, we ask the backend to give us the real // skinny -- if the backend implements an existence check, it can tell us