open-vault/sdk/database/dbplugin/v5/grpc_client_test.go

565 lines
12 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbplugin
import (
"context"
"encoding/json"
"errors"
"reflect"
"testing"
"time"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
"google.golang.org/grpc"
)
func TestGRPCClient_Initialize(t *testing.T) {
type testCase struct {
client proto.DatabaseClient
req InitializeRequest
expectedResp InitializeResponse
assertErr errorAssertion
}
tests := map[string]testCase{
"bad config": {
client: fakeClient{},
req: InitializeRequest{
Config: map[string]interface{}{
"foo": badJSONValue{},
},
},
assertErr: assertErrNotNil,
},
"database error": {
client: fakeClient{
initErr: errors.New("initialize error"),
},
req: InitializeRequest{
Config: map[string]interface{}{
"foo": "bar",
},
},
assertErr: assertErrNotNil,
},
"happy path": {
client: fakeClient{
initResp: &proto.InitializeResponse{
ConfigData: marshal(t, map[string]interface{}{
"foo": "bar",
"baz": "biz",
}),
},
},
req: InitializeRequest{
Config: map[string]interface{}{
"foo": "bar",
},
},
expectedResp: InitializeResponse{
Config: map[string]interface{}{
"foo": "bar",
"baz": "biz",
},
},
assertErr: assertErrNil,
},
"JSON number type in initialize request": {
client: fakeClient{
initResp: &proto.InitializeResponse{
ConfigData: marshal(t, map[string]interface{}{
"foo": "bar",
"max": "10",
}),
},
},
req: InitializeRequest{
Config: map[string]interface{}{
"foo": "bar",
"max": json.Number("10"),
},
},
expectedResp: InitializeResponse{
Config: map[string]interface{}{
"foo": "bar",
"max": "10",
},
},
assertErr: assertErrNil,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
c := gRPCClient{
client: test.client,
doneCtx: nil,
}
// Context doesn't need to timeout since this is just passed through
ctx := context.Background()
resp, err := c.Initialize(ctx, test.req)
test.assertErr(t, err)
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
func TestGRPCClient_NewUser(t *testing.T) {
runningCtx := context.Background()
cancelledCtx, cancel := context.WithCancel(context.Background())
cancel()
type testCase struct {
client proto.DatabaseClient
req NewUserRequest
doneCtx context.Context
expectedResp NewUserResponse
assertErr errorAssertion
}
tests := map[string]testCase{
"missing password": {
client: fakeClient{},
req: NewUserRequest{
Password: "",
Expiration: time.Now(),
},
doneCtx: runningCtx,
assertErr: assertErrNotNil,
},
"bad expiration": {
client: fakeClient{},
req: NewUserRequest{
Password: "njkvcb8y934u90grsnkjl",
Expiration: invalidExpiration,
},
doneCtx: runningCtx,
assertErr: assertErrNotNil,
},
"database error": {
client: fakeClient{
newUserErr: errors.New("new user error"),
},
req: NewUserRequest{
Password: "njkvcb8y934u90grsnkjl",
Expiration: time.Now(),
},
doneCtx: runningCtx,
assertErr: assertErrNotNil,
},
"plugin shut down": {
client: fakeClient{
newUserErr: errors.New("new user error"),
},
req: NewUserRequest{
Password: "njkvcb8y934u90grsnkjl",
Expiration: time.Now(),
},
doneCtx: cancelledCtx,
assertErr: assertErrEquals(ErrPluginShutdown),
},
"happy path": {
client: fakeClient{
newUserResp: &proto.NewUserResponse{
Username: "new_user",
},
},
req: NewUserRequest{
Password: "njkvcb8y934u90grsnkjl",
Expiration: time.Now(),
},
doneCtx: runningCtx,
expectedResp: NewUserResponse{
Username: "new_user",
},
assertErr: assertErrNil,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
c := gRPCClient{
client: test.client,
doneCtx: test.doneCtx,
}
ctx := context.Background()
resp, err := c.NewUser(ctx, test.req)
test.assertErr(t, err)
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
}
})
}
}
func TestGRPCClient_UpdateUser(t *testing.T) {
runningCtx := context.Background()
cancelledCtx, cancel := context.WithCancel(context.Background())
cancel()
type testCase struct {
client proto.DatabaseClient
req UpdateUserRequest
doneCtx context.Context
assertErr errorAssertion
}
tests := map[string]testCase{
"missing username": {
client: fakeClient{},
req: UpdateUserRequest{},
doneCtx: runningCtx,
assertErr: assertErrNotNil,
},
"missing changes": {
client: fakeClient{},
req: UpdateUserRequest{
Username: "user",
},
doneCtx: runningCtx,
assertErr: assertErrNotNil,
},
"empty password": {
client: fakeClient{},
req: UpdateUserRequest{
Username: "user",
Password: &ChangePassword{
NewPassword: "",
},
},
doneCtx: runningCtx,
assertErr: assertErrNotNil,
},
"zero expiration": {
client: fakeClient{},
req: UpdateUserRequest{
Username: "user",
Expiration: &ChangeExpiration{
NewExpiration: time.Time{},
},
},
doneCtx: runningCtx,
assertErr: assertErrNotNil,
},
"bad expiration": {
client: fakeClient{},
req: UpdateUserRequest{
Username: "user",
Expiration: &ChangeExpiration{
NewExpiration: invalidExpiration,
},
},
doneCtx: runningCtx,
assertErr: assertErrNotNil,
},
"database error": {
client: fakeClient{
updateUserErr: errors.New("update user error"),
},
req: UpdateUserRequest{
Username: "user",
Password: &ChangePassword{
NewPassword: "asdf",
},
},
doneCtx: runningCtx,
assertErr: assertErrNotNil,
},
"plugin shut down": {
client: fakeClient{
updateUserErr: errors.New("update user error"),
},
req: UpdateUserRequest{
Username: "user",
Password: &ChangePassword{
NewPassword: "asdf",
},
},
doneCtx: cancelledCtx,
assertErr: assertErrEquals(ErrPluginShutdown),
},
"happy path - change password": {
client: fakeClient{},
req: UpdateUserRequest{
Username: "user",
Password: &ChangePassword{
NewPassword: "asdf",
},
},
doneCtx: runningCtx,
assertErr: assertErrNil,
},
"happy path - change expiration": {
client: fakeClient{},
req: UpdateUserRequest{
Username: "user",
Expiration: &ChangeExpiration{
NewExpiration: time.Now(),
},
},
doneCtx: runningCtx,
assertErr: assertErrNil,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
c := gRPCClient{
client: test.client,
doneCtx: test.doneCtx,
}
ctx := context.Background()
_, err := c.UpdateUser(ctx, test.req)
test.assertErr(t, err)
})
}
}
func TestGRPCClient_DeleteUser(t *testing.T) {
runningCtx := context.Background()
cancelledCtx, cancel := context.WithCancel(context.Background())
cancel()
type testCase struct {
client proto.DatabaseClient
req DeleteUserRequest
doneCtx context.Context
assertErr errorAssertion
}
tests := map[string]testCase{
"missing username": {
client: fakeClient{},
req: DeleteUserRequest{},
doneCtx: runningCtx,
assertErr: assertErrNotNil,
},
"database error": {
client: fakeClient{
deleteUserErr: errors.New("delete user error'"),
},
req: DeleteUserRequest{
Username: "user",
},
doneCtx: runningCtx,
assertErr: assertErrNotNil,
},
"plugin shut down": {
client: fakeClient{
deleteUserErr: errors.New("delete user error'"),
},
req: DeleteUserRequest{
Username: "user",
},
doneCtx: cancelledCtx,
assertErr: assertErrEquals(ErrPluginShutdown),
},
"happy path": {
client: fakeClient{},
req: DeleteUserRequest{
Username: "user",
},
doneCtx: runningCtx,
assertErr: assertErrNil,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
c := gRPCClient{
client: test.client,
doneCtx: test.doneCtx,
}
ctx := context.Background()
_, err := c.DeleteUser(ctx, test.req)
test.assertErr(t, err)
})
}
}
func TestGRPCClient_Type(t *testing.T) {
runningCtx := context.Background()
cancelledCtx, cancel := context.WithCancel(context.Background())
cancel()
type testCase struct {
client proto.DatabaseClient
doneCtx context.Context
expectedType string
assertErr errorAssertion
}
tests := map[string]testCase{
"database error": {
client: fakeClient{
typeErr: errors.New("type error"),
},
doneCtx: runningCtx,
assertErr: assertErrNotNil,
},
"plugin shut down": {
client: fakeClient{
typeErr: errors.New("type error"),
},
doneCtx: cancelledCtx,
assertErr: assertErrEquals(ErrPluginShutdown),
},
"happy path": {
client: fakeClient{
typeResp: &proto.TypeResponse{
Type: "test type",
},
},
doneCtx: runningCtx,
expectedType: "test type",
assertErr: assertErrNil,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
c := gRPCClient{
client: test.client,
doneCtx: test.doneCtx,
}
dbType, err := c.Type()
test.assertErr(t, err)
if dbType != test.expectedType {
t.Fatalf("Actual type: %s Expected type: %s", dbType, test.expectedType)
}
})
}
}
func TestGRPCClient_Close(t *testing.T) {
runningCtx := context.Background()
cancelledCtx, cancel := context.WithCancel(context.Background())
cancel()
type testCase struct {
client proto.DatabaseClient
doneCtx context.Context
assertErr errorAssertion
}
tests := map[string]testCase{
"database error": {
client: fakeClient{
typeErr: errors.New("type error"),
},
doneCtx: runningCtx,
assertErr: assertErrNotNil,
},
"plugin shut down": {
client: fakeClient{
typeErr: errors.New("type error"),
},
doneCtx: cancelledCtx,
assertErr: assertErrEquals(ErrPluginShutdown),
},
"happy path": {
client: fakeClient{},
doneCtx: runningCtx,
assertErr: assertErrNil,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
c := gRPCClient{
client: test.client,
doneCtx: test.doneCtx,
}
err := c.Close()
test.assertErr(t, err)
})
}
}
type errorAssertion func(*testing.T, error)
func assertErrNotNil(t *testing.T, err error) {
t.Helper()
if err == nil {
t.Fatalf("err expected, got nil")
}
}
func assertErrNil(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("no error expected, got: %s", err)
}
}
func assertErrEquals(expectedErr error) errorAssertion {
return func(t *testing.T, err error) {
t.Helper()
if err != expectedErr {
t.Fatalf("Actual err: %#v Expected err: %#v", err, expectedErr)
}
}
}
var _ proto.DatabaseClient = fakeClient{}
type fakeClient struct {
initResp *proto.InitializeResponse
initErr error
newUserResp *proto.NewUserResponse
newUserErr error
updateUserResp *proto.UpdateUserResponse
updateUserErr error
deleteUserResp *proto.DeleteUserResponse
deleteUserErr error
typeResp *proto.TypeResponse
typeErr error
closeErr error
}
func (f fakeClient) Initialize(context.Context, *proto.InitializeRequest, ...grpc.CallOption) (*proto.InitializeResponse, error) {
return f.initResp, f.initErr
}
func (f fakeClient) NewUser(context.Context, *proto.NewUserRequest, ...grpc.CallOption) (*proto.NewUserResponse, error) {
return f.newUserResp, f.newUserErr
}
func (f fakeClient) UpdateUser(context.Context, *proto.UpdateUserRequest, ...grpc.CallOption) (*proto.UpdateUserResponse, error) {
return f.updateUserResp, f.updateUserErr
}
func (f fakeClient) DeleteUser(context.Context, *proto.DeleteUserRequest, ...grpc.CallOption) (*proto.DeleteUserResponse, error) {
return f.deleteUserResp, f.deleteUserErr
}
func (f fakeClient) Type(context.Context, *proto.Empty, ...grpc.CallOption) (*proto.TypeResponse, error) {
return f.typeResp, f.typeErr
}
func (f fakeClient) Close(context.Context, *proto.Empty, ...grpc.CallOption) (*proto.Empty, error) {
return &proto.Empty{}, f.typeErr
}