Merge pull request #775 from hashicorp/issue-771

Rearchitect MountTable locking and fix rollback.
This commit is contained in:
Jeff Mitchell 2015-11-15 17:33:30 -05:00
commit 0b3c7b177a
8 changed files with 135 additions and 60 deletions

View File

@ -82,6 +82,7 @@ generate them, leading to client errors.
* core: Fix an error that could happen in some failure scenarios where Vault
could fail to revert to a clean state [GH-733]
* core: Ensure secondary indexes are removed when a lease is expired [GH-749]
* core: Ensure rollback manager uses an up-to-date mounts table [GH-771]
* everywhere: Don't use http.DefaultClient, as it shares state implicitly and
is a source of hard-to-track-down bugs [GH-700]
* credential/token: Allow creating orphan tokens via an API path [GH-748]

View File

@ -35,9 +35,6 @@ var (
// enableAudit is used to enable a new audit backend
func (c *Core) enableAudit(entry *MountEntry) error {
c.audit.Lock()
defer c.audit.Unlock()
// Ensure we end the path in a slash
if !strings.HasSuffix(entry.Path, "/") {
entry.Path += "/"
@ -48,6 +45,10 @@ func (c *Core) enableAudit(entry *MountEntry) error {
return fmt.Errorf("backend path must be specified")
}
// Update the audit table
c.auditLock.Lock()
defer c.auditLock.Unlock()
// Look for matching name
for _, ent := range c.audit.Entries {
switch {
@ -70,12 +71,12 @@ func (c *Core) enableAudit(entry *MountEntry) error {
return err
}
// Update the audit table
newTable := c.audit.ShallowClone()
newTable.Entries = append(newTable.Entries, entry)
if err := c.persistAudit(newTable); err != nil {
return errors.New("failed to update audit table")
}
c.audit = newTable
// Register the backend
@ -87,15 +88,15 @@ func (c *Core) enableAudit(entry *MountEntry) error {
// disableAudit is used to disable an existing audit backend
func (c *Core) disableAudit(path string) error {
c.audit.Lock()
defer c.audit.Unlock()
// Ensure we end the path in a slash
if !strings.HasSuffix(path, "/") {
path += "/"
}
// Remove the entry from the mount table
c.auditLock.Lock()
defer c.auditLock.Unlock()
newTable := c.audit.ShallowClone()
found := newTable.Remove(path)
@ -108,6 +109,7 @@ func (c *Core) disableAudit(path string) error {
if err := c.persistAudit(newTable); err != nil {
return errors.New("failed to update audit table")
}
c.audit = newTable
// Unmount the backend
@ -118,18 +120,24 @@ func (c *Core) disableAudit(path string) error {
// loadAudits is invoked as part of postUnseal to load the audit table
func (c *Core) loadAudits() error {
auditTable := &MountTable{}
// Load the existing audit table
raw, err := c.barrier.Get(coreAuditConfigPath)
if err != nil {
c.logger.Printf("[ERR] core: failed to read audit table: %v", err)
return errLoadAuditFailed
}
c.auditLock.Lock()
defer c.auditLock.Unlock()
if raw != nil {
c.audit = &MountTable{}
if err := json.Unmarshal(raw.Value, c.audit); err != nil {
if err := json.Unmarshal(raw.Value, auditTable); err != nil {
c.logger.Printf("[ERR] core: failed to decode audit table: %v", err)
return errLoadAuditFailed
}
c.audit = auditTable
}
// Done if we have restored the audit table
@ -172,6 +180,10 @@ func (c *Core) persistAudit(table *MountTable) error {
// initialize the audit backends
func (c *Core) setupAudits() error {
broker := NewAuditBroker(c.logger)
c.auditLock.Lock()
defer c.auditLock.Unlock()
for _, entry := range c.audit.Entries {
// Create a barrier view using the UUID
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/")
@ -195,6 +207,9 @@ func (c *Core) setupAudits() error {
// teardownAudit is used before we seal the vault to reset the audit
// backends to their unloaded state. This is reversed by loadAudits.
func (c *Core) teardownAudits() error {
c.auditLock.Lock()
defer c.auditLock.Unlock()
c.audit = nil
c.auditBroker = nil
return nil

View File

@ -31,9 +31,6 @@ var (
// enableCredential is used to enable a new credential backend
func (c *Core) enableCredential(entry *MountEntry) error {
c.auth.Lock()
defer c.auth.Unlock()
// Ensure we end the path in a slash
if !strings.HasSuffix(entry.Path, "/") {
entry.Path += "/"
@ -44,6 +41,9 @@ func (c *Core) enableCredential(entry *MountEntry) error {
return fmt.Errorf("backend path must be specified")
}
c.authLock.Lock()
defer c.authLock.Unlock()
// Look for matching name
for _, ent := range c.auth.Entries {
switch {
@ -77,6 +77,7 @@ func (c *Core) enableCredential(entry *MountEntry) error {
if err := c.persistAuth(newTable); err != nil {
return errors.New("failed to update auth table")
}
c.auth = newTable
// Mount the backend
@ -91,9 +92,6 @@ func (c *Core) enableCredential(entry *MountEntry) error {
// disableCredential is used to disable an existing credential backend
func (c *Core) disableCredential(path string) error {
c.auth.Lock()
defer c.auth.Unlock()
// Ensure we end the path in a slash
if !strings.HasSuffix(path, "/") {
path += "/"
@ -111,6 +109,9 @@ func (c *Core) disableCredential(path string) error {
return fmt.Errorf("no matching backend")
}
c.authLock.Lock()
defer c.authLock.Unlock()
// Mark the entry as tainted
if err := c.taintCredEntry(path); err != nil {
return err
@ -156,15 +157,18 @@ func (c *Core) removeCredEntry(path string) error {
if err := c.persistAuth(newTable); err != nil {
return errors.New("failed to update auth table")
}
c.auth = newTable
return nil
}
// taintCredEntry is used to mark an entry in the auth table as tainted
func (c *Core) taintCredEntry(path string) error {
// Taint the entry from the auth table
newTable := c.auth.ShallowClone()
found := newTable.SetTaint(path, true)
// We do this on the original since setting the taint operates
// on the entries which a shallow clone shares anyways
found := c.auth.SetTaint(path, true)
// Ensure there was a match
if !found {
@ -172,27 +176,32 @@ func (c *Core) taintCredEntry(path string) error {
}
// Update the auth table
if err := c.persistAuth(newTable); err != nil {
if err := c.persistAuth(c.auth); err != nil {
return errors.New("failed to update auth table")
}
c.auth = newTable
return nil
}
// loadCredentials is invoked as part of postUnseal to load the auth table
func (c *Core) loadCredentials() error {
authTable := &MountTable{}
// Load the existing mount table
raw, err := c.barrier.Get(coreAuthConfigPath)
if err != nil {
c.logger.Printf("[ERR] core: failed to read auth table: %v", err)
return errLoadAuthFailed
}
c.authLock.Lock()
defer c.authLock.Unlock()
if raw != nil {
c.auth = &MountTable{}
if err := json.Unmarshal(raw.Value, c.auth); err != nil {
if err := json.Unmarshal(raw.Value, authTable); err != nil {
c.logger.Printf("[ERR] core: failed to decode auth table: %v", err)
return errLoadAuthFailed
}
c.auth = authTable
}
// Done if we have restored the auth table
@ -238,6 +247,10 @@ func (c *Core) setupCredentials() error {
var backend logical.Backend
var view *BarrierView
var err error
c.authLock.Lock()
defer c.authLock.Unlock()
for _, entry := range c.auth.Entries {
// Create a barrier view using the UUID
view = NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
@ -279,6 +292,9 @@ func (c *Core) setupCredentials() error {
// teardownCredentials is used before we seal the vault to reset the credential
// backends to their unloaded state. This is reversed by loadCredentials.
func (c *Core) teardownCredentials() error {
c.authLock.Lock()
defer c.authLock.Unlock()
c.auth = nil
c.tokenStore = nil
return nil

View File

@ -214,14 +214,26 @@ type Core struct {
// configuration
mounts *MountTable
// mountsLock is used to ensure that the mounts table does not
// change underneath a calling function
mountsLock sync.RWMutex
// auth is loaded after unseal since it is a protected
// configuration
auth *MountTable
// authLock is used to ensure that the auth table does not
// change underneath a calling function
authLock sync.RWMutex
// audit is loaded after unseal since it is a protected
// configuration
audit *MountTable
// auditLock is used to ensure that the audit table does not
// change underneath a calling function
auditLock sync.RWMutex
// auditBroker is used to ingest the audit events and fan
// out into the configured audit backends
auditBroker *AuditBroker

View File

@ -364,12 +364,13 @@ type SystemBackend struct {
// handleMountTable handles the "mounts" endpoint to provide the mount table
func (b *SystemBackend) handleMountTable(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
b.Core.mounts.Lock()
defer b.Core.mounts.Unlock()
b.Core.mountsLock.RLock()
defer b.Core.mountsLock.RUnlock()
resp := &logical.Response{
Data: make(map[string]interface{}),
}
for _, entry := range b.Core.mounts.Entries {
info := map[string]interface{}{
"type": entry.Type,
@ -613,6 +614,8 @@ func (b *SystemBackend) handleMountTuneWrite(
}
if newDefault != nil || newMax != nil {
b.Core.mountsLock.Lock()
defer b.Core.mountsLock.Unlock()
if err := b.tuneMountTTLs(path, &mountEntry.Config, newDefault, newMax); err != nil {
b.Backend.Logger().Printf("[ERR] sys: tune of path '%s' failed: %v", path, err)
return handleError(err)
@ -673,8 +676,8 @@ func (b *SystemBackend) handleRevokePrefix(
// handleAuthTable handles the "auth" endpoint to provide the auth table
func (b *SystemBackend) handleAuthTable(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
b.Core.auth.Lock()
defer b.Core.auth.Unlock()
b.Core.authLock.RLock()
defer b.Core.authLock.RUnlock()
resp := &logical.Response{
Data: make(map[string]interface{}),
@ -802,8 +805,8 @@ func (b *SystemBackend) handlePolicyDelete(
// handleAuditTable handles the "audit" endpoint to provide the audit table
func (b *SystemBackend) handleAuditTable(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
b.Core.audit.Lock()
defer b.Core.audit.Unlock()
b.Core.auditLock.RLock()
defer b.Core.auditLock.RUnlock()
resp := &logical.Response{
Data: make(map[string]interface{}),

View File

@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/hashicorp/uuid"
@ -56,9 +55,6 @@ var (
// MountTable is used to represent the internal mount table
type MountTable struct {
// This lock should be held whenever modifying the Entries field.
sync.RWMutex
Entries []*MountEntry `json:"entries"`
}
@ -157,9 +153,6 @@ func (e *MountEntry) Clone() *MountEntry {
// Mount is used to mount a new backend to the mount table.
func (c *Core) mount(me *MountEntry) error {
c.mounts.Lock()
defer c.mounts.Unlock()
// Ensure we end the path in a slash
if !strings.HasSuffix(me.Path, "/") {
me.Path += "/"
@ -184,6 +177,9 @@ func (c *Core) mount(me *MountEntry) error {
return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match))
}
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
// Generate a new UUID and view
me.UUID = uuid.GenerateUUID()
view := NewBarrierView(c.barrier, backendBarrierPrefix+me.UUID+"/")
@ -211,9 +207,6 @@ func (c *Core) mount(me *MountEntry) error {
// Unmount is used to unmount a path.
func (c *Core) unmount(path string) error {
c.mounts.Lock()
defer c.mounts.Unlock()
// Ensure we end the path in a slash
if !strings.HasSuffix(path, "/") {
path += "/"
@ -235,6 +228,9 @@ func (c *Core) unmount(path string) error {
// Store the view for this backend
view := c.router.MatchingStorageView(path)
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
// Mark the entry as tainted
if err := c.taintMountEntry(path); err != nil {
return err
@ -283,29 +279,27 @@ func (c *Core) removeMountEntry(path string) error {
if err := c.persistMounts(newTable); err != nil {
return errors.New("failed to update mount table")
}
c.mounts = newTable
return nil
}
// taintMountEntry is used to mark an entry in the mount table as tainted
func (c *Core) taintMountEntry(path string) error {
// Remove the entry from the mount table
newTable := c.mounts.ShallowClone()
newTable.SetTaint(path, true)
// As modifying the taint of an entry affects shallow clones,
// we simply use the original
c.mounts.SetTaint(path, true)
// Update the mount table
if err := c.persistMounts(newTable); err != nil {
if err := c.persistMounts(c.mounts); err != nil {
return errors.New("failed to update mount table")
}
c.mounts = newTable
return nil
}
// Remount is used to remount a path at a new mount point.
func (c *Core) remount(src, dst string) error {
c.mounts.Lock()
defer c.mounts.Unlock()
// Ensure we end the path in a slash
if !strings.HasSuffix(src, "/") {
src += "/"
@ -331,6 +325,9 @@ func (c *Core) remount(src, dst string) error {
return fmt.Errorf("existing mount at '%s'", match)
}
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
// Mark the entry as tainted
if err := c.taintMountEntry(src); err != nil {
return err
@ -351,10 +348,8 @@ func (c *Core) remount(src, dst string) error {
return err
}
// Update the entry in the mount table
newTable := c.mounts.ShallowClone()
var ent *MountEntry
for _, ent = range newTable.Entries {
for _, ent = range c.mounts.Entries {
if ent.Path == src {
ent.Path = dst
ent.Tainted = false
@ -363,12 +358,11 @@ func (c *Core) remount(src, dst string) error {
}
// Update the mount table
if err := c.persistMounts(newTable); err != nil {
if err := c.persistMounts(c.mounts); err != nil {
ent.Path = src
ent.Tainted = true
return errors.New("failed to update mount table")
}
c.mounts = newTable
// Remount the backend
if err := c.router.Remount(src, dst); err != nil {
@ -386,18 +380,23 @@ func (c *Core) remount(src, dst string) error {
// loadMounts is invoked as part of postUnseal to load the mount table
func (c *Core) loadMounts() error {
mountTable := &MountTable{}
// Load the existing mount table
raw, err := c.barrier.Get(coreMountConfigPath)
if err != nil {
c.logger.Printf("[ERR] core: failed to read mount table: %v", err)
return errLoadMountsFailed
}
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
if raw != nil {
c.mounts = &MountTable{}
if err := json.Unmarshal(raw.Value, c.mounts); err != nil {
if err := json.Unmarshal(raw.Value, mountTable); err != nil {
c.logger.Printf("[ERR] core: failed to decode mount table: %v", err)
return errLoadMountsFailed
}
c.mounts = mountTable
}
// Ensure that required entries are loaded, or new ones
@ -462,9 +461,13 @@ func (c *Core) persistMounts(table *MountTable) error {
// setupMounts is invoked after we've loaded the mount table to
// initialize the logical backends and setup the router
func (c *Core) setupMounts() error {
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
var backend logical.Backend
var view *BarrierView
var err error
for _, entry := range c.mounts.Entries {
// Initialize the backend, special casing for system
barrierPath := backendBarrierPrefix + entry.UUID + "/"
@ -514,8 +517,12 @@ func (c *Core) setupMounts() error {
// unloadMounts is used before we seal the vault to reset the mounts to
// their unloaded state, calling Cleanup if defined. This is reversed by load and setup mounts.
func (c *Core) unloadMounts() error {
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
if c.mounts != nil {
for _, e := range c.mounts.Entries {
mountTable := c.mounts.ShallowClone()
for _, e := range mountTable.Entries {
prefix := e.Path
b, ok := c.router.root.Get(prefix)
if ok {
@ -523,6 +530,7 @@ func (c *Core) unloadMounts() error {
}
}
}
c.mounts = nil
c.router = NewRouter()
c.systemBarrierView = nil

View File

@ -29,7 +29,12 @@ const (
// is in-flight at any given time within a single seal/unseal phase.
type RollbackManager struct {
logger *log.Logger
mounts *MountTable
// This gives the current mount table, plus a RWMutex that is
// locked for reading. It is up to the caller to RUnlock it
// when done with the mount table
mounts func() []*MountEntry
router *Router
period time.Duration
@ -50,7 +55,7 @@ type rollbackState struct {
}
// NewRollbackManager is used to create a new rollback manager
func NewRollbackManager(logger *log.Logger, mounts *MountTable, router *Router) *RollbackManager {
func NewRollbackManager(logger *log.Logger, mounts func() []*MountEntry, router *Router) *RollbackManager {
r := &RollbackManager{
logger: logger,
mounts: mounts,
@ -101,12 +106,12 @@ func (m *RollbackManager) run() {
// triggerRollbacks is used to trigger the rollbacks across all the backends
func (m *RollbackManager) triggerRollbacks() {
m.mounts.RLock()
defer m.mounts.RUnlock()
m.inflightLock.Lock()
defer m.inflightLock.Unlock()
for _, e := range m.mounts.Entries {
mounts := m.mounts()
for _, e := range mounts {
if _, ok := m.inflight[e.Path]; !ok {
m.startRollback(e.Path)
}
@ -127,6 +132,8 @@ func (m *RollbackManager) startRollback(path string) *rollbackState {
// attemptRollback invokes a RollbackOperation for the given path
func (m *RollbackManager) attemptRollback(path string, rs *rollbackState) (err error) {
defer metrics.MeasureSince([]string{"rollback", "attempt", strings.Replace(path, "/", "-", -1)}, time.Now())
m.logger.Printf("[DEBUG] rollback: attempting rollback on %s", path)
defer func() {
rs.lastError = err
rs.Done()
@ -177,7 +184,16 @@ func (m *RollbackManager) Rollback(path string) error {
// startRollback is used to start the rollback manager after unsealing
func (c *Core) startRollback() error {
c.rollback = NewRollbackManager(c.logger, c.mounts, c.router)
mountsFunc := func() []*MountEntry {
ret := []*MountEntry{}
c.mountsLock.RLock()
defer c.mountsLock.RUnlock()
for _, entry := range c.mounts.Entries {
ret = append(ret, entry)
}
return ret
}
c.rollback = NewRollbackManager(c.logger, mountsFunc, c.router)
c.rollback.Start()
return nil
}

View File

@ -25,8 +25,12 @@ func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
t.Fatalf("err: %s", err)
}
mountsFunc := func() []*MountEntry {
return mounts.Entries
}
logger := log.New(os.Stderr, "", log.LstdFlags)
rb := NewRollbackManager(logger, mounts, router)
rb := NewRollbackManager(logger, mountsFunc, router)
rb.period = 10 * time.Millisecond
return rb, backend
}