Add:
* Request/Response field extension * Parsing of header into request object * Handling of duration/mount point within router * Tests of router WrapDuration handling
This commit is contained in:
parent
0a2e78f8d8
commit
d81806b446
|
@ -6,15 +6,23 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
// AuthHeaderName is the name of the header containing the token.
|
||||
const AuthHeaderName = "X-Vault-Token"
|
||||
const (
|
||||
// AuthHeaderName is the name of the header containing the token.
|
||||
AuthHeaderName = "X-Vault-Token"
|
||||
|
||||
// WrapHeaderName is the name of the header containing a directive to wrap the
|
||||
// response.
|
||||
WrapDurationHeaderName = "X-Vault-Wrap-Duration"
|
||||
)
|
||||
|
||||
// Handler returns an http.Handler for the API. This can be used on
|
||||
// its own to mount the Vault API within another web server.
|
||||
|
@ -153,6 +161,34 @@ func requestAuth(r *http.Request, req *logical.Request) *logical.Request {
|
|||
return req
|
||||
}
|
||||
|
||||
// requestWrapDuration adds the WrapDuration value to the logical.Request if it
|
||||
// exists.
|
||||
func requestWrapDuration(r *http.Request, req *logical.Request) (*logical.Request, error) {
|
||||
// First try for the header value
|
||||
wrapDuration := r.Header.Get(WrapDurationHeaderName)
|
||||
if wrapDuration == "" {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// If it has an allowed suffix parse as a duration string
|
||||
if strings.HasSuffix(wrapDuration, "s") || strings.HasSuffix(wrapDuration, "m") || strings.HasSuffix(wrapDuration, "h") {
|
||||
dur, err := time.ParseDuration(wrapDuration)
|
||||
if err != nil {
|
||||
return req, err
|
||||
}
|
||||
req.WrapDuration = dur
|
||||
} else {
|
||||
// Parse as a straight number of seconds
|
||||
seconds, err := strconv.ParseInt(wrapDuration, 10, 64)
|
||||
if err != nil {
|
||||
return req, err
|
||||
}
|
||||
req.WrapDuration = time.Duration(time.Duration(seconds) * time.Second)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Determines the type of the error being returned and sets the HTTP
|
||||
// status code appropriately
|
||||
func respondErrorStatus(w http.ResponseWriter, err error) {
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
@ -68,12 +69,18 @@ func handleLogical(core *vault.Core, dataOnly bool, prepareRequestCallback Prepa
|
|||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
req := requestAuth(r, &logical.Request{
|
||||
Operation: op,
|
||||
Path: path,
|
||||
Data: data,
|
||||
Connection: getConnection(r),
|
||||
})
|
||||
req, err = requestWrapDuration(r, req)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-Duration header: {{err}}", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Certain endpoints may require changes to the request object.
|
||||
// They will have a callback registered to do the needful.
|
||||
|
@ -123,30 +130,7 @@ func respondLogical(w http.ResponseWriter, r *http.Request, path string, dataOnl
|
|||
return
|
||||
}
|
||||
|
||||
logicalResp := &LogicalResponse{
|
||||
Data: resp.Data,
|
||||
Warnings: resp.Warnings(),
|
||||
}
|
||||
if resp.Secret != nil {
|
||||
logicalResp.LeaseID = resp.Secret.LeaseID
|
||||
logicalResp.Renewable = resp.Secret.Renewable
|
||||
logicalResp.LeaseDuration = int(resp.Secret.TTL.Seconds())
|
||||
}
|
||||
|
||||
// If we have authentication information, then
|
||||
// set up the result structure.
|
||||
if resp.Auth != nil {
|
||||
logicalResp.Auth = &Auth{
|
||||
ClientToken: resp.Auth.ClientToken,
|
||||
Accessor: resp.Auth.Accessor,
|
||||
Policies: resp.Auth.Policies,
|
||||
Metadata: resp.Auth.Metadata,
|
||||
LeaseDuration: int(resp.Auth.TTL.Seconds()),
|
||||
Renewable: resp.Auth.Renewable,
|
||||
}
|
||||
}
|
||||
|
||||
httpResp = logicalResp
|
||||
httpResp = logical.SanitizeResponse(resp)
|
||||
}
|
||||
|
||||
// Respond
|
||||
|
@ -221,21 +205,3 @@ func getConnection(r *http.Request) (connection *logical.Connection) {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
type LogicalResponse struct {
|
||||
LeaseID string `json:"lease_id"`
|
||||
Renewable bool `json:"renewable"`
|
||||
LeaseDuration int `json:"lease_duration"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Warnings []string `json:"warnings"`
|
||||
Auth *Auth `json:"auth"`
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
ClientToken string `json:"client_token"`
|
||||
Accessor string `json:"accessor"`
|
||||
Policies []string `json:"policies"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
LeaseDuration int `json:"lease_duration"`
|
||||
Renewable bool `json:"renewable"`
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package logical
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Request is a struct that stores the parameters and context
|
||||
|
@ -52,6 +53,10 @@ type Request struct {
|
|||
// paths relative to itself. The `Path` is effectively the client
|
||||
// request path with the MountPoint trimmed off.
|
||||
MountPoint string
|
||||
|
||||
// WrapDuration contains the requested TTL of the token used to wrap the
|
||||
// response in a cubbyhole.
|
||||
WrapDuration time.Duration
|
||||
}
|
||||
|
||||
// Get returns a data field and guards for nil Data
|
||||
|
|
|
@ -3,6 +3,7 @@ package logical
|
|||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/copystructure"
|
||||
)
|
||||
|
@ -26,6 +27,19 @@ const (
|
|||
HTTPStatusCode = "http_status_code"
|
||||
)
|
||||
|
||||
type WrapInfo struct {
|
||||
// Setting to non-zero specifies that the response should be wrapped.
|
||||
// Specifies the desired TTL of the wrapping token.
|
||||
Duration time.Duration
|
||||
|
||||
// The token containing the wrapped response
|
||||
Token string
|
||||
|
||||
// The mount point of the backend, useful for further requests (such as
|
||||
// logging in with the given credentials)
|
||||
MountPoint string
|
||||
}
|
||||
|
||||
// Response is a struct that stores the response of a request.
|
||||
// It is used to abstract the details of the higher level request protocol.
|
||||
type Response struct {
|
||||
|
@ -54,6 +68,9 @@ type Response struct {
|
|||
// Vault (backend, core, etc.) to add warnings without accidentally
|
||||
// replacing what exists.
|
||||
warnings []string
|
||||
|
||||
// Information for wrapping the response in a cubbyhole
|
||||
WrapInfo WrapInfo
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
package logical
|
||||
|
||||
func SanitizeResponse(input *Response) *HTTPResponse {
|
||||
logicalResp := &HTTPResponse{
|
||||
Data: input.Data,
|
||||
Warnings: input.Warnings(),
|
||||
}
|
||||
if input.Secret != nil {
|
||||
logicalResp.LeaseID = input.Secret.LeaseID
|
||||
logicalResp.Renewable = input.Secret.Renewable
|
||||
logicalResp.LeaseDuration = int(input.Secret.TTL.Seconds())
|
||||
}
|
||||
|
||||
// If we have authentication information, then
|
||||
// set up the result structure.
|
||||
if input.Auth != nil {
|
||||
logicalResp.Auth = &HTTPAuth{
|
||||
ClientToken: input.Auth.ClientToken,
|
||||
Accessor: input.Auth.Accessor,
|
||||
Policies: input.Auth.Policies,
|
||||
Metadata: input.Auth.Metadata,
|
||||
LeaseDuration: int(input.Auth.TTL.Seconds()),
|
||||
Renewable: input.Auth.Renewable,
|
||||
}
|
||||
}
|
||||
|
||||
return logicalResp
|
||||
}
|
||||
|
||||
type HTTPResponse struct {
|
||||
LeaseID string `json:"lease_id"`
|
||||
Renewable bool `json:"renewable"`
|
||||
LeaseDuration int `json:"lease_duration"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Warnings []string `json:"warnings"`
|
||||
Auth *HTTPAuth `json:"auth"`
|
||||
}
|
||||
|
||||
type HTTPAuth struct {
|
||||
ClientToken string `json:"client_token"`
|
||||
Accessor string `json:"accessor"`
|
||||
Policies []string `json:"policies"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
LeaseDuration int `json:"lease_duration"`
|
||||
Renewable bool `json:"renewable"`
|
||||
}
|
|
@ -194,7 +194,7 @@ func (r *Router) RouteExistenceCheck(req *logical.Request) (bool, bool, error) {
|
|||
return ok, exists, err
|
||||
}
|
||||
|
||||
func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logical.Response, bool, bool, error) {
|
||||
func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (resp *logical.Response, ok bool, exists bool, err error) {
|
||||
// Find the mount point
|
||||
r.l.RLock()
|
||||
mount, raw, ok := r.root.LongestPrefix(req.Path)
|
||||
|
@ -250,8 +250,36 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logica
|
|||
|
||||
// Reset the request before returning
|
||||
defer func() {
|
||||
// Keep the mount point populated in case it's needed for wrapping, but
|
||||
// ensure that it hasn't been modified by the backend by setting it to
|
||||
// the original value
|
||||
req.MountPoint = mount
|
||||
|
||||
// We only run this if resp is not nil, so for instance we don't during
|
||||
// an existence check
|
||||
if resp != nil {
|
||||
// If either of the request or response requested wrapping, ensure that
|
||||
// the lowest value is what ends up in the response.
|
||||
switch {
|
||||
case req.WrapDuration == 0 && resp.WrapInfo.Duration == 0:
|
||||
case req.WrapDuration != 0 && resp.WrapInfo.Duration != 0:
|
||||
if req.WrapDuration < resp.WrapInfo.Duration {
|
||||
resp.WrapInfo.Duration = req.WrapDuration
|
||||
}
|
||||
case req.WrapDuration != 0:
|
||||
resp.WrapInfo.Duration = req.WrapDuration
|
||||
// Only case left is that only resp defines it, which doesn't need to
|
||||
// be explicitly handled
|
||||
}
|
||||
|
||||
// Now set the mount point if we are wrapping
|
||||
if resp.WrapInfo.Duration != 0 {
|
||||
resp.WrapInfo.MountPoint = mount
|
||||
}
|
||||
}
|
||||
|
||||
// Reset other parameters
|
||||
req.Path = original
|
||||
req.MountPoint = ""
|
||||
req.Connection = originalConn
|
||||
req.Storage = nil
|
||||
req.ClientToken = clientToken
|
||||
|
@ -262,7 +290,7 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logica
|
|||
ok, exists, err := re.backend.HandleExistenceCheck(req)
|
||||
return nil, ok, exists, err
|
||||
} else {
|
||||
resp, err := re.backend.HandleRequest(req)
|
||||
resp, err = re.backend.HandleRequest(req)
|
||||
return resp, false, false, err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,8 @@ type NoopBackend struct {
|
|||
Paths []string
|
||||
Requests []*logical.Request
|
||||
Response *logical.Response
|
||||
|
||||
WrapDuration time.Duration
|
||||
}
|
||||
|
||||
func (n *NoopBackend) HandleRequest(req *logical.Request) (*logical.Response, error) {
|
||||
|
@ -32,6 +34,14 @@ func (n *NoopBackend) HandleRequest(req *logical.Request) (*logical.Response, er
|
|||
return nil, fmt.Errorf("missing view")
|
||||
}
|
||||
|
||||
if n.Response == nil && (req.WrapDuration != 0 || n.WrapDuration != 0) {
|
||||
n.Response = &logical.Response{}
|
||||
}
|
||||
|
||||
if n.WrapDuration != 0 {
|
||||
n.Response.WrapInfo.Duration = n.WrapDuration
|
||||
}
|
||||
|
||||
return n.Response, nil
|
||||
}
|
||||
|
||||
|
@ -374,3 +384,115 @@ func TestPathsToRadix(t *testing.T) {
|
|||
t.Fatalf("bad: %v (sub/bar)", raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_Wrapping(t *testing.T) {
|
||||
core, _, root := TestCoreUnsealed(t)
|
||||
|
||||
n := &NoopBackend{}
|
||||
|
||||
core.logicalBackends["noop"] = func(config *logical.BackendConfig) (logical.Backend, error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
meUUID, _ := uuid.GenerateUUID()
|
||||
err := core.mount(&MountEntry{
|
||||
UUID: meUUID,
|
||||
Path: "wraptest",
|
||||
Type: "noop",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// No duration specified
|
||||
req := &logical.Request{
|
||||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.UpdateOperation,
|
||||
}
|
||||
resp, err := core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatalf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
// Just in the request
|
||||
req = &logical.Request{
|
||||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.UpdateOperation,
|
||||
WrapDuration: time.Duration(15 * time.Second),
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
if resp.WrapInfo.Duration != time.Duration(15*time.Second) ||
|
||||
resp.WrapInfo.MountPoint != "wraptest/" {
|
||||
t.Fatalf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
// Just in the response
|
||||
n.WrapDuration = time.Duration(15 * time.Second)
|
||||
req = &logical.Request{
|
||||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.UpdateOperation,
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
if resp.WrapInfo.Duration != time.Duration(15*time.Second) ||
|
||||
resp.WrapInfo.MountPoint != "wraptest/" {
|
||||
t.Fatalf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
// In both, with request less
|
||||
n.WrapDuration = time.Duration(15 * time.Second)
|
||||
req = &logical.Request{
|
||||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.UpdateOperation,
|
||||
WrapDuration: time.Duration(10 * time.Second),
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
if resp.WrapInfo.Duration != time.Duration(10*time.Second) ||
|
||||
resp.WrapInfo.MountPoint != "wraptest/" {
|
||||
t.Fatalf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
// In both, with response less
|
||||
n.WrapDuration = time.Duration(10 * time.Second)
|
||||
req = &logical.Request{
|
||||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.UpdateOperation,
|
||||
WrapDuration: time.Duration(15 * time.Second),
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
if resp.WrapInfo.Duration != time.Duration(10*time.Second) ||
|
||||
resp.WrapInfo.MountPoint != "wraptest/" {
|
||||
t.Fatalf("bad: %#v", resp)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue