337d13cc3a
Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>
525 lines
15 KiB
Go
525 lines
15 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package cluster
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"net"
|
|
"net/url"
|
|
"os"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/hashicorp/vault/sdk/helper/certutil"
|
|
"github.com/hashicorp/vault/sdk/helper/tlsutil"
|
|
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
|
"golang.org/x/net/http2"
|
|
)
|
|
|
|
const (
|
|
ListenerAcceptDeadline = 500 * time.Millisecond
|
|
)
|
|
|
|
// Client is used to lookup a client certificate.
|
|
type Client interface {
|
|
ClientLookup(context.Context, *tls.CertificateRequestInfo) (*tls.Certificate, error)
|
|
ServerName() string
|
|
CACert(ctx context.Context) *x509.Certificate
|
|
}
|
|
|
|
// Handler exposes functions for looking up TLS configuration and handing
|
|
// off a connection for a cluster listener application.
|
|
type Handler interface {
|
|
ServerLookup(context.Context, *tls.ClientHelloInfo) (*tls.Certificate, error)
|
|
CALookup(context.Context) ([]*x509.Certificate, error)
|
|
|
|
// Handoff is used to pass the connection lifetime off to
|
|
// the handler
|
|
Handoff(context.Context, *sync.WaitGroup, chan struct{}, *tls.Conn) error
|
|
Stop() error
|
|
}
|
|
|
|
type ClusterHook interface {
|
|
AddClient(alpn string, client Client)
|
|
RemoveClient(alpn string)
|
|
AddHandler(alpn string, handler Handler)
|
|
StopHandler(alpn string)
|
|
TLSConfig(ctx context.Context) (*tls.Config, error)
|
|
Addr() net.Addr
|
|
GetDialerFunc(ctx context.Context, alpnProto string) func(string, time.Duration) (net.Conn, error)
|
|
}
|
|
|
|
// Listener is the source of truth for cluster handlers and connection
|
|
// clients. It dynamically builds the cluster TLS information. It's also
|
|
// responsible for starting tcp listeners and accepting new cluster connections.
|
|
type Listener struct {
|
|
handlers map[string]Handler
|
|
clients map[string]Client
|
|
shutdown *uint32
|
|
shutdownWg *sync.WaitGroup
|
|
server *http2.Server
|
|
|
|
networkLayer NetworkLayer
|
|
cipherSuites []uint16
|
|
advertise net.Addr
|
|
logger log.Logger
|
|
l sync.RWMutex
|
|
tlsConnectionLoggingLevel log.Level
|
|
}
|
|
|
|
func NewListener(networkLayer NetworkLayer, cipherSuites []uint16, logger log.Logger, idleTimeout time.Duration) *Listener {
|
|
var maxStreams uint32 = math.MaxUint32
|
|
if override := os.Getenv("VAULT_GRPC_MAX_STREAMS"); override != "" {
|
|
i, err := strconv.ParseUint(override, 10, 32)
|
|
if err != nil {
|
|
logger.Warn("vault grpc max streams override must be an uint32 integer", "value", override)
|
|
} else {
|
|
maxStreams = uint32(i)
|
|
logger.Info("overriding grpc max streams", "value", i)
|
|
}
|
|
}
|
|
|
|
// Create the HTTP/2 server that will be shared by both RPC and regular
|
|
// duties. Doing it this way instead of listening via the server and gRPC
|
|
// allows us to re-use the same port via ALPN. We can just tell the server
|
|
// to serve a given conn and which handler to use.
|
|
h2Server := &http2.Server{
|
|
// Our forwarding connections heartbeat regularly so anything else we
|
|
// want to go away/get cleaned up pretty rapidly
|
|
IdleTimeout: idleTimeout,
|
|
|
|
// By default this is 250 which can be too small on high traffic
|
|
// clusters with many forwarded or replication gRPC connections.
|
|
MaxConcurrentStreams: maxStreams,
|
|
}
|
|
|
|
return &Listener{
|
|
handlers: make(map[string]Handler),
|
|
clients: make(map[string]Client),
|
|
shutdown: new(uint32),
|
|
shutdownWg: &sync.WaitGroup{},
|
|
server: h2Server,
|
|
|
|
networkLayer: networkLayer,
|
|
cipherSuites: cipherSuites,
|
|
logger: logger,
|
|
tlsConnectionLoggingLevel: log.LevelFromString(os.Getenv("VAULT_CLUSTER_TLS_SESSION_LOG_LEVEL")),
|
|
}
|
|
}
|
|
|
|
func (cl *Listener) SetAdvertiseAddr(addr string) error {
|
|
u, err := url.ParseRequestURI(addr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse advertise address: %w", err)
|
|
}
|
|
cl.advertise = &NetAddr{
|
|
Host: u.Host,
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (cl *Listener) Addr() net.Addr {
|
|
if cl.advertise != nil {
|
|
return cl.advertise
|
|
}
|
|
|
|
addrs := cl.Addrs()
|
|
if len(addrs) == 0 {
|
|
return nil
|
|
}
|
|
return addrs[0]
|
|
}
|
|
|
|
func (cl *Listener) Addrs() []net.Addr {
|
|
return cl.networkLayer.Addrs()
|
|
}
|
|
|
|
// AddClient adds a new client for an ALPN name
|
|
func (cl *Listener) AddClient(alpn string, client Client) {
|
|
cl.l.Lock()
|
|
cl.clients[alpn] = client
|
|
cl.l.Unlock()
|
|
}
|
|
|
|
// RemoveClient removes the client for the specified ALPN name
|
|
func (cl *Listener) RemoveClient(alpn string) {
|
|
cl.l.Lock()
|
|
delete(cl.clients, alpn)
|
|
cl.l.Unlock()
|
|
}
|
|
|
|
// AddHandler registers a new cluster handler for the provided ALPN name.
|
|
func (cl *Listener) AddHandler(alpn string, handler Handler) {
|
|
cl.l.Lock()
|
|
cl.handlers[alpn] = handler
|
|
cl.l.Unlock()
|
|
}
|
|
|
|
// StopHandler stops the cluster handler for the provided ALPN name, it also
|
|
// calls stop on the handler.
|
|
func (cl *Listener) StopHandler(alpn string) {
|
|
cl.l.Lock()
|
|
handler, ok := cl.handlers[alpn]
|
|
delete(cl.handlers, alpn)
|
|
cl.l.Unlock()
|
|
if ok {
|
|
handler.Stop()
|
|
}
|
|
}
|
|
|
|
// Handler returns the handler for the provided ALPN name
|
|
func (cl *Listener) Handler(alpn string) (Handler, bool) {
|
|
cl.l.RLock()
|
|
handler, ok := cl.handlers[alpn]
|
|
cl.l.RUnlock()
|
|
return handler, ok
|
|
}
|
|
|
|
// Server returns the http2 server that the cluster listener is using
|
|
func (cl *Listener) Server() *http2.Server {
|
|
return cl.server
|
|
}
|
|
|
|
// TLSConfig returns a tls config object that uses dynamic lookups to correctly
|
|
// authenticate registered handlers/clients
|
|
func (cl *Listener) TLSConfig(ctx context.Context) (*tls.Config, error) {
|
|
serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
cl.logger.Debug("performing server cert lookup")
|
|
|
|
cl.l.RLock()
|
|
defer cl.l.RUnlock()
|
|
for _, v := range clientHello.SupportedProtos {
|
|
if handler, ok := cl.handlers[v]; ok {
|
|
return handler.ServerLookup(ctx, clientHello)
|
|
}
|
|
}
|
|
|
|
cl.logger.Warn("no TLS certs found for ALPN", "ALPN", clientHello.SupportedProtos)
|
|
return nil, errors.New("unsupported protocol")
|
|
}
|
|
|
|
clientLookup := func(requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
cl.logger.Debug("performing client cert lookup")
|
|
|
|
cl.l.RLock()
|
|
defer cl.l.RUnlock()
|
|
for _, client := range cl.clients {
|
|
cert, err := client.ClientLookup(ctx, requestInfo)
|
|
if err == nil && cert != nil {
|
|
return cert, nil
|
|
}
|
|
}
|
|
|
|
cl.logger.Warn("no client information found")
|
|
return nil, errors.New("no client cert found")
|
|
}
|
|
|
|
serverConfigLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
caPool := x509.NewCertPool()
|
|
|
|
ret := &tls.Config{
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
GetCertificate: serverLookup,
|
|
GetClientCertificate: clientLookup,
|
|
MinVersion: tls.VersionTLS12,
|
|
RootCAs: caPool,
|
|
ClientCAs: caPool,
|
|
NextProtos: clientHello.SupportedProtos,
|
|
CipherSuites: cl.cipherSuites,
|
|
}
|
|
|
|
cl.l.RLock()
|
|
defer cl.l.RUnlock()
|
|
for _, v := range clientHello.SupportedProtos {
|
|
if handler, ok := cl.handlers[v]; ok {
|
|
caList, err := handler.CALookup(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, ca := range caList {
|
|
caPool.AddCert(ca)
|
|
}
|
|
return ret, nil
|
|
}
|
|
}
|
|
|
|
cl.logger.Warn("no TLS config found for ALPN", "ALPN", clientHello.SupportedProtos)
|
|
return nil, errors.New("unsupported protocol")
|
|
}
|
|
|
|
return &tls.Config{
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
GetCertificate: serverLookup,
|
|
GetClientCertificate: clientLookup,
|
|
GetConfigForClient: serverConfigLookup,
|
|
MinVersion: tls.VersionTLS12,
|
|
CipherSuites: cl.cipherSuites,
|
|
}, nil
|
|
}
|
|
|
|
// Run starts the tcp listeners and will accept connections until stop is
|
|
// called. This function blocks so should be called in a goroutine.
|
|
func (cl *Listener) Run(ctx context.Context) error {
|
|
// Get our TLS config
|
|
tlsConfig, err := cl.TLSConfig(ctx)
|
|
if err != nil {
|
|
cl.logger.Error("failed to get tls configuration when starting cluster listener", "error", err)
|
|
return err
|
|
}
|
|
|
|
// The server supports all of the possible protos
|
|
tlsConfig.NextProtos = []string{"h2", consts.RequestForwardingALPN, consts.PerfStandbyALPN, consts.PerformanceReplicationALPN, consts.DRReplicationALPN}
|
|
|
|
for _, ln := range cl.networkLayer.Listeners() {
|
|
// closeCh is used to shutdown the spawned goroutines once this
|
|
// function returns
|
|
closeCh := make(chan struct{})
|
|
|
|
localLn := ln
|
|
|
|
// Wrap the listener with TLS
|
|
tlsLn := tls.NewListener(localLn, tlsConfig)
|
|
|
|
if cl.logger.IsInfo() {
|
|
cl.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr())
|
|
}
|
|
|
|
cl.shutdownWg.Add(1)
|
|
// Start our listening loop
|
|
go func(closeCh chan struct{}, tlsLn net.Listener) {
|
|
defer func() {
|
|
cl.shutdownWg.Done()
|
|
tlsLn.Close()
|
|
close(closeCh)
|
|
}()
|
|
|
|
// baseDelay is the initial delay after an Accept() error before attempting again
|
|
const baseDelay = 5 * time.Millisecond
|
|
|
|
// maxDelay is the maximum delay after an Accept() error before attempting again.
|
|
// In the case that this function is error-looping, it will delay the shutdown check.
|
|
// Therefore, changes to maxDelay may have an effect on the latency of shutdown.
|
|
const maxDelay = 1 * time.Second
|
|
|
|
var loopDelay time.Duration
|
|
for {
|
|
if atomic.LoadUint32(cl.shutdown) > 0 {
|
|
return
|
|
}
|
|
|
|
// Set the deadline for the accept call. If it passes we'll get
|
|
// an error, causing us to check the condition at the top
|
|
// again.
|
|
localLn.SetDeadline(time.Now().Add(ListenerAcceptDeadline))
|
|
|
|
// Accept the connection
|
|
conn, err := tlsLn.Accept()
|
|
if err != nil {
|
|
err, ok := err.(net.Error)
|
|
if ok && !err.Timeout() {
|
|
cl.logger.Debug("non-timeout error accepting on cluster port", "error", err)
|
|
}
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
if ok && err.Timeout() {
|
|
loopDelay = 0
|
|
continue
|
|
}
|
|
|
|
if loopDelay == 0 {
|
|
loopDelay = baseDelay
|
|
} else {
|
|
loopDelay *= 2
|
|
}
|
|
|
|
if loopDelay > maxDelay {
|
|
loopDelay = maxDelay
|
|
}
|
|
|
|
time.Sleep(loopDelay)
|
|
continue
|
|
}
|
|
// No error, reset loop delay
|
|
loopDelay = 0
|
|
|
|
if conn == nil {
|
|
continue
|
|
}
|
|
|
|
// Type assert to TLS connection and handshake to populate the
|
|
// connection state
|
|
tlsConn := conn.(*tls.Conn)
|
|
|
|
// Set a deadline for the handshake. This will cause clients
|
|
// that don't successfully auth to be kicked out quickly.
|
|
// Cluster connections should be reliable so being marginally
|
|
// aggressive here is fine.
|
|
err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second))
|
|
if err != nil {
|
|
if cl.logger.IsDebug() {
|
|
cl.logger.Debug("error setting deadline for cluster connection", "error", err)
|
|
}
|
|
tlsConn.Close()
|
|
continue
|
|
}
|
|
|
|
err = tlsConn.Handshake()
|
|
if err != nil {
|
|
if cl.logger.IsDebug() {
|
|
cl.logger.Debug("error handshaking cluster connection", "error", err)
|
|
}
|
|
tlsConn.Close()
|
|
continue
|
|
}
|
|
|
|
cl.logTLSSessionStart(tlsConn.RemoteAddr().String(), tlsConn.ConnectionState())
|
|
|
|
// Now, set it back to unlimited
|
|
err = tlsConn.SetDeadline(time.Time{})
|
|
if err != nil {
|
|
if cl.logger.IsDebug() {
|
|
cl.logger.Debug("error setting deadline for cluster connection", "error", err)
|
|
}
|
|
tlsConn.Close()
|
|
continue
|
|
}
|
|
|
|
cl.l.RLock()
|
|
handler, ok := cl.handlers[tlsConn.ConnectionState().NegotiatedProtocol]
|
|
cl.l.RUnlock()
|
|
if !ok {
|
|
cl.logger.Debug("unknown negotiated protocol on cluster port")
|
|
tlsConn.Close()
|
|
continue
|
|
}
|
|
|
|
if err := handler.Handoff(ctx, cl.shutdownWg, closeCh, tlsConn); err != nil {
|
|
cl.logger.Error("error handling cluster connection", "error", err)
|
|
continue
|
|
}
|
|
}
|
|
}(closeCh, tlsLn)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Stop stops the cluster listener
|
|
func (cl *Listener) Stop() {
|
|
// Set the shutdown flag. This will cause the listeners to shut down
|
|
// within the deadline in clusterListenerAcceptDeadline
|
|
atomic.StoreUint32(cl.shutdown, 1)
|
|
cl.logger.Info("forwarding rpc listeners stopped")
|
|
|
|
// Wait for them all to shut down
|
|
cl.shutdownWg.Wait()
|
|
cl.logger.Info("rpc listeners successfully shut down")
|
|
}
|
|
|
|
// GetDialerFunc returns a function that looks up the TLS information for the
|
|
// provided alpn name and calls the network layer's dial function.
|
|
func (cl *Listener) GetDialerFunc(ctx context.Context, alpn string) func(string, time.Duration) (net.Conn, error) {
|
|
return func(addr string, timeout time.Duration) (net.Conn, error) {
|
|
tlsConfig, err := cl.TLSConfig(ctx)
|
|
if err != nil {
|
|
cl.logger.Error("failed to get tls configuration", "error", err)
|
|
return nil, err
|
|
}
|
|
|
|
if tlsConfig == nil {
|
|
return nil, errors.New("no tls config found")
|
|
}
|
|
|
|
cl.l.RLock()
|
|
client, ok := cl.clients[alpn]
|
|
cl.l.RUnlock()
|
|
if !ok {
|
|
return nil, fmt.Errorf("no client configured for alpn: %q", alpn)
|
|
}
|
|
|
|
serverName := client.ServerName()
|
|
if serverName != "" {
|
|
tlsConfig.ServerName = serverName
|
|
}
|
|
|
|
caCert := client.CACert(ctx)
|
|
if caCert != nil {
|
|
pool := x509.NewCertPool()
|
|
pool.AddCert(caCert)
|
|
tlsConfig.RootCAs = pool
|
|
tlsConfig.ClientCAs = pool
|
|
}
|
|
|
|
tlsConfig.NextProtos = []string{alpn}
|
|
cl.logger.Debug("creating rpc dialer", "address", addr, "alpn", alpn, "host", tlsConfig.ServerName)
|
|
|
|
conn, err := cl.networkLayer.Dial(addr, timeout, tlsConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cl.logTLSSessionStart(conn.RemoteAddr().String(), conn.ConnectionState())
|
|
return conn, nil
|
|
}
|
|
}
|
|
|
|
func (cl *Listener) logTLSSessionStart(peerAddress string, state tls.ConnectionState) {
|
|
if cl.tlsConnectionLoggingLevel != log.NoLevel {
|
|
cipherName, _ := tlsutil.GetCipherName(state.CipherSuite)
|
|
cl.logger.Log(cl.tlsConnectionLoggingLevel, "TLS connection established", "peer", peerAddress, "negotiated_protocol", state.NegotiatedProtocol, "cipher_suite", cipherName)
|
|
for _, chain := range state.VerifiedChains {
|
|
for _, cert := range chain {
|
|
cl.logger.Log(cl.tlsConnectionLoggingLevel, "Peer certificate", "is_ca", cert.IsCA, "serial_number", cert.SerialNumber.String(), "subject", cert.Subject.String(),
|
|
"signature_algorithm", cert.SignatureAlgorithm.String(), "public_key_algorithm", cert.PublicKeyAlgorithm.String(), "public_key_size", certutil.GetPublicKeySize(cert.PublicKey))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// NetworkListener is used by the network layer to define a net.Listener for use
|
|
// in the cluster listener.
|
|
type NetworkListener interface {
|
|
net.Listener
|
|
|
|
SetDeadline(t time.Time) error
|
|
}
|
|
|
|
// NetworkLayer is the network abstraction used in the cluster listener.
|
|
// Abstracting the network layer out allows us to swap the underlying
|
|
// implementations for tests.
|
|
type NetworkLayer interface {
|
|
Addrs() []net.Addr
|
|
Listeners() []NetworkListener
|
|
Dial(address string, timeout time.Duration, tlsConfig *tls.Config) (*tls.Conn, error)
|
|
Close() error
|
|
}
|
|
|
|
// NetworkLayerSet is used for returning a slice of layers to a caller.
|
|
type NetworkLayerSet interface {
|
|
Layers() []NetworkLayer
|
|
}
|
|
|
|
type NetAddr struct {
|
|
Host string
|
|
}
|
|
|
|
func (c *NetAddr) String() string {
|
|
return c.Host
|
|
}
|
|
|
|
func (*NetAddr) Network() string {
|
|
return "tcp"
|
|
}
|