Adding interface methods to logical.Backend for parity (#2242)

This commit is contained in:
Armon Dadgar 2017-01-07 15:18:22 -08:00 committed by Jeff Mitchell
parent 336dfed5c3
commit c37d17ed47
12 changed files with 165 additions and 14 deletions

View File

@ -68,6 +68,9 @@ type Backend struct {
// to the backend, if required.
Clean CleanupFunc
// Invalidate is called when a keys is modified if required
Invalidate InvalidateFunc
// AuthRenew is the callback to call when a RenewRequest for an
// authentication comes in. By default, renewal won't be allowed.
// See the built-in AuthRenew helpers in lease.go for common callbacks.
@ -92,6 +95,9 @@ type WALRollbackFunc func(*logical.Request, string, interface{}) error
// CleanupFunc is the callback for backend unload.
type CleanupFunc func()
// InvalidateFunc is the callback for backend key invalidation.
type InvalidateFunc func(string)
func (b *Backend) HandleExistenceCheck(req *logical.Request) (checkFound bool, exists bool, err error) {
b.once.Do(b.init)
@ -218,12 +224,20 @@ func (b *Backend) Setup(config *logical.BackendConfig) (logical.Backend, error)
return b, nil
}
// Cleanup is used to release resources and prepare to stop the backend
func (b *Backend) Cleanup() {
if b.Clean != nil {
b.Clean()
}
}
// InvalidateKey is used to clear caches and reset internal state on key changes
func (b *Backend) InvalidateKey(key string) {
if b.Invalidate != nil {
b.Invalidate(key)
}
}
// Logger can be used to get the logger. If no logger has been set,
// the logs will be discarded.
func (b *Backend) Logger() log.Logger {

View File

@ -35,7 +35,14 @@ type Backend interface {
// existence check function was found, the item exists or not.
HandleExistenceCheck(*Request) (bool, bool, error)
// Cleanup is invoked during an unmount of a backend to allow it to
// handle any cleanup like connection closing or releasing of file handles.
Cleanup()
// InvalidateKey may be invoked when an object is modified that belongs
// to the backend. The backend can use this to clear any caches or reset
// internal state as needed.
InvalidateKey(key string)
}
// BackendConfig is provided to the factory to initialize the backend

View File

@ -1,12 +1,18 @@
package logical
import (
"errors"
"fmt"
"strings"
"github.com/hashicorp/vault/helper/jsonutil"
)
// ErrReadOnly is returned when a backend does not support
// writing. This can be caused by a read-only replica or secondary
// cluster operation.
var ErrReadOnly = errors.New("Cannot write to readonly storage")
// Storage is the way that logical backends are able read/write data.
type Storage interface {
List(prefix string) ([]string, error)

View File

@ -30,6 +30,11 @@ type SystemView interface {
// Returns true if caching is disabled. If true, no caches should be used,
// despite known slowdowns.
CachingDisabled() bool
// IsPrimary checks if this is a primary Vault instance. This
// can be used to avoid writes on secondaries and to avoid doing
// lazy upgrades which may cause writes.
IsPrimary() bool
}
type StaticSystemView struct {
@ -38,6 +43,7 @@ type StaticSystemView struct {
SudoPrivilegeVal bool
TaintedVal bool
CachingDisabledVal bool
Primary bool
}
func (d StaticSystemView) DefaultLeaseTTL() time.Duration {
@ -59,3 +65,7 @@ func (d StaticSystemView) Tainted() bool {
func (d StaticSystemView) CachingDisabled() bool {
return d.CachingDisabledVal
}
func (d StaticSystemView) IsPrimary() bool {
return d.Primary
}

View File

@ -15,8 +15,9 @@ import (
// BarrierView implements logical.Storage so it can be passed in as the
// durable storage mechanism for logical views.
type BarrierView struct {
barrier BarrierStorage
prefix string
barrier BarrierStorage
prefix string
readonly bool
}
// NewBarrierView takes an underlying security barrier and returns
@ -68,6 +69,9 @@ func (v *BarrierView) Get(key string) (*logical.StorageEntry, error) {
// logical.Storage impl.
func (v *BarrierView) Put(entry *logical.StorageEntry) error {
if v.readonly {
return logical.ErrReadOnly
}
if err := v.sanityCheck(entry.Key); err != nil {
return err
}
@ -80,6 +84,9 @@ func (v *BarrierView) Put(entry *logical.StorageEntry) error {
// logical.Storage impl.
func (v *BarrierView) Delete(key string) error {
if v.readonly {
return logical.ErrReadOnly
}
if err := v.sanityCheck(key); err != nil {
return err
}
@ -89,7 +96,7 @@ func (v *BarrierView) Delete(key string) error {
// SubView constructs a nested sub-view using the given prefix
func (v *BarrierView) SubView(prefix string) *BarrierView {
sub := v.expandKey(prefix)
return &BarrierView{barrier: v.barrier, prefix: sub}
return &BarrierView{barrier: v.barrier, prefix: sub, readonly: v.readonly}
}
// expandKey is used to expand to the full key path with the prefix

View File

@ -282,3 +282,35 @@ func TestBarrierView_ClearView(t *testing.T) {
t.Fatalf("have keys: %#v", out)
}
}
func TestBarrierView_Readonly(t *testing.T) {
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "foo/")
// Add a key before enabling read-only
entry := &Entry{Key: "test", Value: []byte("test")}
if err := view.Put(entry.Logical()); err != nil {
t.Fatalf("err: %v", err)
}
// Enable read only mode
view.readonly = true
// Put should fail in readonly mode
if err := view.Put(entry.Logical()); err != logical.ErrReadOnly {
t.Fatalf("err: %v", err)
}
// Delete nested
if err := view.Delete("test"); err != logical.ErrReadOnly {
t.Fatalf("err: %v", err)
}
// Check the non-nested key
e, err := view.Get("test")
if err != nil {
t.Fatalf("err: %v", err)
}
if e == nil {
t.Fatalf("key test missing")
}
}

View File

@ -0,0 +1,8 @@
// +build vault
package vault
// IsPrimary checks if this is a primary Vault instance.
func (d dynamicSystemView) IsPrimary() bool {
return true
}

View File

@ -17,6 +17,9 @@ func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
mounts := new(MountTable)
router := NewRouter()
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")
mounts.Entries = []*MountEntry{
&MountEntry{
Path: "foo",
@ -26,7 +29,7 @@ func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
if err != nil {
t.Fatal(err)
}
if err := router.Mount(backend, "foo", &MountEntry{UUID: meUUID}, nil); err != nil {
if err := router.Mount(backend, "foo", &MountEntry{UUID: meUUID}, view); err != nil {
t.Fatalf("err: %s", err)
}

View File

@ -17,12 +17,18 @@ type Router struct {
l sync.RWMutex
root *radix.Tree
tokenStoreSalt *salt.Salt
// storagePrefix maps the prefix used for storage (ala the BarrierView)
// to the backend. This is used to map a key back into the backend that owns it.
// For example, logical/uuid1/foobar -> secrets/ (generic backend) + foobar
storagePrefix *radix.Tree
}
// NewRouter returns a new router
func NewRouter() *Router {
r := &Router{
root: radix.New(),
root: radix.New(),
storagePrefix: radix.New(),
}
return r
}
@ -69,6 +75,7 @@ func (r *Router) Mount(backend logical.Backend, prefix string, mountEntry *Mount
loginPaths: pathsToRadix(paths.Unauthenticated),
}
r.root.Insert(prefix, re)
r.storagePrefix.Insert(storageView.prefix, re)
return nil
}
@ -78,12 +85,19 @@ func (r *Router) Unmount(prefix string) error {
r.l.Lock()
defer r.l.Unlock()
// Call backend's Cleanup routine
re, ok := r.root.Get(prefix)
if ok {
re.(*routeEntry).backend.Cleanup()
// Fast-path out if the backend doesn't exist
raw, ok := r.root.Get(prefix)
if !ok {
return nil
}
// Call backend's Cleanup routine
re := raw.(*routeEntry)
re.backend.Cleanup()
// Purge from the radix trees
r.root.Delete(prefix)
r.storagePrefix.Delete(re.storageView.prefix)
return nil
}
@ -182,6 +196,23 @@ func (r *Router) MatchingSystemView(path string) logical.SystemView {
return raw.(*routeEntry).backend.System()
}
// MatchingStoragePrefix returns the mount path matching and storage prefix
// matching the given path
func (r *Router) MatchingStoragePrefix(path string) (string, string, bool) {
r.l.RLock()
_, raw, ok := r.storagePrefix.LongestPrefix(path)
r.l.RUnlock()
if !ok {
return "", "", false
}
// Extract the mount path and storage prefix
re := raw.(*routeEntry)
mountPath := re.mountEntry.Path
prefix := re.storageView.prefix
return mountPath, prefix, true
}
// Route is used to route a given request
func (r *Router) Route(req *logical.Request) (*logical.Response, error) {
resp, _, _, err := r.routeCommon(req, false)

View File

@ -57,6 +57,10 @@ func (n *NoopBackend) Cleanup() {
// noop
}
func (n *NoopBackend) InvalidateKey(string) {
// noop
}
func TestRouter_Mount(t *testing.T) {
r := NewRouter()
_, barrier, _ := mockBarrier(t)
@ -67,7 +71,7 @@ func TestRouter_Mount(t *testing.T) {
t.Fatal(err)
}
n := &NoopBackend{}
err = r.Mount(n, "prod/aws/", &MountEntry{UUID: meUUID}, view)
err = r.Mount(n, "prod/aws/", &MountEntry{Path: "prod/aws/", UUID: meUUID}, view)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -97,6 +101,14 @@ func TestRouter_Mount(t *testing.T) {
t.Fatalf("bad: %s", v)
}
mount, prefix, ok := r.MatchingStoragePrefix("logical/foo")
if !ok {
t.Fatalf("missing storage prefix")
}
if mount != "prod/aws/" || prefix != "logical/" {
t.Fatalf("Bad: %v - %v", mount, prefix)
}
req := &logical.Request{
Path: "prod/aws/foo",
}
@ -124,7 +136,7 @@ func TestRouter_Unmount(t *testing.T) {
t.Fatal(err)
}
n := &NoopBackend{}
err = r.Mount(n, "prod/aws/", &MountEntry{UUID: meUUID}, view)
err = r.Mount(n, "prod/aws/", &MountEntry{Path: "prod/aws/", UUID: meUUID}, view)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -141,6 +153,10 @@ func TestRouter_Unmount(t *testing.T) {
if !strings.Contains(err.Error(), "unsupported path") {
t.Fatalf("err: %v", err)
}
if _, _, ok := r.MatchingStoragePrefix("logical/foo"); ok {
t.Fatalf("should not have matching storage prefix")
}
}
func TestRouter_Remount(t *testing.T) {
@ -153,11 +169,13 @@ func TestRouter_Remount(t *testing.T) {
t.Fatal(err)
}
n := &NoopBackend{}
err = r.Mount(n, "prod/aws/", &MountEntry{UUID: meUUID}, view)
me := &MountEntry{Path: "prod/aws/", UUID: meUUID}
err = r.Mount(n, "prod/aws/", me, view)
if err != nil {
t.Fatalf("err: %v", err)
}
me.Path = "stage/aws/"
err = r.Remount("prod/aws/", "stage/aws/")
if err != nil {
t.Fatalf("err: %v", err)
@ -188,6 +206,15 @@ func TestRouter_Remount(t *testing.T) {
if len(n.Paths) != 1 || n.Paths[0] != "foo" {
t.Fatalf("bad: %v", n.Paths)
}
// Check the resolve from storage still works
mount, prefix, _ := r.MatchingStoragePrefix("logical/foobar")
if mount != "stage/aws/" {
t.Fatalf("bad mount: %s", mount)
}
if prefix != "logical/" {
t.Fatalf("Bad prefix: %s", prefix)
}
}
func TestRouter_RootPath(t *testing.T) {

View File

@ -381,6 +381,10 @@ func (n *rawHTTP) Cleanup() {
// noop
}
func (n *rawHTTP) InvalidateKey(string) {
// noop
}
func GenerateRandBytes(length int) ([]byte, error) {
if length < 0 {
return nil, fmt.Errorf("length must be >= 0")

View File

@ -593,11 +593,13 @@ func TestTokenStore_Revoke(t *testing.T) {
}
func TestTokenStore_Revoke_Leases(t *testing.T) {
_, ts, _, _ := TestCoreWithTokenStore(t)
c, ts, _, _ := TestCoreWithTokenStore(t)
view := NewBarrierView(c.barrier, "noop/")
// Mount a noop backend
noop := &NoopBackend{}
ts.expiration.router.Mount(noop, "", &MountEntry{UUID: ""}, nil)
ts.expiration.router.Mount(noop, "", &MountEntry{UUID: ""}, view)
ent := &TokenEntry{Path: "test", Policies: []string{"dev", "ops"}}
if err := ts.create(ent); err != nil {