diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds/ec2_role_provider.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds/ec2_role_provider.go new file mode 100644 index 000000000..946a11720 --- /dev/null +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds/ec2_role_provider.go @@ -0,0 +1,168 @@ +package ec2rolecreds + +import ( + "bufio" + "encoding/json" + "fmt" + "path" + "strings" + "time" + + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/ec2metadata" +) + +// A EC2RoleProvider retrieves credentials from the EC2 service, and keeps track if +// those credentials are expired. +// +// Example how to configure the EC2RoleProvider with custom http Client, Endpoint +// or ExpiryWindow +// +// p := &ec2rolecreds.EC2RoleProvider{ +// // Pass in a custom timeout to be used when requesting +// // IAM EC2 Role credentials. +// Client: &http.Client{ +// Timeout: 10 * time.Second, +// }, +// // Use default EC2 Role metadata endpoint, Alternate endpoints can be +// // specified setting Endpoint to something else. +// Endpoint: "", +// // Do not use early expiry of credentials. If a non zero value is +// // specified the credentials will be expired early +// ExpiryWindow: 0, +// } +type EC2RoleProvider struct { + credentials.Expiry + + // EC2Metadata client to use when connecting to EC2 metadata service + Client *ec2metadata.Client + + // ExpiryWindow will allow the credentials to trigger refreshing prior to + // the credentials actually expiring. This is beneficial so race conditions + // with expiring credentials do not cause request to fail unexpectedly + // due to ExpiredTokenException exceptions. + // + // So a ExpiryWindow of 10s would cause calls to IsExpired() to return true + // 10 seconds before the credentials are actually expired. + // + // If ExpiryWindow is 0 or less it will be ignored. + ExpiryWindow time.Duration +} + +// NewCredentials returns a pointer to a new Credentials object +// wrapping the EC2RoleProvider. +// +// Takes a custom http.Client which can be configured for custom handling of +// things such as timeout. +// +// Endpoint is the URL that the EC2RoleProvider will connect to when retrieving +// role and credentials. +// +// Window is the expiry window that will be subtracted from the expiry returned +// by the role credential request. This is done so that the credentials will +// expire sooner than their actual lifespan. +func NewCredentials(client *ec2metadata.Client, window time.Duration) *credentials.Credentials { + return credentials.NewCredentials(&EC2RoleProvider{ + Client: client, + ExpiryWindow: window, + }) +} + +// Retrieve retrieves credentials from the EC2 service. +// Error will be returned if the request fails, or unable to extract +// the desired credentials. +func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) { + if m.Client == nil { + m.Client = ec2metadata.New(nil) + } + + credsList, err := requestCredList(m.Client) + if err != nil { + return credentials.Value{}, err + } + + if len(credsList) == 0 { + return credentials.Value{}, awserr.New("EmptyEC2RoleList", "empty EC2 Role list", nil) + } + credsName := credsList[0] + + roleCreds, err := requestCred(m.Client, credsName) + if err != nil { + return credentials.Value{}, err + } + + m.SetExpiration(roleCreds.Expiration, m.ExpiryWindow) + + return credentials.Value{ + AccessKeyID: roleCreds.AccessKeyID, + SecretAccessKey: roleCreds.SecretAccessKey, + SessionToken: roleCreds.Token, + }, nil +} + +// A ec2RoleCredRespBody provides the shape for deserializing credential +// request responses. +type ec2RoleCredRespBody struct { + // Success State + Expiration time.Time + AccessKeyID string + SecretAccessKey string + Token string + + // Error state + Code string + Message string +} + +const iamSecurityCredsPath = "/iam/security-credentials" + +// requestCredList requests a list of credentials from the EC2 service. +// If there are no credentials, or there is an error making or receiving the request +func requestCredList(client *ec2metadata.Client) ([]string, error) { + resp, err := client.GetMetadata(iamSecurityCredsPath) + if err != nil { + return nil, awserr.New("EC2RoleRequestError", "failed to list EC2 Roles", err) + } + + credsList := []string{} + s := bufio.NewScanner(strings.NewReader(resp)) + for s.Scan() { + credsList = append(credsList, s.Text()) + } + + if err := s.Err(); err != nil { + return nil, awserr.New("SerializationError", "failed to read list of EC2 Roles", err) + } + + return credsList, nil +} + +// requestCred requests the credentials for a specific credentials from the EC2 service. +// +// If the credentials cannot be found, or there is an error reading the response +// and error will be returned. +func requestCred(client *ec2metadata.Client, credsName string) (ec2RoleCredRespBody, error) { + resp, err := client.GetMetadata(path.Join(iamSecurityCredsPath, credsName)) + if err != nil { + return ec2RoleCredRespBody{}, + awserr.New("EC2RoleRequestError", + fmt.Sprintf("failed to get %s EC2 Role credentials", credsName), + err) + } + + respCreds := ec2RoleCredRespBody{} + if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil { + return ec2RoleCredRespBody{}, + awserr.New("SerializationError", + fmt.Sprintf("failed to decode %s EC2 Role credentials", credsName), + err) + } + + if respCreds.Code != "Success" { + // If an error code was returned something failed requesting the role. + return ec2RoleCredRespBody{}, awserr.New(respCreds.Code, respCreds.Message, nil) + } + + return respCreds, nil +} diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds/ec2_role_provider_test.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds/ec2_role_provider_test.go new file mode 100644 index 000000000..cd0cbc97e --- /dev/null +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds/ec2_role_provider_test.go @@ -0,0 +1,161 @@ +package ec2rolecreds_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" + "github.com/aws/aws-sdk-go/aws/ec2metadata" +) + +const credsRespTmpl = `{ + "Code": "Success", + "Type": "AWS-HMAC", + "AccessKeyId" : "accessKey", + "SecretAccessKey" : "secret", + "Token" : "token", + "Expiration" : "%s", + "LastUpdated" : "2009-11-23T0:00:00Z" +}` + +const credsFailRespTmpl = `{ + "Code": "ErrorCode", + "Message": "ErrorMsg", + "LastUpdated": "2009-11-23T0:00:00Z" +}` + +func initTestServer(expireOn string, failAssume bool) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/latest/meta-data/iam/security-credentials" { + fmt.Fprintln(w, "RoleName") + } else if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" { + if failAssume { + fmt.Fprintf(w, credsFailRespTmpl) + } else { + fmt.Fprintf(w, credsRespTmpl, expireOn) + } + } else { + http.Error(w, "bad request", http.StatusBadRequest) + } + })) + + return server +} + +func TestEC2RoleProvider(t *testing.T) { + server := initTestServer("2014-12-16T01:51:37Z", false) + defer server.Close() + + p := &ec2rolecreds.EC2RoleProvider{ + Client: ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")}), + } + + creds, err := p.Retrieve() + assert.Nil(t, err, "Expect no error") + + assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match") + assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match") + assert.Equal(t, "token", creds.SessionToken, "Expect session token to match") +} + +func TestEC2RoleProviderFailAssume(t *testing.T) { + server := initTestServer("2014-12-16T01:51:37Z", true) + defer server.Close() + + p := &ec2rolecreds.EC2RoleProvider{ + Client: ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")}), + } + + creds, err := p.Retrieve() + assert.Error(t, err, "Expect error") + + e := err.(awserr.Error) + assert.Equal(t, "ErrorCode", e.Code()) + assert.Equal(t, "ErrorMsg", e.Message()) + assert.Nil(t, e.OrigErr()) + + assert.Equal(t, "", creds.AccessKeyID, "Expect access key ID to match") + assert.Equal(t, "", creds.SecretAccessKey, "Expect secret access key to match") + assert.Equal(t, "", creds.SessionToken, "Expect session token to match") +} + +func TestEC2RoleProviderIsExpired(t *testing.T) { + server := initTestServer("2014-12-16T01:51:37Z", false) + defer server.Close() + + p := &ec2rolecreds.EC2RoleProvider{ + Client: ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")}), + } + p.CurrentTime = func() time.Time { + return time.Date(2014, 12, 15, 21, 26, 0, 0, time.UTC) + } + + assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.") + + _, err := p.Retrieve() + assert.Nil(t, err, "Expect no error") + + assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.") + + p.CurrentTime = func() time.Time { + return time.Date(3014, 12, 15, 21, 26, 0, 0, time.UTC) + } + + assert.True(t, p.IsExpired(), "Expect creds to be expired.") +} + +func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) { + server := initTestServer("2014-12-16T01:51:37Z", false) + defer server.Close() + + p := &ec2rolecreds.EC2RoleProvider{ + Client: ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")}), + ExpiryWindow: time.Hour * 1, + } + p.CurrentTime = func() time.Time { + return time.Date(2014, 12, 15, 0, 51, 37, 0, time.UTC) + } + + assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.") + + _, err := p.Retrieve() + assert.Nil(t, err, "Expect no error") + + assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.") + + p.CurrentTime = func() time.Time { + return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC) + } + + assert.True(t, p.IsExpired(), "Expect creds to be expired.") +} + +func BenchmarkEC2RoleProvider(b *testing.B) { + server := initTestServer("2014-12-16T01:51:37Z", false) + defer server.Close() + + p := &ec2rolecreds.EC2RoleProvider{ + Client: ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")}), + } + _, err := p.Retrieve() + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := p.Retrieve() + if err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/ec2metadata/api.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/ec2metadata/api.go new file mode 100644 index 000000000..896abe218 --- /dev/null +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/ec2metadata/api.go @@ -0,0 +1,43 @@ +package ec2metadata + +import ( + "path" + + "github.com/aws/aws-sdk-go/aws/service" +) + +// GetMetadata uses the path provided to request +func (c *Client) GetMetadata(p string) (string, error) { + op := &service.Operation{ + Name: "GetMetadata", + HTTPMethod: "GET", + HTTPPath: path.Join("/", "meta-data", p), + } + + output := &metadataOutput{} + req := service.NewRequest(c.Service, op, nil, output) + + return output.Content, req.Send() +} + +// Region returns the region the instance is running in. +func (c *Client) Region() (string, error) { + resp, err := c.GetMetadata("placement/availability-zone") + if err != nil { + return "", err + } + + // returns region without the suffix. Eg: us-west-2a becomes us-west-2 + return resp[:len(resp)-1], nil +} + +// Available returns if the application has access to the EC2 Metadata service. +// Can be used to determine if application is running within an EC2 Instance and +// the metadata service is available. +func (c *Client) Available() bool { + if _, err := c.GetMetadata("instance-id"); err != nil { + return false + } + + return true +} diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/ec2metadata/api_test.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/ec2metadata/api_test.go new file mode 100644 index 000000000..a371cf065 --- /dev/null +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/ec2metadata/api_test.go @@ -0,0 +1,100 @@ +package ec2metadata_test + +import ( + "bytes" + "io/ioutil" + "net/http" + "net/http/httptest" + "path" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go/aws/service" +) + +func initTestServer(path string, resp string) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.RequestURI != path { + http.Error(w, "not found", http.StatusNotFound) + return + } + + w.Write([]byte(resp)) + })) +} + +func TestEndpoint(t *testing.T) { + c := ec2metadata.New(&ec2metadata.Config{}) + op := &service.Operation{ + Name: "GetMetadata", + HTTPMethod: "GET", + HTTPPath: path.Join("/", "meta-data", "testpath"), + } + + req := service.NewRequest(c.Service, op, nil, nil) + assert.Equal(t, "http://169.254.169.254/latest", req.Endpoint) + assert.Equal(t, "http://169.254.169.254/latest/meta-data/testpath", req.HTTPRequest.URL.String()) +} + +func TestGetMetadata(t *testing.T) { + server := initTestServer( + "/latest/meta-data/some/path", + "success", // real response includes suffix + ) + defer server.Close() + c := ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")}) + + resp, err := c.GetMetadata("some/path") + + assert.NoError(t, err) + assert.Equal(t, "success", resp) +} + +func TestGetRegion(t *testing.T) { + server := initTestServer( + "/latest/meta-data/placement/availability-zone", + "us-west-2a", // real response includes suffix + ) + defer server.Close() + c := ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")}) + + region, err := c.Region() + + assert.NoError(t, err) + assert.Equal(t, "us-west-2", region) +} + +func TestMetadataAvailable(t *testing.T) { + server := initTestServer( + "/latest/meta-data/instance-id", + "instance-id", + ) + defer server.Close() + c := ec2metadata.New(&ec2metadata.Config{Endpoint: aws.String(server.URL + "/latest")}) + + available := c.Available() + + assert.True(t, available) +} + +func TestMetadataNotAvailable(t *testing.T) { + c := ec2metadata.New(nil) + c.Handlers.Send.Clear() + c.Handlers.Send.PushBack(func(r *service.Request) { + r.HTTPResponse = &http.Response{ + StatusCode: int(0), + Status: http.StatusText(int(0)), + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + } + r.Error = awserr.New("RequestError", "send request failed", nil) + r.Retryable = aws.Bool(true) // network errors are retryable + }) + + available := c.Available() + + assert.False(t, available) +} diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/ec2metadata/service.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/ec2metadata/service.go new file mode 100644 index 000000000..cbf067dd4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/ec2metadata/service.go @@ -0,0 +1,131 @@ +package ec2metadata + +import ( + "io/ioutil" + "net/http" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/service" +) + +// DefaultRetries states the default number of times the service client will +// attempt to retry a failed request before failing. +const DefaultRetries = 3 + +// A Config provides the configuration for the EC2 Metadata service. +type Config struct { + // An optional endpoint URL (hostname only or fully qualified URI) + // that overrides the default service endpoint for a client. Set this + // to nil, or `""` to use the default service endpoint. + Endpoint *string + + // The HTTP client to use when sending requests. Defaults to + // `http.DefaultClient`. + HTTPClient *http.Client + + // An integer value representing the logging level. The default log level + // is zero (LogOff), which represents no logging. To enable logging set + // to a LogLevel Value. + Logger aws.Logger + + // The logger writer interface to write logging messages to. Defaults to + // standard out. + LogLevel *aws.LogLevelType + + // The maximum number of times that a request will be retried for failures. + // Defaults to DefaultRetries for the number of retries to be performed + // per request. + MaxRetries *int +} + +// A Client is an EC2 Metadata service Client. +type Client struct { + *service.Service +} + +// New creates a new instance of the EC2 Metadata service client. +// +// In the general use case the configuration for this service client should not +// be needed and `nil` can be provided. Configuration is only needed if the +// `ec2metadata.Config` defaults need to be overridden. Eg. Setting LogLevel. +// +// @note This configuration will NOT be merged with the default AWS service +// client configuration `defaults.DefaultConfig`. Due to circular dependencies +// with the defaults package and credentials EC2 Role Provider. +func New(config *Config) *Client { + service := &service.Service{ + Config: copyConfig(config), + ServiceName: "Client", + Endpoint: "http://169.254.169.254/latest", + APIVersion: "latest", + } + service.Initialize() + service.Handlers.Unmarshal.PushBack(unmarshalHandler) + service.Handlers.UnmarshalError.PushBack(unmarshalError) + service.Handlers.Validate.Clear() + service.Handlers.Validate.PushBack(validateEndpointHandler) + + return &Client{service} +} + +func copyConfig(config *Config) *aws.Config { + if config == nil { + config = &Config{} + } + c := &aws.Config{ + Credentials: credentials.AnonymousCredentials, + Endpoint: config.Endpoint, + HTTPClient: config.HTTPClient, + Logger: config.Logger, + LogLevel: config.LogLevel, + MaxRetries: config.MaxRetries, + } + + if c.HTTPClient == nil { + c.HTTPClient = http.DefaultClient + } + if c.Logger == nil { + c.Logger = aws.NewDefaultLogger() + } + if c.LogLevel == nil { + c.LogLevel = aws.LogLevel(aws.LogOff) + } + if c.MaxRetries == nil { + c.MaxRetries = aws.Int(DefaultRetries) + } + + return c +} + +type metadataOutput struct { + Content string +} + +func unmarshalHandler(r *service.Request) { + defer r.HTTPResponse.Body.Close() + b, err := ioutil.ReadAll(r.HTTPResponse.Body) + if err != nil { + r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err) + } + + data := r.Data.(*metadataOutput) + data.Content = string(b) +} + +func unmarshalError(r *service.Request) { + defer r.HTTPResponse.Body.Close() + _, err := ioutil.ReadAll(r.HTTPResponse.Body) + if err != nil { + r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err) + } + + // TODO extract the error... +} + +func validateEndpointHandler(r *service.Request) { + if r.Service.Endpoint == "" { + r.Error = service.ErrMissingEndpoint + } +} diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/logger.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/logger.go index 935661c0b..f53694873 100644 --- a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/logger.go +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/logger.go @@ -62,6 +62,15 @@ const ( // see the body content of requests and responses made while using the SDK // Will also enable LogDebug. LogDebugWithHTTPBody + + // LogDebugWithRequestRetries states the SDK should log when service requests will + // be retried. This should be used to log when you want to log when service + // requests are being retried. Will also enable LogDebug. + LogDebugWithRequestRetries + + // LogDebugWithRequestErrors states the SDK should log when service requests fail + // to build, send, validate, or unmarshal. + LogDebugWithRequestErrors ) // A Logger is a minimalistic interface for the SDK to log messages to. Should diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/handler_functions.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/handler_functions.go new file mode 100644 index 000000000..8afce7523 --- /dev/null +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/handler_functions.go @@ -0,0 +1,158 @@ +package service + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "regexp" + "strconv" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" +) + +var sleepDelay = func(delay time.Duration) { + time.Sleep(delay) +} + +// Interface for matching types which also have a Len method. +type lener interface { + Len() int +} + +// BuildContentLength builds the content length of a request based on the body, +// or will use the HTTPRequest.Header's "Content-Length" if defined. If unable +// to determine request body length and no "Content-Length" was specified it will panic. +func BuildContentLength(r *Request) { + if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" { + length, _ := strconv.ParseInt(slength, 10, 64) + r.HTTPRequest.ContentLength = length + return + } + + var length int64 + switch body := r.Body.(type) { + case nil: + length = 0 + case lener: + length = int64(body.Len()) + case io.Seeker: + r.bodyStart, _ = body.Seek(0, 1) + end, _ := body.Seek(0, 2) + body.Seek(r.bodyStart, 0) // make sure to seek back to original location + length = end - r.bodyStart + default: + panic("Cannot get length of body, must provide `ContentLength`") + } + + r.HTTPRequest.ContentLength = length + r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length)) +} + +// UserAgentHandler is a request handler for injecting User agent into requests. +func UserAgentHandler(r *Request) { + r.HTTPRequest.Header.Set("User-Agent", aws.SDKName+"/"+aws.SDKVersion) +} + +var reStatusCode = regexp.MustCompile(`^(\d+)`) + +// SendHandler is a request handler to send service request using HTTP client. +func SendHandler(r *Request) { + var err error + r.HTTPResponse, err = r.Service.Config.HTTPClient.Do(r.HTTPRequest) + if err != nil { + // Capture the case where url.Error is returned for error processing + // response. e.g. 301 without location header comes back as string + // error and r.HTTPResponse is nil. Other url redirect errors will + // comeback in a similar method. + if e, ok := err.(*url.Error); ok { + if s := reStatusCode.FindStringSubmatch(e.Error()); s != nil { + code, _ := strconv.ParseInt(s[1], 10, 64) + r.HTTPResponse = &http.Response{ + StatusCode: int(code), + Status: http.StatusText(int(code)), + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + } + return + } + } + if r.HTTPResponse == nil { + // Add a dummy request response object to ensure the HTTPResponse + // value is consistent. + r.HTTPResponse = &http.Response{ + StatusCode: int(0), + Status: http.StatusText(int(0)), + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + } + } + // Catch all other request errors. + r.Error = awserr.New("RequestError", "send request failed", err) + r.Retryable = aws.Bool(true) // network errors are retryable + } +} + +// ValidateResponseHandler is a request handler to validate service response. +func ValidateResponseHandler(r *Request) { + if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 { + // this may be replaced by an UnmarshalError handler + r.Error = awserr.New("UnknownError", "unknown error", nil) + } +} + +// AfterRetryHandler performs final checks to determine if the request should +// be retried and how long to delay. +func AfterRetryHandler(r *Request) { + // If one of the other handlers already set the retry state + // we don't want to override it based on the service's state + if r.Retryable == nil { + r.Retryable = aws.Bool(r.Service.ShouldRetry(r)) + } + + if r.WillRetry() { + r.RetryDelay = r.Service.RetryRules(r) + sleepDelay(r.RetryDelay) + + // when the expired token exception occurs the credentials + // need to be expired locally so that the next request to + // get credentials will trigger a credentials refresh. + if r.Error != nil { + if err, ok := r.Error.(awserr.Error); ok { + if isCodeExpiredCreds(err.Code()) { + r.Config.Credentials.Expire() + } + } + } + + r.RetryCount++ + r.Error = nil + } +} + +var ( + // ErrMissingRegion is an error that is returned if region configuration is + // not found. + // + // @readonly + ErrMissingRegion error = awserr.New("MissingRegion", "could not find region configuration", nil) + + // ErrMissingEndpoint is an error that is returned if an endpoint cannot be + // resolved for a service. + // + // @readonly + ErrMissingEndpoint error = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil) +) + +// ValidateEndpointHandler is a request handler to validate a request had the +// appropriate Region and Endpoint set. Will set r.Error if the endpoint or +// region is not valid. +func ValidateEndpointHandler(r *Request) { + if r.Service.SigningRegion == "" && aws.StringValue(r.Service.Config.Region) == "" { + r.Error = ErrMissingRegion + } else if r.Service.Endpoint == "" { + r.Error = ErrMissingEndpoint + } +} diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/handler_functions_test.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/handler_functions_test.go new file mode 100644 index 000000000..418da7b7c --- /dev/null +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/handler_functions_test.go @@ -0,0 +1,106 @@ +package service + +import ( + "fmt" + "net/http" + "os" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/credentials" +) + +func TestValidateEndpointHandler(t *testing.T) { + os.Clearenv() + svc := NewService(aws.NewConfig().WithRegion("us-west-2")) + svc.Handlers.Clear() + svc.Handlers.Validate.PushBack(ValidateEndpointHandler) + + req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil) + err := req.Build() + + assert.NoError(t, err) +} + +func TestValidateEndpointHandlerErrorRegion(t *testing.T) { + os.Clearenv() + svc := NewService(nil) + svc.Handlers.Clear() + svc.Handlers.Validate.PushBack(ValidateEndpointHandler) + + req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil) + err := req.Build() + + assert.Error(t, err) + assert.Equal(t, ErrMissingRegion, err) +} + +type mockCredsProvider struct { + expired bool + retrieveCalled bool +} + +func (m *mockCredsProvider) Retrieve() (credentials.Value, error) { + m.retrieveCalled = true + return credentials.Value{}, nil +} + +func (m *mockCredsProvider) IsExpired() bool { + return m.expired +} + +func TestAfterRetryRefreshCreds(t *testing.T) { + os.Clearenv() + credProvider := &mockCredsProvider{} + svc := NewService(&aws.Config{Credentials: credentials.NewCredentials(credProvider), MaxRetries: aws.Int(1)}) + + svc.Handlers.Clear() + svc.Handlers.ValidateResponse.PushBack(func(r *Request) { + r.Error = awserr.New("UnknownError", "", nil) + r.HTTPResponse = &http.Response{StatusCode: 400} + }) + svc.Handlers.UnmarshalError.PushBack(func(r *Request) { + r.Error = awserr.New("ExpiredTokenException", "", nil) + }) + svc.Handlers.AfterRetry.PushBack(func(r *Request) { + AfterRetryHandler(r) + }) + + assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired") + assert.False(t, credProvider.retrieveCalled) + + req := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil) + req.Send() + + assert.True(t, svc.Config.Credentials.IsExpired()) + assert.False(t, credProvider.retrieveCalled) + + _, err := svc.Config.Credentials.Get() + assert.NoError(t, err) + assert.True(t, credProvider.retrieveCalled) +} + +type testSendHandlerTransport struct{} + +func (t *testSendHandlerTransport) RoundTrip(r *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("mock error") +} + +func TestSendHandlerError(t *testing.T) { + svc := NewService(&aws.Config{ + HTTPClient: &http.Client{ + Transport: &testSendHandlerTransport{}, + }, + }) + svc.Handlers.Clear() + svc.Handlers.Send.PushBack(SendHandler) + r := NewRequest(svc, &Operation{Name: "Operation"}, nil, nil) + + r.Send() + + assert.Error(t, r.Error) + assert.NotNil(t, r.HTTPResponse) +} diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/handlers.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/handlers.go new file mode 100644 index 000000000..3f125e5a8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/handlers.go @@ -0,0 +1,85 @@ +package service + +// A Handlers provides a collection of request handlers for various +// stages of handling requests. +type Handlers struct { + Validate HandlerList + Build HandlerList + Sign HandlerList + Send HandlerList + ValidateResponse HandlerList + Unmarshal HandlerList + UnmarshalMeta HandlerList + UnmarshalError HandlerList + Retry HandlerList + AfterRetry HandlerList +} + +// copy returns of this handler's lists. +func (h *Handlers) copy() Handlers { + return Handlers{ + Validate: h.Validate.copy(), + Build: h.Build.copy(), + Sign: h.Sign.copy(), + Send: h.Send.copy(), + ValidateResponse: h.ValidateResponse.copy(), + Unmarshal: h.Unmarshal.copy(), + UnmarshalError: h.UnmarshalError.copy(), + UnmarshalMeta: h.UnmarshalMeta.copy(), + Retry: h.Retry.copy(), + AfterRetry: h.AfterRetry.copy(), + } +} + +// Clear removes callback functions for all handlers +func (h *Handlers) Clear() { + h.Validate.Clear() + h.Build.Clear() + h.Send.Clear() + h.Sign.Clear() + h.Unmarshal.Clear() + h.UnmarshalMeta.Clear() + h.UnmarshalError.Clear() + h.ValidateResponse.Clear() + h.Retry.Clear() + h.AfterRetry.Clear() +} + +// A HandlerList manages zero or more handlers in a list. +type HandlerList struct { + list []func(*Request) +} + +// copy creates a copy of the handler list. +func (l *HandlerList) copy() HandlerList { + var n HandlerList + n.list = append([]func(*Request){}, l.list...) + return n +} + +// Clear clears the handler list. +func (l *HandlerList) Clear() { + l.list = []func(*Request){} +} + +// Len returns the number of handlers in the list. +func (l *HandlerList) Len() int { + return len(l.list) +} + +// PushBack pushes handlers f to the back of the handler list. +func (l *HandlerList) PushBack(f ...func(*Request)) { + l.list = append(l.list, f...) +} + +// PushFront pushes handlers f to the front of the handler list. +func (l *HandlerList) PushFront(f ...func(*Request)) { + l.list = append(f, l.list...) +} + +// Run executes all handlers in the list with a given request object. +func (l *HandlerList) Run(r *Request) { + for _, f := range l.list { + f(r) + } +} diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/handlers_test.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/handlers_test.go new file mode 100644 index 000000000..d129269f0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/handlers_test.go @@ -0,0 +1,33 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-sdk-go/aws" +) + +func TestHandlerList(t *testing.T) { + s := "" + r := &Request{} + l := HandlerList{} + l.PushBack(func(r *Request) { + s += "a" + r.Data = s + }) + l.Run(r) + assert.Equal(t, "a", s) + assert.Equal(t, "a", r.Data) +} + +func TestMultipleHandlers(t *testing.T) { + r := &Request{} + l := HandlerList{} + l.PushBack(func(r *Request) { r.Data = nil }) + l.PushFront(func(r *Request) { r.Data = aws.Bool(true) }) + l.Run(r) + if r.Data != nil { + t.Error("Expected handler to execute") + } +} diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/param_validator.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/param_validator.go new file mode 100644 index 000000000..76c67d029 --- /dev/null +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/param_validator.go @@ -0,0 +1,89 @@ +package service + +import ( + "fmt" + "reflect" + "strings" + + "github.com/aws/aws-sdk-go/aws/awserr" +) + +// ValidateParameters is a request handler to validate the input parameters. +// Validating parameters only has meaning if done prior to the request being sent. +func ValidateParameters(r *Request) { + if r.ParamsFilled() { + v := validator{errors: []string{}} + v.validateAny(reflect.ValueOf(r.Params), "") + + if count := len(v.errors); count > 0 { + format := "%d validation errors:\n- %s" + msg := fmt.Sprintf(format, count, strings.Join(v.errors, "\n- ")) + r.Error = awserr.New("InvalidParameter", msg, nil) + } + } +} + +// A validator validates values. Collects validations errors which occurs. +type validator struct { + errors []string +} + +// validateAny will validate any struct, slice or map type. All validations +// are also performed recursively for nested types. +func (v *validator) validateAny(value reflect.Value, path string) { + value = reflect.Indirect(value) + if !value.IsValid() { + return + } + + switch value.Kind() { + case reflect.Struct: + v.validateStruct(value, path) + case reflect.Slice: + for i := 0; i < value.Len(); i++ { + v.validateAny(value.Index(i), path+fmt.Sprintf("[%d]", i)) + } + case reflect.Map: + for _, n := range value.MapKeys() { + v.validateAny(value.MapIndex(n), path+fmt.Sprintf("[%q]", n.String())) + } + } +} + +// validateStruct will validate the struct value's fields. If the structure has +// nested types those types will be validated also. +func (v *validator) validateStruct(value reflect.Value, path string) { + prefix := "." + if path == "" { + prefix = "" + } + + for i := 0; i < value.Type().NumField(); i++ { + f := value.Type().Field(i) + if strings.ToLower(f.Name[0:1]) == f.Name[0:1] { + continue + } + fvalue := value.FieldByName(f.Name) + + notset := false + if f.Tag.Get("required") != "" { + switch fvalue.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map: + if fvalue.IsNil() { + notset = true + } + default: + if !fvalue.IsValid() { + notset = true + } + } + } + + if notset { + msg := "missing required parameter: " + path + prefix + f.Name + v.errors = append(v.errors, msg) + } else { + v.validateAny(fvalue, path+prefix+f.Name) + } + } +} diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/param_validator_test.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/param_validator_test.go new file mode 100644 index 000000000..514f844af --- /dev/null +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/param_validator_test.go @@ -0,0 +1,85 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" +) + +var testSvc = func() *Service { + s := &Service{ + Config: &aws.Config{}, + ServiceName: "mock-service", + APIVersion: "2015-01-01", + } + return s +}() + +type StructShape struct { + RequiredList []*ConditionalStructShape `required:"true"` + RequiredMap map[string]*ConditionalStructShape `required:"true"` + RequiredBool *bool `required:"true"` + OptionalStruct *ConditionalStructShape + + hiddenParameter *string + + metadataStructureShape +} + +type metadataStructureShape struct { + SDKShapeTraits bool +} + +type ConditionalStructShape struct { + Name *string `required:"true"` + SDKShapeTraits bool +} + +func TestNoErrors(t *testing.T) { + input := &StructShape{ + RequiredList: []*ConditionalStructShape{}, + RequiredMap: map[string]*ConditionalStructShape{ + "key1": {Name: aws.String("Name")}, + "key2": {Name: aws.String("Name")}, + }, + RequiredBool: aws.Bool(true), + OptionalStruct: &ConditionalStructShape{Name: aws.String("Name")}, + } + + req := NewRequest(testSvc, &Operation{}, input, nil) + ValidateParameters(req) + assert.NoError(t, req.Error) +} + +func TestMissingRequiredParameters(t *testing.T) { + input := &StructShape{} + req := NewRequest(testSvc, &Operation{}, input, nil) + ValidateParameters(req) + + assert.Error(t, req.Error) + assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code()) + assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList\n- missing required parameter: RequiredMap\n- missing required parameter: RequiredBool", req.Error.(awserr.Error).Message()) +} + +func TestNestedMissingRequiredParameters(t *testing.T) { + input := &StructShape{ + RequiredList: []*ConditionalStructShape{{}}, + RequiredMap: map[string]*ConditionalStructShape{ + "key1": {Name: aws.String("Name")}, + "key2": {}, + }, + RequiredBool: aws.Bool(true), + OptionalStruct: &ConditionalStructShape{}, + } + + req := NewRequest(testSvc, &Operation{}, input, nil) + ValidateParameters(req) + + assert.Error(t, req.Error) + assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code()) + assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList[0].Name\n- missing required parameter: RequiredMap[\"key2\"].Name\n- missing required parameter: OptionalStruct.Name", req.Error.(awserr.Error).Message()) + +} diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/request.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/request.go new file mode 100644 index 000000000..24b9c99dc --- /dev/null +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/request.go @@ -0,0 +1,345 @@ +package service + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "reflect" + "strings" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awsutil" +) + +// A Request is the service request to be made. +type Request struct { + *Service + Handlers Handlers + Time time.Time + ExpireTime time.Duration + Operation *Operation + HTTPRequest *http.Request + HTTPResponse *http.Response + Body io.ReadSeeker + bodyStart int64 // offset from beginning of Body that the request body starts + Params interface{} + Error error + Data interface{} + RequestID string + RetryCount uint + Retryable *bool + RetryDelay time.Duration + + built bool +} + +// An Operation is the service API operation to be made. +type Operation struct { + Name string + HTTPMethod string + HTTPPath string + *Paginator +} + +// Paginator keeps track of pagination configuration for an API operation. +type Paginator struct { + InputTokens []string + OutputTokens []string + LimitToken string + TruncationToken string +} + +// NewRequest returns a new Request pointer for the service API +// operation and parameters. +// +// Params is any value of input parameters to be the request payload. +// Data is pointer value to an object which the request's response +// payload will be deserialized to. +func NewRequest(service *Service, operation *Operation, params interface{}, data interface{}) *Request { + method := operation.HTTPMethod + if method == "" { + method = "POST" + } + p := operation.HTTPPath + if p == "" { + p = "/" + } + + httpReq, _ := http.NewRequest(method, "", nil) + httpReq.URL, _ = url.Parse(service.Endpoint + p) + + r := &Request{ + Service: service, + Handlers: service.Handlers.copy(), + Time: time.Now(), + ExpireTime: 0, + Operation: operation, + HTTPRequest: httpReq, + Body: nil, + Params: params, + Error: nil, + Data: data, + } + r.SetBufferBody([]byte{}) + + return r +} + +// WillRetry returns if the request's can be retried. +func (r *Request) WillRetry() bool { + return r.Error != nil && aws.BoolValue(r.Retryable) && r.RetryCount < r.Service.MaxRetries() +} + +// ParamsFilled returns if the request's parameters have been populated +// and the parameters are valid. False is returned if no parameters are +// provided or invalid. +func (r *Request) ParamsFilled() bool { + return r.Params != nil && reflect.ValueOf(r.Params).Elem().IsValid() +} + +// DataFilled returns true if the request's data for response deserialization +// target has been set and is a valid. False is returned if data is not +// set, or is invalid. +func (r *Request) DataFilled() bool { + return r.Data != nil && reflect.ValueOf(r.Data).Elem().IsValid() +} + +// SetBufferBody will set the request's body bytes that will be sent to +// the service API. +func (r *Request) SetBufferBody(buf []byte) { + r.SetReaderBody(bytes.NewReader(buf)) +} + +// SetStringBody sets the body of the request to be backed by a string. +func (r *Request) SetStringBody(s string) { + r.SetReaderBody(strings.NewReader(s)) +} + +// SetReaderBody will set the request's body reader. +func (r *Request) SetReaderBody(reader io.ReadSeeker) { + r.HTTPRequest.Body = ioutil.NopCloser(reader) + r.Body = reader +} + +// Presign returns the request's signed URL. Error will be returned +// if the signing fails. +func (r *Request) Presign(expireTime time.Duration) (string, error) { + r.ExpireTime = expireTime + r.Sign() + if r.Error != nil { + return "", r.Error + } + return r.HTTPRequest.URL.String(), nil +} + +func debugLogReqError(r *Request, stage string, retrying bool, err error) { + if !r.Config.LogLevel.Matches(aws.LogDebugWithRequestErrors) { + return + } + + retryStr := "not retrying" + if retrying { + retryStr = "will retry" + } + + r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v", + stage, r.ServiceName, r.Operation.Name, retryStr, err)) +} + +// Build will build the request's object so it can be signed and sent +// to the service. Build will also validate all the request's parameters. +// Anny additional build Handlers set on this request will be run +// in the order they were set. +// +// The request will only be built once. Multiple calls to build will have +// no effect. +// +// If any Validate or Build errors occur the build will stop and the error +// which occurred will be returned. +func (r *Request) Build() error { + if !r.built { + r.Error = nil + r.Handlers.Validate.Run(r) + if r.Error != nil { + debugLogReqError(r, "Validate Request", false, r.Error) + return r.Error + } + r.Handlers.Build.Run(r) + r.built = true + } + + return r.Error +} + +// Sign will sign the request retuning error if errors are encountered. +// +// Send will build the request prior to signing. All Sign Handlers will +// be executed in the order they were set. +func (r *Request) Sign() error { + r.Build() + if r.Error != nil { + debugLogReqError(r, "Build Request", false, r.Error) + return r.Error + } + + r.Handlers.Sign.Run(r) + return r.Error +} + +// Send will send the request returning error if errors are encountered. +// +// Send will sign the request prior to sending. All Send Handlers will +// be executed in the order they were set. +func (r *Request) Send() error { + for { + r.Sign() + if r.Error != nil { + return r.Error + } + + if aws.BoolValue(r.Retryable) { + if r.Config.LogLevel.Matches(aws.LogDebugWithRequestRetries) { + r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d", + r.ServiceName, r.Operation.Name, r.RetryCount)) + } + + // Re-seek the body back to the original point in for a retry so that + // send will send the body's contents again in the upcoming request. + r.Body.Seek(r.bodyStart, 0) + r.HTTPRequest.Body = ioutil.NopCloser(r.Body) + } + r.Retryable = nil + + r.Handlers.Send.Run(r) + if r.Error != nil { + err := r.Error + r.Handlers.Retry.Run(r) + r.Handlers.AfterRetry.Run(r) + if r.Error != nil { + debugLogReqError(r, "Send Request", false, r.Error) + return r.Error + } + debugLogReqError(r, "Send Request", true, err) + continue + } + + r.Handlers.UnmarshalMeta.Run(r) + r.Handlers.ValidateResponse.Run(r) + if r.Error != nil { + err := r.Error + r.Handlers.UnmarshalError.Run(r) + r.Handlers.Retry.Run(r) + r.Handlers.AfterRetry.Run(r) + if r.Error != nil { + debugLogReqError(r, "Validate Response", false, r.Error) + return r.Error + } + debugLogReqError(r, "Validate Response", true, err) + continue + } + + r.Handlers.Unmarshal.Run(r) + if r.Error != nil { + err := r.Error + r.Handlers.Retry.Run(r) + r.Handlers.AfterRetry.Run(r) + if r.Error != nil { + debugLogReqError(r, "Unmarshal Response", false, r.Error) + return r.Error + } + debugLogReqError(r, "Unmarshal Response", true, err) + continue + } + + break + } + + return nil +} + +// HasNextPage returns true if this request has more pages of data available. +func (r *Request) HasNextPage() bool { + return r.nextPageTokens() != nil +} + +// nextPageTokens returns the tokens to use when asking for the next page of +// data. +func (r *Request) nextPageTokens() []interface{} { + if r.Operation.Paginator == nil { + return nil + } + + if r.Operation.TruncationToken != "" { + tr := awsutil.ValuesAtAnyPath(r.Data, r.Operation.TruncationToken) + if tr == nil || len(tr) == 0 { + return nil + } + switch v := tr[0].(type) { + case bool: + if v == false { + return nil + } + } + } + + found := false + tokens := make([]interface{}, len(r.Operation.OutputTokens)) + + for i, outtok := range r.Operation.OutputTokens { + v := awsutil.ValuesAtAnyPath(r.Data, outtok) + if v != nil && len(v) > 0 { + found = true + tokens[i] = v[0] + } + } + + if found { + return tokens + } + return nil +} + +// NextPage returns a new Request that can be executed to return the next +// page of result data. Call .Send() on this request to execute it. +func (r *Request) NextPage() *Request { + tokens := r.nextPageTokens() + if tokens == nil { + return nil + } + + data := reflect.New(reflect.TypeOf(r.Data).Elem()).Interface() + nr := NewRequest(r.Service, r.Operation, awsutil.CopyOf(r.Params), data) + for i, intok := range nr.Operation.InputTokens { + awsutil.SetValueAtAnyPath(nr.Params, intok, tokens[i]) + } + return nr +} + +// EachPage iterates over each page of a paginated request object. The fn +// parameter should be a function with the following sample signature: +// +// func(page *T, lastPage bool) bool { +// return true // return false to stop iterating +// } +// +// Where "T" is the structure type matching the output structure of the given +// operation. For example, a request object generated by +// DynamoDB.ListTablesRequest() would expect to see dynamodb.ListTablesOutput +// as the structure "T". The lastPage value represents whether the page is +// the last page of data or not. The return value of this function should +// return true to keep iterating or false to stop. +func (r *Request) EachPage(fn func(data interface{}, isLastPage bool) (shouldContinue bool)) error { + for page := r; page != nil; page = page.NextPage() { + page.Send() + shouldContinue := fn(page.Data, !page.HasNextPage()) + if page.Error != nil || !shouldContinue { + return page.Error + } + } + + return nil +} diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/request_pagination_test.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/request_pagination_test.go new file mode 100644 index 000000000..a236960fa --- /dev/null +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/request_pagination_test.go @@ -0,0 +1,307 @@ +package service_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/service" + "github.com/aws/aws-sdk-go/internal/test/unit" + "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/s3" +) + +var _ = unit.Imported + +// Use DynamoDB methods for simplicity +func TestPagination(t *testing.T) { + db := dynamodb.New(nil) + tokens, pages, numPages, gotToEnd := []string{}, []string{}, 0, false + + reqNum := 0 + resps := []*dynamodb.ListTablesOutput{ + {TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")}, + {TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")}, + {TableNames: []*string{aws.String("Table5")}}, + } + + db.Handlers.Send.Clear() // mock sending + db.Handlers.Unmarshal.Clear() + db.Handlers.UnmarshalMeta.Clear() + db.Handlers.ValidateResponse.Clear() + db.Handlers.Build.PushBack(func(r *service.Request) { + in := r.Params.(*dynamodb.ListTablesInput) + if in == nil { + tokens = append(tokens, "") + } else if in.ExclusiveStartTableName != nil { + tokens = append(tokens, *in.ExclusiveStartTableName) + } + }) + db.Handlers.Unmarshal.PushBack(func(r *service.Request) { + r.Data = resps[reqNum] + reqNum++ + }) + + params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)} + err := db.ListTablesPages(params, func(p *dynamodb.ListTablesOutput, last bool) bool { + numPages++ + for _, t := range p.TableNames { + pages = append(pages, *t) + } + if last { + if gotToEnd { + assert.Fail(t, "last=true happened twice") + } + gotToEnd = true + } + return true + }) + + assert.Equal(t, []string{"Table2", "Table4"}, tokens) + assert.Equal(t, []string{"Table1", "Table2", "Table3", "Table4", "Table5"}, pages) + assert.Equal(t, 3, numPages) + assert.True(t, gotToEnd) + assert.Nil(t, err) + assert.Nil(t, params.ExclusiveStartTableName) +} + +// Use DynamoDB methods for simplicity +func TestPaginationEachPage(t *testing.T) { + db := dynamodb.New(nil) + tokens, pages, numPages, gotToEnd := []string{}, []string{}, 0, false + + reqNum := 0 + resps := []*dynamodb.ListTablesOutput{ + {TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")}, + {TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")}, + {TableNames: []*string{aws.String("Table5")}}, + } + + db.Handlers.Send.Clear() // mock sending + db.Handlers.Unmarshal.Clear() + db.Handlers.UnmarshalMeta.Clear() + db.Handlers.ValidateResponse.Clear() + db.Handlers.Build.PushBack(func(r *service.Request) { + in := r.Params.(*dynamodb.ListTablesInput) + if in == nil { + tokens = append(tokens, "") + } else if in.ExclusiveStartTableName != nil { + tokens = append(tokens, *in.ExclusiveStartTableName) + } + }) + db.Handlers.Unmarshal.PushBack(func(r *service.Request) { + r.Data = resps[reqNum] + reqNum++ + }) + + params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)} + req, _ := db.ListTablesRequest(params) + err := req.EachPage(func(p interface{}, last bool) bool { + numPages++ + for _, t := range p.(*dynamodb.ListTablesOutput).TableNames { + pages = append(pages, *t) + } + if last { + if gotToEnd { + assert.Fail(t, "last=true happened twice") + } + gotToEnd = true + } + + return true + }) + + assert.Equal(t, []string{"Table2", "Table4"}, tokens) + assert.Equal(t, []string{"Table1", "Table2", "Table3", "Table4", "Table5"}, pages) + assert.Equal(t, 3, numPages) + assert.True(t, gotToEnd) + assert.Nil(t, err) +} + +// Use DynamoDB methods for simplicity +func TestPaginationEarlyExit(t *testing.T) { + db := dynamodb.New(nil) + numPages, gotToEnd := 0, false + + reqNum := 0 + resps := []*dynamodb.ListTablesOutput{ + {TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")}, + {TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")}, + {TableNames: []*string{aws.String("Table5")}}, + } + + db.Handlers.Send.Clear() // mock sending + db.Handlers.Unmarshal.Clear() + db.Handlers.UnmarshalMeta.Clear() + db.Handlers.ValidateResponse.Clear() + db.Handlers.Unmarshal.PushBack(func(r *service.Request) { + r.Data = resps[reqNum] + reqNum++ + }) + + params := &dynamodb.ListTablesInput{Limit: aws.Int64(2)} + err := db.ListTablesPages(params, func(p *dynamodb.ListTablesOutput, last bool) bool { + numPages++ + if numPages == 2 { + return false + } + if last { + if gotToEnd { + assert.Fail(t, "last=true happened twice") + } + gotToEnd = true + } + return true + }) + + assert.Equal(t, 2, numPages) + assert.False(t, gotToEnd) + assert.Nil(t, err) +} + +func TestSkipPagination(t *testing.T) { + client := s3.New(nil) + client.Handlers.Send.Clear() // mock sending + client.Handlers.Unmarshal.Clear() + client.Handlers.UnmarshalMeta.Clear() + client.Handlers.ValidateResponse.Clear() + client.Handlers.Unmarshal.PushBack(func(r *service.Request) { + r.Data = &s3.HeadBucketOutput{} + }) + + req, _ := client.HeadBucketRequest(&s3.HeadBucketInput{Bucket: aws.String("bucket")}) + + numPages, gotToEnd := 0, false + req.EachPage(func(p interface{}, last bool) bool { + numPages++ + if last { + gotToEnd = true + } + return true + }) + assert.Equal(t, 1, numPages) + assert.True(t, gotToEnd) +} + +// Use S3 for simplicity +func TestPaginationTruncation(t *testing.T) { + count := 0 + client := s3.New(nil) + + reqNum := &count + resps := []*s3.ListObjectsOutput{ + {IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key1")}}}, + {IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key2")}}}, + {IsTruncated: aws.Bool(false), Contents: []*s3.Object{{Key: aws.String("Key3")}}}, + {IsTruncated: aws.Bool(true), Contents: []*s3.Object{{Key: aws.String("Key4")}}}, + } + + client.Handlers.Send.Clear() // mock sending + client.Handlers.Unmarshal.Clear() + client.Handlers.UnmarshalMeta.Clear() + client.Handlers.ValidateResponse.Clear() + client.Handlers.Unmarshal.PushBack(func(r *service.Request) { + r.Data = resps[*reqNum] + *reqNum++ + }) + + params := &s3.ListObjectsInput{Bucket: aws.String("bucket")} + + results := []string{} + err := client.ListObjectsPages(params, func(p *s3.ListObjectsOutput, last bool) bool { + results = append(results, *p.Contents[0].Key) + return true + }) + + assert.Equal(t, []string{"Key1", "Key2", "Key3"}, results) + assert.Nil(t, err) + + // Try again without truncation token at all + count = 0 + resps[1].IsTruncated = nil + resps[2].IsTruncated = aws.Bool(true) + results = []string{} + err = client.ListObjectsPages(params, func(p *s3.ListObjectsOutput, last bool) bool { + results = append(results, *p.Contents[0].Key) + return true + }) + + assert.Equal(t, []string{"Key1", "Key2"}, results) + assert.Nil(t, err) + +} + +// Benchmarks +var benchResps = []*dynamodb.ListTablesOutput{ + {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, + {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, + {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, + {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, + {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, + {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, + {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, + {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, + {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, + {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, + {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, + {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, + {TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")}, + {TableNames: []*string{aws.String("TABLE")}}, +} + +var benchDb = func() *dynamodb.DynamoDB { + db := dynamodb.New(nil) + db.Handlers.Send.Clear() // mock sending + db.Handlers.Unmarshal.Clear() + db.Handlers.UnmarshalMeta.Clear() + db.Handlers.ValidateResponse.Clear() + return db +} + +func BenchmarkCodegenIterator(b *testing.B) { + reqNum := 0 + db := benchDb() + db.Handlers.Unmarshal.PushBack(func(r *service.Request) { + r.Data = benchResps[reqNum] + reqNum++ + }) + + input := &dynamodb.ListTablesInput{Limit: aws.Int64(2)} + iter := func(fn func(*dynamodb.ListTablesOutput, bool) bool) error { + page, _ := db.ListTablesRequest(input) + for ; page != nil; page = page.NextPage() { + page.Send() + out := page.Data.(*dynamodb.ListTablesOutput) + if result := fn(out, !page.HasNextPage()); page.Error != nil || !result { + return page.Error + } + } + return nil + } + + for i := 0; i < b.N; i++ { + reqNum = 0 + iter(func(p *dynamodb.ListTablesOutput, last bool) bool { + return true + }) + } +} + +func BenchmarkEachPageIterator(b *testing.B) { + reqNum := 0 + db := benchDb() + db.Handlers.Unmarshal.PushBack(func(r *service.Request) { + r.Data = benchResps[reqNum] + reqNum++ + }) + + input := &dynamodb.ListTablesInput{Limit: aws.Int64(2)} + for i := 0; i < b.N; i++ { + reqNum = 0 + req, _ := db.ListTablesRequest(input) + req.EachPage(func(p interface{}, last bool) bool { + return true + }) + } +} diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/request_test.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/request_test.go new file mode 100644 index 000000000..a07353dcb --- /dev/null +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/request_test.go @@ -0,0 +1,226 @@ +package service + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/stretchr/testify/assert" +) + +type testData struct { + Data string +} + +func body(str string) io.ReadCloser { + return ioutil.NopCloser(bytes.NewReader([]byte(str))) +} + +func unmarshal(req *Request) { + defer req.HTTPResponse.Body.Close() + if req.Data != nil { + json.NewDecoder(req.HTTPResponse.Body).Decode(req.Data) + } + return +} + +func unmarshalError(req *Request) { + bodyBytes, err := ioutil.ReadAll(req.HTTPResponse.Body) + if err != nil { + req.Error = awserr.New("UnmarshaleError", req.HTTPResponse.Status, err) + return + } + if len(bodyBytes) == 0 { + req.Error = awserr.NewRequestFailure( + awserr.New("UnmarshaleError", req.HTTPResponse.Status, fmt.Errorf("empty body")), + req.HTTPResponse.StatusCode, + "", + ) + return + } + var jsonErr jsonErrorResponse + if err := json.Unmarshal(bodyBytes, &jsonErr); err != nil { + req.Error = awserr.New("UnmarshaleError", "JSON unmarshal", err) + return + } + req.Error = awserr.NewRequestFailure( + awserr.New(jsonErr.Code, jsonErr.Message, nil), + req.HTTPResponse.StatusCode, + "", + ) +} + +type jsonErrorResponse struct { + Code string `json:"__type"` + Message string `json:"message"` +} + +// test that retries occur for 5xx status codes +func TestRequestRecoverRetry5xx(t *testing.T) { + reqNum := 0 + reqs := []http.Response{ + {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, + {StatusCode: 501, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, + {StatusCode: 200, Body: body(`{"data":"valid"}`)}, + } + + s := NewService(aws.NewConfig().WithMaxRetries(10)) + s.Handlers.Validate.Clear() + s.Handlers.Unmarshal.PushBack(unmarshal) + s.Handlers.UnmarshalError.PushBack(unmarshalError) + s.Handlers.Send.Clear() // mock sending + s.Handlers.Send.PushBack(func(r *Request) { + r.HTTPResponse = &reqs[reqNum] + reqNum++ + }) + out := &testData{} + r := NewRequest(s, &Operation{Name: "Operation"}, nil, out) + err := r.Send() + assert.Nil(t, err) + assert.Equal(t, 2, int(r.RetryCount)) + assert.Equal(t, "valid", out.Data) +} + +// test that retries occur for 4xx status codes with a response type that can be retried - see `shouldRetry` +func TestRequestRecoverRetry4xxRetryable(t *testing.T) { + reqNum := 0 + reqs := []http.Response{ + {StatusCode: 400, Body: body(`{"__type":"Throttling","message":"Rate exceeded."}`)}, + {StatusCode: 429, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)}, + {StatusCode: 200, Body: body(`{"data":"valid"}`)}, + } + + s := NewService(aws.NewConfig().WithMaxRetries(10)) + s.Handlers.Validate.Clear() + s.Handlers.Unmarshal.PushBack(unmarshal) + s.Handlers.UnmarshalError.PushBack(unmarshalError) + s.Handlers.Send.Clear() // mock sending + s.Handlers.Send.PushBack(func(r *Request) { + r.HTTPResponse = &reqs[reqNum] + reqNum++ + }) + out := &testData{} + r := NewRequest(s, &Operation{Name: "Operation"}, nil, out) + err := r.Send() + assert.Nil(t, err) + assert.Equal(t, 2, int(r.RetryCount)) + assert.Equal(t, "valid", out.Data) +} + +// test that retries don't occur for 4xx status codes with a response type that can't be retried +func TestRequest4xxUnretryable(t *testing.T) { + s := NewService(aws.NewConfig().WithMaxRetries(10)) + s.Handlers.Validate.Clear() + s.Handlers.Unmarshal.PushBack(unmarshal) + s.Handlers.UnmarshalError.PushBack(unmarshalError) + s.Handlers.Send.Clear() // mock sending + s.Handlers.Send.PushBack(func(r *Request) { + r.HTTPResponse = &http.Response{StatusCode: 401, Body: body(`{"__type":"SignatureDoesNotMatch","message":"Signature does not match."}`)} + }) + out := &testData{} + r := NewRequest(s, &Operation{Name: "Operation"}, nil, out) + err := r.Send() + assert.NotNil(t, err) + if e, ok := err.(awserr.RequestFailure); ok { + assert.Equal(t, 401, e.StatusCode()) + } else { + assert.Fail(t, "Expected error to be a service failure") + } + assert.Equal(t, "SignatureDoesNotMatch", err.(awserr.Error).Code()) + assert.Equal(t, "Signature does not match.", err.(awserr.Error).Message()) + assert.Equal(t, 0, int(r.RetryCount)) +} + +func TestRequestExhaustRetries(t *testing.T) { + delays := []time.Duration{} + sleepDelay = func(delay time.Duration) { + delays = append(delays, delay) + } + + reqNum := 0 + reqs := []http.Response{ + {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, + {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, + {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, + {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, + } + + s := NewService(aws.NewConfig().WithMaxRetries(aws.DefaultRetries)) + s.Handlers.Validate.Clear() + s.Handlers.Unmarshal.PushBack(unmarshal) + s.Handlers.UnmarshalError.PushBack(unmarshalError) + s.Handlers.Send.Clear() // mock sending + s.Handlers.Send.PushBack(func(r *Request) { + r.HTTPResponse = &reqs[reqNum] + reqNum++ + }) + r := NewRequest(s, &Operation{Name: "Operation"}, nil, nil) + err := r.Send() + assert.NotNil(t, err) + if e, ok := err.(awserr.RequestFailure); ok { + assert.Equal(t, 500, e.StatusCode()) + } else { + assert.Fail(t, "Expected error to be a service failure") + } + assert.Equal(t, "UnknownError", err.(awserr.Error).Code()) + assert.Equal(t, "An error occurred.", err.(awserr.Error).Message()) + assert.Equal(t, 3, int(r.RetryCount)) + + expectDelays := []struct{ min, max time.Duration }{{30, 59}, {60, 118}, {120, 236}} + for i, v := range delays { + min := expectDelays[i].min * time.Millisecond + max := expectDelays[i].max * time.Millisecond + assert.True(t, min <= v && v <= max, + "Expect delay to be within range, i:%d, v:%s, min:%s, max:%s", i, v, min, max) + } +} + +// test that the request is retried after the credentials are expired. +func TestRequestRecoverExpiredCreds(t *testing.T) { + reqNum := 0 + reqs := []http.Response{ + {StatusCode: 400, Body: body(`{"__type":"ExpiredTokenException","message":"expired token"}`)}, + {StatusCode: 200, Body: body(`{"data":"valid"}`)}, + } + + s := NewService(&aws.Config{MaxRetries: aws.Int(10), Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "")}) + s.Handlers.Validate.Clear() + s.Handlers.Unmarshal.PushBack(unmarshal) + s.Handlers.UnmarshalError.PushBack(unmarshalError) + + credExpiredBeforeRetry := false + credExpiredAfterRetry := false + + s.Handlers.AfterRetry.PushBack(func(r *Request) { + credExpiredAfterRetry = r.Config.Credentials.IsExpired() + }) + + s.Handlers.Sign.Clear() + s.Handlers.Sign.PushBack(func(r *Request) { + r.Config.Credentials.Get() + }) + s.Handlers.Send.Clear() // mock sending + s.Handlers.Send.PushBack(func(r *Request) { + r.HTTPResponse = &reqs[reqNum] + reqNum++ + }) + out := &testData{} + r := NewRequest(s, &Operation{Name: "Operation"}, nil, out) + err := r.Send() + assert.Nil(t, err) + + assert.False(t, credExpiredBeforeRetry, "Expect valid creds before retry check") + assert.True(t, credExpiredAfterRetry, "Expect expired creds after retry check") + assert.False(t, s.Config.Credentials.IsExpired(), "Expect valid creds after cred expired recovery") + + assert.Equal(t, 1, int(r.RetryCount)) + assert.Equal(t, "valid", out.Data) +} diff --git a/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/service.go b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/service.go new file mode 100644 index 000000000..900c51744 --- /dev/null +++ b/Godeps/_workspace/src/github.com/aws/aws-sdk-go/aws/service/service.go @@ -0,0 +1,204 @@ +package service + +import ( + "fmt" + "io/ioutil" + "math" + "math/rand" + "net/http" + "net/http/httputil" + "regexp" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/internal/endpoints" +) + +// A Service implements the base service request and response handling +// used by all services. +type Service struct { + Config *aws.Config + Handlers Handlers + ServiceName string + APIVersion string + Endpoint string + SigningName string + SigningRegion string + JSONVersion string + TargetPrefix string + RetryRules func(*Request) time.Duration + ShouldRetry func(*Request) bool + DefaultMaxRetries uint +} + +var schemeRE = regexp.MustCompile("^([^:]+)://") + +// NewService will return a pointer to a new Server object initialized. +func NewService(config *aws.Config) *Service { + svc := &Service{Config: config} + svc.Initialize() + return svc +} + +// Initialize initializes the service. +func (s *Service) Initialize() { + if s.Config == nil { + s.Config = &aws.Config{} + } + if s.Config.HTTPClient == nil { + s.Config.HTTPClient = http.DefaultClient + } + + if s.RetryRules == nil { + s.RetryRules = retryRules + } + + if s.ShouldRetry == nil { + s.ShouldRetry = shouldRetry + } + + s.DefaultMaxRetries = 3 + s.Handlers.Validate.PushBack(ValidateEndpointHandler) + s.Handlers.Build.PushBack(UserAgentHandler) + s.Handlers.Sign.PushBack(BuildContentLength) + s.Handlers.Send.PushBack(SendHandler) + s.Handlers.AfterRetry.PushBack(AfterRetryHandler) + s.Handlers.ValidateResponse.PushBack(ValidateResponseHandler) + s.AddDebugHandlers() + s.buildEndpoint() + + if !aws.BoolValue(s.Config.DisableParamValidation) { + s.Handlers.Validate.PushBack(ValidateParameters) + } +} + +// buildEndpoint builds the endpoint values the service will use to make requests with. +func (s *Service) buildEndpoint() { + if aws.StringValue(s.Config.Endpoint) != "" { + s.Endpoint = *s.Config.Endpoint + } else if s.Endpoint == "" { + s.Endpoint, s.SigningRegion = + endpoints.EndpointForRegion(s.ServiceName, aws.StringValue(s.Config.Region)) + } + + if s.Endpoint != "" && !schemeRE.MatchString(s.Endpoint) { + scheme := "https" + if aws.BoolValue(s.Config.DisableSSL) { + scheme = "http" + } + s.Endpoint = scheme + "://" + s.Endpoint + } +} + +// AddDebugHandlers injects debug logging handlers into the service to log request +// debug information. +func (s *Service) AddDebugHandlers() { + if !s.Config.LogLevel.AtLeast(aws.LogDebug) { + return + } + + s.Handlers.Send.PushFront(logRequest) + s.Handlers.Send.PushBack(logResponse) +} + +const logReqMsg = `DEBUG: Request %s/%s Details: +---[ REQUEST POST-SIGN ]----------------------------- +%s +-----------------------------------------------------` + +func logRequest(r *Request) { + logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) + dumpedBody, _ := httputil.DumpRequestOut(r.HTTPRequest, logBody) + + if logBody { + // Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's + // Body as a NoOpCloser and will not be reset after read by the HTTP + // client reader. + r.Body.Seek(r.bodyStart, 0) + r.HTTPRequest.Body = ioutil.NopCloser(r.Body) + } + + r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ServiceName, r.Operation.Name, string(dumpedBody))) +} + +const logRespMsg = `DEBUG: Response %s/%s Details: +---[ RESPONSE ]-------------------------------------- +%s +-----------------------------------------------------` + +func logResponse(r *Request) { + var msg = "no reponse data" + if r.HTTPResponse != nil { + logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) + dumpedBody, _ := httputil.DumpResponse(r.HTTPResponse, logBody) + msg = string(dumpedBody) + } else if r.Error != nil { + msg = r.Error.Error() + } + r.Config.Logger.Log(fmt.Sprintf(logRespMsg, r.ServiceName, r.Operation.Name, msg)) +} + +// MaxRetries returns the number of maximum returns the service will use to make +// an individual API request. +func (s *Service) MaxRetries() uint { + if aws.IntValue(s.Config.MaxRetries) < 0 { + return s.DefaultMaxRetries + } + return uint(aws.IntValue(s.Config.MaxRetries)) +} + +var seededRand = rand.New(rand.NewSource(time.Now().UnixNano())) + +// retryRules returns the delay duration before retrying this request again +func retryRules(r *Request) time.Duration { + + delay := int(math.Pow(2, float64(r.RetryCount))) * (seededRand.Intn(30) + 30) + return time.Duration(delay) * time.Millisecond +} + +// retryableCodes is a collection of service response codes which are retry-able +// without any further action. +var retryableCodes = map[string]struct{}{ + "RequestError": {}, + "ProvisionedThroughputExceededException": {}, + "Throttling": {}, + "ThrottlingException": {}, + "RequestLimitExceeded": {}, + "RequestThrottled": {}, +} + +// credsExpiredCodes is a collection of error codes which signify the credentials +// need to be refreshed. Expired tokens require refreshing of credentials, and +// resigning before the request can be retried. +var credsExpiredCodes = map[string]struct{}{ + "ExpiredToken": {}, + "ExpiredTokenException": {}, + "RequestExpired": {}, // EC2 Only +} + +func isCodeRetryable(code string) bool { + if _, ok := retryableCodes[code]; ok { + return true + } + + return isCodeExpiredCreds(code) +} + +func isCodeExpiredCreds(code string) bool { + _, ok := credsExpiredCodes[code] + return ok +} + +// shouldRetry returns if the request should be retried. +func shouldRetry(r *Request) bool { + if r.HTTPResponse.StatusCode >= 500 { + return true + } + if r.Error != nil { + if err, ok := r.Error.(awserr.Error); ok { + return isCodeRetryable(err.Code()) + } + } + return false +}