225 lines
7.8 KiB
Go
225 lines
7.8 KiB
Go
package awsutil
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
|
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
|
|
"github.com/aws/aws-sdk-go/aws/defaults"
|
|
"github.com/aws/aws-sdk-go/aws/endpoints"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/aws/aws-sdk-go/service/sts"
|
|
"github.com/hashicorp/go-hclog"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
const iamServerIdHeader = "X-Vault-AWS-IAM-Server-ID"
|
|
|
|
type CredentialsConfig struct {
|
|
// The access key if static credentials are being used
|
|
AccessKey string
|
|
|
|
// The secret key if static credentials are being used
|
|
SecretKey string
|
|
|
|
// The session token if it is being used
|
|
SessionToken string
|
|
|
|
// If specified, the region will be provided to the config of the
|
|
// EC2RoleProvider's client. This may be useful if you want to e.g. reuse
|
|
// the client elsewhere.
|
|
Region string
|
|
|
|
// The filename for the shared credentials provider, if being used
|
|
Filename string
|
|
|
|
// The profile for the shared credentials provider, if being used
|
|
Profile string
|
|
|
|
// The http.Client to use, or nil for the client to use its default
|
|
HTTPClient *http.Client
|
|
|
|
// The logger to use for credential acquisition debugging
|
|
Logger hclog.Logger
|
|
}
|
|
|
|
// Make sure the logger isn't nil before logging
|
|
func (c *CredentialsConfig) log(level hclog.Level, msg string, args ...interface{}) {
|
|
if c.Logger != nil {
|
|
c.Logger.Log(level, msg, args...)
|
|
}
|
|
}
|
|
|
|
func (c *CredentialsConfig) GenerateCredentialChain() (*credentials.Credentials, error) {
|
|
var providers []credentials.Provider
|
|
|
|
switch {
|
|
case c.AccessKey != "" && c.SecretKey != "":
|
|
// Add the static credential provider
|
|
providers = append(providers, &credentials.StaticProvider{
|
|
Value: credentials.Value{
|
|
AccessKeyID: c.AccessKey,
|
|
SecretAccessKey: c.SecretKey,
|
|
SessionToken: c.SessionToken,
|
|
},
|
|
})
|
|
c.log(hclog.Debug, "added static credential provider", "AccessKey", c.AccessKey)
|
|
|
|
case c.AccessKey == "" && c.SecretKey == "":
|
|
// Attempt to get credentials from the IAM instance role below
|
|
|
|
default: // Have one or the other but not both and not neither
|
|
return nil, fmt.Errorf(
|
|
"static AWS client credentials haven't been properly configured (the access key or secret key were provided but not both)")
|
|
}
|
|
|
|
roleARN := os.Getenv("AWS_ROLE_ARN")
|
|
tokenPath := os.Getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
|
|
sessionName := os.Getenv("AWS_ROLE_SESSION_NAME")
|
|
if roleARN != "" && tokenPath != "" {
|
|
// this session is only created to create the WebIdentityRoleProvider, as the env variables are already there
|
|
// this automatically assumes the role, but the provider needs to be added to the chain
|
|
c.log(hclog.Debug, "adding web identity provider", "roleARN", roleARN)
|
|
sess, err := session.NewSession()
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "error creating a new session to create a WebIdentityRoleProvider")
|
|
}
|
|
webIdentityProvider := stscreds.NewWebIdentityRoleProvider(sts.New(sess), roleARN, sessionName, tokenPath)
|
|
|
|
// Check if the webIdentityProvider can successfully retrieve
|
|
// credentials (via sts:AssumeRole), and warn if there's a problem.
|
|
if _, err := webIdentityProvider.Retrieve(); err != nil {
|
|
c.log(hclog.Warn, "error assuming role", "roleARN", roleARN, "tokenPath", tokenPath, "sessionName", sessionName, "err", err)
|
|
}
|
|
|
|
// Add the web identity role credential provider
|
|
providers = append(providers, webIdentityProvider)
|
|
}
|
|
|
|
// Add the environment credential provider
|
|
providers = append(providers, &credentials.EnvProvider{})
|
|
|
|
// Add the shared credentials provider
|
|
providers = append(providers, &credentials.SharedCredentialsProvider{
|
|
Filename: c.Filename,
|
|
Profile: c.Profile,
|
|
})
|
|
|
|
// Add the remote provider
|
|
def := defaults.Get()
|
|
if c.Region != "" {
|
|
def.Config.Region = aws.String(c.Region)
|
|
}
|
|
if c.HTTPClient != nil {
|
|
def.Config.HTTPClient = c.HTTPClient
|
|
_, checkFullURI := os.LookupEnv("AWS_CONTAINER_CREDENTIALS_FULL_URI")
|
|
_, checkRelativeURI := os.LookupEnv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")
|
|
if !checkFullURI && !checkRelativeURI {
|
|
// match the sdk defaults from https://github.com/aws/aws-sdk-go/pull/3066
|
|
def.Config.HTTPClient.Timeout = 1 * time.Second
|
|
def.Config.MaxRetries = aws.Int(2)
|
|
}
|
|
}
|
|
|
|
providers = append(providers, defaults.RemoteCredProvider(*def.Config, def.Handlers))
|
|
|
|
// Create the credentials required to access the API.
|
|
creds := credentials.NewChainCredentials(providers)
|
|
if creds == nil {
|
|
return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, web identity or instance metadata")
|
|
}
|
|
|
|
return creds, nil
|
|
}
|
|
|
|
func RetrieveCreds(accessKey, secretKey, sessionToken string, logger hclog.Logger) (*credentials.Credentials, error) {
|
|
credConfig := CredentialsConfig{
|
|
AccessKey: accessKey,
|
|
SecretKey: secretKey,
|
|
SessionToken: sessionToken,
|
|
Logger: logger,
|
|
}
|
|
creds, err := credConfig.GenerateCredentialChain()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if creds == nil {
|
|
return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata")
|
|
}
|
|
|
|
_, err = creds.Get()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to retrieve credentials from credential chain: %w", err)
|
|
}
|
|
return creds, nil
|
|
}
|
|
|
|
// GenerateLoginData populates the necessary data to send to the Vault server for generating a token
|
|
// This is useful for other API clients to use
|
|
func GenerateLoginData(creds *credentials.Credentials, headerValue, configuredRegion string, logger hclog.Logger) (map[string]interface{}, error) {
|
|
loginData := make(map[string]interface{})
|
|
|
|
// Use the credentials we've found to construct an STS session
|
|
region, err := GetRegion(configuredRegion)
|
|
if err != nil {
|
|
logger.Warn(fmt.Sprintf("defaulting region to %q due to %s", DefaultRegion, err.Error()))
|
|
region = DefaultRegion
|
|
}
|
|
stsSession, err := session.NewSessionWithOptions(session.Options{
|
|
Config: aws.Config{
|
|
Credentials: creds,
|
|
Region: ®ion,
|
|
EndpointResolver: endpoints.ResolverFunc(stsSigningResolver),
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var params *sts.GetCallerIdentityInput
|
|
svc := sts.New(stsSession)
|
|
stsRequest, _ := svc.GetCallerIdentityRequest(params)
|
|
|
|
// Inject the required auth header value, if supplied, and then sign the request including that header
|
|
if headerValue != "" {
|
|
stsRequest.HTTPRequest.Header.Add(iamServerIdHeader, headerValue)
|
|
}
|
|
stsRequest.Sign()
|
|
|
|
// Now extract out the relevant parts of the request
|
|
headersJson, err := json.Marshal(stsRequest.HTTPRequest.Header)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
requestBody, err := ioutil.ReadAll(stsRequest.HTTPRequest.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
loginData["iam_http_request_method"] = stsRequest.HTTPRequest.Method
|
|
loginData["iam_request_url"] = base64.StdEncoding.EncodeToString([]byte(stsRequest.HTTPRequest.URL.String()))
|
|
loginData["iam_request_headers"] = base64.StdEncoding.EncodeToString(headersJson)
|
|
loginData["iam_request_body"] = base64.StdEncoding.EncodeToString(requestBody)
|
|
|
|
return loginData, nil
|
|
}
|
|
|
|
// STS is a really weird service that used to only have global endpoints but now has regional endpoints as well.
|
|
// For backwards compatibility, even if you request a region other than us-east-1, it'll still sign for us-east-1.
|
|
// See, e.g., https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_temp_enable-regions_writing_code
|
|
// So we have to shim in this EndpointResolver to force it to sign for the right region
|
|
func stsSigningResolver(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
|
|
defaultEndpoint, err := endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
|
|
if err != nil {
|
|
return defaultEndpoint, err
|
|
}
|
|
defaultEndpoint.SigningRegion = region
|
|
return defaultEndpoint, nil
|
|
}
|