22import time
33import unittest
44from concurrent .futures import ThreadPoolExecutor
5- from typing import Optional
5+ from typing import List , Literal , Optional , Union
66from unittest .mock import MagicMock , patch
77
88import numpy as np
@@ -973,12 +973,34 @@ def test_create_topology_metadata_with_sub_stages():
973973 assert sub_stage2 .id .endswith ("_sub_1" )
974974
975975
976- class TestDataOpTask :
977- def test_on_data_ready_single_output (self , ray_start_regular_shared ):
978- @ray .remote
979- def map_task ():
976+ def create_stub_streaming_gen (
977+ block_nbytes : List [int ], raise_exception : Optional [Exception ] = None
978+ ) -> ray .ObjectRefGenerator :
979+ """Creating a streaming generator for testing.
980+
981+ The streaming generator passed to the ``DataOpTask`` constructor must yield blocks
982+ then block metadata, and buffer the number of blocks specified by
983+ ``_max_num_blocks_in_streaming_gen_buffer``. This function is a utility to create
984+ streaming generators that satisfy these requirements.
985+
986+ Args:
987+ block_nbytes: A list of the sizes of blocks yielded by the returned streaming
988+ generator.
989+ raise_exception: An exception that the streaming generator immediately raises.
990+
991+ Returns:
992+ A streaming generator that you can pass to ``DataOpTask``.
993+ """
994+
995+ @ray .remote
996+ def stub_map_task ():
997+ if raise_exception is not None :
998+ raise raise_exception
999+
1000+ for nbytes in block_nbytes :
1001+ # Create a block with a single row of the specified size.
9801002 builder = DelegatingBlockBuilder ()
981- builder .add_batch ({"data" : np .zeros ((1 , 128 * MiB ), dtype = np .uint8 )})
1003+ builder .add_batch ({"data" : np .zeros ((1 , nbytes ), dtype = np .uint8 )})
9821004 block = builder .build ()
9831005 yield block
9841006
@@ -988,68 +1010,171 @@ def map_task():
9881010 block_metadata , schema = block_accessor .schema ()
9891011 )
9901012
1013+ generator_backpressure_num_objects = (
1014+ ray .data .DataContext .get_current ()._max_num_blocks_in_streaming_gen_buffer
1015+ * 2 # Multiply by two because we yield a metadata object for each block.
1016+ )
1017+ streaming_gen = stub_map_task .options (
1018+ _generator_backpressure_num_objects = generator_backpressure_num_objects
1019+ ).remote ()
1020+
1021+ return streaming_gen
1022+
1023+
1024+ @pytest .fixture
1025+ def ensure_block_metadata_stored_in_plasma (monkeypatch ):
1026+ # Ray inlines small objects (including metadata) by storing them directly with
1027+ # the object reference itself rather than in the remote node's object store.
1028+ # Consequently, when the streaming executor calls `ray.get` on metadata from a
1029+ # node that has died, the call succeeds because the inlined metadata is not
1030+ # stored in the failed node's object store. To explicitly test the case where
1031+ # metadata resides in the object store (and becomes unavailable when the node
1032+ # dies), we disable inlining by setting the maximum inline size to 0. This
1033+ # simulates scenarios where metadata is too large to inline, which can occur in
1034+ # practice when schemas contain many fields.
1035+ #
1036+ # For context, see https://github.com/ray-project/ray/pull/56451.
1037+ monkeypatch .setenv ("RAY_max_direct_call_object_size" , 0 )
1038+
1039+
1040+ class TestDataOpTask :
1041+ def test_on_data_ready_single_output (self , ray_start_regular_shared ):
1042+ streaming_gen = create_stub_streaming_gen (block_nbytes = [128 * MiB ])
1043+
9911044 def verify_output (bundle ):
992- assert bundle .num_rows () == 1 , bundle .num_rows ()
9931045 assert bundle .size_bytes () == pytest .approx (128 * MiB ), bundle .size_bytes ()
9941046
995- has_completed = False
1047+ data_op_task = DataOpTask ( 0 , streaming_gen , output_ready_callback = verify_output )
9961048
997- def verify_exception (exc : Optional [Exception ]):
998- nonlocal has_completed
1049+ bytes_read = 0
1050+ while not data_op_task .has_finished :
1051+ ray .wait ([streaming_gen ], fetch_local = False )
1052+ nbytes_read = data_op_task .on_data_ready (None )
1053+ bytes_read += nbytes_read
9991054
1000- assert exc is None
1001- has_completed = True
1055+ assert bytes_read == pytest .approx (128 * MiB )
10021056
1003- generator_backpressure_num_objects = (
1004- ray .data .DataContext .get_current ()._max_num_blocks_in_streaming_gen_buffer
1005- * 2 # Multiply by two because we yield a metadata object for each block.
1006- )
1007- streaming_gen = map_task .options (
1008- _generator_backpressure_num_objects = generator_backpressure_num_objects
1009- ).remote ()
1057+ def test_on_data_ready_multiple_outputs (self , ray_start_regular_shared ):
1058+ streaming_gen = create_stub_streaming_gen (block_nbytes = [128 * MiB , 128 * MiB ])
10101059
1011- data_op_task = DataOpTask (
1012- 0 ,
1013- streaming_gen ,
1014- output_ready_callback = verify_output ,
1015- task_done_callback = verify_exception ,
1016- )
1060+ def verify_output (bundle ):
1061+ assert bundle .size_bytes () == pytest .approx (128 * MiB ), bundle .size_bytes ()
1062+
1063+ data_op_task = DataOpTask (0 , streaming_gen , output_ready_callback = verify_output )
10171064
10181065 bytes_read = 0
1019- while not has_completed :
1066+ while not data_op_task . has_finished :
10201067 ray .wait ([streaming_gen ], fetch_local = False )
1021- bytes_read += data_op_task .on_data_ready (None )
1068+ nbytes_read = data_op_task .on_data_ready (None )
1069+ bytes_read += nbytes_read
10221070
1023- assert bytes_read == pytest .approx (128 * MiB )
1071+ assert bytes_read == pytest .approx (256 * MiB )
10241072
10251073 def test_on_data_ready_exception (self , ray_start_regular_shared ):
1026- @ ray . remote
1027- def map_task ():
1028- assert False , "Block generation failed"
1029- yield
1074+ streaming_gen = create_stub_streaming_gen (
1075+ block_nbytes = [ 128 * MiB ],
1076+ raise_exception = AssertionError ( "Block generation failed" ),
1077+ )
10301078
10311079 def verify_exception (exc : Optional [Exception ]):
10321080 assert isinstance (exc , AssertionError )
10331081
1034- generator_backpressure_num_objects = (
1035- ray .data .DataContext .get_current ()._max_num_blocks_in_streaming_gen_buffer
1036- * 2 # Multiply by two because we yield a metadata object for each block.
1037- )
1038- streaming_gen = map_task .options (
1039- _generator_backpressure_num_objects = generator_backpressure_num_objects
1040- ).remote ()
1041-
10421082 data_op_task = DataOpTask (
10431083 0 ,
10441084 streaming_gen ,
10451085 task_done_callback = verify_exception ,
10461086 )
10471087
10481088 with pytest .raises (AssertionError , match = "Block generation failed" ):
1049- while True :
1089+ while not data_op_task . has_finished :
10501090 ray .wait ([streaming_gen ], fetch_local = False )
10511091 data_op_task .on_data_ready (None )
10521092
1093+ @pytest .mark .parametrize (
1094+ "preempt_on" , ["block_ready_callback" , "metadata_ready_callback" ]
1095+ )
1096+ def test_on_data_ready_with_preemption_during_call (
1097+ self ,
1098+ preempt_on : Union [
1099+ Literal ["block_ready_callback" ], Literal ["metadata_ready_callback" ]
1100+ ],
1101+ ray_start_cluster_enabled ,
1102+ ensure_block_metadata_stored_in_plasma ,
1103+ ):
1104+ """Test that ``on_data_ready`` works when a node dies during its execution."""
1105+ # Shutdown Ray incase it's already initialized.
1106+ ray .shutdown ()
1107+
1108+ # Create a single-worker-node cluster with 1 logical CPU.
1109+ cluster = ray_start_cluster_enabled
1110+ head_node = cluster .add_node (num_cpus = 0 ) # noqa: F841
1111+ cluster .wait_for_nodes ()
1112+ ray .init ()
1113+
1114+ worker_node = cluster .add_node (num_cpus = 1 )
1115+ cluster .wait_for_nodes ()
1116+
1117+ # Create a streaming generator that produces a single 128 MiB output block, and
1118+ # configure it so that it preempts the worker node in the specified callback.
1119+ streaming_gen = create_stub_streaming_gen (block_nbytes = [128 * MiB ])
1120+
1121+ def remove_and_add_back_worker_node (_ ):
1122+ cluster .remove_node (worker_node )
1123+
1124+ new_worker_node = cluster .add_node (num_cpus = 1 ) # noqa: F841
1125+ cluster .wait_for_nodes ()
1126+
1127+ data_op_task = DataOpTask (
1128+ 0 , streaming_gen , ** {preempt_on : remove_and_add_back_worker_node }
1129+ )
1130+
1131+ # Run the task to completion.
1132+ bytes_read = 0
1133+ while not data_op_task .has_finished :
1134+ ray .wait ([streaming_gen ], fetch_local = False )
1135+ bytes_read += data_op_task .on_data_ready (None )
1136+
1137+ # Ensure that we read the expected amount of data. Since the streaming generator
1138+ # yields a single 128 MiB block, we should read 128 MiB.
1139+ assert bytes_read == pytest .approx (128 * MiB )
1140+
1141+ def test_on_data_ready_with_preemption_after_wait (
1142+ self , ray_start_cluster_enabled , ensure_block_metadata_stored_in_plasma
1143+ ):
1144+ # Shutdown Ray incase it's already initialized.
1145+ ray .shutdown ()
1146+
1147+ # Create a single-worker-node cluster with 1 logical CPU.
1148+ cluster = ray_start_cluster_enabled
1149+ head_node = cluster .add_node (num_cpus = 0 ) # noqa: F841
1150+ cluster .wait_for_nodes ()
1151+ ray .init ()
1152+
1153+ worker_node = cluster .add_node (num_cpus = 1 )
1154+ cluster .wait_for_nodes ()
1155+
1156+ # Create a streaming generator that produces a single 128 MiB output block.
1157+ streaming_gen = create_stub_streaming_gen (block_nbytes = [128 * MiB ])
1158+ data_op_task = DataOpTask (0 , streaming_gen )
1159+
1160+ # Wait for the block to be ready, then remove the worker node.
1161+ ray .wait ([streaming_gen ], fetch_local = False )
1162+ cluster .remove_node (worker_node )
1163+
1164+ # The block shouldn't be available anymore, so we shouldn't read any data.
1165+ bytes_read = data_op_task .on_data_ready (None )
1166+ assert bytes_read == 0
1167+
1168+ # Re-add the worker node, and run the task to completion.
1169+ new_worker_node = cluster .add_node (num_cpus = 1 ) # noqa: F841
1170+ cluster .wait_for_nodes ()
1171+ while not data_op_task .has_finished :
1172+ ray .wait ([streaming_gen ], fetch_local = False )
1173+ bytes_read += data_op_task .on_data_ready (None )
1174+
1175+ # We should now be able to read the 128 MiB block.
1176+ assert bytes_read == pytest .approx (128 * MiB )
1177+
10531178
10541179if __name__ == "__main__" :
10551180 import sys
0 commit comments