open-vault/vault/cluster.go
Brian Kassouf f5b5fbb392
Refactor the cluster listener (#6232)
* Port over OSS cluster port refactor components

* Start forwarding

* Cleanup a bit

* Fix copy error

* Return error from perf standby creation

* Add some more comments

* Fix copy/paste error
2019-02-14 18:14:56 -08:00

671 lines
19 KiB
Go

package vault
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"errors"
"fmt"
"math/big"
mathrand "math/rand"
"net"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/logical"
"golang.org/x/net/http2"
)
const (
// Storage path where the local cluster name and identifier are stored
coreLocalClusterInfoPath = "core/cluster/local/info"
corePrivateKeyTypeP521 = "p521"
corePrivateKeyTypeED25519 = "ed25519"
// Internal so as not to log a trace message
IntNoForwardingHeaderName = "X-Vault-Internal-No-Request-Forwarding"
)
var (
ErrCannotForward = errors.New("cannot forward request; no connection or address not known")
)
type ClusterLeaderParams struct {
LeaderUUID string
LeaderRedirectAddr string
LeaderClusterAddr string
}
// Structure representing the storage entry that holds cluster information
type Cluster struct {
// Name of the cluster
Name string `json:"name" structs:"name" mapstructure:"name"`
// Identifier of the cluster
ID string `json:"id" structs:"id" mapstructure:"id"`
}
// Cluster fetches the details of the local cluster. This method errors out
// when Vault is sealed.
func (c *Core) Cluster(ctx context.Context) (*Cluster, error) {
var cluster Cluster
// Fetch the storage entry. This call fails when Vault is sealed.
entry, err := c.barrier.Get(ctx, coreLocalClusterInfoPath)
if err != nil {
return nil, err
}
if entry == nil {
return &cluster, nil
}
// Decode the cluster information
if err = jsonutil.DecodeJSON(entry.Value, &cluster); err != nil {
return nil, errwrap.Wrapf("failed to decode cluster details: {{err}}", err)
}
// Set in config file
if c.clusterName != "" {
cluster.Name = c.clusterName
}
return &cluster, nil
}
// This sets our local cluster cert and private key based on the advertisement.
// It also ensures the cert is in our local cluster cert pool.
func (c *Core) loadLocalClusterTLS(adv activeAdvertisement) (retErr error) {
defer func() {
if retErr != nil {
c.localClusterCert.Store(([]byte)(nil))
c.localClusterParsedCert.Store((*x509.Certificate)(nil))
c.localClusterPrivateKey.Store((*ecdsa.PrivateKey)(nil))
c.requestForwardingConnectionLock.Lock()
c.clearForwardingClients()
c.requestForwardingConnectionLock.Unlock()
}
}()
switch {
case adv.ClusterAddr == "":
// Clustering disabled on the server, don't try to look for params
return nil
case adv.ClusterKeyParams == nil:
c.logger.Error("no key params found loading local cluster TLS information")
return fmt.Errorf("no local cluster key params found")
case adv.ClusterKeyParams.X == nil, adv.ClusterKeyParams.Y == nil, adv.ClusterKeyParams.D == nil:
c.logger.Error("failed to parse local cluster key due to missing params")
return fmt.Errorf("failed to parse local cluster key")
case adv.ClusterKeyParams.Type != corePrivateKeyTypeP521:
c.logger.Error("unknown local cluster key type", "key_type", adv.ClusterKeyParams.Type)
return fmt.Errorf("failed to find valid local cluster key type")
case adv.ClusterCert == nil || len(adv.ClusterCert) == 0:
c.logger.Error("no local cluster cert found")
return fmt.Errorf("no local cluster cert found")
}
c.localClusterPrivateKey.Store(&ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P521(),
X: adv.ClusterKeyParams.X,
Y: adv.ClusterKeyParams.Y,
},
D: adv.ClusterKeyParams.D,
})
locCert := make([]byte, len(adv.ClusterCert))
copy(locCert, adv.ClusterCert)
c.localClusterCert.Store(locCert)
cert, err := x509.ParseCertificate(adv.ClusterCert)
if err != nil {
c.logger.Error("failed parsing local cluster certificate", "error", err)
return errwrap.Wrapf("error parsing local cluster certificate: {{err}}", err)
}
c.localClusterParsedCert.Store(cert)
return nil
}
// setupCluster creates storage entries for holding Vault cluster information.
// Entries will be created only if they are not already present. If clusterName
// is not supplied, this method will auto-generate it.
func (c *Core) setupCluster(ctx context.Context) error {
// Prevent data races with the TLS parameters
c.clusterParamsLock.Lock()
defer c.clusterParamsLock.Unlock()
// Check if storage index is already present or not
cluster, err := c.Cluster(ctx)
if err != nil {
c.logger.Error("failed to get cluster details", "error", err)
return err
}
var modified bool
if cluster == nil {
cluster = &Cluster{}
}
if cluster.Name == "" {
// If cluster name is not supplied, generate one
if c.clusterName == "" {
c.logger.Debug("cluster name not found/set, generating new")
clusterNameBytes, err := uuid.GenerateRandomBytes(4)
if err != nil {
c.logger.Error("failed to generate cluster name", "error", err)
return err
}
c.clusterName = fmt.Sprintf("vault-cluster-%08x", clusterNameBytes)
}
cluster.Name = c.clusterName
if c.logger.IsDebug() {
c.logger.Debug("cluster name set", "name", cluster.Name)
}
modified = true
}
if cluster.ID == "" {
c.logger.Debug("cluster ID not found, generating new")
// Generate a clusterID
cluster.ID, err = uuid.GenerateUUID()
if err != nil {
c.logger.Error("failed to generate cluster identifier", "error", err)
return err
}
if c.logger.IsDebug() {
c.logger.Debug("cluster ID set", "id", cluster.ID)
}
modified = true
}
// If we're using HA, generate server-to-server parameters
if c.ha != nil {
// Create a private key
if c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey) == nil {
c.logger.Debug("generating cluster private key")
key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
if err != nil {
c.logger.Error("failed to generate local cluster key", "error", err)
return err
}
c.localClusterPrivateKey.Store(key)
}
// Create a certificate
if c.localClusterCert.Load().([]byte) == nil {
c.logger.Debug("generating local cluster certificate")
host, err := uuid.GenerateUUID()
if err != nil {
return err
}
host = fmt.Sprintf("fw-%s", host)
template := &x509.Certificate{
Subject: pkix.Name{
CommonName: host,
},
DNSNames: []string{host},
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
// 30 years of single-active uptime ought to be enough for anybody
NotAfter: time.Now().Add(262980 * time.Hour),
BasicConstraintsValid: true,
IsCA: true,
}
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey).Public(), c.localClusterPrivateKey.Load().(*ecdsa.PrivateKey))
if err != nil {
c.logger.Error("error generating self-signed cert", "error", err)
return errwrap.Wrapf("unable to generate local cluster certificate: {{err}}", err)
}
parsedCert, err := x509.ParseCertificate(certBytes)
if err != nil {
c.logger.Error("error parsing self-signed cert", "error", err)
return errwrap.Wrapf("error parsing generated certificate: {{err}}", err)
}
c.localClusterCert.Store(certBytes)
c.localClusterParsedCert.Store(parsedCert)
}
}
if modified {
// Encode the cluster information into as a JSON string
rawCluster, err := json.Marshal(cluster)
if err != nil {
c.logger.Error("failed to encode cluster details", "error", err)
return err
}
// Store it
err = c.barrier.Put(ctx, &logical.StorageEntry{
Key: coreLocalClusterInfoPath,
Value: rawCluster,
})
if err != nil {
c.logger.Error("failed to store cluster details", "error", err)
return err
}
}
return nil
}
// ClusterClient is used to lookup a client certificate.
type ClusterClient interface {
ClientLookup(context.Context, *tls.CertificateRequestInfo) (*tls.Certificate, error)
}
// ClusterHandler exposes functions for looking up TLS configuration and handing
// off a connection for a cluster listener application.
type ClusterHandler interface {
ServerLookup(context.Context, *tls.ClientHelloInfo) (*tls.Certificate, error)
CALookup(context.Context) (*x509.Certificate, error)
// Handoff is used to pass the connection lifetime off to
// the storage backend
Handoff(context.Context, *sync.WaitGroup, chan struct{}, *tls.Conn) error
Stop() error
}
// ClusterListener is the source of truth for cluster handlers and connection
// clients. It dynamically builds the cluster TLS information. It's also
// responsible for starting tcp listeners and accepting new cluster connections.
type ClusterListener struct {
handlers map[string]ClusterHandler
clients map[string]ClusterClient
shutdown *uint32
shutdownWg *sync.WaitGroup
server *http2.Server
clusterListenerAddrs []*net.TCPAddr
clusterCipherSuites []uint16
logger log.Logger
l sync.RWMutex
}
// AddClient adds a new client for an ALPN name
func (cl *ClusterListener) AddClient(alpn string, client ClusterClient) {
cl.l.Lock()
cl.clients[alpn] = client
cl.l.Unlock()
}
// RemoveClient removes the client for the specified ALPN name
func (cl *ClusterListener) RemoveClient(alpn string) {
cl.l.Lock()
delete(cl.clients, alpn)
cl.l.Unlock()
}
// AddHandler registers a new cluster handler for the provided ALPN name.
func (cl *ClusterListener) AddHandler(alpn string, handler ClusterHandler) {
cl.l.Lock()
cl.handlers[alpn] = handler
cl.l.Unlock()
}
// StopHandler stops the cluster handler for the provided ALPN name, it also
// calls stop on the handler.
func (cl *ClusterListener) StopHandler(alpn string) {
cl.l.Lock()
handler, ok := cl.handlers[alpn]
delete(cl.handlers, alpn)
cl.l.Unlock()
if ok {
handler.Stop()
}
}
// Server returns the http2 server that the cluster listener is using
func (cl *ClusterListener) Server() *http2.Server {
return cl.server
}
// TLSConfig returns a tls config object that uses dynamic lookups to correctly
// authenticate registered handlers/clients
func (cl *ClusterListener) TLSConfig(ctx context.Context) (*tls.Config, error) {
serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cl.logger.Debug("performing server cert lookup")
cl.l.RLock()
defer cl.l.RUnlock()
for _, v := range clientHello.SupportedProtos {
if handler, ok := cl.handlers[v]; ok {
return handler.ServerLookup(ctx, clientHello)
}
}
return nil, errors.New("unsupported protocol")
}
clientLookup := func(requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
cl.logger.Debug("performing client cert lookup")
cl.l.RLock()
defer cl.l.RUnlock()
for _, client := range cl.clients {
cert, err := client.ClientLookup(ctx, requestInfo)
if err == nil && cert != nil {
return cert, nil
}
}
return nil, errors.New("no client cert found")
}
serverConfigLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
caPool := x509.NewCertPool()
ret := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
GetCertificate: serverLookup,
GetClientCertificate: clientLookup,
MinVersion: tls.VersionTLS12,
RootCAs: caPool,
ClientCAs: caPool,
NextProtos: clientHello.SupportedProtos,
CipherSuites: cl.clusterCipherSuites,
}
cl.l.RLock()
defer cl.l.RUnlock()
for _, v := range clientHello.SupportedProtos {
if handler, ok := cl.handlers[v]; ok {
ca, err := handler.CALookup(ctx)
if err != nil {
return nil, err
}
caPool.AddCert(ca)
return ret, nil
}
}
return nil, errors.New("unsupported protocol")
}
return &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
GetCertificate: serverLookup,
GetClientCertificate: clientLookup,
GetConfigForClient: serverConfigLookup,
MinVersion: tls.VersionTLS12,
CipherSuites: cl.clusterCipherSuites,
}, nil
}
// Run starts the tcp listeners and will accept connections until stop is
// called. This function blocks so should be called in a go routine.
func (cl *ClusterListener) Run(ctx context.Context) error {
// Get our TLS config
tlsConfig, err := cl.TLSConfig(ctx)
if err != nil {
cl.logger.Error("failed to get tls configuration when starting cluster listener", "error", err)
return err
}
// The server supports all of the possible protos
tlsConfig.NextProtos = []string{"h2", requestForwardingALPN, perfStandbyALPN, PerformanceReplicationALPN, DRReplicationALPN}
for _, addr := range cl.clusterListenerAddrs {
cl.shutdownWg.Add(1)
// Force a local resolution to avoid data races
laddr := addr
// Start our listening loop
go func() {
defer cl.shutdownWg.Done()
// closeCh is used to shutdown the spawned goroutines once this
// function returns
closeCh := make(chan struct{})
defer func() {
close(closeCh)
}()
if cl.logger.IsInfo() {
cl.logger.Info("starting listener", "listener_address", laddr)
}
// Create a TCP listener. We do this separately and specifically
// with TCP so that we can set deadlines.
tcpLn, err := net.ListenTCP("tcp", laddr)
if err != nil {
cl.logger.Error("error starting listener", "error", err)
return
}
// Wrap the listener with TLS
tlsLn := tls.NewListener(tcpLn, tlsConfig)
defer tlsLn.Close()
if cl.logger.IsInfo() {
cl.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr())
}
for {
if atomic.LoadUint32(cl.shutdown) > 0 {
return
}
// Set the deadline for the accept call. If it passes we'll get
// an error, causing us to check the condition at the top
// again.
tcpLn.SetDeadline(time.Now().Add(clusterListenerAcceptDeadline))
// Accept the connection
conn, err := tlsLn.Accept()
if err != nil {
if err, ok := err.(net.Error); ok && !err.Timeout() {
cl.logger.Debug("non-timeout error accepting on cluster port", "error", err)
}
if conn != nil {
conn.Close()
}
continue
}
if conn == nil {
continue
}
// Type assert to TLS connection and handshake to populate the
// connection state
tlsConn := conn.(*tls.Conn)
// Set a deadline for the handshake. This will cause clients
// that don't successfully auth to be kicked out quickly.
// Cluster connections should be reliable so being marginally
// aggressive here is fine.
err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second))
if err != nil {
if cl.logger.IsDebug() {
cl.logger.Debug("error setting deadline for cluster connection", "error", err)
}
tlsConn.Close()
continue
}
err = tlsConn.Handshake()
if err != nil {
if cl.logger.IsDebug() {
cl.logger.Debug("error handshaking cluster connection", "error", err)
}
tlsConn.Close()
continue
}
// Now, set it back to unlimited
err = tlsConn.SetDeadline(time.Time{})
if err != nil {
if cl.logger.IsDebug() {
cl.logger.Debug("error setting deadline for cluster connection", "error", err)
}
tlsConn.Close()
continue
}
cl.l.RLock()
handler, ok := cl.handlers[tlsConn.ConnectionState().NegotiatedProtocol]
cl.l.RUnlock()
if !ok {
cl.logger.Debug("unknown negotiated protocol on cluster port")
tlsConn.Close()
continue
}
if err := handler.Handoff(ctx, cl.shutdownWg, closeCh, tlsConn); err != nil {
cl.logger.Error("error handling cluster connection", "error", err)
continue
}
}
}()
}
return nil
}
// Stop stops the cluster listner
func (cl *ClusterListener) Stop() {
// Set the shutdown flag. This will cause the listeners to shut down
// within the deadline in clusterListenerAcceptDeadline
atomic.StoreUint32(cl.shutdown, 1)
cl.logger.Info("forwarding rpc listeners stopped")
// Wait for them all to shut down
cl.shutdownWg.Wait()
cl.logger.Info("rpc listeners successfully shut down")
}
// startClusterListener starts cluster request listeners during unseal. It
// is assumed that the state lock is held while this is run. Right now this
// only starts cluster listeners. Once the listener is started handlers/clients
// can start being registered to it.
func (c *Core) startClusterListener(ctx context.Context) error {
if c.clusterAddr == "" {
c.logger.Info("clustering disabled, not starting listeners")
return nil
}
if c.clusterListenerAddrs == nil || len(c.clusterListenerAddrs) == 0 {
c.logger.Warn("clustering not disabled but no addresses to listen on")
return fmt.Errorf("cluster addresses not found")
}
c.logger.Debug("starting cluster listeners")
// Create the HTTP/2 server that will be shared by both RPC and regular
// duties. Doing it this way instead of listening via the server and gRPC
// allows us to re-use the same port via ALPN. We can just tell the server
// to serve a given conn and which handler to use.
h2Server := &http2.Server{
// Our forwarding connections heartbeat regularly so anything else we
// want to go away/get cleaned up pretty rapidly
IdleTimeout: 5 * HeartbeatInterval,
}
c.clusterListener = &ClusterListener{
handlers: make(map[string]ClusterHandler),
clients: make(map[string]ClusterClient),
shutdown: new(uint32),
shutdownWg: &sync.WaitGroup{},
server: h2Server,
clusterListenerAddrs: c.clusterListenerAddrs,
clusterCipherSuites: c.clusterCipherSuites,
logger: c.logger.Named("cluster-listener"),
}
return c.clusterListener.Run(ctx)
}
// stopClusterListener stops any existing listeners during seal. It is
// assumed that the state lock is held while this is run.
func (c *Core) stopClusterListener() {
if c.clusterListener == nil {
c.logger.Debug("clustering disabled, not stopping listeners")
return
}
c.logger.Info("stopping cluster listeners")
c.clusterListener.Stop()
c.logger.Info("cluster listeners successfully shut down")
}
func (c *Core) SetClusterListenerAddrs(addrs []*net.TCPAddr) {
c.clusterListenerAddrs = addrs
if c.clusterAddr == "" && len(addrs) == 1 {
c.clusterAddr = fmt.Sprintf("https://%s", addrs[0].String())
}
}
func (c *Core) SetClusterHandler(handler http.Handler) {
c.clusterHandler = handler
}
// getGRPCDialer is used to return a dialer that has the correct TLS
// configuration. Otherwise gRPC tries to be helpful and stomps all over our
// NextProtos.
func (c *Core) getGRPCDialer(ctx context.Context, alpnProto, serverName string, caCert *x509.Certificate) func(string, time.Duration) (net.Conn, error) {
return func(addr string, timeout time.Duration) (net.Conn, error) {
if c.clusterListener == nil {
return nil, errors.New("clustering disabled")
}
tlsConfig, err := c.clusterListener.TLSConfig(ctx)
if err != nil {
c.logger.Error("failed to get tls configuration", "error", err)
return nil, err
}
if serverName != "" {
tlsConfig.ServerName = serverName
}
if caCert != nil {
pool := x509.NewCertPool()
pool.AddCert(caCert)
tlsConfig.RootCAs = pool
tlsConfig.ClientCAs = pool
}
c.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName)
tlsConfig.NextProtos = []string{alpnProto}
dialer := &net.Dialer{
Timeout: timeout,
}
return tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
}
}