Skip to content

Commit

Permalink
Merge pull request #2549 from jakkdl/memorychannel_generic_typing
Browse files Browse the repository at this point in the history
Add generic typing support for Memory[Send/Receive]Channel
  • Loading branch information
Zac-HD authored Feb 3, 2023
2 parents 35ba304 + 798e87a commit 333f9af
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 52 deletions.
166 changes: 115 additions & 51 deletions trio/_channel.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,48 @@
from __future__ import annotations

from collections import deque, OrderedDict
from collections.abc import Callable
from math import inf

from types import TracebackType
from typing import (
Any,
Generic,
NoReturn,
TypeVar,
TYPE_CHECKING,
Tuple, # only needed for typechecking on <3.9
)

import attr
from outcome import Error, Value

from .abc import SendChannel, ReceiveChannel, Channel
from ._util import generic_function, NoPublicConstructor

import trio
from ._core import enable_ki_protection
from ._core import enable_ki_protection, Task, Abort, RaiseCancelT

# A regular invariant generic type
T = TypeVar("T")

# The type of object produced by a ReceiveChannel (covariant because
# ReceiveChannel[Derived] can be passed to someone expecting
# ReceiveChannel[Base])
ReceiveType = TypeVar("ReceiveType", covariant=True)

# The type of object accepted by a SendChannel (contravariant because
# SendChannel[Base] can be passed to someone expecting
# SendChannel[Derived])
SendType = TypeVar("SendType", contravariant=True)

# Temporary TypeVar needed until mypy release supports Self as a type
SelfT = TypeVar("SelfT")

@generic_function
def open_memory_channel(max_buffer_size):

def _open_memory_channel(
max_buffer_size: int,
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
"""Open a channel for passing objects between tasks within a process.
Memory channels are lightweight, cheap to allocate, and entirely
Expand Down Expand Up @@ -68,36 +98,57 @@ def open_memory_channel(max_buffer_size):
raise TypeError("max_buffer_size must be an integer or math.inf")
if max_buffer_size < 0:
raise ValueError("max_buffer_size must be >= 0")
state = MemoryChannelState(max_buffer_size)
state: MemoryChannelState[T] = MemoryChannelState(max_buffer_size)
return (
MemorySendChannel._create(state),
MemoryReceiveChannel._create(state),
MemorySendChannel[T]._create(state),
MemoryReceiveChannel[T]._create(state),
)


# This workaround requires python3.9+, once older python versions are not supported
# or there's a better way of achieving type-checking on a generic factory function,
# it could replace the normal function header
if TYPE_CHECKING:
# written as a class so you can say open_memory_channel[int](5)
# Need to use Tuple instead of tuple due to CI check running on 3.8
class open_memory_channel(Tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]):
def __new__( # type: ignore[misc] # "must return a subtype"
cls, max_buffer_size: int
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
return _open_memory_channel(max_buffer_size)

def __init__(self, max_buffer_size: int):
...

else:
# apply the generic_function decorator to make open_memory_channel indexable
# so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime
open_memory_channel = generic_function(_open_memory_channel)


@attr.s(frozen=True, slots=True)
class MemoryChannelStats:
current_buffer_used = attr.ib()
max_buffer_size = attr.ib()
open_send_channels = attr.ib()
open_receive_channels = attr.ib()
tasks_waiting_send = attr.ib()
tasks_waiting_receive = attr.ib()
current_buffer_used: int = attr.ib()
max_buffer_size: int = attr.ib()
open_send_channels: int = attr.ib()
open_receive_channels: int = attr.ib()
tasks_waiting_send: int = attr.ib()
tasks_waiting_receive: int = attr.ib()


@attr.s(slots=True)
class MemoryChannelState:
max_buffer_size = attr.ib()
data = attr.ib(factory=deque)
class MemoryChannelState(Generic[T]):
max_buffer_size: int = attr.ib()
data: deque[T] = attr.ib(factory=deque)
# Counts of open endpoints using this state
open_send_channels = attr.ib(default=0)
open_receive_channels = attr.ib(default=0)
open_send_channels: int = attr.ib(default=0)
open_receive_channels: int = attr.ib(default=0)
# {task: value}
send_tasks = attr.ib(factory=OrderedDict)
send_tasks: OrderedDict[Task, T] = attr.ib(factory=OrderedDict)
# {task: None}
receive_tasks = attr.ib(factory=OrderedDict)
receive_tasks: OrderedDict[Task, None] = attr.ib(factory=OrderedDict)

def statistics(self):
def statistics(self) -> MemoryChannelStats:
return MemoryChannelStats(
current_buffer_used=len(self.data),
max_buffer_size=self.max_buffer_size,
Expand All @@ -109,28 +160,28 @@ def statistics(self):


@attr.s(eq=False, repr=False)
class MemorySendChannel(SendChannel, metaclass=NoPublicConstructor):
_state = attr.ib()
_closed = attr.ib(default=False)
class MemorySendChannel(SendChannel[SendType], metaclass=NoPublicConstructor):
_state: MemoryChannelState[SendType] = attr.ib()
_closed: bool = attr.ib(default=False)
# This is just the tasks waiting on *this* object. As compared to
# self._state.send_tasks, which includes tasks from this object and
# all clones.
_tasks = attr.ib(factory=set)
_tasks: set[Task] = attr.ib(factory=set)

def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
self._state.open_send_channels += 1

def __repr__(self):
def __repr__(self) -> str:
return "<send channel at {:#x}, using buffer at {:#x}>".format(
id(self), id(self._state)
)

def statistics(self):
def statistics(self) -> MemoryChannelStats:
# XX should we also report statistics specific to this object?
return self._state.statistics()

@enable_ki_protection
def send_nowait(self, value):
def send_nowait(self, value: SendType) -> None:
"""Like `~trio.abc.SendChannel.send`, but if the channel's buffer is
full, raises `WouldBlock` instead of blocking.
Expand All @@ -150,7 +201,7 @@ def send_nowait(self, value):
raise trio.WouldBlock

@enable_ki_protection
async def send(self, value):
async def send(self, value: SendType) -> None:
"""See `SendChannel.send <trio.abc.SendChannel.send>`.
Memory channels allow multiple tasks to call `send` at the same time.
Expand All @@ -170,15 +221,16 @@ async def send(self, value):
self._state.send_tasks[task] = value
task.custom_sleep_data = self

def abort_fn(_):
def abort_fn(_: RaiseCancelT) -> Abort:
self._tasks.remove(task)
del self._state.send_tasks[task]
return trio.lowlevel.Abort.SUCCEEDED

await trio.lowlevel.wait_task_rescheduled(abort_fn)

# Return type must be stringified or use a TypeVar
@enable_ki_protection
def clone(self):
def clone(self) -> "MemorySendChannel[SendType]":
"""Clone this send channel object.
This returns a new `MemorySendChannel` object, which acts as a
Expand Down Expand Up @@ -206,14 +258,19 @@ def clone(self):
raise trio.ClosedResourceError
return MemorySendChannel._create(self._state)

def __enter__(self):
def __enter__(self: SelfT) -> SelfT:
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()

@enable_ki_protection
def close(self):
def close(self) -> None:
"""Close this send channel object synchronously.
All channel objects have an asynchronous `~.AsyncResource.aclose` method.
Expand Down Expand Up @@ -241,30 +298,30 @@ def close(self):
self._state.receive_tasks.clear()

@enable_ki_protection
async def aclose(self):
async def aclose(self) -> None:
self.close()
await trio.lowlevel.checkpoint()


@attr.s(eq=False, repr=False)
class MemoryReceiveChannel(ReceiveChannel, metaclass=NoPublicConstructor):
_state = attr.ib()
_closed = attr.ib(default=False)
_tasks = attr.ib(factory=set)
class MemoryReceiveChannel(ReceiveChannel[ReceiveType], metaclass=NoPublicConstructor):
_state: MemoryChannelState[ReceiveType] = attr.ib()
_closed: bool = attr.ib(default=False)
_tasks: set[trio._core._run.Task] = attr.ib(factory=set)

def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
self._state.open_receive_channels += 1

def statistics(self):
def statistics(self) -> MemoryChannelStats:
return self._state.statistics()

def __repr__(self):
def __repr__(self) -> str:
return "<receive channel at {:#x}, using buffer at {:#x}>".format(
id(self), id(self._state)
)

@enable_ki_protection
def receive_nowait(self):
def receive_nowait(self) -> ReceiveType:
"""Like `~trio.abc.ReceiveChannel.receive`, but if there's nothing
ready to receive, raises `WouldBlock` instead of blocking.
Expand All @@ -284,7 +341,7 @@ def receive_nowait(self):
raise trio.WouldBlock

@enable_ki_protection
async def receive(self):
async def receive(self) -> ReceiveType:
"""See `ReceiveChannel.receive <trio.abc.ReceiveChannel.receive>`.
Memory channels allow multiple tasks to call `receive` at the same
Expand All @@ -306,15 +363,17 @@ async def receive(self):
self._state.receive_tasks[task] = None
task.custom_sleep_data = self

def abort_fn(_):
def abort_fn(_: RaiseCancelT) -> Abort:
self._tasks.remove(task)
del self._state.receive_tasks[task]
return trio.lowlevel.Abort.SUCCEEDED

return await trio.lowlevel.wait_task_rescheduled(abort_fn)
# Not strictly guaranteed to return ReceiveType, but will do so unless
# you intentionally reschedule with a bad value.
return await trio.lowlevel.wait_task_rescheduled(abort_fn) # type: ignore[no-any-return]

@enable_ki_protection
def clone(self):
def clone(self) -> "MemoryReceiveChannel[ReceiveType]":
"""Clone this receive channel object.
This returns a new `MemoryReceiveChannel` object, which acts as a
Expand Down Expand Up @@ -345,14 +404,19 @@ def clone(self):
raise trio.ClosedResourceError
return MemoryReceiveChannel._create(self._state)

def __enter__(self):
def __enter__(self: SelfT) -> SelfT:
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()

@enable_ki_protection
def close(self):
def close(self) -> None:
"""Close this receive channel object synchronously.
All channel objects have an asynchronous `~.AsyncResource.aclose` method.
Expand Down Expand Up @@ -381,6 +445,6 @@ def close(self):
self._state.data.clear()

@enable_ki_protection
async def aclose(self):
async def aclose(self) -> None:
self.close()
await trio.lowlevel.checkpoint()
1 change: 1 addition & 0 deletions trio/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from ._traps import (
cancel_shielded_checkpoint,
Abort,
RaiseCancelT,
wait_task_rescheduled,
temporarily_detach_coroutine_object,
permanently_detach_coroutine_object,
Expand Down
7 changes: 6 additions & 1 deletion trio/_core/_traps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from . import _run

from typing import Callable, NoReturn, Any

# Helper for the bottommost 'yield'. You can't use 'yield' inside an async
# function, but you can inside a generator, and if you decorate your generator
Expand Down Expand Up @@ -64,7 +65,11 @@ class WaitTaskRescheduled:
abort_func = attr.ib()


async def wait_task_rescheduled(abort_func):
RaiseCancelT = Callable[[], NoReturn] # TypeAlias

# Should always return the type a Task "expects", unless you willfully reschedule it
# with a bad value.
async def wait_task_rescheduled(abort_func: Callable[[RaiseCancelT], Abort]) -> Any:
"""Put the current task to sleep, with cancellation support.
This is the lowest-level API for blocking in Trio. Every time a
Expand Down
1 change: 1 addition & 0 deletions trio/lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ._core import (
cancel_shielded_checkpoint,
Abort,
RaiseCancelT,
wait_task_rescheduled,
enable_ki_protection,
disable_ki_protection,
Expand Down

0 comments on commit 333f9af

Please sign in to comment.