From c460ff10ca2959e28ea2decde12b6f350dcbc532 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Fri, 4 Sep 2015 16:58:12 -0400 Subject: [PATCH] Push a lot of logic into Router to make a bunch of it nicer and enable a lot of cleanup. Plumb config and calls to framework.Backend.Setup() into logical_system and elsewhere, including tests. --- .gitignore | 2 + command/mounttune.go | 2 +- vault/auth.go | 12 +- vault/core.go | 18 +-- vault/dynamic_system_view.go | 54 +++++++++ vault/expiration.go | 2 +- vault/expiration_test.go | 22 ++-- vault/logical_passthrough.go | 7 +- vault/logical_passthrough_test.go | 8 +- vault/logical_system.go | 31 ++++- vault/logical_system_test.go | 18 ++- vault/mount.go | 189 +++++++----------------------- vault/mount_test.go | 6 +- vault/policy_store.go | 2 +- vault/rollback_test.go | 2 +- vault/router.go | 71 +++++++---- vault/router_test.go | 16 +-- vault/testing.go | 8 +- vault/token_store.go | 6 +- vault/token_store_test.go | 23 +++- 20 files changed, 268 insertions(+), 231 deletions(-) create mode 100644 vault/dynamic_system_view.go diff --git a/.gitignore b/.gitignore index 14789deca..8bc0d6f24 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,8 @@ Vagrantfile dist/* +tags + # Editor backups *~ *.sw[a-z] diff --git a/command/mounttune.go b/command/mounttune.go index 42a12027f..43b3cf498 100644 --- a/command/mounttune.go +++ b/command/mounttune.go @@ -8,7 +8,7 @@ import ( "github.com/hashicorp/vault/vault" ) -// RemountCommand is a Command that remounts a mounted secret backend +// MountTuneCommand is a Command that remounts a mounted secret backend // to a new endpoint. type MountTuneCommand struct { Meta diff --git a/vault/auth.go b/vault/auth.go index 06e5f539a..7b09c2948 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -66,7 +66,7 @@ func (c *Core) enableCredential(entry *MountEntry) error { view := NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/") // Create the new backend - backend, err := c.newCredentialBackend(entry.Type, view, nil) + backend, err := c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil) if err != nil { return err } @@ -81,7 +81,7 @@ func (c *Core) enableCredential(entry *MountEntry) error { // Mount the backend path := credentialRoutePrefix + entry.Path - if err := c.router.Mount(backend, path, entry.UUID, view); err != nil { + if err := c.router.Mount(backend, path, entry, view); err != nil { return err } c.logger.Printf("[INFO] core: enabled credential backend '%s' type: %s", @@ -242,7 +242,7 @@ func (c *Core) setupCredentials() error { view = NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/") // Initialize the backend - backend, err = c.newCredentialBackend(entry.Type, view, nil) + backend, err = c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil) if err != nil { c.logger.Printf( "[ERR] core: failed to create credential entry %#v: %v", @@ -252,7 +252,7 @@ func (c *Core) setupCredentials() error { // Mount the backend path := credentialRoutePrefix + entry.Path - err = c.router.Mount(backend, path, entry.UUID, view) + err = c.router.Mount(backend, path, entry, view) if err != nil { c.logger.Printf("[ERR] core: failed to mount auth entry %#v: %v", entry, err) return loadAuthFailed @@ -281,7 +281,7 @@ func (c *Core) teardownCredentials() error { // newCredentialBackend is used to create and configure a new credential backend by name func (c *Core) newCredentialBackend( - t string, view logical.Storage, conf map[string]string) (logical.Backend, error) { + t string, sysView logical.SystemView, view logical.Storage, conf map[string]string) (logical.Backend, error) { f, ok := c.credentialBackends[t] if !ok { return nil, fmt.Errorf("unknown backend type: %s", t) @@ -291,12 +291,14 @@ func (c *Core) newCredentialBackend( View: view, Logger: c.logger, Config: conf, + System: sysView, } b, err := f(config) if err != nil { return nil, err } + return b, nil } diff --git a/vault/core.go b/vault/core.go index 82b85a2b6..84e658b1c 100644 --- a/vault/core.go +++ b/vault/core.go @@ -220,8 +220,8 @@ type Core struct { // out into the configured audit backends auditBroker *AuditBroker - // systemView is the barrier view for the system backend - systemView *BarrierView + // systemBarrierView is the barrier view for the system backend + systemBarrierView *BarrierView // expiration manager is used for managing LeaseIDs, // renewal, expiration and revocation @@ -351,8 +351,8 @@ func NewCore(conf *CoreConfig) (*Core, error) { logicalBackends[k] = f } logicalBackends["generic"] = PassthroughBackendFactory - logicalBackends["system"] = func(*logical.BackendConfig) (logical.Backend, error) { - return NewSystemBackend(c), nil + logicalBackends["system"] = func(config *logical.BackendConfig) (logical.Backend, error) { + return NewSystemBackend(c, config), nil } c.logicalBackends = logicalBackends @@ -360,8 +360,8 @@ func NewCore(conf *CoreConfig) (*Core, error) { for k, f := range conf.CredentialBackends { credentialBackends[k] = f } - credentialBackends["token"] = func(*logical.BackendConfig) (logical.Backend, error) { - return NewTokenStore(c) + credentialBackends["token"] = func(config *logical.BackendConfig) (logical.Backend, error) { + return NewTokenStore(c, config) } c.credentialBackends = credentialBackends @@ -478,9 +478,9 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r // We exclude renewal of a lease, since it does not need to be re-registered if resp != nil && resp.Secret != nil && !strings.HasPrefix(req.Path, "sys/renew/") { // Get the SystemView for the mount - sysView, err := c.sysViewByPath(req.Path) - if err != nil { - c.logger.Println(err) + sysView := c.router.MatchingSystemView(req.Path) + if sysView == nil { + c.logger.Println("[ERR] core: unable to retrieve system view from router") return nil, auth, ErrInternalError } diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go new file mode 100644 index 000000000..b6fb90364 --- /dev/null +++ b/vault/dynamic_system_view.go @@ -0,0 +1,54 @@ +package vault + +import ( + "fmt" + "strings" + "time" +) + +type dynamicSystemView struct { + core *Core + path string +} + +func (d dynamicSystemView) DefaultLeaseTTL() (time.Duration, error) { + def, _, err := d.fetchTTLs() + if err != nil { + return 0, err + } + return def, nil +} + +func (d dynamicSystemView) MaxLeaseTTL() (time.Duration, error) { + _, max, err := d.fetchTTLs() + if err != nil { + return 0, err + } + return max, nil +} + +// TTLsByPath returns the default and max TTLs corresponding to a particular +// mount point, or the system default +func (d dynamicSystemView) fetchTTLs() (def, max time.Duration, retErr error) { + // Ensure we end the path in a slash + if !strings.HasSuffix(d.path, "/") { + d.path += "/" + } + + me := d.core.router.MatchingMountEntry(d.path) + if me == nil { + return 0, 0, fmt.Errorf("[ERR] core: failed to get mount entry for %s", d.path) + } + + def = d.core.defaultLeaseTTL + max = d.core.maxLeaseTTL + + if me.Config.DefaultLeaseTTL != nil && *me.Config.DefaultLeaseTTL != 0 { + def = *me.Config.DefaultLeaseTTL + } + if me.Config.MaxLeaseTTL != nil && *me.Config.MaxLeaseTTL != 0 { + max = *me.Config.MaxLeaseTTL + } + + return +} diff --git a/vault/expiration.go b/vault/expiration.go index 11454e421..702821157 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -78,7 +78,7 @@ func NewExpirationManager(router *Router, view *BarrierView, ts *TokenStore, log // initialize the expiration manager func (c *Core) setupExpiration() error { // Create a sub-view - view := c.systemView.SubView(expirationSubPath) + view := c.systemBarrierView.SubView(expirationSubPath) // Create the manager mgr := NewExpirationManager(c.router, view, c.tokenStore, c.logger) diff --git a/vault/expiration_test.go b/vault/expiration_test.go index 4c89dc9a5..4825678e9 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -22,7 +22,7 @@ func TestExpiration_Restore(t *testing.T) { noop := &NoopBackend{} _, barrier, _ := mockBarrier(t) view := NewBarrierView(barrier, "logical/") - exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) + exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view) paths := []string{ "prod/aws/foo", @@ -175,7 +175,7 @@ func TestExpiration_Revoke(t *testing.T) { noop := &NoopBackend{} _, barrier, _ := mockBarrier(t) view := NewBarrierView(barrier, "logical/") - exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) + exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view) req := &logical.Request{ Operation: logical.ReadOperation, @@ -213,7 +213,7 @@ func TestExpiration_RevokeOnExpire(t *testing.T) { noop := &NoopBackend{} _, barrier, _ := mockBarrier(t) view := NewBarrierView(barrier, "logical/") - exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) + exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view) req := &logical.Request{ Operation: logical.ReadOperation, @@ -262,7 +262,7 @@ func TestExpiration_RevokePrefix(t *testing.T) { noop := &NoopBackend{} _, barrier, _ := mockBarrier(t) view := NewBarrierView(barrier, "logical/") - exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) + exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view) paths := []string{ "prod/aws/foo", @@ -322,7 +322,7 @@ func TestExpiration_RevokeByToken(t *testing.T) { noop := &NoopBackend{} _, barrier, _ := mockBarrier(t) view := NewBarrierView(barrier, "logical/") - exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) + exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view) paths := []string{ "prod/aws/foo", @@ -441,7 +441,7 @@ func TestExpiration_Renew(t *testing.T) { noop := &NoopBackend{} _, barrier, _ := mockBarrier(t) view := NewBarrierView(barrier, "logical/") - exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) + exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view) req := &logical.Request{ Operation: logical.ReadOperation, @@ -503,7 +503,7 @@ func TestExpiration_Renew_NotRenewable(t *testing.T) { noop := &NoopBackend{} _, barrier, _ := mockBarrier(t) view := NewBarrierView(barrier, "logical/") - exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) + exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view) req := &logical.Request{ Operation: logical.ReadOperation, @@ -545,7 +545,7 @@ func TestExpiration_Renew_RevokeOnExpire(t *testing.T) { noop := &NoopBackend{} _, barrier, _ := mockBarrier(t) view := NewBarrierView(barrier, "logical/") - exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) + exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view) req := &logical.Request{ Operation: logical.ReadOperation, @@ -613,7 +613,7 @@ func TestExpiration_revokeEntry(t *testing.T) { noop := &NoopBackend{} _, barrier, _ := mockBarrier(t) view := NewBarrierView(barrier, "logical/") - exp.router.Mount(noop, "", uuid.GenerateUUID(), view) + exp.router.Mount(noop, "", &MountEntry{UUID: uuid.GenerateUUID()}, view) le := &leaseEntry{ LeaseID: "foo/bar/1234", @@ -702,7 +702,7 @@ func TestExpiration_renewEntry(t *testing.T) { } _, barrier, _ := mockBarrier(t) view := NewBarrierView(barrier, "logical/") - exp.router.Mount(noop, "", uuid.GenerateUUID(), view) + exp.router.Mount(noop, "", &MountEntry{UUID: uuid.GenerateUUID()}, view) le := &leaseEntry{ LeaseID: "foo/bar/1234", @@ -764,7 +764,7 @@ func TestExpiration_renewAuthEntry(t *testing.T) { } _, barrier, _ := mockBarrier(t) view := NewBarrierView(barrier, "auth/foo/") - exp.router.Mount(noop, "auth/foo/", uuid.GenerateUUID(), view) + exp.router.Mount(noop, "auth/foo/", &MountEntry{UUID: uuid.GenerateUUID()}, view) le := &leaseEntry{ LeaseID: "auth/foo/1234", diff --git a/vault/logical_passthrough.go b/vault/logical_passthrough.go index 5ab91b282..03d721c24 100644 --- a/vault/logical_passthrough.go +++ b/vault/logical_passthrough.go @@ -11,7 +11,7 @@ import ( ) // logical.Factory -func PassthroughBackendFactory(*logical.BackendConfig) (logical.Backend, error) { +func PassthroughBackendFactory(conf *logical.BackendConfig) (logical.Backend, error) { var b PassthroughBackend b.Backend = &framework.Backend{ Help: strings.TrimSpace(passthroughHelp), @@ -53,6 +53,11 @@ func PassthroughBackendFactory(*logical.BackendConfig) (logical.Backend, error) }, } + if conf == nil { + return nil, fmt.Errorf("Configuation passed into backend is nil") + } + b.Backend.Setup(conf) + return b, nil } diff --git a/vault/logical_passthrough_test.go b/vault/logical_passthrough_test.go index 7c3b58dde..c10eee516 100644 --- a/vault/logical_passthrough_test.go +++ b/vault/logical_passthrough_test.go @@ -176,6 +176,12 @@ func TestPassthroughBackend_List(t *testing.T) { } func testPassthroughBackend() logical.Backend { - b, _ := PassthroughBackendFactory(nil) + b, _ := PassthroughBackendFactory(&logical.BackendConfig{ + Logger: nil, + System: logical.StaticSystemView{ + DefaultLeaseTTLVal: time.Hour * 24, + MaxLeaseTTLVal: time.Hour * 24 * 30, + }, + }) return b } diff --git a/vault/logical_system.go b/vault/logical_system.go index 17896eadc..82236023e 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -20,10 +20,11 @@ var ( } ) -func NewSystemBackend(core *Core) logical.Backend { +func NewSystemBackend(core *Core, config *logical.BackendConfig) logical.Backend { b := &SystemBackend{ Core: core, } + b.Backend = &framework.Backend{ Help: strings.TrimSpace(sysHelpRoot), @@ -346,6 +347,9 @@ func NewSystemBackend(core *Core) logical.Backend { }, }, } + + b.Backend.Setup(config) + return b.Backend } @@ -486,9 +490,26 @@ func (b *SystemBackend) handleMountConfig( logical.ErrInvalidRequest } - def, max, err := b.Core.TTLsByPath(path) + if !strings.HasSuffix(path, "/") { + path += "/" + } + + sysView := b.Core.router.MatchingSystemView(path) + if sysView == nil { + err := fmt.Errorf("[ERR] sys: cannot fetch sysview for path %s", path) + b.Backend.Logger().Print(err) + return handleError(err) + } + + def, err := sysView.DefaultLeaseTTL() if err != nil { - b.Backend.Logger().Printf("[ERR] sys: fetching config of path '%s' failed: %v", path, err) + b.Backend.Logger().Printf("[ERR] sys: fetching config default TTL of path '%s' failed: %v", path, err) + return handleError(err) + } + + max, err := sysView.MaxLeaseTTL() + if err != nil { + b.Backend.Logger().Printf("[ERR] sys: fetching config max TTL of path '%s' failed: %v", path, err) return handleError(err) } @@ -516,6 +537,10 @@ func (b *SystemBackend) handleMountTune( logical.ErrInvalidRequest } + if !strings.HasSuffix(path, "/") { + path += "/" + } + var config MountConfig configMap := data.Get("config").(map[string]interface{}) if configMap == nil || len(configMap) == 0 { diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 5b794a78f..1d2020785 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -760,10 +760,24 @@ func TestSystemBackend_rotate(t *testing.T) { func testSystemBackend(t *testing.T) logical.Backend { c, _, _ := TestCoreUnsealed(t) - return NewSystemBackend(c) + bc := &logical.BackendConfig{ + Logger: c.logger, + System: logical.StaticSystemView{ + DefaultLeaseTTLVal: time.Hour * 24, + MaxLeaseTTLVal: time.Hour * 24 * 30, + }, + } + return NewSystemBackend(c, bc) } func testCoreSystemBackend(t *testing.T) (*Core, logical.Backend, string) { c, _, root := TestCoreUnsealed(t) - return c, NewSystemBackend(c), root + bc := &logical.BackendConfig{ + Logger: c.logger, + System: logical.StaticSystemView{ + DefaultLeaseTTLVal: time.Hour * 24, + MaxLeaseTTLVal: time.Hour * 24 * 30, + }, + } + return c, NewSystemBackend(c, bc), root } diff --git a/vault/mount.go b/vault/mount.go index 35ed7c897..cc16ae5b5 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -40,27 +40,6 @@ var ( } ) -type dynamicSystemView struct { - core *Core - path string -} - -func (d dynamicSystemView) DefaultLeaseTTL() (time.Duration, error) { - def, _, err := d.core.TTLsByPath(d.path) - if err != nil { - return 0, err - } - return def, nil -} - -func (d dynamicSystemView) MaxLeaseTTL() (time.Duration, error) { - _, max, err := d.core.TTLsByPath(d.path) - if err != nil { - return 0, err - } - return max, nil -} - // MountTable is used to represent the internal mount table type MountTable struct { // This lock should be held whenever modifying the Entries field. @@ -185,12 +164,7 @@ func (c *Core) mount(me *MountEntry) error { me.UUID = uuid.GenerateUUID() view := NewBarrierView(c.barrier, backendBarrierPrefix+me.UUID+"/") - // Create the new backend - sysView, err := c.mountEntrySysView(me) - if err != nil { - return err - } - backend, err := c.newLogicalBackend(me.Type, sysView, view, nil) + backend, err := c.newLogicalBackend(me.Type, c.mountEntrySysView(me), view, nil) if err != nil { return err } @@ -204,7 +178,7 @@ func (c *Core) mount(me *MountEntry) error { c.mounts = newTable // Mount the backend - if err := c.router.Mount(backend, me.Path, me.UUID, view); err != nil { + if err := c.router.Mount(backend, me.Path, me, view); err != nil { return err } c.logger.Printf("[INFO] core: mounted '%s' type: %s", me.Path, me.Type) @@ -394,51 +368,44 @@ func (c *Core) tuneMount(path string, config MountConfig) error { // Prevent protected paths from being changed for _, p := range protectedMounts { if strings.HasPrefix(path, p) { - return fmt.Errorf("cannot tune '%s'", path) + return fmt.Errorf("[ERR] core: cannot tune '%s'", path) } } - // Verify exact match of the route - match := c.router.MatchingMount(path) - if match == "" || path != match { - return fmt.Errorf("no matching mount at '%s'", path) + me := c.router.MatchingMountEntry(path) + if me == nil { + return fmt.Errorf("[ERR] core: no matching mount at '%s'", path) } - // Find and modify mount - for _, ent := range c.mounts.Entries { - if ent.Path == path { - if config.MaxLeaseTTL != nil { - if *ent.Config.DefaultLeaseTTL != 0 { - if *config.MaxLeaseTTL < *ent.Config.DefaultLeaseTTL { - return fmt.Errorf("Given backend max lease TTL of %d less than backend default lease TTL of %d", - *config.MaxLeaseTTL, *ent.Config.DefaultLeaseTTL) - } - } - if *config.MaxLeaseTTL == 0 { - *ent.Config.MaxLeaseTTL = 0 - } else { - ent.Config.MaxLeaseTTL = config.MaxLeaseTTL - } + if config.MaxLeaseTTL != nil { + if *me.Config.DefaultLeaseTTL != 0 { + if *config.MaxLeaseTTL < *me.Config.DefaultLeaseTTL { + return fmt.Errorf("Given backend max lease TTL of %d less than backend default lease TTL of %d", + *config.MaxLeaseTTL, *me.Config.DefaultLeaseTTL) } - if config.DefaultLeaseTTL != nil { - if *ent.Config.MaxLeaseTTL == 0 { - if *config.DefaultLeaseTTL > c.maxLeaseTTL { - return fmt.Errorf("Given default lease TTL of %d greater than system default lease TTL of %d", - *config.DefaultLeaseTTL, c.maxLeaseTTL) - } - } else { - if *ent.Config.MaxLeaseTTL != 0 && *ent.Config.MaxLeaseTTL < *config.DefaultLeaseTTL { - return fmt.Errorf("Given default lease TTL of %d greater than backend max lease TTL of %d", - *config.DefaultLeaseTTL, *ent.Config.MaxLeaseTTL) - } - } - if *config.DefaultLeaseTTL == 0 { - *ent.Config.DefaultLeaseTTL = 0 - } else { - ent.Config.DefaultLeaseTTL = config.DefaultLeaseTTL - } + } + if *config.MaxLeaseTTL == 0 { + *me.Config.MaxLeaseTTL = 0 + } else { + me.Config.MaxLeaseTTL = config.MaxLeaseTTL + } + } + if config.DefaultLeaseTTL != nil { + if *me.Config.MaxLeaseTTL == 0 { + if *config.DefaultLeaseTTL > c.maxLeaseTTL { + return fmt.Errorf("Given default lease TTL of %d greater than system default lease TTL of %d", + *config.DefaultLeaseTTL, c.maxLeaseTTL) } - break + } else { + if *me.Config.MaxLeaseTTL != 0 && *me.Config.MaxLeaseTTL < *config.DefaultLeaseTTL { + return fmt.Errorf("Given default lease TTL of %d greater than backend max lease TTL of %d", + *config.DefaultLeaseTTL, *me.Config.MaxLeaseTTL) + } + } + if *config.DefaultLeaseTTL == 0 { + *me.Config.DefaultLeaseTTL = 0 + } else { + me.Config.DefaultLeaseTTL = config.DefaultLeaseTTL } } @@ -508,6 +475,7 @@ func (c *Core) persistMounts(table *MountTable) error { func (c *Core) setupMounts() error { var backend logical.Backend var view *BarrierView + var err error for _, entry := range c.mounts.Entries { // Initialize the backend, special casing for system barrierPath := backendBarrierPrefix + entry.UUID + "/" @@ -520,11 +488,7 @@ func (c *Core) setupMounts() error { // Initialize the backend // Create the new backend - sysView, err := c.mountEntrySysView(entry) - if err != nil { - return err - } - backend, err = c.newLogicalBackend(entry.Type, sysView, view, nil) + backend, err = c.newLogicalBackend(entry.Type, c.mountEntrySysView(entry), view, nil) if err != nil { c.logger.Printf( "[ERR] core: failed to create mount entry %#v: %v", @@ -533,11 +497,11 @@ func (c *Core) setupMounts() error { } if entry.Type == "system" { - c.systemView = view + c.systemBarrierView = view } // Mount the backend - err = c.router.Mount(backend, entry.Path, entry.UUID, view) + err = c.router.Mount(backend, entry.Path, entry, view) if err != nil { c.logger.Printf("[ERR] core: failed to mount entry %#v: %v", entry, err) return errLoadMountsFailed @@ -556,7 +520,7 @@ func (c *Core) setupMounts() error { func (c *Core) unloadMounts() error { c.mounts = nil c.router = NewRouter() - c.systemView = nil + c.systemBarrierView = nil return nil } @@ -582,82 +546,13 @@ func (c *Core) newLogicalBackend(t string, sysView logical.SystemView, view logi } // mountEntrySysView creates a logical.SystemView from global and -// mount-specific entries -func (c *Core) mountEntrySysView(me *MountEntry) (logical.SystemView, error) { - if me == nil { - return nil, fmt.Errorf("[ERR] core: nil MountEntry when generating SystemView") - } - - sysView := dynamicSystemView{ +// mount-specific entries; because this should be called when setting +// up a mountEntry, it doesn't check to ensure that me is not nil +func (c *Core) mountEntrySysView(me *MountEntry) logical.SystemView { + return dynamicSystemView{ core: c, path: me.Path, } - - return sysView, nil -} - -// sysViewByPath is a simple helper for MountEntrySysView -func (c *Core) sysViewByPath(path string) (logical.SystemView, error) { - // Ensure we end the path in a slash - if !strings.HasSuffix(path, "/") { - path += "/" - } - - me, err := c.mountEntryByPath(path) - if err != nil { - return nil, err - } - return c.mountEntrySysView(me) -} - -// mountEntryByPath searches across all tables to find the MountEntry -func (c *Core) mountEntryByPath(path string) (*MountEntry, error) { - // Ensure we end the path in a slash - if !strings.HasSuffix(path, "/") { - path += "/" - } - - pathSep := strings.IndexRune(path, '/') - if pathSep == -1 { - return nil, fmt.Errorf("[ERR] core: failed to find separator for path %s", path) - } - me := c.mounts.Find(path[0 : pathSep+1]) - if me == nil { - me = c.auth.Find(path[0 : pathSep+1]) - } - if me == nil { - me = c.audit.Find(path[0 : pathSep+1]) - } - if me == nil { - return nil, fmt.Errorf("[ERR] core: failed to find mount entry for path %s", path) - } - return me, nil -} - -// TTLsByPath returns the default and max TTLs corresponding to a particular -// mount point, or the system default -func (c *Core) TTLsByPath(path string) (def, max time.Duration, retErr error) { - // Ensure we end the path in a slash - if !strings.HasSuffix(path, "/") { - path += "/" - } - - me, err := c.mountEntryByPath(path) - if err != nil { - return 0, 0, err - } - - def = c.defaultLeaseTTL - max = c.maxLeaseTTL - - if me.Config.DefaultLeaseTTL != nil && *me.Config.DefaultLeaseTTL != 0 { - def = *me.Config.DefaultLeaseTTL - } - if me.Config.MaxLeaseTTL != nil && *me.Config.MaxLeaseTTL != 0 { - max = *me.Config.MaxLeaseTTL - } - - return } // defaultMountTable creates a default mount table diff --git a/vault/mount_test.go b/vault/mount_test.go index c577611c2..683282b42 100644 --- a/vault/mount_test.go +++ b/vault/mount_test.go @@ -192,7 +192,7 @@ func TestCore_Unmount_Cleanup(t *testing.T) { func TestCore_Remount(t *testing.T) { c, key, _ := TestCoreUnsealed(t) - err := c.remount("secret", "foo", MountConfig{}) + err := c.remount("secret", "foo") if err != nil { t.Fatalf("err: %v", err) } @@ -280,7 +280,7 @@ func TestCore_Remount_Cleanup(t *testing.T) { } // Remount, this should cleanup - if err := c.remount("test/", "new/", MountConfig{}); err != nil { + if err := c.remount("test/", "new/"); err != nil { t.Fatalf("err: %v", err) } @@ -309,7 +309,7 @@ func TestCore_Remount_Cleanup(t *testing.T) { func TestCore_Remount_Protected(t *testing.T) { c, _, _ := TestCoreUnsealed(t) - err := c.remount("sys", "foo", MountConfig{}) + err := c.remount("sys", "foo") if err.Error() != "cannot remount 'sys/'" { t.Fatalf("err: %v", err) } diff --git a/vault/policy_store.go b/vault/policy_store.go index 99f9256b3..bef087ef5 100644 --- a/vault/policy_store.go +++ b/vault/policy_store.go @@ -46,7 +46,7 @@ func NewPolicyStore(view *BarrierView) *PolicyStore { // when the vault is being unsealed. func (c *Core) setupPolicyStore() error { // Create a sub-view - view := c.systemView.SubView(policySubPath) + view := c.systemBarrierView.SubView(policySubPath) // Create the policy store c.policy = NewPolicyStore(view) diff --git a/vault/rollback_test.go b/vault/rollback_test.go index 293d03b5a..9083a6541 100644 --- a/vault/rollback_test.go +++ b/vault/rollback_test.go @@ -21,7 +21,7 @@ func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) { Path: "foo", }, } - if err := router.Mount(backend, "foo", uuid.GenerateUUID(), nil); err != nil { + if err := router.Mount(backend, "foo", &MountEntry{UUID: uuid.GenerateUUID()}, nil); err != nil { t.Fatalf("err: %s", err) } diff --git a/vault/router.go b/vault/router.go index 56ab811b9..8b38532cc 100644 --- a/vault/router.go +++ b/vault/router.go @@ -26,24 +26,24 @@ func NewRouter() *Router { return r } -// mountEntry is used to represent a mount point -type mountEntry struct { +// routeEntry is used to represent a mount point in the router +type routeEntry struct { tainted bool - salt string backend logical.Backend + mountEntry *MountEntry view *BarrierView rootPaths *radix.Tree loginPaths *radix.Tree } // SaltID is used to apply a salt and hash to an ID to make sure its not reversable -func (me *mountEntry) SaltID(id string) string { - return salt.SaltID(me.salt, id, salt.SHA1Hash) +func (re *routeEntry) SaltID(id string) string { + return salt.SaltID(re.mountEntry.UUID, id, salt.SHA1Hash) } // Mount is used to expose a logical backend at a given prefix, using a unique salt, // and the barrier view for that path. -func (r *Router) Mount(backend logical.Backend, prefix, salt string, view *BarrierView) error { +func (r *Router) Mount(backend logical.Backend, prefix string, mountEntry *MountEntry, view *BarrierView) error { r.l.Lock() defer r.l.Unlock() @@ -59,14 +59,15 @@ func (r *Router) Mount(backend logical.Backend, prefix, salt string, view *Barri } // Create a mount entry - me := &mountEntry{ + re := &routeEntry{ tainted: false, backend: backend, + mountEntry: mountEntry, view: view, rootPaths: pathsToRadix(paths.Root), loginPaths: pathsToRadix(paths.Unauthenticated), } - r.root.Insert(prefix, me) + r.root.Insert(prefix, re) return nil } @@ -91,12 +92,8 @@ func (r *Router) Remount(src, dst string) error { // Update the mount point r.root.Delete(src) - mountEntry, ok := raw.(*mountEntry) - if !ok { - return fmt.Errorf("Unable to retrieve mount entry at '%s'", src) - } - sysView := mountEntry.backend.System() - dynSysView, ok := sysView.(dynamicSystemView) + routeEntry := raw.(*routeEntry) + dynSysView, ok := routeEntry.backend.System().(dynamicSystemView) if ok { dynSysView.path = dst } @@ -111,7 +108,7 @@ func (r *Router) Taint(path string) error { defer r.l.Unlock() _, raw, ok := r.root.LongestPrefix(path) if ok { - raw.(*mountEntry).tainted = true + raw.(*routeEntry).tainted = true } return nil } @@ -122,7 +119,7 @@ func (r *Router) Untaint(path string) error { defer r.l.Unlock() _, raw, ok := r.root.LongestPrefix(path) if ok { - raw.(*mountEntry).tainted = false + raw.(*routeEntry).tainted = false } return nil } @@ -146,7 +143,29 @@ func (r *Router) MatchingView(path string) *BarrierView { if !ok { return nil } - return raw.(*mountEntry).view + return raw.(*routeEntry).view +} + +// MatchingMountEntry returns the MountEntry used for a path +func (r *Router) MatchingMountEntry(path string) *MountEntry { + r.l.RLock() + _, raw, ok := r.root.LongestPrefix(path) + r.l.RUnlock() + if !ok { + return nil + } + return raw.(*routeEntry).mountEntry +} + +// MatchingSystemView returns the SystemView used for a path +func (r *Router) MatchingSystemView(path string) logical.SystemView { + r.l.RLock() + _, raw, ok := r.root.LongestPrefix(path) + r.l.RUnlock() + if !ok { + return nil + } + return raw.(*routeEntry).backend.System() } // Route is used to route a given request @@ -166,11 +185,11 @@ func (r *Router) Route(req *logical.Request) (*logical.Response, error) { } defer metrics.MeasureSince([]string{"route", string(req.Operation), strings.Replace(mount, "/", "-", -1)}, time.Now()) - me := raw.(*mountEntry) + re := raw.(*routeEntry) // If the path is tainted, we reject any operation except for // Rollback and Revoke - if me.tainted { + if re.tainted { switch req.Operation { case logical.RevokeOperation, logical.RollbackOperation: default: @@ -190,12 +209,12 @@ func (r *Router) Route(req *logical.Request) (*logical.Response, error) { } // Attach the storage view for the request - req.Storage = me.view + req.Storage = re.view // Hash the request token unless this is the token backend clientToken := req.ClientToken if !strings.HasPrefix(original, "auth/token/") { - req.ClientToken = me.SaltID(req.ClientToken) + req.ClientToken = re.SaltID(req.ClientToken) } // If the request is not a login path, then clear the connection @@ -214,7 +233,7 @@ func (r *Router) Route(req *logical.Request) (*logical.Response, error) { }() // Invoke the backend - return me.backend.HandleRequest(req) + return re.backend.HandleRequest(req) } // RootPath checks if the given path requires root privileges @@ -225,13 +244,13 @@ func (r *Router) RootPath(path string) bool { if !ok { return false } - me := raw.(*mountEntry) + re := raw.(*routeEntry) // Trim to get remaining path remain := strings.TrimPrefix(path, mount) // Check the rootPaths of this backend - match, raw, ok := me.rootPaths.LongestPrefix(remain) + match, raw, ok := re.rootPaths.LongestPrefix(remain) if !ok { return false } @@ -254,13 +273,13 @@ func (r *Router) LoginPath(path string) bool { if !ok { return false } - me := raw.(*mountEntry) + re := raw.(*routeEntry) // Trim to get remaining path remain := strings.TrimPrefix(path, mount) // Check the loginPaths of this backend - match, raw, ok := me.loginPaths.LongestPrefix(remain) + match, raw, ok := re.loginPaths.LongestPrefix(remain) if !ok { return false } diff --git a/vault/router_test.go b/vault/router_test.go index 94450c9be..e1a6190c2 100644 --- a/vault/router_test.go +++ b/vault/router_test.go @@ -55,12 +55,12 @@ func TestRouter_Mount(t *testing.T) { view := NewBarrierView(barrier, "logical/") n := &NoopBackend{} - err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view) + err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view) if err != nil { t.Fatalf("err: %v", err) } - err = r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view) + err = r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view) if !strings.Contains(err.Error(), "cannot mount under existing mount") { t.Fatalf("err: %v", err) } @@ -104,7 +104,7 @@ func TestRouter_Unmount(t *testing.T) { view := NewBarrierView(barrier, "logical/") n := &NoopBackend{} - err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view) + err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view) if err != nil { t.Fatalf("err: %v", err) } @@ -129,7 +129,7 @@ func TestRouter_Remount(t *testing.T) { view := NewBarrierView(barrier, "logical/") n := &NoopBackend{} - err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view) + err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view) if err != nil { t.Fatalf("err: %v", err) } @@ -177,7 +177,7 @@ func TestRouter_RootPath(t *testing.T) { "policy/*", }, } - err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view) + err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view) if err != nil { t.Fatalf("err: %v", err) } @@ -215,7 +215,7 @@ func TestRouter_LoginPath(t *testing.T) { "oauth/*", }, } - err := r.Mount(n, "auth/foo/", uuid.GenerateUUID(), view) + err := r.Mount(n, "auth/foo/", &MountEntry{UUID: uuid.GenerateUUID()}, view) if err != nil { t.Fatalf("err: %v", err) } @@ -246,7 +246,7 @@ func TestRouter_Taint(t *testing.T) { view := NewBarrierView(barrier, "logical/") n := &NoopBackend{} - err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view) + err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view) if err != nil { t.Fatalf("err: %v", err) } @@ -285,7 +285,7 @@ func TestRouter_Untaint(t *testing.T) { view := NewBarrierView(barrier, "logical/") n := &NoopBackend{} - err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view) + err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view) if err != nil { t.Fatalf("err: %v", err) } diff --git a/vault/testing.go b/vault/testing.go index a4b2b085a..e8e0e2fc4 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -62,10 +62,12 @@ func TestCore(t *testing.T) *Core { }, } noopBackends := make(map[string]logical.Factory) - noopBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) { - return new(framework.Backend), nil + noopBackends["noop"] = func(config *logical.BackendConfig) (logical.Backend, error) { + b := new(framework.Backend) + b.Setup(config) + return b, nil } - noopBackends["http"] = func(*logical.BackendConfig) (logical.Backend, error) { + noopBackends["http"] = func(config *logical.BackendConfig) (logical.Backend, error) { return new(rawHTTP), nil } logicalBackends := make(map[string]logical.Factory) diff --git a/vault/token_store.go b/vault/token_store.go index 37e22f7f7..7d3dec488 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -48,9 +48,9 @@ type TokenStore struct { // NewTokenStore is used to construct a token store that is // backed by the given barrier view. -func NewTokenStore(c *Core) (*TokenStore, error) { +func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error) { // Create a sub-view - view := c.systemView.SubView(tokenSubPath) + view := c.systemBarrierView.SubView(tokenSubPath) // Initialize the store t := &TokenStore{ @@ -203,6 +203,8 @@ func NewTokenStore(c *Core) (*TokenStore, error) { }, } + t.Backend.Setup(config) + return t, nil } diff --git a/vault/token_store_test.go b/vault/token_store_test.go index bda3cd415..7daac5e6c 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -10,19 +10,30 @@ import ( "github.com/hashicorp/vault/logical" ) +func getBackendConfig(c *Core) *logical.BackendConfig { + return &logical.BackendConfig{ + Logger: c.logger, + System: logical.StaticSystemView{ + DefaultLeaseTTLVal: time.Hour * 24, + MaxLeaseTTLVal: time.Hour * 24 * 30, + }, + } +} + func mockTokenStore(t *testing.T) (*Core, *TokenStore, string) { logger := log.New(os.Stderr, "", log.LstdFlags) c, _, root := TestCoreUnsealed(t) - ts, err := NewTokenStore(c) + + ts, err := NewTokenStore(c, getBackendConfig(c)) if err != nil { t.Fatalf("err: %v", err) } router := NewRouter() - router.Mount(ts, "auth/token/", "", ts.view) + router.Mount(ts, "auth/token/", &MountEntry{UUID: ""}, ts.view) - view := c.systemView.SubView(expirationSubPath) + view := c.systemBarrierView.SubView(expirationSubPath) exp := NewExpirationManager(router, view, ts, logger) ts.SetExpirationManager(exp) return c, ts, root @@ -68,7 +79,7 @@ func TestTokenStore_CreateLookup(t *testing.T) { } // New store should share the salt - ts2, err := NewTokenStore(c) + ts2, err := NewTokenStore(c, getBackendConfig(c)) if err != nil { t.Fatalf("err: %v", err) } @@ -107,7 +118,7 @@ func TestTokenStore_CreateLookup_ProvidedID(t *testing.T) { } // New store should share the salt - ts2, err := NewTokenStore(c) + ts2, err := NewTokenStore(c, getBackendConfig(c)) if err != nil { t.Fatalf("err: %v", err) } @@ -219,7 +230,7 @@ func TestTokenStore_Revoke_Leases(t *testing.T) { // Mount a noop backend noop := &NoopBackend{} - ts.expiration.router.Mount(noop, "", "", nil) + ts.expiration.router.Mount(noop, "", &MountEntry{UUID: ""}, nil) ent := &TokenEntry{Path: "test", Policies: []string{"dev", "ops"}} if err := ts.Create(ent); err != nil {