Add a TRACE log with TLS connection details on replication connections (#12754)
* remove cruft use helper Add a helper for getting public key sizes wip * error names * Fix ecdsa * only if trace is on * Log listener side as well * rename * Add remote address * Make the log level configurable via the env var, and a member of the Listener and thus modifiable by tests * Fix certutil_test
This commit is contained in:
parent
7fd527dc9a
commit
1097f356af
|
@ -455,6 +455,31 @@ vitin0L6nprauWkKO38XgM4T75qKZpqtiOcT
|
|||
}
|
||||
}
|
||||
|
||||
func TestGetPublicKeySize(t *testing.T) {
|
||||
rsa, err := rsa.GenerateKey(rand.Reader, 3072)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if GetPublicKeySize(&rsa.PublicKey) != 3072 {
|
||||
t.Fatal("unexpected rsa key size")
|
||||
}
|
||||
ecdsa, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if GetPublicKeySize(&ecdsa.PublicKey) != 384 {
|
||||
t.Fatal("unexpected ecdsa key size")
|
||||
}
|
||||
ed25519, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if GetPublicKeySize(ed25519) != 256 {
|
||||
t.Fatal("unexpected ed25519 key size")
|
||||
}
|
||||
//Skipping DSA as too slow
|
||||
}
|
||||
|
||||
func refreshRSA8CertBundle() *CertBundle {
|
||||
initTest.Do(setCerts)
|
||||
return &CertBundle{
|
||||
|
|
|
@ -3,6 +3,7 @@ package certutil
|
|||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/dsa"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
|
@ -1002,3 +1003,22 @@ func parseCertsPEM(pemCerts []byte) ([]*x509.Certificate, error) {
|
|||
}
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
// GetPublicKeySize returns the key size in bits for a given arbitrary crypto.PublicKey
|
||||
// Returns -1 for an unsupported key type.
|
||||
func GetPublicKeySize(key crypto.PublicKey) int {
|
||||
if key, ok := key.(*rsa.PublicKey); ok {
|
||||
return key.Size() * 8
|
||||
}
|
||||
if key, ok := key.(*ecdsa.PublicKey); ok {
|
||||
return key.Params().BitSize
|
||||
}
|
||||
if key, ok := key.(ed25519.PublicKey); ok {
|
||||
return len(key) * 8
|
||||
}
|
||||
if key, ok := key.(dsa.PublicKey); ok {
|
||||
return key.Y.BitLen()
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
|
|
@ -6,8 +6,11 @@ import (
|
|||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/hashicorp/vault/sdk/helper/certutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/tlsutil"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -60,11 +63,12 @@ type Listener struct {
|
|||
shutdownWg *sync.WaitGroup
|
||||
server *http2.Server
|
||||
|
||||
networkLayer NetworkLayer
|
||||
cipherSuites []uint16
|
||||
advertise net.Addr
|
||||
logger log.Logger
|
||||
l sync.RWMutex
|
||||
networkLayer NetworkLayer
|
||||
cipherSuites []uint16
|
||||
advertise net.Addr
|
||||
logger log.Logger
|
||||
l sync.RWMutex
|
||||
tlsConnectionLoggingLevel log.Level
|
||||
}
|
||||
|
||||
func NewListener(networkLayer NetworkLayer, cipherSuites []uint16, logger log.Logger, idleTimeout time.Duration) *Listener {
|
||||
|
@ -85,9 +89,10 @@ func NewListener(networkLayer NetworkLayer, cipherSuites []uint16, logger log.Lo
|
|||
shutdownWg: &sync.WaitGroup{},
|
||||
server: h2Server,
|
||||
|
||||
networkLayer: networkLayer,
|
||||
cipherSuites: cipherSuites,
|
||||
logger: logger,
|
||||
networkLayer: networkLayer,
|
||||
cipherSuites: cipherSuites,
|
||||
logger: logger,
|
||||
tlsConnectionLoggingLevel: log.LevelFromString(os.Getenv("VAULT_CLUSTER_TLS_SESSION_LOG_LEVEL")),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -359,6 +364,8 @@ func (cl *Listener) Run(ctx context.Context) error {
|
|||
continue
|
||||
}
|
||||
|
||||
cl.logTLSSessionStart(tlsConn.RemoteAddr().String(), tlsConn.ConnectionState())
|
||||
|
||||
// Now, set it back to unlimited
|
||||
err = tlsConn.SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
|
@ -438,7 +445,25 @@ func (cl *Listener) GetDialerFunc(ctx context.Context, alpn string) func(string,
|
|||
tlsConfig.NextProtos = []string{alpn}
|
||||
cl.logger.Debug("creating rpc dialer", "address", addr, "alpn", alpn, "host", tlsConfig.ServerName)
|
||||
|
||||
return cl.networkLayer.Dial(addr, timeout, tlsConfig)
|
||||
conn, err := cl.networkLayer.Dial(addr, timeout, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cl.logTLSSessionStart(conn.RemoteAddr().String(), conn.ConnectionState())
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (cl *Listener) logTLSSessionStart(peerAddress string, state tls.ConnectionState) {
|
||||
if cl.tlsConnectionLoggingLevel != log.NoLevel {
|
||||
cipherName, _ := tlsutil.GetCipherName(state.CipherSuite)
|
||||
cl.logger.Log(cl.tlsConnectionLoggingLevel, "TLS connection established", "peer", peerAddress, "negotiated_protocol", state.NegotiatedProtocol, "cipher_suite", cipherName)
|
||||
for _, chain := range state.VerifiedChains {
|
||||
for _, cert := range chain {
|
||||
cl.logger.Log(cl.tlsConnectionLoggingLevel, "Peer certificate", "is_ca", cert.IsCA, "serial_number", cert.SerialNumber.String(), "subject", cert.Subject.String(),
|
||||
"signature_algorithm", cert.SignatureAlgorithm.String(), "public_key_algorithm", cert.PublicKeyAlgorithm.String(), "public_key_size", certutil.GetPublicKeySize(cert.PublicKey))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue