Customizing HTTP headers in the config file (#12485)

* Customizing HTTP headers in the config file

* Add changelog, fix bad imports

* fixing some bugs

* fixing interaction of custom headers and /ui

* Defining a member in core to set custom response headers

* missing additional file

* Some refactoring

* Adding automated tests for the feature

* Changing some error messages based on some recommendations

* Incorporating custom response headers struct into the request context

* removing some unused references

* fixing a test

* changing some error messages, removing a default header value from /ui

* fixing a test

* wrapping ResponseWriter to set the custom headers

* adding a new test

* some cleanup

* removing some extra lines

* Addressing comments

* fixing some agent tests

* skipping custom headers from agent listener config,
removing two of the default headers as they cause issues with Vault in UI mode
Adding X-Content-Type-Options to the ui default headers
Let Content-Type be set as before

* Removing default custom headers, and renaming some function varibles

* some refacotring

* Refactoring and addressing comments

* removing a function and fixing comments
This commit is contained in:
hghaf099 2021-10-13 11:06:33 -04:00 committed by GitHub
parent ce0091f5ee
commit ad2ef412cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1019 additions and 23 deletions

3
changelog/12485.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:feature
**Customizable HTTP Headers**: Add support to define custom HTTP headers for root path (`/`) and also on API endpoints (`/v1/*`)
```

View File

@ -35,6 +35,7 @@ func (c *Config) Prune() {
l.RawConfig = nil
l.Profiling.UnusedKeys = nil
l.Telemetry.UnusedKeys = nil
l.CustomResponseHeaders = nil
}
c.FoundKeys = nil
c.UnusedKeys = nil
@ -172,6 +173,12 @@ func LoadConfig(path string) (*Config, error) {
if err != nil {
return nil, err
}
// Pruning custom headers for Agent for now
for _, ln := range sharedConfig.Listeners {
ln.CustomResponseHeaders = nil
}
result.SharedConfig = sharedConfig
list, ok := obj.Node.(*ast.ObjectList)

View File

@ -536,7 +536,6 @@ func TestLoadConfigFile_AgentCache_PersistMissingType(t *testing.T) {
}
func TestLoadConfigFile_TemplateConfig(t *testing.T) {
testCases := map[string]struct {
fixturePath string
expectedTemplateConfig TemplateConfig
@ -586,7 +585,6 @@ func TestLoadConfigFile_TemplateConfig(t *testing.T) {
}
})
}
}
// TestLoadConfigFile_Template tests template definitions in Vault Agent

View File

@ -1541,6 +1541,12 @@ func (c *ServerCommand) Run(args []string) int {
core.SetConfig(config)
// reloading custom response headers to make sure we have
// the most up to date headers after reloading the config file
if err = core.ReloadCustomResponseHeaders(); err != nil {
c.logger.Error(err.Error())
}
if config.LogLevel != "" {
configLogLevel := strings.ToLower(strings.TrimSpace(config.LogLevel))
switch configLogLevel {

View File

@ -0,0 +1,109 @@
package server
import (
"fmt"
"testing"
"github.com/go-test/deep"
)
var defaultCustomHeaders = map[string]string{
"Strict-Transport-Security": "max-age=1; domains",
"Content-Security-Policy": "default-src 'others'",
"X-Vault-Ignored": "ignored",
"X-Custom-Header": "Custom header value default",
}
var customHeaders307 = map[string]string{
"X-Custom-Header": "Custom header value 307",
}
var customHeader3xx = map[string]string{
"X-Vault-Ignored-3xx": "Ignored 3xx",
"X-Custom-Header": "Custom header value 3xx",
}
var customHeaders200 = map[string]string{
"Someheader-200": "200",
"X-Custom-Header": "Custom header value 200",
}
var customHeader2xx = map[string]string{
"X-Custom-Header": "Custom header value 2xx",
}
var customHeader400 = map[string]string{
"Someheader-400": "400",
}
var defaultCustomHeadersMultiListener = map[string]string{
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
"Content-Security-Policy": "default-src 'others'",
"X-Vault-Ignored": "ignored",
"X-Custom-Header": "Custom header value default",
}
var defaultSTS = map[string]string{
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
}
func TestCustomResponseHeadersConfigs(t *testing.T) {
expectedCustomResponseHeader := map[string]map[string]string{
"default": defaultCustomHeaders,
"307": customHeaders307,
"3xx": customHeader3xx,
"200": customHeaders200,
"2xx": customHeader2xx,
"400": customHeader400,
}
config, err := LoadConfigFile("./test-fixtures/config_custom_response_headers_1.hcl")
if err != nil {
t.Fatalf("Error encountered when loading config %+v", err)
}
if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[0].CustomResponseHeaders); diff != nil {
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
}
}
func TestCustomResponseHeadersConfigsMultipleListeners(t *testing.T) {
expectedCustomResponseHeader := map[string]map[string]string{
"default": defaultCustomHeadersMultiListener,
"307": customHeaders307,
"3xx": customHeader3xx,
"200": customHeaders200,
"2xx": customHeader2xx,
"400": customHeader400,
}
config, err := LoadConfigFile("./test-fixtures/config_custom_response_headers_multiple_listeners.hcl")
if err != nil {
t.Fatalf("Error encountered when loading config %+v", err)
}
if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[0].CustomResponseHeaders); diff != nil {
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
}
if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[1].CustomResponseHeaders); diff == nil {
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
}
if diff := deep.Equal(expectedCustomResponseHeader["default"], config.Listeners[1].CustomResponseHeaders["default"]); diff != nil {
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
}
if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[2].CustomResponseHeaders); diff == nil {
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
}
if diff := deep.Equal(defaultSTS, config.Listeners[2].CustomResponseHeaders["default"]); diff != nil {
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
}
if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[3].CustomResponseHeaders); diff == nil {
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
}
if diff := deep.Equal(defaultSTS, config.Listeners[3].CustomResponseHeaders["default"]); diff != nil {
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
}
}

View File

@ -16,6 +16,12 @@ import (
"github.com/hashicorp/vault/internalshared/configutil"
)
var DefaultCustomHeaders = map[string]map[string]string {
"default": {
"Strict-Transport-Security": configutil.StrictTransportSecurity,
},
}
func boolPointer(x bool) *bool {
return &x
}
@ -32,6 +38,7 @@ func testConfigRaftRetryJoin(t *testing.T) {
{
Type: "tcp",
Address: "127.0.0.1:8200",
CustomResponseHeaders: DefaultCustomHeaders,
},
},
DisableMlock: true,
@ -64,6 +71,7 @@ func testLoadConfigFile_topLevel(t *testing.T, entropy *configutil.Entropy) {
{
Type: "tcp",
Address: "127.0.0.1:443",
CustomResponseHeaders: DefaultCustomHeaders,
},
},
@ -174,10 +182,12 @@ func testLoadConfigFile_json2(t *testing.T, entropy *configutil.Entropy) {
{
Type: "tcp",
Address: "127.0.0.1:443",
CustomResponseHeaders: DefaultCustomHeaders,
},
{
Type: "tcp",
Address: "127.0.0.1:444",
CustomResponseHeaders: DefaultCustomHeaders,
},
},
@ -336,6 +346,7 @@ func testLoadConfigFileIntegerAndBooleanValuesCommon(t *testing.T, path string)
{
Type: "tcp",
Address: "127.0.0.1:8200",
CustomResponseHeaders: DefaultCustomHeaders,
},
},
DisableMlock: true,
@ -379,6 +390,7 @@ func testLoadConfigFile(t *testing.T) {
{
Type: "tcp",
Address: "127.0.0.1:443",
CustomResponseHeaders: DefaultCustomHeaders,
},
},
@ -486,7 +498,7 @@ func testUnknownFieldValidation(t *testing.T) {
for _, er1 := range errors {
found := false
if strings.Contains(er1.String(), "sentinel") {
//This happens on OSS, and is fine
// This happens on OSS, and is fine
continue
}
for _, ex := range expected {
@ -525,6 +537,7 @@ func testLoadConfigFile_json(t *testing.T) {
{
Type: "tcp",
Address: "127.0.0.1:443",
CustomResponseHeaders: DefaultCustomHeaders,
},
},
@ -610,6 +623,7 @@ func testLoadConfigDir(t *testing.T) {
{
Type: "tcp",
Address: "127.0.0.1:443",
CustomResponseHeaders: DefaultCustomHeaders,
},
},
@ -818,6 +832,7 @@ listener "tcp" {
Profiling: configutil.ListenerProfiling{
UnauthenticatedPProfAccess: true,
},
CustomResponseHeaders: DefaultCustomHeaders,
},
},
},
@ -845,6 +860,7 @@ func testParseSeals(t *testing.T) {
{
Type: "tcp",
Address: "127.0.0.1:443",
CustomResponseHeaders: DefaultCustomHeaders,
},
},
Seals: []*configutil.KMS{
@ -898,6 +914,7 @@ func testLoadConfigFileLeaseMetrics(t *testing.T) {
{
Type: "tcp",
Address: "127.0.0.1:443",
CustomResponseHeaders: DefaultCustomHeaders,
},
},

View File

@ -0,0 +1,31 @@
storage "inmem" {}
listener "tcp" {
address = "127.0.0.1:8200"
tls_disable = true
custom_response_headers {
"default" = {
"Strict-Transport-Security" = ["max-age=1","domains"],
"Content-Security-Policy" = ["default-src 'others'"],
"X-Vault-Ignored" = ["ignored"],
"X-Custom-Header" = ["Custom header value default"],
}
"307" = {
"X-Custom-Header" = ["Custom header value 307"],
}
"3xx" = {
"X-Vault-Ignored-3xx" = ["Ignored 3xx"],
"X-Custom-Header" = ["Custom header value 3xx"]
}
"200" = {
"someheader-200" = ["200"],
"X-Custom-Header" = ["Custom header value 200"]
}
"2xx" = {
"X-Custom-Header" = ["Custom header value 2xx"]
}
"400" = {
"someheader-400" = ["400"]
}
}
}
disable_mlock = true

View File

@ -0,0 +1,56 @@
storage "inmem" {}
listener "tcp" {
address = "127.0.0.1:8200"
tls_disable = true
custom_response_headers {
"default" = {
"Content-Security-Policy" = ["default-src 'others'"],
"X-Vault-Ignored" = ["ignored"],
"X-Custom-Header" = ["Custom header value default"],
}
"307" = {
"X-Custom-Header" = ["Custom header value 307"],
}
"3xx" = {
"X-Vault-Ignored-3xx" = ["Ignored 3xx"],
"X-Custom-Header" = ["Custom header value 3xx"]
}
"200" = {
"someheader-200" = ["200"],
"X-Custom-Header" = ["Custom header value 200"]
}
"2xx" = {
"X-Custom-Header" = ["Custom header value 2xx"]
}
"400" = {
"someheader-400" = ["400"]
}
}
}
listener "tcp" {
address = "127.0.0.2:8200"
tls_disable = true
custom_response_headers {
"default" = {
"Content-Security-Policy" = ["default-src 'others'"],
"X-Vault-Ignored" = ["ignored"],
"X-Custom-Header" = ["Custom header value default"],
}
}
}
listener "tcp" {
address = "127.0.0.3:8200"
tls_disable = true
custom_response_headers {
"2xx" = {
"X-Custom-Header" = ["Custom header value 2xx"]
}
}
}
listener "tcp" {
address = "127.0.0.4:8200"
tls_disable = true
}
disable_mlock = true

128
http/custom_header_test.go Normal file
View File

@ -0,0 +1,128 @@
package http
import (
"testing"
"github.com/hashicorp/vault/vault"
)
var defaultCustomHeaders = map[string]string {
"Strict-Transport-Security": "max-age=1; domains",
"Content-Security-Policy": "default-src 'others'",
"X-Custom-Header": "Custom header value default",
"X-Frame-Options": "Deny",
"X-Content-Type-Options": "nosniff",
"Content-Type": "application/json",
"X-XSS-Protection": "1; mode=block",
}
var customHeader2xx = map[string]string {
"X-Custom-Header": "Custom header value 2xx",
}
var customHeader200 = map[string]string {
"Someheader-200": "200",
"X-Custom-Header": "Custom header value 200",
}
var customHeader4xx = map[string]string {
"Someheader-4xx": "4xx",
}
var customHeader400 = map[string]string {
"Someheader-400": "400",
}
var customHeader405 = map[string]string {
"Someheader-405": "405",
}
var CustomResponseHeaders = map[string]map[string]string{
"default": defaultCustomHeaders,
"307": {"X-Custom-Header": "Custom header value 307"},
"3xx": {
"X-Custom-Header": "Custom header value 3xx",
"X-Vault-Ignored-3xx": "Ignored 3xx",
},
"200": customHeader200,
"2xx": customHeader2xx,
"400": customHeader400,
"405": customHeader405,
"4xx": customHeader4xx,
}
func TestCustomResponseHeaders(t *testing.T) {
core, _, token := vault.TestCoreWithCustomResponseHeaderAndUI(t, CustomResponseHeaders, true)
ln, addr := TestServer(t, core)
defer ln.Close()
TestServerAuth(t, addr, token)
resp := testHttpGet(t, token, addr+"/v1/sys/raw/")
testResponseStatus(t, resp, 404)
testResponseHeader(t, resp, defaultCustomHeaders)
testResponseHeader(t, resp, customHeader4xx)
resp = testHttpGet(t, token, addr+"/v1/sys/seal")
testResponseStatus(t, resp, 405)
testResponseHeader(t, resp, defaultCustomHeaders)
testResponseHeader(t, resp, customHeader4xx)
testResponseHeader(t, resp, customHeader405)
resp = testHttpGet(t, token, addr+"/v1/sys/leader")
testResponseStatus(t, resp, 200)
testResponseHeader(t, resp, customHeader200)
resp = testHttpGet(t, token, addr+"/v1/sys/health")
testResponseStatus(t, resp, 200)
testResponseHeader(t, resp, customHeader200)
resp = testHttpGet(t, token, addr+"/v1/sys/generate-root/attempt")
testResponseStatus(t, resp, 200)
testResponseHeader(t, resp, customHeader200)
resp = testHttpGet(t, token, addr+"/v1/sys/generate-root/update")
testResponseStatus(t, resp, 400)
testResponseHeader(t, resp, defaultCustomHeaders)
testResponseHeader(t, resp, customHeader4xx)
testResponseHeader(t, resp, customHeader400)
resp = testHttpGet(t, token, addr+"/v1/sys/")
testResponseStatus(t, resp, 404)
testResponseHeader(t, resp, defaultCustomHeaders)
testResponseHeader(t, resp, customHeader4xx)
resp = testHttpGet(t, token, addr+"/v1/sys")
testResponseStatus(t, resp, 404)
testResponseHeader(t, resp, defaultCustomHeaders)
testResponseHeader(t, resp, customHeader4xx)
resp = testHttpGet(t, token, addr+"/v1/")
testResponseStatus(t, resp, 404)
testResponseHeader(t, resp, defaultCustomHeaders)
testResponseHeader(t, resp, customHeader4xx)
resp = testHttpGet(t, token, addr+"/v1")
testResponseStatus(t, resp, 404)
testResponseHeader(t, resp, defaultCustomHeaders)
testResponseHeader(t, resp, customHeader4xx)
resp = testHttpGet(t, token, addr+"/")
testResponseStatus(t, resp, 200)
testResponseHeader(t, resp, customHeader200)
resp = testHttpGet(t, token, addr+"/ui")
testResponseStatus(t, resp, 200)
testResponseHeader(t, resp, customHeader200)
resp = testHttpGet(t, token, addr+"/ui/")
testResponseStatus(t, resp, 200)
testResponseHeader(t, resp, customHeader200)
resp = testHttpPost(t, token, addr+"/v1/sys/auth/foo", map[string]interface{}{
"type": "noop",
"description": "foo",
})
testResponseStatus(t, resp, 204)
testResponseHeader(t, resp, customHeader2xx)
}

View File

@ -16,12 +16,14 @@ import (
"net/textproto"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/NYTimes/gziphandler"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-cleanhttp"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/go-sockaddr"
"github.com/hashicorp/vault/helper/namespace"
@ -210,6 +212,90 @@ func Handler(props *vault.HandlerProperties) http.Handler {
return printablePathCheckHandler
}
type WrappingResponseWriter interface {
http.ResponseWriter
Wrapped() http.ResponseWriter
}
type statusHeaderResponseWriter struct {
wrapped http.ResponseWriter
logger log.Logger
wroteHeader bool
statusCode int
headers map[string][]*vault.CustomHeader
}
func (w *statusHeaderResponseWriter) Wrapped() http.ResponseWriter {
return w.wrapped
}
func (w *statusHeaderResponseWriter) Header() http.Header {
return w.wrapped.Header()
}
func (w *statusHeaderResponseWriter) Write(buf []byte) (int, error) {
// It is allowed to only call ResponseWriter.Write and skip
// ResponseWriter.WriteHeader. An example of such a situation is
// "handleUIStub". The Write function will internally set the status code
// 200 for the response for which that call might invoke other
// implementations of the WriteHeader function. So, we still need to set
// the custom headers. In cases where both WriteHeader and Write of
// statusHeaderResponseWriter struct are called the internal call to the
// WriterHeader invoked from inside Write method won't change the headers.
if !w.wroteHeader {
w.setCustomResponseHeaders(w.statusCode)
}
return w.wrapped.Write(buf)
}
func (w *statusHeaderResponseWriter) WriteHeader(statusCode int) {
w.setCustomResponseHeaders(statusCode)
w.wrapped.WriteHeader(statusCode)
w.statusCode = statusCode
// in cases where Write is called after WriteHeader, let's prevent setting
// ResponseWriter headers twice
w.wroteHeader = true
}
func (w *statusHeaderResponseWriter) setCustomResponseHeaders(status int) {
sch := w.headers
if sch == nil {
w.logger.Warn("status code header map not configured")
return
}
// Checking the validity of the status code
if status >= 600 || status < 100 {
return
}
// setter function to set the headers
setter := func(hvl []*vault.CustomHeader) {
for _, hv := range hvl {
w.Header().Set(hv.Name, hv.Value)
}
}
// Setting the default headers first
setter(sch["default"])
// setting the Xyy pattern first
d := fmt.Sprintf("%vxx", status/100)
if val, ok := sch[d]; ok {
setter(val)
}
// Setting the specific headers
if val, ok := sch[strconv.Itoa(status)]; ok {
setter(val)
}
return
}
var _ WrappingResponseWriter = &statusHeaderResponseWriter{}
type copyResponseWriter struct {
wrapped http.ResponseWriter
statusCode int
@ -300,6 +386,22 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
hostname, _ := os.Hostname()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// This block needs to be here so that upon sending SIGHUP, custom response
// headers are also reloaded into the handlers.
if props.ListenerConfig != nil {
la := props.ListenerConfig.Address
listenerCustomHeaders := core.GetListenerCustomResponseHeaders(la)
if listenerCustomHeaders != nil {
w = &statusHeaderResponseWriter{
wrapped: w,
logger: core.Logger(),
wroteHeader: false,
statusCode: 200,
headers: listenerCustomHeaders.StatusCodeHeaderMap,
}
}
}
// Set the Cache-Control header for all the responses returned
// by Vault
w.Header().Set("Cache-Control", "no-store")
@ -632,7 +734,15 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
return nil, errors.New("could not parse max_request_size from request context")
}
if max > 0 {
reader = http.MaxBytesReader(w, r.Body, max)
// MaxBytesReader won't do all the internal stuff it must unless it's
// given a ResponseWriter that implements the internal http interface
// requestTooLarger. So we let it have access to the underlying
// ResponseWriter.
inw := w
if myw, ok := inw.(WrappingResponseWriter); ok {
inw = myw.Wrapped()
}
reader = http.MaxBytesReader(inw, r.Body, max)
}
}
var origBody io.ReadWriter

View File

@ -125,6 +125,16 @@ func testResponseStatus(t *testing.T, resp *http.Response, code int) {
}
}
func testResponseHeader(t *testing.T, resp *http.Response, expectedHeaders map[string]string) {
t.Helper()
for k, v := range expectedHeaders {
hv := resp.Header.Get(k)
if v != hv {
t.Fatalf("expected header value %v=%v, got %v=%v", k, v, k, hv)
}
}
}
func testResponseBody(t *testing.T, resp *http.Response, out interface{}) {
defer resp.Body.Close()

View File

@ -35,12 +35,14 @@ func handleMetricsUnauthenticated(core *vault.Core) http.Handler {
resp := core.MetricsHelper().ResponseForFormat(format)
// Manually extract the logical response and send back the information
w.WriteHeader(resp.Data[logical.HTTPStatusCode].(int))
status := resp.Data[logical.HTTPStatusCode].(int)
w.Header().Set("Content-Type", resp.Data[logical.HTTPContentType].(string))
switch v := resp.Data[logical.HTTPRawBody].(type) {
case string:
w.WriteHeader(status)
w.Write([]byte(v))
case []byte:
w.WriteHeader(status)
w.Write(v)
default:
respondError(w, http.StatusInternalServerError, fmt.Errorf("wrong response returned"))

View File

@ -6,6 +6,7 @@ import (
"net/http"
"testing"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/vault"
)
@ -41,10 +42,16 @@ func TestServerWithListenerAndProperties(tb testing.TB, ln net.Listener, addr st
}
func TestServerWithListener(tb testing.TB, ln net.Listener, addr string, core *vault.Core) {
ip, _, _ := net.SplitHostPort(ln.Addr().String())
// Create a muxer to handle our requests so that we can authenticate
// for tests.
props := &vault.HandlerProperties{
Core: core,
// This is needed for testing custom response headers
ListenerConfig: &configutil.Listener {
Address: ip,
},
}
TestServerWithListenerAndProperties(tb, ln, addr, core, props)
}

View File

@ -0,0 +1,129 @@
package configutil
import (
"fmt"
"net/textproto"
"strconv"
"strings"
"github.com/hashicorp/go-secure-stdlib/strutil"
)
var ValidCustomStatusCodeCollection = []string{
"default",
"1xx",
"2xx",
"3xx",
"4xx",
"5xx",
}
const StrictTransportSecurity = "max-age=31536000; includeSubDomains"
// ParseCustomResponseHeaders takes a raw config values for the
// "custom_response_headers". It makes sure the config entry is passed in
// as a map of status code to a map of header name and header values. It
// verifies the validity of the status codes, and header values. It also
// adds the default headers values.
func ParseCustomResponseHeaders(responseHeaders interface{}) (map[string]map[string]string, error) {
h := make(map[string]map[string]string)
// if r is nil, we still should set the default custom headers
if responseHeaders == nil {
h["default"] = map[string]string{"Strict-Transport-Security": StrictTransportSecurity}
return h, nil
}
customResponseHeader, ok := responseHeaders.([]map[string]interface{})
if !ok {
return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps")
}
for _, crh := range customResponseHeader {
for statusCode, responseHeader := range crh {
headerValList, ok := responseHeader.([]map[string]interface{})
if !ok {
return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps")
}
if !IsValidStatusCode(statusCode) {
return nil, fmt.Errorf("invalid status code found in the server configuration: %v", statusCode)
}
if len(headerValList) != 1 {
return nil, fmt.Errorf("invalid number of response headers exist")
}
headerValMap := headerValList[0]
headerVal, err := parseHeaders(headerValMap)
if err != nil {
return nil, err
}
h[statusCode] = headerVal
}
}
// setting Strict-Transport-Security as a default header
if h["default"] == nil {
h["default"] = make(map[string]string)
}
if _, ok := h["default"]["Strict-Transport-Security"]; !ok {
h["default"]["Strict-Transport-Security"] = StrictTransportSecurity
}
return h, nil
}
// IsValidStatusCode checking for status codes outside the boundary
func IsValidStatusCode(sc string) bool {
if strutil.StrListContains(ValidCustomStatusCodeCollection, sc) {
return true
}
i, err := strconv.Atoi(sc)
if err != nil {
return false
}
if i >= 600 || i < 100 {
return false
}
return true
}
func parseHeaders(in map[string]interface{}) (map[string]string, error) {
hvMap := make(map[string]string)
for k, v := range in {
// parsing header name
headerName := textproto.CanonicalMIMEHeaderKey(k)
// parsing header values
s, err := parseHeaderValues(v)
if err != nil {
return nil, err
}
hvMap[headerName] = s
}
return hvMap, nil
}
func parseHeaderValues(header interface{}) (string, error) {
var sl []string
if _, ok := header.([]interface{}); !ok {
return "", fmt.Errorf("headers must be given in a list of strings")
}
headerValList := header.([]interface{})
for _, vh := range headerValList {
if _, ok := vh.(string); !ok {
return "", fmt.Errorf("found a non-string header value: %v", vh)
}
headerVal := strings.TrimSpace(vh.(string))
if headerVal == "" {
continue
}
sl = append(sl, headerVal)
}
s := strings.Join(sl, "; ")
return s, nil
}

View File

@ -99,6 +99,10 @@ type Listener struct {
CorsAllowedOrigins []string `hcl:"cors_allowed_origins"`
CorsAllowedHeaders []string `hcl:"-"`
CorsAllowedHeadersRaw []string `hcl:"cors_allowed_headers,alias:cors_allowed_headers"`
// Custom Http response headers
CustomResponseHeaders map[string]map[string]string `hcl:"-"`
CustomResponseHeadersRaw interface{} `hcl:"custom_response_headers"`
}
func (l *Listener) GoString() string {
@ -361,6 +365,17 @@ func ParseListeners(result *SharedConfig, list *ast.ObjectList) error {
}
}
// HTTP Headers
{
// if CustomResponseHeadersRaw is nil, we still need to set the default headers
customHeadersMap, err := ParseCustomResponseHeaders(l.CustomResponseHeadersRaw)
if err != nil {
return multierror.Prefix(fmt.Errorf("failed to parse custom_response_headers: %w", err), fmt.Sprintf("listeners.%d", i))
}
l.CustomResponseHeaders = customHeadersMap
l.CustomResponseHeadersRaw = nil
}
result.Listeners = append(result.Listeners, &l)
}

View File

@ -510,6 +510,9 @@ type Core struct {
// clusterListener starts up and manages connections on the cluster ports
clusterListener *atomic.Value
// customListenerHeader holds custom response headers for a listener
customListenerHeader *atomic.Value
// Telemetry objects
metricsHelper *metricsutil.MetricsHelper
@ -769,23 +772,24 @@ func CreateCore(conf *CoreConfig) (*Core, error) {
// Setup the core
c := &Core{
entCore: entCore{},
devToken: conf.DevToken,
physical: conf.Physical,
serviceRegistration: conf.GetServiceRegistration(),
underlyingPhysical: conf.Physical,
storageType: conf.StorageType,
redirectAddr: conf.RedirectAddr,
clusterAddr: new(atomic.Value),
clusterListener: new(atomic.Value),
seal: conf.Seal,
router: NewRouter(),
sealed: new(uint32),
sealMigrationDone: new(uint32),
standby: true,
standbyStopCh: new(atomic.Value),
baseLogger: conf.Logger,
logger: conf.Logger.Named("core"),
entCore: entCore{},
devToken: conf.DevToken,
physical: conf.Physical,
serviceRegistration: conf.GetServiceRegistration(),
underlyingPhysical: conf.Physical,
storageType: conf.StorageType,
redirectAddr: conf.RedirectAddr,
clusterAddr: new(atomic.Value),
clusterListener: new(atomic.Value),
customListenerHeader: new(atomic.Value),
seal: conf.Seal,
router: NewRouter(),
sealed: new(uint32),
sealMigrationDone: new(uint32),
standby: true,
standbyStopCh: new(atomic.Value),
baseLogger: conf.Logger,
logger: conf.Logger.Named("core"),
defaultLeaseTTL: conf.DefaultLeaseTTL,
maxLeaseTTL: conf.MaxLeaseTTL,
@ -1005,6 +1009,17 @@ func NewCore(conf *CoreConfig) (*Core, error) {
c.clusterListener.Store((*cluster.Listener)(nil))
// for listeners with custom response headers, configuring customListenerHeader
if conf.RawConfig.Listeners != nil {
uiHeaders, err := c.UIHeaders()
if err != nil {
return nil, err
}
c.customListenerHeader.Store(NewListenerCustomHeader(conf.RawConfig.Listeners, c.logger, uiHeaders))
} else {
c.customListenerHeader.Store(([]*ListenerCustomHeaders)(nil))
}
quotasLogger := conf.Logger.Named("quotas")
c.allLoggers = append(c.allLoggers, quotasLogger)
c.quotaManager, err = quotas.NewManager(quotasLogger, c.quotaLeaseWalker, c.metricSink)
@ -2641,6 +2656,68 @@ func (c *Core) SetConfig(conf *server.Config) {
c.logger.Debug("set config", "sanitized config", string(bz))
}
func (c *Core) GetListenerCustomResponseHeaders(listenerAdd string) *ListenerCustomHeaders {
customHeaders := c.customListenerHeader.Load()
if customHeaders == nil {
return nil
}
customHeadersList, ok := customHeaders.([]*ListenerCustomHeaders)
if customHeadersList == nil || !ok {
return nil
}
for _, l := range customHeadersList {
if l.Address == listenerAdd {
return l
}
}
return nil
}
// ExistCustomResponseHeader checks if a custom header is configured in any
// listener's stanza
func (c *Core) ExistCustomResponseHeader(header string) bool {
customHeaders := c.customListenerHeader.Load()
if customHeaders == nil {
return false
}
customHeadersList, ok := customHeaders.([]*ListenerCustomHeaders)
if customHeadersList == nil || !ok {
return false
}
for _, l := range customHeadersList {
exist := l.ExistCustomResponseHeader(header)
if exist {
return true
}
}
return false
}
func (c *Core) ReloadCustomResponseHeaders() error {
conf := c.rawConfig.Load()
if conf == nil {
return fmt.Errorf("failed to load core raw config")
}
lns := conf.(*server.Config).Listeners
if lns == nil {
return fmt.Errorf("no listener configured")
}
uiHeaders, err := c.UIHeaders()
if err != nil {
return err
}
c.customListenerHeader.Store(NewListenerCustomHeader(lns, c.logger, uiHeaders))
return nil
}
// SanitizedConfig returns a sanitized version of the current config.
// See server.Config.Sanitized for specific values omitted.
func (c *Core) SanitizedConfig() map[string]interface{} {

View File

@ -0,0 +1,90 @@
package vault
import (
"net/http"
"net/textproto"
"strings"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/internalshared/configutil"
)
type ListenerCustomHeaders struct {
Address string
StatusCodeHeaderMap map[string][]*CustomHeader
// ConfiguredHeadersStatusCodeMap field is introduced so that we would not need to loop through
// StatusCodeHeaderMap to see if a header exists, the key for this map is the headers names
configuredHeadersStatusCodeMap map[string][]string
}
type CustomHeader struct {
Name string
Value string
}
func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHeaders http.Header) []*ListenerCustomHeaders {
var listenerCustomHeadersList []*ListenerCustomHeaders
for _, l := range ln {
listenerCustomHeaderStruct := &ListenerCustomHeaders{
Address: l.Address,
}
listenerCustomHeaderStruct.StatusCodeHeaderMap = make(map[string][]*CustomHeader)
listenerCustomHeaderStruct.configuredHeadersStatusCodeMap = make(map[string][]string)
for statusCode, headerValMap := range l.CustomResponseHeaders {
var customHeaderList []*CustomHeader
for headerName, headerVal := range headerValMap {
// Sanitizing custom headers
// X-Vault- prefix is reserved for Vault internal processes
if strings.HasPrefix(headerName, "X-Vault-") {
logger.Warn("custom headers starting with X-Vault are not valid", "header", headerName)
continue
}
// Checking for UI headers, if any common header exists, we just log an error
if uiHeaders != nil {
exist := uiHeaders.Get(headerName)
if exist != "" {
logger.Warn("found a duplicate header in UI", "header:", headerName, "Headers defined in the server configuration take precedence.")
}
}
// Checking if the header value is not an empty string
if headerVal == "" {
logger.Warn("header value is an empty string", "header", headerName, "value", headerVal)
continue
}
ch := &CustomHeader{
Name: headerName,
Value: headerVal,
}
customHeaderList = append(customHeaderList, ch)
// setting up the reverse map of header to status code for easy lookups
listenerCustomHeaderStruct.configuredHeadersStatusCodeMap[headerName] = append(listenerCustomHeaderStruct.configuredHeadersStatusCodeMap[headerName], statusCode)
}
listenerCustomHeaderStruct.StatusCodeHeaderMap[statusCode] = customHeaderList
}
listenerCustomHeadersList = append(listenerCustomHeadersList, listenerCustomHeaderStruct)
}
return listenerCustomHeadersList
}
func (l *ListenerCustomHeaders) ExistCustomResponseHeader(header string) bool {
if header == "" {
return false
}
if l.StatusCodeHeaderMap == nil {
return false
}
headerName := textproto.CanonicalMIMEHeaderKey(header)
headerMap := l.configuredHeadersStatusCodeMap
_, ok := headerMap[headerName]
return ok
}

View File

@ -0,0 +1,174 @@
package vault
import (
"context"
"fmt"
"net/http/httptest"
"strings"
"testing"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/physical/inmem"
)
var defaultCustomHeaders = map[string]string{
"Strict-Transport-Security": "max-age=1; domains",
"Content-Security-Policy": "default-src 'others'",
"X-Vault-Ignored": "ignored",
"X-Custom-Header": "Custom header value default",
"X-Frame-Options": "Deny",
"X-Content-Type-Options": "nosniff",
"Content-Type": "text/plain; charset=utf-8",
"X-XSS-Protection": "1; mode=block",
}
var customHeaders307 = map[string]string{
"X-Custom-Header": "Custom header value 307",
}
var customHeader3xx = map[string]string{
"X-Vault-Ignored-3xx": "Ignored 3xx",
"X-Custom-Header": "Custom header value 3xx",
}
var customHeaders200 = map[string]string{
"Someheader-200": "200",
"X-Custom-Header": "Custom header value 200",
}
var customHeader2xx = map[string]string{
"X-Custom-Header": "Custom header value 2xx",
}
var customHeader400 = map[string]string{
"Someheader-400": "400",
}
func TestConfigCustomHeaders(t *testing.T) {
logger := logging.NewVaultLogger(log.Trace)
phys, err := inmem.NewTransactionalInmem(nil, logger)
if err != nil {
t.Fatal(err)
}
logl := &logical.InmemStorage{}
uiConfig := NewUIConfig(true, phys, logl)
rawListenerConfig := []*configutil.Listener{
{
Type: "tcp",
Address: "127.0.0.1:443",
CustomResponseHeaders: map[string]map[string]string{
"default": defaultCustomHeaders,
"307": customHeaders307,
"3xx": customHeader3xx,
"200": customHeaders200,
"2xx": customHeader2xx,
"400": customHeader400,
},
},
}
uiHeaders, err := uiConfig.Headers(context.Background())
listenerCustomHeaders := NewListenerCustomHeader(rawListenerConfig, logger, uiHeaders)
if listenerCustomHeaders == nil || len(listenerCustomHeaders) != 1 {
t.Fatalf("failed to get custom header configuration")
}
lch := listenerCustomHeaders[0]
if lch.ExistCustomResponseHeader("X-Vault-Ignored-307") {
t.Fatalf("header name with X-Vault prefix is not valid")
}
if lch.ExistCustomResponseHeader("X-Vault-Ignored-3xx") {
t.Fatalf("header name with X-Vault prefix is not valid")
}
if !lch.ExistCustomResponseHeader("X-Custom-Header") {
t.Fatalf("header name with X-Vault prefix is not valid")
}
}
func TestCustomResponseHeadersConfigInteractUiConfig(t *testing.T) {
b := testSystemBackend(t)
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "")
b.(*SystemBackend).Core.systemBarrierView = view
logger := logging.NewVaultLogger(log.Trace)
rawListenerConfig := []*configutil.Listener{
{
Type: "tcp",
Address: "127.0.0.1:443",
CustomResponseHeaders: map[string]map[string]string{
"default": defaultCustomHeaders,
"307": customHeaders307,
"3xx": customHeader3xx,
"200": customHeaders200,
"2xx": customHeader2xx,
"400": customHeader400,
},
},
}
uiHeaders, err := b.(*SystemBackend).Core.uiConfig.Headers(context.Background())
if err != nil {
t.Fatalf("failed to get headers from ui config")
}
customListenerHeader := NewListenerCustomHeader(rawListenerConfig, logger, uiHeaders)
if customListenerHeader == nil {
t.Fatalf("custom header config should be configured")
}
b.(*SystemBackend).Core.customListenerHeader.Store(customListenerHeader)
clh := b.(*SystemBackend).Core.customListenerHeader
if clh == nil {
t.Fatalf("custom header config should be configured in core")
}
w := httptest.NewRecorder()
hw := logical.NewHTTPResponseWriter(w)
// setting a header that already exist in custom headers
req := logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/X-Custom-Header")
req.Data["values"] = []string{"UI Custom Header"}
req.ResponseWriter = hw
resp, err := b.HandleRequest(namespace.RootContext(nil), req)
if err == nil {
t.Fatal("request did not fail on setting a header that is present in custom response headers")
}
if !strings.Contains(resp.Data["error"].(string), fmt.Sprintf("This header already exists in the server configuration and cannot be set in the UI.")) {
t.Fatalf("failed to get the expected error")
}
// setting a header that already exist in custom headers
req = logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/Someheader-400")
req.Data["values"] = []string{"400"}
req.ResponseWriter = hw
_, err = b.HandleRequest(namespace.RootContext(nil), req)
if err == nil {
t.Fatal("request did not fail on setting a header that is present in custom response headers")
}
h, err := b.(*SystemBackend).Core.uiConfig.Headers(context.Background())
if h.Get("Someheader-400") == "400" {
t.Fatalf("should not be able to set a header that is in custom response headers")
}
// setting an ui specific header
req = logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/X-CustomUiHeader")
req.Data["values"] = []string{"Ui header value"}
req.ResponseWriter = hw
_, err = b.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatal("request failed on setting a header that is not present in custom response headers.", "error:", err)
}
h, err = b.(*SystemBackend).Core.uiConfig.Headers(context.Background())
if h.Get("X-CustomUiHeader") != "Ui header value" {
t.Fatalf("failed to set a header that is not in custom response headers")
}
}

View File

@ -2623,6 +2623,9 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo
// Translate the list of values to the valid header string
value := http.Header{}
for _, v := range values {
if b.Core.ExistCustomResponseHeader(header) {
return logical.ErrorResponse("This header already exists in the server configuration and cannot be set in the UI."), logical.ErrInvalidRequest
}
value.Add(header, v)
}
err := b.Core.uiConfig.SetHeader(ctx, header, value.Values(header))

View File

@ -128,6 +128,29 @@ func TestCoreWithSeal(t testing.T, testSeal Seal, enableRaw bool) *Core {
return TestCoreWithSealAndUI(t, conf)
}
func TestCoreWithCustomResponseHeaderAndUI(t testing.T, CustomResponseHeaders map[string]map[string]string, enableUI bool) (*Core, [][]byte, string) {
confRaw := &server.Config{
SharedConfig: &configutil.SharedConfig{
Listeners: []*configutil.Listener{
{
Type: "tcp",
Address: "127.0.0.1",
CustomResponseHeaders: CustomResponseHeaders,
},
},
DisableMlock: true,
},
}
conf := &CoreConfig{
RawConfig: confRaw,
EnableUI: enableUI,
EnableRaw: true,
BuiltinRegistry: NewMockBuiltinRegistry(),
}
core := TestCoreWithSealAndUI(t, conf)
return testCoreUnsealed(t, core)
}
func TestCoreUI(t testing.T, enableUI bool) *Core {
conf := &CoreConfig{
EnableUI: enableUI,

View File

@ -32,8 +32,9 @@ type UIConfig struct {
// NewUIConfig creates a new UI config
func NewUIConfig(enabled bool, physicalStorage physical.Backend, barrierStorage logical.Storage) *UIConfig {
defaultHeaders := http.Header{}
defaultHeaders.Set("Content-Security-Policy", "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'")
defaultHeaders.Set("Service-Worker-Allowed", "/")
defaultHeaders.Set("X-Content-Type-Options", "nosniff")
defaultHeaders.Set("Content-Security-Policy", "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'")
return &UIConfig{
physicalStorage: physicalStorage,