Port some ent mount changes (#4330)

This commit is contained in:
Brian Kassouf 2018-04-11 11:32:55 -07:00 committed by Jeff Mitchell
parent 28f6d65032
commit 6ca3ae4007
6 changed files with 92 additions and 75 deletions

View File

@ -134,7 +134,7 @@ func (c *Core) enableCredential(ctx context.Context, entry *MountEntry) error {
// Update the auth table
newTable := c.auth.shallowClone()
newTable.Entries = append(newTable.Entries, entry)
if err := c.persistAuth(ctx, newTable, entry.Local); err != nil {
if err := c.persistAuth(ctx, newTable, &entry.Local); err != nil {
return errors.New("failed to update auth table")
}
@ -235,7 +235,7 @@ func (c *Core) removeCredEntry(ctx context.Context, path string) error {
}
// Update the auth table
if err := c.persistAuth(ctx, newTable, entry.Local); err != nil {
if err := c.persistAuth(ctx, newTable, &entry.Local); err != nil {
return errors.New("failed to update auth table")
}
@ -281,7 +281,7 @@ func (c *Core) taintCredEntry(ctx context.Context, path string) error {
}
// Update the auth table
if err := c.persistAuth(ctx, c.auth, entry.Local); err != nil {
if err := c.persistAuth(ctx, c.auth, &entry.Local); err != nil {
return errors.New("failed to update auth table")
}
@ -369,7 +369,7 @@ func (c *Core) loadCredentials(ctx context.Context) error {
return nil
}
if err := c.persistAuth(ctx, c.auth, false); err != nil {
if err := c.persistAuth(ctx, c.auth, nil); err != nil {
c.logger.Error("failed to persist auth table", "error", err)
return errLoadAuthFailed
}
@ -377,7 +377,7 @@ func (c *Core) loadCredentials(ctx context.Context) error {
}
// persistAuth is used to persist the auth table after modification
func (c *Core) persistAuth(ctx context.Context, table *MountTable, localOnly bool) error {
func (c *Core) persistAuth(ctx context.Context, table *MountTable, local *bool) error {
if table.Type != credentialTableType {
c.logger.Error("given table to persist has wrong type", "actual_type", table.Type, "expected_type", credentialTableType)
return fmt.Errorf("invalid table type given, not persisting")
@ -406,45 +406,49 @@ func (c *Core) persistAuth(ctx context.Context, table *MountTable, localOnly boo
}
}
if !localOnly {
// Marshal the table
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAuth, nil)
writeTable := func(mt *MountTable, path string) error {
// Encode the mount table into JSON and compress it (lzw).
compressedBytes, err := jsonutil.EncodeJSONAndCompress(mt, nil)
if err != nil {
c.logger.Error("failed to encode and/or compress auth table", "error", err)
c.logger.Error("failed to encode or compress auth mount table", "error", err)
return err
}
// Create an entry
entry := &Entry{
Key: coreAuthConfigPath,
Key: path,
Value: compressedBytes,
}
// Write to the physical backend
if err := c.barrier.Put(ctx, entry); err != nil {
c.logger.Error("failed to persist auth table", "error", err)
c.logger.Error("failed to persist auth mount table", "error", err)
return err
}
return nil
}
// Repeat with local auth
compressedBytes, err := jsonutil.EncodeJSONAndCompress(localAuth, nil)
if err != nil {
c.logger.Error("failed to encode and/or compress local auth table", "error", err)
return err
var err error
switch {
case local == nil:
// Write non-local mounts
err := writeTable(nonLocalAuth, coreAuthConfigPath)
if err != nil {
return err
}
// Write local mounts
err = writeTable(localAuth, coreLocalAuthConfigPath)
if err != nil {
return err
}
case *local:
err = writeTable(localAuth, coreLocalAuthConfigPath)
default:
err = writeTable(nonLocalAuth, coreAuthConfigPath)
}
entry := &Entry{
Key: coreLocalAuthConfigPath,
Value: compressedBytes,
}
if err := c.barrier.Put(ctx, entry); err != nil {
c.logger.Error("failed to persist local auth table", "error", err)
return err
}
return nil
return err
}
// setupCredentials is invoked after we've loaded the auth table to
@ -520,7 +524,7 @@ func (c *Core) setupCredentials(ctx context.Context) error {
}
if persistNeeded {
return c.persistAuth(ctx, c.auth, false)
return c.persistAuth(ctx, c.auth, nil)
}
return nil

View File

@ -164,7 +164,7 @@ func TestCore_EnableCredential_Local(t *testing.T) {
}
c.auth.Entries[1].Local = true
if err := c.persistAuth(context.Background(), c.auth, false); err != nil {
if err := c.persistAuth(context.Background(), c.auth, nil); err != nil {
t.Fatal(err)
}

View File

@ -1988,9 +1988,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
var err error
switch {
case strings.HasPrefix(path, "auth/"):
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local)
err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
default:
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local)
err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
}
if err != nil {
mountEntry.Description = oldDesc
@ -2011,9 +2011,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
var err error
switch {
case strings.HasPrefix(path, "auth/"):
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local)
err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
default:
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local)
err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
}
if err != nil {
mountEntry.Config.AuditNonHMACRequestKeys = oldVal
@ -2037,9 +2037,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
var err error
switch {
case strings.HasPrefix(path, "auth/"):
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local)
err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
default:
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local)
err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
}
if err != nil {
mountEntry.Config.AuditNonHMACResponseKeys = oldVal
@ -2068,9 +2068,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
var err error
switch {
case strings.HasPrefix(path, "auth/"):
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local)
err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
default:
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local)
err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
}
if err != nil {
mountEntry.Config.ListingVisibility = oldVal
@ -2092,9 +2092,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
var err error
switch {
case strings.HasPrefix(path, "auth/"):
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local)
err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
default:
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local)
err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
}
if err != nil {
mountEntry.Config.PassthroughRequestHeaders = oldVal
@ -2154,9 +2154,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
// Update the mount table
switch {
case strings.HasPrefix(path, "auth/"):
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local)
err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
default:
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local)
err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
}
if err != nil {
mountEntry.Options = oldVal

View File

@ -37,9 +37,9 @@ func (b *SystemBackend) tuneMountTTLs(ctx context.Context, path string, me *Moun
var err error
switch {
case strings.HasPrefix(path, credentialRoutePrefix):
err = b.Core.persistAuth(ctx, b.Core.auth, me.Local)
err = b.Core.persistAuth(ctx, b.Core.auth, &me.Local)
default:
err = b.Core.persistMounts(ctx, b.Core.mounts, me.Local)
err = b.Core.persistMounts(ctx, b.Core.mounts, &me.Local)
}
if err != nil {
me.Config.MaxLeaseTTL = origMax

View File

@ -336,7 +336,7 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry) error {
newTable := c.mounts.shallowClone()
newTable.Entries = append(newTable.Entries, entry)
if err := c.persistMounts(ctx, newTable, entry.Local); err != nil {
if err := c.persistMounts(ctx, newTable, &entry.Local); err != nil {
c.logger.Error("failed to update mount table", "error", err)
return logical.CodedError(500, "failed to update mount table")
}
@ -457,7 +457,7 @@ func (c *Core) removeMountEntry(ctx context.Context, path string) error {
}
// Update the mount table
if err := c.persistMounts(ctx, newTable, entry.Local); err != nil {
if err := c.persistMounts(ctx, newTable, &entry.Local); err != nil {
c.logger.Error("failed to remove entry from mounts table", "error", err)
return logical.CodedError(500, "failed to remove entry from mounts table")
}
@ -480,7 +480,7 @@ func (c *Core) taintMountEntry(ctx context.Context, path string) error {
}
// Update the mount table
if err := c.persistMounts(ctx, c.mounts, entry.Local); err != nil {
if err := c.persistMounts(ctx, c.mounts, &entry.Local); err != nil {
c.logger.Error("failed to taint entry in mounts table", "error", err)
return logical.CodedError(500, "failed to taint entry in mounts table")
}
@ -571,7 +571,7 @@ func (c *Core) remount(ctx context.Context, src, dst string) error {
}
// Update the mount table
if err := c.persistMounts(ctx, c.mounts, entry.Local); err != nil {
if err := c.persistMounts(ctx, c.mounts, &entry.Local); err != nil {
entry.Path = src
entry.Tainted = true
c.mountsLock.Unlock()
@ -710,7 +710,7 @@ func (c *Core) loadMounts(ctx context.Context) error {
return nil
}
if err := c.persistMounts(ctx, c.mounts, false); err != nil {
if err := c.persistMounts(ctx, c.mounts, nil); err != nil {
c.logger.Error("failed to persist mount table", "error", err)
return errLoadMountsFailed
}
@ -718,7 +718,7 @@ func (c *Core) loadMounts(ctx context.Context) error {
}
// persistMounts is used to persist the mount table after modification
func (c *Core) persistMounts(ctx context.Context, table *MountTable, localOnly bool) error {
func (c *Core) persistMounts(ctx context.Context, table *MountTable, local *bool) error {
if table.Type != mountTableType {
c.logger.Error("given table to persist has wrong type", "actual_type", table.Type, "expected_type", mountTableType)
return fmt.Errorf("invalid table type given, not persisting")
@ -747,17 +747,17 @@ func (c *Core) persistMounts(ctx context.Context, table *MountTable, localOnly b
}
}
if !localOnly {
writeTable := func(mt *MountTable, path string) error {
// Encode the mount table into JSON and compress it (lzw).
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalMounts, nil)
compressedBytes, err := jsonutil.EncodeJSONAndCompress(mt, nil)
if err != nil {
c.logger.Error("failed to encode and/or compress the mount table", "error", err)
c.logger.Error("failed to encode or compress mount table", "error", err)
return err
}
// Create an entry
entry := &Entry{
Key: coreMountConfigPath,
Key: path,
Value: compressedBytes,
}
@ -766,26 +766,33 @@ func (c *Core) persistMounts(ctx context.Context, table *MountTable, localOnly b
c.logger.Error("failed to persist mount table", "error", err)
return err
}
return nil
}
// Repeat with local mounts
compressedBytes, err := jsonutil.EncodeJSONAndCompress(localMounts, nil)
if err != nil {
c.logger.Error("failed to encode and/or compress the local mount table", "error", err)
return err
var err error
switch {
case local == nil:
// Write non-local mounts
err := writeTable(nonLocalMounts, coreMountConfigPath)
if err != nil {
return err
}
// Write local mounts
err = writeTable(localMounts, coreLocalMountConfigPath)
if err != nil {
return err
}
case *local:
// Write local mounts
err = writeTable(localMounts, coreLocalMountConfigPath)
default:
// Write non-local mounts
err = writeTable(nonLocalMounts, coreMountConfigPath)
}
entry := &Entry{
Key: coreLocalMountConfigPath,
Value: compressedBytes,
}
if err := c.barrier.Put(ctx, entry); err != nil {
c.logger.Error("failed to persist local mount table", "error", err)
return err
}
return nil
return err
}
// setupMounts is invoked after we've loaded the mount table to

View File

@ -161,7 +161,7 @@ func TestCore_Mount_Local(t *testing.T) {
}
c.mounts.Entries[1].Local = true
if err := c.persistMounts(context.Background(), c.mounts, false); err != nil {
if err := c.persistMounts(context.Background(), c.mounts, nil); err != nil {
t.Fatal(err)
}
@ -557,7 +557,7 @@ func testCore_MountTable_UpgradeToTyped_Common(
t.Fatal(err)
}
var persistFunc func(context.Context, *MountTable, bool) error
var persistFunc func(context.Context, *MountTable, *bool) error
// It should load successfully and be upgraded and persisted
switch testType {
@ -571,7 +571,13 @@ func testCore_MountTable_UpgradeToTyped_Common(
mt = c.auth
case "audits":
err = c.loadAudits(context.Background())
persistFunc = c.persistAudit
persistFunc = func(ctx context.Context, mt *MountTable, b *bool) error {
if b == nil {
b = new(bool)
*b = false
}
return c.persistAudit(ctx, mt, *b)
}
mt = c.audit
}
if err != nil {
@ -600,19 +606,19 @@ func testCore_MountTable_UpgradeToTyped_Common(
// Now try saving invalid versions
origTableType := mt.Type
mt.Type = "foo"
if err := persistFunc(context.Background(), mt, false); err == nil {
if err := persistFunc(context.Background(), mt, nil); err == nil {
t.Fatal("expected error")
}
if len(mt.Entries) > 0 {
mt.Type = origTableType
mt.Entries[0].Table = "bar"
if err := persistFunc(context.Background(), mt, false); err == nil {
if err := persistFunc(context.Background(), mt, nil); err == nil {
t.Fatal("expected error")
}
mt.Entries[0].Table = mt.Type
if err := persistFunc(context.Background(), mt, false); err != nil {
if err := persistFunc(context.Background(), mt, nil); err != nil {
t.Fatal(err)
}
}