Skip to content

Commit

Permalink
refactor(mypy): add stricter rules to unittest and utils
Browse files Browse the repository at this point in the history
  • Loading branch information
glevco committed Mar 22, 2024
1 parent 466550d commit 9159491
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 57 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ module = [
"tests.p2p.*",
"tests.pubsub.*",
"tests.simulation.*",
"tests.unittest",
"tests.utils",
]
disallow_untyped_defs = true

Expand Down
3 changes: 2 additions & 1 deletion tests/p2p/test_double_spending.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from hathor.manager import HathorManager
from hathor.simulator.utils import add_new_blocks
from hathor.transaction import Transaction
from hathor.util import not_none
from tests import unittest
from tests.utils import add_blocks_unlock_reward, add_new_tx

Expand All @@ -23,7 +24,7 @@ def setUp(self) -> None:
def _add_new_transactions(self, manager: HathorManager, num_txs: int) -> list[Transaction]:
txs = []
for _ in range(num_txs):
address = self.get_address(0)
address = not_none(self.get_address(0))
value = self.rng.choice([5, 10, 15, 20])
tx = add_new_tx(manager, address, value)
txs.append(tx)
Expand Down
2 changes: 1 addition & 1 deletion tests/tx/test_indexes2.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_timestamp_index(self):
# XXX: we verified they're the same, doesn't matter which we pick:
idx = idx_memory
hashes = hashes_memory
self.log.debug('indexes match', idx=idx, hashes=unittest.shorten_hash(hashes))
self.log.debug('indexes match', idx=idx, hashes=unittest.short_hashes(hashes))
if idx is None:
break
offset_variety.add(idx[1])
Expand Down
112 changes: 70 additions & 42 deletions tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import shutil
import tempfile
import time
from typing import Iterator, Optional
from typing import Callable, Collection, Iterable, Iterator, Optional
from unittest import main as ut_main

from structlog import get_logger
Expand All @@ -16,13 +16,17 @@
from hathor.daa import DifficultyAdjustmentAlgorithm, TestMode
from hathor.event import EventManager
from hathor.event.storage import EventStorage
from hathor.manager import HathorManager
from hathor.p2p.peer_id import PeerId
from hathor.p2p.sync_v1.agent import NodeSyncTimestamp
from hathor.p2p.sync_v2.agent import NodeBlockSync
from hathor.p2p.sync_version import SyncVersion
from hathor.pubsub import PubSubManager
from hathor.reactor import ReactorProtocol as Reactor, get_global_reactor
from hathor.simulator.clock import MemoryReactorHeapClock
from hathor.transaction import BaseTransaction
from hathor.transaction import BaseTransaction, Block, Transaction
from hathor.transaction.storage.transaction_storage import TransactionStorage
from hathor.types import VertexId
from hathor.util import Random, not_none
from hathor.wallet import BaseWallet, HDWallet, Wallet
from tests.test_memory_reactor_clock import TestMemoryReactorClock
Expand All @@ -33,9 +37,8 @@
USE_MEMORY_STORAGE = os.environ.get('HATHOR_TEST_MEMORY_STORAGE', 'false').lower() == 'true'


def shorten_hash(container):
container_type = type(container)
return container_type(h[-2:].hex() for h in container)
def short_hashes(container: Collection[bytes]) -> Iterable[str]:
return map(lambda hash_bytes: hash_bytes[-2:].hex(), container)


def _load_peer_id_pool(file_path: Optional[str] = None) -> Iterator[PeerId]:
Expand All @@ -50,7 +53,7 @@ def _load_peer_id_pool(file_path: Optional[str] = None) -> Iterator[PeerId]:
yield PeerId.create_from_json(peer_id_dict)


def _get_default_peer_id_pool_filepath():
def _get_default_peer_id_pool_filepath() -> str:
this_file_path = os.path.dirname(__file__)
file_name = 'peer_id_pool.json'
file_path = os.path.join(this_file_path, file_name)
Expand Down Expand Up @@ -109,19 +112,19 @@ class TestCase(unittest.TestCase):
use_memory_storage: bool = USE_MEMORY_STORAGE
seed_config: Optional[int] = None

def setUp(self):
self.tmpdirs = []
def setUp(self) -> None:
self.tmpdirs: list[str] = []
self.clock = TestMemoryReactorClock()
self.clock.advance(time.time())
self.log = logger.new()
self.reset_peer_id_pool()
self.seed = secrets.randbits(64) if self.seed_config is None else self.seed_config
self.log.info('set seed', seed=self.seed)
self.rng = Random(self.seed)
self._pending_cleanups = []
self._pending_cleanups: list[Callable] = []
self._settings = get_global_settings()

def tearDown(self):
def tearDown(self) -> None:
self.clean_tmpdirs()
for fn in self._pending_cleanups:
fn()
Expand All @@ -144,12 +147,12 @@ def get_random_peer_id_from_pool(self, pool: Optional[list[PeerId]] = None,
pool.remove(peer_id)
return peer_id

def mkdtemp(self):
def mkdtemp(self) -> str:
tmpdir = tempfile.mkdtemp()
self.tmpdirs.append(tmpdir)
return tmpdir

def _create_test_wallet(self, unlocked=False):
def _create_test_wallet(self, unlocked: bool = False) -> Wallet:
""" Generate a Wallet with a number of keypairs for testing
:rtype: Wallet
"""
Expand All @@ -169,14 +172,14 @@ def get_builder(self, network: str) -> TestBuilder:
.set_network(network)
return builder

def create_peer_from_builder(self, builder, start_manager=True):
def create_peer_from_builder(self, builder: Builder, start_manager: bool = True) -> HathorManager:
artifacts = builder.build()
manager = artifacts.manager

if artifacts.rocksdb_storage:
self._pending_cleanups.append(artifacts.rocksdb_storage.close)

manager.avg_time_between_blocks = 0.0001
# manager.avg_time_between_blocks = 0.0001 # FIXME: This property is not defined. Fix this.

if start_manager:
manager.start()
Expand Down Expand Up @@ -277,7 +280,7 @@ def create_peer( # type: ignore[no-untyped-def]

return manager

def run_to_completion(self):
def run_to_completion(self) -> None:
""" This will advance the test's clock until all calls scheduled are done.
"""
for call in self.clock.getDelayedCalls():
Expand All @@ -300,7 +303,11 @@ def assertIsTopological(self, tx_sequence: Iterator[BaseTransaction], message: O
self.assertIn(dep, valid_deps, message)
valid_deps.add(tx.hash)

def _syncVersionFlags(self, enable_sync_v1=None, enable_sync_v2=None):
def _syncVersionFlags(
self,
enable_sync_v1: bool | None = None,
enable_sync_v2: bool | None = None
) -> tuple[bool, bool]:
"""Internal: use this to check and get the flags and optionally provide override values."""
if enable_sync_v1 is None:
assert hasattr(self, '_enable_sync_v1'), ('`_enable_sync_v1` has no default by design, either set one on '
Expand All @@ -313,19 +320,19 @@ def _syncVersionFlags(self, enable_sync_v1=None, enable_sync_v2=None):
assert enable_sync_v1 or enable_sync_v2, 'enable at least one sync version'
return enable_sync_v1, enable_sync_v2

def assertTipsEqual(self, manager1, manager2):
def assertTipsEqual(self, manager1: HathorManager, manager2: HathorManager) -> None:
_, enable_sync_v2 = self._syncVersionFlags()
if enable_sync_v2:
self.assertTipsEqualSyncV2(manager1, manager2)
else:
self.assertTipsEqualSyncV1(manager1, manager2)

def assertTipsNotEqual(self, manager1, manager2):
def assertTipsNotEqual(self, manager1: HathorManager, manager2: HathorManager) -> None:
s1 = set(manager1.tx_storage.get_all_tips())
s2 = set(manager2.tx_storage.get_all_tips())
self.assertNotEqual(s1, s2)

def assertTipsEqualSyncV1(self, manager1, manager2):
def assertTipsEqualSyncV1(self, manager1: HathorManager, manager2: HathorManager) -> None:
# XXX: this is the original implementation of assertTipsEqual
s1 = set(manager1.tx_storage.get_all_tips())
s2 = set(manager2.tx_storage.get_all_tips())
Expand All @@ -335,39 +342,45 @@ def assertTipsEqualSyncV1(self, manager1, manager2):
s2 = set(manager2.tx_storage.get_tx_tips())
self.assertEqual(s1, s2)

def assertTipsEqualSyncV2(self, manager1, manager2, *, strict_sync_v2_indexes=True):
def assertTipsEqualSyncV2(
self,
manager1: HathorManager,
manager2: HathorManager,
*,
strict_sync_v2_indexes: bool = True
) -> None:
# tx tips
if strict_sync_v2_indexes:
tips1 = manager1.tx_storage.indexes.mempool_tips.get()
tips2 = manager2.tx_storage.indexes.mempool_tips.get()
tips1 = not_none(not_none(manager1.tx_storage.indexes).mempool_tips).get()
tips2 = not_none(not_none(manager2.tx_storage.indexes).mempool_tips).get()
else:
tips1 = {tx.hash for tx in manager1.tx_storage.iter_mempool_tips_from_best_index()}
tips2 = {tx.hash for tx in manager2.tx_storage.iter_mempool_tips_from_best_index()}
self.log.debug('tx tips1', len=len(tips1), list=shorten_hash(tips1))
self.log.debug('tx tips2', len=len(tips2), list=shorten_hash(tips2))
self.log.debug('tx tips1', len=len(tips1), list=short_hashes(tips1))
self.log.debug('tx tips2', len=len(tips2), list=short_hashes(tips2))
self.assertEqual(tips1, tips2)

# best block
s1 = set(manager1.tx_storage.get_best_block_tips())
s2 = set(manager2.tx_storage.get_best_block_tips())
self.log.debug('block tips1', len=len(s1), list=shorten_hash(s1))
self.log.debug('block tips2', len=len(s2), list=shorten_hash(s2))
self.log.debug('block tips1', len=len(s1), list=short_hashes(s1))
self.log.debug('block tips2', len=len(s2), list=short_hashes(s2))
self.assertEqual(s1, s2)

# best block (from height index)
b1 = manager1.tx_storage.indexes.height.get_tip()
b2 = manager2.tx_storage.indexes.height.get_tip()
b1 = not_none(manager1.tx_storage.indexes).height.get_tip()
b2 = not_none(manager2.tx_storage.indexes).height.get_tip()
self.assertIn(b1, s2)
self.assertIn(b2, s1)

def assertConsensusEqual(self, manager1, manager2):
def assertConsensusEqual(self, manager1: HathorManager, manager2: HathorManager) -> None:
_, enable_sync_v2 = self._syncVersionFlags()
if enable_sync_v2:
self.assertConsensusEqualSyncV2(manager1, manager2)
else:
self.assertConsensusEqualSyncV1(manager1, manager2)

def assertConsensusEqualSyncV1(self, manager1, manager2):
def assertConsensusEqualSyncV1(self, manager1: HathorManager, manager2: HathorManager) -> None:
self.assertEqual(manager1.tx_storage.get_vertices_count(), manager2.tx_storage.get_vertices_count())
for tx1 in manager1.tx_storage.get_all_transactions():
tx2 = manager2.tx_storage.get_transaction(tx1.hash)
Expand All @@ -381,12 +394,20 @@ def assertConsensusEqualSyncV1(self, manager1, manager2):
self.assertIsNone(tx2_meta.voided_by)
else:
# If tx1 is voided, then tx2 must be voided.
assert tx1_meta.voided_by is not None
assert tx2_meta.voided_by is not None
self.assertGreaterEqual(len(tx1_meta.voided_by), 1)
self.assertGreaterEqual(len(tx2_meta.voided_by), 1)
# Hard verification
# self.assertEqual(tx1_meta.voided_by, tx2_meta.voided_by)

def assertConsensusEqualSyncV2(self, manager1, manager2, *, strict_sync_v2_indexes=True):
def assertConsensusEqualSyncV2(
self,
manager1: HathorManager,
manager2: HathorManager,
*,
strict_sync_v2_indexes: bool = True
) -> None:
# The current sync algorithm does not propagate voided blocks/txs
# so the count might be different even though the consensus is equal
# One peer might have voided txs that the other does not have
Expand All @@ -397,7 +418,9 @@ def assertConsensusEqualSyncV2(self, manager1, manager2, *, strict_sync_v2_index
# the following is specific to sync-v2

# helper function:
def get_all_executed_or_voided(tx_storage):
def get_all_executed_or_voided(
tx_storage: TransactionStorage
) -> tuple[set[VertexId], set[VertexId], set[VertexId]]:
"""Get all txs separated into three sets: executed, voided, partial"""
tx_executed = set()
tx_voided = set()
Expand All @@ -424,14 +447,16 @@ def get_all_executed_or_voided(tx_storage):
self.log.debug('node1 rest', len_voided=len(tx_voided1), len_partial=len(tx_partial1))
self.log.debug('node2 rest', len_voided=len(tx_voided2), len_partial=len(tx_partial2))

def assertConsensusValid(self, manager):
def assertConsensusValid(self, manager: HathorManager) -> None:
for tx in manager.tx_storage.get_all_transactions():
if tx.is_block:
assert isinstance(tx, Block)
self.assertBlockConsensusValid(tx)
else:
assert isinstance(tx, Transaction)
self.assertTransactionConsensusValid(tx)

def assertBlockConsensusValid(self, block):
def assertBlockConsensusValid(self, block: Block) -> None:
self.assertTrue(block.is_block)
if not block.parents:
# Genesis
Expand All @@ -442,7 +467,8 @@ def assertBlockConsensusValid(self, block):
parent_meta = parent.get_metadata()
self.assertIsNone(parent_meta.voided_by)

def assertTransactionConsensusValid(self, tx):
def assertTransactionConsensusValid(self, tx: Transaction) -> None:
assert tx.storage is not None
self.assertFalse(tx.is_block)
meta = tx.get_metadata()
if meta.voided_by and tx.hash in meta.voided_by:
Expand All @@ -462,38 +488,40 @@ def assertTransactionConsensusValid(self, tx):
spent_meta = spent_tx.get_metadata()

if spent_meta.voided_by is not None:
self.assertIsNotNone(meta.voided_by)
assert meta.voided_by is not None
self.assertTrue(spent_meta.voided_by)
self.assertTrue(meta.voided_by)
self.assertTrue(spent_meta.voided_by.issubset(meta.voided_by))

for parent in tx.get_parents():
parent_meta = parent.get_metadata()
if parent_meta.voided_by is not None:
self.assertIsNotNone(meta.voided_by)
assert meta.voided_by is not None
self.assertTrue(parent_meta.voided_by)
self.assertTrue(meta.voided_by)
self.assertTrue(parent_meta.voided_by.issubset(meta.voided_by))

def assertSyncedProgress(self, node_sync):
def assertSyncedProgress(self, node_sync: NodeSyncTimestamp | NodeBlockSync) -> None:
"""Check "synced" status of p2p-manager, uses self._enable_sync_vX to choose which check to run."""
enable_sync_v1, enable_sync_v2 = self._syncVersionFlags()
if enable_sync_v2:
assert isinstance(node_sync, NodeBlockSync)
self.assertV2SyncedProgress(node_sync)
elif enable_sync_v1:
assert isinstance(node_sync, NodeSyncTimestamp)
self.assertV1SyncedProgress(node_sync)

def assertV1SyncedProgress(self, node_sync):
def assertV1SyncedProgress(self, node_sync: NodeSyncTimestamp) -> None:
self.assertEqual(node_sync.synced_timestamp, node_sync.peer_timestamp)

def assertV2SyncedProgress(self, node_sync):
def assertV2SyncedProgress(self, node_sync: NodeBlockSync) -> None:
self.assertEqual(node_sync.synced_block, node_sync.peer_best_block)

def clean_tmpdirs(self):
def clean_tmpdirs(self) -> None:
for tmpdir in self.tmpdirs:
shutil.rmtree(tmpdir)

def clean_pending(self, required_to_quiesce=True):
def clean_pending(self, required_to_quiesce: bool = True) -> None:
"""
This handy method cleans all pending tasks from the reactor.
Expand Down
Loading

0 comments on commit 9159491

Please sign in to comment.