// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package awsauth import ( "context" "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/sts" cleanhttp "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-secure-stdlib/awsutil" "github.com/hashicorp/vault/sdk/logical" ) // getRawClientConfig creates a aws-sdk-go config, which is used to create client // that can interact with AWS API. This builds credentials in the following // order of preference: // // * Static credentials from 'config/client' // * Environment variables // * Instance metadata role func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, region, clientType string) (*aws.Config, error) { credsConfig := &awsutil.CredentialsConfig{ Region: region, Logger: b.Logger(), } // Read the configured secret key and access key config, err := b.nonLockedClientConfigEntry(ctx, s) if err != nil { return nil, err } endpoint := aws.String("") var maxRetries int = aws.UseServiceDefaultRetries if config != nil { // Override the defaults with configured values. switch { case clientType == "ec2" && config.Endpoint != "": endpoint = aws.String(config.Endpoint) case clientType == "iam" && config.IAMEndpoint != "": endpoint = aws.String(config.IAMEndpoint) case clientType == "sts": if config.STSEndpoint != "" { endpoint = aws.String(config.STSEndpoint) } if config.STSRegion != "" { region = config.STSRegion } } credsConfig.AccessKey = config.AccessKey credsConfig.SecretKey = config.SecretKey maxRetries = config.MaxRetries } credsConfig.HTTPClient = cleanhttp.DefaultClient() creds, err := credsConfig.GenerateCredentialChain() if err != nil { return nil, err } if creds == nil { return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata") } // Create a config that can be used to make the API calls. return &aws.Config{ Credentials: creds, Region: aws.String(region), HTTPClient: cleanhttp.DefaultClient(), Endpoint: endpoint, MaxRetries: aws.Int(maxRetries), }, nil } // getClientConfig returns an aws-sdk-go config, with optionally assumed credentials // It uses getRawClientConfig to obtain config for the runtime environment, and if // stsRole is a non-empty string, it will use AssumeRole to obtain a set of assumed // credentials. The credentials will expire after 15 minutes but will auto-refresh. func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) { config, err := b.getRawClientConfig(ctx, s, region, clientType) if err != nil { return nil, err } if config == nil { return nil, fmt.Errorf("could not compile valid credentials through the default provider chain") } stsConfig, err := b.getRawClientConfig(ctx, s, region, "sts") if stsConfig == nil { return nil, fmt.Errorf("could not configure STS client") } if err != nil { return nil, err } if stsRole != "" { sess, err := session.NewSession(stsConfig) if err != nil { return nil, err } assumedCredentials := stscreds.NewCredentials(sess, stsRole) // Test that we actually have permissions to assume the role if _, err = assumedCredentials.Get(); err != nil { return nil, err } config.Credentials = assumedCredentials } else { if b.defaultAWSAccountID == "" { sess, err := session.NewSession(stsConfig) if err != nil { return nil, err } client := sts.New(sess) if client == nil { return nil, fmt.Errorf("could not obtain sts client: %w", err) } inputParams := &sts.GetCallerIdentityInput{} identity, err := client.GetCallerIdentityWithContext(ctx, inputParams) if err != nil { return nil, fmt.Errorf("unable to fetch current caller: %w", err) } if identity == nil { return nil, fmt.Errorf("got nil result from GetCallerIdentity") } b.defaultAWSAccountID = *identity.Account } if b.defaultAWSAccountID != accountID { return nil, fmt.Errorf("unable to fetch client for account ID %q -- default client is for account %q", accountID, b.defaultAWSAccountID) } } return config, nil } // flushCachedEC2Clients deletes all the cached ec2 client objects from the backend. // If the client credentials configuration is deleted or updated in the backend, all // the cached EC2 client objects will be flushed. Config mutex lock should be // acquired for write operation before calling this method. func (b *backend) flushCachedEC2Clients() { // deleting items in map during iteration is safe for region := range b.EC2ClientsMap { delete(b.EC2ClientsMap, region) } } // flushCachedIAMClients deletes all the cached iam client objects from the // backend. If the client credentials configuration is deleted or updated in // the backend, all the cached IAM client objects will be flushed. Config mutex // lock should be acquired for write operation before calling this method. func (b *backend) flushCachedIAMClients() { // deleting items in map during iteration is safe for region := range b.IAMClientsMap { delete(b.IAMClientsMap, region) } } // Gets an entry out of the user ID cache func (b *backend) getCachedUserId(userId string) string { if userId == "" { return "" } if entry, ok := b.iamUserIdToArnCache.Get(userId); ok { b.iamUserIdToArnCache.SetDefault(userId, entry) return entry.(string) } return "" } // Sets an entry in the user ID cache func (b *backend) setCachedUserId(userId, arn string) { if userId != "" { b.iamUserIdToArnCache.SetDefault(userId, arn) } } func (b *backend) stsRoleForAccount(ctx context.Context, s logical.Storage, accountID string) (string, error) { // Check if an STS configuration exists for the AWS account sts, err := b.lockedAwsStsEntry(ctx, s, accountID) if err != nil { return "", fmt.Errorf("error fetching STS config for account ID %q: %w", accountID, err) } // An empty STS role signifies the master account if sts != nil { return sts.StsRole, nil } return "", nil } // clientEC2 creates a client to interact with AWS EC2 API func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (*ec2.EC2, error) { stsRole, err := b.stsRoleForAccount(ctx, s, accountID) if err != nil { return nil, err } b.configMutex.RLock() if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil { defer b.configMutex.RUnlock() // If the client object was already created, return it return b.EC2ClientsMap[region][stsRole], nil } // Release the read lock and acquire the write lock b.configMutex.RUnlock() b.configMutex.Lock() defer b.configMutex.Unlock() // If the client gets created while switching the locks, return it if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil { return b.EC2ClientsMap[region][stsRole], nil } // Create an AWS config object using a chain of providers var awsConfig *aws.Config awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "ec2") if err != nil { return nil, err } if awsConfig == nil { return nil, fmt.Errorf("could not retrieve valid assumed credentials") } // Create a new EC2 client object, cache it and return the same sess, err := session.NewSession(awsConfig) if err != nil { return nil, err } client := ec2.New(sess) if client == nil { return nil, fmt.Errorf("could not obtain ec2 client") } if _, ok := b.EC2ClientsMap[region]; !ok { b.EC2ClientsMap[region] = map[string]*ec2.EC2{stsRole: client} } else { b.EC2ClientsMap[region][stsRole] = client } return b.EC2ClientsMap[region][stsRole], nil } // clientIAM creates a client to interact with AWS IAM API func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (*iam.IAM, error) { stsRole, err := b.stsRoleForAccount(ctx, s, accountID) if err != nil { return nil, err } if stsRole == "" { b.Logger().Debug(fmt.Sprintf("no stsRole found for %s", accountID)) } else { b.Logger().Debug(fmt.Sprintf("found stsRole %s for account %s", stsRole, accountID)) } b.configMutex.RLock() if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil { defer b.configMutex.RUnlock() // If the client object was already created, return it b.Logger().Debug(fmt.Sprintf("returning cached client for region %s and stsRole %s", region, stsRole)) return b.IAMClientsMap[region][stsRole], nil } b.Logger().Debug(fmt.Sprintf("no cached client for region %s and stsRole %s", region, stsRole)) // Release the read lock and acquire the write lock b.configMutex.RUnlock() b.configMutex.Lock() defer b.configMutex.Unlock() // If the client gets created while switching the locks, return it if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil { return b.IAMClientsMap[region][stsRole], nil } // Create an AWS config object using a chain of providers var awsConfig *aws.Config awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "iam") if err != nil { return nil, err } if awsConfig == nil { return nil, fmt.Errorf("could not retrieve valid assumed credentials") } // Create a new IAM client object, cache it and return the same sess, err := session.NewSession(awsConfig) if err != nil { return nil, err } client := iam.New(sess) if client == nil { return nil, fmt.Errorf("could not obtain iam client") } if _, ok := b.IAMClientsMap[region]; !ok { b.IAMClientsMap[region] = map[string]*iam.IAM{stsRole: client} } else { b.IAMClientsMap[region][stsRole] = client } return b.IAMClientsMap[region][stsRole], nil }