Address most review feedback. Change responses to multierror to better return more useful values when there are multiple errors
This commit is contained in:
parent
4c67a739b9
commit
c4431a7e30
|
@ -29,6 +29,8 @@ var (
|
|||
errRedirect = errors.New("redirect")
|
||||
)
|
||||
|
||||
type WrappingLookupFunc func(operation, path string) string
|
||||
|
||||
// Config is used to configure the creation of the client.
|
||||
type Config struct {
|
||||
// Address is the address of the Vault server. This should be a complete
|
||||
|
@ -37,14 +39,6 @@ type Config struct {
|
|||
// HttpClient.
|
||||
Address string
|
||||
|
||||
// WrapTTL, if specified, asks the Vault server to return the normal
|
||||
// response wrapped in the cubbyhole of a token, with the TTL of the token
|
||||
// being set to the lesser of this value or a value requested by the
|
||||
// backend originating the response. Specified either as a number of
|
||||
// seconds, or a string duration with a "s", "m", or "h" suffix for
|
||||
// "seconds", "minutes", or "hours" respectively.
|
||||
WrapTTL string
|
||||
|
||||
// HttpClient is the HTTP client to use, which will currently always have the
|
||||
// same values as http.DefaultClient. This is used to control redirect behavior.
|
||||
HttpClient *http.Client
|
||||
|
@ -86,7 +80,6 @@ func (c *Config) ReadEnvironment() error {
|
|||
var envCAPath string
|
||||
var envClientCert string
|
||||
var envClientKey string
|
||||
var envWrapTTL string
|
||||
var envInsecure bool
|
||||
var foundInsecure bool
|
||||
var envTLSServerName string
|
||||
|
@ -110,9 +103,6 @@ func (c *Config) ReadEnvironment() error {
|
|||
if v := os.Getenv(EnvVaultClientKey); v != "" {
|
||||
envClientKey = v
|
||||
}
|
||||
if v := os.Getenv(EnvVaultWrapTTL); v != "" {
|
||||
envWrapTTL = v
|
||||
}
|
||||
if v := os.Getenv(EnvVaultInsecure); v != "" {
|
||||
var err error
|
||||
envInsecure, err = strconv.ParseBool(v)
|
||||
|
@ -141,10 +131,6 @@ func (c *Config) ReadEnvironment() error {
|
|||
c.Address = envAddress
|
||||
}
|
||||
|
||||
if envWrapTTL != "" {
|
||||
c.WrapTTL = envWrapTTL
|
||||
}
|
||||
|
||||
clientTLSConfig := c.HttpClient.Transport.(*http.Transport).TLSClientConfig
|
||||
if foundInsecure {
|
||||
clientTLSConfig.InsecureSkipVerify = envInsecure
|
||||
|
@ -172,9 +158,10 @@ func (c *Config) ReadEnvironment() error {
|
|||
// Client is the client to the Vault API. Create a client with
|
||||
// NewClient.
|
||||
type Client struct {
|
||||
addr *url.URL
|
||||
config *Config
|
||||
token string
|
||||
addr *url.URL
|
||||
config *Config
|
||||
token string
|
||||
wrappingLookupFunc WrappingLookupFunc
|
||||
}
|
||||
|
||||
// NewClient returns a new client for the given configuration.
|
||||
|
@ -216,6 +203,12 @@ func NewClient(c *Config) (*Client, error) {
|
|||
return client, nil
|
||||
}
|
||||
|
||||
// SetWrappingLookupFunc sets a lookup function that returns desired wrap TTLs
|
||||
// for a given operation and path
|
||||
func (c *Client) SetWrappingLookupFunc(lookupFunc WrappingLookupFunc) {
|
||||
c.wrappingLookupFunc = lookupFunc
|
||||
}
|
||||
|
||||
// Token returns the access token being used by this client. It will
|
||||
// return the empty string if there is no token set.
|
||||
func (c *Client) Token() string {
|
||||
|
@ -245,10 +238,13 @@ func (c *Client) NewRequest(method, path string) *Request {
|
|||
Path: path,
|
||||
},
|
||||
ClientToken: c.token,
|
||||
WrapTTL: c.config.WrapTTL,
|
||||
Params: make(map[string][]string),
|
||||
}
|
||||
|
||||
if c.wrappingLookupFunc != nil {
|
||||
req.WrapTTL = c.wrappingLookupFunc(method, path)
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
|
|
|
@ -107,19 +107,16 @@ func TestClientEnvSettings(t *testing.T) {
|
|||
oldClientCert := os.Getenv(EnvVaultClientCert)
|
||||
oldClientKey := os.Getenv(EnvVaultClientKey)
|
||||
oldSkipVerify := os.Getenv(EnvVaultInsecure)
|
||||
oldWrapTTL := os.Getenv(EnvVaultWrapTTL)
|
||||
os.Setenv("VAULT_CACERT", cwd+"/test-fixtures/keys/cert.pem")
|
||||
os.Setenv("VAULT_CAPATH", cwd+"/test-fixtures/keys")
|
||||
os.Setenv("VAULT_CLIENT_CERT", cwd+"/test-fixtures/keys/cert.pem")
|
||||
os.Setenv("VAULT_CLIENT_KEY", cwd+"/test-fixtures/keys/key.pem")
|
||||
os.Setenv("VAULT_SKIP_VERIFY", "true")
|
||||
os.Setenv("VAULT_WRAP_TTL", "60")
|
||||
defer os.Setenv("VAULT_CACERT", oldCACert)
|
||||
defer os.Setenv("VAULT_CAPATH", oldCAPath)
|
||||
defer os.Setenv("VAULT_CLIENT_CERT", oldClientCert)
|
||||
defer os.Setenv("VAULT_CLIENT_KEY", oldClientKey)
|
||||
defer os.Setenv("VAULT_SKIP_VERIFY", oldSkipVerify)
|
||||
defer os.Setenv("VAULT_WRAP_TTL", oldWrapTTL)
|
||||
|
||||
config := DefaultConfig()
|
||||
if err := config.ReadEnvironment(); err != nil {
|
||||
|
@ -136,8 +133,4 @@ func TestClientEnvSettings(t *testing.T) {
|
|||
if tlsConfig.InsecureSkipVerify != true {
|
||||
t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify)
|
||||
}
|
||||
|
||||
if config.WrapTTL != "60" {
|
||||
t.Fatalf("bad: %v", config.WrapTTL)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
package command
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/meta"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func TestWrapping_Env(t *testing.T) {
|
||||
core, _, token := vault.TestCoreUnsealed(t)
|
||||
ln, addr := http.TestServer(t, core)
|
||||
defer ln.Close()
|
||||
|
||||
ui := new(cli.MockUi)
|
||||
c := &TokenLookupCommand{
|
||||
Meta: meta.Meta{
|
||||
ClientToken: token,
|
||||
Ui: ui,
|
||||
},
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"-address", addr,
|
||||
}
|
||||
// Run it once for client
|
||||
c.Run(args)
|
||||
|
||||
// Create a new token for us to use
|
||||
client, err := c.Client()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{
|
||||
Lease: "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
prevWrapTTLEnv := os.Getenv(api.EnvVaultWrapTTL)
|
||||
os.Setenv(api.EnvVaultWrapTTL, "5s")
|
||||
defer func() {
|
||||
os.Setenv(api.EnvVaultWrapTTL, prevWrapTTLEnv)
|
||||
}()
|
||||
|
||||
// Now when we do a lookup-self the response should be wrapped
|
||||
args = append(args, resp.Auth.ClientToken)
|
||||
|
||||
resp, err = client.Auth().Token().LookupSelf()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
if resp.WrapInfo == nil {
|
||||
t.Fatal("nil wrap info")
|
||||
}
|
||||
if resp.WrapInfo.Token == "" || resp.WrapInfo.TTL != 5 {
|
||||
t.Fatal("did not get token or ttl wrong")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapping_Flag(t *testing.T) {
|
||||
core, _, token := vault.TestCoreUnsealed(t)
|
||||
ln, addr := http.TestServer(t, core)
|
||||
defer ln.Close()
|
||||
|
||||
ui := new(cli.MockUi)
|
||||
c := &TokenLookupCommand{
|
||||
Meta: meta.Meta{
|
||||
ClientToken: token,
|
||||
Ui: ui,
|
||||
},
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"-address", addr,
|
||||
"-wrap-ttl", "5s",
|
||||
}
|
||||
// Run it once for client
|
||||
c.Run(args)
|
||||
|
||||
// Create a new token for us to use
|
||||
client, err := c.Client()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
resp, err := client.Auth().Token().Create(&api.TokenCreateRequest{
|
||||
Lease: "1h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
if resp.WrapInfo == nil {
|
||||
t.Fatal("nil wrap info")
|
||||
}
|
||||
if resp.WrapInfo.Token == "" || resp.WrapInfo.TTL != 5 {
|
||||
t.Fatal("did not get token or ttl wrong")
|
||||
}
|
||||
}
|
|
@ -93,7 +93,7 @@ func parseRequest(r *http.Request, out interface{}) error {
|
|||
// case of an error.
|
||||
func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool) {
|
||||
resp, err := core.HandleRequest(r)
|
||||
if err == vault.ErrStandby {
|
||||
if errwrap.Contains(err, vault.ErrStandby.Error()) {
|
||||
respondStandby(core, w, rawReq.URL)
|
||||
return resp, false
|
||||
}
|
||||
|
@ -195,7 +195,7 @@ func respondErrorStatus(w http.ResponseWriter, err error) {
|
|||
status := http.StatusInternalServerError
|
||||
switch {
|
||||
// Keep adding more error types here to appropriate the status codes
|
||||
case errwrap.ContainsType(err, new(vault.StatusBadRequest)):
|
||||
case err != nil && errwrap.ContainsType(err, new(vault.StatusBadRequest)):
|
||||
status = http.StatusBadRequest
|
||||
}
|
||||
respondError(w, status, err)
|
||||
|
@ -203,7 +203,7 @@ func respondErrorStatus(w http.ResponseWriter, err error) {
|
|||
|
||||
func respondError(w http.ResponseWriter, status int, err error) {
|
||||
// Adjust status code when sealed
|
||||
if err == vault.ErrSealed {
|
||||
if errwrap.Contains(err, vault.ErrSealed.Error()) {
|
||||
status = http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
|
@ -230,19 +230,19 @@ func respondCommon(w http.ResponseWriter, resp *logical.Response, err error) boo
|
|||
}
|
||||
|
||||
if resp.IsError() {
|
||||
var statusCode int
|
||||
statusCode := http.StatusBadRequest
|
||||
|
||||
switch err {
|
||||
case logical.ErrPermissionDenied:
|
||||
statusCode = http.StatusForbidden
|
||||
case logical.ErrUnsupportedOperation:
|
||||
statusCode = http.StatusMethodNotAllowed
|
||||
case logical.ErrUnsupportedPath:
|
||||
statusCode = http.StatusNotFound
|
||||
case logical.ErrInvalidRequest:
|
||||
statusCode = http.StatusBadRequest
|
||||
default:
|
||||
statusCode = http.StatusBadRequest
|
||||
if err != nil {
|
||||
switch err {
|
||||
case logical.ErrPermissionDenied:
|
||||
statusCode = http.StatusForbidden
|
||||
case logical.ErrUnsupportedOperation:
|
||||
statusCode = http.StatusMethodNotAllowed
|
||||
case logical.ErrUnsupportedPath:
|
||||
statusCode = http.StatusNotFound
|
||||
case logical.ErrInvalidRequest:
|
||||
statusCode = http.StatusBadRequest
|
||||
}
|
||||
}
|
||||
|
||||
err := fmt.Errorf("%s", resp.Data["error"].(string))
|
||||
|
|
|
@ -3,6 +3,7 @@ package http
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
|
@ -20,7 +21,7 @@ func handleSysLeader(core *vault.Core) http.Handler {
|
|||
func handleSysLeaderGet(core *vault.Core, w http.ResponseWriter, r *http.Request) {
|
||||
haEnabled := true
|
||||
isLeader, address, err := core.Leader()
|
||||
if err == vault.ErrHANotEnabled {
|
||||
if errwrap.Contains(err, vault.ErrHANotEnabled.Error()) {
|
||||
haEnabled = false
|
||||
err = nil
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
|
@ -302,8 +303,10 @@ func Test(t TestT, c TestCase) {
|
|||
if err == nil && resp.IsError() {
|
||||
err = fmt.Errorf("Erroneous response:\n\n%#v", resp)
|
||||
}
|
||||
if err != nil && err != logical.ErrUnsupportedOperation {
|
||||
t.Error(fmt.Sprintf("[ERR] Rollback error: %s", err))
|
||||
if err != nil {
|
||||
if !errwrap.Contains(err, logical.ErrUnsupportedOperation.Error()) {
|
||||
t.Error(fmt.Sprintf("[ERR] Rollback error: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
// If we have any failed revokes, log it.
|
||||
|
|
28
meta/meta.go
28
meta/meta.go
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-rootcerts"
|
||||
|
@ -49,6 +50,14 @@ type Meta struct {
|
|||
TokenHelper TokenHelperFunc
|
||||
}
|
||||
|
||||
func (m *Meta) DefaultWrappingLookupFunc(operation, path string) string {
|
||||
if m.flagWrapTTL != "" {
|
||||
return m.flagWrapTTL
|
||||
}
|
||||
|
||||
return os.Getenv(api.EnvVaultWrapTTL)
|
||||
}
|
||||
|
||||
// Client returns the API client to a Vault server given the configured
|
||||
// flag settings for this command.
|
||||
func (m *Meta) Client() (*api.Client, error) {
|
||||
|
@ -94,16 +103,14 @@ func (m *Meta) Client() (*api.Client, error) {
|
|||
}
|
||||
}
|
||||
|
||||
if m.flagWrapTTL != "" {
|
||||
config.WrapTTL = m.flagWrapTTL
|
||||
}
|
||||
|
||||
// Build the client
|
||||
client, err := api.NewClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client.SetWrappingLookupFunc(m.DefaultWrappingLookupFunc)
|
||||
|
||||
// If we have a token directly, then set that
|
||||
token := m.ClientToken
|
||||
|
||||
|
@ -200,17 +207,8 @@ func GeneralOptionsUsage() string {
|
|||
-tls-skip-verify Do not verify TLS certificate. This is highly
|
||||
not recommended. Verification will also be skipped
|
||||
if VAULT_SKIP_VERIFY is set.
|
||||
|
||||
-wrap-ttl="" Indiciates that the response should be wrapped in a
|
||||
cubbyhole token with the requested TTL. The response
|
||||
will live at "/response" in the cubbyhole of the
|
||||
returned token with a key of "response" and can be
|
||||
parsed as a normal API Secret. The backend can also
|
||||
request wrapping; the lesser of the values is used.
|
||||
This is a numeric string with an optional suffix of
|
||||
"s", "m", or "h"; if no suffix is specified it will
|
||||
be parsed as seconds. May also be specified via
|
||||
VAULT_WRAP_TTL.
|
||||
`
|
||||
|
||||
general += AdditionalOptionsUsage()
|
||||
return general
|
||||
}
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
// +build !vault
|
||||
|
||||
package meta
|
||||
|
||||
func AdditionalOptionsUsage() string {
|
||||
return ""
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
// +build vault
|
||||
|
||||
package meta
|
||||
|
||||
func AdditionalOptionsUsage() string {
|
||||
return `
|
||||
-wrap-ttl="" Indiciates that the response should be wrapped in a
|
||||
cubbyhole token with the requested TTL. The response
|
||||
will live at "cubbyhole/response" in the cubbyhole of
|
||||
the returned token with a key of "response" and can
|
||||
be parsed as a normal API Secret. The backend can
|
||||
also request wrapping; the lesser of the values is
|
||||
used. This is a numeric string with an optional
|
||||
suffix "s", "m", or "h"; if no suffix is specified it
|
||||
will be parsed as seconds. May also be specified via
|
||||
VAULT_WRAP_TTL.
|
||||
`
|
||||
}
|
|
@ -665,8 +665,9 @@ func (c *Core) Seal(token string) (retErr error) {
|
|||
|
||||
c.stateLock.Lock()
|
||||
defer c.stateLock.Unlock()
|
||||
|
||||
if c.sealed {
|
||||
return nil
|
||||
return retErr
|
||||
}
|
||||
|
||||
// Validate the token is a root token
|
||||
|
@ -685,9 +686,11 @@ func (c *Core) Seal(token string) (retErr error) {
|
|||
// essentially does the same thing.
|
||||
if c.standby {
|
||||
c.logger.Printf("[ERR] core: vault cannot seal when in standby mode; please restart instead")
|
||||
return errors.New("vault cannot seal when in standby mode; please restart instead")
|
||||
retErr = multierror.Append(retErr, errors.New("vault cannot seal when in standby mode; please restart instead"))
|
||||
return retErr
|
||||
}
|
||||
return err
|
||||
retErr = multierror.Append(retErr, err)
|
||||
return retErr
|
||||
}
|
||||
// Attempt to use the token (decrement num_uses)
|
||||
// On error bail out; if the token has been revoked, bail out too
|
||||
|
@ -695,42 +698,50 @@ func (c *Core) Seal(token string) (retErr error) {
|
|||
te, err = c.tokenStore.UseToken(te)
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to use token: %v", err)
|
||||
return ErrInternalError
|
||||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
return retErr
|
||||
}
|
||||
if te == nil {
|
||||
// Token is no longer valid
|
||||
return logical.ErrPermissionDenied
|
||||
retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
|
||||
return retErr
|
||||
}
|
||||
if te.NumUses == -1 {
|
||||
// Token needs to be revoked
|
||||
return c.tokenStore.Revoke(te.ID)
|
||||
defer func(id string) {
|
||||
err = c.tokenStore.Revoke(id)
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core: token needed revocation after seal but failed to revoke: %v", err)
|
||||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
}
|
||||
}(te.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that this operation is allowed
|
||||
allowed, rootPrivs := acl.AllowOperation(req.Operation, req.Path)
|
||||
if !allowed {
|
||||
return logical.ErrPermissionDenied
|
||||
retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
|
||||
return retErr
|
||||
}
|
||||
|
||||
// We always require root privileges for this operation
|
||||
if !rootPrivs {
|
||||
return logical.ErrPermissionDenied
|
||||
retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
|
||||
return retErr
|
||||
}
|
||||
|
||||
//Seal the Vault
|
||||
err = c.sealInternal()
|
||||
if err == nil && retErr == ErrInternalError {
|
||||
c.logger.Printf("[ERR] core: core is successfully sealed but another error occurred during the operation")
|
||||
} else {
|
||||
retErr = err
|
||||
if err != nil {
|
||||
retErr = multierror.Append(retErr, err)
|
||||
}
|
||||
|
||||
return
|
||||
return retErr
|
||||
}
|
||||
|
||||
// StepDown is used to step down from leadership
|
||||
func (c *Core) StepDown(token string) error {
|
||||
func (c *Core) StepDown(token string) (retErr error) {
|
||||
defer metrics.MeasureSince([]string{"core", "step_down"}, time.Now())
|
||||
|
||||
c.stateLock.Lock()
|
||||
|
@ -751,34 +762,45 @@ func (c *Core) StepDown(token string) error {
|
|||
|
||||
acl, te, err := c.fetchACLandTokenEntry(req)
|
||||
if err != nil {
|
||||
return err
|
||||
retErr = multierror.Append(retErr, err)
|
||||
return retErr
|
||||
}
|
||||
// Attempt to use the token (decrement num_uses)
|
||||
if te != nil {
|
||||
te, err = c.tokenStore.UseToken(te)
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to use token: %v", err)
|
||||
return err
|
||||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
return retErr
|
||||
}
|
||||
if te == nil {
|
||||
// Token has been revoked
|
||||
return logical.ErrPermissionDenied
|
||||
retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
|
||||
return retErr
|
||||
}
|
||||
if te.NumUses == -1 {
|
||||
// Token needs to be revoked
|
||||
return c.tokenStore.Revoke(te.ID)
|
||||
defer func(id string) {
|
||||
err = c.tokenStore.Revoke(id)
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core: token needed revocation after step-down but failed to revoke: %v", err)
|
||||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
}
|
||||
}(te.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that this operation is allowed
|
||||
allowed, rootPrivs := acl.AllowOperation(req.Operation, req.Path)
|
||||
if !allowed {
|
||||
return logical.ErrPermissionDenied
|
||||
retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
|
||||
return retErr
|
||||
}
|
||||
|
||||
// We always require root privileges for this operation
|
||||
if !rootPrivs {
|
||||
return logical.ErrPermissionDenied
|
||||
retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
|
||||
return retErr
|
||||
}
|
||||
|
||||
select {
|
||||
|
@ -787,7 +809,7 @@ func (c *Core) StepDown(token string) error {
|
|||
c.logger.Printf("[WARN] core: manual step-down operation already queued")
|
||||
}
|
||||
|
||||
return nil
|
||||
return retErr
|
||||
}
|
||||
|
||||
// sealInternal is an internal method used to seal the vault. It does not do
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/audit"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
|
@ -384,7 +385,7 @@ func TestCore_HandleRequest_MissingToken(t *testing.T) {
|
|||
},
|
||||
}
|
||||
resp, err := c.HandleRequest(req)
|
||||
if err != logical.ErrInvalidRequest {
|
||||
if err == nil || !errwrap.Contains(err, logical.ErrInvalidRequest.Error()) {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp.Data["error"] != "missing client token" {
|
||||
|
@ -405,7 +406,7 @@ func TestCore_HandleRequest_InvalidToken(t *testing.T) {
|
|||
ClientToken: "foobarbaz",
|
||||
}
|
||||
resp, err := c.HandleRequest(req)
|
||||
if err != logical.ErrPermissionDenied {
|
||||
if err == nil || !errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp.Data["error"] != "permission denied" {
|
||||
|
@ -442,7 +443,7 @@ func TestCore_HandleRequest_RootPath(t *testing.T) {
|
|||
ClientToken: "child",
|
||||
}
|
||||
resp, err := c.HandleRequest(req)
|
||||
if err != logical.ErrPermissionDenied {
|
||||
if err == nil || !errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
|
||||
t.Fatalf("err: %v, resp: %v", err, resp)
|
||||
}
|
||||
}
|
||||
|
@ -499,7 +500,7 @@ func TestCore_HandleRequest_PermissionDenied(t *testing.T) {
|
|||
ClientToken: "child",
|
||||
}
|
||||
resp, err := c.HandleRequest(req)
|
||||
if err != logical.ErrPermissionDenied {
|
||||
if err == nil || !errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
|
||||
t.Fatalf("err: %v, resp: %v", err, resp)
|
||||
}
|
||||
}
|
||||
|
@ -947,7 +948,7 @@ func TestCore_LimitedUseToken(t *testing.T) {
|
|||
|
||||
// Second operation should fail
|
||||
_, err = c.HandleRequest(req)
|
||||
if err != logical.ErrPermissionDenied {
|
||||
if err == nil || !errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/golang-lru"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
|
@ -18,6 +19,9 @@ const (
|
|||
// policyCacheSize is the number of policies that are kept cached
|
||||
policyCacheSize = 1024
|
||||
|
||||
// cubbyholeResponseWrappingPolicyName is the name of the fixed policy
|
||||
cubbyholeResponseWrappingPolicyName = "cubbyhole-response-wrapping"
|
||||
|
||||
// cubbyholeResponseWrappingPolicy is the policy that ensures cubbyhole
|
||||
// response wrapping can always succeed
|
||||
cubbyholeResponseWrappingPolicy = `
|
||||
|
@ -27,6 +31,13 @@ path "cubbyhole/response" {
|
|||
`
|
||||
)
|
||||
|
||||
var (
|
||||
immutablePolicies = []string{
|
||||
"root",
|
||||
cubbyholeResponseWrappingPolicyName,
|
||||
}
|
||||
)
|
||||
|
||||
// PolicyStore is used to provide durable storage of policy, and to
|
||||
// manage ACLs associated with them.
|
||||
type PolicyStore struct {
|
||||
|
@ -76,7 +87,7 @@ func (c *Core) setupPolicyStore() error {
|
|||
}
|
||||
|
||||
// Ensure that the cubbyhole response wrapping policy exists
|
||||
policy, err = c.policyStore.GetPolicy("cubbyhole-response-wrapping")
|
||||
policy, err = c.policyStore.GetPolicy(cubbyholeResponseWrappingPolicyName)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("error fetching default policy from store: {{err}}", err)
|
||||
}
|
||||
|
@ -100,15 +111,12 @@ func (c *Core) teardownPolicyStore() error {
|
|||
// SetPolicy is used to create or update the given policy
|
||||
func (ps *PolicyStore) SetPolicy(p *Policy) error {
|
||||
defer metrics.MeasureSince([]string{"policy", "set_policy"}, time.Now())
|
||||
if p.Name == "root" {
|
||||
return fmt.Errorf("cannot update root policy")
|
||||
}
|
||||
if p.Name == "cubbyhole-response-wrapping" {
|
||||
return fmt.Errorf("cannot update cubbyhole-response-wrapping policy")
|
||||
}
|
||||
if p.Name == "" {
|
||||
return fmt.Errorf("policy name missing")
|
||||
}
|
||||
if strutil.StrListContains(immutablePolicies, p.Name) {
|
||||
return fmt.Errorf("cannot update %s policy", p.Name)
|
||||
}
|
||||
|
||||
return ps.setPolicyInternal(p)
|
||||
}
|
||||
|
@ -208,15 +216,12 @@ func (ps *PolicyStore) ListPolicies() ([]string, error) {
|
|||
// DeletePolicy is used to delete the named policy
|
||||
func (ps *PolicyStore) DeletePolicy(name string) error {
|
||||
defer metrics.MeasureSince([]string{"policy", "delete_policy"}, time.Now())
|
||||
if name == "root" {
|
||||
return fmt.Errorf("cannot delete root policy")
|
||||
if strutil.StrListContains(immutablePolicies, name) {
|
||||
return fmt.Errorf("cannot delete %s policy", name)
|
||||
}
|
||||
if name == "default" {
|
||||
return fmt.Errorf("cannot delete default policy")
|
||||
}
|
||||
if name == "cubbyhole-response-wrapping" {
|
||||
return fmt.Errorf("cannot delete cubbyhole-response-wrapping policy")
|
||||
}
|
||||
if err := ps.view.Delete(name); err != nil {
|
||||
return fmt.Errorf("failed to delete policy: %v", err)
|
||||
}
|
||||
|
@ -286,13 +291,13 @@ path "cubbyhole" {
|
|||
func (ps *PolicyStore) createCubbyholeResponseWrappingPolicy() error {
|
||||
policy, err := Parse(cubbyholeResponseWrappingPolicy)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("error parsing cubbyhole-response-wrapping policy: {{err}}", err)
|
||||
return errwrap.Wrapf(fmt.Sprintf("error parsing %s policy: {{err}}", cubbyholeResponseWrappingPolicyName), err)
|
||||
}
|
||||
|
||||
if policy == nil {
|
||||
return fmt.Errorf("parsing cubbyhole-response-wrapping policy resulted in nil policy")
|
||||
return fmt.Errorf("parsing %s policy resulted in nil policy", cubbyholeResponseWrappingPolicyName)
|
||||
}
|
||||
|
||||
policy.Name = "cubbyhole-response-wrapping"
|
||||
policy.Name = cubbyholeResponseWrappingPolicyName
|
||||
return ps.setPolicyInternal(policy)
|
||||
}
|
||||
|
|
|
@ -142,12 +142,33 @@ func TestPolicyStore_Predefined(t *testing.T) {
|
|||
t.Fatalf("bad: %v", out)
|
||||
}
|
||||
|
||||
p, err := core.policyStore.GetPolicy("cubbyhole-response-wrapping")
|
||||
pCubby, err := core.policyStore.GetPolicy("cubbyhole-response-wrapping")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if p.Raw != cubbyholeResponseWrappingPolicy {
|
||||
t.Fatalf("bad: expected\n%s\ngot\n%s\n", cubbyholeResponseWrappingPolicy, p.Raw)
|
||||
if pCubby.Raw != cubbyholeResponseWrappingPolicy {
|
||||
t.Fatalf("bad: expected\n%s\ngot\n%s\n", cubbyholeResponseWrappingPolicy, pCubby.Raw)
|
||||
}
|
||||
pRoot, err := core.policyStore.GetPolicy("root")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
err = core.policyStore.SetPolicy(pCubby)
|
||||
if err == nil {
|
||||
t.Fatalf("expected err setting %s", pCubby.Name)
|
||||
}
|
||||
err = core.policyStore.SetPolicy(pRoot)
|
||||
if err == nil {
|
||||
t.Fatalf("expected err setting %s", pRoot.Name)
|
||||
}
|
||||
err = core.policyStore.DeletePolicy(pCubby.Name)
|
||||
if err == nil {
|
||||
t.Fatalf("expected err deleting %s", pCubby.Name)
|
||||
}
|
||||
err = core.policyStore.DeletePolicy(pRoot.Name)
|
||||
if err == nil {
|
||||
t.Fatalf("expected err deleting %s", pRoot.Name)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,23 +3,14 @@ package vault
|
|||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
var (
|
||||
// Value for memoizing whether cubbyhole is mounted, e.g. if we are in
|
||||
// normal operation and not test mode
|
||||
cubbyholeMounted bool
|
||||
|
||||
// mutex to ensure the same
|
||||
cubbyholeMountedMutex sync.Mutex
|
||||
)
|
||||
|
||||
// HandleRequest is used to handle a new incoming request
|
||||
func (c *Core) HandleRequest(req *logical.Request) (resp *logical.Response, err error) {
|
||||
c.stateLock.RLock()
|
||||
|
@ -60,87 +51,16 @@ func (c *Core) HandleRequest(req *logical.Request) (resp *logical.Response, err
|
|||
}
|
||||
}
|
||||
|
||||
// In order to wrap, we need cubbyhole to be mounted, so we ensure that
|
||||
// cubbyhole is actually mounted, as it may not be during tests. We memoize
|
||||
// a true response, since cubbyhole cannot be mounted or unmounted during
|
||||
// normal operation.
|
||||
if !cubbyholeMounted {
|
||||
cubbyholeMountedMutex.Lock()
|
||||
// Ensure it wasn't changed by another goroutine
|
||||
if !cubbyholeMounted {
|
||||
if c.router.MatchingMount("cubbyhole/") != "" {
|
||||
cubbyholeMounted = true
|
||||
}
|
||||
}
|
||||
cubbyholeMountedMutex.Unlock()
|
||||
}
|
||||
|
||||
// We are wrapping if there is anything to wrap (not a nil response) and a
|
||||
// TTL was specified for the token, plus if cubbyhole is mounted (which
|
||||
// will be the case normally)
|
||||
wrapping := cubbyholeMounted && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.TTL != 0
|
||||
// TTL was specified for the token
|
||||
wrapping := resp != nil && resp.WrapInfo != nil && resp.WrapInfo.TTL != 0
|
||||
|
||||
// If we are wrapping, the first part happens before auditing so that
|
||||
// resp.WrapInfo.Token can contain the HMAC'd wrapping token ID in the
|
||||
// audit logs, so that it can be determined from the audit logs whether the
|
||||
// token was ever actually used.
|
||||
if wrapping {
|
||||
// Create the wrapping token
|
||||
te := TokenEntry{
|
||||
Path: req.Path,
|
||||
Policies: []string{"cubbyhole-response-wrapping"},
|
||||
CreationTime: time.Now().Unix(),
|
||||
TTL: resp.WrapInfo.TTL,
|
||||
NumUses: 1,
|
||||
}
|
||||
|
||||
if err := c.tokenStore.create(&te); err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to create wrapping token: %v", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
|
||||
resp.WrapInfo.Token = te.ID
|
||||
|
||||
httpResponse := logical.SanitizeResponse(resp)
|
||||
|
||||
cubbyReq := &logical.Request{
|
||||
Operation: logical.CreateOperation,
|
||||
Path: "cubbyhole/response",
|
||||
ClientToken: te.ID,
|
||||
Data: map[string]interface{}{
|
||||
"response": httpResponse,
|
||||
},
|
||||
}
|
||||
|
||||
cubbyResp, err := c.router.Route(cubbyReq)
|
||||
if err != nil {
|
||||
// Revoke since it's not yet being tracked for expiration
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Printf("[ERR] core: failed to store wrapped response information: %v", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
if cubbyResp != nil && cubbyResp.IsError() {
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Printf("[ERR] core: failed to store wrapped response information: %v", cubbyResp.Data["error"])
|
||||
return cubbyResp, nil
|
||||
}
|
||||
|
||||
auth := &logical.Auth{
|
||||
ClientToken: te.ID,
|
||||
Policies: []string{"cubbyhole-response-wrapping"},
|
||||
LeaseOptions: logical.LeaseOptions{
|
||||
TTL: te.TTL,
|
||||
Renewable: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Register the wrapped token with the expiration manager
|
||||
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
|
||||
// Revoke since it's not yet being tracked for expiration
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Printf("[ERR] core: failed to register cubbyhole wrapping token lease "+
|
||||
"(request path: %s): %v", req.Path, err)
|
||||
return nil, ErrInternalError
|
||||
cubbyResp, err := c.wrapInCubbyhole(req, resp)
|
||||
// If not successful, returns either an error response from the
|
||||
// cubbyhole backend or an error; if either is set, return
|
||||
if cubbyResp != nil || err != nil {
|
||||
return cubbyResp, err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -176,11 +96,13 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
|
|||
te, err = c.tokenStore.UseToken(te)
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to use token: %v", err)
|
||||
return nil, nil, ErrInternalError
|
||||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
return nil, nil, retErr
|
||||
}
|
||||
if te == nil {
|
||||
// Token has been revoked by this point
|
||||
return nil, nil, logical.ErrPermissionDenied
|
||||
retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
|
||||
return nil, nil, retErr
|
||||
}
|
||||
if te.NumUses == -1 {
|
||||
// We defer a revocation until after logic has run, since this is a
|
||||
|
@ -192,7 +114,7 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
|
|||
c.logger.Printf("[ERR] core: failed to revoke token: %v", err)
|
||||
retResp = nil
|
||||
retAuth = nil
|
||||
retErr = ErrInternalError
|
||||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
}
|
||||
if retResp != nil && retResp.Secret != nil &&
|
||||
// Some backends return a TTL even without a Lease ID
|
||||
|
@ -219,7 +141,10 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
|
|||
req.Path, err)
|
||||
}
|
||||
|
||||
return logical.ErrorResponse(ctErr.Error()), nil, errType
|
||||
if errType != nil {
|
||||
retErr = multierror.Append(retErr, errType)
|
||||
}
|
||||
return logical.ErrorResponse(ctErr.Error()), nil, retErr
|
||||
}
|
||||
|
||||
// Attach the display name
|
||||
|
@ -229,11 +154,37 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
|
|||
if err := c.auditBroker.LogRequest(auth, req, nil); err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to audit request with path (%s): %v",
|
||||
req.Path, err)
|
||||
return nil, auth, ErrInternalError
|
||||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
return nil, auth, retErr
|
||||
}
|
||||
|
||||
// Route the request
|
||||
resp, err := c.router.Route(req)
|
||||
if resp != nil {
|
||||
// If either of the request or response requested wrapping, ensure that
|
||||
// the lowest value is what ends up in the response.
|
||||
switch {
|
||||
case req.WrapTTL == 0 && (resp.WrapInfo == nil || resp.WrapInfo.TTL == 0):
|
||||
// Neither defines it, so do nothing
|
||||
|
||||
case req.WrapTTL != 0 && (resp.WrapInfo != nil && resp.WrapInfo.TTL != 0):
|
||||
// Both define, so use the lowest
|
||||
if req.WrapTTL < resp.WrapInfo.TTL {
|
||||
resp.WrapInfo.TTL = req.WrapTTL
|
||||
}
|
||||
|
||||
case req.WrapTTL != 0:
|
||||
// Response wrap info doesn't exist, or its TTL is zero, so set
|
||||
// it to the request TTL
|
||||
resp.WrapInfo = &logical.WrapInfo{
|
||||
TTL: req.WrapTTL,
|
||||
}
|
||||
|
||||
default:
|
||||
// Only case left is that only resp defines it, which doesn't
|
||||
// need to be explicitly handled
|
||||
}
|
||||
}
|
||||
|
||||
// If there is a secret, we must register it with the expiration manager.
|
||||
// We exclude renewal of a lease, since it does not need to be re-registered
|
||||
|
@ -242,7 +193,8 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
|
|||
sysView := c.router.MatchingSystemView(req.Path)
|
||||
if sysView == nil {
|
||||
c.logger.Println("[ERR] core: unable to retrieve system view from router")
|
||||
return nil, auth, ErrInternalError
|
||||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
return nil, auth, retErr
|
||||
}
|
||||
|
||||
// Apply the default lease if none given
|
||||
|
@ -262,7 +214,8 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
|
|||
matchingBackend := c.router.MatchingBackend(req.Path)
|
||||
if matchingBackend == nil {
|
||||
c.logger.Println("[ERR] core: unable to retrieve generic backend from router")
|
||||
return nil, auth, ErrInternalError
|
||||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
return nil, auth, retErr
|
||||
}
|
||||
if ptbe, ok := matchingBackend.(*PassthroughBackend); ok {
|
||||
if !ptbe.GeneratesLeases() {
|
||||
|
@ -277,7 +230,8 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
|
|||
c.logger.Printf(
|
||||
"[ERR] core: failed to register lease "+
|
||||
"(request path: %s): %v", req.Path, err)
|
||||
return nil, auth, ErrInternalError
|
||||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
return nil, auth, retErr
|
||||
}
|
||||
resp.Secret.LeaseID = leaseID
|
||||
}
|
||||
|
@ -291,7 +245,8 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
|
|||
c.logger.Printf(
|
||||
"[ERR] core: unexpected Auth response for non-token backend "+
|
||||
"(request path: %s)", req.Path)
|
||||
return nil, auth, ErrInternalError
|
||||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
return nil, auth, retErr
|
||||
}
|
||||
|
||||
// Register with the expiration manager. We use the token's actual path
|
||||
|
@ -299,18 +254,23 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
|
|||
te, err := c.tokenStore.Lookup(resp.Auth.ClientToken)
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to lookup token: %v", err)
|
||||
return nil, nil, ErrInternalError
|
||||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
return nil, nil, retErr
|
||||
}
|
||||
|
||||
if err := c.expiration.RegisterAuth(te.Path, resp.Auth); err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to register token lease "+
|
||||
"(request path: %s): %v", req.Path, err)
|
||||
return nil, auth, ErrInternalError
|
||||
retErr = multierror.Append(retErr, ErrInternalError)
|
||||
return nil, auth, retErr
|
||||
}
|
||||
}
|
||||
|
||||
// Return the response and error
|
||||
return resp, auth, err
|
||||
if err != nil {
|
||||
retErr = multierror.Append(retErr, err)
|
||||
}
|
||||
return resp, auth, retErr
|
||||
}
|
||||
|
||||
// handleLoginRequest is used to handle a login request, which is an
|
||||
|
@ -423,3 +383,68 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
|
|||
|
||||
return resp, auth, err
|
||||
}
|
||||
|
||||
func (c *Core) wrapInCubbyhole(req *logical.Request, resp *logical.Response) (*logical.Response, error) {
|
||||
// If we are wrapping, the first part (performed in this functions) happens
|
||||
// before auditing so that resp.WrapInfo.Token can contain the HMAC'd
|
||||
// wrapping token ID in the audit logs, so that it can be determined from
|
||||
// the audit logs whether the token was ever actually used.
|
||||
te := TokenEntry{
|
||||
Path: req.Path,
|
||||
Policies: []string{"cubbyhole-response-wrapping"},
|
||||
CreationTime: time.Now().Unix(),
|
||||
TTL: resp.WrapInfo.TTL,
|
||||
NumUses: 1,
|
||||
}
|
||||
|
||||
if err := c.tokenStore.create(&te); err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to create wrapping token: %v", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
|
||||
resp.WrapInfo.Token = te.ID
|
||||
|
||||
httpResponse := logical.SanitizeResponse(resp)
|
||||
|
||||
cubbyReq := &logical.Request{
|
||||
Operation: logical.CreateOperation,
|
||||
Path: "cubbyhole/response",
|
||||
ClientToken: te.ID,
|
||||
Data: map[string]interface{}{
|
||||
"response": httpResponse,
|
||||
},
|
||||
}
|
||||
|
||||
cubbyResp, err := c.router.Route(cubbyReq)
|
||||
if err != nil {
|
||||
// Revoke since it's not yet being tracked for expiration
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Printf("[ERR] core: failed to store wrapped response information: %v", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
if cubbyResp != nil && cubbyResp.IsError() {
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Printf("[ERR] core: failed to store wrapped response information: %v", cubbyResp.Data["error"])
|
||||
return cubbyResp, nil
|
||||
}
|
||||
|
||||
auth := &logical.Auth{
|
||||
ClientToken: te.ID,
|
||||
Policies: []string{"cubbyhole-response-wrapping"},
|
||||
LeaseOptions: logical.LeaseOptions{
|
||||
TTL: te.TTL,
|
||||
Renewable: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Register the wrapped token with the expiration manager
|
||||
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
|
||||
// Revoke since it's not yet being tracked for expiration
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Printf("[ERR] core: failed to register cubbyhole wrapping token lease "+
|
||||
"(request path: %s): %v", req.Path, err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
func TestRequestHandling_Wrapping(t *testing.T) {
|
||||
core, _, root := TestCoreUnsealed(t)
|
||||
|
||||
n := &NoopBackend{}
|
||||
|
||||
core.logicalBackends["noop"] = func(config *logical.BackendConfig) (logical.Backend, error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
meUUID, _ := uuid.GenerateUUID()
|
||||
err := core.mount(&MountEntry{
|
||||
UUID: meUUID,
|
||||
Path: "wraptest",
|
||||
Type: "noop",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// No duration specified
|
||||
req := &logical.Request{
|
||||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.UpdateOperation,
|
||||
}
|
||||
resp, err := core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatalf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
// Just in the request
|
||||
req = &logical.Request{
|
||||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.UpdateOperation,
|
||||
WrapTTL: time.Duration(15 * time.Second),
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) {
|
||||
t.Fatalf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
// Just in the response
|
||||
n.WrapTTL = time.Duration(15 * time.Second)
|
||||
req = &logical.Request{
|
||||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.UpdateOperation,
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) {
|
||||
t.Fatalf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
// In both, with request less
|
||||
n.WrapTTL = time.Duration(15 * time.Second)
|
||||
req = &logical.Request{
|
||||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.UpdateOperation,
|
||||
WrapTTL: time.Duration(10 * time.Second),
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(10*time.Second) {
|
||||
t.Fatalf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
// In both, with response less
|
||||
n.WrapTTL = time.Duration(10 * time.Second)
|
||||
req = &logical.Request{
|
||||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.UpdateOperation,
|
||||
WrapTTL: time.Duration(15 * time.Second),
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
|
@ -194,7 +194,7 @@ func (r *Router) RouteExistenceCheck(req *logical.Request) (bool, bool, error) {
|
|||
return ok, exists, err
|
||||
}
|
||||
|
||||
func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (resp *logical.Response, ok bool, exists bool, err error) {
|
||||
func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logical.Response, bool, bool, error) {
|
||||
// Find the mount point
|
||||
r.l.RLock()
|
||||
mount, raw, ok := r.root.LongestPrefix(req.Path)
|
||||
|
@ -250,37 +250,8 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (resp *l
|
|||
|
||||
// Reset the request before returning
|
||||
defer func() {
|
||||
// We only run this if resp is not nil, so for instance we don't during
|
||||
// an existence check
|
||||
if resp != nil {
|
||||
// If either of the request or response requested wrapping, ensure that
|
||||
// the lowest value is what ends up in the response.
|
||||
switch {
|
||||
case req.WrapTTL == 0 && (resp.WrapInfo == nil || resp.WrapInfo.TTL == 0):
|
||||
// Neither defines it, so do nothing
|
||||
|
||||
case req.WrapTTL != 0 && (resp.WrapInfo != nil && resp.WrapInfo.TTL != 0):
|
||||
// Both define, so use the lowest
|
||||
if req.WrapTTL < resp.WrapInfo.TTL {
|
||||
resp.WrapInfo.TTL = req.WrapTTL
|
||||
}
|
||||
|
||||
case req.WrapTTL != 0:
|
||||
// Response wrap info doesn't exist, or its TTL is zero, so set
|
||||
// it to the request TTL
|
||||
resp.WrapInfo = &logical.WrapInfo{
|
||||
TTL: req.WrapTTL,
|
||||
}
|
||||
|
||||
default:
|
||||
// Only case left is that only resp defines it, which doesn't
|
||||
// need to be explicitly handled
|
||||
}
|
||||
}
|
||||
|
||||
// Reset other parameters
|
||||
req.MountPoint = ""
|
||||
req.Path = original
|
||||
req.MountPoint = ""
|
||||
req.Connection = originalConn
|
||||
req.Storage = nil
|
||||
req.ClientToken = clientToken
|
||||
|
@ -291,7 +262,7 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (resp *l
|
|||
ok, exists, err := re.backend.HandleExistenceCheck(req)
|
||||
return nil, ok, exists, err
|
||||
} else {
|
||||
resp, err = re.backend.HandleRequest(req)
|
||||
resp, err := re.backend.HandleRequest(req)
|
||||
return resp, false, false, err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -384,111 +384,3 @@ func TestPathsToRadix(t *testing.T) {
|
|||
t.Fatalf("bad: %v (sub/bar)", raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_Wrapping(t *testing.T) {
|
||||
core, _, root := TestCoreUnsealed(t)
|
||||
|
||||
n := &NoopBackend{}
|
||||
|
||||
core.logicalBackends["noop"] = func(config *logical.BackendConfig) (logical.Backend, error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
meUUID, _ := uuid.GenerateUUID()
|
||||
err := core.mount(&MountEntry{
|
||||
UUID: meUUID,
|
||||
Path: "wraptest",
|
||||
Type: "noop",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// No duration specified
|
||||
req := &logical.Request{
|
||||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.UpdateOperation,
|
||||
}
|
||||
resp, err := core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatalf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
// Just in the request
|
||||
req = &logical.Request{
|
||||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.UpdateOperation,
|
||||
WrapTTL: time.Duration(15 * time.Second),
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) {
|
||||
t.Fatalf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
// Just in the response
|
||||
n.WrapTTL = time.Duration(15 * time.Second)
|
||||
req = &logical.Request{
|
||||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.UpdateOperation,
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(15*time.Second) {
|
||||
t.Fatalf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
// In both, with request less
|
||||
n.WrapTTL = time.Duration(15 * time.Second)
|
||||
req = &logical.Request{
|
||||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.UpdateOperation,
|
||||
WrapTTL: time.Duration(10 * time.Second),
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(10*time.Second) {
|
||||
t.Fatalf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
// In both, with response less
|
||||
n.WrapTTL = time.Duration(10 * time.Second)
|
||||
req = &logical.Request{
|
||||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.UpdateOperation,
|
||||
WrapTTL: time.Duration(15 * time.Second),
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
if resp.WrapInfo == nil || resp.WrapInfo.TTL != time.Duration(10*time.Second) {
|
||||
t.Fatalf("bad: %#v", resp)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -92,14 +92,11 @@ func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error)
|
|||
t.salt = salt
|
||||
|
||||
t.tokenLocks = map[string]*sync.RWMutex{}
|
||||
for i := int64(0); i < 16; i++ {
|
||||
for j := int64(0); j < 16; j++ {
|
||||
t.tokenLocks[fmt.Sprintf("%s%s",
|
||||
strconv.FormatInt(i, 16),
|
||||
strconv.FormatInt(j, 16))] = &sync.RWMutex{}
|
||||
}
|
||||
for i := int64(0); i < 256; i++ {
|
||||
t.tokenLocks[fmt.Sprintf("%2x",
|
||||
strconv.FormatInt(i, 16))] = &sync.RWMutex{}
|
||||
}
|
||||
t.tokenLocks["global"] = &sync.RWMutex{}
|
||||
t.tokenLocks["custom"] = &sync.RWMutex{}
|
||||
|
||||
// Setup the framework endpoints
|
||||
t.Backend = &framework.Backend{
|
||||
|
@ -580,7 +577,8 @@ func (ts *TokenStore) getTokenLock(id string) *sync.RWMutex {
|
|||
lock, ok = ts.tokenLocks[id[0:2]]
|
||||
}
|
||||
if !ok || lock == nil {
|
||||
lock = ts.tokenLocks["global"]
|
||||
// Fall back for custom token IDs
|
||||
lock = ts.tokenLocks["custom"]
|
||||
}
|
||||
|
||||
return lock
|
||||
|
@ -614,20 +612,20 @@ func (ts *TokenStore) UseToken(te *TokenEntry) (*TokenEntry, error) {
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refresh entry: %v", err)
|
||||
}
|
||||
// If it can't be found we shouldn't be trying to use it, so if we get nil
|
||||
// back, it is because it has been revoked in the interim or will be
|
||||
// revoked (NumUses is -1)
|
||||
if te == nil {
|
||||
return nil, fmt.Errorf("token entry nil after refreshing to decrement use count; token has likely been used already")
|
||||
}
|
||||
|
||||
// If the token is already being revoked, return nil to indicate that it's
|
||||
// no longer valid
|
||||
if te.NumUses == -1 {
|
||||
return nil, nil
|
||||
return nil, fmt.Errorf("token not found or fully used already")
|
||||
}
|
||||
|
||||
// Decrement the count. If this is our last use count, we need to indicate
|
||||
// that this is no longer valid, but revocation is deferred to the end of
|
||||
// the call, so this will make sure that any Lookup that happens doesn't
|
||||
// return an entry.
|
||||
// return an entry. This essentially acts as a write-ahead lock and is
|
||||
// especially useful since revocation can end up (via the expiration
|
||||
// manager revoking children) attempting to acquire the same lock
|
||||
// repeatedly.
|
||||
if te.NumUses == 1 {
|
||||
te.NumUses = -1
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue