open-vault/sdk/framework/field_data.go
Anton Averchenkov 9696600e59
Add response schema validation methods & test helpers (#18635)
This pull request adds 3 functions (and corresponding tests):

`testhelpers/response_validation.go`:

  - `ValidateResponse`
  - `ValidateResponseData`
  
field_data.go:

  - `ValidateStrict` (has the "strict" validation logic)

The functions are primarily meant to be used in tests to ensure that the responses are consistent with the defined response schema. An example of how the functions can be used in tests can be found in #18636.

### Background

This PR is part of the ongoing work to add structured responses in Vault OpenAPI (VLT-234)
2023-01-13 14:55:56 -05:00

464 lines
13 KiB
Go

package framework
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"regexp"
"strings"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/mitchellh/mapstructure"
)
// FieldData is the structure passed to the callback to handle a path
// containing the populated parameters for fields. This should be used
// instead of the raw (*vault.Request).Data to access data in a type-safe
// way.
type FieldData struct {
Raw map[string]interface{}
Schema map[string]*FieldSchema
}
// Validate cycles through raw data and validates conversions in
// the schema, so we don't get an error/panic later when
// trying to get data out. Data not in the schema is not
// an error at this point, so we don't worry about it.
func (d *FieldData) Validate() error {
for field, value := range d.Raw {
schema, ok := d.Schema[field]
if !ok {
continue
}
switch schema.Type {
case TypeBool, TypeInt, TypeInt64, TypeMap, TypeDurationSecond, TypeSignedDurationSecond, TypeString,
TypeLowerCaseString, TypeNameString, TypeSlice, TypeStringSlice, TypeCommaStringSlice,
TypeKVPairs, TypeCommaIntSlice, TypeHeader, TypeFloat, TypeTime:
_, _, err := d.getPrimitive(field, schema)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("error converting input %v for field %q: {{err}}", value, field), err)
}
default:
return fmt.Errorf("unknown field type %q for field %q", schema.Type, field)
}
}
return nil
}
// ValidateStrict cycles through raw data and validates conversions in the
// schema. In addition to the checks done by Validate, this function ensures
// that the raw data has all of the schema's required fields and does not
// have any fields outside of the schema. It will return a non-nil error if:
//
// 1. a conversion (parsing of the field's value) fails
// 2. a raw field does not exist in the schema (unless the schema is nil)
// 3. a required schema field is missing from the raw data
//
// This function is currently used for validating response schemas in tests.
func (d *FieldData) ValidateStrict() error {
// the schema is nil, nothing to validate
if d.Schema == nil {
return nil
}
for field := range d.Raw {
if _, _, err := d.GetOkErr(field); err != nil {
return fmt.Errorf("field %q: %w", field, err)
}
}
for field, schema := range d.Schema {
if !schema.Required {
continue
}
if _, ok := d.Raw[field]; !ok {
return fmt.Errorf("missing required field %q", field)
}
}
return nil
}
// Get gets the value for the given field. If the key is an invalid field,
// FieldData will panic. If you want a safer version of this method, use
// GetOk. If the field k is not set, the default value (if set) will be
// returned, otherwise the zero value will be returned.
func (d *FieldData) Get(k string) interface{} {
schema, ok := d.Schema[k]
if !ok {
panic(fmt.Sprintf("field %s not in the schema", k))
}
// If the value can't be decoded, use the zero or default value for the field
// type
value, ok := d.GetOk(k)
if !ok || value == nil {
value = schema.DefaultOrZero()
}
return value
}
// GetDefaultOrZero gets the default value set on the schema for the given
// field. If there is no default value set, the zero value of the type
// will be returned.
func (d *FieldData) GetDefaultOrZero(k string) interface{} {
schema, ok := d.Schema[k]
if !ok {
panic(fmt.Sprintf("field %s not in the schema", k))
}
return schema.DefaultOrZero()
}
// GetFirst gets the value for the given field names, in order from first
// to last. This can be useful for fields with a current name, and one or
// more deprecated names. The second return value will be false if the keys
// are invalid or the keys are not set at all.
func (d *FieldData) GetFirst(k ...string) (interface{}, bool) {
for _, v := range k {
if result, ok := d.GetOk(v); ok {
return result, ok
}
}
return nil, false
}
// GetOk gets the value for the given field. The second return value will be
// false if the key is invalid or the key is not set at all. If the field k is
// set and the decoded value is nil, the default or zero value
// will be returned instead.
func (d *FieldData) GetOk(k string) (interface{}, bool) {
schema, ok := d.Schema[k]
if !ok {
return nil, false
}
result, ok, err := d.GetOkErr(k)
if err != nil {
panic(fmt.Sprintf("error reading %s: %s", k, err))
}
if ok && result == nil {
result = schema.DefaultOrZero()
}
return result, ok
}
// GetOkErr is the most conservative of all the Get methods. It returns
// whether key is set or not, but also an error value. The error value is
// non-nil if the field doesn't exist or there was an error parsing the
// field value.
func (d *FieldData) GetOkErr(k string) (interface{}, bool, error) {
schema, ok := d.Schema[k]
if !ok {
return nil, false, fmt.Errorf("unknown field: %q", k)
}
switch schema.Type {
case TypeBool, TypeInt, TypeInt64, TypeMap, TypeDurationSecond, TypeSignedDurationSecond, TypeString,
TypeLowerCaseString, TypeNameString, TypeSlice, TypeStringSlice, TypeCommaStringSlice,
TypeKVPairs, TypeCommaIntSlice, TypeHeader, TypeFloat, TypeTime:
return d.getPrimitive(k, schema)
default:
return nil, false,
fmt.Errorf("unknown field type %q for field %q", schema.Type, k)
}
}
func (d *FieldData) getPrimitive(k string, schema *FieldSchema) (interface{}, bool, error) {
raw, ok := d.Raw[k]
if !ok {
return nil, false, nil
}
switch t := schema.Type; t {
case TypeBool:
var result bool
if err := mapstructure.WeakDecode(raw, &result); err != nil {
return nil, false, err
}
return result, true, nil
case TypeInt:
var result int
if err := mapstructure.WeakDecode(raw, &result); err != nil {
return nil, false, err
}
return result, true, nil
case TypeInt64:
var result int64
if err := mapstructure.WeakDecode(raw, &result); err != nil {
return nil, false, err
}
return result, true, nil
case TypeFloat:
var result float64
if err := mapstructure.WeakDecode(raw, &result); err != nil {
return nil, false, err
}
return result, true, nil
case TypeString:
var result string
if err := mapstructure.WeakDecode(raw, &result); err != nil {
return nil, false, err
}
return result, true, nil
case TypeLowerCaseString:
var result string
if err := mapstructure.WeakDecode(raw, &result); err != nil {
return nil, false, err
}
return strings.ToLower(result), true, nil
case TypeNameString:
var result string
if err := mapstructure.WeakDecode(raw, &result); err != nil {
return nil, false, err
}
matched, err := regexp.MatchString("^\\w(([\\w-.]+)?\\w)?$", result)
if err != nil {
return nil, false, err
}
if !matched {
return nil, false, errors.New("field does not match the formatting rules")
}
return result, true, nil
case TypeMap:
var result map[string]interface{}
if err := mapstructure.WeakDecode(raw, &result); err != nil {
return nil, false, err
}
return result, true, nil
case TypeDurationSecond, TypeSignedDurationSecond:
var result int
switch inp := raw.(type) {
case nil:
return nil, false, nil
default:
dur, err := parseutil.ParseDurationSecond(inp)
if err != nil {
return nil, false, err
}
result = int(dur.Seconds())
}
if t == TypeDurationSecond && result < 0 {
return nil, false, fmt.Errorf("cannot provide negative value '%d'", result)
}
return result, true, nil
case TypeTime:
switch inp := raw.(type) {
case nil:
// Handle nil interface{} as a non-error case
return nil, false, nil
default:
time, err := parseutil.ParseAbsoluteTime(inp)
if err != nil {
return nil, false, err
}
return time.UTC(), true, nil
}
case TypeCommaIntSlice:
var result []int
jsonIn, ok := raw.(json.Number)
if ok {
raw = jsonIn.String()
}
config := &mapstructure.DecoderConfig{
Result: &result,
WeaklyTypedInput: true,
DecodeHook: mapstructure.StringToSliceHookFunc(","),
}
decoder, err := mapstructure.NewDecoder(config)
if err != nil {
return nil, false, err
}
if err := decoder.Decode(raw); err != nil {
return nil, false, err
}
if len(result) == 0 {
return make([]int, 0), true, nil
}
return result, true, nil
case TypeSlice:
var result []interface{}
if err := mapstructure.WeakDecode(raw, &result); err != nil {
return nil, false, err
}
if len(result) == 0 {
return make([]interface{}, 0), true, nil
}
return result, true, nil
case TypeStringSlice:
rawString, ok := raw.(string)
if ok && rawString == "" {
return []string{}, true, nil
}
var result []string
if err := mapstructure.WeakDecode(raw, &result); err != nil {
return nil, false, err
}
if len(result) == 0 {
return make([]string, 0), true, nil
}
return strutil.TrimStrings(result), true, nil
case TypeCommaStringSlice:
res, err := parseutil.ParseCommaStringSlice(raw)
if err != nil {
return nil, false, err
}
return res, true, nil
case TypeKVPairs:
// First try to parse this as a map
var mapResult map[string]string
if err := mapstructure.WeakDecode(raw, &mapResult); err == nil {
return mapResult, true, nil
}
// If map parse fails, parse as a string list of = delimited pairs
var listResult []string
if err := mapstructure.WeakDecode(raw, &listResult); err != nil {
return nil, false, err
}
result := make(map[string]string, len(listResult))
for _, keyPair := range listResult {
keyPairSlice := strings.SplitN(keyPair, "=", 2)
if len(keyPairSlice) != 2 || keyPairSlice[0] == "" {
return nil, false, fmt.Errorf("invalid key pair %q", keyPair)
}
result[keyPairSlice[0]] = keyPairSlice[1]
}
return result, true, nil
case TypeHeader:
/*
There are multiple ways a header could be provided:
1. As a map[string]interface{} that resolves to a map[string]string or map[string][]string, or a mix of both
because that's permitted for headers.
This mainly comes from the API.
2. As a string...
a. That contains JSON that originally was JSON, but then was base64 encoded.
b. That contains JSON, ex. `{"content-type":"text/json","accept":["encoding/json"]}`.
This mainly comes from the API and is used to save space while sending in the header.
3. As an array of strings that contains comma-delimited key-value pairs associated via a colon,
ex: `content-type:text/json`,`accept:encoding/json`.
This mainly comes from the CLI.
We go through these sequentially below.
*/
result := http.Header{}
toHeader := func(resultMap map[string]interface{}) (http.Header, error) {
header := http.Header{}
for headerKey, headerValGroup := range resultMap {
switch typedHeader := headerValGroup.(type) {
case string:
header.Add(headerKey, typedHeader)
case []string:
for _, headerVal := range typedHeader {
header.Add(headerKey, headerVal)
}
case json.Number:
header.Add(headerKey, typedHeader.String())
case []interface{}:
for _, headerVal := range typedHeader {
switch typedHeader := headerVal.(type) {
case string:
header.Add(headerKey, typedHeader)
case json.Number:
header.Add(headerKey, typedHeader.String())
default:
// All header values should already be strings when they're being sent in.
// Even numbers and booleans will be treated as strings.
return nil, fmt.Errorf("received non-string value for header key:%s, val:%s", headerKey, headerValGroup)
}
}
default:
return nil, fmt.Errorf("unrecognized type for %s", headerValGroup)
}
}
return header, nil
}
resultMap := make(map[string]interface{})
// 1. Are we getting a map from the API?
if err := mapstructure.WeakDecode(raw, &resultMap); err == nil {
result, err = toHeader(resultMap)
if err != nil {
return nil, false, err
}
return result, true, nil
}
// 2. Are we getting a JSON string?
if headerStr, ok := raw.(string); ok {
// a. Is it base64 encoded?
headerBytes, err := base64.StdEncoding.DecodeString(headerStr)
if err != nil {
// b. It's not base64 encoded, it's a straight-out JSON string.
headerBytes = []byte(headerStr)
}
if err := jsonutil.DecodeJSON(headerBytes, &resultMap); err != nil {
return nil, false, err
}
result, err = toHeader(resultMap)
if err != nil {
return nil, false, err
}
return result, true, nil
}
// 3. Are we getting an array of fields like "content-type:encoding/json" from the CLI?
var keyPairs []interface{}
if err := mapstructure.WeakDecode(raw, &keyPairs); err == nil {
for _, keyPairIfc := range keyPairs {
keyPair, ok := keyPairIfc.(string)
if !ok {
return nil, false, fmt.Errorf("invalid key pair %q", keyPair)
}
keyPairSlice := strings.SplitN(keyPair, ":", 2)
if len(keyPairSlice) != 2 || keyPairSlice[0] == "" {
return nil, false, fmt.Errorf("invalid key pair %q", keyPair)
}
result.Add(keyPairSlice[0], keyPairSlice[1])
}
return result, true, nil
}
return nil, false, fmt.Errorf("%s not provided an expected format", raw)
default:
panic(fmt.Sprintf("Unknown type: %s", schema.Type))
}
}