diff --git a/http/logical.go b/http/logical.go index 41d051283..075c25ca4 100644 --- a/http/logical.go +++ b/http/logical.go @@ -14,72 +14,76 @@ import ( type PrepareRequestFunc func(req *logical.Request) error +func buildLogicalRequest(w http.ResponseWriter, r *http.Request) (*logical.Request, int, error) { + // Determine the path... + if !strings.HasPrefix(r.URL.Path, "/v1/") { + return nil, http.StatusNotFound, nil + } + path := r.URL.Path[len("/v1/"):] + if path == "" { + return nil, http.StatusNotFound, nil + } + + // Determine the operation + var op logical.Operation + switch r.Method { + case "DELETE": + op = logical.DeleteOperation + case "GET": + op = logical.ReadOperation + // Need to call ParseForm to get query params loaded + queryVals := r.URL.Query() + listStr := queryVals.Get("list") + if listStr != "" { + list, err := strconv.ParseBool(listStr) + if err != nil { + return nil, http.StatusBadRequest, nil + } + if list { + op = logical.ListOperation + } + } + case "POST", "PUT": + op = logical.UpdateOperation + case "LIST": + op = logical.ListOperation + default: + return nil, http.StatusMethodNotAllowed, nil + } + + // Parse the request if we can + var data map[string]interface{} + if op == logical.UpdateOperation { + err := parseRequest(r, &data) + if err == io.EOF { + data = nil + err = nil + } + if err != nil { + return nil, http.StatusBadRequest, err + } + } + + var err error + req := requestAuth(r, &logical.Request{ + Operation: op, + Path: path, + Data: data, + Connection: getConnection(r), + }) + req, err = requestWrapTTL(r, req) + if err != nil { + return nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err) + } + + return req, 0, nil +} + func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback PrepareRequestFunc) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Determine the path... - if !strings.HasPrefix(r.URL.Path, "/v1/") { - respondError(w, http.StatusNotFound, nil) - return - } - path := r.URL.Path[len("/v1/"):] - if path == "" { - respondError(w, http.StatusNotFound, nil) - return - } - - // Determine the operation - var op logical.Operation - switch r.Method { - case "DELETE": - op = logical.DeleteOperation - case "GET": - op = logical.ReadOperation - // Need to call ParseForm to get query params loaded - queryVals := r.URL.Query() - listStr := queryVals.Get("list") - if listStr != "" { - list, err := strconv.ParseBool(listStr) - if err != nil { - respondError(w, http.StatusBadRequest, nil) - return - } - if list { - op = logical.ListOperation - } - } - case "POST", "PUT": - op = logical.UpdateOperation - case "LIST": - op = logical.ListOperation - default: - respondError(w, http.StatusMethodNotAllowed, nil) - return - } - - // Parse the request if we can - var data map[string]interface{} - if op == logical.UpdateOperation { - err := parseRequest(r, &data) - if err == io.EOF { - data = nil - err = nil - } - if err != nil { - respondError(w, http.StatusBadRequest, err) - return - } - } - - var err error - req := requestAuth(r, &logical.Request{ - Operation: op, - Path: path, - Data: data, - Connection: getConnection(r), - }) - req, err = requestWrapTTL(r, req) - if err != nil { - respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err)) + req, statusCode, err := buildLogicalRequest(w, r) + if err != nil || statusCode != 0 { + respondError(w, statusCode, err) return } @@ -101,7 +105,7 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa return } switch { - case op == logical.ReadOperation: + case req.Operation == logical.ReadOperation: if resp == nil { respondError(w, http.StatusNotFound, nil) return @@ -109,7 +113,7 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa // Basically: if we have empty "keys" or no keys at all, 404. This // provides consistency with GET. - case op == logical.ListOperation: + case req.Operation == logical.ListOperation: if resp == nil || len(resp.Data) == 0 { respondError(w, http.StatusNotFound, nil) return @@ -131,7 +135,7 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa } // Build the proper response - respondLogical(w, r, path, dataOnly, resp) + respondLogical(w, r, req.Path, dataOnly, resp) }) } diff --git a/http/sys_seal.go b/http/sys_seal.go index 60136a1bb..c5d52256a 100644 --- a/http/sys_seal.go +++ b/http/sys_seal.go @@ -13,19 +13,21 @@ import ( func handleSysSeal(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case "PUT": - case "POST": + req, statusCode, err := buildLogicalRequest(w, r) + if err != nil || statusCode != 0 { + respondError(w, statusCode, err) + return + } + + switch req.Operation { + case logical.UpdateOperation: default: respondError(w, http.StatusMethodNotAllowed, nil) return } - // Get the auth for the request so we can access the token directly - req := requestAuth(r, &logical.Request{}) - // Seal with the token above - if err := core.Seal(req.ClientToken); err != nil { + if err := core.SealWithRequest(req); err != nil { respondError(w, http.StatusInternalServerError, err) return } @@ -36,19 +38,21 @@ func handleSysSeal(core *vault.Core) http.Handler { func handleSysStepDown(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case "PUT": - case "POST": + req, statusCode, err := buildLogicalRequest(w, r) + if err != nil || statusCode != 0 { + respondError(w, statusCode, err) + return + } + + switch req.Operation { + case logical.UpdateOperation: default: respondError(w, http.StatusMethodNotAllowed, nil) return } - // Get the auth for the request so we can access the token directly - req := requestAuth(r, &logical.Request{}) - // Seal with the token above - if err := core.StepDown(req.ClientToken); err != nil { + if err := core.StepDown(req); err != nil { respondError(w, http.StatusInternalServerError, err) return } diff --git a/vault/core.go b/vault/core.go index 340e58d1d..9dcc2fcf3 100644 --- a/vault/core.go +++ b/vault/core.go @@ -658,25 +658,54 @@ func (c *Core) Unseal(key []byte) (bool, error) { return true, nil } -// Seal is used to re-seal the Vault. This requires the Vault to -// be unsealed again to perform any further operations. -func (c *Core) Seal(token string) (retErr error) { +// SealWithRequest takes in a logical.Request, acquires the lock, and passes +// through to sealInternal +func (c *Core) SealWithRequest(req *logical.Request) error { + defer metrics.MeasureSince([]string{"core", "seal-with-request"}, time.Now()) + + c.stateLock.Lock() + defer c.stateLock.Unlock() + + if c.sealed { + return nil + } + + return c.sealInitCommon(req) +} + +// Seal takes in a token and creates a logical.Request, acquires the lock, and +// passes through to sealInternal +func (c *Core) Seal(token string) error { defer metrics.MeasureSince([]string{"core", "seal"}, time.Now()) c.stateLock.Lock() defer c.stateLock.Unlock() if c.sealed { - return retErr + return nil } - // Validate the token is a root token req := &logical.Request{ Operation: logical.UpdateOperation, Path: "sys/seal", ClientToken: token, } + return c.sealInitCommon(req) +} + +// sealInitCommon is common logic for Seal and SealWithRequest and is used to +// re-seal the Vault. This requires the Vault to be unsealed again to perform +// any further operations. +func (c *Core) sealInitCommon(req *logical.Request) (retErr error) { + defer metrics.MeasureSince([]string{"core", "seal-internal"}, time.Now()) + + if req == nil { + retErr = multierror.Append(retErr, errors.New("nil request to seal")) + return retErr + } + + // Validate the token is a root token acl, te, err := c.fetchACLandTokenEntry(req) if err != nil { // Since there is no token store in standby nodes, sealing cannot @@ -692,6 +721,22 @@ func (c *Core) Seal(token string) (retErr error) { retErr = multierror.Append(retErr, err) return retErr } + + // Audit-log the request before going any further + auth := &logical.Auth{ + ClientToken: req.ClientToken, + Policies: te.Policies, + Metadata: te.Meta, + DisplayName: te.DisplayName, + } + + if err := c.auditBroker.LogRequest(auth, req, nil); err != nil { + c.logger.Printf("[ERR] core: failed to audit request with path %s: %v", + req.Path, err) + retErr = multierror.Append(retErr, errors.New("failed to audit request, cannot continue")) + return retErr + } + // Attempt to use the token (decrement num_uses) // On error bail out; if the token has been revoked, bail out too if te != nil { @@ -741,9 +786,14 @@ func (c *Core) Seal(token string) (retErr error) { } // StepDown is used to step down from leadership -func (c *Core) StepDown(token string) (retErr error) { +func (c *Core) StepDown(req *logical.Request) (retErr error) { defer metrics.MeasureSince([]string{"core", "step_down"}, time.Now()) + if req == nil { + retErr = multierror.Append(retErr, errors.New("nil request to step-down")) + return retErr + } + c.stateLock.Lock() defer c.stateLock.Unlock() if c.sealed { @@ -753,18 +803,27 @@ func (c *Core) StepDown(token string) (retErr error) { return nil } - // Validate the token is a root token - req := &logical.Request{ - Operation: logical.UpdateOperation, - Path: "sys/step-down", - ClientToken: token, - } - acl, te, err := c.fetchACLandTokenEntry(req) if err != nil { retErr = multierror.Append(retErr, err) return retErr } + + // Audit-log the request before going any further + auth := &logical.Auth{ + ClientToken: req.ClientToken, + Policies: te.Policies, + Metadata: te.Meta, + DisplayName: te.DisplayName, + } + + if err := c.auditBroker.LogRequest(auth, req, nil); err != nil { + c.logger.Printf("[ERR] core: failed to audit request with path %s: %v", + req.Path, err) + retErr = multierror.Append(retErr, errors.New("failed to audit request, cannot continue")) + return retErr + } + // Attempt to use the token (decrement num_uses) if te != nil { te, err = c.tokenStore.UseToken(te) diff --git a/vault/core_test.go b/vault/core_test.go index 0027df15e..4692f1c83 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -1148,8 +1148,13 @@ func TestCore_StepDown(t *testing.T) { t.Fatalf("Bad advertise: %v", advertise) } + req := &logical.Request{ + ClientToken: root, + Path: "sys/step-down", + } + // Step down core - err = core.StepDown(root) + err = core.StepDown(req) if err != nil { t.Fatal("error stepping down core 1") } @@ -1191,7 +1196,7 @@ func TestCore_StepDown(t *testing.T) { } // Step down core2 - err = core2.StepDown(root) + err = core2.StepDown(req) if err != nil { t.Fatal("error stepping down core 1") }