507 lines
12 KiB
507 lines
12 KiB
package checks
import (
func splitURL(u string) (string, string) {
// get the address and port for http server
tokens := strings.Split(u, ":")
addr, port := strings.TrimPrefix(tokens[1], "//"), tokens[2]
return addr, port
func TestChecker_Do_HTTP(t *testing.T) {
// an example response that will be truncated
tooLong, truncate := bigResponse()
// create an http server with various responses
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/fail":
_, _ = io.WriteString(w, "500 problem")
case "/hang":
time.Sleep(1 * time.Second)
_, _ = io.WriteString(w, "too slow")
case "/long-fail":
_, _ = io.WriteString(w, tooLong)
case "/long-not-fail":
_, _ = io.WriteString(w, tooLong)
_, _ = io.WriteString(w, "200 ok")
defer ts.Close()
// get the address and port for http server
addr, port := splitURL(ts.URL)
// create a mock clock so we can assert time is set
now := time.Date(2022, 1, 2, 3, 4, 5, 6, time.UTC)
clock := libtimetest.NewClockMock(t).NowMock.Return(now)
makeQueryContext := func() *QueryContext {
return &QueryContext{
ID: "abc123",
CustomAddress: addr,
ServicePortLabel: port,
Networks: nil,
NetworkStatus: mock.NewNetworkStatus(addr),
Ports: nil,
Group: "group",
Task: "task",
Service: "service",
Check: "check",
makeQuery := func(
kind structs.CheckMode,
path string,
) *Query {
return &Query{
Mode: kind,
Type: "http",
Timeout: 100 * time.Millisecond,
AddressMode: "auto",
PortLabel: port,
Protocol: "http",
Path: path,
Method: "GET",
makeExpResult := func(
kind structs.CheckMode,
status structs.CheckStatus,
code int,
output string,
) *structs.CheckQueryResult {
return &structs.CheckQueryResult{
ID: "abc123",
Mode: kind,
Status: status,
StatusCode: code,
Output: output,
Timestamp: now.Unix(),
Group: "group",
Task: "task",
Service: "service",
Check: "check",
cases := []struct {
name string
qc *QueryContext
q *Query
expResult *structs.CheckQueryResult
name: "200 healthiness",
qc: makeQueryContext(),
q: makeQuery(structs.Healthiness, "/"),
expResult: makeExpResult(
"nomad: http ok",
}, {
name: "200 readiness",
qc: makeQueryContext(),
q: makeQuery(structs.Readiness, "/"),
expResult: makeExpResult(
"nomad: http ok",
}, {
name: "500 healthiness",
qc: makeQueryContext(),
q: makeQuery(structs.Healthiness, "fail"),
expResult: makeExpResult(
"500 problem",
}, {
name: "hang",
qc: makeQueryContext(),
q: makeQuery(structs.Healthiness, "hang"),
expResult: makeExpResult(
fmt.Sprintf(`nomad: Get "%s/hang": context deadline exceeded`, ts.URL),
}, {
name: "500 truncate",
qc: makeQueryContext(),
q: makeQuery(structs.Healthiness, "long-fail"),
expResult: makeExpResult(
}, {
name: "201 truncate",
qc: makeQueryContext(),
q: makeQuery(structs.Healthiness, "long-not-fail"),
expResult: makeExpResult(
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
logger := testlog.HCLogger(t)
c := New(logger)
c.(*checker).clock = clock
ctx := context.Background()
result := c.Do(ctx, tc.qc, tc.q)
must.Eq(t, tc.expResult, result)
// bigResponse creates a response payload larger than the maximum outputSizeLimit
// as well as the same response but truncated to length of outputSizeLimit
func bigResponse() (string, string) {
size := outputSizeLimit + 5
b := make([]byte, size, size)
for i := 0; i < size; i++ {
b[i] = 'a'
s := string(b)
return s, s[:outputSizeLimit]
func TestChecker_Do_HTTP_extras(t *testing.T) {
// record the method, body, and headers of the request
var (
method string
body []byte
headers map[string][]string
host string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
method = r.Method
body, _ = io.ReadAll(r.Body)
headers = maps.Clone(r.Header)
host = r.Host
defer ts.Close()
// get the address and port for http server
addr, port := splitURL(ts.URL)
// make headers from key-value pairs
makeHeaders := func(more ...[2]string) http.Header {
h := make(http.Header)
for _, extra := range more {
h.Set(extra[0], extra[1])
return h
encoding := [2]string{"Accept-Encoding", "gzip"}
agent := [2]string{"User-Agent", "Go-http-client/1.1"}
cases := []struct {
name string
method string
body string
headers http.Header
name: "method GET",
method: "GET",
headers: makeHeaders(encoding, agent),
name: "method Get",
method: "Get",
headers: makeHeaders(encoding, agent),
name: "method HEAD",
method: "HEAD",
headers: makeHeaders(agent),
name: "extra headers",
method: "GET",
headers: makeHeaders(encoding, agent,
[2]string{"X-My-Header", "hello"},
[2]string{"Authorization", "Basic ZWxhc3RpYzpjaGFuZ2VtZQ=="},
name: "host header",
method: "GET",
headers: makeHeaders(encoding, agent,
[2]string{"Host", "hello"},
[2]string{"Test-Abc", "hello"},
name: "host header without normalization",
method: "GET",
body: "",
// This is needed to prevent header normalization by http.Header.Set
headers: func() map[string][]string {
h := makeHeaders(encoding, agent, [2]string{"Test-Abc", "hello"})
h["hoST"] = []string{"heLLO"}
return h
name: "with body",
method: "POST",
headers: makeHeaders(encoding, agent),
body: "some payload",
for _, tc := range cases {
qc := &QueryContext{
ID: "abc123",
CustomAddress: addr,
ServicePortLabel: port,
Networks: nil,
NetworkStatus: mock.NewNetworkStatus(addr),
Ports: nil,
Group: "group",
Task: "task",
Service: "service",
Check: "check",
q := &Query{
Mode: structs.Healthiness,
Type: "http",
Timeout: 1 * time.Second,
AddressMode: "auto",
PortLabel: port,
Protocol: "http",
Path: "/",
Method: tc.method,
Headers: tc.headers,
Body: tc.body,
t.Run(tc.name, func(t *testing.T) {
logger := testlog.HCLogger(t)
c := New(logger)
ctx := context.Background()
result := c.Do(ctx, qc, q)
must.Eq(t, http.StatusOK, result.StatusCode,
must.Sprintf("test.URL: %s", ts.URL),
must.Sprintf("headers: %v", tc.headers),
must.Sprintf("received headers: %v", tc.headers),
must.Eq(t, tc.method, method)
must.Eq(t, tc.body, string(body))
hostSent := false
for key, values := range tc.headers {
if strings.EqualFold(key, "Host") && len(values) > 0 {
must.Eq(t, values[0], host)
hostSent = true
delete(tc.headers, key)
if !hostSent {
must.Eq(t, nil, tc.headers["Host"])
must.Eq(t, tc.headers, headers)
func TestChecker_Do_TCP(t *testing.T) {
// create a mock clock so we can assert time is set
now := time.Date(2022, 1, 2, 3, 4, 5, 6, time.UTC)
clock := libtimetest.NewClockMock(t).NowMock.Return(now)
makeQueryContext := func(address string, port int) *QueryContext {
return &QueryContext{
ID: "abc123",
CustomAddress: address,
ServicePortLabel: fmt.Sprintf("%d", port),
Networks: nil,
NetworkStatus: mock.NewNetworkStatus(address),
Ports: nil,
Group: "group",
Task: "task",
Service: "service",
Check: "check",
makeQuery := func(
kind structs.CheckMode,
port int,
) *Query {
return &Query{
Mode: kind,
Type: "tcp",
Timeout: 100 * time.Millisecond,
AddressMode: "auto",
PortLabel: fmt.Sprintf("%d", port),
makeExpResult := func(
kind structs.CheckMode,
status structs.CheckStatus,
output string,
) *structs.CheckQueryResult {
return &structs.CheckQueryResult{
ID: "abc123",
Mode: kind,
Status: status,
Output: output,
Timestamp: now.Unix(),
Group: "group",
Task: "task",
Service: "service",
Check: "check",
ports := ci.PortAllocator.Grab(3)
cases := []struct {
name string
qc *QueryContext
q *Query
tcpMode string // "ok", "off", "hang"
tcpPort int
expResult *structs.CheckQueryResult
name: "tcp ok",
qc: makeQueryContext("localhost", ports[0]),
q: makeQuery(structs.Healthiness, ports[0]),
tcpMode: "ok",
tcpPort: ports[0],
expResult: makeExpResult(
"nomad: tcp ok",
}, {
name: "tcp not listening",
qc: makeQueryContext("", ports[1]),
q: makeQuery(structs.Healthiness, ports[1]),
tcpMode: "off",
tcpPort: ports[1],
expResult: makeExpResult(
fmt.Sprintf("dial tcp connect: connection refused", ports[1]),
}, {
name: "tcp slow accept",
qc: makeQueryContext("localhost", ports[2]),
q: makeQuery(structs.Healthiness, ports[2]),
tcpMode: "hang",
tcpPort: ports[2],
expResult: makeExpResult(
"dial tcp: lookup localhost: i/o timeout",
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
logger := testlog.HCLogger(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := New(logger)
c.(*checker).clock = clock
switch tc.tcpMode {
case "ok":
// simulate tcp server by listening
tcpServer(t, ctx, tc.tcpPort)
case "hang":
// simulate tcp hang by setting an already expired context
timeout, stop := context.WithDeadline(ctx, now.Add(-1*time.Second))
defer stop()
ctx = timeout
case "off":
// simulate tcp dead connection by not listening
result := c.Do(ctx, tc.qc, tc.q)
must.Eq(t, tc.expResult, result)
// tcpServer will start a tcp listener that accepts connections and closes them.
// The caller can close the listener by cancelling ctx.
func tcpServer(t *testing.T, ctx context.Context, port int) {
var lc net.ListenConfig
l, err := lc.Listen(ctx, "tcp", net.JoinHostPort(
"localhost", fmt.Sprintf("%d", port),
must.NoError(t, err, must.Sprint("port", port))
t.Cleanup(func() {
_ = l.Close()
go func() {
// caller can stop us by cancelling ctx
for {
_, acceptErr := l.Accept()
if acceptErr != nil {