diff --git a/agent/agent.go b/agent/agent.go index 9e5666a0f..14d3502df 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1151,24 +1151,25 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) { l = tls.NewListener(l, tlscfg) } + httpServer := &http.Server{ + Addr: l.Addr().String(), + TLSConfig: tlscfg, + } srv := &HTTPServer{ - Server: &http.Server{ - Addr: l.Addr().String(), - TLSConfig: tlscfg, - }, + Server: httpServer, ln: l, agent: a, denylist: NewDenylist(a.config.HTTPBlockEndpoints), proto: proto, } - srv.Server.Handler = srv.handler(a.config.EnableDebug) + httpServer.Handler = srv.handler(a.config.EnableDebug) // Load the connlimit helper into the server connLimitFn := a.httpConnLimiter.HTTPConnStateFunc() if proto == "https" { // Enforce TLS handshake timeout - srv.Server.ConnState = func(conn net.Conn, state http.ConnState) { + httpServer.ConnState = func(conn net.Conn, state http.ConnState) { switch state { case http.StateNew: // Set deadline to prevent slow send before TLS handshake or first @@ -1188,12 +1189,12 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) { // This will enable upgrading connections to HTTP/2 as // part of TLS negotiation. - err = http2.ConfigureServer(srv.Server, nil) + err = http2.ConfigureServer(httpServer, nil) if err != nil { return err } } else { - srv.Server.ConnState = connLimitFn + httpServer.ConnState = connLimitFn } ln = append(ln, l) @@ -1263,7 +1264,7 @@ func (a *Agent) serveHTTP(srv *HTTPServer) error { go func() { defer a.wgServers.Done() notif <- srv.ln.Addr() - err := srv.Serve(srv.ln) + err := srv.Server.Serve(srv.ln) if err != nil && err != http.ErrServerClosed { a.logger.Error("error closing server", "error", err) } @@ -2111,7 +2112,7 @@ func (a *Agent) ShutdownEndpoints() { ) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - srv.Shutdown(ctx) + srv.Server.Shutdown(ctx) if ctx.Err() == context.DeadlineExceeded { a.logger.Warn("Timeout stopping server", "protocol", strings.ToUpper(srv.proto), diff --git a/agent/agent_endpoint_test.go b/agent/agent_endpoint_test.go index 31f2501f8..c9aef60e3 100644 --- a/agent/agent_endpoint_test.go +++ b/agent/agent_endpoint_test.go @@ -4465,7 +4465,8 @@ func TestAgent_Monitor(t *testing.T) { req = req.WithContext(cancelCtx) resp := httptest.NewRecorder() - go a.srv.Handler.ServeHTTP(resp, req) + handler := a.srv.handler(true) + go handler.ServeHTTP(resp, req) args := &structs.ServiceDefinition{ Name: "monitor", diff --git a/agent/http.go b/agent/http.go index 7402ef5f9..4cc754dfe 100644 --- a/agent/http.go +++ b/agent/http.go @@ -81,7 +81,8 @@ func (e ForbiddenError) Error() string { // HTTPServer provides an HTTP api for an agent. type HTTPServer struct { - *http.Server + // TODO(dnephin): remove Server field, it is not used by any of the HTTPServer methods + Server *http.Server ln net.Listener agent *Agent denylist *Denylist diff --git a/agent/http_oss_test.go b/agent/http_oss_test.go index 62dafc61a..8e936d938 100644 --- a/agent/http_oss_test.go +++ b/agent/http_oss_test.go @@ -133,7 +133,7 @@ func TestHTTPAPI_OptionMethod_OSS(t *testing.T) { uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path) req, _ := http.NewRequest("OPTIONS", uri, nil) resp := httptest.NewRecorder() - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) allMethods := append([]string{"OPTIONS"}, methods...) if resp.Code != http.StatusOK { @@ -175,7 +175,7 @@ func TestHTTPAPI_AllowedNets_OSS(t *testing.T) { req, _ := http.NewRequest(method, uri, nil) req.RemoteAddr = "192.168.1.2:5555" resp := httptest.NewRecorder() - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) require.Equal(t, http.StatusForbidden, resp.Code, "%s %s", method, path) }) diff --git a/agent/http_test.go b/agent/http_test.go index e64edf0ef..b90715e29 100644 --- a/agent/http_test.go +++ b/agent/http_test.go @@ -129,7 +129,7 @@ func TestHTTPServer_UnixSocket_FileExists(t *testing.T) { } } -func TestHTTPServer_H2(t *testing.T) { +func TestHTTPServer_HTTP2(t *testing.T) { t.Parallel() // Fire up an agent with TLS enabled. @@ -161,16 +161,15 @@ func TestHTTPServer_H2(t *testing.T) { if err := http2.ConfigureTransport(transport); err != nil { t.Fatalf("err: %v", err) } - hc := &http.Client{ - Transport: transport, - } + httpClient := &http.Client{Transport: transport} // Hook a handler that echoes back the protocol. handler := func(resp http.ResponseWriter, req *http.Request) { resp.WriteHeader(http.StatusOK) fmt.Fprint(resp, req.Proto) } - w, ok := a.srv.Handler.(*wrappedMux) + + w, ok := a.srv.Server.Handler.(*wrappedMux) if !ok { t.Fatalf("handler is not expected type") } @@ -178,7 +177,7 @@ func TestHTTPServer_H2(t *testing.T) { // Call it and make sure we see HTTP/2. url := fmt.Sprintf("https://%s/echo", a.srv.ln.Addr().String()) - resp, err := hc.Get(url) + resp, err := httpClient.Get(url) if err != nil { t.Fatalf("err: %v", err) } @@ -197,7 +196,7 @@ func TestHTTPServer_H2(t *testing.T) { cfg := &api.Config{ Address: a.srv.ln.Addr().String(), Scheme: "https", - HttpClient: hc, + HttpClient: httpClient, } client, err := api.NewClient(cfg) if err != nil { @@ -333,7 +332,7 @@ func TestHTTPAPI_Ban_Nonprintable_Characters(t *testing.T) { t.Fatal(err) } resp := httptest.NewRecorder() - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) if got, want := resp.Code, http.StatusBadRequest; got != want { t.Fatalf("bad response code got %d want %d", got, want) } @@ -352,7 +351,7 @@ func TestHTTPAPI_Allow_Nonprintable_Characters_With_Flag(t *testing.T) { t.Fatal(err) } resp := httptest.NewRecorder() - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) // Key doesn't actually exist so we should get 404 if got, want := resp.Code, http.StatusNotFound; got != want { t.Fatalf("bad response code got %d want %d", got, want) @@ -490,14 +489,14 @@ func TestAcceptEncodingGzip(t *testing.T) { // negotiation, but since this call doesn't go through a real // transport, the header has to be set manually req.Header["Accept-Encoding"] = []string{"gzip"} - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) require.Equal(t, 200, resp.Code) require.Equal(t, "", resp.Header().Get("Content-Encoding")) resp = httptest.NewRecorder() req, _ = http.NewRequest("GET", "/v1/kv/long", nil) req.Header["Accept-Encoding"] = []string{"gzip"} - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) require.Equal(t, 200, resp.Code) require.Equal(t, "gzip", resp.Header().Get("Content-Encoding")) } @@ -811,35 +810,35 @@ func TestParseWait(t *testing.T) { } } -func TestPProfHandlers_EnableDebug(t *testing.T) { +func TestHTTPServer_PProfHandlers_EnableDebug(t *testing.T) { t.Parallel() - require := require.New(t) - a := NewTestAgent(t, "enable_debug = true") + a := NewTestAgent(t, ``) defer a.Shutdown() resp := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil) - a.srv.Handler.ServeHTTP(resp, req) + httpServer := &HTTPServer{agent: a.Agent} + httpServer.handler(true).ServeHTTP(resp, req) - require.Equal(http.StatusOK, resp.Code) + require.Equal(t, http.StatusOK, resp.Code) } -func TestPProfHandlers_DisableDebugNoACLs(t *testing.T) { +func TestHTTPServer_PProfHandlers_DisableDebugNoACLs(t *testing.T) { t.Parallel() - require := require.New(t) - a := NewTestAgent(t, "enable_debug = false") + a := NewTestAgent(t, ``) defer a.Shutdown() resp := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/debug/pprof/profile", nil) - a.srv.Handler.ServeHTTP(resp, req) + httpServer := &HTTPServer{agent: a.Agent} + httpServer.handler(false).ServeHTTP(resp, req) - require.Equal(http.StatusUnauthorized, resp.Code) + require.Equal(t, http.StatusUnauthorized, resp.Code) } -func TestPProfHandlers_ACLs(t *testing.T) { +func TestHTTPServer_PProfHandlers_ACLs(t *testing.T) { t.Parallel() assert := assert.New(t) dc1 := "dc1" @@ -904,7 +903,7 @@ func TestPProfHandlers_ACLs(t *testing.T) { t.Run(fmt.Sprintf("case %d (%#v)", i, c), func(t *testing.T) { req, _ := http.NewRequest("GET", fmt.Sprintf("%s?token=%s", c.endpoint, c.token), nil) resp := httptest.NewRecorder() - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) assert.Equal(c.code, resp.Code) }) } @@ -1192,7 +1191,7 @@ func TestEnableWebUI(t *testing.T) { req, _ := http.NewRequest("GET", "/ui/", nil) resp := httptest.NewRecorder() - a.srv.Handler.ServeHTTP(resp, req) + a.srv.handler(true).ServeHTTP(resp, req) if resp.Code != 200 { t.Fatalf("should handle ui") } diff --git a/agent/testagent.go b/agent/testagent.go index 635ae6e58..3ddcc2de6 100644 --- a/agent/testagent.go +++ b/agent/testagent.go @@ -384,7 +384,7 @@ func (a *TestAgent) HTTPAddr() string { if a.srv == nil { return "" } - return a.srv.Addr + return a.srv.Server.Addr } func (a *TestAgent) SegmentAddr(name string) string { diff --git a/agent/ui_endpoint_test.go b/agent/ui_endpoint_test.go index b139ba724..4694e392b 100644 --- a/agent/ui_endpoint_test.go +++ b/agent/ui_endpoint_test.go @@ -43,7 +43,7 @@ func TestUiIndex(t *testing.T) { // Register node req, _ := http.NewRequest("GET", "/ui/my-file", nil) req.URL.Scheme = "http" - req.URL.Host = a.srv.Addr + req.URL.Host = a.srv.Server.Addr // Make the request client := cleanhttp.DefaultClient()