Skip to content

Commit 12d4ede

Browse files
bveeramanigemini-code-assist[bot]
authored andcommitted
[Data] Add preemption test for DataOpTask and refactor test utilities (ray-project#57883)
This PR adds a test to verify that DataOpTask handles node failures correctly during execution. To enable this testing, callback seams are added to DataOpTask that allow tests to simulate preemption scenarios by killing and restarting nodes at specific points during task execution. ## Summary - Add callback seams (`block_ready_callback` and `metadata_ready_callback`) to `DataOpTask` for testing purposes - Add `has_finished` property to track task completion state - Create `create_stub_streaming_gen` helper function to simplify test setup - Refactor existing `DataOpTask` tests to use the new helper function - Add new parametrized test `test_on_data_ready_with_preemption` to verify behavior when nodes fail during execution ## Test plan - Existing tests pass with refactored code - New preemption test validates that `on_data_ready` handles node failures correctly by testing both block and metadata callback scenarios --------- Signed-off-by: Balaji Veeramani <bveeramani@berkeley.edu> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 6c05a26 commit 12d4ede

File tree

2 files changed

+190
-41
lines changed

2 files changed

+190
-41
lines changed

python/ray/data/_internal/execution/interfaces/physical_operator.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@ def __init__(
101101
streaming_gen: ObjectRefGenerator,
102102
output_ready_callback: Callable[[RefBundle], None] = lambda bundle: None,
103103
task_done_callback: Callable[[Optional[Exception]], None] = lambda exc: None,
104+
block_ready_callback: Callable[
105+
[ray.ObjectRef[Block]], None
106+
] = lambda block_ref: None,
107+
metadata_ready_callback: Callable[
108+
[ray.ObjectRef[BlockMetadata]], None
109+
] = lambda metadata_ref: None,
104110
task_resource_bundle: Optional[ExecutionResources] = None,
105111
):
106112
"""Create a DataOpTask
@@ -110,6 +116,10 @@ def __init__(
110116
output_ready_callback: The callback to call when a new RefBundle is output
111117
from the generator.
112118
task_done_callback: The callback to call when the task is done.
119+
block_ready_callback: A callback that's invoked when a new block reference
120+
is ready. This is exposed as a seam for testing.
121+
metadata_ready_callback: A callback that's invoked when a new block metadata
122+
reference is ready. This is exposed as a seam for testing.
113123
task_resource_bundle: The execution resources of this task.
114124
"""
115125
super().__init__(task_index, task_resource_bundle)
@@ -120,6 +130,8 @@ def __init__(
120130
self._streaming_gen = streaming_gen
121131
self._output_ready_callback = output_ready_callback
122132
self._task_done_callback = task_done_callback
133+
self._block_ready_callback = block_ready_callback
134+
self._metadata_ready_callback = metadata_ready_callback
123135

124136
# If the generator hasn't produced block metadata yet, or if the block metadata
125137
# object isn't available after we get a reference, we need store the pending
@@ -128,6 +140,8 @@ def __init__(
128140
self._pending_block_ref: ray.ObjectRef[Block] = ray.ObjectRef.nil()
129141
self._pending_meta_ref: ray.ObjectRef[BlockMetadata] = ray.ObjectRef.nil()
130142

143+
self._has_finished = False
144+
131145
def get_waitable(self) -> ObjectRefGenerator:
132146
return self._streaming_gen
133147

@@ -154,13 +168,16 @@ def on_data_ready(self, max_bytes_to_read: Optional[int]) -> int:
154168
)
155169
except StopIteration:
156170
self._task_done_callback(None)
171+
self._has_finished = True
157172
break
158173

159174
if self._pending_block_ref.is_nil():
160175
# The generator currently doesn't have new output.
161176
# And it's not stopped yet.
162177
break
163178

179+
self._block_ready_callback(self._pending_block_ref)
180+
164181
if self._pending_meta_ref.is_nil():
165182
try:
166183
self._pending_meta_ref = self._streaming_gen._next_sync(
@@ -178,13 +195,16 @@ def on_data_ready(self, max_bytes_to_read: Optional[int]) -> int:
178195
assert False, "Above ray.get should raise an exception."
179196
except Exception as ex:
180197
self._task_done_callback(ex)
198+
self._has_finished = True
181199
raise ex from None
182200

183201
if self._pending_meta_ref.is_nil():
184202
# We have a reference to the block but the metadata isn't ready
185203
# yet.
186204
break
187205

206+
self._metadata_ready_callback(self._pending_meta_ref)
207+
188208
try:
189209
# The timeout for `ray.get` includes the time required to ship the
190210
# block metadata to this node. So, if we set the timeout to 0, `ray.get`
@@ -220,6 +240,10 @@ def on_data_ready(self, max_bytes_to_read: Optional[int]) -> int:
220240

221241
return bytes_read
222242

243+
@property
244+
def has_finished(self) -> bool:
245+
return self._has_finished
246+
223247

224248
class MetadataOpTask(OpTask):
225249
"""Represents an OpTask that only handles metadata, instead of Block data."""

python/ray/data/tests/test_streaming_executor.py

Lines changed: 166 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
import unittest
44
from concurrent.futures import ThreadPoolExecutor
5-
from typing import Optional
5+
from typing import List, Literal, Optional, Union
66
from unittest.mock import MagicMock, patch
77

88
import numpy as np
@@ -973,12 +973,34 @@ def test_create_topology_metadata_with_sub_stages():
973973
assert sub_stage2.id.endswith("_sub_1")
974974

975975

976-
class TestDataOpTask:
977-
def test_on_data_ready_single_output(self, ray_start_regular_shared):
978-
@ray.remote
979-
def map_task():
976+
def create_stub_streaming_gen(
977+
block_nbytes: List[int], raise_exception: Optional[Exception] = None
978+
) -> ray.ObjectRefGenerator:
979+
"""Creating a streaming generator for testing.
980+
981+
The streaming generator passed to the ``DataOpTask`` constructor must yield blocks
982+
then block metadata, and buffer the number of blocks specified by
983+
``_max_num_blocks_in_streaming_gen_buffer``. This function is a utility to create
984+
streaming generators that satisfy these requirements.
985+
986+
Args:
987+
block_nbytes: A list of the sizes of blocks yielded by the returned streaming
988+
generator.
989+
raise_exception: An exception that the streaming generator immediately raises.
990+
991+
Returns:
992+
A streaming generator that you can pass to ``DataOpTask``.
993+
"""
994+
995+
@ray.remote
996+
def stub_map_task():
997+
if raise_exception is not None:
998+
raise raise_exception
999+
1000+
for nbytes in block_nbytes:
1001+
# Create a block with a single row of the specified size.
9801002
builder = DelegatingBlockBuilder()
981-
builder.add_batch({"data": np.zeros((1, 128 * MiB), dtype=np.uint8)})
1003+
builder.add_batch({"data": np.zeros((1, nbytes), dtype=np.uint8)})
9821004
block = builder.build()
9831005
yield block
9841006

@@ -988,68 +1010,171 @@ def map_task():
9881010
block_metadata, schema=block_accessor.schema()
9891011
)
9901012

1013+
generator_backpressure_num_objects = (
1014+
ray.data.DataContext.get_current()._max_num_blocks_in_streaming_gen_buffer
1015+
* 2 # Multiply by two because we yield a metadata object for each block.
1016+
)
1017+
streaming_gen = stub_map_task.options(
1018+
_generator_backpressure_num_objects=generator_backpressure_num_objects
1019+
).remote()
1020+
1021+
return streaming_gen
1022+
1023+
1024+
@pytest.fixture
1025+
def ensure_block_metadata_stored_in_plasma(monkeypatch):
1026+
# Ray inlines small objects (including metadata) by storing them directly with
1027+
# the object reference itself rather than in the remote node's object store.
1028+
# Consequently, when the streaming executor calls `ray.get` on metadata from a
1029+
# node that has died, the call succeeds because the inlined metadata is not
1030+
# stored in the failed node's object store. To explicitly test the case where
1031+
# metadata resides in the object store (and becomes unavailable when the node
1032+
# dies), we disable inlining by setting the maximum inline size to 0. This
1033+
# simulates scenarios where metadata is too large to inline, which can occur in
1034+
# practice when schemas contain many fields.
1035+
#
1036+
# For context, see https://github.com/ray-project/ray/pull/56451.
1037+
monkeypatch.setenv("RAY_max_direct_call_object_size", 0)
1038+
1039+
1040+
class TestDataOpTask:
1041+
def test_on_data_ready_single_output(self, ray_start_regular_shared):
1042+
streaming_gen = create_stub_streaming_gen(block_nbytes=[128 * MiB])
1043+
9911044
def verify_output(bundle):
992-
assert bundle.num_rows() == 1, bundle.num_rows()
9931045
assert bundle.size_bytes() == pytest.approx(128 * MiB), bundle.size_bytes()
9941046

995-
has_completed = False
1047+
data_op_task = DataOpTask(0, streaming_gen, output_ready_callback=verify_output)
9961048

997-
def verify_exception(exc: Optional[Exception]):
998-
nonlocal has_completed
1049+
bytes_read = 0
1050+
while not data_op_task.has_finished:
1051+
ray.wait([streaming_gen], fetch_local=False)
1052+
nbytes_read = data_op_task.on_data_ready(None)
1053+
bytes_read += nbytes_read
9991054

1000-
assert exc is None
1001-
has_completed = True
1055+
assert bytes_read == pytest.approx(128 * MiB)
10021056

1003-
generator_backpressure_num_objects = (
1004-
ray.data.DataContext.get_current()._max_num_blocks_in_streaming_gen_buffer
1005-
* 2 # Multiply by two because we yield a metadata object for each block.
1006-
)
1007-
streaming_gen = map_task.options(
1008-
_generator_backpressure_num_objects=generator_backpressure_num_objects
1009-
).remote()
1057+
def test_on_data_ready_multiple_outputs(self, ray_start_regular_shared):
1058+
streaming_gen = create_stub_streaming_gen(block_nbytes=[128 * MiB, 128 * MiB])
10101059

1011-
data_op_task = DataOpTask(
1012-
0,
1013-
streaming_gen,
1014-
output_ready_callback=verify_output,
1015-
task_done_callback=verify_exception,
1016-
)
1060+
def verify_output(bundle):
1061+
assert bundle.size_bytes() == pytest.approx(128 * MiB), bundle.size_bytes()
1062+
1063+
data_op_task = DataOpTask(0, streaming_gen, output_ready_callback=verify_output)
10171064

10181065
bytes_read = 0
1019-
while not has_completed:
1066+
while not data_op_task.has_finished:
10201067
ray.wait([streaming_gen], fetch_local=False)
1021-
bytes_read += data_op_task.on_data_ready(None)
1068+
nbytes_read = data_op_task.on_data_ready(None)
1069+
bytes_read += nbytes_read
10221070

1023-
assert bytes_read == pytest.approx(128 * MiB)
1071+
assert bytes_read == pytest.approx(256 * MiB)
10241072

10251073
def test_on_data_ready_exception(self, ray_start_regular_shared):
1026-
@ray.remote
1027-
def map_task():
1028-
assert False, "Block generation failed"
1029-
yield
1074+
streaming_gen = create_stub_streaming_gen(
1075+
block_nbytes=[128 * MiB],
1076+
raise_exception=AssertionError("Block generation failed"),
1077+
)
10301078

10311079
def verify_exception(exc: Optional[Exception]):
10321080
assert isinstance(exc, AssertionError)
10331081

1034-
generator_backpressure_num_objects = (
1035-
ray.data.DataContext.get_current()._max_num_blocks_in_streaming_gen_buffer
1036-
* 2 # Multiply by two because we yield a metadata object for each block.
1037-
)
1038-
streaming_gen = map_task.options(
1039-
_generator_backpressure_num_objects=generator_backpressure_num_objects
1040-
).remote()
1041-
10421082
data_op_task = DataOpTask(
10431083
0,
10441084
streaming_gen,
10451085
task_done_callback=verify_exception,
10461086
)
10471087

10481088
with pytest.raises(AssertionError, match="Block generation failed"):
1049-
while True:
1089+
while not data_op_task.has_finished:
10501090
ray.wait([streaming_gen], fetch_local=False)
10511091
data_op_task.on_data_ready(None)
10521092

1093+
@pytest.mark.parametrize(
1094+
"preempt_on", ["block_ready_callback", "metadata_ready_callback"]
1095+
)
1096+
def test_on_data_ready_with_preemption_during_call(
1097+
self,
1098+
preempt_on: Union[
1099+
Literal["block_ready_callback"], Literal["metadata_ready_callback"]
1100+
],
1101+
ray_start_cluster_enabled,
1102+
ensure_block_metadata_stored_in_plasma,
1103+
):
1104+
"""Test that ``on_data_ready`` works when a node dies during its execution."""
1105+
# Shutdown Ray incase it's already initialized.
1106+
ray.shutdown()
1107+
1108+
# Create a single-worker-node cluster with 1 logical CPU.
1109+
cluster = ray_start_cluster_enabled
1110+
head_node = cluster.add_node(num_cpus=0) # noqa: F841
1111+
cluster.wait_for_nodes()
1112+
ray.init()
1113+
1114+
worker_node = cluster.add_node(num_cpus=1)
1115+
cluster.wait_for_nodes()
1116+
1117+
# Create a streaming generator that produces a single 128 MiB output block, and
1118+
# configure it so that it preempts the worker node in the specified callback.
1119+
streaming_gen = create_stub_streaming_gen(block_nbytes=[128 * MiB])
1120+
1121+
def remove_and_add_back_worker_node(_):
1122+
cluster.remove_node(worker_node)
1123+
1124+
new_worker_node = cluster.add_node(num_cpus=1) # noqa: F841
1125+
cluster.wait_for_nodes()
1126+
1127+
data_op_task = DataOpTask(
1128+
0, streaming_gen, **{preempt_on: remove_and_add_back_worker_node}
1129+
)
1130+
1131+
# Run the task to completion.
1132+
bytes_read = 0
1133+
while not data_op_task.has_finished:
1134+
ray.wait([streaming_gen], fetch_local=False)
1135+
bytes_read += data_op_task.on_data_ready(None)
1136+
1137+
# Ensure that we read the expected amount of data. Since the streaming generator
1138+
# yields a single 128 MiB block, we should read 128 MiB.
1139+
assert bytes_read == pytest.approx(128 * MiB)
1140+
1141+
def test_on_data_ready_with_preemption_after_wait(
1142+
self, ray_start_cluster_enabled, ensure_block_metadata_stored_in_plasma
1143+
):
1144+
# Shutdown Ray incase it's already initialized.
1145+
ray.shutdown()
1146+
1147+
# Create a single-worker-node cluster with 1 logical CPU.
1148+
cluster = ray_start_cluster_enabled
1149+
head_node = cluster.add_node(num_cpus=0) # noqa: F841
1150+
cluster.wait_for_nodes()
1151+
ray.init()
1152+
1153+
worker_node = cluster.add_node(num_cpus=1)
1154+
cluster.wait_for_nodes()
1155+
1156+
# Create a streaming generator that produces a single 128 MiB output block.
1157+
streaming_gen = create_stub_streaming_gen(block_nbytes=[128 * MiB])
1158+
data_op_task = DataOpTask(0, streaming_gen)
1159+
1160+
# Wait for the block to be ready, then remove the worker node.
1161+
ray.wait([streaming_gen], fetch_local=False)
1162+
cluster.remove_node(worker_node)
1163+
1164+
# The block shouldn't be available anymore, so we shouldn't read any data.
1165+
bytes_read = data_op_task.on_data_ready(None)
1166+
assert bytes_read == 0
1167+
1168+
# Re-add the worker node, and run the task to completion.
1169+
new_worker_node = cluster.add_node(num_cpus=1) # noqa: F841
1170+
cluster.wait_for_nodes()
1171+
while not data_op_task.has_finished:
1172+
ray.wait([streaming_gen], fetch_local=False)
1173+
bytes_read += data_op_task.on_data_ready(None)
1174+
1175+
# We should now be able to read the 128 MiB block.
1176+
assert bytes_read == pytest.approx(128 * MiB)
1177+
10531178

10541179
if __name__ == "__main__":
10551180
import sys

0 commit comments

Comments
 (0)