open-consul/agent/consul/wanfed/pool.go

259 lines
4.3 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package wanfed
import (
"fmt"
"net"
"sync"
"time"
)
// connPool pools idle negotiated ALPN_WANGossipPacket flavored connections to
// remote servers. Idle connections only remain pooled for up to maxTime after
// they were last acquired.
type connPool struct {
// maxTime is the maximum time to keep a connection open.
maxTime time.Duration
// mu protects pool and shutdown
mu sync.Mutex
pool map[string][]*conn
shutdown bool
shutdownCh chan struct{}
reapWg sync.WaitGroup
}
func newConnPool(maxTime time.Duration) (*connPool, error) {
if maxTime == 0 {
return nil, fmt.Errorf("wanfed: conn pool needs a max time configured")
}
p := &connPool{
maxTime: maxTime,
pool: make(map[string][]*conn),
shutdownCh: make(chan struct{}),
}
p.reapWg.Add(1)
go p.reap()
return p, nil
}
func (p *connPool) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.shutdown {
return nil
}
for _, conns := range p.pool {
for _, conn := range conns {
conn.Close()
}
}
p.pool = nil
p.shutdown = true
close(p.shutdownCh)
p.reapWg.Wait()
return nil
}
// AcquireOrDial either removes an idle connection from the pool or
// estabilishes a new one using the provided dialer function.
func (p *connPool) AcquireOrDial(key string, dialer func() (net.Conn, error)) (*conn, error) {
c, err := p.maybeAcquire(key)
if err != nil {
return nil, err
}
if c != nil {
c.markForUse()
return c, nil
}
nc, err := dialer()
if err != nil {
return nil, err
}
c = &conn{
key: key,
pool: p,
Conn: nc,
}
c.markForUse()
return c, nil
}
var errPoolClosed = fmt.Errorf("wanfed: connection pool is closed")
// maybeAcquire removes an idle connection from the pool if possible otherwise
// returns nil indicating there were no idle connections ready. It is the
// caller's responsibility to open a new connection if that is desired.
func (p *connPool) maybeAcquire(key string) (*conn, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.shutdown {
return nil, errPoolClosed
}
conns, ok := p.pool[key]
if !ok {
return nil, nil
}
switch len(conns) {
case 0:
delete(p.pool, key) // stray cleanup
return nil, nil
case 1:
c := conns[0]
delete(p.pool, key)
return c, nil
default:
sz := len(conns)
remaining, last := conns[0:sz-1], conns[sz-1]
p.pool[key] = remaining
return last, nil
}
}
// returnConn puts the connection back into the idle pool for reuse.
func (p *connPool) returnConn(c *conn) error {
p.mu.Lock()
defer p.mu.Unlock()
if p.shutdown {
return c.Conn.Close() // actual shutdown
}
p.pool[c.key] = append(p.pool[c.key], c)
return nil
}
// reap periodically scans the idle pool for connections that have not been
// used recently and closes them.
func (p *connPool) reap() {
defer p.reapWg.Done()
for {
select {
case <-p.shutdownCh:
return
case <-time.After(time.Second):
}
p.reapOnce()
}
}
func (p *connPool) reapOnce() {
p.mu.Lock()
defer p.mu.Unlock()
if p.shutdown {
return
}
now := time.Now()
var removedKeys []string
for key, conns := range p.pool {
if len(conns) == 0 {
removedKeys = append(removedKeys, key) // cleanup
continue
}
var retain []*conn
for _, c := range conns {
// Skip recently used connections
if now.Sub(c.lastUsed) < p.maxTime {
retain = append(retain, c)
} else {
c.Conn.Close()
}
}
if len(retain) == len(conns) {
continue // no change
} else if len(retain) == 0 {
removedKeys = append(removedKeys, key)
continue
}
p.pool[key] = retain
}
for _, key := range removedKeys {
delete(p.pool, key)
}
}
type conn struct {
key string
mu sync.Mutex
lastUsed time.Time
failed bool
closed bool
pool *connPool
net.Conn
}
func (c *conn) ReturnOrClose() error {
c.mu.Lock()
closed := c.closed
failed := c.failed
if failed {
c.closed = true
}
c.mu.Unlock()
if closed {
return nil
}
if failed {
return c.Conn.Close()
}
return c.pool.returnConn(c)
}
func (c *conn) Close() error {
c.mu.Lock()
closed := c.closed
c.closed = true
c.mu.Unlock()
if closed {
return nil
}
return c.Conn.Close()
}
func (c *conn) markForUse() {
c.mu.Lock()
c.lastUsed = time.Now()
c.failed = false
c.mu.Unlock()
}
func (c *conn) MarkFailed() {
c.mu.Lock()
c.failed = true
c.mu.Unlock()
}