diff --git a/vault/audited_headers.go b/vault/audited_headers.go index 18b4aea3c..e7cb69a63 100644 --- a/vault/audited_headers.go +++ b/vault/audited_headers.go @@ -2,11 +2,15 @@ package vault import ( "fmt" + "strings" "sync" "github.com/hashicorp/vault/logical" ) +// N.B.: While we could use textproto to get the canonical mime header, HTTP/2 +// requires all headers to be converted to lower case, so we just do that. + const ( // Key used in the BarrierView to store and retrieve the header config auditedHeadersEntry = "audited-headers" @@ -37,7 +41,7 @@ func (a *AuditedHeadersConfig) add(header string, hmac bool) error { a.Lock() defer a.Unlock() - a.Headers[header] = &auditedHeaderSettings{hmac} + a.Headers[strings.ToLower(header)] = &auditedHeaderSettings{hmac} entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.Headers) if err != nil { return fmt.Errorf("failed to persist audited headers config: %v", err) @@ -60,7 +64,7 @@ func (a *AuditedHeadersConfig) remove(header string) error { a.Lock() defer a.Unlock() - delete(a.Headers, header) + delete(a.Headers, strings.ToLower(header)) entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.Headers) if err != nil { return fmt.Errorf("failed to persist audited headers config: %v", err) @@ -80,9 +84,16 @@ func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc a.RLock() defer a.RUnlock() + // Make a copy of the incoming headers with everything lower so we can + // case-insensitively compare + lowerHeaders := make(map[string][]string, len(headers)) + for k, v := range headers { + lowerHeaders[strings.ToLower(k)] = v + } + result = make(map[string][]string, len(a.Headers)) for key, settings := range a.Headers { - if val, ok := headers[key]; ok { + if val, ok := lowerHeaders[key]; ok { // copy the header values so we don't overwrite them hVals := make([]string, len(val)) copy(hVals, val) @@ -120,8 +131,15 @@ func (c *Core) setupAuditedHeadersConfig() error { } } + // Ensure that we are able to case-sensitively access the headers; + // necessary for the upgrade case + lowerHeaders := make(map[string]*auditedHeaderSettings, len(headers)) + for k, v := range headers { + lowerHeaders[strings.ToLower(k)] = v + } + c.auditedHeaders = &AuditedHeadersConfig{ - Headers: headers, + Headers: lowerHeaders, view: view, } diff --git a/vault/audited_headers_test.go b/vault/audited_headers_test.go index 07da7c9c5..5e82ec71d 100644 --- a/vault/audited_headers_test.go +++ b/vault/audited_headers_test.go @@ -29,7 +29,7 @@ func testAuditedHeadersConfig_Add(t *testing.T, conf *AuditedHeadersConfig) { t.Fatalf("Error when adding header to config: %s", err) } - settings, ok := conf.Headers["X-Test-Header"] + settings, ok := conf.Headers["x-test-header"] if !ok { t.Fatal("Expected header to be found in config") } @@ -50,7 +50,7 @@ func testAuditedHeadersConfig_Add(t *testing.T, conf *AuditedHeadersConfig) { } expected := map[string]*auditedHeaderSettings{ - "X-Test-Header": &auditedHeaderSettings{ + "x-test-header": &auditedHeaderSettings{ HMAC: false, }, } @@ -64,7 +64,7 @@ func testAuditedHeadersConfig_Add(t *testing.T, conf *AuditedHeadersConfig) { t.Fatalf("Error when adding header to config: %s", err) } - settings, ok = conf.Headers["X-Vault-Header"] + settings, ok = conf.Headers["x-vault-header"] if !ok { t.Fatal("Expected header to be found in config") } @@ -84,7 +84,7 @@ func testAuditedHeadersConfig_Add(t *testing.T, conf *AuditedHeadersConfig) { t.Fatalf("Error decoding header view: %s", err) } - expected["X-Vault-Header"] = &auditedHeaderSettings{ + expected["x-vault-header"] = &auditedHeaderSettings{ HMAC: true, } @@ -100,7 +100,7 @@ func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) { t.Fatalf("Error when adding header to config: %s", err) } - _, ok := conf.Headers["X-Test-Header"] + _, ok := conf.Headers["x-Test-HeAder"] if ok { t.Fatal("Expected header to not be found in config") } @@ -117,7 +117,7 @@ func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) { } expected := map[string]*auditedHeaderSettings{ - "X-Vault-Header": &auditedHeaderSettings{ + "x-vault-header": &auditedHeaderSettings{ HMAC: true, }, } @@ -126,12 +126,12 @@ func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) { t.Fatalf("Expected config didn't match actual. Expected: %#v, Got: %#v", expected, headers) } - err = conf.remove("X-Vault-Header") + err = conf.remove("x-VaulT-Header") if err != nil { t.Fatalf("Error when adding header to config: %s", err) } - _, ok = conf.Headers["X-Vault-Header"] + _, ok = conf.Headers["x-vault-header"] if ok { t.Fatal("Expected header to not be found in config") } @@ -157,10 +157,8 @@ func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) { func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) { conf := mockAuditedHeadersConfig(t) - conf.Headers = map[string]*auditedHeaderSettings{ - "X-Test-Header": &auditedHeaderSettings{false}, - "X-Vault-Header": &auditedHeaderSettings{true}, - } + conf.add("X-TesT-Header", false) + conf.add("X-Vault-HeAdEr", true) reqHeaders := map[string][]string{ "X-Test-Header": []string{"foo"}, @@ -173,8 +171,8 @@ func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) { result := conf.ApplyConfig(reqHeaders, hashFunc) expected := map[string][]string{ - "X-Test-Header": []string{"foo"}, - "X-Vault-Header": []string{"hashed", "hashed"}, + "x-test-header": []string{"foo"}, + "x-vault-header": []string{"hashed", "hashed"}, } if !reflect.DeepEqual(result, expected) {