From 475b0e2d339553a6fd697f71207be54e9cbe77bc Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 26 May 2016 12:55:00 -0400 Subject: [PATCH] Add table/type checking to mounts table. --- vault/expiration_test.go | 5 +- vault/logical_system.go | 1 + vault/mount.go | 52 ++++++++++++++++++++- vault/mount_test.go | 84 +++++++++++++++++++++++++++++++--- vault/request_handling_test.go | 7 +-- 5 files changed, 136 insertions(+), 13 deletions(-) diff --git a/vault/expiration_test.go b/vault/expiration_test.go index ae8d45b46..c5a4f8004 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -965,8 +965,9 @@ func TestExpiration_RevokeForce(t *testing.T) { core.logicalBackends["badrenew"] = badRenewFactory me := &MountEntry{ - Path: "badrenew/", - Type: "badrenew", + Table: mountTableType, + Path: "badrenew/", + Type: "badrenew", } err := core.mount(me) diff --git a/vault/logical_system.go b/vault/logical_system.go index d3976b016..06b7352de 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -720,6 +720,7 @@ func (b *SystemBackend) handleMount( // Create the mount entry me := &MountEntry{ + Table: mountTableType, Path: path, Type: logicalType, Description: description, diff --git a/vault/mount.go b/vault/mount.go index 7a06e8e0d..ff5be16c9 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -25,6 +25,9 @@ const ( // systemBarrierPrefix is the prefix used for the // system logical backend. systemBarrierPrefix = "sys/" + + // mountTableType is the value we expect to find for the mount table and corresponding entries + mountTableType = "mounts" ) var ( @@ -55,6 +58,7 @@ var ( // MountTable is used to represent the internal mount table type MountTable struct { + Type string `json:"type"` Entries []*MountEntry `json:"entries"` } @@ -64,6 +68,7 @@ type MountTable struct { // if modifying entries rather than modifying the table itself func (t *MountTable) ShallowClone() *MountTable { mt := &MountTable{ + Type: t.Type, Entries: make([]*MountEntry, len(t.Entries)), } for i, e := range t.Entries { @@ -120,6 +125,7 @@ func (t *MountTable) Remove(path string) bool { // MountEntry is used to represent a mount table entry type MountEntry struct { + Table string `json:"table"` // The table it belongs to Path string `json:"path"` // Mount Path Type string `json:"type"` // Logical backend Type Description string `json:"description"` // User-provided description @@ -142,6 +148,7 @@ func (e *MountEntry) Clone() *MountEntry { optClone[k] = v } return &MountEntry{ + Table: e.Table, Path: e.Path, Type: e.Type, Description: e.Description, @@ -409,6 +416,13 @@ func (c *Core) loadMounts() error { // by type only. if c.mounts != nil { needPersist := false + + // Upgrade to typed mount table + if c.mounts.Type == "" { + c.mounts.Type = mountTableType + needPersist = true + } + for _, requiredMount := range requiredMountTable().Entries { foundRequired := false for _, coreMount := range c.mounts.Entries { @@ -423,6 +437,14 @@ func (c *Core) loadMounts() error { } } + // Upgrade to table-scoped entries + for _, entry := range c.mounts.Entries { + if entry.Table == "" { + entry.Table = c.mounts.Type + needPersist = true + } + } + // Done if we have restored the mount table and we don't need // to persist if !needPersist { @@ -441,6 +463,25 @@ func (c *Core) loadMounts() error { // persistMounts is used to persist the mount table after modification func (c *Core) persistMounts(table *MountTable) error { + if table.Type != mountTableType { + c.logger.Printf( + "[ERR] core: given table to persist has type %s but need type %s", + table.Type, + mountTableType) + return fmt.Errorf("invalid table type given, not persisting") + } + + for _, entry := range table.Entries { + if entry.Table != table.Type { + c.logger.Printf( + "[ERR] core: entry in mount table with path %s has table value %s but is in table %s, refusing to persist", + entry.Path, + entry.Table, + table.Type) + return fmt.Errorf("invalid mount entry found, not persisting") + } + } + // Marshal the table raw, err := json.Marshal(table) if err != nil { @@ -574,12 +615,15 @@ func (c *Core) mountEntrySysView(me *MountEntry) logical.SystemView { // defaultMountTable creates a default mount table func defaultMountTable() *MountTable { - table := &MountTable{} + table := &MountTable{ + Type: mountTableType, + } mountUUID, err := uuid.GenerateUUID() if err != nil { panic(fmt.Sprintf("could not create default mount table UUID: %v", err)) } genericMount := &MountEntry{ + Table: mountTableType, Path: "secret/", Type: "generic", Description: "generic secret storage", @@ -593,12 +637,15 @@ func defaultMountTable() *MountTable { // requiredMountTable() creates a mount table with entries required // to be available func requiredMountTable() *MountTable { - table := &MountTable{} + table := &MountTable{ + Type: mountTableType, + } cubbyholeUUID, err := uuid.GenerateUUID() if err != nil { panic(fmt.Sprintf("could not create cubbyhole UUID: %v", err)) } cubbyholeMount := &MountEntry{ + Table: mountTableType, Path: "cubbyhole/", Type: "cubbyhole", Description: "per-token private secret storage", @@ -610,6 +657,7 @@ func requiredMountTable() *MountTable { panic(fmt.Sprintf("could not create sys UUID: %v", err)) } sysMount := &MountEntry{ + Table: mountTableType, Path: "sys/", Type: "system", Description: "system endpoints used for control, policy and debugging", diff --git a/vault/mount_test.go b/vault/mount_test.go index aa90bfd4d..fc269f739 100644 --- a/vault/mount_test.go +++ b/vault/mount_test.go @@ -1,6 +1,7 @@ package vault import ( + "encoding/json" "reflect" "testing" "time" @@ -38,8 +39,9 @@ func TestCore_DefaultMountTable(t *testing.T) { func TestCore_Mount(t *testing.T) { c, key, _ := TestCoreUnsealed(t) me := &MountEntry{ - Path: "foo", - Type: "generic", + Table: mountTableType, + Path: "foo", + Type: "generic", } err := c.mount(me) if err != nil { @@ -116,8 +118,9 @@ func TestCore_Unmount_Cleanup(t *testing.T) { // Mount the noop backend me := &MountEntry{ - Path: "test/", - Type: "noop", + Table: mountTableType, + Path: "test/", + Type: "noop", } if err := c.mount(me); err != nil { t.Fatalf("err: %v", err) @@ -233,8 +236,9 @@ func TestCore_Remount_Cleanup(t *testing.T) { // Mount the noop backend me := &MountEntry{ - Path: "test/", - Type: "noop", + Table: mountTableType, + Path: "test/", + Type: "noop", } if err := c.mount(me); err != nil { t.Fatalf("err: %v", err) @@ -320,6 +324,71 @@ func TestDefaultMountTable(t *testing.T) { verifyDefaultTable(t, table) } +func TestCore_MountTable_UpgradeToTyped(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + + // Save the expected table + goodJson, err := json.Marshal(c.mounts) + if err != nil { + t.Fatal(err) + } + + // Create a pre-typed version + c.mounts.Type = "" + for _, entry := range c.mounts.Entries { + entry.Table = "" + } + + raw, err := json.Marshal(c.mounts) + if err != nil { + t.Fatal(err) + } + + if reflect.DeepEqual(raw, goodJson) { + t.Fatalf("bad: values here should be different") + } + + entry := &Entry{ + Key: coreMountConfigPath, + Value: raw, + } + if err := c.barrier.Put(entry); err != nil { + t.Fatal(err) + } + + // It should load successfully and be upgraded and persisted + err = c.loadMounts() + if err != nil { + t.Fatal(err) + } + + entry, err = c.barrier.Get(coreMountConfigPath) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(entry.Value, goodJson) { + t.Fatalf("bad: expected\n%s\ngot\n%s\n", string(goodJson), string(entry.Value)) + } + + // Now try saving invalid versions + c.mounts.Type = "auth" + if err := c.persistMounts(c.mounts); err == nil { + t.Fatal("expected error") + } + + c.mounts.Type = mountTableType + c.mounts.Entries[0].Table = "foobar" + if err := c.persistMounts(c.mounts); err == nil { + t.Fatal("expected error") + } + + c.mounts.Entries[0].Table = c.mounts.Type + if err := c.persistMounts(c.mounts); err != nil { + t.Fatal(err) + } +} + func verifyDefaultTable(t *testing.T, table *MountTable) { if len(table.Entries) != 3 { t.Fatalf("bad: %v", table.Entries) @@ -348,6 +417,9 @@ func verifyDefaultTable(t *testing.T, table *MountTable) { t.Fatalf("bad: %v", entry) } } + if entry.Table != mountTableType { + t.Fatalf("bad: %v", entry) + } if entry.Description == "" { t.Fatalf("bad: %v", entry) } diff --git a/vault/request_handling_test.go b/vault/request_handling_test.go index 02003d843..a7659ec48 100644 --- a/vault/request_handling_test.go +++ b/vault/request_handling_test.go @@ -15,9 +15,10 @@ func TestRequestHandling_Wrapping(t *testing.T) { meUUID, _ := uuid.GenerateUUID() err := core.mount(&MountEntry{ - UUID: meUUID, - Path: "wraptest", - Type: "generic", + Table: mountTableType, + UUID: meUUID, + Path: "wraptest", + Type: "generic", }) if err != nil { t.Fatalf("err: %v", err)