Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the boundary param for sort in ray.data.Dataset #41269

Merged
merged 17 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/ray/data/_internal/planner/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ def fn(
num_outputs = num_mappers

# Sample boundaries for sort key.
boundaries = SortTaskSpec.sample_boundaries(blocks, sort_key, num_outputs)
if not sort_key.boundaries:
boundaries = SortTaskSpec.sample_boundaries(blocks, sort_key, num_outputs)
else:
boundaries = [(b,) for b in sort_key.boundaries]
num_outputs = len(boundaries) + 1
_, ascending = sort_key.to_pandas_sort_args()
if not ascending:
boundaries.reverse()
Expand Down
20 changes: 19 additions & 1 deletion python/ray/data/_internal/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
self,
key: Optional[Union[str, List[str]]] = None,
descending: Union[bool, List[bool]] = False,
boundaries: Optional[list] = None,
):
if key is None:
key = []
Expand All @@ -64,6 +65,15 @@ def __init__(
raise ValueError("Sorting with mixed key orders not supported yet.")
self._columns = key
self._descending = descending
if boundaries:
for item in boundaries:
if not isinstance(item, (int, float)):
raise ValueError(
"The type of items in boundaries must be int or float."
)
boundaries = list(set(boundaries))
boundaries.sort()
self._boundaries = boundaries

def get_columns(self) -> List[str]:
return self._columns
Expand Down Expand Up @@ -94,6 +104,10 @@ def validate_schema(self, schema: Optional[Union[type, "pyarrow.lib.Schema"]]):
"schema '{}'.".format(column, schema)
)

@property
def boundaries(self):
return self._boundaries


class _SortOp(ShuffleOp):
@staticmethod
Expand Down Expand Up @@ -209,7 +223,11 @@ def sort_impl(
# Use same number of output partitions.
num_reducers = num_mappers
# TODO(swang): sample_boundaries could be fused with a previous stage.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment should follow closely with the related code. In this case, I think it should be moved into the if block.

boundaries = sample_boundaries(blocks_list, sort_key, num_reducers, ctx)
if not sort_key.boundaries:
boundaries = sample_boundaries(blocks_list, sort_key, num_reducers, ctx)
else:
boundaries = [(b,) for b in sort_key.boundaries]
num_reducers = len(boundaries) + 1
_, ascending = sort_key.to_pandas_sort_args()
if not ascending:
boundaries.reverse()
Expand Down
37 changes: 32 additions & 5 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,6 +2294,7 @@ def sort(
self,
key: Union[str, List[str], None] = None,
descending: Union[bool, List[bool]] = False,
boundaries: List[Union[int, float]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it work for non-numeric columns?

Copy link
Contributor Author

@veryhannibal veryhannibal Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, this function cannot currently process non-numeric columns. However, in our business, if we encounter a non-numeric column, we will process it and convert it to a numeric type.
For example, for a non-numeric column, calculate the hash value and then take modulo 3. Then the value of this column becomes 0, 1 or 2. Then, if the parameter boundaries is set to [1,2], then the rows with values 0, 1, and 2 will be divided into three blocks respectively.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good for now; could you just update the docstring to say that this only supports numeric columns right now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have added code comments to explain that the boundaries parameter currently supports numeric types.😁😁😁

) -> "Dataset":
"""Sort the dataset by the specified key column or key function.

Expand All @@ -2304,22 +2305,48 @@ def sort(

Examples:
>>> import ray
>>> ds = ray.data.range(100)
>>> ds.sort("id", descending=True).take(3)
[{'id': 99}, {'id': 98}, {'id': 97}]
>>> ds = ray.data.range(15)
>>> ds = ds.sort("id", descending=False, boundaries=[5, 10])
>>> for df in ray.get(ds.to_pandas_refs()):
... print(df)
id
0 0
1 1
2 2
3 3
4 4
id
0 5
1 6
2 7
3 8
4 9
id
0 10
1 11
2 12
3 13
4 14

Time complexity: O(dataset size * log(dataset size / parallelism))

Args:
key: The column or a list of columns to sort by.
descending: Whether to sort in descending order. Must be a boolean or a list
of booleans matching the number of the columns.
boundaries: The list of values based on which to repartition the dataset.
For example, if the input boundary is [10,20], rows with values less
than 10 will be divided into the first block, rows with values greater
than or equal to 10 and less than 20 will be divided into the
second block, and rows with values greater than or equal to 20
will be divided into the third block. If not provided, the
boundaries will be sampled from the input blocks. This feature
only supports numeric columns right now.

Returns:
A new, sorted :class:`Dataset`.
"""

sort_key = SortKey(key, descending)
sort_key = SortKey(key, descending, boundaries)
plan = self._plan.with_stage(SortStage(self, sort_key))

logical_plan = self._logical_plan
Expand Down
32 changes: 32 additions & 0 deletions python/ray/data/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,37 @@
from ray.tests.conftest import * # noqa


@pytest.mark.parametrize(
"descending,boundaries",
[
(True, list(range(100, 1000, 200))),
(False, list(range(100, 1000, 200))),
(True, [1, 998]),
(False, [1, 998]),
# Test float.
(True, [501.5]),
(False, [501.5]),
],
)
def test_sort_with_specified_boundaries(ray_start_regular, descending, boundaries):
num_items = 1000
ds = ray.data.range(num_items)
ds = ds.sort("id", descending, boundaries).materialize()

items = range(num_items)
boundaries = [0] + sorted([round(b) for b in boundaries]) + [num_items]
expected_blocks = [
items[boundaries[i] : boundaries[i + 1]] for i in range(len(boundaries) - 1)
]
if descending:
expected_blocks = [list(reversed(block)) for block in reversed(expected_blocks)]

blocks = list(ds.iter_batches(batch_size=None))
assert len(blocks) == len(expected_blocks)
for block, expected_block in zip(blocks, expected_blocks):
assert np.all(block["id"] == expected_block)


def test_sort_simple(ray_start_regular, use_push_based_shuffle):
num_items = 100
parallelism = 4
Expand All @@ -30,6 +61,7 @@ def test_sort_simple(ray_start_regular, use_push_based_shuffle):
)
# Make sure we have rows in each block.
assert len([n for n in ds.sort("item")._block_num_rows() if n > 0]) == parallelism

assert extract_values(
"item", ds.sort("item", descending=True).take(num_items)
) == list(reversed(range(num_items)))
Expand Down
Loading