Rearchitect MountTable locking and fix rollback.
The rollback manager was using a saved MountTable rather than the current table, causing it to attempt to rollback unmounted mounts, and never rollback new mounts. In fixing this, it became clear that bad things could happen to the mount table...the table itself could be locked, but the table pointer (which is what the rollback manager needs) could be modified at any time without locking. This commit therefore also returns locking to a mutex outside the table instead of inside, and plumbs RLock/RUnlock through to the various places that are reading the table but not holding a write lock. Both unit tests and race detection pass. Fixes #771
This commit is contained in:
parent
fa646a1eb1
commit
bc4c18a1cf
|
@ -82,6 +82,7 @@ generate them, leading to client errors.
|
|||
* core: Fix an error that could happen in some failure scenarios where Vault
|
||||
could fail to revert to a clean state [GH-733]
|
||||
* core: Ensure secondary indexes are removed when a lease is expired [GH-749]
|
||||
* core: Ensure rollback manager uses an up-to-date mounts table [GH-771]
|
||||
* everywhere: Don't use http.DefaultClient, as it shares state implicitly and
|
||||
is a source of hard-to-track-down bugs [GH-700]
|
||||
* credential/token: Allow creating orphan tokens via an API path [GH-748]
|
||||
|
|
|
@ -35,9 +35,6 @@ var (
|
|||
|
||||
// enableAudit is used to enable a new audit backend
|
||||
func (c *Core) enableAudit(entry *MountEntry) error {
|
||||
c.audit.Lock()
|
||||
defer c.audit.Unlock()
|
||||
|
||||
// Ensure we end the path in a slash
|
||||
if !strings.HasSuffix(entry.Path, "/") {
|
||||
entry.Path += "/"
|
||||
|
@ -48,6 +45,10 @@ func (c *Core) enableAudit(entry *MountEntry) error {
|
|||
return fmt.Errorf("backend path must be specified")
|
||||
}
|
||||
|
||||
// Update the audit table
|
||||
c.auditLock.Lock()
|
||||
defer c.auditLock.Unlock()
|
||||
|
||||
// Look for matching name
|
||||
for _, ent := range c.audit.Entries {
|
||||
switch {
|
||||
|
@ -70,12 +71,12 @@ func (c *Core) enableAudit(entry *MountEntry) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Update the audit table
|
||||
newTable := c.audit.ShallowClone()
|
||||
newTable.Entries = append(newTable.Entries, entry)
|
||||
if err := c.persistAudit(newTable); err != nil {
|
||||
return errors.New("failed to update audit table")
|
||||
}
|
||||
|
||||
c.audit = newTable
|
||||
|
||||
// Register the backend
|
||||
|
@ -87,15 +88,15 @@ func (c *Core) enableAudit(entry *MountEntry) error {
|
|||
|
||||
// disableAudit is used to disable an existing audit backend
|
||||
func (c *Core) disableAudit(path string) error {
|
||||
c.audit.Lock()
|
||||
defer c.audit.Unlock()
|
||||
|
||||
// Ensure we end the path in a slash
|
||||
if !strings.HasSuffix(path, "/") {
|
||||
path += "/"
|
||||
}
|
||||
|
||||
// Remove the entry from the mount table
|
||||
c.auditLock.Lock()
|
||||
defer c.auditLock.Unlock()
|
||||
|
||||
newTable := c.audit.ShallowClone()
|
||||
found := newTable.Remove(path)
|
||||
|
||||
|
@ -108,6 +109,7 @@ func (c *Core) disableAudit(path string) error {
|
|||
if err := c.persistAudit(newTable); err != nil {
|
||||
return errors.New("failed to update audit table")
|
||||
}
|
||||
|
||||
c.audit = newTable
|
||||
|
||||
// Unmount the backend
|
||||
|
@ -118,18 +120,24 @@ func (c *Core) disableAudit(path string) error {
|
|||
|
||||
// loadAudits is invoked as part of postUnseal to load the audit table
|
||||
func (c *Core) loadAudits() error {
|
||||
auditTable := &MountTable{}
|
||||
|
||||
// Load the existing audit table
|
||||
raw, err := c.barrier.Get(coreAuditConfigPath)
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to read audit table: %v", err)
|
||||
return errLoadAuditFailed
|
||||
}
|
||||
|
||||
c.auditLock.Lock()
|
||||
defer c.auditLock.Unlock()
|
||||
|
||||
if raw != nil {
|
||||
c.audit = &MountTable{}
|
||||
if err := json.Unmarshal(raw.Value, c.audit); err != nil {
|
||||
if err := json.Unmarshal(raw.Value, auditTable); err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to decode audit table: %v", err)
|
||||
return errLoadAuditFailed
|
||||
}
|
||||
c.audit = auditTable
|
||||
}
|
||||
|
||||
// Done if we have restored the audit table
|
||||
|
@ -172,6 +180,10 @@ func (c *Core) persistAudit(table *MountTable) error {
|
|||
// initialize the audit backends
|
||||
func (c *Core) setupAudits() error {
|
||||
broker := NewAuditBroker(c.logger)
|
||||
|
||||
c.auditLock.Lock()
|
||||
defer c.auditLock.Unlock()
|
||||
|
||||
for _, entry := range c.audit.Entries {
|
||||
// Create a barrier view using the UUID
|
||||
view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/")
|
||||
|
@ -195,6 +207,9 @@ func (c *Core) setupAudits() error {
|
|||
// teardownAudit is used before we seal the vault to reset the audit
|
||||
// backends to their unloaded state. This is reversed by loadAudits.
|
||||
func (c *Core) teardownAudits() error {
|
||||
c.auditLock.Lock()
|
||||
defer c.auditLock.Unlock()
|
||||
|
||||
c.audit = nil
|
||||
c.auditBroker = nil
|
||||
return nil
|
||||
|
|
|
@ -31,9 +31,6 @@ var (
|
|||
|
||||
// enableCredential is used to enable a new credential backend
|
||||
func (c *Core) enableCredential(entry *MountEntry) error {
|
||||
c.auth.Lock()
|
||||
defer c.auth.Unlock()
|
||||
|
||||
// Ensure we end the path in a slash
|
||||
if !strings.HasSuffix(entry.Path, "/") {
|
||||
entry.Path += "/"
|
||||
|
@ -44,6 +41,9 @@ func (c *Core) enableCredential(entry *MountEntry) error {
|
|||
return fmt.Errorf("backend path must be specified")
|
||||
}
|
||||
|
||||
c.authLock.Lock()
|
||||
defer c.authLock.Unlock()
|
||||
|
||||
// Look for matching name
|
||||
for _, ent := range c.auth.Entries {
|
||||
switch {
|
||||
|
@ -77,6 +77,7 @@ func (c *Core) enableCredential(entry *MountEntry) error {
|
|||
if err := c.persistAuth(newTable); err != nil {
|
||||
return errors.New("failed to update auth table")
|
||||
}
|
||||
|
||||
c.auth = newTable
|
||||
|
||||
// Mount the backend
|
||||
|
@ -91,9 +92,6 @@ func (c *Core) enableCredential(entry *MountEntry) error {
|
|||
|
||||
// disableCredential is used to disable an existing credential backend
|
||||
func (c *Core) disableCredential(path string) error {
|
||||
c.auth.Lock()
|
||||
defer c.auth.Unlock()
|
||||
|
||||
// Ensure we end the path in a slash
|
||||
if !strings.HasSuffix(path, "/") {
|
||||
path += "/"
|
||||
|
@ -111,6 +109,9 @@ func (c *Core) disableCredential(path string) error {
|
|||
return fmt.Errorf("no matching backend")
|
||||
}
|
||||
|
||||
c.authLock.Lock()
|
||||
defer c.authLock.Unlock()
|
||||
|
||||
// Mark the entry as tainted
|
||||
if err := c.taintCredEntry(path); err != nil {
|
||||
return err
|
||||
|
@ -156,15 +157,18 @@ func (c *Core) removeCredEntry(path string) error {
|
|||
if err := c.persistAuth(newTable); err != nil {
|
||||
return errors.New("failed to update auth table")
|
||||
}
|
||||
|
||||
c.auth = newTable
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// taintCredEntry is used to mark an entry in the auth table as tainted
|
||||
func (c *Core) taintCredEntry(path string) error {
|
||||
// Taint the entry from the auth table
|
||||
newTable := c.auth.ShallowClone()
|
||||
found := newTable.SetTaint(path, true)
|
||||
// We do this on the original since setting the taint operates
|
||||
// on the entries which a shallow clone shares anyways
|
||||
found := c.auth.SetTaint(path, true)
|
||||
|
||||
// Ensure there was a match
|
||||
if !found {
|
||||
|
@ -172,27 +176,32 @@ func (c *Core) taintCredEntry(path string) error {
|
|||
}
|
||||
|
||||
// Update the auth table
|
||||
if err := c.persistAuth(newTable); err != nil {
|
||||
if err := c.persistAuth(c.auth); err != nil {
|
||||
return errors.New("failed to update auth table")
|
||||
}
|
||||
c.auth = newTable
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadCredentials is invoked as part of postUnseal to load the auth table
|
||||
func (c *Core) loadCredentials() error {
|
||||
authTable := &MountTable{}
|
||||
// Load the existing mount table
|
||||
raw, err := c.barrier.Get(coreAuthConfigPath)
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to read auth table: %v", err)
|
||||
return errLoadAuthFailed
|
||||
}
|
||||
|
||||
c.authLock.Lock()
|
||||
defer c.authLock.Unlock()
|
||||
|
||||
if raw != nil {
|
||||
c.auth = &MountTable{}
|
||||
if err := json.Unmarshal(raw.Value, c.auth); err != nil {
|
||||
if err := json.Unmarshal(raw.Value, authTable); err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to decode auth table: %v", err)
|
||||
return errLoadAuthFailed
|
||||
}
|
||||
c.auth = authTable
|
||||
}
|
||||
|
||||
// Done if we have restored the auth table
|
||||
|
@ -238,6 +247,10 @@ func (c *Core) setupCredentials() error {
|
|||
var backend logical.Backend
|
||||
var view *BarrierView
|
||||
var err error
|
||||
|
||||
c.authLock.Lock()
|
||||
defer c.authLock.Unlock()
|
||||
|
||||
for _, entry := range c.auth.Entries {
|
||||
// Create a barrier view using the UUID
|
||||
view = NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
|
||||
|
@ -279,6 +292,9 @@ func (c *Core) setupCredentials() error {
|
|||
// teardownCredentials is used before we seal the vault to reset the credential
|
||||
// backends to their unloaded state. This is reversed by loadCredentials.
|
||||
func (c *Core) teardownCredentials() error {
|
||||
c.authLock.Lock()
|
||||
defer c.authLock.Unlock()
|
||||
|
||||
c.auth = nil
|
||||
c.tokenStore = nil
|
||||
return nil
|
||||
|
|
|
@ -214,14 +214,26 @@ type Core struct {
|
|||
// configuration
|
||||
mounts *MountTable
|
||||
|
||||
// mountsLock is used to ensure that the mounts table does not
|
||||
// change underneath a calling function
|
||||
mountsLock sync.RWMutex
|
||||
|
||||
// auth is loaded after unseal since it is a protected
|
||||
// configuration
|
||||
auth *MountTable
|
||||
|
||||
// authLock is used to ensure that the auth table does not
|
||||
// change underneath a calling function
|
||||
authLock sync.RWMutex
|
||||
|
||||
// audit is loaded after unseal since it is a protected
|
||||
// configuration
|
||||
audit *MountTable
|
||||
|
||||
// auditLock is used to ensure that the audit table does not
|
||||
// change underneath a calling function
|
||||
auditLock sync.RWMutex
|
||||
|
||||
// auditBroker is used to ingest the audit events and fan
|
||||
// out into the configured audit backends
|
||||
auditBroker *AuditBroker
|
||||
|
|
|
@ -364,12 +364,13 @@ type SystemBackend struct {
|
|||
// handleMountTable handles the "mounts" endpoint to provide the mount table
|
||||
func (b *SystemBackend) handleMountTable(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
b.Core.mounts.Lock()
|
||||
defer b.Core.mounts.Unlock()
|
||||
b.Core.mountsLock.RLock()
|
||||
defer b.Core.mountsLock.RUnlock()
|
||||
|
||||
resp := &logical.Response{
|
||||
Data: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
for _, entry := range b.Core.mounts.Entries {
|
||||
info := map[string]interface{}{
|
||||
"type": entry.Type,
|
||||
|
@ -613,10 +614,13 @@ func (b *SystemBackend) handleMountTuneWrite(
|
|||
}
|
||||
|
||||
if newDefault != nil || newMax != nil {
|
||||
b.Core.mountsLock.Lock()
|
||||
if err := b.tuneMountTTLs(path, &mountEntry.Config, newDefault, newMax); err != nil {
|
||||
b.Core.mountsLock.Unlock()
|
||||
b.Backend.Logger().Printf("[ERR] sys: tune of path '%s' failed: %v", path, err)
|
||||
return handleError(err)
|
||||
}
|
||||
b.Core.mountsLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -673,8 +677,8 @@ func (b *SystemBackend) handleRevokePrefix(
|
|||
// handleAuthTable handles the "auth" endpoint to provide the auth table
|
||||
func (b *SystemBackend) handleAuthTable(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
b.Core.auth.Lock()
|
||||
defer b.Core.auth.Unlock()
|
||||
b.Core.authLock.RLock()
|
||||
defer b.Core.authLock.RUnlock()
|
||||
|
||||
resp := &logical.Response{
|
||||
Data: make(map[string]interface{}),
|
||||
|
@ -802,8 +806,8 @@ func (b *SystemBackend) handlePolicyDelete(
|
|||
// handleAuditTable handles the "audit" endpoint to provide the audit table
|
||||
func (b *SystemBackend) handleAuditTable(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
b.Core.audit.Lock()
|
||||
defer b.Core.audit.Unlock()
|
||||
b.Core.auditLock.RLock()
|
||||
defer b.Core.auditLock.RUnlock()
|
||||
|
||||
resp := &logical.Response{
|
||||
Data: make(map[string]interface{}),
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/uuid"
|
||||
|
@ -56,9 +55,6 @@ var (
|
|||
|
||||
// MountTable is used to represent the internal mount table
|
||||
type MountTable struct {
|
||||
// This lock should be held whenever modifying the Entries field.
|
||||
sync.RWMutex
|
||||
|
||||
Entries []*MountEntry `json:"entries"`
|
||||
}
|
||||
|
||||
|
@ -157,9 +153,6 @@ func (e *MountEntry) Clone() *MountEntry {
|
|||
|
||||
// Mount is used to mount a new backend to the mount table.
|
||||
func (c *Core) mount(me *MountEntry) error {
|
||||
c.mounts.Lock()
|
||||
defer c.mounts.Unlock()
|
||||
|
||||
// Ensure we end the path in a slash
|
||||
if !strings.HasSuffix(me.Path, "/") {
|
||||
me.Path += "/"
|
||||
|
@ -184,6 +177,9 @@ func (c *Core) mount(me *MountEntry) error {
|
|||
return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match))
|
||||
}
|
||||
|
||||
c.mountsLock.Lock()
|
||||
defer c.mountsLock.Unlock()
|
||||
|
||||
// Generate a new UUID and view
|
||||
me.UUID = uuid.GenerateUUID()
|
||||
view := NewBarrierView(c.barrier, backendBarrierPrefix+me.UUID+"/")
|
||||
|
@ -211,9 +207,6 @@ func (c *Core) mount(me *MountEntry) error {
|
|||
|
||||
// Unmount is used to unmount a path.
|
||||
func (c *Core) unmount(path string) error {
|
||||
c.mounts.Lock()
|
||||
defer c.mounts.Unlock()
|
||||
|
||||
// Ensure we end the path in a slash
|
||||
if !strings.HasSuffix(path, "/") {
|
||||
path += "/"
|
||||
|
@ -235,6 +228,9 @@ func (c *Core) unmount(path string) error {
|
|||
// Store the view for this backend
|
||||
view := c.router.MatchingStorageView(path)
|
||||
|
||||
c.mountsLock.Lock()
|
||||
defer c.mountsLock.Unlock()
|
||||
|
||||
// Mark the entry as tainted
|
||||
if err := c.taintMountEntry(path); err != nil {
|
||||
return err
|
||||
|
@ -283,29 +279,27 @@ func (c *Core) removeMountEntry(path string) error {
|
|||
if err := c.persistMounts(newTable); err != nil {
|
||||
return errors.New("failed to update mount table")
|
||||
}
|
||||
|
||||
c.mounts = newTable
|
||||
return nil
|
||||
}
|
||||
|
||||
// taintMountEntry is used to mark an entry in the mount table as tainted
|
||||
func (c *Core) taintMountEntry(path string) error {
|
||||
// Remove the entry from the mount table
|
||||
newTable := c.mounts.ShallowClone()
|
||||
newTable.SetTaint(path, true)
|
||||
// As modifying the taint of an entry affects shallow clones,
|
||||
// we simply use the original
|
||||
c.mounts.SetTaint(path, true)
|
||||
|
||||
// Update the mount table
|
||||
if err := c.persistMounts(newTable); err != nil {
|
||||
if err := c.persistMounts(c.mounts); err != nil {
|
||||
return errors.New("failed to update mount table")
|
||||
}
|
||||
c.mounts = newTable
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remount is used to remount a path at a new mount point.
|
||||
func (c *Core) remount(src, dst string) error {
|
||||
c.mounts.Lock()
|
||||
defer c.mounts.Unlock()
|
||||
|
||||
// Ensure we end the path in a slash
|
||||
if !strings.HasSuffix(src, "/") {
|
||||
src += "/"
|
||||
|
@ -331,6 +325,9 @@ func (c *Core) remount(src, dst string) error {
|
|||
return fmt.Errorf("existing mount at '%s'", match)
|
||||
}
|
||||
|
||||
c.mountsLock.Lock()
|
||||
defer c.mountsLock.Unlock()
|
||||
|
||||
// Mark the entry as tainted
|
||||
if err := c.taintMountEntry(src); err != nil {
|
||||
return err
|
||||
|
@ -351,10 +348,8 @@ func (c *Core) remount(src, dst string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Update the entry in the mount table
|
||||
newTable := c.mounts.ShallowClone()
|
||||
var ent *MountEntry
|
||||
for _, ent = range newTable.Entries {
|
||||
for _, ent = range c.mounts.Entries {
|
||||
if ent.Path == src {
|
||||
ent.Path = dst
|
||||
ent.Tainted = false
|
||||
|
@ -363,12 +358,11 @@ func (c *Core) remount(src, dst string) error {
|
|||
}
|
||||
|
||||
// Update the mount table
|
||||
if err := c.persistMounts(newTable); err != nil {
|
||||
if err := c.persistMounts(c.mounts); err != nil {
|
||||
ent.Path = src
|
||||
ent.Tainted = true
|
||||
return errors.New("failed to update mount table")
|
||||
}
|
||||
c.mounts = newTable
|
||||
|
||||
// Remount the backend
|
||||
if err := c.router.Remount(src, dst); err != nil {
|
||||
|
@ -386,18 +380,23 @@ func (c *Core) remount(src, dst string) error {
|
|||
|
||||
// loadMounts is invoked as part of postUnseal to load the mount table
|
||||
func (c *Core) loadMounts() error {
|
||||
mountTable := &MountTable{}
|
||||
// Load the existing mount table
|
||||
raw, err := c.barrier.Get(coreMountConfigPath)
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to read mount table: %v", err)
|
||||
return errLoadMountsFailed
|
||||
}
|
||||
|
||||
c.mountsLock.Lock()
|
||||
defer c.mountsLock.Unlock()
|
||||
|
||||
if raw != nil {
|
||||
c.mounts = &MountTable{}
|
||||
if err := json.Unmarshal(raw.Value, c.mounts); err != nil {
|
||||
if err := json.Unmarshal(raw.Value, mountTable); err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to decode mount table: %v", err)
|
||||
return errLoadMountsFailed
|
||||
}
|
||||
c.mounts = mountTable
|
||||
}
|
||||
|
||||
// Ensure that required entries are loaded, or new ones
|
||||
|
@ -462,9 +461,13 @@ func (c *Core) persistMounts(table *MountTable) error {
|
|||
// setupMounts is invoked after we've loaded the mount table to
|
||||
// initialize the logical backends and setup the router
|
||||
func (c *Core) setupMounts() error {
|
||||
c.mountsLock.Lock()
|
||||
defer c.mountsLock.Unlock()
|
||||
|
||||
var backend logical.Backend
|
||||
var view *BarrierView
|
||||
var err error
|
||||
|
||||
for _, entry := range c.mounts.Entries {
|
||||
// Initialize the backend, special casing for system
|
||||
barrierPath := backendBarrierPrefix + entry.UUID + "/"
|
||||
|
@ -514,8 +517,12 @@ func (c *Core) setupMounts() error {
|
|||
// unloadMounts is used before we seal the vault to reset the mounts to
|
||||
// their unloaded state, calling Cleanup if defined. This is reversed by load and setup mounts.
|
||||
func (c *Core) unloadMounts() error {
|
||||
c.mountsLock.Lock()
|
||||
defer c.mountsLock.Unlock()
|
||||
|
||||
if c.mounts != nil {
|
||||
for _, e := range c.mounts.Entries {
|
||||
mountTable := c.mounts.ShallowClone()
|
||||
for _, e := range mountTable.Entries {
|
||||
prefix := e.Path
|
||||
b, ok := c.router.root.Get(prefix)
|
||||
if ok {
|
||||
|
@ -523,6 +530,7 @@ func (c *Core) unloadMounts() error {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.mounts = nil
|
||||
c.router = NewRouter()
|
||||
c.systemBarrierView = nil
|
||||
|
|
|
@ -29,7 +29,12 @@ const (
|
|||
// is in-flight at any given time within a single seal/unseal phase.
|
||||
type RollbackManager struct {
|
||||
logger *log.Logger
|
||||
mounts *MountTable
|
||||
|
||||
// This gives the current mount table, plus a RWMutex that is
|
||||
// locked for reading. It is up to the caller to RUnlock it
|
||||
// when done with the mount table
|
||||
mounts func() (*MountTable, *sync.RWMutex)
|
||||
|
||||
router *Router
|
||||
period time.Duration
|
||||
|
||||
|
@ -50,7 +55,7 @@ type rollbackState struct {
|
|||
}
|
||||
|
||||
// NewRollbackManager is used to create a new rollback manager
|
||||
func NewRollbackManager(logger *log.Logger, mounts *MountTable, router *Router) *RollbackManager {
|
||||
func NewRollbackManager(logger *log.Logger, mounts func() (*MountTable, *sync.RWMutex), router *Router) *RollbackManager {
|
||||
r := &RollbackManager{
|
||||
logger: logger,
|
||||
mounts: mounts,
|
||||
|
@ -101,12 +106,17 @@ func (m *RollbackManager) run() {
|
|||
|
||||
// triggerRollbacks is used to trigger the rollbacks across all the backends
|
||||
func (m *RollbackManager) triggerRollbacks() {
|
||||
m.mounts.RLock()
|
||||
defer m.mounts.RUnlock()
|
||||
m.inflightLock.Lock()
|
||||
defer m.inflightLock.Unlock()
|
||||
|
||||
for _, e := range m.mounts.Entries {
|
||||
mounts, mountsLock := m.mounts()
|
||||
mountsLock.RLock()
|
||||
defer mountsLock.RUnlock()
|
||||
if mounts == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, e := range mounts.Entries {
|
||||
if _, ok := m.inflight[e.Path]; !ok {
|
||||
m.startRollback(e.Path)
|
||||
}
|
||||
|
@ -127,6 +137,8 @@ func (m *RollbackManager) startRollback(path string) *rollbackState {
|
|||
// attemptRollback invokes a RollbackOperation for the given path
|
||||
func (m *RollbackManager) attemptRollback(path string, rs *rollbackState) (err error) {
|
||||
defer metrics.MeasureSince([]string{"rollback", "attempt", strings.Replace(path, "/", "-", -1)}, time.Now())
|
||||
m.logger.Printf("[DEBUG] rollback: attempting rollback on %s", path)
|
||||
|
||||
defer func() {
|
||||
rs.lastError = err
|
||||
rs.Done()
|
||||
|
@ -177,7 +189,10 @@ func (m *RollbackManager) Rollback(path string) error {
|
|||
|
||||
// startRollback is used to start the rollback manager after unsealing
|
||||
func (c *Core) startRollback() error {
|
||||
c.rollback = NewRollbackManager(c.logger, c.mounts, c.router)
|
||||
mountsFunc := func() (*MountTable, *sync.RWMutex) {
|
||||
return c.mounts, &c.mountsLock
|
||||
}
|
||||
c.rollback = NewRollbackManager(c.logger, mountsFunc, c.router)
|
||||
c.rollback.Start()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -25,8 +25,13 @@ func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
|
|||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
mountsMutex := &sync.RWMutex{}
|
||||
mountsFunc := func() (*MountTable, *sync.RWMutex) {
|
||||
return mounts, mountsMutex
|
||||
}
|
||||
|
||||
logger := log.New(os.Stderr, "", log.LstdFlags)
|
||||
rb := NewRollbackManager(logger, mounts, router)
|
||||
rb := NewRollbackManager(logger, mountsFunc, router)
|
||||
rb.period = 10 * time.Millisecond
|
||||
return rb, backend
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue