diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index be8668cd6..36509941e 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -15,7 +15,14 @@ """Helpers for applying Google Cloud Firestore changes in a transaction.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Coroutine, Optional +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Awaitable, + Callable, + Optional, +) from google.api_core import exceptions, gapic_v1 from google.api_core import retry_async as retries @@ -37,11 +44,15 @@ # Types needed only for Type Hints if TYPE_CHECKING: # pragma: NO COVER import datetime + from typing_extensions import TypeVar, ParamSpec, Concatenate from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.query_profile import ExplainOptions + T = TypeVar("T") + P = ParamSpec("P") + class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction): """Accumulate read-and-write operations to be sent in a transaction. @@ -253,12 +264,14 @@ class _AsyncTransactional(_BaseTransactional): A coroutine that should be run (and retried) in a transaction. """ - def __init__(self, to_wrap) -> None: + def __init__( + self, to_wrap: Callable[Concatenate[AsyncTransaction, P], Awaitable[T]] + ) -> None: super(_AsyncTransactional, self).__init__(to_wrap) async def _pre_commit( - self, transaction: AsyncTransaction, *args, **kwargs - ) -> Coroutine: + self, transaction: AsyncTransaction, *args: P.args, **kwargs: P.kwargs + ) -> T: """Begin transaction and call the wrapped coroutine. Args: @@ -271,7 +284,7 @@ async def _pre_commit( along to the wrapped coroutine. Returns: - Any: result of the wrapped coroutine. + T: result of the wrapped coroutine. Raises: Exception: Any failure caused by ``to_wrap``. @@ -286,12 +299,14 @@ async def _pre_commit( self.retry_id = self.current_id return await self.to_wrap(transaction, *args, **kwargs) - async def __call__(self, transaction, *args, **kwargs): + async def __call__( + self, transaction: AsyncTransaction, *args: P.args, **kwargs: P.kwargs + ) -> T: """Execute the wrapped callable within a transaction. Args: transaction - (:class:`~google.cloud.firestore_v1.transaction.Transaction`): + (:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`): A transaction to execute the callable within. args (Tuple[Any, ...]): The extra positional arguments to pass along to the wrapped callable. @@ -299,7 +314,7 @@ async def __call__(self, transaction, *args, **kwargs): along to the wrapped callable. Returns: - Any: The result of the wrapped callable. + T: The result of the wrapped callable. Raises: ValueError: If the transaction does not succeed in @@ -313,7 +328,7 @@ async def __call__(self, transaction, *args, **kwargs): try: for attempt in range(transaction._max_attempts): - result = await self._pre_commit(transaction, *args, **kwargs) + result: T = await self._pre_commit(transaction, *args, **kwargs) try: await transaction._commit() return result @@ -338,17 +353,17 @@ async def __call__(self, transaction, *args, **kwargs): def async_transactional( - to_wrap: Callable[[AsyncTransaction], Any] -) -> _AsyncTransactional: + to_wrap: Callable[Concatenate[AsyncTransaction, P], Awaitable[T]] +) -> Callable[Concatenate[AsyncTransaction, P], Awaitable[T]]: """Decorate a callable so that it runs in a transaction. Args: to_wrap - (Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]): + (Callable[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Awaitable[Any]]): A callable that should be run (and retried) in a transaction. Returns: - Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]: + Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Awaitable[Any]]: the wrapped callable. """ return _AsyncTransactional(to_wrap) diff --git a/mypy.ini b/mypy.ini index 4505b4854..beaa679a8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,3 +1,3 @@ [mypy] -python_version = 3.6 +python_version = 3.8 namespace_packages = True