Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: kaihsun <kaihsun@anyscale.com>
  • Loading branch information
kevin85421 committed May 31, 2024
1 parent 7cd549b commit 97292b1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
7 changes: 7 additions & 0 deletions python/ray/experimental/channel/local_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ def __init__(
self,
actor_handle: ray.actor.ActorHandle,
):
"""
LocalChannel is a channel for communication between two tasks in the same
worker process. It writes data directly to the worker's serialization context
and reads data from the serialization context to avoid the serialization
overhead and the need for reading/writing from shared memory.
"""

# TODO (kevin85421): Currently, if we don't pass `actor_handle` to
# `LocalChannel`, the actor will die due to the reference count of
# `actor_handle` is 0. We should fix this issue in the future.
Expand Down
26 changes: 20 additions & 6 deletions python/ray/experimental/channel/shared_memory_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,13 +389,19 @@ def __init__(
self,
writer: Optional[ray.actor.ActorHandle],
readers: List[Optional[ray.actor.ActorHandle]],
local_channel: Optional[LocalChannel] = None,
remote_channel: Optional[Channel] = None,
_local_channel: Optional[LocalChannel] = None,
_remote_channel: Optional[Channel] = None,
):
"""
Can be used to send data to different readers via different channels.
For example, if the reader is in the same worker process as the writer,
the data can be sent via LocalChannel. If the reader is in a different
worker process, the data can be sent via shared memory channel.
"""
self._writer = writer
self._readers = readers
self._local_channel = local_channel
self._remote_channel = remote_channel
self._local_channel = _local_channel
self._remote_channel = _remote_channel

remote_readers = []
for reader in self._readers:
Expand All @@ -412,10 +418,14 @@ def __init__(
assert hasattr(self, "_local_channel") or hasattr(self, "_remote_channel")

def ensure_registered_as_writer(self) -> None:
if self._local_channel:
self._local_channel.ensure_registered_as_writer()
if self._remote_channel:
self._remote_channel.ensure_registered_as_writer()

def ensure_registered_as_reader(self) -> None:
if self._local_channel:
self._local_channel.ensure_registered_as_reader()
if self._remote_channel:
self._remote_channel.ensure_registered_as_reader()

Expand Down Expand Up @@ -447,9 +457,13 @@ def begin_read(self) -> Any:
return self._remote_channel.begin_read()

def end_read(self):
if not self.use_local_channel():
if self.use_local_channel():
self._local_channel.end_read()
else:
self._remote_channel.end_read()

def close(self) -> None:
if not self.use_local_channel():
if self._local_channel:
self._local_channel.close()
if self._remote_channel:
self._remote_channel.close()

0 comments on commit 97292b1

Please sign in to comment.