diff --git a/vault/core.go b/vault/core.go index 5ec324430..4df0cbdc7 100644 --- a/vault/core.go +++ b/vault/core.go @@ -1,6 +1,7 @@ package vault import ( + "context" "crypto" "crypto/ecdsa" "crypto/subtle" @@ -18,7 +19,6 @@ import ( "github.com/armon/go-metrics" log "github.com/mgutz/logxi/v1" - "golang.org/x/net/context" "google.golang.org/grpc" "github.com/hashicorp/errwrap" @@ -1498,8 +1498,6 @@ func (c *Core) sealInternal() error { // Signal the standby goroutine to shutdown, wait for completion close(c.standbyStopCh) - c.requestContext = nil - // Release the lock while we wait to avoid deadlocking c.stateLock.Unlock() <-c.standbyDoneCh @@ -1536,9 +1534,8 @@ func (c *Core) postUnseal() (retErr error) { defer metrics.MeasureSince([]string{"core", "post_unseal"}, time.Now()) defer func() { if retErr != nil { + c.requestContextCancelFunc() c.preSeal() - } else { - c.requestContext, c.requestContextCancelFunc = context.WithCancel(context.Background()) } }() c.logger.Info("core: post-unseal setup starting") @@ -1559,6 +1556,8 @@ func (c *Core) postUnseal() (retErr error) { c.seal.SetRecoveryConfig(nil) } + c.requestContext, c.requestContextCancelFunc = context.WithCancel(context.Background()) + if err := enterprisePostUnseal(c); err != nil { return err } diff --git a/vault/expiration.go b/vault/expiration.go index 710fcb8f0..23176ae44 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -1,6 +1,7 @@ package vault import ( + "context" "encoding/json" "errors" "fmt" @@ -68,21 +69,20 @@ type ExpirationManager struct { restoreLocks []*locksutil.LockEntry restoreLoaded sync.Map quitCh chan struct{} + + coreStateLock *sync.RWMutex + quitContext context.Context } // NewExpirationManager creates a new ExpirationManager that is backed // using a given view, and uses the provided router for revocation. -func NewExpirationManager(router *Router, view *BarrierView, ts *TokenStore, logger log.Logger) *ExpirationManager { - if logger == nil { - logger = log.New("expiration_manager") - } - +func NewExpirationManager(c *Core, view *BarrierView) *ExpirationManager { exp := &ExpirationManager{ - router: router, + router: c.router, idView: view.SubView(leaseViewPrefix), tokenView: view.SubView(tokenViewPrefix), - tokenStore: ts, - logger: logger, + tokenStore: c.tokenStore, + logger: c.logger, pending: make(map[string]*time.Timer), // new instances of the expiration manager will go immediately into @@ -90,7 +90,15 @@ func NewExpirationManager(router *Router, view *BarrierView, ts *TokenStore, log restoreMode: 1, restoreLocks: locksutil.CreateLocks(), quitCh: make(chan struct{}), + + coreStateLock: &c.stateLock, + quitContext: c.requestContext, } + + if exp.logger == nil { + exp.logger = log.New("expiration_manager") + } + return exp } @@ -103,7 +111,7 @@ func (c *Core) setupExpiration() error { view := c.systemBarrierView.SubView(expirationSubPath) // Create the manager - mgr := NewExpirationManager(c.router, view, c.tokenStore, c.logger) + mgr := NewExpirationManager(c, view) c.expiration = mgr // Link the token store to this @@ -430,6 +438,10 @@ func (m *ExpirationManager) Stop() error { m.logger.Debug("expiration: stop triggered") defer m.logger.Debug("expiration: finished stopping") + // Do this before stopping pending timers to avoid potential races with + // expiring timers + close(m.quitCh) + m.pendingLock.Lock() for _, timer := range m.pending { timer.Stop() @@ -437,7 +449,6 @@ func (m *ExpirationManager) Stop() error { m.pending = make(map[string]*time.Timer) m.pendingLock.Unlock() - close(m.quitCh) if m.inRestoreMode() { for { if !m.inRestoreMode() { @@ -969,13 +980,24 @@ func (m *ExpirationManager) expireID(leaseID string) { return default: } + + m.coreStateLock.RLock() + if m.quitContext.Err() == context.Canceled { + m.logger.Error("expiration: core context canceled, not attempting further revocation of lease", "lease_id", leaseID) + m.coreStateLock.RUnlock() + return + } + err := m.Revoke(leaseID) if err == nil { if m.logger.IsInfo() { m.logger.Info("expiration: revoked lease", "lease_id", leaseID) } + m.coreStateLock.RUnlock() return } + + m.coreStateLock.RUnlock() m.logger.Error("expiration: failed to revoke lease", "lease_id", leaseID, "error", err) time.Sleep((1 << attempt) * revokeRetryBase) } diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index 10e6a9a57..d031566d3 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -1,6 +1,7 @@ package vault import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -13,7 +14,6 @@ import ( "time" "github.com/hashicorp/vault/helper/forwarding" - "golang.org/x/net/context" "golang.org/x/net/http2" "google.golang.org/grpc" "google.golang.org/grpc/keepalive"