Skip to content

Commit

Permalink
feat(share/eds) Add Accessor test infra (celestiaorg#3425)
Browse files Browse the repository at this point in the history
Adds generalised testing framework for file interface implementaions

Also adds non-inclusion proofs handling in RowNamespaceDataFromShares
and closes celestiaorg#3428
  • Loading branch information
walldiss committed Jul 6, 2024
1 parent 3dd1100 commit 1c42d99
Show file tree
Hide file tree
Showing 5 changed files with 352 additions and 26 deletions.
2 changes: 1 addition & 1 deletion share/new_eds/accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type Accessor interface {
AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error)
// RowNamespaceData returns data for the given namespace and row index.
RowNamespaceData(ctx context.Context, namespace share.Namespace, rowIdx int) (shwap.RowNamespaceData, error)
// Shares returns data shares extracted from the Accessor.
// Shares returns data (ODS) shares extracted from the Accessor.
Shares(ctx context.Context) ([]share.Share, error)
}

Expand Down
4 changes: 2 additions & 2 deletions share/new_eds/rsmt2d.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ func (eds Rsmt2D) RowNamespaceData(
return shwap.RowNamespaceDataFromShares(shares, namespace, rowIdx)
}

// Shares returns data shares extracted from the EDS. It returns new copy of the shares each
// Shares returns data (ODS) shares extracted from the EDS. It returns new copy of the shares each
// time.
func (eds Rsmt2D) Shares(_ context.Context) ([]share.Share, error) {
return eds.ExtendedDataSquare.Flattened(), nil
return eds.ExtendedDataSquare.FlattenedODS(), nil
}

func getAxis(eds *rsmt2d.ExtendedDataSquare, axisType rsmt2d.Axis, axisIdx int) []share.Share {
Expand Down
23 changes: 10 additions & 13 deletions share/new_eds/rsmt2d_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package eds
import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"

Expand All @@ -13,22 +14,18 @@ import (
"github.com/celestiaorg/celestia-node/share/shwap"
)

func TestRsmt2dSample(t *testing.T) {
eds, root := randRsmt2dAccsessor(t, 8)

width := int(eds.Width())
for rowIdx := 0; rowIdx < width; rowIdx++ {
for colIdx := 0; colIdx < width; colIdx++ {
shr, err := eds.Sample(context.TODO(), rowIdx, colIdx)
require.NoError(t, err)

err = shr.Validate(root, rowIdx, colIdx)
require.NoError(t, err)
}
func TestMemFile(t *testing.T) {
odsSize := 8
newAccessor := func(eds *rsmt2d.ExtendedDataSquare) Accessor {
return &Rsmt2D{ExtendedDataSquare: eds}
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
t.Cleanup(cancel)

TestSuiteAccessor(ctx, t, newAccessor, odsSize)
}

func TestRsmt2dHalfRowFrom(t *testing.T) {
func TestRsmt2dHalfRow(t *testing.T) {
const odsSize = 8
eds, _ := randRsmt2dAccsessor(t, odsSize)

Expand Down
300 changes: 300 additions & 0 deletions share/new_eds/testing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
package eds

import (
"context"
"fmt"
"strconv"
"sync"
"testing"

"github.com/stretchr/testify/require"

"github.com/celestiaorg/nmt"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/eds/edstest"
"github.com/celestiaorg/celestia-node/share/sharetest"
"github.com/celestiaorg/celestia-node/share/shwap"
)

type createAccessor func(eds *rsmt2d.ExtendedDataSquare) Accessor

// TestSuiteAccessor runs a suite of tests for the given Accessor implementation.
func TestSuiteAccessor(
ctx context.Context,
t *testing.T,
createAccessor createAccessor,
odsSize int,
) {
t.Run("Sample", func(t *testing.T) {
testAccessorSample(ctx, t, createAccessor, odsSize)
})

t.Run("AxisHalf", func(t *testing.T) {
testAccessorAxisHalf(ctx, t, createAccessor, odsSize)
})

t.Run("RowNamespaceData", func(t *testing.T) {
testAccessorRowNamespaceData(ctx, t, createAccessor, odsSize)
})

t.Run("Shares", func(t *testing.T) {
testAccessorShares(ctx, t, createAccessor, odsSize)
})
}

func testAccessorSample(
ctx context.Context,
t *testing.T,
createAccessor createAccessor,
odsSize int,
) {
eds := edstest.RandEDS(t, odsSize)
fl := createAccessor(eds)

dah, err := share.NewRoot(eds)
require.NoError(t, err)

width := int(eds.Width())
t.Run("single thread", func(t *testing.T) {
for rowIdx := 0; rowIdx < width; rowIdx++ {
for colIdx := 0; colIdx < width; colIdx++ {
testSample(ctx, t, fl, dah, colIdx, rowIdx)
}
}
})

t.Run("parallel", func(t *testing.T) {
wg := sync.WaitGroup{}
for rowIdx := 0; rowIdx < width; rowIdx++ {
for colIdx := 0; colIdx < width; colIdx++ {
wg.Add(1)
go func(rowIdx, colIdx int) {
defer wg.Done()
testSample(ctx, t, fl, dah, rowIdx, colIdx)
}(rowIdx, colIdx)
}
}
wg.Wait()
})
}

func testSample(
ctx context.Context,
t *testing.T,
fl Accessor,
dah *share.Root,
rowIdx, colIdx int,
) {
shr, err := fl.Sample(ctx, rowIdx, colIdx)
require.NoError(t, err)

err = shr.Validate(dah, rowIdx, colIdx)
require.NoError(t, err)
}

func testAccessorRowNamespaceData(
ctx context.Context,
t *testing.T,
createAccessor createAccessor,
odsSize int,
) {
t.Run("included", func(t *testing.T) {
// generate EDS with random data and some Shares with the same namespace
sharesAmount := odsSize * odsSize
namespace := sharetest.RandV0Namespace()
// test with different amount of shares
for amount := 1; amount < sharesAmount; amount++ {
// select random amount of shares, but not less than 1
eds, dah := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize)
f := createAccessor(eds)

var actualSharesAmount int
// loop over all rows and check that the amount of shares in the namespace is equal to the expected
// amount
for i, root := range dah.RowRoots {
rowData, err := f.RowNamespaceData(ctx, namespace, i)

// namespace is not included in the row, so there should be no shares
if namespace.IsOutsideRange(root, root) {
require.ErrorIs(t, err, shwap.ErrNamespaceOutsideRange)
require.Len(t, rowData.Shares, 0)
continue
}

actualSharesAmount += len(rowData.Shares)
require.NoError(t, err)
require.True(t, len(rowData.Shares) > 0)
err = rowData.Validate(dah, namespace, i)
require.NoError(t, err)
}

// check that the amount of shares in the namespace is equal to the expected amount
require.Equal(t, amount, actualSharesAmount)
}
})

t.Run("not included", func(t *testing.T) {
// generate EDS with random data and some Shares with the same namespace
eds := edstest.RandEDS(t, odsSize)
dah, err := share.NewRoot(eds)
require.NoError(t, err)

// loop over first half of the rows, because the second half is parity and does not contain
// namespaced shares
for i, root := range dah.RowRoots[:odsSize] {
// select namespace that within the range of root namespaces, but is not included
maxNs := nmt.MaxNamespace(root, share.NamespaceSize)
absentNs, err := share.Namespace(maxNs).AddInt(-1)
require.NoError(t, err)

f := createAccessor(eds)
rowData, err := f.RowNamespaceData(ctx, absentNs, i)
require.NoError(t, err)

// namespace is not included in the row, so there should be no shares
require.Len(t, rowData.Shares, 0)
require.True(t, rowData.Proof.IsOfAbsence())

err = rowData.Validate(dah, absentNs, i)
require.NoError(t, err)
}
})
}

func testAccessorAxisHalf(
ctx context.Context,
t *testing.T,
createAccessor createAccessor,
odsSize int,
) {
eds := edstest.RandEDS(t, odsSize)
fl := createAccessor(eds)

t.Run("single thread", func(t *testing.T) {
for _, axisType := range []rsmt2d.Axis{rsmt2d.Col, rsmt2d.Row} {
for axisIdx := 0; axisIdx < int(eds.Width()); axisIdx++ {
half, err := fl.AxisHalf(ctx, axisType, axisIdx)
require.NoError(t, err)
require.Len(t, half.Shares, odsSize)

var expected []share.Share
if half.IsParity {
expected = getAxis(eds, axisType, axisIdx)[odsSize:]
} else {
expected = getAxis(eds, axisType, axisIdx)[:odsSize]
}

require.Equal(t, expected, half.Shares)
}
}
})

t.Run("parallel", func(t *testing.T) {
wg := sync.WaitGroup{}
for _, axisType := range []rsmt2d.Axis{rsmt2d.Col, rsmt2d.Row} {
for i := 0; i < int(eds.Width()); i++ {
wg.Add(1)
go func(axisType rsmt2d.Axis, idx int) {
defer wg.Done()
half, err := fl.AxisHalf(ctx, axisType, idx)
require.NoError(t, err)
require.Len(t, half.Shares, odsSize)

var expected []share.Share
if half.IsParity {
expected = getAxis(eds, axisType, idx)[odsSize:]
} else {
expected = getAxis(eds, axisType, idx)[:odsSize]
}

require.Equal(t, expected, half.Shares)
}(axisType, i)
}
}
wg.Wait()
})
}

func testAccessorShares(
ctx context.Context,
t *testing.T,
createAccessor createAccessor,
odsSize int,
) {
eds := edstest.RandEDS(t, odsSize)
fl := createAccessor(eds)

shares, err := fl.Shares(ctx)
require.NoError(t, err)
expected := eds.FlattenedODS()
require.Equal(t, expected, shares)
}

func BenchGetHalfAxisFromAccessor(
ctx context.Context,
b *testing.B,
newAccessor func(size int) Accessor,
minOdsSize, maxOdsSize int,
) {
for size := minOdsSize; size <= maxOdsSize; size *= 2 {
f := newAccessor(size)

// loop over all possible axis types and quadrants
for _, axisType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} {
for _, squareHalf := range []int{0, 1} {
name := fmt.Sprintf("Size:%v/ProofType:%s/squareHalf:%s", size, axisType, strconv.Itoa(squareHalf))
b.Run(name, func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := f.AxisHalf(ctx, axisType, f.Size(ctx)/2*(squareHalf))
require.NoError(b, err)
}
})
}
}
}
}

func BenchGetSampleFromAccessor(
ctx context.Context,
b *testing.B,
newAccessor func(size int) Accessor,
minOdsSize, maxOdsSize int,
) {
for size := minOdsSize; size <= maxOdsSize; size *= 2 {
f := newAccessor(size)

// loop over all possible axis types and quadrants
for _, q := range quadrants {
name := fmt.Sprintf("Size:%v/quadrant:%s", size, q)
b.Run(name, func(b *testing.B) {
rowIdx, colIdx := q.coordinates(f.Size(ctx))
// warm up cache
_, err := f.Sample(ctx, rowIdx, colIdx)
require.NoError(b, err, q.String())

b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := f.Sample(ctx, rowIdx, colIdx)
require.NoError(b, err)
}
})
}
}
}

type quadrant int

var quadrants = []quadrant{1, 2, 3, 4}

func (q quadrant) String() string {
return strconv.Itoa(int(q))
}

func (q quadrant) coordinates(edsSize int) (rowIdx, colIdx int) {
colIdx = edsSize/2*(int(q-1)%2) + 1
rowIdx = edsSize/2*(int(q-1)/2) + 1
return rowIdx, colIdx
}
Loading

0 comments on commit 1c42d99

Please sign in to comment.