vault: rollback supports joining an inflight operation
This commit is contained in:
parent
c3aed5589e
commit
f231a6c67d
|
@ -2,12 +2,17 @@ package vault
|
|||
|
||||
import (
|
||||
"log"
|
||||
"sync/atomic"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
const (
|
||||
// rollbackPeriod is how often we attempt rollbacks for all the backends
|
||||
rollbackPeriod = time.Minute
|
||||
)
|
||||
|
||||
// RollbackManager is responsible for performing rollbacks of partial
|
||||
// secrets within logical backends.
|
||||
//
|
||||
|
@ -17,147 +22,169 @@ import (
|
|||
// This manager handles that by periodically (on a timer) requesting that the
|
||||
// backends clean up.
|
||||
//
|
||||
// The RollbackManager periodically (according to the Period option)
|
||||
// initiates a logical.RollbackOperation on every mounted logical backend.
|
||||
// It ensures that only one rollback operation is in-flight at any given
|
||||
// time within a single seal/unseal phase.
|
||||
// The RollbackManager periodically initiates a logical.RollbackOperation
|
||||
// on every mounted logical backend. It ensures that only one rollback operation
|
||||
// is in-flight at any given time within a single seal/unseal phase.
|
||||
type RollbackManager struct {
|
||||
// NOTE: This must always be at the top of the struct to avoid
|
||||
// atomic alignment issues. Go bug.
|
||||
running uint32
|
||||
logger *log.Logger
|
||||
mounts *MountTable
|
||||
router *Router
|
||||
period time.Duration
|
||||
|
||||
Logger *log.Logger
|
||||
Mounts *MountTable
|
||||
Router *Router
|
||||
inflightAll sync.WaitGroup
|
||||
inflight map[string]*rollbackState
|
||||
inflightLock sync.Mutex
|
||||
|
||||
Period time.Duration // time between rollback calls
|
||||
doneCh chan struct{}
|
||||
shutdown bool
|
||||
shutdownCh chan struct{}
|
||||
shutdownLock sync.Mutex
|
||||
}
|
||||
|
||||
// Start starts the rollback manager. This will block until Stop is called
|
||||
// so it should be executed within a goroutine.
|
||||
func (m *RollbackManager) Start() {
|
||||
// If we're already running, then don't start again
|
||||
if !atomic.CompareAndSwapUint32(&m.running, 0, 1) {
|
||||
return
|
||||
// rollbackState is used to track the state of a single rollback attempt
|
||||
type rollbackState struct {
|
||||
lastError error
|
||||
sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewRollbackManager is used to create a new rollback manager
|
||||
func NewRollbackManager(logger *log.Logger, mounts *MountTable, router *Router) *RollbackManager {
|
||||
r := &RollbackManager{
|
||||
logger: logger,
|
||||
mounts: mounts,
|
||||
router: router,
|
||||
period: rollbackPeriod,
|
||||
inflight: make(map[string]*rollbackState),
|
||||
doneCh: make(chan struct{}),
|
||||
shutdownCh: make(chan struct{}),
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
m.Logger.Printf("[INFO] rollback: starting rollback manager")
|
||||
// Start starts the rollback manager
|
||||
func (m *RollbackManager) Start() {
|
||||
go m.run()
|
||||
}
|
||||
|
||||
// mounts is a mapping of a mount path (i.e. "sys") to a uint32 pointer
|
||||
// we can do atomic operations on. The purpose of this map is to ensure
|
||||
// we only ever have one RollbackOperation request in-flight for each
|
||||
// path.
|
||||
//
|
||||
// When a RollbackOperation is started, the pointer is changed to 0 to 1
|
||||
// atomically. When the operation completes, it is atomatically loaded
|
||||
// to 0 (from anything). Before we start a rollback operation, we use a
|
||||
// CAS 0 to 1 and only start a rollback if that succeeds.
|
||||
//
|
||||
// As a result, we only ever get one in-flight request at one time.
|
||||
var mounts map[string]*uint32
|
||||
// Stop stops the running manager. This will wait for any in-flight
|
||||
// rollbacks to complete.
|
||||
func (m *RollbackManager) Stop() {
|
||||
m.shutdownLock.Lock()
|
||||
defer m.shutdownLock.Unlock()
|
||||
if !m.shutdown {
|
||||
m.shutdown = true
|
||||
close(m.shutdownCh)
|
||||
<-m.doneCh
|
||||
}
|
||||
m.inflightAll.Wait()
|
||||
}
|
||||
|
||||
tick := time.NewTicker(m.Period)
|
||||
// run is a long running routine to periodically invoke rollback
|
||||
func (m *RollbackManager) run() {
|
||||
m.logger.Printf("[INFO] rollback: starting rollback manager")
|
||||
tick := time.NewTicker(m.period)
|
||||
defer tick.Stop()
|
||||
defer close(m.doneCh)
|
||||
for {
|
||||
// Wait for the tick
|
||||
<-tick.C
|
||||
select {
|
||||
case <-tick.C:
|
||||
m.triggerRollbacks()
|
||||
|
||||
// If we're quitting, then stop
|
||||
if atomic.LoadUint32(&m.running) != 1 {
|
||||
m.Logger.Printf("[INFO] rollback: stopping rollback manager")
|
||||
case <-m.shutdownCh:
|
||||
m.logger.Printf("[INFO] rollback: stopping rollback manager")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get the list of paths that we should rollback and setup our
|
||||
// mounts mapping. Mounts that have since been unmounted will
|
||||
// just "fall off" naturally: they aren't in our new mount mapping
|
||||
// and when their goroutine ends they'll naturally lose the reference.
|
||||
//
|
||||
// The reason we make a new mapping is so that unmounted paths
|
||||
// are automatically removed. If a mount path was in the last mapping
|
||||
// we copy the uint32 pointer. So the result of the copy is: new
|
||||
// mount paths get a new uint32 pointer, unmounted paths are removed
|
||||
// from the map, and existing mount paths nothing changes.
|
||||
//
|
||||
// The purpose of the map is documented above where mounts is defined.
|
||||
newMounts := make(map[string]*uint32)
|
||||
m.Mounts.RLock()
|
||||
for _, e := range m.Mounts.Entries {
|
||||
if s, ok := mounts[e.Path]; ok {
|
||||
newMounts[e.Path] = s
|
||||
} else {
|
||||
newMounts[e.Path] = new(uint32)
|
||||
}
|
||||
}
|
||||
m.Mounts.RUnlock()
|
||||
mounts = newMounts
|
||||
// triggerRollbacks is used to trigger the rollbacks across all the backends
|
||||
func (m *RollbackManager) triggerRollbacks() {
|
||||
m.mounts.RLock()
|
||||
defer m.mounts.RUnlock()
|
||||
m.inflightLock.Lock()
|
||||
defer m.inflightLock.Unlock()
|
||||
|
||||
// Go through the mounts and start the rollback if we can
|
||||
for path, status := range mounts {
|
||||
// If we can change the status from 0 to 1, we can start it
|
||||
if !atomic.CompareAndSwapUint32(status, 0, 1) {
|
||||
continue
|
||||
}
|
||||
|
||||
go m.rollback(path, status)
|
||||
for _, e := range m.mounts.Entries {
|
||||
if _, ok := m.inflight[e.Path]; !ok {
|
||||
m.startRollback(e.Path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the running manager. This will not halt any in-flight
|
||||
// rollbacks.
|
||||
func (m *RollbackManager) Stop() {
|
||||
atomic.StoreUint32(&m.running, 0)
|
||||
// startRollback is used to start an async rollback attempt.
|
||||
// This must be called with the inflightLock held.
|
||||
func (m *RollbackManager) startRollback(path string) *rollbackState {
|
||||
rs := &rollbackState{}
|
||||
rs.Add(1)
|
||||
m.inflightAll.Add(1)
|
||||
m.inflight[path] = rs
|
||||
go m.attemptRollback(path, rs)
|
||||
return rs
|
||||
}
|
||||
|
||||
func (m *RollbackManager) rollback(path string, state *uint32) {
|
||||
defer atomic.StoreUint32(state, 0)
|
||||
// attemptRollback invokes a RollbackOperation for the given path
|
||||
func (m *RollbackManager) attemptRollback(path string, rs *rollbackState) (err error) {
|
||||
defer func() {
|
||||
rs.lastError = err
|
||||
rs.Done()
|
||||
m.inflightAll.Done()
|
||||
m.inflightLock.Lock()
|
||||
delete(m.inflight, path)
|
||||
m.inflightLock.Unlock()
|
||||
}()
|
||||
|
||||
m.Logger.Printf(
|
||||
"[DEBUG] rollback: starting rollback for %s",
|
||||
path)
|
||||
// Invoke a RollbackOperation
|
||||
m.logger.Printf("[DEBUG] rollback: starting rollback for %s", path)
|
||||
req := &logical.Request{
|
||||
Operation: logical.RollbackOperation,
|
||||
Path: path,
|
||||
}
|
||||
if _, err := m.Router.Route(req); err != nil {
|
||||
// If the error is an unsupported operation, then it doesn't
|
||||
// matter, the backend doesn't support it.
|
||||
if err == logical.ErrUnsupportedOperation {
|
||||
return
|
||||
}
|
||||
_, err = m.router.Route(req)
|
||||
|
||||
m.Logger.Printf(
|
||||
"[ERR] rollback: error rolling back %s: %s",
|
||||
// If the error is an unsupported operation, then it doesn't
|
||||
// matter, the backend doesn't support it.
|
||||
if err == logical.ErrUnsupportedOperation {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
m.logger.Printf("[ERR] rollback: error rolling back %s: %s",
|
||||
path, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Rollback is used to trigger an immediate rollback of the path,
|
||||
// or to join an existing rollback operation if in flight.
|
||||
func (m *RollbackManager) Rollback(path string) error {
|
||||
// Check for an existing attempt and start one if none
|
||||
m.inflightLock.Lock()
|
||||
rs, ok := m.inflight[path]
|
||||
if !ok {
|
||||
rs = m.startRollback(path)
|
||||
}
|
||||
m.inflightLock.Unlock()
|
||||
|
||||
// Wait for the attempt to finish
|
||||
rs.Wait()
|
||||
|
||||
// Return the last error
|
||||
return rs.lastError
|
||||
}
|
||||
|
||||
// The methods below are the hooks from core that are called pre/post seal.
|
||||
|
||||
// startRollback is used to start the rollback manager after unsealing
|
||||
func (c *Core) startRollback() error {
|
||||
// Ensure if we had a rollback it was stopped. This should never
|
||||
// be the case but it doesn't hurt to check.
|
||||
if c.rollback != nil {
|
||||
c.rollback.Stop()
|
||||
}
|
||||
|
||||
c.rollback = &RollbackManager{
|
||||
Logger: c.logger,
|
||||
Router: c.router,
|
||||
Mounts: c.mounts,
|
||||
Period: 1 * time.Minute,
|
||||
}
|
||||
go c.rollback.Start()
|
||||
|
||||
c.rollback = NewRollbackManager(c.logger, c.mounts, c.router)
|
||||
c.rollback.Start()
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopRollback is used to stop running the rollback manager before sealing
|
||||
func (c *Core) stopRollback() error {
|
||||
if c.rollback != nil {
|
||||
c.rollback.Stop()
|
||||
c.rollback = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package vault
|
|||
import (
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
@ -23,12 +24,9 @@ func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
|
|||
}
|
||||
|
||||
logger := log.New(os.Stderr, "", log.LstdFlags)
|
||||
return &RollbackManager{
|
||||
Logger: logger,
|
||||
Mounts: mounts,
|
||||
Router: router,
|
||||
Period: 10 * time.Millisecond,
|
||||
}, backend
|
||||
rb := NewRollbackManager(logger, mounts, router)
|
||||
rb.period = 10 * time.Millisecond
|
||||
return rb, backend
|
||||
}
|
||||
|
||||
func TestRollbackManager(t *testing.T) {
|
||||
|
@ -37,8 +35,8 @@ func TestRollbackManager(t *testing.T) {
|
|||
t.Fatalf("bad: %#v", backend)
|
||||
}
|
||||
|
||||
go m.Start()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
m.Start()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
m.Stop()
|
||||
|
||||
count := len(backend.Paths)
|
||||
|
@ -49,9 +47,47 @@ func TestRollbackManager(t *testing.T) {
|
|||
t.Fatalf("bad: %#v", backend)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if count != len(backend.Paths) {
|
||||
t.Fatalf("should stop requests: %#v", backend)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRollbackManager_Join(t *testing.T) {
|
||||
m, backend := mockRollback(t)
|
||||
if len(backend.Paths) > 0 {
|
||||
t.Fatalf("bad: %#v", backend)
|
||||
}
|
||||
|
||||
m.Start()
|
||||
defer m.Stop()
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(3)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := m.Rollback("foo")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := m.Rollback("foo")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := m.Rollback("foo")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue