diff --git a/builtin/credential/aws/backend.go b/builtin/credential/aws/backend.go index 863801526..44bd9d8b3 100644 --- a/builtin/credential/aws/backend.go +++ b/builtin/credential/aws/backend.go @@ -43,6 +43,8 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) { pathImageTag(&b), pathConfigClient(&b), pathConfigCertificate(&b), + pathConfigTidyBlacklistRoleTag(&b), + pathConfigTidyWhitelistIdentity(&b), pathListCertificates(&b), pathBlacklistRoleTag(&b), pathListBlacklistRoleTags(&b), @@ -53,6 +55,8 @@ func Backend(conf *logical.BackendConfig) (*framework.Backend, error) { }), AuthRenew: b.pathLoginRenew, + + TidyFunc: b.tidyFunc, } b.EC2ClientsMap = make(map[string]*ec2.EC2) @@ -69,6 +73,45 @@ type backend struct { EC2ClientsMap map[string]*ec2.EC2 } +func (b *backend) tidyFunc(req *logical.Request) error { + b.configMutex.Lock() + defer b.configMutex.Unlock() + // safety_buffer defaults to 72h + safety_buffer := 259200 + tidyBlacklistConfigEntry, err := configTidyBlacklistRoleTag(req.Storage) + if err != nil { + return err + } + skipBlacklistTidy := false + if tidyBlacklistConfigEntry != nil { + if tidyBlacklistConfigEntry.DisablePeriodicTidy { + skipBlacklistTidy = true + } + safety_buffer = tidyBlacklistConfigEntry.SafetyBuffer + } + if !skipBlacklistTidy { + tidyBlacklistRoleTag(req.Storage, safety_buffer) + } + + // reset the safety_buffer to 72h + safety_buffer = 259200 + tidyWhitelistConfigEntry, err := configTidyWhitelistIdentity(req.Storage) + if err != nil { + return err + } + skipWhitelistTidy := false + if tidyWhitelistConfigEntry != nil { + if tidyWhitelistConfigEntry.DisablePeriodicTidy { + skipWhitelistTidy = true + } + safety_buffer = tidyWhitelistConfigEntry.SafetyBuffer + } + if !skipWhitelistTidy { + tidyWhitelistIdentity(req.Storage, safety_buffer) + } + return nil +} + const backendHelp = ` AWS auth backend takes in a AWS EC2 instance identity document, its PKCS#7 signature and a client created nonce to authenticates the instance with Vault. diff --git a/builtin/credential/aws/path_blacklist_roletag.go b/builtin/credential/aws/path_blacklist_roletag.go index 9100f2346..4ebad4301 100644 --- a/builtin/credential/aws/path_blacklist_roletag.go +++ b/builtin/credential/aws/path_blacklist_roletag.go @@ -171,8 +171,7 @@ func (b *backend) pathBlacklistRoleTagUpdate( currentTime := time.Now() - var epoch time.Time - if blEntry.CreationTime.Equal(epoch) { + if blEntry.CreationTime.IsZero() { // Set the creation time for the blacklist entry. // This should not be updated after setting it once. // If blacklist operation is invoked more than once, only update the expiration time. diff --git a/builtin/credential/aws/path_blacklist_roletag_tidy.go b/builtin/credential/aws/path_blacklist_roletag_tidy.go index 7de0f0808..1a506840a 100644 --- a/builtin/credential/aws/path_blacklist_roletag_tidy.go +++ b/builtin/credential/aws/path_blacklist_roletag_tidy.go @@ -29,46 +29,47 @@ expiration, before it is removed from the backend storage.`, } } -// pathBlacklistRoleTagTidyUpdate is used to clean-up the entries in the role tag blacklist. -func (b *backend) pathBlacklistRoleTagTidyUpdate( - req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - - // safety_buffer is an optional parameter. - safety_buffer := data.Get("safety_buffer").(int) +// tidyBlacklistRoleTag is used to clean-up the entries in the role tag blacklist. +func tidyBlacklistRoleTag(s logical.Storage, safety_buffer int) error { bufferDuration := time.Duration(safety_buffer) * time.Second - - tags, err := req.Storage.List("blacklist/roletag/") + tags, err := s.List("blacklist/roletag/") if err != nil { - return nil, err + return err } for _, tag := range tags { - tagEntry, err := req.Storage.Get("blacklist/roletag/" + tag) + tagEntry, err := s.Get("blacklist/roletag/" + tag) if err != nil { - return nil, fmt.Errorf("error fetching tag %s: %s", tag, err) + return fmt.Errorf("error fetching tag %s: %s", tag, err) } if tagEntry == nil { - return nil, fmt.Errorf("tag entry for tag %s is nil", tag) + return fmt.Errorf("tag entry for tag %s is nil", tag) } if tagEntry.Value == nil || len(tagEntry.Value) == 0 { - return nil, fmt.Errorf("found entry for tag %s but actual tag is empty", tag) + return fmt.Errorf("found entry for tag %s but actual tag is empty", tag) } var result roleTagBlacklistEntry if err := tagEntry.DecodeJSON(&result); err != nil { - return nil, err + return err } if time.Now().After(result.ExpirationTime.Add(bufferDuration)) { - if err := req.Storage.Delete("blacklist/roletag" + tag); err != nil { - return nil, fmt.Errorf("error deleting tag %s from storage: %s", tag, err) + if err := s.Delete("blacklist/roletag" + tag); err != nil { + return fmt.Errorf("error deleting tag %s from storage: %s", tag, err) } } } - return nil, nil + return nil +} + +// pathBlacklistRoleTagTidyUpdate is used to clean-up the entries in the role tag blacklist. +func (b *backend) pathBlacklistRoleTagTidyUpdate( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + return nil, tidyBlacklistRoleTag(req.Storage, data.Get("safety_buffer").(int)) } const pathBlacklistRoleTagTidySyn = ` diff --git a/builtin/credential/aws/path_config_client.go b/builtin/credential/aws/path_config_client.go index fd0c600c1..331dd3d6d 100644 --- a/builtin/credential/aws/path_config_client.go +++ b/builtin/credential/aws/path_config_client.go @@ -8,7 +8,7 @@ import ( func pathConfigClient(b *backend) *framework.Path { return &framework.Path{ - Pattern: "config/client", + Pattern: "config/client$", Fields: map[string]*framework.FieldSchema{ "access_key": &framework.FieldSchema{ Type: framework.TypeString, diff --git a/builtin/credential/aws/path_config_tidy_blacklist_roletag.go b/builtin/credential/aws/path_config_tidy_blacklist_roletag.go new file mode 100644 index 000000000..4d319927e --- /dev/null +++ b/builtin/credential/aws/path_config_tidy_blacklist_roletag.go @@ -0,0 +1,141 @@ +package aws + +import ( + "github.com/fatih/structs" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathConfigTidyBlacklistRoleTag(b *backend) *framework.Path { + return &framework.Path{ + Pattern: "config/tidy/blacklist/roletag$", + Fields: map[string]*framework.FieldSchema{ + "safety_buffer": &framework.FieldSchema{ + Type: framework.TypeDurationSecond, + Default: 259200, //72h + Description: `The amount of extra time that must have passed beyond the roletag +expiration, before it is removed from the backend storage.`, + }, + "disable_periodic_tidy": &framework.FieldSchema{ + Type: framework.TypeBool, + Default: false, + Description: "If set to 'true', disables the periodic tidying of the 'blacklist/roletag/' entries and 'whitelist/identity' entries.", + }, + }, + + ExistenceCheck: b.pathConfigTidyBlacklistRoleTagExistenceCheck, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.CreateOperation: b.pathConfigTidyBlacklistRoleTagCreateUpdate, + logical.UpdateOperation: b.pathConfigTidyBlacklistRoleTagCreateUpdate, + }, + + HelpSynopsis: pathConfigTidyBlacklistRoleTagHelpSyn, + HelpDescription: pathConfigTidyBlacklistRoleTagHelpDesc, + } +} + +func (b *backend) pathConfigTidyBlacklistRoleTagExistenceCheck(req *logical.Request, data *framework.FieldData) (bool, error) { + b.configMutex.RLock() + defer b.configMutex.RUnlock() + + entry, err := configTidyBlacklistRoleTag(req.Storage) + if err != nil { + return false, err + } + return entry != nil, nil +} + +func configTidyBlacklistRoleTag(s logical.Storage) (*tidyBlacklistRoleTagConfig, error) { + entry, err := s.Get("config/tidy/blacklist/roletag") + if err != nil { + return nil, err + } + if entry == nil { + return nil, nil + } + + var result tidyBlacklistRoleTagConfig + if err := entry.DecodeJSON(&result); err != nil { + return nil, err + } + return &result, nil +} + +func (b *backend) pathConfigTidyBlacklistRoleTagCreateUpdate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + b.configMutex.Lock() + defer b.configMutex.Unlock() + configEntry, err := configTidyBlacklistRoleTag(req.Storage) + if err != nil { + return nil, err + } + if configEntry == nil { + configEntry = &tidyBlacklistRoleTagConfig{} + } + safetyBufferInt, ok := data.GetOk("safety_buffer") + if ok { + configEntry.SafetyBuffer = safetyBufferInt.(int) + } else if req.Operation == logical.CreateOperation { + configEntry.SafetyBuffer = data.Get("safety_buffer").(int) + } + disablePeriodicTidyBool, ok := data.GetOk("disable_periodic_tidy") + if ok { + configEntry.DisablePeriodicTidy = disablePeriodicTidyBool.(bool) + } else if req.Operation == logical.CreateOperation { + configEntry.DisablePeriodicTidy = data.Get("disable_periodic_tidy").(bool) + } + + entry, err := logical.StorageEntryJSON("config/tidy/blacklist/roletag", configEntry) + if err != nil { + return nil, err + } + + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + return nil, nil +} + +func (b *backend) pathConfigTidyBlacklistRoleTagRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + b.configMutex.RLock() + defer b.configMutex.RUnlock() + + clientConfig, err := configTidyBlacklistRoleTag(req.Storage) + if err != nil { + return nil, err + } + + if clientConfig == nil { + return nil, nil + } + return &logical.Response{ + Data: structs.New(clientConfig).Map(), + }, nil +} + +func (b *backend) pathConfigTidyBlacklistRoleTagDelete(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + b.configMutex.Lock() + defer b.configMutex.Unlock() + + if err := req.Storage.Delete("config/tidy/blacklist/roletag"); err != nil { + return nil, err + } + + return nil, nil +} + +type tidyBlacklistRoleTagConfig struct { + SafetyBuffer int `json:"safety_buffer" structs:"safety_buffer" mapstructure:"safety_buffer"` + DisablePeriodicTidy bool `json:"disable_periodic_tidy" structs:"disable_periodic_tidy" mapstructure:"disable_periodic_tidy"` +} + +const pathConfigTidyBlacklistRoleTagHelpSyn = ` +Configures the periodic tidying operation of the blacklisted role tag entries. +` +const pathConfigTidyBlacklistRoleTagHelpDesc = ` +By default, the expired entries in the blacklist will be attempted to be removed +periodically. This operation will look for expired items in the list and purge them. +However, there is a safety buffer duration (defaults to 72h), which purges the entries, +only if they have been persisting this duration, past its expiration time. +` diff --git a/builtin/credential/aws/path_config_tidy_whitelist_identity.go b/builtin/credential/aws/path_config_tidy_whitelist_identity.go new file mode 100644 index 000000000..1baf28885 --- /dev/null +++ b/builtin/credential/aws/path_config_tidy_whitelist_identity.go @@ -0,0 +1,141 @@ +package aws + +import ( + "github.com/fatih/structs" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathConfigTidyWhitelistIdentity(b *backend) *framework.Path { + return &framework.Path{ + Pattern: "config/tidy/whitelist/identity$", + Fields: map[string]*framework.FieldSchema{ + "safety_buffer": &framework.FieldSchema{ + Type: framework.TypeDurationSecond, + Default: 259200, //72h + Description: `The amount of extra time that must have passed beyond the identity's +expiration, before it is removed from the backend storage.`, + }, + "disable_periodic_tidy": &framework.FieldSchema{ + Type: framework.TypeBool, + Default: false, + Description: "If set to 'true', disables the periodic tidying of the 'whitelist/identity/' entries and 'whitelist/identity' entries.", + }, + }, + + ExistenceCheck: b.pathConfigTidyWhitelistIdentityExistenceCheck, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.CreateOperation: b.pathConfigTidyWhitelistIdentityCreateUpdate, + logical.UpdateOperation: b.pathConfigTidyWhitelistIdentityCreateUpdate, + }, + + HelpSynopsis: pathConfigTidyWhitelistIdentityHelpSyn, + HelpDescription: pathConfigTidyWhitelistIdentityHelpDesc, + } +} + +func (b *backend) pathConfigTidyWhitelistIdentityExistenceCheck(req *logical.Request, data *framework.FieldData) (bool, error) { + b.configMutex.RLock() + defer b.configMutex.RUnlock() + + entry, err := configTidyWhitelistIdentity(req.Storage) + if err != nil { + return false, err + } + return entry != nil, nil +} + +func configTidyWhitelistIdentity(s logical.Storage) (*tidyWhitelistIdentityConfig, error) { + entry, err := s.Get("config/tidy/whitelist/identity") + if err != nil { + return nil, err + } + if entry == nil { + return nil, nil + } + + var result tidyWhitelistIdentityConfig + if err := entry.DecodeJSON(&result); err != nil { + return nil, err + } + return &result, nil +} + +func (b *backend) pathConfigTidyWhitelistIdentityCreateUpdate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + b.configMutex.Lock() + defer b.configMutex.Unlock() + configEntry, err := configTidyWhitelistIdentity(req.Storage) + if err != nil { + return nil, err + } + if configEntry == nil { + configEntry = &tidyWhitelistIdentityConfig{} + } + safetyBufferInt, ok := data.GetOk("safety_buffer") + if ok { + configEntry.SafetyBuffer = safetyBufferInt.(int) + } else if req.Operation == logical.CreateOperation { + configEntry.SafetyBuffer = data.Get("safety_buffer").(int) + } + disablePeriodicTidyBool, ok := data.GetOk("disable_periodic_tidy") + if ok { + configEntry.DisablePeriodicTidy = disablePeriodicTidyBool.(bool) + } else if req.Operation == logical.CreateOperation { + configEntry.DisablePeriodicTidy = data.Get("disable_periodic_tidy").(bool) + } + + entry, err := logical.StorageEntryJSON("config/tidy/whitelist/identity", configEntry) + if err != nil { + return nil, err + } + + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + return nil, nil +} + +func (b *backend) pathConfigTidyWhitelistIdentityRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + b.configMutex.RLock() + defer b.configMutex.RUnlock() + + clientConfig, err := configTidyWhitelistIdentity(req.Storage) + if err != nil { + return nil, err + } + + if clientConfig == nil { + return nil, nil + } + return &logical.Response{ + Data: structs.New(clientConfig).Map(), + }, nil +} + +func (b *backend) pathConfigTidyWhitelistIdentityDelete(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + b.configMutex.Lock() + defer b.configMutex.Unlock() + + if err := req.Storage.Delete("config/tidy/whitelist/identity"); err != nil { + return nil, err + } + + return nil, nil +} + +type tidyWhitelistIdentityConfig struct { + SafetyBuffer int `json:"safety_buffer" structs:"safety_buffer" mapstructure:"safety_buffer"` + DisablePeriodicTidy bool `json:"disable_periodic_tidy" structs:"disable_periodic_tidy" mapstructure:"disable_periodic_tidy"` +} + +const pathConfigTidyWhitelistIdentityHelpSyn = ` +Configures the periodic tidying operation of the whitelisted identity entries. +` +const pathConfigTidyWhitelistIdentityHelpDesc = ` +By default, the expired entries in teb whitelist will be attempted to be removed +periodically. This operation will look for expired items in the list and purge them. +However, there is a safety buffer duration (defaults to 72h), which purges the entries, +only if they have been persisting this duration, past its expiration time. +` diff --git a/builtin/credential/aws/path_whitelist_identity_tidy.go b/builtin/credential/aws/path_whitelist_identity_tidy.go index bbefba46d..b80c49596 100644 --- a/builtin/credential/aws/path_whitelist_identity_tidy.go +++ b/builtin/credential/aws/path_whitelist_identity_tidy.go @@ -29,46 +29,48 @@ expiration, before it is removed from the backend storage.`, } } -// pathWhitelistIdentityTidyUpdate is used to delete entries in the whitelist that are expired. -func (b *backend) pathWhitelistIdentityTidyUpdate( - req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - - safety_buffer := data.Get("safety_buffer").(int) - +// tidyWhitelistIdentity is used to delete entries in the whitelist that are expired. +func tidyWhitelistIdentity(s logical.Storage, safety_buffer int) error { bufferDuration := time.Duration(safety_buffer) * time.Second - identities, err := req.Storage.List("whitelist/identity/") + identities, err := s.List("whitelist/identity/") if err != nil { - return nil, err + return err } for _, instanceID := range identities { - identityEntry, err := req.Storage.Get("whitelist/identity/" + instanceID) + identityEntry, err := s.Get("whitelist/identity/" + instanceID) if err != nil { - return nil, fmt.Errorf("error fetching identity of instanceID %s: %s", instanceID, err) + return fmt.Errorf("error fetching identity of instanceID %s: %s", instanceID, err) } if identityEntry == nil { - return nil, fmt.Errorf("identity entry for instanceID %s is nil", instanceID) + return fmt.Errorf("identity entry for instanceID %s is nil", instanceID) } if identityEntry.Value == nil || len(identityEntry.Value) == 0 { - return nil, fmt.Errorf("found identity entry for instanceID %s but actual identity is empty", instanceID) + return fmt.Errorf("found identity entry for instanceID %s but actual identity is empty", instanceID) } var result whitelistIdentity if err := identityEntry.DecodeJSON(&result); err != nil { - return nil, err + return err } if time.Now().After(result.ExpirationTime.Add(bufferDuration)) { - if err := req.Storage.Delete("whitelist/identity" + instanceID); err != nil { - return nil, fmt.Errorf("error deleting identity of instanceID %s from storage: %s", instanceID, err) + if err := s.Delete("whitelist/identity" + instanceID); err != nil { + return fmt.Errorf("error deleting identity of instanceID %s from storage: %s", instanceID, err) } } } - return nil, nil + return nil +} + +// pathWhitelistIdentityTidyUpdate is used to delete entries in the whitelist that are expired. +func (b *backend) pathWhitelistIdentityTidyUpdate( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + return nil, tidyWhitelistIdentity(req.Storage, data.Get("safety_buffer").(int)) } const pathWhitelistIdentityTidySyn = ` diff --git a/builtin/logical/aws/backend.go b/builtin/logical/aws/backend.go index 8d9614d60..721a4c382 100644 --- a/builtin/logical/aws/backend.go +++ b/builtin/logical/aws/backend.go @@ -35,8 +35,8 @@ func Backend() *framework.Backend { secretAccessKeys(&b), }, - Rollback: rollback, - RollbackMinAge: 5 * time.Minute, + WALRollback: walRollback, + WALRollbackMinAge: 5 * time.Minute, } return b.Backend diff --git a/builtin/logical/aws/rollback.go b/builtin/logical/aws/rollback.go index 8f133396f..5d1b335ed 100644 --- a/builtin/logical/aws/rollback.go +++ b/builtin/logical/aws/rollback.go @@ -7,12 +7,12 @@ import ( "github.com/hashicorp/vault/logical/framework" ) -var rollbackMap = map[string]framework.RollbackFunc{ +var walRollbackMap = map[string]framework.WALRollbackFunc{ "user": pathUserRollback, } -func rollback(req *logical.Request, kind string, data interface{}) error { - f, ok := rollbackMap[kind] +func walRollback(req *logical.Request, kind string, data interface{}) error { + f, ok := walRollbackMap[kind] if !ok { return fmt.Errorf("unknown type to rollback") } diff --git a/logical/framework/backend.go b/logical/framework/backend.go index e6749935a..0df8624b9 100644 --- a/logical/framework/backend.go +++ b/logical/framework/backend.go @@ -42,14 +42,24 @@ type Backend struct { // and ease specifying callbacks for revocation, renewal, etc. Secrets []*Secret - // Rollback is called when a WAL entry (see wal.go) has to be rolled + // TidyFunc is the callback, which if set, will be invoked when the + // periodic timer of RollbackManager ticks. This can be used by + // backends to do any tidying tasks. + // + // TidyFunc is different from 'Clean' in the sense that, TidyFunc is + // invoked to, say to periodically delete expired/stale entries in backend's + // storage, while the backend is still being used. Whereas `Clean` is + // invoked just before the backend is unmounted. + TidyFunc tidyFunc + + // WALRollback is called when a WAL entry (see wal.go) has to be rolled // back. It is called with the data from the entry. // - // RollbackMinAge is the minimum age of a WAL entry before it is attempted + // WALRollbackMinAge is the minimum age of a WAL entry before it is attempted // to be rolled back. This should be longer than the maximum time it takes // to successfully create a secret. - Rollback RollbackFunc - RollbackMinAge time.Duration + WALRollback WALRollbackFunc + WALRollbackMinAge time.Duration // Clean is called on unload to clean up e.g any existing connections // to the backend, if required. @@ -66,11 +76,15 @@ type Backend struct { pathsRe []*regexp.Regexp } +// tidyFunc is the callback called when the RollbackManager's timer ticks. +// This can be utilized by the backends to do tidying tasks. +type tidyFunc func(*logical.Request) error + // OperationFunc is the callback called for an operation on a path. type OperationFunc func(*logical.Request, *FieldData) (*logical.Response, error) -// RollbackFunc is the callback for rollbacks. -type RollbackFunc func(*logical.Request, string, interface{}) error +// WALRollbackFunc is the callback for rollbacks. +type WALRollbackFunc func(*logical.Request, string, interface{}) error // CleanupFunc is the callback for backend unload. type CleanupFunc func() @@ -385,6 +399,19 @@ func (b *Backend) handleRevokeRenew( } } +// handleRollback invokes the TidyFunc set on the backend. It also does a WAL rollback operation. +func (b *Backend) handleRollback( + req *logical.Request) (*logical.Response, error) { + // Response is not expected from the tidy operation. + if b.TidyFunc != nil { + if err := b.TidyFunc(req); err != nil { + return nil, err + } + } + + return b.handleWALRollback(req) +} + func (b *Backend) handleAuthRenew(req *logical.Request) (*logical.Response, error) { if b.AuthRenew == nil { return logical.ErrorResponse("this auth type doesn't support renew"), nil @@ -393,9 +420,9 @@ func (b *Backend) handleAuthRenew(req *logical.Request) (*logical.Response, erro return b.AuthRenew(req, nil) } -func (b *Backend) handleRollback( +func (b *Backend) handleWALRollback( req *logical.Request) (*logical.Response, error) { - if b.Rollback == nil { + if b.WALRollback == nil { return nil, logical.ErrUnsupportedOperation } @@ -410,7 +437,7 @@ func (b *Backend) handleRollback( // Calculate the minimum time that the WAL entries could be // created in order to be rolled back. - age := b.RollbackMinAge + age := b.WALRollbackMinAge if age == 0 { age = 10 * time.Minute } @@ -434,8 +461,8 @@ func (b *Backend) handleRollback( continue } - // Attempt a rollback - err = b.Rollback(req, entry.Kind, entry.Data) + // Attempt a WAL rollback + err = b.WALRollback(req, entry.Kind, entry.Data) if err != nil { err = fmt.Errorf( "Error rolling back '%s' entry: %s", entry.Kind, err) diff --git a/vault/rollback.go b/vault/rollback.go index 443a2c0ba..1c13d6f7c 100644 --- a/vault/rollback.go +++ b/vault/rollback.go @@ -30,10 +30,10 @@ const ( type RollbackManager struct { logger *log.Logger - // This gives the current mount table, plus a RWMutex that is - // locked for reading. It is up to the caller to RUnlock it - // when done with the mount table - mounts func() []*MountEntry + // This gives the current mount table of both logical and credential backends, + // plus a RWMutex that is locked for reading. It is up to the caller to RUnlock + // it when done with the mount table. + backends func() []*MountEntry router *Router period time.Duration @@ -55,10 +55,10 @@ type rollbackState struct { } // NewRollbackManager is used to create a new rollback manager -func NewRollbackManager(logger *log.Logger, mounts func() []*MountEntry, router *Router) *RollbackManager { +func NewRollbackManager(logger *log.Logger, backendsFunc func() []*MountEntry, router *Router) *RollbackManager { r := &RollbackManager{ logger: logger, - mounts: mounts, + backends: backendsFunc, router: router, period: rollbackPeriod, inflight: make(map[string]*rollbackState), @@ -109,9 +109,9 @@ func (m *RollbackManager) triggerRollbacks() { m.inflightLock.Lock() defer m.inflightLock.Unlock() - mounts := m.mounts() + backends := m.backends() - for _, e := range mounts { + for _, e := range backends { if _, ok := m.inflight[e.Path]; !ok { m.startRollback(e.Path) } @@ -184,16 +184,24 @@ func (m *RollbackManager) Rollback(path string) error { // startRollback is used to start the rollback manager after unsealing func (c *Core) startRollback() error { - mountsFunc := func() []*MountEntry { + backendsFunc := func() []*MountEntry { ret := []*MountEntry{} c.mountsLock.RLock() defer c.mountsLock.RUnlock() for _, entry := range c.mounts.Entries { ret = append(ret, entry) } + c.authLock.RLock() + defer c.authLock.RUnlock() + for _, entry := range c.auth.Entries { + if !strings.HasPrefix(entry.Path, "auth/") { + entry.Path = "auth/" + entry.Path + } + ret = append(ret, entry) + } return ret } - c.rollback = NewRollbackManager(c.logger, mountsFunc, c.router) + c.rollback = NewRollbackManager(c.logger, backendsFunc, c.router) c.rollback.Start() return nil }