Skip to content

Commit

Permalink
Add generic typing support for MemorySendChannel and MemoryReceiveCha…
Browse files Browse the repository at this point in the history
…nnel
  • Loading branch information
jakkdl committed Jan 28, 2023
1 parent f2a71a0 commit f7df07b
Showing 1 changed file with 90 additions and 49 deletions.
139 changes: 90 additions & 49 deletions trio/_channel.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,46 @@
from collections import deque, OrderedDict
from math import inf

from types import TracebackType
from typing import (
Generic,
TypeVar,
Any,
Tuple,
Optional,
Type,
Deque,
OrderedDict as T_OrderedDict,
Set,
)

import attr
from outcome import Error, Value

from .abc import SendChannel, ReceiveChannel, Channel
from ._util import generic_function, NoPublicConstructor

import trio
from ._core import enable_ki_protection
from ._core import enable_ki_protection, Task, Abort

# A regular invariant generic type
T = TypeVar("T")

# The type of object produced by a ReceiveChannel (covariant because
# ReceiveChannel[Derived] can be passed to someone expecting
# ReceiveChannel[Base])
ReceiveType = TypeVar("ReceiveType", covariant=True)

# The type of object accepted by a SendChannel (contravariant because
# SendChannel[Base] can be passed to someone expecting
# SendChannel[Derived])
SendType = TypeVar("SendType", contravariant=True)


@generic_function
def open_memory_channel(max_buffer_size):
def open_memory_channel(
max_buffer_size: int,
) -> Tuple[SendChannel[T], ReceiveChannel[T]]:
"""Open a channel for passing objects between tasks within a process.
Memory channels are lightweight, cheap to allocate, and entirely
Expand Down Expand Up @@ -68,7 +96,7 @@ def open_memory_channel(max_buffer_size):
raise TypeError("max_buffer_size must be an integer or math.inf")
if max_buffer_size < 0:
raise ValueError("max_buffer_size must be >= 0")
state = MemoryChannelState(max_buffer_size)
state: MemoryChannelState[T] = MemoryChannelState(max_buffer_size)
return (
MemorySendChannel._create(state),
MemoryReceiveChannel._create(state),
Expand All @@ -77,27 +105,27 @@ def open_memory_channel(max_buffer_size):

@attr.s(frozen=True, slots=True)
class MemoryChannelStats:
current_buffer_used = attr.ib()
max_buffer_size = attr.ib()
open_send_channels = attr.ib()
open_receive_channels = attr.ib()
tasks_waiting_send = attr.ib()
tasks_waiting_receive = attr.ib()
current_buffer_used: int = attr.ib()
max_buffer_size: int = attr.ib()
open_send_channels: int = attr.ib()
open_receive_channels: int = attr.ib()
tasks_waiting_send: int = attr.ib()
tasks_waiting_receive: int = attr.ib()


@attr.s(slots=True)
class MemoryChannelState:
max_buffer_size = attr.ib()
data = attr.ib(factory=deque)
class MemoryChannelState(Generic[T]):
max_buffer_size: int = attr.ib()
data: Deque[T] = attr.ib(factory=deque)
# Counts of open endpoints using this state
open_send_channels = attr.ib(default=0)
open_receive_channels = attr.ib(default=0)
open_send_channels: int = attr.ib(default=0)
open_receive_channels: int = attr.ib(default=0)
# {task: value}
send_tasks = attr.ib(factory=OrderedDict)
send_tasks: T_OrderedDict[Task, T] = attr.ib(factory=OrderedDict)
# {task: None}
receive_tasks = attr.ib(factory=OrderedDict)
receive_tasks: T_OrderedDict[Task, None] = attr.ib(factory=OrderedDict)

def statistics(self):
def statistics(self) -> MemoryChannelStats:
return MemoryChannelStats(
current_buffer_used=len(self.data),
max_buffer_size=self.max_buffer_size,
Expand All @@ -109,28 +137,28 @@ def statistics(self):


@attr.s(eq=False, repr=False)
class MemorySendChannel(SendChannel, metaclass=NoPublicConstructor):
_state = attr.ib()
_closed = attr.ib(default=False)
class MemorySendChannel(SendChannel, Generic[SendType], metaclass=NoPublicConstructor):
_state: MemoryChannelState[SendType] = attr.ib()
_closed: bool = attr.ib(default=False)
# This is just the tasks waiting on *this* object. As compared to
# self._state.send_tasks, which includes tasks from this object and
# all clones.
_tasks = attr.ib(factory=set)
_tasks: Set[Task] = attr.ib(factory=set)

def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
self._state.open_send_channels += 1

def __repr__(self):
def __repr__(self) -> str:
return "<send channel at {:#x}, using buffer at {:#x}>".format(
id(self), id(self._state)
)

def statistics(self):
def statistics(self) -> MemoryChannelStats:
# XX should we also report statistics specific to this object?
return self._state.statistics()

@enable_ki_protection
def send_nowait(self, value):
def send_nowait(self, value: SendType) -> None:
"""Like `~trio.abc.SendChannel.send`, but if the channel's buffer is
full, raises `WouldBlock` instead of blocking.
Expand All @@ -150,7 +178,7 @@ def send_nowait(self, value):
raise trio.WouldBlock

@enable_ki_protection
async def send(self, value):
async def send(self, value: SendType) -> None:
"""See `SendChannel.send <trio.abc.SendChannel.send>`.
Memory channels allow multiple tasks to call `send` at the same time.
Expand All @@ -170,15 +198,16 @@ async def send(self, value):
self._state.send_tasks[task] = value
task.custom_sleep_data = self

def abort_fn(_):
def abort_fn(_) -> Abort:
self._tasks.remove(task)
del self._state.send_tasks[task]
return trio.lowlevel.Abort.SUCCEEDED

await trio.lowlevel.wait_task_rescheduled(abort_fn)

# Return type must be stringified, use a TypeVar, or (py311+) use typing.Self
@enable_ki_protection
def clone(self):
def clone(self) -> "MemorySendChannel[SendType]":
"""Clone this send channel object.
This returns a new `MemorySendChannel` object, which acts as a
Expand Down Expand Up @@ -206,14 +235,19 @@ def clone(self):
raise trio.ClosedResourceError
return MemorySendChannel._create(self._state)

def __enter__(self):
def __enter__(self) -> "MemorySendChannel[SendType]":
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.close()

@enable_ki_protection
def close(self):
def close(self) -> None:
"""Close this send channel object synchronously.
All channel objects have an asynchronous `~.AsyncResource.aclose` method.
Expand Down Expand Up @@ -241,30 +275,32 @@ def close(self):
self._state.receive_tasks.clear()

@enable_ki_protection
async def aclose(self):
async def aclose(self) -> None:
self.close()
await trio.lowlevel.checkpoint()


@attr.s(eq=False, repr=False)
class MemoryReceiveChannel(ReceiveChannel, metaclass=NoPublicConstructor):
_state = attr.ib()
_closed = attr.ib(default=False)
_tasks = attr.ib(factory=set)

def __attrs_post_init__(self):
class MemoryReceiveChannel(
ReceiveChannel, Generic[ReceiveType], metaclass=NoPublicConstructor
):
_state: MemoryChannelState[ReceiveType] = attr.ib()
_closed: bool = attr.ib(default=False)
_tasks: Set[trio._core._run.Task] = attr.ib(factory=set)

def __attrs_post_init__(self) -> None:
self._state.open_receive_channels += 1

def statistics(self):
def statistics(self) -> MemoryChannelStats:
return self._state.statistics()

def __repr__(self):
def __repr__(self) -> str:
return "<receive channel at {:#x}, using buffer at {:#x}>".format(
id(self), id(self._state)
)

@enable_ki_protection
def receive_nowait(self):
def receive_nowait(self) -> ReceiveType:
"""Like `~trio.abc.ReceiveChannel.receive`, but if there's nothing
ready to receive, raises `WouldBlock` instead of blocking.
Expand All @@ -284,7 +320,7 @@ def receive_nowait(self):
raise trio.WouldBlock

@enable_ki_protection
async def receive(self):
async def receive(self) -> ReceiveType:
"""See `ReceiveChannel.receive <trio.abc.ReceiveChannel.receive>`.
Memory channels allow multiple tasks to call `receive` at the same
Expand All @@ -306,15 +342,15 @@ async def receive(self):
self._state.receive_tasks[task] = None
task.custom_sleep_data = self

def abort_fn(_):
def abort_fn(_) -> Abort:
self._tasks.remove(task)
del self._state.receive_tasks[task]
return trio.lowlevel.Abort.SUCCEEDED

return await trio.lowlevel.wait_task_rescheduled(abort_fn)
return await trio.lowlevel.wait_task_rescheduled(abort_fn) # type: ignore

@enable_ki_protection
def clone(self):
def clone(self) -> "MemoryReceiveChannel[ReceiveType]":
"""Clone this receive channel object.
This returns a new `MemoryReceiveChannel` object, which acts as a
Expand Down Expand Up @@ -345,14 +381,19 @@ def clone(self):
raise trio.ClosedResourceError
return MemoryReceiveChannel._create(self._state)

def __enter__(self):
def __enter__(self) -> "MemoryReceiveChannel[ReceiveType]":
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.close()

@enable_ki_protection
def close(self):
def close(self) -> None:
"""Close this receive channel object synchronously.
All channel objects have an asynchronous `~.AsyncResource.aclose` method.
Expand Down Expand Up @@ -381,6 +422,6 @@ def close(self):
self._state.data.clear()

@enable_ki_protection
async def aclose(self):
async def aclose(self) -> None:
self.close()
await trio.lowlevel.checkpoint()

0 comments on commit f7df07b

Please sign in to comment.