Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task/trans ctx typing #128

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions sqlalchemy-stubs/ext/asyncio/base.pyi
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import abc
from typing import Any
from typing import Generator
from typing import Generic
from typing import TypeVar

_TStartableContext = TypeVar("_TStartableContext", bound=StartableContext)
_T = TypeVar("_T")

class StartableContext(abc.ABC, metaclass=abc.ABCMeta):
class StartableContext(abc.ABC, Generic[_T], metaclass=abc.ABCMeta):
@abc.abstractmethod
async def start(self: _TStartableContext) -> _TStartableContext: ...
def __await__(
self: _TStartableContext,
) -> Generator[Any, None, _TStartableContext]: ...
async def __aenter__(self: _TStartableContext) -> _TStartableContext: ...
async def start(self) -> _T: ...
def __await__(self) -> Generator[Any, None, _T]: ...
async def __aenter__(self) -> _T: ...
@abc.abstractmethod
async def __aexit__(
self, type_: Any, value: Any, traceback: Any
Expand Down
12 changes: 7 additions & 5 deletions sqlalchemy-stubs/ext/asyncio/engine.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def create_async_engine(*arg: Any, **kw: Any) -> AsyncEngine: ...

class AsyncConnectable: ...

class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
class AsyncConnection(
ProxyComparable, StartableContext["AsyncConnection"], AsyncConnectable
):
# copied from future.Connection via create_proxy_methods
@property
def closed(self) -> bool: ...
Expand Down Expand Up @@ -102,12 +104,12 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
def update_execution_options(self, **opt: Any) -> None: ...
def get_execution_options(self) -> Mapping[Any, Any]: ...
# end copied
class _trans_ctx(StartableContext):
class _trans_ctx(StartableContext[AsyncConnection]):
conn: AsyncConnection = ...
def __init__(self, conn: AsyncConnection) -> None: ...
transaction: Any = ...
async def start(self) -> AsyncConnection: ... # type: ignore[override]
def __await__(self) -> Generator[Any, None, AsyncConnection]: ... # type: ignore[override]
async def start(self) -> AsyncConnection: ...
def __await__(self) -> Generator[Any, None, AsyncConnection]: ...
async def __aexit__(
self, type_: Any, value: Any, traceback: Any
) -> None: ...
Expand All @@ -119,7 +121,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
def execution_options(self, **opt: Any) -> AsyncEngine: ...
async def dispose(self) -> None: ...

class AsyncTransaction(ProxyComparable, StartableContext):
class AsyncTransaction(ProxyComparable, StartableContext["AsyncTransaction"]):
connection: AsyncConnection = ...
sync_transaction: Optional[Transaction] = ...
nested: bool = ...
Expand Down
2 changes: 1 addition & 1 deletion sqlalchemy-stubs/ext/asyncio/session.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class _AsyncSessionContextManager:
self, type_: Any, value: Any, traceback: Any
) -> None: ...

class AsyncSessionTransaction(StartableContext):
class AsyncSessionTransaction(StartableContext["AsyncSessionTransaction"]):
session: AsyncSession = ...
nested: bool = ...
sync_transaction: Optional[Any] = ...
Expand Down
11 changes: 11 additions & 0 deletions test/files/async_context_processor_ticket_109.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from sqlalchemy import literal
from sqlalchemy import select
from sqlalchemy.ext import asyncio


async def test() -> None:
database = asyncio.create_async_engine("", future=True)

trans_ctx = database.begin()
async with trans_ctx as connection:
await connection.execute(select(literal(1)))