From fecc24b5a7dbf369887ef098861851befc46b07c Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Sat, 23 Dec 2023 19:23:49 +0000 Subject: [PATCH] improve `atomic` --- piccolo/engine/base.py | 28 ++++++++++++++++++++++++---- piccolo/engine/postgres.py | 3 ++- piccolo/engine/sqlite.py | 3 ++- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/piccolo/engine/base.py b/piccolo/engine/base.py index ca46c7c26..4681df86e 100644 --- a/piccolo/engine/base.py +++ b/piccolo/engine/base.py @@ -15,7 +15,7 @@ from piccolo.utils.warnings import Level, colored_string, colored_warning if t.TYPE_CHECKING: # pragma: no cover - from piccolo.query.base import Query + from piccolo.query.base import DDL, Query logger = logging.getLogger(__name__) @@ -35,7 +35,7 @@ def validate_savepoint_name(savepoint_name: str) -> None: ) -class Batch: +class Batch(metaclass=ABCMeta): @abc.abstractmethod async def __aenter__(self, *args, **kwargs): ... @@ -53,14 +53,34 @@ async def __anext__(self) -> t.List[t.Dict]: ... -class BaseTransaction: +class BaseTransaction(metaclass=ABCMeta): + @abc.abstractmethod async def __aenter__(self, *args, **kwargs): ... + @abc.abstractmethod async def __aexit__(self, *args, **kwargs): ... +class BaseAtomic(metaclass=ABCMeta): + @abc.abstractmethod + def add(self, *query: t.Union[Query, DDL]): + ... + + @abc.abstractmethod + async def run(self): + ... + + @abc.abstractmethod + def run_sync(self): + ... + + @abc.abstractmethod + def __await__(self): + ... + + TransactionClass = t.TypeVar("TransactionClass", bound=BaseTransaction) @@ -125,7 +145,7 @@ def transaction(self) -> TransactionClass: pass @abstractmethod - def atomic(self): + def atomic(self) -> BaseAtomic: pass async def check_version(self): diff --git a/piccolo/engine/postgres.py b/piccolo/engine/postgres.py index 909eec3af..294b3a5a9 100644 --- a/piccolo/engine/postgres.py +++ b/piccolo/engine/postgres.py @@ -7,6 +7,7 @@ from typing_extensions import Self from piccolo.engine.base import ( + BaseAtomic, BaseTransaction, Batch, Engine, @@ -86,7 +87,7 @@ async def __aexit__(self, exception_type, exception, traceback): ############################################################################### -class Atomic: +class Atomic(BaseAtomic): """ This is useful if you want to build up a transaction programatically, by adding queries to it. diff --git a/piccolo/engine/sqlite.py b/piccolo/engine/sqlite.py index ab466be55..78f83316c 100644 --- a/piccolo/engine/sqlite.py +++ b/piccolo/engine/sqlite.py @@ -13,6 +13,7 @@ from typing_extensions import Self from piccolo.engine.base import ( + BaseAtomic, BaseTransaction, Batch, Engine, @@ -256,7 +257,7 @@ class TransactionType(enum.Enum): exclusive = "EXCLUSIVE" -class Atomic: +class Atomic(BaseAtomic): """ Usage: