[Vault-5248] MFA support for api login helpers (#14900)
* Add MFA support to login helpers
This commit is contained in:
parent
d43324831b
commit
6116903a37
78
api/auth.go
78
api/auth.go
|
@ -31,16 +31,82 @@ func (a *Auth) Login(ctx context.Context, authMethod AuthMethod) (*Secret, error
|
||||||
if authMethod == nil {
|
if authMethod == nil {
|
||||||
return nil, fmt.Errorf("no auth method provided for login")
|
return nil, fmt.Errorf("no auth method provided for login")
|
||||||
}
|
}
|
||||||
|
return a.login(ctx, authMethod)
|
||||||
|
}
|
||||||
|
|
||||||
authSecret, err := authMethod.Login(ctx, a.c)
|
// MFALogin is a wrapper that helps satisfy Vault's MFA implementation.
|
||||||
|
// If optional credentials are provided a single-phase login will be attempted
|
||||||
|
// and the resulting Secret will contain a ClientToken if the authentication is successful.
|
||||||
|
// The client's token will also be set accordingly.
|
||||||
|
//
|
||||||
|
// If no credentials are provided a two-phase MFA login will be assumed and the resulting
|
||||||
|
// Secret will have a MFARequirement containing the MFARequestID to be used in a follow-up
|
||||||
|
// call to `sys/mfa/validate` or by passing it to the method (*Auth).MFAValidate.
|
||||||
|
func (a *Auth) MFALogin(ctx context.Context, authMethod AuthMethod, creds ...string) (*Secret, error) {
|
||||||
|
if len(creds) > 0 {
|
||||||
|
a.c.SetMFACreds(creds)
|
||||||
|
return a.login(ctx, authMethod)
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.twoPhaseMFALogin(ctx, authMethod)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MFAValidate validates an MFA request using the appropriate payload and a secret containing
|
||||||
|
// Auth.MFARequirement, like the one returned by MFALogin when credentials are not provided.
|
||||||
|
// Upon successful validation the client token will be set accordingly.
|
||||||
|
//
|
||||||
|
// The Secret returned is the authentication secret, which if desired can be
|
||||||
|
// passed as input to the NewLifetimeWatcher method in order to start
|
||||||
|
// automatically renewing the token.
|
||||||
|
func (a *Auth) MFAValidate(ctx context.Context, mfaSecret *Secret, payload map[string]interface{}) (*Secret, error) {
|
||||||
|
if mfaSecret == nil || mfaSecret.Auth == nil || mfaSecret.Auth.MFARequirement == nil {
|
||||||
|
return nil, fmt.Errorf("secret does not contain MFARequirements")
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := a.c.Sys().MFAValidateWithContext(ctx, mfaSecret.Auth.MFARequirement.GetMFARequestID(), payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.checkAndSetToken(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// login performs the (*AuthMethod).Login() with the configured client and checks that a ClientToken is returned
|
||||||
|
func (a *Auth) login(ctx context.Context, authMethod AuthMethod) (*Secret, error) {
|
||||||
|
s, err := authMethod.Login(ctx, a.c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to log in to auth method: %w", err)
|
return nil, fmt.Errorf("unable to log in to auth method: %w", err)
|
||||||
}
|
}
|
||||||
if authSecret == nil || authSecret.Auth == nil || authSecret.Auth.ClientToken == "" {
|
|
||||||
return nil, fmt.Errorf("login response from auth method did not return client token")
|
return a.checkAndSetToken(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// twoPhaseMFALogin performs the (*AuthMethod).Login() with the configured client
|
||||||
|
// and checks that an MFARequirement is returned
|
||||||
|
func (a *Auth) twoPhaseMFALogin(ctx context.Context, authMethod AuthMethod) (*Secret, error) {
|
||||||
|
s, err := authMethod.Login(ctx, a.c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to log in: %w", err)
|
||||||
|
}
|
||||||
|
if s == nil || s.Auth == nil || s.Auth.MFARequirement == nil {
|
||||||
|
if s != nil {
|
||||||
|
s.Warnings = append(s.Warnings, "expected secret to contain MFARequirements")
|
||||||
|
}
|
||||||
|
return s, fmt.Errorf("assumed two-phase MFA login, returned secret is missing MFARequirements")
|
||||||
}
|
}
|
||||||
|
|
||||||
a.c.SetToken(authSecret.Auth.ClientToken)
|
return s, nil
|
||||||
|
}
|
||||||
return authSecret, nil
|
|
||||||
|
func (a *Auth) checkAndSetToken(s *Secret) (*Secret, error) {
|
||||||
|
if s == nil || s.Auth == nil || s.Auth.ClientToken == "" {
|
||||||
|
if s != nil {
|
||||||
|
s.Warnings = append(s.Warnings, "expected secret to contain ClientToken")
|
||||||
|
}
|
||||||
|
return s, fmt.Errorf("response did not return ClientToken, client token not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
a.c.SetToken(s.Auth.ClientToken)
|
||||||
|
|
||||||
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,130 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockAuthMethod struct {
|
||||||
|
mockedSecret *Secret
|
||||||
|
mockedError error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAuthMethod) Login(_ context.Context, _ *Client) (*Secret, error) {
|
||||||
|
return m.mockedSecret, m.mockedError
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuth_Login(t *testing.T) {
|
||||||
|
a := &Auth{
|
||||||
|
c: &Client{},
|
||||||
|
}
|
||||||
|
|
||||||
|
m := mockAuthMethod{
|
||||||
|
mockedSecret: &Secret{
|
||||||
|
Auth: &SecretAuth{
|
||||||
|
ClientToken: "a-client-token",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
mockedError: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Login should set token on success", func(t *testing.T) {
|
||||||
|
if a.c.Token() != "" {
|
||||||
|
t.Errorf("client token was %v expected to be unset", a.c.Token())
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := a.Login(context.Background(), &m)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Login() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.c.Token() != m.mockedSecret.Auth.ClientToken {
|
||||||
|
t.Errorf("client token was %v expected %v", a.c.Token(), m.mockedSecret.Auth.ClientToken)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuth_MFALoginSinglePhase(t *testing.T) {
|
||||||
|
t.Run("MFALogin() should succeed if credentials are passed in", func(t *testing.T) {
|
||||||
|
a := &Auth{
|
||||||
|
c: &Client{},
|
||||||
|
}
|
||||||
|
|
||||||
|
m := mockAuthMethod{
|
||||||
|
mockedSecret: &Secret{
|
||||||
|
Auth: &SecretAuth{
|
||||||
|
ClientToken: "a-client-token",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
mockedError: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := a.MFALogin(context.Background(), &m, "testMethod:testPasscode")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("MFALogin() error %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if a.c.Token() != m.mockedSecret.Auth.ClientToken {
|
||||||
|
t.Errorf("client token was %v expected %v", a.c.Token(), m.mockedSecret.Auth.ClientToken)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuth_MFALoginTwoPhase(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
a *Auth
|
||||||
|
m *mockAuthMethod
|
||||||
|
creds *string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "return MFARequirements",
|
||||||
|
a: &Auth{
|
||||||
|
c: &Client{},
|
||||||
|
},
|
||||||
|
m: &mockAuthMethod{
|
||||||
|
mockedSecret: &Secret{
|
||||||
|
Auth: &SecretAuth{
|
||||||
|
MFARequirement: &logical.MFARequirement{
|
||||||
|
MFARequestID: "a-req-id",
|
||||||
|
MFAConstraints: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error if no MFARequirements",
|
||||||
|
a: &Auth{
|
||||||
|
c: &Client{},
|
||||||
|
},
|
||||||
|
m: &mockAuthMethod{
|
||||||
|
mockedSecret: &Secret{
|
||||||
|
Auth: &SecretAuth{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
secret, err := tt.a.MFALogin(context.Background(), tt.m)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("MFALogin() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if secret.Auth.MFARequirement != tt.m.mockedSecret.Auth.MFARequirement {
|
||||||
|
t.Errorf("MFALogin() returned %v, expected %v", secret.Auth.MFARequirement, tt.m.mockedSecret.Auth.MFARequirement)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *Sys) MFAValidate(requestID string, payload map[string]interface{}) (*Secret, error) {
|
||||||
|
return c.MFAValidateWithContext(context.Background(), requestID, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Sys) MFAValidateWithContext(ctx context.Context, requestID string, payload map[string]interface{}) (*Secret, error) {
|
||||||
|
ctx, cancelFunc := c.c.withConfiguredTimeout(ctx)
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"mfa_request_id": requestID,
|
||||||
|
"mfa_payload": payload,
|
||||||
|
}
|
||||||
|
|
||||||
|
r := c.c.NewRequest(http.MethodPost, fmt.Sprintf("/v1/sys/mfa/validate"))
|
||||||
|
if err := r.SetJSONBody(body); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set request body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.c.rawRequestWithContext(ctx, r)
|
||||||
|
if resp != nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
secret, err := ParseSecret(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse secret from response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if secret == nil {
|
||||||
|
return nil, fmt.Errorf("data from server response is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
return secret, nil
|
||||||
|
}
|
|
@ -0,0 +1,3 @@
|
||||||
|
```release-note:improvement
|
||||||
|
api: Added MFALogin() for handling MFA flow when using login helpers.
|
||||||
|
```
|
|
@ -259,8 +259,8 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int {
|
||||||
}
|
}
|
||||||
|
|
||||||
// passcode could be an empty string
|
// passcode could be an empty string
|
||||||
mfaPayload := map[string][]string{
|
mfaPayload := map[string]interface{}{
|
||||||
methodInfo.methodID: {passcode},
|
methodInfo.methodID: []string{passcode},
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := c.Client()
|
client, err := c.Client()
|
||||||
|
@ -269,12 +269,7 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int {
|
||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
|
|
||||||
path := "sys/mfa/validate"
|
secret, err := client.Sys().MFAValidate(reqID, mfaPayload)
|
||||||
|
|
||||||
secret, err := client.Logical().Write(path, map[string]interface{}{
|
|
||||||
"mfa_request_id": reqID,
|
|
||||||
"mfa_payload": mfaPayload,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.UI.Error(err.Error())
|
c.UI.Error(err.Error())
|
||||||
if secret != nil {
|
if secret != nil {
|
||||||
|
@ -285,7 +280,7 @@ func (c *BaseCommand) validateMFA(reqID string, methodInfo MFAMethodInfo) int {
|
||||||
if secret == nil {
|
if secret == nil {
|
||||||
// Don't output anything unless using the "table" format
|
// Don't output anything unless using the "table" format
|
||||||
if Format(c.UI) == "table" {
|
if Format(c.UI) == "table" {
|
||||||
c.UI.Info(fmt.Sprintf("Success! Data written to: %s", path))
|
c.UI.Info("Success! Data written to: sys/mfa/validate")
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package identity
|
package identity
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
@ -241,7 +242,6 @@ func mfaGenerateLoginDUOTest(client *api.Client) error {
|
||||||
return fmt.Errorf("failed to configure MFAEnforcementConfig: %v", err)
|
return fmt.Errorf("failed to configure MFAEnforcementConfig: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
secret, err = client.Logical().Write("auth/userpass/login/vaultmfa", map[string]interface{}{
|
secret, err = client.Logical().Write("auth/userpass/login/vaultmfa", map[string]interface{}{
|
||||||
"password": "testpassword",
|
"password": "testpassword",
|
||||||
})
|
})
|
||||||
|
@ -272,12 +272,11 @@ func mfaGenerateLoginDUOTest(client *api.Client) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// validation
|
// validation
|
||||||
secret, err = client.Logical().Write("sys/mfa/validate", map[string]interface{}{
|
secret, err = client.Sys().MFAValidateWithContext(context.Background(),
|
||||||
"mfa_request_id": secret.Auth.MFARequirement.MFARequestID,
|
secret.Auth.MFARequirement.MFARequestID,
|
||||||
"mfa_payload": map[string][]string{
|
map[string]interface{}{
|
||||||
methodID: {},
|
methodID: []string{},
|
||||||
},
|
})
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("MFA failed: %v", err)
|
return fmt.Errorf("MFA failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package identity
|
package identity
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -322,12 +323,12 @@ func mfaGenerateOktaLoginMFATest(client *api.Client) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// validation
|
// validation
|
||||||
secret, err = client.Logical().Write("sys/mfa/validate", map[string]interface{}{
|
secret, err = client.Sys().MFAValidateWithContext(context.Background(),
|
||||||
"mfa_request_id": secret.Auth.MFARequirement.MFARequestID,
|
secret.Auth.MFARequirement.MFARequestID,
|
||||||
"mfa_payload": map[string][]string{
|
map[string]interface{}{
|
||||||
methodID: {},
|
methodID: []string{},
|
||||||
},
|
},
|
||||||
})
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("MFA failed: %v", err)
|
return fmt.Errorf("MFA failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,8 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
upAuth "github.com/hashicorp/vault/api/auth/userpass"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/api"
|
"github.com/hashicorp/vault/api"
|
||||||
"github.com/hashicorp/vault/audit"
|
"github.com/hashicorp/vault/audit"
|
||||||
"github.com/hashicorp/vault/builtin/credential/userpass"
|
"github.com/hashicorp/vault/builtin/credential/userpass"
|
||||||
|
@ -76,21 +78,27 @@ func doTwoPhaseLogin(client *api.Client, totpCodePath, methodID, username string
|
||||||
}
|
}
|
||||||
totpPasscode := totpResp.Data["code"].(string)
|
totpPasscode := totpResp.Data["code"].(string)
|
||||||
|
|
||||||
secret, err := client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("auth/userpass/login/%s", username), map[string]interface{}{
|
upMethod, err := upAuth.NewUserpassAuth(username, &upAuth.Password{FromString: "testpassword"})
|
||||||
"password": "testpassword",
|
|
||||||
})
|
mfaSecret, err := client.Auth().MFALogin(context.Background(), upMethod)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("first phase of login MFA failed: %v", err)
|
t.Fatalf("failed to login with userpass auth method: %v", err)
|
||||||
}
|
}
|
||||||
secret, err = client.Logical().WriteWithContext(context.Background(), "sys/mfa/validate", map[string]interface{}{
|
|
||||||
"mfa_request_id": secret.Auth.MFARequirement.MFARequestID,
|
secret, err := client.Auth().MFAValidate(
|
||||||
"mfa_payload": map[string][]string{
|
context.Background(),
|
||||||
methodID: {totpPasscode},
|
mfaSecret,
|
||||||
|
map[string]interface{}{
|
||||||
|
methodID: []string{totpPasscode},
|
||||||
},
|
},
|
||||||
})
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("MFA validation failed: %v", err)
|
t.Fatalf("MFA validation failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" {
|
||||||
|
t.Fatalf("MFA validation failed to return a ClientToken in secret: %v", secret)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) {
|
func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue