vault: extend router to handle login routing

This commit is contained in:
Armon Dadgar 2015-03-23 11:47:55 -07:00
parent af2fe5681a
commit 10e64d1e90
3 changed files with 202 additions and 19 deletions

View File

@ -15,6 +15,10 @@ type NoopCred struct {
Paths []string
Requests []*logical.Request
Response *logical.Response
LPaths []string
LoginRequests []*credential.Request
LoginResponse *credential.Response
}
func (n *NoopCred) HandleRequest(req *logical.Request) (*logical.Response, error) {
@ -35,7 +39,12 @@ func (n *NoopCred) LoginPaths() []string {
}
func (n *NoopCred) HandleLogin(req *credential.Request) (*credential.Response, error) {
return nil, nil
n.LPaths = append(n.LPaths, req.Path)
n.LoginRequests = append(n.LoginRequests, req)
if req.Storage == nil {
return nil, fmt.Errorf("missing view")
}
return n.LoginResponse, nil
}
func TestCore_DefaultAuthTable(t *testing.T) {

View File

@ -6,6 +6,7 @@ import (
"sync"
"github.com/armon/go-radix"
"github.com/hashicorp/vault/credential"
"github.com/hashicorp/vault/logical"
)
@ -25,9 +26,10 @@ func NewRouter() *Router {
// mountEntry is used to represent a mount point
type mountEntry struct {
backend logical.Backend
view *BarrierView
rootPaths *radix.Tree
backend logical.Backend
view *BarrierView
rootPaths *radix.Tree
loginPaths *radix.Tree
}
// Mount is used to expose a logical backend at a given prefix
@ -41,25 +43,20 @@ func (r *Router) Mount(backend logical.Backend, prefix string, view *BarrierView
}
// Get the root paths
paths := backend.RootPaths()
var rootPaths *radix.Tree
if len(paths) > 0 {
rootPaths = radix.New()
}
for _, path := range paths {
// Check if this is a prefix or exact match
prefixMatch := len(path) >= 1 && path[len(path)-1] == '*'
if prefixMatch {
path = path[:len(path)-1]
}
rootPaths.Insert(path, prefixMatch)
rootPaths := pathsToRadix(backend.RootPaths())
// Check if this is a credential backend, calculate the login paths
var loginPaths *radix.Tree
if cred, ok := backend.(credential.Backend); ok {
loginPaths = pathsToRadix(cred.LoginPaths())
}
// Create a mount entry
me := &mountEntry{
backend: backend,
view: view,
rootPaths: rootPaths,
backend: backend,
view: view,
rootPaths: rootPaths,
loginPaths: loginPaths,
}
r.root.Insert(prefix, me)
return nil
@ -127,6 +124,43 @@ func (r *Router) Route(req *logical.Request) (*logical.Response, error) {
return me.backend.HandleRequest(req)
}
// RouteLogin is used to route a given login request
func (r *Router) RouteLogin(req *credential.Request) (*credential.Response, error) {
// Ensure this is a login path
if !r.LoginPath(req.Path) {
return nil, fmt.Errorf("invalid login route '%s'", req.Path)
}
// Find the mount point
r.l.RLock()
mount, raw, ok := r.root.LongestPrefix(req.Path)
r.l.RUnlock()
if !ok {
return nil, fmt.Errorf("no handler for route '%s'", req.Path)
}
me := raw.(*mountEntry)
// Adjust the path, attach the barrier view
original := req.Path
req.Path = strings.TrimPrefix(req.Path, mount)
req.Storage = me.view
// Reset the request before returning
defer func() {
req.Path = original
req.Storage = nil
}()
// Convert to a credential backend
cred, ok := me.backend.(credential.Backend)
if !ok {
return nil, fmt.Errorf("invalid login route '%s'", req.Path)
}
// Invoke the backend
return cred.HandleLogin(req)
}
// RootPath checks if the given path requires root privileges
func (r *Router) RootPath(path string) bool {
r.l.RLock()
@ -155,3 +189,47 @@ func (r *Router) RootPath(path string) bool {
// Handle the exact match case
return match == remain
}
// LoginPath checks if the given path is used for logins
func (r *Router) LoginPath(path string) bool {
r.l.RLock()
mount, raw, ok := r.root.LongestPrefix(path)
r.l.RUnlock()
if !ok {
return false
}
me := raw.(*mountEntry)
// Trim to get remaining path
remain := strings.TrimPrefix(path, mount)
// Check the loginPaths of this backend
match, raw, ok := me.loginPaths.LongestPrefix(remain)
if !ok {
return false
}
prefixMatch := raw.(bool)
// Handle the prefix match case
if prefixMatch {
return strings.HasPrefix(remain, match)
}
// Handle the exact match case
return match == remain
}
// pathsToRadix converts a list of paths potentially ending with
// a wildcard expansion "*" into a radix tree.
func pathsToRadix(paths []string) *radix.Tree {
tree := radix.New()
for _, path := range paths {
// Check if this is a prefix or exact match
prefixMatch := len(path) >= 1 && path[len(path)-1] == '*'
if prefixMatch {
path = path[:len(path)-1]
}
tree.Insert(path, prefixMatch)
}
return tree
}

View File

@ -5,6 +5,7 @@ import (
"strings"
"testing"
"github.com/hashicorp/vault/credential"
"github.com/hashicorp/vault/logical"
)
@ -174,3 +175,98 @@ func TestRouter_RootPath(t *testing.T) {
}
}
}
func TestRouter_RouteLogin(t *testing.T) {
r := NewRouter()
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "auth/")
n := &NoopCred{
Login: []string{"bar"},
}
err := r.Mount(n, "auth/foo/", view)
if err != nil {
t.Fatalf("err: %v", err)
}
if path := r.MatchingMount("auth/foo/bar"); path != "auth/foo/" {
t.Fatalf("bad: %s", path)
}
req := &credential.Request{
Path: "auth/foo/bar",
}
resp, err := r.RouteLogin(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %v", resp)
}
// Verify the path
if len(n.LPaths) != 1 || n.LPaths[0] != "bar" {
t.Fatalf("bad: %v", n.Paths)
}
}
func TestRouter_LoginPath(t *testing.T) {
r := NewRouter()
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "auth/")
n := &NoopCred{
Login: []string{
"login",
"oauth/*",
},
}
err := r.Mount(n, "auth/foo/", view)
if err != nil {
t.Fatalf("err: %v", err)
}
type tcase struct {
path string
expect bool
}
tcases := []tcase{
{"random", false},
{"auth/foo/bar", false},
{"auth/foo/login", true},
{"auth/foo/oauth", false},
{"auth/foo/oauth/redirect", true},
}
for _, tc := range tcases {
out := r.LoginPath(tc.path)
if out != tc.expect {
t.Fatalf("bad: path: %s expect: %v got %v", tc.path, tc.expect, out)
}
}
}
func TestPathsToRadix(t *testing.T) {
// Provide real paths
paths := []string{
"foo",
"foo/*",
"sub/bar*",
}
r := pathsToRadix(paths)
raw, ok := r.Get("foo")
if !ok || raw.(bool) != false {
t.Fatalf("bad: %v (foo)", raw)
}
raw, ok = r.Get("foo/")
if !ok || raw.(bool) != true {
t.Fatalf("bad: %v (foo/)", raw)
}
raw, ok = r.Get("sub/bar")
if !ok || raw.(bool) != true {
t.Fatalf("bad: %v (sub/bar)", raw)
}
}