Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the signatures of @(async)contextmanager #12087

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
14 changes: 13 additions & 1 deletion stdlib/contextlib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ from abc import abstractmethod
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Generator, Iterator
from types import TracebackType
from typing import IO, Any, Generic, Protocol, TypeVar, overload, runtime_checkable
from typing_extensions import ParamSpec, Self, TypeAlias
from typing_extensions import ParamSpec, Self, TypeAlias, deprecated

__all__ = [
"contextmanager",
Expand Down Expand Up @@ -75,6 +75,12 @@ class _GeneratorContextManager(AbstractContextManager[_T_co, bool | None], Conte
self, type: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None
) -> bool | None: ...

@overload
def contextmanager(func: Callable[_P, Generator[_T_co]]) -> Callable[_P, _GeneratorContextManager[_T_co]]: ...
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the type here should be Generator[_T_co, None, object] instead? contextmanager does not actually care what the return type of the generator is. E.g. this:

@contextmanager
def f():
    yield
    return 1

is fine (though pointless).

@overload
@deprecated(
"Annotating the return type as `-> Iterator[Foo]` with `@contextmanager` is deprecated. Use `-> Generator[Foo]` instead."
)
def contextmanager(func: Callable[_P, Iterator[_T_co]]) -> Callable[_P, _GeneratorContextManager[_T_co]]: ...

if sys.version_info >= (3, 10):
Expand Down Expand Up @@ -107,6 +113,12 @@ else:
self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None
) -> bool | None: ...

@overload
def asynccontextmanager(func: Callable[_P, AsyncGenerator[_T_co]]) -> Callable[_P, _AsyncGeneratorContextManager[_T_co]]: ...
@overload
@deprecated(
"Annotating the return type as `-> AsyncIterator[Foo]` with `@contextmanager` is deprecated. Use `-> AsyncGenerator[Foo]` instead."
)
def asynccontextmanager(func: Callable[_P, AsyncIterator[_T_co]]) -> Callable[_P, _AsyncGeneratorContextManager[_T_co]]: ...

class _SupportsClose(Protocol):
Expand Down
Loading