vault: extend router to handle login routing
This commit is contained in:
parent
af2fe5681a
commit
10e64d1e90
|
@ -15,6 +15,10 @@ type NoopCred struct {
|
|||
Paths []string
|
||||
Requests []*logical.Request
|
||||
Response *logical.Response
|
||||
|
||||
LPaths []string
|
||||
LoginRequests []*credential.Request
|
||||
LoginResponse *credential.Response
|
||||
}
|
||||
|
||||
func (n *NoopCred) HandleRequest(req *logical.Request) (*logical.Response, error) {
|
||||
|
@ -35,7 +39,12 @@ func (n *NoopCred) LoginPaths() []string {
|
|||
}
|
||||
|
||||
func (n *NoopCred) HandleLogin(req *credential.Request) (*credential.Response, error) {
|
||||
return nil, nil
|
||||
n.LPaths = append(n.LPaths, req.Path)
|
||||
n.LoginRequests = append(n.LoginRequests, req)
|
||||
if req.Storage == nil {
|
||||
return nil, fmt.Errorf("missing view")
|
||||
}
|
||||
return n.LoginResponse, nil
|
||||
}
|
||||
|
||||
func TestCore_DefaultAuthTable(t *testing.T) {
|
||||
|
|
114
vault/router.go
114
vault/router.go
|
@ -6,6 +6,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/armon/go-radix"
|
||||
"github.com/hashicorp/vault/credential"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
|
@ -25,9 +26,10 @@ func NewRouter() *Router {
|
|||
|
||||
// mountEntry is used to represent a mount point
|
||||
type mountEntry struct {
|
||||
backend logical.Backend
|
||||
view *BarrierView
|
||||
rootPaths *radix.Tree
|
||||
backend logical.Backend
|
||||
view *BarrierView
|
||||
rootPaths *radix.Tree
|
||||
loginPaths *radix.Tree
|
||||
}
|
||||
|
||||
// Mount is used to expose a logical backend at a given prefix
|
||||
|
@ -41,25 +43,20 @@ func (r *Router) Mount(backend logical.Backend, prefix string, view *BarrierView
|
|||
}
|
||||
|
||||
// Get the root paths
|
||||
paths := backend.RootPaths()
|
||||
var rootPaths *radix.Tree
|
||||
if len(paths) > 0 {
|
||||
rootPaths = radix.New()
|
||||
}
|
||||
for _, path := range paths {
|
||||
// Check if this is a prefix or exact match
|
||||
prefixMatch := len(path) >= 1 && path[len(path)-1] == '*'
|
||||
if prefixMatch {
|
||||
path = path[:len(path)-1]
|
||||
}
|
||||
rootPaths.Insert(path, prefixMatch)
|
||||
rootPaths := pathsToRadix(backend.RootPaths())
|
||||
|
||||
// Check if this is a credential backend, calculate the login paths
|
||||
var loginPaths *radix.Tree
|
||||
if cred, ok := backend.(credential.Backend); ok {
|
||||
loginPaths = pathsToRadix(cred.LoginPaths())
|
||||
}
|
||||
|
||||
// Create a mount entry
|
||||
me := &mountEntry{
|
||||
backend: backend,
|
||||
view: view,
|
||||
rootPaths: rootPaths,
|
||||
backend: backend,
|
||||
view: view,
|
||||
rootPaths: rootPaths,
|
||||
loginPaths: loginPaths,
|
||||
}
|
||||
r.root.Insert(prefix, me)
|
||||
return nil
|
||||
|
@ -127,6 +124,43 @@ func (r *Router) Route(req *logical.Request) (*logical.Response, error) {
|
|||
return me.backend.HandleRequest(req)
|
||||
}
|
||||
|
||||
// RouteLogin is used to route a given login request
|
||||
func (r *Router) RouteLogin(req *credential.Request) (*credential.Response, error) {
|
||||
// Ensure this is a login path
|
||||
if !r.LoginPath(req.Path) {
|
||||
return nil, fmt.Errorf("invalid login route '%s'", req.Path)
|
||||
}
|
||||
|
||||
// Find the mount point
|
||||
r.l.RLock()
|
||||
mount, raw, ok := r.root.LongestPrefix(req.Path)
|
||||
r.l.RUnlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no handler for route '%s'", req.Path)
|
||||
}
|
||||
me := raw.(*mountEntry)
|
||||
|
||||
// Adjust the path, attach the barrier view
|
||||
original := req.Path
|
||||
req.Path = strings.TrimPrefix(req.Path, mount)
|
||||
req.Storage = me.view
|
||||
|
||||
// Reset the request before returning
|
||||
defer func() {
|
||||
req.Path = original
|
||||
req.Storage = nil
|
||||
}()
|
||||
|
||||
// Convert to a credential backend
|
||||
cred, ok := me.backend.(credential.Backend)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid login route '%s'", req.Path)
|
||||
}
|
||||
|
||||
// Invoke the backend
|
||||
return cred.HandleLogin(req)
|
||||
}
|
||||
|
||||
// RootPath checks if the given path requires root privileges
|
||||
func (r *Router) RootPath(path string) bool {
|
||||
r.l.RLock()
|
||||
|
@ -155,3 +189,47 @@ func (r *Router) RootPath(path string) bool {
|
|||
// Handle the exact match case
|
||||
return match == remain
|
||||
}
|
||||
|
||||
// LoginPath checks if the given path is used for logins
|
||||
func (r *Router) LoginPath(path string) bool {
|
||||
r.l.RLock()
|
||||
mount, raw, ok := r.root.LongestPrefix(path)
|
||||
r.l.RUnlock()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
me := raw.(*mountEntry)
|
||||
|
||||
// Trim to get remaining path
|
||||
remain := strings.TrimPrefix(path, mount)
|
||||
|
||||
// Check the loginPaths of this backend
|
||||
match, raw, ok := me.loginPaths.LongestPrefix(remain)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
prefixMatch := raw.(bool)
|
||||
|
||||
// Handle the prefix match case
|
||||
if prefixMatch {
|
||||
return strings.HasPrefix(remain, match)
|
||||
}
|
||||
|
||||
// Handle the exact match case
|
||||
return match == remain
|
||||
}
|
||||
|
||||
// pathsToRadix converts a list of paths potentially ending with
|
||||
// a wildcard expansion "*" into a radix tree.
|
||||
func pathsToRadix(paths []string) *radix.Tree {
|
||||
tree := radix.New()
|
||||
for _, path := range paths {
|
||||
// Check if this is a prefix or exact match
|
||||
prefixMatch := len(path) >= 1 && path[len(path)-1] == '*'
|
||||
if prefixMatch {
|
||||
path = path[:len(path)-1]
|
||||
}
|
||||
tree.Insert(path, prefixMatch)
|
||||
}
|
||||
return tree
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/credential"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
|
@ -174,3 +175,98 @@ func TestRouter_RootPath(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_RouteLogin(t *testing.T) {
|
||||
r := NewRouter()
|
||||
_, barrier, _ := mockBarrier(t)
|
||||
view := NewBarrierView(barrier, "auth/")
|
||||
|
||||
n := &NoopCred{
|
||||
Login: []string{"bar"},
|
||||
}
|
||||
err := r.Mount(n, "auth/foo/", view)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if path := r.MatchingMount("auth/foo/bar"); path != "auth/foo/" {
|
||||
t.Fatalf("bad: %s", path)
|
||||
}
|
||||
|
||||
req := &credential.Request{
|
||||
Path: "auth/foo/bar",
|
||||
}
|
||||
resp, err := r.RouteLogin(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
|
||||
// Verify the path
|
||||
if len(n.LPaths) != 1 || n.LPaths[0] != "bar" {
|
||||
t.Fatalf("bad: %v", n.Paths)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_LoginPath(t *testing.T) {
|
||||
r := NewRouter()
|
||||
_, barrier, _ := mockBarrier(t)
|
||||
view := NewBarrierView(barrier, "auth/")
|
||||
|
||||
n := &NoopCred{
|
||||
Login: []string{
|
||||
"login",
|
||||
"oauth/*",
|
||||
},
|
||||
}
|
||||
err := r.Mount(n, "auth/foo/", view)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
type tcase struct {
|
||||
path string
|
||||
expect bool
|
||||
}
|
||||
tcases := []tcase{
|
||||
{"random", false},
|
||||
{"auth/foo/bar", false},
|
||||
{"auth/foo/login", true},
|
||||
{"auth/foo/oauth", false},
|
||||
{"auth/foo/oauth/redirect", true},
|
||||
}
|
||||
|
||||
for _, tc := range tcases {
|
||||
out := r.LoginPath(tc.path)
|
||||
if out != tc.expect {
|
||||
t.Fatalf("bad: path: %s expect: %v got %v", tc.path, tc.expect, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathsToRadix(t *testing.T) {
|
||||
// Provide real paths
|
||||
paths := []string{
|
||||
"foo",
|
||||
"foo/*",
|
||||
"sub/bar*",
|
||||
}
|
||||
r := pathsToRadix(paths)
|
||||
|
||||
raw, ok := r.Get("foo")
|
||||
if !ok || raw.(bool) != false {
|
||||
t.Fatalf("bad: %v (foo)", raw)
|
||||
}
|
||||
|
||||
raw, ok = r.Get("foo/")
|
||||
if !ok || raw.(bool) != true {
|
||||
t.Fatalf("bad: %v (foo/)", raw)
|
||||
}
|
||||
|
||||
raw, ok = r.Get("sub/bar")
|
||||
if !ok || raw.(bool) != true {
|
||||
t.Fatalf("bad: %v (sub/bar)", raw)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue