Merge pull request #8234 from hashicorp/dnephin/shutdown-on-http-server-error

agent: shutdown if the http server goroutine exits
This commit is contained in:
Daniel Nephin 2020-09-03 16:44:21 -04:00 committed by GitHub
commit 3c77f4c7d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 234 additions and 106 deletions

View File

@ -259,10 +259,12 @@ type Agent struct {
// dnsServer provides the DNS API
dnsServers []*DNSServer
// httpServers provides the HTTP API on various endpoints
httpServers []*HTTPServer
// apiServers listening for connections. If any of these server goroutines
// fail, the agent will be shutdown.
apiServers *apiServers
// wgServers is the wait group for all HTTP and DNS servers
// TODO: remove once dnsServers are handled by apiServers
wgServers sync.WaitGroup
// watchPlans tracks all the currently-running watch plans for the
@ -375,6 +377,9 @@ func New(bd BaseDeps) (*Agent, error) {
a.loadTokens(a.config)
a.loadEnterpriseTokens(a.config)
// TODO: pass in a fully populated apiServers into Agent.New
a.apiServers = NewAPIServers(a.logger)
return &a, nil
}
@ -580,10 +585,7 @@ func (a *Agent) Start(ctx context.Context) error {
// Start HTTP and HTTPS servers.
for _, srv := range servers {
if err := a.serveHTTP(srv); err != nil {
return err
}
a.httpServers = append(a.httpServers, srv)
a.apiServers.Start(srv)
}
// Start gRPC server.
@ -605,6 +607,12 @@ func (a *Agent) Start(ctx context.Context) error {
return nil
}
// Failed returns a channel which is closed when the first server goroutine exits
// with a non-nil error.
func (a *Agent) Failed() <-chan struct{} {
return a.apiServers.failed
}
func (a *Agent) listenAndServeGRPC() error {
if len(a.config.GRPCAddrs) < 1 {
return nil
@ -737,14 +745,16 @@ func (a *Agent) startListeners(addrs []net.Addr) ([]net.Listener, error) {
//
// This approach should ultimately be refactored to the point where we just
// start the server and any error should trigger a proper shutdown of the agent.
func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
func (a *Agent) listenHTTP() ([]apiServer, error) {
var ln []net.Listener
var servers []*HTTPServer
var servers []apiServer
start := func(proto string, addrs []net.Addr) error {
listeners, err := a.startListeners(addrs)
if err != nil {
return err
}
ln = append(ln, listeners...)
for _, l := range listeners {
var tlscfg *tls.Config
@ -754,18 +764,15 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
l = tls.NewListener(l, tlscfg)
}
srv := &HTTPServer{
agent: a,
denylist: NewDenylist(a.config.HTTPBlockEndpoints),
}
httpServer := &http.Server{
Addr: l.Addr().String(),
TLSConfig: tlscfg,
Handler: srv.handler(a.config.EnableDebug),
}
srv := &HTTPServer{
Server: httpServer,
ln: l,
agent: a,
denylist: NewDenylist(a.config.HTTPBlockEndpoints),
proto: proto,
}
httpServer.Handler = srv.handler(a.config.EnableDebug)
// Load the connlimit helper into the server
connLimitFn := a.httpConnLimiter.HTTPConnStateFuncWithDefault429Handler(10 * time.Millisecond)
@ -778,27 +785,39 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
httpServer.ConnState = connLimitFn
}
ln = append(ln, l)
servers = append(servers, srv)
servers = append(servers, apiServer{
Protocol: proto,
Addr: l.Addr(),
Shutdown: httpServer.Shutdown,
Run: func() error {
err := httpServer.Serve(l)
if err == nil || err == http.ErrServerClosed {
return nil
}
return fmt.Errorf("%s server %s failed: %w", proto, l.Addr(), err)
},
})
}
return nil
}
if err := start("http", a.config.HTTPAddrs); err != nil {
for _, l := range ln {
l.Close()
}
closeListeners(ln)
return nil, err
}
if err := start("https", a.config.HTTPSAddrs); err != nil {
for _, l := range ln {
l.Close()
}
closeListeners(ln)
return nil, err
}
return servers, nil
}
func closeListeners(lns []net.Listener) {
for _, l := range lns {
l.Close()
}
}
// setupHTTPS adds HTTP/2 support, ConnState, and a connection handshake timeout
// to the http.Server.
func setupHTTPS(server *http.Server, connState func(net.Conn, http.ConnState), timeout time.Duration) error {
@ -860,43 +879,6 @@ func (a *Agent) listenSocket(path string) (net.Listener, error) {
return l, nil
}
func (a *Agent) serveHTTP(srv *HTTPServer) error {
// https://github.com/golang/go/issues/20239
//
// In go.8.1 there is a race between Serve and Shutdown. If
// Shutdown is called before the Serve go routine was scheduled then
// the Serve go routine never returns. This deadlocks the agent
// shutdown for some tests since it will wait forever.
notif := make(chan net.Addr)
a.wgServers.Add(1)
go func() {
defer a.wgServers.Done()
notif <- srv.ln.Addr()
err := srv.Server.Serve(srv.ln)
if err != nil && err != http.ErrServerClosed {
a.logger.Error("error closing server", "error", err)
}
}()
select {
case addr := <-notif:
if srv.proto == "https" {
a.logger.Info("Started HTTPS server",
"address", addr.String(),
"network", addr.Network(),
)
} else {
a.logger.Info("Started HTTP server",
"address", addr.String(),
"network", addr.Network(),
)
}
return nil
case <-time.After(time.Second):
return fmt.Errorf("agent: timeout starting HTTP servers")
}
}
// stopAllWatches stops all the currently running watches
func (a *Agent) stopAllWatches() {
for _, wp := range a.watchPlans {
@ -1395,13 +1377,12 @@ func (a *Agent) ShutdownAgent() error {
// ShutdownEndpoints terminates the HTTP and DNS servers. Should be
// preceded by ShutdownAgent.
// TODO: remove this method, move to ShutdownAgent
func (a *Agent) ShutdownEndpoints() {
a.shutdownLock.Lock()
defer a.shutdownLock.Unlock()
if len(a.dnsServers) == 0 && len(a.httpServers) == 0 {
return
}
ctx := context.TODO()
for _, srv := range a.dnsServers {
if srv.Server != nil {
@ -1415,27 +1396,11 @@ func (a *Agent) ShutdownEndpoints() {
}
a.dnsServers = nil
for _, srv := range a.httpServers {
a.logger.Info("Stopping server",
"protocol", strings.ToUpper(srv.proto),
"address", srv.ln.Addr().String(),
"network", srv.ln.Addr().Network(),
)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
srv.Server.Shutdown(ctx)
if ctx.Err() == context.DeadlineExceeded {
a.logger.Warn("Timeout stopping server",
"protocol", strings.ToUpper(srv.proto),
"address", srv.ln.Addr().String(),
"network", srv.ln.Addr().Network(),
)
}
}
a.httpServers = nil
a.apiServers.Shutdown(ctx)
a.logger.Info("Waiting for endpoints to shut down")
a.wgServers.Wait()
if err := a.apiServers.WaitForShutdown(); err != nil {
a.logger.Error(err.Error())
}
a.logger.Info("Endpoints down")
}

View File

@ -1917,7 +1917,7 @@ func TestAgent_HTTPCheck_EnableAgentTLSForChecks(t *testing.T) {
Status: api.HealthCritical,
}
url := fmt.Sprintf("https://%s/v1/agent/self", a.srv.ln.Addr().String())
url := fmt.Sprintf("https://%s/v1/agent/self", a.HTTPAddr())
chk := &structs.CheckType{
HTTP: url,
Interval: 20 * time.Millisecond,

94
agent/apiserver.go Normal file
View File

@ -0,0 +1,94 @@
package agent
import (
"context"
"net"
"sync"
"time"
"github.com/hashicorp/go-hclog"
"golang.org/x/sync/errgroup"
)
// apiServers is a wrapper around errgroup.Group for managing go routines for
// long running agent components (ex: http server, dns server). If any of the
// servers fail, the failed channel will be closed, which will cause the agent
// to be shutdown instead of running in a degraded state.
//
// This struct exists as a shim for using errgroup.Group without making major
// changes to Agent. In the future it may be removed and replaced with more
// direct usage of errgroup.Group.
type apiServers struct {
logger hclog.Logger
group *errgroup.Group
servers []apiServer
// failed channel is closed when the first server goroutines exit with a
// non-nil error.
failed <-chan struct{}
}
type apiServer struct {
// Protocol supported by this server. One of: dns, http, https
Protocol string
// Addr the server is listening on
Addr net.Addr
// Run will be called in a goroutine to run the server. When any Run exits
// with a non-nil error, the failed channel will be closed.
Run func() error
// Shutdown function used to stop the server
Shutdown func(context.Context) error
}
// NewAPIServers returns an empty apiServers that is ready to Start servers.
func NewAPIServers(logger hclog.Logger) *apiServers {
group, ctx := errgroup.WithContext(context.TODO())
return &apiServers{
logger: logger,
group: group,
failed: ctx.Done(),
}
}
func (s *apiServers) Start(srv apiServer) {
srv.logger(s.logger).Info("Starting server")
s.servers = append(s.servers, srv)
s.group.Go(srv.Run)
}
func (s apiServer) logger(base hclog.Logger) hclog.Logger {
return base.With(
"protocol", s.Protocol,
"address", s.Addr.String(),
"network", s.Addr.Network())
}
// Shutdown all the servers and log any errors as warning. Each server is given
// 1 second, or until ctx is cancelled, to shutdown gracefully.
func (s *apiServers) Shutdown(ctx context.Context) {
shutdownGroup := new(sync.WaitGroup)
for i := range s.servers {
server := s.servers[i]
shutdownGroup.Add(1)
go func() {
defer shutdownGroup.Done()
logger := server.logger(s.logger)
logger.Info("Stopping server")
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
logger.Warn("Failed to stop server")
}
}()
}
s.servers = nil
shutdownGroup.Wait()
}
// WaitForShutdown waits until all server goroutines have exited. Shutdown
// must be called before WaitForShutdown, otherwise it will block forever.
func (s *apiServers) WaitForShutdown() error {
return s.group.Wait()
}

65
agent/apiserver_test.go Normal file
View File

@ -0,0 +1,65 @@
package agent
import (
"context"
"fmt"
"net"
"testing"
"time"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/require"
)
func TestAPIServers_WithServiceRunError(t *testing.T) {
servers := NewAPIServers(hclog.New(nil))
server1, chErr1 := newAPIServerStub()
server2, _ := newAPIServerStub()
t.Run("Start", func(t *testing.T) {
servers.Start(server1)
servers.Start(server2)
select {
case <-servers.failed:
t.Fatalf("expected servers to still be running")
case <-time.After(5 * time.Millisecond):
}
})
err := fmt.Errorf("oops, I broke")
t.Run("server exit non-nil error", func(t *testing.T) {
chErr1 <- err
select {
case <-servers.failed:
case <-time.After(time.Second):
t.Fatalf("expected failed channel to be closed")
}
})
t.Run("shutdown remaining services", func(t *testing.T) {
servers.Shutdown(context.Background())
require.Equal(t, err, servers.WaitForShutdown())
})
}
func newAPIServerStub() (apiServer, chan error) {
chErr := make(chan error)
return apiServer{
Protocol: "http",
Addr: &net.TCPAddr{
IP: net.ParseIP("127.0.0.11"),
Port: 5505,
},
Run: func() error {
return <-chErr
},
Shutdown: func(ctx context.Context) error {
close(chErr)
return nil
},
}, chErr
}

View File

@ -80,16 +80,14 @@ func (e ForbiddenError) Error() string {
}
// HTTPServer provides an HTTP api for an agent.
//
// TODO: rename this struct to something more appropriate. It is an http.Handler,
// request router or multiplexer, but it is not a Server.
type HTTPServer struct {
// TODO(dnephin): remove Server field, it is not used by any of the HTTPServer methods
Server *http.Server
ln net.Listener
agent *Agent
denylist *Denylist
// proto is filled by the agent to "http" or "https".
proto string
}
type templatedFile struct {
templated *bytes.Reader
name string

View File

@ -1353,7 +1353,7 @@ func TestHTTPServer_HandshakeTimeout(t *testing.T) {
// Connect to it with a plain TCP client that doesn't attempt to send HTTP or
// complete a TLS handshake.
conn, err := net.Dial("tcp", a.srv.ln.Addr().String())
conn, err := net.Dial("tcp", a.HTTPAddr())
require.NoError(t, err)
defer conn.Close()
@ -1413,7 +1413,7 @@ func TestRPC_HTTPSMaxConnsPerClient(t *testing.T) {
})
defer a.Shutdown()
addr := a.srv.ln.Addr()
addr := a.HTTPAddr()
assertConn := func(conn net.Conn, wantOpen bool) {
retry.Run(t, func(r *retry.R) {
@ -1433,21 +1433,21 @@ func TestRPC_HTTPSMaxConnsPerClient(t *testing.T) {
}
// Connect to the server with bare TCP
conn1, err := net.DialTimeout("tcp", addr.String(), time.Second)
conn1, err := net.DialTimeout("tcp", addr, time.Second)
require.NoError(t, err)
defer conn1.Close()
assertConn(conn1, true)
// Two conns should succeed
conn2, err := net.DialTimeout("tcp", addr.String(), time.Second)
conn2, err := net.DialTimeout("tcp", addr, time.Second)
require.NoError(t, err)
defer conn2.Close()
assertConn(conn2, true)
// Third should succeed negotiating TCP handshake...
conn3, err := net.DialTimeout("tcp", addr.String(), time.Second)
conn3, err := net.DialTimeout("tcp", addr, time.Second)
require.NoError(t, err)
defer conn3.Close()
@ -1460,7 +1460,7 @@ func TestRPC_HTTPSMaxConnsPerClient(t *testing.T) {
require.NoError(t, a.reloadConfigInternal(&newCfg))
// Now another conn should be allowed
conn4, err := net.DialTimeout("tcp", addr.String(), time.Second)
conn4, err := net.DialTimeout("tcp", addr, time.Second)
require.NoError(t, err)
defer conn4.Close()

View File

@ -73,8 +73,7 @@ type TestAgent struct {
// It is valid after Start().
dns *DNSServer
// srv is a reference to the first started HTTP endpoint.
// It is valid after Start().
// srv is an HTTPServer that may be used to test http endpoints.
srv *HTTPServer
// overrides is an hcl config source to use to override otherwise
@ -213,6 +212,8 @@ func (a *TestAgent) Start(t *testing.T) (err error) {
// Start the anti-entropy syncer
a.Agent.StartSync()
a.srv = &HTTPServer{agent: agent, denylist: NewDenylist(a.config.HTTPBlockEndpoints)}
if err := a.waitForUp(); err != nil {
a.Shutdown()
t.Logf("Error while waiting for test agent to start: %v", err)
@ -220,7 +221,6 @@ func (a *TestAgent) Start(t *testing.T) (err error) {
}
a.dns = a.dnsServers[0]
a.srv = a.httpServers[0]
return nil
}
@ -233,7 +233,7 @@ func (a *TestAgent) waitForUp() error {
var retErr error
var out structs.IndexedNodes
for ; !time.Now().After(deadline); time.Sleep(timer.Wait) {
if len(a.httpServers) == 0 {
if len(a.apiServers.servers) == 0 {
retErr = fmt.Errorf("waiting for server")
continue // fail, try again
}
@ -262,7 +262,7 @@ func (a *TestAgent) waitForUp() error {
} else {
req := httptest.NewRequest("GET", "/v1/agent/self", nil)
resp := httptest.NewRecorder()
_, err := a.httpServers[0].AgentSelf(resp, req)
_, err := a.srv.AgentSelf(resp, req)
if acl.IsErrPermissionDenied(err) || resp.Code == 403 {
// permission denied is enough to show that the client is
// connected to the servers as it would get a 503 if
@ -313,10 +313,13 @@ func (a *TestAgent) DNSAddr() string {
}
func (a *TestAgent) HTTPAddr() string {
if a.srv == nil {
return ""
var srv apiServer
for _, srv = range a.Agent.apiServers.servers {
if srv.Protocol == "http" {
break
}
}
return a.srv.Server.Addr
return srv.Addr.String()
}
func (a *TestAgent) SegmentAddr(name string) string {

View File

@ -41,7 +41,7 @@ func TestUiIndex(t *testing.T) {
// Register node
req, _ := http.NewRequest("GET", "/ui/my-file", nil)
req.URL.Scheme = "http"
req.URL.Host = a.srv.Server.Addr
req.URL.Host = a.HTTPAddr()
// Make the request
client := cleanhttp.DefaultClient()

View File

@ -288,6 +288,9 @@ func (c *cmd) run(args []string) int {
case err := <-agent.RetryJoinCh():
c.logger.Error("Retry join failed", "error", err)
return 1
case <-agent.Failed():
// The deferred Shutdown method will log the appropriate error
return 1
case <-agent.ShutdownCh():
// agent is already down!
return 0