Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADAG] Add visualization of compiled graphs #47958

Merged
merged 16 commits into from
Oct 24, 2024
58 changes: 58 additions & 0 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,64 @@ async def execute_async(
self._execution_index += 1
return fut

def visualize(self, filename="compiled_dag", format="png", view=False):
Bye-legumes marked this conversation as resolved.
Show resolved Hide resolved
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
"""
Visualize the compiled DAG using Graphviz.
Bye-legumes marked this conversation as resolved.
Show resolved Hide resolved

Args:
filename: The name of the output file (without extension).
format: The format of the output file (e.g., 'png', 'pdf').
view: Whether to open the file with the default viewer.
"""
import graphviz
from ray.dag import (
InputAttributeNode,
InputNode,
)

dot = graphviz.Digraph(name="CompiledDAG", format=format)

# Add nodes with task information
for idx, task in self.idx_to_task.items():
dag_node = task.dag_node

# Initialize the label
label = f"Task {idx}\n"

# Handle different types of dag_node
if isinstance(dag_node, InputNode):
label += "InputNode"
elif isinstance(dag_node, InputAttributeNode):
label += f"InputAttributeNode[{dag_node.key}]"
elif hasattr(dag_node, "get_method_name"):
Bye-legumes marked this conversation as resolved.
Show resolved Hide resolved
method_name_attr = getattr(dag_node, "get_method_name")
if callable(method_name_attr):
method_name = method_name_attr()
else:
method_name = method_name_attr

# Get actor ID if applicable
actor_handle = dag_node._get_actor_handle()
if actor_handle:
Bye-legumes marked this conversation as resolved.
Show resolved Hide resolved
Bye-legumes marked this conversation as resolved.
Show resolved Hide resolved
actor_id = actor_handle._actor_id.hex()
label += f"Actor: {actor_id[:6]}...\nMethod: {method_name}"
else:
label += f"Method: {method_name}"
else:
label += type(dag_node).__name__
Bye-legumes marked this conversation as resolved.
Show resolved Hide resolved

# Add the node to the graph
dot.node(str(idx), label)

# Add edges based on downstream tasks
for idx, task in self.idx_to_task.items():
for downstream_idx, _ in task.downstream_task_idxs.items():
# You can also include edge labels with channel types or other info
dot.edge(str(idx), str(downstream_idx))
Bye-legumes marked this conversation as resolved.
Show resolved Hide resolved

# Render the graph to a file
dot.render(filename, view=view)

def teardown(self):
"""Teardown and cancel all actor tasks for this DAG. After this
function returns, the actors should be available to execute new tasks
Expand Down