Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][compiled-graphs] Don't persist input_nodes in _CollectiveOperation to avoid wrong understanding about DAGs #48463

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions python/ray/dag/collective_node.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we only scan certain fields for dag_node.py::_collect_upstream_nodes(), say not all dict entries of other_args_to_resolve, but exclude COLLECTIVE_OPERATION_KEY?

I think in any case we should mention in the docstring of dag_node.py::_collect_upstream_nodes() its assumptions. Currently the assumption is all nodes appear in the following are considered as upstream nodes:

                self._bound_args,
                self._bound_kwargs,
                self._bound_other_args_to_resolve,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in any case we should mention in the docstring of dag_node.py::_collect_upstream_nodes() its assumptions. Currently the assumption is all nodes appear in the following are considered as upstream nodes:

Agree on this. It sounds ok to not scan everything in other_args_to_resolve.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have considered skipping COLLECTIVE_OPERATION_KEY. My concern is that it seems a bit odd for a basic class (DAGNode) to implement logic from other classes built on top of it. Additionally, the code path applies to both DAG and ADAG. I am a bit worried about the complexity in the future if we add more and more ADAG-specific logic inside the shared code path. HDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you two think skipping the key is better, I will update the PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or perhaps add a new field _bound_other_args_not_to_resolve and avoid scanning it?

Copy link
Contributor

@dengwxn dengwxn Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the expected upstream nodes for a CollectiveOutputNode? Are they all the input nodes from all the actors, or simply the only one input node from the same actor?

This could potentially cause issues when compiling the graph.

What are the potential issues?

Copy link
Contributor

@ruisearch42 ruisearch42 Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chatted with Kaihsun a bit yesterday, but +1 to Weixin's question.

I think the key issue is what's the definition of "upstream nodes", especially in the special case of collectives mentioned above. This definition needs to make sense based on how we use them in DAG and ADAG. Once this is clarified, we know what should be the right thing to do. @kevin85421 Can we define that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the expected upstream nodes for a CollectiveOutputNode? Are they all the input nodes from all the actors, or simply the only one input node from the same actor?
What are the potential issues?

I think for now the upstream nodes for a CollectiveOutputNode should be the args of the DAGNode so that DAG and ADAG can have the same understanding for the same graph.

For example, compiled_dag_node.py sets up the upstream/downstream relationship inside preprocess by treating args as a DAGNode's upstream nodes.

However, in dag_node.py, all DAGNodes inside self._bound_args, self._bound_kwargs, and self._bound_other_args_to_resolve are considered as the upstream nodes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sync offline with @ruisearch42 : update the comments, and open an issue to track the progress #48520.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our conclusion is introducing a new field only when we observe more and more issues are caused by the inconsistency.

Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,21 @@ def __init__(
op: _CollectiveOp,
transport: Optional[Union[str, GPUCommunicator]] = None,
):
self._input_nodes: List[DAGNode] = input_nodes
if len(self._input_nodes) == 0:
if len(input_nodes) == 0:
raise ValueError("Expected input nodes for a collective operation")
if len(set(self._input_nodes)) != len(self._input_nodes):
if len(set(input_nodes)) != len(input_nodes):
raise ValueError("Expected unique input nodes for a collective operation")

self._actor_handles: List["ray.actor.ActorHandle"] = []
for input_node in self._input_nodes:
for input_node in input_nodes:
actor_handle = input_node._get_actor_handle()
if actor_handle is None:
raise ValueError("Expected an actor handle from the input node")
self._actor_handles.append(actor_handle)
if len(set(self._actor_handles)) != len(self._actor_handles):
invalid_input_nodes = [
input_node
for input_node in self._input_nodes
for input_node in input_nodes
if self._actor_handles.count(input_node._get_actor_handle()) > 1
]
raise ValueError(
Expand All @@ -76,7 +75,6 @@ def __init__(
def __str__(self) -> str:
return (
f"CollectiveGroup("
f"_input_nodes={self._input_nodes}, "
f"_actor_handles={self._actor_handles}, "
f"_op={self._op}, "
f"_type_hint={self._type_hint})"
Expand Down