3b57df33e3
Before this change, Client had 2 copies of the config object: config and configCopy. There was no guidance around which to use where (other than configCopy's comment to pass it to alloc runners), both are shared among goroutines and mutated in data racy ways. At least at one point I think the idea was to have `config` be mutable and then grab a lock to overwrite `configCopy`'s pointer atomically. This would have allowed alloc runners to read their config copies in data race safe ways, but this isn't how the current implementation worked. This change takes the following approach to safely handling configs in the client: 1. `Client.config` is the only copy of the config and all access must go through the `Client.configLock` mutex 2. Since the mutex *only protects the config pointer itself and not fields inside the Config struct:* all config mutation must be done on a *copy* of the config, and then Client's config pointer is overwritten while the mutex is acquired. Alloc runners and other goroutines with the old config pointer will not see config updates. 3. Deep copying is implemented on the Config struct to satisfy the previous approach. The TLS Keyloader is an exception because it has its own internal locking to support mutating in place. An unfortunate complication but one I couldn't find a way to untangle in a timely fashion. 4. To facilitate deep copying I made an *internally backward incompatible API change:* our `helper/funcs` used to turn containers (slices and maps) with 0 elements into nils. This probably saves a few memory allocations but makes it very easy to cause panics. Since my new config handling approach uses more copying, it became very difficult to ensure all code that used containers on configs could handle nils properly. Since this code has caused panics in the past, I fixed it: nil containers are copied as nil, but 0-element containers properly return a new 0-element container. No more "downgrading to nil!"
1489 lines
36 KiB
Go
1489 lines
36 KiB
Go
package agent
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-msgpack/codec"
|
|
"github.com/hashicorp/nomad/acl"
|
|
"github.com/hashicorp/nomad/api"
|
|
"github.com/hashicorp/nomad/ci"
|
|
"github.com/hashicorp/nomad/helper/pointer"
|
|
"github.com/hashicorp/nomad/helper/testlog"
|
|
"github.com/hashicorp/nomad/nomad/mock"
|
|
"github.com/hashicorp/nomad/nomad/structs"
|
|
"github.com/hashicorp/nomad/nomad/structs/config"
|
|
"github.com/hashicorp/nomad/testutil"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// makeHTTPServer returns a test server whose logs will be written to
|
|
// the passed writer. If the writer is nil, the logs are written to stderr.
|
|
func makeHTTPServer(t testing.TB, cb func(c *Config)) *TestAgent {
|
|
return NewTestAgent(t, t.Name(), cb)
|
|
}
|
|
|
|
func BenchmarkHTTPRequests(b *testing.B) {
|
|
s := makeHTTPServer(b, func(c *Config) {
|
|
c.Client.Enabled = false
|
|
})
|
|
defer s.Shutdown()
|
|
|
|
job := mock.Job()
|
|
var allocs []*structs.Allocation
|
|
count := 1000
|
|
for i := 0; i < count; i++ {
|
|
alloc := mock.Alloc()
|
|
alloc.Job = job
|
|
alloc.JobID = job.ID
|
|
alloc.Name = fmt.Sprintf("my-job.web[%d]", i)
|
|
allocs = append(allocs, alloc)
|
|
}
|
|
|
|
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
|
return allocs[:count], nil
|
|
}
|
|
b.ResetTimer()
|
|
|
|
b.RunParallel(func(pb *testing.PB) {
|
|
for pb.Next() {
|
|
resp := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
|
|
s.Server.wrap(handler)(resp, req)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestMultipleInterfaces(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
httpIps := []string{"127.0.0.1", "127.0.0.2"}
|
|
|
|
s := makeHTTPServer(t, func(c *Config) {
|
|
c.Addresses.HTTP = strings.Join(httpIps, " ")
|
|
c.ACL.Enabled = true
|
|
})
|
|
defer s.Shutdown()
|
|
|
|
httpPort := s.ports[0]
|
|
for _, ip := range httpIps {
|
|
resp, err := http.Get(fmt.Sprintf("http://%s:%d/", ip, httpPort))
|
|
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, resp.StatusCode, 200)
|
|
}
|
|
}
|
|
|
|
// TestRootFallthrough tests rootFallthrough handler to
|
|
// verify redirect and 404 behavior
|
|
func TestRootFallthrough(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
cases := []struct {
|
|
desc string
|
|
path string
|
|
expectedPath string
|
|
expectedRawQuery string
|
|
expectedCode int
|
|
}{
|
|
{
|
|
desc: "unknown endpoint 404s",
|
|
path: "/v1/unknown/endpoint",
|
|
expectedCode: 404,
|
|
},
|
|
{
|
|
desc: "root path redirects to ui",
|
|
path: "/",
|
|
expectedPath: "/ui/",
|
|
expectedCode: 307,
|
|
},
|
|
{
|
|
desc: "root path with one-time token redirects to ui",
|
|
path: "/?ott=whatever",
|
|
expectedPath: "/ui/",
|
|
expectedRawQuery: "ott=whatever",
|
|
expectedCode: 307,
|
|
},
|
|
}
|
|
|
|
s := makeHTTPServer(t, nil)
|
|
defer s.Shutdown()
|
|
|
|
// setup a client that doesn't follow redirects
|
|
client := &http.Client{
|
|
CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
|
|
return http.ErrUseLastResponse
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
|
|
reqURL := fmt.Sprintf("http://%s%s", s.Agent.config.AdvertiseAddrs.HTTP, tc.path)
|
|
|
|
resp, err := client.Get(reqURL)
|
|
require.NoError(t, err)
|
|
require.Equal(t, tc.expectedCode, resp.StatusCode)
|
|
|
|
if tc.expectedPath != "" {
|
|
loc, err := resp.Location()
|
|
require.NoError(t, err)
|
|
require.Equal(t, tc.expectedPath, loc.Path)
|
|
require.Equal(t, tc.expectedRawQuery, loc.RawQuery)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSetIndex(t *testing.T) {
|
|
ci.Parallel(t)
|
|
resp := httptest.NewRecorder()
|
|
setIndex(resp, 1000)
|
|
header := resp.Header().Get("X-Nomad-Index")
|
|
if header != "1000" {
|
|
t.Fatalf("Bad: %v", header)
|
|
}
|
|
setIndex(resp, 2000)
|
|
if v := resp.Header()["X-Nomad-Index"]; len(v) != 1 {
|
|
t.Fatalf("bad: %#v", v)
|
|
}
|
|
}
|
|
|
|
func TestSetKnownLeader(t *testing.T) {
|
|
ci.Parallel(t)
|
|
resp := httptest.NewRecorder()
|
|
setKnownLeader(resp, true)
|
|
header := resp.Header().Get("X-Nomad-KnownLeader")
|
|
if header != "true" {
|
|
t.Fatalf("Bad: %v", header)
|
|
}
|
|
resp = httptest.NewRecorder()
|
|
setKnownLeader(resp, false)
|
|
header = resp.Header().Get("X-Nomad-KnownLeader")
|
|
if header != "false" {
|
|
t.Fatalf("Bad: %v", header)
|
|
}
|
|
}
|
|
|
|
func TestSetLastContact(t *testing.T) {
|
|
ci.Parallel(t)
|
|
resp := httptest.NewRecorder()
|
|
setLastContact(resp, 123456*time.Microsecond)
|
|
header := resp.Header().Get("X-Nomad-LastContact")
|
|
if header != "123" {
|
|
t.Fatalf("Bad: %v", header)
|
|
}
|
|
}
|
|
|
|
func TestSetMeta(t *testing.T) {
|
|
ci.Parallel(t)
|
|
meta := structs.QueryMeta{
|
|
Index: 1000,
|
|
KnownLeader: true,
|
|
LastContact: 123456 * time.Microsecond,
|
|
}
|
|
resp := httptest.NewRecorder()
|
|
setMeta(resp, &meta)
|
|
header := resp.Header().Get("X-Nomad-Index")
|
|
if header != "1000" {
|
|
t.Fatalf("Bad: %v", header)
|
|
}
|
|
header = resp.Header().Get("X-Nomad-KnownLeader")
|
|
if header != "true" {
|
|
t.Fatalf("Bad: %v", header)
|
|
}
|
|
header = resp.Header().Get("X-Nomad-LastContact")
|
|
if header != "123" {
|
|
t.Fatalf("Bad: %v", header)
|
|
}
|
|
}
|
|
|
|
func TestSetHeaders(t *testing.T) {
|
|
ci.Parallel(t)
|
|
s := makeHTTPServer(t, nil)
|
|
s.Agent.config.HTTPAPIResponseHeaders = map[string]string{"foo": "bar"}
|
|
defer s.Shutdown()
|
|
|
|
resp := httptest.NewRecorder()
|
|
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
|
return &structs.Job{Name: "foo"}, nil
|
|
}
|
|
|
|
req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
|
|
s.Server.wrap(handler)(resp, req)
|
|
header := resp.Header().Get("foo")
|
|
|
|
if header != "bar" {
|
|
t.Fatalf("expected header: %v, actual: %v", "bar", header)
|
|
}
|
|
|
|
}
|
|
|
|
func TestContentTypeIsJSON(t *testing.T) {
|
|
ci.Parallel(t)
|
|
s := makeHTTPServer(t, nil)
|
|
defer s.Shutdown()
|
|
|
|
resp := httptest.NewRecorder()
|
|
|
|
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
|
return &structs.Job{Name: "foo"}, nil
|
|
}
|
|
|
|
req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
|
|
s.Server.wrap(handler)(resp, req)
|
|
|
|
contentType := resp.Header().Get("Content-Type")
|
|
|
|
if contentType != "application/json" {
|
|
t.Fatalf("Content-Type header was not 'application/json'")
|
|
}
|
|
}
|
|
|
|
func TestWrapNonJSON(t *testing.T) {
|
|
ci.Parallel(t)
|
|
s := makeHTTPServer(t, nil)
|
|
defer s.Shutdown()
|
|
|
|
resp := httptest.NewRecorder()
|
|
|
|
handler := func(resp http.ResponseWriter, req *http.Request) ([]byte, error) {
|
|
return []byte("test response"), nil
|
|
}
|
|
|
|
req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
|
|
s.Server.wrapNonJSON(handler)(resp, req)
|
|
|
|
respBody, _ := ioutil.ReadAll(resp.Body)
|
|
require.Equal(t, respBody, []byte("test response"))
|
|
|
|
}
|
|
|
|
func TestWrapNonJSON_Error(t *testing.T) {
|
|
ci.Parallel(t)
|
|
s := makeHTTPServer(t, nil)
|
|
defer s.Shutdown()
|
|
|
|
handlerRPCErr := func(resp http.ResponseWriter, req *http.Request) ([]byte, error) {
|
|
return nil, structs.NewErrRPCCoded(404, "not found")
|
|
}
|
|
|
|
handlerCodedErr := func(resp http.ResponseWriter, req *http.Request) ([]byte, error) {
|
|
return nil, CodedError(422, "unprocessable")
|
|
}
|
|
|
|
// RPC coded error
|
|
{
|
|
resp := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
|
|
s.Server.wrapNonJSON(handlerRPCErr)(resp, req)
|
|
respBody, _ := ioutil.ReadAll(resp.Body)
|
|
require.Equal(t, []byte("not found"), respBody)
|
|
require.Equal(t, 404, resp.Code)
|
|
}
|
|
|
|
// CodedError
|
|
{
|
|
resp := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
|
|
s.Server.wrapNonJSON(handlerCodedErr)(resp, req)
|
|
respBody, _ := ioutil.ReadAll(resp.Body)
|
|
require.Equal(t, []byte("unprocessable"), respBody)
|
|
require.Equal(t, 422, resp.Code)
|
|
}
|
|
|
|
}
|
|
|
|
func TestPrettyPrint(t *testing.T) {
|
|
ci.Parallel(t)
|
|
testPrettyPrint("pretty=1", true, t)
|
|
}
|
|
|
|
func TestPrettyPrintOff(t *testing.T) {
|
|
ci.Parallel(t)
|
|
testPrettyPrint("pretty=0", false, t)
|
|
}
|
|
|
|
func TestPrettyPrintBare(t *testing.T) {
|
|
ci.Parallel(t)
|
|
testPrettyPrint("pretty", true, t)
|
|
}
|
|
|
|
func testPrettyPrint(pretty string, prettyFmt bool, t *testing.T) {
|
|
s := makeHTTPServer(t, nil)
|
|
defer s.Shutdown()
|
|
|
|
r := &structs.Job{Name: "foo"}
|
|
|
|
resp := httptest.NewRecorder()
|
|
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
|
return r, nil
|
|
}
|
|
|
|
urlStr := "/v1/job/foo?" + pretty
|
|
req, _ := http.NewRequest("GET", urlStr, nil)
|
|
s.Server.wrap(handler)(resp, req)
|
|
|
|
var expected bytes.Buffer
|
|
var err error
|
|
if prettyFmt {
|
|
enc := codec.NewEncoder(&expected, structs.JsonHandlePretty)
|
|
err = enc.Encode(r)
|
|
expected.WriteByte('\n')
|
|
} else {
|
|
enc := codec.NewEncoder(&expected, structs.JsonHandleWithExtensions)
|
|
err = enc.Encode(r)
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("failed to encode: %v", err)
|
|
}
|
|
actual, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
if !bytes.Equal(expected.Bytes(), actual) {
|
|
t.Fatalf("bad:\nexpected:\t%q\nactual:\t\t%q", expected.String(), string(actual))
|
|
}
|
|
}
|
|
|
|
func TestPermissionDenied(t *testing.T) {
|
|
s := makeHTTPServer(t, func(c *Config) {
|
|
c.ACL.Enabled = true
|
|
})
|
|
defer s.Shutdown()
|
|
|
|
{
|
|
resp := httptest.NewRecorder()
|
|
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
|
return nil, structs.ErrPermissionDenied
|
|
}
|
|
|
|
req, _ := http.NewRequest("GET", "/v1/job/foo", nil)
|
|
s.Server.wrap(handler)(resp, req)
|
|
assert.Equal(t, resp.Code, 403)
|
|
}
|
|
|
|
// When remote RPC is used the errors have "rpc error: " prependend
|
|
{
|
|
resp := httptest.NewRecorder()
|
|
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
|
return nil, fmt.Errorf("rpc error: %v", structs.ErrPermissionDenied)
|
|
}
|
|
|
|
req, _ := http.NewRequest("GET", "/v1/job/foo", nil)
|
|
s.Server.wrap(handler)(resp, req)
|
|
assert.Equal(t, resp.Code, 403)
|
|
}
|
|
}
|
|
|
|
func TestTokenNotFound(t *testing.T) {
|
|
s := makeHTTPServer(t, func(c *Config) {
|
|
c.ACL.Enabled = true
|
|
})
|
|
defer s.Shutdown()
|
|
|
|
resp := httptest.NewRecorder()
|
|
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
|
return nil, structs.ErrTokenNotFound
|
|
}
|
|
|
|
urlStr := "/v1/job/foo"
|
|
req, _ := http.NewRequest("GET", urlStr, nil)
|
|
s.Server.wrap(handler)(resp, req)
|
|
assert.Equal(t, resp.Code, 403)
|
|
}
|
|
|
|
func TestParseWait(t *testing.T) {
|
|
ci.Parallel(t)
|
|
resp := httptest.NewRecorder()
|
|
var b structs.QueryOptions
|
|
|
|
req, err := http.NewRequest("GET",
|
|
"/v1/catalog/nodes?wait=60s&index=1000", nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
if d := parseWait(resp, req, &b); d {
|
|
t.Fatalf("unexpected done")
|
|
}
|
|
|
|
if b.MinQueryIndex != 1000 {
|
|
t.Fatalf("Bad: %v", b)
|
|
}
|
|
if b.MaxQueryTime != 60*time.Second {
|
|
t.Fatalf("Bad: %v", b)
|
|
}
|
|
}
|
|
|
|
func TestParseWait_InvalidTime(t *testing.T) {
|
|
ci.Parallel(t)
|
|
resp := httptest.NewRecorder()
|
|
var b structs.QueryOptions
|
|
|
|
req, err := http.NewRequest("GET",
|
|
"/v1/catalog/nodes?wait=60foo&index=1000", nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
if d := parseWait(resp, req, &b); !d {
|
|
t.Fatalf("expected done")
|
|
}
|
|
|
|
if resp.Code != 400 {
|
|
t.Fatalf("bad code: %v", resp.Code)
|
|
}
|
|
}
|
|
|
|
func TestParseWait_InvalidIndex(t *testing.T) {
|
|
ci.Parallel(t)
|
|
resp := httptest.NewRecorder()
|
|
var b structs.QueryOptions
|
|
|
|
req, err := http.NewRequest("GET",
|
|
"/v1/catalog/nodes?wait=60s&index=foo", nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
if d := parseWait(resp, req, &b); !d {
|
|
t.Fatalf("expected done")
|
|
}
|
|
|
|
if resp.Code != 400 {
|
|
t.Fatalf("bad code: %v", resp.Code)
|
|
}
|
|
}
|
|
|
|
func TestParseConsistency(t *testing.T) {
|
|
ci.Parallel(t)
|
|
var b structs.QueryOptions
|
|
|
|
req, err := http.NewRequest("GET",
|
|
"/v1/catalog/nodes?stale", nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
parseConsistency(req, &b)
|
|
if !b.AllowStale {
|
|
t.Fatalf("Bad: %v", b)
|
|
}
|
|
|
|
b = structs.QueryOptions{}
|
|
req, err = http.NewRequest("GET",
|
|
"/v1/catalog/nodes?consistent", nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
parseConsistency(req, &b)
|
|
if b.AllowStale {
|
|
t.Fatalf("Bad: %v", b)
|
|
}
|
|
}
|
|
|
|
func TestParseRegion(t *testing.T) {
|
|
ci.Parallel(t)
|
|
s := makeHTTPServer(t, nil)
|
|
defer s.Shutdown()
|
|
|
|
req, err := http.NewRequest("GET",
|
|
"/v1/jobs?region=foo", nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
var region string
|
|
s.Server.parseRegion(req, ®ion)
|
|
if region != "foo" {
|
|
t.Fatalf("bad %s", region)
|
|
}
|
|
|
|
region = ""
|
|
req, err = http.NewRequest("GET", "/v1/jobs", nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
s.Server.parseRegion(req, ®ion)
|
|
if region != "global" {
|
|
t.Fatalf("bad %s", region)
|
|
}
|
|
}
|
|
|
|
func TestParseToken(t *testing.T) {
|
|
ci.Parallel(t)
|
|
s := makeHTTPServer(t, nil)
|
|
defer s.Shutdown()
|
|
|
|
cases := []struct {
|
|
Name string
|
|
HeaderKey string
|
|
HeaderValue string
|
|
ExpectedToken string
|
|
}{
|
|
{
|
|
Name: "Parses token from X-Nomad-Token",
|
|
HeaderKey: "X-Nomad-Token",
|
|
HeaderValue: "foobar",
|
|
ExpectedToken: "foobar",
|
|
},
|
|
{
|
|
Name: "Parses token from bearer authentication",
|
|
HeaderKey: "Authorization",
|
|
HeaderValue: "Bearer foobar",
|
|
ExpectedToken: "foobar",
|
|
},
|
|
{
|
|
Name: "Fails to parse token from bad bearer authentication",
|
|
HeaderKey: "Authorization",
|
|
HeaderValue: "foobar",
|
|
ExpectedToken: "",
|
|
},
|
|
}
|
|
|
|
for i := range cases {
|
|
tc := cases[i]
|
|
t.Run(tc.Name, func(t *testing.T) {
|
|
req, err := http.NewRequest("GET", "/v1/jobs", nil)
|
|
req.Header.Add(tc.HeaderKey, tc.HeaderValue)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
var token string
|
|
s.Server.parseToken(req, &token)
|
|
if token != tc.ExpectedToken {
|
|
t.Fatalf("bad %s", token)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseBool(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
cases := []struct {
|
|
Input string
|
|
Expected *bool
|
|
Err bool // true if an error should be expected
|
|
}{
|
|
{
|
|
Input: "",
|
|
Expected: nil,
|
|
},
|
|
{
|
|
Input: "true",
|
|
Expected: pointer.Of(true),
|
|
},
|
|
{
|
|
Input: "false",
|
|
Expected: pointer.Of(false),
|
|
},
|
|
{
|
|
Input: "1234",
|
|
Err: true,
|
|
},
|
|
}
|
|
|
|
for i := range cases {
|
|
tc := cases[i]
|
|
t.Run("Input-"+tc.Input, func(t *testing.T) {
|
|
testURL, err := url.Parse("http://localhost/foo?resources=" + tc.Input)
|
|
require.NoError(t, err)
|
|
req := &http.Request{
|
|
URL: testURL,
|
|
}
|
|
|
|
result, err := parseBool(req, "resources")
|
|
if tc.Err {
|
|
require.Error(t, err)
|
|
require.Nil(t, result)
|
|
} else {
|
|
require.NoError(t, err)
|
|
require.Equal(t, tc.Expected, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_parseInt(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
cases := []struct {
|
|
Input string
|
|
Expected *int
|
|
Err bool
|
|
}{
|
|
{
|
|
Input: "",
|
|
Expected: nil,
|
|
},
|
|
{
|
|
Input: "13",
|
|
Expected: pointer.Of(13),
|
|
},
|
|
{
|
|
Input: "99",
|
|
Expected: pointer.Of(99),
|
|
},
|
|
{
|
|
Input: "ten",
|
|
Err: true,
|
|
},
|
|
}
|
|
|
|
for i := range cases {
|
|
tc := cases[i]
|
|
t.Run("Input-"+tc.Input, func(t *testing.T) {
|
|
testURL, err := url.Parse("http://localhost/foo?eval_priority=" + tc.Input)
|
|
require.NoError(t, err)
|
|
req := &http.Request{
|
|
URL: testURL,
|
|
}
|
|
|
|
result, err := parseInt(req, "eval_priority")
|
|
if tc.Err {
|
|
require.Error(t, err)
|
|
require.Nil(t, result)
|
|
} else {
|
|
require.NoError(t, err)
|
|
require.Equal(t, tc.Expected, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParsePagination(t *testing.T) {
|
|
ci.Parallel(t)
|
|
s := makeHTTPServer(t, nil)
|
|
defer s.Shutdown()
|
|
|
|
cases := []struct {
|
|
Input string
|
|
ExpectedNextToken string
|
|
ExpectedPerPage int32
|
|
}{
|
|
{
|
|
Input: "",
|
|
},
|
|
{
|
|
Input: "next_token=a&per_page=3",
|
|
ExpectedNextToken: "a",
|
|
ExpectedPerPage: 3,
|
|
},
|
|
{
|
|
Input: "next_token=a&next_token=b",
|
|
ExpectedNextToken: "a",
|
|
},
|
|
{
|
|
Input: "per_page=a",
|
|
},
|
|
}
|
|
|
|
for i := range cases {
|
|
tc := cases[i]
|
|
t.Run("Input-"+tc.Input, func(t *testing.T) {
|
|
|
|
req, err := http.NewRequest("GET",
|
|
"/v1/volumes/csi/external?"+tc.Input, nil)
|
|
|
|
require.NoError(t, err)
|
|
opts := &structs.QueryOptions{}
|
|
parsePagination(req, opts)
|
|
require.Equal(t, tc.ExpectedNextToken, opts.NextToken)
|
|
require.Equal(t, tc.ExpectedPerPage, opts.PerPage)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestHTTP_VerifyHTTPSClient asserts that a client certificate signed by the
|
|
// appropriate CA is required when VerifyHTTPSClient=true.
|
|
func TestHTTP_VerifyHTTPSClient(t *testing.T) {
|
|
ci.Parallel(t)
|
|
const (
|
|
cafile = "../../helper/tlsutil/testdata/ca.pem"
|
|
foocert = "../../helper/tlsutil/testdata/nomad-foo.pem"
|
|
fookey = "../../helper/tlsutil/testdata/nomad-foo-key.pem"
|
|
)
|
|
s := makeHTTPServer(t, func(c *Config) {
|
|
c.Region = "foo" // match the region on foocert
|
|
c.TLSConfig = &config.TLSConfig{
|
|
EnableHTTP: true,
|
|
VerifyHTTPSClient: true,
|
|
CAFile: cafile,
|
|
CertFile: foocert,
|
|
KeyFile: fookey,
|
|
}
|
|
})
|
|
defer s.Shutdown()
|
|
|
|
reqURL := fmt.Sprintf("https://%s/v1/agent/self", s.Agent.config.AdvertiseAddrs.HTTP)
|
|
|
|
// FAIL: Requests that expect 127.0.0.1 as the name should fail
|
|
resp, err := http.Get(reqURL)
|
|
if err == nil {
|
|
resp.Body.Close()
|
|
t.Fatalf("expected non-nil error but received: %v", resp.StatusCode)
|
|
}
|
|
urlErr, ok := err.(*url.Error)
|
|
if !ok {
|
|
t.Fatalf("expected a *url.Error but received: %T -> %v", err, err)
|
|
}
|
|
hostErr, ok := urlErr.Err.(x509.HostnameError)
|
|
if !ok {
|
|
t.Fatalf("expected a x509.HostnameError but received: %T -> %v", urlErr.Err, urlErr.Err)
|
|
}
|
|
if expected := "127.0.0.1"; hostErr.Host != expected {
|
|
t.Fatalf("expected hostname on error to be %q but found %q", expected, hostErr.Host)
|
|
}
|
|
|
|
// FAIL: Requests that specify a valid hostname but not the CA should
|
|
// fail
|
|
tlsConf := &tls.Config{
|
|
ServerName: "client.regionFoo.nomad",
|
|
}
|
|
transport := &http.Transport{TLSClientConfig: tlsConf}
|
|
client := &http.Client{Transport: transport}
|
|
req, err := http.NewRequest("GET", reqURL, nil)
|
|
if err != nil {
|
|
t.Fatalf("error creating request: %v", err)
|
|
}
|
|
resp, err = client.Do(req)
|
|
if err == nil {
|
|
resp.Body.Close()
|
|
t.Fatalf("expected non-nil error but received: %v", resp.StatusCode)
|
|
}
|
|
urlErr, ok = err.(*url.Error)
|
|
if !ok {
|
|
t.Fatalf("expected a *url.Error but received: %T -> %v", err, err)
|
|
}
|
|
_, ok = urlErr.Err.(x509.UnknownAuthorityError)
|
|
if !ok {
|
|
t.Fatalf("expected a x509.UnknownAuthorityError but received: %T -> %v", urlErr.Err, urlErr.Err)
|
|
}
|
|
|
|
// FAIL: Requests that specify a valid hostname and CA cert but lack a
|
|
// client certificate should fail
|
|
cacertBytes, err := ioutil.ReadFile(cafile)
|
|
if err != nil {
|
|
t.Fatalf("error reading cacert: %v", err)
|
|
}
|
|
tlsConf.RootCAs = x509.NewCertPool()
|
|
tlsConf.RootCAs.AppendCertsFromPEM(cacertBytes)
|
|
req, err = http.NewRequest("GET", reqURL, nil)
|
|
if err != nil {
|
|
t.Fatalf("error creating request: %v", err)
|
|
}
|
|
resp, err = client.Do(req)
|
|
if err == nil {
|
|
resp.Body.Close()
|
|
t.Fatalf("expected non-nil error but received: %v", resp.StatusCode)
|
|
}
|
|
urlErr, ok = err.(*url.Error)
|
|
if !ok {
|
|
t.Fatalf("expected a *url.Error but received: %T -> %v", err, err)
|
|
}
|
|
|
|
var opErr *net.OpError
|
|
ok = errors.As(urlErr.Err, &opErr)
|
|
if !ok {
|
|
t.Fatalf("expected a *net.OpErr but received: %T -> %v", urlErr.Err, urlErr.Err)
|
|
}
|
|
const badCertificate = "tls: bad certificate" // from crypto/tls/alert.go:52 and RFC 5246 § A.3
|
|
if opErr.Err.Error() != badCertificate {
|
|
t.Fatalf("expected tls.alert bad_certificate but received: %q", opErr.Err.Error())
|
|
}
|
|
|
|
// PASS: Requests that specify a valid hostname, CA cert, and client
|
|
// certificate succeed.
|
|
tlsConf.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
c, err := tls.LoadX509KeyPair(foocert, fookey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &c, nil
|
|
}
|
|
transport = &http.Transport{TLSClientConfig: tlsConf}
|
|
client = &http.Client{Transport: transport}
|
|
req, err = http.NewRequest("GET", reqURL, nil)
|
|
if err != nil {
|
|
t.Fatalf("error creating request: %v", err)
|
|
}
|
|
resp, err = client.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
resp.Body.Close()
|
|
if resp.StatusCode != 200 {
|
|
t.Fatalf("expected 200 status code but got: %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTP_VerifyHTTPSClient_AfterConfigReload(t *testing.T) {
|
|
ci.Parallel(t)
|
|
assert := assert.New(t)
|
|
|
|
const (
|
|
cafile = "../../helper/tlsutil/testdata/ca.pem"
|
|
foocert = "../../helper/tlsutil/testdata/nomad-bad.pem"
|
|
fookey = "../../helper/tlsutil/testdata/nomad-bad-key.pem"
|
|
foocert2 = "../../helper/tlsutil/testdata/nomad-foo.pem"
|
|
fookey2 = "../../helper/tlsutil/testdata/nomad-foo-key.pem"
|
|
)
|
|
|
|
agentConfig := &Config{
|
|
TLSConfig: &config.TLSConfig{
|
|
EnableHTTP: true,
|
|
VerifyHTTPSClient: true,
|
|
CAFile: cafile,
|
|
CertFile: foocert,
|
|
KeyFile: fookey,
|
|
},
|
|
}
|
|
|
|
newConfig := &Config{
|
|
TLSConfig: &config.TLSConfig{
|
|
EnableHTTP: true,
|
|
VerifyHTTPSClient: true,
|
|
CAFile: cafile,
|
|
CertFile: foocert2,
|
|
KeyFile: fookey2,
|
|
},
|
|
}
|
|
|
|
s := makeHTTPServer(t, func(c *Config) {
|
|
c.TLSConfig = agentConfig.TLSConfig
|
|
})
|
|
defer s.Shutdown()
|
|
|
|
// Make an initial request that should fail.
|
|
// Requests that specify a valid hostname, CA cert, and client
|
|
// certificate succeed.
|
|
tlsConf := &tls.Config{
|
|
ServerName: "client.regionFoo.nomad",
|
|
RootCAs: x509.NewCertPool(),
|
|
GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
c, err := tls.LoadX509KeyPair(foocert, fookey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &c, nil
|
|
},
|
|
}
|
|
|
|
// HTTPS request should succeed
|
|
httpsReqURL := fmt.Sprintf("https://%s/v1/agent/self", s.Agent.config.AdvertiseAddrs.HTTP)
|
|
|
|
cacertBytes, err := ioutil.ReadFile(cafile)
|
|
assert.Nil(err)
|
|
tlsConf.RootCAs.AppendCertsFromPEM(cacertBytes)
|
|
|
|
transport := &http.Transport{TLSClientConfig: tlsConf}
|
|
client := &http.Client{Transport: transport}
|
|
req, err := http.NewRequest("GET", httpsReqURL, nil)
|
|
assert.Nil(err)
|
|
|
|
// Check that we get an error that the certificate isn't valid for the
|
|
// region we are contacting.
|
|
_, err = client.Do(req)
|
|
assert.Contains(err.Error(), "certificate is valid for")
|
|
|
|
// Reload the TLS configuration==
|
|
assert.Nil(s.Agent.Reload(newConfig))
|
|
|
|
// Requests that specify a valid hostname, CA cert, and client
|
|
// certificate succeed.
|
|
tlsConf = &tls.Config{
|
|
ServerName: "client.regionFoo.nomad",
|
|
RootCAs: x509.NewCertPool(),
|
|
GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
c, err := tls.LoadX509KeyPair(foocert2, fookey2)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &c, nil
|
|
},
|
|
}
|
|
|
|
cacertBytes, err = ioutil.ReadFile(cafile)
|
|
assert.Nil(err)
|
|
tlsConf.RootCAs.AppendCertsFromPEM(cacertBytes)
|
|
|
|
transport = &http.Transport{TLSClientConfig: tlsConf}
|
|
client = &http.Client{Transport: transport}
|
|
req, err = http.NewRequest("GET", httpsReqURL, nil)
|
|
assert.Nil(err)
|
|
|
|
resp, err := client.Do(req)
|
|
if assert.NoError(err) {
|
|
resp.Body.Close()
|
|
assert.Equal(resp.StatusCode, 200)
|
|
}
|
|
}
|
|
|
|
// TestHTTPServer_Limits_Error asserts invalid Limits cause errors. This is the
|
|
// HTTP counterpart to TestAgent_ServerConfig_Limits_Error.
|
|
func TestHTTPServer_Limits_Error(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
cases := []struct {
|
|
tls bool
|
|
timeout string
|
|
limit *int
|
|
expectedErr string
|
|
}{
|
|
{
|
|
tls: true,
|
|
timeout: "",
|
|
limit: nil,
|
|
expectedErr: "error parsing https_handshake_timeout: ",
|
|
},
|
|
{
|
|
tls: false,
|
|
timeout: "",
|
|
limit: nil,
|
|
expectedErr: "error parsing https_handshake_timeout: ",
|
|
},
|
|
{
|
|
tls: true,
|
|
timeout: "-1s",
|
|
limit: nil,
|
|
expectedErr: "https_handshake_timeout must be >= 0",
|
|
},
|
|
{
|
|
tls: false,
|
|
timeout: "-1s",
|
|
limit: nil,
|
|
expectedErr: "https_handshake_timeout must be >= 0",
|
|
},
|
|
{
|
|
tls: true,
|
|
timeout: "5s",
|
|
limit: pointer.Of(-1),
|
|
expectedErr: "http_max_conns_per_client must be >= 0",
|
|
},
|
|
{
|
|
tls: false,
|
|
timeout: "5s",
|
|
limit: pointer.Of(-1),
|
|
expectedErr: "http_max_conns_per_client must be >= 0",
|
|
},
|
|
}
|
|
|
|
for i := range cases {
|
|
tc := cases[i]
|
|
name := fmt.Sprintf("%d-tls-%t-timeout-%s-limit-%v", i, tc.tls, tc.timeout, tc.limit)
|
|
t.Run(name, func(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
conf := &Config{
|
|
normalizedAddrs: &NormalizedAddrs{
|
|
HTTP: []string{"localhost:0"}, // port is never used
|
|
},
|
|
TLSConfig: &config.TLSConfig{
|
|
EnableHTTP: tc.tls,
|
|
},
|
|
Limits: config.Limits{
|
|
HTTPSHandshakeTimeout: tc.timeout,
|
|
HTTPMaxConnsPerClient: tc.limit,
|
|
},
|
|
}
|
|
|
|
// Use a fake agent since the HTTP server should never start
|
|
agent := &Agent{
|
|
logger: testlog.HCLogger(t),
|
|
httpLogger: testlog.HCLogger(t),
|
|
config: conf,
|
|
}
|
|
|
|
srv, err := NewHTTPServers(agent, conf)
|
|
require.Error(t, err)
|
|
require.Nil(t, srv)
|
|
require.Contains(t, err.Error(), tc.expectedErr)
|
|
})
|
|
}
|
|
}
|
|
|
|
func limitStr(limit *int) string {
|
|
if limit == nil {
|
|
return "none"
|
|
}
|
|
return strconv.Itoa(*limit)
|
|
}
|
|
|
|
// TestHTTPServer_Limits_OK asserts that all valid limits combinations
|
|
// (tls/timeout/conns) work.
|
|
func TestHTTPServer_Limits_OK(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
const (
|
|
cafile = "../../helper/tlsutil/testdata/ca.pem"
|
|
foocert = "../../helper/tlsutil/testdata/nomad-foo.pem"
|
|
fookey = "../../helper/tlsutil/testdata/nomad-foo-key.pem"
|
|
maxConns = 10 // limit must be < this for testing
|
|
bufSize = 1 // enough to know if something was written
|
|
)
|
|
|
|
cases := []struct {
|
|
tls bool
|
|
timeout string
|
|
limit *int
|
|
assertTimeout bool
|
|
assertLimit bool
|
|
}{
|
|
{
|
|
tls: false,
|
|
timeout: "5s",
|
|
limit: nil,
|
|
assertTimeout: false,
|
|
assertLimit: false,
|
|
},
|
|
{
|
|
tls: true,
|
|
timeout: "5s",
|
|
limit: nil,
|
|
assertTimeout: true,
|
|
assertLimit: false,
|
|
},
|
|
{
|
|
tls: false,
|
|
timeout: "0",
|
|
limit: nil,
|
|
assertTimeout: false,
|
|
assertLimit: false,
|
|
},
|
|
{
|
|
tls: true,
|
|
timeout: "0",
|
|
limit: nil,
|
|
assertTimeout: false,
|
|
assertLimit: false,
|
|
},
|
|
{
|
|
tls: false,
|
|
timeout: "0",
|
|
limit: pointer.Of(2),
|
|
assertTimeout: false,
|
|
assertLimit: true,
|
|
},
|
|
{
|
|
tls: true,
|
|
timeout: "0",
|
|
limit: pointer.Of(2),
|
|
assertTimeout: false,
|
|
assertLimit: true,
|
|
},
|
|
{
|
|
tls: false,
|
|
timeout: "5s",
|
|
limit: pointer.Of(2),
|
|
assertTimeout: false,
|
|
assertLimit: true,
|
|
},
|
|
{
|
|
tls: true,
|
|
timeout: "5s",
|
|
limit: pointer.Of(2),
|
|
assertTimeout: true,
|
|
assertLimit: true,
|
|
},
|
|
}
|
|
|
|
assertTimeout := func(t *testing.T, a *TestAgent, assertTimeout bool, timeout string) {
|
|
timeoutDeadline, err := time.ParseDuration(timeout)
|
|
require.NoError(t, err)
|
|
|
|
// Increase deadline to detect timeouts
|
|
deadline := timeoutDeadline + time.Second
|
|
|
|
conn, err := net.DialTimeout("tcp", a.Server.Addr, deadline)
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
buf := []byte{0}
|
|
readDeadline := time.Now().Add(deadline)
|
|
err = conn.SetReadDeadline(readDeadline)
|
|
require.NoError(t, err)
|
|
n, err := conn.Read(buf)
|
|
require.Zero(t, n)
|
|
if assertTimeout {
|
|
// Server timeouts == EOF
|
|
require.Equal(t, io.EOF, err)
|
|
|
|
// Perform blocking query to assert timeout is not
|
|
// enabled post-TLS-handshake.
|
|
q := &api.QueryOptions{
|
|
WaitIndex: 10000, // wait a looong time
|
|
WaitTime: deadline,
|
|
}
|
|
|
|
// Assertions don't require certificate validation
|
|
conf := api.DefaultConfig()
|
|
conf.Address = a.HTTPAddr()
|
|
conf.TLSConfig.Insecure = true
|
|
client, err := api.NewClient(conf)
|
|
require.NoError(t, err)
|
|
defer client.Close()
|
|
|
|
// Assert a blocking query isn't timed out by the
|
|
// handshake timeout
|
|
jobs, meta, err := client.Jobs().List(q)
|
|
require.NoError(t, err)
|
|
require.Len(t, jobs, 0)
|
|
require.Truef(t, meta.RequestTime >= deadline,
|
|
"expected RequestTime (%s) >= Deadline (%s)",
|
|
meta.RequestTime, deadline)
|
|
|
|
return
|
|
}
|
|
|
|
// HTTP Server should *not* have timed out.
|
|
// Now() should always be after the read deadline, but
|
|
// isn't a sufficient assertion for correctness as slow
|
|
// tests may cause this to be true even if the server
|
|
// timed out.
|
|
require.True(t, time.Now().After(readDeadline))
|
|
|
|
require.Truef(t, errors.Is(err, os.ErrDeadlineExceeded),
|
|
"error does not wrap os.ErrDeadlineExceeded: (%T) %v", err, err)
|
|
}
|
|
|
|
assertNoLimit := func(t *testing.T, addr string) {
|
|
var err error
|
|
|
|
// Create max connections
|
|
conns := make([]net.Conn, maxConns)
|
|
errCh := make(chan error, maxConns)
|
|
for i := 0; i < maxConns; i++ {
|
|
conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second)
|
|
require.NoError(t, err)
|
|
|
|
go func(i int) {
|
|
buf := []byte{0}
|
|
readDeadline := time.Now().Add(1 * time.Second)
|
|
err = conns[i].SetReadDeadline(readDeadline)
|
|
require.NoError(t, err)
|
|
n, err := conns[i].Read(buf)
|
|
if n > 0 {
|
|
errCh <- fmt.Errorf("n > 0: %d", n)
|
|
return
|
|
}
|
|
errCh <- err
|
|
}(i)
|
|
}
|
|
|
|
// Now assert each error is a clientside read deadline error
|
|
for i := 0; i < maxConns; i++ {
|
|
select {
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatalf("timed out waiting for conn error %d", i)
|
|
case err := <-errCh:
|
|
require.Truef(t, errors.Is(err, os.ErrDeadlineExceeded),
|
|
"error does not wrap os.ErrDeadlineExceeded: (%T) %v", err, err)
|
|
}
|
|
}
|
|
|
|
for i := 0; i < maxConns; i++ {
|
|
require.NoError(t, conns[i].Close())
|
|
}
|
|
}
|
|
|
|
dial := func(t *testing.T, addr string, useTLS bool) (net.Conn, error) {
|
|
if useTLS {
|
|
cert, err := tls.LoadX509KeyPair(foocert, fookey)
|
|
require.NoError(t, err)
|
|
return tls.Dial("tcp", addr, &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
InsecureSkipVerify: true, // good enough
|
|
})
|
|
} else {
|
|
return net.DialTimeout("tcp", addr, 1*time.Second)
|
|
}
|
|
}
|
|
|
|
assertLimit := func(t *testing.T, addr string, limit int, useTLS bool) {
|
|
var err error
|
|
|
|
// Create limit connections
|
|
conns := make([]net.Conn, limit)
|
|
errCh := make(chan error, limit)
|
|
for i := range conns {
|
|
conn, err := dial(t, addr, useTLS)
|
|
require.NoError(t, err)
|
|
conns[i] = conn
|
|
|
|
go func(i int) {
|
|
buf := []byte{0}
|
|
n, err := conns[i].Read(buf)
|
|
if n > 0 {
|
|
errCh <- fmt.Errorf("n > 0: %d", n)
|
|
return
|
|
}
|
|
errCh <- err
|
|
}(i)
|
|
}
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatalf("unexpected error from connection prior to limit: %T %v", err, err)
|
|
case <-time.After(500 * time.Millisecond):
|
|
}
|
|
|
|
// Create a new connection that will go over the connection limit.
|
|
limitConn, err := dial(t, addr, useTLS)
|
|
require.NoError(t, err)
|
|
|
|
response := "HTTP/1.1 429"
|
|
buf := make([]byte, len(response))
|
|
deadline := time.Now().Add(10 * time.Second)
|
|
require.NoError(t, limitConn.SetReadDeadline(deadline))
|
|
n, err := limitConn.Read(buf)
|
|
require.Equal(t, response, string(buf))
|
|
require.Nil(t, err)
|
|
require.Equal(t, len(response), n)
|
|
require.NoError(t, limitConn.Close())
|
|
|
|
// Assert existing connections are ok
|
|
require.Len(t, errCh, 0)
|
|
|
|
// Cleanup
|
|
for _, conn := range conns {
|
|
require.NoError(t, conn.Close())
|
|
}
|
|
|
|
for range conns {
|
|
err := <-errCh
|
|
require.Contains(t, err.Error(), "use of closed network connection")
|
|
}
|
|
}
|
|
|
|
for i := range cases {
|
|
tc := cases[i]
|
|
name := fmt.Sprintf("%d-tls-%t-timeout-%s-limit-%v", i, tc.tls, tc.timeout, limitStr(tc.limit))
|
|
t.Run(name, func(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
if tc.limit != nil && *tc.limit >= maxConns {
|
|
t.Fatalf("test fixture failure: cannot assert limit (%d) >= max (%d)", *tc.limit, maxConns)
|
|
}
|
|
|
|
s := makeHTTPServer(t, func(c *Config) {
|
|
if tc.tls {
|
|
c.TLSConfig = &config.TLSConfig{
|
|
EnableHTTP: true,
|
|
CAFile: cafile,
|
|
CertFile: foocert,
|
|
KeyFile: fookey,
|
|
}
|
|
}
|
|
c.Limits.HTTPSHandshakeTimeout = tc.timeout
|
|
c.Limits.HTTPMaxConnsPerClient = tc.limit
|
|
c.LogLevel = "ERROR"
|
|
})
|
|
defer func() {
|
|
require.NoError(t, s.Shutdown())
|
|
}()
|
|
|
|
assertTimeout(t, s, tc.assertTimeout, tc.timeout)
|
|
|
|
if tc.assertLimit {
|
|
// There's a race between assertTimeout(false) closing
|
|
// its connection and the HTTP server noticing and
|
|
// un-tracking it. Since there's no way to coordinate
|
|
// when this occurs, sleeping is the only way to avoid
|
|
// asserting limits before the timed out connection is
|
|
// untracked.
|
|
time.Sleep(1 * time.Second)
|
|
|
|
assertLimit(t, s.Server.Addr, *tc.limit, tc.tls)
|
|
} else {
|
|
assertNoLimit(t, s.Server.Addr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHTTPServer_ResolveToken(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
// Setup two servers, one with ACL enabled and another with ACL disabled.
|
|
noACLServer := makeHTTPServer(t, func(c *Config) {
|
|
c.ACL = &ACLConfig{Enabled: false}
|
|
})
|
|
defer noACLServer.Shutdown()
|
|
|
|
ACLServer := makeHTTPServer(t, func(c *Config) {
|
|
c.ACL = &ACLConfig{Enabled: true}
|
|
})
|
|
defer ACLServer.Shutdown()
|
|
|
|
// Register sample token.
|
|
state := ACLServer.Agent.server.State()
|
|
token := mock.CreatePolicyAndToken(t, state, 1000, "node", mock.NodePolicy(acl.PolicyWrite))
|
|
|
|
// Tests cases.
|
|
t.Run("acl disabled", func(t *testing.T) {
|
|
req := &http.Request{Body: http.NoBody}
|
|
got, err := noACLServer.Server.ResolveToken(req)
|
|
require.NoError(t, err)
|
|
require.Nil(t, got)
|
|
})
|
|
|
|
t.Run("token not found", func(t *testing.T) {
|
|
req := &http.Request{
|
|
Body: http.NoBody,
|
|
Header: make(map[string][]string),
|
|
}
|
|
setToken(req, mock.ACLToken())
|
|
got, err := ACLServer.Server.ResolveToken(req)
|
|
require.Nil(t, got)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "ACL token not found")
|
|
})
|
|
|
|
t.Run("set token", func(t *testing.T) {
|
|
req := &http.Request{
|
|
Body: http.NoBody,
|
|
Header: make(map[string][]string),
|
|
}
|
|
setToken(req, token)
|
|
got, err := ACLServer.Server.ResolveToken(req)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, got)
|
|
require.True(t, got.AllowNodeWrite())
|
|
})
|
|
}
|
|
|
|
func Test_IsAPIClientError(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
trueCases := []int{400, 403, 404, 499}
|
|
for _, c := range trueCases {
|
|
require.Truef(t, isAPIClientError(c), "code: %v", c)
|
|
}
|
|
|
|
falseCases := []int{100, 300, 500, 501, 505}
|
|
for _, c := range falseCases {
|
|
require.Falsef(t, isAPIClientError(c), "code: %v", c)
|
|
}
|
|
}
|
|
|
|
func Test_decodeBody(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
testCases := []struct {
|
|
inputReq *http.Request
|
|
inputOut interface{}
|
|
expectedOut interface{}
|
|
expectedError error
|
|
name string
|
|
}{
|
|
{
|
|
inputReq: &http.Request{Body: http.NoBody},
|
|
expectedError: errors.New("Request body is empty"),
|
|
name: "empty input request body",
|
|
},
|
|
{
|
|
inputReq: &http.Request{Body: ioutil.NopCloser(strings.NewReader(`{"foo":"bar"}`))},
|
|
inputOut: &struct {
|
|
Foo string `json:"foo"`
|
|
}{},
|
|
expectedOut: &struct {
|
|
Foo string `json:"foo"`
|
|
}{Foo: "bar"},
|
|
expectedError: nil,
|
|
name: "populated request body and correct out",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
actualError := decodeBody(tc.inputReq, tc.inputOut)
|
|
assert.Equal(t, tc.expectedError, actualError, tc.name)
|
|
assert.Equal(t, tc.expectedOut, tc.inputOut, tc.name)
|
|
})
|
|
}
|
|
}
|
|
|
|
// BenchmarkHTTPServer_JSONEncodingWithExtensions benchmarks the performance of
|
|
// encoding JSON objects using extensions
|
|
func BenchmarkHTTPServer_JSONEncodingWithExtensions(b *testing.B) {
|
|
benchmarkJsonEncoding(b, structs.JsonHandleWithExtensions)
|
|
}
|
|
|
|
// BenchmarkHTTPServer_JSONEncodingWithoutExtensions benchmarks the performance of
|
|
// encoding JSON objects using extensions
|
|
func BenchmarkHTTPServer_JSONEncodingWithoutExtensions(b *testing.B) {
|
|
benchmarkJsonEncoding(b, structs.JsonHandle)
|
|
}
|
|
|
|
func benchmarkJsonEncoding(b *testing.B, handle *codec.JsonHandle) {
|
|
n := mock.Node()
|
|
var buf bytes.Buffer
|
|
|
|
enc := codec.NewEncoder(&buf, handle)
|
|
for i := 0; i < b.N; i++ {
|
|
buf.Reset()
|
|
err := enc.Encode(n)
|
|
require.NoError(b, err)
|
|
}
|
|
}
|
|
|
|
func httpTest(t testing.TB, cb func(c *Config), f func(srv *TestAgent)) {
|
|
s := makeHTTPServer(t, cb)
|
|
defer s.Shutdown()
|
|
testutil.WaitForLeader(t, s.Agent.RPC)
|
|
f(s)
|
|
}
|
|
|
|
func httpACLTest(t testing.TB, cb func(c *Config), f func(srv *TestAgent)) {
|
|
s := makeHTTPServer(t, func(c *Config) {
|
|
c.ACL.Enabled = true
|
|
if cb != nil {
|
|
cb(c)
|
|
}
|
|
})
|
|
defer s.Shutdown()
|
|
testutil.WaitForLeader(t, s.Agent.RPC)
|
|
f(s)
|
|
}
|
|
|
|
func setToken(req *http.Request, token *structs.ACLToken) {
|
|
req.Header.Set("X-Nomad-Token", token.SecretID)
|
|
}
|
|
|
|
func setNamespace(req *http.Request, ns string) {
|
|
q := req.URL.Query()
|
|
q.Add("namespace", ns)
|
|
req.URL.RawQuery = q.Encode()
|
|
}
|
|
|
|
func encodeReq(obj interface{}) io.ReadCloser {
|
|
buf := bytes.NewBuffer(nil)
|
|
enc := json.NewEncoder(buf)
|
|
enc.Encode(obj)
|
|
return ioutil.NopCloser(buf)
|
|
}
|