open-vault/http/forwarded_for_test.go
Hamid Ghaf 27bb03bbc0
adding copyright header (#19555)
* adding copyright header

* fix fmt and a test
2023-03-15 09:00:52 -07:00

259 lines
7.7 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package http
import (
"bytes"
"net/http"
"strings"
"testing"
sockaddr "github.com/hashicorp/go-sockaddr"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/vault"
)
func getListenerConfigForMarshalerTest(addr sockaddr.IPAddr) *configutil.Listener {
return &configutil.Listener{
XForwardedForAuthorizedAddrs: []*sockaddr.SockAddrMarshaler{
{
SockAddr: addr,
},
},
}
}
func TestHandler_XForwardedFor(t *testing.T) {
goodAddr, err := sockaddr.NewIPAddr("127.0.0.1")
if err != nil {
t.Fatal(err)
}
badAddr, err := sockaddr.NewIPAddr("1.2.3.4")
if err != nil {
t.Fatal(err)
}
// First: test reject not present
t.Run("reject_not_present", func(t *testing.T) {
t.Parallel()
testHandler := func(props *vault.HandlerProperties) http.Handler {
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.RemoteAddr))
})
listenerConfig := getListenerConfigForMarshalerTest(goodAddr)
listenerConfig.XForwardedForRejectNotPresent = true
return WrapForwardedForHandler(origHandler, listenerConfig)
}
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
HandlerFunc: HandlerFunc(testHandler),
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
req := client.NewRequest("GET", "/")
_, err := client.RawRequest(req)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "missing x-forwarded-for") {
t.Fatalf("bad error message: %v", err)
}
req = client.NewRequest("GET", "/")
req.Headers = make(http.Header)
req.Headers.Set("x-forwarded-for", "1.2.3.4")
resp, err := client.RawRequest(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
buf := bytes.NewBuffer(nil)
buf.ReadFrom(resp.Body)
if !strings.HasPrefix(buf.String(), "1.2.3.4:") {
t.Fatalf("bad body: %s", buf.String())
}
})
// Next: test allow unauth
t.Run("allow_unauth", func(t *testing.T) {
t.Parallel()
testHandler := func(props *vault.HandlerProperties) http.Handler {
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.RemoteAddr))
})
listenerConfig := getListenerConfigForMarshalerTest(badAddr)
listenerConfig.XForwardedForRejectNotPresent = true
return WrapForwardedForHandler(origHandler, listenerConfig)
}
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
HandlerFunc: HandlerFunc(testHandler),
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
req := client.NewRequest("GET", "/")
req.Headers = make(http.Header)
req.Headers.Set("x-forwarded-for", "5.6.7.8")
resp, err := client.RawRequest(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
buf := bytes.NewBuffer(nil)
buf.ReadFrom(resp.Body)
if !strings.HasPrefix(buf.String(), "127.0.0.1:") {
t.Fatalf("bad body: %s", buf.String())
}
})
// Next: test fail unauth
t.Run("fail_unauth", func(t *testing.T) {
t.Parallel()
testHandler := func(props *vault.HandlerProperties) http.Handler {
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.RemoteAddr))
})
listenerConfig := getListenerConfigForMarshalerTest(badAddr)
listenerConfig.XForwardedForRejectNotPresent = true
listenerConfig.XForwardedForRejectNotAuthorized = true
return WrapForwardedForHandler(origHandler, listenerConfig)
}
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
HandlerFunc: HandlerFunc(testHandler),
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
req := client.NewRequest("GET", "/")
req.Headers = make(http.Header)
req.Headers.Set("x-forwarded-for", "5.6.7.8")
_, err := client.RawRequest(req)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "not authorized for x-forwarded-for") {
t.Fatalf("bad error message: %v", err)
}
})
// Next: test bad hops (too many)
t.Run("too_many_hops", func(t *testing.T) {
t.Parallel()
testHandler := func(props *vault.HandlerProperties) http.Handler {
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.RemoteAddr))
})
listenerConfig := getListenerConfigForMarshalerTest(goodAddr)
listenerConfig.XForwardedForRejectNotPresent = true
listenerConfig.XForwardedForRejectNotAuthorized = true
listenerConfig.XForwardedForHopSkips = 4
return WrapForwardedForHandler(origHandler, listenerConfig)
}
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
HandlerFunc: HandlerFunc(testHandler),
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
req := client.NewRequest("GET", "/")
req.Headers = make(http.Header)
req.Headers.Set("x-forwarded-for", "2.3.4.5,3.4.5.6")
_, err := client.RawRequest(req)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "would skip before earliest") {
t.Fatalf("bad error message: %v", err)
}
})
// Next: test picking correct value
t.Run("correct_hop_skipping", func(t *testing.T) {
t.Parallel()
testHandler := func(props *vault.HandlerProperties) http.Handler {
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.RemoteAddr))
})
listenerConfig := getListenerConfigForMarshalerTest(goodAddr)
listenerConfig.XForwardedForRejectNotPresent = true
listenerConfig.XForwardedForRejectNotAuthorized = true
listenerConfig.XForwardedForHopSkips = 1
return WrapForwardedForHandler(origHandler, listenerConfig)
}
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
HandlerFunc: HandlerFunc(testHandler),
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
req := client.NewRequest("GET", "/")
req.Headers = make(http.Header)
req.Headers.Set("x-forwarded-for", "2.3.4.5,3.4.5.6,4.5.6.7,5.6.7.8")
resp, err := client.RawRequest(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
buf := bytes.NewBuffer(nil)
buf.ReadFrom(resp.Body)
if !strings.HasPrefix(buf.String(), "4.5.6.7:") {
t.Fatalf("bad body: %s", buf.String())
}
})
// Next: multi-header approach
t.Run("correct_hop_skipping_multi_header", func(t *testing.T) {
t.Parallel()
testHandler := func(props *vault.HandlerProperties) http.Handler {
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.RemoteAddr))
})
listenerConfig := getListenerConfigForMarshalerTest(goodAddr)
listenerConfig.XForwardedForRejectNotPresent = true
listenerConfig.XForwardedForRejectNotAuthorized = true
listenerConfig.XForwardedForHopSkips = 1
return WrapForwardedForHandler(origHandler, listenerConfig)
}
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
HandlerFunc: HandlerFunc(testHandler),
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
req := client.NewRequest("GET", "/")
req.Headers = make(http.Header)
req.Headers.Add("x-forwarded-for", "2.3.4.5")
req.Headers.Add("x-forwarded-for", "3.4.5.6,4.5.6.7")
req.Headers.Add("x-forwarded-for", "5.6.7.8")
resp, err := client.RawRequest(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
buf := bytes.NewBuffer(nil)
buf.ReadFrom(resp.Body)
if !strings.HasPrefix(buf.String(), "4.5.6.7:") {
t.Fatalf("bad body: %s", buf.String())
}
})
}