diff --git a/agent/consul/config.go b/agent/consul/config.go index 820e8cbea..2c752064d 100644 --- a/agent/consul/config.go +++ b/agent/consul/config.go @@ -85,6 +85,11 @@ type Config struct { // as a voting member of the Raft cluster. NonVoter bool + // NotifyListen is called after the RPC listener has been configured. + // RPCAdvertise will be set to the listener address if it hasn't been + // configured at this point. + NotifyListen func() + // RPCAddr is the RPC address used by Consul. This should be reachable // by the WAN and LAN RPCAddr *net.TCPAddr @@ -92,7 +97,8 @@ type Config struct { // RPCAdvertise is the address that is advertised to other nodes for // the RPC endpoint. This can differ from the RPC address, if for example // the RPCAddr is unspecified "0.0.0.0:8300", but this address must be - // reachable + // reachable. If RPCAdvertise is nil then it will be set to the Listener + // address after the listening socket is configured. RPCAdvertise *net.TCPAddr // RPCSrcAddr is the source address for outgoing RPC connections. diff --git a/agent/consul/rpc.go b/agent/consul/rpc.go index 55b330c2c..9f63166c6 100644 --- a/agent/consul/rpc.go +++ b/agent/consul/rpc.go @@ -49,7 +49,7 @@ const ( func (s *Server) listen() { for { // Accept a connection - conn, err := s.rpcListener.Accept() + conn, err := s.Listener.Accept() if err != nil { if s.shutdown { return diff --git a/agent/consul/server.go b/agent/consul/server.go index 12d5f10be..a53f5a213 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -150,9 +150,9 @@ type Server struct { // Enterprise user-defined areas. router *servers.Router - // rpcListener is used to listen for incoming connections - rpcListener net.Listener - rpcServer *rpc.Server + // Listener is used to listen for incoming connections + Listener net.Listener + rpcServer *rpc.Server // rpcTLS is the TLS config for incoming TLS requests rpcTLS *tls.Config @@ -392,7 +392,7 @@ func NewServerLogger(config *Config, logger *log.Logger) (*Server, error) { // setupSerf is used to setup and initialize a Serf func (s *Server) setupSerf(conf *serf.Config, ch chan serf.Event, path string, wan bool) (*serf.Serf, error) { - addr := s.rpcListener.Addr().(*net.TCPAddr) + addr := s.Listener.Addr().(*net.TCPAddr) conf.Init() if wan { conf.NodeName = fmt.Sprintf("%s.%s", s.config.NodeName, s.config.Datacenter) @@ -645,7 +645,14 @@ func (s *Server) setupRPC(tlsWrap tlsutil.DCWrapper) error { if err != nil { return err } - s.rpcListener = ln + s.Listener = ln + if s.config.NotifyListen != nil { + s.config.NotifyListen() + } + // todo(fs): we should probably guard this + if s.config.RPCAdvertise == nil { + s.config.RPCAdvertise = ln.Addr().(*net.TCPAddr) + } // Verify that we have a usable advertise address if s.config.RPCAdvertise.IP.IsUnspecified() { @@ -714,8 +721,8 @@ func (s *Server) Shutdown() error { } } - if s.rpcListener != nil { - s.rpcListener.Close() + if s.Listener != nil { + s.Listener.Close() } // Close the connection pool diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index 36ad92ffc..91a474d9d 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -35,25 +35,30 @@ func testServerConfig(t *testing.T, NodeName string) (string, *Config) { config.Bootstrap = true config.Datacenter = "dc1" config.DataDir = dir - config.RPCAddr = &net.TCPAddr{ - IP: []byte{127, 0, 0, 1}, - Port: getPort(), - } - config.RPCAdvertise = config.RPCAddr + + // bind the rpc server to a random port. config.RPCAdvertise will be + // set to the listen address unless it was set in the configuration. + // In that case get the address from srv.Listener.Addr(). + config.RPCAddr = &net.TCPAddr{IP: []byte{127, 0, 0, 1}} + nodeID, err := uuid.GenerateUUID() if err != nil { t.Fatal(err) } config.NodeID = types.NodeID(nodeID) + + // set the memberlist bind port to 0 to bind to a random port. + // memberlist will update the value of BindPort after bind + // to the actual value. config.SerfLANConfig.MemberlistConfig.BindAddr = "127.0.0.1" - config.SerfLANConfig.MemberlistConfig.BindPort = getPort() + config.SerfLANConfig.MemberlistConfig.BindPort = 0 config.SerfLANConfig.MemberlistConfig.SuspicionMult = 2 config.SerfLANConfig.MemberlistConfig.ProbeTimeout = 50 * time.Millisecond config.SerfLANConfig.MemberlistConfig.ProbeInterval = 100 * time.Millisecond config.SerfLANConfig.MemberlistConfig.GossipInterval = 100 * time.Millisecond config.SerfWANConfig.MemberlistConfig.BindAddr = "127.0.0.1" - config.SerfWANConfig.MemberlistConfig.BindPort = getPort() + config.SerfWANConfig.MemberlistConfig.BindPort = 0 config.SerfWANConfig.MemberlistConfig.SuspicionMult = 2 config.SerfWANConfig.MemberlistConfig.ProbeTimeout = 50 * time.Millisecond config.SerfWANConfig.MemberlistConfig.ProbeInterval = 100 * time.Millisecond @@ -107,14 +112,50 @@ func testServerDCExpect(t *testing.T, dc string, expect int) (string, *Server) { func testServerWithConfig(t *testing.T, cb func(c *Config)) (string, *Server) { name := fmt.Sprintf("Node %d", getPort()) dir, config := testServerConfig(t, name) - cb(config) - server, err := NewServer(config) + if cb != nil { + cb(config) + } + server, err := newServer(config) if err != nil { t.Fatalf("err: %v", err) } return dir, server } +func newServer(c *Config) (*Server, error) { + // chain server up notification + oldNotify := c.NotifyListen + up := make(chan struct{}) + c.NotifyListen = func() { + close(up) + if oldNotify != nil { + oldNotify() + } + } + + // start server + srv, err := NewServer(c) + if err != nil { + return nil, err + } + + // wait until after listen + <-up + + // get the real address + // + // the server already sets the RPCAdvertise address + // if it wasn't configured since it needs it for + // some initialization + // + // todo(fs): setting RPCAddr should probably be guarded + // todo(fs): but for now it is a shortcut to avoid fixing + // todo(fs): tests which depend on that value. They should + // todo(fs): just get the listener address instead. + c.RPCAddr = srv.Listener.Addr().(*net.TCPAddr) + return srv, nil +} + func TestServer_StartStop(t *testing.T) { // Start up a server and then stop it. dir1, s1 := testServer(t) @@ -381,7 +422,7 @@ func TestServer_JoinLAN_TLS(t *testing.T) { conf1.VerifyIncoming = true conf1.VerifyOutgoing = true configureTLS(conf1) - s1, err := NewServer(conf1) + s1, err := newServer(conf1) if err != nil { t.Fatalf("err: %v", err) } @@ -393,7 +434,7 @@ func TestServer_JoinLAN_TLS(t *testing.T) { conf2.VerifyIncoming = true conf2.VerifyOutgoing = true configureTLS(conf2) - s2, err := NewServer(conf2) + s2, err := newServer(conf2) if err != nil { t.Fatalf("err: %v", err) } diff --git a/agent/consul/status_endpoint_test.go b/agent/consul/status_endpoint_test.go index 776d91b77..0c7440663 100644 --- a/agent/consul/status_endpoint_test.go +++ b/agent/consul/status_endpoint_test.go @@ -13,7 +13,7 @@ import ( ) func rpcClient(t *testing.T, s *Server) rpc.ClientCodec { - addr := s.config.RPCAddr + addr := s.config.RPCAdvertise conn, err := net.DialTimeout("tcp", addr.String(), time.Second) if err != nil { t.Fatalf("err: %v", err)