move custom metadata validation logic to its own package (#16464)

* move custom metadata validation logic to its own package

* add comments

* add custom metadata Validate unit tests
This commit is contained in:
Chris Capurso 2022-07-28 10:40:38 -04:00 committed by GitHub
parent 488858e919
commit 013e1d12b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 165 additions and 59 deletions

View File

@ -0,0 +1,78 @@
package custommetadata
import (
"fmt"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-secure-stdlib/strutil"
)
// CustomMetadata should be arbitrary user-provided key-value pairs meant to
// provide supplemental information about a resource.
type CustomMetadata map[string]string
// The following constants are used by Validate and are meant to be imposed
// broadly for consistency.
const (
maxKeys = 64
maxKeyLength = 128
maxValueLength = 512
validationErrorPrefix = "custom_metadata validation failed"
)
// Validate will perform input validation for custom metadata. If the key count
// exceeds maxKeys, the validation will be short-circuited to prevent
// unnecessary (and potentially costly) validation to be run. If the key count
// falls at or below maxKeys, multiple checks will be made per key and value.
// These checks include:
// - 0 < length of key <= maxKeyLength
// - 0 < length of value <= maxValueLength
// - keys and values cannot include unprintable characters
func Validate(cm CustomMetadata) error {
var errs *multierror.Error
if keyCount := len(cm); keyCount > maxKeys {
errs = multierror.Append(errs, fmt.Errorf("%s: payload must contain at most %d keys, provided %d",
validationErrorPrefix,
maxKeys,
keyCount))
return errs.ErrorOrNil()
}
// Perform validation on each key and value and return ALL errors
for key, value := range cm {
if keyLen := len(key); 0 == keyLen || keyLen > maxKeyLength {
errs = multierror.Append(errs, fmt.Errorf("%s: length of key %q is %d but must be 0 < len(key) <= %d",
validationErrorPrefix,
key,
keyLen,
maxKeyLength))
}
if valueLen := len(value); 0 == valueLen || valueLen > maxValueLength {
errs = multierror.Append(errs, fmt.Errorf("%s: length of value for key %q is %d but must be 0 < len(value) <= %d",
validationErrorPrefix,
key,
valueLen,
maxValueLength))
}
if !strutil.Printable(key) {
// Include unquoted format (%s) to also include the string without the unprintable
// characters visible to allow for easier debug and key identification
errs = multierror.Append(errs, fmt.Errorf("%s: key %q (%s) contains unprintable characters",
validationErrorPrefix,
key,
key))
}
if !strutil.Printable(value) {
errs = multierror.Append(errs, fmt.Errorf("%s: value for key %q contains unprintable characters",
validationErrorPrefix,
key))
}
}
return errs.ErrorOrNil()
}

View File

@ -0,0 +1,85 @@
package custommetadata
import (
"strconv"
"strings"
"testing"
)
func TestValidate(t *testing.T) {
cases := []struct {
name string
input CustomMetadata
shouldPass bool
}{
{
"valid",
CustomMetadata{
"foo": "abc",
"bar": "def",
"baz": "ghi",
},
true,
},
{
"too_many_keys",
func() CustomMetadata {
cm := make(CustomMetadata)
for i := 0; i < maxKeyLength+1; i++ {
s := strconv.Itoa(i)
cm[s] = s
}
return cm
}(),
false,
},
{
"key_too_long",
CustomMetadata{
strings.Repeat("a", maxKeyLength+1): "abc",
},
false,
},
{
"value_too_long",
CustomMetadata{
"foo": strings.Repeat("a", maxValueLength+1),
},
false,
},
{
"unprintable_key",
CustomMetadata{
"unprint\u200bable": "abc",
},
false,
},
{
"unprintable_value",
CustomMetadata{
"foo": "unprint\u200bable",
},
false,
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
err := Validate(tc.input)
if tc.shouldPass && err != nil {
t.Fatalf("expected validation to pass, input: %#v, err: %v", tc.input, err)
}
if !tc.shouldPass && err == nil {
t.Fatalf("expected validation to fail, input: %#v, err: %v", tc.input, err)
}
})
}
}

View File

@ -6,23 +6,15 @@ import (
"strings"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/storagepacker"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/custommetadata"
"github.com/hashicorp/vault/sdk/logical"
)
const (
maxCustomMetadataKeys = 64
maxCustomMetadataKeyLength = 128
maxCustomMetadataValueLength = 512
customMetadataValidationErrorPrefix = "custom_metadata validation failed"
)
// aliasPaths returns the API endpoints to operate on aliases.
// Following are the paths supported:
// entity-alias - To register/modify an alias
@ -152,7 +144,7 @@ func (i *IdentityStore) handleAliasCreateUpdate() framework.OperationFunc {
// validate customMetadata if provided
if len(customMetadata) != 0 {
if err := validateCustomMetadata(customMetadata); err != nil {
if err := custommetadata.Validate(customMetadata); err != nil {
return nil, err
}
}
@ -468,55 +460,6 @@ func (i *IdentityStore) handleAliasUpdate(ctx context.Context, canonicalID, name
}, nil
}
func validateCustomMetadata(customMetadata map[string]string) error {
var errs *multierror.Error
if keyCount := len(customMetadata); keyCount > maxCustomMetadataKeys {
errs = multierror.Append(errs, fmt.Errorf("%s: payload must contain at most %d keys, provided %d",
customMetadataValidationErrorPrefix,
maxCustomMetadataKeys,
keyCount))
return errs.ErrorOrNil()
}
// Perform validation on each key and value and return ALL errors
for key, value := range customMetadata {
if keyLen := len(key); 0 == keyLen || keyLen > maxCustomMetadataKeyLength {
errs = multierror.Append(errs, fmt.Errorf("%s: length of key %q is %d but must be 0 < len(key) <= %d",
customMetadataValidationErrorPrefix,
key,
keyLen,
maxCustomMetadataKeyLength))
}
if valueLen := len(value); 0 == valueLen || valueLen > maxCustomMetadataValueLength {
errs = multierror.Append(errs, fmt.Errorf("%s: length of value for key %q is %d but must be 0 < len(value) <= %d",
customMetadataValidationErrorPrefix,
key,
valueLen,
maxCustomMetadataValueLength))
}
if !strutil.Printable(key) {
// Include unquoted format (%s) to also include the string without the unprintable
// characters visible to allow for easier debug and key identification
errs = multierror.Append(errs, fmt.Errorf("%s: key %q (%s) contains unprintable characters",
customMetadataValidationErrorPrefix,
key,
key))
}
if !strutil.Printable(value) {
errs = multierror.Append(errs, fmt.Errorf("%s: value for key %q contains unprintable characters",
customMetadataValidationErrorPrefix,
key))
}
}
return errs.ErrorOrNil()
}
// pathAliasIDRead returns the properties of an alias for a given
// alias ID
func (i *IdentityStore) pathAliasIDRead() framework.OperationFunc {