diff --git a/sdk/testutil/retry/retry.go b/sdk/testutil/retry/retry.go index 09f845abe..b59bd1c4c 100644 --- a/sdk/testutil/retry/retry.go +++ b/sdk/testutil/retry/retry.go @@ -120,15 +120,7 @@ func dedup(a []string) string { func run(r Retryer, t Failer, f func(r *R)) { t.Helper() rr := &R{} - fail := func() { - t.Helper() - out := dedup(rr.output) - if out != "" { - t.Log(out) - } - t.FailNow() - } - for r.NextOr(t, fail) { + for r.Continue() { func() { defer func() { if p := recover(); p != nil && p != runFailed { @@ -142,6 +134,12 @@ func run(r Retryer, t Failer, f func(r *R)) { } rr.fail = false } + + out := dedup(rr.output) + if out != "" { + t.Log(out) + } + t.FailNow() } // DefaultFailer provides default retry.Run() behavior for unit tests. @@ -162,9 +160,9 @@ func ThreeTimes() *Counter { // Retryer provides an interface for repeating operations // until they succeed or an exit condition is met. type Retryer interface { - // NextOr returns true if the operation should be repeated. - // Otherwise, it calls fail and returns false. - NextOr(t Failer, fail func()) bool + // NextOr returns true if the operation should be repeated, otherwise it + // returns false to indicate retrying should stop. + Continue() bool } // Counter repeats an operation a given number of @@ -176,10 +174,8 @@ type Counter struct { count int } -func (r *Counter) NextOr(t Failer, fail func()) bool { - t.Helper() +func (r *Counter) Continue() bool { if r.count == r.Count { - fail() return false } if r.count > 0 { @@ -200,14 +196,12 @@ type Timer struct { stop time.Time } -func (r *Timer) NextOr(t Failer, fail func()) bool { - t.Helper() +func (r *Timer) Continue() bool { if r.stop.IsZero() { r.stop = time.Now().Add(r.Timeout) return true } if time.Now().After(r.stop) { - fail() return false } time.Sleep(r.Wait) diff --git a/sdk/testutil/retry/retry_test.go b/sdk/testutil/retry/retry_test.go index f58f8eb92..95186374e 100644 --- a/sdk/testutil/retry/retry_test.go +++ b/sdk/testutil/retry/retry_test.go @@ -3,6 +3,8 @@ package retry import ( "testing" "time" + + "github.com/stretchr/testify/require" ) // delta defines the time band a test run should complete in. @@ -19,19 +21,15 @@ func TestRetryer(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - var iters, fails int - fail := func() { fails++ } + var iters int start := time.Now() - for tt.r.NextOr(t, fail) { + for tt.r.Continue() { iters++ } dur := time.Since(start) if got, want := iters, 3; got != want { t.Fatalf("got %d retries want %d", got, want) } - if got, want := fails, 1; got != want { - t.Fatalf("got %d FailNow calls want %d", got, want) - } // since the first iteration happens immediately // the retryer waits only twice for three iterations. // order of events: (true, (wait) true, (wait) true, false) @@ -41,3 +39,32 @@ func TestRetryer(t *testing.T) { }) } } + +func TestRunWith(t *testing.T) { + t.Run("calls FailNow after exceeding retries", func(t *testing.T) { + ft := &fakeT{} + iter := 0 + RunWith(&Counter{Count: 3, Wait: time.Millisecond}, ft, func(r *R) { + iter++ + r.FailNow() + }) + + require.Equal(t, 3, iter) + require.Equal(t, 1, ft.fails) + }) +} + +type fakeT struct { + fails int +} + +func (f *fakeT) Helper() {} + +func (f *fakeT) Log(args ...interface{}) { +} + +func (f *fakeT) FailNow() { + f.fails++ +} + +var _ Failer = &fakeT{}