Skip to content
This repository has been archived by the owner on Aug 19, 2024. It is now read-only.

Commit

Permalink
Resolved typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
qstokkink committed May 13, 2024
1 parent 2522daf commit a99f650
Show file tree
Hide file tree
Showing 62 changed files with 858 additions and 595 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
submodules: 'true'
- uses: actions/setup-python@v5
with:
python-version: '3.8'
python-version: '3.9'
cache: 'pip'
- run: python -m pip install -r requirements.txt
- name: Run unit tests
Expand All @@ -24,7 +24,7 @@ jobs:
submodules: 'true'
- uses: actions/setup-python@v5
with:
python-version: '3.8'
python-version: '3.9'
cache: 'pip'
- uses: actions/cache/restore@v4
id: restore_cache
Expand All @@ -50,7 +50,7 @@ jobs:
submodules: 'true'
- uses: actions/setup-python@v5
with:
python-version: '3.8'
python-version: '3.9'
cache: 'pip'
- run: python -m pip install -r requirements.txt
- name: Run unit tests
Expand Down
2 changes: 2 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[mypy]
ignore_missing_imports = True
8 changes: 4 additions & 4 deletions src/tribler/core/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, settings: SettingsClass) -> None:
Overlay.__init__(self, settings)
self.cancel_pending_task("discover_lan_addresses")
self.endpoint.remove_listener(self)
self.bootstrappers = []
self.bootstrappers: list[Bootstrapper] = []
self.max_peers = 0
self._prefix = settings.community_id
self.settings = settings
Expand Down Expand Up @@ -124,13 +124,13 @@ def prepare(self, ipv8: IPv8, session: Session) -> None:
from tribler.core.knowledge.rules.knowledge_rules_processor import KnowledgeRulesProcessor
from tribler.core.notifier import Notification

db_path = Path(session.config.get("state_dir")) / "sqlite" / "tribler.db"
mds_path = Path(session.config.get("state_dir")) / "sqlite" / "metadata.db"
db_path = str(Path(session.config.get("state_dir")) / "sqlite" / "tribler.db")
mds_path = str(Path(session.config.get("state_dir")) / "sqlite" / "metadata.db")
if session.config.get("memory_db"):
db_path = ":memory:"
mds_path = ":memory:"

session.db = TriblerDatabase(str(db_path))
session.db = TriblerDatabase(db_path)
session.mds = MetadataStore(
mds_path,
session.ipv8.keys["anonymous id"].key,
Expand Down
8 changes: 4 additions & 4 deletions src/tribler/core/content_discovery/restapi/search_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from binascii import hexlify, unhexlify
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

from aiohttp import web
from aiohttp_apispec import docs, querystring_schema
Expand Down Expand Up @@ -48,15 +48,15 @@ def __init__(self,
self.app.add_routes([web.put("/remote", self.remote_search)])

@classmethod
def sanitize_parameters(cls: type[Self], parameters: MultiMapping[str]) -> dict:
def sanitize_parameters(cls: type[Self], parameters: MultiMapping[int | str]) -> dict:
"""
Correct the human-readable parameters to be their respective correct type.
"""
sanitized = dict(parameters)
sanitized: dict = dict(parameters)
if "max_rowid" in parameters:
sanitized["max_rowid"] = int(parameters["max_rowid"])
if "channel_pk" in parameters:
sanitized["channel_pk"] = unhexlify(parameters["channel_pk"])
sanitized["channel_pk"] = unhexlify(cast(str, parameters["channel_pk"]))
if "origin_id" in parameters:
sanitized["origin_id"] = int(parameters["origin_id"])
return sanitized
Expand Down
26 changes: 13 additions & 13 deletions src/tribler/core/database/layers/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ def get(subject: Resource, object: Resource) -> Statement | None: ... # noqa: D
def get_for_update(subject: Resource, object: Resource) -> Statement | None: ... # noqa: D102, A002


class IterResource(type): # noqa: D101

def __iter__(cls) -> Iterator[Resource]: ... # noqa: D105


@dataclasses.dataclass
class Resource(EntityImpl):
class Resource(EntityImpl, metaclass=IterResource):
"""
Database type for a resources.
"""
Expand Down Expand Up @@ -275,7 +280,7 @@ def _get_resources(self, resource_type: ResourceType | None, name: str | None, c

def get_statements(self, source_type: ResourceType | None, source_name: str | None, # noqa: PLR0913
statements_getter: Callable[[Entity], Entity],
target_condition: Callable[[], bool], condition: Callable[[], bool],
target_condition: Callable[[Statement], bool], condition: Callable[[Statement], bool],
case_sensitive: bool, ) -> Iterator[Statement]:
"""
Get entities that satisfies the given condition.
Expand Down Expand Up @@ -365,7 +370,7 @@ def _show_condition(s: Statement) -> bool:

def get_objects(self, subject_type: ResourceType | None = None, subject: str | None = "",
predicate: ResourceType | None = None, case_sensitive: bool = True,
condition: Callable[[], bool] | None = None) -> List[str]:
condition: Callable[[Statement], bool] | None = None) -> List[str]:
"""
Get objects that satisfy the given subject and predicate.
Expand Down Expand Up @@ -438,14 +443,9 @@ def get_simple_statements(self, subject_type: ResourceType | None = None, subjec
case_sensitive=case_sensitive,
)

statements = (SimpleStatement(
subject_type=s.subject.type,
subject=s.subject.name,
predicate=s.object.type,
object=s.object.name
) for s in statements)

return list(statements)
return [SimpleStatement(subject_type=s.subject.type, subject=s.subject.name, predicate=s.object.type,
object=s.object.name)
for s in statements]

def get_suggestions(self, subject_type: ResourceType | None = None, subject: str | None = "",
predicate: ResourceType | None = None, case_sensitive: bool = True) -> List[str]:
Expand All @@ -470,7 +470,7 @@ def get_suggestions(self, subject_type: ResourceType | None = None, subject: str

def get_subjects_intersection(self, objects: Set[str],
predicate: ResourceType | None,
subjects_type: ResourceType | None = ResourceType.TORRENT,
subjects_type: ResourceType = ResourceType.TORRENT,
case_sensitive: bool = True) -> Set[str]:
"""
Get all subjects that have a certain predicate.
Expand Down Expand Up @@ -540,7 +540,7 @@ def _get_random_operations_by_condition(self, condition: Callable[[Entity], bool
:param attempts: maximum attempt count for requesting the DB.
:returns: a set of random operations
"""
operations = set()
operations: set[Entity] = set()
for _ in range(attempts):
if len(operations) == count:
return operations
Expand Down
2 changes: 1 addition & 1 deletion src/tribler/core/database/layers/user_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def get_random_query_aggregate(self, neighbors: int,
# Option 2: aggregate
results = self.Query.select(lambda q: q.query == random_selection.query)[:]
num_results_div = 1/len(results)
preferences = {}
preferences: dict[bytes, float] = {}
for query in results:
for infohash_preference in query.infohashes:
preferences[infohash_preference.infohash] = (preferences.get(infohash_preference.infohash, 0.0)
Expand Down
2 changes: 1 addition & 1 deletion src/tribler/core/database/orm_bindings/torrent_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def get_magnet(self) -> str:

@classmethod
@db_session
def add_ffa_from_dict(cls: type[Self], metadata: dict) -> Self:
def add_ffa_from_dict(cls: type[Self], metadata: dict) -> Self | None:
# To produce a relatively unique id_ we take some bytes of the infohash and convert these to a number.
# abs is necessary as the conversion can produce a negative value, and we do not support that.
id_ = infohash_to_id(metadata["infohash"])
Expand Down
14 changes: 7 additions & 7 deletions src/tribler/core/database/ranks.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def title_rank(query: str, title: str) -> float:
:param title: a torrent name
:return: the similarity of the title string to a query string as a float value in range [0, 1]
"""
query = word_re.findall(query.lower())
title = word_re.findall(title.lower())
return calculate_rank(query, title)
pat_query = word_re.findall(query.lower())
pat_title = word_re.findall(title.lower())
return calculate_rank(pat_query, pat_title)


# These coefficients are found empirically. Their exact values are not very important for a relative ranking of results
Expand Down Expand Up @@ -125,13 +125,13 @@ def calculate_rank(query: list[str], title: list[str]) -> float:
if not title:
return 0.0

title = deque(title)
total_error = 0
q_title = deque(title)
total_error = 0.0
for i, word in enumerate(query):
# The first word is more important than the second word, and so on
word_weight = POSITION_COEFF / (POSITION_COEFF + i)

found, skipped = find_word_and_rotate_title(word, title)
found, skipped = find_word_and_rotate_title(word, q_title)
if found:
# if the query word is found in the title, add penalty for skipped words in title before it
total_error += skipped * word_weight
Expand All @@ -141,7 +141,7 @@ def calculate_rank(query: list[str], title: list[str]) -> float:

# a small penalty for excess words in the title that was not mentioned in the search phrase
remainder_weight = 1 / (REMAINDER_COEFF + len(query))
remained_words_error = len(title) * remainder_weight
remained_words_error = len(q_title) * remainder_weight
total_error += remained_words_error

# a search rank should be between 1 and 0
Expand Down
33 changes: 21 additions & 12 deletions src/tribler/core/database/restapi/database_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from aiohttp import web
from aiohttp_apispec import docs, querystring_schema
from ipv8.REST.base_endpoint import HTTP_BAD_REQUEST
from ipv8.REST.schema import schema
from marshmallow.fields import Boolean, Integer, String
from pony.orm import db_session
Expand All @@ -19,7 +18,13 @@
from tribler.core.database.restapi.schema import MetadataSchema, SearchMetadataParameters, TorrentSchema
from tribler.core.database.serialization import REGULAR_TORRENT, SNIPPET
from tribler.core.notifier import Notification
from tribler.core.restapi.rest_endpoint import MAX_REQUEST_SIZE, RESTEndpoint, RESTResponse
from tribler.core.restapi.rest_endpoint import (
HTTP_BAD_REQUEST,
HTTP_NOT_FOUND,
MAX_REQUEST_SIZE,
RESTEndpoint,
RESTResponse,
)

if typing.TYPE_CHECKING:
from aiohttp.abc import Request
Expand Down Expand Up @@ -97,14 +102,15 @@ def __init__(self, # noqa: PLR0913

@classmethod
def sanitize_parameters(cls: type[Self],
parameters: MultiDictProxy | MultiMapping[str]) -> dict[str, str | float | set | None]:
parameters: MultiDictProxy | MultiMapping[str]
) -> dict[str, str | float | list[str] | set[bytes] | bytes | None]:
"""
Sanitize the parameters for a request that fetches channels.
"""
sanitized = {
sanitized: dict[str, str | float | list[str] | set[bytes] | bytes | None] = {
"first": int(parameters.get("first", 1)),
"last": int(parameters.get("last", 50)),
"sort_by": json2pony_columns.get(parameters.get("sort_by")),
"sort_by": json2pony_columns.get(parameters.get("sort_by", "")),
"sort_desc": parse_bool(parameters.get("sort_desc", "true")),
"txt_filter": parameters.get("txt_filter"),
"hide_xxx": parse_bool(parameters.get("hide_xxx", "false")),
Expand Down Expand Up @@ -233,20 +239,20 @@ async def get_popular_torrents(self, request: Request) -> RESTResponse:

return RESTResponse(response_dict)

def build_snippets(self, search_results: list[dict]) -> list[dict]:
def build_snippets(self, tribler_db: TriblerDatabase, search_results: list[dict]) -> list[dict]:
"""
Build a list of snippets that bundle torrents describing the same content item.
For each search result we determine the content item it is associated to and bundle it inside a snippet.
We sort the snippets based on the number of torrents inside the snippet.
Within each snippet, we sort on torrent popularity, putting the torrent with the most seeders on top.
Torrents bundled in a snippet are filtered out from the search results.
"""
content_to_torrents: typing.Dict[str, list] = defaultdict(list)
content_to_torrents: dict[str, list] = defaultdict(list)
for search_result in search_results:
if "infohash" not in search_result:
continue
with db_session:
content_items: typing.List[str] = self.tribler_db.knowledge.get_objects(
content_items = tribler_db.knowledge.get_objects(
subject_type=ResourceType.TORRENT,
subject=search_result["infohash"],
predicate=ResourceType.CONTENT_ITEM)
Expand All @@ -262,7 +268,7 @@ def build_snippets(self, search_results: list[dict]) -> list[dict]:
sorted_content_info = list(content_to_torrents.items())
sorted_content_info.sort(key=lambda x: x[1][0]["num_seeders"], reverse=True)

snippets: typing.List[typing.Dict] = []
snippets: list[dict] = []
for content_info in sorted_content_info:
content_id = content_info[0]
torrents_in_snippet = content_to_torrents[content_id][:MAX_TORRENTS_IN_SNIPPETS]
Expand Down Expand Up @@ -310,7 +316,7 @@ def build_snippets(self, search_results: list[dict]) -> list[dict]:
},
)
@querystring_schema(SearchMetadataParameters)
async def local_search(self, request: Request) -> RESTResponse:
async def local_search(self, request: Request) -> RESTResponse: # noqa: C901
"""
Perform a search for a given query.
"""
Expand All @@ -320,6 +326,9 @@ async def local_search(self, request: Request) -> RESTResponse:
except (ValueError, KeyError):
return RESTResponse({"error": "Error processing request parameters"}, status=HTTP_BAD_REQUEST)

if self.tribler_db is None:
return RESTResponse({"error": "Tribler DB not initialized"}, status=HTTP_NOT_FOUND)

include_total = request.query.get("include_total", "")

mds: MetadataStore = self.mds
Expand All @@ -345,7 +354,7 @@ def search_db() -> tuple[list[dict], int, int]:
if tags:
infohash_set = self.tribler_db.knowledge.get_subjects_intersection(
subjects_type=ResourceType.TORRENT,
objects=set(tags),
objects=set(typing.cast(list[str], tags)),
predicate=ResourceType.TAG,
case_sensitive=False)
if infohash_set:
Expand All @@ -359,7 +368,7 @@ def search_db() -> tuple[list[dict], int, int]:
self.add_statements_to_metadata_list(search_results)

if sanitized["first"] == 1: # Only show a snippet on top
search_results = self.build_snippets(search_results)
search_results = self.build_snippets(self.tribler_db, search_results)

response_dict = {
"results": search_results,
Expand Down
9 changes: 5 additions & 4 deletions src/tribler/core/database/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ def from_dict(cls: type[Self], **kwargs) -> Self:
Create a payload from the given data (an unpacked dict).
"""
out = cls(**{key: value for key, value in kwargs.items() if key in cls.names})
if kwargs.get("signature") is not None:
out.signature = kwargs.get("signature")
signature = kwargs.get("signature")
if signature is not None:
out.signature = signature
return out

def add_signature(self, key: PrivateKey) -> None:
Expand Down Expand Up @@ -196,8 +197,8 @@ def get_magnet(self) -> str:
"""
Create a magnet link for this payload.
"""
return (f"magnet:?xt=urn:btih:{hexlify(self.infohash).decode()}&dn={self.title.encode()}"
+ (f"&tr={self.tracker_info.encode()}" if self.tracker_info else ""))
return (f"magnet:?xt=urn:btih:{hexlify(self.infohash).decode()}&dn={self.title}"
+ (f"&tr={self.tracker_info}" if self.tracker_info else ""))


@vp_compile
Expand Down
13 changes: 4 additions & 9 deletions src/tribler/core/database/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from tribler.core.torrent_checker.dataclasses import HealthInfo

if TYPE_CHECKING:
from os import PathLike
from sqlite3 import Connection

from ipv8.types import PrivateKey
Expand Down Expand Up @@ -133,7 +132,7 @@ class MetadataStore:

def __init__( # noqa: PLR0913
self,
db_filename: PathLike,
db_filename: str,
private_key: PrivateKey,
disable_sync: bool = False,
notifier: Notifier | None = None,
Expand Down Expand Up @@ -241,9 +240,8 @@ def get_objects_to_create(self) -> list[Entity]:
connection = self.db.get_connection()
schema = self.db.schema
provider = schema.provider
created_tables = set()
return [db_object for table in schema.order_tables_to_create()
for db_object in table.get_objects_to_create(created_tables)
for db_object in table.get_objects_to_create(set())
if not db_object.exists(provider, connection)]

def get_db_file_size(self) -> int:
Expand Down Expand Up @@ -481,7 +479,7 @@ def get_num_torrents(self) -> int:
"""
return orm.count(self.TorrentMetadata.select(lambda g: g.metadata_type == REGULAR_TORRENT))

def search_keyword(self, query: str, origin_id: int | None = None) -> list[TorrentMetadata]:
def search_keyword(self, query: str, origin_id: int | None = None) -> Query:
"""
Search for an FTS query, potentially restricted to a given origin id.
Expand Down Expand Up @@ -585,10 +583,7 @@ def get_entries_query( # noqa: C901, PLR0912, PLR0913
pony_query = pony_query.where(lambda g: g.rowid <= max_rowid)

if metadata_type is not None:
try:
pony_query = pony_query.where(lambda g: g.metadata_type in metadata_type)
except TypeError:
pony_query = pony_query.where(lambda g: g.metadata_type == metadata_type)
pony_query = pony_query.where(lambda g: g.metadata_type == metadata_type)

pony_query = (
pony_query.where(public_key=(b"" if channel_pk == NULL_KEY_SUBST else channel_pk))
Expand Down
Loading

0 comments on commit a99f650

Please sign in to comment.