[Vault-5248] MFA support for api login helpers (#14900)
* Add MFA support to login helpers
This commit is contained in:
parent
d43324831b
commit
6116903a37
78
api/auth.go
78
api/auth.go
|
@ -31,16 +31,82 @@ func (a *Auth) Login(ctx context.Context, authMethod AuthMethod) (*Secret, error
|
|||
if authMethod == nil {
|
||||
return nil, fmt.Errorf("no auth method provided for login")
|
||||
}
|
||||
return a.login(ctx, authMethod)
|
||||
}
|
||||
|
||||
authSecret, err := authMethod.Login(ctx, a.c)
|
||||
// MFALogin is a wrapper that helps satisfy Vault's MFA implementation.
|
||||
// If optional credentials are provided a single-phase login will be attempted
|
||||
// and the resulting Secret will contain a ClientToken if the authentication is successful.
|
||||
// The client's token will also be set accordingly.
|
||||
//
|
||||
// If no credentials are provided a two-phase MFA login will be assumed and the resulting
|
||||
// Secret will have a MFARequirement containing the MFARequestID to be used in a follow-up
|
||||
// call to `sys/mfa/validate` or by passing it to the method (*Auth).MFAValidate.
|
||||
func (a *Auth) MFALogin(ctx context.Context, authMethod AuthMethod, creds ...string) (*Secret, error) {
|
||||
if len(creds) > 0 {
|
||||
a.c.SetMFACreds(creds)
|
||||
return a.login(ctx, authMethod)
|
||||
}
|
||||
|
||||
return a.twoPhaseMFALogin(ctx, authMethod)
|
||||
}
|
||||
|
||||
// MFAValidate validates an MFA request using the appropriate payload and a secret containing
|
||||
// Auth.MFARequirement, like the one returned by MFALogin when credentials are not provided.
|
||||
// Upon successful validation the client token will be set accordingly.
|
||||
//
|
||||
// The Secret returned is the authentication secret, which if desired can be
|
||||
// passed as input to the NewLifetimeWatcher method in order to start
|
||||
// automatically renewing the token.
|
||||
func (a *Auth) MFAValidate(ctx context.Context, mfaSecret *Secret, payload map[string]interface{}) (*Secret, error) {
|
||||
if mfaSecret == nil || mfaSecret.Auth == nil || mfaSecret.Auth.MFARequirement == nil {
|
||||
return nil, fmt.Errorf("secret does not contain MFARequirements")
|
||||
}
|
||||
|
||||
s, err := a.c.Sys().MFAValidateWithContext(ctx, mfaSecret.Auth.MFARequirement.GetMFARequestID(), payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return a.checkAndSetToken(s)
|
||||
}
|
||||
|
||||
// login performs the (*AuthMethod).Login() with the configured client and checks that a ClientToken is returned
|
||||
func (a *Auth) login(ctx context.Context, authMethod AuthMethod) (*Secret, error) {
|
||||
s, err := authMethod.Login(ctx, a.c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to log in to auth method: %w", err)
|
||||
}
|
||||
if authSecret == nil || authSecret.Auth == nil || authSecret.Auth.ClientToken == "" {
|
||||
return nil, fmt.Errorf("login response from auth method did not return client token")
|
||||
|
||||
return a.checkAndSetToken(s)
|
||||
}
|
||||
|
||||
a.c.SetToken(authSecret.Auth.ClientToken)
|
||||
|
||||
return authSecret, nil
|
||||
// twoPhaseMFALogin performs the (*AuthMethod).Login() with the configured client
|
||||
// and checks that an MFARequirement is returned
|
||||
func (a *Auth) twoPhaseMFALogin(ctx context.Context, authMethod AuthMethod) (*Secret, error) {
|
||||
s, err := authMethod.Login(ctx, a.c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to log in: %w", err)
|
||||
}
|
||||
if s == nil || s.Auth == nil || s.Auth.MFARequirement == nil {
|
||||
if s != nil {
|
||||
s.Warnings = append(s.Warnings, "expected secret to contain MFARequirements")
|
||||
}
|
||||
return s, fmt.Errorf("assumed two-phase MFA login, returned secret is missing MFARequirements")
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (a *Auth) checkAndSetToken(s *Secret) (*Secret, error) {
|
||||
if s == nil || s.Auth == nil || s.Auth.ClientToken == "" {
|
||||
if s != nil {
|
||||
s.Warnings = append(s.Warnings, "expected secret to contain ClientToken")
|
||||
}
|
||||
return s, fmt.Errorf("response did not return ClientToken, client token not set")
|
||||
}
|
||||
|
||||
a.c.SetToken(s.Auth.ClientToken)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
type mockAuthMethod struct {
|
||||
mockedSecret *Secret
|
||||
mockedError error
|
||||
}
|
||||
|
||||
func (m *mockAuthMethod) Login(_ context.Context, _ *Client) (*Secret, error) {
|
||||
return m.mockedSecret, m.mockedError
|
||||
}
|
||||
|
||||
func TestAuth_Login(t *testing.T) {
|
||||
a := &Auth{
|
||||
c: &Client{},
|
||||
}
|
||||
|
||||
m := mockAuthMethod{
|
||||
mockedSecret: &Secret{
|
||||
Auth: &SecretAuth{
|
||||
ClientToken: "a-client-token",
|
||||
},
|
||||
},
|
||||
mockedError: nil,
|
||||
}
|
||||
|
||||
t.Run("Login should set token on success", func(t *testing.T) {
|
||||
if a.c.Token() != "" {
|
||||
t.Errorf("client token was %v expected to be unset", a.c.Token())
|
||||
}
|
||||
|
||||
_, err := a.Login(context.Background(), &m)
|
||||
if err != nil {
|
||||
t.Errorf("Login() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if a.c.Token() != m.mockedSecret.Auth.ClientToken {
|
||||
t.Errorf("client token was %v expected %v", a.c.Token(), m.mockedSecret.Auth.ClientToken)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuth_MFALoginSinglePhase(t *testing.T) {
|
||||
t.Run("MFALogin() should succeed if credentials are passed in", func(t *testing.T) {
|
||||
a := &Auth{
|
||||
c: &Client{},
|
||||
}
|
||||
|
||||
m := mockAuthMethod{
|
||||
mockedSecret: &Secret{
|
||||
Auth: &SecretAuth{
|
||||
ClientToken: "a-client-token",
|
||||
},
|
||||
},
|
||||
mockedError: nil,
|
||||
}
|
||||
|
||||
_, err := a.MFALogin(context.Background(), &m, "testMethod:testPasscode")
|
||||
if err != nil {
|
||||
t.Errorf("MFALogin() error %v", err)
|
||||
return
|
||||
}
|
||||
if a.c.Token() != m.mockedSecret.Auth.ClientToken {
|
||||
t.Errorf("client token was %v expected %v", a.c.Token(), m.mockedSecret.Auth.ClientToken)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuth_MFALoginTwoPhase(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a *Auth
|
||||
m *mockAuthMethod
|
||||
creds *string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "return MFARequirements",
|
||||
a: &Auth{
|
||||
c: &Client{},
|
||||
},
|
||||
m: &mockAuthMethod{
|
||||
mockedSecret: &Secret{
|
||||
Auth: &SecretAuth{
|
||||
MFARequirement: &logical.MFARequirement{
|
||||
MFARequestID: "a-req-id",
|
||||
MFAConstraints: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "error if no MFARequirements",
|
||||
a: &Auth{
|
||||
c: &Client{},
|
||||
},
|
||||
m: &mockAuthMethod{
|
||||
mockedSecret: &Secret{
|
||||
Auth: &SecretAuth{},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
secret, err := tt.a.MFALogin(context.Background(), tt.m)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("MFALogin() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if secret.Auth.MFARequirement != tt.m.mockedSecret.Auth.MFARequirement {
|
||||
t.Errorf("MFALogin() returned %v, expected %v", secret.Auth.MFARequirement, tt.m.mockedSecret.Auth.MFARequirement)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (c *Sys) MFAValidate(requestID string, payload map[string]interface{}) (*Secret, error) {
|
||||
return c.MFAValidateWithContext(context.Background(), requestID, payload)
|
||||
}
|
||||
|
||||
func (c *Sys) MFAValidateWithContext(ctx context.Context, requestID string, payload map[string]interface{}) (*Secret, error) {
|
||||
ctx, cancelFunc := c.c.withConfiguredTimeout(ctx)
|
||||
defer cancelFunc()
|
||||
|
||||
body := map[string]interface{}{
|
||||
"mfa_request_id": requestID,
|
||||
"mfa_payload": payload,
|
||||
}
|
||||
|
||||
r := c.c.NewRequest(http.MethodPost, fmt.Sprintf("/v1/sys/mfa/validate"))
|
||||
if err := r.SetJSONBody(body); err != nil {
|
||||
return nil, fmt.Errorf("failed to set request body: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.c.rawRequestWithContext(ctx, r)
|
||||
if resp != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
secret, err := ParseSecret(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse secret from response: %w", err)
|
||||
}
|
||||
|
||||
if secret == nil {
|
||||
return nil, fmt.Errorf("data from server response is empty")
|
||||
}
|
||||
|
||||
return secret, nil
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
```release-note:improvement
|
||||
api: Added MFALogin() for handling MFA flow when using login helpers.
|
||||
```
|
|
@ -259,8 +259,8 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int {
|
|||
}
|
||||
|
||||
// passcode could be an empty string
|
||||
mfaPayload := map[string][]string{
|
||||
methodInfo.methodID: {passcode},
|
||||
mfaPayload := map[string]interface{}{
|
||||
methodInfo.methodID: []string{passcode},
|
||||
}
|
||||
|
||||
client, err := c.Client()
|
||||
|
@ -269,12 +269,7 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int {
|
|||
return 2
|
||||
}
|
||||
|
||||
path := "sys/mfa/validate"
|
||||
|
||||
secret, err := client.Logical().Write(path, map[string]interface{}{
|
||||
"mfa_request_id": reqID,
|
||||
"mfa_payload": mfaPayload,
|
||||
})
|
||||
secret, err := client.Sys().MFAValidate(reqID, mfaPayload)
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
if secret != nil {
|
||||
|
@ -285,7 +280,7 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int {
|
|||
if secret == nil {
|
||||
// Don't output anything unless using the "table" format
|
||||
if Format(c.UI) == "table" {
|
||||
c.UI.Info(fmt.Sprintf("Success! Data written to: %s", path))
|
||||
c.UI.Info("Success! Data written to: sys/mfa/validate")
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package identity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
|
@ -241,7 +242,6 @@ func mfaGenerateLoginDUOTest(client *api.Client) error {
|
|||
return fmt.Errorf("failed to configure MFAEnforcementConfig: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
secret, err = client.Logical().Write("auth/userpass/login/vaultmfa", map[string]interface{}{
|
||||
"password": "testpassword",
|
||||
})
|
||||
|
@ -272,11 +272,10 @@ func mfaGenerateLoginDUOTest(client *api.Client) error {
|
|||
}
|
||||
|
||||
// validation
|
||||
secret, err = client.Logical().Write("sys/mfa/validate", map[string]interface{}{
|
||||
"mfa_request_id": secret.Auth.MFARequirement.MFARequestID,
|
||||
"mfa_payload": map[string][]string{
|
||||
methodID: {},
|
||||
},
|
||||
secret, err = client.Sys().MFAValidateWithContext(context.Background(),
|
||||
secret.Auth.MFARequirement.MFARequestID,
|
||||
map[string]interface{}{
|
||||
methodID: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("MFA failed: %v", err)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package identity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
@ -322,12 +323,12 @@ func mfaGenerateOktaLoginMFATest(client *api.Client) error {
|
|||
}
|
||||
|
||||
// validation
|
||||
secret, err = client.Logical().Write("sys/mfa/validate", map[string]interface{}{
|
||||
"mfa_request_id": secret.Auth.MFARequirement.MFARequestID,
|
||||
"mfa_payload": map[string][]string{
|
||||
methodID: {},
|
||||
secret, err = client.Sys().MFAValidateWithContext(context.Background(),
|
||||
secret.Auth.MFARequirement.MFARequestID,
|
||||
map[string]interface{}{
|
||||
methodID: []string{},
|
||||
},
|
||||
})
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("MFA failed: %v", err)
|
||||
}
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
upAuth "github.com/hashicorp/vault/api/auth/userpass"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/audit"
|
||||
"github.com/hashicorp/vault/builtin/credential/userpass"
|
||||
|
@ -76,21 +78,27 @@ func doTwoPhaseLogin(client *api.Client, totpCodePath, methodID, username string
|
|||
}
|
||||
totpPasscode := totpResp.Data["code"].(string)
|
||||
|
||||
secret, err := client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("auth/userpass/login/%s", username), map[string]interface{}{
|
||||
"password": "testpassword",
|
||||
})
|
||||
upMethod, err := upAuth.NewUserpassAuth(username, &upAuth.Password{FromString: "testpassword"})
|
||||
|
||||
mfaSecret, err := client.Auth().MFALogin(context.Background(), upMethod)
|
||||
if err != nil {
|
||||
t.Fatalf("first phase of login MFA failed: %v", err)
|
||||
t.Fatalf("failed to login with userpass auth method: %v", err)
|
||||
}
|
||||
secret, err = client.Logical().WriteWithContext(context.Background(), "sys/mfa/validate", map[string]interface{}{
|
||||
"mfa_request_id": secret.Auth.MFARequirement.MFARequestID,
|
||||
"mfa_payload": map[string][]string{
|
||||
methodID: {totpPasscode},
|
||||
|
||||
secret, err := client.Auth().MFAValidate(
|
||||
context.Background(),
|
||||
mfaSecret,
|
||||
map[string]interface{}{
|
||||
methodID: []string{totpPasscode},
|
||||
},
|
||||
})
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("MFA validation failed: %v", err)
|
||||
}
|
||||
|
||||
if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" {
|
||||
t.Fatalf("MFA validation failed to return a ClientToken in secret: %v", secret)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue