Added utility on router to fetch mount entry using its ID

This commit is contained in:
vishalnayak 2017-05-05 12:54:37 -04:00 committed by Jeff Mitchell
parent 81e95c62e3
commit 4fe7fc4ef9
2 changed files with 41 additions and 6 deletions

View File

@ -14,9 +14,10 @@ import (
// Router is used to do prefix based routing of a request to a logical backend
type Router struct {
l sync.RWMutex
root *radix.Tree
tokenStoreSalt *salt.Salt
l sync.RWMutex
root *radix.Tree
mountEntryCache *radix.Tree
tokenStoreSalt *salt.Salt
// storagePrefix maps the prefix used for storage (ala the BarrierView)
// to the backend. This is used to map a key back into the backend that owns it.
@ -27,8 +28,9 @@ type Router struct {
// NewRouter returns a new router
func NewRouter() *Router {
r := &Router{
root: radix.New(),
storagePrefix: radix.New(),
root: radix.New(),
storagePrefix: radix.New(),
mountEntryCache: radix.New(),
}
return r
}
@ -74,8 +76,10 @@ func (r *Router) Mount(backend logical.Backend, prefix string, mountEntry *Mount
rootPaths: pathsToRadix(paths.Root),
loginPaths: pathsToRadix(paths.Unauthenticated),
}
r.root.Insert(prefix, re)
r.storagePrefix.Insert(storageView.prefix, re)
r.mountEntryCache.Insert(re.mountEntry.UUID, re.mountEntry)
return nil
}
@ -98,6 +102,8 @@ func (r *Router) Unmount(prefix string) error {
// Purge from the radix trees
r.root.Delete(prefix)
r.storagePrefix.Delete(re.storageView.prefix)
r.mountEntryCache.Delete(re.mountEntry.UUID)
return nil
}
@ -141,6 +147,22 @@ func (r *Router) Untaint(path string) error {
return nil
}
func (r *Router) MatchingMountByID(mountID string) *MountEntry {
if mountID == "" {
return nil
}
r.l.RLock()
defer r.l.RUnlock()
_, raw, ok := r.mountEntryCache.LongestPrefix(mountID)
if !ok {
return nil
}
return raw.(*MountEntry)
}
// MatchingMount returns the mount prefix that would be used for a path
func (r *Router) MatchingMount(path string) string {
r.l.RLock()

View File

@ -2,6 +2,7 @@ package vault
import (
"fmt"
"reflect"
"strings"
"sync"
"testing"
@ -75,8 +76,14 @@ func TestRouter_Mount(t *testing.T) {
if err != nil {
t.Fatal(err)
}
mountEntry := &MountEntry{
Path: "prod/aws/",
UUID: meUUID,
}
n := &NoopBackend{}
err = r.Mount(n, "prod/aws/", &MountEntry{Path: "prod/aws/", UUID: meUUID}, view)
err = r.Mount(n, "prod/aws/", mountEntry, view)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -85,6 +92,7 @@ func TestRouter_Mount(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = r.Mount(n, "prod/aws/", &MountEntry{UUID: meUUID}, view)
if !strings.Contains(err.Error(), "cannot mount under existing mount") {
t.Fatalf("err: %v", err)
@ -106,6 +114,11 @@ func TestRouter_Mount(t *testing.T) {
t.Fatalf("bad: %v", v)
}
mountEntryFetched := r.MatchingMountByID(mountEntry.UUID)
if mountEntryFetched == nil || !reflect.DeepEqual(mountEntry, mountEntryFetched) {
t.Fatalf("failed to fetch mount entry using its ID; mountEntry: %#v\n mountEntryFetched: %#v\n", mountEntry, mountEntryFetched)
}
mount, prefix, ok := r.MatchingStoragePrefix("logical/foo")
if !ok {
t.Fatalf("missing storage prefix")