diff --git a/api/session.go b/api/session.go index 574738127..36e99a389 100644 --- a/api/session.go +++ b/api/session.go @@ -1,6 +1,7 @@ package api import ( + "errors" "fmt" "time" ) @@ -16,6 +17,8 @@ const ( SessionBehaviorDelete = "delete" ) +var ErrSessionExpired = errors.New("session expired") + // SessionEntry represents a session in consul type SessionEntry struct { CreateIndex uint64 @@ -113,11 +116,26 @@ func (s *Session) Destroy(id string, q *WriteOptions) (*WriteMeta, error) { // Renew renews the TTL on a given session func (s *Session) Renew(id string, q *WriteOptions) (*SessionEntry, *WriteMeta, error) { - var entries []*SessionEntry - wm, err := s.c.write("/v1/session/renew/"+id, nil, &entries, q) + r := s.c.newRequest("PUT", "/v1/session/renew/"+id) + r.setWriteOptions(q) + rtt, resp, err := s.c.doRequest(r) if err != nil { return nil, nil, err } + defer resp.Body.Close() + + wm := &WriteMeta{RequestTime: rtt} + + if resp.StatusCode == 404 { + return nil, wm, nil + } else if resp.StatusCode != 200 { + return nil, nil, fmt.Errorf("Unexpected response code: %d", resp.StatusCode) + } + + var entries []*SessionEntry + if err := decodeBody(resp, &entries); err != nil { + return nil, nil, fmt.Errorf("Failed to read response: %v", err) + } if len(entries) > 0 { return entries[0], wm, nil } @@ -149,9 +167,7 @@ func (s *Session) RenewPeriodic(initialTTL string, id string, q *WriteOptions, d continue } if entry == nil { - waitDur = time.Second - lastErr = fmt.Errorf("No SessionEntry returned") - continue + return ErrSessionExpired } // Handle the server updating the TTL diff --git a/api/session_test.go b/api/session_test.go index c503c21a0..85bea228e 100644 --- a/api/session_test.go +++ b/api/session_test.go @@ -2,6 +2,7 @@ package api import ( "testing" + "time" ) func TestSession_CreateDestroy(t *testing.T) { @@ -85,6 +86,114 @@ func TestSession_CreateRenewDestroy(t *testing.T) { } } +func TestSession_CreateRenewDestroyRenew(t *testing.T) { + t.Parallel() + c, s := makeClient(t) + defer s.Stop() + + session := c.Session() + + entry := &SessionEntry{ + Behavior: SessionBehaviorDelete, + TTL: "500s", // disable ttl + } + + id, meta, err := session.Create(entry, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + if meta.RequestTime == 0 { + t.Fatalf("bad: %v", meta) + } + + if id == "" { + t.Fatalf("invalid: %v", id) + } + + // Extend right after create. Everything should be fine. + entry, _, err = session.Renew(id, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + if entry == nil { + t.Fatal("session unexpectedly vanished") + } + + // Simulate TTL loss by manually destroying the session. + meta, err = session.Destroy(id, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + if meta.RequestTime == 0 { + t.Fatalf("bad: %v", meta) + } + + // Extend right after delete. The 404 should proxy as a nil. + entry, _, err = session.Renew(id, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + if entry != nil { + t.Fatal("session still exists") + } +} + +func TestSession_CreateDestroyRenewPeriodic(t *testing.T) { + t.Parallel() + c, s := makeClient(t) + defer s.Stop() + + session := c.Session() + + entry := &SessionEntry{ + Behavior: SessionBehaviorDelete, + TTL: "500s", // disable ttl + } + + id, meta, err := session.Create(entry, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + if meta.RequestTime == 0 { + t.Fatalf("bad: %v", meta) + } + + if id == "" { + t.Fatalf("invalid: %v", id) + } + + // This only tests Create/Destroy/RenewPeriodic to avoid the more + // difficult case of testing all of the timing code. + + // Simulate TTL loss by manually destroying the session. + meta, err = session.Destroy(id, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + if meta.RequestTime == 0 { + t.Fatalf("bad: %v", meta) + } + + // Extend right after delete. The 404 should terminate the loop quickly and return ErrSessionExpired. + errCh := make(chan error, 1) + doneCh := make(chan struct{}) + go func() { errCh <- session.RenewPeriodic("1s", id, nil, doneCh) }() + defer close(doneCh) + + select { + case <-time.After(1 * time.Second): + t.Fatal("timedout: missing session did not terminate renewal loop") + case err = <-errCh: + if err != ErrSessionExpired { + t.Fatalf("err: %v", err) + } + } +} + func TestSession_Info(t *testing.T) { t.Parallel() c, s := makeClient(t)