Clean up request logic and use retryable's more efficient handling (#4670)

This commit is contained in:
Jeff Mitchell 2018-06-01 09:12:43 -04:00 committed by GitHub
parent 4d5713d090
commit c7981e6417
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 64 deletions

View File

@ -635,7 +635,7 @@ func (c *Client) RawRequest(r *Request) (*Response, error) {
redirectCount := 0 redirectCount := 0
START: START:
req, err := r.toRetryableHTTP(false) req, err := r.toRetryableHTTP()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -22,8 +22,14 @@ type Request struct {
MFAHeaderVals []string MFAHeaderVals []string
WrapTTL string WrapTTL string
Obj interface{} Obj interface{}
Body io.Reader
BodySize int64 // When possible, use BodyBytes as it is more efficient due to how the
// retry logic works
BodyBytes []byte
// Fallback
Body io.Reader
BodySize int64
// Whether to request overriding soft-mandatory Sentinel policies (RGPs and // Whether to request overriding soft-mandatory Sentinel policies (RGPs and
// EGPs). If set, the override flag will take effect for all policies // EGPs). If set, the override flag will take effect for all policies
@ -33,67 +39,75 @@ type Request struct {
// SetJSONBody is used to set a request body that is a JSON-encoded value. // SetJSONBody is used to set a request body that is a JSON-encoded value.
func (r *Request) SetJSONBody(val interface{}) error { func (r *Request) SetJSONBody(val interface{}) error {
buf := bytes.NewBuffer(nil) buf, err := json.Marshal(val)
enc := json.NewEncoder(buf) if err != nil {
if err := enc.Encode(val); err != nil {
return err return err
} }
r.Obj = val r.Obj = val
r.Body = buf r.BodyBytes = buf
r.BodySize = int64(buf.Len())
return nil return nil
} }
// ResetJSONBody is used to reset the body for a redirect // ResetJSONBody is used to reset the body for a redirect
func (r *Request) ResetJSONBody() error { func (r *Request) ResetJSONBody() error {
if r.Body == nil { if r.BodyBytes == nil {
return nil return nil
} }
return r.SetJSONBody(r.Obj) return r.SetJSONBody(r.Obj)
} }
// ToHTTP turns this request into a valid *http.Request for use with the // DEPRECATED: ToHTTP turns this request into a valid *http.Request for use
// net/http package. // with the net/http package.
func (r *Request) ToHTTP() (*http.Request, error) { func (r *Request) ToHTTP() (*http.Request, error) {
req, err := r.toRetryableHTTP(true) req, err := r.toRetryableHTTP()
if err != nil { if err != nil {
return nil, err return nil, err
} }
switch {
case r.BodyBytes == nil && r.Body == nil:
// No body
case r.BodyBytes != nil:
req.Request.Body = ioutil.NopCloser(bytes.NewReader(r.BodyBytes))
default:
if c, ok := r.Body.(io.ReadCloser); ok {
req.Request.Body = c
} else {
req.Request.Body = ioutil.NopCloser(r.Body)
}
}
return req.Request, nil return req.Request, nil
} }
// legacy indicates whether we want to return a request derived from func (r *Request) toRetryableHTTP() (*retryablehttp.Request, error) {
// http.NewRequest instead of retryablehttp.NewRequest, so that legacy clents
// that might be using the public ToHTTP method still work
func (r *Request) toRetryableHTTP(legacy bool) (*retryablehttp.Request, error) {
// Encode the query parameters // Encode the query parameters
r.URL.RawQuery = r.Params.Encode() r.URL.RawQuery = r.Params.Encode()
// Create the HTTP request, defaulting to retryable // Create the HTTP request, defaulting to retryable
var req *retryablehttp.Request var req *retryablehttp.Request
if legacy { var err error
regReq, err := http.NewRequest(r.Method, r.URL.RequestURI(), r.Body) var body interface{}
if err != nil {
return nil, err switch {
} case r.BodyBytes == nil && r.Body == nil:
req = &retryablehttp.Request{ // No body
Request: regReq,
} case r.BodyBytes != nil:
} else { // Use bytes, it's more efficient
var buf []byte body = r.BodyBytes
var err error
if r.Body != nil { default:
buf, err = ioutil.ReadAll(r.Body) body = r.Body
if err != nil { }
return nil, err
} req, err = retryablehttp.NewRequest(r.Method, r.URL.RequestURI(), body)
} if err != nil {
req, err = retryablehttp.NewRequest(r.Method, r.URL.RequestURI(), bytes.NewReader(buf)) return nil, err
if err != nil {
return nil, err
}
} }
req.URL.User = r.URL.User req.URL.User = r.URL.User

View File

@ -1,8 +1,6 @@
package api package api
import ( import (
"bytes"
"io"
"strings" "strings"
"testing" "testing"
) )
@ -14,20 +12,11 @@ func TestRequestSetJSONBody(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
var buf bytes.Buffer
if _, err := io.Copy(&buf, r.Body); err != nil {
t.Fatalf("err: %s", err)
}
expected := `{"foo":"bar"}` expected := `{"foo":"bar"}`
actual := strings.TrimSpace(buf.String()) actual := strings.TrimSpace(string(r.BodyBytes))
if actual != expected { if actual != expected {
t.Fatalf("bad: %s", actual) t.Fatalf("bad: %s", actual)
} }
if int64(len(buf.String())) != r.BodySize {
t.Fatalf("bad: %d", len(actual))
}
} }
func TestRequestResetJSONBody(t *testing.T) { func TestRequestResetJSONBody(t *testing.T) {
@ -37,27 +26,16 @@ func TestRequestResetJSONBody(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
var buf bytes.Buffer
if _, err := io.Copy(&buf, r.Body); err != nil {
t.Fatalf("err: %s", err)
}
if err := r.ResetJSONBody(); err != nil { if err := r.ResetJSONBody(); err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
var buf2 bytes.Buffer buf := make([]byte, len(r.BodyBytes))
if _, err := io.Copy(&buf2, r.Body); err != nil { copy(buf, r.BodyBytes)
t.Fatalf("err: %s", err)
}
expected := `{"foo":"bar"}` expected := `{"foo":"bar"}`
actual := strings.TrimSpace(buf2.String()) actual := strings.TrimSpace(string(buf))
if actual != expected { if actual != expected {
t.Fatalf("bad: %s", actual) t.Fatalf("bad: actual %s, expected %s", actual, expected)
}
if int64(len(buf2.String())) != r.BodySize {
t.Fatalf("bad: %d", len(actual))
} }
} }