From 79d00f59cd79e1b8f8d2d7e82e81787ae82d7a78 Mon Sep 17 00:00:00 2001 From: "Chris S. Kim" Date: Tue, 9 Aug 2022 12:22:39 -0400 Subject: [PATCH] Close active listeners on error If startListeners successfully created listeners for some of its input addresses but eventually failed, the function would return an error and existing listeners would not be cleaned up. --- .changelog/14081.txt | 3 ++ agent/agent.go | 19 +++++++++++-- agent/agent_test.go | 67 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 3 deletions(-) create mode 100644 .changelog/14081.txt diff --git a/.changelog/14081.txt b/.changelog/14081.txt new file mode 100644 index 000000000..ccb03ffb0 --- /dev/null +++ b/.changelog/14081.txt @@ -0,0 +1,3 @@ +```release-note:bug +agent: Fixes an issue where an agent that fails to start due to bad addresses won't clean up any existing listeners +``` diff --git a/agent/agent.go b/agent/agent.go index e087af518..8a263647e 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -863,8 +863,18 @@ func (a *Agent) listenAndServeDNS() error { return merr.ErrorOrNil() } +// startListeners will return a net.Listener for every address unless an +// error is encountered, in which case it will close all previously opened +// listeners and return the error. func (a *Agent) startListeners(addrs []net.Addr) ([]net.Listener, error) { - var ln []net.Listener + var lns []net.Listener + + closeAll := func() { + for _, l := range lns { + l.Close() + } + } + for _, addr := range addrs { var l net.Listener var err error @@ -873,22 +883,25 @@ func (a *Agent) startListeners(addrs []net.Addr) ([]net.Listener, error) { case *net.UnixAddr: l, err = a.listenSocket(x.Name) if err != nil { + closeAll() return nil, err } case *net.TCPAddr: l, err = net.Listen("tcp", x.String()) if err != nil { + closeAll() return nil, err } l = &tcpKeepAliveListener{l.(*net.TCPListener)} default: + closeAll() return nil, fmt.Errorf("unsupported address type %T", addr) } - ln = append(ln, l) + lns = append(lns, l) } - return ln, nil + return lns, nil } // listenHTTP binds listeners to the provided addresses and also returns diff --git a/agent/agent_test.go b/agent/agent_test.go index d7b118fcb..8bae81ce4 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -5857,6 +5857,73 @@ func Test_coalesceTimerTwoPeriods(t *testing.T) { } +func TestAgent_startListeners(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + t.Parallel() + + ports := freeport.GetN(t, 3) + bd := BaseDeps{ + Deps: consul.Deps{ + Logger: hclog.NewInterceptLogger(nil), + Tokens: new(token.Store), + GRPCConnPool: &fakeGRPCConnPool{}, + }, + RuntimeConfig: &config.RuntimeConfig{ + HTTPAddrs: []net.Addr{}, + }, + Cache: cache.New(cache.Options{}), + } + + bd, err := initEnterpriseBaseDeps(bd, nil) + require.NoError(t, err) + + agent, err := New(bd) + require.NoError(t, err) + + // use up an address + used := net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[2]} + l, err := net.Listen("tcp", used.String()) + require.NoError(t, err) + t.Cleanup(func() { l.Close() }) + + var lns []net.Listener + t.Cleanup(func() { + for _, ln := range lns { + ln.Close() + } + }) + + // first two addresses open listeners but third address should fail + lns, err = agent.startListeners([]net.Addr{ + &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[0]}, + &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[1]}, + &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[2]}, + }) + require.Contains(t, err.Error(), "address already in use") + + // first two ports should be freed up + retry.Run(t, func(r *retry.R) { + lns, err = agent.startListeners([]net.Addr{ + &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[0]}, + &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[1]}, + }) + require.NoError(r, err) + require.Len(r, lns, 2) + }) + + // first two ports should be in use + retry.Run(t, func(r *retry.R) { + _, err = agent.startListeners([]net.Addr{ + &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[0]}, + &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: ports[1]}, + }) + require.Contains(r, err.Error(), "address already in use") + }) + +} + func getExpectedCaPoolByFile(t *testing.T) *x509.CertPool { pool := x509.NewCertPool() data, err := ioutil.ReadFile("../test/ca/root.cer")