Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(ORM): update ORMThreadPoolBase's API #11

Merged
merged 9 commits into from
Jul 21, 2024
85 changes: 60 additions & 25 deletions src/simple_sqlite3_orm/_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sqlite3
import sys
import threading
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor
from functools import cached_property
from typing import (
TYPE_CHECKING,
Expand All @@ -27,7 +27,6 @@

from simple_sqlite3_orm._sqlite_spec import INSERT_OR, ORDER_DIRECTION
from simple_sqlite3_orm._table_spec import TableSpec, TableSpecType
from simple_sqlite3_orm._typing import copy_callable_typehint

_parameterized_orm_cache: WeakValueDictionary[
tuple[type[ORMBase], type[TableSpec]], type[ORMBase[Any]]
Expand Down Expand Up @@ -411,11 +410,6 @@ class ORMThreadPoolBase(ORMBase[TableSpecType]):
See https://www.sqlite.org/wal.html#concurrency for more details.
"""

@property
def _con(self) -> sqlite3.Connection:
"""Get thread-specific sqlite3 connection."""
return self._thread_id_cons[threading.get_native_id()]

def __init__(
self,
table_name: str,
Expand All @@ -441,6 +435,11 @@ def _thread_initializer():
thread_name_prefix=thread_name_prefix,
)

@property
def _con(self) -> sqlite3.Connection:
"""Get thread-specific sqlite3 connection."""
return self._thread_id_cons[threading.get_native_id()]

@property
@deprecated("orm_con is not available in thread pool ORM")
def orm_con(self):
Expand All @@ -466,18 +465,44 @@ def orm_pool_shutdown(self, *, wait=True, close_connections=True) -> None:

def orm_execute(
self, sql_stmt: str, params: tuple[Any, ...] | dict[str, Any] | None = None
) -> list[Any]:
return self._pool.submit(super().orm_execute, sql_stmt, params).result()
) -> Future[list[Any]]:
return self._pool.submit(super().orm_execute, sql_stmt, params)

orm_execute.__doc__ = ORMBase.orm_execute.__doc__

@copy_callable_typehint(ORMBase.orm_create_table)
def orm_create_table(self, *args, **kwargs):
self._pool.submit(super().orm_create_table, *args, **kwargs).result()
def orm_create_table(
self,
*,
allow_existed: bool = False,
strict: bool = False,
without_rowid: bool = False,
) -> Future[None]:
return self._pool.submit(
super().orm_create_table,
allow_existed=allow_existed,
strict=strict,
without_rowid=without_rowid,
)

orm_create_table.__doc__ = ORMBase.orm_create_table.__doc__

def orm_create_index(
self,
*,
index_name: str,
index_keys: tuple[str, ...],
allow_existed: bool = False,
unique: bool = False,
) -> Future[None]:
return self._pool.submit(
super().orm_create_index,
index_name=index_name,
index_keys=index_keys,
allow_existed=allow_existed,
unique=unique,
)

@copy_callable_typehint(ORMBase.orm_create_index)
def orm_create_index(self, *args, **kwargs):
self._pool.submit(super().orm_create_index, *args, **kwargs).result()
orm_create_index.__doc__ = ORMBase.orm_create_index.__doc__

def orm_select_entries_gen(
self,
Expand Down Expand Up @@ -528,7 +553,7 @@ def orm_select_entries(
_order_by: tuple[str | tuple[str, Literal["ASC", "DESC"]], ...] | None = None,
_limit: int | None = None,
**col_values: Any,
) -> list[TableSpecType]:
) -> Future[list[TableSpecType]]:
"""Select multiple entries and return all the entries in a list."""

def _inner():
Expand All @@ -542,15 +567,21 @@ def _inner():
)
)

return self._pool.submit(_inner).result()
return self._pool.submit(_inner)

def orm_insert_entries(
self, _in: Iterable[TableSpecType], *, or_option: INSERT_OR | None = None
) -> Future[int]:
return self._pool.submit(super().orm_insert_entries, _in, or_option=or_option)

orm_insert_entries.__doc__ = ORMBase.orm_insert_entries.__doc__

@copy_callable_typehint(ORMBase.orm_insert_entries)
def orm_insert_entries(self, *args, **kwargs):
return self._pool.submit(super().orm_insert_entries, *args, **kwargs).result()
def orm_insert_entry(
self, _in: TableSpecType, *, or_option: INSERT_OR | None = None
) -> Future[int]:
return self._pool.submit(super().orm_insert_entry, _in, or_option=or_option)

@copy_callable_typehint(ORMBase.orm_insert_entry)
def orm_insert_entry(self, *args, **kwargs):
return self._pool.submit(super().orm_insert_entry, *args, **kwargs).result()
orm_insert_entry.__doc__ = ORMBase.orm_insert_entry.__doc__

def orm_delete_entries(
self,
Expand All @@ -559,7 +590,7 @@ def orm_delete_entries(
_limit: int | None = None,
_returning_cols: tuple[str, ...] | None | Literal["*"] = None,
**cols_value: Any,
) -> int | list[TableSpecType]:
) -> Future[int | list[TableSpecType]]:
# NOTE(20240708): currently we don't support generator for delete with RETURNING statement
def _inner():
res = ORMBase.orm_delete_entries(
Expand All @@ -574,7 +605,9 @@ def _inner():
return res
return list(res)

return self._pool.submit(_inner).result()
return self._pool.submit(_inner)

orm_delete_entries.__doc__ = ORMBase.orm_delete_entries.__doc__


class AsyncORMThreadPoolBase(ORMThreadPoolBase[TableSpecType]):
Expand Down Expand Up @@ -733,12 +766,14 @@ async def orm_create_table(
self,
*,
allow_existed: bool = False,
strict: bool = False,
without_rowid: bool = False,
) -> None:
return await self._run_in_pool(
ORMBase.orm_create_table,
self,
allow_existed=allow_existed,
strict=strict,
without_rowid=without_rowid,
)

Expand Down
23 changes: 0 additions & 23 deletions src/simple_sqlite3_orm/_typing.py

This file was deleted.

5 changes: 1 addition & 4 deletions tests/test_e2e/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def async_pool(
)
yield pool
finally:
pool.orm_pool_shutdown()
pool.orm_pool_shutdown(wait=True, close_connections=True)

@pytest_asyncio.fixture(autouse=True, scope="class")
async def start_timer(self) -> tuple[asyncio.Task[None], asyncio.Event]:
Expand Down Expand Up @@ -163,6 +163,3 @@ async def test_check_timer(
_task, _event = start_timer
_event.set()
await _task

async def test_shutdown_pool(self, async_pool: SampleDBAsyncio):
async_pool.orm_pool_shutdown(wait=True, close_connections=True)
28 changes: 15 additions & 13 deletions tests/test_e2e/test_threadpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import logging
import sqlite3
from concurrent.futures import ThreadPoolExecutor
from typing import Callable
from concurrent.futures import Future, ThreadPoolExecutor
from functools import partial
from typing import Any, Callable

import pytest

Expand All @@ -23,6 +24,10 @@ class SampleDBConnectionPool(ORMThreadPoolBase[SampleTable]):
WORKER_NUM = 6


def _get_result(func: Callable[..., Future[Any]]):
return func().result()


class TestWithSampleDBAndThreadPool:

@pytest.fixture(autouse=True, scope="class")
Expand All @@ -35,11 +40,11 @@ def thread_pool(self, setup_con_factory: Callable[[], sqlite3.Connection]):
)
yield pool
finally:
pool.orm_pool_shutdown()
pool.orm_pool_shutdown(wait=True, close_connections=True)

def test_create_table(self, thread_pool: SampleDBConnectionPool):
logger.info("test create table")
thread_pool.orm_create_table(without_rowid=True)
thread_pool.orm_create_table(without_rowid=True).result()

def test_insert_entries_with_pool(
self,
Expand All @@ -54,7 +59,7 @@ def test_insert_entries_with_pool(
batched(setup_test_data.values(), TEST_INSERT_BATCH_SIZE),
start=1,
):
pool.submit(thread_pool.orm_insert_entries, entry)
pool.submit(_get_result, partial(thread_pool.orm_insert_entries, entry))

logger.info(
f"all insert tasks are dispatched: {_batch_count} batches with {TEST_INSERT_BATCH_SIZE=}"
Expand All @@ -74,7 +79,7 @@ def test_create_index(self, thread_pool: SampleDBConnectionPool):
index_name=INDEX_NAME,
index_keys=INDEX_KEYS,
unique=True,
)
).result()

def test_orm_execute(
self,
Expand All @@ -86,7 +91,7 @@ def test_orm_execute(
select_from=thread_pool.orm_table_name,
function="count",
)
res = thread_pool.orm_execute(sql_stmt)
res = thread_pool.orm_execute(sql_stmt).result()

assert res and res[0][0] == len(setup_test_data)

Expand All @@ -98,7 +103,7 @@ def test_lookup_entries(
_looked_up = thread_pool.orm_select_entries(
key_id=_entry.key_id,
prim_key_sha256hash=_entry.prim_key_sha256hash,
)
).result()
_looked_up = list(_looked_up)
assert len(_looked_up) == 1
assert _looked_up[0] == _entry
Expand All @@ -120,19 +125,16 @@ def test_delete_entries(
_res = thread_pool.orm_delete_entries(
key_id=entry.key_id,
prim_key_sha256hash=entry.prim_key_sha256hash,
)
).result()
assert _res == 1
else:
for entry in entries_to_remove:
_res = thread_pool.orm_delete_entries(
_returning_cols="*",
key_id=entry.key_id,
prim_key_sha256hash=entry.prim_key_sha256hash,
)
).result()
assert isinstance(_res, list)

assert len(_res) == 1
assert _res[0] == entry

def test_shutdown_pool(self, thread_pool: SampleDBConnectionPool):
thread_pool.orm_pool_shutdown(wait=True, close_connections=True)