package cluster import ( "crypto/tls" "errors" "fmt" "net" "sync" "time" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/sdk/helper/base62" "go.uber.org/atomic" ) // InmemLayer is an in-memory implementation of NetworkLayer. This is // primarially useful for tests. type InmemLayer struct { listener *inmemListener addr string logger log.Logger servConns map[string][]net.Conn clientConns map[string][]net.Conn peers map[string]*InmemLayer l sync.Mutex stopped *atomic.Bool stopCh chan struct{} connectionCh chan *ConnectionInfo readerDelay time.Duration } // NewInmemLayer returns a new in-memory layer configured to listen on the // provided address. func NewInmemLayer(addr string, logger log.Logger) *InmemLayer { return &InmemLayer{ addr: addr, logger: logger, stopped: atomic.NewBool(false), stopCh: make(chan struct{}), peers: make(map[string]*InmemLayer), servConns: make(map[string][]net.Conn), clientConns: make(map[string][]net.Conn), } } func (l *InmemLayer) SetConnectionCh(ch chan *ConnectionInfo) { l.l.Lock() l.connectionCh = ch l.l.Unlock() } func (l *InmemLayer) SetReaderDelay(delay time.Duration) { l.l.Lock() defer l.l.Unlock() l.readerDelay = delay // Update the existing server and client connections for _, servConns := range l.servConns { for _, c := range servConns { c.(*delayedConn).SetDelay(delay) } } for _, clientConns := range l.clientConns { for _, c := range clientConns { c.(*delayedConn).SetDelay(delay) } } } // Addrs implements NetworkLayer. func (l *InmemLayer) Addrs() []net.Addr { l.l.Lock() defer l.l.Unlock() if l.listener == nil { return nil } return []net.Addr{l.listener.Addr()} } // Listeners implements NetworkLayer. func (l *InmemLayer) Listeners() []NetworkListener { l.l.Lock() defer l.l.Unlock() if l.listener != nil { return []NetworkListener{l.listener} } l.listener = &inmemListener{ addr: l.addr, pendingConns: make(chan net.Conn), stopped: atomic.NewBool(false), stopCh: make(chan struct{}), } return []NetworkListener{l.listener} } // Dial implements NetworkLayer. func (l *InmemLayer) Dial(addr string, timeout time.Duration, tlsConfig *tls.Config) (*tls.Conn, error) { l.l.Lock() defer l.l.Unlock() if addr == l.addr { panic(fmt.Sprintf("%q attempted to dial itself", l.addr)) } peer, ok := l.peers[addr] if !ok { return nil, errors.New("inmemlayer: no address found") } alpn := "" if tlsConfig != nil { alpn = tlsConfig.NextProtos[0] } if l.logger.IsDebug() { l.logger.Debug("dailing connection", "node", l.addr, "remote", addr, "alpn", alpn) } if l.connectionCh != nil { select { case l.connectionCh <- &ConnectionInfo{ Node: l.addr, Remote: addr, IsServer: false, ALPN: alpn, }: case <-time.After(2 * time.Second): l.logger.Warn("failed to send connection info") } } conn, err := peer.clientConn(l.addr) if err != nil { return nil, err } tlsConn := tls.Client(conn, tlsConfig) l.clientConns[addr] = append(l.clientConns[addr], conn) return tlsConn, nil } // clientConn is executed on a server when a new client connection comes in and // needs to be Accepted. func (l *InmemLayer) clientConn(addr string) (net.Conn, error) { l.l.Lock() defer l.l.Unlock() if l.listener == nil { return nil, errors.New("inmemlayer: listener not started") } _, ok := l.peers[addr] if !ok { return nil, errors.New("inmemlayer: no peer found") } retConn, servConn := net.Pipe() retConn = newDelayedConn(retConn, l.readerDelay) servConn = newDelayedConn(servConn, l.readerDelay) l.servConns[addr] = append(l.servConns[addr], servConn) if l.logger.IsDebug() { l.logger.Debug("received connection", "node", l.addr, "remote", addr) } if l.connectionCh != nil { select { case l.connectionCh <- &ConnectionInfo{ Node: l.addr, Remote: addr, IsServer: true, }: case <-time.After(2 * time.Second): l.logger.Warn("failed to send connection info") } } select { case l.listener.pendingConns <- servConn: case <-time.After(2 * time.Second): return nil, errors.New("inmemlayer: timeout while accepting connection") } return retConn, nil } // Connect is used to connect this transport to another transport for // a given peer name. This allows for local routing. func (l *InmemLayer) Connect(remote *InmemLayer) { l.l.Lock() defer l.l.Unlock() l.peers[remote.addr] = remote } // Disconnect is used to remove the ability to route to a given peer. func (l *InmemLayer) Disconnect(peer string) { l.l.Lock() defer l.l.Unlock() delete(l.peers, peer) // Remove any open connections servConns := l.servConns[peer] for _, c := range servConns { c.Close() } delete(l.servConns, peer) clientConns := l.clientConns[peer] for _, c := range clientConns { c.Close() } delete(l.clientConns, peer) } // DisconnectAll is used to remove all routes to peers. func (l *InmemLayer) DisconnectAll() { l.l.Lock() defer l.l.Unlock() l.peers = make(map[string]*InmemLayer) // Close all connections for _, peerConns := range l.servConns { for _, c := range peerConns { c.Close() } } l.servConns = make(map[string][]net.Conn) for _, peerConns := range l.clientConns { for _, c := range peerConns { c.Close() } } l.clientConns = make(map[string][]net.Conn) } // Close is used to permanently disable the transport func (l *InmemLayer) Close() error { if l.stopped.Swap(true) { return nil } l.DisconnectAll() close(l.stopCh) return nil } // inmemListener implements the NetworkListener interface. type inmemListener struct { addr string pendingConns chan net.Conn stopped *atomic.Bool stopCh chan struct{} deadline time.Time } // Accept implements the NetworkListener interface. func (ln *inmemListener) Accept() (net.Conn, error) { deadline := ln.deadline if !deadline.IsZero() { select { case conn := <-ln.pendingConns: return conn, nil case <-time.After(time.Until(deadline)): return nil, deadlineError("deadline") case <-ln.stopCh: return nil, errors.New("listener shut down") } } select { case conn := <-ln.pendingConns: return conn, nil case <-ln.stopCh: return nil, errors.New("listener shut down") } } // Close implements the NetworkListener interface. func (ln *inmemListener) Close() error { if ln.stopped.Swap(true) { return nil } close(ln.stopCh) return nil } // Addr implements the NetworkListener interface. func (ln *inmemListener) Addr() net.Addr { return inmemAddr{addr: ln.addr} } // SetDeadline implements the NetworkListener interface. func (ln *inmemListener) SetDeadline(d time.Time) error { ln.deadline = d return nil } type inmemAddr struct { addr string } func (a inmemAddr) Network() string { return "inmem" } func (a inmemAddr) String() string { return a.addr } type deadlineError string func (d deadlineError) Error() string { return string(d) } func (d deadlineError) Timeout() bool { return true } func (d deadlineError) Temporary() bool { return true } // InmemLayerCluster composes a set of layers and handles connecting them all // together. It also satisfies the NetworkLayerSet interface. type InmemLayerCluster struct { layers []*InmemLayer } // NewInmemLayerCluster returns a new in-memory layer set that builds n nodes // and connects them all together. func NewInmemLayerCluster(clusterName string, nodes int, logger log.Logger) (*InmemLayerCluster, error) { if clusterName == "" { clusterID, err := base62.Random(4) if err != nil { return nil, err } clusterName = "cluster_" + clusterID } layers := make([]*InmemLayer, nodes) for i := 0; i < nodes; i++ { layers[i] = NewInmemLayer(fmt.Sprintf("%s_node_%d", clusterName, i), logger) } // Connect all the peers together for _, node := range layers { for _, peer := range layers { // Don't connect to itself if node == peer { continue } node.Connect(peer) peer.Connect(node) } } return &InmemLayerCluster{layers: layers}, nil } // ConnectCluster connects this cluster with the provided remote cluster, // connecting all nodes to each other. func (ic *InmemLayerCluster) ConnectCluster(remote *InmemLayerCluster) { for _, node := range ic.layers { for _, peer := range remote.layers { node.Connect(peer) peer.Connect(node) } } } // Layers implements the NetworkLayerSet interface. func (ic *InmemLayerCluster) Layers() []NetworkLayer { ret := make([]NetworkLayer, len(ic.layers)) for i, l := range ic.layers { ret[i] = l } return ret } func (ic *InmemLayerCluster) SetConnectionCh(ch chan *ConnectionInfo) { for _, node := range ic.layers { node.SetConnectionCh(ch) } } func (ic *InmemLayerCluster) SetReaderDelay(delay time.Duration) { for _, node := range ic.layers { node.SetReaderDelay(delay) } } type ConnectionInfo struct { Node string Remote string IsServer bool ALPN string }