diff --git a/consul/pool.go b/consul/pool.go index 1c2f86fe5..e6c7220eb 100644 --- a/consul/pool.go +++ b/consul/pool.go @@ -135,8 +135,10 @@ type ConnPool struct { pool map[string]*Conn // limiter is used to throttle the number of connect attempts - // to a given address. - limiter map[string]chan int + // to a given address. The first thread will attempt a connection + // and put a channel in here, which all other threads will wait + // on to close. + limiter map[string]chan struct{} // TLS wrapper tlsWrap tlsutil.DCWrapper @@ -157,7 +159,7 @@ func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap maxTime: maxTime, maxStreams: maxStreams, pool: make(map[string]*Conn), - limiter: make(map[string]chan int), + limiter: make(map[string]chan struct{}), tlsWrap: tlsWrap, shutdownCh: make(chan struct{}), } @@ -209,52 +211,57 @@ func (p *ConnPool) acquire(dc string, addr net.Addr, version int) (*Conn, error) } // If not (while we are still locked), set up the throttling structure - // for this address. - var wait chan int + // for this address, which will make everyone else wait until our + // attempt is done. + var wait chan struct{} var ok bool if wait, ok = p.limiter[addr.String()]; !ok { - wait = make(chan int, 1) + wait = make(chan struct{}, 1) p.limiter[addr.String()] = wait } + isLeadThread := !ok p.Unlock() - // Now throttle so we don't pound on a server if there are a ton of - // outstanding requests to one server. - wait <- 1 - defer func() { <- wait }() + // If we are the lead thread, make the new connection and then wake + // everybody else up to see if we got it. + if isLeadThread { + defer func() { + p.Lock() + delete(p.limiter, addr.String()) + p.Unlock() - // In case we got throttled, check the pool one more time. + close(wait) + }() + + c, err := p.getNewConn(dc, addr, version) + if err != nil { + return nil, err + } + + p.Lock() + p.pool[addr.String()] = c + p.Unlock() + return c, nil + } + + // Otherwise, wait for the lead thread to attempt the connection + // and use what's in the pool at that point. + select { + case <-p.shutdownCh: + return nil, fmt.Errorf("rpc error: shutdown") + case <-wait: + } + + // See if the lead thread was able to get us a connection. p.Lock() - c = p.pool[addr.String()] - if c != nil { + if c := p.pool[addr.String()]; c != nil { markForUse(c) p.Unlock() return c, nil } + p.Unlock() - - // Go ahead and make a new connection. - c, err := p.getNewConn(dc, addr, version) - if err != nil { - return nil, err - } - - // Return the new connection, adding it to the pool. If the connection - // the throttle was waiting for fails then all the threads will then try - // to connect, so we have to handle that potential race condition and - // scuttle the connection we just made if someone else got there first. - p.Lock() - if existing := p.pool[addr.String()]; existing != nil { - c.Close() - - markForUse(existing) - p.Unlock() - return existing, nil - } - - p.pool[addr.String()] = c - p.Unlock() - return c, nil + return nil, fmt.Errorf("rpc error: lead thread didn't get connection") } // getNewConn is used to return a new connection