-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core][experimental] Avoid serialization for data passed between two tasks on the same actor #45591
Changes from all commits
ee76411
74ad669
26903ae
03bac31
5972446
7cd549b
97292b1
d1a025c
c540b94
cc2e897
3170419
2ab160d
a25e443
af458bf
316b408
ad592e7
777a284
ec98928
a298811
980cc82
085bcf9
f49593c
53373ba
c16d8a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,7 +43,7 @@ def do_allocate_channel( | |
|
||
Args: | ||
readers: The actor handles of the readers. | ||
buffer_size_bytes: The maximum size of messages in the channel. | ||
typ: The output type hint for the channel. | ||
|
||
Returns: | ||
The allocated channel. | ||
|
@@ -129,7 +129,7 @@ def _exec_task(self, task: "ExecutableTask", idx: int) -> bool: | |
True if we are done executing all tasks of this actor, False otherwise. | ||
""" | ||
# TODO: for cases where output is passed as input to a task on | ||
# the same actor, introduce a "LocalChannel" to avoid the overhead | ||
# the same actor, introduce a "IntraProcessChannel" to avoid the overhead | ||
# of serialization/deserialization and synchronization. | ||
method = getattr(self, task.method_name) | ||
input_reader = self._input_readers[idx] | ||
|
@@ -649,12 +649,12 @@ def _get_or_compile( | |
|
||
if isinstance(task.dag_node, ClassMethodNode): | ||
readers = [self.idx_to_task[idx] for idx in task.downstream_node_idxs] | ||
assert len(readers) == 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assertion will make the following case fail: the actor a has two readers (actor b and itself) a = Actor.remote(0)
b = Actor.remote(100)
with InputNode() as inp:
dag = a.inc.bind(inp)
dag = MultiOutputNode([a.inc.bind(dag), b.inc.bind(dag)])
compiled_dag = dag.experimental_compile() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is a known bug, cc @jackhumphries I think this check got added when Jack was looking into supporting multi-node. I think you could remove it, and it should be okay as long as all of the remote readers are on the same node. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Why should all remote readers be on the same node? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now that is required because all of the remote readers will share one channel. But with your PR, we can create multiple channels, one for each unique reader node. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. I will update my PR. Currently, it only creates one There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is currently something I need to work on. The hope is that this PR will make it easy at this point to add support for readers across multiple nodes, though you don't need to explicitly support that in this PR--I will take care of it once this PR is merged. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Chat with @jackhumphries offline. He will work on this part. I will not include the multi-node support in this PR. |
||
|
||
def _get_node_id(self): | ||
return ray.get_runtime_context().get_node_id() | ||
|
||
if isinstance(readers[0].dag_node, MultiOutputNode): | ||
assert len(readers) == 1 | ||
# This node is a multi-output node, which means that it will only be | ||
# read by the driver, not an actor. Thus, we handle this case by | ||
# setting `reader_handles` to `[None]`. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -646,6 +646,119 @@ async def main(): | |
compiled_dag.teardown() | ||
|
||
|
||
class TestCompositeChannel: | ||
def test_composite_channel_one_actor(self, ray_start_regular_shared): | ||
""" | ||
In this test, there are three 'inc' tasks on the same Ray actor, chained | ||
together. Therefore, the DAG will look like this: | ||
|
||
Driver -> a.inc -> a.inc -> a.inc -> Driver | ||
|
||
All communication between the driver and the actor will be done through remote | ||
channels, i.e., shared memory channels. All communication between the actor | ||
tasks will be conducted through local channels, i.e., IntraProcessChannel in | ||
this case. | ||
|
||
To elaborate, all output channels of the actor DAG nodes will be | ||
CompositeChannel, and the first two will have a local channel, while the last | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we use CompositeChannels here? Isn't there only one reader per channel? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes.
In this PR, the |
||
one will have a remote channel. | ||
""" | ||
a = Actor.remote(0) | ||
with InputNode() as inp: | ||
dag = a.inc.bind(inp) | ||
dag = a.inc.bind(dag) | ||
dag = a.inc.bind(dag) | ||
|
||
compiled_dag = dag.experimental_compile() | ||
output_channel = compiled_dag.execute(1) | ||
result = output_channel.begin_read() | ||
assert result == 4 | ||
kevin85421 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
output_channel.end_read() | ||
|
||
output_channel = compiled_dag.execute(2) | ||
result = output_channel.begin_read() | ||
assert result == 24 | ||
output_channel.end_read() | ||
|
||
output_channel = compiled_dag.execute(3) | ||
result = output_channel.begin_read() | ||
assert result == 108 | ||
output_channel.end_read() | ||
|
||
compiled_dag.teardown() | ||
|
||
def test_composite_channel_two_actors(self, ray_start_regular_shared): | ||
""" | ||
In this test, there are three 'inc' tasks on the two Ray actors, chained | ||
together. Therefore, the DAG will look like this: | ||
|
||
Driver -> a.inc -> b.inc -> a.inc -> Driver | ||
|
||
All communication between the driver and actors will be done through remote | ||
channels. Also, all communication between the actor tasks will be conducted | ||
through remote channels, i.e., shared memory channel in this case because no | ||
consecutive tasks are on the same actor. | ||
""" | ||
a = Actor.remote(0) | ||
b = Actor.remote(100) | ||
with InputNode() as inp: | ||
dag = a.inc.bind(inp) | ||
dag = b.inc.bind(dag) | ||
dag = a.inc.bind(dag) | ||
|
||
# a: 0+1 -> b: 100+1 -> a: 1+101 | ||
compiled_dag = dag.experimental_compile() | ||
output_channel = compiled_dag.execute(1) | ||
result = output_channel.begin_read() | ||
assert result == 102 | ||
output_channel.end_read() | ||
|
||
# a: 102+2 -> b: 101+104 -> a: 104+205 | ||
output_channel = compiled_dag.execute(2) | ||
result = output_channel.begin_read() | ||
assert result == 309 | ||
output_channel.end_read() | ||
|
||
# a: 309+3 -> b: 205+312 -> a: 312+517 | ||
output_channel = compiled_dag.execute(3) | ||
result = output_channel.begin_read() | ||
assert result == 829 | ||
output_channel.end_read() | ||
|
||
compiled_dag.teardown() | ||
|
||
def test_composite_channel_multi_output(self, ray_start_regular_shared): | ||
""" | ||
Driver -> a.inc -> a.inc ---> Driver | ||
| | | ||
-> b.inc - | ||
|
||
All communication in this DAG will be done through CompositeChannel. | ||
Under the hood, the communication between two `a.inc` tasks will | ||
be done through a local channel, i.e., IntraProcessChannel in this | ||
case, while the communication between `a.inc` and `b.inc` will be | ||
done through a shared memory channel. | ||
""" | ||
a = Actor.remote(0) | ||
b = Actor.remote(100) | ||
with InputNode() as inp: | ||
dag = a.inc.bind(inp) | ||
dag = MultiOutputNode([a.inc.bind(dag), b.inc.bind(dag)]) | ||
|
||
compiled_dag = dag.experimental_compile() | ||
output_channel = compiled_dag.execute(1) | ||
result = output_channel.begin_read() | ||
assert result == [2, 101] | ||
output_channel.end_read() | ||
|
||
output_channel = compiled_dag.execute(3) | ||
result = output_channel.begin_read() | ||
assert result == [10, 106] | ||
output_channel.end_read() | ||
|
||
compiled_dag.teardown() | ||
|
||
|
||
if __name__ == "__main__": | ||
if os.environ.get("PARALLEL_CI"): | ||
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import uuid | ||
from typing import Any, Optional | ||
|
||
import ray | ||
from ray.experimental.channel import ChannelContext | ||
from ray.experimental.channel.common import ChannelInterface | ||
from ray.util.annotations import PublicAPI | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class IntraProcessChannel(ChannelInterface): | ||
""" | ||
IntraProcessChannel is a channel for communication between two tasks in the same | ||
worker process. It writes data directly to the worker's _SerializationContext | ||
and reads data from the _SerializationContext to avoid the serialization | ||
overhead and the need for reading/writing from shared memory. | ||
|
||
Args: | ||
actor_handle: The actor handle of the worker process. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
actor_handle: ray.actor.ActorHandle, | ||
_channel_id: Optional[str] = None, | ||
): | ||
# TODO (kevin85421): Currently, if we don't pass `actor_handle` to | ||
# `IntraProcessChannel`, the actor will die due to the reference count of | ||
# `actor_handle` is 0. We should fix this issue in the future. | ||
self._actor_handle = actor_handle | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be fixed by #45699, btw. |
||
# Generate a unique ID for the channel. The writer and reader will use | ||
# this ID to store and retrieve data from the _SerializationContext. | ||
self._channel_id = _channel_id | ||
if self._channel_id is None: | ||
self._channel_id = str(uuid.uuid4()) | ||
|
||
def ensure_registered_as_writer(self) -> None: | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be good to check here that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will remove |
||
|
||
def ensure_registered_as_reader(self) -> None: | ||
pass | ||
|
||
def __reduce__(self): | ||
return IntraProcessChannel, ( | ||
self._actor_handle, | ||
self._channel_id, | ||
) | ||
|
||
def write(self, value: Any): | ||
# Because both the reader and writer are in the same worker process, | ||
# we can directly store the data in the context instead of storing | ||
# it in the channel object. This removes the serialization overhead of `value`. | ||
ctx = ChannelContext.get_current().serialization_context | ||
ctx.set_data(self._channel_id, value) | ||
|
||
def begin_read(self) -> Any: | ||
ctx = ChannelContext.get_current().serialization_context | ||
return ctx.get_data(self._channel_id) | ||
|
||
def end_read(self): | ||
pass | ||
|
||
def close(self) -> None: | ||
ctx = ChannelContext.get_current().serialization_context | ||
ctx.reset_data(self._channel_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function doesn't have an argument
buffer_size_bytes
.