diff --git a/vault/audit.go b/vault/audit.go index 36828c806..5857bd19f 100644 --- a/vault/audit.go +++ b/vault/audit.go @@ -26,6 +26,10 @@ const ( // auditBarrierPrefix is the prefix to the UUID used in the // barrier view for the audit backends. auditBarrierPrefix = "audit/" + + // auditTableType is the value we expect to find for the audit table and + // corresponding entries + auditTableType = "audit" ) var ( @@ -146,6 +150,26 @@ func (c *Core) loadAudits() error { // Done if we have restored the audit table if c.audit != nil { + needPersist := false + + // Upgrade to typed auth table + if c.audit.Type == "" { + c.audit.Type = auditTableType + needPersist = true + } + + // Upgrade to table-scoped entries + for _, entry := range c.audit.Entries { + if entry.Table == "" { + entry.Table = c.audit.Type + needPersist = true + } + } + + if needPersist { + return c.persistAudit(c.audit) + } + return nil } @@ -159,6 +183,25 @@ func (c *Core) loadAudits() error { // persistAudit is used to persist the audit table after modification func (c *Core) persistAudit(table *MountTable) error { + if table.Type != auditTableType { + c.logger.Printf( + "[ERR] core: given table to persist has type %s but need type %s", + table.Type, + auditTableType) + return fmt.Errorf("invalid table type given, not persisting") + } + + for _, entry := range table.Entries { + if entry.Table != table.Type { + c.logger.Printf( + "[ERR] core: entry in audit table with path %s has table value %s but is in table %s, refusing to persist", + entry.Path, + entry.Table, + table.Type) + return fmt.Errorf("invalid audit entry found, not persisting") + } + } + // Marshal the table raw, err := json.Marshal(table) if err != nil { @@ -240,7 +283,9 @@ func (c *Core) newAuditBackend(t string, view logical.Storage, conf map[string]s // defaultAuditTable creates a default audit table func defaultAuditTable() *MountTable { - table := &MountTable{} + table := &MountTable{ + Type: auditTableType, + } return table } diff --git a/vault/audit_test.go b/vault/audit_test.go index 385c716d7..f673e0a26 100644 --- a/vault/audit_test.go +++ b/vault/audit_test.go @@ -56,8 +56,9 @@ func TestCore_EnableAudit(t *testing.T) { } me := &MountEntry{ - Path: "foo", - Type: "noop", + Table: auditTableType, + Path: "foo", + Type: "noop", } err := c.enableAudit(me) if err != nil { @@ -115,8 +116,9 @@ func TestCore_DisableAudit(t *testing.T) { } me := &MountEntry{ - Path: "foo", - Type: "noop", + Table: auditTableType, + Path: "foo", + Type: "noop", } err = c.enableAudit(me) if err != nil { @@ -196,6 +198,9 @@ func verifyDefaultAuditTable(t *testing.T, table *MountTable) { if len(table.Entries) != 0 { t.Fatalf("bad: %v", table.Entries) } + if table.Type != auditTableType { + t.Fatalf("bad: %v", *table) + } } func TestAuditBroker_LogRequest(t *testing.T) { diff --git a/vault/auth.go b/vault/auth.go index 1db4b7859..2c1f2d286 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -22,6 +22,10 @@ const ( // credentialRoutePrefix is the mount prefix used for the router credentialRoutePrefix = "auth/" + + // credentialTableType is the value we expect to find for the credential + // table and corresponding entries + credentialTableType = "auth" ) var ( @@ -210,6 +214,26 @@ func (c *Core) loadCredentials() error { // Done if we have restored the auth table if c.auth != nil { + needPersist := false + + // Upgrade to typed auth table + if c.auth.Type == "" { + c.auth.Type = credentialTableType + needPersist = true + } + + // Upgrade to table-scoped entries + for _, entry := range c.auth.Entries { + if entry.Table == "" { + entry.Table = c.auth.Type + needPersist = true + } + } + + if needPersist { + return c.persistAuth(c.auth) + } + return nil } @@ -224,6 +248,25 @@ func (c *Core) loadCredentials() error { // persistAuth is used to persist the auth table after modification func (c *Core) persistAuth(table *MountTable) error { + if table.Type != credentialTableType { + c.logger.Printf( + "[ERR] core: given table to persist has type %s but need type %s", + table.Type, + credentialTableType) + return fmt.Errorf("invalid table type given, not persisting") + } + + for _, entry := range table.Entries { + if entry.Table != table.Type { + c.logger.Printf( + "[ERR] core: entry in auth table with path %s has table value %s but is in table %s, refusing to persist", + entry.Path, + entry.Table, + table.Type) + return fmt.Errorf("invalid auth entry found, not persisting") + } + } + // Marshal the table raw, err := json.Marshal(table) if err != nil { @@ -341,12 +384,15 @@ func (c *Core) newCredentialBackend( // defaultAuthTable creates a default auth table func defaultAuthTable() *MountTable { - table := &MountTable{} + table := &MountTable{ + Type: credentialTableType, + } tokenUUID, err := uuid.GenerateUUID() if err != nil { panic(fmt.Sprintf("could not generate UUID for default auth table token entry: %v", err)) } tokenAuth := &MountEntry{ + Table: credentialTableType, Path: "token/", Type: "token", Description: "token based credentials", diff --git a/vault/auth_test.go b/vault/auth_test.go index edf103f95..40dbf8ef4 100644 --- a/vault/auth_test.go +++ b/vault/auth_test.go @@ -41,8 +41,9 @@ func TestCore_EnableCredential(t *testing.T) { } me := &MountEntry{ - Path: "foo", - Type: "noop", + Table: credentialTableType, + Path: "foo", + Type: "noop", } err := c.enableCredential(me) if err != nil { @@ -86,8 +87,9 @@ func TestCore_EnableCredential_twice_409(t *testing.T) { } me := &MountEntry{ - Path: "foo", - Type: "noop", + Table: credentialTableType, + Path: "foo", + Type: "noop", } err := c.enableCredential(me) if err != nil { @@ -109,8 +111,9 @@ func TestCore_EnableCredential_twice_409(t *testing.T) { func TestCore_EnableCredential_Token(t *testing.T) { c, _, _ := TestCoreUnsealed(t) me := &MountEntry{ - Path: "foo", - Type: "token", + Table: credentialTableType, + Path: "foo", + Type: "token", } err := c.enableCredential(me) if err.Error() != "token credential backend cannot be instantiated" { @@ -130,8 +133,9 @@ func TestCore_DisableCredential(t *testing.T) { } me := &MountEntry{ - Path: "foo", - Type: "noop", + Table: credentialTableType, + Path: "foo", + Type: "noop", } err = c.enableCredential(me) if err != nil { @@ -188,8 +192,9 @@ func TestCore_DisableCredential_Cleanup(t *testing.T) { } me := &MountEntry{ - Path: "foo", - Type: "noop", + Table: credentialTableType, + Path: "foo", + Type: "noop", } err := c.enableCredential(me) if err != nil { @@ -260,6 +265,9 @@ func verifyDefaultAuthTable(t *testing.T, table *MountTable) { if len(table.Entries) != 1 { t.Fatalf("bad: %v", table.Entries) } + if table.Type != credentialTableType { + t.Fatalf("bad: %v", *table) + } for idx, entry := range table.Entries { switch idx { case 0: diff --git a/vault/expiration_test.go b/vault/expiration_test.go index ae8d45b46..c5a4f8004 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -965,8 +965,9 @@ func TestExpiration_RevokeForce(t *testing.T) { core.logicalBackends["badrenew"] = badRenewFactory me := &MountEntry{ - Path: "badrenew/", - Type: "badrenew", + Table: mountTableType, + Path: "badrenew/", + Type: "badrenew", } err := core.mount(me) diff --git a/vault/logical_system.go b/vault/logical_system.go index d3976b016..a0784a068 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -720,6 +720,7 @@ func (b *SystemBackend) handleMount( // Create the mount entry me := &MountEntry{ + Table: mountTableType, Path: path, Type: logicalType, Description: description, @@ -1001,6 +1002,7 @@ func (b *SystemBackend) handleEnableAuth( // Create the mount entry me := &MountEntry{ + Table: credentialTableType, Path: path, Type: logicalType, Description: description, @@ -1169,6 +1171,7 @@ func (b *SystemBackend) handleEnableAudit( // Create the mount entry me := &MountEntry{ + Table: auditTableType, Path: path, Type: backendType, Description: description, diff --git a/vault/mount.go b/vault/mount.go index 7a06e8e0d..d809a9a18 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -25,6 +25,10 @@ const ( // systemBarrierPrefix is the prefix used for the // system logical backend. systemBarrierPrefix = "sys/" + + // mountTableType is the value we expect to find for the mount table and + // corresponding entries + mountTableType = "mounts" ) var ( @@ -55,6 +59,7 @@ var ( // MountTable is used to represent the internal mount table type MountTable struct { + Type string `json:"type"` Entries []*MountEntry `json:"entries"` } @@ -64,6 +69,7 @@ type MountTable struct { // if modifying entries rather than modifying the table itself func (t *MountTable) ShallowClone() *MountTable { mt := &MountTable{ + Type: t.Type, Entries: make([]*MountEntry, len(t.Entries)), } for i, e := range t.Entries { @@ -120,6 +126,7 @@ func (t *MountTable) Remove(path string) bool { // MountEntry is used to represent a mount table entry type MountEntry struct { + Table string `json:"table"` // The table it belongs to Path string `json:"path"` // Mount Path Type string `json:"type"` // Logical backend Type Description string `json:"description"` // User-provided description @@ -142,6 +149,7 @@ func (e *MountEntry) Clone() *MountEntry { optClone[k] = v } return &MountEntry{ + Table: e.Table, Path: e.Path, Type: e.Type, Description: e.Description, @@ -409,6 +417,13 @@ func (c *Core) loadMounts() error { // by type only. if c.mounts != nil { needPersist := false + + // Upgrade to typed mount table + if c.mounts.Type == "" { + c.mounts.Type = mountTableType + needPersist = true + } + for _, requiredMount := range requiredMountTable().Entries { foundRequired := false for _, coreMount := range c.mounts.Entries { @@ -423,6 +438,14 @@ func (c *Core) loadMounts() error { } } + // Upgrade to table-scoped entries + for _, entry := range c.mounts.Entries { + if entry.Table == "" { + entry.Table = c.mounts.Type + needPersist = true + } + } + // Done if we have restored the mount table and we don't need // to persist if !needPersist { @@ -441,6 +464,25 @@ func (c *Core) loadMounts() error { // persistMounts is used to persist the mount table after modification func (c *Core) persistMounts(table *MountTable) error { + if table.Type != mountTableType { + c.logger.Printf( + "[ERR] core: given table to persist has type %s but need type %s", + table.Type, + mountTableType) + return fmt.Errorf("invalid table type given, not persisting") + } + + for _, entry := range table.Entries { + if entry.Table != table.Type { + c.logger.Printf( + "[ERR] core: entry in mount table with path %s has table value %s but is in table %s, refusing to persist", + entry.Path, + entry.Table, + table.Type) + return fmt.Errorf("invalid mount entry found, not persisting") + } + } + // Marshal the table raw, err := json.Marshal(table) if err != nil { @@ -574,12 +616,15 @@ func (c *Core) mountEntrySysView(me *MountEntry) logical.SystemView { // defaultMountTable creates a default mount table func defaultMountTable() *MountTable { - table := &MountTable{} + table := &MountTable{ + Type: mountTableType, + } mountUUID, err := uuid.GenerateUUID() if err != nil { panic(fmt.Sprintf("could not create default mount table UUID: %v", err)) } genericMount := &MountEntry{ + Table: mountTableType, Path: "secret/", Type: "generic", Description: "generic secret storage", @@ -593,12 +638,15 @@ func defaultMountTable() *MountTable { // requiredMountTable() creates a mount table with entries required // to be available func requiredMountTable() *MountTable { - table := &MountTable{} + table := &MountTable{ + Type: mountTableType, + } cubbyholeUUID, err := uuid.GenerateUUID() if err != nil { panic(fmt.Sprintf("could not create cubbyhole UUID: %v", err)) } cubbyholeMount := &MountEntry{ + Table: mountTableType, Path: "cubbyhole/", Type: "cubbyhole", Description: "per-token private secret storage", @@ -610,6 +658,7 @@ func requiredMountTable() *MountTable { panic(fmt.Sprintf("could not create sys UUID: %v", err)) } sysMount := &MountEntry{ + Table: mountTableType, Path: "sys/", Type: "system", Description: "system endpoints used for control, policy and debugging", diff --git a/vault/mount_test.go b/vault/mount_test.go index aa90bfd4d..dde024144 100644 --- a/vault/mount_test.go +++ b/vault/mount_test.go @@ -1,10 +1,12 @@ package vault import ( + "encoding/json" "reflect" "testing" "time" + "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/logical" ) @@ -38,8 +40,9 @@ func TestCore_DefaultMountTable(t *testing.T) { func TestCore_Mount(t *testing.T) { c, key, _ := TestCoreUnsealed(t) me := &MountEntry{ - Path: "foo", - Type: "generic", + Table: mountTableType, + Path: "foo", + Type: "generic", } err := c.mount(me) if err != nil { @@ -116,8 +119,9 @@ func TestCore_Unmount_Cleanup(t *testing.T) { // Mount the noop backend me := &MountEntry{ - Path: "test/", - Type: "noop", + Table: mountTableType, + Path: "test/", + Type: "noop", } if err := c.mount(me); err != nil { t.Fatalf("err: %v", err) @@ -233,8 +237,9 @@ func TestCore_Remount_Cleanup(t *testing.T) { // Mount the noop backend me := &MountEntry{ - Path: "test/", - Type: "noop", + Table: mountTableType, + Path: "test/", + Type: "noop", } if err := c.mount(me); err != nil { t.Fatalf("err: %v", err) @@ -320,6 +325,143 @@ func TestDefaultMountTable(t *testing.T) { verifyDefaultTable(t, table) } +func TestCore_MountTable_UpgradeToTyped(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + + c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { + return &NoopAudit{ + Config: config, + }, nil + } + + me := &MountEntry{ + Table: auditTableType, + Path: "foo", + Type: "noop", + } + err := c.enableAudit(me) + if err != nil { + t.Fatalf("err: %v", err) + } + + c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) { + return &NoopBackend{}, nil + } + + me = &MountEntry{ + Table: credentialTableType, + Path: "foo", + Type: "noop", + } + err = c.enableCredential(me) + if err != nil { + t.Fatalf("err: %v", err) + } + + testCore_MountTable_UpgradeToTyped_Common(t, c, "mounts") + testCore_MountTable_UpgradeToTyped_Common(t, c, "audits") + testCore_MountTable_UpgradeToTyped_Common(t, c, "credentials") +} + +func testCore_MountTable_UpgradeToTyped_Common( + t *testing.T, + c *Core, + testType string) { + + var path string + var mt *MountTable + switch testType { + case "mounts": + path = coreMountConfigPath + mt = c.mounts + case "audits": + path = coreAuditConfigPath + mt = c.audit + case "credentials": + path = coreAuthConfigPath + mt = c.auth + } + + // Save the expected table + goodJson, err := json.Marshal(mt) + if err != nil { + t.Fatal(err) + } + + // Create a pre-typed version + mt.Type = "" + for _, entry := range mt.Entries { + entry.Table = "" + } + + raw, err := json.Marshal(mt) + if err != nil { + t.Fatal(err) + } + + if reflect.DeepEqual(raw, goodJson) { + t.Fatalf("bad: values here should be different") + } + + entry := &Entry{ + Key: path, + Value: raw, + } + if err := c.barrier.Put(entry); err != nil { + t.Fatal(err) + } + + var persistFunc func(*MountTable) error + + // It should load successfully and be upgraded and persisted + switch testType { + case "mounts": + err = c.loadMounts() + persistFunc = c.persistMounts + mt = c.mounts + case "credentials": + err = c.loadCredentials() + persistFunc = c.persistAuth + mt = c.auth + case "audits": + err = c.loadAudits() + persistFunc = c.persistAudit + mt = c.audit + } + if err != nil { + t.Fatal(err) + } + + entry, err = c.barrier.Get(path) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(entry.Value, goodJson) { + t.Fatalf("bad: expected\n%s\ngot\n%s\n", string(goodJson), string(entry.Value)) + } + + // Now try saving invalid versions + origTableType := mt.Type + mt.Type = "foo" + if err := persistFunc(mt); err == nil { + t.Fatal("expected error") + } + + if len(mt.Entries) > 0 { + mt.Type = origTableType + mt.Entries[0].Table = "bar" + if err := persistFunc(mt); err == nil { + t.Fatal("expected error") + } + + mt.Entries[0].Table = mt.Type + if err := persistFunc(mt); err != nil { + t.Fatal(err) + } + } +} + func verifyDefaultTable(t *testing.T, table *MountTable) { if len(table.Entries) != 3 { t.Fatalf("bad: %v", table.Entries) @@ -348,6 +490,9 @@ func verifyDefaultTable(t *testing.T, table *MountTable) { t.Fatalf("bad: %v", entry) } } + if entry.Table != mountTableType { + t.Fatalf("bad: %v", entry) + } if entry.Description == "" { t.Fatalf("bad: %v", entry) } diff --git a/vault/request_handling_test.go b/vault/request_handling_test.go index 02003d843..a7659ec48 100644 --- a/vault/request_handling_test.go +++ b/vault/request_handling_test.go @@ -15,9 +15,10 @@ func TestRequestHandling_Wrapping(t *testing.T) { meUUID, _ := uuid.GenerateUUID() err := core.mount(&MountEntry{ - UUID: meUUID, - Path: "wraptest", - Type: "generic", + Table: mountTableType, + UUID: meUUID, + Path: "wraptest", + Type: "generic", }) if err != nil { t.Fatalf("err: %v", err) diff --git a/vault/testing.go b/vault/testing.go index 1f1647709..e84e598bd 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -167,6 +167,7 @@ func TestCoreWithTokenStore(t *testing.T) (*Core, *TokenStore, []byte, string) { c, key, root := TestCoreUnsealed(t) me := &MountEntry{ + Table: credentialTableType, Path: "token/", Type: "token", Description: "token based credentials", @@ -184,7 +185,7 @@ func TestCoreWithTokenStore(t *testing.T) (*Core, *TokenStore, []byte, string) { ts := tokenstore.(*TokenStore) router := NewRouter() - router.Mount(ts, "auth/token/", &MountEntry{UUID: ""}, ts.view) + router.Mount(ts, "auth/token/", &MountEntry{Table: credentialTableType, UUID: ""}, ts.view) subview := c.systemBarrierView.SubView(expirationSubPath) logger := log.New(os.Stderr, "", log.LstdFlags)