diff --git a/src/tribler/core/components/libtorrent/download_manager/download_config.py b/src/tribler/core/components/libtorrent/download_manager/download_config.py index fa532786507..5e355cfd21c 100644 --- a/src/tribler/core/components/libtorrent/download_manager/download_config.py +++ b/src/tribler/core/components/libtorrent/download_manager/download_config.py @@ -1,4 +1,5 @@ import base64 +from typing import Dict, Optional from configobj import ConfigObj from validate import Validator @@ -15,6 +16,20 @@ CONFIG_SPEC_PATH = get_lib_path() / 'components/libtorrent/download_manager' / SPEC_FILENAME NONPERSISTENT_DEFAULTS = {} + +def _from_dict(value: Dict) -> str: + binary = lt.bencode(value) + base64_bytes = base64.b64encode(binary) + return base64_bytes.decode('utf-8') + + +def _to_dict(value: str) -> Optional[Dict]: + binary = value.encode('utf-8') + # b'==' is added to avoid incorrect padding + base64_bytes = base64.b64decode(binary + b'==') + return bdecode_compat(base64_bytes) + + class DownloadConfig: def __init__(self, config=None, state_dir=None): self.config = config or ConfigObj(configspec=str(CONFIG_SPEC_PATH), default_encoding='utf8') @@ -150,17 +165,17 @@ def set_bootstrap_download(self, value): def get_bootstrap_download(self): return self.config['download_defaults']['bootstrap_download'] - def set_metainfo(self, metainfo): - self.config['state']['metainfo'] = base64.b64encode(lt.bencode(metainfo)).decode('utf-8') + def set_metainfo(self, metainfo: Dict): + self.config['state']['metainfo'] = _from_dict(metainfo) - def get_metainfo(self): - return bdecode_compat(base64.b64decode(self.config['state']['metainfo'].encode('utf-8'))) + def get_metainfo(self) -> Optional[Dict]: + return _to_dict(self.config['state']['metainfo']) - def set_engineresumedata(self, engineresumedata): - self.config['state']['engineresumedata'] = base64.b64encode(lt.bencode(engineresumedata)).decode('utf-8') + def set_engineresumedata(self, engineresumedata: Dict): + self.config['state']['engineresumedata'] = _from_dict(engineresumedata) - def get_engineresumedata(self): - return bdecode_compat(base64.b64decode(self.config['state']['engineresumedata'].encode('utf-8'))) + def get_engineresumedata(self) -> Optional[Dict]: + return _to_dict(self.config['state']['engineresumedata']) def get_default_dest_dir(): diff --git a/src/tribler/core/components/libtorrent/tests/test_download.py b/src/tribler/core/components/libtorrent/tests/test_download.py index 225d79395bd..698d72ba96b 100644 --- a/src/tribler/core/components/libtorrent/tests/test_download.py +++ b/src/tribler/core/components/libtorrent/tests/test_download.py @@ -49,7 +49,7 @@ async def test_save_resume(mock_handle, test_download, test_tdef): basename = hexlify(test_tdef.get_infohash()) + '.conf' filename = test_download.dlmgr.get_checkpoint_dir() / basename dcfg = DownloadConfig.load(str(filename)) - assert test_tdef.get_infohash(), dcfg.get_engineresumedata().get(b'info-hash') + assert test_tdef.get_infohash() == dcfg.get_engineresumedata().get(b'info-hash') def test_move_storage(mock_handle, test_download, test_tdef, test_tdef_no_metainfo): diff --git a/src/tribler/core/components/libtorrent/tests/test_download_config.py b/src/tribler/core/components/libtorrent/tests/test_download_config.py index 938bd571b05..f98bf78fa9c 100644 --- a/src/tribler/core/components/libtorrent/tests/test_download_config.py +++ b/src/tribler/core/components/libtorrent/tests/test_download_config.py @@ -1,13 +1,12 @@ from pathlib import Path -from configobj import ConfigObjError - import pytest +from configobj import ConfigObjError -from tribler.core.components.libtorrent.download_manager.download_config import DownloadConfig, get_default_dest_dir +from tribler.core.components.libtorrent.download_manager.download_config import DownloadConfig, _from_dict, _to_dict, \ + get_default_dest_dir from tribler.core.tests.tools.common import TESTS_DATA_DIR - CONFIG_FILES_DIR = TESTS_DATA_DIR / "config_files" @@ -71,3 +70,15 @@ def test_user_stopped(download_config): download_config.set_user_stopped(True) assert download_config.get_user_stopped() + + +def test_to_dict(): + d = {b'a': b'b'} + s = _from_dict(d) + assert d == _to_dict(s) + + +def test_avoid_incorrect_padding(): + assert {b'a': b'b'} == _to_dict('ZDE6YTE6YmU==') + assert {b'a': b'b'} == _to_dict('ZDE6YTE6YmU=') + assert {b'a': b'b'} == _to_dict('ZDE6YTE6YmU') diff --git a/src/tribler/core/utilities/utilities.py b/src/tribler/core/utilities/utilities.py index 5c70ebf797c..952f8384064 100644 --- a/src/tribler/core/utilities/utilities.py +++ b/src/tribler/core/utilities/utilities.py @@ -13,7 +13,7 @@ from base64 import b32decode from dataclasses import dataclass, field from functools import wraps -from typing import Set, Tuple +from typing import Dict, Optional, Set, Tuple from urllib.parse import parse_qsl, urlsplit from tribler.core.components.libtorrent.utils.libtorrent_helper import libtorrent as lt @@ -160,7 +160,7 @@ def is_hex_string(text): return False -def bdecode_compat(packet_buffer): +def bdecode_compat(packet_buffer: bytes) -> Optional[Dict]: """ Utility method to make libtorrent bdecode() with Python3 in the existing Tribler codebase. We should change this when Libtorrent wrapper is refactored. @@ -226,7 +226,7 @@ def extract_tags(text: str) -> Tuple[Set[str], str]: positions.extend(itertools.chain.from_iterable(m.regs)) positions.append(len(text)) - remaining_text = ''.join(text[positions[i] : positions[i + 1]] for i in range(0, len(positions) - 1, 2)) + remaining_text = ''.join(text[positions[i]: positions[i + 1]] for i in range(0, len(positions) - 1, 2)) return tags, remaining_text