Skip to content

Commit

Permalink
adress comments, removing redundant Generic, fixing open_memory_chann…
Browse files Browse the repository at this point in the history
…el return type, and adding trio-typings class workaround to get type-checking on open_memory_channel by hiding it behind a TYPE_CHECKING guard
  • Loading branch information
jakkdl committed Feb 1, 2023
1 parent c653319 commit d5fd09f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
36 changes: 26 additions & 10 deletions trio/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Generic,
NoReturn,
TypeVar,
TYPE_CHECKING,
)

import attr
Expand Down Expand Up @@ -38,8 +39,7 @@
SelfT = TypeVar("SelfT")


@generic_function
def open_memory_channel(
def _open_memory_channel(
max_buffer_size: int,
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
"""Open a channel for passing objects between tasks within a process.
Expand Down Expand Up @@ -99,11 +99,31 @@ def open_memory_channel(
raise ValueError("max_buffer_size must be >= 0")
state: MemoryChannelState[T] = MemoryChannelState(max_buffer_size)
return (
MemorySendChannel._create(state),
MemoryReceiveChannel._create(state),
MemorySendChannel[T]._create(state),
MemoryReceiveChannel[T]._create(state),
)


# This workaround requires python3.9+, once older python versions are not supported
# or there's a better way of achieving type-checking on a generic factory function,
# it could replace the normal function header
if TYPE_CHECKING:
# written as a class so you can say open_memory_channel[int](5)
class open_memory_channel(tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]):
def __new__( # type: ignore[misc] # "must return a subtype"
cls, max_buffer_size: int
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
return _open_memory_channel(max_buffer_size)

def __init__(self, max_buffer_size: int):
...

else:
# apply the generic_function decorator to make open_memory_channel indexable
# so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime
open_memory_channel = generic_function(_open_memory_channel)


@attr.s(frozen=True, slots=True)
class MemoryChannelStats:
current_buffer_used: int = attr.ib()
Expand Down Expand Up @@ -138,9 +158,7 @@ def statistics(self) -> MemoryChannelStats:


@attr.s(eq=False, repr=False)
class MemorySendChannel(
SendChannel[SendType], Generic[SendType], metaclass=NoPublicConstructor
):
class MemorySendChannel(SendChannel[SendType], metaclass=NoPublicConstructor):
_state: MemoryChannelState[SendType] = attr.ib()
_closed: bool = attr.ib(default=False)
# This is just the tasks waiting on *this* object. As compared to
Expand Down Expand Up @@ -284,9 +302,7 @@ async def aclose(self) -> None:


@attr.s(eq=False, repr=False)
class MemoryReceiveChannel(
ReceiveChannel[ReceiveType], Generic[ReceiveType], metaclass=NoPublicConstructor
):
class MemoryReceiveChannel(ReceiveChannel[ReceiveType], metaclass=NoPublicConstructor):
_state: MemoryChannelState[ReceiveType] = attr.ib()
_closed: bool = attr.ib(default=False)
_tasks: set[trio._core._run.Task] = attr.ib(factory=set)
Expand Down
3 changes: 2 additions & 1 deletion trio/_core/_traps.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class Abort(enum.Enum):
class WaitTaskRescheduled:
abort_func = attr.ib()

RaiseCancelT = Callable[[], NoReturn] # TypeAlias

RaiseCancelT = Callable[[], NoReturn] # TypeAlias

# Can this function be retyped to return something other than Any?
async def wait_task_rescheduled(abort_func: Callable[[RaiseCancelT], Abort]) -> Any:
Expand Down

0 comments on commit d5fd09f

Please sign in to comment.