From 0303f51b680878b4afdb36f2af7e4376115bcefa Mon Sep 17 00:00:00 2001 From: Aaron Salvo Date: Sat, 17 Jun 2017 00:04:55 -0400 Subject: [PATCH] Cors headers (#2021) --- api/sys_config_cors.go | 56 +++++++++ cli/commands.go | 1 - http/cors.go | 68 +++++++++++ http/handler.go | 3 +- http/handler_test.go | 81 +++++++++++++ http/http_test.go | 6 + http/logical.go | 1 + vault/core.go | 15 +++ vault/cors.go | 108 +++++++++++++++++ vault/logical_system.go | 81 ++++++++++++- vault/logical_system_test.go | 53 +++++++++ website/source/api/index.html.md | 4 +- .../source/docs/http/sys-config-cors.html.md | 109 ++++++++++++++++++ 13 files changed, 580 insertions(+), 6 deletions(-) create mode 100644 api/sys_config_cors.go create mode 100644 http/cors.go create mode 100644 vault/cors.go create mode 100644 website/source/docs/http/sys-config-cors.html.md diff --git a/api/sys_config_cors.go b/api/sys_config_cors.go new file mode 100644 index 000000000..e7f2a5945 --- /dev/null +++ b/api/sys_config_cors.go @@ -0,0 +1,56 @@ +package api + +func (c *Sys) CORSStatus() (*CORSResponse, error) { + r := c.c.NewRequest("GET", "/v1/sys/config/cors") + resp, err := c.c.RawRequest(r) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var result CORSResponse + err = resp.DecodeJSON(&result) + return &result, err +} + +func (c *Sys) ConfigureCORS(req *CORSRequest) (*CORSResponse, error) { + r := c.c.NewRequest("PUT", "/v1/sys/config/cors") + if err := r.SetJSONBody(req); err != nil { + return nil, err + } + + resp, err := c.c.RawRequest(r) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var result CORSResponse + err = resp.DecodeJSON(&result) + return &result, err +} + +func (c *Sys) DisableCORS() (*CORSResponse, error) { + r := c.c.NewRequest("DELETE", "/v1/sys/config/cors") + + resp, err := c.c.RawRequest(r) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var result CORSResponse + err = resp.DecodeJSON(&result) + return &result, err + +} + +type CORSRequest struct { + AllowedOrigins string `json:"allowed_origins"` + Enabled bool `json:"enabled"` +} + +type CORSResponse struct { + AllowedOrigins string `json:"allowed_origins"` + Enabled bool `json:"enabled"` +} diff --git a/cli/commands.go b/cli/commands.go index 0d1945f11..3fab33956 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -61,7 +61,6 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory { Meta: *metaPtr, }, nil }, - "server": func() (cli.Command, error) { return &command.ServerCommand{ Meta: *metaPtr, diff --git a/http/cors.go b/http/cors.go new file mode 100644 index 000000000..5bd0a1366 --- /dev/null +++ b/http/cors.go @@ -0,0 +1,68 @@ +package http + +import ( + "net/http" + "strings" + + "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/vault" +) + +var preflightHeaders = map[string]string{ + "Access-Control-Allow-Headers": "*", + "Access-Control-Max-Age": "300", +} + +var allowedMethods = []string{ + http.MethodDelete, + http.MethodGet, + http.MethodOptions, + http.MethodPost, + http.MethodPut, + "LIST", // LIST is not an official HTTP method, but Vault supports it. +} + +func wrapCORSHandler(h http.Handler, core *vault.Core) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + corsConf := core.CORSConfig() + + origin := req.Header.Get("Origin") + requestMethod := req.Header.Get("Access-Control-Request-Method") + + // If CORS is not enabled or if no Origin header is present (i.e. the request + // is from the Vault CLI. A browser will always send an Origin header), then + // just return a 204. + if !corsConf.IsEnabled() || origin == "" { + h.ServeHTTP(w, req) + return + } + + // Return a 403 if the origin is not + // allowed to make cross-origin requests. + if !corsConf.IsValidOrigin(origin) { + w.WriteHeader(http.StatusForbidden) + return + } + + if req.Method == http.MethodOptions && !strutil.StrListContains(allowedMethods, requestMethod) { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Vary", "Origin") + + // apply headers for preflight requests + if req.Method == http.MethodOptions { + w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ",")) + + for k, v := range preflightHeaders { + w.Header().Set(k, v) + } + return + } + + h.ServeHTTP(w, req) + return + }) +} diff --git a/http/handler.go b/http/handler.go index 845cabaa8..bc5914aca 100644 --- a/http/handler.go +++ b/http/handler.go @@ -67,10 +67,11 @@ func Handler(core *vault.Core) http.Handler { // Wrap the handler in another handler to trigger all help paths. helpWrappedHandler := wrapHelpHandler(mux, core) + corsWrappedHandler := wrapCORSHandler(helpWrappedHandler, core) // Wrap the help wrapped handler with another layer with a generic // handler - genericWrappedHandler := wrapGenericHandler(helpWrappedHandler) + genericWrappedHandler := wrapGenericHandler(corsWrappedHandler) return genericWrappedHandler } diff --git a/http/handler_test.go b/http/handler_test.go index 149e60373..8450a8b6b 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -14,6 +14,87 @@ import ( "github.com/hashicorp/vault/vault" ) +func TestHandler_cors(t *testing.T) { + core, _, _ := vault.TestCoreUnsealed(t) + ln, addr := TestServer(t, core) + defer ln.Close() + + // Enable CORS and allow from any origin for testing. + corsConfig := core.CORSConfig() + err := corsConfig.Enable([]string{addr}) + if err != nil { + t.Fatalf("Error enabling CORS: %s", err) + } + + req, err := http.NewRequest(http.MethodOptions, addr+"/v1/sys/seal-status", nil) + if err != nil { + t.Fatalf("err: %s", err) + } + req.Header.Set("Origin", "BAD ORIGIN") + + // Requests from unacceptable origins will be rejected with a 403. + client := cleanhttp.DefaultClient() + resp, err := client.Do(req) + if err != nil { + t.Fatalf("err: %s", err) + } + + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("Bad status:\nexpected: 403 Forbidden\nactual: %s", resp.Status) + } + + // + // Test preflight requests + // + + // Set a valid origin + req.Header.Set("Origin", addr) + + // Server should NOT accept arbitrary methods. + req.Header.Set("Access-Control-Request-Method", "FOO") + + client = cleanhttp.DefaultClient() + resp, err = client.Do(req) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Fail if an arbitrary method is accepted. + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Fatalf("Bad status:\nexpected: 405 Method Not Allowed\nactual: %s", resp.Status) + } + + // Server SHOULD accept acceptable methods. + req.Header.Set("Access-Control-Request-Method", http.MethodPost) + + client = cleanhttp.DefaultClient() + resp, err = client.Do(req) + if err != nil { + t.Fatalf("err: %s", err) + } + + // + // Test that the CORS headers are applied correctly. + // + expHeaders := map[string]string{ + "Access-Control-Allow-Origin": addr, + "Access-Control-Allow-Headers": "*", + "Access-Control-Max-Age": "300", + "Vary": "Origin", + } + + for expHeader, expected := range expHeaders { + actual := resp.Header.Get(expHeader) + if actual == "" { + t.Fatalf("bad:\nHeader: %#v was not on response.", expHeader) + } + + if actual != expected { + t.Fatalf("bad:\nExpected: %#v\nActual: %#v\n", expected, actual) + } + } +} + func TestHandler_CacheControlNoStore(t *testing.T) { core, _, token := vault.TestCoreUnsealed(t) ln, addr := TestServer(t, core) diff --git a/http/http_test.go b/http/http_test.go index 16e052171..eb43817e3 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "regexp" "strings" "testing" "time" @@ -55,6 +56,11 @@ func testHttpData(t *testing.T, method string, token string, addr string, body i t.Fatalf("err: %s", err) } + // Get the address of the local listener in order to attach it to an Origin header. + // This will allow for the testing of requests that require CORS, without using a browser. + hostURLRegexp, _ := regexp.Compile("http[s]?://.+:[0-9]+") + req.Header.Set("Origin", hostURLRegexp.FindString(addr)) + req.Header.Set("Content-Type", "application/json") if len(token) != 0 { diff --git a/http/logical.go b/http/logical.go index f73e532d6..bc6355ce2 100644 --- a/http/logical.go +++ b/http/logical.go @@ -49,6 +49,7 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques op = logical.UpdateOperation case "LIST": op = logical.ListOperation + case "OPTIONS": default: return nil, http.StatusMethodNotAllowed, nil } diff --git a/vault/core.go b/vault/core.go index 01ebdd0d4..b940dbc6c 100644 --- a/vault/core.go +++ b/vault/core.go @@ -331,6 +331,9 @@ type Core struct { // The grpc forwarding client rpcForwardingClient *forwardingClient + // CORS Information + corsConfig *CORSConfig + // replicationState keeps the current replication state cached for quick // lookup replicationState consts.ReplicationState @@ -447,6 +450,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { clusterName: conf.ClusterName, clusterListenerShutdownCh: make(chan struct{}), clusterListenerShutdownSuccessCh: make(chan struct{}), + corsConfig: &CORSConfig{}, clusterPeerClusterAddrsCache: cache.New(3*heartbeatInterval, time.Second), enableMlock: !conf.DisableMlock, } @@ -555,6 +559,11 @@ func (c *Core) Shutdown() error { return c.sealInternal() } +// CORSConfig returns the current CORS configuration +func (c *Core) CORSConfig() *CORSConfig { + return c.corsConfig +} + // LookupToken returns the properties of the token from the token store. This // is particularly useful to fetch the accessor of the client token and get it // populated in the logical request along with the client token. The accessor @@ -1291,6 +1300,9 @@ func (c *Core) postUnseal() (retErr error) { if err := c.setupPolicyStore(); err != nil { return err } + if err := c.loadCORSConfig(); err != nil { + return err + } if err := c.loadCredentials(); err != nil { return err } @@ -1356,6 +1368,9 @@ func (c *Core) preSeal() error { if err := c.teardownPolicyStore(); err != nil { result = multierror.Append(result, errwrap.Wrapf("error tearing down policy store: {{err}}", err)) } + if err := c.saveCORSConfig(); err != nil { + result = multierror.Append(result, errwrap.Wrapf("error tearing down CORS config: {{err}}", err)) + } if err := c.stopRollback(); err != nil { result = multierror.Append(result, errwrap.Wrapf("error stopping rollback: {{err}}", err)) } diff --git a/vault/cors.go b/vault/cors.go new file mode 100644 index 000000000..288c57b49 --- /dev/null +++ b/vault/cors.go @@ -0,0 +1,108 @@ +package vault + +import ( + "errors" + "fmt" + "sync" + + "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/logical" +) + +var errCORSNotConfigured = errors.New("CORS is not configured") + +// CORSConfig stores the state of the CORS configuration. +type CORSConfig struct { + sync.RWMutex + Enabled bool `json:"enabled"` + AllowedOrigins []string `json:"allowed_origins"` +} + +func (c *Core) saveCORSConfig() error { + view := c.systemBarrierView.SubView("config/") + + entry, err := logical.StorageEntryJSON("cors", c.corsConfig) + if err != nil { + return fmt.Errorf("failed to create CORS config entry: %v", err) + } + + if err := view.Put(entry); err != nil { + return fmt.Errorf("failed to save CORS config: %v", err) + } + + return nil +} + +func (c *Core) loadCORSConfig() error { + view := c.systemBarrierView.SubView("config/") + + // Load the config in + out, err := view.Get("cors") + if err != nil { + return fmt.Errorf("failed to read CORS config: %v", err) + } + if out == nil { + return nil + } + + err = out.DecodeJSON(c.corsConfig) + if err != nil { + return err + } + + return nil +} + +// Enable takes either a '*' or a comma-seprated list of URLs that can make +// cross-origin requests to Vault. +func (c *CORSConfig) Enable(urls []string) error { + if len(urls) == 0 { + return errors.New("the list of allowed origins cannot be empty") + } + + if strutil.StrListContains(urls, "*") && len(urls) > 1 { + return errors.New("to allow all origins the '*' must be the only value for allowed_origins") + } + + c.Lock() + defer c.Unlock() + + c.AllowedOrigins = urls + c.Enabled = true + + return nil +} + +// IsEnabled returns the value of CORSConfig.isEnabled +func (c *CORSConfig) IsEnabled() bool { + c.RLock() + defer c.RUnlock() + + return c.Enabled +} + +// Disable sets CORS to disabled and clears the allowed origins +func (c *CORSConfig) Disable() { + c.Lock() + defer c.Unlock() + + c.Enabled = false + c.AllowedOrigins = []string{} +} + +// IsValidOrigin determines if the origin of the request is allowed to make +// cross-origin requests based on the CORSConfig. +func (c *CORSConfig) IsValidOrigin(origin string) bool { + c.RLock() + defer c.RUnlock() + + if c.AllowedOrigins == nil { + return false + } + + if len(c.AllowedOrigins) == 1 && (c.AllowedOrigins)[0] == "*" { + return true + } + + return strutil.StrListContains(c.AllowedOrigins, origin) +} diff --git a/vault/logical_system.go b/vault/logical_system.go index 936dd3a28..20573f6a9 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -62,6 +62,7 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen "replication/primary/secondary-token", "replication/reindex", "rotate", + "config/*", "config/auditing/*", "plugins/catalog/*", "revoke-prefix/*", @@ -99,6 +100,30 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen HelpDescription: strings.TrimSpace(sysHelp["capabilities_accessor"][1]), }, + &framework.Path{ + Pattern: "config/cors$", + + Fields: map[string]*framework.FieldSchema{ + "enable": &framework.FieldSchema{ + Type: framework.TypeBool, + Description: "Enables or disables CORS headers on requests.", + }, + "allowed_origins": &framework.FieldSchema{ + Type: framework.TypeCommaStringSlice, + Description: "A comma-separated list of origins that may make cross-origin requests.", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.handleCORSRead, + logical.UpdateOperation: b.handleCORSUpdate, + logical.DeleteOperation: b.handleCORSDelete, + }, + + HelpDescription: strings.TrimSpace(sysHelp["config/cors"][0]), + HelpSynopsis: strings.TrimSpace(sysHelp["config/cors"][1]), + }, + &framework.Path{ Pattern: "capabilities$", @@ -809,6 +834,41 @@ type SystemBackend struct { Backend *framework.Backend } +// handleCORSRead returns the current CORS configuration +func (b *SystemBackend) handleCORSRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + corsConf := b.Core.corsConfig + if corsConf == nil { + return nil, errCORSNotConfigured + } + + return &logical.Response{ + Data: map[string]interface{}{ + "enabled": corsConf.Enabled, + "allowed_origins": strings.Join(corsConf.AllowedOrigins, ","), + }, + }, nil +} + +// handleCORSUpdate sets the list of origins that are allowed +// to make cross-origin requests and sets the CORS enabled flag to true +func (b *SystemBackend) handleCORSUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + origins := d.Get("allowed_origins").([]string) + + err := b.Core.corsConfig.Enable(origins) + if err != nil { + return nil, err + } + + return nil, nil +} + +// handleCORSDelete clears the allowed origins and sets the CORS enabled flag to false +func (b *SystemBackend) handleCORSDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + b.Core.CORSConfig().Disable() + + return nil, nil +} + func (b *SystemBackend) handleTidyLeases(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { err := b.Core.expiration.Tidy() if err != nil { @@ -967,7 +1027,7 @@ func (b *SystemBackend) handleAuditedHeadersRead(req *logical.Request, d *framew }, nil } -// handleCapabilitiesreturns the ACL capabilities of the token for a given path +// handleCapabilities returns the ACL capabilities of the token for a given path func (b *SystemBackend) handleCapabilities(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { token := d.Get("token").(string) if token == "" { @@ -985,8 +1045,8 @@ func (b *SystemBackend) handleCapabilities(req *logical.Request, d *framework.Fi }, nil } -// handleCapabilitiesAccessor returns the ACL capabilities of the token associted -// with the given accessor for a given path. +// handleCapabilitiesAccessor returns the ACL capabilities of the +// token associted with the given accessor for a given path. func (b *SystemBackend) handleCapabilitiesAccessor(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { accessor := d.Get("accessor").(string) if accessor == "" { @@ -2244,6 +2304,21 @@ as well as perform core operations. // sysHelp is all the help text for the sys backend. var sysHelp = map[string][2]string{ + "config/cors": { + "Configures or returns the current configuration of CORS settings.", + ` +This path responds to the following HTTP methods. + + GET / + Returns the configuration of the CORS setting. + + POST / + Sets the comma-separated list of origins that can make cross-origin requests. + + DELETE / + Clears the CORS configuration and disables acceptance of CORS requests. + `, + }, "init": { "Initializes or returns the initialization status of the Vault.", ` diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 536b4fa39..87df8e038 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -31,6 +31,7 @@ func TestSystemBackend_RootPaths(t *testing.T) { "replication/primary/secondary-token", "replication/reindex", "rotate", + "config/*", "config/auditing/*", "plugins/catalog/*", "revoke-prefix/*", @@ -46,6 +47,58 @@ func TestSystemBackend_RootPaths(t *testing.T) { } } +func TestSystemConfigCORS(t *testing.T) { + b := testSystemBackend(t) + + req := logical.TestRequest(t, logical.UpdateOperation, "config/cors") + req.Data["allowed_origins"] = "http://www.example.com" + _, err := b.HandleRequest(req) + if err != nil { + t.Fatal(err) + } + + expected := &logical.Response{ + Data: map[string]interface{}{ + "enabled": true, + "allowed_origins": "http://www.example.com", + }, + } + + req = logical.TestRequest(t, logical.ReadOperation, "config/cors") + actual, err := b.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("UPDATE FAILED -- bad: %#v", actual) + } + + req = logical.TestRequest(t, logical.DeleteOperation, "config/cors") + _, err = b.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + req = logical.TestRequest(t, logical.ReadOperation, "config/cors") + actual, err = b.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + expected = &logical.Response{ + Data: map[string]interface{}{ + "enabled": false, + "allowed_origins": "", + }, + } + + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("DELETE FAILED -- bad: %#v", actual) + } + +} + func TestSystemBackend_mounts(t *testing.T) { b := testSystemBackend(t) req := logical.TestRequest(t, logical.ReadOperation, "mounts") diff --git a/website/source/api/index.html.md b/website/source/api/index.html.md index e0b46ff06..4f84f82ec 100644 --- a/website/source/api/index.html.md +++ b/website/source/api/index.html.md @@ -156,7 +156,9 @@ The following HTTP status codes are used throughout the API. - `204` - Success, no data returned. - `400` - Invalid request, missing or invalid data. - `403` - Forbidden, your authentication details are either - incorrect or you don't have access to this feature. + incorrect, you don't have access to this feature, or - if CORS is + enabled - you made a cross-origin request from an origin that is + not allowed to make such requests. - `404` - Invalid path. This can both mean that the path truly doesn't exist or that you don't have permission to view a specific path. We use 404 in some cases to avoid state leakage. diff --git a/website/source/docs/http/sys-config-cors.html.md b/website/source/docs/http/sys-config-cors.html.md new file mode 100644 index 000000000..05755f996 --- /dev/null +++ b/website/source/docs/http/sys-config-cors.html.md @@ -0,0 +1,109 @@ +--- +layout: "http" +page_title: "HTTP API: /sys/config/cors" +sidebar_current: "docs-http-config-cors" +description: |- + The '/sys/config/cors' endpoint configures how the Vault server responds to cross-origin requests. +--- + +# /sys/config/cors + +This is a protected path, therefore all requests require a token with `root` +policy or `sudo` capability on the path. + +## GET + +
+
Description
+
+ Returns the current CORS configuration. +
+ +
Method
+
GET
+ +
URL
+
`/sys/config/cors`
+ +
Parameters
+
+ None +
+ +
Returns
+
+ + ```javascript + { + "enabled": true, + "allowed_origins": "http://www.example.com" + } + ``` + + Sample response when CORS is disabled. + + ```javascript + { + "enabled": false, + "allowed_origins": "" + } + ``` +
+
+ +## PUT + +
+
Description
+
+ Configures the Vault server to return CORS headers for origins that are + permitted to make cross-origin requests based on the `allowed_origins` + parameter. +
+ +
Method
+
PUT
+ +
URL
+
`/sys/config/cors`
+ +
Parameters
+
+
    +
  • + allowed_origins + required + Valid values are either a wildcard (*) or a comma-separated list of + exact origins that are permitted to make cross-origin requests. +
  • +
+
+ +
Returns
+
`204` response code. +
+
+ +## DELETE + +
+
Description
+
+ Disables the CORS functionality of the Vault server. +
+ +
Method
+
DELETE
+ +
URL
+
`/sys/config/cors`
+ +
Parameters
+
+ None +
+ +
Returns
+
`204` response code. +
+