[Vault-5248] MFA support for api login helpers (#14900)

* Add MFA support to login helpers
This commit is contained in:
Vinny Mannello 2022-04-15 11:13:15 -07:00 committed by GitHub
parent d43324831b
commit 6116903a37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 283 additions and 36 deletions

View File

@ -31,16 +31,82 @@ func (a *Auth) Login(ctx context.Context, authMethod AuthMethod) (*Secret, error
if authMethod == nil { if authMethod == nil {
return nil, fmt.Errorf("no auth method provided for login") return nil, fmt.Errorf("no auth method provided for login")
} }
return a.login(ctx, authMethod)
}
authSecret, err := authMethod.Login(ctx, a.c) // MFALogin is a wrapper that helps satisfy Vault's MFA implementation.
// If optional credentials are provided a single-phase login will be attempted
// and the resulting Secret will contain a ClientToken if the authentication is successful.
// The client's token will also be set accordingly.
//
// If no credentials are provided a two-phase MFA login will be assumed and the resulting
// Secret will have a MFARequirement containing the MFARequestID to be used in a follow-up
// call to `sys/mfa/validate` or by passing it to the method (*Auth).MFAValidate.
func (a *Auth) MFALogin(ctx context.Context, authMethod AuthMethod, creds ...string) (*Secret, error) {
if len(creds) > 0 {
a.c.SetMFACreds(creds)
return a.login(ctx, authMethod)
}
return a.twoPhaseMFALogin(ctx, authMethod)
}
// MFAValidate validates an MFA request using the appropriate payload and a secret containing
// Auth.MFARequirement, like the one returned by MFALogin when credentials are not provided.
// Upon successful validation the client token will be set accordingly.
//
// The Secret returned is the authentication secret, which if desired can be
// passed as input to the NewLifetimeWatcher method in order to start
// automatically renewing the token.
func (a *Auth) MFAValidate(ctx context.Context, mfaSecret *Secret, payload map[string]interface{}) (*Secret, error) {
if mfaSecret == nil || mfaSecret.Auth == nil || mfaSecret.Auth.MFARequirement == nil {
return nil, fmt.Errorf("secret does not contain MFARequirements")
}
s, err := a.c.Sys().MFAValidateWithContext(ctx, mfaSecret.Auth.MFARequirement.GetMFARequestID(), payload)
if err != nil {
return nil, err
}
return a.checkAndSetToken(s)
}
// login performs the (*AuthMethod).Login() with the configured client and checks that a ClientToken is returned
func (a *Auth) login(ctx context.Context, authMethod AuthMethod) (*Secret, error) {
s, err := authMethod.Login(ctx, a.c)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to log in to auth method: %w", err) return nil, fmt.Errorf("unable to log in to auth method: %w", err)
} }
if authSecret == nil || authSecret.Auth == nil || authSecret.Auth.ClientToken == "" {
return nil, fmt.Errorf("login response from auth method did not return client token") return a.checkAndSetToken(s)
}
// twoPhaseMFALogin performs the (*AuthMethod).Login() with the configured client
// and checks that an MFARequirement is returned
func (a *Auth) twoPhaseMFALogin(ctx context.Context, authMethod AuthMethod) (*Secret, error) {
s, err := authMethod.Login(ctx, a.c)
if err != nil {
return nil, fmt.Errorf("unable to log in: %w", err)
}
if s == nil || s.Auth == nil || s.Auth.MFARequirement == nil {
if s != nil {
s.Warnings = append(s.Warnings, "expected secret to contain MFARequirements")
}
return s, fmt.Errorf("assumed two-phase MFA login, returned secret is missing MFARequirements")
} }
a.c.SetToken(authSecret.Auth.ClientToken) return s, nil
}
return authSecret, nil
func (a *Auth) checkAndSetToken(s *Secret) (*Secret, error) {
if s == nil || s.Auth == nil || s.Auth.ClientToken == "" {
if s != nil {
s.Warnings = append(s.Warnings, "expected secret to contain ClientToken")
}
return s, fmt.Errorf("response did not return ClientToken, client token not set")
}
a.c.SetToken(s.Auth.ClientToken)
return s, nil
} }

130
api/auth_test.go Normal file
View File

@ -0,0 +1,130 @@
package api
import (
"context"
"testing"
"github.com/hashicorp/vault/sdk/logical"
)
type mockAuthMethod struct {
mockedSecret *Secret
mockedError error
}
func (m *mockAuthMethod) Login(_ context.Context, _ *Client) (*Secret, error) {
return m.mockedSecret, m.mockedError
}
func TestAuth_Login(t *testing.T) {
a := &Auth{
c: &Client{},
}
m := mockAuthMethod{
mockedSecret: &Secret{
Auth: &SecretAuth{
ClientToken: "a-client-token",
},
},
mockedError: nil,
}
t.Run("Login should set token on success", func(t *testing.T) {
if a.c.Token() != "" {
t.Errorf("client token was %v expected to be unset", a.c.Token())
}
_, err := a.Login(context.Background(), &m)
if err != nil {
t.Errorf("Login() error = %v", err)
return
}
if a.c.Token() != m.mockedSecret.Auth.ClientToken {
t.Errorf("client token was %v expected %v", a.c.Token(), m.mockedSecret.Auth.ClientToken)
return
}
})
}
func TestAuth_MFALoginSinglePhase(t *testing.T) {
t.Run("MFALogin() should succeed if credentials are passed in", func(t *testing.T) {
a := &Auth{
c: &Client{},
}
m := mockAuthMethod{
mockedSecret: &Secret{
Auth: &SecretAuth{
ClientToken: "a-client-token",
},
},
mockedError: nil,
}
_, err := a.MFALogin(context.Background(), &m, "testMethod:testPasscode")
if err != nil {
t.Errorf("MFALogin() error %v", err)
return
}
if a.c.Token() != m.mockedSecret.Auth.ClientToken {
t.Errorf("client token was %v expected %v", a.c.Token(), m.mockedSecret.Auth.ClientToken)
return
}
})
}
func TestAuth_MFALoginTwoPhase(t *testing.T) {
tests := []struct {
name string
a *Auth
m *mockAuthMethod
creds *string
wantErr bool
}{
{
name: "return MFARequirements",
a: &Auth{
c: &Client{},
},
m: &mockAuthMethod{
mockedSecret: &Secret{
Auth: &SecretAuth{
MFARequirement: &logical.MFARequirement{
MFARequestID: "a-req-id",
MFAConstraints: nil,
},
},
},
},
wantErr: false,
},
{
name: "error if no MFARequirements",
a: &Auth{
c: &Client{},
},
m: &mockAuthMethod{
mockedSecret: &Secret{
Auth: &SecretAuth{},
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
secret, err := tt.a.MFALogin(context.Background(), tt.m)
if (err != nil) != tt.wantErr {
t.Errorf("MFALogin() error = %v, wantErr %v", err, tt.wantErr)
return
}
if secret.Auth.MFARequirement != tt.m.mockedSecret.Auth.MFARequirement {
t.Errorf("MFALogin() returned %v, expected %v", secret.Auth.MFARequirement, tt.m.mockedSecret.Auth.MFARequirement)
return
}
})
}
}

45
api/sys_mfa.go Normal file
View File

@ -0,0 +1,45 @@
package api
import (
"context"
"fmt"
"net/http"
)
func (c *Sys) MFAValidate(requestID string, payload map[string]interface{}) (*Secret, error) {
return c.MFAValidateWithContext(context.Background(), requestID, payload)
}
func (c *Sys) MFAValidateWithContext(ctx context.Context, requestID string, payload map[string]interface{}) (*Secret, error) {
ctx, cancelFunc := c.c.withConfiguredTimeout(ctx)
defer cancelFunc()
body := map[string]interface{}{
"mfa_request_id": requestID,
"mfa_payload": payload,
}
r := c.c.NewRequest(http.MethodPost, fmt.Sprintf("/v1/sys/mfa/validate"))
if err := r.SetJSONBody(body); err != nil {
return nil, fmt.Errorf("failed to set request body: %w", err)
}
resp, err := c.c.rawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}
if err != nil {
return nil, err
}
secret, err := ParseSecret(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to parse secret from response: %w", err)
}
if secret == nil {
return nil, fmt.Errorf("data from server response is empty")
}
return secret, nil
}

3
changelog/14900.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
api: Added MFALogin() for handling MFA flow when using login helpers.
```

View File

@ -259,8 +259,8 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int {
} }
// passcode could be an empty string // passcode could be an empty string
mfaPayload := map[string][]string{ mfaPayload := map[string]interface{}{
methodInfo.methodID: {passcode}, methodInfo.methodID: []string{passcode},
} }
client, err := c.Client() client, err := c.Client()
@ -269,12 +269,7 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int {
return 2 return 2
} }
path := "sys/mfa/validate" secret, err := client.Sys().MFAValidate(reqID, mfaPayload)
secret, err := client.Logical().Write(path, map[string]interface{}{
"mfa_request_id": reqID,
"mfa_payload": mfaPayload,
})
if err != nil { if err != nil {
c.UI.Error(err.Error()) c.UI.Error(err.Error())
if secret != nil { if secret != nil {
@ -285,7 +280,7 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int {
if secret == nil { if secret == nil {
// Don't output anything unless using the "table" format // Don't output anything unless using the "table" format
if Format(c.UI) == "table" { if Format(c.UI) == "table" {
c.UI.Info(fmt.Sprintf("Success! Data written to: %s", path)) c.UI.Info("Success! Data written to: sys/mfa/validate")
} }
return 0 return 0
} }

View File

@ -1,6 +1,7 @@
package identity package identity
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"reflect" "reflect"
@ -241,7 +242,6 @@ func mfaGenerateLoginDUOTest(client *api.Client) error {
return fmt.Errorf("failed to configure MFAEnforcementConfig: %v", err) return fmt.Errorf("failed to configure MFAEnforcementConfig: %v", err)
} }
} }
secret, err = client.Logical().Write("auth/userpass/login/vaultmfa", map[string]interface{}{ secret, err = client.Logical().Write("auth/userpass/login/vaultmfa", map[string]interface{}{
"password": "testpassword", "password": "testpassword",
}) })
@ -272,12 +272,11 @@ func mfaGenerateLoginDUOTest(client *api.Client) error {
} }
// validation // validation
secret, err = client.Logical().Write("sys/mfa/validate", map[string]interface{}{ secret, err = client.Sys().MFAValidateWithContext(context.Background(),
"mfa_request_id": secret.Auth.MFARequirement.MFARequestID, secret.Auth.MFARequirement.MFARequestID,
"mfa_payload": map[string][]string{ map[string]interface{}{
methodID: {}, methodID: []string{},
}, })
})
if err != nil { if err != nil {
return fmt.Errorf("MFA failed: %v", err) return fmt.Errorf("MFA failed: %v", err)
} }

View File

@ -1,6 +1,7 @@
package identity package identity
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
@ -322,12 +323,12 @@ func mfaGenerateOktaLoginMFATest(client *api.Client) error {
} }
// validation // validation
secret, err = client.Logical().Write("sys/mfa/validate", map[string]interface{}{ secret, err = client.Sys().MFAValidateWithContext(context.Background(),
"mfa_request_id": secret.Auth.MFARequirement.MFARequestID, secret.Auth.MFARequirement.MFARequestID,
"mfa_payload": map[string][]string{ map[string]interface{}{
methodID: {}, methodID: []string{},
}, },
}) )
if err != nil { if err != nil {
return fmt.Errorf("MFA failed: %v", err) return fmt.Errorf("MFA failed: %v", err)
} }

View File

@ -7,6 +7,8 @@ import (
"testing" "testing"
"time" "time"
upAuth "github.com/hashicorp/vault/api/auth/userpass"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/builtin/credential/userpass" "github.com/hashicorp/vault/builtin/credential/userpass"
@ -76,21 +78,27 @@ func doTwoPhaseLogin(client *api.Client, totpCodePath, methodID, username string
} }
totpPasscode := totpResp.Data["code"].(string) totpPasscode := totpResp.Data["code"].(string)
secret, err := client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("auth/userpass/login/%s", username), map[string]interface{}{ upMethod, err := upAuth.NewUserpassAuth(username, &upAuth.Password{FromString: "testpassword"})
"password": "testpassword",
}) mfaSecret, err := client.Auth().MFALogin(context.Background(), upMethod)
if err != nil { if err != nil {
t.Fatalf("first phase of login MFA failed: %v", err) t.Fatalf("failed to login with userpass auth method: %v", err)
} }
secret, err = client.Logical().WriteWithContext(context.Background(), "sys/mfa/validate", map[string]interface{}{
"mfa_request_id": secret.Auth.MFARequirement.MFARequestID, secret, err := client.Auth().MFAValidate(
"mfa_payload": map[string][]string{ context.Background(),
methodID: {totpPasscode}, mfaSecret,
map[string]interface{}{
methodID: []string{totpPasscode},
}, },
}) )
if err != nil { if err != nil {
t.Fatalf("MFA validation failed: %v", err) t.Fatalf("MFA validation failed: %v", err)
} }
if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" {
t.Fatalf("MFA validation failed to return a ClientToken in secret: %v", secret)
}
} }
func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) {