Refactor the cluster listener (#6232)
* Port over OSS cluster port refactor components * Start forwarding * Cleanup a bit * Fix copy error * Return error from perf standby creation * Add some more comments * Fix copy/paste error
This commit is contained in:
parent
feb235d5f8
commit
f5b5fbb392
|
@ -17,12 +17,21 @@ import (
|
|||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/helper/errutil"
|
||||
)
|
||||
|
||||
// This can be one of a few key types so the different params may or may not be filled
|
||||
type ClusterKeyParams struct {
|
||||
Type string `json:"type" structs:"type" mapstructure:"type"`
|
||||
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
|
||||
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
|
||||
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
|
||||
}
|
||||
|
||||
// Secret is used to attempt to unmarshal a Vault secret
|
||||
// JSON response, as a convenience
|
||||
type Secret struct {
|
||||
|
|
421
vault/cluster.go
421
vault/cluster.go
|
@ -15,12 +15,16 @@ import (
|
|||
mathrand "math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -44,19 +48,6 @@ type ClusterLeaderParams struct {
|
|||
LeaderClusterAddr string
|
||||
}
|
||||
|
||||
type ReplicatedClusters struct {
|
||||
DR *ReplicatedCluster
|
||||
Performance *ReplicatedCluster
|
||||
}
|
||||
|
||||
// This can be one of a few key types so the different params may or may not be filled
|
||||
type clusterKeyParams struct {
|
||||
Type string `json:"type" structs:"type" mapstructure:"type"`
|
||||
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
|
||||
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
|
||||
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
|
||||
}
|
||||
|
||||
// Structure representing the storage entry that holds cluster information
|
||||
type Cluster struct {
|
||||
// Name of the cluster
|
||||
|
@ -290,10 +281,297 @@ func (c *Core) setupCluster(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// startClusterListener starts cluster request listeners during postunseal. It
|
||||
// ClusterClient is used to lookup a client certificate.
|
||||
type ClusterClient interface {
|
||||
ClientLookup(context.Context, *tls.CertificateRequestInfo) (*tls.Certificate, error)
|
||||
}
|
||||
|
||||
// ClusterHandler exposes functions for looking up TLS configuration and handing
|
||||
// off a connection for a cluster listener application.
|
||||
type ClusterHandler interface {
|
||||
ServerLookup(context.Context, *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||
CALookup(context.Context) (*x509.Certificate, error)
|
||||
|
||||
// Handoff is used to pass the connection lifetime off to
|
||||
// the storage backend
|
||||
Handoff(context.Context, *sync.WaitGroup, chan struct{}, *tls.Conn) error
|
||||
Stop() error
|
||||
}
|
||||
|
||||
// ClusterListener is the source of truth for cluster handlers and connection
|
||||
// clients. It dynamically builds the cluster TLS information. It's also
|
||||
// responsible for starting tcp listeners and accepting new cluster connections.
|
||||
type ClusterListener struct {
|
||||
handlers map[string]ClusterHandler
|
||||
clients map[string]ClusterClient
|
||||
shutdown *uint32
|
||||
shutdownWg *sync.WaitGroup
|
||||
server *http2.Server
|
||||
|
||||
clusterListenerAddrs []*net.TCPAddr
|
||||
clusterCipherSuites []uint16
|
||||
logger log.Logger
|
||||
l sync.RWMutex
|
||||
}
|
||||
|
||||
// AddClient adds a new client for an ALPN name
|
||||
func (cl *ClusterListener) AddClient(alpn string, client ClusterClient) {
|
||||
cl.l.Lock()
|
||||
cl.clients[alpn] = client
|
||||
cl.l.Unlock()
|
||||
}
|
||||
|
||||
// RemoveClient removes the client for the specified ALPN name
|
||||
func (cl *ClusterListener) RemoveClient(alpn string) {
|
||||
cl.l.Lock()
|
||||
delete(cl.clients, alpn)
|
||||
cl.l.Unlock()
|
||||
}
|
||||
|
||||
// AddHandler registers a new cluster handler for the provided ALPN name.
|
||||
func (cl *ClusterListener) AddHandler(alpn string, handler ClusterHandler) {
|
||||
cl.l.Lock()
|
||||
cl.handlers[alpn] = handler
|
||||
cl.l.Unlock()
|
||||
}
|
||||
|
||||
// StopHandler stops the cluster handler for the provided ALPN name, it also
|
||||
// calls stop on the handler.
|
||||
func (cl *ClusterListener) StopHandler(alpn string) {
|
||||
cl.l.Lock()
|
||||
handler, ok := cl.handlers[alpn]
|
||||
delete(cl.handlers, alpn)
|
||||
cl.l.Unlock()
|
||||
if ok {
|
||||
handler.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// Server returns the http2 server that the cluster listener is using
|
||||
func (cl *ClusterListener) Server() *http2.Server {
|
||||
return cl.server
|
||||
}
|
||||
|
||||
// TLSConfig returns a tls config object that uses dynamic lookups to correctly
|
||||
// authenticate registered handlers/clients
|
||||
func (cl *ClusterListener) TLSConfig(ctx context.Context) (*tls.Config, error) {
|
||||
serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
cl.logger.Debug("performing server cert lookup")
|
||||
|
||||
cl.l.RLock()
|
||||
defer cl.l.RUnlock()
|
||||
for _, v := range clientHello.SupportedProtos {
|
||||
if handler, ok := cl.handlers[v]; ok {
|
||||
return handler.ServerLookup(ctx, clientHello)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported protocol")
|
||||
}
|
||||
|
||||
clientLookup := func(requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
cl.logger.Debug("performing client cert lookup")
|
||||
|
||||
cl.l.RLock()
|
||||
defer cl.l.RUnlock()
|
||||
for _, client := range cl.clients {
|
||||
cert, err := client.ClientLookup(ctx, requestInfo)
|
||||
if err == nil && cert != nil {
|
||||
return cert, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("no client cert found")
|
||||
}
|
||||
|
||||
serverConfigLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
caPool := x509.NewCertPool()
|
||||
|
||||
ret := &tls.Config{
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
GetCertificate: serverLookup,
|
||||
GetClientCertificate: clientLookup,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
RootCAs: caPool,
|
||||
ClientCAs: caPool,
|
||||
NextProtos: clientHello.SupportedProtos,
|
||||
CipherSuites: cl.clusterCipherSuites,
|
||||
}
|
||||
|
||||
cl.l.RLock()
|
||||
defer cl.l.RUnlock()
|
||||
for _, v := range clientHello.SupportedProtos {
|
||||
if handler, ok := cl.handlers[v]; ok {
|
||||
ca, err := handler.CALookup(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
caPool.AddCert(ca)
|
||||
return ret, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported protocol")
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
GetCertificate: serverLookup,
|
||||
GetClientCertificate: clientLookup,
|
||||
GetConfigForClient: serverConfigLookup,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: cl.clusterCipherSuites,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Run starts the tcp listeners and will accept connections until stop is
|
||||
// called. This function blocks so should be called in a go routine.
|
||||
func (cl *ClusterListener) Run(ctx context.Context) error {
|
||||
// Get our TLS config
|
||||
tlsConfig, err := cl.TLSConfig(ctx)
|
||||
if err != nil {
|
||||
cl.logger.Error("failed to get tls configuration when starting cluster listener", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// The server supports all of the possible protos
|
||||
tlsConfig.NextProtos = []string{"h2", requestForwardingALPN, perfStandbyALPN, PerformanceReplicationALPN, DRReplicationALPN}
|
||||
|
||||
for _, addr := range cl.clusterListenerAddrs {
|
||||
cl.shutdownWg.Add(1)
|
||||
|
||||
// Force a local resolution to avoid data races
|
||||
laddr := addr
|
||||
|
||||
// Start our listening loop
|
||||
go func() {
|
||||
defer cl.shutdownWg.Done()
|
||||
|
||||
// closeCh is used to shutdown the spawned goroutines once this
|
||||
// function returns
|
||||
closeCh := make(chan struct{})
|
||||
defer func() {
|
||||
close(closeCh)
|
||||
}()
|
||||
|
||||
if cl.logger.IsInfo() {
|
||||
cl.logger.Info("starting listener", "listener_address", laddr)
|
||||
}
|
||||
|
||||
// Create a TCP listener. We do this separately and specifically
|
||||
// with TCP so that we can set deadlines.
|
||||
tcpLn, err := net.ListenTCP("tcp", laddr)
|
||||
if err != nil {
|
||||
cl.logger.Error("error starting listener", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Wrap the listener with TLS
|
||||
tlsLn := tls.NewListener(tcpLn, tlsConfig)
|
||||
defer tlsLn.Close()
|
||||
|
||||
if cl.logger.IsInfo() {
|
||||
cl.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr())
|
||||
}
|
||||
|
||||
for {
|
||||
if atomic.LoadUint32(cl.shutdown) > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Set the deadline for the accept call. If it passes we'll get
|
||||
// an error, causing us to check the condition at the top
|
||||
// again.
|
||||
tcpLn.SetDeadline(time.Now().Add(clusterListenerAcceptDeadline))
|
||||
|
||||
// Accept the connection
|
||||
conn, err := tlsLn.Accept()
|
||||
if err != nil {
|
||||
if err, ok := err.(net.Error); ok && !err.Timeout() {
|
||||
cl.logger.Debug("non-timeout error accepting on cluster port", "error", err)
|
||||
}
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
continue
|
||||
}
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Type assert to TLS connection and handshake to populate the
|
||||
// connection state
|
||||
tlsConn := conn.(*tls.Conn)
|
||||
|
||||
// Set a deadline for the handshake. This will cause clients
|
||||
// that don't successfully auth to be kicked out quickly.
|
||||
// Cluster connections should be reliable so being marginally
|
||||
// aggressive here is fine.
|
||||
err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second))
|
||||
if err != nil {
|
||||
if cl.logger.IsDebug() {
|
||||
cl.logger.Debug("error setting deadline for cluster connection", "error", err)
|
||||
}
|
||||
tlsConn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
err = tlsConn.Handshake()
|
||||
if err != nil {
|
||||
if cl.logger.IsDebug() {
|
||||
cl.logger.Debug("error handshaking cluster connection", "error", err)
|
||||
}
|
||||
tlsConn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
// Now, set it back to unlimited
|
||||
err = tlsConn.SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
if cl.logger.IsDebug() {
|
||||
cl.logger.Debug("error setting deadline for cluster connection", "error", err)
|
||||
}
|
||||
tlsConn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
cl.l.RLock()
|
||||
handler, ok := cl.handlers[tlsConn.ConnectionState().NegotiatedProtocol]
|
||||
cl.l.RUnlock()
|
||||
if !ok {
|
||||
cl.logger.Debug("unknown negotiated protocol on cluster port")
|
||||
tlsConn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
if err := handler.Handoff(ctx, cl.shutdownWg, closeCh, tlsConn); err != nil {
|
||||
cl.logger.Error("error handling cluster connection", "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the cluster listner
|
||||
func (cl *ClusterListener) Stop() {
|
||||
// Set the shutdown flag. This will cause the listeners to shut down
|
||||
// within the deadline in clusterListenerAcceptDeadline
|
||||
atomic.StoreUint32(cl.shutdown, 1)
|
||||
cl.logger.Info("forwarding rpc listeners stopped")
|
||||
|
||||
// Wait for them all to shut down
|
||||
cl.shutdownWg.Wait()
|
||||
cl.logger.Info("rpc listeners successfully shut down")
|
||||
}
|
||||
|
||||
// startClusterListener starts cluster request listeners during unseal. It
|
||||
// is assumed that the state lock is held while this is run. Right now this
|
||||
// only starts forwarding listeners; it's TBD whether other request types will
|
||||
// be built in the same mechanism or started independently.
|
||||
// only starts cluster listeners. Once the listener is started handlers/clients
|
||||
// can start being registered to it.
|
||||
func (c *Core) startClusterListener(ctx context.Context) error {
|
||||
if c.clusterAddr == "" {
|
||||
c.logger.Info("clustering disabled, not starting listeners")
|
||||
|
@ -307,76 +585,46 @@ func (c *Core) startClusterListener(ctx context.Context) error {
|
|||
|
||||
c.logger.Debug("starting cluster listeners")
|
||||
|
||||
err := c.startForwarding(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
// Create the HTTP/2 server that will be shared by both RPC and regular
|
||||
// duties. Doing it this way instead of listening via the server and gRPC
|
||||
// allows us to re-use the same port via ALPN. We can just tell the server
|
||||
// to serve a given conn and which handler to use.
|
||||
h2Server := &http2.Server{
|
||||
// Our forwarding connections heartbeat regularly so anything else we
|
||||
// want to go away/get cleaned up pretty rapidly
|
||||
IdleTimeout: 5 * HeartbeatInterval,
|
||||
}
|
||||
|
||||
return nil
|
||||
c.clusterListener = &ClusterListener{
|
||||
handlers: make(map[string]ClusterHandler),
|
||||
clients: make(map[string]ClusterClient),
|
||||
shutdown: new(uint32),
|
||||
shutdownWg: &sync.WaitGroup{},
|
||||
server: h2Server,
|
||||
|
||||
clusterListenerAddrs: c.clusterListenerAddrs,
|
||||
clusterCipherSuites: c.clusterCipherSuites,
|
||||
logger: c.logger.Named("cluster-listener"),
|
||||
}
|
||||
|
||||
return c.clusterListener.Run(ctx)
|
||||
}
|
||||
|
||||
// stopClusterListener stops any existing listeners during preseal. It is
|
||||
// stopClusterListener stops any existing listeners during seal. It is
|
||||
// assumed that the state lock is held while this is run.
|
||||
func (c *Core) stopClusterListener() {
|
||||
if c.clusterAddr == "" {
|
||||
|
||||
if c.clusterListener == nil {
|
||||
c.logger.Debug("clustering disabled, not stopping listeners")
|
||||
return
|
||||
}
|
||||
|
||||
if !c.clusterListenersRunning {
|
||||
c.logger.Info("cluster listeners not running")
|
||||
return
|
||||
}
|
||||
c.logger.Info("stopping cluster listeners")
|
||||
|
||||
// Tell the goroutine managing the listeners to perform the shutdown
|
||||
// process
|
||||
c.clusterListenerShutdownCh <- struct{}{}
|
||||
|
||||
// The reason for this loop-de-loop is that we may be unsealing again
|
||||
// quickly, and if the listeners are not yet closed, we will get socket
|
||||
// bind errors. This ensures proper ordering.
|
||||
|
||||
c.logger.Debug("waiting for success notification while stopping cluster listeners")
|
||||
<-c.clusterListenerShutdownSuccessCh
|
||||
c.clusterListenersRunning = false
|
||||
c.clusterListener.Stop()
|
||||
|
||||
c.logger.Info("cluster listeners successfully shut down")
|
||||
}
|
||||
|
||||
// ClusterTLSConfig generates a TLS configuration based on the local/replicated
|
||||
// cluster key and cert.
|
||||
func (c *Core) ClusterTLSConfig(ctx context.Context, repClusters *ReplicatedClusters, perfStandbyCluster *ReplicatedCluster) (*tls.Config, error) {
|
||||
// Using lookup functions allows just-in-time lookup of the current state
|
||||
// of clustering as connections come and go
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
GetCertificate: clusterTLSServerLookup(ctx, c, repClusters, perfStandbyCluster),
|
||||
GetClientCertificate: clusterTLSClientLookup(ctx, c, repClusters, perfStandbyCluster),
|
||||
GetConfigForClient: clusterTLSServerConfigLookup(ctx, c, repClusters, perfStandbyCluster),
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: c.clusterCipherSuites,
|
||||
}
|
||||
|
||||
parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate)
|
||||
currCert := c.localClusterCert.Load().([]byte)
|
||||
localCert := make([]byte, len(currCert))
|
||||
copy(localCert, currCert)
|
||||
|
||||
if parsedCert != nil {
|
||||
tlsConfig.ServerName = parsedCert.Subject.CommonName
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(parsedCert)
|
||||
tlsConfig.RootCAs = pool
|
||||
tlsConfig.ClientCAs = pool
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func (c *Core) SetClusterListenerAddrs(addrs []*net.TCPAddr) {
|
||||
c.clusterListenerAddrs = addrs
|
||||
if c.clusterAddr == "" && len(addrs) == 1 {
|
||||
|
@ -387,3 +635,36 @@ func (c *Core) SetClusterListenerAddrs(addrs []*net.TCPAddr) {
|
|||
func (c *Core) SetClusterHandler(handler http.Handler) {
|
||||
c.clusterHandler = handler
|
||||
}
|
||||
|
||||
// getGRPCDialer is used to return a dialer that has the correct TLS
|
||||
// configuration. Otherwise gRPC tries to be helpful and stomps all over our
|
||||
// NextProtos.
|
||||
func (c *Core) getGRPCDialer(ctx context.Context, alpnProto, serverName string, caCert *x509.Certificate) func(string, time.Duration) (net.Conn, error) {
|
||||
return func(addr string, timeout time.Duration) (net.Conn, error) {
|
||||
if c.clusterListener == nil {
|
||||
return nil, errors.New("clustering disabled")
|
||||
}
|
||||
|
||||
tlsConfig, err := c.clusterListener.TLSConfig(ctx)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to get tls configuration", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
if serverName != "" {
|
||||
tlsConfig.ServerName = serverName
|
||||
}
|
||||
if caCert != nil {
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(caCert)
|
||||
tlsConfig.RootCAs = pool
|
||||
tlsConfig.ClientCAs = pool
|
||||
}
|
||||
c.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName)
|
||||
|
||||
tlsConfig.NextProtos = []string{alpnProto}
|
||||
dialer := &net.Dialer{
|
||||
Timeout: timeout,
|
||||
}
|
||||
return tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -102,32 +103,25 @@ func TestCluster_ListenForRequests(t *testing.T) {
|
|||
TestWaitActive(t, cores[0].Core)
|
||||
|
||||
// Use this to have a valid config after sealing since ClusterTLSConfig returns nil
|
||||
var lastTLSConfig *tls.Config
|
||||
checkListenersFunc := func(expectFail bool) {
|
||||
tlsConfig, err := cores[0].ClusterTLSConfig(context.Background(), nil, nil)
|
||||
if err != nil {
|
||||
if err.Error() != consts.ErrSealed.Error() {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tlsConfig = lastTLSConfig
|
||||
} else {
|
||||
tlsConfig.NextProtos = []string{"h2"}
|
||||
lastTLSConfig = tlsConfig
|
||||
}
|
||||
cores[0].clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
|
||||
|
||||
parsedCert := cores[0].localClusterParsedCert.Load().(*x509.Certificate)
|
||||
dialer := cores[0].getGRPCDialer(context.Background(), requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
|
||||
for _, ln := range cores[0].Listeners {
|
||||
tcpAddr, ok := ln.Addr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
t.Fatalf("%s not a TCP port", tcpAddr.String())
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", tcpAddr.IP.String(), tcpAddr.Port+105), tlsConfig)
|
||||
netConn, err := dialer(fmt.Sprintf("%s:%d", tcpAddr.IP.String(), tcpAddr.Port+105), 0)
|
||||
conn := netConn.(*tls.Conn)
|
||||
if err != nil {
|
||||
if expectFail {
|
||||
t.Logf("testing %s:%d unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+105)
|
||||
continue
|
||||
}
|
||||
t.Fatalf("error: %v\nlisteners are\n%#v\n%#v\n", err, cores[0].Listeners[0], cores[0].Listeners[1])
|
||||
t.Fatalf("error: %v\nlisteners are\n%#v\n%#v\n", err, cores[0].Listeners[0], cores[0].Listeners[0])
|
||||
}
|
||||
if expectFail {
|
||||
t.Fatalf("testing %s:%d not unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+105)
|
||||
|
@ -140,7 +134,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
|
|||
switch {
|
||||
case connState.Version != tls.VersionTLS12:
|
||||
t.Fatal("version mismatch")
|
||||
case connState.NegotiatedProtocol != "h2" || !connState.NegotiatedProtocolIsMutual:
|
||||
case connState.NegotiatedProtocol != requestForwardingALPN || !connState.NegotiatedProtocolIsMutual:
|
||||
t.Fatal("bad protocol negotiation")
|
||||
}
|
||||
t.Logf("testing %s:%d successful", tcpAddr.IP.String(), tcpAddr.Port+105)
|
||||
|
@ -392,12 +386,13 @@ func TestCluster_CustomCipherSuites(t *testing.T) {
|
|||
// Wait for core to become active
|
||||
TestWaitActive(t, core.Core)
|
||||
|
||||
tlsConf, err := core.Core.ClusterTLSConfig(context.Background(), nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
core.clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{core.Core})
|
||||
|
||||
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", core.Listeners[0].Address.IP.String(), core.Listeners[0].Address.Port+105), tlsConf)
|
||||
parsedCert := core.localClusterParsedCert.Load().(*x509.Certificate)
|
||||
dialer := core.getGRPCDialer(context.Background(), requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
|
||||
|
||||
netConn, err := dialer(fmt.Sprintf("%s:%d", core.Listeners[0].Address.IP.String(), core.Listeners[0].Address.Port+105), 0)
|
||||
conn := netConn.(*tls.Conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -1,85 +0,0 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
clusterTLSServerLookup = func(ctx context.Context, c *Core, repClusters *ReplicatedClusters, _ *ReplicatedCluster) func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
c.logger.Debug("performing server cert lookup")
|
||||
|
||||
switch {
|
||||
default:
|
||||
currCert := c.localClusterCert.Load().([]byte)
|
||||
if len(currCert) == 0 {
|
||||
return nil, fmt.Errorf("got forwarding connection but no local cert")
|
||||
}
|
||||
|
||||
localCert := make([]byte, len(currCert))
|
||||
copy(localCert, currCert)
|
||||
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{localCert},
|
||||
PrivateKey: c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
|
||||
Leaf: c.localClusterParsedCert.Load().(*x509.Certificate),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
clusterTLSClientLookup = func(ctx context.Context, c *Core, repClusters *ReplicatedClusters, _ *ReplicatedCluster) func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return func(requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
if len(requestInfo.AcceptableCAs) != 1 {
|
||||
return nil, fmt.Errorf("expected only a single acceptable CA")
|
||||
}
|
||||
|
||||
currCert := c.localClusterCert.Load().([]byte)
|
||||
if len(currCert) == 0 {
|
||||
return nil, fmt.Errorf("forwarding connection client but no local cert")
|
||||
}
|
||||
|
||||
localCert := make([]byte, len(currCert))
|
||||
copy(localCert, currCert)
|
||||
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{localCert},
|
||||
PrivateKey: c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
|
||||
Leaf: c.localClusterParsedCert.Load().(*x509.Certificate),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
clusterTLSServerConfigLookup = func(ctx context.Context, c *Core, repClusters *ReplicatedClusters, repCluster *ReplicatedCluster) func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
return func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
//c.logger.Trace("performing server config lookup")
|
||||
|
||||
caPool := x509.NewCertPool()
|
||||
|
||||
ret := &tls.Config{
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
GetCertificate: clusterTLSServerLookup(ctx, c, repClusters, repCluster),
|
||||
GetClientCertificate: clusterTLSClientLookup(ctx, c, repClusters, repCluster),
|
||||
MinVersion: tls.VersionTLS12,
|
||||
RootCAs: caPool,
|
||||
ClientCAs: caPool,
|
||||
NextProtos: clientHello.SupportedProtos,
|
||||
CipherSuites: c.clusterCipherSuites,
|
||||
}
|
||||
|
||||
parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate)
|
||||
|
||||
if parsedCert == nil {
|
||||
return nil, fmt.Errorf("forwarding connection client but no local cert")
|
||||
}
|
||||
|
||||
caPool.AddCert(parsedCert)
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
}
|
||||
)
|
106
vault/core.go
106
vault/core.go
|
@ -26,6 +26,7 @@ import (
|
|||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/audit"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/logging"
|
||||
|
@ -132,10 +133,10 @@ func (e *ErrInvalidKey) Error() string {
|
|||
type RegisterAuthFunc func(context.Context, time.Duration, string, *logical.Auth) error
|
||||
|
||||
type activeAdvertisement struct {
|
||||
RedirectAddr string `json:"redirect_addr"`
|
||||
ClusterAddr string `json:"cluster_addr,omitempty"`
|
||||
ClusterCert []byte `json:"cluster_cert,omitempty"`
|
||||
ClusterKeyParams *clusterKeyParams `json:"cluster_key_params,omitempty"`
|
||||
RedirectAddr string `json:"redirect_addr"`
|
||||
ClusterAddr string `json:"cluster_addr,omitempty"`
|
||||
ClusterCert []byte `json:"cluster_cert,omitempty"`
|
||||
ClusterKeyParams *certutil.ClusterKeyParams `json:"cluster_key_params,omitempty"`
|
||||
}
|
||||
|
||||
type unlockInformation struct {
|
||||
|
@ -328,14 +329,6 @@ type Core struct {
|
|||
clusterListenerAddrs []*net.TCPAddr
|
||||
// The handler to use for request forwarding
|
||||
clusterHandler http.Handler
|
||||
// Tracks whether cluster listeners are running, e.g. it's safe to send a
|
||||
// shutdown down the channel
|
||||
clusterListenersRunning bool
|
||||
// Shutdown channel for the cluster listeners
|
||||
clusterListenerShutdownCh chan struct{}
|
||||
// Shutdown success channel. We need this to be done serially to ensure
|
||||
// that binds are removed before they might be reinstated.
|
||||
clusterListenerShutdownSuccessCh chan struct{}
|
||||
// Write lock used to ensure that we don't have multiple connections adjust
|
||||
// this value at the same time
|
||||
requestForwardingConnectionLock sync.RWMutex
|
||||
|
@ -346,8 +339,6 @@ type Core struct {
|
|||
clusterLeaderParams *atomic.Value
|
||||
// Info on cluster members
|
||||
clusterPeerClusterAddrsCache *cache.Cache
|
||||
// Stores whether we currently have a server running
|
||||
rpcServerActive *uint32
|
||||
// The context for the client
|
||||
rpcClientConnContext context.Context
|
||||
// The function for canceling the client connection
|
||||
|
@ -422,6 +413,9 @@ type Core struct {
|
|||
// artifacts in a case sensitive manner. To be used only in testing.
|
||||
loadCaseSensitiveIdentityStore bool
|
||||
|
||||
// clusterListener starts up and manages connections on the cluster ports
|
||||
clusterListener *ClusterListener
|
||||
|
||||
// Telemetry objects
|
||||
metricsHelper *metricsutil.MetricsHelper
|
||||
}
|
||||
|
@ -567,43 +561,40 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
|
||||
// Setup the core
|
||||
c := &Core{
|
||||
entCore: entCore{},
|
||||
devToken: conf.DevToken,
|
||||
physical: conf.Physical,
|
||||
redirectAddr: conf.RedirectAddr,
|
||||
clusterAddr: conf.ClusterAddr,
|
||||
seal: conf.Seal,
|
||||
router: NewRouter(),
|
||||
sealed: new(uint32),
|
||||
standby: true,
|
||||
baseLogger: conf.Logger,
|
||||
logger: conf.Logger.Named("core"),
|
||||
defaultLeaseTTL: conf.DefaultLeaseTTL,
|
||||
maxLeaseTTL: conf.MaxLeaseTTL,
|
||||
cachingDisabled: conf.DisableCache,
|
||||
clusterName: conf.ClusterName,
|
||||
clusterListenerShutdownCh: make(chan struct{}),
|
||||
clusterListenerShutdownSuccessCh: make(chan struct{}),
|
||||
clusterPeerClusterAddrsCache: cache.New(3*HeartbeatInterval, time.Second),
|
||||
enableMlock: !conf.DisableMlock,
|
||||
rawEnabled: conf.EnableRaw,
|
||||
replicationState: new(uint32),
|
||||
rpcServerActive: new(uint32),
|
||||
atomicPrimaryClusterAddrs: new(atomic.Value),
|
||||
atomicPrimaryFailoverAddrs: new(atomic.Value),
|
||||
localClusterPrivateKey: new(atomic.Value),
|
||||
localClusterCert: new(atomic.Value),
|
||||
localClusterParsedCert: new(atomic.Value),
|
||||
activeNodeReplicationState: new(uint32),
|
||||
keepHALockOnStepDown: new(uint32),
|
||||
replicationFailure: new(uint32),
|
||||
disablePerfStandby: true,
|
||||
activeContextCancelFunc: new(atomic.Value),
|
||||
allLoggers: conf.AllLoggers,
|
||||
builtinRegistry: conf.BuiltinRegistry,
|
||||
neverBecomeActive: new(uint32),
|
||||
clusterLeaderParams: new(atomic.Value),
|
||||
metricsHelper: conf.MetricsHelper,
|
||||
entCore: entCore{},
|
||||
devToken: conf.DevToken,
|
||||
physical: conf.Physical,
|
||||
redirectAddr: conf.RedirectAddr,
|
||||
clusterAddr: conf.ClusterAddr,
|
||||
seal: conf.Seal,
|
||||
router: NewRouter(),
|
||||
sealed: new(uint32),
|
||||
standby: true,
|
||||
baseLogger: conf.Logger,
|
||||
logger: conf.Logger.Named("core"),
|
||||
defaultLeaseTTL: conf.DefaultLeaseTTL,
|
||||
maxLeaseTTL: conf.MaxLeaseTTL,
|
||||
cachingDisabled: conf.DisableCache,
|
||||
clusterName: conf.ClusterName,
|
||||
clusterPeerClusterAddrsCache: cache.New(3*HeartbeatInterval, time.Second),
|
||||
enableMlock: !conf.DisableMlock,
|
||||
rawEnabled: conf.EnableRaw,
|
||||
replicationState: new(uint32),
|
||||
atomicPrimaryClusterAddrs: new(atomic.Value),
|
||||
atomicPrimaryFailoverAddrs: new(atomic.Value),
|
||||
localClusterPrivateKey: new(atomic.Value),
|
||||
localClusterCert: new(atomic.Value),
|
||||
localClusterParsedCert: new(atomic.Value),
|
||||
activeNodeReplicationState: new(uint32),
|
||||
keepHALockOnStepDown: new(uint32),
|
||||
replicationFailure: new(uint32),
|
||||
disablePerfStandby: true,
|
||||
activeContextCancelFunc: new(atomic.Value),
|
||||
allLoggers: conf.AllLoggers,
|
||||
builtinRegistry: conf.BuiltinRegistry,
|
||||
neverBecomeActive: new(uint32),
|
||||
clusterLeaderParams: new(atomic.Value),
|
||||
metricsHelper: conf.MetricsHelper,
|
||||
}
|
||||
|
||||
atomic.StoreUint32(c.sealed, 1)
|
||||
|
@ -1043,6 +1034,10 @@ func (c *Core) unsealInternal(ctx context.Context, masterKey []byte) (bool, erro
|
|||
return false, err
|
||||
}
|
||||
|
||||
if err := c.startClusterListener(ctx); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Do post-unseal setup if HA is not enabled
|
||||
if c.ha == nil {
|
||||
// We still need to set up cluster info even if it's not part of a
|
||||
|
@ -1365,6 +1360,9 @@ func (c *Core) sealInternalWithOptions(grabStateLock, keepHALock bool) error {
|
|||
c.logger.Debug("runStandby done")
|
||||
}
|
||||
|
||||
// Stop the cluster listener
|
||||
c.stopClusterListener()
|
||||
|
||||
c.logger.Debug("sealing barrier")
|
||||
if err := c.barrier.Seal(); err != nil {
|
||||
c.logger.Error("error sealing barrier", "error", err)
|
||||
|
@ -1461,8 +1459,8 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c
|
|||
c.auditBroker = NewAuditBroker(c.logger)
|
||||
}
|
||||
|
||||
if c.ha != nil || shouldStartClusterListener(c) {
|
||||
if err := c.startClusterListener(ctx); err != nil {
|
||||
if c.clusterListener != nil && (c.ha != nil || shouldStartClusterListener(c)) {
|
||||
if err := c.startForwarding(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -1553,7 +1551,7 @@ func (c *Core) preSeal() error {
|
|||
}
|
||||
c.clusterParamsLock.Unlock()
|
||||
|
||||
c.stopClusterListener()
|
||||
c.stopForwarding()
|
||||
|
||||
if err := c.teardownAudits(); err != nil {
|
||||
result = multierror.Append(result, errwrap.Wrapf("error tearing down audits: {{err}}", err))
|
||||
|
|
|
@ -4,11 +4,14 @@ package vault
|
|||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/license"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/physical"
|
||||
"github.com/hashicorp/vault/vault/replication"
|
||||
cache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
type entCore struct{}
|
||||
|
@ -89,7 +92,7 @@ func (c *Core) namepaceByPath(string) *namespace.Namespace {
|
|||
return namespace.RootNamespace
|
||||
}
|
||||
|
||||
func (c *Core) setupReplicatedClusterPrimary(*ReplicatedCluster) error { return nil }
|
||||
func (c *Core) setupReplicatedClusterPrimary(*replication.Cluster) error { return nil }
|
||||
|
||||
func (c *Core) perfStandbyCount() int { return 0 }
|
||||
|
||||
|
@ -104,3 +107,7 @@ func (c *Core) checkReplicatedFiltering(context.Context, *MountEntry, string) (b
|
|||
func (c *Core) invalidateSentinelPolicy(PolicyType, string) {}
|
||||
|
||||
func (c *Core) removePerfStandbySecondary(context.Context, string) {}
|
||||
|
||||
func (c *Core) perfStandbyClusterHandler() (*replication.Cluster, *cache.Cache, chan struct{}, error) {
|
||||
return nil, cache.New(2*HeartbeatInterval, 1*time.Second), make(chan struct{}), nil
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
multierror "github.com/hashicorp/go-multierror"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/audit"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
|
@ -817,7 +818,7 @@ func (c *Core) advertiseLeader(ctx context.Context, uuid string, leaderLostCh <-
|
|||
return fmt.Errorf("unknown cluster private key type %T", c.localClusterPrivateKey.Load())
|
||||
}
|
||||
|
||||
keyParams := &clusterKeyParams{
|
||||
keyParams := &certutil.ClusterKeyParams{
|
||||
Type: corePrivateKeyTypeP521,
|
||||
X: key.X,
|
||||
Y: key.Y,
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
// +build !enterprise
|
||||
|
||||
package vault
|
||||
package replication
|
||||
|
||||
import "github.com/hashicorp/vault/helper/consts"
|
||||
|
||||
type ReplicatedCluster struct {
|
||||
type Cluster struct {
|
||||
State consts.ReplicationState
|
||||
ClusterID string
|
||||
PrimaryClusterAddr string
|
||||
}
|
||||
|
||||
type Clusters struct {
|
||||
DR *Cluster
|
||||
Performance *Cluster
|
||||
}
|
|
@ -1,23 +1,23 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
math "math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
cache "github.com/patrickmn/go-cache"
|
||||
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/helper/forwarding"
|
||||
"github.com/hashicorp/vault/vault/replication"
|
||||
cache "github.com/patrickmn/go-cache"
|
||||
"golang.org/x/net/http2"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
@ -44,11 +44,154 @@ var (
|
|||
HeartbeatInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
type SecondaryConnsCacheVals struct {
|
||||
ID string
|
||||
Token string
|
||||
Connection net.Conn
|
||||
Mode consts.ReplicationState
|
||||
type requestForwardingHandler struct {
|
||||
fws *http2.Server
|
||||
fwRPCServer *grpc.Server
|
||||
logger log.Logger
|
||||
ha bool
|
||||
core *Core
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
type requestForwardingClusterClient struct {
|
||||
core *Core
|
||||
}
|
||||
|
||||
// NewRequestForwardingHandler creates a cluster handler for use with request
|
||||
// forwarding.
|
||||
func NewRequestForwardingHandler(c *Core, fws *http2.Server, perfStandbySlots chan struct{}, perfStandbyRepCluster *replication.Cluster, perfStandbyCache *cache.Cache) (*requestForwardingHandler, error) {
|
||||
// Resolve locally to avoid races
|
||||
ha := c.ha != nil
|
||||
|
||||
fwRPCServer := grpc.NewServer(
|
||||
grpc.KeepaliveParams(keepalive.ServerParameters{
|
||||
Time: 2 * HeartbeatInterval,
|
||||
}),
|
||||
grpc.MaxRecvMsgSize(math.MaxInt32),
|
||||
grpc.MaxSendMsgSize(math.MaxInt32),
|
||||
)
|
||||
|
||||
if ha && c.clusterHandler != nil {
|
||||
RegisterRequestForwardingServer(fwRPCServer, &forwardedRequestRPCServer{
|
||||
core: c,
|
||||
handler: c.clusterHandler,
|
||||
perfStandbySlots: perfStandbySlots,
|
||||
perfStandbyRepCluster: perfStandbyRepCluster,
|
||||
perfStandbyCache: perfStandbyCache,
|
||||
})
|
||||
}
|
||||
|
||||
return &requestForwardingHandler{
|
||||
fws: fws,
|
||||
fwRPCServer: fwRPCServer,
|
||||
ha: ha,
|
||||
logger: c.logger.Named("request-forward"),
|
||||
core: c,
|
||||
stopCh: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ClientLookup satisfies the ClusterClient interface and returns the ha tls
|
||||
// client certs.
|
||||
func (c *requestForwardingClusterClient) ClientLookup(ctx context.Context, requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
parsedCert := c.core.localClusterParsedCert.Load().(*x509.Certificate)
|
||||
if parsedCert == nil {
|
||||
return nil, nil
|
||||
}
|
||||
currCert := c.core.localClusterCert.Load().([]byte)
|
||||
if len(currCert) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
localCert := make([]byte, len(currCert))
|
||||
copy(localCert, currCert)
|
||||
|
||||
for _, subj := range requestInfo.AcceptableCAs {
|
||||
if bytes.Equal(subj, parsedCert.RawIssuer) {
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{localCert},
|
||||
PrivateKey: c.core.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
|
||||
Leaf: c.core.localClusterParsedCert.Load().(*x509.Certificate),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// ServerLookup satisfies the ClusterHandler interface and returns the server's
|
||||
// tls certs.
|
||||
func (rf *requestForwardingHandler) ServerLookup(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
currCert := rf.core.localClusterCert.Load().([]byte)
|
||||
if len(currCert) == 0 {
|
||||
return nil, fmt.Errorf("got forwarding connection but no local cert")
|
||||
}
|
||||
|
||||
localCert := make([]byte, len(currCert))
|
||||
copy(localCert, currCert)
|
||||
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{localCert},
|
||||
PrivateKey: rf.core.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
|
||||
Leaf: rf.core.localClusterParsedCert.Load().(*x509.Certificate),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CALookup satisfies the ClusterHandler interface and returns the ha ca cert.
|
||||
func (rf *requestForwardingHandler) CALookup(ctx context.Context) (*x509.Certificate, error) {
|
||||
parsedCert := rf.core.localClusterParsedCert.Load().(*x509.Certificate)
|
||||
|
||||
if parsedCert == nil {
|
||||
return nil, fmt.Errorf("forwarding connection client but no local cert")
|
||||
}
|
||||
|
||||
return parsedCert, nil
|
||||
}
|
||||
|
||||
// Handoff serves a request forwarding connection.
|
||||
func (rf *requestForwardingHandler) Handoff(ctx context.Context, shutdownWg *sync.WaitGroup, closeCh chan struct{}, tlsConn *tls.Conn) error {
|
||||
if !rf.ha {
|
||||
tlsConn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
rf.logger.Debug("got request forwarding connection")
|
||||
|
||||
shutdownWg.Add(2)
|
||||
// quitCh is used to close the connection and the second
|
||||
// goroutine if the server closes before closeCh.
|
||||
quitCh := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-quitCh:
|
||||
case <-closeCh:
|
||||
case <-rf.stopCh:
|
||||
}
|
||||
tlsConn.Close()
|
||||
shutdownWg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
rf.fws.ServeConn(tlsConn, &http2.ServeConnOpts{
|
||||
Handler: rf.fwRPCServer,
|
||||
BaseConfig: &http.Server{
|
||||
ErrorLog: rf.logger.StandardLogger(nil),
|
||||
},
|
||||
})
|
||||
|
||||
// close the quitCh which will close the connection and
|
||||
// the other goroutine.
|
||||
close(quitCh)
|
||||
shutdownWg.Done()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the request forwarding server and closes connections.
|
||||
func (rf *requestForwardingHandler) Stop() error {
|
||||
close(rf.stopCh)
|
||||
rf.fwRPCServer.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Starts the listeners and servers necessary to handle forwarded requests
|
||||
|
@ -62,269 +205,32 @@ func (c *Core) startForwarding(ctx context.Context) error {
|
|||
c.requestForwardingConnectionLock.Unlock()
|
||||
|
||||
// Resolve locally to avoid races
|
||||
ha := c.ha != nil
|
||||
|
||||
var perfStandbyRepCluster *ReplicatedCluster
|
||||
if ha {
|
||||
id, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
perfStandbyRepCluster = &ReplicatedCluster{
|
||||
State: consts.ReplicationPerformanceStandby,
|
||||
ClusterID: id,
|
||||
PrimaryClusterAddr: c.clusterAddr,
|
||||
}
|
||||
if err = c.setupReplicatedClusterPrimary(perfStandbyRepCluster); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Get our TLS config
|
||||
tlsConfig, err := c.ClusterTLSConfig(ctx, nil, perfStandbyRepCluster)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to get tls configuration when starting forwarding", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// The server supports all of the possible protos
|
||||
tlsConfig.NextProtos = []string{"h2", requestForwardingALPN, perfStandbyALPN, PerformanceReplicationALPN, DRReplicationALPN}
|
||||
|
||||
if !atomic.CompareAndSwapUint32(c.rpcServerActive, 0, 1) {
|
||||
c.logger.Warn("forwarding rpc server already running")
|
||||
if c.ha == nil || c.clusterListener == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fwRPCServer := grpc.NewServer(
|
||||
grpc.KeepaliveParams(keepalive.ServerParameters{
|
||||
Time: 2 * HeartbeatInterval,
|
||||
}),
|
||||
grpc.MaxRecvMsgSize(math.MaxInt32),
|
||||
grpc.MaxSendMsgSize(math.MaxInt32),
|
||||
)
|
||||
|
||||
// Setup performance standby RPC servers
|
||||
perfStandbyCount := 0
|
||||
if !c.IsDRSecondary() && !c.disablePerfStandby {
|
||||
perfStandbyCount = c.perfStandbyCount()
|
||||
}
|
||||
perfStandbySlots := make(chan struct{}, perfStandbyCount)
|
||||
|
||||
perfStandbyCache := cache.New(2*HeartbeatInterval, 1*time.Second)
|
||||
perfStandbyCache.OnEvicted(func(secondaryID string, _ interface{}) {
|
||||
c.logger.Debug("removing performance standby", "id", secondaryID)
|
||||
c.removePerfStandbySecondary(context.Background(), secondaryID)
|
||||
select {
|
||||
case <-perfStandbySlots:
|
||||
default:
|
||||
c.logger.Warn("perf secondary timeout hit but no slot to free")
|
||||
}
|
||||
})
|
||||
|
||||
perfStandbyReplicationRPCServer := perfStandbyRPCServer(c, perfStandbyCache)
|
||||
|
||||
if ha && c.clusterHandler != nil {
|
||||
RegisterRequestForwardingServer(fwRPCServer, &forwardedRequestRPCServer{
|
||||
core: c,
|
||||
handler: c.clusterHandler,
|
||||
perfStandbySlots: perfStandbySlots,
|
||||
perfStandbyRepCluster: perfStandbyRepCluster,
|
||||
perfStandbyCache: perfStandbyCache,
|
||||
})
|
||||
perfStandbyRepCluster, perfStandbyCache, perfStandbySlots, err := c.perfStandbyClusterHandler()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create the HTTP/2 server that will be shared by both RPC and regular
|
||||
// duties. Doing it this way instead of listening via the server and gRPC
|
||||
// allows us to re-use the same port via ALPN. We can just tell the server
|
||||
// to serve a given conn and which handler to use.
|
||||
fws := &http2.Server{
|
||||
// Our forwarding connections heartbeat regularly so anything else we
|
||||
// want to go away/get cleaned up pretty rapidly
|
||||
IdleTimeout: 5 * HeartbeatInterval,
|
||||
handler, err := NewRequestForwardingHandler(c, c.clusterListener.Server(), perfStandbySlots, perfStandbyRepCluster, perfStandbyCache)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Shutdown coordination logic
|
||||
shutdown := new(uint32)
|
||||
shutdownWg := &sync.WaitGroup{}
|
||||
|
||||
for _, addr := range c.clusterListenerAddrs {
|
||||
shutdownWg.Add(1)
|
||||
|
||||
// Force a local resolution to avoid data races
|
||||
laddr := addr
|
||||
|
||||
// Start our listening loop
|
||||
go func() {
|
||||
defer shutdownWg.Done()
|
||||
|
||||
// closeCh is used to shutdown the spawned goroutines once this
|
||||
// function returns
|
||||
closeCh := make(chan struct{})
|
||||
defer func() {
|
||||
close(closeCh)
|
||||
}()
|
||||
|
||||
if c.logger.IsInfo() {
|
||||
c.logger.Info("starting listener", "listener_address", laddr)
|
||||
}
|
||||
|
||||
// Create a TCP listener. We do this separately and specifically
|
||||
// with TCP so that we can set deadlines.
|
||||
tcpLn, err := net.ListenTCP("tcp", laddr)
|
||||
if err != nil {
|
||||
c.logger.Error("error starting listener", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Wrap the listener with TLS
|
||||
tlsLn := tls.NewListener(tcpLn, tlsConfig)
|
||||
defer tlsLn.Close()
|
||||
|
||||
if c.logger.IsInfo() {
|
||||
c.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr())
|
||||
}
|
||||
|
||||
for {
|
||||
if atomic.LoadUint32(shutdown) > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Set the deadline for the accept call. If it passes we'll get
|
||||
// an error, causing us to check the condition at the top
|
||||
// again.
|
||||
tcpLn.SetDeadline(time.Now().Add(clusterListenerAcceptDeadline))
|
||||
|
||||
// Accept the connection
|
||||
conn, err := tlsLn.Accept()
|
||||
if err != nil {
|
||||
if err, ok := err.(net.Error); ok && !err.Timeout() {
|
||||
c.logger.Debug("non-timeout error accepting on cluster port", "error", err)
|
||||
}
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
continue
|
||||
}
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Type assert to TLS connection and handshake to populate the
|
||||
// connection state
|
||||
tlsConn := conn.(*tls.Conn)
|
||||
|
||||
// Set a deadline for the handshake. This will cause clients
|
||||
// that don't successfully auth to be kicked out quickly.
|
||||
// Cluster connections should be reliable so being marginally
|
||||
// aggressive here is fine.
|
||||
err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second))
|
||||
if err != nil {
|
||||
if c.logger.IsDebug() {
|
||||
c.logger.Debug("error setting deadline for cluster connection", "error", err)
|
||||
}
|
||||
tlsConn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
err = tlsConn.Handshake()
|
||||
if err != nil {
|
||||
if c.logger.IsDebug() {
|
||||
c.logger.Debug("error handshaking cluster connection", "error", err)
|
||||
}
|
||||
tlsConn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
// Now, set it back to unlimited
|
||||
err = tlsConn.SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
if c.logger.IsDebug() {
|
||||
c.logger.Debug("error setting deadline for cluster connection", "error", err)
|
||||
}
|
||||
tlsConn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
switch tlsConn.ConnectionState().NegotiatedProtocol {
|
||||
case requestForwardingALPN:
|
||||
if !ha {
|
||||
tlsConn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Debug("got request forwarding connection")
|
||||
|
||||
shutdownWg.Add(2)
|
||||
// quitCh is used to close the connection and the second
|
||||
// goroutine if the server closes before closeCh.
|
||||
quitCh := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-quitCh:
|
||||
case <-closeCh:
|
||||
}
|
||||
tlsConn.Close()
|
||||
shutdownWg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
fws.ServeConn(tlsConn, &http2.ServeConnOpts{
|
||||
Handler: fwRPCServer,
|
||||
BaseConfig: &http.Server{
|
||||
ErrorLog: c.logger.StandardLogger(nil),
|
||||
},
|
||||
})
|
||||
// close the quitCh which will close the connection and
|
||||
// the other goroutine.
|
||||
close(quitCh)
|
||||
shutdownWg.Done()
|
||||
}()
|
||||
|
||||
case PerformanceReplicationALPN, DRReplicationALPN, perfStandbyALPN:
|
||||
handleReplicationConn(ctx, c, shutdownWg, closeCh, fws, perfStandbyReplicationRPCServer, perfStandbyCache, tlsConn)
|
||||
default:
|
||||
c.logger.Debug("unknown negotiated protocol on cluster port")
|
||||
tlsConn.Close()
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// This is in its own goroutine so that we don't block the main thread, and
|
||||
// thus we use atomic and channels to coordinate
|
||||
// However, because you can't query the status of a channel, we set a bool
|
||||
// here while we have the state lock to know whether to actually send a
|
||||
// shutdown (e.g. whether the channel will block). See issue #2083.
|
||||
c.clusterListenersRunning = true
|
||||
go func() {
|
||||
// If we get told to shut down...
|
||||
<-c.clusterListenerShutdownCh
|
||||
|
||||
// Stop the RPC server
|
||||
c.logger.Info("shutting down forwarding rpc listeners")
|
||||
fwRPCServer.Stop()
|
||||
|
||||
// Set the shutdown flag. This will cause the listeners to shut down
|
||||
// within the deadline in clusterListenerAcceptDeadline
|
||||
atomic.StoreUint32(shutdown, 1)
|
||||
c.logger.Info("forwarding rpc listeners stopped")
|
||||
|
||||
// Wait for them all to shut down
|
||||
shutdownWg.Wait()
|
||||
c.logger.Info("rpc listeners successfully shut down")
|
||||
|
||||
// Clear us up to run this function again
|
||||
atomic.StoreUint32(c.rpcServerActive, 0)
|
||||
|
||||
// Tell the main thread that shutdown is done.
|
||||
c.clusterListenerShutdownSuccessCh <- struct{}{}
|
||||
}()
|
||||
c.clusterListener.AddHandler(requestForwardingALPN, handler)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Core) stopForwarding() {
|
||||
if c.clusterListener != nil {
|
||||
c.clusterListener.StopHandler(requestForwardingALPN)
|
||||
c.clusterListener.StopHandler(perfStandbyALPN)
|
||||
}
|
||||
}
|
||||
|
||||
// refreshRequestForwardingConnection ensures that the client/transport are
|
||||
// alive and that the current active address value matches the most
|
||||
// recently-known address.
|
||||
|
@ -349,13 +255,25 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
|
|||
return err
|
||||
}
|
||||
|
||||
parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate)
|
||||
if parsedCert == nil {
|
||||
c.logger.Error("no request forwarding cluster certificate found")
|
||||
return errors.New("no request forwarding cluster certificate found")
|
||||
}
|
||||
|
||||
if c.clusterListener != nil {
|
||||
c.clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{
|
||||
core: c,
|
||||
})
|
||||
}
|
||||
|
||||
// Set up grpc forwarding handling
|
||||
// It's not really insecure, but we have to dial manually to get the
|
||||
// ALPN header right. It's just "insecure" because GRPC isn't managing
|
||||
// the TLS state.
|
||||
dctx, cancelFunc := context.WithCancel(ctx)
|
||||
c.rpcClientConn, err = grpc.DialContext(dctx, clusterURL.Host,
|
||||
grpc.WithDialer(c.getGRPCDialer(ctx, requestForwardingALPN, "", nil, nil, nil)),
|
||||
grpc.WithDialer(c.getGRPCDialer(ctx, requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)),
|
||||
grpc.WithInsecure(), // it's not, we handle it in the dialer
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 2 * HeartbeatInterval,
|
||||
|
@ -398,6 +316,9 @@ func (c *Core) clearForwardingClients() {
|
|||
c.rpcClientConnContext = nil
|
||||
c.rpcForwardingClient = nil
|
||||
|
||||
if c.clusterListener != nil {
|
||||
c.clusterListener.RemoveClient(requestForwardingALPN)
|
||||
}
|
||||
c.clusterLeaderParams.Store((*ClusterLeaderParams)(nil))
|
||||
}
|
||||
|
||||
|
@ -450,32 +371,3 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro
|
|||
|
||||
return int(resp.StatusCode), header, resp.Body, nil
|
||||
}
|
||||
|
||||
// getGRPCDialer is used to return a dialer that has the correct TLS
|
||||
// configuration. Otherwise gRPC tries to be helpful and stomps all over our
|
||||
// NextProtos.
|
||||
func (c *Core) getGRPCDialer(ctx context.Context, alpnProto, serverName string, caCert *x509.Certificate, repClusters *ReplicatedClusters, perfStandbyCluster *ReplicatedCluster) func(string, time.Duration) (net.Conn, error) {
|
||||
return func(addr string, timeout time.Duration) (net.Conn, error) {
|
||||
tlsConfig, err := c.ClusterTLSConfig(ctx, repClusters, perfStandbyCluster)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to get tls configuration", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
if serverName != "" {
|
||||
tlsConfig.ServerName = serverName
|
||||
}
|
||||
if caCert != nil {
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(caCert)
|
||||
tlsConfig.RootCAs = pool
|
||||
tlsConfig.ClientCAs = pool
|
||||
}
|
||||
c.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName)
|
||||
|
||||
tlsConfig.NextProtos = []string{alpnProto}
|
||||
dialer := &net.Dialer{
|
||||
Timeout: timeout,
|
||||
}
|
||||
return tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/forwarding"
|
||||
"github.com/hashicorp/vault/vault/replication"
|
||||
cache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
|
@ -16,7 +17,7 @@ type forwardedRequestRPCServer struct {
|
|||
core *Core
|
||||
handler http.Handler
|
||||
perfStandbySlots chan struct{}
|
||||
perfStandbyRepCluster *ReplicatedCluster
|
||||
perfStandbyRepCluster *replication.Cluster
|
||||
perfStandbyCache *cache.Cache
|
||||
}
|
||||
|
||||
|
|
|
@ -1,18 +0,0 @@
|
|||
// +build !enterprise
|
||||
|
||||
package vault
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"sync"
|
||||
|
||||
cache "github.com/patrickmn/go-cache"
|
||||
"golang.org/x/net/http2"
|
||||
grpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func perfStandbyRPCServer(*Core, *cache.Cache) *grpc.Server { return nil }
|
||||
|
||||
func handleReplicationConn(context.Context, *Core, *sync.WaitGroup, chan struct{}, *http2.Server, *grpc.Server, *cache.Cache, *tls.Conn) {
|
||||
}
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/SermoDigital/jose/jws"
|
||||
"github.com/SermoDigital/jose/jwt"
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
|
@ -31,7 +32,7 @@ func (c *Core) ensureWrappingKey(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
var keyParams clusterKeyParams
|
||||
var keyParams certutil.ClusterKeyParams
|
||||
|
||||
if entry == nil {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
||||
|
|
Loading…
Reference in New Issue