-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core][experimental] Avoid serialization for data passed between two tasks on the same actor #45591
Changes from 6 commits
ee76411
74ad669
26903ae
03bac31
5972446
7cd549b
97292b1
d1a025c
c540b94
cc2e897
3170419
2ab160d
a25e443
af458bf
316b408
ad592e7
777a284
ec98928
a298811
980cc82
085bcf9
f49593c
53373ba
c16d8a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,8 @@ | |
get_or_create_event_loop, | ||
) | ||
|
||
from ray.experimental.channel.shared_memory_channel import MultiChannel | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -531,6 +533,66 @@ async def main(): | |
compiled_dag.teardown() | ||
|
||
|
||
class TestMultiChannel: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit confused by these tests. I thought MultiChannel would need multiple nodes to read the same result? But here we are only testing chain DAGs where each node is read by exactly one other node, so why do we need MultiChannels? For example, I thought this would require a MultiChannel: with InputNode() as inp:
dag = a.inc.bind(inp)
dag = MultiOutputNode([a.inc.bind(dag), b.inc.bind(dag)]) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a298811 |
||
def test_multi_channel_one_actor(self, ray_start_regular_shared): | ||
a = Actor.remote(0) | ||
with InputNode() as inp: | ||
dag = a.inc.bind(inp) | ||
dag = a.inc.bind(dag) | ||
dag = a.inc.bind(dag) | ||
|
||
compiled_dag = dag.experimental_compile() | ||
assert len(compiled_dag.actor_to_tasks) == 1 | ||
|
||
num_local_channels = 0 | ||
num_remote_channels = 0 | ||
for tasks in compiled_dag.actor_to_tasks.values(): | ||
assert len(tasks) == 3 | ||
for task in tasks: | ||
assert isinstance(task.output_channel, MultiChannel) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm ideally unit tests should not assume a certain implementation or that certain internal fields are available to call. Can you perhaps do something like a mocked approach where we check how many Channel vs IntraProcessChannel constructors or methods are called by the end of the test? You can check out the |
||
if task.output_channel._local_channel: | ||
num_local_channels += 1 | ||
if task.output_channel._remote_channel: | ||
num_remote_channels += 1 | ||
assert num_local_channels == 2 | ||
kevin85421 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert num_remote_channels == 1 | ||
output_channel = compiled_dag.execute(1) | ||
result = output_channel.begin_read() | ||
assert result == 4 | ||
kevin85421 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
output_channel.end_read() | ||
compiled_dag.teardown() | ||
|
||
def test_multi_channel_two_actors(self, ray_start_regular_shared): | ||
a = Actor.remote(0) | ||
b = Actor.remote(100) | ||
with InputNode() as inp: | ||
dag = a.inc.bind(inp) | ||
dag = b.inc.bind(dag) | ||
dag = a.inc.bind(dag) | ||
|
||
compiled_dag = dag.experimental_compile() | ||
assert len(compiled_dag.actor_to_tasks) == 2 | ||
|
||
a_tasks = compiled_dag.actor_to_tasks[a] | ||
assert len(a_tasks) == 2 | ||
for task in a_tasks: | ||
assert isinstance(task.output_channel, MultiChannel) | ||
assert not task.output_channel._local_channel | ||
assert task.output_channel._remote_channel | ||
|
||
b_tasks = compiled_dag.actor_to_tasks[b] | ||
assert len(b_tasks) == 1 | ||
assert isinstance(b_tasks[0].output_channel, MultiChannel) | ||
assert not b_tasks[0].output_channel._local_channel | ||
assert b_tasks[0].output_channel._remote_channel | ||
|
||
output_channel = compiled_dag.execute(1) | ||
result = output_channel.begin_read() | ||
assert result == 102 | ||
output_channel.end_read() | ||
compiled_dag.teardown() | ||
|
||
|
||
if __name__ == "__main__": | ||
if os.environ.get("PARALLEL_CI"): | ||
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import uuid | ||
from typing import Any | ||
|
||
import ray | ||
from ray.experimental.channel import ChannelContext | ||
from ray.experimental.channel.common import ChannelInterface | ||
from ray.util.annotations import PublicAPI | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class LocalChannel(ChannelInterface): | ||
kevin85421 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__( | ||
self, | ||
actor_handle: ray.actor.ActorHandle, | ||
): | ||
# TODO (kevin85421): Currently, if we don't pass `actor_handle` to | ||
# `LocalChannel`, the actor will die due to the reference count of | ||
# `actor_handle` is 0. We should fix this issue in the future. | ||
self._actor_handle = actor_handle | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without passing the actor handle, the actor will be killed due to the reference count. This is a Ray Core bug. I will fix it in a followup PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice, thanks. I think it's also fine if you want to just file an issue for this and we can address it later. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By the way, I saw what I think is the same issue and should have a fix for it soon. |
||
self.channel_id = str(uuid.uuid4()) | ||
kevin85421 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def ensure_registered_as_writer(self) -> None: | ||
pass | ||
|
||
def ensure_registered_as_reader(self) -> None: | ||
pass | ||
|
||
def __reduce__(self): | ||
return LocalChannel, (self._actor_handle,) | ||
|
||
def write(self, value: Any): | ||
# Because both the reader and writer are in the same worker process, | ||
# we can directly store the data in the context instead of storing | ||
# it in the channel object. This reduces the serialization overhead of `value`. | ||
ctx = ChannelContext.get_current().serialization_context | ||
ctx.set_data(self.channel_id, value) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume this just writes an object reference to the serialization context? (or in the case of a primitive value, the entire primitive value is written?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure whether I understand this correctly or not. We don't serialize the data because the reader and writer are in the same actor process. |
||
|
||
def begin_read(self) -> Any: | ||
ctx = ChannelContext.get_current().serialization_context | ||
return ctx.get_data(self.channel_id) | ||
|
||
def end_read(self): | ||
pass | ||
|
||
def close(self) -> None: | ||
ctx = ChannelContext.get_current().serialization_context | ||
ctx.reset_data(self.channel_id) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
import ray | ||
from ray._raylet import SerializedObject | ||
from ray.experimental.channel.common import ChannelInterface, ChannelOutputType | ||
from ray.experimental.channel.local_channel import LocalChannel | ||
from ray.experimental.channel.torch_tensor_type import TorchTensorType | ||
from ray.util.annotations import PublicAPI | ||
|
||
|
@@ -128,7 +129,7 @@ def create_channel( | |
cpu_data_typ=cpu_data_typ, | ||
) | ||
|
||
return Channel(writer, readers) | ||
return MultiChannel(writer, readers) | ||
|
||
def set_nccl_group_id(self, group_id: str) -> None: | ||
assert self.requires_nccl() | ||
|
@@ -380,3 +381,75 @@ def close(self) -> None: | |
if self.is_local_node(self._reader_node_id): | ||
self.ensure_registered_as_reader() | ||
self._worker.core_worker.experimental_channel_set_error(self._reader_ref) | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class MultiChannel(ChannelInterface): | ||
kevin85421 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__( | ||
self, | ||
writer: Optional[ray.actor.ActorHandle], | ||
readers: List[Optional[ray.actor.ActorHandle]], | ||
local_channel: Optional[LocalChannel] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @stephanie-wang I used def write(self, value):
for ch in self.channels:
ch.write(value) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm what if we support an interface like channels_dict, keyed by reader actor ID? Then each reader can retrieve its channel in A nice thing about this interface is that it also supports the case where readers are on different nodes (cc @jackhumphries). channel_dict: Dict[ray.ActorID, ChannelInterface] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated af458bf |
||
remote_channel: Optional[Channel] = None, | ||
): | ||
self._writer = writer | ||
self._readers = readers | ||
self._local_channel = local_channel | ||
self._remote_channel = remote_channel | ||
|
||
remote_readers = [] | ||
for reader in self._readers: | ||
if self._writer != reader: | ||
remote_readers.append(reader) | ||
# There are some local readers which are the same Ray actor as the writer. | ||
kevin85421 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Create a local channel for the writer and the local readers. | ||
if not self._local_channel and len(remote_readers) != len(self._readers): | ||
self._local_channel = LocalChannel(self._writer) | ||
# There are some remote readers which are not the same Ray actor as the writer. | ||
# Create a shared memory channel for the writer and the remote readers. | ||
if not self._remote_channel and len(remote_readers) != 0: | ||
self._remote_channel = Channel(self._writer, remote_readers) | ||
assert hasattr(self, "_local_channel") or hasattr(self, "_remote_channel") | ||
|
||
def ensure_registered_as_writer(self) -> None: | ||
if self._remote_channel: | ||
self._remote_channel.ensure_registered_as_writer() | ||
|
||
def ensure_registered_as_reader(self) -> None: | ||
if self._remote_channel: | ||
self._remote_channel.ensure_registered_as_reader() | ||
|
||
def __reduce__(self): | ||
return MultiChannel, ( | ||
self._writer, | ||
self._readers, | ||
self._local_channel, | ||
self._remote_channel, | ||
) | ||
|
||
def write(self, value: Any): | ||
if self._local_channel: | ||
self._local_channel.write(value) | ||
if self._remote_channel: | ||
self._remote_channel.write(value) | ||
|
||
def use_local_channel(self): | ||
kevin85421 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
current_actor_id = ray.get_runtime_context().get_actor_id() | ||
if not current_actor_id or not self._writer: | ||
# We are calling from the driver, or the writer is the driver. | ||
return False | ||
return current_actor_id == self._writer._actor_id.hex() | ||
|
||
def begin_read(self) -> Any: | ||
if self.use_local_channel(): | ||
return self._local_channel.begin_read() | ||
else: | ||
kevin85421 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self._remote_channel.begin_read() | ||
|
||
def end_read(self): | ||
if not self.use_local_channel(): | ||
self._remote_channel.end_read() | ||
|
||
def close(self) -> None: | ||
if not self.use_local_channel(): | ||
self._remote_channel.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function doesn't have an argument
buffer_size_bytes
.