Merge pull request #1461 from hashicorp/check-table-sanity

Add table/type checking to mounts table.
This commit is contained in:
Jeff Mitchell 2016-05-26 13:59:00 -04:00
commit 4152ee31d0
10 changed files with 334 additions and 30 deletions

View file

@ -26,6 +26,10 @@ const (
// auditBarrierPrefix is the prefix to the UUID used in the // auditBarrierPrefix is the prefix to the UUID used in the
// barrier view for the audit backends. // barrier view for the audit backends.
auditBarrierPrefix = "audit/" auditBarrierPrefix = "audit/"
// auditTableType is the value we expect to find for the audit table and
// corresponding entries
auditTableType = "audit"
) )
var ( var (
@ -146,6 +150,26 @@ func (c *Core) loadAudits() error {
// Done if we have restored the audit table // Done if we have restored the audit table
if c.audit != nil { if c.audit != nil {
needPersist := false
// Upgrade to typed auth table
if c.audit.Type == "" {
c.audit.Type = auditTableType
needPersist = true
}
// Upgrade to table-scoped entries
for _, entry := range c.audit.Entries {
if entry.Table == "" {
entry.Table = c.audit.Type
needPersist = true
}
}
if needPersist {
return c.persistAudit(c.audit)
}
return nil return nil
} }
@ -159,6 +183,25 @@ func (c *Core) loadAudits() error {
// persistAudit is used to persist the audit table after modification // persistAudit is used to persist the audit table after modification
func (c *Core) persistAudit(table *MountTable) error { func (c *Core) persistAudit(table *MountTable) error {
if table.Type != auditTableType {
c.logger.Printf(
"[ERR] core: given table to persist has type %s but need type %s",
table.Type,
auditTableType)
return fmt.Errorf("invalid table type given, not persisting")
}
for _, entry := range table.Entries {
if entry.Table != table.Type {
c.logger.Printf(
"[ERR] core: entry in audit table with path %s has table value %s but is in table %s, refusing to persist",
entry.Path,
entry.Table,
table.Type)
return fmt.Errorf("invalid audit entry found, not persisting")
}
}
// Marshal the table // Marshal the table
raw, err := json.Marshal(table) raw, err := json.Marshal(table)
if err != nil { if err != nil {
@ -240,7 +283,9 @@ func (c *Core) newAuditBackend(t string, view logical.Storage, conf map[string]s
// defaultAuditTable creates a default audit table // defaultAuditTable creates a default audit table
func defaultAuditTable() *MountTable { func defaultAuditTable() *MountTable {
table := &MountTable{} table := &MountTable{
Type: auditTableType,
}
return table return table
} }

View file

@ -56,6 +56,7 @@ func TestCore_EnableAudit(t *testing.T) {
} }
me := &MountEntry{ me := &MountEntry{
Table: auditTableType,
Path: "foo", Path: "foo",
Type: "noop", Type: "noop",
} }
@ -115,6 +116,7 @@ func TestCore_DisableAudit(t *testing.T) {
} }
me := &MountEntry{ me := &MountEntry{
Table: auditTableType,
Path: "foo", Path: "foo",
Type: "noop", Type: "noop",
} }
@ -196,6 +198,9 @@ func verifyDefaultAuditTable(t *testing.T, table *MountTable) {
if len(table.Entries) != 0 { if len(table.Entries) != 0 {
t.Fatalf("bad: %v", table.Entries) t.Fatalf("bad: %v", table.Entries)
} }
if table.Type != auditTableType {
t.Fatalf("bad: %v", *table)
}
} }
func TestAuditBroker_LogRequest(t *testing.T) { func TestAuditBroker_LogRequest(t *testing.T) {

View file

@ -22,6 +22,10 @@ const (
// credentialRoutePrefix is the mount prefix used for the router // credentialRoutePrefix is the mount prefix used for the router
credentialRoutePrefix = "auth/" credentialRoutePrefix = "auth/"
// credentialTableType is the value we expect to find for the credential
// table and corresponding entries
credentialTableType = "auth"
) )
var ( var (
@ -210,6 +214,26 @@ func (c *Core) loadCredentials() error {
// Done if we have restored the auth table // Done if we have restored the auth table
if c.auth != nil { if c.auth != nil {
needPersist := false
// Upgrade to typed auth table
if c.auth.Type == "" {
c.auth.Type = credentialTableType
needPersist = true
}
// Upgrade to table-scoped entries
for _, entry := range c.auth.Entries {
if entry.Table == "" {
entry.Table = c.auth.Type
needPersist = true
}
}
if needPersist {
return c.persistAuth(c.auth)
}
return nil return nil
} }
@ -224,6 +248,25 @@ func (c *Core) loadCredentials() error {
// persistAuth is used to persist the auth table after modification // persistAuth is used to persist the auth table after modification
func (c *Core) persistAuth(table *MountTable) error { func (c *Core) persistAuth(table *MountTable) error {
if table.Type != credentialTableType {
c.logger.Printf(
"[ERR] core: given table to persist has type %s but need type %s",
table.Type,
credentialTableType)
return fmt.Errorf("invalid table type given, not persisting")
}
for _, entry := range table.Entries {
if entry.Table != table.Type {
c.logger.Printf(
"[ERR] core: entry in auth table with path %s has table value %s but is in table %s, refusing to persist",
entry.Path,
entry.Table,
table.Type)
return fmt.Errorf("invalid auth entry found, not persisting")
}
}
// Marshal the table // Marshal the table
raw, err := json.Marshal(table) raw, err := json.Marshal(table)
if err != nil { if err != nil {
@ -341,12 +384,15 @@ func (c *Core) newCredentialBackend(
// defaultAuthTable creates a default auth table // defaultAuthTable creates a default auth table
func defaultAuthTable() *MountTable { func defaultAuthTable() *MountTable {
table := &MountTable{} table := &MountTable{
Type: credentialTableType,
}
tokenUUID, err := uuid.GenerateUUID() tokenUUID, err := uuid.GenerateUUID()
if err != nil { if err != nil {
panic(fmt.Sprintf("could not generate UUID for default auth table token entry: %v", err)) panic(fmt.Sprintf("could not generate UUID for default auth table token entry: %v", err))
} }
tokenAuth := &MountEntry{ tokenAuth := &MountEntry{
Table: credentialTableType,
Path: "token/", Path: "token/",
Type: "token", Type: "token",
Description: "token based credentials", Description: "token based credentials",

View file

@ -41,6 +41,7 @@ func TestCore_EnableCredential(t *testing.T) {
} }
me := &MountEntry{ me := &MountEntry{
Table: credentialTableType,
Path: "foo", Path: "foo",
Type: "noop", Type: "noop",
} }
@ -86,6 +87,7 @@ func TestCore_EnableCredential_twice_409(t *testing.T) {
} }
me := &MountEntry{ me := &MountEntry{
Table: credentialTableType,
Path: "foo", Path: "foo",
Type: "noop", Type: "noop",
} }
@ -109,6 +111,7 @@ func TestCore_EnableCredential_twice_409(t *testing.T) {
func TestCore_EnableCredential_Token(t *testing.T) { func TestCore_EnableCredential_Token(t *testing.T) {
c, _, _ := TestCoreUnsealed(t) c, _, _ := TestCoreUnsealed(t)
me := &MountEntry{ me := &MountEntry{
Table: credentialTableType,
Path: "foo", Path: "foo",
Type: "token", Type: "token",
} }
@ -130,6 +133,7 @@ func TestCore_DisableCredential(t *testing.T) {
} }
me := &MountEntry{ me := &MountEntry{
Table: credentialTableType,
Path: "foo", Path: "foo",
Type: "noop", Type: "noop",
} }
@ -188,6 +192,7 @@ func TestCore_DisableCredential_Cleanup(t *testing.T) {
} }
me := &MountEntry{ me := &MountEntry{
Table: credentialTableType,
Path: "foo", Path: "foo",
Type: "noop", Type: "noop",
} }
@ -260,6 +265,9 @@ func verifyDefaultAuthTable(t *testing.T, table *MountTable) {
if len(table.Entries) != 1 { if len(table.Entries) != 1 {
t.Fatalf("bad: %v", table.Entries) t.Fatalf("bad: %v", table.Entries)
} }
if table.Type != credentialTableType {
t.Fatalf("bad: %v", *table)
}
for idx, entry := range table.Entries { for idx, entry := range table.Entries {
switch idx { switch idx {
case 0: case 0:

View file

@ -965,6 +965,7 @@ func TestExpiration_RevokeForce(t *testing.T) {
core.logicalBackends["badrenew"] = badRenewFactory core.logicalBackends["badrenew"] = badRenewFactory
me := &MountEntry{ me := &MountEntry{
Table: mountTableType,
Path: "badrenew/", Path: "badrenew/",
Type: "badrenew", Type: "badrenew",
} }

View file

@ -720,6 +720,7 @@ func (b *SystemBackend) handleMount(
// Create the mount entry // Create the mount entry
me := &MountEntry{ me := &MountEntry{
Table: mountTableType,
Path: path, Path: path,
Type: logicalType, Type: logicalType,
Description: description, Description: description,
@ -1001,6 +1002,7 @@ func (b *SystemBackend) handleEnableAuth(
// Create the mount entry // Create the mount entry
me := &MountEntry{ me := &MountEntry{
Table: credentialTableType,
Path: path, Path: path,
Type: logicalType, Type: logicalType,
Description: description, Description: description,
@ -1169,6 +1171,7 @@ func (b *SystemBackend) handleEnableAudit(
// Create the mount entry // Create the mount entry
me := &MountEntry{ me := &MountEntry{
Table: auditTableType,
Path: path, Path: path,
Type: backendType, Type: backendType,
Description: description, Description: description,

View file

@ -25,6 +25,10 @@ const (
// systemBarrierPrefix is the prefix used for the // systemBarrierPrefix is the prefix used for the
// system logical backend. // system logical backend.
systemBarrierPrefix = "sys/" systemBarrierPrefix = "sys/"
// mountTableType is the value we expect to find for the mount table and
// corresponding entries
mountTableType = "mounts"
) )
var ( var (
@ -55,6 +59,7 @@ var (
// MountTable is used to represent the internal mount table // MountTable is used to represent the internal mount table
type MountTable struct { type MountTable struct {
Type string `json:"type"`
Entries []*MountEntry `json:"entries"` Entries []*MountEntry `json:"entries"`
} }
@ -64,6 +69,7 @@ type MountTable struct {
// if modifying entries rather than modifying the table itself // if modifying entries rather than modifying the table itself
func (t *MountTable) ShallowClone() *MountTable { func (t *MountTable) ShallowClone() *MountTable {
mt := &MountTable{ mt := &MountTable{
Type: t.Type,
Entries: make([]*MountEntry, len(t.Entries)), Entries: make([]*MountEntry, len(t.Entries)),
} }
for i, e := range t.Entries { for i, e := range t.Entries {
@ -120,6 +126,7 @@ func (t *MountTable) Remove(path string) bool {
// MountEntry is used to represent a mount table entry // MountEntry is used to represent a mount table entry
type MountEntry struct { type MountEntry struct {
Table string `json:"table"` // The table it belongs to
Path string `json:"path"` // Mount Path Path string `json:"path"` // Mount Path
Type string `json:"type"` // Logical backend Type Type string `json:"type"` // Logical backend Type
Description string `json:"description"` // User-provided description Description string `json:"description"` // User-provided description
@ -142,6 +149,7 @@ func (e *MountEntry) Clone() *MountEntry {
optClone[k] = v optClone[k] = v
} }
return &MountEntry{ return &MountEntry{
Table: e.Table,
Path: e.Path, Path: e.Path,
Type: e.Type, Type: e.Type,
Description: e.Description, Description: e.Description,
@ -409,6 +417,13 @@ func (c *Core) loadMounts() error {
// by type only. // by type only.
if c.mounts != nil { if c.mounts != nil {
needPersist := false needPersist := false
// Upgrade to typed mount table
if c.mounts.Type == "" {
c.mounts.Type = mountTableType
needPersist = true
}
for _, requiredMount := range requiredMountTable().Entries { for _, requiredMount := range requiredMountTable().Entries {
foundRequired := false foundRequired := false
for _, coreMount := range c.mounts.Entries { for _, coreMount := range c.mounts.Entries {
@ -423,6 +438,14 @@ func (c *Core) loadMounts() error {
} }
} }
// Upgrade to table-scoped entries
for _, entry := range c.mounts.Entries {
if entry.Table == "" {
entry.Table = c.mounts.Type
needPersist = true
}
}
// Done if we have restored the mount table and we don't need // Done if we have restored the mount table and we don't need
// to persist // to persist
if !needPersist { if !needPersist {
@ -441,6 +464,25 @@ func (c *Core) loadMounts() error {
// persistMounts is used to persist the mount table after modification // persistMounts is used to persist the mount table after modification
func (c *Core) persistMounts(table *MountTable) error { func (c *Core) persistMounts(table *MountTable) error {
if table.Type != mountTableType {
c.logger.Printf(
"[ERR] core: given table to persist has type %s but need type %s",
table.Type,
mountTableType)
return fmt.Errorf("invalid table type given, not persisting")
}
for _, entry := range table.Entries {
if entry.Table != table.Type {
c.logger.Printf(
"[ERR] core: entry in mount table with path %s has table value %s but is in table %s, refusing to persist",
entry.Path,
entry.Table,
table.Type)
return fmt.Errorf("invalid mount entry found, not persisting")
}
}
// Marshal the table // Marshal the table
raw, err := json.Marshal(table) raw, err := json.Marshal(table)
if err != nil { if err != nil {
@ -574,12 +616,15 @@ func (c *Core) mountEntrySysView(me *MountEntry) logical.SystemView {
// defaultMountTable creates a default mount table // defaultMountTable creates a default mount table
func defaultMountTable() *MountTable { func defaultMountTable() *MountTable {
table := &MountTable{} table := &MountTable{
Type: mountTableType,
}
mountUUID, err := uuid.GenerateUUID() mountUUID, err := uuid.GenerateUUID()
if err != nil { if err != nil {
panic(fmt.Sprintf("could not create default mount table UUID: %v", err)) panic(fmt.Sprintf("could not create default mount table UUID: %v", err))
} }
genericMount := &MountEntry{ genericMount := &MountEntry{
Table: mountTableType,
Path: "secret/", Path: "secret/",
Type: "generic", Type: "generic",
Description: "generic secret storage", Description: "generic secret storage",
@ -593,12 +638,15 @@ func defaultMountTable() *MountTable {
// requiredMountTable() creates a mount table with entries required // requiredMountTable() creates a mount table with entries required
// to be available // to be available
func requiredMountTable() *MountTable { func requiredMountTable() *MountTable {
table := &MountTable{} table := &MountTable{
Type: mountTableType,
}
cubbyholeUUID, err := uuid.GenerateUUID() cubbyholeUUID, err := uuid.GenerateUUID()
if err != nil { if err != nil {
panic(fmt.Sprintf("could not create cubbyhole UUID: %v", err)) panic(fmt.Sprintf("could not create cubbyhole UUID: %v", err))
} }
cubbyholeMount := &MountEntry{ cubbyholeMount := &MountEntry{
Table: mountTableType,
Path: "cubbyhole/", Path: "cubbyhole/",
Type: "cubbyhole", Type: "cubbyhole",
Description: "per-token private secret storage", Description: "per-token private secret storage",
@ -610,6 +658,7 @@ func requiredMountTable() *MountTable {
panic(fmt.Sprintf("could not create sys UUID: %v", err)) panic(fmt.Sprintf("could not create sys UUID: %v", err))
} }
sysMount := &MountEntry{ sysMount := &MountEntry{
Table: mountTableType,
Path: "sys/", Path: "sys/",
Type: "system", Type: "system",
Description: "system endpoints used for control, policy and debugging", Description: "system endpoints used for control, policy and debugging",

View file

@ -1,10 +1,12 @@
package vault package vault
import ( import (
"encoding/json"
"reflect" "reflect"
"testing" "testing"
"time" "time"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
) )
@ -38,6 +40,7 @@ func TestCore_DefaultMountTable(t *testing.T) {
func TestCore_Mount(t *testing.T) { func TestCore_Mount(t *testing.T) {
c, key, _ := TestCoreUnsealed(t) c, key, _ := TestCoreUnsealed(t)
me := &MountEntry{ me := &MountEntry{
Table: mountTableType,
Path: "foo", Path: "foo",
Type: "generic", Type: "generic",
} }
@ -116,6 +119,7 @@ func TestCore_Unmount_Cleanup(t *testing.T) {
// Mount the noop backend // Mount the noop backend
me := &MountEntry{ me := &MountEntry{
Table: mountTableType,
Path: "test/", Path: "test/",
Type: "noop", Type: "noop",
} }
@ -233,6 +237,7 @@ func TestCore_Remount_Cleanup(t *testing.T) {
// Mount the noop backend // Mount the noop backend
me := &MountEntry{ me := &MountEntry{
Table: mountTableType,
Path: "test/", Path: "test/",
Type: "noop", Type: "noop",
} }
@ -320,6 +325,143 @@ func TestDefaultMountTable(t *testing.T) {
verifyDefaultTable(t, table) verifyDefaultTable(t, table)
} }
func TestCore_MountTable_UpgradeToTyped(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) {
return &NoopAudit{
Config: config,
}, nil
}
me := &MountEntry{
Table: auditTableType,
Path: "foo",
Type: "noop",
}
err := c.enableAudit(me)
if err != nil {
t.Fatalf("err: %v", err)
}
c.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
return &NoopBackend{}, nil
}
me = &MountEntry{
Table: credentialTableType,
Path: "foo",
Type: "noop",
}
err = c.enableCredential(me)
if err != nil {
t.Fatalf("err: %v", err)
}
testCore_MountTable_UpgradeToTyped_Common(t, c, "mounts")
testCore_MountTable_UpgradeToTyped_Common(t, c, "audits")
testCore_MountTable_UpgradeToTyped_Common(t, c, "credentials")
}
func testCore_MountTable_UpgradeToTyped_Common(
t *testing.T,
c *Core,
testType string) {
var path string
var mt *MountTable
switch testType {
case "mounts":
path = coreMountConfigPath
mt = c.mounts
case "audits":
path = coreAuditConfigPath
mt = c.audit
case "credentials":
path = coreAuthConfigPath
mt = c.auth
}
// Save the expected table
goodJson, err := json.Marshal(mt)
if err != nil {
t.Fatal(err)
}
// Create a pre-typed version
mt.Type = ""
for _, entry := range mt.Entries {
entry.Table = ""
}
raw, err := json.Marshal(mt)
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(raw, goodJson) {
t.Fatalf("bad: values here should be different")
}
entry := &Entry{
Key: path,
Value: raw,
}
if err := c.barrier.Put(entry); err != nil {
t.Fatal(err)
}
var persistFunc func(*MountTable) error
// It should load successfully and be upgraded and persisted
switch testType {
case "mounts":
err = c.loadMounts()
persistFunc = c.persistMounts
mt = c.mounts
case "credentials":
err = c.loadCredentials()
persistFunc = c.persistAuth
mt = c.auth
case "audits":
err = c.loadAudits()
persistFunc = c.persistAudit
mt = c.audit
}
if err != nil {
t.Fatal(err)
}
entry, err = c.barrier.Get(path)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(entry.Value, goodJson) {
t.Fatalf("bad: expected\n%s\ngot\n%s\n", string(goodJson), string(entry.Value))
}
// Now try saving invalid versions
origTableType := mt.Type
mt.Type = "foo"
if err := persistFunc(mt); err == nil {
t.Fatal("expected error")
}
if len(mt.Entries) > 0 {
mt.Type = origTableType
mt.Entries[0].Table = "bar"
if err := persistFunc(mt); err == nil {
t.Fatal("expected error")
}
mt.Entries[0].Table = mt.Type
if err := persistFunc(mt); err != nil {
t.Fatal(err)
}
}
}
func verifyDefaultTable(t *testing.T, table *MountTable) { func verifyDefaultTable(t *testing.T, table *MountTable) {
if len(table.Entries) != 3 { if len(table.Entries) != 3 {
t.Fatalf("bad: %v", table.Entries) t.Fatalf("bad: %v", table.Entries)
@ -348,6 +490,9 @@ func verifyDefaultTable(t *testing.T, table *MountTable) {
t.Fatalf("bad: %v", entry) t.Fatalf("bad: %v", entry)
} }
} }
if entry.Table != mountTableType {
t.Fatalf("bad: %v", entry)
}
if entry.Description == "" { if entry.Description == "" {
t.Fatalf("bad: %v", entry) t.Fatalf("bad: %v", entry)
} }

View file

@ -15,6 +15,7 @@ func TestRequestHandling_Wrapping(t *testing.T) {
meUUID, _ := uuid.GenerateUUID() meUUID, _ := uuid.GenerateUUID()
err := core.mount(&MountEntry{ err := core.mount(&MountEntry{
Table: mountTableType,
UUID: meUUID, UUID: meUUID,
Path: "wraptest", Path: "wraptest",
Type: "generic", Type: "generic",

View file

@ -167,6 +167,7 @@ func TestCoreWithTokenStore(t *testing.T) (*Core, *TokenStore, []byte, string) {
c, key, root := TestCoreUnsealed(t) c, key, root := TestCoreUnsealed(t)
me := &MountEntry{ me := &MountEntry{
Table: credentialTableType,
Path: "token/", Path: "token/",
Type: "token", Type: "token",
Description: "token based credentials", Description: "token based credentials",
@ -184,7 +185,7 @@ func TestCoreWithTokenStore(t *testing.T) (*Core, *TokenStore, []byte, string) {
ts := tokenstore.(*TokenStore) ts := tokenstore.(*TokenStore)
router := NewRouter() router := NewRouter()
router.Mount(ts, "auth/token/", &MountEntry{UUID: ""}, ts.view) router.Mount(ts, "auth/token/", &MountEntry{Table: credentialTableType, UUID: ""}, ts.view)
subview := c.systemBarrierView.SubView(expirationSubPath) subview := c.systemBarrierView.SubView(expirationSubPath)
logger := log.New(os.Stderr, "", log.LstdFlags) logger := log.New(os.Stderr, "", log.LstdFlags)