Add table/type checking to mounts table.

This commit is contained in:
Jeff Mitchell 2016-05-26 12:55:00 -04:00
parent 05d1da0656
commit 475b0e2d33
5 changed files with 136 additions and 13 deletions

View File

@ -965,8 +965,9 @@ func TestExpiration_RevokeForce(t *testing.T) {
core.logicalBackends["badrenew"] = badRenewFactory
me := &MountEntry{
Path: "badrenew/",
Type: "badrenew",
Table: mountTableType,
Path: "badrenew/",
Type: "badrenew",
}
err := core.mount(me)

View File

@ -720,6 +720,7 @@ func (b *SystemBackend) handleMount(
// Create the mount entry
me := &MountEntry{
Table: mountTableType,
Path: path,
Type: logicalType,
Description: description,

View File

@ -25,6 +25,9 @@ const (
// systemBarrierPrefix is the prefix used for the
// system logical backend.
systemBarrierPrefix = "sys/"
// mountTableType is the value we expect to find for the mount table and corresponding entries
mountTableType = "mounts"
)
var (
@ -55,6 +58,7 @@ var (
// MountTable is used to represent the internal mount table
type MountTable struct {
Type string `json:"type"`
Entries []*MountEntry `json:"entries"`
}
@ -64,6 +68,7 @@ type MountTable struct {
// if modifying entries rather than modifying the table itself
func (t *MountTable) ShallowClone() *MountTable {
mt := &MountTable{
Type: t.Type,
Entries: make([]*MountEntry, len(t.Entries)),
}
for i, e := range t.Entries {
@ -120,6 +125,7 @@ func (t *MountTable) Remove(path string) bool {
// MountEntry is used to represent a mount table entry
type MountEntry struct {
Table string `json:"table"` // The table it belongs to
Path string `json:"path"` // Mount Path
Type string `json:"type"` // Logical backend Type
Description string `json:"description"` // User-provided description
@ -142,6 +148,7 @@ func (e *MountEntry) Clone() *MountEntry {
optClone[k] = v
}
return &MountEntry{
Table: e.Table,
Path: e.Path,
Type: e.Type,
Description: e.Description,
@ -409,6 +416,13 @@ func (c *Core) loadMounts() error {
// by type only.
if c.mounts != nil {
needPersist := false
// Upgrade to typed mount table
if c.mounts.Type == "" {
c.mounts.Type = mountTableType
needPersist = true
}
for _, requiredMount := range requiredMountTable().Entries {
foundRequired := false
for _, coreMount := range c.mounts.Entries {
@ -423,6 +437,14 @@ func (c *Core) loadMounts() error {
}
}
// Upgrade to table-scoped entries
for _, entry := range c.mounts.Entries {
if entry.Table == "" {
entry.Table = c.mounts.Type
needPersist = true
}
}
// Done if we have restored the mount table and we don't need
// to persist
if !needPersist {
@ -441,6 +463,25 @@ func (c *Core) loadMounts() error {
// persistMounts is used to persist the mount table after modification
func (c *Core) persistMounts(table *MountTable) error {
if table.Type != mountTableType {
c.logger.Printf(
"[ERR] core: given table to persist has type %s but need type %s",
table.Type,
mountTableType)
return fmt.Errorf("invalid table type given, not persisting")
}
for _, entry := range table.Entries {
if entry.Table != table.Type {
c.logger.Printf(
"[ERR] core: entry in mount table with path %s has table value %s but is in table %s, refusing to persist",
entry.Path,
entry.Table,
table.Type)
return fmt.Errorf("invalid mount entry found, not persisting")
}
}
// Marshal the table
raw, err := json.Marshal(table)
if err != nil {
@ -574,12 +615,15 @@ func (c *Core) mountEntrySysView(me *MountEntry) logical.SystemView {
// defaultMountTable creates a default mount table
func defaultMountTable() *MountTable {
table := &MountTable{}
table := &MountTable{
Type: mountTableType,
}
mountUUID, err := uuid.GenerateUUID()
if err != nil {
panic(fmt.Sprintf("could not create default mount table UUID: %v", err))
}
genericMount := &MountEntry{
Table: mountTableType,
Path: "secret/",
Type: "generic",
Description: "generic secret storage",
@ -593,12 +637,15 @@ func defaultMountTable() *MountTable {
// requiredMountTable() creates a mount table with entries required
// to be available
func requiredMountTable() *MountTable {
table := &MountTable{}
table := &MountTable{
Type: mountTableType,
}
cubbyholeUUID, err := uuid.GenerateUUID()
if err != nil {
panic(fmt.Sprintf("could not create cubbyhole UUID: %v", err))
}
cubbyholeMount := &MountEntry{
Table: mountTableType,
Path: "cubbyhole/",
Type: "cubbyhole",
Description: "per-token private secret storage",
@ -610,6 +657,7 @@ func requiredMountTable() *MountTable {
panic(fmt.Sprintf("could not create sys UUID: %v", err))
}
sysMount := &MountEntry{
Table: mountTableType,
Path: "sys/",
Type: "system",
Description: "system endpoints used for control, policy and debugging",

View File

@ -1,6 +1,7 @@
package vault
import (
"encoding/json"
"reflect"
"testing"
"time"
@ -38,8 +39,9 @@ func TestCore_DefaultMountTable(t *testing.T) {
func TestCore_Mount(t *testing.T) {
c, key, _ := TestCoreUnsealed(t)
me := &MountEntry{
Path: "foo",
Type: "generic",
Table: mountTableType,
Path: "foo",
Type: "generic",
}
err := c.mount(me)
if err != nil {
@ -116,8 +118,9 @@ func TestCore_Unmount_Cleanup(t *testing.T) {
// Mount the noop backend
me := &MountEntry{
Path: "test/",
Type: "noop",
Table: mountTableType,
Path: "test/",
Type: "noop",
}
if err := c.mount(me); err != nil {
t.Fatalf("err: %v", err)
@ -233,8 +236,9 @@ func TestCore_Remount_Cleanup(t *testing.T) {
// Mount the noop backend
me := &MountEntry{
Path: "test/",
Type: "noop",
Table: mountTableType,
Path: "test/",
Type: "noop",
}
if err := c.mount(me); err != nil {
t.Fatalf("err: %v", err)
@ -320,6 +324,71 @@ func TestDefaultMountTable(t *testing.T) {
verifyDefaultTable(t, table)
}
func TestCore_MountTable_UpgradeToTyped(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
// Save the expected table
goodJson, err := json.Marshal(c.mounts)
if err != nil {
t.Fatal(err)
}
// Create a pre-typed version
c.mounts.Type = ""
for _, entry := range c.mounts.Entries {
entry.Table = ""
}
raw, err := json.Marshal(c.mounts)
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(raw, goodJson) {
t.Fatalf("bad: values here should be different")
}
entry := &Entry{
Key: coreMountConfigPath,
Value: raw,
}
if err := c.barrier.Put(entry); err != nil {
t.Fatal(err)
}
// It should load successfully and be upgraded and persisted
err = c.loadMounts()
if err != nil {
t.Fatal(err)
}
entry, err = c.barrier.Get(coreMountConfigPath)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(entry.Value, goodJson) {
t.Fatalf("bad: expected\n%s\ngot\n%s\n", string(goodJson), string(entry.Value))
}
// Now try saving invalid versions
c.mounts.Type = "auth"
if err := c.persistMounts(c.mounts); err == nil {
t.Fatal("expected error")
}
c.mounts.Type = mountTableType
c.mounts.Entries[0].Table = "foobar"
if err := c.persistMounts(c.mounts); err == nil {
t.Fatal("expected error")
}
c.mounts.Entries[0].Table = c.mounts.Type
if err := c.persistMounts(c.mounts); err != nil {
t.Fatal(err)
}
}
func verifyDefaultTable(t *testing.T, table *MountTable) {
if len(table.Entries) != 3 {
t.Fatalf("bad: %v", table.Entries)
@ -348,6 +417,9 @@ func verifyDefaultTable(t *testing.T, table *MountTable) {
t.Fatalf("bad: %v", entry)
}
}
if entry.Table != mountTableType {
t.Fatalf("bad: %v", entry)
}
if entry.Description == "" {
t.Fatalf("bad: %v", entry)
}

View File

@ -15,9 +15,10 @@ func TestRequestHandling_Wrapping(t *testing.T) {
meUUID, _ := uuid.GenerateUUID()
err := core.mount(&MountEntry{
UUID: meUUID,
Path: "wraptest",
Type: "generic",
Table: mountTableType,
UUID: meUUID,
Path: "wraptest",
Type: "generic",
})
if err != nil {
t.Fatalf("err: %v", err)