diff --git a/python/ray/data/_expression_evaluator.py b/python/ray/data/_expression_evaluator.py index b94b6541860c..6e177e909fa6 100644 --- a/python/ray/data/_expression_evaluator.py +++ b/python/ray/data/_expression_evaluator.py @@ -1,7 +1,7 @@ from __future__ import annotations import operator -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Union import numpy as np import pandas as pd @@ -47,8 +47,67 @@ def _pa_is_in(left: Any, right: Any) -> Any: Operation.NOT_IN: lambda left, right: ~left.is_in(right), } + +def _is_pa_string_type(t: pa.DataType) -> bool: + return pa.types.is_string(t) or pa.types.is_large_string(t) + + +def _is_pa_string_like(x: Union[pa.Array, pa.ChunkedArray]) -> bool: + t = x.type + if pa.types.is_dictionary(t): + t = t.value_type + return _is_pa_string_type(t) + + +def _pa_decode_dict_string_array(x: Union[pa.Array, pa.ChunkedArray]) -> Any: + """Convert Arrow dictionary-encoded string arrays to regular string arrays. + + Dictionary encoding stores strings as indices into a dictionary of unique values. + This function converts them back to regular string arrays for string operations. + + Example: + # Input: pa.array(['a', 'b']).dictionary_encode() + # -- dictionary: ["a", "b"] + # -- indices: [0, 1] + # Output: regular string array ["a", "b"] + Args: + x: The input array to convert. + Returns: + The converted string array. + """ + if pa.types.is_dictionary(x.type) and _is_pa_string_type(x.type.value_type): + return pc.cast(x, pa.string()) + return x + + +def _to_pa_string_input(x: Any) -> Any: + if isinstance(x, str): + return pa.scalar(x) + elif _is_pa_string_like(x) and isinstance(x, (pa.Array, pa.ChunkedArray)): + x = _pa_decode_dict_string_array(x) + else: + raise + return x + + +def _pa_add_or_concat(left: Any, right: Any) -> Any: + # If either side is string-like, perform string concatenation. + if ( + isinstance(left, str) + or isinstance(right, str) + or (isinstance(left, (pa.Array, pa.ChunkedArray)) and _is_pa_string_like(left)) + or ( + isinstance(right, (pa.Array, pa.ChunkedArray)) and _is_pa_string_like(right) + ) + ): + left_input = _to_pa_string_input(left) + right_input = _to_pa_string_input(right) + return pc.binary_join_element_wise(left_input, right_input, "") + return pc.add(left, right) + + _ARROW_EXPR_OPS_MAP: Dict[Operation, Callable[..., Any]] = { - Operation.ADD: pc.add, + Operation.ADD: _pa_add_or_concat, Operation.SUB: pc.subtract, Operation.MUL: pc.multiply, Operation.DIV: pc.divide, diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index e90d027d242a..05e311aa4c5e 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -2710,6 +2710,106 @@ def invalid_int_return(x: pa.Array) -> int: assert "pandas.Series" in error_message and "numpy.ndarray" in error_message +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="with_column requires PyArrow >= 20.0.0", +) +@pytest.mark.parametrize( + "scenario", + [ + pytest.param( + { + "data": [ + {"name": "Alice"}, + {"name": "Bob"}, + {"name": "Charlie"}, + ], + "expr_factory": lambda: col("name") + "_X", + "column_name": "name_with_suffix", + "expected": ["Alice_X", "Bob_X", "Charlie_X"], + }, + id="string_col_plus_python_literal_rhs", + ), + pytest.param( + { + "data": [ + {"name": "Alice"}, + {"name": "Bob"}, + {"name": "Charlie"}, + ], + "expr_factory": lambda: "_X" + col("name"), + "column_name": "name_with_prefix", + "expected": ["_XAlice", "_XBob", "_XCharlie"], + }, + id="python_literal_lhs_plus_string_col", + ), + pytest.param( + { + "data": [ + {"first": "John", "last": "Doe"}, + {"first": "Jane", "last": "Smith"}, + ], + "expr_factory": lambda: col("first") + col("last"), + "column_name": "full_name", + "expected": ["JohnDoe", "JaneSmith"], + }, + id="string_col_plus_string_col", + ), + pytest.param( + { + "arrow_table": pa.table( + {"name": pa.array(["Alice", "Bob"]).dictionary_encode()} + ), + "expr_factory": lambda: col("name") + "_X", + "column_name": "name_with_suffix", + "expected": ["Alice_X", "Bob_X"], + }, + id="dict_encoded_string_col_plus_literal_rhs", + ), + pytest.param( + { + "data": [ + {"name": "Alice"}, + {"name": "Bob"}, + ], + "expr_factory": lambda: col("name") + lit("_X"), + "column_name": "name_with_suffix", + "expected": ["Alice_X", "Bob_X"], + }, + id="string_col_plus_lit_literal_rhs", + ), + ], +) +def test_with_column_string_concat_combinations( + ray_start_regular_shared, + scenario, +): + if "arrow_table" in scenario: + ds = ray.data.from_arrow(scenario["arrow_table"]) + else: + ds = ray.data.from_items(scenario["data"]) + + expr = scenario["expr_factory"]() + column_name = scenario["column_name"] + + ds2 = ds.with_column(column_name, expr) + out = ds2.to_pandas() + assert out[column_name].tolist() == scenario["expected"] + + +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="with_column requires PyArrow >= 20.0.0", +) +def test_with_column_string_concat_type_mismatch_raises( + ray_start_regular_shared, +): + # int + string should raise a user-facing error + ds = ray.data.range(3) + with pytest.raises((RayTaskError, UserCodeException)): + ds.with_column("bad", col("id") + "_X").materialize() + + @pytest.mark.skipif( get_pyarrow_version() < parse_version("20.0.0"), reason="with_column requires PyArrow >= 20.0.0",