From 27a0989c1d1aa346fe599b3d93e673f4017eff1c Mon Sep 17 00:00:00 2001 From: Quinten Stokkink Date: Mon, 29 Apr 2024 15:19:13 +0200 Subject: [PATCH] Fixed ruff violations in content discovery folder --- src/tribler/core/content_discovery/cache.py | 31 ++- .../core/content_discovery/community.py | 243 +++++++++++------- src/tribler/core/content_discovery/payload.py | 134 +++++++--- .../restapi/search_endpoint.py | 54 ++-- src/tribler/core/database/store.py | 5 +- .../core/content_discovery/test_cache.py | 8 +- 6 files changed, 323 insertions(+), 152 deletions(-) diff --git a/src/tribler/core/content_discovery/cache.py b/src/tribler/core/content_discovery/cache.py index 441053c372..fc9c9ded98 100644 --- a/src/tribler/core/content_discovery/cache.py +++ b/src/tribler/core/content_discovery/cache.py @@ -1,9 +1,29 @@ -from ipv8.requestcache import RandomNumberCache +from __future__ import annotations + +from binascii import hexlify +from typing import TYPE_CHECKING, Callable + +from ipv8.requestcache import RandomNumberCache, RequestCache +from typing_extensions import Self + +if TYPE_CHECKING: + from ipv8.types import Peer + + from tribler.core.database.store import ProcessingResult class SelectRequest(RandomNumberCache): - def __init__(self, request_cache, prefix, request_kwargs, peer, processing_callback=None, timeout_callback=None): - super().__init__(request_cache, prefix) + """ + Keep track of the packets to a Peer during the answering of a select request. + """ + + def __init__(self, request_cache: RequestCache, request_kwargs: dict, peer: Peer, + processing_callback: Callable[[Self, list[ProcessingResult]], None] | None = None, + timeout_callback: Callable[[Self], None] | None = None) -> None: + """ + Create a new select request cache. + """ + super().__init__(request_cache, hexlify(peer.mid).decode()) self.request_kwargs = request_kwargs # The callback to call on results of processing of the response payload self.processing_callback = processing_callback @@ -17,6 +37,9 @@ def __init__(self, request_cache, prefix, request_kwargs, peer, processing_callb self.timeout_callback = timeout_callback - def on_timeout(self): + def on_timeout(self) -> None: + """ + Call the timeout callback, if one is registered. + """ if self.timeout_callback is not None: self.timeout_callback(self) diff --git a/src/tribler/core/content_discovery/community.py b/src/tribler/core/content_discovery/community.py index 3d892bb2c6..c3ac916129 100644 --- a/src/tribler/core/content_discovery/community.py +++ b/src/tribler/core/content_discovery/community.py @@ -5,31 +5,43 @@ import sys import time import uuid -from binascii import unhexlify, hexlify +from binascii import hexlify, unhexlify from itertools import count -from typing import Any, Dict, List, Optional, Set, Sequence +from typing import TYPE_CHECKING, Any, Callable, Sequence from ipv8.community import Community, CommunitySettings from ipv8.lazy_community import lazy_wrapper from ipv8.requestcache import RequestCache -from ipv8.types import Peer from pony.orm import OperationalError, db_session from tribler.core.content_discovery.cache import SelectRequest -from tribler.core.content_discovery.payload import (PopularTorrentsRequest, RemoteSelectPayload, - SelectResponsePayload, TorrentsHealthPayload, - VersionRequest, VersionResponse) +from tribler.core.content_discovery.payload import ( + PopularTorrentsRequest, + RemoteSelectPayload, + SelectResponsePayload, + TorrentsHealthPayload, + VersionRequest, + VersionResponse, +) from tribler.core.database.layers.knowledge import ResourceType -from tribler.core.database.orm_bindings.torrent_metadata import LZ4_EMPTY_ARCHIVE, entries_to_chunk -from tribler.core.database.store import MetadataStore, ObjState -from tribler.core.database.tribler_database import TriblerDatabase +from tribler.core.database.orm_bindings.torrent_metadata import LZ4_EMPTY_ARCHIVE, TorrentMetadata, entries_to_chunk +from tribler.core.database.store import MetadataStore, ObjState, ProcessingResult from tribler.core.knowledge.community import is_valid_resource from tribler.core.notifier import Notification, Notifier from tribler.core.torrent_checker.dataclasses import HealthInfo -from tribler.core.torrent_checker.torrent_checker import TorrentChecker + +if TYPE_CHECKING: + from ipv8.types import Peer + + from tribler.core.database.tribler_database import TriblerDatabase + from tribler.core.torrent_checker.torrent_checker import TorrentChecker class ContentDiscoverySettings(CommunitySettings): + """ + The settings for the content discovery community. + """ + random_torrent_interval: float = 5 # seconds random_torrent_count: int = 10 max_query_peers: int = 20 @@ -37,7 +49,7 @@ class ContentDiscoverySettings(CommunitySettings): max_response_size: int = 100 # Max number of entries returned by SQL query binary_fields: Sequence[str] = ("infohash", "channel_pk") - deprecated_parameters: Sequence[str] = ('subscribed', 'attribute_ranges', 'complete_channel') + deprecated_parameters: Sequence[str] = ("subscribed", "attribute_ranges", "complete_channel") metadata_store: MetadataStore torrent_checker: TorrentChecker @@ -48,19 +60,15 @@ class ContentDiscoverySettings(CommunitySettings): class ContentDiscoveryCommunity(Community): """ Community for disseminating the content across the network. - - Push: - - Every 5 seconds it gossips 10 random torrents to a random peer. - Pull: - - Every time it receives an introduction request, it sends a request - to return their popular torrents. - - Gossiping is for checked torrents only. """ - community_id = unhexlify('9aca62f878969c437da9844cba29a134917e1648') + + community_id = unhexlify("9aca62f878969c437da9844cba29a134917e1648") settings_class = ContentDiscoverySettings - def __init__(self, settings: ContentDiscoverySettings): + def __init__(self, settings: ContentDiscoverySettings) -> None: + """ + Create a new overlay for content discovery. + """ super().__init__(settings) self.composition = settings @@ -79,26 +87,35 @@ def __init__(self, settings: ContentDiscoverySettings): self.remote_queries_in_progress = 0 self.next_remote_query_num = count().__next__ # generator of sequential numbers, for logging & debug purposes - self.logger.info('Content Discovery Community initialized (peer mid %s)', hexlify(self.my_peer.mid)) + self.logger.info("Content Discovery Community initialized (peer mid %s)", hexlify(self.my_peer.mid)) self.register_task("gossip_random_torrents", self.gossip_random_torrents_health, interval=self.composition.random_torrent_interval) - async def unload(self): + async def unload(self) -> None: + """ + Shut down the request cache. + """ await self.request_cache.shutdown() await super().unload() - def sanitize_dict(self, parameters: dict[str, Any], decode=True) -> None: + def sanitize_dict(self, parameters: dict[str, Any], decode: bool = True) -> None: + """ + Convert the binary values in the given dictionary to (decode=True) and from (decode=False) hex format. + """ for field in self.composition.binary_fields: value = parameters.get(field) if value is not None: parameters[field] = unhexlify(value.encode()) if decode else hexlify(value.encode()).decode() - def sanitize_query(self, query_dict: Dict[str, Any], cap=100) -> Dict[str, Any]: + def sanitize_query(self, query_dict: dict[str, Any], cap: int = 100) -> dict[str, Any]: + """ + Convert the values in a query to the appropriate format and supply missing values. + """ sanitized_dict = dict(query_dict) # We impose a cap on max numbers of returned entries to prevent DDOS-like attacks - first = sanitized_dict.get("first", None) or 0 - last = sanitized_dict.get("last", None) + first = sanitized_dict.get("first") or 0 + last = sanitized_dict.get("last") last = last if (last is not None and last <= (first + cap)) else (first + cap) sanitized_dict.update({"first": first, "last": last}) @@ -107,7 +124,10 @@ def sanitize_query(self, query_dict: Dict[str, Any], cap=100) -> Dict[str, Any]: return sanitized_dict - def convert_to_json(self, parameters): + def convert_to_json(self, parameters: dict[str, Any]) -> str: + """ + Sanitize and dump the given dictionary to a string using JSON. + """ sanitized = dict(parameters) # Convert metadata_type to an int list if it is a string if "metadata_type" in sanitized and isinstance(sanitized["metadata_type"], str): @@ -120,7 +140,10 @@ def convert_to_json(self, parameters): return json.dumps(sanitized) - def get_alive_checked_torrents(self) -> List[HealthInfo]: + def get_alive_checked_torrents(self) -> list[HealthInfo]: + """ + Get torrents that we know have seeders AND leechers. + """ if not self.composition.torrent_checker: return [] @@ -128,7 +151,7 @@ def get_alive_checked_torrents(self) -> List[HealthInfo]: return [health for health in self.composition.torrent_checker.torrents_checked.values() if health.seeders > 0 and health.leechers >= 0] - def gossip_random_torrents_health(self): + def gossip_random_torrents_health(self) -> None: """ Gossip random torrent health information to another peer. """ @@ -142,10 +165,12 @@ def gossip_random_torrents_health(self): self.ez_send(p, PopularTorrentsRequest()) @lazy_wrapper(TorrentsHealthPayload) - async def on_torrents_health(self, peer, payload: TorrentsHealthPayload): - self.logger.debug(f"Received torrent health information for " - f"{len(payload.torrents_checked)} popular torrents and" - f" {len(payload.random_torrents)} random torrents") + async def on_torrents_health(self, peer: Peer, payload: TorrentsHealthPayload) -> None: + """ + Callback for when we receive torrent health. + """ + self.logger.debug("Received torrent health information for %d popular torrents" + " and %d random torrents", len(payload.torrents_checked), len(payload.random_torrents)) health_tuples = payload.random_torrents + payload.torrents_checked health_list = [HealthInfo(infohash, last_check=last_check, seeders=seeders, leechers=leechers) @@ -159,7 +184,10 @@ async def on_torrents_health(self, peer, payload: TorrentsHealthPayload): self.send_remote_select(peer=peer, infohash=infohash, last=1) @db_session - def process_torrents_health(self, health_list: List[HealthInfo]): + def process_torrents_health(self, health_list: list[HealthInfo]) -> set[bytes]: + """ + Get the infohashes that we did not know about before from the given health list. + """ infohashes_to_resolve = set() for health in health_list: added = self.composition.metadata_store.process_torrent_health(health) @@ -168,12 +196,18 @@ def process_torrents_health(self, health_list: List[HealthInfo]): return infohashes_to_resolve @lazy_wrapper(PopularTorrentsRequest) - async def on_popular_torrents_request(self, peer, payload): + async def on_popular_torrents_request(self, peer: Peer, payload: PopularTorrentsRequest) -> None: + """ + Callback for when we receive a request for popular torrents. + """ self.logger.debug("Received popular torrents health request") - popular_torrents = self.get_likely_popular_torrents() + popular_torrents = self.get_random_torrents() self.ez_send(peer, TorrentsHealthPayload.create({}, popular_torrents)) - def get_likely_popular_torrents(self) -> List[HealthInfo]: + def get_random_torrents(self) -> list[HealthInfo]: + """ + Get torrent health info for torrents that were alive, last we know of. + """ checked_and_alive = self.get_alive_checked_torrents() if not checked_and_alive: return [] @@ -181,27 +215,20 @@ def get_likely_popular_torrents(self) -> List[HealthInfo]: num_torrents_to_send = min(self.composition.random_torrent_count, len(checked_and_alive)) return random.sample(checked_and_alive, num_torrents_to_send) - def get_random_torrents(self) -> List[HealthInfo]: - checked_and_alive = list(self.get_alive_checked_torrents()) - if not checked_and_alive: - return [] - - num_torrents = len(checked_and_alive) - num_torrents_to_send = min(self.composition.random_torrent_count, num_torrents) - - random_torrents = random.sample(checked_and_alive, num_torrents_to_send) - return random_torrents - - def get_random_peers(self, sample_size=None): - # Randomly sample sample_size peers from the complete list of our peers + def get_random_peers(self, sample_size: int | None = None) -> list[Peer]: + """ + Randomly sample sample_size peers from the complete list of our peers. + """ all_peers = self.get_peers() return random.sample(all_peers, min(sample_size or len(all_peers), len(all_peers))) - def send_search_request(self, **kwargs): - # Send a remote query request to multiple random peers to search for some terms + def send_search_request(self, **kwargs) -> tuple[uuid.UUID, list[Peer]]: + """ + Send a remote query request to multiple random peers to search for some terms. + """ request_uuid = uuid.uuid4() - def notify_gui(request, processing_results): + def notify_gui(request: SelectRequest, processing_results: list[ProcessingResult]) -> None: results = [ r.md_obj.to_simple_dict() for r in processing_results @@ -224,49 +251,59 @@ def notify_gui(request, processing_results): return request_uuid, peers_to_query @lazy_wrapper(VersionRequest) - async def on_version_request(self, peer, _): + async def on_version_request(self, peer: Peer, _: VersionRequest) -> None: + """ + Callback for when our Tribler version and Operating System is requested. + """ version_response = VersionResponse("Tribler Experimental", sys.platform) self.ez_send(peer, version_response) @lazy_wrapper(VersionResponse) - async def on_version_response(self, peer, payload): - pass - - def send_remote_select(self, peer, processing_callback=None, **kwargs): - request = SelectRequest( - self.request_cache, - hexlify(peer.mid).decode(), - kwargs, - peer, - processing_callback=processing_callback, - timeout_callback=self._on_query_timeout, - ) + async def on_version_response(self, peer: Peer, payload: VersionResponse) -> None: + """ + Callback for when we receive a Tribler version and Operating System of a peer. + """ + + def send_remote_select(self, peer: Peer, + processing_callback: Callable[[SelectRequest, list[ProcessingResult]], None] | None = None, + **kwargs) -> SelectRequest: + """ + Query a peer using an SQL statement descriptions (kwargs). + """ + request = SelectRequest(self.request_cache, kwargs, peer, processing_callback, self._on_query_timeout) self.request_cache.add(request) - self.logger.debug(f"Select to {hexlify(peer.mid).decode()} with ({kwargs})") + self.logger.debug("Select to %s with (%s)", hexlify(peer.mid).decode(), str(kwargs)) self.ez_send(peer, RemoteSelectPayload(request.number, self.convert_to_json(kwargs).encode())) return request - def should_limit_rate_for_query(self, sanitized_parameters: Dict[str, Any]) -> bool: - return 'txt_filter' in sanitized_parameters + def should_limit_rate_for_query(self, sanitized_parameters: dict[str, Any]) -> bool: + """ + Don't allow too many queries with potentially heavy database load. + """ + return "txt_filter" in sanitized_parameters - async def process_rpc_query_rate_limited(self, sanitized_parameters: Dict[str, Any]) -> List: + async def process_rpc_query_rate_limited(self, sanitized_parameters: dict[str, Any]) -> list: + """ + Process the given query and return results. + """ query_num = self.next_remote_query_num() if self.remote_queries_in_progress and self.should_limit_rate_for_query(sanitized_parameters): - self.logger.warning(f'Ignore remote query {query_num} as another one is already processing. ' - f'The ignored query: {sanitized_parameters}') + self.logger.warning("Ignore remote query %d as another one is already processing. The ignored query: %s", + query_num, sanitized_parameters) return [] - self.logger.info(f'Process remote query {query_num}: {sanitized_parameters}') + self.logger.info("Process remote query %d: %s", query_num, sanitized_parameters) self.remote_queries_in_progress += 1 t = time.time() try: return await self.process_rpc_query(sanitized_parameters) finally: self.remote_queries_in_progress -= 1 - self.logger.info(f'Remote query {query_num} processed in {time.time() - t} seconds: {sanitized_parameters}') + self.logger.info("Remote query %d processed in %f seconds: %s", + query_num, time.time() - t, sanitized_parameters) - async def process_rpc_query(self, sanitized_parameters: Dict[str, Any]) -> List: + async def process_rpc_query(self, sanitized_parameters: dict[str, Any]) -> list: """ Retrieve the result of a database query from a third party, encoded as raw JSON bytes (through `dumps`). @@ -276,32 +313,36 @@ async def process_rpc_query(self, sanitized_parameters: Dict[str, Any]) -> List: """ if self.composition.tribler_db: # tags should be extracted because `get_entries_threaded` doesn't expect them as a parameter - tags = sanitized_parameters.pop('tags', None) + tags = sanitized_parameters.pop("tags", None) infohash_set = self.composition.tribler_db.instance(self.search_for_tags, tags) if infohash_set: - sanitized_parameters['infohash_set'] = {bytes.fromhex(s) for s in infohash_set} + sanitized_parameters["infohash_set"] = {bytes.fromhex(s) for s in infohash_set} # exclude_deleted should be extracted because `get_entries_threaded` doesn't expect it as a parameter - sanitized_parameters.pop('exclude_deleted', None) + sanitized_parameters.pop("exclude_deleted", None) return await self.composition.metadata_store.get_entries_threaded(**sanitized_parameters) @db_session - def search_for_tags(self, tags: Optional[List[str]]) -> Optional[Set[str]]: + def search_for_tags(self, tags: list[str] | None) -> set[str] | None: + """ + Query our local database for the given tags. + """ if not tags or not self.composition.tribler_db: return None valid_tags = {tag for tag in tags if is_valid_resource(tag)} - result = self.composition.tribler_db.knowledge.get_subjects_intersection( + return self.composition.tribler_db.knowledge.get_subjects_intersection( subjects_type=ResourceType.TORRENT, objects=valid_tags, predicate=ResourceType.TAG, case_sensitive=False ) - return result - - def send_db_results(self, peer, request_payload_id, db_results): + def send_db_results(self, peer: Peer, request_payload_id: int, db_results: list[TorrentMetadata]) -> None: + """ + Send the given results to the given peer. + """ # Special case of empty results list - sending empty lz4 archive if len(db_results) == 0: self.ez_send(peer, SelectResponsePayload(request_payload_id, LZ4_EMPTY_ARCHIVE)) @@ -315,35 +356,43 @@ def send_db_results(self, peer, request_payload_id, db_results): self.ez_send(peer, payload) @lazy_wrapper(RemoteSelectPayload) - async def on_remote_select(self, peer, request_payload): + async def on_remote_select(self, peer: Peer, request_payload: RemoteSelectPayload) -> None: + """ + Callback for when another peer queries us. + """ try: sanitized_parameters = self.parse_parameters(request_payload.json) # Drop selects with deprecated queries if any(param in sanitized_parameters for param in self.composition.deprecated_parameters): - self.logger.warning(f"Remote select with deprecated parameters: {sanitized_parameters}") + self.logger.warning("Remote select with deprecated parameters: %s", str(sanitized_parameters)) self.ez_send(peer, SelectResponsePayload(request_payload.id, LZ4_EMPTY_ARCHIVE)) return db_results = await self.process_rpc_query_rate_limited(sanitized_parameters) self.send_db_results(peer, request_payload.id, db_results) except (OperationalError, TypeError, ValueError) as error: - self.logger.error(f"Remote select error: {error}. Request content: {request_payload.json!r}") + self.logger.exception("Remote select error: %s. Request content: %s", + str(error), repr(request_payload.json)) - def parse_parameters(self, json_bytes: bytes) -> Dict[str, Any]: + def parse_parameters(self, json_bytes: bytes) -> dict[str, Any]: + """ + Load a (JSON) dict from the given bytes and sanitize it to use as a database query. + """ return self.sanitize_query(json.loads(json_bytes), self.composition.max_response_size) @lazy_wrapper(SelectResponsePayload) - async def on_remote_select_response(self, peer, response_payload): + async def on_remote_select_response(self, peer: Peer, + response_payload: SelectResponsePayload) -> list[ProcessingResult] | None: """ Match the response that we received from the network to a query cache and process it by adding the corresponding entries to the MetadataStore database. - This processes both direct responses and pushback (updates) responses + This processes both direct responses and pushback (updates) responses. """ - self.logger.debug(f"Response from {hexlify(peer.mid).decode()}") + self.logger.debug("Response from %s", hexlify(peer.mid).decode()) request: SelectRequest | None = self.request_cache.get(hexlify(peer.mid).decode(), response_payload.id) if request is None: - return + return None # Check for limit on the number of packets per request if request.packets_limit > 1: @@ -354,7 +403,7 @@ async def on_remote_select_response(self, peer, response_payload): processing_results = await self.composition.metadata_store.process_compressed_mdblob_threaded( response_payload.raw_blob ) - self.logger.debug(f"Response result: {processing_results}") + self.logger.debug("Response result: %s", str(processing_results)) if isinstance(request, SelectRequest) and request.processing_callback: request.processing_callback(request, processing_results) @@ -365,7 +414,10 @@ async def on_remote_select_response(self, peer, response_payload): return processing_results - def _on_query_timeout(self, request_cache): + def _on_query_timeout(self, request_cache: SelectRequest) -> None: + """ + Remove a peer if it failed to respond to our select request. + """ if not request_cache.peer_responded: self.logger.debug( "Remote query timeout, deleting peer: %s %s %s", @@ -376,4 +428,7 @@ def _on_query_timeout(self, request_cache): self.network.remove_peer(request_cache.peer) def send_ping(self, peer: Peer) -> None: + """ + Send a ping to a peer to keep it alive. + """ self.send_introduction_request(peer) diff --git a/src/tribler/core/content_discovery/payload.py b/src/tribler/core/content_discovery/payload.py index e9947cc4fd..9bed398515 100644 --- a/src/tribler/core/content_discovery/payload.py +++ b/src/tribler/core/content_discovery/payload.py @@ -1,16 +1,23 @@ from __future__ import annotations -from typing import List + +from typing import TYPE_CHECKING from ipv8.messaging.lazy_payload import VariablePayload, vp_compile from ipv8.messaging.serialization import default_serializer +from typing_extensions import Self -from tribler.core.torrent_checker.dataclasses import HealthInfo +if TYPE_CHECKING: + from tribler.core.torrent_checker.dataclasses import HealthInfo @vp_compile class TorrentInfoFormat(VariablePayload): - format_list = ['20s', 'I', 'I', 'Q'] - names = ['infohash', 'seeders', 'leechers', 'timestamp'] + """ + For a given infohash at a given time: the known seeders and leechers. + """ + + format_list = ["20s", "I", "I", "Q"] + names = ["infohash", "seeders", "leechers", "timestamp"] length = 36 infohash: bytes @@ -18,89 +25,146 @@ class TorrentInfoFormat(VariablePayload): leechers: int timestamp: int - def to_tuple(self): + def to_tuple(self) -> tuple[bytes, int, int, int]: + """ + Convert this payload to a tuple. + """ return self.infohash, self.seeders, self.leechers, self.timestamp @classmethod - def from_list_bytes(cls, serialized): + def from_list_bytes(cls: type[Self], serialized: bytes) -> list[Self]: + """ + Convert the given bytes to a list of this payload. + """ return default_serializer.unpack_serializable_list([cls] * (len(serialized) // cls.length), serialized, consume_all=False)[:-1] @vp_compile class TorrentsHealthPayload(VariablePayload): + """ + A payload for lists of health information. + + For backward compatibility, this payload includes two lists. Originally, one list was for random torrents and + one list was for torrents that we personally checked. Now, only one is used. + """ + msg_id = 1 - format_list = ['I', 'I', 'varlenI', 'raw'] # Number of random torrents, number of torrents checked by you - names = ['random_torrents_length', 'torrents_checked_length', 'random_torrents', 'torrents_checked'] + format_list = ["I", "I", "varlenI", "raw"] # Number of random torrents, number of torrents checked by you + names = ["random_torrents_length", "torrents_checked_length", "random_torrents", "torrents_checked"] random_torrents_length: int torrents_checked_length: int - random_torrents: list[tuple] - torrents_checked: list[tuple] + random_torrents: list[tuple[bytes, int, int, int]] + torrents_checked: list[tuple[bytes, int, int, int]] - def fix_pack_random_torrents(self, value) -> bytes: - return b''.join(default_serializer.pack_serializable(TorrentInfoFormat(*sublist)) for sublist in value) + def fix_pack_random_torrents(self, value: list[tuple[bytes, int, int, int]]) -> bytes: + """ + Convert the list of random torrent info tuples to bytes. + """ + return b"".join(default_serializer.pack_serializable(TorrentInfoFormat(*sublist)) for sublist in value) - def fix_pack_torrents_checked(self, value) -> bytes: - return b''.join(default_serializer.pack_serializable(TorrentInfoFormat(*sublist)) for sublist in value) + def fix_pack_torrents_checked(self, value: list[tuple[bytes, int, int, int]]) -> bytes: + """ + Convert the list of checked torrent info tuples to bytes. + """ + return b"".join(default_serializer.pack_serializable(TorrentInfoFormat(*sublist)) for sublist in value) @classmethod - def fix_unpack_random_torrents(cls, value): + def fix_unpack_random_torrents(cls: type[Self], value: bytes) -> list[tuple[bytes, int, int, int]]: + """ + Convert the raw data back to a list of random torrent info tuples. + """ return [payload.to_tuple() for payload in TorrentInfoFormat.from_list_bytes(value)] @classmethod - def fix_unpack_torrents_checked(cls, value): + def fix_unpack_torrents_checked(cls: type[Self], value: bytes) -> list[tuple[bytes, int, int, int]]: + """ + Convert the raw data back to a list of checked torrent info tuples. + """ return [payload.to_tuple() for payload in TorrentInfoFormat.from_list_bytes(value)] @classmethod - def create(cls, random_torrents_checked: List[HealthInfo], popular_torrents_checked: List[HealthInfo]): + def create(cls: type[Self], random_torrents_checked: list[HealthInfo], + popular_torrents_checked: list[HealthInfo]) -> Self: + """ + Create a payload from the given lists. + """ random_torrent_tuples = [(health.infohash, health.seeders, health.leechers, health.last_check) for health in random_torrents_checked] popular_torrent_tuples = [(health.infohash, health.seeders, health.leechers, health.last_check) for health in popular_torrents_checked] - return cls(len(random_torrents_checked), len(popular_torrents_checked), - random_torrent_tuples, popular_torrent_tuples) + return cls(len(random_torrents_checked), len(popular_torrents_checked), random_torrent_tuples, + popular_torrent_tuples) @vp_compile class PopularTorrentsRequest(VariablePayload): + """ + A request to be sent the health information of popular torrents. + """ + msg_id = 2 @vp_compile class VersionRequest(VariablePayload): + """ + A request for the Tribler version and Operating System of a peer. + """ + msg_id = 101 @vp_compile class VersionResponse(VariablePayload): + """ + A response to a request for Tribler version and OS. + """ + msg_id = 102 - format_list = ['varlenI', 'varlenI'] - names = ['version', 'platform'] + format_list = ["varlenI", "varlenI"] + names = ["version", "platform"] version: str platform: str - def fix_pack_version(self, value): - return value.encode('utf-8') + def fix_pack_version(self, value: str) -> bytes: + """ + Convert the (utf-8) Tribler version string to bytes. + """ + return value.encode() - def fix_pack_platform(self, value): - return value.encode('utf-8') + def fix_pack_platform(self, value: str) -> bytes: + """ + Convert the (utf-8) platform description string to bytes. + """ + return value.encode() @classmethod - def fix_unpack_version(cls, value): - return value.decode('utf-8') + def fix_unpack_version(cls: type[Self], value: bytes) -> str: + """ + Convert the packed Tribler version back to a string. + """ + return value.decode() @classmethod - def fix_unpack_platform(cls, value): - return value.decode('utf-8') + def fix_unpack_platform(cls: type[Self], value: bytes) -> str: + """ + Convert the packed platform description back to a string. + """ + return value.decode() @vp_compile class RemoteSelectPayload(VariablePayload): + """ + A payload to sent SQL queries to other peers. + """ + msg_id = 201 - format_list = ['I', 'varlenH'] - names = ['id', 'json'] + format_list = ["I", "varlenH"] + names = ["id", "json"] id: int json: bytes @@ -108,9 +172,13 @@ class RemoteSelectPayload(VariablePayload): @vp_compile class SelectResponsePayload(VariablePayload): + """ + A response to a select request. + """ + msg_id = 202 - format_list = ['I', 'raw'] - names = ['id', 'raw_blob'] + format_list = ["I", "raw"] + names = ["id", "raw_blob"] id: int raw_blob: bytes diff --git a/src/tribler/core/content_discovery/restapi/search_endpoint.py b/src/tribler/core/content_discovery/restapi/search_endpoint.py index 1d737bad47..ef1aaaef0f 100644 --- a/src/tribler/core/content_discovery/restapi/search_endpoint.py +++ b/src/tribler/core/content_discovery/restapi/search_endpoint.py @@ -1,40 +1,57 @@ +from __future__ import annotations + from binascii import hexlify, unhexlify +from typing import TYPE_CHECKING from aiohttp import web -from aiohttp.abc import Request from aiohttp_apispec import docs, querystring_schema -from marshmallow.fields import Integer, List, String - from ipv8.REST.schema import schema -from multidict import MultiDictProxy +from marshmallow.fields import Integer, List, String +from typing_extensions import Self -from tribler.core.content_discovery.community import ContentDiscoveryCommunity from tribler.core.database.restapi.schema import MetadataParameters from tribler.core.restapi.rest_endpoint import HTTP_BAD_REQUEST, MAX_REQUEST_SIZE, RESTEndpoint, RESTResponse +if TYPE_CHECKING: + from aiohttp.abc import Request + from multidict import MultiMapping + + from tribler.core.content_discovery.community import ContentDiscoveryCommunity + class RemoteQueryParameters(MetadataParameters): + """ + The REST API schema for requets to other peers. + """ + uuid = String() - channel_pk = String(description='Channel to query, must also define origin_id') - origin_id = Integer(default=None, description='Peer id to query, must also define channel_pk') + channel_pk = String(description="Channel to query, must also define origin_id") + origin_id = Integer(default=None, description="Peer id to query, must also define channel_pk") class SearchEndpoint(RESTEndpoint): """ This endpoint is responsible for searching in channels and torrents present in the local Tribler database. """ - path = '/search' + + path = "/search" def __init__(self, content_discovery_community: ContentDiscoveryCommunity, - middlewares=(), - client_max_size=MAX_REQUEST_SIZE): + middlewares: tuple = (), + client_max_size: int = MAX_REQUEST_SIZE) -> None: + """ + Create a new search endpoint. + """ super().__init__(middlewares, client_max_size) self.content_discovery_community = content_discovery_community - self.app.add_routes([web.put('/remote', self.remote_search)]) + self.app.add_routes([web.put("/remote", self.remote_search)]) @classmethod - def sanitize_parameters(cls, parameters: MultiDictProxy[str]) -> dict: + def sanitize_parameters(cls: type[Self], parameters: MultiMapping[str]) -> dict: + """ + Correct the human-readable parameters to be their respective correct type. + """ sanitized = dict(parameters) if "max_rowid" in parameters: sanitized["max_rowid"] = int(parameters["max_rowid"]) @@ -45,13 +62,13 @@ def sanitize_parameters(cls, parameters: MultiDictProxy[str]) -> dict: return sanitized @docs( - tags=['Metadata'], + tags=["Metadata"], summary="Perform a search for a given query.", responses={ 200: { - 'schema': schema(RemoteSearchResponse={'request_uuid': String(), 'peers': List(String())}), + "schema": schema(RemoteSearchResponse={"request_uuid": String(), "peers": List(String())}), "examples": { - 'Success': { + "Success": { "request_uuid": "268560c0-3f28-4e6e-9d85-d5ccb0269693", "peers": ["50e9a2ce646c373985a8e827e328830e053025c6", "107c84e5d9636c17b46c88c3ddb54842d80081b0"] @@ -62,13 +79,16 @@ def sanitize_parameters(cls, parameters: MultiDictProxy[str]) -> dict: ) @querystring_schema(RemoteQueryParameters) async def remote_search(self, request: Request) -> RESTResponse: - self._logger.info('Create remote search request') + """ + Perform a search for a given query. + """ + self._logger.info("Create remote search request") # Results are returned over the Events endpoint. try: sanitized = self.sanitize_parameters(request.query) except (ValueError, KeyError) as e: return RESTResponse({"error": f"Error processing request parameters: {e}"}, status=HTTP_BAD_REQUEST) - self._logger.info(f'Parameters: {sanitized}') + self._logger.info("Parameters: %s", str(sanitized)) request_uuid, peers_list = self.content_discovery_community.send_search_request(**sanitized) peers_mid_list = [hexlify(p.mid).decode() for p in peers_list] diff --git a/src/tribler/core/database/store.py b/src/tribler/core/database/store.py index 95271bdef9..be1521e86e 100644 --- a/src/tribler/core/database/store.py +++ b/src/tribler/core/database/store.py @@ -17,6 +17,7 @@ from pony.orm import Database, db_session, desc, left_join, raw_sql, select from pony.orm.dbproviders.sqlite import keep_exception +from tribler.core.database.layers.layer import EntityImpl from tribler.core.database.orm_bindings import misc, torrent_metadata, tracker_state from tribler.core.database.orm_bindings import torrent_state as torrent_state_ from tribler.core.database.orm_bindings.torrent_metadata import NULL_KEY_SUBST @@ -64,8 +65,8 @@ class ProcessingResult: arguments for get_entries to query the sender back through Remote Query Community. """ - md_obj: object = None - obj_state: object = None + md_obj: EntityImpl + obj_state: object missing_deps: list = field(default_factory=list) diff --git a/src/tribler/test_unit/core/content_discovery/test_cache.py b/src/tribler/test_unit/core/content_discovery/test_cache.py index c2ae5a1413..192ad0980e 100644 --- a/src/tribler/test_unit/core/content_discovery/test_cache.py +++ b/src/tribler/test_unit/core/content_discovery/test_cache.py @@ -1,5 +1,7 @@ from asyncio import sleep +from ipv8.keyvault.private.libnaclkey import LibNaCLSK +from ipv8.peer import Peer from ipv8.requestcache import RequestCache from ipv8.test.base import TestBase @@ -11,6 +13,8 @@ class TestSelectRequest(TestBase): Tests for the SelectRequest cache. """ + FAKE_PEER = Peer(LibNaCLSK(b"")) + async def test_timeout_no_cb(self) -> None: """ Test if a SelectRequest can time out without a callback set. @@ -18,7 +22,7 @@ async def test_timeout_no_cb(self) -> None: request_cache = RequestCache() with request_cache.passthrough(): - cache = request_cache.add(SelectRequest(request_cache, "test", {}, None)) + cache = request_cache.add(SelectRequest(request_cache, {}, TestSelectRequest.FAKE_PEER)) await sleep(0) self.assertFalse(request_cache.has(cache.prefix, cache.number)) @@ -31,7 +35,7 @@ async def test_timeout_with_cb(self) -> None: callback_values = [] with request_cache.passthrough(): - cache = request_cache.add(SelectRequest(request_cache, "test", {}, None, + cache = request_cache.add(SelectRequest(request_cache, {}, TestSelectRequest.FAKE_PEER, timeout_callback=callback_values.append)) await sleep(0)