diff --git a/vault/router.go b/vault/router.go index 60eb431c6..7b899bed5 100644 --- a/vault/router.go +++ b/vault/router.go @@ -24,13 +24,14 @@ func NewRouter() *Router { // mountEntry is used to represent a mount point type mountEntry struct { + mtype string backend LogicalBackend view *BarrierView rootPaths *radix.Tree } // Mount is used to expose a logical backend at a given prefix -func (r *Router) Mount(backend LogicalBackend, prefix string, view *BarrierView) error { +func (r *Router) Mount(backend LogicalBackend, mtype, prefix string, view *BarrierView) error { r.l.Lock() defer r.l.Unlock() @@ -56,6 +57,7 @@ func (r *Router) Mount(backend LogicalBackend, prefix string, view *BarrierView) // Create a mount entry me := &mountEntry{ + mtype: mtype, backend: backend, view: view, rootPaths: rootPaths, diff --git a/vault/router_test.go b/vault/router_test.go index 63d70548a..f975267d5 100644 --- a/vault/router_test.go +++ b/vault/router_test.go @@ -29,12 +29,12 @@ func TestRouter_Mount(t *testing.T) { view := NewBarrierView(barrier, "logical/") n := &NoopBackend{} - err := r.Mount(n, "prod/aws/", view) + err := r.Mount(n, "noop", "prod/aws/", view) if err != nil { t.Fatalf("err: %v", err) } - err = r.Mount(n, "prod/aws/", view) + err = r.Mount(n, "noop", "prod/aws/", view) if !strings.Contains(err.Error(), "cannot mount under existing mount") { t.Fatalf("err: %v", err) } @@ -62,7 +62,7 @@ func TestRouter_Unmount(t *testing.T) { view := NewBarrierView(barrier, "logical/") n := &NoopBackend{} - err := r.Mount(n, "prod/aws/", view) + err := r.Mount(n, "noop", "prod/aws/", view) if err != nil { t.Fatalf("err: %v", err) } @@ -87,7 +87,7 @@ func TestRouter_Remount(t *testing.T) { view := NewBarrierView(barrier, "logical/") n := &NoopBackend{} - err := r.Mount(n, "prod/aws/", view) + err := r.Mount(n, "noop", "prod/aws/", view) if err != nil { t.Fatalf("err: %v", err) } @@ -135,7 +135,7 @@ func TestRouter_RootPath(t *testing.T) { "policy/*", }, } - err := r.Mount(n, "prod/aws/", view) + err := r.Mount(n, "noop", "prod/aws/", view) if err != nil { t.Fatalf("err: %v", err) }