// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package framework import ( "encoding/base64" "encoding/json" "errors" "fmt" "net/http" "regexp" "strings" "github.com/hashicorp/errwrap" "github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/mitchellh/mapstructure" ) // FieldData is the structure passed to the callback to handle a path // containing the populated parameters for fields. This should be used // instead of the raw (*vault.Request).Data to access data in a type-safe // way. type FieldData struct { Raw map[string]interface{} Schema map[string]*FieldSchema } // Validate cycles through raw data and validates conversions in // the schema, so we don't get an error/panic later when // trying to get data out. Data not in the schema is not // an error at this point, so we don't worry about it. func (d *FieldData) Validate() error { for field, value := range d.Raw { schema, ok := d.Schema[field] if !ok { continue } switch schema.Type { case TypeBool, TypeInt, TypeInt64, TypeMap, TypeDurationSecond, TypeSignedDurationSecond, TypeString, TypeLowerCaseString, TypeNameString, TypeSlice, TypeStringSlice, TypeCommaStringSlice, TypeKVPairs, TypeCommaIntSlice, TypeHeader, TypeFloat, TypeTime: _, _, err := d.getPrimitive(field, schema) if err != nil { return errwrap.Wrapf(fmt.Sprintf("error converting input %v for field %q: {{err}}", value, field), err) } default: return fmt.Errorf("unknown field type %q for field %q", schema.Type, field) } } return nil } // ValidateStrict cycles through raw data and validates conversions in the // schema. In addition to the checks done by Validate, this function ensures // that the raw data has all of the schema's required fields and does not // have any fields outside of the schema. It will return a non-nil error if: // // 1. a conversion (parsing of the field's value) fails // 2. a raw field does not exist in the schema (unless the schema is nil) // 3. a required schema field is missing from the raw data // // This function is currently used for validating response schemas in tests. func (d *FieldData) ValidateStrict() error { // the schema is nil, nothing to validate if d.Schema == nil { return nil } for field := range d.Raw { if _, _, err := d.GetOkErr(field); err != nil { return fmt.Errorf("field %q: %w", field, err) } } for field, schema := range d.Schema { if !schema.Required { continue } if _, ok := d.Raw[field]; !ok { return fmt.Errorf("missing required field %q", field) } } return nil } // Get gets the value for the given field. If the key is an invalid field, // FieldData will panic. If you want a safer version of this method, use // GetOk. If the field k is not set, the default value (if set) will be // returned, otherwise the zero value will be returned. func (d *FieldData) Get(k string) interface{} { schema, ok := d.Schema[k] if !ok { panic(fmt.Sprintf("field %s not in the schema", k)) } // If the value can't be decoded, use the zero or default value for the field // type value, ok := d.GetOk(k) if !ok || value == nil { value = schema.DefaultOrZero() } return value } // GetDefaultOrZero gets the default value set on the schema for the given // field. If there is no default value set, the zero value of the type // will be returned. func (d *FieldData) GetDefaultOrZero(k string) interface{} { schema, ok := d.Schema[k] if !ok { panic(fmt.Sprintf("field %s not in the schema", k)) } return schema.DefaultOrZero() } // GetFirst gets the value for the given field names, in order from first // to last. This can be useful for fields with a current name, and one or // more deprecated names. The second return value will be false if the keys // are invalid or the keys are not set at all. func (d *FieldData) GetFirst(k ...string) (interface{}, bool) { for _, v := range k { if result, ok := d.GetOk(v); ok { return result, ok } } return nil, false } // GetOk gets the value for the given field. The second return value will be // false if the key is invalid or the key is not set at all. If the field k is // set and the decoded value is nil, the default or zero value // will be returned instead. func (d *FieldData) GetOk(k string) (interface{}, bool) { schema, ok := d.Schema[k] if !ok { return nil, false } result, ok, err := d.GetOkErr(k) if err != nil { panic(fmt.Sprintf("error reading %s: %s", k, err)) } if ok && result == nil { result = schema.DefaultOrZero() } return result, ok } // GetOkErr is the most conservative of all the Get methods. It returns // whether key is set or not, but also an error value. The error value is // non-nil if the field doesn't exist or there was an error parsing the // field value. func (d *FieldData) GetOkErr(k string) (interface{}, bool, error) { schema, ok := d.Schema[k] if !ok { return nil, false, fmt.Errorf("unknown field: %q", k) } switch schema.Type { case TypeBool, TypeInt, TypeInt64, TypeMap, TypeDurationSecond, TypeSignedDurationSecond, TypeString, TypeLowerCaseString, TypeNameString, TypeSlice, TypeStringSlice, TypeCommaStringSlice, TypeKVPairs, TypeCommaIntSlice, TypeHeader, TypeFloat, TypeTime: return d.getPrimitive(k, schema) default: return nil, false, fmt.Errorf("unknown field type %q for field %q", schema.Type, k) } } func (d *FieldData) getPrimitive(k string, schema *FieldSchema) (interface{}, bool, error) { raw, ok := d.Raw[k] if !ok { return nil, false, nil } switch t := schema.Type; t { case TypeBool: var result bool if err := mapstructure.WeakDecode(raw, &result); err != nil { return nil, false, err } return result, true, nil case TypeInt: var result int if err := mapstructure.WeakDecode(raw, &result); err != nil { return nil, false, err } return result, true, nil case TypeInt64: var result int64 if err := mapstructure.WeakDecode(raw, &result); err != nil { return nil, false, err } return result, true, nil case TypeFloat: var result float64 if err := mapstructure.WeakDecode(raw, &result); err != nil { return nil, false, err } return result, true, nil case TypeString: var result string if err := mapstructure.WeakDecode(raw, &result); err != nil { return nil, false, err } return result, true, nil case TypeLowerCaseString: var result string if err := mapstructure.WeakDecode(raw, &result); err != nil { return nil, false, err } return strings.ToLower(result), true, nil case TypeNameString: var result string if err := mapstructure.WeakDecode(raw, &result); err != nil { return nil, false, err } matched, err := regexp.MatchString("^\\w(([\\w-.]+)?\\w)?$", result) if err != nil { return nil, false, err } if !matched { return nil, false, errors.New("field does not match the formatting rules") } return result, true, nil case TypeMap: var result map[string]interface{} if err := mapstructure.WeakDecode(raw, &result); err != nil { return nil, false, err } return result, true, nil case TypeDurationSecond, TypeSignedDurationSecond: var result int switch inp := raw.(type) { case nil: return nil, false, nil default: dur, err := parseutil.ParseDurationSecond(inp) if err != nil { return nil, false, err } result = int(dur.Seconds()) } if t == TypeDurationSecond && result < 0 { return nil, false, fmt.Errorf("cannot provide negative value '%d'", result) } return result, true, nil case TypeTime: switch inp := raw.(type) { case nil: // Handle nil interface{} as a non-error case return nil, false, nil default: time, err := parseutil.ParseAbsoluteTime(inp) if err != nil { return nil, false, err } return time.UTC(), true, nil } case TypeCommaIntSlice: var result []int jsonIn, ok := raw.(json.Number) if ok { raw = jsonIn.String() } config := &mapstructure.DecoderConfig{ Result: &result, WeaklyTypedInput: true, DecodeHook: mapstructure.StringToSliceHookFunc(","), } decoder, err := mapstructure.NewDecoder(config) if err != nil { return nil, false, err } if err := decoder.Decode(raw); err != nil { return nil, false, err } if len(result) == 0 { return make([]int, 0), true, nil } return result, true, nil case TypeSlice: var result []interface{} if err := mapstructure.WeakDecode(raw, &result); err != nil { return nil, false, err } if len(result) == 0 { return make([]interface{}, 0), true, nil } return result, true, nil case TypeStringSlice: rawString, ok := raw.(string) if ok && rawString == "" { return []string{}, true, nil } var result []string if err := mapstructure.WeakDecode(raw, &result); err != nil { return nil, false, err } if len(result) == 0 { return make([]string, 0), true, nil } return strutil.TrimStrings(result), true, nil case TypeCommaStringSlice: res, err := parseutil.ParseCommaStringSlice(raw) if err != nil { return nil, false, err } return res, true, nil case TypeKVPairs: // First try to parse this as a map var mapResult map[string]string if err := mapstructure.WeakDecode(raw, &mapResult); err == nil { return mapResult, true, nil } // If map parse fails, parse as a string list of = delimited pairs var listResult []string if err := mapstructure.WeakDecode(raw, &listResult); err != nil { return nil, false, err } result := make(map[string]string, len(listResult)) for _, keyPair := range listResult { keyPairSlice := strings.SplitN(keyPair, "=", 2) if len(keyPairSlice) != 2 || keyPairSlice[0] == "" { return nil, false, fmt.Errorf("invalid key pair %q", keyPair) } result[keyPairSlice[0]] = keyPairSlice[1] } return result, true, nil case TypeHeader: /* There are multiple ways a header could be provided: 1. As a map[string]interface{} that resolves to a map[string]string or map[string][]string, or a mix of both because that's permitted for headers. This mainly comes from the API. 2. As a string... a. That contains JSON that originally was JSON, but then was base64 encoded. b. That contains JSON, ex. `{"content-type":"text/json","accept":["encoding/json"]}`. This mainly comes from the API and is used to save space while sending in the header. 3. As an array of strings that contains comma-delimited key-value pairs associated via a colon, ex: `content-type:text/json`,`accept:encoding/json`. This mainly comes from the CLI. We go through these sequentially below. */ result := http.Header{} toHeader := func(resultMap map[string]interface{}) (http.Header, error) { header := http.Header{} for headerKey, headerValGroup := range resultMap { switch typedHeader := headerValGroup.(type) { case string: header.Add(headerKey, typedHeader) case []string: for _, headerVal := range typedHeader { header.Add(headerKey, headerVal) } case json.Number: header.Add(headerKey, typedHeader.String()) case []interface{}: for _, headerVal := range typedHeader { switch typedHeader := headerVal.(type) { case string: header.Add(headerKey, typedHeader) case json.Number: header.Add(headerKey, typedHeader.String()) default: // All header values should already be strings when they're being sent in. // Even numbers and booleans will be treated as strings. return nil, fmt.Errorf("received non-string value for header key:%s, val:%s", headerKey, headerValGroup) } } default: return nil, fmt.Errorf("unrecognized type for %s", headerValGroup) } } return header, nil } resultMap := make(map[string]interface{}) // 1. Are we getting a map from the API? if err := mapstructure.WeakDecode(raw, &resultMap); err == nil { result, err = toHeader(resultMap) if err != nil { return nil, false, err } return result, true, nil } // 2. Are we getting a JSON string? if headerStr, ok := raw.(string); ok { // a. Is it base64 encoded? headerBytes, err := base64.StdEncoding.DecodeString(headerStr) if err != nil { // b. It's not base64 encoded, it's a straight-out JSON string. headerBytes = []byte(headerStr) } if err := jsonutil.DecodeJSON(headerBytes, &resultMap); err != nil { return nil, false, err } result, err = toHeader(resultMap) if err != nil { return nil, false, err } return result, true, nil } // 3. Are we getting an array of fields like "content-type:encoding/json" from the CLI? var keyPairs []interface{} if err := mapstructure.WeakDecode(raw, &keyPairs); err == nil { for _, keyPairIfc := range keyPairs { keyPair, ok := keyPairIfc.(string) if !ok { return nil, false, fmt.Errorf("invalid key pair %q", keyPair) } keyPairSlice := strings.SplitN(keyPair, ":", 2) if len(keyPairSlice) != 2 || keyPairSlice[0] == "" { return nil, false, fmt.Errorf("invalid key pair %q", keyPair) } result.Add(keyPairSlice[0], keyPairSlice[1]) } return result, true, nil } return nil, false, fmt.Errorf("%s not provided an expected format", raw) default: panic(fmt.Sprintf("Unknown type: %s", schema.Type)) } }