Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generic typing support for Memory[Send/Receive]Channel #2549

Merged
merged 1 commit into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 115 additions & 51 deletions trio/_channel.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,48 @@
from __future__ import annotations

from collections import deque, OrderedDict
from collections.abc import Callable
from math import inf

from types import TracebackType
from typing import (
Any,
Generic,
NoReturn,
TypeVar,
TYPE_CHECKING,
Tuple, # only needed for typechecking on <3.9
)

import attr
from outcome import Error, Value

from .abc import SendChannel, ReceiveChannel, Channel
from ._util import generic_function, NoPublicConstructor

import trio
from ._core import enable_ki_protection
from ._core import enable_ki_protection, Task, Abort, RaiseCancelT

# A regular invariant generic type
T = TypeVar("T")

# The type of object produced by a ReceiveChannel (covariant because
# ReceiveChannel[Derived] can be passed to someone expecting
# ReceiveChannel[Base])
ReceiveType = TypeVar("ReceiveType", covariant=True)

# The type of object accepted by a SendChannel (contravariant because
# SendChannel[Base] can be passed to someone expecting
# SendChannel[Derived])
SendType = TypeVar("SendType", contravariant=True)
A5rocks marked this conversation as resolved.
Show resolved Hide resolved

# Temporary TypeVar needed until mypy release supports Self as a type
SelfT = TypeVar("SelfT")

@generic_function
def open_memory_channel(max_buffer_size):

def _open_memory_channel(
max_buffer_size: int,
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
"""Open a channel for passing objects between tasks within a process.

Memory channels are lightweight, cheap to allocate, and entirely
Expand Down Expand Up @@ -68,36 +98,57 @@ def open_memory_channel(max_buffer_size):
raise TypeError("max_buffer_size must be an integer or math.inf")
if max_buffer_size < 0:
raise ValueError("max_buffer_size must be >= 0")
state = MemoryChannelState(max_buffer_size)
state: MemoryChannelState[T] = MemoryChannelState(max_buffer_size)
return (
MemorySendChannel._create(state),
MemoryReceiveChannel._create(state),
MemorySendChannel[T]._create(state),
MemoryReceiveChannel[T]._create(state),
)


# This workaround requires python3.9+, once older python versions are not supported
# or there's a better way of achieving type-checking on a generic factory function,
# it could replace the normal function header
if TYPE_CHECKING:
# written as a class so you can say open_memory_channel[int](5)
# Need to use Tuple instead of tuple due to CI check running on 3.8
class open_memory_channel(Tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]):
def __new__( # type: ignore[misc] # "must return a subtype"
cls, max_buffer_size: int
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
return _open_memory_channel(max_buffer_size)

def __init__(self, max_buffer_size: int):
...

else:
# apply the generic_function decorator to make open_memory_channel indexable
# so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime
open_memory_channel = generic_function(_open_memory_channel)


@attr.s(frozen=True, slots=True)
class MemoryChannelStats:
current_buffer_used = attr.ib()
max_buffer_size = attr.ib()
open_send_channels = attr.ib()
open_receive_channels = attr.ib()
tasks_waiting_send = attr.ib()
tasks_waiting_receive = attr.ib()
current_buffer_used: int = attr.ib()
max_buffer_size: int = attr.ib()
open_send_channels: int = attr.ib()
open_receive_channels: int = attr.ib()
tasks_waiting_send: int = attr.ib()
tasks_waiting_receive: int = attr.ib()


@attr.s(slots=True)
class MemoryChannelState:
max_buffer_size = attr.ib()
data = attr.ib(factory=deque)
class MemoryChannelState(Generic[T]):
max_buffer_size: int = attr.ib()
data: deque[T] = attr.ib(factory=deque)
# Counts of open endpoints using this state
open_send_channels = attr.ib(default=0)
open_receive_channels = attr.ib(default=0)
open_send_channels: int = attr.ib(default=0)
open_receive_channels: int = attr.ib(default=0)
# {task: value}
send_tasks = attr.ib(factory=OrderedDict)
send_tasks: OrderedDict[Task, T] = attr.ib(factory=OrderedDict)
# {task: None}
receive_tasks = attr.ib(factory=OrderedDict)
receive_tasks: OrderedDict[Task, None] = attr.ib(factory=OrderedDict)

def statistics(self):
def statistics(self) -> MemoryChannelStats:
return MemoryChannelStats(
current_buffer_used=len(self.data),
max_buffer_size=self.max_buffer_size,
Expand All @@ -109,28 +160,28 @@ def statistics(self):


@attr.s(eq=False, repr=False)
class MemorySendChannel(SendChannel, metaclass=NoPublicConstructor):
_state = attr.ib()
_closed = attr.ib(default=False)
class MemorySendChannel(SendChannel[SendType], metaclass=NoPublicConstructor):
_state: MemoryChannelState[SendType] = attr.ib()
_closed: bool = attr.ib(default=False)
# This is just the tasks waiting on *this* object. As compared to
# self._state.send_tasks, which includes tasks from this object and
# all clones.
_tasks = attr.ib(factory=set)
_tasks: set[Task] = attr.ib(factory=set)

def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
self._state.open_send_channels += 1

def __repr__(self):
def __repr__(self) -> str:
return "<send channel at {:#x}, using buffer at {:#x}>".format(
id(self), id(self._state)
)

def statistics(self):
def statistics(self) -> MemoryChannelStats:
# XX should we also report statistics specific to this object?
return self._state.statistics()

@enable_ki_protection
def send_nowait(self, value):
def send_nowait(self, value: SendType) -> None:
"""Like `~trio.abc.SendChannel.send`, but if the channel's buffer is
full, raises `WouldBlock` instead of blocking.

Expand All @@ -150,7 +201,7 @@ def send_nowait(self, value):
raise trio.WouldBlock

@enable_ki_protection
async def send(self, value):
async def send(self, value: SendType) -> None:
"""See `SendChannel.send <trio.abc.SendChannel.send>`.

Memory channels allow multiple tasks to call `send` at the same time.
Expand All @@ -170,15 +221,16 @@ async def send(self, value):
self._state.send_tasks[task] = value
task.custom_sleep_data = self

def abort_fn(_):
def abort_fn(_: RaiseCancelT) -> Abort:
self._tasks.remove(task)
del self._state.send_tasks[task]
return trio.lowlevel.Abort.SUCCEEDED

await trio.lowlevel.wait_task_rescheduled(abort_fn)

# Return type must be stringified or use a TypeVar
@enable_ki_protection
def clone(self):
def clone(self) -> "MemorySendChannel[SendType]":
"""Clone this send channel object.

This returns a new `MemorySendChannel` object, which acts as a
Expand Down Expand Up @@ -206,14 +258,19 @@ def clone(self):
raise trio.ClosedResourceError
return MemorySendChannel._create(self._state)

def __enter__(self):
def __enter__(self: SelfT) -> SelfT:
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()

@enable_ki_protection
def close(self):
def close(self) -> None:
"""Close this send channel object synchronously.

All channel objects have an asynchronous `~.AsyncResource.aclose` method.
Expand Down Expand Up @@ -241,30 +298,30 @@ def close(self):
self._state.receive_tasks.clear()

@enable_ki_protection
async def aclose(self):
async def aclose(self) -> None:
self.close()
await trio.lowlevel.checkpoint()


@attr.s(eq=False, repr=False)
class MemoryReceiveChannel(ReceiveChannel, metaclass=NoPublicConstructor):
_state = attr.ib()
_closed = attr.ib(default=False)
_tasks = attr.ib(factory=set)
class MemoryReceiveChannel(ReceiveChannel[ReceiveType], metaclass=NoPublicConstructor):
_state: MemoryChannelState[ReceiveType] = attr.ib()
_closed: bool = attr.ib(default=False)
_tasks: set[trio._core._run.Task] = attr.ib(factory=set)

def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
self._state.open_receive_channels += 1

def statistics(self):
def statistics(self) -> MemoryChannelStats:
return self._state.statistics()

def __repr__(self):
def __repr__(self) -> str:
return "<receive channel at {:#x}, using buffer at {:#x}>".format(
id(self), id(self._state)
)

@enable_ki_protection
def receive_nowait(self):
def receive_nowait(self) -> ReceiveType:
"""Like `~trio.abc.ReceiveChannel.receive`, but if there's nothing
ready to receive, raises `WouldBlock` instead of blocking.

Expand All @@ -284,7 +341,7 @@ def receive_nowait(self):
raise trio.WouldBlock

@enable_ki_protection
async def receive(self):
async def receive(self) -> ReceiveType:
"""See `ReceiveChannel.receive <trio.abc.ReceiveChannel.receive>`.

Memory channels allow multiple tasks to call `receive` at the same
Expand All @@ -306,15 +363,17 @@ async def receive(self):
self._state.receive_tasks[task] = None
task.custom_sleep_data = self

def abort_fn(_):
def abort_fn(_: RaiseCancelT) -> Abort:
self._tasks.remove(task)
del self._state.receive_tasks[task]
return trio.lowlevel.Abort.SUCCEEDED

return await trio.lowlevel.wait_task_rescheduled(abort_fn)
# Not strictly guaranteed to return ReceiveType, but will do so unless
# you intentionally reschedule with a bad value.
return await trio.lowlevel.wait_task_rescheduled(abort_fn) # type: ignore[no-any-return]

@enable_ki_protection
def clone(self):
def clone(self) -> "MemoryReceiveChannel[ReceiveType]":
"""Clone this receive channel object.

This returns a new `MemoryReceiveChannel` object, which acts as a
Expand Down Expand Up @@ -345,14 +404,19 @@ def clone(self):
raise trio.ClosedResourceError
return MemoryReceiveChannel._create(self._state)

def __enter__(self):
def __enter__(self: SelfT) -> SelfT:
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()

@enable_ki_protection
def close(self):
def close(self) -> None:
"""Close this receive channel object synchronously.

All channel objects have an asynchronous `~.AsyncResource.aclose` method.
Expand Down Expand Up @@ -381,6 +445,6 @@ def close(self):
self._state.data.clear()

@enable_ki_protection
async def aclose(self):
async def aclose(self) -> None:
self.close()
await trio.lowlevel.checkpoint()
1 change: 1 addition & 0 deletions trio/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from ._traps import (
cancel_shielded_checkpoint,
Abort,
RaiseCancelT,
wait_task_rescheduled,
temporarily_detach_coroutine_object,
permanently_detach_coroutine_object,
Expand Down
7 changes: 6 additions & 1 deletion trio/_core/_traps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from . import _run

from typing import Callable, NoReturn, Any

# Helper for the bottommost 'yield'. You can't use 'yield' inside an async
# function, but you can inside a generator, and if you decorate your generator
Expand Down Expand Up @@ -64,7 +65,11 @@ class WaitTaskRescheduled:
abort_func = attr.ib()


async def wait_task_rescheduled(abort_func):
RaiseCancelT = Callable[[], NoReturn] # TypeAlias
A5rocks marked this conversation as resolved.
Show resolved Hide resolved

# Should always return the type a Task "expects", unless you willfully reschedule it
# with a bad value.
async def wait_task_rescheduled(abort_func: Callable[[RaiseCancelT], Abort]) -> Any:
"""Put the current task to sleep, with cancellation support.

This is the lowest-level API for blocking in Trio. Every time a
Expand Down
1 change: 1 addition & 0 deletions trio/lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ._core import (
cancel_shielded_checkpoint,
Abort,
RaiseCancelT,
jakkdl marked this conversation as resolved.
Show resolved Hide resolved
wait_task_rescheduled,
enable_ki_protection,
disable_ki_protection,
Expand Down