diff --git a/shamir/shamir.go b/shamir/shamir.go index 5180202dd..da32f2c47 100644 --- a/shamir/shamir.go +++ b/shamir/shamir.go @@ -194,9 +194,16 @@ func Combine(parts [][]byte) ([]byte, error) { x_samples := make([]uint8, len(parts)) y_samples := make([]uint8, len(parts)) - // Set the x value for each sample + // Set the x value for each sample and ensure no x_sample values are the same, + // otherwise div() can be unhappy + checkMap := map[byte]bool{} for i, part := range parts { - x_samples[i] = part[firstPartLen-1] + samp := part[firstPartLen-1] + if exists := checkMap[samp]; exists { + return nil, fmt.Errorf("duplicte part detected") + } + checkMap[samp] = true + x_samples[i] = samp } // Reconstruct each byte diff --git a/shamir/shamir_test.go b/shamir/shamir_test.go index e2ee3cf44..d2b2e68d1 100644 --- a/shamir/shamir_test.go +++ b/shamir/shamir_test.go @@ -71,6 +71,14 @@ func TestCombine_invalid(t *testing.T) { if _, err := Combine(parts); err == nil { t.Fatalf("should err") } + + parts = [][]byte{ + []byte("foo"), + []byte("foo"), + } + if _, err := Combine(parts); err == nil { + t.Fatalf("should err") + } } func TestCombine(t *testing.T) {