api: Support redirect for HA

This commit is contained in:
Armon Dadgar 2015-04-20 11:30:35 -07:00
parent 57f3ceac14
commit fbaca87f56
2 changed files with 98 additions and 1 deletions

View file

@ -1,6 +1,8 @@
package api
import (
"errors"
"fmt"
"net/http"
"net/http/cookiejar"
"net/url"
@ -9,6 +11,10 @@ import (
vaultHttp "github.com/hashicorp/vault/http"
)
var (
errRedirect = errors.New("redirect")
)
// Config is used to configure the creation of the client.
type Config struct {
// Address is the address of the Vault server. This should be a complete
@ -30,7 +36,7 @@ type Config struct {
func DefaultConfig() Config {
config := Config{
Address: "https://127.0.0.1:8200",
HttpClient: http.DefaultClient,
HttpClient: &http.Client{},
}
return config
@ -64,6 +70,11 @@ func NewClient(c Config) (*Client, error) {
c.HttpClient.Jar = jar
}
// Ensure redirects are not automatically followed
c.HttpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return errRedirect
}
return &Client{
addr: u,
config: c,
@ -128,6 +139,8 @@ func (c *Client) NewRequest(method, path string) *Request {
// a Vault server not configured with this client. This is an advanced operation
// that generally won't need to be called externally.
func (c *Client) RawRequest(r *Request) (*Response, error) {
redirectCount := 0
START:
req, err := r.ToHTTP()
if err != nil {
return nil, err
@ -138,10 +151,46 @@ func (c *Client) RawRequest(r *Request) (*Response, error) {
if resp != nil {
result = &Response{Response: resp}
}
if err != nil {
urlErr, ok := err.(*url.Error)
if ok && urlErr.Err == errRedirect {
err = nil
}
}
if err != nil {
return result, err
}
// Check for a redirect, only allowing for a single redirect
if (resp.StatusCode == 302 || resp.StatusCode == 307) && redirectCount == 0 {
// Parse the updated location
respLoc, err := resp.Location()
if err != nil {
return result, err
}
// Ensure a protocol downgrade doesn't happen
if req.URL.Scheme == "https" && respLoc.Scheme != "https" {
return result, fmt.Errorf("redirect would cause protocol downgrade")
}
// Copy the cookies so that our client auth transfers
cookies := c.config.HttpClient.Jar.Cookies(r.URL)
c.config.HttpClient.Jar.SetCookies(respLoc, cookies)
// Update the request
r.URL = respLoc
// Reset the request body if any
if err := r.ResetJSONBody(); err != nil {
return result, err
}
// Retry the request
redirectCount++
goto START
}
if err := result.Error(); err != nil {
return result, err
}

View file

@ -1,6 +1,8 @@
package api
import (
"bytes"
"io"
"net/http"
"testing"
"time"
@ -94,3 +96,49 @@ func TestClientSetToken(t *testing.T) {
t.Fatalf("bad: %s", v)
}
}
func TestClientRedirect(t *testing.T) {
primary := func(w http.ResponseWriter, req *http.Request) {
cookie, err := req.Cookie(vaultHttp.AuthCookieName)
if err != nil {
t.Fatalf("err: %s", err)
}
if cookie.Value != "foo" {
t.Fatalf("Bad: %#v", cookie)
}
w.Write([]byte("test"))
}
config, ln := testHTTPServer(t, http.HandlerFunc(primary))
defer ln.Close()
standby := func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Location", config.Address)
w.WriteHeader(307)
}
config2, ln2 := testHTTPServer(t, http.HandlerFunc(standby))
defer ln2.Close()
client, err := NewClient(config2)
if err != nil {
t.Fatalf("err: %s", err)
}
// Set the cookie manually
client.SetToken("foo")
// Do a raw "/" request
resp, err := client.RawRequest(client.NewRequest("PUT", "/"))
if err != nil {
t.Fatalf("err: %s", err)
}
// Copy the response
var buf bytes.Buffer
io.Copy(&buf, resp.Body)
// Verify we got the response from the primary
if buf.String() != "test" {
t.Fatalf("Bad: %s", buf.String())
}
}