lib/decode: fix hook to work with embedded squash struct

The decode hook is not call for the embedded squashed struct, so we need to recurse when we
find squash tags.

See https://github.com/mitchellh/mapstructure/issues/226
This commit is contained in:
Daniel Nephin 2020-12-22 14:14:01 -05:00
parent bc04a155fb
commit 7ffd1560aa
2 changed files with 89 additions and 15 deletions

View File

@ -73,12 +73,28 @@ func translationsForType(to reflect.Type) map[string]string {
translations := map[string]string{}
for i := 0; i < to.NumField(); i++ {
field := to.Field(i)
tags := fieldTags(field)
if tags.squash {
embedded := field.Type
if embedded.Kind() == reflect.Ptr {
embedded = embedded.Elem()
}
if embedded.Kind() != reflect.Struct {
// mapstructure will handle reporting this error
continue
}
for k, v := range translationsForType(embedded) {
translations[k] = v
}
continue
}
tag, ok := field.Tag.Lookup("alias")
if !ok {
continue
}
canonKey := strings.ToLower(canonicalFieldKey(field))
canonKey := strings.ToLower(tags.name)
for _, alias := range strings.Split(tag, ",") {
translations[strings.ToLower(alias)] = canonKey
}
@ -86,19 +102,31 @@ func translationsForType(to reflect.Type) map[string]string {
return translations
}
func canonicalFieldKey(field reflect.StructField) string {
func fieldTags(field reflect.StructField) mapstructureFieldTags {
tag, ok := field.Tag.Lookup("mapstructure")
if !ok {
return field.Name
return mapstructureFieldTags{name: field.Name}
}
parts := strings.SplitN(tag, ",", 2)
switch {
case len(parts) < 1:
return field.Name
case parts[0] == "":
return field.Name
tags := mapstructureFieldTags{name: field.Name}
parts := strings.Split(tag, ",")
if len(parts) == 0 {
return tags
}
return parts[0]
if parts[0] != "" {
tags.name = parts[0]
}
for _, part := range parts[1:] {
if part == "squash" {
tags.squash = true
}
}
return tags
}
type mapstructureFieldTags struct {
name string
squash bool
}
// HookWeakDecodeFromSlice looks for []map[string]interface{} and []interface{}

View File

@ -1,6 +1,7 @@
package decode
import (
"fmt"
"reflect"
"testing"
@ -210,16 +211,29 @@ type translateExample struct {
FieldWithMapstructureTag string `alias:"second" mapstructure:"field_with_mapstruct_tag"`
FieldWithMapstructureTagOmit string `mapstructure:"field_with_mapstruct_omit,omitempty" alias:"third"`
FieldWithEmptyTag string `mapstructure:"" alias:"forth"`
EmbeddedStruct `mapstructure:",squash"`
*PtrEmbeddedStruct `mapstructure:",squash"`
BadField string `mapstructure:",squash"`
}
type EmbeddedStruct struct {
NextField string `alias:"next"`
}
type PtrEmbeddedStruct struct {
OtherNextField string `alias:"othernext"`
}
func TestTranslationsForType(t *testing.T) {
to := reflect.TypeOf(translateExample{})
actual := translationsForType(to)
expected := map[string]string{
"first": "fielddefaultcanonical",
"second": "field_with_mapstruct_tag",
"third": "field_with_mapstruct_omit",
"forth": "fieldwithemptytag",
"first": "fielddefaultcanonical",
"second": "field_with_mapstruct_tag",
"third": "field_with_mapstruct_omit",
"forth": "fieldwithemptytag",
"next": "nextfield",
"othernext": "othernextfield",
}
require.Equal(t, expected, actual)
}
@ -389,3 +403,35 @@ service {
}
require.Equal(t, target, expected)
}
func TestFieldTags(t *testing.T) {
type testCase struct {
tags string
expected mapstructureFieldTags
}
fn := func(t *testing.T, tc testCase) {
tag := fmt.Sprintf(`mapstructure:"%v"`, tc.tags)
field := reflect.StructField{
Tag: reflect.StructTag(tag),
Name: "Original",
}
actual := fieldTags(field)
require.Equal(t, tc.expected, actual)
}
var testCases = []testCase{
{tags: "", expected: mapstructureFieldTags{name: "Original"}},
{tags: "just-a-name", expected: mapstructureFieldTags{name: "just-a-name"}},
{tags: "name,squash", expected: mapstructureFieldTags{name: "name", squash: true}},
{tags: ",squash", expected: mapstructureFieldTags{name: "Original", squash: true}},
{tags: ",omitempty,squash", expected: mapstructureFieldTags{name: "Original", squash: true}},
{tags: "named,omitempty,squash", expected: mapstructureFieldTags{name: "named", squash: true}},
}
for _, tc := range testCases {
t.Run(tc.tags, func(t *testing.T) {
fn(t, tc)
})
}
}