diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index d1be38f8ac52b..9d7c64c7c56cc 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -320,11 +320,16 @@ def transform_fn( output_batch_queue = queue.Queue() async def process_batch(batch: DataBatch): - output_batch_iterator = await fn(batch) - # As soon as results become available from the async generator, - # put them into the result queue so they can be yielded. - async for output_batch in output_batch_iterator: - output_batch_queue.put(output_batch) + try: + output_batch_iterator = await fn(batch) + # As soon as results become available from the async generator, + # put them into the result queue so they can be yielded. + async for output_batch in output_batch_iterator: + output_batch_queue.put(output_batch) + except Exception as e: + output_batch_queue.put( + e + ) # Put the exception into the queue to signal an error async def process_all_batches(): loop = ray.data._map_actor_context.udf_map_asyncio_loop @@ -348,6 +353,8 @@ async def process_all_batches(): # from the async generator, corresponding to a # single row from the input batch. out_batch = output_batch_queue.get() + if isinstance(out_batch, Exception): + raise out_batch _validate_batch_output(out_batch) yield out_batch @@ -448,7 +455,7 @@ def _create_map_transformer_for_row_based_map_op( def generate_map_rows_fn( target_max_block_size: int, -) -> (Callable[[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]]): +) -> Callable[[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]]: """Generate function to apply the UDF to each record of blocks.""" context = DataContext.get_current() @@ -468,7 +475,7 @@ def fn( def generate_flat_map_fn( target_max_block_size: int, -) -> (Callable[[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]]): +) -> Callable[[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]]: """Generate function to apply the UDF to each record of blocks, and then flatten results. """ @@ -491,7 +498,7 @@ def fn( def generate_filter_fn( target_max_block_size: int, -) -> (Callable[[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]]): +) -> Callable[[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]]: """Generate function to apply the UDF to each record of blocks, and filter out records that do not satisfy the given predicate. """ diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 18b36bfb8246a..00a0386ca2e59 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -1096,6 +1096,29 @@ async def __call__(self, batch): assert output == expected_output, (output, expected_output) +def test_map_batches_async_exception_propagation(shutdown_only): + ray.shutdown() + ray.init(num_cpus=2) + + class MyUDF: + def __init__(self): + pass + + async def __call__(self, batch): + # This will trigger an assertion error. + assert False + yield batch + + ds = ray.data.range(20) + ds = ds.map_batches(MyUDF, concurrency=2) + + with pytest.raises(ray.exceptions.RayTaskError) as exc_info: + ds.materialize() + + assert "AssertionError" in str(exc_info.value) + assert "assert False" in str(exc_info.value) + + if __name__ == "__main__": import sys