Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][experimental] Avoid serialization for data passed between two tasks on the same actor #45591

Merged
merged 24 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def do_allocate_channel(

Args:
readers: The actor handles of the readers.
buffer_size_bytes: The maximum size of messages in the channel.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function doesn't have an argument buffer_size_bytes.

typ: The output type hint for the channel.

Returns:
The allocated channel.
Expand Down
62 changes: 62 additions & 0 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
get_or_create_event_loop,
)

from ray.experimental.channel.shared_memory_channel import MultiChannel


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -531,6 +533,66 @@ async def main():
compiled_dag.teardown()


class TestMultiChannel:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused by these tests. I thought MultiChannel would need multiple nodes to read the same result? But here we are only testing chain DAGs where each node is read by exactly one other node, so why do we need MultiChannels?

For example, I thought this would require a MultiChannel:

with InputNode() as inp:
  dag = a.inc.bind(inp)
  dag = MultiOutputNode([a.inc.bind(dag), b.inc.bind(dag)])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a298811

def test_multi_channel_one_actor(self, ray_start_regular_shared):
a = Actor.remote(0)
with InputNode() as inp:
dag = a.inc.bind(inp)
dag = a.inc.bind(dag)
dag = a.inc.bind(dag)

compiled_dag = dag.experimental_compile()
assert len(compiled_dag.actor_to_tasks) == 1

num_local_channels = 0
num_remote_channels = 0
for tasks in compiled_dag.actor_to_tasks.values():
assert len(tasks) == 3
for task in tasks:
assert isinstance(task.output_channel, MultiChannel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm ideally unit tests should not assume a certain implementation or that certain internal fields are available to call. Can you perhaps do something like a mocked approach where we check how many Channel vs IntraProcessChannel constructors or methods are called by the end of the test? You can check out the TracedChannel test util that I added in this open PR for an example of this.

if task.output_channel._local_channel:
num_local_channels += 1
if task.output_channel._remote_channel:
num_remote_channels += 1
assert num_local_channels == 2
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
assert num_remote_channels == 1
output_channel = compiled_dag.execute(1)
result = output_channel.begin_read()
assert result == 4
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
output_channel.end_read()
compiled_dag.teardown()

def test_multi_channel_two_actors(self, ray_start_regular_shared):
a = Actor.remote(0)
b = Actor.remote(100)
with InputNode() as inp:
dag = a.inc.bind(inp)
dag = b.inc.bind(dag)
dag = a.inc.bind(dag)

compiled_dag = dag.experimental_compile()
assert len(compiled_dag.actor_to_tasks) == 2

a_tasks = compiled_dag.actor_to_tasks[a]
assert len(a_tasks) == 2
for task in a_tasks:
assert isinstance(task.output_channel, MultiChannel)
assert not task.output_channel._local_channel
assert task.output_channel._remote_channel

b_tasks = compiled_dag.actor_to_tasks[b]
assert len(b_tasks) == 1
assert isinstance(b_tasks[0].output_channel, MultiChannel)
assert not b_tasks[0].output_channel._local_channel
assert b_tasks[0].output_channel._remote_channel

output_channel = compiled_dag.execute(1)
result = output_channel.begin_read()
assert result == 102
output_channel.end_read()
compiled_dag.teardown()


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
54 changes: 54 additions & 0 deletions python/ray/experimental/channel/local_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import uuid
from typing import Any

import ray
from ray.experimental.channel import ChannelContext
from ray.experimental.channel.common import ChannelInterface
from ray.util.annotations import PublicAPI


@PublicAPI(stability="alpha")
class LocalChannel(ChannelInterface):
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
actor_handle: ray.actor.ActorHandle,
):
"""
LocalChannel is a channel for communication between two tasks in the same
worker process. It writes data directly to the worker's serialization context
and reads data from the serialization context to avoid the serialization
overhead and the need for reading/writing from shared memory.
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
"""

# TODO (kevin85421): Currently, if we don't pass `actor_handle` to
# `LocalChannel`, the actor will die due to the reference count of
# `actor_handle` is 0. We should fix this issue in the future.
self._actor_handle = actor_handle
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without passing the actor handle, the actor will be killed due to the reference count. This is a Ray Core bug. I will fix it in a followup PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks. I think it's also fine if you want to just file an issue for this and we can address it later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, I saw what I think is the same issue and should have a fix for it soon.

self.channel_id = str(uuid.uuid4())
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved

def ensure_registered_as_writer(self) -> None:
pass

def ensure_registered_as_reader(self) -> None:
pass

def __reduce__(self):
return LocalChannel, (self._actor_handle,)

def write(self, value: Any):
# Because both the reader and writer are in the same worker process,
# we can directly store the data in the context instead of storing
# it in the channel object. This reduces the serialization overhead of `value`.
ctx = ChannelContext.get_current().serialization_context
ctx.set_data(self.channel_id, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this just writes an object reference to the serialization context? (or in the case of a primitive value, the entire primitive value is written?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whether I understand this correctly or not. We don't serialize the data because the reader and writer are in the same actor process.


def begin_read(self) -> Any:
ctx = ChannelContext.get_current().serialization_context
return ctx.get_data(self.channel_id)

def end_read(self):
pass

def close(self) -> None:
ctx = ChannelContext.get_current().serialization_context
ctx.reset_data(self.channel_id)
12 changes: 11 additions & 1 deletion python/ray/experimental/channel/serialization_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

if TYPE_CHECKING:
import numpy as np
Expand All @@ -10,13 +10,23 @@ def __init__(self):
self.torch_device: Optional["torch.device"] = None
self.use_external_transport: bool = False
self.tensors: List["torch.Tensor"] = []
self.data: Optional[Dict[str, Any]] = {}

kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
def set_use_external_transport(self, use_external_transport: bool) -> None:
self.use_external_transport = use_external_transport

def set_torch_device(self, torch_device: "torch.device") -> None:
self.torch_device = torch_device

def set_data(self, channel_id: str, value: Any) -> None:
self.data[channel_id] = value

kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
def get_data(self, channel_id: str) -> Any:
return self.data.get(channel_id, None)

def reset_data(self, channel_id: str) -> Any:
self.data.pop(channel_id, None)
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved

def reset_tensors(self, tensors: List["torch.Tensor"]) -> List["torch.Tensor"]:
prev_tensors = self.tensors
self.tensors = tensors
Expand Down
89 changes: 88 additions & 1 deletion python/ray/experimental/channel/shared_memory_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ray
from ray._raylet import SerializedObject
from ray.experimental.channel.common import ChannelInterface, ChannelOutputType
from ray.experimental.channel.local_channel import LocalChannel
from ray.experimental.channel.torch_tensor_type import TorchTensorType
from ray.util.annotations import PublicAPI

Expand Down Expand Up @@ -128,7 +129,7 @@ def create_channel(
cpu_data_typ=cpu_data_typ,
)

return Channel(writer, readers)
return MultiChannel(writer, readers)

def set_nccl_group_id(self, group_id: str) -> None:
assert self.requires_nccl()
Expand Down Expand Up @@ -380,3 +381,89 @@ def close(self) -> None:
if self.is_local_node(self._reader_node_id):
self.ensure_registered_as_reader()
self._worker.core_worker.experimental_channel_set_error(self._reader_ref)


@PublicAPI(stability="alpha")
class MultiChannel(ChannelInterface):
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
writer: Optional[ray.actor.ActorHandle],
readers: List[Optional[ray.actor.ActorHandle]],
_local_channel: Optional[LocalChannel] = None,
_remote_channel: Optional[Channel] = None,
):
"""
Can be used to send data to different readers via different channels.
For example, if the reader is in the same worker process as the writer,
the data can be sent via LocalChannel. If the reader is in a different
worker process, the data can be sent via shared memory channel.
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
"""
self._writer = writer
self._readers = readers
self._local_channel = _local_channel
self._remote_channel = _remote_channel

remote_readers = []
for reader in self._readers:
if self._writer != reader:
remote_readers.append(reader)
# There are some local readers which are the same Ray actor as the writer.
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
# Create a local channel for the writer and the local readers.
if not self._local_channel and len(remote_readers) != len(self._readers):
self._local_channel = LocalChannel(self._writer)
# There are some remote readers which are not the same Ray actor as the writer.
# Create a shared memory channel for the writer and the remote readers.
if not self._remote_channel and len(remote_readers) != 0:
self._remote_channel = Channel(self._writer, remote_readers)
assert hasattr(self, "_local_channel") or hasattr(self, "_remote_channel")

def ensure_registered_as_writer(self) -> None:
if self._local_channel:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LocalChannel's ensure_registered_as_writer() doesn't currently do anything. I am calling the function here to ensure we don't forget to update it should we make any changes to ensure_registered_as_writer() in LocalChannel in the future. Some following functions have the same reason.

self._local_channel.ensure_registered_as_writer()
if self._remote_channel:
self._remote_channel.ensure_registered_as_writer()

def ensure_registered_as_reader(self) -> None:
if self._local_channel:
self._local_channel.ensure_registered_as_reader()
if self._remote_channel:
self._remote_channel.ensure_registered_as_reader()

def __reduce__(self):
return MultiChannel, (
self._writer,
self._readers,
self._local_channel,
self._remote_channel,
)

def write(self, value: Any):
if self._local_channel:
self._local_channel.write(value)
if self._remote_channel:
self._remote_channel.write(value)

def use_local_channel(self):
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
current_actor_id = ray.get_runtime_context().get_actor_id()
if not current_actor_id or not self._writer:
# We are calling from the driver, or the writer is the driver.
return False
return current_actor_id == self._writer._actor_id.hex()

def begin_read(self) -> Any:
if self.use_local_channel():
return self._local_channel.begin_read()
else:
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
return self._remote_channel.begin_read()

def end_read(self):
if self.use_local_channel():
self._local_channel.end_read()
else:
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
self._remote_channel.end_read()

def close(self) -> None:
if self._local_channel:
self._local_channel.close()
if self._remote_channel:
self._remote_channel.close()
Loading