Add an sts_region parameter to the AWS auth engine's client config (#7922)
This commit is contained in:
parent
875e0f490a
commit
535e88a629
|
@ -37,14 +37,19 @@ func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, reg
|
|||
endpoint := aws.String("")
|
||||
var maxRetries int = aws.UseServiceDefaultRetries
|
||||
if config != nil {
|
||||
// Override the default endpoint with the configured endpoint.
|
||||
// Override the defaults with configured values.
|
||||
switch {
|
||||
case clientType == "ec2" && config.Endpoint != "":
|
||||
endpoint = aws.String(config.Endpoint)
|
||||
case clientType == "iam" && config.IAMEndpoint != "":
|
||||
endpoint = aws.String(config.IAMEndpoint)
|
||||
case clientType == "sts" && config.STSEndpoint != "":
|
||||
endpoint = aws.String(config.STSEndpoint)
|
||||
case clientType == "sts":
|
||||
if config.STSEndpoint != "" {
|
||||
endpoint = aws.String(config.STSEndpoint)
|
||||
}
|
||||
if config.STSRegion != "" {
|
||||
region = config.STSRegion
|
||||
}
|
||||
}
|
||||
|
||||
credsConfig.AccessKey = config.AccessKey
|
||||
|
|
|
@ -42,6 +42,12 @@ func (b *backend) pathConfigClient() *framework.Path {
|
|||
Description: "URL to override the default generated endpoint for making AWS STS API calls.",
|
||||
},
|
||||
|
||||
"sts_region": {
|
||||
Type: framework.TypeString,
|
||||
Default: "",
|
||||
Description: "The region ID for the sts_endpoint, if set.",
|
||||
},
|
||||
|
||||
"iam_server_id_header_value": {
|
||||
Type: framework.TypeString,
|
||||
Default: "",
|
||||
|
@ -127,6 +133,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request
|
|||
"endpoint": clientConfig.Endpoint,
|
||||
"iam_endpoint": clientConfig.IAMEndpoint,
|
||||
"sts_endpoint": clientConfig.STSEndpoint,
|
||||
"sts_region": clientConfig.STSRegion,
|
||||
"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue,
|
||||
"max_retries": clientConfig.MaxRetries,
|
||||
},
|
||||
|
@ -217,7 +224,7 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
|
|||
stsEndpointStr, ok := data.GetOk("sts_endpoint")
|
||||
if ok {
|
||||
if configEntry.STSEndpoint != stsEndpointStr.(string) {
|
||||
// We don't directly cache STS clients as they are ever directly used.
|
||||
// We don't directly cache STS clients as they are never directly used.
|
||||
// However, they are potentially indirectly used as credential providers
|
||||
// for the EC2 and IAM clients, and thus we would be indirectly caching
|
||||
// them there. So, if we change the STS endpoint, we should flush those
|
||||
|
@ -229,6 +236,16 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
|
|||
configEntry.STSEndpoint = data.Get("sts_endpoint").(string)
|
||||
}
|
||||
|
||||
stsRegionStr, ok := data.GetOk("sts_region")
|
||||
if ok {
|
||||
if configEntry.STSRegion != stsRegionStr.(string) {
|
||||
// Region is used when building STS clients. As such, all the comments
|
||||
// regarding the sts_endpoint changing apply here as well.
|
||||
changedCreds = true
|
||||
configEntry.STSRegion = stsRegionStr.(string)
|
||||
}
|
||||
}
|
||||
|
||||
headerValStr, ok := data.GetOk("iam_server_id_header_value")
|
||||
if ok {
|
||||
if configEntry.IAMServerIdHeaderValue != headerValStr.(string) {
|
||||
|
@ -281,6 +298,7 @@ type clientConfig struct {
|
|||
Endpoint string `json:"endpoint"`
|
||||
IAMEndpoint string `json:"iam_endpoint"`
|
||||
STSEndpoint string `json:"sts_endpoint"`
|
||||
STSRegion string `json:"sts_region"`
|
||||
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ func TestBackend_pathConfigClient(t *testing.T) {
|
|||
|
||||
data := map[string]interface{}{
|
||||
"sts_endpoint": "https://my-custom-sts-endpoint.example.com",
|
||||
"sts_region": "us-east-2",
|
||||
"iam_server_id_header_value": "vault_server_identification_314159",
|
||||
}
|
||||
resp, err = b.HandleRequest(context.Background(), &logical.Request{
|
||||
|
@ -52,7 +53,6 @@ func TestBackend_pathConfigClient(t *testing.T) {
|
|||
Data: data,
|
||||
Storage: storage,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -75,8 +75,18 @@ func TestBackend_pathConfigClient(t *testing.T) {
|
|||
t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'",
|
||||
data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"])
|
||||
}
|
||||
if resp.Data["sts_endpoint"] != data["sts_endpoint"] {
|
||||
t.Fatalf("expected sts_endpoint: '%#v'; returned sts_endpoint: '%#v'",
|
||||
data["sts_endpoint"], resp.Data["sts_endpoint"])
|
||||
}
|
||||
if resp.Data["sts_region"] != data["sts_region"] {
|
||||
t.Fatalf("expected sts_region: '%#v'; returned sts_region: '%#v'",
|
||||
data["sts_region"], resp.Data["sts_region"])
|
||||
}
|
||||
|
||||
data = map[string]interface{}{
|
||||
"sts_endpoint": "https://my-custom-sts-endpoint2.example.com",
|
||||
"sts_region": "us-west-1",
|
||||
"iam_server_id_header_value": "vault_server_identification_2718281",
|
||||
}
|
||||
resp, err = b.HandleRequest(context.Background(), &logical.Request{
|
||||
|
@ -108,4 +118,12 @@ func TestBackend_pathConfigClient(t *testing.T) {
|
|||
t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'",
|
||||
data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"])
|
||||
}
|
||||
if resp.Data["sts_endpoint"] != data["sts_endpoint"] {
|
||||
t.Fatalf("expected sts_endpoint: '%#v'; returned sts_endpoint: '%#v'",
|
||||
data["sts_endpoint"], resp.Data["sts_endpoint"])
|
||||
}
|
||||
if resp.Data["sts_region"] != data["sts_region"] {
|
||||
t.Fatalf("expected sts_region: '%#v'; returned sts_region: '%#v'",
|
||||
data["sts_region"], resp.Data["sts_region"])
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,11 +2,14 @@ package awsauth
|
|||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
"github.com/hashicorp/vault/helper/awsutil"
|
||||
vlttesting "github.com/hashicorp/vault/helper/testhelpers/logical"
|
||||
"github.com/hashicorp/vault/sdk/helper/policyutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
|
@ -986,6 +989,90 @@ func TestAwsVersion(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// This test was used to reproduce https://github.com/hashicorp/vault/issues/7418
|
||||
// and verify its fix.
|
||||
// Please run it at least 3 times to ensure that passing tests are due to actually
|
||||
// passing, rather than the region being randomly chosen tying to the one in the
|
||||
// test through luck.
|
||||
func TestRoleResolutionWithSTSEndpointConfigured(t *testing.T) {
|
||||
if enabled := os.Getenv(vlttesting.TestEnvVar); enabled == "" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
/* ARN of an AWS role that Vault can query during testing.
|
||||
This role should exist in your current AWS account and your credentials
|
||||
should have iam:GetRole permissions to query it.
|
||||
*/
|
||||
assumableRoleArn := os.Getenv("AWS_ASSUMABLE_ROLE_ARN")
|
||||
if assumableRoleArn == "" {
|
||||
t.Skip("skipping because AWS_ASSUMABLE_ROLE_ARN is unset")
|
||||
}
|
||||
|
||||
// Ensure aws credentials are available locally for testing.
|
||||
credsConfig := &awsutil.CredentialsConfig{}
|
||||
credsChain, err := credsConfig.GenerateCredentialChain()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = credsChain.Get()
|
||||
if err != nil {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
config := logical.TestBackendConfig()
|
||||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b, err := Backend(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// configure the client with an sts endpoint that should be used in creating the role
|
||||
data := map[string]interface{}{
|
||||
"sts_endpoint": "https://sts.eu-west-1.amazonaws.com",
|
||||
// Note - if you comment this out, you can reproduce the error shown
|
||||
// in the linked GH issue above. This essentially reproduces the problem
|
||||
// we had when we didn't have an sts_region field.
|
||||
"sts_region": "eu-west-1",
|
||||
}
|
||||
resp, err := b.HandleRequest(context.Background(), &logical.Request{
|
||||
Operation: logical.CreateOperation,
|
||||
Path: "config/client",
|
||||
Data: data,
|
||||
Storage: storage,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp != nil && resp.IsError() {
|
||||
t.Fatalf("failed to create the role entry; resp: %#v", resp)
|
||||
}
|
||||
|
||||
data = map[string]interface{}{
|
||||
"auth_type": iamAuthType,
|
||||
"bound_iam_principal_arn": assumableRoleArn,
|
||||
"resolve_aws_unique_ids": true,
|
||||
}
|
||||
resp, err = b.HandleRequest(context.Background(), &logical.Request{
|
||||
Operation: logical.CreateOperation,
|
||||
Path: "role/MyRoleName",
|
||||
Data: data,
|
||||
Storage: storage,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp != nil && resp.IsError() {
|
||||
t.Fatalf("failed to create the role entry; resp: %#v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveArnToFakeUniqueId(_ context.Context, _ logical.Storage, _ string) (string, error) {
|
||||
return "FakeUniqueId1", nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue