diff --git a/command/base.go b/command/base.go index 9521e1722..ea677514e 100644 --- a/command/base.go +++ b/command/base.go @@ -15,6 +15,8 @@ import ( "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/command/token" "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/sdk/logical" + "github.com/mattn/go-isatty" "github.com/mitchellh/cli" "github.com/pkg/errors" "github.com/posener/complete" @@ -212,6 +214,90 @@ func (c *BaseCommand) DefaultWrappingLookupFunc(operation, path string) string { return api.DefaultWrappingLookupFunc(operation, path) } +func (c *BaseCommand) isInteractiveEnabled(mfaConstraintLen int) bool { + if mfaConstraintLen != 1 || !isatty.IsTerminal(os.Stdin.Fd()) { + return false + } + + if !c.flagNonInteractive { + return true + } + + return false +} + +// getMFAMethodInfo returns MFA method information only if one MFA method is +// configured. +func (c *BaseCommand) getMFAMethodInfo(mfaConstraintAny map[string]*logical.MFAConstraintAny) MFAMethodInfo { + for _, mfaConstraint := range mfaConstraintAny { + if len(mfaConstraint.Any) != 1 { + return MFAMethodInfo{} + } + + return MFAMethodInfo{ + methodType: mfaConstraint.Any[0].Type, + methodID: mfaConstraint.Any[0].ID, + usePasscode: mfaConstraint.Any[0].UsesPasscode, + } + } + + return MFAMethodInfo{} +} + +func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int { + var passcode string + var err error + if methodInfo.usePasscode { + passcode, err = c.UI.AskSecret(fmt.Sprintf("Enter the passphrase for methodID %q of type %q:", methodInfo.methodID, methodInfo.methodType)) + if err != nil { + c.UI.Error(fmt.Sprintf("failed to read the passphrase with error %q. please validate the login by sending a request to sys/mfa/validate", err.Error())) + return 2 + } + } else { + c.UI.Warn("Asking Vault to perform MFA validation with upstream service. " + + "You should receive a push notification in your authenticator app shortly") + } + + // passcode could be an empty string + mfaPayload := map[string][]string{ + methodInfo.methodID: {passcode}, + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + path := "sys/mfa/validate" + + secret, err := client.Logical().Write(path, map[string]interface{}{ + "mfa_request_id": reqID, + "mfa_payload": mfaPayload, + }) + if err != nil { + c.UI.Error(err.Error()) + if secret != nil { + OutputSecret(c.UI, secret) + } + return 2 + } + if secret == nil { + // Don't output anything unless using the "table" format + if Format(c.UI) == "table" { + c.UI.Info(fmt.Sprintf("Success! Data written to: %s", path)) + } + return 0 + } + + // Handle single field output + if c.flagField != "" { + return PrintRawField(c.UI, secret, c.flagField) + } + + return OutputSecret(c.UI, secret) +} + type FlagSetBit uint const ( diff --git a/command/login.go b/command/login.go index 1128bf962..4a56075e4 100644 --- a/command/login.go +++ b/command/login.go @@ -228,6 +228,20 @@ func (c *LoginCommand) Run(args []string) int { return 2 } + if secret != nil && secret.Auth != nil && secret.Auth.MFARequirement != nil { + if c.isInteractiveEnabled(len(secret.Auth.MFARequirement.MFAConstraints)) { + // Currently, if there is only one MFA method configured, the login + // request is validated interactively + methodInfo := c.getMFAMethodInfo(secret.Auth.MFARequirement.MFAConstraints) + if methodInfo.methodID != "" { + return c.validateMFA(secret.Auth.MFARequirement.MFARequestID, methodInfo) + } + } + c.UI.Warn(wrapAtLength("A login request was issued that is subject to "+ + "MFA validation. Please make sure to validate the login by sending another "+ + "request to sys/mfa/validate endpoint.") + "\n") + } + // Unset any previous token wrapping functionality. If the original request // was for a wrapped token, we don't want future requests to be wrapped. client.SetWrappingLookupFunc(func(string, string) string { return "" }) diff --git a/command/write.go b/command/write.go index 0de7299eb..4990cde71 100644 --- a/command/write.go +++ b/command/write.go @@ -6,8 +6,6 @@ import ( "os" "strings" - "github.com/hashicorp/vault/sdk/logical" - "github.com/mattn/go-isatty" "github.com/mitchellh/cli" "github.com/posener/complete" ) @@ -176,87 +174,3 @@ func (c *WriteCommand) Run(args []string) int { return OutputSecret(c.UI, secret) } - -func (c *WriteCommand) isInteractiveEnabled(mfaConstraintLen int) bool { - if mfaConstraintLen != 1 || !isatty.IsTerminal(os.Stdin.Fd()) { - return false - } - - if !c.flagNonInteractive { - return true - } - - return false -} - -// getMFAMethodInfo returns MFA method information only if one MFA method is -// configured. -func (c *WriteCommand) getMFAMethodInfo(mfaConstraintAny map[string]*logical.MFAConstraintAny) MFAMethodInfo { - for _, mfaConstraint := range mfaConstraintAny { - if len(mfaConstraint.Any) != 1 { - return MFAMethodInfo{} - } - - return MFAMethodInfo{ - methodType: mfaConstraint.Any[0].Type, - methodID: mfaConstraint.Any[0].ID, - usePasscode: mfaConstraint.Any[0].UsesPasscode, - } - } - - return MFAMethodInfo{} -} - -func (c *WriteCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int { - var passcode string - var err error - if methodInfo.usePasscode { - passcode, err = c.UI.AskSecret(fmt.Sprintf("Enter the passphrase for methodID %q of type %q:", methodInfo.methodID, methodInfo.methodType)) - if err != nil { - c.UI.Error(fmt.Sprintf("failed to read the passphrase with error %q. please validate the login by sending a request to sys/mfa/validate", err.Error())) - return 2 - } - } else { - c.UI.Warn("Asking Vault to perform MFA validation with upstream service. " + - "You should receive a push notification in your authenticator app shortly") - } - - // passcode could be an empty string - mfaPayload := map[string][]string{ - methodInfo.methodID: {passcode}, - } - - client, err := c.Client() - if err != nil { - c.UI.Error(err.Error()) - return 2 - } - - path := "sys/mfa/validate" - - secret, err := client.Logical().Write(path, map[string]interface{}{ - "mfa_request_id": reqID, - "mfa_payload": mfaPayload, - }) - if err != nil { - c.UI.Error(err.Error()) - if secret != nil { - OutputSecret(c.UI, secret) - } - return 2 - } - if secret == nil { - // Don't output anything unless using the "table" format - if Format(c.UI) == "table" { - c.UI.Info(fmt.Sprintf("Success! Data written to: %s", path)) - } - return 0 - } - - // Handle single field output - if c.flagField != "" { - return PrintRawField(c.UI, secret, c.flagField) - } - - return OutputSecret(c.UI, secret) -} diff --git a/vault/login_mfa.go b/vault/login_mfa.go index 4de0bb553..25e762d13 100644 --- a/vault/login_mfa.go +++ b/vault/login_mfa.go @@ -2010,15 +2010,11 @@ func (c *Core) validateTOTP(ctx context.Context, creds []string, entityMethodSec return fmt.Errorf("entity does not contain the TOTP secret") } - // Take the key skew, add two for behind and in front, and multiply that by - // the period to cover the full possibility of the validity of the key - validityPeriod := time.Duration(int64(time.Second) * int64(totpSecret.Period) * int64(2+totpSecret.Skew)) - usedName := fmt.Sprintf("%s_%s", configID, creds[0]) _, ok := c.loginMFABackend.usedCodes.Get(usedName) if ok { - return fmt.Errorf("code already used; new code is available in %v seconds", validityPeriod) + return fmt.Errorf("code already used; new code is available in %v seconds", totpSecret.Period) } key, err := c.fetchTOTPKey(ctx, configID, entityID) @@ -2046,6 +2042,10 @@ func (c *Core) validateTOTP(ctx context.Context, creds []string, entityMethodSec return fmt.Errorf("failed to validate TOTP passcode") } + // Take the key skew, add two for behind and in front, and multiply that by + // the period to cover the full possibility of the validity of the key + validityPeriod := time.Duration(int64(time.Second) * int64(totpSecret.Period) * int64(2+totpSecret.Skew)) + // Adding the used code to the cache err = c.loginMFABackend.usedCodes.Add(usedName, nil, validityPeriod) if err != nil {