open-consul/connect/proxy/listener.go

296 lines
8.4 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package proxy
import (
"context"
"crypto/tls"
"errors"
"net"
"sync"
"sync/atomic"
"time"
metrics "github.com/armon/go-metrics"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/connect"
"github.com/hashicorp/consul/ipaddr"
)
const (
publicListenerPrefix = "inbound"
upstreamListenerPrefix = "upstream"
)
// Listener is the implementation of a specific proxy listener. It has pluggable
// Listen and Dial methods to suit public mTLS vs upstream semantics. It handles
// the lifecycle of the listener and all connections opened through it
type Listener struct {
// Service is the connect service instance to use.
Service *connect.Service
// listenFunc, dialFunc, and bindAddr are set by type-specific constructors.
listenFunc func() (net.Listener, error)
dialFunc func() (net.Conn, error)
bindAddr string
stopFlag int32
stopChan chan struct{}
// listeningChan is closed when listener is opened successfully. It's really
// only for use in tests where we need to coordinate wait for the Serve
// goroutine to be running before we proceed trying to connect. On my laptop
// this always works out anyway but on constrained VMs and especially docker
// containers (e.g. in CI) we often see the Dial routine win the race and get
// `connection refused`. Retry loops and sleeps are unpleasant workarounds and
// this is cheap and correct.
listeningChan chan struct{}
// listenerLock guards access to the listener field
listenerLock sync.Mutex
listener net.Listener
logger hclog.Logger
// Gauge to track current open connections
activeConns int32
connWG sync.WaitGroup
metricPrefix string
metricLabels []metrics.Label
}
// NewPublicListener returns a Listener setup to listen for public mTLS
// connections and proxy them to the configured local application over TCP.
func NewPublicListener(svc *connect.Service, cfg PublicListenerConfig,
logger hclog.Logger) *Listener {
bindAddr := ipaddr.FormatAddressPort(cfg.BindAddress, cfg.BindPort)
return &Listener{
Service: svc,
listenFunc: func() (net.Listener, error) {
return tls.Listen("tcp", bindAddr, svc.ServerTLSConfig())
},
dialFunc: func() (net.Conn, error) {
return net.DialTimeout("tcp", cfg.LocalServiceAddress,
time.Duration(cfg.LocalConnectTimeoutMs)*time.Millisecond)
},
bindAddr: bindAddr,
stopChan: make(chan struct{}),
listeningChan: make(chan struct{}),
logger: logger.Named(publicListenerPrefix),
metricPrefix: publicListenerPrefix,
// For now we only label ourselves as source - we could fetch the src
// service from cert on each connection and label metrics differently but it
// significaly complicates the active connection tracking here and it's not
// clear that it's very valuable - on aggregate looking at all _outbound_
// connections across all proxies gets you a full picture of src->dst
// traffic. We might expand this later for better debugging of which clients
// are abusing a particular service instance but we'll see how valuable that
// seems for the extra complication of tracking many gauges here.
metricLabels: []metrics.Label{{Name: "dst", Value: svc.Name()}},
}
}
// NewUpstreamListener returns a Listener setup to listen locally for TCP
// connections that are proxied to a discovered Connect service instance.
func NewUpstreamListener(svc *connect.Service, client *api.Client,
cfg UpstreamConfig, logger hclog.Logger) *Listener {
return newUpstreamListenerWithResolver(svc, cfg,
UpstreamResolverFuncFromClient(client), logger)
}
func newUpstreamListenerWithResolver(svc *connect.Service, cfg UpstreamConfig,
resolverFunc func(UpstreamConfig) (connect.Resolver, error),
logger hclog.Logger) *Listener {
bindAddr := ipaddr.FormatAddressPort(cfg.LocalBindAddress, cfg.LocalBindPort)
return &Listener{
Service: svc,
listenFunc: func() (net.Listener, error) {
return net.Listen("tcp", bindAddr)
},
dialFunc: func() (net.Conn, error) {
rf, err := resolverFunc(cfg)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(),
cfg.ConnectTimeout())
defer cancel()
return svc.Dial(ctx, rf)
},
bindAddr: bindAddr,
stopChan: make(chan struct{}),
listeningChan: make(chan struct{}),
logger: logger.Named(upstreamListenerPrefix),
metricPrefix: upstreamListenerPrefix,
metricLabels: []metrics.Label{
{Name: "src", Value: svc.Name()},
// TODO(banks): namespace support
{Name: "dst_type", Value: string(cfg.DestinationType)},
{Name: "dst", Value: cfg.DestinationName},
},
}
}
// Serve runs the listener until it is stopped. It is an error to call Serve
// more than once for any given Listener instance.
func (l *Listener) Serve() error {
// Ensure we mark state closed if we fail before Close is called externally.
defer l.Close()
if atomic.LoadInt32(&l.stopFlag) != 0 {
return errors.New("serve called on a closed listener")
}
listener, err := l.listenFunc()
if err != nil {
return err
}
l.setListener(listener)
close(l.listeningChan)
for {
conn, err := listener.Accept()
if err != nil {
if atomic.LoadInt32(&l.stopFlag) == 1 {
return nil
}
return err
}
l.connWG.Add(1)
go l.handleConn(conn)
}
}
// handleConn is the internal connection handler goroutine.
func (l *Listener) handleConn(src net.Conn) {
defer func() {
// Make sure Listener.Close waits for this conn to be cleaned up.
src.Close()
l.connWG.Done()
}()
dst, err := l.dialFunc()
if err != nil {
l.logger.Error("failed to dial", "error", err)
return
}
// Track active conn now (first function call) and defer un-counting it when
// it closes.
defer l.trackConn()()
// Note no need to defer dst.Close() since conn handles that for us.
conn := NewConn(src, dst)
defer conn.Close()
connStop := make(chan struct{})
// Run another goroutine to copy the bytes.
go func() {
err = conn.CopyBytes()
if err != nil {
l.logger.Error("connection failed", "error", err)
}
close(connStop)
}()
// Periodically copy stats from conn to metrics (to keep metrics calls out of
// the path of every single packet copy). 5 seconds is probably good enough
// resolution - statsd and most others tend to summarize with lower resolution
// anyway and this amortizes the cost more.
var tx, rx uint64
statsT := time.NewTicker(5 * time.Second)
defer statsT.Stop()
reportStats := func() {
newTx, newRx := conn.Stats()
if delta := newTx - tx; delta > 0 {
metrics.IncrCounterWithLabels([]string{l.metricPrefix, "tx_bytes"},
float32(newTx-tx), l.metricLabels)
}
if delta := newRx - rx; delta > 0 {
metrics.IncrCounterWithLabels([]string{l.metricPrefix, "rx_bytes"},
float32(newRx-rx), l.metricLabels)
}
tx, rx = newTx, newRx
}
// Always report final stats for the conn.
defer reportStats()
// Wait for conn to close
for {
select {
case <-connStop:
return
case <-l.stopChan:
return
case <-statsT.C:
reportStats()
}
}
}
// trackConn increments the count of active conns and returns a func() that can
// be deferred on to decrement the counter again on connection close.
func (l *Listener) trackConn() func() {
c := atomic.AddInt32(&l.activeConns, 1)
metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c),
l.metricLabels)
return func() {
c := atomic.AddInt32(&l.activeConns, -1)
metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c),
l.metricLabels)
}
}
// Close terminates the listener and all active connections.
func (l *Listener) Close() error {
// Prevent the listener from being started.
oldFlag := atomic.SwapInt32(&l.stopFlag, 1)
if oldFlag != 0 {
return nil
}
// Stop the current listener and stop accepting new requests.
if listener := l.getListener(); listener != nil {
listener.Close()
}
// Stop outstanding requests.
close(l.stopChan)
// Wait for all conns to close
l.connWG.Wait()
return nil
}
// Wait for the listener to be ready to accept connections.
func (l *Listener) Wait() {
<-l.listeningChan
}
// BindAddr returns the address the listen is bound to.
func (l *Listener) BindAddr() string {
return l.bindAddr
}
func (l *Listener) setListener(listener net.Listener) {
l.listenerLock.Lock()
l.listener = listener
l.listenerLock.Unlock()
}
func (l *Listener) getListener() net.Listener {
l.listenerLock.Lock()
defer l.listenerLock.Unlock()
return l.listener
}