Add a TRACE log with TLS connection details on replication connections (#12754)

* remove cruft
use helper
Add a helper for getting public key sizes
wip

* error names

* Fix ecdsa

* only if trace is on

* Log listener side as well

* rename

* Add remote address

* Make the log level configurable via the env var, and a member of the Listener and thus modifiable by tests

* Fix certutil_test
This commit is contained in:
Scott Miller 2021-10-07 14:17:31 -05:00 committed by GitHub
parent 7fd527dc9a
commit 1097f356af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 9 deletions

View File

@ -455,6 +455,31 @@ vitin0L6nprauWkKO38XgM4T75qKZpqtiOcT
}
}
func TestGetPublicKeySize(t *testing.T) {
rsa, err := rsa.GenerateKey(rand.Reader, 3072)
if err != nil {
t.Fatal(err)
}
if GetPublicKeySize(&rsa.PublicKey) != 3072 {
t.Fatal("unexpected rsa key size")
}
ecdsa, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
t.Fatal(err)
}
if GetPublicKeySize(&ecdsa.PublicKey) != 384 {
t.Fatal("unexpected ecdsa key size")
}
ed25519, _, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
if GetPublicKeySize(ed25519) != 256 {
t.Fatal("unexpected ed25519 key size")
}
//Skipping DSA as too slow
}
func refreshRSA8CertBundle() *CertBundle {
initTest.Do(setCerts)
return &CertBundle{

View File

@ -3,6 +3,7 @@ package certutil
import (
"bytes"
"crypto"
"crypto/dsa"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
@ -1002,3 +1003,22 @@ func parseCertsPEM(pemCerts []byte) ([]*x509.Certificate, error) {
}
return certs, nil
}
// GetPublicKeySize returns the key size in bits for a given arbitrary crypto.PublicKey
// Returns -1 for an unsupported key type.
func GetPublicKeySize(key crypto.PublicKey) int {
if key, ok := key.(*rsa.PublicKey); ok {
return key.Size() * 8
}
if key, ok := key.(*ecdsa.PublicKey); ok {
return key.Params().BitSize
}
if key, ok := key.(ed25519.PublicKey); ok {
return len(key) * 8
}
if key, ok := key.(dsa.PublicKey); ok {
return key.Y.BitLen()
}
return -1
}

View File

@ -6,8 +6,11 @@ import (
"crypto/x509"
"errors"
"fmt"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/tlsutil"
"net"
"net/url"
"os"
"sync"
"sync/atomic"
"time"
@ -60,11 +63,12 @@ type Listener struct {
shutdownWg *sync.WaitGroup
server *http2.Server
networkLayer NetworkLayer
cipherSuites []uint16
advertise net.Addr
logger log.Logger
l sync.RWMutex
networkLayer NetworkLayer
cipherSuites []uint16
advertise net.Addr
logger log.Logger
l sync.RWMutex
tlsConnectionLoggingLevel log.Level
}
func NewListener(networkLayer NetworkLayer, cipherSuites []uint16, logger log.Logger, idleTimeout time.Duration) *Listener {
@ -85,9 +89,10 @@ func NewListener(networkLayer NetworkLayer, cipherSuites []uint16, logger log.Lo
shutdownWg: &sync.WaitGroup{},
server: h2Server,
networkLayer: networkLayer,
cipherSuites: cipherSuites,
logger: logger,
networkLayer: networkLayer,
cipherSuites: cipherSuites,
logger: logger,
tlsConnectionLoggingLevel: log.LevelFromString(os.Getenv("VAULT_CLUSTER_TLS_SESSION_LOG_LEVEL")),
}
}
@ -359,6 +364,8 @@ func (cl *Listener) Run(ctx context.Context) error {
continue
}
cl.logTLSSessionStart(tlsConn.RemoteAddr().String(), tlsConn.ConnectionState())
// Now, set it back to unlimited
err = tlsConn.SetDeadline(time.Time{})
if err != nil {
@ -438,7 +445,25 @@ func (cl *Listener) GetDialerFunc(ctx context.Context, alpn string) func(string,
tlsConfig.NextProtos = []string{alpn}
cl.logger.Debug("creating rpc dialer", "address", addr, "alpn", alpn, "host", tlsConfig.ServerName)
return cl.networkLayer.Dial(addr, timeout, tlsConfig)
conn, err := cl.networkLayer.Dial(addr, timeout, tlsConfig)
if err != nil {
return nil, err
}
cl.logTLSSessionStart(conn.RemoteAddr().String(), conn.ConnectionState())
return conn, nil
}
}
func (cl *Listener) logTLSSessionStart(peerAddress string, state tls.ConnectionState) {
if cl.tlsConnectionLoggingLevel != log.NoLevel {
cipherName, _ := tlsutil.GetCipherName(state.CipherSuite)
cl.logger.Log(cl.tlsConnectionLoggingLevel, "TLS connection established", "peer", peerAddress, "negotiated_protocol", state.NegotiatedProtocol, "cipher_suite", cipherName)
for _, chain := range state.VerifiedChains {
for _, cert := range chain {
cl.logger.Log(cl.tlsConnectionLoggingLevel, "Peer certificate", "is_ca", cert.IsCA, "serial_number", cert.SerialNumber.String(), "subject", cert.Subject.String(),
"signature_algorithm", cert.SignatureAlgorithm.String(), "public_key_algorithm", cert.PublicKeyAlgorithm.String(), "public_key_size", certutil.GetPublicKeySize(cert.PublicKey))
}
}
}
}