Skip to content

Commit

Permalink
Task/trans ctx typing (#128)
Browse files Browse the repository at this point in the history
* Fix return typing of AsyncEngine._trans_ctx.__aenter__
* StartableContext is made generic as to remove the assumption that it'll always be returning itself which is not the case for for it's behaviour in _trans_ctx

* tests: add testcase for #109

Co-authored-by: Faster Speeding <luke@lmbyrne.dev>
  • Loading branch information
MaicoTimmerman and FasterSpeeding authored Jun 25, 2021
1 parent 424b378 commit f761e14
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
13 changes: 6 additions & 7 deletions sqlalchemy-stubs/ext/asyncio/base.pyi
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import abc
from typing import Any
from typing import Generator
from typing import Generic
from typing import TypeVar

_TStartableContext = TypeVar("_TStartableContext", bound=StartableContext)
_T = TypeVar("_T")

class StartableContext(abc.ABC, metaclass=abc.ABCMeta):
class StartableContext(abc.ABC, Generic[_T], metaclass=abc.ABCMeta):
@abc.abstractmethod
async def start(self: _TStartableContext) -> _TStartableContext: ...
def __await__(
self: _TStartableContext,
) -> Generator[Any, None, _TStartableContext]: ...
async def __aenter__(self: _TStartableContext) -> _TStartableContext: ...
async def start(self) -> _T: ...
def __await__(self) -> Generator[Any, None, _T]: ...
async def __aenter__(self) -> _T: ...
@abc.abstractmethod
async def __aexit__(
self, type_: Any, value: Any, traceback: Any
Expand Down
12 changes: 7 additions & 5 deletions sqlalchemy-stubs/ext/asyncio/engine.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def create_async_engine(*arg: Any, **kw: Any) -> AsyncEngine: ...

class AsyncConnectable: ...

class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
class AsyncConnection(
ProxyComparable, StartableContext["AsyncConnection"], AsyncConnectable
):
# copied from future.Connection via create_proxy_methods
@property
def closed(self) -> bool: ...
Expand Down Expand Up @@ -102,12 +104,12 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
def update_execution_options(self, **opt: Any) -> None: ...
def get_execution_options(self) -> Mapping[Any, Any]: ...
# end copied
class _trans_ctx(StartableContext):
class _trans_ctx(StartableContext[AsyncConnection]):
conn: AsyncConnection = ...
def __init__(self, conn: AsyncConnection) -> None: ...
transaction: Any = ...
async def start(self) -> AsyncConnection: ... # type: ignore[override]
def __await__(self) -> Generator[Any, None, AsyncConnection]: ... # type: ignore[override]
async def start(self) -> AsyncConnection: ...
def __await__(self) -> Generator[Any, None, AsyncConnection]: ...
async def __aexit__(
self, type_: Any, value: Any, traceback: Any
) -> None: ...
Expand All @@ -119,7 +121,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
def execution_options(self, **opt: Any) -> AsyncEngine: ...
async def dispose(self) -> None: ...

class AsyncTransaction(ProxyComparable, StartableContext):
class AsyncTransaction(ProxyComparable, StartableContext["AsyncTransaction"]):
connection: AsyncConnection = ...
sync_transaction: Optional[Transaction] = ...
nested: bool = ...
Expand Down
2 changes: 1 addition & 1 deletion sqlalchemy-stubs/ext/asyncio/session.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class _AsyncSessionContextManager:
self, type_: Any, value: Any, traceback: Any
) -> None: ...

class AsyncSessionTransaction(StartableContext):
class AsyncSessionTransaction(StartableContext["AsyncSessionTransaction"]):
session: AsyncSession = ...
nested: bool = ...
sync_transaction: Optional[Any] = ...
Expand Down
11 changes: 11 additions & 0 deletions test/files/async_context_processor_ticket_109.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from sqlalchemy import literal
from sqlalchemy import select
from sqlalchemy.ext import asyncio


async def test() -> None:
database = asyncio.create_async_engine("", future=True)

trans_ctx = database.begin()
async with trans_ctx as connection:
await connection.execute(select(literal(1)))

0 comments on commit f761e14

Please sign in to comment.