Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 21 additions & 66 deletions python/ray/data/tests/test_download_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ def test_download_expression_with_malformed_uris(self, tmp_path):

This tests that various malformed URIs are caught and return None
instead of crashing.

All of the URIs should be malformed in order to test the ZeroDivisionError
described in https://github.com/ray-project/ray/issues/58462.
"""
malformed_uris = [
f"local://{tmp_path}/nonexistent.txt", # File doesn't exist
Expand All @@ -324,78 +327,30 @@ def test_download_expression_with_malformed_uris(self, tmp_path):
for result in results:
assert result["bytes"] is None

def test_download_expression_all_size_estimations_fail(self):
"""Test download expression when all URI size estimations fail.

This tests the failed download does not cause division by zero error.
"""
# Create URIs that will fail size estimation (non-existent files)
# Using enough URIs to trigger size estimation sampling
invalid_uris = [
f"local:///nonexistent/path/file_{i}.txt"
for i in range(30) # More than INIT_SAMPLE_BATCH_SIZE (25)
]
def test_download_expression_mixed_valid_and_invalid_uris(self, tmp_path):
"""Test download expression when some but not all of the URIs are invalid."""
# Create one valid file
valid_file = tmp_path / "valid.txt"
valid_file.write_bytes(b"valid content")

table = pa.Table.from_arrays(
[pa.array(invalid_uris)],
names=["uri"],
# Create URIs: one valid and one non-existent file.
ds = ray.data.from_items(
[
{"uri": str(valid_file), "id": 0},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why would you change the format of the uri from f"local://{valid_file}" to str(valid_file)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's ok to work without local:// prefix in this case, but I assume it's good to follow the pattern of other tests in the same file

{"uri": str(tmp_path / "nonexistent.txt"), "id": 1},
]
)

ds = ray.data.from_arrow(table)
ds_with_downloads = ds.with_column("bytes", download("uri"))

# Should not crash with divide-by-zero error
# The PartitionActor should handle all failed size estimations gracefully
# and fall back to using the number of rows in the block as partition size
results = ds_with_downloads.take_all()

# All downloads should fail gracefully (return None)
assert len(results) == 30
for result in results:
assert result["bytes"] is None

def test_download_expression_mixed_valid_and_invalid_size_estimation(
self, tmp_path
):
"""Test download expression with mix of valid and invalid URIs for size estimation.

This tests that size estimation handles partial failures correctly.
"""
# Create some valid files
valid_files = []
for i in range(10):
file_path = tmp_path / f"valid_{i}.txt"
file_path.write_bytes(b"x" * 100) # 100 bytes each
valid_files.append(str(file_path))

# Mix valid and invalid URIs
mixed_uris = []
for i in range(30):
if i % 3 == 0 and i // 3 < len(valid_files):
# Every 3rd URI is valid (for first 10)
mixed_uris.append(f"local://{valid_files[i // 3]}")
else:
# Others are invalid
mixed_uris.append(f"local:///nonexistent/file_{i}.txt")
# Should not crash - failed downloads return None
results = sorted(ds_with_downloads.take_all(), key=lambda row: row["id"])
assert len(results) == 2

table = pa.Table.from_arrays(
[pa.array(mixed_uris)],
names=["uri"],
)
# First URI should succeed
assert results[0]["bytes"] == b"valid content"

ds = ray.data.from_arrow(table)
ds_with_downloads = ds.with_column("bytes", download("uri"))

# Should not crash - should handle mixed valid/invalid gracefully
results = ds_with_downloads.take_all()
assert len(results) == 30

# Verify valid URIs downloaded successfully
for i, result in enumerate(results):
if i % 3 == 0 and i // 3 < len(valid_files):
assert result["bytes"] == b"x" * 100
else:
assert result["bytes"] is None
# Second URI should fail gracefully (return None)
assert results[1]["bytes"] is None


class TestDownloadExpressionIntegration:
Expand Down