Rearchitect MountTable locking and fix rollback.

The rollback manager was using a saved MountTable rather than the
current table, causing it to attempt to rollback unmounted mounts, and
never rollback new mounts.

In fixing this, it became clear that bad things could happen to the
mount table...the table itself could be locked, but the table pointer
(which is what the rollback manager needs) could be modified at any time
without locking. This commit therefore also returns locking to a mutex
outside the table instead of inside, and plumbs RLock/RUnlock through to
the various places that are reading the table but not holding a write
lock.

Both unit tests and race detection pass.

Fixes #771
This commit is contained in:
Jeff Mitchell 2015-11-11 11:44:07 -05:00
parent fa646a1eb1
commit bc4c18a1cf
8 changed files with 136 additions and 60 deletions

View File

@ -82,6 +82,7 @@ generate them, leading to client errors.
* core: Fix an error that could happen in some failure scenarios where Vault
could fail to revert to a clean state [GH-733]
* core: Ensure secondary indexes are removed when a lease is expired [GH-749]
* core: Ensure rollback manager uses an up-to-date mounts table [GH-771]
* everywhere: Don't use http.DefaultClient, as it shares state implicitly and
is a source of hard-to-track-down bugs [GH-700]
* credential/token: Allow creating orphan tokens via an API path [GH-748]

View File

@ -35,9 +35,6 @@ var (
// enableAudit is used to enable a new audit backend
func (c *Core) enableAudit(entry *MountEntry) error {
c.audit.Lock()
defer c.audit.Unlock()
// Ensure we end the path in a slash
if !strings.HasSuffix(entry.Path, "/") {
entry.Path += "/"
@ -48,6 +45,10 @@ func (c *Core) enableAudit(entry *MountEntry) error {
return fmt.Errorf("backend path must be specified")
}
// Update the audit table
c.auditLock.Lock()
defer c.auditLock.Unlock()
// Look for matching name
for _, ent := range c.audit.Entries {
switch {
@ -70,12 +71,12 @@ func (c *Core) enableAudit(entry *MountEntry) error {
return err
}
// Update the audit table
newTable := c.audit.ShallowClone()
newTable.Entries = append(newTable.Entries, entry)
if err := c.persistAudit(newTable); err != nil {
return errors.New("failed to update audit table")
}
c.audit = newTable
// Register the backend
@ -87,15 +88,15 @@ func (c *Core) enableAudit(entry *MountEntry) error {
// disableAudit is used to disable an existing audit backend
func (c *Core) disableAudit(path string) error {
c.audit.Lock()
defer c.audit.Unlock()
// Ensure we end the path in a slash
if !strings.HasSuffix(path, "/") {
path += "/"
}
// Remove the entry from the mount table
c.auditLock.Lock()
defer c.auditLock.Unlock()
newTable := c.audit.ShallowClone()
found := newTable.Remove(path)
@ -108,6 +109,7 @@ func (c *Core) disableAudit(path string) error {
if err := c.persistAudit(newTable); err != nil {
return errors.New("failed to update audit table")
}
c.audit = newTable
// Unmount the backend
@ -118,18 +120,24 @@ func (c *Core) disableAudit(path string) error {
// loadAudits is invoked as part of postUnseal to load the audit table
func (c *Core) loadAudits() error {
auditTable := &MountTable{}
// Load the existing audit table
raw, err := c.barrier.Get(coreAuditConfigPath)
if err != nil {
c.logger.Printf("[ERR] core: failed to read audit table: %v", err)
return errLoadAuditFailed
}
c.auditLock.Lock()
defer c.auditLock.Unlock()
if raw != nil {
c.audit = &MountTable{}
if err := json.Unmarshal(raw.Value, c.audit); err != nil {
if err := json.Unmarshal(raw.Value, auditTable); err != nil {
c.logger.Printf("[ERR] core: failed to decode audit table: %v", err)
return errLoadAuditFailed
}
c.audit = auditTable
}
// Done if we have restored the audit table
@ -172,6 +180,10 @@ func (c *Core) persistAudit(table *MountTable) error {
// initialize the audit backends
func (c *Core) setupAudits() error {
broker := NewAuditBroker(c.logger)
c.auditLock.Lock()
defer c.auditLock.Unlock()
for _, entry := range c.audit.Entries {
// Create a barrier view using the UUID
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/")
@ -195,6 +207,9 @@ func (c *Core) setupAudits() error {
// teardownAudit is used before we seal the vault to reset the audit
// backends to their unloaded state. This is reversed by loadAudits.
func (c *Core) teardownAudits() error {
c.auditLock.Lock()
defer c.auditLock.Unlock()
c.audit = nil
c.auditBroker = nil
return nil

View File

@ -31,9 +31,6 @@ var (
// enableCredential is used to enable a new credential backend
func (c *Core) enableCredential(entry *MountEntry) error {
c.auth.Lock()
defer c.auth.Unlock()
// Ensure we end the path in a slash
if !strings.HasSuffix(entry.Path, "/") {
entry.Path += "/"
@ -44,6 +41,9 @@ func (c *Core) enableCredential(entry *MountEntry) error {
return fmt.Errorf("backend path must be specified")
}
c.authLock.Lock()
defer c.authLock.Unlock()
// Look for matching name
for _, ent := range c.auth.Entries {
switch {
@ -77,6 +77,7 @@ func (c *Core) enableCredential(entry *MountEntry) error {
if err := c.persistAuth(newTable); err != nil {
return errors.New("failed to update auth table")
}
c.auth = newTable
// Mount the backend
@ -91,9 +92,6 @@ func (c *Core) enableCredential(entry *MountEntry) error {
// disableCredential is used to disable an existing credential backend
func (c *Core) disableCredential(path string) error {
c.auth.Lock()
defer c.auth.Unlock()
// Ensure we end the path in a slash
if !strings.HasSuffix(path, "/") {
path += "/"
@ -111,6 +109,9 @@ func (c *Core) disableCredential(path string) error {
return fmt.Errorf("no matching backend")
}
c.authLock.Lock()
defer c.authLock.Unlock()
// Mark the entry as tainted
if err := c.taintCredEntry(path); err != nil {
return err
@ -156,15 +157,18 @@ func (c *Core) removeCredEntry(path string) error {
if err := c.persistAuth(newTable); err != nil {
return errors.New("failed to update auth table")
}
c.auth = newTable
return nil
}
// taintCredEntry is used to mark an entry in the auth table as tainted
func (c *Core) taintCredEntry(path string) error {
// Taint the entry from the auth table
newTable := c.auth.ShallowClone()
found := newTable.SetTaint(path, true)
// We do this on the original since setting the taint operates
// on the entries which a shallow clone shares anyways
found := c.auth.SetTaint(path, true)
// Ensure there was a match
if !found {
@ -172,27 +176,32 @@ func (c *Core) taintCredEntry(path string) error {
}
// Update the auth table
if err := c.persistAuth(newTable); err != nil {
if err := c.persistAuth(c.auth); err != nil {
return errors.New("failed to update auth table")
}
c.auth = newTable
return nil
}
// loadCredentials is invoked as part of postUnseal to load the auth table
func (c *Core) loadCredentials() error {
authTable := &MountTable{}
// Load the existing mount table
raw, err := c.barrier.Get(coreAuthConfigPath)
if err != nil {
c.logger.Printf("[ERR] core: failed to read auth table: %v", err)
return errLoadAuthFailed
}
c.authLock.Lock()
defer c.authLock.Unlock()
if raw != nil {
c.auth = &MountTable{}
if err := json.Unmarshal(raw.Value, c.auth); err != nil {
if err := json.Unmarshal(raw.Value, authTable); err != nil {
c.logger.Printf("[ERR] core: failed to decode auth table: %v", err)
return errLoadAuthFailed
}
c.auth = authTable
}
// Done if we have restored the auth table
@ -238,6 +247,10 @@ func (c *Core) setupCredentials() error {
var backend logical.Backend
var view *BarrierView
var err error
c.authLock.Lock()
defer c.authLock.Unlock()
for _, entry := range c.auth.Entries {
// Create a barrier view using the UUID
view = NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
@ -279,6 +292,9 @@ func (c *Core) setupCredentials() error {
// teardownCredentials is used before we seal the vault to reset the credential
// backends to their unloaded state. This is reversed by loadCredentials.
func (c *Core) teardownCredentials() error {
c.authLock.Lock()
defer c.authLock.Unlock()
c.auth = nil
c.tokenStore = nil
return nil

View File

@ -214,14 +214,26 @@ type Core struct {
// configuration
mounts *MountTable
// mountsLock is used to ensure that the mounts table does not
// change underneath a calling function
mountsLock sync.RWMutex
// auth is loaded after unseal since it is a protected
// configuration
auth *MountTable
// authLock is used to ensure that the auth table does not
// change underneath a calling function
authLock sync.RWMutex
// audit is loaded after unseal since it is a protected
// configuration
audit *MountTable
// auditLock is used to ensure that the audit table does not
// change underneath a calling function
auditLock sync.RWMutex
// auditBroker is used to ingest the audit events and fan
// out into the configured audit backends
auditBroker *AuditBroker

View File

@ -364,12 +364,13 @@ type SystemBackend struct {
// handleMountTable handles the "mounts" endpoint to provide the mount table
func (b *SystemBackend) handleMountTable(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
b.Core.mounts.Lock()
defer b.Core.mounts.Unlock()
b.Core.mountsLock.RLock()
defer b.Core.mountsLock.RUnlock()
resp := &logical.Response{
Data: make(map[string]interface{}),
}
for _, entry := range b.Core.mounts.Entries {
info := map[string]interface{}{
"type": entry.Type,
@ -613,10 +614,13 @@ func (b *SystemBackend) handleMountTuneWrite(
}
if newDefault != nil || newMax != nil {
b.Core.mountsLock.Lock()
if err := b.tuneMountTTLs(path, &mountEntry.Config, newDefault, newMax); err != nil {
b.Core.mountsLock.Unlock()
b.Backend.Logger().Printf("[ERR] sys: tune of path '%s' failed: %v", path, err)
return handleError(err)
}
b.Core.mountsLock.Unlock()
}
}
@ -673,8 +677,8 @@ func (b *SystemBackend) handleRevokePrefix(
// handleAuthTable handles the "auth" endpoint to provide the auth table
func (b *SystemBackend) handleAuthTable(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
b.Core.auth.Lock()
defer b.Core.auth.Unlock()
b.Core.authLock.RLock()
defer b.Core.authLock.RUnlock()
resp := &logical.Response{
Data: make(map[string]interface{}),
@ -802,8 +806,8 @@ func (b *SystemBackend) handlePolicyDelete(
// handleAuditTable handles the "audit" endpoint to provide the audit table
func (b *SystemBackend) handleAuditTable(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
b.Core.audit.Lock()
defer b.Core.audit.Unlock()
b.Core.auditLock.RLock()
defer b.Core.auditLock.RUnlock()
resp := &logical.Response{
Data: make(map[string]interface{}),

View File

@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/hashicorp/uuid"
@ -56,9 +55,6 @@ var (
// MountTable is used to represent the internal mount table
type MountTable struct {
// This lock should be held whenever modifying the Entries field.
sync.RWMutex
Entries []*MountEntry `json:"entries"`
}
@ -157,9 +153,6 @@ func (e *MountEntry) Clone() *MountEntry {
// Mount is used to mount a new backend to the mount table.
func (c *Core) mount(me *MountEntry) error {
c.mounts.Lock()
defer c.mounts.Unlock()
// Ensure we end the path in a slash
if !strings.HasSuffix(me.Path, "/") {
me.Path += "/"
@ -184,6 +177,9 @@ func (c *Core) mount(me *MountEntry) error {
return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match))
}
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
// Generate a new UUID and view
me.UUID = uuid.GenerateUUID()
view := NewBarrierView(c.barrier, backendBarrierPrefix+me.UUID+"/")
@ -211,9 +207,6 @@ func (c *Core) mount(me *MountEntry) error {
// Unmount is used to unmount a path.
func (c *Core) unmount(path string) error {
c.mounts.Lock()
defer c.mounts.Unlock()
// Ensure we end the path in a slash
if !strings.HasSuffix(path, "/") {
path += "/"
@ -235,6 +228,9 @@ func (c *Core) unmount(path string) error {
// Store the view for this backend
view := c.router.MatchingStorageView(path)
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
// Mark the entry as tainted
if err := c.taintMountEntry(path); err != nil {
return err
@ -283,29 +279,27 @@ func (c *Core) removeMountEntry(path string) error {
if err := c.persistMounts(newTable); err != nil {
return errors.New("failed to update mount table")
}
c.mounts = newTable
return nil
}
// taintMountEntry is used to mark an entry in the mount table as tainted
func (c *Core) taintMountEntry(path string) error {
// Remove the entry from the mount table
newTable := c.mounts.ShallowClone()
newTable.SetTaint(path, true)
// As modifying the taint of an entry affects shallow clones,
// we simply use the original
c.mounts.SetTaint(path, true)
// Update the mount table
if err := c.persistMounts(newTable); err != nil {
if err := c.persistMounts(c.mounts); err != nil {
return errors.New("failed to update mount table")
}
c.mounts = newTable
return nil
}
// Remount is used to remount a path at a new mount point.
func (c *Core) remount(src, dst string) error {
c.mounts.Lock()
defer c.mounts.Unlock()
// Ensure we end the path in a slash
if !strings.HasSuffix(src, "/") {
src += "/"
@ -331,6 +325,9 @@ func (c *Core) remount(src, dst string) error {
return fmt.Errorf("existing mount at '%s'", match)
}
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
// Mark the entry as tainted
if err := c.taintMountEntry(src); err != nil {
return err
@ -351,10 +348,8 @@ func (c *Core) remount(src, dst string) error {
return err
}
// Update the entry in the mount table
newTable := c.mounts.ShallowClone()
var ent *MountEntry
for _, ent = range newTable.Entries {
for _, ent = range c.mounts.Entries {
if ent.Path == src {
ent.Path = dst
ent.Tainted = false
@ -363,12 +358,11 @@ func (c *Core) remount(src, dst string) error {
}
// Update the mount table
if err := c.persistMounts(newTable); err != nil {
if err := c.persistMounts(c.mounts); err != nil {
ent.Path = src
ent.Tainted = true
return errors.New("failed to update mount table")
}
c.mounts = newTable
// Remount the backend
if err := c.router.Remount(src, dst); err != nil {
@ -386,18 +380,23 @@ func (c *Core) remount(src, dst string) error {
// loadMounts is invoked as part of postUnseal to load the mount table
func (c *Core) loadMounts() error {
mountTable := &MountTable{}
// Load the existing mount table
raw, err := c.barrier.Get(coreMountConfigPath)
if err != nil {
c.logger.Printf("[ERR] core: failed to read mount table: %v", err)
return errLoadMountsFailed
}
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
if raw != nil {
c.mounts = &MountTable{}
if err := json.Unmarshal(raw.Value, c.mounts); err != nil {
if err := json.Unmarshal(raw.Value, mountTable); err != nil {
c.logger.Printf("[ERR] core: failed to decode mount table: %v", err)
return errLoadMountsFailed
}
c.mounts = mountTable
}
// Ensure that required entries are loaded, or new ones
@ -462,9 +461,13 @@ func (c *Core) persistMounts(table *MountTable) error {
// setupMounts is invoked after we've loaded the mount table to
// initialize the logical backends and setup the router
func (c *Core) setupMounts() error {
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
var backend logical.Backend
var view *BarrierView
var err error
for _, entry := range c.mounts.Entries {
// Initialize the backend, special casing for system
barrierPath := backendBarrierPrefix + entry.UUID + "/"
@ -514,8 +517,12 @@ func (c *Core) setupMounts() error {
// unloadMounts is used before we seal the vault to reset the mounts to
// their unloaded state, calling Cleanup if defined. This is reversed by load and setup mounts.
func (c *Core) unloadMounts() error {
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
if c.mounts != nil {
for _, e := range c.mounts.Entries {
mountTable := c.mounts.ShallowClone()
for _, e := range mountTable.Entries {
prefix := e.Path
b, ok := c.router.root.Get(prefix)
if ok {
@ -523,6 +530,7 @@ func (c *Core) unloadMounts() error {
}
}
}
c.mounts = nil
c.router = NewRouter()
c.systemBarrierView = nil

View File

@ -29,7 +29,12 @@ const (
// is in-flight at any given time within a single seal/unseal phase.
type RollbackManager struct {
logger *log.Logger
mounts *MountTable
// This gives the current mount table, plus a RWMutex that is
// locked for reading. It is up to the caller to RUnlock it
// when done with the mount table
mounts func() (*MountTable, *sync.RWMutex)
router *Router
period time.Duration
@ -50,7 +55,7 @@ type rollbackState struct {
}
// NewRollbackManager is used to create a new rollback manager
func NewRollbackManager(logger *log.Logger, mounts *MountTable, router *Router) *RollbackManager {
func NewRollbackManager(logger *log.Logger, mounts func() (*MountTable, *sync.RWMutex), router *Router) *RollbackManager {
r := &RollbackManager{
logger: logger,
mounts: mounts,
@ -101,12 +106,17 @@ func (m *RollbackManager) run() {
// triggerRollbacks is used to trigger the rollbacks across all the backends
func (m *RollbackManager) triggerRollbacks() {
m.mounts.RLock()
defer m.mounts.RUnlock()
m.inflightLock.Lock()
defer m.inflightLock.Unlock()
for _, e := range m.mounts.Entries {
mounts, mountsLock := m.mounts()
mountsLock.RLock()
defer mountsLock.RUnlock()
if mounts == nil {
return
}
for _, e := range mounts.Entries {
if _, ok := m.inflight[e.Path]; !ok {
m.startRollback(e.Path)
}
@ -127,6 +137,8 @@ func (m *RollbackManager) startRollback(path string) *rollbackState {
// attemptRollback invokes a RollbackOperation for the given path
func (m *RollbackManager) attemptRollback(path string, rs *rollbackState) (err error) {
defer metrics.MeasureSince([]string{"rollback", "attempt", strings.Replace(path, "/", "-", -1)}, time.Now())
m.logger.Printf("[DEBUG] rollback: attempting rollback on %s", path)
defer func() {
rs.lastError = err
rs.Done()
@ -177,7 +189,10 @@ func (m *RollbackManager) Rollback(path string) error {
// startRollback is used to start the rollback manager after unsealing
func (c *Core) startRollback() error {
c.rollback = NewRollbackManager(c.logger, c.mounts, c.router)
mountsFunc := func() (*MountTable, *sync.RWMutex) {
return c.mounts, &c.mountsLock
}
c.rollback = NewRollbackManager(c.logger, mountsFunc, c.router)
c.rollback.Start()
return nil
}

View File

@ -25,8 +25,13 @@ func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
t.Fatalf("err: %s", err)
}
mountsMutex := &sync.RWMutex{}
mountsFunc := func() (*MountTable, *sync.RWMutex) {
return mounts, mountsMutex
}
logger := log.New(os.Stderr, "", log.LstdFlags)
rb := NewRollbackManager(logger, mounts, router)
rb := NewRollbackManager(logger, mountsFunc, router)
rb.period = 10 * time.Millisecond
return rb, backend
}