Add denylist check when filtering passthrough headers (#5436)
* Add denylist check when filtering passthrough headers * Minor comment update
This commit is contained in:
parent
ac8816a7a9
commit
37c0b83669
|
@ -340,9 +340,42 @@ func MergeSlices(args ...[]string) []string {
|
|||
}
|
||||
|
||||
result := make([]string, 0, len(all))
|
||||
for k, _ := range all {
|
||||
for k := range all {
|
||||
result = append(result, k)
|
||||
}
|
||||
sort.Strings(result)
|
||||
return result
|
||||
}
|
||||
|
||||
// Difference returns the set difference (A - B) of the two given slices. The
|
||||
// result will also remove any duplicated values in set A regardless of whether
|
||||
// that matches any values in set B.
|
||||
func Difference(a, b []string, lowercase bool) []string {
|
||||
if len(a) == 0 || len(b) == 0 {
|
||||
return a
|
||||
}
|
||||
|
||||
a = RemoveDuplicates(a, lowercase)
|
||||
b = RemoveDuplicates(b, lowercase)
|
||||
|
||||
itemsMap := map[string]bool{}
|
||||
for _, aVal := range a {
|
||||
itemsMap[aVal] = true
|
||||
}
|
||||
|
||||
// Perform difference calculation
|
||||
for _, bVal := range b {
|
||||
if _, ok := itemsMap[bVal]; ok {
|
||||
itemsMap[bVal] = false
|
||||
}
|
||||
}
|
||||
|
||||
items := []string{}
|
||||
for item, exists := range itemsMap {
|
||||
if exists {
|
||||
items = append(items, item)
|
||||
}
|
||||
}
|
||||
sort.Strings(items)
|
||||
return items
|
||||
}
|
||||
|
|
|
@ -459,3 +459,59 @@ func TestStrUtil_MergeSlices(t *testing.T) {
|
|||
t.Fatalf("expected %v, got %v", expect, res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifference(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Name string
|
||||
SetA []string
|
||||
SetB []string
|
||||
Lowercase bool
|
||||
ExpectedResult []string
|
||||
}{
|
||||
{
|
||||
Name: "case_sensitive",
|
||||
SetA: []string{"a", "b", "c"},
|
||||
SetB: []string{"b", "c"},
|
||||
Lowercase: false,
|
||||
ExpectedResult: []string{"a"},
|
||||
},
|
||||
{
|
||||
Name: "case_insensitive",
|
||||
SetA: []string{"a", "B", "c"},
|
||||
SetB: []string{"b", "C"},
|
||||
Lowercase: true,
|
||||
ExpectedResult: []string{"a"},
|
||||
},
|
||||
{
|
||||
Name: "no_match",
|
||||
SetA: []string{"a", "b", "c"},
|
||||
SetB: []string{"d"},
|
||||
Lowercase: false,
|
||||
ExpectedResult: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
Name: "empty_set_a",
|
||||
SetA: []string{},
|
||||
SetB: []string{"d", "e"},
|
||||
Lowercase: false,
|
||||
ExpectedResult: []string{},
|
||||
},
|
||||
{
|
||||
Name: "empty_set_b",
|
||||
SetA: []string{"a", "b"},
|
||||
SetB: []string{},
|
||||
Lowercase: false,
|
||||
ExpectedResult: []string{"a", "b"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
actualResult := Difference(tc.SetA, tc.SetB, tc.Lowercase)
|
||||
|
||||
if !reflect.DeepEqual(actualResult, tc.ExpectedResult) {
|
||||
t.Fatalf("expected %v, got %v", tc.ExpectedResult, actualResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2308,3 +2308,61 @@ func TestCore_HandleRequest_Headers(t *testing.T) {
|
|||
t.Fatalf("did not expect 'Should-Not-Passthrough' to be in the headers map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCore_HandleRequest_Headers_denyList(t *testing.T) {
|
||||
noop := &NoopBackend{
|
||||
Response: &logical.Response{
|
||||
Data: map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
|
||||
c, _, root := TestCoreUnsealed(t)
|
||||
c.logicalBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) {
|
||||
return noop, nil
|
||||
}
|
||||
|
||||
// Enable the backend
|
||||
req := logical.TestRequest(t, logical.UpdateOperation, "sys/mounts/foo")
|
||||
req.Data["type"] = "noop"
|
||||
req.ClientToken = root
|
||||
_, err := c.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Mount tune
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "sys/mounts/foo/tune")
|
||||
req.Data["passthrough_request_headers"] = []string{"Authorization", consts.AuthHeaderName}
|
||||
req.ClientToken = root
|
||||
_, err = c.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to read
|
||||
lreq := &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "foo/test",
|
||||
ClientToken: root,
|
||||
Headers: map[string][]string{
|
||||
consts.AuthHeaderName: []string{"foo"},
|
||||
"Authorization": []string{"baz"},
|
||||
},
|
||||
}
|
||||
_, err = c.HandleRequest(namespace.RootContext(nil), lreq)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Check the headers
|
||||
headers := noop.Requests[0].Headers
|
||||
|
||||
// Test passthrough values, they should not be present in the backend
|
||||
if _, ok := headers["Authorization"]; ok {
|
||||
t.Fatalf("did not expect 'Should-Not-Passthrough' to be in the headers map")
|
||||
}
|
||||
|
||||
if _, ok := headers[consts.AuthHeaderName]; ok {
|
||||
t.Fatalf("did not expect %q to be in the headers map", consts.AuthHeaderName)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,11 +10,20 @@ import (
|
|||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/armon/go-radix"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/helper/salt"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
var (
|
||||
denylistHeaders = []string{
|
||||
"Authorization",
|
||||
consts.AuthHeaderName,
|
||||
}
|
||||
)
|
||||
|
||||
// Router is used to do prefix based routing of a request to a logical backend
|
||||
type Router struct {
|
||||
l sync.RWMutex
|
||||
|
@ -753,9 +762,9 @@ func pathsToRadix(paths []string) *radix.Tree {
|
|||
}
|
||||
|
||||
// filteredPassthroughHeaders returns a headers map[string][]string that
|
||||
// contains the filtered values contained in passthroughHeaders, as well as the
|
||||
// values in whitelistedHeaders. Filtering of passthroughHeaders from the
|
||||
// origHeaders is done is a case-insensitive manner.
|
||||
// contains the filtered values contained in passthroughHeaders. Filtering of
|
||||
// passthroughHeaders from the origHeaders is done is a case-insensitive manner.
|
||||
// Headers that match values from denylistHeaders will be ignored.
|
||||
func filteredPassthroughHeaders(origHeaders map[string][]string, passthroughHeaders []string) map[string][]string {
|
||||
retHeaders := make(map[string][]string)
|
||||
|
||||
|
@ -764,6 +773,10 @@ func filteredPassthroughHeaders(origHeaders map[string][]string, passthroughHead
|
|||
return retHeaders
|
||||
}
|
||||
|
||||
// Filter passthroughHeaders values through denyListHeaders first. Returns the
|
||||
// lowercased the complement set.
|
||||
passthroughHeadersSubset := strutil.Difference(passthroughHeaders, denylistHeaders, true)
|
||||
|
||||
// Create a map that uses lowercased header values as the key and the original
|
||||
// header naming as the value for comparison down below.
|
||||
lowerHeadersRef := make(map[string]string, len(origHeaders))
|
||||
|
@ -774,8 +787,8 @@ func filteredPassthroughHeaders(origHeaders map[string][]string, passthroughHead
|
|||
// Case-insensitive compare of passthrough headers against originating
|
||||
// headers. The returned headers will be the same casing as the originating
|
||||
// header name.
|
||||
for _, ph := range passthroughHeaders {
|
||||
if header, ok := lowerHeadersRef[strings.ToLower(ph)]; ok {
|
||||
for _, ph := range passthroughHeadersSubset {
|
||||
if header, ok := lowerHeadersRef[ph]; ok {
|
||||
retHeaders[header] = origHeaders[header]
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue