logical/aws: Refactor role validation (#7276)

This refactors role validation for the AWS secrets engine to be in a
separate method. Previously, all validation was interspersed with the
parsing of parameters when creating/updating a role, which led to a high
degree of complexity. Now, all validation is centralized which makes it
easier to understand and also easier to test (and so a number of test
cases have been added).
This commit is contained in:
Joel Thompson 2019-08-08 14:53:06 -04:00 committed by Becca Petrin
parent baed23d816
commit e4b9efd37f
2 changed files with 173 additions and 33 deletions

View file

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/aws/arn"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/helper/strutil"
@ -212,12 +213,7 @@ func (b *backend) pathRolesWrite(ctx context.Context, req *logical.Request, d *f
if legacyRole != "" { if legacyRole != "" {
return logical.ErrorResponse("cannot supply deprecated role or policy parameters with an explicit credential_type"), nil return logical.ErrorResponse("cannot supply deprecated role or policy parameters with an explicit credential_type"), nil
} }
credentialType := credentialTypeRaw.(string) roleEntry.CredentialTypes = []string{credentialTypeRaw.(string)}
allowedCredentialTypes := []string{iamUserCred, assumedRoleCred, federationTokenCred}
if !strutil.StrListContains(allowedCredentialTypes, credentialType) {
return logical.ErrorResponse(fmt.Sprintf("unrecognized credential_type: %q, not one of %#v", credentialType, allowedCredentialTypes)), nil
}
roleEntry.CredentialTypes = []string{credentialType}
} }
if roleArnsRaw, ok := d.GetOk("role_arns"); ok { if roleArnsRaw, ok := d.GetOk("role_arns"); ok {
@ -252,9 +248,6 @@ func (b *backend) pathRolesWrite(ctx context.Context, req *logical.Request, d *f
if legacyRole != "" { if legacyRole != "" {
return logical.ErrorResponse("cannot supply deprecated role or policy parameters with default_sts_ttl"), nil return logical.ErrorResponse("cannot supply deprecated role or policy parameters with default_sts_ttl"), nil
} }
if !strutil.StrListContains(roleEntry.CredentialTypes, assumedRoleCred) && !strutil.StrListContains(roleEntry.CredentialTypes, federationTokenCred) {
return logical.ErrorResponse(fmt.Sprintf("default_sts_ttl parameter only valid for %s and %s credential types", assumedRoleCred, federationTokenCred)), nil
}
roleEntry.DefaultSTSTTL = time.Duration(defaultSTSTTLRaw.(int)) * time.Second roleEntry.DefaultSTSTTL = time.Duration(defaultSTSTTLRaw.(int)) * time.Second
} }
@ -262,10 +255,6 @@ func (b *backend) pathRolesWrite(ctx context.Context, req *logical.Request, d *f
if legacyRole != "" { if legacyRole != "" {
return logical.ErrorResponse("cannot supply deprecated role or policy parameters with max_sts_ttl"), nil return logical.ErrorResponse("cannot supply deprecated role or policy parameters with max_sts_ttl"), nil
} }
if !strutil.StrListContains(roleEntry.CredentialTypes, assumedRoleCred) && !strutil.StrListContains(roleEntry.CredentialTypes, federationTokenCred) {
return logical.ErrorResponse(fmt.Sprintf("max_sts_ttl parameter only valid for %s and %s credential types", assumedRoleCred, federationTokenCred)), nil
}
roleEntry.MaxSTSTTL = time.Duration(maxSTSTTLRaw.(int)) * time.Second roleEntry.MaxSTSTTL = time.Duration(maxSTSTTLRaw.(int)) * time.Second
} }
@ -273,21 +262,10 @@ func (b *backend) pathRolesWrite(ctx context.Context, req *logical.Request, d *f
if legacyRole != "" { if legacyRole != "" {
return logical.ErrorResponse("cannot supply deprecated role or policy parameters with user_path"), nil return logical.ErrorResponse("cannot supply deprecated role or policy parameters with user_path"), nil
} }
if !strutil.StrListContains(roleEntry.CredentialTypes, iamUserCred) {
return logical.ErrorResponse(fmt.Sprintf("user_path parameter only valid for %s credential type", iamUserCred)), nil
}
if !userPathRegex.MatchString(userPathRaw.(string)) {
return logical.ErrorResponse(fmt.Sprintf("The specified value for user_path is invalid. It must match '%s' regexp", userPathRegex.String())), nil
}
roleEntry.UserPath = userPathRaw.(string) roleEntry.UserPath = userPathRaw.(string)
} }
if roleEntry.MaxSTSTTL > 0 &&
roleEntry.DefaultSTSTTL > 0 &&
roleEntry.DefaultSTSTTL > roleEntry.MaxSTSTTL {
return logical.ErrorResponse(`"default_sts_ttl" value must be less than or equal to "max_sts_ttl" value`), nil
}
if legacyRole != "" { if legacyRole != "" {
roleEntry = upgradeLegacyPolicyEntry(legacyRole) roleEntry = upgradeLegacyPolicyEntry(legacyRole)
@ -299,15 +277,9 @@ func (b *backend) pathRolesWrite(ctx context.Context, req *logical.Request, d *f
roleEntry.ProhibitFlexibleCredPath = false roleEntry.ProhibitFlexibleCredPath = false
} }
if len(roleEntry.CredentialTypes) == 0 { err = roleEntry.validate()
return logical.ErrorResponse("did not supply credential_type"), nil if err != nil {
} return logical.ErrorResponse(fmt.Sprintf("error(s) validating supplied role data: %q", err)), nil
if len(roleEntry.RoleArns) > 0 && !strutil.StrListContains(roleEntry.CredentialTypes, assumedRoleCred) {
return logical.ErrorResponse(fmt.Sprintf("cannot supply role_arns when credential_type isn't %s", assumedRoleCred)), nil
}
if len(roleEntry.PolicyArns) > 0 && !strutil.StrListContains(roleEntry.CredentialTypes, iamUserCred) {
return logical.ErrorResponse(fmt.Sprintf("cannot supply policy_arns when credential_type isn't %s", iamUserCred)), nil
} }
err = setAwsRole(ctx, req.Storage, roleName, roleEntry) err = setAwsRole(ctx, req.Storage, roleName, roleEntry)
@ -490,6 +462,53 @@ func (r *awsRoleEntry) toResponseData() map[string]interface{} {
return respData return respData
} }
func (r *awsRoleEntry) validate() error {
var errors *multierror.Error
if len(r.CredentialTypes) == 0 {
errors = multierror.Append(errors, fmt.Errorf("did not supply credential_type"))
}
allowedCredentialTypes := []string{iamUserCred, assumedRoleCred, federationTokenCred}
for _, credType := range r.CredentialTypes {
if !strutil.StrListContains(allowedCredentialTypes, credType) {
errors = multierror.Append(errors, fmt.Errorf("unrecognized credential type: %s", credType))
}
}
if r.DefaultSTSTTL != 0 && !strutil.StrListContains(r.CredentialTypes, assumedRoleCred) && !strutil.StrListContains(r.CredentialTypes, federationTokenCred) {
errors = multierror.Append(errors, fmt.Errorf("default_sts_ttl parameter only valid for %s and %s credential types", assumedRoleCred, federationTokenCred))
}
if r.MaxSTSTTL != 0 && !strutil.StrListContains(r.CredentialTypes, assumedRoleCred) && !strutil.StrListContains(r.CredentialTypes, federationTokenCred) {
errors = multierror.Append(errors, fmt.Errorf("max_sts_ttl parameter only valid for %s and %s credential types", assumedRoleCred, federationTokenCred))
}
if r.MaxSTSTTL > 0 &&
r.DefaultSTSTTL > 0 &&
r.DefaultSTSTTL > r.MaxSTSTTL {
errors = multierror.Append(errors, fmt.Errorf(`"default_sts_ttl" value must be less than or equal to "max_sts_ttl" value`))
}
if r.UserPath != "" {
if !strutil.StrListContains(r.CredentialTypes, iamUserCred) {
errors = multierror.Append(errors, fmt.Errorf("user_path parameter only valid for %s credential type", iamUserCred))
}
if !userPathRegex.MatchString(r.UserPath) {
errors = multierror.Append(errors, fmt.Errorf("The specified value for user_path is invalid. It must match '%s' regexp", userPathRegex.String()))
}
}
if len(r.RoleArns) > 0 && !strutil.StrListContains(r.CredentialTypes, assumedRoleCred) {
errors = multierror.Append(errors, fmt.Errorf("cannot supply role_arns when credential_type isn't %s", assumedRoleCred))
}
if len(r.PolicyArns) > 0 && !strutil.StrListContains(r.CredentialTypes, iamUserCred) {
errors = multierror.Append(errors, fmt.Errorf("cannot supply policy_arns when credential_type isn't %s", iamUserCred))
}
return errors.ErrorOrNil()
}
func compactJSON(input string) (string, error) { func compactJSON(input string) (string, error) {
var compacted bytes.Buffer var compacted bytes.Buffer
err := json.Compact(&compacted, []byte(input)) err := json.Compact(&compacted, []byte(input))

View file

@ -213,3 +213,124 @@ func TestUserPathValidity(t *testing.T) {
}) })
} }
} }
func TestRoleEntryValidationCredTypes(t *testing.T) {
roleEntry := awsRoleEntry{
CredentialTypes: []string{},
PolicyArns: []string{"arn:aws:iam::aws:policy/AdministratorAccess"},
}
if roleEntry.validate() == nil {
t.Errorf("bad: invalid roleEntry with no CredentialTypes %#v passed validation", roleEntry)
}
roleEntry.CredentialTypes = []string{"invalid_type"}
if roleEntry.validate() == nil {
t.Errorf("bad: invalid roleEntry with invalid CredentialTypes %#v passed validation", roleEntry)
}
roleEntry.CredentialTypes = []string{iamUserCred, "invalid_type"}
if roleEntry.validate() == nil {
t.Errorf("bad: invalid roleEntry with invalid CredentialTypes %#v passed validation", roleEntry)
}
}
func TestRoleEntryValidationIamUserCred(t *testing.T) {
var allowAllPolicyDocument = `{"Version": "2012-10-17", "Statement": [{"Sid": "AllowAll", "Effect": "Allow", "Action": "*", "Resource": "*"}]}`
roleEntry := awsRoleEntry{
CredentialTypes: []string{iamUserCred},
PolicyArns: []string{"arn:aws:iam::aws:policy/AdministratorAccess"},
}
err := roleEntry.validate()
if err != nil {
t.Errorf("bad: valid roleEntry %#v failed validation: %v", roleEntry, err)
}
roleEntry.PolicyDocument = allowAllPolicyDocument
err = roleEntry.validate()
if err != nil {
t.Errorf("bad: valid roleEntry %#v failed validation: %v", roleEntry, err)
}
roleEntry.PolicyArns = []string{}
err = roleEntry.validate()
if err != nil {
t.Errorf("bad: valid roleEntry %#v failed validation: %v", roleEntry, err)
}
roleEntry = awsRoleEntry{
CredentialTypes: []string{iamUserCred},
RoleArns: []string{"arn:aws:iam::123456789012:role/SomeRole"},
}
if roleEntry.validate() == nil {
t.Errorf("bad: invalid roleEntry with invalid RoleArns parameter %#v passed validation", roleEntry)
}
roleEntry = awsRoleEntry{
CredentialTypes: []string{iamUserCred},
PolicyArns: []string{"arn:aws:iam::aws:policy/AdministratorAccess"},
DefaultSTSTTL: 1,
}
if roleEntry.validate() == nil {
t.Errorf("bad: invalid roleEntry with unrecognized DefaultSTSTTL %#v passed validation", roleEntry)
}
roleEntry.DefaultSTSTTL = 0
roleEntry.MaxSTSTTL = 1
if roleEntry.validate() == nil {
t.Errorf("bad: invalid roleEntry with unrecognized MaxSTSTTL %#v passed validation", roleEntry)
}
}
func TestRoleEntryValidationAssumedRoleCred(t *testing.T) {
var allowAllPolicyDocument = `{"Version": "2012-10-17", "Statement": [{"Sid": "AllowAll", "Effect": "Allow", "Action": "*", "Resource": "*"}]}`
roleEntry := awsRoleEntry{
CredentialTypes: []string{assumedRoleCred},
RoleArns: []string{"arn:aws:iam::123456789012:role/SomeRole"},
PolicyDocument: allowAllPolicyDocument,
DefaultSTSTTL: 2,
MaxSTSTTL: 3,
}
if err := roleEntry.validate(); err != nil {
t.Errorf("bad: valid roleEntry %#v failed validation: %v", roleEntry, err)
}
roleEntry.PolicyArns = []string{"arn:aws:iam::aws:policy/AdministratorAccess"}
if roleEntry.validate() == nil {
t.Errorf("bad: invalid roleEntry with unrecognized PolicyArns %#v passed validation", roleEntry)
}
roleEntry.PolicyArns = []string{}
roleEntry.MaxSTSTTL = 1
if roleEntry.validate() == nil {
t.Errorf("bad: invalid roleEntry with MaxSTSTTL < DefaultSTSTTL %#v passed validation", roleEntry)
}
roleEntry.MaxSTSTTL = 0
roleEntry.UserPath = "/foobar/"
if roleEntry.validate() == nil {
t.Errorf("bad: invalid roleEntry with unrecognized UserPath %#v passed validation", roleEntry)
}
}
func TestRoleEntryValidationFederationTokenCred(t *testing.T) {
var allowAllPolicyDocument = `{"Version": "2012-10-17", "Statement": [{"Sid": "AllowAll", "Effect": "Allow", "Action": "*", "Resource": "*"}]}`
roleEntry := awsRoleEntry{
CredentialTypes: []string{federationTokenCred},
PolicyDocument: allowAllPolicyDocument,
DefaultSTSTTL: 2,
MaxSTSTTL: 3,
}
if err := roleEntry.validate(); err != nil {
t.Errorf("bad: valid roleEntry %#v failed validation: %v", roleEntry, err)
}
roleEntry.RoleArns = []string{"arn:aws:iam::123456789012:role/SomeRole"}
if roleEntry.validate() == nil {
t.Errorf("bad: invalid roleEntry with unrecognized RoleArns %#v passed validation", roleEntry)
}
roleEntry.RoleArns = []string{}
roleEntry.UserPath = "/foobar/"
if roleEntry.validate() == nil {
t.Errorf("bad: invalid roleEntry with unrecognized UserPath %#v passed validation", roleEntry)
}
roleEntry.UserPath = ""
roleEntry.MaxSTSTTL = 1
if roleEntry.validate() == nil {
t.Errorf("bad: invalid roleEntry with MaxSTSTTL < DefaultSTSTTL %#v passed validation", roleEntry)
}
}