Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Fix exception in async map #47110

Merged
merged 5 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions python/ray/data/_internal/planner/plan_udf_map_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,16 @@ def transform_fn(
output_batch_queue = queue.Queue()

async def process_batch(batch: DataBatch):
output_batch_iterator = await fn(batch)
# As soon as results become available from the async generator,
# put them into the result queue so they can be yielded.
async for output_batch in output_batch_iterator:
output_batch_queue.put(output_batch)
try:
output_batch_iterator = await fn(batch)
# As soon as results become available from the async generator,
# put them into the result queue so they can be yielded.
async for output_batch in output_batch_iterator:
output_batch_queue.put(output_batch)
except Exception as e:
output_batch_queue.put(
e
) # Put the exception into the queue to signal an error

async def process_all_batches():
loop = ray.data._map_actor_context.udf_map_asyncio_loop
Expand All @@ -348,6 +353,8 @@ async def process_all_batches():
# from the async generator, corresponding to a
# single row from the input batch.
out_batch = output_batch_queue.get()
if isinstance(out_batch, Exception):
raise out_batch
_validate_batch_output(out_batch)
yield out_batch

Expand Down Expand Up @@ -448,7 +455,7 @@ def _create_map_transformer_for_row_based_map_op(

def generate_map_rows_fn(
target_max_block_size: int,
) -> (Callable[[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]]):
) -> Callable[[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]]:
"""Generate function to apply the UDF to each record of blocks."""
context = DataContext.get_current()

Expand All @@ -468,7 +475,7 @@ def fn(

def generate_flat_map_fn(
target_max_block_size: int,
) -> (Callable[[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]]):
) -> Callable[[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]]:
"""Generate function to apply the UDF to each record of blocks,
and then flatten results.
"""
Expand All @@ -491,7 +498,7 @@ def fn(

def generate_filter_fn(
target_max_block_size: int,
) -> (Callable[[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]]):
) -> Callable[[Iterator[Block], TaskContext, UserDefinedFunction], Iterator[Block]]:
"""Generate function to apply the UDF to each record of blocks,
and filter out records that do not satisfy the given predicate.
"""
Expand Down
23 changes: 23 additions & 0 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,29 @@ async def __call__(self, batch):
assert output == expected_output, (output, expected_output)


def test_map_batches_async_exception_propagation(shutdown_only):
ray.shutdown()
ray.init(num_cpus=2)

class MyUDF:
def __init__(self):
pass

async def __call__(self, batch):
# This will trigger an assertion error.
assert False
yield batch

ds = ray.data.range(20)
ds = ds.map_batches(MyUDF, concurrency=2)

with pytest.raises(ray.exceptions.RayTaskError) as exc_info:
ds.materialize()

assert "AssertionError" in str(exc_info.value)
assert "assert False" in str(exc_info.value)


if __name__ == "__main__":
import sys

Expand Down
Loading