This commit is contained in:
Jeff Mitchell 2017-10-23 15:35:28 -04:00
parent d7c5a3acfc
commit 3c6fe40a91
8 changed files with 360 additions and 98 deletions

View File

@ -31,6 +31,12 @@ func LeaseSwitchedPassthroughBackend(conf *logical.BackendConfig, leases bool) (
b.Backend = &framework.Backend{
Help: strings.TrimSpace(passthroughHelp),
PathsSpecial: &logical.Paths{
SealWrapStorage: []string{
"/",
},
},
Paths: []*framework.Path{
&framework.Path{
Pattern: ".*",

View File

@ -14,7 +14,7 @@ func TestPassthroughBackend_RootPaths(t *testing.T) {
b := testPassthroughBackend()
test := func(b logical.Backend) {
root := b.SpecialPaths()
if root != nil {
if len(root.Root) != 0 {
t.Fatalf("unexpected: %v", root)
}
}

View File

@ -28,6 +28,7 @@ var (
// This is both for security and to prevent disrupting Vault.
protectedPaths = []string{
keyringPath,
coreLocalClusterInfoPath,
}
replicationPaths = func(b *SystemBackend) []*framework.Path {
@ -53,6 +54,7 @@ var (
func NewSystemBackend(core *Core) *SystemBackend {
b := &SystemBackend{
Core: core,
logger: core.logger,
}
b.Backend = &framework.Backend{
@ -536,7 +538,7 @@ func NewSystemBackend(core *Core) *SystemBackend {
},
&framework.Path{
Pattern: "policy$",
Pattern: "policy/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.handlePolicyList,
@ -559,6 +561,10 @@ func NewSystemBackend(core *Core) *SystemBackend {
Type: framework.TypeString,
Description: strings.TrimSpace(sysHelp["policy-rules"][0]),
},
"policy": &framework.FieldSchema{
Type: framework.TypeString,
Description: strings.TrimSpace(sysHelp["policy-rules"][0]),
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -566,6 +572,44 @@ func NewSystemBackend(core *Core) *SystemBackend {
logical.UpdateOperation: b.handlePolicySet,
logical.DeleteOperation: b.handlePolicyDelete,
},
},
&framework.Path{
Pattern: "policies/acl/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.handlePoliciesList(PolicyTypeACL),
},
HelpSynopsis: strings.TrimSpace(sysHelp["policy-list"][0]),
HelpDescription: strings.TrimSpace(sysHelp["policy-list"][1]),
},
&framework.Path{
Pattern: "policies/acl/(?P<name>.+)",
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: strings.TrimSpace(sysHelp["policy-name"][0]),
},
"policy": &framework.FieldSchema{
Type: framework.TypeString,
Description: strings.TrimSpace(sysHelp["policy-rules"][0]),
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.handlePoliciesRead(PolicyTypeACL),
logical.UpdateOperation: b.handlePoliciesSet(PolicyTypeACL),
logical.DeleteOperation: b.handlePoliciesDelete(PolicyTypeACL),
},
HelpSynopsis: strings.TrimSpace(sysHelp["policy"][0]),
HelpDescription: strings.TrimSpace(sysHelp["policy"][1]),
},
},
HelpSynopsis: strings.TrimSpace(sysHelp["policy"][0]),
HelpDescription: strings.TrimSpace(sysHelp["policy"][1]),
@ -789,6 +833,7 @@ func NewSystemBackend(core *Core) *SystemBackend {
HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers"][0]),
HelpDescription: strings.TrimSpace(sysHelp["audited-headers"][1]),
},
&framework.Path{
Pattern: "plugins/catalog/?$",
@ -801,6 +846,7 @@ func NewSystemBackend(core *Core) *SystemBackend {
HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog"][0]),
HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]),
},
&framework.Path{
Pattern: "plugins/catalog/(?P<name>.+)",
@ -959,6 +1005,7 @@ func NewSystemBackend(core *Core) *SystemBackend {
type SystemBackend struct {
*framework.Backend
Core *Core
logger log.Logger
}
// handleCORSRead returns the current CORS configuration
@ -1008,15 +1055,23 @@ func (b *SystemBackend) handleTidyLeases(req *logical.Request, d *framework.Fiel
}
func (b *SystemBackend) invalidate(key string) {
if b.Core.logger.IsTrace() {
b.Core.logger.Trace("sys: invalidating key", "key", key)
}
/*
if b.Core.logger.IsTrace() {
b.Core.logger.Trace("sys: invalidating key", "key", key)
}
*/
switch {
case strings.HasPrefix(key, policySubPath):
case strings.HasPrefix(key, policyACLSubPath):
b.Core.stateLock.RLock()
defer b.Core.stateLock.RUnlock()
if b.Core.policyStore != nil {
b.Core.policyStore.invalidate(strings.TrimPrefix(key, policySubPath))
b.Core.policyStore.invalidate(strings.TrimPrefix(key, policyACLSubPath), PolicyTypeACL)
}
case strings.HasPrefix(key, tokenSubPath):
b.Core.stateLock.RLock()
defer b.Core.stateLock.RUnlock()
if b.Core.tokenStore != nil {
b.Core.tokenStore.Invalidate(key)
}
}
}
@ -1317,15 +1372,18 @@ func (b *SystemBackend) handleMountTable(
for _, entry := range b.Core.mounts.Entries {
// Populate mount info
structConfig := structs.New(entry.Config).Map()
structConfig["default_lease_ttl"] = int64(structConfig["default_lease_ttl"].(time.Duration).Seconds())
structConfig["max_lease_ttl"] = int64(structConfig["max_lease_ttl"].(time.Duration).Seconds())
info := map[string]interface{}{
"type": entry.Type,
"description": entry.Description,
"accessor": entry.Accessor,
"config": structConfig,
"local": entry.Local,
"config": map[string]interface{}{
"default_lease_ttl": int64(entry.Config.DefaultLeaseTTL.Seconds()),
"max_lease_ttl": int64(entry.Config.MaxLeaseTTL.Seconds()),
"force_no_cache": entry.Config.ForceNoCache,
"plugin_name": entry.Config.PluginName,
"seal_wrap": entry.Config.SealWrap,
},
"local": entry.Local,
}
resp.Data[entry.Path] = info
}
@ -1396,15 +1454,21 @@ func (b *SystemBackend) handleMount(
logical.ErrInvalidRequest
}
if config.DefaultLeaseTTL > b.Core.maxLeaseTTL {
if config.DefaultLeaseTTL > b.Core.maxLeaseTTL && config.MaxLeaseTTL == 0 {
return logical.ErrorResponse(fmt.Sprintf(
"given default lease TTL greater than system max lease TTL of %d", int(b.Core.maxLeaseTTL.Seconds()))),
logical.ErrInvalidRequest
}
// Only set plugin-name if mount is of type plugin, with apiConfig.PluginName
// option taking precedence.
if logicalType == "plugin" {
switch logicalType {
case "":
return logical.ErrorResponse(
"backend type must be specified as a string"),
logical.ErrInvalidRequest
case "plugin":
// Only set plugin-name if mount is of type plugin, with apiConfig.PluginName
// option taking precedence.
switch {
case apiConfig.PluginName != "":
config.PluginName = apiConfig.PluginName
@ -1417,17 +1481,15 @@ func (b *SystemBackend) handleMount(
}
}
if apiConfig.SealWrap {
config.SealWrap = true
}
// Copy over the force no cache if set
if apiConfig.ForceNoCache {
config.ForceNoCache = true
}
if logicalType == "" {
return logical.ErrorResponse(
"backend type must be specified as a string"),
logical.ErrInvalidRequest
}
// Create the mount entry
me := &MountEntry{
Table: mountTableType,
@ -1477,6 +1539,12 @@ func (b *SystemBackend) handleUnmount(
return nil, nil
}
_, prefix, found := b.Core.router.MatchingStoragePrefixByAPIPath(path)
if !found {
b.Backend.Logger().Error("sys: unable to find storage for path", "path", path)
return handleError(fmt.Errorf("unable to find storage for path: %s", path))
}
// Attempt unmount
if err := b.Core.unmount(path); err != nil {
b.Backend.Logger().Error("sys: unmount failed", "path", path, "error", err)
@ -1623,7 +1691,7 @@ func (b *SystemBackend) handleTuneWriteCommon(
var lock *sync.RWMutex
switch {
case strings.HasPrefix(path, "auth/"):
case strings.HasPrefix(path, credentialRoutePrefix):
lock = &b.Core.authLock
default:
lock = &b.Core.mountsLock
@ -1871,7 +1939,6 @@ func (b *SystemBackend) handleAuthTable(
func (b *SystemBackend) handleEnableAuth(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
repState := b.Core.replicationState
local := data.Get("local").(bool)
if !local && repState.HasState(consts.ReplicationPerformanceSecondary) {
return logical.ErrorResponse("cannot add a non-local mount to a replication secondary"), nil
@ -1942,6 +2009,7 @@ func (b *SystemBackend) handleDisableAuth(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
path := data.Get("path").(string)
path = sanitizeMountPath(path)
fullPath := credentialRoutePrefix + path
repState := b.Core.replicationState
@ -1969,7 +2037,7 @@ func (b *SystemBackend) handleDisableAuth(
func (b *SystemBackend) handlePolicyList(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
// Get all the configured policies
policies, err := b.Core.policyStore.ListPolicies()
policies, err := b.Core.policyStore.ListPolicies(PolicyTypeACL)
// Add the special "root" policy
policies = append(policies, "root")
@ -1981,12 +2049,52 @@ func (b *SystemBackend) handlePolicyList(
return resp, err
}
func (b *SystemBackend) handlePoliciesList(policyType PolicyType) func(*logical.Request, *framework.FieldData) (*logical.Response, error) {
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
policies, err := b.Core.policyStore.ListPolicies(policyType)
if err != nil {
return nil, err
}
switch policyType {
case PolicyTypeACL:
// Add the special "root" policy if not egp
policies = append(policies, "root")
return logical.ListResponse(policies), nil
}
return logical.ErrorResponse("unknown policy type"), nil
}
}
func (b *SystemBackend) handlePoliciesRead(policyType PolicyType) func(*logical.Request, *framework.FieldData) (*logical.Response, error) {
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
policy, err := b.Core.policyStore.GetPolicy(name, policyType)
if err != nil {
return handleError(err)
}
if policy == nil {
return nil, nil
}
resp := &logical.Response{
Data: map[string]interface{}{
"name": policy.Name,
"policy": policy.Raw,
},
}
}
// handlePolicyRead handles the "policy/<name>" endpoint to read a policy
func (b *SystemBackend) handlePolicyRead(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
policy, err := b.Core.policyStore.GetPolicy(name)
policy, err := b.Core.policyStore.GetPolicy(name, PolicyTypeACL)
if err != nil {
return handleError(err)
}
@ -1995,44 +2103,106 @@ func (b *SystemBackend) handlePolicyRead(
return nil, nil
}
return &logical.Response{
resp := &logical.Response{
Data: map[string]interface{}{
"name": policy.Name,
"rules": policy.Raw,
},
}, nil
}
return resp, nil
}
func (b *SystemBackend) handlePoliciesSet(policyType PolicyType) func(*logical.Request, *framework.FieldData) (*logical.Response, error) {
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
policy := &Policy{
Name: strings.ToLower(data.Get("name").(string)),
Type: policyType,
}
if policy.Name == "" {
return logical.ErrorResponse("policy name must be provided in the URL"), nil
}
policy.Raw = data.Get("policy").(string)
if policy.Raw == "" {
return logical.ErrorResponse("'policy' parameter not supplied or empty"), nil
}
if polBytes, err := base64.StdEncoding.DecodeString(policy.Raw); err == nil {
policy.Raw = string(polBytes)
}
var enforcementLevel string
switch policyType {
case PolicyTypeACL:
p, err := ParseACLPolicy(policy.Raw)
if err != nil {
return handleError(err)
}
policy.Paths = p.Paths
default:
return logical.ErrorResponse("unknown policy type"), nil
}
// Update the policy
if err := b.Core.policyStore.SetPolicy(policy); err != nil {
return handleError(err)
}
return nil, nil
}
}
// handlePolicySet handles the "policy/<name>" endpoint to set a policy
func (b *SystemBackend) handlePolicySet(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
rulesRaw, ok := data.GetOk("rules")
if !ok {
return logical.ErrorResponse("'rules' parameter not supplied"), nil
policy := &Policy{
Type: PolicyTypeACL,
Name: strings.ToLower(data.Get("name").(string)),
}
if policy.Name == "" {
return logical.ErrorResponse("policy name must be provided in the URL"), nil
}
rules := rulesRaw.(string)
if rules == "" {
return logical.ErrorResponse("'rules' parameter empty"), nil
var resp *logical.Response
policy.Raw = data.Get("policy").(string)
if policy.Raw == "" {
policy.Raw = data.Get("rules").(string)
if resp == nil {
resp = &logical.Response{}
}
resp.AddWarning("'rules' is deprecated, please use 'policy' instead")
}
if policy.Raw == "" {
return logical.ErrorResponse("'policy' parameter not supplied or empty"), nil
}
// Validate the rules parse
parse, err := Parse(rules)
p, err := ParseACLPolicy(policy.Raw)
if err != nil {
return handleError(err)
}
if name != "" {
parse.Name = name
}
policy.Paths = p.Paths
// Update the policy
if err := b.Core.policyStore.SetPolicy(parse); err != nil {
if err := b.Core.policyStore.SetPolicy(policy); err != nil {
return handleError(err)
}
return nil, nil
return resp, nil
}
func (b *SystemBackend) handlePoliciesDelete(policyType PolicyType) func(*logical.Request, *framework.FieldData) (*logical.Response, error) {
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if err := b.Core.policyStore.DeletePolicy(name, policyType); err != nil {
return handleError(err)
}
return nil, nil
}
}
// handlePolicyDelete handles the "policy/<name>" endpoint to delete a policy
@ -2040,7 +2210,7 @@ func (b *SystemBackend) handlePolicyDelete(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if err := b.Core.policyStore.DeletePolicy(name); err != nil {
if err := b.Core.policyStore.DeletePolicy(name, PolicyTypeACL); err != nil {
return handleError(err)
}
return nil, nil

View File

@ -36,7 +36,7 @@ func (b *SystemBackend) tuneMountTTLs(path string, me *MountEntry, newDefault, n
// Update the mount table
var err error
switch {
case strings.HasPrefix(path, "auth/"):
case strings.HasPrefix(path, credentialRoutePrefix):
err = b.Core.persistAuth(b.Core.auth, me.Local)
default:
err = b.Core.persistMounts(b.Core.mounts, me.Local)

View File

@ -124,7 +124,9 @@ func TestSystemBackend_mounts(t *testing.T) {
"config": map[string]interface{}{
"default_lease_ttl": resp.Data["secret/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64),
"max_lease_ttl": resp.Data["secret/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64),
"plugin_name": "",
"force_no_cache": false,
"seal_wrap": false,
},
"local": false,
},
@ -135,7 +137,9 @@ func TestSystemBackend_mounts(t *testing.T) {
"config": map[string]interface{}{
"default_lease_ttl": resp.Data["sys/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64),
"max_lease_ttl": resp.Data["sys/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64),
"plugin_name": "",
"force_no_cache": false,
"seal_wrap": false,
},
"local": false,
},
@ -146,7 +150,9 @@ func TestSystemBackend_mounts(t *testing.T) {
"config": map[string]interface{}{
"default_lease_ttl": resp.Data["cubbyhole/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64),
"max_lease_ttl": resp.Data["cubbyhole/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64),
"plugin_name": "",
"force_no_cache": false,
"seal_wrap": false,
},
"local": true,
},
@ -157,13 +163,15 @@ func TestSystemBackend_mounts(t *testing.T) {
"config": map[string]interface{}{
"default_lease_ttl": resp.Data["identity/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64),
"max_lease_ttl": resp.Data["identity/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64),
"plugin_name": "",
"force_no_cache": false,
"seal_wrap": false,
},
"local": false,
},
}
if !reflect.DeepEqual(resp.Data, exp) {
t.Fatalf("Got:\n%#v\nExpected:\n%#v", resp.Data, exp)
t.Fatalf("bad: got\n%#v\nexpected\n%#v\n", resp.Data, exp)
}
}
@ -270,7 +278,7 @@ func testCapabilities(t *testing.T, endpoint string) {
t.Fatalf("bad: got\n%#v\nexpected\n%#v\n", actual, expected)
}
policy, _ := Parse(capabilitiesPolicy)
policy, _ := ParseACLPolicy(capabilitiesPolicy)
err = core.policyStore.SetPolicy(policy)
if err != nil {
t.Fatalf("err: %v", err)
@ -322,7 +330,7 @@ func TestSystemBackend_CapabilitiesAccessor(t *testing.T) {
t.Fatalf("bad: got\n%#v\nexpected\n%#v\n", actual, expected)
}
policy, _ := Parse(capabilitiesPolicy)
policy, _ := ParseACLPolicy(capabilitiesPolicy)
err = core.policyStore.SetPolicy(policy)
if err != nil {
t.Fatalf("err: %v", err)
@ -1226,7 +1234,7 @@ func TestSystemBackend_policyCRUD(t *testing.T) {
if err != nil {
t.Fatalf("err: %v %#v", err, resp)
}
if resp != nil {
if resp != nil && (resp.IsError() || len(resp.Data) > 0) {
t.Fatalf("bad: %#v", resp)
}
@ -1480,7 +1488,7 @@ func TestSystemBackend_rawWrite_Protected(t *testing.T) {
}
func TestSystemBackend_rawReadWrite(t *testing.T) {
c, b, _ := testCoreSystemBackendRaw(t)
_, b, _ := testCoreSystemBackendRaw(t)
req := logical.TestRequest(t, logical.UpdateOperation, "raw/sys/policy/test")
req.Data["value"] = `path "secret/" { policy = "read" }`
@ -1502,17 +1510,8 @@ func TestSystemBackend_rawReadWrite(t *testing.T) {
t.Fatalf("bad: %v", resp)
}
// Read the policy!
p, err := c.policyStore.GetPolicy("test")
if err != nil {
t.Fatalf("err: %v", err)
}
if p == nil || len(p.Paths) == 0 {
t.Fatalf("missing policy %#v", p)
}
if p.Paths[0].Prefix != "secret/" || p.Paths[0].Policy != ReadCapability {
t.Fatalf("Bad: %#v", p)
}
// Note: since the upgrade code is gone that upgraded from 0.1, we can't
// simply parse this out directly via GetPolicy, so the test now ends here.
}
func TestSystemBackend_rawDelete_Protected(t *testing.T) {
@ -1529,7 +1528,10 @@ func TestSystemBackend_rawDelete(t *testing.T) {
c, b, _ := testCoreSystemBackendRaw(t)
// set the policy!
p := &Policy{Name: "test"}
p := &Policy{
Name: "test",
Type: PolicyTypeACL,
}
err := c.policyStore.SetPolicy(p)
if err != nil {
t.Fatalf("err: %v", err)
@ -1546,8 +1548,8 @@ func TestSystemBackend_rawDelete(t *testing.T) {
}
// Policy should be gone
c.policyStore.lru.Purge()
out, err := c.policyStore.GetPolicy("test")
c.policyStore.tokenPoliciesLRU.Purge()
out, err := c.policyStore.GetPolicy("test", PolicyTypeToken)
if err != nil {
t.Fatalf("err: %v", err)
}

View File

@ -74,6 +74,19 @@ var (
mountAliases = map[string]string{"generic": "kv"}
)
func collectBackendLocalPaths(backend logical.Backend, viewPath string) []string {
if backend == nil || backend.SpecialPaths() == nil || len(backend.SpecialPaths().LocalStorage) == 0 {
return nil
}
var paths []string
for _, path := range backend.SpecialPaths().LocalStorage {
paths = append(paths, viewPath+path)
}
return paths
}
func (c *Core) generateMountAccessor(entryType string) (string, error) {
var accessor string
for {
@ -176,6 +189,7 @@ type MountConfig struct {
MaxLeaseTTL time.Duration `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"` // Override for global default
ForceNoCache bool `json:"force_no_cache" structs:"force_no_cache" mapstructure:"force_no_cache"` // Override for global default
PluginName string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"`
SealWrap bool `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"`
}
// APIMountConfig is an embedded struct of api.MountConfigInput
@ -184,6 +198,16 @@ type APIMountConfig struct {
MaxLeaseTTL string `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"`
ForceNoCache bool `json:"force_no_cache" structs:"force_no_cache" mapstructure:"force_no_cache"`
PluginName string `json:"plugin_name,omitempty" structs:"plugin_name,omitempty" mapstructure:"plugin_name"`
SealWrap bool `json:"seal_wrap" structs:"seal_wrap" mapstructure:"seal_wrap"`
}
// Clone returns a deep copy of the mount entry
func (e *MountEntry) Clone() (*MountEntry, error) {
cp, err := copystructure.Copy(e)
if err != nil {
return nil, err
}
return cp.(*MountEntry), nil
}
// Mount is used to mount a new backend to the mount table.
@ -206,7 +230,10 @@ func (c *Core) mount(entry *MountEntry) error {
return logical.CodedError(403, fmt.Sprintf("Cannot mount more than one instance of '%s'", entry.Type))
}
}
return c.mountInternal(entry)
}
func (c *Core) mountInternal(entry *MountEntry) error {
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
@ -232,13 +259,16 @@ func (c *Core) mount(entry *MountEntry) error {
}
viewPath := backendBarrierPrefix + entry.UUID + "/"
view := NewBarrierView(c.barrier, viewPath)
var backend logical.Backend
var err error
sysView := c.mountEntrySysView(entry)
conf := make(map[string]string)
if entry.Config.PluginName != "" {
conf["plugin_name"] = entry.Config.PluginName
}
backend, err := c.newLogicalBackend(entry.Type, sysView, view, conf)
// Consider having plugin name under entry.Options
backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf)
if err != nil {
return err
}
@ -253,7 +283,7 @@ func (c *Core) mount(entry *MountEntry) error {
}
// Call initialize; this takes care of init tasks that must be run after
// the ignore paths are collected.
// the ignore paths are collected
if err := backend.Initialize(); err != nil {
return err
}
@ -292,7 +322,10 @@ func (c *Core) unmount(path string) error {
return fmt.Errorf("cannot unmount '%s'", path)
}
}
return c.unmountInternal(path)
}
func (c *Core) unmountInternal(path string) error {
// Verify exact match of the route
match := c.router.MatchingMount(path)
if match == "" || path != match {
@ -300,10 +333,16 @@ func (c *Core) unmount(path string) error {
}
// Get the view for this backend
view := c.router.MatchingStorageView(path)
view := c.router.MatchingStorageByAPIPath(path)
// Get the backend/mount entry for this path, used to remove ignored
// replication prefixes
backend := c.router.MatchingBackend(path)
entry := c.router.MatchingMountEntry(path)
// Mark the entry as tainted
if err := c.taintMountEntry(path); err != nil {
c.logger.Error("core: failed to taint mount entry for path being unmounted", "error", err, "path", path)
return err
}
@ -313,19 +352,18 @@ func (c *Core) unmount(path string) error {
return err
}
// Invoke the rollback manager a final time
if err := c.rollback.Rollback(path); err != nil {
return err
}
// Revoke all the dynamic keys
if err := c.expiration.RevokePrefix(path); err != nil {
return err
}
// Call cleanup function if it exists
backend := c.router.MatchingBackend(path)
if backend != nil {
// Invoke the rollback manager a final time
if err := c.rollback.Rollback(path); err != nil {
return err
}
// Revoke all the dynamic keys
if err := c.expiration.RevokePrefix(path); err != nil {
return err
}
// Call cleanup function if it exists
backend.Cleanup()
}
@ -334,15 +372,21 @@ func (c *Core) unmount(path string) error {
return err
}
// Clear the data in the view
if err := logical.ClearView(view); err != nil {
return err
switch {
case entry.Local, !c.replicationState.HasState(consts.ReplicationPerformanceSecondary):
// Have writable storage, remove the whole thing
if err := logical.ClearView(view); err != nil {
c.logger.Error("core: failed to clear view for path being unmounted", "error", err, "path", path)
return err
}
}
// Remove the mount table entry
if err := c.removeMountEntry(path); err != nil {
c.logger.Error("core: failed to remove mount entry for path being unmounted", "error", err, "path", path)
return err
}
if c.logger.IsInfo() {
c.logger.Info("core: successfully unmounted", "path", path)
}
@ -400,6 +444,25 @@ func (c *Core) taintMountEntry(path string) error {
return nil
}
// remountForce takes a copy of the mount entry for the path and fully unmounts
// and remounts the backend to pick up any changes, such as filtered paths
func (c *Core) remountForce(path string) error {
me := c.router.MatchingMountEntry(path)
if me == nil {
return fmt.Errorf("cannot find mount for path '%s'", path)
}
me, err := me.Clone()
if err != nil {
return err
}
if err := c.unmount(path); err != nil {
return err
}
return c.mount(me)
}
// Remount is used to remount a path at a new mount point.
func (c *Core) remount(src, dst string) error {
// Ensure we end the path in a slash
@ -448,24 +511,25 @@ func (c *Core) remount(src, dst string) error {
}
c.mountsLock.Lock()
var ent *MountEntry
for _, ent = range c.mounts.Entries {
if ent.Path == src {
ent.Path = dst
ent.Tainted = false
var entry *MountEntry
for _, entry = range c.mounts.Entries {
if entry.Path == src {
entry.Path = dst
entry.Tainted = false
break
}
}
if ent == nil {
if entry == nil {
c.mountsLock.Unlock()
c.logger.Error("core: failed to find entry in mounts table")
return logical.CodedError(500, "failed to find entry in mounts table")
}
// Update the mount table
if err := c.persistMounts(c.mounts, ent.Local); err != nil {
ent.Path = src
ent.Tainted = true
if err := c.persistMounts(c.mounts, entry.Local); err != nil {
entry.Path = src
entry.Tainted = true
c.mountsLock.Unlock()
c.logger.Error("core: failed to update mounts table", "error", err)
return logical.CodedError(500, "failed to update mounts table")
@ -546,6 +610,7 @@ func (c *Core) loadMounts() error {
break
}
}
// In a replication scenario we will let sync invalidation take
// care of creating a new required mount that doesn't exist yet.
// This should only happen in the upgrade case where a new one is
@ -553,7 +618,7 @@ func (c *Core) loadMounts() error {
// ensure this comes over. If we upgrade first, we simply don't
// create the mount, so we won't conflict when we sync. If this is
// local (e.g. cubbyhole) we do still add it.
if !foundRequired && (c.replicationState.HasState(consts.ReplicationPerformanceSecondary) || requiredMount.Local) {
if !foundRequired && (!c.replicationState.HasState(consts.ReplicationPerformanceSecondary) || requiredMount.Local) {
c.mounts.Entries = append(c.mounts.Entries, requiredMount)
needPersist = true
}
@ -677,7 +742,6 @@ func (c *Core) setupMounts() error {
var err error
for _, entry := range c.mounts.Entries {
var backend logical.Backend
// Initialize the backend, special casing for system
barrierPath := backendBarrierPrefix + entry.UUID + "/"
@ -687,6 +751,9 @@ func (c *Core) setupMounts() error {
// Create a barrier view using the UUID
view = NewBarrierView(c.barrier, barrierPath)
var backend logical.Backend
var err error
sysView := c.mountEntrySysView(entry)
// Set up conf to pass in plugin_name
conf := make(map[string]string)
@ -710,8 +777,9 @@ func (c *Core) setupMounts() error {
}
// Check for the correct backend type
if entry.Type == "plugin" && backend.Type() != logical.TypeLogical {
return fmt.Errorf("cannot mount '%s' of type '%s' as a logical backend", entry.Config.PluginName, backend.Type())
backendType := backend.Type()
if entry.Type == "plugin" && backendType != logical.TypeLogical {
return fmt.Errorf("cannot mount '%s' of type '%s' as a logical backend", entry.Config.PluginName, backendType)
}
if err := backend.Initialize(); err != nil {
@ -727,6 +795,7 @@ func (c *Core) setupMounts() error {
c.logger.Error("core: failed to mount entry", "path", entry.Path, "error", err)
return errLoadMountsFailed
}
if c.logger.IsInfo() {
c.logger.Info("core: successfully mounted backend", "type", entry.Type, "path", entry.Path)
}

View File

@ -233,7 +233,7 @@ func TestCore_Unmount_Cleanup(t *testing.T) {
}
// Store the view
view := c.router.MatchingStorageView("test/")
view := c.router.MatchingStorageByAPIPath("test/")
// Inject data
se := &logical.StorageEntry{
@ -353,7 +353,7 @@ func TestCore_Remount_Cleanup(t *testing.T) {
}
// Store the view
view := c.router.MatchingStorageView("test/")
view := c.router.MatchingStorageByAPIPath("test/")
// Inject data
se := &logical.StorageEntry{
@ -633,8 +633,9 @@ func TestSingletonMountTableFunc(t *testing.T) {
mounts, auth := c.singletonMountTables()
if len(mounts.Entries) != 2 {
t.Fatal("length of mounts is wrong")
t.Fatalf("length of mounts is wrong; expected 2, got %d", len(mounts.Entries))
}
for _, entry := range mounts.Entries {
switch entry.Type {
case "system":

14
vault/router_access.go Normal file
View File

@ -0,0 +1,14 @@
package vault
// RouterAccess provides access into some things necessary for testing
type RouterAccess struct {
c *Core
}
func NewRouterAccess(c *Core) *RouterAccess {
return &RouterAccess{c: c}
}
func (r *RouterAccess) StoragePrefixByAPIPath(path string) (string, string, bool) {
return r.c.router.MatchingStoragePrefixByAPIPath(path)
}