vault: RollbackManager
There are some major TODO items here, and it isn't hooked into the core yet, but the basic functionality is there.
This commit is contained in:
parent
abe0859aa5
commit
c7b9148841
|
@ -50,6 +50,7 @@ const (
|
|||
ListOperation = "list"
|
||||
RevokeOperation = "revoke"
|
||||
RenewOperation = "renew"
|
||||
RollbackOperation = "rollback"
|
||||
HelpOperation = "help"
|
||||
)
|
||||
|
||||
|
|
103
vault/rollback.go
Normal file
103
vault/rollback.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
// RollbackManager is responsible for performing rollbacks of partial
|
||||
// secrets within logical backends.
|
||||
//
|
||||
// During normal operations, it is possible for logical backends to
|
||||
// error partially through an operation. These are called "partial secrets":
|
||||
// they are never sent back to a user, but they do need to be cleaned up.
|
||||
// This manager handles that by periodically (on a timer) requesting that the
|
||||
// backends clean up.
|
||||
//
|
||||
// The RollbackManager periodically (according to the Period option)
|
||||
// initiates a logical.RollbackOperation on every mounted logical backend.
|
||||
// It ensures that only one rollback operation is in-flight at any given
|
||||
// time within a single seal/unseal phase.
|
||||
type RollbackManager struct {
|
||||
Logger *log.Logger
|
||||
Mounts *MountTable
|
||||
Router *Router
|
||||
|
||||
Period time.Duration // time between rollback calls
|
||||
|
||||
running uint32
|
||||
}
|
||||
|
||||
// Start starts the rollback manager. This will block until Stop is called
|
||||
// so it should be executed within a goroutine.
|
||||
func (m *RollbackManager) Start() {
|
||||
// If we're already running, then don't start again
|
||||
if !atomic.CompareAndSwapUint32(&m.running, 0, 1) {
|
||||
return
|
||||
}
|
||||
|
||||
var mounts map[string]*uint32
|
||||
tick := time.NewTicker(m.Period)
|
||||
defer tick.Stop()
|
||||
for {
|
||||
// Wait for the tick
|
||||
<-tick.C
|
||||
|
||||
// If we're quitting, then stop
|
||||
if atomic.LoadUint32(&m.running) != 1 {
|
||||
return
|
||||
}
|
||||
|
||||
// Get the list of paths that we should rollback and setup our
|
||||
// mounts mapping. Mounts that have since been unmounted will
|
||||
// just "fall off" naturally: they aren't in our new mount mapping
|
||||
// and when their goroutine ends they'll naturally lose the reference.
|
||||
newMounts := make(map[string]*uint32)
|
||||
m.Mounts.RLock()
|
||||
for _, e := range m.Mounts.Entries {
|
||||
if s, ok := mounts[e.Path]; ok {
|
||||
newMounts[e.Path] = s
|
||||
} else {
|
||||
newMounts[e.Path] = new(uint32)
|
||||
}
|
||||
}
|
||||
m.Mounts.RUnlock()
|
||||
mounts = newMounts
|
||||
|
||||
// Go through the mounts and start the rollback if we can
|
||||
for path, status := range mounts {
|
||||
// If we can change the status from 0 to 1, we can start it
|
||||
if !atomic.CompareAndSwapUint32(status, 0, 1) {
|
||||
continue
|
||||
}
|
||||
|
||||
go m.rollback(path, status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the running manager. This will not halt any in-flight
|
||||
// rollbacks.
|
||||
func (m *RollbackManager) Stop() {
|
||||
atomic.StoreUint32(&m.running, 0)
|
||||
}
|
||||
|
||||
func (m *RollbackManager) rollback(path string, state *uint32) {
|
||||
defer atomic.StoreUint32(state, 0)
|
||||
|
||||
m.Logger.Printf(
|
||||
"[DEBUG] rollback: starting rollback for %s",
|
||||
path)
|
||||
req := &logical.Request{
|
||||
Operation: logical.RollbackOperation,
|
||||
Path: path,
|
||||
}
|
||||
if _, err := m.Router.Route(req); err != nil {
|
||||
m.Logger.Printf(
|
||||
"[ERR] rollback: error rolling back %s: %s",
|
||||
path, err)
|
||||
}
|
||||
}
|
57
vault/rollback_test.go
Normal file
57
vault/rollback_test.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockRollback returns a mock rollback manager
|
||||
func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
|
||||
backend := new(NoopBackend)
|
||||
mounts := new(MountTable)
|
||||
router := NewRouter()
|
||||
|
||||
mounts.Entries = []*MountEntry{
|
||||
&MountEntry{
|
||||
Path: "foo",
|
||||
},
|
||||
}
|
||||
if err := router.Mount(backend, "noop", "foo", nil); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
logger := log.New(os.Stderr, "", log.LstdFlags)
|
||||
return &RollbackManager{
|
||||
Logger: logger,
|
||||
Mounts: mounts,
|
||||
Router: router,
|
||||
Period: 10 * time.Millisecond,
|
||||
}, backend
|
||||
}
|
||||
|
||||
func TestRollbackManager(t *testing.T) {
|
||||
m, backend := mockRollback(t)
|
||||
if len(backend.Paths) > 0 {
|
||||
t.Fatalf("bad: %#v", backend)
|
||||
}
|
||||
|
||||
go m.Start()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
m.Stop()
|
||||
|
||||
count := len(backend.Paths)
|
||||
if count == 0 {
|
||||
t.Fatalf("bad: %#v", backend)
|
||||
}
|
||||
if backend.Paths[0] != "" {
|
||||
t.Fatalf("bad: %#v", backend)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if count != len(backend.Paths) {
|
||||
t.Fatalf("should stop requests: %#v", backend)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue