Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][experimental] Avoid serialization for data passed between two tasks on the same actor #45591

Merged
merged 24 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def do_allocate_channel(

Args:
readers: The actor handles of the readers.
buffer_size_bytes: The maximum size of messages in the channel.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function doesn't have an argument buffer_size_bytes.

typ: The output type hint for the channel.

Returns:
The allocated channel.
Expand Down Expand Up @@ -129,7 +129,7 @@ def _exec_task(self, task: "ExecutableTask", idx: int) -> bool:
True if we are done executing all tasks of this actor, False otherwise.
"""
# TODO: for cases where output is passed as input to a task on
# the same actor, introduce a "LocalChannel" to avoid the overhead
# the same actor, introduce a "IntraProcessChannel" to avoid the overhead
# of serialization/deserialization and synchronization.
method = getattr(self, task.method_name)
input_reader = self._input_readers[idx]
Expand Down Expand Up @@ -649,12 +649,12 @@ def _get_or_compile(

if isinstance(task.dag_node, ClassMethodNode):
readers = [self.idx_to_task[idx] for idx in task.downstream_node_idxs]
assert len(readers) == 1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion will make the following case fail: the actor a has two readers (actor b and itself)

a = Actor.remote(0)
b = Actor.remote(100)
with InputNode() as inp:
    dag = a.inc.bind(inp)
    dag = MultiOutputNode([a.inc.bind(dag), b.inc.bind(dag)])

compiled_dag = dag.experimental_compile()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a known bug, cc @jackhumphries

I think this check got added when Jack was looking into supporting multi-node. I think you could remove it, and it should be okay as long as all of the remote readers are on the same node.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of the remote readers are on the same node.

Why should all remote readers be on the same node?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now that is required because all of the remote readers will share one channel. But with your PR, we can create multiple channels, one for each unique reader node.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I will update my PR. Currently, it only creates one Channel for all remote readers (readers which are not on the same actor as the writer).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is currently something I need to work on. The hope is that this PR will make it easy at this point to add support for readers across multiple nodes, though you don't need to explicitly support that in this PR--I will take care of it once this PR is merged.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chat with @jackhumphries offline. He will work on this part. I will not include the multi-node support in this PR.


def _get_node_id(self):
return ray.get_runtime_context().get_node_id()

if isinstance(readers[0].dag_node, MultiOutputNode):
assert len(readers) == 1
# This node is a multi-output node, which means that it will only be
# read by the driver, not an actor. Thus, we handle this case by
# setting `reader_handles` to `[None]`.
Expand Down
113 changes: 113 additions & 0 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,119 @@ async def main():
compiled_dag.teardown()


class TestCompositeChannel:
def test_composite_channel_one_actor(self, ray_start_regular_shared):
"""
In this test, there are three 'inc' tasks on the same Ray actor, chained
together. Therefore, the DAG will look like this:

Driver -> a.inc -> a.inc -> a.inc -> Driver

All communication between the driver and the actor will be done through remote
channels, i.e., shared memory channels. All communication between the actor
tasks will be conducted through local channels, i.e., IntraProcessChannel in
this case.

To elaborate, all output channels of the actor DAG nodes will be
CompositeChannel, and the first two will have a local channel, while the last
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we use CompositeChannels here? Isn't there only one reader per channel?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there only one reader per channel?

Yes.

Why do we use CompositeChannels here?

In this PR, the SharedMemoryType's create_channel function creates a CompositeChannel instead of a Channel. The CompositeChannel will determine whether to use only IntraProcessChannel, only Channel, or both IntraProcessChannel and Channel to transfer the data under the hood.

one will have a remote channel.
"""
a = Actor.remote(0)
with InputNode() as inp:
dag = a.inc.bind(inp)
dag = a.inc.bind(dag)
dag = a.inc.bind(dag)

compiled_dag = dag.experimental_compile()
output_channel = compiled_dag.execute(1)
result = output_channel.begin_read()
assert result == 4
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
output_channel.end_read()

output_channel = compiled_dag.execute(2)
result = output_channel.begin_read()
assert result == 24
output_channel.end_read()

output_channel = compiled_dag.execute(3)
result = output_channel.begin_read()
assert result == 108
output_channel.end_read()

compiled_dag.teardown()

def test_composite_channel_two_actors(self, ray_start_regular_shared):
"""
In this test, there are three 'inc' tasks on the two Ray actors, chained
together. Therefore, the DAG will look like this:

Driver -> a.inc -> b.inc -> a.inc -> Driver

All communication between the driver and actors will be done through remote
channels. Also, all communication between the actor tasks will be conducted
through remote channels, i.e., shared memory channel in this case because no
consecutive tasks are on the same actor.
"""
a = Actor.remote(0)
b = Actor.remote(100)
with InputNode() as inp:
dag = a.inc.bind(inp)
dag = b.inc.bind(dag)
dag = a.inc.bind(dag)

# a: 0+1 -> b: 100+1 -> a: 1+101
compiled_dag = dag.experimental_compile()
output_channel = compiled_dag.execute(1)
result = output_channel.begin_read()
assert result == 102
output_channel.end_read()

# a: 102+2 -> b: 101+104 -> a: 104+205
output_channel = compiled_dag.execute(2)
result = output_channel.begin_read()
assert result == 309
output_channel.end_read()

# a: 309+3 -> b: 205+312 -> a: 312+517
output_channel = compiled_dag.execute(3)
result = output_channel.begin_read()
assert result == 829
output_channel.end_read()

compiled_dag.teardown()

def test_composite_channel_multi_output(self, ray_start_regular_shared):
"""
Driver -> a.inc -> a.inc ---> Driver
| |
-> b.inc -

All communication in this DAG will be done through CompositeChannel.
Under the hood, the communication between two `a.inc` tasks will
be done through a local channel, i.e., IntraProcessChannel in this
case, while the communication between `a.inc` and `b.inc` will be
done through a shared memory channel.
"""
a = Actor.remote(0)
b = Actor.remote(100)
with InputNode() as inp:
dag = a.inc.bind(inp)
dag = MultiOutputNode([a.inc.bind(dag), b.inc.bind(dag)])

compiled_dag = dag.experimental_compile()
output_channel = compiled_dag.execute(1)
result = output_channel.begin_read()
assert result == [2, 101]
output_channel.end_read()

output_channel = compiled_dag.execute(3)
result = output_channel.begin_read()
assert result == [10, 106]
output_channel.end_read()

compiled_dag.teardown()


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
5 changes: 4 additions & 1 deletion python/ray/experimental/channel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
SynchronousWriter,
WriterInterface,
)
from ray.experimental.channel.shared_memory_channel import Channel
from ray.experimental.channel.intra_process_channel import IntraProcessChannel
from ray.experimental.channel.shared_memory_channel import Channel, CompositeChannel
from ray.experimental.channel.torch_tensor_nccl_channel import TorchTensorNcclChannel

__all__ = [
Expand All @@ -22,4 +23,6 @@
"WriterInterface",
"ChannelContext",
"TorchTensorNcclChannel",
"IntraProcessChannel",
"CompositeChannel",
]
6 changes: 6 additions & 0 deletions python/ray/experimental/channel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,15 @@ def __init__(
pass

def ensure_registered_as_writer(self):
"""
Check whether the process is a valid writer. This method must be idempotent.
"""
raise NotImplementedError

def ensure_registered_as_reader(self):
"""
Check whether the process is a valid reader. This method must be idempotent.
"""
raise NotImplementedError

def write(self, value: Any) -> None:
Expand Down
65 changes: 65 additions & 0 deletions python/ray/experimental/channel/intra_process_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import uuid
from typing import Any, Optional

import ray
from ray.experimental.channel import ChannelContext
from ray.experimental.channel.common import ChannelInterface
from ray.util.annotations import PublicAPI


@PublicAPI(stability="alpha")
class IntraProcessChannel(ChannelInterface):
"""
IntraProcessChannel is a channel for communication between two tasks in the same
worker process. It writes data directly to the worker's _SerializationContext
and reads data from the _SerializationContext to avoid the serialization
overhead and the need for reading/writing from shared memory.

Args:
actor_handle: The actor handle of the worker process.
"""

def __init__(
self,
actor_handle: ray.actor.ActorHandle,
_channel_id: Optional[str] = None,
):
# TODO (kevin85421): Currently, if we don't pass `actor_handle` to
# `IntraProcessChannel`, the actor will die due to the reference count of
# `actor_handle` is 0. We should fix this issue in the future.
self._actor_handle = actor_handle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed by #45699, btw.

# Generate a unique ID for the channel. The writer and reader will use
# this ID to store and retrieve data from the _SerializationContext.
self._channel_id = _channel_id
if self._channel_id is None:
self._channel_id = str(uuid.uuid4())

def ensure_registered_as_writer(self) -> None:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to check here that ray.get_runtime_context().current_actor matches self._actor_handle?

Copy link
Member Author

@kevin85421 kevin85421 Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove self._actor_handle from the constructor after the ref count issue is fixed. It is not used in IntraProcessChannel anywhere.


def ensure_registered_as_reader(self) -> None:
pass

def __reduce__(self):
return IntraProcessChannel, (
self._actor_handle,
self._channel_id,
)

def write(self, value: Any):
# Because both the reader and writer are in the same worker process,
# we can directly store the data in the context instead of storing
# it in the channel object. This removes the serialization overhead of `value`.
ctx = ChannelContext.get_current().serialization_context
ctx.set_data(self._channel_id, value)

def begin_read(self) -> Any:
ctx = ChannelContext.get_current().serialization_context
return ctx.get_data(self._channel_id)

def end_read(self):
pass

def close(self) -> None:
ctx = ChannelContext.get_current().serialization_context
ctx.reset_data(self._channel_id)
23 changes: 22 additions & 1 deletion python/ray/experimental/channel/serialization_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, List, Union
from typing import TYPE_CHECKING, Any, Dict, List, Union

if TYPE_CHECKING:
import numpy as np
Expand All @@ -9,10 +9,31 @@ class _SerializationContext:
def __init__(self):
self.use_external_transport: bool = False
self.tensors: List["torch.Tensor"] = []
# Buffer for transferring data between tasks in the same worker process.
# The key is the channel ID, and the value is the data. We don't use a
# lock when reading/writing the buffer because a DAG node actor will only
# execute one task at a time in `do_exec_tasks`. It will not execute multiple
# Ray tasks on a single actor simultaneously.
self.intra_process_channel_buffers: Dict[str, Any] = {}

kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
def set_use_external_transport(self, use_external_transport: bool) -> None:
self.use_external_transport = use_external_transport

def set_data(self, channel_id: str, value: Any) -> None:
assert (
channel_id not in self.intra_process_channel_buffers
), f"Channel {channel_id} already exists in the buffer."
self.intra_process_channel_buffers[channel_id] = value

kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
def get_data(self, channel_id: str) -> Any:
assert (
channel_id in self.intra_process_channel_buffers
), f"Channel {channel_id} does not exist in the buffer."
return self.intra_process_channel_buffers.pop(channel_id)

def reset_data(self, channel_id: str) -> None:
self.intra_process_channel_buffers.pop(channel_id, None)

def reset_tensors(self, tensors: List["torch.Tensor"]) -> List["torch.Tensor"]:
prev_tensors = self.tensors
self.tensors = tensors
Expand Down
Loading
Loading