250 lines
6.7 KiB
Go
250 lines
6.7 KiB
Go
package http
|
|
|
|
import (
|
|
"bytes"
|
|
"net/http"
|
|
"strings"
|
|
"testing"
|
|
|
|
sockaddr "github.com/hashicorp/go-sockaddr"
|
|
"github.com/hashicorp/vault/vault"
|
|
)
|
|
|
|
func TestHandler_XForwardedFor(t *testing.T) {
|
|
goodAddr, err := sockaddr.NewIPAddr("127.0.0.1")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
badAddr, err := sockaddr.NewIPAddr("1.2.3.4")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// First: test reject not present
|
|
t.Run("reject_not_present", func(t *testing.T) {
|
|
t.Parallel()
|
|
testHandler := func(c *vault.Core) http.Handler {
|
|
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(r.RemoteAddr))
|
|
})
|
|
return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
|
|
&sockaddr.SockAddrMarshaler{
|
|
SockAddr: goodAddr,
|
|
},
|
|
}, true, false, 0)
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
HandlerFunc: testHandler,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
client := cluster.Cores[0].Client
|
|
|
|
req := client.NewRequest("GET", "/")
|
|
_, err = client.RawRequest(req)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "missing x-forwarded-for") {
|
|
t.Fatalf("bad error message: %v", err)
|
|
}
|
|
req = client.NewRequest("GET", "/")
|
|
req.Headers = make(http.Header)
|
|
req.Headers.Set("x-forwarded-for", "1.2.3.4")
|
|
resp, err := client.RawRequest(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
buf := bytes.NewBuffer(nil)
|
|
buf.ReadFrom(resp.Body)
|
|
if !strings.HasPrefix(buf.String(), "1.2.3.4:") {
|
|
t.Fatalf("bad body: %s", buf.String())
|
|
}
|
|
})
|
|
|
|
// Next: test allow unauth
|
|
t.Run("allow_unauth", func(t *testing.T) {
|
|
t.Parallel()
|
|
testHandler := func(c *vault.Core) http.Handler {
|
|
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(r.RemoteAddr))
|
|
})
|
|
return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
|
|
&sockaddr.SockAddrMarshaler{
|
|
SockAddr: badAddr,
|
|
},
|
|
}, true, false, 0)
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
HandlerFunc: testHandler,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
client := cluster.Cores[0].Client
|
|
|
|
req := client.NewRequest("GET", "/")
|
|
req.Headers = make(http.Header)
|
|
req.Headers.Set("x-forwarded-for", "5.6.7.8")
|
|
resp, err := client.RawRequest(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
buf := bytes.NewBuffer(nil)
|
|
buf.ReadFrom(resp.Body)
|
|
if !strings.HasPrefix(buf.String(), "127.0.0.1:") {
|
|
t.Fatalf("bad body: %s", buf.String())
|
|
}
|
|
})
|
|
|
|
// Next: test fail unauth
|
|
t.Run("fail_unauth", func(t *testing.T) {
|
|
t.Parallel()
|
|
testHandler := func(c *vault.Core) http.Handler {
|
|
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(r.RemoteAddr))
|
|
})
|
|
return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
|
|
&sockaddr.SockAddrMarshaler{
|
|
SockAddr: badAddr,
|
|
},
|
|
}, true, true, 0)
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
HandlerFunc: testHandler,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
client := cluster.Cores[0].Client
|
|
|
|
req := client.NewRequest("GET", "/")
|
|
req.Headers = make(http.Header)
|
|
req.Headers.Set("x-forwarded-for", "5.6.7.8")
|
|
_, err = client.RawRequest(req)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "not authorized for x-forwarded-for") {
|
|
t.Fatalf("bad error message: %v", err)
|
|
}
|
|
})
|
|
|
|
// Next: test bad hops (too many)
|
|
t.Run("too_many_hops", func(t *testing.T) {
|
|
t.Parallel()
|
|
testHandler := func(c *vault.Core) http.Handler {
|
|
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(r.RemoteAddr))
|
|
})
|
|
return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
|
|
&sockaddr.SockAddrMarshaler{
|
|
SockAddr: goodAddr,
|
|
},
|
|
}, true, true, 4)
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
HandlerFunc: testHandler,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
client := cluster.Cores[0].Client
|
|
|
|
req := client.NewRequest("GET", "/")
|
|
req.Headers = make(http.Header)
|
|
req.Headers.Set("x-forwarded-for", "2.3.4.5,3.4.5.6")
|
|
_, err = client.RawRequest(req)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "would skip before earliest") {
|
|
t.Fatalf("bad error message: %v", err)
|
|
}
|
|
})
|
|
|
|
// Next: test picking correct value
|
|
t.Run("correct_hop_skipping", func(t *testing.T) {
|
|
t.Parallel()
|
|
testHandler := func(c *vault.Core) http.Handler {
|
|
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(r.RemoteAddr))
|
|
})
|
|
return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
|
|
&sockaddr.SockAddrMarshaler{
|
|
SockAddr: goodAddr,
|
|
},
|
|
}, true, true, 1)
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
HandlerFunc: testHandler,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
client := cluster.Cores[0].Client
|
|
|
|
req := client.NewRequest("GET", "/")
|
|
req.Headers = make(http.Header)
|
|
req.Headers.Set("x-forwarded-for", "2.3.4.5,3.4.5.6,4.5.6.7,5.6.7.8")
|
|
resp, err := client.RawRequest(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
buf := bytes.NewBuffer(nil)
|
|
buf.ReadFrom(resp.Body)
|
|
if !strings.HasPrefix(buf.String(), "4.5.6.7:") {
|
|
t.Fatalf("bad body: %s", buf.String())
|
|
}
|
|
})
|
|
|
|
// Next: multi-header approach
|
|
t.Run("correct_hop_skipping_multi_header", func(t *testing.T) {
|
|
t.Parallel()
|
|
testHandler := func(c *vault.Core) http.Handler {
|
|
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(r.RemoteAddr))
|
|
})
|
|
return WrapForwardedForHandler(origHandler, []*sockaddr.SockAddrMarshaler{
|
|
&sockaddr.SockAddrMarshaler{
|
|
SockAddr: goodAddr,
|
|
},
|
|
}, true, true, 1)
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
HandlerFunc: testHandler,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
client := cluster.Cores[0].Client
|
|
|
|
req := client.NewRequest("GET", "/")
|
|
req.Headers = make(http.Header)
|
|
req.Headers.Add("x-forwarded-for", "2.3.4.5")
|
|
req.Headers.Add("x-forwarded-for", "3.4.5.6,4.5.6.7")
|
|
req.Headers.Add("x-forwarded-for", "5.6.7.8")
|
|
resp, err := client.RawRequest(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
buf := bytes.NewBuffer(nil)
|
|
buf.ReadFrom(resp.Body)
|
|
if !strings.HasPrefix(buf.String(), "4.5.6.7:") {
|
|
t.Fatalf("bad body: %s", buf.String())
|
|
}
|
|
})
|
|
}
|