prototest: fix early return condition in AssertElementsMatch (#17416)

This commit is contained in:
R.B. Boyer 2023-05-22 13:49:50 -05:00 committed by GitHub
parent 0477d15a5a
commit e1110ea82d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 127 additions and 12 deletions

View File

@ -550,7 +550,7 @@ func TestStreamResources_Server_StreamTracker(t *testing.T) {
it := incrementalTime{ it := incrementalTime{
base: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC), base: time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC),
} }
waitUntil := it.FutureNow(6) waitUntil := it.FutureNow(7)
srv, store := newTestServer(t, nil) srv, store := newTestServer(t, nil)
srv.Tracker.setClock(it.Now) srv.Tracker.setClock(it.Now)
@ -1442,7 +1442,9 @@ func makeClient(t *testing.T, srv *testServer, peerID string) *MockClient {
// Note that server address may not come as an initial message // Note that server address may not come as an initial message
for _, resourceURL := range []string{ for _, resourceURL := range []string{
pbpeerstream.TypeURLExportedService, pbpeerstream.TypeURLExportedService,
pbpeerstream.TypeURLExportedServiceList,
pbpeerstream.TypeURLPeeringTrustBundle, pbpeerstream.TypeURLPeeringTrustBundle,
// only dialers request, which is why this is absent below
pbpeerstream.TypeURLPeeringServerAddresses, pbpeerstream.TypeURLPeeringServerAddresses,
} { } {
init := &pbpeerstream.ReplicationMessage{ init := &pbpeerstream.ReplicationMessage{
@ -1471,7 +1473,7 @@ func makeClient(t *testing.T, srv *testServer, peerID string) *MockClient {
{ {
Payload: &pbpeerstream.ReplicationMessage_Request_{ Payload: &pbpeerstream.ReplicationMessage_Request_{
Request: &pbpeerstream.ReplicationMessage_Request{ Request: &pbpeerstream.ReplicationMessage_Request{
ResourceURL: pbpeerstream.TypeURLPeeringTrustBundle, ResourceURL: pbpeerstream.TypeURLExportedServiceList,
// The PeerID field is only set for the messages coming FROM // The PeerID field is only set for the messages coming FROM
// the establishing side and are going to be empty from the // the establishing side and are going to be empty from the
// other side. // other side.
@ -1482,7 +1484,7 @@ func makeClient(t *testing.T, srv *testServer, peerID string) *MockClient {
{ {
Payload: &pbpeerstream.ReplicationMessage_Request_{ Payload: &pbpeerstream.ReplicationMessage_Request_{
Request: &pbpeerstream.ReplicationMessage_Request{ Request: &pbpeerstream.ReplicationMessage_Request{
ResourceURL: pbpeerstream.TypeURLPeeringServerAddresses, ResourceURL: pbpeerstream.TypeURLPeeringTrustBundle,
// The PeerID field is only set for the messages coming FROM // The PeerID field is only set for the messages coming FROM
// the establishing side and are going to be empty from the // the establishing side and are going to be empty from the
// other side. // other side.

View File

@ -93,10 +93,11 @@ func TestList_Many(t *testing.T) {
// Prevent test flakes if the generated names collide. // Prevent test flakes if the generated names collide.
artist.Id.Name = fmt.Sprintf("%s-%d", artist.Id.Name, i) artist.Id.Name = fmt.Sprintf("%s-%d", artist.Id.Name, i)
_, err = server.Backend.WriteCAS(tc.ctx, artist)
rsp, err := client.Write(tc.ctx, &pbresource.WriteRequest{Resource: artist})
require.NoError(t, err) require.NoError(t, err)
resources[i] = artist resources[i] = rsp.Resource
} }
rsp, err := client.List(tc.ctx, &pbresource.ListRequest{ rsp, err := client.List(tc.ctx, &pbresource.ListRequest{

View File

@ -32,14 +32,25 @@ func AssertDeepEqual(t TestingT, x, y interface{}, opts ...cmp.Option) {
func AssertElementsMatch[V any]( func AssertElementsMatch[V any](
t TestingT, listX, listY []V, opts ...cmp.Option, t TestingT, listX, listY []V, opts ...cmp.Option,
) { ) {
t.Helper() diff := diffElements(listX, listY, opts...)
if diff != "" {
t.Fatalf("assertion failed: slices do not have matching elements\n--- expected\n+++ actual\n%v", diff)
}
}
func diffElements[V any](
listX, listY []V, opts ...cmp.Option,
) string {
if len(listX) == 0 && len(listY) == 0 { if len(listX) == 0 && len(listY) == 0 {
return return ""
} }
opts = append(opts, protocmp.Transform()) opts = append(opts, protocmp.Transform())
if len(listX) != len(listY) {
return cmp.Diff(listX, listY, opts...)
}
// dump into a map keyed by sliceID // dump into a map keyed by sliceID
mapX := make(map[int]V) mapX := make(map[int]V)
for i, val := range listX { for i, val := range listX {
@ -63,8 +74,8 @@ func AssertElementsMatch[V any](
} }
} }
if len(outX) == len(outY) && len(listX) == len(listY) { if len(outX) == len(listX) && len(outY) == len(listY) {
return // matches return "" // matches
} }
// dump remainder into the slice so we can generate a useful error // dump remainder into the slice so we can generate a useful error
@ -75,9 +86,7 @@ func AssertElementsMatch[V any](
outY = append(outY, itemY) outY = append(outY, itemY)
} }
if diff := cmp.Diff(outX, outY, opts...); diff != "" { return cmp.Diff(outX, outY, opts...)
t.Fatalf("assertion failed: slices do not have matching elements\n--- expected\n+++ actual\n%v", diff)
}
} }
func AssertContainsElement[V any](t TestingT, list []V, element V, opts ...cmp.Option) { func AssertContainsElement[V any](t TestingT, list []V, element V, opts ...cmp.Option) {

View File

@ -0,0 +1,103 @@
package prototest
import (
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
type wrap struct {
V int
O string
}
func (w *wrap) String() string {
return strconv.Itoa(w.V)
}
func (w *wrap) GoString() string {
return w.String()
}
func TestDiffElements_noProtobufs(t *testing.T) {
// NOTE: this test only tests non-protobuf slices initially
type testcase struct {
a, b []*wrap
notSame bool
}
run := func(t *testing.T, tc testcase) {
diff := diffElements(tc.a, tc.b)
if tc.notSame {
require.False(t, diff == "", "expected not to be the same")
} else {
require.True(t, diff == "", "expected to be the same")
}
}
w := func(v int) *wrap {
return &wrap{V: v}
}
cases := map[string]testcase{
"nil": {},
"empty": {a: []*wrap{}, b: []*wrap{}},
"nil and empty": {a: []*wrap{}, b: nil},
"ordered match": {
a: []*wrap{w(1), w(22), w(303), w(43004), w(-5)},
b: []*wrap{w(1), w(22), w(303), w(43004), w(-5)},
},
"permuted match": {
a: []*wrap{w(1), w(22), w(303), w(43004), w(-5)},
b: []*wrap{w(-5), w(43004), w(303), w(22), w(1)},
},
"duplicates": {
a: []*wrap{w(1), w(2), w(2), w(3)},
b: []*wrap{w(2), w(1), w(3), w(2)},
},
// no match
"1 vs nil": {
a: []*wrap{w(1)},
b: nil,
notSame: true,
},
"1 vs 2": {
a: []*wrap{w(1)},
b: []*wrap{w(2)},
notSame: true,
},
"1,2 vs 2,3": {
a: []*wrap{w(1), w(2)},
b: []*wrap{w(2), w(3)},
notSame: true,
},
"1,2 vs 1,2,3": {
a: []*wrap{w(1), w(2)},
b: []*wrap{w(1), w(2), w(3)},
notSame: true,
},
"duplicates omitted": {
a: []*wrap{w(1), w(2), w(2), w(3)},
b: []*wrap{w(1), w(3), w(2)},
notSame: true,
},
}
allCases := make(map[string]testcase)
for name, tc := range cases {
allCases[name] = tc
allCases[name+" (flipped)"] = testcase{
a: tc.b,
b: tc.a,
notSame: tc.notSame,
}
}
for name, tc := range allCases {
t.Run(name, func(t *testing.T) {
run(t, tc)
})
}
}