open-consul/agent/http_ce_test.go

227 lines
5.6 KiB
Go
Raw Normal View History

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
2017-11-29 00:06:26 +00:00
package agent
import (
"fmt"
"net"
2017-11-29 00:06:26 +00:00
"net/http"
"net/http/httptest"
2017-11-29 00:06:26 +00:00
"strings"
"testing"
"time"
2017-11-29 00:06:26 +00:00
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/testrpc"
2017-11-29 00:06:26 +00:00
)
2018-03-03 19:19:18 +00:00
// extra endpoints that should be tested, and their allowed methods
var extraTestEndpoints = map[string][]string{
"/v1/query": {"GET", "POST"},
"/v1/query/": {"GET", "PUT", "DELETE"},
"/v1/query/xxx/execute": {"GET"},
"/v1/query/xxx/explain": {"GET"},
2018-03-03 19:19:18 +00:00
}
2018-03-16 14:36:57 +00:00
// These endpoints are ignored in unit testing for response codes
2018-03-27 20:31:27 +00:00
var ignoredEndpoints = []string{"/v1/status/peers", "/v1/agent/monitor", "/v1/agent/reload"}
2018-03-16 14:36:57 +00:00
// These have custom logic
var customEndpoints = []string{"/v1/query", "/v1/query/"}
// includePathInTest returns whether this path should be ignored for the purpose of testing its response code
2018-03-03 19:19:18 +00:00
func includePathInTest(path string) bool {
2018-03-16 14:36:57 +00:00
ignored := false
for _, p := range ignoredEndpoints {
if p == path {
ignored = true
break
}
}
for _, p := range customEndpoints {
if p == path {
ignored = true
break
}
}
return !ignored
}
func newHttpClient(timeout time.Duration) *http.Client {
return &http.Client{
Timeout: timeout,
Transport: &http.Transport{
Dial: (&net.Dialer{
Timeout: timeout,
}).Dial,
TLSHandshakeTimeout: timeout,
},
}
}
func TestHTTPAPI_MethodNotAllowed_CE(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
// To avoid actually triggering RPCs that are allowed, lock everything down
// with default-deny ACLs. This drops the test runtime from 11s to 0.6s.
a := NewTestAgent(t, `
primary_datacenter = "dc1"
acl {
enabled = true
default_policy = "deny"
tokens {
initial_management = "sekrit"
agent = "sekrit"
}
}
`)
2017-11-29 00:06:26 +00:00
defer a.Shutdown()
// Use the initial management token here so the wait actually works.
testrpc.WaitForTestAgent(t, a.RPC, "dc1", testrpc.WithToken("sekrit"))
2017-11-29 00:06:26 +00:00
all := []string{"GET", "PUT", "POST", "DELETE", "HEAD", "OPTIONS"}
client := newHttpClient(15 * time.Second)
2017-11-29 00:06:26 +00:00
testMethodNotAllowed := func(t *testing.T, method string, path string, allowedMethods []string) {
2018-03-03 19:19:18 +00:00
t.Run(method+" "+path, func(t *testing.T) {
uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path)
req, _ := http.NewRequest(method, uri, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatal("client.Do failed: ", err)
}
defer resp.Body.Close()
2017-11-29 00:06:26 +00:00
2018-03-03 19:19:18 +00:00
allowed := method == "OPTIONS"
for _, allowedMethod := range allowedMethods {
if allowedMethod == method {
allowed = true
break
2017-11-29 00:06:26 +00:00
}
2018-03-03 19:19:18 +00:00
}
if allowed && resp.StatusCode == http.StatusMethodNotAllowed {
t.Fatalf("method allowed: got status code %d want any other code", resp.StatusCode)
}
if !allowed && resp.StatusCode != http.StatusMethodNotAllowed {
t.Fatalf("method not allowed: got status code %d want %d", resp.StatusCode, http.StatusMethodNotAllowed)
}
})
}
for path, methods := range extraTestEndpoints {
for _, method := range all {
testMethodNotAllowed(t, method, path, methods)
2018-03-03 19:19:18 +00:00
}
}
for path, methods := range allowedMethods {
if includePathInTest(path) {
for _, method := range all {
testMethodNotAllowed(t, method, path, methods)
2018-03-03 19:19:18 +00:00
}
2017-11-29 00:06:26 +00:00
}
}
}
func TestHTTPAPI_OptionMethod_CE(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
a := NewTestAgent(t, `acl_datacenter = "dc1"`)
defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
2018-03-03 19:19:18 +00:00
testOptionMethod := func(path string, methods []string) {
t.Run("OPTIONS "+path, func(t *testing.T) {
uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path)
req, _ := http.NewRequest("OPTIONS", uri, nil)
resp := httptest.NewRecorder()
a.enableDebug.Store(true)
a.srv.handler().ServeHTTP(resp, req)
2018-03-03 19:19:18 +00:00
allMethods := append([]string{"OPTIONS"}, methods...)
if resp.Code != http.StatusOK {
t.Fatalf("options request: got status code %d want %d", resp.Code, http.StatusOK)
}
optionsStr := resp.Header().Get("Allow")
if optionsStr == "" {
t.Fatalf("options request: got empty 'Allow' header")
2018-03-03 19:19:18 +00:00
} else if optionsStr != strings.Join(allMethods, ",") {
t.Fatalf("options request: got 'Allow' header value of %s want %s", optionsStr, allMethods)
}
})
}
2018-03-03 19:19:18 +00:00
for path, methods := range extraTestEndpoints {
testOptionMethod(path, methods)
}
for path, methods := range allowedMethods {
if includePathInTest(path) {
testOptionMethod(path, methods)
}
}
}
func TestHTTPAPI_AllowedNets_CE(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
a := NewTestAgent(t, `
acl_datacenter = "dc1"
http_config {
allow_write_http_from = ["127.0.0.1/8"]
}
`)
defer a.Shutdown()
testrpc.WaitForTestAgent(t, a.RPC, "dc1")
testOptionMethod := func(path string, method string) {
t.Run(method+" "+path, func(t *testing.T) {
uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path)
req, _ := http.NewRequest(method, uri, nil)
req.RemoteAddr = "192.168.1.2:5555"
resp := httptest.NewRecorder()
a.enableDebug.Store(true)
a.srv.handler().ServeHTTP(resp, req)
require.Equal(t, http.StatusForbidden, resp.Code, "%s %s", method, path)
})
}
for path, methods := range extraTestEndpoints {
if !includePathInTest(path) {
continue
}
for _, method := range methods {
if method == http.MethodGet {
continue
}
testOptionMethod(path, method)
}
}
for path, methods := range allowedMethods {
if !includePathInTest(path) {
continue
}
for _, method := range methods {
if method == http.MethodGet {
continue
}
testOptionMethod(path, method)
}
}
}