159 lines
4.1 KiB
Go
159 lines
4.1 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package dbplugin
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
|
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
|
"github.com/hashicorp/vault/sdk/helper/wrapping"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
"github.com/stretchr/testify/mock"
|
|
"google.golang.org/grpc"
|
|
)
|
|
|
|
func TestNewPluginClient(t *testing.T) {
|
|
type testCase struct {
|
|
config pluginutil.PluginClientConfig
|
|
pluginClient pluginutil.PluginClient
|
|
expectedResp *DatabasePluginClient
|
|
expectedErr error
|
|
}
|
|
|
|
tests := map[string]testCase{
|
|
"happy path": {
|
|
config: testPluginClientConfig(),
|
|
pluginClient: &fakePluginClient{
|
|
connResp: nil,
|
|
dispenseResp: gRPCClient{client: fakeClient{}},
|
|
dispenseErr: nil,
|
|
},
|
|
expectedResp: &DatabasePluginClient{
|
|
client: &fakePluginClient{
|
|
connResp: nil,
|
|
dispenseResp: gRPCClient{client: fakeClient{}},
|
|
dispenseErr: nil,
|
|
},
|
|
Database: gRPCClient{client: proto.NewDatabaseClient(nil), versionClient: logical.NewPluginVersionClient(nil), doneCtx: context.Context(nil)},
|
|
},
|
|
expectedErr: nil,
|
|
},
|
|
"dispense error": {
|
|
config: testPluginClientConfig(),
|
|
pluginClient: &fakePluginClient{
|
|
connResp: nil,
|
|
dispenseResp: gRPCClient{},
|
|
dispenseErr: errors.New("dispense error"),
|
|
},
|
|
expectedResp: nil,
|
|
expectedErr: errors.New("dispense error"),
|
|
},
|
|
"error unsupported client type": {
|
|
config: testPluginClientConfig(),
|
|
pluginClient: &fakePluginClient{
|
|
connResp: nil,
|
|
dispenseResp: nil,
|
|
dispenseErr: nil,
|
|
},
|
|
expectedResp: nil,
|
|
expectedErr: errors.New("unsupported client type"),
|
|
},
|
|
}
|
|
|
|
for name, test := range tests {
|
|
t.Run(name, func(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
mockWrapper := new(mockRunnerUtil)
|
|
mockWrapper.On("NewPluginClient", ctx, mock.Anything).
|
|
Return(test.pluginClient, nil)
|
|
defer mockWrapper.AssertNumberOfCalls(t, "NewPluginClient", 1)
|
|
|
|
resp, err := NewPluginClient(ctx, mockWrapper, test.config)
|
|
if test.expectedErr != nil && err == nil {
|
|
t.Fatalf("err expected, got nil")
|
|
}
|
|
if test.expectedErr == nil && err != nil {
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
}
|
|
if test.expectedErr == nil && !reflect.DeepEqual(resp, test.expectedResp) {
|
|
t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func testPluginClientConfig() pluginutil.PluginClientConfig {
|
|
return pluginutil.PluginClientConfig{
|
|
Name: "test-plugin",
|
|
PluginSets: PluginSets,
|
|
PluginType: consts.PluginTypeDatabase,
|
|
HandshakeConfig: HandshakeConfig,
|
|
Logger: log.NewNullLogger(),
|
|
IsMetadataMode: true,
|
|
AutoMTLS: true,
|
|
}
|
|
}
|
|
|
|
var _ pluginutil.PluginClient = &fakePluginClient{}
|
|
|
|
type fakePluginClient struct {
|
|
connResp grpc.ClientConnInterface
|
|
|
|
dispenseResp interface{}
|
|
dispenseErr error
|
|
}
|
|
|
|
func (f *fakePluginClient) Conn() grpc.ClientConnInterface {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakePluginClient) Reload() error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakePluginClient) Dispense(name string) (interface{}, error) {
|
|
return f.dispenseResp, f.dispenseErr
|
|
}
|
|
|
|
func (f *fakePluginClient) Ping() error {
|
|
return nil
|
|
}
|
|
|
|
func (f *fakePluginClient) Close() error {
|
|
return nil
|
|
}
|
|
|
|
var _ pluginutil.RunnerUtil = &mockRunnerUtil{}
|
|
|
|
type mockRunnerUtil struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *mockRunnerUtil) VaultVersion(ctx context.Context) (string, error) {
|
|
return "dummyversion", nil
|
|
}
|
|
|
|
func (m *mockRunnerUtil) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) {
|
|
args := m.Called(ctx, config)
|
|
return args.Get(0).(pluginutil.PluginClient), args.Error(1)
|
|
}
|
|
|
|
func (m *mockRunnerUtil) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
|
|
args := m.Called(ctx, data, ttl, jwt)
|
|
return args.Get(0).(*wrapping.ResponseWrapInfo), args.Error(1)
|
|
}
|
|
|
|
func (m *mockRunnerUtil) MlockEnabled() bool {
|
|
args := m.Called()
|
|
return args.Bool(0)
|
|
}
|