Set allowed headers via API instead of defaulting to wildcard. (#3023)

This commit is contained in:
Aaron Salvo 2017-08-07 10:03:30 -04:00 committed by Jeff Mitchell
parent 3fb75beb59
commit ad1d74cae0
8 changed files with 142 additions and 24 deletions

View File

@ -9,11 +9,6 @@ import (
"github.com/hashicorp/vault/vault"
)
var preflightHeaders = map[string]string{
"Access-Control-Allow-Headers": "*",
"Access-Control-Max-Age": "300",
}
var allowedMethods = []string{
http.MethodDelete,
http.MethodGet,
@ -38,8 +33,7 @@ func wrapCORSHandler(h http.Handler, core *vault.Core) http.Handler {
return
}
// Return a 403 if the origin is not
// allowed to make cross-origin requests.
// Return a 403 if the origin is not allowed to make cross-origin requests.
if !corsConf.IsValidOrigin(origin) {
respondError(w, http.StatusForbidden, fmt.Errorf("origin not allowed"))
return
@ -56,10 +50,9 @@ func wrapCORSHandler(h http.Handler, core *vault.Core) http.Handler {
// apply headers for preflight requests
if req.Method == http.MethodOptions {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ","))
w.Header().Set("Access-Control-Allow-Headers", strings.Join(corsConf.AllowedHeaders, ","))
w.Header().Set("Access-Control-Max-Age", "300")
for k, v := range preflightHeaders {
w.Header().Set(k, v)
}
return
}

View File

@ -6,6 +6,7 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"github.com/hashicorp/go-cleanhttp"
@ -21,7 +22,7 @@ func TestHandler_cors(t *testing.T) {
// Enable CORS and allow from any origin for testing.
corsConfig := core.CORSConfig()
err := corsConfig.Enable([]string{addr})
err := corsConfig.Enable([]string{addr}, nil)
if err != nil {
t.Fatalf("Error enabling CORS: %s", err)
}
@ -78,7 +79,7 @@ func TestHandler_cors(t *testing.T) {
//
expHeaders := map[string]string{
"Access-Control-Allow-Origin": addr,
"Access-Control-Allow-Headers": "*",
"Access-Control-Allow-Headers": strings.Join(stdAllowedHeaders, ","),
"Access-Control-Max-Age": "300",
"Vary": "Origin",
}

View File

@ -0,0 +1,78 @@
package http
import (
"encoding/json"
"net/http"
"reflect"
"testing"
"github.com/hashicorp/vault/vault"
)
func TestSysConfigCors(t *testing.T) {
var resp *http.Response
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := TestServer(t, core)
defer ln.Close()
TestServerAuth(t, addr, token)
corsConf := core.CORSConfig()
// Try to enable CORS without providing a value for allowed_origins
resp = testHttpPut(t, token, addr+"/v1/sys/config/cors", map[string]interface{}{
"allowed_headers": "X-Custom-Header",
})
testResponseStatus(t, resp, 500)
// Enable CORS, but provide an origin this time.
resp = testHttpPut(t, token, addr+"/v1/sys/config/cors", map[string]interface{}{
"allowed_origins": addr,
"allowed_headers": "X-Custom-Header",
})
testResponseStatus(t, resp, 204)
// Read the CORS configuration
resp = testHttpGet(t, token, addr+"/v1/sys/config/cors")
testResponseStatus(t, resp, 200)
var actual map[string]interface{}
var expected map[string]interface{}
lenStdHeaders := len(corsConf.AllowedHeaders)
expectedHeaders := make([]interface{}, lenStdHeaders)
for i := range corsConf.AllowedHeaders {
expectedHeaders[i] = corsConf.AllowedHeaders[i]
}
expected = map[string]interface{}{
"lease_id": "",
"renewable": false,
"lease_duration": json.Number("0"),
"wrap_info": nil,
"warnings": nil,
"auth": nil,
"data": map[string]interface{}{
"enabled": true,
"allowed_origins": []interface{}{addr},
"allowed_headers": expectedHeaders,
},
"enabled": true,
"allowed_origins": []interface{}{addr},
"allowed_headers": expectedHeaders,
}
testResponseStatus(t, resp, 200)
testResponseBody(t, resp, &actual)
expected["request_id"] = actual["request_id"]
if !reflect.DeepEqual(actual, expected) {
t.Fatalf("bad: expected: %#v\nactual: %#v", expected, actual)
}
}

View File

@ -459,8 +459,8 @@ func NewCore(conf *CoreConfig) (*Core, error) {
enableMlock: !conf.DisableMlock,
}
// Load CORS config and provide core
c.corsConfig = &CORSConfig{core: c}
// Load CORS config and provide a value for the core field.
// Wrap the physical backend in a cache layer if enabled and not already wrapped
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {

View File

@ -15,12 +15,24 @@ const (
CORSEnabled
)
var stdAllowedHeaders = []string{
"Content-Type",
"X-Requested-With",
"X-Vault-AWS-IAM-Server-ID",
"X-Vault-MFA",
"X-Vault-No-Request-Forwarding",
"X-Vault-Token",
"X-Vault-Wrap-Format",
"X-Vault-Wrap-TTL",
}
// CORSConfig stores the state of the CORS configuration.
type CORSConfig struct {
sync.RWMutex `json:"-"`
core *Core
Enabled uint32 `json:"enabled"`
AllowedOrigins []string `json:"allowed_origins,omitempty"`
AllowedHeaders []string `json:"allowed_headers,omitempty"`
}
func (c *Core) saveCORSConfig() error {
@ -31,6 +43,7 @@ func (c *Core) saveCORSConfig() error {
}
c.corsConfig.RLock()
localConfig.AllowedOrigins = c.corsConfig.AllowedOrigins
localConfig.AllowedHeaders = c.corsConfig.AllowedHeaders
c.corsConfig.RUnlock()
entry, err := logical.StorageEntryJSON("cors", localConfig)
@ -72,9 +85,9 @@ func (c *Core) loadCORSConfig() error {
// Enable takes either a '*' or a comma-seprated list of URLs that can make
// cross-origin requests to Vault.
func (c *CORSConfig) Enable(urls []string) error {
func (c *CORSConfig) Enable(urls []string, headers []string) error {
if len(urls) == 0 {
return errors.New("the list of allowed origins cannot be empty")
return errors.New("at least one origin or the wildcard must be provided.")
}
if strutil.StrListContains(urls, "*") && len(urls) > 1 {
@ -83,6 +96,15 @@ func (c *CORSConfig) Enable(urls []string) error {
c.Lock()
c.AllowedOrigins = urls
// Start with the standard headers to Vault accepts.
c.AllowedHeaders = append(c.AllowedHeaders, stdAllowedHeaders...)
// Allow the user to add additional headers to the list of
// headers allowed on cross-origin requests.
if len(headers) > 0 {
c.AllowedHeaders = append(c.AllowedHeaders, headers...)
}
c.Unlock()
atomic.StoreUint32(&c.Enabled, CORSEnabled)
@ -95,12 +117,16 @@ func (c *CORSConfig) IsEnabled() bool {
return atomic.LoadUint32(&c.Enabled) == CORSEnabled
}
// Disable sets CORS to disabled and clears the allowed origins
// Disable sets CORS to disabled and clears the allowed origins & headers.
func (c *CORSConfig) Disable() error {
atomic.StoreUint32(&c.Enabled, CORSDisabled)
c.Lock()
c.AllowedOrigins = []string(nil)
c.AllowedOrigins = nil
c.AllowedHeaders = nil
c.Unlock()
return c.core.saveCORSConfig()
}

View File

@ -115,6 +115,10 @@ func NewSystemBackend(core *Core) *SystemBackend {
Type: framework.TypeCommaStringSlice,
Description: "A comma-separated string or array of strings indicating origins that may make cross-origin requests.",
},
"allowed_headers": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: "A comma-separated string or array of strings indicating headers that are allowed on cross-origin requests.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -854,6 +858,7 @@ func (b *SystemBackend) handleCORSRead(req *logical.Request, d *framework.FieldD
if enabled {
corsConf.RLock()
resp.Data["allowed_origins"] = corsConf.AllowedOrigins
resp.Data["allowed_headers"] = corsConf.AllowedHeaders
corsConf.RUnlock()
}
@ -864,12 +869,13 @@ func (b *SystemBackend) handleCORSRead(req *logical.Request, d *framework.FieldD
// cross-origin requests and sets the CORS enabled flag to true
func (b *SystemBackend) handleCORSUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
origins := d.Get("allowed_origins").([]string)
headers := d.Get("allowed_headers").([]string)
return nil, b.Core.corsConfig.Enable(origins)
return nil, b.Core.corsConfig.Enable(origins, headers)
}
// handleCORSDelete clears the allowed origins and sets the CORS enabled flag
// to false
// handleCORSDelete sets the CORS enabled flag to false and clears the list of
// allowed origins & headers.
func (b *SystemBackend) handleCORSDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
return nil, b.Core.corsConfig.Disable()
}

View File

@ -56,6 +56,7 @@ func TestSystemConfigCORS(t *testing.T) {
req := logical.TestRequest(t, logical.UpdateOperation, "config/cors")
req.Data["allowed_origins"] = "http://www.example.com"
req.Data["allowed_headers"] = "X-Custom-Header"
_, err := b.HandleRequest(req)
if err != nil {
t.Fatal(err)
@ -65,6 +66,7 @@ func TestSystemConfigCORS(t *testing.T) {
Data: map[string]interface{}{
"enabled": true,
"allowed_origins": []string{"http://www.example.com"},
"allowed_headers": append(stdAllowedHeaders, "X-Custom-Header"),
},
}

View File

@ -34,14 +34,23 @@ $ curl \
```json
{
"enabled": true,
"allowed_origins": "http://www.example.com"
"allowed_origins": ["http://www.example.com"],
"allowed_headers": [
"Content-Type",
"X-Requested-With",
"X-Vault-AWS-IAM-Server-ID",
"X-Vault-No-Request-Forwarding",
"X-Vault-Token",
"X-Vault-Wrap-Format",
"X-Vault-Wrap-TTL",
]
}
```
## Configure CORS Settings
This endpoint allows configuring the origins that are permitted to make
cross-origin requests.
cross-origin requests, as well as headers that are allowed on cross-origin requests.
| Method | Path | Produces |
| :------- | :--------------------------- | :--------------------- |
@ -49,13 +58,16 @@ cross-origin requests.
### Parameters
- `allowed_origins` `(string or string array: "" or [])`  A wildcard (`*`), comma-delimited string, or array of strings specifying the origins that are permitted to make cross-origin requests.
- `allowed_origins` `(string or string array: <required>)`  A wildcard (`*`), comma-delimited string, or array of strings specifying the origins that are permitted to make cross-origin requests.
- `allowed_headers` `(string or string array: "" or [])`  A comma-delimited string or array of strings specifying headers that are permitted to be on cross-origin requests. Headers set via this parameter will be appended to the list of headers that Vault allows by default.
### Sample Payload
```json
{
"allowed_origins": "*"
"allowed_origins": "*",
"allowed_headers": "X-Custom-Header"
}
```