open-consul/connect/proxy/listener.go

153 lines
4.2 KiB
Go
Raw Normal View History

package proxy
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"sync/atomic"
"time"
"github.com/hashicorp/consul/connect"
)
// Listener is the implementation of a specific proxy listener. It has pluggable
// Listen and Dial methods to suit public mTLS vs upstream semantics. It handles
// the lifecycle of the listener and all connections opened through it
type Listener struct {
// Service is the connect service instance to use.
Service *connect.Service
// listenFunc, dialFunc and bindAddr are set by type-specific constructors
listenFunc func() (net.Listener, error)
dialFunc func() (net.Conn, error)
bindAddr string
stopFlag int32
stopChan chan struct{}
// listeningChan is closed when listener is opened successfully. It's really
// only for use in tests where we need to coordinate wait for the Serve
// goroutine to be running before we proceed trying to connect. On my laptop
// this always works out anyway but on constrained VMs and especially docker
// containers (e.g. in CI) we often see the Dial routine win the race and get
// `connection refused`. Retry loops and sleeps are unpleasant workarounds and
// this is cheap and correct.
listeningChan chan struct{}
logger *log.Logger
}
// NewPublicListener returns a Listener setup to listen for public mTLS
// connections and proxy them to the configured local application over TCP.
func NewPublicListener(svc *connect.Service, cfg PublicListenerConfig,
logger *log.Logger) *Listener {
bindAddr := fmt.Sprintf("%s:%d", cfg.BindAddress, cfg.BindPort)
return &Listener{
Service: svc,
listenFunc: func() (net.Listener, error) {
return tls.Listen("tcp", bindAddr, svc.ServerTLSConfig())
},
dialFunc: func() (net.Conn, error) {
return net.DialTimeout("tcp", cfg.LocalServiceAddress,
time.Duration(cfg.LocalConnectTimeoutMs)*time.Millisecond)
},
bindAddr: bindAddr,
stopChan: make(chan struct{}),
listeningChan: make(chan struct{}),
logger: logger,
}
}
// NewUpstreamListener returns a Listener setup to listen locally for TCP
// connections that are proxied to a discovered Connect service instance.
func NewUpstreamListener(svc *connect.Service, cfg UpstreamConfig,
logger *log.Logger) *Listener {
bindAddr := fmt.Sprintf("%s:%d", cfg.LocalBindAddress, cfg.LocalBindPort)
return &Listener{
Service: svc,
listenFunc: func() (net.Listener, error) {
return net.Listen("tcp", bindAddr)
},
dialFunc: func() (net.Conn, error) {
if cfg.resolver == nil {
return nil, errors.New("no resolver provided")
}
ctx, cancel := context.WithTimeout(context.Background(),
time.Duration(cfg.ConnectTimeoutMs)*time.Millisecond)
defer cancel()
return svc.Dial(ctx, cfg.resolver)
},
bindAddr: bindAddr,
stopChan: make(chan struct{}),
listeningChan: make(chan struct{}),
logger: logger,
}
}
// Serve runs the listener until it is stopped. It is an error to call Serve
// more than once for any given Listener instance.
func (l *Listener) Serve() error {
// Ensure we mark state closed if we fail before Close is called externally.
defer l.Close()
if atomic.LoadInt32(&l.stopFlag) != 0 {
return errors.New("serve called on a closed listener")
}
listen, err := l.listenFunc()
if err != nil {
return err
}
close(l.listeningChan)
for {
conn, err := listen.Accept()
if err != nil {
if atomic.LoadInt32(&l.stopFlag) == 1 {
return nil
}
return err
}
go l.handleConn(conn)
}
}
// handleConn is the internal connection handler goroutine.
func (l *Listener) handleConn(src net.Conn) {
defer src.Close()
dst, err := l.dialFunc()
if err != nil {
l.logger.Printf("[ERR] failed to dial: %s", err)
return
}
// Note no need to defer dst.Close() since conn handles that for us.
conn := NewConn(src, dst)
defer conn.Close()
err = conn.CopyBytes()
if err != nil {
l.logger.Printf("[ERR] connection failed: %s", err)
return
}
}
// Close terminates the listener and all active connections.
func (l *Listener) Close() error {
return nil
}
// Wait for the listener to be ready to accept connections.
func (l *Listener) Wait() {
<-l.listeningChan
}
// BindAddr returns the address the listen is bound to.
func (l *Listener) BindAddr() string {
return l.bindAddr
}