Deployment watcher takes state store

This commit is contained in:
Alex Dadgar 2017-08-30 17:45:32 -07:00
parent 3a439f45a6
commit 590ff91bf3
19 changed files with 543 additions and 563 deletions

View file

@ -243,7 +243,7 @@ func (d *Deployment) List(args *structs.DeploymentListRequest, reply *structs.De
} }
reply.Deployments = deploys reply.Deployments = deploys
// Use the last index that affected the jobs table // Use the last index that affected the deployment table
index, err := state.Index("deployment") index, err := state.Index("deployment")
if err != nil { if err != nil {
return err return err

View file

@ -4,83 +4,6 @@ import (
"github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs"
) )
// deploymentWatcherStateShim is the shim that provides the state watching
// methods. These should be set by the server and passed to the deployment
// watcher.
type deploymentWatcherStateShim struct {
// region is the region the server is a member of. It is used to
// auto-populate requests that do not have it set
region string
// evaluations returns the set of evaluations for the given job
evaluations func(args *structs.JobSpecificRequest, reply *structs.JobEvaluationsResponse) error
// allocations returns the set of allocations that are part of the
// deployment.
allocations func(args *structs.DeploymentSpecificRequest, reply *structs.AllocListResponse) error
// list is used to list all the deployments in the system
list func(args *structs.DeploymentListRequest, reply *structs.DeploymentListResponse) error
// GetDeployment is used to lookup a particular deployment.
getDeployment func(args *structs.DeploymentSpecificRequest, reply *structs.SingleDeploymentResponse) error
// getJobVersions is used to lookup the versions of a job. This is used when
// rolling back to find the latest stable job
getJobVersions func(args *structs.JobVersionsRequest, reply *structs.JobVersionsResponse) error
// getJob is used to lookup a particular job.
getJob func(args *structs.JobSpecificRequest, reply *structs.SingleJobResponse) error
}
func (d *deploymentWatcherStateShim) Evaluations(args *structs.JobSpecificRequest, reply *structs.JobEvaluationsResponse) error {
if args.Region == "" {
args.Region = d.region
}
return d.evaluations(args, reply)
}
func (d *deploymentWatcherStateShim) Allocations(args *structs.DeploymentSpecificRequest, reply *structs.AllocListResponse) error {
if args.Region == "" {
args.Region = d.region
}
return d.allocations(args, reply)
}
func (d *deploymentWatcherStateShim) List(args *structs.DeploymentListRequest, reply *structs.DeploymentListResponse) error {
if args.Region == "" {
args.Region = d.region
}
return d.list(args, reply)
}
func (d *deploymentWatcherStateShim) GetDeployment(args *structs.DeploymentSpecificRequest, reply *structs.SingleDeploymentResponse) error {
if args.Region == "" {
args.Region = d.region
}
return d.getDeployment(args, reply)
}
func (d *deploymentWatcherStateShim) GetJobVersions(args *structs.JobVersionsRequest, reply *structs.JobVersionsResponse) error {
if args.Region == "" {
args.Region = d.region
}
return d.getJobVersions(args, reply)
}
func (d *deploymentWatcherStateShim) GetJob(args *structs.JobSpecificRequest, reply *structs.SingleJobResponse) error {
if args.Region == "" {
args.Region = d.region
}
return d.getJob(args, reply)
}
// deploymentWatcherRaftShim is the shim that provides the state watching // deploymentWatcherRaftShim is the shim that provides the state watching
// methods. These should be set by the server and passed to the deployment // methods. These should be set by the server and passed to the deployment
// watcher. // watcher.

View file

@ -8,7 +8,9 @@ import (
"golang.org/x/time/rate" "golang.org/x/time/rate"
memdb "github.com/hashicorp/go-memdb"
"github.com/hashicorp/nomad/helper" "github.com/hashicorp/nomad/helper"
"github.com/hashicorp/nomad/nomad/state"
"github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs"
) )
@ -49,9 +51,8 @@ type deploymentWatcher struct {
// deployment // deployment
deploymentTriggers deploymentTriggers
// DeploymentStateWatchers holds the methods required to watch objects for // state is the state that is watched for state changes.
// changes on behalf of the deployment state *state.StateStore
watchers DeploymentStateWatchers
// d is the deployment being watched // d is the deployment being watched
d *structs.Deployment d *structs.Deployment
@ -77,7 +78,7 @@ type deploymentWatcher struct {
// newDeploymentWatcher returns a deployment watcher that is used to watch // newDeploymentWatcher returns a deployment watcher that is used to watch
// deployments and trigger the scheduler as needed. // deployments and trigger the scheduler as needed.
func newDeploymentWatcher(parent context.Context, queryLimiter *rate.Limiter, func newDeploymentWatcher(parent context.Context, queryLimiter *rate.Limiter,
logger *log.Logger, watchers DeploymentStateWatchers, d *structs.Deployment, logger *log.Logger, state *state.StateStore, d *structs.Deployment,
j *structs.Job, triggers deploymentTriggers) *deploymentWatcher { j *structs.Job, triggers deploymentTriggers) *deploymentWatcher {
ctx, exitFn := context.WithCancel(parent) ctx, exitFn := context.WithCancel(parent)
@ -85,7 +86,7 @@ func newDeploymentWatcher(parent context.Context, queryLimiter *rate.Limiter,
queryLimiter: queryLimiter, queryLimiter: queryLimiter,
d: d, d: d,
j: j, j: j,
watchers: watchers, state: state,
deploymentTriggers: triggers, deploymentTriggers: triggers,
logger: logger, logger: logger,
ctx: ctx, ctx: ctx,
@ -116,15 +117,19 @@ func (w *deploymentWatcher) SetAllocHealth(
} }
// Get the allocations for the deployment // Get the allocations for the deployment
args := &structs.DeploymentSpecificRequest{DeploymentID: req.DeploymentID} snap, err := w.state.Snapshot()
var resp structs.AllocListResponse if err != nil {
if err := w.watchers.Allocations(args, &resp); err != nil { return err
}
allocs, err := snap.AllocsByDeployment(nil, req.DeploymentID)
if err != nil {
return err return err
} }
// Determine if we should autorevert to an older job // Determine if we should autorevert to an older job
desc := structs.DeploymentStatusDescriptionFailedAllocations desc := structs.DeploymentStatusDescriptionFailedAllocations
for _, alloc := range resp.Allocations { for _, alloc := range allocs {
// Check that the alloc has been marked unhealthy // Check that the alloc has been marked unhealthy
if _, ok := unhealthy[alloc.ID]; !ok { if _, ok := unhealthy[alloc.ID]; !ok {
continue continue
@ -295,7 +300,7 @@ func (w *deploymentWatcher) watch() {
// Block getting all allocations that are part of the deployment using // Block getting all allocations that are part of the deployment using
// the last evaluation index. This will have us block waiting for // the last evaluation index. This will have us block waiting for
// something to change past what the scheduler has evaluated. // something to change past what the scheduler has evaluated.
allocResp, err := w.getAllocs(allocIndex) allocs, index, err := w.getAllocs(allocIndex)
if err != nil { if err != nil {
if err == context.Canceled || w.ctx.Err() == context.Canceled { if err == context.Canceled || w.ctx.Err() == context.Canceled {
return return
@ -304,7 +309,7 @@ func (w *deploymentWatcher) watch() {
w.logger.Printf("[ERR] nomad.deployment_watcher: failed to retrieve allocations for deployment %q: %v", w.d.ID, err) w.logger.Printf("[ERR] nomad.deployment_watcher: failed to retrieve allocations for deployment %q: %v", w.d.ID, err)
return return
} }
allocIndex = allocResp.Index allocIndex = index
// Get the latest evaluation index // Get the latest evaluation index
latestEval, err := w.latestEvalIndex() latestEval, err := w.latestEvalIndex()
@ -320,7 +325,7 @@ func (w *deploymentWatcher) watch() {
// Create an evaluation trigger if there is any allocation whose // Create an evaluation trigger if there is any allocation whose
// deployment status has been updated past the latest eval index. // deployment status has been updated past the latest eval index.
createEval, failDeployment, rollback := false, false, false createEval, failDeployment, rollback := false, false, false
for _, alloc := range allocResp.Allocations { for _, alloc := range allocs {
if alloc.DeploymentStatus == nil || alloc.DeploymentStatus.ModifyIndex <= latestEval { if alloc.DeploymentStatus == nil || alloc.DeploymentStatus.ModifyIndex <= latestEval {
continue continue
} }
@ -379,21 +384,25 @@ func (w *deploymentWatcher) watch() {
} }
} else if createEval { } else if createEval {
// Create an eval to push the deployment along // Create an eval to push the deployment along
w.createEvalBatched(allocResp.Index) w.createEvalBatched(index)
} }
} }
} }
// latestStableJob returns the latest stable job. It may be nil if none exist // latestStableJob returns the latest stable job. It may be nil if none exist
func (w *deploymentWatcher) latestStableJob() (*structs.Job, error) { func (w *deploymentWatcher) latestStableJob() (*structs.Job, error) {
args := &structs.JobVersionsRequest{JobID: w.d.JobID} snap, err := w.state.Snapshot()
var resp structs.JobVersionsResponse if err != nil {
if err := w.watchers.GetJobVersions(args, &resp); err != nil { return nil, err
}
versions, err := snap.JobVersionsByID(nil, w.d.JobID)
if err != nil {
return nil, err return nil, err
} }
var stable *structs.Job var stable *structs.Job
for _, job := range resp.Versions { for _, job := range versions {
if job.Stable { if job.Stable {
stable = job stable = job
break break
@ -454,27 +463,42 @@ func (w *deploymentWatcher) getDeploymentStatusUpdate(status, desc string) *stru
// getAllocs retrieves the allocations that are part of the deployment blocking // getAllocs retrieves the allocations that are part of the deployment blocking
// at the given index. // at the given index.
func (w *deploymentWatcher) getAllocs(index uint64) (*structs.AllocListResponse, error) { func (w *deploymentWatcher) getAllocs(index uint64) ([]*structs.AllocListStub, uint64, error) {
// Build the request resp, index, err := w.state.BlockingQuery(w.getAllocsImpl, index, w.ctx)
args := &structs.DeploymentSpecificRequest{ if err != nil {
DeploymentID: w.d.ID, return nil, 0, err
QueryOptions: structs.QueryOptions{ }
MinQueryIndex: index, if err := w.ctx.Err(); err != nil {
}, return nil, 0, err
} }
var resp structs.AllocListResponse
for resp.Index <= index { return resp.([]*structs.AllocListStub), index, nil
}
// getDeploysImpl retrieves all deployments from the passed state store.
func (w *deploymentWatcher) getAllocsImpl(ws memdb.WatchSet, state *state.StateStore) (interface{}, uint64, error) {
if err := w.queryLimiter.Wait(w.ctx); err != nil { if err := w.queryLimiter.Wait(w.ctx); err != nil {
return nil, err return nil, 0, err
} }
if err := w.watchers.Allocations(args, &resp); err != nil { // Capture all the allocations
return nil, err allocs, err := state.AllocsByDeployment(ws, w.d.ID)
} if err != nil {
return nil, 0, err
} }
return &resp, nil stubs := make([]*structs.AllocListStub, 0, len(allocs))
for _, alloc := range allocs {
stubs = append(stubs, alloc.Stub())
}
// Use the last index that affected the jobs table
index, err := state.Index("allocs")
if err != nil {
return nil, index, err
}
return stubs, index, nil
} }
// latestEvalIndex returns the index of the last evaluation created for // latestEvalIndex returns the index of the last evaluation created for
@ -485,22 +509,26 @@ func (w *deploymentWatcher) latestEvalIndex() (uint64, error) {
return 0, err return 0, err
} }
args := &structs.JobSpecificRequest{ snap, err := w.state.Snapshot()
JobID: w.d.JobID,
}
var resp structs.JobEvaluationsResponse
err := w.watchers.Evaluations(args, &resp)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if len(resp.Evaluations) == 0 { evals, err := snap.EvalsByJob(nil, w.d.JobID)
w.setLatestEval(resp.Index) if err != nil {
return resp.Index, nil return 0, err
}
if len(evals) == 0 {
idx, err := snap.Index("evals")
if err != nil {
w.setLatestEval(idx)
}
return idx, err
} }
// Prefer using the snapshot index. Otherwise use the create index // Prefer using the snapshot index. Otherwise use the create index
e := resp.Evaluations[0] e := evals[0]
if e.SnapshotIndex != 0 { if e.SnapshotIndex != 0 {
w.setLatestEval(e.SnapshotIndex) w.setLatestEval(e.SnapshotIndex)
return e.SnapshotIndex, nil return e.SnapshotIndex, nil

View file

@ -9,6 +9,8 @@ import (
"golang.org/x/time/rate" "golang.org/x/time/rate"
memdb "github.com/hashicorp/go-memdb"
"github.com/hashicorp/nomad/nomad/state"
"github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs"
) )
@ -49,30 +51,6 @@ type DeploymentRaftEndpoints interface {
UpdateDeploymentAllocHealth(req *structs.ApplyDeploymentAllocHealthRequest) (uint64, error) UpdateDeploymentAllocHealth(req *structs.ApplyDeploymentAllocHealthRequest) (uint64, error)
} }
// DeploymentStateWatchers are the set of functions required to watch objects on
// behalf of a deployment
type DeploymentStateWatchers interface {
// Evaluations returns the set of evaluations for the given job
Evaluations(args *structs.JobSpecificRequest, reply *structs.JobEvaluationsResponse) error
// Allocations returns the set of allocations that are part of the
// deployment.
Allocations(args *structs.DeploymentSpecificRequest, reply *structs.AllocListResponse) error
// List is used to list all the deployments in the system
List(args *structs.DeploymentListRequest, reply *structs.DeploymentListResponse) error
// GetDeployment is used to lookup a particular deployment.
GetDeployment(args *structs.DeploymentSpecificRequest, reply *structs.SingleDeploymentResponse) error
// GetJobVersions is used to lookup the versions of a job. This is used when
// rolling back to find the latest stable job
GetJobVersions(args *structs.JobVersionsRequest, reply *structs.JobVersionsResponse) error
// GetJob is used to lookup a particular job.
GetJob(args *structs.JobSpecificRequest, reply *structs.SingleJobResponse) error
}
// Watcher is used to watch deployments and their allocations created // Watcher is used to watch deployments and their allocations created
// by the scheduler and trigger the scheduler when allocation health // by the scheduler and trigger the scheduler when allocation health
// transistions. // transistions.
@ -91,9 +69,8 @@ type Watcher struct {
// deployments watcher // deployments watcher
raft DeploymentRaftEndpoints raft DeploymentRaftEndpoints
// stateWatchers is the set of functions required to watch a deployment for // state is the state that is watched for state changes.
// state changes state *state.StateStore
stateWatchers DeploymentStateWatchers
// watchers is the set of active watchers, one per deployment // watchers is the set of active watchers, one per deployment
watchers map[string]*deploymentWatcher watchers map[string]*deploymentWatcher
@ -110,12 +87,11 @@ type Watcher struct {
// NewDeploymentsWatcher returns a deployments watcher that is used to watch // NewDeploymentsWatcher returns a deployments watcher that is used to watch
// deployments and trigger the scheduler as needed. // deployments and trigger the scheduler as needed.
func NewDeploymentsWatcher(logger *log.Logger, watchers DeploymentStateWatchers, func NewDeploymentsWatcher(logger *log.Logger,
raft DeploymentRaftEndpoints, stateQueriesPerSecond float64, raft DeploymentRaftEndpoints, stateQueriesPerSecond float64,
evalBatchDuration time.Duration) *Watcher { evalBatchDuration time.Duration) *Watcher {
return &Watcher{ return &Watcher{
stateWatchers: watchers,
raft: raft, raft: raft,
queryLimiter: rate.NewLimiter(rate.Limit(stateQueriesPerSecond), 100), queryLimiter: rate.NewLimiter(rate.Limit(stateQueriesPerSecond), 100),
evalBatchDuration: evalBatchDuration, evalBatchDuration: evalBatchDuration,
@ -124,14 +100,22 @@ func NewDeploymentsWatcher(logger *log.Logger, watchers DeploymentStateWatchers,
} }
// SetEnabled is used to control if the watcher is enabled. The watcher // SetEnabled is used to control if the watcher is enabled. The watcher
// should only be enabled on the active leader. // should only be enabled on the active leader. When being enabled the state is
func (w *Watcher) SetEnabled(enabled bool) error { // passsed in as it is no longer valid once a leader election has taken place.
func (w *Watcher) SetEnabled(enabled bool, state *state.StateStore) error {
w.l.Lock() w.l.Lock()
defer w.l.Unlock() defer w.l.Unlock()
wasEnabled := w.enabled wasEnabled := w.enabled
w.enabled = enabled w.enabled = enabled
if state != nil {
if wasEnabled && w.state != nil {
panic("we may have a blocking")
}
w.state = state
}
// Flush the state to create the necessary objects // Flush the state to create the necessary objects
w.flush() w.flush()
@ -166,7 +150,7 @@ func (w *Watcher) watchDeployments(ctx context.Context) {
dindex := uint64(1) dindex := uint64(1)
for { for {
// Block getting all deployments using the last deployment index. // Block getting all deployments using the last deployment index.
resp, err := w.getDeploys(ctx, dindex) deployments, idx, err := w.getDeploys(ctx, dindex)
if err != nil { if err != nil {
if err == context.Canceled || ctx.Err() == context.Canceled { if err == context.Canceled || ctx.Err() == context.Canceled {
return return
@ -175,14 +159,12 @@ func (w *Watcher) watchDeployments(ctx context.Context) {
w.logger.Printf("[ERR] nomad.deployments_watcher: failed to retrieve deploylements: %v", err) w.logger.Printf("[ERR] nomad.deployments_watcher: failed to retrieve deploylements: %v", err)
} }
// Guard against npe // Update the latest index
if resp == nil { dindex = idx
continue
}
// Ensure we are tracking the things we should and not tracking what we // Ensure we are tracking the things we should and not tracking what we
// shouldn't be // shouldn't be
for _, d := range resp.Deployments { for _, d := range deployments {
if d.Active() { if d.Active() {
if err := w.add(d); err != nil { if err := w.add(d); err != nil {
w.logger.Printf("[ERR] nomad.deployments_watcher: failed to track deployment %q: %v", d.ID, err) w.logger.Printf("[ERR] nomad.deployments_watcher: failed to track deployment %q: %v", d.ID, err)
@ -191,33 +173,47 @@ func (w *Watcher) watchDeployments(ctx context.Context) {
w.remove(d) w.remove(d)
} }
} }
// Update the latest index
dindex = resp.Index
} }
} }
// getDeploys retrieves all deployments blocking at the given index. // getDeploys retrieves all deployments blocking at the given index.
func (w *Watcher) getDeploys(ctx context.Context, index uint64) (*structs.DeploymentListResponse, error) { func (w *Watcher) getDeploys(ctx context.Context, minIndex uint64) ([]*structs.Deployment, uint64, error) {
// Build the request resp, index, err := w.state.BlockingQuery(w.getDeploysImpl, minIndex, ctx)
args := &structs.DeploymentListRequest{ if err != nil {
QueryOptions: structs.QueryOptions{ return nil, 0, err
MinQueryIndex: index,
},
} }
var resp structs.DeploymentListResponse if err := ctx.Err(); err != nil {
return nil, 0, err
for resp.Index <= index {
if err := w.queryLimiter.Wait(ctx); err != nil {
return nil, err
} }
if err := w.stateWatchers.List(args, &resp); err != nil { return resp.([]*structs.Deployment), index, nil
return nil, err }
}
// getDeploysImpl retrieves all deployments from the passed state store.
func (w *Watcher) getDeploysImpl(ws memdb.WatchSet, state *state.StateStore) (interface{}, uint64, error) {
iter, err := state.Deployments(ws)
if err != nil {
return nil, 0, err
} }
return &resp, nil var deploys []*structs.Deployment
for {
raw := iter.Next()
if raw == nil {
break
}
deploy := raw.(*structs.Deployment)
deploys = append(deploys, deploy)
}
// Use the last index that affected the deployment table
index, err := state.Index("deployment")
if err != nil {
return nil, 0, err
}
return deploys, index, nil
} }
// add adds a deployment to the watch list // add adds a deployment to the watch list
@ -246,18 +242,20 @@ func (w *Watcher) addLocked(d *structs.Deployment) (*deploymentWatcher, error) {
} }
// Get the job the deployment is referencing // Get the job the deployment is referencing
args := &structs.JobSpecificRequest{ snap, err := w.state.Snapshot()
JobID: d.JobID, if err != nil {
}
var resp structs.SingleJobResponse
if err := w.stateWatchers.GetJob(args, &resp); err != nil {
return nil, err return nil, err
} }
if resp.Job == nil {
job, err := snap.JobByID(nil, d.JobID)
if err != nil {
return nil, err
}
if job == nil {
return nil, fmt.Errorf("deployment %q references unknown job %q", d.ID, d.JobID) return nil, fmt.Errorf("deployment %q references unknown job %q", d.ID, d.JobID)
} }
watcher := newDeploymentWatcher(w.ctx, w.queryLimiter, w.logger, w.stateWatchers, d, resp.Job, w) watcher := newDeploymentWatcher(w.ctx, w.queryLimiter, w.logger, w.state, d, job, w)
w.watchers[d.ID] = watcher w.watchers[d.ID] = watcher
return watcher, nil return watcher, nil
} }
@ -283,18 +281,21 @@ func (w *Watcher) remove(d *structs.Deployment) {
// a watcher. If the deployment does not exist or is terminal an error is // a watcher. If the deployment does not exist or is terminal an error is
// returned. // returned.
func (w *Watcher) forceAdd(dID string) (*deploymentWatcher, error) { func (w *Watcher) forceAdd(dID string) (*deploymentWatcher, error) {
// Build the request snap, err := w.state.Snapshot()
args := &structs.DeploymentSpecificRequest{DeploymentID: dID} if err != nil {
var resp structs.SingleDeploymentResponse
if err := w.stateWatchers.GetDeployment(args, &resp); err != nil {
return nil, err return nil, err
} }
if resp.Deployment == nil { deployment, err := snap.DeploymentByID(nil, dID)
if err != nil {
return nil, err
}
if deployment == nil {
return nil, fmt.Errorf("unknown deployment %q", dID) return nil, fmt.Errorf("unknown deployment %q", dID)
} }
return w.addLocked(resp.Deployment) return w.addLocked(deployment)
} }
// getOrCreateWatcher returns the deployment watcher for the given deployment ID. // getOrCreateWatcher returns the deployment watcher for the given deployment ID.

View file

@ -16,7 +16,7 @@ import (
func testDeploymentWatcher(t *testing.T, qps float64, batchDur time.Duration) (*Watcher, *mockBackend) { func testDeploymentWatcher(t *testing.T, qps float64, batchDur time.Duration) (*Watcher, *mockBackend) {
m := newMockBackend(t) m := newMockBackend(t)
w := NewDeploymentsWatcher(testLogger(), m, m, qps, batchDur) w := NewDeploymentsWatcher(testLogger(), m, qps, batchDur)
return w, m return w, m
} }
@ -30,23 +30,11 @@ func TestWatcher_WatchDeployments(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
w, m := defaultTestDeploymentWatcher(t) w, m := defaultTestDeploymentWatcher(t)
// Return no allocations or evals
m.On("Allocations", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) {
reply := args.Get(1).(*structs.AllocListResponse)
reply.Index = m.nextIndex()
})
m.On("Evaluations", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) {
reply := args.Get(1).(*structs.JobEvaluationsResponse)
reply.Index = m.nextIndex()
})
// Create three jobs // Create three jobs
j1, j2, j3 := mock.Job(), mock.Job(), mock.Job() j1, j2, j3 := mock.Job(), mock.Job(), mock.Job()
jobs := map[string]*structs.Job{ assert.Nil(m.state.UpsertJob(100, j1))
j1.ID: j1, assert.Nil(m.state.UpsertJob(101, j2))
j2.ID: j2, assert.Nil(m.state.UpsertJob(102, j3))
j3.ID: j3,
}
// Create three deployments all running // Create three deployments all running
d1, d2, d3 := mock.Deployment(), mock.Deployment(), mock.Deployment() d1, d2, d3 := mock.Deployment(), mock.Deployment(), mock.Deployment()
@ -54,46 +42,27 @@ func TestWatcher_WatchDeployments(t *testing.T) {
d2.JobID = j2.ID d2.JobID = j2.ID
d3.JobID = j3.ID d3.JobID = j3.ID
m.On("GetJob", mocker.Anything, mocker.Anything). // Upsert the first deployment
Return(nil).Run(func(args mocker.Arguments) { assert.Nil(m.state.UpsertDeployment(103, d1))
in := args.Get(0).(*structs.JobSpecificRequest)
reply := args.Get(1).(*structs.SingleJobResponse)
reply.Job = jobs[in.JobID]
reply.Index = reply.Job.ModifyIndex
})
// Set up the calls for retrieving deployments
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) {
reply := args.Get(1).(*structs.DeploymentListResponse)
reply.Deployments = []*structs.Deployment{d1}
reply.Index = m.nextIndex()
}).Once()
// Next list 3 // Next list 3
block1 := make(chan time.Time) block1 := make(chan time.Time)
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) { go func() {
reply := args.Get(1).(*structs.DeploymentListResponse) <-block1
reply.Deployments = []*structs.Deployment{d1, d2, d3} assert.Nil(m.state.UpsertDeployment(104, d2))
reply.Index = m.nextIndex() assert.Nil(m.state.UpsertDeployment(105, d3))
}).Once().WaitUntil(block1) }()
//// Next list 3 but have one be terminal //// Next list 3 but have one be terminal
block2 := make(chan time.Time) block2 := make(chan time.Time)
d3terminal := d3.Copy() d3terminal := d3.Copy()
d3terminal.Status = structs.DeploymentStatusFailed d3terminal.Status = structs.DeploymentStatusFailed
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) { go func() {
reply := args.Get(1).(*structs.DeploymentListResponse) <-block2
reply.Deployments = []*structs.Deployment{d1, d2, d3terminal} assert.Nil(m.state.UpsertDeployment(106, d3terminal))
reply.Index = m.nextIndex() }()
}).WaitUntil(block2)
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) { w.SetEnabled(true, m.state)
reply := args.Get(1).(*structs.DeploymentListResponse)
reply.Deployments = []*structs.Deployment{d1, d2, d3terminal}
reply.Index = m.nextIndex()
})
w.SetEnabled(true)
testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil },
func(err error) { assert.Equal(1, len(w.watchers), "1 deployment returned") }) func(err error) { assert.Equal(1, len(w.watchers), "1 deployment returned") })
@ -111,17 +80,7 @@ func TestWatcher_UnknownDeployment(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t) assert := assert.New(t)
w, m := defaultTestDeploymentWatcher(t) w, m := defaultTestDeploymentWatcher(t)
w.SetEnabled(true) w.SetEnabled(true, m.state)
// Set up the calls for retrieving deployments
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) {
reply := args.Get(1).(*structs.DeploymentListResponse)
reply.Index = m.nextIndex()
})
m.On("GetDeployment", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) {
reply := args.Get(1).(*structs.SingleDeploymentResponse)
reply.Index = m.nextIndex()
})
// The expected error is that it should be an unknown deployment // The expected error is that it should be an unknown deployment
dID := structs.GenerateUUID() dID := structs.GenerateUUID()
@ -181,16 +140,7 @@ func TestWatcher_SetAllocHealth_Unknown(t *testing.T) {
assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob") assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob")
assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment")
// Assert the following methods will be called w.SetEnabled(true, m.state)
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState)
m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)),
mocker.Anything).Return(nil).Run(m.allocationsFromState)
m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.evaluationsFromState)
m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.getJobFromState)
w.SetEnabled(true)
testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil },
func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") })
@ -233,16 +183,7 @@ func TestWatcher_SetAllocHealth_Healthy(t *testing.T) {
assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment")
assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a}), "UpsertAllocs") assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a}), "UpsertAllocs")
// Assert the following methods will be called w.SetEnabled(true, m.state)
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState)
m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)),
mocker.Anything).Return(nil).Run(m.allocationsFromState)
m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.evaluationsFromState)
m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.getJobFromState)
w.SetEnabled(true)
testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil },
func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") })
@ -283,16 +224,7 @@ func TestWatcher_SetAllocHealth_Unhealthy(t *testing.T) {
assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment")
assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a}), "UpsertAllocs") assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a}), "UpsertAllocs")
// Assert the following methods will be called w.SetEnabled(true, m.state)
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState)
m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)),
mocker.Anything).Return(nil).Run(m.allocationsFromState)
m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.evaluationsFromState)
m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.getJobFromState)
w.SetEnabled(true)
testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil },
func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") })
@ -350,18 +282,7 @@ func TestWatcher_SetAllocHealth_Unhealthy_Rollback(t *testing.T) {
j2.Stable = false j2.Stable = false
assert.Nil(m.state.UpsertJob(m.nextIndex(), j2), "UpsertJob2") assert.Nil(m.state.UpsertJob(m.nextIndex(), j2), "UpsertJob2")
// Assert the following methods will be called w.SetEnabled(true, m.state)
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState)
m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)),
mocker.Anything).Return(nil).Run(m.allocationsFromState)
m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.evaluationsFromState)
m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.getJobFromState)
m.On("GetJobVersions", mocker.MatchedBy(matchJobVersionsRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.getJobVersionsFromState)
w.SetEnabled(true)
testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil },
func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") })
@ -417,16 +338,7 @@ func TestWatcher_PromoteDeployment_HealthyCanaries(t *testing.T) {
assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment")
assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a}), "UpsertAllocs") assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a}), "UpsertAllocs")
// Assert the following methods will be called w.SetEnabled(true, m.state)
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState)
m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)),
mocker.Anything).Return(nil).Run(m.allocationsFromState)
m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.evaluationsFromState)
m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.getJobFromState)
w.SetEnabled(true)
testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil },
func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") })
@ -473,16 +385,7 @@ func TestWatcher_PromoteDeployment_UnhealthyCanaries(t *testing.T) {
assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment")
assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a}), "UpsertAllocs") assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a}), "UpsertAllocs")
// Assert the following methods will be called w.SetEnabled(true, m.state)
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState)
m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)),
mocker.Anything).Return(nil).Run(m.allocationsFromState)
m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.evaluationsFromState)
m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.getJobFromState)
w.SetEnabled(true)
testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil },
func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") })
@ -525,16 +428,7 @@ func TestWatcher_PauseDeployment_Pause_Running(t *testing.T) {
assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob") assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob")
assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment")
// Assert the following methods will be called w.SetEnabled(true, m.state)
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState)
m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)),
mocker.Anything).Return(nil).Run(m.allocationsFromState)
m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.evaluationsFromState)
m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.getJobFromState)
w.SetEnabled(true)
testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil },
func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") })
@ -574,16 +468,7 @@ func TestWatcher_PauseDeployment_Pause_Paused(t *testing.T) {
assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob") assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob")
assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment")
// Assert the following methods will be called w.SetEnabled(true, m.state)
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState)
m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)),
mocker.Anything).Return(nil).Run(m.allocationsFromState)
m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.evaluationsFromState)
m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.getJobFromState)
w.SetEnabled(true)
testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil },
func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") })
@ -623,16 +508,7 @@ func TestWatcher_PauseDeployment_Unpause_Paused(t *testing.T) {
assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob") assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob")
assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment")
// Assert the following methods will be called w.SetEnabled(true, m.state)
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState)
m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)),
mocker.Anything).Return(nil).Run(m.allocationsFromState)
m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.evaluationsFromState)
m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.getJobFromState)
w.SetEnabled(true)
testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil },
func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") })
@ -672,16 +548,7 @@ func TestWatcher_PauseDeployment_Unpause_Running(t *testing.T) {
assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob") assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob")
assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment")
// Assert the following methods will be called w.SetEnabled(true, m.state)
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState)
m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)),
mocker.Anything).Return(nil).Run(m.allocationsFromState)
m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.evaluationsFromState)
m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.getJobFromState)
w.SetEnabled(true)
testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil },
func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") })
@ -721,16 +588,7 @@ func TestWatcher_FailDeployment_Running(t *testing.T) {
assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob") assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob")
assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment")
// Assert the following methods will be called w.SetEnabled(true, m.state)
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState)
m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)),
mocker.Anything).Return(nil).Run(m.allocationsFromState)
m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.evaluationsFromState)
m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.getJobFromState)
w.SetEnabled(true)
testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil },
func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") })
@ -783,18 +641,7 @@ func TestDeploymentWatcher_Watch(t *testing.T) {
j2.Stable = false j2.Stable = false
assert.Nil(m.state.UpsertJob(m.nextIndex(), j2), "UpsertJob2") assert.Nil(m.state.UpsertJob(m.nextIndex(), j2), "UpsertJob2")
// Assert the following methods will be called w.SetEnabled(true, m.state)
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState)
m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)),
mocker.Anything).Return(nil).Run(m.allocationsFromState)
m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.evaluationsFromState)
m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.getJobFromState)
m.On("GetJobVersions", mocker.MatchedBy(matchJobVersionsRequest(j.ID)),
mocker.Anything).Return(nil).Run(m.getJobVersionsFromState)
w.SetEnabled(true)
testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil },
func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") })
@ -912,30 +759,7 @@ func TestWatcher_BatchEvals(t *testing.T) {
assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a1}), "UpsertAllocs") assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a1}), "UpsertAllocs")
assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a2}), "UpsertAllocs") assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a2}), "UpsertAllocs")
// Assert the following methods will be called w.SetEnabled(true, m.state)
m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState)
m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d1.ID)),
mocker.Anything).Return(nil).Run(m.allocationsFromState)
m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d2.ID)),
mocker.Anything).Return(nil).Run(m.allocationsFromState)
m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j1.ID)),
mocker.Anything).Return(nil).Run(m.evaluationsFromState)
m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j2.ID)),
mocker.Anything).Return(nil).Run(m.evaluationsFromState)
m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j1.ID)),
mocker.Anything).Return(nil).Run(m.getJobFromState)
m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j2.ID)),
mocker.Anything).Return(nil).Run(m.getJobFromState)
m.On("GetJobVersions", mocker.MatchedBy(matchJobVersionsRequest(j1.ID)),
mocker.Anything).Return(nil).Run(m.getJobVersionsFromState)
m.On("GetJobVersions", mocker.MatchedBy(matchJobVersionsRequest(j2.ID)),
mocker.Anything).Return(nil).Run(m.getJobVersionsFromState)
w.SetEnabled(true)
testutil.WaitForResult(func() (bool, error) { return 2 == len(w.watchers), nil }, testutil.WaitForResult(func() (bool, error) { return 2 == len(w.watchers), nil },
func(err error) { assert.Equal(2, len(w.watchers), "Should have 2 deployment") }) func(err error) { assert.Equal(2, len(w.watchers), "Should have 2 deployment") })

View file

@ -8,7 +8,6 @@ import (
"sync" "sync"
"testing" "testing"
memdb "github.com/hashicorp/go-memdb"
"github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/state"
"github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs"
mocker "github.com/stretchr/testify/mock" mocker "github.com/stretchr/testify/mock"
@ -256,114 +255,3 @@ func matchDeploymentAllocHealthRequest(c *matchDeploymentAllocHealthRequestConfi
return true return true
} }
} }
func (m *mockBackend) Evaluations(args *structs.JobSpecificRequest, reply *structs.JobEvaluationsResponse) error {
rargs := m.Called(args, reply)
return rargs.Error(0)
}
func (m *mockBackend) evaluationsFromState(in mocker.Arguments) {
args, reply := in.Get(0).(*structs.JobSpecificRequest), in.Get(1).(*structs.JobEvaluationsResponse)
ws := memdb.NewWatchSet()
evals, _ := m.state.EvalsByJob(ws, args.JobID)
reply.Evaluations = evals
reply.Index, _ = m.state.Index("evals")
}
func (m *mockBackend) Allocations(args *structs.DeploymentSpecificRequest, reply *structs.AllocListResponse) error {
rargs := m.Called(args, reply)
return rargs.Error(0)
}
func (m *mockBackend) allocationsFromState(in mocker.Arguments) {
args, reply := in.Get(0).(*structs.DeploymentSpecificRequest), in.Get(1).(*structs.AllocListResponse)
ws := memdb.NewWatchSet()
allocs, _ := m.state.AllocsByDeployment(ws, args.DeploymentID)
var stubs []*structs.AllocListStub
for _, a := range allocs {
stubs = append(stubs, a.Stub())
}
reply.Allocations = stubs
reply.Index, _ = m.state.Index("allocs")
}
func (m *mockBackend) List(args *structs.DeploymentListRequest, reply *structs.DeploymentListResponse) error {
rargs := m.Called(args, reply)
return rargs.Error(0)
}
func (m *mockBackend) listFromState(in mocker.Arguments) {
reply := in.Get(1).(*structs.DeploymentListResponse)
ws := memdb.NewWatchSet()
iter, _ := m.state.Deployments(ws)
var deploys []*structs.Deployment
for {
raw := iter.Next()
if raw == nil {
break
}
deploys = append(deploys, raw.(*structs.Deployment))
}
reply.Deployments = deploys
reply.Index, _ = m.state.Index("deployment")
}
func (m *mockBackend) GetDeployment(args *structs.DeploymentSpecificRequest, reply *structs.SingleDeploymentResponse) error {
rargs := m.Called(args, reply)
return rargs.Error(0)
}
func (m *mockBackend) GetJobVersions(args *structs.JobVersionsRequest, reply *structs.JobVersionsResponse) error {
rargs := m.Called(args, reply)
return rargs.Error(0)
}
func (m *mockBackend) getJobVersionsFromState(in mocker.Arguments) {
args, reply := in.Get(0).(*structs.JobVersionsRequest), in.Get(1).(*structs.JobVersionsResponse)
ws := memdb.NewWatchSet()
versions, _ := m.state.JobVersionsByID(ws, args.JobID)
reply.Versions = versions
reply.Index, _ = m.state.Index("jobs")
}
func (m *mockBackend) GetJob(args *structs.JobSpecificRequest, reply *structs.SingleJobResponse) error {
rargs := m.Called(args, reply)
return rargs.Error(0)
}
func (m *mockBackend) getJobFromState(in mocker.Arguments) {
args, reply := in.Get(0).(*structs.JobSpecificRequest), in.Get(1).(*structs.SingleJobResponse)
ws := memdb.NewWatchSet()
job, _ := m.state.JobByID(ws, args.JobID)
reply.Job = job
reply.Index, _ = m.state.Index("jobs")
}
// matchDeploymentSpecificRequest is used to match that a deployment specific
// request is for the passed deployment id
func matchDeploymentSpecificRequest(dID string) func(args *structs.DeploymentSpecificRequest) bool {
return func(args *structs.DeploymentSpecificRequest) bool {
return args.DeploymentID == dID
}
}
// matchJobSpecificRequest is used to match that a job specific
// request is for the passed job id
func matchJobSpecificRequest(jID string) func(args *structs.JobSpecificRequest) bool {
return func(args *structs.JobSpecificRequest) bool {
return args.JobID == jID
}
}
// matchJobVersionsRequest is used to match that a job version
// request is for the passed job id
func matchJobVersionsRequest(jID string) func(args *structs.JobVersionsRequest) bool {
return func(args *structs.JobVersionsRequest) bool {
return args.JobID == jID
}
}

View file

@ -131,7 +131,7 @@ func (s *Server) establishLeadership(stopCh chan struct{}) error {
s.blockedEvals.SetEnabled(true) s.blockedEvals.SetEnabled(true)
// Enable the deployment watcher, since we are now the leader // Enable the deployment watcher, since we are now the leader
if err := s.deploymentWatcher.SetEnabled(true); err != nil { if err := s.deploymentWatcher.SetEnabled(true, s.State()); err != nil {
return err return err
} }
@ -494,7 +494,7 @@ func (s *Server) revokeLeadership() error {
s.vault.SetActive(false) s.vault.SetActive(false)
// Disable the deployment watcher as it is only useful as a leader. // Disable the deployment watcher as it is only useful as a leader.
if err := s.deploymentWatcher.SetEnabled(false); err != nil { if err := s.deploymentWatcher.SetEnabled(false, nil); err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package nomad package nomad
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
@ -341,7 +342,9 @@ type blockingOptions struct {
// blockingRPC is used for queries that need to wait for a // blockingRPC is used for queries that need to wait for a
// minimum index. This is used to block and wait for changes. // minimum index. This is used to block and wait for changes.
func (s *Server) blockingRPC(opts *blockingOptions) error { func (s *Server) blockingRPC(opts *blockingOptions) error {
var timeout *time.Timer var deadline time.Time
ctx := context.Background()
var cancel context.CancelFunc
var state *state.StateStore var state *state.StateStore
// Fast path non-blocking // Fast path non-blocking
@ -360,8 +363,9 @@ func (s *Server) blockingRPC(opts *blockingOptions) error {
opts.queryOpts.MaxQueryTime += lib.RandomStagger(opts.queryOpts.MaxQueryTime / jitterFraction) opts.queryOpts.MaxQueryTime += lib.RandomStagger(opts.queryOpts.MaxQueryTime / jitterFraction)
// Setup a query timeout // Setup a query timeout
timeout = time.NewTimer(opts.queryOpts.MaxQueryTime) deadline = time.Now().Add(opts.queryOpts.MaxQueryTime)
defer timeout.Stop() ctx, cancel = context.WithDeadline(context.Background(), deadline)
defer cancel()
RUN_QUERY: RUN_QUERY:
// Update the query meta data // Update the query meta data
@ -393,7 +397,7 @@ RUN_QUERY:
// Check for minimum query time // Check for minimum query time
if err == nil && opts.queryOpts.MinQueryIndex > 0 && opts.queryMeta.Index <= opts.queryOpts.MinQueryIndex { if err == nil && opts.queryOpts.MinQueryIndex > 0 && opts.queryMeta.Index <= opts.queryOpts.MinQueryIndex {
if expired := ws.Watch(timeout.C); !expired { if expired := ws.WatchCtx(ctx); !expired {
goto RUN_QUERY goto RUN_QUERY
} }
} }

View file

@ -679,23 +679,15 @@ func (s *Server) setupConsulSyncer() error {
// shim that provides the appropriate methods. // shim that provides the appropriate methods.
func (s *Server) setupDeploymentWatcher() error { func (s *Server) setupDeploymentWatcher() error {
// Create the shims // Create the raft shim type to restrict the set of raft methods that can be
stateShim := &deploymentWatcherStateShim{ // made
region: s.Region(),
evaluations: s.endpoints.Job.Evaluations,
allocations: s.endpoints.Deployment.Allocations,
list: s.endpoints.Deployment.List,
getDeployment: s.endpoints.Deployment.GetDeployment,
getJobVersions: s.endpoints.Job.GetJobVersions,
getJob: s.endpoints.Job.GetJob,
}
raftShim := &deploymentWatcherRaftShim{ raftShim := &deploymentWatcherRaftShim{
apply: s.raftApply, apply: s.raftApply,
} }
// Create the deployment watcher // Create the deployment watcher
s.deploymentWatcher = deploymentwatcher.NewDeploymentsWatcher( s.deploymentWatcher = deploymentwatcher.NewDeploymentsWatcher(
s.logger, stateShim, raftShim, s.logger, raftShim,
deploymentwatcher.LimitStateQueriesPerSecond, deploymentwatcher.LimitStateQueriesPerSecond,
deploymentwatcher.CrossDeploymentEvalBatchDuration) deploymentwatcher.CrossDeploymentEvalBatchDuration)

View file

@ -1,6 +1,7 @@
package state package state
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -87,6 +88,48 @@ func (s *StateStore) Abandon() {
close(s.abandonCh) close(s.abandonCh)
} }
// QueryFn is the definition of a function that can be used to implement a basic
// blocking query against the state store.
type QueryFn func(memdb.WatchSet, *StateStore) (resp interface{}, index uint64, err error)
// BlockingQuery takes a query function and runs the function until the minimum
// query index is met or until the passed context is cancelled.
func (s *StateStore) BlockingQuery(query QueryFn, minIndex uint64, ctx context.Context) (
resp interface{}, index uint64, err error) {
RUN_QUERY:
// We capture the state store and its abandon channel but pass a snapshot to
// the blocking query function. We operate on the snapshot to allow separate
// calls to the state store not all wrapped within the same transaction.
abandonCh := s.AbandonCh()
snap, _ := s.Snapshot()
stateSnap := &snap.StateStore
// We can skip all watch tracking if this isn't a blocking query.
var ws memdb.WatchSet
if minIndex > 0 {
ws = memdb.NewWatchSet()
// This channel will be closed if a snapshot is restored and the
// whole state store is abandoned.
ws.Add(abandonCh)
}
resp, index, err = query(ws, stateSnap)
if err != nil {
return nil, index, err
}
// We haven't reached the min-index yet.
if minIndex > 0 && index <= minIndex {
if expired := ws.WatchCtx(ctx); !expired {
goto RUN_QUERY
}
}
return resp, index, nil
}
// UpsertPlanResults is used to upsert the results of a plan. // UpsertPlanResults is used to upsert the results of a plan.
func (s *StateStore) UpsertPlanResults(index uint64, results *structs.ApplyPlanResultsRequest) error { func (s *StateStore) UpsertPlanResults(index uint64, results *structs.ApplyPlanResultsRequest) error {
txn := s.db.Txn(true) txn := s.db.Txn(true)

View file

@ -1,6 +1,7 @@
package state package state
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
@ -27,6 +28,69 @@ func testStateStore(t *testing.T) *StateStore {
return state return state
} }
func TestStateStore_Blocking_Error(t *testing.T) {
expected := fmt.Errorf("test error")
errFn := func(memdb.WatchSet, *StateStore) (interface{}, uint64, error) {
return nil, 0, expected
}
state := testStateStore(t)
_, idx, err := state.BlockingQuery(errFn, 10, context.Background())
assert.EqualError(t, err, expected.Error())
assert.Zero(t, idx)
}
func TestStateStore_Blocking_Timeout(t *testing.T) {
noopFn := func(memdb.WatchSet, *StateStore) (interface{}, uint64, error) {
return nil, 5, nil
}
state := testStateStore(t)
timeout := time.Now().Add(10 * time.Millisecond)
deadlineCtx, cancel := context.WithDeadline(context.Background(), timeout)
defer cancel()
_, idx, err := state.BlockingQuery(noopFn, 10, deadlineCtx)
assert.Nil(t, err)
assert.EqualValues(t, 5, idx)
assert.WithinDuration(t, timeout, time.Now(), 5*time.Millisecond)
}
func TestStateStore_Blocking_MinQuery(t *testing.T) {
job := mock.Job()
count := 0
queryFn := func(ws memdb.WatchSet, s *StateStore) (interface{}, uint64, error) {
_, err := s.JobByID(ws, job.ID)
if err != nil {
return nil, 0, err
}
count++
if count == 1 {
return false, 5, nil
} else if count > 2 {
return false, 20, fmt.Errorf("called too many times")
}
return true, 11, nil
}
state := testStateStore(t)
timeout := time.Now().Add(10 * time.Millisecond)
deadlineCtx, cancel := context.WithDeadline(context.Background(), timeout)
defer cancel()
time.AfterFunc(5*time.Millisecond, func() {
state.UpsertJob(11, job)
})
resp, idx, err := state.BlockingQuery(queryFn, 10, deadlineCtx)
assert.Nil(t, err)
assert.Equal(t, 2, count)
assert.EqualValues(t, 11, idx)
assert.True(t, resp.(bool))
}
// This test checks that: // This test checks that:
// 1) The job is denormalized // 1) The job is denormalized
// 2) Allocations are created // 2) Allocations are created

View file

@ -183,6 +183,31 @@ func (t *Txn) writeNode(n *Node, forLeafUpdate bool) *Node {
return nc return nc
} }
// Visit all the nodes in the tree under n, and add their mutateChannels to the transaction
// Returns the size of the subtree visited
func (t *Txn) trackChannelsAndCount(n *Node) int {
// Count only leaf nodes
leaves := 0
if n.leaf != nil {
leaves = 1
}
// Mark this node as being mutated.
if t.trackMutate {
t.trackChannel(n.mutateCh)
}
// Mark its leaf as being mutated, if appropriate.
if t.trackMutate && n.leaf != nil {
t.trackChannel(n.leaf.mutateCh)
}
// Recurse on the children
for _, e := range n.edges {
leaves += t.trackChannelsAndCount(e.node)
}
return leaves
}
// mergeChild is called to collapse the given node with its child. This is only // mergeChild is called to collapse the given node with its child. This is only
// called when the given node is not a leaf and has a single edge. // called when the given node is not a leaf and has a single edge.
func (t *Txn) mergeChild(n *Node) { func (t *Txn) mergeChild(n *Node) {
@ -357,6 +382,56 @@ func (t *Txn) delete(parent, n *Node, search []byte) (*Node, *leafNode) {
return nc, leaf return nc, leaf
} }
// delete does a recursive deletion
func (t *Txn) deletePrefix(parent, n *Node, search []byte) (*Node, int) {
// Check for key exhaustion
if len(search) == 0 {
nc := t.writeNode(n, true)
if n.isLeaf() {
nc.leaf = nil
}
nc.edges = nil
return nc, t.trackChannelsAndCount(n)
}
// Look for an edge
label := search[0]
idx, child := n.getEdge(label)
// We make sure that either the child node's prefix starts with the search term, or the search term starts with the child node's prefix
// Need to do both so that we can delete prefixes that don't correspond to any node in the tree
if child == nil || (!bytes.HasPrefix(child.prefix, search) && !bytes.HasPrefix(search, child.prefix)) {
return nil, 0
}
// Consume the search prefix
if len(child.prefix) > len(search) {
search = []byte("")
} else {
search = search[len(child.prefix):]
}
newChild, numDeletions := t.deletePrefix(n, child, search)
if newChild == nil {
return nil, 0
}
// Copy this node. WATCH OUT - it's safe to pass "false" here because we
// will only ADD a leaf via nc.mergeChild() if there isn't one due to
// the !nc.isLeaf() check in the logic just below. This is pretty subtle,
// so be careful if you change any of the logic here.
nc := t.writeNode(n, false)
// Delete the edge if the node has no edges
if newChild.leaf == nil && len(newChild.edges) == 0 {
nc.delEdge(label)
if n != t.root && len(nc.edges) == 1 && !nc.isLeaf() {
t.mergeChild(nc)
}
} else {
nc.edges[idx].node = newChild
}
return nc, numDeletions
}
// Insert is used to add or update a given key. The return provides // Insert is used to add or update a given key. The return provides
// the previous value and a bool indicating if any was set. // the previous value and a bool indicating if any was set.
func (t *Txn) Insert(k []byte, v interface{}) (interface{}, bool) { func (t *Txn) Insert(k []byte, v interface{}) (interface{}, bool) {
@ -384,6 +459,19 @@ func (t *Txn) Delete(k []byte) (interface{}, bool) {
return nil, false return nil, false
} }
// DeletePrefix is used to delete an entire subtree that matches the prefix
// This will delete all nodes under that prefix
func (t *Txn) DeletePrefix(prefix []byte) bool {
newRoot, numDeletions := t.deletePrefix(nil, t.root, prefix)
if newRoot != nil {
t.root = newRoot
t.size = t.size - numDeletions
return true
}
return false
}
// Root returns the current root of the radix tree within this // Root returns the current root of the radix tree within this
// transaction. The root is not safe across insert and delete operations, // transaction. The root is not safe across insert and delete operations,
// but can be used to read the current state during a transaction. // but can be used to read the current state during a transaction.
@ -524,6 +612,14 @@ func (t *Tree) Delete(k []byte) (*Tree, interface{}, bool) {
return txn.Commit(), old, ok return txn.Commit(), old, ok
} }
// DeletePrefix is used to delete all nodes starting with a given prefix. Returns the new tree,
// and a bool indicating if the prefix matched any nodes
func (t *Tree) DeletePrefix(k []byte) (*Tree, bool) {
txn := t.Txn()
ok := txn.DeletePrefix(k)
return txn.Commit(), ok
}
// Root returns the root node of the tree which can be used for richer // Root returns the root node of the tree which can be used for richer
// query operations. // query operations.
func (t *Tree) Root() *Node { func (t *Tree) Root() *Node {

View file

@ -22,6 +22,11 @@ The database provides the following:
UUID can be efficiently compressed from strings into byte indexes for reduced UUID can be efficiently compressed from strings into byte indexes for reduced
storage requirements. storage requirements.
* Watches - Callers can populate a watch set as part of a query, which can be used to
detect when a modification has been made to the database which affects the query
results. This lets callers easily watch for changes in the database in a very general
way.
For the underlying immutable radix trees, see [go-immutable-radix](https://github.com/hashicorp/go-immutable-radix). For the underlying immutable radix trees, see [go-immutable-radix](https://github.com/hashicorp/go-immutable-radix).
Documentation Documentation

View file

@ -305,20 +305,19 @@ func (u *UintFieldIndex) FromArgs(args ...interface{}) ([]byte, error) {
} }
// IsUintType returns whether the passed type is a type of uint and the number // IsUintType returns whether the passed type is a type of uint and the number
// of bytes the type is. To avoid platform specific sizes, the uint type returns // of bytes needed to encode the type.
// 8 bytes regardless of if it is smaller.
func IsUintType(k reflect.Kind) (size int, okay bool) { func IsUintType(k reflect.Kind) (size int, okay bool) {
switch k { switch k {
case reflect.Uint: case reflect.Uint:
return 8, true return binary.MaxVarintLen64, true
case reflect.Uint8: case reflect.Uint8:
return 1, true
case reflect.Uint16:
return 2, true return 2, true
case reflect.Uint16:
return binary.MaxVarintLen16, true
case reflect.Uint32: case reflect.Uint32:
return 4, true return binary.MaxVarintLen32, true
case reflect.Uint64: case reflect.Uint64:
return 8, true return binary.MaxVarintLen64, true
default: default:
return 0, false return 0, false
} }

View file

@ -76,7 +76,7 @@ func (db *MemDB) Snapshot() *MemDB {
func (db *MemDB) initialize() error { func (db *MemDB) initialize() error {
root := db.getRoot() root := db.getRoot()
for tName, tableSchema := range db.schema.Tables { for tName, tableSchema := range db.schema.Tables {
for iName, _ := range tableSchema.Indexes { for iName := range tableSchema.Indexes {
index := iradix.New() index := iradix.New()
path := indexPath(tName, iName) path := indexPath(tName, iName)
root, _, _ = root.Insert(path, index) root, _, _ = root.Insert(path, index)

View file

@ -330,6 +330,96 @@ func (txn *Txn) Delete(table string, obj interface{}) error {
return nil return nil
} }
// DeletePrefix is used to delete an entire subtree based on a prefix.
// The given index must be a prefix index, and will be used to perform a scan and enumerate the set of objects to delete.
// These will be removed from all other indexes, and then a special prefix operation will delete the objects from the given index in an efficient subtree delete operation.
// This is useful when you have a very large number of objects indexed by the given index, along with a much smaller number of entries in the other indexes for those objects.
func (txn *Txn) DeletePrefix(table string, prefix_index string, prefix string) (bool, error) {
if !txn.write {
return false, fmt.Errorf("cannot delete in read-only transaction")
}
if !strings.HasSuffix(prefix_index, "_prefix") {
return false, fmt.Errorf("Index name for DeletePrefix must be a prefix index, Got %v ", prefix_index)
}
deletePrefixIndex := strings.TrimSuffix(prefix_index, "_prefix")
// Get an iterator over all of the keys with the given prefix.
entries, err := txn.Get(table, prefix_index, prefix)
if err != nil {
return false, fmt.Errorf("failed kvs lookup: %s", err)
}
// Get the table schema
tableSchema, ok := txn.db.schema.Tables[table]
if !ok {
return false, fmt.Errorf("invalid table '%s'", table)
}
foundAny := false
for entry := entries.Next(); entry != nil; entry = entries.Next() {
if !foundAny {
foundAny = true
}
// Get the primary ID of the object
idSchema := tableSchema.Indexes[id]
idIndexer := idSchema.Indexer.(SingleIndexer)
ok, idVal, err := idIndexer.FromObject(entry)
if err != nil {
return false, fmt.Errorf("failed to build primary index: %v", err)
}
if !ok {
return false, fmt.Errorf("object missing primary index")
}
// Remove the object from all the indexes except the given prefix index
for name, indexSchema := range tableSchema.Indexes {
if name == deletePrefixIndex {
continue
}
indexTxn := txn.writableIndex(table, name)
// Handle the update by deleting from the index first
var (
ok bool
vals [][]byte
err error
)
switch indexer := indexSchema.Indexer.(type) {
case SingleIndexer:
var val []byte
ok, val, err = indexer.FromObject(entry)
vals = [][]byte{val}
case MultiIndexer:
ok, vals, err = indexer.FromObject(entry)
}
if err != nil {
return false, fmt.Errorf("failed to build index '%s': %v", name, err)
}
if ok {
// Handle non-unique index by computing a unique index.
// This is done by appending the primary key which must
// be unique anyways.
for _, val := range vals {
if !indexSchema.Unique {
val = append(val, idVal...)
}
indexTxn.Delete(val)
}
}
}
}
if foundAny {
indexTxn := txn.writableIndex(table, deletePrefixIndex)
ok = indexTxn.DeletePrefix([]byte(prefix))
if !ok {
panic(fmt.Errorf("prefix %v matched some entries but DeletePrefix did not delete any ", prefix))
}
return true, nil
}
return false, nil
}
// DeleteAll is used to delete all the objects in a given table // DeleteAll is used to delete all the objects in a given table
// matching the constraints on the index // matching the constraints on the index
func (txn *Txn) DeleteAll(table, index string, args ...interface{}) (int, error) { func (txn *Txn) DeleteAll(table, index string, args ...interface{}) (int, error) {

View file

@ -1,6 +1,9 @@
package memdb package memdb
import "time" import (
"context"
"time"
)
// WatchSet is a collection of watch channels. // WatchSet is a collection of watch channels.
type WatchSet map[<-chan struct{}]struct{} type WatchSet map[<-chan struct{}]struct{}
@ -46,6 +49,30 @@ func (w WatchSet) Watch(timeoutCh <-chan time.Time) bool {
return false return false
} }
// Create a context that gets cancelled when the timeout is triggered
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
select {
case <-timeoutCh:
cancel()
case <-ctx.Done():
}
}()
return w.WatchCtx(ctx)
}
// WatchCtx is used to wait for either the watch set to trigger or for the
// context to be cancelled. Returns true if the context is cancelled. Watch with
// a timeout channel can be mimicked by creating a context with a deadline.
// WatchCtx should be preferred over Watch.
func (w WatchSet) WatchCtx(ctx context.Context) bool {
if w == nil {
return false
}
if n := len(w); n <= aFew { if n := len(w); n <= aFew {
idx := 0 idx := 0
chunk := make([]<-chan struct{}, aFew) chunk := make([]<-chan struct{}, aFew)
@ -53,23 +80,18 @@ func (w WatchSet) Watch(timeoutCh <-chan time.Time) bool {
chunk[idx] = watchCh chunk[idx] = watchCh
idx++ idx++
} }
return watchFew(chunk, timeoutCh) return watchFew(ctx, chunk)
} else {
return w.watchMany(timeoutCh)
} }
return w.watchMany(ctx)
} }
// watchMany is used if there are many watchers. // watchMany is used if there are many watchers.
func (w WatchSet) watchMany(timeoutCh <-chan time.Time) bool { func (w WatchSet) watchMany(ctx context.Context) bool {
// Make a fake timeout channel we can feed into watchFew to cancel all
// the blocking goroutines.
doneCh := make(chan time.Time)
defer close(doneCh)
// Set up a goroutine for each watcher. // Set up a goroutine for each watcher.
triggerCh := make(chan struct{}, 1) triggerCh := make(chan struct{}, 1)
watcher := func(chunk []<-chan struct{}) { watcher := func(chunk []<-chan struct{}) {
if timeout := watchFew(chunk, doneCh); !timeout { if timeout := watchFew(ctx, chunk); !timeout {
select { select {
case triggerCh <- struct{}{}: case triggerCh <- struct{}{}:
default: default:
@ -102,7 +124,7 @@ func (w WatchSet) watchMany(timeoutCh <-chan time.Time) bool {
select { select {
case <-triggerCh: case <-triggerCh:
return false return false
case <-timeoutCh: case <-ctx.Done():
return true return true
} }
} }

View file

@ -1,8 +1,9 @@
//go:generate sh -c "go run watch-gen/main.go >watch_few.go"
package memdb package memdb
//go:generate sh -c "go run watch-gen/main.go >watch_few.go"
import( import(
"time" "context"
) )
// aFew gives how many watchers this function is wired to support. You must // aFew gives how many watchers this function is wired to support. You must
@ -11,7 +12,7 @@ const aFew = 32
// watchFew is used if there are only a few watchers as a performance // watchFew is used if there are only a few watchers as a performance
// optimization. // optimization.
func watchFew(ch []<-chan struct{}, timeoutCh <-chan time.Time) bool { func watchFew(ctx context.Context, ch []<-chan struct{}) bool {
select { select {
case <-ch[0]: case <-ch[0]:
@ -110,7 +111,7 @@ func watchFew(ch []<-chan struct{}, timeoutCh <-chan time.Time) bool {
case <-ch[31]: case <-ch[31]:
return false return false
case <-timeoutCh: case <-ctx.Done():
return true return true
} }
} }

12
vendor/vendor.json vendored
View file

@ -731,16 +731,16 @@
"revisionTime": "2017-07-16T17:45:23Z" "revisionTime": "2017-07-16T17:45:23Z"
}, },
{ {
"checksumSHA1": "zvmksNyW6g+Fd/bywd4vcn8rp+M=", "checksumSHA1": "Cas2nprG6pWzf05A2F/OlnjUu2Y=",
"path": "github.com/hashicorp/go-immutable-radix", "path": "github.com/hashicorp/go-immutable-radix",
"revision": "30664b879c9a771d8d50b137ab80ee0748cb2fcc", "revision": "8aac2701530899b64bdea735a1de8da899815220",
"revisionTime": "2017-02-14T02:52:36Z" "revisionTime": "2017-07-25T22:12:15Z"
}, },
{ {
"checksumSHA1": "KeH4FuTKuv3tqFOr3NpLQtL1jPs=", "checksumSHA1": "Q7MLoOLgXyvHBVmT/rvSeOhJo6c=",
"path": "github.com/hashicorp/go-memdb", "path": "github.com/hashicorp/go-memdb",
"revision": "ed59a4bb9146689d4b00d060b70b9e9648b523af", "revision": "f2dec88c7441ddf375eabd561b0a1584b67b8ce4",
"revisionTime": "2017-04-11T17:33:47Z" "revisionTime": "2017-08-30T23:01:53Z"
}, },
{ {
"path": "github.com/hashicorp/go-msgpack/codec", "path": "github.com/hashicorp/go-msgpack/codec",