diff --git a/python/ray/data/_internal/planner/sort.py b/python/ray/data/_internal/planner/sort.py index d5958cb47619..808f445ff8a9 100644 --- a/python/ray/data/_internal/planner/sort.py +++ b/python/ray/data/_internal/planner/sort.py @@ -45,7 +45,11 @@ def fn( num_outputs = num_mappers # Sample boundaries for sort key. - boundaries = SortTaskSpec.sample_boundaries(blocks, sort_key, num_outputs) + if not sort_key.boundaries: + boundaries = SortTaskSpec.sample_boundaries(blocks, sort_key, num_outputs) + else: + boundaries = [(b,) for b in sort_key.boundaries] + num_outputs = len(boundaries) + 1 _, ascending = sort_key.to_pandas_sort_args() if not ascending: boundaries.reverse() diff --git a/python/ray/data/_internal/sort.py b/python/ray/data/_internal/sort.py index 25fec3bf80d2..7eaf859a41c6 100644 --- a/python/ray/data/_internal/sort.py +++ b/python/ray/data/_internal/sort.py @@ -44,6 +44,7 @@ def __init__( self, key: Optional[Union[str, List[str]]] = None, descending: Union[bool, List[bool]] = False, + boundaries: Optional[list] = None, ): if key is None: key = [] @@ -64,6 +65,15 @@ def __init__( raise ValueError("Sorting with mixed key orders not supported yet.") self._columns = key self._descending = descending + if boundaries: + for item in boundaries: + if not isinstance(item, (int, float)): + raise ValueError( + "The type of items in boundaries must be int or float." + ) + boundaries = list(set(boundaries)) + boundaries.sort() + self._boundaries = boundaries def get_columns(self) -> List[str]: return self._columns @@ -94,6 +104,10 @@ def validate_schema(self, schema: Optional[Union[type, "pyarrow.lib.Schema"]]): "schema '{}'.".format(column, schema) ) + @property + def boundaries(self): + return self._boundaries + class _SortOp(ShuffleOp): @staticmethod @@ -209,7 +223,11 @@ def sort_impl( # Use same number of output partitions. num_reducers = num_mappers # TODO(swang): sample_boundaries could be fused with a previous stage. - boundaries = sample_boundaries(blocks_list, sort_key, num_reducers, ctx) + if not sort_key.boundaries: + boundaries = sample_boundaries(blocks_list, sort_key, num_reducers, ctx) + else: + boundaries = [(b,) for b in sort_key.boundaries] + num_reducers = len(boundaries) + 1 _, ascending = sort_key.to_pandas_sort_args() if not ascending: boundaries.reverse() diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index ad52edd1ab41..5395e30e0b5d 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2294,6 +2294,7 @@ def sort( self, key: Union[str, List[str], None] = None, descending: Union[bool, List[bool]] = False, + boundaries: List[Union[int, float]] = None, ) -> "Dataset": """Sort the dataset by the specified key column or key function. @@ -2304,9 +2305,28 @@ def sort( Examples: >>> import ray - >>> ds = ray.data.range(100) - >>> ds.sort("id", descending=True).take(3) - [{'id': 99}, {'id': 98}, {'id': 97}] + >>> ds = ray.data.range(15) + >>> ds = ds.sort("id", descending=False, boundaries=[5, 10]) + >>> for df in ray.get(ds.to_pandas_refs()): + ... print(df) + id + 0 0 + 1 1 + 2 2 + 3 3 + 4 4 + id + 0 5 + 1 6 + 2 7 + 3 8 + 4 9 + id + 0 10 + 1 11 + 2 12 + 3 13 + 4 14 Time complexity: O(dataset size * log(dataset size / parallelism)) @@ -2314,12 +2334,19 @@ def sort( key: The column or a list of columns to sort by. descending: Whether to sort in descending order. Must be a boolean or a list of booleans matching the number of the columns. + boundaries: The list of values based on which to repartition the dataset. + For example, if the input boundary is [10,20], rows with values less + than 10 will be divided into the first block, rows with values greater + than or equal to 10 and less than 20 will be divided into the + second block, and rows with values greater than or equal to 20 + will be divided into the third block. If not provided, the + boundaries will be sampled from the input blocks. This feature + only supports numeric columns right now. Returns: A new, sorted :class:`Dataset`. """ - - sort_key = SortKey(key, descending) + sort_key = SortKey(key, descending, boundaries) plan = self._plan.with_stage(SortStage(self, sort_key)) logical_plan = self._logical_plan diff --git a/python/ray/data/tests/test_sort.py b/python/ray/data/tests/test_sort.py index cfd6753cad82..36e47e0e0780 100644 --- a/python/ray/data/tests/test_sort.py +++ b/python/ray/data/tests/test_sort.py @@ -19,6 +19,37 @@ from ray.tests.conftest import * # noqa +@pytest.mark.parametrize( + "descending,boundaries", + [ + (True, list(range(100, 1000, 200))), + (False, list(range(100, 1000, 200))), + (True, [1, 998]), + (False, [1, 998]), + # Test float. + (True, [501.5]), + (False, [501.5]), + ], +) +def test_sort_with_specified_boundaries(ray_start_regular, descending, boundaries): + num_items = 1000 + ds = ray.data.range(num_items) + ds = ds.sort("id", descending, boundaries).materialize() + + items = range(num_items) + boundaries = [0] + sorted([round(b) for b in boundaries]) + [num_items] + expected_blocks = [ + items[boundaries[i] : boundaries[i + 1]] for i in range(len(boundaries) - 1) + ] + if descending: + expected_blocks = [list(reversed(block)) for block in reversed(expected_blocks)] + + blocks = list(ds.iter_batches(batch_size=None)) + assert len(blocks) == len(expected_blocks) + for block, expected_block in zip(blocks, expected_blocks): + assert np.all(block["id"] == expected_block) + + def test_sort_simple(ray_start_regular, use_push_based_shuffle): num_items = 100 parallelism = 4 @@ -30,6 +61,7 @@ def test_sort_simple(ray_start_regular, use_push_based_shuffle): ) # Make sure we have rows in each block. assert len([n for n in ds.sort("item")._block_num_rows() if n > 0]) == parallelism + assert extract_values( "item", ds.sort("item", descending=True).take(num_items) ) == list(reversed(range(num_items)))