|
10 | 10 | from ray.experimental.channel.conftest import start_nccl_mock
|
11 | 11 | from ray.tests.conftest import * # noqa
|
12 | 12 | from ray.dag import InputNode, MultiOutputNode
|
| 13 | +from ray.dag.compiled_dag_node import DAGNodeOperationType |
13 | 14 |
|
14 | 15 | if sys.platform != "linux" and sys.platform != "darwin":
|
15 | 16 | pytest.skip("Skipping, requires Linux or Mac.", allow_module_level=True)
|
@@ -72,6 +73,45 @@ def test_simulate_pp_2workers_1f1b(ray_start_regular, monkeypatch):
|
72 | 73 | ]
|
73 | 74 | )
|
74 | 75 | compiled_graph = dag.experimental_compile()
|
| 76 | + |
| 77 | + w1_expected_schedule = [ |
| 78 | + (0, DAGNodeOperationType.READ), |
| 79 | + (0, DAGNodeOperationType.COMPUTE), |
| 80 | + (0, DAGNodeOperationType.WRITE), |
| 81 | + (1, DAGNodeOperationType.READ), |
| 82 | + (1, DAGNodeOperationType.COMPUTE), |
| 83 | + (1, DAGNodeOperationType.WRITE), |
| 84 | + (2, DAGNodeOperationType.READ), |
| 85 | + (2, DAGNodeOperationType.COMPUTE), |
| 86 | + (2, DAGNodeOperationType.WRITE), |
| 87 | + (3, DAGNodeOperationType.READ), |
| 88 | + (3, DAGNodeOperationType.COMPUTE), |
| 89 | + (3, DAGNodeOperationType.WRITE), |
| 90 | + ] |
| 91 | + w2_expected_schedule = [ |
| 92 | + (0, DAGNodeOperationType.READ), |
| 93 | + (0, DAGNodeOperationType.COMPUTE), |
| 94 | + (0, DAGNodeOperationType.WRITE), |
| 95 | + (1, DAGNodeOperationType.READ), |
| 96 | + (1, DAGNodeOperationType.COMPUTE), |
| 97 | + (2, DAGNodeOperationType.READ), |
| 98 | + (1, DAGNodeOperationType.WRITE), |
| 99 | + (2, DAGNodeOperationType.COMPUTE), |
| 100 | + (2, DAGNodeOperationType.WRITE), |
| 101 | + (3, DAGNodeOperationType.READ), |
| 102 | + (3, DAGNodeOperationType.COMPUTE), |
| 103 | + (3, DAGNodeOperationType.WRITE), |
| 104 | + ] |
| 105 | + w1_schedule = compiled_graph.actor_to_execution_schedule[w1] |
| 106 | + w2_schedule = compiled_graph.actor_to_execution_schedule[w2] |
| 107 | + |
| 108 | + for schedule, expected_schedule in zip( |
| 109 | + [w1_schedule, w2_schedule], [w1_expected_schedule, w2_expected_schedule] |
| 110 | + ): |
| 111 | + assert len(schedule) == len(expected_schedule) |
| 112 | + for i, operation in enumerate(schedule): |
| 113 | + assert operation.idx == expected_schedule[i][0] |
| 114 | + assert operation.type == expected_schedule[i][1] |
75 | 115 | compiled_graph.teardown()
|
76 | 116 |
|
77 | 117 |
|
|
0 commit comments