Refactor the cluster listener (#6232)

* Port over OSS cluster port refactor components

* Start forwarding

* Cleanup a bit

* Fix copy error

* Return error from perf standby creation

* Add some more comments

* Fix copy/paste error
This commit is contained in:
Brian Kassouf 2019-02-14 18:14:56 -08:00 committed by GitHub
parent feb235d5f8
commit f5b5fbb392
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 633 additions and 546 deletions

View File

@ -17,12 +17,21 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"math/big"
"strings"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/errutil"
)
// This can be one of a few key types so the different params may or may not be filled
type ClusterKeyParams struct {
Type string `json:"type" structs:"type" mapstructure:"type"`
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
}
// Secret is used to attempt to unmarshal a Vault secret
// JSON response, as a convenience
type Secret struct {

View File

@ -15,12 +15,16 @@ import (
mathrand "math/rand"
"net"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/logical"
"golang.org/x/net/http2"
)
const (
@ -44,19 +48,6 @@ type ClusterLeaderParams struct {
LeaderClusterAddr string
}
type ReplicatedClusters struct {
DR *ReplicatedCluster
Performance *ReplicatedCluster
}
// This can be one of a few key types so the different params may or may not be filled
type clusterKeyParams struct {
Type string `json:"type" structs:"type" mapstructure:"type"`
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
}
// Structure representing the storage entry that holds cluster information
type Cluster struct {
// Name of the cluster
@ -290,10 +281,297 @@ func (c *Core) setupCluster(ctx context.Context) error {
return nil
}
// startClusterListener starts cluster request listeners during postunseal. It
// ClusterClient is used to lookup a client certificate.
type ClusterClient interface {
ClientLookup(context.Context, *tls.CertificateRequestInfo) (*tls.Certificate, error)
}
// ClusterHandler exposes functions for looking up TLS configuration and handing
// off a connection for a cluster listener application.
type ClusterHandler interface {
ServerLookup(context.Context, *tls.ClientHelloInfo) (*tls.Certificate, error)
CALookup(context.Context) (*x509.Certificate, error)
// Handoff is used to pass the connection lifetime off to
// the storage backend
Handoff(context.Context, *sync.WaitGroup, chan struct{}, *tls.Conn) error
Stop() error
}
// ClusterListener is the source of truth for cluster handlers and connection
// clients. It dynamically builds the cluster TLS information. It's also
// responsible for starting tcp listeners and accepting new cluster connections.
type ClusterListener struct {
handlers map[string]ClusterHandler
clients map[string]ClusterClient
shutdown *uint32
shutdownWg *sync.WaitGroup
server *http2.Server
clusterListenerAddrs []*net.TCPAddr
clusterCipherSuites []uint16
logger log.Logger
l sync.RWMutex
}
// AddClient adds a new client for an ALPN name
func (cl *ClusterListener) AddClient(alpn string, client ClusterClient) {
cl.l.Lock()
cl.clients[alpn] = client
cl.l.Unlock()
}
// RemoveClient removes the client for the specified ALPN name
func (cl *ClusterListener) RemoveClient(alpn string) {
cl.l.Lock()
delete(cl.clients, alpn)
cl.l.Unlock()
}
// AddHandler registers a new cluster handler for the provided ALPN name.
func (cl *ClusterListener) AddHandler(alpn string, handler ClusterHandler) {
cl.l.Lock()
cl.handlers[alpn] = handler
cl.l.Unlock()
}
// StopHandler stops the cluster handler for the provided ALPN name, it also
// calls stop on the handler.
func (cl *ClusterListener) StopHandler(alpn string) {
cl.l.Lock()
handler, ok := cl.handlers[alpn]
delete(cl.handlers, alpn)
cl.l.Unlock()
if ok {
handler.Stop()
}
}
// Server returns the http2 server that the cluster listener is using
func (cl *ClusterListener) Server() *http2.Server {
return cl.server
}
// TLSConfig returns a tls config object that uses dynamic lookups to correctly
// authenticate registered handlers/clients
func (cl *ClusterListener) TLSConfig(ctx context.Context) (*tls.Config, error) {
serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cl.logger.Debug("performing server cert lookup")
cl.l.RLock()
defer cl.l.RUnlock()
for _, v := range clientHello.SupportedProtos {
if handler, ok := cl.handlers[v]; ok {
return handler.ServerLookup(ctx, clientHello)
}
}
return nil, errors.New("unsupported protocol")
}
clientLookup := func(requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
cl.logger.Debug("performing client cert lookup")
cl.l.RLock()
defer cl.l.RUnlock()
for _, client := range cl.clients {
cert, err := client.ClientLookup(ctx, requestInfo)
if err == nil && cert != nil {
return cert, nil
}
}
return nil, errors.New("no client cert found")
}
serverConfigLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
caPool := x509.NewCertPool()
ret := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
GetCertificate: serverLookup,
GetClientCertificate: clientLookup,
MinVersion: tls.VersionTLS12,
RootCAs: caPool,
ClientCAs: caPool,
NextProtos: clientHello.SupportedProtos,
CipherSuites: cl.clusterCipherSuites,
}
cl.l.RLock()
defer cl.l.RUnlock()
for _, v := range clientHello.SupportedProtos {
if handler, ok := cl.handlers[v]; ok {
ca, err := handler.CALookup(ctx)
if err != nil {
return nil, err
}
caPool.AddCert(ca)
return ret, nil
}
}
return nil, errors.New("unsupported protocol")
}
return &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
GetCertificate: serverLookup,
GetClientCertificate: clientLookup,
GetConfigForClient: serverConfigLookup,
MinVersion: tls.VersionTLS12,
CipherSuites: cl.clusterCipherSuites,
}, nil
}
// Run starts the tcp listeners and will accept connections until stop is
// called. This function blocks so should be called in a go routine.
func (cl *ClusterListener) Run(ctx context.Context) error {
// Get our TLS config
tlsConfig, err := cl.TLSConfig(ctx)
if err != nil {
cl.logger.Error("failed to get tls configuration when starting cluster listener", "error", err)
return err
}
// The server supports all of the possible protos
tlsConfig.NextProtos = []string{"h2", requestForwardingALPN, perfStandbyALPN, PerformanceReplicationALPN, DRReplicationALPN}
for _, addr := range cl.clusterListenerAddrs {
cl.shutdownWg.Add(1)
// Force a local resolution to avoid data races
laddr := addr
// Start our listening loop
go func() {
defer cl.shutdownWg.Done()
// closeCh is used to shutdown the spawned goroutines once this
// function returns
closeCh := make(chan struct{})
defer func() {
close(closeCh)
}()
if cl.logger.IsInfo() {
cl.logger.Info("starting listener", "listener_address", laddr)
}
// Create a TCP listener. We do this separately and specifically
// with TCP so that we can set deadlines.
tcpLn, err := net.ListenTCP("tcp", laddr)
if err != nil {
cl.logger.Error("error starting listener", "error", err)
return
}
// Wrap the listener with TLS
tlsLn := tls.NewListener(tcpLn, tlsConfig)
defer tlsLn.Close()
if cl.logger.IsInfo() {
cl.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr())
}
for {
if atomic.LoadUint32(cl.shutdown) > 0 {
return
}
// Set the deadline for the accept call. If it passes we'll get
// an error, causing us to check the condition at the top
// again.
tcpLn.SetDeadline(time.Now().Add(clusterListenerAcceptDeadline))
// Accept the connection
conn, err := tlsLn.Accept()
if err != nil {
if err, ok := err.(net.Error); ok && !err.Timeout() {
cl.logger.Debug("non-timeout error accepting on cluster port", "error", err)
}
if conn != nil {
conn.Close()
}
continue
}
if conn == nil {
continue
}
// Type assert to TLS connection and handshake to populate the
// connection state
tlsConn := conn.(*tls.Conn)
// Set a deadline for the handshake. This will cause clients
// that don't successfully auth to be kicked out quickly.
// Cluster connections should be reliable so being marginally
// aggressive here is fine.
err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second))
if err != nil {
if cl.logger.IsDebug() {
cl.logger.Debug("error setting deadline for cluster connection", "error", err)
}
tlsConn.Close()
continue
}
err = tlsConn.Handshake()
if err != nil {
if cl.logger.IsDebug() {
cl.logger.Debug("error handshaking cluster connection", "error", err)
}
tlsConn.Close()
continue
}
// Now, set it back to unlimited
err = tlsConn.SetDeadline(time.Time{})
if err != nil {
if cl.logger.IsDebug() {
cl.logger.Debug("error setting deadline for cluster connection", "error", err)
}
tlsConn.Close()
continue
}
cl.l.RLock()
handler, ok := cl.handlers[tlsConn.ConnectionState().NegotiatedProtocol]
cl.l.RUnlock()
if !ok {
cl.logger.Debug("unknown negotiated protocol on cluster port")
tlsConn.Close()
continue
}
if err := handler.Handoff(ctx, cl.shutdownWg, closeCh, tlsConn); err != nil {
cl.logger.Error("error handling cluster connection", "error", err)
continue
}
}
}()
}
return nil
}
// Stop stops the cluster listner
func (cl *ClusterListener) Stop() {
// Set the shutdown flag. This will cause the listeners to shut down
// within the deadline in clusterListenerAcceptDeadline
atomic.StoreUint32(cl.shutdown, 1)
cl.logger.Info("forwarding rpc listeners stopped")
// Wait for them all to shut down
cl.shutdownWg.Wait()
cl.logger.Info("rpc listeners successfully shut down")
}
// startClusterListener starts cluster request listeners during unseal. It
// is assumed that the state lock is held while this is run. Right now this
// only starts forwarding listeners; it's TBD whether other request types will
// be built in the same mechanism or started independently.
// only starts cluster listeners. Once the listener is started handlers/clients
// can start being registered to it.
func (c *Core) startClusterListener(ctx context.Context) error {
if c.clusterAddr == "" {
c.logger.Info("clustering disabled, not starting listeners")
@ -307,76 +585,46 @@ func (c *Core) startClusterListener(ctx context.Context) error {
c.logger.Debug("starting cluster listeners")
err := c.startForwarding(ctx)
if err != nil {
return err
// Create the HTTP/2 server that will be shared by both RPC and regular
// duties. Doing it this way instead of listening via the server and gRPC
// allows us to re-use the same port via ALPN. We can just tell the server
// to serve a given conn and which handler to use.
h2Server := &http2.Server{
// Our forwarding connections heartbeat regularly so anything else we
// want to go away/get cleaned up pretty rapidly
IdleTimeout: 5 * HeartbeatInterval,
}
return nil
c.clusterListener = &ClusterListener{
handlers: make(map[string]ClusterHandler),
clients: make(map[string]ClusterClient),
shutdown: new(uint32),
shutdownWg: &sync.WaitGroup{},
server: h2Server,
clusterListenerAddrs: c.clusterListenerAddrs,
clusterCipherSuites: c.clusterCipherSuites,
logger: c.logger.Named("cluster-listener"),
}
return c.clusterListener.Run(ctx)
}
// stopClusterListener stops any existing listeners during preseal. It is
// stopClusterListener stops any existing listeners during seal. It is
// assumed that the state lock is held while this is run.
func (c *Core) stopClusterListener() {
if c.clusterAddr == "" {
if c.clusterListener == nil {
c.logger.Debug("clustering disabled, not stopping listeners")
return
}
if !c.clusterListenersRunning {
c.logger.Info("cluster listeners not running")
return
}
c.logger.Info("stopping cluster listeners")
// Tell the goroutine managing the listeners to perform the shutdown
// process
c.clusterListenerShutdownCh <- struct{}{}
// The reason for this loop-de-loop is that we may be unsealing again
// quickly, and if the listeners are not yet closed, we will get socket
// bind errors. This ensures proper ordering.
c.logger.Debug("waiting for success notification while stopping cluster listeners")
<-c.clusterListenerShutdownSuccessCh
c.clusterListenersRunning = false
c.clusterListener.Stop()
c.logger.Info("cluster listeners successfully shut down")
}
// ClusterTLSConfig generates a TLS configuration based on the local/replicated
// cluster key and cert.
func (c *Core) ClusterTLSConfig(ctx context.Context, repClusters *ReplicatedClusters, perfStandbyCluster *ReplicatedCluster) (*tls.Config, error) {
// Using lookup functions allows just-in-time lookup of the current state
// of clustering as connections come and go
tlsConfig := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
GetCertificate: clusterTLSServerLookup(ctx, c, repClusters, perfStandbyCluster),
GetClientCertificate: clusterTLSClientLookup(ctx, c, repClusters, perfStandbyCluster),
GetConfigForClient: clusterTLSServerConfigLookup(ctx, c, repClusters, perfStandbyCluster),
MinVersion: tls.VersionTLS12,
CipherSuites: c.clusterCipherSuites,
}
parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate)
currCert := c.localClusterCert.Load().([]byte)
localCert := make([]byte, len(currCert))
copy(localCert, currCert)
if parsedCert != nil {
tlsConfig.ServerName = parsedCert.Subject.CommonName
pool := x509.NewCertPool()
pool.AddCert(parsedCert)
tlsConfig.RootCAs = pool
tlsConfig.ClientCAs = pool
}
return tlsConfig, nil
}
func (c *Core) SetClusterListenerAddrs(addrs []*net.TCPAddr) {
c.clusterListenerAddrs = addrs
if c.clusterAddr == "" && len(addrs) == 1 {
@ -387,3 +635,36 @@ func (c *Core) SetClusterListenerAddrs(addrs []*net.TCPAddr) {
func (c *Core) SetClusterHandler(handler http.Handler) {
c.clusterHandler = handler
}
// getGRPCDialer is used to return a dialer that has the correct TLS
// configuration. Otherwise gRPC tries to be helpful and stomps all over our
// NextProtos.
func (c *Core) getGRPCDialer(ctx context.Context, alpnProto, serverName string, caCert *x509.Certificate) func(string, time.Duration) (net.Conn, error) {
return func(addr string, timeout time.Duration) (net.Conn, error) {
if c.clusterListener == nil {
return nil, errors.New("clustering disabled")
}
tlsConfig, err := c.clusterListener.TLSConfig(ctx)
if err != nil {
c.logger.Error("failed to get tls configuration", "error", err)
return nil, err
}
if serverName != "" {
tlsConfig.ServerName = serverName
}
if caCert != nil {
pool := x509.NewCertPool()
pool.AddCert(caCert)
tlsConfig.RootCAs = pool
tlsConfig.ClientCAs = pool
}
c.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName)
tlsConfig.NextProtos = []string{alpnProto}
dialer := &net.Dialer{
Timeout: timeout,
}
return tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
}
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
@ -102,32 +103,25 @@ func TestCluster_ListenForRequests(t *testing.T) {
TestWaitActive(t, cores[0].Core)
// Use this to have a valid config after sealing since ClusterTLSConfig returns nil
var lastTLSConfig *tls.Config
checkListenersFunc := func(expectFail bool) {
tlsConfig, err := cores[0].ClusterTLSConfig(context.Background(), nil, nil)
if err != nil {
if err.Error() != consts.ErrSealed.Error() {
t.Fatal(err)
}
tlsConfig = lastTLSConfig
} else {
tlsConfig.NextProtos = []string{"h2"}
lastTLSConfig = tlsConfig
}
cores[0].clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
parsedCert := cores[0].localClusterParsedCert.Load().(*x509.Certificate)
dialer := cores[0].getGRPCDialer(context.Background(), requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
for _, ln := range cores[0].Listeners {
tcpAddr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatalf("%s not a TCP port", tcpAddr.String())
}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", tcpAddr.IP.String(), tcpAddr.Port+105), tlsConfig)
netConn, err := dialer(fmt.Sprintf("%s:%d", tcpAddr.IP.String(), tcpAddr.Port+105), 0)
conn := netConn.(*tls.Conn)
if err != nil {
if expectFail {
t.Logf("testing %s:%d unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+105)
continue
}
t.Fatalf("error: %v\nlisteners are\n%#v\n%#v\n", err, cores[0].Listeners[0], cores[0].Listeners[1])
t.Fatalf("error: %v\nlisteners are\n%#v\n%#v\n", err, cores[0].Listeners[0], cores[0].Listeners[0])
}
if expectFail {
t.Fatalf("testing %s:%d not unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+105)
@ -140,7 +134,7 @@ func TestCluster_ListenForRequests(t *testing.T) {
switch {
case connState.Version != tls.VersionTLS12:
t.Fatal("version mismatch")
case connState.NegotiatedProtocol != "h2" || !connState.NegotiatedProtocolIsMutual:
case connState.NegotiatedProtocol != requestForwardingALPN || !connState.NegotiatedProtocolIsMutual:
t.Fatal("bad protocol negotiation")
}
t.Logf("testing %s:%d successful", tcpAddr.IP.String(), tcpAddr.Port+105)
@ -392,12 +386,13 @@ func TestCluster_CustomCipherSuites(t *testing.T) {
// Wait for core to become active
TestWaitActive(t, core.Core)
tlsConf, err := core.Core.ClusterTLSConfig(context.Background(), nil, nil)
if err != nil {
t.Fatal(err)
}
core.clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{core.Core})
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", core.Listeners[0].Address.IP.String(), core.Listeners[0].Address.Port+105), tlsConf)
parsedCert := core.localClusterParsedCert.Load().(*x509.Certificate)
dialer := core.getGRPCDialer(context.Background(), requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)
netConn, err := dialer(fmt.Sprintf("%s:%d", core.Listeners[0].Address.IP.String(), core.Listeners[0].Address.Port+105), 0)
conn := netConn.(*tls.Conn)
if err != nil {
t.Fatal(err)
}

View File

@ -1,85 +0,0 @@
package vault
import (
"context"
"crypto/ecdsa"
"crypto/tls"
"crypto/x509"
"fmt"
)
var (
clusterTLSServerLookup = func(ctx context.Context, c *Core, repClusters *ReplicatedClusters, _ *ReplicatedCluster) func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
c.logger.Debug("performing server cert lookup")
switch {
default:
currCert := c.localClusterCert.Load().([]byte)
if len(currCert) == 0 {
return nil, fmt.Errorf("got forwarding connection but no local cert")
}
localCert := make([]byte, len(currCert))
copy(localCert, currCert)
return &tls.Certificate{
Certificate: [][]byte{localCert},
PrivateKey: c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
Leaf: c.localClusterParsedCert.Load().(*x509.Certificate),
}, nil
}
}
}
clusterTLSClientLookup = func(ctx context.Context, c *Core, repClusters *ReplicatedClusters, _ *ReplicatedCluster) func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return func(requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
if len(requestInfo.AcceptableCAs) != 1 {
return nil, fmt.Errorf("expected only a single acceptable CA")
}
currCert := c.localClusterCert.Load().([]byte)
if len(currCert) == 0 {
return nil, fmt.Errorf("forwarding connection client but no local cert")
}
localCert := make([]byte, len(currCert))
copy(localCert, currCert)
return &tls.Certificate{
Certificate: [][]byte{localCert},
PrivateKey: c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
Leaf: c.localClusterParsedCert.Load().(*x509.Certificate),
}, nil
}
}
clusterTLSServerConfigLookup = func(ctx context.Context, c *Core, repClusters *ReplicatedClusters, repCluster *ReplicatedCluster) func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
return func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
//c.logger.Trace("performing server config lookup")
caPool := x509.NewCertPool()
ret := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
GetCertificate: clusterTLSServerLookup(ctx, c, repClusters, repCluster),
GetClientCertificate: clusterTLSClientLookup(ctx, c, repClusters, repCluster),
MinVersion: tls.VersionTLS12,
RootCAs: caPool,
ClientCAs: caPool,
NextProtos: clientHello.SupportedProtos,
CipherSuites: c.clusterCipherSuites,
}
parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate)
if parsedCert == nil {
return nil, fmt.Errorf("forwarding connection client but no local cert")
}
caPool.AddCert(parsedCert)
return ret, nil
}
}
)

View File

@ -26,6 +26,7 @@ import (
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/logging"
@ -135,7 +136,7 @@ type activeAdvertisement struct {
RedirectAddr string `json:"redirect_addr"`
ClusterAddr string `json:"cluster_addr,omitempty"`
ClusterCert []byte `json:"cluster_cert,omitempty"`
ClusterKeyParams *clusterKeyParams `json:"cluster_key_params,omitempty"`
ClusterKeyParams *certutil.ClusterKeyParams `json:"cluster_key_params,omitempty"`
}
type unlockInformation struct {
@ -328,14 +329,6 @@ type Core struct {
clusterListenerAddrs []*net.TCPAddr
// The handler to use for request forwarding
clusterHandler http.Handler
// Tracks whether cluster listeners are running, e.g. it's safe to send a
// shutdown down the channel
clusterListenersRunning bool
// Shutdown channel for the cluster listeners
clusterListenerShutdownCh chan struct{}
// Shutdown success channel. We need this to be done serially to ensure
// that binds are removed before they might be reinstated.
clusterListenerShutdownSuccessCh chan struct{}
// Write lock used to ensure that we don't have multiple connections adjust
// this value at the same time
requestForwardingConnectionLock sync.RWMutex
@ -346,8 +339,6 @@ type Core struct {
clusterLeaderParams *atomic.Value
// Info on cluster members
clusterPeerClusterAddrsCache *cache.Cache
// Stores whether we currently have a server running
rpcServerActive *uint32
// The context for the client
rpcClientConnContext context.Context
// The function for canceling the client connection
@ -422,6 +413,9 @@ type Core struct {
// artifacts in a case sensitive manner. To be used only in testing.
loadCaseSensitiveIdentityStore bool
// clusterListener starts up and manages connections on the cluster ports
clusterListener *ClusterListener
// Telemetry objects
metricsHelper *metricsutil.MetricsHelper
}
@ -582,13 +576,10 @@ func NewCore(conf *CoreConfig) (*Core, error) {
maxLeaseTTL: conf.MaxLeaseTTL,
cachingDisabled: conf.DisableCache,
clusterName: conf.ClusterName,
clusterListenerShutdownCh: make(chan struct{}),
clusterListenerShutdownSuccessCh: make(chan struct{}),
clusterPeerClusterAddrsCache: cache.New(3*HeartbeatInterval, time.Second),
enableMlock: !conf.DisableMlock,
rawEnabled: conf.EnableRaw,
replicationState: new(uint32),
rpcServerActive: new(uint32),
atomicPrimaryClusterAddrs: new(atomic.Value),
atomicPrimaryFailoverAddrs: new(atomic.Value),
localClusterPrivateKey: new(atomic.Value),
@ -1043,6 +1034,10 @@ func (c *Core) unsealInternal(ctx context.Context, masterKey []byte) (bool, erro
return false, err
}
if err := c.startClusterListener(ctx); err != nil {
return false, err
}
// Do post-unseal setup if HA is not enabled
if c.ha == nil {
// We still need to set up cluster info even if it's not part of a
@ -1365,6 +1360,9 @@ func (c *Core) sealInternalWithOptions(grabStateLock, keepHALock bool) error {
c.logger.Debug("runStandby done")
}
// Stop the cluster listener
c.stopClusterListener()
c.logger.Debug("sealing barrier")
if err := c.barrier.Seal(); err != nil {
c.logger.Error("error sealing barrier", "error", err)
@ -1461,8 +1459,8 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c
c.auditBroker = NewAuditBroker(c.logger)
}
if c.ha != nil || shouldStartClusterListener(c) {
if err := c.startClusterListener(ctx); err != nil {
if c.clusterListener != nil && (c.ha != nil || shouldStartClusterListener(c)) {
if err := c.startForwarding(ctx); err != nil {
return err
}
}
@ -1553,7 +1551,7 @@ func (c *Core) preSeal() error {
}
c.clusterParamsLock.Unlock()
c.stopClusterListener()
c.stopForwarding()
if err := c.teardownAudits(); err != nil {
result = multierror.Append(result, errwrap.Wrapf("error tearing down audits: {{err}}", err))

View File

@ -4,11 +4,14 @@ package vault
import (
"context"
"time"
"github.com/hashicorp/vault/helper/license"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/physical"
"github.com/hashicorp/vault/vault/replication"
cache "github.com/patrickmn/go-cache"
)
type entCore struct{}
@ -89,7 +92,7 @@ func (c *Core) namepaceByPath(string) *namespace.Namespace {
return namespace.RootNamespace
}
func (c *Core) setupReplicatedClusterPrimary(*ReplicatedCluster) error { return nil }
func (c *Core) setupReplicatedClusterPrimary(*replication.Cluster) error { return nil }
func (c *Core) perfStandbyCount() int { return 0 }
@ -104,3 +107,7 @@ func (c *Core) checkReplicatedFiltering(context.Context, *MountEntry, string) (b
func (c *Core) invalidateSentinelPolicy(PolicyType, string) {}
func (c *Core) removePerfStandbySecondary(context.Context, string) {}
func (c *Core) perfStandbyClusterHandler() (*replication.Cluster, *cache.Cache, chan struct{}, error) {
return nil, cache.New(2*HeartbeatInterval, 1*time.Second), make(chan struct{}), nil
}

View File

@ -14,6 +14,7 @@ import (
multierror "github.com/hashicorp/go-multierror"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/namespace"
@ -817,7 +818,7 @@ func (c *Core) advertiseLeader(ctx context.Context, uuid string, leaderLostCh <-
return fmt.Errorf("unknown cluster private key type %T", c.localClusterPrivateKey.Load())
}
keyParams := &clusterKeyParams{
keyParams := &certutil.ClusterKeyParams{
Type: corePrivateKeyTypeP521,
X: key.X,
Y: key.Y,

View File

@ -1,11 +1,16 @@
// +build !enterprise
package vault
package replication
import "github.com/hashicorp/vault/helper/consts"
type ReplicatedCluster struct {
type Cluster struct {
State consts.ReplicationState
ClusterID string
PrimaryClusterAddr string
}
type Clusters struct {
DR *Cluster
Performance *Cluster
}

View File

@ -1,23 +1,23 @@
package vault
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
math "math"
"net"
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"
cache "github.com/patrickmn/go-cache"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/consts"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/forwarding"
"github.com/hashicorp/vault/vault/replication"
cache "github.com/patrickmn/go-cache"
"golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
@ -44,11 +44,154 @@ var (
HeartbeatInterval = 5 * time.Second
)
type SecondaryConnsCacheVals struct {
ID string
Token string
Connection net.Conn
Mode consts.ReplicationState
type requestForwardingHandler struct {
fws *http2.Server
fwRPCServer *grpc.Server
logger log.Logger
ha bool
core *Core
stopCh chan struct{}
}
type requestForwardingClusterClient struct {
core *Core
}
// NewRequestForwardingHandler creates a cluster handler for use with request
// forwarding.
func NewRequestForwardingHandler(c *Core, fws *http2.Server, perfStandbySlots chan struct{}, perfStandbyRepCluster *replication.Cluster, perfStandbyCache *cache.Cache) (*requestForwardingHandler, error) {
// Resolve locally to avoid races
ha := c.ha != nil
fwRPCServer := grpc.NewServer(
grpc.KeepaliveParams(keepalive.ServerParameters{
Time: 2 * HeartbeatInterval,
}),
grpc.MaxRecvMsgSize(math.MaxInt32),
grpc.MaxSendMsgSize(math.MaxInt32),
)
if ha && c.clusterHandler != nil {
RegisterRequestForwardingServer(fwRPCServer, &forwardedRequestRPCServer{
core: c,
handler: c.clusterHandler,
perfStandbySlots: perfStandbySlots,
perfStandbyRepCluster: perfStandbyRepCluster,
perfStandbyCache: perfStandbyCache,
})
}
return &requestForwardingHandler{
fws: fws,
fwRPCServer: fwRPCServer,
ha: ha,
logger: c.logger.Named("request-forward"),
core: c,
stopCh: make(chan struct{}),
}, nil
}
// ClientLookup satisfies the ClusterClient interface and returns the ha tls
// client certs.
func (c *requestForwardingClusterClient) ClientLookup(ctx context.Context, requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
parsedCert := c.core.localClusterParsedCert.Load().(*x509.Certificate)
if parsedCert == nil {
return nil, nil
}
currCert := c.core.localClusterCert.Load().([]byte)
if len(currCert) == 0 {
return nil, nil
}
localCert := make([]byte, len(currCert))
copy(localCert, currCert)
for _, subj := range requestInfo.AcceptableCAs {
if bytes.Equal(subj, parsedCert.RawIssuer) {
return &tls.Certificate{
Certificate: [][]byte{localCert},
PrivateKey: c.core.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
Leaf: c.core.localClusterParsedCert.Load().(*x509.Certificate),
}, nil
}
}
return nil, nil
}
// ServerLookup satisfies the ClusterHandler interface and returns the server's
// tls certs.
func (rf *requestForwardingHandler) ServerLookup(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
currCert := rf.core.localClusterCert.Load().([]byte)
if len(currCert) == 0 {
return nil, fmt.Errorf("got forwarding connection but no local cert")
}
localCert := make([]byte, len(currCert))
copy(localCert, currCert)
return &tls.Certificate{
Certificate: [][]byte{localCert},
PrivateKey: rf.core.localClusterPrivateKey.Load().(*ecdsa.PrivateKey),
Leaf: rf.core.localClusterParsedCert.Load().(*x509.Certificate),
}, nil
}
// CALookup satisfies the ClusterHandler interface and returns the ha ca cert.
func (rf *requestForwardingHandler) CALookup(ctx context.Context) (*x509.Certificate, error) {
parsedCert := rf.core.localClusterParsedCert.Load().(*x509.Certificate)
if parsedCert == nil {
return nil, fmt.Errorf("forwarding connection client but no local cert")
}
return parsedCert, nil
}
// Handoff serves a request forwarding connection.
func (rf *requestForwardingHandler) Handoff(ctx context.Context, shutdownWg *sync.WaitGroup, closeCh chan struct{}, tlsConn *tls.Conn) error {
if !rf.ha {
tlsConn.Close()
return nil
}
rf.logger.Debug("got request forwarding connection")
shutdownWg.Add(2)
// quitCh is used to close the connection and the second
// goroutine if the server closes before closeCh.
quitCh := make(chan struct{})
go func() {
select {
case <-quitCh:
case <-closeCh:
case <-rf.stopCh:
}
tlsConn.Close()
shutdownWg.Done()
}()
go func() {
rf.fws.ServeConn(tlsConn, &http2.ServeConnOpts{
Handler: rf.fwRPCServer,
BaseConfig: &http.Server{
ErrorLog: rf.logger.StandardLogger(nil),
},
})
// close the quitCh which will close the connection and
// the other goroutine.
close(quitCh)
shutdownWg.Done()
}()
return nil
}
// Stop stops the request forwarding server and closes connections.
func (rf *requestForwardingHandler) Stop() error {
close(rf.stopCh)
rf.fwRPCServer.Stop()
return nil
}
// Starts the listeners and servers necessary to handle forwarded requests
@ -62,269 +205,32 @@ func (c *Core) startForwarding(ctx context.Context) error {
c.requestForwardingConnectionLock.Unlock()
// Resolve locally to avoid races
ha := c.ha != nil
var perfStandbyRepCluster *ReplicatedCluster
if ha {
id, err := uuid.GenerateUUID()
if err != nil {
return err
}
perfStandbyRepCluster = &ReplicatedCluster{
State: consts.ReplicationPerformanceStandby,
ClusterID: id,
PrimaryClusterAddr: c.clusterAddr,
}
if err = c.setupReplicatedClusterPrimary(perfStandbyRepCluster); err != nil {
return err
}
}
// Get our TLS config
tlsConfig, err := c.ClusterTLSConfig(ctx, nil, perfStandbyRepCluster)
if err != nil {
c.logger.Error("failed to get tls configuration when starting forwarding", "error", err)
return err
}
// The server supports all of the possible protos
tlsConfig.NextProtos = []string{"h2", requestForwardingALPN, perfStandbyALPN, PerformanceReplicationALPN, DRReplicationALPN}
if !atomic.CompareAndSwapUint32(c.rpcServerActive, 0, 1) {
c.logger.Warn("forwarding rpc server already running")
if c.ha == nil || c.clusterListener == nil {
return nil
}
fwRPCServer := grpc.NewServer(
grpc.KeepaliveParams(keepalive.ServerParameters{
Time: 2 * HeartbeatInterval,
}),
grpc.MaxRecvMsgSize(math.MaxInt32),
grpc.MaxSendMsgSize(math.MaxInt32),
)
// Setup performance standby RPC servers
perfStandbyCount := 0
if !c.IsDRSecondary() && !c.disablePerfStandby {
perfStandbyCount = c.perfStandbyCount()
}
perfStandbySlots := make(chan struct{}, perfStandbyCount)
perfStandbyCache := cache.New(2*HeartbeatInterval, 1*time.Second)
perfStandbyCache.OnEvicted(func(secondaryID string, _ interface{}) {
c.logger.Debug("removing performance standby", "id", secondaryID)
c.removePerfStandbySecondary(context.Background(), secondaryID)
select {
case <-perfStandbySlots:
default:
c.logger.Warn("perf secondary timeout hit but no slot to free")
}
})
perfStandbyReplicationRPCServer := perfStandbyRPCServer(c, perfStandbyCache)
if ha && c.clusterHandler != nil {
RegisterRequestForwardingServer(fwRPCServer, &forwardedRequestRPCServer{
core: c,
handler: c.clusterHandler,
perfStandbySlots: perfStandbySlots,
perfStandbyRepCluster: perfStandbyRepCluster,
perfStandbyCache: perfStandbyCache,
})
}
// Create the HTTP/2 server that will be shared by both RPC and regular
// duties. Doing it this way instead of listening via the server and gRPC
// allows us to re-use the same port via ALPN. We can just tell the server
// to serve a given conn and which handler to use.
fws := &http2.Server{
// Our forwarding connections heartbeat regularly so anything else we
// want to go away/get cleaned up pretty rapidly
IdleTimeout: 5 * HeartbeatInterval,
}
// Shutdown coordination logic
shutdown := new(uint32)
shutdownWg := &sync.WaitGroup{}
for _, addr := range c.clusterListenerAddrs {
shutdownWg.Add(1)
// Force a local resolution to avoid data races
laddr := addr
// Start our listening loop
go func() {
defer shutdownWg.Done()
// closeCh is used to shutdown the spawned goroutines once this
// function returns
closeCh := make(chan struct{})
defer func() {
close(closeCh)
}()
if c.logger.IsInfo() {
c.logger.Info("starting listener", "listener_address", laddr)
}
// Create a TCP listener. We do this separately and specifically
// with TCP so that we can set deadlines.
tcpLn, err := net.ListenTCP("tcp", laddr)
perfStandbyRepCluster, perfStandbyCache, perfStandbySlots, err := c.perfStandbyClusterHandler()
if err != nil {
c.logger.Error("error starting listener", "error", err)
return
return err
}
// Wrap the listener with TLS
tlsLn := tls.NewListener(tcpLn, tlsConfig)
defer tlsLn.Close()
if c.logger.IsInfo() {
c.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr())
}
for {
if atomic.LoadUint32(shutdown) > 0 {
return
}
// Set the deadline for the accept call. If it passes we'll get
// an error, causing us to check the condition at the top
// again.
tcpLn.SetDeadline(time.Now().Add(clusterListenerAcceptDeadline))
// Accept the connection
conn, err := tlsLn.Accept()
handler, err := NewRequestForwardingHandler(c, c.clusterListener.Server(), perfStandbySlots, perfStandbyRepCluster, perfStandbyCache)
if err != nil {
if err, ok := err.(net.Error); ok && !err.Timeout() {
c.logger.Debug("non-timeout error accepting on cluster port", "error", err)
}
if conn != nil {
conn.Close()
}
continue
}
if conn == nil {
continue
return err
}
// Type assert to TLS connection and handshake to populate the
// connection state
tlsConn := conn.(*tls.Conn)
// Set a deadline for the handshake. This will cause clients
// that don't successfully auth to be kicked out quickly.
// Cluster connections should be reliable so being marginally
// aggressive here is fine.
err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second))
if err != nil {
if c.logger.IsDebug() {
c.logger.Debug("error setting deadline for cluster connection", "error", err)
}
tlsConn.Close()
continue
}
err = tlsConn.Handshake()
if err != nil {
if c.logger.IsDebug() {
c.logger.Debug("error handshaking cluster connection", "error", err)
}
tlsConn.Close()
continue
}
// Now, set it back to unlimited
err = tlsConn.SetDeadline(time.Time{})
if err != nil {
if c.logger.IsDebug() {
c.logger.Debug("error setting deadline for cluster connection", "error", err)
}
tlsConn.Close()
continue
}
switch tlsConn.ConnectionState().NegotiatedProtocol {
case requestForwardingALPN:
if !ha {
tlsConn.Close()
continue
}
c.logger.Debug("got request forwarding connection")
shutdownWg.Add(2)
// quitCh is used to close the connection and the second
// goroutine if the server closes before closeCh.
quitCh := make(chan struct{})
go func() {
select {
case <-quitCh:
case <-closeCh:
}
tlsConn.Close()
shutdownWg.Done()
}()
go func() {
fws.ServeConn(tlsConn, &http2.ServeConnOpts{
Handler: fwRPCServer,
BaseConfig: &http.Server{
ErrorLog: c.logger.StandardLogger(nil),
},
})
// close the quitCh which will close the connection and
// the other goroutine.
close(quitCh)
shutdownWg.Done()
}()
case PerformanceReplicationALPN, DRReplicationALPN, perfStandbyALPN:
handleReplicationConn(ctx, c, shutdownWg, closeCh, fws, perfStandbyReplicationRPCServer, perfStandbyCache, tlsConn)
default:
c.logger.Debug("unknown negotiated protocol on cluster port")
tlsConn.Close()
continue
}
}
}()
}
// This is in its own goroutine so that we don't block the main thread, and
// thus we use atomic and channels to coordinate
// However, because you can't query the status of a channel, we set a bool
// here while we have the state lock to know whether to actually send a
// shutdown (e.g. whether the channel will block). See issue #2083.
c.clusterListenersRunning = true
go func() {
// If we get told to shut down...
<-c.clusterListenerShutdownCh
// Stop the RPC server
c.logger.Info("shutting down forwarding rpc listeners")
fwRPCServer.Stop()
// Set the shutdown flag. This will cause the listeners to shut down
// within the deadline in clusterListenerAcceptDeadline
atomic.StoreUint32(shutdown, 1)
c.logger.Info("forwarding rpc listeners stopped")
// Wait for them all to shut down
shutdownWg.Wait()
c.logger.Info("rpc listeners successfully shut down")
// Clear us up to run this function again
atomic.StoreUint32(c.rpcServerActive, 0)
// Tell the main thread that shutdown is done.
c.clusterListenerShutdownSuccessCh <- struct{}{}
}()
c.clusterListener.AddHandler(requestForwardingALPN, handler)
return nil
}
func (c *Core) stopForwarding() {
if c.clusterListener != nil {
c.clusterListener.StopHandler(requestForwardingALPN)
c.clusterListener.StopHandler(perfStandbyALPN)
}
}
// refreshRequestForwardingConnection ensures that the client/transport are
// alive and that the current active address value matches the most
// recently-known address.
@ -349,13 +255,25 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd
return err
}
parsedCert := c.localClusterParsedCert.Load().(*x509.Certificate)
if parsedCert == nil {
c.logger.Error("no request forwarding cluster certificate found")
return errors.New("no request forwarding cluster certificate found")
}
if c.clusterListener != nil {
c.clusterListener.AddClient(requestForwardingALPN, &requestForwardingClusterClient{
core: c,
})
}
// Set up grpc forwarding handling
// It's not really insecure, but we have to dial manually to get the
// ALPN header right. It's just "insecure" because GRPC isn't managing
// the TLS state.
dctx, cancelFunc := context.WithCancel(ctx)
c.rpcClientConn, err = grpc.DialContext(dctx, clusterURL.Host,
grpc.WithDialer(c.getGRPCDialer(ctx, requestForwardingALPN, "", nil, nil, nil)),
grpc.WithDialer(c.getGRPCDialer(ctx, requestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)),
grpc.WithInsecure(), // it's not, we handle it in the dialer
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 2 * HeartbeatInterval,
@ -398,6 +316,9 @@ func (c *Core) clearForwardingClients() {
c.rpcClientConnContext = nil
c.rpcForwardingClient = nil
if c.clusterListener != nil {
c.clusterListener.RemoveClient(requestForwardingALPN)
}
c.clusterLeaderParams.Store((*ClusterLeaderParams)(nil))
}
@ -450,32 +371,3 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro
return int(resp.StatusCode), header, resp.Body, nil
}
// getGRPCDialer is used to return a dialer that has the correct TLS
// configuration. Otherwise gRPC tries to be helpful and stomps all over our
// NextProtos.
func (c *Core) getGRPCDialer(ctx context.Context, alpnProto, serverName string, caCert *x509.Certificate, repClusters *ReplicatedClusters, perfStandbyCluster *ReplicatedCluster) func(string, time.Duration) (net.Conn, error) {
return func(addr string, timeout time.Duration) (net.Conn, error) {
tlsConfig, err := c.ClusterTLSConfig(ctx, repClusters, perfStandbyCluster)
if err != nil {
c.logger.Error("failed to get tls configuration", "error", err)
return nil, err
}
if serverName != "" {
tlsConfig.ServerName = serverName
}
if caCert != nil {
pool := x509.NewCertPool()
pool.AddCert(caCert)
tlsConfig.RootCAs = pool
tlsConfig.ClientCAs = pool
}
c.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName)
tlsConfig.NextProtos = []string{alpnProto}
dialer := &net.Dialer{
Timeout: timeout,
}
return tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
}
}

View File

@ -9,6 +9,7 @@ import (
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/forwarding"
"github.com/hashicorp/vault/vault/replication"
cache "github.com/patrickmn/go-cache"
)
@ -16,7 +17,7 @@ type forwardedRequestRPCServer struct {
core *Core
handler http.Handler
perfStandbySlots chan struct{}
perfStandbyRepCluster *ReplicatedCluster
perfStandbyRepCluster *replication.Cluster
perfStandbyCache *cache.Cache
}

View File

@ -1,18 +0,0 @@
// +build !enterprise
package vault
import (
"context"
"crypto/tls"
"sync"
cache "github.com/patrickmn/go-cache"
"golang.org/x/net/http2"
grpc "google.golang.org/grpc"
)
func perfStandbyRPCServer(*Core, *cache.Cache) *grpc.Server { return nil }
func handleReplicationConn(context.Context, *Core, *sync.WaitGroup, chan struct{}, *http2.Server, *grpc.Server, *cache.Cache, *tls.Conn) {
}

View File

@ -14,6 +14,7 @@ import (
"github.com/SermoDigital/jose/jws"
"github.com/SermoDigital/jose/jwt"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/namespace"
@ -31,7 +32,7 @@ func (c *Core) ensureWrappingKey(ctx context.Context) error {
return err
}
var keyParams clusterKeyParams
var keyParams certutil.ClusterKeyParams
if entry == nil {
key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)