2020-10-15 19:20:12 +00:00
|
|
|
package dbplugin
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"reflect"
|
|
|
|
"strings"
|
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
"unicode"
|
|
|
|
|
|
|
|
"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
|
|
|
|
"google.golang.org/protobuf/types/known/structpb"
|
|
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
)
|
|
|
|
|
|
|
|
func TestConversionsHaveAllFields(t *testing.T) {
|
|
|
|
t.Run("initReqToProto", func(t *testing.T) {
|
|
|
|
req := InitializeRequest{
|
|
|
|
Config: map[string]interface{}{
|
|
|
|
"foo": map[string]interface{}{
|
|
|
|
"bar": "baz",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
VerifyConnection: true,
|
|
|
|
}
|
|
|
|
|
|
|
|
protoReq, err := initReqToProto(req)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to convert request to proto request: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
values := getAllGetterValues(protoReq)
|
|
|
|
if len(values) == 0 {
|
|
|
|
// Probably a test failure - the protos used in these tests should have Get functions on them
|
|
|
|
t.Fatalf("No values found from Get functions!")
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, gtr := range values {
|
|
|
|
err := assertAllFieldsSet(fmt.Sprintf("InitializeRequest.%s", gtr.name), gtr.value)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("%s", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("newUserReqToProto", func(t *testing.T) {
|
|
|
|
req := NewUserRequest{
|
|
|
|
UsernameConfig: UsernameMetadata{
|
|
|
|
DisplayName: "dispName",
|
|
|
|
RoleName: "roleName",
|
|
|
|
},
|
|
|
|
Statements: Statements{
|
|
|
|
Commands: []string{
|
|
|
|
"statement",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
RollbackStatements: Statements{
|
|
|
|
Commands: []string{
|
|
|
|
"rollback_statement",
|
|
|
|
},
|
|
|
|
},
|
2022-05-17 16:21:26 +00:00
|
|
|
CredentialType: CredentialTypeRSAPrivateKey,
|
|
|
|
PublicKey: []byte("-----BEGIN PUBLIC KEY-----"),
|
|
|
|
Password: "password",
|
|
|
|
Expiration: time.Now(),
|
2020-10-15 19:20:12 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
protoReq, err := newUserReqToProto(req)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to convert request to proto request: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
values := getAllGetterValues(protoReq)
|
|
|
|
if len(values) == 0 {
|
|
|
|
// Probably a test failure - the protos used in these tests should have Get functions on them
|
|
|
|
t.Fatalf("No values found from Get functions!")
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, gtr := range values {
|
|
|
|
err := assertAllFieldsSet(fmt.Sprintf("NewUserRequest.%s", gtr.name), gtr.value)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("%s", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("updateUserReqToProto", func(t *testing.T) {
|
|
|
|
req := UpdateUserRequest{
|
2022-05-17 16:21:26 +00:00
|
|
|
Username: "username",
|
|
|
|
CredentialType: CredentialTypeRSAPrivateKey,
|
2020-10-15 19:20:12 +00:00
|
|
|
Password: &ChangePassword{
|
|
|
|
NewPassword: "newpassword",
|
|
|
|
Statements: Statements{
|
|
|
|
Commands: []string{
|
|
|
|
"statement",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
2022-05-17 16:21:26 +00:00
|
|
|
PublicKey: &ChangePublicKey{
|
|
|
|
NewPublicKey: []byte("-----BEGIN PUBLIC KEY-----"),
|
|
|
|
Statements: Statements{
|
|
|
|
Commands: []string{
|
|
|
|
"statement",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
2020-10-15 19:20:12 +00:00
|
|
|
Expiration: &ChangeExpiration{
|
|
|
|
NewExpiration: time.Now(),
|
|
|
|
Statements: Statements{
|
|
|
|
Commands: []string{
|
|
|
|
"statement",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
protoReq, err := updateUserReqToProto(req)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to convert request to proto request: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
values := getAllGetterValues(protoReq)
|
|
|
|
if len(values) == 0 {
|
|
|
|
// Probably a test failure - the protos used in these tests should have Get functions on them
|
|
|
|
t.Fatalf("No values found from Get functions!")
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, gtr := range values {
|
|
|
|
err := assertAllFieldsSet(fmt.Sprintf("UpdateUserRequest.%s", gtr.name), gtr.value)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("%s", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("deleteUserReqToProto", func(t *testing.T) {
|
|
|
|
req := DeleteUserRequest{
|
|
|
|
Username: "username",
|
|
|
|
Statements: Statements{
|
|
|
|
Commands: []string{
|
|
|
|
"statement",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
protoReq, err := deleteUserReqToProto(req)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to convert request to proto request: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
values := getAllGetterValues(protoReq)
|
|
|
|
if len(values) == 0 {
|
|
|
|
// Probably a test failure - the protos used in these tests should have Get functions on them
|
|
|
|
t.Fatalf("No values found from Get functions!")
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, gtr := range values {
|
|
|
|
err := assertAllFieldsSet(fmt.Sprintf("DeleteUserRequest.%s", gtr.name), gtr.value)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("%s", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("getUpdateUserRequest", func(t *testing.T) {
|
|
|
|
req := &proto.UpdateUserRequest{
|
2022-05-17 16:21:26 +00:00
|
|
|
Username: "username",
|
|
|
|
CredentialType: int32(CredentialTypeRSAPrivateKey),
|
2020-10-15 19:20:12 +00:00
|
|
|
Password: &proto.ChangePassword{
|
|
|
|
NewPassword: "newpass",
|
|
|
|
Statements: &proto.Statements{
|
|
|
|
Commands: []string{
|
|
|
|
"statement",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
2022-05-17 16:21:26 +00:00
|
|
|
PublicKey: &proto.ChangePublicKey{
|
|
|
|
NewPublicKey: []byte("-----BEGIN PUBLIC KEY-----"),
|
|
|
|
Statements: &proto.Statements{
|
|
|
|
Commands: []string{
|
|
|
|
"statement",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
2020-10-15 19:20:12 +00:00
|
|
|
Expiration: &proto.ChangeExpiration{
|
|
|
|
NewExpiration: timestamppb.Now(),
|
|
|
|
Statements: &proto.Statements{
|
|
|
|
Commands: []string{
|
|
|
|
"statement",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
protoReq, err := getUpdateUserRequest(req)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to convert request to proto request: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
err = assertAllFieldsSet("proto.UpdateUserRequest", protoReq)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("%s", err)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
type getter struct {
|
|
|
|
name string
|
|
|
|
value interface{}
|
|
|
|
}
|
|
|
|
|
|
|
|
func getAllGetterValues(value interface{}) (values []getter) {
|
|
|
|
typ := reflect.TypeOf(value)
|
|
|
|
val := reflect.ValueOf(value)
|
|
|
|
for i := 0; i < typ.NumMethod(); i++ {
|
|
|
|
method := typ.Method(i)
|
|
|
|
if !strings.HasPrefix(method.Name, "Get") {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
valMethod := val.Method(i)
|
|
|
|
resp := valMethod.Call(nil)
|
|
|
|
getVal := resp[0].Interface()
|
|
|
|
gtr := getter{
|
|
|
|
name: strings.TrimPrefix(method.Name, "Get"),
|
|
|
|
value: getVal,
|
|
|
|
}
|
|
|
|
values = append(values, gtr)
|
|
|
|
}
|
|
|
|
return values
|
|
|
|
}
|
|
|
|
|
|
|
|
// Ensures the assertion works properly
|
|
|
|
func TestAssertAllFieldsSet(t *testing.T) {
|
|
|
|
type testCase struct {
|
|
|
|
value interface{}
|
|
|
|
expectErr bool
|
|
|
|
}
|
|
|
|
|
|
|
|
tests := map[string]testCase{
|
|
|
|
"zero int": {
|
|
|
|
value: 0,
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"non-zero int": {
|
|
|
|
value: 1,
|
|
|
|
expectErr: false,
|
|
|
|
},
|
|
|
|
"zero float64": {
|
|
|
|
value: 0.0,
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"non-zero float64": {
|
|
|
|
value: 1.0,
|
|
|
|
expectErr: false,
|
|
|
|
},
|
|
|
|
"empty string": {
|
|
|
|
value: "",
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"true boolean": {
|
|
|
|
value: true,
|
|
|
|
expectErr: false,
|
|
|
|
},
|
|
|
|
"false boolean": { // False is an exception to the "is zero" rule
|
|
|
|
value: false,
|
|
|
|
expectErr: false,
|
|
|
|
},
|
|
|
|
"blank struct": {
|
|
|
|
value: struct{}{},
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"non-blank but empty struct": {
|
|
|
|
value: struct {
|
|
|
|
str string
|
|
|
|
}{
|
|
|
|
str: "",
|
|
|
|
},
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"non-empty string": {
|
|
|
|
value: "foo",
|
|
|
|
expectErr: false,
|
|
|
|
},
|
|
|
|
"non-empty struct": {
|
|
|
|
value: struct {
|
|
|
|
str string
|
|
|
|
}{
|
|
|
|
str: "foo",
|
|
|
|
},
|
|
|
|
expectErr: false,
|
|
|
|
},
|
|
|
|
"empty nested struct": {
|
|
|
|
value: struct {
|
|
|
|
Str string
|
|
|
|
Substruct struct {
|
|
|
|
Substr string
|
|
|
|
}
|
|
|
|
}{
|
|
|
|
Str: "foo",
|
|
|
|
Substruct: struct {
|
|
|
|
Substr string
|
|
|
|
}{}, // Empty sub-field
|
|
|
|
},
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"filled nested struct": {
|
|
|
|
value: struct {
|
|
|
|
str string
|
|
|
|
substruct struct {
|
|
|
|
substr string
|
|
|
|
}
|
|
|
|
}{
|
|
|
|
str: "foo",
|
|
|
|
substruct: struct {
|
|
|
|
substr string
|
|
|
|
}{
|
|
|
|
substr: "sub-foo",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
expectErr: false,
|
|
|
|
},
|
|
|
|
"nil map": {
|
|
|
|
value: map[string]string(nil),
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"empty map": {
|
|
|
|
value: map[string]string{},
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"filled map": {
|
|
|
|
value: map[string]string{
|
|
|
|
"foo": "bar",
|
|
|
|
"int": "42",
|
|
|
|
},
|
|
|
|
expectErr: false,
|
|
|
|
},
|
|
|
|
"map with empty string value": {
|
|
|
|
value: map[string]string{
|
|
|
|
"foo": "",
|
|
|
|
},
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"nested map with empty string value": {
|
|
|
|
value: map[string]interface{}{
|
|
|
|
"bar": "baz",
|
|
|
|
"foo": map[string]interface{}{
|
|
|
|
"subfoo": "",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"nil slice": {
|
|
|
|
value: []string(nil),
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"empty slice": {
|
|
|
|
value: []string{},
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"filled slice": {
|
|
|
|
value: []string{
|
|
|
|
"foo",
|
|
|
|
},
|
|
|
|
expectErr: false,
|
|
|
|
},
|
|
|
|
"slice with empty string value": {
|
|
|
|
value: []string{
|
|
|
|
"",
|
|
|
|
},
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"empty structpb": {
|
|
|
|
value: newStructPb(t, map[string]interface{}{}),
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"filled structpb": {
|
|
|
|
value: newStructPb(t, map[string]interface{}{
|
|
|
|
"foo": "bar",
|
|
|
|
"int": 42,
|
|
|
|
}),
|
|
|
|
expectErr: false,
|
|
|
|
},
|
|
|
|
|
|
|
|
"pointer to zero int": {
|
|
|
|
value: intPtr(0),
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"pointer to non-zero int": {
|
|
|
|
value: intPtr(1),
|
|
|
|
expectErr: false,
|
|
|
|
},
|
|
|
|
"pointer to zero float64": {
|
|
|
|
value: float64Ptr(0.0),
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"pointer to non-zero float64": {
|
|
|
|
value: float64Ptr(1.0),
|
|
|
|
expectErr: false,
|
|
|
|
},
|
|
|
|
"pointer to nil string": {
|
|
|
|
value: new(string),
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
"pointer to non-nil string": {
|
|
|
|
value: strPtr("foo"),
|
|
|
|
expectErr: false,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
for name, test := range tests {
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
err := assertAllFieldsSet("", test.value)
|
|
|
|
if test.expectErr && err == nil {
|
|
|
|
t.Fatalf("err expected, got nil")
|
|
|
|
}
|
|
|
|
if !test.expectErr && err != nil {
|
|
|
|
t.Fatalf("no error expected, got: %s", err)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func assertAllFieldsSet(name string, val interface{}) error {
|
|
|
|
if val == nil {
|
|
|
|
return fmt.Errorf("value is nil")
|
|
|
|
}
|
|
|
|
|
|
|
|
rVal := reflect.ValueOf(val)
|
|
|
|
return assertAllFieldsSetValue(name, rVal)
|
|
|
|
}
|
|
|
|
|
|
|
|
func assertAllFieldsSetValue(name string, rVal reflect.Value) error {
|
|
|
|
// All booleans are allowed - we don't have a way of differentiating between
|
|
|
|
// and intentional false and a missing false
|
|
|
|
if rVal.Kind() == reflect.Bool {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Primitives fall through here
|
|
|
|
if rVal.IsZero() {
|
|
|
|
return fmt.Errorf("%s is zero", name)
|
|
|
|
}
|
|
|
|
|
|
|
|
switch rVal.Kind() {
|
|
|
|
case reflect.Ptr, reflect.Interface:
|
|
|
|
return assertAllFieldsSetValue(name, rVal.Elem())
|
|
|
|
case reflect.Struct:
|
|
|
|
return assertAllFieldsSetStruct(name, rVal)
|
|
|
|
case reflect.Map:
|
|
|
|
if rVal.Len() == 0 {
|
|
|
|
return fmt.Errorf("%s (map type) is empty", name)
|
|
|
|
}
|
|
|
|
|
|
|
|
iter := rVal.MapRange()
|
|
|
|
for iter.Next() {
|
|
|
|
k := iter.Key()
|
|
|
|
v := iter.Value()
|
|
|
|
|
|
|
|
err := assertAllFieldsSetValue(fmt.Sprintf("%s[%s]", name, k), v)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
case reflect.Slice:
|
|
|
|
if rVal.Len() == 0 {
|
|
|
|
return fmt.Errorf("%s (slice type) is empty", name)
|
|
|
|
}
|
|
|
|
for i := 0; i < rVal.Len(); i++ {
|
|
|
|
sliceVal := rVal.Index(i)
|
|
|
|
err := assertAllFieldsSetValue(fmt.Sprintf("%s[%d]", name, i), sliceVal)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func assertAllFieldsSetStruct(name string, rVal reflect.Value) error {
|
|
|
|
switch rVal.Type() {
|
|
|
|
case reflect.TypeOf(timestamppb.Timestamp{}):
|
|
|
|
ts := rVal.Interface().(timestamppb.Timestamp)
|
|
|
|
if ts.AsTime().IsZero() {
|
|
|
|
return fmt.Errorf("%s is zero", name)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
default:
|
|
|
|
for i := 0; i < rVal.NumField(); i++ {
|
|
|
|
field := rVal.Field(i)
|
|
|
|
fieldName := rVal.Type().Field(i)
|
|
|
|
|
|
|
|
// Skip fields that aren't exported
|
|
|
|
if unicode.IsLower([]rune(fieldName.Name)[0]) {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
err := assertAllFieldsSetValue(fmt.Sprintf("%s.%s", name, fieldName.Name), field)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func intPtr(i int) *int {
|
|
|
|
return &i
|
|
|
|
}
|
|
|
|
|
|
|
|
func float64Ptr(f float64) *float64 {
|
|
|
|
return &f
|
|
|
|
}
|
2021-04-08 16:43:39 +00:00
|
|
|
|
2020-10-15 19:20:12 +00:00
|
|
|
func strPtr(str string) *string {
|
|
|
|
return &str
|
|
|
|
}
|
|
|
|
|
|
|
|
func newStructPb(t *testing.T, m map[string]interface{}) *structpb.Struct {
|
|
|
|
t.Helper()
|
|
|
|
|
|
|
|
s, err := structpb.NewStruct(m)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to convert map to struct: %s", err)
|
|
|
|
}
|
|
|
|
return s
|
|
|
|
}
|