open-vault/api/auth_test.go

132 lines
2.7 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package api
import (
"context"
"testing"
)
type mockAuthMethod struct {
mockedSecret *Secret
mockedError error
}
func (m *mockAuthMethod) Login(_ context.Context, _ *Client) (*Secret, error) {
return m.mockedSecret, m.mockedError
}
func TestAuth_Login(t *testing.T) {
a := &Auth{
c: &Client{},
}
m := mockAuthMethod{
mockedSecret: &Secret{
Auth: &SecretAuth{
ClientToken: "a-client-token",
},
},
mockedError: nil,
}
t.Run("Login should set token on success", func(t *testing.T) {
if a.c.Token() != "" {
t.Errorf("client token was %v expected to be unset", a.c.Token())
}
_, err := a.Login(context.Background(), &m)
if err != nil {
t.Errorf("Login() error = %v", err)
return
}
if a.c.Token() != m.mockedSecret.Auth.ClientToken {
t.Errorf("client token was %v expected %v", a.c.Token(), m.mockedSecret.Auth.ClientToken)
return
}
})
}
func TestAuth_MFALoginSinglePhase(t *testing.T) {
t.Run("MFALogin() should succeed if credentials are passed in", func(t *testing.T) {
a := &Auth{
c: &Client{},
}
m := mockAuthMethod{
mockedSecret: &Secret{
Auth: &SecretAuth{
ClientToken: "a-client-token",
},
},
mockedError: nil,
}
_, err := a.MFALogin(context.Background(), &m, "testMethod:testPasscode")
if err != nil {
t.Errorf("MFALogin() error %v", err)
return
}
if a.c.Token() != m.mockedSecret.Auth.ClientToken {
t.Errorf("client token was %v expected %v", a.c.Token(), m.mockedSecret.Auth.ClientToken)
return
}
})
}
func TestAuth_MFALoginTwoPhase(t *testing.T) {
tests := []struct {
name string
a *Auth
m *mockAuthMethod
creds *string
wantErr bool
}{
{
name: "return MFARequirements",
a: &Auth{
c: &Client{},
},
m: &mockAuthMethod{
mockedSecret: &Secret{
Auth: &SecretAuth{
MFARequirement: &MFARequirement{
MFARequestID: "a-req-id",
MFAConstraints: nil,
},
},
},
},
wantErr: false,
},
{
name: "error if no MFARequirements",
a: &Auth{
c: &Client{},
},
m: &mockAuthMethod{
mockedSecret: &Secret{
Auth: &SecretAuth{},
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
secret, err := tt.a.MFALogin(context.Background(), tt.m)
if (err != nil) != tt.wantErr {
t.Errorf("MFALogin() error = %v, wantErr %v", err, tt.wantErr)
return
}
if secret.Auth.MFARequirement != tt.m.mockedSecret.Auth.MFARequirement {
t.Errorf("MFALogin() returned %v, expected %v", secret.Auth.MFARequirement, tt.m.mockedSecret.Auth.MFARequirement)
return
}
})
}
}