Skip to content

Commit

Permalink
refactor(mypy): add stricter rules to p2p tests
Browse files Browse the repository at this point in the history
  • Loading branch information
glevco committed Mar 19, 2024
1 parent 6e1595f commit ef291c0
Show file tree
Hide file tree
Showing 26 changed files with 348 additions and 313 deletions.
7 changes: 4 additions & 3 deletions hathor/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
TransactionRocksDBStorage,
TransactionStorage,
)
from hathor.transaction.storage.transaction_storage import BaseTransactionStorage
from hathor.util import Random, get_environment_info, not_none
from hathor.verification.verification_service import VerificationService
from hathor.verification.vertex_verifiers import VertexVerifiers
Expand Down Expand Up @@ -131,7 +132,7 @@ def __init__(self) -> None:
self._tx_storage_cache_capacity: Optional[int] = None

self._indexes_manager: Optional[IndexesManager] = None
self._tx_storage: Optional[TransactionStorage] = None
self._tx_storage: Optional[BaseTransactionStorage] = None
self._event_storage: Optional[EventStorage] = None

self._reactor: Optional[Reactor] = None
Expand Down Expand Up @@ -393,7 +394,7 @@ def _get_or_create_indexes_manager(self) -> IndexesManager:

return self._indexes_manager

def _get_or_create_tx_storage(self) -> TransactionStorage:
def _get_or_create_tx_storage(self) -> BaseTransactionStorage:
indexes = self._get_or_create_indexes_manager()

if self._tx_storage is not None:
Expand Down Expand Up @@ -616,7 +617,7 @@ def enable_event_queue(self) -> 'Builder':
self._enable_event_queue = True
return self

def set_tx_storage(self, tx_storage: TransactionStorage) -> 'Builder':
def set_tx_storage(self, tx_storage: BaseTransactionStorage) -> 'Builder':
self.check_if_can_modify()
self._tx_storage = tx_storage
return self
Expand Down
4 changes: 2 additions & 2 deletions hathor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
from hathor.stratum import StratumFactory
from hathor.transaction import BaseTransaction, Block, MergeMinedBlock, Transaction, TxVersion, sum_weights
from hathor.transaction.exceptions import TxValidationError
from hathor.transaction.storage import TransactionStorage
from hathor.transaction.storage.exceptions import TransactionDoesNotExist
from hathor.transaction.storage.transaction_storage import BaseTransactionStorage
from hathor.transaction.storage.tx_allow_scope import TxAllowScope
from hathor.types import Address, VertexId
from hathor.util import EnvironmentInfo, LogDuration, Random, calculate_min_significant_weight, not_none
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self,
consensus_algorithm: ConsensusAlgorithm,
daa: DifficultyAdjustmentAlgorithm,
peer_id: PeerId,
tx_storage: TransactionStorage,
tx_storage: BaseTransactionStorage,
p2p_manager: ConnectionsManager,
event_manager: EventManager,
feature_service: FeatureService,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ module = [
"tests.event.*",
"tests.execution_manager.*",
"tests.feature_activation.*",
# "tests.p2p.*",
"tests.p2p.*",
"tests.pubsub.*",
"tests.simulation.*",
]
Expand Down
6 changes: 4 additions & 2 deletions tests/p2p/netfilter/test_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import Mock

from twisted.internet.address import IPv4Address

from hathor.p2p.netfilter import get_table
Expand All @@ -10,7 +12,7 @@


class NetfilterFactoryTest(unittest.TestCase):
def test_factory(self):
def test_factory(self) -> None:
pre_conn = get_table('filter').get_chain('pre_conn')

match = NetfilterMatchIPAddress('192.168.0.1/32')
Expand All @@ -20,7 +22,7 @@ def test_factory(self):
builder = TestBuilder()
artifacts = builder.build()
wrapped_factory = artifacts.p2p_manager.server_factory
factory = NetfilterFactory(connections=None, wrappedFactory=wrapped_factory)
factory = NetfilterFactory(connections=Mock(), wrappedFactory=wrapped_factory)

ret = factory.buildProtocol(IPv4Address('TCP', '192.168.0.1', 1234))
self.assertIsNone(ret)
Expand Down
50 changes: 25 additions & 25 deletions tests/p2p/netfilter/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def match(self, context: 'NetfilterContext') -> bool:


class NetfilterMatchTest(unittest.TestCase):
def test_match_all(self):
def test_match_all(self) -> None:
matcher = NetfilterMatchAll()
context = NetfilterContext()
self.assertTrue(matcher.match(context))
Expand All @@ -31,7 +31,7 @@ def test_match_all(self):
json = matcher.to_json()
self.assertEqual(json['type'], 'NetfilterMatchAll')

def test_never_match(self):
def test_never_match(self) -> None:
matcher = NetfilterNeverMatch()
context = NetfilterContext()
self.assertFalse(matcher.match(context))
Expand All @@ -40,14 +40,14 @@ def test_never_match(self):
json = matcher.to_json()
self.assertEqual(json['type'], 'NetfilterNeverMatch')

def test_match_and_success(self):
def test_match_and_success(self) -> None:
m1 = NetfilterMatchAll()
m2 = NetfilterMatchAll()
matcher = NetfilterMatchAnd(m1, m2)
context = NetfilterContext()
self.assertTrue(matcher.match(context))

def test_match_and_fail_01(self):
def test_match_and_fail_01(self) -> None:
m1 = NetfilterNeverMatch()
m2 = NetfilterMatchAll()
matcher = NetfilterMatchAnd(m1, m2)
Expand All @@ -60,28 +60,28 @@ def test_match_and_fail_01(self):
self.assertEqual(json['match_params']['a']['type'], 'NetfilterNeverMatch')
self.assertEqual(json['match_params']['b']['type'], 'NetfilterMatchAll')

def test_match_and_fail_10(self):
def test_match_and_fail_10(self) -> None:
m1 = NetfilterMatchAll()
m2 = NetfilterNeverMatch()
matcher = NetfilterMatchAnd(m1, m2)
context = NetfilterContext()
self.assertFalse(matcher.match(context))

def test_match_and_fail_00(self):
def test_match_and_fail_00(self) -> None:
m1 = NetfilterNeverMatch()
m2 = NetfilterNeverMatch()
matcher = NetfilterMatchAnd(m1, m2)
context = NetfilterContext()
self.assertFalse(matcher.match(context))

def test_match_or_success_11(self):
def test_match_or_success_11(self) -> None:
m1 = NetfilterMatchAll()
m2 = NetfilterMatchAll()
matcher = NetfilterMatchOr(m1, m2)
context = NetfilterContext()
self.assertTrue(matcher.match(context))

def test_match_or_success_10(self):
def test_match_or_success_10(self) -> None:
m1 = NetfilterMatchAll()
m2 = NetfilterNeverMatch()
matcher = NetfilterMatchOr(m1, m2)
Expand All @@ -94,21 +94,21 @@ def test_match_or_success_10(self):
self.assertEqual(json['match_params']['a']['type'], 'NetfilterMatchAll')
self.assertEqual(json['match_params']['b']['type'], 'NetfilterNeverMatch')

def test_match_or_success_01(self):
def test_match_or_success_01(self) -> None:
m1 = NetfilterNeverMatch()
m2 = NetfilterMatchAll()
matcher = NetfilterMatchOr(m1, m2)
context = NetfilterContext()
self.assertTrue(matcher.match(context))

def test_match_or_fail_00(self):
def test_match_or_fail_00(self) -> None:
m1 = NetfilterNeverMatch()
m2 = NetfilterNeverMatch()
matcher = NetfilterMatchOr(m1, m2)
context = NetfilterContext()
self.assertFalse(matcher.match(context))

def test_match_ip_address_empty_context(self):
def test_match_ip_address_empty_context(self) -> None:
matcher = NetfilterMatchIPAddress('192.168.0.0/24')
context = NetfilterContext()
self.assertFalse(matcher.match(context))
Expand All @@ -118,7 +118,7 @@ def test_match_ip_address_empty_context(self):
self.assertEqual(json['type'], 'NetfilterMatchIPAddress')
self.assertEqual(json['match_params']['host'], '192.168.0.0/24')

def test_match_ip_address_ipv4_net(self):
def test_match_ip_address_ipv4_net(self) -> None:
matcher = NetfilterMatchIPAddress('192.168.0.0/24')
context = NetfilterContext(addr=IPv4Address('TCP', '192.168.0.10', 1234))
self.assertTrue(matcher.match(context))
Expand All @@ -129,7 +129,7 @@ def test_match_ip_address_ipv4_net(self):
context = NetfilterContext(addr=IPv4Address('TCP', '', 1234))
self.assertFalse(matcher.match(context))

def test_match_ip_address_ipv4_ip(self):
def test_match_ip_address_ipv4_ip(self) -> None:
matcher = NetfilterMatchIPAddress('192.168.0.1/32')
context = NetfilterContext(addr=IPv4Address('TCP', '192.168.0.1', 1234))
self.assertTrue(matcher.match(context))
Expand All @@ -138,24 +138,24 @@ def test_match_ip_address_ipv4_ip(self):
context = NetfilterContext(addr=IPv4Address('TCP', '', 1234))
self.assertFalse(matcher.match(context))

def test_match_ip_address_ipv4_hostname(self):
def test_match_ip_address_ipv4_hostname(self) -> None:
matcher = NetfilterMatchIPAddress('192.168.0.1/32')
context = NetfilterContext(addr=HostnameAddress('hathor.network', 80))
context = NetfilterContext(addr=HostnameAddress(b'hathor.network', 80))
self.assertFalse(matcher.match(context))

def test_match_ip_address_ipv4_unix(self):
def test_match_ip_address_ipv4_unix(self) -> None:
matcher = NetfilterMatchIPAddress('192.168.0.1/32')
context = NetfilterContext(addr=UNIXAddress('/unix.sock'))
self.assertFalse(matcher.match(context))

def test_match_ip_address_ipv4_ipv6(self):
def test_match_ip_address_ipv4_ipv6(self) -> None:
matcher = NetfilterMatchIPAddress('192.168.0.1/32')
context = NetfilterContext(addr=IPv6Address('TCP', '2001:db8::', 80))
self.assertFalse(matcher.match(context))
context = NetfilterContext(addr=IPv6Address('TCP', '', 80))
self.assertFalse(matcher.match(context))

def test_match_ip_address_ipv6_net(self):
def test_match_ip_address_ipv6_net(self) -> None:
matcher = NetfilterMatchIPAddress('2001:0db8:0:f101::/64')
context = NetfilterContext(addr=IPv6Address('TCP', '2001:db8::8a2e:370:7334', 1234))
self.assertFalse(matcher.match(context))
Expand All @@ -167,7 +167,7 @@ def test_match_ip_address_ipv6_net(self):
self.assertEqual(json['type'], 'NetfilterMatchIPAddress')
self.assertEqual(json['match_params']['host'], str(ip_network('2001:0db8:0:f101::/64')))

def test_match_ip_address_ipv6_ip(self):
def test_match_ip_address_ipv6_ip(self) -> None:
matcher = NetfilterMatchIPAddress('2001:0db8:0:f101::1/128')
context = NetfilterContext(addr=IPv6Address('TCP', '2001:db8:0:f101::1', 1234))
self.assertTrue(matcher.match(context))
Expand All @@ -176,22 +176,22 @@ def test_match_ip_address_ipv6_ip(self):
context = NetfilterContext(addr=IPv6Address('TCP', '2001:db8:0:f101:2::7334', 1234))
self.assertFalse(matcher.match(context))

def test_match_ip_address_ipv6_hostname(self):
def test_match_ip_address_ipv6_hostname(self) -> None:
matcher = NetfilterMatchIPAddress('2001:0db8:0:f101::1/128')
context = NetfilterContext(addr=HostnameAddress('hathor.network', 80))
context = NetfilterContext(addr=HostnameAddress(b'hathor.network', 80))
self.assertFalse(matcher.match(context))

def test_match_ip_address_ipv6_unix(self):
def test_match_ip_address_ipv6_unix(self) -> None:
matcher = NetfilterMatchIPAddress('2001:0db8:0:f101::1/128')
context = NetfilterContext(addr=UNIXAddress('/unix.sock'))
self.assertFalse(matcher.match(context))

def test_match_ip_address_ipv6_ipv4(self):
def test_match_ip_address_ipv6_ipv4(self) -> None:
matcher = NetfilterMatchIPAddress('2001:0db8:0:f101::1/128')
context = NetfilterContext(addr=IPv4Address('TCP', '192.168.0.1', 1234))
self.assertFalse(matcher.match(context))

def test_match_peer_id_empty_context(self):
def test_match_peer_id_empty_context(self) -> None:
matcher = NetfilterMatchPeerId('123')
context = NetfilterContext()
self.assertFalse(matcher.match(context))
Expand All @@ -200,7 +200,7 @@ def test_match_peer_id_empty_context(self):
class BaseNetfilterMatchTest(unittest.TestCase):
__test__ = False

def test_match_peer_id(self):
def test_match_peer_id(self) -> None:
network = 'testnet'
peer_id1 = PeerId()
peer_id2 = PeerId()
Expand Down
2 changes: 1 addition & 1 deletion tests/p2p/netfilter/test_match_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class NetfilterMatchRemoteTest(unittest.TestCase):
def test_match_ip(self):
def test_match_ip(self) -> None:
matcher = NetfilterMatchIPAddressRemoteURL('test', self.clock, 'http://localhost:8080')
context = NetfilterContext(addr=IPv4Address('TCP', '192.168.0.1', 1234))
self.assertFalse(matcher.match(context))
Expand Down
6 changes: 3 additions & 3 deletions tests/p2p/netfilter/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@


class NetfilterTableTest(unittest.TestCase):
def test_default_table_filter(self):
def test_default_table_filter(self) -> None:
tb_filter = get_table('filter')
tb_filter.get_chain('pre_conn')
tb_filter.get_chain('post_hello')
tb_filter.get_chain('post_peerid')

def test_default_table_not_exists(self):
def test_default_table_not_exists(self) -> None:
with self.assertRaises(KeyError):
get_table('do-not-exists')

def test_add_get_chain(self):
def test_add_get_chain(self) -> None:
mytable = NetfilterTable('mytable')
mychain = NetfilterChain('mychain', NetfilterAccept())
mytable.add_chain(mychain)
Expand Down
2 changes: 1 addition & 1 deletion tests/p2p/netfilter/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class NetfilterUtilsTest(unittest.TestCase):
def test_peer_id_blacklist(self):
def test_peer_id_blacklist(self) -> None:
post_peerid = get_table('filter').get_chain('post_peerid')

# Chain starts empty
Expand Down
13 changes: 11 additions & 2 deletions tests/p2p/test_capabilities.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from hathor.p2p.states import ReadyState
from hathor.p2p.sync_v1.agent import NodeSyncTimestamp
from hathor.p2p.sync_v2.agent import NodeBlockSync
from hathor.simulator import FakeConnection
from tests import unittest


class SyncV1HathorCapabilitiesTestCase(unittest.SyncV1Params, unittest.TestCase):
def test_capabilities(self):
def test_capabilities(self) -> None:
network = 'testnet'
manager1 = self.create_peer(network, capabilities=[self._settings.CAPABILITY_WHITELIST])
manager2 = self.create_peer(network, capabilities=[])
Expand All @@ -18,6 +19,8 @@ def test_capabilities(self):
self.clock.advance(0.1)

# Even if we don't have the capability we must connect because the whitelist url conf is None
assert isinstance(conn._proto1.state, ReadyState)
assert isinstance(conn._proto2.state, ReadyState)
self.assertEqual(conn._proto1.state.state_name, 'READY')
self.assertEqual(conn._proto2.state.state_name, 'READY')
self.assertIsInstance(conn._proto1.state.sync_agent, NodeSyncTimestamp)
Expand All @@ -33,14 +36,16 @@ def test_capabilities(self):
conn2.run_one_step(debug=True)
self.clock.advance(0.1)

assert isinstance(conn2._proto1.state, ReadyState)
assert isinstance(conn2._proto2.state, ReadyState)
self.assertEqual(conn2._proto1.state.state_name, 'READY')
self.assertEqual(conn2._proto2.state.state_name, 'READY')
self.assertIsInstance(conn2._proto1.state.sync_agent, NodeSyncTimestamp)
self.assertIsInstance(conn2._proto2.state.sync_agent, NodeSyncTimestamp)


class SyncV2HathorCapabilitiesTestCase(unittest.SyncV2Params, unittest.TestCase):
def test_capabilities(self):
def test_capabilities(self) -> None:
network = 'testnet'
manager1 = self.create_peer(network, capabilities=[self._settings.CAPABILITY_WHITELIST,
self._settings.CAPABILITY_SYNC_VERSION])
Expand All @@ -54,6 +59,8 @@ def test_capabilities(self):
self.clock.advance(0.1)

# Even if we don't have the capability we must connect because the whitelist url conf is None
assert isinstance(conn._proto1.state, ReadyState)
assert isinstance(conn._proto2.state, ReadyState)
self.assertEqual(conn._proto1.state.state_name, 'READY')
self.assertEqual(conn._proto2.state.state_name, 'READY')
self.assertIsInstance(conn._proto1.state.sync_agent, NodeBlockSync)
Expand All @@ -71,6 +78,8 @@ def test_capabilities(self):
conn2.run_one_step(debug=True)
self.clock.advance(0.1)

assert isinstance(conn2._proto1.state, ReadyState)
assert isinstance(conn2._proto2.state, ReadyState)
self.assertEqual(conn2._proto1.state.state_name, 'READY')
self.assertEqual(conn2._proto2.state.state_name, 'READY')
self.assertIsInstance(conn2._proto1.state.sync_agent, NodeBlockSync)
Expand Down
4 changes: 2 additions & 2 deletions tests/p2p/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class ConnectionsTest(unittest.TestCase):
@pytest.mark.skipif(sys.platform == 'win32', reason='run_server is very finicky on Windows')
def test_connections(self):
def test_connections(self) -> None:
process = run_server()
process2 = run_server(listen=8006, status=8086, bootstrap='tcp://127.0.0.1:8005')
process3 = run_server(listen=8007, status=8087, bootstrap='tcp://127.0.0.1:8005')
Expand All @@ -17,7 +17,7 @@ def test_connections(self):
process2.terminate()
process3.terminate()

def test_manager_connections(self):
def test_manager_connections(self) -> None:
manager = self.create_peer('testnet', enable_sync_v1=True, enable_sync_v2=False)

endpoint = 'tcp://127.0.0.1:8005'
Expand Down
Loading

0 comments on commit ef291c0

Please sign in to comment.