Skip to content

Commit bc90966

Browse files
committed
polish
1 parent b2ac0f9 commit bc90966

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

python/ray/dag/tests/experimental/test_execution_schedule.py

+40
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ray.experimental.channel.conftest import start_nccl_mock
1111
from ray.tests.conftest import * # noqa
1212
from ray.dag import InputNode, MultiOutputNode
13+
from ray.dag.compiled_dag_node import DAGNodeOperationType
1314

1415
if sys.platform != "linux" and sys.platform != "darwin":
1516
pytest.skip("Skipping, requires Linux or Mac.", allow_module_level=True)
@@ -72,6 +73,45 @@ def test_simulate_pp_2workers_1f1b(ray_start_regular, monkeypatch):
7273
]
7374
)
7475
compiled_graph = dag.experimental_compile()
76+
77+
w1_expected_schedule = [
78+
(0, DAGNodeOperationType.READ),
79+
(0, DAGNodeOperationType.COMPUTE),
80+
(0, DAGNodeOperationType.WRITE),
81+
(1, DAGNodeOperationType.READ),
82+
(1, DAGNodeOperationType.COMPUTE),
83+
(1, DAGNodeOperationType.WRITE),
84+
(2, DAGNodeOperationType.READ),
85+
(2, DAGNodeOperationType.COMPUTE),
86+
(2, DAGNodeOperationType.WRITE),
87+
(3, DAGNodeOperationType.READ),
88+
(3, DAGNodeOperationType.COMPUTE),
89+
(3, DAGNodeOperationType.WRITE),
90+
]
91+
w2_expected_schedule = [
92+
(0, DAGNodeOperationType.READ),
93+
(0, DAGNodeOperationType.COMPUTE),
94+
(0, DAGNodeOperationType.WRITE),
95+
(1, DAGNodeOperationType.READ),
96+
(1, DAGNodeOperationType.COMPUTE),
97+
(2, DAGNodeOperationType.READ),
98+
(1, DAGNodeOperationType.WRITE),
99+
(2, DAGNodeOperationType.COMPUTE),
100+
(2, DAGNodeOperationType.WRITE),
101+
(3, DAGNodeOperationType.READ),
102+
(3, DAGNodeOperationType.COMPUTE),
103+
(3, DAGNodeOperationType.WRITE),
104+
]
105+
w1_schedule = compiled_graph.actor_to_execution_schedule[w1]
106+
w2_schedule = compiled_graph.actor_to_execution_schedule[w2]
107+
108+
for schedule, expected_schedule in zip(
109+
[w1_schedule, w2_schedule], [w1_expected_schedule, w2_expected_schedule]
110+
):
111+
assert len(schedule) == len(expected_schedule)
112+
for i, operation in enumerate(schedule):
113+
assert operation.idx == expected_schedule[i][0]
114+
assert operation.type == expected_schedule[i][1]
75115
compiled_graph.teardown()
76116

77117

0 commit comments

Comments
 (0)