Share mgw addrs in peering stream if needed

This commit adds handling so that the replication stream considers
whether the user intends to peer through mesh gateways.

The subscription will return server or mesh gateway addresses depending
on the mesh configuration setting. These watches can be updated at
runtime by modifying the mesh config entry.
This commit is contained in:
freddygv 2022-09-21 09:55:19 -06:00
parent 17463472b7
commit 2c5caec97c
8 changed files with 233 additions and 80 deletions

View File

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/types" "github.com/hashicorp/consul/types"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
gogrpc "google.golang.org/grpc" gogrpc "google.golang.org/grpc"
@ -71,7 +72,7 @@ func TestPeeringBackend_GetServerAddresses(t *testing.T) {
} }
_, cfg := testServerConfig(t) _, cfg := testServerConfig(t)
cfg.GRPCTLSPort = 8505 cfg.GRPCTLSPort = freeport.GetOne(t)
srv, err := newServer(t, cfg) srv, err := newServer(t, cfg)
require.NoError(t, err) require.NoError(t, err)

View File

@ -150,7 +150,6 @@ func (m *CertManager) watchServerToken(ctx context.Context) {
// Cancel existing the leaf cert watch and spin up new one any time the server token changes. // Cancel existing the leaf cert watch and spin up new one any time the server token changes.
// The watch needs the current token as set by the leader since certificate signing requests go to the leader. // The watch needs the current token as set by the leader since certificate signing requests go to the leader.
fmt.Println("canceling and resetting")
cancel() cancel()
notifyCtx, cancel = context.WithCancel(ctx) notifyCtx, cancel = context.WithCancel(ctx)

View File

@ -123,5 +123,6 @@ type StateStore interface {
CAConfig(ws memdb.WatchSet) (uint64, *structs.CAConfiguration, error) CAConfig(ws memdb.WatchSet) (uint64, *structs.CAConfiguration, error)
TrustBundleListByService(ws memdb.WatchSet, service, dc string, entMeta acl.EnterpriseMeta) (uint64, []*pbpeering.PeeringTrustBundle, error) TrustBundleListByService(ws memdb.WatchSet, service, dc string, entMeta acl.EnterpriseMeta) (uint64, []*pbpeering.PeeringTrustBundle, error)
ServiceList(ws memdb.WatchSet, entMeta *acl.EnterpriseMeta, peerName string) (uint64, structs.ServiceList, error) ServiceList(ws memdb.WatchSet, entMeta *acl.EnterpriseMeta, peerName string) (uint64, structs.ServiceList, error)
ConfigEntry(ws memdb.WatchSet, kind, name string, entMeta *acl.EnterpriseMeta) (uint64, structs.ConfigEntry, error)
AbandonCh() <-chan struct{} AbandonCh() <-chan struct{}
} }

View File

@ -640,9 +640,6 @@ func (s *Server) realHandleStream(streamReq HandleStreamRequest) error {
continue continue
} }
case strings.HasPrefix(update.CorrelationID, subMeshGateway):
// TODO(Peering): figure out how to sync this separately
case update.CorrelationID == subCARoot: case update.CorrelationID == subCARoot:
resp, err = makeCARootsResponse(update) resp, err = makeCARootsResponse(update)
if err != nil { if err != nil {

View File

@ -98,25 +98,25 @@ func (m *subscriptionManager) syncViaBlockingQuery(
ws.Add(store.AbandonCh()) ws.Add(store.AbandonCh())
ws.Add(ctx.Done()) ws.Add(ctx.Done())
if result, err := queryFn(ctx, store, ws); err != nil { if result, err := queryFn(ctx, store, ws); err != nil && ctx.Err() == nil {
logger.Error("failed to sync from query", "error", err) logger.Error("failed to sync from query", "error", err)
} else { } else {
// Block for any changes to the state store. select {
updateCh <- cache.UpdateEvent{ case <-ctx.Done():
CorrelationID: correlationID, return
Result: result, case updateCh <- cache.UpdateEvent{CorrelationID: correlationID, Result: result}:
} }
// Block for any changes to the state store.
ws.WatchCtx(ctx) ws.WatchCtx(ctx)
} }
if err := waiter.Wait(ctx); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { err := waiter.Wait(ctx)
logger.Error("failed to wait before re-trying sync", "error", err) if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
}
select {
case <-ctx.Done():
return return
default: } else if err != nil {
logger.Error("failed to wait before re-trying sync", "error", err)
} }
} }
} }

View File

@ -6,9 +6,13 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/hashicorp/consul/ipaddr"
"github.com/hashicorp/consul/lib/retry"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/cache" "github.com/hashicorp/consul/agent/cache"
@ -247,16 +251,10 @@ func (m *subscriptionManager) handleEvent(ctx context.Context, state *subscripti
pending := &pendingPayload{} pending := &pendingPayload{}
// Directly replicate information about our mesh gateways to the consuming side.
// TODO(peering): should we scrub anything before replicating this?
if err := pending.Add(meshGatewayPayloadID, u.CorrelationID, csn); err != nil {
return err
}
if state.exportList != nil { if state.exportList != nil {
// Trigger public events for all synthetic discovery chain replies. // Trigger public events for all synthetic discovery chain replies.
for chainName, info := range state.connectServices { for chainName, info := range state.connectServices {
m.emitEventForDiscoveryChain(ctx, state, pending, chainName, info) m.collectPendingEventForDiscoveryChain(ctx, state, pending, chainName, info)
} }
} }
@ -490,7 +488,7 @@ func (m *subscriptionManager) syncDiscoveryChains(
state.connectServices[chainName] = info state.connectServices[chainName] = info
m.emitEventForDiscoveryChain(ctx, state, pending, chainName, info) m.collectPendingEventForDiscoveryChain(ctx, state, pending, chainName, info)
} }
// if it was dropped, try to emit an DELETE event // if it was dropped, try to emit an DELETE event
@ -517,7 +515,7 @@ func (m *subscriptionManager) syncDiscoveryChains(
} }
} }
func (m *subscriptionManager) emitEventForDiscoveryChain( func (m *subscriptionManager) collectPendingEventForDiscoveryChain(
ctx context.Context, ctx context.Context,
state *subscriptionState, state *subscriptionState,
pending *pendingPayload, pending *pendingPayload,
@ -738,32 +736,118 @@ func (m *subscriptionManager) notifyServerAddrUpdates(
ctx context.Context, ctx context.Context,
updateCh chan<- cache.UpdateEvent, updateCh chan<- cache.UpdateEvent,
) { ) {
// Wait until this is subscribed-to. // Wait until server address updates are subscribed-to.
select { select {
case <-m.serverAddrsSubReady: case <-m.serverAddrsSubReady:
case <-ctx.Done(): case <-ctx.Done():
return return
} }
var idx uint64 configNotifyCh := m.notifyMeshConfigUpdates(ctx)
// TODO(peering): retry logic; fail past a threshold
for {
var err error
// Typically, this function will block inside `m.subscribeServerAddrs` and only return on error.
// Errors are logged and the watch is retried.
idx, err = m.subscribeServerAddrs(ctx, idx, updateCh)
if errors.Is(err, stream.ErrSubForceClosed) {
m.logger.Trace("subscription force-closed due to an ACL change or snapshot restore, will attempt resume")
} else if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
m.logger.Warn("failed to subscribe to server addresses, will attempt resume", "error", err.Error())
} else {
m.logger.Trace(err.Error())
}
// Intentionally initialized to empty values.
// These are set after the first mesh config entry update arrives.
var queryCtx context.Context
cancel := func() {}
useGateways := false
for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
cancel()
return
case event := <-configNotifyCh:
entry, ok := event.Result.(*structs.MeshConfigEntry)
if event.Result != nil && !ok {
m.logger.Error(fmt.Sprintf("saw unexpected type %T for mesh config entry: falling back to pushing direct server addresses", event.Result))
}
if entry != nil && entry.Peering != nil && entry.Peering.PeerThroughMeshGateways {
useGateways = true
} else {
useGateways = false
}
// Cancel and re-set watches based on the updated config entry.
cancel()
queryCtx, cancel = context.WithCancel(ctx)
if useGateways {
go m.notifyServerMeshGatewayAddresses(queryCtx, updateCh)
} else {
go m.ensureServerAddrSubscription(queryCtx, updateCh)
}
}
}
}
func (m *subscriptionManager) notifyMeshConfigUpdates(ctx context.Context) <-chan cache.UpdateEvent {
const meshConfigWatch = "mesh-config-entry"
notifyCh := make(chan cache.UpdateEvent, 1)
go m.syncViaBlockingQuery(ctx, meshConfigWatch, func(ctx_ context.Context, store StateStore, ws memdb.WatchSet) (interface{}, error) {
_, rawEntry, err := store.ConfigEntry(ws, structs.MeshConfig, structs.MeshConfigMesh, acl.DefaultEnterpriseMeta())
if err != nil {
return nil, fmt.Errorf("failed to get mesh config entry: %w", err)
}
return rawEntry, nil
}, meshConfigWatch, notifyCh)
return notifyCh
}
func (m *subscriptionManager) notifyServerMeshGatewayAddresses(ctx context.Context, updateCh chan<- cache.UpdateEvent) {
m.syncViaBlockingQuery(ctx, "mesh-gateways", func(ctx context.Context, store StateStore, ws memdb.WatchSet) (interface{}, error) {
_, nodes, err := store.ServiceDump(ws, structs.ServiceKindMeshGateway, true, acl.DefaultEnterpriseMeta(), structs.DefaultPeerKeyword)
if err != nil {
return nil, fmt.Errorf("failed to watch mesh gateways services for servers: %w", err)
}
var gatewayAddrs []string
for _, csn := range nodes {
_, addr, port := csn.BestAddress(true)
gatewayAddrs = append(gatewayAddrs, ipaddr.FormatAddressPort(addr, port))
}
if len(gatewayAddrs) == 0 {
return nil, errors.New("configured to peer through mesh gateways but no mesh gateways are registered")
}
// We may return an empty list if there are no gateway addresses.
return &pbpeering.PeeringServerAddresses{
Addresses: gatewayAddrs,
}, nil
}, subServerAddrs, updateCh)
}
func (m *subscriptionManager) ensureServerAddrSubscription(ctx context.Context, updateCh chan<- cache.UpdateEvent) {
waiter := &retry.Waiter{
MinFailures: 1,
Factor: 500 * time.Millisecond,
MaxWait: 60 * time.Second,
Jitter: retry.NewJitter(100),
}
logger := m.logger.With("queryType", "server-addresses")
var idx uint64
for {
var err error
idx, err = m.subscribeServerAddrs(ctx, idx, updateCh)
if errors.Is(err, stream.ErrSubForceClosed) {
logger.Trace("subscription force-closed due to an ACL change or snapshot restore, will attempt resume")
} else if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.Warn("failed to subscribe to server addresses, will attempt resume", "error", err.Error())
} else if err != nil {
logger.Trace(err.Error())
return
}
if err := waiter.Wait(ctx); err != nil {
return return
default:
} }
} }
} }
@ -826,17 +910,22 @@ func (m *subscriptionManager) subscribeServerAddrs(
grpcAddr := srv.Address + ":" + strconv.Itoa(srv.ExtGRPCPort) grpcAddr := srv.Address + ":" + strconv.Itoa(srv.ExtGRPCPort)
serverAddrs = append(serverAddrs, grpcAddr) serverAddrs = append(serverAddrs, grpcAddr)
} }
if len(serverAddrs) == 0 { if len(serverAddrs) == 0 {
m.logger.Warn("did not find any server addresses with external gRPC ports to publish") m.logger.Warn("did not find any server addresses with external gRPC ports to publish")
continue continue
} }
updateCh <- cache.UpdateEvent{ u := cache.UpdateEvent{
CorrelationID: subServerAddrs, CorrelationID: subServerAddrs,
Result: &pbpeering.PeeringServerAddresses{ Result: &pbpeering.PeeringServerAddresses{
Addresses: serverAddrs, Addresses: serverAddrs,
}, },
} }
select {
case <-ctx.Done():
return 0, ctx.Err()
case updateCh <- u:
}
} }
} }

View File

@ -7,6 +7,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/hashicorp/consul/types"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -49,10 +50,7 @@ func TestSubscriptionManager_RegisterDeregister(t *testing.T) {
subCh := mgr.subscribe(ctx, id, "my-peering", partition) subCh := mgr.subscribe(ctx, id, "my-peering", partition)
var ( var (
gatewayCorrID = subMeshGateway + partition mysqlCorrID = subExportedService + structs.NewServiceName("mysql", nil).String()
mysqlCorrID = subExportedService + structs.NewServiceName("mysql", nil).String()
mysqlProxyCorrID = subExportedService + structs.NewServiceName("mysql-sidecar-proxy", nil).String() mysqlProxyCorrID = subExportedService + structs.NewServiceName("mysql-sidecar-proxy", nil).String()
) )
@ -60,11 +58,7 @@ func TestSubscriptionManager_RegisterDeregister(t *testing.T) {
expectEvents(t, subCh, expectEvents(t, subCh,
func(t *testing.T, got cache.UpdateEvent) { func(t *testing.T, got cache.UpdateEvent) {
checkExportedServices(t, got, []string{}) checkExportedServices(t, got, []string{})
}, })
func(t *testing.T, got cache.UpdateEvent) {
checkEvent(t, got, gatewayCorrID, 0)
},
)
// Initially add in L4 failover so that later we can test removing it. We // Initially add in L4 failover so that later we can test removing it. We
// cannot do the other way around because it would fail validation to // cannot do the other way around because it would fail validation to
@ -300,17 +294,6 @@ func TestSubscriptionManager_RegisterDeregister(t *testing.T) {
}, },
}, res.Nodes[0]) }, res.Nodes[0])
}, },
func(t *testing.T, got cache.UpdateEvent) {
require.Equal(t, gatewayCorrID, got.CorrelationID)
res := got.Result.(*pbservice.IndexedCheckServiceNodes)
require.Equal(t, uint64(0), res.Index)
require.Len(t, res.Nodes, 1)
prototest.AssertDeepEqual(t, &pbservice.CheckServiceNode{
Node: pbNode("mgw", "10.1.1.1", partition),
Service: pbService("mesh-gateway", "gateway-1", "gateway", 8443, nil),
}, res.Nodes[0])
},
) )
}) })
@ -434,13 +417,6 @@ func TestSubscriptionManager_RegisterDeregister(t *testing.T) {
res := got.Result.(*pbservice.IndexedCheckServiceNodes) res := got.Result.(*pbservice.IndexedCheckServiceNodes)
require.Equal(t, uint64(0), res.Index) require.Equal(t, uint64(0), res.Index)
require.Len(t, res.Nodes, 0)
},
func(t *testing.T, got cache.UpdateEvent) {
require.Equal(t, gatewayCorrID, got.CorrelationID)
res := got.Result.(*pbservice.IndexedCheckServiceNodes)
require.Equal(t, uint64(0), res.Index)
require.Len(t, res.Nodes, 0) require.Len(t, res.Nodes, 0)
}, },
) )
@ -506,8 +482,6 @@ func TestSubscriptionManager_InitialSnapshot(t *testing.T) {
backend.ensureService(t, "zip", mongo.Service) backend.ensureService(t, "zip", mongo.Service)
var ( var (
gatewayCorrID = subMeshGateway + partition
mysqlCorrID = subExportedService + structs.NewServiceName("mysql", nil).String() mysqlCorrID = subExportedService + structs.NewServiceName("mysql", nil).String()
mongoCorrID = subExportedService + structs.NewServiceName("mongo", nil).String() mongoCorrID = subExportedService + structs.NewServiceName("mongo", nil).String()
chainCorrID = subExportedService + structs.NewServiceName("chain", nil).String() chainCorrID = subExportedService + structs.NewServiceName("chain", nil).String()
@ -521,9 +495,6 @@ func TestSubscriptionManager_InitialSnapshot(t *testing.T) {
expectEvents(t, subCh, expectEvents(t, subCh,
func(t *testing.T, got cache.UpdateEvent) { func(t *testing.T, got cache.UpdateEvent) {
checkExportedServices(t, got, []string{}) checkExportedServices(t, got, []string{})
},
func(t *testing.T, got cache.UpdateEvent) {
checkEvent(t, got, gatewayCorrID, 0)
}) })
// At this point in time we'll have a mesh-gateway notification with no // At this point in time we'll have a mesh-gateway notification with no
@ -597,9 +568,6 @@ func TestSubscriptionManager_InitialSnapshot(t *testing.T) {
func(t *testing.T, got cache.UpdateEvent) { func(t *testing.T, got cache.UpdateEvent) {
checkEvent(t, got, mysqlProxyCorrID, 1, "mysql-sidecar-proxy", string(structs.ServiceKindConnectProxy)) checkEvent(t, got, mysqlProxyCorrID, 1, "mysql-sidecar-proxy", string(structs.ServiceKindConnectProxy))
}, },
func(t *testing.T, got cache.UpdateEvent) {
checkEvent(t, got, gatewayCorrID, 1, "gateway", string(structs.ServiceKindMeshGateway))
},
) )
}) })
} }
@ -741,6 +709,102 @@ func TestSubscriptionManager_ServerAddrs(t *testing.T) {
}, },
) )
}) })
testutil.RunStep(t, "flipped to peering through mesh gateways", func(t *testing.T) {
require.NoError(t, backend.store.EnsureConfigEntry(1, &structs.MeshConfigEntry{
Peering: &structs.PeeringMeshConfig{
PeerThroughMeshGateways: true,
},
}))
select {
case <-time.After(100 * time.Millisecond):
case <-subCh:
t.Fatal("expected to time out: no mesh gateways are registered")
}
})
testutil.RunStep(t, "registered and received a mesh gateway", func(t *testing.T) {
reg := structs.RegisterRequest{
ID: types.NodeID("b5489ca9-f5e9-4dba-a779-61fec4e8e364"),
Node: "gw-node",
Address: "1.2.3.4",
TaggedAddresses: map[string]string{
structs.TaggedAddressWAN: "172.217.22.14",
},
Service: &structs.NodeService{
ID: "mesh-gateway",
Service: "mesh-gateway",
Kind: structs.ServiceKindMeshGateway,
Port: 443,
TaggedAddresses: map[string]structs.ServiceAddress{
structs.TaggedAddressWAN: {Address: "154.238.12.252", Port: 8443},
},
},
}
require.NoError(t, backend.store.EnsureRegistration(2, &reg))
expectEvents(t, subCh,
func(t *testing.T, got cache.UpdateEvent) {
require.Equal(t, subServerAddrs, got.CorrelationID)
addrs, ok := got.Result.(*pbpeering.PeeringServerAddresses)
require.True(t, ok)
require.Equal(t, []string{"154.238.12.252:8443"}, addrs.GetAddresses())
},
)
})
testutil.RunStep(t, "registered and received a second mesh gateway", func(t *testing.T) {
reg := structs.RegisterRequest{
ID: types.NodeID("e4cc0af3-5c09-4ddf-94a9-5840e427bc45"),
Node: "gw-node-2",
Address: "1.2.3.5",
TaggedAddresses: map[string]string{
structs.TaggedAddressWAN: "172.217.22.15",
},
Service: &structs.NodeService{
ID: "mesh-gateway",
Service: "mesh-gateway",
Kind: structs.ServiceKindMeshGateway,
Port: 443,
},
}
require.NoError(t, backend.store.EnsureRegistration(3, &reg))
expectEvents(t, subCh,
func(t *testing.T, got cache.UpdateEvent) {
require.Equal(t, subServerAddrs, got.CorrelationID)
addrs, ok := got.Result.(*pbpeering.PeeringServerAddresses)
require.True(t, ok)
require.Equal(t, []string{"154.238.12.252:8443", "172.217.22.15:443"}, addrs.GetAddresses())
},
)
})
testutil.RunStep(t, "disabled peering through gateways and received server addresses", func(t *testing.T) {
require.NoError(t, backend.store.EnsureConfigEntry(4, &structs.MeshConfigEntry{
Peering: &structs.PeeringMeshConfig{
PeerThroughMeshGateways: false,
},
}))
expectEvents(t, subCh,
func(t *testing.T, got cache.UpdateEvent) {
require.Equal(t, subServerAddrs, got.CorrelationID)
addrs, ok := got.Result.(*pbpeering.PeeringServerAddresses)
require.True(t, ok)
// New subscriptions receive a snapshot from the event publisher.
// At the start of the test the handler registered a mock that only returns a single address.
require.Equal(t, []string{"198.18.0.1:8502"}, addrs.GetAddresses())
},
)
})
} }
type testSubscriptionBackend struct { type testSubscriptionBackend struct {

View File

@ -96,7 +96,9 @@ func (w *Waiter) Failures() int {
// Every call to Wait increments the failures count, so Reset must be called // Every call to Wait increments the failures count, so Reset must be called
// after Wait when there wasn't a failure. // after Wait when there wasn't a failure.
// //
// Wait will return ctx.Err() if the context is cancelled. // The only non-nil error that Wait returns will come from ctx.Err(),
// such as when the context is canceled. This makes it suitable for
// long-running routines that do not get re-initialized, such as replication.
func (w *Waiter) Wait(ctx context.Context) error { func (w *Waiter) Wait(ctx context.Context) error {
w.failures++ w.failures++
timer := time.NewTimer(w.delay()) timer := time.NewTimer(w.delay())