diff --git a/builtin/credential/aws/backend.go b/builtin/credential/aws/backend.go index 5e94db7b9..4e5cd2f62 100644 --- a/builtin/credential/aws/backend.go +++ b/builtin/credential/aws/backend.go @@ -312,7 +312,7 @@ func (b *backend) resolveArnToRealUniqueId(ctx context.Context, s logical.Storag switch entity.Type { case "user": - userInfo, err := iamClient.GetUser(&iam.GetUserInput{UserName: &entity.FriendlyName}) + userInfo, err := iamClient.GetUserWithContext(ctx, &iam.GetUserInput{UserName: &entity.FriendlyName}) if err != nil { return "", awsutil.AppendAWSError(err) } @@ -321,7 +321,7 @@ func (b *backend) resolveArnToRealUniqueId(ctx context.Context, s logical.Storag } return *userInfo.User.UserId, nil case "role": - roleInfo, err := iamClient.GetRole(&iam.GetRoleInput{RoleName: &entity.FriendlyName}) + roleInfo, err := iamClient.GetRoleWithContext(ctx, &iam.GetRoleInput{RoleName: &entity.FriendlyName}) if err != nil { return "", awsutil.AppendAWSError(err) } @@ -330,7 +330,7 @@ func (b *backend) resolveArnToRealUniqueId(ctx context.Context, s logical.Storag } return *roleInfo.Role.RoleId, nil case "instance-profile": - profileInfo, err := iamClient.GetInstanceProfile(&iam.GetInstanceProfileInput{InstanceProfileName: &entity.FriendlyName}) + profileInfo, err := iamClient.GetInstanceProfileWithContext(ctx, &iam.GetInstanceProfileInput{InstanceProfileName: &entity.FriendlyName}) if err != nil { return "", awsutil.AppendAWSError(err) } diff --git a/builtin/credential/aws/client.go b/builtin/credential/aws/client.go index 079eabbe8..314c97ec3 100644 --- a/builtin/credential/aws/client.go +++ b/builtin/credential/aws/client.go @@ -122,7 +122,7 @@ func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region return nil, fmt.Errorf("could not obtain sts client: %w", err) } inputParams := &sts.GetCallerIdentityInput{} - identity, err := client.GetCallerIdentity(inputParams) + identity, err := client.GetCallerIdentityWithContext(ctx, inputParams) if err != nil { return nil, fmt.Errorf("unable to fetch current caller: %w", err) } diff --git a/builtin/credential/aws/path_config_rotate_root.go b/builtin/credential/aws/path_config_rotate_root.go index 0a28b627b..6c517b9c4 100644 --- a/builtin/credential/aws/path_config_rotate_root.go +++ b/builtin/credential/aws/path_config_rotate_root.go @@ -100,7 +100,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R // Get the current user's name since it's required to create an access key. // Empty input means get the current user. var getUserInput iam.GetUserInput - getUserRes, err := iamClient.GetUser(&getUserInput) + getUserRes, err := iamClient.GetUserWithContext(ctx, &getUserInput) if err != nil { return nil, fmt.Errorf("error calling GetUser: %w", err) } @@ -118,7 +118,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R createAccessKeyInput := iam.CreateAccessKeyInput{ UserName: getUserRes.User.UserName, } - createAccessKeyRes, err := iamClient.CreateAccessKey(&createAccessKeyInput) + createAccessKeyRes, err := iamClient.CreateAccessKeyWithContext(ctx, &createAccessKeyInput) if err != nil { return nil, fmt.Errorf("error calling CreateAccessKey: %w", err) } @@ -142,7 +142,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R AccessKeyId: createAccessKeyRes.AccessKey.AccessKeyId, UserName: getUserRes.User.UserName, } - if _, err := iamClient.DeleteAccessKey(&deleteAccessKeyInput); err != nil { + if _, err := iamClient.DeleteAccessKeyWithContext(ctx, &deleteAccessKeyInput); err != nil { // Include this error in the errs returned by this method. errs = multierror.Append(errs, fmt.Errorf("error deleting newly created but unstored access key ID %s: %s", *createAccessKeyRes.AccessKey.AccessKeyId, err)) } @@ -179,7 +179,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R AccessKeyId: aws.String(oldAccessKey), UserName: getUserRes.User.UserName, } - if _, err = iamClient.DeleteAccessKey(&deleteAccessKeyInput); err != nil { + if _, err = iamClient.DeleteAccessKeyWithContext(ctx, &deleteAccessKeyInput); err != nil { errs = multierror.Append(errs, fmt.Errorf("error deleting old access key ID %s: %w", oldAccessKey, err)) return nil, errs } diff --git a/builtin/credential/aws/path_config_rotate_root_test.go b/builtin/credential/aws/path_config_rotate_root_test.go index 21f7f0fbb..3fe5b29c0 100644 --- a/builtin/credential/aws/path_config_rotate_root_test.go +++ b/builtin/credential/aws/path_config_rotate_root_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam/iamiface" @@ -15,9 +16,23 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) +type mockIAMClient awsutil.MockIAM + +func (m *mockIAMClient) GetUserWithContext(_ aws.Context, input *iam.GetUserInput, _ ...request.Option) (*iam.GetUserOutput, error) { + return (*awsutil.MockIAM)(m).GetUser(input) +} + +func (m *mockIAMClient) CreateAccessKeyWithContext(_ aws.Context, input *iam.CreateAccessKeyInput, _ ...request.Option) (*iam.CreateAccessKeyOutput, error) { + return (*awsutil.MockIAM)(m).CreateAccessKey(input) +} + +func (m *mockIAMClient) DeleteAccessKeyWithContext(_ aws.Context, input *iam.DeleteAccessKeyInput, _ ...request.Option) (*iam.DeleteAccessKeyOutput, error) { + return (*awsutil.MockIAM)(m).DeleteAccessKey(input) +} + func TestPathConfigRotateRoot(t *testing.T) { getIAMClient = func(sess *session.Session) iamiface.IAMAPI { - return &awsutil.MockIAM{ + return &mockIAMClient{ CreateAccessKeyOutput: &iam.CreateAccessKeyOutput{ AccessKey: &iam.AccessKey{ AccessKeyId: aws.String("fizz2"), diff --git a/builtin/credential/aws/path_login.go b/builtin/credential/aws/path_login.go index 320230534..f85c5c6d8 100644 --- a/builtin/credential/aws/path_login.go +++ b/builtin/credential/aws/path_login.go @@ -106,8 +106,8 @@ This must match the request body included in the signature.`, "iam_request_headers": { Type: framework.TypeHeader, Description: `Key/value pairs of headers for use in the -sts:GetCallerIdentity HTTP requests headers when auth_type is iam. Can be either -a Base64-encoded, JSON-serialized string, or a JSON object of key/value pairs. +sts:GetCallerIdentity HTTP requests headers when auth_type is iam. Can be either +a Base64-encoded, JSON-serialized string, or a JSON object of key/value pairs. This must at a minimum include the headers over which AWS has included a signature.`, }, "identity": { @@ -340,7 +340,7 @@ func (b *backend) pathLoginResolveRoleIam(ctx context.Context, req *logical.Requ // instanceIamRoleARN fetches the IAM role ARN associated with the given // instance profile name -func (b *backend) instanceIamRoleARN(iamClient *iam.IAM, instanceProfileName string) (string, error) { +func (b *backend) instanceIamRoleARN(ctx context.Context, iamClient *iam.IAM, instanceProfileName string) (string, error) { if iamClient == nil { return "", fmt.Errorf("nil iamClient") } @@ -348,7 +348,7 @@ func (b *backend) instanceIamRoleARN(iamClient *iam.IAM, instanceProfileName str return "", fmt.Errorf("missing instance profile name") } - profile, err := iamClient.GetInstanceProfile(&iam.GetInstanceProfileInput{ + profile, err := iamClient.GetInstanceProfileWithContext(ctx, &iam.GetInstanceProfileInput{ InstanceProfileName: aws.String(instanceProfileName), }) if err != nil { @@ -382,7 +382,7 @@ func (b *backend) validateInstance(ctx context.Context, s logical.Storage, insta return nil, err } - status, err := ec2Client.DescribeInstances(&ec2.DescribeInstancesInput{ + status, err := ec2Client.DescribeInstancesWithContext(ctx, &ec2.DescribeInstancesInput{ InstanceIds: []*string{ aws.String(instanceID), }, @@ -724,7 +724,7 @@ func (b *backend) verifyInstanceMeetsRoleRequirements(ctx context.Context, } else if iamClient == nil { return nil, fmt.Errorf("received a nil iamClient") } - iamRoleARN, err := b.instanceIamRoleARN(iamClient, iamInstanceProfileEntity.FriendlyName) + iamRoleARN, err := b.instanceIamRoleARN(ctx, iamClient, iamInstanceProfileEntity.FriendlyName) if err != nil { return nil, fmt.Errorf("IAM role ARN could not be fetched: %w", err) } @@ -1835,7 +1835,7 @@ func (b *backend) fullArn(ctx context.Context, e *iamEntity, s logical.Storage) input := iam.GetUserInput{ UserName: aws.String(e.FriendlyName), } - resp, err := client.GetUser(&input) + resp, err := client.GetUserWithContext(ctx, &input) if err != nil { return "", fmt.Errorf("error fetching user %q: %w", e.FriendlyName, err) } @@ -1849,7 +1849,7 @@ func (b *backend) fullArn(ctx context.Context, e *iamEntity, s logical.Storage) input := iam.GetRoleInput{ RoleName: aws.String(e.FriendlyName), } - resp, err := client.GetRole(&input) + resp, err := client.GetRoleWithContext(ctx, &input) if err != nil { return "", fmt.Errorf("error fetching role %q: %w", e.FriendlyName, err) } diff --git a/builtin/logical/aws/backend_test.go b/builtin/logical/aws/backend_test.go index 5e59fdf2e..706a705d3 100644 --- a/builtin/logical/aws/backend_test.go +++ b/builtin/logical/aws/backend_test.go @@ -19,6 +19,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/ec2" @@ -39,7 +40,7 @@ type mockIAMClient struct { iamiface.IAMAPI } -func (m *mockIAMClient) CreateUser(input *iam.CreateUserInput) (*iam.CreateUserOutput, error) { +func (m *mockIAMClient) CreateUserWithContext(_ aws.Context, input *iam.CreateUserInput, _ ...request.Option) (*iam.CreateUserOutput, error) { return nil, awserr.New("Throttling", "", nil) } diff --git a/builtin/logical/aws/iam_policies.go b/builtin/logical/aws/iam_policies.go index 27b6f1822..002a7389e 100644 --- a/builtin/logical/aws/iam_policies.go +++ b/builtin/logical/aws/iam_policies.go @@ -73,7 +73,7 @@ func (b *backend) getGroupPolicies(ctx context.Context, s logical.Storage, iamGr for _, g := range iamGroups { // Collect managed policy ARNs from the IAM Group - agp, err = iamClient.ListAttachedGroupPolicies(&iam.ListAttachedGroupPoliciesInput{ + agp, err = iamClient.ListAttachedGroupPoliciesWithContext(ctx, &iam.ListAttachedGroupPoliciesInput{ GroupName: aws.String(g), }) if err != nil { @@ -84,14 +84,14 @@ func (b *backend) getGroupPolicies(ctx context.Context, s logical.Storage, iamGr } // Collect inline policy names from the IAM Group - inlinePolicies, err = iamClient.ListGroupPolicies(&iam.ListGroupPoliciesInput{ + inlinePolicies, err = iamClient.ListGroupPoliciesWithContext(ctx, &iam.ListGroupPoliciesInput{ GroupName: aws.String(g), }) if err != nil { return nil, nil, err } for _, iP := range inlinePolicies.PolicyNames { - inlinePolicyDoc, err = iamClient.GetGroupPolicy(&iam.GetGroupPolicyInput{ + inlinePolicyDoc, err = iamClient.GetGroupPolicyWithContext(ctx, &iam.GetGroupPolicyInput{ GroupName: &g, PolicyName: iP, }) diff --git a/builtin/logical/aws/iam_policies_test.go b/builtin/logical/aws/iam_policies_test.go index 584018630..5e2de534b 100644 --- a/builtin/logical/aws/iam_policies_test.go +++ b/builtin/logical/aws/iam_policies_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam/iamiface" "github.com/hashicorp/vault/sdk/logical" @@ -29,15 +30,15 @@ type mockGroupIAMClient struct { GetGroupPolicyResp iam.GetGroupPolicyOutput } -func (m mockGroupIAMClient) ListAttachedGroupPolicies(in *iam.ListAttachedGroupPoliciesInput) (*iam.ListAttachedGroupPoliciesOutput, error) { +func (m mockGroupIAMClient) ListAttachedGroupPoliciesWithContext(_ aws.Context, in *iam.ListAttachedGroupPoliciesInput, _ ...request.Option) (*iam.ListAttachedGroupPoliciesOutput, error) { return &m.ListAttachedGroupPoliciesResp, nil } -func (m mockGroupIAMClient) ListGroupPolicies(in *iam.ListGroupPoliciesInput) (*iam.ListGroupPoliciesOutput, error) { +func (m mockGroupIAMClient) ListGroupPoliciesWithContext(_ aws.Context, in *iam.ListGroupPoliciesInput, _ ...request.Option) (*iam.ListGroupPoliciesOutput, error) { return &m.ListGroupPoliciesResp, nil } -func (m mockGroupIAMClient) GetGroupPolicy(in *iam.GetGroupPolicyInput) (*iam.GetGroupPolicyOutput, error) { +func (m mockGroupIAMClient) GetGroupPolicyWithContext(_ aws.Context, in *iam.GetGroupPolicyInput, _ ...request.Option) (*iam.GetGroupPolicyOutput, error) { return &m.GetGroupPolicyResp, nil } diff --git a/builtin/logical/aws/path_config_rotate_root.go b/builtin/logical/aws/path_config_rotate_root.go index 212a9eb3a..295b08547 100644 --- a/builtin/logical/aws/path_config_rotate_root.go +++ b/builtin/logical/aws/path_config_rotate_root.go @@ -59,7 +59,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R } var getUserInput iam.GetUserInput // empty input means get current user - getUserRes, err := client.GetUser(&getUserInput) + getUserRes, err := client.GetUserWithContext(ctx, &getUserInput) if err != nil { return nil, fmt.Errorf("error calling GetUser: %w", err) } @@ -76,7 +76,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R createAccessKeyInput := iam.CreateAccessKeyInput{ UserName: getUserRes.User.UserName, } - createAccessKeyRes, err := client.CreateAccessKey(&createAccessKeyInput) + createAccessKeyRes, err := client.CreateAccessKeyWithContext(ctx, &createAccessKeyInput) if err != nil { return nil, fmt.Errorf("error calling CreateAccessKey: %w", err) } @@ -107,7 +107,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R AccessKeyId: aws.String(oldAccessKey), UserName: getUserRes.User.UserName, } - _, err = client.DeleteAccessKey(&deleteAccessKeyInput) + _, err = client.DeleteAccessKeyWithContext(ctx, &deleteAccessKeyInput) if err != nil { return nil, fmt.Errorf("error deleting old access key: %w", err) } diff --git a/builtin/logical/aws/path_user.go b/builtin/logical/aws/path_user.go index ca5e1a295..4fce31d02 100644 --- a/builtin/logical/aws/path_user.go +++ b/builtin/logical/aws/path_user.go @@ -155,7 +155,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k } // Get information about this user - groupsResp, err := client.ListGroupsForUser(&iam.ListGroupsForUserInput{ + groupsResp, err := client.ListGroupsForUserWithContext(ctx, &iam.ListGroupsForUserInput{ UserName: aws.String(username), MaxItems: aws.Int64(1000), }) @@ -194,7 +194,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k groups := groupsResp.Groups // Inline (user) policies - policiesResp, err := client.ListUserPolicies(&iam.ListUserPoliciesInput{ + policiesResp, err := client.ListUserPoliciesWithContext(ctx, &iam.ListUserPoliciesInput{ UserName: aws.String(username), MaxItems: aws.Int64(1000), }) @@ -204,7 +204,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k policies := policiesResp.PolicyNames // Attached managed policies - manPoliciesResp, err := client.ListAttachedUserPolicies(&iam.ListAttachedUserPoliciesInput{ + manPoliciesResp, err := client.ListAttachedUserPoliciesWithContext(ctx, &iam.ListAttachedUserPoliciesInput{ UserName: aws.String(username), MaxItems: aws.Int64(1000), }) @@ -213,7 +213,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k } manPolicies := manPoliciesResp.AttachedPolicies - keysResp, err := client.ListAccessKeys(&iam.ListAccessKeysInput{ + keysResp, err := client.ListAccessKeysWithContext(ctx, &iam.ListAccessKeysInput{ UserName: aws.String(username), MaxItems: aws.Int64(1000), }) @@ -224,7 +224,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k // Revoke all keys for _, k := range keys { - _, err = client.DeleteAccessKey(&iam.DeleteAccessKeyInput{ + _, err = client.DeleteAccessKeyWithContext(ctx, &iam.DeleteAccessKeyInput{ AccessKeyId: k.AccessKeyId, UserName: aws.String(username), }) @@ -235,7 +235,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k // Detach managed policies for _, p := range manPolicies { - _, err = client.DetachUserPolicy(&iam.DetachUserPolicyInput{ + _, err = client.DetachUserPolicyWithContext(ctx, &iam.DetachUserPolicyInput{ UserName: aws.String(username), PolicyArn: p.PolicyArn, }) @@ -246,7 +246,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k // Delete any inline (user) policies for _, p := range policies { - _, err = client.DeleteUserPolicy(&iam.DeleteUserPolicyInput{ + _, err = client.DeleteUserPolicyWithContext(ctx, &iam.DeleteUserPolicyInput{ UserName: aws.String(username), PolicyName: p, }) @@ -257,7 +257,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k // Remove the user from all their groups for _, g := range groups { - _, err = client.RemoveUserFromGroup(&iam.RemoveUserFromGroupInput{ + _, err = client.RemoveUserFromGroupWithContext(ctx, &iam.RemoveUserFromGroupInput{ GroupName: g.GroupName, UserName: aws.String(username), }) @@ -267,7 +267,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k } // Delete the user - _, err = client.DeleteUser(&iam.DeleteUserInput{ + _, err = client.DeleteUserWithContext(ctx, &iam.DeleteUserInput{ UserName: aws.String(username), }) if err != nil { diff --git a/builtin/logical/aws/secret_access_keys.go b/builtin/logical/aws/secret_access_keys.go index 9b8a2bc9b..a4c57d278 100644 --- a/builtin/logical/aws/secret_access_keys.go +++ b/builtin/logical/aws/secret_access_keys.go @@ -153,7 +153,7 @@ func (b *backend) getFederationToken(ctx context.Context, s logical.Storage, return logical.ErrorResponse("must specify at least one of policy_arns or policy_document with %s credential_type", federationTokenCred), nil } - tokenResp, err := stsClient.GetFederationToken(getTokenInput) + tokenResp, err := stsClient.GetFederationTokenWithContext(ctx, getTokenInput) if err != nil { return logical.ErrorResponse("Error generating STS keys: %s", err), awsutil.CheckAWSError(err) } @@ -228,7 +228,7 @@ func (b *backend) assumeRole(ctx context.Context, s logical.Storage, if len(policyARNs) > 0 { assumeRoleInput.SetPolicyArns(convertPolicyARNs(policyARNs)) } - tokenResp, err := stsClient.AssumeRole(assumeRoleInput) + tokenResp, err := stsClient.AssumeRoleWithContext(ctx, assumeRoleInput) if err != nil { return logical.ErrorResponse("Error assuming role: %s", err), awsutil.CheckAWSError(err) } @@ -314,7 +314,7 @@ func (b *backend) secretAccessKeysCreate( } // Create the user - _, err = iamClient.CreateUser(createUserRequest) + _, err = iamClient.CreateUserWithContext(ctx, createUserRequest) if err != nil { if walErr := framework.DeleteWAL(ctx, s, walID); walErr != nil { iamErr := fmt.Errorf("error creating IAM user: %w", err) @@ -325,7 +325,7 @@ func (b *backend) secretAccessKeysCreate( for _, arn := range role.PolicyArns { // Attach existing policy against user - _, err = iamClient.AttachUserPolicy(&iam.AttachUserPolicyInput{ + _, err = iamClient.AttachUserPolicyWithContext(ctx, &iam.AttachUserPolicyInput{ UserName: aws.String(username), PolicyArn: aws.String(arn), }) @@ -336,7 +336,7 @@ func (b *backend) secretAccessKeysCreate( } if role.PolicyDocument != "" { // Add new inline user policy against user - _, err = iamClient.PutUserPolicy(&iam.PutUserPolicyInput{ + _, err = iamClient.PutUserPolicyWithContext(ctx, &iam.PutUserPolicyInput{ UserName: aws.String(username), PolicyName: aws.String(policyName), PolicyDocument: aws.String(role.PolicyDocument), @@ -348,7 +348,7 @@ func (b *backend) secretAccessKeysCreate( for _, group := range role.IAMGroups { // Add user to IAM groups - _, err = iamClient.AddUserToGroup(&iam.AddUserToGroupInput{ + _, err = iamClient.AddUserToGroupWithContext(ctx, &iam.AddUserToGroupInput{ UserName: aws.String(username), GroupName: aws.String(group), }) @@ -367,7 +367,7 @@ func (b *backend) secretAccessKeysCreate( } if len(tags) > 0 { - _, err = iamClient.TagUser(&iam.TagUserInput{ + _, err = iamClient.TagUserWithContext(ctx, &iam.TagUserInput{ Tags: tags, UserName: &username, }) @@ -378,7 +378,7 @@ func (b *backend) secretAccessKeysCreate( } // Create the keys - keyResp, err := iamClient.CreateAccessKey(&iam.CreateAccessKeyInput{ + keyResp, err := iamClient.CreateAccessKeyWithContext(ctx, &iam.CreateAccessKeyInput{ UserName: aws.String(username), }) if err != nil { diff --git a/changelog/19365.txt b/changelog/19365.txt new file mode 100644 index 000000000..774c750f4 --- /dev/null +++ b/changelog/19365.txt @@ -0,0 +1,7 @@ +```release-note: enhancement +auth/aws: Support request cancellation with AWS requests +``` + +```release-note: enhancement +secrets/aws: Support request cancellation with AWS requests +```