diff --git a/api/client.go b/api/client.go index 96d9d05fb..fd5a11ee3 100644 --- a/api/client.go +++ b/api/client.go @@ -1,6 +1,8 @@ package api import ( + "errors" + "fmt" "net/http" "net/http/cookiejar" "net/url" @@ -9,6 +11,10 @@ import ( vaultHttp "github.com/hashicorp/vault/http" ) +var ( + errRedirect = errors.New("redirect") +) + // Config is used to configure the creation of the client. type Config struct { // Address is the address of the Vault server. This should be a complete @@ -30,7 +36,7 @@ type Config struct { func DefaultConfig() Config { config := Config{ Address: "https://127.0.0.1:8200", - HttpClient: http.DefaultClient, + HttpClient: &http.Client{}, } return config @@ -64,6 +70,11 @@ func NewClient(c Config) (*Client, error) { c.HttpClient.Jar = jar } + // Ensure redirects are not automatically followed + c.HttpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return errRedirect + } + return &Client{ addr: u, config: c, @@ -128,6 +139,8 @@ func (c *Client) NewRequest(method, path string) *Request { // a Vault server not configured with this client. This is an advanced operation // that generally won't need to be called externally. func (c *Client) RawRequest(r *Request) (*Response, error) { + redirectCount := 0 +START: req, err := r.ToHTTP() if err != nil { return nil, err @@ -138,10 +151,46 @@ func (c *Client) RawRequest(r *Request) (*Response, error) { if resp != nil { result = &Response{Response: resp} } + if err != nil { + urlErr, ok := err.(*url.Error) + if ok && urlErr.Err == errRedirect { + err = nil + } + } if err != nil { return result, err } + // Check for a redirect, only allowing for a single redirect + if (resp.StatusCode == 302 || resp.StatusCode == 307) && redirectCount == 0 { + // Parse the updated location + respLoc, err := resp.Location() + if err != nil { + return result, err + } + + // Ensure a protocol downgrade doesn't happen + if req.URL.Scheme == "https" && respLoc.Scheme != "https" { + return result, fmt.Errorf("redirect would cause protocol downgrade") + } + + // Copy the cookies so that our client auth transfers + cookies := c.config.HttpClient.Jar.Cookies(r.URL) + c.config.HttpClient.Jar.SetCookies(respLoc, cookies) + + // Update the request + r.URL = respLoc + + // Reset the request body if any + if err := r.ResetJSONBody(); err != nil { + return result, err + } + + // Retry the request + redirectCount++ + goto START + } + if err := result.Error(); err != nil { return result, err } diff --git a/api/client_test.go b/api/client_test.go index e249516f3..9525e2112 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -1,6 +1,8 @@ package api import ( + "bytes" + "io" "net/http" "testing" "time" @@ -94,3 +96,49 @@ func TestClientSetToken(t *testing.T) { t.Fatalf("bad: %s", v) } } + +func TestClientRedirect(t *testing.T) { + primary := func(w http.ResponseWriter, req *http.Request) { + cookie, err := req.Cookie(vaultHttp.AuthCookieName) + if err != nil { + t.Fatalf("err: %s", err) + } + if cookie.Value != "foo" { + t.Fatalf("Bad: %#v", cookie) + } + + w.Write([]byte("test")) + } + config, ln := testHTTPServer(t, http.HandlerFunc(primary)) + defer ln.Close() + + standby := func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Location", config.Address) + w.WriteHeader(307) + } + config2, ln2 := testHTTPServer(t, http.HandlerFunc(standby)) + defer ln2.Close() + + client, err := NewClient(config2) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Set the cookie manually + client.SetToken("foo") + + // Do a raw "/" request + resp, err := client.RawRequest(client.NewRequest("PUT", "/")) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Copy the response + var buf bytes.Buffer + io.Copy(&buf, resp.Body) + + // Verify we got the response from the primary + if buf.String() != "test" { + t.Fatalf("Bad: %s", buf.String()) + } +}