From c7981e64173b98acdb43da970ae61314793cc26f Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Fri, 1 Jun 2018 09:12:43 -0400 Subject: [PATCH] Clean up request logic and use retryable's more efficient handling (#4670) --- api/client.go | 2 +- api/request.go | 86 ++++++++++++++++++++++++++------------------- api/request_test.go | 32 +++-------------- 3 files changed, 56 insertions(+), 64 deletions(-) diff --git a/api/client.go b/api/client.go index ce10fff14..8f0d3f86e 100644 --- a/api/client.go +++ b/api/client.go @@ -635,7 +635,7 @@ func (c *Client) RawRequest(r *Request) (*Response, error) { redirectCount := 0 START: - req, err := r.toRetryableHTTP(false) + req, err := r.toRetryableHTTP() if err != nil { return nil, err } diff --git a/api/request.go b/api/request.go index 8e8a26fda..5bcff8c6c 100644 --- a/api/request.go +++ b/api/request.go @@ -22,8 +22,14 @@ type Request struct { MFAHeaderVals []string WrapTTL string Obj interface{} - Body io.Reader - BodySize int64 + + // When possible, use BodyBytes as it is more efficient due to how the + // retry logic works + BodyBytes []byte + + // Fallback + Body io.Reader + BodySize int64 // Whether to request overriding soft-mandatory Sentinel policies (RGPs and // EGPs). If set, the override flag will take effect for all policies @@ -33,67 +39,75 @@ type Request struct { // SetJSONBody is used to set a request body that is a JSON-encoded value. func (r *Request) SetJSONBody(val interface{}) error { - buf := bytes.NewBuffer(nil) - enc := json.NewEncoder(buf) - if err := enc.Encode(val); err != nil { + buf, err := json.Marshal(val) + if err != nil { return err } r.Obj = val - r.Body = buf - r.BodySize = int64(buf.Len()) + r.BodyBytes = buf return nil } // ResetJSONBody is used to reset the body for a redirect func (r *Request) ResetJSONBody() error { - if r.Body == nil { + if r.BodyBytes == nil { return nil } return r.SetJSONBody(r.Obj) } -// ToHTTP turns this request into a valid *http.Request for use with the -// net/http package. +// DEPRECATED: ToHTTP turns this request into a valid *http.Request for use +// with the net/http package. func (r *Request) ToHTTP() (*http.Request, error) { - req, err := r.toRetryableHTTP(true) + req, err := r.toRetryableHTTP() if err != nil { return nil, err } + + switch { + case r.BodyBytes == nil && r.Body == nil: + // No body + + case r.BodyBytes != nil: + req.Request.Body = ioutil.NopCloser(bytes.NewReader(r.BodyBytes)) + + default: + if c, ok := r.Body.(io.ReadCloser); ok { + req.Request.Body = c + } else { + req.Request.Body = ioutil.NopCloser(r.Body) + } + } + return req.Request, nil } -// legacy indicates whether we want to return a request derived from -// http.NewRequest instead of retryablehttp.NewRequest, so that legacy clents -// that might be using the public ToHTTP method still work -func (r *Request) toRetryableHTTP(legacy bool) (*retryablehttp.Request, error) { +func (r *Request) toRetryableHTTP() (*retryablehttp.Request, error) { // Encode the query parameters r.URL.RawQuery = r.Params.Encode() // Create the HTTP request, defaulting to retryable var req *retryablehttp.Request - if legacy { - regReq, err := http.NewRequest(r.Method, r.URL.RequestURI(), r.Body) - if err != nil { - return nil, err - } - req = &retryablehttp.Request{ - Request: regReq, - } - } else { - var buf []byte - var err error - if r.Body != nil { - buf, err = ioutil.ReadAll(r.Body) - if err != nil { - return nil, err - } - } - req, err = retryablehttp.NewRequest(r.Method, r.URL.RequestURI(), bytes.NewReader(buf)) - if err != nil { - return nil, err - } + var err error + var body interface{} + + switch { + case r.BodyBytes == nil && r.Body == nil: + // No body + + case r.BodyBytes != nil: + // Use bytes, it's more efficient + body = r.BodyBytes + + default: + body = r.Body + } + + req, err = retryablehttp.NewRequest(r.Method, r.URL.RequestURI(), body) + if err != nil { + return nil, err } req.URL.User = r.URL.User diff --git a/api/request_test.go b/api/request_test.go index 904f59a16..f2657e61c 100644 --- a/api/request_test.go +++ b/api/request_test.go @@ -1,8 +1,6 @@ package api import ( - "bytes" - "io" "strings" "testing" ) @@ -14,20 +12,11 @@ func TestRequestSetJSONBody(t *testing.T) { t.Fatalf("err: %s", err) } - var buf bytes.Buffer - if _, err := io.Copy(&buf, r.Body); err != nil { - t.Fatalf("err: %s", err) - } - expected := `{"foo":"bar"}` - actual := strings.TrimSpace(buf.String()) + actual := strings.TrimSpace(string(r.BodyBytes)) if actual != expected { t.Fatalf("bad: %s", actual) } - - if int64(len(buf.String())) != r.BodySize { - t.Fatalf("bad: %d", len(actual)) - } } func TestRequestResetJSONBody(t *testing.T) { @@ -37,27 +26,16 @@ func TestRequestResetJSONBody(t *testing.T) { t.Fatalf("err: %s", err) } - var buf bytes.Buffer - if _, err := io.Copy(&buf, r.Body); err != nil { - t.Fatalf("err: %s", err) - } - if err := r.ResetJSONBody(); err != nil { t.Fatalf("err: %s", err) } - var buf2 bytes.Buffer - if _, err := io.Copy(&buf2, r.Body); err != nil { - t.Fatalf("err: %s", err) - } + buf := make([]byte, len(r.BodyBytes)) + copy(buf, r.BodyBytes) expected := `{"foo":"bar"}` - actual := strings.TrimSpace(buf2.String()) + actual := strings.TrimSpace(string(buf)) if actual != expected { - t.Fatalf("bad: %s", actual) - } - - if int64(len(buf2.String())) != r.BodySize { - t.Fatalf("bad: %d", len(actual)) + t.Fatalf("bad: actual %s, expected %s", actual, expected) } }