Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract the run_threaded function to the pony_utils.py #7150

Merged
merged 4 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/seedbox/disseminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def commit(self):
def flush(self):
_logger.debug('Flush')

self.community.mds._db.flush() # pylint: disable=protected-access
self.community.mds.db.flush()


class Service(TinyTriblerService):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tribler.core.components.metadata_store.db.serialization import CHANNEL_TORRENT
from tribler.core.components.metadata_store.db.store import MetadataStore
from tribler.core.utilities.notifier import Notifier
from tribler.core.utilities.pony_utils import run_threaded
from tribler.core.utilities.simpledefs import DLSTATUS_SEEDING, NTFY
from tribler.core.utilities.unicode import hexlify

Expand Down Expand Up @@ -283,7 +284,7 @@ def _process_download():
mds.process_channel_dir(channel_dirname, channel.public_key, channel.id_, external_thread=True)

try:
await mds.run_threaded(_process_download)
await run_threaded(mds.db, _process_download)
except Exception as e: # pylint: disable=broad-except # pragma: no cover
self._logger.error("Error when processing channel dir download: %s", e)

Expand Down
85 changes: 33 additions & 52 deletions src/tribler/core/components/metadata_store/db/store.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import logging
import re
import threading
from asyncio import get_event_loop
from datetime import datetime, timedelta
from time import sleep, time
from typing import Optional, Union

from lz4.frame import LZ4FrameDecompressor

from pony import orm
from pony.orm import db_session, desc, left_join, raw_sql, select
from pony.orm.dbproviders.sqlite import keep_exception
Expand Down Expand Up @@ -50,12 +47,11 @@
from tribler.core.exceptions import InvalidSignatureException
from tribler.core.utilities.notifier import Notifier
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.pony_utils import get_max, get_or_create
from tribler.core.utilities.pony_utils import get_max, get_or_create, run_threaded
from tribler.core.utilities.search_utils import torrent_rank
from tribler.core.utilities.unicode import hexlify
from tribler.core.utilities.utilities import MEMORY_DB


BETA_DB_VERSIONS = [0, 1, 2, 3, 4, 5]
CURRENT_DB_VERSION = 14

Expand Down Expand Up @@ -164,12 +160,12 @@ def __init__(
# We have to dynamically define/init ORM-managed entities here to be able to support
# multiple sessions in Tribler. ORM-managed classes are bound to the database instance
# at definition.
self._db = orm.Database()
self.db = orm.Database()

# This attribute is internally called by Pony on startup, though pylint cannot detect it
# with the static analysis.
# pylint: disable=unused-variable
@self._db.on_connect(provider='sqlite')
@self.db.on_connect(provider='sqlite')
def on_connect(_, connection):
cursor = connection.cursor()
cursor.execute("PRAGMA journal_mode = WAL")
Expand All @@ -189,31 +185,31 @@ def on_connect(_, connection):

# pylint: enable=unused-variable

self.MiscData = misc.define_binding(self._db)
self.MiscData = misc.define_binding(self.db)

self.TrackerState = tracker_state.define_binding(self._db)
self.TorrentState = torrent_state.define_binding(self._db)
self.TrackerState = tracker_state.define_binding(self.db)
self.TorrentState = torrent_state.define_binding(self.db)

self.ChannelNode = channel_node.define_binding(self._db, logger=self._logger, key=my_key)
self.ChannelNode = channel_node.define_binding(self.db, logger=self._logger, key=my_key)

self.MetadataNode = metadata_node.define_binding(self._db)
self.CollectionNode = collection_node.define_binding(self._db)
self.MetadataNode = metadata_node.define_binding(self.db)
self.CollectionNode = collection_node.define_binding(self.db)
self.TorrentMetadata = torrent_metadata.define_binding(
self._db,
self.db,
notifier=notifier,
tag_processor_version=tag_processor_version
)
self.ChannelMetadata = channel_metadata.define_binding(self._db)
self.ChannelMetadata = channel_metadata.define_binding(self.db)

self.JsonNode = json_node.define_binding(self._db, db_version)
self.ChannelDescription = channel_description.define_binding(self._db)
self.JsonNode = json_node.define_binding(self.db, db_version)
self.ChannelDescription = channel_description.define_binding(self.db)

self.BinaryNode = binary_node.define_binding(self._db, db_version)
self.ChannelThumbnail = channel_thumbnail.define_binding(self._db)
self.BinaryNode = binary_node.define_binding(self.db, db_version)
self.ChannelThumbnail = channel_thumbnail.define_binding(self.db)

self.ChannelVote = channel_vote.define_binding(self._db)
self.ChannelPeer = channel_peer.define_binding(self._db)
self.Vsids = vsids.define_binding(self._db)
self.ChannelVote = channel_vote.define_binding(self.db)
self.ChannelPeer = channel_peer.define_binding(self.db)
self.Vsids = vsids.define_binding(self.db)

self.ChannelMetadata._channels_dir = channels_dir # pylint: disable=protected-access

Expand All @@ -224,13 +220,13 @@ def on_connect(_, connection):
create_db = not db_filename.is_file()
db_path_string = str(db_filename)

self._db.bind(provider='sqlite', filename=db_path_string, create_db=create_db, timeout=120.0)
self._db.generate_mapping(
self.db.bind(provider='sqlite', filename=db_path_string, create_db=create_db, timeout=120.0)
self.db.generate_mapping(
create_tables=create_db, check_tables=check_tables
) # Must be run out of session scope
if create_db:
with db_session(ddl=True):
self._db.execute(sql_create_fts_table)
self.db.execute(sql_create_fts_table)
self.create_fts_triggers()
self.create_torrentstate_triggers()
self.create_partial_indexes()
Expand All @@ -245,15 +241,6 @@ def on_connect(_, connection):
default_vsids = self.Vsids.create_default_vsids()
self.ChannelMetadata.votes_scaling = default_vsids.max_val

async def run_threaded(self, func, *args, **kwargs):
def wrapper():
try:
return func(*args, **kwargs)
finally:
self.disconnect_thread()

return await get_event_loop().run_in_executor(None, wrapper)

def set_value(self, key: str, value: str):
key_value = get_or_create(self.MiscData, name=key)
key_value.value = value
Expand All @@ -263,14 +250,14 @@ def get_value(self, key: str, default: Optional[str] = None) -> Optional[str]:
return data.value if data else default

def drop_indexes(self):
cursor = self._db.get_connection().cursor()
cursor = self.db.get_connection().cursor()
cursor.execute("select name from sqlite_master where type='index' and name like 'idx_%'")
for [index_name] in cursor.fetchall():
cursor.execute(f"drop index {index_name}")

def get_objects_to_create(self):
connection = self._db.get_connection()
schema = self._db.schema
connection = self.db.get_connection()
schema = self.db.schema
provider = schema.provider
created_tables = set()
result = []
Expand All @@ -284,28 +271,28 @@ def get_db_file_size(self):
return 0 if self.db_path is MEMORY_DB else Path(self.db_path).size()

def drop_fts_triggers(self):
cursor = self._db.get_connection().cursor()
cursor = self.db.get_connection().cursor()
cursor.execute("select name from sqlite_master where type='trigger' and name like 'fts_%'")
for [trigger_name] in cursor.fetchall():
cursor.execute(f"drop trigger {trigger_name}")

def create_fts_triggers(self):
cursor = self._db.get_connection().cursor()
cursor = self.db.get_connection().cursor()
cursor.execute(sql_add_fts_trigger_insert)
cursor.execute(sql_add_fts_trigger_delete)
cursor.execute(sql_add_fts_trigger_update)

def fill_fts_index(self):
cursor = self._db.get_connection().cursor()
cursor = self.db.get_connection().cursor()
cursor.execute("insert into FtsIndex(rowid, title) select rowid, title from ChannelNode")

def create_torrentstate_triggers(self):
cursor = self._db.get_connection().cursor()
cursor = self.db.get_connection().cursor()
cursor.execute(sql_add_torrentstate_trigger_after_insert)
cursor.execute(sql_add_torrentstate_trigger_after_update)

def create_partial_indexes(self):
cursor = self._db.get_connection().cursor()
cursor = self.db.get_connection().cursor()
cursor.execute(sql_create_partial_index_channelnode_subscribed)
cursor.execute(sql_create_partial_index_channelnode_metadata_type)

Expand All @@ -332,13 +319,7 @@ def vote_bump(self, public_key, id_, voter_pk):

def shutdown(self):
self._shutting_down = True
self._db.disconnect()

def disconnect_thread(self):
# Ugly workaround for closing threadpool connections
# Remark: maybe subclass ThreadPoolExecutor to handle this automatically?
if not isinstance(threading.current_thread(), threading._MainThread): # pylint: disable=W0212
self._db.disconnect()
self.db.disconnect()

@staticmethod
def get_list_of_channel_blobs_to_process(dirname, start_timestamp):
Expand Down Expand Up @@ -467,7 +448,7 @@ def process_mdblob_file(self, filepath, **kwargs):

async def process_compressed_mdblob_threaded(self, compressed_data, **kwargs):
try:
return await self.run_threaded(self.process_compressed_mdblob, compressed_data, **kwargs)
return await run_threaded(self.db, self.process_compressed_mdblob, compressed_data, **kwargs)
except Exception as e: # pylint: disable=broad-except # pragma: no cover
self._logger.warning("DB transaction error when tried to process compressed mdblob: %s", str(e))
return None
Expand Down Expand Up @@ -787,7 +768,7 @@ def get_entries_query(
return pony_query

async def get_entries_threaded(self, **kwargs):
return await self.run_threaded(self.get_entries, **kwargs)
return await run_threaded(self.db, self.get_entries, **kwargs)

@db_session
def get_entries(self, first=1, last=None, **kwargs):
Expand Down Expand Up @@ -838,7 +819,7 @@ def get_auto_complete_terms(self, text, max_terms, limit=10):
suggestion_re = re.compile(suggestion_pattern, re.UNICODE)

with db_session:
titles = self._db.select("""
titles = self.db.select("""
cn.title
FROM ChannelNode cn
INNER JOIN FtsIndex ON cn.rowid = FtsIndex.rowid
Expand Down
25 changes: 12 additions & 13 deletions src/tribler/core/components/metadata_store/db/tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
from datetime import datetime
from unittest.mock import patch

import pytest
from ipv8.keyvault.crypto import default_eccrypto

from pony.orm import db_session

import pytest

from tribler.core.components.metadata_store.db.orm_bindings.channel_metadata import (
CHANNEL_DIR_NAME_LENGTH,
entries_to_chunk,
Expand All @@ -29,8 +27,10 @@
from tribler.core.components.metadata_store.tests.test_channel_download import CHANNEL_METADATA_UPDATED
from tribler.core.tests.tools.common import TESTS_DATA_DIR
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.pony_utils import run_threaded
from tribler.core.utilities.utilities import random_infohash


# pylint: disable=protected-access,unused-argument


Expand Down Expand Up @@ -269,13 +269,12 @@ def test_process_forbidden_payload(metadata_store):
def test_process_payload(metadata_store):
sender_key = default_eccrypto.generate_key("curve25519")
for md_class in (
metadata_store.ChannelMetadata,
metadata_store.TorrentMetadata,
metadata_store.CollectionNode,
metadata_store.ChannelDescription,
metadata_store.ChannelThumbnail,
metadata_store.ChannelMetadata,
metadata_store.TorrentMetadata,
metadata_store.CollectionNode,
metadata_store.ChannelDescription,
metadata_store.ChannelThumbnail,
):

node, node_payload, node_deleted_payload = get_payloads(md_class, sender_key)
node_dict = node.to_dict()
node.delete()
Expand Down Expand Up @@ -333,8 +332,8 @@ def test_process_payload_with_known_channel_public_key(metadata_store):

# Check accepting a payload with matching public key
assert (
metadata_store.process_payload(payload, channel_public_key=key1.pub().key_to_bin()[10:])[0].obj_state
== ObjState.NEW_OBJECT
metadata_store.process_payload(payload, channel_public_key=key1.pub().key_to_bin()[10:])[0].obj_state
== ObjState.NEW_OBJECT
)
assert metadata_store.TorrentMetadata.get()

Expand Down Expand Up @@ -465,8 +464,8 @@ def f1(a, b, *, c, d):
return threading.get_ident()
raise ThreadedTestException('test exception')

result = await metadata_store.run_threaded(f1, 1, 2, c=3, d=4)
result = await run_threaded(metadata_store.db, f1, 1, 2, c=3, d=4)
assert result != thread_id

with pytest.raises(ThreadedTestException, match='^test exception$'):
await metadata_store.run_threaded(f1, 1, 2, c=5, d=6)
await run_threaded(metadata_store.db, f1, 1, 2, c=5, d=6)
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tribler.core.components.metadata_store.utils import RequestTimeoutException
from tribler.core.components.knowledge.community.knowledge_validator import is_valid_resource
from tribler.core.components.knowledge.db.knowledge_db import ResourceType
from tribler.core.utilities.pony_utils import run_threaded
from tribler.core.utilities.unicode import hexlify

BINARY_FIELDS = ("infohash", "channel_pk")
Expand Down Expand Up @@ -213,12 +214,13 @@ async def process_rpc_query(self, sanitized_parameters: Dict[str, Any]) -> List:
:raises ValueError: if no JSON could be decoded.
:raises pony.orm.dbapiprovider.OperationalError: if an illegal query was performed.
"""
# tags should be extracted because `get_entries_threaded` doesn't expect them as a parameter
tags = sanitized_parameters.pop('tags', None)
if self.knowledge_db:
# tags should be extracted because `get_entries_threaded` doesn't expect them as a parameter
tags = sanitized_parameters.pop('tags', None)

infohash_set = await self.mds.run_threaded(self.search_for_tags, tags)
if infohash_set:
sanitized_parameters['infohash_set'] = {bytes.fromhex(s) for s in infohash_set}
infohash_set = await run_threaded(self.knowledge_db.instance, self.search_for_tags, tags)
if infohash_set:
sanitized_parameters['infohash_set'] = {bytes.fromhex(s) for s in infohash_set}

return await self.mds.get_entries_threaded(**sanitized_parameters)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tribler.core.components.metadata_store.restapi.metadata_schema import MetadataParameters, MetadataSchema
from tribler.core.components.restapi.rest.rest_endpoint import HTTP_BAD_REQUEST, RESTResponse
from tribler.core.components.knowledge.db.knowledge_db import ResourceType
from tribler.core.utilities.pony_utils import run_threaded
from tribler.core.utilities.utilities import froze_it

SNIPPETS_TO_SHOW = 3 # The number of snippets we return from the search results
Expand Down Expand Up @@ -151,7 +152,7 @@ def search_db():
if infohash_set:
sanitized['infohash_set'] = {bytes.fromhex(s) for s in infohash_set}

search_results, total, max_rowid = await mds.run_threaded(search_db)
search_results, total, max_rowid = await run_threaded(mds.db, search_db)
except Exception as e: # pylint: disable=broad-except; # pragma: no cover
self._logger.exception("Error while performing DB search: %s: %s", type(e).__name__, e)
return RESTResponse(status=HTTP_BAD_REQUEST)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
from pathlib import Path
from unittest.mock import Mock, patch

import pytest
from ipv8.keyvault.crypto import default_eccrypto

from lz4.frame import LZ4FrameDecompressor

from pony.orm import ObjectNotFound, db_session

import pytest

from tribler.core.components.libtorrent.torrentdef import TorrentDef
from tribler.core.components.metadata_store.db.orm_bindings.channel_metadata import (
CHANNEL_DIR_NAME_LENGTH,
Expand All @@ -31,6 +28,7 @@
from tribler.core.utilities.simpledefs import CHANNEL_STATE
from tribler.core.utilities.utilities import random_infohash


# pylint: disable=protected-access


Expand Down Expand Up @@ -80,7 +78,7 @@ def mds_with_some_torrents_fixture(metadata_store):
# torrent6 aaa zzz

def save():
metadata_store._db.flush() # pylint: disable=W0212
metadata_store.db.flush()

def new_channel(**kwargs):
params = dict(subscribed=True, share=True, status=NEW, infohash=random_infohash())
Expand Down
Loading