From 28039af1c0016046bea1872c66ac74d7e411b844 Mon Sep 17 00:00:00 2001 From: Faster Speeding Date: Mon, 17 May 2021 02:34:04 +0100 Subject: [PATCH] Fix return typing of AsyncEngine._trans_ctx.__aenter__ * StartableContext is made generic as to remove the assumption that it'll always be returning itself which is not the case for for it's behaviour in _trans_ctx --- sqlalchemy-stubs/ext/asyncio/base.pyi | 13 ++++++------- sqlalchemy-stubs/ext/asyncio/engine.pyi | 12 +++++++----- sqlalchemy-stubs/ext/asyncio/session.pyi | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/sqlalchemy-stubs/ext/asyncio/base.pyi b/sqlalchemy-stubs/ext/asyncio/base.pyi index a41b394..13d785e 100644 --- a/sqlalchemy-stubs/ext/asyncio/base.pyi +++ b/sqlalchemy-stubs/ext/asyncio/base.pyi @@ -1,17 +1,16 @@ import abc from typing import Any from typing import Generator +from typing import Generic from typing import TypeVar -_TStartableContext = TypeVar("_TStartableContext", bound=StartableContext) +_T = TypeVar("_T") -class StartableContext(abc.ABC, metaclass=abc.ABCMeta): +class StartableContext(abc.ABC, Generic[_T], metaclass=abc.ABCMeta): @abc.abstractmethod - async def start(self: _TStartableContext) -> _TStartableContext: ... - def __await__( - self: _TStartableContext, - ) -> Generator[Any, None, _TStartableContext]: ... - async def __aenter__(self: _TStartableContext) -> _TStartableContext: ... + async def start(self) -> _T: ... + def __await__(self) -> Generator[Any, None, _T]: ... + async def __aenter__(self) -> _T: ... @abc.abstractmethod async def __aexit__( self, type_: Any, value: Any, traceback: Any diff --git a/sqlalchemy-stubs/ext/asyncio/engine.pyi b/sqlalchemy-stubs/ext/asyncio/engine.pyi index bfe43d0..3e0112f 100644 --- a/sqlalchemy-stubs/ext/asyncio/engine.pyi +++ b/sqlalchemy-stubs/ext/asyncio/engine.pyi @@ -22,7 +22,9 @@ def create_async_engine(*arg: Any, **kw: Any) -> AsyncEngine: ... class AsyncConnectable: ... -class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): +class AsyncConnection( + ProxyComparable, StartableContext["AsyncConnection"], AsyncConnectable +): # copied from future.Connection via create_proxy_methods @property def closed(self) -> bool: ... @@ -102,12 +104,12 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): def update_execution_options(self, **opt: Any) -> None: ... def get_execution_options(self) -> Mapping[Any, Any]: ... # end copied - class _trans_ctx(StartableContext): + class _trans_ctx(StartableContext[AsyncConnection]): conn: AsyncConnection = ... def __init__(self, conn: AsyncConnection) -> None: ... transaction: Any = ... - async def start(self) -> AsyncConnection: ... # type: ignore[override] - def __await__(self) -> Generator[Any, None, AsyncConnection]: ... # type: ignore[override] + async def start(self) -> AsyncConnection: ... + def __await__(self) -> Generator[Any, None, AsyncConnection]: ... async def __aexit__( self, type_: Any, value: Any, traceback: Any ) -> None: ... @@ -119,7 +121,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): def execution_options(self, **opt: Any) -> AsyncEngine: ... async def dispose(self) -> None: ... -class AsyncTransaction(ProxyComparable, StartableContext): +class AsyncTransaction(ProxyComparable, StartableContext["AsyncTransaction"]): connection: AsyncConnection = ... sync_transaction: Optional[Transaction] = ... nested: bool = ... diff --git a/sqlalchemy-stubs/ext/asyncio/session.pyi b/sqlalchemy-stubs/ext/asyncio/session.pyi index 6916576..47b1bd1 100644 --- a/sqlalchemy-stubs/ext/asyncio/session.pyi +++ b/sqlalchemy-stubs/ext/asyncio/session.pyi @@ -151,7 +151,7 @@ class _AsyncSessionContextManager: self, type_: Any, value: Any, traceback: Any ) -> None: ... -class AsyncSessionTransaction(StartableContext): +class AsyncSessionTransaction(StartableContext["AsyncSessionTransaction"]): session: AsyncSession = ... nested: bool = ... sync_transaction: Optional[Any] = ...