1515"""Helpers for applying Google Cloud Firestore changes in a transaction."""
1616from __future__ import annotations
1717
18- from typing import TYPE_CHECKING , Any , AsyncGenerator , Callable , Coroutine , Optional
18+ from typing import (
19+ TYPE_CHECKING ,
20+ Any ,
21+ AsyncGenerator ,
22+ Awaitable ,
23+ Callable ,
24+ Optional ,
25+ )
1926
2027from google .api_core import exceptions , gapic_v1
2128from google .api_core import retry_async as retries
3744# Types needed only for Type Hints
3845if TYPE_CHECKING : # pragma: NO COVER
3946 import datetime
47+ from typing_extensions import TypeVar , ParamSpec , Concatenate
4048
4149 from google .cloud .firestore_v1 .async_stream_generator import AsyncStreamGenerator
4250 from google .cloud .firestore_v1 .base_document import DocumentSnapshot
4351 from google .cloud .firestore_v1 .query_profile import ExplainOptions
4452
53+ T = TypeVar ("T" )
54+ P = ParamSpec ("P" )
55+
4556
4657class AsyncTransaction (async_batch .AsyncWriteBatch , BaseTransaction ):
4758 """Accumulate read-and-write operations to be sent in a transaction.
@@ -253,12 +264,14 @@ class _AsyncTransactional(_BaseTransactional):
253264 A coroutine that should be run (and retried) in a transaction.
254265 """
255266
256- def __init__ (self , to_wrap ) -> None :
267+ def __init__ (
268+ self , to_wrap : Callable [Concatenate [AsyncTransaction , P ], Awaitable [T ]]
269+ ) -> None :
257270 super (_AsyncTransactional , self ).__init__ (to_wrap )
258271
259272 async def _pre_commit (
260- self , transaction : AsyncTransaction , * args , ** kwargs
261- ) -> Coroutine :
273+ self , transaction : AsyncTransaction , * args : P . args , ** kwargs : P . kwargs
274+ ) -> T :
262275 """Begin transaction and call the wrapped coroutine.
263276
264277 Args:
@@ -271,7 +284,7 @@ async def _pre_commit(
271284 along to the wrapped coroutine.
272285
273286 Returns:
274- Any : result of the wrapped coroutine.
287+ T : result of the wrapped coroutine.
275288
276289 Raises:
277290 Exception: Any failure caused by ``to_wrap``.
@@ -286,20 +299,22 @@ async def _pre_commit(
286299 self .retry_id = self .current_id
287300 return await self .to_wrap (transaction , * args , ** kwargs )
288301
289- async def __call__ (self , transaction , * args , ** kwargs ):
302+ async def __call__ (
303+ self , transaction : AsyncTransaction , * args : P .args , ** kwargs : P .kwargs
304+ ) -> T :
290305 """Execute the wrapped callable within a transaction.
291306
292307 Args:
293308 transaction
294- (:class:`~google.cloud.firestore_v1.transaction.Transaction `):
309+ (:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction `):
295310 A transaction to execute the callable within.
296311 args (Tuple[Any, ...]): The extra positional arguments to pass
297312 along to the wrapped callable.
298313 kwargs (Dict[str, Any]): The extra keyword arguments to pass
299314 along to the wrapped callable.
300315
301316 Returns:
302- Any : The result of the wrapped callable.
317+ T : The result of the wrapped callable.
303318
304319 Raises:
305320 ValueError: If the transaction does not succeed in
@@ -313,7 +328,7 @@ async def __call__(self, transaction, *args, **kwargs):
313328
314329 try :
315330 for attempt in range (transaction ._max_attempts ):
316- result = await self ._pre_commit (transaction , * args , ** kwargs )
331+ result : T = await self ._pre_commit (transaction , * args , ** kwargs )
317332 try :
318333 await transaction ._commit ()
319334 return result
@@ -338,17 +353,17 @@ async def __call__(self, transaction, *args, **kwargs):
338353
339354
340355def async_transactional (
341- to_wrap : Callable [[AsyncTransaction ], Any ]
342- ) -> _AsyncTransactional :
356+ to_wrap : Callable [Concatenate [AsyncTransaction , P ], Awaitable [ T ] ]
357+ ) -> Callable [ Concatenate [ AsyncTransaction , P ], Awaitable [ T ]] :
343358 """Decorate a callable so that it runs in a transaction.
344359
345360 Args:
346361 to_wrap
347- (Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction `, ...], Any]):
362+ (Callable[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction `, ...], Awaitable[ Any] ]):
348363 A callable that should be run (and retried) in a transaction.
349364
350365 Returns:
351- Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]:
366+ Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Awaitable[ Any] ]:
352367 the wrapped callable.
353368 """
354369 return _AsyncTransactional (to_wrap )
0 commit comments