loading MFA configs upont restart (#15261)

* loading MFA configs upont restart

* Adding CL

* feedback

* Update vault/core.go

Co-authored-by: Chris Capurso <1036769+ccapurso@users.noreply.github.com>

Co-authored-by: Chris Capurso <1036769+ccapurso@users.noreply.github.com>
This commit is contained in:
Hamid Ghaf 2022-05-05 18:53:57 -04:00 committed by GitHub
parent 36f1938b30
commit 0fc7a363cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 201 additions and 39 deletions

3
changelog/15261.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
auth: load login MFA configuration upon restart
```

View File

@ -35,6 +35,7 @@ import (
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/command/server"
"github.com/hashicorp/vault/helper/identity/mfa"
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/osutil"
@ -2139,7 +2140,6 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c
if err := c.setupQuotas(ctx, false); err != nil {
return err
}
c.setupCachedMFAResponseAuth()
if err := c.setupHeaderHMACKey(ctx, false); err != nil {
return err
@ -2161,9 +2161,14 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c
if err := c.loadIdentityStoreArtifacts(ctx); err != nil {
return err
}
if err := loadMFAConfigs(ctx, c); err != nil {
if err := loadPolicyMFAConfigs(ctx, c); err != nil {
return err
}
c.setupCachedMFAResponseAuth()
if err := c.loadLoginMFAConfigs(ctx); err != nil {
return err
}
if err := c.setupAuditedHeadersConfig(ctx); err != nil {
return err
}
@ -2325,10 +2330,6 @@ func (c *Core) preSeal() error {
result = multierror.Append(result, fmt.Errorf("error stopping expiration: %w", err))
}
c.stopActivityLog()
// Clear any cached auth response
c.mfaResponseAuthQueueLock.Lock()
c.mfaResponseAuthQueue = nil
c.mfaResponseAuthQueueLock.Unlock()
if err := c.teardownCredentials(context.Background()); err != nil {
result = multierror.Append(result, fmt.Errorf("error tearing down credentials: %w", err))
@ -2356,10 +2357,13 @@ func (c *Core) preSeal() error {
seal.StopHealthCheck()
}
c.loginMFABackend.usedCodes = nil
if c.systemBackend != nil && c.systemBackend.mfaBackend != nil {
c.systemBackend.mfaBackend.usedCodes = nil
}
if err := c.teardownLoginMFA(); err != nil {
result = multierror.Append(result, fmt.Errorf("error tearing down login MFA, error: %w", err))
}
preSealPhysical(c)
c.logger.Info("pre-seal teardown complete")
@ -3073,6 +3077,31 @@ type LicenseState struct {
Terminated bool
}
func (c *Core) loadLoginMFAConfigs(ctx context.Context) error {
eConfigs := make([]*mfa.MFAEnforcementConfig, 0)
allNamespaces := c.collectNamespaces()
for _, ns := range allNamespaces {
err := c.loginMFABackend.loadMFAMethodConfigs(ctx, ns)
if err != nil {
return fmt.Errorf("error loading MFA method Config, namespaceid %s, error: %w", ns.ID, err)
}
loadedConfigs, err := c.loginMFABackend.loadMFAEnforcementConfigs(ctx, ns)
if err != nil {
return fmt.Errorf("error loading MFA enforcement Config, namespaceid %s, error: %w", ns.ID, err)
}
eConfigs = append(eConfigs, loadedConfigs...)
}
for _, conf := range eConfigs {
if err := c.loginMFABackend.loginMFAMethodExistenceCheck(conf); err != nil {
c.loginMFABackend.mfaLogger.Error("failed to find all MFA methods that exist in MFA enforcement configs", "configID", conf.ID, "namespaceID", conf.NamespaceID, "error", err.Error())
}
}
return nil
}
type MFACachedAuthResponse struct {
CachedAuth *logical.Auth
RequestPath string

View File

@ -115,7 +115,7 @@ func postUnsealPhysical(c *Core) error {
return nil
}
func loadMFAConfigs(context.Context, *Core) error { return nil }
func loadPolicyMFAConfigs(context.Context, *Core) error { return nil }
func shouldStartClusterListener(*Core) bool { return true }

View File

@ -122,6 +122,18 @@ func NewLoginMFABackend(core *Core, logger hclog.Logger) *LoginMFABackend {
}
func NewMFABackend(core *Core, logger hclog.Logger, prefix string, schemaFuncs []func() *memdb.TableSchema) *MFABackend {
db, _ := SetupMFAMemDB(schemaFuncs)
return &MFABackend{
Core: core,
mfaLock: &sync.RWMutex{},
db: db,
mfaLogger: logger.Named("mfa"),
namespacer: core,
methodTable: prefix,
}
}
func SetupMFAMemDB(schemaFuncs []func() *memdb.TableSchema) (*memdb.MemDB, error) {
mfaSchemas := &memdb.DBSchema{
Tables: make(map[string]*memdb.TableSchema),
}
@ -134,15 +146,24 @@ func NewMFABackend(core *Core, logger hclog.Logger, prefix string, schemaFuncs [
mfaSchemas.Tables[schema.Name] = schema
}
db, _ := memdb.NewMemDB(mfaSchemas)
return &MFABackend{
Core: core,
mfaLock: &sync.RWMutex{},
db: db,
mfaLogger: logger.Named("mfa"),
namespacer: core,
methodTable: prefix,
db, err := memdb.NewMemDB(mfaSchemas)
if err != nil {
return nil, err
}
return db, nil
}
func (b *LoginMFABackend) ResetLoginMFAMemDB() error {
var err error
db, err := SetupMFAMemDB(loginMFASchemaFuncs())
if err != nil {
return err
}
b.db = db
return nil
}
func (i *IdentityStore) handleMFAMethodListTOTP(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
@ -474,6 +495,103 @@ func (i *IdentityStore) handleLoginMFAAdminDestroyUpdate(ctx context.Context, re
return nil, nil
}
// loadMFAMethodConfigs loads MFA method configs for login MFA
func (b *LoginMFABackend) loadMFAMethodConfigs(ctx context.Context, ns *namespace.Namespace) error {
b.mfaLogger.Trace("loading login MFA configurations")
barrierView, err := b.Core.barrierViewForNamespace(ns.ID)
if err != nil {
return fmt.Errorf("error getting namespace view, namespaceid %s, error %w", ns.ID, err)
}
existing, err := barrierView.List(ctx, loginMFAConfigPrefix)
if err != nil {
return fmt.Errorf("failed to list MFA configurations for namespace path %s and prefix %s: %w", ns.Path, loginMFAConfigPrefix, err)
}
b.mfaLogger.Trace("methods collected", "num_existing", len(existing))
for _, key := range existing {
b.mfaLogger.Trace("loading method", "method", key)
// Read the config from storage
mConfig, err := b.getMFAConfig(ctx, loginMFAConfigPrefix+key, barrierView)
if err != nil {
return err
}
if mConfig == nil {
b.mfaLogger.Trace("failed to find the config related to a method", "namespace", ns.Path, "prefix", loginMFAConfigPrefix, "method", key)
continue
}
// Load the config in MemDB
err = b.MemDBUpsertMFAConfig(ctx, mConfig)
if err != nil {
return fmt.Errorf("failed to load configuration ID %s prefix %s in MemDB: %w", mConfig.ID, loginMFAConfigPrefix, err)
}
}
b.mfaLogger.Trace("configurations restored", "namespace", ns.Path, "prefix", loginMFAConfigPrefix)
return nil
}
// loadMFAEnforcementConfigs loads MFA method configs for login MFA
func (b *LoginMFABackend) loadMFAEnforcementConfigs(ctx context.Context, ns *namespace.Namespace) ([]*mfa.MFAEnforcementConfig, error) {
b.mfaLogger.Trace("loading login MFA enforcement configurations")
barrierView, err := b.Core.barrierViewForNamespace(ns.ID)
if err != nil {
return nil, fmt.Errorf("error getting namespace view, namespaceid %s, error %w", ns.ID, err)
}
existing, err := barrierView.List(ctx, mfaLoginEnforcementPrefix)
if err != nil {
return nil, fmt.Errorf("failed to list MFA enforcement configurations for namespace %s with prefix %s: %w", ns.Path, mfaLoginEnforcementPrefix, err)
}
b.mfaLogger.Trace("enforcements configs collected", "num_existing", len(existing))
eConfigs := make([]*mfa.MFAEnforcementConfig, 0)
for _, key := range existing {
b.mfaLogger.Trace("loading enforcement", "config", key)
// Read the config from storage
mConfig, err := b.getMFALoginEnforcementConfig(ctx, mfaLoginEnforcementPrefix+key, barrierView)
if err != nil {
return nil, err
}
if mConfig == nil {
b.mfaLogger.Trace("failed to find an enforcement config", "namespace", ns.Path, "prefix", mfaLoginEnforcementPrefix, "config", key)
continue
}
// Load the config in MemDB
err = b.MemDBUpsertMFALoginEnforcementConfig(ctx, mConfig)
if err != nil {
return nil, fmt.Errorf("failed to load enforcement configuration ID %s with prefix %s in MemDB: %w", mConfig.ID, mfaLoginEnforcementPrefix, err)
}
eConfigs = append(eConfigs, mConfig)
}
b.mfaLogger.Trace("enforcement configurations restored", "namespace", ns.Path, "prefix", mfaLoginEnforcementPrefix)
return eConfigs, nil
}
func (b *LoginMFABackend) loginMFAMethodExistenceCheck(eConfig *mfa.MFAEnforcementConfig) error {
var aggErr *multierror.Error
for _, confID := range eConfig.MFAMethodIDs {
config, memErr := b.MemDBMFAConfigByID(confID)
if memErr != nil {
aggErr = multierror.Append(aggErr, memErr)
return aggErr.ErrorOrNil()
}
if config == nil {
aggErr = multierror.Append(aggErr, fmt.Errorf("found an MFA method ID in enforcement config, but failed to find the MFA method config method ID %s", confID))
}
}
return aggErr.ErrorOrNil()
}
func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logical.Request, d *framework.FieldData) (retResp *logical.Response, retErr error) {
// mfaReqID is the ID of the login request
mfaReqID := d.Get("mfa_request_id").(string)
@ -551,6 +669,22 @@ func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logic
return resp, nil
}
func (c *Core) teardownLoginMFA() error {
if !c.IsDRSecondary() {
// Clear any cached auth response
c.mfaResponseAuthQueueLock.Lock()
c.mfaResponseAuthQueue = nil
c.mfaResponseAuthQueueLock.Unlock()
c.loginMFABackend.usedCodes = nil
if err := c.loginMFABackend.ResetLoginMFAMemDB(); err != nil {
return err
}
}
return nil
}
// LoginMFACreateToken creates a token after the login MFA is validated.
// It also applies the lease quotas on the original login request path.
func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAuth *logical.Auth) (*logical.Response, error) {
@ -2320,7 +2454,6 @@ func (b *LoginMFABackend) deleteMFALoginEnforcementConfigByNameAndNamespace(ctx
}
entryIndex := mfaLoginEnforcementPrefix + eConfig.ID
barrierView, err := b.Core.barrierViewForNamespace(eConfig.NamespaceID)
if err != nil {
return err
@ -2530,6 +2663,25 @@ func (b *MFABackend) getMFAConfig(ctx context.Context, path string, barrierView
return &mConfig, nil
}
func (b *LoginMFABackend) getMFALoginEnforcementConfig(ctx context.Context, path string, barrierView *BarrierView) (*mfa.MFAEnforcementConfig, error) {
entry, err := barrierView.Get(ctx, path)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var mConfig mfa.MFAEnforcementConfig
err = proto.Unmarshal(entry.Value, &mConfig)
if err != nil {
return nil, err
}
return &mConfig, nil
}
func (b *LoginMFABackend) putMFALoginEnforcementConfig(ctx context.Context, eConfig *mfa.MFAEnforcementConfig) error {
entryIndex := mfaLoginEnforcementPrefix + eConfig.ID
marshaledEntry, err := proto.Marshal(eConfig)
@ -2548,28 +2700,6 @@ func (b *LoginMFABackend) putMFALoginEnforcementConfig(ctx context.Context, eCon
})
}
func (b *LoginMFABackend) getMFALoginEnforcementConfig(ctx context.Context, key, namespaceId string) (*mfa.MFAEnforcementConfig, error) {
barrierView, err := b.Core.barrierViewForNamespace(namespaceId)
if err != nil {
return nil, err
}
entry, err := barrierView.Get(ctx, mfaLoginEnforcementPrefix+key)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var eConfig mfa.MFAEnforcementConfig
err = proto.Unmarshal(entry.Value, &eConfig)
if err != nil {
return nil, err
}
return &eConfig, nil
}
var mfaHelp = map[string][2]string{
"methods-list": {
"Lists all the available MFA methods by their name.",