Add an sts_region parameter to the AWS auth engine's client config (#7922)

This commit is contained in:
Becca Petrin 2019-12-10 16:02:04 -08:00 committed by GitHub
parent 875e0f490a
commit 535e88a629
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 133 additions and 5 deletions

View file

@ -37,15 +37,20 @@ func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, reg
endpoint := aws.String("") endpoint := aws.String("")
var maxRetries int = aws.UseServiceDefaultRetries var maxRetries int = aws.UseServiceDefaultRetries
if config != nil { if config != nil {
// Override the default endpoint with the configured endpoint. // Override the defaults with configured values.
switch { switch {
case clientType == "ec2" && config.Endpoint != "": case clientType == "ec2" && config.Endpoint != "":
endpoint = aws.String(config.Endpoint) endpoint = aws.String(config.Endpoint)
case clientType == "iam" && config.IAMEndpoint != "": case clientType == "iam" && config.IAMEndpoint != "":
endpoint = aws.String(config.IAMEndpoint) endpoint = aws.String(config.IAMEndpoint)
case clientType == "sts" && config.STSEndpoint != "": case clientType == "sts":
if config.STSEndpoint != "" {
endpoint = aws.String(config.STSEndpoint) endpoint = aws.String(config.STSEndpoint)
} }
if config.STSRegion != "" {
region = config.STSRegion
}
}
credsConfig.AccessKey = config.AccessKey credsConfig.AccessKey = config.AccessKey
credsConfig.SecretKey = config.SecretKey credsConfig.SecretKey = config.SecretKey

View file

@ -42,6 +42,12 @@ func (b *backend) pathConfigClient() *framework.Path {
Description: "URL to override the default generated endpoint for making AWS STS API calls.", Description: "URL to override the default generated endpoint for making AWS STS API calls.",
}, },
"sts_region": {
Type: framework.TypeString,
Default: "",
Description: "The region ID for the sts_endpoint, if set.",
},
"iam_server_id_header_value": { "iam_server_id_header_value": {
Type: framework.TypeString, Type: framework.TypeString,
Default: "", Default: "",
@ -127,6 +133,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request
"endpoint": clientConfig.Endpoint, "endpoint": clientConfig.Endpoint,
"iam_endpoint": clientConfig.IAMEndpoint, "iam_endpoint": clientConfig.IAMEndpoint,
"sts_endpoint": clientConfig.STSEndpoint, "sts_endpoint": clientConfig.STSEndpoint,
"sts_region": clientConfig.STSRegion,
"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue, "iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue,
"max_retries": clientConfig.MaxRetries, "max_retries": clientConfig.MaxRetries,
}, },
@ -217,7 +224,7 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
stsEndpointStr, ok := data.GetOk("sts_endpoint") stsEndpointStr, ok := data.GetOk("sts_endpoint")
if ok { if ok {
if configEntry.STSEndpoint != stsEndpointStr.(string) { if configEntry.STSEndpoint != stsEndpointStr.(string) {
// We don't directly cache STS clients as they are ever directly used. // We don't directly cache STS clients as they are never directly used.
// However, they are potentially indirectly used as credential providers // However, they are potentially indirectly used as credential providers
// for the EC2 and IAM clients, and thus we would be indirectly caching // for the EC2 and IAM clients, and thus we would be indirectly caching
// them there. So, if we change the STS endpoint, we should flush those // them there. So, if we change the STS endpoint, we should flush those
@ -229,6 +236,16 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
configEntry.STSEndpoint = data.Get("sts_endpoint").(string) configEntry.STSEndpoint = data.Get("sts_endpoint").(string)
} }
stsRegionStr, ok := data.GetOk("sts_region")
if ok {
if configEntry.STSRegion != stsRegionStr.(string) {
// Region is used when building STS clients. As such, all the comments
// regarding the sts_endpoint changing apply here as well.
changedCreds = true
configEntry.STSRegion = stsRegionStr.(string)
}
}
headerValStr, ok := data.GetOk("iam_server_id_header_value") headerValStr, ok := data.GetOk("iam_server_id_header_value")
if ok { if ok {
if configEntry.IAMServerIdHeaderValue != headerValStr.(string) { if configEntry.IAMServerIdHeaderValue != headerValStr.(string) {
@ -281,6 +298,7 @@ type clientConfig struct {
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
IAMEndpoint string `json:"iam_endpoint"` IAMEndpoint string `json:"iam_endpoint"`
STSEndpoint string `json:"sts_endpoint"` STSEndpoint string `json:"sts_endpoint"`
STSRegion string `json:"sts_region"`
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"` IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
MaxRetries int `json:"max_retries"` MaxRetries int `json:"max_retries"`
} }

View file

@ -44,6 +44,7 @@ func TestBackend_pathConfigClient(t *testing.T) {
data := map[string]interface{}{ data := map[string]interface{}{
"sts_endpoint": "https://my-custom-sts-endpoint.example.com", "sts_endpoint": "https://my-custom-sts-endpoint.example.com",
"sts_region": "us-east-2",
"iam_server_id_header_value": "vault_server_identification_314159", "iam_server_id_header_value": "vault_server_identification_314159",
} }
resp, err = b.HandleRequest(context.Background(), &logical.Request{ resp, err = b.HandleRequest(context.Background(), &logical.Request{
@ -52,7 +53,6 @@ func TestBackend_pathConfigClient(t *testing.T) {
Data: data, Data: data,
Storage: storage, Storage: storage,
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -75,8 +75,18 @@ func TestBackend_pathConfigClient(t *testing.T) {
t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'", t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'",
data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"]) data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"])
} }
if resp.Data["sts_endpoint"] != data["sts_endpoint"] {
t.Fatalf("expected sts_endpoint: '%#v'; returned sts_endpoint: '%#v'",
data["sts_endpoint"], resp.Data["sts_endpoint"])
}
if resp.Data["sts_region"] != data["sts_region"] {
t.Fatalf("expected sts_region: '%#v'; returned sts_region: '%#v'",
data["sts_region"], resp.Data["sts_region"])
}
data = map[string]interface{}{ data = map[string]interface{}{
"sts_endpoint": "https://my-custom-sts-endpoint2.example.com",
"sts_region": "us-west-1",
"iam_server_id_header_value": "vault_server_identification_2718281", "iam_server_id_header_value": "vault_server_identification_2718281",
} }
resp, err = b.HandleRequest(context.Background(), &logical.Request{ resp, err = b.HandleRequest(context.Background(), &logical.Request{
@ -108,4 +118,12 @@ func TestBackend_pathConfigClient(t *testing.T) {
t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'", t.Fatalf("expected iam_server_id_header_value: '%#v'; returned iam_server_id_header_value: '%#v'",
data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"]) data["iam_server_id_header_value"], resp.Data["iam_server_id_header_value"])
} }
if resp.Data["sts_endpoint"] != data["sts_endpoint"] {
t.Fatalf("expected sts_endpoint: '%#v'; returned sts_endpoint: '%#v'",
data["sts_endpoint"], resp.Data["sts_endpoint"])
}
if resp.Data["sts_region"] != data["sts_region"] {
t.Fatalf("expected sts_region: '%#v'; returned sts_region: '%#v'",
data["sts_region"], resp.Data["sts_region"])
}
} }

View file

@ -2,11 +2,14 @@ package awsauth
import ( import (
"context" "context"
"os"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"github.com/go-test/deep" "github.com/go-test/deep"
"github.com/hashicorp/vault/helper/awsutil"
vlttesting "github.com/hashicorp/vault/helper/testhelpers/logical"
"github.com/hashicorp/vault/sdk/helper/policyutil" "github.com/hashicorp/vault/sdk/helper/policyutil"
"github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
@ -986,6 +989,90 @@ func TestAwsVersion(t *testing.T) {
} }
} }
// This test was used to reproduce https://github.com/hashicorp/vault/issues/7418
// and verify its fix.
// Please run it at least 3 times to ensure that passing tests are due to actually
// passing, rather than the region being randomly chosen tying to the one in the
// test through luck.
func TestRoleResolutionWithSTSEndpointConfigured(t *testing.T) {
if enabled := os.Getenv(vlttesting.TestEnvVar); enabled == "" {
t.Skip()
}
/* ARN of an AWS role that Vault can query during testing.
This role should exist in your current AWS account and your credentials
should have iam:GetRole permissions to query it.
*/
assumableRoleArn := os.Getenv("AWS_ASSUMABLE_ROLE_ARN")
if assumableRoleArn == "" {
t.Skip("skipping because AWS_ASSUMABLE_ROLE_ARN is unset")
}
// Ensure aws credentials are available locally for testing.
credsConfig := &awsutil.CredentialsConfig{}
credsChain, err := credsConfig.GenerateCredentialChain()
if err != nil {
t.Fatal(err)
}
_, err = credsChain.Get()
if err != nil {
t.SkipNow()
}
config := logical.TestBackendConfig()
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := Backend(config)
if err != nil {
t.Fatal(err)
}
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
// configure the client with an sts endpoint that should be used in creating the role
data := map[string]interface{}{
"sts_endpoint": "https://sts.eu-west-1.amazonaws.com",
// Note - if you comment this out, you can reproduce the error shown
// in the linked GH issue above. This essentially reproduces the problem
// we had when we didn't have an sts_region field.
"sts_region": "eu-west-1",
}
resp, err := b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.CreateOperation,
Path: "config/client",
Data: data,
Storage: storage,
})
if err != nil {
t.Fatal(err)
}
if resp != nil && resp.IsError() {
t.Fatalf("failed to create the role entry; resp: %#v", resp)
}
data = map[string]interface{}{
"auth_type": iamAuthType,
"bound_iam_principal_arn": assumableRoleArn,
"resolve_aws_unique_ids": true,
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.CreateOperation,
Path: "role/MyRoleName",
Data: data,
Storage: storage,
})
if err != nil {
t.Fatal(err)
}
if resp != nil && resp.IsError() {
t.Fatalf("failed to create the role entry; resp: %#v", resp)
}
}
func resolveArnToFakeUniqueId(_ context.Context, _ logical.Storage, _ string) (string, error) { func resolveArnToFakeUniqueId(_ context.Context, _ logical.Storage, _ string) (string, error) {
return "FakeUniqueId1", nil return "FakeUniqueId1", nil
} }