diff --git a/codec_test.go b/codec_test.go index 244afd0..e6b1f49 100644 --- a/codec_test.go +++ b/codec_test.go @@ -90,3 +90,30 @@ func generateMissingData(count int, codec Codec) [][]byte { return output } + +// testCodec is a codec that is used for testing purposes. +type testCodec struct{} + +func newTestCodec() Codec { + return &testCodec{} +} + +func (c *testCodec) Encode(chunk [][]byte) ([][]byte, error) { + return chunk, nil +} + +func (c *testCodec) Decode(chunk [][]byte) ([][]byte, error) { + return chunk, nil +} + +func (c *testCodec) MaxChunks() int { + return 0 +} + +func (c *testCodec) Name() string { + return "testCodec" +} + +func (c *testCodec) ValidateChunkSize(_ int) error { + return nil +} diff --git a/extendeddatasquare.go b/extendeddatasquare.go index c05c6e6..c1a74de 100644 --- a/extendeddatasquare.go +++ b/extendeddatasquare.go @@ -274,6 +274,52 @@ func (eds *ExtendedDataSquare) Width() uint { return eds.width } +// Flattened returns the extended data square as a flattened slice of bytes. +func (eds *ExtendedDataSquare) Flattened() [][]byte { + return eds.dataSquare.Flattened() +} + +// FlattenedODS returns the original data square as a flattened slice of bytes. +func (eds *ExtendedDataSquare) FlattenedODS() (flattened [][]byte) { + flattened = make([][]byte, eds.originalDataWidth*eds.originalDataWidth) + for i := uint(0); i < eds.originalDataWidth; i++ { + row := eds.Row(i) + for j := uint(0); j < eds.originalDataWidth; j++ { + flattened[(i*eds.originalDataWidth)+j] = row[j] + } + } + return flattened +} + +// Equals returns true if other is equal to eds. +func (eds *ExtendedDataSquare) Equals(other *ExtendedDataSquare) bool { + if eds.originalDataWidth != other.originalDataWidth { + return false + } + if eds.codec.Name() != other.codec.Name() { + return false + } + if eds.chunkSize != other.chunkSize { + return false + } + if eds.width != other.width { + return false + } + + for rowIndex := uint(0); rowIndex < eds.Width(); rowIndex++ { + edsRow := eds.Row(rowIndex) + otherRow := other.Row(rowIndex) + + for colIndex := 0; colIndex < len(edsRow); colIndex++ { + if !bytes.Equal(edsRow[colIndex], otherRow[colIndex]) { + return false + } + } + } + + return true +} + // validateEdsWidth returns an error if edsWidth is not a valid width for an // extended data square. func validateEdsWidth(edsWidth uint) error { diff --git a/extendeddatasquare_test.go b/extendeddatasquare_test.go index 2b47ab3..caccadb 100644 --- a/extendeddatasquare_test.go +++ b/extendeddatasquare_test.go @@ -74,7 +74,7 @@ func TestComputeExtendedDataSquare(t *testing.T) { func TestImportExtendedDataSquare(t *testing.T) { t.Run("is able to import an EDS", func(t *testing.T) { - eds := createExampleEds(t, ShardSize) + eds := createExampleEds(t, shareSize) got, err := ImportExtendedDataSquare(eds.Flattened(), NewLeoRSCodec(), NewDefaultTree) assert.NoError(t, err) assert.Equal(t, eds.Flattened(), got.Flattened()) @@ -293,6 +293,85 @@ func genRandDS(width int, chunkSize int) [][]byte { return ds } +// TestFlattened_EDS tests that eds.Flattened() returns all the shares in the +// EDS. This function has the `_EDS` suffix to avoid a name collision with the +// TestFlattened. +func TestFlattened_EDS(t *testing.T) { + example := createExampleEds(t, shareSize) + want := [][]byte{ + ones, twos, zeros, threes, + threes, fours, eights, fifteens, + twos, elevens, thirteens, fours, + zeros, thirteens, fives, eights, + } + + got := example.Flattened() + assert.Equal(t, want, got) +} + +func TestFlattenedODS(t *testing.T) { + example := createExampleEds(t, shareSize) + want := [][]byte{ + ones, twos, + threes, fours, + } + + got := example.FlattenedODS() + assert.Equal(t, want, got) +} + +func TestEquals(t *testing.T) { + t.Run("returns true for two equal EDS", func(t *testing.T) { + a := createExampleEds(t, shareSize) + b := createExampleEds(t, shareSize) + assert.True(t, a.Equals(b)) + }) + t.Run("returns false for two unequal EDS", func(t *testing.T) { + a := createExampleEds(t, shareSize) + + type testCase struct { + name string + other *ExtendedDataSquare + } + + unequalOriginalDataWidth := createExampleEds(t, shareSize) + unequalOriginalDataWidth.originalDataWidth = 1 + + unequalCodecs := createExampleEds(t, shareSize) + unequalCodecs.codec = newTestCodec() + + unequalChunkSize := createExampleEds(t, shareSize*2) + + unequalEds, err := ComputeExtendedDataSquare([][]byte{ones}, NewLeoRSCodec(), NewDefaultTree) + require.NoError(t, err) + + testCases := []testCase{ + { + name: "unequal original data width", + other: unequalOriginalDataWidth, + }, + { + name: "unequal codecs", + other: unequalCodecs, + }, + { + name: "unequal chunkSize", + other: unequalChunkSize, + }, + { + name: "unequalEds", + other: unequalEds, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.False(t, a.Equals(tc.other)) + assert.False(t, reflect.DeepEqual(a, tc.other)) + }) + } + }) +} + func createExampleEds(t *testing.T, chunkSize int) (eds *ExtendedDataSquare) { ones := bytes.Repeat([]byte{1}, chunkSize) twos := bytes.Repeat([]byte{2}, chunkSize)