544 lines
11 KiB
Go
544 lines
11 KiB
Go
|
package database
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/rand"
|
||
|
"crypto/rsa"
|
||
|
"crypto/x509"
|
||
|
"encoding/pem"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/hashicorp/vault/sdk/helper/base62"
|
||
|
"github.com/hashicorp/vault/sdk/logical"
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
"github.com/stretchr/testify/mock"
|
||
|
)
|
||
|
|
||
|
func Test_newPasswordGenerator(t *testing.T) {
|
||
|
type args struct {
|
||
|
config map[string]interface{}
|
||
|
}
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
args args
|
||
|
want passwordGenerator
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{
|
||
|
name: "newPasswordGenerator with nil config",
|
||
|
args: args{
|
||
|
config: nil,
|
||
|
},
|
||
|
want: passwordGenerator{
|
||
|
PasswordPolicy: "",
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "newPasswordGenerator without password_policy",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{},
|
||
|
},
|
||
|
want: passwordGenerator{
|
||
|
PasswordPolicy: "",
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "newPasswordGenerator with password_policy",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"password_policy": "test-policy",
|
||
|
},
|
||
|
},
|
||
|
want: passwordGenerator{
|
||
|
PasswordPolicy: "test-policy",
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
got, err := newPasswordGenerator(tt.args.config)
|
||
|
if tt.wantErr {
|
||
|
assert.Error(t, err)
|
||
|
return
|
||
|
}
|
||
|
assert.Equal(t, tt.want, got)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func Test_newRSAKeyGenerator(t *testing.T) {
|
||
|
type args struct {
|
||
|
config map[string]interface{}
|
||
|
}
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
args args
|
||
|
want rsaKeyGenerator
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{
|
||
|
name: "newRSAKeyGenerator with nil config",
|
||
|
args: args{
|
||
|
config: nil,
|
||
|
},
|
||
|
want: rsaKeyGenerator{
|
||
|
Format: "pkcs8",
|
||
|
KeyBits: 2048,
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "newRSAKeyGenerator with empty config",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{},
|
||
|
},
|
||
|
want: rsaKeyGenerator{
|
||
|
Format: "pkcs8",
|
||
|
KeyBits: 2048,
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "newRSAKeyGenerator with zero value format",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"format": "",
|
||
|
},
|
||
|
},
|
||
|
want: rsaKeyGenerator{
|
||
|
Format: "pkcs8",
|
||
|
KeyBits: 2048,
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "newRSAKeyGenerator with zero value key_bits",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"key_bits": "0",
|
||
|
},
|
||
|
},
|
||
|
want: rsaKeyGenerator{
|
||
|
Format: "pkcs8",
|
||
|
KeyBits: 2048,
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "newRSAKeyGenerator with format",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"format": "pkcs8",
|
||
|
},
|
||
|
},
|
||
|
want: rsaKeyGenerator{
|
||
|
Format: "pkcs8",
|
||
|
KeyBits: 2048,
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "newRSAKeyGenerator with format case insensitive",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"format": "PKCS8",
|
||
|
},
|
||
|
},
|
||
|
want: rsaKeyGenerator{
|
||
|
Format: "PKCS8",
|
||
|
KeyBits: 2048,
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "newRSAKeyGenerator with 3072 key_bits",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"key_bits": "3072",
|
||
|
},
|
||
|
},
|
||
|
want: rsaKeyGenerator{
|
||
|
Format: "pkcs8",
|
||
|
KeyBits: 3072,
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "newRSAKeyGenerator with 4096 key_bits",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"key_bits": "4096",
|
||
|
},
|
||
|
},
|
||
|
want: rsaKeyGenerator{
|
||
|
Format: "pkcs8",
|
||
|
KeyBits: 4096,
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "newRSAKeyGenerator with invalid key_bits",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"key_bits": "4097",
|
||
|
},
|
||
|
},
|
||
|
wantErr: true,
|
||
|
},
|
||
|
{
|
||
|
name: "newRSAKeyGenerator with invalid format",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"format": "pkcs1",
|
||
|
},
|
||
|
},
|
||
|
wantErr: true,
|
||
|
},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
got, err := newRSAKeyGenerator(tt.args.config)
|
||
|
if tt.wantErr {
|
||
|
assert.Error(t, err)
|
||
|
return
|
||
|
}
|
||
|
assert.Equal(t, tt.want, got)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func Test_passwordGenerator_generate(t *testing.T) {
|
||
|
config := logical.TestBackendConfig()
|
||
|
b := Backend(config)
|
||
|
b.Setup(context.Background(), config)
|
||
|
|
||
|
type args struct {
|
||
|
config map[string]interface{}
|
||
|
mock func() interface{}
|
||
|
passGen logical.PasswordGenerator
|
||
|
}
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
args args
|
||
|
wantRegexp string
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{
|
||
|
name: "wrapper missing v4 and v5 interface",
|
||
|
args: args{
|
||
|
mock: func() interface{} {
|
||
|
return nil
|
||
|
},
|
||
|
},
|
||
|
wantErr: true,
|
||
|
},
|
||
|
{
|
||
|
name: "v4: generate password using GenerateCredentials",
|
||
|
args: args{
|
||
|
mock: func() interface{} {
|
||
|
v4Mock := new(mockLegacyDatabase)
|
||
|
v4Mock.On("GenerateCredentials", mock.Anything).
|
||
|
Return("v4-generated-password", nil).
|
||
|
Times(1)
|
||
|
return v4Mock
|
||
|
},
|
||
|
},
|
||
|
wantRegexp: "^v4-generated-password$",
|
||
|
},
|
||
|
{
|
||
|
name: "v5: generate password without policy",
|
||
|
args: args{
|
||
|
mock: func() interface{} {
|
||
|
return new(mockNewDatabase)
|
||
|
},
|
||
|
},
|
||
|
wantRegexp: "^[a-zA-Z0-9-]{20}$",
|
||
|
},
|
||
|
{
|
||
|
name: "v5: generate password with non-existing policy",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"password_policy": "not-created",
|
||
|
},
|
||
|
mock: func() interface{} {
|
||
|
return new(mockNewDatabase)
|
||
|
},
|
||
|
},
|
||
|
wantErr: true,
|
||
|
},
|
||
|
{
|
||
|
name: "v5: generate password with existing policy",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"password_policy": "test-policy",
|
||
|
},
|
||
|
mock: func() interface{} {
|
||
|
return new(mockNewDatabase)
|
||
|
},
|
||
|
passGen: func() (string, error) {
|
||
|
return base62.Random(30)
|
||
|
},
|
||
|
},
|
||
|
wantRegexp: "^[a-zA-Z0-9]{30}$",
|
||
|
},
|
||
|
{
|
||
|
name: "v5: generate password with existing policy static",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"password_policy": "test-policy",
|
||
|
},
|
||
|
mock: func() interface{} {
|
||
|
return new(mockNewDatabase)
|
||
|
},
|
||
|
passGen: func() (string, error) {
|
||
|
return "policy-generated-password", nil
|
||
|
},
|
||
|
},
|
||
|
wantRegexp: "^policy-generated-password$",
|
||
|
},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
// Set up the version wrapper with a mock database implementation
|
||
|
wrapper := databaseVersionWrapper{}
|
||
|
switch m := tt.args.mock().(type) {
|
||
|
case *mockNewDatabase:
|
||
|
wrapper.v5 = m
|
||
|
case *mockLegacyDatabase:
|
||
|
wrapper.v4 = m
|
||
|
}
|
||
|
|
||
|
// Set the password policy for the test case
|
||
|
config.System.(*logical.StaticSystemView).SetPasswordPolicy(
|
||
|
"test-policy", tt.args.passGen)
|
||
|
|
||
|
// Generate the password
|
||
|
pg, err := newPasswordGenerator(tt.args.config)
|
||
|
got, err := pg.generate(context.Background(), b, wrapper)
|
||
|
if tt.wantErr {
|
||
|
assert.Error(t, err)
|
||
|
return
|
||
|
}
|
||
|
assert.Regexp(t, tt.wantRegexp, got)
|
||
|
|
||
|
// Assert all expected calls took place on the mock
|
||
|
if m, ok := wrapper.v5.(*mockNewDatabase); ok {
|
||
|
m.AssertExpectations(t)
|
||
|
}
|
||
|
if m, ok := wrapper.v4.(*mockLegacyDatabase); ok {
|
||
|
m.AssertExpectations(t)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func Test_passwordGenerator_configMap(t *testing.T) {
|
||
|
type args struct {
|
||
|
config map[string]interface{}
|
||
|
}
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
args args
|
||
|
want map[string]interface{}
|
||
|
}{
|
||
|
{
|
||
|
name: "nil config results in empty map",
|
||
|
args: args{
|
||
|
config: nil,
|
||
|
},
|
||
|
want: map[string]interface{}{},
|
||
|
},
|
||
|
{
|
||
|
name: "empty config results in empty map",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{},
|
||
|
},
|
||
|
want: map[string]interface{}{},
|
||
|
},
|
||
|
{
|
||
|
name: "input config is equal to output config",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"password_policy": "test-policy",
|
||
|
},
|
||
|
},
|
||
|
want: map[string]interface{}{
|
||
|
"password_policy": "test-policy",
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
pg, err := newPasswordGenerator(tt.args.config)
|
||
|
assert.NoError(t, err)
|
||
|
cm, err := pg.configMap()
|
||
|
assert.NoError(t, err)
|
||
|
assert.Equal(t, tt.want, cm)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func Test_rsaKeyGenerator_generate(t *testing.T) {
|
||
|
type args struct {
|
||
|
config map[string]interface{}
|
||
|
}
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
args args
|
||
|
}{
|
||
|
{
|
||
|
name: "generate RSA key with nil default config",
|
||
|
args: args{
|
||
|
config: nil,
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "generate RSA key with empty default config",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{},
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "generate RSA key with 2048 key_bits and format",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"key_bits": "2048",
|
||
|
"format": "pkcs8",
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "generate RSA key with 2048 key_bits",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"key_bits": "2048",
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "generate RSA key with 3072 key_bits",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"key_bits": "3072",
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "generate RSA key with 4096 key_bits",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"key_bits": "4096",
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
// Generate the RSA key pair
|
||
|
kg, err := newRSAKeyGenerator(tt.args.config)
|
||
|
public, private, err := kg.generate(rand.Reader)
|
||
|
assert.NoError(t, err)
|
||
|
assert.NotEmpty(t, public)
|
||
|
assert.NotEmpty(t, private)
|
||
|
|
||
|
// Decode the public and private key PEMs
|
||
|
pubBlock, pubRest := pem.Decode(public)
|
||
|
privBlock, privRest := pem.Decode(private)
|
||
|
assert.NotNil(t, pubBlock)
|
||
|
assert.Empty(t, pubRest)
|
||
|
assert.Equal(t, "PUBLIC KEY", pubBlock.Type)
|
||
|
assert.NotNil(t, privBlock)
|
||
|
assert.Empty(t, privRest)
|
||
|
assert.Equal(t, "PRIVATE KEY", privBlock.Type)
|
||
|
|
||
|
// Assert that we can parse the public key PEM block
|
||
|
pub, err := x509.ParsePKIXPublicKey(pubBlock.Bytes)
|
||
|
assert.NoError(t, err)
|
||
|
assert.NotNil(t, pub)
|
||
|
assert.IsType(t, &rsa.PublicKey{}, pub)
|
||
|
|
||
|
// Assert that we can parse the private key PEM block in
|
||
|
// the configured format
|
||
|
switch kg.Format {
|
||
|
case "pkcs8":
|
||
|
priv, err := x509.ParsePKCS8PrivateKey(privBlock.Bytes)
|
||
|
assert.NoError(t, err)
|
||
|
assert.NotNil(t, priv)
|
||
|
assert.IsType(t, &rsa.PrivateKey{}, priv)
|
||
|
default:
|
||
|
t.Fatal("unknown format")
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func Test_rsaKeyGenerator_configMap(t *testing.T) {
|
||
|
type args struct {
|
||
|
config map[string]interface{}
|
||
|
}
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
args args
|
||
|
want map[string]interface{}
|
||
|
}{
|
||
|
{
|
||
|
name: "nil config results in defaults",
|
||
|
args: args{
|
||
|
config: nil,
|
||
|
},
|
||
|
want: map[string]interface{}{
|
||
|
"format": "pkcs8",
|
||
|
"key_bits": 2048,
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "empty config results in defaults",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{},
|
||
|
},
|
||
|
want: map[string]interface{}{
|
||
|
"format": "pkcs8",
|
||
|
"key_bits": 2048,
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "config with format",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"format": "pkcs8",
|
||
|
},
|
||
|
},
|
||
|
want: map[string]interface{}{
|
||
|
"format": "pkcs8",
|
||
|
"key_bits": 2048,
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "config with key_bits",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"key_bits": 4096,
|
||
|
},
|
||
|
},
|
||
|
want: map[string]interface{}{
|
||
|
"format": "pkcs8",
|
||
|
"key_bits": 4096,
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "config with format and key_bits",
|
||
|
args: args{
|
||
|
config: map[string]interface{}{
|
||
|
"format": "pkcs8",
|
||
|
"key_bits": 3072,
|
||
|
},
|
||
|
},
|
||
|
want: map[string]interface{}{
|
||
|
"format": "pkcs8",
|
||
|
"key_bits": 3072,
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
kg, err := newRSAKeyGenerator(tt.args.config)
|
||
|
assert.NoError(t, err)
|
||
|
cm, err := kg.configMap()
|
||
|
assert.NoError(t, err)
|
||
|
assert.Equal(t, tt.want, cm)
|
||
|
})
|
||
|
}
|
||
|
}
|