tlsutil: un-embed the RWMutex

Embedded structs make code harder to navidate because an IDE can not show all uses of
the methods of that field separate from other uses.

Generally embedding of structs should only be used to satisfy an interface, and in this
case the Configurator type does not need to implement the RWMutex interface.
This commit is contained in:
Daniel Nephin 2021-06-17 18:48:44 -04:00
parent b0a2252fa0
commit 3717888b30
1 changed files with 50 additions and 49 deletions

View File

@ -176,7 +176,8 @@ type manual struct {
// Configurator holds a Config and is responsible for generating all the // Configurator holds a Config and is responsible for generating all the
// *tls.Config necessary for Consul. Except the one in the api package. // *tls.Config necessary for Consul. Except the one in the api package.
type Configurator struct { type Configurator struct {
sync.RWMutex // lock synchronizes access to all fields on this struct
lock sync.RWMutex
base *Config base *Config
autoTLS *autoTLS autoTLS *autoTLS
manual *manual manual *manual
@ -211,15 +212,15 @@ func NewConfigurator(config Config, logger hclog.Logger) (*Configurator, error)
// CAPems returns the currently loaded CAs in PEM format. // CAPems returns the currently loaded CAs in PEM format.
func (c *Configurator) CAPems() []string { func (c *Configurator) CAPems() []string {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
return append(c.manual.caPems, c.autoTLS.caPems()...) return append(c.manual.caPems, c.autoTLS.caPems()...)
} }
// ManualCAPems returns the currently loaded CAs in PEM format. // ManualCAPems returns the currently loaded CAs in PEM format.
func (c *Configurator) ManualCAPems() []string { func (c *Configurator) ManualCAPems() []string {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
return c.manual.caPems return c.manual.caPems
} }
@ -227,10 +228,10 @@ func (c *Configurator) ManualCAPems() []string {
// *tls.Config. // *tls.Config.
// This function acquires a write lock because it writes the new config. // This function acquires a write lock because it writes the new config.
func (c *Configurator) Update(config Config) error { func (c *Configurator) Update(config Config) error {
c.Lock() c.lock.Lock()
// order of defers matters because log acquires a RLock() // order of defers matters because log acquires a RLock()
defer c.log("Update") defer c.log("Update")
defer c.Unlock() defer c.lock.Unlock()
cert, err := loadKeyPair(config.CertFile, config.KeyFile) cert, err := loadKeyPair(config.CertFile, config.KeyFile)
if err != nil { if err != nil {
@ -260,18 +261,18 @@ func (c *Configurator) Update(config Config) error {
// certificates. // certificates.
// Or it is being called on the client side when CA changes are detected. // Or it is being called on the client side when CA changes are detected.
func (c *Configurator) UpdateAutoTLSCA(connectCAPems []string) error { func (c *Configurator) UpdateAutoTLSCA(connectCAPems []string) error {
c.Lock() c.lock.Lock()
// order of defers matters because log acquires a RLock() // order of defers matters because log acquires a RLock()
defer c.log("UpdateAutoEncryptCA") defer c.log("UpdateAutoEncryptCA")
defer c.Unlock() defer c.lock.Unlock()
pool, err := pool(append(c.manual.caPems, append(c.autoTLS.manualCAPems, connectCAPems...)...)) pool, err := pool(append(c.manual.caPems, append(c.autoTLS.manualCAPems, connectCAPems...)...))
if err != nil { if err != nil {
c.RUnlock() c.lock.RUnlock()
return err return err
} }
if err = c.check(*c.base, pool, c.manual.cert); err != nil { if err = c.check(*c.base, pool, c.manual.cert); err != nil {
c.RUnlock() c.lock.RUnlock()
return err return err
} }
c.autoTLS.connectCAPems = connectCAPems c.autoTLS.connectCAPems = connectCAPems
@ -289,8 +290,8 @@ func (c *Configurator) UpdateAutoTLSCert(pub, priv string) error {
return fmt.Errorf("Failed to load cert/key pair: %v", err) return fmt.Errorf("Failed to load cert/key pair: %v", err)
} }
c.Lock() c.lock.Lock()
defer c.Unlock() defer c.lock.Unlock()
c.autoTLS.cert = &cert c.autoTLS.cert = &cert
c.version++ c.version++
@ -307,8 +308,8 @@ func (c *Configurator) UpdateAutoTLS(manualCAPems, connectCAPems []string, pub,
return fmt.Errorf("Failed to load cert/key pair: %v", err) return fmt.Errorf("Failed to load cert/key pair: %v", err)
} }
c.Lock() c.lock.Lock()
defer c.Unlock() defer c.lock.Unlock()
pool, err := pool(append(c.manual.caPems, append(manualCAPems, connectCAPems...)...)) pool, err := pool(append(c.manual.caPems, append(manualCAPems, connectCAPems...)...))
if err != nil { if err != nil {
@ -324,15 +325,15 @@ func (c *Configurator) UpdateAutoTLS(manualCAPems, connectCAPems []string, pub,
} }
func (c *Configurator) UpdateAreaPeerDatacenterUseTLS(peerDatacenter string, useTLS bool) { func (c *Configurator) UpdateAreaPeerDatacenterUseTLS(peerDatacenter string, useTLS bool) {
c.Lock() c.lock.Lock()
defer c.Unlock() defer c.lock.Unlock()
c.version++ c.version++
c.peerDatacenterUseTLS[peerDatacenter] = useTLS c.peerDatacenterUseTLS[peerDatacenter] = useTLS
} }
func (c *Configurator) getAreaForPeerDatacenterUseTLS(peerDatacenter string) bool { func (c *Configurator) getAreaForPeerDatacenterUseTLS(peerDatacenter string) bool {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
if v, ok := c.peerDatacenterUseTLS[peerDatacenter]; ok { if v, ok := c.peerDatacenterUseTLS[peerDatacenter]; ok {
return v return v
} }
@ -340,8 +341,8 @@ func (c *Configurator) getAreaForPeerDatacenterUseTLS(peerDatacenter string) boo
} }
func (c *Configurator) Base() Config { func (c *Configurator) Base() Config {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
return *c.base return *c.base
} }
@ -472,8 +473,8 @@ func (c *Configurator) commonTLSConfig(verifyIncoming bool) *tls.Config {
// this needs to be outside of RLock because it acquires an RLock itself // this needs to be outside of RLock because it acquires an RLock itself
verifyServerHostname := c.VerifyServerHostname() verifyServerHostname := c.VerifyServerHostname()
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
InsecureSkipVerify: !verifyServerHostname, InsecureSkipVerify: !verifyServerHostname,
} }
@ -529,8 +530,8 @@ func (c *Configurator) commonTLSConfig(verifyIncoming bool) *tls.Config {
// This function acquires a read lock because it reads from the config. // This function acquires a read lock because it reads from the config.
func (c *Configurator) Cert() *tls.Certificate { func (c *Configurator) Cert() *tls.Certificate {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
cert := c.manual.cert cert := c.manual.cert
if cert == nil { if cert == nil {
cert = c.autoTLS.cert cert = c.autoTLS.cert
@ -540,15 +541,15 @@ func (c *Configurator) Cert() *tls.Certificate {
// This function acquires a read lock because it reads from the config. // This function acquires a read lock because it reads from the config.
func (c *Configurator) VerifyIncomingRPC() bool { func (c *Configurator) VerifyIncomingRPC() bool {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
return c.base.verifyIncomingRPC() return c.base.verifyIncomingRPC()
} }
// This function acquires a read lock because it reads from the config. // This function acquires a read lock because it reads from the config.
func (c *Configurator) outgoingRPCTLSDisabled() bool { func (c *Configurator) outgoingRPCTLSDisabled() bool {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
// if AutoEncrypt enabled, always use TLS // if AutoEncrypt enabled, always use TLS
if c.base.AutoTLS { if c.base.AutoTLS {
@ -569,15 +570,15 @@ func (c *Configurator) MutualTLSCapable() bool {
// This function acquires a read lock because it reads from the config. // This function acquires a read lock because it reads from the config.
func (c *Configurator) mutualTLSCapable() bool { func (c *Configurator) mutualTLSCapable() bool {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
return c.caPool != nil && (c.autoTLS.cert != nil || c.manual.cert != nil) return c.caPool != nil && (c.autoTLS.cert != nil || c.manual.cert != nil)
} }
// This function acquires a read lock because it reads from the config. // This function acquires a read lock because it reads from the config.
func (c *Configurator) verifyOutgoing() bool { func (c *Configurator) verifyOutgoing() bool {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
// If AutoEncryptTLS is enabled and there is a CA, then verify // If AutoEncryptTLS is enabled and there is a CA, then verify
// outgoing. // outgoing.
@ -601,36 +602,36 @@ func (c *Configurator) ServerSNI(dc, nodeName string) string {
// This function acquires a read lock because it reads from the config. // This function acquires a read lock because it reads from the config.
func (c *Configurator) domain() string { func (c *Configurator) domain() string {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
return c.base.Domain return c.base.Domain
} }
// This function acquires a read lock because it reads from the config. // This function acquires a read lock because it reads from the config.
func (c *Configurator) verifyIncomingRPC() bool { func (c *Configurator) verifyIncomingRPC() bool {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
return c.base.verifyIncomingRPC() return c.base.verifyIncomingRPC()
} }
// This function acquires a read lock because it reads from the config. // This function acquires a read lock because it reads from the config.
func (c *Configurator) verifyIncomingHTTPS() bool { func (c *Configurator) verifyIncomingHTTPS() bool {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
return c.base.verifyIncomingHTTPS() return c.base.verifyIncomingHTTPS()
} }
// This function acquires a read lock because it reads from the config. // This function acquires a read lock because it reads from the config.
func (c *Configurator) enableAgentTLSForChecks() bool { func (c *Configurator) enableAgentTLSForChecks() bool {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
return c.base.EnableAgentTLSForChecks return c.base.EnableAgentTLSForChecks
} }
// This function acquires a read lock because it reads from the config. // This function acquires a read lock because it reads from the config.
func (c *Configurator) serverNameOrNodeName() string { func (c *Configurator) serverNameOrNodeName() string {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
if c.base.ServerName != "" { if c.base.ServerName != "" {
return c.base.ServerName return c.base.ServerName
} }
@ -639,8 +640,8 @@ func (c *Configurator) serverNameOrNodeName() string {
// This function acquires a read lock because it reads from the config. // This function acquires a read lock because it reads from the config.
func (c *Configurator) VerifyServerHostname() bool { func (c *Configurator) VerifyServerHostname() bool {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
return c.base.VerifyServerHostname || c.autoTLS.verifyServerHostname return c.base.VerifyServerHostname || c.autoTLS.verifyServerHostname
} }
@ -799,8 +800,8 @@ func (c *Configurator) OutgoingALPNRPCWrapper() ALPNWrapper {
// AutoEncryptCertNotAfter returns NotAfter from the auto_encrypt cert. In case // AutoEncryptCertNotAfter returns NotAfter from the auto_encrypt cert. In case
// there is no cert, it will return a time in the past. // there is no cert, it will return a time in the past.
func (c *Configurator) AutoEncryptCertNotAfter() time.Time { func (c *Configurator) AutoEncryptCertNotAfter() time.Time {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
tlsCert := c.autoTLS.cert tlsCert := c.autoTLS.cert
if tlsCert == nil || tlsCert.Certificate == nil { if tlsCert == nil || tlsCert.Certificate == nil {
return time.Now().AddDate(0, 0, -1) return time.Now().AddDate(0, 0, -1)
@ -820,8 +821,8 @@ func (c *Configurator) AutoEncryptCertExpired() bool {
// This function acquires a read lock because it reads from the config. // This function acquires a read lock because it reads from the config.
func (c *Configurator) log(name string) { func (c *Configurator) log(name string) {
if c.logger != nil { if c.logger != nil {
c.RLock() c.lock.RLock()
defer c.RUnlock() defer c.lock.RUnlock()
c.logger.Trace(name, "version", c.version) c.logger.Trace(name, "version", c.version)
} }
} }