From 8b9c3eff19c951faf560ccf59b6143a5235d2219 Mon Sep 17 00:00:00 2001 From: drew2a Date: Mon, 7 Nov 2022 13:47:03 +0100 Subject: [PATCH 1/4] Extract the `run_threaded` function to the `pony_utils.py` --- scripts/seedbox/disseminator.py | 2 +- .../gigachannel_manager.py | 3 +- .../components/metadata_store/db/store.py | 72 +++++++++---------- .../metadata_store/db/tests/test_store.py | 25 ++++--- .../remote_query_community.py | 12 ++-- .../metadata_store/restapi/search_endpoint.py | 3 +- .../tests/test_channel_metadata.py | 2 +- .../community/popularity_community.py | 3 +- src/tribler/core/upgrade/db8_to_db10.py | 6 +- .../core/upgrade/tests/test_upgrader.py | 16 ++--- src/tribler/core/upgrade/upgrade.py | 14 ++-- src/tribler/core/utilities/pony_utils.py | 31 +++++++- 12 files changed, 107 insertions(+), 82 deletions(-) diff --git a/scripts/seedbox/disseminator.py b/scripts/seedbox/disseminator.py index ca50024fb60..5d51ed76468 100644 --- a/scripts/seedbox/disseminator.py +++ b/scripts/seedbox/disseminator.py @@ -159,7 +159,7 @@ def commit(self): def flush(self): _logger.debug('Flush') - self.community.mds._db.flush() # pylint: disable=protected-access + self.community.mds.db.flush() # pylint: disable=protected-access class Service(TinyTriblerService): diff --git a/src/tribler/core/components/gigachannel_manager/gigachannel_manager.py b/src/tribler/core/components/gigachannel_manager/gigachannel_manager.py index b9496d948ee..4bfb860ae6e 100644 --- a/src/tribler/core/components/gigachannel_manager/gigachannel_manager.py +++ b/src/tribler/core/components/gigachannel_manager/gigachannel_manager.py @@ -14,6 +14,7 @@ from tribler.core.components.metadata_store.db.serialization import CHANNEL_TORRENT from tribler.core.components.metadata_store.db.store import MetadataStore from tribler.core.utilities.notifier import Notifier +from tribler.core.utilities.pony_utils import run_threaded from tribler.core.utilities.simpledefs import DLSTATUS_SEEDING, NTFY from tribler.core.utilities.unicode import hexlify @@ -283,7 +284,7 @@ def _process_download(): mds.process_channel_dir(channel_dirname, channel.public_key, channel.id_, external_thread=True) try: - await mds.run_threaded(_process_download) + await run_threaded(mds.db, _process_download) except Exception as e: # pylint: disable=broad-except # pragma: no cover self._logger.error("Error when processing channel dir download: %s", e) diff --git a/src/tribler/core/components/metadata_store/db/store.py b/src/tribler/core/components/metadata_store/db/store.py index 78089c42a55..daeb264c151 100644 --- a/src/tribler/core/components/metadata_store/db/store.py +++ b/src/tribler/core/components/metadata_store/db/store.py @@ -50,7 +50,7 @@ from tribler.core.exceptions import InvalidSignatureException from tribler.core.utilities.notifier import Notifier from tribler.core.utilities.path_util import Path -from tribler.core.utilities.pony_utils import get_max, get_or_create +from tribler.core.utilities.pony_utils import get_max, get_or_create, run_threaded from tribler.core.utilities.search_utils import torrent_rank from tribler.core.utilities.unicode import hexlify from tribler.core.utilities.utilities import MEMORY_DB @@ -164,12 +164,12 @@ def __init__( # We have to dynamically define/init ORM-managed entities here to be able to support # multiple sessions in Tribler. ORM-managed classes are bound to the database instance # at definition. - self._db = orm.Database() + self.db = orm.Database() # This attribute is internally called by Pony on startup, though pylint cannot detect it # with the static analysis. # pylint: disable=unused-variable - @self._db.on_connect(provider='sqlite') + @self.db.on_connect(provider='sqlite') def on_connect(_, connection): cursor = connection.cursor() cursor.execute("PRAGMA journal_mode = WAL") @@ -189,31 +189,31 @@ def on_connect(_, connection): # pylint: enable=unused-variable - self.MiscData = misc.define_binding(self._db) + self.MiscData = misc.define_binding(self.db) - self.TrackerState = tracker_state.define_binding(self._db) - self.TorrentState = torrent_state.define_binding(self._db) + self.TrackerState = tracker_state.define_binding(self.db) + self.TorrentState = torrent_state.define_binding(self.db) - self.ChannelNode = channel_node.define_binding(self._db, logger=self._logger, key=my_key) + self.ChannelNode = channel_node.define_binding(self.db, logger=self._logger, key=my_key) - self.MetadataNode = metadata_node.define_binding(self._db) - self.CollectionNode = collection_node.define_binding(self._db) + self.MetadataNode = metadata_node.define_binding(self.db) + self.CollectionNode = collection_node.define_binding(self.db) self.TorrentMetadata = torrent_metadata.define_binding( - self._db, + self.db, notifier=notifier, tag_processor_version=tag_processor_version ) - self.ChannelMetadata = channel_metadata.define_binding(self._db) + self.ChannelMetadata = channel_metadata.define_binding(self.db) - self.JsonNode = json_node.define_binding(self._db, db_version) - self.ChannelDescription = channel_description.define_binding(self._db) + self.JsonNode = json_node.define_binding(self.db, db_version) + self.ChannelDescription = channel_description.define_binding(self.db) - self.BinaryNode = binary_node.define_binding(self._db, db_version) - self.ChannelThumbnail = channel_thumbnail.define_binding(self._db) + self.BinaryNode = binary_node.define_binding(self.db, db_version) + self.ChannelThumbnail = channel_thumbnail.define_binding(self.db) - self.ChannelVote = channel_vote.define_binding(self._db) - self.ChannelPeer = channel_peer.define_binding(self._db) - self.Vsids = vsids.define_binding(self._db) + self.ChannelVote = channel_vote.define_binding(self.db) + self.ChannelPeer = channel_peer.define_binding(self.db) + self.Vsids = vsids.define_binding(self.db) self.ChannelMetadata._channels_dir = channels_dir # pylint: disable=protected-access @@ -224,13 +224,13 @@ def on_connect(_, connection): create_db = not db_filename.is_file() db_path_string = str(db_filename) - self._db.bind(provider='sqlite', filename=db_path_string, create_db=create_db, timeout=120.0) - self._db.generate_mapping( + self.db.bind(provider='sqlite', filename=db_path_string, create_db=create_db, timeout=120.0) + self.db.generate_mapping( create_tables=create_db, check_tables=check_tables ) # Must be run out of session scope if create_db: with db_session(ddl=True): - self._db.execute(sql_create_fts_table) + self.db.execute(sql_create_fts_table) self.create_fts_triggers() self.create_torrentstate_triggers() self.create_partial_indexes() @@ -263,14 +263,14 @@ def get_value(self, key: str, default: Optional[str] = None) -> Optional[str]: return data.value if data else default def drop_indexes(self): - cursor = self._db.get_connection().cursor() + cursor = self.db.get_connection().cursor() cursor.execute("select name from sqlite_master where type='index' and name like 'idx_%'") for [index_name] in cursor.fetchall(): cursor.execute(f"drop index {index_name}") def get_objects_to_create(self): - connection = self._db.get_connection() - schema = self._db.schema + connection = self.db.get_connection() + schema = self.db.schema provider = schema.provider created_tables = set() result = [] @@ -284,28 +284,28 @@ def get_db_file_size(self): return 0 if self.db_path is MEMORY_DB else Path(self.db_path).size() def drop_fts_triggers(self): - cursor = self._db.get_connection().cursor() + cursor = self.db.get_connection().cursor() cursor.execute("select name from sqlite_master where type='trigger' and name like 'fts_%'") for [trigger_name] in cursor.fetchall(): cursor.execute(f"drop trigger {trigger_name}") def create_fts_triggers(self): - cursor = self._db.get_connection().cursor() + cursor = self.db.get_connection().cursor() cursor.execute(sql_add_fts_trigger_insert) cursor.execute(sql_add_fts_trigger_delete) cursor.execute(sql_add_fts_trigger_update) def fill_fts_index(self): - cursor = self._db.get_connection().cursor() + cursor = self.db.get_connection().cursor() cursor.execute("insert into FtsIndex(rowid, title) select rowid, title from ChannelNode") def create_torrentstate_triggers(self): - cursor = self._db.get_connection().cursor() + cursor = self.db.get_connection().cursor() cursor.execute(sql_add_torrentstate_trigger_after_insert) cursor.execute(sql_add_torrentstate_trigger_after_update) def create_partial_indexes(self): - cursor = self._db.get_connection().cursor() + cursor = self.db.get_connection().cursor() cursor.execute(sql_create_partial_index_channelnode_subscribed) cursor.execute(sql_create_partial_index_channelnode_metadata_type) @@ -332,13 +332,7 @@ def vote_bump(self, public_key, id_, voter_pk): def shutdown(self): self._shutting_down = True - self._db.disconnect() - - def disconnect_thread(self): - # Ugly workaround for closing threadpool connections - # Remark: maybe subclass ThreadPoolExecutor to handle this automatically? - if not isinstance(threading.current_thread(), threading._MainThread): # pylint: disable=W0212 - self._db.disconnect() + self.db.disconnect() @staticmethod def get_list_of_channel_blobs_to_process(dirname, start_timestamp): @@ -467,7 +461,7 @@ def process_mdblob_file(self, filepath, **kwargs): async def process_compressed_mdblob_threaded(self, compressed_data, **kwargs): try: - return await self.run_threaded(self.process_compressed_mdblob, compressed_data, **kwargs) + return await run_threaded(self.db, self.process_compressed_mdblob, compressed_data, **kwargs) except Exception as e: # pylint: disable=broad-except # pragma: no cover self._logger.warning("DB transaction error when tried to process compressed mdblob: %s", str(e)) return None @@ -787,7 +781,7 @@ def get_entries_query( return pony_query async def get_entries_threaded(self, **kwargs): - return await self.run_threaded(self.get_entries, **kwargs) + return await run_threaded(self.db, self.get_entries, **kwargs) @db_session def get_entries(self, first=1, last=None, **kwargs): @@ -838,7 +832,7 @@ def get_auto_complete_terms(self, text, max_terms, limit=10): suggestion_re = re.compile(suggestion_pattern, re.UNICODE) with db_session: - titles = self._db.select(""" + titles = self.db.select(""" cn.title FROM ChannelNode cn INNER JOIN FtsIndex ON cn.rowid = FtsIndex.rowid diff --git a/src/tribler/core/components/metadata_store/db/tests/test_store.py b/src/tribler/core/components/metadata_store/db/tests/test_store.py index 0fc378bbdc4..c3d83d9cc98 100644 --- a/src/tribler/core/components/metadata_store/db/tests/test_store.py +++ b/src/tribler/core/components/metadata_store/db/tests/test_store.py @@ -6,12 +6,10 @@ from datetime import datetime from unittest.mock import patch +import pytest from ipv8.keyvault.crypto import default_eccrypto - from pony.orm import db_session -import pytest - from tribler.core.components.metadata_store.db.orm_bindings.channel_metadata import ( CHANNEL_DIR_NAME_LENGTH, entries_to_chunk, @@ -29,8 +27,10 @@ from tribler.core.components.metadata_store.tests.test_channel_download import CHANNEL_METADATA_UPDATED from tribler.core.tests.tools.common import TESTS_DATA_DIR from tribler.core.utilities.path_util import Path +from tribler.core.utilities.pony_utils import run_threaded from tribler.core.utilities.utilities import random_infohash + # pylint: disable=protected-access,unused-argument @@ -269,13 +269,12 @@ def test_process_forbidden_payload(metadata_store): def test_process_payload(metadata_store): sender_key = default_eccrypto.generate_key("curve25519") for md_class in ( - metadata_store.ChannelMetadata, - metadata_store.TorrentMetadata, - metadata_store.CollectionNode, - metadata_store.ChannelDescription, - metadata_store.ChannelThumbnail, + metadata_store.ChannelMetadata, + metadata_store.TorrentMetadata, + metadata_store.CollectionNode, + metadata_store.ChannelDescription, + metadata_store.ChannelThumbnail, ): - node, node_payload, node_deleted_payload = get_payloads(md_class, sender_key) node_dict = node.to_dict() node.delete() @@ -333,8 +332,8 @@ def test_process_payload_with_known_channel_public_key(metadata_store): # Check accepting a payload with matching public key assert ( - metadata_store.process_payload(payload, channel_public_key=key1.pub().key_to_bin()[10:])[0].obj_state - == ObjState.NEW_OBJECT + metadata_store.process_payload(payload, channel_public_key=key1.pub().key_to_bin()[10:])[0].obj_state + == ObjState.NEW_OBJECT ) assert metadata_store.TorrentMetadata.get() @@ -465,8 +464,8 @@ def f1(a, b, *, c, d): return threading.get_ident() raise ThreadedTestException('test exception') - result = await metadata_store.run_threaded(f1, 1, 2, c=3, d=4) + result = await run_threaded(metadata_store.db, f1, 1, 2, c=3, d=4) assert result != thread_id with pytest.raises(ThreadedTestException, match='^test exception$'): - await metadata_store.run_threaded(f1, 1, 2, c=5, d=6) + await run_threaded(metadata_store.db, f1, 1, 2, c=5, d=6) diff --git a/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py b/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py index aa6fd817fef..ac4cef885f5 100644 --- a/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py +++ b/src/tribler/core/components/metadata_store/remote_query_community/remote_query_community.py @@ -23,6 +23,7 @@ from tribler.core.components.metadata_store.utils import RequestTimeoutException from tribler.core.components.knowledge.community.knowledge_validator import is_valid_resource from tribler.core.components.knowledge.db.knowledge_db import ResourceType +from tribler.core.utilities.pony_utils import run_threaded from tribler.core.utilities.unicode import hexlify BINARY_FIELDS = ("infohash", "channel_pk") @@ -213,12 +214,13 @@ async def process_rpc_query(self, sanitized_parameters: Dict[str, Any]) -> List: :raises ValueError: if no JSON could be decoded. :raises pony.orm.dbapiprovider.OperationalError: if an illegal query was performed. """ - # tags should be extracted because `get_entries_threaded` doesn't expect them as a parameter - tags = sanitized_parameters.pop('tags', None) + if self.knowledge_db: + # tags should be extracted because `get_entries_threaded` doesn't expect them as a parameter + tags = sanitized_parameters.pop('tags', None) - infohash_set = await self.mds.run_threaded(self.search_for_tags, tags) - if infohash_set: - sanitized_parameters['infohash_set'] = {bytes.fromhex(s) for s in infohash_set} + infohash_set = await run_threaded(self.knowledge_db.instance, self.search_for_tags, tags) + if infohash_set: + sanitized_parameters['infohash_set'] = {bytes.fromhex(s) for s in infohash_set} return await self.mds.get_entries_threaded(**sanitized_parameters) diff --git a/src/tribler/core/components/metadata_store/restapi/search_endpoint.py b/src/tribler/core/components/metadata_store/restapi/search_endpoint.py index a7a1963871f..e814d9e843f 100644 --- a/src/tribler/core/components/metadata_store/restapi/search_endpoint.py +++ b/src/tribler/core/components/metadata_store/restapi/search_endpoint.py @@ -14,6 +14,7 @@ from tribler.core.components.metadata_store.restapi.metadata_schema import MetadataParameters, MetadataSchema from tribler.core.components.restapi.rest.rest_endpoint import HTTP_BAD_REQUEST, RESTResponse from tribler.core.components.knowledge.db.knowledge_db import ResourceType +from tribler.core.utilities.pony_utils import run_threaded from tribler.core.utilities.utilities import froze_it SNIPPETS_TO_SHOW = 3 # The number of snippets we return from the search results @@ -151,7 +152,7 @@ def search_db(): if infohash_set: sanitized['infohash_set'] = {bytes.fromhex(s) for s in infohash_set} - search_results, total, max_rowid = await mds.run_threaded(search_db) + search_results, total, max_rowid = await run_threaded(mds.db, search_db) except Exception as e: # pylint: disable=broad-except; # pragma: no cover self._logger.exception("Error while performing DB search: %s: %s", type(e).__name__, e) return RESTResponse(status=HTTP_BAD_REQUEST) diff --git a/src/tribler/core/components/metadata_store/tests/test_channel_metadata.py b/src/tribler/core/components/metadata_store/tests/test_channel_metadata.py index 7f369aeedc0..fcb2426472b 100644 --- a/src/tribler/core/components/metadata_store/tests/test_channel_metadata.py +++ b/src/tribler/core/components/metadata_store/tests/test_channel_metadata.py @@ -80,7 +80,7 @@ def mds_with_some_torrents_fixture(metadata_store): # torrent6 aaa zzz def save(): - metadata_store._db.flush() # pylint: disable=W0212 + metadata_store.db.flush() # pylint: disable=W0212 def new_channel(**kwargs): params = dict(subscribed=True, share=True, status=NEW, infohash=random_infohash()) diff --git a/src/tribler/core/components/popularity/community/popularity_community.py b/src/tribler/core/components/popularity/community/popularity_community.py index 8ee15b6f276..e47531d3940 100644 --- a/src/tribler/core/components/popularity/community/popularity_community.py +++ b/src/tribler/core/components/popularity/community/popularity_community.py @@ -9,6 +9,7 @@ from tribler.core.components.metadata_store.remote_query_community.remote_query_community import RemoteQueryCommunity from tribler.core.components.popularity.community.payload import TorrentsHealthPayload, PopularTorrentsRequest from tribler.core.components.popularity.community.version_community_mixin import VersionCommunityMixin +from tribler.core.utilities.pony_utils import run_threaded from tribler.core.utilities.unicode import hexlify from tribler.core.utilities.utilities import get_normally_distributed_positive_integers @@ -79,7 +80,7 @@ async def on_torrents_health(self, peer, payload): torrents = payload.random_torrents + payload.torrents_checked - for infohash in await self.mds.run_threaded(self.process_torrents_health, torrents): + for infohash in await run_threaded(self.mds.db, self.process_torrents_health, torrents): # Get a single result per infohash to avoid duplicates self.send_remote_select(peer=peer, infohash=infohash, last=1) diff --git a/src/tribler/core/upgrade/db8_to_db10.py b/src/tribler/core/upgrade/db8_to_db10.py index 3db159bd91a..ead9d0524cf 100644 --- a/src/tribler/core/upgrade/db8_to_db10.py +++ b/src/tribler/core/upgrade/db8_to_db10.py @@ -185,7 +185,7 @@ def index_callback_handler(): # Recreate table indexes with db_session(ddl=True): - connection = mds._db.get_connection() + connection = mds.db.get_connection() try: db_objects = mds.get_objects_to_create() index_total = len(db_objects) @@ -193,7 +193,7 @@ def index_callback_handler(): index_num = i t1 = now() connection.set_progress_handler(index_callback_handler, 5000) - obj.create(mds._db.schema.provider, connection) + obj.create(mds.db.schema.provider, connection) duration = now() - t1 self._logger.info(f"Upgrade: created {obj.name} in {duration:.2f} seconds") finally: @@ -216,7 +216,7 @@ def fts_callback_handler(): # Create FTS index with db_session(ddl=True): mds.create_fts_triggers() - connection = mds._db.get_connection() + connection = mds.db.get_connection() connection.set_progress_handler(fts_callback_handler, 5000) try: t = now() diff --git a/src/tribler/core/upgrade/tests/test_upgrader.py b/src/tribler/core/upgrade/tests/test_upgrader.py index 3b7b77390f0..692b2cd3243 100644 --- a/src/tribler/core/upgrade/tests/test_upgrader.py +++ b/src/tribler/core/upgrade/tests/test_upgrader.py @@ -63,7 +63,7 @@ def test_upgrade_pony_db_complete(upgrader, channels_dir, state_dir, trustchain_ upgrader.run() mds = MetadataStore(mds_path, channels_dir, trustchain_keypair) - db = mds._db # pylint: disable=protected-access + db = mds.db # pylint: disable=protected-access existing_indexes = [ 'idx_channelnode__metadata_type__partial', @@ -137,7 +137,7 @@ def test_upgrade_pony_10to11(upgrader, channels_dir, mds_path, trustchain_keypai mds = MetadataStore(mds_path, channels_dir, trustchain_keypair, check_tables=False, db_version=11) with db_session: # pylint: disable=protected-access - assert upgrader.column_exists_in_table(mds._db, 'TorrentState', 'self_checked') + assert upgrader.column_exists_in_table(mds.db, 'TorrentState', 'self_checked') assert mds.get_value("db_version") == '11' mds.shutdown() @@ -149,9 +149,9 @@ def test_upgrade_pony11to12(upgrader, channels_dir, mds_path, trustchain_keypair mds = MetadataStore(mds_path, channels_dir, trustchain_keypair, check_tables=False, db_version=11) with db_session: # pylint: disable=protected-access - assert upgrader.column_exists_in_table(mds._db, 'ChannelNode', 'json_text') - assert upgrader.column_exists_in_table(mds._db, 'ChannelNode', 'binary_data') - assert upgrader.column_exists_in_table(mds._db, 'ChannelNode', 'data_type') + assert upgrader.column_exists_in_table(mds.db, 'ChannelNode', 'json_text') + assert upgrader.column_exists_in_table(mds.db, 'ChannelNode', 'binary_data') + assert upgrader.column_exists_in_table(mds.db, 'ChannelNode', 'data_type') assert mds.get_value("db_version") == '12' mds.shutdown() @@ -166,7 +166,7 @@ def test_upgrade_pony13to14(upgrader: TriblerUpgrader, state_dir, channels_dir, mds = MetadataStore(mds_path, channels_dir, trustchain_keypair, check_tables=False) with db_session: - assert upgrader.column_exists_in_table(mds._db, 'ChannelNode', 'tag_processor_version') + assert upgrader.column_exists_in_table(mds.db, 'ChannelNode', 'tag_processor_version') assert mds.get_value('db_version') == '14' @@ -188,7 +188,7 @@ def _exists(db, table, column): return upgrader.column_exists_in_table(db, table, column) # The end result is the same as in the previous test - assert _exists(mds._db, 'ChannelNode', 'tag_processor_version') + assert _exists(mds.db, 'ChannelNode', 'tag_processor_version') assert _exists(tags.instance, 'TorrentTagOp', 'auto_generated') assert mds.get_value('db_version') == '14' @@ -199,7 +199,7 @@ def test_upgrade_pony12to13(upgrader, channels_dir, mds_path, trustchain_keypair upgrader.upgrade_pony_db_12to13() mds = MetadataStore(mds_path, channels_dir, trustchain_keypair, check_tables=False, db_version=12) - db = mds._db # pylint: disable=protected-access + db = mds.db # pylint: disable=protected-access existing_indexes = [ 'idx_channelnode__metadata_type__partial', diff --git a/src/tribler/core/upgrade/upgrade.py b/src/tribler/core/upgrade/upgrade.py index 199d8ae5e9a..56cee19a3ba 100644 --- a/src/tribler/core/upgrade/upgrade.py +++ b/src/tribler/core/upgrade/upgrade.py @@ -205,7 +205,7 @@ def do_upgrade_pony_db_12to13(self, mds): from_version = 12 to_version = 13 - db = mds._db # pylint: disable=protected-access + db = mds.db # pylint: disable=protected-access with db_session: db_version = mds.MiscData.get(name="db_version") @@ -255,8 +255,8 @@ def add_column(db, table_name, column_name, column_type): column_type='BOOLEAN') tags.instance.commit() - add_column(db=mds._db, table_name='ChannelNode', column_name='tag_processor_version', column_type='INT') - mds._db.commit() + add_column(db=mds.db, table_name='ChannelNode', column_name='tag_processor_version', column_type='INT') + mds.db.commit() mds.set_value(key='db_version', value=version.next) def do_upgrade_pony_db_11to12(self, mds): @@ -276,9 +276,9 @@ def do_upgrade_pony_db_11to12(self, mds): for column_name, datatype in new_columns: # pylint: disable=protected-access - if not self.column_exists_in_table(mds._db, table_name, column_name): + if not self.column_exists_in_table(mds.db, table_name, column_name): sql = f'ALTER TABLE {table_name} ADD {column_name} {datatype};' - mds._db.execute(sql) + mds.db.execute(sql) db_version = mds.MiscData.get(name="db_version") db_version.value = str(to_version) @@ -298,9 +298,9 @@ def do_upgrade_pony_db_10to11(self, mds): column_name = "self_checked" # pylint: disable=protected-access - if not self.column_exists_in_table(mds._db, table_name, column_name): + if not self.column_exists_in_table(mds.db, table_name, column_name): sql = f'ALTER TABLE {table_name} ADD {column_name} BOOLEAN default 0;' - mds._db.execute(sql) + mds.db.execute(sql) db_version = mds.MiscData.get(name="db_version") db_version.value = str(to_version) diff --git a/src/tribler/core/utilities/pony_utils.py b/src/tribler/core/utilities/pony_utils.py index 19238a59447..4e589049dd6 100644 --- a/src/tribler/core/utilities/pony_utils.py +++ b/src/tribler/core/utilities/pony_utils.py @@ -1,6 +1,8 @@ -from typing import Type +import threading +from asyncio import get_event_loop +from typing import Callable, Type -from pony.orm.core import Entity, select +from pony.orm.core import Database, Entity, select # pylint: disable=bad-staticmethod-argument @@ -29,3 +31,28 @@ def get_max(cls: Type[Entity], column_name='rowid') -> int: Returns: Max row ID or 0. """ return select(max(getattr(obj, column_name)) for obj in cls).get() or 0 + + +async def run_threaded(db: Database, func: Callable, *args, **kwargs): + """ Run `func` threaded and close DB connection at the end of the execution. + + Args: + db: the DB to be closed + func: the function to be executed threaded + *args: args for the function call + **kwargs: kwargs for the function call + + Returns: a result of the func call. + """ + + def wrapper(): + try: + return func(*args, **kwargs) + finally: + # @ichorid: this is a workaround for closing threadpool connections + # Remark: maybe subclass ThreadPoolExecutor to handle this automatically? + is_main_thread = isinstance(threading.current_thread(), threading._MainThread) # pylint: disable=W0212 + if not is_main_thread: + db.disconnect() + + return await get_event_loop().run_in_executor(None, wrapper) From 6ad892a3d12fff4ff9130a565a1770542dbe5027 Mon Sep 17 00:00:00 2001 From: drew2a Date: Mon, 7 Nov 2022 14:03:15 +0100 Subject: [PATCH 2/4] Remove # pylint: disable=protected-access --- scripts/seedbox/disseminator.py | 2 +- .../metadata_store/tests/test_channel_metadata.py | 8 +++----- src/tribler/core/upgrade/tests/test_upgrader.py | 6 ++---- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/scripts/seedbox/disseminator.py b/scripts/seedbox/disseminator.py index 5d51ed76468..76a896a940d 100644 --- a/scripts/seedbox/disseminator.py +++ b/scripts/seedbox/disseminator.py @@ -159,7 +159,7 @@ def commit(self): def flush(self): _logger.debug('Flush') - self.community.mds.db.flush() # pylint: disable=protected-access + self.community.mds.db.flush() class Service(TinyTriblerService): diff --git a/src/tribler/core/components/metadata_store/tests/test_channel_metadata.py b/src/tribler/core/components/metadata_store/tests/test_channel_metadata.py index fcb2426472b..375977c207e 100644 --- a/src/tribler/core/components/metadata_store/tests/test_channel_metadata.py +++ b/src/tribler/core/components/metadata_store/tests/test_channel_metadata.py @@ -5,14 +5,11 @@ from pathlib import Path from unittest.mock import Mock, patch +import pytest from ipv8.keyvault.crypto import default_eccrypto - from lz4.frame import LZ4FrameDecompressor - from pony.orm import ObjectNotFound, db_session -import pytest - from tribler.core.components.libtorrent.torrentdef import TorrentDef from tribler.core.components.metadata_store.db.orm_bindings.channel_metadata import ( CHANNEL_DIR_NAME_LENGTH, @@ -31,6 +28,7 @@ from tribler.core.utilities.simpledefs import CHANNEL_STATE from tribler.core.utilities.utilities import random_infohash + # pylint: disable=protected-access @@ -80,7 +78,7 @@ def mds_with_some_torrents_fixture(metadata_store): # torrent6 aaa zzz def save(): - metadata_store.db.flush() # pylint: disable=W0212 + metadata_store.db.flush() def new_channel(**kwargs): params = dict(subscribed=True, share=True, status=NEW, infohash=random_infohash()) diff --git a/src/tribler/core/upgrade/tests/test_upgrader.py b/src/tribler/core/upgrade/tests/test_upgrader.py index 692b2cd3243..809e9d4fd43 100644 --- a/src/tribler/core/upgrade/tests/test_upgrader.py +++ b/src/tribler/core/upgrade/tests/test_upgrader.py @@ -63,7 +63,7 @@ def test_upgrade_pony_db_complete(upgrader, channels_dir, state_dir, trustchain_ upgrader.run() mds = MetadataStore(mds_path, channels_dir, trustchain_keypair) - db = mds.db # pylint: disable=protected-access + db = mds.db existing_indexes = [ 'idx_channelnode__metadata_type__partial', @@ -136,7 +136,6 @@ def test_upgrade_pony_10to11(upgrader, channels_dir, mds_path, trustchain_keypai upgrader.upgrade_pony_db_10to11() mds = MetadataStore(mds_path, channels_dir, trustchain_keypair, check_tables=False, db_version=11) with db_session: - # pylint: disable=protected-access assert upgrader.column_exists_in_table(mds.db, 'TorrentState', 'self_checked') assert mds.get_value("db_version") == '11' mds.shutdown() @@ -148,7 +147,6 @@ def test_upgrade_pony11to12(upgrader, channels_dir, mds_path, trustchain_keypair upgrader.upgrade_pony_db_11to12() mds = MetadataStore(mds_path, channels_dir, trustchain_keypair, check_tables=False, db_version=11) with db_session: - # pylint: disable=protected-access assert upgrader.column_exists_in_table(mds.db, 'ChannelNode', 'json_text') assert upgrader.column_exists_in_table(mds.db, 'ChannelNode', 'binary_data') assert upgrader.column_exists_in_table(mds.db, 'ChannelNode', 'data_type') @@ -199,7 +197,7 @@ def test_upgrade_pony12to13(upgrader, channels_dir, mds_path, trustchain_keypair upgrader.upgrade_pony_db_12to13() mds = MetadataStore(mds_path, channels_dir, trustchain_keypair, check_tables=False, db_version=12) - db = mds.db # pylint: disable=protected-access + db = mds.db existing_indexes = [ 'idx_channelnode__metadata_type__partial', From 789012c349fb68206c54b48e547e91501e16ca6f Mon Sep 17 00:00:00 2001 From: drew2a Date: Mon, 7 Nov 2022 14:18:31 +0100 Subject: [PATCH 3/4] Remove `run_threaded` from `store.py` --- .../core/components/metadata_store/db/store.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/tribler/core/components/metadata_store/db/store.py b/src/tribler/core/components/metadata_store/db/store.py index daeb264c151..06440a3f9cc 100644 --- a/src/tribler/core/components/metadata_store/db/store.py +++ b/src/tribler/core/components/metadata_store/db/store.py @@ -1,13 +1,10 @@ import logging import re -import threading -from asyncio import get_event_loop from datetime import datetime, timedelta from time import sleep, time from typing import Optional, Union from lz4.frame import LZ4FrameDecompressor - from pony import orm from pony.orm import db_session, desc, left_join, raw_sql, select from pony.orm.dbproviders.sqlite import keep_exception @@ -55,7 +52,6 @@ from tribler.core.utilities.unicode import hexlify from tribler.core.utilities.utilities import MEMORY_DB - BETA_DB_VERSIONS = [0, 1, 2, 3, 4, 5] CURRENT_DB_VERSION = 14 @@ -245,15 +241,6 @@ def on_connect(_, connection): default_vsids = self.Vsids.create_default_vsids() self.ChannelMetadata.votes_scaling = default_vsids.max_val - async def run_threaded(self, func, *args, **kwargs): - def wrapper(): - try: - return func(*args, **kwargs) - finally: - self.disconnect_thread() - - return await get_event_loop().run_in_executor(None, wrapper) - def set_value(self, key: str, value: str): key_value = get_or_create(self.MiscData, name=key) key_value.value = value From 2d79443c55aa495f0e675da9b66ad3439487986d Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 7 Nov 2022 15:18:25 +0100 Subject: [PATCH 4/4] Add function description --- src/tribler/core/utilities/pony_utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/tribler/core/utilities/pony_utils.py b/src/tribler/core/utilities/pony_utils.py index 4e589049dd6..8317a330db1 100644 --- a/src/tribler/core/utilities/pony_utils.py +++ b/src/tribler/core/utilities/pony_utils.py @@ -34,7 +34,7 @@ def get_max(cls: Type[Entity], column_name='rowid') -> int: async def run_threaded(db: Database, func: Callable, *args, **kwargs): - """ Run `func` threaded and close DB connection at the end of the execution. + """Run `func` threaded and close DB connection at the end of the execution. Args: db: the DB to be closed @@ -43,6 +43,15 @@ async def run_threaded(db: Database, func: Callable, *args, **kwargs): **kwargs: kwargs for the function call Returns: a result of the func call. + + You should use `run_threaded` to wrap all functions that should be executed from a separate thread and work with + the database. The `run_threaded` function ensures that all database connections opened in worker threads are + properly closed before the Tribler shutdown. + + The Asyncio `run_in_executor` method executes its argument in a separate worker thread. After the db_session is + over, PonyORM caches the connection to the database to re-use it again later in the same thread. It was previously + reported that some obscure problems could be observed during the Tribler shutdown if connections in the Tribler + worker threads are not closed properly. """ def wrapper():