Add an sts_region parameter to the AWS auth engine's client config (#7922)
This commit is contained in:
parent
875e0f490a
commit
535e88a629
|
@ -37,15 +37,20 @@ func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, reg
|
||||||
endpoint := aws.String("")
|
endpoint := aws.String("")
|
||||||
var maxRetries int = aws.UseServiceDefaultRetries
|
var maxRetries int = aws.UseServiceDefaultRetries
|
||||||
if config != nil {
|
if config != nil {
|
||||||
// Override the default endpoint with the configured endpoint.
|
// Override the defaults with configured values.
|
||||||
switch {
|
switch {
|
||||||
case clientType == "ec2" && config.Endpoint != "":
|
case clientType == "ec2" && config.Endpoint != "":
|
||||||
endpoint = aws.String(config.Endpoint)
|
endpoint = aws.String(config.Endpoint)
|
||||||
case clientType == "iam" && config.IAMEndpoint != "":
|
case clientType == "iam" && config.IAMEndpoint != "":
|
||||||
endpoint = aws.String(config.IAMEndpoint)
|
endpoint = aws.String(config.IAMEndpoint)
|
||||||
case clientType == "sts" && config.STSEndpoint != "":
|
case clientType == "sts":
|
||||||
|
if config.STSEndpoint != "" {
|
||||||
endpoint = aws.String(config.STSEndpoint)
|
endpoint = aws.String(config.STSEndpoint)
|
||||||
}
|
}
|
||||||
|
if config.STSRegion != "" {
|
||||||
|
region = config.STSRegion
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
credsConfig.AccessKey = config.AccessKey
|
credsConfig.AccessKey = config.AccessKey
|
||||||
credsConfig.SecretKey = config.SecretKey
|
credsConfig.SecretKey = config.SecretKey
|
||||||
|
|
|
@ -42,6 +42,12 @@ func (b *backend) pathConfigClient() *framework.Path {
|
||||||
Description: "URL to override the default generated endpoint for making AWS STS API calls.",
|
Description: "URL to override the default generated endpoint for making AWS STS API calls.",
|
||||||
},
|
},
|
||||||
|
|
||||||
|
"sts_region": {
|
||||||
|
Type: framework.TypeString,
|
||||||
|
Default: "",
|
||||||
|
Description: "The region ID for the sts_endpoint, if set.",
|
||||||
|
},
|
||||||
|
|
||||||
"iam_server_id_header_value": {
|
"iam_server_id_header_value": {
|
||||||
Type: framework.TypeString,
|
Type: framework.TypeString,
|
||||||
Default: "",
|
Default: "",
|
||||||
|
@ -127,6 +133,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request
|
||||||
"endpoint": clientConfig.Endpoint,
|
"endpoint": clientConfig.Endpoint,
|
||||||
"iam_endpoint": clientConfig.IAMEndpoint,
|
"iam_endpoint": clientConfig.IAMEndpoint,
|
||||||
"sts_endpoint": clientConfig.STSEndpoint,
|
"sts_endpoint": clientConfig.STSEndpoint,
|
||||||
|
"sts_region": clientConfig.STSRegion,
|
||||||
"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue,
|
"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue,
|
||||||
"max_retries": clientConfig.MaxRetries,
|
"max_retries": clientConfig.MaxRetries,
|
||||||
},
|
},
|
||||||
|
@ -217,7 +224,7 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
|
||||||
stsEndpointStr, ok := data.GetOk("sts_endpoint")
|
stsEndpointStr, ok := data.GetOk("sts_endpoint")
|
||||||
if ok {
|
if ok {
|
||||||
if configEntry.STSEndpoint != stsEndpointStr.(string) {
|
if configEntry.STSEndpoint != stsEndpointStr.(string) {
|
||||||
// We don't directly cache STS clients as they are ever directly used.
|
// We don't directly cache STS clients as they are never directly used.
|
||||||
// However, they are potentially indirectly used as credential providers
|
// However, they are potentially indirectly used as credential providers
|
||||||
// for the EC2 and IAM clients, and thus we would be indirectly caching
|
// for the EC2 and IAM clients, and thus we would be indirectly caching
|
||||||
// them there. So, if we change the STS endpoint, we should flush those
|
// them there. So, if we change the STS endpoint, we should flush those
|
||||||
|
@ -229,6 +236,16 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
|
||||||
configEntry.STSEndpoint = data.Get("sts_endpoint").(string)
|
configEntry.STSEndpoint = data.Get("sts_endpoint").(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stsRegionStr, ok := data.GetOk("sts_region")
|
||||||
|
if ok {
|
||||||
|
if configEntry.STSRegion != stsRegionStr.(string) {
|
||||||
|
// Region is used when building STS clients. As such, all the comments
|
||||||
|
// regarding the sts_endpoint changing apply here as well.
|
||||||
|
changedCreds = true
|
||||||
|
configEntry.STSRegion = stsRegionStr.(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
headerValStr, ok := data.GetOk("iam_server_id_header_value")
|
headerValStr, ok := data.GetOk("iam_server_id_header_value")
|
||||||
if ok {
|
if ok {
|
||||||
if configEntry.IAMServerIdHeaderValue != headerValStr.(string) {
|
if configEntry.IAMServerIdHeaderValue != headerValStr.(string) {
|
||||||
|
@ -281,6 +298,7 @@ type clientConfig struct {
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
IAMEndpoint string `json:"iam_endpoint"`
|
IAMEndpoint string `json:"iam_endpoint"`
|
||||||
STSEndpoint string `json:"sts_endpoint"`
|
STSEndpoint string `json:"sts_endpoint"`
|
||||||
|
STSRegion string `json:"sts_region"`
|
||||||
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
|
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
|
||||||
MaxRetries int `json:"max_retries"`
|
MaxRetries int `json:"max_retries"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,6 +44,7 @@ func TestBackend_pathConfigClient(t *testing.T) {
|
||||||
|
|
||||||
data := map[string]interface{}{
|
data := map[string]interface{}{
|
||||||
"sts_endpoint": "https://my-custom-sts-endpoint.example.com",
|
"sts_endpoint": "https://my-custom-sts-endpoint.example.com",
|
||||||
|
"sts_region": "us-east-2",
|
||||||
"iam_server_id_header_value": "vault_server_identification_314159",
|
"iam_server_id_header_value": "vault_server_identification_314159",
|
||||||
}
|
}
|
||||||
resp, err = b.HandleRequest(context.Background(), &logical.Request{
|
resp, err = b.HandleRequest(context.Background(), &logical.Request{
|
||||||
|
@ -52,7 +53,6 @@ func TestBackend_pathConfigClient(t *testing.T) {
|
||||||
Data: data,
|
Data: data,
|
||||||
Storage: storage,
|
Storage: storage,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -75,8 +75,18 @@ func TestBackend_pathConfigClient(t *testing.T) {
|
||||||
t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'",
|
t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'",
|
||||||
data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"])
|
data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"])
|
||||||
}
|
}
|
||||||
|
if resp.Data["sts_endpoint"] != data["sts_endpoint"] {
|
||||||
|
t.Fatalf("expected sts_endpoint: '%#v'; returned sts_endpoint: '%#v'",
|
||||||
|
data["sts_endpoint"], resp.Data["sts_endpoint"])
|
||||||
|
}
|
||||||
|
if resp.Data["sts_region"] != data["sts_region"] {
|
||||||
|
t.Fatalf("expected sts_region: '%#v'; returned sts_region: '%#v'",
|
||||||
|
data["sts_region"], resp.Data["sts_region"])
|
||||||
|
}
|
||||||
|
|
||||||
data = map[string]interface{}{
|
data = map[string]interface{}{
|
||||||
|
"sts_endpoint": "https://my-custom-sts-endpoint2.example.com",
|
||||||
|
"sts_region": "us-west-1",
|
||||||
"iam_server_id_header_value": "vault_server_identification_2718281",
|
"iam_server_id_header_value": "vault_server_identification_2718281",
|
||||||
}
|
}
|
||||||
resp, err = b.HandleRequest(context.Background(), &logical.Request{
|
resp, err = b.HandleRequest(context.Background(), &logical.Request{
|
||||||
|
@ -108,4 +118,12 @@ func TestBackend_pathConfigClient(t *testing.T) {
|
||||||
t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'",
|
t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'",
|
||||||
data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"])
|
data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"])
|
||||||
}
|
}
|
||||||
|
if resp.Data["sts_endpoint"] != data["sts_endpoint"] {
|
||||||
|
t.Fatalf("expected sts_endpoint: '%#v'; returned sts_endpoint: '%#v'",
|
||||||
|
data["sts_endpoint"], resp.Data["sts_endpoint"])
|
||||||
|
}
|
||||||
|
if resp.Data["sts_region"] != data["sts_region"] {
|
||||||
|
t.Fatalf("expected sts_region: '%#v'; returned sts_region: '%#v'",
|
||||||
|
data["sts_region"], resp.Data["sts_region"])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,11 +2,14 @@ package awsauth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/go-test/deep"
|
"github.com/go-test/deep"
|
||||||
|
"github.com/hashicorp/vault/helper/awsutil"
|
||||||
|
vlttesting "github.com/hashicorp/vault/helper/testhelpers/logical"
|
||||||
"github.com/hashicorp/vault/sdk/helper/policyutil"
|
"github.com/hashicorp/vault/sdk/helper/policyutil"
|
||||||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
|
@ -986,6 +989,90 @@ func TestAwsVersion(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This test was used to reproduce https://github.com/hashicorp/vault/issues/7418
|
||||||
|
// and verify its fix.
|
||||||
|
// Please run it at least 3 times to ensure that passing tests are due to actually
|
||||||
|
// passing, rather than the region being randomly chosen tying to the one in the
|
||||||
|
// test through luck.
|
||||||
|
func TestRoleResolutionWithSTSEndpointConfigured(t *testing.T) {
|
||||||
|
if enabled := os.Getenv(vlttesting.TestEnvVar); enabled == "" {
|
||||||
|
t.Skip()
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ARN of an AWS role that Vault can query during testing.
|
||||||
|
This role should exist in your current AWS account and your credentials
|
||||||
|
should have iam:GetRole permissions to query it.
|
||||||
|
*/
|
||||||
|
assumableRoleArn := os.Getenv("AWS_ASSUMABLE_ROLE_ARN")
|
||||||
|
if assumableRoleArn == "" {
|
||||||
|
t.Skip("skipping because AWS_ASSUMABLE_ROLE_ARN is unset")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure aws credentials are available locally for testing.
|
||||||
|
credsConfig := &awsutil.CredentialsConfig{}
|
||||||
|
credsChain, err := credsConfig.GenerateCredentialChain()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err = credsChain.Get()
|
||||||
|
if err != nil {
|
||||||
|
t.SkipNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
config := logical.TestBackendConfig()
|
||||||
|
storage := &logical.InmemStorage{}
|
||||||
|
config.StorageView = storage
|
||||||
|
|
||||||
|
b, err := Backend(config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = b.Setup(context.Background(), config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// configure the client with an sts endpoint that should be used in creating the role
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"sts_endpoint": "https://sts.eu-west-1.amazonaws.com",
|
||||||
|
// Note - if you comment this out, you can reproduce the error shown
|
||||||
|
// in the linked GH issue above. This essentially reproduces the problem
|
||||||
|
// we had when we didn't have an sts_region field.
|
||||||
|
"sts_region": "eu-west-1",
|
||||||
|
}
|
||||||
|
resp, err := b.HandleRequest(context.Background(), &logical.Request{
|
||||||
|
Operation: logical.CreateOperation,
|
||||||
|
Path: "config/client",
|
||||||
|
Data: data,
|
||||||
|
Storage: storage,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if resp != nil && resp.IsError() {
|
||||||
|
t.Fatalf("failed to create the role entry; resp: %#v", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
data = map[string]interface{}{
|
||||||
|
"auth_type": iamAuthType,
|
||||||
|
"bound_iam_principal_arn": assumableRoleArn,
|
||||||
|
"resolve_aws_unique_ids": true,
|
||||||
|
}
|
||||||
|
resp, err = b.HandleRequest(context.Background(), &logical.Request{
|
||||||
|
Operation: logical.CreateOperation,
|
||||||
|
Path: "role/MyRoleName",
|
||||||
|
Data: data,
|
||||||
|
Storage: storage,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if resp != nil && resp.IsError() {
|
||||||
|
t.Fatalf("failed to create the role entry; resp: %#v", resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func resolveArnToFakeUniqueId(_ context.Context, _ logical.Storage, _ string) (string, error) {
|
func resolveArnToFakeUniqueId(_ context.Context, _ logical.Storage, _ string) (string, error) {
|
||||||
return "FakeUniqueId1", nil
|
return "FakeUniqueId1", nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue