Skip to content

Commit c45298b

Browse files
committed
add a test for 1f1b without nccl
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
1 parent c98362b commit c45298b

File tree

1 file changed

+109
-1
lines changed

1 file changed

+109
-1
lines changed

python/ray/dag/tests/experimental/test_execution_schedule.py

+109-1
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def test_two_actors_with_nccl(self, monkeypatch):
780780
graph[dag_idx_2_2][_DAGNodeOperationType.WRITE].operation,
781781
]
782782

783-
def test_simulate_pp_2workers_2batches_1f1b(self, monkeypatch):
783+
def test_simulate_pp_2workers_2batches_1f1b_with_nccl(self, monkeypatch):
784784
"""
785785
This test simulates a simple 1F1B pipeline parallelism for training with
786786
2 workers and 2 batches.
@@ -884,6 +884,114 @@ def test_simulate_pp_2workers_2batches_1f1b(self, monkeypatch):
884884
graph[dag_idx_2_4][_DAGNodeOperationType.WRITE].operation,
885885
]
886886

887+
def test_simulate_pp_2workers_2batches_1f1b_no_nccl(self, monkeypatch):
888+
"""
889+
This test simulates a simple 1F1B pipeline parallelism for training with
890+
2 workers and 2 batches.
891+
892+
w1: fwd_b1 fwd_b2 bwd_b1 bwd_b2
893+
w2: fwd_b1 bwd_b1 fwd_b2 bwd_b2
894+
895+
Because there is no NCCL operation, all operations with smaller
896+
`bind_index` should be executed before the operations with larger
897+
`bind_index` on the same actor.
898+
"""
899+
monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init)
900+
901+
worker_1 = ActorHandle("worker_1")
902+
dag_idx_1_1, local_idx_1_1 = 1, 0
903+
dag_idx_1_2, local_idx_1_2 = 2, 1
904+
dag_idx_1_3, local_idx_1_3 = 3, 2
905+
dag_idx_1_4, local_idx_1_4 = 4, 3
906+
worker_2 = ActorHandle("worker_2")
907+
dag_idx_2_1, local_idx_2_1 = 5, 0
908+
dag_idx_2_2, local_idx_2_2 = 6, 1
909+
dag_idx_2_3, local_idx_2_3 = 7, 2
910+
dag_idx_2_4, local_idx_2_4 = 8, 3
911+
912+
# No NCCL operation.
913+
graph = {
914+
dag_idx_1_1: generate_dag_graph_nodes(
915+
local_idx_1_1, dag_idx_1_1, worker_1, False
916+
),
917+
dag_idx_1_2: generate_dag_graph_nodes(
918+
local_idx_1_2, dag_idx_1_2, worker_1, False
919+
),
920+
dag_idx_1_3: generate_dag_graph_nodes(
921+
local_idx_1_3, dag_idx_1_3, worker_1, False
922+
),
923+
dag_idx_1_4: generate_dag_graph_nodes(
924+
local_idx_1_4, dag_idx_1_4, worker_1, False
925+
),
926+
dag_idx_2_1: generate_dag_graph_nodes(
927+
local_idx_2_1, dag_idx_2_1, worker_2, False
928+
),
929+
dag_idx_2_2: generate_dag_graph_nodes(
930+
local_idx_2_2, dag_idx_2_2, worker_2, False
931+
),
932+
dag_idx_2_3: generate_dag_graph_nodes(
933+
local_idx_2_3, dag_idx_2_3, worker_2, False
934+
),
935+
dag_idx_2_4: generate_dag_graph_nodes(
936+
local_idx_2_4, dag_idx_2_4, worker_2, False
937+
),
938+
}
939+
self.add_edge_between_read_compute_write(graph[dag_idx_1_1])
940+
self.add_edge_between_read_compute_write(graph[dag_idx_1_2])
941+
self.add_edge_between_read_compute_write(graph[dag_idx_1_3])
942+
self.add_edge_between_read_compute_write(graph[dag_idx_1_4])
943+
self.add_edge_between_read_compute_write(graph[dag_idx_2_1])
944+
self.add_edge_between_read_compute_write(graph[dag_idx_2_2])
945+
self.add_edge_between_read_compute_write(graph[dag_idx_2_3])
946+
self.add_edge_between_read_compute_write(graph[dag_idx_2_4])
947+
self.add_data_dependeny(graph[dag_idx_1_1], graph[dag_idx_2_1])
948+
self.add_data_dependeny(graph[dag_idx_2_1], graph[dag_idx_2_2])
949+
self.add_data_dependeny(graph[dag_idx_2_2], graph[dag_idx_1_3])
950+
self.add_data_dependeny(graph[dag_idx_1_2], graph[dag_idx_2_3])
951+
self.add_data_dependeny(graph[dag_idx_2_3], graph[dag_idx_2_4])
952+
self.add_data_dependeny(graph[dag_idx_2_4], graph[dag_idx_1_4])
953+
self.add_control_dependency(graph[dag_idx_1_1], graph[dag_idx_1_2])
954+
self.add_control_dependency(graph[dag_idx_1_2], graph[dag_idx_1_3])
955+
self.add_control_dependency(graph[dag_idx_1_3], graph[dag_idx_1_4])
956+
self.add_control_dependency(graph[dag_idx_2_1], graph[dag_idx_2_2])
957+
self.add_control_dependency(graph[dag_idx_2_2], graph[dag_idx_2_3])
958+
self.add_control_dependency(graph[dag_idx_2_3], graph[dag_idx_2_4])
959+
960+
actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph)
961+
assert len(actor_to_execution_schedule) == 2
962+
assert len(actor_to_execution_schedule[worker_1]) == 12
963+
assert len(actor_to_execution_schedule[worker_2]) == 12
964+
assert actor_to_execution_schedule[worker_1] == [
965+
graph[dag_idx_1_1][_DAGNodeOperationType.READ].operation,
966+
graph[dag_idx_1_1][_DAGNodeOperationType.COMPUTE].operation,
967+
graph[dag_idx_1_1][_DAGNodeOperationType.WRITE].operation,
968+
graph[dag_idx_1_2][_DAGNodeOperationType.READ].operation,
969+
graph[dag_idx_1_2][_DAGNodeOperationType.COMPUTE].operation,
970+
graph[dag_idx_1_2][_DAGNodeOperationType.WRITE].operation,
971+
graph[dag_idx_1_3][_DAGNodeOperationType.READ].operation,
972+
graph[dag_idx_1_3][_DAGNodeOperationType.COMPUTE].operation,
973+
graph[dag_idx_1_3][_DAGNodeOperationType.WRITE].operation,
974+
graph[dag_idx_1_4][_DAGNodeOperationType.READ].operation,
975+
graph[dag_idx_1_4][_DAGNodeOperationType.COMPUTE].operation,
976+
graph[dag_idx_1_4][_DAGNodeOperationType.WRITE].operation,
977+
]
978+
assert actor_to_execution_schedule[worker_2] == [
979+
graph[dag_idx_2_1][_DAGNodeOperationType.READ].operation,
980+
graph[dag_idx_2_1][_DAGNodeOperationType.COMPUTE].operation,
981+
graph[dag_idx_2_1][_DAGNodeOperationType.WRITE].operation,
982+
graph[dag_idx_2_2][_DAGNodeOperationType.READ].operation,
983+
graph[dag_idx_2_2][_DAGNodeOperationType.COMPUTE].operation,
984+
# The order of `dag_idx_2_3.READ` and `dag_idx_2_2.WRITE` is important.
985+
# It is different from the case where there is an NCCL operation.
986+
graph[dag_idx_2_2][_DAGNodeOperationType.WRITE].operation,
987+
graph[dag_idx_2_3][_DAGNodeOperationType.READ].operation,
988+
graph[dag_idx_2_3][_DAGNodeOperationType.COMPUTE].operation,
989+
graph[dag_idx_2_3][_DAGNodeOperationType.WRITE].operation,
990+
graph[dag_idx_2_4][_DAGNodeOperationType.READ].operation,
991+
graph[dag_idx_2_4][_DAGNodeOperationType.COMPUTE].operation,
992+
graph[dag_idx_2_4][_DAGNodeOperationType.WRITE].operation,
993+
]
994+
887995

888996
if __name__ == "__main__":
889997
if os.environ.get("PARALLEL_CI"):

0 commit comments

Comments
 (0)