Skip to content

Commit

Permalink
refactor(share): GetShare -> GetSamples (#3905)
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg committed Nov 14, 2024
1 parent 7354df8 commit f823d6b
Show file tree
Hide file tree
Showing 23 changed files with 153 additions and 85 deletions.
30 changes: 21 additions & 9 deletions blob/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,11 @@ func TestBlobService_Get(t *testing.T) {
shareOffset := 0
for i := range blobs {
row, col := calculateIndex(len(h.DAH.RowRoots), blobs[i].index)
sh, err := service.shareGetter.GetShare(ctx, h, row, col)
idx, err := shwap.SampleIndexFromCoordinates(row, col, len(h.DAH.RowRoots))
require.NoError(t, err)
require.True(t, bytes.Equal(sh.ToBytes(), resultShares[shareOffset].ToBytes()),
smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleIndex{idx})
require.NoError(t, err)
require.True(t, bytes.Equal(smpls[0].Share.ToBytes(), resultShares[shareOffset].ToBytes()),
fmt.Sprintf("issue on %d attempt. ROW:%d, COL: %d, blobIndex:%d", i, row, col, blobs[i].index),
)
shareOffset += libshare.SparseSharesNeeded(uint32(len(blobs[i].Data())))
Expand Down Expand Up @@ -487,10 +489,13 @@ func TestService_GetSingleBlobWithoutPadding(t *testing.T) {
h, err := service.headerGetter(ctx, 1)
require.NoError(t, err)
row, col := calculateIndex(len(h.DAH.RowRoots), newBlob.index)
sh, err := service.shareGetter.GetShare(ctx, h, row, col)
idx, err := shwap.SampleIndexFromCoordinates(row, col, len(h.DAH.RowRoots))
require.NoError(t, err)

smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleIndex{idx})
require.NoError(t, err)

assert.Equal(t, sh, resultShares[0])
assert.Equal(t, smpls[0].Share, resultShares[0])
}

func TestService_Get(t *testing.T) {
Expand Down Expand Up @@ -521,10 +526,13 @@ func TestService_Get(t *testing.T) {
assert.Equal(t, b.Commitment, blob.Commitment)

row, col := calculateIndex(len(h.DAH.RowRoots), b.index)
sh, err := service.shareGetter.GetShare(ctx, h, row, col)
idx, err := shwap.SampleIndexFromCoordinates(row, col, len(h.DAH.RowRoots))
require.NoError(t, err)

assert.Equal(t, sh, resultShares[shareOffset], fmt.Sprintf("issue on %d attempt", i))
smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleIndex{idx})
require.NoError(t, err)

assert.Equal(t, smpls[0].Share, resultShares[shareOffset], fmt.Sprintf("issue on %d attempt", i))
shareOffset += libshare.SparseSharesNeeded(uint32(len(blob.Data())))
}
}
Expand Down Expand Up @@ -580,10 +588,13 @@ func TestService_GetAllWithoutPadding(t *testing.T) {
require.True(t, blobs[i].compareCommitments(blob.Commitment))

row, col := calculateIndex(len(h.DAH.RowRoots), blob.index)
sh, err := service.shareGetter.GetShare(ctx, h, row, col)
idx, err := shwap.SampleIndexFromCoordinates(row, col, len(h.DAH.RowRoots))
require.NoError(t, err)

smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleIndex{idx})
require.NoError(t, err)

assert.Equal(t, sh, resultShares[shareOffset])
assert.Equal(t, smpls[0].Share, resultShares[shareOffset])
shareOffset += libshare.SparseSharesNeeded(uint32(len(blob.Data())))
}
}
Expand Down Expand Up @@ -904,7 +915,8 @@ func createService(ctx context.Context, t testing.TB, shares []libshare.Share) *
})
shareGetter.EXPECT().GetSamples(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().
DoAndReturn(func(ctx context.Context, h *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) {
return smpls, nil
smpl, err := accessor.Sample(ctx, indices[0])
return []shwap.Sample{smpl}, err
})

// create header and put it into the store
Expand Down
20 changes: 20 additions & 0 deletions nodebuilder/share/share.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
libshare "github.com/celestiaorg/go-square/v2/share"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/header"
headerServ "github.com/celestiaorg/celestia-node/nodebuilder/header"
"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/eds"
Expand Down Expand Up @@ -45,6 +46,8 @@ type Module interface {
SharesAvailable(ctx context.Context, height uint64) error
// GetShare gets a Share by coordinates in EDS.
GetShare(ctx context.Context, height uint64, row, col int) (libshare.Share, error)
// GetSamples gets sample for given indices.
GetSamples(ctx context.Context, header *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error)
// GetEDS gets the full EDS identified by the given extended header.
GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error)
// GetNamespaceData gets all shares from an EDS within the given namespace.
Expand All @@ -65,6 +68,11 @@ type API struct {
height uint64,
row, col int,
) (libshare.Share, error) `perm:"read"`
GetSamples func(
ctx context.Context,
header *header.ExtendedHeader,
indices []shwap.SampleIndex,
) ([]shwap.Sample, error) `perm:"read"`
GetEDS func(
ctx context.Context,
height uint64,
Expand All @@ -90,6 +98,12 @@ func (api *API) GetShare(ctx context.Context, height uint64, row, col int) (libs
return api.Internal.GetShare(ctx, height, row, col)
}

func (api *API) GetSamples(ctx context.Context, header *header.ExtendedHeader,
indices []shwap.SampleIndex,
) ([]shwap.Sample, error) {
return api.Internal.GetSamples(ctx, header, indices)
}

func (api *API) GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error) {
return api.Internal.GetEDS(ctx, height)
}
Expand Down Expand Up @@ -132,6 +146,12 @@ func (m module) GetShare(ctx context.Context, height uint64, row, col int) (libs
return smpls[0].Share, nil
}

func (m module) GetSamples(ctx context.Context, header *header.ExtendedHeader,
indices []shwap.SampleIndex,
) ([]shwap.Sample, error) {
return m.getter.GetSamples(ctx, header, indices)
}

func (m module) GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error) {
header, err := m.hs.GetByHeight(ctx, height)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions share/availability/light/availability.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header
}

smpls, err := la.getter.GetSamples(ctx, header, idxs)
if errors.Is(ctx.Err(), context.Canceled) {
if errors.Is(err, context.Canceled) {
// Availability did not complete due to context cancellation, return context error instead of
// share.ErrNotAvailable
return ctx.Err()
return err
}
if len(smpls) == 0 {
return share.ErrNotAvailable
Expand Down
11 changes: 4 additions & 7 deletions share/availability/light/availability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,7 @@ func TestSharesAvailableSuccess(t *testing.T) {
acc := eds.Rsmt2D{ExtendedDataSquare: square}
smpls := make([]shwap.Sample, len(indices))
for i, idx := range indices {
rowIdx, colIdx, err := idx.Coordinates(len(hdr.DAH.RowRoots))
if err != nil {
return nil, err
}

smpl, err := acc.Sample(ctx, rowIdx, colIdx)
smpl, err := acc.Sample(ctx, idx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -261,7 +256,9 @@ func (g onceGetter) checkOnce(t *testing.T) {
}
}

func (m onceGetter) GetSamples(_ context.Context, hdr *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) {
func (m onceGetter) GetSamples(_ context.Context, hdr *header.ExtendedHeader,
indices []shwap.SampleIndex,
) ([]shwap.Sample, error) {
m.Lock()
defer m.Unlock()

Expand Down
3 changes: 1 addition & 2 deletions share/eds/accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ type Accessor interface {
// Sample returns share and corresponding proof for row and column indices. Implementation can
// choose which axis to use for proof. Chosen axis for proof should be indicated in the returned
// Sample.
// TODO(@Wondertan): change to SampleIndex
Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error)
Sample(ctx context.Context, idx shwap.SampleIndex) (shwap.Sample, error)
// AxisHalf returns half of shares axis of the given type and index. Side is determined by
// implementation. Implementations should indicate the side in the returned AxisHalf.
AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error)
Expand Down
4 changes: 2 additions & 2 deletions share/eds/close_once.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ func (c *closeOnce) AxisRoots(ctx context.Context) (*share.AxisRoots, error) {
return c.f.AxisRoots(ctx)
}

func (c *closeOnce) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) {
func (c *closeOnce) Sample(ctx context.Context, idx shwap.SampleIndex) (shwap.Sample, error) {
if c.closed.Load() {
return shwap.Sample{}, errAccessorClosed
}
return c.f.Sample(ctx, rowIdx, colIdx)
return c.f.Sample(ctx, idx)
}

func (c *closeOnce) AxisHalf(
Expand Down
6 changes: 3 additions & 3 deletions share/eds/close_once_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestWithClosedOnce(t *testing.T) {
stub := &stubEdsAccessorCloser{}
closedOnce := WithClosedOnce(stub)

_, err := closedOnce.Sample(ctx, 0, 0)
_, err := closedOnce.Sample(ctx, 0)
require.NoError(t, err)
_, err = closedOnce.AxisHalf(ctx, rsmt2d.Row, 0)
require.NoError(t, err)
Expand All @@ -33,7 +33,7 @@ func TestWithClosedOnce(t *testing.T) {
require.True(t, stub.closed)

// Ensure that the underlying file is not accessible after closing
_, err = closedOnce.Sample(ctx, 0, 0)
_, err = closedOnce.Sample(ctx, 0)
require.ErrorIs(t, err, errAccessorClosed)
_, err = closedOnce.AxisHalf(ctx, rsmt2d.Row, 0)
require.ErrorIs(t, err, errAccessorClosed)
Expand All @@ -59,7 +59,7 @@ func (s *stubEdsAccessorCloser) AxisRoots(context.Context) (*share.AxisRoots, er
return &share.AxisRoots{}, nil
}

func (s *stubEdsAccessorCloser) Sample(context.Context, int, int) (shwap.Sample, error) {
func (s *stubEdsAccessorCloser) Sample(context.Context, shwap.SampleIndex) (shwap.Sample, error) {
return shwap.Sample{}, nil
}

Expand Down
7 changes: 6 additions & 1 deletion share/eds/proofs_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,12 @@ func (c *proofsCache) AxisRoots(ctx context.Context) (*share.AxisRoots, error) {
return roots, nil
}

func (c *proofsCache) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) {
func (c *proofsCache) Sample(ctx context.Context, idx shwap.SampleIndex) (shwap.Sample, error) {
rowIdx, colIdx, err := idx.Coordinates(c.Size(ctx))
if err != nil {
return shwap.Sample{}, err
}

axisType, axisIdx, shrIdx := rsmt2d.Row, rowIdx, colIdx
ax, err := c.axisWithProofs(ctx, axisType, axisIdx)
if err != nil {
Expand Down
11 changes: 8 additions & 3 deletions share/eds/rsmt2d.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,22 @@ func (eds *Rsmt2D) AxisRoots(context.Context) (*share.AxisRoots, error) {
// Sample returns share and corresponding proof for row and column indices.
func (eds *Rsmt2D) Sample(
_ context.Context,
rowIdx, colIdx int,
idx shwap.SampleIndex,
) (shwap.Sample, error) {
return eds.SampleForProofAxis(rowIdx, colIdx, rsmt2d.Row)
return eds.SampleForProofAxis(idx, rsmt2d.Row)
}

// SampleForProofAxis samples a share from an Extended Data Square based on the provided
// row and column indices and proof axis. It returns a sample with the share and proof.
func (eds *Rsmt2D) SampleForProofAxis(
rowIdx, colIdx int,
idx shwap.SampleIndex,
proofType rsmt2d.Axis,
) (shwap.Sample, error) {
rowIdx, colIdx, err := idx.Coordinates(int(eds.Width()))
if err != nil {
return shwap.Sample{}, err
}

axisIdx, shrIdx := relativeIndexes(rowIdx, colIdx, proofType)
shares, err := getAxis(eds.ExtendedDataSquare, proofType, axisIdx)
if err != nil {
Expand Down
5 changes: 4 additions & 1 deletion share/eds/rsmt2d_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ func TestRsmt2dSampleForProofAxis(t *testing.T) {
for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} {
for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ {
for colIdx := 0; colIdx < odsSize*2; colIdx++ {
sample, err := accessor.SampleForProofAxis(rowIdx, colIdx, proofType)
idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, accessor.Size(context.Background()))
require.NoError(t, err)

sample, err := accessor.SampleForProofAxis(idx, proofType)
require.NoError(t, err)

want := eds.GetCell(uint(rowIdx), uint(colIdx))
Expand Down
30 changes: 20 additions & 10 deletions share/eds/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ func testAccessorSample(
// t.Parallel() this fails the test for some reason
for rowIdx := 0; rowIdx < width; rowIdx++ {
for colIdx := 0; colIdx < width; colIdx++ {
testSample(ctx, t, acc, roots, colIdx, rowIdx)
idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, acc.Size(ctx))
require.NoError(t, err)
testSample(ctx, t, acc, roots, idx)
}
}
})
Expand All @@ -162,10 +164,12 @@ func testAccessorSample(
for rowIdx := 0; rowIdx < width; rowIdx++ {
for colIdx := 0; colIdx < width; colIdx++ {
wg.Add(1)
go func(rowIdx, colIdx int) {
idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, acc.Size(ctx))
require.NoError(t, err)
go func(idx shwap.SampleIndex) {
defer wg.Done()
testSample(ctx, t, acc, roots, rowIdx, colIdx)
}(rowIdx, colIdx)
testSample(ctx, t, acc, roots, idx)
}(idx)
}
}
wg.Wait()
Expand All @@ -182,8 +186,8 @@ func testAccessorSample(
wg.Add(1)
go func() {
defer wg.Done()
rowIdx, colIdx := rand.IntN(width), rand.IntN(width) //nolint:gosec
testSample(ctx, t, acc, roots, rowIdx, colIdx)
idx := rand.IntN(int(eds.Width())) //nolint:gosec
testSample(ctx, t, acc, roots, shwap.SampleIndex(idx))
}()
}
wg.Wait()
Expand All @@ -195,9 +199,12 @@ func testSample(
t *testing.T,
acc Accessor,
roots *share.AxisRoots,
rowIdx, colIdx int,
idx shwap.SampleIndex,
) {
shr, err := acc.Sample(ctx, rowIdx, colIdx)
shr, err := acc.Sample(ctx, idx)
require.NoError(t, err)

rowIdx, colIdx, err := idx.Coordinates(acc.Size(ctx))
require.NoError(t, err)

err = shr.Verify(roots, rowIdx, colIdx)
Expand Down Expand Up @@ -444,13 +451,16 @@ func BenchGetSampleFromAccessor(
name := fmt.Sprintf("Size:%v/quadrant:%s", size, q)
b.Run(name, func(b *testing.B) {
rowIdx, colIdx := q.coordinates(acc.Size(ctx))
idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, acc.Size(ctx))
require.NoError(b, err)

// warm up cache
_, err := acc.Sample(ctx, rowIdx, colIdx)
_, err = acc.Sample(ctx, idx)
require.NoError(b, err, q.String())

b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := acc.Sample(ctx, rowIdx, colIdx)
_, err := acc.Sample(ctx, idx)
require.NoError(b, err)
}
})
Expand Down
11 changes: 3 additions & 8 deletions share/eds/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,12 @@ func (f validation) Size(ctx context.Context) int {
return int(size)
}

func (f validation) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) {
idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, f.Size(ctx))
if err != nil {
return shwap.Sample{}, err
}

_, err = shwap.NewSampleID(1, idx, f.Size(ctx))
func (f validation) Sample(ctx context.Context, idx shwap.SampleIndex) (shwap.Sample, error) {
_, err := shwap.NewSampleID(1, idx, f.Size(ctx))
if err != nil {
return shwap.Sample{}, fmt.Errorf("sample validation: %w", err)
}
return f.Accessor.Sample(ctx, rowIdx, colIdx)
return f.Accessor.Sample(ctx, idx)
}

func (f validation) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) {
Expand Down
9 changes: 8 additions & 1 deletion share/eds/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ func TestValidation_Sample(t *testing.T) {
accessor := &Rsmt2D{ExtendedDataSquare: randEDS}
validation := WithValidation(AccessorAndStreamer(accessor, nil))

_, err := validation.Sample(context.Background(), tt.rowIdx, tt.colIdx)
idx, err := shwap.SampleIndexFromCoordinates(tt.rowIdx, tt.colIdx, accessor.Size(context.Background()))
if tt.expectFail {
require.ErrorIs(t, err, shwap.ErrInvalidID, tt.name)
return
}
require.NoError(t, err, tt.name)

_, err = validation.Sample(context.Background(), idx)
if tt.expectFail {
require.ErrorIs(t, err, shwap.ErrInvalidID)
} else {
Expand Down
4 changes: 3 additions & 1 deletion share/shwap/getters/cascade.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ func NewCascadeGetter(getters []shwap.Getter) *CascadeGetter {
}

// GetSamples gets samples from any of registered shwap.Getters in cascading order.
func (cg *CascadeGetter) GetSamples(ctx context.Context, hdr *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) {
func (cg *CascadeGetter) GetSamples(ctx context.Context, hdr *header.ExtendedHeader,
indices []shwap.SampleIndex,
) ([]shwap.Sample, error) {
ctx, span := tracer.Start(ctx, "cascade/get-samples", trace.WithAttributes(
attribute.Int("amount", len(indices)),
))
Expand Down
Loading

0 comments on commit f823d6b

Please sign in to comment.