Skip to content

Commit

Permalink
Add generic typing support for MemorySendChannel and MemoryReceiveCha…
Browse files Browse the repository at this point in the history
…nnel
  • Loading branch information
jakkdl committed Jan 28, 2023
1 parent f2a71a0 commit 9674e6b
Showing 1 changed file with 80 additions and 49 deletions.
129 changes: 80 additions & 49 deletions trio/_channel.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
from collections import deque, OrderedDict
from math import inf

from types import TracebackType
from typing import Generic, TypeVar, Any, Tuple, Optional, Type

import attr
from outcome import Error, Value

from .abc import SendChannel, ReceiveChannel, Channel
from ._util import generic_function, NoPublicConstructor

import trio
from ._core import enable_ki_protection
from ._core import enable_ki_protection, Task, Abort

# A regular invariant generic type
T = TypeVar("T")

# The type of object produced by a ReceiveChannel (covariant because
# ReceiveChannel[Derived] can be passed to someone expecting
# ReceiveChannel[Base])
ReceiveType = TypeVar("ReceiveType", covariant=True)

# The type of object accepted by a SendChannel (contravariant because
# SendChannel[Base] can be passed to someone expecting
# SendChannel[Derived])
SendType = TypeVar("SendType", contravariant=True)


@generic_function
def open_memory_channel(max_buffer_size):
def open_memory_channel(
max_buffer_size: int,
) -> Tuple[SendChannel[T], ReceiveChannel[T]]:
"""Open a channel for passing objects between tasks within a process.
Memory channels are lightweight, cheap to allocate, and entirely
Expand Down Expand Up @@ -68,7 +86,7 @@ def open_memory_channel(max_buffer_size):
raise TypeError("max_buffer_size must be an integer or math.inf")
if max_buffer_size < 0:
raise ValueError("max_buffer_size must be >= 0")
state = MemoryChannelState(max_buffer_size)
state: MemoryChannelState[T] = MemoryChannelState(max_buffer_size)
return (
MemorySendChannel._create(state),
MemoryReceiveChannel._create(state),
Expand All @@ -77,27 +95,27 @@ def open_memory_channel(max_buffer_size):

@attr.s(frozen=True, slots=True)
class MemoryChannelStats:
current_buffer_used = attr.ib()
max_buffer_size = attr.ib()
open_send_channels = attr.ib()
open_receive_channels = attr.ib()
tasks_waiting_send = attr.ib()
tasks_waiting_receive = attr.ib()
current_buffer_used: int = attr.ib()
max_buffer_size: int = attr.ib()
open_send_channels: int = attr.ib()
open_receive_channels: int = attr.ib()
tasks_waiting_send: int = attr.ib()
tasks_waiting_receive: int = attr.ib()


@attr.s(slots=True)
class MemoryChannelState:
max_buffer_size = attr.ib()
data = attr.ib(factory=deque)
class MemoryChannelState(Generic[T]):
max_buffer_size: int = attr.ib()
data: deque[T] = attr.ib(factory=deque)
# Counts of open endpoints using this state
open_send_channels = attr.ib(default=0)
open_receive_channels = attr.ib(default=0)
open_send_channels: int = attr.ib(default=0)
open_receive_channels: int = attr.ib(default=0)
# {task: value}
send_tasks = attr.ib(factory=OrderedDict)
send_tasks: OrderedDict[Task, T] = attr.ib(factory=OrderedDict)
# {task: None}
receive_tasks = attr.ib(factory=OrderedDict)
receive_tasks: OrderedDict[Task, None] = attr.ib(factory=OrderedDict)

def statistics(self):
def statistics(self) -> MemoryChannelStats:
return MemoryChannelStats(
current_buffer_used=len(self.data),
max_buffer_size=self.max_buffer_size,
Expand All @@ -109,28 +127,28 @@ def statistics(self):


@attr.s(eq=False, repr=False)
class MemorySendChannel(SendChannel, metaclass=NoPublicConstructor):
_state = attr.ib()
_closed = attr.ib(default=False)
class MemorySendChannel(SendChannel, Generic[SendType], metaclass=NoPublicConstructor):
_state: MemoryChannelState[SendType] = attr.ib()
_closed: bool = attr.ib(default=False)
# This is just the tasks waiting on *this* object. As compared to
# self._state.send_tasks, which includes tasks from this object and
# all clones.
_tasks = attr.ib(factory=set)
_tasks: set[Task] = attr.ib(factory=set)

def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
self._state.open_send_channels += 1

def __repr__(self):
def __repr__(self) -> str:
return "<send channel at {:#x}, using buffer at {:#x}>".format(
id(self), id(self._state)
)

def statistics(self):
def statistics(self) -> MemoryChannelStats:
# XX should we also report statistics specific to this object?
return self._state.statistics()

@enable_ki_protection
def send_nowait(self, value):
def send_nowait(self, value: SendType) -> None:
"""Like `~trio.abc.SendChannel.send`, but if the channel's buffer is
full, raises `WouldBlock` instead of blocking.
Expand All @@ -150,7 +168,7 @@ def send_nowait(self, value):
raise trio.WouldBlock

@enable_ki_protection
async def send(self, value):
async def send(self, value: SendType) -> None:
"""See `SendChannel.send <trio.abc.SendChannel.send>`.
Memory channels allow multiple tasks to call `send` at the same time.
Expand All @@ -170,15 +188,16 @@ async def send(self, value):
self._state.send_tasks[task] = value
task.custom_sleep_data = self

def abort_fn(_):
def abort_fn(_) -> Abort:
self._tasks.remove(task)
del self._state.send_tasks[task]
return trio.lowlevel.Abort.SUCCEEDED

await trio.lowlevel.wait_task_rescheduled(abort_fn)

# Return type must be stringified, use a TypeVar, or (py311+) use typing.Self
@enable_ki_protection
def clone(self):
def clone(self) -> "MemorySendChannel[SendType]":
"""Clone this send channel object.
This returns a new `MemorySendChannel` object, which acts as a
Expand Down Expand Up @@ -206,14 +225,19 @@ def clone(self):
raise trio.ClosedResourceError
return MemorySendChannel._create(self._state)

def __enter__(self):
def __enter__(self) -> "MemorySendChannel[SendType]":
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.close()

@enable_ki_protection
def close(self):
def close(self) -> None:
"""Close this send channel object synchronously.
All channel objects have an asynchronous `~.AsyncResource.aclose` method.
Expand Down Expand Up @@ -241,30 +265,32 @@ def close(self):
self._state.receive_tasks.clear()

@enable_ki_protection
async def aclose(self):
async def aclose(self) -> None:
self.close()
await trio.lowlevel.checkpoint()


@attr.s(eq=False, repr=False)
class MemoryReceiveChannel(ReceiveChannel, metaclass=NoPublicConstructor):
_state = attr.ib()
_closed = attr.ib(default=False)
_tasks = attr.ib(factory=set)

def __attrs_post_init__(self):
class MemoryReceiveChannel(
ReceiveChannel, Generic[ReceiveType], metaclass=NoPublicConstructor
):
_state: MemoryChannelState[ReceiveType] = attr.ib()
_closed: bool = attr.ib(default=False)
_tasks: set[trio._core._run.Task] = attr.ib(factory=set)

def __attrs_post_init__(self) -> None:
self._state.open_receive_channels += 1

def statistics(self):
def statistics(self) -> MemoryChannelStats:
return self._state.statistics()

def __repr__(self):
def __repr__(self) -> str:
return "<receive channel at {:#x}, using buffer at {:#x}>".format(
id(self), id(self._state)
)

@enable_ki_protection
def receive_nowait(self):
def receive_nowait(self) -> ReceiveType:
"""Like `~trio.abc.ReceiveChannel.receive`, but if there's nothing
ready to receive, raises `WouldBlock` instead of blocking.
Expand All @@ -284,7 +310,7 @@ def receive_nowait(self):
raise trio.WouldBlock

@enable_ki_protection
async def receive(self):
async def receive(self) -> ReceiveType:
"""See `ReceiveChannel.receive <trio.abc.ReceiveChannel.receive>`.
Memory channels allow multiple tasks to call `receive` at the same
Expand All @@ -306,15 +332,15 @@ async def receive(self):
self._state.receive_tasks[task] = None
task.custom_sleep_data = self

def abort_fn(_):
def abort_fn(_) -> Abort:
self._tasks.remove(task)
del self._state.receive_tasks[task]
return trio.lowlevel.Abort.SUCCEEDED

return await trio.lowlevel.wait_task_rescheduled(abort_fn)
return await trio.lowlevel.wait_task_rescheduled(abort_fn) # type: ignore

@enable_ki_protection
def clone(self):
def clone(self) -> "MemoryReceiveChannel[ReceiveType]":
"""Clone this receive channel object.
This returns a new `MemoryReceiveChannel` object, which acts as a
Expand Down Expand Up @@ -345,14 +371,19 @@ def clone(self):
raise trio.ClosedResourceError
return MemoryReceiveChannel._create(self._state)

def __enter__(self):
def __enter__(self) -> "MemoryReceiveChannel[ReceiveType]":
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.close()

@enable_ki_protection
def close(self):
def close(self) -> None:
"""Close this receive channel object synchronously.
All channel objects have an asynchronous `~.AsyncResource.aclose` method.
Expand Down Expand Up @@ -381,6 +412,6 @@ def close(self):
self._state.data.clear()

@enable_ki_protection
async def aclose(self):
async def aclose(self) -> None:
self.close()
await trio.lowlevel.checkpoint()

0 comments on commit 9674e6b

Please sign in to comment.