From 82b02e38004154e65beb2bf09b93b0ea8c22f833 Mon Sep 17 00:00:00 2001 From: pga2rn Date: Sun, 21 Jul 2024 18:03:21 +0900 Subject: [PATCH 1/8] orm.ThreadPoolBase: all APIs now return Future object instead --- src/simple_sqlite3_orm/_orm.py | 83 ++++++++++++++++++++++++---------- 1 file changed, 58 insertions(+), 25 deletions(-) diff --git a/src/simple_sqlite3_orm/_orm.py b/src/simple_sqlite3_orm/_orm.py index f37c1fa0..efdcc8d6 100644 --- a/src/simple_sqlite3_orm/_orm.py +++ b/src/simple_sqlite3_orm/_orm.py @@ -7,7 +7,7 @@ import sqlite3 import sys import threading -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, Future from functools import cached_property from typing import ( TYPE_CHECKING, @@ -27,7 +27,6 @@ from simple_sqlite3_orm._sqlite_spec import INSERT_OR, ORDER_DIRECTION from simple_sqlite3_orm._table_spec import TableSpec, TableSpecType -from simple_sqlite3_orm._typing import copy_callable_typehint _parameterized_orm_cache: WeakValueDictionary[ tuple[type[ORMBase], type[TableSpec]], type[ORMBase[Any]] @@ -411,11 +410,6 @@ class ORMThreadPoolBase(ORMBase[TableSpecType]): See https://www.sqlite.org/wal.html#concurrency for more details. """ - @property - def _con(self) -> sqlite3.Connection: - """Get thread-specific sqlite3 connection.""" - return self._thread_id_cons[threading.get_native_id()] - def __init__( self, table_name: str, @@ -441,6 +435,11 @@ def _thread_initializer(): thread_name_prefix=thread_name_prefix, ) + @property + def _con(self) -> sqlite3.Connection: + """Get thread-specific sqlite3 connection.""" + return self._thread_id_cons[threading.get_native_id()] + @property @deprecated("orm_con is not available in thread pool ORM") def orm_con(self): @@ -466,18 +465,44 @@ def orm_pool_shutdown(self, *, wait=True, close_connections=True) -> None: def orm_execute( self, sql_stmt: str, params: tuple[Any, ...] | dict[str, Any] | None = None - ) -> list[Any]: - return self._pool.submit(super().orm_execute, sql_stmt, params).result() + ) -> Future[list[Any]]: + return self._pool.submit(super().orm_execute, sql_stmt, params) orm_execute.__doc__ = ORMBase.orm_execute.__doc__ - @copy_callable_typehint(ORMBase.orm_create_table) - def orm_create_table(self, *args, **kwargs): - self._pool.submit(super().orm_create_table, *args, **kwargs).result() + def orm_create_table( + self, + *, + allow_existed: bool = False, + strict: bool = False, + without_rowid: bool = False, + ) -> Future[None]: + return self._pool.submit( + super().orm_create_table, + allow_existed=allow_existed, + strict=strict, + without_rowid=without_rowid, + ) - @copy_callable_typehint(ORMBase.orm_create_index) - def orm_create_index(self, *args, **kwargs): - self._pool.submit(super().orm_create_index, *args, **kwargs).result() + orm_create_table.__doc__ = ORMBase.orm_create_table.__doc__ + + def orm_create_index( + self, + *, + index_name: str, + index_keys: tuple[str, ...], + allow_existed: bool = False, + unique: bool = False, + ) -> Future[None]: + return self._pool.submit( + super().orm_create_index, + index_name=index_name, + index_keys=index_keys, + allow_existed=allow_existed, + unique=unique, + ) + + orm_create_index.__doc__ = ORMBase.orm_create_index.__doc__ def orm_select_entries_gen( self, @@ -528,7 +553,7 @@ def orm_select_entries( _order_by: tuple[str | tuple[str, Literal["ASC", "DESC"]], ...] | None = None, _limit: int | None = None, **col_values: Any, - ) -> list[TableSpecType]: + ) -> Future[list[TableSpecType]]: """Select multiple entries and return all the entries in a list.""" def _inner(): @@ -542,15 +567,21 @@ def _inner(): ) ) - return self._pool.submit(_inner).result() + return self._pool.submit(_inner) - @copy_callable_typehint(ORMBase.orm_insert_entries) - def orm_insert_entries(self, *args, **kwargs): - return self._pool.submit(super().orm_insert_entries, *args, **kwargs).result() + def orm_insert_entries( + self, _in: Iterable[TableSpecType], *, or_option: INSERT_OR | None = None + ) -> Future[int]: + return self._pool.submit(super().orm_insert_entries, _in, or_option=or_option) - @copy_callable_typehint(ORMBase.orm_insert_entry) - def orm_insert_entry(self, *args, **kwargs): - return self._pool.submit(super().orm_insert_entry, *args, **kwargs).result() + orm_insert_entries.__doc__ = ORMBase.orm_insert_entries.__doc__ + + def orm_insert_entry( + self, _in: TableSpecType, *, or_option: INSERT_OR | None = None + ) -> Future[int]: + return self._pool.submit(super().orm_insert_entry, _in, or_option=or_option) + + orm_insert_entry.__doc__ = ORMBase.orm_insert_entry.__doc__ def orm_delete_entries( self, @@ -559,7 +590,7 @@ def orm_delete_entries( _limit: int | None = None, _returning_cols: tuple[str, ...] | None | Literal["*"] = None, **cols_value: Any, - ) -> int | list[TableSpecType]: + ) -> Future[int | list[TableSpecType]]: # NOTE(20240708): currently we don't support generator for delete with RETURNING statement def _inner(): res = ORMBase.orm_delete_entries( @@ -574,7 +605,9 @@ def _inner(): return res return list(res) - return self._pool.submit(_inner).result() + return self._pool.submit(_inner) + + orm_delete_entries.__doc__ = ORMBase.orm_delete_entries.__doc__ class AsyncORMThreadPoolBase(ORMThreadPoolBase[TableSpecType]): From a1b73e3636b5e1db7077425b946d25e879f3b941 Mon Sep 17 00:00:00 2001 From: pga2rn Date: Sun, 21 Jul 2024 18:03:32 +0900 Subject: [PATCH 2/8] remove unused --- src/simple_sqlite3_orm/_typing.py | 23 ----------------------- 1 file changed, 23 deletions(-) delete mode 100644 src/simple_sqlite3_orm/_typing.py diff --git a/src/simple_sqlite3_orm/_typing.py b/src/simple_sqlite3_orm/_typing.py deleted file mode 100644 index c35f82b9..00000000 --- a/src/simple_sqlite3_orm/_typing.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -from typing import Any, Callable - -from typing_extensions import ParamSpec - -P = ParamSpec("P") - - -def copy_callable_typehint(src: Callable[P, Any]): - """This helper function return a decorator that can type hint the target - function as the _source function. - - At runtime, this decorator actually does nothing, but just return the input function as it. - But the returned function will have the same type hint as the source function in ide. - It will not impact the runtime behavior of the decorated function. - """ - - def _decorator(target) -> Callable[P, Any]: - target.__doc__ = src.__doc__ - return target - - return _decorator From 0a8e13fd8046b8611f82ea39da10928e3755d418 Mon Sep 17 00:00:00 2001 From: pga2rn Date: Sun, 21 Jul 2024 18:07:36 +0900 Subject: [PATCH 3/8] test_threadpool: update accordingly --- tests/test_e2e/test_threadpool.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/test_e2e/test_threadpool.py b/tests/test_e2e/test_threadpool.py index e48f014a..db0a0e35 100644 --- a/tests/test_e2e/test_threadpool.py +++ b/tests/test_e2e/test_threadpool.py @@ -35,11 +35,11 @@ def thread_pool(self, setup_con_factory: Callable[[], sqlite3.Connection]): ) yield pool finally: - pool.orm_pool_shutdown() + pool.orm_pool_shutdown(wait=True, close_connections=True) def test_create_table(self, thread_pool: SampleDBConnectionPool): logger.info("test create table") - thread_pool.orm_create_table(without_rowid=True) + thread_pool.orm_create_table(without_rowid=True).result() def test_insert_entries_with_pool( self, @@ -74,7 +74,7 @@ def test_create_index(self, thread_pool: SampleDBConnectionPool): index_name=INDEX_NAME, index_keys=INDEX_KEYS, unique=True, - ) + ).result() def test_orm_execute( self, @@ -86,7 +86,7 @@ def test_orm_execute( select_from=thread_pool.orm_table_name, function="count", ) - res = thread_pool.orm_execute(sql_stmt) + res = thread_pool.orm_execute(sql_stmt).result() assert res and res[0][0] == len(setup_test_data) @@ -98,7 +98,7 @@ def test_lookup_entries( _looked_up = thread_pool.orm_select_entries( key_id=_entry.key_id, prim_key_sha256hash=_entry.prim_key_sha256hash, - ) + ).result() _looked_up = list(_looked_up) assert len(_looked_up) == 1 assert _looked_up[0] == _entry @@ -120,7 +120,7 @@ def test_delete_entries( _res = thread_pool.orm_delete_entries( key_id=entry.key_id, prim_key_sha256hash=entry.prim_key_sha256hash, - ) + ).result() assert _res == 1 else: for entry in entries_to_remove: @@ -128,11 +128,8 @@ def test_delete_entries( _returning_cols="*", key_id=entry.key_id, prim_key_sha256hash=entry.prim_key_sha256hash, - ) + ).result() assert isinstance(_res, list) assert len(_res) == 1 assert _res[0] == entry - - def test_shutdown_pool(self, thread_pool: SampleDBConnectionPool): - thread_pool.orm_pool_shutdown(wait=True, close_connections=True) From 3ad09f918c6c8e50fa4ff86d8895ede8462a4ff3 Mon Sep 17 00:00:00 2001 From: pga2rn Date: Sun, 21 Jul 2024 18:07:54 +0900 Subject: [PATCH 4/8] test_asyncio: remove duplicated test case --- tests/test_e2e/test_asyncio.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_e2e/test_asyncio.py b/tests/test_e2e/test_asyncio.py index 334e73b9..4ff119d7 100644 --- a/tests/test_e2e/test_asyncio.py +++ b/tests/test_e2e/test_asyncio.py @@ -42,7 +42,7 @@ async def async_pool( ) yield pool finally: - pool.orm_pool_shutdown() + pool.orm_pool_shutdown(wait=True, close_connections=True) @pytest_asyncio.fixture(autouse=True, scope="class") async def start_timer(self) -> tuple[asyncio.Task[None], asyncio.Event]: @@ -163,6 +163,3 @@ async def test_check_timer( _task, _event = start_timer _event.set() await _task - - async def test_shutdown_pool(self, async_pool: SampleDBAsyncio): - async_pool.orm_pool_shutdown(wait=True, close_connections=True) From bc084249a46ea4bb39218790c7cae0d23db0a3b7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 21 Jul 2024 09:09:51 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/simple_sqlite3_orm/_orm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_sqlite3_orm/_orm.py b/src/simple_sqlite3_orm/_orm.py index efdcc8d6..8b6c2382 100644 --- a/src/simple_sqlite3_orm/_orm.py +++ b/src/simple_sqlite3_orm/_orm.py @@ -7,7 +7,7 @@ import sqlite3 import sys import threading -from concurrent.futures import ThreadPoolExecutor, Future +from concurrent.futures import Future, ThreadPoolExecutor from functools import cached_property from typing import ( TYPE_CHECKING, From 6ec2fdec1f19a78755e85cc5742a3d892588f827 Mon Sep 17 00:00:00 2001 From: pga2rn Date: Sun, 21 Jul 2024 18:15:40 +0900 Subject: [PATCH 6/8] minor fix --- src/simple_sqlite3_orm/_orm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/simple_sqlite3_orm/_orm.py b/src/simple_sqlite3_orm/_orm.py index 8b6c2382..7d3d7a96 100644 --- a/src/simple_sqlite3_orm/_orm.py +++ b/src/simple_sqlite3_orm/_orm.py @@ -766,12 +766,14 @@ async def orm_create_table( self, *, allow_existed: bool = False, + strict: bool = False, without_rowid: bool = False, ) -> None: return await self._run_in_pool( ORMBase.orm_create_table, self, allow_existed=allow_existed, + strict=strict, without_rowid=without_rowid, ) From cddcb85f782bbf2bb270bc4d5b05dc6eb98d1b06 Mon Sep 17 00:00:00 2001 From: pga2rn Date: Sun, 21 Jul 2024 18:25:56 +0900 Subject: [PATCH 7/8] minor fix to test threadpool --- tests/test_e2e/test_threadpool.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/test_e2e/test_threadpool.py b/tests/test_e2e/test_threadpool.py index db0a0e35..350f76e5 100644 --- a/tests/test_e2e/test_threadpool.py +++ b/tests/test_e2e/test_threadpool.py @@ -2,8 +2,9 @@ import logging import sqlite3 -from concurrent.futures import ThreadPoolExecutor -from typing import Callable +from concurrent.futures import ThreadPoolExecutor, Future +from functools import partial +from typing import Any, Callable import pytest @@ -23,6 +24,10 @@ class SampleDBConnectionPool(ORMThreadPoolBase[SampleTable]): WORKER_NUM = 6 +def _get_result(func: Callable[..., Future[Any]]): + return func().result() + + class TestWithSampleDBAndThreadPool: @pytest.fixture(autouse=True, scope="class") @@ -54,7 +59,7 @@ def test_insert_entries_with_pool( batched(setup_test_data.values(), TEST_INSERT_BATCH_SIZE), start=1, ): - pool.submit(thread_pool.orm_insert_entries, entry) + pool.submit(_get_result, partial(thread_pool.orm_insert_entries, entry)) logger.info( f"all insert tasks are dispatched: {_batch_count} batches with {TEST_INSERT_BATCH_SIZE=}" From 877339e55653d18ba5d5aaccc20f165875cedcd0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 21 Jul 2024 09:26:09 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_e2e/test_threadpool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_e2e/test_threadpool.py b/tests/test_e2e/test_threadpool.py index 350f76e5..4fc9cac6 100644 --- a/tests/test_e2e/test_threadpool.py +++ b/tests/test_e2e/test_threadpool.py @@ -2,7 +2,7 @@ import logging import sqlite3 -from concurrent.futures import ThreadPoolExecutor, Future +from concurrent.futures import Future, ThreadPoolExecutor from functools import partial from typing import Any, Callable