12e1b609ac
Create global quotas of each type in every NewTestCluster. Also switch some key locks to use DeadlockMutex to make it easier to discover deadlocks in testing. NewTestCluster also now starts the cluster, and the Start method becomes a no-op. Unless SkipInit is provided, we also wait for a node to become active, eliminating the need for WaitForActiveNode. This was needed because otherwise we can't safely make the quota api call. We can't do it in Start because Start doesn't return an error, and I didn't want to begin storing the testing object T instead TestCluster just so we could call t.Fatal inside Start. The last change here was to address the problem of how to skip setting up quotas when creating a cluster with a nonstandard handler that might not even implement the quotas endpoint. The challenge is that because we were taking a func pointer to generate the real handler func, we didn't have any way to compare that func pointer to the standard handler-generating func http.Handler without creating a circular dependency between packages vault and http. The solution was to pass a method instead of an anonymous func pointer so that we can do reflection on it.
256 lines
7.6 KiB
Go
256 lines
7.6 KiB
Go
package http
|
|
|
|
import (
|
|
"bytes"
|
|
"net/http"
|
|
"strings"
|
|
"testing"
|
|
|
|
sockaddr "github.com/hashicorp/go-sockaddr"
|
|
"github.com/hashicorp/vault/internalshared/configutil"
|
|
"github.com/hashicorp/vault/vault"
|
|
)
|
|
|
|
func getListenerConfigForMarshalerTest(addr sockaddr.IPAddr) *configutil.Listener {
|
|
return &configutil.Listener{
|
|
XForwardedForAuthorizedAddrs: []*sockaddr.SockAddrMarshaler{
|
|
{
|
|
SockAddr: addr,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func TestHandler_XForwardedFor(t *testing.T) {
|
|
goodAddr, err := sockaddr.NewIPAddr("127.0.0.1")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
badAddr, err := sockaddr.NewIPAddr("1.2.3.4")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// First: test reject not present
|
|
t.Run("reject_not_present", func(t *testing.T) {
|
|
t.Parallel()
|
|
testHandler := func(props *vault.HandlerProperties) http.Handler {
|
|
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(r.RemoteAddr))
|
|
})
|
|
listenerConfig := getListenerConfigForMarshalerTest(goodAddr)
|
|
listenerConfig.XForwardedForRejectNotPresent = true
|
|
return WrapForwardedForHandler(origHandler, listenerConfig)
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
HandlerFunc: HandlerFunc(testHandler),
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
client := cluster.Cores[0].Client
|
|
|
|
req := client.NewRequest("GET", "/")
|
|
_, err := client.RawRequest(req)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "missing x-forwarded-for") {
|
|
t.Fatalf("bad error message: %v", err)
|
|
}
|
|
req = client.NewRequest("GET", "/")
|
|
req.Headers = make(http.Header)
|
|
req.Headers.Set("x-forwarded-for", "1.2.3.4")
|
|
resp, err := client.RawRequest(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
buf := bytes.NewBuffer(nil)
|
|
buf.ReadFrom(resp.Body)
|
|
if !strings.HasPrefix(buf.String(), "1.2.3.4:") {
|
|
t.Fatalf("bad body: %s", buf.String())
|
|
}
|
|
})
|
|
|
|
// Next: test allow unauth
|
|
t.Run("allow_unauth", func(t *testing.T) {
|
|
t.Parallel()
|
|
testHandler := func(props *vault.HandlerProperties) http.Handler {
|
|
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(r.RemoteAddr))
|
|
})
|
|
listenerConfig := getListenerConfigForMarshalerTest(badAddr)
|
|
listenerConfig.XForwardedForRejectNotPresent = true
|
|
return WrapForwardedForHandler(origHandler, listenerConfig)
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
HandlerFunc: HandlerFunc(testHandler),
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
client := cluster.Cores[0].Client
|
|
|
|
req := client.NewRequest("GET", "/")
|
|
req.Headers = make(http.Header)
|
|
req.Headers.Set("x-forwarded-for", "5.6.7.8")
|
|
resp, err := client.RawRequest(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
buf := bytes.NewBuffer(nil)
|
|
buf.ReadFrom(resp.Body)
|
|
if !strings.HasPrefix(buf.String(), "127.0.0.1:") {
|
|
t.Fatalf("bad body: %s", buf.String())
|
|
}
|
|
})
|
|
|
|
// Next: test fail unauth
|
|
t.Run("fail_unauth", func(t *testing.T) {
|
|
t.Parallel()
|
|
testHandler := func(props *vault.HandlerProperties) http.Handler {
|
|
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(r.RemoteAddr))
|
|
})
|
|
listenerConfig := getListenerConfigForMarshalerTest(badAddr)
|
|
listenerConfig.XForwardedForRejectNotPresent = true
|
|
listenerConfig.XForwardedForRejectNotAuthorized = true
|
|
return WrapForwardedForHandler(origHandler, listenerConfig)
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
HandlerFunc: HandlerFunc(testHandler),
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
client := cluster.Cores[0].Client
|
|
|
|
req := client.NewRequest("GET", "/")
|
|
req.Headers = make(http.Header)
|
|
req.Headers.Set("x-forwarded-for", "5.6.7.8")
|
|
_, err := client.RawRequest(req)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "not authorized for x-forwarded-for") {
|
|
t.Fatalf("bad error message: %v", err)
|
|
}
|
|
})
|
|
|
|
// Next: test bad hops (too many)
|
|
t.Run("too_many_hops", func(t *testing.T) {
|
|
t.Parallel()
|
|
testHandler := func(props *vault.HandlerProperties) http.Handler {
|
|
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(r.RemoteAddr))
|
|
})
|
|
listenerConfig := getListenerConfigForMarshalerTest(goodAddr)
|
|
listenerConfig.XForwardedForRejectNotPresent = true
|
|
listenerConfig.XForwardedForRejectNotAuthorized = true
|
|
listenerConfig.XForwardedForHopSkips = 4
|
|
return WrapForwardedForHandler(origHandler, listenerConfig)
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
HandlerFunc: HandlerFunc(testHandler),
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
client := cluster.Cores[0].Client
|
|
|
|
req := client.NewRequest("GET", "/")
|
|
req.Headers = make(http.Header)
|
|
req.Headers.Set("x-forwarded-for", "2.3.4.5,3.4.5.6")
|
|
_, err := client.RawRequest(req)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "would skip before earliest") {
|
|
t.Fatalf("bad error message: %v", err)
|
|
}
|
|
})
|
|
|
|
// Next: test picking correct value
|
|
t.Run("correct_hop_skipping", func(t *testing.T) {
|
|
t.Parallel()
|
|
testHandler := func(props *vault.HandlerProperties) http.Handler {
|
|
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(r.RemoteAddr))
|
|
})
|
|
listenerConfig := getListenerConfigForMarshalerTest(goodAddr)
|
|
listenerConfig.XForwardedForRejectNotPresent = true
|
|
listenerConfig.XForwardedForRejectNotAuthorized = true
|
|
listenerConfig.XForwardedForHopSkips = 1
|
|
return WrapForwardedForHandler(origHandler, listenerConfig)
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
HandlerFunc: HandlerFunc(testHandler),
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
client := cluster.Cores[0].Client
|
|
|
|
req := client.NewRequest("GET", "/")
|
|
req.Headers = make(http.Header)
|
|
req.Headers.Set("x-forwarded-for", "2.3.4.5,3.4.5.6,4.5.6.7,5.6.7.8")
|
|
resp, err := client.RawRequest(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
buf := bytes.NewBuffer(nil)
|
|
buf.ReadFrom(resp.Body)
|
|
if !strings.HasPrefix(buf.String(), "4.5.6.7:") {
|
|
t.Fatalf("bad body: %s", buf.String())
|
|
}
|
|
})
|
|
|
|
// Next: multi-header approach
|
|
t.Run("correct_hop_skipping_multi_header", func(t *testing.T) {
|
|
t.Parallel()
|
|
testHandler := func(props *vault.HandlerProperties) http.Handler {
|
|
origHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(r.RemoteAddr))
|
|
})
|
|
listenerConfig := getListenerConfigForMarshalerTest(goodAddr)
|
|
listenerConfig.XForwardedForRejectNotPresent = true
|
|
listenerConfig.XForwardedForRejectNotAuthorized = true
|
|
listenerConfig.XForwardedForHopSkips = 1
|
|
return WrapForwardedForHandler(origHandler, listenerConfig)
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
HandlerFunc: HandlerFunc(testHandler),
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
client := cluster.Cores[0].Client
|
|
|
|
req := client.NewRequest("GET", "/")
|
|
req.Headers = make(http.Header)
|
|
req.Headers.Add("x-forwarded-for", "2.3.4.5")
|
|
req.Headers.Add("x-forwarded-for", "3.4.5.6,4.5.6.7")
|
|
req.Headers.Add("x-forwarded-for", "5.6.7.8")
|
|
resp, err := client.RawRequest(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
buf := bytes.NewBuffer(nil)
|
|
buf.ReadFrom(resp.Body)
|
|
if !strings.HasPrefix(buf.String(), "4.5.6.7:") {
|
|
t.Fatalf("bad body: %s", buf.String())
|
|
}
|
|
})
|
|
}
|