From e4b9efd37f57d19e4a7ee1f11d88b04846e93796 Mon Sep 17 00:00:00 2001 From: Joel Thompson Date: Thu, 8 Aug 2019 14:53:06 -0400 Subject: [PATCH] logical/aws: Refactor role validation (#7276) This refactors role validation for the AWS secrets engine to be in a separate method. Previously, all validation was interspersed with the parsing of parameters when creating/updating a role, which led to a high degree of complexity. Now, all validation is centralized which makes it easier to understand and also easier to test (and so a number of test cases have been added). --- builtin/logical/aws/path_roles.go | 85 ++++++++++------- builtin/logical/aws/path_roles_test.go | 121 +++++++++++++++++++++++++ 2 files changed, 173 insertions(+), 33 deletions(-) diff --git a/builtin/logical/aws/path_roles.go b/builtin/logical/aws/path_roles.go index 8a3f728f9..d8e9664b6 100644 --- a/builtin/logical/aws/path_roles.go +++ b/builtin/logical/aws/path_roles.go @@ -11,6 +11,7 @@ import ( "time" "github.com/aws/aws-sdk-go/aws/arn" + "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/strutil" @@ -212,12 +213,7 @@ func (b *backend) pathRolesWrite(ctx context.Context, req *logical.Request, d *f if legacyRole != "" { return logical.ErrorResponse("cannot supply deprecated role or policy parameters with an explicit credential_type"), nil } - credentialType := credentialTypeRaw.(string) - allowedCredentialTypes := []string{iamUserCred, assumedRoleCred, federationTokenCred} - if !strutil.StrListContains(allowedCredentialTypes, credentialType) { - return logical.ErrorResponse(fmt.Sprintf("unrecognized credential_type: %q, not one of %#v", credentialType, allowedCredentialTypes)), nil - } - roleEntry.CredentialTypes = []string{credentialType} + roleEntry.CredentialTypes = []string{credentialTypeRaw.(string)} } if roleArnsRaw, ok := d.GetOk("role_arns"); ok { @@ -252,9 +248,6 @@ func (b *backend) pathRolesWrite(ctx context.Context, req *logical.Request, d *f if legacyRole != "" { return logical.ErrorResponse("cannot supply deprecated role or policy parameters with default_sts_ttl"), nil } - if !strutil.StrListContains(roleEntry.CredentialTypes, assumedRoleCred) && !strutil.StrListContains(roleEntry.CredentialTypes, federationTokenCred) { - return logical.ErrorResponse(fmt.Sprintf("default_sts_ttl parameter only valid for %s and %s credential types", assumedRoleCred, federationTokenCred)), nil - } roleEntry.DefaultSTSTTL = time.Duration(defaultSTSTTLRaw.(int)) * time.Second } @@ -262,10 +255,6 @@ func (b *backend) pathRolesWrite(ctx context.Context, req *logical.Request, d *f if legacyRole != "" { return logical.ErrorResponse("cannot supply deprecated role or policy parameters with max_sts_ttl"), nil } - if !strutil.StrListContains(roleEntry.CredentialTypes, assumedRoleCred) && !strutil.StrListContains(roleEntry.CredentialTypes, federationTokenCred) { - return logical.ErrorResponse(fmt.Sprintf("max_sts_ttl parameter only valid for %s and %s credential types", assumedRoleCred, federationTokenCred)), nil - } - roleEntry.MaxSTSTTL = time.Duration(maxSTSTTLRaw.(int)) * time.Second } @@ -273,21 +262,10 @@ func (b *backend) pathRolesWrite(ctx context.Context, req *logical.Request, d *f if legacyRole != "" { return logical.ErrorResponse("cannot supply deprecated role or policy parameters with user_path"), nil } - if !strutil.StrListContains(roleEntry.CredentialTypes, iamUserCred) { - return logical.ErrorResponse(fmt.Sprintf("user_path parameter only valid for %s credential type", iamUserCred)), nil - } - if !userPathRegex.MatchString(userPathRaw.(string)) { - return logical.ErrorResponse(fmt.Sprintf("The specified value for user_path is invalid. It must match '%s' regexp", userPathRegex.String())), nil - } roleEntry.UserPath = userPathRaw.(string) } - if roleEntry.MaxSTSTTL > 0 && - roleEntry.DefaultSTSTTL > 0 && - roleEntry.DefaultSTSTTL > roleEntry.MaxSTSTTL { - return logical.ErrorResponse(`"default_sts_ttl" value must be less than or equal to "max_sts_ttl" value`), nil - } if legacyRole != "" { roleEntry = upgradeLegacyPolicyEntry(legacyRole) @@ -299,15 +277,9 @@ func (b *backend) pathRolesWrite(ctx context.Context, req *logical.Request, d *f roleEntry.ProhibitFlexibleCredPath = false } - if len(roleEntry.CredentialTypes) == 0 { - return logical.ErrorResponse("did not supply credential_type"), nil - } - - if len(roleEntry.RoleArns) > 0 && !strutil.StrListContains(roleEntry.CredentialTypes, assumedRoleCred) { - return logical.ErrorResponse(fmt.Sprintf("cannot supply role_arns when credential_type isn't %s", assumedRoleCred)), nil - } - if len(roleEntry.PolicyArns) > 0 && !strutil.StrListContains(roleEntry.CredentialTypes, iamUserCred) { - return logical.ErrorResponse(fmt.Sprintf("cannot supply policy_arns when credential_type isn't %s", iamUserCred)), nil + err = roleEntry.validate() + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("error(s) validating supplied role data: %q", err)), nil } err = setAwsRole(ctx, req.Storage, roleName, roleEntry) @@ -490,6 +462,53 @@ func (r *awsRoleEntry) toResponseData() map[string]interface{} { return respData } +func (r *awsRoleEntry) validate() error { + var errors *multierror.Error + + if len(r.CredentialTypes) == 0 { + errors = multierror.Append(errors, fmt.Errorf("did not supply credential_type")) + } + + allowedCredentialTypes := []string{iamUserCred, assumedRoleCred, federationTokenCred} + for _, credType := range r.CredentialTypes { + if !strutil.StrListContains(allowedCredentialTypes, credType) { + errors = multierror.Append(errors, fmt.Errorf("unrecognized credential type: %s", credType)) + } + } + + if r.DefaultSTSTTL != 0 && !strutil.StrListContains(r.CredentialTypes, assumedRoleCred) && !strutil.StrListContains(r.CredentialTypes, federationTokenCred) { + errors = multierror.Append(errors, fmt.Errorf("default_sts_ttl parameter only valid for %s and %s credential types", assumedRoleCred, federationTokenCred)) + } + + if r.MaxSTSTTL != 0 && !strutil.StrListContains(r.CredentialTypes, assumedRoleCred) && !strutil.StrListContains(r.CredentialTypes, federationTokenCred) { + errors = multierror.Append(errors, fmt.Errorf("max_sts_ttl parameter only valid for %s and %s credential types", assumedRoleCred, federationTokenCred)) + } + + if r.MaxSTSTTL > 0 && + r.DefaultSTSTTL > 0 && + r.DefaultSTSTTL > r.MaxSTSTTL { + errors = multierror.Append(errors, fmt.Errorf(`"default_sts_ttl" value must be less than or equal to "max_sts_ttl" value`)) + } + + if r.UserPath != "" { + if !strutil.StrListContains(r.CredentialTypes, iamUserCred) { + errors = multierror.Append(errors, fmt.Errorf("user_path parameter only valid for %s credential type", iamUserCred)) + } + if !userPathRegex.MatchString(r.UserPath) { + errors = multierror.Append(errors, fmt.Errorf("The specified value for user_path is invalid. It must match '%s' regexp", userPathRegex.String())) + } + } + + if len(r.RoleArns) > 0 && !strutil.StrListContains(r.CredentialTypes, assumedRoleCred) { + errors = multierror.Append(errors, fmt.Errorf("cannot supply role_arns when credential_type isn't %s", assumedRoleCred)) + } + if len(r.PolicyArns) > 0 && !strutil.StrListContains(r.CredentialTypes, iamUserCred) { + errors = multierror.Append(errors, fmt.Errorf("cannot supply policy_arns when credential_type isn't %s", iamUserCred)) + } + + return errors.ErrorOrNil() +} + func compactJSON(input string) (string, error) { var compacted bytes.Buffer err := json.Compact(&compacted, []byte(input)) diff --git a/builtin/logical/aws/path_roles_test.go b/builtin/logical/aws/path_roles_test.go index c421da606..18a60c8f4 100644 --- a/builtin/logical/aws/path_roles_test.go +++ b/builtin/logical/aws/path_roles_test.go @@ -213,3 +213,124 @@ func TestUserPathValidity(t *testing.T) { }) } } + +func TestRoleEntryValidationCredTypes(t *testing.T) { + roleEntry := awsRoleEntry{ + CredentialTypes: []string{}, + PolicyArns: []string{"arn:aws:iam::aws:policy/AdministratorAccess"}, + } + if roleEntry.validate() == nil { + t.Errorf("bad: invalid roleEntry with no CredentialTypes %#v passed validation", roleEntry) + } + roleEntry.CredentialTypes = []string{"invalid_type"} + if roleEntry.validate() == nil { + t.Errorf("bad: invalid roleEntry with invalid CredentialTypes %#v passed validation", roleEntry) + } + roleEntry.CredentialTypes = []string{iamUserCred, "invalid_type"} + if roleEntry.validate() == nil { + t.Errorf("bad: invalid roleEntry with invalid CredentialTypes %#v passed validation", roleEntry) + } +} + +func TestRoleEntryValidationIamUserCred(t *testing.T) { + var allowAllPolicyDocument = `{"Version": "2012-10-17", "Statement": [{"Sid": "AllowAll", "Effect": "Allow", "Action": "*", "Resource": "*"}]}` + + roleEntry := awsRoleEntry{ + CredentialTypes: []string{iamUserCred}, + PolicyArns: []string{"arn:aws:iam::aws:policy/AdministratorAccess"}, + } + err := roleEntry.validate() + if err != nil { + t.Errorf("bad: valid roleEntry %#v failed validation: %v", roleEntry, err) + } + roleEntry.PolicyDocument = allowAllPolicyDocument + err = roleEntry.validate() + if err != nil { + t.Errorf("bad: valid roleEntry %#v failed validation: %v", roleEntry, err) + } + roleEntry.PolicyArns = []string{} + err = roleEntry.validate() + if err != nil { + t.Errorf("bad: valid roleEntry %#v failed validation: %v", roleEntry, err) + } + + roleEntry = awsRoleEntry{ + CredentialTypes: []string{iamUserCred}, + RoleArns: []string{"arn:aws:iam::123456789012:role/SomeRole"}, + } + if roleEntry.validate() == nil { + t.Errorf("bad: invalid roleEntry with invalid RoleArns parameter %#v passed validation", roleEntry) + } + + roleEntry = awsRoleEntry{ + CredentialTypes: []string{iamUserCred}, + PolicyArns: []string{"arn:aws:iam::aws:policy/AdministratorAccess"}, + DefaultSTSTTL: 1, + } + if roleEntry.validate() == nil { + t.Errorf("bad: invalid roleEntry with unrecognized DefaultSTSTTL %#v passed validation", roleEntry) + } + roleEntry.DefaultSTSTTL = 0 + roleEntry.MaxSTSTTL = 1 + if roleEntry.validate() == nil { + t.Errorf("bad: invalid roleEntry with unrecognized MaxSTSTTL %#v passed validation", roleEntry) + } +} + +func TestRoleEntryValidationAssumedRoleCred(t *testing.T) { + var allowAllPolicyDocument = `{"Version": "2012-10-17", "Statement": [{"Sid": "AllowAll", "Effect": "Allow", "Action": "*", "Resource": "*"}]}` + roleEntry := awsRoleEntry{ + CredentialTypes: []string{assumedRoleCred}, + RoleArns: []string{"arn:aws:iam::123456789012:role/SomeRole"}, + PolicyDocument: allowAllPolicyDocument, + DefaultSTSTTL: 2, + MaxSTSTTL: 3, + } + if err := roleEntry.validate(); err != nil { + t.Errorf("bad: valid roleEntry %#v failed validation: %v", roleEntry, err) + } + + roleEntry.PolicyArns = []string{"arn:aws:iam::aws:policy/AdministratorAccess"} + if roleEntry.validate() == nil { + t.Errorf("bad: invalid roleEntry with unrecognized PolicyArns %#v passed validation", roleEntry) + } + roleEntry.PolicyArns = []string{} + roleEntry.MaxSTSTTL = 1 + if roleEntry.validate() == nil { + t.Errorf("bad: invalid roleEntry with MaxSTSTTL < DefaultSTSTTL %#v passed validation", roleEntry) + } + roleEntry.MaxSTSTTL = 0 + roleEntry.UserPath = "/foobar/" + if roleEntry.validate() == nil { + t.Errorf("bad: invalid roleEntry with unrecognized UserPath %#v passed validation", roleEntry) + } +} + +func TestRoleEntryValidationFederationTokenCred(t *testing.T) { + var allowAllPolicyDocument = `{"Version": "2012-10-17", "Statement": [{"Sid": "AllowAll", "Effect": "Allow", "Action": "*", "Resource": "*"}]}` + roleEntry := awsRoleEntry{ + CredentialTypes: []string{federationTokenCred}, + PolicyDocument: allowAllPolicyDocument, + DefaultSTSTTL: 2, + MaxSTSTTL: 3, + } + if err := roleEntry.validate(); err != nil { + t.Errorf("bad: valid roleEntry %#v failed validation: %v", roleEntry, err) + } + + roleEntry.RoleArns = []string{"arn:aws:iam::123456789012:role/SomeRole"} + if roleEntry.validate() == nil { + t.Errorf("bad: invalid roleEntry with unrecognized RoleArns %#v passed validation", roleEntry) + } + roleEntry.RoleArns = []string{} + roleEntry.UserPath = "/foobar/" + if roleEntry.validate() == nil { + t.Errorf("bad: invalid roleEntry with unrecognized UserPath %#v passed validation", roleEntry) + } + + roleEntry.UserPath = "" + roleEntry.MaxSTSTTL = 1 + if roleEntry.validate() == nil { + t.Errorf("bad: invalid roleEntry with MaxSTSTTL < DefaultSTSTTL %#v passed validation", roleEntry) + } +}