Add denylist check when filtering passthrough headers (#5436)

* Add denylist check when filtering passthrough headers

* Minor comment update
This commit is contained in:
Calvin Leung Huang 2018-10-01 12:20:31 -07:00 committed by GitHub
parent ac8816a7a9
commit 37c0b83669
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 166 additions and 6 deletions

View File

@ -340,9 +340,42 @@ func MergeSlices(args ...[]string) []string {
}
result := make([]string, 0, len(all))
for k, _ := range all {
for k := range all {
result = append(result, k)
}
sort.Strings(result)
return result
}
// Difference returns the set difference (A - B) of the two given slices. The
// result will also remove any duplicated values in set A regardless of whether
// that matches any values in set B.
func Difference(a, b []string, lowercase bool) []string {
if len(a) == 0 || len(b) == 0 {
return a
}
a = RemoveDuplicates(a, lowercase)
b = RemoveDuplicates(b, lowercase)
itemsMap := map[string]bool{}
for _, aVal := range a {
itemsMap[aVal] = true
}
// Perform difference calculation
for _, bVal := range b {
if _, ok := itemsMap[bVal]; ok {
itemsMap[bVal] = false
}
}
items := []string{}
for item, exists := range itemsMap {
if exists {
items = append(items, item)
}
}
sort.Strings(items)
return items
}

View File

@ -459,3 +459,59 @@ func TestStrUtil_MergeSlices(t *testing.T) {
t.Fatalf("expected %v, got %v", expect, res)
}
}
func TestDifference(t *testing.T) {
testCases := []struct {
Name string
SetA []string
SetB []string
Lowercase bool
ExpectedResult []string
}{
{
Name: "case_sensitive",
SetA: []string{"a", "b", "c"},
SetB: []string{"b", "c"},
Lowercase: false,
ExpectedResult: []string{"a"},
},
{
Name: "case_insensitive",
SetA: []string{"a", "B", "c"},
SetB: []string{"b", "C"},
Lowercase: true,
ExpectedResult: []string{"a"},
},
{
Name: "no_match",
SetA: []string{"a", "b", "c"},
SetB: []string{"d"},
Lowercase: false,
ExpectedResult: []string{"a", "b", "c"},
},
{
Name: "empty_set_a",
SetA: []string{},
SetB: []string{"d", "e"},
Lowercase: false,
ExpectedResult: []string{},
},
{
Name: "empty_set_b",
SetA: []string{"a", "b"},
SetB: []string{},
Lowercase: false,
ExpectedResult: []string{"a", "b"},
},
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
actualResult := Difference(tc.SetA, tc.SetB, tc.Lowercase)
if !reflect.DeepEqual(actualResult, tc.ExpectedResult) {
t.Fatalf("expected %v, got %v", tc.ExpectedResult, actualResult)
}
})
}
}

View File

@ -2308,3 +2308,61 @@ func TestCore_HandleRequest_Headers(t *testing.T) {
t.Fatalf("did not expect 'Should-Not-Passthrough' to be in the headers map")
}
}
func TestCore_HandleRequest_Headers_denyList(t *testing.T) {
noop := &NoopBackend{
Response: &logical.Response{
Data: map[string]interface{}{},
},
}
c, _, root := TestCoreUnsealed(t)
c.logicalBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) {
return noop, nil
}
// Enable the backend
req := logical.TestRequest(t, logical.UpdateOperation, "sys/mounts/foo")
req.Data["type"] = "noop"
req.ClientToken = root
_, err := c.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
// Mount tune
req = logical.TestRequest(t, logical.UpdateOperation, "sys/mounts/foo/tune")
req.Data["passthrough_request_headers"] = []string{"Authorization", consts.AuthHeaderName}
req.ClientToken = root
_, err = c.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
// Attempt to read
lreq := &logical.Request{
Operation: logical.ReadOperation,
Path: "foo/test",
ClientToken: root,
Headers: map[string][]string{
consts.AuthHeaderName: []string{"foo"},
"Authorization": []string{"baz"},
},
}
_, err = c.HandleRequest(namespace.RootContext(nil), lreq)
if err != nil {
t.Fatalf("err: %v", err)
}
// Check the headers
headers := noop.Requests[0].Headers
// Test passthrough values, they should not be present in the backend
if _, ok := headers["Authorization"]; ok {
t.Fatalf("did not expect 'Should-Not-Passthrough' to be in the headers map")
}
if _, ok := headers[consts.AuthHeaderName]; ok {
t.Fatalf("did not expect %q to be in the headers map", consts.AuthHeaderName)
}
}

View File

@ -10,11 +10,20 @@ import (
"github.com/armon/go-metrics"
"github.com/armon/go-radix"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
)
var (
denylistHeaders = []string{
"Authorization",
consts.AuthHeaderName,
}
)
// Router is used to do prefix based routing of a request to a logical backend
type Router struct {
l sync.RWMutex
@ -753,9 +762,9 @@ func pathsToRadix(paths []string) *radix.Tree {
}
// filteredPassthroughHeaders returns a headers map[string][]string that
// contains the filtered values contained in passthroughHeaders, as well as the
// values in whitelistedHeaders. Filtering of passthroughHeaders from the
// origHeaders is done is a case-insensitive manner.
// contains the filtered values contained in passthroughHeaders. Filtering of
// passthroughHeaders from the origHeaders is done is a case-insensitive manner.
// Headers that match values from denylistHeaders will be ignored.
func filteredPassthroughHeaders(origHeaders map[string][]string, passthroughHeaders []string) map[string][]string {
retHeaders := make(map[string][]string)
@ -764,6 +773,10 @@ func filteredPassthroughHeaders(origHeaders map[string][]string, passthroughHead
return retHeaders
}
// Filter passthroughHeaders values through denyListHeaders first. Returns the
// lowercased the complement set.
passthroughHeadersSubset := strutil.Difference(passthroughHeaders, denylistHeaders, true)
// Create a map that uses lowercased header values as the key and the original
// header naming as the value for comparison down below.
lowerHeadersRef := make(map[string]string, len(origHeaders))
@ -774,8 +787,8 @@ func filteredPassthroughHeaders(origHeaders map[string][]string, passthroughHead
// Case-insensitive compare of passthrough headers against originating
// headers. The returned headers will be the same casing as the originating
// header name.
for _, ph := range passthroughHeaders {
if header, ok := lowerHeadersRef[strings.ToLower(ph)]; ok {
for _, ph := range passthroughHeadersSubset {
if header, ok := lowerHeadersRef[ph]; ok {
retHeaders[header] = origHeaders[header]
}
}