From ec57e983f7e797568d6c12d460497892be439313 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Wed, 26 Aug 2015 10:03:44 -0700 Subject: [PATCH] Don't allow duplicate x parts in Shamir. Add unit test for verification. --- shamir/shamir.go | 11 +++++++++-- shamir/shamir_test.go | 8 ++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/shamir/shamir.go b/shamir/shamir.go index 5180202dd..da32f2c47 100644 --- a/shamir/shamir.go +++ b/shamir/shamir.go @@ -194,9 +194,16 @@ func Combine(parts [][]byte) ([]byte, error) { x_samples := make([]uint8, len(parts)) y_samples := make([]uint8, len(parts)) - // Set the x value for each sample + // Set the x value for each sample and ensure no x_sample values are the same, + // otherwise div() can be unhappy + checkMap := map[byte]bool{} for i, part := range parts { - x_samples[i] = part[firstPartLen-1] + samp := part[firstPartLen-1] + if exists := checkMap[samp]; exists { + return nil, fmt.Errorf("duplicte part detected") + } + checkMap[samp] = true + x_samples[i] = samp } // Reconstruct each byte diff --git a/shamir/shamir_test.go b/shamir/shamir_test.go index e2ee3cf44..d2b2e68d1 100644 --- a/shamir/shamir_test.go +++ b/shamir/shamir_test.go @@ -71,6 +71,14 @@ func TestCombine_invalid(t *testing.T) { if _, err := Combine(parts); err == nil { t.Fatalf("should err") } + + parts = [][]byte{ + []byte("foo"), + []byte("foo"), + } + if _, err := Combine(parts); err == nil { + t.Fatalf("should err") + } } func TestCombine(t *testing.T) {