diff --git a/das/daser_test.go b/das/daser_test.go index ae43d28fbb..fac628418b 100644 --- a/das/daser_test.go +++ b/das/daser_test.go @@ -319,7 +319,15 @@ func (m getterStub) GetByHeight(_ context.Context, height uint64) (*header.Exten DAH: &header.DataAvailabilityHeader{RowsRoots: make([][]byte, 0)}}, nil } -func (m getterStub) GetRangeByHeight(ctx context.Context, from, to uint64) ([]*header.ExtendedHeader, error) { +func (m getterStub) GetRangeByHeight(ctx context.Context, from, amount uint64) ([]*header.ExtendedHeader, error) { + return nil, nil +} + +func (m getterStub) GetVerifiedRange( + context.Context, + *header.ExtendedHeader, + uint64, +) ([]*header.ExtendedHeader, error) { return nil, nil } diff --git a/header/core/exchange.go b/header/core/exchange.go index b3c6d917d0..9689f84f00 100644 --- a/header/core/exchange.go +++ b/header/core/exchange.go @@ -55,6 +55,23 @@ func (ce *Exchange) GetRangeByHeight(ctx context.Context, from, amount uint64) ( return headers, nil } +func (ce *Exchange) GetVerifiedRange(ctx context.Context, from *header.ExtendedHeader, amount uint64, +) ([]*header.ExtendedHeader, error) { + headers, err := ce.GetRangeByHeight(ctx, uint64(from.Height)+1, amount) + if err != nil { + return nil, err + } + + for _, h := range headers { + err := from.VerifyAdjacent(h) + if err != nil { + return nil, err + } + from = h + } + return headers, nil +} + func (ce *Exchange) Get(ctx context.Context, hash tmbytes.HexBytes) (*header.ExtendedHeader, error) { log.Debugw("requesting header", "hash", hash.String()) block, err := ce.fetcher.GetBlockByHash(ctx, hash) diff --git a/header/interface.go b/header/interface.go index 1273dea103..b6a439f4f1 100644 --- a/header/interface.go +++ b/header/interface.go @@ -123,8 +123,12 @@ type Getter interface { // GetByHeight returns the ExtendedHeader corresponding to the given block height. GetByHeight(context.Context, uint64) (*ExtendedHeader, error) - // GetRangeByHeight returns the given range [from:to) of ExtendedHeaders. - GetRangeByHeight(ctx context.Context, from, to uint64) ([]*ExtendedHeader, error) + // GetRangeByHeight returns the given range of ExtendedHeaders. + GetRangeByHeight(ctx context.Context, from, amount uint64) ([]*ExtendedHeader, error) + + // GetVerifiedRange requests the header range from the provided ExtendedHeader and + // verifies that the returned headers are adjacent to each other. + GetVerifiedRange(ctx context.Context, from *ExtendedHeader, amount uint64) ([]*ExtendedHeader, error) } // Head contains the behavior necessary for a component to retrieve diff --git a/header/local/exchange.go b/header/local/exchange.go index 9b284930a4..3bd978f483 100644 --- a/header/local/exchange.go +++ b/header/local/exchange.go @@ -43,6 +43,11 @@ func (l *Exchange) GetRangeByHeight(ctx context.Context, origin, amount uint64) return l.store.GetRangeByHeight(ctx, origin, origin+amount) } +func (l *Exchange) GetVerifiedRange(ctx context.Context, from *header.ExtendedHeader, amount uint64, +) ([]*header.ExtendedHeader, error) { + return l.store.GetVerifiedRange(ctx, from, uint64(from.Height)+amount) +} + func (l *Exchange) Get(ctx context.Context, hash bytes.HexBytes) (*header.ExtendedHeader, error) { return l.store.Get(ctx, hash) } diff --git a/header/p2p/exchange.go b/header/p2p/exchange.go index 721fc1500d..41cb4d8925 100644 --- a/header/p2p/exchange.go +++ b/header/p2p/exchange.go @@ -151,6 +151,31 @@ func (ex *Exchange) GetRangeByHeight(ctx context.Context, from, amount uint64) ( return session.getRangeByHeight(ctx, from, amount) } +// GetVerifiedRange performs a request for the given range of ExtendedHeaders to the network and ensures +// that returned headers are correct against the passed one. +func (ex *Exchange) GetVerifiedRange( + ctx context.Context, + from *header.ExtendedHeader, + amount uint64, +) ([]*header.ExtendedHeader, error) { + session := newSession(ex.ctx, ex.host, ex.peerTracker.peers(), ex.protocolID) + defer session.close() + + headers, err := session.getRangeByHeight(ctx, uint64(from.Height)+1, amount) + if err != nil { + return nil, err + } + + for _, h := range headers { + err := from.VerifyAdjacent(h) + if err != nil { + return nil, err + } + from = h + } + return headers, nil +} + // Get performs a request for the ExtendedHeader by the given hash corresponding // to the RawHeader. Note that the ExtendedHeader must be verified thereafter. func (ex *Exchange) Get(ctx context.Context, hash tmbytes.HexBytes) (*header.ExtendedHeader, error) { diff --git a/header/p2p/exchange_test.go b/header/p2p/exchange_test.go index 8dfb8c3139..e35a36d897 100644 --- a/header/p2p/exchange_test.go +++ b/header/p2p/exchange_test.go @@ -55,6 +55,25 @@ func TestExchange_RequestHeaders(t *testing.T) { } } +func TestExchange_RequestVerifiedHeaders(t *testing.T) { + hosts := createMocknet(t, 2) + exchg, store := createP2PExAndServer(t, hosts[0], hosts[1]) + // perform expected request + h := store.headers[1] + _, err := exchg.GetVerifiedRange(context.Background(), h, 3) + require.NoError(t, err) +} + +func TestExchange_RequestVerifiedHeadersFails(t *testing.T) { + hosts := createMocknet(t, 2) + exchg, store := createP2PExAndServer(t, hosts[0], hosts[1]) + store.headers[2] = store.headers[3] + // perform expected request + h := store.headers[1] + _, err := exchg.GetVerifiedRange(context.Background(), h, 3) + require.Error(t, err) +} + // TestExchange_RequestFullRangeHeaders requests max amount of headers // to verify how session will parallelize all requests. func TestExchange_RequestFullRangeHeaders(t *testing.T) { @@ -342,6 +361,14 @@ func (m *mockStore) GetRangeByHeight(ctx context.Context, from, to uint64) ([]*h return headers, nil } +func (m *mockStore) GetVerifiedRange( + ctx context.Context, + h *header.ExtendedHeader, + to uint64, +) ([]*header.ExtendedHeader, error) { + return m.GetRangeByHeight(ctx, uint64(h.Height)+1, to) +} + func (m *mockStore) Has(context.Context, tmbytes.HexBytes) (bool, error) { return false, nil } diff --git a/header/p2p/session.go b/header/p2p/session.go index 3f64f71173..74889c9d1c 100644 --- a/header/p2p/session.go +++ b/header/p2p/session.go @@ -210,7 +210,9 @@ func (s *session) processResponse(responses []*p2p_pb.ExtendedHeaderResponse) ([ } headers = append(headers, header) } - + if len(headers) == 0 { + return nil, header.ErrNotFound + } return headers, nil } diff --git a/header/store/store.go b/header/store/store.go index 9587fa79f7..42be50e955 100644 --- a/header/store/store.go +++ b/header/store/store.go @@ -245,6 +245,29 @@ func (s *Store) GetRangeByHeight(ctx context.Context, from, to uint64) ([]*heade return headers, nil } +func (s *Store) GetVerifiedRange( + ctx context.Context, + from *header.ExtendedHeader, + to uint64, +) ([]*header.ExtendedHeader, error) { + if uint64(from.Height) >= to { + return nil, fmt.Errorf("header/store: invalid range(%d,%d)", from.Height, to) + } + headers, err := s.GetRangeByHeight(ctx, uint64(from.Height)+1, to) + if err != nil { + return nil, err + } + + for _, h := range headers { + err := from.VerifyAdjacent(h) + if err != nil { + return nil, err + } + from = h + } + return headers, nil +} + func (s *Store) Has(ctx context.Context, hash tmbytes.HexBytes) (bool, error) { if ok := s.cache.Contains(hash.String()); ok { return ok, nil diff --git a/header/sync/sync_test.go b/header/sync/sync_test.go index 4d65a576da..4324925f56 100644 --- a/header/sync/sync_test.go +++ b/header/sync/sync_test.go @@ -212,6 +212,14 @@ func (e *exchangeCountingHead) GetByHeight(ctx context.Context, u uint64) (*head panic("implement me") } -func (e *exchangeCountingHead) GetRangeByHeight(c context.Context, from, to uint64) ([]*header.ExtendedHeader, error) { +func (e *exchangeCountingHead) GetRangeByHeight( + c context.Context, + from, amount uint64, +) ([]*header.ExtendedHeader, error) { + panic("implement me") +} + +func (e *exchangeCountingHead) GetVerifiedRange(c context.Context, from *header.ExtendedHeader, amount uint64, +) ([]*header.ExtendedHeader, error) { panic("implement me") }