Rejig locks during unmount/remount. (#1855)

This commit is contained in:
Jeff Mitchell 2016-09-13 11:50:14 -04:00 committed by GitHub
parent ac5ea8ccc2
commit fffee5611a
4 changed files with 42 additions and 46 deletions

View File

@ -82,7 +82,7 @@ func (c *Core) enableAudit(entry *MountEntry) error {
return err
}
newTable := c.audit.ShallowClone()
newTable := c.audit.shallowClone()
newTable.Entries = append(newTable.Entries, entry)
if err := c.persistAudit(newTable); err != nil {
return errors.New("failed to update audit table")
@ -109,8 +109,8 @@ func (c *Core) disableAudit(path string) error {
c.auditLock.Lock()
defer c.auditLock.Unlock()
newTable := c.audit.ShallowClone()
found := newTable.Remove(path)
newTable := c.audit.shallowClone()
found := newTable.remove(path)
// Ensure there was a match
if !found {

View File

@ -81,7 +81,7 @@ func (c *Core) enableCredential(entry *MountEntry) error {
}
// Update the auth table
newTable := c.auth.ShallowClone()
newTable := c.auth.shallowClone()
newTable.Entries = append(newTable.Entries, entry)
if err := c.persistAuth(newTable); err != nil {
return errors.New("failed to update auth table")
@ -162,8 +162,8 @@ func (c *Core) disableCredential(path string) error {
// removeCredEntry is used to remove an entry in the auth table
func (c *Core) removeCredEntry(path string) error {
// Taint the entry from the auth table
newTable := c.auth.ShallowClone()
newTable.Remove(path)
newTable := c.auth.shallowClone()
newTable.remove(path)
// Update the auth table
if err := c.persistAuth(newTable); err != nil {
@ -180,7 +180,7 @@ func (c *Core) taintCredEntry(path string) error {
// Taint the entry from the auth table
// We do this on the original since setting the taint operates
// on the entries which a shallow clone shares anyways
found := c.auth.SetTaint(path, true)
found := c.auth.setTaint(path, true)
// Ensure there was a match
if !found {

View File

@ -64,11 +64,11 @@ type MountTable struct {
Entries []*MountEntry `json:"entries"`
}
// ShallowClone returns a copy of the mount table that
// shallowClone returns a copy of the mount table that
// keeps the MountEntry locations, so as not to invalidate
// other locations holding pointers. Care needs to be taken
// if modifying entries rather than modifying the table itself
func (t *MountTable) ShallowClone() *MountTable {
func (t *MountTable) shallowClone() *MountTable {
mt := &MountTable{
Type: t.Type,
Entries: make([]*MountEntry, len(t.Entries)),
@ -89,19 +89,8 @@ func (t *MountTable) Hash() ([]byte, error) {
return hash[:], nil
}
// Find is used to lookup an entry
func (t *MountTable) Find(path string) *MountEntry {
n := len(t.Entries)
for i := 0; i < n; i++ {
if t.Entries[i].Path == path {
return t.Entries[i]
}
}
return nil
}
// SetTaint is used to set the taint on given entry
func (t *MountTable) SetTaint(path string, value bool) bool {
// setTaint is used to set the taint on given entry
func (t *MountTable) setTaint(path string, value bool) bool {
n := len(t.Entries)
for i := 0; i < n; i++ {
if t.Entries[i].Path == path {
@ -112,8 +101,8 @@ func (t *MountTable) SetTaint(path string, value bool) bool {
return false
}
// Remove is used to remove a given path entry
func (t *MountTable) Remove(path string) bool {
// remove is used to remove a given path entry
func (t *MountTable) remove(path string) bool {
n := len(t.Entries)
for i := 0; i < n; i++ {
if t.Entries[i].Path == path {
@ -186,9 +175,6 @@ func (c *Core) mount(me *MountEntry) error {
return logical.CodedError(409, fmt.Sprintf("existing mount at %s", match))
}
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
// Generate a new UUID and view
meUUID, err := uuid.GenerateUUID()
if err != nil {
@ -203,12 +189,15 @@ func (c *Core) mount(me *MountEntry) error {
}
// Update the mount table
newTable := c.mounts.ShallowClone()
c.mountsLock.Lock()
newTable := c.mounts.shallowClone()
newTable.Entries = append(newTable.Entries, me)
if err := c.persistMounts(newTable); err != nil {
c.mountsLock.Unlock()
return logical.CodedError(500, "failed to update mount table")
}
c.mounts = newTable
c.mountsLock.Unlock()
// Mount the backend
if err := c.router.Mount(backend, me.Path, me, view); err != nil {
@ -240,18 +229,16 @@ func (c *Core) unmount(path string) error {
return fmt.Errorf("no matching mount")
}
// Store the view for this backend
// Get the view for this backend
view := c.router.MatchingStorageView(path)
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
// Mark the entry as tainted
if err := c.taintMountEntry(path); err != nil {
return err
}
// Taint the router path to prevent routing
// Taint the router path to prevent routing. Note that in-flight requests
// are uncertain, right now.
if err := c.router.Taint(path); err != nil {
return err
}
@ -288,9 +275,12 @@ func (c *Core) unmount(path string) error {
// removeMountEntry is used to remove an entry from the mount table
func (c *Core) removeMountEntry(path string) error {
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
// Remove the entry from the mount table
newTable := c.mounts.ShallowClone()
newTable.Remove(path)
newTable := c.mounts.shallowClone()
newTable.remove(path)
// Update the mount table
if err := c.persistMounts(newTable); err != nil {
@ -303,9 +293,12 @@ func (c *Core) removeMountEntry(path string) error {
// taintMountEntry is used to mark an entry in the mount table as tainted
func (c *Core) taintMountEntry(path string) error {
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
// As modifying the taint of an entry affects shallow clones,
// we simply use the original
c.mounts.SetTaint(path, true)
c.mounts.setTaint(path, true)
// Update the mount table
if err := c.persistMounts(c.mounts); err != nil {
@ -342,9 +335,6 @@ func (c *Core) remount(src, dst string) error {
return fmt.Errorf("existing mount at '%s'", match)
}
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
// Mark the entry as tainted
if err := c.taintMountEntry(src); err != nil {
return err
@ -365,6 +355,7 @@ func (c *Core) remount(src, dst string) error {
return err
}
c.mountsLock.Lock()
var ent *MountEntry
for _, ent = range c.mounts.Entries {
if ent.Path == src {
@ -378,8 +369,10 @@ func (c *Core) remount(src, dst string) error {
if err := c.persistMounts(c.mounts); err != nil {
ent.Path = src
ent.Tainted = true
c.mountsLock.Unlock()
return logical.CodedError(500, "failed to update mount table")
}
c.mountsLock.Unlock()
// Remount the backend
if err := c.router.Remount(src, dst); err != nil {
@ -570,7 +563,7 @@ func (c *Core) unloadMounts() error {
defer c.mountsLock.Unlock()
if c.mounts != nil {
mountTable := c.mounts.ShallowClone()
mountTable := c.mounts.shallowClone()
for _, e := range mountTable.Entries {
prefix := e.Path
b, ok := c.router.root.Get(prefix)

View File

@ -41,7 +41,7 @@ type RollbackManager struct {
inflightAll sync.WaitGroup
inflight map[string]*rollbackState
inflightLock sync.Mutex
inflightLock sync.RWMutex
doneCh chan struct{}
shutdown bool
@ -107,8 +107,6 @@ func (m *RollbackManager) run() {
// triggerRollbacks is used to trigger the rollbacks across all the backends
func (m *RollbackManager) triggerRollbacks() {
m.inflightLock.Lock()
defer m.inflightLock.Unlock()
backends := m.backends()
@ -117,7 +115,10 @@ func (m *RollbackManager) triggerRollbacks() {
if e.Table == credentialTableType {
path = "auth/" + path
}
if _, ok := m.inflight[path]; !ok {
m.inflightLock.RLock()
_, ok := m.inflight[path]
m.inflightLock.RUnlock()
if !ok {
m.startRollback(path)
}
}
@ -129,7 +130,9 @@ func (m *RollbackManager) startRollback(path string) *rollbackState {
rs := &rollbackState{}
rs.Add(1)
m.inflightAll.Add(1)
m.inflightLock.Lock()
m.inflight[path] = rs
m.inflightLock.Unlock()
go m.attemptRollback(path, rs)
return rs
}
@ -172,12 +175,12 @@ func (m *RollbackManager) attemptRollback(path string, rs *rollbackState) (err e
// or to join an existing rollback operation if in flight.
func (m *RollbackManager) Rollback(path string) error {
// Check for an existing attempt and start one if none
m.inflightLock.Lock()
m.inflightLock.RLock()
rs, ok := m.inflight[path]
m.inflightLock.RUnlock()
if !ok {
rs = m.startRollback(path)
}
m.inflightLock.Unlock()
// Wait for the attempt to finish
rs.Wait()