382 lines
9.4 KiB
Go
382 lines
9.4 KiB
Go
package raft
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"errors"
|
|
fmt "fmt"
|
|
"io"
|
|
"math/big"
|
|
mathrand "math/rand"
|
|
"net"
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/hashicorp/go-hclog"
|
|
uuid "github.com/hashicorp/go-uuid"
|
|
"github.com/hashicorp/raft"
|
|
"github.com/hashicorp/vault/sdk/helper/certutil"
|
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
|
"github.com/hashicorp/vault/vault/cluster"
|
|
)
|
|
|
|
// TLSKey is a single TLS keypair in the Keyring
|
|
type TLSKey struct {
|
|
// ID is a unique identifier for this Key
|
|
ID string `json:"id"`
|
|
|
|
// KeyType defines the algorighm used to generate the private keys
|
|
KeyType string `json:"key_type"`
|
|
|
|
// AppliedIndex is the earliest known raft index that safely contains this
|
|
// key.
|
|
AppliedIndex uint64 `json:"applied_index"`
|
|
|
|
// CertBytes is the marshaled certificate.
|
|
CertBytes []byte `json:"cluster_cert"`
|
|
|
|
// KeyParams is the marshaled private key.
|
|
KeyParams *certutil.ClusterKeyParams `json:"cluster_key_params"`
|
|
|
|
// CreatedTime is the time this key was generated. This value is useful in
|
|
// determining when the next rotation should be.
|
|
CreatedTime time.Time `json:"created_time"`
|
|
|
|
parsedCert *x509.Certificate
|
|
parsedKey *ecdsa.PrivateKey
|
|
}
|
|
|
|
// TLSKeyring is the set of keys that raft uses for network communication.
|
|
// Only one key is used to dial at a time but both keys will be used to accept
|
|
// connections.
|
|
type TLSKeyring struct {
|
|
// Keys is the set of available key pairs
|
|
Keys []*TLSKey `json:"keys"`
|
|
|
|
// AppliedIndex is the earliest known raft index that safely contains the
|
|
// latest key in the keyring.
|
|
AppliedIndex uint64 `json:"applied_index"`
|
|
|
|
// Term is an incrementing identifier value used to quickly determine if two
|
|
// states of the keyring are different.
|
|
Term uint64 `json:"term"`
|
|
|
|
// ActiveKeyID is the key ID to track the active key in the keyring. Only
|
|
// the active key is used for dialing.
|
|
ActiveKeyID string `json:"active_key_id"`
|
|
}
|
|
|
|
// GetActive returns the active key.
|
|
func (k *TLSKeyring) GetActive() *TLSKey {
|
|
if k.ActiveKeyID == "" {
|
|
return nil
|
|
}
|
|
|
|
for _, key := range k.Keys {
|
|
if key.ID == k.ActiveKeyID {
|
|
return key
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func GenerateTLSKey(reader io.Reader) (*TLSKey, error) {
|
|
key, err := ecdsa.GenerateKey(elliptic.P521(), reader)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
host, err := uuid.GenerateUUID()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
host = fmt.Sprintf("raft-%s", host)
|
|
template := &x509.Certificate{
|
|
Subject: pkix.Name{
|
|
CommonName: host,
|
|
},
|
|
DNSNames: []string{host},
|
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
|
x509.ExtKeyUsageServerAuth,
|
|
x509.ExtKeyUsageClientAuth,
|
|
},
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
|
|
SerialNumber: big.NewInt(mathrand.Int63()),
|
|
NotBefore: time.Now().Add(-30 * time.Second),
|
|
// 30 years ought to be enough for anybody
|
|
NotAfter: time.Now().Add(262980 * time.Hour),
|
|
BasicConstraintsValid: true,
|
|
IsCA: true,
|
|
}
|
|
|
|
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to generate local cluster certificate: %w", err)
|
|
}
|
|
|
|
return &TLSKey{
|
|
ID: host,
|
|
KeyType: certutil.PrivateKeyTypeP521,
|
|
CertBytes: certBytes,
|
|
KeyParams: &certutil.ClusterKeyParams{
|
|
Type: certutil.PrivateKeyTypeP521,
|
|
X: key.PublicKey.X,
|
|
Y: key.PublicKey.Y,
|
|
D: key.D,
|
|
},
|
|
CreatedTime: time.Now(),
|
|
}, nil
|
|
}
|
|
|
|
var (
|
|
// Make sure raftLayer satisfies the raft.StreamLayer interface
|
|
_ raft.StreamLayer = (*raftLayer)(nil)
|
|
|
|
// Make sure raftLayer satisfies the cluster.Handler and cluster.Client
|
|
// interfaces
|
|
_ cluster.Handler = (*raftLayer)(nil)
|
|
_ cluster.Client = (*raftLayer)(nil)
|
|
)
|
|
|
|
// RaftLayer implements the raft.StreamLayer interface,
|
|
// so that we can use a single RPC layer for Raft and Vault
|
|
type raftLayer struct {
|
|
// Addr is the listener address to return
|
|
addr net.Addr
|
|
|
|
// connCh is used to accept connections
|
|
connCh chan net.Conn
|
|
|
|
// Tracks if we are closed
|
|
closed bool
|
|
closeCh chan struct{}
|
|
closeLock sync.Mutex
|
|
|
|
logger log.Logger
|
|
|
|
dialerFunc func(string, time.Duration) (net.Conn, error)
|
|
|
|
// TLS config
|
|
keyring *TLSKeyring
|
|
clusterListener cluster.ClusterHook
|
|
}
|
|
|
|
// NewRaftLayer creates a new raftLayer object. It parses the TLS information
|
|
// from the network config.
|
|
func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterListener cluster.ClusterHook) (*raftLayer, error) {
|
|
clusterAddr := clusterListener.Addr()
|
|
if clusterAddr == nil {
|
|
return nil, errors.New("no raft addr found")
|
|
}
|
|
|
|
{
|
|
// Test the advertised address to make sure it's not an unspecified IP
|
|
u := url.URL{
|
|
Host: clusterAddr.String(),
|
|
}
|
|
ip := net.ParseIP(u.Hostname())
|
|
if ip != nil && ip.IsUnspecified() {
|
|
return nil, fmt.Errorf("cannot use unspecified IP with raft storage: %s", clusterAddr.String())
|
|
}
|
|
}
|
|
|
|
layer := &raftLayer{
|
|
addr: clusterAddr,
|
|
connCh: make(chan net.Conn),
|
|
closeCh: make(chan struct{}),
|
|
logger: logger,
|
|
clusterListener: clusterListener,
|
|
}
|
|
|
|
if err := layer.setTLSKeyring(raftTLSKeyring); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return layer, nil
|
|
}
|
|
|
|
func (l *raftLayer) setTLSKeyring(keyring *TLSKeyring) error {
|
|
// Fast path a noop update
|
|
if l.keyring != nil && l.keyring.Term == keyring.Term {
|
|
return nil
|
|
}
|
|
|
|
for _, key := range keyring.Keys {
|
|
switch {
|
|
case key.KeyParams == nil:
|
|
return errors.New("no raft cluster key params found")
|
|
|
|
case key.KeyParams.X == nil, key.KeyParams.Y == nil, key.KeyParams.D == nil:
|
|
return errors.New("failed to parse raft cluster key")
|
|
|
|
case key.KeyParams.Type != certutil.PrivateKeyTypeP521:
|
|
return errors.New("failed to find valid raft cluster key type")
|
|
|
|
case len(key.CertBytes) == 0:
|
|
return errors.New("no cluster cert found")
|
|
}
|
|
|
|
parsedCert, err := x509.ParseCertificate(key.CertBytes)
|
|
if err != nil {
|
|
return fmt.Errorf("error parsing raft cluster certificate: %w", err)
|
|
}
|
|
|
|
key.parsedCert = parsedCert
|
|
key.parsedKey = &ecdsa.PrivateKey{
|
|
PublicKey: ecdsa.PublicKey{
|
|
Curve: elliptic.P521(),
|
|
X: key.KeyParams.X,
|
|
Y: key.KeyParams.Y,
|
|
},
|
|
D: key.KeyParams.D,
|
|
}
|
|
}
|
|
|
|
if keyring.GetActive() == nil {
|
|
return errors.New("expected one active key to be present in the keyring")
|
|
}
|
|
|
|
l.keyring = keyring
|
|
|
|
return nil
|
|
}
|
|
|
|
func (l *raftLayer) ServerName() string {
|
|
key := l.keyring.GetActive()
|
|
if key == nil {
|
|
return ""
|
|
}
|
|
|
|
return key.parsedCert.Subject.CommonName
|
|
}
|
|
|
|
func (l *raftLayer) CACert(ctx context.Context) *x509.Certificate {
|
|
key := l.keyring.GetActive()
|
|
if key == nil {
|
|
return nil
|
|
}
|
|
|
|
return key.parsedCert
|
|
}
|
|
|
|
func (l *raftLayer) ClientLookup(ctx context.Context, requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
for _, subj := range requestInfo.AcceptableCAs {
|
|
for _, key := range l.keyring.Keys {
|
|
if bytes.Equal(subj, key.parsedCert.RawIssuer) {
|
|
localCert := make([]byte, len(key.CertBytes))
|
|
copy(localCert, key.CertBytes)
|
|
|
|
return &tls.Certificate{
|
|
Certificate: [][]byte{localCert},
|
|
PrivateKey: key.parsedKey,
|
|
Leaf: key.parsedCert,
|
|
}, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
func (l *raftLayer) ServerLookup(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
if l.keyring == nil {
|
|
return nil, errors.New("got raft connection but no local cert")
|
|
}
|
|
|
|
for _, key := range l.keyring.Keys {
|
|
if clientHello.ServerName == key.ID {
|
|
localCert := make([]byte, len(key.CertBytes))
|
|
copy(localCert, key.CertBytes)
|
|
|
|
return &tls.Certificate{
|
|
Certificate: [][]byte{localCert},
|
|
PrivateKey: key.parsedKey,
|
|
Leaf: key.parsedCert,
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
// CALookup returns the CA to use when validating this connection.
|
|
func (l *raftLayer) CALookup(context.Context) ([]*x509.Certificate, error) {
|
|
ret := make([]*x509.Certificate, len(l.keyring.Keys))
|
|
for i, key := range l.keyring.Keys {
|
|
ret[i] = key.parsedCert
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
// Stop shuts down the raft layer.
|
|
func (l *raftLayer) Stop() error {
|
|
l.Close()
|
|
return nil
|
|
}
|
|
|
|
// Handoff is used to hand off a connection to the
|
|
// RaftLayer. This allows it to be Accept()'ed
|
|
func (l *raftLayer) Handoff(ctx context.Context, wg *sync.WaitGroup, quit chan struct{}, conn *tls.Conn) error {
|
|
l.closeLock.Lock()
|
|
closed := l.closed
|
|
l.closeLock.Unlock()
|
|
|
|
if closed {
|
|
return errors.New("raft is shutdown")
|
|
}
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
select {
|
|
case l.connCh <- conn:
|
|
case <-l.closeCh:
|
|
case <-ctx.Done():
|
|
case <-quit:
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Accept is used to return connection which are
|
|
// dialed to be used with the Raft layer
|
|
func (l *raftLayer) Accept() (net.Conn, error) {
|
|
select {
|
|
case conn := <-l.connCh:
|
|
return conn, nil
|
|
case <-l.closeCh:
|
|
return nil, fmt.Errorf("Raft RPC layer closed")
|
|
}
|
|
}
|
|
|
|
// Close is used to stop listening for Raft connections
|
|
func (l *raftLayer) Close() error {
|
|
l.closeLock.Lock()
|
|
defer l.closeLock.Unlock()
|
|
|
|
if !l.closed {
|
|
l.closed = true
|
|
close(l.closeCh)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Addr is used to return the address of the listener
|
|
func (l *raftLayer) Addr() net.Addr {
|
|
return l.addr
|
|
}
|
|
|
|
// Dial is used to create a new outgoing connection
|
|
func (l *raftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
|
|
dialFunc := l.clusterListener.GetDialerFunc(context.Background(), consts.RaftStorageALPN)
|
|
return dialFunc(string(address), timeout)
|
|
}
|