open-vault/vault/custom_response_headers.go

91 lines
2.9 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package vault
import (
"fmt"
"net/http"
"net/textproto"
"strings"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/sdk/logical"
)
type ListenerCustomHeaders struct {
Address string
StatusCodeHeaderMap map[string][]*logical.CustomHeader
// ConfiguredHeadersStatusCodeMap field is introduced so that we would not need to loop through
// StatusCodeHeaderMap to see if a header exists, the key for this map is the headers names
configuredHeadersStatusCodeMap map[string][]string
}
func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHeaders http.Header) []*ListenerCustomHeaders {
var listenerCustomHeadersList []*ListenerCustomHeaders
for _, l := range ln {
listenerCustomHeaderStruct := &ListenerCustomHeaders{
Address: l.Address,
}
listenerCustomHeaderStruct.StatusCodeHeaderMap = make(map[string][]*logical.CustomHeader)
listenerCustomHeaderStruct.configuredHeadersStatusCodeMap = make(map[string][]string)
for statusCode, headerValMap := range l.CustomResponseHeaders {
var customHeaderList []*logical.CustomHeader
for headerName, headerVal := range headerValMap {
// Sanitizing custom headers
// X-Vault- prefix is reserved for Vault internal processes
if strings.HasPrefix(headerName, "X-Vault-") {
logger.Warn("custom headers starting with X-Vault are not valid", "header", headerName)
continue
}
// Checking for UI headers, if any common header exists, we just log an error
if uiHeaders != nil {
exist := uiHeaders.Get(headerName)
if exist != "" {
logger.Warn(fmt.Sprintf("found a duplicate header in UI: header=%s. Headers defined in the server configuration take precedence.", headerName))
}
}
// Checking if the header value is not an empty string
if headerVal == "" {
logger.Warn("header value is an empty string", "header", headerName, "value", headerVal)
continue
}
ch := &logical.CustomHeader{
Name: headerName,
Value: headerVal,
}
customHeaderList = append(customHeaderList, ch)
// setting up the reverse map of header to status code for easy lookups
listenerCustomHeaderStruct.configuredHeadersStatusCodeMap[headerName] = append(listenerCustomHeaderStruct.configuredHeadersStatusCodeMap[headerName], statusCode)
}
listenerCustomHeaderStruct.StatusCodeHeaderMap[statusCode] = customHeaderList
}
listenerCustomHeadersList = append(listenerCustomHeadersList, listenerCustomHeaderStruct)
}
return listenerCustomHeadersList
}
func (l *ListenerCustomHeaders) ExistCustomResponseHeader(header string) bool {
if header == "" {
return false
}
if l.StatusCodeHeaderMap == nil {
return false
}
headerName := textproto.CanonicalMIMEHeaderKey(header)
headerMap := l.configuredHeadersStatusCodeMap
_, ok := headerMap[headerName]
return ok
}