321 lines
9.1 KiB
Go
321 lines
9.1 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package consul
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
|
|
"github.com/hashicorp/go-version"
|
|
|
|
"github.com/hashicorp/consul/agent/consul/gateways"
|
|
"github.com/hashicorp/consul/agent/structs"
|
|
"github.com/hashicorp/consul/logging"
|
|
)
|
|
|
|
const (
|
|
// loopRateLimit is the maximum rate per second at which we can rerun CA and intention
|
|
// replication watches.
|
|
loopRateLimit rate.Limit = 0.2
|
|
|
|
// retryBucketSize is the maximum number of stored rate limit attempts for looped
|
|
// blocking query operations.
|
|
retryBucketSize = 5
|
|
)
|
|
|
|
var (
|
|
// maxRetryBackoff is the maximum number of seconds to wait between failed blocking
|
|
// queries when backing off.
|
|
maxRetryBackoff = 256
|
|
|
|
// minVirtualIPVersion is the minimum version for all Consul servers for virtual IP
|
|
// assignment to be enabled.
|
|
minVirtualIPVersion = version.Must(version.NewVersion("1.11.0"))
|
|
|
|
// minVirtualIPVersion is the minimum version for all Consul servers for virtual IP
|
|
// assignment to be enabled for terminating gateways.
|
|
minVirtualIPTerminatingGatewayVersion = version.Must(version.NewVersion("1.11.2"))
|
|
|
|
// virtualIPVersionCheckInterval is the frequency we check whether all servers meet
|
|
// the minimum version to enable virtual IP assignment for services.
|
|
virtualIPVersionCheckInterval = time.Minute
|
|
)
|
|
|
|
// startConnectLeader starts multi-dc connect leader routines.
|
|
func (s *Server) startConnectLeader(ctx context.Context) error {
|
|
if !s.config.ConnectEnabled {
|
|
return nil
|
|
}
|
|
|
|
s.caManager.Start(ctx)
|
|
s.leaderRoutineManager.Start(ctx, caRootPruningRoutineName, s.runCARootPruning)
|
|
s.leaderRoutineManager.Start(ctx, caRootMetricRoutineName, rootCAExpiryMonitor(s).Monitor)
|
|
s.leaderRoutineManager.Start(ctx, caSigningMetricRoutineName, signingCAExpiryMonitor(s).Monitor)
|
|
s.leaderRoutineManager.Start(ctx, virtualIPCheckRoutineName, s.runVirtualIPVersionCheck)
|
|
s.leaderRoutineManager.Start(ctx, configEntryControllersRoutineName, s.runConfigEntryControllers)
|
|
|
|
return s.startIntentionConfigEntryMigration(ctx)
|
|
}
|
|
|
|
// stopConnectLeader stops connect specific leader functions.
|
|
func (s *Server) stopConnectLeader() {
|
|
s.caManager.Stop()
|
|
s.leaderRoutineManager.Stop(intentionMigrationRoutineName)
|
|
s.leaderRoutineManager.Stop(caRootPruningRoutineName)
|
|
s.leaderRoutineManager.Stop(caRootMetricRoutineName)
|
|
s.leaderRoutineManager.Stop(caSigningMetricRoutineName)
|
|
s.leaderRoutineManager.Stop(virtualIPCheckRoutineName)
|
|
s.leaderRoutineManager.Stop(configEntryControllersRoutineName)
|
|
}
|
|
|
|
func (s *Server) runConfigEntryControllers(ctx context.Context) error {
|
|
updater := &gateways.Updater{
|
|
UpdateWithStatus: func(entry structs.ControlledConfigEntry) error {
|
|
_, err := s.leaderRaftApply("ConfigEntry.Apply", structs.ConfigEntryRequestType, &structs.ConfigEntryRequest{
|
|
Op: structs.ConfigEntryUpsertWithStatusCAS,
|
|
Entry: entry,
|
|
})
|
|
return err
|
|
},
|
|
Update: func(entry structs.ConfigEntry) error {
|
|
_, err := s.leaderRaftApply("ConfigEntry.Apply", structs.ConfigEntryRequestType, &structs.ConfigEntryRequest{
|
|
Op: structs.ConfigEntryUpsertCAS,
|
|
Entry: entry,
|
|
})
|
|
return err
|
|
},
|
|
Delete: func(entry structs.ConfigEntry) error {
|
|
_, err := s.leaderRaftApply("ConfigEntry.Delete", structs.ConfigEntryRequestType, &structs.ConfigEntryRequest{
|
|
Op: structs.ConfigEntryDeleteCAS,
|
|
Entry: entry,
|
|
})
|
|
return err
|
|
},
|
|
}
|
|
logger := s.logger.Named(logging.APIGatewayController)
|
|
return gateways.NewAPIGatewayController(s.fsm, s.publisher, updater, logger).Run(ctx)
|
|
}
|
|
|
|
func (s *Server) runCARootPruning(ctx context.Context) error {
|
|
ticker := time.NewTicker(caRootPruneInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case <-ticker.C:
|
|
if err := s.pruneCARoots(); err != nil {
|
|
s.loggers.Named(logging.Connect).Error("error pruning CA roots", "error", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// pruneCARoots looks for any CARoots that have been rotated out and expired.
|
|
func (s *Server) pruneCARoots() error {
|
|
if !s.config.ConnectEnabled {
|
|
return nil
|
|
}
|
|
|
|
state := s.fsm.State()
|
|
idx, roots, err := state.CARoots(nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, caConf, err := state.CAConfig(nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
common, err := caConf.GetCommonConfig()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var newRoots structs.CARoots
|
|
for _, r := range roots {
|
|
if !r.Active && !r.RotatedOutAt.IsZero() && time.Now().Sub(r.RotatedOutAt) > common.LeafCertTTL*2 {
|
|
s.loggers.Named(logging.Connect).Info("pruning old unused root CA", "id", r.ID)
|
|
continue
|
|
}
|
|
newRoot := *r
|
|
newRoots = append(newRoots, &newRoot)
|
|
}
|
|
|
|
// Return early if there's nothing to remove.
|
|
if len(newRoots) == len(roots) {
|
|
return nil
|
|
}
|
|
|
|
// Commit the new root state.
|
|
var args structs.CARequest
|
|
args.Op = structs.CAOpSetRoots
|
|
args.Index = idx
|
|
args.Roots = newRoots
|
|
_, err = s.raftApply(structs.ConnectCARequestType, args)
|
|
return err
|
|
}
|
|
|
|
func (s *Server) runVirtualIPVersionCheck(ctx context.Context) error {
|
|
// Return early if the flag is already set.
|
|
done, err := s.setVirtualIPFlags()
|
|
if err != nil {
|
|
s.loggers.Named(logging.Connect).Warn("error enabling virtual IPs", "error", err)
|
|
}
|
|
if done {
|
|
s.leaderRoutineManager.Stop(virtualIPCheckRoutineName)
|
|
return nil
|
|
}
|
|
|
|
ticker := time.NewTicker(virtualIPVersionCheckInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case <-ticker.C:
|
|
done, err := s.setVirtualIPFlags()
|
|
if err != nil {
|
|
s.loggers.Named(logging.Connect).Warn("error enabling virtual IPs", "error", err)
|
|
continue
|
|
}
|
|
if done {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) setVirtualIPFlags() (bool, error) {
|
|
virtualIPFlag, err := s.setVirtualIPVersionFlag()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
terminatingGatewayVirtualIPFlag, err := s.setVirtualIPTerminatingGatewayVersionFlag()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return virtualIPFlag && terminatingGatewayVirtualIPFlag, nil
|
|
}
|
|
|
|
func (s *Server) setVirtualIPVersionFlag() (bool, error) {
|
|
val, err := s.getSystemMetadata(structs.SystemMetadataVirtualIPsEnabled)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if val != "" {
|
|
return true, nil
|
|
}
|
|
|
|
if ok, _ := ServersInDCMeetMinimumVersion(s, s.config.Datacenter, minVirtualIPVersion); !ok {
|
|
return false, fmt.Errorf("can't allocate Virtual IPs until all servers >= %s",
|
|
minVirtualIPVersion.String())
|
|
}
|
|
|
|
if err := s.setSystemMetadataKey(structs.SystemMetadataVirtualIPsEnabled, "true"); err != nil {
|
|
return false, nil
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (s *Server) setVirtualIPTerminatingGatewayVersionFlag() (bool, error) {
|
|
val, err := s.getSystemMetadata(structs.SystemMetadataTermGatewayVirtualIPsEnabled)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if val != "" {
|
|
return true, nil
|
|
}
|
|
|
|
if ok, _ := ServersInDCMeetMinimumVersion(s, s.config.Datacenter, minVirtualIPTerminatingGatewayVersion); !ok {
|
|
return false, fmt.Errorf("can't allocate Virtual IPs for terminating gateways until all servers >= %s",
|
|
minVirtualIPTerminatingGatewayVersion.String())
|
|
}
|
|
|
|
if err := s.setSystemMetadataKey(structs.SystemMetadataTermGatewayVirtualIPsEnabled, "true"); err != nil {
|
|
return false, nil
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// retryLoopBackoff loops a given function indefinitely, backing off exponentially
|
|
// upon errors up to a maximum of maxRetryBackoff seconds.
|
|
func retryLoopBackoff(ctx context.Context, loopFn func() error, errFn func(error)) {
|
|
retryLoopBackoffHandleSuccess(ctx, loopFn, errFn, false)
|
|
}
|
|
|
|
func retryLoopBackoffAbortOnSuccess(ctx context.Context, loopFn func() error, errFn func(error)) {
|
|
retryLoopBackoffHandleSuccess(ctx, loopFn, errFn, true)
|
|
}
|
|
|
|
func retryLoopBackoffHandleSuccess(ctx context.Context, loopFn func() error, errFn func(error), abortOnSuccess bool) {
|
|
var failedAttempts uint
|
|
limiter := rate.NewLimiter(loopRateLimit, retryBucketSize)
|
|
for {
|
|
// Rate limit how often we run the loop
|
|
limiter.Wait(ctx)
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
if (1 << failedAttempts) < maxRetryBackoff {
|
|
failedAttempts++
|
|
}
|
|
retryTime := (1 << failedAttempts) * time.Second
|
|
|
|
if err := loopFn(); err != nil {
|
|
errFn(err)
|
|
|
|
timer := time.NewTimer(retryTime)
|
|
select {
|
|
case <-ctx.Done():
|
|
timer.Stop()
|
|
return
|
|
case <-timer.C:
|
|
continue
|
|
}
|
|
} else if abortOnSuccess {
|
|
return
|
|
}
|
|
|
|
// Reset the failed attempts after a successful run.
|
|
failedAttempts = 0
|
|
}
|
|
}
|
|
|
|
// nextIndexVal computes the next index value to query for, resetting to zero
|
|
// if the index went backward.
|
|
func nextIndexVal(prevIdx, idx uint64) uint64 {
|
|
if prevIdx > idx {
|
|
return 0
|
|
}
|
|
return idx
|
|
}
|
|
|
|
// halfTime returns a duration that is half the time between notBefore and
|
|
// notAfter.
|
|
func halfTime(notBefore, notAfter time.Time) time.Duration {
|
|
interval := notAfter.Sub(notBefore)
|
|
return interval / 2
|
|
}
|
|
|
|
// lessThanHalfTimePassed decides if half the time between notBefore and
|
|
// notAfter has passed relative to now.
|
|
// lessThanHalfTimePassed is being called while holding caProviderReconfigurationLock
|
|
// which means it must never take that lock itself or call anything that does.
|
|
func lessThanHalfTimePassed(now, notBefore, notAfter time.Time) bool {
|
|
t := notBefore.Add(halfTime(notBefore, notAfter))
|
|
return t.Sub(now) > 0
|
|
}
|