parent
e2de7ec7fe
commit
acb7391b12
|
@ -2,11 +2,15 @@ package vault
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// N.B.: While we could use textproto to get the canonical mime header, HTTP/2
|
||||||
|
// requires all headers to be converted to lower case, so we just do that.
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Key used in the BarrierView to store and retrieve the header config
|
// Key used in the BarrierView to store and retrieve the header config
|
||||||
auditedHeadersEntry = "audited-headers"
|
auditedHeadersEntry = "audited-headers"
|
||||||
|
@ -37,7 +41,7 @@ func (a *AuditedHeadersConfig) add(header string, hmac bool) error {
|
||||||
a.Lock()
|
a.Lock()
|
||||||
defer a.Unlock()
|
defer a.Unlock()
|
||||||
|
|
||||||
a.Headers[header] = &auditedHeaderSettings{hmac}
|
a.Headers[strings.ToLower(header)] = &auditedHeaderSettings{hmac}
|
||||||
entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.Headers)
|
entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.Headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to persist audited headers config: %v", err)
|
return fmt.Errorf("failed to persist audited headers config: %v", err)
|
||||||
|
@ -60,7 +64,7 @@ func (a *AuditedHeadersConfig) remove(header string) error {
|
||||||
a.Lock()
|
a.Lock()
|
||||||
defer a.Unlock()
|
defer a.Unlock()
|
||||||
|
|
||||||
delete(a.Headers, header)
|
delete(a.Headers, strings.ToLower(header))
|
||||||
entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.Headers)
|
entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.Headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to persist audited headers config: %v", err)
|
return fmt.Errorf("failed to persist audited headers config: %v", err)
|
||||||
|
@ -80,9 +84,16 @@ func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc
|
||||||
a.RLock()
|
a.RLock()
|
||||||
defer a.RUnlock()
|
defer a.RUnlock()
|
||||||
|
|
||||||
|
// Make a copy of the incoming headers with everything lower so we can
|
||||||
|
// case-insensitively compare
|
||||||
|
lowerHeaders := make(map[string][]string, len(headers))
|
||||||
|
for k, v := range headers {
|
||||||
|
lowerHeaders[strings.ToLower(k)] = v
|
||||||
|
}
|
||||||
|
|
||||||
result = make(map[string][]string, len(a.Headers))
|
result = make(map[string][]string, len(a.Headers))
|
||||||
for key, settings := range a.Headers {
|
for key, settings := range a.Headers {
|
||||||
if val, ok := headers[key]; ok {
|
if val, ok := lowerHeaders[key]; ok {
|
||||||
// copy the header values so we don't overwrite them
|
// copy the header values so we don't overwrite them
|
||||||
hVals := make([]string, len(val))
|
hVals := make([]string, len(val))
|
||||||
copy(hVals, val)
|
copy(hVals, val)
|
||||||
|
@ -120,8 +131,15 @@ func (c *Core) setupAuditedHeadersConfig() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure that we are able to case-sensitively access the headers;
|
||||||
|
// necessary for the upgrade case
|
||||||
|
lowerHeaders := make(map[string]*auditedHeaderSettings, len(headers))
|
||||||
|
for k, v := range headers {
|
||||||
|
lowerHeaders[strings.ToLower(k)] = v
|
||||||
|
}
|
||||||
|
|
||||||
c.auditedHeaders = &AuditedHeadersConfig{
|
c.auditedHeaders = &AuditedHeadersConfig{
|
||||||
Headers: headers,
|
Headers: lowerHeaders,
|
||||||
view: view,
|
view: view,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ func testAuditedHeadersConfig_Add(t *testing.T, conf *AuditedHeadersConfig) {
|
||||||
t.Fatalf("Error when adding header to config: %s", err)
|
t.Fatalf("Error when adding header to config: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, ok := conf.Headers["X-Test-Header"]
|
settings, ok := conf.Headers["x-test-header"]
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("Expected header to be found in config")
|
t.Fatal("Expected header to be found in config")
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,7 @@ func testAuditedHeadersConfig_Add(t *testing.T, conf *AuditedHeadersConfig) {
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := map[string]*auditedHeaderSettings{
|
expected := map[string]*auditedHeaderSettings{
|
||||||
"X-Test-Header": &auditedHeaderSettings{
|
"x-test-header": &auditedHeaderSettings{
|
||||||
HMAC: false,
|
HMAC: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -64,7 +64,7 @@ func testAuditedHeadersConfig_Add(t *testing.T, conf *AuditedHeadersConfig) {
|
||||||
t.Fatalf("Error when adding header to config: %s", err)
|
t.Fatalf("Error when adding header to config: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, ok = conf.Headers["X-Vault-Header"]
|
settings, ok = conf.Headers["x-vault-header"]
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("Expected header to be found in config")
|
t.Fatal("Expected header to be found in config")
|
||||||
}
|
}
|
||||||
|
@ -84,7 +84,7 @@ func testAuditedHeadersConfig_Add(t *testing.T, conf *AuditedHeadersConfig) {
|
||||||
t.Fatalf("Error decoding header view: %s", err)
|
t.Fatalf("Error decoding header view: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
expected["X-Vault-Header"] = &auditedHeaderSettings{
|
expected["x-vault-header"] = &auditedHeaderSettings{
|
||||||
HMAC: true,
|
HMAC: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) {
|
||||||
t.Fatalf("Error when adding header to config: %s", err)
|
t.Fatalf("Error when adding header to config: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := conf.Headers["X-Test-Header"]
|
_, ok := conf.Headers["x-Test-HeAder"]
|
||||||
if ok {
|
if ok {
|
||||||
t.Fatal("Expected header to not be found in config")
|
t.Fatal("Expected header to not be found in config")
|
||||||
}
|
}
|
||||||
|
@ -117,7 +117,7 @@ func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) {
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := map[string]*auditedHeaderSettings{
|
expected := map[string]*auditedHeaderSettings{
|
||||||
"X-Vault-Header": &auditedHeaderSettings{
|
"x-vault-header": &auditedHeaderSettings{
|
||||||
HMAC: true,
|
HMAC: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -126,12 +126,12 @@ func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) {
|
||||||
t.Fatalf("Expected config didn't match actual. Expected: %#v, Got: %#v", expected, headers)
|
t.Fatalf("Expected config didn't match actual. Expected: %#v, Got: %#v", expected, headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = conf.remove("X-Vault-Header")
|
err = conf.remove("x-VaulT-Header")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error when adding header to config: %s", err)
|
t.Fatalf("Error when adding header to config: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok = conf.Headers["X-Vault-Header"]
|
_, ok = conf.Headers["x-vault-header"]
|
||||||
if ok {
|
if ok {
|
||||||
t.Fatal("Expected header to not be found in config")
|
t.Fatal("Expected header to not be found in config")
|
||||||
}
|
}
|
||||||
|
@ -157,10 +157,8 @@ func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) {
|
||||||
func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) {
|
func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) {
|
||||||
conf := mockAuditedHeadersConfig(t)
|
conf := mockAuditedHeadersConfig(t)
|
||||||
|
|
||||||
conf.Headers = map[string]*auditedHeaderSettings{
|
conf.add("X-TesT-Header", false)
|
||||||
"X-Test-Header": &auditedHeaderSettings{false},
|
conf.add("X-Vault-HeAdEr", true)
|
||||||
"X-Vault-Header": &auditedHeaderSettings{true},
|
|
||||||
}
|
|
||||||
|
|
||||||
reqHeaders := map[string][]string{
|
reqHeaders := map[string][]string{
|
||||||
"X-Test-Header": []string{"foo"},
|
"X-Test-Header": []string{"foo"},
|
||||||
|
@ -173,8 +171,8 @@ func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) {
|
||||||
result := conf.ApplyConfig(reqHeaders, hashFunc)
|
result := conf.ApplyConfig(reqHeaders, hashFunc)
|
||||||
|
|
||||||
expected := map[string][]string{
|
expected := map[string][]string{
|
||||||
"X-Test-Header": []string{"foo"},
|
"x-test-header": []string{"foo"},
|
||||||
"X-Vault-Header": []string{"hashed", "hashed"},
|
"x-vault-header": []string{"hashed", "hashed"},
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(result, expected) {
|
if !reflect.DeepEqual(result, expected) {
|
||||||
|
|
Loading…
Reference in New Issue