From 6116903a37011b90b34e81dd990708780adcaaf9 Mon Sep 17 00:00:00 2001 From: Vinny Mannello <94396874+VinnyHC@users.noreply.github.com> Date: Fri, 15 Apr 2022 11:13:15 -0700 Subject: [PATCH] [Vault-5248] MFA support for api login helpers (#14900) * Add MFA support to login helpers --- api/auth.go | 78 ++++++++++- api/auth_test.go | 130 ++++++++++++++++++ api/sys_mfa.go | 45 ++++++ changelog/14900.txt | 3 + command/base.go | 13 +- .../identity/login_mfa_duo_test.go | 13 +- .../identity/login_mfa_okta_test.go | 11 +- .../identity/login_mfa_totp_test.go | 26 ++-- 8 files changed, 283 insertions(+), 36 deletions(-) create mode 100644 api/auth_test.go create mode 100644 api/sys_mfa.go create mode 100644 changelog/14900.txt diff --git a/api/auth.go b/api/auth.go index 10af56bb9..fa92de4b3 100644 --- a/api/auth.go +++ b/api/auth.go @@ -31,16 +31,82 @@ func (a *Auth) Login(ctx context.Context, authMethod AuthMethod) (*Secret, error if authMethod == nil { return nil, fmt.Errorf("no auth method provided for login") } + return a.login(ctx, authMethod) +} - authSecret, err := authMethod.Login(ctx, a.c) +// MFALogin is a wrapper that helps satisfy Vault's MFA implementation. +// If optional credentials are provided a single-phase login will be attempted +// and the resulting Secret will contain a ClientToken if the authentication is successful. +// The client's token will also be set accordingly. +// +// If no credentials are provided a two-phase MFA login will be assumed and the resulting +// Secret will have a MFARequirement containing the MFARequestID to be used in a follow-up +// call to `sys/mfa/validate` or by passing it to the method (*Auth).MFAValidate. +func (a *Auth) MFALogin(ctx context.Context, authMethod AuthMethod, creds ...string) (*Secret, error) { + if len(creds) > 0 { + a.c.SetMFACreds(creds) + return a.login(ctx, authMethod) + } + + return a.twoPhaseMFALogin(ctx, authMethod) +} + +// MFAValidate validates an MFA request using the appropriate payload and a secret containing +// Auth.MFARequirement, like the one returned by MFALogin when credentials are not provided. +// Upon successful validation the client token will be set accordingly. +// +// The Secret returned is the authentication secret, which if desired can be +// passed as input to the NewLifetimeWatcher method in order to start +// automatically renewing the token. +func (a *Auth) MFAValidate(ctx context.Context, mfaSecret *Secret, payload map[string]interface{}) (*Secret, error) { + if mfaSecret == nil || mfaSecret.Auth == nil || mfaSecret.Auth.MFARequirement == nil { + return nil, fmt.Errorf("secret does not contain MFARequirements") + } + + s, err := a.c.Sys().MFAValidateWithContext(ctx, mfaSecret.Auth.MFARequirement.GetMFARequestID(), payload) + if err != nil { + return nil, err + } + + return a.checkAndSetToken(s) +} + +// login performs the (*AuthMethod).Login() with the configured client and checks that a ClientToken is returned +func (a *Auth) login(ctx context.Context, authMethod AuthMethod) (*Secret, error) { + s, err := authMethod.Login(ctx, a.c) if err != nil { return nil, fmt.Errorf("unable to log in to auth method: %w", err) } - if authSecret == nil || authSecret.Auth == nil || authSecret.Auth.ClientToken == "" { - return nil, fmt.Errorf("login response from auth method did not return client token") + + return a.checkAndSetToken(s) +} + +// twoPhaseMFALogin performs the (*AuthMethod).Login() with the configured client +// and checks that an MFARequirement is returned +func (a *Auth) twoPhaseMFALogin(ctx context.Context, authMethod AuthMethod) (*Secret, error) { + s, err := authMethod.Login(ctx, a.c) + if err != nil { + return nil, fmt.Errorf("unable to log in: %w", err) + } + if s == nil || s.Auth == nil || s.Auth.MFARequirement == nil { + if s != nil { + s.Warnings = append(s.Warnings, "expected secret to contain MFARequirements") + } + return s, fmt.Errorf("assumed two-phase MFA login, returned secret is missing MFARequirements") } - a.c.SetToken(authSecret.Auth.ClientToken) - - return authSecret, nil + return s, nil +} + +func (a *Auth) checkAndSetToken(s *Secret) (*Secret, error) { + if s == nil || s.Auth == nil || s.Auth.ClientToken == "" { + if s != nil { + s.Warnings = append(s.Warnings, "expected secret to contain ClientToken") + } + return s, fmt.Errorf("response did not return ClientToken, client token not set") + } + + a.c.SetToken(s.Auth.ClientToken) + + return s, nil } diff --git a/api/auth_test.go b/api/auth_test.go new file mode 100644 index 000000000..46113d92f --- /dev/null +++ b/api/auth_test.go @@ -0,0 +1,130 @@ +package api + +import ( + "context" + "testing" + + "github.com/hashicorp/vault/sdk/logical" +) + +type mockAuthMethod struct { + mockedSecret *Secret + mockedError error +} + +func (m *mockAuthMethod) Login(_ context.Context, _ *Client) (*Secret, error) { + return m.mockedSecret, m.mockedError +} + +func TestAuth_Login(t *testing.T) { + a := &Auth{ + c: &Client{}, + } + + m := mockAuthMethod{ + mockedSecret: &Secret{ + Auth: &SecretAuth{ + ClientToken: "a-client-token", + }, + }, + mockedError: nil, + } + + t.Run("Login should set token on success", func(t *testing.T) { + if a.c.Token() != "" { + t.Errorf("client token was %v expected to be unset", a.c.Token()) + } + + _, err := a.Login(context.Background(), &m) + if err != nil { + t.Errorf("Login() error = %v", err) + return + } + + if a.c.Token() != m.mockedSecret.Auth.ClientToken { + t.Errorf("client token was %v expected %v", a.c.Token(), m.mockedSecret.Auth.ClientToken) + return + } + }) +} + +func TestAuth_MFALoginSinglePhase(t *testing.T) { + t.Run("MFALogin() should succeed if credentials are passed in", func(t *testing.T) { + a := &Auth{ + c: &Client{}, + } + + m := mockAuthMethod{ + mockedSecret: &Secret{ + Auth: &SecretAuth{ + ClientToken: "a-client-token", + }, + }, + mockedError: nil, + } + + _, err := a.MFALogin(context.Background(), &m, "testMethod:testPasscode") + if err != nil { + t.Errorf("MFALogin() error %v", err) + return + } + if a.c.Token() != m.mockedSecret.Auth.ClientToken { + t.Errorf("client token was %v expected %v", a.c.Token(), m.mockedSecret.Auth.ClientToken) + return + } + }) +} + +func TestAuth_MFALoginTwoPhase(t *testing.T) { + tests := []struct { + name string + a *Auth + m *mockAuthMethod + creds *string + wantErr bool + }{ + { + name: "return MFARequirements", + a: &Auth{ + c: &Client{}, + }, + m: &mockAuthMethod{ + mockedSecret: &Secret{ + Auth: &SecretAuth{ + MFARequirement: &logical.MFARequirement{ + MFARequestID: "a-req-id", + MFAConstraints: nil, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "error if no MFARequirements", + a: &Auth{ + c: &Client{}, + }, + m: &mockAuthMethod{ + mockedSecret: &Secret{ + Auth: &SecretAuth{}, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + secret, err := tt.a.MFALogin(context.Background(), tt.m) + if (err != nil) != tt.wantErr { + t.Errorf("MFALogin() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if secret.Auth.MFARequirement != tt.m.mockedSecret.Auth.MFARequirement { + t.Errorf("MFALogin() returned %v, expected %v", secret.Auth.MFARequirement, tt.m.mockedSecret.Auth.MFARequirement) + return + } + }) + } +} diff --git a/api/sys_mfa.go b/api/sys_mfa.go new file mode 100644 index 000000000..a1ba1bd80 --- /dev/null +++ b/api/sys_mfa.go @@ -0,0 +1,45 @@ +package api + +import ( + "context" + "fmt" + "net/http" +) + +func (c *Sys) MFAValidate(requestID string, payload map[string]interface{}) (*Secret, error) { + return c.MFAValidateWithContext(context.Background(), requestID, payload) +} + +func (c *Sys) MFAValidateWithContext(ctx context.Context, requestID string, payload map[string]interface{}) (*Secret, error) { + ctx, cancelFunc := c.c.withConfiguredTimeout(ctx) + defer cancelFunc() + + body := map[string]interface{}{ + "mfa_request_id": requestID, + "mfa_payload": payload, + } + + r := c.c.NewRequest(http.MethodPost, fmt.Sprintf("/v1/sys/mfa/validate")) + if err := r.SetJSONBody(body); err != nil { + return nil, fmt.Errorf("failed to set request body: %w", err) + } + + resp, err := c.c.rawRequestWithContext(ctx, r) + if resp != nil { + defer resp.Body.Close() + } + if err != nil { + return nil, err + } + + secret, err := ParseSecret(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to parse secret from response: %w", err) + } + + if secret == nil { + return nil, fmt.Errorf("data from server response is empty") + } + + return secret, nil +} diff --git a/changelog/14900.txt b/changelog/14900.txt new file mode 100644 index 000000000..6d995fa36 --- /dev/null +++ b/changelog/14900.txt @@ -0,0 +1,3 @@ +```release-note:improvement +api: Added MFALogin() for handling MFA flow when using login helpers. +``` \ No newline at end of file diff --git a/command/base.go b/command/base.go index ea677514e..0363db3a0 100644 --- a/command/base.go +++ b/command/base.go @@ -259,8 +259,8 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int { } // passcode could be an empty string - mfaPayload := map[string][]string{ - methodInfo.methodID: {passcode}, + mfaPayload := map[string]interface{}{ + methodInfo.methodID: []string{passcode}, } client, err := c.Client() @@ -269,12 +269,7 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int { return 2 } - path := "sys/mfa/validate" - - secret, err := client.Logical().Write(path, map[string]interface{}{ - "mfa_request_id": reqID, - "mfa_payload": mfaPayload, - }) + secret, err := client.Sys().MFAValidate(reqID, mfaPayload) if err != nil { c.UI.Error(err.Error()) if secret != nil { @@ -285,7 +280,7 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int { if secret == nil { // Don't output anything unless using the "table" format if Format(c.UI) == "table" { - c.UI.Info(fmt.Sprintf("Success! Data written to: %s", path)) + c.UI.Info("Success! Data written to: sys/mfa/validate") } return 0 } diff --git a/vault/external_tests/identity/login_mfa_duo_test.go b/vault/external_tests/identity/login_mfa_duo_test.go index 233afd4fd..4fe92202c 100644 --- a/vault/external_tests/identity/login_mfa_duo_test.go +++ b/vault/external_tests/identity/login_mfa_duo_test.go @@ -1,6 +1,7 @@ package identity import ( + "context" "fmt" "net/http" "reflect" @@ -241,7 +242,6 @@ func mfaGenerateLoginDUOTest(client *api.Client) error { return fmt.Errorf("failed to configure MFAEnforcementConfig: %v", err) } } - secret, err = client.Logical().Write("auth/userpass/login/vaultmfa", map[string]interface{}{ "password": "testpassword", }) @@ -272,12 +272,11 @@ func mfaGenerateLoginDUOTest(client *api.Client) error { } // validation - secret, err = client.Logical().Write("sys/mfa/validate", map[string]interface{}{ - "mfa_request_id": secret.Auth.MFARequirement.MFARequestID, - "mfa_payload": map[string][]string{ - methodID: {}, - }, - }) + secret, err = client.Sys().MFAValidateWithContext(context.Background(), + secret.Auth.MFARequirement.MFARequestID, + map[string]interface{}{ + methodID: []string{}, + }) if err != nil { return fmt.Errorf("MFA failed: %v", err) } diff --git a/vault/external_tests/identity/login_mfa_okta_test.go b/vault/external_tests/identity/login_mfa_okta_test.go index c80825af4..53127485b 100644 --- a/vault/external_tests/identity/login_mfa_okta_test.go +++ b/vault/external_tests/identity/login_mfa_okta_test.go @@ -1,6 +1,7 @@ package identity import ( + "context" "fmt" "reflect" "testing" @@ -322,12 +323,12 @@ func mfaGenerateOktaLoginMFATest(client *api.Client) error { } // validation - secret, err = client.Logical().Write("sys/mfa/validate", map[string]interface{}{ - "mfa_request_id": secret.Auth.MFARequirement.MFARequestID, - "mfa_payload": map[string][]string{ - methodID: {}, + secret, err = client.Sys().MFAValidateWithContext(context.Background(), + secret.Auth.MFARequirement.MFARequestID, + map[string]interface{}{ + methodID: []string{}, }, - }) + ) if err != nil { return fmt.Errorf("MFA failed: %v", err) } diff --git a/vault/external_tests/identity/login_mfa_totp_test.go b/vault/external_tests/identity/login_mfa_totp_test.go index 9103bc845..d0871cc28 100644 --- a/vault/external_tests/identity/login_mfa_totp_test.go +++ b/vault/external_tests/identity/login_mfa_totp_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + upAuth "github.com/hashicorp/vault/api/auth/userpass" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/builtin/credential/userpass" @@ -76,21 +78,27 @@ func doTwoPhaseLogin(client *api.Client, totpCodePath, methodID, username string } totpPasscode := totpResp.Data["code"].(string) - secret, err := client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("auth/userpass/login/%s", username), map[string]interface{}{ - "password": "testpassword", - }) + upMethod, err := upAuth.NewUserpassAuth(username, &upAuth.Password{FromString: "testpassword"}) + + mfaSecret, err := client.Auth().MFALogin(context.Background(), upMethod) if err != nil { - t.Fatalf("first phase of login MFA failed: %v", err) + t.Fatalf("failed to login with userpass auth method: %v", err) } - secret, err = client.Logical().WriteWithContext(context.Background(), "sys/mfa/validate", map[string]interface{}{ - "mfa_request_id": secret.Auth.MFARequirement.MFARequestID, - "mfa_payload": map[string][]string{ - methodID: {totpPasscode}, + + secret, err := client.Auth().MFAValidate( + context.Background(), + mfaSecret, + map[string]interface{}{ + methodID: []string{totpPasscode}, }, - }) + ) if err != nil { t.Fatalf("MFA validation failed: %v", err) } + + if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { + t.Fatalf("MFA validation failed to return a ClientToken in secret: %v", secret) + } } func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) {