diff --git a/consul/client_test.go b/consul/client_test.go index 16e751d69..dad717397 100644 --- a/consul/client_test.go +++ b/consul/client_test.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "os" + "sync" "testing" "time" @@ -189,6 +190,46 @@ func TestClient_RPC(t *testing.T) { }) } +func TestClient_RPC_Pool(t *testing.T) { + dir1, s1 := testServer(t) + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + dir2, c1 := testClient(t) + defer os.RemoveAll(dir2) + defer c1.Shutdown() + + // Try to join. + addr := fmt.Sprintf("127.0.0.1:%d", + s1.config.SerfLANConfig.MemberlistConfig.BindPort) + if _, err := c1.JoinLAN([]string{addr}); err != nil { + t.Fatalf("err: %v", err) + } + if len(s1.LANMembers()) != 2 || len(c1.LANMembers()) != 2 { + t.Fatalf("bad len") + } + + // Blast out a bunch of RPC requests at the same time to try to get + // contention opening new connections. + var wg sync.WaitGroup + for i := 0; i < 150; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + var out struct{} + testutil.WaitForResult(func() (bool, error) { + err := c1.RPC("Status.Ping", struct{}{}, &out) + return err == nil, err + }, func(err error) { + t.Fatalf("err: %v", err) + }) + }() + } + + wg.Wait() +} + func TestClient_RPC_TLS(t *testing.T) { dir1, conf1 := testServerConfig(t, "a.testco.internal") conf1.VerifyIncoming = true diff --git a/consul/pool.go b/consul/pool.go index 3512fa621..0cd0a99df 100644 --- a/consul/pool.go +++ b/consul/pool.go @@ -114,6 +114,12 @@ func (c *Conn) returnClient(client *StreamClient) { } } +// markForUse does all the bookkeeping required to ready a connection for use. +func (c *Conn) markForUse() { + c.lastUsed = time.Now() + atomic.AddInt32(&c.refCount, 1) +} + // ConnPool is used to maintain a connection pool to other // Consul servers. This is used to reduce the latency of // RPC requests between servers. It is only used to pool @@ -134,6 +140,12 @@ type ConnPool struct { // Pool maps an address to a open connection pool map[string]*Conn + // limiter is used to throttle the number of connect attempts + // to a given address. The first thread will attempt a connection + // and put a channel in here, which all other threads will wait + // on to close. + limiter map[string]chan struct{} + // TLS wrapper tlsWrap tlsutil.DCWrapper @@ -153,6 +165,7 @@ func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap maxTime: maxTime, maxStreams: maxStreams, pool: make(map[string]*Conn), + limiter: make(map[string]chan struct{}), tlsWrap: tlsWrap, shutdownCh: make(chan struct{}), } @@ -180,28 +193,69 @@ func (p *ConnPool) Shutdown() error { return nil } -// Acquire is used to get a connection that is -// pooled or to return a new connection +// acquire will return a pooled connection, if available. Otherwise it will +// wait for an existing connection attempt to finish, if one if in progress, +// and will return that one if it succeeds. If all else fails, it will return a +// newly-created connection and add it to the pool. func (p *ConnPool) acquire(dc string, addr net.Addr, version int) (*Conn, error) { - // Check for a pooled ocnn - if conn := p.getPooled(addr, version); conn != nil { - return conn, nil - } - - // Create a new connection - return p.getNewConn(dc, addr, version) -} - -// getPooled is used to return a pooled connection -func (p *ConnPool) getPooled(addr net.Addr, version int) *Conn { + // Check to see if there's a pooled connection available. This is up + // here since it should the the vastly more common case than the rest + // of the code here. p.Lock() c := p.pool[addr.String()] if c != nil { - c.lastUsed = time.Now() - atomic.AddInt32(&c.refCount, 1) + c.markForUse() + p.Unlock() + return c, nil } + + // If not (while we are still locked), set up the throttling structure + // for this address, which will make everyone else wait until our + // attempt is done. + var wait chan struct{} + var ok bool + if wait, ok = p.limiter[addr.String()]; !ok { + wait = make(chan struct{}) + p.limiter[addr.String()] = wait + } + isLeadThread := !ok p.Unlock() - return c + + // If we are the lead thread, make the new connection and then wake + // everybody else up to see if we got it. + if isLeadThread { + c, err := p.getNewConn(dc, addr, version) + p.Lock() + delete(p.limiter, addr.String()) + close(wait) + if err != nil { + p.Unlock() + return nil, err + } + + p.pool[addr.String()] = c + p.Unlock() + return c, nil + } + + // Otherwise, wait for the lead thread to attempt the connection + // and use what's in the pool at that point. + select { + case <-p.shutdownCh: + return nil, fmt.Errorf("rpc error: shutdown") + case <-wait: + } + + // See if the lead thread was able to get us a connection. + p.Lock() + if c := p.pool[addr.String()]; c != nil { + c.markForUse() + p.Unlock() + return c, nil + } + + p.Unlock() + return nil, fmt.Errorf("rpc error: lead thread didn't get connection") } // getNewConn is used to return a new connection @@ -272,18 +326,7 @@ func (p *ConnPool) getNewConn(dc string, addr net.Addr, version int) (*Conn, err version: version, pool: p, } - - // Track this connection, handle potential race condition - p.Lock() - if existing := p.pool[addr.String()]; existing != nil { - c.Close() - p.Unlock() - return existing, nil - } else { - p.pool[addr.String()] = c - p.Unlock() - return c, nil - } + return c, nil } // clearConn is used to clear any cached connection, potentially in response to an erro