// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package fairshare import ( "fmt" "io/ioutil" "sync" log "github.com/hashicorp/go-hclog" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/helper/logging" ) // Job is an interface for jobs used with this job manager type Job interface { // Execute performs the work. // It should be synchronous if a cleanupFn is provided. Execute() error // OnFailure handles the error resulting from a failed Execute(). // It should be synchronous if a cleanupFn is provided. OnFailure(err error) } type ( initFn func() cleanupFn func() ) type wrappedJob struct { job Job init initFn cleanup cleanupFn } // worker represents a single worker in a pool type worker struct { name string jobCh <-chan wrappedJob quit chan struct{} logger log.Logger // waitgroup for testing stop functionality wg *sync.WaitGroup } // start starts the worker listening and working until the quit channel is closed func (w *worker) start() { w.wg.Add(1) go func() { for { select { case <-w.quit: w.wg.Done() return case wJob := <-w.jobCh: if wJob.init != nil { wJob.init() } err := wJob.job.Execute() if err != nil { wJob.job.OnFailure(err) } if wJob.cleanup != nil { wJob.cleanup() } } } }() } // dispatcher represents a worker pool type dispatcher struct { name string numWorkers int workers []worker jobCh chan wrappedJob onceStart sync.Once onceStop sync.Once quit chan struct{} logger log.Logger wg *sync.WaitGroup } // newDispatcher generates a new worker dispatcher and populates it with workers func newDispatcher(name string, numWorkers int, l log.Logger) *dispatcher { d := createDispatcher(name, numWorkers, l) d.init() return d } // dispatch dispatches a job to the worker pool, with optional initialization // and cleanup functions (useful for tracking job progress) func (d *dispatcher) dispatch(job Job, init initFn, cleanup cleanupFn) { wJob := wrappedJob{ init: init, job: job, cleanup: cleanup, } select { case d.jobCh <- wJob: case <-d.quit: d.logger.Info("shutting down during dispatch") } } // start starts all the workers listening on the job channel // this will only start the workers for this dispatch once func (d *dispatcher) start() { d.onceStart.Do(func() { d.logger.Trace("starting dispatcher") for _, w := range d.workers { worker := w worker.start() } }) } // stop stops the worker pool asynchronously func (d *dispatcher) stop() { d.onceStop.Do(func() { d.logger.Trace("terminating dispatcher") close(d.quit) }) } // createDispatcher generates a new Dispatcher object, but does not initialize the // worker pool func createDispatcher(name string, numWorkers int, l log.Logger) *dispatcher { if l == nil { l = logging.NewVaultLoggerWithWriter(ioutil.Discard, log.NoLevel) } if numWorkers <= 0 { numWorkers = 1 l.Warn("must have 1 or more workers. setting number of workers to 1") } if name == "" { guid, err := uuid.GenerateUUID() if err != nil { l.Warn("uuid generator failed, using 'no-uuid'", "err", err) guid = "no-uuid" } name = fmt.Sprintf("dispatcher-%s", guid) } var wg sync.WaitGroup d := dispatcher{ name: name, numWorkers: numWorkers, workers: make([]worker, 0), jobCh: make(chan wrappedJob), quit: make(chan struct{}), logger: l, wg: &wg, } d.logger.Trace("created dispatcher", "name", d.name, "num_workers", d.numWorkers) return &d } func (d *dispatcher) init() { for len(d.workers) < d.numWorkers { d.initializeWorker() } d.logger.Trace("initialized dispatcher", "num_workers", d.numWorkers) } // initializeWorker initializes and adds a new worker, with an optional name func (d *dispatcher) initializeWorker() { w := worker{ name: fmt.Sprint("worker-", len(d.workers)), jobCh: d.jobCh, quit: d.quit, logger: d.logger, wg: d.wg, } d.workers = append(d.workers, w) }