Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 61 additions & 2 deletions python/ray/data/_expression_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import operator
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -47,8 +47,67 @@ def _pa_is_in(left: Any, right: Any) -> Any:
Operation.NOT_IN: lambda left, right: ~left.is_in(right),
}


def _is_pa_string_type(t: pa.DataType) -> bool:
return pa.types.is_string(t) or pa.types.is_large_string(t)


def _is_pa_string_like(x: Union[pa.Array, pa.ChunkedArray]) -> bool:
t = x.type
if pa.types.is_dictionary(t):
t = t.value_type
return _is_pa_string_type(t)


def _pa_decode_dict_string_array(x: Union[pa.Array, pa.ChunkedArray]) -> Any:
"""Convert Arrow dictionary-encoded string arrays to regular string arrays.

Dictionary encoding stores strings as indices into a dictionary of unique values.
This function converts them back to regular string arrays for string operations.

Example:
# Input: pa.array(['a', 'b']).dictionary_encode()
# -- dictionary: ["a", "b"]
# -- indices: [0, 1]
# Output: regular string array ["a", "b"]
Args:
x: The input array to convert.
Returns:
The converted string array.
"""
if pa.types.is_dictionary(x.type) and _is_pa_string_type(x.type.value_type):
return pc.cast(x, pa.string())
return x


def _to_pa_string_input(x: Any) -> Any:
if isinstance(x, str):
return pa.scalar(x)
elif _is_pa_string_like(x) and isinstance(x, (pa.Array, pa.ChunkedArray)):
x = _pa_decode_dict_string_array(x)
else:
raise
return x


def _pa_add_or_concat(left: Any, right: Any) -> Any:
# If either side is string-like, perform string concatenation.
if (
isinstance(left, str)
or isinstance(right, str)
or (isinstance(left, (pa.Array, pa.ChunkedArray)) and _is_pa_string_like(left))
or (
isinstance(right, (pa.Array, pa.ChunkedArray)) and _is_pa_string_like(right)
)
):
left_input = _to_pa_string_input(left)
right_input = _to_pa_string_input(right)
return pc.binary_join_element_wise(left_input, right_input, "")
return pc.add(left, right)


_ARROW_EXPR_OPS_MAP: Dict[Operation, Callable[..., Any]] = {
Operation.ADD: pc.add,
Operation.ADD: _pa_add_or_concat,
Operation.SUB: pc.subtract,
Operation.MUL: pc.multiply,
Operation.DIV: pc.divide,
Expand Down
100 changes: 100 additions & 0 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2710,6 +2710,106 @@ def invalid_int_return(x: pa.Array) -> int:
assert "pandas.Series" in error_message and "numpy.ndarray" in error_message


@pytest.mark.skipif(
get_pyarrow_version() < parse_version("20.0.0"),
reason="with_column requires PyArrow >= 20.0.0",
)
@pytest.mark.parametrize(
"scenario",
[
pytest.param(
{
"data": [
{"name": "Alice"},
{"name": "Bob"},
{"name": "Charlie"},
],
"expr_factory": lambda: col("name") + "_X",
"column_name": "name_with_suffix",
"expected": ["Alice_X", "Bob_X", "Charlie_X"],
},
id="string_col_plus_python_literal_rhs",
),
pytest.param(
{
"data": [
{"name": "Alice"},
{"name": "Bob"},
{"name": "Charlie"},
],
"expr_factory": lambda: "_X" + col("name"),
"column_name": "name_with_prefix",
"expected": ["_XAlice", "_XBob", "_XCharlie"],
},
id="python_literal_lhs_plus_string_col",
),
pytest.param(
{
"data": [
{"first": "John", "last": "Doe"},
{"first": "Jane", "last": "Smith"},
],
"expr_factory": lambda: col("first") + col("last"),
"column_name": "full_name",
"expected": ["JohnDoe", "JaneSmith"],
},
id="string_col_plus_string_col",
),
pytest.param(
{
"arrow_table": pa.table(
{"name": pa.array(["Alice", "Bob"]).dictionary_encode()}
),
"expr_factory": lambda: col("name") + "_X",
"column_name": "name_with_suffix",
"expected": ["Alice_X", "Bob_X"],
},
id="dict_encoded_string_col_plus_literal_rhs",
),
pytest.param(
{
"data": [
{"name": "Alice"},
{"name": "Bob"},
],
"expr_factory": lambda: col("name") + lit("_X"),
"column_name": "name_with_suffix",
"expected": ["Alice_X", "Bob_X"],
},
id="string_col_plus_lit_literal_rhs",
),
],
)
def test_with_column_string_concat_combinations(
ray_start_regular_shared,
scenario,
):
if "arrow_table" in scenario:
ds = ray.data.from_arrow(scenario["arrow_table"])
else:
ds = ray.data.from_items(scenario["data"])

expr = scenario["expr_factory"]()
column_name = scenario["column_name"]

ds2 = ds.with_column(column_name, expr)
out = ds2.to_pandas()
assert out[column_name].tolist() == scenario["expected"]


@pytest.mark.skipif(
get_pyarrow_version() < parse_version("20.0.0"),
reason="with_column requires PyArrow >= 20.0.0",
)
def test_with_column_string_concat_type_mismatch_raises(
ray_start_regular_shared,
):
# int + string should raise a user-facing error
ds = ray.data.range(3)
with pytest.raises((RayTaskError, UserCodeException)):
ds.with_column("bad", col("id") + "_X").materialize()


@pytest.mark.skipif(
get_pyarrow_version() < parse_version("20.0.0"),
reason="with_column requires PyArrow >= 20.0.0",
Expand Down