From f5b5fbb3920ca86aa96107849c6fd411ce7b4450 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 14 Feb 2019 18:14:56 -0800 Subject: [PATCH] Refactor the cluster listener (#6232) * Port over OSS cluster port refactor components * Start forwarding * Cleanup a bit * Fix copy error * Return error from perf standby creation * Add some more comments * Fix copy/paste error --- helper/certutil/types.go | 9 + vault/cluster.go | 421 ++++++++++++--- vault/cluster_test.go | 33 +- vault/cluster_tls.go | 85 ---- vault/core.go | 108 ++-- vault/core_util.go | 9 +- vault/ha.go | 3 +- .../cluster.go} | 9 +- vault/request_forwarding.go | 478 +++++++----------- vault/request_forwarding_rpc.go | 3 +- vault/request_forwarding_util.go | 18 - vault/wrapping.go | 3 +- 12 files changed, 633 insertions(+), 546 deletions(-) delete mode 100644 vault/cluster_tls.go rename vault/{replication_cluster_util.go => replication/cluster.go} (60%) delete mode 100644 vault/request_forwarding_util.go diff --git a/helper/certutil/types.go b/helper/certutil/types.go index 9a27a6fb1..06c3b3b11 100644 --- a/helper/certutil/types.go +++ b/helper/certutil/types.go @@ -17,12 +17,21 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "math/big" "strings" "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/helper/errutil" ) +// This can be one of a few key types so the different params may or may not be filled +type ClusterKeyParams struct { + Type string `json:"type" structs:"type" mapstructure:"type"` + X *big.Int `json:"x" structs:"x" mapstructure:"x"` + Y *big.Int `json:"y" structs:"y" mapstructure:"y"` + D *big.Int `json:"d" structs:"d" mapstructure:"d"` +} + // Secret is used to attempt to unmarshal a Vault secret // JSON response, as a convenience type Secret struct { diff --git a/vault/cluster.go b/vault/cluster.go index 5960c3b5d..00445f825 100644 --- a/vault/cluster.go +++ b/vault/cluster.go @@ -15,12 +15,16 @@ import ( mathrand "math/rand" "net" "net/http" + "sync" + "sync/atomic" "time" "github.com/hashicorp/errwrap" + log "github.com/hashicorp/go-hclog" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/logical" + "golang.org/x/net/http2" ) const ( @@ -44,19 +48,6 @@ type ClusterLeaderParams struct { LeaderClusterAddr string } -type ReplicatedClusters struct { - DR *ReplicatedCluster - Performance *ReplicatedCluster -} - -// This can be one of a few key types so the different params may or may not be filled -type clusterKeyParams struct { - Type string `json:"type" structs:"type" mapstructure:"type"` - X *big.Int `json:"x" structs:"x" mapstructure:"x"` - Y *big.Int `json:"y" structs:"y" mapstructure:"y"` - D *big.Int `json:"d" structs:"d" mapstructure:"d"` -} - // Structure representing the storage entry that holds cluster information type Cluster struct { // Name of the cluster @@ -290,10 +281,297 @@ func (c *Core) setupCluster(ctx context.Context) error { return nil } -// startClusterListener starts cluster request listeners during postunseal. It +// ClusterClient is used to lookup a client certificate. +type ClusterClient interface { + ClientLookup(context.Context, *tls.CertificateRequestInfo) (*tls.Certificate, error) +} + +// ClusterHandler exposes functions for looking up TLS configuration and handing +// off a connection for a cluster listener application. +type ClusterHandler interface { + ServerLookup(context.Context, *tls.ClientHelloInfo) (*tls.Certificate, error) + CALookup(context.Context) (*x509.Certificate, error) + + // Handoff is used to pass the connection lifetime off to + // the storage backend + Handoff(context.Context, *sync.WaitGroup, chan struct{}, *tls.Conn) error + Stop() error +} + +// ClusterListener is the source of truth for cluster handlers and connection +// clients. It dynamically builds the cluster TLS information. It's also +// responsible for starting tcp listeners and accepting new cluster connections. +type ClusterListener struct { + handlers map[string]ClusterHandler + clients map[string]ClusterClient + shutdown *uint32 + shutdownWg *sync.WaitGroup + server *http2.Server + + clusterListenerAddrs []*net.TCPAddr + clusterCipherSuites []uint16 + logger log.Logger + l sync.RWMutex +} + +// AddClient adds a new client for an ALPN name +func (cl *ClusterListener) AddClient(alpn string, client ClusterClient) { + cl.l.Lock() + cl.clients[alpn] = client + cl.l.Unlock() +} + +// RemoveClient removes the client for the specified ALPN name +func (cl *ClusterListener) RemoveClient(alpn string) { + cl.l.Lock() + delete(cl.clients, alpn) + cl.l.Unlock() +} + +// AddHandler registers a new cluster handler for the provided ALPN name. +func (cl *ClusterListener) AddHandler(alpn string, handler ClusterHandler) { + cl.l.Lock() + cl.handlers[alpn] = handler + cl.l.Unlock() +} + +// StopHandler stops the cluster handler for the provided ALPN name, it also +// calls stop on the handler. +func (cl *ClusterListener) StopHandler(alpn string) { + cl.l.Lock() + handler, ok := cl.handlers[alpn] + delete(cl.handlers, alpn) + cl.l.Unlock() + if ok { + handler.Stop() + } +} + +// Server returns the http2 server that the cluster listener is using +func (cl *ClusterListener) Server() *http2.Server { + return cl.server +} + +// TLSConfig returns a tls config object that uses dynamic lookups to correctly +// authenticate registered handlers/clients +func (cl *ClusterListener) TLSConfig(ctx context.Context) (*tls.Config, error) { + serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + cl.logger.Debug("performing server cert lookup") + + cl.l.RLock() + defer cl.l.RUnlock() + for _, v := range clientHello.SupportedProtos { + if handler, ok := cl.handlers[v]; ok { + return handler.ServerLookup(ctx, clientHello) + } + } + + return nil, errors.New("unsupported protocol") + } + + clientLookup := func(requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { + cl.logger.Debug("performing client cert lookup") + + cl.l.RLock() + defer cl.l.RUnlock() + for _, client := range cl.clients { + cert, err := client.ClientLookup(ctx, requestInfo) + if err == nil && cert != nil { + return cert, nil + } + } + + return nil, errors.New("no client cert found") + } + + serverConfigLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) { + caPool := x509.NewCertPool() + + ret := &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + GetCertificate: serverLookup, + GetClientCertificate: clientLookup, + MinVersion: tls.VersionTLS12, + RootCAs: caPool, + ClientCAs: caPool, + NextProtos: clientHello.SupportedProtos, + CipherSuites: cl.clusterCipherSuites, + } + + cl.l.RLock() + defer cl.l.RUnlock() + for _, v := range clientHello.SupportedProtos { + if handler, ok := cl.handlers[v]; ok { + ca, err := handler.CALookup(ctx) + if err != nil { + return nil, err + } + + caPool.AddCert(ca) + return ret, nil + } + } + + return nil, errors.New("unsupported protocol") + } + + return &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + GetCertificate: serverLookup, + GetClientCertificate: clientLookup, + GetConfigForClient: serverConfigLookup, + MinVersion: tls.VersionTLS12, + CipherSuites: cl.clusterCipherSuites, + }, nil +} + +// Run starts the tcp listeners and will accept connections until stop is +// called. This function blocks so should be called in a go routine. +func (cl *ClusterListener) Run(ctx context.Context) error { + // Get our TLS config + tlsConfig, err := cl.TLSConfig(ctx) + if err != nil { + cl.logger.Error("failed to get tls configuration when starting cluster listener", "error", err) + return err + } + + // The server supports all of the possible protos + tlsConfig.NextProtos = []string{"h2", requestForwardingALPN, perfStandbyALPN, PerformanceReplicationALPN, DRReplicationALPN} + + for _, addr := range cl.clusterListenerAddrs { + cl.shutdownWg.Add(1) + + // Force a local resolution to avoid data races + laddr := addr + + // Start our listening loop + go func() { + defer cl.shutdownWg.Done() + + // closeCh is used to shutdown the spawned goroutines once this + // function returns + closeCh := make(chan struct{}) + defer func() { + close(closeCh) + }() + + if cl.logger.IsInfo() { + cl.logger.Info("starting listener", "listener_address", laddr) + } + + // Create a TCP listener. We do this separately and specifically + // with TCP so that we can set deadlines. + tcpLn, err := net.ListenTCP("tcp", laddr) + if err != nil { + cl.logger.Error("error starting listener", "error", err) + return + } + + // Wrap the listener with TLS + tlsLn := tls.NewListener(tcpLn, tlsConfig) + defer tlsLn.Close() + + if cl.logger.IsInfo() { + cl.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr()) + } + + for { + if atomic.LoadUint32(cl.shutdown) > 0 { + return + } + + // Set the deadline for the accept call. If it passes we'll get + // an error, causing us to check the condition at the top + // again. + tcpLn.SetDeadline(time.Now().Add(clusterListenerAcceptDeadline)) + + // Accept the connection + conn, err := tlsLn.Accept() + if err != nil { + if err, ok := err.(net.Error); ok && !err.Timeout() { + cl.logger.Debug("non-timeout error accepting on cluster port", "error", err) + } + if conn != nil { + conn.Close() + } + continue + } + if conn == nil { + continue + } + + // Type assert to TLS connection and handshake to populate the + // connection state + tlsConn := conn.(*tls.Conn) + + // Set a deadline for the handshake. This will cause clients + // that don't successfully auth to be kicked out quickly. + // Cluster connections should be reliable so being marginally + // aggressive here is fine. + err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second)) + if err != nil { + if cl.logger.IsDebug() { + cl.logger.Debug("error setting deadline for cluster connection", "error", err) + } + tlsConn.Close() + continue + } + + err = tlsConn.Handshake() + if err != nil { + if cl.logger.IsDebug() { + cl.logger.Debug("error handshaking cluster connection", "error", err) + } + tlsConn.Close() + continue + } + + // Now, set it back to unlimited + err = tlsConn.SetDeadline(time.Time{}) + if err != nil { + if cl.logger.IsDebug() { + cl.logger.Debug("error setting deadline for cluster connection", "error", err) + } + tlsConn.Close() + continue + } + + cl.l.RLock() + handler, ok := cl.handlers[tlsConn.ConnectionState().NegotiatedProtocol] + cl.l.RUnlock() + if !ok { + cl.logger.Debug("unknown negotiated protocol on cluster port") + tlsConn.Close() + continue + } + + if err := handler.Handoff(ctx, cl.shutdownWg, closeCh, tlsConn); err != nil { + cl.logger.Error("error handling cluster connection", "error", err) + continue + } + } + }() + } + + return nil +} + +// Stop stops the cluster listner +func (cl *ClusterListener) Stop() { + // Set the shutdown flag. This will cause the listeners to shut down + // within the deadline in clusterListenerAcceptDeadline + atomic.StoreUint32(cl.shutdown, 1) + cl.logger.Info("forwarding rpc listeners stopped") + + // Wait for them all to shut down + cl.shutdownWg.Wait() + cl.logger.Info("rpc listeners successfully shut down") +} + +// startClusterListener starts cluster request listeners during unseal. It // is assumed that the state lock is held while this is run. Right now this -// only starts forwarding listeners; it's TBD whether other request types will -// be built in the same mechanism or started independently. +// only starts cluster listeners. Once the listener is started handlers/clients +// can start being registered to it. func (c *Core) startClusterListener(ctx context.Context) error { if c.clusterAddr == "" { c.logger.Info("clustering disabled, not starting listeners") @@ -307,76 +585,46 @@ func (c *Core) startClusterListener(ctx context.Context) error { c.logger.Debug("starting cluster listeners") - err := c.startForwarding(ctx) - if err != nil { - return err + // Create the HTTP/2 server that will be shared by both RPC and regular + // duties. Doing it this way instead of listening via the server and gRPC + // allows us to re-use the same port via ALPN. We can just tell the server + // to serve a given conn and which handler to use. + h2Server := &http2.Server{ + // Our forwarding connections heartbeat regularly so anything else we + // want to go away/get cleaned up pretty rapidly + IdleTimeout: 5 * HeartbeatInterval, } - return nil + c.clusterListener = &ClusterListener{ + handlers: make(map[string]ClusterHandler), + clients: make(map[string]ClusterClient), + shutdown: new(uint32), + shutdownWg: &sync.WaitGroup{}, + server: h2Server, + + clusterListenerAddrs: c.clusterListenerAddrs, + clusterCipherSuites: c.clusterCipherSuites, + logger: c.logger.Named("cluster-listener"), + } + + return c.clusterListener.Run(ctx) } -// stopClusterListener stops any existing listeners during preseal. It is +// stopClusterListener stops any existing listeners during seal. It is // assumed that the state lock is held while this is run. func (c *Core) stopClusterListener() { - if c.clusterAddr == "" { - + if c.clusterListener == nil { c.logger.Debug("clustering disabled, not stopping listeners") return } - if !c.clusterListenersRunning { - c.logger.Info("cluster listeners not running") - return - } c.logger.Info("stopping cluster listeners") - // Tell the goroutine managing the listeners to perform the shutdown - // process - c.clusterListenerShutdownCh <- struct{}{} - - // The reason for this loop-de-loop is that we may be unsealing again - // quickly, and if the listeners are not yet closed, we will get socket - // bind errors. This ensures proper ordering. - - c.logger.Debug("waiting for success notification while stopping cluster listeners") - <-c.clusterListenerShutdownSuccessCh - c.clusterListenersRunning = false + c.clusterListener.Stop() c.logger.Info("cluster listeners successfully shut down") } -// ClusterTLSConfig generates a TLS configuration based on the local/replicated -// cluster key and cert. -func (c *Core) ClusterTLSConfig(ctx context.Context, repClusters *ReplicatedClusters, perfStandbyCluster *ReplicatedCluster) (*tls.Config, error) { - // Using lookup functions allows just-in-time lookup of the current state - // of clustering as connections come and go - - tlsConfig := &tls.Config{ - ClientAuth: tls.RequireAndVerifyClientCert, - GetCertificate: clusterTLSServerLookup(ctx, c, repClusters, perfStandbyCluster), - GetClientCertificate: clusterTLSClientLookup(ctx, c, repClusters, perfStandbyCluster), - GetConfigForClient: clusterTLSServerConfigLookup(ctx, c, repClusters, perfStandbyCluster), - MinVersion: tls.VersionTLS12, - CipherSuites: c.clusterCipherSuites, - } - - parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate) - currCert := c.localClusterCert.Load().([]byte) - localCert := make([]byte, len(currCert)) - copy(localCert, currCert) - - if parsedCert != nil { - tlsConfig.ServerName = parsedCert.Subject.CommonName - - pool := x509.NewCertPool() - pool.AddCert(parsedCert) - tlsConfig.RootCAs = pool - tlsConfig.ClientCAs = pool - } - - return tlsConfig, nil -} - func (c *Core) SetClusterListenerAddrs(addrs []*net.TCPAddr) { c.clusterListenerAddrs = addrs if c.clusterAddr == "" && len(addrs) == 1 { @@ -387,3 +635,36 @@ func (c *Core) SetClusterListenerAddrs(addrs []*net.TCPAddr) { func (c *Core) SetClusterHandler(handler http.Handler) { c.clusterHandler = handler } + +// getGRPCDialer is used to return a dialer that has the correct TLS +// configuration. Otherwise gRPC tries to be helpful and stomps all over our +// NextProtos. +func (c *Core) getGRPCDialer(ctx context.Context, alpnProto, serverName string, caCert *x509.Certificate) func(string, time.Duration) (net.Conn, error) { + return func(addr string, timeout time.Duration) (net.Conn, error) { + if c.clusterListener == nil { + return nil, errors.New("clustering disabled") + } + + tlsConfig, err := c.clusterListener.TLSConfig(ctx) + if err != nil { + c.logger.Error("failed to get tls configuration", "error", err) + return nil, err + } + if serverName != "" { + tlsConfig.ServerName = serverName + } + if caCert != nil { + pool := x509.NewCertPool() + pool.AddCert(caCert) + tlsConfig.RootCAs = pool + tlsConfig.ClientCAs = pool + } + c.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName) + + tlsConfig.NextProtos = []string{alpnProto} + dialer := &net.Dialer{ + Timeout: timeout, + } + return tls.DialWithDialer(dialer, "tcp", addr, tlsConfig) + } +} diff --git a/vault/cluster_test.go b/vault/cluster_test.go index 949670d27..3f348ac18 100644 --- a/vault/cluster_test.go +++ b/vault/cluster_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/tls" + "crypto/x509" "fmt" "net" "net/http" @@ -102,32 +103,25 @@ func TestCluster_ListenForRequests(t *testing.T) { TestWaitActive(t, cores[0].Core) // Use this to have a valid config after sealing since ClusterTLSConfig returns nil - var lastTLSConfig *tls.Config checkListenersFunc := func(expectFail bool) { - tlsConfig, err := cores[0].ClusterTLSConfig(context.Background(), nil, nil) - if err != nil { - if err.Error() != consts.ErrSealed.Error() { - t.Fatal(err) - } - tlsConfig = lastTLSConfig - } else { - tlsConfig.NextProtos = []string{"h2"} - lastTLSConfig = tlsConfig - } + cores[0].clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{cores[0].Core}) + parsedCert := cores[0].localClusterParsedCert.Load().(*x509.Certificate) + dialer := cores[0].getGRPCDialer(context.Background(), requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert) for _, ln := range cores[0].Listeners { tcpAddr, ok := ln.Addr().(*net.TCPAddr) if !ok { t.Fatalf("%s not a TCP port", tcpAddr.String()) } - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", tcpAddr.IP.String(), tcpAddr.Port+105), tlsConfig) + netConn, err := dialer(fmt.Sprintf("%s:%d", tcpAddr.IP.String(), tcpAddr.Port+105), 0) + conn := netConn.(*tls.Conn) if err != nil { if expectFail { t.Logf("testing %s:%d unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+105) continue } - t.Fatalf("error: %v\nlisteners are\n%#v\n%#v\n", err, cores[0].Listeners[0], cores[0].Listeners[1]) + t.Fatalf("error: %v\nlisteners are\n%#v\n%#v\n", err, cores[0].Listeners[0], cores[0].Listeners[0]) } if expectFail { t.Fatalf("testing %s:%d not unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+105) @@ -140,7 +134,7 @@ func TestCluster_ListenForRequests(t *testing.T) { switch { case connState.Version != tls.VersionTLS12: t.Fatal("version mismatch") - case connState.NegotiatedProtocol != "h2" || !connState.NegotiatedProtocolIsMutual: + case connState.NegotiatedProtocol != requestForwardingALPN || !connState.NegotiatedProtocolIsMutual: t.Fatal("bad protocol negotiation") } t.Logf("testing %s:%d successful", tcpAddr.IP.String(), tcpAddr.Port+105) @@ -392,12 +386,13 @@ func TestCluster_CustomCipherSuites(t *testing.T) { // Wait for core to become active TestWaitActive(t, core.Core) - tlsConf, err := core.Core.ClusterTLSConfig(context.Background(), nil, nil) - if err != nil { - t.Fatal(err) - } + core.clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{core.Core}) - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", core.Listeners[0].Address.IP.String(), core.Listeners[0].Address.Port+105), tlsConf) + parsedCert := core.localClusterParsedCert.Load().(*x509.Certificate) + dialer := core.getGRPCDialer(context.Background(), requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert) + + netConn, err := dialer(fmt.Sprintf("%s:%d", core.Listeners[0].Address.IP.String(), core.Listeners[0].Address.Port+105), 0) + conn := netConn.(*tls.Conn) if err != nil { t.Fatal(err) } diff --git a/vault/cluster_tls.go b/vault/cluster_tls.go deleted file mode 100644 index 4a63ecfa3..000000000 --- a/vault/cluster_tls.go +++ /dev/null @@ -1,85 +0,0 @@ -package vault - -import ( - "context" - "crypto/ecdsa" - "crypto/tls" - "crypto/x509" - "fmt" -) - -var ( - clusterTLSServerLookup = func(ctx context.Context, c *Core, repClusters *ReplicatedClusters, _ *ReplicatedCluster) func(*tls.ClientHelloInfo) (*tls.Certificate, error) { - return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - c.logger.Debug("performing server cert lookup") - - switch { - default: - currCert := c.localClusterCert.Load().([]byte) - if len(currCert) == 0 { - return nil, fmt.Errorf("got forwarding connection but no local cert") - } - - localCert := make([]byte, len(currCert)) - copy(localCert, currCert) - - return &tls.Certificate{ - Certificate: [][]byte{localCert}, - PrivateKey: c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey), - Leaf: c.localClusterParsedCert.Load().(*x509.Certificate), - }, nil - } - } - } - - clusterTLSClientLookup = func(ctx context.Context, c *Core, repClusters *ReplicatedClusters, _ *ReplicatedCluster) func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { - return func(requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { - if len(requestInfo.AcceptableCAs) != 1 { - return nil, fmt.Errorf("expected only a single acceptable CA") - } - - currCert := c.localClusterCert.Load().([]byte) - if len(currCert) == 0 { - return nil, fmt.Errorf("forwarding connection client but no local cert") - } - - localCert := make([]byte, len(currCert)) - copy(localCert, currCert) - - return &tls.Certificate{ - Certificate: [][]byte{localCert}, - PrivateKey: c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey), - Leaf: c.localClusterParsedCert.Load().(*x509.Certificate), - }, nil - } - } - - clusterTLSServerConfigLookup = func(ctx context.Context, c *Core, repClusters *ReplicatedClusters, repCluster *ReplicatedCluster) func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) { - return func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) { - //c.logger.Trace("performing server config lookup") - - caPool := x509.NewCertPool() - - ret := &tls.Config{ - ClientAuth: tls.RequireAndVerifyClientCert, - GetCertificate: clusterTLSServerLookup(ctx, c, repClusters, repCluster), - GetClientCertificate: clusterTLSClientLookup(ctx, c, repClusters, repCluster), - MinVersion: tls.VersionTLS12, - RootCAs: caPool, - ClientCAs: caPool, - NextProtos: clientHello.SupportedProtos, - CipherSuites: c.clusterCipherSuites, - } - - parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate) - - if parsedCert == nil { - return nil, fmt.Errorf("forwarding connection client but no local cert") - } - - caPool.AddCert(parsedCert) - - return ret, nil - } - } -) diff --git a/vault/core.go b/vault/core.go index 8807af66b..77064c84b 100644 --- a/vault/core.go +++ b/vault/core.go @@ -26,6 +26,7 @@ import ( "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/logging" @@ -132,10 +133,10 @@ func (e *ErrInvalidKey) Error() string { type RegisterAuthFunc func(context.Context, time.Duration, string, *logical.Auth) error type activeAdvertisement struct { - RedirectAddr string `json:"redirect_addr"` - ClusterAddr string `json:"cluster_addr,omitempty"` - ClusterCert []byte `json:"cluster_cert,omitempty"` - ClusterKeyParams *clusterKeyParams `json:"cluster_key_params,omitempty"` + RedirectAddr string `json:"redirect_addr"` + ClusterAddr string `json:"cluster_addr,omitempty"` + ClusterCert []byte `json:"cluster_cert,omitempty"` + ClusterKeyParams *certutil.ClusterKeyParams `json:"cluster_key_params,omitempty"` } type unlockInformation struct { @@ -328,14 +329,6 @@ type Core struct { clusterListenerAddrs []*net.TCPAddr // The handler to use for request forwarding clusterHandler http.Handler - // Tracks whether cluster listeners are running, e.g. it's safe to send a - // shutdown down the channel - clusterListenersRunning bool - // Shutdown channel for the cluster listeners - clusterListenerShutdownCh chan struct{} - // Shutdown success channel. We need this to be done serially to ensure - // that binds are removed before they might be reinstated. - clusterListenerShutdownSuccessCh chan struct{} // Write lock used to ensure that we don't have multiple connections adjust // this value at the same time requestForwardingConnectionLock sync.RWMutex @@ -346,8 +339,6 @@ type Core struct { clusterLeaderParams *atomic.Value // Info on cluster members clusterPeerClusterAddrsCache *cache.Cache - // Stores whether we currently have a server running - rpcServerActive *uint32 // The context for the client rpcClientConnContext context.Context // The function for canceling the client connection @@ -420,7 +411,10 @@ type Core struct { // loadCaseSensitiveIdentityStore enforces the loading of identity store // artifacts in a case sensitive manner. To be used only in testing. - loadCaseSensitiveIdentityStore bool + loadCaseSensitiveIdentityStore bool + + // clusterListener starts up and manages connections on the cluster ports + clusterListener *ClusterListener // Telemetry objects metricsHelper *metricsutil.MetricsHelper @@ -567,43 +561,40 @@ func NewCore(conf *CoreConfig) (*Core, error) { // Setup the core c := &Core{ - entCore: entCore{}, - devToken: conf.DevToken, - physical: conf.Physical, - redirectAddr: conf.RedirectAddr, - clusterAddr: conf.ClusterAddr, - seal: conf.Seal, - router: NewRouter(), - sealed: new(uint32), - standby: true, - baseLogger: conf.Logger, - logger: conf.Logger.Named("core"), - defaultLeaseTTL: conf.DefaultLeaseTTL, - maxLeaseTTL: conf.MaxLeaseTTL, - cachingDisabled: conf.DisableCache, - clusterName: conf.ClusterName, - clusterListenerShutdownCh: make(chan struct{}), - clusterListenerShutdownSuccessCh: make(chan struct{}), - clusterPeerClusterAddrsCache: cache.New(3*HeartbeatInterval, time.Second), - enableMlock: !conf.DisableMlock, - rawEnabled: conf.EnableRaw, - replicationState: new(uint32), - rpcServerActive: new(uint32), - atomicPrimaryClusterAddrs: new(atomic.Value), - atomicPrimaryFailoverAddrs: new(atomic.Value), - localClusterPrivateKey: new(atomic.Value), - localClusterCert: new(atomic.Value), - localClusterParsedCert: new(atomic.Value), - activeNodeReplicationState: new(uint32), - keepHALockOnStepDown: new(uint32), - replicationFailure: new(uint32), - disablePerfStandby: true, - activeContextCancelFunc: new(atomic.Value), - allLoggers: conf.AllLoggers, - builtinRegistry: conf.BuiltinRegistry, - neverBecomeActive: new(uint32), - clusterLeaderParams: new(atomic.Value), - metricsHelper: conf.MetricsHelper, + entCore: entCore{}, + devToken: conf.DevToken, + physical: conf.Physical, + redirectAddr: conf.RedirectAddr, + clusterAddr: conf.ClusterAddr, + seal: conf.Seal, + router: NewRouter(), + sealed: new(uint32), + standby: true, + baseLogger: conf.Logger, + logger: conf.Logger.Named("core"), + defaultLeaseTTL: conf.DefaultLeaseTTL, + maxLeaseTTL: conf.MaxLeaseTTL, + cachingDisabled: conf.DisableCache, + clusterName: conf.ClusterName, + clusterPeerClusterAddrsCache: cache.New(3*HeartbeatInterval, time.Second), + enableMlock: !conf.DisableMlock, + rawEnabled: conf.EnableRaw, + replicationState: new(uint32), + atomicPrimaryClusterAddrs: new(atomic.Value), + atomicPrimaryFailoverAddrs: new(atomic.Value), + localClusterPrivateKey: new(atomic.Value), + localClusterCert: new(atomic.Value), + localClusterParsedCert: new(atomic.Value), + activeNodeReplicationState: new(uint32), + keepHALockOnStepDown: new(uint32), + replicationFailure: new(uint32), + disablePerfStandby: true, + activeContextCancelFunc: new(atomic.Value), + allLoggers: conf.AllLoggers, + builtinRegistry: conf.BuiltinRegistry, + neverBecomeActive: new(uint32), + clusterLeaderParams: new(atomic.Value), + metricsHelper: conf.MetricsHelper, } atomic.StoreUint32(c.sealed, 1) @@ -1043,6 +1034,10 @@ func (c *Core) unsealInternal(ctx context.Context, masterKey []byte) (bool, erro return false, err } + if err := c.startClusterListener(ctx); err != nil { + return false, err + } + // Do post-unseal setup if HA is not enabled if c.ha == nil { // We still need to set up cluster info even if it's not part of a @@ -1365,6 +1360,9 @@ func (c *Core) sealInternalWithOptions(grabStateLock, keepHALock bool) error { c.logger.Debug("runStandby done") } + // Stop the cluster listener + c.stopClusterListener() + c.logger.Debug("sealing barrier") if err := c.barrier.Seal(); err != nil { c.logger.Error("error sealing barrier", "error", err) @@ -1461,8 +1459,8 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c c.auditBroker = NewAuditBroker(c.logger) } - if c.ha != nil || shouldStartClusterListener(c) { - if err := c.startClusterListener(ctx); err != nil { + if c.clusterListener != nil && (c.ha != nil || shouldStartClusterListener(c)) { + if err := c.startForwarding(ctx); err != nil { return err } } @@ -1553,7 +1551,7 @@ func (c *Core) preSeal() error { } c.clusterParamsLock.Unlock() - c.stopClusterListener() + c.stopForwarding() if err := c.teardownAudits(); err != nil { result = multierror.Append(result, errwrap.Wrapf("error tearing down audits: {{err}}", err)) diff --git a/vault/core_util.go b/vault/core_util.go index af3fff1ae..eddc924dc 100644 --- a/vault/core_util.go +++ b/vault/core_util.go @@ -4,11 +4,14 @@ package vault import ( "context" + "time" "github.com/hashicorp/vault/helper/license" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/physical" + "github.com/hashicorp/vault/vault/replication" + cache "github.com/patrickmn/go-cache" ) type entCore struct{} @@ -89,7 +92,7 @@ func (c *Core) namepaceByPath(string) *namespace.Namespace { return namespace.RootNamespace } -func (c *Core) setupReplicatedClusterPrimary(*ReplicatedCluster) error { return nil } +func (c *Core) setupReplicatedClusterPrimary(*replication.Cluster) error { return nil } func (c *Core) perfStandbyCount() int { return 0 } @@ -104,3 +107,7 @@ func (c *Core) checkReplicatedFiltering(context.Context, *MountEntry, string) (b func (c *Core) invalidateSentinelPolicy(PolicyType, string) {} func (c *Core) removePerfStandbySecondary(context.Context, string) {} + +func (c *Core) perfStandbyClusterHandler() (*replication.Cluster, *cache.Cache, chan struct{}, error) { + return nil, cache.New(2*HeartbeatInterval, 1*time.Second), make(chan struct{}), nil +} diff --git a/vault/ha.go b/vault/ha.go index fc998132b..791167522 100644 --- a/vault/ha.go +++ b/vault/ha.go @@ -14,6 +14,7 @@ import ( multierror "github.com/hashicorp/go-multierror" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/namespace" @@ -817,7 +818,7 @@ func (c *Core) advertiseLeader(ctx context.Context, uuid string, leaderLostCh <- return fmt.Errorf("unknown cluster private key type %T", c.localClusterPrivateKey.Load()) } - keyParams := &clusterKeyParams{ + keyParams := &certutil.ClusterKeyParams{ Type: corePrivateKeyTypeP521, X: key.X, Y: key.Y, diff --git a/vault/replication_cluster_util.go b/vault/replication/cluster.go similarity index 60% rename from vault/replication_cluster_util.go rename to vault/replication/cluster.go index 013cc8f70..20d445510 100644 --- a/vault/replication_cluster_util.go +++ b/vault/replication/cluster.go @@ -1,11 +1,16 @@ // +build !enterprise -package vault +package replication import "github.com/hashicorp/vault/helper/consts" -type ReplicatedCluster struct { +type Cluster struct { State consts.ReplicationState ClusterID string PrimaryClusterAddr string } + +type Clusters struct { + DR *Cluster + Performance *Cluster +} diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index ff0eb5fd4..ad8c6d42f 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -1,23 +1,23 @@ package vault import ( + "bytes" "context" + "crypto/ecdsa" "crypto/tls" "crypto/x509" + "errors" "fmt" math "math" - "net" "net/http" "net/url" "sync" - "sync/atomic" "time" - cache "github.com/patrickmn/go-cache" - - uuid "github.com/hashicorp/go-uuid" - "github.com/hashicorp/vault/helper/consts" + log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/forwarding" + "github.com/hashicorp/vault/vault/replication" + cache "github.com/patrickmn/go-cache" "golang.org/x/net/http2" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" @@ -44,11 +44,154 @@ var ( HeartbeatInterval = 5 * time.Second ) -type SecondaryConnsCacheVals struct { - ID string - Token string - Connection net.Conn - Mode consts.ReplicationState +type requestForwardingHandler struct { + fws *http2.Server + fwRPCServer *grpc.Server + logger log.Logger + ha bool + core *Core + stopCh chan struct{} +} + +type requestForwardingClusterClient struct { + core *Core +} + +// NewRequestForwardingHandler creates a cluster handler for use with request +// forwarding. +func NewRequestForwardingHandler(c *Core, fws *http2.Server, perfStandbySlots chan struct{}, perfStandbyRepCluster *replication.Cluster, perfStandbyCache *cache.Cache) (*requestForwardingHandler, error) { + // Resolve locally to avoid races + ha := c.ha != nil + + fwRPCServer := grpc.NewServer( + grpc.KeepaliveParams(keepalive.ServerParameters{ + Time: 2 * HeartbeatInterval, + }), + grpc.MaxRecvMsgSize(math.MaxInt32), + grpc.MaxSendMsgSize(math.MaxInt32), + ) + + if ha && c.clusterHandler != nil { + RegisterRequestForwardingServer(fwRPCServer, &forwardedRequestRPCServer{ + core: c, + handler: c.clusterHandler, + perfStandbySlots: perfStandbySlots, + perfStandbyRepCluster: perfStandbyRepCluster, + perfStandbyCache: perfStandbyCache, + }) + } + + return &requestForwardingHandler{ + fws: fws, + fwRPCServer: fwRPCServer, + ha: ha, + logger: c.logger.Named("request-forward"), + core: c, + stopCh: make(chan struct{}), + }, nil +} + +// ClientLookup satisfies the ClusterClient interface and returns the ha tls +// client certs. +func (c *requestForwardingClusterClient) ClientLookup(ctx context.Context, requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { + parsedCert := c.core.localClusterParsedCert.Load().(*x509.Certificate) + if parsedCert == nil { + return nil, nil + } + currCert := c.core.localClusterCert.Load().([]byte) + if len(currCert) == 0 { + return nil, nil + } + localCert := make([]byte, len(currCert)) + copy(localCert, currCert) + + for _, subj := range requestInfo.AcceptableCAs { + if bytes.Equal(subj, parsedCert.RawIssuer) { + return &tls.Certificate{ + Certificate: [][]byte{localCert}, + PrivateKey: c.core.localClusterPrivateKey.Load().(*ecdsa.PrivateKey), + Leaf: c.core.localClusterParsedCert.Load().(*x509.Certificate), + }, nil + } + } + + return nil, nil +} + +// ServerLookup satisfies the ClusterHandler interface and returns the server's +// tls certs. +func (rf *requestForwardingHandler) ServerLookup(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + currCert := rf.core.localClusterCert.Load().([]byte) + if len(currCert) == 0 { + return nil, fmt.Errorf("got forwarding connection but no local cert") + } + + localCert := make([]byte, len(currCert)) + copy(localCert, currCert) + + return &tls.Certificate{ + Certificate: [][]byte{localCert}, + PrivateKey: rf.core.localClusterPrivateKey.Load().(*ecdsa.PrivateKey), + Leaf: rf.core.localClusterParsedCert.Load().(*x509.Certificate), + }, nil +} + +// CALookup satisfies the ClusterHandler interface and returns the ha ca cert. +func (rf *requestForwardingHandler) CALookup(ctx context.Context) (*x509.Certificate, error) { + parsedCert := rf.core.localClusterParsedCert.Load().(*x509.Certificate) + + if parsedCert == nil { + return nil, fmt.Errorf("forwarding connection client but no local cert") + } + + return parsedCert, nil +} + +// Handoff serves a request forwarding connection. +func (rf *requestForwardingHandler) Handoff(ctx context.Context, shutdownWg *sync.WaitGroup, closeCh chan struct{}, tlsConn *tls.Conn) error { + if !rf.ha { + tlsConn.Close() + return nil + } + + rf.logger.Debug("got request forwarding connection") + + shutdownWg.Add(2) + // quitCh is used to close the connection and the second + // goroutine if the server closes before closeCh. + quitCh := make(chan struct{}) + go func() { + select { + case <-quitCh: + case <-closeCh: + case <-rf.stopCh: + } + tlsConn.Close() + shutdownWg.Done() + }() + + go func() { + rf.fws.ServeConn(tlsConn, &http2.ServeConnOpts{ + Handler: rf.fwRPCServer, + BaseConfig: &http.Server{ + ErrorLog: rf.logger.StandardLogger(nil), + }, + }) + + // close the quitCh which will close the connection and + // the other goroutine. + close(quitCh) + shutdownWg.Done() + }() + + return nil +} + +// Stop stops the request forwarding server and closes connections. +func (rf *requestForwardingHandler) Stop() error { + close(rf.stopCh) + rf.fwRPCServer.Stop() + return nil } // Starts the listeners and servers necessary to handle forwarded requests @@ -62,269 +205,32 @@ func (c *Core) startForwarding(ctx context.Context) error { c.requestForwardingConnectionLock.Unlock() // Resolve locally to avoid races - ha := c.ha != nil - - var perfStandbyRepCluster *ReplicatedCluster - if ha { - id, err := uuid.GenerateUUID() - if err != nil { - return err - } - - perfStandbyRepCluster = &ReplicatedCluster{ - State: consts.ReplicationPerformanceStandby, - ClusterID: id, - PrimaryClusterAddr: c.clusterAddr, - } - if err = c.setupReplicatedClusterPrimary(perfStandbyRepCluster); err != nil { - return err - } - } - - // Get our TLS config - tlsConfig, err := c.ClusterTLSConfig(ctx, nil, perfStandbyRepCluster) - if err != nil { - c.logger.Error("failed to get tls configuration when starting forwarding", "error", err) - return err - } - - // The server supports all of the possible protos - tlsConfig.NextProtos = []string{"h2", requestForwardingALPN, perfStandbyALPN, PerformanceReplicationALPN, DRReplicationALPN} - - if !atomic.CompareAndSwapUint32(c.rpcServerActive, 0, 1) { - c.logger.Warn("forwarding rpc server already running") + if c.ha == nil || c.clusterListener == nil { return nil } - fwRPCServer := grpc.NewServer( - grpc.KeepaliveParams(keepalive.ServerParameters{ - Time: 2 * HeartbeatInterval, - }), - grpc.MaxRecvMsgSize(math.MaxInt32), - grpc.MaxSendMsgSize(math.MaxInt32), - ) - - // Setup performance standby RPC servers - perfStandbyCount := 0 - if !c.IsDRSecondary() && !c.disablePerfStandby { - perfStandbyCount = c.perfStandbyCount() - } - perfStandbySlots := make(chan struct{}, perfStandbyCount) - - perfStandbyCache := cache.New(2*HeartbeatInterval, 1*time.Second) - perfStandbyCache.OnEvicted(func(secondaryID string, _ interface{}) { - c.logger.Debug("removing performance standby", "id", secondaryID) - c.removePerfStandbySecondary(context.Background(), secondaryID) - select { - case <-perfStandbySlots: - default: - c.logger.Warn("perf secondary timeout hit but no slot to free") - } - }) - - perfStandbyReplicationRPCServer := perfStandbyRPCServer(c, perfStandbyCache) - - if ha && c.clusterHandler != nil { - RegisterRequestForwardingServer(fwRPCServer, &forwardedRequestRPCServer{ - core: c, - handler: c.clusterHandler, - perfStandbySlots: perfStandbySlots, - perfStandbyRepCluster: perfStandbyRepCluster, - perfStandbyCache: perfStandbyCache, - }) + perfStandbyRepCluster, perfStandbyCache, perfStandbySlots, err := c.perfStandbyClusterHandler() + if err != nil { + return err } - // Create the HTTP/2 server that will be shared by both RPC and regular - // duties. Doing it this way instead of listening via the server and gRPC - // allows us to re-use the same port via ALPN. We can just tell the server - // to serve a given conn and which handler to use. - fws := &http2.Server{ - // Our forwarding connections heartbeat regularly so anything else we - // want to go away/get cleaned up pretty rapidly - IdleTimeout: 5 * HeartbeatInterval, + handler, err := NewRequestForwardingHandler(c, c.clusterListener.Server(), perfStandbySlots, perfStandbyRepCluster, perfStandbyCache) + if err != nil { + return err } - // Shutdown coordination logic - shutdown := new(uint32) - shutdownWg := &sync.WaitGroup{} - - for _, addr := range c.clusterListenerAddrs { - shutdownWg.Add(1) - - // Force a local resolution to avoid data races - laddr := addr - - // Start our listening loop - go func() { - defer shutdownWg.Done() - - // closeCh is used to shutdown the spawned goroutines once this - // function returns - closeCh := make(chan struct{}) - defer func() { - close(closeCh) - }() - - if c.logger.IsInfo() { - c.logger.Info("starting listener", "listener_address", laddr) - } - - // Create a TCP listener. We do this separately and specifically - // with TCP so that we can set deadlines. - tcpLn, err := net.ListenTCP("tcp", laddr) - if err != nil { - c.logger.Error("error starting listener", "error", err) - return - } - - // Wrap the listener with TLS - tlsLn := tls.NewListener(tcpLn, tlsConfig) - defer tlsLn.Close() - - if c.logger.IsInfo() { - c.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr()) - } - - for { - if atomic.LoadUint32(shutdown) > 0 { - return - } - - // Set the deadline for the accept call. If it passes we'll get - // an error, causing us to check the condition at the top - // again. - tcpLn.SetDeadline(time.Now().Add(clusterListenerAcceptDeadline)) - - // Accept the connection - conn, err := tlsLn.Accept() - if err != nil { - if err, ok := err.(net.Error); ok && !err.Timeout() { - c.logger.Debug("non-timeout error accepting on cluster port", "error", err) - } - if conn != nil { - conn.Close() - } - continue - } - if conn == nil { - continue - } - - // Type assert to TLS connection and handshake to populate the - // connection state - tlsConn := conn.(*tls.Conn) - - // Set a deadline for the handshake. This will cause clients - // that don't successfully auth to be kicked out quickly. - // Cluster connections should be reliable so being marginally - // aggressive here is fine. - err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second)) - if err != nil { - if c.logger.IsDebug() { - c.logger.Debug("error setting deadline for cluster connection", "error", err) - } - tlsConn.Close() - continue - } - - err = tlsConn.Handshake() - if err != nil { - if c.logger.IsDebug() { - c.logger.Debug("error handshaking cluster connection", "error", err) - } - tlsConn.Close() - continue - } - - // Now, set it back to unlimited - err = tlsConn.SetDeadline(time.Time{}) - if err != nil { - if c.logger.IsDebug() { - c.logger.Debug("error setting deadline for cluster connection", "error", err) - } - tlsConn.Close() - continue - } - - switch tlsConn.ConnectionState().NegotiatedProtocol { - case requestForwardingALPN: - if !ha { - tlsConn.Close() - continue - } - - c.logger.Debug("got request forwarding connection") - - shutdownWg.Add(2) - // quitCh is used to close the connection and the second - // goroutine if the server closes before closeCh. - quitCh := make(chan struct{}) - go func() { - select { - case <-quitCh: - case <-closeCh: - } - tlsConn.Close() - shutdownWg.Done() - }() - - go func() { - fws.ServeConn(tlsConn, &http2.ServeConnOpts{ - Handler: fwRPCServer, - BaseConfig: &http.Server{ - ErrorLog: c.logger.StandardLogger(nil), - }, - }) - // close the quitCh which will close the connection and - // the other goroutine. - close(quitCh) - shutdownWg.Done() - }() - - case PerformanceReplicationALPN, DRReplicationALPN, perfStandbyALPN: - handleReplicationConn(ctx, c, shutdownWg, closeCh, fws, perfStandbyReplicationRPCServer, perfStandbyCache, tlsConn) - default: - c.logger.Debug("unknown negotiated protocol on cluster port") - tlsConn.Close() - continue - } - } - }() - } - - // This is in its own goroutine so that we don't block the main thread, and - // thus we use atomic and channels to coordinate - // However, because you can't query the status of a channel, we set a bool - // here while we have the state lock to know whether to actually send a - // shutdown (e.g. whether the channel will block). See issue #2083. - c.clusterListenersRunning = true - go func() { - // If we get told to shut down... - <-c.clusterListenerShutdownCh - - // Stop the RPC server - c.logger.Info("shutting down forwarding rpc listeners") - fwRPCServer.Stop() - - // Set the shutdown flag. This will cause the listeners to shut down - // within the deadline in clusterListenerAcceptDeadline - atomic.StoreUint32(shutdown, 1) - c.logger.Info("forwarding rpc listeners stopped") - - // Wait for them all to shut down - shutdownWg.Wait() - c.logger.Info("rpc listeners successfully shut down") - - // Clear us up to run this function again - atomic.StoreUint32(c.rpcServerActive, 0) - - // Tell the main thread that shutdown is done. - c.clusterListenerShutdownSuccessCh <- struct{}{} - }() + c.clusterListener.AddHandler(requestForwardingALPN, handler) return nil } +func (c *Core) stopForwarding() { + if c.clusterListener != nil { + c.clusterListener.StopHandler(requestForwardingALPN) + c.clusterListener.StopHandler(perfStandbyALPN) + } +} + // refreshRequestForwardingConnection ensures that the client/transport are // alive and that the current active address value matches the most // recently-known address. @@ -349,13 +255,25 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd return err } + parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate) + if parsedCert == nil { + c.logger.Error("no request forwarding cluster certificate found") + return errors.New("no request forwarding cluster certificate found") + } + + if c.clusterListener != nil { + c.clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{ + core: c, + }) + } + // Set up grpc forwarding handling // It's not really insecure, but we have to dial manually to get the // ALPN header right. It's just "insecure" because GRPC isn't managing // the TLS state. dctx, cancelFunc := context.WithCancel(ctx) c.rpcClientConn, err = grpc.DialContext(dctx, clusterURL.Host, - grpc.WithDialer(c.getGRPCDialer(ctx, requestForwardingALPN, "", nil, nil, nil)), + grpc.WithDialer(c.getGRPCDialer(ctx, requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)), grpc.WithInsecure(), // it's not, we handle it in the dialer grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 2 * HeartbeatInterval, @@ -398,6 +316,9 @@ func (c *Core) clearForwardingClients() { c.rpcClientConnContext = nil c.rpcForwardingClient = nil + if c.clusterListener != nil { + c.clusterListener.RemoveClient(requestForwardingALPN) + } c.clusterLeaderParams.Store((*ClusterLeaderParams)(nil)) } @@ -450,32 +371,3 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro return int(resp.StatusCode), header, resp.Body, nil } - -// getGRPCDialer is used to return a dialer that has the correct TLS -// configuration. Otherwise gRPC tries to be helpful and stomps all over our -// NextProtos. -func (c *Core) getGRPCDialer(ctx context.Context, alpnProto, serverName string, caCert *x509.Certificate, repClusters *ReplicatedClusters, perfStandbyCluster *ReplicatedCluster) func(string, time.Duration) (net.Conn, error) { - return func(addr string, timeout time.Duration) (net.Conn, error) { - tlsConfig, err := c.ClusterTLSConfig(ctx, repClusters, perfStandbyCluster) - if err != nil { - c.logger.Error("failed to get tls configuration", "error", err) - return nil, err - } - if serverName != "" { - tlsConfig.ServerName = serverName - } - if caCert != nil { - pool := x509.NewCertPool() - pool.AddCert(caCert) - tlsConfig.RootCAs = pool - tlsConfig.ClientCAs = pool - } - c.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName) - - tlsConfig.NextProtos = []string{alpnProto} - dialer := &net.Dialer{ - Timeout: timeout, - } - return tls.DialWithDialer(dialer, "tcp", addr, tlsConfig) - } -} diff --git a/vault/request_forwarding_rpc.go b/vault/request_forwarding_rpc.go index b3b6e0b01..24adfac66 100644 --- a/vault/request_forwarding_rpc.go +++ b/vault/request_forwarding_rpc.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/forwarding" + "github.com/hashicorp/vault/vault/replication" cache "github.com/patrickmn/go-cache" ) @@ -16,7 +17,7 @@ type forwardedRequestRPCServer struct { core *Core handler http.Handler perfStandbySlots chan struct{} - perfStandbyRepCluster *ReplicatedCluster + perfStandbyRepCluster *replication.Cluster perfStandbyCache *cache.Cache } diff --git a/vault/request_forwarding_util.go b/vault/request_forwarding_util.go deleted file mode 100644 index 20fae15f0..000000000 --- a/vault/request_forwarding_util.go +++ /dev/null @@ -1,18 +0,0 @@ -// +build !enterprise - -package vault - -import ( - "context" - "crypto/tls" - "sync" - - cache "github.com/patrickmn/go-cache" - "golang.org/x/net/http2" - grpc "google.golang.org/grpc" -) - -func perfStandbyRPCServer(*Core, *cache.Cache) *grpc.Server { return nil } - -func handleReplicationConn(context.Context, *Core, *sync.WaitGroup, chan struct{}, *http2.Server, *grpc.Server, *cache.Cache, *tls.Conn) { -} diff --git a/vault/wrapping.go b/vault/wrapping.go index b6ff5211c..9f244108b 100644 --- a/vault/wrapping.go +++ b/vault/wrapping.go @@ -14,6 +14,7 @@ import ( "github.com/SermoDigital/jose/jws" "github.com/SermoDigital/jose/jwt" "github.com/hashicorp/errwrap" + "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/namespace" @@ -31,7 +32,7 @@ func (c *Core) ensureWrappingKey(ctx context.Context) error { return err } - var keyParams clusterKeyParams + var keyParams certutil.ClusterKeyParams if entry == nil { key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)