Pass headers back when request forwarding (#1795)

This commit is contained in:
Jeff Mitchell 2016-08-26 17:53:47 -04:00 committed by GitHub
parent d9c46aadc2
commit 7e41d5ab45
5 changed files with 114 additions and 43 deletions

View File

@ -98,6 +98,9 @@ type Response struct {
// uint64 id = 1;
StatusCode uint32 `protobuf:"varint,2,opt,name=status_code,json=statusCode" json:"status_code,omitempty"`
Body []byte `protobuf:"bytes,3,opt,name=body,proto3" json:"body,omitempty"`
// Added in 0.6.2 to ensure that the content-type is set appropriately, as
// well as any other information
HeaderEntries map[string]*HeaderEntry `protobuf:"bytes,4,rep,name=header_entries,json=headerEntries" json:"header_entries,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
}
func (m *Response) Reset() { *m = Response{} }
@ -105,6 +108,13 @@ func (m *Response) String() string { return proto.CompactTextString(m
func (*Response) ProtoMessage() {}
func (*Response) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
func (m *Response) GetHeaderEntries() map[string]*HeaderEntry {
if m != nil {
return m.HeaderEntries
}
return nil
}
func init() {
proto.RegisterType((*Request)(nil), "forwarding.Request")
proto.RegisterType((*URL)(nil), "forwarding.URL")
@ -115,31 +125,33 @@ func init() {
func init() { proto.RegisterFile("types.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 416 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0x64, 0x52, 0xcd, 0x6e, 0xd4, 0x30,
0x10, 0x56, 0xd6, 0xdb, 0xdd, 0x64, 0xd2, 0x42, 0xf1, 0x01, 0x4c, 0x39, 0x10, 0x56, 0x02, 0x45,
0x42, 0xe4, 0xb0, 0x5c, 0x10, 0x17, 0x84, 0x2a, 0x24, 0x0e, 0x45, 0x02, 0x4b, 0x3d, 0x70, 0x8a,
0xdc, 0x78, 0xb6, 0x89, 0x68, 0xe2, 0xac, 0xed, 0xb0, 0xca, 0x63, 0xf1, 0x4e, 0x3c, 0x08, 0xb2,
0x1d, 0xba, 0x91, 0x7a, 0xca, 0x7c, 0x3f, 0x99, 0xcc, 0x37, 0x13, 0x48, 0xed, 0xd8, 0xa3, 0x29,
0x7a, 0xad, 0xac, 0xa2, 0xb0, 0x53, 0xfa, 0x20, 0xb4, 0x6c, 0xba, 0xdb, 0xcd, 0xdf, 0x05, 0xac,
0x39, 0xee, 0x07, 0x34, 0x96, 0x3e, 0x85, 0x55, 0x8b, 0xb6, 0x56, 0x92, 0x2d, 0xb2, 0x28, 0x4f,
0xf8, 0x84, 0xe8, 0x2b, 0x20, 0x83, 0xbe, 0x63, 0x24, 0x8b, 0xf2, 0x74, 0xfb, 0xb8, 0x38, 0xbe,
0x5d, 0x5c, 0xf3, 0x2b, 0xee, 0x34, 0xfa, 0x0d, 0x1e, 0xd5, 0x28, 0x24, 0xea, 0x12, 0x3b, 0xab,
0x1b, 0x34, 0x6c, 0x99, 0x91, 0x3c, 0xdd, 0xbe, 0x99, 0xbb, 0xa7, 0xef, 0x14, 0x5f, 0xbd, 0xf3,
0x4b, 0x30, 0xba, 0xc7, 0xc8, 0xcf, 0xea, 0x39, 0x47, 0x29, 0x2c, 0x6f, 0x94, 0x1c, 0xd9, 0x49,
0x16, 0xe5, 0xa7, 0xdc, 0xd7, 0x8e, 0xab, 0x95, 0xb1, 0x6c, 0xe5, 0x67, 0xf3, 0x35, 0x7d, 0x09,
0xa9, 0xc6, 0x56, 0x59, 0x2c, 0x85, 0x94, 0x9a, 0xad, 0xbd, 0x04, 0x81, 0xfa, 0x2c, 0xa5, 0xa6,
0x6f, 0xe1, 0x49, 0x8f, 0xa8, 0xcb, 0x0a, 0xb5, 0x6d, 0x76, 0x4d, 0x25, 0x2c, 0x1a, 0x16, 0x67,
0x24, 0x3f, 0xe5, 0xe7, 0x4e, 0xb8, 0x9c, 0xf1, 0x17, 0x3f, 0x81, 0x3e, 0x1c, 0x8d, 0x9e, 0x03,
0xf9, 0x85, 0x23, 0x8b, 0x7c, 0x6f, 0x57, 0xd2, 0x77, 0x70, 0xf2, 0x5b, 0xdc, 0x0d, 0xe8, 0xd7,
0x94, 0x6e, 0x9f, 0xcd, 0x33, 0x1e, 0x1b, 0x8c, 0x3c, 0xb8, 0x3e, 0x2e, 0x3e, 0x44, 0x9b, 0x3f,
0x11, 0x90, 0x6b, 0x7e, 0xe5, 0x56, 0x6c, 0xaa, 0x1a, 0x5b, 0x9c, 0xfa, 0x4d, 0xc8, 0xf1, 0xaa,
0x17, 0xfb, 0xa9, 0x67, 0xc2, 0x27, 0x74, 0x1f, 0x7a, 0x39, 0x0b, 0x4d, 0x61, 0xd9, 0x0b, 0x5b,
0xfb, 0xe5, 0x24, 0xdc, 0xd7, 0xf4, 0x39, 0xc4, 0x5a, 0x1c, 0x4a, 0xcf, 0x87, 0x05, 0xad, 0xb5,
0x38, 0x7c, 0x77, 0xd2, 0x0b, 0x48, 0x9c, 0xb4, 0x1f, 0x50, 0x8f, 0x2c, 0xf6, 0x9a, 0xf3, 0xfe,
0x70, 0x98, 0x5e, 0x40, 0xbc, 0xd3, 0xe2, 0xb6, 0xc5, 0xce, 0xb2, 0x24, 0x68, 0xff, 0xf1, 0xe6,
0x35, 0xa4, 0xb3, 0x34, 0x6e, 0x44, 0x9f, 0xc7, 0xb0, 0x28, 0x23, 0x6e, 0xc4, 0x80, 0x36, 0x9f,
0x20, 0xe6, 0x68, 0x7a, 0xd5, 0x19, 0x74, 0xf7, 0x30, 0x56, 0xd8, 0xc1, 0x94, 0x95, 0x92, 0x21,
0xcb, 0x19, 0x87, 0x40, 0x5d, 0x2a, 0x89, 0xf7, 0x87, 0x25, 0xc7, 0xc3, 0xde, 0xac, 0xfc, 0x5f,
0xf9, 0xfe, 0x5f, 0x00, 0x00, 0x00, 0xff, 0xff, 0xb9, 0x58, 0xe7, 0xdc, 0xa4, 0x02, 0x00, 0x00,
// 437 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xb4, 0x53, 0xc1, 0x6e, 0xd4, 0x30,
0x10, 0x95, 0xd7, 0xdb, 0xdd, 0x64, 0xd2, 0x42, 0xf1, 0x01, 0x4c, 0x39, 0x10, 0x56, 0x02, 0x22,
0x21, 0xf6, 0xb0, 0x5c, 0x10, 0x37, 0x54, 0x21, 0x71, 0x28, 0x08, 0x2c, 0xf5, 0xc0, 0x29, 0x72,
0xd7, 0xb3, 0xcd, 0x8a, 0x26, 0xce, 0xda, 0x0e, 0xab, 0x7c, 0x16, 0xff, 0xc4, 0x89, 0xaf, 0x40,
0xb6, 0x43, 0x1b, 0x84, 0x10, 0xa7, 0x9e, 0x76, 0xde, 0x7b, 0xb3, 0xe3, 0x79, 0x33, 0x13, 0xc8,
0x5c, 0xdf, 0xa2, 0x5d, 0xb6, 0x46, 0x3b, 0xcd, 0x60, 0xa3, 0xcd, 0x5e, 0x1a, 0xb5, 0x6d, 0x2e,
0x17, 0x3f, 0x26, 0x30, 0x17, 0xb8, 0xeb, 0xd0, 0x3a, 0x76, 0x1f, 0x66, 0x35, 0xba, 0x4a, 0x2b,
0x3e, 0xc9, 0x49, 0x91, 0x8a, 0x01, 0xb1, 0x27, 0x40, 0x3b, 0x73, 0xc5, 0x69, 0x4e, 0x8a, 0x6c,
0x75, 0x77, 0x79, 0xf3, 0xef, 0xe5, 0xb9, 0x38, 0x13, 0x5e, 0x63, 0x1f, 0xe0, 0x4e, 0x85, 0x52,
0xa1, 0x29, 0xb1, 0x71, 0x66, 0x8b, 0x96, 0x4f, 0x73, 0x5a, 0x64, 0xab, 0x67, 0xe3, 0xec, 0xe1,
0x9d, 0xe5, 0xfb, 0x90, 0xf9, 0x2e, 0x26, 0xfa, 0x9f, 0x5e, 0x1c, 0x55, 0x63, 0x8e, 0x31, 0x98,
0x5e, 0x68, 0xd5, 0xf3, 0x83, 0x9c, 0x14, 0x87, 0x22, 0xc4, 0x9e, 0xab, 0xb4, 0x75, 0x7c, 0x16,
0x7a, 0x0b, 0x31, 0x7b, 0x0c, 0x99, 0xc1, 0x5a, 0x3b, 0x2c, 0xa5, 0x52, 0x86, 0xcf, 0x83, 0x04,
0x91, 0x7a, 0xab, 0x94, 0x61, 0x2f, 0xe0, 0x5e, 0x8b, 0x68, 0xca, 0x35, 0x1a, 0xb7, 0xdd, 0x6c,
0xd7, 0xd2, 0xa1, 0xe5, 0x49, 0x4e, 0x8b, 0x43, 0x71, 0xec, 0x85, 0xd3, 0x11, 0x7f, 0xf2, 0x05,
0xd8, 0xdf, 0xad, 0xb1, 0x63, 0xa0, 0x5f, 0xb1, 0xe7, 0x24, 0xd4, 0xf6, 0x21, 0x7b, 0x09, 0x07,
0xdf, 0xe4, 0x55, 0x87, 0x61, 0x4c, 0xd9, 0xea, 0xc1, 0xd8, 0xe3, 0x4d, 0x81, 0x5e, 0xc4, 0xac,
0x37, 0x93, 0xd7, 0x64, 0xf1, 0x9d, 0x00, 0x3d, 0x17, 0x67, 0x7e, 0xc4, 0x76, 0x5d, 0x61, 0x8d,
0x43, 0xbd, 0x01, 0x79, 0x5e, 0xb7, 0x72, 0x37, 0xd4, 0x4c, 0xc5, 0x80, 0xae, 0x4d, 0x4f, 0x47,
0xa6, 0x19, 0x4c, 0x5b, 0xe9, 0xaa, 0x30, 0x9c, 0x54, 0x84, 0x98, 0x3d, 0x84, 0xc4, 0xc8, 0x7d,
0x19, 0xf8, 0x38, 0xa0, 0xb9, 0x91, 0xfb, 0x4f, 0x5e, 0x7a, 0x04, 0xa9, 0x97, 0x76, 0x1d, 0x9a,
0x9e, 0x27, 0x41, 0xf3, 0xb9, 0x9f, 0x3d, 0x66, 0x27, 0x90, 0x6c, 0x8c, 0xbc, 0xac, 0xb1, 0x71,
0x3c, 0x8d, 0xda, 0x6f, 0xbc, 0x78, 0x0a, 0xd9, 0xc8, 0x8d, 0x6f, 0x31, 0xf8, 0xb1, 0x9c, 0xe4,
0xd4, 0xb7, 0x18, 0xd1, 0xe2, 0x27, 0x81, 0x44, 0xa0, 0x6d, 0x75, 0x63, 0xd1, 0x2f, 0xc4, 0x3a,
0xe9, 0x3a, 0x5b, 0xae, 0xb5, 0x8a, 0x66, 0x8e, 0x04, 0x44, 0xea, 0x54, 0x2b, 0xbc, 0xde, 0x2c,
0x1d, 0x6d, 0xf6, 0xe3, 0x3f, 0x8e, 0xe7, 0xf9, 0x9f, 0xc7, 0x13, 0x9f, 0xf8, 0xff, 0xf5, 0xdc,
0xe2, 0x1e, 0x2f, 0x66, 0xe1, 0x0b, 0x7a, 0xf5, 0x2b, 0x00, 0x00, 0xff, 0xff, 0x57, 0x73, 0xdf,
0x6b, 0x50, 0x03, 0x00, 0x00,
}

View File

@ -40,4 +40,7 @@ message Response {
//uint64 id = 1;
uint32 status_code = 2;
bytes body = 3;
// Added in 0.6.2 to ensure that the content-type is set appropriately, as
// well as any other information
map<string, HeaderEntry> header_entries = 4;
}

View File

@ -136,7 +136,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle
// Attempt forwarding the request. If we cannot forward -- perhaps it's
// been disabled on the active node -- this will return with an
// ErrCannotForward and we simply fall back
statusCode, retBytes, err := core.ForwardRequest(r)
statusCode, header, retBytes, err := core.ForwardRequest(r)
if err != nil {
if err == vault.ErrCannotForward {
core.Logger().Trace("http/handleRequestForwarding: cannot forward (possibly disabled on active node), falling back")
@ -149,6 +149,15 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle
return
}
if header != nil {
for k, v := range header {
for _, j := range v {
core.Logger().Trace("writing header %v %v", k, j)
w.Header().Add(k, j)
}
}
}
w.WriteHeader(statusCode)
w.Write(retBytes)
return

View File

@ -6,6 +6,7 @@ import (
"fmt"
"net"
"net/http"
"os"
"testing"
"time"
@ -168,18 +169,33 @@ func TestCluster_ForwardRequests(t *testing.T) {
// Make this nicer for tests
manualStepDownSleepPeriod = 5 * time.Second
testCluster_ForwardRequestsCommon(t, false)
testCluster_ForwardRequestsCommon(t, true)
os.Setenv("VAULT_USE_GRPC_REQUEST_FORWARDING", "")
}
func testCluster_ForwardRequestsCommon(t *testing.T, rpc bool) {
if rpc {
os.Setenv("VAULT_USE_GRPC_REQUEST_FORWARDING", "1")
} else {
os.Setenv("VAULT_USE_GRPC_REQUEST_FORWARDING", "")
}
handler1 := http.NewServeMux()
handler1.HandleFunc("/core1", func(w http.ResponseWriter, req *http.Request) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(201)
w.Write([]byte("core1"))
})
handler2 := http.NewServeMux()
handler2.HandleFunc("/core2", func(w http.ResponseWriter, req *http.Request) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(202)
w.Write([]byte("core2"))
})
handler3 := http.NewServeMux()
handler3.HandleFunc("/core3", func(w http.ResponseWriter, req *http.Request) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(203)
w.Write([]byte("core3"))
})
@ -331,10 +347,18 @@ func testCluster_ForwardRequests(t *testing.T, c *TestClusterCore, remoteCoreID
}
req.Header.Add("X-Vault-Token", c.Root)
statusCode, respBytes, err := c.ForwardRequest(req)
statusCode, header, respBytes, err := c.ForwardRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if header == nil {
t.Fatal("err: expected at least a content-type header")
}
if header.Get("Content-Type") != "application/json" {
t.Fatalf("bad content-type: %s", header.Get("Content-Type"))
}
body := string(respBytes)
if body != remoteCoreID {

View File

@ -251,30 +251,30 @@ func (c *Core) clearForwardingClients() {
// ForwardRequest forwards a given request to the active node and returns the
// response.
func (c *Core) ForwardRequest(req *http.Request) (int, []byte, error) {
func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, error) {
c.requestForwardingConnectionLock.RLock()
defer c.requestForwardingConnectionLock.RUnlock()
switch os.Getenv("VAULT_USE_GRPC_REQUEST_FORWARDING") {
case "":
if c.requestForwardingConnection == nil {
return 0, nil, ErrCannotForward
return 0, nil, nil, ErrCannotForward
}
if c.requestForwardingConnection.clusterAddr == "" {
return 0, nil, ErrCannotForward
return 0, nil, nil, ErrCannotForward
}
freq, err := forwarding.GenerateForwardedHTTPRequest(req, c.requestForwardingConnection.clusterAddr+"/cluster/local/forwarded-request")
if err != nil {
c.logger.Error("core/ForwardRequest: error creating forwarded request", "error", err)
return 0, nil, fmt.Errorf("error creating forwarding request")
return 0, nil, nil, fmt.Errorf("error creating forwarding request")
}
//resp, err := c.requestForwardingConnection.Do(freq)
resp, err := c.requestForwardingConnection.transport.RoundTrip(freq)
if err != nil {
return 0, nil, err
return 0, nil, nil, err
}
defer resp.Body.Close()
@ -283,30 +283,41 @@ func (c *Core) ForwardRequest(req *http.Request) (int, []byte, error) {
buf := bytes.NewBuffer(nil)
_, err = buf.ReadFrom(resp.Body)
if err != nil {
return 0, nil, err
return 0, nil, nil, err
}
return resp.StatusCode, buf.Bytes(), nil
return resp.StatusCode, resp.Header, buf.Bytes(), nil
default:
if c.rpcForwardingClient == nil {
return 0, nil, ErrCannotForward
return 0, nil, nil, ErrCannotForward
}
freq, err := forwarding.GenerateForwardedRequest(req)
if err != nil {
c.logger.Error("core/ForwardRequest: error creating forwarding RPC request", "error", err)
return 0, nil, fmt.Errorf("error creating forwarding RPC request")
return 0, nil, nil, fmt.Errorf("error creating forwarding RPC request")
}
if freq == nil {
c.logger.Error("core/ForwardRequest: got nil forwarding RPC request")
return 0, nil, fmt.Errorf("got nil forwarding RPC request")
return 0, nil, nil, fmt.Errorf("got nil forwarding RPC request")
}
resp, err := c.rpcForwardingClient.HandleRequest(context.Background(), freq, grpc.FailFast(true))
if err != nil {
c.logger.Error("core/ForwardRequest: error during forwarded RPC request", "error", err)
return 0, nil, fmt.Errorf("error during forwarding RPC request")
return 0, nil, nil, fmt.Errorf("error during forwarding RPC request")
}
return int(resp.StatusCode), resp.Body, nil
var header http.Header
if resp.HeaderEntries != nil {
header = make(http.Header)
for k, v := range resp.HeaderEntries {
for _, j := range v.Values {
header.Add(k, j)
}
}
}
return int(resp.StatusCode), header, resp.Body, nil
}
}
@ -347,8 +358,20 @@ func (s *forwardedRequestRPCServer) HandleRequest(ctx context.Context, freq *for
s.handler.ServeHTTP(w, req)
return &forwarding.Response{
resp := &forwarding.Response{
StatusCode: uint32(w.StatusCode()),
Body: w.Body().Bytes(),
}, nil
}
header := w.Header()
if header != nil {
resp.HeaderEntries = make(map[string]*forwarding.HeaderEntry, len(header))
for k, v := range header {
resp.HeaderEntries[k] = &forwarding.HeaderEntry{
Values: v,
}
}
}
return resp, nil
}