Skip to content

Commit 54df435

Browse files
committed
address comments
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
1 parent b990685 commit 54df435

File tree

2 files changed

+39
-17
lines changed

2 files changed

+39
-17
lines changed

python/ray/dag/dag_node.py

+31-16
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from ray.util.annotations import DeveloperAPI
55
import copy
66

7+
from itertools import chain
8+
79
from typing import (
810
Optional,
911
Union,
@@ -60,28 +62,33 @@ def __init__(
6062
)
6163

6264
# The list of nodes that use this DAG node as input.
63-
self._downstream_nodes: List[DAGNode] = []
65+
self._downstream_nodes: List["DAGNode"] = []
6466
# The list of nodes that this DAG node uses as input.
67+
self._upstream_nodes: List["DAGNode"] = self._prepare_upstream_nodes()
68+
69+
# UUID that is not changed over copies of this node.
70+
self._stable_uuid = uuid.uuid4().hex
71+
# Cached values from last call to execute()
72+
self.cache_from_last_execute = {}
73+
74+
self._type_hint: Optional[ChannelOutputType] = ChannelOutputType()
75+
self.is_output_node = False
76+
77+
def _prepare_upstream_nodes(self) -> List["DAGNode"]:
78+
"""Retrieve upstream nodes and update their downstream dependencies."""
6579
scanner = _PyObjScanner()
66-
self._upstream_nodes: List[DAGNode] = scanner.find_nodes(
80+
upstream_nodes: List["DAGNode"] = scanner.find_nodes(
6781
[
6882
self._bound_args,
6983
self._bound_kwargs,
7084
self._bound_other_args_to_resolve,
7185
]
7286
)
7387
scanner.clear()
74-
# Update upstream dependencies.
75-
for upstream_node in self._upstream_nodes:
88+
# Update dependencies.
89+
for upstream_node in upstream_nodes:
7690
upstream_node._downstream_nodes.append(self)
77-
78-
# UUID that is not changed over copies of this node.
79-
self._stable_uuid = uuid.uuid4().hex
80-
# Cached values from last call to execute()
81-
self.cache_from_last_execute = {}
82-
83-
self._type_hint: Optional[ChannelOutputType] = ChannelOutputType()
84-
self.is_output_node = False
91+
return upstream_nodes
8592

8693
def with_type_hint(self, typ: ChannelOutputType):
8794
if typ.is_direct_return:
@@ -332,9 +339,11 @@ def __init__(self, fn):
332339
self.input_node_uuid = None
333340

334341
def __call__(self, node: "DAGNode"):
342+
from ray.dag.input_node import InputNode
343+
335344
if node._stable_uuid not in self.cache:
336345
self.cache[node._stable_uuid] = self.fn(node)
337-
if type(node).__name__ == "InputNode":
346+
if isinstance(node, InputNode):
338347
if not self.input_node_uuid:
339348
self.input_node_uuid = node._stable_uuid
340349
elif self.input_node_uuid != node._stable_uuid:
@@ -368,14 +377,20 @@ def bfs(self, fn: "Callable[[DAGNode], T]"):
368377
# in some invalid cases, some nodes may not be descendants of the
369378
# root. Therefore, we also add upstream nodes to the queue so that
370379
# a meaningful error message can be raised when the DAG is compiled.
371-
for node in node._downstream_nodes + node._upstream_nodes:
380+
for node in chain.from_iterable(
381+
[node._downstream_nodes, node._upstream_nodes]
382+
):
372383
if node not in visited:
373384
queue.append(node)
374385

375386
def _find_root(self) -> "DAGNode":
376-
"""Return the root node of the DAG."""
387+
"""
388+
Return the root node of the DAG. The root node must be an InputNode.
389+
"""
390+
from ray.dag.input_node import InputNode
391+
377392
root = self
378-
while type(root).__name__ != "InputNode":
393+
while not isinstance(root, InputNode):
379394
if len(root._upstream_nodes) == 0:
380395
raise ValueError("No InputNode found in the DAG.")
381396
root = root._upstream_nodes[0]

python/ray/dag/py_obj_scanner.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,14 @@ def reducer_override(self, obj):
7070
return super().reducer_override(obj)
7171

7272
def find_nodes(self, obj: Any) -> List[SourceType]:
73-
"""Find top-level DAGNodes."""
73+
"""
74+
Serialize `obj` and store all instances of `source_type` found in `_found`.
75+
76+
Args:
77+
obj: The object to scan for `source_type`.
78+
Returns:
79+
A list of all instances of `source_type` found in `obj`.
80+
"""
7481
assert (
7582
self._found is None
7683
), "find_nodes cannot be called twice on the same PyObjScanner instance."

0 commit comments

Comments
 (0)