open-vault/vault/cors.go

154 lines
3.5 KiB
Go
Raw Normal View History

2017-06-17 04:04:55 +00:00
package vault
import (
"errors"
"fmt"
"sync"
2017-06-17 05:26:25 +00:00
"sync/atomic"
2017-06-17 04:04:55 +00:00
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
)
2017-06-17 05:26:25 +00:00
const (
CORSDisabled uint32 = iota
CORSEnabled
)
2017-06-17 04:04:55 +00:00
2017-08-07 19:02:08 +00:00
var StdAllowedHeaders = []string{
"Content-Type",
"X-Requested-With",
"X-Vault-AWS-IAM-Server-ID",
"X-Vault-MFA",
"X-Vault-No-Request-Forwarding",
"X-Vault-Token",
"X-Vault-Wrap-Format",
"X-Vault-Wrap-TTL",
}
2017-06-17 04:04:55 +00:00
// CORSConfig stores the state of the CORS configuration.
type CORSConfig struct {
2017-06-17 05:26:25 +00:00
sync.RWMutex `json:"-"`
core *Core
Enabled uint32 `json:"enabled"`
AllowedOrigins []string `json:"allowed_origins,omitempty"`
AllowedHeaders []string `json:"allowed_headers,omitempty"`
2017-06-17 04:04:55 +00:00
}
func (c *Core) saveCORSConfig() error {
view := c.systemBarrierView.SubView("config/")
2017-06-17 05:26:25 +00:00
localConfig := &CORSConfig{
Enabled: atomic.LoadUint32(&c.corsConfig.Enabled),
}
c.corsConfig.RLock()
localConfig.AllowedOrigins = c.corsConfig.AllowedOrigins
localConfig.AllowedHeaders = c.corsConfig.AllowedHeaders
2017-06-17 05:26:25 +00:00
c.corsConfig.RUnlock()
entry, err := logical.StorageEntryJSON("cors", localConfig)
2017-06-17 04:04:55 +00:00
if err != nil {
return fmt.Errorf("failed to create CORS config entry: %v", err)
}
if err := view.Put(entry); err != nil {
return fmt.Errorf("failed to save CORS config: %v", err)
}
return nil
}
2017-06-17 05:26:25 +00:00
// This should only be called with the core state lock held for writing
2017-06-17 04:04:55 +00:00
func (c *Core) loadCORSConfig() error {
view := c.systemBarrierView.SubView("config/")
// Load the config in
out, err := view.Get("cors")
if err != nil {
return fmt.Errorf("failed to read CORS config: %v", err)
}
if out == nil {
return nil
}
2017-06-17 05:26:25 +00:00
newConfig := new(CORSConfig)
err = out.DecodeJSON(newConfig)
2017-06-17 04:04:55 +00:00
if err != nil {
return err
}
2017-06-17 05:26:25 +00:00
newConfig.core = c
c.corsConfig = newConfig
2017-06-17 04:04:55 +00:00
return nil
}
// Enable takes either a '*' or a comma-seprated list of URLs that can make
// cross-origin requests to Vault.
func (c *CORSConfig) Enable(urls []string, headers []string) error {
2017-06-17 04:04:55 +00:00
if len(urls) == 0 {
return errors.New("at least one origin or the wildcard must be provided.")
2017-06-17 04:04:55 +00:00
}
if strutil.StrListContains(urls, "*") && len(urls) > 1 {
return errors.New("to allow all origins the '*' must be the only value for allowed_origins")
}
c.Lock()
c.AllowedOrigins = urls
// Start with the standard headers to Vault accepts.
2017-08-07 19:02:08 +00:00
c.AllowedHeaders = append(c.AllowedHeaders, StdAllowedHeaders...)
// Allow the user to add additional headers to the list of
// headers allowed on cross-origin requests.
if len(headers) > 0 {
c.AllowedHeaders = append(c.AllowedHeaders, headers...)
}
2017-06-17 05:26:25 +00:00
c.Unlock()
2017-06-17 04:04:55 +00:00
2017-06-17 05:26:25 +00:00
atomic.StoreUint32(&c.Enabled, CORSEnabled)
return c.core.saveCORSConfig()
2017-06-17 04:04:55 +00:00
}
// IsEnabled returns the value of CORSConfig.isEnabled
func (c *CORSConfig) IsEnabled() bool {
2017-06-17 05:26:25 +00:00
return atomic.LoadUint32(&c.Enabled) == CORSEnabled
2017-06-17 04:04:55 +00:00
}
// Disable sets CORS to disabled and clears the allowed origins & headers.
2017-06-17 05:26:25 +00:00
func (c *CORSConfig) Disable() error {
atomic.StoreUint32(&c.Enabled, CORSDisabled)
2017-06-17 04:04:55 +00:00
c.Lock()
c.AllowedOrigins = nil
c.AllowedHeaders = nil
2017-06-17 05:26:25 +00:00
c.Unlock()
2017-06-17 05:26:25 +00:00
return c.core.saveCORSConfig()
2017-06-17 04:04:55 +00:00
}
// IsValidOrigin determines if the origin of the request is allowed to make
// cross-origin requests based on the CORSConfig.
func (c *CORSConfig) IsValidOrigin(origin string) bool {
2017-06-17 05:26:25 +00:00
// If we aren't enabling CORS then all origins are valid
if !c.IsEnabled() {
return true
}
2017-06-17 04:04:55 +00:00
c.RLock()
defer c.RUnlock()
2017-06-17 05:26:25 +00:00
if len(c.AllowedOrigins) == 0 {
2017-06-17 04:04:55 +00:00
return false
}
if len(c.AllowedOrigins) == 1 && (c.AllowedOrigins)[0] == "*" {
return true
}
return strutil.StrListContains(c.AllowedOrigins, origin)
}