Clean up request logic and use retryable's more efficient handling (#4670)

This commit is contained in:
Jeff Mitchell 2018-06-01 09:12:43 -04:00 committed by GitHub
parent 4d5713d090
commit c7981e6417
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 64 deletions

View File

@ -635,7 +635,7 @@ func (c *Client) RawRequest(r *Request) (*Response, error) {
redirectCount := 0 redirectCount := 0
START: START:
req, err := r.toRetryableHTTP(false) req, err := r.toRetryableHTTP()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -22,6 +22,12 @@ type Request struct {
MFAHeaderVals []string MFAHeaderVals []string
WrapTTL string WrapTTL string
Obj interface{} Obj interface{}
// When possible, use BodyBytes as it is more efficient due to how the
// retry logic works
BodyBytes []byte
// Fallback
Body io.Reader Body io.Reader
BodySize int64 BodySize int64
@ -33,68 +39,76 @@ type Request struct {
// SetJSONBody is used to set a request body that is a JSON-encoded value. // SetJSONBody is used to set a request body that is a JSON-encoded value.
func (r *Request) SetJSONBody(val interface{}) error { func (r *Request) SetJSONBody(val interface{}) error {
buf := bytes.NewBuffer(nil) buf, err := json.Marshal(val)
enc := json.NewEncoder(buf) if err != nil {
if err := enc.Encode(val); err != nil {
return err return err
} }
r.Obj = val r.Obj = val
r.Body = buf r.BodyBytes = buf
r.BodySize = int64(buf.Len())
return nil return nil
} }
// ResetJSONBody is used to reset the body for a redirect // ResetJSONBody is used to reset the body for a redirect
func (r *Request) ResetJSONBody() error { func (r *Request) ResetJSONBody() error {
if r.Body == nil { if r.BodyBytes == nil {
return nil return nil
} }
return r.SetJSONBody(r.Obj) return r.SetJSONBody(r.Obj)
} }
// ToHTTP turns this request into a valid *http.Request for use with the // DEPRECATED: ToHTTP turns this request into a valid *http.Request for use
// net/http package. // with the net/http package.
func (r *Request) ToHTTP() (*http.Request, error) { func (r *Request) ToHTTP() (*http.Request, error) {
req, err := r.toRetryableHTTP(true) req, err := r.toRetryableHTTP()
if err != nil { if err != nil {
return nil, err return nil, err
} }
switch {
case r.BodyBytes == nil && r.Body == nil:
// No body
case r.BodyBytes != nil:
req.Request.Body = ioutil.NopCloser(bytes.NewReader(r.BodyBytes))
default:
if c, ok := r.Body.(io.ReadCloser); ok {
req.Request.Body = c
} else {
req.Request.Body = ioutil.NopCloser(r.Body)
}
}
return req.Request, nil return req.Request, nil
} }
// legacy indicates whether we want to return a request derived from func (r *Request) toRetryableHTTP() (*retryablehttp.Request, error) {
// http.NewRequest instead of retryablehttp.NewRequest, so that legacy clents
// that might be using the public ToHTTP method still work
func (r *Request) toRetryableHTTP(legacy bool) (*retryablehttp.Request, error) {
// Encode the query parameters // Encode the query parameters
r.URL.RawQuery = r.Params.Encode() r.URL.RawQuery = r.Params.Encode()
// Create the HTTP request, defaulting to retryable // Create the HTTP request, defaulting to retryable
var req *retryablehttp.Request var req *retryablehttp.Request
if legacy {
regReq, err := http.NewRequest(r.Method, r.URL.RequestURI(), r.Body)
if err != nil {
return nil, err
}
req = &retryablehttp.Request{
Request: regReq,
}
} else {
var buf []byte
var err error var err error
if r.Body != nil { var body interface{}
buf, err = ioutil.ReadAll(r.Body)
switch {
case r.BodyBytes == nil && r.Body == nil:
// No body
case r.BodyBytes != nil:
// Use bytes, it's more efficient
body = r.BodyBytes
default:
body = r.Body
}
req, err = retryablehttp.NewRequest(r.Method, r.URL.RequestURI(), body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
req, err = retryablehttp.NewRequest(r.Method, r.URL.RequestURI(), bytes.NewReader(buf))
if err != nil {
return nil, err
}
}
req.URL.User = r.URL.User req.URL.User = r.URL.User
req.URL.Scheme = r.URL.Scheme req.URL.Scheme = r.URL.Scheme

View File

@ -1,8 +1,6 @@
package api package api
import ( import (
"bytes"
"io"
"strings" "strings"
"testing" "testing"
) )
@ -14,20 +12,11 @@ func TestRequestSetJSONBody(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
var buf bytes.Buffer
if _, err := io.Copy(&buf, r.Body); err != nil {
t.Fatalf("err: %s", err)
}
expected := `{"foo":"bar"}` expected := `{"foo":"bar"}`
actual := strings.TrimSpace(buf.String()) actual := strings.TrimSpace(string(r.BodyBytes))
if actual != expected { if actual != expected {
t.Fatalf("bad: %s", actual) t.Fatalf("bad: %s", actual)
} }
if int64(len(buf.String())) != r.BodySize {
t.Fatalf("bad: %d", len(actual))
}
} }
func TestRequestResetJSONBody(t *testing.T) { func TestRequestResetJSONBody(t *testing.T) {
@ -37,27 +26,16 @@ func TestRequestResetJSONBody(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
var buf bytes.Buffer
if _, err := io.Copy(&buf, r.Body); err != nil {
t.Fatalf("err: %s", err)
}
if err := r.ResetJSONBody(); err != nil { if err := r.ResetJSONBody(); err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
var buf2 bytes.Buffer buf := make([]byte, len(r.BodyBytes))
if _, err := io.Copy(&buf2, r.Body); err != nil { copy(buf, r.BodyBytes)
t.Fatalf("err: %s", err)
}
expected := `{"foo":"bar"}` expected := `{"foo":"bar"}`
actual := strings.TrimSpace(buf2.String()) actual := strings.TrimSpace(string(buf))
if actual != expected { if actual != expected {
t.Fatalf("bad: %s", actual) t.Fatalf("bad: actual %s, expected %s", actual, expected)
}
if int64(len(buf2.String())) != r.BodySize {
t.Fatalf("bad: %d", len(actual))
} }
} }