Skip to content

Commit

Permalink
feature(ORM): update ORMThreadPoolBase's API (#11)
Browse files Browse the repository at this point in the history
This PR updates the ORMThreadPoolBase's API, now all the sql query APIs return concurrent.futures.Future object instead.
  • Loading branch information
pga2rn authored Jul 21, 2024
1 parent 3f40db8 commit 1db0d93
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 65 deletions.
85 changes: 60 additions & 25 deletions src/simple_sqlite3_orm/_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sqlite3
import sys
import threading
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor
from functools import cached_property
from typing import (
TYPE_CHECKING,
Expand All @@ -27,7 +27,6 @@

from simple_sqlite3_orm._sqlite_spec import INSERT_OR, ORDER_DIRECTION
from simple_sqlite3_orm._table_spec import TableSpec, TableSpecType
from simple_sqlite3_orm._typing import copy_callable_typehint

_parameterized_orm_cache: WeakValueDictionary[
tuple[type[ORMBase], type[TableSpec]], type[ORMBase[Any]]
Expand Down Expand Up @@ -411,11 +410,6 @@ class ORMThreadPoolBase(ORMBase[TableSpecType]):
See https://www.sqlite.org/wal.html#concurrency for more details.
"""

@property
def _con(self) -> sqlite3.Connection:
"""Get thread-specific sqlite3 connection."""
return self._thread_id_cons[threading.get_native_id()]

def __init__(
self,
table_name: str,
Expand All @@ -441,6 +435,11 @@ def _thread_initializer():
thread_name_prefix=thread_name_prefix,
)

@property
def _con(self) -> sqlite3.Connection:
"""Get thread-specific sqlite3 connection."""
return self._thread_id_cons[threading.get_native_id()]

@property
@deprecated("orm_con is not available in thread pool ORM")
def orm_con(self):
Expand All @@ -466,18 +465,44 @@ def orm_pool_shutdown(self, *, wait=True, close_connections=True) -> None:

def orm_execute(
self, sql_stmt: str, params: tuple[Any, ...] | dict[str, Any] | None = None
) -> list[Any]:
return self._pool.submit(super().orm_execute, sql_stmt, params).result()
) -> Future[list[Any]]:
return self._pool.submit(super().orm_execute, sql_stmt, params)

orm_execute.__doc__ = ORMBase.orm_execute.__doc__

@copy_callable_typehint(ORMBase.orm_create_table)
def orm_create_table(self, *args, **kwargs):
self._pool.submit(super().orm_create_table, *args, **kwargs).result()
def orm_create_table(
self,
*,
allow_existed: bool = False,
strict: bool = False,
without_rowid: bool = False,
) -> Future[None]:
return self._pool.submit(
super().orm_create_table,
allow_existed=allow_existed,
strict=strict,
without_rowid=without_rowid,
)

orm_create_table.__doc__ = ORMBase.orm_create_table.__doc__

def orm_create_index(
self,
*,
index_name: str,
index_keys: tuple[str, ...],
allow_existed: bool = False,
unique: bool = False,
) -> Future[None]:
return self._pool.submit(
super().orm_create_index,
index_name=index_name,
index_keys=index_keys,
allow_existed=allow_existed,
unique=unique,
)

@copy_callable_typehint(ORMBase.orm_create_index)
def orm_create_index(self, *args, **kwargs):
self._pool.submit(super().orm_create_index, *args, **kwargs).result()
orm_create_index.__doc__ = ORMBase.orm_create_index.__doc__

def orm_select_entries_gen(
self,
Expand Down Expand Up @@ -528,7 +553,7 @@ def orm_select_entries(
_order_by: tuple[str | tuple[str, Literal["ASC", "DESC"]], ...] | None = None,
_limit: int | None = None,
**col_values: Any,
) -> list[TableSpecType]:
) -> Future[list[TableSpecType]]:
"""Select multiple entries and return all the entries in a list."""

def _inner():
Expand All @@ -542,15 +567,21 @@ def _inner():
)
)

return self._pool.submit(_inner).result()
return self._pool.submit(_inner)

def orm_insert_entries(
self, _in: Iterable[TableSpecType], *, or_option: INSERT_OR | None = None
) -> Future[int]:
return self._pool.submit(super().orm_insert_entries, _in, or_option=or_option)

orm_insert_entries.__doc__ = ORMBase.orm_insert_entries.__doc__

@copy_callable_typehint(ORMBase.orm_insert_entries)
def orm_insert_entries(self, *args, **kwargs):
return self._pool.submit(super().orm_insert_entries, *args, **kwargs).result()
def orm_insert_entry(
self, _in: TableSpecType, *, or_option: INSERT_OR | None = None
) -> Future[int]:
return self._pool.submit(super().orm_insert_entry, _in, or_option=or_option)

@copy_callable_typehint(ORMBase.orm_insert_entry)
def orm_insert_entry(self, *args, **kwargs):
return self._pool.submit(super().orm_insert_entry, *args, **kwargs).result()
orm_insert_entry.__doc__ = ORMBase.orm_insert_entry.__doc__

def orm_delete_entries(
self,
Expand All @@ -559,7 +590,7 @@ def orm_delete_entries(
_limit: int | None = None,
_returning_cols: tuple[str, ...] | None | Literal["*"] = None,
**cols_value: Any,
) -> int | list[TableSpecType]:
) -> Future[int | list[TableSpecType]]:
# NOTE(20240708): currently we don't support generator for delete with RETURNING statement
def _inner():
res = ORMBase.orm_delete_entries(
Expand All @@ -574,7 +605,9 @@ def _inner():
return res
return list(res)

return self._pool.submit(_inner).result()
return self._pool.submit(_inner)

orm_delete_entries.__doc__ = ORMBase.orm_delete_entries.__doc__


class AsyncORMThreadPoolBase(ORMThreadPoolBase[TableSpecType]):
Expand Down Expand Up @@ -733,12 +766,14 @@ async def orm_create_table(
self,
*,
allow_existed: bool = False,
strict: bool = False,
without_rowid: bool = False,
) -> None:
return await self._run_in_pool(
ORMBase.orm_create_table,
self,
allow_existed=allow_existed,
strict=strict,
without_rowid=without_rowid,
)

Expand Down
23 changes: 0 additions & 23 deletions src/simple_sqlite3_orm/_typing.py

This file was deleted.

5 changes: 1 addition & 4 deletions tests/test_e2e/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def async_pool(
)
yield pool
finally:
pool.orm_pool_shutdown()
pool.orm_pool_shutdown(wait=True, close_connections=True)

@pytest_asyncio.fixture(autouse=True, scope="class")
async def start_timer(self) -> tuple[asyncio.Task[None], asyncio.Event]:
Expand Down Expand Up @@ -163,6 +163,3 @@ async def test_check_timer(
_task, _event = start_timer
_event.set()
await _task

async def test_shutdown_pool(self, async_pool: SampleDBAsyncio):
async_pool.orm_pool_shutdown(wait=True, close_connections=True)
28 changes: 15 additions & 13 deletions tests/test_e2e/test_threadpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import logging
import sqlite3
from concurrent.futures import ThreadPoolExecutor
from typing import Callable
from concurrent.futures import Future, ThreadPoolExecutor
from functools import partial
from typing import Any, Callable

import pytest

Expand All @@ -23,6 +24,10 @@ class SampleDBConnectionPool(ORMThreadPoolBase[SampleTable]):
WORKER_NUM = 6


def _get_result(func: Callable[..., Future[Any]]):
return func().result()


class TestWithSampleDBAndThreadPool:

@pytest.fixture(autouse=True, scope="class")
Expand All @@ -35,11 +40,11 @@ def thread_pool(self, setup_con_factory: Callable[[], sqlite3.Connection]):
)
yield pool
finally:
pool.orm_pool_shutdown()
pool.orm_pool_shutdown(wait=True, close_connections=True)

def test_create_table(self, thread_pool: SampleDBConnectionPool):
logger.info("test create table")
thread_pool.orm_create_table(without_rowid=True)
thread_pool.orm_create_table(without_rowid=True).result()

def test_insert_entries_with_pool(
self,
Expand All @@ -54,7 +59,7 @@ def test_insert_entries_with_pool(
batched(setup_test_data.values(), TEST_INSERT_BATCH_SIZE),
start=1,
):
pool.submit(thread_pool.orm_insert_entries, entry)
pool.submit(_get_result, partial(thread_pool.orm_insert_entries, entry))

logger.info(
f"all insert tasks are dispatched: {_batch_count} batches with {TEST_INSERT_BATCH_SIZE=}"
Expand All @@ -74,7 +79,7 @@ def test_create_index(self, thread_pool: SampleDBConnectionPool):
index_name=INDEX_NAME,
index_keys=INDEX_KEYS,
unique=True,
)
).result()

def test_orm_execute(
self,
Expand All @@ -86,7 +91,7 @@ def test_orm_execute(
select_from=thread_pool.orm_table_name,
function="count",
)
res = thread_pool.orm_execute(sql_stmt)
res = thread_pool.orm_execute(sql_stmt).result()

assert res and res[0][0] == len(setup_test_data)

Expand All @@ -98,7 +103,7 @@ def test_lookup_entries(
_looked_up = thread_pool.orm_select_entries(
key_id=_entry.key_id,
prim_key_sha256hash=_entry.prim_key_sha256hash,
)
).result()
_looked_up = list(_looked_up)
assert len(_looked_up) == 1
assert _looked_up[0] == _entry
Expand All @@ -120,19 +125,16 @@ def test_delete_entries(
_res = thread_pool.orm_delete_entries(
key_id=entry.key_id,
prim_key_sha256hash=entry.prim_key_sha256hash,
)
).result()
assert _res == 1
else:
for entry in entries_to_remove:
_res = thread_pool.orm_delete_entries(
_returning_cols="*",
key_id=entry.key_id,
prim_key_sha256hash=entry.prim_key_sha256hash,
)
).result()
assert isinstance(_res, list)

assert len(_res) == 1
assert _res[0] == entry

def test_shutdown_pool(self, thread_pool: SampleDBConnectionPool):
thread_pool.orm_pool_shutdown(wait=True, close_connections=True)

1 comment on commit 1db0d93

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/simple_sqlite3_orm
   __init__.py70100% 
   _orm.py2322688%43, 45–46, 49–50, 148, 402, 447, 529, 531–532, 541–542, 544, 582, 605, 648, 679, 681–682, 691–692, 694, 729, 758, 813
   _sqlite_spec.py360100% 
   _table_spec.py1712187%45, 60, 69, 80, 84–85, 97, 109, 126, 187, 192–194, 270, 272, 274–275, 333–334, 443–444
   _types.py19194%27
   _utils.py51590%52, 65, 68, 78, 91
   utils.py971782%112, 119–120, 155–158, 204, 217, 237–241, 267, 271, 312
TOTAL6137088% 

Tests Skipped Failures Errors Time
60 0 💤 0 ❌ 0 🔥 2m 0s ⏱️

Please sign in to comment.