diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index 038710929..b38cdfb71 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -15,7 +15,17 @@ """Helpers for applying Google Cloud Firestore changes in a transaction.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Coroutine, Optional +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Awaitable, + Callable, + Coroutine, + Optional, + TypeVar, + Protocol, +) from google.api_core import exceptions, gapic_v1 from google.api_core import retry_async as retries @@ -41,6 +51,9 @@ from google.cloud.firestore_v1.query_profile import ExplainOptions +T = TypeVar("T", bound=Callable[..., Any]) + + class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction): """Accumulate read-and-write operations to be sent in a transaction. @@ -236,11 +249,13 @@ class _AsyncTransactional(_BaseTransactional): A coroutine that should be run (and retried) in a transaction. """ - def __init__(self, to_wrap) -> None: + def __init__( + self, to_wrap: Callable[..., Awaitable[T]] + ) -> None: super(_AsyncTransactional, self).__init__(to_wrap) async def _pre_commit( - self, transaction: AsyncTransaction, *args, **kwargs + self, transaction: AsyncTransaction, *args: Any, **kwargs: Any ) -> Coroutine: """Begin transaction and call the wrapped coroutine. @@ -254,7 +269,7 @@ async def _pre_commit( along to the wrapped coroutine. Returns: - Any: result of the wrapped coroutine. + T: result of the wrapped coroutine. Raises: Exception: Any failure caused by ``to_wrap``. @@ -269,12 +284,14 @@ async def _pre_commit( self.retry_id = self.current_id return await self.to_wrap(transaction, *args, **kwargs) - async def __call__(self, transaction, *args, **kwargs): + async def __call__( + self, transaction: AsyncTransaction, *args: Any, **kwargs: Any + ) -> T: """Execute the wrapped callable within a transaction. Args: transaction - (:class:`~google.cloud.firestore_v1.transaction.Transaction`): + (:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`): A transaction to execute the callable within. args (Tuple[Any, ...]): The extra positional arguments to pass along to the wrapped callable. @@ -282,7 +299,7 @@ async def __call__(self, transaction, *args, **kwargs): along to the wrapped callable. Returns: - Any: The result of the wrapped callable. + T: The result of the wrapped callable. Raises: ValueError: If the transaction does not succeed in @@ -320,14 +337,17 @@ async def __call__(self, transaction, *args, **kwargs): raise +class WithAsyncTransaction(Protocol[T]): + def __call__(self, transaction: AsyncTransaction, *args: Any, **kwargs: Any) -> Awaitable[T]: ... + def async_transactional( - to_wrap: Callable[[AsyncTransaction], Any] -) -> _AsyncTransactional: + to_wrap: Callable[..., Awaitable[T]] +) -> WithAsyncTransaction[T]: """Decorate a callable so that it runs in a transaction. Args: to_wrap - (Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]): + (Callable[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Any]): A callable that should be run (and retried) in a transaction. Returns: