diff --git a/agent/consul/stats_fetcher_test.go b/agent/consul/stats_fetcher_test.go index 0bb5abb5e..a7829d46b 100644 --- a/agent/consul/stats_fetcher_test.go +++ b/agent/consul/stats_fetcher_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/hashicorp/consul/agent/metadata" + "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/testrpc" "github.com/hashicorp/consul/types" ) @@ -47,51 +48,55 @@ func TestStatsFetcher(t *testing.T) { // Do a normal fetch and make sure we get three responses. func() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - stats := s1.statsFetcher.Fetch(ctx, s1.LANMembers()) - if len(stats) != 3 { - t.Fatalf("bad: %#v", stats) - } - for id, stat := range stats { - switch types.NodeID(id) { - case s1.config.NodeID, s2.config.NodeID, s3.config.NodeID: - // OK - default: - t.Fatalf("bad: %s", id) + retry.Run(t, func(r *retry.R) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + stats := s1.statsFetcher.Fetch(ctx, s1.LANMembers()) + if len(stats) != 3 { + t.Fatalf("bad: %#v", stats) } + for id, stat := range stats { + switch types.NodeID(id) { + case s1.config.NodeID, s2.config.NodeID, s3.config.NodeID: + // OK + default: + t.Fatalf("bad: %s", id) + } - if stat == nil || stat.LastTerm == 0 { - t.Fatalf("bad: %#v", stat) + if stat == nil || stat.LastTerm == 0 { + t.Fatalf("bad: %#v", stat) + } } - } + }) }() // Fake an in-flight request to server 3 and make sure we don't fetch // from it. func() { - s1.statsFetcher.inflight[string(s3.config.NodeID)] = struct{}{} - defer delete(s1.statsFetcher.inflight, string(s3.config.NodeID)) + retry.Run(t, func(r *retry.R) { + s1.statsFetcher.inflight[string(s3.config.NodeID)] = struct{}{} + defer delete(s1.statsFetcher.inflight, string(s3.config.NodeID)) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - stats := s1.statsFetcher.Fetch(ctx, s1.LANMembers()) - if len(stats) != 2 { - t.Fatalf("bad: %#v", stats) - } - for id, stat := range stats { - switch types.NodeID(id) { - case s1.config.NodeID, s2.config.NodeID: - // OK - case s3.config.NodeID: - t.Fatalf("bad") - default: - t.Fatalf("bad: %s", id) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + stats := s1.statsFetcher.Fetch(ctx, s1.LANMembers()) + if len(stats) != 2 { + t.Fatalf("bad: %#v", stats) } + for id, stat := range stats { + switch types.NodeID(id) { + case s1.config.NodeID, s2.config.NodeID: + // OK + case s3.config.NodeID: + t.Fatalf("bad") + default: + t.Fatalf("bad: %s", id) + } - if stat == nil || stat.LastTerm == 0 { - t.Fatalf("bad: %#v", stat) + if stat == nil || stat.LastTerm == 0 { + t.Fatalf("bad: %#v", stat) + } } - } + }) }() }