diff --git a/src/tribler/core/components/libtorrent/download_manager/download_manager.py b/src/tribler/core/components/libtorrent/download_manager/download_manager.py index 529553bbe13..d150bc092e5 100644 --- a/src/tribler/core/components/libtorrent/download_manager/download_manager.py +++ b/src/tribler/core/components/libtorrent/download_manager/download_manager.py @@ -586,8 +586,7 @@ def start_download(self, torrent_file=None, tdef=None, config: DownloadConfig = download = self.get_download(infohash) if download and infohash not in self.metainfo_requests: - new_trackers = list(set(tdef.get_trackers_as_single_tuple()) - - set(download.get_def().get_trackers_as_single_tuple())) + new_trackers = list(tdef.get_trackers() - download.get_def().get_trackers()) if new_trackers: self.update_trackers(tdef.get_infohash(), new_trackers) return download @@ -779,8 +778,8 @@ def update_trackers(self, infohash, trackers): download = self.get_download(infohash) if download: old_def = download.get_def() - old_trackers = old_def.get_trackers_as_single_tuple() - new_trackers = list(set(trackers) - set(old_trackers)) + old_trackers = old_def.get_trackers() + new_trackers = list(set(trackers) - old_trackers) all_trackers = list(old_trackers) + new_trackers if new_trackers: diff --git a/src/tribler/core/components/libtorrent/restapi/downloads_endpoint.py b/src/tribler/core/components/libtorrent/restapi/downloads_endpoint.py index 2fb61f35440..4b9950fd9b6 100644 --- a/src/tribler/core/components/libtorrent/restapi/downloads_endpoint.py +++ b/src/tribler/core/components/libtorrent/restapi/downloads_endpoint.py @@ -44,13 +44,13 @@ def _safe_extended_peer_info(ext_peer_info): """ # First see if we can use this as-is if not ext_peer_info: - return '' + return "" try: return ensure_unicode(ext_peer_info, "utf8") except UnicodeDecodeError: # We might have some special unicode characters in here - return ''.join([chr(c) for c in ext_peer_info]) + return "".join(map(chr, ext_peer_info)) def get_extended_status(tunnel_community, download) -> DownloadStatus: diff --git a/src/tribler/core/components/libtorrent/restapi/tests/test_downloads_endpoint.py b/src/tribler/core/components/libtorrent/restapi/tests/test_downloads_endpoint.py index cd2b8654456..756df417067 100644 --- a/src/tribler/core/components/libtorrent/restapi/tests/test_downloads_endpoint.py +++ b/src/tribler/core/components/libtorrent/restapi/tests/test_downloads_endpoint.py @@ -1,11 +1,13 @@ import collections import os +import unittest.mock from unittest.mock import Mock import pytest from aiohttp.web_app import Application from ipv8.util import fail, succeed +import tribler.core.components.libtorrent.restapi.downloads_endpoint as download_endpoint from tribler.core.components.libtorrent.download_manager.download_state import DownloadState from tribler.core.components.libtorrent.restapi.downloads_endpoint import DownloadsEndpoint, get_extended_status from tribler.core.components.restapi.rest.base_api_test import do_request @@ -145,6 +147,16 @@ def test_get_extended_status_circuits(mock_extended_status): assert mock_extended_status == DownloadStatus.CIRCUITS +@unittest.mock.patch("tribler.core.components.libtorrent.restapi.downloads_endpoint.ensure_unicode", + Mock(side_effect=UnicodeDecodeError("", b"", 0, 0, ""))) +def test_safe_extended_peer_info(): + """ + Test that we return the string mapped by `chr` in the case of `UnicodeDecodeError` + """ + extended_peer_info = download_endpoint._safe_extended_peer_info(b"abcd") # pylint: disable=protected-access + assert extended_peer_info == "abcd" + + async def test_get_downloads_if_checkpoints_are_not_loaded(mock_dlmgr, rest_api): mock_dlmgr.checkpoints_count = 10 mock_dlmgr.checkpoints_loaded = 5 diff --git a/src/tribler/core/components/libtorrent/tests/test_download_manager.py b/src/tribler/core/components/libtorrent/tests/test_download_manager.py index d01c050316c..2834c5668f6 100644 --- a/src/tribler/core/components/libtorrent/tests/test_download_manager.py +++ b/src/tribler/core/components/libtorrent/tests/test_download_manager.py @@ -263,7 +263,7 @@ def test_start_download_existing_download(fake_dlmgr): infohash = b'a' * 20 mock_download = MagicMock() - mock_download.get_def = lambda: MagicMock(get_trackers_as_single_tuple=lambda: ()) + mock_download.get_def = lambda: MagicMock(get_trackers=set) mock_ltsession = MagicMock() diff --git a/src/tribler/core/components/libtorrent/tests/test_torrent_def.py b/src/tribler/core/components/libtorrent/tests/test_torrent_def.py index 90def1736ab..009222abc06 100644 --- a/src/tribler/core/components/libtorrent/tests/test_torrent_def.py +++ b/src/tribler/core/components/libtorrent/tests/test_torrent_def.py @@ -1,4 +1,5 @@ import shutil +from unittest.mock import Mock import pytest from aiohttp import ClientResponseError @@ -154,11 +155,19 @@ def test_set_tracker_strip_slash(tdef): def test_set_tracker(tdef): - assert not tdef.get_trackers_as_single_tuple() + assert len(tdef.get_trackers()) == 0 tdef.set_tracker("http://tracker.org") - assert tdef.get_trackers_as_single_tuple() == ('http://tracker.org',) + assert tdef.get_trackers() == {'http://tracker.org'} +def test_get_trackers(tdef): + """ + Test that `get_trackers` returns flat set of trackers + """ + tdef.get_tracker_hierarchy = Mock(return_value=[["t1", "t2"], ["t3"], ["t4"]]) + trackers = tdef.get_trackers() + assert trackers == {"t1", "t2", "t3", "t4"} + def test_get_nr_pieces(tdef): """ Test getting the number of pieces from a TorrentDef @@ -207,13 +216,13 @@ def test_torrent_no_metainfo(): assert tdef.get_name_as_unicode() == VIDEO_FILE_NAME assert not tdef.get_files() assert tdef.get_files_with_length() == [] - assert not tdef.get_trackers_as_single_tuple() + assert len(tdef.get_trackers()) == 0 assert not tdef.is_private() assert tdef.get_name_utf8() == "video.avi" assert tdef.get_nr_pieces() == 0 torrent2 = TorrentDefNoMetainfo(b"12345678901234567890", VIDEO_FILE_NAME, "magnet:") - assert not torrent2.get_trackers_as_single_tuple() + assert len(torrent2.get_trackers()) == 0 def test_get_length(tdef): @@ -250,7 +259,20 @@ def test_get_name_as_unicode(tdef): tdef.metainfo = {b'info': {b'name': name_bytes}} assert tdef.get_name_as_unicode() == name_unicode tdef.metainfo = {b'info': {b'name': b'test\xff' + name_bytes}} - assert tdef.get_name_as_unicode() == 'test?????????????' + assert tdef.get_name_as_unicode() == 'test' + '?' * len(b'\xff' + name_bytes) + + +def test_filter_characters(tdef): + """ + Test `_filter_characters` sanitizes its input + """ + name_bytes = b"\xe8\xaf\xad\xe8\xa8\x80\xe5\xa4\x84\xe7\x90\x86" + name = name_bytes + name_sanitized = "?" * len(name) + assert tdef._filter_characters(name) == name_sanitized # pylint: disable=protected-access + name = b"test\xff" + name_bytes + name_sanitized = "test" + "?" * len(b"\xff" + name_bytes) + assert tdef._filter_characters(name) == name_sanitized # pylint: disable=protected-access def test_get_files_with_length(tdef): diff --git a/src/tribler/core/components/libtorrent/torrentdef.py b/src/tribler/core/components/libtorrent/torrentdef.py index 34ccbb2371b..952792bc1db 100644 --- a/src/tribler/core/components/libtorrent/torrentdef.py +++ b/src/tribler/core/components/libtorrent/torrentdef.py @@ -1,6 +1,7 @@ """ Author(s): Arno Bakker """ +import itertools import logging from hashlib import sha1 @@ -135,6 +136,25 @@ async def load_from_url(url): body = await response.read() return TorrentDef.load_from_memory(body) + def _filter_characters(self, name: bytes) -> str: + """ + Sanitize the names in path to unicode by replacing out all + characters that may -even remotely- cause problems with the '?' + character. + + :param name: the name to sanitize + :type name: bytes + :return: the sanitized string + :rtype: str + """ + def filter_character(char: int) -> str: + if 0 < char < 128: + return chr(char) + self._logger.debug("Bad character 0x%X", char) + return "?" + + return "".join(map(filter_character, name)) + def add_content(self, file_path): """ Add some content to the torrent file. @@ -179,21 +199,20 @@ def get_tracker_hierarchy(self): """ return self.torrent_parameters.get(b'announce-list', []) - def get_trackers_as_single_tuple(self): + def get_trackers(self) -> set: """ - Returns a flat tuple of all known trackers. + Returns a flat set of all known trackers. + + :return: all known trackers + :rtype: set """ if self.get_tracker_hierarchy(): - trackers = [] - for level in self.get_tracker_hierarchy(): - for tracker in level: - if tracker and tracker not in trackers: - trackers.append(tracker) - return tuple(trackers) + trackers = itertools.chain.from_iterable(self.get_tracker_hierarchy()) + return set(filter(None, trackers)) tracker = self.get_tracker() if tracker: - return tracker, - return () + return {tracker} + return set() def set_piece_length(self, piece_length): """ @@ -296,16 +315,7 @@ def get_name_as_unicode(self): # all characters that may -even remotely- cause problems # with the '?' character try: - def filter_characters(name): - def filter_character(char): - if 0 < char < 128: - return chr(char) - self._logger.debug("Bad character 0x%X", char) - return "?" - - return "".join([filter_character(char) for char in name]) - - return filter_characters(self.metainfo[b"info"][b"name"]) + return self._filter_characters(self.metainfo[b"info"][b"name"]) except UnicodeError: pass @@ -339,7 +349,7 @@ def _get_all_files_as_unicode_with_length(self): # We assume that it is correctly encoded and use # it normally try: - yield (Path(*[ensure_unicode(element, "UTF-8") for element in file_dict[b"path.utf-8"]]), + yield (Path(*(ensure_unicode(element, "UTF-8") for element in file_dict[b"path.utf-8"])), file_dict[b"length"]) continue except UnicodeError: @@ -351,7 +361,7 @@ def _get_all_files_as_unicode_with_length(self): if b"encoding" in self.metainfo: encoding = ensure_unicode(self.metainfo[b"encoding"], "utf8") try: - yield (Path(*[ensure_unicode(element, encoding) for element in file_dict[b"path"]]), + yield (Path(*(ensure_unicode(element, encoding) for element in file_dict[b"path"])), file_dict[b"length"]) continue except UnicodeError: @@ -366,7 +376,7 @@ def _get_all_files_as_unicode_with_length(self): # Try to convert the names in path to unicode, # assuming that it was encoded as utf-8 try: - yield (Path(*[ensure_unicode(element, "UTF-8") for element in file_dict[b"path"]]), + yield (Path(*(ensure_unicode(element, "UTF-8") for element in file_dict[b"path"])), file_dict[b"length"]) continue except UnicodeError: @@ -376,17 +386,7 @@ def _get_all_files_as_unicode_with_length(self): # replacing out all characters that may -even # remotely- cause problems with the '?' character try: - def filter_characters(name): - def filter_character(char): - if 0 < char < 128: - return chr(char) - self._logger.debug("Bad character 0x%X", char) - return "?" - - return "".join([filter_character(char) for char in name]) - - yield (Path(*[filter_characters(element) for element in file_dict[b"path"]]), - file_dict[b"length"]) + yield (Path(*map(self._filter_characters, file_dict[b"path"])), file_dict[b"length"]) continue except UnicodeError: pass @@ -517,11 +517,17 @@ def get_files(self, exts=None): def get_files_with_length(self, exts=None): return [] - def get_trackers_as_single_tuple(self): + def get_trackers(self) -> set: + """ + Returns a flat set of all known trackers. + + :return: all known trackers + :rtype: set + """ if self.url and self.url.startswith('magnet:'): - _, _, trs = parse_magnetlink(self.url) - return tuple(trs) - return () + trackers = parse_magnetlink(self.url)[2] + return set(trackers) + return set() def is_private(self): return False diff --git a/src/tribler/core/components/metadata_store/remote_query_community/payload_checker.py b/src/tribler/core/components/metadata_store/remote_query_community/payload_checker.py index 278f4ee6340..4792d7ca068 100644 --- a/src/tribler/core/components/metadata_store/remote_query_community/payload_checker.py +++ b/src/tribler/core/components/metadata_store/remote_query_community/payload_checker.py @@ -102,7 +102,7 @@ def reject_payload_with_offending_words(self): """ if is_forbidden( " ".join( - [getattr(self.payload, attr) for attr in ("title", "tags", "text") if hasattr(self.payload, attr)]) + getattr(self.payload, attr) for attr in ("title", "tags", "text") if hasattr(self.payload, attr)) ): return [] return CONTINUE