Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace list comprehensions with generator expressions and similar functions #7440

Merged
merged 7 commits into from
May 26, 2023
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,7 @@ def start_download(self, torrent_file=None, tdef=None, config: DownloadConfig =
download = self.get_download(infohash)

if download and infohash not in self.metainfo_requests:
new_trackers = list(set(tdef.get_trackers_as_single_tuple()) -
set(download.get_def().get_trackers_as_single_tuple()))
new_trackers = list(tdef.get_trackers() - download.get_def().get_trackers())
if new_trackers:
self.update_trackers(tdef.get_infohash(), new_trackers)
return download
Expand Down Expand Up @@ -779,8 +778,8 @@ def update_trackers(self, infohash, trackers):
download = self.get_download(infohash)
if download:
old_def = download.get_def()
old_trackers = old_def.get_trackers_as_single_tuple()
new_trackers = list(set(trackers) - set(old_trackers))
old_trackers = old_def.get_trackers()
new_trackers = list(set(trackers) - old_trackers)
all_trackers = list(old_trackers) + new_trackers

if new_trackers:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def _safe_extended_peer_info(ext_peer_info):
"""
# First see if we can use this as-is
if not ext_peer_info:
return ''
return ""

try:
return ensure_unicode(ext_peer_info, "utf8")
except UnicodeDecodeError:
# We might have some special unicode characters in here
return ''.join([chr(c) for c in ext_peer_info])
return "".join(map(chr, ext_peer_info))


def get_extended_status(tunnel_community, download) -> DownloadStatus:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import collections
import os
import unittest.mock
from unittest.mock import Mock

import pytest
from aiohttp.web_app import Application
from ipv8.util import fail, succeed

import tribler.core.components.libtorrent.restapi.downloads_endpoint as download_endpoint
from tribler.core.components.libtorrent.download_manager.download_state import DownloadState
from tribler.core.components.libtorrent.restapi.downloads_endpoint import DownloadsEndpoint, get_extended_status
from tribler.core.components.restapi.rest.base_api_test import do_request
Expand Down Expand Up @@ -145,6 +147,16 @@ def test_get_extended_status_circuits(mock_extended_status):
assert mock_extended_status == DownloadStatus.CIRCUITS


@unittest.mock.patch("tribler.core.components.libtorrent.restapi.downloads_endpoint.ensure_unicode",
Mock(side_effect=UnicodeDecodeError("", b"", 0, 0, "")))
def test_safe_extended_peer_info():
"""
Test that we return the string mapped by `chr` in the case of `UnicodeDecodeError`
"""
extended_peer_info = download_endpoint._safe_extended_peer_info(b"abcd") # pylint: disable=protected-access
assert extended_peer_info == "abcd"


async def test_get_downloads_if_checkpoints_are_not_loaded(mock_dlmgr, rest_api):
mock_dlmgr.checkpoints_count = 10
mock_dlmgr.checkpoints_loaded = 5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def test_start_download_existing_download(fake_dlmgr):
infohash = b'a' * 20

mock_download = MagicMock()
mock_download.get_def = lambda: MagicMock(get_trackers_as_single_tuple=lambda: ())
mock_download.get_def = lambda: MagicMock(get_trackers=set)

mock_ltsession = MagicMock()

Expand Down
32 changes: 27 additions & 5 deletions src/tribler/core/components/libtorrent/tests/test_torrent_def.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import shutil
from unittest.mock import Mock

import pytest
from aiohttp import ClientResponseError
Expand Down Expand Up @@ -154,11 +155,19 @@ def test_set_tracker_strip_slash(tdef):


def test_set_tracker(tdef):
assert not tdef.get_trackers_as_single_tuple()
assert len(tdef.get_trackers()) == 0
tdef.set_tracker("http://tracker.org")
assert tdef.get_trackers_as_single_tuple() == ('http://tracker.org',)
assert tdef.get_trackers() == {'http://tracker.org'}


def test_get_trackers(tdef):
"""
Test that `get_trackers` returns flat set of trackers
"""
tdef.get_tracker_hierarchy = Mock(return_value=[["t1", "t2"], ["t3"], ["t4"]])
trackers = tdef.get_trackers()
assert trackers == {"t1", "t2", "t3", "t4"}

def test_get_nr_pieces(tdef):
"""
Test getting the number of pieces from a TorrentDef
Expand Down Expand Up @@ -207,13 +216,13 @@ def test_torrent_no_metainfo():
assert tdef.get_name_as_unicode() == VIDEO_FILE_NAME
assert not tdef.get_files()
assert tdef.get_files_with_length() == []
assert not tdef.get_trackers_as_single_tuple()
assert len(tdef.get_trackers()) == 0
assert not tdef.is_private()
assert tdef.get_name_utf8() == "video.avi"
assert tdef.get_nr_pieces() == 0

torrent2 = TorrentDefNoMetainfo(b"12345678901234567890", VIDEO_FILE_NAME, "magnet:")
assert not torrent2.get_trackers_as_single_tuple()
assert len(torrent2.get_trackers()) == 0


def test_get_length(tdef):
Expand Down Expand Up @@ -250,7 +259,20 @@ def test_get_name_as_unicode(tdef):
tdef.metainfo = {b'info': {b'name': name_bytes}}
assert tdef.get_name_as_unicode() == name_unicode
tdef.metainfo = {b'info': {b'name': b'test\xff' + name_bytes}}
assert tdef.get_name_as_unicode() == 'test?????????????'
assert tdef.get_name_as_unicode() == 'test' + '?' * len(b'\xff' + name_bytes)


def test_filter_characters(tdef):
"""
Test `_filter_characters` sanitizes its input
"""
name_bytes = b"\xe8\xaf\xad\xe8\xa8\x80\xe5\xa4\x84\xe7\x90\x86"
name = name_bytes
name_sanitized = "?" * len(name)
assert tdef._filter_characters(name) == name_sanitized # pylint: disable=protected-access
name = b"test\xff" + name_bytes
name_sanitized = "test" + "?" * len(b"\xff" + name_bytes)
assert tdef._filter_characters(name) == name_sanitized # pylint: disable=protected-access


def test_get_files_with_length(tdef):
Expand Down
82 changes: 44 additions & 38 deletions src/tribler/core/components/libtorrent/torrentdef.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Author(s): Arno Bakker
"""
import itertools
import logging
from hashlib import sha1

Expand Down Expand Up @@ -135,6 +136,25 @@ async def load_from_url(url):
body = await response.read()
return TorrentDef.load_from_memory(body)

def _filter_characters(self, name: bytes) -> str:
"""
Sanitize the names in path to unicode by replacing out all
characters that may -even remotely- cause problems with the '?'
character.

:param name: the name to sanitize
:type name: bytes
:return: the sanitized string
:rtype: str
"""
def filter_character(char: int) -> str:
if 0 < char < 128:
return chr(char)
self._logger.debug("Bad character 0x%X", char)
return "?"

return "".join(map(filter_character, name))

def add_content(self, file_path):
"""
Add some content to the torrent file.
Expand Down Expand Up @@ -179,21 +199,20 @@ def get_tracker_hierarchy(self):
"""
return self.torrent_parameters.get(b'announce-list', [])

def get_trackers_as_single_tuple(self):
def get_trackers(self) -> set:
"""
Returns a flat tuple of all known trackers.
Returns a flat set of all known trackers.

:return: all known trackers
:rtype: set
"""
if self.get_tracker_hierarchy():
trackers = []
for level in self.get_tracker_hierarchy():
for tracker in level:
if tracker and tracker not in trackers:
trackers.append(tracker)
return tuple(trackers)
trackers = itertools.chain.from_iterable(self.get_tracker_hierarchy())
return set(filter(None, trackers))
tracker = self.get_tracker()
if tracker:
return tracker,
return ()
return {tracker}
return set()

def set_piece_length(self, piece_length):
"""
Expand Down Expand Up @@ -296,16 +315,7 @@ def get_name_as_unicode(self):
# all characters that may -even remotely- cause problems
# with the '?' character
try:
def filter_characters(name):
def filter_character(char):
if 0 < char < 128:
return chr(char)
self._logger.debug("Bad character 0x%X", char)
return "?"

return "".join([filter_character(char) for char in name])

return filter_characters(self.metainfo[b"info"][b"name"])
return self._filter_characters(self.metainfo[b"info"][b"name"])
except UnicodeError:
pass

Expand Down Expand Up @@ -339,7 +349,7 @@ def _get_all_files_as_unicode_with_length(self):
# We assume that it is correctly encoded and use
# it normally
try:
yield (Path(*[ensure_unicode(element, "UTF-8") for element in file_dict[b"path.utf-8"]]),
yield (Path(*(ensure_unicode(element, "UTF-8") for element in file_dict[b"path.utf-8"])),
file_dict[b"length"])
continue
except UnicodeError:
Expand All @@ -351,7 +361,7 @@ def _get_all_files_as_unicode_with_length(self):
if b"encoding" in self.metainfo:
encoding = ensure_unicode(self.metainfo[b"encoding"], "utf8")
try:
yield (Path(*[ensure_unicode(element, encoding) for element in file_dict[b"path"]]),
yield (Path(*(ensure_unicode(element, encoding) for element in file_dict[b"path"])),
file_dict[b"length"])
continue
except UnicodeError:
Expand All @@ -366,7 +376,7 @@ def _get_all_files_as_unicode_with_length(self):
# Try to convert the names in path to unicode,
# assuming that it was encoded as utf-8
try:
yield (Path(*[ensure_unicode(element, "UTF-8") for element in file_dict[b"path"]]),
yield (Path(*(ensure_unicode(element, "UTF-8") for element in file_dict[b"path"])),
file_dict[b"length"])
continue
except UnicodeError:
Expand All @@ -376,17 +386,7 @@ def _get_all_files_as_unicode_with_length(self):
# replacing out all characters that may -even
# remotely- cause problems with the '?' character
try:
def filter_characters(name):
def filter_character(char):
if 0 < char < 128:
return chr(char)
self._logger.debug("Bad character 0x%X", char)
return "?"

return "".join([filter_character(char) for char in name])

yield (Path(*[filter_characters(element) for element in file_dict[b"path"]]),
file_dict[b"length"])
yield (Path(*map(self._filter_characters, file_dict[b"path"])), file_dict[b"length"])
continue
except UnicodeError:
pass
Expand Down Expand Up @@ -517,11 +517,17 @@ def get_files(self, exts=None):
def get_files_with_length(self, exts=None):
return []

def get_trackers_as_single_tuple(self):
def get_trackers(self) -> set:
"""
Returns a flat set of all known trackers.

:return: all known trackers
:rtype: set
"""
if self.url and self.url.startswith('magnet:'):
_, _, trs = parse_magnetlink(self.url)
return tuple(trs)
return ()
trackers = parse_magnetlink(self.url)[2]
return set(trackers)
return set()

def is_private(self):
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def reject_payload_with_offending_words(self):
"""
if is_forbidden(
" ".join(
[getattr(self.payload, attr) for attr in ("title", "tags", "text") if hasattr(self.payload, attr)])
getattr(self.payload, attr) for attr in ("title", "tags", "text") if hasattr(self.payload, attr))
):
return []
return CONTINUE
Expand Down