diff --git a/changelog/17040.txt b/changelog/17040.txt new file mode 100644 index 000000000..add116d1e --- /dev/null +++ b/changelog/17040.txt @@ -0,0 +1,3 @@ +```release-note:bug +login: Store token in tokenhelper for interactive login MFA +``` diff --git a/command/base.go b/command/base.go index 3aa5bb749..3655494ab 100644 --- a/command/base.go +++ b/command/base.go @@ -15,7 +15,6 @@ import ( "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/command/token" "github.com/hashicorp/vault/helper/namespace" - "github.com/hashicorp/vault/sdk/logical" "github.com/mattn/go-isatty" "github.com/mitchellh/cli" "github.com/pkg/errors" @@ -220,44 +219,55 @@ func (c *BaseCommand) DefaultWrappingLookupFunc(operation, path string) string { return api.DefaultWrappingLookupFunc(operation, path) } -func (c *BaseCommand) isInteractiveEnabled(mfaConstraintLen int) bool { - if mfaConstraintLen != 1 || !isatty.IsTerminal(os.Stdin.Fd()) { - return false - } - - if !c.flagNonInteractive { - return true +// getValidationRequired checks to see if the secret exists and has an MFA +// requirement. If MFA is required and the number of constraints is greater than +// 1, we can assert that interactive validation is not required. +func (c *BaseCommand) getMFAValidationRequired(secret *api.Secret) bool { + if secret != nil && secret.Auth != nil && secret.Auth.MFARequirement != nil { + if c.flagMFA == nil && len(secret.Auth.MFARequirement.MFAConstraints) == 1 { + return true + } else if len(secret.Auth.MFARequirement.MFAConstraints) > 1 { + return true + } } return false } -// getMFAMethodInfo returns MFA method information only if one MFA method is -// configured. -func (c *BaseCommand) getMFAMethodInfo(mfaConstraintAny map[string]*logical.MFAConstraintAny) MFAMethodInfo { - for _, mfaConstraint := range mfaConstraintAny { +// getInteractiveMFAMethodInfo returns MFA method information only if operating +// in interactive mode and one MFA method is configured. +func (c *BaseCommand) getInteractiveMFAMethodInfo(secret *api.Secret) *MFAMethodInfo { + if secret == nil || secret.Auth == nil || secret.Auth.MFARequirement == nil { + return nil + } + + mfaConstraints := secret.Auth.MFARequirement.MFAConstraints + if c.flagNonInteractive || len(mfaConstraints) != 1 || !isatty.IsTerminal(os.Stdin.Fd()) { + return nil + } + + for _, mfaConstraint := range mfaConstraints { if len(mfaConstraint.Any) != 1 { - return MFAMethodInfo{} + return nil } - return MFAMethodInfo{ + return &MFAMethodInfo{ methodType: mfaConstraint.Any[0].Type, methodID: mfaConstraint.Any[0].ID, usePasscode: mfaConstraint.Any[0].UsesPasscode, } } - return MFAMethodInfo{} + return nil } -func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int { +func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) (*api.Secret, error) { var passcode string var err error if methodInfo.usePasscode { passcode, err = c.UI.AskSecret(fmt.Sprintf("Enter the passphrase for methodID %q of type %q:", methodInfo.methodID, methodInfo.methodType)) if err != nil { - c.UI.Error(fmt.Sprintf("failed to read the passphrase with error %q. please validate the login by sending a request to sys/mfa/validate", err.Error())) - return 2 + return nil, fmt.Errorf("failed to read passphrase: %w. please validate the login by sending a request to sys/mfa/validate", err) } } else { c.UI.Warn("Asking Vault to perform MFA validation with upstream service. " + @@ -271,32 +281,10 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int { client, err := c.Client() if err != nil { - c.UI.Error(err.Error()) - return 2 + return nil, err } - secret, err := client.Sys().MFAValidate(reqID, mfaPayload) - if err != nil { - c.UI.Error(err.Error()) - if secret != nil { - OutputSecret(c.UI, secret) - } - return 2 - } - if secret == nil { - // Don't output anything unless using the "table" format - if Format(c.UI) == "table" { - c.UI.Info("Success! Data written to: sys/mfa/validate") - } - return 0 - } - - // Handle single field output - if c.flagField != "" { - return PrintRawField(c.UI, secret, c.flagField) - } - - return OutputSecret(c.UI, secret) + return client.Sys().MFAValidate(reqID, mfaPayload) } type FlagSetBit uint diff --git a/command/login.go b/command/login.go index 9beab755b..f46d26490 100644 --- a/command/login.go +++ b/command/login.go @@ -228,21 +228,27 @@ func (c *LoginCommand) Run(args []string) int { return 2 } - if secret != nil && secret.Auth != nil && secret.Auth.MFARequirement != nil { - if c.isInteractiveEnabled(len(secret.Auth.MFARequirement.MFAConstraints)) { - // Currently, if there is only one MFA method configured, the login - // request is validated interactively - methodInfo := c.getMFAMethodInfo(secret.Auth.MFARequirement.MFAConstraints) - if methodInfo.methodID != "" { - return c.validateMFA(secret.Auth.MFARequirement.MFARequestID, methodInfo) - } + // If there is only one MFA method configured and c.NonInteractive flag is + // unset, the login request is validated interactively. + // + // interactiveMethodInfo here means that `validateMFA` will complete the MFA + // by prompting for a password or directing you to a push notification. In + // this scenario, no external validation is needed. + interactiveMethodInfo := c.getInteractiveMFAMethodInfo(secret) + if interactiveMethodInfo != nil { + c.UI.Warn("Initiating Iteractive MFA Validation...") + secret, err = c.validateMFA(secret.Auth.MFARequirement.MFARequestID, *interactiveMethodInfo) + if err != nil { + c.UI.Error(err.Error()) + return 2 } + } else if c.getMFAValidationRequired(secret) { + // Warn about existing login token, but return here, since the secret + // won't have any token information if further validation is required. + c.checkForAndWarnAboutLoginToken() c.UI.Warn(wrapAtLength("A login request was issued that is subject to "+ "MFA validation. Please make sure to validate the login by sending another "+ "request to sys/mfa/validate endpoint.") + "\n") - - // We return early to prevent success message from being printed - c.checkForAndWarnAboutLoginToken() return OutputSecret(c.UI, secret) } diff --git a/command/login_test.go b/command/login_test.go index aefdd2585..56ed790f3 100644 --- a/command/login_test.go +++ b/command/login_test.go @@ -10,6 +10,7 @@ import ( credToken "github.com/hashicorp/vault/builtin/credential/token" credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" "github.com/hashicorp/vault/command/token" + "github.com/hashicorp/vault/helper/testhelpers" "github.com/hashicorp/vault/vault" ) @@ -428,6 +429,91 @@ func TestLoginCommand_Run(t *testing.T) { } }) + t.Run("login_mfa_single_phase", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testLoginCommand(t) + + userclient, entityID, methodID := testhelpers.SetupLoginMFATOTP(t, client) + cmd.client = userclient + + enginePath := testhelpers.RegisterEntityInTOTPEngine(t, client, entityID, methodID) + totpCode := testhelpers.GetTOTPCodeFromEngine(t, client, enginePath) + + // login command bails early for test clients, so we have to explicitly set this + cmd.client.SetMFACreds([]string{methodID + ":" + totpCode}) + code := cmd.Run([]string{ + "-method", "userpass", + "username=testuser1", + "password=testpassword", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + tokenHelper, err := cmd.TokenHelper() + if err != nil { + t.Fatal(err) + } + storedToken, err := tokenHelper.Get() + if err != nil { + t.Fatal(err) + } + output = ui.OutputWriter.String() + ui.ErrorWriter.String() + t.Logf("\n%+v", output) + if !strings.Contains(output, storedToken) { + t.Fatalf("expected stored token: %q, got: %q", storedToken, output) + } + }) + + t.Run("login_mfa_two_phase", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testLoginCommand(t) + + userclient, entityID, methodID := testhelpers.SetupLoginMFATOTP(t, client) + cmd.client = userclient + + _ = testhelpers.RegisterEntityInTOTPEngine(t, client, entityID, methodID) + + // clear the MFA creds just to be sure + cmd.client.SetMFACreds([]string{}) + + code := cmd.Run([]string{ + "-method", "userpass", + "username=testuser1", + "password=testpassword", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := methodID + output = ui.OutputWriter.String() + ui.ErrorWriter.String() + t.Logf("\n%+v", output) + if !strings.Contains(output, expected) { + t.Fatalf("expected stored token: %q, got: %q", expected, output) + } + + tokenHelper, err := cmd.TokenHelper() + if err != nil { + t.Fatal(err) + } + storedToken, err := tokenHelper.Get() + if storedToken != "" { + t.Fatal("expected empty stored token") + } + if err != nil { + t.Fatal(err) + } + }) + t.Run("communication_failure", func(t *testing.T) { t.Parallel() diff --git a/command/write.go b/command/write.go index daea96c4c..3daa2bae6 100644 --- a/command/write.go +++ b/command/write.go @@ -158,15 +158,16 @@ func handleWriteSecretOutput(c *BaseCommand, path string, secret *api.Secret, er return 0 } - if secret != nil && secret.Auth != nil && secret.Auth.MFARequirement != nil { - if c.isInteractiveEnabled(len(secret.Auth.MFARequirement.MFAConstraints)) { - // Currently, if there is only one MFA method configured, the login - // request is validated interactively - methodInfo := c.getMFAMethodInfo(secret.Auth.MFARequirement.MFAConstraints) - if methodInfo.methodID != "" { - return c.validateMFA(secret.Auth.MFARequirement.MFARequestID, methodInfo) - } + // Currently, if there is only one MFA method configured, the login + // request is validated interactively + methodInfo := c.getInteractiveMFAMethodInfo(secret) + if methodInfo != nil { + secret, err = c.validateMFA(secret.Auth.MFARequirement.MFARequestID, *methodInfo) + if err != nil { + c.UI.Error(err.Error()) + return 2 } + } else if c.getMFAValidationRequired(secret) { c.UI.Warn(wrapAtLength("A login request was issued that is subject to "+ "MFA validation. Please make sure to validate the login by sending another "+ "request to sys/mfa/validate endpoint.") + "\n") diff --git a/helper/testhelpers/testhelpers.go b/helper/testhelpers/testhelpers.go index 9c07ea523..d124182fd 100644 --- a/helper/testhelpers/testhelpers.go +++ b/helper/testhelpers/testhelpers.go @@ -779,3 +779,200 @@ func RetryUntil(t testing.T, timeout time.Duration, f func() error) { } t.Fatalf("did not complete before deadline, err: %v", err) } + +// CreateEntityAndAlias clones an existing client and creates an entity/alias. +// It returns the cloned client, entityID, and aliasID. +func CreateEntityAndAlias(t testing.T, client *api.Client, mountAccessor, entityName, aliasName string) (*api.Client, string, string) { + t.Helper() + userClient, err := client.Clone() + if err != nil { + t.Fatalf("failed to clone the client:%v", err) + } + userClient.SetToken(client.Token()) + + resp, err := client.Logical().WriteWithContext(context.Background(), "identity/entity", map[string]interface{}{ + "name": entityName, + }) + if err != nil { + t.Fatalf("failed to create an entity:%v", err) + } + entityID := resp.Data["id"].(string) + + aliasResp, err := client.Logical().WriteWithContext(context.Background(), "identity/entity-alias", map[string]interface{}{ + "name": aliasName, + "canonical_id": entityID, + "mount_accessor": mountAccessor, + }) + if err != nil { + t.Fatalf("failed to create an entity alias:%v", err) + } + aliasID := aliasResp.Data["id"].(string) + if aliasID == "" { + t.Fatal("Alias ID not present in response") + } + _, err = client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("auth/userpass/users/%s", aliasName), map[string]interface{}{ + "password": "testpassword", + }) + if err != nil { + t.Fatalf("failed to configure userpass backend: %v", err) + } + + return userClient, entityID, aliasID +} + +// SetupTOTPMount enables the totp secrets engine by mounting it. This requires +// that the test cluster has a totp backend available. +func SetupTOTPMount(t testing.T, client *api.Client) { + t.Helper() + // Mount the TOTP backend + mountInfo := &api.MountInput{ + Type: "totp", + } + if err := client.Sys().Mount("totp", mountInfo); err != nil { + t.Fatalf("failed to mount totp backend: %v", err) + } +} + +// SetupTOTPMethod configures the TOTP secrets engine with a provided config map. +func SetupTOTPMethod(t testing.T, client *api.Client, config map[string]interface{}) string { + t.Helper() + + resp1, err := client.Logical().Write("identity/mfa/method/totp", config) + + if err != nil || (resp1 == nil) { + t.Fatalf("bad: resp: %#v\n err: %v", resp1, err) + } + + methodID := resp1.Data["method_id"].(string) + if methodID == "" { + t.Fatalf("method ID is empty") + } + + return methodID +} + +// SetupMFALoginEnforcement configures a single enforcement method using the +// provided config map. "name" field is required in the config map. +func SetupMFALoginEnforcement(t testing.T, client *api.Client, config map[string]interface{}) { + t.Helper() + enfName, ok := config["name"] + if !ok { + t.Fatalf("couldn't find name in login-enforcement config") + } + _, err := client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("identity/mfa/login-enforcement/%s", enfName), config) + if err != nil { + t.Fatalf("failed to configure MFAEnforcementConfig: %v", err) + } +} + +// SetupUserpassMountAccessor sets up userpass auth and returns its mount +// accessor. This requires that the test cluster has a "userpass" auth method +// available. +func SetupUserpassMountAccessor(t testing.T, client *api.Client) string { + t.Helper() + var mountAccessor string + // Enable Userpass authentication + err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ + Type: "userpass", + }) + if err != nil { + t.Fatalf("failed to enable userpass auth: %v", err) + } + + auths, err := client.Sys().ListAuthWithContext(context.Background()) + if err != nil { + t.Fatalf("failed to list auth methods: %v", err) + } + if auths != nil && auths["userpass/"] != nil { + mountAccessor = auths["userpass/"].Accessor + } else { + t.Fatalf("failed to get userpass mount accessor") + } + + return mountAccessor +} + +// RegisterEntityInTOTPEngine registers an entity with a methodID and returns +// the generated name. +func RegisterEntityInTOTPEngine(t testing.T, client *api.Client, entityID, methodID string) string { + t.Helper() + totpGenName := fmt.Sprintf("%s-%s", entityID, methodID) + secret, err := client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("identity/mfa/method/totp/admin-generate"), map[string]interface{}{ + "entity_id": entityID, + "method_id": methodID, + }) + if err != nil { + t.Fatalf("failed to generate a TOTP secret on an entity: %v", err) + } + totpURL := secret.Data["url"].(string) + if totpURL == "" { + t.Fatalf("failed to get TOTP url in secret response: %+v", secret) + } + _, err = client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("totp/keys/%s", totpGenName), map[string]interface{}{ + "url": totpURL, + }) + if err != nil { + t.Fatalf("failed to register a TOTP URL: %v", err) + } + _, err = client.Logical().WriteWithContext(context.Background(), "identity/mfa/login-enforcement/randomName", map[string]interface{}{ + "name": "randomName", + "identity_entity_ids": []string{entityID}, + "mfa_method_ids": []string{methodID}, + }) + if err != nil { + t.Fatalf("failed to create login enforcement") + } + + return totpGenName +} + +// GetTOTPCodeFromEngine requests a TOTP code from the specified enginePath. +func GetTOTPCodeFromEngine(t testing.T, client *api.Client, enginePath string) string { + t.Helper() + totpPath := fmt.Sprintf("totp/code/%s", enginePath) + secret, err := client.Logical().ReadWithContext(context.Background(), totpPath) + if err != nil { + t.Fatalf("failed to create totp passcode: %v", err) + } + if secret == nil { + t.Fatalf("bad secret returned from %s", totpPath) + } + return secret.Data["code"].(string) +} + +// SetupLoginMFATOTP setups up a TOTP MFA using some basic configuration and +// returns all relevant information to the client. +func SetupLoginMFATOTP(t testing.T, client *api.Client) (*api.Client, string, string) { + t.Helper() + // Mount the totp secrets engine + SetupTOTPMount(t, client) + + // Create a mount accessor to associate with an entity + mountAccessor := SetupUserpassMountAccessor(t, client) + + // Create a test entity and alias + entityClient, entityID, _ := CreateEntityAndAlias(t, client, mountAccessor, "entity1", "testuser1") + + // Configure a default TOTP method + totpConfig := map[string]interface{}{ + "issuer": "yCorp", + "period": 5, + "algorithm": "SHA256", + "digits": 6, + "skew": 0, + "key_size": 20, + "qr_size": 200, + "max_validation_attempts": 5, + } + methodID := SetupTOTPMethod(t, client, totpConfig) + + // Configure a default login enforcement + enforcementConfig := map[string]interface{}{ + "auth_method_types": []string{"userpass"}, + "name": "randomName", + "mfa_method_ids": []string{methodID}, + } + + SetupMFALoginEnforcement(t, client, enforcementConfig) + return entityClient, entityID, methodID +} diff --git a/vault/external_tests/identity/aliases_test.go b/vault/external_tests/identity/aliases_test.go index d58540f8c..747bd6a6f 100644 --- a/vault/external_tests/identity/aliases_test.go +++ b/vault/external_tests/identity/aliases_test.go @@ -12,6 +12,7 @@ import ( auth "github.com/hashicorp/vault/api/auth/userpass" "github.com/hashicorp/vault/builtin/credential/github" "github.com/hashicorp/vault/builtin/credential/userpass" + "github.com/hashicorp/vault/helper/testhelpers" vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" @@ -301,7 +302,7 @@ func TestIdentityStore_MergeEntities_FailsDueToClash(t *testing.T) { t.Fatal("did not find userpass accessor") } - _, entityIdBob, aliasIdBob := createEntityAndAlias(client, mountAccessor, "bob-smith", "bob", t) + _, entityIdBob, aliasIdBob := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, "bob-smith", "bob") // Create userpass login for alice _, err = client.Logical().Write("auth/userpass/users/alice", map[string]interface{}{ @@ -311,7 +312,7 @@ func TestIdentityStore_MergeEntities_FailsDueToClash(t *testing.T) { t.Fatal(err) } - _, entityIdAlice, aliasIdAlice := createEntityAndAlias(client, mountAccessor, "alice-smith", "alice", t) + _, entityIdAlice, aliasIdAlice := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, "alice-smith", "alice") // Perform entity merge mergeResp, err := client.Logical().Write("identity/entity/merge", map[string]interface{}{ @@ -404,9 +405,9 @@ func TestIdentityStore_MergeEntities_FailsDueToClashInFromEntities(t *testing.T) t.Fatal("did not find github accessor") } - _, entityIdBob, _ := createEntityAndAlias(client, mountAccessor, "bob-smith", "bob", t) - _, entityIdAlice, _ := createEntityAndAlias(client, mountAccessorGitHub, "alice-smith", "alice", t) - _, entityIdClara, _ := createEntityAndAlias(client, mountAccessorGitHub, "clara-smith", "clara", t) + _, entityIdBob, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, "bob-smith", "bob") + _, entityIdAlice, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessorGitHub, "alice-smith", "alice") + _, entityIdClara, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessorGitHub, "clara-smith", "clara") // Perform entity merge mergeResp, err := client.Logical().Write("identity/entity/merge", map[string]interface{}{ @@ -491,7 +492,7 @@ func TestIdentityStore_MergeEntities_FailsDueToDoubleClash(t *testing.T) { t.Fatal("did not find github accessor") } - _, entityIdBob, aliasIdBob := createEntityAndAlias(client, mountAccessor, "bob-smith", "bob", t) + _, entityIdBob, aliasIdBob := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, "bob-smith", "bob") aliasResp, err := client.Logical().Write("identity/entity-alias", map[string]interface{}{ "name": "bob-github", @@ -515,8 +516,8 @@ func TestIdentityStore_MergeEntities_FailsDueToDoubleClash(t *testing.T) { t.Fatal(err) } - _, entityIdAlice, aliasIdAlice := createEntityAndAlias(client, mountAccessor, "alice-smith", "alice", t) - _, entityIdClara, aliasIdClara := createEntityAndAlias(client, mountAccessorGitHub, "clara-smith", "clara", t) + _, entityIdAlice, aliasIdAlice := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, "alice-smith", "alice") + _, entityIdClara, aliasIdClara := testhelpers.CreateEntityAndAlias(t, client, mountAccessorGitHub, "clara-smith", "clara") // Perform entity merge mergeResp, err := client.Logical().Write("identity/entity/merge", map[string]interface{}{ @@ -602,7 +603,7 @@ func TestIdentityStore_MergeEntities_FailsDueToClashInFromEntities_CheckRawReque t.Fatal("did not find userpass accessor") } - _, entityIdBob, _ := createEntityAndAlias(client, mountAccessor, "bob-smith", "bob", t) + _, entityIdBob, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, "bob-smith", "bob") // Create userpass login for alice _, err = client.Logical().Write("auth/userpass/users/alice", map[string]interface{}{ @@ -612,7 +613,7 @@ func TestIdentityStore_MergeEntities_FailsDueToClashInFromEntities_CheckRawReque t.Fatal(err) } - _, entityIdAlice, _ := createEntityAndAlias(client, mountAccessor, "alice-smith", "alice", t) + _, entityIdAlice, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, "alice-smith", "alice") // Perform entity merge as a Raw Request so we can investigate the response body req := client.NewRequest("POST", "/v1/identity/entity/merge") @@ -772,7 +773,7 @@ func TestIdentityStore_MergeEntities_SameMountAccessor_ThenUseAlias(t *testing.T t.Fatal("did not find userpass accessor") } - _, entityIdBob, aliasIdBob := createEntityAndAlias(client, mountAccessor, "bob-smith", "bob", t) + _, entityIdBob, aliasIdBob := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, "bob-smith", "bob") // Create userpass login for alice _, err = client.Logical().Write("auth/userpass/users/alice", map[string]interface{}{ @@ -788,7 +789,7 @@ func TestIdentityStore_MergeEntities_SameMountAccessor_ThenUseAlias(t *testing.T t.Fatal(err) } - _, entityIdAlice, _ := createEntityAndAlias(client, mountAccessor, "alice-smith", "alice", t) + _, entityIdAlice, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, "alice-smith", "alice") // Try and login with alias 2 (alice) pre-merge userpassAuth, err := auth.NewUserpassAuth("alice", &auth.Password{FromString: "testpassword"}) @@ -909,7 +910,7 @@ func TestIdentityStore_MergeEntities_FailsDueToMultipleClashMergesAttempted(t *t t.Fatal("did not find github accessor") } - _, entityIdBob, _ := createEntityAndAlias(client, mountAccessor, "bob-smith", "bob", t) + _, entityIdBob, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, "bob-smith", "bob") aliasResp, err := client.Logical().Write("identity/entity-alias", map[string]interface{}{ "name": "bob-github", "canonical_id": entityIdBob, @@ -932,8 +933,8 @@ func TestIdentityStore_MergeEntities_FailsDueToMultipleClashMergesAttempted(t *t t.Fatal(err) } - _, entityIdAlice, aliasIdAlice := createEntityAndAlias(client, mountAccessor, "alice-smith", "alice", t) - _, entityIdClara, aliasIdClara := createEntityAndAlias(client, mountAccessorGitHub, "clara-smith", "alice", t) + _, entityIdAlice, aliasIdAlice := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, "alice-smith", "alice") + _, entityIdClara, aliasIdClara := testhelpers.CreateEntityAndAlias(t, client, mountAccessorGitHub, "clara-smith", "alice") // Perform entity merge mergeResp, err := client.Logical().Write("identity/entity/merge", map[string]interface{}{ diff --git a/vault/external_tests/identity/login_mfa_totp_test.go b/vault/external_tests/identity/login_mfa_totp_test.go index 15d7bf009..b6f586145 100644 --- a/vault/external_tests/identity/login_mfa_totp_test.go +++ b/vault/external_tests/identity/login_mfa_totp_test.go @@ -8,6 +8,7 @@ import ( "time" upAuth "github.com/hashicorp/vault/api/auth/userpass" + "github.com/hashicorp/vault/helper/testhelpers" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/audit" @@ -18,70 +19,9 @@ import ( "github.com/hashicorp/vault/vault" ) -func createEntityAndAlias(client *api.Client, mountAccessor, entityName, aliasName string, t *testing.T) (*api.Client, string, string) { - _, err := client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("auth/userpass/users/%s", aliasName), map[string]interface{}{ - "password": "testpassword", - }) - if err != nil { - t.Fatalf("failed to configure userpass backend: %v", err) - } - - userClient, err := client.Clone() - if err != nil { - t.Fatalf("failed to clone the client:%v", err) - } - userClient.SetToken(client.Token()) - - resp, err := client.Logical().WriteWithContext(context.Background(), "identity/entity", map[string]interface{}{ - "name": entityName, - }) - if err != nil { - t.Fatalf("failed to create an entity:%v", err) - } - entityID := resp.Data["id"].(string) - - aliasResp, err := client.Logical().WriteWithContext(context.Background(), "identity/entity-alias", map[string]interface{}{ - "name": aliasName, - "canonical_id": entityID, - "mount_accessor": mountAccessor, - }) - if err != nil { - t.Fatalf("failed to create an entity alias:%v", err) - } - - aliasID := aliasResp.Data["id"].(string) - if aliasID == "" { - t.Fatal("Alias ID not present in response") - } - return userClient, entityID, aliasID -} - -func registerEntityInTOTPEngine(client *api.Client, entityID, methodID string, t *testing.T) string { - totpGenName := fmt.Sprintf("%s-%s", entityID, methodID) - secret, err := client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("identity/mfa/method/totp/admin-generate"), map[string]interface{}{ - "entity_id": entityID, - "method_id": methodID, - }) - if err != nil { - t.Fatalf("failed to generate a TOTP secret on an entity: %v", err) - } - totpURL := secret.Data["url"].(string) - - _, err = client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("totp/keys/%s", totpGenName), map[string]interface{}{ - "url": totpURL, - }) - if err != nil { - t.Fatalf("failed to register a TOTP URL: %v", err) - } - return totpGenName -} - -func doTwoPhaseLogin(client *api.Client, totpCodePath, methodID, username string, t *testing.T) { - totpResp, err := client.Logical().ReadWithContext(context.Background(), totpCodePath) - if err != nil { - t.Fatalf("failed to create totp passcode: %v", err) - } - totpPasscode := totpResp.Data["code"].(string) +func doTwoPhaseLogin(t *testing.T, client *api.Client, totpCodePath, methodID, username string) { + t.Helper() + totpPasscode := testhelpers.GetTOTPCodeFromEngine(t, client, totpCodePath) upMethod, err := upAuth.NewUserpassAuth(username, &upAuth.Password{FromString: "testpassword"}) @@ -135,91 +75,48 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { client := cluster.Cores[0].Client // Enable the audit backend - err := client.Sys().EnableAuditWithOptions("noop", &api.EnableAuditOptions{Type: "noop"}) - if err != nil { + if err := client.Sys().EnableAuditWithOptions("noop", &api.EnableAuditOptions{Type: "noop"}); err != nil { t.Fatal(err) } - // Mount the TOTP backend - mountInfo := &api.MountInput{ - Type: "totp", - } - err = client.Sys().Mount("totp", mountInfo) - if err != nil { - t.Fatalf("failed to mount totp backend: %v", err) - } - - // Enable Userpass authentication - err = client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ - Type: "userpass", - }) - if err != nil { - t.Fatalf("failed to enable userpass auth: %v", err) - } - - auths, err := client.Sys().ListAuthWithContext(context.Background()) - if err != nil { - t.Fatalf("bb") - } - var mountAccessor string - if auths != nil && auths["userpass/"] != nil { - mountAccessor = auths["userpass/"].Accessor - } + testhelpers.SetupTOTPMount(t, client) + mountAccessor := testhelpers.SetupUserpassMountAccessor(t, client) // Creating two users in the userpass auth mount - userClient1, entityID1, _ := createEntityAndAlias(client, mountAccessor, "entity1", "testuser1", t) - userClient2, entityID2, _ := createEntityAndAlias(client, mountAccessor, "entity2", "testuser2", t) + userClient1, entityID1, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, "entity1", "testuser1") + userClient2, entityID2, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, "entity2", "testuser2") - // configure TOTP secret engine - var methodID string - // login MFA - { - // create a config - resp1, err := client.Logical().Write("identity/mfa/method/totp", map[string]interface{}{ - "issuer": "yCorp", - "period": 5, - "algorithm": "SHA1", - "digits": 6, - "skew": 1, - "key_size": 10, - "qr_size": 100, - "max_validation_attempts": 3, - }) - - if err != nil || (resp1 == nil) { - t.Fatalf("bad: resp: %#v\n err: %v", resp1, err) - } - - methodID = resp1.Data["method_id"].(string) - if methodID == "" { - t.Fatalf("method ID is empty") - } - - // creating MFAEnforcementConfig - _, err = client.Logical().WriteWithContext(context.Background(), "identity/mfa/login-enforcement/randomName", map[string]interface{}{ - "auth_method_types": []string{"userpass"}, - "name": "randomName", - "mfa_method_ids": []string{methodID}, - }) - if err != nil { - t.Fatalf("failed to configure MFAEnforcementConfig: %v", err) - } + totpConfig := map[string]interface{}{ + "issuer": "yCorp", + "period": 10, + "algorithm": "SHA512", + "digits": 6, + "skew": 0, + "key_size": 20, + "qr_size": 200, + "max_validation_attempts": 5, } + methodID := testhelpers.SetupTOTPMethod(t, client, totpConfig) + // registering EntityIDs in the TOTP secret Engine for MethodID - totpEngineConfigName1 := registerEntityInTOTPEngine(client, entityID1, methodID, t) - totpEngineConfigName2 := registerEntityInTOTPEngine(client, entityID2, methodID, t) + enginePath1 := testhelpers.RegisterEntityInTOTPEngine(t, client, entityID1, methodID) + enginePath2 := testhelpers.RegisterEntityInTOTPEngine(t, client, entityID2, methodID) + + // Configure a default login enforcement + enforcementConfig := map[string]interface{}{ + "auth_method_types": []string{"userpass"}, + "name": "randomName", + "mfa_method_ids": []string{methodID}, + } + + testhelpers.SetupMFALoginEnforcement(t, client, enforcementConfig) // MFA single-phase login - totpCodePath1 := fmt.Sprintf("totp/code/%s", totpEngineConfigName1) - secret, err := client.Logical().ReadWithContext(context.Background(), totpCodePath1) - if err != nil { - t.Fatalf("failed to create totp passcode: %v", err) - } - totpPasscode1 := secret.Data["code"].(string) + totpPasscode1 := testhelpers.GetTOTPCodeFromEngine(t, client, enginePath1) userClient1.AddHeader("X-Vault-MFA", fmt.Sprintf("%s:%s", methodID, totpPasscode1)) - secret, err = userClient1.Logical().WriteWithContext(context.Background(), "auth/userpass/login/testuser1", map[string]interface{}{ + secret, err := userClient1.Logical().WriteWithContext(context.Background(), "auth/userpass/login/testuser1", map[string]interface{}{ "password": "testpassword", }) if err != nil { @@ -280,13 +177,8 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { // validation // waiting for 5 seconds so that a fresh code could be generated - time.Sleep(5 * time.Second) - // getting a fresh totp passcode for the validation step - totpResp, err := client.Logical().ReadWithContext(context.Background(), totpCodePath1) - if err != nil { - t.Fatalf("failed to create totp passcode: %v", err) - } - totpPasscode1 = totpResp.Data["code"].(string) + time.Sleep(10 * time.Second) + totpPasscode1 = testhelpers.GetTOTPCodeFromEngine(t, client, enginePath1) secret, err = userClient1.Logical().WriteWithContext(context.Background(), "sys/mfa/validate", map[string]interface{}{ "mfa_request_id": secret.Auth.MFARequirement.MFARequestID, @@ -350,7 +242,9 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { } var maxErr error - for i := 0; i < 4; i++ { + maxAttempts := 6 + i := 0 + for i = 0; i < maxAttempts; i++ { _, maxErr = userClient1.Logical().WriteWithContext(context.Background(), "sys/mfa/validate", map[string]interface{}{ "mfa_request_id": secret.Auth.MFARequirement.MFARequestID, "mfa_payload": map[string][]string{ @@ -361,18 +255,16 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { t.Fatalf("MFA succeeded with an invalid passcode") } } - if !strings.Contains(maxErr.Error(), "maximum TOTP validation attempts 4 exceeded the allowed attempts 3") { - t.Fatalf("unexpected error message when exceeding max failed validation attempts") + if !strings.Contains(maxErr.Error(), "maximum TOTP validation attempts") { + t.Fatalf("unexpected error message when exceeding max failed validation attempts: %s", maxErr.Error()) } // let's make sure the configID is not blocked for other users - totpCodePath2 := fmt.Sprintf("totp/code/%s", totpEngineConfigName2) - doTwoPhaseLogin(userClient2, totpCodePath2, methodID, "testuser2", t) + doTwoPhaseLogin(t, userClient2, enginePath2, methodID, "testuser2") // let's see if user1 is able to login after 5 seconds - time.Sleep(5 * time.Second) - // getting a fresh totp passcode for the validation step - doTwoPhaseLogin(userClient1, totpCodePath1, methodID, "testuser1", t) + time.Sleep(10 * time.Second) + doTwoPhaseLogin(t, userClient1, enginePath1, methodID, "testuser1") // Destroy the secret so that the token can self generate _, err = client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("identity/mfa/method/totp/admin-destroy"), map[string]interface{}{