named Login MFA methods (#18610)

* named MFA method configurations

* fix a test

* CL

* fix an issue with same config name different ID and add a test

* feedback

* feedback on test

* consistent use of passcode for all MFA methods (#18611)

* make use of passcode factor consistent for all MFA types

* improved type for MFA factors

* add method name to login CLI

* minor refactoring

* only accept MFA method name with its namespace path in the login request MFA header

* fix a bug

* fixing an ErrorOrNil return value

* more informative error message

* Apply suggestions from code review

Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>

* feedback

* test refactor a bit

* adding godoc for a test

* feedback

* remove sanitize method name

* guard a possbile nil ref

Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>
This commit is contained in:
Hamid Ghaf 2023-01-23 15:51:22 -05:00 committed by GitHub
parent 43f679bf05
commit 65a41d4f08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1008 additions and 516 deletions

4
changelog/18610.txt Normal file
View File

@ -0,0 +1,4 @@
```release-note:improvement
auth: Allow naming login MFA methods and using those names instead of IDs in satisfying MFA requirement for requests.
Make passcode arguments consistent across login MFA method types.
```

View File

@ -557,6 +557,9 @@ func (t TableFormatter) OutputSecret(ui cli.Ui, secret *api.Secret) error {
for _, constraint := range constraintSet.Any { for _, constraint := range constraintSet.Any {
out = append(out, fmt.Sprintf("mfa_constraint_%s_%s_id %s %s", k, constraint.Type, hopeDelim, constraint.ID)) out = append(out, fmt.Sprintf("mfa_constraint_%s_%s_id %s %s", k, constraint.Type, hopeDelim, constraint.ID))
out = append(out, fmt.Sprintf("mfa_constraint_%s_%s_uses_passcode %s %t", k, constraint.Type, hopeDelim, constraint.UsesPasscode)) out = append(out, fmt.Sprintf("mfa_constraint_%s_%s_uses_passcode %s %t", k, constraint.Type, hopeDelim, constraint.UsesPasscode))
if constraint.Name != "" {
out = append(out, fmt.Sprintf("mfa_constraint_%s_%s_name %s %s", k, constraint.Type, hopeDelim, constraint.Name))
}
} }
} }
} else { // Token information only makes sense if no further MFA requirement (i.e. if we actually have a token) } else { // Token information only makes sense if no further MFA requirement (i.e. if we actually have a token)

View File

@ -1,8 +1,11 @@
package command package command
import ( import (
"context"
"regexp"
"strings" "strings"
"testing" "testing"
"time"
"github.com/mitchellh/cli" "github.com/mitchellh/cli"
@ -37,10 +40,7 @@ func testLoginCommand(tb testing.TB) (*cli.MockUi, *LoginCommand) {
} }
} }
func TestLoginCommand_Run(t *testing.T) { func TestCustomPath(t *testing.T) {
t.Parallel()
t.Run("custom_path", func(t *testing.T) {
t.Parallel() t.Parallel()
client, closer := testVaultServer(t) client, closer := testVaultServer(t)
@ -91,9 +91,10 @@ func TestLoginCommand_Run(t *testing.T) {
if l, exp := len(storedToken), minTokenLengthExternal+vault.TokenPrefixLength; l < exp { if l, exp := len(storedToken), minTokenLengthExternal+vault.TokenPrefixLength; l < exp {
t.Errorf("expected token to be %d characters, was %d: %q", exp, l, storedToken) t.Errorf("expected token to be %d characters, was %d: %q", exp, l, storedToken)
} }
}) }
t.Run("no_store", func(t *testing.T) { // Do not persist the token to the token helper
func TestNoStore(t *testing.T) {
t.Parallel() t.Parallel()
client, closer := testVaultServer(t) client, closer := testVaultServer(t)
@ -137,9 +138,9 @@ func TestLoginCommand_Run(t *testing.T) {
if exp := ""; storedToken != exp { if exp := ""; storedToken != exp {
t.Errorf("expected %q to be %q", storedToken, exp) t.Errorf("expected %q to be %q", storedToken, exp)
} }
}) }
t.Run("stores", func(t *testing.T) { func TestStores(t *testing.T) {
t.Parallel() t.Parallel()
client, closer := testVaultServer(t) client, closer := testVaultServer(t)
@ -177,9 +178,9 @@ func TestLoginCommand_Run(t *testing.T) {
if storedToken != token { if storedToken != token {
t.Errorf("expected %q to be %q", storedToken, token) t.Errorf("expected %q to be %q", storedToken, token)
} }
}) }
t.Run("token_only", func(t *testing.T) { func TestTokenOnly(t *testing.T) {
t.Parallel() t.Parallel()
client, closer := testVaultServer(t) client, closer := testVaultServer(t)
@ -223,9 +224,9 @@ func TestLoginCommand_Run(t *testing.T) {
if storedToken, err := tokenHelper.Get(); err != nil || storedToken != "" { if storedToken, err := tokenHelper.Get(); err != nil || storedToken != "" {
t.Fatalf("expected token to not be stored: %s: %q", err, storedToken) t.Fatalf("expected token to not be stored: %s: %q", err, storedToken)
} }
}) }
t.Run("failure_no_store", func(t *testing.T) { func TestFailureNoStore(t *testing.T) {
t.Parallel() t.Parallel()
client, closer := testVaultServer(t) client, closer := testVaultServer(t)
@ -255,9 +256,9 @@ func TestLoginCommand_Run(t *testing.T) {
if storedToken, err := tokenHelper.Get(); err != nil || storedToken != "" { if storedToken, err := tokenHelper.Get(); err != nil || storedToken != "" {
t.Fatalf("expected token to not be stored: %s: %q", err, storedToken) t.Fatalf("expected token to not be stored: %s: %q", err, storedToken)
} }
}) }
t.Run("wrap_auto_unwrap", func(t *testing.T) { func TestWrapAutoUnwrap(t *testing.T) {
t.Parallel() t.Parallel()
client, closer := testVaultServer(t) client, closer := testVaultServer(t)
@ -314,9 +315,9 @@ func TestLoginCommand_Run(t *testing.T) {
if secret.WrapInfo != nil { if secret.WrapInfo != nil {
t.Errorf("expected to be unwrapped: %#v", secret) t.Errorf("expected to be unwrapped: %#v", secret)
} }
}) }
t.Run("wrap_token_only", func(t *testing.T) { func TestWrapTokenOnly(t *testing.T) {
t.Parallel() t.Parallel()
client, closer := testVaultServer(t) client, closer := testVaultServer(t)
@ -375,9 +376,9 @@ func TestLoginCommand_Run(t *testing.T) {
if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" {
t.Fatalf("expected secret to have auth: %#v", secret) t.Fatalf("expected secret to have auth: %#v", secret)
} }
}) }
t.Run("wrap_no_store", func(t *testing.T) { func TestWrapNoStore(t *testing.T) {
t.Parallel() t.Parallel()
client, closer := testVaultServer(t) client, closer := testVaultServer(t)
@ -427,94 +428,9 @@ func TestLoginCommand_Run(t *testing.T) {
if !strings.Contains(output, expected) { if !strings.Contains(output, expected) {
t.Errorf("expected %q to contain %q", output, expected) t.Errorf("expected %q to contain %q", output, expected)
} }
})
t.Run("login_mfa_single_phase", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
ui, cmd := testLoginCommand(t)
userclient, entityID, methodID := testhelpers.SetupLoginMFATOTP(t, client)
cmd.client = userclient
enginePath := testhelpers.RegisterEntityInTOTPEngine(t, client, entityID, methodID)
totpCode := testhelpers.GetTOTPCodeFromEngine(t, client, enginePath)
// login command bails early for test clients, so we have to explicitly set this
cmd.client.SetMFACreds([]string{methodID + ":" + totpCode})
code := cmd.Run([]string{
"-method", "userpass",
"username=testuser1",
"password=testpassword",
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d", code, exp)
} }
tokenHelper, err := cmd.TokenHelper() func TestCommunicationFailure(t *testing.T) {
if err != nil {
t.Fatal(err)
}
storedToken, err := tokenHelper.Get()
if err != nil {
t.Fatal(err)
}
output = ui.OutputWriter.String() + ui.ErrorWriter.String()
t.Logf("\n%+v", output)
if !strings.Contains(output, storedToken) {
t.Fatalf("expected stored token: %q, got: %q", storedToken, output)
}
})
t.Run("login_mfa_two_phase", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
ui, cmd := testLoginCommand(t)
userclient, entityID, methodID := testhelpers.SetupLoginMFATOTP(t, client)
cmd.client = userclient
_ = testhelpers.RegisterEntityInTOTPEngine(t, client, entityID, methodID)
// clear the MFA creds just to be sure
cmd.client.SetMFACreds([]string{})
code := cmd.Run([]string{
"-method", "userpass",
"username=testuser1",
"password=testpassword",
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d", code, exp)
}
expected := methodID
output = ui.OutputWriter.String() + ui.ErrorWriter.String()
t.Logf("\n%+v", output)
if !strings.Contains(output, expected) {
t.Fatalf("expected stored token: %q, got: %q", expected, output)
}
tokenHelper, err := cmd.TokenHelper()
if err != nil {
t.Fatal(err)
}
storedToken, err := tokenHelper.Get()
if storedToken != "" {
t.Fatal("expected empty stored token")
}
if err != nil {
t.Fatal(err)
}
})
t.Run("communication_failure", func(t *testing.T) {
t.Parallel() t.Parallel()
client, closer := testVaultServerBad(t) client, closer := testVaultServerBad(t)
@ -535,12 +451,163 @@ func TestLoginCommand_Run(t *testing.T) {
if !strings.Contains(combined, expected) { if !strings.Contains(combined, expected) {
t.Errorf("expected %q to contain %q", combined, expected) t.Errorf("expected %q to contain %q", combined, expected)
} }
}) }
t.Run("no_tabs", func(t *testing.T) { func TestNoTabs(t *testing.T) {
t.Parallel() t.Parallel()
_, cmd := testLoginCommand(t) _, cmd := testLoginCommand(t)
assertNoTabs(t, cmd) assertNoTabs(t, cmd)
}) }
func TestLoginMFASinglePhase(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
methodName := "foo"
waitPeriod := 5
userClient, entityID, methodID := testhelpers.SetupLoginMFATOTP(t, client, methodName, waitPeriod)
enginePath := testhelpers.RegisterEntityInTOTPEngine(t, client, entityID, methodID)
runCommand := func(methodIdentifier string) {
// the time required for the totp engine to generate a new code
time.Sleep(time.Duration(waitPeriod) * time.Second)
totpCode := testhelpers.GetTOTPCodeFromEngine(t, client, enginePath)
ui, cmd := testLoginCommand(t)
cmd.client = userClient
// login command bails early for test clients, so we have to explicitly set this
cmd.client.SetMFACreds([]string{methodIdentifier + ":" + totpCode})
code := cmd.Run([]string{
"-method", "userpass",
"username=testuser1",
"password=testpassword",
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d", code, exp)
}
tokenHelper, err := cmd.TokenHelper()
if err != nil {
t.Fatal(err)
}
storedToken, err := tokenHelper.Get()
if err != nil {
t.Fatal(err)
}
if storedToken == "" {
t.Fatal("expected non-empty stored token")
}
output = ui.OutputWriter.String()
if !strings.Contains(output, storedToken) {
t.Fatalf("expected stored token: %q, got: %q", storedToken, output)
}
}
runCommand(methodID)
runCommand(methodName)
}
func TestLoginMFATwoPhase(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
ui, cmd := testLoginCommand(t)
userclient, entityID, methodID := testhelpers.SetupLoginMFATOTP(t, client, "", 5)
cmd.client = userclient
_ = testhelpers.RegisterEntityInTOTPEngine(t, client, entityID, methodID)
// clear the MFA creds just to be sure
cmd.client.SetMFACreds([]string{})
code := cmd.Run([]string{
"-method", "userpass",
"username=testuser1",
"password=testpassword",
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d", code, exp)
}
expected := methodID
output = ui.OutputWriter.String()
if !strings.Contains(output, expected) {
t.Fatalf("expected stored token: %q, got: %q", expected, output)
}
tokenHelper, err := cmd.TokenHelper()
if err != nil {
t.Fatal(err)
}
storedToken, err := tokenHelper.Get()
if storedToken != "" {
t.Fatal("expected empty stored token")
}
if err != nil {
t.Fatal(err)
}
}
func TestLoginMFATwoPhaseNonInteractiveMethodName(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
defer closer()
ui, cmd := testLoginCommand(t)
methodName := "foo"
waitPeriod := 5
userclient, entityID, methodID := testhelpers.SetupLoginMFATOTP(t, client, methodName, waitPeriod)
cmd.client = userclient
engineName := testhelpers.RegisterEntityInTOTPEngine(t, client, entityID, methodID)
// clear the MFA creds just to be sure
cmd.client.SetMFACreds([]string{})
code := cmd.Run([]string{
"-method", "userpass",
"-non-interactive",
"username=testuser1",
"password=testpassword",
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d", code, exp)
}
output = ui.OutputWriter.String()
reqIdReg := regexp.MustCompile(`mfa_request_id\s+([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})\s+mfa_constraint`)
reqIDRaw := reqIdReg.FindAllStringSubmatch(output, -1)
if len(reqIDRaw) == 0 || len(reqIDRaw[0]) < 2 {
t.Fatal("failed to MFA request ID from output")
}
mfaReqID := reqIDRaw[0][1]
validateFunc := func(methodIdentifier string) {
// the time required for the totp engine to generate a new code
time.Sleep(time.Duration(waitPeriod) * time.Second)
totpPasscode1 := "passcode=" + testhelpers.GetTOTPCodeFromEngine(t, client, engineName)
secret, err := cmd.client.Logical().WriteWithContext(context.Background(), "sys/mfa/validate", map[string]interface{}{
"mfa_request_id": mfaReqID,
"mfa_payload": map[string][]string{
methodIdentifier: {totpPasscode1},
},
})
if err != nil {
t.Fatalf("mfa validation failed: %v", err)
}
if secret.Auth == nil || secret.Auth.ClientToken == "" {
t.Fatalf("mfa validation did not return a client token")
}
}
validateFunc(methodName)
} }

View File

@ -33,7 +33,7 @@ const (
GenerateRecovery GenerateRecovery
) )
// Generates a root token on the target cluster. // GenerateRoot generates a root token on the target cluster.
func GenerateRoot(t testing.T, cluster *vault.TestCluster, kind GenerateRootKind) string { func GenerateRoot(t testing.T, cluster *vault.TestCluster, kind GenerateRootKind) string {
t.Helper() t.Helper()
token, err := GenerateRootWithError(t, cluster, kind) token, err := GenerateRootWithError(t, cluster, kind)
@ -767,6 +767,21 @@ func SetNonRootToken(client *api.Client) error {
return nil return nil
} }
// RetryUntilAtCadence runs f until it returns a nil result or the timeout is reached.
// If a nil result hasn't been obtained by timeout, calls t.Fatal.
func RetryUntilAtCadence(t testing.T, timeout, sleepTime time.Duration, f func() error) {
t.Helper()
deadline := time.Now().Add(timeout)
var err error
for time.Now().Before(deadline) {
if err = f(); err == nil {
return
}
time.Sleep(sleepTime)
}
t.Fatalf("did not complete before deadline, err: %v", err)
}
// RetryUntil runs f until it returns a nil result or the timeout is reached. // RetryUntil runs f until it returns a nil result or the timeout is reached.
// If a nil result hasn't been obtained by timeout, calls t.Fatal. // If a nil result hasn't been obtained by timeout, calls t.Fatal.
func RetryUntil(t testing.T, timeout time.Duration, f func() error) { func RetryUntil(t testing.T, timeout time.Duration, f func() error) {
@ -942,7 +957,7 @@ func GetTOTPCodeFromEngine(t testing.T, client *api.Client, enginePath string) s
// SetupLoginMFATOTP setups up a TOTP MFA using some basic configuration and // SetupLoginMFATOTP setups up a TOTP MFA using some basic configuration and
// returns all relevant information to the client. // returns all relevant information to the client.
func SetupLoginMFATOTP(t testing.T, client *api.Client) (*api.Client, string, string) { func SetupLoginMFATOTP(t testing.T, client *api.Client, methodName string, waitPeriod int) (*api.Client, string, string) {
t.Helper() t.Helper()
// Mount the totp secrets engine // Mount the totp secrets engine
SetupTOTPMount(t, client) SetupTOTPMount(t, client)
@ -956,13 +971,14 @@ func SetupLoginMFATOTP(t testing.T, client *api.Client) (*api.Client, string, st
// Configure a default TOTP method // Configure a default TOTP method
totpConfig := map[string]interface{}{ totpConfig := map[string]interface{}{
"issuer": "yCorp", "issuer": "yCorp",
"period": 20, "period": waitPeriod,
"algorithm": "SHA256", "algorithm": "SHA256",
"digits": 6, "digits": 6,
"skew": 1, "skew": 1,
"key_size": 20, "key_size": 20,
"qr_size": 200, "qr_size": 200,
"max_validation_attempts": 5, "max_validation_attempts": 5,
"method_name": methodName,
} }
methodID := SetupTOTPMethod(t, client, totpConfig) methodID := SetupTOTPMethod(t, client, totpConfig)

View File

@ -318,6 +318,7 @@ type MFAMethodID struct {
Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"` Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"`
ID string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"` ID string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"`
UsesPasscode bool `protobuf:"varint,3,opt,name=uses_passcode,json=usesPasscode,proto3" json:"uses_passcode,omitempty"` UsesPasscode bool `protobuf:"varint,3,opt,name=uses_passcode,json=usesPasscode,proto3" json:"uses_passcode,omitempty"`
Name string `protobuf:"bytes,4,opt,name=name,proto3" json:"name,omitempty"`
} }
func (x *MFAMethodID) Reset() { func (x *MFAMethodID) Reset() {
@ -373,6 +374,13 @@ func (x *MFAMethodID) GetUsesPasscode() bool {
return false return false
} }
func (x *MFAMethodID) GetName() string {
if x != nil {
return x.Name
}
return ""
}
type MFAConstraintAny struct { type MFAConstraintAny struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
@ -537,34 +545,35 @@ var file_sdk_logical_identity_proto_rawDesc = []byte{
0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01,
0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65,
0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38,
0x01, 0x22, 0x56, 0x0a, 0x0b, 0x4d, 0x46, 0x41, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x49, 0x44, 0x01, 0x22, 0x6a, 0x0a, 0x0b, 0x4d, 0x46, 0x41, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x49, 0x44,
0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04,
0x74, 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x74, 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09,
0x52, 0x02, 0x69, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x75, 0x73, 0x65, 0x73, 0x5f, 0x70, 0x61, 0x73, 0x52, 0x02, 0x69, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x75, 0x73, 0x65, 0x73, 0x5f, 0x70, 0x61, 0x73,
0x73, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x75, 0x73, 0x65, 0x73, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x75, 0x73, 0x65,
0x73, 0x50, 0x61, 0x73, 0x73, 0x63, 0x6f, 0x64, 0x65, 0x22, 0x3a, 0x0a, 0x10, 0x4d, 0x46, 0x41, 0x73, 0x50, 0x61, 0x73, 0x73, 0x63, 0x6f, 0x64, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d,
0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x74, 0x41, 0x6e, 0x79, 0x12, 0x26, 0x0a, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x3a, 0x0a,
0x03, 0x61, 0x6e, 0x79, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6c, 0x6f, 0x67, 0x10, 0x4d, 0x46, 0x41, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x74, 0x41, 0x6e,
0x69, 0x63, 0x61, 0x6c, 0x2e, 0x4d, 0x46, 0x41, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x49, 0x44, 0x79, 0x12, 0x26, 0x0a, 0x03, 0x61, 0x6e, 0x79, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14,
0x52, 0x03, 0x61, 0x6e, 0x79, 0x22, 0xea, 0x01, 0x0a, 0x0e, 0x4d, 0x46, 0x41, 0x52, 0x65, 0x71, 0x2e, 0x6c, 0x6f, 0x67, 0x69, 0x63, 0x61, 0x6c, 0x2e, 0x4d, 0x46, 0x41, 0x4d, 0x65, 0x74, 0x68,
0x75, 0x69, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0e, 0x6d, 0x66, 0x61, 0x5f, 0x6f, 0x64, 0x49, 0x44, 0x52, 0x03, 0x61, 0x6e, 0x79, 0x22, 0xea, 0x01, 0x0a, 0x0e, 0x4d, 0x46,
0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x41, 0x52, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0e,
0x52, 0x0c, 0x6d, 0x66, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, 0x12, 0x54, 0x6d, 0x66, 0x61, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01,
0x0a, 0x0f, 0x6d, 0x66, 0x61, 0x5f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x74, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x6d, 0x66, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2b, 0x2e, 0x6c, 0x6f, 0x67, 0x69, 0x63, 0x61, 0x49, 0x64, 0x12, 0x54, 0x0a, 0x0f, 0x6d, 0x66, 0x61, 0x5f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x72,
0x6c, 0x2e, 0x4d, 0x46, 0x41, 0x52, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x74, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2b, 0x2e, 0x6c, 0x6f,
0x2e, 0x4d, 0x66, 0x61, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x74, 0x73, 0x45, 0x67, 0x69, 0x63, 0x61, 0x6c, 0x2e, 0x4d, 0x46, 0x41, 0x52, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65,
0x6e, 0x74, 0x72, 0x79, 0x52, 0x0e, 0x6d, 0x66, 0x61, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x61, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x66, 0x61, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x61, 0x69,
0x69, 0x6e, 0x74, 0x73, 0x1a, 0x5c, 0x0a, 0x13, 0x4d, 0x66, 0x61, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x6e, 0x74, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0e, 0x6d, 0x66, 0x61, 0x43, 0x6f, 0x6e,
0x72, 0x61, 0x69, 0x6e, 0x74, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x73, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x74, 0x73, 0x1a, 0x5c, 0x0a, 0x13, 0x4d, 0x66, 0x61, 0x43,
0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x2f, 0x0a, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x74, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12,
0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x6c, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65,
0x6f, 0x67, 0x69, 0x63, 0x61, 0x6c, 0x2e, 0x4d, 0x46, 0x41, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x79, 0x12, 0x2f, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b,
0x61, 0x69, 0x6e, 0x74, 0x41, 0x6e, 0x79, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x32, 0x19, 0x2e, 0x6c, 0x6f, 0x67, 0x69, 0x63, 0x61, 0x6c, 0x2e, 0x4d, 0x46, 0x41, 0x43, 0x6f,
0x38, 0x01, 0x42, 0x28, 0x5a, 0x26, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x6e, 0x73, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x74, 0x41, 0x6e, 0x79, 0x52, 0x05, 0x76, 0x61, 0x6c,
0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x28, 0x5a, 0x26, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62,
0x2f, 0x73, 0x64, 0x6b, 0x2f, 0x6c, 0x6f, 0x67, 0x69, 0x63, 0x61, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x76,
0x6f, 0x74, 0x6f, 0x33, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x73, 0x64, 0x6b, 0x2f, 0x6c, 0x6f, 0x67, 0x69, 0x63, 0x61, 0x6c,
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
var ( var (

View File

@ -79,6 +79,7 @@ message MFAMethodID {
string type = 1; string type = 1;
string id = 2; string id = 2;
bool uses_passcode = 3; bool uses_passcode = 3;
string name = 4;
} }
message MFAConstraintAny { message MFAConstraintAny {

View File

@ -93,16 +93,17 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) {
// Creating two users in the userpass auth mount // Creating two users in the userpass auth mount
userClient1, entityID1, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, entity1, testuser1) userClient1, entityID1, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, entity1, testuser1)
userClient2, entityID2, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, entity2, testuser2) userClient2, entityID2, _ := testhelpers.CreateEntityAndAlias(t, client, mountAccessor, entity2, testuser2)
waitPeriod := 5
totpConfig := map[string]interface{}{ totpConfig := map[string]interface{}{
"issuer": "yCorp", "issuer": "yCorp",
"period": 5, "period": waitPeriod,
"algorithm": "SHA1", "algorithm": "SHA1",
"digits": 6, "digits": 6,
"skew": 1, "skew": 1,
"key_size": 10, "key_size": 10,
"qr_size": 100, "qr_size": 100,
"max_validation_attempts": 3, "max_validation_attempts": 3,
"method_name": "foo",
} }
methodID := testhelpers.SetupTOTPMethod(t, client, totpConfig) methodID := testhelpers.SetupTOTPMethod(t, client, totpConfig)
@ -123,22 +124,7 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) {
userpassPath := fmt.Sprintf("auth/userpass/login/%s", testuser1) userpassPath := fmt.Sprintf("auth/userpass/login/%s", testuser1)
// MFA single-phase login // MFA single-phase login
time.Sleep(5 * time.Second) verifyLoginRequest := func(secret *api.Secret) {
var secret *api.Secret
testhelpers.RetryUntil(t, 20*time.Second, func() error {
var err error
totpPasscode := testhelpers.GetTOTPCodeFromEngine(t, client, enginePath1)
userClient1.AddHeader("X-Vault-MFA", fmt.Sprintf("%s:%s", methodID, totpPasscode))
secret, err = userClient1.Logical().WriteWithContext(context.Background(), userpassPath, map[string]interface{}{
"password": "testpassword",
})
if err != nil {
return fmt.Errorf("MFA failed: %w", err)
}
return nil
})
userpassToken := secret.Auth.ClientToken userpassToken := secret.Auth.ClientToken
userClient1.SetToken(client.Token()) userClient1.SetToken(client.Token())
secret, err := userClient1.Logical().WriteWithContext(context.Background(), "auth/token/lookup", map[string]interface{}{ secret, err := userClient1.Logical().WriteWithContext(context.Background(), "auth/token/lookup", map[string]interface{}{
@ -152,11 +138,44 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) {
if entityIDCheck != entityID1 { if entityIDCheck != entityID1 {
t.Fatalf("different entityID assigned") t.Fatalf("different entityID assigned")
} }
}
// helper function to clear the MFA request header
clearMFARequestHeaders := func(c *api.Client) {
headers := c.Headers()
headers.Del("X-Vault-MFA")
c.SetHeaders(headers)
}
var secret *api.Secret
var err error
var methodIdentifier string
singlePhaseLoginFunc := func() error {
totpPasscode := testhelpers.GetTOTPCodeFromEngine(t, client, enginePath1)
userClient1.AddHeader("X-Vault-MFA", fmt.Sprintf("%s:%s", methodIdentifier, totpPasscode))
defer clearMFARequestHeaders(userClient1)
secret, err = userClient1.Logical().WriteWithContext(context.Background(), userpassPath, map[string]interface{}{
"password": "testpassword",
})
if err != nil {
return fmt.Errorf("MFA failed for identifier %s: %v", methodIdentifier, err)
}
return nil
}
// single phase login for both method name and method ID
methodIdentifier = totpConfig["method_name"].(string)
testhelpers.RetryUntilAtCadence(t, 20*time.Second, 100*time.Millisecond, singlePhaseLoginFunc)
verifyLoginRequest(secret)
methodIdentifier = methodID
// Need to wait a bit longer to avoid hitting maximum allowed consecutive
// failed TOTP validation
testhelpers.RetryUntilAtCadence(t, 20*time.Second, time.Duration(waitPeriod)*time.Second, singlePhaseLoginFunc)
verifyLoginRequest(secret)
// Two-phase login // Two-phase login
headers := userClient1.Headers()
headers.Del("X-Vault-MFA")
userClient1.SetHeaders(headers)
secret, err = userClient1.Logical().WriteWithContext(context.Background(), userpassPath, map[string]interface{}{ secret, err = userClient1.Logical().WriteWithContext(context.Background(), userpassPath, map[string]interface{}{
"password": "testpassword", "password": "testpassword",
}) })
@ -191,26 +210,43 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) {
} }
// validation // validation
time.Sleep(5 * time.Second) var mfaReqID string
var totpPasscode1 string var totpPasscode1 string
testhelpers.RetryUntil(t, 20*time.Second, func() error { mfaValidateFunc := func() error {
totpPasscode1 = testhelpers.GetTOTPCodeFromEngine(t, client, enginePath1) totpPasscode1 = testhelpers.GetTOTPCodeFromEngine(t, client, enginePath1)
secret, err = userClient1.Logical().WriteWithContext(context.Background(), "sys/mfa/validate", map[string]interface{}{ secret, err = userClient1.Logical().WriteWithContext(context.Background(), "sys/mfa/validate", map[string]interface{}{
"mfa_request_id": secret.Auth.MFARequirement.MFARequestID, "mfa_request_id": mfaReqID,
"mfa_payload": map[string][]string{ "mfa_payload": map[string][]string{
methodID: {totpPasscode1}, methodIdentifier: {totpPasscode1},
}, },
}) })
if err != nil { if err != nil {
return fmt.Errorf("MFA failed: %w", err) return fmt.Errorf("MFA failed: %v", err)
} }
return nil
})
if secret.Auth == nil || secret.Auth.ClientToken == "" { if secret.Auth == nil || secret.Auth.ClientToken == "" {
t.Fatalf("successful mfa validation did not return a client token") t.Fatalf("successful mfa validation did not return a client token")
} }
return nil
}
methodIdentifier = methodID
mfaReqID = secret.Auth.MFARequirement.MFARequestID
testhelpers.RetryUntilAtCadence(t, 20*time.Second, time.Duration(waitPeriod)*time.Second, mfaValidateFunc)
// two phase login with method name
secret, err = userClient1.Logical().WriteWithContext(context.Background(), userpassPath, map[string]interface{}{
"password": "testpassword",
})
if err != nil {
t.Fatalf("MFA failed: %v", err)
}
methodIdentifier = totpConfig["method_name"].(string)
mfaReqID = secret.Auth.MFARequirement.MFARequestID
testhelpers.RetryUntilAtCadence(t, 20*time.Second, time.Duration(waitPeriod)*time.Second, mfaValidateFunc)
// checking audit log
if noop.Req == nil { if noop.Req == nil {
t.Fatalf("no request was logged in audit log") t.Fatalf("no request was logged in audit log")
} }

View File

@ -13,7 +13,7 @@ import (
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
) )
// TestLoginMFA_Method_CRUD tests creating/reading/updating/deleting a method config for all of the MFA providers // TestLoginMFA_Method_CRUD tests creating/reading/updating/deleting a method config for all the MFA providers
func TestLoginMFA_Method_CRUD(t *testing.T) { func TestLoginMFA_Method_CRUD(t *testing.T) {
cluster := vault.NewTestCluster(t, &vault.CoreConfig{ cluster := vault.NewTestCluster(t, &vault.CoreConfig{
CredentialBackends: map[string]logical.Factory{ CredentialBackends: map[string]logical.Factory{
@ -216,6 +216,126 @@ func TestLoginMFA_Method_CRUD(t *testing.T) {
} }
} }
func TestLoginMFAMethodName(t *testing.T) {
cluster := vault.NewTestCluster(t, &vault.CoreConfig{
CredentialBackends: map[string]logical.Factory{
"userpass": userpass.Factory,
},
}, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
core := cluster.Cores[0].Core
vault.TestWaitActive(t, core)
client := cluster.Cores[0].Client
// Enable userpass authentication
err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
Type: "userpass",
})
if err != nil {
t.Fatalf("failed to enable userpass auth: %v", err)
}
auths, err := client.Sys().ListAuth()
if err != nil {
t.Fatal(err)
}
mountAccessor := auths["userpass/"].Accessor
testCases := []struct {
methodType string
configData map[string]interface{}
}{
{
"totp",
map[string]interface{}{
"issuer": "yCorp",
"method_name": "totp-method",
},
},
{
"duo",
map[string]interface{}{
"mount_accessor": mountAccessor,
"secret_key": "lol-secret",
"integration_key": "integration-key",
"api_hostname": "some-hostname",
"method_name": "duo-method",
},
},
{
"okta",
map[string]interface{}{
"mount_accessor": mountAccessor,
"base_url": "example.com",
"org_name": "my-org",
"api_token": "lol-token",
"method_name": "okta-method",
},
},
{
"pingid",
map[string]interface{}{
"mount_accessor": mountAccessor,
"settings_file_base64": "I0F1dG8tR2VuZXJhdGVkIGZyb20gUGluZ09uZSwgZG93bmxvYWRlZCBieSBpZD1bU1NPXSBlbWFpbD1baGFtaWRAaGFzaGljb3JwLmNvbV0KI1dlZCBEZWMgMTUgMTM6MDg6NDQgTVNUIDIwMjEKdXNlX2Jhc2U2NF9rZXk9YlhrdGMyVmpjbVYwTFd0bGVRPT0KdXNlX3NpZ25hdHVyZT10cnVlCnRva2VuPWxvbC10b2tlbgppZHBfdXJsPWh0dHBzOi8vaWRweG55bDNtLnBpbmdpZGVudGl0eS5jb20vcGluZ2lkCm9yZ19hbGlhcz1sb2wtb3JnLWFsaWFzCmFkbWluX3VybD1odHRwczovL2lkcHhueWwzbS5waW5naWRlbnRpdHkuY29tL3BpbmdpZAphdXRoZW50aWNhdG9yX3VybD1odHRwczovL2F1dGhlbnRpY2F0b3IucGluZ29uZS5jb20vcGluZ2lkL3BwbQ==",
"method_name": "pingid-method",
},
},
}
for _, tc := range testCases {
t.Run(tc.methodType, func(t *testing.T) {
// create a new method config
myPath := fmt.Sprintf("identity/mfa/method/%s", tc.methodType)
resp, err := client.Logical().Write(myPath, tc.configData)
if err != nil {
t.Fatal(err)
}
methodId := resp.Data["method_id"]
if methodId == "" {
t.Fatal("method id is empty")
}
// creating an MFA config with the same name should not return a new method ID
resp, err = client.Logical().Write(myPath, tc.configData)
if err != nil {
t.Fatal(err)
}
if methodId != resp.Data["method_id"] {
t.Fatal("trying to create a new MFA config with the same name should not result in a new MFA config")
}
originalName := tc.configData["method_name"]
// create a new MFA config name
tc.configData["method_name"] = "newName"
resp, err = client.Logical().Write(myPath, tc.configData)
if err != nil {
t.Fatal(err)
}
myNewPath := fmt.Sprintf("%s/%s", myPath, methodId)
// Updating an existing MFA config with another config's name
resp, err = client.Logical().Write(myNewPath, tc.configData)
if err == nil {
t.Fatalf("expected a failure for configuring an MFA method with an existing MFA method name, %v", err)
}
// Create a method with a / in the name
tc.configData["method_name"] = fmt.Sprintf("ns1/%s", originalName)
_, err = client.Logical().Write(myNewPath, tc.configData)
if err != nil {
t.Fatal(err)
}
})
}
}
// TestLoginMFA_ListAllMFAConfigs tests listing all configs globally // TestLoginMFA_ListAllMFAConfigs tests listing all configs globally
func TestLoginMFA_ListAllMFAConfigsGlobally(t *testing.T) { func TestLoginMFA_ListAllMFAConfigsGlobally(t *testing.T) {
cluster := vault.NewTestCluster(t, &vault.CoreConfig{ cluster := vault.NewTestCluster(t, &vault.CoreConfig{

View File

@ -170,6 +170,10 @@ func mfaPaths(i *IdentityStore) []*framework.Path {
{ {
Pattern: "mfa/method/totp" + genericOptionalUUIDRegex("method_id"), Pattern: "mfa/method/totp" + genericOptionalUUIDRegex("method_id"),
Fields: map[string]*framework.FieldSchema{ Fields: map[string]*framework.FieldSchema{
"method_name": {
Type: framework.TypeString,
Description: `The unique name identifier for this MFA method.`,
},
"method_id": { "method_id": {
Type: framework.TypeString, Type: framework.TypeString,
Description: `The unique identifier for this MFA method.`, Description: `The unique identifier for this MFA method.`,
@ -298,6 +302,10 @@ func mfaPaths(i *IdentityStore) []*framework.Path {
{ {
Pattern: "mfa/method/okta" + genericOptionalUUIDRegex("method_id"), Pattern: "mfa/method/okta" + genericOptionalUUIDRegex("method_id"),
Fields: map[string]*framework.FieldSchema{ Fields: map[string]*framework.FieldSchema{
"method_name": {
Type: framework.TypeString,
Description: `The unique name identifier for this MFA method.`,
},
"method_id": { "method_id": {
Type: framework.TypeString, Type: framework.TypeString,
Description: `The unique identifier for this MFA method.`, Description: `The unique identifier for this MFA method.`,
@ -354,6 +362,10 @@ func mfaPaths(i *IdentityStore) []*framework.Path {
{ {
Pattern: "mfa/method/duo" + genericOptionalUUIDRegex("method_id"), Pattern: "mfa/method/duo" + genericOptionalUUIDRegex("method_id"),
Fields: map[string]*framework.FieldSchema{ Fields: map[string]*framework.FieldSchema{
"method_name": {
Type: framework.TypeString,
Description: `The unique name identifier for this MFA method.`,
},
"method_id": { "method_id": {
Type: framework.TypeString, Type: framework.TypeString,
Description: `The unique identifier for this MFA method.`, Description: `The unique identifier for this MFA method.`,
@ -410,6 +422,10 @@ func mfaPaths(i *IdentityStore) []*framework.Path {
{ {
Pattern: "mfa/method/pingid" + genericOptionalUUIDRegex("method_id"), Pattern: "mfa/method/pingid" + genericOptionalUUIDRegex("method_id"),
Fields: map[string]*framework.FieldSchema{ Fields: map[string]*framework.FieldSchema{
"method_name": {
Type: framework.TypeString,
Description: `The unique name identifier for this MFA method.`,
},
"method_id": { "method_id": {
Type: framework.TypeString, Type: framework.TypeString,
Description: `The unique identifier for this MFA method.`, Description: `The unique identifier for this MFA method.`,

View File

@ -269,6 +269,7 @@ func (i *IdentityStore) handleMFAMethodUpdateCommon(ctx context.Context, req *lo
} }
methodID := d.Get("method_id").(string) methodID := d.Get("method_id").(string)
methodName := d.Get("method_name").(string)
b := i.mfaBackend b := i.mfaBackend
b.mfaLock.Lock() b.mfaLock.Lock()
@ -286,6 +287,23 @@ func (i *IdentityStore) handleMFAMethodUpdateCommon(ctx context.Context, req *lo
} }
} }
// check if an MFA method configuration exists with that method name
if methodName != "" {
namedMfaConfig, err := b.MemDBMFAConfigByName(ctx, methodName)
if err != nil {
return nil, err
}
if namedMfaConfig != nil {
if mConfig == nil {
mConfig = namedMfaConfig
} else {
if mConfig.ID != namedMfaConfig.ID {
return nil, fmt.Errorf("a login MFA method configuration with the method name %s already exists", methodName)
}
}
}
}
if mConfig == nil { if mConfig == nil {
configID, err := uuid.GenerateUUID() configID, err := uuid.GenerateUUID()
if err != nil { if err != nil {
@ -298,6 +316,11 @@ func (i *IdentityStore) handleMFAMethodUpdateCommon(ctx context.Context, req *lo
} }
} }
// Updating the method config name
if methodName != "" {
mConfig.Name = methodName
}
mfaNs, err := i.namespacer.NamespaceByID(ctx, mConfig.NamespaceID) mfaNs, err := i.namespacer.NamespaceByID(ctx, mConfig.NamespaceID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -647,6 +670,50 @@ func (b *LoginMFABackend) loginMFAMethodExistenceCheck(eConfig *mfa.MFAEnforceme
return aggErr.ErrorOrNil() return aggErr.ErrorOrNil()
} }
// sanitizeMFACredsWithLoginEnforcementMethodIDs updates the MFACred map
// looping through the matched login enforcement configurations, and
// replacing MFA method names with MFA method IDs
func (b *LoginMFABackend) sanitizeMFACredsWithLoginEnforcementMethodIDs(ctx context.Context, mfaCredsMap logical.MFACreds, mfaMethodIDs []string) (logical.MFACreds, error) {
sanitizedMfaCreds := make(logical.MFACreds, 0)
var multiError *multierror.Error
for _, methodID := range mfaMethodIDs {
val, ok := mfaCredsMap[methodID]
if ok {
sanitizedMfaCreds[methodID] = val
continue
}
mConfig, err := b.MemDBMFAConfigByID(methodID)
if err != nil {
return nil, err
}
// method name in the MFACredsMap should be the method full name,
// i.e., namespacePath+name. This is because, a user in a child
// namespace can reference an MFA method ID in a parent namespace
configNS, err := NamespaceByID(ctx, mConfig.NamespaceID, b.Core)
if err != nil {
return nil, err
}
if configNS != nil {
val, ok = mfaCredsMap[configNS.Path+mConfig.Name]
if ok {
sanitizedMfaCreds[mConfig.ID] = val
} else {
multiError = multierror.Append(multiError, fmt.Errorf("failed to find MFA credentials associated with an MFA method ID %v, method name %v", methodID, configNS.Path+mConfig.Name))
}
} else {
multiError = multierror.Append(multiError, fmt.Errorf("failed to find the namespace associated with an MFA method ID %v", mConfig.ID))
}
}
// we don't need to find every MFA method identifiers in the MFA header
// So, don't return errors if that is the case.
if len(sanitizedMfaCreds) > 0 {
return sanitizedMfaCreds, nil
}
return sanitizedMfaCreds, multiError
}
func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logical.Request, d *framework.FieldData) (retResp *logical.Response, retErr error) { func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logical.Request, d *framework.FieldData) (retResp *logical.Response, retErr error) {
// mfaReqID is the ID of the login request // mfaReqID is the ID of the login request
mfaReqID := d.Get("mfa_request_id").(string) mfaReqID := d.Get("mfa_request_id").(string)
@ -655,13 +722,13 @@ func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logic
} }
// a map of methodID to passcode // a map of methodID to passcode
methodIDToPasscodeInterface := d.Get("mfa_payload") mfaPayload := d.Get("mfa_payload")
if methodIDToPasscodeInterface == nil { if mfaPayload == nil {
return logical.ErrorResponse("missing mfa payload"), nil return logical.ErrorResponse("missing mfa payload"), nil
} }
var mfaCreds logical.MFACreds var mfaCreds logical.MFACreds
err := mapstructure.Decode(methodIDToPasscodeInterface, &mfaCreds) err := mapstructure.Decode(mfaPayload, &mfaCreds)
if err != nil { if err != nil {
return logical.ErrorResponse("invalid mfa payload"), nil return logical.ErrorResponse("invalid mfa payload"), nil
} }
@ -1574,11 +1641,19 @@ func parseOktaConfig(mConfig *mfa.Config, d *framework.FieldData) error {
} }
func (c *Core) validateLoginMFA(ctx context.Context, eConfig *mfa.MFAEnforcementConfig, entity *identity.Entity, requestConnRemoteAddr string, mfaCredsMap logical.MFACreds) error { func (c *Core) validateLoginMFA(ctx context.Context, eConfig *mfa.MFAEnforcementConfig, entity *identity.Entity, requestConnRemoteAddr string, mfaCredsMap logical.MFACreds) error {
sanitizedMfaCreds, err := c.loginMFABackend.sanitizeMFACredsWithLoginEnforcementMethodIDs(ctx, mfaCredsMap, eConfig.MFAMethodIDs)
if err != nil {
return fmt.Errorf("failed to sanitize MFA creds, %w", err)
}
if len(sanitizedMfaCreds) == 0 && len(eConfig.MFAMethodIDs) > 0 {
return fmt.Errorf("login MFA validation failed for methodID: %v", eConfig.MFAMethodIDs)
}
var retErr error var retErr error
for _, methodID := range eConfig.MFAMethodIDs { for _, methodID := range eConfig.MFAMethodIDs {
// as configID is the same as methodID, and methodID is unique, we can // as configID is the same as methodID, and methodID is unique, we can
// use it to retrieve the MFACreds // use it to retrieve the MFACreds
mfaCreds, ok := mfaCredsMap[methodID] mfaCreds, ok := sanitizedMfaCreds[methodID]
if !ok || mfaCreds == nil { if !ok || mfaCreds == nil {
continue continue
} }
@ -1634,6 +1709,11 @@ func (c *Core) validateLoginMFAInternal(ctx context.Context, methodID string, en
} }
} }
mfaFactors, err := parseMfaFactors(mfaCreds)
if err != nil {
return fmt.Errorf("failed to parse MFA factor, %w", err)
}
switch mConfig.Type { switch mConfig.Type {
case mfaMethodTypeTOTP: case mfaMethodTypeTOTP:
// Get the MFA secret data required to validate the supplied credentials // Get the MFA secret data required to validate the supplied credentials
@ -1645,17 +1725,13 @@ func (c *Core) validateLoginMFAInternal(ctx context.Context, methodID string, en
return fmt.Errorf("MFA secret for method name %q not present in entity %q", mConfig.Name, entity.ID) return fmt.Errorf("MFA secret for method name %q not present in entity %q", mConfig.Name, entity.ID)
} }
if mfaCreds == nil { return c.validateTOTP(ctx, mfaFactors, entityMFASecret, mConfig.ID, entity.ID, c.loginMFABackend.usedCodes, mConfig.GetTOTPConfig().MaxValidationAttempts)
return fmt.Errorf("MFA credentials not supplied")
}
return c.validateTOTP(ctx, mfaCreds, entityMFASecret, mConfig.ID, entity.ID, c.loginMFABackend.usedCodes, mConfig.GetTOTPConfig().MaxValidationAttempts)
case mfaMethodTypeOkta: case mfaMethodTypeOkta:
return c.validateOkta(ctx, mConfig, finalUsername) return c.validateOkta(ctx, mConfig, finalUsername)
case mfaMethodTypeDuo: case mfaMethodTypeDuo:
return c.validateDuo(ctx, mfaCreds, mConfig, finalUsername, reqConnectionRemoteAddress) return c.validateDuo(ctx, mfaFactors, mConfig, finalUsername, reqConnectionRemoteAddress)
case mfaMethodTypePingID: case mfaMethodTypePingID:
return c.validatePingID(ctx, mConfig, finalUsername) return c.validatePingID(ctx, mConfig, finalUsername)
@ -1764,23 +1840,52 @@ func formatUsername(format string, alias *identity.Alias, entity *identity.Entit
return username return username
} }
func (c *Core) validateDuo(ctx context.Context, creds []string, mConfig *mfa.Config, username, reqConnectionRemoteAddr string) error { type MFAFactor struct {
passcode string
}
func parseMfaFactors(creds []string) (*MFAFactor, error) {
mfaFactor := &MFAFactor{}
for _, cred := range creds {
switch {
case cred == "": // for the case of push notification
continue
case strings.HasPrefix(cred, "passcode="):
if mfaFactor.passcode != "" {
return nil, fmt.Errorf("found multiple passcodes for the same MFA method")
}
splits := strings.SplitN(cred, "=", 2)
if splits[1] == "" {
return nil, fmt.Errorf("invalid passcode")
}
mfaFactor.passcode = splits[1]
case strings.Contains(cred, "="):
return nil, fmt.Errorf("found an invalid MFA cred: %v", cred)
default:
// a non-empty cred that does not match the above
// means it is a passcode
if mfaFactor.passcode != "" {
return nil, fmt.Errorf("found multiple passcodes for the same MFA method")
}
mfaFactor.passcode = cred
}
}
return mfaFactor, nil
}
func (c *Core) validateDuo(ctx context.Context, mfaFactors *MFAFactor, mConfig *mfa.Config, username, reqConnectionRemoteAddr string) error {
duoConfig := mConfig.GetDuoConfig() duoConfig := mConfig.GetDuoConfig()
if duoConfig == nil { if duoConfig == nil {
return fmt.Errorf("failed to get Duo configuration for method %q", mConfig.Name) return fmt.Errorf("failed to get Duo configuration for method %q", mConfig.Name)
} }
passcode := "" var passcode string
for _, cred := range creds { if mfaFactors != nil {
if strings.HasPrefix(cred, "passcode") { passcode = mfaFactors.passcode
splits := strings.SplitN(cred, "=", 2)
if len(splits) != 2 {
return fmt.Errorf("invalid credential %q", cred)
}
if splits[0] == "passcode" {
passcode = splits[1]
}
}
} }
client := duoapi.NewDuoApi( client := duoapi.NewDuoApi(
@ -2229,21 +2334,18 @@ func (c *Core) validatePingID(ctx context.Context, mConfig *mfa.Config, username
return nil return nil
} }
func (c *Core) validateTOTP(ctx context.Context, creds []string, entityMethodSecret *mfa.Secret, configID, entityID string, usedCodes *cache.Cache, maximumValidationAttempts uint32) error { func (c *Core) validateTOTP(ctx context.Context, mfaFactors *MFAFactor, entityMethodSecret *mfa.Secret, configID, entityID string, usedCodes *cache.Cache, maximumValidationAttempts uint32) error {
if len(creds) == 0 { if mfaFactors.passcode == "" {
return fmt.Errorf("missing TOTP passcode") return fmt.Errorf("MFA credentials not supplied")
}
if len(creds) > 1 {
return fmt.Errorf("more than one TOTP passcode supplied")
} }
passcode := mfaFactors.passcode
totpSecret := entityMethodSecret.GetTOTPSecret() totpSecret := entityMethodSecret.GetTOTPSecret()
if totpSecret == nil { if totpSecret == nil {
return fmt.Errorf("entity does not contain the TOTP secret") return fmt.Errorf("entity does not contain the TOTP secret")
} }
usedName := fmt.Sprintf("%s_%s", configID, creds[0]) usedName := fmt.Sprintf("%s_%s", configID, passcode)
_, ok := usedCodes.Get(usedName) _, ok := usedCodes.Get(usedName)
if ok { if ok {
@ -2290,7 +2392,7 @@ func (c *Core) validateTOTP(ctx context.Context, creds []string, entityMethodSec
Algorithm: otplib.Algorithm(int(totpSecret.Algorithm)), Algorithm: otplib.Algorithm(int(totpSecret.Algorithm)),
} }
valid, err := totplib.ValidateCustom(creds[0], key, time.Now(), validateOpts) valid, err := totplib.ValidateCustom(passcode, key, time.Now(), validateOpts)
if err != nil && err != otplib.ErrValidateInputInvalidLength { if err != nil && err != otplib.ErrValidateInputInvalidLength {
return errwrap.Wrapf("failed to validate TOTP passcode: {{err}}", err) return errwrap.Wrapf("failed to validate TOTP passcode: {{err}}", err)
} }
@ -2340,6 +2442,21 @@ func loginMFAConfigTableSchema() *memdb.TableSchema {
Field: "Type", Field: "Type",
}, },
}, },
"name": {
Name: "name",
Unique: true,
AllowMissing: true,
Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "NamespaceID",
},
&memdb.StringFieldIndex{
Field: "Name",
},
},
},
},
}, },
} }
} }
@ -2487,6 +2604,47 @@ func (b *LoginMFABackend) MemDBMFAConfigByID(mConfigID string) (*mfa.Config, err
return b.MemDBMFAConfigByIDInTxn(txn, mConfigID) return b.MemDBMFAConfigByIDInTxn(txn, mConfigID)
} }
func (b *LoginMFABackend) MemDBMFAConfigByNameInTxn(ctx context.Context, txn *memdb.Txn, mConfigName string) (*mfa.Config, error) {
if mConfigName == "" {
return nil, fmt.Errorf("missing config name")
}
if txn == nil {
return nil, fmt.Errorf("txn is nil")
}
ns, err := namespace.FromContext(ctx)
if err != nil {
return nil, err
}
mConfigRaw, err := txn.First(b.methodTable, "name", ns.ID, mConfigName)
if err != nil {
return nil, fmt.Errorf("failed to fetch MFA config from memdb using name: %w", err)
}
if mConfigRaw == nil {
return nil, nil
}
mConfig, ok := mConfigRaw.(*mfa.Config)
if !ok {
return nil, fmt.Errorf("failed to declare the type of fetched MFA config")
}
return mConfig.Clone()
}
func (b *LoginMFABackend) MemDBMFAConfigByName(ctx context.Context, name string) (*mfa.Config, error) {
if name == "" {
return nil, fmt.Errorf("missing config name")
}
txn := b.db.Txn(false)
return b.MemDBMFAConfigByNameInTxn(ctx, txn, name)
}
func (b *LoginMFABackend) MemDBMFALoginEnforcementConfigByNameAndNamespace(name, namespaceId string) (*mfa.MFAEnforcementConfig, error) { func (b *LoginMFABackend) MemDBMFALoginEnforcementConfigByNameAndNamespace(name, namespaceId string) (*mfa.MFAEnforcementConfig, error) {
if name == "" { if name == "" {
return nil, fmt.Errorf("missing config name") return nil, fmt.Errorf("missing config name")

61
vault/login_mfa_test.go Normal file
View File

@ -0,0 +1,61 @@
package vault
import (
"strings"
"testing"
)
func TestParseFactors(t *testing.T) {
testcases := []struct {
name string
invalidMFAHeaderVal []string
expectedError string
}{
{
"two headers with passcode",
[]string{"passcode", "foo"},
"found multiple passcodes for the same MFA method",
},
{
"single header with passcode=",
[]string{"passcode="},
"invalid passcode",
},
{
"single invalid header",
[]string{"foo="},
"found an invalid MFA cred",
},
{
"single header equal char",
[]string{"=="},
"found an invalid MFA cred",
},
{
"two headers with passcode=",
[]string{"passcode=foo", "foo"},
"found multiple passcodes for the same MFA method",
},
{
"two headers invalid name",
[]string{"passcode=foo", "passcode=bar"},
"found multiple passcodes for the same MFA method",
},
{
"two headers, two invalid",
[]string{"foo", "bar"},
"found multiple passcodes for the same MFA method",
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
_, err := parseMfaFactors(tc.invalidMFAHeaderVal)
if err == nil {
t.Fatal("nil error returned")
}
if !strings.Contains(err.Error(), tc.expectedError) {
t.Fatalf("expected %s, got %v", tc.expectedError, err)
}
})
}
}

View File

@ -2059,6 +2059,7 @@ func (c *Core) buildMfaEnforcementResponse(eConfig *mfa.MFAEnforcementConfig) (*
Type: mConfig.Type, Type: mConfig.Type,
ID: methodID, ID: methodID,
UsesPasscode: mConfig.Type == mfaMethodTypeTOTP || duoUsePasscode, UsesPasscode: mConfig.Type == mfaMethodTypeTOTP || duoUsePasscode,
Name: mConfig.Name,
} }
mfaAny.Any = append(mfaAny.Any, mfaMethod) mfaAny.Any = append(mfaAny.Any, mfaMethod)
} }