diff --git a/internalshared/configutil/listener.go b/internalshared/configutil/listener.go index 7e08ba3d0..c5463a800 100644 --- a/internalshared/configutil/listener.go +++ b/internalshared/configutil/listener.go @@ -3,6 +3,7 @@ package configutil import ( "errors" "fmt" + "net/textproto" "strings" "time" @@ -11,6 +12,7 @@ import ( "github.com/hashicorp/hcl" "github.com/hashicorp/hcl/hcl/ast" "github.com/hashicorp/vault/sdk/helper/parseutil" + "github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/helper/tlsutil" ) @@ -81,6 +83,12 @@ type Listener struct { // RandomPort is used only for some testing purposes RandomPort bool `hcl:"-"` + + CorsEnabledRaw interface{} `hcl:"cors_enabled"` + CorsEnabled bool `hcl:"-"` + CorsAllowedOrigins []string `hcl:"cors_allowed_origins"` + CorsAllowedHeaders []string `hcl:"-"` + CorsAllowedHeadersRaw []string `hcl:"cors_allowed_headers"` } func (l *Listener) GoString() string { @@ -127,6 +135,8 @@ func ParseListeners(result *SharedConfig, list *ast.ObjectList) error { for i, v := range l.Purpose { l.Purpose[i] = strings.ToLower(v) } + + l.PurposeRaw = nil } } @@ -308,6 +318,27 @@ func ParseListeners(result *SharedConfig, list *ast.ObjectList) error { } } + // CORS + { + if l.CorsEnabledRaw != nil { + if l.CorsEnabled, err = parseutil.ParseBool(l.CorsEnabledRaw); err != nil { + return multierror.Prefix(fmt.Errorf("invalid value for cors_enabled: %w", err), fmt.Sprintf("listeners.%d", i)) + } + + l.CorsEnabledRaw = nil + } + + if strutil.StrListContains(l.CorsAllowedOrigins, "*") && len(l.CorsAllowedOrigins) > 1 { + return multierror.Prefix(errors.New("cors_allowed_origins must only contain a wildcard or only non-wildcard values"), fmt.Sprintf("listeners.%d", i)) + } + + if len(l.CorsAllowedHeadersRaw) > 0 { + for _, header := range l.CorsAllowedHeadersRaw { + l.CorsAllowedHeaders = append(l.CorsAllowedHeaders, textproto.CanonicalMIMEHeaderKey(header)) + } + } + } + result.Listeners = append(result.Listeners, &l) }