diff --git a/stdlib/concurrent/futures/process.pyi b/stdlib/concurrent/futures/process.pyi index d3706a9c15a6..a1de3d679b23 100644 --- a/stdlib/concurrent/futures/process.pyi +++ b/stdlib/concurrent/futures/process.pyi @@ -19,8 +19,9 @@ _global_shutdown: bool class _ThreadWakeup: _closed: bool - _reader: Connection - _writer: Connection + # Any: Unused send and recv methods + _reader: Connection[Any, Any] + _writer: Connection[Any, Any] def close(self) -> None: ... def wakeup(self) -> None: ... def clear(self) -> None: ... diff --git a/stdlib/multiprocessing/connection.pyi b/stdlib/multiprocessing/connection.pyi index 7045a81b85be..9998239d3119 100644 --- a/stdlib/multiprocessing/connection.pyi +++ b/stdlib/multiprocessing/connection.pyi @@ -1,9 +1,9 @@ import socket import sys -import types -from _typeshed import ReadableBuffer +from _typeshed import Incomplete, ReadableBuffer from collections.abc import Iterable -from typing import Any, SupportsIndex +from types import TracebackType +from typing import Any, Generic, SupportsIndex, TypeVar from typing_extensions import Self, TypeAlias __all__ = ["Client", "Listener", "Pipe", "wait"] @@ -11,7 +11,11 @@ __all__ = ["Client", "Listener", "Pipe", "wait"] # https://docs.python.org/3/library/multiprocessing.html#address-formats _Address: TypeAlias = str | tuple[str, int] -class _ConnectionBase: +# Defaulting to Any to avoid forcing generics on a lot of pre-existing code +_SendT = TypeVar("_SendT", contravariant=True, default=Any) +_RecvT = TypeVar("_RecvT", covariant=True, default=Any) + +class _ConnectionBase(Generic[_SendT, _RecvT]): def __init__(self, handle: SupportsIndex, readable: bool = True, writable: bool = True) -> None: ... @property def closed(self) -> bool: ... # undocumented @@ -22,27 +26,27 @@ class _ConnectionBase: def fileno(self) -> int: ... def close(self) -> None: ... def send_bytes(self, buf: ReadableBuffer, offset: int = 0, size: int | None = None) -> None: ... - def send(self, obj: Any) -> None: ... + def send(self, obj: _SendT) -> None: ... def recv_bytes(self, maxlength: int | None = None) -> bytes: ... def recv_bytes_into(self, buf: Any, offset: int = 0) -> int: ... - def recv(self) -> Any: ... + def recv(self) -> _RecvT: ... def poll(self, timeout: float | None = 0.0) -> bool: ... def __enter__(self) -> Self: ... def __exit__( - self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: types.TracebackType | None + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None ) -> None: ... def __del__(self) -> None: ... -class Connection(_ConnectionBase): ... +class Connection(_ConnectionBase[_SendT, _RecvT]): ... if sys.platform == "win32": - class PipeConnection(_ConnectionBase): ... + class PipeConnection(_ConnectionBase[_SendT, _RecvT]): ... class Listener: def __init__( self, address: _Address | None = None, family: str | None = None, backlog: int = 1, authkey: bytes | None = None ) -> None: ... - def accept(self) -> Connection: ... + def accept(self) -> Connection[Incomplete, Incomplete]: ... def close(self) -> None: ... @property def address(self) -> _Address: ... @@ -50,26 +54,30 @@ class Listener: def last_accepted(self) -> _Address | None: ... def __enter__(self) -> Self: ... def __exit__( - self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: types.TracebackType | None + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None ) -> None: ... +# Any: send and recv methods unused if sys.version_info >= (3, 12): - def deliver_challenge(connection: Connection, authkey: bytes, digest_name: str = "sha256") -> None: ... + def deliver_challenge(connection: Connection[Any, Any], authkey: bytes, digest_name: str = "sha256") -> None: ... else: - def deliver_challenge(connection: Connection, authkey: bytes) -> None: ... + def deliver_challenge(connection: Connection[Any, Any], authkey: bytes) -> None: ... -def answer_challenge(connection: Connection, authkey: bytes) -> None: ... +def answer_challenge(connection: Connection[Any, Any], authkey: bytes) -> None: ... def wait( - object_list: Iterable[Connection | socket.socket | int], timeout: float | None = None -) -> list[Connection | socket.socket | int]: ... -def Client(address: _Address, family: str | None = None, authkey: bytes | None = None) -> Connection: ... + object_list: Iterable[Connection[_SendT, _RecvT] | socket.socket | int], timeout: float | None = None +) -> list[Connection[_SendT, _RecvT] | socket.socket | int]: ... +def Client(address: _Address, family: str | None = None, authkey: bytes | None = None) -> Connection[Any, Any]: ... # N.B. Keep this in sync with multiprocessing.context.BaseContext.Pipe. # _ConnectionBase is the common base class of Connection and PipeConnection # and can be used in cross-platform code. +# +# The two connections should have the same generic types but inverted (Connection[_T1, _T2], Connection[_T2, _T1]). +# However, TypeVars scoped entirely within a return annotation is unspecified in the spec. if sys.platform != "win32": - def Pipe(duplex: bool = True) -> tuple[Connection, Connection]: ... + def Pipe(duplex: bool = True) -> tuple[Connection[Any, Any], Connection[Any, Any]]: ... else: - def Pipe(duplex: bool = True) -> tuple[PipeConnection, PipeConnection]: ... + def Pipe(duplex: bool = True) -> tuple[PipeConnection[Any, Any], PipeConnection[Any, Any]]: ... diff --git a/stdlib/multiprocessing/context.pyi b/stdlib/multiprocessing/context.pyi index a3edaa463818..c1cbbce4f63d 100644 --- a/stdlib/multiprocessing/context.pyi +++ b/stdlib/multiprocessing/context.pyi @@ -46,10 +46,13 @@ class BaseContext: # N.B. Keep this in sync with multiprocessing.connection.Pipe. # _ConnectionBase is the common base class of Connection and PipeConnection # and can be used in cross-platform code. + # + # The two connections should have the same generic types but inverted (Connection[_T1, _T2], Connection[_T2, _T1]). + # However, TypeVars scoped entirely within a return annotation is unspecified in the spec. if sys.platform != "win32": - def Pipe(self, duplex: bool = True) -> tuple[Connection, Connection]: ... + def Pipe(self, duplex: bool = True) -> tuple[Connection[Any, Any], Connection[Any, Any]]: ... else: - def Pipe(self, duplex: bool = True) -> tuple[PipeConnection, PipeConnection]: ... + def Pipe(self, duplex: bool = True) -> tuple[PipeConnection[Any, Any], PipeConnection[Any, Any]]: ... def Barrier( self, parties: int, action: Callable[..., object] | None = None, timeout: float | None = None diff --git a/stdlib/multiprocessing/managers.pyi b/stdlib/multiprocessing/managers.pyi index 02b5c4bc8c67..71d87db1d4aa 100644 --- a/stdlib/multiprocessing/managers.pyi +++ b/stdlib/multiprocessing/managers.pyi @@ -1,7 +1,7 @@ import queue import sys import threading -from _typeshed import SupportsKeysAndGetItem, SupportsRichComparison, SupportsRichComparisonT +from _typeshed import Incomplete, SupportsKeysAndGetItem, SupportsRichComparison, SupportsRichComparisonT from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping, MutableSequence, Sequence from types import TracebackType from typing import Any, AnyStr, ClassVar, Generic, SupportsIndex, TypeVar, overload @@ -125,7 +125,9 @@ class Server: self, registry: dict[str, tuple[Callable[..., Any], Any, Any, Any]], address: Any, authkey: bytes, serializer: str ) -> None: ... def serve_forever(self) -> None: ... - def accept_connection(self, c: Connection, name: str) -> None: ... + def accept_connection( + self, c: Connection[tuple[str, str | None], tuple[str, str, Iterable[Incomplete], Mapping[str, Incomplete]]], name: str + ) -> None: ... class BaseManager: if sys.version_info >= (3, 11): diff --git a/stdlib/multiprocessing/reduction.pyi b/stdlib/multiprocessing/reduction.pyi index 91532633e1b9..322a17145f5b 100644 --- a/stdlib/multiprocessing/reduction.pyi +++ b/stdlib/multiprocessing/reduction.pyi @@ -35,8 +35,8 @@ if sys.platform == "win32": handle: int, target_process: int | None = None, inheritable: bool = False, *, source_process: int | None = None ) -> int: ... def steal_handle(source_pid: int, handle: int) -> int: ... - def send_handle(conn: connection.PipeConnection, handle: int, destination_pid: int) -> None: ... - def recv_handle(conn: connection.PipeConnection) -> int: ... + def send_handle(conn: connection.PipeConnection[DupHandle, Any], handle: int, destination_pid: int) -> None: ... + def recv_handle(conn: connection.PipeConnection[Any, DupHandle]) -> int: ... class DupHandle: def __init__(self, handle: int, access: int, pid: int | None = None) -> None: ... diff --git a/test_cases/stdlib/multiprocessing/check_pipe_connections.py b/test_cases/stdlib/multiprocessing/check_pipe_connections.py new file mode 100644 index 000000000000..1d6266a0aabb --- /dev/null +++ b/test_cases/stdlib/multiprocessing/check_pipe_connections.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import sys +from multiprocessing.connection import Pipe + +if sys.platform != "win32": + from multiprocessing.connection import Connection +else: + from multiprocessing.connection import PipeConnection as Connection + + +# Unfortunately, we cannot validate that both connections have the same, but inverted generic types, +# since TypeVars scoped entirely within a return annotation is unspecified in the spec. +# Pipe[str, int]() -> tuple[Connection[str, int], Connection[int, str]] + +a: Connection[str, int] +b: Connection[int, str] +a, b = Pipe() + +connections: tuple[Connection[str, int], Connection[int, str]] = Pipe() +a, b = connections + +a.send("test") +a.send(0) # type: ignore +test1: str = b.recv() +test2: int = b.recv() # type: ignore + +b.send("test") # type: ignore +b.send(0) +test3: str = a.recv() # type: ignore +test4: int = a.recv()