diff --git a/changelog/15261.txt b/changelog/15261.txt new file mode 100644 index 000000000..c04c80a20 --- /dev/null +++ b/changelog/15261.txt @@ -0,0 +1,3 @@ +```release-note:bug +auth: load login MFA configuration upon restart +``` diff --git a/vault/core.go b/vault/core.go index cc74ed247..7a0e3fc69 100644 --- a/vault/core.go +++ b/vault/core.go @@ -35,6 +35,7 @@ import ( "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/command/server" + "github.com/hashicorp/vault/helper/identity/mfa" "github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/osutil" @@ -2139,7 +2140,6 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c if err := c.setupQuotas(ctx, false); err != nil { return err } - c.setupCachedMFAResponseAuth() if err := c.setupHeaderHMACKey(ctx, false); err != nil { return err @@ -2161,9 +2161,14 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c if err := c.loadIdentityStoreArtifacts(ctx); err != nil { return err } - if err := loadMFAConfigs(ctx, c); err != nil { + if err := loadPolicyMFAConfigs(ctx, c); err != nil { return err } + c.setupCachedMFAResponseAuth() + if err := c.loadLoginMFAConfigs(ctx); err != nil { + return err + } + if err := c.setupAuditedHeadersConfig(ctx); err != nil { return err } @@ -2325,10 +2330,6 @@ func (c *Core) preSeal() error { result = multierror.Append(result, fmt.Errorf("error stopping expiration: %w", err)) } c.stopActivityLog() - // Clear any cached auth response - c.mfaResponseAuthQueueLock.Lock() - c.mfaResponseAuthQueue = nil - c.mfaResponseAuthQueueLock.Unlock() if err := c.teardownCredentials(context.Background()); err != nil { result = multierror.Append(result, fmt.Errorf("error tearing down credentials: %w", err)) @@ -2356,10 +2357,13 @@ func (c *Core) preSeal() error { seal.StopHealthCheck() } - c.loginMFABackend.usedCodes = nil if c.systemBackend != nil && c.systemBackend.mfaBackend != nil { c.systemBackend.mfaBackend.usedCodes = nil } + if err := c.teardownLoginMFA(); err != nil { + result = multierror.Append(result, fmt.Errorf("error tearing down login MFA, error: %w", err)) + } + preSealPhysical(c) c.logger.Info("pre-seal teardown complete") @@ -3073,6 +3077,31 @@ type LicenseState struct { Terminated bool } +func (c *Core) loadLoginMFAConfigs(ctx context.Context) error { + eConfigs := make([]*mfa.MFAEnforcementConfig, 0) + allNamespaces := c.collectNamespaces() + for _, ns := range allNamespaces { + err := c.loginMFABackend.loadMFAMethodConfigs(ctx, ns) + if err != nil { + return fmt.Errorf("error loading MFA method Config, namespaceid %s, error: %w", ns.ID, err) + } + + loadedConfigs, err := c.loginMFABackend.loadMFAEnforcementConfigs(ctx, ns) + if err != nil { + return fmt.Errorf("error loading MFA enforcement Config, namespaceid %s, error: %w", ns.ID, err) + } + + eConfigs = append(eConfigs, loadedConfigs...) + } + + for _, conf := range eConfigs { + if err := c.loginMFABackend.loginMFAMethodExistenceCheck(conf); err != nil { + c.loginMFABackend.mfaLogger.Error("failed to find all MFA methods that exist in MFA enforcement configs", "configID", conf.ID, "namespaceID", conf.NamespaceID, "error", err.Error()) + } + } + return nil +} + type MFACachedAuthResponse struct { CachedAuth *logical.Auth RequestPath string diff --git a/vault/core_util.go b/vault/core_util.go index 965cc32ae..efa155bc2 100644 --- a/vault/core_util.go +++ b/vault/core_util.go @@ -115,7 +115,7 @@ func postUnsealPhysical(c *Core) error { return nil } -func loadMFAConfigs(context.Context, *Core) error { return nil } +func loadPolicyMFAConfigs(context.Context, *Core) error { return nil } func shouldStartClusterListener(*Core) bool { return true } diff --git a/vault/login_mfa.go b/vault/login_mfa.go index 6cb039023..0976482b8 100644 --- a/vault/login_mfa.go +++ b/vault/login_mfa.go @@ -122,6 +122,18 @@ func NewLoginMFABackend(core *Core, logger hclog.Logger) *LoginMFABackend { } func NewMFABackend(core *Core, logger hclog.Logger, prefix string, schemaFuncs []func() *memdb.TableSchema) *MFABackend { + db, _ := SetupMFAMemDB(schemaFuncs) + return &MFABackend{ + Core: core, + mfaLock: &sync.RWMutex{}, + db: db, + mfaLogger: logger.Named("mfa"), + namespacer: core, + methodTable: prefix, + } +} + +func SetupMFAMemDB(schemaFuncs []func() *memdb.TableSchema) (*memdb.MemDB, error) { mfaSchemas := &memdb.DBSchema{ Tables: make(map[string]*memdb.TableSchema), } @@ -134,15 +146,24 @@ func NewMFABackend(core *Core, logger hclog.Logger, prefix string, schemaFuncs [ mfaSchemas.Tables[schema.Name] = schema } - db, _ := memdb.NewMemDB(mfaSchemas) - return &MFABackend{ - Core: core, - mfaLock: &sync.RWMutex{}, - db: db, - mfaLogger: logger.Named("mfa"), - namespacer: core, - methodTable: prefix, + db, err := memdb.NewMemDB(mfaSchemas) + if err != nil { + return nil, err } + return db, nil +} + +func (b *LoginMFABackend) ResetLoginMFAMemDB() error { + var err error + + db, err := SetupMFAMemDB(loginMFASchemaFuncs()) + if err != nil { + return err + } + + b.db = db + + return nil } func (i *IdentityStore) handleMFAMethodListTOTP(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { @@ -474,6 +495,103 @@ func (i *IdentityStore) handleLoginMFAAdminDestroyUpdate(ctx context.Context, re return nil, nil } +// loadMFAMethodConfigs loads MFA method configs for login MFA +func (b *LoginMFABackend) loadMFAMethodConfigs(ctx context.Context, ns *namespace.Namespace) error { + b.mfaLogger.Trace("loading login MFA configurations") + barrierView, err := b.Core.barrierViewForNamespace(ns.ID) + if err != nil { + return fmt.Errorf("error getting namespace view, namespaceid %s, error %w", ns.ID, err) + } + existing, err := barrierView.List(ctx, loginMFAConfigPrefix) + if err != nil { + return fmt.Errorf("failed to list MFA configurations for namespace path %s and prefix %s: %w", ns.Path, loginMFAConfigPrefix, err) + } + b.mfaLogger.Trace("methods collected", "num_existing", len(existing)) + + for _, key := range existing { + b.mfaLogger.Trace("loading method", "method", key) + + // Read the config from storage + mConfig, err := b.getMFAConfig(ctx, loginMFAConfigPrefix+key, barrierView) + if err != nil { + return err + } + + if mConfig == nil { + b.mfaLogger.Trace("failed to find the config related to a method", "namespace", ns.Path, "prefix", loginMFAConfigPrefix, "method", key) + continue + } + + // Load the config in MemDB + err = b.MemDBUpsertMFAConfig(ctx, mConfig) + if err != nil { + return fmt.Errorf("failed to load configuration ID %s prefix %s in MemDB: %w", mConfig.ID, loginMFAConfigPrefix, err) + } + } + + b.mfaLogger.Trace("configurations restored", "namespace", ns.Path, "prefix", loginMFAConfigPrefix) + + return nil +} + +// loadMFAEnforcementConfigs loads MFA method configs for login MFA +func (b *LoginMFABackend) loadMFAEnforcementConfigs(ctx context.Context, ns *namespace.Namespace) ([]*mfa.MFAEnforcementConfig, error) { + b.mfaLogger.Trace("loading login MFA enforcement configurations") + barrierView, err := b.Core.barrierViewForNamespace(ns.ID) + if err != nil { + return nil, fmt.Errorf("error getting namespace view, namespaceid %s, error %w", ns.ID, err) + } + existing, err := barrierView.List(ctx, mfaLoginEnforcementPrefix) + if err != nil { + return nil, fmt.Errorf("failed to list MFA enforcement configurations for namespace %s with prefix %s: %w", ns.Path, mfaLoginEnforcementPrefix, err) + } + b.mfaLogger.Trace("enforcements configs collected", "num_existing", len(existing)) + + eConfigs := make([]*mfa.MFAEnforcementConfig, 0) + for _, key := range existing { + b.mfaLogger.Trace("loading enforcement", "config", key) + + // Read the config from storage + mConfig, err := b.getMFALoginEnforcementConfig(ctx, mfaLoginEnforcementPrefix+key, barrierView) + if err != nil { + return nil, err + } + + if mConfig == nil { + b.mfaLogger.Trace("failed to find an enforcement config", "namespace", ns.Path, "prefix", mfaLoginEnforcementPrefix, "config", key) + continue + } + + // Load the config in MemDB + err = b.MemDBUpsertMFALoginEnforcementConfig(ctx, mConfig) + if err != nil { + return nil, fmt.Errorf("failed to load enforcement configuration ID %s with prefix %s in MemDB: %w", mConfig.ID, mfaLoginEnforcementPrefix, err) + } + + eConfigs = append(eConfigs, mConfig) + } + + b.mfaLogger.Trace("enforcement configurations restored", "namespace", ns.Path, "prefix", mfaLoginEnforcementPrefix) + + return eConfigs, nil +} + +func (b *LoginMFABackend) loginMFAMethodExistenceCheck(eConfig *mfa.MFAEnforcementConfig) error { + var aggErr *multierror.Error + for _, confID := range eConfig.MFAMethodIDs { + config, memErr := b.MemDBMFAConfigByID(confID) + if memErr != nil { + aggErr = multierror.Append(aggErr, memErr) + return aggErr.ErrorOrNil() + } + if config == nil { + aggErr = multierror.Append(aggErr, fmt.Errorf("found an MFA method ID in enforcement config, but failed to find the MFA method config method ID %s", confID)) + } + } + + return aggErr.ErrorOrNil() +} + func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logical.Request, d *framework.FieldData) (retResp *logical.Response, retErr error) { // mfaReqID is the ID of the login request mfaReqID := d.Get("mfa_request_id").(string) @@ -551,6 +669,22 @@ func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logic return resp, nil } +func (c *Core) teardownLoginMFA() error { + if !c.IsDRSecondary() { + // Clear any cached auth response + c.mfaResponseAuthQueueLock.Lock() + c.mfaResponseAuthQueue = nil + c.mfaResponseAuthQueueLock.Unlock() + + c.loginMFABackend.usedCodes = nil + + if err := c.loginMFABackend.ResetLoginMFAMemDB(); err != nil { + return err + } + } + return nil +} + // LoginMFACreateToken creates a token after the login MFA is validated. // It also applies the lease quotas on the original login request path. func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAuth *logical.Auth) (*logical.Response, error) { @@ -2320,7 +2454,6 @@ func (b *LoginMFABackend) deleteMFALoginEnforcementConfigByNameAndNamespace(ctx } entryIndex := mfaLoginEnforcementPrefix + eConfig.ID - barrierView, err := b.Core.barrierViewForNamespace(eConfig.NamespaceID) if err != nil { return err @@ -2530,6 +2663,25 @@ func (b *MFABackend) getMFAConfig(ctx context.Context, path string, barrierView return &mConfig, nil } +func (b *LoginMFABackend) getMFALoginEnforcementConfig(ctx context.Context, path string, barrierView *BarrierView) (*mfa.MFAEnforcementConfig, error) { + entry, err := barrierView.Get(ctx, path) + if err != nil { + return nil, err + } + + if entry == nil { + return nil, nil + } + + var mConfig mfa.MFAEnforcementConfig + err = proto.Unmarshal(entry.Value, &mConfig) + if err != nil { + return nil, err + } + + return &mConfig, nil +} + func (b *LoginMFABackend) putMFALoginEnforcementConfig(ctx context.Context, eConfig *mfa.MFAEnforcementConfig) error { entryIndex := mfaLoginEnforcementPrefix + eConfig.ID marshaledEntry, err := proto.Marshal(eConfig) @@ -2548,28 +2700,6 @@ func (b *LoginMFABackend) putMFALoginEnforcementConfig(ctx context.Context, eCon }) } -func (b *LoginMFABackend) getMFALoginEnforcementConfig(ctx context.Context, key, namespaceId string) (*mfa.MFAEnforcementConfig, error) { - barrierView, err := b.Core.barrierViewForNamespace(namespaceId) - if err != nil { - return nil, err - } - entry, err := barrierView.Get(ctx, mfaLoginEnforcementPrefix+key) - if err != nil { - return nil, err - } - if entry == nil { - return nil, nil - } - - var eConfig mfa.MFAEnforcementConfig - err = proto.Unmarshal(entry.Value, &eConfig) - if err != nil { - return nil, err - } - - return &eConfig, nil -} - var mfaHelp = map[string][2]string{ "methods-list": { "Lists all the available MFA methods by their name.",