Create network layer abstraction to allow in-memory cluster traffic (#8173)

This commit is contained in:
Brian Kassouf 2020-01-16 23:03:02 -08:00 committed by GitHub
parent 3956072c93
commit f32a86ee7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 870 additions and 118 deletions

View File

@ -503,14 +503,8 @@ func (b *RaftBackend) SetupCluster(ctx context.Context, opts SetupOpts) error {
case opts.ClusterListener == nil:
return errors.New("no cluster listener provided")
default:
// Load the base TLS config from the cluster listener.
baseTLSConfig, err := opts.ClusterListener.TLSConfig(ctx)
if err != nil {
return err
}
// Set the local address and localID in the streaming layer and the raft config.
streamLayer, err := NewRaftLayer(b.logger.Named("stream"), opts.TLSKeyring, opts.ClusterListener.Addr(), baseTLSConfig)
streamLayer, err := NewRaftLayer(b.logger.Named("stream"), opts.TLSKeyring, opts.ClusterListener)
if err != nil {
return err
}

View File

@ -110,7 +110,7 @@ func GenerateTLSKey(reader io.Reader) (*TLSKey, error) {
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
// 30 years of single-active uptime ought to be enough for anybody
// 30 years ought to be enough for anybody
NotAfter: time.Now().Add(262980 * time.Hour),
BasicConstraintsValid: true,
IsCA: true,
@ -162,13 +162,14 @@ type raftLayer struct {
dialerFunc func(string, time.Duration) (net.Conn, error)
// TLS config
keyring *TLSKeyring
baseTLSConfig *tls.Config
keyring *TLSKeyring
clusterListener cluster.ClusterHook
}
// NewRaftLayer creates a new raftLayer object. It parses the TLS information
// from the network config.
func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterAddr net.Addr, baseTLSConfig *tls.Config) (*raftLayer, error) {
func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterListener cluster.ClusterHook) (*raftLayer, error) {
clusterAddr := clusterListener.Addr()
switch {
case clusterAddr == nil:
// Clustering disabled on the server, don't try to look for params
@ -176,11 +177,11 @@ func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterAddr net
}
layer := &raftLayer{
addr: clusterAddr,
connCh: make(chan net.Conn),
closeCh: make(chan struct{}),
logger: logger,
baseTLSConfig: baseTLSConfig,
addr: clusterAddr,
connCh: make(chan net.Conn),
closeCh: make(chan struct{}),
logger: logger,
clusterListener: clusterListener,
}
if err := layer.setTLSKeyring(raftTLSKeyring); err != nil {
@ -236,6 +237,24 @@ func (l *raftLayer) setTLSKeyring(keyring *TLSKeyring) error {
return nil
}
func (l *raftLayer) ServerName() string {
key := l.keyring.GetActive()
if key == nil {
return ""
}
return key.parsedCert.Subject.CommonName
}
func (l *raftLayer) CACert(ctx context.Context) *x509.Certificate {
key := l.keyring.GetActive()
if key == nil {
return nil
}
return key.parsedCert
}
func (l *raftLayer) ClientLookup(ctx context.Context, requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
for _, subj := range requestInfo.AcceptableCAs {
for _, key := range l.keyring.Keys {
@ -346,26 +365,6 @@ func (l *raftLayer) Addr() net.Addr {
// Dial is used to create a new outgoing connection
func (l *raftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
tlsConfig := l.baseTLSConfig.Clone()
key := l.keyring.GetActive()
if key == nil {
return nil, errors.New("no active key")
}
tlsConfig.NextProtos = []string{consts.RaftStorageALPN}
tlsConfig.ServerName = key.parsedCert.Subject.CommonName
l.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName)
pool := x509.NewCertPool()
pool.AddCert(key.parsedCert)
tlsConfig.RootCAs = pool
tlsConfig.ClientCAs = pool
dialer := &net.Dialer{
Timeout: timeout,
}
return tls.DialWithDialer(dialer, "tcp", string(address), tlsConfig)
dialFunc := l.clusterListener.GetDialerFunc(context.Background(), consts.RaftStorageALPN)
return dialFunc(string(address), timeout)
}

View File

@ -5,7 +5,6 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
@ -302,7 +301,13 @@ func (c *Core) startClusterListener(ctx context.Context) error {
c.logger.Debug("starting cluster listeners")
c.clusterListener.Store(cluster.NewListener(c.clusterListenerAddrs, c.clusterCipherSuites, c.logger.Named("cluster-listener")))
networkLayer := c.clusterNetworkLayer
if networkLayer == nil {
networkLayer = cluster.NewTCPLayer(c.clusterListenerAddrs, c.logger.Named("cluster-listener.tcp"))
}
c.clusterListener.Store(cluster.NewListener(networkLayer, c.clusterCipherSuites, c.logger.Named("cluster-listener")))
err := c.getClusterListener().Run(ctx)
if err != nil {
@ -310,7 +315,7 @@ func (c *Core) startClusterListener(ctx context.Context) error {
}
if strings.HasSuffix(c.ClusterAddr(), ":0") {
// If we listened on port 0, record the port the OS gave us.
c.clusterAddr.Store(fmt.Sprintf("https://%s", c.getClusterListener().Addrs()[0]))
c.clusterAddr.Store(fmt.Sprintf("https://%s", c.getClusterListener().Addr()))
}
return nil
}
@ -355,37 +360,3 @@ func (c *Core) SetClusterListenerAddrs(addrs []*net.TCPAddr) {
func (c *Core) SetClusterHandler(handler http.Handler) {
c.clusterHandler = handler
}
// getGRPCDialer is used to return a dialer that has the correct TLS
// configuration. Otherwise gRPC tries to be helpful and stomps all over our
// NextProtos.
func (c *Core) getGRPCDialer(ctx context.Context, alpnProto, serverName string, caCert *x509.Certificate) func(string, time.Duration) (net.Conn, error) {
return func(addr string, timeout time.Duration) (net.Conn, error) {
clusterListener := c.getClusterListener()
if clusterListener == nil {
return nil, errors.New("clustering disabled")
}
tlsConfig, err := clusterListener.TLSConfig(ctx)
if err != nil {
c.logger.Error("failed to get tls configuration", "error", err)
return nil, err
}
if serverName != "" {
tlsConfig.ServerName = serverName
}
if caCert != nil {
pool := x509.NewCertPool()
pool.AddCert(caCert)
tlsConfig.RootCAs = pool
tlsConfig.ClientCAs = pool
}
c.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName)
tlsConfig.NextProtos = []string{alpnProto}
dialer := &net.Dialer{
Timeout: timeout,
}
return tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
}
}

View File

@ -5,6 +5,7 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
@ -27,6 +28,8 @@ const (
// Client is used to lookup a client certificate.
type Client interface {
ClientLookup(context.Context, *tls.CertificateRequestInfo) (*tls.Certificate, error)
ServerName() string
CACert(ctx context.Context) *x509.Certificate
}
// Handler exposes functions for looking up TLS configuration and handing
@ -48,6 +51,7 @@ type ClusterHook interface {
StopHandler(alpn string)
TLSConfig(ctx context.Context) (*tls.Config, error)
Addr() net.Addr
GetDialerFunc(ctx context.Context, alpnProto string) func(string, time.Duration) (net.Conn, error)
}
// Listener is the source of truth for cluster handlers and connection
@ -60,13 +64,13 @@ type Listener struct {
shutdownWg *sync.WaitGroup
server *http2.Server
listenerAddrs []*net.TCPAddr
cipherSuites []uint16
logger log.Logger
l sync.RWMutex
networkLayer NetworkLayer
cipherSuites []uint16
logger log.Logger
l sync.RWMutex
}
func NewListener(addrs []*net.TCPAddr, cipherSuites []uint16, logger log.Logger) *Listener {
func NewListener(networkLayer NetworkLayer, cipherSuites []uint16, logger log.Logger) *Listener {
// Create the HTTP/2 server that will be shared by both RPC and regular
// duties. Doing it this way instead of listening via the server and gRPC
// allows us to re-use the same port via ALPN. We can just tell the server
@ -84,19 +88,22 @@ func NewListener(addrs []*net.TCPAddr, cipherSuites []uint16, logger log.Logger)
shutdownWg: &sync.WaitGroup{},
server: h2Server,
listenerAddrs: addrs,
cipherSuites: cipherSuites,
logger: logger,
networkLayer: networkLayer,
cipherSuites: cipherSuites,
logger: logger,
}
}
// TODO: This probably isn't correct
func (cl *Listener) Addr() net.Addr {
return cl.listenerAddrs[0]
addrs := cl.Addrs()
if len(addrs) == 0 {
return nil
}
return addrs[0]
}
func (cl *Listener) Addrs() []*net.TCPAddr {
return cl.listenerAddrs
func (cl *Listener) Addrs() []net.Addr {
return cl.networkLayer.Addrs()
}
// AddClient adds a new client for an ALPN name
@ -236,29 +243,15 @@ func (cl *Listener) Run(ctx context.Context) error {
// The server supports all of the possible protos
tlsConfig.NextProtos = []string{"h2", consts.RequestForwardingALPN, consts.PerfStandbyALPN, consts.PerformanceReplicationALPN, consts.DRReplicationALPN}
for i, laddr := range cl.listenerAddrs {
for _, ln := range cl.networkLayer.Listeners() {
// closeCh is used to shutdown the spawned goroutines once this
// function returns
closeCh := make(chan struct{})
if cl.logger.IsInfo() {
cl.logger.Info("starting listener", "listener_address", laddr)
}
// Create a TCP listener. We do this separately and specifically
// with TCP so that we can set deadlines.
tcpLn, err := net.ListenTCP("tcp", laddr)
if err != nil {
cl.logger.Error("error starting listener", "error", err)
continue
}
if laddr.String() != tcpLn.Addr().String() {
// If we listened on port 0, record the port the OS gave us.
cl.listenerAddrs[i] = tcpLn.Addr().(*net.TCPAddr)
}
localLn := ln
// Wrap the listener with TLS
tlsLn := tls.NewListener(tcpLn, tlsConfig)
tlsLn := tls.NewListener(localLn, tlsConfig)
if cl.logger.IsInfo() {
cl.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr())
@ -281,7 +274,7 @@ func (cl *Listener) Run(ctx context.Context) error {
// Set the deadline for the accept call. If it passes we'll get
// an error, causing us to check the condition at the top
// again.
tcpLn.SetDeadline(time.Now().Add(ListenerAcceptDeadline))
localLn.SetDeadline(time.Now().Add(ListenerAcceptDeadline))
// Accept the connection
conn, err := tlsLn.Accept()
@ -365,3 +358,67 @@ func (cl *Listener) Stop() {
cl.shutdownWg.Wait()
cl.logger.Info("rpc listeners successfully shut down")
}
// GetDialerFunc returns a function that looks up the TLS information for the
// provided alpn name and calls the network layer's dial function.
func (cl *Listener) GetDialerFunc(ctx context.Context, alpn string) func(string, time.Duration) (net.Conn, error) {
return func(addr string, timeout time.Duration) (net.Conn, error) {
tlsConfig, err := cl.TLSConfig(ctx)
if err != nil {
cl.logger.Error("failed to get tls configuration", "error", err)
return nil, err
}
if tlsConfig == nil {
return nil, errors.New("no tls config found")
}
cl.l.RLock()
client, ok := cl.clients[alpn]
cl.l.RUnlock()
if !ok {
return nil, fmt.Errorf("no client configured for alpn: %q", alpn)
}
serverName := client.ServerName()
if serverName != "" {
tlsConfig.ServerName = serverName
}
caCert := client.CACert(ctx)
if caCert != nil {
pool := x509.NewCertPool()
pool.AddCert(caCert)
tlsConfig.RootCAs = pool
tlsConfig.ClientCAs = pool
}
tlsConfig.NextProtos = []string{alpn}
cl.logger.Debug("creating rpc dialer", "alpn", alpn, "host", tlsConfig.ServerName)
return cl.networkLayer.Dial(addr, timeout, tlsConfig)
}
}
// NetworkListener is used by the network layer to define a net.Listener for use
// in the cluster listener.
type NetworkListener interface {
net.Listener
SetDeadline(t time.Time) error
}
// NetworkLayer is the network abstraction used in the cluster listener.
// Abstracting the network layer out allows us to swap the underlying
// implementations for tests.
type NetworkLayer interface {
Addrs() []net.Addr
Listeners() []NetworkListener
Dial(address string, timeout time.Duration, tlsConfig *tls.Config) (*tls.Conn, error)
Close() error
}
// NetworkLayerSet is used for returning a slice of layers to a caller.
type NetworkLayerSet interface {
Layers() []NetworkLayer
}

View File

@ -0,0 +1,323 @@
package cluster
import (
"crypto/tls"
"errors"
"net"
"sync"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/base62"
"go.uber.org/atomic"
)
// InmemLayer is an in-memory implementation of NetworkLayer. This is
// primarially useful for tests.
type InmemLayer struct {
listener *inmemListener
addr string
logger log.Logger
servConns map[string][]net.Conn
clientConns map[string][]net.Conn
peers map[string]*InmemLayer
l sync.Mutex
stopped *atomic.Bool
stopCh chan struct{}
}
// NewInmemLayer returns a new in-memory layer configured to listen on the
// provided address.
func NewInmemLayer(addr string, logger log.Logger) *InmemLayer {
return &InmemLayer{
addr: addr,
logger: logger,
stopped: atomic.NewBool(false),
stopCh: make(chan struct{}),
peers: make(map[string]*InmemLayer),
servConns: make(map[string][]net.Conn),
clientConns: make(map[string][]net.Conn),
}
}
// Addrs implements NetworkLayer.
func (l *InmemLayer) Addrs() []net.Addr {
l.l.Lock()
defer l.l.Unlock()
if l.listener == nil {
return nil
}
return []net.Addr{l.listener.Addr()}
}
// Listeners implements NetworkLayer.
func (l *InmemLayer) Listeners() []NetworkListener {
l.l.Lock()
defer l.l.Unlock()
if l.listener != nil {
return []NetworkListener{l.listener}
}
l.listener = &inmemListener{
addr: l.addr,
pendingConns: make(chan net.Conn),
stopped: atomic.NewBool(false),
stopCh: make(chan struct{}),
}
return []NetworkListener{l.listener}
}
// Dial implements NetworkLayer.
func (l *InmemLayer) Dial(addr string, timeout time.Duration, tlsConfig *tls.Config) (*tls.Conn, error) {
l.l.Lock()
defer l.l.Unlock()
peer, ok := l.peers[addr]
if !ok {
return nil, errors.New("inmemlayer: no address found")
}
conn, err := peer.clientConn(l.addr)
if err != nil {
return nil, err
}
tlsConn := tls.Client(conn, tlsConfig)
l.clientConns[addr] = append(l.clientConns[addr], tlsConn)
return tlsConn, nil
}
// clientConn is executed on a server when a new client connection comes in and
// needs to be Accepted.
func (l *InmemLayer) clientConn(addr string) (net.Conn, error) {
l.l.Lock()
defer l.l.Unlock()
if l.listener == nil {
return nil, errors.New("inmemlayer: listener not started")
}
_, ok := l.peers[addr]
if !ok {
return nil, errors.New("inmemlayer: no peer found")
}
retConn, servConn := net.Pipe()
l.servConns[addr] = append(l.servConns[addr], servConn)
select {
case l.listener.pendingConns <- servConn:
case <-time.After(2 * time.Second):
return nil, errors.New("inmemlayer: timeout while accepting connection")
}
return retConn, nil
}
// Connect is used to connect this transport to another transport for
// a given peer name. This allows for local routing.
func (l *InmemLayer) Connect(peer string, remote *InmemLayer) {
l.l.Lock()
defer l.l.Unlock()
l.peers[peer] = remote
}
// Disconnect is used to remove the ability to route to a given peer.
func (l *InmemLayer) Disconnect(peer string) {
l.l.Lock()
defer l.l.Unlock()
delete(l.peers, peer)
// Remove any open connections
servConns := l.servConns[peer]
for _, c := range servConns {
c.Close()
}
delete(l.servConns, peer)
clientConns := l.clientConns[peer]
for _, c := range clientConns {
c.Close()
}
delete(l.clientConns, peer)
}
// DisconnectAll is used to remove all routes to peers.
func (l *InmemLayer) DisconnectAll() {
l.l.Lock()
defer l.l.Unlock()
l.peers = make(map[string]*InmemLayer)
// Close all connections
for _, peerConns := range l.servConns {
for _, c := range peerConns {
c.Close()
}
}
l.servConns = make(map[string][]net.Conn)
for _, peerConns := range l.clientConns {
for _, c := range peerConns {
c.Close()
}
}
l.clientConns = make(map[string][]net.Conn)
}
// Close is used to permanently disable the transport
func (l *InmemLayer) Close() error {
if l.stopped.Swap(true) {
return nil
}
l.DisconnectAll()
close(l.stopCh)
return nil
}
// inmemListener implements the NetworkListener interface.
type inmemListener struct {
addr string
pendingConns chan net.Conn
stopped *atomic.Bool
stopCh chan struct{}
deadline time.Time
}
// Accept implements the NetworkListener interface.
func (ln *inmemListener) Accept() (net.Conn, error) {
deadline := ln.deadline
if !deadline.IsZero() {
select {
case conn := <-ln.pendingConns:
return conn, nil
case <-time.After(time.Until(deadline)):
return nil, deadlineError("deadline")
case <-ln.stopCh:
return nil, errors.New("listener shut down")
}
}
select {
case conn := <-ln.pendingConns:
return conn, nil
case <-ln.stopCh:
return nil, errors.New("listener shut down")
}
}
// Close implements the NetworkListener interface.
func (ln *inmemListener) Close() error {
if ln.stopped.Swap(true) {
return nil
}
close(ln.stopCh)
return nil
}
// Addr implements the NetworkListener interface.
func (ln *inmemListener) Addr() net.Addr {
return inmemAddr{addr: ln.addr}
}
// SetDeadline implements the NetworkListener interface.
func (ln *inmemListener) SetDeadline(d time.Time) error {
ln.deadline = d
return nil
}
type inmemAddr struct {
addr string
}
func (a inmemAddr) Network() string {
return "inmem"
}
func (a inmemAddr) String() string {
return a.addr
}
type deadlineError string
func (d deadlineError) Error() string { return string(d) }
func (d deadlineError) Timeout() bool { return true }
func (d deadlineError) Temporary() bool { return true }
// InmemLayerCluster composes a set of layers and handles connecting them all
// together. It also satisfies the NetworkLayerSet interface.
type InmemLayerCluster struct {
layers []*InmemLayer
}
// NewInmemLayerCluster returns a new in-memory layer set that builds n nodes
// and connects them all together.
func NewInmemLayerCluster(nodes int, logger log.Logger) (*InmemLayerCluster, error) {
clusterID, err := base62.Random(4)
if err != nil {
return nil, err
}
clusterName := "cluster_" + clusterID
var layers []*InmemLayer
for i := 0; i < nodes; i++ {
nodeID, err := base62.Random(4)
if err != nil {
return nil, err
}
nodeName := clusterName + "_node_" + nodeID
layers = append(layers, NewInmemLayer(nodeName, logger))
}
// Connect all the peers together
for _, node := range layers {
for _, peer := range layers {
// Don't connect to itself
if node == peer {
continue
}
node.Connect(peer.addr, peer)
peer.Connect(node.addr, node)
}
}
return &InmemLayerCluster{layers: layers}, nil
}
// ConnectCluster connects this cluster with the provided remote cluster,
// connecting all nodes to each other.
func (ic *InmemLayerCluster) ConnectCluster(remote *InmemLayerCluster) {
for _, node := range ic.layers {
for _, peer := range remote.layers {
node.Connect(peer.addr, peer)
peer.Connect(node.addr, node)
}
}
}
// Layers implements the NetworkLayerSet interface.
func (ic *InmemLayerCluster) Layers() []NetworkLayer {
ret := make([]NetworkLayer, len(ic.layers))
for i, l := range ic.layers {
ret[i] = l
}
return ret
}

View File

@ -0,0 +1,240 @@
package cluster
import (
"sync"
"testing"
"time"
"go.uber.org/atomic"
)
func TestInmemCluster_Connect(t *testing.T) {
cluster, err := NewInmemLayerCluster(3, nil)
if err != nil {
t.Fatal(err)
}
server := cluster.layers[0]
listener := server.Listeners()[0]
var accepted int
stopCh := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopCh:
return
default:
}
listener.SetDeadline(time.Now().Add(5 * time.Second))
_, err := listener.Accept()
if err != nil {
return
}
accepted++
}
}()
// Make sure two nodes can connect in
conn, err := cluster.layers[1].Dial(server.addr, 0, nil)
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("nil conn")
}
conn, err = cluster.layers[2].Dial(server.addr, 0, nil)
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("nil conn")
}
close(stopCh)
wg.Wait()
if accepted != 2 {
t.Fatalf("expected 2 connections to be accepted, got %d", accepted)
}
}
func TestInmemCluster_Disconnect(t *testing.T) {
cluster, err := NewInmemLayerCluster(3, nil)
if err != nil {
t.Fatal(err)
}
server := cluster.layers[0]
server.Disconnect(cluster.layers[1].addr)
listener := server.Listeners()[0]
var accepted int
stopCh := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopCh:
return
default:
}
listener.SetDeadline(time.Now().Add(5 * time.Second))
_, err := listener.Accept()
if err != nil {
return
}
accepted++
}
}()
// Make sure node1 cannot connect in
conn, err := cluster.layers[1].Dial(server.addr, 0, nil)
if err == nil {
t.Fatal("expected error")
}
if conn != nil {
t.Fatal("expected nil conn")
}
// Node2 should be able to connect
conn, err = cluster.layers[2].Dial(server.addr, 0, nil)
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("nil conn")
}
close(stopCh)
wg.Wait()
if accepted != 1 {
t.Fatalf("expected 1 connections to be accepted, got %d", accepted)
}
}
func TestInmemCluster_DisconnectAll(t *testing.T) {
cluster, err := NewInmemLayerCluster(3, nil)
if err != nil {
t.Fatal(err)
}
server := cluster.layers[0]
server.DisconnectAll()
// Make sure nodes cannot connect in
conn, err := cluster.layers[1].Dial(server.addr, 0, nil)
if err == nil {
t.Fatal("expected error")
}
if conn != nil {
t.Fatal("expected nil conn")
}
conn, err = cluster.layers[2].Dial(server.addr, 0, nil)
if err == nil {
t.Fatal("expected error")
}
if conn != nil {
t.Fatal("expected nil conn")
}
}
func TestInmemCluster_ConnectCluster(t *testing.T) {
cluster, err := NewInmemLayerCluster(3, nil)
if err != nil {
t.Fatal(err)
}
cluster2, err := NewInmemLayerCluster(3, nil)
if err != nil {
t.Fatal(err)
}
cluster.ConnectCluster(cluster2)
var accepted atomic.Int32
stopCh := make(chan struct{})
var wg sync.WaitGroup
acceptConns := func(listener NetworkListener) {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopCh:
return
default:
}
listener.SetDeadline(time.Now().Add(5 * time.Second))
_, err := listener.Accept()
if err != nil {
return
}
accepted.Add(1)
}
}()
}
// Start a listener on each node.
for _, node := range cluster.layers {
acceptConns(node.Listeners()[0])
}
for _, node := range cluster2.layers {
acceptConns(node.Listeners()[0])
}
// Make sure each node can connect to each other
for _, node1 := range cluster.layers {
for _, node2 := range cluster2.layers {
conn, err := node1.Dial(node2.addr, 0, nil)
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("nil conn")
}
conn, err = node2.Dial(node1.addr, 0, nil)
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("nil conn")
}
}
}
close(stopCh)
wg.Wait()
if accepted.Load() != 18 {
t.Fatalf("expected 18 connections to be accepted, got %d", accepted)
}
}

111
vault/cluster/tcp_layer.go Normal file
View File

@ -0,0 +1,111 @@
package cluster
import (
"crypto/tls"
"net"
"sync"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
"go.uber.org/atomic"
)
// TCPLayer implements the NetworkLayer interface and uses TCP as the underlying
// network.
type TCPLayer struct {
listeners []NetworkListener
addrs []*net.TCPAddr
logger log.Logger
l sync.Mutex
stopped *atomic.Bool
}
// NewTCPLayer returns a TCPLayer.
func NewTCPLayer(addrs []*net.TCPAddr, logger log.Logger) *TCPLayer {
return &TCPLayer{
addrs: addrs,
logger: logger,
stopped: atomic.NewBool(false),
}
}
// Addrs implements NetworkLayer.
func (l *TCPLayer) Addrs() []net.Addr {
l.l.Lock()
defer l.l.Unlock()
if len(l.addrs) == 0 {
return nil
}
ret := make([]net.Addr, len(l.addrs))
for i, a := range l.addrs {
ret[i] = a
}
return ret
}
// Listeners implements NetworkLayer. It starts a new TCP listener for each
// configured address.
func (l *TCPLayer) Listeners() []NetworkListener {
l.l.Lock()
defer l.l.Unlock()
if l.listeners != nil {
return l.listeners
}
listeners := make([]NetworkListener, len(l.addrs))
for i, laddr := range l.addrs {
if l.logger.IsInfo() {
l.logger.Info("starting listener", "listener_address", laddr)
}
tcpLn, err := net.ListenTCP("tcp", laddr)
if err != nil {
l.logger.Error("error starting listener", "error", err)
continue
}
if laddr.String() != tcpLn.Addr().String() {
// If we listened on port 0, record the port the OS gave us.
l.addrs[i] = tcpLn.Addr().(*net.TCPAddr)
}
listeners[i] = tcpLn
}
l.listeners = listeners
return listeners
}
// Dial implements the NetworkLayer interface.
func (l *TCPLayer) Dial(address string, timeout time.Duration, tlsConfig *tls.Config) (*tls.Conn, error) {
dialer := &net.Dialer{
Timeout: timeout,
}
return tls.DialWithDialer(dialer, "tcp", address, tlsConfig)
}
// Close implements the NetworkLayer interface.
func (l *TCPLayer) Close() error {
if l.stopped.Swap(true) {
return nil
}
l.l.Lock()
defer l.l.Unlock()
var retErr *multierror.Error
for _, ln := range l.listeners {
if err := ln.Close(); err != nil {
retErr = multierror.Append(retErr, err)
}
}
l.listeners = nil
return retErr.ErrorOrNil()
}

View File

@ -4,7 +4,6 @@ import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"net/http"
"testing"
"time"
@ -15,6 +14,7 @@ import (
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/physical"
"github.com/hashicorp/vault/sdk/physical/inmem"
"github.com/hashicorp/vault/vault/cluster"
)
var (
@ -100,13 +100,13 @@ func TestCluster_ListenForRequests(t *testing.T) {
// Wait for core to become active
TestWaitActive(t, cores[0].Core)
cores[0].getClusterListener().AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
clusterListener := cores[0].getClusterListener()
clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
addrs := cores[0].getClusterListener().Addrs()
// Use this to have a valid config after sealing since ClusterTLSConfig returns nil
checkListenersFunc := func(expectFail bool) {
parsedCert := cores[0].localClusterParsedCert.Load().(*x509.Certificate)
dialer := cores[0].getGRPCDialer(context.Background(), consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
dialer := clusterListener.GetDialerFunc(context.Background(), consts.RequestForwardingALPN)
for i := range cores[0].Listeners {
clnAddr := addrs[i]
@ -172,11 +172,25 @@ func TestCluster_ForwardRequests(t *testing.T) {
// Make this nicer for tests
manualStepDownSleepPeriod = 5 * time.Second
testCluster_ForwardRequestsCommon(t)
t.Run("tcpLayer", func(t *testing.T) {
testCluster_ForwardRequestsCommon(t, nil)
})
t.Run("inmemLayer", func(t *testing.T) {
// Run again with in-memory network
inmemCluster, err := cluster.NewInmemLayerCluster(3, nil)
if err != nil {
t.Fatal(err)
}
testCluster_ForwardRequestsCommon(t, &TestClusterOptions{
ClusterLayers: inmemCluster,
})
})
}
func testCluster_ForwardRequestsCommon(t *testing.T) {
cluster := NewTestCluster(t, nil, nil)
func testCluster_ForwardRequestsCommon(t *testing.T, clusterOpts *TestClusterOptions) {
cluster := NewTestCluster(t, nil, clusterOpts)
cores := cluster.Cores
cores[0].Handler.(*http.ServeMux).HandleFunc("/core1", func(w http.ResponseWriter, req *http.Request) {
w.Header().Add("Content-Type", "application/json")

View File

@ -493,6 +493,14 @@ type Core struct {
secureRandomReader io.Reader
recoveryMode bool
clusterNetworkLayer cluster.NetworkLayer
// PR1103disabled is used to test upgrade workflows: when set to true,
// the correct behaviour for namespaced cubbyholes is disabled, so we
// can test an upgrade to a version that includes the fixes from
// https://github.com/hashicorp/vault-enterprise/pull/1103
PR1103disabled bool
}
// CoreConfig is used to parameterize a core
@ -576,6 +584,8 @@ type CoreConfig struct {
CounterSyncInterval time.Duration
RecoveryMode bool
ClusterNetworkLayer cluster.NetworkLayer
}
func (c *CoreConfig) Clone() *CoreConfig {
@ -611,6 +621,7 @@ func (c *CoreConfig) Clone() *CoreConfig {
DisableIndexing: c.DisableIndexing,
AllLoggers: c.AllLoggers,
CounterSyncInterval: c.CounterSyncInterval,
ClusterNetworkLayer: c.ClusterNetworkLayer,
}
}
@ -706,6 +717,7 @@ func NewCore(conf *CoreConfig) (*Core, error) {
maxLeaseTTL: conf.MaxLeaseTTL,
cachingDisabled: conf.DisableCache,
clusterName: conf.ClusterName,
clusterNetworkLayer: conf.ClusterNetworkLayer,
clusterPeerClusterAddrsCache: cache.New(3*cluster.HeartbeatInterval, time.Second),
enableMlock: !conf.DisableMlock,
rawEnabled: conf.EnableRaw,

View File

@ -99,6 +99,19 @@ func (c *requestForwardingClusterClient) ClientLookup(ctx context.Context, reque
return nil, nil
}
func (c *requestForwardingClusterClient) ServerName() string {
parsedCert := c.core.localClusterParsedCert.Load().(*x509.Certificate)
if parsedCert == nil {
return ""
}
return parsedCert.Subject.CommonName
}
func (c *requestForwardingClusterClient) CACert(ctx context.Context) *x509.Certificate {
return c.core.localClusterParsedCert.Load().(*x509.Certificate)
}
// ServerLookup satisfies the ClusterHandler interface and returns the server's
// tls certs.
func (rf *requestForwardingHandler) ServerLookup(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
@ -246,19 +259,22 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
}
clusterListener := c.getClusterListener()
if clusterListener != nil {
clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{
core: c,
})
if clusterListener == nil {
c.logger.Error("no cluster listener configured")
return nil
}
clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{
core: c,
})
// Set up grpc forwarding handling
// It's not really insecure, but we have to dial manually to get the
// ALPN header right. It's just "insecure" because GRPC isn't managing
// the TLS state.
dctx, cancelFunc := context.WithCancel(ctx)
c.rpcClientConn, err = grpc.DialContext(dctx, clusterURL.Host,
grpc.WithDialer(c.getGRPCDialer(ctx, consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)),
grpc.WithDialer(clusterListener.GetDialerFunc(ctx, consts.RequestForwardingALPN)),
grpc.WithInsecure(), // it's not, we handle it in the dialer
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 2 * cluster.HeartbeatInterval,

View File

@ -31,6 +31,7 @@ import (
hclog "github.com/hashicorp/go-hclog"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/vault/cluster"
"github.com/hashicorp/vault/vault/seal"
"github.com/mitchellh/copystructure"
@ -1057,7 +1058,11 @@ type TestClusterOptions struct {
FirstCoreNumber int
RequireClientAuth bool
// SetupFunc is called after the cluster is started.
SetupFunc func(t testing.T, c *TestCluster)
SetupFunc func(t testing.T, c *TestCluster)
PR1103Disabled bool
// ClusterLayers are used to override the default cluster connection layer
ClusterLayers cluster.NetworkLayerSet
}
var DefaultNumCores = 3
@ -1093,6 +1098,11 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
numCores = opts.NumCores
}
var disablePR1103 bool
if opts != nil && opts.PR1103Disabled {
disablePR1103 = true
}
var firstCoreNumber int
if opts != nil {
firstCoreNumber = opts.FirstCoreNumber
@ -1486,6 +1496,10 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
}
}
if opts != nil && opts.ClusterLayers != nil {
localConfig.ClusterNetworkLayer = opts.ClusterLayers.Layers()[i]
}
switch {
case localConfig.LicensingConfig != nil:
if pubKey != nil {
@ -1506,6 +1520,7 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
t.Fatalf("err: %v", err)
}
c.coreNumber = firstCoreNumber + i
c.PR1103disabled = disablePR1103
cores = append(cores, c)
coreConfigs = append(coreConfigs, &localConfig)
if opts != nil && opts.HandlerFunc != nil {