Skip to content

Commit

Permalink
pass nmt options instead of treeConstructorFn
Browse files Browse the repository at this point in the history
  • Loading branch information
walldiss committed Aug 4, 2023
1 parent 0613050 commit d411a34
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 26 deletions.
18 changes: 10 additions & 8 deletions core/eds.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"github.com/celestiaorg/celestia-app/pkg/appconsts"
"github.com/celestiaorg/celestia-app/pkg/shares"
"github.com/celestiaorg/celestia-app/pkg/square"
"github.com/celestiaorg/celestia-app/pkg/wrapper"
"github.com/celestiaorg/nmt"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
Expand All @@ -22,7 +24,7 @@ import (
// extendBlock extends the given block data, returning the resulting
// ExtendedDataSquare (EDS). If there are no transactions in the block,
// nil is returned in place of the eds.
func extendBlock(data types.Data, appVersion uint64, fn rsmt2d.TreeConstructorFn) (*rsmt2d.ExtendedDataSquare, error) {
func extendBlock(data types.Data, appVersion uint64, options ...nmt.Option) (*rsmt2d.ExtendedDataSquare, error) {
if app.IsEmptyBlock(data, appVersion) {
return nil, nil
}
Expand All @@ -32,21 +34,21 @@ func extendBlock(data types.Data, appVersion uint64, fn rsmt2d.TreeConstructorFn
if err != nil {
return nil, err
}

if data.SquareSize != uint64(square.Size(len(dataSquare))) {
panic("mismatch")
}
return extendShares(shares.ToBytes(dataSquare), fn)
return extendShares(shares.ToBytes(dataSquare), options...)
}

func extendShares(s [][]byte, fn rsmt2d.TreeConstructorFn) (*rsmt2d.ExtendedDataSquare, error) {
func extendShares(s [][]byte, options ...nmt.Option) (*rsmt2d.ExtendedDataSquare, error) {
// Check that the length of the square is a power of 2.
if !shares.IsPowerOfTwo(len(s)) {
return nil, fmt.Errorf("number of shares is not a power of 2: got %d", len(s))
}
// here we construct a tree
// Note: uses the nmt wrapper to construct the tree.
return rsmt2d.ComputeExtendedDataSquare(s, appconsts.DefaultCodec(), fn)
squareSize := square.Size(len(s))
return rsmt2d.ComputeExtendedDataSquare(s,
appconsts.DefaultCodec(),
wrapper.NewConstructor(uint64(squareSize),
options...))
}

// storeEDS will only store extended block if it is not empty and doesn't already exist.
Expand Down
5 changes: 2 additions & 3 deletions core/eds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/celestiaorg/celestia-app/app"
"github.com/celestiaorg/celestia-app/pkg/appconsts"
"github.com/celestiaorg/celestia-app/pkg/da"
"github.com/celestiaorg/celestia-app/pkg/wrapper"

"github.com/celestiaorg/celestia-node/share"
)
Expand All @@ -24,7 +23,7 @@ func TestTrulyEmptySquare(t *testing.T) {
SquareSize: 1,
}

eds, err := extendBlock(data, appconsts.LatestVersion, wrapper.NewConstructor(data.SquareSize))
eds, err := extendBlock(data, appconsts.LatestVersion)
require.NoError(t, err)
assert.Nil(t, eds)
}
Expand All @@ -40,7 +39,7 @@ func TestEmptySquareWithZeroTxs(t *testing.T) {
Txs: []types.Tx{},
}

eds, err := extendBlock(data, appconsts.LatestVersion, wrapper.NewConstructor(data.SquareSize))
eds, err := extendBlock(data, appconsts.LatestVersion)
require.Nil(t, eds)
require.NoError(t, err)

Expand Down
9 changes: 2 additions & 7 deletions core/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"golang.org/x/sync/errgroup"

"github.com/celestiaorg/celestia-app/pkg/wrapper"
libhead "github.com/celestiaorg/go-header"
"github.com/celestiaorg/nmt"

Expand Down Expand Up @@ -109,9 +108,7 @@ func (ce *Exchange) Get(ctx context.Context, hash libhead.Hash) (*header.Extende

// extend block data
adder := ipld.NewProofsAdder(int(block.Data.SquareSize))
eds, err := extendBlock(block.Data, block.Header.Version.App,
wrapper.NewConstructor(block.Data.SquareSize,
nmt.NodeVisitor(adder.VisitFn())))
eds, err := extendBlock(block.Data, block.Header.Version.App, nmt.NodeVisitor(adder.VisitFn()))
if err != nil {
return nil, fmt.Errorf("extending block data for height %d: %w", &block.Height, err)
}
Expand Down Expand Up @@ -149,9 +146,7 @@ func (ce *Exchange) getExtendedHeaderByHeight(ctx context.Context, height *int64

// extend block data
adder := ipld.NewProofsAdder(int(b.Data.SquareSize))
eds, err := extendBlock(b.Data, b.Header.Version.App,
wrapper.NewConstructor(b.Data.SquareSize,
nmt.NodeVisitor(adder.VisitFn())))
eds, err := extendBlock(b.Data, b.Header.Version.App, nmt.NodeVisitor(adder.VisitFn()))
if err != nil {
return nil, fmt.Errorf("extending block data for height %d: %w", b.Header.Height, err)
}
Expand Down
5 changes: 1 addition & 4 deletions core/header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/tendermint/tendermint/libs/rand"

"github.com/celestiaorg/celestia-app/pkg/wrapper"

"github.com/celestiaorg/celestia-node/header"
"github.com/celestiaorg/celestia-node/header/headertest"
)
Expand All @@ -32,8 +30,7 @@ func TestMakeExtendedHeaderForEmptyBlock(t *testing.T) {
comm, val, err := fetcher.GetBlockInfo(ctx, &height)
require.NoError(t, err)

eds, err := extendBlock(b.Data, b.Header.Version.App,
wrapper.NewConstructor(b.Data.SquareSize))
eds, err := extendBlock(b.Data, b.Header.Version.App)
require.NoError(t, err)

headerExt, err := header.MakeExtendedHeader(ctx, &b.Header, comm, val, eds)
Expand Down
5 changes: 1 addition & 4 deletions core/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"

"github.com/celestiaorg/celestia-app/pkg/wrapper"
libhead "github.com/celestiaorg/go-header"
"github.com/celestiaorg/nmt"

Expand Down Expand Up @@ -154,9 +153,7 @@ func (cl *Listener) handleNewSignedBlock(ctx context.Context, b types.EventDataS
)
// extend block data
adder := ipld.NewProofsAdder(int(b.Data.SquareSize))
eds, err := extendBlock(b.Data, b.Header.Version.App,
wrapper.NewConstructor(b.Data.SquareSize,
nmt.NodeVisitor(adder.VisitFn())))
eds, err := extendBlock(b.Data, b.Header.Version.App, nmt.NodeVisitor(adder.VisitFn()))
if err != nil {
return fmt.Errorf("extending block data: %w", err)
}
Expand Down

0 comments on commit d411a34

Please sign in to comment.