304 lines
9.6 KiB
Go
304 lines
9.6 KiB
Go
package awsauth
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/aws/aws-sdk-go/service/ec2"
|
|
"github.com/aws/aws-sdk-go/service/iam"
|
|
"github.com/aws/aws-sdk-go/service/sts"
|
|
cleanhttp "github.com/hashicorp/go-cleanhttp"
|
|
"github.com/hashicorp/go-secure-stdlib/awsutil"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
)
|
|
|
|
// getRawClientConfig creates a aws-sdk-go config, which is used to create client
|
|
// that can interact with AWS API. This builds credentials in the following
|
|
// order of preference:
|
|
//
|
|
// * Static credentials from 'config/client'
|
|
// * Environment variables
|
|
// * Instance metadata role
|
|
func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, region, clientType string) (*aws.Config, error) {
|
|
credsConfig := &awsutil.CredentialsConfig{
|
|
Region: region,
|
|
Logger: b.Logger(),
|
|
}
|
|
|
|
// Read the configured secret key and access key
|
|
config, err := b.nonLockedClientConfigEntry(ctx, s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
endpoint := aws.String("")
|
|
var maxRetries int = aws.UseServiceDefaultRetries
|
|
if config != nil {
|
|
// Override the defaults with configured values.
|
|
switch {
|
|
case clientType == "ec2" && config.Endpoint != "":
|
|
endpoint = aws.String(config.Endpoint)
|
|
case clientType == "iam" && config.IAMEndpoint != "":
|
|
endpoint = aws.String(config.IAMEndpoint)
|
|
case clientType == "sts":
|
|
if config.STSEndpoint != "" {
|
|
endpoint = aws.String(config.STSEndpoint)
|
|
}
|
|
if config.STSRegion != "" {
|
|
region = config.STSRegion
|
|
}
|
|
}
|
|
|
|
credsConfig.AccessKey = config.AccessKey
|
|
credsConfig.SecretKey = config.SecretKey
|
|
maxRetries = config.MaxRetries
|
|
}
|
|
|
|
credsConfig.HTTPClient = cleanhttp.DefaultClient()
|
|
|
|
creds, err := credsConfig.GenerateCredentialChain()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if creds == nil {
|
|
return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata")
|
|
}
|
|
|
|
// Create a config that can be used to make the API calls.
|
|
return &aws.Config{
|
|
Credentials: creds,
|
|
Region: aws.String(region),
|
|
HTTPClient: cleanhttp.DefaultClient(),
|
|
Endpoint: endpoint,
|
|
MaxRetries: aws.Int(maxRetries),
|
|
}, nil
|
|
}
|
|
|
|
// getClientConfig returns an aws-sdk-go config, with optionally assumed credentials
|
|
// It uses getRawClientConfig to obtain config for the runtime environment, and if
|
|
// stsRole is a non-empty string, it will use AssumeRole to obtain a set of assumed
|
|
// credentials. The credentials will expire after 15 minutes but will auto-refresh.
|
|
func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) {
|
|
config, err := b.getRawClientConfig(ctx, s, region, clientType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if config == nil {
|
|
return nil, fmt.Errorf("could not compile valid credentials through the default provider chain")
|
|
}
|
|
|
|
stsConfig, err := b.getRawClientConfig(ctx, s, region, "sts")
|
|
if stsConfig == nil {
|
|
return nil, fmt.Errorf("could not configure STS client")
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if stsRole != "" {
|
|
sess, err := session.NewSession(stsConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
assumedCredentials := stscreds.NewCredentials(sess, stsRole)
|
|
// Test that we actually have permissions to assume the role
|
|
if _, err = assumedCredentials.Get(); err != nil {
|
|
return nil, err
|
|
}
|
|
config.Credentials = assumedCredentials
|
|
} else {
|
|
if b.defaultAWSAccountID == "" {
|
|
sess, err := session.NewSession(stsConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
client := sts.New(sess)
|
|
if client == nil {
|
|
return nil, fmt.Errorf("could not obtain sts client: %w", err)
|
|
}
|
|
inputParams := &sts.GetCallerIdentityInput{}
|
|
identity, err := client.GetCallerIdentity(inputParams)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to fetch current caller: %w", err)
|
|
}
|
|
if identity == nil {
|
|
return nil, fmt.Errorf("got nil result from GetCallerIdentity")
|
|
}
|
|
b.defaultAWSAccountID = *identity.Account
|
|
}
|
|
if b.defaultAWSAccountID != accountID {
|
|
return nil, fmt.Errorf("unable to fetch client for account ID %q -- default client is for account %q", accountID, b.defaultAWSAccountID)
|
|
}
|
|
}
|
|
|
|
return config, nil
|
|
}
|
|
|
|
// flushCachedEC2Clients deletes all the cached ec2 client objects from the backend.
|
|
// If the client credentials configuration is deleted or updated in the backend, all
|
|
// the cached EC2 client objects will be flushed. Config mutex lock should be
|
|
// acquired for write operation before calling this method.
|
|
func (b *backend) flushCachedEC2Clients() {
|
|
// deleting items in map during iteration is safe
|
|
for region := range b.EC2ClientsMap {
|
|
delete(b.EC2ClientsMap, region)
|
|
}
|
|
}
|
|
|
|
// flushCachedIAMClients deletes all the cached iam client objects from the
|
|
// backend. If the client credentials configuration is deleted or updated in
|
|
// the backend, all the cached IAM client objects will be flushed. Config mutex
|
|
// lock should be acquired for write operation before calling this method.
|
|
func (b *backend) flushCachedIAMClients() {
|
|
// deleting items in map during iteration is safe
|
|
for region := range b.IAMClientsMap {
|
|
delete(b.IAMClientsMap, region)
|
|
}
|
|
}
|
|
|
|
// Gets an entry out of the user ID cache
|
|
func (b *backend) getCachedUserId(userId string) string {
|
|
if userId == "" {
|
|
return ""
|
|
}
|
|
if entry, ok := b.iamUserIdToArnCache.Get(userId); ok {
|
|
b.iamUserIdToArnCache.SetDefault(userId, entry)
|
|
return entry.(string)
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// Sets an entry in the user ID cache
|
|
func (b *backend) setCachedUserId(userId, arn string) {
|
|
if userId != "" {
|
|
b.iamUserIdToArnCache.SetDefault(userId, arn)
|
|
}
|
|
}
|
|
|
|
func (b *backend) stsRoleForAccount(ctx context.Context, s logical.Storage, accountID string) (string, error) {
|
|
// Check if an STS configuration exists for the AWS account
|
|
sts, err := b.lockedAwsStsEntry(ctx, s, accountID)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error fetching STS config for account ID %q: %w", accountID, err)
|
|
}
|
|
// An empty STS role signifies the master account
|
|
if sts != nil {
|
|
return sts.StsRole, nil
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
// clientEC2 creates a client to interact with AWS EC2 API
|
|
func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (*ec2.EC2, error) {
|
|
stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
b.configMutex.RLock()
|
|
if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil {
|
|
defer b.configMutex.RUnlock()
|
|
// If the client object was already created, return it
|
|
return b.EC2ClientsMap[region][stsRole], nil
|
|
}
|
|
|
|
// Release the read lock and acquire the write lock
|
|
b.configMutex.RUnlock()
|
|
b.configMutex.Lock()
|
|
defer b.configMutex.Unlock()
|
|
|
|
// If the client gets created while switching the locks, return it
|
|
if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil {
|
|
return b.EC2ClientsMap[region][stsRole], nil
|
|
}
|
|
|
|
// Create an AWS config object using a chain of providers
|
|
var awsConfig *aws.Config
|
|
awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "ec2")
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if awsConfig == nil {
|
|
return nil, fmt.Errorf("could not retrieve valid assumed credentials")
|
|
}
|
|
|
|
// Create a new EC2 client object, cache it and return the same
|
|
sess, err := session.NewSession(awsConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
client := ec2.New(sess)
|
|
if client == nil {
|
|
return nil, fmt.Errorf("could not obtain ec2 client")
|
|
}
|
|
if _, ok := b.EC2ClientsMap[region]; !ok {
|
|
b.EC2ClientsMap[region] = map[string]*ec2.EC2{stsRole: client}
|
|
} else {
|
|
b.EC2ClientsMap[region][stsRole] = client
|
|
}
|
|
|
|
return b.EC2ClientsMap[region][stsRole], nil
|
|
}
|
|
|
|
// clientIAM creates a client to interact with AWS IAM API
|
|
func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (*iam.IAM, error) {
|
|
stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if stsRole == "" {
|
|
b.Logger().Debug(fmt.Sprintf("no stsRole found for %s", accountID))
|
|
} else {
|
|
b.Logger().Debug(fmt.Sprintf("found stsRole %s for account %s", stsRole, accountID))
|
|
}
|
|
b.configMutex.RLock()
|
|
if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil {
|
|
defer b.configMutex.RUnlock()
|
|
// If the client object was already created, return it
|
|
b.Logger().Debug(fmt.Sprintf("returning cached client for region %s and stsRole %s", region, stsRole))
|
|
return b.IAMClientsMap[region][stsRole], nil
|
|
}
|
|
b.Logger().Debug(fmt.Sprintf("no cached client for region %s and stsRole %s", region, stsRole))
|
|
|
|
// Release the read lock and acquire the write lock
|
|
b.configMutex.RUnlock()
|
|
b.configMutex.Lock()
|
|
defer b.configMutex.Unlock()
|
|
|
|
// If the client gets created while switching the locks, return it
|
|
if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil {
|
|
return b.IAMClientsMap[region][stsRole], nil
|
|
}
|
|
|
|
// Create an AWS config object using a chain of providers
|
|
var awsConfig *aws.Config
|
|
awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "iam")
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if awsConfig == nil {
|
|
return nil, fmt.Errorf("could not retrieve valid assumed credentials")
|
|
}
|
|
|
|
// Create a new IAM client object, cache it and return the same
|
|
sess, err := session.NewSession(awsConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
client := iam.New(sess)
|
|
if client == nil {
|
|
return nil, fmt.Errorf("could not obtain iam client")
|
|
}
|
|
if _, ok := b.IAMClientsMap[region]; !ok {
|
|
b.IAMClientsMap[region] = map[string]*iam.IAM{stsRole: client}
|
|
} else {
|
|
b.IAMClientsMap[region][stsRole] = client
|
|
}
|
|
return b.IAMClientsMap[region][stsRole], nil
|
|
}
|