Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(traversal): add BFS order traversal #672

Merged
merged 1 commit into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions hathor/consensus/block_consensus.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,8 @@ def remove_first_block_markers(self, block: Block) -> None:
assert block.storage is not None
storage = block.storage

from hathor.transaction.storage.traversal import BFSWalk
bfs = BFSWalk(storage, is_dag_verifications=True, is_left_to_right=False)
from hathor.transaction.storage.traversal import BFSTimestampWalk
bfs = BFSTimestampWalk(storage, is_dag_verifications=True, is_left_to_right=False)
for tx in bfs.run(block, skip_root=True):
if tx.is_block:
bfs.skip_neighbors(tx)
Expand Down Expand Up @@ -470,8 +470,8 @@ def _score_block_dfs(self, block: BaseTransaction, used: set[bytes],
score = sum_weights(score, x)

else:
from hathor.transaction.storage.traversal import BFSWalk
bfs = BFSWalk(storage, is_dag_verifications=True, is_left_to_right=False)
from hathor.transaction.storage.traversal import BFSTimestampWalk
bfs = BFSTimestampWalk(storage, is_dag_verifications=True, is_left_to_right=False)
for tx in bfs.run(parent, skip_root=False):
assert tx.hash is not None
assert not tx.is_block
Expand Down
9 changes: 5 additions & 4 deletions hathor/consensus/transaction_consensus.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def remove_voided_by(self, tx: Transaction, voided_hash: bytes) -> bool:
""" Remove a hash from `meta.voided_by` and its descendants (both from verification DAG
and funds tree).
"""
from hathor.transaction.storage.traversal import BFSWalk
from hathor.transaction.storage.traversal import BFSTimestampWalk

assert tx.hash is not None
assert tx.storage is not None
Expand All @@ -347,7 +347,7 @@ def remove_voided_by(self, tx: Transaction, voided_hash: bytes) -> bool:

self.log.debug('remove_voided_by', tx=tx.hash_hex, voided_hash=voided_hash.hex())

bfs = BFSWalk(tx.storage, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=True)
bfs = BFSTimestampWalk(tx.storage, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=True)
check_list: list[BaseTransaction] = []
for tx2 in bfs.run(tx, skip_root=False):
assert tx2.storage is not None
Expand Down Expand Up @@ -404,8 +404,9 @@ def add_voided_by(self, tx: Transaction, voided_hash: bytes) -> bool:
# If tx is soft voided, we can only walk through the DAG of funds.
is_dag_verifications = False

from hathor.transaction.storage.traversal import BFSWalk
bfs = BFSWalk(tx.storage, is_dag_funds=True, is_dag_verifications=is_dag_verifications, is_left_to_right=True)
from hathor.transaction.storage.traversal import BFSTimestampWalk
bfs = BFSTimestampWalk(tx.storage, is_dag_funds=True, is_dag_verifications=is_dag_verifications,
is_left_to_right=True)
check_list: list[Transaction] = []
for tx2 in bfs.run(tx, skip_root=False):
assert tx2.storage is not None
Expand Down
4 changes: 2 additions & 2 deletions hathor/indexes/mempool_tips_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ def iter(self, tx_storage: 'TransactionStorage', max_timestamp: Optional[float]
yield from cast(Iterator[Transaction], it)

def iter_all(self, tx_storage: 'TransactionStorage') -> Iterator[Transaction]:
from hathor.transaction.storage.traversal import BFSWalk
bfs = BFSWalk(tx_storage, is_dag_verifications=True, is_left_to_right=False)
from hathor.transaction.storage.traversal import BFSTimestampWalk
bfs = BFSTimestampWalk(tx_storage, is_dag_verifications=True, is_left_to_right=False)
for tx in bfs.run(self.iter(tx_storage), skip_root=False):
assert isinstance(tx, Transaction)
if tx.get_metadata().first_block is not None:
Expand Down
4 changes: 2 additions & 2 deletions hathor/transaction/base_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,8 +951,8 @@ def update_accumulated_weight(self, *, stop_value: float = inf, save_file: bool
# reduce the number of visits in the BFS. We need to specially handle when a transaction is not
# directly verified by a block.

from hathor.transaction.storage.traversal import BFSWalk
bfs_walk = BFSWalk(self.storage, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=True)
from hathor.transaction.storage.traversal import BFSTimestampWalk
bfs_walk = BFSTimestampWalk(self.storage, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=True)
for tx in bfs_walk.run(self, skip_root=True):
accumulated_weight = sum_weights(accumulated_weight, tx.weight)
if accumulated_weight > stop_value:
Expand Down
4 changes: 2 additions & 2 deletions hathor/transaction/storage/transaction_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,10 +1026,10 @@ def iter_mempool_from_tx_tips(self) -> Iterator[Transaction]:

This method requires indexes to be enabled.
"""
from hathor.transaction.storage.traversal import BFSWalk
from hathor.transaction.storage.traversal import BFSTimestampWalk

root = self.iter_mempool_tips_from_tx_tips()
walk = BFSWalk(self, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=False)
walk = BFSTimestampWalk(self, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=False)
for tx in walk.run(root):
tx_meta = tx.get_metadata()
# XXX: skip blocks and tx-tips that have already been confirmed
Expand Down
79 changes: 61 additions & 18 deletions hathor/transaction/storage/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@

import heapq
from abc import ABC, abstractmethod
from collections import deque
from itertools import chain
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Optional, Union
from typing import TYPE_CHECKING, Iterable, Iterator, Optional, Union

if TYPE_CHECKING:
from hathor.transaction import BaseTransaction # noqa: F401
from hathor.transaction.storage import TransactionStorage # noqa: F401
from hathor.types import VertexId


class HeapItem:
Expand All @@ -43,8 +45,7 @@ def __le__(self, other: 'HeapItem') -> bool:
class GenericWalk(ABC):
""" A helper class to walk on the DAG.
"""
seen: set[bytes]
to_visit: list[Any]
seen: set['VertexId']

def __init__(self, storage: 'TransactionStorage', *, is_dag_funds: bool = False,
is_dag_verifications: bool = False, is_left_to_right: bool = True):
Expand All @@ -58,7 +59,6 @@ def __init__(self, storage: 'TransactionStorage', *, is_dag_funds: bool = False,
"""
self.storage = storage
self.seen = set()
self.to_visit = []

self.is_dag_funds = is_dag_funds
self.is_dag_verifications = is_dag_verifications
Expand All @@ -79,26 +79,36 @@ def _pop_visit(self) -> 'BaseTransaction':
"""
raise NotImplementedError

def add_neighbors(self, tx: 'BaseTransaction') -> None:
""" Add neighbors of `tx` to be visited later according to the configuration.
@abstractmethod
def _is_empty(self) -> bool:
""" Return true if there aren't any txs left to be visited.
"""
raise NotImplementedError

def _get_iterator(self, tx: 'BaseTransaction', *, is_left_to_right: bool) -> Iterator['VertexId']:
meta = None
it: Iterator[bytes] = chain()
it: Iterator['VertexId'] = chain()

if self.is_dag_verifications:
if self.is_left_to_right:
if is_left_to_right:
meta = meta or tx.get_metadata()
it = chain(it, meta.children)
else:
it = chain(it, tx.parents)

if self.is_dag_funds:
if self.is_left_to_right:
if is_left_to_right:
meta = meta or tx.get_metadata()
it = chain(it, *meta.spent_outputs.values())
else:
it = chain(it, [txin.tx_id for txin in tx.inputs])

return it

def add_neighbors(self, tx: 'BaseTransaction') -> None:
""" Add neighbors of `tx` to be visited later according to the configuration.
"""
it = self._get_iterator(tx, is_left_to_right=self.is_left_to_right)
for _hash in it:
if _hash not in self.seen:
self.seen.add(_hash)
Expand Down Expand Up @@ -131,7 +141,7 @@ def run(self, root: Union['BaseTransaction', Iterable['BaseTransaction']], *,
else:
self.add_neighbors(root)

while self.to_visit:
while not self._is_empty():
tx = self._pop_visit()
assert tx.hash is not None
yield tx
Expand All @@ -142,16 +152,23 @@ def run(self, root: Union['BaseTransaction', Iterable['BaseTransaction']], *,
self._ignore_neighbors = None


class BFSWalk(GenericWalk):
""" A help to walk in the DAG using a BFS.
class BFSTimestampWalk(GenericWalk):
""" A help to walk in the DAG using a BFS that prioritizes by timestamp.
"""
to_visit: list[HeapItem]
_to_visit: list[HeapItem]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._to_visit = []

def _is_empty(self) -> bool:
return not self._to_visit

def _push_visit(self, tx: 'BaseTransaction') -> None:
heapq.heappush(self.to_visit, HeapItem(tx, reverse=self._reverse_heap))
heapq.heappush(self._to_visit, HeapItem(tx, reverse=self._reverse_heap))

def _pop_visit(self) -> 'BaseTransaction':
item = heapq.heappop(self.to_visit)
item = heapq.heappop(self._to_visit)
tx = item.tx
# We can safely remove it because we are walking in topological order
# and it won't appear again in the future because this would be a cycle.
Expand All @@ -160,13 +177,39 @@ def _pop_visit(self) -> 'BaseTransaction':
return tx


class BFSOrderWalk(GenericWalk):
""" A help to walk in the DAG using a BFS.
"""
_to_visit: deque['BaseTransaction']

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._to_visit = deque()

def _is_empty(self) -> bool:
return not self._to_visit

def _push_visit(self, tx: 'BaseTransaction') -> None:
self._to_visit.append(tx)

def _pop_visit(self) -> 'BaseTransaction':
return self._to_visit.popleft()


class DFSWalk(GenericWalk):
""" A help to walk in the DAG using a DFS.
"""
to_visit: list['BaseTransaction']
_to_visit: list['BaseTransaction']

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._to_visit = []

def _is_empty(self) -> bool:
return not self._to_visit
glevco marked this conversation as resolved.
Show resolved Hide resolved

def _push_visit(self, tx: 'BaseTransaction') -> None:
self.to_visit.append(tx)
self._to_visit.append(tx)

def _pop_visit(self) -> 'BaseTransaction':
return self.to_visit.pop()
return self._to_visit.pop()
55 changes: 49 additions & 6 deletions tests/tx/test_traversal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from math import inf

from hathor.transaction.storage.traversal import BFSWalk, DFSWalk
from hathor.transaction.storage.traversal import BFSOrderWalk, BFSTimestampWalk, DFSWalk
from tests import unittest
from tests.utils import add_blocks_unlock_reward, add_new_blocks, add_new_transactions, add_new_tx

Expand Down Expand Up @@ -86,9 +86,9 @@ def test_right_to_left(self):
self.assertTrue(seen_v.union(seen_f).issubset(seen_vf))


class BaseBFSWalkTestCase(_TraversalTestCase):
class BaseBFSTimestampWalkTestCase(_TraversalTestCase):
def gen_walk(self, **kwargs):
return BFSWalk(self.manager.tx_storage, **kwargs)
return BFSTimestampWalk(self.manager.tx_storage, **kwargs)

def _run_lr(self, walk, skip_root=True):
seen = set()
Expand All @@ -109,16 +109,59 @@ def _run_rl(self, walk):
return seen


class SyncV1BFSWalkTestCase(unittest.SyncV1Params, BaseBFSWalkTestCase):
class SyncV1BFSTimestampWalkTestCase(unittest.SyncV1Params, BaseBFSTimestampWalkTestCase):
__test__ = True


class SyncV2BFSWalkTestCase(unittest.SyncV2Params, BaseBFSWalkTestCase):
class SyncV2BFSTimestampWalkTestCase(unittest.SyncV2Params, BaseBFSTimestampWalkTestCase):
__test__ = True


class BaseBFSOrderWalkTestCase(_TraversalTestCase):
def gen_walk(self, **kwargs):
return BFSOrderWalk(self.manager.tx_storage, **kwargs)

def _run_lr(self, walk, skip_root=True):
seen = set()
distance = {}
distance[self.root_tx.hash] = 0
last_dist = 0
for tx in walk.run(self.root_tx, skip_root=skip_root):
seen.add(tx.hash)
it = walk._get_iterator(tx, is_left_to_right=False)
dist = 1 + min(distance.get(_hash, inf) for _hash in it)
self.assertIsInstance(dist, int)
distance[tx.hash] = dist
self.assertGreaterEqual(dist, last_dist)
last_dist = dist
return seen

def _run_rl(self, walk):
seen = set()
distance = {}
distance[self.root_tx.hash] = 0
last_dist = 0
for tx in walk.run(self.root_tx, skip_root=True):
seen.add(tx.hash)
it = walk._get_iterator(tx, is_left_to_right=True)
dist = 1 + min(distance.get(_hash, inf) for _hash in it)
self.assertIsInstance(dist, int)
distance[tx.hash] = dist
self.assertGreaterEqual(dist, last_dist)
last_dist = dist
return seen


class SyncV1BFSOrderWalkTestCase(unittest.SyncV1Params, BaseBFSOrderWalkTestCase):
__test__ = True


class SyncV2BFSOrderWalkTestCase(unittest.SyncV2Params, BaseBFSOrderWalkTestCase):
__test__ = True


# sync-bridge should behave like sync-v2
class SyncBridgeBFSWalkTestCase(unittest.SyncBridgeParams, SyncV2BFSWalkTestCase):
class SyncBridgeBFSOrderWalkTestCase(unittest.SyncBridgeParams, SyncV2BFSOrderWalkTestCase):
pass


Expand Down