Skip to content

Commit

Permalink
[data] add validation for shuffle arg (#47055)
Browse files Browse the repository at this point in the history
Add validation for `shuffle` argument. Previously, specifying an
incorrect argument would lead to the default behavior of `None`, which
does no shuffling.


Signed-off-by: Matthew Deng <matt@anyscale.com>
  • Loading branch information
matthewdeng authored Aug 13, 2024
1 parent 65c8952 commit e2e4076
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/ray/data/datasource/file_based_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
"'file_extensions' field is set properly."
)

_validate_shuffle_arg(shuffle)
self._file_metadata_shuffler = None
if shuffle == "files":
self._file_metadata_shuffler = np.random.default_rng()
Expand Down Expand Up @@ -519,3 +520,11 @@ def _open_file_with_retry(
max_attempts=OPEN_FILE_MAX_ATTEMPTS,
max_backoff_s=OPEN_FILE_RETRY_MAX_BACKOFF_SECONDS,
)


def _validate_shuffle_arg(shuffle: Optional[str]) -> None:
if shuffle not in [None, "files"]:
raise ValueError(
f"Invalid value for 'shuffle': {shuffle}. "
"Valid values are None, 'files'."
)
10 changes: 10 additions & 0 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,8 @@ def read_parquet(
:class:`~ray.data.Dataset` producing records read from the specified parquet
files.
"""
_validate_shuffle_arg(shuffle)

if meta_provider is None:
meta_provider = get_parquet_metadata_provider(override_num_blocks)
arrow_parquet_args = _resolve_parquet_args(
Expand Down Expand Up @@ -3157,3 +3159,11 @@ def _get_num_output_blocks(
elif override_num_blocks is not None:
parallelism = override_num_blocks
return parallelism


def _validate_shuffle_arg(shuffle: Optional[str]) -> None:
if shuffle not in [None, "files"]:
raise ValueError(
f"Invalid value for 'shuffle': {shuffle}. "
"Valid values are None, 'files'."
)
11 changes: 11 additions & 0 deletions python/ray/data/tests/test_file_based_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ def test_windows_path():
assert _is_local_windows_path("c:\\some\\where/mixed")


@pytest.mark.parametrize("shuffle", [True, False, "file"])
def test_invalid_shuffle_arg_raises_error(ray_start_regular_shared, shuffle):
with pytest.raises(ValueError):
FileBasedDatasource("example://iris.csv", shuffle=shuffle)


@pytest.mark.parametrize("shuffle", [None, "files"])
def test_valid_shuffle_arg_does_not_raise_error(ray_start_regular_shared, shuffle):
FileBasedDatasource("example://iris.csv", shuffle=shuffle)


if __name__ == "__main__":
import sys

Expand Down
12 changes: 12 additions & 0 deletions python/ray/data/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,18 @@ def test_write_num_rows_per_file(tmp_path, ray_start_regular_shared, num_rows_pe
assert len(table) == num_rows_per_file


@pytest.mark.parametrize("shuffle", [True, False, "file"])
def test_invalid_shuffle_arg_raises_error(ray_start_regular_shared, shuffle):

with pytest.raises(ValueError):
ray.data.read_parquet("example://iris.parquet", shuffle=shuffle)


@pytest.mark.parametrize("shuffle", [None, "files"])
def test_valid_shuffle_arg_does_not_raise_error(ray_start_regular_shared, shuffle):
ray.data.read_parquet("example://iris.parquet", shuffle=shuffle)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit e2e4076

Please sign in to comment.