@@ -780,7 +780,7 @@ def test_two_actors_with_nccl(self, monkeypatch):
780
780
graph [dag_idx_2_2 ][_DAGNodeOperationType .WRITE ].operation ,
781
781
]
782
782
783
- def test_simulate_pp_2workers_2batches_1f1b (self , monkeypatch ):
783
+ def test_simulate_pp_2workers_2batches_1f1b_with_nccl (self , monkeypatch ):
784
784
"""
785
785
This test simulates a simple 1F1B pipeline parallelism for training with
786
786
2 workers and 2 batches.
@@ -884,6 +884,114 @@ def test_simulate_pp_2workers_2batches_1f1b(self, monkeypatch):
884
884
graph [dag_idx_2_4 ][_DAGNodeOperationType .WRITE ].operation ,
885
885
]
886
886
887
+ def test_simulate_pp_2workers_2batches_1f1b_no_nccl (self , monkeypatch ):
888
+ """
889
+ This test simulates a simple 1F1B pipeline parallelism for training with
890
+ 2 workers and 2 batches.
891
+
892
+ w1: fwd_b1 fwd_b2 bwd_b1 bwd_b2
893
+ w2: fwd_b1 bwd_b1 fwd_b2 bwd_b2
894
+
895
+ Because there is no NCCL operation, all operations with smaller
896
+ `bind_index` should be executed before the operations with larger
897
+ `bind_index` on the same actor.
898
+ """
899
+ monkeypatch .setattr (ActorHandle , "__init__" , mock_actor_handle_init )
900
+
901
+ worker_1 = ActorHandle ("worker_1" )
902
+ dag_idx_1_1 , local_idx_1_1 = 1 , 0
903
+ dag_idx_1_2 , local_idx_1_2 = 2 , 1
904
+ dag_idx_1_3 , local_idx_1_3 = 3 , 2
905
+ dag_idx_1_4 , local_idx_1_4 = 4 , 3
906
+ worker_2 = ActorHandle ("worker_2" )
907
+ dag_idx_2_1 , local_idx_2_1 = 5 , 0
908
+ dag_idx_2_2 , local_idx_2_2 = 6 , 1
909
+ dag_idx_2_3 , local_idx_2_3 = 7 , 2
910
+ dag_idx_2_4 , local_idx_2_4 = 8 , 3
911
+
912
+ # No NCCL operation.
913
+ graph = {
914
+ dag_idx_1_1 : generate_dag_graph_nodes (
915
+ local_idx_1_1 , dag_idx_1_1 , worker_1 , False
916
+ ),
917
+ dag_idx_1_2 : generate_dag_graph_nodes (
918
+ local_idx_1_2 , dag_idx_1_2 , worker_1 , False
919
+ ),
920
+ dag_idx_1_3 : generate_dag_graph_nodes (
921
+ local_idx_1_3 , dag_idx_1_3 , worker_1 , False
922
+ ),
923
+ dag_idx_1_4 : generate_dag_graph_nodes (
924
+ local_idx_1_4 , dag_idx_1_4 , worker_1 , False
925
+ ),
926
+ dag_idx_2_1 : generate_dag_graph_nodes (
927
+ local_idx_2_1 , dag_idx_2_1 , worker_2 , False
928
+ ),
929
+ dag_idx_2_2 : generate_dag_graph_nodes (
930
+ local_idx_2_2 , dag_idx_2_2 , worker_2 , False
931
+ ),
932
+ dag_idx_2_3 : generate_dag_graph_nodes (
933
+ local_idx_2_3 , dag_idx_2_3 , worker_2 , False
934
+ ),
935
+ dag_idx_2_4 : generate_dag_graph_nodes (
936
+ local_idx_2_4 , dag_idx_2_4 , worker_2 , False
937
+ ),
938
+ }
939
+ self .add_edge_between_read_compute_write (graph [dag_idx_1_1 ])
940
+ self .add_edge_between_read_compute_write (graph [dag_idx_1_2 ])
941
+ self .add_edge_between_read_compute_write (graph [dag_idx_1_3 ])
942
+ self .add_edge_between_read_compute_write (graph [dag_idx_1_4 ])
943
+ self .add_edge_between_read_compute_write (graph [dag_idx_2_1 ])
944
+ self .add_edge_between_read_compute_write (graph [dag_idx_2_2 ])
945
+ self .add_edge_between_read_compute_write (graph [dag_idx_2_3 ])
946
+ self .add_edge_between_read_compute_write (graph [dag_idx_2_4 ])
947
+ self .add_data_dependeny (graph [dag_idx_1_1 ], graph [dag_idx_2_1 ])
948
+ self .add_data_dependeny (graph [dag_idx_2_1 ], graph [dag_idx_2_2 ])
949
+ self .add_data_dependeny (graph [dag_idx_2_2 ], graph [dag_idx_1_3 ])
950
+ self .add_data_dependeny (graph [dag_idx_1_2 ], graph [dag_idx_2_3 ])
951
+ self .add_data_dependeny (graph [dag_idx_2_3 ], graph [dag_idx_2_4 ])
952
+ self .add_data_dependeny (graph [dag_idx_2_4 ], graph [dag_idx_1_4 ])
953
+ self .add_control_dependency (graph [dag_idx_1_1 ], graph [dag_idx_1_2 ])
954
+ self .add_control_dependency (graph [dag_idx_1_2 ], graph [dag_idx_1_3 ])
955
+ self .add_control_dependency (graph [dag_idx_1_3 ], graph [dag_idx_1_4 ])
956
+ self .add_control_dependency (graph [dag_idx_2_1 ], graph [dag_idx_2_2 ])
957
+ self .add_control_dependency (graph [dag_idx_2_2 ], graph [dag_idx_2_3 ])
958
+ self .add_control_dependency (graph [dag_idx_2_3 ], graph [dag_idx_2_4 ])
959
+
960
+ actor_to_execution_schedule = _generate_actor_to_execution_schedule (graph )
961
+ assert len (actor_to_execution_schedule ) == 2
962
+ assert len (actor_to_execution_schedule [worker_1 ]) == 12
963
+ assert len (actor_to_execution_schedule [worker_2 ]) == 12
964
+ assert actor_to_execution_schedule [worker_1 ] == [
965
+ graph [dag_idx_1_1 ][_DAGNodeOperationType .READ ].operation ,
966
+ graph [dag_idx_1_1 ][_DAGNodeOperationType .COMPUTE ].operation ,
967
+ graph [dag_idx_1_1 ][_DAGNodeOperationType .WRITE ].operation ,
968
+ graph [dag_idx_1_2 ][_DAGNodeOperationType .READ ].operation ,
969
+ graph [dag_idx_1_2 ][_DAGNodeOperationType .COMPUTE ].operation ,
970
+ graph [dag_idx_1_2 ][_DAGNodeOperationType .WRITE ].operation ,
971
+ graph [dag_idx_1_3 ][_DAGNodeOperationType .READ ].operation ,
972
+ graph [dag_idx_1_3 ][_DAGNodeOperationType .COMPUTE ].operation ,
973
+ graph [dag_idx_1_3 ][_DAGNodeOperationType .WRITE ].operation ,
974
+ graph [dag_idx_1_4 ][_DAGNodeOperationType .READ ].operation ,
975
+ graph [dag_idx_1_4 ][_DAGNodeOperationType .COMPUTE ].operation ,
976
+ graph [dag_idx_1_4 ][_DAGNodeOperationType .WRITE ].operation ,
977
+ ]
978
+ assert actor_to_execution_schedule [worker_2 ] == [
979
+ graph [dag_idx_2_1 ][_DAGNodeOperationType .READ ].operation ,
980
+ graph [dag_idx_2_1 ][_DAGNodeOperationType .COMPUTE ].operation ,
981
+ graph [dag_idx_2_1 ][_DAGNodeOperationType .WRITE ].operation ,
982
+ graph [dag_idx_2_2 ][_DAGNodeOperationType .READ ].operation ,
983
+ graph [dag_idx_2_2 ][_DAGNodeOperationType .COMPUTE ].operation ,
984
+ # The order of `dag_idx_2_3.READ` and `dag_idx_2_2.WRITE` is important.
985
+ # It is different from the case where there is an NCCL operation.
986
+ graph [dag_idx_2_2 ][_DAGNodeOperationType .WRITE ].operation ,
987
+ graph [dag_idx_2_3 ][_DAGNodeOperationType .READ ].operation ,
988
+ graph [dag_idx_2_3 ][_DAGNodeOperationType .COMPUTE ].operation ,
989
+ graph [dag_idx_2_3 ][_DAGNodeOperationType .WRITE ].operation ,
990
+ graph [dag_idx_2_4 ][_DAGNodeOperationType .READ ].operation ,
991
+ graph [dag_idx_2_4 ][_DAGNodeOperationType .COMPUTE ].operation ,
992
+ graph [dag_idx_2_4 ][_DAGNodeOperationType .WRITE ].operation ,
993
+ ]
994
+
887
995
888
996
if __name__ == "__main__" :
889
997
if os .environ .get ("PARALLEL_CI" ):
0 commit comments