aws: pass cancelable context with aws calls (#19365)

* auth/aws: use cancelable context with aws calls

* secrets/aws: use cancelable context with aws calls
This commit is contained in:
Mason Foster 2023-03-23 13:02:24 -04:00 committed by GitHub
parent a3f26af4c5
commit 09c6ff0623
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 68 additions and 44 deletions

View File

@ -312,7 +312,7 @@ func (b *backend) resolveArnToRealUniqueId(ctx context.Context, s logical.Storag
switch entity.Type {
case "user":
userInfo, err := iamClient.GetUser(&iam.GetUserInput{UserName: &entity.FriendlyName})
userInfo, err := iamClient.GetUserWithContext(ctx, &iam.GetUserInput{UserName: &entity.FriendlyName})
if err != nil {
return "", awsutil.AppendAWSError(err)
}
@ -321,7 +321,7 @@ func (b *backend) resolveArnToRealUniqueId(ctx context.Context, s logical.Storag
}
return *userInfo.User.UserId, nil
case "role":
roleInfo, err := iamClient.GetRole(&iam.GetRoleInput{RoleName: &entity.FriendlyName})
roleInfo, err := iamClient.GetRoleWithContext(ctx, &iam.GetRoleInput{RoleName: &entity.FriendlyName})
if err != nil {
return "", awsutil.AppendAWSError(err)
}
@ -330,7 +330,7 @@ func (b *backend) resolveArnToRealUniqueId(ctx context.Context, s logical.Storag
}
return *roleInfo.Role.RoleId, nil
case "instance-profile":
profileInfo, err := iamClient.GetInstanceProfile(&iam.GetInstanceProfileInput{InstanceProfileName: &entity.FriendlyName})
profileInfo, err := iamClient.GetInstanceProfileWithContext(ctx, &iam.GetInstanceProfileInput{InstanceProfileName: &entity.FriendlyName})
if err != nil {
return "", awsutil.AppendAWSError(err)
}

View File

@ -122,7 +122,7 @@ func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region
return nil, fmt.Errorf("could not obtain sts client: %w", err)
}
inputParams := &sts.GetCallerIdentityInput{}
identity, err := client.GetCallerIdentity(inputParams)
identity, err := client.GetCallerIdentityWithContext(ctx, inputParams)
if err != nil {
return nil, fmt.Errorf("unable to fetch current caller: %w", err)
}

View File

@ -100,7 +100,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
// Get the current user's name since it's required to create an access key.
// Empty input means get the current user.
var getUserInput iam.GetUserInput
getUserRes, err := iamClient.GetUser(&getUserInput)
getUserRes, err := iamClient.GetUserWithContext(ctx, &getUserInput)
if err != nil {
return nil, fmt.Errorf("error calling GetUser: %w", err)
}
@ -118,7 +118,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
createAccessKeyInput := iam.CreateAccessKeyInput{
UserName: getUserRes.User.UserName,
}
createAccessKeyRes, err := iamClient.CreateAccessKey(&createAccessKeyInput)
createAccessKeyRes, err := iamClient.CreateAccessKeyWithContext(ctx, &createAccessKeyInput)
if err != nil {
return nil, fmt.Errorf("error calling CreateAccessKey: %w", err)
}
@ -142,7 +142,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
AccessKeyId: createAccessKeyRes.AccessKey.AccessKeyId,
UserName: getUserRes.User.UserName,
}
if _, err := iamClient.DeleteAccessKey(&deleteAccessKeyInput); err != nil {
if _, err := iamClient.DeleteAccessKeyWithContext(ctx, &deleteAccessKeyInput); err != nil {
// Include this error in the errs returned by this method.
errs = multierror.Append(errs, fmt.Errorf("error deleting newly created but unstored access key ID %s: %s", *createAccessKeyRes.AccessKey.AccessKeyId, err))
}
@ -179,7 +179,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
AccessKeyId: aws.String(oldAccessKey),
UserName: getUserRes.User.UserName,
}
if _, err = iamClient.DeleteAccessKey(&deleteAccessKeyInput); err != nil {
if _, err = iamClient.DeleteAccessKeyWithContext(ctx, &deleteAccessKeyInput); err != nil {
errs = multierror.Append(errs, fmt.Errorf("error deleting old access key ID %s: %w", oldAccessKey, err))
return nil, errs
}

View File

@ -8,6 +8,7 @@ import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
@ -15,9 +16,23 @@ import (
"github.com/hashicorp/vault/sdk/logical"
)
type mockIAMClient awsutil.MockIAM
func (m *mockIAMClient) GetUserWithContext(_ aws.Context, input *iam.GetUserInput, _ ...request.Option) (*iam.GetUserOutput, error) {
return (*awsutil.MockIAM)(m).GetUser(input)
}
func (m *mockIAMClient) CreateAccessKeyWithContext(_ aws.Context, input *iam.CreateAccessKeyInput, _ ...request.Option) (*iam.CreateAccessKeyOutput, error) {
return (*awsutil.MockIAM)(m).CreateAccessKey(input)
}
func (m *mockIAMClient) DeleteAccessKeyWithContext(_ aws.Context, input *iam.DeleteAccessKeyInput, _ ...request.Option) (*iam.DeleteAccessKeyOutput, error) {
return (*awsutil.MockIAM)(m).DeleteAccessKey(input)
}
func TestPathConfigRotateRoot(t *testing.T) {
getIAMClient = func(sess *session.Session) iamiface.IAMAPI {
return &awsutil.MockIAM{
return &mockIAMClient{
CreateAccessKeyOutput: &iam.CreateAccessKeyOutput{
AccessKey: &iam.AccessKey{
AccessKeyId: aws.String("fizz2"),

View File

@ -106,8 +106,8 @@ This must match the request body included in the signature.`,
"iam_request_headers": {
Type: framework.TypeHeader,
Description: `Key/value pairs of headers for use in the
sts:GetCallerIdentity HTTP requests headers when auth_type is iam. Can be either
a Base64-encoded, JSON-serialized string, or a JSON object of key/value pairs.
sts:GetCallerIdentity HTTP requests headers when auth_type is iam. Can be either
a Base64-encoded, JSON-serialized string, or a JSON object of key/value pairs.
This must at a minimum include the headers over which AWS has included a signature.`,
},
"identity": {
@ -340,7 +340,7 @@ func (b *backend) pathLoginResolveRoleIam(ctx context.Context, req *logical.Requ
// instanceIamRoleARN fetches the IAM role ARN associated with the given
// instance profile name
func (b *backend) instanceIamRoleARN(iamClient *iam.IAM, instanceProfileName string) (string, error) {
func (b *backend) instanceIamRoleARN(ctx context.Context, iamClient *iam.IAM, instanceProfileName string) (string, error) {
if iamClient == nil {
return "", fmt.Errorf("nil iamClient")
}
@ -348,7 +348,7 @@ func (b *backend) instanceIamRoleARN(iamClient *iam.IAM, instanceProfileName str
return "", fmt.Errorf("missing instance profile name")
}
profile, err := iamClient.GetInstanceProfile(&iam.GetInstanceProfileInput{
profile, err := iamClient.GetInstanceProfileWithContext(ctx, &iam.GetInstanceProfileInput{
InstanceProfileName: aws.String(instanceProfileName),
})
if err != nil {
@ -382,7 +382,7 @@ func (b *backend) validateInstance(ctx context.Context, s logical.Storage, insta
return nil, err
}
status, err := ec2Client.DescribeInstances(&ec2.DescribeInstancesInput{
status, err := ec2Client.DescribeInstancesWithContext(ctx, &ec2.DescribeInstancesInput{
InstanceIds: []*string{
aws.String(instanceID),
},
@ -724,7 +724,7 @@ func (b *backend) verifyInstanceMeetsRoleRequirements(ctx context.Context,
} else if iamClient == nil {
return nil, fmt.Errorf("received a nil iamClient")
}
iamRoleARN, err := b.instanceIamRoleARN(iamClient, iamInstanceProfileEntity.FriendlyName)
iamRoleARN, err := b.instanceIamRoleARN(ctx, iamClient, iamInstanceProfileEntity.FriendlyName)
if err != nil {
return nil, fmt.Errorf("IAM role ARN could not be fetched: %w", err)
}
@ -1835,7 +1835,7 @@ func (b *backend) fullArn(ctx context.Context, e *iamEntity, s logical.Storage)
input := iam.GetUserInput{
UserName: aws.String(e.FriendlyName),
}
resp, err := client.GetUser(&input)
resp, err := client.GetUserWithContext(ctx, &input)
if err != nil {
return "", fmt.Errorf("error fetching user %q: %w", e.FriendlyName, err)
}
@ -1849,7 +1849,7 @@ func (b *backend) fullArn(ctx context.Context, e *iamEntity, s logical.Storage)
input := iam.GetRoleInput{
RoleName: aws.String(e.FriendlyName),
}
resp, err := client.GetRole(&input)
resp, err := client.GetRoleWithContext(ctx, &input)
if err != nil {
return "", fmt.Errorf("error fetching role %q: %w", e.FriendlyName, err)
}

View File

@ -19,6 +19,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/ec2"
@ -39,7 +40,7 @@ type mockIAMClient struct {
iamiface.IAMAPI
}
func (m *mockIAMClient) CreateUser(input *iam.CreateUserInput) (*iam.CreateUserOutput, error) {
func (m *mockIAMClient) CreateUserWithContext(_ aws.Context, input *iam.CreateUserInput, _ ...request.Option) (*iam.CreateUserOutput, error) {
return nil, awserr.New("Throttling", "", nil)
}

View File

@ -73,7 +73,7 @@ func (b *backend) getGroupPolicies(ctx context.Context, s logical.Storage, iamGr
for _, g := range iamGroups {
// Collect managed policy ARNs from the IAM Group
agp, err = iamClient.ListAttachedGroupPolicies(&iam.ListAttachedGroupPoliciesInput{
agp, err = iamClient.ListAttachedGroupPoliciesWithContext(ctx, &iam.ListAttachedGroupPoliciesInput{
GroupName: aws.String(g),
})
if err != nil {
@ -84,14 +84,14 @@ func (b *backend) getGroupPolicies(ctx context.Context, s logical.Storage, iamGr
}
// Collect inline policy names from the IAM Group
inlinePolicies, err = iamClient.ListGroupPolicies(&iam.ListGroupPoliciesInput{
inlinePolicies, err = iamClient.ListGroupPoliciesWithContext(ctx, &iam.ListGroupPoliciesInput{
GroupName: aws.String(g),
})
if err != nil {
return nil, nil, err
}
for _, iP := range inlinePolicies.PolicyNames {
inlinePolicyDoc, err = iamClient.GetGroupPolicy(&iam.GetGroupPolicyInput{
inlinePolicyDoc, err = iamClient.GetGroupPolicyWithContext(ctx, &iam.GetGroupPolicyInput{
GroupName: &g,
PolicyName: iP,
})

View File

@ -8,6 +8,7 @@ import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/hashicorp/vault/sdk/logical"
@ -29,15 +30,15 @@ type mockGroupIAMClient struct {
GetGroupPolicyResp iam.GetGroupPolicyOutput
}
func (m mockGroupIAMClient) ListAttachedGroupPolicies(in *iam.ListAttachedGroupPoliciesInput) (*iam.ListAttachedGroupPoliciesOutput, error) {
func (m mockGroupIAMClient) ListAttachedGroupPoliciesWithContext(_ aws.Context, in *iam.ListAttachedGroupPoliciesInput, _ ...request.Option) (*iam.ListAttachedGroupPoliciesOutput, error) {
return &m.ListAttachedGroupPoliciesResp, nil
}
func (m mockGroupIAMClient) ListGroupPolicies(in *iam.ListGroupPoliciesInput) (*iam.ListGroupPoliciesOutput, error) {
func (m mockGroupIAMClient) ListGroupPoliciesWithContext(_ aws.Context, in *iam.ListGroupPoliciesInput, _ ...request.Option) (*iam.ListGroupPoliciesOutput, error) {
return &m.ListGroupPoliciesResp, nil
}
func (m mockGroupIAMClient) GetGroupPolicy(in *iam.GetGroupPolicyInput) (*iam.GetGroupPolicyOutput, error) {
func (m mockGroupIAMClient) GetGroupPolicyWithContext(_ aws.Context, in *iam.GetGroupPolicyInput, _ ...request.Option) (*iam.GetGroupPolicyOutput, error) {
return &m.GetGroupPolicyResp, nil
}

View File

@ -59,7 +59,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
}
var getUserInput iam.GetUserInput // empty input means get current user
getUserRes, err := client.GetUser(&getUserInput)
getUserRes, err := client.GetUserWithContext(ctx, &getUserInput)
if err != nil {
return nil, fmt.Errorf("error calling GetUser: %w", err)
}
@ -76,7 +76,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
createAccessKeyInput := iam.CreateAccessKeyInput{
UserName: getUserRes.User.UserName,
}
createAccessKeyRes, err := client.CreateAccessKey(&createAccessKeyInput)
createAccessKeyRes, err := client.CreateAccessKeyWithContext(ctx, &createAccessKeyInput)
if err != nil {
return nil, fmt.Errorf("error calling CreateAccessKey: %w", err)
}
@ -107,7 +107,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
AccessKeyId: aws.String(oldAccessKey),
UserName: getUserRes.User.UserName,
}
_, err = client.DeleteAccessKey(&deleteAccessKeyInput)
_, err = client.DeleteAccessKeyWithContext(ctx, &deleteAccessKeyInput)
if err != nil {
return nil, fmt.Errorf("error deleting old access key: %w", err)
}

View File

@ -155,7 +155,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
}
// Get information about this user
groupsResp, err := client.ListGroupsForUser(&iam.ListGroupsForUserInput{
groupsResp, err := client.ListGroupsForUserWithContext(ctx, &iam.ListGroupsForUserInput{
UserName: aws.String(username),
MaxItems: aws.Int64(1000),
})
@ -194,7 +194,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
groups := groupsResp.Groups
// Inline (user) policies
policiesResp, err := client.ListUserPolicies(&iam.ListUserPoliciesInput{
policiesResp, err := client.ListUserPoliciesWithContext(ctx, &iam.ListUserPoliciesInput{
UserName: aws.String(username),
MaxItems: aws.Int64(1000),
})
@ -204,7 +204,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
policies := policiesResp.PolicyNames
// Attached managed policies
manPoliciesResp, err := client.ListAttachedUserPolicies(&iam.ListAttachedUserPoliciesInput{
manPoliciesResp, err := client.ListAttachedUserPoliciesWithContext(ctx, &iam.ListAttachedUserPoliciesInput{
UserName: aws.String(username),
MaxItems: aws.Int64(1000),
})
@ -213,7 +213,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
}
manPolicies := manPoliciesResp.AttachedPolicies
keysResp, err := client.ListAccessKeys(&iam.ListAccessKeysInput{
keysResp, err := client.ListAccessKeysWithContext(ctx, &iam.ListAccessKeysInput{
UserName: aws.String(username),
MaxItems: aws.Int64(1000),
})
@ -224,7 +224,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
// Revoke all keys
for _, k := range keys {
_, err = client.DeleteAccessKey(&iam.DeleteAccessKeyInput{
_, err = client.DeleteAccessKeyWithContext(ctx, &iam.DeleteAccessKeyInput{
AccessKeyId: k.AccessKeyId,
UserName: aws.String(username),
})
@ -235,7 +235,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
// Detach managed policies
for _, p := range manPolicies {
_, err = client.DetachUserPolicy(&iam.DetachUserPolicyInput{
_, err = client.DetachUserPolicyWithContext(ctx, &iam.DetachUserPolicyInput{
UserName: aws.String(username),
PolicyArn: p.PolicyArn,
})
@ -246,7 +246,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
// Delete any inline (user) policies
for _, p := range policies {
_, err = client.DeleteUserPolicy(&iam.DeleteUserPolicyInput{
_, err = client.DeleteUserPolicyWithContext(ctx, &iam.DeleteUserPolicyInput{
UserName: aws.String(username),
PolicyName: p,
})
@ -257,7 +257,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
// Remove the user from all their groups
for _, g := range groups {
_, err = client.RemoveUserFromGroup(&iam.RemoveUserFromGroupInput{
_, err = client.RemoveUserFromGroupWithContext(ctx, &iam.RemoveUserFromGroupInput{
GroupName: g.GroupName,
UserName: aws.String(username),
})
@ -267,7 +267,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
}
// Delete the user
_, err = client.DeleteUser(&iam.DeleteUserInput{
_, err = client.DeleteUserWithContext(ctx, &iam.DeleteUserInput{
UserName: aws.String(username),
})
if err != nil {

View File

@ -153,7 +153,7 @@ func (b *backend) getFederationToken(ctx context.Context, s logical.Storage,
return logical.ErrorResponse("must specify at least one of policy_arns or policy_document with %s credential_type", federationTokenCred), nil
}
tokenResp, err := stsClient.GetFederationToken(getTokenInput)
tokenResp, err := stsClient.GetFederationTokenWithContext(ctx, getTokenInput)
if err != nil {
return logical.ErrorResponse("Error generating STS keys: %s", err), awsutil.CheckAWSError(err)
}
@ -228,7 +228,7 @@ func (b *backend) assumeRole(ctx context.Context, s logical.Storage,
if len(policyARNs) > 0 {
assumeRoleInput.SetPolicyArns(convertPolicyARNs(policyARNs))
}
tokenResp, err := stsClient.AssumeRole(assumeRoleInput)
tokenResp, err := stsClient.AssumeRoleWithContext(ctx, assumeRoleInput)
if err != nil {
return logical.ErrorResponse("Error assuming role: %s", err), awsutil.CheckAWSError(err)
}
@ -314,7 +314,7 @@ func (b *backend) secretAccessKeysCreate(
}
// Create the user
_, err = iamClient.CreateUser(createUserRequest)
_, err = iamClient.CreateUserWithContext(ctx, createUserRequest)
if err != nil {
if walErr := framework.DeleteWAL(ctx, s, walID); walErr != nil {
iamErr := fmt.Errorf("error creating IAM user: %w", err)
@ -325,7 +325,7 @@ func (b *backend) secretAccessKeysCreate(
for _, arn := range role.PolicyArns {
// Attach existing policy against user
_, err = iamClient.AttachUserPolicy(&iam.AttachUserPolicyInput{
_, err = iamClient.AttachUserPolicyWithContext(ctx, &iam.AttachUserPolicyInput{
UserName: aws.String(username),
PolicyArn: aws.String(arn),
})
@ -336,7 +336,7 @@ func (b *backend) secretAccessKeysCreate(
}
if role.PolicyDocument != "" {
// Add new inline user policy against user
_, err = iamClient.PutUserPolicy(&iam.PutUserPolicyInput{
_, err = iamClient.PutUserPolicyWithContext(ctx, &iam.PutUserPolicyInput{
UserName: aws.String(username),
PolicyName: aws.String(policyName),
PolicyDocument: aws.String(role.PolicyDocument),
@ -348,7 +348,7 @@ func (b *backend) secretAccessKeysCreate(
for _, group := range role.IAMGroups {
// Add user to IAM groups
_, err = iamClient.AddUserToGroup(&iam.AddUserToGroupInput{
_, err = iamClient.AddUserToGroupWithContext(ctx, &iam.AddUserToGroupInput{
UserName: aws.String(username),
GroupName: aws.String(group),
})
@ -367,7 +367,7 @@ func (b *backend) secretAccessKeysCreate(
}
if len(tags) > 0 {
_, err = iamClient.TagUser(&iam.TagUserInput{
_, err = iamClient.TagUserWithContext(ctx, &iam.TagUserInput{
Tags: tags,
UserName: &username,
})
@ -378,7 +378,7 @@ func (b *backend) secretAccessKeysCreate(
}
// Create the keys
keyResp, err := iamClient.CreateAccessKey(&iam.CreateAccessKeyInput{
keyResp, err := iamClient.CreateAccessKeyWithContext(ctx, &iam.CreateAccessKeyInput{
UserName: aws.String(username),
})
if err != nil {

7
changelog/19365.txt Normal file
View File

@ -0,0 +1,7 @@
```release-note: enhancement
auth/aws: Support request cancellation with AWS requests
```
```release-note: enhancement
secrets/aws: Support request cancellation with AWS requests
```