From ad2ef412cc343dab9bbd29465ed33faaac815174 Mon Sep 17 00:00:00 2001 From: hghaf099 <83242695+hghaf099@users.noreply.github.com> Date: Wed, 13 Oct 2021 11:06:33 -0400 Subject: [PATCH] Customizing HTTP headers in the config file (#12485) * Customizing HTTP headers in the config file * Add changelog, fix bad imports * fixing some bugs * fixing interaction of custom headers and /ui * Defining a member in core to set custom response headers * missing additional file * Some refactoring * Adding automated tests for the feature * Changing some error messages based on some recommendations * Incorporating custom response headers struct into the request context * removing some unused references * fixing a test * changing some error messages, removing a default header value from /ui * fixing a test * wrapping ResponseWriter to set the custom headers * adding a new test * some cleanup * removing some extra lines * Addressing comments * fixing some agent tests * skipping custom headers from agent listener config, removing two of the default headers as they cause issues with Vault in UI mode Adding X-Content-Type-Options to the ui default headers Let Content-Type be set as before * Removing default custom headers, and renaming some function varibles * some refacotring * Refactoring and addressing comments * removing a function and fixing comments --- changelog/12485.txt | 3 + command/agent/config/config.go | 7 + command/agent/config/config_test.go | 2 - command/server.go | 6 + .../config_custom_response_headers_test.go | 109 +++++++++++ command/server/config_test_helpers.go | 19 +- .../config_custom_response_headers_1.hcl | 31 ++++ ...om_response_headers_multiple_listeners.hcl | 56 ++++++ http/custom_header_test.go | 128 +++++++++++++ http/handler.go | 112 ++++++++++- http/http_test.go | 10 + http/sys_metrics.go | 4 +- http/testing.go | 7 + .../configutil/http_response_headers.go | 129 +++++++++++++ internalshared/configutil/listener.go | 15 ++ vault/core.go | 111 +++++++++-- vault/custom_response_headers.go | 90 +++++++++ vault/custom_response_headers_test.go | 174 ++++++++++++++++++ vault/logical_system.go | 3 + vault/testing.go | 23 +++ vault/ui.go | 3 +- 21 files changed, 1019 insertions(+), 23 deletions(-) create mode 100644 changelog/12485.txt create mode 100644 command/server/config_custom_response_headers_test.go create mode 100644 command/server/test-fixtures/config_custom_response_headers_1.hcl create mode 100644 command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl create mode 100644 http/custom_header_test.go create mode 100644 internalshared/configutil/http_response_headers.go create mode 100644 vault/custom_response_headers.go create mode 100644 vault/custom_response_headers_test.go diff --git a/changelog/12485.txt b/changelog/12485.txt new file mode 100644 index 000000000..6c8a87cd2 --- /dev/null +++ b/changelog/12485.txt @@ -0,0 +1,3 @@ +```release-note:feature +**Customizable HTTP Headers**: Add support to define custom HTTP headers for root path (`/`) and also on API endpoints (`/v1/*`) +``` diff --git a/command/agent/config/config.go b/command/agent/config/config.go index 9438bd327..502d512d1 100644 --- a/command/agent/config/config.go +++ b/command/agent/config/config.go @@ -35,6 +35,7 @@ func (c *Config) Prune() { l.RawConfig = nil l.Profiling.UnusedKeys = nil l.Telemetry.UnusedKeys = nil + l.CustomResponseHeaders = nil } c.FoundKeys = nil c.UnusedKeys = nil @@ -172,6 +173,12 @@ func LoadConfig(path string) (*Config, error) { if err != nil { return nil, err } + + // Pruning custom headers for Agent for now + for _, ln := range sharedConfig.Listeners { + ln.CustomResponseHeaders = nil + } + result.SharedConfig = sharedConfig list, ok := obj.Node.(*ast.ObjectList) diff --git a/command/agent/config/config_test.go b/command/agent/config/config_test.go index 0db8cf919..252461236 100644 --- a/command/agent/config/config_test.go +++ b/command/agent/config/config_test.go @@ -536,7 +536,6 @@ func TestLoadConfigFile_AgentCache_PersistMissingType(t *testing.T) { } func TestLoadConfigFile_TemplateConfig(t *testing.T) { - testCases := map[string]struct { fixturePath string expectedTemplateConfig TemplateConfig @@ -586,7 +585,6 @@ func TestLoadConfigFile_TemplateConfig(t *testing.T) { } }) } - } // TestLoadConfigFile_Template tests template definitions in Vault Agent diff --git a/command/server.go b/command/server.go index 84814c410..718009b8c 100644 --- a/command/server.go +++ b/command/server.go @@ -1541,6 +1541,12 @@ func (c *ServerCommand) Run(args []string) int { core.SetConfig(config) + // reloading custom response headers to make sure we have + // the most up to date headers after reloading the config file + if err = core.ReloadCustomResponseHeaders(); err != nil { + c.logger.Error(err.Error()) + } + if config.LogLevel != "" { configLogLevel := strings.ToLower(strings.TrimSpace(config.LogLevel)) switch configLogLevel { diff --git a/command/server/config_custom_response_headers_test.go b/command/server/config_custom_response_headers_test.go new file mode 100644 index 000000000..5380568c2 --- /dev/null +++ b/command/server/config_custom_response_headers_test.go @@ -0,0 +1,109 @@ +package server + +import ( + "fmt" + "testing" + + "github.com/go-test/deep" +) + +var defaultCustomHeaders = map[string]string{ + "Strict-Transport-Security": "max-age=1; domains", + "Content-Security-Policy": "default-src 'others'", + "X-Vault-Ignored": "ignored", + "X-Custom-Header": "Custom header value default", +} + +var customHeaders307 = map[string]string{ + "X-Custom-Header": "Custom header value 307", +} + +var customHeader3xx = map[string]string{ + "X-Vault-Ignored-3xx": "Ignored 3xx", + "X-Custom-Header": "Custom header value 3xx", +} + +var customHeaders200 = map[string]string{ + "Someheader-200": "200", + "X-Custom-Header": "Custom header value 200", +} + +var customHeader2xx = map[string]string{ + "X-Custom-Header": "Custom header value 2xx", +} + +var customHeader400 = map[string]string{ + "Someheader-400": "400", +} + +var defaultCustomHeadersMultiListener = map[string]string{ + "Strict-Transport-Security": "max-age=31536000; includeSubDomains", + "Content-Security-Policy": "default-src 'others'", + "X-Vault-Ignored": "ignored", + "X-Custom-Header": "Custom header value default", +} + +var defaultSTS = map[string]string{ + "Strict-Transport-Security": "max-age=31536000; includeSubDomains", +} + +func TestCustomResponseHeadersConfigs(t *testing.T) { + expectedCustomResponseHeader := map[string]map[string]string{ + "default": defaultCustomHeaders, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, + } + + config, err := LoadConfigFile("./test-fixtures/config_custom_response_headers_1.hcl") + if err != nil { + t.Fatalf("Error encountered when loading config %+v", err) + } + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[0].CustomResponseHeaders); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } +} + +func TestCustomResponseHeadersConfigsMultipleListeners(t *testing.T) { + expectedCustomResponseHeader := map[string]map[string]string{ + "default": defaultCustomHeadersMultiListener, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, + } + + config, err := LoadConfigFile("./test-fixtures/config_custom_response_headers_multiple_listeners.hcl") + if err != nil { + t.Fatalf("Error encountered when loading config %+v", err) + } + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[0].CustomResponseHeaders); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[1].CustomResponseHeaders); diff == nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + if diff := deep.Equal(expectedCustomResponseHeader["default"], config.Listeners[1].CustomResponseHeaders["default"]); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[2].CustomResponseHeaders); diff == nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(defaultSTS, config.Listeners[2].CustomResponseHeaders["default"]); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[3].CustomResponseHeaders); diff == nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(defaultSTS, config.Listeners[3].CustomResponseHeaders["default"]); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } +} diff --git a/command/server/config_test_helpers.go b/command/server/config_test_helpers.go index e40f6c836..8936a0244 100644 --- a/command/server/config_test_helpers.go +++ b/command/server/config_test_helpers.go @@ -16,6 +16,12 @@ import ( "github.com/hashicorp/vault/internalshared/configutil" ) +var DefaultCustomHeaders = map[string]map[string]string { + "default": { + "Strict-Transport-Security": configutil.StrictTransportSecurity, + }, +} + func boolPointer(x bool) *bool { return &x } @@ -32,6 +38,7 @@ func testConfigRaftRetryJoin(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:8200", + CustomResponseHeaders: DefaultCustomHeaders, }, }, DisableMlock: true, @@ -64,6 +71,7 @@ func testLoadConfigFile_topLevel(t *testing.T, entropy *configutil.Entropy) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -174,10 +182,12 @@ func testLoadConfigFile_json2(t *testing.T, entropy *configutil.Entropy) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, { Type: "tcp", Address: "127.0.0.1:444", + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -336,6 +346,7 @@ func testLoadConfigFileIntegerAndBooleanValuesCommon(t *testing.T, path string) { Type: "tcp", Address: "127.0.0.1:8200", + CustomResponseHeaders: DefaultCustomHeaders, }, }, DisableMlock: true, @@ -379,6 +390,7 @@ func testLoadConfigFile(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -486,7 +498,7 @@ func testUnknownFieldValidation(t *testing.T) { for _, er1 := range errors { found := false if strings.Contains(er1.String(), "sentinel") { - //This happens on OSS, and is fine + // This happens on OSS, and is fine continue } for _, ex := range expected { @@ -525,6 +537,7 @@ func testLoadConfigFile_json(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -610,6 +623,7 @@ func testLoadConfigDir(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -818,6 +832,7 @@ listener "tcp" { Profiling: configutil.ListenerProfiling{ UnauthenticatedPProfAccess: true, }, + CustomResponseHeaders: DefaultCustomHeaders, }, }, }, @@ -845,6 +860,7 @@ func testParseSeals(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, Seals: []*configutil.KMS{ @@ -898,6 +914,7 @@ func testLoadConfigFileLeaseMetrics(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, diff --git a/command/server/test-fixtures/config_custom_response_headers_1.hcl b/command/server/test-fixtures/config_custom_response_headers_1.hcl new file mode 100644 index 000000000..c2f868c2f --- /dev/null +++ b/command/server/test-fixtures/config_custom_response_headers_1.hcl @@ -0,0 +1,31 @@ +storage "inmem" {} +listener "tcp" { + address = "127.0.0.1:8200" + tls_disable = true + custom_response_headers { + "default" = { + "Strict-Transport-Security" = ["max-age=1","domains"], + "Content-Security-Policy" = ["default-src 'others'"], + "X-Vault-Ignored" = ["ignored"], + "X-Custom-Header" = ["Custom header value default"], + } + "307" = { + "X-Custom-Header" = ["Custom header value 307"], + } + "3xx" = { + "X-Vault-Ignored-3xx" = ["Ignored 3xx"], + "X-Custom-Header" = ["Custom header value 3xx"] + } + "200" = { + "someheader-200" = ["200"], + "X-Custom-Header" = ["Custom header value 200"] + } + "2xx" = { + "X-Custom-Header" = ["Custom header value 2xx"] + } + "400" = { + "someheader-400" = ["400"] + } + } +} +disable_mlock = true diff --git a/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl b/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl new file mode 100644 index 000000000..11aa09923 --- /dev/null +++ b/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl @@ -0,0 +1,56 @@ +storage "inmem" {} +listener "tcp" { + address = "127.0.0.1:8200" + tls_disable = true + custom_response_headers { + "default" = { + "Content-Security-Policy" = ["default-src 'others'"], + "X-Vault-Ignored" = ["ignored"], + "X-Custom-Header" = ["Custom header value default"], + } + "307" = { + "X-Custom-Header" = ["Custom header value 307"], + } + "3xx" = { + "X-Vault-Ignored-3xx" = ["Ignored 3xx"], + "X-Custom-Header" = ["Custom header value 3xx"] + } + "200" = { + "someheader-200" = ["200"], + "X-Custom-Header" = ["Custom header value 200"] + } + "2xx" = { + "X-Custom-Header" = ["Custom header value 2xx"] + } + "400" = { + "someheader-400" = ["400"] + } + } +} +listener "tcp" { + address = "127.0.0.2:8200" + tls_disable = true + custom_response_headers { + "default" = { + "Content-Security-Policy" = ["default-src 'others'"], + "X-Vault-Ignored" = ["ignored"], + "X-Custom-Header" = ["Custom header value default"], + } + } +} +listener "tcp" { + address = "127.0.0.3:8200" + tls_disable = true + custom_response_headers { + "2xx" = { + "X-Custom-Header" = ["Custom header value 2xx"] + } + } +} +listener "tcp" { + address = "127.0.0.4:8200" + tls_disable = true +} + + +disable_mlock = true diff --git a/http/custom_header_test.go b/http/custom_header_test.go new file mode 100644 index 000000000..5125050ad --- /dev/null +++ b/http/custom_header_test.go @@ -0,0 +1,128 @@ +package http + +import ( + "testing" + + "github.com/hashicorp/vault/vault" +) + +var defaultCustomHeaders = map[string]string { + "Strict-Transport-Security": "max-age=1; domains", + "Content-Security-Policy": "default-src 'others'", + "X-Custom-Header": "Custom header value default", + "X-Frame-Options": "Deny", + "X-Content-Type-Options": "nosniff", + "Content-Type": "application/json", + "X-XSS-Protection": "1; mode=block", +} + +var customHeader2xx = map[string]string { + "X-Custom-Header": "Custom header value 2xx", +} + +var customHeader200 = map[string]string { + "Someheader-200": "200", + "X-Custom-Header": "Custom header value 200", +} + +var customHeader4xx = map[string]string { + "Someheader-4xx": "4xx", +} + +var customHeader400 = map[string]string { + "Someheader-400": "400", +} + +var customHeader405 = map[string]string { + "Someheader-405": "405", +} + +var CustomResponseHeaders = map[string]map[string]string{ + "default": defaultCustomHeaders, + "307": {"X-Custom-Header": "Custom header value 307"}, + "3xx": { + "X-Custom-Header": "Custom header value 3xx", + "X-Vault-Ignored-3xx": "Ignored 3xx", + }, + "200": customHeader200, + "2xx": customHeader2xx, + "400": customHeader400, + "405": customHeader405, + "4xx": customHeader4xx, +} + +func TestCustomResponseHeaders(t *testing.T) { + core, _, token := vault.TestCoreWithCustomResponseHeaderAndUI(t, CustomResponseHeaders, true) + ln, addr := TestServer(t, core) + defer ln.Close() + TestServerAuth(t, addr, token) + + resp := testHttpGet(t, token, addr+"/v1/sys/raw/") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1/sys/seal") + testResponseStatus(t, resp, 405) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + testResponseHeader(t, resp, customHeader405) + + resp = testHttpGet(t, token, addr+"/v1/sys/leader") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/health") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/generate-root/attempt") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/generate-root/update") + testResponseStatus(t, resp, 400) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + testResponseHeader(t, resp, customHeader400) + + resp = testHttpGet(t, token, addr+"/v1/sys/") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1/sys") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1/") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/ui") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/ui/") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpPost(t, token, addr+"/v1/sys/auth/foo", map[string]interface{}{ + "type": "noop", + "description": "foo", + }) + testResponseStatus(t, resp, 204) + testResponseHeader(t, resp, customHeader2xx) + +} \ No newline at end of file diff --git a/http/handler.go b/http/handler.go index 7d48f97ae..22aab6ccb 100644 --- a/http/handler.go +++ b/http/handler.go @@ -16,12 +16,14 @@ import ( "net/textproto" "net/url" "os" + "strconv" "strings" "time" "github.com/NYTimes/gziphandler" "github.com/hashicorp/errwrap" "github.com/hashicorp/go-cleanhttp" + log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/vault/helper/namespace" @@ -210,6 +212,90 @@ func Handler(props *vault.HandlerProperties) http.Handler { return printablePathCheckHandler } +type WrappingResponseWriter interface { + http.ResponseWriter + Wrapped() http.ResponseWriter +} + +type statusHeaderResponseWriter struct { + wrapped http.ResponseWriter + logger log.Logger + wroteHeader bool + statusCode int + headers map[string][]*vault.CustomHeader +} + +func (w *statusHeaderResponseWriter) Wrapped() http.ResponseWriter { + return w.wrapped +} + +func (w *statusHeaderResponseWriter) Header() http.Header { + return w.wrapped.Header() +} + +func (w *statusHeaderResponseWriter) Write(buf []byte) (int, error) { + // It is allowed to only call ResponseWriter.Write and skip + // ResponseWriter.WriteHeader. An example of such a situation is + // "handleUIStub". The Write function will internally set the status code + // 200 for the response for which that call might invoke other + // implementations of the WriteHeader function. So, we still need to set + // the custom headers. In cases where both WriteHeader and Write of + // statusHeaderResponseWriter struct are called the internal call to the + // WriterHeader invoked from inside Write method won't change the headers. + if !w.wroteHeader { + w.setCustomResponseHeaders(w.statusCode) + } + + return w.wrapped.Write(buf) +} + +func (w *statusHeaderResponseWriter) WriteHeader(statusCode int) { + w.setCustomResponseHeaders(statusCode) + w.wrapped.WriteHeader(statusCode) + w.statusCode = statusCode + // in cases where Write is called after WriteHeader, let's prevent setting + // ResponseWriter headers twice + w.wroteHeader = true +} + +func (w *statusHeaderResponseWriter) setCustomResponseHeaders(status int) { + sch := w.headers + if sch == nil { + w.logger.Warn("status code header map not configured") + return + } + + // Checking the validity of the status code + if status >= 600 || status < 100 { + return + } + + // setter function to set the headers + setter := func(hvl []*vault.CustomHeader) { + for _, hv := range hvl { + w.Header().Set(hv.Name, hv.Value) + } + } + + // Setting the default headers first + setter(sch["default"]) + + // setting the Xyy pattern first + d := fmt.Sprintf("%vxx", status/100) + if val, ok := sch[d]; ok { + setter(val) + } + + // Setting the specific headers + if val, ok := sch[strconv.Itoa(status)]; ok { + setter(val) + } + + return +} + +var _ WrappingResponseWriter = &statusHeaderResponseWriter{} + type copyResponseWriter struct { wrapped http.ResponseWriter statusCode int @@ -300,6 +386,22 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr hostname, _ := os.Hostname() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // This block needs to be here so that upon sending SIGHUP, custom response + // headers are also reloaded into the handlers. + if props.ListenerConfig != nil { + la := props.ListenerConfig.Address + listenerCustomHeaders := core.GetListenerCustomResponseHeaders(la) + if listenerCustomHeaders != nil { + w = &statusHeaderResponseWriter{ + wrapped: w, + logger: core.Logger(), + wroteHeader: false, + statusCode: 200, + headers: listenerCustomHeaders.StatusCodeHeaderMap, + } + } + } + // Set the Cache-Control header for all the responses returned // by Vault w.Header().Set("Cache-Control", "no-store") @@ -632,7 +734,15 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, return nil, errors.New("could not parse max_request_size from request context") } if max > 0 { - reader = http.MaxBytesReader(w, r.Body, max) + // MaxBytesReader won't do all the internal stuff it must unless it's + // given a ResponseWriter that implements the internal http interface + // requestTooLarger. So we let it have access to the underlying + // ResponseWriter. + inw := w + if myw, ok := inw.(WrappingResponseWriter); ok { + inw = myw.Wrapped() + } + reader = http.MaxBytesReader(inw, r.Body, max) } } var origBody io.ReadWriter diff --git a/http/http_test.go b/http/http_test.go index e37b9c3d7..692aef0d8 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -125,6 +125,16 @@ func testResponseStatus(t *testing.T, resp *http.Response, code int) { } } +func testResponseHeader(t *testing.T, resp *http.Response, expectedHeaders map[string]string) { + t.Helper() + for k, v := range expectedHeaders { + hv := resp.Header.Get(k) + if v != hv { + t.Fatalf("expected header value %v=%v, got %v=%v", k, v, k, hv) + } + } +} + func testResponseBody(t *testing.T, resp *http.Response, out interface{}) { defer resp.Body.Close() diff --git a/http/sys_metrics.go b/http/sys_metrics.go index 0e58be3ea..012417282 100644 --- a/http/sys_metrics.go +++ b/http/sys_metrics.go @@ -35,12 +35,14 @@ func handleMetricsUnauthenticated(core *vault.Core) http.Handler { resp := core.MetricsHelper().ResponseForFormat(format) // Manually extract the logical response and send back the information - w.WriteHeader(resp.Data[logical.HTTPStatusCode].(int)) + status := resp.Data[logical.HTTPStatusCode].(int) w.Header().Set("Content-Type", resp.Data[logical.HTTPContentType].(string)) switch v := resp.Data[logical.HTTPRawBody].(type) { case string: + w.WriteHeader(status) w.Write([]byte(v)) case []byte: + w.WriteHeader(status) w.Write(v) default: respondError(w, http.StatusInternalServerError, fmt.Errorf("wrong response returned")) diff --git a/http/testing.go b/http/testing.go index be9569dc9..84ab73fc0 100644 --- a/http/testing.go +++ b/http/testing.go @@ -6,6 +6,7 @@ import ( "net/http" "testing" + "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/vault" ) @@ -41,10 +42,16 @@ func TestServerWithListenerAndProperties(tb testing.TB, ln net.Listener, addr st } func TestServerWithListener(tb testing.TB, ln net.Listener, addr string, core *vault.Core) { + ip, _, _ := net.SplitHostPort(ln.Addr().String()) + // Create a muxer to handle our requests so that we can authenticate // for tests. props := &vault.HandlerProperties{ Core: core, + // This is needed for testing custom response headers + ListenerConfig: &configutil.Listener { + Address: ip, + }, } TestServerWithListenerAndProperties(tb, ln, addr, core, props) } diff --git a/internalshared/configutil/http_response_headers.go b/internalshared/configutil/http_response_headers.go new file mode 100644 index 000000000..2db3034e5 --- /dev/null +++ b/internalshared/configutil/http_response_headers.go @@ -0,0 +1,129 @@ +package configutil + +import ( + "fmt" + "net/textproto" + "strconv" + "strings" + + "github.com/hashicorp/go-secure-stdlib/strutil" +) + +var ValidCustomStatusCodeCollection = []string{ + "default", + "1xx", + "2xx", + "3xx", + "4xx", + "5xx", +} + +const StrictTransportSecurity = "max-age=31536000; includeSubDomains" + +// ParseCustomResponseHeaders takes a raw config values for the +// "custom_response_headers". It makes sure the config entry is passed in +// as a map of status code to a map of header name and header values. It +// verifies the validity of the status codes, and header values. It also +// adds the default headers values. +func ParseCustomResponseHeaders(responseHeaders interface{}) (map[string]map[string]string, error) { + h := make(map[string]map[string]string) + // if r is nil, we still should set the default custom headers + if responseHeaders == nil { + h["default"] = map[string]string{"Strict-Transport-Security": StrictTransportSecurity} + return h, nil + } + + customResponseHeader, ok := responseHeaders.([]map[string]interface{}) + if !ok { + return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps") + } + + for _, crh := range customResponseHeader { + for statusCode, responseHeader := range crh { + headerValList, ok := responseHeader.([]map[string]interface{}) + if !ok { + return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps") + } + + if !IsValidStatusCode(statusCode) { + return nil, fmt.Errorf("invalid status code found in the server configuration: %v", statusCode) + } + + if len(headerValList) != 1 { + return nil, fmt.Errorf("invalid number of response headers exist") + } + headerValMap := headerValList[0] + headerVal, err := parseHeaders(headerValMap) + if err != nil { + return nil, err + } + + h[statusCode] = headerVal + } + } + + // setting Strict-Transport-Security as a default header + if h["default"] == nil { + h["default"] = make(map[string]string) + } + if _, ok := h["default"]["Strict-Transport-Security"]; !ok { + h["default"]["Strict-Transport-Security"] = StrictTransportSecurity + } + + return h, nil +} + +// IsValidStatusCode checking for status codes outside the boundary +func IsValidStatusCode(sc string) bool { + if strutil.StrListContains(ValidCustomStatusCodeCollection, sc) { + return true + } + + i, err := strconv.Atoi(sc) + if err != nil { + return false + } + + if i >= 600 || i < 100 { + return false + } + + return true +} + +func parseHeaders(in map[string]interface{}) (map[string]string, error) { + hvMap := make(map[string]string) + for k, v := range in { + // parsing header name + headerName := textproto.CanonicalMIMEHeaderKey(k) + // parsing header values + s, err := parseHeaderValues(v) + if err != nil { + return nil, err + } + hvMap[headerName] = s + } + return hvMap, nil +} + +func parseHeaderValues(header interface{}) (string, error) { + var sl []string + if _, ok := header.([]interface{}); !ok { + return "", fmt.Errorf("headers must be given in a list of strings") + } + headerValList := header.([]interface{}) + for _, vh := range headerValList { + if _, ok := vh.(string); !ok { + return "", fmt.Errorf("found a non-string header value: %v", vh) + } + headerVal := strings.TrimSpace(vh.(string)) + if headerVal == "" { + continue + } + sl = append(sl, headerVal) + + } + s := strings.Join(sl, "; ") + + return s, nil +} diff --git a/internalshared/configutil/listener.go b/internalshared/configutil/listener.go index 981990828..677dbf9df 100644 --- a/internalshared/configutil/listener.go +++ b/internalshared/configutil/listener.go @@ -99,6 +99,10 @@ type Listener struct { CorsAllowedOrigins []string `hcl:"cors_allowed_origins"` CorsAllowedHeaders []string `hcl:"-"` CorsAllowedHeadersRaw []string `hcl:"cors_allowed_headers,alias:cors_allowed_headers"` + + // Custom Http response headers + CustomResponseHeaders map[string]map[string]string `hcl:"-"` + CustomResponseHeadersRaw interface{} `hcl:"custom_response_headers"` } func (l *Listener) GoString() string { @@ -361,6 +365,17 @@ func ParseListeners(result *SharedConfig, list *ast.ObjectList) error { } } + // HTTP Headers + { + // if CustomResponseHeadersRaw is nil, we still need to set the default headers + customHeadersMap, err := ParseCustomResponseHeaders(l.CustomResponseHeadersRaw) + if err != nil { + return multierror.Prefix(fmt.Errorf("failed to parse custom_response_headers: %w", err), fmt.Sprintf("listeners.%d", i)) + } + l.CustomResponseHeaders = customHeadersMap + l.CustomResponseHeadersRaw = nil + } + result.Listeners = append(result.Listeners, &l) } diff --git a/vault/core.go b/vault/core.go index ddd2a2cf4..39f7b6da7 100644 --- a/vault/core.go +++ b/vault/core.go @@ -510,6 +510,9 @@ type Core struct { // clusterListener starts up and manages connections on the cluster ports clusterListener *atomic.Value + // customListenerHeader holds custom response headers for a listener + customListenerHeader *atomic.Value + // Telemetry objects metricsHelper *metricsutil.MetricsHelper @@ -769,23 +772,24 @@ func CreateCore(conf *CoreConfig) (*Core, error) { // Setup the core c := &Core{ - entCore: entCore{}, - devToken: conf.DevToken, - physical: conf.Physical, - serviceRegistration: conf.GetServiceRegistration(), - underlyingPhysical: conf.Physical, - storageType: conf.StorageType, - redirectAddr: conf.RedirectAddr, - clusterAddr: new(atomic.Value), - clusterListener: new(atomic.Value), - seal: conf.Seal, - router: NewRouter(), - sealed: new(uint32), - sealMigrationDone: new(uint32), - standby: true, - standbyStopCh: new(atomic.Value), - baseLogger: conf.Logger, - logger: conf.Logger.Named("core"), + entCore: entCore{}, + devToken: conf.DevToken, + physical: conf.Physical, + serviceRegistration: conf.GetServiceRegistration(), + underlyingPhysical: conf.Physical, + storageType: conf.StorageType, + redirectAddr: conf.RedirectAddr, + clusterAddr: new(atomic.Value), + clusterListener: new(atomic.Value), + customListenerHeader: new(atomic.Value), + seal: conf.Seal, + router: NewRouter(), + sealed: new(uint32), + sealMigrationDone: new(uint32), + standby: true, + standbyStopCh: new(atomic.Value), + baseLogger: conf.Logger, + logger: conf.Logger.Named("core"), defaultLeaseTTL: conf.DefaultLeaseTTL, maxLeaseTTL: conf.MaxLeaseTTL, @@ -1005,6 +1009,17 @@ func NewCore(conf *CoreConfig) (*Core, error) { c.clusterListener.Store((*cluster.Listener)(nil)) + // for listeners with custom response headers, configuring customListenerHeader + if conf.RawConfig.Listeners != nil { + uiHeaders, err := c.UIHeaders() + if err != nil { + return nil, err + } + c.customListenerHeader.Store(NewListenerCustomHeader(conf.RawConfig.Listeners, c.logger, uiHeaders)) + } else { + c.customListenerHeader.Store(([]*ListenerCustomHeaders)(nil)) + } + quotasLogger := conf.Logger.Named("quotas") c.allLoggers = append(c.allLoggers, quotasLogger) c.quotaManager, err = quotas.NewManager(quotasLogger, c.quotaLeaseWalker, c.metricSink) @@ -2641,6 +2656,68 @@ func (c *Core) SetConfig(conf *server.Config) { c.logger.Debug("set config", "sanitized config", string(bz)) } +func (c *Core) GetListenerCustomResponseHeaders(listenerAdd string) *ListenerCustomHeaders { + + customHeaders := c.customListenerHeader.Load() + if customHeaders == nil { + return nil + } + + customHeadersList, ok := customHeaders.([]*ListenerCustomHeaders) + if customHeadersList == nil || !ok { + return nil + } + + for _, l := range customHeadersList { + if l.Address == listenerAdd { + return l + } + } + return nil +} + +// ExistCustomResponseHeader checks if a custom header is configured in any +// listener's stanza +func (c *Core) ExistCustomResponseHeader(header string) bool { + customHeaders := c.customListenerHeader.Load() + if customHeaders == nil { + return false + } + + customHeadersList, ok := customHeaders.([]*ListenerCustomHeaders) + if customHeadersList == nil || !ok { + return false + } + + for _, l := range customHeadersList { + exist := l.ExistCustomResponseHeader(header) + if exist { + return true + } + } + + return false +} + +func (c *Core) ReloadCustomResponseHeaders() error { + conf := c.rawConfig.Load() + if conf == nil { + return fmt.Errorf("failed to load core raw config") + } + lns := conf.(*server.Config).Listeners + if lns == nil { + return fmt.Errorf("no listener configured") + } + + uiHeaders, err := c.UIHeaders() + if err != nil { + return err + } + c.customListenerHeader.Store(NewListenerCustomHeader(lns, c.logger, uiHeaders)) + + return nil +} + // SanitizedConfig returns a sanitized version of the current config. // See server.Config.Sanitized for specific values omitted. func (c *Core) SanitizedConfig() map[string]interface{} { diff --git a/vault/custom_response_headers.go b/vault/custom_response_headers.go new file mode 100644 index 000000000..54df08954 --- /dev/null +++ b/vault/custom_response_headers.go @@ -0,0 +1,90 @@ +package vault + +import ( + "net/http" + "net/textproto" + "strings" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/internalshared/configutil" +) + +type ListenerCustomHeaders struct { + Address string + StatusCodeHeaderMap map[string][]*CustomHeader + // ConfiguredHeadersStatusCodeMap field is introduced so that we would not need to loop through + // StatusCodeHeaderMap to see if a header exists, the key for this map is the headers names + configuredHeadersStatusCodeMap map[string][]string +} + +type CustomHeader struct { + Name string + Value string +} + +func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHeaders http.Header) []*ListenerCustomHeaders { + var listenerCustomHeadersList []*ListenerCustomHeaders + + for _, l := range ln { + listenerCustomHeaderStruct := &ListenerCustomHeaders{ + Address: l.Address, + } + listenerCustomHeaderStruct.StatusCodeHeaderMap = make(map[string][]*CustomHeader) + listenerCustomHeaderStruct.configuredHeadersStatusCodeMap = make(map[string][]string) + for statusCode, headerValMap := range l.CustomResponseHeaders { + var customHeaderList []*CustomHeader + for headerName, headerVal := range headerValMap { + // Sanitizing custom headers + // X-Vault- prefix is reserved for Vault internal processes + if strings.HasPrefix(headerName, "X-Vault-") { + logger.Warn("custom headers starting with X-Vault are not valid", "header", headerName) + continue + } + + // Checking for UI headers, if any common header exists, we just log an error + if uiHeaders != nil { + exist := uiHeaders.Get(headerName) + if exist != "" { + logger.Warn("found a duplicate header in UI", "header:", headerName, "Headers defined in the server configuration take precedence.") + } + } + + // Checking if the header value is not an empty string + if headerVal == "" { + logger.Warn("header value is an empty string", "header", headerName, "value", headerVal) + continue + } + + ch := &CustomHeader{ + Name: headerName, + Value: headerVal, + } + + customHeaderList = append(customHeaderList, ch) + + // setting up the reverse map of header to status code for easy lookups + listenerCustomHeaderStruct.configuredHeadersStatusCodeMap[headerName] = append(listenerCustomHeaderStruct.configuredHeadersStatusCodeMap[headerName], statusCode) + } + listenerCustomHeaderStruct.StatusCodeHeaderMap[statusCode] = customHeaderList + } + listenerCustomHeadersList = append(listenerCustomHeadersList, listenerCustomHeaderStruct) + } + + return listenerCustomHeadersList +} + +func (l *ListenerCustomHeaders) ExistCustomResponseHeader(header string) bool { + if header == "" { + return false + } + + if l.StatusCodeHeaderMap == nil { + return false + } + + headerName := textproto.CanonicalMIMEHeaderKey(header) + + headerMap := l.configuredHeadersStatusCodeMap + _, ok := headerMap[headerName] + return ok +} diff --git a/vault/custom_response_headers_test.go b/vault/custom_response_headers_test.go new file mode 100644 index 000000000..1ea79cf7e --- /dev/null +++ b/vault/custom_response_headers_test.go @@ -0,0 +1,174 @@ +package vault + +import ( + "context" + "fmt" + "net/http/httptest" + "strings" + "testing" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/internalshared/configutil" + "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/sdk/physical/inmem" +) + +var defaultCustomHeaders = map[string]string{ + "Strict-Transport-Security": "max-age=1; domains", + "Content-Security-Policy": "default-src 'others'", + "X-Vault-Ignored": "ignored", + "X-Custom-Header": "Custom header value default", + "X-Frame-Options": "Deny", + "X-Content-Type-Options": "nosniff", + "Content-Type": "text/plain; charset=utf-8", + "X-XSS-Protection": "1; mode=block", +} + +var customHeaders307 = map[string]string{ + "X-Custom-Header": "Custom header value 307", +} + +var customHeader3xx = map[string]string{ + "X-Vault-Ignored-3xx": "Ignored 3xx", + "X-Custom-Header": "Custom header value 3xx", +} + +var customHeaders200 = map[string]string{ + "Someheader-200": "200", + "X-Custom-Header": "Custom header value 200", +} + +var customHeader2xx = map[string]string{ + "X-Custom-Header": "Custom header value 2xx", +} + +var customHeader400 = map[string]string{ + "Someheader-400": "400", +} + +func TestConfigCustomHeaders(t *testing.T) { + logger := logging.NewVaultLogger(log.Trace) + phys, err := inmem.NewTransactionalInmem(nil, logger) + if err != nil { + t.Fatal(err) + } + logl := &logical.InmemStorage{} + uiConfig := NewUIConfig(true, phys, logl) + + rawListenerConfig := []*configutil.Listener{ + { + Type: "tcp", + Address: "127.0.0.1:443", + CustomResponseHeaders: map[string]map[string]string{ + "default": defaultCustomHeaders, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, + }, + }, + } + + uiHeaders, err := uiConfig.Headers(context.Background()) + listenerCustomHeaders := NewListenerCustomHeader(rawListenerConfig, logger, uiHeaders) + if listenerCustomHeaders == nil || len(listenerCustomHeaders) != 1 { + t.Fatalf("failed to get custom header configuration") + } + + lch := listenerCustomHeaders[0] + + if lch.ExistCustomResponseHeader("X-Vault-Ignored-307") { + t.Fatalf("header name with X-Vault prefix is not valid") + } + if lch.ExistCustomResponseHeader("X-Vault-Ignored-3xx") { + t.Fatalf("header name with X-Vault prefix is not valid") + } + + if !lch.ExistCustomResponseHeader("X-Custom-Header") { + t.Fatalf("header name with X-Vault prefix is not valid") + } +} + +func TestCustomResponseHeadersConfigInteractUiConfig(t *testing.T) { + b := testSystemBackend(t) + _, barrier, _ := mockBarrier(t) + view := NewBarrierView(barrier, "") + b.(*SystemBackend).Core.systemBarrierView = view + + logger := logging.NewVaultLogger(log.Trace) + rawListenerConfig := []*configutil.Listener{ + { + Type: "tcp", + Address: "127.0.0.1:443", + CustomResponseHeaders: map[string]map[string]string{ + "default": defaultCustomHeaders, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, + }, + }, + } + uiHeaders, err := b.(*SystemBackend).Core.uiConfig.Headers(context.Background()) + if err != nil { + t.Fatalf("failed to get headers from ui config") + } + customListenerHeader := NewListenerCustomHeader(rawListenerConfig, logger, uiHeaders) + if customListenerHeader == nil { + t.Fatalf("custom header config should be configured") + } + b.(*SystemBackend).Core.customListenerHeader.Store(customListenerHeader) + clh := b.(*SystemBackend).Core.customListenerHeader + if clh == nil { + t.Fatalf("custom header config should be configured in core") + } + + w := httptest.NewRecorder() + hw := logical.NewHTTPResponseWriter(w) + + // setting a header that already exist in custom headers + req := logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/X-Custom-Header") + req.Data["values"] = []string{"UI Custom Header"} + req.ResponseWriter = hw + + resp, err := b.HandleRequest(namespace.RootContext(nil), req) + if err == nil { + t.Fatal("request did not fail on setting a header that is present in custom response headers") + } + if !strings.Contains(resp.Data["error"].(string), fmt.Sprintf("This header already exists in the server configuration and cannot be set in the UI.")) { + t.Fatalf("failed to get the expected error") + } + + // setting a header that already exist in custom headers + req = logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/Someheader-400") + req.Data["values"] = []string{"400"} + req.ResponseWriter = hw + + _, err = b.HandleRequest(namespace.RootContext(nil), req) + if err == nil { + t.Fatal("request did not fail on setting a header that is present in custom response headers") + } + h, err := b.(*SystemBackend).Core.uiConfig.Headers(context.Background()) + if h.Get("Someheader-400") == "400" { + t.Fatalf("should not be able to set a header that is in custom response headers") + } + + // setting an ui specific header + req = logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/X-CustomUiHeader") + req.Data["values"] = []string{"Ui header value"} + req.ResponseWriter = hw + + _, err = b.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatal("request failed on setting a header that is not present in custom response headers.", "error:", err) + } + + h, err = b.(*SystemBackend).Core.uiConfig.Headers(context.Background()) + if h.Get("X-CustomUiHeader") != "Ui header value" { + t.Fatalf("failed to set a header that is not in custom response headers") + } +} diff --git a/vault/logical_system.go b/vault/logical_system.go index 617928781..1675475a4 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -2623,6 +2623,9 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo // Translate the list of values to the valid header string value := http.Header{} for _, v := range values { + if b.Core.ExistCustomResponseHeader(header) { + return logical.ErrorResponse("This header already exists in the server configuration and cannot be set in the UI."), logical.ErrInvalidRequest + } value.Add(header, v) } err := b.Core.uiConfig.SetHeader(ctx, header, value.Values(header)) diff --git a/vault/testing.go b/vault/testing.go index dbe921969..6ba933e85 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -128,6 +128,29 @@ func TestCoreWithSeal(t testing.T, testSeal Seal, enableRaw bool) *Core { return TestCoreWithSealAndUI(t, conf) } +func TestCoreWithCustomResponseHeaderAndUI(t testing.T, CustomResponseHeaders map[string]map[string]string, enableUI bool) (*Core, [][]byte, string) { + confRaw := &server.Config{ + SharedConfig: &configutil.SharedConfig{ + Listeners: []*configutil.Listener{ + { + Type: "tcp", + Address: "127.0.0.1", + CustomResponseHeaders: CustomResponseHeaders, + }, + }, + DisableMlock: true, + }, + } + conf := &CoreConfig{ + RawConfig: confRaw, + EnableUI: enableUI, + EnableRaw: true, + BuiltinRegistry: NewMockBuiltinRegistry(), + } + core := TestCoreWithSealAndUI(t, conf) + return testCoreUnsealed(t, core) +} + func TestCoreUI(t testing.T, enableUI bool) *Core { conf := &CoreConfig{ EnableUI: enableUI, diff --git a/vault/ui.go b/vault/ui.go index c36a247af..bd1d3c688 100644 --- a/vault/ui.go +++ b/vault/ui.go @@ -32,8 +32,9 @@ type UIConfig struct { // NewUIConfig creates a new UI config func NewUIConfig(enabled bool, physicalStorage physical.Backend, barrierStorage logical.Storage) *UIConfig { defaultHeaders := http.Header{} - defaultHeaders.Set("Content-Security-Policy", "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'") defaultHeaders.Set("Service-Worker-Allowed", "/") + defaultHeaders.Set("X-Content-Type-Options", "nosniff") + defaultHeaders.Set("Content-Security-Policy", "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'") return &UIConfig{ physicalStorage: physicalStorage,