api: Support redirect for HA
This commit is contained in:
parent
57f3ceac14
commit
fbaca87f56
|
@ -1,6 +1,8 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
|
@ -9,6 +11,10 @@ import (
|
|||
vaultHttp "github.com/hashicorp/vault/http"
|
||||
)
|
||||
|
||||
var (
|
||||
errRedirect = errors.New("redirect")
|
||||
)
|
||||
|
||||
// Config is used to configure the creation of the client.
|
||||
type Config struct {
|
||||
// Address is the address of the Vault server. This should be a complete
|
||||
|
@ -30,7 +36,7 @@ type Config struct {
|
|||
func DefaultConfig() Config {
|
||||
config := Config{
|
||||
Address: "https://127.0.0.1:8200",
|
||||
HttpClient: http.DefaultClient,
|
||||
HttpClient: &http.Client{},
|
||||
}
|
||||
|
||||
return config
|
||||
|
@ -64,6 +70,11 @@ func NewClient(c Config) (*Client, error) {
|
|||
c.HttpClient.Jar = jar
|
||||
}
|
||||
|
||||
// Ensure redirects are not automatically followed
|
||||
c.HttpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return errRedirect
|
||||
}
|
||||
|
||||
return &Client{
|
||||
addr: u,
|
||||
config: c,
|
||||
|
@ -128,6 +139,8 @@ func (c *Client) NewRequest(method, path string) *Request {
|
|||
// a Vault server not configured with this client. This is an advanced operation
|
||||
// that generally won't need to be called externally.
|
||||
func (c *Client) RawRequest(r *Request) (*Response, error) {
|
||||
redirectCount := 0
|
||||
START:
|
||||
req, err := r.ToHTTP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -138,10 +151,46 @@ func (c *Client) RawRequest(r *Request) (*Response, error) {
|
|||
if resp != nil {
|
||||
result = &Response{Response: resp}
|
||||
}
|
||||
if err != nil {
|
||||
urlErr, ok := err.(*url.Error)
|
||||
if ok && urlErr.Err == errRedirect {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Check for a redirect, only allowing for a single redirect
|
||||
if (resp.StatusCode == 302 || resp.StatusCode == 307) && redirectCount == 0 {
|
||||
// Parse the updated location
|
||||
respLoc, err := resp.Location()
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Ensure a protocol downgrade doesn't happen
|
||||
if req.URL.Scheme == "https" && respLoc.Scheme != "https" {
|
||||
return result, fmt.Errorf("redirect would cause protocol downgrade")
|
||||
}
|
||||
|
||||
// Copy the cookies so that our client auth transfers
|
||||
cookies := c.config.HttpClient.Jar.Cookies(r.URL)
|
||||
c.config.HttpClient.Jar.SetCookies(respLoc, cookies)
|
||||
|
||||
// Update the request
|
||||
r.URL = respLoc
|
||||
|
||||
// Reset the request body if any
|
||||
if err := r.ResetJSONBody(); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Retry the request
|
||||
redirectCount++
|
||||
goto START
|
||||
}
|
||||
|
||||
if err := result.Error(); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -94,3 +96,49 @@ func TestClientSetToken(t *testing.T) {
|
|||
t.Fatalf("bad: %s", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientRedirect(t *testing.T) {
|
||||
primary := func(w http.ResponseWriter, req *http.Request) {
|
||||
cookie, err := req.Cookie(vaultHttp.AuthCookieName)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if cookie.Value != "foo" {
|
||||
t.Fatalf("Bad: %#v", cookie)
|
||||
}
|
||||
|
||||
w.Write([]byte("test"))
|
||||
}
|
||||
config, ln := testHTTPServer(t, http.HandlerFunc(primary))
|
||||
defer ln.Close()
|
||||
|
||||
standby := func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set("Location", config.Address)
|
||||
w.WriteHeader(307)
|
||||
}
|
||||
config2, ln2 := testHTTPServer(t, http.HandlerFunc(standby))
|
||||
defer ln2.Close()
|
||||
|
||||
client, err := NewClient(config2)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Set the cookie manually
|
||||
client.SetToken("foo")
|
||||
|
||||
// Do a raw "/" request
|
||||
resp, err := client.RawRequest(client.NewRequest("PUT", "/"))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Copy the response
|
||||
var buf bytes.Buffer
|
||||
io.Copy(&buf, resp.Body)
|
||||
|
||||
// Verify we got the response from the primary
|
||||
if buf.String() != "test" {
|
||||
t.Fatalf("Bad: %s", buf.String())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue