Port some ent mount changes (#4330)
This commit is contained in:
parent
28f6d65032
commit
6ca3ae4007
|
@ -134,7 +134,7 @@ func (c *Core) enableCredential(ctx context.Context, entry *MountEntry) error {
|
|||
// Update the auth table
|
||||
newTable := c.auth.shallowClone()
|
||||
newTable.Entries = append(newTable.Entries, entry)
|
||||
if err := c.persistAuth(ctx, newTable, entry.Local); err != nil {
|
||||
if err := c.persistAuth(ctx, newTable, &entry.Local); err != nil {
|
||||
return errors.New("failed to update auth table")
|
||||
}
|
||||
|
||||
|
@ -235,7 +235,7 @@ func (c *Core) removeCredEntry(ctx context.Context, path string) error {
|
|||
}
|
||||
|
||||
// Update the auth table
|
||||
if err := c.persistAuth(ctx, newTable, entry.Local); err != nil {
|
||||
if err := c.persistAuth(ctx, newTable, &entry.Local); err != nil {
|
||||
return errors.New("failed to update auth table")
|
||||
}
|
||||
|
||||
|
@ -281,7 +281,7 @@ func (c *Core) taintCredEntry(ctx context.Context, path string) error {
|
|||
}
|
||||
|
||||
// Update the auth table
|
||||
if err := c.persistAuth(ctx, c.auth, entry.Local); err != nil {
|
||||
if err := c.persistAuth(ctx, c.auth, &entry.Local); err != nil {
|
||||
return errors.New("failed to update auth table")
|
||||
}
|
||||
|
||||
|
@ -369,7 +369,7 @@ func (c *Core) loadCredentials(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
if err := c.persistAuth(ctx, c.auth, false); err != nil {
|
||||
if err := c.persistAuth(ctx, c.auth, nil); err != nil {
|
||||
c.logger.Error("failed to persist auth table", "error", err)
|
||||
return errLoadAuthFailed
|
||||
}
|
||||
|
@ -377,7 +377,7 @@ func (c *Core) loadCredentials(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// persistAuth is used to persist the auth table after modification
|
||||
func (c *Core) persistAuth(ctx context.Context, table *MountTable, localOnly bool) error {
|
||||
func (c *Core) persistAuth(ctx context.Context, table *MountTable, local *bool) error {
|
||||
if table.Type != credentialTableType {
|
||||
c.logger.Error("given table to persist has wrong type", "actual_type", table.Type, "expected_type", credentialTableType)
|
||||
return fmt.Errorf("invalid table type given, not persisting")
|
||||
|
@ -406,45 +406,49 @@ func (c *Core) persistAuth(ctx context.Context, table *MountTable, localOnly boo
|
|||
}
|
||||
}
|
||||
|
||||
if !localOnly {
|
||||
// Marshal the table
|
||||
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAuth, nil)
|
||||
writeTable := func(mt *MountTable, path string) error {
|
||||
// Encode the mount table into JSON and compress it (lzw).
|
||||
compressedBytes, err := jsonutil.EncodeJSONAndCompress(mt, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to encode and/or compress auth table", "error", err)
|
||||
c.logger.Error("failed to encode or compress auth mount table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create an entry
|
||||
entry := &Entry{
|
||||
Key: coreAuthConfigPath,
|
||||
Key: path,
|
||||
Value: compressedBytes,
|
||||
}
|
||||
|
||||
// Write to the physical backend
|
||||
if err := c.barrier.Put(ctx, entry); err != nil {
|
||||
c.logger.Error("failed to persist auth table", "error", err)
|
||||
c.logger.Error("failed to persist auth mount table", "error", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Repeat with local auth
|
||||
compressedBytes, err := jsonutil.EncodeJSONAndCompress(localAuth, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to encode and/or compress local auth table", "error", err)
|
||||
return err
|
||||
var err error
|
||||
switch {
|
||||
case local == nil:
|
||||
// Write non-local mounts
|
||||
err := writeTable(nonLocalAuth, coreAuthConfigPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write local mounts
|
||||
err = writeTable(localAuth, coreLocalAuthConfigPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case *local:
|
||||
err = writeTable(localAuth, coreLocalAuthConfigPath)
|
||||
default:
|
||||
err = writeTable(nonLocalAuth, coreAuthConfigPath)
|
||||
}
|
||||
|
||||
entry := &Entry{
|
||||
Key: coreLocalAuthConfigPath,
|
||||
Value: compressedBytes,
|
||||
}
|
||||
|
||||
if err := c.barrier.Put(ctx, entry); err != nil {
|
||||
c.logger.Error("failed to persist local auth table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// setupCredentials is invoked after we've loaded the auth table to
|
||||
|
@ -520,7 +524,7 @@ func (c *Core) setupCredentials(ctx context.Context) error {
|
|||
}
|
||||
|
||||
if persistNeeded {
|
||||
return c.persistAuth(ctx, c.auth, false)
|
||||
return c.persistAuth(ctx, c.auth, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -164,7 +164,7 @@ func TestCore_EnableCredential_Local(t *testing.T) {
|
|||
}
|
||||
|
||||
c.auth.Entries[1].Local = true
|
||||
if err := c.persistAuth(context.Background(), c.auth, false); err != nil {
|
||||
if err := c.persistAuth(context.Background(), c.auth, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -1988,9 +1988,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
|
|||
var err error
|
||||
switch {
|
||||
case strings.HasPrefix(path, "auth/"):
|
||||
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local)
|
||||
err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
|
||||
default:
|
||||
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local)
|
||||
err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
|
||||
}
|
||||
if err != nil {
|
||||
mountEntry.Description = oldDesc
|
||||
|
@ -2011,9 +2011,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
|
|||
var err error
|
||||
switch {
|
||||
case strings.HasPrefix(path, "auth/"):
|
||||
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local)
|
||||
err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
|
||||
default:
|
||||
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local)
|
||||
err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
|
||||
}
|
||||
if err != nil {
|
||||
mountEntry.Config.AuditNonHMACRequestKeys = oldVal
|
||||
|
@ -2037,9 +2037,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
|
|||
var err error
|
||||
switch {
|
||||
case strings.HasPrefix(path, "auth/"):
|
||||
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local)
|
||||
err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
|
||||
default:
|
||||
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local)
|
||||
err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
|
||||
}
|
||||
if err != nil {
|
||||
mountEntry.Config.AuditNonHMACResponseKeys = oldVal
|
||||
|
@ -2068,9 +2068,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
|
|||
var err error
|
||||
switch {
|
||||
case strings.HasPrefix(path, "auth/"):
|
||||
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local)
|
||||
err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
|
||||
default:
|
||||
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local)
|
||||
err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
|
||||
}
|
||||
if err != nil {
|
||||
mountEntry.Config.ListingVisibility = oldVal
|
||||
|
@ -2092,9 +2092,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
|
|||
var err error
|
||||
switch {
|
||||
case strings.HasPrefix(path, "auth/"):
|
||||
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local)
|
||||
err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
|
||||
default:
|
||||
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local)
|
||||
err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
|
||||
}
|
||||
if err != nil {
|
||||
mountEntry.Config.PassthroughRequestHeaders = oldVal
|
||||
|
@ -2154,9 +2154,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
|
|||
// Update the mount table
|
||||
switch {
|
||||
case strings.HasPrefix(path, "auth/"):
|
||||
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local)
|
||||
err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
|
||||
default:
|
||||
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local)
|
||||
err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
|
||||
}
|
||||
if err != nil {
|
||||
mountEntry.Options = oldVal
|
||||
|
|
|
@ -37,9 +37,9 @@ func (b *SystemBackend) tuneMountTTLs(ctx context.Context, path string, me *Moun
|
|||
var err error
|
||||
switch {
|
||||
case strings.HasPrefix(path, credentialRoutePrefix):
|
||||
err = b.Core.persistAuth(ctx, b.Core.auth, me.Local)
|
||||
err = b.Core.persistAuth(ctx, b.Core.auth, &me.Local)
|
||||
default:
|
||||
err = b.Core.persistMounts(ctx, b.Core.mounts, me.Local)
|
||||
err = b.Core.persistMounts(ctx, b.Core.mounts, &me.Local)
|
||||
}
|
||||
if err != nil {
|
||||
me.Config.MaxLeaseTTL = origMax
|
||||
|
|
|
@ -336,7 +336,7 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry) error {
|
|||
|
||||
newTable := c.mounts.shallowClone()
|
||||
newTable.Entries = append(newTable.Entries, entry)
|
||||
if err := c.persistMounts(ctx, newTable, entry.Local); err != nil {
|
||||
if err := c.persistMounts(ctx, newTable, &entry.Local); err != nil {
|
||||
c.logger.Error("failed to update mount table", "error", err)
|
||||
return logical.CodedError(500, "failed to update mount table")
|
||||
}
|
||||
|
@ -457,7 +457,7 @@ func (c *Core) removeMountEntry(ctx context.Context, path string) error {
|
|||
}
|
||||
|
||||
// Update the mount table
|
||||
if err := c.persistMounts(ctx, newTable, entry.Local); err != nil {
|
||||
if err := c.persistMounts(ctx, newTable, &entry.Local); err != nil {
|
||||
c.logger.Error("failed to remove entry from mounts table", "error", err)
|
||||
return logical.CodedError(500, "failed to remove entry from mounts table")
|
||||
}
|
||||
|
@ -480,7 +480,7 @@ func (c *Core) taintMountEntry(ctx context.Context, path string) error {
|
|||
}
|
||||
|
||||
// Update the mount table
|
||||
if err := c.persistMounts(ctx, c.mounts, entry.Local); err != nil {
|
||||
if err := c.persistMounts(ctx, c.mounts, &entry.Local); err != nil {
|
||||
c.logger.Error("failed to taint entry in mounts table", "error", err)
|
||||
return logical.CodedError(500, "failed to taint entry in mounts table")
|
||||
}
|
||||
|
@ -571,7 +571,7 @@ func (c *Core) remount(ctx context.Context, src, dst string) error {
|
|||
}
|
||||
|
||||
// Update the mount table
|
||||
if err := c.persistMounts(ctx, c.mounts, entry.Local); err != nil {
|
||||
if err := c.persistMounts(ctx, c.mounts, &entry.Local); err != nil {
|
||||
entry.Path = src
|
||||
entry.Tainted = true
|
||||
c.mountsLock.Unlock()
|
||||
|
@ -710,7 +710,7 @@ func (c *Core) loadMounts(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
if err := c.persistMounts(ctx, c.mounts, false); err != nil {
|
||||
if err := c.persistMounts(ctx, c.mounts, nil); err != nil {
|
||||
c.logger.Error("failed to persist mount table", "error", err)
|
||||
return errLoadMountsFailed
|
||||
}
|
||||
|
@ -718,7 +718,7 @@ func (c *Core) loadMounts(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// persistMounts is used to persist the mount table after modification
|
||||
func (c *Core) persistMounts(ctx context.Context, table *MountTable, localOnly bool) error {
|
||||
func (c *Core) persistMounts(ctx context.Context, table *MountTable, local *bool) error {
|
||||
if table.Type != mountTableType {
|
||||
c.logger.Error("given table to persist has wrong type", "actual_type", table.Type, "expected_type", mountTableType)
|
||||
return fmt.Errorf("invalid table type given, not persisting")
|
||||
|
@ -747,17 +747,17 @@ func (c *Core) persistMounts(ctx context.Context, table *MountTable, localOnly b
|
|||
}
|
||||
}
|
||||
|
||||
if !localOnly {
|
||||
writeTable := func(mt *MountTable, path string) error {
|
||||
// Encode the mount table into JSON and compress it (lzw).
|
||||
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalMounts, nil)
|
||||
compressedBytes, err := jsonutil.EncodeJSONAndCompress(mt, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to encode and/or compress the mount table", "error", err)
|
||||
c.logger.Error("failed to encode or compress mount table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create an entry
|
||||
entry := &Entry{
|
||||
Key: coreMountConfigPath,
|
||||
Key: path,
|
||||
Value: compressedBytes,
|
||||
}
|
||||
|
||||
|
@ -766,26 +766,33 @@ func (c *Core) persistMounts(ctx context.Context, table *MountTable, localOnly b
|
|||
c.logger.Error("failed to persist mount table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Repeat with local mounts
|
||||
compressedBytes, err := jsonutil.EncodeJSONAndCompress(localMounts, nil)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to encode and/or compress the local mount table", "error", err)
|
||||
return err
|
||||
var err error
|
||||
switch {
|
||||
case local == nil:
|
||||
// Write non-local mounts
|
||||
err := writeTable(nonLocalMounts, coreMountConfigPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write local mounts
|
||||
err = writeTable(localMounts, coreLocalMountConfigPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case *local:
|
||||
// Write local mounts
|
||||
err = writeTable(localMounts, coreLocalMountConfigPath)
|
||||
default:
|
||||
// Write non-local mounts
|
||||
err = writeTable(nonLocalMounts, coreMountConfigPath)
|
||||
}
|
||||
|
||||
entry := &Entry{
|
||||
Key: coreLocalMountConfigPath,
|
||||
Value: compressedBytes,
|
||||
}
|
||||
|
||||
if err := c.barrier.Put(ctx, entry); err != nil {
|
||||
c.logger.Error("failed to persist local mount table", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// setupMounts is invoked after we've loaded the mount table to
|
||||
|
|
|
@ -161,7 +161,7 @@ func TestCore_Mount_Local(t *testing.T) {
|
|||
}
|
||||
|
||||
c.mounts.Entries[1].Local = true
|
||||
if err := c.persistMounts(context.Background(), c.mounts, false); err != nil {
|
||||
if err := c.persistMounts(context.Background(), c.mounts, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -557,7 +557,7 @@ func testCore_MountTable_UpgradeToTyped_Common(
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var persistFunc func(context.Context, *MountTable, bool) error
|
||||
var persistFunc func(context.Context, *MountTable, *bool) error
|
||||
|
||||
// It should load successfully and be upgraded and persisted
|
||||
switch testType {
|
||||
|
@ -571,7 +571,13 @@ func testCore_MountTable_UpgradeToTyped_Common(
|
|||
mt = c.auth
|
||||
case "audits":
|
||||
err = c.loadAudits(context.Background())
|
||||
persistFunc = c.persistAudit
|
||||
persistFunc = func(ctx context.Context, mt *MountTable, b *bool) error {
|
||||
if b == nil {
|
||||
b = new(bool)
|
||||
*b = false
|
||||
}
|
||||
return c.persistAudit(ctx, mt, *b)
|
||||
}
|
||||
mt = c.audit
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -600,19 +606,19 @@ func testCore_MountTable_UpgradeToTyped_Common(
|
|||
// Now try saving invalid versions
|
||||
origTableType := mt.Type
|
||||
mt.Type = "foo"
|
||||
if err := persistFunc(context.Background(), mt, false); err == nil {
|
||||
if err := persistFunc(context.Background(), mt, nil); err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
if len(mt.Entries) > 0 {
|
||||
mt.Type = origTableType
|
||||
mt.Entries[0].Table = "bar"
|
||||
if err := persistFunc(context.Background(), mt, false); err == nil {
|
||||
if err := persistFunc(context.Background(), mt, nil); err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
mt.Entries[0].Table = mt.Type
|
||||
if err := persistFunc(context.Background(), mt, false); err != nil {
|
||||
if err := persistFunc(context.Background(), mt, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue