Share mgw addrs in peering stream if needed

This commit adds handling so that the replication stream considers
whether the user intends to peer through mesh gateways.

The subscription will return server or mesh gateway addresses depending
on the mesh configuration setting. These watches can be updated at
runtime by modifying the mesh config entry.
This commit is contained in:
freddygv 2022-09-21 09:55:19 -06:00
parent 17463472b7
commit 2c5caec97c
8 changed files with 233 additions and 80 deletions

View File

@ -8,6 +8,7 @@ import (
"time"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/types"
"github.com/stretchr/testify/require"
gogrpc "google.golang.org/grpc"
@ -71,7 +72,7 @@ func TestPeeringBackend_GetServerAddresses(t *testing.T) {
}
_, cfg := testServerConfig(t)
cfg.GRPCTLSPort = 8505
cfg.GRPCTLSPort = freeport.GetOne(t)
srv, err := newServer(t, cfg)
require.NoError(t, err)

View File

@ -150,7 +150,6 @@ func (m *CertManager) watchServerToken(ctx context.Context) {
// Cancel existing the leaf cert watch and spin up new one any time the server token changes.
// The watch needs the current token as set by the leader since certificate signing requests go to the leader.
fmt.Println("canceling and resetting")
cancel()
notifyCtx, cancel = context.WithCancel(ctx)

View File

@ -123,5 +123,6 @@ type StateStore interface {
CAConfig(ws memdb.WatchSet) (uint64, *structs.CAConfiguration, error)
TrustBundleListByService(ws memdb.WatchSet, service, dc string, entMeta acl.EnterpriseMeta) (uint64, []*pbpeering.PeeringTrustBundle, error)
ServiceList(ws memdb.WatchSet, entMeta *acl.EnterpriseMeta, peerName string) (uint64, structs.ServiceList, error)
ConfigEntry(ws memdb.WatchSet, kind, name string, entMeta *acl.EnterpriseMeta) (uint64, structs.ConfigEntry, error)
AbandonCh() <-chan struct{}
}

View File

@ -640,9 +640,6 @@ func (s *Server) realHandleStream(streamReq HandleStreamRequest) error {
continue
}
case strings.HasPrefix(update.CorrelationID, subMeshGateway):
// TODO(Peering): figure out how to sync this separately
case update.CorrelationID == subCARoot:
resp, err = makeCARootsResponse(update)
if err != nil {

View File

@ -98,25 +98,25 @@ func (m *subscriptionManager) syncViaBlockingQuery(
ws.Add(store.AbandonCh())
ws.Add(ctx.Done())
if result, err := queryFn(ctx, store, ws); err != nil {
if result, err := queryFn(ctx, store, ws); err != nil && ctx.Err() == nil {
logger.Error("failed to sync from query", "error", err)
} else {
// Block for any changes to the state store.
updateCh <- cache.UpdateEvent{
CorrelationID: correlationID,
Result: result,
select {
case <-ctx.Done():
return
case updateCh <- cache.UpdateEvent{CorrelationID: correlationID, Result: result}:
}
// Block for any changes to the state store.
ws.WatchCtx(ctx)
}
if err := waiter.Wait(ctx); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.Error("failed to wait before re-trying sync", "error", err)
}
select {
case <-ctx.Done():
err := waiter.Wait(ctx)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
default:
} else if err != nil {
logger.Error("failed to wait before re-trying sync", "error", err)
}
}
}

View File

@ -6,9 +6,13 @@ import (
"fmt"
"strconv"
"strings"
"time"
"github.com/golang/protobuf/proto"
"github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/lib/retry"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache"
@ -247,16 +251,10 @@ func (m *subscriptionManager) handleEvent(ctx context.Context, state *subscripti
pending := &pendingPayload{}
// Directly replicate information about our mesh gateways to the consuming side.
// TODO(peering): should we scrub anything before replicating this?
if err := pending.Add(meshGatewayPayloadID, u.CorrelationID, csn); err != nil {
return err
}
if state.exportList != nil {
// Trigger public events for all synthetic discovery chain replies.
for chainName, info := range state.connectServices {
m.emitEventForDiscoveryChain(ctx, state, pending, chainName, info)
m.collectPendingEventForDiscoveryChain(ctx, state, pending, chainName, info)
}
}
@ -490,7 +488,7 @@ func (m *subscriptionManager) syncDiscoveryChains(
state.connectServices[chainName] = info
m.emitEventForDiscoveryChain(ctx, state, pending, chainName, info)
m.collectPendingEventForDiscoveryChain(ctx, state, pending, chainName, info)
}
// if it was dropped, try to emit an DELETE event
@ -517,7 +515,7 @@ func (m *subscriptionManager) syncDiscoveryChains(
}
}
func (m *subscriptionManager) emitEventForDiscoveryChain(
func (m *subscriptionManager) collectPendingEventForDiscoveryChain(
ctx context.Context,
state *subscriptionState,
pending *pendingPayload,
@ -738,32 +736,118 @@ func (m *subscriptionManager) notifyServerAddrUpdates(
ctx context.Context,
updateCh chan<- cache.UpdateEvent,
) {
// Wait until this is subscribed-to.
// Wait until server address updates are subscribed-to.
select {
case <-m.serverAddrsSubReady:
case <-ctx.Done():
return
}
var idx uint64
// TODO(peering): retry logic; fail past a threshold
for {
var err error
// Typically, this function will block inside `m.subscribeServerAddrs` and only return on error.
// Errors are logged and the watch is retried.
idx, err = m.subscribeServerAddrs(ctx, idx, updateCh)
if errors.Is(err, stream.ErrSubForceClosed) {
m.logger.Trace("subscription force-closed due to an ACL change or snapshot restore, will attempt resume")
} else if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
m.logger.Warn("failed to subscribe to server addresses, will attempt resume", "error", err.Error())
} else {
m.logger.Trace(err.Error())
}
configNotifyCh := m.notifyMeshConfigUpdates(ctx)
// Intentionally initialized to empty values.
// These are set after the first mesh config entry update arrives.
var queryCtx context.Context
cancel := func() {}
useGateways := false
for {
select {
case <-ctx.Done():
cancel()
return
case event := <-configNotifyCh:
entry, ok := event.Result.(*structs.MeshConfigEntry)
if event.Result != nil && !ok {
m.logger.Error(fmt.Sprintf("saw unexpected type %T for mesh config entry: falling back to pushing direct server addresses", event.Result))
}
if entry != nil && entry.Peering != nil && entry.Peering.PeerThroughMeshGateways {
useGateways = true
} else {
useGateways = false
}
// Cancel and re-set watches based on the updated config entry.
cancel()
queryCtx, cancel = context.WithCancel(ctx)
if useGateways {
go m.notifyServerMeshGatewayAddresses(queryCtx, updateCh)
} else {
go m.ensureServerAddrSubscription(queryCtx, updateCh)
}
}
}
}
func (m *subscriptionManager) notifyMeshConfigUpdates(ctx context.Context) <-chan cache.UpdateEvent {
const meshConfigWatch = "mesh-config-entry"
notifyCh := make(chan cache.UpdateEvent, 1)
go m.syncViaBlockingQuery(ctx, meshConfigWatch, func(ctx_ context.Context, store StateStore, ws memdb.WatchSet) (interface{}, error) {
_, rawEntry, err := store.ConfigEntry(ws, structs.MeshConfig, structs.MeshConfigMesh, acl.DefaultEnterpriseMeta())
if err != nil {
return nil, fmt.Errorf("failed to get mesh config entry: %w", err)
}
return rawEntry, nil
}, meshConfigWatch, notifyCh)
return notifyCh
}
func (m *subscriptionManager) notifyServerMeshGatewayAddresses(ctx context.Context, updateCh chan<- cache.UpdateEvent) {
m.syncViaBlockingQuery(ctx, "mesh-gateways", func(ctx context.Context, store StateStore, ws memdb.WatchSet) (interface{}, error) {
_, nodes, err := store.ServiceDump(ws, structs.ServiceKindMeshGateway, true, acl.DefaultEnterpriseMeta(), structs.DefaultPeerKeyword)
if err != nil {
return nil, fmt.Errorf("failed to watch mesh gateways services for servers: %w", err)
}
var gatewayAddrs []string
for _, csn := range nodes {
_, addr, port := csn.BestAddress(true)
gatewayAddrs = append(gatewayAddrs, ipaddr.FormatAddressPort(addr, port))
}
if len(gatewayAddrs) == 0 {
return nil, errors.New("configured to peer through mesh gateways but no mesh gateways are registered")
}
// We may return an empty list if there are no gateway addresses.
return &pbpeering.PeeringServerAddresses{
Addresses: gatewayAddrs,
}, nil
}, subServerAddrs, updateCh)
}
func (m *subscriptionManager) ensureServerAddrSubscription(ctx context.Context, updateCh chan<- cache.UpdateEvent) {
waiter := &retry.Waiter{
MinFailures: 1,
Factor: 500 * time.Millisecond,
MaxWait: 60 * time.Second,
Jitter: retry.NewJitter(100),
}
logger := m.logger.With("queryType", "server-addresses")
var idx uint64
for {
var err error
idx, err = m.subscribeServerAddrs(ctx, idx, updateCh)
if errors.Is(err, stream.ErrSubForceClosed) {
logger.Trace("subscription force-closed due to an ACL change or snapshot restore, will attempt resume")
} else if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.Warn("failed to subscribe to server addresses, will attempt resume", "error", err.Error())
} else if err != nil {
logger.Trace(err.Error())
return
}
if err := waiter.Wait(ctx); err != nil {
return
default:
}
}
}
@ -826,17 +910,22 @@ func (m *subscriptionManager) subscribeServerAddrs(
grpcAddr := srv.Address + ":" + strconv.Itoa(srv.ExtGRPCPort)
serverAddrs = append(serverAddrs, grpcAddr)
}
if len(serverAddrs) == 0 {
m.logger.Warn("did not find any server addresses with external gRPC ports to publish")
continue
}
updateCh <- cache.UpdateEvent{
u := cache.UpdateEvent{
CorrelationID: subServerAddrs,
Result: &pbpeering.PeeringServerAddresses{
Addresses: serverAddrs,
},
}
select {
case <-ctx.Done():
return 0, ctx.Err()
case updateCh <- u:
}
}
}

View File

@ -7,6 +7,7 @@ import (
"testing"
"time"
"github.com/hashicorp/consul/types"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
@ -49,10 +50,7 @@ func TestSubscriptionManager_RegisterDeregister(t *testing.T) {
subCh := mgr.subscribe(ctx, id, "my-peering", partition)
var (
gatewayCorrID = subMeshGateway + partition
mysqlCorrID = subExportedService + structs.NewServiceName("mysql", nil).String()
mysqlCorrID = subExportedService + structs.NewServiceName("mysql", nil).String()
mysqlProxyCorrID = subExportedService + structs.NewServiceName("mysql-sidecar-proxy", nil).String()
)
@ -60,11 +58,7 @@ func TestSubscriptionManager_RegisterDeregister(t *testing.T) {
expectEvents(t, subCh,
func(t *testing.T, got cache.UpdateEvent) {
checkExportedServices(t, got, []string{})
},
func(t *testing.T, got cache.UpdateEvent) {
checkEvent(t, got, gatewayCorrID, 0)
},
)
})
// Initially add in L4 failover so that later we can test removing it. We
// cannot do the other way around because it would fail validation to
@ -300,17 +294,6 @@ func TestSubscriptionManager_RegisterDeregister(t *testing.T) {
},
}, res.Nodes[0])
},
func(t *testing.T, got cache.UpdateEvent) {
require.Equal(t, gatewayCorrID, got.CorrelationID)
res := got.Result.(*pbservice.IndexedCheckServiceNodes)
require.Equal(t, uint64(0), res.Index)
require.Len(t, res.Nodes, 1)
prototest.AssertDeepEqual(t, &pbservice.CheckServiceNode{
Node: pbNode("mgw", "10.1.1.1", partition),
Service: pbService("mesh-gateway", "gateway-1", "gateway", 8443, nil),
}, res.Nodes[0])
},
)
})
@ -434,13 +417,6 @@ func TestSubscriptionManager_RegisterDeregister(t *testing.T) {
res := got.Result.(*pbservice.IndexedCheckServiceNodes)
require.Equal(t, uint64(0), res.Index)
require.Len(t, res.Nodes, 0)
},
func(t *testing.T, got cache.UpdateEvent) {
require.Equal(t, gatewayCorrID, got.CorrelationID)
res := got.Result.(*pbservice.IndexedCheckServiceNodes)
require.Equal(t, uint64(0), res.Index)
require.Len(t, res.Nodes, 0)
},
)
@ -506,8 +482,6 @@ func TestSubscriptionManager_InitialSnapshot(t *testing.T) {
backend.ensureService(t, "zip", mongo.Service)
var (
gatewayCorrID = subMeshGateway + partition
mysqlCorrID = subExportedService + structs.NewServiceName("mysql", nil).String()
mongoCorrID = subExportedService + structs.NewServiceName("mongo", nil).String()
chainCorrID = subExportedService + structs.NewServiceName("chain", nil).String()
@ -521,9 +495,6 @@ func TestSubscriptionManager_InitialSnapshot(t *testing.T) {
expectEvents(t, subCh,
func(t *testing.T, got cache.UpdateEvent) {
checkExportedServices(t, got, []string{})
},
func(t *testing.T, got cache.UpdateEvent) {
checkEvent(t, got, gatewayCorrID, 0)
})
// At this point in time we'll have a mesh-gateway notification with no
@ -597,9 +568,6 @@ func TestSubscriptionManager_InitialSnapshot(t *testing.T) {
func(t *testing.T, got cache.UpdateEvent) {
checkEvent(t, got, mysqlProxyCorrID, 1, "mysql-sidecar-proxy", string(structs.ServiceKindConnectProxy))
},
func(t *testing.T, got cache.UpdateEvent) {
checkEvent(t, got, gatewayCorrID, 1, "gateway", string(structs.ServiceKindMeshGateway))
},
)
})
}
@ -741,6 +709,102 @@ func TestSubscriptionManager_ServerAddrs(t *testing.T) {
},
)
})
testutil.RunStep(t, "flipped to peering through mesh gateways", func(t *testing.T) {
require.NoError(t, backend.store.EnsureConfigEntry(1, &structs.MeshConfigEntry{
Peering: &structs.PeeringMeshConfig{
PeerThroughMeshGateways: true,
},
}))
select {
case <-time.After(100 * time.Millisecond):
case <-subCh:
t.Fatal("expected to time out: no mesh gateways are registered")
}
})
testutil.RunStep(t, "registered and received a mesh gateway", func(t *testing.T) {
reg := structs.RegisterRequest{
ID: types.NodeID("b5489ca9-f5e9-4dba-a779-61fec4e8e364"),
Node: "gw-node",
Address: "1.2.3.4",
TaggedAddresses: map[string]string{
structs.TaggedAddressWAN: "172.217.22.14",
},
Service: &structs.NodeService{
ID: "mesh-gateway",
Service: "mesh-gateway",
Kind: structs.ServiceKindMeshGateway,
Port: 443,
TaggedAddresses: map[string]structs.ServiceAddress{
structs.TaggedAddressWAN: {Address: "154.238.12.252", Port: 8443},
},
},
}
require.NoError(t, backend.store.EnsureRegistration(2, &reg))
expectEvents(t, subCh,
func(t *testing.T, got cache.UpdateEvent) {
require.Equal(t, subServerAddrs, got.CorrelationID)
addrs, ok := got.Result.(*pbpeering.PeeringServerAddresses)
require.True(t, ok)
require.Equal(t, []string{"154.238.12.252:8443"}, addrs.GetAddresses())
},
)
})
testutil.RunStep(t, "registered and received a second mesh gateway", func(t *testing.T) {
reg := structs.RegisterRequest{
ID: types.NodeID("e4cc0af3-5c09-4ddf-94a9-5840e427bc45"),
Node: "gw-node-2",
Address: "1.2.3.5",
TaggedAddresses: map[string]string{
structs.TaggedAddressWAN: "172.217.22.15",
},
Service: &structs.NodeService{
ID: "mesh-gateway",
Service: "mesh-gateway",
Kind: structs.ServiceKindMeshGateway,
Port: 443,
},
}
require.NoError(t, backend.store.EnsureRegistration(3, &reg))
expectEvents(t, subCh,
func(t *testing.T, got cache.UpdateEvent) {
require.Equal(t, subServerAddrs, got.CorrelationID)
addrs, ok := got.Result.(*pbpeering.PeeringServerAddresses)
require.True(t, ok)
require.Equal(t, []string{"154.238.12.252:8443", "172.217.22.15:443"}, addrs.GetAddresses())
},
)
})
testutil.RunStep(t, "disabled peering through gateways and received server addresses", func(t *testing.T) {
require.NoError(t, backend.store.EnsureConfigEntry(4, &structs.MeshConfigEntry{
Peering: &structs.PeeringMeshConfig{
PeerThroughMeshGateways: false,
},
}))
expectEvents(t, subCh,
func(t *testing.T, got cache.UpdateEvent) {
require.Equal(t, subServerAddrs, got.CorrelationID)
addrs, ok := got.Result.(*pbpeering.PeeringServerAddresses)
require.True(t, ok)
// New subscriptions receive a snapshot from the event publisher.
// At the start of the test the handler registered a mock that only returns a single address.
require.Equal(t, []string{"198.18.0.1:8502"}, addrs.GetAddresses())
},
)
})
}
type testSubscriptionBackend struct {

View File

@ -96,7 +96,9 @@ func (w *Waiter) Failures() int {
// Every call to Wait increments the failures count, so Reset must be called
// after Wait when there wasn't a failure.
//
// Wait will return ctx.Err() if the context is cancelled.
// The only non-nil error that Wait returns will come from ctx.Err(),
// such as when the context is canceled. This makes it suitable for
// long-running routines that do not get re-initialized, such as replication.
func (w *Waiter) Wait(ctx context.Context) error {
w.failures++
timer := time.NewTimer(w.delay())