2016-08-19 15:03:53 +00:00
|
|
|
package vault
|
|
|
|
|
|
|
|
import (
|
2017-12-01 22:08:38 +00:00
|
|
|
"context"
|
2016-08-19 15:03:53 +00:00
|
|
|
"crypto/tls"
|
2017-03-02 15:03:49 +00:00
|
|
|
"crypto/x509"
|
2016-08-19 15:03:53 +00:00
|
|
|
"fmt"
|
2018-02-06 18:52:35 +00:00
|
|
|
math "math"
|
2016-08-19 15:03:53 +00:00
|
|
|
"net"
|
|
|
|
"net/http"
|
|
|
|
"net/url"
|
2017-06-20 00:20:44 +00:00
|
|
|
"runtime"
|
2016-08-19 15:03:53 +00:00
|
|
|
"sync"
|
|
|
|
"sync/atomic"
|
|
|
|
"time"
|
|
|
|
|
2018-01-20 00:24:04 +00:00
|
|
|
"github.com/hashicorp/vault/helper/consts"
|
2016-08-19 15:03:53 +00:00
|
|
|
"github.com/hashicorp/vault/helper/forwarding"
|
|
|
|
"golang.org/x/net/http2"
|
|
|
|
"google.golang.org/grpc"
|
2017-05-26 17:32:13 +00:00
|
|
|
"google.golang.org/grpc/keepalive"
|
2016-08-19 15:03:53 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
clusterListenerAcceptDeadline = 500 * time.Millisecond
|
2017-08-30 20:28:23 +00:00
|
|
|
requestForwardingALPN = "req_fw_sb-act_v1"
|
2016-08-19 15:03:53 +00:00
|
|
|
)
|
|
|
|
|
2018-01-18 04:05:11 +00:00
|
|
|
var (
|
|
|
|
// Making this a package var allows tests to modify
|
2018-01-20 00:24:04 +00:00
|
|
|
HeartbeatInterval = 5 * time.Second
|
2018-01-18 04:05:11 +00:00
|
|
|
)
|
|
|
|
|
2016-08-19 15:03:53 +00:00
|
|
|
// Starts the listeners and servers necessary to handle forwarded requests
|
2018-01-19 09:11:59 +00:00
|
|
|
func (c *Core) startForwarding(ctx context.Context) error {
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Debug("cluster listener setup function")
|
|
|
|
defer c.logger.Debug("leaving cluster listener setup function")
|
2017-03-01 23:16:47 +00:00
|
|
|
|
2016-08-19 15:03:53 +00:00
|
|
|
// Clean up in case we have transitioned from a client to a server
|
2017-03-01 23:16:47 +00:00
|
|
|
c.requestForwardingConnectionLock.Lock()
|
2016-08-19 18:49:11 +00:00
|
|
|
c.clearForwardingClients()
|
2017-03-01 23:16:47 +00:00
|
|
|
c.requestForwardingConnectionLock.Unlock()
|
2016-08-19 15:03:53 +00:00
|
|
|
|
2017-02-17 04:09:39 +00:00
|
|
|
// Resolve locally to avoid races
|
|
|
|
ha := c.ha != nil
|
|
|
|
|
2016-08-19 15:03:53 +00:00
|
|
|
// Get our TLS config
|
2018-02-23 19:01:15 +00:00
|
|
|
tlsConfig, err := c.ClusterTLSConfig(ctx, nil)
|
2016-08-19 15:03:53 +00:00
|
|
|
if err != nil {
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Error("failed to get tls configuration when starting forwarding", "error", err)
|
2016-08-19 15:03:53 +00:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// The server supports all of the possible protos
|
2017-08-30 20:28:23 +00:00
|
|
|
tlsConfig.NextProtos = []string{"h2", requestForwardingALPN}
|
2016-08-19 15:03:53 +00:00
|
|
|
|
2018-01-25 01:23:08 +00:00
|
|
|
if !atomic.CompareAndSwapUint32(c.rpcServerActive, 0, 1) {
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Warn("forwarding rpc server already running")
|
2017-03-01 23:16:47 +00:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2018-01-25 01:23:08 +00:00
|
|
|
fwRPCServer := grpc.NewServer(
|
2017-05-26 17:32:13 +00:00
|
|
|
grpc.KeepaliveParams(keepalive.ServerParameters{
|
2018-01-18 04:08:35 +00:00
|
|
|
Time: 2 * HeartbeatInterval,
|
2017-05-26 17:32:13 +00:00
|
|
|
}),
|
2018-06-29 16:52:23 +00:00
|
|
|
grpc.MaxRecvMsgSize(math.MaxInt32),
|
|
|
|
grpc.MaxSendMsgSize(math.MaxInt32),
|
2017-05-26 17:32:13 +00:00
|
|
|
)
|
2017-02-17 04:09:39 +00:00
|
|
|
|
2017-05-24 14:38:48 +00:00
|
|
|
if ha && c.clusterHandler != nil {
|
2018-01-25 01:23:08 +00:00
|
|
|
RegisterRequestForwardingServer(fwRPCServer, &forwardedRequestRPCServer{
|
2017-02-17 04:09:39 +00:00
|
|
|
core: c,
|
2017-05-24 14:38:48 +00:00
|
|
|
handler: c.clusterHandler,
|
2017-02-17 04:09:39 +00:00
|
|
|
})
|
|
|
|
}
|
2016-08-19 15:03:53 +00:00
|
|
|
|
|
|
|
// Create the HTTP/2 server that will be shared by both RPC and regular
|
|
|
|
// duties. Doing it this way instead of listening via the server and gRPC
|
|
|
|
// allows us to re-use the same port via ALPN. We can just tell the server
|
|
|
|
// to serve a given conn and which handler to use.
|
2018-06-16 22:21:33 +00:00
|
|
|
fws := &http2.Server{
|
|
|
|
// Our forwarding connections heartbeat regularly so anything else we
|
|
|
|
// want to go away/get cleaned up pretty rapidly
|
|
|
|
IdleTimeout: 5 * HeartbeatInterval,
|
|
|
|
}
|
2016-08-19 15:03:53 +00:00
|
|
|
|
|
|
|
// Shutdown coordination logic
|
2018-06-09 19:35:22 +00:00
|
|
|
shutdown := new(uint32)
|
2016-08-19 15:03:53 +00:00
|
|
|
shutdownWg := &sync.WaitGroup{}
|
|
|
|
|
|
|
|
for _, addr := range c.clusterListenerAddrs {
|
|
|
|
shutdownWg.Add(1)
|
|
|
|
|
|
|
|
// Force a local resolution to avoid data races
|
|
|
|
laddr := addr
|
|
|
|
|
|
|
|
// Start our listening loop
|
|
|
|
go func() {
|
|
|
|
defer shutdownWg.Done()
|
|
|
|
|
2017-12-13 01:18:04 +00:00
|
|
|
// closeCh is used to shutdown the spawned goroutines once this
|
|
|
|
// function returns
|
|
|
|
closeCh := make(chan struct{})
|
|
|
|
defer func() {
|
|
|
|
close(closeCh)
|
|
|
|
}()
|
|
|
|
|
2016-11-08 15:31:15 +00:00
|
|
|
if c.logger.IsInfo() {
|
|
|
|
c.logger.Info("core/startClusterListener: starting listener", "listener_address", laddr)
|
|
|
|
}
|
2016-08-19 15:03:53 +00:00
|
|
|
|
|
|
|
// Create a TCP listener. We do this separately and specifically
|
|
|
|
// with TCP so that we can set deadlines.
|
|
|
|
tcpLn, err := net.ListenTCP("tcp", laddr)
|
|
|
|
if err != nil {
|
2016-08-19 20:45:17 +00:00
|
|
|
c.logger.Error("core/startClusterListener: error starting listener", "error", err)
|
2016-08-19 15:03:53 +00:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// Wrap the listener with TLS
|
|
|
|
tlsLn := tls.NewListener(tcpLn, tlsConfig)
|
2017-02-17 04:09:39 +00:00
|
|
|
defer tlsLn.Close()
|
2016-08-19 15:03:53 +00:00
|
|
|
|
2016-08-19 20:45:17 +00:00
|
|
|
if c.logger.IsInfo() {
|
|
|
|
c.logger.Info("core/startClusterListener: serving cluster requests", "cluster_listen_address", tlsLn.Addr())
|
|
|
|
}
|
2016-08-19 15:03:53 +00:00
|
|
|
|
|
|
|
for {
|
2018-06-09 19:35:22 +00:00
|
|
|
if atomic.LoadUint32(shutdown) > 0 {
|
2016-08-19 15:03:53 +00:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// Set the deadline for the accept call. If it passes we'll get
|
|
|
|
// an error, causing us to check the condition at the top
|
|
|
|
// again.
|
|
|
|
tcpLn.SetDeadline(time.Now().Add(clusterListenerAcceptDeadline))
|
|
|
|
|
|
|
|
// Accept the connection
|
|
|
|
conn, err := tlsLn.Accept()
|
2017-11-07 22:27:13 +00:00
|
|
|
if err != nil {
|
|
|
|
if err, ok := err.(net.Error); ok && !err.Timeout() {
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Debug("non-timeout error accepting on cluster port", "error", err)
|
2017-11-07 22:27:13 +00:00
|
|
|
}
|
2017-11-01 01:58:45 +00:00
|
|
|
if conn != nil {
|
|
|
|
conn.Close()
|
|
|
|
}
|
2016-08-19 15:03:53 +00:00
|
|
|
continue
|
|
|
|
}
|
2017-11-07 22:27:13 +00:00
|
|
|
if conn == nil {
|
|
|
|
continue
|
|
|
|
}
|
2016-08-19 15:03:53 +00:00
|
|
|
|
|
|
|
// Type assert to TLS connection and handshake to populate the
|
|
|
|
// connection state
|
|
|
|
tlsConn := conn.(*tls.Conn)
|
2018-06-16 22:21:33 +00:00
|
|
|
|
|
|
|
// Set a deadline for the handshake. This will cause clients
|
|
|
|
// that don't successfully auth to be kicked out quickly.
|
|
|
|
// Cluster connections should be reliable so being marginally
|
|
|
|
// aggressive here is fine.
|
|
|
|
err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second))
|
|
|
|
if err != nil {
|
|
|
|
if c.logger.IsDebug() {
|
|
|
|
c.logger.Debug("error setting deadline for cluster connection", "error", err)
|
|
|
|
}
|
|
|
|
tlsConn.Close()
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
2016-08-19 15:03:53 +00:00
|
|
|
err = tlsConn.Handshake()
|
|
|
|
if err != nil {
|
2016-08-19 20:45:17 +00:00
|
|
|
if c.logger.IsDebug() {
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Debug("error handshaking cluster connection", "error", err)
|
2016-08-19 20:45:17 +00:00
|
|
|
}
|
2017-11-07 22:27:13 +00:00
|
|
|
tlsConn.Close()
|
2016-08-19 15:03:53 +00:00
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
2018-06-16 22:21:33 +00:00
|
|
|
// Now, set it back to unlimited
|
|
|
|
err = tlsConn.SetDeadline(time.Time{})
|
|
|
|
if err != nil {
|
|
|
|
if c.logger.IsDebug() {
|
|
|
|
c.logger.Debug("error setting deadline for cluster connection", "error", err)
|
|
|
|
}
|
|
|
|
tlsConn.Close()
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
2016-08-19 15:03:53 +00:00
|
|
|
switch tlsConn.ConnectionState().NegotiatedProtocol {
|
2017-08-30 20:28:23 +00:00
|
|
|
case requestForwardingALPN:
|
2017-02-17 04:09:39 +00:00
|
|
|
if !ha {
|
2017-11-07 22:27:13 +00:00
|
|
|
tlsConn.Close()
|
2017-02-17 04:09:39 +00:00
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Debug("got request forwarding connection")
|
2017-12-13 01:18:04 +00:00
|
|
|
|
|
|
|
shutdownWg.Add(2)
|
|
|
|
// quitCh is used to close the connection and the second
|
|
|
|
// goroutine if the server closes before closeCh.
|
|
|
|
quitCh := make(chan struct{})
|
|
|
|
go func() {
|
|
|
|
select {
|
|
|
|
case <-quitCh:
|
|
|
|
case <-closeCh:
|
|
|
|
}
|
|
|
|
tlsConn.Close()
|
|
|
|
shutdownWg.Done()
|
|
|
|
}()
|
|
|
|
|
2017-11-01 01:58:45 +00:00
|
|
|
go func() {
|
2017-11-07 22:27:13 +00:00
|
|
|
fws.ServeConn(tlsConn, &http2.ServeConnOpts{
|
2018-01-25 01:23:08 +00:00
|
|
|
Handler: fwRPCServer,
|
2017-11-01 01:58:45 +00:00
|
|
|
})
|
2017-12-13 01:18:04 +00:00
|
|
|
// close the quitCh which will close the connection and
|
|
|
|
// the other goroutine.
|
|
|
|
close(quitCh)
|
|
|
|
shutdownWg.Done()
|
2017-11-01 01:58:45 +00:00
|
|
|
}()
|
2016-08-19 15:03:53 +00:00
|
|
|
|
|
|
|
default:
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Debug("unknown negotiated protocol on cluster port")
|
2017-11-07 22:27:13 +00:00
|
|
|
tlsConn.Close()
|
2016-08-19 15:03:53 +00:00
|
|
|
continue
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
}
|
|
|
|
|
|
|
|
// This is in its own goroutine so that we don't block the main thread, and
|
|
|
|
// thus we use atomic and channels to coordinate
|
2016-11-11 21:43:33 +00:00
|
|
|
// However, because you can't query the status of a channel, we set a bool
|
|
|
|
// here while we have the state lock to know whether to actually send a
|
|
|
|
// shutdown (e.g. whether the channel will block). See issue #2083.
|
|
|
|
c.clusterListenersRunning = true
|
2016-08-19 15:03:53 +00:00
|
|
|
go func() {
|
|
|
|
// If we get told to shut down...
|
|
|
|
<-c.clusterListenerShutdownCh
|
|
|
|
|
|
|
|
// Stop the RPC server
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Info("shutting down forwarding rpc listeners")
|
2018-01-25 01:23:08 +00:00
|
|
|
fwRPCServer.Stop()
|
2016-08-19 15:03:53 +00:00
|
|
|
|
|
|
|
// Set the shutdown flag. This will cause the listeners to shut down
|
|
|
|
// within the deadline in clusterListenerAcceptDeadline
|
2018-06-09 19:35:22 +00:00
|
|
|
atomic.StoreUint32(shutdown, 1)
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Info("forwarding rpc listeners stopped")
|
2016-08-19 15:03:53 +00:00
|
|
|
|
|
|
|
// Wait for them all to shut down
|
|
|
|
shutdownWg.Wait()
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Info("rpc listeners successfully shut down")
|
2016-08-19 15:03:53 +00:00
|
|
|
|
2018-01-25 01:23:08 +00:00
|
|
|
// Clear us up to run this function again
|
|
|
|
atomic.StoreUint32(c.rpcServerActive, 0)
|
|
|
|
|
2016-08-19 15:03:53 +00:00
|
|
|
// Tell the main thread that shutdown is done.
|
|
|
|
c.clusterListenerShutdownSuccessCh <- struct{}{}
|
|
|
|
}()
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// refreshRequestForwardingConnection ensures that the client/transport are
|
|
|
|
// alive and that the current active address value matches the most
|
|
|
|
// recently-known address.
|
2018-01-19 09:11:59 +00:00
|
|
|
func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAddr string) error {
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Debug("refreshing forwarding connection")
|
|
|
|
defer c.logger.Debug("done refreshing forwarding connection")
|
2017-03-01 23:16:47 +00:00
|
|
|
|
2016-08-19 15:03:53 +00:00
|
|
|
c.requestForwardingConnectionLock.Lock()
|
|
|
|
defer c.requestForwardingConnectionLock.Unlock()
|
|
|
|
|
2017-03-01 23:16:47 +00:00
|
|
|
// Clean things up first
|
|
|
|
c.clearForwardingClients()
|
2016-08-19 15:03:53 +00:00
|
|
|
|
2017-03-01 23:16:47 +00:00
|
|
|
// If we don't have anything to connect to, just return
|
2016-08-19 15:03:53 +00:00
|
|
|
if clusterAddr == "" {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
clusterURL, err := url.Parse(clusterAddr)
|
|
|
|
if err != nil {
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Error("error parsing cluster address attempting to refresh forwarding connection", "error", err)
|
2016-08-19 15:03:53 +00:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2017-05-24 13:34:59 +00:00
|
|
|
// Set up grpc forwarding handling
|
|
|
|
// It's not really insecure, but we have to dial manually to get the
|
|
|
|
// ALPN header right. It's just "insecure" because GRPC isn't managing
|
|
|
|
// the TLS state.
|
2018-01-19 09:11:59 +00:00
|
|
|
dctx, cancelFunc := context.WithCancel(ctx)
|
|
|
|
c.rpcClientConn, err = grpc.DialContext(dctx, clusterURL.Host,
|
2018-02-23 19:01:15 +00:00
|
|
|
grpc.WithDialer(c.getGRPCDialer(ctx, requestForwardingALPN, "", nil, nil)),
|
2017-05-26 17:32:13 +00:00
|
|
|
grpc.WithInsecure(), // it's not, we handle it in the dialer
|
|
|
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
2018-01-18 04:08:35 +00:00
|
|
|
Time: 2 * HeartbeatInterval,
|
2018-02-06 18:52:35 +00:00
|
|
|
}),
|
|
|
|
grpc.WithDefaultCallOptions(
|
|
|
|
grpc.MaxCallRecvMsgSize(math.MaxInt32),
|
|
|
|
grpc.MaxCallSendMsgSize(math.MaxInt32),
|
|
|
|
))
|
2017-05-24 13:34:59 +00:00
|
|
|
if err != nil {
|
2017-05-24 19:06:56 +00:00
|
|
|
cancelFunc()
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Error("err setting up forwarding rpc client", "error", err)
|
2017-05-24 13:34:59 +00:00
|
|
|
return err
|
2016-08-19 15:03:53 +00:00
|
|
|
}
|
2018-01-19 12:22:31 +00:00
|
|
|
c.rpcClientConnContext = dctx
|
2017-05-24 19:06:56 +00:00
|
|
|
c.rpcClientConnCancelFunc = cancelFunc
|
|
|
|
c.rpcForwardingClient = &forwardingClient{
|
|
|
|
RequestForwardingClient: NewRequestForwardingClient(c.rpcClientConn),
|
|
|
|
core: c,
|
2018-01-18 04:08:35 +00:00
|
|
|
echoTicker: time.NewTicker(HeartbeatInterval),
|
2018-01-19 12:22:31 +00:00
|
|
|
echoContext: dctx,
|
2017-05-24 19:06:56 +00:00
|
|
|
}
|
2017-05-26 17:32:13 +00:00
|
|
|
c.rpcForwardingClient.startHeartbeat()
|
2016-08-19 15:03:53 +00:00
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2016-08-19 18:49:11 +00:00
|
|
|
func (c *Core) clearForwardingClients() {
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Debug("clearing forwarding clients")
|
|
|
|
defer c.logger.Debug("done clearing forwarding clients")
|
2017-03-01 23:16:47 +00:00
|
|
|
|
2016-08-19 18:49:11 +00:00
|
|
|
if c.rpcClientConnCancelFunc != nil {
|
|
|
|
c.rpcClientConnCancelFunc()
|
|
|
|
c.rpcClientConnCancelFunc = nil
|
|
|
|
}
|
|
|
|
if c.rpcClientConn != nil {
|
|
|
|
c.rpcClientConn.Close()
|
|
|
|
c.rpcClientConn = nil
|
|
|
|
}
|
2017-05-24 19:06:56 +00:00
|
|
|
|
|
|
|
c.rpcClientConnContext = nil
|
2017-03-01 23:16:47 +00:00
|
|
|
c.rpcForwardingClient = nil
|
2016-08-19 18:49:11 +00:00
|
|
|
}
|
|
|
|
|
2016-08-19 15:03:53 +00:00
|
|
|
// ForwardRequest forwards a given request to the active node and returns the
|
|
|
|
// response.
|
2016-08-26 21:53:47 +00:00
|
|
|
func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, error) {
|
2016-08-19 15:03:53 +00:00
|
|
|
c.requestForwardingConnectionLock.RLock()
|
|
|
|
defer c.requestForwardingConnectionLock.RUnlock()
|
|
|
|
|
2017-05-24 13:34:59 +00:00
|
|
|
if c.rpcForwardingClient == nil {
|
|
|
|
return 0, nil, nil, ErrCannotForward
|
|
|
|
}
|
2016-08-19 15:03:53 +00:00
|
|
|
|
2017-05-24 13:34:59 +00:00
|
|
|
freq, err := forwarding.GenerateForwardedRequest(req)
|
|
|
|
if err != nil {
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Error("error creating forwarding RPC request", "error", err)
|
2017-05-24 13:34:59 +00:00
|
|
|
return 0, nil, nil, fmt.Errorf("error creating forwarding RPC request")
|
|
|
|
}
|
|
|
|
if freq == nil {
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Error("got nil forwarding RPC request")
|
2017-05-24 13:34:59 +00:00
|
|
|
return 0, nil, nil, fmt.Errorf("got nil forwarding RPC request")
|
|
|
|
}
|
2017-05-24 19:06:56 +00:00
|
|
|
resp, err := c.rpcForwardingClient.ForwardRequest(c.rpcClientConnContext, freq)
|
2017-05-24 13:34:59 +00:00
|
|
|
if err != nil {
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Error("error during forwarded RPC request", "error", err)
|
2017-05-24 13:34:59 +00:00
|
|
|
return 0, nil, nil, fmt.Errorf("error during forwarding RPC request")
|
|
|
|
}
|
2016-08-26 21:53:47 +00:00
|
|
|
|
2017-05-24 13:34:59 +00:00
|
|
|
var header http.Header
|
|
|
|
if resp.HeaderEntries != nil {
|
|
|
|
header = make(http.Header)
|
|
|
|
for k, v := range resp.HeaderEntries {
|
2017-11-02 12:31:50 +00:00
|
|
|
header[k] = v.Values
|
2016-08-26 21:53:47 +00:00
|
|
|
}
|
2016-08-19 15:03:53 +00:00
|
|
|
}
|
2017-05-24 13:34:59 +00:00
|
|
|
|
|
|
|
return int(resp.StatusCode), header, resp.Body, nil
|
2016-08-19 15:03:53 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// getGRPCDialer is used to return a dialer that has the correct TLS
|
|
|
|
// configuration. Otherwise gRPC tries to be helpful and stomps all over our
|
|
|
|
// NextProtos.
|
2018-02-23 19:01:15 +00:00
|
|
|
func (c *Core) getGRPCDialer(ctx context.Context, alpnProto, serverName string, caCert *x509.Certificate, repClusters *ReplicatedClusters) func(string, time.Duration) (net.Conn, error) {
|
2016-08-19 15:03:53 +00:00
|
|
|
return func(addr string, timeout time.Duration) (net.Conn, error) {
|
2018-02-23 19:01:15 +00:00
|
|
|
tlsConfig, err := c.ClusterTLSConfig(ctx, repClusters)
|
2016-08-19 15:03:53 +00:00
|
|
|
if err != nil {
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Error("failed to get tls configuration", "error", err)
|
2016-08-19 15:03:53 +00:00
|
|
|
return nil, err
|
|
|
|
}
|
2017-01-06 20:42:18 +00:00
|
|
|
if serverName != "" {
|
|
|
|
tlsConfig.ServerName = serverName
|
|
|
|
}
|
2017-03-02 15:03:49 +00:00
|
|
|
if caCert != nil {
|
|
|
|
pool := x509.NewCertPool()
|
|
|
|
pool.AddCert(caCert)
|
|
|
|
tlsConfig.RootCAs = pool
|
|
|
|
tlsConfig.ClientCAs = pool
|
|
|
|
}
|
2018-04-03 00:46:59 +00:00
|
|
|
c.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName)
|
2017-01-06 20:42:18 +00:00
|
|
|
|
|
|
|
tlsConfig.NextProtos = []string{alpnProto}
|
2016-08-19 15:03:53 +00:00
|
|
|
dialer := &net.Dialer{
|
|
|
|
Timeout: timeout,
|
|
|
|
}
|
|
|
|
return tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type forwardedRequestRPCServer struct {
|
|
|
|
core *Core
|
|
|
|
handler http.Handler
|
|
|
|
}
|
|
|
|
|
2017-01-06 22:08:43 +00:00
|
|
|
func (s *forwardedRequestRPCServer) ForwardRequest(ctx context.Context, freq *forwarding.Request) (*forwarding.Response, error) {
|
2018-04-03 00:46:59 +00:00
|
|
|
//s.core.logger.Debug("forwarding: serving rpc forwarded request")
|
2017-03-02 01:57:38 +00:00
|
|
|
|
2016-08-19 15:03:53 +00:00
|
|
|
// Parse an http.Request out of it
|
|
|
|
req, err := forwarding.ParseForwardedRequest(freq)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// A very dummy response writer that doesn't follow normal semantics, just
|
|
|
|
// lets you write a status code (last written wins) and a body. But it
|
|
|
|
// meets the interface requirements.
|
|
|
|
w := forwarding.NewRPCResponseWriter()
|
|
|
|
|
2017-06-20 00:20:44 +00:00
|
|
|
resp := &forwarding.Response{}
|
|
|
|
|
|
|
|
runRequest := func() {
|
|
|
|
defer func() {
|
|
|
|
// Logic here comes mostly from the Go source code
|
|
|
|
if err := recover(); err != nil {
|
|
|
|
const size = 64 << 10
|
|
|
|
buf := make([]byte, size)
|
|
|
|
buf = buf[:runtime.Stack(buf, false)]
|
2018-01-18 16:40:59 +00:00
|
|
|
s.core.logger.Error("forwarding: panic serving request", "path", req.URL.Path, "error", err, "stacktrace", string(buf))
|
2017-06-20 00:20:44 +00:00
|
|
|
}
|
|
|
|
}()
|
|
|
|
s.handler.ServeHTTP(w, req)
|
|
|
|
}
|
|
|
|
runRequest()
|
2017-06-20 23:54:10 +00:00
|
|
|
resp.StatusCode = uint32(w.StatusCode())
|
|
|
|
resp.Body = w.Body().Bytes()
|
2016-08-26 21:53:47 +00:00
|
|
|
|
|
|
|
header := w.Header()
|
|
|
|
if header != nil {
|
|
|
|
resp.HeaderEntries = make(map[string]*forwarding.HeaderEntry, len(header))
|
|
|
|
for k, v := range header {
|
|
|
|
resp.HeaderEntries[k] = &forwarding.HeaderEntry{
|
|
|
|
Values: v,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return resp, nil
|
2016-08-19 15:03:53 +00:00
|
|
|
}
|
2017-05-24 19:06:56 +00:00
|
|
|
|
|
|
|
func (s *forwardedRequestRPCServer) Echo(ctx context.Context, in *EchoRequest) (*EchoReply, error) {
|
2017-05-25 00:51:53 +00:00
|
|
|
if in.ClusterAddr != "" {
|
2017-05-25 01:10:32 +00:00
|
|
|
s.core.clusterPeerClusterAddrsCache.Set(in.ClusterAddr, nil, 0)
|
2017-05-25 00:51:53 +00:00
|
|
|
}
|
2017-05-24 19:06:56 +00:00
|
|
|
return &EchoReply{
|
2018-01-18 03:17:47 +00:00
|
|
|
Message: "pong",
|
|
|
|
ReplicationState: uint32(s.core.ReplicationState()),
|
2017-05-24 19:06:56 +00:00
|
|
|
}, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
type forwardingClient struct {
|
|
|
|
RequestForwardingClient
|
|
|
|
|
|
|
|
core *Core
|
|
|
|
|
|
|
|
echoTicker *time.Ticker
|
|
|
|
echoContext context.Context
|
|
|
|
}
|
|
|
|
|
2017-05-26 17:32:13 +00:00
|
|
|
// NOTE: we also take advantage of gRPC's keepalive bits, but as we send data
|
|
|
|
// with these requests it's useful to keep this as well
|
|
|
|
func (c *forwardingClient) startHeartbeat() {
|
2017-05-24 19:06:56 +00:00
|
|
|
go func() {
|
2017-05-25 01:45:51 +00:00
|
|
|
tick := func() {
|
|
|
|
c.core.stateLock.RLock()
|
|
|
|
clusterAddr := c.core.clusterAddr
|
|
|
|
c.core.stateLock.RUnlock()
|
|
|
|
|
|
|
|
ctx, cancel := context.WithTimeout(c.echoContext, 2*time.Second)
|
|
|
|
resp, err := c.RequestForwardingClient.Echo(ctx, &EchoRequest{
|
|
|
|
Message: "ping",
|
|
|
|
ClusterAddr: clusterAddr,
|
|
|
|
})
|
|
|
|
cancel()
|
|
|
|
if err != nil {
|
|
|
|
c.core.logger.Debug("forwarding: error sending echo request to active node", "error", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if resp == nil {
|
|
|
|
c.core.logger.Debug("forwarding: empty echo response from active node")
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if resp.Message != "pong" {
|
|
|
|
c.core.logger.Debug("forwarding: unexpected echo response from active node", "message", resp.Message)
|
|
|
|
return
|
|
|
|
}
|
2018-01-18 03:17:47 +00:00
|
|
|
// Store the active node's replication state to display in
|
|
|
|
// sys/health calls
|
2018-01-20 00:24:04 +00:00
|
|
|
atomic.StoreUint32(c.core.activeNodeReplicationState, resp.ReplicationState)
|
2018-04-03 00:46:59 +00:00
|
|
|
//c.core.logger.Debug("forwarding: successful heartbeat")
|
2017-05-25 01:45:51 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
tick()
|
|
|
|
|
2017-05-24 19:06:56 +00:00
|
|
|
for {
|
|
|
|
select {
|
|
|
|
case <-c.echoContext.Done():
|
|
|
|
c.echoTicker.Stop()
|
2018-04-03 00:46:59 +00:00
|
|
|
c.core.logger.Debug("forwarding: stopping heartbeating")
|
2018-01-23 02:44:38 +00:00
|
|
|
atomic.StoreUint32(c.core.activeNodeReplicationState, uint32(consts.ReplicationUnknown))
|
2017-05-24 19:06:56 +00:00
|
|
|
return
|
|
|
|
case <-c.echoTicker.C:
|
2017-05-25 01:45:51 +00:00
|
|
|
tick()
|
2017-05-24 19:06:56 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
}
|