diff --git a/core/eds.go b/core/eds.go index 8a202f5750..c0a000f95c 100644 --- a/core/eds.go +++ b/core/eds.go @@ -13,6 +13,8 @@ import ( "github.com/celestiaorg/celestia-app/pkg/appconsts" "github.com/celestiaorg/celestia-app/pkg/shares" "github.com/celestiaorg/celestia-app/pkg/square" + "github.com/celestiaorg/celestia-app/pkg/wrapper" + "github.com/celestiaorg/nmt" "github.com/celestiaorg/rsmt2d" "github.com/celestiaorg/celestia-node/share" @@ -22,7 +24,7 @@ import ( // extendBlock extends the given block data, returning the resulting // ExtendedDataSquare (EDS). If there are no transactions in the block, // nil is returned in place of the eds. -func extendBlock(data types.Data, appVersion uint64, fn rsmt2d.TreeConstructorFn) (*rsmt2d.ExtendedDataSquare, error) { +func extendBlock(data types.Data, appVersion uint64, options ...nmt.Option) (*rsmt2d.ExtendedDataSquare, error) { if app.IsEmptyBlock(data, appVersion) { return nil, nil } @@ -32,21 +34,21 @@ func extendBlock(data types.Data, appVersion uint64, fn rsmt2d.TreeConstructorFn if err != nil { return nil, err } - - if data.SquareSize != uint64(square.Size(len(dataSquare))) { - panic("mismatch") - } - return extendShares(shares.ToBytes(dataSquare), fn) + return extendShares(shares.ToBytes(dataSquare), options...) } -func extendShares(s [][]byte, fn rsmt2d.TreeConstructorFn) (*rsmt2d.ExtendedDataSquare, error) { +func extendShares(s [][]byte, options ...nmt.Option) (*rsmt2d.ExtendedDataSquare, error) { // Check that the length of the square is a power of 2. if !shares.IsPowerOfTwo(len(s)) { return nil, fmt.Errorf("number of shares is not a power of 2: got %d", len(s)) } // here we construct a tree // Note: uses the nmt wrapper to construct the tree. - return rsmt2d.ComputeExtendedDataSquare(s, appconsts.DefaultCodec(), fn) + squareSize := square.Size(len(s)) + return rsmt2d.ComputeExtendedDataSquare(s, + appconsts.DefaultCodec(), + wrapper.NewConstructor(uint64(squareSize), + options...)) } // storeEDS will only store extended block if it is not empty and doesn't already exist. diff --git a/core/eds_test.go b/core/eds_test.go index f92a798587..6a2026ee58 100644 --- a/core/eds_test.go +++ b/core/eds_test.go @@ -10,7 +10,6 @@ import ( "github.com/celestiaorg/celestia-app/app" "github.com/celestiaorg/celestia-app/pkg/appconsts" "github.com/celestiaorg/celestia-app/pkg/da" - "github.com/celestiaorg/celestia-app/pkg/wrapper" "github.com/celestiaorg/celestia-node/share" ) @@ -24,7 +23,7 @@ func TestTrulyEmptySquare(t *testing.T) { SquareSize: 1, } - eds, err := extendBlock(data, appconsts.LatestVersion, wrapper.NewConstructor(data.SquareSize)) + eds, err := extendBlock(data, appconsts.LatestVersion) require.NoError(t, err) assert.Nil(t, eds) } @@ -40,7 +39,7 @@ func TestEmptySquareWithZeroTxs(t *testing.T) { Txs: []types.Tx{}, } - eds, err := extendBlock(data, appconsts.LatestVersion, wrapper.NewConstructor(data.SquareSize)) + eds, err := extendBlock(data, appconsts.LatestVersion) require.Nil(t, eds) require.NoError(t, err) diff --git a/core/exchange.go b/core/exchange.go index b6ad9a4672..d93735e403 100644 --- a/core/exchange.go +++ b/core/exchange.go @@ -8,7 +8,6 @@ import ( "golang.org/x/sync/errgroup" - "github.com/celestiaorg/celestia-app/pkg/wrapper" libhead "github.com/celestiaorg/go-header" "github.com/celestiaorg/nmt" @@ -109,9 +108,7 @@ func (ce *Exchange) Get(ctx context.Context, hash libhead.Hash) (*header.Extende // extend block data adder := ipld.NewProofsAdder(int(block.Data.SquareSize)) - eds, err := extendBlock(block.Data, block.Header.Version.App, - wrapper.NewConstructor(block.Data.SquareSize, - nmt.NodeVisitor(adder.VisitFn()))) + eds, err := extendBlock(block.Data, block.Header.Version.App, nmt.NodeVisitor(adder.VisitFn())) if err != nil { return nil, fmt.Errorf("extending block data for height %d: %w", &block.Height, err) } @@ -149,9 +146,7 @@ func (ce *Exchange) getExtendedHeaderByHeight(ctx context.Context, height *int64 // extend block data adder := ipld.NewProofsAdder(int(b.Data.SquareSize)) - eds, err := extendBlock(b.Data, b.Header.Version.App, - wrapper.NewConstructor(b.Data.SquareSize, - nmt.NodeVisitor(adder.VisitFn()))) + eds, err := extendBlock(b.Data, b.Header.Version.App, nmt.NodeVisitor(adder.VisitFn())) if err != nil { return nil, fmt.Errorf("extending block data for height %d: %w", b.Header.Height, err) } diff --git a/core/header_test.go b/core/header_test.go index 92e6151de5..6315fbc143 100644 --- a/core/header_test.go +++ b/core/header_test.go @@ -8,8 +8,6 @@ import ( "github.com/stretchr/testify/require" "github.com/tendermint/tendermint/libs/rand" - "github.com/celestiaorg/celestia-app/pkg/wrapper" - "github.com/celestiaorg/celestia-node/header" "github.com/celestiaorg/celestia-node/header/headertest" ) @@ -32,8 +30,7 @@ func TestMakeExtendedHeaderForEmptyBlock(t *testing.T) { comm, val, err := fetcher.GetBlockInfo(ctx, &height) require.NoError(t, err) - eds, err := extendBlock(b.Data, b.Header.Version.App, - wrapper.NewConstructor(b.Data.SquareSize)) + eds, err := extendBlock(b.Data, b.Header.Version.App) require.NoError(t, err) headerExt, err := header.MakeExtendedHeader(ctx, &b.Header, comm, val, eds) diff --git a/core/listener.go b/core/listener.go index 1fac0bbedc..638484a241 100644 --- a/core/listener.go +++ b/core/listener.go @@ -11,7 +11,6 @@ import ( "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" - "github.com/celestiaorg/celestia-app/pkg/wrapper" libhead "github.com/celestiaorg/go-header" "github.com/celestiaorg/nmt" @@ -154,9 +153,7 @@ func (cl *Listener) handleNewSignedBlock(ctx context.Context, b types.EventDataS ) // extend block data adder := ipld.NewProofsAdder(int(b.Data.SquareSize)) - eds, err := extendBlock(b.Data, b.Header.Version.App, - wrapper.NewConstructor(b.Data.SquareSize, - nmt.NodeVisitor(adder.VisitFn()))) + eds, err := extendBlock(b.Data, b.Header.Version.App, nmt.NodeVisitor(adder.VisitFn())) if err != nil { return fmt.Errorf("extending block data: %w", err) }