diff --git a/core/exchange.go b/core/exchange.go index 61dd701bb4..6a6cba94f0 100644 --- a/core/exchange.go +++ b/core/exchange.go @@ -136,7 +136,7 @@ func (ce *Exchange) Get(ctx context.Context, hash libhead.Hash) (*header.Extende } // extend block data - adder := ipld.NewProofsAdder(int(block.Data.SquareSize)) + adder := ipld.NewProofsAdder(int(block.Data.SquareSize), false) defer adder.Purge() eds, err := extendBlock(block.Data, block.Header.Version.App, nmt.NodeVisitor(adder.VisitFn())) @@ -181,7 +181,7 @@ func (ce *Exchange) getExtendedHeaderByHeight(ctx context.Context, height *int64 log.Debugw("fetched signed block from core", "height", b.Header.Height) // extend block data - adder := ipld.NewProofsAdder(int(b.Data.SquareSize)) + adder := ipld.NewProofsAdder(int(b.Data.SquareSize), false) defer adder.Purge() eds, err := extendBlock(b.Data, b.Header.Version.App, nmt.NodeVisitor(adder.VisitFn())) diff --git a/core/listener.go b/core/listener.go index 5260067154..a9b87b6dd2 100644 --- a/core/listener.go +++ b/core/listener.go @@ -215,7 +215,7 @@ func (cl *Listener) handleNewSignedBlock(ctx context.Context, b types.EventDataS attribute.Int64("height", b.Header.Height), ) // extend block data - adder := ipld.NewProofsAdder(int(b.Data.SquareSize)) + adder := ipld.NewProofsAdder(int(b.Data.SquareSize), false) defer adder.Purge() eds, err := extendBlock(b.Data, b.Header.Version.App, nmt.NodeVisitor(adder.VisitFn())) diff --git a/header/headertest/fraud/testing.go b/header/headertest/fraud/testing.go index 66965cf178..1cc43ea019 100644 --- a/header/headertest/fraud/testing.go +++ b/header/headertest/fraud/testing.go @@ -58,7 +58,7 @@ func (f *FraudMaker) MakeExtendedHeader(odsSize int, edsStore *eds.Store) header hdr := *h if h.Height == f.height { - adder := ipld.NewProofsAdder(odsSize) + adder := ipld.NewProofsAdder(odsSize, false) square := edstest.RandByzantineEDS(f.t, odsSize, nmt.NodeVisitor(adder.VisitFn())) dah, err := da.NewDataAvailabilityHeader(square) require.NoError(f.t, err) diff --git a/nodebuilder/store_test.go b/nodebuilder/store_test.go index 17d15f8a6f..9a9ed443d4 100644 --- a/nodebuilder/store_test.go +++ b/nodebuilder/store_test.go @@ -84,7 +84,7 @@ func BenchmarkStore(b *testing.B) { b.StopTimer() b.ResetTimer() for i := 0; i < b.N; i++ { - adder := ipld.NewProofsAdder(size * 2) + adder := ipld.NewProofsAdder(size*2, false) shares := sharetest.RandShares(b, size*size) eds, err := rsmt2d.ComputeExtendedDataSquare( shares, diff --git a/share/availability/full/availability.go b/share/availability/full/availability.go index 4ea211cb1e..ff26404d45 100644 --- a/share/availability/full/availability.go +++ b/share/availability/full/availability.go @@ -77,7 +77,7 @@ func (fa *ShareAvailability) SharesAvailable(ctx context.Context, header *header return nil } - adder := ipld.NewProofsAdder(len(dah.RowRoots)) + adder := ipld.NewProofsAdder(len(dah.RowRoots), false) ctx = ipld.CtxWithProofsAdder(ctx, adder) defer adder.Purge() diff --git a/share/eds/byzantine/share_proof.go b/share/eds/byzantine/share_proof.go index d064656830..f802b486c4 100644 --- a/share/eds/byzantine/share_proof.go +++ b/share/eds/byzantine/share_proof.go @@ -3,10 +3,8 @@ package byzantine import ( "context" "errors" - "math" "github.com/ipfs/boxo/blockservice" - "github.com/ipfs/go-cid" logging "github.com/ipfs/go-log/v2" "github.com/celestiaorg/nmt" @@ -87,8 +85,7 @@ func GetShareWithProof( width := len(dah.RowRoots) // try row proofs root := dah.RowRoots[axisIdx] - rootCid := ipld.MustCidFromNamespacedSha256(root) - proof, err := getProofsAt(ctx, bGetter, rootCid, shrIdx, width) + proof, err := ipld.GetProof(ctx, bGetter, root, shrIdx, width) if err == nil { shareWithProof := &ShareWithProof{ Share: share, @@ -102,8 +99,7 @@ func GetShareWithProof( // try column proofs root = dah.ColumnRoots[shrIdx] - rootCid = ipld.MustCidFromNamespacedSha256(root) - proof, err = getProofsAt(ctx, bGetter, rootCid, axisIdx, width) + proof, err = ipld.GetProof(ctx, bGetter, root, axisIdx, width) if err != nil { return nil, err } @@ -118,28 +114,6 @@ func GetShareWithProof( return nil, errors.New("failed to collect proof") } -func getProofsAt( - ctx context.Context, - bGetter blockservice.BlockGetter, - root cid.Cid, - index, - total int, -) (nmt.Proof, error) { - proofPath := make([]cid.Cid, 0, int(math.Sqrt(float64(total)))) - proofPath, err := ipld.GetProof(ctx, bGetter, root, proofPath, index, total) - if err != nil { - return nmt.Proof{}, err - } - - rangeProofs := make([][]byte, 0, len(proofPath)) - for i := len(proofPath) - 1; i >= 0; i-- { - node := ipld.NamespacedSha256FromCID(proofPath[i]) - rangeProofs = append(rangeProofs, node) - } - - return nmt.NewInclusionProof(index, index+1, rangeProofs, true), nil -} - func ProtoToShare(protoShares []*pb.Share) []*ShareWithProof { shares := make([]*ShareWithProof, len(protoShares)) for i, share := range protoShares { diff --git a/share/eds/eds.go b/share/eds/eds.go index b8a332f275..c72d9df596 100644 --- a/share/eds/eds.go +++ b/share/eds/eds.go @@ -121,7 +121,7 @@ func getProofs(ctx context.Context, eds *rsmt2d.ExtendedDataSquare) (map[cid.Cid // this adder ignores leaves, so that they are not added to the store we iterate through in // writeProofs - adder := ipld.NewProofsAdder(odsWidth * 2) + adder := ipld.NewProofsAdder(odsWidth*2, false) defer adder.Purge() eds, err := rsmt2d.ImportExtendedDataSquare( diff --git a/share/ipld/get.go b/share/ipld/get.go index adf2ffa8c5..9cbbc23414 100644 --- a/share/ipld/get.go +++ b/share/ipld/get.go @@ -157,46 +157,6 @@ func GetLeaves(ctx context.Context, wg.Wait() } -// GetProof fetches and returns the leaf's Merkle Proof. -// It walks down the IPLD NMT tree until it reaches the leaf and returns collected proof -func GetProof( - ctx context.Context, - bGetter blockservice.BlockGetter, - root cid.Cid, - proof []cid.Cid, - leaf, total int, -) ([]cid.Cid, error) { - // request the node - nd, err := GetNode(ctx, bGetter, root) - if err != nil { - return nil, err - } - // look for links - lnks := nd.Links() - if len(lnks) == 0 { - p := make([]cid.Cid, len(proof)) - copy(p, proof) - return p, nil - } - - // route walk to appropriate children - total /= 2 // as we are using binary tree, every step decreases total leaves in a half - if leaf < total { - root = lnks[0].Cid // if target leave on the left, go with walk down the first children - proof = append(proof, lnks[1].Cid) - } else { - root, leaf = lnks[1].Cid, leaf-total // otherwise go down the second - proof, err = GetProof(ctx, bGetter, root, proof, leaf, total) - if err != nil { - return nil, err - } - return append(proof, lnks[0].Cid), nil - } - - // recursively walk down through selected children - return GetProof(ctx, bGetter, root, proof, leaf, total) -} - // chanGroup implements an atomic wait group, closing a jobs chan // when fully done. type chanGroup struct { diff --git a/share/ipld/namespace_data.go b/share/ipld/namespace_data.go index 5a6fd2abb4..32d1b5b369 100644 --- a/share/ipld/namespace_data.go +++ b/share/ipld/namespace_data.go @@ -14,10 +14,10 @@ import ( "github.com/celestiaorg/nmt" "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/shwap" ) -var ErrNamespaceOutsideRange = errors.New("share/ipld: " + - "target namespace is outside of namespace range for the given root") +var ErrNamespaceOutsideRange = shwap.ErrNamespaceOutsideRange // Option is the functional option that is applied to the NamespaceData instance // to configure data that needs to be stored. diff --git a/share/ipld/nmt_adder.go b/share/ipld/nmt_adder.go index 7ce52859b2..f5065df224 100644 --- a/share/ipld/nmt_adder.go +++ b/share/ipld/nmt_adder.go @@ -103,13 +103,15 @@ func BatchSize(squareSize int) int { // ProofsAdder is used to collect proof nodes, while traversing merkle tree type ProofsAdder struct { - lock sync.RWMutex - proofs map[cid.Cid][]byte + lock sync.RWMutex + collectShares bool + proofs map[cid.Cid][]byte } // NewProofsAdder creates new instance of ProofsAdder. -func NewProofsAdder(squareSize int) *ProofsAdder { +func NewProofsAdder(squareSize int, collectShares bool) *ProofsAdder { return &ProofsAdder{ + collectShares: collectShares, // preallocate map to fit all inner nodes for given square size proofs: make(map[cid.Cid][]byte, innerNodesAmount(squareSize)), } @@ -156,7 +158,7 @@ func (a *ProofsAdder) VisitFn() nmt.NodeVisitorFn { if len(a.proofs) > 0 { return nil } - return a.visitInnerNodes + return a.visitNodes } // Purge removed proofs from ProofsAdder allowing GC to collect the memory @@ -171,10 +173,13 @@ func (a *ProofsAdder) Purge() { a.proofs = nil } -func (a *ProofsAdder) visitInnerNodes(hash []byte, children ...[]byte) { +func (a *ProofsAdder) visitNodes(hash []byte, children ...[]byte) { switch len(children) { case 1: - break + if a.collectShares { + id := MustCidFromNamespacedSha256(hash) + a.addProof(id, children[0]) + } case 2: id := MustCidFromNamespacedSha256(hash) a.addProof(id, append(children[0], children[1]...)) diff --git a/share/ipld/proofs.go b/share/ipld/proofs.go new file mode 100644 index 0000000000..286e817d95 --- /dev/null +++ b/share/ipld/proofs.go @@ -0,0 +1,74 @@ +package ipld + +import ( + "context" + "math" + + "github.com/ipfs/boxo/blockservice" + "github.com/ipfs/go-cid" + + "github.com/celestiaorg/nmt" +) + +// GetProof fetches and returns the leaf's Merkle Proof. +// It walks down the IPLD NMT tree until it reaches the leaf and returns collected proof +func GetProof( + ctx context.Context, + bGetter blockservice.BlockGetter, + root []byte, + shareIdx, + total int, +) (nmt.Proof, error) { + rootCid := MustCidFromNamespacedSha256(root) + proofPath := make([]cid.Cid, 0, int(math.Sqrt(float64(total)))) + proofPath, err := getProof(ctx, bGetter, rootCid, proofPath, shareIdx, total) + if err != nil { + return nmt.Proof{}, err + } + + rangeProofs := make([][]byte, 0, len(proofPath)) + for i := len(proofPath) - 1; i >= 0; i-- { + node := NamespacedSha256FromCID(proofPath[i]) + rangeProofs = append(rangeProofs, node) + } + + return nmt.NewInclusionProof(shareIdx, shareIdx+1, rangeProofs, true), nil +} + +func getProof( + ctx context.Context, + bGetter blockservice.BlockGetter, + root cid.Cid, + proof []cid.Cid, + leaf, total int, +) ([]cid.Cid, error) { + // request the node + nd, err := GetNode(ctx, bGetter, root) + if err != nil { + return nil, err + } + // look for links + lnks := nd.Links() + if len(lnks) == 0 { + p := make([]cid.Cid, len(proof)) + copy(p, proof) + return p, nil + } + + // route walk to appropriate children + total /= 2 // as we are using binary tree, every step decreases total leaves in a half + if leaf < total { + root = lnks[0].Cid // if target leave on the left, go with walk down the first children + proof = append(proof, lnks[1].Cid) + } else { + root, leaf = lnks[1].Cid, leaf-total // otherwise go down the second + proof, err = getProof(ctx, bGetter, root, proof, leaf, total) + if err != nil { + return nil, err + } + return append(proof, lnks[0].Cid), nil + } + + // recursively walk down through selected children + return getProof(ctx, bGetter, root, proof, leaf, total) +} diff --git a/share/eds/byzantine/share_proof_test.go b/share/ipld/proofs_test.go similarity index 55% rename from share/eds/byzantine/share_proof_test.go rename to share/ipld/proofs_test.go index 170ba591e2..0b41b568b6 100644 --- a/share/eds/byzantine/share_proof_test.go +++ b/share/ipld/proofs_test.go @@ -1,4 +1,4 @@ -package byzantine +package ipld import ( "context" @@ -11,8 +11,8 @@ import ( "github.com/celestiaorg/rsmt2d" "github.com/celestiaorg/celestia-node/share" - "github.com/celestiaorg/celestia-node/share/ipld" "github.com/celestiaorg/celestia-node/share/sharetest" + "github.com/celestiaorg/celestia-node/share/shwap" ) func TestGetProof(t *testing.T) { @@ -20,10 +20,10 @@ func TestGetProof(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) defer cancel() - bServ := ipld.NewMemBlockservice() + bServ := NewMemBlockservice() shares := sharetest.RandShares(t, width*width) - in, err := ipld.AddShares(ctx, shares, bServ) + in, err := AddShares(ctx, shares, bServ) require.NoError(t, err) dah, err := da.NewDataAvailabilityHeader(in) @@ -38,25 +38,28 @@ func TestGetProof(t *testing.T) { roots = dah.ColumnRoots } for axisIdx := 0; axisIdx < width*2; axisIdx++ { - rootCid := ipld.MustCidFromNamespacedSha256(roots[axisIdx]) + root := roots[axisIdx] for shrIdx := 0; shrIdx < width*2; shrIdx++ { - proof, err := getProofsAt(ctx, bServ, rootCid, shrIdx, int(in.Width())) + proof, err := GetProof(ctx, bServ, root, shrIdx, int(in.Width())) require.NoError(t, err) - node, err := ipld.GetLeaf(ctx, bServ, rootCid, shrIdx, int(in.Width())) + rootCid := MustCidFromNamespacedSha256(root) + node, err := GetLeaf(ctx, bServ, rootCid, shrIdx, int(in.Width())) require.NoError(t, err) - inclusion := &ShareWithProof{ - Share: share.GetData(node.RawData()), - Proof: &proof, - Axis: proofType, + + sample := shwap.Sample{ + Share: share.GetData(node.RawData()), + Proof: &proof, + ProofType: proofType, } - require.True(t, inclusion.Validate(&dah, proofType, axisIdx, shrIdx)) - // swap axis indexes to test if validation still works against the orthogonal coordinate + var rowIdx, colIdx int switch proofType { case rsmt2d.Row: - require.True(t, inclusion.Validate(&dah, rsmt2d.Col, shrIdx, axisIdx)) + rowIdx, colIdx = axisIdx, shrIdx case rsmt2d.Col: - require.True(t, inclusion.Validate(&dah, rsmt2d.Row, shrIdx, axisIdx)) + rowIdx, colIdx = shrIdx, axisIdx } + err = sample.Validate(&dah, rowIdx, colIdx) + require.NoError(t, err) } } } diff --git a/share/new_eds/axis_half.go b/share/new_eds/axis_half.go index dede70ebbc..6b48676fe2 100644 --- a/share/new_eds/axis_half.go +++ b/share/new_eds/axis_half.go @@ -7,6 +7,8 @@ import ( "github.com/celestiaorg/celestia-node/share/shwap" ) +var codec = share.DefaultRSMT2DCodec() + // AxisHalf represents a half of data for a row or column in the EDS. type AxisHalf struct { Shares []share.Share @@ -37,18 +39,18 @@ func extendShares(original []share.Share) ([]share.Share, error) { return nil, fmt.Errorf("original shares are empty") } - codec := share.DefaultRSMT2DCodec() parity, err := codec.Encode(original) if err != nil { return nil, fmt.Errorf("encoding: %w", err) } - shares := make([]share.Share, len(original)*2) + + sqLen := len(original) * 2 + shares := make([]share.Share, sqLen) copy(shares, original) - copy(shares[len(original):], parity) + copy(shares[sqLen/2:], parity) return shares, nil } -// reconstructShares constructs full axis shares from parity half axis shares. func reconstructShares(parity []share.Share) ([]share.Share, error) { if len(parity) == 0 { return nil, fmt.Errorf("parity shares are empty") @@ -59,9 +61,7 @@ func reconstructShares(parity []share.Share) ([]share.Share, error) { for i := sqLen / 2; i < sqLen; i++ { shares[i] = parity[i-sqLen/2] } - - codec := share.DefaultRSMT2DCodec() - shares, err := codec.Decode(shares) + _, err := codec.Decode(shares) if err != nil { return nil, fmt.Errorf("reconstructing: %w", err) } diff --git a/share/new_eds/proofs_cache.go b/share/new_eds/proofs_cache.go new file mode 100644 index 0000000000..068b4c4d0b --- /dev/null +++ b/share/new_eds/proofs_cache.go @@ -0,0 +1,279 @@ +package eds + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + + "github.com/ipfs/boxo/blockservice" + blocks "github.com/ipfs/go-block-format" + "github.com/ipfs/go-cid" + + "github.com/celestiaorg/celestia-app/pkg/wrapper" + "github.com/celestiaorg/nmt" + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/ipld" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +var _ Accessor = (*proofsCache)(nil) + +// proofsCache is eds accessor that caches proofs for rows and columns. It also caches extended +// axis Shares. It is used to speed up the process of building proofs for rows and columns, +// reducing the number of reads from the underlying accessor. +type proofsCache struct { + inner Accessor + + // lock protects axisCache + lock sync.RWMutex + // axisCache caches the axis Shares and proofs. Index in the slice corresponds to the axis type. + // The map key is the index of the axis. + axisCache []map[int]axisWithProofs + // size caches the size of the data square + size atomic.Int32 + // disableCache disables caching of rows for testing purposes + disableCache bool +} + +// axisWithProofs is used to cache the extended axis Shares and proofs. +type axisWithProofs struct { + half AxisHalf + // shares are the extended axis Shares + shares []share.Share + // root caches the root of the tree. It will be set only when proofs are calculated + root []byte + // proofs are stored in a blockservice.BlockGetter by their CID. It will be set only when proofs + // are calculated and will be used to get the proof for a specific share. BlockGetter is used to + // reuse ipld based proof generation logic, which traverses the tree from the root to the leafs and + // collects the nodes on the path. This is temporary and will be replaced with a more efficient + // proof caching mechanism in nmt package, once it is implemented. + proofs blockservice.BlockGetter +} + +// WithProofsCache creates a new eds accessor with caching of proofs for rows and columns. It is +// used to speed up the process of building proofs for rows and columns, reducing the number of +// reads from the underlying accessor. +func WithProofsCache(ac Accessor) Accessor { + rows := make(map[int]axisWithProofs) + cols := make(map[int]axisWithProofs) + axisCache := []map[int]axisWithProofs{rows, cols} + return &proofsCache{ + inner: ac, + axisCache: axisCache, + } +} + +func (c *proofsCache) Size(ctx context.Context) int { + size := c.size.Load() + if size == 0 { + size = int32(c.inner.Size(ctx)) + c.size.Store(size) + } + return int(size) +} + +func (c *proofsCache) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) { + axisType, axisIdx, shrIdx := rsmt2d.Row, rowIdx, colIdx + ax, err := c.axisWithProofs(ctx, axisType, axisIdx) + if err != nil { + return shwap.Sample{}, err + } + + // build share proof from proofs cached for given axis + share := ax.shares[shrIdx] + proofs, err := ipld.GetProof(ctx, ax.proofs, ax.root, shrIdx, c.Size(ctx)) + if err != nil { + return shwap.Sample{}, fmt.Errorf("building proof from cache: %w", err) + } + + return shwap.Sample{ + Share: share, + Proof: &proofs, + ProofType: axisType, + }, nil +} + +func (c *proofsCache) axisWithProofs(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (axisWithProofs, error) { + // return axis with proofs from cache if possible + ax, ok := c.getAxisFromCache(axisType, axisIdx) + if ax.proofs != nil { + // return axis with proofs from cache, only if proofs are already calculated + return ax, nil + } + + if !ok { + // if shares are not in cache, read them from the inner accessor + shares, err := c.axisShares(ctx, axisType, axisIdx) + if err != nil { + return axisWithProofs{}, fmt.Errorf("get axis: %w", err) + } + ax.shares = shares + } + + // build proofs from Shares and cache them + adder := ipld.NewProofsAdder(c.Size(ctx), true) + tree := wrapper.NewErasuredNamespacedMerkleTree( + uint64(c.Size(ctx)/2), + uint(axisIdx), + nmt.NodeVisitor(adder.VisitFn()), + ) + for _, shr := range ax.shares { + err := tree.Push(shr) + if err != nil { + return axisWithProofs{}, fmt.Errorf("push shares: %w", err) + } + } + + // build the tree + root, err := tree.Root() + if err != nil { + return axisWithProofs{}, fmt.Errorf("calculating root: %w", err) + } + + ax.root = root + ax.proofs, err = newRowProofsGetter(adder.Proofs()) + if err != nil { + return axisWithProofs{}, fmt.Errorf("creating proof getter: %w", err) + } + + if !c.disableCache { + c.storeAxisInCache(axisType, axisIdx, ax) + } + return ax, nil +} + +func (c *proofsCache) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) { + // return axis from cache if possible + ax, ok := c.getAxisFromCache(axisType, axisIdx) + if ok { + return ax.half, nil + } + + // read axis from inner accessor if axis is in the first quadrant + half, err := c.inner.AxisHalf(ctx, axisType, axisIdx) + if err != nil { + return AxisHalf{}, fmt.Errorf("reading axis from inner accessor: %w", err) + } + + if !c.disableCache { + ax.half = half + c.storeAxisInCache(axisType, axisIdx, ax) + } + + return half, nil +} + +func (c *proofsCache) RowNamespaceData( + ctx context.Context, + namespace share.Namespace, + rowIdx int, +) (shwap.RowNamespaceData, error) { + ax, err := c.axisWithProofs(ctx, rsmt2d.Row, rowIdx) + if err != nil { + return shwap.RowNamespaceData{}, err + } + + row, proof, err := ipld.GetSharesByNamespace(ctx, ax.proofs, ax.root, namespace, c.Size(ctx)) + if err != nil { + return shwap.RowNamespaceData{}, fmt.Errorf("shares by namespace %s for row %v: %w", namespace.String(), rowIdx, err) + } + + return shwap.RowNamespaceData{ + Shares: row, + Proof: proof, + }, nil +} + +func (c *proofsCache) Shares(ctx context.Context) ([]share.Share, error) { + odsSize := c.Size(ctx) / 2 + shares := make([]share.Share, 0, odsSize*odsSize) + for i := 0; i < c.Size(ctx)/2; i++ { + ax, err := c.AxisHalf(ctx, rsmt2d.Row, i) + if err != nil { + return nil, err + } + + half := ax.Shares + if ax.IsParity { + shares, err = c.axisShares(ctx, rsmt2d.Row, i) + if err != nil { + return nil, err + } + half = shares[:odsSize] + } + + shares = append(shares, half...) + } + return shares, nil +} + +func (c *proofsCache) axisShares(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) ([]share.Share, error) { + ax, ok := c.getAxisFromCache(axisType, axisIdx) + if ok && ax.shares != nil { + return ax.shares, nil + } + + if len(ax.half.Shares) == 0 { + half, err := c.AxisHalf(ctx, axisType, axisIdx) + if err != nil { + return nil, err + } + ax.half = half + } + + shares, err := ax.half.Extended() + if err != nil { + return nil, fmt.Errorf("extending shares: %w", err) + } + + if !c.disableCache { + ax.shares = shares + c.storeAxisInCache(axisType, axisIdx, ax) + } + return shares, nil +} + +func (c *proofsCache) storeAxisInCache(axisType rsmt2d.Axis, axisIdx int, axis axisWithProofs) { + c.lock.Lock() + defer c.lock.Unlock() + c.axisCache[axisType][axisIdx] = axis +} + +func (c *proofsCache) getAxisFromCache(axisType rsmt2d.Axis, axisIdx int) (axisWithProofs, bool) { + c.lock.RLock() + defer c.lock.RUnlock() + ax, ok := c.axisCache[axisType][axisIdx] + return ax, ok +} + +// rowProofsGetter implements blockservice.BlockGetter interface +type rowProofsGetter struct { + proofs map[cid.Cid]blocks.Block +} + +func newRowProofsGetter(rawProofs map[cid.Cid][]byte) (*rowProofsGetter, error) { + proofs := make(map[cid.Cid]blocks.Block, len(rawProofs)) + for k, v := range rawProofs { + b, err := blocks.NewBlockWithCid(v, k) + if err != nil { + return nil, err + } + proofs[k] = b + } + return &rowProofsGetter{proofs: proofs}, nil +} + +func (r rowProofsGetter) GetBlock(_ context.Context, c cid.Cid) (blocks.Block, error) { + if b, ok := r.proofs[c]; ok { + return b, nil + } + return nil, errors.New("block not found") +} + +func (r rowProofsGetter) GetBlocks(_ context.Context, _ []cid.Cid) <-chan blocks.Block { + panic("not implemented") +} diff --git a/share/new_eds/proofs_cache_test.go b/share/new_eds/proofs_cache_test.go new file mode 100644 index 0000000000..8b22af6e4f --- /dev/null +++ b/share/new_eds/proofs_cache_test.go @@ -0,0 +1,22 @@ +package eds + +import ( + "context" + "testing" + "time" + + "github.com/celestiaorg/rsmt2d" +) + +func TestCache(t *testing.T) { + size := 8 + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + withProofsCache := func(tb testing.TB, inner *rsmt2d.ExtendedDataSquare) Accessor { + accessor := &Rsmt2D{ExtendedDataSquare: inner} + return WithProofsCache(accessor) + } + + TestSuiteAccessor(ctx, t, withProofsCache, size) +}