Customizing HTTP headers in the config file (#12485)
* Customizing HTTP headers in the config file * Add changelog, fix bad imports * fixing some bugs * fixing interaction of custom headers and /ui * Defining a member in core to set custom response headers * missing additional file * Some refactoring * Adding automated tests for the feature * Changing some error messages based on some recommendations * Incorporating custom response headers struct into the request context * removing some unused references * fixing a test * changing some error messages, removing a default header value from /ui * fixing a test * wrapping ResponseWriter to set the custom headers * adding a new test * some cleanup * removing some extra lines * Addressing comments * fixing some agent tests * skipping custom headers from agent listener config, removing two of the default headers as they cause issues with Vault in UI mode Adding X-Content-Type-Options to the ui default headers Let Content-Type be set as before * Removing default custom headers, and renaming some function varibles * some refacotring * Refactoring and addressing comments * removing a function and fixing comments
This commit is contained in:
parent
ce0091f5ee
commit
ad2ef412cc
|
@ -0,0 +1,3 @@
|
|||
```release-note:feature
|
||||
**Customizable HTTP Headers**: Add support to define custom HTTP headers for root path (`/`) and also on API endpoints (`/v1/*`)
|
||||
```
|
|
@ -35,6 +35,7 @@ func (c *Config) Prune() {
|
|||
l.RawConfig = nil
|
||||
l.Profiling.UnusedKeys = nil
|
||||
l.Telemetry.UnusedKeys = nil
|
||||
l.CustomResponseHeaders = nil
|
||||
}
|
||||
c.FoundKeys = nil
|
||||
c.UnusedKeys = nil
|
||||
|
@ -172,6 +173,12 @@ func LoadConfig(path string) (*Config, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Pruning custom headers for Agent for now
|
||||
for _, ln := range sharedConfig.Listeners {
|
||||
ln.CustomResponseHeaders = nil
|
||||
}
|
||||
|
||||
result.SharedConfig = sharedConfig
|
||||
|
||||
list, ok := obj.Node.(*ast.ObjectList)
|
||||
|
|
|
@ -536,7 +536,6 @@ func TestLoadConfigFile_AgentCache_PersistMissingType(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLoadConfigFile_TemplateConfig(t *testing.T) {
|
||||
|
||||
testCases := map[string]struct {
|
||||
fixturePath string
|
||||
expectedTemplateConfig TemplateConfig
|
||||
|
@ -586,7 +585,6 @@ func TestLoadConfigFile_TemplateConfig(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TestLoadConfigFile_Template tests template definitions in Vault Agent
|
||||
|
|
|
@ -1541,6 +1541,12 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
|
||||
core.SetConfig(config)
|
||||
|
||||
// reloading custom response headers to make sure we have
|
||||
// the most up to date headers after reloading the config file
|
||||
if err = core.ReloadCustomResponseHeaders(); err != nil {
|
||||
c.logger.Error(err.Error())
|
||||
}
|
||||
|
||||
if config.LogLevel != "" {
|
||||
configLogLevel := strings.ToLower(strings.TrimSpace(config.LogLevel))
|
||||
switch configLogLevel {
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
)
|
||||
|
||||
var defaultCustomHeaders = map[string]string{
|
||||
"Strict-Transport-Security": "max-age=1; domains",
|
||||
"Content-Security-Policy": "default-src 'others'",
|
||||
"X-Vault-Ignored": "ignored",
|
||||
"X-Custom-Header": "Custom header value default",
|
||||
}
|
||||
|
||||
var customHeaders307 = map[string]string{
|
||||
"X-Custom-Header": "Custom header value 307",
|
||||
}
|
||||
|
||||
var customHeader3xx = map[string]string{
|
||||
"X-Vault-Ignored-3xx": "Ignored 3xx",
|
||||
"X-Custom-Header": "Custom header value 3xx",
|
||||
}
|
||||
|
||||
var customHeaders200 = map[string]string{
|
||||
"Someheader-200": "200",
|
||||
"X-Custom-Header": "Custom header value 200",
|
||||
}
|
||||
|
||||
var customHeader2xx = map[string]string{
|
||||
"X-Custom-Header": "Custom header value 2xx",
|
||||
}
|
||||
|
||||
var customHeader400 = map[string]string{
|
||||
"Someheader-400": "400",
|
||||
}
|
||||
|
||||
var defaultCustomHeadersMultiListener = map[string]string{
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
|
||||
"Content-Security-Policy": "default-src 'others'",
|
||||
"X-Vault-Ignored": "ignored",
|
||||
"X-Custom-Header": "Custom header value default",
|
||||
}
|
||||
|
||||
var defaultSTS = map[string]string{
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
|
||||
}
|
||||
|
||||
func TestCustomResponseHeadersConfigs(t *testing.T) {
|
||||
expectedCustomResponseHeader := map[string]map[string]string{
|
||||
"default": defaultCustomHeaders,
|
||||
"307": customHeaders307,
|
||||
"3xx": customHeader3xx,
|
||||
"200": customHeaders200,
|
||||
"2xx": customHeader2xx,
|
||||
"400": customHeader400,
|
||||
}
|
||||
|
||||
config, err := LoadConfigFile("./test-fixtures/config_custom_response_headers_1.hcl")
|
||||
if err != nil {
|
||||
t.Fatalf("Error encountered when loading config %+v", err)
|
||||
}
|
||||
if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[0].CustomResponseHeaders); diff != nil {
|
||||
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomResponseHeadersConfigsMultipleListeners(t *testing.T) {
|
||||
expectedCustomResponseHeader := map[string]map[string]string{
|
||||
"default": defaultCustomHeadersMultiListener,
|
||||
"307": customHeaders307,
|
||||
"3xx": customHeader3xx,
|
||||
"200": customHeaders200,
|
||||
"2xx": customHeader2xx,
|
||||
"400": customHeader400,
|
||||
}
|
||||
|
||||
config, err := LoadConfigFile("./test-fixtures/config_custom_response_headers_multiple_listeners.hcl")
|
||||
if err != nil {
|
||||
t.Fatalf("Error encountered when loading config %+v", err)
|
||||
}
|
||||
if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[0].CustomResponseHeaders); diff != nil {
|
||||
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
|
||||
}
|
||||
|
||||
if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[1].CustomResponseHeaders); diff == nil {
|
||||
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
|
||||
}
|
||||
if diff := deep.Equal(expectedCustomResponseHeader["default"], config.Listeners[1].CustomResponseHeaders["default"]); diff != nil {
|
||||
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
|
||||
}
|
||||
|
||||
if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[2].CustomResponseHeaders); diff == nil {
|
||||
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
|
||||
}
|
||||
|
||||
if diff := deep.Equal(defaultSTS, config.Listeners[2].CustomResponseHeaders["default"]); diff != nil {
|
||||
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
|
||||
}
|
||||
|
||||
if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[3].CustomResponseHeaders); diff == nil {
|
||||
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
|
||||
}
|
||||
|
||||
if diff := deep.Equal(defaultSTS, config.Listeners[3].CustomResponseHeaders["default"]); diff != nil {
|
||||
t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff))
|
||||
}
|
||||
}
|
|
@ -16,6 +16,12 @@ import (
|
|||
"github.com/hashicorp/vault/internalshared/configutil"
|
||||
)
|
||||
|
||||
var DefaultCustomHeaders = map[string]map[string]string {
|
||||
"default": {
|
||||
"Strict-Transport-Security": configutil.StrictTransportSecurity,
|
||||
},
|
||||
}
|
||||
|
||||
func boolPointer(x bool) *bool {
|
||||
return &x
|
||||
}
|
||||
|
@ -32,6 +38,7 @@ func testConfigRaftRetryJoin(t *testing.T) {
|
|||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:8200",
|
||||
CustomResponseHeaders: DefaultCustomHeaders,
|
||||
},
|
||||
},
|
||||
DisableMlock: true,
|
||||
|
@ -64,6 +71,7 @@ func testLoadConfigFile_topLevel(t *testing.T, entropy *configutil.Entropy) {
|
|||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:443",
|
||||
CustomResponseHeaders: DefaultCustomHeaders,
|
||||
},
|
||||
},
|
||||
|
||||
|
@ -174,10 +182,12 @@ func testLoadConfigFile_json2(t *testing.T, entropy *configutil.Entropy) {
|
|||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:443",
|
||||
CustomResponseHeaders: DefaultCustomHeaders,
|
||||
},
|
||||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:444",
|
||||
CustomResponseHeaders: DefaultCustomHeaders,
|
||||
},
|
||||
},
|
||||
|
||||
|
@ -336,6 +346,7 @@ func testLoadConfigFileIntegerAndBooleanValuesCommon(t *testing.T, path string)
|
|||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:8200",
|
||||
CustomResponseHeaders: DefaultCustomHeaders,
|
||||
},
|
||||
},
|
||||
DisableMlock: true,
|
||||
|
@ -379,6 +390,7 @@ func testLoadConfigFile(t *testing.T) {
|
|||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:443",
|
||||
CustomResponseHeaders: DefaultCustomHeaders,
|
||||
},
|
||||
},
|
||||
|
||||
|
@ -486,7 +498,7 @@ func testUnknownFieldValidation(t *testing.T) {
|
|||
for _, er1 := range errors {
|
||||
found := false
|
||||
if strings.Contains(er1.String(), "sentinel") {
|
||||
//This happens on OSS, and is fine
|
||||
// This happens on OSS, and is fine
|
||||
continue
|
||||
}
|
||||
for _, ex := range expected {
|
||||
|
@ -525,6 +537,7 @@ func testLoadConfigFile_json(t *testing.T) {
|
|||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:443",
|
||||
CustomResponseHeaders: DefaultCustomHeaders,
|
||||
},
|
||||
},
|
||||
|
||||
|
@ -610,6 +623,7 @@ func testLoadConfigDir(t *testing.T) {
|
|||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:443",
|
||||
CustomResponseHeaders: DefaultCustomHeaders,
|
||||
},
|
||||
},
|
||||
|
||||
|
@ -818,6 +832,7 @@ listener "tcp" {
|
|||
Profiling: configutil.ListenerProfiling{
|
||||
UnauthenticatedPProfAccess: true,
|
||||
},
|
||||
CustomResponseHeaders: DefaultCustomHeaders,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -845,6 +860,7 @@ func testParseSeals(t *testing.T) {
|
|||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:443",
|
||||
CustomResponseHeaders: DefaultCustomHeaders,
|
||||
},
|
||||
},
|
||||
Seals: []*configutil.KMS{
|
||||
|
@ -898,6 +914,7 @@ func testLoadConfigFileLeaseMetrics(t *testing.T) {
|
|||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:443",
|
||||
CustomResponseHeaders: DefaultCustomHeaders,
|
||||
},
|
||||
},
|
||||
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
storage "inmem" {}
|
||||
listener "tcp" {
|
||||
address = "127.0.0.1:8200"
|
||||
tls_disable = true
|
||||
custom_response_headers {
|
||||
"default" = {
|
||||
"Strict-Transport-Security" = ["max-age=1","domains"],
|
||||
"Content-Security-Policy" = ["default-src 'others'"],
|
||||
"X-Vault-Ignored" = ["ignored"],
|
||||
"X-Custom-Header" = ["Custom header value default"],
|
||||
}
|
||||
"307" = {
|
||||
"X-Custom-Header" = ["Custom header value 307"],
|
||||
}
|
||||
"3xx" = {
|
||||
"X-Vault-Ignored-3xx" = ["Ignored 3xx"],
|
||||
"X-Custom-Header" = ["Custom header value 3xx"]
|
||||
}
|
||||
"200" = {
|
||||
"someheader-200" = ["200"],
|
||||
"X-Custom-Header" = ["Custom header value 200"]
|
||||
}
|
||||
"2xx" = {
|
||||
"X-Custom-Header" = ["Custom header value 2xx"]
|
||||
}
|
||||
"400" = {
|
||||
"someheader-400" = ["400"]
|
||||
}
|
||||
}
|
||||
}
|
||||
disable_mlock = true
|
|
@ -0,0 +1,56 @@
|
|||
storage "inmem" {}
|
||||
listener "tcp" {
|
||||
address = "127.0.0.1:8200"
|
||||
tls_disable = true
|
||||
custom_response_headers {
|
||||
"default" = {
|
||||
"Content-Security-Policy" = ["default-src 'others'"],
|
||||
"X-Vault-Ignored" = ["ignored"],
|
||||
"X-Custom-Header" = ["Custom header value default"],
|
||||
}
|
||||
"307" = {
|
||||
"X-Custom-Header" = ["Custom header value 307"],
|
||||
}
|
||||
"3xx" = {
|
||||
"X-Vault-Ignored-3xx" = ["Ignored 3xx"],
|
||||
"X-Custom-Header" = ["Custom header value 3xx"]
|
||||
}
|
||||
"200" = {
|
||||
"someheader-200" = ["200"],
|
||||
"X-Custom-Header" = ["Custom header value 200"]
|
||||
}
|
||||
"2xx" = {
|
||||
"X-Custom-Header" = ["Custom header value 2xx"]
|
||||
}
|
||||
"400" = {
|
||||
"someheader-400" = ["400"]
|
||||
}
|
||||
}
|
||||
}
|
||||
listener "tcp" {
|
||||
address = "127.0.0.2:8200"
|
||||
tls_disable = true
|
||||
custom_response_headers {
|
||||
"default" = {
|
||||
"Content-Security-Policy" = ["default-src 'others'"],
|
||||
"X-Vault-Ignored" = ["ignored"],
|
||||
"X-Custom-Header" = ["Custom header value default"],
|
||||
}
|
||||
}
|
||||
}
|
||||
listener "tcp" {
|
||||
address = "127.0.0.3:8200"
|
||||
tls_disable = true
|
||||
custom_response_headers {
|
||||
"2xx" = {
|
||||
"X-Custom-Header" = ["Custom header value 2xx"]
|
||||
}
|
||||
}
|
||||
}
|
||||
listener "tcp" {
|
||||
address = "127.0.0.4:8200"
|
||||
tls_disable = true
|
||||
}
|
||||
|
||||
|
||||
disable_mlock = true
|
|
@ -0,0 +1,128 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
var defaultCustomHeaders = map[string]string {
|
||||
"Strict-Transport-Security": "max-age=1; domains",
|
||||
"Content-Security-Policy": "default-src 'others'",
|
||||
"X-Custom-Header": "Custom header value default",
|
||||
"X-Frame-Options": "Deny",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Content-Type": "application/json",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
}
|
||||
|
||||
var customHeader2xx = map[string]string {
|
||||
"X-Custom-Header": "Custom header value 2xx",
|
||||
}
|
||||
|
||||
var customHeader200 = map[string]string {
|
||||
"Someheader-200": "200",
|
||||
"X-Custom-Header": "Custom header value 200",
|
||||
}
|
||||
|
||||
var customHeader4xx = map[string]string {
|
||||
"Someheader-4xx": "4xx",
|
||||
}
|
||||
|
||||
var customHeader400 = map[string]string {
|
||||
"Someheader-400": "400",
|
||||
}
|
||||
|
||||
var customHeader405 = map[string]string {
|
||||
"Someheader-405": "405",
|
||||
}
|
||||
|
||||
var CustomResponseHeaders = map[string]map[string]string{
|
||||
"default": defaultCustomHeaders,
|
||||
"307": {"X-Custom-Header": "Custom header value 307"},
|
||||
"3xx": {
|
||||
"X-Custom-Header": "Custom header value 3xx",
|
||||
"X-Vault-Ignored-3xx": "Ignored 3xx",
|
||||
},
|
||||
"200": customHeader200,
|
||||
"2xx": customHeader2xx,
|
||||
"400": customHeader400,
|
||||
"405": customHeader405,
|
||||
"4xx": customHeader4xx,
|
||||
}
|
||||
|
||||
func TestCustomResponseHeaders(t *testing.T) {
|
||||
core, _, token := vault.TestCoreWithCustomResponseHeaderAndUI(t, CustomResponseHeaders, true)
|
||||
ln, addr := TestServer(t, core)
|
||||
defer ln.Close()
|
||||
TestServerAuth(t, addr, token)
|
||||
|
||||
resp := testHttpGet(t, token, addr+"/v1/sys/raw/")
|
||||
testResponseStatus(t, resp, 404)
|
||||
testResponseHeader(t, resp, defaultCustomHeaders)
|
||||
testResponseHeader(t, resp, customHeader4xx)
|
||||
|
||||
resp = testHttpGet(t, token, addr+"/v1/sys/seal")
|
||||
testResponseStatus(t, resp, 405)
|
||||
testResponseHeader(t, resp, defaultCustomHeaders)
|
||||
testResponseHeader(t, resp, customHeader4xx)
|
||||
testResponseHeader(t, resp, customHeader405)
|
||||
|
||||
resp = testHttpGet(t, token, addr+"/v1/sys/leader")
|
||||
testResponseStatus(t, resp, 200)
|
||||
testResponseHeader(t, resp, customHeader200)
|
||||
|
||||
resp = testHttpGet(t, token, addr+"/v1/sys/health")
|
||||
testResponseStatus(t, resp, 200)
|
||||
testResponseHeader(t, resp, customHeader200)
|
||||
|
||||
resp = testHttpGet(t, token, addr+"/v1/sys/generate-root/attempt")
|
||||
testResponseStatus(t, resp, 200)
|
||||
testResponseHeader(t, resp, customHeader200)
|
||||
|
||||
resp = testHttpGet(t, token, addr+"/v1/sys/generate-root/update")
|
||||
testResponseStatus(t, resp, 400)
|
||||
testResponseHeader(t, resp, defaultCustomHeaders)
|
||||
testResponseHeader(t, resp, customHeader4xx)
|
||||
testResponseHeader(t, resp, customHeader400)
|
||||
|
||||
resp = testHttpGet(t, token, addr+"/v1/sys/")
|
||||
testResponseStatus(t, resp, 404)
|
||||
testResponseHeader(t, resp, defaultCustomHeaders)
|
||||
testResponseHeader(t, resp, customHeader4xx)
|
||||
|
||||
resp = testHttpGet(t, token, addr+"/v1/sys")
|
||||
testResponseStatus(t, resp, 404)
|
||||
testResponseHeader(t, resp, defaultCustomHeaders)
|
||||
testResponseHeader(t, resp, customHeader4xx)
|
||||
|
||||
resp = testHttpGet(t, token, addr+"/v1/")
|
||||
testResponseStatus(t, resp, 404)
|
||||
testResponseHeader(t, resp, defaultCustomHeaders)
|
||||
testResponseHeader(t, resp, customHeader4xx)
|
||||
|
||||
resp = testHttpGet(t, token, addr+"/v1")
|
||||
testResponseStatus(t, resp, 404)
|
||||
testResponseHeader(t, resp, defaultCustomHeaders)
|
||||
testResponseHeader(t, resp, customHeader4xx)
|
||||
|
||||
resp = testHttpGet(t, token, addr+"/")
|
||||
testResponseStatus(t, resp, 200)
|
||||
testResponseHeader(t, resp, customHeader200)
|
||||
|
||||
resp = testHttpGet(t, token, addr+"/ui")
|
||||
testResponseStatus(t, resp, 200)
|
||||
testResponseHeader(t, resp, customHeader200)
|
||||
|
||||
resp = testHttpGet(t, token, addr+"/ui/")
|
||||
testResponseStatus(t, resp, 200)
|
||||
testResponseHeader(t, resp, customHeader200)
|
||||
|
||||
resp = testHttpPost(t, token, addr+"/v1/sys/auth/foo", map[string]interface{}{
|
||||
"type": "noop",
|
||||
"description": "foo",
|
||||
})
|
||||
testResponseStatus(t, resp, 204)
|
||||
testResponseHeader(t, resp, customHeader2xx)
|
||||
|
||||
}
|
112
http/handler.go
112
http/handler.go
|
@ -16,12 +16,14 @@ import (
|
|||
"net/textproto"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-secure-stdlib/parseutil"
|
||||
"github.com/hashicorp/go-sockaddr"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
|
@ -210,6 +212,90 @@ func Handler(props *vault.HandlerProperties) http.Handler {
|
|||
return printablePathCheckHandler
|
||||
}
|
||||
|
||||
type WrappingResponseWriter interface {
|
||||
http.ResponseWriter
|
||||
Wrapped() http.ResponseWriter
|
||||
}
|
||||
|
||||
type statusHeaderResponseWriter struct {
|
||||
wrapped http.ResponseWriter
|
||||
logger log.Logger
|
||||
wroteHeader bool
|
||||
statusCode int
|
||||
headers map[string][]*vault.CustomHeader
|
||||
}
|
||||
|
||||
func (w *statusHeaderResponseWriter) Wrapped() http.ResponseWriter {
|
||||
return w.wrapped
|
||||
}
|
||||
|
||||
func (w *statusHeaderResponseWriter) Header() http.Header {
|
||||
return w.wrapped.Header()
|
||||
}
|
||||
|
||||
func (w *statusHeaderResponseWriter) Write(buf []byte) (int, error) {
|
||||
// It is allowed to only call ResponseWriter.Write and skip
|
||||
// ResponseWriter.WriteHeader. An example of such a situation is
|
||||
// "handleUIStub". The Write function will internally set the status code
|
||||
// 200 for the response for which that call might invoke other
|
||||
// implementations of the WriteHeader function. So, we still need to set
|
||||
// the custom headers. In cases where both WriteHeader and Write of
|
||||
// statusHeaderResponseWriter struct are called the internal call to the
|
||||
// WriterHeader invoked from inside Write method won't change the headers.
|
||||
if !w.wroteHeader {
|
||||
w.setCustomResponseHeaders(w.statusCode)
|
||||
}
|
||||
|
||||
return w.wrapped.Write(buf)
|
||||
}
|
||||
|
||||
func (w *statusHeaderResponseWriter) WriteHeader(statusCode int) {
|
||||
w.setCustomResponseHeaders(statusCode)
|
||||
w.wrapped.WriteHeader(statusCode)
|
||||
w.statusCode = statusCode
|
||||
// in cases where Write is called after WriteHeader, let's prevent setting
|
||||
// ResponseWriter headers twice
|
||||
w.wroteHeader = true
|
||||
}
|
||||
|
||||
func (w *statusHeaderResponseWriter) setCustomResponseHeaders(status int) {
|
||||
sch := w.headers
|
||||
if sch == nil {
|
||||
w.logger.Warn("status code header map not configured")
|
||||
return
|
||||
}
|
||||
|
||||
// Checking the validity of the status code
|
||||
if status >= 600 || status < 100 {
|
||||
return
|
||||
}
|
||||
|
||||
// setter function to set the headers
|
||||
setter := func(hvl []*vault.CustomHeader) {
|
||||
for _, hv := range hvl {
|
||||
w.Header().Set(hv.Name, hv.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// Setting the default headers first
|
||||
setter(sch["default"])
|
||||
|
||||
// setting the Xyy pattern first
|
||||
d := fmt.Sprintf("%vxx", status/100)
|
||||
if val, ok := sch[d]; ok {
|
||||
setter(val)
|
||||
}
|
||||
|
||||
// Setting the specific headers
|
||||
if val, ok := sch[strconv.Itoa(status)]; ok {
|
||||
setter(val)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var _ WrappingResponseWriter = &statusHeaderResponseWriter{}
|
||||
|
||||
type copyResponseWriter struct {
|
||||
wrapped http.ResponseWriter
|
||||
statusCode int
|
||||
|
@ -300,6 +386,22 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
|
|||
hostname, _ := os.Hostname()
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// This block needs to be here so that upon sending SIGHUP, custom response
|
||||
// headers are also reloaded into the handlers.
|
||||
if props.ListenerConfig != nil {
|
||||
la := props.ListenerConfig.Address
|
||||
listenerCustomHeaders := core.GetListenerCustomResponseHeaders(la)
|
||||
if listenerCustomHeaders != nil {
|
||||
w = &statusHeaderResponseWriter{
|
||||
wrapped: w,
|
||||
logger: core.Logger(),
|
||||
wroteHeader: false,
|
||||
statusCode: 200,
|
||||
headers: listenerCustomHeaders.StatusCodeHeaderMap,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set the Cache-Control header for all the responses returned
|
||||
// by Vault
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
|
@ -632,7 +734,15 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
|
|||
return nil, errors.New("could not parse max_request_size from request context")
|
||||
}
|
||||
if max > 0 {
|
||||
reader = http.MaxBytesReader(w, r.Body, max)
|
||||
// MaxBytesReader won't do all the internal stuff it must unless it's
|
||||
// given a ResponseWriter that implements the internal http interface
|
||||
// requestTooLarger. So we let it have access to the underlying
|
||||
// ResponseWriter.
|
||||
inw := w
|
||||
if myw, ok := inw.(WrappingResponseWriter); ok {
|
||||
inw = myw.Wrapped()
|
||||
}
|
||||
reader = http.MaxBytesReader(inw, r.Body, max)
|
||||
}
|
||||
}
|
||||
var origBody io.ReadWriter
|
||||
|
|
|
@ -125,6 +125,16 @@ func testResponseStatus(t *testing.T, resp *http.Response, code int) {
|
|||
}
|
||||
}
|
||||
|
||||
func testResponseHeader(t *testing.T, resp *http.Response, expectedHeaders map[string]string) {
|
||||
t.Helper()
|
||||
for k, v := range expectedHeaders {
|
||||
hv := resp.Header.Get(k)
|
||||
if v != hv {
|
||||
t.Fatalf("expected header value %v=%v, got %v=%v", k, v, k, hv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testResponseBody(t *testing.T, resp *http.Response, out interface{}) {
|
||||
defer resp.Body.Close()
|
||||
|
||||
|
|
|
@ -35,12 +35,14 @@ func handleMetricsUnauthenticated(core *vault.Core) http.Handler {
|
|||
resp := core.MetricsHelper().ResponseForFormat(format)
|
||||
|
||||
// Manually extract the logical response and send back the information
|
||||
w.WriteHeader(resp.Data[logical.HTTPStatusCode].(int))
|
||||
status := resp.Data[logical.HTTPStatusCode].(int)
|
||||
w.Header().Set("Content-Type", resp.Data[logical.HTTPContentType].(string))
|
||||
switch v := resp.Data[logical.HTTPRawBody].(type) {
|
||||
case string:
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(v))
|
||||
case []byte:
|
||||
w.WriteHeader(status)
|
||||
w.Write(v)
|
||||
default:
|
||||
respondError(w, http.StatusInternalServerError, fmt.Errorf("wrong response returned"))
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/internalshared/configutil"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
|
@ -41,10 +42,16 @@ func TestServerWithListenerAndProperties(tb testing.TB, ln net.Listener, addr st
|
|||
}
|
||||
|
||||
func TestServerWithListener(tb testing.TB, ln net.Listener, addr string, core *vault.Core) {
|
||||
ip, _, _ := net.SplitHostPort(ln.Addr().String())
|
||||
|
||||
// Create a muxer to handle our requests so that we can authenticate
|
||||
// for tests.
|
||||
props := &vault.HandlerProperties{
|
||||
Core: core,
|
||||
// This is needed for testing custom response headers
|
||||
ListenerConfig: &configutil.Listener {
|
||||
Address: ip,
|
||||
},
|
||||
}
|
||||
TestServerWithListenerAndProperties(tb, ln, addr, core, props)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
package configutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
)
|
||||
|
||||
var ValidCustomStatusCodeCollection = []string{
|
||||
"default",
|
||||
"1xx",
|
||||
"2xx",
|
||||
"3xx",
|
||||
"4xx",
|
||||
"5xx",
|
||||
}
|
||||
|
||||
const StrictTransportSecurity = "max-age=31536000; includeSubDomains"
|
||||
|
||||
// ParseCustomResponseHeaders takes a raw config values for the
|
||||
// "custom_response_headers". It makes sure the config entry is passed in
|
||||
// as a map of status code to a map of header name and header values. It
|
||||
// verifies the validity of the status codes, and header values. It also
|
||||
// adds the default headers values.
|
||||
func ParseCustomResponseHeaders(responseHeaders interface{}) (map[string]map[string]string, error) {
|
||||
h := make(map[string]map[string]string)
|
||||
// if r is nil, we still should set the default custom headers
|
||||
if responseHeaders == nil {
|
||||
h["default"] = map[string]string{"Strict-Transport-Security": StrictTransportSecurity}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
customResponseHeader, ok := responseHeaders.([]map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps")
|
||||
}
|
||||
|
||||
for _, crh := range customResponseHeader {
|
||||
for statusCode, responseHeader := range crh {
|
||||
headerValList, ok := responseHeader.([]map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps")
|
||||
}
|
||||
|
||||
if !IsValidStatusCode(statusCode) {
|
||||
return nil, fmt.Errorf("invalid status code found in the server configuration: %v", statusCode)
|
||||
}
|
||||
|
||||
if len(headerValList) != 1 {
|
||||
return nil, fmt.Errorf("invalid number of response headers exist")
|
||||
}
|
||||
headerValMap := headerValList[0]
|
||||
headerVal, err := parseHeaders(headerValMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h[statusCode] = headerVal
|
||||
}
|
||||
}
|
||||
|
||||
// setting Strict-Transport-Security as a default header
|
||||
if h["default"] == nil {
|
||||
h["default"] = make(map[string]string)
|
||||
}
|
||||
if _, ok := h["default"]["Strict-Transport-Security"]; !ok {
|
||||
h["default"]["Strict-Transport-Security"] = StrictTransportSecurity
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// IsValidStatusCode checking for status codes outside the boundary
|
||||
func IsValidStatusCode(sc string) bool {
|
||||
if strutil.StrListContains(ValidCustomStatusCodeCollection, sc) {
|
||||
return true
|
||||
}
|
||||
|
||||
i, err := strconv.Atoi(sc)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if i >= 600 || i < 100 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func parseHeaders(in map[string]interface{}) (map[string]string, error) {
|
||||
hvMap := make(map[string]string)
|
||||
for k, v := range in {
|
||||
// parsing header name
|
||||
headerName := textproto.CanonicalMIMEHeaderKey(k)
|
||||
// parsing header values
|
||||
s, err := parseHeaderValues(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hvMap[headerName] = s
|
||||
}
|
||||
return hvMap, nil
|
||||
}
|
||||
|
||||
func parseHeaderValues(header interface{}) (string, error) {
|
||||
var sl []string
|
||||
if _, ok := header.([]interface{}); !ok {
|
||||
return "", fmt.Errorf("headers must be given in a list of strings")
|
||||
}
|
||||
headerValList := header.([]interface{})
|
||||
for _, vh := range headerValList {
|
||||
if _, ok := vh.(string); !ok {
|
||||
return "", fmt.Errorf("found a non-string header value: %v", vh)
|
||||
}
|
||||
headerVal := strings.TrimSpace(vh.(string))
|
||||
if headerVal == "" {
|
||||
continue
|
||||
}
|
||||
sl = append(sl, headerVal)
|
||||
|
||||
}
|
||||
s := strings.Join(sl, "; ")
|
||||
|
||||
return s, nil
|
||||
}
|
|
@ -99,6 +99,10 @@ type Listener struct {
|
|||
CorsAllowedOrigins []string `hcl:"cors_allowed_origins"`
|
||||
CorsAllowedHeaders []string `hcl:"-"`
|
||||
CorsAllowedHeadersRaw []string `hcl:"cors_allowed_headers,alias:cors_allowed_headers"`
|
||||
|
||||
// Custom Http response headers
|
||||
CustomResponseHeaders map[string]map[string]string `hcl:"-"`
|
||||
CustomResponseHeadersRaw interface{} `hcl:"custom_response_headers"`
|
||||
}
|
||||
|
||||
func (l *Listener) GoString() string {
|
||||
|
@ -361,6 +365,17 @@ func ParseListeners(result *SharedConfig, list *ast.ObjectList) error {
|
|||
}
|
||||
}
|
||||
|
||||
// HTTP Headers
|
||||
{
|
||||
// if CustomResponseHeadersRaw is nil, we still need to set the default headers
|
||||
customHeadersMap, err := ParseCustomResponseHeaders(l.CustomResponseHeadersRaw)
|
||||
if err != nil {
|
||||
return multierror.Prefix(fmt.Errorf("failed to parse custom_response_headers: %w", err), fmt.Sprintf("listeners.%d", i))
|
||||
}
|
||||
l.CustomResponseHeaders = customHeadersMap
|
||||
l.CustomResponseHeadersRaw = nil
|
||||
}
|
||||
|
||||
result.Listeners = append(result.Listeners, &l)
|
||||
}
|
||||
|
||||
|
|
111
vault/core.go
111
vault/core.go
|
@ -510,6 +510,9 @@ type Core struct {
|
|||
// clusterListener starts up and manages connections on the cluster ports
|
||||
clusterListener *atomic.Value
|
||||
|
||||
// customListenerHeader holds custom response headers for a listener
|
||||
customListenerHeader *atomic.Value
|
||||
|
||||
// Telemetry objects
|
||||
metricsHelper *metricsutil.MetricsHelper
|
||||
|
||||
|
@ -769,23 +772,24 @@ func CreateCore(conf *CoreConfig) (*Core, error) {
|
|||
|
||||
// Setup the core
|
||||
c := &Core{
|
||||
entCore: entCore{},
|
||||
devToken: conf.DevToken,
|
||||
physical: conf.Physical,
|
||||
serviceRegistration: conf.GetServiceRegistration(),
|
||||
underlyingPhysical: conf.Physical,
|
||||
storageType: conf.StorageType,
|
||||
redirectAddr: conf.RedirectAddr,
|
||||
clusterAddr: new(atomic.Value),
|
||||
clusterListener: new(atomic.Value),
|
||||
seal: conf.Seal,
|
||||
router: NewRouter(),
|
||||
sealed: new(uint32),
|
||||
sealMigrationDone: new(uint32),
|
||||
standby: true,
|
||||
standbyStopCh: new(atomic.Value),
|
||||
baseLogger: conf.Logger,
|
||||
logger: conf.Logger.Named("core"),
|
||||
entCore: entCore{},
|
||||
devToken: conf.DevToken,
|
||||
physical: conf.Physical,
|
||||
serviceRegistration: conf.GetServiceRegistration(),
|
||||
underlyingPhysical: conf.Physical,
|
||||
storageType: conf.StorageType,
|
||||
redirectAddr: conf.RedirectAddr,
|
||||
clusterAddr: new(atomic.Value),
|
||||
clusterListener: new(atomic.Value),
|
||||
customListenerHeader: new(atomic.Value),
|
||||
seal: conf.Seal,
|
||||
router: NewRouter(),
|
||||
sealed: new(uint32),
|
||||
sealMigrationDone: new(uint32),
|
||||
standby: true,
|
||||
standbyStopCh: new(atomic.Value),
|
||||
baseLogger: conf.Logger,
|
||||
logger: conf.Logger.Named("core"),
|
||||
|
||||
defaultLeaseTTL: conf.DefaultLeaseTTL,
|
||||
maxLeaseTTL: conf.MaxLeaseTTL,
|
||||
|
@ -1005,6 +1009,17 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
|
||||
c.clusterListener.Store((*cluster.Listener)(nil))
|
||||
|
||||
// for listeners with custom response headers, configuring customListenerHeader
|
||||
if conf.RawConfig.Listeners != nil {
|
||||
uiHeaders, err := c.UIHeaders()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.customListenerHeader.Store(NewListenerCustomHeader(conf.RawConfig.Listeners, c.logger, uiHeaders))
|
||||
} else {
|
||||
c.customListenerHeader.Store(([]*ListenerCustomHeaders)(nil))
|
||||
}
|
||||
|
||||
quotasLogger := conf.Logger.Named("quotas")
|
||||
c.allLoggers = append(c.allLoggers, quotasLogger)
|
||||
c.quotaManager, err = quotas.NewManager(quotasLogger, c.quotaLeaseWalker, c.metricSink)
|
||||
|
@ -2641,6 +2656,68 @@ func (c *Core) SetConfig(conf *server.Config) {
|
|||
c.logger.Debug("set config", "sanitized config", string(bz))
|
||||
}
|
||||
|
||||
func (c *Core) GetListenerCustomResponseHeaders(listenerAdd string) *ListenerCustomHeaders {
|
||||
|
||||
customHeaders := c.customListenerHeader.Load()
|
||||
if customHeaders == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
customHeadersList, ok := customHeaders.([]*ListenerCustomHeaders)
|
||||
if customHeadersList == nil || !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, l := range customHeadersList {
|
||||
if l.Address == listenerAdd {
|
||||
return l
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExistCustomResponseHeader checks if a custom header is configured in any
|
||||
// listener's stanza
|
||||
func (c *Core) ExistCustomResponseHeader(header string) bool {
|
||||
customHeaders := c.customListenerHeader.Load()
|
||||
if customHeaders == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
customHeadersList, ok := customHeaders.([]*ListenerCustomHeaders)
|
||||
if customHeadersList == nil || !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, l := range customHeadersList {
|
||||
exist := l.ExistCustomResponseHeader(header)
|
||||
if exist {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Core) ReloadCustomResponseHeaders() error {
|
||||
conf := c.rawConfig.Load()
|
||||
if conf == nil {
|
||||
return fmt.Errorf("failed to load core raw config")
|
||||
}
|
||||
lns := conf.(*server.Config).Listeners
|
||||
if lns == nil {
|
||||
return fmt.Errorf("no listener configured")
|
||||
}
|
||||
|
||||
uiHeaders, err := c.UIHeaders()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.customListenerHeader.Store(NewListenerCustomHeader(lns, c.logger, uiHeaders))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SanitizedConfig returns a sanitized version of the current config.
|
||||
// See server.Config.Sanitized for specific values omitted.
|
||||
func (c *Core) SanitizedConfig() map[string]interface{} {
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/internalshared/configutil"
|
||||
)
|
||||
|
||||
type ListenerCustomHeaders struct {
|
||||
Address string
|
||||
StatusCodeHeaderMap map[string][]*CustomHeader
|
||||
// ConfiguredHeadersStatusCodeMap field is introduced so that we would not need to loop through
|
||||
// StatusCodeHeaderMap to see if a header exists, the key for this map is the headers names
|
||||
configuredHeadersStatusCodeMap map[string][]string
|
||||
}
|
||||
|
||||
type CustomHeader struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHeaders http.Header) []*ListenerCustomHeaders {
|
||||
var listenerCustomHeadersList []*ListenerCustomHeaders
|
||||
|
||||
for _, l := range ln {
|
||||
listenerCustomHeaderStruct := &ListenerCustomHeaders{
|
||||
Address: l.Address,
|
||||
}
|
||||
listenerCustomHeaderStruct.StatusCodeHeaderMap = make(map[string][]*CustomHeader)
|
||||
listenerCustomHeaderStruct.configuredHeadersStatusCodeMap = make(map[string][]string)
|
||||
for statusCode, headerValMap := range l.CustomResponseHeaders {
|
||||
var customHeaderList []*CustomHeader
|
||||
for headerName, headerVal := range headerValMap {
|
||||
// Sanitizing custom headers
|
||||
// X-Vault- prefix is reserved for Vault internal processes
|
||||
if strings.HasPrefix(headerName, "X-Vault-") {
|
||||
logger.Warn("custom headers starting with X-Vault are not valid", "header", headerName)
|
||||
continue
|
||||
}
|
||||
|
||||
// Checking for UI headers, if any common header exists, we just log an error
|
||||
if uiHeaders != nil {
|
||||
exist := uiHeaders.Get(headerName)
|
||||
if exist != "" {
|
||||
logger.Warn("found a duplicate header in UI", "header:", headerName, "Headers defined in the server configuration take precedence.")
|
||||
}
|
||||
}
|
||||
|
||||
// Checking if the header value is not an empty string
|
||||
if headerVal == "" {
|
||||
logger.Warn("header value is an empty string", "header", headerName, "value", headerVal)
|
||||
continue
|
||||
}
|
||||
|
||||
ch := &CustomHeader{
|
||||
Name: headerName,
|
||||
Value: headerVal,
|
||||
}
|
||||
|
||||
customHeaderList = append(customHeaderList, ch)
|
||||
|
||||
// setting up the reverse map of header to status code for easy lookups
|
||||
listenerCustomHeaderStruct.configuredHeadersStatusCodeMap[headerName] = append(listenerCustomHeaderStruct.configuredHeadersStatusCodeMap[headerName], statusCode)
|
||||
}
|
||||
listenerCustomHeaderStruct.StatusCodeHeaderMap[statusCode] = customHeaderList
|
||||
}
|
||||
listenerCustomHeadersList = append(listenerCustomHeadersList, listenerCustomHeaderStruct)
|
||||
}
|
||||
|
||||
return listenerCustomHeadersList
|
||||
}
|
||||
|
||||
func (l *ListenerCustomHeaders) ExistCustomResponseHeader(header string) bool {
|
||||
if header == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if l.StatusCodeHeaderMap == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
headerName := textproto.CanonicalMIMEHeaderKey(header)
|
||||
|
||||
headerMap := l.configuredHeadersStatusCodeMap
|
||||
_, ok := headerMap[headerName]
|
||||
return ok
|
||||
}
|
|
@ -0,0 +1,174 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/internalshared/configutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/sdk/physical/inmem"
|
||||
)
|
||||
|
||||
var defaultCustomHeaders = map[string]string{
|
||||
"Strict-Transport-Security": "max-age=1; domains",
|
||||
"Content-Security-Policy": "default-src 'others'",
|
||||
"X-Vault-Ignored": "ignored",
|
||||
"X-Custom-Header": "Custom header value default",
|
||||
"X-Frame-Options": "Deny",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Content-Type": "text/plain; charset=utf-8",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
}
|
||||
|
||||
var customHeaders307 = map[string]string{
|
||||
"X-Custom-Header": "Custom header value 307",
|
||||
}
|
||||
|
||||
var customHeader3xx = map[string]string{
|
||||
"X-Vault-Ignored-3xx": "Ignored 3xx",
|
||||
"X-Custom-Header": "Custom header value 3xx",
|
||||
}
|
||||
|
||||
var customHeaders200 = map[string]string{
|
||||
"Someheader-200": "200",
|
||||
"X-Custom-Header": "Custom header value 200",
|
||||
}
|
||||
|
||||
var customHeader2xx = map[string]string{
|
||||
"X-Custom-Header": "Custom header value 2xx",
|
||||
}
|
||||
|
||||
var customHeader400 = map[string]string{
|
||||
"Someheader-400": "400",
|
||||
}
|
||||
|
||||
func TestConfigCustomHeaders(t *testing.T) {
|
||||
logger := logging.NewVaultLogger(log.Trace)
|
||||
phys, err := inmem.NewTransactionalInmem(nil, logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
logl := &logical.InmemStorage{}
|
||||
uiConfig := NewUIConfig(true, phys, logl)
|
||||
|
||||
rawListenerConfig := []*configutil.Listener{
|
||||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:443",
|
||||
CustomResponseHeaders: map[string]map[string]string{
|
||||
"default": defaultCustomHeaders,
|
||||
"307": customHeaders307,
|
||||
"3xx": customHeader3xx,
|
||||
"200": customHeaders200,
|
||||
"2xx": customHeader2xx,
|
||||
"400": customHeader400,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
uiHeaders, err := uiConfig.Headers(context.Background())
|
||||
listenerCustomHeaders := NewListenerCustomHeader(rawListenerConfig, logger, uiHeaders)
|
||||
if listenerCustomHeaders == nil || len(listenerCustomHeaders) != 1 {
|
||||
t.Fatalf("failed to get custom header configuration")
|
||||
}
|
||||
|
||||
lch := listenerCustomHeaders[0]
|
||||
|
||||
if lch.ExistCustomResponseHeader("X-Vault-Ignored-307") {
|
||||
t.Fatalf("header name with X-Vault prefix is not valid")
|
||||
}
|
||||
if lch.ExistCustomResponseHeader("X-Vault-Ignored-3xx") {
|
||||
t.Fatalf("header name with X-Vault prefix is not valid")
|
||||
}
|
||||
|
||||
if !lch.ExistCustomResponseHeader("X-Custom-Header") {
|
||||
t.Fatalf("header name with X-Vault prefix is not valid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomResponseHeadersConfigInteractUiConfig(t *testing.T) {
|
||||
b := testSystemBackend(t)
|
||||
_, barrier, _ := mockBarrier(t)
|
||||
view := NewBarrierView(barrier, "")
|
||||
b.(*SystemBackend).Core.systemBarrierView = view
|
||||
|
||||
logger := logging.NewVaultLogger(log.Trace)
|
||||
rawListenerConfig := []*configutil.Listener{
|
||||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:443",
|
||||
CustomResponseHeaders: map[string]map[string]string{
|
||||
"default": defaultCustomHeaders,
|
||||
"307": customHeaders307,
|
||||
"3xx": customHeader3xx,
|
||||
"200": customHeaders200,
|
||||
"2xx": customHeader2xx,
|
||||
"400": customHeader400,
|
||||
},
|
||||
},
|
||||
}
|
||||
uiHeaders, err := b.(*SystemBackend).Core.uiConfig.Headers(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get headers from ui config")
|
||||
}
|
||||
customListenerHeader := NewListenerCustomHeader(rawListenerConfig, logger, uiHeaders)
|
||||
if customListenerHeader == nil {
|
||||
t.Fatalf("custom header config should be configured")
|
||||
}
|
||||
b.(*SystemBackend).Core.customListenerHeader.Store(customListenerHeader)
|
||||
clh := b.(*SystemBackend).Core.customListenerHeader
|
||||
if clh == nil {
|
||||
t.Fatalf("custom header config should be configured in core")
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
hw := logical.NewHTTPResponseWriter(w)
|
||||
|
||||
// setting a header that already exist in custom headers
|
||||
req := logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/X-Custom-Header")
|
||||
req.Data["values"] = []string{"UI Custom Header"}
|
||||
req.ResponseWriter = hw
|
||||
|
||||
resp, err := b.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err == nil {
|
||||
t.Fatal("request did not fail on setting a header that is present in custom response headers")
|
||||
}
|
||||
if !strings.Contains(resp.Data["error"].(string), fmt.Sprintf("This header already exists in the server configuration and cannot be set in the UI.")) {
|
||||
t.Fatalf("failed to get the expected error")
|
||||
}
|
||||
|
||||
// setting a header that already exist in custom headers
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/Someheader-400")
|
||||
req.Data["values"] = []string{"400"}
|
||||
req.ResponseWriter = hw
|
||||
|
||||
_, err = b.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err == nil {
|
||||
t.Fatal("request did not fail on setting a header that is present in custom response headers")
|
||||
}
|
||||
h, err := b.(*SystemBackend).Core.uiConfig.Headers(context.Background())
|
||||
if h.Get("Someheader-400") == "400" {
|
||||
t.Fatalf("should not be able to set a header that is in custom response headers")
|
||||
}
|
||||
|
||||
// setting an ui specific header
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/X-CustomUiHeader")
|
||||
req.Data["values"] = []string{"Ui header value"}
|
||||
req.ResponseWriter = hw
|
||||
|
||||
_, err = b.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err != nil {
|
||||
t.Fatal("request failed on setting a header that is not present in custom response headers.", "error:", err)
|
||||
}
|
||||
|
||||
h, err = b.(*SystemBackend).Core.uiConfig.Headers(context.Background())
|
||||
if h.Get("X-CustomUiHeader") != "Ui header value" {
|
||||
t.Fatalf("failed to set a header that is not in custom response headers")
|
||||
}
|
||||
}
|
|
@ -2623,6 +2623,9 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo
|
|||
// Translate the list of values to the valid header string
|
||||
value := http.Header{}
|
||||
for _, v := range values {
|
||||
if b.Core.ExistCustomResponseHeader(header) {
|
||||
return logical.ErrorResponse("This header already exists in the server configuration and cannot be set in the UI."), logical.ErrInvalidRequest
|
||||
}
|
||||
value.Add(header, v)
|
||||
}
|
||||
err := b.Core.uiConfig.SetHeader(ctx, header, value.Values(header))
|
||||
|
|
|
@ -128,6 +128,29 @@ func TestCoreWithSeal(t testing.T, testSeal Seal, enableRaw bool) *Core {
|
|||
return TestCoreWithSealAndUI(t, conf)
|
||||
}
|
||||
|
||||
func TestCoreWithCustomResponseHeaderAndUI(t testing.T, CustomResponseHeaders map[string]map[string]string, enableUI bool) (*Core, [][]byte, string) {
|
||||
confRaw := &server.Config{
|
||||
SharedConfig: &configutil.SharedConfig{
|
||||
Listeners: []*configutil.Listener{
|
||||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1",
|
||||
CustomResponseHeaders: CustomResponseHeaders,
|
||||
},
|
||||
},
|
||||
DisableMlock: true,
|
||||
},
|
||||
}
|
||||
conf := &CoreConfig{
|
||||
RawConfig: confRaw,
|
||||
EnableUI: enableUI,
|
||||
EnableRaw: true,
|
||||
BuiltinRegistry: NewMockBuiltinRegistry(),
|
||||
}
|
||||
core := TestCoreWithSealAndUI(t, conf)
|
||||
return testCoreUnsealed(t, core)
|
||||
}
|
||||
|
||||
func TestCoreUI(t testing.T, enableUI bool) *Core {
|
||||
conf := &CoreConfig{
|
||||
EnableUI: enableUI,
|
||||
|
|
|
@ -32,8 +32,9 @@ type UIConfig struct {
|
|||
// NewUIConfig creates a new UI config
|
||||
func NewUIConfig(enabled bool, physicalStorage physical.Backend, barrierStorage logical.Storage) *UIConfig {
|
||||
defaultHeaders := http.Header{}
|
||||
defaultHeaders.Set("Content-Security-Policy", "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'")
|
||||
defaultHeaders.Set("Service-Worker-Allowed", "/")
|
||||
defaultHeaders.Set("X-Content-Type-Options", "nosniff")
|
||||
defaultHeaders.Set("Content-Security-Policy", "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'")
|
||||
|
||||
return &UIConfig{
|
||||
physicalStorage: physicalStorage,
|
||||
|
|
Loading…
Reference in New Issue