* Request/Response field extension
* Parsing of header into request object
* Handling of duration/mount point within router
* Tests of router WrapDuration handling
This commit is contained in:
Jeff Mitchell 2016-05-01 22:39:45 -04:00
parent 0a2e78f8d8
commit d81806b446
7 changed files with 267 additions and 47 deletions

View File

@ -6,15 +6,23 @@ import (
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)
// AuthHeaderName is the name of the header containing the token.
const AuthHeaderName = "X-Vault-Token"
const (
// AuthHeaderName is the name of the header containing the token.
AuthHeaderName = "X-Vault-Token"
// WrapHeaderName is the name of the header containing a directive to wrap the
// response.
WrapDurationHeaderName = "X-Vault-Wrap-Duration"
)
// Handler returns an http.Handler for the API. This can be used on
// its own to mount the Vault API within another web server.
@ -153,6 +161,34 @@ func requestAuth(r *http.Request, req *logical.Request) *logical.Request {
return req
}
// requestWrapDuration adds the WrapDuration value to the logical.Request if it
// exists.
func requestWrapDuration(r *http.Request, req *logical.Request) (*logical.Request, error) {
// First try for the header value
wrapDuration := r.Header.Get(WrapDurationHeaderName)
if wrapDuration == "" {
return req, nil
}
// If it has an allowed suffix parse as a duration string
if strings.HasSuffix(wrapDuration, "s") || strings.HasSuffix(wrapDuration, "m") || strings.HasSuffix(wrapDuration, "h") {
dur, err := time.ParseDuration(wrapDuration)
if err != nil {
return req, err
}
req.WrapDuration = dur
} else {
// Parse as a straight number of seconds
seconds, err := strconv.ParseInt(wrapDuration, 10, 64)
if err != nil {
return req, err
}
req.WrapDuration = time.Duration(time.Duration(seconds) * time.Second)
}
return req, nil
}
// Determines the type of the error being returned and sets the HTTP
// status code appropriately
func respondErrorStatus(w http.ResponseWriter, err error) {

View File

@ -7,6 +7,7 @@ import (
"strconv"
"strings"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)
@ -68,12 +69,18 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa
}
}
var err error
req := requestAuth(r, &logical.Request{
Operation: op,
Path: path,
Data: data,
Connection: getConnection(r),
})
req, err = requestWrapDuration(r, req)
if err != nil {
respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-Duration header: {{err}}", err))
return
}
// Certain endpoints may require changes to the request object.
// They will have a callback registered to do the needful.
@ -123,30 +130,7 @@ func respondLogical(w http.ResponseWriter, r *http.Request, path string, dataOnl
return
}
logicalResp := &LogicalResponse{
Data: resp.Data,
Warnings: resp.Warnings(),
}
if resp.Secret != nil {
logicalResp.LeaseID = resp.Secret.LeaseID
logicalResp.Renewable = resp.Secret.Renewable
logicalResp.LeaseDuration = int(resp.Secret.TTL.Seconds())
}
// If we have authentication information, then
// set up the result structure.
if resp.Auth != nil {
logicalResp.Auth = &Auth{
ClientToken: resp.Auth.ClientToken,
Accessor: resp.Auth.Accessor,
Policies: resp.Auth.Policies,
Metadata: resp.Auth.Metadata,
LeaseDuration: int(resp.Auth.TTL.Seconds()),
Renewable: resp.Auth.Renewable,
}
}
httpResp = logicalResp
httpResp = logical.SanitizeResponse(resp)
}
// Respond
@ -221,21 +205,3 @@ func getConnection(r *http.Request) (connection *logical.Connection) {
}
return
}
type LogicalResponse struct {
LeaseID string `json:"lease_id"`
Renewable bool `json:"renewable"`
LeaseDuration int `json:"lease_duration"`
Data map[string]interface{} `json:"data"`
Warnings []string `json:"warnings"`
Auth *Auth `json:"auth"`
}
type Auth struct {
ClientToken string `json:"client_token"`
Accessor string `json:"accessor"`
Policies []string `json:"policies"`
Metadata map[string]string `json:"metadata"`
LeaseDuration int `json:"lease_duration"`
Renewable bool `json:"renewable"`
}

View File

@ -3,6 +3,7 @@ package logical
import (
"errors"
"fmt"
"time"
)
// Request is a struct that stores the parameters and context
@ -52,6 +53,10 @@ type Request struct {
// paths relative to itself. The `Path` is effectively the client
// request path with the MountPoint trimmed off.
MountPoint string
// WrapDuration contains the requested TTL of the token used to wrap the
// response in a cubbyhole.
WrapDuration time.Duration
}
// Get returns a data field and guards for nil Data

View File

@ -3,6 +3,7 @@ package logical
import (
"fmt"
"reflect"
"time"
"github.com/mitchellh/copystructure"
)
@ -26,6 +27,19 @@ const (
HTTPStatusCode = "http_status_code"
)
type WrapInfo struct {
// Setting to non-zero specifies that the response should be wrapped.
// Specifies the desired TTL of the wrapping token.
Duration time.Duration
// The token containing the wrapped response
Token string
// The mount point of the backend, useful for further requests (such as
// logging in with the given credentials)
MountPoint string
}
// Response is a struct that stores the response of a request.
// It is used to abstract the details of the higher level request protocol.
type Response struct {
@ -54,6 +68,9 @@ type Response struct {
// Vault (backend, core, etc.) to add warnings without accidentally
// replacing what exists.
warnings []string
// Information for wrapping the response in a cubbyhole
WrapInfo WrapInfo
}
func init() {

46
logical/sanitize.go Normal file
View File

@ -0,0 +1,46 @@
package logical
func SanitizeResponse(input *Response) *HTTPResponse {
logicalResp := &HTTPResponse{
Data: input.Data,
Warnings: input.Warnings(),
}
if input.Secret != nil {
logicalResp.LeaseID = input.Secret.LeaseID
logicalResp.Renewable = input.Secret.Renewable
logicalResp.LeaseDuration = int(input.Secret.TTL.Seconds())
}
// If we have authentication information, then
// set up the result structure.
if input.Auth != nil {
logicalResp.Auth = &HTTPAuth{
ClientToken: input.Auth.ClientToken,
Accessor: input.Auth.Accessor,
Policies: input.Auth.Policies,
Metadata: input.Auth.Metadata,
LeaseDuration: int(input.Auth.TTL.Seconds()),
Renewable: input.Auth.Renewable,
}
}
return logicalResp
}
type HTTPResponse struct {
LeaseID string `json:"lease_id"`
Renewable bool `json:"renewable"`
LeaseDuration int `json:"lease_duration"`
Data map[string]interface{} `json:"data"`
Warnings []string `json:"warnings"`
Auth *HTTPAuth `json:"auth"`
}
type HTTPAuth struct {
ClientToken string `json:"client_token"`
Accessor string `json:"accessor"`
Policies []string `json:"policies"`
Metadata map[string]string `json:"metadata"`
LeaseDuration int `json:"lease_duration"`
Renewable bool `json:"renewable"`
}

View File

@ -194,7 +194,7 @@ func (r *Router) RouteExistenceCheck(req *logical.Request) (bool, bool, error) {
return ok, exists, err
}
func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logical.Response, bool, bool, error) {
func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (resp *logical.Response, ok bool, exists bool, err error) {
// Find the mount point
r.l.RLock()
mount, raw, ok := r.root.LongestPrefix(req.Path)
@ -250,8 +250,36 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logica
// Reset the request before returning
defer func() {
// Keep the mount point populated in case it's needed for wrapping, but
// ensure that it hasn't been modified by the backend by setting it to
// the original value
req.MountPoint = mount
// We only run this if resp is not nil, so for instance we don't during
// an existence check
if resp != nil {
// If either of the request or response requested wrapping, ensure that
// the lowest value is what ends up in the response.
switch {
case req.WrapDuration == 0 && resp.WrapInfo.Duration == 0:
case req.WrapDuration != 0 && resp.WrapInfo.Duration != 0:
if req.WrapDuration < resp.WrapInfo.Duration {
resp.WrapInfo.Duration = req.WrapDuration
}
case req.WrapDuration != 0:
resp.WrapInfo.Duration = req.WrapDuration
// Only case left is that only resp defines it, which doesn't need to
// be explicitly handled
}
// Now set the mount point if we are wrapping
if resp.WrapInfo.Duration != 0 {
resp.WrapInfo.MountPoint = mount
}
}
// Reset other parameters
req.Path = original
req.MountPoint = ""
req.Connection = originalConn
req.Storage = nil
req.ClientToken = clientToken
@ -262,7 +290,7 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logica
ok, exists, err := re.backend.HandleExistenceCheck(req)
return nil, ok, exists, err
} else {
resp, err := re.backend.HandleRequest(req)
resp, err = re.backend.HandleRequest(req)
return resp, false, false, err
}
}

View File

@ -19,6 +19,8 @@ type NoopBackend struct {
Paths []string
Requests []*logical.Request
Response *logical.Response
WrapDuration time.Duration
}
func (n *NoopBackend) HandleRequest(req *logical.Request) (*logical.Response, error) {
@ -32,6 +34,14 @@ func (n *NoopBackend) HandleRequest(req *logical.Request) (*logical.Response, er
return nil, fmt.Errorf("missing view")
}
if n.Response == nil && (req.WrapDuration != 0 || n.WrapDuration != 0) {
n.Response = &logical.Response{}
}
if n.WrapDuration != 0 {
n.Response.WrapInfo.Duration = n.WrapDuration
}
return n.Response, nil
}
@ -374,3 +384,115 @@ func TestPathsToRadix(t *testing.T) {
t.Fatalf("bad: %v (sub/bar)", raw)
}
}
func TestRouter_Wrapping(t *testing.T) {
core, _, root := TestCoreUnsealed(t)
n := &NoopBackend{}
core.logicalBackends["noop"] = func(config *logical.BackendConfig) (logical.Backend, error) {
return n, nil
}
meUUID, _ := uuid.GenerateUUID()
err := core.mount(&MountEntry{
UUID: meUUID,
Path: "wraptest",
Type: "noop",
})
if err != nil {
t.Fatalf("err: %v", err)
}
// No duration specified
req := &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
}
resp, err := core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %#v", resp)
}
// Just in the request
req = &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
WrapDuration: time.Duration(15 * time.Second),
}
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.WrapInfo.Duration != time.Duration(15*time.Second) ||
resp.WrapInfo.MountPoint != "wraptest/" {
t.Fatalf("bad: %#v", resp)
}
// Just in the response
n.WrapDuration = time.Duration(15 * time.Second)
req = &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
}
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.WrapInfo.Duration != time.Duration(15*time.Second) ||
resp.WrapInfo.MountPoint != "wraptest/" {
t.Fatalf("bad: %#v", resp)
}
// In both, with request less
n.WrapDuration = time.Duration(15 * time.Second)
req = &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
WrapDuration: time.Duration(10 * time.Second),
}
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.WrapInfo.Duration != time.Duration(10*time.Second) ||
resp.WrapInfo.MountPoint != "wraptest/" {
t.Fatalf("bad: %#v", resp)
}
// In both, with response less
n.WrapDuration = time.Duration(10 * time.Second)
req = &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
WrapDuration: time.Duration(15 * time.Second),
}
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.WrapInfo.Duration != time.Duration(10*time.Second) ||
resp.WrapInfo.MountPoint != "wraptest/" {
t.Fatalf("bad: %#v", resp)
}
}