// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package agent import ( "bytes" "context" "crypto/tls" "encoding/json" "fmt" "io" "net" "net/http" "net/http/httptest" "net/netip" "net/url" "os" "path/filepath" "runtime" "strconv" "strings" "testing" "time" "github.com/NYTimes/gziphandler" "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-hclog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/http2" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/consul" "github.com/hashicorp/consul/agent/structs" tokenStore "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/testrpc" ) func TestHTTPServer_UnixSocket(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() if runtime.GOOS == "windows" { t.SkipNow() } tempDir := testutil.TempDir(t, "consul") socket := filepath.Join(tempDir, "test.sock") // Only testing mode, since uid/gid might not be settable // from test environment. a := NewTestAgent(t, ` addresses { http = "unix://`+socket+`" } unix_sockets { mode = "0777" } `) defer a.Shutdown() // Ensure the socket was created if _, err := os.Stat(socket); err != nil { t.Fatalf("err: %s", err) } // Ensure the mode was set properly fi, err := os.Stat(socket) if err != nil { t.Fatalf("err: %s", err) } if fi.Mode().String() != "Srwxrwxrwx" { t.Fatalf("bad permissions: %s", fi.Mode()) } // Ensure we can get a response from the socket. trans := cleanhttp.DefaultTransport() trans.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { return net.Dial("unix", socket) } client := &http.Client{ Transport: trans, } // This URL doesn't look like it makes sense, but the scheme (http://) and // the host (127.0.0.1) are required by the HTTP client library. In reality // this will just use the custom dialer and talk to the socket. resp, err := client.Get("http://127.0.0.1/v1/agent/self") if err != nil { t.Fatalf("err: %s", err) } defer resp.Body.Close() if body, err := io.ReadAll(resp.Body); err != nil || len(body) == 0 { t.Fatalf("bad: %s %v", body, err) } } func TestHTTPServer_UnixSocket_FileExists(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() if runtime.GOOS == "windows" { t.SkipNow() } tempDir := testutil.TempDir(t, "consul") socket := filepath.Join(tempDir, "test.sock") // Create a regular file at the socket path if err := os.WriteFile(socket, []byte("hello world"), 0644); err != nil { t.Fatalf("err: %s", err) } fi, err := os.Stat(socket) if err != nil { t.Fatalf("err: %s", err) } if !fi.Mode().IsRegular() { t.Fatalf("not a regular file: %s", socket) } a := NewTestAgent(t, ` addresses { http = "unix://`+socket+`" } `) defer a.Shutdown() // Ensure the file was replaced by the socket fi, err = os.Stat(socket) if err != nil { t.Fatalf("err: %s", err) } if fi.Mode()&os.ModeSocket == 0 { t.Fatalf("expected socket to replace file") } } func TestHTTPSServer_UnixSocket(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() if runtime.GOOS == "windows" { t.SkipNow() } tempDir := testutil.TempDir(t, "consul") socket := filepath.Join(tempDir, "test.sock") a := StartTestAgent(t, TestAgent{ UseHTTPS: true, HCL: ` addresses { https = "unix://` + socket + `" } unix_sockets { mode = "0777" } tls { defaults { ca_file = "../test/client_certs/rootca.crt" cert_file = "../test/client_certs/server.crt" key_file = "../test/client_certs/server.key" } } `, }) defer a.Shutdown() // Ensure the socket was created if _, err := os.Stat(socket); err != nil { t.Fatalf("err: %s", err) } // Ensure the mode was set properly fi, err := os.Stat(socket) if err != nil { t.Fatalf("err: %s", err) } if fi.Mode().String() != "Srwxrwxrwx" { t.Fatalf("bad permissions: %s", fi.Mode()) } // Make an HTTP/2-enabled client, using the API helpers to set // up TLS to be as normal as possible for Consul. tlscfg := &api.TLSConfig{ Address: "consul.test", KeyFile: "../test/client_certs/client.key", CertFile: "../test/client_certs/client.crt", CAFile: "../test/client_certs/rootca.crt", } tlsccfg, err := api.SetupTLSConfig(tlscfg) if err != nil { t.Fatalf("err: %v", err) } transport := api.DefaultConfig().Transport transport.TLSHandshakeTimeout = 30 * time.Second transport.TLSClientConfig = tlsccfg if err := http2.ConfigureTransport(transport); err != nil { t.Fatalf("err: %v", err) } transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { return net.Dial("unix", socket) } client := &http.Client{Transport: transport} u, err := url.Parse("https://unix" + socket) if err != nil { t.Fatalf("err: %s", err) } u.Path = "/v1/agent/self" u.Scheme = "https" resp, err := client.Get(u.String()) if err != nil { t.Fatalf("err: %s", err) } defer resp.Body.Close() if body, err := io.ReadAll(resp.Body); err != nil || len(body) == 0 { t.Fatalf("bad: %s %v", body, err) } else if !strings.Contains(string(body), "NodeName") { t.Fatalf("NodeName not found in results: %s", string(body)) } } func TestSetupHTTPServer_HTTP2(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() // Fire up an agent with TLS enabled. a := StartTestAgent(t, TestAgent{ UseHTTPS: true, HCL: ` tls { defaults { ca_file = "../test/client_certs/rootca.crt" cert_file = "../test/client_certs/server.crt" key_file = "../test/client_certs/server.key" } } `, }) defer a.Shutdown() // Make an HTTP/2-enabled client, using the API helpers to set // up TLS to be as normal as possible for Consul. tlscfg := &api.TLSConfig{ Address: "consul.test", KeyFile: "../test/client_certs/client.key", CertFile: "../test/client_certs/client.crt", CAFile: "../test/client_certs/rootca.crt", } tlsccfg, err := api.SetupTLSConfig(tlscfg) if err != nil { t.Fatalf("err: %v", err) } transport := api.DefaultConfig().Transport transport.TLSHandshakeTimeout = 30 * time.Second transport.TLSClientConfig = tlsccfg if err := http2.ConfigureTransport(transport); err != nil { t.Fatalf("err: %v", err) } httpClient := &http.Client{Transport: transport} // Hook a handler that echoes back the protocol. handler := func(resp http.ResponseWriter, req *http.Request) { resp.WriteHeader(http.StatusOK) fmt.Fprint(resp, req.Proto) } // Create an httpServer to be configured with setupHTTPS, and add our // custom handler. httpServer := &http.Server{} noopConnState := func(net.Conn, http.ConnState) {} err = setupHTTPS(httpServer, noopConnState, time.Second) require.NoError(t, err) a.enableDebug.Store(true) srvHandler := a.srv.handler() mux, ok := srvHandler.(*wrappedMux) require.True(t, ok, "expected a *wrappedMux, got %T", handler) mux.mux.HandleFunc("/echo", handler) httpServer.Handler = mux listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) tlsListener := tls.NewListener(listener, a.tlsConfigurator.IncomingHTTPSConfig()) go httpServer.Serve(tlsListener) defer httpServer.Shutdown(context.Background()) // Call it and make sure we see HTTP/2. url := fmt.Sprintf("https://%s/echo", listener.Addr().String()) resp, err := httpClient.Get(url) if err != nil { t.Fatalf("err: %v", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(body, []byte("HTTP/2.0")) { t.Fatalf("bad: %v", body) } // This doesn't have a closed-loop way to verify HTTP/2 support for // some other endpoint, but configure an API client and make a call // just as a sanity check. cfg := &api.Config{ Address: listener.Addr().String(), Scheme: "https", HttpClient: httpClient, } client, err := api.NewClient(cfg) if err != nil { t.Fatalf("err: %v", err) } if _, err := client.Agent().Self(); err != nil { t.Fatalf("err: %v", err) } } func TestSetIndex(t *testing.T) { t.Parallel() resp := httptest.NewRecorder() setIndex(resp, 1000) header := resp.Header().Get("X-Consul-Index") if header != "1000" { t.Fatalf("Bad: %v", header) } setIndex(resp, 2000) if v := resp.Header()["X-Consul-Index"]; len(v) != 1 { t.Fatalf("bad: %#v", v) } } func TestSetKnownLeader(t *testing.T) { t.Parallel() resp := httptest.NewRecorder() setKnownLeader(resp, true) header := resp.Header().Get("X-Consul-KnownLeader") if header != "true" { t.Fatalf("Bad: %v", header) } resp = httptest.NewRecorder() setKnownLeader(resp, false) header = resp.Header().Get("X-Consul-KnownLeader") if header != "false" { t.Fatalf("Bad: %v", header) } } func TestSetFilteredByACLs(t *testing.T) { t.Parallel() resp := httptest.NewRecorder() setResultsFilteredByACLs(resp, true) header := resp.Header().Get("X-Consul-Results-Filtered-By-ACLs") if header != "true" { t.Fatalf("Bad: %v", header) } resp = httptest.NewRecorder() setResultsFilteredByACLs(resp, false) header = resp.Header().Get("X-Consul-Results-Filtered-By-ACLs") if header != "" { t.Fatalf("Bad: %v", header) } } func TestSetLastContact(t *testing.T) { t.Parallel() tests := []struct { desc string d time.Duration h string }{ {"neg", -1, "0"}, {"zero", 0, "0"}, {"pos", 123 * time.Millisecond, "123"}, {"pos ms only", 123456 * time.Microsecond, "123"}, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { resp := httptest.NewRecorder() setLastContact(resp, tt.d) header := resp.Header().Get("X-Consul-LastContact") if got, want := header, tt.h; got != want { t.Fatalf("got X-Consul-LastContact header %q want %q", got, want) } }) } } func TestSetMeta(t *testing.T) { t.Parallel() meta := structs.QueryMeta{ Index: 1000, KnownLeader: true, LastContact: 123456 * time.Microsecond, ResultsFilteredByACLs: true, } resp := httptest.NewRecorder() setMeta(resp, &meta) testCases := map[string]string{ "X-Consul-Index": "1000", "X-Consul-KnownLeader": "true", "X-Consul-LastContact": "123", "X-Consul-Results-Filtered-By-ACLs": "true", } for header, expectedValue := range testCases { if v := resp.Header().Get(header); v != expectedValue { t.Fatalf("expected %q for header %s got %q", expectedValue, header, v) } } } func TestHTTPAPI_BlockEndpoints(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() a := NewTestAgent(t, ` http_config { block_endpoints = ["/v1/agent/self"] } `) defer a.Shutdown() handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) { return nil, nil } // Try a blocked endpoint, which should get a 403. { req, _ := http.NewRequest("GET", "/v1/agent/self", nil) resp := httptest.NewRecorder() a.srv.wrap(handler, []string{"GET"})(resp, req) if got, want := resp.Code, http.StatusForbidden; got != want { t.Fatalf("bad response code got %d want %d", got, want) } } // Make sure some other endpoint still works. { req, _ := http.NewRequest("GET", "/v1/agent/checks", nil) resp := httptest.NewRecorder() a.srv.wrap(handler, []string{"GET"})(resp, req) if got, want := resp.Code, http.StatusOK; got != want { t.Fatalf("bad response code got %d want %d", got, want) } } } func TestHTTPAPI_Ban_Nonprintable_Characters(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } a := NewTestAgent(t, "") defer a.Shutdown() _, err := http.NewRequest("GET", "/v1/kv/bad\x00ness", nil) if err == nil { t.Fatal("expected error") } req, err := http.NewRequest("GET", "/v1/kv/bad%00ness", nil) if err != nil { t.Fatal(err) } resp := httptest.NewRecorder() a.enableDebug.Store(true) a.srv.handler().ServeHTTP(resp, req) if got, want := resp.Code, http.StatusBadRequest; got != want { t.Fatalf("bad response code got %d want %d", got, want) } } func TestHTTPAPI_Allow_Nonprintable_Characters_With_Flag(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } a := NewTestAgent(t, "disable_http_unprintable_char_filter = true") defer a.Shutdown() _, err := http.NewRequest("GET", "/v1/kv/bad\x00ness", nil) if err == nil { t.Fatal("expected error") } req, err := http.NewRequest("GET", "/v1/kv/bad%00ness", nil) if err != nil { t.Fatal(err) } resp := httptest.NewRecorder() a.enableDebug.Store(true) a.srv.handler().ServeHTTP(resp, req) // Key doesn't actually exist so we should get 404 if got, want := resp.Code, http.StatusNotFound; got != want { t.Fatalf("bad response code got %d want %d", got, want) } } func TestHTTPAPI_TranslateAddrHeader(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() // Header should not be present if address translation is off. { a := NewTestAgent(t, "") defer a.Shutdown() resp := httptest.NewRecorder() handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) { return nil, nil } req, _ := http.NewRequest("GET", "/v1/agent/self", nil) a.srv.wrap(handler, []string{"GET"})(resp, req) translate := resp.Header().Get("X-Consul-Translate-Addresses") if translate != "" { t.Fatalf("bad: expected %q, got %q", "", translate) } } // Header should be set to true if it's turned on. { a := NewTestAgent(t, ` translate_wan_addrs = true `) defer a.Shutdown() resp := httptest.NewRecorder() handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) { return nil, nil } req, _ := http.NewRequest("GET", "/v1/agent/self", nil) a.srv.wrap(handler, []string{"GET"})(resp, req) translate := resp.Header().Get("X-Consul-Translate-Addresses") if translate != "true" { t.Fatalf("bad: expected %q, got %q", "true", translate) } } } func TestHTTPAPI_DefaultACLPolicy(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() type testcase struct { name string hcl string expect string } cases := []testcase{ { name: "default is allow", hcl: ``, expect: "allow", }, { name: "explicit allow", hcl: `acl { default_policy = "allow" }`, expect: "allow", }, { name: "explicit deny", hcl: `acl { default_policy = "deny" }`, expect: "deny", }, } for _, tc := range cases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() a := NewTestAgent(t, tc.hcl) defer a.Shutdown() resp := httptest.NewRecorder() handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) { return nil, nil } req, _ := http.NewRequest("GET", "/v1/agent/self", nil) a.srv.wrap(handler, []string{"GET"})(resp, req) require.Equal(t, tc.expect, resp.Header().Get("X-Consul-Default-ACL-Policy")) }) } } func TestHTTPAPIResponseHeaders(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() a := NewTestAgent(t, ` ui_config { # Explicitly disable UI so we can ensure the index replacement gets headers too. enabled = false } http_config { response_headers = { "Access-Control-Allow-Origin" = "*" "X-XSS-Protection" = "1; mode=block" "X-Frame-Options" = "SAMEORIGIN" } } `) defer a.Shutdown() requireHasHeadersSet(t, a, "/v1/agent/self") // Check the Index page that just renders a simple message with UI disabled // also gets the right headers. requireHasHeadersSet(t, a, "/") } func requireHasHeadersSet(t *testing.T, a *TestAgent, path string) { t.Helper() resp := httptest.NewRecorder() req, _ := http.NewRequest("GET", path, nil) a.enableDebug.Store(true) a.srv.handler().ServeHTTP(resp, req) hdrs := resp.Header() require.Equal(t, "*", hdrs.Get("Access-Control-Allow-Origin"), "Access-Control-Allow-Origin header value incorrect") require.Equal(t, "1; mode=block", hdrs.Get("X-XSS-Protection"), "X-XSS-Protection header value incorrect") } func TestUIResponseHeaders(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() a := NewTestAgent(t, ` http_config { response_headers = { "Access-Control-Allow-Origin" = "*" "X-XSS-Protection" = "1; mode=block" "X-Frame-Options" = "SAMEORIGIN" } } `) defer a.Shutdown() requireHasHeadersSet(t, a, "/ui") } func TestAcceptEncodingGzip(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() a := NewTestAgent(t, "") defer a.Shutdown() // Setting up the KV to store a short and a long value buf := bytes.NewBuffer([]byte("short")) req, _ := http.NewRequest("PUT", "/v1/kv/short", buf) resp := httptest.NewRecorder() _, err := a.srv.KVSEndpoint(resp, req) require.NoError(t, err) // this generates a string which is longer than // gziphandler.DefaultMinSize to trigger compression. long := fmt.Sprintf(fmt.Sprintf("%%0%dd", gziphandler.DefaultMinSize+1), 1) buf = bytes.NewBuffer([]byte(long)) req, _ = http.NewRequest("PUT", "/v1/kv/long", buf) resp = httptest.NewRecorder() _, err = a.srv.KVSEndpoint(resp, req) require.NoError(t, err) resp = httptest.NewRecorder() req, _ = http.NewRequest("GET", "/v1/kv/short", nil) // Usually this would be automatically set by transport content // negotiation, but since this call doesn't go through a real // transport, the header has to be set manually req.Header["Accept-Encoding"] = []string{"gzip"} a.enableDebug.Store(true) a.srv.handler().ServeHTTP(resp, req) require.Equal(t, 200, resp.Code) require.Equal(t, "", resp.Header().Get("Content-Encoding")) resp = httptest.NewRecorder() req, _ = http.NewRequest("GET", "/v1/kv/long", nil) req.Header["Accept-Encoding"] = []string{"gzip"} a.enableDebug.Store(true) a.srv.handler().ServeHTTP(resp, req) require.Equal(t, 200, resp.Code) require.Equal(t, "gzip", resp.Header().Get("Content-Encoding")) } func TestContentTypeIsJSON(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() a := NewTestAgent(t, "") defer a.Shutdown() resp := httptest.NewRecorder() handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) { // stub out a DirEntry so that it will be encoded as JSON return &structs.DirEntry{Key: "key"}, nil } req, _ := http.NewRequest("GET", "/v1/kv/key", nil) a.srv.wrap(handler, []string{"GET"})(resp, req) contentType := resp.Header().Get("Content-Type") if contentType != "application/json" { t.Fatalf("Content-Type header was not 'application/json'") } } func TestHTTP_wrap_obfuscateLog(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() buf := &syncBuffer{b: new(bytes.Buffer)} a := StartTestAgent(t, TestAgent{ LogOutput: buf, LogLevel: hclog.Debug, }) defer a.Shutdown() handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) { return nil, nil } for _, pair := range [][]string{ { "/some/url?token=secret1&token=secret2", "/some/url?token=&token=", }, { "/v1/acl/clone/secret1", "/v1/acl/clone/", }, { "/v1/acl/clone/secret1?token=secret2", "/v1/acl/clone/?token=", }, { "/v1/acl/destroy/secret1", "/v1/acl/destroy/", }, { "/v1/acl/destroy/secret1?token=secret2", "/v1/acl/destroy/?token=", }, { "/v1/acl/info/secret1", "/v1/acl/info/", }, { "/v1/acl/info/secret1?token=secret2", "/v1/acl/info/?token=", }, } { url, want := pair[0], pair[1] t.Run(url, func(t *testing.T) { resp := httptest.NewRecorder() req, _ := http.NewRequest("GET", url, nil) a.srv.wrap(handler, []string{"GET"})(resp, req) bufout := buf.String() require.Contains(t, bufout, want) }) } } func TestPrettyPrint(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() testPrettyPrint("pretty=1", t) } func TestPrettyPrintBare(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() testPrettyPrint("pretty", t) } func testPrettyPrint(pretty string, t *testing.T) { a := NewTestAgent(t, "") defer a.Shutdown() r := &structs.DirEntry{Key: "key"} resp := httptest.NewRecorder() handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) { return r, nil } urlStr := "/v1/kv/key?" + pretty req, _ := http.NewRequest("GET", urlStr, nil) a.srv.wrap(handler, []string{"GET"})(resp, req) expected, _ := json.MarshalIndent(r, "", " ") expected = append(expected, "\n"...) actual, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("err: %s", err) } if !bytes.Equal(expected, actual) { t.Fatalf("bad: %q", string(actual)) } } func TestParseSource(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() a := NewTestAgent(t, "") defer a.Shutdown() // Default is agent's DC and no node (since the user didn't care, then // just give them the cheapest possible query). req, _ := http.NewRequest("GET", "/v1/catalog/nodes", nil) source := structs.QuerySource{} a.srv.parseSource(req, &source) if source.Datacenter != "dc1" || source.Node != "" { t.Fatalf("bad: %v", source) } // Adding the source parameter should set that node. req, _ = http.NewRequest("GET", "/v1/catalog/nodes?near=bob", nil) source = structs.QuerySource{} a.srv.parseSource(req, &source) if source.Datacenter != "dc1" || source.Node != "bob" { t.Fatalf("bad: %v", source) } // We should follow whatever dc parameter was given so that the node is // looked up correctly on the receiving end. req, _ = http.NewRequest("GET", "/v1/catalog/nodes?near=bob&dc=foo", nil) source = structs.QuerySource{} a.srv.parseSource(req, &source) if source.Datacenter != "foo" || source.Node != "bob" { t.Fatalf("bad: %v", source) } // We should follow whatever datacenter parameter was given so that the node is // looked up correctly on the receiving end. req, _ = http.NewRequest("GET", "/v1/catalog/nodes?near=bob&datacenter=foo", nil) source = structs.QuerySource{} a.srv.parseSource(req, &source) if source.Datacenter != "foo" || source.Node != "bob" { t.Fatalf("bad: %v", source) } // The magic "_agent" node name will use the agent's local node name. req, _ = http.NewRequest("GET", "/v1/catalog/nodes?near=_agent", nil) source = structs.QuerySource{} a.srv.parseSource(req, &source) if source.Datacenter != "dc1" || source.Node != a.Config.NodeName { t.Fatalf("bad: %v", source) } } func TestParseCacheControl(t *testing.T) { tests := []struct { name string headerVal string want structs.QueryOptions wantErr bool }{ { name: "empty header", headerVal: "", want: structs.QueryOptions{}, wantErr: false, }, { name: "simple max-age", headerVal: "max-age=30", want: structs.QueryOptions{ MaxAge: 30 * time.Second, }, wantErr: false, }, { name: "zero max-age", headerVal: "max-age=0", want: structs.QueryOptions{ MustRevalidate: true, }, wantErr: false, }, { name: "must-revalidate", headerVal: "must-revalidate", want: structs.QueryOptions{ MustRevalidate: true, }, wantErr: false, }, { name: "mixes age, must-revalidate", headerVal: "max-age=123, must-revalidate", want: structs.QueryOptions{ MaxAge: 123 * time.Second, MustRevalidate: true, }, wantErr: false, }, { name: "quoted max-age", headerVal: "max-age=\"30\"", want: structs.QueryOptions{}, wantErr: true, }, { name: "mixed case max-age", headerVal: "Max-Age=30", want: structs.QueryOptions{ MaxAge: 30 * time.Second, }, wantErr: false, }, { name: "simple stale-if-error", headerVal: "stale-if-error=300", want: structs.QueryOptions{ StaleIfError: 300 * time.Second, }, wantErr: false, }, { name: "combined with space", headerVal: "max-age=30, stale-if-error=300", want: structs.QueryOptions{ MaxAge: 30 * time.Second, StaleIfError: 300 * time.Second, }, wantErr: false, }, { name: "combined no space", headerVal: "stale-IF-error=300,max-age=30", want: structs.QueryOptions{ MaxAge: 30 * time.Second, StaleIfError: 300 * time.Second, }, wantErr: false, }, { name: "unsupported directive", headerVal: "no-cache", want: structs.QueryOptions{}, wantErr: false, }, { name: "mixed unsupported directive", headerVal: "no-cache, max-age=120", want: structs.QueryOptions{ MaxAge: 120 * time.Second, }, wantErr: false, }, { name: "garbage value", headerVal: "max-age=\"I'm not, an int\"", want: structs.QueryOptions{}, wantErr: true, }, { name: "garbage value with quotes", headerVal: "max-age=\"I'm \\\"not an int\"", want: structs.QueryOptions{}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r, _ := http.NewRequest("GET", "/foo/bar", nil) if tt.headerVal != "" { r.Header.Set("Cache-Control", tt.headerVal) } rr := httptest.NewRecorder() var got structs.QueryOptions failed := parseCacheControl(rr, r, &got) if tt.wantErr { require.True(t, failed) require.Equal(t, http.StatusBadRequest, rr.Code) } else { require.False(t, failed) } require.Equal(t, tt.want, got) }) } } func TestParseWait(t *testing.T) { t.Parallel() resp := httptest.NewRecorder() var b structs.QueryOptions req, _ := http.NewRequest("GET", "/v1/catalog/nodes?wait=60s&index=1000", nil) if d := parseWait(resp, req, &b); d { t.Fatalf("unexpected done") } if b.MinQueryIndex != 1000 { t.Fatalf("Bad: %v", b) } if b.MaxQueryTime != 60*time.Second { t.Fatalf("Bad: %v", b) } } func TestHTTPServer_PProfHandlers_EnableDebug(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() a := NewTestAgent(t, ``) defer a.Shutdown() resp := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil) a.enableDebug.Store(true) httpServer := &HTTPHandlers{agent: a.Agent} httpServer.handler().ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) } func TestHTTPServer_PProfHandlers_DisableDebugNoACLs(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() a := NewTestAgent(t, ``) defer a.Shutdown() resp := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/debug/pprof/profile", nil) httpServer := &HTTPHandlers{agent: a.Agent} httpServer.handler().ServeHTTP(resp, req) require.Equal(t, http.StatusNotFound, resp.Code) } func TestHTTPServer_PProfHandlers_ACLs(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() dc1 := "dc1" a := NewTestAgent(t, ` primary_datacenter = "`+dc1+`" acl { enabled = true default_policy = "deny" tokens { initial_management = "root" agent = "agent" agent_recovery = "towel" } } enable_debug = false `) cases := []struct { code int token string endpoint string nilResponse bool }{ { code: http.StatusOK, token: "root", endpoint: "/debug/pprof/heap", nilResponse: false, }, { code: http.StatusForbidden, token: "agent", endpoint: "/debug/pprof/heap", nilResponse: true, }, { code: http.StatusForbidden, token: "agent", endpoint: "/debug/pprof/", nilResponse: true, }, { code: http.StatusForbidden, token: "", endpoint: "/debug/pprof/", nilResponse: true, }, { code: http.StatusOK, token: "root", endpoint: "/debug/pprof/heap", nilResponse: false, }, { code: http.StatusForbidden, token: "towel", endpoint: "/debug/pprof/heap", nilResponse: true, }, } defer a.Shutdown() testrpc.WaitForLeader(t, a.RPC, "dc1") for i, c := range cases { t.Run(fmt.Sprintf("case %d (%#v)", i, c), func(t *testing.T) { req, _ := http.NewRequest("GET", fmt.Sprintf("%s?token=%s", c.endpoint, c.token), nil) resp := httptest.NewRecorder() a.enableDebug.Store(true) a.srv.handler().ServeHTTP(resp, req) assert.Equal(t, c.code, resp.Code) }) } } func TestParseWait_InvalidTime(t *testing.T) { t.Parallel() resp := httptest.NewRecorder() var b structs.QueryOptions req, _ := http.NewRequest("GET", "/v1/catalog/nodes?wait=60foo&index=1000", nil) if d := parseWait(resp, req, &b); !d { t.Fatalf("expected done") } if resp.Code != 400 { t.Fatalf("bad code: %v", resp.Code) } } func TestParseWait_InvalidIndex(t *testing.T) { t.Parallel() resp := httptest.NewRecorder() var b structs.QueryOptions req, _ := http.NewRequest("GET", "/v1/catalog/nodes?wait=60s&index=foo", nil) if d := parseWait(resp, req, &b); !d { t.Fatalf("expected done") } if resp.Code != 400 { t.Fatalf("bad code: %v", resp.Code) } } func TestParseConsistency(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() resp := httptest.NewRecorder() var b structs.QueryOptions req, _ := http.NewRequest("GET", "/v1/catalog/nodes?stale", nil) a := NewTestAgent(t, "") defer a.Shutdown() if d := a.srv.parseConsistency(resp, req, &b); d { t.Fatalf("unexpected done") } if !b.AllowStale { t.Fatalf("Bad: %v", b) } if b.RequireConsistent { t.Fatalf("Bad: %v", b) } b = structs.QueryOptions{} req, _ = http.NewRequest("GET", "/v1/catalog/nodes?consistent", nil) if d := a.srv.parseConsistency(resp, req, &b); d { t.Fatalf("unexpected done") } if b.AllowStale { t.Fatalf("Bad: %v", b) } if !b.RequireConsistent { t.Fatalf("Bad: %v", b) } } // ensureConsistency check if consistency modes are correctly applied // if maxStale < 0 => stale, without MaxStaleDuration // if maxStale == 0 => no stale // if maxStale > 0 => stale + check duration func ensureConsistency(t *testing.T, a *TestAgent, path string, maxStale time.Duration, requireConsistent bool) { t.Helper() req, _ := http.NewRequest("GET", path, nil) var b structs.QueryOptions resp := httptest.NewRecorder() if d := a.srv.parseConsistency(resp, req, &b); d { t.Fatalf("unexpected done") } allowStale := maxStale.Nanoseconds() != 0 if b.AllowStale != allowStale { t.Fatalf("Bad Allow Stale") } if maxStale > 0 && b.MaxStaleDuration != maxStale { t.Fatalf("Bad MaxStaleDuration: %d VS expected %d", b.MaxStaleDuration, maxStale) } if b.RequireConsistent != requireConsistent { t.Fatal("Bad Consistent") } } func TestParseConsistencyAndMaxStale(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } a := NewTestAgent(t, "") defer a.Shutdown() // Default => Consistent a.config.DiscoveryMaxStale = time.Duration(0) ensureConsistency(t, a, "/v1/catalog/nodes", 0, false) // Stale, without MaxStale ensureConsistency(t, a, "/v1/catalog/nodes?stale", -1, false) // Override explicitly ensureConsistency(t, a, "/v1/catalog/nodes?max_stale=3s", 3*time.Second, false) ensureConsistency(t, a, "/v1/catalog/nodes?stale&max_stale=3s", 3*time.Second, false) // stale by defaul on discovery a.config.DiscoveryMaxStale = 7 * time.Second ensureConsistency(t, a, "/v1/catalog/nodes", a.config.DiscoveryMaxStale, false) // Not in KV ensureConsistency(t, a, "/v1/kv/my/path", 0, false) // DiscoveryConsistencyLevel should apply ensureConsistency(t, a, "/v1/health/service/one", a.config.DiscoveryMaxStale, false) ensureConsistency(t, a, "/v1/catalog/service/one", a.config.DiscoveryMaxStale, false) ensureConsistency(t, a, "/v1/catalog/services", a.config.DiscoveryMaxStale, false) // Query path should be taken into account ensureConsistency(t, a, "/v1/catalog/services?consistent", 0, true) // Since stale is added, no MaxStale should be applied ensureConsistency(t, a, "/v1/catalog/services?stale", -1, false) ensureConsistency(t, a, "/v1/catalog/services?leader", 0, false) } func TestParseConsistency_Invalid(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() resp := httptest.NewRecorder() var b structs.QueryOptions req, _ := http.NewRequest("GET", "/v1/catalog/nodes?stale&consistent", nil) a := NewTestAgent(t, "") defer a.Shutdown() if d := a.srv.parseConsistency(resp, req, &b); !d { t.Fatalf("expected done") } if resp.Code != 400 { t.Fatalf("bad code: %v", resp.Code) } } // Test ACL token is resolved in correct order func TestACLResolution(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() var token string // Request without token req, _ := http.NewRequest("GET", "/v1/catalog/nodes", nil) // Request with explicit token reqToken, _ := http.NewRequest("GET", "/v1/catalog/nodes?token=foo", nil) // Request with header token only reqHeaderToken, _ := http.NewRequest("GET", "/v1/catalog/nodes", nil) reqHeaderToken.Header.Add("X-Consul-Token", "bar") // Request with header and querystring tokens reqBothTokens, _ := http.NewRequest("GET", "/v1/catalog/nodes?token=baz", nil) reqBothTokens.Header.Add("X-Consul-Token", "zap") // Request with Authorization Bearer token reqAuthBearerToken, _ := http.NewRequest("GET", "/v1/catalog/nodes", nil) reqAuthBearerToken.Header.Add("Authorization", "Bearer bearer-token") // Request with invalid Authorization scheme reqAuthBearerInvalidScheme, _ := http.NewRequest("GET", "/v1/catalog/nodes", nil) reqAuthBearerInvalidScheme.Header.Add("Authorization", "Beer") // Request with empty Authorization Bearer token reqAuthBearerTokenEmpty, _ := http.NewRequest("GET", "/v1/catalog/nodes", nil) reqAuthBearerTokenEmpty.Header.Add("Authorization", "Bearer") // Request with empty Authorization Bearer token reqAuthBearerTokenInvalid, _ := http.NewRequest("GET", "/v1/catalog/nodes", nil) reqAuthBearerTokenInvalid.Header.Add("Authorization", "Bearertoken") // Request with more than one space between Bearer and token reqAuthBearerTokenMultiSpaces, _ := http.NewRequest("GET", "/v1/catalog/nodes", nil) reqAuthBearerTokenMultiSpaces.Header.Add("Authorization", "Bearer bearer-token") // Request with Authorization Bearer token containing spaces reqAuthBearerTokenSpaces, _ := http.NewRequest("GET", "/v1/catalog/nodes", nil) reqAuthBearerTokenSpaces.Header.Add("Authorization", "Bearer bearer-token "+ " the rest is discarded ") // Request with Authorization Bearer and querystring token reqAuthBearerAndQsToken, _ := http.NewRequest("GET", "/v1/catalog/nodes?token=qstoken", nil) reqAuthBearerAndQsToken.Header.Add("Authorization", "Bearer bearer-token") // Request with Authorization Bearer and X-Consul-Token header token reqAuthBearerAndXToken, _ := http.NewRequest("GET", "/v1/catalog/nodes", nil) reqAuthBearerAndXToken.Header.Add("X-Consul-Token", "xtoken") reqAuthBearerAndXToken.Header.Add("Authorization", "Bearer bearer-token") a := NewTestAgent(t, "") defer a.Shutdown() // Check when no token is set a.tokens.UpdateUserToken("", tokenStore.TokenSourceConfig) a.srv.parseToken(req, &token) if token != "" { t.Fatalf("bad: %s", token) } // Check when ACLToken set a.tokens.UpdateUserToken("agent", tokenStore.TokenSourceAPI) a.srv.parseToken(req, &token) if token != "agent" { t.Fatalf("bad: %s", token) } // Explicit token has highest precedence a.srv.parseToken(reqToken, &token) if token != "foo" { t.Fatalf("bad: %s", token) } // Header token has precedence over agent token a.srv.parseToken(reqHeaderToken, &token) if token != "bar" { t.Fatalf("bad: %s", token) } // Querystring token has precedence over header and agent tokens a.srv.parseToken(reqBothTokens, &token) if token != "baz" { t.Fatalf("bad: %s", token) } // // Authorization Bearer token tests // // Check if Authorization bearer token header is parsed correctly a.srv.parseToken(reqAuthBearerToken, &token) if token != "bearer-token" { t.Fatalf("bad: %s", token) } // Check Authorization Bearer scheme invalid a.srv.parseToken(reqAuthBearerInvalidScheme, &token) if token != "agent" { t.Fatalf("bad: %s", token) } // Check if Authorization Bearer token is empty a.srv.parseToken(reqAuthBearerTokenEmpty, &token) if token != "agent" { t.Fatalf("bad: %s", token) } // Check if the Authorization Bearer token is invalid a.srv.parseToken(reqAuthBearerTokenInvalid, &token) if token != "agent" { t.Fatalf("bad: %s", token) } // Check multi spaces between Authorization Bearer and token value a.srv.parseToken(reqAuthBearerTokenMultiSpaces, &token) if token != "bearer-token" { t.Fatalf("bad: %s", token) } // Check if Authorization Bearer token with spaces is parsed correctly a.srv.parseToken(reqAuthBearerTokenSpaces, &token) if token != "bearer-token" { t.Fatalf("bad: %s", token) } // Check if explicit token has precedence over Authorization bearer token a.srv.parseToken(reqAuthBearerAndQsToken, &token) if token != "qstoken" { t.Fatalf("bad: %s", token) } // Check if X-Consul-Token has precedence over Authorization bearer token a.srv.parseToken(reqAuthBearerAndXToken, &token) if token != "xtoken" { t.Fatalf("bad: %s", token) } } func TestEnableWebUI(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } t.Parallel() a := NewTestAgent(t, ` ui_config { enabled = true } `) defer a.Shutdown() req, _ := http.NewRequest("GET", "/ui/", nil) resp := httptest.NewRecorder() a.enableDebug.Store(true) a.srv.handler().ServeHTTP(resp, req) require.Equal(t, http.StatusOK, resp.Code) // Validate that it actually sent the index page we expect since an error // during serving the special intercepted index.html can result in an empty // response but a 200 status. require.Contains(t, resp.Body.String(), `