diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 0cd932f393..a5337ff0d2 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -9,7 +9,7 @@ jobs: submodules: 'true' - uses: actions/setup-python@v5 with: - python-version: '3.8' + python-version: '3.9' cache: 'pip' - run: python -m pip install -r requirements.txt - name: Run unit tests @@ -24,7 +24,7 @@ jobs: submodules: 'true' - uses: actions/setup-python@v5 with: - python-version: '3.8' + python-version: '3.9' cache: 'pip' - uses: actions/cache/restore@v4 id: restore_cache @@ -50,7 +50,7 @@ jobs: submodules: 'true' - uses: actions/setup-python@v5 with: - python-version: '3.8' + python-version: '3.9' cache: 'pip' - run: python -m pip install -r requirements.txt - name: Run unit tests diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000000..f095d4d3b9 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +ignore_missing_imports = True diff --git a/src/tribler/core/components.py b/src/tribler/core/components.py index ceb20ecf94..85b7268eae 100644 --- a/src/tribler/core/components.py +++ b/src/tribler/core/components.py @@ -63,7 +63,7 @@ def __init__(self, settings: SettingsClass) -> None: Overlay.__init__(self, settings) self.cancel_pending_task("discover_lan_addresses") self.endpoint.remove_listener(self) - self.bootstrappers = [] + self.bootstrappers: list[Bootstrapper] = [] self.max_peers = 0 self._prefix = settings.community_id self.settings = settings @@ -124,13 +124,13 @@ def prepare(self, ipv8: IPv8, session: Session) -> None: from tribler.core.knowledge.rules.knowledge_rules_processor import KnowledgeRulesProcessor from tribler.core.notifier import Notification - db_path = Path(session.config.get("state_dir")) / "sqlite" / "tribler.db" - mds_path = Path(session.config.get("state_dir")) / "sqlite" / "metadata.db" + db_path = str(Path(session.config.get("state_dir")) / "sqlite" / "tribler.db") + mds_path = str(Path(session.config.get("state_dir")) / "sqlite" / "metadata.db") if session.config.get("memory_db"): db_path = ":memory:" mds_path = ":memory:" - session.db = TriblerDatabase(str(db_path)) + session.db = TriblerDatabase(db_path) session.mds = MetadataStore( mds_path, session.ipv8.keys["anonymous id"].key, diff --git a/src/tribler/core/content_discovery/restapi/search_endpoint.py b/src/tribler/core/content_discovery/restapi/search_endpoint.py index ef1aaaef0f..bb1c761619 100644 --- a/src/tribler/core/content_discovery/restapi/search_endpoint.py +++ b/src/tribler/core/content_discovery/restapi/search_endpoint.py @@ -1,7 +1,7 @@ from __future__ import annotations from binascii import hexlify, unhexlify -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from aiohttp import web from aiohttp_apispec import docs, querystring_schema @@ -48,15 +48,15 @@ def __init__(self, self.app.add_routes([web.put("/remote", self.remote_search)]) @classmethod - def sanitize_parameters(cls: type[Self], parameters: MultiMapping[str]) -> dict: + def sanitize_parameters(cls: type[Self], parameters: MultiMapping[int | str]) -> dict: """ Correct the human-readable parameters to be their respective correct type. """ - sanitized = dict(parameters) + sanitized: dict = dict(parameters) if "max_rowid" in parameters: sanitized["max_rowid"] = int(parameters["max_rowid"]) if "channel_pk" in parameters: - sanitized["channel_pk"] = unhexlify(parameters["channel_pk"]) + sanitized["channel_pk"] = unhexlify(cast(str, parameters["channel_pk"])) if "origin_id" in parameters: sanitized["origin_id"] = int(parameters["origin_id"]) return sanitized diff --git a/src/tribler/core/database/layers/knowledge.py b/src/tribler/core/database/layers/knowledge.py index 2774da9b4c..7522b04209 100644 --- a/src/tribler/core/database/layers/knowledge.py +++ b/src/tribler/core/database/layers/knowledge.py @@ -70,8 +70,13 @@ def get(subject: Resource, object: Resource) -> Statement | None: ... # noqa: D def get_for_update(subject: Resource, object: Resource) -> Statement | None: ... # noqa: D102, A002 + class IterResource(type): # noqa: D101 + + def __iter__(cls) -> Iterator[Resource]: ... # noqa: D105 + + @dataclasses.dataclass - class Resource(EntityImpl): + class Resource(EntityImpl, metaclass=IterResource): """ Database type for a resources. """ @@ -275,7 +280,7 @@ def _get_resources(self, resource_type: ResourceType | None, name: str | None, c def get_statements(self, source_type: ResourceType | None, source_name: str | None, # noqa: PLR0913 statements_getter: Callable[[Entity], Entity], - target_condition: Callable[[], bool], condition: Callable[[], bool], + target_condition: Callable[[Statement], bool], condition: Callable[[Statement], bool], case_sensitive: bool, ) -> Iterator[Statement]: """ Get entities that satisfies the given condition. @@ -365,7 +370,7 @@ def _show_condition(s: Statement) -> bool: def get_objects(self, subject_type: ResourceType | None = None, subject: str | None = "", predicate: ResourceType | None = None, case_sensitive: bool = True, - condition: Callable[[], bool] | None = None) -> List[str]: + condition: Callable[[Statement], bool] | None = None) -> List[str]: """ Get objects that satisfy the given subject and predicate. @@ -438,14 +443,9 @@ def get_simple_statements(self, subject_type: ResourceType | None = None, subjec case_sensitive=case_sensitive, ) - statements = (SimpleStatement( - subject_type=s.subject.type, - subject=s.subject.name, - predicate=s.object.type, - object=s.object.name - ) for s in statements) - - return list(statements) + return [SimpleStatement(subject_type=s.subject.type, subject=s.subject.name, predicate=s.object.type, + object=s.object.name) + for s in statements] def get_suggestions(self, subject_type: ResourceType | None = None, subject: str | None = "", predicate: ResourceType | None = None, case_sensitive: bool = True) -> List[str]: @@ -470,7 +470,7 @@ def get_suggestions(self, subject_type: ResourceType | None = None, subject: str def get_subjects_intersection(self, objects: Set[str], predicate: ResourceType | None, - subjects_type: ResourceType | None = ResourceType.TORRENT, + subjects_type: ResourceType = ResourceType.TORRENT, case_sensitive: bool = True) -> Set[str]: """ Get all subjects that have a certain predicate. @@ -540,7 +540,7 @@ def _get_random_operations_by_condition(self, condition: Callable[[Entity], bool :param attempts: maximum attempt count for requesting the DB. :returns: a set of random operations """ - operations = set() + operations: set[Entity] = set() for _ in range(attempts): if len(operations) == count: return operations diff --git a/src/tribler/core/database/layers/user_activity.py b/src/tribler/core/database/layers/user_activity.py index da2270c3cb..2a97bbe78e 100644 --- a/src/tribler/core/database/layers/user_activity.py +++ b/src/tribler/core/database/layers/user_activity.py @@ -201,7 +201,7 @@ def get_random_query_aggregate(self, neighbors: int, # Option 2: aggregate results = self.Query.select(lambda q: q.query == random_selection.query)[:] num_results_div = 1/len(results) - preferences = {} + preferences: dict[bytes, float] = {} for query in results: for infohash_preference in query.infohashes: preferences[infohash_preference.infohash] = (preferences.get(infohash_preference.infohash, 0.0) diff --git a/src/tribler/core/database/orm_bindings/torrent_metadata.py b/src/tribler/core/database/orm_bindings/torrent_metadata.py index e73778279a..1c392b611c 100644 --- a/src/tribler/core/database/orm_bindings/torrent_metadata.py +++ b/src/tribler/core/database/orm_bindings/torrent_metadata.py @@ -266,7 +266,7 @@ def get_magnet(self) -> str: @classmethod @db_session - def add_ffa_from_dict(cls: type[Self], metadata: dict) -> Self: + def add_ffa_from_dict(cls: type[Self], metadata: dict) -> Self | None: # To produce a relatively unique id_ we take some bytes of the infohash and convert these to a number. # abs is necessary as the conversion can produce a negative value, and we do not support that. id_ = infohash_to_id(metadata["infohash"]) diff --git a/src/tribler/core/database/ranks.py b/src/tribler/core/database/ranks.py index de91a6531d..e1148b343f 100644 --- a/src/tribler/core/database/ranks.py +++ b/src/tribler/core/database/ranks.py @@ -88,9 +88,9 @@ def title_rank(query: str, title: str) -> float: :param title: a torrent name :return: the similarity of the title string to a query string as a float value in range [0, 1] """ - query = word_re.findall(query.lower()) - title = word_re.findall(title.lower()) - return calculate_rank(query, title) + pat_query = word_re.findall(query.lower()) + pat_title = word_re.findall(title.lower()) + return calculate_rank(pat_query, pat_title) # These coefficients are found empirically. Their exact values are not very important for a relative ranking of results @@ -125,13 +125,13 @@ def calculate_rank(query: list[str], title: list[str]) -> float: if not title: return 0.0 - title = deque(title) - total_error = 0 + q_title = deque(title) + total_error = 0.0 for i, word in enumerate(query): # The first word is more important than the second word, and so on word_weight = POSITION_COEFF / (POSITION_COEFF + i) - found, skipped = find_word_and_rotate_title(word, title) + found, skipped = find_word_and_rotate_title(word, q_title) if found: # if the query word is found in the title, add penalty for skipped words in title before it total_error += skipped * word_weight @@ -141,7 +141,7 @@ def calculate_rank(query: list[str], title: list[str]) -> float: # a small penalty for excess words in the title that was not mentioned in the search phrase remainder_weight = 1 / (REMAINDER_COEFF + len(query)) - remained_words_error = len(title) * remainder_weight + remained_words_error = len(q_title) * remainder_weight total_error += remained_words_error # a search rank should be between 1 and 0 diff --git a/src/tribler/core/database/restapi/database_endpoint.py b/src/tribler/core/database/restapi/database_endpoint.py index 38716528db..0954826584 100644 --- a/src/tribler/core/database/restapi/database_endpoint.py +++ b/src/tribler/core/database/restapi/database_endpoint.py @@ -9,7 +9,6 @@ from aiohttp import web from aiohttp_apispec import docs, querystring_schema -from ipv8.REST.base_endpoint import HTTP_BAD_REQUEST from ipv8.REST.schema import schema from marshmallow.fields import Boolean, Integer, String from pony.orm import db_session @@ -19,7 +18,13 @@ from tribler.core.database.restapi.schema import MetadataSchema, SearchMetadataParameters, TorrentSchema from tribler.core.database.serialization import REGULAR_TORRENT, SNIPPET from tribler.core.notifier import Notification -from tribler.core.restapi.rest_endpoint import MAX_REQUEST_SIZE, RESTEndpoint, RESTResponse +from tribler.core.restapi.rest_endpoint import ( + HTTP_BAD_REQUEST, + HTTP_NOT_FOUND, + MAX_REQUEST_SIZE, + RESTEndpoint, + RESTResponse, +) if typing.TYPE_CHECKING: from aiohttp.abc import Request @@ -97,14 +102,15 @@ def __init__(self, # noqa: PLR0913 @classmethod def sanitize_parameters(cls: type[Self], - parameters: MultiDictProxy | MultiMapping[str]) -> dict[str, str | float | set | None]: + parameters: MultiDictProxy | MultiMapping[str] + ) -> dict[str, str | float | list[str] | set[bytes] | bytes | None]: """ Sanitize the parameters for a request that fetches channels. """ - sanitized = { + sanitized: dict[str, str | float | list[str] | set[bytes] | bytes | None] = { "first": int(parameters.get("first", 1)), "last": int(parameters.get("last", 50)), - "sort_by": json2pony_columns.get(parameters.get("sort_by")), + "sort_by": json2pony_columns.get(parameters.get("sort_by", "")), "sort_desc": parse_bool(parameters.get("sort_desc", "true")), "txt_filter": parameters.get("txt_filter"), "hide_xxx": parse_bool(parameters.get("hide_xxx", "false")), @@ -233,7 +239,7 @@ async def get_popular_torrents(self, request: Request) -> RESTResponse: return RESTResponse(response_dict) - def build_snippets(self, search_results: list[dict]) -> list[dict]: + def build_snippets(self, tribler_db: TriblerDatabase, search_results: list[dict]) -> list[dict]: """ Build a list of snippets that bundle torrents describing the same content item. For each search result we determine the content item it is associated to and bundle it inside a snippet. @@ -241,12 +247,12 @@ def build_snippets(self, search_results: list[dict]) -> list[dict]: Within each snippet, we sort on torrent popularity, putting the torrent with the most seeders on top. Torrents bundled in a snippet are filtered out from the search results. """ - content_to_torrents: typing.Dict[str, list] = defaultdict(list) + content_to_torrents: dict[str, list] = defaultdict(list) for search_result in search_results: if "infohash" not in search_result: continue with db_session: - content_items: typing.List[str] = self.tribler_db.knowledge.get_objects( + content_items = tribler_db.knowledge.get_objects( subject_type=ResourceType.TORRENT, subject=search_result["infohash"], predicate=ResourceType.CONTENT_ITEM) @@ -262,7 +268,7 @@ def build_snippets(self, search_results: list[dict]) -> list[dict]: sorted_content_info = list(content_to_torrents.items()) sorted_content_info.sort(key=lambda x: x[1][0]["num_seeders"], reverse=True) - snippets: typing.List[typing.Dict] = [] + snippets: list[dict] = [] for content_info in sorted_content_info: content_id = content_info[0] torrents_in_snippet = content_to_torrents[content_id][:MAX_TORRENTS_IN_SNIPPETS] @@ -310,7 +316,7 @@ def build_snippets(self, search_results: list[dict]) -> list[dict]: }, ) @querystring_schema(SearchMetadataParameters) - async def local_search(self, request: Request) -> RESTResponse: + async def local_search(self, request: Request) -> RESTResponse: # noqa: C901 """ Perform a search for a given query. """ @@ -320,6 +326,9 @@ async def local_search(self, request: Request) -> RESTResponse: except (ValueError, KeyError): return RESTResponse({"error": "Error processing request parameters"}, status=HTTP_BAD_REQUEST) + if self.tribler_db is None: + return RESTResponse({"error": "Tribler DB not initialized"}, status=HTTP_NOT_FOUND) + include_total = request.query.get("include_total", "") mds: MetadataStore = self.mds @@ -345,7 +354,7 @@ def search_db() -> tuple[list[dict], int, int]: if tags: infohash_set = self.tribler_db.knowledge.get_subjects_intersection( subjects_type=ResourceType.TORRENT, - objects=set(tags), + objects=set(typing.cast(list[str], tags)), predicate=ResourceType.TAG, case_sensitive=False) if infohash_set: @@ -359,7 +368,7 @@ def search_db() -> tuple[list[dict], int, int]: self.add_statements_to_metadata_list(search_results) if sanitized["first"] == 1: # Only show a snippet on top - search_results = self.build_snippets(search_results) + search_results = self.build_snippets(self.tribler_db, search_results) response_dict = { "results": search_results, diff --git a/src/tribler/core/database/serialization.py b/src/tribler/core/database/serialization.py index 7f5ab9ca4d..d6a18f4e5b 100644 --- a/src/tribler/core/database/serialization.py +++ b/src/tribler/core/database/serialization.py @@ -119,8 +119,9 @@ def from_dict(cls: type[Self], **kwargs) -> Self: Create a payload from the given data (an unpacked dict). """ out = cls(**{key: value for key, value in kwargs.items() if key in cls.names}) - if kwargs.get("signature") is not None: - out.signature = kwargs.get("signature") + signature = kwargs.get("signature") + if signature is not None: + out.signature = signature return out def add_signature(self, key: PrivateKey) -> None: @@ -196,8 +197,8 @@ def get_magnet(self) -> str: """ Create a magnet link for this payload. """ - return (f"magnet:?xt=urn:btih:{hexlify(self.infohash).decode()}&dn={self.title.encode()}" - + (f"&tr={self.tracker_info.encode()}" if self.tracker_info else "")) + return (f"magnet:?xt=urn:btih:{hexlify(self.infohash).decode()}&dn={self.title}" + + (f"&tr={self.tracker_info}" if self.tracker_info else "")) @vp_compile diff --git a/src/tribler/core/database/store.py b/src/tribler/core/database/store.py index eb95d71c7e..7f65c101df 100644 --- a/src/tribler/core/database/store.py +++ b/src/tribler/core/database/store.py @@ -33,7 +33,6 @@ from tribler.core.torrent_checker.dataclasses import HealthInfo if TYPE_CHECKING: - from os import PathLike from sqlite3 import Connection from ipv8.types import PrivateKey @@ -133,7 +132,7 @@ class MetadataStore: def __init__( # noqa: PLR0913 self, - db_filename: PathLike, + db_filename: str, private_key: PrivateKey, disable_sync: bool = False, notifier: Notifier | None = None, @@ -241,9 +240,8 @@ def get_objects_to_create(self) -> list[Entity]: connection = self.db.get_connection() schema = self.db.schema provider = schema.provider - created_tables = set() return [db_object for table in schema.order_tables_to_create() - for db_object in table.get_objects_to_create(created_tables) + for db_object in table.get_objects_to_create(set()) if not db_object.exists(provider, connection)] def get_db_file_size(self) -> int: @@ -481,7 +479,7 @@ def get_num_torrents(self) -> int: """ return orm.count(self.TorrentMetadata.select(lambda g: g.metadata_type == REGULAR_TORRENT)) - def search_keyword(self, query: str, origin_id: int | None = None) -> list[TorrentMetadata]: + def search_keyword(self, query: str, origin_id: int | None = None) -> Query: """ Search for an FTS query, potentially restricted to a given origin id. @@ -585,10 +583,7 @@ def get_entries_query( # noqa: C901, PLR0912, PLR0913 pony_query = pony_query.where(lambda g: g.rowid <= max_rowid) if metadata_type is not None: - try: - pony_query = pony_query.where(lambda g: g.metadata_type in metadata_type) - except TypeError: - pony_query = pony_query.where(lambda g: g.metadata_type == metadata_type) + pony_query = pony_query.where(lambda g: g.metadata_type == metadata_type) pony_query = ( pony_query.where(public_key=(b"" if channel_pk == NULL_KEY_SUBST else channel_pk)) diff --git a/src/tribler/core/database/tribler_database.py b/src/tribler/core/database/tribler_database.py index 9153ecae71..ed6916952d 100644 --- a/src/tribler/core/database/tribler_database.py +++ b/src/tribler/core/database/tribler_database.py @@ -3,7 +3,7 @@ import logging import os from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from pony import orm from pony.orm import Database, db_session @@ -117,7 +117,7 @@ def version(self) -> int: """ Get the database version. """ - return int(self.get_misc(key=self._SCHEME_VERSION_KEY, default="0")) + return int(cast(str, self.get_misc(key=self._SCHEME_VERSION_KEY, default="0"))) @version.setter def version(self, value: int) -> None: diff --git a/src/tribler/core/knowledge/operations_requests.py b/src/tribler/core/knowledge/operations_requests.py index 1de3534c42..3c1c2ea4e7 100644 --- a/src/tribler/core/knowledge/operations_requests.py +++ b/src/tribler/core/knowledge/operations_requests.py @@ -1,6 +1,10 @@ +from __future__ import annotations + from collections import defaultdict +from typing import TYPE_CHECKING -from ipv8.types import Peer +if TYPE_CHECKING: + from ipv8.types import Peer class PeerValidationError(ValueError): @@ -23,7 +27,7 @@ def __init__(self) -> None: """ Create a new dictionary to keep track of responses. """ - self.requests = defaultdict(int) + self.requests: dict[Peer, int] = defaultdict(int) def register_peer(self, peer: Peer, number_of_responses: int) -> None: """ diff --git a/src/tribler/core/knowledge/restapi/knowledge_endpoint.py b/src/tribler/core/knowledge/restapi/knowledge_endpoint.py index 62793e1dfd..ae2b82cdf3 100644 --- a/src/tribler/core/knowledge/restapi/knowledge_endpoint.py +++ b/src/tribler/core/knowledge/restapi/knowledge_endpoint.py @@ -51,7 +51,7 @@ def __init__(self, db: TriblerDatabase, community: KnowledgeCommunity) -> None: ) @staticmethod - def validate_infohash(infohash: bytes) -> tuple[bool, RESTResponse | None]: + def validate_infohash(infohash: str) -> tuple[bool, RESTResponse | None]: """ Check if the given bytes are a string of 40 HEX-character bytes. """ diff --git a/src/tribler/core/knowledge/rules/knowledge_rules_processor.py b/src/tribler/core/knowledge/rules/knowledge_rules_processor.py index ad37fc81c0..fe4465336c 100644 --- a/src/tribler/core/knowledge/rules/knowledge_rules_processor.py +++ b/src/tribler/core/knowledge/rules/knowledge_rules_processor.py @@ -67,7 +67,7 @@ def __init__(self, notifier: Notifier, db: TriblerDatabase, mds: MetadataStore, self.queue_batch_size = queue_batch_size self.queue_max_size = queue_max_size - self._last_warning_time = 0 + self._last_warning_time: float = 0 self._start_rowid_in_current_session = 0 self._start_time_in_current_session = 0 diff --git a/src/tribler/core/knowledge/rules/rules.py b/src/tribler/core/knowledge/rules/rules.py index eaf49ad292..53bcedde3f 100644 --- a/src/tribler/core/knowledge/rules/rules.py +++ b/src/tribler/core/knowledge/rules/rules.py @@ -3,7 +3,7 @@ import re from dataclasses import dataclass, field from re import Pattern -from typing import AnyStr, Callable, Generator, Sequence +from typing import Callable, Generator, Sequence from tribler.core.knowledge.community import is_valid_resource @@ -24,7 +24,7 @@ class Rule: A Linux distribution correction rule. """ - patterns: Sequence[Pattern[AnyStr]] = field(default_factory=list) + patterns: Sequence[Pattern[str]] = field(default_factory=list) actions: Sequence[Callable[[str], str]] = field(default_factory=list) @@ -92,7 +92,7 @@ def extract_tags(text: str, rules: RulesList | None = None) -> Generator[str, No text_set = next_text_set for action in rule.actions: - text_set = map(action, text_set) + text_set = {action(e) for e in text_set} yield from text_set diff --git a/src/tribler/core/libtorrent/download_manager/dht_health_manager.py b/src/tribler/core/libtorrent/download_manager/dht_health_manager.py index a3cb805e07..98429742d8 100644 --- a/src/tribler/core/libtorrent/download_manager/dht_health_manager.py +++ b/src/tribler/core/libtorrent/download_manager/dht_health_manager.py @@ -23,10 +23,10 @@ def __init__(self, lt_session: lt.session) -> None: :param lt_session: The session used to perform health lookups. """ TaskManager.__init__(self) - self.lookup_futures = {} # Map from binary infohash to future - self.bf_seeders = {} # Map from infohash to (final) seeders bloomfilter - self.bf_peers = {} # Map from infohash to (final) peers bloomfilter - self.outstanding = {} # Map from transaction_id to infohash + self.lookup_futures: dict[bytes, Future[HealthInfo]] = {} # Map from binary infohash to future + self.bf_seeders: dict[bytes, bytearray] = {} # Map from infohash to (final) seeders bloomfilter + self.bf_peers: dict[bytes, bytearray] = {} # Map from infohash to (final) peers bloomfilter + self.outstanding: dict[str, bytes] = {} # Map from transaction_id to infohash self.lt_session = lt_session def get_health(self, infohash: bytes, timeout: float = 15) -> Awaitable[HealthInfo]: @@ -39,7 +39,7 @@ def get_health(self, infohash: bytes, timeout: float = 15) -> Awaitable[HealthIn if infohash in self.lookup_futures: return self.lookup_futures[infohash] - lookup_future = Future() + lookup_future: Future[HealthInfo] = Future() self.lookup_futures[infohash] = lookup_future self.bf_seeders[infohash] = bytearray(256) self.bf_peers[infohash] = bytearray(256) @@ -100,8 +100,7 @@ def get_size_from_bloomfilter(bf: bytearray) -> int: def tobits(s: bytes) -> list[int]: result = [] - for c in s: - num = ord(c) if isinstance(c, str) else c + for num in s: bits = bin(num)[2:] bits = "00000000"[len(bits):] + bits result.extend([int(b) for b in bits]) diff --git a/src/tribler/core/libtorrent/download_manager/download.py b/src/tribler/core/libtorrent/download_manager/download.py index 9ecf374ad3..7111043a47 100644 --- a/src/tribler/core/libtorrent/download_manager/download.py +++ b/src/tribler/core/libtorrent/download_manager/download.py @@ -9,13 +9,13 @@ import base64 import itertools import logging -from asyncio import CancelledError, Future, get_running_loop, iscoroutine, sleep, wait_for +from asyncio import CancelledError, Future, get_running_loop, sleep, wait_for from binascii import hexlify from collections import defaultdict from contextlib import suppress from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple, TypedDict +from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Tuple, TypedDict, cast import libtorrent as lt from bitarray import bitarray @@ -29,6 +29,7 @@ from tribler.core.libtorrent.torrentdef import TorrentDef, TorrentDefNoMetainfo from tribler.core.libtorrent.torrents import check_handle, get_info_from_handle, require_handle from tribler.core.notifier import Notification, Notifier +from tribler.tribler_config import TriblerConfigManager if TYPE_CHECKING: from tribler.core.libtorrent.download_manager.download_manager import DownloadManager @@ -92,10 +93,10 @@ class Download(TaskManager): def __init__(self, # noqa: PLR0913 tdef: TorrentDef, + download_manager: DownloadManager, config: DownloadConfig | None = None, notifier: Notifier | None = None, state_dir: Path | None = None, - download_manager: DownloadManager | None =None, checkpoint_disabled: bool = False, hidden: bool = False) -> None: """ @@ -116,10 +117,10 @@ def __init__(self, # noqa: PLR0913 self.error = None self.pause_after_next_hashcheck = False self.checkpoint_after_next_hashcheck = False - self.tracker_status = {} # {url: [num_peers, status_str]} + self.tracker_status: dict[str, tuple[int, str]] = {} # {url: (num_peers, status_str)} - self.futures: Dict[str, list[tuple[Future, Callable, Getter | None]]] = defaultdict(list) - self.alert_handlers = defaultdict(list) + self.futures: dict[str, list[tuple[Future, Callable, Getter | None]]] = defaultdict(list) + self.alert_handlers: dict[str, list[Callable[[lt.torrent_alert], None]]] = defaultdict(list) self.future_added = self.wait_for_alert("add_torrent_alert", lambda a: a.handle) self.future_removed = self.wait_for_alert("torrent_removed_alert") @@ -146,7 +147,11 @@ def __init__(self, # noqa: PLR0913 # With hidden True download will not be in GET/downloads set, as a result will not be shown in GUI self.hidden = hidden self.checkpoint_disabled = checkpoint_disabled - self.config = config or DownloadConfig.from_defaults(self.download_manager.config) + self.config: DownloadConfig = config + if config is None and self.download_manager is not None: + self.config = DownloadConfig.from_defaults(self.download_manager.config) + elif config is None: + self.config = DownloadConfig.from_defaults(TriblerConfigManager()) self._logger.debug("Setup: %s", hexlify(self.tdef.get_infohash()).decode()) @@ -194,7 +199,7 @@ def wait_for_alert(self, success_type: str, success_getter: Getter | None = None """ Create a future that fires when a certain alert is received. """ - future = Future() + future: Future[Any] = Future() if success_type: self.futures[success_type].append((future, future.set_result, success_getter)) if fail_type: @@ -215,7 +220,7 @@ def get_def(self) -> TorrentDef: """ return self.tdef - def get_handle(self) -> Awaitable[lt.torrent_handle]: + def get_handle(self) -> Future[lt.torrent_handle]: """ Returns a deferred that fires with a valid libtorrent download handle. """ @@ -312,7 +317,7 @@ def get_pieces_base64(self) -> bytes: """ Returns a base64 encoded bitmask of the pieces that we have. """ - binary_gen = (int(boolean) for boolean in self.handle.status().pieces) + binary_gen = (int(boolean) for boolean in cast(lt.torrent_handle, self.handle).status().pieces) bits = bitarray(binary_gen) return base64.b64encode(bits.tobytes()) @@ -413,7 +418,7 @@ def on_tracker_reply_alert(self, alert: lt.tracker_reply_alert) -> None: """ self._logger.info("On tracker reply alert: %s", repr(alert)) - self.tracker_status[alert.url] = [alert.num_peers, 'Working'] + self.tracker_status[alert.url] = (alert.num_peers, 'Working') def on_tracker_error_alert(self, alert: lt.tracker_error_alert) -> None: """ @@ -435,7 +440,7 @@ def on_tracker_error_alert(self, alert: lt.tracker_error_alert) -> None: status = "Not working" peers = 0 # If there is a tracker error, alert.num_peers is not available. So resetting peer count to zero. - self.tracker_status[url] = [peers, status] + self.tracker_status[url] = (peers, status) def on_tracker_warning_alert(self, alert: lt.tracker_warning_alert) -> None: """ @@ -446,14 +451,15 @@ def on_tracker_warning_alert(self, alert: lt.tracker_warning_alert) -> None: peers = self.tracker_status[alert.url][0] if alert.url in self.tracker_status else 0 status = "Warning: " + str(alert.message()) - self.tracker_status[alert.url] = [peers, status] + self.tracker_status[alert.url] = (peers, status) - @check_handle() + @check_handle(None) def on_metadata_received_alert(self, alert: lt.metadata_received_alert) -> None: # noqa: C901, PLR0912 """ Handle a metadata received alert. """ self._logger.info("On metadata received alert: %s", repr(alert)) + self.handle = cast(lt.torrent_handle, self.handle) torrent_info = get_info_from_handle(self.handle) if not torrent_info: @@ -542,19 +548,20 @@ def on_torrent_checked_alert(self, alert: lt.torrent_checked_alert) -> None: """ self._logger.info("On torrent checked alert: %s", repr(alert)) - if self.pause_after_next_hashcheck: + if self.pause_after_next_hashcheck and self.handle: self.pause_after_next_hashcheck = False self.handle.pause() if self.checkpoint_after_next_hashcheck: self.checkpoint_after_next_hashcheck = False self.checkpoint() - @check_handle() + @check_handle(None) def on_torrent_finished_alert(self, alert: lt.torrent_finished_alert) -> None: """ Handle a torrent finished alert. """ self._logger.info("On torrent finished alert: %s", repr(alert)) + self.handle = cast(lt.torrent_handle, self.handle) self.update_lt_status(self.handle.status()) self.checkpoint() downloaded = self.get_state().total_download @@ -579,14 +586,14 @@ def update_lt_status(self, lt_status: lt.torrent_status) -> None: (mode == "time" and state.get_seeding_time() >= seeding_time)): self.stop() - @check_handle() + @check_handle(None) def set_selected_files(self, selected_files: list[int] | None = None, prio: int = 4, force: bool = False) -> int | None: """ Set the selected files. If the selected files is None or empty, all files will be selected. """ if not force and self.stream is not None: - return + return None if not isinstance(self.tdef, TorrentDefNoMetainfo) and not self.get_share_mode(): if selected_files is None: selected_files = self.config.get_selected_files() @@ -597,7 +604,7 @@ def set_selected_files(self, selected_files: list[int] | None = None, prio: int total_files = self.tdef.torrent_info.num_files() if not selected_files: - selected_files = range(total_files) + selected_files = list(range(total_files)) def map_selected(index: int) -> int: file_instance = tree.find(Path(tree.file_storage.file_path(index))) @@ -608,6 +615,7 @@ def map_selected(index: int) -> int: return 0 self.set_file_priorities(list(map(map_selected, range(total_files)))) + return None @check_handle(False) def move_storage(self, new_dir: Path) -> bool: @@ -615,16 +623,18 @@ def move_storage(self, new_dir: Path) -> bool: Move the output files to a different location. """ if not isinstance(self.tdef, TorrentDefNoMetainfo): + self.handle = cast(lt.torrent_handle, self.handle) self.handle.move_storage(str(new_dir)) self.config.set_dest_dir(new_dir) return True - @check_handle() + @check_handle(None) def force_recheck(self) -> None: """ Force libtorrent to validate the files. """ if not isinstance(self.tdef, TorrentDefNoMetainfo): + self.handle = cast(lt.torrent_handle, self.handle) if self.get_state().get_status() == DownloadStatus.STOPPED: self.pause_after_next_hashcheck = True self.checkpoint_after_next_hashcheck = True @@ -666,7 +676,7 @@ def get_peer_list(self, include_have: bool = True) -> List[PeerDict | PeerDictHa extended_version = peer_info.client except UnicodeDecodeError: extended_version = "unknown" - peer_dict = { + peer_dict: PeerDict | PeerDictHave = cast(PeerDict, { "id": hexlify(peer_info.pid.to_bytes()).decode(), "extended_version": extended_version, "ip": peer_info.ip[0], @@ -690,8 +700,9 @@ def get_peer_list(self, include_have: bool = True) -> List[PeerDict | PeerDictHa "connection_type": peer_info.connection_type, "seed": bool(peer_info.flags & peer_info.seed), "upload_only": bool(peer_info.flags & peer_info.upload_only) - } + }) if include_have: + peer_dict = cast(PeerDictHave, peer_dict) peer_dict["have"] = peer_info.pieces peers.append(peer_dict) return peers @@ -728,12 +739,13 @@ def get_tracker_status(self) -> dict[str, tuple[int, str]]: """ Retrieve an overview of the trackers and their statuses. """ + self.handle = cast(lt.torrent_handle, self.handle) # Make sure all trackers are in the tracker_status dict try: for announce_entry in self.handle.trackers(): url = announce_entry["url"] if url not in self.tracker_status: - self.tracker_status[url] = [0, "Not contacted yet"] + self.tracker_status[url] = (0, "Not contacted yet") except UnicodeDecodeError: self._logger.warning("UnicodeDecodeError in get_tracker_status") @@ -760,21 +772,6 @@ def get_tracker_status(self) -> dict[str, tuple[int, str]]: result["[PeX]"] = (pex_peers, "Working") return result - def set_state_callback(self, usercallback: Callable[[DownloadState], float | Awaitable[float]]) -> Future: - """ - Fire a callback after a second and subsequently whenever the callback returns a value larger than zero. - """ - async def state_callback_loop() -> None: - if usercallback: - when = 1 - while when and not self.future_removed.done() and not self.download_manager.is_shutting_down(): - result = usercallback(self.get_state()) - when = (await result) if iscoroutine(result) else result - if when > 0.0 and not self.download_manager.is_shutting_down(): - await sleep(when) - - return self.register_anonymous_task("downloads_cb", state_callback_loop) - async def shutdown(self) -> None: """ Shut down the download. @@ -792,7 +789,7 @@ async def shutdown(self) -> None: self.futures.clear() await self.shutdown_task_manager() - def stop(self, user_stopped: bool | None = None) -> Future[None]: + def stop(self, user_stopped: bool | None = None) -> Awaitable[None]: """ Stop downloading the download. """ @@ -859,16 +856,17 @@ def set_def(self, tdef: TorrentDef) -> None: """ self.tdef = tdef - @check_handle() + @check_handle(None) def add_trackers(self, trackers: list[str]) -> None: """ Add the given trackers to the handle. """ + self.handle = cast(lt.torrent_handle, self.handle) if hasattr(self.handle, "add_tracker"): for tracker in trackers: self.handle.add_tracker({"url": tracker, "verified": False}) - @check_handle() + @check_handle(None) def get_magnet_link(self) -> str: """ Generate a magnet link for our download. @@ -882,6 +880,7 @@ def add_peer(self, addr: tuple[str, int]) -> None: :param addr: The (hostname_ip,port) tuple to connect to """ + self.handle = cast(lt.torrent_handle, self.handle) self.handle.connect_peer(addr, 0) @require_handle @@ -889,6 +888,7 @@ def set_priority(self, priority: int) -> None: """ Set the priority of this download. """ + self.handle = cast(lt.torrent_handle, self.handle) self.handle.set_priority(priority) @require_handle @@ -896,6 +896,7 @@ def set_max_upload_rate(self, value: int) -> None: """ Set the maximum upload rate of this download. """ + self.handle = cast(lt.torrent_handle, self.handle) self.handle.set_upload_limit(value * 1024) @require_handle @@ -903,6 +904,7 @@ def set_max_download_rate(self, value: int) -> None: """ Set the maximum download rate of this download. """ + self.handle = cast(lt.torrent_handle, self.handle) self.handle.set_download_limit(value * 1024) @require_handle @@ -910,6 +912,7 @@ def apply_ip_filter(self, enable: bool) -> None: """ Enable the IP filter on this download. """ + self.handle = cast(lt.torrent_handle, self.handle) self.handle.apply_ip_filter(enable) def get_share_mode(self) -> bool: @@ -923,6 +926,7 @@ def set_share_mode(self, share_mode: bool) -> None: """ Set whether this download is in sharing mode. """ + self.handle = cast(lt.torrent_handle, self.handle) self.config.set_share_mode(share_mode) self.handle.set_share_mode(share_mode) @@ -937,6 +941,7 @@ def set_upload_mode(self, upload_mode: bool) -> None: """ Set whether this download is in upload mode. """ + self.handle = cast(lt.torrent_handle, self.handle) self.config.set_upload_mode(upload_mode) self.handle.set_upload_mode(upload_mode) @@ -945,6 +950,7 @@ def force_dht_announce(self) -> None: """ Force announce thid download on the DHT. """ + self.handle = cast(lt.torrent_handle, self.handle) self.handle.force_dht_announce() @require_handle @@ -952,6 +958,7 @@ def set_sequential_download(self, enable: bool) -> None: """ Set this download to sequential download mode. """ + self.handle = cast(lt.torrent_handle, self.handle) self.handle.set_sequential_download(enable) @check_handle(None) @@ -959,6 +966,7 @@ def set_piece_priorities(self, piece_priorities: list[int]) -> None: """ Set the priority for all pieces in the download. """ + self.handle = cast(lt.torrent_handle, self.handle) self.handle.prioritize_pieces(piece_priorities) @check_handle([]) @@ -966,6 +974,7 @@ def get_piece_priorities(self) -> list[int]: """ Get the priorities of all pieces in the download. """ + self.handle = cast(lt.torrent_handle, self.handle) return self.handle.piece_priorities() @check_handle(None) @@ -973,12 +982,15 @@ def set_file_priorities(self, file_priorities: list[int]) -> None: """ Set the priority for all files in the download. """ + self.handle = cast(lt.torrent_handle, self.handle) self.handle.prioritize_files(file_priorities) + @check_handle(None) def set_file_priority(self, file_index: int, prio: int = 4) -> None: """ Set the priority for a particular file in the download. """ + self.handle = cast(lt.torrent_handle, self.handle) self.handle.file_priority(file_index, prio) @check_handle(None) @@ -986,6 +998,7 @@ def reset_piece_deadline(self, piece: int) -> None: """ Reset the deadline for the given piece. """ + self.handle = cast(lt.torrent_handle, self.handle) self.handle.reset_piece_deadline(piece) @check_handle(None) @@ -993,6 +1006,7 @@ def set_piece_deadline(self, piece: int, deadline: int, flags: int = 0) -> None: """ Set the deadline for a given piece. """ + self.handle = cast(lt.torrent_handle, self.handle) self.handle.set_piece_deadline(piece, deadline, flags) @check_handle([]) @@ -1000,6 +1014,7 @@ def get_file_priorities(self) -> list[int]: """ Get the priorities of all files in the download. """ + self.handle = cast(lt.torrent_handle, self.handle) return self.handle.file_priorities() def file_piece_range(self, file_path: Path) -> list[int]: @@ -1030,6 +1045,7 @@ def get_file_completion(self, path: Path) -> float: """ Calculate the completion of a given file or directory. """ + self.handle = cast(lt.torrent_handle, self.handle) total = 0 have = 0 for piece_index in self.file_piece_range(path): @@ -1071,6 +1087,7 @@ def set_selected_file_or_dir(self, path: Path, selected: bool) -> None: """ Set a single file or directory to be selected or not. """ + self.handle = cast(lt.torrent_handle, self.handle) tree = self.tdef.torrent_file_tree prio = 4 if selected else 0 for index in tree.set_selected(Path(path), selected): diff --git a/src/tribler/core/libtorrent/download_manager/download_config.py b/src/tribler/core/libtorrent/download_manager/download_config.py index bf3feddf5a..b5dcceed23 100644 --- a/src/tribler/core/libtorrent/download_manager/download_config.py +++ b/src/tribler/core/libtorrent/download_manager/download_config.py @@ -3,7 +3,7 @@ import base64 from io import StringIO from pathlib import Path -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Any, Dict, Literal, TypedDict, cast, overload import libtorrent as lt from configobj import ConfigObj @@ -12,6 +12,56 @@ if TYPE_CHECKING: from tribler.tribler_config import TriblerConfigManager + + class DownloadConfigDefaultsSection(TypedDict): + """ + The default config settings for a download. + """ + + hops: int + selected_files: list[str] + selected_file_indexes: list[int] + safe_seeding: bool + user_stopped: bool + share_mode: bool + upload_mode: bool + time_added: int + bootstrap_download: bool + channel_download: bool + add_download_to_channel: bool + saveas: str | None + + + class StateConfigSection(TypedDict): + """ + The runtime state info of a download. + """ + + metainfo: str + engineresumedata: str + + + class DownloadConfigDict(dict): + """ + All config settings in the config file. + """ + + @overload # type: ignore[override] + def __getitem__(self, key: Literal["filename"]) -> str: ... + + @overload + def __getitem__(self, key: Literal["download_defaults"]) -> DownloadConfigDefaultsSection: ... + + @overload + def __getitem__(self, key: Literal["state"]) -> StateConfigSection: ... + + def __getitem__(self, key: str) -> Any: ... # noqa: D105 + + def write(self) -> None: ... # noqa: D102 +else: + DownloadConfigDict = ConfigObj + + SPEC_FILENAME = 'download_config.spec' SPEC_CONTENT = """[download_defaults] hops = integer(default=0) @@ -51,11 +101,11 @@ class DownloadConfig: A configuration belonging to a specific download. """ - def __init__(self, config: ConfigObj | None = None) -> None: + def __init__(self, config: ConfigObj) -> None: """ Create a download config from the given ConfigObj. """ - self.config = config + self.config: DownloadConfigDict = cast(DownloadConfigDict, config) @staticmethod def get_spec_file_name(settings: TriblerConfigManager) -> str: @@ -94,7 +144,7 @@ def write(self, filename: Path) -> None: """ Write the contents of this config to a file. """ - self.config.filename = Path(filename) + self.config["filename"] = str(filename) self.config.write() def set_dest_dir(self, path: Path | str) -> None: @@ -109,7 +159,7 @@ def get_dest_dir(self) -> Path: """ Gets the directory where to save this Download. """ - dest_dir = self.config["download_defaults"]["saveas"] + dest_dir = self.config["download_defaults"]["saveas"] or "" return Path(dest_dir) def set_hops(self, hops: int) -> None: diff --git a/src/tribler/core/libtorrent/download_manager/download_config.spec b/src/tribler/core/libtorrent/download_manager/download_config.spec deleted file mode 100644 index 5dc7beaf8d..0000000000 --- a/src/tribler/core/libtorrent/download_manager/download_config.spec +++ /dev/null @@ -1,17 +0,0 @@ -[download_defaults] -hops = integer(default=0) -selected_files = string_list(default=list()) -selected_file_indexes = int_list(default=list()) -safe_seeding = boolean(default=False) -user_stopped = boolean(default=False) -share_mode = boolean(default=False) -upload_mode = boolean(default=False) -time_added = integer(default=0) -bootstrap_download = boolean(default=False) -channel_download = boolean(default=False) -add_download_to_channel = boolean(default=False) -saveas = string(default=None) - -[state] -metainfo = string(default='ZGU=') -engineresumedata = string(default='ZGU=') diff --git a/src/tribler/core/libtorrent/download_manager/download_manager.py b/src/tribler/core/libtorrent/download_manager/download_manager.py index 3fdde4a203..e7e2a7f880 100644 --- a/src/tribler/core/libtorrent/download_manager/download_manager.py +++ b/src/tribler/core/libtorrent/download_manager/download_manager.py @@ -6,6 +6,7 @@ from __future__ import annotations import asyncio +import dataclasses import logging import os import time @@ -28,7 +29,7 @@ from tribler.core.libtorrent.download_manager.download import Download from tribler.core.libtorrent.download_manager.download_config import DownloadConfig from tribler.core.libtorrent.download_manager.download_state import DownloadState, DownloadStatus -from tribler.core.libtorrent.torrentdef import TorrentDef, TorrentDefNoMetainfo +from tribler.core.libtorrent.torrentdef import MetainfoDict, TorrentDef, TorrentDefNoMetainfo from tribler.core.libtorrent.uris import unshorten, url_to_path from tribler.core.notifier import Notification, Notifier @@ -56,6 +57,16 @@ logger = logging.getLogger(__name__) +@dataclasses.dataclass +class MetainfoLookup: + """ + A metainfo lookup download and the number of times it has been invoked. + """ + + download: Download + pending: int + + def encode_atp(atp: dict) -> dict: """ Encode the "Add Torrent Params" dictionary to only include bytes, instead of strings and Paths. @@ -82,8 +93,8 @@ def __init__(self, config: TriblerConfigManager, notifier: Notifier, self.config = config self.state_dir = Path(config.get("state_dir")) - self.ltsettings = {} # Stores a copy of the settings dict for each libtorrent session - self.ltsessions = {} + self.ltsettings: dict[lt.session, dict] = {} # Stores a copy of the settings dict for each libtorrent session + self.ltsessions: dict[int, lt.session] = {} self.dht_health_manager: DHTHealthManager | None = None self.listen_ports: dict[int, dict[str, int]] = defaultdict(dict) @@ -97,15 +108,16 @@ def __init__(self, config: TriblerConfigManager, notifier: Notifier, self.downloads: Dict[bytes, Download] = {} self.checkpoint_directory = (self.state_dir / "dlcheckpoints") - self.checkpoints_count = None + self.checkpoints_count = 0 self.checkpoints_loaded = 0 self.all_checkpoints_are_loaded = False - self.metadata_tmpdir = metadata_tmpdir or TemporaryDirectory(suffix="tribler_metainfo_tmpdir") + self.metadata_tmpdir: TemporaryDirectory | None = (metadata_tmpdir or + TemporaryDirectory(suffix="tribler_metainfo_tmpdir")) # Dictionary that maps infohashes to download instances. These include only downloads that have # been made specifically for fetching metainfo, and will be removed afterwards. - self.metainfo_requests = {} - self.metainfo_cache = {} # Dictionary that maps infohashes to cached metainfo items + self.metainfo_requests: dict[bytes, MetainfoLookup] = {} + self.metainfo_cache: dict[bytes, MetainfoDict] = {} # Dictionary that maps infohashes to cached metainfo items self.default_alert_mask = lt.alert.category_t.error_notification | lt.alert.category_t.status_notification | \ lt.alert.category_t.storage_notification | lt.alert.category_t.performance_warning | \ @@ -115,10 +127,10 @@ def __init__(self, config: TriblerConfigManager, notifier: Notifier, self.queued_write_bytes = -1 # Status of libtorrent session to indicate if it can safely close and no pending writes to disk exists. - self.lt_session_shutdown_ready = {} + self.lt_session_shutdown_ready: dict[int, bool] = {} self.dht_ready_task = None self.dht_readiness_timeout = config.get("libtorrent/dht_readiness_timeout") - self._last_states_list = [] + self._last_states_list: list[DownloadState] = [] def is_shutting_down(self) -> bool: """ @@ -267,13 +279,13 @@ def create_session(self, hops: int = 0) -> lt.session: # noqa: PLR0912, PLR0915 # Due to a bug in Libtorrent 0.16.18, the outgoing_port and num_outgoing_ports value should be set in # the settings dictionary logger.info("Creating a session") - settings = {"outgoing_port": 0, - "num_outgoing_ports": 1, - "allow_multiple_connections_per_ip": 0, - "enable_upnp": int(self.config.get("libtorrent/upnp")), - "enable_dht": int(self.config.get("libtorrent/dht")), - "enable_lsd": int(self.config.get("libtorrent/lsd")), - "enable_natpmp": int(self.config.get("libtorrent/natpmp"))} + settings: dict[str, str | float] = {"outgoing_port": 0, + "num_outgoing_ports": 1, + "allow_multiple_connections_per_ip": 0, + "enable_upnp": int(self.config.get("libtorrent/upnp")), + "enable_dht": int(self.config.get("libtorrent/dht")), + "enable_lsd": int(self.config.get("libtorrent/lsd")), + "enable_natpmp": int(self.config.get("libtorrent/natpmp"))} # Copy construct so we don't modify the default list extensions = list(DEFAULT_LT_EXTENSIONS) @@ -363,10 +375,8 @@ def set_proxy_settings(self, ltsession: lt.session, ptype: int, server: tuple[st """ Apply the proxy settings to a libtorrent session. This mechanism changed significantly in libtorrent 1.1.0. """ - settings = {} - settings["proxy_type"] = ptype - settings["proxy_hostnames"] = True - settings["proxy_peer_connections"] = True + settings: dict[str, str | float] = {"proxy_type": ptype, "proxy_hostnames": True, + "proxy_peer_connections": True} if server is not None: proxy_host = server[0] if proxy_host: @@ -476,7 +486,7 @@ def process_alert(self, alert: lt.alert, hops: int = 0) -> None: # noqa: C901, if self.session_stats_callback: self.session_stats_callback(alert) - elif alert_type == "dht_pkt_alert": + elif alert_type == "dht_pkt_alert" and self.dht_health_manager is not None: # Unfortunately, the Python bindings don't have a direction attribute. # So, we'll have to resort to using the string representation of the alert instead. incoming = str(alert).startswith("<==") @@ -507,7 +517,7 @@ def update_ip_filter(self, lt_session: lt.session, ip_addresses: Iterable[str]) ip_filter.add_rule(ip, ip, 0) lt_session.set_ip_filter(ip_filter) - async def get_metainfo(self, infohash: bytes, timeout: float = 7, hops: int | None = None, + async def get_metainfo(self, infohash: bytes, timeout: float = 7, hops: int | None = None, # noqa: C901, PLR0912 url: str | None = None, raise_errors: bool = False) -> dict | None: """ Lookup metainfo for a given infohash. The mechanism works by joining the swarm for the infohash connecting @@ -526,8 +536,8 @@ async def get_metainfo(self, infohash: bytes, timeout: float = 7, hops: int | No logger.info("Trying to fetch metainfo for %s", infohash_hex) if infohash in self.metainfo_requests: - download = self.metainfo_requests[infohash][0] - self.metainfo_requests[infohash][1] += 1 + download = self.metainfo_requests[infohash].download + self.metainfo_requests[infohash].pending += 1 elif infohash in self.downloads: download = self.downloads[infohash] else: @@ -535,7 +545,8 @@ async def get_metainfo(self, infohash: bytes, timeout: float = 7, hops: int | No dcfg = DownloadConfig.from_defaults(self.config) dcfg.set_hops(hops or self.config.get("libtorrent/download_defaults/number_hops")) dcfg.set_upload_mode(True) # Upload mode should prevent libtorrent from creating files - dcfg.set_dest_dir(self.metadata_tmpdir.name) + if self.metadata_tmpdir is not None: + dcfg.set_dest_dir(self.metadata_tmpdir.name) try: download = await self.start_download(tdef=tdef, config=dcfg, hidden=True, checkpoint_disabled=True) except TypeError as e: @@ -543,7 +554,7 @@ async def get_metainfo(self, infohash: bytes, timeout: float = 7, hops: int | No if raise_errors: raise return None - self.metainfo_requests[infohash] = [download, 1] + self.metainfo_requests[infohash] = MetainfoLookup(download, 1) try: metainfo = download.tdef.get_metainfo() or await wait_for(shield(download.future_metainfo), timeout) @@ -565,8 +576,8 @@ async def get_metainfo(self, infohash: bytes, timeout: float = 7, hops: int | No }) if infohash in self.metainfo_requests: - self.metainfo_requests[infohash][1] -= 1 - if self.metainfo_requests[infohash][1] <= 0: + self.metainfo_requests[infohash].pending -= 1 + if self.metainfo_requests[infohash].pending <= 0: await self.remove_download(download, remove_content=True) self.metainfo_requests.pop(infohash, None) @@ -695,7 +706,7 @@ async def start_download(self, torrent_file: str | None = None, tdef: TorrentDef logger.info("ATP: %s", str({k: v for k, v in atp.items() if k not in ["resume_data"]})) # Keep metainfo downloads in self.downloads for now because we will need to remove it later, # and removing the download at this point will stop us from receiving any further alerts. - if infohash not in self.metainfo_requests or self.metainfo_requests[infohash][0] == download: + if infohash not in self.metainfo_requests or self.metainfo_requests[infohash].download == download: logger.info("Metainfo is not requested or download is the first in the queue.") self.downloads[infohash] = download logger.info("Starting handle.") @@ -717,11 +728,11 @@ async def start_handle(self, download: Download, atp: dict) -> None: ltsession = self.get_session(download.config.get_hops()) infohash = download.get_def().get_infohash() - if infohash in self.metainfo_requests and self.metainfo_requests[infohash][0] != download: + if infohash in self.metainfo_requests and self.metainfo_requests[infohash].download != download: logger.info("Cancelling metainfo request(s) for infohash:%s", hexlify(infohash)) - metainfo_dl, _ = self.metainfo_requests.pop(infohash) # Leave the checkpoint. Any checkpoint that exists will belong to the download we are currently starting. - await self.remove_download(metainfo_dl, remove_content=True, remove_checkpoint=False) + await self.remove_download(self.metainfo_requests.pop(infohash).download, + remove_content=True, remove_checkpoint=False) self.downloads[infohash] = download known = {h.info_hash().to_bytes(): h for h in ltsession.get_torrents()} @@ -1093,11 +1104,13 @@ def get_libtorrent_proxy_settings(self) -> tuple[int, tuple[str, str] | None, tu """ Get the settings for the libtorrent proxy. """ - proxy_server = str(self.config.get("libtorrent/proxy_server")) - proxy_server = proxy_server.split(":") if proxy_server else None + setting_proxy_server = str(self.config.get("libtorrent/proxy_server")).split(":") + proxy_server = ((setting_proxy_server[0], setting_proxy_server[1]) + if setting_proxy_server and len(setting_proxy_server) == 2 else None) - proxy_auth = str(self.config.get("libtorrent/proxy_auth")) - proxy_auth = proxy_auth.split(":") if proxy_auth else None + setting_proxy_auth = str(self.config.get("libtorrent/proxy_auth")).split(":") + proxy_auth = ((setting_proxy_auth[0], setting_proxy_auth[1]) + if setting_proxy_auth and len(setting_proxy_auth) == 2 else None) return self.config.get("libtorrent/proxy_type"), proxy_server, proxy_auth diff --git a/src/tribler/core/libtorrent/download_manager/download_state.py b/src/tribler/core/libtorrent/download_manager/download_state.py index 5ac4c96014..655a1d977e 100644 --- a/src/tribler/core/libtorrent/download_manager/download_state.py +++ b/src/tribler/core/libtorrent/download_manager/download_state.py @@ -108,7 +108,7 @@ def get_status(self) -> DownloadStatus: return DownloadStatus.STOPPED_ON_ERROR return DownloadStatus.STOPPED - def get_error(self) -> str: + def get_error(self) -> str | None: """ Returns the Exception that caused the download to be moved to STOPPED_ON_ERROR status. diff --git a/src/tribler/core/libtorrent/download_manager/stream.py b/src/tribler/core/libtorrent/download_manager/stream.py index a4d850ba34..a9196e46ea 100644 --- a/src/tribler/core/libtorrent/download_manager/stream.py +++ b/src/tribler/core/libtorrent/download_manager/stream.py @@ -20,8 +20,11 @@ import logging from asyncio import sleep -from typing import TYPE_CHECKING, Generator +from io import BufferedReader +from pathlib import Path +from typing import TYPE_CHECKING, Generator, cast +import libtorrent from typing_extensions import Self from tribler.core.libtorrent.download_manager.download_state import DownloadStatus @@ -81,27 +84,27 @@ def __init__(self, download: Download) -> None: Create a stream for the given download. """ self._logger = logging.getLogger(self.__class__.__name__) - self.infohash = None - self.filename = None - self.filesize = None - self.enabledfiles = None - self.firstpiece = None - self.lastpiece = None - self.prebuffsize = None - self.destdir = None - self.piecelen = None - self.files = None - self.mapfile = None - self.prebuffpieces = [] - self.headerpieces = [] - self.footerpieces = [] + self.infohash: bytes | None = None + self.filename: Path | None = None + self.filesize: int | None = None + self.enabledfiles: list[int] | None = None + self.firstpiece: int | None = None + self.lastpiece: int | None = None + self.prebuffsize: int | None = None + self.destdir: Path | None = None + self.piecelen: int | None = None + self.files: list[tuple[Path, int]] | None = None + self.mapfile: libtorrent.peer_request | None = None + self.prebuffpieces: list[int] = [] + self.headerpieces: list[int] = [] + self.footerpieces: list[int] = [] # cursorpiecemap represents the pieces maintained by all available chunks. # Each chunk is identified by its startbyte # structure for cursorpieces is # <-------------------- dynamic buffer pieces --------------------> - # {int:startbyte: [bool:ispaused, list:piecestobuffer 'according to the cursor of the related chunk'] - self.cursorpiecemap = {} - self.fileindex = None + # {int:startbyte: (bool:ispaused, list:piecestobuffer 'according to the cursor of the related chunk') + self.cursorpiecemap: dict[int, tuple[bool, list[int]]] = {} + self.fileindex: int | None = None # when first initiate this instance does not have related callback ready, # this coro will be awaited when the stream is enabled. If never enabled, # this coro will be closed. @@ -144,6 +147,12 @@ async def enable(self, fileindex: int = 0, prebufpos: int | None = None) -> None if not self.infohash: await self.__prepare_coro + self.destdir = cast(Path, self.destdir) + self.piecelen = cast(int, self.piecelen) + self.files = cast(list[tuple[Path, int]], self.files) + self.infohash = cast(bytes, self.infohash) + self.mapfile = cast(libtorrent.peer_request, self.mapfile) + # if fileindex not available for torrent raise exception if fileindex >= len(self.files): raise NoAvailableStreamError @@ -266,6 +275,8 @@ def bytestopieces(self, bytes_begin: int, bytes_end: int) -> list[int]: """ Returns the pieces that represents the given byte range. """ + self.filesize = cast(int, self.filesize) # Ensured by ``check_vod`` + bytes_begin = min(self.filesize, bytes_begin) if bytes_begin >= 0 else self.filesize + bytes_begin bytes_end = min(self.filesize, bytes_end) if bytes_end > 0 else self.filesize + bytes_end @@ -280,6 +291,8 @@ def bytetopiece(self, byte_begin: int) -> int: """ Finds the piece position that begin_bytes is mapped to. """ + self.mapfile = cast(libtorrent.peer_request, self.mapfile) # Ensured by ``check_vod`` + return self.mapfile(self.fileindex, byte_begin, 0).piece @check_vod(0) @@ -308,6 +321,9 @@ def iterpieces(self, have: bool | None = None, consec: bool = False, :param consec: True: sequentially, False: all pieces :param startfrom: int: start form index, None: start from first piece """ + self.firstpiece = cast(int, self.firstpiece) # Ensured by ``check_vod`` + self.lastpiece = cast(int, self.lastpiece) # Ensured by ``check_vod`` + if have is not None: pieces_have = self.pieceshave for piece in range(self.firstpiece, self.lastpiece + 1): @@ -341,7 +357,7 @@ def _updateprio(piece: int, prio: int, deadline: int | None = None) -> None: self.__resetdeadline(piece) diffmap[piece] = f"{piece}:-:{curr_prio}->{prio}" - def _find_deadline(piece: int) -> tuple[int, int]: + def _find_deadline(piece: int) -> tuple[int, int] | tuple[None, None]: """ Find the cursor which has this piece closest to its start. Returns the deadline for the piece and the cursor startbyte. @@ -356,7 +372,9 @@ def _find_deadline(piece: int) -> tuple[int, int]: (deadline is None or cursorpieces.index(piece) < deadline): deadline = cursorpieces.index(piece) cursor = startbyte - return deadline, cursor + if cursor is not None and deadline is not None: + return deadline, cursor + return None, None # current priorities piecepriorities = self.__getpieceprios() @@ -364,7 +382,7 @@ def _find_deadline(piece: int) -> tuple[int, int]: # this case might happen when hop count is changing. return # a map holds the changes, used only for logging purposes - diffmap = {} + diffmap: dict[int, str] = {} # flag that holds if we are in static buffering phase of dynamic buffering staticbuff = False for piece in self.iterpieces(have=False): @@ -384,7 +402,7 @@ def _find_deadline(piece: int) -> tuple[int, int]: else: # dynamic buffering deadline, cursor = _find_deadline(piece) - if cursor is not None: + if cursor is not None and deadline is not None: if deadline < len(DEADLINE_PRIO_MAP): # get prio according to deadline _updateprio(piece, DEADLINE_PRIO_MAP[deadline], deadline) @@ -438,7 +456,7 @@ def __init__(self, stream: Stream, startpos: int = 0) -> None: if not stream.enabled: raise NotStreamingError self.stream = stream - self.file = None + self.file: BufferedReader | None = None self.startpos = startpos self.__seekpos = self.startpos @@ -468,9 +486,11 @@ async def open(self) -> None: """ Opens the file in the filesystem until its ready and seeks to the seekpos position. """ - while not self.stream.filename.exists(): + filename = cast(Path, self.stream.filename) # Ensured by ``NotStreamingError`` (in ``__init__``) + + while not filename.exists(): await sleep(1) - self.file = open(self.stream.filename, 'rb') # noqa: ASYNC101, SIM115 + self.file = open(filename, "rb") # noqa: ASYNC101, SIM115 self.file.seek(self.seekpos) @property @@ -515,7 +535,7 @@ def pause(self, force: bool = False) -> bool: Sets the chunk pieces to pause, if not forced, chunk is only paused if other chunks are not paused. """ if not self.ispaused and (self.shouldpause or force): - self.stream.cursorpiecemap[self.startpos][0] = True + self.stream.cursorpiecemap[self.startpos] = True, self.stream.cursorpiecemap[self.startpos][1] return True return False @@ -524,7 +544,7 @@ def resume(self, force: bool = False) -> bool: Sets the chunk pieces to resume, if not forced, chunk is only resume if other chunks are paused. """ if self.ispaused and (not self.shouldpause or force): - self.stream.cursorpiecemap[self.startpos][0] = False + self.stream.cursorpiecemap[self.startpos] = False, self.stream.cursorpiecemap[self.startpos][1] return True return False @@ -533,6 +553,9 @@ async def seek(self, positionbyte: int) -> list[int]: Seeks the stream to the related picece that represents the position byte. Also updates the dynamic buffer accordingly. """ + self.stream.prebuffsize = cast(int, self.stream.prebuffsize) # Ensured by ``NotStreamingError`` + self.stream.piecelen = cast(int, self.stream.piecelen) # Ensured by ``NotStreamingError`` + buffersize = 0 pospiece = self.stream.bytetopiece(positionbyte) pieces = [] @@ -544,7 +567,7 @@ async def seek(self, positionbyte: int) -> list[int]: else: break # update cursor piece that represents this chunk - self.stream.cursorpiecemap[self.startpos] = [self.ispaused, pieces] + self.stream.cursorpiecemap[self.startpos] = (self.ispaused, pieces) # update the torrent prios await self.stream.updateprios() # update the file cursor also @@ -581,6 +604,8 @@ async def read(self) -> bytes: self._logger.debug('Chunk %s: Got no bytes, file is closed', self.startpos) return b'' + self.file = cast(BufferedReader, self.file) # Ensured by ``self.isclosed`` + # wait until we download what we want, then read the localfile # experiment a garbage write mechanism here if the torrent read is too slow piece = self.stream.bytetopiece(self.seekpos) diff --git a/src/tribler/core/libtorrent/restapi/downloads_endpoint.py b/src/tribler/core/libtorrent/restapi/downloads_endpoint.py index 3728dacd49..0f208f7370 100644 --- a/src/tribler/core/libtorrent/restapi/downloads_endpoint.py +++ b/src/tribler/core/libtorrent/restapi/downloads_endpoint.py @@ -5,7 +5,7 @@ from binascii import hexlify, unhexlify from contextlib import suppress from pathlib import Path, PurePosixPath -from typing import TYPE_CHECKING, TypedDict +from typing import TYPE_CHECKING, TypedDict, cast import libtorrent as lt from aiohttp import web @@ -18,7 +18,7 @@ from tribler.core.libtorrent.download_manager.download_config import DownloadConfig from tribler.core.libtorrent.download_manager.download_manager import DownloadManager from tribler.core.libtorrent.download_manager.download_state import DOWNLOAD, UPLOAD, DownloadStatus -from tribler.core.libtorrent.download_manager.stream import STREAM_PAUSE_TIME, StreamChunk +from tribler.core.libtorrent.download_manager.stream import STREAM_PAUSE_TIME, Stream, StreamChunk from tribler.core.restapi.rest_endpoint import ( HTTP_BAD_REQUEST, HTTP_INTERNAL_SERVER_ERROR, @@ -129,14 +129,14 @@ def get_files_info_json(download: Download) -> list[JSONFilesInfo]: files_completion = dict(download.get_state().get_files_completion()) selected_files = download.config.get_selected_files() for file_index, (fn, size) in enumerate(download.get_def().get_files_with_length()): - files_json.append({ + files_json.append(cast(JSONFilesInfo, { "index": file_index, # We always return files in Posix format to make GUI independent of Core and simplify testing "name": str(PurePosixPath(fn)), "size": size, "included": (file_index in selected_files or not selected_files), "progress": files_completion.get(fn, 0.0) - }) + })) return files_json @staticmethod @@ -470,6 +470,8 @@ async def vod_response(self, download: Download, parameters: dict, request: Requ status=HTTP_BAD_REQUEST) if download.stream is None: download.add_stream() + download.stream = cast(Stream, download.stream) + if not download.stream.enabled or download.stream.fileindex != file_index: await wait_for(download.stream.enable(file_index, request.http_range.start or 0), 10) await download.stream.updateprios() @@ -594,7 +596,7 @@ async def get_torrent(self, request: Request) -> RESTResponse: return RESTResponse(lt.bencode(torrent), headers={ "content-type": "application/x-bittorrent", - "Content-Disposition": f"attachment; filename={hexlify(infohash)}.torrent" + "Content-Disposition": f"attachment; filename={hexlify(infohash).decode()}.torrent" }) @docs( @@ -689,6 +691,9 @@ async def collapse_tree_directory(self, request: Request) -> RESTResponse: params = request.query path = params.get("path") + if not path: + return RESTResponse({"error": "path parameter missing"}, status=HTTP_BAD_REQUEST) + download.tdef.torrent_file_tree.collapse(Path(path)) return RESTResponse({"path": path}) @@ -727,6 +732,9 @@ async def expand_tree_directory(self, request: Request) -> RESTResponse: params = request.query path = params.get("path") + if not path: + return RESTResponse({"error": "path parameter missing"}, status=HTTP_BAD_REQUEST) + download.tdef.torrent_file_tree.expand(Path(path)) return RESTResponse({"path": path}) @@ -763,6 +771,9 @@ async def select_tree_path(self, request: Request) -> RESTResponse: params = request.query path = params.get("path") + if not path: + return RESTResponse({"error": "path parameter missing"}, status=HTTP_BAD_REQUEST) + download.set_selected_file_or_dir(Path(path), True) return RESTResponse({}) @@ -799,6 +810,9 @@ async def deselect_tree_path(self, request: Request) -> RESTResponse: params = request.query path = params.get("path") + if not path: + return RESTResponse({"error": "path parameter missing"}, status=HTTP_BAD_REQUEST) + download.set_selected_file_or_dir(Path(path), False) return RESTResponse({}) @@ -818,7 +832,7 @@ def _get_extended_status(self, download: Download) -> DownloadStatus: if download.config.get_hops() == 0: return DownloadStatus.STOPPED - if self.tunnel_community.get_candidates(PEER_FLAG_EXIT_BT): + if self.tunnel_community and self.tunnel_community.get_candidates(PEER_FLAG_EXIT_BT): return DownloadStatus.CIRCUITS return DownloadStatus.EXIT_NODES @@ -881,6 +895,7 @@ async def stream(self, request: Request) -> web.StreamResponse: # noqa: C901 if download.stream is None: download.add_stream() + download.stream = cast(Stream, download.stream) await wait_for(download.stream.enable(file_index, None if start > 0 else 0), 10) stop = download.stream.filesize if http_range.stop is None else min(http_range.stop, download.stream.filesize) @@ -901,7 +916,7 @@ async def stream(self, request: Request) -> web.StreamResponse: # noqa: C901 bytes_todo = stop - start bytes_done = 0 self._logger.info("Got range request for %s-%s (%s bytes)", start, stop, bytes_todo) - while not request.transport.is_closing(): + while request.transport is not None and not request.transport.is_closing(): if chunk.seekpos >= download.stream.filesize: break data = await chunk.read() @@ -927,4 +942,4 @@ async def stream(self, request: Request) -> web.StreamResponse: # noqa: C901 # there is no need to keep sequenial buffer if there are other chunks waiting for prios if chunk.pause(): self._logger.debug("Stream %s-%s is paused, stopping sequential buffer", start, stop) - return response + return response diff --git a/src/tribler/core/libtorrent/restapi/libtorrent_endpoint.py b/src/tribler/core/libtorrent/restapi/libtorrent_endpoint.py index 6cd5ac8b5f..a96b11273b 100644 --- a/src/tribler/core/libtorrent/restapi/libtorrent_endpoint.py +++ b/src/tribler/core/libtorrent/restapi/libtorrent_endpoint.py @@ -1,16 +1,22 @@ +from __future__ import annotations + from asyncio import Future from binascii import hexlify +from typing import TYPE_CHECKING -import libtorrent from aiohttp import web -from aiohttp.abc import Request from aiohttp_apispec import docs from ipv8.REST.schema import schema from marshmallow.fields import Integer -from tribler.core.libtorrent.download_manager.download_manager import DownloadManager from tribler.core.restapi.rest_endpoint import RESTEndpoint, RESTResponse +if TYPE_CHECKING: + import libtorrent + from aiohttp.abc import Request + + from tribler.core.libtorrent.download_manager.download_manager import DownloadManager + class LibTorrentEndpoint(RESTEndpoint): """ @@ -89,7 +95,7 @@ async def get_libtorrent_session_info(self, request: Request) -> RESTResponse: """ Return Libtorrent session information. """ - session_stats = Future() + session_stats: Future[dict[str, int]] = Future() def on_session_stats_alert_received(alert: libtorrent.alert) -> None: if not session_stats.done(): diff --git a/src/tribler/core/libtorrent/restapi/torrentinfo_endpoint.py b/src/tribler/core/libtorrent/restapi/torrentinfo_endpoint.py index ccad2d069a..f8cb934d44 100644 --- a/src/tribler/core/libtorrent/restapi/torrentinfo_endpoint.py +++ b/src/tribler/core/libtorrent/restapi/torrentinfo_endpoint.py @@ -69,7 +69,7 @@ async def query_uri(uri: str, connector: BaseConnector | None = None, headers: L """ Retrieve the response for the given aiohttp context. """ - kwargs = {"headers": headers} + kwargs: dict = {"headers": headers} if timeout: # ClientSession uses a sentinel object for the default timeout. Therefore, it should only be specified if an # actual value has been passed to this function. @@ -121,18 +121,19 @@ async def get_torrent_info(self, request: Request) -> RESTResponse: # noqa: C90 """ params = request.query hops = params.get("hops") - uri = params.get("uri") - self._logger.info("URI: %s", uri) + i_hops = 0 + p_uri = params.get("uri") + self._logger.info("URI: %s", p_uri) if hops: try: - hops = int(hops) + i_hops = int(hops) except ValueError: return RESTResponse({"error": f"wrong value of 'hops' parameter: {hops}"}, status=HTTP_BAD_REQUEST) - if not uri: + if not p_uri: return RESTResponse({"error": "uri parameter missing"}, status=HTTP_BAD_REQUEST) - uri = await unshorten(uri) + uri = await unshorten(p_uri) scheme = URL(uri).scheme if scheme == "file": @@ -151,6 +152,11 @@ async def get_torrent_info(self, request: Request) -> RESTResponse: # noqa: C90 self._logger.warning("Error while querying http uri: %s", str(e)) return RESTResponse({"error": str(e)}, status=HTTP_INTERNAL_SERVER_ERROR) + if not isinstance(response, bytes): + self._logger.warning("Error while reading response from http uri: %s", repr(response)) + return RESTResponse({"error": "Error while reading response from http uri"}, + status=HTTP_INTERNAL_SERVER_ERROR) + if response.startswith(b'magnet'): try: try: @@ -165,7 +171,7 @@ async def get_torrent_info(self, request: Request) -> RESTResponse: # noqa: C90 status=HTTP_INTERNAL_SERVER_ERROR ) - metainfo = await self.download_manager.get_metainfo(infohash, timeout=10.0, hops=hops, + metainfo = await self.download_manager.get_metainfo(infohash, timeout=10.0, hops=i_hops, url=response.decode()) else: metainfo = lt.bdecode(response) @@ -184,7 +190,7 @@ async def get_torrent_info(self, request: Request) -> RESTResponse: # noqa: C90 {"error": f'Error while getting an infohash from magnet: {e.__class__.__name__}: {e}'}, status=HTTP_BAD_REQUEST ) - metainfo = await self.download_manager.get_metainfo(infohash, timeout=10.0, hops=hops, url=uri) + metainfo = await self.download_manager.get_metainfo(infohash, timeout=10.0, hops=i_hops, url=uri) else: return RESTResponse({"error": "invalid uri"}, status=HTTP_BAD_REQUEST) @@ -209,6 +215,6 @@ async def get_torrent_info(self, request: Request) -> RESTResponse: # noqa: C90 ready_for_unicode = recursive_unicode(encoded_metainfo, ignore_errors=True) json_dump = json.dumps(ready_for_unicode, ensure_ascii=False) - encoded_metainfo = hexlify(json_dump.encode()).decode() - return RESTResponse({"metainfo": encoded_metainfo, + + return RESTResponse({"metainfo": hexlify(json_dump.encode()).decode(), "download_exists": download and not download_is_metainfo_request}) diff --git a/src/tribler/core/libtorrent/torrent_file_tree.py b/src/tribler/core/libtorrent/torrent_file_tree.py index abf9f33b5b..ef36e7fe1f 100644 --- a/src/tribler/core/libtorrent/torrent_file_tree.py +++ b/src/tribler/core/libtorrent/torrent_file_tree.py @@ -5,11 +5,9 @@ from bisect import bisect from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Dict, Generator, ItemsView, Sequence, cast +from typing import TYPE_CHECKING, Dict, Generator, ItemsView, cast if TYPE_CHECKING: - from collections import defaultdict - import libtorrent @@ -24,7 +22,7 @@ class Directory: A directory that contains other directories and files. """ - directories: defaultdict[str, TorrentFileTree.Directory] = field(default_factory=dict) + directories: dict[str, TorrentFileTree.Directory] = field(default_factory=dict) files: list[TorrentFileTree.File] = field(default_factory=list) collapsed: bool = True size: int = 0 @@ -84,7 +82,7 @@ def tostr(self, depth: int = 0) -> str: """ return "\t" * depth + f"File({self.index}, {self.name}, {self.size} bytes)" - def sort_key(self) -> Sequence[int | str]: + def sort_key(self) -> tuple[int | str, ...]: """ Sort File instances using natural sort based on their names, which SHOULD be unique. """ @@ -114,17 +112,21 @@ def __ge__(self, other: TorrentFileTree.File) -> bool: """ return self.sort_key() >= other.sort_key() - def __eq__(self, other: TorrentFileTree.File) -> bool: + def __eq__(self, other: object) -> bool: """ Python 3.8 quirk/shortcoming is that File needs to be a SupportsRichComparisonT (instead of using a key). """ - return self.sort_key() == other.sort_key() + if isinstance(other, TorrentFileTree.File): + return self.sort_key() == other.sort_key() + return False - def __ne__(self, other: TorrentFileTree.File) -> bool: + def __ne__(self, other: object) -> bool: """ Python 3.8 quirk/shortcoming is that File needs to be a SupportsRichComparisonT (instead of using a key). """ - return self.sort_key() != other.sort_key() + if isinstance(other, TorrentFileTree.File): + return self.sort_key() != other.sort_key() + return True def __init__(self, file_storage: libtorrent.file_storage) -> None: """ @@ -263,7 +265,7 @@ def find_next_directory(self, from_path: Path) -> tuple[Directory, Path] | None: from_parts = from_path.parts for i in range(1, len(from_parts) + 1): parent_path = Path(os.sep.join(from_parts[:-i])) - parent = self.find(parent_path) + parent = cast(TorrentFileTree.Directory, self.find(parent_path)) dir_in_parent = from_parts[-i] dir_indices = list(parent.directories.keys()) # Python 3 "quirk": dict keys() order is stable index_in_parent = dir_indices.index(dir_in_parent) @@ -298,7 +300,7 @@ def _view_up_after_files(self, number: int, fetch_path: Path) -> list[str]: Run up the tree to the next available directory (if it exists) and continue building a view. """ next_dir_desc = self.find_next_directory(fetch_path) - view = [] + view: list[str] = [] if next_dir_desc is None: return view @@ -352,7 +354,7 @@ def view(self, start_path: tuple[Directory, Path] | Path, number: int) -> list[s if fetch_directory.collapsed: return self._view_up_after_files(number, fetch_path) - view = [] + view: list[str] = [] if self.path_is_dir(element_path): # This is a directory: loop through its directories, then process its files. view, number = self._view_process_directories(number, fetch_directory.directories.items(), fetch_path) diff --git a/src/tribler/core/libtorrent/torrentdef.py b/src/tribler/core/libtorrent/torrentdef.py index a65a2c3915..f5f1423901 100644 --- a/src/tribler/core/libtorrent/torrentdef.py +++ b/src/tribler/core/libtorrent/torrentdef.py @@ -10,7 +10,7 @@ from functools import cached_property from hashlib import sha1 from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List +from typing import TYPE_CHECKING, Any, Generator, Iterable, Literal, cast, overload import aiohttp import libtorrent as lt @@ -23,6 +23,122 @@ from os import PathLike + class FileDict(dict): # noqa: D101 + + @overload # type: ignore[override] + def __getitem__(self, key: Literal[b"length"]) -> int: ... + + @overload + def __getitem__(self, key: Literal[b"path"]) -> list[bytes]: ... + + @overload + def __getitem__(self, key: Literal[b"path.utf-8"]) -> list[bytes] | None: ... + + def __getitem__(self, key: bytes) -> Any: ... # noqa: D105 + + + class InfoDict(dict): # noqa: D101 + + @overload # type: ignore[override] + def __getitem__(self, key: Literal[b"files"]) -> list[FileDict]: ... + + @overload + def __getitem__(self, key: Literal[b"length"]) -> int: ... + + @overload + def __getitem__(self, key: Literal[b"name"]) -> bytes: ... + + @overload + def __getitem__(self, key: Literal[b"name.utf-8"]) -> bytes: ... + + @overload + def __getitem__(self, key: Literal[b"piece length"]) -> int: ... + + @overload + def __getitem__(self, key: Literal[b"pieces"]) -> bytes: ... + + def __getitem__(self, key: bytes) -> Any: ... # noqa: D105 + + + class MetainfoDict(dict): # noqa: D101 + + @overload # type: ignore[override] + def __getitem__(self, key: Literal[b"announce"]) -> bytes | None: ... + + @overload + def __getitem__(self, key: Literal[b"announce-list"]) -> list[list[bytes]] | None: ... + + @overload + def __getitem__(self, key: Literal[b"comment"]) -> bytes | None: ... + + @overload + def __getitem__(self, key: Literal[b"created by"]) -> bytes | None: ... + + @overload + def __getitem__(self, key: Literal[b"creation date"]) -> bytes | None: ... + + @overload + def __getitem__(self, key: Literal[b"encoding"]) -> bytes | None: ... + + @overload + def __getitem__(self, key: Literal[b"info"]) -> InfoDict: ... + + @overload + def __getitem__(self, key: Literal[b"httpseeds"]) -> list[bytes] | None: ... + + @overload + def __getitem__(self, key: Literal[b"nodes"]) -> list[bytes] | None: ... + + @overload + def __getitem__(self, key: Literal[b"urllist"]) -> list[bytes] | None: ... + + def __getitem__(self, key: bytes) -> Any: ... # noqa: D105 + + + class TorrentParameters(dict): # noqa: D101 + + @overload # type: ignore[override] + def __getitem__(self, key: Literal[b"announce"]) -> bytes | None: ... + + @overload + def __getitem__(self, key: Literal[b"announce-list"]) -> list[list[bytes]] | None: ... + + @overload + def __getitem__(self, key: Literal[b"comment"]) -> bytes | None: ... + + @overload + def __getitem__(self, key: Literal[b"created by"]) -> bytes | None: ... + + @overload + def __getitem__(self, key: Literal[b"creation date"]) -> bytes | None: ... + + @overload + def __getitem__(self, key: Literal[b"encoding"]) -> bytes | None: ... + + @overload + def __getitem__(self, key: Literal[b"httpseeds"]) -> list[bytes] | None: ... + + @overload + def __getitem__(self, key: Literal[b"name"]) -> bytes | None: ... + + @overload + def __getitem__(self, key: Literal[b"name.utf-8"]) -> bytes | None: ... + + @overload + def __getitem__(self, key: Literal[b"nodes"]) -> list[bytes] | None: ... + + @overload + def __getitem__(self, key: Literal[b"urllist"]) -> list[bytes] | None: ... + + def __getitem__(self, key: bytes) -> Any: ... # noqa: D105 + +else: + FileDict = dict + InfoDict = dict + MetainfoDict = dict + TorrentParameters = dict + + def escape_as_utf8(string: bytes, encoding: str = "utf8") -> str: """ Make a string UTF-8 compliant, destroying characters if necessary. @@ -50,7 +166,7 @@ def pathlist2filename(pathlist: Iterable[bytes]) -> Path: return Path(*(x.decode() for x in pathlist)) -def get_length_from_metainfo(metainfo: dict, selectedfiles: set[Path]) -> int: +def get_length_from_metainfo(metainfo: MetainfoDict, selectedfiles: set[Path]) -> int: """ Loop through all files in a torrent and calculate the total size. """ @@ -75,7 +191,8 @@ class TorrentDef: It can be used to create new torrents, or analyze existing ones. """ - def __init__(self, metainfo: dict | None = None, torrent_parameters: dict[bytes, Any] | None = None, + def __init__(self, metainfo: MetainfoDict | None = None, + torrent_parameters: TorrentParameters | None = None, ignore_validation: bool = True) -> None: """ Create a new TorrentDef object, possibly based on existing data. @@ -85,17 +202,17 @@ def __init__(self, metainfo: dict | None = None, torrent_parameters: dict[bytes, :param ignore_validation: Whether we ignore the libtorrent validation. """ self._logger = logging.getLogger(self.__class__.__name__) - self.torrent_parameters = {} - self.metainfo = metainfo - self.files_list = [] - self.infohash = None - self._torrent_info = None + self.torrent_parameters: TorrentParameters = cast(TorrentParameters, {}) + self.metainfo: MetainfoDict | None = metainfo + self.files_list: list[Path] = [] + self.infohash: bytes | None = None + self._torrent_info: lt.torrent_info | None = None - if metainfo is not None: + if self.metainfo is not None: # First, make sure the passed metainfo is valid if not ignore_validation: try: - self._torrent_info = lt.torrent_info(metainfo) + self._torrent_info = lt.torrent_info(self.metainfo) self.infohash = self._torrent_info.info_hash() except RuntimeError as exc: raise ValueError from exc @@ -109,30 +226,34 @@ def __init__(self, metainfo: dict | None = None, torrent_parameters: dict[bytes, raise ValueError from exc self.copy_metainfo_to_torrent_parameters() - elif torrent_parameters: + elif torrent_parameters is not None: self.torrent_parameters.update(torrent_parameters) - def copy_metainfo_to_torrent_parameters(self) -> None: + def copy_metainfo_to_torrent_parameters(self) -> None: # noqa: C901 """ Populate the torrent_parameters dictionary with information from the metainfo. """ - for key in [ - b"comment", - b"created by", - b"creation date", - b"announce", - b"announce-list", - b"nodes", - b"httpseeds", - b"urllist", - ]: - if self.metainfo and key in self.metainfo: - self.torrent_parameters[key] = self.metainfo[key] - - infokeys = [b"name", b"piece length"] - for key in infokeys: - if self.metainfo and key in self.metainfo[b"info"]: - self.torrent_parameters[key] = self.metainfo[b"info"][key] + if self.metainfo is not None: + if b"comment" in self.metainfo: + self.torrent_parameters[b"comment"] = self.metainfo[b"comment"] + if b"created by" in self.metainfo: + self.torrent_parameters[b"created by"] = self.metainfo[b"created by"] + if b"creation date" in self.metainfo: + self.torrent_parameters[b"creation date"] = self.metainfo[b"creation date"] + if b"announce" in self.metainfo: + self.torrent_parameters[b"announce"] = self.metainfo[b"announce"] + if b"announce-list" in self.metainfo: + self.torrent_parameters[b"announce-list"] = self.metainfo[b"announce-list"] + if b"nodes" in self.metainfo: + self.torrent_parameters[b"nodes"] = self.metainfo[b"nodes"] + if b"httpseeds" in self.metainfo: + self.torrent_parameters[b"httpseeds"] = self.metainfo[b"httpseeds"] + if b"urllist" in self.metainfo: + self.torrent_parameters[b"urllist"] = self.metainfo[b"urllist"] + if b"name" in self.metainfo[b"info"]: + self.torrent_parameters[b"name"] = self.metainfo[b"info"][b"name"] + if b"piece length" in self.metainfo[b"info"]: + self.torrent_parameters[b"piece length"] = self.metainfo[b"info"][b"piece length"] @property def torrent_info(self) -> lt.torrent_info | None: @@ -166,7 +287,7 @@ def torrent_file_tree(self) -> TorrentFileTree: """ Construct a file tree from this torrent definition. """ - return TorrentFileTree.from_lt_file_storage(self.torrent_info.files()) + return TorrentFileTree.from_lt_file_storage(self.torrent_info.files()) # type: ignore[union-attr] @staticmethod def _threaded_load_job(filepath: str | bytes | PathLike) -> TorrentDef: @@ -204,7 +325,7 @@ def load_from_memory(bencoded_data: bytes) -> TorrentDef: return TorrentDef.load_from_dict(metainfo) @staticmethod - def load_from_dict(metainfo: Dict) -> TorrentDef: + def load_from_dict(metainfo: MetainfoDict) -> TorrentDef: """ Load a metainfo dictionary into a TorrentDef object. @@ -279,19 +400,19 @@ def set_tracker(self, url: str) -> None: url = url[:-1] self.torrent_parameters[b"announce"] = url - def get_tracker(self) -> str | None: + def get_tracker(self) -> bytes | None: """ Returns the torrent announce URL. """ return self.torrent_parameters.get(b"announce", None) - def get_tracker_hierarchy(self) -> list[list[str]]: + def get_tracker_hierarchy(self) -> list[list[bytes]]: """ Returns the hierarchy of trackers. """ return self.torrent_parameters.get(b"announce-list", []) - def get_trackers(self) -> set[str]: + def get_trackers(self) -> set[bytes]: """ Returns a flat set of all known trackers. @@ -333,27 +454,19 @@ def get_nr_pieces(self) -> int: return 0 return len(self.metainfo[b"info"][b"pieces"]) // 20 - def get_pieces(self) -> List: - """ - Returns the pieces. - """ - if not self.metainfo: - return [] - return self.metainfo[b"info"][b"pieces"][:] - def get_infohash(self) -> bytes | None: """ Returns the infohash of the torrent, if metainfo is provided. Might be None if no metainfo is provided. """ return self.infohash - def get_metainfo(self) -> dict: + def get_metainfo(self) -> MetainfoDict | None: """ Returns the metainfo of the torrent. Might be None if no metainfo is provided. """ return self.metainfo - def get_name(self) -> bytes: + def get_name(self) -> bytes | None: """ Returns the name as raw string of bytes. """ @@ -363,7 +476,7 @@ def get_name_utf8(self) -> str: """ Not all names are utf-8, attempt to construct it as utf-8 anyway. """ - return escape_as_utf8(self.get_name(), self.get_encoding()) + return escape_as_utf8(self.get_name() or b"", self.get_encoding()) def set_name(self, name: bytes) -> None: """ @@ -376,45 +489,30 @@ def set_name(self, name: bytes) -> None: def get_name_as_unicode(self) -> str: """ Returns the info['name'] field as Unicode string. - """ - if self.metainfo and b"name.utf-8" in self.metainfo[b"info"]: - # There is an utf-8 encoded name. We assume that it is - # correctly encoded and use it normally - try: - return self.metainfo[b"info"][b"name.utf-8"].decode() - except UnicodeError: - pass - - if self.metainfo and b"name" in self.metainfo[b"info"]: - # Try to use the 'encoding' field. If it exists, it - # should contain something like 'utf-8' - if "encoding" in self.metainfo: - try: - return self.metainfo[b"info"][b"name"].decode(self.metainfo[b"encoding"]) - except UnicodeError: - pass - except LookupError: - # Some encodings are not supported by python. For - # instance, the MBCS codec which is used by - # Windows is not supported (Jan 2010) - pass - - # Try to convert the names in path to unicode, assuming - # that it was encoded as utf-8 - try: - return self.metainfo[b"info"][b"name"].decode() - except UnicodeError: - pass - - # Convert the names in path to unicode by replacing out - # all characters that may -even remotely- cause problems - # with the '?' character - try: - return self._filter_characters(self.metainfo[b"info"][b"name"]) - except UnicodeError: - pass - - # We failed. Returning an empty string + + If there is an utf-8 encoded name, we assume that it is correctly encoded and use it normally. + Otherwise, if there is an encoding[1], we attempt to decode the (bytes) name. + Otherwise, we attempt to decode the (bytes) name as UTF-8. + Otherwise, we attempt to replace non-UTF-8 characters from the (bytes) name with "?". + If all of the above fails, this returns an empty string. + + [1] Some encodings are not supported by python. For instance, the MBCS codec which is used by Windows is not + supported (Jan 2010). + """ + if self.metainfo is not None: + if b"name.utf-8" in self.metainfo[b"info"]: + with suppress(UnicodeError): + return self.metainfo[b"info"][b"name.utf-8"].decode() + + if (name := self.metainfo[b"info"].get(b"name")) is not None: + if (encoding := self.metainfo.get(b"encoding")) is not None: + with suppress(UnicodeError), suppress(LookupError): + return name.decode(encoding.decode()) + with suppress(UnicodeError): + return name.decode() + with suppress(UnicodeError): + return self._filter_characters(name) + return "" def save(self, torrent_filepath: str | None = None) -> None: @@ -431,7 +529,7 @@ def save(self, torrent_filepath: str | None = None) -> None: self.copy_metainfo_to_torrent_parameters() self.infohash = torrent_dict['infohash'] - def _get_all_files_as_unicode_with_length(self) -> Iterator[Path, int]: # noqa: C901, PLR0912 + def _get_all_files_as_unicode_with_length(self) -> Generator[tuple[Path, int], None, None]: # noqa: C901, PLR0912 """ Get a generator for files in the torrent def. No filtering is possible and all tricks are allowed to obtain a unicode list of filenames. @@ -440,7 +538,7 @@ def _get_all_files_as_unicode_with_length(self) -> Iterator[Path, int]: # noqa: """ if self.metainfo and b"files" in self.metainfo[b"info"]: # Multi-file torrent - files = self.metainfo[b"info"][b"files"] + files = cast(FileDict, self.metainfo[b"info"][b"files"]) for file_dict in files: if b"path.utf-8" in file_dict: @@ -455,10 +553,9 @@ def _get_all_files_as_unicode_with_length(self) -> Iterator[Path, int]: # noqa: if b"path" in file_dict: # Try to use the 'encoding' field. If it exists, it should contain something like 'utf-8'. - if b"encoding" in self.metainfo: - encoding = self.metainfo[b"encoding"].decode() + if (encoding := self.metainfo.get(b"encoding")) is not None: try: - yield (Path(*(element.decode(encoding) for element in file_dict[b"path"])), + yield (Path(*(element.decode(encoding.decode()) for element in file_dict[b"path"])), file_dict[b"length"]) continue except UnicodeError: @@ -486,9 +583,9 @@ def _get_all_files_as_unicode_with_length(self) -> Iterator[Path, int]: # noqa: elif self.metainfo: # Single-file torrent - yield self.get_name_as_unicode(), self.metainfo[b"info"][b"length"] + yield Path(self.get_name_as_unicode()), self.metainfo[b"info"][b"length"] - def get_files_with_length(self, exts: str | None = None) -> list[tuple[Path, int]]: + def get_files_with_length(self, exts: set[str] | None = None) -> list[tuple[Path, int]]: """ The list of files in the torrent def. @@ -518,7 +615,7 @@ def get_length(self, selectedfiles: set[Path] | None = None) -> int: :return: A length (long) """ - if self.metainfo: + if self.metainfo and selectedfiles is not None: return get_length_from_metainfo(self.metainfo, selectedfiles) return 0 @@ -562,8 +659,7 @@ def get_index_of_file_in_files(self, file: str | None) -> int: for i in range(len(info[b"files"])): file_dict = info[b"files"][i] - intorrentpath = (pathlist2filename(file_dict[b"path.utf-8"]) if b"path.utf-8" in file_dict - else pathlist2filename(file_dict[b"path"])) + intorrentpath = pathlist2filename(file_dict.get(b"path.utf-8", file_dict[b"path"])) if intorrentpath == Path(file): return i @@ -586,11 +682,9 @@ def __init__(self, infohash: bytes, name: bytes, url: bytes | str | None = None) """ Create a new valid torrent def without metainfo. """ - torrent_parameters = { - b"name": name - } + torrent_parameters: TorrentParameters = cast(TorrentParameters, {b"name": name}) if url is not None: - torrent_parameters[b"urllist"] = [url] + torrent_parameters[b"urllist"] = [url if isinstance(url, bytes) else url.encode()] super().__init__(torrent_parameters=torrent_parameters) self.infohash = infohash diff --git a/src/tribler/core/libtorrent/torrents.py b/src/tribler/core/libtorrent/torrents.py index f7dce22581..d0b5bdfdf1 100644 --- a/src/tribler/core/libtorrent/torrents.py +++ b/src/tribler/core/libtorrent/torrents.py @@ -5,27 +5,33 @@ from contextlib import suppress from hashlib import sha1 from os.path import getsize -from typing import TYPE_CHECKING, Any, Dict, Iterable, TypedDict +from typing import TYPE_CHECKING, Callable, Iterable, ParamSpec, TypedDict, TypeVar import libtorrent as lt if TYPE_CHECKING: from pathlib import Path + from tribler.core.libtorrent.download_manager.download import Download + from tribler.core.libtorrent.download_manager.stream import Stream + from tribler.core.libtorrent.torrentdef import InfoDict + logger = logging.getLogger(__name__) +WrappedParams = ParamSpec("WrappedParams") +WrappedReturn = TypeVar("WrappedReturn") +Wrapped = Callable[WrappedParams, WrappedReturn] -def check_handle(default=None): +def check_handle(default: WrappedReturn) -> Wrapped: """ Return the libtorrent handle if it's available, else return the default value. Author(s): Egbert Bouman """ - def wrap(f): - def invoke_func(*args, **kwargs): - download = args[0] - if download.handle and download.handle.is_valid(): + def wrap(f: Wrapped) -> Wrapped: + def invoke_func(self: Download, *args: WrappedParams.args, **kwargs: WrappedParams.kwargs) -> WrappedReturn: + if self.handle and self.handle.is_valid(): return f(*args, **kwargs) return default @@ -34,21 +40,22 @@ def invoke_func(*args, **kwargs): return wrap -def require_handle(func): +def require_handle(func: Wrapped) -> Wrapped: """ Invoke the function once the handle is available. Returns a future that will fire once the function has completed. Author(s): Egbert Bouman """ - def invoke_func(*args, **kwargs): - result_future = Future() + def invoke_func(self: Download, *args: WrappedParams.args, + **kwargs: WrappedParams.kwargs) -> Future[WrappedReturn | None]: + result_future: Future[WrappedReturn | None] = Future() - def done_cb(fut): + def done_cb(fut: Future[lt.torrent_handle]) -> None: with suppress(CancelledError): handle = fut.result() - if fut.cancelled() or result_future.done() or handle != download.handle or not handle.is_valid(): + if fut.cancelled() or result_future.done() or handle != self.handle or not handle.is_valid(): logger.warning('Can not invoke function, handle is not valid or future is cancelled') result_future.set_result(None) return @@ -65,21 +72,20 @@ def done_cb(fut): else: result_future.set_result(result) - download = args[0] - handle_future = download.get_handle() + handle_future = self.get_handle() handle_future.add_done_callback(done_cb) return result_future return invoke_func -def check_vod(default=None): +def check_vod(default: WrappedReturn) -> Wrapped: """ Check if torrent is vod mode, else return default. """ - def wrap(f): - def invoke_func(self, *args, **kwargs): + def wrap(f: Wrapped) -> Wrapped: + def invoke_func(self: Stream, *args: WrappedParams.args, **kwargs: WrappedParams.kwargs) -> WrappedReturn: if self.enabled: return f(self, *args, **kwargs) return default @@ -121,7 +127,7 @@ class TorrentFileResult(TypedDict): infohash: bytes -def create_torrent_file(file_path_list: list[Path], params: Dict[bytes, Any], # noqa: C901 +def create_torrent_file(file_path_list: list[Path], params: InfoDict, # noqa: C901 torrent_filepath: str | None = None) -> TorrentFileResult: """ Create a torrent file from the given paths and parameters. diff --git a/src/tribler/core/libtorrent/trackers.py b/src/tribler/core/libtorrent/trackers.py index 09d43e4163..d15cefa7a8 100644 --- a/src/tribler/core/libtorrent/trackers.py +++ b/src/tribler/core/libtorrent/trackers.py @@ -121,6 +121,10 @@ def _parse_tracker_url(tracker_url: str) -> tuple[str, tuple[str, int], str]: scheme = parsed_url.scheme port = parsed_url.port + if host is None: + msg = f"Could not resolve hostname from {tracker_url}." + raise MalformedTrackerURLException(msg) + if scheme not in SUPPORTED_SCHEMES: msg = f"Unsupported tracker type ({scheme})." raise MalformedTrackerURLException(msg) @@ -136,7 +140,7 @@ def _parse_tracker_url(tracker_url: str) -> tuple[str, tuple[str, int], str]: if not port: port = DEFAULT_PORTS[scheme] - return scheme, (host, port), path + return scheme, (host, port or 0), path def add_url_params(url: str, params: dict) -> str: diff --git a/src/tribler/core/notifier.py b/src/tribler/core/notifier.py index 021995a312..afc125ed6d 100644 --- a/src/tribler/core/notifier.py +++ b/src/tribler/core/notifier.py @@ -3,7 +3,7 @@ import typing from collections import defaultdict from enum import Enum -from typing import Callable, Optional +from typing import Callable from ipv8.messaging.anonymization.tunnel import Circuit @@ -15,7 +15,7 @@ class Desc(typing.NamedTuple): name: str fields: list[str] - types: list[type] + types: list[tuple[type, ...] | type] class Notification(Enum): @@ -45,7 +45,7 @@ class Notification(Enum): [bytes, bytes, int]) torrent_metadata_added = Desc("torrent_metadata_added", ["metadata"], [dict]) new_torrent_metadata_created = Desc("new_torrent_metadata_created", ["infohash", "title"], - [Optional[bytes], Optional[str]]) + [(bytes, type(None)), (str, type(None))]) class Notifier: @@ -57,10 +57,10 @@ def __init__(self) -> None: """ Create a new notifier. """ - self.observers = defaultdict(list) - self.delegates = set() + self.observers: dict[Notification, list[Callable[..., None]]] = defaultdict(list) + self.delegates: set[Callable[..., None]] = set() - def add(self, topic: Notification, observer: Callable) -> None: + def add(self, topic: Notification, observer: Callable[..., None]) -> None: """ Add an observer for the given Notification type. """ @@ -70,13 +70,12 @@ def notify(self, topic: Notification | str, /, **kwargs) -> None: """ Notify all observers that have subscribed to the given topic. """ - if isinstance(topic, str): - topic = getattr(Notification, topic) - topic_name, args, types = topic.value + notification = getattr(Notification, topic) if isinstance(topic, str) else topic + topic_name, args, types = notification.value if set(args) ^ set(kwargs.keys()): message = f"{topic_name} expecting arguments {args} (of types {types}) but received {kwargs}" raise ValueError(message) - for observer in self.observers[topic]: + for observer in self.observers[notification]: observer(**kwargs) for delegate in self.delegates: - delegate(topic, **kwargs) + delegate(notification, **kwargs) diff --git a/src/tribler/core/rendezvous/orm_bindings/certificate.py b/src/tribler/core/rendezvous/orm_bindings/certificate.py index b1bd7a54f5..68b330c557 100644 --- a/src/tribler/core/rendezvous/orm_bindings/certificate.py +++ b/src/tribler/core/rendezvous/orm_bindings/certificate.py @@ -1,17 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterator from pony.orm import Database, Required from typing_extensions import Self if TYPE_CHECKING: import dataclasses - from collections.abc import Iterable + + + class IterRendezvousCertificate(type): # noqa: D101 + + def __iter__(cls) -> Iterator[RendezvousCertificate]: ... # noqa: D105 @dataclasses.dataclass - class RendezvousCertificate(metaclass=Iterable): + class RendezvousCertificate(metaclass=IterRendezvousCertificate): """ The database type for rendezvous certificates. """ diff --git a/src/tribler/core/restapi/events_endpoint.py b/src/tribler/core/restapi/events_endpoint.py index 4056a4066a..917189896c 100644 --- a/src/tribler/core/restapi/events_endpoint.py +++ b/src/tribler/core/restapi/events_endpoint.py @@ -60,7 +60,7 @@ def __init__(self, notifier: Notifier, public_key: str | None = None) -> None: self.undelivered_error: Exception | None = None self.public_key = public_key self.notifier = notifier - self.queue = Queue() + self.queue: Queue[MessageDict] = Queue() self.register_task("Process queue", self.process_queue) notifier.add(Notification.circuit_removed, self.on_circuit_removed) @@ -103,7 +103,7 @@ def initial_message(self) -> MessageDict: """ return { "topic": Notification.events_start.value.name, - "kwargs": {"public_key": self.public_key, "version": "Tribler Experimental"} + "kwargs": {"public_key": self.public_key or "", "version": "Tribler Experimental"} } def error_message(self, reported_error: Exception) -> MessageDict: @@ -120,12 +120,11 @@ def encode_message(self, message: MessageDict) -> bytes: Use JSON to dump the given message to bytes. """ try: - message = json.dumps(message) + return b"data: " + json.dumps(message).encode() + b"\n\n" except (UnicodeDecodeError, TypeError) as e: # The message contains invalid characters; fix them self._logger.exception("Event contains non-unicode characters, dropping %s", repr(message)) return self.encode_message(self.error_message(e)) - return b"data: " + message.encode() + b"\n\n" def has_connection_to_gui(self) -> bool: """ diff --git a/src/tribler/core/restapi/rest_endpoint.py b/src/tribler/core/restapi/rest_endpoint.py index b6cf4b9e5f..ec9cd2ded7 100644 --- a/src/tribler/core/restapi/rest_endpoint.py +++ b/src/tribler/core/restapi/rest_endpoint.py @@ -65,7 +65,7 @@ class RESTResponse(web.Response): JSON-compatible response bodies are automatically converted to JSON type. """ - def __init__(self, body: dict | list | bytes | None = None, headers: dict | None = None, + def __init__(self, body: dict | list | bytes | str | None = None, headers: dict | None = None, content_type: str | None = None, status: int = 200, **kwargs) -> None: """ Create a new rest response. diff --git a/src/tribler/core/restapi/rest_manager.py b/src/tribler/core/restapi/rest_manager.py index dbc66972ff..112c87d9f0 100644 --- a/src/tribler/core/restapi/rest_manager.py +++ b/src/tribler/core/restapi/rest_manager.py @@ -3,8 +3,9 @@ import logging import ssl import traceback +from asyncio.base_events import Server from pathlib import Path -from typing import TYPE_CHECKING, Awaitable, Callable, TypeVar +from typing import TYPE_CHECKING, Awaitable, Callable, TypeVar, cast from aiohttp import web from aiohttp.web_exceptions import HTTPNotFound, HTTPRequestEntityTooLarge @@ -174,23 +175,23 @@ async def start(self) -> None: if self.config.get("api/http_enabled"): self._logger.info("Http enabled") - await self.start_http_site() + await self.start_http_site(self.runner) if self.config.get("api/https_enabled"): self._logger.info("Https enabled") - await self.start_https_site() + await self.start_https_site(self.runner) self._logger.info("Swagger docs: http://%s:%d/docs", self.http_host, self.config.get("api/http_port")) self._logger.info("Swagger JSON: http://%s:%d/docs/swagger.json", self.http_host, self.config.get("api/http_port")) - async def start_http_site(self) -> None: + async def start_http_site(self, runner: web.AppRunner) -> None: """ Start serving HTTP requests. """ api_port = max(self.config.get("api/http_port"), 0) # if the value in config is <0 we convert it to 0 - self.site = web.TCPSite(self.runner, self.http_host, api_port, shutdown_timeout=self.shutdown_timeout) + self.site = web.TCPSite(runner, self.http_host, api_port, shutdown_timeout=self.shutdown_timeout) self._logger.info("Starting HTTP REST API server on port %d...", api_port) try: @@ -201,12 +202,12 @@ async def start_http_site(self) -> None: raise if not api_port: - api_port = self.site._server.sockets[0].getsockname()[1] # noqa: SLF001 + api_port = cast(Server, self.site._server).sockets[0].getsockname()[1] # noqa: SLF001 self.set_api_port(api_port) self._logger.info("HTTP REST API server started on port %d", api_port) - async def start_https_site(self) -> None: + async def start_https_site(self, runner: web.AppRunner) -> None: """ Start serving HTTPS requests. """ @@ -214,7 +215,7 @@ async def start_https_site(self) -> None: ssl_context.load_cert_chain(Path(self.config.get("state_dir")) / "https_certfile") port = self.config.get("api/https_port") - self.site_https = web.TCPSite(self.runner, self.https_host, port, ssl_context=ssl_context) + self.site_https = web.TCPSite(runner, self.https_host, port, ssl_context=ssl_context) await self.site_https.start() self._logger.info("Started HTTPS REST API: %s", self.site_https.name) diff --git a/src/tribler/core/socks5/aiohttp_connector.py b/src/tribler/core/socks5/aiohttp_connector.py index 49c46b0cef..7a3b00b396 100644 --- a/src/tribler/core/socks5/aiohttp_connector.py +++ b/src/tribler/core/socks5/aiohttp_connector.py @@ -44,8 +44,10 @@ def __init__(self, proxy_addr: tuple, **kwargs) -> None: super().__init__(**kwargs) self.proxy_addr = proxy_addr - async def _wrap_create_connection(self, protocol_factory: Callable[[], Socks5ClientUDPConnection], host: str, - port: int, **kwargs) -> tuple[BaseTransport, Socks5ClientUDPConnection]: + async def _wrap_create_connection(self, # type: ignore[override] + protocol_factory: Callable[[], Socks5ClientUDPConnection], + host: str, port: int, + **kwargs) -> tuple[BaseTransport, Socks5ClientUDPConnection]: """ Create a transport and its associated connection. """ diff --git a/src/tribler/core/socks5/client.py b/src/tribler/core/socks5/client.py index 9b500c2bec..4ac61b90a1 100644 --- a/src/tribler/core/socks5/client.py +++ b/src/tribler/core/socks5/client.py @@ -3,8 +3,8 @@ import ipaddress import logging import socket -from asyncio import BaseTransport, DatagramProtocol, Protocol, Queue, get_event_loop -from typing import Callable +from asyncio import BaseTransport, DatagramProtocol, DatagramTransport, Protocol, Queue, WriteTransport, get_event_loop +from typing import Callable, cast from ipv8.messaging.interfaces.udp.endpoint import DomainAddress from ipv8.messaging.serialization import PackError @@ -39,7 +39,7 @@ def __init__(self, callback: Callable[[bytes, DomainAddress | tuple], None]) -> Create a new Socks5 udp connection. """ self.callback = callback - self.transport = None + self.transport: DatagramTransport | None = None self.proxy_udp_addr = None self.logger = logging.getLogger(self.__class__.__name__) @@ -47,7 +47,7 @@ def connection_made(self, transport: BaseTransport) -> None: """ Callback for when a transport is available. """ - self.transport = transport + self.transport = cast(DatagramTransport, transport) def datagram_received(self, data: bytes, _: tuple) -> None: """ @@ -64,6 +64,8 @@ def sendto(self, data: bytes, target_addr: DomainAddress | tuple) -> None: """ Attempt to send the given data to the given address. """ + if self.transport is None: + return try: ipaddress.IPv4Address(target_addr[0]) except ipaddress.AddressValueError: @@ -77,23 +79,23 @@ class Socks5Client(Protocol): This object represents a minimal Socks5 client. Both TCP and UDP are supported. """ - def __init__(self, proxy_addr: tuple, callback: Callable[[bytes], None]) -> None: + def __init__(self, proxy_addr: tuple, callback: Callable[[bytes, DomainAddress | tuple], None]) -> None: """ Create a client for the given proxy address and call the given callback with incoming data. """ self.proxy_addr = proxy_addr self.callback = callback - self.transport = None - self.connection = None - self.connected_to = None - self.queue = Queue(maxsize=1) + self.transport: WriteTransport | None = None + self.connection: Socks5ClientUDPConnection | None = None + self.connected_to: DomainAddress | tuple | None = None + self.queue: Queue[bytes] = Queue(maxsize=1) def data_received(self, data: bytes) -> None: """ Callback for when data comes in. Call our registered callback or save the incoming save for calling back later. """ if self.connected_to: - self.callback(data) + self.callback(data, self.connected_to) elif self.queue.empty(): self.queue.put_nowait(data) @@ -103,11 +105,11 @@ def connection_lost(self, _: Exception | None) -> None: """ self.transport = None - async def _send(self, data: bytes) -> None: + async def _send(self, data: bytes) -> bytes: """ Send data to the remote and wait for an answer. """ - self.transport.write(data) + cast(WriteTransport, self.transport).write(data) return await self.queue.get() async def _login(self) -> None: @@ -192,7 +194,8 @@ async def associate_udp(self) -> None: Login and associate with the proxy. """ if self.connected: - msg = f"Client already used for connecting to {self.connected_to[0]}:{self.connected_to[1]}" + connection = cast(tuple, self.connected_to) + msg = f"Client already used for connecting to {connection[0]}:{connection[1]}" raise Socks5Error(msg) if not self.associated: @@ -208,7 +211,7 @@ def sendto(self, data: bytes, target_addr: tuple) -> None: if not self.associated: msg = "Not associated yet. First call associate_udp." raise Socks5Error(msg) - self.connection.sendto(data, target_addr) + cast(Socks5ClientUDPConnection, self.connection).sendto(data, target_addr) async def connect_tcp(self, target_addr: tuple) -> None: """ @@ -231,4 +234,4 @@ def write(self, data: bytes) -> None: if not self.connected: msg = "Not connected yet. First call connect_tcp." raise Socks5Error(msg) - return self.transport.write(data) + cast(WriteTransport, self.transport).write(data) diff --git a/src/tribler/core/socks5/connection.py b/src/tribler/core/socks5/connection.py index dd6c17c02d..c7a4feb72a 100644 --- a/src/tribler/core/socks5/connection.py +++ b/src/tribler/core/socks5/connection.py @@ -1,8 +1,8 @@ from __future__ import annotations import logging -from asyncio import BaseTransport, Protocol, ensure_future -from typing import TYPE_CHECKING +from asyncio import BaseTransport, Protocol, WriteTransport, ensure_future +from typing import TYPE_CHECKING, cast from ipv8.messaging.serialization import PackError @@ -52,20 +52,18 @@ def __init__(self, socksserver: Socks5Server) -> None: self._logger = logging.getLogger(self.__class__.__name__) self._logger.setLevel(logging.WARNING) self.socksserver = socksserver - self.transport = None + self.transport: WriteTransport | None = None self.connect_to = None - self.udp_connection = None + self.udp_connection: RustUDPConnection | SocksUDPConnection | None = None self.state = ConnectionState.BEFORE_METHOD_REQUEST self.buffer = b"" - self.destinations = {} - def connection_made(self, transport: BaseTransport) -> None: """ Callback for when a connection is made. """ - self.transport = transport + self.transport = cast(WriteTransport, transport) def data_received(self, data: bytes) -> None: """ @@ -109,7 +107,7 @@ def _try_handshake(self) -> bool: # Only accept NO AUTH if request.version != SOCKS_VERSION or 0x00 not in request.methods: self._logger.error("Client has sent INVALID METHOD REQUEST") - self.buffer = "" + self.buffer = b"" self.close() return False @@ -118,7 +116,7 @@ def _try_handshake(self) -> bool: # Respond that we would like to use NO AUTHENTICATION (0x00) if self.state is not ConnectionState.CONNECTED: response = socks5_serializer.pack_serializable(MethodsResponse(SOCKS_VERSION, 0)) - self.transport.write(response) + cast(WriteTransport, self.transport).write(response) # We are connected now, the next incoming message will be a REQUEST self.state = ConnectionState.CONNECTED @@ -153,7 +151,7 @@ def _try_request(self) -> bool: elif request.cmd == REQ_CMD_BIND: payload = CommandResponse(SOCKS_VERSION, REP_SUCCEEDED, 0, ("127.0.0.1", 1081)) response = socks5_serializer.pack_serializable(payload) - self.transport.write(response) + cast(WriteTransport, self.transport).write(response) self.state = ConnectionState.PROXY_REQUEST_ACCEPTED elif request.cmd == REQ_CMD_CONNECT: @@ -161,7 +159,7 @@ def _try_request(self) -> bool: self.connect_to = request.destination payload = CommandResponse(SOCKS_VERSION, REP_SUCCEEDED, 0, ("127.0.0.1", 1081)) response = socks5_serializer.pack_serializable(payload) - self.transport.write(response) + cast(WriteTransport, self.transport).write(response) else: self.deny_request() @@ -176,7 +174,7 @@ def deny_request(self) -> None: payload = CommandResponse(SOCKS_VERSION, REP_COMMAND_NOT_SUPPORTED, 0, ("0.0.0.0", 0)) response = socks5_serializer.pack_serializable(payload) - self.transport.write(response) + cast(WriteTransport, self.transport).write(response) self._logger.error("DENYING SOCKS5 request") async def on_udp_associate_request(self, request: CommandRequest) -> None: @@ -191,13 +189,13 @@ async def on_udp_associate_request(self, request: CommandRequest) -> None: else: self.udp_connection = SocksUDPConnection(self, request.destination) await self.udp_connection.open() - ip, _ = self.transport.get_extra_info('sockname') + ip, _ = cast(WriteTransport, self.transport).get_extra_info('sockname') port = self.udp_connection.get_listen_port() self._logger.info("Accepting UDP ASSOCIATE request to %s:%d (BIND addr %s:%d)", ip, port, *request.destination) payload = CommandResponse(SOCKS_VERSION, REP_SUCCEEDED, 0, (ip, port)) response = socks5_serializer.pack_serializable(payload) - self.transport.write(response) + cast(WriteTransport, self.transport).write(response) def connection_lost(self, _: Exception | None) -> None: """ diff --git a/src/tribler/core/socks5/conversion.py b/src/tribler/core/socks5/conversion.py index eb0f824020..2b2b053abe 100644 --- a/src/tribler/core/socks5/conversion.py +++ b/src/tribler/core/socks5/conversion.py @@ -143,8 +143,7 @@ def unpack(self, data: bytes, offset: int, unpack_list: list, *args: Any) -> int offset += 1 host = "" try: - host = data[offset:offset + domain_length] - host = host.decode() + host = data[offset:offset + domain_length].decode() except UnicodeDecodeError as e: msg = f"Could not decode host {host}" raise InvalidAddressException(msg) from e diff --git a/src/tribler/core/socks5/server.py b/src/tribler/core/socks5/server.py index a6a9261c98..3ba5bc1bf9 100644 --- a/src/tribler/core/socks5/server.py +++ b/src/tribler/core/socks5/server.py @@ -7,6 +7,8 @@ from tribler.core.socks5.connection import Socks5Connection if TYPE_CHECKING: + from asyncio.base_events import Server + from ipv8_rust_tunnels.endpoint import RustEndpoint from tribler.core.tunnel.dispatcher import TunnelDispatcher @@ -26,7 +28,7 @@ def __init__(self, hops: int, port: int | None = None, output_stream: TunnelDisp self.hops = hops self.port = port self.output_stream = output_stream - self.server = None + self.server: Server | None = None self.sessions: List[Socks5Connection] = [] self.rust_endpoint = rust_endpoint @@ -40,7 +42,8 @@ def build_protocol() -> Socks5Connection: self.sessions.append(socks5connection) return socks5connection - self.server = await get_event_loop().create_server(build_protocol, "127.0.0.1", self.port) + self.server = await get_event_loop().create_server(build_protocol, "127.0.0.1", + self.port) # type: ignore[arg-type] server_socket = self.server.sockets[0] _, self.port = server_socket.getsockname()[:2] self._logger.info("Started SOCKS5 server on port %i", self.port) diff --git a/src/tribler/core/socks5/udp_connection.py b/src/tribler/core/socks5/udp_connection.py index aa748f8ca0..446c319147 100644 --- a/src/tribler/core/socks5/udp_connection.py +++ b/src/tribler/core/socks5/udp_connection.py @@ -1,8 +1,8 @@ from __future__ import annotations import logging -from asyncio import DatagramProtocol, get_event_loop -from typing import TYPE_CHECKING +from asyncio import DatagramProtocol, DatagramTransport, get_event_loop +from typing import TYPE_CHECKING, cast from ipv8.messaging.serialization import PackError @@ -26,7 +26,7 @@ def __init__(self, socksconnection: Socks5Connection, remote_udp_address: Domain """ self._logger = logging.getLogger(self.__class__.__name__) self.socksconnection = socksconnection - self.transport = None + self.transport: DatagramTransport | None = None self.remote_udp_address = remote_udp_address if remote_udp_address != ("0.0.0.0", 0) else None async def open(self) -> None: @@ -39,7 +39,10 @@ def get_listen_port(self) -> int: """ Retrieve the listen port for this protocol. """ - _, port = self.transport.get_extra_info("sockname") + if self.transport: + _, port = self.transport.get_extra_info("sockname") + else: + port = 0 return port def send_datagram(self, data: bytes) -> bool: @@ -47,15 +50,21 @@ def send_datagram(self, data: bytes) -> bool: Send a datagram to the known remote address. Returns False if there is no remote yet. """ if self.remote_udp_address: - self.transport.sendto(data, self.remote_udp_address) + cast(DatagramTransport, self.transport).sendto(data, self.remote_udp_address) return True self._logger.error("cannot send data, no clue where to send it to") return False - def datagram_received(self, data: bytes, source: tuple) -> bool: + def datagram_received(self, data: bytes, source: tuple) -> None: """ The callback for when data is handed to our protocol. """ + self.datagram_received(data, source) + + def cb_datagram_received(self, data: bytes, source: tuple) -> bool: + """ + The callback for when data is handed to our protocol and whether the handling succeeded. + """ # If remote_address was not set before, use first one if self.remote_udp_address is None: self.remote_udp_address = source @@ -98,7 +107,7 @@ def __init__(self, rust_endpoint: RustEndpoint, hops: int) -> None: """ self.rust_endpoint = rust_endpoint self.hops = hops - self.port = None + self.port: int | None = None self.logger = logging.getLogger(self.__class__.__name__) @property @@ -129,7 +138,7 @@ def get_listen_port(self) -> int: """ Get the claimed port for this connection. """ - return self.port + return self.port or 0 def close(self) -> None: """ diff --git a/src/tribler/core/torrent_checker/torrent_checker.py b/src/tribler/core/torrent_checker/torrent_checker.py index fa8f919a24..a17157f330 100644 --- a/src/tribler/core/torrent_checker/torrent_checker.py +++ b/src/tribler/core/torrent_checker/torrent_checker.py @@ -7,7 +7,7 @@ from asyncio import CancelledError, DatagramTransport from binascii import hexlify from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, cast from ipv8.taskmanager import TaskManager from pony.orm import db_session, desc, select @@ -78,9 +78,9 @@ def __init__(self, # noqa: PLR0913 self.socks_listen_ports = socks_listen_ports self._should_stop = False - self.sessions = defaultdict(list) + self.sessions: dict[str, list[TrackerSession]] = defaultdict(list) self.socket_mgr = UdpSocketManager() - self.udp_transport = None + self.udp_transport: DatagramTransport | None = None # We keep track of the results of popular torrents checked by you. # The content_discovery community gossips this information around. @@ -368,7 +368,8 @@ def create_session_for_request(self, tracker_url: str, timeout: float = 20) -> T self._logger.warning("Dropping the request. Required amount of hops not reached. " "Required hops: %d. Actual hops: %d", required_hops, actual_hops) return None - proxy = ('127.0.0.1', self.socks_listen_ports[required_hops - 1]) if required_hops > 0 else None + listen_ports = cast(list[int], self.socks_listen_ports) # Guaranteed by check above + proxy = ('127.0.0.1', listen_ports[required_hops - 1]) if required_hops > 0 else None session = create_tracker_session(tracker_url, timeout, proxy, self.socket_mgr) self._logger.info("Tracker session has been created: %s", str(session)) self.sessions[tracker_url].append(session) diff --git a/src/tribler/core/torrent_checker/torrentchecker_session.py b/src/tribler/core/torrent_checker/torrentchecker_session.py index 349781ba30..6937ad786c 100644 --- a/src/tribler/core/torrent_checker/torrentchecker_session.py +++ b/src/tribler/core/torrent_checker/torrentchecker_session.py @@ -4,11 +4,10 @@ import random import socket import struct -import sys import time from abc import ABCMeta, abstractmethod -from asyncio import BaseTransport, DatagramProtocol, Future, TimeoutError, ensure_future, get_event_loop -from typing import TYPE_CHECKING, List +from asyncio import DatagramProtocol, Future, TimeoutError, ensure_future, get_event_loop +from typing import TYPE_CHECKING, List, cast import async_timeout import libtorrent as lt @@ -56,8 +55,8 @@ def __init__(self, tracker_type: str, tracker_url: str, tracker_address: tuple[s # if this is a nonempty string it starts with '/'. self.announce_page = announce_page self.timeout = timeout - self.infohash_list = [] - self.last_contact = None + self.infohash_list: list[bytes] = [] + self.last_contact = 0 # some flags self.is_initiated = False # you cannot add requests to a session if it has been initiated @@ -75,7 +74,7 @@ async def cleanup(self) -> None: Shutdown and invalidate. """ await self.shutdown_task_manager() - self.infohash_list = None + self.infohash_list = [] def has_infohash(self, infohash: bytes) -> bool: """ @@ -217,11 +216,11 @@ def __init__(self) -> None: Create a new UDP socket protocol for trackers. """ self._logger = logging.getLogger(self.__class__.__name__) - self.tracker_sessions = {} - self.transport = None - self.proxy_transports = {} + self.tracker_sessions: dict[int, Future[bytes]] = {} + self.transport: Socks5Client | None = None + self.proxy_transports: dict[tuple, Socks5Client] = {} - def connection_made(self, transport: BaseTransport) -> None: + def connection_made(self, transport: Socks5Client) -> None: """ Callback for when a connection is established. """ @@ -231,7 +230,7 @@ async def send_request(self, data: bytes, tracker_session: UdpTrackerSession) -> """ Send a request and wait for the answer. """ - transport = self.transport + transport: Socks5Client | None = self.transport proxy = tracker_session.proxy if proxy: @@ -241,6 +240,9 @@ async def send_request(self, data: bytes, tracker_session: UdpTrackerSession) -> if proxy not in self.proxy_transports: self.proxy_transports[proxy] = transport + if transport is None: + return RuntimeError("Unable to write without transport") + host = tracker_session.ip_address or tracker_session.tracker_address[0] try: transport.sendto(data, (host, tracker_session.port)) @@ -277,7 +279,7 @@ class UdpTrackerSession(TrackerSession): """ # A list of transaction IDs that have been used in order to avoid conflict. - _active_session_dict = {} + _active_session_dict: dict[UdpTrackerSession, int] = {} def __init__(self, tracker_url: str, tracker_address: tuple[str, int], announce_page: str, # noqa: PLR0913 timeout: float, proxy: tuple, socket_mgr: UdpSocketManager) -> None: @@ -306,7 +308,7 @@ def generate_transaction_id(self) -> None: while True: # make sure there is no duplicated transaction IDs transaction_id = random.randint(0, 2147483647) - if transaction_id not in UdpTrackerSession._active_session_dict.items(): + if transaction_id not in UdpTrackerSession._active_session_dict.values(): UdpTrackerSession._active_session_dict[self] = transaction_id self.transaction_id = transaction_id break @@ -373,7 +375,11 @@ async def connect(self) -> None: # Initiate the connection message = struct.pack("!qii", self._connection_id, self.action, self.transaction_id) - response = await self.socket_mgr.send_request(message, self) + raw_response = await self.socket_mgr.send_request(message, self) + + if isinstance(raw_response, Exception): + self.failed(msg=str(raw_response)) + response = cast(bytes, raw_response) # check message size if len(response) < 16: @@ -401,17 +407,14 @@ async def scrape(self) -> TrackerResponse: """ Parse the response of a tracker. """ - # pack and send the message - if sys.version_info.major > 2: - infohash_list = self.infohash_list - else: - infohash_list = [str(infohash) for infohash in self.infohash_list] - fmt = "!qii" + ("20s" * len(self.infohash_list)) - message = struct.pack(fmt, self._connection_id, self.action, self.transaction_id, *infohash_list) + message = struct.pack(fmt, self._connection_id, self.action, self.transaction_id, *self.infohash_list) # Send the scrape message - response = await self.socket_mgr.send_request(message, self) + raw_response = await self.socket_mgr.send_request(message, self) + if isinstance(raw_response, Exception): + self.failed(msg=str(raw_response)) + response = cast(bytes, raw_response) # check message size if len(response) < 8: @@ -498,7 +501,7 @@ async def connect_to_tracker(self) -> TrackerResponse: """ coros = [self.download_manager.dht_health_manager.get_health(infohash, timeout=self.timeout) for infohash in self.infohash_list] - results = [] + results: list[HealthInfo] = [] for coroutine in coros: local_results = [result for result in (await coroutine) if not isinstance(result, Exception)] results = [*results, *local_results] diff --git a/src/tribler/core/torrent_checker/tracker_manager.py b/src/tribler/core/torrent_checker/tracker_manager.py index 6f58cbb44e..38b596bd79 100644 --- a/src/tribler/core/torrent_checker/tracker_manager.py +++ b/src/tribler/core/torrent_checker/tracker_manager.py @@ -29,7 +29,7 @@ def __init__(self, state_dir: Path | None = None, metadata_store: MetadataStore self.state_dir = state_dir self.TrackerState = metadata_store.TrackerState - self.blacklist = [] + self.blacklist: list[str] = [] self.load_blacklist() def load_blacklist(self) -> None: @@ -38,7 +38,7 @@ def load_blacklist(self) -> None: Entries are newline separated and are supposed to be sanitized. """ - blacklist_file = (Path(self.state_dir) / "tracker_blacklist.txt").absolute() + blacklist_file = (Path(self.state_dir or ".") / "tracker_blacklist.txt").absolute() if blacklist_file.exists(): with open(blacklist_file) as blacklist_file_handle: # Note that get_uniformed_tracker_url will strip the newline at the end of .readlines() diff --git a/src/tribler/core/tunnel/caches.py b/src/tribler/core/tunnel/caches.py index 97c223308a..930813380f 100644 --- a/src/tribler/core/tunnel/caches.py +++ b/src/tribler/core/tunnel/caches.py @@ -21,8 +21,8 @@ def __init__(self, community: TriblerTunnelCommunity, circuit_id: int) -> None: """ super().__init__(community.request_cache, "http-request") self.circuit_id = circuit_id - self.response = {} - self.response_future = Future() + self.response: dict[int, bytes] = {} + self.response_future: Future[bytes] = Future() self.register_future(self.response_future) def add_response(self, payload: HTTPResponsePayload) -> bool: diff --git a/src/tribler/core/tunnel/community.py b/src/tribler/core/tunnel/community.py index bf8dcaadb7..06ad145a46 100644 --- a/src/tribler/core/tunnel/community.py +++ b/src/tribler/core/tunnel/community.py @@ -88,10 +88,10 @@ def __init__(self, settings: TriblerTunnelSettings) -> None: self.logger.info("Using %s with flags %s", self.endpoint.__class__.__name__, self.settings.peer_flags) - self.bittorrent_peers = {} + self.bittorrent_peers: dict[Download, set[tuple[str, int]]] = {} self.dispatcher = TunnelDispatcher(self) - self.download_states = {} - self.last_forced_announce = {} + self.download_states: dict[bytes, DownloadStatus] = {} + self.last_forced_announce: dict[bytes, float] = {} if settings.socks_servers: self.dispatcher.set_socks_servers(settings.socks_servers) @@ -253,7 +253,7 @@ def _ours_on_created_extended(self, circuit_id: int, payload: CreatedPayload | E # Re-add BitTorrent peers, if needed. self.readd_bittorrent_peers() - def on_raw_data(self, circuit: Circuit, origin: int, data: bytes) -> None: + def on_raw_data(self, circuit: Circuit, origin: tuple[str, int], data: bytes) -> None: """ Let our dispatcher know that we have incoming data. """ @@ -357,9 +357,11 @@ def update_ip_filter(self, info_hash: bytes) -> None: Set the IP filter setting for the given infohash. """ download = self.get_download(info_hash) - lt_session = self.settings.download_manager.get_session(download.config.get_hops()) - ip_addresses = [self.circuit_id_to_ip(c.circuit_id) for c in self.find_circuits(ctype=CIRCUIT_TYPE_RP_SEEDER)] - self.settings.download_manager.update_ip_filter(lt_session, ip_addresses) + if download is not None: + lt_session = self.settings.download_manager.get_session(download.config.get_hops()) + ip_addresses = [self.circuit_id_to_ip(c.circuit_id) + for c in self.find_circuits(ctype=CIRCUIT_TYPE_RP_SEEDER)] + self.settings.download_manager.update_ip_filter(lt_session, ip_addresses) def get_download(self, lookup_info_hash: bytes) -> Download | None: """ diff --git a/src/tribler/core/tunnel/discovery.py b/src/tribler/core/tunnel/discovery.py index 8989394394..0eeff6962d 100644 --- a/src/tribler/core/tunnel/discovery.py +++ b/src/tribler/core/tunnel/discovery.py @@ -8,6 +8,8 @@ from ipv8.peerdiscovery.discovery import DiscoveryStrategy if TYPE_CHECKING: + from ipv8.types import Peer + from tribler.core.tunnel.community import TriblerTunnelCommunity @@ -32,7 +34,7 @@ def __init__(self, overlay: TriblerTunnelCommunity, golden_ratio: float = 9 / 16 super().__init__(overlay) self.golden_ratio = golden_ratio self.target_peers = target_peers - self.intro_sent = {} + self.intro_sent: dict[Peer, float] = {} assert target_peers > 0 assert 0.0 <= golden_ratio <= 1.0 diff --git a/src/tribler/core/tunnel/dispatcher.py b/src/tribler/core/tunnel/dispatcher.py index 87706b2661..41ba765799 100644 --- a/src/tribler/core/tunnel/dispatcher.py +++ b/src/tribler/core/tunnel/dispatcher.py @@ -19,6 +19,8 @@ if TYPE_CHECKING: from asyncio import Future + from ipv8.messaging.interfaces.udp.endpoint import DomainAddress, UDPv4Address + from tribler.core.socks5.connection import Socks5Connection from tribler.core.socks5.server import Socks5Server from tribler.core.socks5.udp_connection import RustUDPConnection, SocksUDPConnection @@ -37,13 +39,13 @@ def __init__(self, tunnels: TriblerTunnelCommunity) -> None: """ super().__init__() self.tunnels = tunnels - self.socks_servers = [] + self.socks_servers: list[Socks5Server] = [] # Map to keep track of the circuits associated with each destination. - self.con_to_cir = defaultdict(dict) + self.con_to_cir: dict[Socks5Connection, dict[DomainAddress | UDPv4Address, Circuit]] = defaultdict(dict) # Map to keep track of the circuit id to UDP connection. - self.cid_to_con = {} + self.cid_to_con: dict[int, Socks5Connection] = {} self.register_task("check_connections", self.check_connections, interval=30) @@ -53,7 +55,7 @@ def set_socks_servers(self, socks_servers: list[Socks5Server]) -> None: """ self.socks_servers = socks_servers - def on_incoming_from_tunnel(self, community: TriblerTunnelCommunity, circuit: Circuit, origin: int, + def on_incoming_from_tunnel(self, community: TriblerTunnelCommunity, circuit: Circuit, origin: tuple[str, int], data: bytes) -> bool: """ We received some data from the tunnel community. Dispatch it to the right UDP SOCKS5 socket. diff --git a/src/tribler/core/user_activity/manager.py b/src/tribler/core/user_activity/manager.py index 929aba6649..d1209c4200 100644 --- a/src/tribler/core/user_activity/manager.py +++ b/src/tribler/core/user_activity/manager.py @@ -30,12 +30,9 @@ def __init__(self, task_manager: TaskManager, session: Session, max_query_histor self.infohash_to_queries: dict[InfoHash, list[str]] = defaultdict(list) self.queries: OrderedDict[str, typing.Set[InfoHash]] = OrderedDict() self.max_query_history = max_query_history - self.database_manager = None - self.torrent_checker = None - self.task_manager = task_manager - self.database_manager: UserActivityLayer = session.db.user_activity self.torrent_checker: TorrentChecker = session.torrent_checker + self.task_manager = task_manager # Hook events session.notifier.add(Notification.torrent_finished, self.on_torrent_finished) diff --git a/src/tribler/gui/tribler_window.py b/src/tribler/gui/tribler_window.py index 5c84c2841f..20211ce534 100644 --- a/src/tribler/gui/tribler_window.py +++ b/src/tribler/gui/tribler_window.py @@ -13,6 +13,7 @@ from PyQt5.QtWidgets import (QAction, QApplication, QCompleter, QFileDialog, QLineEdit, QListWidget, QMainWindow, QShortcut, QStyledItemDelegate, QSystemTrayIcon, QTreeWidget) +from tribler.core.knowledge.rules.rules import extract_tags from tribler.gui.app_manager import AppManager from tribler.gui.core_manager import CoreManager from tribler.gui.debug_window import DebugWindow @@ -747,7 +748,7 @@ def on_top_search_bar_return_pressed(self): if not query_text: return - query = Query(original_query=query_text) + query = Query(query_text, *extract_tags(query_text)) if self.search_results_page.search(query): self._logger.info(f'Do search for query: {query}') self.deselect_all_menu_buttons() diff --git a/src/tribler/test_unit/core/database/restapi/test_database_endpoint.py b/src/tribler/test_unit/core/database/restapi/test_database_endpoint.py index aeed637fbc..4db30a5dd1 100644 --- a/src/tribler/test_unit/core/database/restapi/test_database_endpoint.py +++ b/src/tribler/test_unit/core/database/restapi/test_database_endpoint.py @@ -242,7 +242,7 @@ def test_build_snippets_empty(self) -> None: """ endpoint = DatabaseEndpoint(None, None, None) - value = endpoint.build_snippets([]) + value = endpoint.build_snippets(None, []) self.assertEqual([], value) @@ -252,7 +252,7 @@ def test_build_snippets_one_empty(self) -> None: """ endpoint = DatabaseEndpoint(None, None, None) - value = endpoint.build_snippets([{}]) + value = endpoint.build_snippets(None, [{}]) self.assertEqual([{}], value) @@ -264,7 +264,7 @@ def test_build_snippets_one_filled_no_knowledge(self) -> None: endpoint.tribler_db = Mock(knowledge=Mock(get_objects=Mock(return_value=[]))) search_result = {"infohash": "AA"} - value = endpoint.build_snippets([search_result]) + value = endpoint.build_snippets(endpoint.tribler_db, [search_result]) self.assertEqual([search_result], value) @@ -276,7 +276,7 @@ def test_build_snippets_one_filled_with_knowledge(self) -> None: endpoint.tribler_db = Mock(knowledge=Mock(get_objects=Mock(return_value=["AA"]))) search_result = {"infohash": "AA", "num_seeders": 1} - value = endpoint.build_snippets([search_result]) + value = endpoint.build_snippets(endpoint.tribler_db, [search_result]) self.assertEqual(SNIPPET, value[0]["type"]) self.assertEqual("", value[0]["category"]) @@ -294,7 +294,7 @@ def test_build_snippets_two_filled_with_knowledge(self) -> None: endpoint.tribler_db = Mock(knowledge=Mock(get_objects=Mock(return_value=["AA", "BB"]))) search_result = {"infohash": "AA", "num_seeders": 1} - value = endpoint.build_snippets([search_result]) + value = endpoint.build_snippets(endpoint.tribler_db, [search_result]) for snippet_id in range(2): self.assertEqual(SNIPPET, value[snippet_id]["type"]) @@ -315,7 +315,7 @@ def test_build_snippets_max_filled_with_knowledge(self) -> None: endpoint.tribler_db = Mock(knowledge=Mock(get_objects=Mock(return_value=mock_results))) search_result = {"infohash": "AA", "num_seeders": 1} - value = endpoint.build_snippets([search_result]) + value = endpoint.build_snippets(endpoint.tribler_db, [search_result]) self.assertEqual(SNIPPETS_TO_SHOW, len(value)) @@ -336,6 +336,7 @@ async def test_local_search_errored_search(self) -> None: The exception here stems from the ``mds`` being set to ``None``. """ endpoint = DatabaseEndpoint(None, None, None) + endpoint.tribler_db = Mock() response = await endpoint.local_search(SearchLocalRequest({})) @@ -346,8 +347,10 @@ async def test_local_search_no_knowledge(self) -> None: Test if performing a local search without a tribler db set returns mds results. """ endpoint = DatabaseEndpoint(None, None, None) + endpoint.tribler_db = Mock() endpoint.mds = Mock(run_threaded=self.mds_run_now, get_total_count=Mock(), get_max_rowid=Mock(), - get_entries=Mock(return_value=[Mock(to_simple_dict=Mock(return_value={"test": "test"}))])) + get_entries=Mock(return_value=[Mock(to_simple_dict=Mock(return_value={"test": "test", + "type": -1}))])) response = await endpoint.local_search(SearchLocalRequest({})) response_body_json = await response_to_json(response) @@ -364,9 +367,11 @@ async def test_local_search_no_knowledge_include_total(self) -> None: Test if performing a local search with requested total, includes a total. """ endpoint = DatabaseEndpoint(None, None, None) + endpoint.tribler_db = Mock() endpoint.mds = Mock(run_threaded=self.mds_run_now, get_total_count=Mock(return_value=1), get_max_rowid=Mock(return_value=7), - get_entries=Mock(return_value=[Mock(to_simple_dict=Mock(return_value={"test": "test"}))])) + get_entries=Mock(return_value=[Mock(to_simple_dict=Mock(return_value={"test": "test", + "type": -1}))])) response = await endpoint.local_search(SearchLocalRequest({"include_total": "I would like this"})) response_body_json = await response_to_json(response) diff --git a/src/tribler/test_unit/core/database/test_serialization.py b/src/tribler/test_unit/core/database/test_serialization.py index b6dff2176a..9b23bafb09 100644 --- a/src/tribler/test_unit/core/database/test_serialization.py +++ b/src/tribler/test_unit/core/database/test_serialization.py @@ -130,7 +130,7 @@ def test_get_magnet(self) -> None: id_=7, origin_id=1337, timestamp=10, infohash=b"\x01" * 20, size=42, torrent_date=int2time(0), title="test", tags="tags", tracker_info="") - self.assertEqual("magnet:?xt=urn:btih:0101010101010101010101010101010101010101&dn=b'test'", + self.assertEqual("magnet:?xt=urn:btih:0101010101010101010101010101010101010101&dn=test", payload.get_magnet()) def test_auto_convert_torrent_date(self) -> None: diff --git a/src/tribler/test_unit/core/libtorrent/download_manager/test_download.py b/src/tribler/test_unit/core/libtorrent/download_manager/test_download.py index 96ba9cbe96..fcfd4a71cb 100644 --- a/src/tribler/test_unit/core/libtorrent/download_manager/test_download.py +++ b/src/tribler/test_unit/core/libtorrent/download_manager/test_download.py @@ -52,7 +52,7 @@ def test_download_get_magnet_link_no_handle(self) -> None: """ Test if a download without a handle does not have a magnet link. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) self.assertIsNone(download.get_magnet_link()) @@ -61,7 +61,7 @@ def test_download_get_atp(self) -> None: """ Test if the atp can be retrieved from a download. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) atp = download.get_atp() @@ -75,7 +75,7 @@ def test_download_resume(self) -> None: """ Test if a download can be resumed. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(is_valid=Mock(return_value=True)) @@ -89,7 +89,7 @@ async def test_save_resume(self) -> None: """ Test if a download is resumed after fetching the save/resume data. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.futures["save_resume_data"] = succeed(True) @@ -105,7 +105,7 @@ def test_move_storage(self) -> None: """ Test if storage can be moved. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(is_valid=Mock(return_value=True)) @@ -119,7 +119,7 @@ def test_move_storage_no_metainfo(self) -> None: """ Test if storage is not moved for torrents without metainfo. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(is_valid=Mock(return_value=True)) @@ -133,7 +133,7 @@ async def test_save_checkpoint_disabled(self) -> None: """ Test if checkpoints are not saved if checkpointing is disabled. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(is_valid=Mock(return_value=False)) @@ -145,7 +145,7 @@ async def test_save_checkpoint_handle_no_data(self) -> None: """ Test if checkpoints are not saved if the handle specifies that it does not need resume data. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.checkpoint_disabled = False download.handle = Mock(is_valid=Mock(return_value=True), need_save_resume_data=Mock(return_value=False)) @@ -159,7 +159,7 @@ async def test_save_checkpoint_no_handle_no_existing(self) -> None: Test if checkpoints are saved for torrents without a handle and no existing checkpoint file. """ alerts = [] - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.checkpoint_disabled = False download.download_manager = Mock(get_checkpoint_dir=Mock(return_value=Path("foo"))) @@ -180,7 +180,7 @@ async def test_save_checkpoint_no_handle_existing(self) -> None: Test if existing checkpoints are not overwritten by checkpoints without data. """ alerts = [] - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.checkpoint_disabled = False download.download_manager = Mock(get_checkpoint_dir=Mock(return_value=Path("foo"))) @@ -197,7 +197,7 @@ def test_selected_files_default(self) -> None: """ Test if the default selected files are no files. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(file_priorities=Mock(return_value=[0, 0])) @@ -208,7 +208,7 @@ def test_selected_files_last(self) -> None: """ Test if the last selected file in a list of files gets correctly selected. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(file_priorities=Mock(return_value=[0, 4])) @@ -221,7 +221,7 @@ def test_selected_files_first(self) -> None: """ Test if the first selected file in a list of files gets correctly selected. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(file_priorities=Mock(return_value=[4, 0])) @@ -234,7 +234,7 @@ def test_selected_files_all(self) -> None: """ Test if all files can be selected. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(file_priorities=Mock(return_value=[4, 4])) @@ -247,7 +247,7 @@ def test_selected_files_all_through_none(self) -> None: """ Test if all files can be selected by selecting None. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(file_priorities=Mock(return_value=[4, 4])) @@ -260,7 +260,7 @@ def test_selected_files_all_through_empty_list(self) -> None: """ Test if all files can be selected by selecting an empty list. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(file_priorities=Mock(return_value=[4, 4])) @@ -273,7 +273,7 @@ def test_get_share_mode_enabled(self) -> None: """ Test if we forward the enabled share mode when requested in the download. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.config.set_share_mode(True) @@ -283,7 +283,7 @@ def test_get_share_mode_disabled(self) -> None: """ Test if we forward the disabled share mode when requested in the download. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.config.set_share_mode(False) @@ -293,7 +293,7 @@ async def test_enable_share_mode(self) -> None: """ Test if the share mode can be enabled in a download. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(is_valid=Mock(return_value=True)) @@ -307,7 +307,7 @@ async def test_disable_share_mode(self) -> None: """ Test if the share mode can be disabled in a download. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(is_valid=Mock(return_value=True)) @@ -321,7 +321,7 @@ def test_get_num_connected_seeds_peers_no_handle(self) -> None: """ Test if connected peers and seeds are 0 if there is no handle. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) num_seeds, num_peers = download.get_num_connected_seeds_peers() @@ -333,7 +333,7 @@ def test_get_num_connected_seeds_peers(self) -> None: """ Test if connected peers and seeds are correctly returned. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(is_valid=Mock(return_value=True), get_peer_info=Mock(return_value=[ Mock(flags=140347, seed=1024), @@ -350,7 +350,7 @@ async def test_set_priority(self) -> None: """ Test if setting the priority calls the right methods in download. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(is_valid=Mock(return_value=True)) @@ -363,7 +363,7 @@ def test_add_trackers(self) -> None: """ Test if trackers are added to the libtorrent handle. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(is_valid=Mock(return_value=True)) @@ -376,7 +376,7 @@ def test_process_error_alert(self) -> None: """ Test if error alerts are processed correctly. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.process_alert(Mock(msg=None, status_code=123, url="http://google.com", @@ -389,7 +389,7 @@ def test_process_error_alert_timeout(self) -> None: """ Test if timeout error alerts are processed correctly. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.process_alert(Mock(msg=None, status_code=0, url="http://google.com", @@ -402,7 +402,7 @@ def test_process_error_alert_not_working(self) -> None: """ Test if not working error alerts are processed correctly. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.process_alert(Mock(msg=None, status_code=-1, url="http://google.com", @@ -415,7 +415,7 @@ def test_tracker_warning_alert(self) -> None: """ Test if a tracking warning alert is processed correctly. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.process_alert(Mock(message=Mock(return_value="test"), url="http://google.com", @@ -429,7 +429,7 @@ async def test_on_metadata_received_alert(self) -> None: Test if the right operations happen when we receive metadata. """ tdef = TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT) - download = Download(tdef, checkpoint_disabled=True, config=self.create_mock_download_config()) + download = Download(tdef, None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(torrent_file=Mock(return_value=download.tdef.torrent_info), trackers=Mock(return_value=[{"url": "http://google.com"}]), get_peer_info=Mock(return_value=[Mock(progress=1)] * 42 + [Mock(progress=0)] * 7)) @@ -448,7 +448,7 @@ def test_on_metadata_received_alert_unicode_error_encode(self) -> None: Test if no exception is raised when the url is not unicode compatible. """ tdef = TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT) - download = Download(tdef, checkpoint_disabled=True, config=self.create_mock_download_config()) + download = Download(tdef, None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(trackers=Mock(return_value=[{"url": "\uD800"}]), torrent_file=Mock(return_value=download.tdef.torrent_info), get_peer_info=Mock(return_value=[])) @@ -467,7 +467,7 @@ def test_on_metadata_received_alert_unicode_error_decode(self) -> None: See: https://github.com/Tribler/tribler/issues/7223 """ tdef = TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT) - download = Download(tdef, checkpoint_disabled=True, config=self.create_mock_download_config()) + download = Download(tdef, None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(trackers=lambda: [{"url": b"\xFD".decode()}], torrent_file=Mock(return_value=download.tdef.torrent_info), get_peer_info=Mock(return_value=[])) @@ -484,7 +484,7 @@ def test_metadata_received_invalid_torrent_with_error(self) -> None: Test if no torrent def is loaded when a RuntimeError/ValueError occurs when parsing the metadata. """ tdef = TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT) - download = Download(tdef, checkpoint_disabled=True, config=self.create_mock_download_config()) + download = Download(tdef, None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(trackers=Mock(return_value=[]), torrent_file=Mock(return_value=Mock(metadata=Mock(return_value=b""))), get_peer_info=Mock(return_value=[])) @@ -498,7 +498,7 @@ def test_torrent_checked_alert_no_pause_no_checkpoint(self) -> None: """ Test if no pause or checkpoint happens if the download state is such. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.checkpoint_disabled = False download.handle = Mock(is_valid=Mock(return_value=True), need_save_resume_data=Mock(return_value=False)) @@ -516,7 +516,7 @@ def test_torrent_checked_alert_no_pause_checkpoint(self) -> None: """ Test if no pause but a checkpoint happens if the download state is such. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.checkpoint_disabled = False download.handle = Mock(is_valid=Mock(return_value=True), need_save_resume_data=Mock(return_value=False)) @@ -534,7 +534,7 @@ def test_torrent_checked_alert_pause_no_checkpoint(self) -> None: """ Test if a pause but no checkpoint happens if the download state is such. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.checkpoint_disabled = False download.handle = Mock(is_valid=Mock(return_value=True), need_save_resume_data=Mock(return_value=False)) @@ -552,7 +552,7 @@ def test_torrent_checked_alert_pause_checkpoint(self) -> None: """ Test if both a pause and a checkpoint happens if the download state is such. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.checkpoint_disabled = False download.handle = Mock(is_valid=Mock(return_value=True), need_save_resume_data=Mock(return_value=False)) @@ -570,18 +570,18 @@ def test_tracker_reply_alert(self) -> None: """ Test if the tracker status is extracted from a reply alert. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.on_tracker_reply_alert(Mock(url="http://google.com", num_peers=42)) - self.assertEqual([42, "Working"], download.tracker_status["http://google.com"]) + self.assertEqual((42, "Working"), download.tracker_status["http://google.com"]) def test_get_pieces_bitmask(self) -> None: """ Test if a correct pieces bitmask is returned when requested. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(status=Mock(return_value=Mock(pieces=[True, False, True, False, False]))) @@ -591,7 +591,7 @@ async def test_resume_data_failed(self) -> None: """ Test if an error is raised when loading resume data failed. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) future = download.wait_for_alert("save_resume_data_alert", None, "save_resume_data_failed_alert", @@ -605,7 +605,7 @@ async def test_on_state_changed_apply_ip_filter(self) -> None: """ Test if the ip filter gets enabled when in torrent status seeding (5) when hops are not zero. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.config.set_hops(1) download.handle = Mock(is_valid=Mock(return_value=True)) @@ -619,7 +619,7 @@ async def test_on_state_changed_no_filter(self) -> None: """ Test if the ip filter does not get enabled when the hop count is zero. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.config.set_hops(0) download.handle = Mock(is_valid=Mock(return_value=True)) @@ -633,7 +633,7 @@ async def test_on_state_changed_not_seeding(self) -> None: """ Test if the ip filter does not get enabled when the hop count is zero. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.config.set_hops(1) download.handle = Mock(is_valid=Mock(return_value=True)) @@ -647,7 +647,7 @@ async def test_checkpoint_timeout(self) -> None: """ Testing whether making a checkpoint times out when we receive no alert from libtorrent. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.futures["save_resume_data"] = [Future()] @@ -662,7 +662,7 @@ def test_on_save_resume_data_alert_permission_denied(self) -> None: """ Test if permission error in writing the download config does not crash the save resume alert handler. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=PermissionErrorDownloadConfig(self.create_mock_download_config().config)) download.checkpoint_disabled = False download.download_manager = Mock(get_checkpoint_dir=Mock(return_value=Path(__file__).absolute().parent)) @@ -678,7 +678,7 @@ def test_get_tracker_status_unicode_decode_error(self) -> None: See: https://github.com/Tribler/tribler/issues/7036 """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.download_manager = Mock(get_session=Mock(return_value=Mock(is_dht_running=Mock(return_value=False)))) download.handle = Mock(is_valid=Mock(return_value=True), @@ -695,7 +695,7 @@ def test_get_tracker_status_get_peer_info_error(self) -> None: """ Test if a tracker status is returned when getting peer info leads to a RuntimeError. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.download_manager = Mock(get_session=Mock(return_value=Mock(is_dht_running=Mock(return_value=True)))) download.handle = Mock(is_valid=Mock(return_value=True), get_peer_info=Mock(side_effect=RuntimeError), @@ -710,7 +710,7 @@ async def test_shutdown(self) -> None: """ Test if the shutdown method closes the stream and clears the futures dictionary. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name"), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.stream = Mock() @@ -723,7 +723,7 @@ def test_file_piece_range_flat(self) -> None: """ Test if the piece range of a single-file torrent is correctly determined. """ - download = Download(TorrentDef.load_from_memory(TORRENT_UBUNTU_FILE_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_UBUNTU_FILE_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) total_pieces = download.tdef.torrent_info.num_pieces() @@ -735,7 +735,7 @@ def test_file_piece_range_minifiles(self) -> None: """ Test if the piece range of a file is correctly determined if multiple files exist in the same piece. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) piece_range_a = download.file_piece_range(Path("torrent_create") / "abc" / "file2.txt") piece_range_b = download.file_piece_range(Path("torrent_create") / "abc" / "file3.txt") @@ -750,7 +750,7 @@ def test_file_piece_range_wide(self) -> None: tdef = TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT) tdef.metainfo[b"info"][b"files"][0][b"length"] = 60000 tdef.metainfo[b"info"][b"pieces"] = b'\x01' * 80 - download = Download(tdef, checkpoint_disabled=True, config=self.create_mock_download_config()) + download = Download(tdef, None, checkpoint_disabled=True, config=self.create_mock_download_config()) file1 = download.file_piece_range(Path("torrent_create") / "abc" / "file2.txt") other_indices = [download.file_piece_range(Path("torrent_create") / Path( @@ -765,7 +765,7 @@ def test_file_piece_range_nonexistent(self) -> None: """ Test if the piece range of a non-existent file is correctly determined. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) piece_range = download.file_piece_range(Path("I don't exist")) @@ -776,7 +776,7 @@ def test_file_completion_full(self) -> None: """ Test if a complete file shows 1.0 completion. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(is_valid=Mock(return_value=True), have_piece=Mock(return_value=True)) @@ -786,7 +786,7 @@ def test_file_completion_nonexistent(self) -> None: """ Test if an unknown path (does not exist in a torrent) shows 1.0 completion. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(is_valid=Mock(return_value=True)) @@ -796,7 +796,7 @@ def test_file_completion_directory(self) -> None: """ Test if a directory (does not exist in a torrent) shows 1.0 completion. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(is_valid=Mock(return_value=True), have_piece=Mock(return_value=True)) @@ -806,7 +806,7 @@ def test_file_completion_nohandle(self) -> None: """ Test if a file shows 0.0 completion if the torrent handle is not valid. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(is_valid=Mock(return_value=False), have_piece=Mock(return_value=True)) @@ -816,7 +816,7 @@ def test_file_completion_partial(self) -> None: """ Test if a file shows 0.0 completion if the torrent handle is not valid. """ - download = Download(TorrentDef.load_from_memory(TORRENT_UBUNTU_FILE_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_UBUNTU_FILE_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) total_pieces = download.tdef.torrent_info.num_pieces() expected = (total_pieces // 2) / total_pieces @@ -834,7 +834,7 @@ def test_file_length(self) -> None: """ Test if we can get the length of a file. """ - download = Download(TorrentDef.load_from_memory(TORRENT_UBUNTU_FILE_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_UBUNTU_FILE_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) self.assertEqual(1150844928, download.get_file_length(Path("ubuntu-15.04-desktop-amd64.iso"))) @@ -843,7 +843,7 @@ def test_file_length_two(self) -> None: """ Test if we can get the length of a file in a multi-file torrent. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) self.assertEqual(6, download.get_file_length(Path("torrent_create") / "abc" / "file2.txt")) @@ -853,7 +853,7 @@ def test_file_length_nonexistent(self) -> None: """ Test if the length of a non-existent file is 0. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) self.assertEqual(0, download.get_file_length(Path("I don't exist"))) @@ -862,7 +862,7 @@ def test_file_index_unloaded(self) -> None: """ Test if a non-existent path leads to the special unloaded index. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) self.assertEqual(IllegalFileIndex.unloaded.value, download.get_file_index(Path("I don't exist"))) @@ -871,7 +871,7 @@ def test_file_index_directory_collapsed(self) -> None: """ Test if a collapsed-dir path leads to the special collapsed dir index. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) self.assertEqual(IllegalFileIndex.collapsed_dir.value, download.get_file_index(Path("torrent_create"))) @@ -880,7 +880,7 @@ def test_file_index_directory_expanded(self) -> None: """ Test if an expanded-dir path leads to the special expanded dir index. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.tdef.torrent_file_tree.expand(Path("torrent_create")) @@ -890,7 +890,7 @@ def test_file_index_file(self) -> None: """ Test if we can get the index of a file. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) self.assertEqual(1, download.get_file_index(Path("torrent_create") / "abc" / "file3.txt")) @@ -899,7 +899,7 @@ def test_file_selected_nonexistent(self) -> None: """ Test if a non-existent file does not register as selected. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) self.assertFalse(download.is_file_selected(Path("I don't exist"))) @@ -908,7 +908,7 @@ def test_file_selected_realfile(self) -> None: """ Test if a file starts off as selected. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) self.assertTrue(download.is_file_selected(Path("torrent_create") / "abc" / "file3.txt")) @@ -917,7 +917,7 @@ def test_file_selected_directory(self) -> None: """ Test if a directory does not register as selected. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) self.assertFalse(download.is_file_selected(Path("torrent_create") / "abc")) @@ -927,7 +927,7 @@ def test_on_torrent_finished_alert(self) -> None: Test if the torrent_finished notification is called when the torrent finishes. """ callback = Mock() - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.stream = Mock() download.handle = Mock(is_valid=Mock(return_value=True), status=Mock(return_value=Mock(total_download=7))) diff --git a/src/tribler/test_unit/core/libtorrent/download_manager/test_download_manager.py b/src/tribler/test_unit/core/libtorrent/download_manager/test_download_manager.py index 6a4535b284..5b00ddd8a4 100644 --- a/src/tribler/test_unit/core/libtorrent/download_manager/test_download_manager.py +++ b/src/tribler/test_unit/core/libtorrent/download_manager/test_download_manager.py @@ -14,7 +14,7 @@ import tribler from tribler.core.libtorrent.download_manager.download import Download from tribler.core.libtorrent.download_manager.download_config import SPEC_CONTENT, DownloadConfig -from tribler.core.libtorrent.download_manager.download_manager import DownloadManager +from tribler.core.libtorrent.download_manager.download_manager import DownloadManager, MetainfoLookup from tribler.core.libtorrent.download_manager.download_state import DownloadState from tribler.core.libtorrent.torrentdef import TorrentDef, TorrentDefNoMetainfo from tribler.core.notifier import Notifier @@ -71,7 +71,7 @@ async def test_get_metainfo_valid_metadata(self) -> None: """ Testing if the metainfo is retrieved when the handle has valid metadata immediately. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=True)) config = DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT))) @@ -95,7 +95,7 @@ async def test_get_metainfo_duplicate_request(self) -> None: """ Test if the same request is returned when invoking get_metainfo twice with the same infohash. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=True)) config = DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT))) @@ -120,7 +120,7 @@ async def test_get_metainfo_with_already_added_torrent(self) -> None: """ Test if metainfo can be fetched for a torrent which is already in session. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=True)) self.manager.downloads[download.tdef.infohash] = download @@ -131,11 +131,11 @@ async def test_start_download_while_getting_metainfo(self) -> None: """ Test if a torrent can be added while a metainfo request is running. """ - info_download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), + info_download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) info_download.handle = Mock(is_valid=Mock(return_value=True)) self.manager.downloads[info_download.tdef.infohash] = info_download - self.manager.metainfo_requests[info_download.tdef.infohash] = [info_download, 1] + self.manager.metainfo_requests[info_download.tdef.infohash] = MetainfoLookup(info_download, 1) tdef = TorrentDefNoMetainfo(info_download.tdef.infohash, b"name", f"magnet:?xt=urn:btih:{hexlify(info_download.tdef.infohash).decode()}&") @@ -167,7 +167,7 @@ async def test_start_handle_wait_for_dht_timeout(self) -> None: """ Test if start handle waits no longer than the set timeout for the DHT to be ready. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(is_valid=Mock(return_value=True)) self.manager.dht_ready_task = Future() @@ -179,7 +179,7 @@ async def test_start_handle_wait_for_dht(self) -> None: """ Test if start handle waits for the DHT to be ready. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) download.handle = Mock(is_valid=Mock(return_value=True)) self.manager.dht_ready_task = Future() @@ -230,7 +230,7 @@ async def test_start_download_existing_download(self) -> None: """ Test if torrents can be added when there is a pre-existing download. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) self.manager.downloads[download.tdef.infohash] = download @@ -375,7 +375,7 @@ async def test_readd_download_safe_seeding(self) -> None: """ config = self.create_mock_download_config() config.set_safe_seeding(True) - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=config) download.futures["save_resume_data"] = succeed(True) download_state = DownloadState(download, Mock(state=4, paused=False), None) @@ -399,7 +399,7 @@ def test_get_downloads_by_name(self) -> None: """ Test if downloads can be retrieved by name. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) self.manager.downloads = {b"\x01" * 20: download} @@ -433,7 +433,7 @@ def test_update_trackers(self) -> None: """ Test if trackers can be updated for an existing download. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) self.manager.downloads[download.tdef.infohash] = download @@ -446,7 +446,7 @@ def test_update_trackers_list(self) -> None: """ Test if multiple trackers are correctly added as an announce list instead of a the singular announce. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) self.manager.downloads[download.tdef.infohash] = download @@ -460,7 +460,7 @@ def test_update_trackers_list_append(self) -> None: """ Test if trackers can be updated in sequence. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), checkpoint_disabled=True, + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=self.create_mock_download_config()) self.manager.downloads[download.tdef.infohash] = download diff --git a/src/tribler/test_unit/core/libtorrent/download_manager/test_download_state.py b/src/tribler/test_unit/core/libtorrent/download_manager/test_download_state.py index 9497d74846..9096a177c2 100644 --- a/src/tribler/test_unit/core/libtorrent/download_manager/test_download_state.py +++ b/src/tribler/test_unit/core/libtorrent/download_manager/test_download_state.py @@ -27,7 +27,7 @@ def test_initialize(self) -> None: """ Test if DownloadState gets properly initialized from a download without a status. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download_state = DownloadState(download, None, None) @@ -48,7 +48,7 @@ def test_initialize_with_status(self) -> None: """ Test if DownloadState gets properly initialized from a download with a status. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download_state = DownloadState(download, libtorrent.torrent_status(), None) @@ -67,7 +67,7 @@ def test_initialize_with_mocked_status(self) -> None: """ Test if DownloadState gets properly initialized from a download with a mocked status. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.config.set_selected_files(["test"]) download_state = DownloadState(download, Mock(num_pieces=6, pieces=[1, 1, 1, 0, 0, 0], progress=0.75, @@ -124,7 +124,7 @@ def test_get_files_completion(self) -> None: Each file is 6 bytes, so a file progress of 3 bytes is 0.5 completion. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=True), file_progress=Mock(return_value=[3] * 6)) download_state = DownloadState(download, Mock(), None) @@ -138,7 +138,7 @@ def test_get_files_completion_no_progress(self) -> None: """ Testing if file progress is not given if no file progress is available. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=True), file_progress=Mock(return_value=[])) download_state = DownloadState(download, Mock(), None) @@ -149,7 +149,7 @@ def test_get_files_completion_zero_length_file(self) -> None: """ Testing if file progress is 100% for a file of 0 bytes. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) for file_spec in download.tdef.metainfo[b"info"][b"files"]: file_spec[b"length"] = 0 @@ -163,7 +163,7 @@ def test_get_availability_incomplete(self) -> None: """ Testing if the right availability of a file is returned if another peer has no pieces. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=True), file_progress=Mock(return_value=[]), get_peer_info=Mock(return_value=[Mock(**TestDownloadState.base_peer_info, @@ -176,7 +176,7 @@ def test_get_availability_complete(self) -> None: """ Testing if the right availability of a file is returned if another peer has all pieces. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=True), file_progress=Mock(return_value=[]), get_peer_info=Mock(return_value=[Mock(**TestDownloadState.base_peer_info, @@ -189,7 +189,7 @@ def test_get_availability_mixed(self) -> None: """ Testing if the right availability of a file is returned if one peer is complete and the other is not. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=True), file_progress=Mock(return_value=[]), get_peer_info=Mock(return_value=[Mock(**TestDownloadState.base_peer_info, @@ -206,7 +206,7 @@ def test_get_files_completion_semivalid_handle(self) -> None: This case mirrors https://github.com/Tribler/tribler/issues/6454 """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=True), file_progress=Mock(side_effect=RuntimeError("invalid torrent handle used"))) diff --git a/src/tribler/test_unit/core/libtorrent/download_manager/test_stream.py b/src/tribler/test_unit/core/libtorrent/download_manager/test_stream.py index dcf585c779..8d624ab095 100644 --- a/src/tribler/test_unit/core/libtorrent/download_manager/test_stream.py +++ b/src/tribler/test_unit/core/libtorrent/download_manager/test_stream.py @@ -199,7 +199,8 @@ def create_mock_download(self) -> Download: conf.validate(Validator()) config = DownloadConfig(conf) config.set_dest_dir(Path("")) - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), config, checkpoint_disabled=True) + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, config, + checkpoint_disabled=True) download.handle = Mock(is_valid=Mock(return_value=True), file_priorities=Mock(return_value=[0] * 6), torrent_file=Mock(return_value=download.tdef.torrent_info)) download.lt_status = Mock(state=3, paused=False, pieces=[]) diff --git a/src/tribler/test_unit/core/libtorrent/restapi/test_downloads_endpoint.py b/src/tribler/test_unit/core/libtorrent/restapi/test_downloads_endpoint.py index 8338075b8d..28481fb620 100644 --- a/src/tribler/test_unit/core/libtorrent/restapi/test_downloads_endpoint.py +++ b/src/tribler/test_unit/core/libtorrent/restapi/test_downloads_endpoint.py @@ -298,7 +298,8 @@ def create_mock_download(self) -> Download: conf.validate(Validator()) config = DownloadConfig(conf) config.set_dest_dir(Path("")) - return Download(TorrentDefNoMetainfo(b"\x01" * 20, b"test"), config, hidden=False, checkpoint_disabled=True) + return Download(TorrentDefNoMetainfo(b"\x01" * 20, b"test"), None, config, hidden=False, + checkpoint_disabled=True) async def test_get_downloads_unloaded(self) -> None: """ @@ -337,7 +338,7 @@ async def test_get_downloads_hidden_download(self) -> None: """ Test if an empty list is returned if there are only hidden downloads. """ - self.set_loaded_downloads([Download(TorrentDefNoMetainfo(b"\x01" * 20, b"test"), Mock(), + self.set_loaded_downloads([Download(TorrentDefNoMetainfo(b"\x01" * 20, b"test"), None, Mock(), hidden=True, checkpoint_disabled=True)]) response = await self.endpoint.get_downloads(GetDownloadsRequest({})) diff --git a/src/tribler/test_unit/core/libtorrent/test_torrentdef.py b/src/tribler/test_unit/core/libtorrent/test_torrentdef.py index 5456b5c984..ccb5dd9e88 100644 --- a/src/tribler/test_unit/core/libtorrent/test_torrentdef.py +++ b/src/tribler/test_unit/core/libtorrent/test_torrentdef.py @@ -332,7 +332,7 @@ def test_torrent_no_metainfo(self) -> None: self.assertEqual(b"12345678901234567890", tdef.get_infohash()) self.assertEqual(0, tdef.get_length()) self.assertIsNone(tdef.get_metainfo()) - self.assertEqual("http://google.com", tdef.get_url()) + self.assertEqual(b"http://google.com", tdef.get_url()) self.assertFalse(tdef.is_multifile_torrent()) self.assertEqual("ubuntu.torrent", tdef.get_name_as_unicode()) self.assertEqual([], tdef.get_files()) diff --git a/src/tribler/test_unit/core/libtorrent/test_torrents.py b/src/tribler/test_unit/core/libtorrent/test_torrents.py index 7847df78be..034b0df0d3 100644 --- a/src/tribler/test_unit/core/libtorrent/test_torrents.py +++ b/src/tribler/test_unit/core/libtorrent/test_torrents.py @@ -30,7 +30,7 @@ def test_check_handle_default_missing_handle(self) -> None: """ Test if the default value is returned for missing handles. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) self.assertEqual("default", (check_handle("default")(Download.get_def)(download))) @@ -39,7 +39,7 @@ def test_check_handle_default_invalid_handle(self) -> None: """ Test if the default value is returned for invalid handles. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=False)) @@ -49,7 +49,7 @@ def test_check_handle_default_valid_handle(self) -> None: """ Test if the given method is called for valid handles. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=True)) @@ -59,7 +59,7 @@ async def test_require_handle_invalid_handle(self) -> None: """ Test if None is returned for invalid handles. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=False)) @@ -71,7 +71,7 @@ async def test_require_handle_valid_handle(self) -> None: """ Test if the result of the given method is given for valid handles. """ - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=True)) @@ -89,7 +89,7 @@ def callback(_: Download) -> None: """ raise RuntimeError - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=True)) @@ -108,7 +108,7 @@ def callback(_: Download) -> None: """ raise ValueError - download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + download = Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=True)) @@ -120,7 +120,7 @@ async def test_check_vod_disabled(self) -> None: """ Test if the default value is returned for disabled vod mode. """ - stream = Stream(Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), checkpoint_disabled=True, + stream = Stream(Download(TorrentDefNoMetainfo(b"\x01" * 20, b"name", None), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT))))) stream.close() @@ -132,7 +132,7 @@ async def test_check_vod_enabled(self) -> None: """ Test if the function result is returned for enabled vod mode. """ - download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), + download = Download(TorrentDef.load_from_memory(TORRENT_WITH_DIRS_CONTENT), None, checkpoint_disabled=True, config=DownloadConfig(ConfigObj(StringIO(SPEC_CONTENT)))) download.handle = Mock(is_valid=Mock(return_value=True), file_priorities=Mock(return_value=[0]), torrent_file=Mock(return_value=Mock(map_file=Mock(return_value=Mock(piece=0))))) diff --git a/src/tribler/test_unit/core/restapi/test_events_endpoint.py b/src/tribler/test_unit/core/restapi/test_events_endpoint.py index bf54e6d996..2798bf9163 100644 --- a/src/tribler/test_unit/core/restapi/test_events_endpoint.py +++ b/src/tribler/test_unit/core/restapi/test_events_endpoint.py @@ -103,7 +103,7 @@ async def test_establish_connection(self) -> None: self.assertEqual(200, response.status) self.assertEqual((b'data: {' b'"topic": "events_start", ' - b'"kwargs": {"public_key": null, "version": "Tribler Experimental"}' + b'"kwargs": {"public_key": "", "version": "Tribler Experimental"}' b'}\n\n'), request.payload_writer.captured[0]) async def test_establish_connection_with_error(self) -> None: diff --git a/src/tribler/test_unit/core/socks5/test_client.py b/src/tribler/test_unit/core/socks5/test_client.py index 817a03ec5d..e931b3c81c 100644 --- a/src/tribler/test_unit/core/socks5/test_client.py +++ b/src/tribler/test_unit/core/socks5/test_client.py @@ -46,24 +46,24 @@ def test_data_received_connected(self) -> None: """ Test if data is fed to the registered callback when a connection is open. """ - callback = [] - client = MockSocks5Client(None, callback.append) + callback = Mock() + client = MockSocks5Client(None, callback) client.connected_to = ("localhost", 80) client.data_received(b"test") - self.assertEqual(b"test", callback[0]) + self.assertEqual(call(b"test", ("localhost", 80)), callback.call_args) async def test_data_received_queue_unconnected(self) -> None: """ Test if data is put in a single-item queue when no connection is open. """ - callback = [] - client = MockSocks5Client(None, callback.append) + callback = Mock() + client = MockSocks5Client(None, callback) client.data_received(b"test") - self.assertEqual([], callback) + self.assertEqual(None, callback.call_args) self.assertTrue(client.queue.full()) self.assertEqual(b"test", await client.queue.get()) diff --git a/src/tribler/test_unit/core/socks5/test_udp_connection.py b/src/tribler/test_unit/core/socks5/test_udp_connection.py index 129ae61ac8..ae69ed2e67 100644 --- a/src/tribler/test_unit/core/socks5/test_udp_connection.py +++ b/src/tribler/test_unit/core/socks5/test_udp_connection.py @@ -74,7 +74,7 @@ async def test_datagram_received_first(self) -> None: connection = MockSocksUDPConnection(socks_connection, None) await connection.open() - value = connection.datagram_received(b"\x00\x00\x00\x03\tlocalhost\x0590x000", ("localhost", 1337)) + value = connection.cb_datagram_received(b"\x00\x00\x00\x03\tlocalhost\x0590x000", ("localhost", 1337)) udp_payload = socks_connection.socksserver.output_stream.on_socks5_udp_data.call_args.args[1] self.assertTrue(value) @@ -92,7 +92,7 @@ async def test_datagram_received_wrong_source(self) -> None: connection = MockSocksUDPConnection(socks_connection, ("localhost", 1337)) await connection.open() - value = connection.datagram_received(b"\x00\x00\x00\x03\tlocalhost\x0590x000", ("notlocalhost", 1337)) + value = connection.cb_datagram_received(b"\x00\x00\x00\x03\tlocalhost\x0590x000", ("notlocalhost", 1337)) self.assertFalse(value) @@ -104,7 +104,7 @@ async def test_datagram_received_garbage(self) -> None: connection = MockSocksUDPConnection(socks_connection, ("localhost", 1337)) await connection.open() - value = connection.datagram_received(b"\x00", ("localhost", 1337)) + value = connection.cb_datagram_received(b"\x00", ("localhost", 1337)) self.assertFalse(value) @@ -116,7 +116,7 @@ async def test_datagram_received_fragmented(self) -> None: connection = MockSocksUDPConnection(socks_connection, ("localhost", 1337)) await connection.open() - value = connection.datagram_received(b"\x00\x00\x01\x03\tlocalhost\x0590x000", ("localhost", 1337)) + value = connection.cb_datagram_received(b"\x00\x00\x01\x03\tlocalhost\x0590x000", ("localhost", 1337)) self.assertFalse(value)