diff --git a/nomad/deployment_endpoint.go b/nomad/deployment_endpoint.go index 731c5e71f..c34eba8e9 100644 --- a/nomad/deployment_endpoint.go +++ b/nomad/deployment_endpoint.go @@ -243,7 +243,7 @@ func (d *Deployment) List(args *structs.DeploymentListRequest, reply *structs.De } reply.Deployments = deploys - // Use the last index that affected the jobs table + // Use the last index that affected the deployment table index, err := state.Index("deployment") if err != nil { return err diff --git a/nomad/deployment_watcher_shims.go b/nomad/deployment_watcher_shims.go index 2640720d2..c703feb7f 100644 --- a/nomad/deployment_watcher_shims.go +++ b/nomad/deployment_watcher_shims.go @@ -4,83 +4,6 @@ import ( "github.com/hashicorp/nomad/nomad/structs" ) -// deploymentWatcherStateShim is the shim that provides the state watching -// methods. These should be set by the server and passed to the deployment -// watcher. -type deploymentWatcherStateShim struct { - // region is the region the server is a member of. It is used to - // auto-populate requests that do not have it set - region string - - // evaluations returns the set of evaluations for the given job - evaluations func(args *structs.JobSpecificRequest, reply *structs.JobEvaluationsResponse) error - - // allocations returns the set of allocations that are part of the - // deployment. - allocations func(args *structs.DeploymentSpecificRequest, reply *structs.AllocListResponse) error - - // list is used to list all the deployments in the system - list func(args *structs.DeploymentListRequest, reply *structs.DeploymentListResponse) error - - // GetDeployment is used to lookup a particular deployment. - getDeployment func(args *structs.DeploymentSpecificRequest, reply *structs.SingleDeploymentResponse) error - - // getJobVersions is used to lookup the versions of a job. This is used when - // rolling back to find the latest stable job - getJobVersions func(args *structs.JobVersionsRequest, reply *structs.JobVersionsResponse) error - - // getJob is used to lookup a particular job. - getJob func(args *structs.JobSpecificRequest, reply *structs.SingleJobResponse) error -} - -func (d *deploymentWatcherStateShim) Evaluations(args *structs.JobSpecificRequest, reply *structs.JobEvaluationsResponse) error { - if args.Region == "" { - args.Region = d.region - } - - return d.evaluations(args, reply) -} - -func (d *deploymentWatcherStateShim) Allocations(args *structs.DeploymentSpecificRequest, reply *structs.AllocListResponse) error { - if args.Region == "" { - args.Region = d.region - } - - return d.allocations(args, reply) -} - -func (d *deploymentWatcherStateShim) List(args *structs.DeploymentListRequest, reply *structs.DeploymentListResponse) error { - if args.Region == "" { - args.Region = d.region - } - - return d.list(args, reply) -} - -func (d *deploymentWatcherStateShim) GetDeployment(args *structs.DeploymentSpecificRequest, reply *structs.SingleDeploymentResponse) error { - if args.Region == "" { - args.Region = d.region - } - - return d.getDeployment(args, reply) -} - -func (d *deploymentWatcherStateShim) GetJobVersions(args *structs.JobVersionsRequest, reply *structs.JobVersionsResponse) error { - if args.Region == "" { - args.Region = d.region - } - - return d.getJobVersions(args, reply) -} - -func (d *deploymentWatcherStateShim) GetJob(args *structs.JobSpecificRequest, reply *structs.SingleJobResponse) error { - if args.Region == "" { - args.Region = d.region - } - - return d.getJob(args, reply) -} - // deploymentWatcherRaftShim is the shim that provides the state watching // methods. These should be set by the server and passed to the deployment // watcher. diff --git a/nomad/deploymentwatcher/deployment_watcher.go b/nomad/deploymentwatcher/deployment_watcher.go index 2fdb05611..ee7e2aa3f 100644 --- a/nomad/deploymentwatcher/deployment_watcher.go +++ b/nomad/deploymentwatcher/deployment_watcher.go @@ -8,7 +8,9 @@ import ( "golang.org/x/time/rate" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/helper" + "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" ) @@ -49,9 +51,8 @@ type deploymentWatcher struct { // deployment deploymentTriggers - // DeploymentStateWatchers holds the methods required to watch objects for - // changes on behalf of the deployment - watchers DeploymentStateWatchers + // state is the state that is watched for state changes. + state *state.StateStore // d is the deployment being watched d *structs.Deployment @@ -77,7 +78,7 @@ type deploymentWatcher struct { // newDeploymentWatcher returns a deployment watcher that is used to watch // deployments and trigger the scheduler as needed. func newDeploymentWatcher(parent context.Context, queryLimiter *rate.Limiter, - logger *log.Logger, watchers DeploymentStateWatchers, d *structs.Deployment, + logger *log.Logger, state *state.StateStore, d *structs.Deployment, j *structs.Job, triggers deploymentTriggers) *deploymentWatcher { ctx, exitFn := context.WithCancel(parent) @@ -85,7 +86,7 @@ func newDeploymentWatcher(parent context.Context, queryLimiter *rate.Limiter, queryLimiter: queryLimiter, d: d, j: j, - watchers: watchers, + state: state, deploymentTriggers: triggers, logger: logger, ctx: ctx, @@ -116,15 +117,19 @@ func (w *deploymentWatcher) SetAllocHealth( } // Get the allocations for the deployment - args := &structs.DeploymentSpecificRequest{DeploymentID: req.DeploymentID} - var resp structs.AllocListResponse - if err := w.watchers.Allocations(args, &resp); err != nil { + snap, err := w.state.Snapshot() + if err != nil { + return err + } + + allocs, err := snap.AllocsByDeployment(nil, req.DeploymentID) + if err != nil { return err } // Determine if we should autorevert to an older job desc := structs.DeploymentStatusDescriptionFailedAllocations - for _, alloc := range resp.Allocations { + for _, alloc := range allocs { // Check that the alloc has been marked unhealthy if _, ok := unhealthy[alloc.ID]; !ok { continue @@ -295,7 +300,7 @@ func (w *deploymentWatcher) watch() { // Block getting all allocations that are part of the deployment using // the last evaluation index. This will have us block waiting for // something to change past what the scheduler has evaluated. - allocResp, err := w.getAllocs(allocIndex) + allocs, index, err := w.getAllocs(allocIndex) if err != nil { if err == context.Canceled || w.ctx.Err() == context.Canceled { return @@ -304,7 +309,7 @@ func (w *deploymentWatcher) watch() { w.logger.Printf("[ERR] nomad.deployment_watcher: failed to retrieve allocations for deployment %q: %v", w.d.ID, err) return } - allocIndex = allocResp.Index + allocIndex = index // Get the latest evaluation index latestEval, err := w.latestEvalIndex() @@ -320,7 +325,7 @@ func (w *deploymentWatcher) watch() { // Create an evaluation trigger if there is any allocation whose // deployment status has been updated past the latest eval index. createEval, failDeployment, rollback := false, false, false - for _, alloc := range allocResp.Allocations { + for _, alloc := range allocs { if alloc.DeploymentStatus == nil || alloc.DeploymentStatus.ModifyIndex <= latestEval { continue } @@ -379,21 +384,25 @@ func (w *deploymentWatcher) watch() { } } else if createEval { // Create an eval to push the deployment along - w.createEvalBatched(allocResp.Index) + w.createEvalBatched(index) } } } // latestStableJob returns the latest stable job. It may be nil if none exist func (w *deploymentWatcher) latestStableJob() (*structs.Job, error) { - args := &structs.JobVersionsRequest{JobID: w.d.JobID} - var resp structs.JobVersionsResponse - if err := w.watchers.GetJobVersions(args, &resp); err != nil { + snap, err := w.state.Snapshot() + if err != nil { + return nil, err + } + + versions, err := snap.JobVersionsByID(nil, w.d.JobID) + if err != nil { return nil, err } var stable *structs.Job - for _, job := range resp.Versions { + for _, job := range versions { if job.Stable { stable = job break @@ -454,27 +463,42 @@ func (w *deploymentWatcher) getDeploymentStatusUpdate(status, desc string) *stru // getAllocs retrieves the allocations that are part of the deployment blocking // at the given index. -func (w *deploymentWatcher) getAllocs(index uint64) (*structs.AllocListResponse, error) { - // Build the request - args := &structs.DeploymentSpecificRequest{ - DeploymentID: w.d.ID, - QueryOptions: structs.QueryOptions{ - MinQueryIndex: index, - }, +func (w *deploymentWatcher) getAllocs(index uint64) ([]*structs.AllocListStub, uint64, error) { + resp, index, err := w.state.BlockingQuery(w.getAllocsImpl, index, w.ctx) + if err != nil { + return nil, 0, err } - var resp structs.AllocListResponse - - for resp.Index <= index { - if err := w.queryLimiter.Wait(w.ctx); err != nil { - return nil, err - } - - if err := w.watchers.Allocations(args, &resp); err != nil { - return nil, err - } + if err := w.ctx.Err(); err != nil { + return nil, 0, err } - return &resp, nil + return resp.([]*structs.AllocListStub), index, nil +} + +// getDeploysImpl retrieves all deployments from the passed state store. +func (w *deploymentWatcher) getAllocsImpl(ws memdb.WatchSet, state *state.StateStore) (interface{}, uint64, error) { + if err := w.queryLimiter.Wait(w.ctx); err != nil { + return nil, 0, err + } + + // Capture all the allocations + allocs, err := state.AllocsByDeployment(ws, w.d.ID) + if err != nil { + return nil, 0, err + } + + stubs := make([]*structs.AllocListStub, 0, len(allocs)) + for _, alloc := range allocs { + stubs = append(stubs, alloc.Stub()) + } + + // Use the last index that affected the jobs table + index, err := state.Index("allocs") + if err != nil { + return nil, index, err + } + + return stubs, index, nil } // latestEvalIndex returns the index of the last evaluation created for @@ -485,22 +509,26 @@ func (w *deploymentWatcher) latestEvalIndex() (uint64, error) { return 0, err } - args := &structs.JobSpecificRequest{ - JobID: w.d.JobID, - } - var resp structs.JobEvaluationsResponse - err := w.watchers.Evaluations(args, &resp) + snap, err := w.state.Snapshot() if err != nil { return 0, err } - if len(resp.Evaluations) == 0 { - w.setLatestEval(resp.Index) - return resp.Index, nil + evals, err := snap.EvalsByJob(nil, w.d.JobID) + if err != nil { + return 0, err + } + + if len(evals) == 0 { + idx, err := snap.Index("evals") + if err != nil { + w.setLatestEval(idx) + } + return idx, err } // Prefer using the snapshot index. Otherwise use the create index - e := resp.Evaluations[0] + e := evals[0] if e.SnapshotIndex != 0 { w.setLatestEval(e.SnapshotIndex) return e.SnapshotIndex, nil diff --git a/nomad/deploymentwatcher/deployments_watcher.go b/nomad/deploymentwatcher/deployments_watcher.go index 0deaf34d9..e42c4a416 100644 --- a/nomad/deploymentwatcher/deployments_watcher.go +++ b/nomad/deploymentwatcher/deployments_watcher.go @@ -9,6 +9,8 @@ import ( "golang.org/x/time/rate" + memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" ) @@ -49,30 +51,6 @@ type DeploymentRaftEndpoints interface { UpdateDeploymentAllocHealth(req *structs.ApplyDeploymentAllocHealthRequest) (uint64, error) } -// DeploymentStateWatchers are the set of functions required to watch objects on -// behalf of a deployment -type DeploymentStateWatchers interface { - // Evaluations returns the set of evaluations for the given job - Evaluations(args *structs.JobSpecificRequest, reply *structs.JobEvaluationsResponse) error - - // Allocations returns the set of allocations that are part of the - // deployment. - Allocations(args *structs.DeploymentSpecificRequest, reply *structs.AllocListResponse) error - - // List is used to list all the deployments in the system - List(args *structs.DeploymentListRequest, reply *structs.DeploymentListResponse) error - - // GetDeployment is used to lookup a particular deployment. - GetDeployment(args *structs.DeploymentSpecificRequest, reply *structs.SingleDeploymentResponse) error - - // GetJobVersions is used to lookup the versions of a job. This is used when - // rolling back to find the latest stable job - GetJobVersions(args *structs.JobVersionsRequest, reply *structs.JobVersionsResponse) error - - // GetJob is used to lookup a particular job. - GetJob(args *structs.JobSpecificRequest, reply *structs.SingleJobResponse) error -} - // Watcher is used to watch deployments and their allocations created // by the scheduler and trigger the scheduler when allocation health // transistions. @@ -91,9 +69,8 @@ type Watcher struct { // deployments watcher raft DeploymentRaftEndpoints - // stateWatchers is the set of functions required to watch a deployment for - // state changes - stateWatchers DeploymentStateWatchers + // state is the state that is watched for state changes. + state *state.StateStore // watchers is the set of active watchers, one per deployment watchers map[string]*deploymentWatcher @@ -110,12 +87,11 @@ type Watcher struct { // NewDeploymentsWatcher returns a deployments watcher that is used to watch // deployments and trigger the scheduler as needed. -func NewDeploymentsWatcher(logger *log.Logger, watchers DeploymentStateWatchers, +func NewDeploymentsWatcher(logger *log.Logger, raft DeploymentRaftEndpoints, stateQueriesPerSecond float64, evalBatchDuration time.Duration) *Watcher { return &Watcher{ - stateWatchers: watchers, raft: raft, queryLimiter: rate.NewLimiter(rate.Limit(stateQueriesPerSecond), 100), evalBatchDuration: evalBatchDuration, @@ -124,14 +100,22 @@ func NewDeploymentsWatcher(logger *log.Logger, watchers DeploymentStateWatchers, } // SetEnabled is used to control if the watcher is enabled. The watcher -// should only be enabled on the active leader. -func (w *Watcher) SetEnabled(enabled bool) error { +// should only be enabled on the active leader. When being enabled the state is +// passsed in as it is no longer valid once a leader election has taken place. +func (w *Watcher) SetEnabled(enabled bool, state *state.StateStore) error { w.l.Lock() defer w.l.Unlock() wasEnabled := w.enabled w.enabled = enabled + if state != nil { + if wasEnabled && w.state != nil { + panic("we may have a blocking") + } + w.state = state + } + // Flush the state to create the necessary objects w.flush() @@ -166,7 +150,7 @@ func (w *Watcher) watchDeployments(ctx context.Context) { dindex := uint64(1) for { // Block getting all deployments using the last deployment index. - resp, err := w.getDeploys(ctx, dindex) + deployments, idx, err := w.getDeploys(ctx, dindex) if err != nil { if err == context.Canceled || ctx.Err() == context.Canceled { return @@ -175,14 +159,12 @@ func (w *Watcher) watchDeployments(ctx context.Context) { w.logger.Printf("[ERR] nomad.deployments_watcher: failed to retrieve deploylements: %v", err) } - // Guard against npe - if resp == nil { - continue - } + // Update the latest index + dindex = idx // Ensure we are tracking the things we should and not tracking what we // shouldn't be - for _, d := range resp.Deployments { + for _, d := range deployments { if d.Active() { if err := w.add(d); err != nil { w.logger.Printf("[ERR] nomad.deployments_watcher: failed to track deployment %q: %v", d.ID, err) @@ -191,33 +173,47 @@ func (w *Watcher) watchDeployments(ctx context.Context) { w.remove(d) } } - - // Update the latest index - dindex = resp.Index } } // getDeploys retrieves all deployments blocking at the given index. -func (w *Watcher) getDeploys(ctx context.Context, index uint64) (*structs.DeploymentListResponse, error) { - // Build the request - args := &structs.DeploymentListRequest{ - QueryOptions: structs.QueryOptions{ - MinQueryIndex: index, - }, +func (w *Watcher) getDeploys(ctx context.Context, minIndex uint64) ([]*structs.Deployment, uint64, error) { + resp, index, err := w.state.BlockingQuery(w.getDeploysImpl, minIndex, ctx) + if err != nil { + return nil, 0, err } - var resp structs.DeploymentListResponse - - for resp.Index <= index { - if err := w.queryLimiter.Wait(ctx); err != nil { - return nil, err - } - - if err := w.stateWatchers.List(args, &resp); err != nil { - return nil, err - } + if err := ctx.Err(); err != nil { + return nil, 0, err } - return &resp, nil + return resp.([]*structs.Deployment), index, nil +} + +// getDeploysImpl retrieves all deployments from the passed state store. +func (w *Watcher) getDeploysImpl(ws memdb.WatchSet, state *state.StateStore) (interface{}, uint64, error) { + + iter, err := state.Deployments(ws) + if err != nil { + return nil, 0, err + } + + var deploys []*structs.Deployment + for { + raw := iter.Next() + if raw == nil { + break + } + deploy := raw.(*structs.Deployment) + deploys = append(deploys, deploy) + } + + // Use the last index that affected the deployment table + index, err := state.Index("deployment") + if err != nil { + return nil, 0, err + } + + return deploys, index, nil } // add adds a deployment to the watch list @@ -246,18 +242,20 @@ func (w *Watcher) addLocked(d *structs.Deployment) (*deploymentWatcher, error) { } // Get the job the deployment is referencing - args := &structs.JobSpecificRequest{ - JobID: d.JobID, - } - var resp structs.SingleJobResponse - if err := w.stateWatchers.GetJob(args, &resp); err != nil { + snap, err := w.state.Snapshot() + if err != nil { return nil, err } - if resp.Job == nil { + + job, err := snap.JobByID(nil, d.JobID) + if err != nil { + return nil, err + } + if job == nil { return nil, fmt.Errorf("deployment %q references unknown job %q", d.ID, d.JobID) } - watcher := newDeploymentWatcher(w.ctx, w.queryLimiter, w.logger, w.stateWatchers, d, resp.Job, w) + watcher := newDeploymentWatcher(w.ctx, w.queryLimiter, w.logger, w.state, d, job, w) w.watchers[d.ID] = watcher return watcher, nil } @@ -283,18 +281,21 @@ func (w *Watcher) remove(d *structs.Deployment) { // a watcher. If the deployment does not exist or is terminal an error is // returned. func (w *Watcher) forceAdd(dID string) (*deploymentWatcher, error) { - // Build the request - args := &structs.DeploymentSpecificRequest{DeploymentID: dID} - var resp structs.SingleDeploymentResponse - if err := w.stateWatchers.GetDeployment(args, &resp); err != nil { + snap, err := w.state.Snapshot() + if err != nil { return nil, err } - if resp.Deployment == nil { + deployment, err := snap.DeploymentByID(nil, dID) + if err != nil { + return nil, err + } + + if deployment == nil { return nil, fmt.Errorf("unknown deployment %q", dID) } - return w.addLocked(resp.Deployment) + return w.addLocked(deployment) } // getOrCreateWatcher returns the deployment watcher for the given deployment ID. diff --git a/nomad/deploymentwatcher/deployments_watcher_test.go b/nomad/deploymentwatcher/deployments_watcher_test.go index e62ae747a..5bc38c658 100644 --- a/nomad/deploymentwatcher/deployments_watcher_test.go +++ b/nomad/deploymentwatcher/deployments_watcher_test.go @@ -16,7 +16,7 @@ import ( func testDeploymentWatcher(t *testing.T, qps float64, batchDur time.Duration) (*Watcher, *mockBackend) { m := newMockBackend(t) - w := NewDeploymentsWatcher(testLogger(), m, m, qps, batchDur) + w := NewDeploymentsWatcher(testLogger(), m, qps, batchDur) return w, m } @@ -30,23 +30,11 @@ func TestWatcher_WatchDeployments(t *testing.T) { assert := assert.New(t) w, m := defaultTestDeploymentWatcher(t) - // Return no allocations or evals - m.On("Allocations", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) { - reply := args.Get(1).(*structs.AllocListResponse) - reply.Index = m.nextIndex() - }) - m.On("Evaluations", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) { - reply := args.Get(1).(*structs.JobEvaluationsResponse) - reply.Index = m.nextIndex() - }) - // Create three jobs j1, j2, j3 := mock.Job(), mock.Job(), mock.Job() - jobs := map[string]*structs.Job{ - j1.ID: j1, - j2.ID: j2, - j3.ID: j3, - } + assert.Nil(m.state.UpsertJob(100, j1)) + assert.Nil(m.state.UpsertJob(101, j2)) + assert.Nil(m.state.UpsertJob(102, j3)) // Create three deployments all running d1, d2, d3 := mock.Deployment(), mock.Deployment(), mock.Deployment() @@ -54,46 +42,27 @@ func TestWatcher_WatchDeployments(t *testing.T) { d2.JobID = j2.ID d3.JobID = j3.ID - m.On("GetJob", mocker.Anything, mocker.Anything). - Return(nil).Run(func(args mocker.Arguments) { - in := args.Get(0).(*structs.JobSpecificRequest) - reply := args.Get(1).(*structs.SingleJobResponse) - reply.Job = jobs[in.JobID] - reply.Index = reply.Job.ModifyIndex - }) - - // Set up the calls for retrieving deployments - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) { - reply := args.Get(1).(*structs.DeploymentListResponse) - reply.Deployments = []*structs.Deployment{d1} - reply.Index = m.nextIndex() - }).Once() + // Upsert the first deployment + assert.Nil(m.state.UpsertDeployment(103, d1)) // Next list 3 block1 := make(chan time.Time) - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) { - reply := args.Get(1).(*structs.DeploymentListResponse) - reply.Deployments = []*structs.Deployment{d1, d2, d3} - reply.Index = m.nextIndex() - }).Once().WaitUntil(block1) + go func() { + <-block1 + assert.Nil(m.state.UpsertDeployment(104, d2)) + assert.Nil(m.state.UpsertDeployment(105, d3)) + }() //// Next list 3 but have one be terminal block2 := make(chan time.Time) d3terminal := d3.Copy() d3terminal.Status = structs.DeploymentStatusFailed - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) { - reply := args.Get(1).(*structs.DeploymentListResponse) - reply.Deployments = []*structs.Deployment{d1, d2, d3terminal} - reply.Index = m.nextIndex() - }).WaitUntil(block2) + go func() { + <-block2 + assert.Nil(m.state.UpsertDeployment(106, d3terminal)) + }() - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) { - reply := args.Get(1).(*structs.DeploymentListResponse) - reply.Deployments = []*structs.Deployment{d1, d2, d3terminal} - reply.Index = m.nextIndex() - }) - - w.SetEnabled(true) + w.SetEnabled(true, m.state) testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, func(err error) { assert.Equal(1, len(w.watchers), "1 deployment returned") }) @@ -111,17 +80,7 @@ func TestWatcher_UnknownDeployment(t *testing.T) { t.Parallel() assert := assert.New(t) w, m := defaultTestDeploymentWatcher(t) - w.SetEnabled(true) - - // Set up the calls for retrieving deployments - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) { - reply := args.Get(1).(*structs.DeploymentListResponse) - reply.Index = m.nextIndex() - }) - m.On("GetDeployment", mocker.Anything, mocker.Anything).Return(nil).Run(func(args mocker.Arguments) { - reply := args.Get(1).(*structs.SingleDeploymentResponse) - reply.Index = m.nextIndex() - }) + w.SetEnabled(true, m.state) // The expected error is that it should be an unknown deployment dID := structs.GenerateUUID() @@ -181,16 +140,7 @@ func TestWatcher_SetAllocHealth_Unknown(t *testing.T) { assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") - // Assert the following methods will be called - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState) - m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)), - mocker.Anything).Return(nil).Run(m.allocationsFromState) - m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.evaluationsFromState) - m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.getJobFromState) - - w.SetEnabled(true) + w.SetEnabled(true, m.state) testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) @@ -233,16 +183,7 @@ func TestWatcher_SetAllocHealth_Healthy(t *testing.T) { assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a}), "UpsertAllocs") - // Assert the following methods will be called - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState) - m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)), - mocker.Anything).Return(nil).Run(m.allocationsFromState) - m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.evaluationsFromState) - m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.getJobFromState) - - w.SetEnabled(true) + w.SetEnabled(true, m.state) testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) @@ -283,16 +224,7 @@ func TestWatcher_SetAllocHealth_Unhealthy(t *testing.T) { assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a}), "UpsertAllocs") - // Assert the following methods will be called - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState) - m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)), - mocker.Anything).Return(nil).Run(m.allocationsFromState) - m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.evaluationsFromState) - m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.getJobFromState) - - w.SetEnabled(true) + w.SetEnabled(true, m.state) testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) @@ -350,18 +282,7 @@ func TestWatcher_SetAllocHealth_Unhealthy_Rollback(t *testing.T) { j2.Stable = false assert.Nil(m.state.UpsertJob(m.nextIndex(), j2), "UpsertJob2") - // Assert the following methods will be called - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState) - m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)), - mocker.Anything).Return(nil).Run(m.allocationsFromState) - m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.evaluationsFromState) - m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.getJobFromState) - m.On("GetJobVersions", mocker.MatchedBy(matchJobVersionsRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.getJobVersionsFromState) - - w.SetEnabled(true) + w.SetEnabled(true, m.state) testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) @@ -417,16 +338,7 @@ func TestWatcher_PromoteDeployment_HealthyCanaries(t *testing.T) { assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a}), "UpsertAllocs") - // Assert the following methods will be called - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState) - m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)), - mocker.Anything).Return(nil).Run(m.allocationsFromState) - m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.evaluationsFromState) - m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.getJobFromState) - - w.SetEnabled(true) + w.SetEnabled(true, m.state) testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) @@ -473,16 +385,7 @@ func TestWatcher_PromoteDeployment_UnhealthyCanaries(t *testing.T) { assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a}), "UpsertAllocs") - // Assert the following methods will be called - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState) - m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)), - mocker.Anything).Return(nil).Run(m.allocationsFromState) - m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.evaluationsFromState) - m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.getJobFromState) - - w.SetEnabled(true) + w.SetEnabled(true, m.state) testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) @@ -525,16 +428,7 @@ func TestWatcher_PauseDeployment_Pause_Running(t *testing.T) { assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") - // Assert the following methods will be called - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState) - m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)), - mocker.Anything).Return(nil).Run(m.allocationsFromState) - m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.evaluationsFromState) - m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.getJobFromState) - - w.SetEnabled(true) + w.SetEnabled(true, m.state) testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) @@ -574,16 +468,7 @@ func TestWatcher_PauseDeployment_Pause_Paused(t *testing.T) { assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") - // Assert the following methods will be called - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState) - m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)), - mocker.Anything).Return(nil).Run(m.allocationsFromState) - m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.evaluationsFromState) - m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.getJobFromState) - - w.SetEnabled(true) + w.SetEnabled(true, m.state) testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) @@ -623,16 +508,7 @@ func TestWatcher_PauseDeployment_Unpause_Paused(t *testing.T) { assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") - // Assert the following methods will be called - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState) - m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)), - mocker.Anything).Return(nil).Run(m.allocationsFromState) - m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.evaluationsFromState) - m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.getJobFromState) - - w.SetEnabled(true) + w.SetEnabled(true, m.state) testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) @@ -672,16 +548,7 @@ func TestWatcher_PauseDeployment_Unpause_Running(t *testing.T) { assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") - // Assert the following methods will be called - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState) - m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)), - mocker.Anything).Return(nil).Run(m.allocationsFromState) - m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.evaluationsFromState) - m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.getJobFromState) - - w.SetEnabled(true) + w.SetEnabled(true, m.state) testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) @@ -721,16 +588,7 @@ func TestWatcher_FailDeployment_Running(t *testing.T) { assert.Nil(m.state.UpsertJob(m.nextIndex(), j), "UpsertJob") assert.Nil(m.state.UpsertDeployment(m.nextIndex(), d), "UpsertDeployment") - // Assert the following methods will be called - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState) - m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)), - mocker.Anything).Return(nil).Run(m.allocationsFromState) - m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.evaluationsFromState) - m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.getJobFromState) - - w.SetEnabled(true) + w.SetEnabled(true, m.state) testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) @@ -783,18 +641,7 @@ func TestDeploymentWatcher_Watch(t *testing.T) { j2.Stable = false assert.Nil(m.state.UpsertJob(m.nextIndex(), j2), "UpsertJob2") - // Assert the following methods will be called - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState) - m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d.ID)), - mocker.Anything).Return(nil).Run(m.allocationsFromState) - m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.evaluationsFromState) - m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.getJobFromState) - m.On("GetJobVersions", mocker.MatchedBy(matchJobVersionsRequest(j.ID)), - mocker.Anything).Return(nil).Run(m.getJobVersionsFromState) - - w.SetEnabled(true) + w.SetEnabled(true, m.state) testutil.WaitForResult(func() (bool, error) { return 1 == len(w.watchers), nil }, func(err error) { assert.Equal(1, len(w.watchers), "Should have 1 deployment") }) @@ -912,30 +759,7 @@ func TestWatcher_BatchEvals(t *testing.T) { assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a1}), "UpsertAllocs") assert.Nil(m.state.UpsertAllocs(m.nextIndex(), []*structs.Allocation{a2}), "UpsertAllocs") - // Assert the following methods will be called - m.On("List", mocker.Anything, mocker.Anything).Return(nil).Run(m.listFromState) - - m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d1.ID)), - mocker.Anything).Return(nil).Run(m.allocationsFromState) - m.On("Allocations", mocker.MatchedBy(matchDeploymentSpecificRequest(d2.ID)), - mocker.Anything).Return(nil).Run(m.allocationsFromState) - - m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j1.ID)), - mocker.Anything).Return(nil).Run(m.evaluationsFromState) - m.On("Evaluations", mocker.MatchedBy(matchJobSpecificRequest(j2.ID)), - mocker.Anything).Return(nil).Run(m.evaluationsFromState) - - m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j1.ID)), - mocker.Anything).Return(nil).Run(m.getJobFromState) - m.On("GetJob", mocker.MatchedBy(matchJobSpecificRequest(j2.ID)), - mocker.Anything).Return(nil).Run(m.getJobFromState) - - m.On("GetJobVersions", mocker.MatchedBy(matchJobVersionsRequest(j1.ID)), - mocker.Anything).Return(nil).Run(m.getJobVersionsFromState) - m.On("GetJobVersions", mocker.MatchedBy(matchJobVersionsRequest(j2.ID)), - mocker.Anything).Return(nil).Run(m.getJobVersionsFromState) - - w.SetEnabled(true) + w.SetEnabled(true, m.state) testutil.WaitForResult(func() (bool, error) { return 2 == len(w.watchers), nil }, func(err error) { assert.Equal(2, len(w.watchers), "Should have 2 deployment") }) diff --git a/nomad/deploymentwatcher/testutil_test.go b/nomad/deploymentwatcher/testutil_test.go index 768a28717..06fbf542c 100644 --- a/nomad/deploymentwatcher/testutil_test.go +++ b/nomad/deploymentwatcher/testutil_test.go @@ -8,7 +8,6 @@ import ( "sync" "testing" - memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" mocker "github.com/stretchr/testify/mock" @@ -256,114 +255,3 @@ func matchDeploymentAllocHealthRequest(c *matchDeploymentAllocHealthRequestConfi return true } } - -func (m *mockBackend) Evaluations(args *structs.JobSpecificRequest, reply *structs.JobEvaluationsResponse) error { - rargs := m.Called(args, reply) - return rargs.Error(0) -} - -func (m *mockBackend) evaluationsFromState(in mocker.Arguments) { - args, reply := in.Get(0).(*structs.JobSpecificRequest), in.Get(1).(*structs.JobEvaluationsResponse) - ws := memdb.NewWatchSet() - evals, _ := m.state.EvalsByJob(ws, args.JobID) - reply.Evaluations = evals - reply.Index, _ = m.state.Index("evals") -} - -func (m *mockBackend) Allocations(args *structs.DeploymentSpecificRequest, reply *structs.AllocListResponse) error { - rargs := m.Called(args, reply) - return rargs.Error(0) -} - -func (m *mockBackend) allocationsFromState(in mocker.Arguments) { - args, reply := in.Get(0).(*structs.DeploymentSpecificRequest), in.Get(1).(*structs.AllocListResponse) - ws := memdb.NewWatchSet() - allocs, _ := m.state.AllocsByDeployment(ws, args.DeploymentID) - - var stubs []*structs.AllocListStub - for _, a := range allocs { - stubs = append(stubs, a.Stub()) - } - - reply.Allocations = stubs - reply.Index, _ = m.state.Index("allocs") -} - -func (m *mockBackend) List(args *structs.DeploymentListRequest, reply *structs.DeploymentListResponse) error { - rargs := m.Called(args, reply) - return rargs.Error(0) -} - -func (m *mockBackend) listFromState(in mocker.Arguments) { - reply := in.Get(1).(*structs.DeploymentListResponse) - ws := memdb.NewWatchSet() - iter, _ := m.state.Deployments(ws) - - var deploys []*structs.Deployment - for { - raw := iter.Next() - if raw == nil { - break - } - - deploys = append(deploys, raw.(*structs.Deployment)) - } - - reply.Deployments = deploys - reply.Index, _ = m.state.Index("deployment") -} - -func (m *mockBackend) GetDeployment(args *structs.DeploymentSpecificRequest, reply *structs.SingleDeploymentResponse) error { - rargs := m.Called(args, reply) - return rargs.Error(0) -} - -func (m *mockBackend) GetJobVersions(args *structs.JobVersionsRequest, reply *structs.JobVersionsResponse) error { - rargs := m.Called(args, reply) - return rargs.Error(0) -} - -func (m *mockBackend) getJobVersionsFromState(in mocker.Arguments) { - args, reply := in.Get(0).(*structs.JobVersionsRequest), in.Get(1).(*structs.JobVersionsResponse) - ws := memdb.NewWatchSet() - versions, _ := m.state.JobVersionsByID(ws, args.JobID) - reply.Versions = versions - reply.Index, _ = m.state.Index("jobs") -} - -func (m *mockBackend) GetJob(args *structs.JobSpecificRequest, reply *structs.SingleJobResponse) error { - rargs := m.Called(args, reply) - return rargs.Error(0) -} - -func (m *mockBackend) getJobFromState(in mocker.Arguments) { - args, reply := in.Get(0).(*structs.JobSpecificRequest), in.Get(1).(*structs.SingleJobResponse) - ws := memdb.NewWatchSet() - job, _ := m.state.JobByID(ws, args.JobID) - reply.Job = job - reply.Index, _ = m.state.Index("jobs") -} - -// matchDeploymentSpecificRequest is used to match that a deployment specific -// request is for the passed deployment id -func matchDeploymentSpecificRequest(dID string) func(args *structs.DeploymentSpecificRequest) bool { - return func(args *structs.DeploymentSpecificRequest) bool { - return args.DeploymentID == dID - } -} - -// matchJobSpecificRequest is used to match that a job specific -// request is for the passed job id -func matchJobSpecificRequest(jID string) func(args *structs.JobSpecificRequest) bool { - return func(args *structs.JobSpecificRequest) bool { - return args.JobID == jID - } -} - -// matchJobVersionsRequest is used to match that a job version -// request is for the passed job id -func matchJobVersionsRequest(jID string) func(args *structs.JobVersionsRequest) bool { - return func(args *structs.JobVersionsRequest) bool { - return args.JobID == jID - } -} diff --git a/nomad/leader.go b/nomad/leader.go index d54dfbbb2..75211d0d8 100644 --- a/nomad/leader.go +++ b/nomad/leader.go @@ -131,7 +131,7 @@ func (s *Server) establishLeadership(stopCh chan struct{}) error { s.blockedEvals.SetEnabled(true) // Enable the deployment watcher, since we are now the leader - if err := s.deploymentWatcher.SetEnabled(true); err != nil { + if err := s.deploymentWatcher.SetEnabled(true, s.State()); err != nil { return err } @@ -494,7 +494,7 @@ func (s *Server) revokeLeadership() error { s.vault.SetActive(false) // Disable the deployment watcher as it is only useful as a leader. - if err := s.deploymentWatcher.SetEnabled(false); err != nil { + if err := s.deploymentWatcher.SetEnabled(false, nil); err != nil { return err } diff --git a/nomad/rpc.go b/nomad/rpc.go index b0eeff515..5cbd82226 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -1,6 +1,7 @@ package nomad import ( + "context" "crypto/tls" "fmt" "io" @@ -341,7 +342,9 @@ type blockingOptions struct { // blockingRPC is used for queries that need to wait for a // minimum index. This is used to block and wait for changes. func (s *Server) blockingRPC(opts *blockingOptions) error { - var timeout *time.Timer + var deadline time.Time + ctx := context.Background() + var cancel context.CancelFunc var state *state.StateStore // Fast path non-blocking @@ -360,8 +363,9 @@ func (s *Server) blockingRPC(opts *blockingOptions) error { opts.queryOpts.MaxQueryTime += lib.RandomStagger(opts.queryOpts.MaxQueryTime / jitterFraction) // Setup a query timeout - timeout = time.NewTimer(opts.queryOpts.MaxQueryTime) - defer timeout.Stop() + deadline = time.Now().Add(opts.queryOpts.MaxQueryTime) + ctx, cancel = context.WithDeadline(context.Background(), deadline) + defer cancel() RUN_QUERY: // Update the query meta data @@ -393,7 +397,7 @@ RUN_QUERY: // Check for minimum query time if err == nil && opts.queryOpts.MinQueryIndex > 0 && opts.queryMeta.Index <= opts.queryOpts.MinQueryIndex { - if expired := ws.Watch(timeout.C); !expired { + if expired := ws.WatchCtx(ctx); !expired { goto RUN_QUERY } } diff --git a/nomad/server.go b/nomad/server.go index fbd8f728f..2f1ff12a8 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -679,23 +679,15 @@ func (s *Server) setupConsulSyncer() error { // shim that provides the appropriate methods. func (s *Server) setupDeploymentWatcher() error { - // Create the shims - stateShim := &deploymentWatcherStateShim{ - region: s.Region(), - evaluations: s.endpoints.Job.Evaluations, - allocations: s.endpoints.Deployment.Allocations, - list: s.endpoints.Deployment.List, - getDeployment: s.endpoints.Deployment.GetDeployment, - getJobVersions: s.endpoints.Job.GetJobVersions, - getJob: s.endpoints.Job.GetJob, - } + // Create the raft shim type to restrict the set of raft methods that can be + // made raftShim := &deploymentWatcherRaftShim{ apply: s.raftApply, } // Create the deployment watcher s.deploymentWatcher = deploymentwatcher.NewDeploymentsWatcher( - s.logger, stateShim, raftShim, + s.logger, raftShim, deploymentwatcher.LimitStateQueriesPerSecond, deploymentwatcher.CrossDeploymentEvalBatchDuration) diff --git a/nomad/state/state_store.go b/nomad/state/state_store.go index 6d8923ac4..0bef7d1aa 100644 --- a/nomad/state/state_store.go +++ b/nomad/state/state_store.go @@ -1,6 +1,7 @@ package state import ( + "context" "fmt" "io" "log" @@ -87,6 +88,48 @@ func (s *StateStore) Abandon() { close(s.abandonCh) } +// QueryFn is the definition of a function that can be used to implement a basic +// blocking query against the state store. +type QueryFn func(memdb.WatchSet, *StateStore) (resp interface{}, index uint64, err error) + +// BlockingQuery takes a query function and runs the function until the minimum +// query index is met or until the passed context is cancelled. +func (s *StateStore) BlockingQuery(query QueryFn, minIndex uint64, ctx context.Context) ( + resp interface{}, index uint64, err error) { + +RUN_QUERY: + // We capture the state store and its abandon channel but pass a snapshot to + // the blocking query function. We operate on the snapshot to allow separate + // calls to the state store not all wrapped within the same transaction. + abandonCh := s.AbandonCh() + snap, _ := s.Snapshot() + stateSnap := &snap.StateStore + + // We can skip all watch tracking if this isn't a blocking query. + var ws memdb.WatchSet + if minIndex > 0 { + ws = memdb.NewWatchSet() + + // This channel will be closed if a snapshot is restored and the + // whole state store is abandoned. + ws.Add(abandonCh) + } + + resp, index, err = query(ws, stateSnap) + if err != nil { + return nil, index, err + } + + // We haven't reached the min-index yet. + if minIndex > 0 && index <= minIndex { + if expired := ws.WatchCtx(ctx); !expired { + goto RUN_QUERY + } + } + + return resp, index, nil +} + // UpsertPlanResults is used to upsert the results of a plan. func (s *StateStore) UpsertPlanResults(index uint64, results *structs.ApplyPlanResultsRequest) error { txn := s.db.Txn(true) diff --git a/nomad/state/state_store_test.go b/nomad/state/state_store_test.go index c47e53b0d..71f1e9003 100644 --- a/nomad/state/state_store_test.go +++ b/nomad/state/state_store_test.go @@ -1,6 +1,7 @@ package state import ( + "context" "fmt" "os" "reflect" @@ -27,6 +28,69 @@ func testStateStore(t *testing.T) *StateStore { return state } +func TestStateStore_Blocking_Error(t *testing.T) { + expected := fmt.Errorf("test error") + errFn := func(memdb.WatchSet, *StateStore) (interface{}, uint64, error) { + return nil, 0, expected + } + + state := testStateStore(t) + _, idx, err := state.BlockingQuery(errFn, 10, context.Background()) + assert.EqualError(t, err, expected.Error()) + assert.Zero(t, idx) +} + +func TestStateStore_Blocking_Timeout(t *testing.T) { + noopFn := func(memdb.WatchSet, *StateStore) (interface{}, uint64, error) { + return nil, 5, nil + } + + state := testStateStore(t) + timeout := time.Now().Add(10 * time.Millisecond) + deadlineCtx, cancel := context.WithDeadline(context.Background(), timeout) + defer cancel() + + _, idx, err := state.BlockingQuery(noopFn, 10, deadlineCtx) + assert.Nil(t, err) + assert.EqualValues(t, 5, idx) + assert.WithinDuration(t, timeout, time.Now(), 5*time.Millisecond) +} + +func TestStateStore_Blocking_MinQuery(t *testing.T) { + job := mock.Job() + count := 0 + queryFn := func(ws memdb.WatchSet, s *StateStore) (interface{}, uint64, error) { + _, err := s.JobByID(ws, job.ID) + if err != nil { + return nil, 0, err + } + + count++ + if count == 1 { + return false, 5, nil + } else if count > 2 { + return false, 20, fmt.Errorf("called too many times") + } + + return true, 11, nil + } + + state := testStateStore(t) + timeout := time.Now().Add(10 * time.Millisecond) + deadlineCtx, cancel := context.WithDeadline(context.Background(), timeout) + defer cancel() + + time.AfterFunc(5*time.Millisecond, func() { + state.UpsertJob(11, job) + }) + + resp, idx, err := state.BlockingQuery(queryFn, 10, deadlineCtx) + assert.Nil(t, err) + assert.Equal(t, 2, count) + assert.EqualValues(t, 11, idx) + assert.True(t, resp.(bool)) +} + // This test checks that: // 1) The job is denormalized // 2) Allocations are created diff --git a/vendor/github.com/hashicorp/go-immutable-radix/iradix.go b/vendor/github.com/hashicorp/go-immutable-radix/iradix.go index 551ccbde7..c7172c406 100644 --- a/vendor/github.com/hashicorp/go-immutable-radix/iradix.go +++ b/vendor/github.com/hashicorp/go-immutable-radix/iradix.go @@ -183,6 +183,31 @@ func (t *Txn) writeNode(n *Node, forLeafUpdate bool) *Node { return nc } +// Visit all the nodes in the tree under n, and add their mutateChannels to the transaction +// Returns the size of the subtree visited +func (t *Txn) trackChannelsAndCount(n *Node) int { + // Count only leaf nodes + leaves := 0 + if n.leaf != nil { + leaves = 1 + } + // Mark this node as being mutated. + if t.trackMutate { + t.trackChannel(n.mutateCh) + } + + // Mark its leaf as being mutated, if appropriate. + if t.trackMutate && n.leaf != nil { + t.trackChannel(n.leaf.mutateCh) + } + + // Recurse on the children + for _, e := range n.edges { + leaves += t.trackChannelsAndCount(e.node) + } + return leaves +} + // mergeChild is called to collapse the given node with its child. This is only // called when the given node is not a leaf and has a single edge. func (t *Txn) mergeChild(n *Node) { @@ -357,6 +382,56 @@ func (t *Txn) delete(parent, n *Node, search []byte) (*Node, *leafNode) { return nc, leaf } +// delete does a recursive deletion +func (t *Txn) deletePrefix(parent, n *Node, search []byte) (*Node, int) { + // Check for key exhaustion + if len(search) == 0 { + nc := t.writeNode(n, true) + if n.isLeaf() { + nc.leaf = nil + } + nc.edges = nil + return nc, t.trackChannelsAndCount(n) + } + + // Look for an edge + label := search[0] + idx, child := n.getEdge(label) + // We make sure that either the child node's prefix starts with the search term, or the search term starts with the child node's prefix + // Need to do both so that we can delete prefixes that don't correspond to any node in the tree + if child == nil || (!bytes.HasPrefix(child.prefix, search) && !bytes.HasPrefix(search, child.prefix)) { + return nil, 0 + } + + // Consume the search prefix + if len(child.prefix) > len(search) { + search = []byte("") + } else { + search = search[len(child.prefix):] + } + newChild, numDeletions := t.deletePrefix(n, child, search) + if newChild == nil { + return nil, 0 + } + // Copy this node. WATCH OUT - it's safe to pass "false" here because we + // will only ADD a leaf via nc.mergeChild() if there isn't one due to + // the !nc.isLeaf() check in the logic just below. This is pretty subtle, + // so be careful if you change any of the logic here. + + nc := t.writeNode(n, false) + + // Delete the edge if the node has no edges + if newChild.leaf == nil && len(newChild.edges) == 0 { + nc.delEdge(label) + if n != t.root && len(nc.edges) == 1 && !nc.isLeaf() { + t.mergeChild(nc) + } + } else { + nc.edges[idx].node = newChild + } + return nc, numDeletions +} + // Insert is used to add or update a given key. The return provides // the previous value and a bool indicating if any was set. func (t *Txn) Insert(k []byte, v interface{}) (interface{}, bool) { @@ -384,6 +459,19 @@ func (t *Txn) Delete(k []byte) (interface{}, bool) { return nil, false } +// DeletePrefix is used to delete an entire subtree that matches the prefix +// This will delete all nodes under that prefix +func (t *Txn) DeletePrefix(prefix []byte) bool { + newRoot, numDeletions := t.deletePrefix(nil, t.root, prefix) + if newRoot != nil { + t.root = newRoot + t.size = t.size - numDeletions + return true + } + return false + +} + // Root returns the current root of the radix tree within this // transaction. The root is not safe across insert and delete operations, // but can be used to read the current state during a transaction. @@ -524,6 +612,14 @@ func (t *Tree) Delete(k []byte) (*Tree, interface{}, bool) { return txn.Commit(), old, ok } +// DeletePrefix is used to delete all nodes starting with a given prefix. Returns the new tree, +// and a bool indicating if the prefix matched any nodes +func (t *Tree) DeletePrefix(k []byte) (*Tree, bool) { + txn := t.Txn() + ok := txn.DeletePrefix(k) + return txn.Commit(), ok +} + // Root returns the root node of the tree which can be used for richer // query operations. func (t *Tree) Root() *Node { diff --git a/vendor/github.com/hashicorp/go-memdb/README.md b/vendor/github.com/hashicorp/go-memdb/README.md index 675044beb..4e051c81a 100644 --- a/vendor/github.com/hashicorp/go-memdb/README.md +++ b/vendor/github.com/hashicorp/go-memdb/README.md @@ -21,6 +21,11 @@ The database provides the following: a single field index, or more advanced compound field indexes. Certain types like UUID can be efficiently compressed from strings into byte indexes for reduced storage requirements. + +* Watches - Callers can populate a watch set as part of a query, which can be used to + detect when a modification has been made to the database which affects the query + results. This lets callers easily watch for changes in the database in a very general + way. For the underlying immutable radix trees, see [go-immutable-radix](https://github.com/hashicorp/go-immutable-radix). diff --git a/vendor/github.com/hashicorp/go-memdb/index.go b/vendor/github.com/hashicorp/go-memdb/index.go index a40312aa5..d1fb95146 100644 --- a/vendor/github.com/hashicorp/go-memdb/index.go +++ b/vendor/github.com/hashicorp/go-memdb/index.go @@ -305,20 +305,19 @@ func (u *UintFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { } // IsUintType returns whether the passed type is a type of uint and the number -// of bytes the type is. To avoid platform specific sizes, the uint type returns -// 8 bytes regardless of if it is smaller. +// of bytes needed to encode the type. func IsUintType(k reflect.Kind) (size int, okay bool) { switch k { case reflect.Uint: - return 8, true + return binary.MaxVarintLen64, true case reflect.Uint8: - return 1, true - case reflect.Uint16: return 2, true + case reflect.Uint16: + return binary.MaxVarintLen16, true case reflect.Uint32: - return 4, true + return binary.MaxVarintLen32, true case reflect.Uint64: - return 8, true + return binary.MaxVarintLen64, true default: return 0, false } diff --git a/vendor/github.com/hashicorp/go-memdb/memdb.go b/vendor/github.com/hashicorp/go-memdb/memdb.go index 13817547b..9e9b98df5 100644 --- a/vendor/github.com/hashicorp/go-memdb/memdb.go +++ b/vendor/github.com/hashicorp/go-memdb/memdb.go @@ -13,8 +13,8 @@ import ( // on values. The database makes use of immutable radix trees to provide // transactions and MVCC. type MemDB struct { - schema *DBSchema - root unsafe.Pointer // *iradix.Tree underneath + schema *DBSchema + root unsafe.Pointer // *iradix.Tree underneath primary bool // There can only be a single writter at once @@ -30,8 +30,8 @@ func NewMemDB(schema *DBSchema) (*MemDB, error) { // Create the MemDB db := &MemDB{ - schema: schema, - root: unsafe.Pointer(iradix.New()), + schema: schema, + root: unsafe.Pointer(iradix.New()), primary: true, } if err := db.initialize(); err != nil { @@ -65,8 +65,8 @@ func (db *MemDB) Txn(write bool) *Txn { // operations to the existing DB. func (db *MemDB) Snapshot() *MemDB { clone := &MemDB{ - schema: db.schema, - root: unsafe.Pointer(db.getRoot()), + schema: db.schema, + root: unsafe.Pointer(db.getRoot()), primary: false, } return clone @@ -76,7 +76,7 @@ func (db *MemDB) Snapshot() *MemDB { func (db *MemDB) initialize() error { root := db.getRoot() for tName, tableSchema := range db.schema.Tables { - for iName, _ := range tableSchema.Indexes { + for iName := range tableSchema.Indexes { index := iradix.New() path := indexPath(tName, iName) root, _, _ = root.Insert(path, index) diff --git a/vendor/github.com/hashicorp/go-memdb/txn.go b/vendor/github.com/hashicorp/go-memdb/txn.go index 617070e40..c4273648e 100644 --- a/vendor/github.com/hashicorp/go-memdb/txn.go +++ b/vendor/github.com/hashicorp/go-memdb/txn.go @@ -330,6 +330,96 @@ func (txn *Txn) Delete(table string, obj interface{}) error { return nil } +// DeletePrefix is used to delete an entire subtree based on a prefix. +// The given index must be a prefix index, and will be used to perform a scan and enumerate the set of objects to delete. +// These will be removed from all other indexes, and then a special prefix operation will delete the objects from the given index in an efficient subtree delete operation. +// This is useful when you have a very large number of objects indexed by the given index, along with a much smaller number of entries in the other indexes for those objects. +func (txn *Txn) DeletePrefix(table string, prefix_index string, prefix string) (bool, error) { + if !txn.write { + return false, fmt.Errorf("cannot delete in read-only transaction") + } + + if !strings.HasSuffix(prefix_index, "_prefix") { + return false, fmt.Errorf("Index name for DeletePrefix must be a prefix index, Got %v ", prefix_index) + } + + deletePrefixIndex := strings.TrimSuffix(prefix_index, "_prefix") + + // Get an iterator over all of the keys with the given prefix. + entries, err := txn.Get(table, prefix_index, prefix) + if err != nil { + return false, fmt.Errorf("failed kvs lookup: %s", err) + } + // Get the table schema + tableSchema, ok := txn.db.schema.Tables[table] + if !ok { + return false, fmt.Errorf("invalid table '%s'", table) + } + + foundAny := false + for entry := entries.Next(); entry != nil; entry = entries.Next() { + if !foundAny { + foundAny = true + } + // Get the primary ID of the object + idSchema := tableSchema.Indexes[id] + idIndexer := idSchema.Indexer.(SingleIndexer) + ok, idVal, err := idIndexer.FromObject(entry) + if err != nil { + return false, fmt.Errorf("failed to build primary index: %v", err) + } + if !ok { + return false, fmt.Errorf("object missing primary index") + } + // Remove the object from all the indexes except the given prefix index + for name, indexSchema := range tableSchema.Indexes { + if name == deletePrefixIndex { + continue + } + indexTxn := txn.writableIndex(table, name) + + // Handle the update by deleting from the index first + var ( + ok bool + vals [][]byte + err error + ) + switch indexer := indexSchema.Indexer.(type) { + case SingleIndexer: + var val []byte + ok, val, err = indexer.FromObject(entry) + vals = [][]byte{val} + case MultiIndexer: + ok, vals, err = indexer.FromObject(entry) + } + if err != nil { + return false, fmt.Errorf("failed to build index '%s': %v", name, err) + } + + if ok { + // Handle non-unique index by computing a unique index. + // This is done by appending the primary key which must + // be unique anyways. + for _, val := range vals { + if !indexSchema.Unique { + val = append(val, idVal...) + } + indexTxn.Delete(val) + } + } + } + } + if foundAny { + indexTxn := txn.writableIndex(table, deletePrefixIndex) + ok = indexTxn.DeletePrefix([]byte(prefix)) + if !ok { + panic(fmt.Errorf("prefix %v matched some entries but DeletePrefix did not delete any ", prefix)) + } + return true, nil + } + return false, nil +} + // DeleteAll is used to delete all the objects in a given table // matching the constraints on the index func (txn *Txn) DeleteAll(table, index string, args ...interface{}) (int, error) { diff --git a/vendor/github.com/hashicorp/go-memdb/watch.go b/vendor/github.com/hashicorp/go-memdb/watch.go index 7c4a3ba6e..2f6dcdb81 100644 --- a/vendor/github.com/hashicorp/go-memdb/watch.go +++ b/vendor/github.com/hashicorp/go-memdb/watch.go @@ -1,6 +1,9 @@ package memdb -import "time" +import ( + "context" + "time" +) // WatchSet is a collection of watch channels. type WatchSet map[<-chan struct{}]struct{} @@ -46,6 +49,30 @@ func (w WatchSet) Watch(timeoutCh <-chan time.Time) bool { return false } + // Create a context that gets cancelled when the timeout is triggered + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + select { + case <-timeoutCh: + cancel() + case <-ctx.Done(): + } + }() + + return w.WatchCtx(ctx) +} + +// WatchCtx is used to wait for either the watch set to trigger or for the +// context to be cancelled. Returns true if the context is cancelled. Watch with +// a timeout channel can be mimicked by creating a context with a deadline. +// WatchCtx should be preferred over Watch. +func (w WatchSet) WatchCtx(ctx context.Context) bool { + if w == nil { + return false + } + if n := len(w); n <= aFew { idx := 0 chunk := make([]<-chan struct{}, aFew) @@ -53,23 +80,18 @@ func (w WatchSet) Watch(timeoutCh <-chan time.Time) bool { chunk[idx] = watchCh idx++ } - return watchFew(chunk, timeoutCh) - } else { - return w.watchMany(timeoutCh) + return watchFew(ctx, chunk) } + + return w.watchMany(ctx) } // watchMany is used if there are many watchers. -func (w WatchSet) watchMany(timeoutCh <-chan time.Time) bool { - // Make a fake timeout channel we can feed into watchFew to cancel all - // the blocking goroutines. - doneCh := make(chan time.Time) - defer close(doneCh) - +func (w WatchSet) watchMany(ctx context.Context) bool { // Set up a goroutine for each watcher. triggerCh := make(chan struct{}, 1) watcher := func(chunk []<-chan struct{}) { - if timeout := watchFew(chunk, doneCh); !timeout { + if timeout := watchFew(ctx, chunk); !timeout { select { case triggerCh <- struct{}{}: default: @@ -102,7 +124,7 @@ func (w WatchSet) watchMany(timeoutCh <-chan time.Time) bool { select { case <-triggerCh: return false - case <-timeoutCh: + case <-ctx.Done(): return true } } diff --git a/vendor/github.com/hashicorp/go-memdb/watch_few.go b/vendor/github.com/hashicorp/go-memdb/watch_few.go index f2bb19db1..cd06c619c 100644 --- a/vendor/github.com/hashicorp/go-memdb/watch_few.go +++ b/vendor/github.com/hashicorp/go-memdb/watch_few.go @@ -1,8 +1,9 @@ -//go:generate sh -c "go run watch-gen/main.go >watch_few.go" package memdb +//go:generate sh -c "go run watch-gen/main.go >watch_few.go" + import( - "time" + "context" ) // aFew gives how many watchers this function is wired to support. You must @@ -11,7 +12,7 @@ const aFew = 32 // watchFew is used if there are only a few watchers as a performance // optimization. -func watchFew(ch []<-chan struct{}, timeoutCh <-chan time.Time) bool { +func watchFew(ctx context.Context, ch []<-chan struct{}) bool { select { case <-ch[0]: @@ -110,7 +111,7 @@ func watchFew(ch []<-chan struct{}, timeoutCh <-chan time.Time) bool { case <-ch[31]: return false - case <-timeoutCh: + case <-ctx.Done(): return true } } diff --git a/vendor/vendor.json b/vendor/vendor.json index 025c95475..77abfd9cc 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -731,16 +731,16 @@ "revisionTime": "2017-07-16T17:45:23Z" }, { - "checksumSHA1": "zvmksNyW6g+Fd/bywd4vcn8rp+M=", + "checksumSHA1": "Cas2nprG6pWzf05A2F/OlnjUu2Y=", "path": "github.com/hashicorp/go-immutable-radix", - "revision": "30664b879c9a771d8d50b137ab80ee0748cb2fcc", - "revisionTime": "2017-02-14T02:52:36Z" + "revision": "8aac2701530899b64bdea735a1de8da899815220", + "revisionTime": "2017-07-25T22:12:15Z" }, { - "checksumSHA1": "KeH4FuTKuv3tqFOr3NpLQtL1jPs=", + "checksumSHA1": "Q7MLoOLgXyvHBVmT/rvSeOhJo6c=", "path": "github.com/hashicorp/go-memdb", - "revision": "ed59a4bb9146689d4b00d060b70b9e9648b523af", - "revisionTime": "2017-04-11T17:33:47Z" + "revision": "f2dec88c7441ddf375eabd561b0a1584b67b8ce4", + "revisionTime": "2017-08-30T23:01:53Z" }, { "path": "github.com/hashicorp/go-msgpack/codec",