Add trigger for doing reconciliation based on watch sets (#16052)
* Add trigger for doing reconciliation based on watch sets * update doc string * Fix my grammar fail
This commit is contained in:
parent
43c9eccf5a
commit
b376fd2151
|
@ -4,12 +4,14 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/state"
|
||||
"github.com/hashicorp/consul/agent/consul/stream"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
|
@ -39,6 +41,9 @@ type Controller interface {
|
|||
// Request retry rate limiter. This should only ever be called prior to
|
||||
// running Run.
|
||||
WithBackoff(base, max time.Duration) Controller
|
||||
// WithLogger sets the logger for the controller, it should be called prior to Start
|
||||
// being invoked.
|
||||
WithLogger(logger hclog.Logger) Controller
|
||||
// WithWorkers sets the number of worker goroutines used to process the queue
|
||||
// this defaults to 1 goroutine.
|
||||
WithWorkers(i int) Controller
|
||||
|
@ -46,6 +51,12 @@ type Controller interface {
|
|||
// implementation. This is most useful for testing. This should only ever be called
|
||||
// prior to running Run.
|
||||
WithQueueFactory(fn func(ctx context.Context, baseBackoff time.Duration, maxBackoff time.Duration) WorkQueue) Controller
|
||||
// AddTrigger allows for triggering a reconciliation request when a
|
||||
// triggering function returns, when the passed in context is canceled
|
||||
// the trigger must return
|
||||
AddTrigger(request Request, trigger func(ctx context.Context) error)
|
||||
// RemoveTrigger removes the triggering function associated with the Request object
|
||||
RemoveTrigger(request Request)
|
||||
}
|
||||
|
||||
var _ Controller = &controller{}
|
||||
|
@ -78,8 +89,27 @@ type controller struct {
|
|||
// publisher is the event publisher that should be subscribed to for any updates
|
||||
publisher state.EventPublisher
|
||||
|
||||
// waitOnce ensures we wait until the controller has started
|
||||
waitOnce sync.Once
|
||||
// started signals when the controller has started
|
||||
started chan struct{}
|
||||
|
||||
// group is the error group used in our main start up worker routines
|
||||
group *errgroup.Group
|
||||
// groupCtx is the context of the error group to use in spinning up our
|
||||
// worker routines
|
||||
groupCtx context.Context
|
||||
|
||||
// triggers is a map of cancel functions for out-of-band Request triggers
|
||||
triggers map[Request]func()
|
||||
// triggerMutex is used for accessing the above map
|
||||
triggerMutex sync.Mutex
|
||||
|
||||
// running ensures that we are only calling Run a single time
|
||||
running int32
|
||||
|
||||
// logger is the logger for the controller
|
||||
logger hclog.Logger
|
||||
}
|
||||
|
||||
// New returns a new Controller associated with the given state store and reconciler.
|
||||
|
@ -91,6 +121,9 @@ func New(publisher state.EventPublisher, reconciler Reconciler) Controller {
|
|||
baseBackoff: 5 * time.Millisecond,
|
||||
maxBackoff: 1000 * time.Second,
|
||||
makeQueue: RunWorkQueue,
|
||||
started: make(chan struct{}),
|
||||
triggers: make(map[Request]func()),
|
||||
logger: hclog.NewNullLogger(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -130,6 +163,14 @@ func (c *controller) WithWorkers(i int) Controller {
|
|||
return c
|
||||
}
|
||||
|
||||
// WithLogger sets the internal logger for the controller.
|
||||
func (c *controller) WithLogger(logger hclog.Logger) Controller {
|
||||
c.ensureNotRunning()
|
||||
|
||||
c.logger = logger
|
||||
return c
|
||||
}
|
||||
|
||||
// WithQueueFactory changes the initialization method for the Controller's work
|
||||
// queue, this is predominantly just used for testing. This should only ever be called
|
||||
// prior to running Start.
|
||||
|
@ -157,15 +198,18 @@ func (c *controller) Run(ctx context.Context) error {
|
|||
panic("Run cannot be called more than once")
|
||||
}
|
||||
|
||||
group, groupCtx := errgroup.WithContext(ctx)
|
||||
c.group, c.groupCtx = errgroup.WithContext(ctx)
|
||||
|
||||
// set up our queue
|
||||
c.work = c.makeQueue(groupCtx, c.baseBackoff, c.maxBackoff)
|
||||
c.work = c.makeQueue(c.groupCtx, c.baseBackoff, c.maxBackoff)
|
||||
|
||||
// we can now add stuff to the queue from other contexts
|
||||
close(c.started)
|
||||
|
||||
for _, sub := range c.subscriptions {
|
||||
// store a reference for the closure
|
||||
sub := sub
|
||||
group.Go(func() error {
|
||||
c.group.Go(func() error {
|
||||
var index uint64
|
||||
|
||||
subscription, err := c.publisher.Subscribe(sub.request)
|
||||
|
@ -201,14 +245,14 @@ func (c *controller) Run(ctx context.Context) error {
|
|||
}
|
||||
|
||||
for i := 0; i < c.workers; i++ {
|
||||
group.Go(func() error {
|
||||
c.group.Go(func() error {
|
||||
for {
|
||||
request, shutdown := c.work.Get()
|
||||
if shutdown {
|
||||
// Stop working
|
||||
return nil
|
||||
}
|
||||
c.reconcileHandler(groupCtx, request)
|
||||
c.reconcileHandler(c.groupCtx, request)
|
||||
// Done is called here because it is required to be called
|
||||
// when we've finished processing each request
|
||||
c.work.Done(request)
|
||||
|
@ -216,10 +260,57 @@ func (c *controller) Run(ctx context.Context) error {
|
|||
})
|
||||
}
|
||||
|
||||
<-groupCtx.Done()
|
||||
<-c.groupCtx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddTrigger allows for triggering a reconciliation request every time that the
|
||||
// triggering function returns, when the passed in context is canceled
|
||||
// the trigger must return
|
||||
func (c *controller) AddTrigger(request Request, trigger func(ctx context.Context) error) {
|
||||
c.wait()
|
||||
|
||||
ctx, cancel := context.WithCancel(c.groupCtx)
|
||||
|
||||
c.triggerMutex.Lock()
|
||||
oldCancel, ok := c.triggers[request]
|
||||
if ok {
|
||||
oldCancel()
|
||||
}
|
||||
c.triggers[request] = cancel
|
||||
c.triggerMutex.Unlock()
|
||||
|
||||
c.group.Go(func() error {
|
||||
if err := trigger(ctx); err != nil {
|
||||
c.logger.Error("error while running trigger, adding re-reconcilation anyway", "error", err)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
c.work.Add(request)
|
||||
return nil
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveTrigger removes the triggering function associated with the Request object
|
||||
func (c *controller) RemoveTrigger(request Request) {
|
||||
c.triggerMutex.Lock()
|
||||
cancel, ok := c.triggers[request]
|
||||
if ok {
|
||||
cancel()
|
||||
delete(c.triggers, request)
|
||||
}
|
||||
c.triggerMutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *controller) wait() {
|
||||
c.waitOnce.Do(func() {
|
||||
<-c.started
|
||||
})
|
||||
}
|
||||
|
||||
func (c *controller) processEvent(sub subscription, event stream.Event) error {
|
||||
switch payload := event.Payload.(type) {
|
||||
case state.EventPayloadConfigEntry:
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/hashicorp/consul/agent/consul/stream"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-memdb"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -415,3 +416,148 @@ func TestConfigEntrySubscriptions(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicController_Triggers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
reconciler := newTestReconciler(true)
|
||||
|
||||
publisher := stream.NewEventPublisher(0)
|
||||
go publisher.Run(ctx)
|
||||
|
||||
controller := New(publisher, reconciler)
|
||||
|
||||
go func() {
|
||||
require.NoError(t, controller.Run(ctx))
|
||||
}()
|
||||
|
||||
ensureCalled := func(request chan Request, name string) bool {
|
||||
select {
|
||||
case req := <-request:
|
||||
require.Equal(t, structs.IngressGateway, req.Kind)
|
||||
require.Equal(t, name, req.Name)
|
||||
return true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
request := Request{
|
||||
Kind: structs.IngressGateway,
|
||||
Name: "foo-1",
|
||||
}
|
||||
|
||||
triggerOneChan := make(chan struct{}, 3)
|
||||
triggerOne := func(ctx context.Context) error {
|
||||
select {
|
||||
case <-triggerOneChan:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
controller.AddTrigger(request, triggerOne)
|
||||
require.False(t, ensureCalled(reconciler.received, "foo-1"))
|
||||
triggerOneChan <- struct{}{}
|
||||
reconciler.stepFor(10 * time.Millisecond)
|
||||
require.True(t, ensureCalled(reconciler.received, "foo-1"))
|
||||
|
||||
// do it again
|
||||
require.False(t, ensureCalled(reconciler.received, "foo-1"))
|
||||
controller.AddTrigger(request, triggerOne)
|
||||
triggerOneChan <- struct{}{}
|
||||
reconciler.stepFor(10 * time.Millisecond)
|
||||
require.True(t, ensureCalled(reconciler.received, "foo-1"))
|
||||
|
||||
// check with the overwritten trigger
|
||||
controller.AddTrigger(request, triggerOne)
|
||||
triggerTwoChan := make(chan struct{}, 2)
|
||||
triggerTwo := func(ctx context.Context) error {
|
||||
select {
|
||||
case <-triggerTwoChan:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
controller.AddTrigger(request, triggerTwo)
|
||||
triggerOneChan <- struct{}{}
|
||||
reconciler.stepFor(10 * time.Millisecond)
|
||||
require.False(t, ensureCalled(reconciler.received, "foo-1"))
|
||||
triggerTwoChan <- struct{}{}
|
||||
reconciler.stepFor(10 * time.Millisecond)
|
||||
require.True(t, ensureCalled(reconciler.received, "foo-1"))
|
||||
|
||||
// remove the trigger and make sure we're not called again
|
||||
controller.RemoveTrigger(request)
|
||||
triggerTwoChan <- struct{}{}
|
||||
reconciler.stepFor(10 * time.Millisecond)
|
||||
require.False(t, ensureCalled(reconciler.received, "foo-1"))
|
||||
}
|
||||
|
||||
func TestDiscoveryChainController(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
reconciler := newTestReconciler(false)
|
||||
|
||||
publisher := stream.NewEventPublisher(1 * time.Millisecond)
|
||||
go publisher.Run(ctx)
|
||||
|
||||
// get the store through the FSM since the publisher handlers get registered through it
|
||||
store := fsm.NewFromDeps(fsm.Deps{
|
||||
Logger: hclog.New(nil),
|
||||
NewStateStore: func() *state.Store {
|
||||
return state.NewStateStoreWithEventPublisher(nil, publisher)
|
||||
},
|
||||
Publisher: publisher,
|
||||
}).State()
|
||||
|
||||
controller := New(publisher, reconciler)
|
||||
go controller.Subscribe(&stream.SubscribeRequest{
|
||||
Topic: state.EventTopicIngressGateway,
|
||||
Subject: stream.SubjectWildcard,
|
||||
}).WithWorkers(10).Run(ctx)
|
||||
|
||||
request := Request{
|
||||
Kind: structs.IngressGateway,
|
||||
Name: "foo-1",
|
||||
}
|
||||
|
||||
ensureCalled := func(request chan Request, name string) bool {
|
||||
select {
|
||||
case req := <-request:
|
||||
require.Equal(t, structs.IngressGateway, req.Kind)
|
||||
require.Equal(t, name, req.Name)
|
||||
return true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
require.NoError(t, store.EnsureConfigEntry(1, &structs.IngressGatewayConfigEntry{
|
||||
Kind: structs.IngressGateway,
|
||||
Name: "foo-1",
|
||||
}))
|
||||
require.True(t, ensureCalled(reconciler.received, "foo-1"))
|
||||
|
||||
// create the trigger and something that changes in its upstream discovery chain and ensure that we've
|
||||
// fired the reconciler
|
||||
ws := memdb.NewWatchSet()
|
||||
ws.Add(store.AbandonCh())
|
||||
_, _, err := store.ReadDiscoveryChainConfigEntries(ws, "foo-2", nil)
|
||||
require.NoError(t, err)
|
||||
controller.AddTrigger(request, ws.WatchCtx)
|
||||
|
||||
require.False(t, ensureCalled(reconciler.received, "foo-1"))
|
||||
require.NoError(t, store.EnsureConfigEntry(1, &structs.ServiceResolverConfigEntry{
|
||||
Kind: structs.ServiceResolver,
|
||||
Name: "foo-2",
|
||||
}))
|
||||
require.True(t, ensureCalled(reconciler.received, "foo-1"))
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package controller
|
|||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testReconciler struct {
|
||||
|
@ -43,6 +44,12 @@ func (r *testReconciler) setResponse(err error) {
|
|||
func (r *testReconciler) step() {
|
||||
r.stepChan <- struct{}{}
|
||||
}
|
||||
func (r *testReconciler) stepFor(duration time.Duration) {
|
||||
select {
|
||||
case r.stepChan <- struct{}{}:
|
||||
case <-time.After(duration):
|
||||
}
|
||||
}
|
||||
|
||||
func (r *testReconciler) stop() {
|
||||
close(r.stopChan)
|
||||
|
|
Loading…
Reference in New Issue