diff --git a/python/ray/dag/BUILD b/python/ray/dag/BUILD index 013cc584cd1dd..72c8b4b38b22f 100644 --- a/python/ray/dag/BUILD +++ b/python/ray/dag/BUILD @@ -105,6 +105,7 @@ py_test_module_list( "tests/experimental/test_detect_deadlock_dag.py", "tests/experimental/test_multi_node_dag.py", "tests/experimental/test_torch_tensor_dag.py", + "tests/experimental/test_execution_schedule.py", ], tags = [ "accelerated_dag", diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index f4c15060820b1..c2bff74298326 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -37,6 +37,14 @@ _destroy_nccl_group, ) +from ray.dag.dag_node_operation import ( + _DAGNodeOperation, + _DAGNodeOperationType, + _DAGOperationGraphNode, + _build_dag_node_operation_graph, + _generate_actor_to_execution_schedule, +) + from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy @@ -84,30 +92,29 @@ def do_allocate_channel( def do_exec_tasks( self, tasks: List["ExecutableTask"], + schedule: List[_DAGNodeOperation], ) -> None: - """Generic actor method to begin executing the tasks belonging to an actor. - This runs an infinite loop to run each task in turn (following the order specified - in the list): reading input channel(s), executing the given taks, and writing output - channel(s). It only exits if the actor dies or an exception is thrown. + """A generic actor method to begin executing the operations belonging to an + actor. This runs an infinite loop to execute each _DAGNodeOperation in the + order specified by the schedule. It exits only if the actor dies or an + exception is thrown. Args: tasks: the executable tasks corresponding to the actor methods. + schedule: A list of _DAGNodeOperation that should be executed in order. """ try: - self._input_readers = [] - self._output_writers = [] for task in tasks: - _prep_task(self, task) + task.prepare() done = False while True: if done: break - for idx, task in enumerate(tasks): - done = _exec_task(self, task, idx) + for operation in schedule: + done = tasks[operation.local_idx].exec_operation(self, operation.type) if done: break - except Exception: logging.exception("Compiled DAG task exited with exception") raise @@ -115,26 +122,8 @@ def do_exec_tasks( @DeveloperAPI def do_cancel_executable_tasks(self, tasks: List["ExecutableTask"]) -> None: - for idx in range(len(tasks)): - self._input_readers[idx].close() - self._output_writers[idx].close() - - -def _prep_task(self, task: "ExecutableTask") -> None: - """ - Prepare the task for execution. - """ - for typ_hint in task.input_type_hints: - typ_hint.register_custom_serializer() - task.output_type_hint.register_custom_serializer() - - input_reader: ReaderInterface = SynchronousReader(task.input_channels) - output_writer: WriterInterface = SynchronousWriter(task.output_channel) - self._input_readers.append(input_reader) - self._output_writers.append(output_writer) - - input_reader.start() - output_writer.start() + for task in tasks: + task.cancel() def _wrap_exception(exc): @@ -150,56 +139,6 @@ def _wrap_exception(exc): return wrapped -def _exec_task(self, task: "ExecutableTask", idx: int) -> bool: - """ - Execute the task. - Args: - task: The task to execute. - idx: The index of the task in the list of tasks of the actor. - Returns: - True if we are done executing all tasks of this actor, False otherwise. - """ - # TODO: for cases where output is passed as input to a task on - # the same actor, introduce a "IntraProcessChannel" to avoid the overhead - # of serialization/deserialization and synchronization. - method = getattr(self, task.method_name) - input_reader = self._input_readers[idx] - output_writer = self._output_writers[idx] - res = None - try: - res = input_reader.read() - except RayChannelError: - # Channel closed. Exit the loop. - return True - - try: - _process_return_vals(res, return_single_output=False) - except Exception as exc: - # Previous task raised an application-level exception. - # Propagate it and skip the actual task. We don't need to wrap the - # exception in a RayTaskError here because it has already been wrapped - # by the previous task. - output_writer.write(exc) - return False - - resolved_inputs = [] - for task_input in task.task_inputs: - resolved_inputs.append(task_input.resolve(res)) - - try: - output_val = method(*resolved_inputs, **task.resolved_kwargs) - except Exception as exc: - output_val = _wrap_exception(exc) - - try: - output_writer.write(output_val) - except RayChannelError: - # Channel closed. Exit the loop. - return True - - return False - - @DeveloperAPI class CompiledTask: """Wraps the normal Ray DAGNode with some metadata.""" @@ -354,6 +293,9 @@ def __init__( self.input_channels: List[ChannelInterface] = [] self.task_inputs: List[_ExecutableTaskInput] = [] self.resolved_kwargs: Dict[str, Any] = resolved_kwargs + # A unique index which can be used to index into `idx_to_task` to get + # the corresponding task. + self.dag_idx = task.idx # Reverse map for input_channels: maps an input channel to # its index in input_channels. @@ -386,6 +328,152 @@ def __init__( assert not isinstance(val, ChannelInterface) assert not isinstance(val, DAGInputAdapter) + # Input reader to read input data from upstream DAG nodes. + self.input_reader: ReaderInterface = SynchronousReader(self.input_channels) + # Output writer to write output data to downstream DAG nodes. + self.output_writer: WriterInterface = SynchronousWriter(self.output_channel) + # Store the intermediate result of a READ or COMPUTE operation. + # The result of a READ operation will be used by a COMPUTE operation, + # and the result of a COMPUTE operation will be used by a WRITE operation. + self._intermediate_buffer: Any = None + + def cancel(self): + """ + Close all the input channels and the output channel. The exact behavior + depends on the type of channel. Typically, it will release the resources + used by the channels. + """ + self.input_reader.close() + self.output_writer.close() + + def prepare(self): + """ + Prepare the task for execution. The `exec_operation` function can only + be called after `prepare` has been called. + """ + for typ_hint in self.input_type_hints: + typ_hint.register_custom_serializer() + self.output_type_hint.register_custom_serializer() + self.input_reader.start() + self.output_writer.start() + + def set_intermediate_buffer(self, data: Any): + """ + Store the intermediate result of a READ or COMPUTE operation. + + Args: + data: The intermediate result of a READ or COMPUTE operation. + """ + assert self._intermediate_buffer is None + self._intermediate_buffer = data + + def reset_intermediate_buffer(self) -> Any: + """ + Retrieve the intermediate result of a READ or COMPUTE operation, + and reset the intermediate buffer to None. + + Returns: + The intermediate result of a READ or COMPUTE operation. + """ + data = self._intermediate_buffer + self._intermediate_buffer = None + return data + + def _read(self) -> bool: + """ + Read input data from upstream DAG nodes and cache the intermediate result. + + Returns: + True if system error occurs and exit the loop; otherwise, False. + """ + assert self._intermediate_buffer is None + exit = False + try: + input_data = self.input_reader.read() + self.set_intermediate_buffer(input_data) + except RayChannelError: + # Channel closed. Exit the loop. + exit = True + return exit + + def _compute(self, class_handle) -> bool: + """ + Retrieve the intermediate result from the READ operation and perform the + computation. Then, cache the new intermediate result. The caller must ensure + that the last operation executed is READ so that the function retrieves the + correct intermediate result. + + Args: + class_handle: An instance of the class to which the actor belongs. For + example, the type of `class_handle` is if the + actor belongs to the `class Worker` class. + + Returns: + True if system error occurs and exit the loop; otherwise, False. + """ + input_data = self.reset_intermediate_buffer() + method = getattr(class_handle, self.method_name) + try: + _process_return_vals(input_data, return_single_output=False) + except Exception as exc: + # Previous task raised an application-level exception. + # Propagate it and skip the actual task. We don't need to wrap the + # exception in a RayTaskError here because it has already been wrapped + # by the previous task. + self.set_intermediate_buffer(exc) + return False + + resolved_inputs = [] + for task_input in self.task_inputs: + resolved_inputs.append(task_input.resolve(input_data)) + + try: + output_val = method(*resolved_inputs, **self.resolved_kwargs) + except Exception as exc: + output_val = _wrap_exception(exc) + self.set_intermediate_buffer(output_val) + return False + + def _write(self) -> bool: + """ + Retrieve the intermediate result from the COMPUTE operation and write to its + downstream DAG nodes. The caller must ensure that the last operation executed + is COMPUTE so that the function retrieves the correct intermediate result. + + Returns: + True if system error occurs and exit the loop; otherwise, False. + """ + output_val = self.reset_intermediate_buffer() + exit = False + try: + self.output_writer.write(output_val) + except RayChannelError: + # Channel closed. Exit the loop. + exit = True + return exit + + def exec_operation(self, class_handle, op_type: _DAGNodeOperationType) -> bool: + """ + An ExecutableTask corresponds to a DAGNode. It consists of three + operations: READ, COMPUTE, and WRITE, which should be executed in + order to ensure that each operation can read the correct intermediate + result. + + Args: + class_handle: The handle of the class to which the actor belongs. + op_type: The type of the operation. Possible types are READ, + COMPUTE, and WRITE. + + Returns: + True if the next operation should not be executed; otherwise, False. + """ + if op_type == _DAGNodeOperationType.READ: + return self._read() + elif op_type == _DAGNodeOperationType.COMPUTE: + return self._compute(class_handle) + elif op_type == _DAGNodeOperationType.WRITE: + return self._write() + @DeveloperAPI class CompiledDAG: @@ -522,6 +610,11 @@ def __init__( self.actor_to_executable_tasks: Dict[ "ray.actor.ActorHandle", List["ExecutableTask"] ] = {} + # Mapping from the actor handle to the execution schedule which is a list + # of operations to be executed. + self.actor_to_execution_schedule: Dict[ + "ray.actor.ActorHandle", List[_DAGNodeOperation] + ] = defaultdict(list) # Mapping from the actor handle to the node ID that the actor is on. self.actor_to_node_id: Dict["ray.actor.ActorHandle", str] = {} @@ -990,7 +1083,6 @@ def _get_or_compile( # Create executable tasks for each actor for actor_handle, tasks in self.actor_to_tasks.items(): executable_tasks = [] - worker_fn = None for task in tasks: resolved_args = [] has_at_least_one_channel_input = False @@ -1024,19 +1116,20 @@ def _get_or_compile( task.kwargs, ) executable_tasks.append(executable_task) - if worker_fn is None: - worker_fn = task.dag_node._get_remote_method("__ray_call__") # Sort executable tasks based on their bind index, i.e., submission order # so that they will be executed in that order. executable_tasks.sort(key=lambda task: task.bind_index) - self.actor_to_executable_tasks[actor_handle] = executable_tasks - # Assign the task with the correct input and output buffers. - self.worker_task_refs[ - task.dag_node._get_actor_handle() - ] = worker_fn.options(concurrency_group="_ray_system").remote( + + # Build an execution schedule for each actor + self.actor_to_execution_schedule = self._build_execution_schedule() + for actor_handle, executable_tasks in self.actor_to_executable_tasks.items(): + self.worker_task_refs[actor_handle] = actor_handle.__ray_call__.options( + concurrency_group="_ray_system" + ).remote( do_exec_tasks, executable_tasks, + self.actor_to_execution_schedule[actor_handle], ) self.dag_output_channels = [] @@ -1075,6 +1168,103 @@ def _get_or_compile( self._dag_submitter.start() self._dag_output_fetcher.start() + def _generate_dag_operation_graph_node( + self, + ) -> Dict["ray.actor.ActorHandle", List[List[_DAGOperationGraphNode]]]: + """ + Generate READ, COMPUTE, and WRITE operations for each DAG node. + + Returns: + A dictionary that maps an actor handle to a list of lists of + _DAGOperationGraphNode. For the same actor, the index of the + outer list corresponds to the index of the ExecutableTask in + the list of `executable_tasks` in `actor_to_executable_tasks`, + i.e. `local_idx`. In the inner list, the order of operations + is READ, COMPUTE, and WRITE. + + Example: + { + actor1: [ + [READ COMPUTE WRITE] # local_idx 0 + [READ COMPUTE WRITE] # local_idx 1 + ] + } + """ + assert self.idx_to_task + assert self.actor_to_executable_tasks + + actor_to_operation_nodes: Dict[ + "ray.actor.ActorHandle", List[List[_DAGOperationGraphNode]] + ] = defaultdict(list) + + for actor_handle, executable_tasks in self.actor_to_executable_tasks.items(): + for local_idx, exec_task in enumerate(executable_tasks): + # Divide a DAG node into three _DAGOperationGraphNodes: READ, COMPUTE, + # and WRITE. Each _DAGOperationGraphNode has a _DAGNodeOperation. + dag_idx = exec_task.dag_idx + dag_node = self.idx_to_task[dag_idx].dag_node + actor_handle = dag_node._get_actor_handle() + requires_nccl = dag_node.type_hint.requires_nccl() + + read_node = _DAGOperationGraphNode( + _DAGNodeOperation(local_idx, _DAGNodeOperationType.READ), + dag_idx, + actor_handle, + requires_nccl, + ) + compute_node = _DAGOperationGraphNode( + _DAGNodeOperation(local_idx, _DAGNodeOperationType.COMPUTE), + dag_idx, + actor_handle, + requires_nccl, + ) + write_node = _DAGOperationGraphNode( + _DAGNodeOperation(local_idx, _DAGNodeOperationType.WRITE), + dag_idx, + actor_handle, + requires_nccl, + ) + actor_to_operation_nodes[actor_handle].append( + [read_node, compute_node, write_node] + ) + return actor_to_operation_nodes + + def _build_execution_schedule(self): + """ + Generate an execution schedule for each actor. The schedule is a list of + _DAGNodeOperation. + + Step 1: Generate a DAG node operation graph. Refer to the functions + `_generate_dag_operation_graph_node` and `_build_dag_node_operation_graph` + for more details. + + Step 2: Topological sort + + It is possible to have multiple _DAGOperationGraphNodes with zero in-degree. + Refer to the function `_select_next_nodes` for the logic of selecting nodes. + + Then, put the selected nodes into the corresponding actors' schedules. + + The schedule should be intuitive to users, meaning that the execution should + perform operations in ascending order of `bind_index` as much as possible. + + [Example]: + + See `test_execution_schedule` for more examples. + + Returns: + actor_to_execution_schedule: A dictionary that maps an actor handle to + the execution schedule which is a list of operations to be executed. + """ + # Step 1: Build a graph of _DAGOperationGraphNode + actor_to_operation_nodes = self._generate_dag_operation_graph_node() + graph = _build_dag_node_operation_graph( + self.idx_to_task, actor_to_operation_nodes + ) + # Step 2: Generate an execution schedule for each actor using topological sort + actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) + return actor_to_execution_schedule + def _detect_deadlock(self) -> bool: """ Create a graph with the following 3 rules, and then use diff --git a/python/ray/dag/dag_node_operation.py b/python/ray/dag/dag_node_operation.py new file mode 100644 index 0000000000000..7492bc77a7b60 --- /dev/null +++ b/python/ray/dag/dag_node_operation.py @@ -0,0 +1,354 @@ +from functools import total_ordering +from enum import Enum +from typing import Set, Tuple, List, Dict +import ray +import heapq +from collections import defaultdict + + +class _DAGNodeOperationType(Enum): + """ + There are three types of operations that a DAG node can perform: + 1. READ: Read from an input channel. + 2. COMPUTE: Execute the method corresponding to the node. + 3. WRITE: Write to an output channel. + """ + + READ = "READ" + COMPUTE = "COMPUTE" + WRITE = "WRITE" + + +class _DAGNodeOperation: + def __init__( + self, + local_idx: int, + operation_type: _DAGNodeOperationType, + ): + """ + Args: + local_idx: The index of the task that this operation belongs to + in the actor's ExecutableTask list. The index is not the same + as bind_index because there may be more tasks bound to an actor + than tasks that appear in the current compiled DAG. + operation_type: The type of operation to perform. + """ + self.local_idx = local_idx + self.type = operation_type + + +@total_ordering +class _DAGOperationGraphNode: + def __init__( + self, + operation: _DAGNodeOperation, + dag_idx: int, + actor_handle: "ray.actor.ActorHandle", + requires_nccl: bool, + ): + """ + _DAGOperationGraphNode represents a node in the DAG operation graph. + It contains information about the node's in-degree, out-degree, edges, + and the operation it performs. + + Args: + operation: The operation that this node performs. The operation + can be a READ, COMPUTE, or WRITE operation. + dag_idx: A unique index which can be used to index into + `CompiledDAG.idx_to_task` to get the corresponding task. + actor_handle: The actor handle to which this operation belongs. + requires_nccl: Whether this operation requires NCCL. + """ + self.operation = operation + self.dag_idx = dag_idx + self.actor_handle = actor_handle + self.requires_nccl = requires_nccl + # The in_edges and out_edges are sets of tuples. Each tuple contains + # an integer `dag_idx`, which can be used to index into `idx_to_task` + # to get the corresponding task, and a `_DAGNodeOperationType`, which can + # be READ, COMPUTE, or WRITE. + self.in_edges: Set[Tuple[int, _DAGNodeOperationType]] = set() + self.out_edges: Set[Tuple[int, _DAGNodeOperationType]] = set() + + @property + def in_degree(self) -> int: + return len(self.in_edges) + + def __lt__(self, other: "_DAGOperationGraphNode"): + """ + This function defines the order of the nodes in the priority queue used in + `_select_next_nodes`. The priority queue is a min-heap, so the node with + higher priority is considered "less than" the other node. + """ + # If two nodes belong to the same actor, select the one with + # the smaller `local_idx`. + if self.actor_handle == other.actor_handle: + return self.operation.local_idx < other.operation.local_idx + # If two nodes belong to different actors and one of them is an NCCL + # write node, select the one that is not an NCCL write node. + is_nccl_write = ( + self.operation.type == _DAGNodeOperationType.WRITE and self.requires_nccl + ) + other_is_nccl_write = ( + other.operation.type == _DAGNodeOperationType.WRITE and other.requires_nccl + ) + if is_nccl_write != other_is_nccl_write: + return not is_nccl_write + # If two nodes belong to different actors and both are either NCCL write + # nodes or neither are NCCL write nodes, select the one with the smaller + # `local_idx`. If they have the same `local_idx`, select the one with the + # smaller `dag_idx`. + if self.operation.local_idx != other.operation.local_idx: + return self.operation.local_idx < other.operation.local_idx + return self.dag_idx < other.dag_idx + + def __eq__(self, other: "_DAGOperationGraphNode"): + """ + Two operations are equal only when they have the same `local_idx` and `type` + and belong to the same actor. + """ + return ( + self.actor_handle == other.actor_handle + and self.operation.local_idx == other.operation.local_idx + and self.operation.type == other.operation.type + ) + + def __hash__(self): + """ + An operation is uniquely identified by its `dag_idx` and type. + """ + return hash((self.operation, self.dag_idx)) + + +def _add_edge(from_node: _DAGOperationGraphNode, to_node: _DAGOperationGraphNode): + """ + Add an edge from `from_node` to `to_node`. An edge is a tuple of + the operation's `dag_idx` and type. + """ + from_node.out_edges.add((to_node.dag_idx, to_node.operation.type)) + to_node.in_edges.add((from_node.dag_idx, from_node.operation.type)) + + +def _select_next_nodes( + actor_to_candidates: Dict["ray._raylet.ActorID", List[_DAGOperationGraphNode]], + graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], +): + """ + This function selects the next nodes for topological sort to generate execution + schedule. If there are multiple candidate _DAGOperationGraphNodes, select the node + with the top priority based on the following rules: + + #1 If two candidate nodes belong to the same actor, select the one with + the smaller `local_idx`. + + #2 If two candidate nodes belong to different actors and both are either NCCL + write nodes or neither are NCCL write nodes, select the one with the smaller + `local_idx`. If they have the same `local_idx`, select the one with the + smaller `dag_idx`. + + #3 If two candidate nodes belong to different actors and one of them is an NCCL + write node, select the one that is not an NCCL write node. + + For the implementation details, we maintain a priority queue for each actor, + where the head of the priority queue is the node with the smallest `local_idx`. + + If the selected node is an NCCL write node, select all its immediately downstream + nodes, which are NCCL read nodes, regardless of whether the downstream nodes are + heads of their own priority queues. In that case, this function only removes the + NCCL write node, which is also the head of a priority queue. Other nodes will be + removed in the following iterations. The NCCL read nodes will be returned even + though they should not yet be in the candidate list. + + Args: + actor_to_candidates: A dictionary mapping an actor id to a list of + candidate nodes. The list is maintained as a priority queue, so + the head of the queue, i.e., `candidates[0]`, is the node with + the smallest `bind_index`. + graph: A dictionary mapping the index of a task to a dictionary of its + _DAGOperationGraphNodes for different operations. + + Returns: + A list of _DAGOperationGraphNodes to be placed into the corresponding + execution schedules. + """ + top_priority_node = None + next_nodes: List[_DAGOperationGraphNode] = [] + for _, candidates in actor_to_candidates.items(): + if len(candidates) == 0: + continue + if top_priority_node is None or candidates[0] < top_priority_node: + top_priority_node = candidates[0] + assert top_priority_node is not None + next_nodes.append( + heapq.heappop(actor_to_candidates[top_priority_node.actor_handle._actor_id]) + ) + + if not ( + top_priority_node.operation.type == _DAGNodeOperationType.WRITE + and top_priority_node.requires_nccl + ): + assert len(next_nodes) == 1 + return next_nodes + + # An NCCL write node is picked. NCCL is a blocking operation, so we need to pick all + # the corresponding NCCL read nodes to avoid a deadlock. + for downstream_node_metadata in top_priority_node.out_edges: + dag_idx, op_type = downstream_node_metadata[0], downstream_node_metadata[1] + downstream_node = graph[dag_idx][op_type] + assert downstream_node.operation.type == _DAGNodeOperationType.READ + next_nodes.append(downstream_node) + assert len(next_nodes) == 1 + len(top_priority_node.out_edges) + return next_nodes + + +def _build_dag_node_operation_graph( + idx_to_task: Dict[int, "ray.dag.compiled_dag_node.CompiledTask"], + actor_to_operation_nodes: Dict[ + "ray.actor.ActorHandle", List[List[_DAGOperationGraphNode]] + ], +) -> Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]]: + """ + Generate a DAG node operation graph by adding edges based on the + following rules: + + #1 Add edges from READ to COMPUTE, and from COMPUTE to WRITE, which + belong to the same task. + #2 Add an edge from COMPUTE with bind_index i to COMPUTE with bind_index + i+1 if they belong to the same actor. + #3 Add an edge from WRITE of the writer task to READ of the reader task. + + This is the step one of building an execution schedule for each actor. + + Args: + idx_to_task: A dictionary that maps the `dag_idx` to the `CompiledTask`. + `CompiledTask` contains information about a DAGNode and its downstream + nodes. + + actor_to_operation_nodes: A dictionary that maps an actor handle to + a list of lists of _DAGOperationGraphNode. For the same actor, the + index of the outer list corresponds to the index of the ExecutableTask + in the list of `executable_tasks` in `actor_to_executable_tasks`. In + the inner list, the order of operations is READ, COMPUTE, and WRITE. + + Returns: + A graph where each node is a _DAGOperationGraphNode. The key is `dag_idx`, + the index to retrieve its task from `idx_to_task`, and the value is a + dictionary that maps the _DAGNodeOperationType (READ, COMPUTE, or WRITE) + to the corresponding _DAGOperationGraphNode + """ + assert idx_to_task + graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]] = {} + + for _, operation_nodes_list in actor_to_operation_nodes.items(): + prev_compute_node = None + for operation_nodes in operation_nodes_list: + dag_idx = operation_nodes[0].dag_idx + read_node, compute_node, write_node = ( + operation_nodes[0], + operation_nodes[1], + operation_nodes[2], + ) + # Add edges from READ to COMPUTE, and from COMPUTE to WRITE, which + # belong to the same task. + _add_edge(read_node, compute_node) + _add_edge(compute_node, write_node) + # Add an edge from COMPUTE with `bind_index` i to COMPUTE with + # `bind_index` i+1 if they belong to the same actor. + if prev_compute_node is not None: + _add_edge(prev_compute_node, compute_node) + prev_compute_node = compute_node + assert dag_idx not in graph + graph[dag_idx] = { + _DAGNodeOperationType.READ: read_node, + _DAGNodeOperationType.COMPUTE: compute_node, + _DAGNodeOperationType.WRITE: write_node, + } + + # Import `ray.dag` here to avoid circular import. + from ray.dag import ClassMethodNode, MultiOutputNode + + # Add an edge from WRITE of the writer task to READ of the reader task. + for dag_idx, task in idx_to_task.items(): + if not isinstance(task.dag_node, ClassMethodNode): + # The graph is used to generate an execution schedule for each actor. + # The edge from the InputNode has no impact on the final execution + # schedule. + continue + for downstream_dag_idx in task.downstream_node_idxs: + downstream_dag_node = idx_to_task[downstream_dag_idx].dag_node + if isinstance(downstream_dag_node, MultiOutputNode): + continue + _add_edge( + graph[dag_idx][_DAGNodeOperationType.WRITE], + graph[downstream_dag_idx][_DAGNodeOperationType.READ], + ) + return graph + + +def _generate_actor_to_execution_schedule( + graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]] +): + """ + Generate an execution schedule for each actor. The schedule is a list of + operations to be executed. The function uses a topological sort algorithm + to generate the schedule. + + Args: + graph: A graph where each node is a _DAGOperationGraphNode. The key is + `dag_idx`, the index to retrieve its task from `idx_to_task`, and + the value is a dictionary that maps the _DAGNodeOperationType (READ, + COMPUTE, or WRITE) to the corresponding _DAGOperationGraphNode. It is + generated by `_build_dag_node_operation_graph`. + + Returns: + actor_to_execution_schedule: A dictionary that maps an actor handle to + the execution schedule which is a list of operations to be executed. + """ + + # Mapping from the actor handle to the execution schedule which is a list + # of operations to be executed. + actor_to_execution_schedule: Dict[ + "ray.actor.ActorHandle", List[_DAGNodeOperation] + ] = defaultdict(list) + + # A dictionary mapping an actor id to a list of candidate nodes. The list + # is maintained as a priority queue, so the head of the queue, i.e., + # `candidates[0]`, is the node with the smallest `bind_index`. + actor_to_candidates: Dict[ + "ray._raylet.ActorID", List[_DAGOperationGraphNode] + ] = defaultdict(list) + for _, node_dict in graph.items(): + for _, node in node_dict.items(): + # A node with a zero in-degree edge means all of its dependencies + # have been satisfied, including both data and control dependencies. + # Therefore, it is a candidate for execution. + if node.in_degree == 0: + heapq.heappush(actor_to_candidates[node.actor_handle._actor_id], node) + + visited_nodes = set() + + # Use topological sort algorithm to generate the execution schedule. Each iteration + # pops a candidate node from `actor_to_candidates` and each DAG node consists of + # three operations: READ, COMPUTE, and WRITE. + for _ in range(len(graph) * 3): + # The function `_select_next_nodes` will pop a candidate node from + # `actor_to_candidates` and return a list of nodes that can be executed + # in the next step. If multiple nodes are returned, only the NCCL write + # node is popped in this iteration. + nodes = _select_next_nodes(actor_to_candidates, graph) + for node in nodes: + if node in visited_nodes: + continue + actor_to_execution_schedule[node.actor_handle].append(node.operation) + visited_nodes.add(node) + for out_node_dag_idx, out_node_type in node.out_edges: + out_node = graph[out_node_dag_idx][out_node_type] + out_node.in_edges.remove((node.dag_idx, node.operation.type)) + if out_node.in_degree == 0: + heapq.heappush( + actor_to_candidates[out_node.actor_handle._actor_id], + out_node, + ) + for _, candidates in actor_to_candidates.items(): + assert len(candidates) == 0 + return actor_to_execution_schedule diff --git a/python/ray/dag/tests/experimental/test_accelerated_dag.py b/python/ray/dag/tests/experimental/test_accelerated_dag.py index 859ea7c7ab0cf..5707956689e51 100644 --- a/python/ray/dag/tests/experimental/test_accelerated_dag.py +++ b/python/ray/dag/tests/experimental/test_accelerated_dag.py @@ -788,7 +788,7 @@ def test_compiled_dag_ref_del(ray_start_regular): compiled_dag.teardown() -def test_dag_fault_tolerance_chain(ray_start_regular_shared): +def test_dag_fault_tolerance_chain(ray_start_regular): actors = [ Actor.remote(0, fail_after=100 if i == 0 else None, sys_exit=False) for i in range(4) @@ -831,7 +831,7 @@ def test_dag_fault_tolerance_chain(ray_start_regular_shared): compiled_dag.teardown() -def test_dag_fault_tolerance(ray_start_regular_shared): +def test_dag_fault_tolerance(ray_start_regular): actors = [ Actor.remote(0, fail_after=100 if i == 0 else None, sys_exit=False) for i in range(4) @@ -872,7 +872,7 @@ def test_dag_fault_tolerance(ray_start_regular_shared): compiled_dag.teardown() -def test_dag_fault_tolerance_sys_exit(ray_start_regular_shared): +def test_dag_fault_tolerance_sys_exit(ray_start_regular): actors = [ Actor.remote(0, fail_after=100 if i == 0 else None, sys_exit=True) for i in range(4) @@ -912,7 +912,7 @@ def test_dag_fault_tolerance_sys_exit(ray_start_regular_shared): compiled_dag.teardown() -def test_dag_teardown_while_running(ray_start_regular_shared): +def test_dag_teardown_while_running(ray_start_regular): a = Actor.remote(0) with InputNode() as inp: @@ -939,7 +939,7 @@ def test_dag_teardown_while_running(ray_start_regular_shared): @pytest.mark.parametrize("max_queue_size", [None, 2]) -def test_asyncio(ray_start_regular_shared, max_queue_size): +def test_asyncio(ray_start_regular, max_queue_size): a = Actor.remote(0) with InputNode() as i: dag = a.echo.bind(i) @@ -965,7 +965,7 @@ async def main(i): @pytest.mark.parametrize("max_queue_size", [None, 2]) -def test_asyncio_exceptions(ray_start_regular_shared, max_queue_size): +def test_asyncio_exceptions(ray_start_regular, max_queue_size): a = Actor.remote(0) with InputNode() as i: dag = a.inc.bind(i) @@ -1005,7 +1005,7 @@ async def main(): class TestCompositeChannel: - def test_composite_channel_one_actor(self, ray_start_regular_shared): + def test_composite_channel_one_actor(self, ray_start_regular): """ In this test, there are three 'inc' tasks on the same Ray actor, chained together. Therefore, the DAG will look like this: @@ -1039,7 +1039,7 @@ def test_composite_channel_one_actor(self, ray_start_regular_shared): compiled_dag.teardown() - def test_composite_channel_two_actors(self, ray_start_regular_shared): + def test_composite_channel_two_actors(self, ray_start_regular): """ In this test, there are three 'inc' tasks on the two Ray actors, chained together. Therefore, the DAG will look like this: @@ -1073,7 +1073,7 @@ def test_composite_channel_two_actors(self, ray_start_regular_shared): compiled_dag.teardown() - def test_composite_channel_multi_output(self, ray_start_regular_shared): + def test_composite_channel_multi_output(self, ray_start_regular): """ Driver -> a.inc -> a.inc ---> Driver | | @@ -1100,7 +1100,7 @@ def test_composite_channel_multi_output(self, ray_start_regular_shared): compiled_dag.teardown() - def test_intra_process_channel_with_multi_readers(self, ray_start_regular_shared): + def test_intra_process_channel_with_multi_readers(self, ray_start_regular): """ In this test, there are three 'echo' tasks on the same Ray actor. The DAG will look like this: @@ -1134,7 +1134,7 @@ def test_intra_process_channel_with_multi_readers(self, ray_start_regular_shared compiled_dag.teardown() -def test_simulate_pipeline_parallelism(ray_start_regular_shared): +def test_simulate_pipeline_parallelism(ray_start_regular): """ This pattern simulates the case of pipeline parallelism training, where `w0_input` reads data from the driver, and the fan-out tasks, `d00`, `d01`, and `d02`, use @@ -1209,7 +1209,7 @@ def read_input(self, input): output_dag.teardown() -def test_channel_read_after_close(ray_start_regular_shared): +def test_channel_read_after_close(ray_start_regular): # Tests that read to a channel after accelerated DAG teardown raises a # RayChannelError exception as the channel is closed (see issue #46284). @ray.remote @@ -1229,7 +1229,7 @@ def foo(self, arg): ray.get(ref) -def test_channel_write_after_close(ray_start_regular_shared): +def test_channel_write_after_close(ray_start_regular): # Tests that write to a channel after accelerated DAG teardown raises a # RayChannelError exception as the channel is closed. @ray.remote diff --git a/python/ray/dag/tests/experimental/test_execution_schedule.py b/python/ray/dag/tests/experimental/test_execution_schedule.py new file mode 100644 index 0000000000000..46d285b63890e --- /dev/null +++ b/python/ray/dag/tests/experimental/test_execution_schedule.py @@ -0,0 +1,1001 @@ +# coding: utf-8 +import os +import sys + +import pytest + +from ray.tests.conftest import * # noqa +from ray.dag import InputNode, MultiOutputNode, ClassMethodNode +from ray.dag.dag_node_operation import ( + _DAGNodeOperationType, + _DAGOperationGraphNode, + _DAGNodeOperation, + _select_next_nodes, + _build_dag_node_operation_graph, + _add_edge, + _generate_actor_to_execution_schedule, +) +from ray.dag.compiled_dag_node import CompiledTask +from typing import List, Dict, Tuple +from ray.actor import ActorHandle + +if sys.platform != "linux" and sys.platform != "darwin": + pytest.skip("Skipping, requires Linux or Mac.", allow_module_level=True) + + +def mock_actor_handle_init(self, actor_id: str): + self._ray_actor_id = actor_id + + +def mock_init(self): + pass + + +def generate_dag_graph_nodes(local_idx, dag_idx, actor_handle, requires_nccl): + graph_nodes = {} + for op_type in _DAGNodeOperationType: + graph_nodes[op_type] = _DAGOperationGraphNode( + _DAGNodeOperation(local_idx, op_type), + dag_idx, + actor_handle, + requires_nccl, + ) + return graph_nodes + + +class TestSelectNextNodes: + """ + Test whether `_select_next_nodes` function selects the next nodes for + topological sort to generate execution schedule correctly. + + dag_idx: Each DAG node has a unique global index. + local_idx: The DAG node's index in the actor's `executable_tasks` list. + """ + + def test_two_candidates_on_same_actor(self, monkeypatch): + """ + Simulate the case where there are two candidates on the same actor. + The candidate with the smaller index in the `executable_tasks` list + should be selected. + + driver -> fake_actor.op -> fake_actor.op -> driver + + In the example above, both READ operations on the fake_actor have zero + in-degree. The operation with the smaller index in the executable_tasks + list should be selected first; therefore, the one on the left side will + be selected first. + """ + monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) + fake_actor = ActorHandle("fake_actor") + # The DAG node has a global index of 1, and its index in the + # actor's `executable_tasks` list is 0. + dag_idx_1 = 1 + dag_node_1 = _DAGOperationGraphNode( + _DAGNodeOperation(0, _DAGNodeOperationType.READ), + dag_idx_1, + fake_actor, + False, + ) + # The DAG node has a global index of 2, and its index in the + # actor's `executable_tasks` list is 1. + dag_idx_2 = 2 + dag_node_2 = _DAGOperationGraphNode( + _DAGNodeOperation(1, _DAGNodeOperationType.READ), + dag_idx_2, + fake_actor, + False, + ) + mock_actor_to_candidates = { + fake_actor: [ + dag_node_1, + dag_node_2, + ], + } + next_nodes = _select_next_nodes(mock_actor_to_candidates, None) + assert len(next_nodes) == 1 + assert next_nodes[0] == dag_node_1 + + def test_only_one_nccl_write(self, monkeypatch): + """ + Simulate the case where there is only one candidate which is a NCCL + WRITE operation. In this case, `_select_next_nodes` should return both + the NCCL WRITE operation and the corresponding READ operation. + + driver -> fake_actor_1.op -> fake_actor_2.op -> driver + + In the example above, communication between fake_actor_1 and fake_actor_2 + is done using NCCL. The following test case simulates a scenario where the + READ and COMPUTE operations on fake_actor_1 have already been added to the + execution schedule. + """ + monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) + fake_actor_1, dag_idx_1, local_idx_1 = ActorHandle("fake_actor_1"), 1, 0 + fake_actor_2, dag_idx_2, local_idx_2 = ActorHandle("fake_actor_2"), 2, 0 + mock_graph = { + dag_idx_1: generate_dag_graph_nodes( + local_idx_1, dag_idx_1, fake_actor_1, True + ), + dag_idx_2: generate_dag_graph_nodes( + local_idx_2, dag_idx_2, fake_actor_2, False + ), + } + del mock_graph[dag_idx_1][_DAGNodeOperationType.READ] + del mock_graph[dag_idx_1][_DAGNodeOperationType.COMPUTE] + + _add_edge( + mock_graph[dag_idx_1][_DAGNodeOperationType.WRITE], + mock_graph[dag_idx_2][_DAGNodeOperationType.READ], + ) + _add_edge( + mock_graph[dag_idx_2][_DAGNodeOperationType.READ], + mock_graph[dag_idx_2][_DAGNodeOperationType.COMPUTE], + ) + _add_edge( + mock_graph[dag_idx_2][_DAGNodeOperationType.COMPUTE], + mock_graph[dag_idx_2][_DAGNodeOperationType.WRITE], + ) + mock_actor_to_candidates = { + fake_actor_1: [mock_graph[dag_idx_1][_DAGNodeOperationType.WRITE]], + } + next_nodes = _select_next_nodes(mock_actor_to_candidates, mock_graph) + assert len(next_nodes) == 2 + assert next_nodes[0] == mock_graph[dag_idx_1][_DAGNodeOperationType.WRITE] + assert next_nodes[1] == mock_graph[dag_idx_2][_DAGNodeOperationType.READ] + + def test_two_nccl_writes(self, monkeypatch): + """ + Simulate a scenario where there are two candidates that are NCCL WRITE + operations. In this case, _select_next_nodes can choose either of the + two NCCL WRITE operations and their corresponding READ operations. + + driver -> fake_actor_1.op -> fake_actor_2.op -> driver + | | + -> fake_actor_2.op -> fake_actor_1.op - + + In the example above, communication between fake_actor_1 and fake_actor_2 is + done using NCCL. The following test case simulates a scenario where the READ + and COMPUTE operations on both the DAG nodes with smaller bind_index on + fake_actor_1 and fake_actor_2 have already been added to the execution schedule. + """ + monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) + + fake_actor_1 = ActorHandle("fake_actor_1") + dag_idx_1_0, local_idx_1_0 = 1, 0 + dag_idx_1_1, local_idx_1_1 = 3, 1 + fake_actor_2 = ActorHandle("fake_actor_2") + dag_idx_2_0, local_idx_2_0 = 2, 0 + dag_idx_2_1, local_idx_2_1 = 4, 1 + + # Run the test 10 times to ensure that the result of `_select_next_nodes` + # is deterministic. + for _ in range(20): + mock_graph = { + dag_idx_1_0: generate_dag_graph_nodes( + local_idx_1_0, dag_idx_1_0, fake_actor_1, True + ), + dag_idx_1_1: generate_dag_graph_nodes( + local_idx_1_1, dag_idx_1_1, fake_actor_1, False + ), + dag_idx_2_0: generate_dag_graph_nodes( + local_idx_2_0, dag_idx_2_0, fake_actor_2, True + ), + dag_idx_2_1: generate_dag_graph_nodes( + local_idx_2_1, dag_idx_2_1, fake_actor_2, False + ), + } + del mock_graph[dag_idx_1_0][_DAGNodeOperationType.READ] + del mock_graph[dag_idx_1_0][_DAGNodeOperationType.COMPUTE] + del mock_graph[dag_idx_2_0][_DAGNodeOperationType.READ] + del mock_graph[dag_idx_2_0][_DAGNodeOperationType.COMPUTE] + + _add_edge( + mock_graph[dag_idx_1_0][_DAGNodeOperationType.WRITE], + mock_graph[dag_idx_2_1][_DAGNodeOperationType.READ], + ) + _add_edge( + mock_graph[dag_idx_2_0][_DAGNodeOperationType.WRITE], + mock_graph[dag_idx_1_1][_DAGNodeOperationType.READ], + ) + _add_edge( + mock_graph[dag_idx_2_1][_DAGNodeOperationType.READ], + mock_graph[dag_idx_2_1][_DAGNodeOperationType.COMPUTE], + ) + _add_edge( + mock_graph[dag_idx_2_1][_DAGNodeOperationType.COMPUTE], + mock_graph[dag_idx_2_1][_DAGNodeOperationType.WRITE], + ) + _add_edge( + mock_graph[dag_idx_1_1][_DAGNodeOperationType.READ], + mock_graph[dag_idx_1_1][_DAGNodeOperationType.COMPUTE], + ) + _add_edge( + mock_graph[dag_idx_1_1][_DAGNodeOperationType.COMPUTE], + mock_graph[dag_idx_1_1][_DAGNodeOperationType.WRITE], + ) + mock_actor_to_candidates = { + fake_actor_1: [mock_graph[dag_idx_1_0][_DAGNodeOperationType.WRITE]], + fake_actor_2: [mock_graph[dag_idx_2_0][_DAGNodeOperationType.WRITE]], + } + + next_nodes = _select_next_nodes(mock_actor_to_candidates, mock_graph) + assert len(next_nodes) == 2 + assert next_nodes[0] == mock_graph[dag_idx_1_0][_DAGNodeOperationType.WRITE] + assert next_nodes[1] == mock_graph[dag_idx_2_1][_DAGNodeOperationType.READ] + + +class TestBuildDAGNodeOperationGraph: + """ + Test whether `_build_dag_node_operation_graph` function adds the correct + edges between the nodes in the operation graph base on the 3 rules mentioned + in the doc string of `_build_dag_node_operation_graph`. + """ + + def check_edges_between_read_compute_write( + self, + graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], + dag_idx: int, + expected_num_edges: List[Tuple[int, int]], + ): + """ + Check whether edges from READ to COMPUTE, and from COMPUTE to WRITE, + belonging to the same task are added. + + Args: + graph: The operation graph generated by `_build_dag_node_operation_graph`. + dag_idx: The global index of the task used to access the task in + `idx_to_task`. + expected_num_edges: A list of tuples where each tuple contains the expected + number of in-edges and out-edges for READ, COMPUTE, and WRITE + operations. + """ + assert len(expected_num_edges) == 3 + assert len(graph[dag_idx]) == 3 + read_node = graph[dag_idx][_DAGNodeOperationType.READ] + compute_node = graph[dag_idx][_DAGNodeOperationType.COMPUTE] + write_node = graph[dag_idx][_DAGNodeOperationType.WRITE] + + for idx, node in enumerate([read_node, compute_node, write_node]): + assert node.in_degree == expected_num_edges[idx][0] + assert len(node.out_edges) == expected_num_edges[idx][1] + + assert (dag_idx, _DAGNodeOperationType.COMPUTE) in read_node.out_edges + assert (dag_idx, _DAGNodeOperationType.READ) in compute_node.in_edges + assert (dag_idx, _DAGNodeOperationType.WRITE) in compute_node.out_edges + assert (dag_idx, _DAGNodeOperationType.COMPUTE) in write_node.in_edges + + def check_edge_between_writer_and_reader( + self, + graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], + writer_dag_idx: int, + reader_dag_idx: int, + ): + """ + Check whether the edge from writer's WRITE to reader's READ operation is added. + + Args: + graph: The operation graph generated by `_build_dag_node_operation_graph`. + writer_dag_idx: The index of the task used to access the task + that the writer belongs to in `idx_to_task`. + reader_dag_idx: The index of the task used to access the task + that the reader belongs to in `idx_to_task`. + """ + write_node = graph[writer_dag_idx][_DAGNodeOperationType.WRITE] + read_node = graph[reader_dag_idx][_DAGNodeOperationType.READ] + + assert (reader_dag_idx, _DAGNodeOperationType.READ) in write_node.out_edges + assert (writer_dag_idx, _DAGNodeOperationType.WRITE) in read_node.in_edges + + def check_edge_between_compute_nodes( + self, + graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], + dag_idx_1: int, + dag_idx_2: int, + ): + """ + Check whether the edge from COMPUTE with `bind_index` i to COMPUTE with + `bind_index` i+1 if they belong to the same actor. + + Args: + graph: The operation graph generated by `_build_dag_node_operation_graph`. + dag_idx_1: The index of the task used to access the task in + `idx_to_task`. + dag_idx_2: The index of the task used to access the task in + `idx_to_task`. Note that both tasks belong to the same actor, and the + `bind_index` of the second task is equal to the `bind_index` of the + first task plus one. + """ + compute_node_1 = graph[dag_idx_1][_DAGNodeOperationType.COMPUTE] + compute_node_2 = graph[dag_idx_2][_DAGNodeOperationType.COMPUTE] + + assert (dag_idx_2, _DAGNodeOperationType.COMPUTE) in compute_node_1.out_edges + assert (dag_idx_1, _DAGNodeOperationType.COMPUTE) in compute_node_2.in_edges + + def test_edges_between_read_compute_write(self, monkeypatch): + """ + driver -> fake_actor.op -> driver + + This test case aims to verify whether the function correctly adds edges + between READ/COMPUTE and COMPUTE/WRITE operations on the same actor. + """ + monkeypatch.setattr(ClassMethodNode, "__init__", mock_init) + monkeypatch.setattr(MultiOutputNode, "__init__", mock_init) + + idx_to_task = { + 0: CompiledTask(0, InputNode()), + 1: CompiledTask(1, ClassMethodNode()), + 2: CompiledTask(2, MultiOutputNode()), + } + + fake_actor = "fake_actor" + dag_idx = 1 + actor_to_operation_nodes = { + fake_actor: [ + list(generate_dag_graph_nodes(0, dag_idx, fake_actor, False).values()) + ] + } + graph = _build_dag_node_operation_graph(idx_to_task, actor_to_operation_nodes) + assert len(graph) == 1 + + self.check_edges_between_read_compute_write( + graph, dag_idx, [(0, 1), (1, 1), (1, 0)] + ) + + def test_edge_between_writer_and_reader(self, monkeypatch): + """ + driver -> fake_actor_1.op -> fake_actor_2.op -> driver + + This test case aims to verify whether the function correctly adds an edge + from the writer's WRITE operation to the reader's READ operation. + """ + monkeypatch.setattr(ClassMethodNode, "__init__", mock_init) + monkeypatch.setattr(MultiOutputNode, "__init__", mock_init) + + fake_actor_1, dag_idx_1 = "fake_actor_1", 1 + fake_actor_2, dag_idx_2 = "fake_actor_2", 2 + idx_to_task = { + 0: CompiledTask(0, InputNode()), + 1: CompiledTask(1, ClassMethodNode()), + 2: CompiledTask(2, ClassMethodNode()), + 3: CompiledTask(3, MultiOutputNode()), + } + idx_to_task[1].downstream_node_idxs = {2: fake_actor_2} + + actor_to_operation_nodes = { + fake_actor_1: [ + list( + generate_dag_graph_nodes(0, dag_idx_1, fake_actor_1, False).values() + ) + ], + fake_actor_2: [ + list( + generate_dag_graph_nodes(0, dag_idx_2, fake_actor_2, False).values() + ) + ], + } + graph = _build_dag_node_operation_graph(idx_to_task, actor_to_operation_nodes) + assert len(graph) == 2 + + self.check_edges_between_read_compute_write( + graph, dag_idx_1, [(0, 1), (1, 1), (1, 1)] + ) + self.check_edges_between_read_compute_write( + graph, dag_idx_2, [(1, 1), (1, 1), (1, 0)] + ) + self.check_edge_between_writer_and_reader(graph, dag_idx_1, dag_idx_2) + + def test_edge_between_compute_nodes(self, monkeypatch): + """ + driver -> fake_actor.op -> fake_actor.op -> driver + + This test case aims to verify whether the function correctly adds an edge + from the COMPUTE operation with `bind_index` i to the COMPUTE operation with + `bind_index` i+1 if they belong to the same actor. + """ + monkeypatch.setattr(ClassMethodNode, "__init__", mock_init) + monkeypatch.setattr(MultiOutputNode, "__init__", mock_init) + + fake_actor = "fake_actor" + dag_idx_1, dag_idx_2 = 1, 2 + idx_to_task = { + 0: CompiledTask(0, InputNode()), + dag_idx_1: CompiledTask(dag_idx_1, ClassMethodNode()), + dag_idx_2: CompiledTask(dag_idx_2, ClassMethodNode()), + 3: CompiledTask(3, MultiOutputNode()), + } + idx_to_task[dag_idx_1].downstream_node_idxs = {dag_idx_2: fake_actor} + + actor_to_operation_nodes = { + fake_actor: [ + list( + generate_dag_graph_nodes(0, dag_idx_1, fake_actor, False).values() + ), + list( + generate_dag_graph_nodes(1, dag_idx_2, fake_actor, False).values() + ), + ], + } + graph = _build_dag_node_operation_graph(idx_to_task, actor_to_operation_nodes) + assert len(graph) == 2 + + self.check_edges_between_read_compute_write( + graph, dag_idx_1, [(0, 1), (1, 2), (1, 1)] + ) + self.check_edges_between_read_compute_write( + graph, dag_idx_2, [(1, 1), (2, 1), (1, 0)] + ) + self.check_edge_between_writer_and_reader(graph, dag_idx_1, dag_idx_2) + self.check_edge_between_compute_nodes(graph, dag_idx_1, dag_idx_2) + + def test_two_actors(self, monkeypatch): + """ + driver -> fake_actor_1.op -> fake_actor_2.op -> driver + | | + -> fake_actor_2.op -> fake_actor_1.op - + + This test includes two actors, each with two tasks. The + test case covers all three rules for adding edges between + operation nodes in the operation graph. + """ + monkeypatch.setattr(ClassMethodNode, "__init__", mock_init) + monkeypatch.setattr(MultiOutputNode, "__init__", mock_init) + + fake_actor_1, dag_idx_1, dag_idx_3 = "fake_actor_1", 1, 3 + fake_actor_2, dag_idx_2, dag_idx_4 = "fake_actor_2", 2, 4 + + idx_to_task = { + 0: CompiledTask(0, InputNode()), + dag_idx_1: CompiledTask(dag_idx_1, ClassMethodNode()), + dag_idx_2: CompiledTask(dag_idx_2, ClassMethodNode()), + dag_idx_3: CompiledTask(dag_idx_3, ClassMethodNode()), + dag_idx_4: CompiledTask(dag_idx_4, ClassMethodNode()), + 5: CompiledTask(5, MultiOutputNode()), + } + idx_to_task[dag_idx_1].downstream_node_idxs = {dag_idx_4: fake_actor_2} + idx_to_task[dag_idx_2].downstream_node_idxs = {dag_idx_3: fake_actor_1} + + actor_to_operation_nodes = { + fake_actor_1: [ + list( + generate_dag_graph_nodes(0, dag_idx_1, fake_actor_1, False).values() + ), + list( + generate_dag_graph_nodes(1, dag_idx_3, fake_actor_1, False).values() + ), + ], + fake_actor_2: [ + list( + generate_dag_graph_nodes(0, dag_idx_2, fake_actor_2, False).values() + ), + list( + generate_dag_graph_nodes(1, dag_idx_4, fake_actor_2, False).values() + ), + ], + } + graph = _build_dag_node_operation_graph(idx_to_task, actor_to_operation_nodes) + assert len(graph) == 4 + + self.check_edges_between_read_compute_write( + graph, dag_idx_1, [(0, 1), (1, 2), (1, 1)] + ) + self.check_edges_between_read_compute_write( + graph, dag_idx_2, [(0, 1), (1, 2), (1, 1)] + ) + self.check_edges_between_read_compute_write( + graph, dag_idx_3, [(1, 1), (2, 1), (1, 0)] + ) + self.check_edges_between_read_compute_write( + graph, dag_idx_4, [(1, 1), (2, 1), (1, 0)] + ) + self.check_edge_between_writer_and_reader(graph, dag_idx_1, dag_idx_4) + self.check_edge_between_writer_and_reader(graph, dag_idx_2, dag_idx_3) + + +class TestGenerateActorToExecutionSchedule: + """ + Test whether `_generate_actor_to_execution_schedule` function generates the + correct execution schedule for each actor. + """ + + def add_edge_between_read_compute_write( + self, operations: Dict[_DAGNodeOperationType, _DAGOperationGraphNode] + ): + """ + Add edges between READ and COMPUTE, and between COMPUTE and WRITE operations + on the same actor. + + Args: + operations: A dictionary where the key is the operation type and the value + is the operation node. + """ + assert len(operations) == 3 + _add_edge( + operations[_DAGNodeOperationType.READ], + operations[_DAGNodeOperationType.COMPUTE], + ) + _add_edge( + operations[_DAGNodeOperationType.COMPUTE], + operations[_DAGNodeOperationType.WRITE], + ) + + def add_data_dependeny( + self, + writer_operations: Dict[_DAGNodeOperationType, _DAGOperationGraphNode], + reader_operations: Dict[_DAGNodeOperationType, _DAGOperationGraphNode], + ): + """ + Add a data dependency between the WRITE operation of the writer and the READ + operation of the reader. + + Args: + writer_operations: A dictionary where the key is the operation type and the + value is the operation node of the writer. + reader_operations: A dictionary where the key is the operation type and the + value is the operation node of the reader. + """ + _add_edge( + writer_operations[_DAGNodeOperationType.WRITE], + reader_operations[_DAGNodeOperationType.READ], + ) + + def add_control_dependency( + self, + operations_1: Dict[_DAGNodeOperationType, _DAGOperationGraphNode], + operations_2: Dict[_DAGNodeOperationType, _DAGOperationGraphNode], + ): + """ + Add a control dependency between the COMPUTE operation of the task with + bind_index i and the COMPUTE operation of the task with bind_index i+1 + on the same actor. + + Args: + operations_1: A dictionary where the key is the operation type and the value + is the operation node of the task with bind_index i. + operations_2: A dictionary where the key is the operation type and the value + is the operation node of the task with bind_index i+1. + """ + _add_edge( + operations_1[_DAGNodeOperationType.COMPUTE], + operations_2[_DAGNodeOperationType.COMPUTE], + ) + + def test_single_actor_1(self, monkeypatch): + """ + driver -> fake_actor.op (dag_idx_1) -> fake_actor.op (dag_idx_2) -> driver + + Test the case where there is only one actor and no NCCL operations. + Because there is no NCCL operation, all operations with smaller + `bind_index` should be executed before the operations with larger + `bind_index` on the same actor. + """ + monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) + + fake_actor = ActorHandle("fake_actor") + dag_idx_1, local_idx_1 = 1, 0 + dag_idx_2, local_idx_2 = 2, 1 + graph = { + dag_idx_1: generate_dag_graph_nodes( + local_idx_1, dag_idx_1, fake_actor, False + ), + dag_idx_2: generate_dag_graph_nodes( + local_idx_2, dag_idx_2, fake_actor, False + ), + } + self.add_edge_between_read_compute_write(graph[dag_idx_1]) + self.add_edge_between_read_compute_write(graph[dag_idx_2]) + self.add_data_dependeny(graph[dag_idx_1], graph[dag_idx_2]) + self.add_control_dependency(graph[dag_idx_1], graph[dag_idx_2]) + + actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) + assert len(actor_to_execution_schedule) == 1 + assert len(actor_to_execution_schedule[fake_actor]) == 6 + assert actor_to_execution_schedule[fake_actor] == [ + graph[dag_idx_1][_DAGNodeOperationType.READ].operation, + graph[dag_idx_1][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_1][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_2][_DAGNodeOperationType.READ].operation, + graph[dag_idx_2][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_2][_DAGNodeOperationType.WRITE].operation, + ] + + def test_single_actor_2(self, monkeypatch): + """ + driver -> fake_actor.op (dag_idx_1) -> fake_actor.op (dag_idx_2) -> driver + | | + -> fake_actor.op (dag_idx_3) - + + When the `dad_idx_1.WRITE` operation is picked, both `dag_idx_2.READ` and + `dag_idx_3.READ` operations should be zero in-degree. In this case, the one + with the smaller `bind_index` should be selected first. That is, + `dag_idx_2.READ` should be selected first. + """ + monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) + + fake_actor = ActorHandle("fake_actor") + dag_idx_1, local_idx_1 = 1, 0 + dag_idx_2, local_idx_2 = 2, 1 + dag_idx_3, local_idx_3 = 3, 2 + + graph = { + dag_idx_1: generate_dag_graph_nodes( + local_idx_1, dag_idx_1, fake_actor, False + ), + dag_idx_2: generate_dag_graph_nodes( + local_idx_2, dag_idx_2, fake_actor, False + ), + dag_idx_3: generate_dag_graph_nodes( + local_idx_3, dag_idx_3, fake_actor, False + ), + } + self.add_edge_between_read_compute_write(graph[dag_idx_1]) + self.add_edge_between_read_compute_write(graph[dag_idx_2]) + self.add_edge_between_read_compute_write(graph[dag_idx_3]) + self.add_data_dependeny(graph[dag_idx_1], graph[dag_idx_2]) + self.add_data_dependeny(graph[dag_idx_1], graph[dag_idx_3]) + self.add_control_dependency(graph[dag_idx_1], graph[dag_idx_2]) + self.add_control_dependency(graph[dag_idx_2], graph[dag_idx_3]) + + actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) + assert len(actor_to_execution_schedule) == 1 + assert len(actor_to_execution_schedule[fake_actor]) == 9 + assert actor_to_execution_schedule[fake_actor] == [ + graph[dag_idx_1][_DAGNodeOperationType.READ].operation, + graph[dag_idx_1][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_1][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_2][_DAGNodeOperationType.READ].operation, + graph[dag_idx_2][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_2][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_3][_DAGNodeOperationType.READ].operation, + graph[dag_idx_3][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_3][_DAGNodeOperationType.WRITE].operation, + ] + + def test_two_actors_no_nccl(self, monkeypatch): + """ + driver -> actor_1.op (dag_idx_1_1) -> actor_2.op (dag_idx_2_2) -> driver + | | + -> actor_2.op (dag_idx_2_1) -> actor_1.op (dag_idx_1_2) - + + Test the case where there are two actors and no NCCL operations. + Because there is no NCCL operation, all operations with smaller + `bind_index` should be executed before the operations with larger + `bind_index` on the same actor. + """ + monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) + + fake_actor_1 = ActorHandle("fake_actor_1") + dag_idx_1_1, local_idx_1_1 = 1, 0 + dag_idx_1_2, local_idx_1_2 = 4, 1 + + fake_actor_2 = ActorHandle("fake_actor_2") + dag_idx_2_1, local_idx_2_1 = 2, 0 + dag_idx_2_2, local_idx_2_2 = 3, 1 + + graph = { + dag_idx_1_1: generate_dag_graph_nodes( + local_idx_1_1, dag_idx_1_1, fake_actor_1, False + ), + dag_idx_2_1: generate_dag_graph_nodes( + local_idx_2_1, dag_idx_2_1, fake_actor_2, False + ), + dag_idx_2_2: generate_dag_graph_nodes( + local_idx_2_2, dag_idx_2_2, fake_actor_2, False + ), + dag_idx_1_2: generate_dag_graph_nodes( + local_idx_1_2, dag_idx_1_2, fake_actor_1, False + ), + } + self.add_edge_between_read_compute_write(graph[dag_idx_1_1]) + self.add_edge_between_read_compute_write(graph[dag_idx_1_2]) + self.add_edge_between_read_compute_write(graph[dag_idx_2_1]) + self.add_edge_between_read_compute_write(graph[dag_idx_2_2]) + self.add_data_dependeny(graph[dag_idx_1_1], graph[dag_idx_2_2]) + self.add_data_dependeny(graph[dag_idx_2_1], graph[dag_idx_1_2]) + self.add_control_dependency(graph[dag_idx_1_1], graph[dag_idx_1_2]) + self.add_control_dependency(graph[dag_idx_2_1], graph[dag_idx_2_2]) + + actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) + assert len(actor_to_execution_schedule) == 2 + assert len(actor_to_execution_schedule[fake_actor_1]) == 6 + assert len(actor_to_execution_schedule[fake_actor_2]) == 6 + + assert actor_to_execution_schedule[fake_actor_1] == [ + graph[dag_idx_1_1][_DAGNodeOperationType.READ].operation, + graph[dag_idx_1_1][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_1_1][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_1_2][_DAGNodeOperationType.READ].operation, + graph[dag_idx_1_2][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_1_2][_DAGNodeOperationType.WRITE].operation, + ] + assert actor_to_execution_schedule[fake_actor_2] == [ + graph[dag_idx_2_1][_DAGNodeOperationType.READ].operation, + graph[dag_idx_2_1][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_2_1][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_2_2][_DAGNodeOperationType.READ].operation, + graph[dag_idx_2_2][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_2_2][_DAGNodeOperationType.WRITE].operation, + ] + + def test_two_actors_with_nccl(self, monkeypatch): + """ + driver -> actor_1.op (dag_idx_1_1) -> actor_2.op (dag_idx_2_2) -> driver + | | + -> actor_2.op (dag_idx_2_1) -> actor_1.op (dag_idx_1_2) - + + In this test, the communication between fake_actor_1 and fake_actor_2 is done + using NCCL. When the dag_idx_1.WRITE operation is picked, the dag_idx_2.READ + operation is also added to the execution schedule because of the NCCL operation. + """ + monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) + + fake_actor_1 = ActorHandle("fake_actor_1") + dag_idx_1_1, local_idx_1_1 = 1, 0 + dag_idx_1_2, local_idx_1_2 = 4, 1 + + fake_actor_2 = ActorHandle("fake_actor_2") + dag_idx_2_1, local_idx_2_1 = 2, 0 + dag_idx_2_2, local_idx_2_2 = 3, 1 + + graph = { + dag_idx_1_1: generate_dag_graph_nodes( + local_idx_1_1, dag_idx_1_1, fake_actor_1, True + ), + dag_idx_2_1: generate_dag_graph_nodes( + local_idx_2_1, dag_idx_2_1, fake_actor_2, True + ), + dag_idx_2_2: generate_dag_graph_nodes( + local_idx_2_2, dag_idx_2_2, fake_actor_2, False + ), + dag_idx_1_2: generate_dag_graph_nodes( + local_idx_1_2, dag_idx_1_2, fake_actor_1, False + ), + } + self.add_edge_between_read_compute_write(graph[dag_idx_1_1]) + self.add_edge_between_read_compute_write(graph[dag_idx_1_2]) + self.add_edge_between_read_compute_write(graph[dag_idx_2_1]) + self.add_edge_between_read_compute_write(graph[dag_idx_2_2]) + self.add_data_dependeny(graph[dag_idx_1_1], graph[dag_idx_2_2]) + self.add_data_dependeny(graph[dag_idx_2_1], graph[dag_idx_1_2]) + self.add_control_dependency(graph[dag_idx_1_1], graph[dag_idx_1_2]) + self.add_control_dependency(graph[dag_idx_2_1], graph[dag_idx_2_2]) + + actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) + assert len(actor_to_execution_schedule) == 2 + assert len(actor_to_execution_schedule[fake_actor_1]) == 6 + assert len(actor_to_execution_schedule[fake_actor_2]) == 6 + + assert actor_to_execution_schedule[fake_actor_1] == [ + graph[dag_idx_1_1][_DAGNodeOperationType.READ].operation, + graph[dag_idx_1_1][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_1_1][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_1_2][_DAGNodeOperationType.READ].operation, + graph[dag_idx_1_2][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_1_2][_DAGNodeOperationType.WRITE].operation, + ] + assert actor_to_execution_schedule[fake_actor_2] == [ + graph[dag_idx_2_1][_DAGNodeOperationType.READ].operation, + graph[dag_idx_2_1][_DAGNodeOperationType.COMPUTE].operation, + # The order of `dag_idx_2_2.READ` and `dag_idx_2_2.COMPUTE` is important. + graph[dag_idx_2_2][_DAGNodeOperationType.READ].operation, + graph[dag_idx_2_1][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_2_2][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_2_2][_DAGNodeOperationType.WRITE].operation, + ] + + def test_simulate_pp_2workers_2batches_1f1b_with_nccl(self, monkeypatch): + """ + This test simulates a simple 1F1B pipeline parallelism for training with + 2 workers and 2 batches. + + w1: fwd_b1 fwd_b2 bwd_b1 bwd_b2 + w2: fwd_b1 bwd_b1 fwd_b2 bwd_b2 + + The communication between workers is done using NCCL. The communication + within the worker actor is done using IntraProcessChannel. + """ + monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) + + worker_1 = ActorHandle("worker_1") + dag_idx_1_1, local_idx_1_1 = 1, 0 + dag_idx_1_2, local_idx_1_2 = 2, 1 + dag_idx_1_3, local_idx_1_3 = 3, 2 + dag_idx_1_4, local_idx_1_4 = 4, 3 + worker_2 = ActorHandle("worker_2") + dag_idx_2_1, local_idx_2_1 = 5, 0 + dag_idx_2_2, local_idx_2_2 = 6, 1 + dag_idx_2_3, local_idx_2_3 = 7, 2 + dag_idx_2_4, local_idx_2_4 = 8, 3 + graph = { + dag_idx_1_1: generate_dag_graph_nodes( + local_idx_1_1, dag_idx_1_1, worker_1, True + ), + dag_idx_1_2: generate_dag_graph_nodes( + local_idx_1_2, dag_idx_1_2, worker_1, True + ), + dag_idx_1_3: generate_dag_graph_nodes( + local_idx_1_3, dag_idx_1_3, worker_1, False + ), + dag_idx_1_4: generate_dag_graph_nodes( + local_idx_1_4, dag_idx_1_4, worker_1, False + ), + dag_idx_2_1: generate_dag_graph_nodes( + local_idx_2_1, dag_idx_2_1, worker_2, False + ), + dag_idx_2_2: generate_dag_graph_nodes( + local_idx_2_2, dag_idx_2_2, worker_2, True + ), + dag_idx_2_3: generate_dag_graph_nodes( + local_idx_2_3, dag_idx_2_3, worker_2, False + ), + dag_idx_2_4: generate_dag_graph_nodes( + local_idx_2_4, dag_idx_2_4, worker_2, True + ), + } + self.add_edge_between_read_compute_write(graph[dag_idx_1_1]) + self.add_edge_between_read_compute_write(graph[dag_idx_1_2]) + self.add_edge_between_read_compute_write(graph[dag_idx_1_3]) + self.add_edge_between_read_compute_write(graph[dag_idx_1_4]) + self.add_edge_between_read_compute_write(graph[dag_idx_2_1]) + self.add_edge_between_read_compute_write(graph[dag_idx_2_2]) + self.add_edge_between_read_compute_write(graph[dag_idx_2_3]) + self.add_edge_between_read_compute_write(graph[dag_idx_2_4]) + self.add_data_dependeny(graph[dag_idx_1_1], graph[dag_idx_2_1]) + self.add_data_dependeny(graph[dag_idx_2_1], graph[dag_idx_2_2]) + self.add_data_dependeny(graph[dag_idx_2_2], graph[dag_idx_1_3]) + self.add_data_dependeny(graph[dag_idx_1_2], graph[dag_idx_2_3]) + self.add_data_dependeny(graph[dag_idx_2_3], graph[dag_idx_2_4]) + self.add_data_dependeny(graph[dag_idx_2_4], graph[dag_idx_1_4]) + self.add_control_dependency(graph[dag_idx_1_1], graph[dag_idx_1_2]) + self.add_control_dependency(graph[dag_idx_1_2], graph[dag_idx_1_3]) + self.add_control_dependency(graph[dag_idx_1_3], graph[dag_idx_1_4]) + self.add_control_dependency(graph[dag_idx_2_1], graph[dag_idx_2_2]) + self.add_control_dependency(graph[dag_idx_2_2], graph[dag_idx_2_3]) + self.add_control_dependency(graph[dag_idx_2_3], graph[dag_idx_2_4]) + + actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) + assert len(actor_to_execution_schedule) == 2 + assert len(actor_to_execution_schedule[worker_1]) == 12 + assert len(actor_to_execution_schedule[worker_2]) == 12 + assert actor_to_execution_schedule[worker_1] == [ + graph[dag_idx_1_1][_DAGNodeOperationType.READ].operation, + graph[dag_idx_1_1][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_1_1][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_1_2][_DAGNodeOperationType.READ].operation, + graph[dag_idx_1_2][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_1_2][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_1_3][_DAGNodeOperationType.READ].operation, + graph[dag_idx_1_3][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_1_3][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_1_4][_DAGNodeOperationType.READ].operation, + graph[dag_idx_1_4][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_1_4][_DAGNodeOperationType.WRITE].operation, + ] + assert actor_to_execution_schedule[worker_2] == [ + graph[dag_idx_2_1][_DAGNodeOperationType.READ].operation, + graph[dag_idx_2_1][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_2_1][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_2_2][_DAGNodeOperationType.READ].operation, + graph[dag_idx_2_2][_DAGNodeOperationType.COMPUTE].operation, + # The order of `dag_idx_2_3.READ` and `dag_idx_2_2.WRITE` is important. + graph[dag_idx_2_3][_DAGNodeOperationType.READ].operation, + graph[dag_idx_2_2][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_2_3][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_2_3][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_2_4][_DAGNodeOperationType.READ].operation, + graph[dag_idx_2_4][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_2_4][_DAGNodeOperationType.WRITE].operation, + ] + + def test_simulate_pp_2workers_2batches_1f1b_no_nccl(self, monkeypatch): + """ + This test simulates a simple 1F1B pipeline parallelism for training with + 2 workers and 2 batches. + + w1: fwd_b1 fwd_b2 bwd_b1 bwd_b2 + w2: fwd_b1 bwd_b1 fwd_b2 bwd_b2 + + Because there is no NCCL operation, all operations with smaller + `bind_index` should be executed before the operations with larger + `bind_index` on the same actor. + """ + monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) + + worker_1 = ActorHandle("worker_1") + dag_idx_1_1, local_idx_1_1 = 1, 0 + dag_idx_1_2, local_idx_1_2 = 2, 1 + dag_idx_1_3, local_idx_1_3 = 3, 2 + dag_idx_1_4, local_idx_1_4 = 4, 3 + worker_2 = ActorHandle("worker_2") + dag_idx_2_1, local_idx_2_1 = 5, 0 + dag_idx_2_2, local_idx_2_2 = 6, 1 + dag_idx_2_3, local_idx_2_3 = 7, 2 + dag_idx_2_4, local_idx_2_4 = 8, 3 + + # No NCCL operation. + graph = { + dag_idx_1_1: generate_dag_graph_nodes( + local_idx_1_1, dag_idx_1_1, worker_1, False + ), + dag_idx_1_2: generate_dag_graph_nodes( + local_idx_1_2, dag_idx_1_2, worker_1, False + ), + dag_idx_1_3: generate_dag_graph_nodes( + local_idx_1_3, dag_idx_1_3, worker_1, False + ), + dag_idx_1_4: generate_dag_graph_nodes( + local_idx_1_4, dag_idx_1_4, worker_1, False + ), + dag_idx_2_1: generate_dag_graph_nodes( + local_idx_2_1, dag_idx_2_1, worker_2, False + ), + dag_idx_2_2: generate_dag_graph_nodes( + local_idx_2_2, dag_idx_2_2, worker_2, False + ), + dag_idx_2_3: generate_dag_graph_nodes( + local_idx_2_3, dag_idx_2_3, worker_2, False + ), + dag_idx_2_4: generate_dag_graph_nodes( + local_idx_2_4, dag_idx_2_4, worker_2, False + ), + } + self.add_edge_between_read_compute_write(graph[dag_idx_1_1]) + self.add_edge_between_read_compute_write(graph[dag_idx_1_2]) + self.add_edge_between_read_compute_write(graph[dag_idx_1_3]) + self.add_edge_between_read_compute_write(graph[dag_idx_1_4]) + self.add_edge_between_read_compute_write(graph[dag_idx_2_1]) + self.add_edge_between_read_compute_write(graph[dag_idx_2_2]) + self.add_edge_between_read_compute_write(graph[dag_idx_2_3]) + self.add_edge_between_read_compute_write(graph[dag_idx_2_4]) + self.add_data_dependeny(graph[dag_idx_1_1], graph[dag_idx_2_1]) + self.add_data_dependeny(graph[dag_idx_2_1], graph[dag_idx_2_2]) + self.add_data_dependeny(graph[dag_idx_2_2], graph[dag_idx_1_3]) + self.add_data_dependeny(graph[dag_idx_1_2], graph[dag_idx_2_3]) + self.add_data_dependeny(graph[dag_idx_2_3], graph[dag_idx_2_4]) + self.add_data_dependeny(graph[dag_idx_2_4], graph[dag_idx_1_4]) + self.add_control_dependency(graph[dag_idx_1_1], graph[dag_idx_1_2]) + self.add_control_dependency(graph[dag_idx_1_2], graph[dag_idx_1_3]) + self.add_control_dependency(graph[dag_idx_1_3], graph[dag_idx_1_4]) + self.add_control_dependency(graph[dag_idx_2_1], graph[dag_idx_2_2]) + self.add_control_dependency(graph[dag_idx_2_2], graph[dag_idx_2_3]) + self.add_control_dependency(graph[dag_idx_2_3], graph[dag_idx_2_4]) + + actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph) + assert len(actor_to_execution_schedule) == 2 + assert len(actor_to_execution_schedule[worker_1]) == 12 + assert len(actor_to_execution_schedule[worker_2]) == 12 + assert actor_to_execution_schedule[worker_1] == [ + graph[dag_idx_1_1][_DAGNodeOperationType.READ].operation, + graph[dag_idx_1_1][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_1_1][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_1_2][_DAGNodeOperationType.READ].operation, + graph[dag_idx_1_2][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_1_2][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_1_3][_DAGNodeOperationType.READ].operation, + graph[dag_idx_1_3][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_1_3][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_1_4][_DAGNodeOperationType.READ].operation, + graph[dag_idx_1_4][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_1_4][_DAGNodeOperationType.WRITE].operation, + ] + assert actor_to_execution_schedule[worker_2] == [ + graph[dag_idx_2_1][_DAGNodeOperationType.READ].operation, + graph[dag_idx_2_1][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_2_1][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_2_2][_DAGNodeOperationType.READ].operation, + graph[dag_idx_2_2][_DAGNodeOperationType.COMPUTE].operation, + # The order of `dag_idx_2_3.READ` and `dag_idx_2_2.WRITE` is important. + # It is different from the case where there is an NCCL operation. + graph[dag_idx_2_2][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_2_3][_DAGNodeOperationType.READ].operation, + graph[dag_idx_2_3][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_2_3][_DAGNodeOperationType.WRITE].operation, + graph[dag_idx_2_4][_DAGNodeOperationType.READ].operation, + graph[dag_idx_2_4][_DAGNodeOperationType.COMPUTE].operation, + graph[dag_idx_2_4][_DAGNodeOperationType.WRITE].operation, + ] + + +if __name__ == "__main__": + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py b/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py new file mode 100644 index 0000000000000..3b07b5f7615eb --- /dev/null +++ b/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py @@ -0,0 +1,373 @@ +# coding: utf-8 +import os +import sys + +import pytest + +import ray +import ray.cluster_utils +from ray.experimental.channel.torch_tensor_type import TorchTensorType +from ray.tests.conftest import * # noqa +from ray.dag import InputNode, MultiOutputNode +from ray.dag.dag_node_operation import _DAGNodeOperationType +import torch +from typing import Optional +from ray.dag.compiled_dag_node import CompiledDAG + +if sys.platform != "linux" and sys.platform != "darwin": + pytest.skip("Skipping, requires Linux or Mac.", allow_module_level=True) + +USE_GPU = bool(os.environ.get("RAY_PYTEST_USE_GPU", 0)) + + +@ray.remote(num_cpus=0, num_gpus=1) +class Worker: + def __init__(self, rank: Optional[int] = None): + self.rank = rank + self.trace = [] + + def fwd(self, value): + self.trace.append(("FWD", self.rank)) + return value + + def bwd(self, value): + self.trace.append(("BWD", self.rank)) + return value + + def pop_trace(self): + trace = self.trace + self.trace = [] + return trace + + def read_input(self, input): + return input + + def no_op(self, value): + return value + + def no_op_two(self, value1, value2): + return value1, value2 + + +def generate_1f1b_dag( + num_workers: int, num_microbatches: int, num_lead_microbatches: int +) -> CompiledDAG: + workers = [Worker.remote(rank) for rank in range(num_workers)] + + with ray.dag.InputNode() as inp: + fwd_queues = [[] for _ in range(num_workers)] + bwd_queues = [[] for _ in range(num_workers)] + # Once a worker's counter reaches 0, it cannot execute another fwd until it + # executes a bwd first. + fwd_counter = [num_lead_microbatches - i for i in range(num_workers)] + # All of the done batches. + done = [] + + # FWD on worker 0. + input_data = workers[0].read_input.bind(inp) + for i in range(num_microbatches): + fwd_queues[0].append(input_data) + + while len(done) < num_microbatches: + for i, worker in enumerate(workers): + if fwd_counter[i] > 0 and fwd_queues[i]: + b = fwd_queues[i].pop(0) + b = worker.fwd.bind(b) + if i < num_workers - 1: + fwd_queues[i + 1].append(b) + # Use NCCL channel for communication between workers. + b.with_type_hint( + TorchTensorType(transport=TorchTensorType.NCCL) + ) + else: + bwd_queues[i].append(b) + fwd_counter[i] -= 1 + elif bwd_queues[i]: + b = bwd_queues[i].pop(0) + b = worker.bwd.bind(b) + if i > 0: + bwd_queues[i - 1].append(b) + # Use NCCL channel for communication between workers. + b.with_type_hint( + TorchTensorType(transport=TorchTensorType.NCCL) + ) + else: + done.append(b) + fwd_counter[i] += 1 + dag = ray.dag.MultiOutputNode(done) + compiled_dag = dag.experimental_compile() + return compiled_dag + + +@pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 2}], indirect=True) +def test_simulate_pp_2workers_2batches_1f1b(ray_start_regular, monkeypatch): + """ + This test simulates a simple 1F1B pipeline parallelism for training with + 2 workers and 2 batches. + + w1: fwd_b1 fwd_b2 bwd_b1 bwd_b2 + w2: fwd_b1 bwd_b1 fwd_b2 bwd_b2 + + The communication between workers is done using NCCL. The communication + within the worker actor is done using IntraProcessChannel. + """ + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + monkeypatch.setattr(ray.dag.constants, "RAY_ADAG_ENABLE_DETECT_DEADLOCK", False) + + w1 = Worker.remote() + w2 = Worker.remote() + + with InputNode() as inp: + w1_input = w1.read_input.bind(inp) + batch_1 = w1.fwd.bind(w1_input) + batch_1.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL)) + batch_2 = w1.fwd.bind(w1_input) + batch_2.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL)) + batch_1 = w2.fwd.bind(batch_1) + batch_1 = w2.bwd.bind(batch_1) + batch_1.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL)) + batch_2 = w2.fwd.bind(batch_2) + batch_1 = w1.bwd.bind(batch_1) + batch_2 = w2.bwd.bind(batch_2) + batch_2.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL)) + batch_2 = w1.bwd.bind(batch_2) + dag = MultiOutputNode( + [ + batch_1, + batch_2, + ] + ) + compiled_dag = dag.experimental_compile() + + w1_expected_schedule = [ + (0, _DAGNodeOperationType.READ), + (0, _DAGNodeOperationType.COMPUTE), + (0, _DAGNodeOperationType.WRITE), + (1, _DAGNodeOperationType.READ), + (1, _DAGNodeOperationType.COMPUTE), + (1, _DAGNodeOperationType.WRITE), + (2, _DAGNodeOperationType.READ), + (2, _DAGNodeOperationType.COMPUTE), + (3, _DAGNodeOperationType.READ), + (2, _DAGNodeOperationType.WRITE), + (3, _DAGNodeOperationType.COMPUTE), + (3, _DAGNodeOperationType.WRITE), + (4, _DAGNodeOperationType.READ), + (4, _DAGNodeOperationType.COMPUTE), + (4, _DAGNodeOperationType.WRITE), + ] + w2_expected_schedule = [ + (0, _DAGNodeOperationType.READ), + (0, _DAGNodeOperationType.COMPUTE), + (0, _DAGNodeOperationType.WRITE), + (1, _DAGNodeOperationType.READ), + (1, _DAGNodeOperationType.COMPUTE), + (1, _DAGNodeOperationType.WRITE), + (2, _DAGNodeOperationType.READ), + (2, _DAGNodeOperationType.COMPUTE), + (2, _DAGNodeOperationType.WRITE), + (3, _DAGNodeOperationType.READ), + (3, _DAGNodeOperationType.COMPUTE), + (3, _DAGNodeOperationType.WRITE), + ] + w1_schedule = compiled_dag.actor_to_execution_schedule[w1] + w2_schedule = compiled_dag.actor_to_execution_schedule[w2] + + for schedule, expected_schedule in zip( + [w1_schedule, w2_schedule], [w1_expected_schedule, w2_expected_schedule] + ): + assert len(schedule) == len(expected_schedule) + for i, operation in enumerate(schedule): + assert operation.local_idx == expected_schedule[i][0] + assert operation.type == expected_schedule[i][1] + + tensor_cpu = torch.zeros(10, 10) + ref = compiled_dag.execute(tensor_cpu) + tensors = ray.get(ref) + tensor_cuda = tensor_cpu.to("cuda:0") + + assert len(tensors) == 2 + for t in tensors: + assert torch.equal(t, tensor_cuda) + + compiled_dag.teardown() + + +@pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 4}], indirect=True) +def test_simulate_pp_4workers_8batches_1f1b(ray_start_regular, monkeypatch): + """ + This test simulates a 1F1B pipeline parallelism for training with + 4 workers and 8 batches. + """ + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + monkeypatch.setattr(ray.dag.constants, "RAY_ADAG_ENABLE_DETECT_DEADLOCK", False) + + num_workers, num_microbatches, num_lead_microbatches = 4, 8, 4 + compiled_dag = generate_1f1b_dag( + num_workers, num_microbatches, num_lead_microbatches + ) + + tensor_cpu = torch.zeros(10, 10) + tensors = ray.get(compiled_dag.execute(tensor_cpu)) + tensor_cuda = tensor_cpu.to("cuda:0") + assert len(tensors) == num_microbatches + for t in tensors: + assert torch.equal(t, tensor_cuda) + compiled_dag.teardown() + + +@pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 3}], indirect=True) +def test_three_actors_with_nccl_1(ray_start_regular): + """ + Driver -> a.no_op -> b.no_op -> a.no_op_two -> Driver + | | + -> c.no_op - + """ + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + a = Worker.remote() + b = Worker.remote() + c = Worker.remote() + + with InputNode() as inp: + dag = a.no_op.bind(inp) + dag.with_type_hint(TorchTensorType(transport="nccl")) + branch1 = b.no_op.bind(dag) + branch1.with_type_hint(TorchTensorType(transport="nccl")) + branch2 = c.no_op.bind(dag) + branch2.with_type_hint(TorchTensorType(transport="nccl")) + dag = a.no_op_two.bind(branch1, branch2) + + compiled_dag = dag.experimental_compile() + + a_expected_schedule = [ + (0, _DAGNodeOperationType.READ), + (0, _DAGNodeOperationType.COMPUTE), + (0, _DAGNodeOperationType.WRITE), + (1, _DAGNodeOperationType.READ), + (1, _DAGNodeOperationType.COMPUTE), + (1, _DAGNodeOperationType.WRITE), + ] + b_expected_schedule = [ + (0, _DAGNodeOperationType.READ), + (0, _DAGNodeOperationType.COMPUTE), + (0, _DAGNodeOperationType.WRITE), + ] + c_expected_schedule = [ + (0, _DAGNodeOperationType.READ), + (0, _DAGNodeOperationType.COMPUTE), + (0, _DAGNodeOperationType.WRITE), + ] + a_schedule = compiled_dag.actor_to_execution_schedule[a] + b_schedule = compiled_dag.actor_to_execution_schedule[b] + c_schedule = compiled_dag.actor_to_execution_schedule[c] + + for schedule, expected_schedule in zip( + [a_schedule, b_schedule, c_schedule], + [a_expected_schedule, b_expected_schedule, c_expected_schedule], + ): + assert len(schedule) == len(expected_schedule) + for i, operation in enumerate(schedule): + assert operation.local_idx == expected_schedule[i][0] + assert operation.type == expected_schedule[i][1] + + tensor_cpu = torch.zeros(10, 10) + ref = compiled_dag.execute(tensor_cpu) + tensors = ray.get(ref) + tensor_cuda = tensor_cpu.to("cuda:0") + + assert len(tensors) == 2 + for t in tensors: + assert torch.equal(t, tensor_cuda) + + compiled_dag.teardown() + + +@pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 3}], indirect=True) +def test_three_actors_with_nccl_2(ray_start_regular, monkeypatch): + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + monkeypatch.setattr(ray.dag.constants, "RAY_ADAG_ENABLE_DETECT_DEADLOCK", False) + + a = Worker.remote() + b = Worker.remote() + c = Worker.remote() + + with InputNode() as inp: + branch1 = a.no_op.bind(inp) + branch1.with_type_hint(TorchTensorType(transport="nccl")) + branch2 = b.no_op.bind(inp) + branch2.with_type_hint(TorchTensorType(transport="nccl")) + branch3 = c.no_op.bind(inp) + branch3.with_type_hint(TorchTensorType(transport="nccl")) + dag = MultiOutputNode( + [ + a.no_op.bind(branch3), + b.no_op.bind(branch1), + c.no_op.bind(branch2), + ] + ) + + compiled_dag = dag.experimental_compile() + + a_expected_schedule = [ + (0, _DAGNodeOperationType.READ), + (0, _DAGNodeOperationType.COMPUTE), + (1, _DAGNodeOperationType.READ), + (0, _DAGNodeOperationType.WRITE), + (1, _DAGNodeOperationType.COMPUTE), + (1, _DAGNodeOperationType.WRITE), + ] + b_expected_schedule = [ + (0, _DAGNodeOperationType.READ), + (0, _DAGNodeOperationType.COMPUTE), + (1, _DAGNodeOperationType.READ), + (0, _DAGNodeOperationType.WRITE), + (1, _DAGNodeOperationType.COMPUTE), + (1, _DAGNodeOperationType.WRITE), + ] + c_expected_schedule = [ + (0, _DAGNodeOperationType.READ), + (0, _DAGNodeOperationType.COMPUTE), + (0, _DAGNodeOperationType.WRITE), + (1, _DAGNodeOperationType.READ), + (1, _DAGNodeOperationType.COMPUTE), + (1, _DAGNodeOperationType.WRITE), + ] + + a_schedule = compiled_dag.actor_to_execution_schedule[a] + b_schedule = compiled_dag.actor_to_execution_schedule[b] + c_schedule = compiled_dag.actor_to_execution_schedule[c] + + for schedule, expected_schedule in zip( + [a_schedule, b_schedule, c_schedule], + [a_expected_schedule, b_expected_schedule, c_expected_schedule], + ): + assert len(schedule) == len(expected_schedule) + for i, operation in enumerate(schedule): + assert operation.local_idx == expected_schedule[i][0] + assert operation.type == expected_schedule[i][1] + + tensor_cpu = torch.zeros(10, 10) + ref = compiled_dag.execute(tensor_cpu) + tensors = ray.get(ref) + tensor_cuda = tensor_cpu.to("cuda:0") + + assert len(tensors) == 3 + for t in tensors: + assert torch.equal(t, tensor_cuda) + + compiled_dag.teardown() + + +if __name__ == "__main__": + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__]))