diff --git a/api/api.go b/api/api.go index ff06c5cc1..0a62b4f68 100644 --- a/api/api.go +++ b/api/api.go @@ -106,9 +106,25 @@ type QueryOptions struct { // a value from 0 to 5 (inclusive). RelayFactor uint8 - // Context (optional) is passed through to the underlying http request layer, can be used - // to set timeouts and deadlines as well as to cancel requests - Context context.Context + // ctx is an optional context pass through to the underlying HTTP + // request layer. Use Context() and WithContext() to manage this. + ctx context.Context +} + +func (o *QueryOptions) Context() context.Context { + if o != nil && o.ctx != nil { + return o.ctx + } + return context.Background() +} + +func (o *QueryOptions) WithContext(ctx context.Context) *QueryOptions { + o2 := new(QueryOptions) + if o != nil { + *o2 = *o + } + o2.ctx = ctx + return o2 } // WriteOptions are used to parameterize a write @@ -125,6 +141,26 @@ type WriteOptions struct { // relayed back to the sender through N other random nodes. Must be // a value from 0 to 5 (inclusive). RelayFactor uint8 + + // ctx is an optional context pass through to the underlying HTTP + // request layer. Use Context() and WithContext() to manage this. + ctx context.Context +} + +func (o *WriteOptions) Context() context.Context { + if o != nil && o.ctx != nil { + return o.ctx + } + return context.Background() +} + +func (o *WriteOptions) WithContext(ctx context.Context) *WriteOptions { + o2 := new(WriteOptions) + if o != nil { + *o2 = *o + } + o2.ctx = ctx + return o2 } // QueryMeta is used to return meta data about a query @@ -499,7 +535,7 @@ func (r *request) setQueryOptions(q *QueryOptions) { if q.RelayFactor != 0 { r.params.Set("relay-factor", strconv.Itoa(int(q.RelayFactor))) } - r.ctx = q.Context + r.ctx = q.ctx } // durToMsec converts a duration to a millisecond specified string. If the @@ -544,6 +580,7 @@ func (r *request) setWriteOptions(q *WriteOptions) { if q.RelayFactor != 0 { r.params.Set("relay-factor", strconv.Itoa(int(q.RelayFactor))) } + r.ctx = q.ctx } // toHTTP converts the request to an HTTP request diff --git a/api/session.go b/api/session.go index 6fcf00b1e..1613f11a6 100644 --- a/api/session.go +++ b/api/session.go @@ -146,6 +146,8 @@ func (s *Session) Renew(id string, q *WriteOptions) (*SessionEntry, *WriteMeta, // session until a doneCh is closed. This is meant to be used in a long running // goroutine to ensure a session stays valid. func (s *Session) RenewPeriodic(initialTTL string, id string, q *WriteOptions, doneCh <-chan struct{}) error { + ctx := q.Context() + ttl, err := time.ParseDuration(initialTTL) if err != nil { return err @@ -179,6 +181,11 @@ func (s *Session) RenewPeriodic(initialTTL string, id string, q *WriteOptions, d // Attempt a session destroy s.Destroy(id, q) return nil + + case <-ctx.Done(): + // Bail immediately since attempting the destroy would + // use the canceled context in q, which would just bail. + return ctx.Err() } } } diff --git a/api/session_test.go b/api/session_test.go index 468ccc331..0039bb2e3 100644 --- a/api/session_test.go +++ b/api/session_test.go @@ -1,6 +1,8 @@ package api import ( + "context" + "strings" "testing" "time" ) @@ -194,6 +196,82 @@ func TestAPI_SessionCreateDestroyRenewPeriodic(t *testing.T) { } } +func TestAPI_SessionRenewPeriodic_Cancel(t *testing.T) { + t.Parallel() + c, s := makeClient(t) + defer s.Stop() + + session := c.Session() + entry := &SessionEntry{ + Behavior: SessionBehaviorDelete, + TTL: "500s", // disable ttl + } + + t.Run("done channel", func(t *testing.T) { + id, _, err := session.Create(entry, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + errCh := make(chan error, 1) + doneCh := make(chan struct{}) + go func() { errCh <- session.RenewPeriodic("1s", id, nil, doneCh) }() + + close(doneCh) + + select { + case <-time.After(1 * time.Second): + t.Fatal("renewal loop didn't terminate") + case err = <-errCh: + if err != nil { + t.Fatalf("err: %v", err) + } + } + + sess, _, err := session.Info(id, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + if sess != nil { + t.Fatalf("session was not expired") + } + }) + + t.Run("context", func(t *testing.T) { + id, _, err := session.Create(entry, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + wo := new(WriteOptions).WithContext(ctx) + + errCh := make(chan error, 1) + go func() { errCh <- session.RenewPeriodic("1s", id, wo, nil) }() + + cancel() + + select { + case <-time.After(1 * time.Second): + t.Fatal("renewal loop didn't terminate") + case err = <-errCh: + if err == nil || !strings.Contains(err.Error(), "context canceled") { + t.Fatalf("err: %v", err) + } + } + + // See comment in session.go for why the session isn't removed + // in this case. + sess, _, err := session.Info(id, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + if sess == nil { + t.Fatalf("session should not be expired") + } + }) +} + func TestAPI_SessionInfo(t *testing.T) { t.Parallel() c, s := makeClient(t) diff --git a/watch/funcs.go b/watch/funcs.go index 738343009..20265decc 100644 --- a/watch/funcs.go +++ b/watch/funcs.go @@ -234,6 +234,6 @@ func eventWatch(params map[string]interface{}) (WatcherFunc, error) { func makeQueryOptionsWithContext(p *Plan, stale bool) consulapi.QueryOptions { ctx, cancel := context.WithCancel(context.Background()) p.cancelFunc = cancel - opts := consulapi.QueryOptions{AllowStale: stale, WaitIndex: p.lastIndex, Context: ctx} - return opts + opts := consulapi.QueryOptions{AllowStale: stale, WaitIndex: p.lastIndex} + return *opts.WithContext(ctx) }