open-vault/builtin/logical/database/version_wrapper_test.go

854 lines
19 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package database
import (
"context"
"fmt"
"reflect"
"testing"
"time"
v4 "github.com/hashicorp/vault/sdk/database/dbplugin"
v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestInitDatabase_missingDB(t *testing.T) {
dbw := databaseVersionWrapper{}
req := v5.InitializeRequest{}
resp, err := dbw.Initialize(context.Background(), req)
if err == nil {
t.Fatalf("err expected, got nil")
}
expectedResp := v5.InitializeResponse{}
if !reflect.DeepEqual(resp, expectedResp) {
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, expectedResp)
}
}
func TestInitDatabase_newDB(t *testing.T) {
type testCase struct {
req v5.InitializeRequest
newInitResp v5.InitializeResponse
newInitErr error
newInitCalls int
expectedResp v5.InitializeResponse
expectErr bool
}
tests := map[string]testCase{
"success": {
req: v5.InitializeRequest{
Config: map[string]interface{}{
"foo": "bar",
},
VerifyConnection: true,
},
newInitResp: v5.InitializeResponse{
Config: map[string]interface{}{
"foo": "bar",
},
},
newInitCalls: 1,
expectedResp: v5.InitializeResponse{
Config: map[string]interface{}{
"foo": "bar",
},
},
expectErr: false,
},
"error": {
req: v5.InitializeRequest{
Config: map[string]interface{}{
"foo": "bar",
},
VerifyConnection: true,
},
newInitResp: v5.InitializeResponse{},
newInitErr: fmt.Errorf("test error"),
newInitCalls: 1,
expectedResp: v5.InitializeResponse{},
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
newDB := new(mockNewDatabase)
newDB.On("Initialize", mock.Anything, mock.Anything).
Return(test.newInitResp, test.newInitErr)
defer newDB.AssertNumberOfCalls(t, "Initialize", test.newInitCalls)
dbw := databaseVersionWrapper{
v5: newDB,
}
resp, err := dbw.Initialize(context.Background(), test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, test.expectedResp)
}
})
}
}
func TestInitDatabase_legacyDB(t *testing.T) {
type testCase struct {
req v5.InitializeRequest
initConfig map[string]interface{}
initErr error
initCalls int
expectedResp v5.InitializeResponse
expectErr bool
}
tests := map[string]testCase{
"success": {
req: v5.InitializeRequest{
Config: map[string]interface{}{
"foo": "bar",
},
VerifyConnection: true,
},
initConfig: map[string]interface{}{
"foo": "bar",
},
initCalls: 1,
expectedResp: v5.InitializeResponse{
Config: map[string]interface{}{
"foo": "bar",
},
},
expectErr: false,
},
"error": {
req: v5.InitializeRequest{
Config: map[string]interface{}{
"foo": "bar",
},
VerifyConnection: true,
},
initErr: fmt.Errorf("test error"),
initCalls: 1,
expectedResp: v5.InitializeResponse{},
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
legacyDB := new(mockLegacyDatabase)
legacyDB.On("Init", mock.Anything, mock.Anything, mock.Anything).
Return(test.initConfig, test.initErr)
defer legacyDB.AssertNumberOfCalls(t, "Init", test.initCalls)
dbw := databaseVersionWrapper{
v4: legacyDB,
}
resp, err := dbw.Initialize(context.Background(), test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, test.expectedResp)
}
})
}
}
func TestNewUser_missingDB(t *testing.T) {
dbw := databaseVersionWrapper{}
req := v5.NewUserRequest{}
resp, pass, err := dbw.NewUser(context.Background(), req)
if err == nil {
t.Fatalf("err expected, got nil")
}
expectedResp := v5.NewUserResponse{}
if !reflect.DeepEqual(resp, expectedResp) {
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, expectedResp)
}
if pass != "" {
t.Fatalf("Password should be empty but was: %s", pass)
}
}
func TestNewUser_newDB(t *testing.T) {
type testCase struct {
req v5.NewUserRequest
newUserResp v5.NewUserResponse
newUserErr error
newUserCalls int
expectedResp v5.NewUserResponse
expectErr bool
}
tests := map[string]testCase{
"success": {
req: v5.NewUserRequest{
Password: "new_password",
},
newUserResp: v5.NewUserResponse{
Username: "newuser",
},
newUserCalls: 1,
expectedResp: v5.NewUserResponse{
Username: "newuser",
},
expectErr: false,
},
"error": {
req: v5.NewUserRequest{
Password: "new_password",
},
newUserErr: fmt.Errorf("test error"),
newUserCalls: 1,
expectedResp: v5.NewUserResponse{},
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
newDB := new(mockNewDatabase)
newDB.On("NewUser", mock.Anything, mock.Anything).
Return(test.newUserResp, test.newUserErr)
defer newDB.AssertNumberOfCalls(t, "NewUser", test.newUserCalls)
dbw := databaseVersionWrapper{
v5: newDB,
}
resp, password, err := dbw.NewUser(context.Background(), test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, test.expectedResp)
}
if password != test.req.Password {
t.Fatalf("Actual password: %s Expected password: %s", password, test.req.Password)
}
})
}
}
func TestNewUser_legacyDB(t *testing.T) {
type testCase struct {
req v5.NewUserRequest
createUserUsername string
createUserPassword string
createUserErr error
createUserCalls int
expectedResp v5.NewUserResponse
expectedPassword string
expectErr bool
}
tests := map[string]testCase{
"success": {
req: v5.NewUserRequest{
Password: "new_password",
},
createUserUsername: "newuser",
createUserPassword: "securepassword",
createUserCalls: 1,
expectedResp: v5.NewUserResponse{
Username: "newuser",
},
expectedPassword: "securepassword",
expectErr: false,
},
"error": {
req: v5.NewUserRequest{
Password: "new_password",
},
createUserErr: fmt.Errorf("test error"),
createUserCalls: 1,
expectedResp: v5.NewUserResponse{},
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
legacyDB := new(mockLegacyDatabase)
legacyDB.On("CreateUser", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(test.createUserUsername, test.createUserPassword, test.createUserErr)
defer legacyDB.AssertNumberOfCalls(t, "CreateUser", test.createUserCalls)
dbw := databaseVersionWrapper{
v4: legacyDB,
}
resp, password, err := dbw.NewUser(context.Background(), test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, test.expectedResp)
}
if password != test.expectedPassword {
t.Fatalf("Actual password: %s Expected password: %s", password, test.req.Password)
}
})
}
}
func TestUpdateUser_missingDB(t *testing.T) {
dbw := databaseVersionWrapper{}
req := v5.UpdateUserRequest{}
resp, err := dbw.UpdateUser(context.Background(), req, false)
if err == nil {
t.Fatalf("err expected, got nil")
}
expectedConfig := map[string]interface{}(nil)
if !reflect.DeepEqual(resp, expectedConfig) {
t.Fatalf("Actual config: %#v\nExpected config: %#v", resp, expectedConfig)
}
}
func TestUpdateUser_newDB(t *testing.T) {
type testCase struct {
req v5.UpdateUserRequest
updateUserErr error
updateUserCalls int
expectedResp v5.UpdateUserResponse
expectErr bool
}
tests := map[string]testCase{
"success": {
req: v5.UpdateUserRequest{
Username: "existing_user",
},
updateUserCalls: 1,
expectErr: false,
},
"error": {
req: v5.UpdateUserRequest{
Username: "existing_user",
},
updateUserErr: fmt.Errorf("test error"),
updateUserCalls: 1,
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
newDB := new(mockNewDatabase)
newDB.On("UpdateUser", mock.Anything, mock.Anything).
Return(v5.UpdateUserResponse{}, test.updateUserErr)
defer newDB.AssertNumberOfCalls(t, "UpdateUser", test.updateUserCalls)
dbw := databaseVersionWrapper{
v5: newDB,
}
_, err := dbw.UpdateUser(context.Background(), test.req, false)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}
func TestUpdateUser_legacyDB(t *testing.T) {
type testCase struct {
req v5.UpdateUserRequest
isRootUser bool
setCredentialsErr error
setCredentialsCalls int
rotateRootConfig map[string]interface{}
rotateRootErr error
rotateRootCalls int
renewUserErr error
renewUserCalls int
expectedConfig map[string]interface{}
expectErr bool
}
tests := map[string]testCase{
"missing changes": {
req: v5.UpdateUserRequest{
Username: "existing_user",
},
isRootUser: false,
setCredentialsCalls: 0,
rotateRootCalls: 0,
renewUserCalls: 0,
expectErr: true,
},
"both password and expiration changes": {
req: v5.UpdateUserRequest{
Username: "existing_user",
Password: &v5.ChangePassword{},
Expiration: &v5.ChangeExpiration{},
},
isRootUser: false,
setCredentialsCalls: 0,
rotateRootCalls: 0,
renewUserCalls: 0,
expectErr: true,
},
"change password - SetCredentials": {
req: v5.UpdateUserRequest{
Username: "existing_user",
Password: &v5.ChangePassword{
NewPassword: "newpassowrd",
},
},
isRootUser: false,
setCredentialsErr: nil,
setCredentialsCalls: 1,
rotateRootCalls: 0,
renewUserCalls: 0,
expectedConfig: nil,
expectErr: false,
},
"change password - SetCredentials failed": {
req: v5.UpdateUserRequest{
Username: "existing_user",
Password: &v5.ChangePassword{
NewPassword: "newpassowrd",
},
},
isRootUser: false,
setCredentialsErr: fmt.Errorf("set credentials failed"),
setCredentialsCalls: 1,
rotateRootCalls: 0,
renewUserCalls: 0,
expectedConfig: nil,
expectErr: true,
},
"change password - SetCredentials unimplemented but not a root user": {
req: v5.UpdateUserRequest{
Username: "existing_user",
Password: &v5.ChangePassword{
NewPassword: "newpassowrd",
},
},
isRootUser: false,
setCredentialsErr: status.Error(codes.Unimplemented, "SetCredentials is not implemented"),
setCredentialsCalls: 1,
rotateRootCalls: 0,
renewUserCalls: 0,
expectedConfig: nil,
expectErr: true,
},
"change password - RotateRootCredentials (gRPC Unimplemented)": {
req: v5.UpdateUserRequest{
Username: "existing_user",
Password: &v5.ChangePassword{
NewPassword: "newpassowrd",
},
},
isRootUser: true,
setCredentialsErr: status.Error(codes.Unimplemented, "SetCredentials is not implemented"),
setCredentialsCalls: 1,
rotateRootConfig: map[string]interface{}{
"foo": "bar",
},
rotateRootCalls: 1,
renewUserCalls: 0,
expectedConfig: map[string]interface{}{
"foo": "bar",
},
expectErr: false,
},
"change password - RotateRootCredentials (ErrPluginStaticUnsupported)": {
req: v5.UpdateUserRequest{
Username: "existing_user",
Password: &v5.ChangePassword{
NewPassword: "newpassowrd",
},
},
isRootUser: true,
setCredentialsErr: v4.ErrPluginStaticUnsupported,
setCredentialsCalls: 1,
rotateRootConfig: map[string]interface{}{
"foo": "bar",
},
rotateRootCalls: 1,
renewUserCalls: 0,
expectedConfig: map[string]interface{}{
"foo": "bar",
},
expectErr: false,
},
"change password - RotateRootCredentials failed": {
req: v5.UpdateUserRequest{
Username: "existing_user",
Password: &v5.ChangePassword{
NewPassword: "newpassowrd",
},
},
isRootUser: true,
setCredentialsErr: status.Error(codes.Unimplemented, "SetCredentials is not implemented"),
setCredentialsCalls: 1,
rotateRootErr: fmt.Errorf("rotate root failed"),
rotateRootCalls: 1,
renewUserCalls: 0,
expectedConfig: nil,
expectErr: true,
},
"change expiration": {
req: v5.UpdateUserRequest{
Username: "existing_user",
Expiration: &v5.ChangeExpiration{
NewExpiration: time.Now(),
},
},
isRootUser: false,
setCredentialsCalls: 0,
rotateRootCalls: 0,
renewUserErr: nil,
renewUserCalls: 1,
expectedConfig: nil,
expectErr: false,
},
"change expiration failed": {
req: v5.UpdateUserRequest{
Username: "existing_user",
Expiration: &v5.ChangeExpiration{
NewExpiration: time.Now(),
},
},
isRootUser: false,
setCredentialsCalls: 0,
rotateRootCalls: 0,
renewUserErr: fmt.Errorf("test error"),
renewUserCalls: 1,
expectedConfig: nil,
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
legacyDB := new(mockLegacyDatabase)
legacyDB.On("SetCredentials", mock.Anything, mock.Anything, mock.Anything).
Return("", "", test.setCredentialsErr)
defer legacyDB.AssertNumberOfCalls(t, "SetCredentials", test.setCredentialsCalls)
legacyDB.On("RotateRootCredentials", mock.Anything, mock.Anything).
Return(test.rotateRootConfig, test.rotateRootErr)
defer legacyDB.AssertNumberOfCalls(t, "RotateRootCredentials", test.rotateRootCalls)
legacyDB.On("RenewUser", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(test.renewUserErr)
defer legacyDB.AssertNumberOfCalls(t, "RenewUser", test.renewUserCalls)
dbw := databaseVersionWrapper{
v4: legacyDB,
}
newConfig, err := dbw.UpdateUser(context.Background(), test.req, test.isRootUser)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !reflect.DeepEqual(newConfig, test.expectedConfig) {
t.Fatalf("Actual config: %#v\nExpected config: %#v", newConfig, test.expectedConfig)
}
})
}
}
func TestDeleteUser_missingDB(t *testing.T) {
dbw := databaseVersionWrapper{}
req := v5.DeleteUserRequest{}
_, err := dbw.DeleteUser(context.Background(), req)
if err == nil {
t.Fatalf("err expected, got nil")
}
}
func TestDeleteUser_newDB(t *testing.T) {
type testCase struct {
req v5.DeleteUserRequest
deleteUserErr error
deleteUserCalls int
expectErr bool
}
tests := map[string]testCase{
"success": {
req: v5.DeleteUserRequest{
Username: "existing_user",
},
deleteUserErr: nil,
deleteUserCalls: 1,
expectErr: false,
},
"error": {
req: v5.DeleteUserRequest{
Username: "existing_user",
},
deleteUserErr: fmt.Errorf("test error"),
deleteUserCalls: 1,
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
newDB := new(mockNewDatabase)
newDB.On("DeleteUser", mock.Anything, mock.Anything).
Return(v5.DeleteUserResponse{}, test.deleteUserErr)
defer newDB.AssertNumberOfCalls(t, "DeleteUser", test.deleteUserCalls)
dbw := databaseVersionWrapper{
v5: newDB,
}
_, err := dbw.DeleteUser(context.Background(), test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}
func TestDeleteUser_legacyDB(t *testing.T) {
type testCase struct {
req v5.DeleteUserRequest
revokeUserErr error
revokeUserCalls int
expectErr bool
}
tests := map[string]testCase{
"success": {
req: v5.DeleteUserRequest{
Username: "existing_user",
},
revokeUserErr: nil,
revokeUserCalls: 1,
expectErr: false,
},
"error": {
req: v5.DeleteUserRequest{
Username: "existing_user",
},
revokeUserErr: fmt.Errorf("test error"),
revokeUserCalls: 1,
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
legacyDB := new(mockLegacyDatabase)
legacyDB.On("RevokeUser", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(test.revokeUserErr)
defer legacyDB.AssertNumberOfCalls(t, "RevokeUser", test.revokeUserCalls)
dbw := databaseVersionWrapper{
v4: legacyDB,
}
_, err := dbw.DeleteUser(context.Background(), test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}
type badValue struct{}
func (badValue) MarshalJSON() ([]byte, error) {
return nil, fmt.Errorf("this value cannot be marshalled to JSON")
}
var _ logical.Storage = fakeStorage{}
type fakeStorage struct {
putErr error
}
func (f fakeStorage) Put(ctx context.Context, entry *logical.StorageEntry) error {
return f.putErr
}
func (f fakeStorage) List(ctx context.Context, s string) ([]string, error) {
panic("list not implemented")
}
func (f fakeStorage) Get(ctx context.Context, s string) (*logical.StorageEntry, error) {
panic("get not implemented")
}
func (f fakeStorage) Delete(ctx context.Context, s string) error {
panic("delete not implemented")
}
func TestStoreConfig(t *testing.T) {
type testCase struct {
config *DatabaseConfig
putErr error
expectErr bool
}
tests := map[string]testCase{
"bad config": {
config: &DatabaseConfig{
PluginName: "testplugin",
ConnectionDetails: map[string]interface{}{
"bad value": badValue{},
},
},
putErr: nil,
expectErr: true,
},
"storage error": {
config: &DatabaseConfig{
PluginName: "testplugin",
ConnectionDetails: map[string]interface{}{
"foo": "bar",
},
},
putErr: fmt.Errorf("failed to store config"),
expectErr: true,
},
"happy path": {
config: &DatabaseConfig{
PluginName: "testplugin",
ConnectionDetails: map[string]interface{}{
"foo": "bar",
},
},
putErr: nil,
expectErr: false,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
storage := fakeStorage{
putErr: test.putErr,
}
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
err := storeConfig(ctx, storage, "testconfig", test.config)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}