// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package agent import ( "context" "fmt" "net" "net/http" "sync" "time" "github.com/hashicorp/go-hclog" "golang.org/x/sync/errgroup" ) // apiServers is a wrapper around errgroup.Group for managing go routines for // long running agent components (ex: http server, dns server). If any of the // servers fail, the failed channel will be closed, which will cause the agent // to be shutdown instead of running in a degraded state. // // This struct exists as a shim for using errgroup.Group without making major // changes to Agent. In the future it may be removed and replaced with more // direct usage of errgroup.Group. type apiServers struct { logger hclog.Logger group *errgroup.Group servers []apiServer // failed channel is closed when the first server goroutines exit with a // non-nil error. failed <-chan struct{} } type apiServer struct { // Protocol supported by this server. One of: dns, http, https Protocol string // Addr the server is listening on Addr net.Addr // Run will be called in a goroutine to run the server. When any Run exits // with a non-nil error, the failed channel will be closed. Run func() error // Shutdown function used to stop the server Shutdown func(context.Context) error } // NewAPIServers returns an empty apiServers that is ready to Start servers. func NewAPIServers(logger hclog.Logger) *apiServers { group, ctx := errgroup.WithContext(context.TODO()) return &apiServers{ logger: logger, group: group, failed: ctx.Done(), } } func (s *apiServers) Start(srv apiServer) { srv.logger(s.logger).Info("Starting server") s.servers = append(s.servers, srv) s.group.Go(srv.Run) } func (s apiServer) logger(base hclog.Logger) hclog.Logger { return base.With( "protocol", s.Protocol, "address", s.Addr.String(), "network", s.Addr.Network()) } // Shutdown all the servers and log any errors as warning. Each server is given // 1 second, or until ctx is cancelled, to shutdown gracefully. func (s *apiServers) Shutdown(ctx context.Context) { shutdownGroup := new(sync.WaitGroup) for i := range s.servers { server := s.servers[i] shutdownGroup.Add(1) go func() { defer shutdownGroup.Done() logger := server.logger(s.logger) logger.Info("Stopping server") ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() if err := server.Shutdown(ctx); err != nil { logger.Warn("Failed to stop server") } }() } s.servers = nil shutdownGroup.Wait() } // WaitForShutdown waits until all server goroutines have exited. Shutdown // must be called before WaitForShutdown, otherwise it will block forever. func (s *apiServers) WaitForShutdown() error { return s.group.Wait() } func newAPIServerHTTP(proto string, l net.Listener, httpServer *http.Server) apiServer { return apiServer{ Protocol: proto, Addr: l.Addr(), Shutdown: httpServer.Shutdown, Run: func() error { err := httpServer.Serve(l) if err == nil || err == http.ErrServerClosed { return nil } return fmt.Errorf("%s server %s failed: %w", proto, l.Addr(), err) }, } }