532 lines
15 KiB
Go
532 lines
15 KiB
Go
package vault
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"math/big"
|
|
mathrand "math/rand"
|
|
"net"
|
|
"net/http"
|
|
"time"
|
|
|
|
log "github.com/mgutz/logxi/v1"
|
|
|
|
"golang.org/x/net/http2"
|
|
|
|
"github.com/hashicorp/errwrap"
|
|
"github.com/hashicorp/go-uuid"
|
|
"github.com/hashicorp/vault/helper/forwarding"
|
|
"github.com/hashicorp/vault/helper/jsonutil"
|
|
)
|
|
|
|
const (
|
|
// Storage path where the local cluster name and identifier are stored
|
|
coreLocalClusterInfoPath = "core/cluster/local/info"
|
|
|
|
corePrivateKeyTypeP521 = "p521"
|
|
corePrivateKeyTypeED25519 = "ed25519"
|
|
|
|
// Internal so as not to log a trace message
|
|
IntNoForwardingHeaderName = "X-Vault-Internal-No-Request-Forwarding"
|
|
)
|
|
|
|
var (
|
|
ErrCannotForward = errors.New("cannot forward request; no connection or address not known")
|
|
)
|
|
|
|
// This can be one of a few key types so the different params may or may not be filled
|
|
type clusterKeyParams struct {
|
|
Type string `json:"type" structs:"type" mapstructure:"type"`
|
|
X *big.Int `json:"x" structs:"x" mapstructure:"x"`
|
|
Y *big.Int `json:"y" structs:"y" mapstructure:"y"`
|
|
D *big.Int `json:"d" structs:"d" mapstructure:"d"`
|
|
}
|
|
|
|
type activeConnection struct {
|
|
transport *http2.Transport
|
|
clusterAddr string
|
|
}
|
|
|
|
// Structure representing the storage entry that holds cluster information
|
|
type Cluster struct {
|
|
// Name of the cluster
|
|
Name string `json:"name" structs:"name" mapstructure:"name"`
|
|
|
|
// Identifier of the cluster
|
|
ID string `json:"id" structs:"id" mapstructure:"id"`
|
|
}
|
|
|
|
// Cluster fetches the details of the local cluster. This method errors out
|
|
// when Vault is sealed.
|
|
func (c *Core) Cluster() (*Cluster, error) {
|
|
var cluster Cluster
|
|
|
|
// Fetch the storage entry. This call fails when Vault is sealed.
|
|
entry, err := c.barrier.Get(coreLocalClusterInfoPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if entry == nil {
|
|
return &cluster, nil
|
|
}
|
|
|
|
// Decode the cluster information
|
|
if err = jsonutil.DecodeJSON(entry.Value, &cluster); err != nil {
|
|
return nil, fmt.Errorf("failed to decode cluster details: %v", err)
|
|
}
|
|
|
|
// Set in config file
|
|
if c.clusterName != "" {
|
|
cluster.Name = c.clusterName
|
|
}
|
|
|
|
return &cluster, nil
|
|
}
|
|
|
|
// This sets our local cluster cert and private key based on the advertisement.
|
|
// It also ensures the cert is in our local cluster cert pool.
|
|
func (c *Core) loadLocalClusterTLS(adv activeAdvertisement) (retErr error) {
|
|
defer func() {
|
|
if retErr != nil {
|
|
c.clusterParamsLock.Lock()
|
|
c.localClusterCert = nil
|
|
c.localClusterPrivateKey = nil
|
|
c.localClusterParsedCert = nil
|
|
c.clusterParamsLock.Unlock()
|
|
|
|
c.requestForwardingConnectionLock.Lock()
|
|
c.clearForwardingClients()
|
|
c.requestForwardingConnectionLock.Unlock()
|
|
}
|
|
}()
|
|
|
|
switch {
|
|
case adv.ClusterAddr == "":
|
|
// Clustering disabled on the server, don't try to look for params
|
|
return nil
|
|
|
|
case adv.ClusterKeyParams == nil:
|
|
c.logger.Error("core: no key params found loading local cluster TLS information")
|
|
return fmt.Errorf("no local cluster key params found")
|
|
|
|
case adv.ClusterKeyParams.X == nil, adv.ClusterKeyParams.Y == nil, adv.ClusterKeyParams.D == nil:
|
|
c.logger.Error("core: failed to parse local cluster key due to missing params")
|
|
return fmt.Errorf("failed to parse local cluster key")
|
|
|
|
case adv.ClusterKeyParams.Type != corePrivateKeyTypeP521:
|
|
c.logger.Error("core: unknown local cluster key type", "key_type", adv.ClusterKeyParams.Type)
|
|
return fmt.Errorf("failed to find valid local cluster key type")
|
|
|
|
case adv.ClusterCert == nil || len(adv.ClusterCert) == 0:
|
|
c.logger.Error("core: no local cluster cert found")
|
|
return fmt.Errorf("no local cluster cert found")
|
|
|
|
}
|
|
|
|
// Prevent data races with the TLS parameters
|
|
c.clusterParamsLock.Lock()
|
|
defer c.clusterParamsLock.Unlock()
|
|
|
|
c.localClusterPrivateKey = &ecdsa.PrivateKey{
|
|
PublicKey: ecdsa.PublicKey{
|
|
Curve: elliptic.P521(),
|
|
X: adv.ClusterKeyParams.X,
|
|
Y: adv.ClusterKeyParams.Y,
|
|
},
|
|
D: adv.ClusterKeyParams.D,
|
|
}
|
|
|
|
c.localClusterCert = adv.ClusterCert
|
|
|
|
cert, err := x509.ParseCertificate(c.localClusterCert)
|
|
if err != nil {
|
|
c.logger.Error("core: failed parsing local cluster certificate", "error", err)
|
|
return fmt.Errorf("error parsing local cluster certificate: %v", err)
|
|
}
|
|
|
|
c.localClusterParsedCert = cert
|
|
|
|
return nil
|
|
}
|
|
|
|
// setupCluster creates storage entries for holding Vault cluster information.
|
|
// Entries will be created only if they are not already present. If clusterName
|
|
// is not supplied, this method will auto-generate it.
|
|
func (c *Core) setupCluster() error {
|
|
// Prevent data races with the TLS parameters
|
|
c.clusterParamsLock.Lock()
|
|
defer c.clusterParamsLock.Unlock()
|
|
|
|
// Check if storage index is already present or not
|
|
cluster, err := c.Cluster()
|
|
if err != nil {
|
|
c.logger.Error("core: failed to get cluster details", "error", err)
|
|
return err
|
|
}
|
|
|
|
var modified bool
|
|
|
|
if cluster == nil {
|
|
cluster = &Cluster{}
|
|
}
|
|
|
|
if cluster.Name == "" {
|
|
// If cluster name is not supplied, generate one
|
|
if c.clusterName == "" {
|
|
c.logger.Trace("core: cluster name not found/set, generating new")
|
|
clusterNameBytes, err := uuid.GenerateRandomBytes(4)
|
|
if err != nil {
|
|
c.logger.Error("core: failed to generate cluster name", "error", err)
|
|
return err
|
|
}
|
|
|
|
c.clusterName = fmt.Sprintf("vault-cluster-%08x", clusterNameBytes)
|
|
}
|
|
|
|
cluster.Name = c.clusterName
|
|
if c.logger.IsDebug() {
|
|
c.logger.Debug("core: cluster name set", "name", cluster.Name)
|
|
}
|
|
modified = true
|
|
}
|
|
|
|
if cluster.ID == "" {
|
|
c.logger.Trace("core: cluster ID not found, generating new")
|
|
// Generate a clusterID
|
|
cluster.ID, err = uuid.GenerateUUID()
|
|
if err != nil {
|
|
c.logger.Error("core: failed to generate cluster identifier", "error", err)
|
|
return err
|
|
}
|
|
if c.logger.IsDebug() {
|
|
c.logger.Debug("core: cluster ID set", "id", cluster.ID)
|
|
}
|
|
modified = true
|
|
}
|
|
|
|
// If we're using HA, generate server-to-server parameters
|
|
if c.ha != nil {
|
|
// Create a private key
|
|
if c.localClusterPrivateKey == nil {
|
|
c.logger.Trace("core: generating cluster private key")
|
|
key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
|
if err != nil {
|
|
c.logger.Error("core: failed to generate local cluster key", "error", err)
|
|
return err
|
|
}
|
|
|
|
c.localClusterPrivateKey = key
|
|
}
|
|
|
|
// Create a certificate
|
|
if c.localClusterCert == nil {
|
|
c.logger.Trace("core: generating local cluster certificate")
|
|
|
|
host, err := uuid.GenerateUUID()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
host = fmt.Sprintf("fw-%s", host)
|
|
template := &x509.Certificate{
|
|
Subject: pkix.Name{
|
|
CommonName: host,
|
|
},
|
|
DNSNames: []string{host},
|
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
|
x509.ExtKeyUsageServerAuth,
|
|
x509.ExtKeyUsageClientAuth,
|
|
},
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
|
|
SerialNumber: big.NewInt(mathrand.Int63()),
|
|
NotBefore: time.Now().Add(-30 * time.Second),
|
|
// 30 years of single-active uptime ought to be enough for anybody
|
|
NotAfter: time.Now().Add(262980 * time.Hour),
|
|
BasicConstraintsValid: true,
|
|
IsCA: true,
|
|
}
|
|
|
|
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, c.localClusterPrivateKey.Public(), c.localClusterPrivateKey)
|
|
if err != nil {
|
|
c.logger.Error("core: error generating self-signed cert", "error", err)
|
|
return errwrap.Wrapf("unable to generate local cluster certificate: {{err}}", err)
|
|
}
|
|
|
|
parsedCert, err := x509.ParseCertificate(certBytes)
|
|
if err != nil {
|
|
c.logger.Error("core: error parsing self-signed cert", "error", err)
|
|
return errwrap.Wrapf("error parsing generated certificate: {{err}}", err)
|
|
}
|
|
|
|
c.localClusterCert = certBytes
|
|
c.localClusterParsedCert = parsedCert
|
|
}
|
|
}
|
|
|
|
if modified {
|
|
// Encode the cluster information into as a JSON string
|
|
rawCluster, err := json.Marshal(cluster)
|
|
if err != nil {
|
|
c.logger.Error("core: failed to encode cluster details", "error", err)
|
|
return err
|
|
}
|
|
|
|
// Store it
|
|
err = c.barrier.Put(&Entry{
|
|
Key: coreLocalClusterInfoPath,
|
|
Value: rawCluster,
|
|
})
|
|
if err != nil {
|
|
c.logger.Error("core: failed to store cluster details", "error", err)
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SetClusterSetupFuncs sets the handler setup func
|
|
func (c *Core) SetClusterSetupFuncs(handler func() (http.Handler, http.Handler)) {
|
|
c.clusterHandlerSetupFunc = handler
|
|
}
|
|
|
|
// startClusterListener starts cluster request listeners during postunseal. It
|
|
// is assumed that the state lock is held while this is run. Right now this
|
|
// only starts forwarding listeners; it's TBD whether other request types will
|
|
// be built in the same mechanism or started independently.
|
|
func (c *Core) startClusterListener() error {
|
|
if c.clusterHandlerSetupFunc == nil {
|
|
c.logger.Error("core: cluster handler setup function has not been set when trying to start listeners")
|
|
return fmt.Errorf("cluster handler setup function has not been set")
|
|
}
|
|
|
|
if c.clusterAddr == "" {
|
|
c.logger.Info("core: clustering disabled, not starting listeners")
|
|
return nil
|
|
}
|
|
|
|
if c.clusterListenerAddrs == nil || len(c.clusterListenerAddrs) == 0 {
|
|
c.logger.Warn("core: clustering not disabled but no addresses to listen on")
|
|
return fmt.Errorf("cluster addresses not found")
|
|
}
|
|
|
|
c.logger.Trace("core: starting cluster listeners")
|
|
|
|
err := c.startForwarding()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// stopClusterListener stops any existing listeners during preseal. It is
|
|
// assumed that the state lock is held while this is run.
|
|
func (c *Core) stopClusterListener() {
|
|
if c.clusterAddr == "" {
|
|
c.logger.Trace("core: clustering disabled, not stopping listeners")
|
|
return
|
|
}
|
|
|
|
if !c.clusterListenersRunning {
|
|
c.logger.Info("core: cluster listeners not running")
|
|
return
|
|
}
|
|
c.logger.Info("core: stopping cluster listeners")
|
|
|
|
// Tell the goroutine managing the listeners to perform the shutdown
|
|
// process
|
|
c.clusterListenerShutdownCh <- struct{}{}
|
|
|
|
// The reason for this loop-de-loop is that we may be unsealing again
|
|
// quickly, and if the listeners are not yet closed, we will get socket
|
|
// bind errors. This ensures proper ordering.
|
|
c.logger.Trace("core: waiting for success notification while stopping cluster listeners")
|
|
<-c.clusterListenerShutdownSuccessCh
|
|
c.clusterListenersRunning = false
|
|
|
|
c.logger.Info("core: cluster listeners successfully shut down")
|
|
}
|
|
|
|
// ClusterTLSConfig generates a TLS configuration based on the local/replicated
|
|
// cluster key and cert.
|
|
func (c *Core) ClusterTLSConfig() (*tls.Config, error) {
|
|
// Using lookup functions allows just-in-time lookup of the current state
|
|
// of clustering as connections come and go
|
|
|
|
serverLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
switch {
|
|
default:
|
|
var localCert bytes.Buffer
|
|
|
|
c.clusterParamsLock.RLock()
|
|
localCert.Write(c.localClusterCert)
|
|
localSigner := c.localClusterPrivateKey
|
|
parsedCert := c.localClusterParsedCert
|
|
c.clusterParamsLock.RUnlock()
|
|
|
|
if localCert.Len() == 0 {
|
|
return nil, fmt.Errorf("got forwarding connection but no local cert")
|
|
}
|
|
|
|
//c.logger.Trace("core: performing cert name lookup", "hello_server_name", clientHello.ServerName, "local_cluster_cert_name", parsedCert.Subject.CommonName)
|
|
|
|
return &tls.Certificate{
|
|
Certificate: [][]byte{localCert.Bytes()},
|
|
PrivateKey: localSigner,
|
|
Leaf: parsedCert,
|
|
}, nil
|
|
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
clientLookup := func(requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
//c.logger.Trace("core: performing client cert lookup")
|
|
|
|
if len(requestInfo.AcceptableCAs) != 1 {
|
|
return nil, fmt.Errorf("expected only a single acceptable CA")
|
|
}
|
|
var localCert bytes.Buffer
|
|
|
|
c.clusterParamsLock.RLock()
|
|
localCert.Write(c.localClusterCert)
|
|
localSigner := c.localClusterPrivateKey
|
|
parsedCert := c.localClusterParsedCert
|
|
c.clusterParamsLock.RUnlock()
|
|
|
|
if localCert.Len() == 0 {
|
|
return nil, fmt.Errorf("forwarding connection client but no local cert")
|
|
}
|
|
|
|
return &tls.Certificate{
|
|
Certificate: [][]byte{localCert.Bytes()},
|
|
PrivateKey: localSigner,
|
|
Leaf: parsedCert,
|
|
}, nil
|
|
}
|
|
|
|
serverConfigLookup := func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
//c.logger.Trace("core: performing server config lookup")
|
|
for _, v := range clientHello.SupportedProtos {
|
|
switch v {
|
|
case "h2", "req_fw_sb-act_v1":
|
|
default:
|
|
return nil, fmt.Errorf("unknown ALPN proto %s", v)
|
|
}
|
|
}
|
|
|
|
caPool := x509.NewCertPool()
|
|
|
|
ret := &tls.Config{
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
GetCertificate: serverLookup,
|
|
GetClientCertificate: clientLookup,
|
|
MinVersion: tls.VersionTLS12,
|
|
RootCAs: caPool,
|
|
ClientCAs: caPool,
|
|
NextProtos: clientHello.SupportedProtos,
|
|
}
|
|
|
|
switch {
|
|
default:
|
|
c.clusterParamsLock.RLock()
|
|
parsedCert := c.localClusterParsedCert
|
|
c.clusterParamsLock.RUnlock()
|
|
|
|
if parsedCert == nil {
|
|
return nil, fmt.Errorf("forwarding connection client but no local cert")
|
|
}
|
|
|
|
caPool.AddCert(parsedCert)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
tlsConfig := &tls.Config{
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
GetCertificate: serverLookup,
|
|
GetClientCertificate: clientLookup,
|
|
GetConfigForClient: serverConfigLookup,
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
|
|
var localCert bytes.Buffer
|
|
c.clusterParamsLock.RLock()
|
|
localCert.Write(c.localClusterCert)
|
|
parsedCert := c.localClusterParsedCert
|
|
c.clusterParamsLock.RUnlock()
|
|
|
|
if parsedCert != nil {
|
|
tlsConfig.ServerName = parsedCert.Subject.CommonName
|
|
|
|
pool := x509.NewCertPool()
|
|
pool.AddCert(parsedCert)
|
|
tlsConfig.RootCAs = pool
|
|
tlsConfig.ClientCAs = pool
|
|
}
|
|
|
|
return tlsConfig, nil
|
|
}
|
|
|
|
func (c *Core) SetClusterListenerAddrs(addrs []*net.TCPAddr) {
|
|
c.clusterListenerAddrs = addrs
|
|
}
|
|
|
|
// WrapHandlerForClustering takes in Vault's HTTP handler and returns a setup
|
|
// function that returns both the original handler and one wrapped with cluster
|
|
// methods
|
|
func WrapHandlerForClustering(handler http.Handler, logger log.Logger) func() (http.Handler, http.Handler) {
|
|
return func() (http.Handler, http.Handler) {
|
|
// This mux handles cluster functions (right now, only forwarded requests)
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/cluster/local/forwarded-request", func(w http.ResponseWriter, req *http.Request) {
|
|
//logger.Trace("forwarding: serving h2 forwarded request")
|
|
freq, err := forwarding.ParseForwardedHTTPRequest(req)
|
|
if err != nil {
|
|
if logger != nil {
|
|
logger.Error("http/forwarded-request-server: error parsing forwarded request", "error", err)
|
|
}
|
|
|
|
w.Header().Add("Content-Type", "application/json")
|
|
|
|
// The response writer here is different from
|
|
// the one set in Vault's HTTP handler.
|
|
// Hence, set the Cache-Control explicitly.
|
|
w.Header().Set("Cache-Control", "no-store")
|
|
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
|
|
type errorResponse struct {
|
|
Errors []string
|
|
}
|
|
resp := &errorResponse{
|
|
Errors: []string{
|
|
err.Error(),
|
|
},
|
|
}
|
|
|
|
enc := json.NewEncoder(w)
|
|
enc.Encode(resp)
|
|
return
|
|
}
|
|
|
|
// To avoid the risk of a forward loop in some pathological condition,
|
|
// set the no-forward header
|
|
freq.Header.Set(IntNoForwardingHeaderName, "true")
|
|
handler.ServeHTTP(w, freq)
|
|
})
|
|
|
|
return handler, mux
|
|
}
|
|
}
|