156 lines
3.7 KiB
Go
156 lines
3.7 KiB
Go
package vault
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/hashicorp/errwrap"
|
|
"github.com/hashicorp/vault/helper/strutil"
|
|
"github.com/hashicorp/vault/logical"
|
|
)
|
|
|
|
const (
|
|
CORSDisabled uint32 = iota
|
|
CORSEnabled
|
|
)
|
|
|
|
var StdAllowedHeaders = []string{
|
|
"Content-Type",
|
|
"X-Requested-With",
|
|
"X-Vault-AWS-IAM-Server-ID",
|
|
"X-Vault-MFA",
|
|
"X-Vault-No-Request-Forwarding",
|
|
"X-Vault-Token",
|
|
"X-Vault-Wrap-Format",
|
|
"X-Vault-Wrap-TTL",
|
|
"X-Vault-Policy-Override",
|
|
}
|
|
|
|
// CORSConfig stores the state of the CORS configuration.
|
|
type CORSConfig struct {
|
|
sync.RWMutex `json:"-"`
|
|
core *Core
|
|
Enabled uint32 `json:"enabled"`
|
|
AllowedOrigins []string `json:"allowed_origins,omitempty"`
|
|
AllowedHeaders []string `json:"allowed_headers,omitempty"`
|
|
}
|
|
|
|
func (c *Core) saveCORSConfig(ctx context.Context) error {
|
|
view := c.systemBarrierView.SubView("config/")
|
|
|
|
localConfig := &CORSConfig{
|
|
Enabled: atomic.LoadUint32(&c.corsConfig.Enabled),
|
|
}
|
|
c.corsConfig.RLock()
|
|
localConfig.AllowedOrigins = c.corsConfig.AllowedOrigins
|
|
localConfig.AllowedHeaders = c.corsConfig.AllowedHeaders
|
|
c.corsConfig.RUnlock()
|
|
|
|
entry, err := logical.StorageEntryJSON("cors", localConfig)
|
|
if err != nil {
|
|
return errwrap.Wrapf("failed to create CORS config entry: {{err}}", err)
|
|
}
|
|
|
|
if err := view.Put(ctx, entry); err != nil {
|
|
return errwrap.Wrapf("failed to save CORS config: {{err}}", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// This should only be called with the core state lock held for writing
|
|
func (c *Core) loadCORSConfig(ctx context.Context) error {
|
|
view := c.systemBarrierView.SubView("config/")
|
|
|
|
// Load the config in
|
|
out, err := view.Get(ctx, "cors")
|
|
if err != nil {
|
|
return errwrap.Wrapf("failed to read CORS config: {{err}}", err)
|
|
}
|
|
if out == nil {
|
|
return nil
|
|
}
|
|
|
|
newConfig := new(CORSConfig)
|
|
err = out.DecodeJSON(newConfig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
newConfig.core = c
|
|
|
|
c.corsConfig = newConfig
|
|
|
|
return nil
|
|
}
|
|
|
|
// Enable takes either a '*' or a comma-separated list of URLs that can make
|
|
// cross-origin requests to Vault.
|
|
func (c *CORSConfig) Enable(ctx context.Context, urls []string, headers []string) error {
|
|
if len(urls) == 0 {
|
|
return errors.New("at least one origin or the wildcard must be provided")
|
|
}
|
|
|
|
if strutil.StrListContains(urls, "*") && len(urls) > 1 {
|
|
return errors.New("to allow all origins the '*' must be the only value for allowed_origins")
|
|
}
|
|
|
|
c.Lock()
|
|
c.AllowedOrigins = urls
|
|
|
|
// Start with the standard headers to Vault accepts.
|
|
c.AllowedHeaders = append(c.AllowedHeaders, StdAllowedHeaders...)
|
|
|
|
// Allow the user to add additional headers to the list of
|
|
// headers allowed on cross-origin requests.
|
|
if len(headers) > 0 {
|
|
c.AllowedHeaders = append(c.AllowedHeaders, headers...)
|
|
}
|
|
c.Unlock()
|
|
|
|
atomic.StoreUint32(&c.Enabled, CORSEnabled)
|
|
|
|
return c.core.saveCORSConfig(ctx)
|
|
}
|
|
|
|
// IsEnabled returns the value of CORSConfig.isEnabled
|
|
func (c *CORSConfig) IsEnabled() bool {
|
|
return atomic.LoadUint32(&c.Enabled) == CORSEnabled
|
|
}
|
|
|
|
// Disable sets CORS to disabled and clears the allowed origins & headers.
|
|
func (c *CORSConfig) Disable(ctx context.Context) error {
|
|
atomic.StoreUint32(&c.Enabled, CORSDisabled)
|
|
c.Lock()
|
|
|
|
c.AllowedOrigins = nil
|
|
c.AllowedHeaders = nil
|
|
|
|
c.Unlock()
|
|
|
|
return c.core.saveCORSConfig(ctx)
|
|
}
|
|
|
|
// IsValidOrigin determines if the origin of the request is allowed to make
|
|
// cross-origin requests based on the CORSConfig.
|
|
func (c *CORSConfig) IsValidOrigin(origin string) bool {
|
|
// If we aren't enabling CORS then all origins are valid
|
|
if !c.IsEnabled() {
|
|
return true
|
|
}
|
|
|
|
c.RLock()
|
|
defer c.RUnlock()
|
|
|
|
if len(c.AllowedOrigins) == 0 {
|
|
return false
|
|
}
|
|
|
|
if len(c.AllowedOrigins) == 1 && (c.AllowedOrigins)[0] == "*" {
|
|
return true
|
|
}
|
|
|
|
return strutil.StrListContains(c.AllowedOrigins, origin)
|
|
}
|