Added utility on router to fetch mount entry using its ID
This commit is contained in:
parent
81e95c62e3
commit
4fe7fc4ef9
|
@ -14,9 +14,10 @@ import (
|
|||
|
||||
// Router is used to do prefix based routing of a request to a logical backend
|
||||
type Router struct {
|
||||
l sync.RWMutex
|
||||
root *radix.Tree
|
||||
tokenStoreSalt *salt.Salt
|
||||
l sync.RWMutex
|
||||
root *radix.Tree
|
||||
mountEntryCache *radix.Tree
|
||||
tokenStoreSalt *salt.Salt
|
||||
|
||||
// storagePrefix maps the prefix used for storage (ala the BarrierView)
|
||||
// to the backend. This is used to map a key back into the backend that owns it.
|
||||
|
@ -27,8 +28,9 @@ type Router struct {
|
|||
// NewRouter returns a new router
|
||||
func NewRouter() *Router {
|
||||
r := &Router{
|
||||
root: radix.New(),
|
||||
storagePrefix: radix.New(),
|
||||
root: radix.New(),
|
||||
storagePrefix: radix.New(),
|
||||
mountEntryCache: radix.New(),
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
@ -74,8 +76,10 @@ func (r *Router) Mount(backend logical.Backend, prefix string, mountEntry *Mount
|
|||
rootPaths: pathsToRadix(paths.Root),
|
||||
loginPaths: pathsToRadix(paths.Unauthenticated),
|
||||
}
|
||||
|
||||
r.root.Insert(prefix, re)
|
||||
r.storagePrefix.Insert(storageView.prefix, re)
|
||||
r.mountEntryCache.Insert(re.mountEntry.UUID, re.mountEntry)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -98,6 +102,8 @@ func (r *Router) Unmount(prefix string) error {
|
|||
// Purge from the radix trees
|
||||
r.root.Delete(prefix)
|
||||
r.storagePrefix.Delete(re.storageView.prefix)
|
||||
r.mountEntryCache.Delete(re.mountEntry.UUID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -141,6 +147,22 @@ func (r *Router) Untaint(path string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) MatchingMountByID(mountID string) *MountEntry {
|
||||
if mountID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.l.RLock()
|
||||
defer r.l.RUnlock()
|
||||
|
||||
_, raw, ok := r.mountEntryCache.LongestPrefix(mountID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return raw.(*MountEntry)
|
||||
}
|
||||
|
||||
// MatchingMount returns the mount prefix that would be used for a path
|
||||
func (r *Router) MatchingMount(path string) string {
|
||||
r.l.RLock()
|
||||
|
|
|
@ -2,6 +2,7 @@ package vault
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
@ -75,8 +76,14 @@ func TestRouter_Mount(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mountEntry := &MountEntry{
|
||||
Path: "prod/aws/",
|
||||
UUID: meUUID,
|
||||
}
|
||||
|
||||
n := &NoopBackend{}
|
||||
err = r.Mount(n, "prod/aws/", &MountEntry{Path: "prod/aws/", UUID: meUUID}, view)
|
||||
err = r.Mount(n, "prod/aws/", mountEntry, view)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -85,6 +92,7 @@ func TestRouter_Mount(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = r.Mount(n, "prod/aws/", &MountEntry{UUID: meUUID}, view)
|
||||
if !strings.Contains(err.Error(), "cannot mount under existing mount") {
|
||||
t.Fatalf("err: %v", err)
|
||||
|
@ -106,6 +114,11 @@ func TestRouter_Mount(t *testing.T) {
|
|||
t.Fatalf("bad: %v", v)
|
||||
}
|
||||
|
||||
mountEntryFetched := r.MatchingMountByID(mountEntry.UUID)
|
||||
if mountEntryFetched == nil || !reflect.DeepEqual(mountEntry, mountEntryFetched) {
|
||||
t.Fatalf("failed to fetch mount entry using its ID; mountEntry: %#v\n mountEntryFetched: %#v\n", mountEntry, mountEntryFetched)
|
||||
}
|
||||
|
||||
mount, prefix, ok := r.MatchingStoragePrefix("logical/foo")
|
||||
if !ok {
|
||||
t.Fatalf("missing storage prefix")
|
||||
|
|
Loading…
Reference in New Issue