diff --git a/python/ray/data/_internal/arrow_block.py b/python/ray/data/_internal/arrow_block.py index 3ea83b9f04cc..f12f89d8cceb 100644 --- a/python/ray/data/_internal/arrow_block.py +++ b/python/ray/data/_internal/arrow_block.py @@ -32,7 +32,7 @@ ) from ray.data._internal.row import TableRow from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder -from ray.data._internal.util import find_partitions +from ray.data._internal.util import NULL_SENTINEL, find_partitions from ray.data.block import ( Block, BlockAccessor, @@ -500,7 +500,6 @@ def sort_and_partition( table = sort(self._table, sort_key) if len(boundaries) == 0: return [table] - return find_partitions(table, boundaries, sort_key) def combine(self, sort_key: "SortKey", aggs: Tuple["AggregateFn"]) -> Block: @@ -634,6 +633,11 @@ def key_fn(r): else: return (0,) + # Replace Nones with NULL_SENTINEL to ensure safe sorting. + def key_fn_with_null_sentinel(r): + values = key_fn(r) + return [NULL_SENTINEL if v is None else v for v in values] + # Handle blocks of different types. blocks = TableBlockAccessor.normalize_block_types(blocks, "arrow") @@ -642,7 +646,7 @@ def key_fn(r): ArrowBlockAccessor(block).iter_rows(public_row_format=False) for block in blocks ], - key=key_fn, + key=key_fn_with_null_sentinel, ) next_row = None builder = ArrowBlockBuilder() diff --git a/python/ray/data/_internal/planner/exchange/sort_task_spec.py b/python/ray/data/_internal/planner/exchange/sort_task_spec.py index 299e8793774f..827c4a2c7a51 100644 --- a/python/ray/data/_internal/planner/exchange/sort_task_spec.py +++ b/python/ray/data/_internal/planner/exchange/sort_task_spec.py @@ -7,6 +7,7 @@ from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.table_block import TableBlockAccessor +from ray.data._internal.util import NULL_SENTINEL from ray.data.block import Block, BlockAccessor, BlockExecStats, BlockMetadata from ray.types import ObjectRef @@ -23,7 +24,7 @@ def __init__( self, key: Optional[Union[str, List[str]]] = None, descending: Union[bool, List[bool]] = False, - boundaries: Optional[list] = None, + boundaries: Optional[List[T]] = None, ): if key is None: key = [] @@ -195,7 +196,23 @@ def sample_boundaries( samples_table = builder.build() samples_dict = BlockAccessor.for_block(samples_table).to_numpy(columns=columns) # This zip does the transposition from list of column values to list of tuples. - samples_list = sorted(zip(*samples_dict.values())) + samples_list = list(zip(*samples_dict.values())) + + def is_na(x): + # Check if x is None or NaN. Type casting to np.array first to avoid + # isnan failing on strings and other types. + if x is None: + return True + x = np.asarray(x) + if np.issubdtype(x.dtype, np.number): + return np.isnan(x) + return False + + def key_fn_with_nones(sample): + return tuple(NULL_SENTINEL if is_na(x) else x for x in sample) + + # Sort the list, but Nones should be NULL_SENTINEL to ensure safe sorting. + samples_list = sorted(samples_list, key=key_fn_with_nones) # Each boundary corresponds to a quantile of the data. quantile_indices = [ diff --git a/python/ray/data/_internal/util.py b/python/ray/data/_internal/util.py index 9696074fe66d..5e8c921c3733 100644 --- a/python/ray/data/_internal/util.py +++ b/python/ray/data/_internal/util.py @@ -55,6 +55,28 @@ _pyarrow_dataset: LazyModule = None +class _NullSentinel: + """Sentinel value that sorts greater than any other value.""" + + def __eq__(self, other): + return isinstance(other, _NullSentinel) + + def __lt__(self, other): + return False + + def __le__(self, other): + return isinstance(other, _NullSentinel) + + def __gt__(self, other): + return True + + def __ge__(self, other): + return True + + +NULL_SENTINEL = _NullSentinel() + + def _lazy_import_pyarrow_dataset() -> LazyModule: global _pyarrow_dataset if _pyarrow_dataset is None: @@ -723,6 +745,16 @@ def find_partition_index( col_vals = table[col_name].to_numpy()[left:right] desired_val = desired[i] + # Handle null values - replace them with sentinel values + if desired_val is None: + desired_val = NULL_SENTINEL + + # Replace None/NaN values in col_vals with sentinel + null_mask = col_vals == None # noqa: E711 + if null_mask.any(): + col_vals = col_vals.copy() # Make a copy to avoid modifying original + col_vals[null_mask] = NULL_SENTINEL + prevleft = left if descending is True: left = prevleft + ( diff --git a/python/ray/data/tests/test_all_to_all.py b/python/ray/data/tests/test_all_to_all.py index a7b6ec823a9c..cf0cb8b2b2e7 100644 --- a/python/ray/data/tests/test_all_to_all.py +++ b/python/ray/data/tests/test_all_to_all.py @@ -123,6 +123,65 @@ def test_unique(ray_start_regular_shared): assert mock_validate.call_args_list[0].args[0].names == ["b"] +@pytest.mark.parametrize("batch_format", ["pandas", "pyarrow"]) +def test_unique_with_nulls(ray_start_regular_shared, batch_format): + ds = ray.data.from_items([3, 2, 3, 1, 2, 3, None]) + assert set(ds.unique("item")) == {1, 2, 3, None} + assert len(ds.unique("item")) == 4 + + ds = ray.data.from_items( + [ + {"a": 1, "b": 1}, + {"a": 1, "b": 2}, + {"a": 1, "b": None}, + {"a": None, "b": 3}, + {"a": None, "b": 4}, + ] + ) + assert set(ds.unique("a")) == {1, None} + assert len(ds.unique("a")) == 2 + assert set(ds.unique("b")) == {1, 2, 3, 4, None} + assert len(ds.unique("b")) == 5 + + # Check with 3 columns + df = pd.DataFrame( + { + "col1": [1, 2, None, 3, None, 3, 2], + "col2": [None, 2, 2, 3, None, 3, 2], + "col3": [1, None, 2, None, None, None, 2], + } + ) + # df["col"].unique() works fine, as expected + ds2 = ray.data.from_pandas(df) + ds2 = ds2.map_batches(lambda x: x, batch_format=batch_format) + assert set(ds2.unique("col1")) == {1, 2, 3, None} + assert len(ds2.unique("col1")) == 4 + assert set(ds2.unique("col2")) == {2, 3, None} + assert len(ds2.unique("col2")) == 3 + assert set(ds2.unique("col3")) == {1, 2, None} + assert len(ds2.unique("col3")) == 3 + + # Check with 3 columns and different dtypes + df = pd.DataFrame( + { + "col1": [1, 2, None, 3, None, 3, 2], + "col2": [None, 2, 2, 3, None, 3, 2], + "col3": [1, None, 2, None, None, None, 2], + } + ) + df["col1"] = df["col1"].astype("Int64") + df["col2"] = df["col2"].astype("Float64") + df["col3"] = df["col3"].astype("string") + ds3 = ray.data.from_pandas(df) + ds3 = ds3.map_batches(lambda x: x, batch_format=batch_format) + assert set(ds3.unique("col1")) == {1, 2, 3, None} + assert len(ds3.unique("col1")) == 4 + assert set(ds3.unique("col2")) == {2, 3, None} + assert len(ds3.unique("col2")) == 3 + assert set(ds3.unique("col3")) == {"1.0", "2.0", None} + assert len(ds3.unique("col3")) == 3 + + def test_grouped_dataset_repr(ray_start_regular_shared): ds = ray.data.from_items([{"key": "spam"}, {"key": "ham"}, {"key": "spam"}]) assert repr(ds.groupby("key")) == f"GroupedData(dataset={ds!r}, key='key')" diff --git a/python/ray/data/tests/test_util.py b/python/ray/data/tests/test_util.py index b66a9bc5804f..e5bfed6154d0 100644 --- a/python/ray/data/tests/test_util.py +++ b/python/ray/data/tests/test_util.py @@ -14,6 +14,7 @@ ) from ray.data._internal.remote_fn import _make_hashable, cached_remote_fn from ray.data._internal.util import ( + NULL_SENTINEL, _check_pyarrow_version, _split_list, iterate_with_retry, @@ -35,6 +36,21 @@ def foo(): assert cpu_only_foo != gpu_only_foo +def test_null_sentinel(): + """Check that NULL_SENTINEL sorts greater than any other value.""" + assert NULL_SENTINEL > 1000 + assert NULL_SENTINEL > "abc" + assert NULL_SENTINEL == NULL_SENTINEL + assert NULL_SENTINEL != 1000 + assert NULL_SENTINEL != "abc" + assert not NULL_SENTINEL < 1000 + assert not NULL_SENTINEL < "abc" + assert not NULL_SENTINEL <= 1000 + assert not NULL_SENTINEL <= "abc" + assert NULL_SENTINEL >= 1000 + assert NULL_SENTINEL >= "abc" + + def test_make_hashable(): valid_args = { "int": 0,