03aa6106ff
We noticed that TestUpstreamListener would deadlock sometimes when run with the race detector. While debugging this issue I found and fixed the following problems. 1. the net.Listener was not being closed properly when Listener.Stop was called. This caused the Listener.Serve goroutine to run forever. Fixed by storing a reference to net.Listener and closing it properly when Listener.Stop is called. 2. call connWG.Add in the correct place. WaitGroup.Add must be called before starting a goroutine, not from inside the goroutine. 3. Set metrics config EnableRuntimeMetrics to `false` so that we don't start a background goroutine in each test for no reason. There is no way to shutdown this goroutine, and it was an added distraction while debugging these timeouts. 5. two tests were calling require.NoError from a goroutine. require.NoError calls t.FailNow, which MUST be called from the main test goroutine. Instead use t.Errorf, which can be called from other goroutines and will still fail the test. 6. `assertCurrentGaugeValue` wass breaking out of a for loop, which would cause the `RWMutex.RUnlock` to be missed. Fixed by calling unlock before `break`. The core issue of a deadlock was fixed by https://github.com/armon/go-metrics/pull/124.
267 lines
7.8 KiB
Go
267 lines
7.8 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
metrics "github.com/armon/go-metrics"
|
|
"github.com/hashicorp/go-hclog"
|
|
|
|
"github.com/hashicorp/consul/api"
|
|
"github.com/hashicorp/consul/connect"
|
|
"github.com/hashicorp/consul/ipaddr"
|
|
)
|
|
|
|
const (
|
|
publicListenerPrefix = "inbound"
|
|
upstreamListenerPrefix = "upstream"
|
|
)
|
|
|
|
// Listener is the implementation of a specific proxy listener. It has pluggable
|
|
// Listen and Dial methods to suit public mTLS vs upstream semantics. It handles
|
|
// the lifecycle of the listener and all connections opened through it
|
|
type Listener struct {
|
|
// Service is the connect service instance to use.
|
|
Service *connect.Service
|
|
|
|
// listenFunc, dialFunc, and bindAddr are set by type-specific constructors.
|
|
listenFunc func() (net.Listener, error)
|
|
dialFunc func() (net.Conn, error)
|
|
bindAddr string
|
|
|
|
stopFlag int32
|
|
stopChan chan struct{}
|
|
|
|
// listeningChan is closed when listener is opened successfully. It's really
|
|
// only for use in tests where we need to coordinate wait for the Serve
|
|
// goroutine to be running before we proceed trying to connect. On my laptop
|
|
// this always works out anyway but on constrained VMs and especially docker
|
|
// containers (e.g. in CI) we often see the Dial routine win the race and get
|
|
// `connection refused`. Retry loops and sleeps are unpleasant workarounds and
|
|
// this is cheap and correct.
|
|
listeningChan chan struct{}
|
|
listener net.Listener
|
|
|
|
logger hclog.Logger
|
|
|
|
// Gauge to track current open connections
|
|
activeConns int32
|
|
connWG sync.WaitGroup
|
|
metricPrefix string
|
|
metricLabels []metrics.Label
|
|
}
|
|
|
|
// NewPublicListener returns a Listener setup to listen for public mTLS
|
|
// connections and proxy them to the configured local application over TCP.
|
|
func NewPublicListener(svc *connect.Service, cfg PublicListenerConfig,
|
|
logger hclog.Logger) *Listener {
|
|
bindAddr := ipaddr.FormatAddressPort(cfg.BindAddress, cfg.BindPort)
|
|
return &Listener{
|
|
Service: svc,
|
|
listenFunc: func() (net.Listener, error) {
|
|
return tls.Listen("tcp", bindAddr, svc.ServerTLSConfig())
|
|
},
|
|
dialFunc: func() (net.Conn, error) {
|
|
return net.DialTimeout("tcp", cfg.LocalServiceAddress,
|
|
time.Duration(cfg.LocalConnectTimeoutMs)*time.Millisecond)
|
|
},
|
|
bindAddr: bindAddr,
|
|
stopChan: make(chan struct{}),
|
|
listeningChan: make(chan struct{}),
|
|
logger: logger.Named(publicListenerPrefix),
|
|
metricPrefix: publicListenerPrefix,
|
|
// For now we only label ourselves as source - we could fetch the src
|
|
// service from cert on each connection and label metrics differently but it
|
|
// significaly complicates the active connection tracking here and it's not
|
|
// clear that it's very valuable - on aggregate looking at all _outbound_
|
|
// connections across all proxies gets you a full picture of src->dst
|
|
// traffic. We might expand this later for better debugging of which clients
|
|
// are abusing a particular service instance but we'll see how valuable that
|
|
// seems for the extra complication of tracking many gauges here.
|
|
metricLabels: []metrics.Label{{Name: "dst", Value: svc.Name()}},
|
|
}
|
|
}
|
|
|
|
// NewUpstreamListener returns a Listener setup to listen locally for TCP
|
|
// connections that are proxied to a discovered Connect service instance.
|
|
func NewUpstreamListener(svc *connect.Service, client *api.Client,
|
|
cfg UpstreamConfig, logger hclog.Logger) *Listener {
|
|
return newUpstreamListenerWithResolver(svc, cfg,
|
|
UpstreamResolverFuncFromClient(client), logger)
|
|
}
|
|
|
|
func newUpstreamListenerWithResolver(svc *connect.Service, cfg UpstreamConfig,
|
|
resolverFunc func(UpstreamConfig) (connect.Resolver, error),
|
|
logger hclog.Logger) *Listener {
|
|
bindAddr := ipaddr.FormatAddressPort(cfg.LocalBindAddress, cfg.LocalBindPort)
|
|
return &Listener{
|
|
Service: svc,
|
|
listenFunc: func() (net.Listener, error) {
|
|
return net.Listen("tcp", bindAddr)
|
|
},
|
|
dialFunc: func() (net.Conn, error) {
|
|
rf, err := resolverFunc(cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(),
|
|
cfg.ConnectTimeout())
|
|
defer cancel()
|
|
return svc.Dial(ctx, rf)
|
|
},
|
|
bindAddr: bindAddr,
|
|
stopChan: make(chan struct{}),
|
|
listeningChan: make(chan struct{}),
|
|
logger: logger.Named(upstreamListenerPrefix),
|
|
metricPrefix: upstreamListenerPrefix,
|
|
metricLabels: []metrics.Label{
|
|
{Name: "src", Value: svc.Name()},
|
|
// TODO(banks): namespace support
|
|
{Name: "dst_type", Value: string(cfg.DestinationType)},
|
|
{Name: "dst", Value: cfg.DestinationName},
|
|
},
|
|
}
|
|
}
|
|
|
|
// Serve runs the listener until it is stopped. It is an error to call Serve
|
|
// more than once for any given Listener instance.
|
|
func (l *Listener) Serve() error {
|
|
// Ensure we mark state closed if we fail before Close is called externally.
|
|
defer l.Close()
|
|
|
|
if atomic.LoadInt32(&l.stopFlag) != 0 {
|
|
return errors.New("serve called on a closed listener")
|
|
}
|
|
|
|
var err error
|
|
l.listener, err = l.listenFunc()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
close(l.listeningChan)
|
|
|
|
for {
|
|
conn, err := l.listener.Accept()
|
|
if err != nil {
|
|
if atomic.LoadInt32(&l.stopFlag) == 1 {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
l.connWG.Add(1)
|
|
go l.handleConn(conn)
|
|
}
|
|
}
|
|
|
|
// handleConn is the internal connection handler goroutine.
|
|
func (l *Listener) handleConn(src net.Conn) {
|
|
defer src.Close()
|
|
// Make sure Listener.Close waits for this conn to be cleaned up.
|
|
defer l.connWG.Done()
|
|
|
|
dst, err := l.dialFunc()
|
|
if err != nil {
|
|
l.logger.Error("failed to dial", "error", err)
|
|
return
|
|
}
|
|
|
|
// Track active conn now (first function call) and defer un-counting it when
|
|
// it closes.
|
|
defer l.trackConn()()
|
|
|
|
// Note no need to defer dst.Close() since conn handles that for us.
|
|
conn := NewConn(src, dst)
|
|
defer conn.Close()
|
|
|
|
connStop := make(chan struct{})
|
|
|
|
// Run another goroutine to copy the bytes.
|
|
go func() {
|
|
err = conn.CopyBytes()
|
|
if err != nil {
|
|
l.logger.Error("connection failed", "error", err)
|
|
}
|
|
close(connStop)
|
|
}()
|
|
|
|
// Periodically copy stats from conn to metrics (to keep metrics calls out of
|
|
// the path of every single packet copy). 5 seconds is probably good enough
|
|
// resolution - statsd and most others tend to summarize with lower resolution
|
|
// anyway and this amortizes the cost more.
|
|
var tx, rx uint64
|
|
statsT := time.NewTicker(5 * time.Second)
|
|
defer statsT.Stop()
|
|
|
|
reportStats := func() {
|
|
newTx, newRx := conn.Stats()
|
|
if delta := newTx - tx; delta > 0 {
|
|
metrics.IncrCounterWithLabels([]string{l.metricPrefix, "tx_bytes"},
|
|
float32(newTx-tx), l.metricLabels)
|
|
}
|
|
if delta := newRx - rx; delta > 0 {
|
|
metrics.IncrCounterWithLabels([]string{l.metricPrefix, "rx_bytes"},
|
|
float32(newRx-rx), l.metricLabels)
|
|
}
|
|
tx, rx = newTx, newRx
|
|
}
|
|
// Always report final stats for the conn.
|
|
defer reportStats()
|
|
|
|
// Wait for conn to close
|
|
for {
|
|
select {
|
|
case <-connStop:
|
|
return
|
|
case <-l.stopChan:
|
|
return
|
|
case <-statsT.C:
|
|
reportStats()
|
|
}
|
|
}
|
|
}
|
|
|
|
// trackConn increments the count of active conns and returns a func() that can
|
|
// be deferred on to decrement the counter again on connection close.
|
|
func (l *Listener) trackConn() func() {
|
|
c := atomic.AddInt32(&l.activeConns, 1)
|
|
metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c),
|
|
l.metricLabels)
|
|
|
|
return func() {
|
|
c := atomic.AddInt32(&l.activeConns, -1)
|
|
metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c),
|
|
l.metricLabels)
|
|
}
|
|
}
|
|
|
|
// Close terminates the listener and all active connections.
|
|
func (l *Listener) Close() error {
|
|
oldFlag := atomic.SwapInt32(&l.stopFlag, 1)
|
|
if oldFlag == 0 {
|
|
close(l.stopChan)
|
|
// Wait for all conns to close
|
|
l.connWG.Wait()
|
|
|
|
if l.listener != nil {
|
|
l.listener.Close()
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Wait for the listener to be ready to accept connections.
|
|
func (l *Listener) Wait() {
|
|
<-l.listeningChan
|
|
}
|
|
|
|
// BindAddr returns the address the listen is bound to.
|
|
func (l *Listener) BindAddr() string {
|
|
return l.bindAddr
|
|
}
|