Adding interface methods to logical.Backend for parity (#2242)

This commit is contained in:
Armon Dadgar 2017-01-07 15:18:22 -08:00 committed by Jeff Mitchell
parent 336dfed5c3
commit c37d17ed47
12 changed files with 165 additions and 14 deletions

View File

@ -68,6 +68,9 @@ type Backend struct {
// to the backend, if required. // to the backend, if required.
Clean CleanupFunc Clean CleanupFunc
// Invalidate is called when a keys is modified if required
Invalidate InvalidateFunc
// AuthRenew is the callback to call when a RenewRequest for an // AuthRenew is the callback to call when a RenewRequest for an
// authentication comes in. By default, renewal won't be allowed. // authentication comes in. By default, renewal won't be allowed.
// See the built-in AuthRenew helpers in lease.go for common callbacks. // See the built-in AuthRenew helpers in lease.go for common callbacks.
@ -92,6 +95,9 @@ type WALRollbackFunc func(*logical.Request, string, interface{}) error
// CleanupFunc is the callback for backend unload. // CleanupFunc is the callback for backend unload.
type CleanupFunc func() type CleanupFunc func()
// InvalidateFunc is the callback for backend key invalidation.
type InvalidateFunc func(string)
func (b *Backend) HandleExistenceCheck(req *logical.Request) (checkFound bool, exists bool, err error) { func (b *Backend) HandleExistenceCheck(req *logical.Request) (checkFound bool, exists bool, err error) {
b.once.Do(b.init) b.once.Do(b.init)
@ -218,12 +224,20 @@ func (b *Backend) Setup(config *logical.BackendConfig) (logical.Backend, error)
return b, nil return b, nil
} }
// Cleanup is used to release resources and prepare to stop the backend
func (b *Backend) Cleanup() { func (b *Backend) Cleanup() {
if b.Clean != nil { if b.Clean != nil {
b.Clean() b.Clean()
} }
} }
// InvalidateKey is used to clear caches and reset internal state on key changes
func (b *Backend) InvalidateKey(key string) {
if b.Invalidate != nil {
b.Invalidate(key)
}
}
// Logger can be used to get the logger. If no logger has been set, // Logger can be used to get the logger. If no logger has been set,
// the logs will be discarded. // the logs will be discarded.
func (b *Backend) Logger() log.Logger { func (b *Backend) Logger() log.Logger {

View File

@ -35,7 +35,14 @@ type Backend interface {
// existence check function was found, the item exists or not. // existence check function was found, the item exists or not.
HandleExistenceCheck(*Request) (bool, bool, error) HandleExistenceCheck(*Request) (bool, bool, error)
// Cleanup is invoked during an unmount of a backend to allow it to
// handle any cleanup like connection closing or releasing of file handles.
Cleanup() Cleanup()
// InvalidateKey may be invoked when an object is modified that belongs
// to the backend. The backend can use this to clear any caches or reset
// internal state as needed.
InvalidateKey(key string)
} }
// BackendConfig is provided to the factory to initialize the backend // BackendConfig is provided to the factory to initialize the backend

View File

@ -1,12 +1,18 @@
package logical package logical
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
"github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/jsonutil"
) )
// ErrReadOnly is returned when a backend does not support
// writing. This can be caused by a read-only replica or secondary
// cluster operation.
var ErrReadOnly = errors.New("Cannot write to readonly storage")
// Storage is the way that logical backends are able read/write data. // Storage is the way that logical backends are able read/write data.
type Storage interface { type Storage interface {
List(prefix string) ([]string, error) List(prefix string) ([]string, error)

View File

@ -30,6 +30,11 @@ type SystemView interface {
// Returns true if caching is disabled. If true, no caches should be used, // Returns true if caching is disabled. If true, no caches should be used,
// despite known slowdowns. // despite known slowdowns.
CachingDisabled() bool CachingDisabled() bool
// IsPrimary checks if this is a primary Vault instance. This
// can be used to avoid writes on secondaries and to avoid doing
// lazy upgrades which may cause writes.
IsPrimary() bool
} }
type StaticSystemView struct { type StaticSystemView struct {
@ -38,6 +43,7 @@ type StaticSystemView struct {
SudoPrivilegeVal bool SudoPrivilegeVal bool
TaintedVal bool TaintedVal bool
CachingDisabledVal bool CachingDisabledVal bool
Primary bool
} }
func (d StaticSystemView) DefaultLeaseTTL() time.Duration { func (d StaticSystemView) DefaultLeaseTTL() time.Duration {
@ -59,3 +65,7 @@ func (d StaticSystemView) Tainted() bool {
func (d StaticSystemView) CachingDisabled() bool { func (d StaticSystemView) CachingDisabled() bool {
return d.CachingDisabledVal return d.CachingDisabledVal
} }
func (d StaticSystemView) IsPrimary() bool {
return d.Primary
}

View File

@ -15,8 +15,9 @@ import (
// BarrierView implements logical.Storage so it can be passed in as the // BarrierView implements logical.Storage so it can be passed in as the
// durable storage mechanism for logical views. // durable storage mechanism for logical views.
type BarrierView struct { type BarrierView struct {
barrier BarrierStorage barrier BarrierStorage
prefix string prefix string
readonly bool
} }
// NewBarrierView takes an underlying security barrier and returns // NewBarrierView takes an underlying security barrier and returns
@ -68,6 +69,9 @@ func (v *BarrierView) Get(key string) (*logical.StorageEntry, error) {
// logical.Storage impl. // logical.Storage impl.
func (v *BarrierView) Put(entry *logical.StorageEntry) error { func (v *BarrierView) Put(entry *logical.StorageEntry) error {
if v.readonly {
return logical.ErrReadOnly
}
if err := v.sanityCheck(entry.Key); err != nil { if err := v.sanityCheck(entry.Key); err != nil {
return err return err
} }
@ -80,6 +84,9 @@ func (v *BarrierView) Put(entry *logical.StorageEntry) error {
// logical.Storage impl. // logical.Storage impl.
func (v *BarrierView) Delete(key string) error { func (v *BarrierView) Delete(key string) error {
if v.readonly {
return logical.ErrReadOnly
}
if err := v.sanityCheck(key); err != nil { if err := v.sanityCheck(key); err != nil {
return err return err
} }
@ -89,7 +96,7 @@ func (v *BarrierView) Delete(key string) error {
// SubView constructs a nested sub-view using the given prefix // SubView constructs a nested sub-view using the given prefix
func (v *BarrierView) SubView(prefix string) *BarrierView { func (v *BarrierView) SubView(prefix string) *BarrierView {
sub := v.expandKey(prefix) sub := v.expandKey(prefix)
return &BarrierView{barrier: v.barrier, prefix: sub} return &BarrierView{barrier: v.barrier, prefix: sub, readonly: v.readonly}
} }
// expandKey is used to expand to the full key path with the prefix // expandKey is used to expand to the full key path with the prefix

View File

@ -282,3 +282,35 @@ func TestBarrierView_ClearView(t *testing.T) {
t.Fatalf("have keys: %#v", out) t.Fatalf("have keys: %#v", out)
} }
} }
func TestBarrierView_Readonly(t *testing.T) {
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "foo/")
// Add a key before enabling read-only
entry := &Entry{Key: "test", Value: []byte("test")}
if err := view.Put(entry.Logical()); err != nil {
t.Fatalf("err: %v", err)
}
// Enable read only mode
view.readonly = true
// Put should fail in readonly mode
if err := view.Put(entry.Logical()); err != logical.ErrReadOnly {
t.Fatalf("err: %v", err)
}
// Delete nested
if err := view.Delete("test"); err != logical.ErrReadOnly {
t.Fatalf("err: %v", err)
}
// Check the non-nested key
e, err := view.Get("test")
if err != nil {
t.Fatalf("err: %v", err)
}
if e == nil {
t.Fatalf("key test missing")
}
}

View File

@ -0,0 +1,8 @@
// +build vault
package vault
// IsPrimary checks if this is a primary Vault instance.
func (d dynamicSystemView) IsPrimary() bool {
return true
}

View File

@ -17,6 +17,9 @@ func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
mounts := new(MountTable) mounts := new(MountTable)
router := NewRouter() router := NewRouter()
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")
mounts.Entries = []*MountEntry{ mounts.Entries = []*MountEntry{
&MountEntry{ &MountEntry{
Path: "foo", Path: "foo",
@ -26,7 +29,7 @@ func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := router.Mount(backend, "foo", &MountEntry{UUID: meUUID}, nil); err != nil { if err := router.Mount(backend, "foo", &MountEntry{UUID: meUUID}, view); err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }

View File

@ -17,12 +17,18 @@ type Router struct {
l sync.RWMutex l sync.RWMutex
root *radix.Tree root *radix.Tree
tokenStoreSalt *salt.Salt tokenStoreSalt *salt.Salt
// storagePrefix maps the prefix used for storage (ala the BarrierView)
// to the backend. This is used to map a key back into the backend that owns it.
// For example, logical/uuid1/foobar -> secrets/ (generic backend) + foobar
storagePrefix *radix.Tree
} }
// NewRouter returns a new router // NewRouter returns a new router
func NewRouter() *Router { func NewRouter() *Router {
r := &Router{ r := &Router{
root: radix.New(), root: radix.New(),
storagePrefix: radix.New(),
} }
return r return r
} }
@ -69,6 +75,7 @@ func (r *Router) Mount(backend logical.Backend, prefix string, mountEntry *Mount
loginPaths: pathsToRadix(paths.Unauthenticated), loginPaths: pathsToRadix(paths.Unauthenticated),
} }
r.root.Insert(prefix, re) r.root.Insert(prefix, re)
r.storagePrefix.Insert(storageView.prefix, re)
return nil return nil
} }
@ -78,12 +85,19 @@ func (r *Router) Unmount(prefix string) error {
r.l.Lock() r.l.Lock()
defer r.l.Unlock() defer r.l.Unlock()
// Call backend's Cleanup routine // Fast-path out if the backend doesn't exist
re, ok := r.root.Get(prefix) raw, ok := r.root.Get(prefix)
if ok { if !ok {
re.(*routeEntry).backend.Cleanup() return nil
} }
// Call backend's Cleanup routine
re := raw.(*routeEntry)
re.backend.Cleanup()
// Purge from the radix trees
r.root.Delete(prefix) r.root.Delete(prefix)
r.storagePrefix.Delete(re.storageView.prefix)
return nil return nil
} }
@ -182,6 +196,23 @@ func (r *Router) MatchingSystemView(path string) logical.SystemView {
return raw.(*routeEntry).backend.System() return raw.(*routeEntry).backend.System()
} }
// MatchingStoragePrefix returns the mount path matching and storage prefix
// matching the given path
func (r *Router) MatchingStoragePrefix(path string) (string, string, bool) {
r.l.RLock()
_, raw, ok := r.storagePrefix.LongestPrefix(path)
r.l.RUnlock()
if !ok {
return "", "", false
}
// Extract the mount path and storage prefix
re := raw.(*routeEntry)
mountPath := re.mountEntry.Path
prefix := re.storageView.prefix
return mountPath, prefix, true
}
// Route is used to route a given request // Route is used to route a given request
func (r *Router) Route(req *logical.Request) (*logical.Response, error) { func (r *Router) Route(req *logical.Request) (*logical.Response, error) {
resp, _, _, err := r.routeCommon(req, false) resp, _, _, err := r.routeCommon(req, false)

View File

@ -57,6 +57,10 @@ func (n *NoopBackend) Cleanup() {
// noop // noop
} }
func (n *NoopBackend) InvalidateKey(string) {
// noop
}
func TestRouter_Mount(t *testing.T) { func TestRouter_Mount(t *testing.T) {
r := NewRouter() r := NewRouter()
_, barrier, _ := mockBarrier(t) _, barrier, _ := mockBarrier(t)
@ -67,7 +71,7 @@ func TestRouter_Mount(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
n := &NoopBackend{} n := &NoopBackend{}
err = r.Mount(n, "prod/aws/", &MountEntry{UUID: meUUID}, view) err = r.Mount(n, "prod/aws/", &MountEntry{Path: "prod/aws/", UUID: meUUID}, view)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -97,6 +101,14 @@ func TestRouter_Mount(t *testing.T) {
t.Fatalf("bad: %s", v) t.Fatalf("bad: %s", v)
} }
mount, prefix, ok := r.MatchingStoragePrefix("logical/foo")
if !ok {
t.Fatalf("missing storage prefix")
}
if mount != "prod/aws/" || prefix != "logical/" {
t.Fatalf("Bad: %v - %v", mount, prefix)
}
req := &logical.Request{ req := &logical.Request{
Path: "prod/aws/foo", Path: "prod/aws/foo",
} }
@ -124,7 +136,7 @@ func TestRouter_Unmount(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
n := &NoopBackend{} n := &NoopBackend{}
err = r.Mount(n, "prod/aws/", &MountEntry{UUID: meUUID}, view) err = r.Mount(n, "prod/aws/", &MountEntry{Path: "prod/aws/", UUID: meUUID}, view)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -141,6 +153,10 @@ func TestRouter_Unmount(t *testing.T) {
if !strings.Contains(err.Error(), "unsupported path") { if !strings.Contains(err.Error(), "unsupported path") {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
if _, _, ok := r.MatchingStoragePrefix("logical/foo"); ok {
t.Fatalf("should not have matching storage prefix")
}
} }
func TestRouter_Remount(t *testing.T) { func TestRouter_Remount(t *testing.T) {
@ -153,11 +169,13 @@ func TestRouter_Remount(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
n := &NoopBackend{} n := &NoopBackend{}
err = r.Mount(n, "prod/aws/", &MountEntry{UUID: meUUID}, view) me := &MountEntry{Path: "prod/aws/", UUID: meUUID}
err = r.Mount(n, "prod/aws/", me, view)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
me.Path = "stage/aws/"
err = r.Remount("prod/aws/", "stage/aws/") err = r.Remount("prod/aws/", "stage/aws/")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
@ -188,6 +206,15 @@ func TestRouter_Remount(t *testing.T) {
if len(n.Paths) != 1 || n.Paths[0] != "foo" { if len(n.Paths) != 1 || n.Paths[0] != "foo" {
t.Fatalf("bad: %v", n.Paths) t.Fatalf("bad: %v", n.Paths)
} }
// Check the resolve from storage still works
mount, prefix, _ := r.MatchingStoragePrefix("logical/foobar")
if mount != "stage/aws/" {
t.Fatalf("bad mount: %s", mount)
}
if prefix != "logical/" {
t.Fatalf("Bad prefix: %s", prefix)
}
} }
func TestRouter_RootPath(t *testing.T) { func TestRouter_RootPath(t *testing.T) {

View File

@ -381,6 +381,10 @@ func (n *rawHTTP) Cleanup() {
// noop // noop
} }
func (n *rawHTTP) InvalidateKey(string) {
// noop
}
func GenerateRandBytes(length int) ([]byte, error) { func GenerateRandBytes(length int) ([]byte, error) {
if length < 0 { if length < 0 {
return nil, fmt.Errorf("length must be >= 0") return nil, fmt.Errorf("length must be >= 0")

View File

@ -593,11 +593,13 @@ func TestTokenStore_Revoke(t *testing.T) {
} }
func TestTokenStore_Revoke_Leases(t *testing.T) { func TestTokenStore_Revoke_Leases(t *testing.T) {
_, ts, _, _ := TestCoreWithTokenStore(t) c, ts, _, _ := TestCoreWithTokenStore(t)
view := NewBarrierView(c.barrier, "noop/")
// Mount a noop backend // Mount a noop backend
noop := &NoopBackend{} noop := &NoopBackend{}
ts.expiration.router.Mount(noop, "", &MountEntry{UUID: ""}, nil) ts.expiration.router.Mount(noop, "", &MountEntry{UUID: ""}, view)
ent := &TokenEntry{Path: "test", Policies: []string{"dev", "ops"}} ent := &TokenEntry{Path: "test", Policies: []string{"dev", "ops"}}
if err := ts.create(ent); err != nil { if err := ts.create(ent); err != nil {