Address most review feedback. Change responses to multierror to better return more useful values when there are multiple errors

This commit is contained in:
Jeff Mitchell 2016-05-16 16:11:33 -04:00
parent 4c67a739b9
commit c4431a7e30
18 changed files with 531 additions and 360 deletions

View File

@ -29,6 +29,8 @@ var (
errRedirect = errors.New("redirect")
)
type WrappingLookupFunc func(operation, path string) string
// Config is used to configure the creation of the client.
type Config struct {
// Address is the address of the Vault server. This should be a complete
@ -37,14 +39,6 @@ type Config struct {
// HttpClient.
Address string
// WrapTTL, if specified, asks the Vault server to return the normal
// response wrapped in the cubbyhole of a token, with the TTL of the token
// being set to the lesser of this value or a value requested by the
// backend originating the response. Specified either as a number of
// seconds, or a string duration with a "s", "m", or "h" suffix for
// "seconds", "minutes", or "hours" respectively.
WrapTTL string
// HttpClient is the HTTP client to use, which will currently always have the
// same values as http.DefaultClient. This is used to control redirect behavior.
HttpClient *http.Client
@ -86,7 +80,6 @@ func (c *Config) ReadEnvironment() error {
var envCAPath string
var envClientCert string
var envClientKey string
var envWrapTTL string
var envInsecure bool
var foundInsecure bool
var envTLSServerName string
@ -110,9 +103,6 @@ func (c *Config) ReadEnvironment() error {
if v := os.Getenv(EnvVaultClientKey); v != "" {
envClientKey = v
}
if v := os.Getenv(EnvVaultWrapTTL); v != "" {
envWrapTTL = v
}
if v := os.Getenv(EnvVaultInsecure); v != "" {
var err error
envInsecure, err = strconv.ParseBool(v)
@ -141,10 +131,6 @@ func (c *Config) ReadEnvironment() error {
c.Address = envAddress
}
if envWrapTTL != "" {
c.WrapTTL = envWrapTTL
}
clientTLSConfig := c.HttpClient.Transport.(*http.Transport).TLSClientConfig
if foundInsecure {
clientTLSConfig.InsecureSkipVerify = envInsecure
@ -172,9 +158,10 @@ func (c *Config) ReadEnvironment() error {
// Client is the client to the Vault API. Create a client with
// NewClient.
type Client struct {
addr *url.URL
config *Config
token string
addr *url.URL
config *Config
token string
wrappingLookupFunc WrappingLookupFunc
}
// NewClient returns a new client for the given configuration.
@ -216,6 +203,12 @@ func NewClient(c *Config) (*Client, error) {
return client, nil
}
// SetWrappingLookupFunc sets a lookup function that returns desired wrap TTLs
// for a given operation and path
func (c *Client) SetWrappingLookupFunc(lookupFunc WrappingLookupFunc) {
c.wrappingLookupFunc = lookupFunc
}
// Token returns the access token being used by this client. It will
// return the empty string if there is no token set.
func (c *Client) Token() string {
@ -245,10 +238,13 @@ func (c *Client) NewRequest(method, path string) *Request {
Path: path,
},
ClientToken: c.token,
WrapTTL: c.config.WrapTTL,
Params: make(map[string][]string),
}
if c.wrappingLookupFunc != nil {
req.WrapTTL = c.wrappingLookupFunc(method, path)
}
return req
}

View File

@ -107,19 +107,16 @@ func TestClientEnvSettings(t *testing.T) {
oldClientCert := os.Getenv(EnvVaultClientCert)
oldClientKey := os.Getenv(EnvVaultClientKey)
oldSkipVerify := os.Getenv(EnvVaultInsecure)
oldWrapTTL := os.Getenv(EnvVaultWrapTTL)
os.Setenv("VAULT_CACERT", cwd+"/test-fixtures/keys/cert.pem")
os.Setenv("VAULT_CAPATH", cwd+"/test-fixtures/keys")
os.Setenv("VAULT_CLIENT_CERT", cwd+"/test-fixtures/keys/cert.pem")
os.Setenv("VAULT_CLIENT_KEY", cwd+"/test-fixtures/keys/key.pem")
os.Setenv("VAULT_SKIP_VERIFY", "true")
os.Setenv("VAULT_WRAP_TTL", "60")
defer os.Setenv("VAULT_CACERT", oldCACert)
defer os.Setenv("VAULT_CAPATH", oldCAPath)
defer os.Setenv("VAULT_CLIENT_CERT", oldClientCert)
defer os.Setenv("VAULT_CLIENT_KEY", oldClientKey)
defer os.Setenv("VAULT_SKIP_VERIFY", oldSkipVerify)
defer os.Setenv("VAULT_WRAP_TTL", oldWrapTTL)
config := DefaultConfig()
if err := config.ReadEnvironment(); err != nil {
@ -136,8 +133,4 @@ func TestClientEnvSettings(t *testing.T) {
if tlsConfig.InsecureSkipVerify != true {
t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify)
}
if config.WrapTTL != "60" {
t.Fatalf("bad: %v", config.WrapTTL)
}
}

109
command/wrapping_test.go Normal file
View File

@ -0,0 +1,109 @@
package command
import (
"os"
"testing"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/meta"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/cli"
)
func TestWrapping_Env(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
ui := new(cli.MockUi)
c := &TokenLookupCommand{
Meta: meta.Meta{
ClientToken: token,
Ui: ui,
},
}
args := []string{
"-address", addr,
}
// Run it once for client
c.Run(args)
// Create a new token for us to use
client, err := c.Client()
if err != nil {
t.Fatalf("err: %s", err)
}
resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{
Lease: "1h",
})
if err != nil {
t.Fatalf("err: %s", err)
}
prevWrapTTLEnv := os.Getenv(api.EnvVaultWrapTTL)
os.Setenv(api.EnvVaultWrapTTL, "5s")
defer func() {
os.Setenv(api.EnvVaultWrapTTL, prevWrapTTLEnv)
}()
// Now when we do a lookup-self the response should be wrapped
args = append(args, resp.Auth.ClientToken)
resp, err = client.Auth().Token().LookupSelf()
if err != nil {
t.Fatalf("err: %s", err)
}
if resp == nil {
t.Fatal("nil response")
}
if resp.WrapInfo == nil {
t.Fatal("nil wrap info")
}
if resp.WrapInfo.Token == "" || resp.WrapInfo.TTL != 5 {
t.Fatal("did not get token or ttl wrong")
}
}
func TestWrapping_Flag(t *testing.T) {
core, _, token := vault.TestCoreUnsealed(t)
ln, addr := http.TestServer(t, core)
defer ln.Close()
ui := new(cli.MockUi)
c := &TokenLookupCommand{
Meta: meta.Meta{
ClientToken: token,
Ui: ui,
},
}
args := []string{
"-address", addr,
"-wrap-ttl", "5s",
}
// Run it once for client
c.Run(args)
// Create a new token for us to use
client, err := c.Client()
if err != nil {
t.Fatalf("err: %s", err)
}
resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{
Lease: "1h",
})
if err != nil {
t.Fatalf("err: %s", err)
}
if resp == nil {
t.Fatal("nil response")
}
if resp.WrapInfo == nil {
t.Fatal("nil wrap info")
}
if resp.WrapInfo.Token == "" || resp.WrapInfo.TTL != 5 {
t.Fatal("did not get token or ttl wrong")
}
}

View File

@ -93,7 +93,7 @@ func parseRequest(r *http.Request, out interface{}) error {
// case of an error.
func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool) {
resp, err := core.HandleRequest(r)
if err == vault.ErrStandby {
if errwrap.Contains(err, vault.ErrStandby.Error()) {
respondStandby(core, w, rawReq.URL)
return resp, false
}
@ -195,7 +195,7 @@ func respondErrorStatus(w http.ResponseWriter, err error) {
status := http.StatusInternalServerError
switch {
// Keep adding more error types here to appropriate the status codes
case errwrap.ContainsType(err, new(vault.StatusBadRequest)):
case err != nil && errwrap.ContainsType(err, new(vault.StatusBadRequest)):
status = http.StatusBadRequest
}
respondError(w, status, err)
@ -203,7 +203,7 @@ func respondErrorStatus(w http.ResponseWriter, err error) {
func respondError(w http.ResponseWriter, status int, err error) {
// Adjust status code when sealed
if err == vault.ErrSealed {
if errwrap.Contains(err, vault.ErrSealed.Error()) {
status = http.StatusServiceUnavailable
}
@ -230,19 +230,19 @@ func respondCommon(w http.ResponseWriter, resp *logical.Response, err error) boo
}
if resp.IsError() {
var statusCode int
statusCode := http.StatusBadRequest
switch err {
case logical.ErrPermissionDenied:
statusCode = http.StatusForbidden
case logical.ErrUnsupportedOperation:
statusCode = http.StatusMethodNotAllowed
case logical.ErrUnsupportedPath:
statusCode = http.StatusNotFound
case logical.ErrInvalidRequest:
statusCode = http.StatusBadRequest
default:
statusCode = http.StatusBadRequest
if err != nil {
switch err {
case logical.ErrPermissionDenied:
statusCode = http.StatusForbidden
case logical.ErrUnsupportedOperation:
statusCode = http.StatusMethodNotAllowed
case logical.ErrUnsupportedPath:
statusCode = http.StatusNotFound
case logical.ErrInvalidRequest:
statusCode = http.StatusBadRequest
}
}
err := fmt.Errorf("%s", resp.Data["error"].(string))

View File

@ -3,6 +3,7 @@ package http
import (
"net/http"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/vault"
)
@ -20,7 +21,7 @@ func handleSysLeader(core *vault.Core) http.Handler {
func handleSysLeaderGet(core *vault.Core, w http.ResponseWriter, r *http.Request) {
haEnabled := true
isLeader, address, err := core.Leader()
if err == vault.ErrHANotEnabled {
if errwrap.Contains(err, vault.ErrHANotEnabled.Error()) {
haEnabled = false
err = nil
}

View File

@ -9,6 +9,7 @@ import (
"sort"
"testing"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
@ -302,8 +303,10 @@ func Test(t TestT, c TestCase) {
if err == nil && resp.IsError() {
err = fmt.Errorf("Erroneous response:\n\n%#v", resp)
}
if err != nil && err != logical.ErrUnsupportedOperation {
t.Error(fmt.Sprintf("[ERR] Rollback error: %s", err))
if err != nil {
if !errwrap.Contains(err, logical.ErrUnsupportedOperation.Error()) {
t.Error(fmt.Sprintf("[ERR] Rollback error: %s", err))
}
}
// If we have any failed revokes, log it.

View File

@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"os"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-rootcerts"
@ -49,6 +50,14 @@ type Meta struct {
TokenHelper TokenHelperFunc
}
func (m *Meta) DefaultWrappingLookupFunc(operation, path string) string {
if m.flagWrapTTL != "" {
return m.flagWrapTTL
}
return os.Getenv(api.EnvVaultWrapTTL)
}
// Client returns the API client to a Vault server given the configured
// flag settings for this command.
func (m *Meta) Client() (*api.Client, error) {
@ -94,16 +103,14 @@ func (m *Meta) Client() (*api.Client, error) {
}
}
if m.flagWrapTTL != "" {
config.WrapTTL = m.flagWrapTTL
}
// Build the client
client, err := api.NewClient(config)
if err != nil {
return nil, err
}
client.SetWrappingLookupFunc(m.DefaultWrappingLookupFunc)
// If we have a token directly, then set that
token := m.ClientToken
@ -200,17 +207,8 @@ func GeneralOptionsUsage() string {
-tls-skip-verify Do not verify TLS certificate. This is highly
not recommended. Verification will also be skipped
if VAULT_SKIP_VERIFY is set.
-wrap-ttl="" Indiciates that the response should be wrapped in a
cubbyhole token with the requested TTL. The response
will live at "/response" in the cubbyhole of the
returned token with a key of "response" and can be
parsed as a normal API Secret. The backend can also
request wrapping; the lesser of the values is used.
This is a numeric string with an optional suffix of
"s", "m", or "h"; if no suffix is specified it will
be parsed as seconds. May also be specified via
VAULT_WRAP_TTL.
`
general += AdditionalOptionsUsage()
return general
}

7
meta/meta_non_vault.go Normal file
View File

@ -0,0 +1,7 @@
// +build !vault
package meta
func AdditionalOptionsUsage() string {
return ""
}

18
meta/meta_vault.go Normal file
View File

@ -0,0 +1,18 @@
// +build vault
package meta
func AdditionalOptionsUsage() string {
return `
-wrap-ttl="" Indiciates that the response should be wrapped in a
cubbyhole token with the requested TTL. The response
will live at "cubbyhole/response" in the cubbyhole of
the returned token with a key of "response" and can
be parsed as a normal API Secret. The backend can
also request wrapping; the lesser of the values is
used. This is a numeric string with an optional
suffix "s", "m", or "h"; if no suffix is specified it
will be parsed as seconds. May also be specified via
VAULT_WRAP_TTL.
`
}

View File

@ -665,8 +665,9 @@ func (c *Core) Seal(token string) (retErr error) {
c.stateLock.Lock()
defer c.stateLock.Unlock()
if c.sealed {
return nil
return retErr
}
// Validate the token is a root token
@ -685,9 +686,11 @@ func (c *Core) Seal(token string) (retErr error) {
// essentially does the same thing.
if c.standby {
c.logger.Printf("[ERR] core: vault cannot seal when in standby mode; please restart instead")
return errors.New("vault cannot seal when in standby mode; please restart instead")
retErr = multierror.Append(retErr, errors.New("vault cannot seal when in standby mode; please restart instead"))
return retErr
}
return err
retErr = multierror.Append(retErr, err)
return retErr
}
// Attempt to use the token (decrement num_uses)
// On error bail out; if the token has been revoked, bail out too
@ -695,42 +698,50 @@ func (c *Core) Seal(token string) (retErr error) {
te, err = c.tokenStore.UseToken(te)
if err != nil {
c.logger.Printf("[ERR] core: failed to use token: %v", err)
return ErrInternalError
retErr = multierror.Append(retErr, ErrInternalError)
return retErr
}
if te == nil {
// Token is no longer valid
return logical.ErrPermissionDenied
retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
return retErr
}
if te.NumUses == -1 {
// Token needs to be revoked
return c.tokenStore.Revoke(te.ID)
defer func(id string) {
err = c.tokenStore.Revoke(id)
if err != nil {
c.logger.Printf("[ERR] core: token needed revocation after seal but failed to revoke: %v", err)
retErr = multierror.Append(retErr, ErrInternalError)
}
}(te.ID)
}
}
// Verify that this operation is allowed
allowed, rootPrivs := acl.AllowOperation(req.Operation, req.Path)
if !allowed {
return logical.ErrPermissionDenied
retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
return retErr
}
// We always require root privileges for this operation
if !rootPrivs {
return logical.ErrPermissionDenied
retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
return retErr
}
//Seal the Vault
err = c.sealInternal()
if err == nil && retErr == ErrInternalError {
c.logger.Printf("[ERR] core: core is successfully sealed but another error occurred during the operation")
} else {
retErr = err
if err != nil {
retErr = multierror.Append(retErr, err)
}
return
return retErr
}
// StepDown is used to step down from leadership
func (c *Core) StepDown(token string) error {
func (c *Core) StepDown(token string) (retErr error) {
defer metrics.MeasureSince([]string{"core", "step_down"}, time.Now())
c.stateLock.Lock()
@ -751,34 +762,45 @@ func (c *Core) StepDown(token string) error {
acl, te, err := c.fetchACLandTokenEntry(req)
if err != nil {
return err
retErr = multierror.Append(retErr, err)
return retErr
}
// Attempt to use the token (decrement num_uses)
if te != nil {
te, err = c.tokenStore.UseToken(te)
if err != nil {
c.logger.Printf("[ERR] core: failed to use token: %v", err)
return err
retErr = multierror.Append(retErr, ErrInternalError)
return retErr
}
if te == nil {
// Token has been revoked
return logical.ErrPermissionDenied
retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
return retErr
}
if te.NumUses == -1 {
// Token needs to be revoked
return c.tokenStore.Revoke(te.ID)
defer func(id string) {
err = c.tokenStore.Revoke(id)
if err != nil {
c.logger.Printf("[ERR] core: token needed revocation after step-down but failed to revoke: %v", err)
retErr = multierror.Append(retErr, ErrInternalError)
}
}(te.ID)
}
}
// Verify that this operation is allowed
allowed, rootPrivs := acl.AllowOperation(req.Operation, req.Path)
if !allowed {
return logical.ErrPermissionDenied
retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
return retErr
}
// We always require root privileges for this operation
if !rootPrivs {
return logical.ErrPermissionDenied
retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
return retErr
}
select {
@ -787,7 +809,7 @@ func (c *Core) StepDown(token string) error {
c.logger.Printf("[WARN] core: manual step-down operation already queued")
}
return nil
return retErr
}
// sealInternal is an internal method used to seal the vault. It does not do

View File

@ -7,6 +7,7 @@ import (
"testing"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/logical"
@ -384,7 +385,7 @@ func TestCore_HandleRequest_MissingToken(t *testing.T) {
},
}
resp, err := c.HandleRequest(req)
if err != logical.ErrInvalidRequest {
if err == nil || !errwrap.Contains(err, logical.ErrInvalidRequest.Error()) {
t.Fatalf("err: %v", err)
}
if resp.Data["error"] != "missing client token" {
@ -405,7 +406,7 @@ func TestCore_HandleRequest_InvalidToken(t *testing.T) {
ClientToken: "foobarbaz",
}
resp, err := c.HandleRequest(req)
if err != logical.ErrPermissionDenied {
if err == nil || !errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
t.Fatalf("err: %v", err)
}
if resp.Data["error"] != "permission denied" {
@ -442,7 +443,7 @@ func TestCore_HandleRequest_RootPath(t *testing.T) {
ClientToken: "child",
}
resp, err := c.HandleRequest(req)
if err != logical.ErrPermissionDenied {
if err == nil || !errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
t.Fatalf("err: %v, resp: %v", err, resp)
}
}
@ -499,7 +500,7 @@ func TestCore_HandleRequest_PermissionDenied(t *testing.T) {
ClientToken: "child",
}
resp, err := c.HandleRequest(req)
if err != logical.ErrPermissionDenied {
if err == nil || !errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
t.Fatalf("err: %v, resp: %v", err, resp)
}
}
@ -947,7 +948,7 @@ func TestCore_LimitedUseToken(t *testing.T) {
// Second operation should fail
_, err = c.HandleRequest(req)
if err != logical.ErrPermissionDenied {
if err == nil || !errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
t.Fatalf("err: %v", err)
}
}

View File

@ -7,6 +7,7 @@ import (
"github.com/armon/go-metrics"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/golang-lru"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
)
@ -18,6 +19,9 @@ const (
// policyCacheSize is the number of policies that are kept cached
policyCacheSize = 1024
// cubbyholeResponseWrappingPolicyName is the name of the fixed policy
cubbyholeResponseWrappingPolicyName = "cubbyhole-response-wrapping"
// cubbyholeResponseWrappingPolicy is the policy that ensures cubbyhole
// response wrapping can always succeed
cubbyholeResponseWrappingPolicy = `
@ -27,6 +31,13 @@ path "cubbyhole/response" {
`
)
var (
immutablePolicies = []string{
"root",
cubbyholeResponseWrappingPolicyName,
}
)
// PolicyStore is used to provide durable storage of policy, and to
// manage ACLs associated with them.
type PolicyStore struct {
@ -76,7 +87,7 @@ func (c *Core) setupPolicyStore() error {
}
// Ensure that the cubbyhole response wrapping policy exists
policy, err = c.policyStore.GetPolicy("cubbyhole-response-wrapping")
policy, err = c.policyStore.GetPolicy(cubbyholeResponseWrappingPolicyName)
if err != nil {
return errwrap.Wrapf("error fetching default policy from store: {{err}}", err)
}
@ -100,15 +111,12 @@ func (c *Core) teardownPolicyStore() error {
// SetPolicy is used to create or update the given policy
func (ps *PolicyStore) SetPolicy(p *Policy) error {
defer metrics.MeasureSince([]string{"policy", "set_policy"}, time.Now())
if p.Name == "root" {
return fmt.Errorf("cannot update root policy")
}
if p.Name == "cubbyhole-response-wrapping" {
return fmt.Errorf("cannot update cubbyhole-response-wrapping policy")
}
if p.Name == "" {
return fmt.Errorf("policy name missing")
}
if strutil.StrListContains(immutablePolicies, p.Name) {
return fmt.Errorf("cannot update %s policy", p.Name)
}
return ps.setPolicyInternal(p)
}
@ -208,15 +216,12 @@ func (ps *PolicyStore) ListPolicies() ([]string, error) {
// DeletePolicy is used to delete the named policy
func (ps *PolicyStore) DeletePolicy(name string) error {
defer metrics.MeasureSince([]string{"policy", "delete_policy"}, time.Now())
if name == "root" {
return fmt.Errorf("cannot delete root policy")
if strutil.StrListContains(immutablePolicies, name) {
return fmt.Errorf("cannot delete %s policy", name)
}
if name == "default" {
return fmt.Errorf("cannot delete default policy")
}
if name == "cubbyhole-response-wrapping" {
return fmt.Errorf("cannot delete cubbyhole-response-wrapping policy")
}
if err := ps.view.Delete(name); err != nil {
return fmt.Errorf("failed to delete policy: %v", err)
}
@ -286,13 +291,13 @@ path "cubbyhole" {
func (ps *PolicyStore) createCubbyholeResponseWrappingPolicy() error {
policy, err := Parse(cubbyholeResponseWrappingPolicy)
if err != nil {
return errwrap.Wrapf("error parsing cubbyhole-response-wrapping policy: {{err}}", err)
return errwrap.Wrapf(fmt.Sprintf("error parsing %s policy: {{err}}", cubbyholeResponseWrappingPolicyName), err)
}
if policy == nil {
return fmt.Errorf("parsing cubbyhole-response-wrapping policy resulted in nil policy")
return fmt.Errorf("parsing %s policy resulted in nil policy", cubbyholeResponseWrappingPolicyName)
}
policy.Name = "cubbyhole-response-wrapping"
policy.Name = cubbyholeResponseWrappingPolicyName
return ps.setPolicyInternal(policy)
}

View File

@ -142,12 +142,33 @@ func TestPolicyStore_Predefined(t *testing.T) {
t.Fatalf("bad: %v", out)
}
p, err := core.policyStore.GetPolicy("cubbyhole-response-wrapping")
pCubby, err := core.policyStore.GetPolicy("cubbyhole-response-wrapping")
if err != nil {
t.Fatalf("err: %v", err)
}
if p.Raw != cubbyholeResponseWrappingPolicy {
t.Fatalf("bad: expected\n%s\ngot\n%s\n", cubbyholeResponseWrappingPolicy, p.Raw)
if pCubby.Raw != cubbyholeResponseWrappingPolicy {
t.Fatalf("bad: expected\n%s\ngot\n%s\n", cubbyholeResponseWrappingPolicy, pCubby.Raw)
}
pRoot, err := core.policyStore.GetPolicy("root")
if err != nil {
t.Fatalf("err: %v", err)
}
err = core.policyStore.SetPolicy(pCubby)
if err == nil {
t.Fatalf("expected err setting %s", pCubby.Name)
}
err = core.policyStore.SetPolicy(pRoot)
if err == nil {
t.Fatalf("expected err setting %s", pRoot.Name)
}
err = core.policyStore.DeletePolicy(pCubby.Name)
if err == nil {
t.Fatalf("expected err deleting %s", pCubby.Name)
}
err = core.policyStore.DeletePolicy(pRoot.Name)
if err == nil {
t.Fatalf("expected err deleting %s", pRoot.Name)
}
}

View File

@ -3,23 +3,14 @@ package vault
import (
"sort"
"strings"
"sync"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
)
var (
// Value for memoizing whether cubbyhole is mounted, e.g. if we are in
// normal operation and not test mode
cubbyholeMounted bool
// mutex to ensure the same
cubbyholeMountedMutex sync.Mutex
)
// HandleRequest is used to handle a new incoming request
func (c *Core) HandleRequest(req *logical.Request) (resp *logical.Response, err error) {
c.stateLock.RLock()
@ -60,87 +51,16 @@ func (c *Core) HandleRequest(req *logical.Request) (resp *logical.Response, err
}
}
// In order to wrap, we need cubbyhole to be mounted, so we ensure that
// cubbyhole is actually mounted, as it may not be during tests. We memoize
// a true response, since cubbyhole cannot be mounted or unmounted during
// normal operation.
if !cubbyholeMounted {
cubbyholeMountedMutex.Lock()
// Ensure it wasn't changed by another goroutine
if !cubbyholeMounted {
if c.router.MatchingMount("cubbyhole/") != "" {
cubbyholeMounted = true
}
}
cubbyholeMountedMutex.Unlock()
}
// We are wrapping if there is anything to wrap (not a nil response) and a
// TTL was specified for the token, plus if cubbyhole is mounted (which
// will be the case normally)
wrapping := cubbyholeMounted && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.TTL != 0
// TTL was specified for the token
wrapping := resp != nil && resp.WrapInfo != nil && resp.WrapInfo.TTL != 0
// If we are wrapping, the first part happens before auditing so that
// resp.WrapInfo.Token can contain the HMAC'd wrapping token ID in the
// audit logs, so that it can be determined from the audit logs whether the
// token was ever actually used.
if wrapping {
// Create the wrapping token
te := TokenEntry{
Path: req.Path,
Policies: []string{"cubbyhole-response-wrapping"},
CreationTime: time.Now().Unix(),
TTL: resp.WrapInfo.TTL,
NumUses: 1,
}
if err := c.tokenStore.create(&te); err != nil {
c.logger.Printf("[ERR] core: failed to create wrapping token: %v", err)
return nil, ErrInternalError
}
resp.WrapInfo.Token = te.ID
httpResponse := logical.SanitizeResponse(resp)
cubbyReq := &logical.Request{
Operation: logical.CreateOperation,
Path: "cubbyhole/response",
ClientToken: te.ID,
Data: map[string]interface{}{
"response": httpResponse,
},
}
cubbyResp, err := c.router.Route(cubbyReq)
if err != nil {
// Revoke since it's not yet being tracked for expiration
c.tokenStore.Revoke(te.ID)
c.logger.Printf("[ERR] core: failed to store wrapped response information: %v", err)
return nil, ErrInternalError
}
if cubbyResp != nil && cubbyResp.IsError() {
c.tokenStore.Revoke(te.ID)
c.logger.Printf("[ERR] core: failed to store wrapped response information: %v", cubbyResp.Data["error"])
return cubbyResp, nil
}
auth := &logical.Auth{
ClientToken: te.ID,
Policies: []string{"cubbyhole-response-wrapping"},
LeaseOptions: logical.LeaseOptions{
TTL: te.TTL,
Renewable: false,
},
}
// Register the wrapped token with the expiration manager
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
// Revoke since it's not yet being tracked for expiration
c.tokenStore.Revoke(te.ID)
c.logger.Printf("[ERR] core: failed to register cubbyhole wrapping token lease "+
"(request path: %s): %v", req.Path, err)
return nil, ErrInternalError
cubbyResp, err := c.wrapInCubbyhole(req, resp)
// If not successful, returns either an error response from the
// cubbyhole backend or an error; if either is set, return
if cubbyResp != nil || err != nil {
return cubbyResp, err
}
}
@ -176,11 +96,13 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
te, err = c.tokenStore.UseToken(te)
if err != nil {
c.logger.Printf("[ERR] core: failed to use token: %v", err)
return nil, nil, ErrInternalError
retErr = multierror.Append(retErr, ErrInternalError)
return nil, nil, retErr
}
if te == nil {
// Token has been revoked by this point
return nil, nil, logical.ErrPermissionDenied
retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
return nil, nil, retErr
}
if te.NumUses == -1 {
// We defer a revocation until after logic has run, since this is a
@ -192,7 +114,7 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
c.logger.Printf("[ERR] core: failed to revoke token: %v", err)
retResp = nil
retAuth = nil
retErr = ErrInternalError
retErr = multierror.Append(retErr, ErrInternalError)
}
if retResp != nil && retResp.Secret != nil &&
// Some backends return a TTL even without a Lease ID
@ -219,7 +141,10 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
req.Path, err)
}
return logical.ErrorResponse(ctErr.Error()), nil, errType
if errType != nil {
retErr = multierror.Append(retErr, errType)
}
return logical.ErrorResponse(ctErr.Error()), nil, retErr
}
// Attach the display name
@ -229,11 +154,37 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
if err := c.auditBroker.LogRequest(auth, req, nil); err != nil {
c.logger.Printf("[ERR] core: failed to audit request with path (%s): %v",
req.Path, err)
return nil, auth, ErrInternalError
retErr = multierror.Append(retErr, ErrInternalError)
return nil, auth, retErr
}
// Route the request
resp, err := c.router.Route(req)
if resp != nil {
// If either of the request or response requested wrapping, ensure that
// the lowest value is what ends up in the response.
switch {
case req.WrapTTL == 0 && (resp.WrapInfo == nil || resp.WrapInfo.TTL == 0):
// Neither defines it, so do nothing
case req.WrapTTL != 0 && (resp.WrapInfo != nil && resp.WrapInfo.TTL != 0):
// Both define, so use the lowest
if req.WrapTTL < resp.WrapInfo.TTL {
resp.WrapInfo.TTL = req.WrapTTL
}
case req.WrapTTL != 0:
// Response wrap info doesn't exist, or its TTL is zero, so set
// it to the request TTL
resp.WrapInfo = &logical.WrapInfo{
TTL: req.WrapTTL,
}
default:
// Only case left is that only resp defines it, which doesn't
// need to be explicitly handled
}
}
// If there is a secret, we must register it with the expiration manager.
// We exclude renewal of a lease, since it does not need to be re-registered
@ -242,7 +193,8 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
sysView := c.router.MatchingSystemView(req.Path)
if sysView == nil {
c.logger.Println("[ERR] core: unable to retrieve system view from router")
return nil, auth, ErrInternalError
retErr = multierror.Append(retErr, ErrInternalError)
return nil, auth, retErr
}
// Apply the default lease if none given
@ -262,7 +214,8 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
matchingBackend := c.router.MatchingBackend(req.Path)
if matchingBackend == nil {
c.logger.Println("[ERR] core: unable to retrieve generic backend from router")
return nil, auth, ErrInternalError
retErr = multierror.Append(retErr, ErrInternalError)
return nil, auth, retErr
}
if ptbe, ok := matchingBackend.(*PassthroughBackend); ok {
if !ptbe.GeneratesLeases() {
@ -277,7 +230,8 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
c.logger.Printf(
"[ERR] core: failed to register lease "+
"(request path: %s): %v", req.Path, err)
return nil, auth, ErrInternalError
retErr = multierror.Append(retErr, ErrInternalError)
return nil, auth, retErr
}
resp.Secret.LeaseID = leaseID
}
@ -291,7 +245,8 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
c.logger.Printf(
"[ERR] core: unexpected Auth response for non-token backend "+
"(request path: %s)", req.Path)
return nil, auth, ErrInternalError
retErr = multierror.Append(retErr, ErrInternalError)
return nil, auth, retErr
}
// Register with the expiration manager. We use the token's actual path
@ -299,18 +254,23 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
te, err := c.tokenStore.Lookup(resp.Auth.ClientToken)
if err != nil {
c.logger.Printf("[ERR] core: failed to lookup token: %v", err)
return nil, nil, ErrInternalError
retErr = multierror.Append(retErr, ErrInternalError)
return nil, nil, retErr
}
if err := c.expiration.RegisterAuth(te.Path, resp.Auth); err != nil {
c.logger.Printf("[ERR] core: failed to register token lease "+
"(request path: %s): %v", req.Path, err)
return nil, auth, ErrInternalError
retErr = multierror.Append(retErr, ErrInternalError)
return nil, auth, retErr
}
}
// Return the response and error
return resp, auth, err
if err != nil {
retErr = multierror.Append(retErr, err)
}
return resp, auth, retErr
}
// handleLoginRequest is used to handle a login request, which is an
@ -423,3 +383,68 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
return resp, auth, err
}
func (c *Core) wrapInCubbyhole(req *logical.Request, resp *logical.Response) (*logical.Response, error) {
// If we are wrapping, the first part (performed in this functions) happens
// before auditing so that resp.WrapInfo.Token can contain the HMAC'd
// wrapping token ID in the audit logs, so that it can be determined from
// the audit logs whether the token was ever actually used.
te := TokenEntry{
Path: req.Path,
Policies: []string{"cubbyhole-response-wrapping"},
CreationTime: time.Now().Unix(),
TTL: resp.WrapInfo.TTL,
NumUses: 1,
}
if err := c.tokenStore.create(&te); err != nil {
c.logger.Printf("[ERR] core: failed to create wrapping token: %v", err)
return nil, ErrInternalError
}
resp.WrapInfo.Token = te.ID
httpResponse := logical.SanitizeResponse(resp)
cubbyReq := &logical.Request{
Operation: logical.CreateOperation,
Path: "cubbyhole/response",
ClientToken: te.ID,
Data: map[string]interface{}{
"response": httpResponse,
},
}
cubbyResp, err := c.router.Route(cubbyReq)
if err != nil {
// Revoke since it's not yet being tracked for expiration
c.tokenStore.Revoke(te.ID)
c.logger.Printf("[ERR] core: failed to store wrapped response information: %v", err)
return nil, ErrInternalError
}
if cubbyResp != nil && cubbyResp.IsError() {
c.tokenStore.Revoke(te.ID)
c.logger.Printf("[ERR] core: failed to store wrapped response information: %v", cubbyResp.Data["error"])
return cubbyResp, nil
}
auth := &logical.Auth{
ClientToken: te.ID,
Policies: []string{"cubbyhole-response-wrapping"},
LeaseOptions: logical.LeaseOptions{
TTL: te.TTL,
Renewable: false,
},
}
// Register the wrapped token with the expiration manager
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
// Revoke since it's not yet being tracked for expiration
c.tokenStore.Revoke(te.ID)
c.logger.Printf("[ERR] core: failed to register cubbyhole wrapping token lease "+
"(request path: %s): %v", req.Path, err)
return nil, ErrInternalError
}
return nil, nil
}

View File

@ -0,0 +1,111 @@
package vault
import (
"testing"
"time"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/logical"
)
func TestRequestHandling_Wrapping(t *testing.T) {
core, _, root := TestCoreUnsealed(t)
n := &NoopBackend{}
core.logicalBackends["noop"] = func(config *logical.BackendConfig) (logical.Backend, error) {
return n, nil
}
meUUID, _ := uuid.GenerateUUID()
err := core.mount(&MountEntry{
UUID: meUUID,
Path: "wraptest",
Type: "noop",
})
if err != nil {
t.Fatalf("err: %v", err)
}
// No duration specified
req := &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
}
resp, err := core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %#v", resp)
}
// Just in the request
req = &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
WrapTTL: time.Duration(15 * time.Second),
}
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) {
t.Fatalf("bad: %#v", resp)
}
// Just in the response
n.WrapTTL = time.Duration(15 * time.Second)
req = &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
}
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) {
t.Fatalf("bad: %#v", resp)
}
// In both, with request less
n.WrapTTL = time.Duration(15 * time.Second)
req = &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
WrapTTL: time.Duration(10 * time.Second),
}
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(10*time.Second) {
t.Fatalf("bad: %#v", resp)
}
// In both, with response less
n.WrapTTL = time.Duration(10 * time.Second)
req = &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
WrapTTL: time.Duration(15 * time.Second),
}
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
}

View File

@ -194,7 +194,7 @@ func (r *Router) RouteExistenceCheck(req *logical.Request) (bool, bool, error) {
return ok, exists, err
}
func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (resp *logical.Response, ok bool, exists bool, err error) {
func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logical.Response, bool, bool, error) {
// Find the mount point
r.l.RLock()
mount, raw, ok := r.root.LongestPrefix(req.Path)
@ -250,37 +250,8 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (resp *l
// Reset the request before returning
defer func() {
// We only run this if resp is not nil, so for instance we don't during
// an existence check
if resp != nil {
// If either of the request or response requested wrapping, ensure that
// the lowest value is what ends up in the response.
switch {
case req.WrapTTL == 0 && (resp.WrapInfo == nil || resp.WrapInfo.TTL == 0):
// Neither defines it, so do nothing
case req.WrapTTL != 0 && (resp.WrapInfo != nil && resp.WrapInfo.TTL != 0):
// Both define, so use the lowest
if req.WrapTTL < resp.WrapInfo.TTL {
resp.WrapInfo.TTL = req.WrapTTL
}
case req.WrapTTL != 0:
// Response wrap info doesn't exist, or its TTL is zero, so set
// it to the request TTL
resp.WrapInfo = &logical.WrapInfo{
TTL: req.WrapTTL,
}
default:
// Only case left is that only resp defines it, which doesn't
// need to be explicitly handled
}
}
// Reset other parameters
req.MountPoint = ""
req.Path = original
req.MountPoint = ""
req.Connection = originalConn
req.Storage = nil
req.ClientToken = clientToken
@ -291,7 +262,7 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (resp *l
ok, exists, err := re.backend.HandleExistenceCheck(req)
return nil, ok, exists, err
} else {
resp, err = re.backend.HandleRequest(req)
resp, err := re.backend.HandleRequest(req)
return resp, false, false, err
}
}

View File

@ -384,111 +384,3 @@ func TestPathsToRadix(t *testing.T) {
t.Fatalf("bad: %v (sub/bar)", raw)
}
}
func TestRouter_Wrapping(t *testing.T) {
core, _, root := TestCoreUnsealed(t)
n := &NoopBackend{}
core.logicalBackends["noop"] = func(config *logical.BackendConfig) (logical.Backend, error) {
return n, nil
}
meUUID, _ := uuid.GenerateUUID()
err := core.mount(&MountEntry{
UUID: meUUID,
Path: "wraptest",
Type: "noop",
})
if err != nil {
t.Fatalf("err: %v", err)
}
// No duration specified
req := &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
}
resp, err := core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %#v", resp)
}
// Just in the request
req = &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
WrapTTL: time.Duration(15 * time.Second),
}
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) {
t.Fatalf("bad: %#v", resp)
}
// Just in the response
n.WrapTTL = time.Duration(15 * time.Second)
req = &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
}
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) {
t.Fatalf("bad: %#v", resp)
}
// In both, with request less
n.WrapTTL = time.Duration(15 * time.Second)
req = &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
WrapTTL: time.Duration(10 * time.Second),
}
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(10*time.Second) {
t.Fatalf("bad: %#v", resp)
}
// In both, with response less
n.WrapTTL = time.Duration(10 * time.Second)
req = &logical.Request{
Path: "wraptest/foo",
ClientToken: root,
Operation: logical.UpdateOperation,
WrapTTL: time.Duration(15 * time.Second),
}
resp, err = core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: %v", resp)
}
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(10*time.Second) {
t.Fatalf("bad: %#v", resp)
}
}

View File

@ -92,14 +92,11 @@ func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error)
t.salt = salt
t.tokenLocks = map[string]*sync.RWMutex{}
for i := int64(0); i < 16; i++ {
for j := int64(0); j < 16; j++ {
t.tokenLocks[fmt.Sprintf("%s%s",
strconv.FormatInt(i, 16),
strconv.FormatInt(j, 16))] = &sync.RWMutex{}
}
for i := int64(0); i < 256; i++ {
t.tokenLocks[fmt.Sprintf("%2x",
strconv.FormatInt(i, 16))] = &sync.RWMutex{}
}
t.tokenLocks["global"] = &sync.RWMutex{}
t.tokenLocks["custom"] = &sync.RWMutex{}
// Setup the framework endpoints
t.Backend = &framework.Backend{
@ -580,7 +577,8 @@ func (ts *TokenStore) getTokenLock(id string) *sync.RWMutex {
lock, ok = ts.tokenLocks[id[0:2]]
}
if !ok || lock == nil {
lock = ts.tokenLocks["global"]
// Fall back for custom token IDs
lock = ts.tokenLocks["custom"]
}
return lock
@ -614,20 +612,20 @@ func (ts *TokenStore) UseToken(te *TokenEntry) (*TokenEntry, error) {
if err != nil {
return nil, fmt.Errorf("failed to refresh entry: %v", err)
}
// If it can't be found we shouldn't be trying to use it, so if we get nil
// back, it is because it has been revoked in the interim or will be
// revoked (NumUses is -1)
if te == nil {
return nil, fmt.Errorf("token entry nil after refreshing to decrement use count; token has likely been used already")
}
// If the token is already being revoked, return nil to indicate that it's
// no longer valid
if te.NumUses == -1 {
return nil, nil
return nil, fmt.Errorf("token not found or fully used already")
}
// Decrement the count. If this is our last use count, we need to indicate
// that this is no longer valid, but revocation is deferred to the end of
// the call, so this will make sure that any Lookup that happens doesn't
// return an entry.
// return an entry. This essentially acts as a write-ahead lock and is
// especially useful since revocation can end up (via the expiration
// manager revoking children) attempting to acquire the same lock
// repeatedly.
if te.NumUses == 1 {
te.NumUses = -1
} else {