91 lines
2.9 KiB
Go
91 lines
2.9 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package vault
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"net/textproto"
|
|
"strings"
|
|
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/vault/internalshared/configutil"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
)
|
|
|
|
type ListenerCustomHeaders struct {
|
|
Address string
|
|
StatusCodeHeaderMap map[string][]*logical.CustomHeader
|
|
// ConfiguredHeadersStatusCodeMap field is introduced so that we would not need to loop through
|
|
// StatusCodeHeaderMap to see if a header exists, the key for this map is the headers names
|
|
configuredHeadersStatusCodeMap map[string][]string
|
|
}
|
|
|
|
func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHeaders http.Header) []*ListenerCustomHeaders {
|
|
var listenerCustomHeadersList []*ListenerCustomHeaders
|
|
|
|
for _, l := range ln {
|
|
listenerCustomHeaderStruct := &ListenerCustomHeaders{
|
|
Address: l.Address,
|
|
}
|
|
listenerCustomHeaderStruct.StatusCodeHeaderMap = make(map[string][]*logical.CustomHeader)
|
|
listenerCustomHeaderStruct.configuredHeadersStatusCodeMap = make(map[string][]string)
|
|
for statusCode, headerValMap := range l.CustomResponseHeaders {
|
|
var customHeaderList []*logical.CustomHeader
|
|
for headerName, headerVal := range headerValMap {
|
|
// Sanitizing custom headers
|
|
// X-Vault- prefix is reserved for Vault internal processes
|
|
if strings.HasPrefix(headerName, "X-Vault-") {
|
|
logger.Warn("custom headers starting with X-Vault are not valid", "header", headerName)
|
|
continue
|
|
}
|
|
|
|
// Checking for UI headers, if any common header exists, we just log an error
|
|
if uiHeaders != nil {
|
|
exist := uiHeaders.Get(headerName)
|
|
if exist != "" {
|
|
logger.Warn(fmt.Sprintf("found a duplicate header in UI: header=%s. Headers defined in the server configuration take precedence.", headerName))
|
|
}
|
|
}
|
|
|
|
// Checking if the header value is not an empty string
|
|
if headerVal == "" {
|
|
logger.Warn("header value is an empty string", "header", headerName, "value", headerVal)
|
|
continue
|
|
}
|
|
|
|
ch := &logical.CustomHeader{
|
|
Name: headerName,
|
|
Value: headerVal,
|
|
}
|
|
|
|
customHeaderList = append(customHeaderList, ch)
|
|
|
|
// setting up the reverse map of header to status code for easy lookups
|
|
listenerCustomHeaderStruct.configuredHeadersStatusCodeMap[headerName] = append(listenerCustomHeaderStruct.configuredHeadersStatusCodeMap[headerName], statusCode)
|
|
}
|
|
listenerCustomHeaderStruct.StatusCodeHeaderMap[statusCode] = customHeaderList
|
|
}
|
|
listenerCustomHeadersList = append(listenerCustomHeadersList, listenerCustomHeaderStruct)
|
|
}
|
|
|
|
return listenerCustomHeadersList
|
|
}
|
|
|
|
func (l *ListenerCustomHeaders) ExistCustomResponseHeader(header string) bool {
|
|
if header == "" {
|
|
return false
|
|
}
|
|
|
|
if l.StatusCodeHeaderMap == nil {
|
|
return false
|
|
}
|
|
|
|
headerName := textproto.CanonicalMIMEHeaderKey(header)
|
|
|
|
headerMap := l.configuredHeadersStatusCodeMap
|
|
_, ok := headerMap[headerName]
|
|
return ok
|
|
}
|