Create network layer abstraction to allow in-memory cluster traffic (#8173)
This commit is contained in:
parent
3956072c93
commit
f32a86ee7a
|
@ -503,14 +503,8 @@ func (b *RaftBackend) SetupCluster(ctx context.Context, opts SetupOpts) error {
|
|||
case opts.ClusterListener == nil:
|
||||
return errors.New("no cluster listener provided")
|
||||
default:
|
||||
// Load the base TLS config from the cluster listener.
|
||||
baseTLSConfig, err := opts.ClusterListener.TLSConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the local address and localID in the streaming layer and the raft config.
|
||||
streamLayer, err := NewRaftLayer(b.logger.Named("stream"), opts.TLSKeyring, opts.ClusterListener.Addr(), baseTLSConfig)
|
||||
streamLayer, err := NewRaftLayer(b.logger.Named("stream"), opts.TLSKeyring, opts.ClusterListener)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -110,7 +110,7 @@ func GenerateTLSKey(reader io.Reader) (*TLSKey, error) {
|
|||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
|
||||
SerialNumber: big.NewInt(mathrand.Int63()),
|
||||
NotBefore: time.Now().Add(-30 * time.Second),
|
||||
// 30 years of single-active uptime ought to be enough for anybody
|
||||
// 30 years ought to be enough for anybody
|
||||
NotAfter: time.Now().Add(262980 * time.Hour),
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
|
@ -162,13 +162,14 @@ type raftLayer struct {
|
|||
dialerFunc func(string, time.Duration) (net.Conn, error)
|
||||
|
||||
// TLS config
|
||||
keyring *TLSKeyring
|
||||
baseTLSConfig *tls.Config
|
||||
keyring *TLSKeyring
|
||||
clusterListener cluster.ClusterHook
|
||||
}
|
||||
|
||||
// NewRaftLayer creates a new raftLayer object. It parses the TLS information
|
||||
// from the network config.
|
||||
func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterAddr net.Addr, baseTLSConfig *tls.Config) (*raftLayer, error) {
|
||||
func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterListener cluster.ClusterHook) (*raftLayer, error) {
|
||||
clusterAddr := clusterListener.Addr()
|
||||
switch {
|
||||
case clusterAddr == nil:
|
||||
// Clustering disabled on the server, don't try to look for params
|
||||
|
@ -176,11 +177,11 @@ func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterAddr net
|
|||
}
|
||||
|
||||
layer := &raftLayer{
|
||||
addr: clusterAddr,
|
||||
connCh: make(chan net.Conn),
|
||||
closeCh: make(chan struct{}),
|
||||
logger: logger,
|
||||
baseTLSConfig: baseTLSConfig,
|
||||
addr: clusterAddr,
|
||||
connCh: make(chan net.Conn),
|
||||
closeCh: make(chan struct{}),
|
||||
logger: logger,
|
||||
clusterListener: clusterListener,
|
||||
}
|
||||
|
||||
if err := layer.setTLSKeyring(raftTLSKeyring); err != nil {
|
||||
|
@ -236,6 +237,24 @@ func (l *raftLayer) setTLSKeyring(keyring *TLSKeyring) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (l *raftLayer) ServerName() string {
|
||||
key := l.keyring.GetActive()
|
||||
if key == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return key.parsedCert.Subject.CommonName
|
||||
}
|
||||
|
||||
func (l *raftLayer) CACert(ctx context.Context) *x509.Certificate {
|
||||
key := l.keyring.GetActive()
|
||||
if key == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return key.parsedCert
|
||||
}
|
||||
|
||||
func (l *raftLayer) ClientLookup(ctx context.Context, requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
for _, subj := range requestInfo.AcceptableCAs {
|
||||
for _, key := range l.keyring.Keys {
|
||||
|
@ -346,26 +365,6 @@ func (l *raftLayer) Addr() net.Addr {
|
|||
|
||||
// Dial is used to create a new outgoing connection
|
||||
func (l *raftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
|
||||
|
||||
tlsConfig := l.baseTLSConfig.Clone()
|
||||
|
||||
key := l.keyring.GetActive()
|
||||
if key == nil {
|
||||
return nil, errors.New("no active key")
|
||||
}
|
||||
|
||||
tlsConfig.NextProtos = []string{consts.RaftStorageALPN}
|
||||
tlsConfig.ServerName = key.parsedCert.Subject.CommonName
|
||||
|
||||
l.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName)
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(key.parsedCert)
|
||||
tlsConfig.RootCAs = pool
|
||||
tlsConfig.ClientCAs = pool
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: timeout,
|
||||
}
|
||||
return tls.DialWithDialer(dialer, "tcp", string(address), tlsConfig)
|
||||
dialFunc := l.clusterListener.GetDialerFunc(context.Background(), consts.RaftStorageALPN)
|
||||
return dialFunc(string(address), timeout)
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
|
@ -302,7 +301,13 @@ func (c *Core) startClusterListener(ctx context.Context) error {
|
|||
|
||||
c.logger.Debug("starting cluster listeners")
|
||||
|
||||
c.clusterListener.Store(cluster.NewListener(c.clusterListenerAddrs, c.clusterCipherSuites, c.logger.Named("cluster-listener")))
|
||||
networkLayer := c.clusterNetworkLayer
|
||||
|
||||
if networkLayer == nil {
|
||||
networkLayer = cluster.NewTCPLayer(c.clusterListenerAddrs, c.logger.Named("cluster-listener.tcp"))
|
||||
}
|
||||
|
||||
c.clusterListener.Store(cluster.NewListener(networkLayer, c.clusterCipherSuites, c.logger.Named("cluster-listener")))
|
||||
|
||||
err := c.getClusterListener().Run(ctx)
|
||||
if err != nil {
|
||||
|
@ -310,7 +315,7 @@ func (c *Core) startClusterListener(ctx context.Context) error {
|
|||
}
|
||||
if strings.HasSuffix(c.ClusterAddr(), ":0") {
|
||||
// If we listened on port 0, record the port the OS gave us.
|
||||
c.clusterAddr.Store(fmt.Sprintf("https://%s", c.getClusterListener().Addrs()[0]))
|
||||
c.clusterAddr.Store(fmt.Sprintf("https://%s", c.getClusterListener().Addr()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -355,37 +360,3 @@ func (c *Core) SetClusterListenerAddrs(addrs []*net.TCPAddr) {
|
|||
func (c *Core) SetClusterHandler(handler http.Handler) {
|
||||
c.clusterHandler = handler
|
||||
}
|
||||
|
||||
// getGRPCDialer is used to return a dialer that has the correct TLS
|
||||
// configuration. Otherwise gRPC tries to be helpful and stomps all over our
|
||||
// NextProtos.
|
||||
func (c *Core) getGRPCDialer(ctx context.Context, alpnProto, serverName string, caCert *x509.Certificate) func(string, time.Duration) (net.Conn, error) {
|
||||
return func(addr string, timeout time.Duration) (net.Conn, error) {
|
||||
clusterListener := c.getClusterListener()
|
||||
if clusterListener == nil {
|
||||
return nil, errors.New("clustering disabled")
|
||||
}
|
||||
|
||||
tlsConfig, err := clusterListener.TLSConfig(ctx)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to get tls configuration", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
if serverName != "" {
|
||||
tlsConfig.ServerName = serverName
|
||||
}
|
||||
if caCert != nil {
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(caCert)
|
||||
tlsConfig.RootCAs = pool
|
||||
tlsConfig.ClientCAs = pool
|
||||
}
|
||||
c.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName)
|
||||
|
||||
tlsConfig.NextProtos = []string{alpnProto}
|
||||
dialer := &net.Dialer{
|
||||
Timeout: timeout,
|
||||
}
|
||||
return tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -27,6 +28,8 @@ const (
|
|||
// Client is used to lookup a client certificate.
|
||||
type Client interface {
|
||||
ClientLookup(context.Context, *tls.CertificateRequestInfo) (*tls.Certificate, error)
|
||||
ServerName() string
|
||||
CACert(ctx context.Context) *x509.Certificate
|
||||
}
|
||||
|
||||
// Handler exposes functions for looking up TLS configuration and handing
|
||||
|
@ -48,6 +51,7 @@ type ClusterHook interface {
|
|||
StopHandler(alpn string)
|
||||
TLSConfig(ctx context.Context) (*tls.Config, error)
|
||||
Addr() net.Addr
|
||||
GetDialerFunc(ctx context.Context, alpnProto string) func(string, time.Duration) (net.Conn, error)
|
||||
}
|
||||
|
||||
// Listener is the source of truth for cluster handlers and connection
|
||||
|
@ -60,13 +64,13 @@ type Listener struct {
|
|||
shutdownWg *sync.WaitGroup
|
||||
server *http2.Server
|
||||
|
||||
listenerAddrs []*net.TCPAddr
|
||||
cipherSuites []uint16
|
||||
logger log.Logger
|
||||
l sync.RWMutex
|
||||
networkLayer NetworkLayer
|
||||
cipherSuites []uint16
|
||||
logger log.Logger
|
||||
l sync.RWMutex
|
||||
}
|
||||
|
||||
func NewListener(addrs []*net.TCPAddr, cipherSuites []uint16, logger log.Logger) *Listener {
|
||||
func NewListener(networkLayer NetworkLayer, cipherSuites []uint16, logger log.Logger) *Listener {
|
||||
// Create the HTTP/2 server that will be shared by both RPC and regular
|
||||
// duties. Doing it this way instead of listening via the server and gRPC
|
||||
// allows us to re-use the same port via ALPN. We can just tell the server
|
||||
|
@ -84,19 +88,22 @@ func NewListener(addrs []*net.TCPAddr, cipherSuites []uint16, logger log.Logger)
|
|||
shutdownWg: &sync.WaitGroup{},
|
||||
server: h2Server,
|
||||
|
||||
listenerAddrs: addrs,
|
||||
cipherSuites: cipherSuites,
|
||||
logger: logger,
|
||||
networkLayer: networkLayer,
|
||||
cipherSuites: cipherSuites,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: This probably isn't correct
|
||||
func (cl *Listener) Addr() net.Addr {
|
||||
return cl.listenerAddrs[0]
|
||||
addrs := cl.Addrs()
|
||||
if len(addrs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return addrs[0]
|
||||
}
|
||||
|
||||
func (cl *Listener) Addrs() []*net.TCPAddr {
|
||||
return cl.listenerAddrs
|
||||
func (cl *Listener) Addrs() []net.Addr {
|
||||
return cl.networkLayer.Addrs()
|
||||
}
|
||||
|
||||
// AddClient adds a new client for an ALPN name
|
||||
|
@ -236,29 +243,15 @@ func (cl *Listener) Run(ctx context.Context) error {
|
|||
// The server supports all of the possible protos
|
||||
tlsConfig.NextProtos = []string{"h2", consts.RequestForwardingALPN, consts.PerfStandbyALPN, consts.PerformanceReplicationALPN, consts.DRReplicationALPN}
|
||||
|
||||
for i, laddr := range cl.listenerAddrs {
|
||||
for _, ln := range cl.networkLayer.Listeners() {
|
||||
// closeCh is used to shutdown the spawned goroutines once this
|
||||
// function returns
|
||||
closeCh := make(chan struct{})
|
||||
|
||||
if cl.logger.IsInfo() {
|
||||
cl.logger.Info("starting listener", "listener_address", laddr)
|
||||
}
|
||||
|
||||
// Create a TCP listener. We do this separately and specifically
|
||||
// with TCP so that we can set deadlines.
|
||||
tcpLn, err := net.ListenTCP("tcp", laddr)
|
||||
if err != nil {
|
||||
cl.logger.Error("error starting listener", "error", err)
|
||||
continue
|
||||
}
|
||||
if laddr.String() != tcpLn.Addr().String() {
|
||||
// If we listened on port 0, record the port the OS gave us.
|
||||
cl.listenerAddrs[i] = tcpLn.Addr().(*net.TCPAddr)
|
||||
}
|
||||
localLn := ln
|
||||
|
||||
// Wrap the listener with TLS
|
||||
tlsLn := tls.NewListener(tcpLn, tlsConfig)
|
||||
tlsLn := tls.NewListener(localLn, tlsConfig)
|
||||
|
||||
if cl.logger.IsInfo() {
|
||||
cl.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr())
|
||||
|
@ -281,7 +274,7 @@ func (cl *Listener) Run(ctx context.Context) error {
|
|||
// Set the deadline for the accept call. If it passes we'll get
|
||||
// an error, causing us to check the condition at the top
|
||||
// again.
|
||||
tcpLn.SetDeadline(time.Now().Add(ListenerAcceptDeadline))
|
||||
localLn.SetDeadline(time.Now().Add(ListenerAcceptDeadline))
|
||||
|
||||
// Accept the connection
|
||||
conn, err := tlsLn.Accept()
|
||||
|
@ -365,3 +358,67 @@ func (cl *Listener) Stop() {
|
|||
cl.shutdownWg.Wait()
|
||||
cl.logger.Info("rpc listeners successfully shut down")
|
||||
}
|
||||
|
||||
// GetDialerFunc returns a function that looks up the TLS information for the
|
||||
// provided alpn name and calls the network layer's dial function.
|
||||
func (cl *Listener) GetDialerFunc(ctx context.Context, alpn string) func(string, time.Duration) (net.Conn, error) {
|
||||
return func(addr string, timeout time.Duration) (net.Conn, error) {
|
||||
tlsConfig, err := cl.TLSConfig(ctx)
|
||||
if err != nil {
|
||||
cl.logger.Error("failed to get tls configuration", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tlsConfig == nil {
|
||||
return nil, errors.New("no tls config found")
|
||||
}
|
||||
|
||||
cl.l.RLock()
|
||||
client, ok := cl.clients[alpn]
|
||||
cl.l.RUnlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no client configured for alpn: %q", alpn)
|
||||
}
|
||||
|
||||
serverName := client.ServerName()
|
||||
if serverName != "" {
|
||||
tlsConfig.ServerName = serverName
|
||||
}
|
||||
|
||||
caCert := client.CACert(ctx)
|
||||
if caCert != nil {
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(caCert)
|
||||
tlsConfig.RootCAs = pool
|
||||
tlsConfig.ClientCAs = pool
|
||||
}
|
||||
|
||||
tlsConfig.NextProtos = []string{alpn}
|
||||
cl.logger.Debug("creating rpc dialer", "alpn", alpn, "host", tlsConfig.ServerName)
|
||||
|
||||
return cl.networkLayer.Dial(addr, timeout, tlsConfig)
|
||||
}
|
||||
}
|
||||
|
||||
// NetworkListener is used by the network layer to define a net.Listener for use
|
||||
// in the cluster listener.
|
||||
type NetworkListener interface {
|
||||
net.Listener
|
||||
|
||||
SetDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// NetworkLayer is the network abstraction used in the cluster listener.
|
||||
// Abstracting the network layer out allows us to swap the underlying
|
||||
// implementations for tests.
|
||||
type NetworkLayer interface {
|
||||
Addrs() []net.Addr
|
||||
Listeners() []NetworkListener
|
||||
Dial(address string, timeout time.Duration, tlsConfig *tls.Config) (*tls.Conn, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// NetworkLayerSet is used for returning a slice of layers to a caller.
|
||||
type NetworkLayerSet interface {
|
||||
Layers() []NetworkLayer
|
||||
}
|
||||
|
|
|
@ -0,0 +1,323 @@
|
|||
package cluster
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/sdk/helper/base62"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// InmemLayer is an in-memory implementation of NetworkLayer. This is
|
||||
// primarially useful for tests.
|
||||
type InmemLayer struct {
|
||||
listener *inmemListener
|
||||
addr string
|
||||
logger log.Logger
|
||||
|
||||
servConns map[string][]net.Conn
|
||||
clientConns map[string][]net.Conn
|
||||
|
||||
peers map[string]*InmemLayer
|
||||
l sync.Mutex
|
||||
|
||||
stopped *atomic.Bool
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewInmemLayer returns a new in-memory layer configured to listen on the
|
||||
// provided address.
|
||||
func NewInmemLayer(addr string, logger log.Logger) *InmemLayer {
|
||||
return &InmemLayer{
|
||||
addr: addr,
|
||||
logger: logger,
|
||||
stopped: atomic.NewBool(false),
|
||||
stopCh: make(chan struct{}),
|
||||
peers: make(map[string]*InmemLayer),
|
||||
servConns: make(map[string][]net.Conn),
|
||||
clientConns: make(map[string][]net.Conn),
|
||||
}
|
||||
}
|
||||
|
||||
// Addrs implements NetworkLayer.
|
||||
func (l *InmemLayer) Addrs() []net.Addr {
|
||||
l.l.Lock()
|
||||
defer l.l.Unlock()
|
||||
|
||||
if l.listener == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []net.Addr{l.listener.Addr()}
|
||||
}
|
||||
|
||||
// Listeners implements NetworkLayer.
|
||||
func (l *InmemLayer) Listeners() []NetworkListener {
|
||||
l.l.Lock()
|
||||
defer l.l.Unlock()
|
||||
|
||||
if l.listener != nil {
|
||||
return []NetworkListener{l.listener}
|
||||
}
|
||||
|
||||
l.listener = &inmemListener{
|
||||
addr: l.addr,
|
||||
pendingConns: make(chan net.Conn),
|
||||
|
||||
stopped: atomic.NewBool(false),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
return []NetworkListener{l.listener}
|
||||
}
|
||||
|
||||
// Dial implements NetworkLayer.
|
||||
func (l *InmemLayer) Dial(addr string, timeout time.Duration, tlsConfig *tls.Config) (*tls.Conn, error) {
|
||||
l.l.Lock()
|
||||
defer l.l.Unlock()
|
||||
|
||||
peer, ok := l.peers[addr]
|
||||
if !ok {
|
||||
return nil, errors.New("inmemlayer: no address found")
|
||||
}
|
||||
|
||||
conn, err := peer.clientConn(l.addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConn := tls.Client(conn, tlsConfig)
|
||||
|
||||
l.clientConns[addr] = append(l.clientConns[addr], tlsConn)
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// clientConn is executed on a server when a new client connection comes in and
|
||||
// needs to be Accepted.
|
||||
func (l *InmemLayer) clientConn(addr string) (net.Conn, error) {
|
||||
l.l.Lock()
|
||||
defer l.l.Unlock()
|
||||
|
||||
if l.listener == nil {
|
||||
return nil, errors.New("inmemlayer: listener not started")
|
||||
}
|
||||
|
||||
_, ok := l.peers[addr]
|
||||
if !ok {
|
||||
return nil, errors.New("inmemlayer: no peer found")
|
||||
}
|
||||
|
||||
retConn, servConn := net.Pipe()
|
||||
|
||||
l.servConns[addr] = append(l.servConns[addr], servConn)
|
||||
|
||||
select {
|
||||
case l.listener.pendingConns <- servConn:
|
||||
case <-time.After(2 * time.Second):
|
||||
return nil, errors.New("inmemlayer: timeout while accepting connection")
|
||||
}
|
||||
|
||||
return retConn, nil
|
||||
}
|
||||
|
||||
// Connect is used to connect this transport to another transport for
|
||||
// a given peer name. This allows for local routing.
|
||||
func (l *InmemLayer) Connect(peer string, remote *InmemLayer) {
|
||||
l.l.Lock()
|
||||
defer l.l.Unlock()
|
||||
l.peers[peer] = remote
|
||||
}
|
||||
|
||||
// Disconnect is used to remove the ability to route to a given peer.
|
||||
func (l *InmemLayer) Disconnect(peer string) {
|
||||
l.l.Lock()
|
||||
defer l.l.Unlock()
|
||||
delete(l.peers, peer)
|
||||
|
||||
// Remove any open connections
|
||||
servConns := l.servConns[peer]
|
||||
for _, c := range servConns {
|
||||
c.Close()
|
||||
}
|
||||
delete(l.servConns, peer)
|
||||
|
||||
clientConns := l.clientConns[peer]
|
||||
for _, c := range clientConns {
|
||||
c.Close()
|
||||
}
|
||||
delete(l.clientConns, peer)
|
||||
}
|
||||
|
||||
// DisconnectAll is used to remove all routes to peers.
|
||||
func (l *InmemLayer) DisconnectAll() {
|
||||
l.l.Lock()
|
||||
defer l.l.Unlock()
|
||||
l.peers = make(map[string]*InmemLayer)
|
||||
|
||||
// Close all connections
|
||||
for _, peerConns := range l.servConns {
|
||||
for _, c := range peerConns {
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
l.servConns = make(map[string][]net.Conn)
|
||||
|
||||
for _, peerConns := range l.clientConns {
|
||||
for _, c := range peerConns {
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
l.clientConns = make(map[string][]net.Conn)
|
||||
}
|
||||
|
||||
// Close is used to permanently disable the transport
|
||||
func (l *InmemLayer) Close() error {
|
||||
if l.stopped.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
l.DisconnectAll()
|
||||
close(l.stopCh)
|
||||
return nil
|
||||
}
|
||||
|
||||
// inmemListener implements the NetworkListener interface.
|
||||
type inmemListener struct {
|
||||
addr string
|
||||
pendingConns chan net.Conn
|
||||
|
||||
stopped *atomic.Bool
|
||||
stopCh chan struct{}
|
||||
|
||||
deadline time.Time
|
||||
}
|
||||
|
||||
// Accept implements the NetworkListener interface.
|
||||
func (ln *inmemListener) Accept() (net.Conn, error) {
|
||||
deadline := ln.deadline
|
||||
if !deadline.IsZero() {
|
||||
select {
|
||||
case conn := <-ln.pendingConns:
|
||||
return conn, nil
|
||||
case <-time.After(time.Until(deadline)):
|
||||
return nil, deadlineError("deadline")
|
||||
case <-ln.stopCh:
|
||||
return nil, errors.New("listener shut down")
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case conn := <-ln.pendingConns:
|
||||
return conn, nil
|
||||
case <-ln.stopCh:
|
||||
return nil, errors.New("listener shut down")
|
||||
}
|
||||
}
|
||||
|
||||
// Close implements the NetworkListener interface.
|
||||
func (ln *inmemListener) Close() error {
|
||||
if ln.stopped.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
close(ln.stopCh)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Addr implements the NetworkListener interface.
|
||||
func (ln *inmemListener) Addr() net.Addr {
|
||||
return inmemAddr{addr: ln.addr}
|
||||
}
|
||||
|
||||
// SetDeadline implements the NetworkListener interface.
|
||||
func (ln *inmemListener) SetDeadline(d time.Time) error {
|
||||
ln.deadline = d
|
||||
return nil
|
||||
}
|
||||
|
||||
type inmemAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a inmemAddr) Network() string {
|
||||
return "inmem"
|
||||
}
|
||||
func (a inmemAddr) String() string {
|
||||
return a.addr
|
||||
}
|
||||
|
||||
type deadlineError string
|
||||
|
||||
func (d deadlineError) Error() string { return string(d) }
|
||||
func (d deadlineError) Timeout() bool { return true }
|
||||
func (d deadlineError) Temporary() bool { return true }
|
||||
|
||||
// InmemLayerCluster composes a set of layers and handles connecting them all
|
||||
// together. It also satisfies the NetworkLayerSet interface.
|
||||
type InmemLayerCluster struct {
|
||||
layers []*InmemLayer
|
||||
}
|
||||
|
||||
// NewInmemLayerCluster returns a new in-memory layer set that builds n nodes
|
||||
// and connects them all together.
|
||||
func NewInmemLayerCluster(nodes int, logger log.Logger) (*InmemLayerCluster, error) {
|
||||
clusterID, err := base62.Random(4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clusterName := "cluster_" + clusterID
|
||||
|
||||
var layers []*InmemLayer
|
||||
for i := 0; i < nodes; i++ {
|
||||
nodeID, err := base62.Random(4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodeName := clusterName + "_node_" + nodeID
|
||||
|
||||
layers = append(layers, NewInmemLayer(nodeName, logger))
|
||||
}
|
||||
|
||||
// Connect all the peers together
|
||||
for _, node := range layers {
|
||||
for _, peer := range layers {
|
||||
// Don't connect to itself
|
||||
if node == peer {
|
||||
continue
|
||||
}
|
||||
|
||||
node.Connect(peer.addr, peer)
|
||||
peer.Connect(node.addr, node)
|
||||
}
|
||||
}
|
||||
|
||||
return &InmemLayerCluster{layers: layers}, nil
|
||||
}
|
||||
|
||||
// ConnectCluster connects this cluster with the provided remote cluster,
|
||||
// connecting all nodes to each other.
|
||||
func (ic *InmemLayerCluster) ConnectCluster(remote *InmemLayerCluster) {
|
||||
for _, node := range ic.layers {
|
||||
for _, peer := range remote.layers {
|
||||
node.Connect(peer.addr, peer)
|
||||
peer.Connect(node.addr, node)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Layers implements the NetworkLayerSet interface.
|
||||
func (ic *InmemLayerCluster) Layers() []NetworkLayer {
|
||||
ret := make([]NetworkLayer, len(ic.layers))
|
||||
for i, l := range ic.layers {
|
||||
ret[i] = l
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
|
@ -0,0 +1,240 @@
|
|||
package cluster
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
func TestInmemCluster_Connect(t *testing.T) {
|
||||
cluster, err := NewInmemLayerCluster(3, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
server := cluster.layers[0]
|
||||
|
||||
listener := server.Listeners()[0]
|
||||
var accepted int
|
||||
stopCh := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
listener.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
_, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
accepted++
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
// Make sure two nodes can connect in
|
||||
conn, err := cluster.layers[1].Dial(server.addr, 0, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
t.Fatal("nil conn")
|
||||
}
|
||||
|
||||
conn, err = cluster.layers[2].Dial(server.addr, 0, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
t.Fatal("nil conn")
|
||||
}
|
||||
|
||||
close(stopCh)
|
||||
wg.Wait()
|
||||
|
||||
if accepted != 2 {
|
||||
t.Fatalf("expected 2 connections to be accepted, got %d", accepted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInmemCluster_Disconnect(t *testing.T) {
|
||||
cluster, err := NewInmemLayerCluster(3, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
server := cluster.layers[0]
|
||||
server.Disconnect(cluster.layers[1].addr)
|
||||
|
||||
listener := server.Listeners()[0]
|
||||
var accepted int
|
||||
stopCh := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
listener.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
_, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
accepted++
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
// Make sure node1 cannot connect in
|
||||
conn, err := cluster.layers[1].Dial(server.addr, 0, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
|
||||
// Node2 should be able to connect
|
||||
conn, err = cluster.layers[2].Dial(server.addr, 0, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
t.Fatal("nil conn")
|
||||
}
|
||||
|
||||
close(stopCh)
|
||||
wg.Wait()
|
||||
|
||||
if accepted != 1 {
|
||||
t.Fatalf("expected 1 connections to be accepted, got %d", accepted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInmemCluster_DisconnectAll(t *testing.T) {
|
||||
cluster, err := NewInmemLayerCluster(3, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
server := cluster.layers[0]
|
||||
server.DisconnectAll()
|
||||
|
||||
// Make sure nodes cannot connect in
|
||||
conn, err := cluster.layers[1].Dial(server.addr, 0, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
|
||||
conn, err = cluster.layers[2].Dial(server.addr, 0, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInmemCluster_ConnectCluster(t *testing.T) {
|
||||
cluster, err := NewInmemLayerCluster(3, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cluster2, err := NewInmemLayerCluster(3, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cluster.ConnectCluster(cluster2)
|
||||
|
||||
var accepted atomic.Int32
|
||||
stopCh := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
acceptConns := func(listener NetworkListener) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
listener.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
_, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
accepted.Add(1)
|
||||
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Start a listener on each node.
|
||||
for _, node := range cluster.layers {
|
||||
acceptConns(node.Listeners()[0])
|
||||
}
|
||||
for _, node := range cluster2.layers {
|
||||
acceptConns(node.Listeners()[0])
|
||||
}
|
||||
|
||||
// Make sure each node can connect to each other
|
||||
for _, node1 := range cluster.layers {
|
||||
for _, node2 := range cluster2.layers {
|
||||
conn, err := node1.Dial(node2.addr, 0, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
t.Fatal("nil conn")
|
||||
}
|
||||
|
||||
conn, err = node2.Dial(node1.addr, 0, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
t.Fatal("nil conn")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
close(stopCh)
|
||||
wg.Wait()
|
||||
|
||||
if accepted.Load() != 18 {
|
||||
t.Fatalf("expected 18 connections to be accepted, got %d", accepted)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
package cluster
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// TCPLayer implements the NetworkLayer interface and uses TCP as the underlying
|
||||
// network.
|
||||
type TCPLayer struct {
|
||||
listeners []NetworkListener
|
||||
addrs []*net.TCPAddr
|
||||
logger log.Logger
|
||||
|
||||
l sync.Mutex
|
||||
stopped *atomic.Bool
|
||||
}
|
||||
|
||||
// NewTCPLayer returns a TCPLayer.
|
||||
func NewTCPLayer(addrs []*net.TCPAddr, logger log.Logger) *TCPLayer {
|
||||
return &TCPLayer{
|
||||
addrs: addrs,
|
||||
logger: logger,
|
||||
stopped: atomic.NewBool(false),
|
||||
}
|
||||
}
|
||||
|
||||
// Addrs implements NetworkLayer.
|
||||
func (l *TCPLayer) Addrs() []net.Addr {
|
||||
l.l.Lock()
|
||||
defer l.l.Unlock()
|
||||
|
||||
if len(l.addrs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ret := make([]net.Addr, len(l.addrs))
|
||||
for i, a := range l.addrs {
|
||||
ret[i] = a
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// Listeners implements NetworkLayer. It starts a new TCP listener for each
|
||||
// configured address.
|
||||
func (l *TCPLayer) Listeners() []NetworkListener {
|
||||
l.l.Lock()
|
||||
defer l.l.Unlock()
|
||||
|
||||
if l.listeners != nil {
|
||||
return l.listeners
|
||||
}
|
||||
|
||||
listeners := make([]NetworkListener, len(l.addrs))
|
||||
for i, laddr := range l.addrs {
|
||||
if l.logger.IsInfo() {
|
||||
l.logger.Info("starting listener", "listener_address", laddr)
|
||||
}
|
||||
|
||||
tcpLn, err := net.ListenTCP("tcp", laddr)
|
||||
if err != nil {
|
||||
l.logger.Error("error starting listener", "error", err)
|
||||
continue
|
||||
}
|
||||
if laddr.String() != tcpLn.Addr().String() {
|
||||
// If we listened on port 0, record the port the OS gave us.
|
||||
l.addrs[i] = tcpLn.Addr().(*net.TCPAddr)
|
||||
}
|
||||
|
||||
listeners[i] = tcpLn
|
||||
}
|
||||
|
||||
l.listeners = listeners
|
||||
|
||||
return listeners
|
||||
}
|
||||
|
||||
// Dial implements the NetworkLayer interface.
|
||||
func (l *TCPLayer) Dial(address string, timeout time.Duration, tlsConfig *tls.Config) (*tls.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: timeout,
|
||||
}
|
||||
return tls.DialWithDialer(dialer, "tcp", address, tlsConfig)
|
||||
}
|
||||
|
||||
// Close implements the NetworkLayer interface.
|
||||
func (l *TCPLayer) Close() error {
|
||||
if l.stopped.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
l.l.Lock()
|
||||
defer l.l.Unlock()
|
||||
|
||||
var retErr *multierror.Error
|
||||
for _, ln := range l.listeners {
|
||||
if err := ln.Close(); err != nil {
|
||||
retErr = multierror.Append(retErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
l.listeners = nil
|
||||
|
||||
return retErr.ErrorOrNil()
|
||||
}
|
|
@ -4,7 +4,6 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -15,6 +14,7 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
"github.com/hashicorp/vault/sdk/physical/inmem"
|
||||
"github.com/hashicorp/vault/vault/cluster"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -100,13 +100,13 @@ func TestCluster_ListenForRequests(t *testing.T) {
|
|||
// Wait for core to become active
|
||||
TestWaitActive(t, cores[0].Core)
|
||||
|
||||
cores[0].getClusterListener().AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
|
||||
clusterListener := cores[0].getClusterListener()
|
||||
clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
|
||||
addrs := cores[0].getClusterListener().Addrs()
|
||||
|
||||
// Use this to have a valid config after sealing since ClusterTLSConfig returns nil
|
||||
checkListenersFunc := func(expectFail bool) {
|
||||
parsedCert := cores[0].localClusterParsedCert.Load().(*x509.Certificate)
|
||||
dialer := cores[0].getGRPCDialer(context.Background(), consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
|
||||
dialer := clusterListener.GetDialerFunc(context.Background(), consts.RequestForwardingALPN)
|
||||
for i := range cores[0].Listeners {
|
||||
|
||||
clnAddr := addrs[i]
|
||||
|
@ -172,11 +172,25 @@ func TestCluster_ForwardRequests(t *testing.T) {
|
|||
// Make this nicer for tests
|
||||
manualStepDownSleepPeriod = 5 * time.Second
|
||||
|
||||
testCluster_ForwardRequestsCommon(t)
|
||||
t.Run("tcpLayer", func(t *testing.T) {
|
||||
testCluster_ForwardRequestsCommon(t, nil)
|
||||
})
|
||||
|
||||
t.Run("inmemLayer", func(t *testing.T) {
|
||||
// Run again with in-memory network
|
||||
inmemCluster, err := cluster.NewInmemLayerCluster(3, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testCluster_ForwardRequestsCommon(t, &TestClusterOptions{
|
||||
ClusterLayers: inmemCluster,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testCluster_ForwardRequestsCommon(t *testing.T) {
|
||||
cluster := NewTestCluster(t, nil, nil)
|
||||
func testCluster_ForwardRequestsCommon(t *testing.T, clusterOpts *TestClusterOptions) {
|
||||
cluster := NewTestCluster(t, nil, clusterOpts)
|
||||
cores := cluster.Cores
|
||||
cores[0].Handler.(*http.ServeMux).HandleFunc("/core1", func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
|
|
|
@ -493,6 +493,14 @@ type Core struct {
|
|||
secureRandomReader io.Reader
|
||||
|
||||
recoveryMode bool
|
||||
|
||||
clusterNetworkLayer cluster.NetworkLayer
|
||||
|
||||
// PR1103disabled is used to test upgrade workflows: when set to true,
|
||||
// the correct behaviour for namespaced cubbyholes is disabled, so we
|
||||
// can test an upgrade to a version that includes the fixes from
|
||||
// https://github.com/hashicorp/vault-enterprise/pull/1103
|
||||
PR1103disabled bool
|
||||
}
|
||||
|
||||
// CoreConfig is used to parameterize a core
|
||||
|
@ -576,6 +584,8 @@ type CoreConfig struct {
|
|||
CounterSyncInterval time.Duration
|
||||
|
||||
RecoveryMode bool
|
||||
|
||||
ClusterNetworkLayer cluster.NetworkLayer
|
||||
}
|
||||
|
||||
func (c *CoreConfig) Clone() *CoreConfig {
|
||||
|
@ -611,6 +621,7 @@ func (c *CoreConfig) Clone() *CoreConfig {
|
|||
DisableIndexing: c.DisableIndexing,
|
||||
AllLoggers: c.AllLoggers,
|
||||
CounterSyncInterval: c.CounterSyncInterval,
|
||||
ClusterNetworkLayer: c.ClusterNetworkLayer,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -706,6 +717,7 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
maxLeaseTTL: conf.MaxLeaseTTL,
|
||||
cachingDisabled: conf.DisableCache,
|
||||
clusterName: conf.ClusterName,
|
||||
clusterNetworkLayer: conf.ClusterNetworkLayer,
|
||||
clusterPeerClusterAddrsCache: cache.New(3*cluster.HeartbeatInterval, time.Second),
|
||||
enableMlock: !conf.DisableMlock,
|
||||
rawEnabled: conf.EnableRaw,
|
||||
|
|
|
@ -99,6 +99,19 @@ func (c *requestForwardingClusterClient) ClientLookup(ctx context.Context, reque
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *requestForwardingClusterClient) ServerName() string {
|
||||
parsedCert := c.core.localClusterParsedCert.Load().(*x509.Certificate)
|
||||
if parsedCert == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return parsedCert.Subject.CommonName
|
||||
}
|
||||
|
||||
func (c *requestForwardingClusterClient) CACert(ctx context.Context) *x509.Certificate {
|
||||
return c.core.localClusterParsedCert.Load().(*x509.Certificate)
|
||||
}
|
||||
|
||||
// ServerLookup satisfies the ClusterHandler interface and returns the server's
|
||||
// tls certs.
|
||||
func (rf *requestForwardingHandler) ServerLookup(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
|
@ -246,19 +259,22 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
|
|||
}
|
||||
|
||||
clusterListener := c.getClusterListener()
|
||||
if clusterListener != nil {
|
||||
clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{
|
||||
core: c,
|
||||
})
|
||||
if clusterListener == nil {
|
||||
c.logger.Error("no cluster listener configured")
|
||||
return nil
|
||||
}
|
||||
|
||||
clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{
|
||||
core: c,
|
||||
})
|
||||
|
||||
// Set up grpc forwarding handling
|
||||
// It's not really insecure, but we have to dial manually to get the
|
||||
// ALPN header right. It's just "insecure" because GRPC isn't managing
|
||||
// the TLS state.
|
||||
dctx, cancelFunc := context.WithCancel(ctx)
|
||||
c.rpcClientConn, err = grpc.DialContext(dctx, clusterURL.Host,
|
||||
grpc.WithDialer(c.getGRPCDialer(ctx, consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)),
|
||||
grpc.WithDialer(clusterListener.GetDialerFunc(ctx, consts.RequestForwardingALPN)),
|
||||
grpc.WithInsecure(), // it's not, we handle it in the dialer
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 2 * cluster.HeartbeatInterval,
|
||||
|
|
|
@ -31,6 +31,7 @@ import (
|
|||
hclog "github.com/hashicorp/go-hclog"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/helper/metricsutil"
|
||||
"github.com/hashicorp/vault/vault/cluster"
|
||||
"github.com/hashicorp/vault/vault/seal"
|
||||
"github.com/mitchellh/copystructure"
|
||||
|
||||
|
@ -1057,7 +1058,11 @@ type TestClusterOptions struct {
|
|||
FirstCoreNumber int
|
||||
RequireClientAuth bool
|
||||
// SetupFunc is called after the cluster is started.
|
||||
SetupFunc func(t testing.T, c *TestCluster)
|
||||
SetupFunc func(t testing.T, c *TestCluster)
|
||||
PR1103Disabled bool
|
||||
|
||||
// ClusterLayers are used to override the default cluster connection layer
|
||||
ClusterLayers cluster.NetworkLayerSet
|
||||
}
|
||||
|
||||
var DefaultNumCores = 3
|
||||
|
@ -1093,6 +1098,11 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
|
|||
numCores = opts.NumCores
|
||||
}
|
||||
|
||||
var disablePR1103 bool
|
||||
if opts != nil && opts.PR1103Disabled {
|
||||
disablePR1103 = true
|
||||
}
|
||||
|
||||
var firstCoreNumber int
|
||||
if opts != nil {
|
||||
firstCoreNumber = opts.FirstCoreNumber
|
||||
|
@ -1486,6 +1496,10 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
|
|||
}
|
||||
}
|
||||
|
||||
if opts != nil && opts.ClusterLayers != nil {
|
||||
localConfig.ClusterNetworkLayer = opts.ClusterLayers.Layers()[i]
|
||||
}
|
||||
|
||||
switch {
|
||||
case localConfig.LicensingConfig != nil:
|
||||
if pubKey != nil {
|
||||
|
@ -1506,6 +1520,7 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
c.coreNumber = firstCoreNumber + i
|
||||
c.PR1103disabled = disablePR1103
|
||||
cores = append(cores, c)
|
||||
coreConfigs = append(coreConfigs, &localConfig)
|
||||
if opts != nil && opts.HandlerFunc != nil {
|
||||
|
|
Loading…
Reference in New Issue