// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package vault import ( "context" "fmt" "regexp" "strings" "sync" "sync/atomic" "time" "github.com/armon/go-metrics" "github.com/armon/go-radix" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/salt" "github.com/hashicorp/vault/sdk/logical" ) var deniedPassthroughRequestHeaders = []string{ consts.AuthHeaderName, } // matches when '+' is next to a non-slash char var wcAdjacentNonSlashRegEx = regexp.MustCompile(`\+[^/]|[^/]\+`).MatchString // Router is used to do prefix based routing of a request to a logical backend type Router struct { l sync.RWMutex root *radix.Tree mountUUIDCache *radix.Tree mountAccessorCache *radix.Tree tokenStoreSaltFunc func(context.Context) (*salt.Salt, error) // storagePrefix maps the prefix used for storage (ala the BarrierView) // to the backend. This is used to map a key back into the backend that owns it. // For example, logical/uuid1/foobar -> secrets/ (kv backend) + foobar storagePrefix *radix.Tree logger hclog.Logger } // NewRouter returns a new router func NewRouter() *Router { r := &Router{ root: radix.New(), storagePrefix: radix.New(), mountUUIDCache: radix.New(), mountAccessorCache: radix.New(), // this will get replaced in production with a real logger but it's useful to have a default in place for tests logger: hclog.NewNullLogger(), } return r } // routeEntry is used to represent a mount point in the router type routeEntry struct { tainted bool backend logical.Backend mountEntry *MountEntry storageView logical.Storage storagePrefix string rootPaths atomic.Value loginPaths atomic.Value l sync.RWMutex } type wildcardPath struct { // this sits in the hot path of requests so we are micro-optimizing by // storing pre-split slices of path segments segments []string isPrefix bool } // loginPathsEntry is used to hold the routeEntry loginPaths type loginPathsEntry struct { paths *radix.Tree wildcardPaths []wildcardPath } type ValidateMountResponse struct { MountType string `json:"mount_type" structs:"mount_type" mapstructure:"mount_type"` MountAccessor string `json:"mount_accessor" structs:"mount_accessor" mapstructure:"mount_accessor"` MountPath string `json:"mount_path" structs:"mount_path" mapstructure:"mount_path"` MountLocal bool `json:"mount_local" structs:"mount_local" mapstructure:"mount_local"` } func (r *Router) reset() { r.l.Lock() defer r.l.Unlock() r.root = radix.New() r.storagePrefix = radix.New() r.mountUUIDCache = radix.New() r.mountAccessorCache = radix.New() } func (r *Router) GetRecords(tag string) ([]map[string]interface{}, error) { r.l.RLock() defer r.l.RUnlock() var data []map[string]interface{} var tree *radix.Tree switch tag { case "root": tree = r.root case "uuid": tree = r.mountUUIDCache case "accessor": tree = r.mountAccessorCache case "storage": tree = r.storagePrefix default: return nil, logical.ErrUnsupportedPath } for _, v := range tree.ToMap() { info := v.(Deserializable).Deserialize() data = append(data, info) } return data, nil } func (entry *routeEntry) Deserialize() map[string]interface{} { entry.l.RLock() defer entry.l.RUnlock() ret := map[string]interface{}{ "tainted": entry.tainted, "storage_prefix": entry.storagePrefix, } for k, v := range entry.mountEntry.Deserialize() { ret[k] = v } return ret } // ValidateMountByAccessor returns the mount type and ID for a given mount // accessor func (r *Router) ValidateMountByAccessor(accessor string) *ValidateMountResponse { if accessor == "" { return nil } mountEntry := r.MatchingMountByAccessor(accessor) if mountEntry == nil { return nil } mountPath := mountEntry.Path if mountEntry.Table == credentialTableType { mountPath = credentialRoutePrefix + mountPath } return &ValidateMountResponse{ MountAccessor: mountEntry.Accessor, MountType: mountEntry.Type, MountPath: mountPath, MountLocal: mountEntry.Local, } } // SaltID is used to apply a salt and hash to an ID to make sure its not reversible func (re *routeEntry) SaltID(id string) string { return salt.SaltID(re.mountEntry.UUID, id, salt.SHA1Hash) } // Mount is used to expose a logical backend at a given prefix, using a unique salt, // and the barrier view for that path. func (r *Router) Mount(backend logical.Backend, prefix string, mountEntry *MountEntry, storageView *BarrierView) error { r.l.Lock() defer r.l.Unlock() // prepend namespace prefix = mountEntry.Namespace().Path + prefix // Check if this is a nested mount if existing, _, ok := r.root.LongestPrefix(prefix); ok && existing != "" { return fmt.Errorf("cannot mount under existing mount %q", existing) } // Build the paths paths := new(logical.Paths) if backend != nil { specialPaths := backend.SpecialPaths() if specialPaths != nil { paths = specialPaths } } // Create a mount entry re := &routeEntry{ tainted: mountEntry.Tainted, backend: backend, mountEntry: mountEntry, storagePrefix: storageView.Prefix(), storageView: storageView, } re.rootPaths.Store(pathsToRadix(paths.Root)) loginPathsEntry, err := parseUnauthenticatedPaths(paths.Unauthenticated) if err != nil { return err } re.loginPaths.Store(loginPathsEntry) switch { case prefix == "": return fmt.Errorf("missing prefix to be used for router entry; mount_path: %q, mount_type: %q", re.mountEntry.Path, re.mountEntry.Type) case re.storagePrefix == "": return fmt.Errorf("missing storage view prefix; mount_path: %q, mount_type: %q", re.mountEntry.Path, re.mountEntry.Type) case re.mountEntry.UUID == "": return fmt.Errorf("missing mount identifier; mount_path: %q, mount_type: %q", re.mountEntry.Path, re.mountEntry.Type) case re.mountEntry.Accessor == "": return fmt.Errorf("missing mount accessor; mount_path: %q, mount_type: %q", re.mountEntry.Path, re.mountEntry.Type) } r.root.Insert(prefix, re) r.storagePrefix.Insert(re.storagePrefix, re) r.mountUUIDCache.Insert(re.mountEntry.UUID, re.mountEntry) r.mountAccessorCache.Insert(re.mountEntry.Accessor, re.mountEntry) return nil } // Unmount is used to remove a logical backend from a given prefix func (r *Router) Unmount(ctx context.Context, prefix string) error { ns, err := namespace.FromContext(ctx) if err != nil { return err } prefix = ns.Path + prefix r.l.Lock() defer r.l.Unlock() // Fast-path out if the backend doesn't exist raw, ok := r.root.Get(prefix) if !ok { return nil } // Call backend's Cleanup routine re := raw.(*routeEntry) if re.backend != nil { re.backend.Cleanup(ctx) } // Purge from the radix trees r.root.Delete(prefix) r.storagePrefix.Delete(re.storagePrefix) r.mountUUIDCache.Delete(re.mountEntry.UUID) r.mountAccessorCache.Delete(re.mountEntry.Accessor) return nil } // Remount is used to change the mount location of a logical backend func (r *Router) Remount(ctx context.Context, src, dst string) error { ns, err := namespace.FromContext(ctx) if err != nil { return err } src = ns.Path + src dst = ns.Path + dst r.l.Lock() defer r.l.Unlock() // Check for existing mount raw, ok := r.root.Get(src) if !ok { return fmt.Errorf("no mount at %q", src) } // Update the mount point r.root.Delete(src) r.root.Insert(dst, raw) return nil } // Taint is used to mark a path as tainted. This means only RollbackOperation // RevokeOperation requests are allowed to proceed func (r *Router) Taint(ctx context.Context, path string) error { ns, err := namespace.FromContext(ctx) if err != nil { return err } path = ns.Path + path r.l.Lock() defer r.l.Unlock() _, raw, ok := r.root.LongestPrefix(path) if ok { raw.(*routeEntry).tainted = true } return nil } // Untaint is used to unmark a path as tainted. func (r *Router) Untaint(ctx context.Context, path string) error { ns, err := namespace.FromContext(ctx) if err != nil { return err } path = ns.Path + path r.l.Lock() defer r.l.Unlock() _, raw, ok := r.root.LongestPrefix(path) if ok { raw.(*routeEntry).tainted = false } return nil } func (r *Router) MatchingMountByUUID(mountID string) *MountEntry { if mountID == "" { return nil } r.l.RLock() _, raw, ok := r.mountUUIDCache.LongestPrefix(mountID) if !ok { r.l.RUnlock() return nil } r.l.RUnlock() return raw.(*MountEntry) } // MatchingMountByAccessor returns the MountEntry by accessor lookup func (r *Router) MatchingMountByAccessor(mountAccessor string) *MountEntry { if mountAccessor == "" { return nil } r.l.RLock() _, raw, ok := r.mountAccessorCache.LongestPrefix(mountAccessor) if !ok { r.l.RUnlock() return nil } r.l.RUnlock() return raw.(*MountEntry) } // MatchingMount returns the mount prefix that would be used for a path func (r *Router) MatchingMount(ctx context.Context, path string) string { r.l.RLock() mount := r.matchingMountInternal(ctx, path) r.l.RUnlock() return mount } func (r *Router) matchingMountInternal(ctx context.Context, path string) string { ns, err := namespace.FromContext(ctx) if err != nil { return "" } path = ns.Path + path mount, _, ok := r.root.LongestPrefix(path) if !ok { return "" } return mount } // matchingPrefixInternal returns a mount prefix that a path may be a part of func (r *Router) matchingPrefixInternal(ctx context.Context, path string) string { ns, err := namespace.FromContext(ctx) if err != nil { return "" } path = ns.Path + path var existing string fn := func(existingPath string, v interface{}) bool { if strings.HasPrefix(existingPath, path) { existing = existingPath return true } return false } r.root.WalkPrefix(path, fn) return existing } // MountConflict determines if there are potential path conflicts func (r *Router) MountConflict(ctx context.Context, path string) string { r.l.RLock() defer r.l.RUnlock() if exactMatch := r.matchingMountInternal(ctx, path); exactMatch != "" { return exactMatch } if prefixMatch := r.matchingPrefixInternal(ctx, path); prefixMatch != "" { return prefixMatch } return "" } // MatchingStorageByAPIPath/StoragePath returns the storage used for // API/Storage paths respectively func (r *Router) MatchingStorageByAPIPath(ctx context.Context, path string) logical.Storage { return r.matchingStorage(ctx, path, true) } func (r *Router) MatchingStorageByStoragePath(ctx context.Context, path string) logical.Storage { return r.matchingStorage(ctx, path, false) } func (r *Router) matchingStorage(ctx context.Context, path string, apiPath bool) logical.Storage { ns, err := namespace.FromContext(ctx) if err != nil { return nil } path = ns.Path + path var raw interface{} var ok bool r.l.RLock() if apiPath { _, raw, ok = r.root.LongestPrefix(path) } else { _, raw, ok = r.storagePrefix.LongestPrefix(path) } r.l.RUnlock() if !ok { return nil } return raw.(*routeEntry).storageView } // MatchingMountEntry returns the MountEntry used for a path func (r *Router) MatchingMountEntry(ctx context.Context, path string) *MountEntry { ns, err := namespace.FromContext(ctx) if err != nil { return nil } path = ns.Path + path r.l.RLock() _, raw, ok := r.root.LongestPrefix(path) r.l.RUnlock() if !ok { return nil } return raw.(*routeEntry).mountEntry } // MatchingBackend returns the backend used for a path func (r *Router) MatchingBackend(ctx context.Context, path string) logical.Backend { ns, err := namespace.FromContext(ctx) if err != nil { return nil } path = ns.Path + path r.l.RLock() _, raw, ok := r.root.LongestPrefix(path) r.l.RUnlock() if !ok { return nil } re := raw.(*routeEntry) re.l.RLock() defer re.l.RUnlock() return re.backend } // MatchingSystemView returns the SystemView used for a path func (r *Router) MatchingSystemView(ctx context.Context, path string) logical.SystemView { ns, err := namespace.FromContext(ctx) if err != nil { return nil } path = ns.Path + path r.l.RLock() _, raw, ok := r.root.LongestPrefix(path) r.l.RUnlock() if !ok || raw.(*routeEntry).backend == nil { return nil } return raw.(*routeEntry).backend.System() } func (r *Router) MatchingMountByAPIPath(ctx context.Context, path string) string { me, _, _ := r.matchingMountEntryByPath(ctx, path, true) if me == nil { return "" } return me.Path } // MatchingStoragePrefixByAPIPath the storage prefix for the given api path func (r *Router) MatchingStoragePrefixByAPIPath(ctx context.Context, path string) (string, bool) { ns, err := namespace.FromContext(ctx) if err != nil { return "", false } path = ns.Path + path _, prefix, found := r.matchingMountEntryByPath(ctx, path, true) return prefix, found } // MatchingAPIPrefixByStoragePath the api path information for the given storage path func (r *Router) MatchingAPIPrefixByStoragePath(ctx context.Context, path string) (*namespace.Namespace, string, string, bool) { me, prefix, found := r.matchingMountEntryByPath(ctx, path, false) if !found { return nil, "", "", found } mountPath := me.Path // Add back the prefix for credential backends if strings.HasPrefix(path, credentialBarrierPrefix) { mountPath = credentialRoutePrefix + mountPath } return me.Namespace(), mountPath, prefix, found } func (r *Router) matchingMountEntryByPath(ctx context.Context, path string, apiPath bool) (*MountEntry, string, bool) { var raw interface{} var ok bool r.l.RLock() if apiPath { _, raw, ok = r.root.LongestPrefix(path) } else { _, raw, ok = r.storagePrefix.LongestPrefix(path) } r.l.RUnlock() if !ok { return nil, "", false } // Extract the mount path and storage prefix re := raw.(*routeEntry) prefix := re.storagePrefix return re.mountEntry, prefix, true } // Route is used to route a given request func (r *Router) Route(ctx context.Context, req *logical.Request) (*logical.Response, error) { resp, _, _, err := r.routeCommon(ctx, req, false) return resp, err } // RouteExistenceCheck is used to route a given existence check request func (r *Router) RouteExistenceCheck(ctx context.Context, req *logical.Request) (*logical.Response, bool, bool, error) { resp, ok, exists, err := r.routeCommon(ctx, req, true) return resp, ok, exists, err } func (r *Router) routeCommon(ctx context.Context, req *logical.Request, existenceCheck bool) (*logical.Response, bool, bool, error) { ns, err := namespace.FromContext(ctx) if err != nil { return nil, false, false, err } // Find the mount point r.l.RLock() adjustedPath := req.Path mount, raw, ok := r.root.LongestPrefix(ns.Path + adjustedPath) if !ok && !strings.HasSuffix(adjustedPath, "/") { // Re-check for a backend by appending a slash. This lets "foo" mean // "foo/" at the root level which is almost always what we want. adjustedPath += "/" mount, raw, ok = r.root.LongestPrefix(ns.Path + adjustedPath) } r.l.RUnlock() if !ok { return logical.ErrorResponse(fmt.Sprintf("no handler for route %q. route entry not found.", req.Path)), false, false, logical.ErrUnsupportedPath } req.Path = adjustedPath if !existenceCheck { defer metrics.MeasureSince([]string{ "route", string(req.Operation), strings.ReplaceAll(mount, "/", "-"), }, time.Now()) } re := raw.(*routeEntry) // Grab a read lock on the route entry, this protects against the backend // being reloaded during a request. The exception is a renew request on the // token store; such a request will have already been routed through the // token store -> exp manager -> here so we need to not grab the lock again // or we'll be recursively grabbing it. if !(req.Operation == logical.RenewOperation && strings.HasPrefix(req.Path, "auth/token/")) { re.l.RLock() defer re.l.RUnlock() } // Filtered mounts will have a nil backend if re.backend == nil { return logical.ErrorResponse(fmt.Sprintf("no handler for route %q. route entry found, but backend is nil.", req.Path)), false, false, logical.ErrUnsupportedPath } // If the path is tainted, we reject any operation except for // Rollback and Revoke if re.tainted { switch req.Operation { case logical.RevokeOperation, logical.RollbackOperation: default: return logical.ErrorResponse(fmt.Sprintf("no handler for route %q. route entry is tainted.", req.Path)), false, false, logical.ErrUnsupportedPath } } // Adjust the path to exclude the routing prefix originalPath := req.Path req.Path = strings.TrimPrefix(ns.Path+req.Path, mount) req.MountPoint = mount req.MountType = re.mountEntry.Type req.SetMountRunningSha256(re.mountEntry.RunningSha256) req.SetMountRunningVersion(re.mountEntry.RunningVersion) req.SetMountIsExternalPlugin(re.mountEntry.IsExternalPlugin()) req.SetMountClass(re.mountEntry.MountClass()) if req.Path == "/" { req.Path = "" } // Attach the storage view for the request req.Storage = re.storageView originalEntityID := req.EntityID // Hash the request token unless the request is being routed to the token // or system backend. clientToken := req.ClientToken switch { case strings.HasPrefix(originalPath, "auth/token/"): case strings.HasPrefix(originalPath, mountPathSystem): case strings.HasPrefix(originalPath, mountPathIdentity): case strings.HasPrefix(originalPath, mountPathCubbyhole): if req.Operation == logical.RollbackOperation { // Backend doesn't support this and it can't properly look up a // cubbyhole ID so just return here return nil, false, false, nil } te := req.TokenEntry() if te == nil { return nil, false, false, fmt.Errorf("nil token entry") } if te.Type != logical.TokenTypeService { return logical.ErrorResponse(`cubbyhole operations are only supported by "service" type tokens`), false, false, nil } switch { case te.NamespaceID == namespace.RootNamespaceID && !strings.HasPrefix(req.ClientToken, consts.LegacyServiceTokenPrefix) && !strings.HasPrefix(req.ClientToken, consts.ServiceTokenPrefix): // In order for the token store to revoke later, we need to have the same // salted ID, so we double-salt what's going to the cubbyhole backend salt, err := r.tokenStoreSaltFunc(ctx) if err != nil { return nil, false, false, err } req.ClientToken = re.SaltID(salt.SaltID(req.ClientToken)) default: if te.CubbyholeID == "" { return nil, false, false, fmt.Errorf("empty cubbyhole id") } req.ClientToken = te.CubbyholeID } default: req.ClientToken = re.SaltID(req.ClientToken) } // Cache the pointer to the original connection object originalConn := req.Connection // Cache the identifier of the request originalReqID := req.ID // Cache the client token's number of uses in the request originalClientTokenRemainingUses := req.ClientTokenRemainingUses req.ClientTokenRemainingUses = 0 originalMFACreds := req.MFACreds req.MFACreds = nil originalControlGroup := req.ControlGroup req.ControlGroup = nil // Cache the headers headers := req.Headers req.Headers = nil // Cache the saved request SSC token inboundToken := req.InboundSSCToken // Ensure that the inbound token we cache in the // request during token creation isn't sent to backends req.InboundSSCToken = "" // Filter and add passthrough headers to the backend var passthroughRequestHeaders []string if rawVal, ok := re.mountEntry.synthesizedConfigCache.Load("passthrough_request_headers"); ok { passthroughRequestHeaders = rawVal.([]string) } var allowedResponseHeaders []string if rawVal, ok := re.mountEntry.synthesizedConfigCache.Load("allowed_response_headers"); ok { allowedResponseHeaders = rawVal.([]string) } if len(passthroughRequestHeaders) > 0 { req.Headers = filteredHeaders(headers, passthroughRequestHeaders, deniedPassthroughRequestHeaders) } // Cache the wrap info of the request var wrapInfo *logical.RequestWrapInfo if req.WrapInfo != nil { wrapInfo = &logical.RequestWrapInfo{ TTL: req.WrapInfo.TTL, Format: req.WrapInfo.Format, SealWrap: req.WrapInfo.SealWrap, } } originalPolicyOverride := req.PolicyOverride reqTokenEntry := req.TokenEntry() req.SetTokenEntry(nil) // Reset the request before returning defer func() { req.Path = originalPath req.MountPoint = mount req.MountType = re.mountEntry.Type req.SetMountRunningSha256(re.mountEntry.RunningSha256) req.SetMountRunningVersion(re.mountEntry.RunningVersion) req.SetMountIsExternalPlugin(re.mountEntry.IsExternalPlugin()) req.SetMountClass(re.mountEntry.MountClass()) req.Connection = originalConn req.ID = originalReqID req.Storage = nil req.ClientToken = clientToken req.ClientTokenRemainingUses = originalClientTokenRemainingUses req.WrapInfo = wrapInfo req.Headers = headers req.PolicyOverride = originalPolicyOverride // This is only set in one place, after routing, so should never be set // by a backend req.SetLastRemoteWAL(0) // This will be used for attaching the mount accessor for the identities // returned by the authentication backends req.MountAccessor = re.mountEntry.Accessor req.EntityID = originalEntityID req.MFACreds = originalMFACreds req.InboundSSCToken = inboundToken // Before resetting the tokenEntry, see if an ExternalID was added if req.TokenEntry() != nil && req.TokenEntry().ExternalID != "" { reqTokenEntry.ExternalID = req.TokenEntry().ExternalID } req.SetTokenEntry(reqTokenEntry) req.ControlGroup = originalControlGroup }() // Invoke the backend if existenceCheck { ok, exists, err := re.backend.HandleExistenceCheck(ctx, req) return nil, ok, exists, err } else { resp, err := re.backend.HandleRequest(ctx, req) if resp != nil { if len(allowedResponseHeaders) > 0 { resp.Headers = filteredHeaders(resp.Headers, allowedResponseHeaders, nil) } else { resp.Headers = nil } if resp.Auth != nil { // When a token gets renewed, the request hits this path and // reaches token store. Token store delegates the renewal to the // expiration manager. Expiration manager in-turn creates a // different logical request and forwards the request to the auth // backend that had initially authenticated the login request. The // forwarding to auth backend will make this code path hit for the // second time for the same renewal request. The accessors in the // Alias structs should be of the auth backend and not of the token // store. Therefore, avoiding the overwriting of accessors by // having a check for path prefix having "renew". This gets applied // for "renew" and "renew-self" requests. if !strings.HasPrefix(req.Path, "renew") { if resp.Auth.Alias != nil { resp.Auth.Alias.MountAccessor = re.mountEntry.Accessor } for _, alias := range resp.Auth.GroupAliases { alias.MountAccessor = re.mountEntry.Accessor } } switch re.mountEntry.Type { case mountTypeToken, mountTypeNSToken: // Nothing; we respect what the token store is telling us and // we don't allow tuning default: switch re.mountEntry.Config.TokenType { case logical.TokenTypeService, logical.TokenTypeBatch: resp.Auth.TokenType = re.mountEntry.Config.TokenType case logical.TokenTypeDefault, logical.TokenTypeDefaultService: switch resp.Auth.TokenType { case logical.TokenTypeDefault, logical.TokenTypeDefaultService, logical.TokenTypeService: resp.Auth.TokenType = logical.TokenTypeService default: resp.Auth.TokenType = logical.TokenTypeBatch } case logical.TokenTypeDefaultBatch: switch resp.Auth.TokenType { case logical.TokenTypeDefault, logical.TokenTypeDefaultBatch, logical.TokenTypeBatch: resp.Auth.TokenType = logical.TokenTypeBatch default: resp.Auth.TokenType = logical.TokenTypeService } } } } } return resp, false, false, err } } // RootPath checks if the given path requires root privileges func (r *Router) RootPath(ctx context.Context, path string) bool { ns, err := namespace.FromContext(ctx) if err != nil { return false } adjustedPath := ns.Path + path r.l.RLock() mount, raw, ok := r.root.LongestPrefix(adjustedPath) r.l.RUnlock() if !ok { return false } re := raw.(*routeEntry) // Trim to get remaining path remain := strings.TrimPrefix(adjustedPath, mount) // Check the rootPaths of this backend rootPaths := re.rootPaths.Load().(*radix.Tree) match, raw, ok := rootPaths.LongestPrefix(remain) if !ok { return false } prefixMatch := raw.(bool) // Handle the prefix match case if prefixMatch { return strings.HasPrefix(remain, match) } // Handle the exact match case return match == remain } // LoginPath checks if the given path is used for logins // Matching Priority // 1. prefix // 2. exact // 3. wildcard func (r *Router) LoginPath(ctx context.Context, path string) bool { ns, err := namespace.FromContext(ctx) if err != nil { return false } adjustedPath := ns.Path + path r.l.RLock() mount, raw, ok := r.root.LongestPrefix(adjustedPath) r.l.RUnlock() if !ok { return false } re := raw.(*routeEntry) // Trim to get remaining path remain := strings.TrimPrefix(adjustedPath, mount) // Check the loginPaths of this backend pe := re.loginPaths.Load().(*loginPathsEntry) match, raw, ok := pe.paths.LongestPrefix(remain) if !ok && len(pe.wildcardPaths) == 0 { // no match found return false } if ok { prefixMatch := raw.(bool) if prefixMatch { // Handle the prefix match case return strings.HasPrefix(remain, match) } if match == remain { // Handle the exact match case return true } } // check Login Paths containing wildcards reqPathParts := strings.Split(remain, "/") for _, w := range pe.wildcardPaths { if pathMatchesWildcardPath(reqPathParts, w.segments, w.isPrefix) { return true } } return false } // pathMatchesWildcardPath returns true if the path made up of the path slice // matches the given wildcard path slice func pathMatchesWildcardPath(path, wcPath []string, isPrefix bool) bool { if len(wcPath) == 0 { return false } if len(path) < len(wcPath) { // check if the path coming in is shorter; if so it can't match return false } if !isPrefix && len(wcPath) != len(path) { // If it's not a prefix we expect the same number of segments return false } for i, wcPathPart := range wcPath { switch { case wcPathPart == "+": case wcPathPart == path[i]: case isPrefix && i == len(wcPath)-1 && strings.HasPrefix(path[i], wcPathPart): default: // we encountered segments that did not match return false } } return true } func wildcardError(path, msg string) error { return fmt.Errorf("path %q: invalid use of wildcards %s", path, msg) } func isValidUnauthenticatedPath(path string) (bool, error) { switch { case strings.Count(path, "*") > 1: return false, wildcardError(path, "(multiple '*' is forbidden)") case strings.Contains(path, "+*"): return false, wildcardError(path, "('+*' is forbidden)") case strings.Contains(path, "*") && path[len(path)-1] != '*': return false, wildcardError(path, "('*' is only allowed at the end of a path)") case wcAdjacentNonSlashRegEx(path): return false, wildcardError(path, "('+' is not allowed next to a non-slash)") } return true, nil } // parseUnauthenticatedPaths converts a list of special paths to a // loginPathsEntry func parseUnauthenticatedPaths(paths []string) (*loginPathsEntry, error) { var tempPaths []string tempWildcardPaths := make([]wildcardPath, 0) for _, path := range paths { if ok, err := isValidUnauthenticatedPath(path); !ok { return nil, err } if strings.Contains(path, "+") { // Paths with wildcards are not stored in the radix tree because // the radix tree does not handle wildcards in the middle of strings. isPrefix := false if path[len(path)-1] == '*' { isPrefix = true path = path[0 : len(path)-1] } // We are micro-optimizing by storing pre-split slices of path segments wcPath := wildcardPath{segments: strings.Split(path, "/"), isPrefix: isPrefix} tempWildcardPaths = append(tempWildcardPaths, wcPath) } else { // accumulate paths that do not contain wildcards // to be stored in the radix tree tempPaths = append(tempPaths, path) } } return &loginPathsEntry{ paths: pathsToRadix(tempPaths), wildcardPaths: tempWildcardPaths, }, nil } // pathsToRadix converts a list of special paths to a radix tree. func pathsToRadix(paths []string) *radix.Tree { tree := radix.New() for _, path := range paths { // Check if this is a prefix or exact match prefixMatch := len(path) >= 1 && path[len(path)-1] == '*' if prefixMatch { path = path[:len(path)-1] } tree.Insert(path, prefixMatch) } return tree } // filteredHeaders returns a headers map[string][]string that // contains the filtered values contained in candidateHeaders. Filtering of // candidateHeaders from the origHeaders is done is a case-insensitive manner. // Headers that match values from deniedHeaders will be ignored. func filteredHeaders(origHeaders map[string][]string, candidateHeaders, deniedHeaders []string) map[string][]string { // Short-circuit if there's nothing to filter if len(candidateHeaders) == 0 { return nil } retHeaders := make(map[string][]string, len(origHeaders)) // Filter candidateHeaders values through deniedHeaders first. Returns the // lowercased complement set. We call even if no denied headers to get the // values lowercased. allowedCandidateHeaders := strutil.Difference(candidateHeaders, deniedHeaders, true) // Create a map that uses lowercased header values as the key and the original // header naming as the value for comparison down below. lowerOrigHeaderKeys := make(map[string]string, len(origHeaders)) for key := range origHeaders { lowerOrigHeaderKeys[strings.ToLower(key)] = key } // Case-insensitive compare of passthrough headers against originating // headers. The returned headers will be the same casing as the originating // header name. for _, ch := range allowedCandidateHeaders { if header, ok := lowerOrigHeaderKeys[ch]; ok { retHeaders[header] = origHeaders[header] } } return retHeaders }