Merge pull request #8234 from hashicorp/dnephin/shutdown-on-http-server-error
agent: shutdown if the http server goroutine exits
This commit is contained in:
commit
3c77f4c7d2
133
agent/agent.go
133
agent/agent.go
|
@ -259,10 +259,12 @@ type Agent struct {
|
|||
// dnsServer provides the DNS API
|
||||
dnsServers []*DNSServer
|
||||
|
||||
// httpServers provides the HTTP API on various endpoints
|
||||
httpServers []*HTTPServer
|
||||
// apiServers listening for connections. If any of these server goroutines
|
||||
// fail, the agent will be shutdown.
|
||||
apiServers *apiServers
|
||||
|
||||
// wgServers is the wait group for all HTTP and DNS servers
|
||||
// TODO: remove once dnsServers are handled by apiServers
|
||||
wgServers sync.WaitGroup
|
||||
|
||||
// watchPlans tracks all the currently-running watch plans for the
|
||||
|
@ -375,6 +377,9 @@ func New(bd BaseDeps) (*Agent, error) {
|
|||
a.loadTokens(a.config)
|
||||
a.loadEnterpriseTokens(a.config)
|
||||
|
||||
// TODO: pass in a fully populated apiServers into Agent.New
|
||||
a.apiServers = NewAPIServers(a.logger)
|
||||
|
||||
return &a, nil
|
||||
}
|
||||
|
||||
|
@ -580,10 +585,7 @@ func (a *Agent) Start(ctx context.Context) error {
|
|||
|
||||
// Start HTTP and HTTPS servers.
|
||||
for _, srv := range servers {
|
||||
if err := a.serveHTTP(srv); err != nil {
|
||||
return err
|
||||
}
|
||||
a.httpServers = append(a.httpServers, srv)
|
||||
a.apiServers.Start(srv)
|
||||
}
|
||||
|
||||
// Start gRPC server.
|
||||
|
@ -605,6 +607,12 @@ func (a *Agent) Start(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Failed returns a channel which is closed when the first server goroutine exits
|
||||
// with a non-nil error.
|
||||
func (a *Agent) Failed() <-chan struct{} {
|
||||
return a.apiServers.failed
|
||||
}
|
||||
|
||||
func (a *Agent) listenAndServeGRPC() error {
|
||||
if len(a.config.GRPCAddrs) < 1 {
|
||||
return nil
|
||||
|
@ -737,14 +745,16 @@ func (a *Agent) startListeners(addrs []net.Addr) ([]net.Listener, error) {
|
|||
//
|
||||
// This approach should ultimately be refactored to the point where we just
|
||||
// start the server and any error should trigger a proper shutdown of the agent.
|
||||
func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
|
||||
func (a *Agent) listenHTTP() ([]apiServer, error) {
|
||||
var ln []net.Listener
|
||||
var servers []*HTTPServer
|
||||
var servers []apiServer
|
||||
|
||||
start := func(proto string, addrs []net.Addr) error {
|
||||
listeners, err := a.startListeners(addrs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ln = append(ln, listeners...)
|
||||
|
||||
for _, l := range listeners {
|
||||
var tlscfg *tls.Config
|
||||
|
@ -754,18 +764,15 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
|
|||
l = tls.NewListener(l, tlscfg)
|
||||
}
|
||||
|
||||
srv := &HTTPServer{
|
||||
agent: a,
|
||||
denylist: NewDenylist(a.config.HTTPBlockEndpoints),
|
||||
}
|
||||
httpServer := &http.Server{
|
||||
Addr: l.Addr().String(),
|
||||
TLSConfig: tlscfg,
|
||||
Handler: srv.handler(a.config.EnableDebug),
|
||||
}
|
||||
srv := &HTTPServer{
|
||||
Server: httpServer,
|
||||
ln: l,
|
||||
agent: a,
|
||||
denylist: NewDenylist(a.config.HTTPBlockEndpoints),
|
||||
proto: proto,
|
||||
}
|
||||
httpServer.Handler = srv.handler(a.config.EnableDebug)
|
||||
|
||||
// Load the connlimit helper into the server
|
||||
connLimitFn := a.httpConnLimiter.HTTPConnStateFuncWithDefault429Handler(10 * time.Millisecond)
|
||||
|
@ -778,27 +785,39 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
|
|||
httpServer.ConnState = connLimitFn
|
||||
}
|
||||
|
||||
ln = append(ln, l)
|
||||
servers = append(servers, srv)
|
||||
servers = append(servers, apiServer{
|
||||
Protocol: proto,
|
||||
Addr: l.Addr(),
|
||||
Shutdown: httpServer.Shutdown,
|
||||
Run: func() error {
|
||||
err := httpServer.Serve(l)
|
||||
if err == nil || err == http.ErrServerClosed {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%s server %s failed: %w", proto, l.Addr(), err)
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := start("http", a.config.HTTPAddrs); err != nil {
|
||||
for _, l := range ln {
|
||||
l.Close()
|
||||
}
|
||||
closeListeners(ln)
|
||||
return nil, err
|
||||
}
|
||||
if err := start("https", a.config.HTTPSAddrs); err != nil {
|
||||
for _, l := range ln {
|
||||
l.Close()
|
||||
}
|
||||
closeListeners(ln)
|
||||
return nil, err
|
||||
}
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
func closeListeners(lns []net.Listener) {
|
||||
for _, l := range lns {
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// setupHTTPS adds HTTP/2 support, ConnState, and a connection handshake timeout
|
||||
// to the http.Server.
|
||||
func setupHTTPS(server *http.Server, connState func(net.Conn, http.ConnState), timeout time.Duration) error {
|
||||
|
@ -860,43 +879,6 @@ func (a *Agent) listenSocket(path string) (net.Listener, error) {
|
|||
return l, nil
|
||||
}
|
||||
|
||||
func (a *Agent) serveHTTP(srv *HTTPServer) error {
|
||||
// https://github.com/golang/go/issues/20239
|
||||
//
|
||||
// In go.8.1 there is a race between Serve and Shutdown. If
|
||||
// Shutdown is called before the Serve go routine was scheduled then
|
||||
// the Serve go routine never returns. This deadlocks the agent
|
||||
// shutdown for some tests since it will wait forever.
|
||||
notif := make(chan net.Addr)
|
||||
a.wgServers.Add(1)
|
||||
go func() {
|
||||
defer a.wgServers.Done()
|
||||
notif <- srv.ln.Addr()
|
||||
err := srv.Server.Serve(srv.ln)
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
a.logger.Error("error closing server", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case addr := <-notif:
|
||||
if srv.proto == "https" {
|
||||
a.logger.Info("Started HTTPS server",
|
||||
"address", addr.String(),
|
||||
"network", addr.Network(),
|
||||
)
|
||||
} else {
|
||||
a.logger.Info("Started HTTP server",
|
||||
"address", addr.String(),
|
||||
"network", addr.Network(),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
case <-time.After(time.Second):
|
||||
return fmt.Errorf("agent: timeout starting HTTP servers")
|
||||
}
|
||||
}
|
||||
|
||||
// stopAllWatches stops all the currently running watches
|
||||
func (a *Agent) stopAllWatches() {
|
||||
for _, wp := range a.watchPlans {
|
||||
|
@ -1395,13 +1377,12 @@ func (a *Agent) ShutdownAgent() error {
|
|||
|
||||
// ShutdownEndpoints terminates the HTTP and DNS servers. Should be
|
||||
// preceded by ShutdownAgent.
|
||||
// TODO: remove this method, move to ShutdownAgent
|
||||
func (a *Agent) ShutdownEndpoints() {
|
||||
a.shutdownLock.Lock()
|
||||
defer a.shutdownLock.Unlock()
|
||||
|
||||
if len(a.dnsServers) == 0 && len(a.httpServers) == 0 {
|
||||
return
|
||||
}
|
||||
ctx := context.TODO()
|
||||
|
||||
for _, srv := range a.dnsServers {
|
||||
if srv.Server != nil {
|
||||
|
@ -1415,27 +1396,11 @@ func (a *Agent) ShutdownEndpoints() {
|
|||
}
|
||||
a.dnsServers = nil
|
||||
|
||||
for _, srv := range a.httpServers {
|
||||
a.logger.Info("Stopping server",
|
||||
"protocol", strings.ToUpper(srv.proto),
|
||||
"address", srv.ln.Addr().String(),
|
||||
"network", srv.ln.Addr().Network(),
|
||||
)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
srv.Server.Shutdown(ctx)
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
a.logger.Warn("Timeout stopping server",
|
||||
"protocol", strings.ToUpper(srv.proto),
|
||||
"address", srv.ln.Addr().String(),
|
||||
"network", srv.ln.Addr().Network(),
|
||||
)
|
||||
}
|
||||
}
|
||||
a.httpServers = nil
|
||||
|
||||
a.apiServers.Shutdown(ctx)
|
||||
a.logger.Info("Waiting for endpoints to shut down")
|
||||
a.wgServers.Wait()
|
||||
if err := a.apiServers.WaitForShutdown(); err != nil {
|
||||
a.logger.Error(err.Error())
|
||||
}
|
||||
a.logger.Info("Endpoints down")
|
||||
}
|
||||
|
||||
|
|
|
@ -1917,7 +1917,7 @@ func TestAgent_HTTPCheck_EnableAgentTLSForChecks(t *testing.T) {
|
|||
Status: api.HealthCritical,
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("https://%s/v1/agent/self", a.srv.ln.Addr().String())
|
||||
url := fmt.Sprintf("https://%s/v1/agent/self", a.HTTPAddr())
|
||||
chk := &structs.CheckType{
|
||||
HTTP: url,
|
||||
Interval: 20 * time.Millisecond,
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// apiServers is a wrapper around errgroup.Group for managing go routines for
|
||||
// long running agent components (ex: http server, dns server). If any of the
|
||||
// servers fail, the failed channel will be closed, which will cause the agent
|
||||
// to be shutdown instead of running in a degraded state.
|
||||
//
|
||||
// This struct exists as a shim for using errgroup.Group without making major
|
||||
// changes to Agent. In the future it may be removed and replaced with more
|
||||
// direct usage of errgroup.Group.
|
||||
type apiServers struct {
|
||||
logger hclog.Logger
|
||||
group *errgroup.Group
|
||||
servers []apiServer
|
||||
// failed channel is closed when the first server goroutines exit with a
|
||||
// non-nil error.
|
||||
failed <-chan struct{}
|
||||
}
|
||||
|
||||
type apiServer struct {
|
||||
// Protocol supported by this server. One of: dns, http, https
|
||||
Protocol string
|
||||
// Addr the server is listening on
|
||||
Addr net.Addr
|
||||
// Run will be called in a goroutine to run the server. When any Run exits
|
||||
// with a non-nil error, the failed channel will be closed.
|
||||
Run func() error
|
||||
// Shutdown function used to stop the server
|
||||
Shutdown func(context.Context) error
|
||||
}
|
||||
|
||||
// NewAPIServers returns an empty apiServers that is ready to Start servers.
|
||||
func NewAPIServers(logger hclog.Logger) *apiServers {
|
||||
group, ctx := errgroup.WithContext(context.TODO())
|
||||
return &apiServers{
|
||||
logger: logger,
|
||||
group: group,
|
||||
failed: ctx.Done(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *apiServers) Start(srv apiServer) {
|
||||
srv.logger(s.logger).Info("Starting server")
|
||||
s.servers = append(s.servers, srv)
|
||||
s.group.Go(srv.Run)
|
||||
}
|
||||
|
||||
func (s apiServer) logger(base hclog.Logger) hclog.Logger {
|
||||
return base.With(
|
||||
"protocol", s.Protocol,
|
||||
"address", s.Addr.String(),
|
||||
"network", s.Addr.Network())
|
||||
}
|
||||
|
||||
// Shutdown all the servers and log any errors as warning. Each server is given
|
||||
// 1 second, or until ctx is cancelled, to shutdown gracefully.
|
||||
func (s *apiServers) Shutdown(ctx context.Context) {
|
||||
shutdownGroup := new(sync.WaitGroup)
|
||||
|
||||
for i := range s.servers {
|
||||
server := s.servers[i]
|
||||
shutdownGroup.Add(1)
|
||||
|
||||
go func() {
|
||||
defer shutdownGroup.Done()
|
||||
logger := server.logger(s.logger)
|
||||
logger.Info("Stopping server")
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
logger.Warn("Failed to stop server")
|
||||
}
|
||||
}()
|
||||
}
|
||||
s.servers = nil
|
||||
shutdownGroup.Wait()
|
||||
}
|
||||
|
||||
// WaitForShutdown waits until all server goroutines have exited. Shutdown
|
||||
// must be called before WaitForShutdown, otherwise it will block forever.
|
||||
func (s *apiServers) WaitForShutdown() error {
|
||||
return s.group.Wait()
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAPIServers_WithServiceRunError(t *testing.T) {
|
||||
servers := NewAPIServers(hclog.New(nil))
|
||||
|
||||
server1, chErr1 := newAPIServerStub()
|
||||
server2, _ := newAPIServerStub()
|
||||
|
||||
t.Run("Start", func(t *testing.T) {
|
||||
servers.Start(server1)
|
||||
servers.Start(server2)
|
||||
|
||||
select {
|
||||
case <-servers.failed:
|
||||
t.Fatalf("expected servers to still be running")
|
||||
case <-time.After(5 * time.Millisecond):
|
||||
}
|
||||
})
|
||||
|
||||
err := fmt.Errorf("oops, I broke")
|
||||
|
||||
t.Run("server exit non-nil error", func(t *testing.T) {
|
||||
chErr1 <- err
|
||||
|
||||
select {
|
||||
case <-servers.failed:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("expected failed channel to be closed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("shutdown remaining services", func(t *testing.T) {
|
||||
servers.Shutdown(context.Background())
|
||||
require.Equal(t, err, servers.WaitForShutdown())
|
||||
})
|
||||
}
|
||||
|
||||
func newAPIServerStub() (apiServer, chan error) {
|
||||
chErr := make(chan error)
|
||||
return apiServer{
|
||||
Protocol: "http",
|
||||
Addr: &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.11"),
|
||||
Port: 5505,
|
||||
},
|
||||
Run: func() error {
|
||||
return <-chErr
|
||||
},
|
||||
Shutdown: func(ctx context.Context) error {
|
||||
close(chErr)
|
||||
return nil
|
||||
},
|
||||
}, chErr
|
||||
}
|
|
@ -80,16 +80,14 @@ func (e ForbiddenError) Error() string {
|
|||
}
|
||||
|
||||
// HTTPServer provides an HTTP api for an agent.
|
||||
//
|
||||
// TODO: rename this struct to something more appropriate. It is an http.Handler,
|
||||
// request router or multiplexer, but it is not a Server.
|
||||
type HTTPServer struct {
|
||||
// TODO(dnephin): remove Server field, it is not used by any of the HTTPServer methods
|
||||
Server *http.Server
|
||||
ln net.Listener
|
||||
agent *Agent
|
||||
denylist *Denylist
|
||||
|
||||
// proto is filled by the agent to "http" or "https".
|
||||
proto string
|
||||
}
|
||||
|
||||
type templatedFile struct {
|
||||
templated *bytes.Reader
|
||||
name string
|
||||
|
|
|
@ -1353,7 +1353,7 @@ func TestHTTPServer_HandshakeTimeout(t *testing.T) {
|
|||
|
||||
// Connect to it with a plain TCP client that doesn't attempt to send HTTP or
|
||||
// complete a TLS handshake.
|
||||
conn, err := net.Dial("tcp", a.srv.ln.Addr().String())
|
||||
conn, err := net.Dial("tcp", a.HTTPAddr())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -1413,7 +1413,7 @@ func TestRPC_HTTPSMaxConnsPerClient(t *testing.T) {
|
|||
})
|
||||
defer a.Shutdown()
|
||||
|
||||
addr := a.srv.ln.Addr()
|
||||
addr := a.HTTPAddr()
|
||||
|
||||
assertConn := func(conn net.Conn, wantOpen bool) {
|
||||
retry.Run(t, func(r *retry.R) {
|
||||
|
@ -1433,21 +1433,21 @@ func TestRPC_HTTPSMaxConnsPerClient(t *testing.T) {
|
|||
}
|
||||
|
||||
// Connect to the server with bare TCP
|
||||
conn1, err := net.DialTimeout("tcp", addr.String(), time.Second)
|
||||
conn1, err := net.DialTimeout("tcp", addr, time.Second)
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close()
|
||||
|
||||
assertConn(conn1, true)
|
||||
|
||||
// Two conns should succeed
|
||||
conn2, err := net.DialTimeout("tcp", addr.String(), time.Second)
|
||||
conn2, err := net.DialTimeout("tcp", addr, time.Second)
|
||||
require.NoError(t, err)
|
||||
defer conn2.Close()
|
||||
|
||||
assertConn(conn2, true)
|
||||
|
||||
// Third should succeed negotiating TCP handshake...
|
||||
conn3, err := net.DialTimeout("tcp", addr.String(), time.Second)
|
||||
conn3, err := net.DialTimeout("tcp", addr, time.Second)
|
||||
require.NoError(t, err)
|
||||
defer conn3.Close()
|
||||
|
||||
|
@ -1460,7 +1460,7 @@ func TestRPC_HTTPSMaxConnsPerClient(t *testing.T) {
|
|||
require.NoError(t, a.reloadConfigInternal(&newCfg))
|
||||
|
||||
// Now another conn should be allowed
|
||||
conn4, err := net.DialTimeout("tcp", addr.String(), time.Second)
|
||||
conn4, err := net.DialTimeout("tcp", addr, time.Second)
|
||||
require.NoError(t, err)
|
||||
defer conn4.Close()
|
||||
|
||||
|
|
|
@ -73,8 +73,7 @@ type TestAgent struct {
|
|||
// It is valid after Start().
|
||||
dns *DNSServer
|
||||
|
||||
// srv is a reference to the first started HTTP endpoint.
|
||||
// It is valid after Start().
|
||||
// srv is an HTTPServer that may be used to test http endpoints.
|
||||
srv *HTTPServer
|
||||
|
||||
// overrides is an hcl config source to use to override otherwise
|
||||
|
@ -213,6 +212,8 @@ func (a *TestAgent) Start(t *testing.T) (err error) {
|
|||
// Start the anti-entropy syncer
|
||||
a.Agent.StartSync()
|
||||
|
||||
a.srv = &HTTPServer{agent: agent, denylist: NewDenylist(a.config.HTTPBlockEndpoints)}
|
||||
|
||||
if err := a.waitForUp(); err != nil {
|
||||
a.Shutdown()
|
||||
t.Logf("Error while waiting for test agent to start: %v", err)
|
||||
|
@ -220,7 +221,6 @@ func (a *TestAgent) Start(t *testing.T) (err error) {
|
|||
}
|
||||
|
||||
a.dns = a.dnsServers[0]
|
||||
a.srv = a.httpServers[0]
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -233,7 +233,7 @@ func (a *TestAgent) waitForUp() error {
|
|||
var retErr error
|
||||
var out structs.IndexedNodes
|
||||
for ; !time.Now().After(deadline); time.Sleep(timer.Wait) {
|
||||
if len(a.httpServers) == 0 {
|
||||
if len(a.apiServers.servers) == 0 {
|
||||
retErr = fmt.Errorf("waiting for server")
|
||||
continue // fail, try again
|
||||
}
|
||||
|
@ -262,7 +262,7 @@ func (a *TestAgent) waitForUp() error {
|
|||
} else {
|
||||
req := httptest.NewRequest("GET", "/v1/agent/self", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.httpServers[0].AgentSelf(resp, req)
|
||||
_, err := a.srv.AgentSelf(resp, req)
|
||||
if acl.IsErrPermissionDenied(err) || resp.Code == 403 {
|
||||
// permission denied is enough to show that the client is
|
||||
// connected to the servers as it would get a 503 if
|
||||
|
@ -313,10 +313,13 @@ func (a *TestAgent) DNSAddr() string {
|
|||
}
|
||||
|
||||
func (a *TestAgent) HTTPAddr() string {
|
||||
if a.srv == nil {
|
||||
return ""
|
||||
var srv apiServer
|
||||
for _, srv = range a.Agent.apiServers.servers {
|
||||
if srv.Protocol == "http" {
|
||||
break
|
||||
}
|
||||
}
|
||||
return a.srv.Server.Addr
|
||||
return srv.Addr.String()
|
||||
}
|
||||
|
||||
func (a *TestAgent) SegmentAddr(name string) string {
|
||||
|
|
|
@ -41,7 +41,7 @@ func TestUiIndex(t *testing.T) {
|
|||
// Register node
|
||||
req, _ := http.NewRequest("GET", "/ui/my-file", nil)
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = a.srv.Server.Addr
|
||||
req.URL.Host = a.HTTPAddr()
|
||||
|
||||
// Make the request
|
||||
client := cleanhttp.DefaultClient()
|
||||
|
|
|
@ -288,6 +288,9 @@ func (c *cmd) run(args []string) int {
|
|||
case err := <-agent.RetryJoinCh():
|
||||
c.logger.Error("Retry join failed", "error", err)
|
||||
return 1
|
||||
case <-agent.Failed():
|
||||
// The deferred Shutdown method will log the appropriate error
|
||||
return 1
|
||||
case <-agent.ShutdownCh():
|
||||
// agent is already down!
|
||||
return 0
|
||||
|
|
Loading…
Reference in New Issue