|
4 | 4 | from ray.util.annotations import DeveloperAPI
|
5 | 5 | import copy
|
6 | 6 |
|
| 7 | +from itertools import chain |
| 8 | + |
7 | 9 | from typing import (
|
8 | 10 | Optional,
|
9 | 11 | Union,
|
@@ -60,28 +62,33 @@ def __init__(
|
60 | 62 | )
|
61 | 63 |
|
62 | 64 | # The list of nodes that use this DAG node as input.
|
63 |
| - self._downstream_nodes: List[DAGNode] = [] |
| 65 | + self._downstream_nodes: List["DAGNode"] = [] |
64 | 66 | # The list of nodes that this DAG node uses as input.
|
| 67 | + self._upstream_nodes: List["DAGNode"] = self._prepare_upstream_nodes() |
| 68 | + |
| 69 | + # UUID that is not changed over copies of this node. |
| 70 | + self._stable_uuid = uuid.uuid4().hex |
| 71 | + # Cached values from last call to execute() |
| 72 | + self.cache_from_last_execute = {} |
| 73 | + |
| 74 | + self._type_hint: Optional[ChannelOutputType] = ChannelOutputType() |
| 75 | + self.is_output_node = False |
| 76 | + |
| 77 | + def _prepare_upstream_nodes(self) -> List["DAGNode"]: |
| 78 | + """Retrieve upstream nodes and update their downstream dependencies.""" |
65 | 79 | scanner = _PyObjScanner()
|
66 |
| - self._upstream_nodes: List[DAGNode] = scanner.find_nodes( |
| 80 | + upstream_nodes: List["DAGNode"] = scanner.find_nodes( |
67 | 81 | [
|
68 | 82 | self._bound_args,
|
69 | 83 | self._bound_kwargs,
|
70 | 84 | self._bound_other_args_to_resolve,
|
71 | 85 | ]
|
72 | 86 | )
|
73 | 87 | scanner.clear()
|
74 |
| - # Update upstream dependencies. |
75 |
| - for upstream_node in self._upstream_nodes: |
| 88 | + # Update dependencies. |
| 89 | + for upstream_node in upstream_nodes: |
76 | 90 | upstream_node._downstream_nodes.append(self)
|
77 |
| - |
78 |
| - # UUID that is not changed over copies of this node. |
79 |
| - self._stable_uuid = uuid.uuid4().hex |
80 |
| - # Cached values from last call to execute() |
81 |
| - self.cache_from_last_execute = {} |
82 |
| - |
83 |
| - self._type_hint: Optional[ChannelOutputType] = ChannelOutputType() |
84 |
| - self.is_output_node = False |
| 91 | + return upstream_nodes |
85 | 92 |
|
86 | 93 | def with_type_hint(self, typ: ChannelOutputType):
|
87 | 94 | if typ.is_direct_return:
|
@@ -332,9 +339,11 @@ def __init__(self, fn):
|
332 | 339 | self.input_node_uuid = None
|
333 | 340 |
|
334 | 341 | def __call__(self, node: "DAGNode"):
|
| 342 | + from ray.dag.input_node import InputNode |
| 343 | + |
335 | 344 | if node._stable_uuid not in self.cache:
|
336 | 345 | self.cache[node._stable_uuid] = self.fn(node)
|
337 |
| - if type(node).__name__ == "InputNode": |
| 346 | + if isinstance(node, InputNode): |
338 | 347 | if not self.input_node_uuid:
|
339 | 348 | self.input_node_uuid = node._stable_uuid
|
340 | 349 | elif self.input_node_uuid != node._stable_uuid:
|
@@ -368,14 +377,20 @@ def bfs(self, fn: "Callable[[DAGNode], T]"):
|
368 | 377 | # in some invalid cases, some nodes may not be descendants of the
|
369 | 378 | # root. Therefore, we also add upstream nodes to the queue so that
|
370 | 379 | # a meaningful error message can be raised when the DAG is compiled.
|
371 |
| - for node in node._downstream_nodes + node._upstream_nodes: |
| 380 | + for node in chain.from_iterable( |
| 381 | + [node._downstream_nodes, node._upstream_nodes] |
| 382 | + ): |
372 | 383 | if node not in visited:
|
373 | 384 | queue.append(node)
|
374 | 385 |
|
375 | 386 | def _find_root(self) -> "DAGNode":
|
376 |
| - """Return the root node of the DAG.""" |
| 387 | + """ |
| 388 | + Return the root node of the DAG. The root node must be an InputNode. |
| 389 | + """ |
| 390 | + from ray.dag.input_node import InputNode |
| 391 | + |
377 | 392 | root = self
|
378 |
| - while type(root).__name__ != "InputNode": |
| 393 | + while not isinstance(root, InputNode): |
379 | 394 | if len(root._upstream_nodes) == 0:
|
380 | 395 | raise ValueError("No InputNode found in the DAG.")
|
381 | 396 | root = root._upstream_nodes[0]
|
|
0 commit comments