Skip to content
This repository has been archived by the owner on Aug 19, 2024. It is now read-only.

Commit

Permalink
Merge pull request #43 from qstokkink/fix_ruff_remainder
Browse files Browse the repository at this point in the history
Fixed ruff violations
  • Loading branch information
qstokkink authored May 8, 2024
2 parents fc2acbd + ef11c88 commit 2522daf
Show file tree
Hide file tree
Showing 24 changed files with 676 additions and 250 deletions.
122 changes: 109 additions & 13 deletions src/tribler/core/components.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,64 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, cast, Type
from typing import TYPE_CHECKING, Type, cast

from ipv8.bootstrapping.bootstrapper_interface import Bootstrapper
from ipv8.bootstrapping.dispersy.bootstrapper import DispersyBootstrapper
from ipv8.community import Community
from ipv8.configuration import DISPERSY_BOOTSTRAPPER
from ipv8.loader import CommunityLauncher, set_in_session, overlay, kwargs, precondition, after, walk_strategy
from ipv8.overlay import SettingsClass, Overlay
from ipv8.peer import Peer
from ipv8.loader import CommunityLauncher, after, kwargs, overlay, precondition, set_in_session, walk_strategy
from ipv8.overlay import Overlay, SettingsClass
from ipv8.peerdiscovery.discovery import DiscoveryStrategy, RandomWalk

if TYPE_CHECKING:
from ipv8.bootstrapping.bootstrapper_interface import Bootstrapper
from ipv8.peer import Peer
from ipv8.types import IPv8

from tribler.core.session import Session


class BaseLauncher(CommunityLauncher):
"""
The base class for all Tribler Community launchers.
"""

def get_overlay_class(self) -> type[Community]:
"""
Overwrite this to return the correct Community type.
"""
raise NotImplementedError

def get_bootstrappers(self, session: Session) -> list[tuple[type[Bootstrapper], dict]]:
"""
Simply use the old Dispersy bootstrapper format.
"""
return [(DispersyBootstrapper, DISPERSY_BOOTSTRAPPER["init"])]

def get_walk_strategies(self) -> list[tuple[type[DiscoveryStrategy], dict, int]]:
"""
Adhere to the default walking behavior.
"""
return [(RandomWalk, {}, 20)]

def get_my_peer(self, ipv8: IPv8, session: Session) -> Peer:
"""
Get the default key.
"""
return ipv8.keys["anonymous id"]


class Component(Community):
"""
A glorified TaskManager. This should also really be a TaskManager.
def __init__(self, settings: SettingsClass):
I did not make this a TaskManager because I am lazy - Quinten (2024)
"""

def __init__(self, settings: SettingsClass) -> None:
"""
Create a new inert fake Community.
"""
settings.community_id = self.__class__.__name__.encode()
Overlay.__init__(self, settings)
self.cancel_pending_task("discover_lan_addresses")
Expand All @@ -47,11 +70,20 @@ def __init__(self, settings: SettingsClass):


class ComponentLauncher(CommunityLauncher):
"""
A launcher for components that simply need a TaskManager, not a full Community.
"""

def get_overlay_class(self) -> type[Community]:
"""
Create a fake Community.
"""
return cast(Type[Community], type(f"{self.__class__.__name__}", (Component,), {}))

def get_my_peer(self, ipv8: IPv8, session: Session) -> Peer:
"""
Our peer still uses the Tribler default key.
"""
return ipv8.keys["anonymous id"]


Expand All @@ -63,19 +95,32 @@ def get_my_peer(self, ipv8: IPv8, session: Session) -> Peer:
@overlay("tribler.core.content_discovery.community", "ContentDiscoveryCommunity")
@kwargs(metadata_store="session.mds", torrent_checker="session.torrent_checker", notifier="session.notifier")
class ContentDiscoveryComponent(BaseLauncher):
"""
Launch instructions for the content discovery community.
"""

def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
"""
When we are done launching, register our REST API.
"""
from tribler.core.content_discovery.community import ContentDiscoveryCommunity
from tribler.core.content_discovery.restapi.search_endpoint import SearchEndpoint

session.rest_manager.add_endpoint(SearchEndpoint(community))
session.rest_manager.add_endpoint(SearchEndpoint(cast(ContentDiscoveryCommunity, community)))


@precondition('session.config.get("database/enabled")')
class DatabaseComponent(ComponentLauncher):
"""
Launch instructions for the database.
"""

def prepare(self, ipv8: IPv8, session: Session) -> None:
from tribler.core.database.tribler_database import TriblerDatabase
"""
Create the database instances we need for Tribler.
"""
from tribler.core.database.store import MetadataStore
from tribler.core.database.tribler_database import TriblerDatabase
from tribler.core.knowledge.rules.knowledge_rules_processor import KnowledgeRulesProcessor
from tribler.core.notifier import Notification

Expand All @@ -96,6 +141,9 @@ def prepare(self, ipv8: IPv8, session: Session) -> None:
session.notifier.add(Notification.torrent_metadata_added, session.mds.TorrentMetadata.add_ffa_from_dict)

def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
"""
When we are done launching, register our REST API.
"""
from tribler.core.database.restapi.database_endpoint import DatabaseEndpoint

session.rest_manager.get_endpoint("/downloads").mds = session.mds
Expand All @@ -113,10 +161,17 @@ def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
@overlay("tribler.core.knowledge.community", "KnowledgeCommunity")
@kwargs(db="session.db", key='session.ipv8.keys["secondary"].key')
class KnowledgeComponent(CommunityLauncher):
"""
Launch instructions for the knowledge community.
"""

def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
from tribler.core.knowledge.rules.knowledge_rules_processor import KnowledgeRulesProcessor
"""
When we are done launching, register our REST API.
"""
from tribler.core.knowledge.community import KnowledgeCommunity
from tribler.core.knowledge.restapi.knowledge_endpoint import KnowledgeEndpoint
from tribler.core.knowledge.rules.knowledge_rules_processor import KnowledgeRulesProcessor

session.knowledge_processor = KnowledgeRulesProcessor(
notifier=session.notifier,
Expand All @@ -125,15 +180,21 @@ def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
)
session.knowledge_processor.start()
session.rest_manager.get_endpoint("/metadata").tag_rules_processor = session.knowledge_processor
session.rest_manager.add_endpoint(KnowledgeEndpoint(session.db, community))
session.rest_manager.add_endpoint(KnowledgeEndpoint(session.db, cast(KnowledgeCommunity, community)))


@after("DatabaseComponent")
@precondition('session.config.get("rendezvous/enabled")')
@overlay("tribler.core.rendezvous.community", "RendezvousCommunity")
class RendezvousComponent(BaseLauncher):
"""
Launch instructions for the rendezvous community.
"""

def get_kwargs(self, session: object) -> dict:
def get_kwargs(self, session: Session) -> dict:
"""
Create and forward the rendezvous database for the Community.
"""
from tribler.core.rendezvous.database import RendezvousDatabase

out = super().get_kwargs(session)
Expand All @@ -142,16 +203,26 @@ def get_kwargs(self, session: object) -> dict:
return out

def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
"""
Start listening to peer connections after starting.
"""
from tribler.core.rendezvous.community import RendezvousCommunity
from tribler.core.rendezvous.rendezvous_hook import RendezvousHook

rendezvous_hook = RendezvousHook(community.composition.database)
rendezvous_hook = RendezvousHook(cast(RendezvousCommunity, community).composition.database)
ipv8.network.add_peer_observer(rendezvous_hook)


@precondition('session.config.get("torrent_checker/enabled")')
class TorrentCheckerComponent(ComponentLauncher):
"""
Launch instructions for the torrent checker.
"""

def prepare(self, overlay_provider: IPv8, session: Session) -> None:
"""
Initialize the torrecht checker and the torrent manager.
"""
from tribler.core.torrent_checker.torrent_checker import TorrentChecker
from tribler.core.torrent_checker.tracker_manager import TrackerManager

Expand All @@ -165,6 +236,9 @@ def prepare(self, overlay_provider: IPv8, session: Session) -> None:
session.torrent_checker = torrent_checker

def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
"""
When we are done launching, register our REST API.
"""
community.register_task("Start torrent checker", session.torrent_checker.initialize)
session.rest_manager.get_endpoint("/metadata").torrent_checker = session.torrent_checker

Expand All @@ -173,8 +247,14 @@ def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
@precondition('session.config.get("dht_discovery/enabled")')
@overlay("ipv8.dht.discovery", "DHTDiscoveryCommunity")
class DHTDiscoveryComponent(BaseLauncher):
"""
Launch instructions for the DHT discovery community.
"""

def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
"""
When we are done launching, register our REST API.
"""
session.rest_manager.get_endpoint("/ipv8").endpoints["/dht"].dht = community


Expand All @@ -183,8 +263,14 @@ def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
@walk_strategy("tribler.core.tunnel.discovery", "GoldenRatioStrategy", -1)
@overlay("tribler.core.tunnel.community", "TriblerTunnelCommunity")
class TunnelComponent(BaseLauncher):
"""
Launch instructions for the tunnel community.
"""

def get_kwargs(self, session: Session) -> dict:
"""
Extend our community arguments with all necessary config settings and objects.
"""
from ipv8.dht.discovery import DHTDiscoveryCommunity
from ipv8.dht.provider import DHTCommunityProvider

Expand All @@ -202,16 +288,26 @@ def get_kwargs(self, session: Session) -> dict:
return out

def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
"""
When we are done launching, register our REST API.
"""
session.rest_manager.get_endpoint("/downloads").tunnel_community = community
session.rest_manager.get_endpoint("/ipv8").endpoints["/tunnel"].tunnels = community


@after("ContentDiscoveryComponent", "TorrentCheckerComponent")
@precondition('session.config.get("user_activity/enabled")')
class UserActivityComponent(ComponentLauncher):
"""
Launch instructions for the user activity community.
"""

def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
"""
When we are done launching, start listening for GUI events.
"""
from tribler.core.user_activity.manager import UserActivityManager

component = cast(Component, community)
max_query_history = session.config.get("user_activity/max_query_history")
community.settings.manager = UserActivityManager(community, session, max_query_history)
component.settings.manager = UserActivityManager(component, session, max_query_history)
11 changes: 6 additions & 5 deletions src/tribler/core/libtorrent/download_manager/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from copy import deepcopy
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, List

import libtorrent as lt
from configobj import ConfigObj
Expand Down Expand Up @@ -423,7 +423,7 @@ def get_download_rate_limit(self, hops: int = 0) -> int:
libtorrent_rate = self.get_session(hops=hops).download_rate_limit()
return self.reverse_convert_rate(rate=libtorrent_rate)

def process_alert(self, alert, hops: int = 0) -> None: # noqa: C901, PLR0912
def process_alert(self, alert: lt.alert, hops: int = 0) -> None: # noqa: C901, PLR0912
"""
Process a libtorrent alert.
"""
Expand Down Expand Up @@ -909,7 +909,8 @@ def update_trackers(self, infohash: bytes, trackers: list[str]) -> None:
download.set_def(new_def)
download.checkpoint()

def set_download_states_callback(self, user_callback, interval: float = 1.0) -> None:
def set_download_states_callback(self, user_callback: Callable[[list[DownloadState]], Awaitable[None] | None],
interval: float = 1.0) -> None:
"""
Set the download state callback. Remove any old callback if it's present.
Calls user_callback with a list of
Expand All @@ -925,7 +926,7 @@ def set_download_states_callback(self, user_callback, interval: float = 1.0) ->
logger.debug("Starting the download state callback with interval %f", interval)
self.replace_task("download_states_lc", self._invoke_states_cb, user_callback, interval=interval)

async def _invoke_states_cb(self, callback) -> None:
async def _invoke_states_cb(self, callback: Callable[[list[DownloadState]], Awaitable[None] | None]) -> None:
"""
Invoke the download states callback with a list of the download states.
"""
Expand Down Expand Up @@ -980,7 +981,7 @@ async def load_checkpoints(self) -> None:
self.all_checkpoints_are_loaded = True
self._logger.info("Checkpoints are loaded")

async def load_checkpoint(self, filename: str) -> bool:
async def load_checkpoint(self, filename: Path | str) -> bool:
"""
Load a checkpoint from a given file name.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)


def recursive_bytes(obj):
def recursive_bytes(obj): # noqa: ANN001, ANN201
"""
Converts any unicode strings within a Python data structure to bytes. Strings will be encoded using UTF-8.
Expand Down
Loading

0 comments on commit 2522daf

Please sign in to comment.