diff --git a/share/eds/retriever.go b/share/eds/retriever.go index c2966c3953..9389d8b6b2 100644 --- a/share/eds/retriever.go +++ b/share/eds/retriever.go @@ -72,7 +72,6 @@ func (r *Retriever) Retrieve(ctx context.Context, dah *da.DataAvailabilityHeader if err != nil { return nil, err } - defer ses.Close() // wait for a signal to start reconstruction // try until either success or context or bad data @@ -81,18 +80,23 @@ func (r *Retriever) Retrieve(ctx context.Context, dah *da.DataAvailabilityHeader case <-ses.Done(): eds, err := ses.Reconstruct(ctx) if err == nil { + ses.close(true) span.SetStatus(codes.Ok, "square-retrieved") return eds, nil } // check to ensure it is not a catastrophic ErrByzantine case, otherwise handle accordingly var errByz *rsmt2d.ErrByzantineData if errors.As(err, &errByz) { + // session should be closed before constructing the Byzantine error to allow constructor to access nmt proofs + // computed during the session + ses.close(false) span.RecordError(err) return nil, byzantine.NewErrByzantine(ctx, r.bServ, dah, errByz) } log.Warnw("not enough shares to reconstruct data square, requesting more...", "err", err) case <-ctx.Done(): + ses.close(false) return nil, ctx.Err() } } @@ -103,8 +107,9 @@ func (r *Retriever) Retrieve(ctx context.Context, dah *da.DataAvailabilityHeader // quadrant request retries. Also, provides an API // to reconstruct the block once enough shares are fetched. type retrievalSession struct { - dah *da.DataAvailabilityHeader - bget blockservice.BlockGetter + dah *da.DataAvailabilityHeader + bget blockservice.BlockGetter + adder *ipld.NmtNodeAdder // TODO(@Wondertan): Extract into a separate data structure // https://github.com/celestiaorg/rsmt2d/issues/135 @@ -123,15 +128,18 @@ type retrievalSession struct { func (r *Retriever) newSession(ctx context.Context, dah *da.DataAvailabilityHeader) (*retrievalSession, error) { size := len(dah.RowRoots) - treeFn := func(_ rsmt2d.Axis, index uint) rsmt2d.Tree { + adder := ipld.NewNmtNodeAdder(ctx, r.bServ, ipld.MaxSizeBatchOption(size)) + proofsVisitor := ipld.ProofsAdderFromCtx(ctx).VisitFn() + visitor := func(hash []byte, children ...[]byte) { // use proofs adder if provided, to cache collected proofs while recomputing the eds - var opts []nmt.Option - visitor := ipld.ProofsAdderFromCtx(ctx).VisitFn() - if visitor != nil { - opts = append(opts, nmt.NodeVisitor(visitor)) + if proofsVisitor != nil { + proofsVisitor(hash, children...) } + adder.Visit(hash, children...) + } - tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(size)/2, index, opts...) + treeFn := func(_ rsmt2d.Axis, index uint) rsmt2d.Tree { + tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(size)/2, index, nmt.NodeVisitor(visitor)) return &tree } @@ -143,6 +151,7 @@ func (r *Retriever) newSession(ctx context.Context, dah *da.DataAvailabilityHead ses := &retrievalSession{ dah: dah, bget: blockservice.NewSession(ctx, r.bServ), + adder: adder, squareQuadrants: newQuadrants(dah), squareCellsLks: make([][]sync.Mutex, size), squareSig: make(chan struct{}, 1), @@ -200,9 +209,16 @@ func (rs *retrievalSession) isReconstructed() bool { } } -func (rs *retrievalSession) Close() error { +func (rs *retrievalSession) close(success bool) { defer rs.span.End() - return nil + if success { + return + } + // commit intermediate nodes to the blockservice if failed to reconstruct + err := rs.adder.Commit() + if err != nil { + log.Warnw("failed to commit intermediate nodes", "err", err) + } } // request kicks off quadrants requests.