diff --git a/doc/source/data/api/expressions.rst b/doc/source/data/api/expressions.rst index 3e73a314b3bf..b5966c1636c2 100644 --- a/doc/source/data/api/expressions.rst +++ b/doc/source/data/api/expressions.rst @@ -35,4 +35,6 @@ instantiate them directly, but you may encounter them when working with expressi Expr ColumnExpr LiteralExpr - BinaryExpr \ No newline at end of file + BinaryExpr + UnaryExpr + UDFExpr \ No newline at end of file diff --git a/python/ray/data/_expression_evaluator.py b/python/ray/data/_expression_evaluator.py index 26642055aa2e..b94b6541860c 100644 --- a/python/ray/data/_expression_evaluator.py +++ b/python/ray/data/_expression_evaluator.py @@ -16,34 +16,56 @@ LiteralExpr, Operation, UDFExpr, + UnaryExpr, ) -_PANDAS_EXPR_OPS_MAP = { + +def _pa_is_in(left: Any, right: Any) -> Any: + if not isinstance(right, (pa.Array, pa.ChunkedArray)): + right = pa.array(right.as_py() if isinstance(right, pa.Scalar) else right) + return pc.is_in(left, right) + + +_PANDAS_EXPR_OPS_MAP: Dict[Operation, Callable[..., Any]] = { Operation.ADD: operator.add, Operation.SUB: operator.sub, Operation.MUL: operator.mul, Operation.DIV: operator.truediv, + Operation.FLOORDIV: operator.floordiv, Operation.GT: operator.gt, Operation.LT: operator.lt, Operation.GE: operator.ge, Operation.LE: operator.le, Operation.EQ: operator.eq, + Operation.NE: operator.ne, Operation.AND: operator.and_, Operation.OR: operator.or_, + Operation.NOT: operator.not_, + Operation.IS_NULL: pd.isna, + Operation.IS_NOT_NULL: pd.notna, + Operation.IN: lambda left, right: left.is_in(right), + Operation.NOT_IN: lambda left, right: ~left.is_in(right), } -_ARROW_EXPR_OPS_MAP = { +_ARROW_EXPR_OPS_MAP: Dict[Operation, Callable[..., Any]] = { Operation.ADD: pc.add, Operation.SUB: pc.subtract, Operation.MUL: pc.multiply, Operation.DIV: pc.divide, + Operation.FLOORDIV: lambda left, right: pc.floor(pc.divide(left, right)), Operation.GT: pc.greater, Operation.LT: pc.less, Operation.GE: pc.greater_equal, Operation.LE: pc.less_equal, Operation.EQ: pc.equal, - Operation.AND: pc.and_, - Operation.OR: pc.or_, + Operation.NE: pc.not_equal, + Operation.AND: pc.and_kleene, + Operation.OR: pc.or_kleene, + Operation.NOT: pc.invert, + Operation.IS_NULL: pc.is_null, + Operation.IS_NOT_NULL: pc.is_valid, + Operation.IN: _pa_is_in, + Operation.NOT_IN: lambda left, right: pc.invert(_pa_is_in(left, right)), } @@ -63,6 +85,10 @@ def _eval_expr_recursive( _eval_expr_recursive(expr.left, batch, ops), _eval_expr_recursive(expr.right, batch, ops), ) + if isinstance(expr, UnaryExpr): + # TODO: Use Visitor pattern here and store ops in shared state. + return ops[expr.op](_eval_expr_recursive(expr.operand, batch, ops)) + if isinstance(expr, UDFExpr): args = [_eval_expr_recursive(arg, batch, ops) for arg in expr.args] kwargs = { @@ -79,6 +105,7 @@ def _eval_expr_recursive( ) return result + raise TypeError(f"Unsupported expression node: {type(expr).__name__}") diff --git a/python/ray/data/_internal/pandas_block.py b/python/ray/data/_internal/pandas_block.py index 1c82b10cc5c5..ff08af3c0622 100644 --- a/python/ray/data/_internal/pandas_block.py +++ b/python/ray/data/_internal/pandas_block.py @@ -320,6 +320,8 @@ def rename_columns(self, columns_rename: Dict[str, str]) -> "pandas.DataFrame": def upsert_column( self, column_name: str, column_data: BlockColumn ) -> "pandas.DataFrame": + import pyarrow + if isinstance(column_data, (pyarrow.Array, pyarrow.ChunkedArray)): column_data = column_data.to_pandas() diff --git a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py index 2a1adffca718..938c2a2d21fc 100644 --- a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py +++ b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py @@ -65,9 +65,9 @@ def visit_Compare(self, node: ast.Compare) -> ds.Expression: op = node.ops[0] if isinstance(op, ast.In): - return left_expr.isin(comparators[0]) + return left_expr.is_in(comparators[0]) elif isinstance(op, ast.NotIn): - return ~left_expr.isin(comparators[0]) + return ~left_expr.is_in(comparators[0]) elif isinstance(op, ast.Eq): return left_expr == comparators[0] elif isinstance(op, ast.NotEq): @@ -210,7 +210,7 @@ def visit_Call(self, node: ast.Call) -> ds.Expression: nan_is_null=nan_is_null ), "is_valid": lambda arg: arg.is_valid(), - "isin": lambda arg1, arg2: arg1.isin(arg2), + "is_in": lambda arg1, arg2: arg1.is_in(arg2), } if func_name in function_map: @@ -224,11 +224,11 @@ def visit_Call(self, node: ast.Call) -> ds.Expression: return function_map[func_name](args[0], args[1]) else: raise ValueError("is_null function requires one or two arguments.") - # Handle the "isin" function with exactly two arguments - elif func_name == "isin" and len(args) != 2: - raise ValueError("isin function requires two arguments.") + # Handle the "is_in" function with exactly two arguments + elif func_name == "is_in" and len(args) != 2: + raise ValueError("is_in function requires two arguments.") # Ensure the function has one argument (for functions like is_valid) - elif func_name != "isin" and len(args) != 1: + elif func_name != "is_in" and len(args) != 1: raise ValueError(f"{func_name} function requires exactly one argument.") # Call the corresponding function with the arguments return function_map[func_name](*args) diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 94a799802068..2bdd358c0515 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Union from ray.data.block import BatchColumn from ray.data.datatype import DataType @@ -23,26 +23,40 @@ class Operation(Enum): SUB: Subtraction operation (-) MUL: Multiplication operation (*) DIV: Division operation (/) + FLOORDIV: Floor division operation (//) GT: Greater than comparison (>) LT: Less than comparison (<) GE: Greater than or equal comparison (>=) LE: Less than or equal comparison (<=) EQ: Equality comparison (==) + NE: Not equal comparison (!=) AND: Logical AND operation (&) OR: Logical OR operation (|) + NOT: Logical NOT operation (~) + IS_NULL: Check if value is null + IS_NOT_NULL: Check if value is not null + IN: Check if value is in a list + NOT_IN: Check if value is not in a list """ ADD = "add" SUB = "sub" MUL = "mul" DIV = "div" + FLOORDIV = "floordiv" GT = "gt" LT = "lt" GE = "ge" LE = "le" EQ = "eq" + NE = "ne" AND = "and" OR = "or" + NOT = "not" + IS_NULL = "is_null" + IS_NOT_NULL = "is_not_null" + IN = "in" + NOT_IN = "not_in" @DeveloperAPI(stability="alpha") @@ -127,6 +141,14 @@ def __rtruediv__(self, other: Any) -> "Expr": """Reverse division operator (for literal / expr).""" return LiteralExpr(other)._bin(self, Operation.DIV) + def __floordiv__(self, other: Any) -> "Expr": + """Floor division operator (//).""" + return self._bin(other, Operation.FLOORDIV) + + def __rfloordiv__(self, other: Any) -> "Expr": + """Reverse floor division operator (for literal // expr).""" + return LiteralExpr(other)._bin(self, Operation.FLOORDIV) + # comparison def __gt__(self, other: Any) -> "Expr": """Greater than operator (>).""" @@ -148,6 +170,10 @@ def __eq__(self, other: Any) -> "Expr": """Equality operator (==).""" return self._bin(other, Operation.EQ) + def __ne__(self, other: Any) -> "Expr": + """Not equal operator (!=).""" + return self._bin(other, Operation.NE) + # boolean def __and__(self, other: Any) -> "Expr": """Logical AND operator (&).""" @@ -157,6 +183,31 @@ def __or__(self, other: Any) -> "Expr": """Logical OR operator (|).""" return self._bin(other, Operation.OR) + def __invert__(self) -> "Expr": + """Logical NOT operator (~).""" + return UnaryExpr(Operation.NOT, self) + + # predicate methods + def is_null(self) -> "Expr": + """Check if the expression value is null.""" + return UnaryExpr(Operation.IS_NULL, self) + + def is_not_null(self) -> "Expr": + """Check if the expression value is not null.""" + return UnaryExpr(Operation.IS_NOT_NULL, self) + + def is_in(self, values: Union[List[Any], "Expr"]) -> "Expr": + """Check if the expression value is in a list of values.""" + if not isinstance(values, Expr): + values = LiteralExpr(values) + return self._bin(values, Operation.IN) + + def not_in(self, values: Union[List[Any], "Expr"]) -> "Expr": + """Check if the expression value is not in a list of values.""" + if not isinstance(values, Expr): + values = LiteralExpr(values) + return self._bin(values, Operation.NOT_IN) + @DeveloperAPI(stability="alpha") @dataclass(frozen=True, eq=False) @@ -257,6 +308,39 @@ def structurally_equals(self, other: Any) -> bool: ) +@DeveloperAPI(stability="alpha") +@dataclass(frozen=True, eq=False) +class UnaryExpr(Expr): + """Expression that represents a unary operation on a single expression. + + This expression type represents an operation with one operand. + Common unary operations include logical NOT, IS NULL, IS NOT NULL, etc. + + Args: + op: The operation to perform (from Operation enum) + operand: The operand expression + + Example: + >>> from ray.data.expressions import col + >>> # Check if a column is null + >>> expr = col("age").is_null() # Creates UnaryExpr(IS_NULL, col("age")) + >>> # Logical not + >>> expr = ~(col("active")) # Creates UnaryExpr(NOT, col("active")) + """ + + op: Operation + operand: Expr + + data_type: DataType = field(init=False) + + def structurally_equals(self, other: Any) -> bool: + return ( + isinstance(other, UnaryExpr) + and self.op is other.op + and self.operand.structurally_equals(other.operand) + ) + + @DeveloperAPI(stability="alpha") @dataclass(frozen=True, eq=False) class UDFExpr(Expr): @@ -517,6 +601,7 @@ def download(uri_column_name: str) -> DownloadExpr: "ColumnExpr", "LiteralExpr", "BinaryExpr", + "UnaryExpr", "UDFExpr", "udf", "DownloadExpr", diff --git a/python/ray/data/tests/test_expressions.py b/python/ray/data/tests/test_expressions.py index ac4783bbb8a5..815ab4465352 100644 --- a/python/ray/data/tests/test_expressions.py +++ b/python/ray/data/tests/test_expressions.py @@ -1,6 +1,13 @@ import pytest -from ray.data.expressions import Expr, col, lit +from ray.data.expressions import ( + BinaryExpr, + Expr, + Operation, + UnaryExpr, + col, + lit, +) # Tuples of (expr1, expr2, expected_result) STRUCTURAL_EQUALITY_TEST_CASES = [ @@ -58,6 +65,148 @@ def test_operator_eq_is_not_structural_eq(): assert struct_eq_result is True +class TestUnaryExpressions: + """Test unary expression functionality.""" + + @pytest.mark.parametrize( + "expr, expected_op", + [ + (col("age").is_null(), Operation.IS_NULL), + (col("name").is_not_null(), Operation.IS_NOT_NULL), + (~col("active"), Operation.NOT), + ], + ids=["is_null", "is_not_null", "not"], + ) + def test_unary_operations(self, expr, expected_op): + """Test that unary operations create correct UnaryExpr.""" + assert isinstance(expr, UnaryExpr) + assert expr.op == expected_op + assert isinstance(expr.operand, Expr) + + def test_unary_structural_equality(self): + """Test structural equality for unary expressions.""" + # Same expressions should be equal + assert col("age").is_null().structurally_equals(col("age").is_null()) + assert ( + col("active").is_not_null().structurally_equals(col("active").is_not_null()) + ) + assert (~col("flag")).structurally_equals(~col("flag")) + + # Different operations should not be equal + assert not col("age").is_null().structurally_equals(col("age").is_not_null()) + + # Different operands should not be equal + assert not col("age").is_null().structurally_equals(col("name").is_null()) + + +class TestBinaryExpressions: + """Test enhanced binary expression functionality.""" + + @pytest.mark.parametrize( + "expr, expected_op", + [ + (col("age") != lit(25), Operation.NE), + (col("status").is_in(["active", "pending"]), Operation.IN), + (col("status").not_in(["inactive", "deleted"]), Operation.NOT_IN), + (col("a").is_in(col("b")), Operation.IN), + ], + ids=["not_equal", "is_in", "not_in", "is_in_amongst_cols"], + ) + def test_new_binary_operations(self, expr, expected_op): + """Test new binary operations.""" + assert isinstance(expr, BinaryExpr) + assert expr.op == expected_op + + def test_is_in_with_list(self): + """Test is_in with list of values.""" + expr = col("status").is_in(["active", "pending", "completed"]) + assert isinstance(expr, BinaryExpr) + assert expr.op == Operation.IN + # The right operand should be a LiteralExpr containing the list + assert expr.right.value == ["active", "pending", "completed"] + + def test_is_in_with_expr(self): + """Test is_in with expression.""" + values_expr = lit(["a", "b", "c"]) + expr = col("category").is_in(values_expr) + assert isinstance(expr, BinaryExpr) + assert expr.op == Operation.IN + assert expr.right == values_expr + + def test_is_in_amongst_cols(self): + """Test is_in with expression.""" + expr = col("a").is_in(col("b")) + assert isinstance(expr, BinaryExpr) + assert expr.op == Operation.IN + assert expr.right == col("b") + + +class TestBooleanExpressions: + """Test boolean expression functionality.""" + + @pytest.mark.parametrize( + "condition", + [ + col("age") > lit(18), + col("status") == lit("active"), + col("name").is_not_null(), + (col("age") >= lit(21)) & (col("country") == lit("USA")), + ], + ids=["simple_gt", "simple_eq", "is_not_null", "complex_and"], + ) + def test_boolean_expressions_directly(self, condition): + """Test that boolean expressions work directly.""" + assert isinstance(condition, Expr) + # Verify the expression structure based on type + if condition.op in [Operation.GT, Operation.EQ]: + assert isinstance(condition, BinaryExpr) + elif condition.op == Operation.IS_NOT_NULL: + assert isinstance(condition, UnaryExpr) + elif condition.op == Operation.AND: + assert isinstance(condition, BinaryExpr) + + def test_boolean_combination(self): + """Test combining boolean expressions with logical operators.""" + expr1 = col("age") > 18 + expr2 = col("status") == "active" + + # Test AND combination + combined_and = expr1 & expr2 + assert isinstance(combined_and, BinaryExpr) + assert combined_and.op == Operation.AND + + # Test OR combination + combined_or = expr1 | expr2 + assert isinstance(combined_or, BinaryExpr) + assert combined_or.op == Operation.OR + + # Test NOT operation + negated = ~expr1 + assert isinstance(negated, UnaryExpr) + assert negated.op == Operation.NOT + + def test_boolean_structural_equality(self): + """Test structural equality for boolean expressions.""" + expr1 = col("age") > 18 + expr2 = col("age") > 18 + expr3 = col("age") > 21 + + assert expr1.structurally_equals(expr2) + assert not expr1.structurally_equals(expr3) + + def test_complex_boolean_expressions(self): + """Test complex boolean expressions work correctly.""" + # Complex boolean expression + complex_expr = (col("age") >= 21) & (col("country") == "USA") + assert isinstance(complex_expr, BinaryExpr) + assert complex_expr.op == Operation.AND + + # Even more complex with OR and NOT + very_complex = ((col("age") > 21) | (col("status") == "VIP")) & ~col("banned") + assert isinstance(very_complex, BinaryExpr) + assert very_complex.op == Operation.AND + + if __name__ == "__main__": import sys diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 439d8d991170..e7ae680ba95e 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -2711,6 +2711,472 @@ def invalid_int_return(x: pa.Array) -> int: assert "pandas.Series" in error_message and "numpy.ndarray" in error_message +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="with_column requires PyArrow >= 20.0.0", +) +@pytest.mark.parametrize( + "expression, expected_column_data, test_description", + [ + # Floor division operations + pytest.param( + col("id") // 2, + [0, 0, 1, 1, 2], # [0//2, 1//2, 2//2, 3//2, 4//2] + "floor_division_by_literal", + ), + pytest.param( + lit(10) // (col("id") + 2), + [5, 3, 2, 2, 1], # [10//(0+2), 10//(1+2), 10//(2+2), 10//(3+2), 10//(4+2)] + "literal_floor_division_by_expression", + ), + # Not equal operations + pytest.param( + col("id") != 2, + [True, True, False, True, True], # [0!=2, 1!=2, 2!=2, 3!=2, 4!=2] + "not_equal_operation", + ), + # Null checking operations + pytest.param( + col("id").is_null(), + [False, False, False, False, False], # None of the values are null + "is_null_operation", + ), + pytest.param( + col("id").is_not_null(), + [True, True, True, True, True], # All values are not null + "is_not_null_operation", + ), + # Logical NOT operations + pytest.param( + ~(col("id") == 2), + [True, True, False, True, True], # ~[0==2, 1==2, 2==2, 3==2, 4==2] + "logical_not_operation", + ), + ], +) +def test_with_column_floor_division_and_logical_operations( + ray_start_regular_shared, + expression, + expected_column_data, + test_description, +): + """Test floor division, not equal, null checks, and logical NOT operations with with_column.""" + ds = ray.data.range(5) + result_ds = ds.with_column("result", expression) + + # Convert to pandas and assert on the whole dataframe + result_df = result_ds.to_pandas() + expected_df = pd.DataFrame({"id": [0, 1, 2, 3, 4], "result": expected_column_data}) + + pd.testing.assert_frame_equal(result_df, expected_df, check_dtype=False) + + +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="with_column requires PyArrow >= 20.0.0", +) +@pytest.mark.parametrize( + "test_data, expression, expected_results, test_description", + [ + # Test with null values + pytest.param( + [{"value": 1}, {"value": None}, {"value": 3}], + col("value").is_null(), + [False, True, False], + "is_null_with_actual_nulls", + ), + pytest.param( + [{"value": 1}, {"value": None}, {"value": 3}], + col("value").is_not_null(), + [True, False, True], + "is_not_null_with_actual_nulls", + ), + # Test is_in operations + pytest.param( + [{"value": 1}, {"value": 2}, {"value": 3}], + col("value").is_in([1, 3]), + [True, False, True], + "isin_operation", + ), + pytest.param( + [{"value": 1}, {"value": 2}, {"value": 3}], + col("value").not_in([1, 3]), + [False, True, False], + "not_in_operation", + ), + # Test string operations + pytest.param( + [{"name": "Alice"}, {"name": "Bob"}, {"name": "Charlie"}], + col("name") == "Bob", + [False, True, False], + "string_equality", + ), + pytest.param( + [{"name": "Alice"}, {"name": "Bob"}, {"name": "Charlie"}], + col("name") != "Bob", + [True, False, True], + "string_not_equal", + ), + # Filter with string operations - accept engine's null propagation + pytest.param( + [ + {"name": "included"}, + {"name": "excluded"}, + {"name": None}, + ], + col("name").is_not_null() & (col("name") != "excluded"), + [True, False, False], + "string_filter", + ), + ], +) +def test_with_column_null_checks_and_membership_operations( + ray_start_regular_shared, + test_data, + expression, + expected_results, + test_description, + target_max_block_size_infinite_or_default, +): + """Test null checking, is_in/not_in membership operations, and string comparisons with with_column.""" + ds = ray.data.from_items(test_data) + result_ds = ds.with_column("result", expression) + + # Convert to pandas and assert on the whole dataframe + result_df = result_ds.to_pandas() + + # Create expected dataframe from test data + expected_data = {} + for key in test_data[0].keys(): + expected_data[key] = [row[key] for row in test_data] + expected_data["result"] = expected_results + + expected_df = pd.DataFrame(expected_data) + + pd.testing.assert_frame_equal(result_df, expected_df, check_dtype=False) + + +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="with_column requires PyArrow >= 20.0.0", +) +@pytest.mark.parametrize( + "expression_factory, expected_results, test_description", + [ + # Complex boolean expressions + pytest.param( + lambda: (col("age") > 18) & (col("country") == "USA"), + [ + True, + False, + False, + ], # [(25>18)&("USA"=="USA"), (17>18)&("Canada"=="USA"), (30>18)&("UK"=="USA")] + "complex_and_expression", + ), + pytest.param( + lambda: (col("age") < 18) | (col("country") == "USA"), + [ + True, + True, + False, + ], # [(25<18)|("USA"=="USA"), (17<18)|("Canada"=="USA"), (30<18)|("UK"=="USA")] + "complex_or_expression", + ), + pytest.param( + lambda: ~((col("age") < 25) & (col("country") != "USA")), + [ + True, + False, + True, + ], # ~[(25<25)&("USA"!="USA"), (17<25)&("Canada"!="USA"), (30<25)&("UK"!="USA")] + "complex_not_expression", + ), + # Age group calculation (common use case) + pytest.param( + lambda: col("age") // 10 * 10, + [20, 10, 30], # [25//10*10, 17//10*10, 30//10*10] + "age_group_calculation", + ), + # Eligibility flags + pytest.param( + lambda: (col("age") >= 21) + & (col("score") >= 10) + & col("active").is_not_null() + & (col("active") == lit(True)), + [ + True, + False, + False, + ], + "eligibility_flag", + ), + ], +) +def test_with_column_complex_boolean_expressions( + ray_start_regular_shared, + expression_factory, + expected_results, + test_description, + target_max_block_size_infinite_or_default, +): + """Test complex boolean expressions with AND, OR, NOT operations commonly used for filtering and flagging.""" + test_data = [ + {"age": 25, "country": "USA", "active": True, "score": 20}, + {"age": 17, "country": "Canada", "active": False, "score": 10}, + {"age": 30, "country": "UK", "active": None, "score": 20}, + ] + + ds = ray.data.from_items(test_data) + expression = expression_factory() + result_ds = ds.with_column("result", expression) + + # Convert to pandas and assert on the whole dataframe + result_df = result_ds.to_pandas() + expected_df = pd.DataFrame( + { + "age": [25, 17, 30], + "country": ["USA", "Canada", "UK"], + "active": [True, False, None], + "score": [20, 10, 20], + "result": expected_results, + } + ) + + pd.testing.assert_frame_equal(result_df, expected_df, check_dtype=False) + + +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="with_column requires PyArrow >= 20.0.0", +) +def test_with_column_chained_expression_operations( + ray_start_regular_shared, target_max_block_size_infinite_or_default +): + """Test chaining multiple expression operations together in a data transformation pipeline.""" + test_data = [ + {"age": 25, "salary": 50000, "active": True, "score": 20}, + {"age": 17, "salary": 0, "active": False, "score": 10}, + {"age": 35, "salary": 75000, "active": None, "score": 20}, + ] + + ds = ray.data.from_items(test_data) + + # Chain multiple operations + result_ds = ( + ds.with_column("is_adult", col("age") >= 18) + .with_column("age_group", (col("age") // 10) * 10) + .with_column("has_salary", col("salary") != 0) + .with_column( + "is_active_adult", (col("age") >= 18) & col("active").is_not_null() + ) + .with_column("salary_tier", (col("salary") // 25000) * 25000) + .with_column("score_tier", (col("score") // 20) * 20) + ) + + # Convert to pandas and assert on the whole dataframe + result_df = result_ds.to_pandas() + expected_df = pd.DataFrame( + { + "age": [25, 17, 35], + "salary": [50000, 0, 75000], + "active": [True, False, None], + "score": [20, 10, 20], # Add the missing score column + "is_adult": [True, False, True], + "age_group": [20, 10, 30], # age // 10 * 10 + "has_salary": [True, False, True], # salary != 0 + "is_active_adult": [ + True, + False, + False, + ], # (age >= 18) & (active is not null) + "salary_tier": [50000, 0, 75000], # salary // 25000 * 25000 + "score_tier": [20, 0, 20], # score // 20 * 20 + } + ) + + pd.testing.assert_frame_equal(result_df, expected_df, check_dtype=False) + + +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="with_column requires PyArrow >= 20.0.0", +) +@pytest.mark.parametrize( + "filter_expr, test_data, expected_flags, test_description", + [ + # Simple filter expressions + pytest.param( + col("age") >= 21, + [ + {"age": 20, "name": "Alice"}, + {"age": 21, "name": "Bob"}, + {"age": 25, "name": "Charlie"}, + ], + [False, True, True], + "age_filter", + ), + pytest.param( + col("score") > 50, + [ + {"score": 30, "status": "fail"}, + {"score": 50, "status": "pass"}, + {"score": 70, "status": "pass"}, + ], + [False, False, True], + "score_filter", + ), + # Complex filter with multiple conditions + pytest.param( + (col("age") >= 18) & col("active"), + [ + {"age": 17, "active": True}, + {"age": 18, "active": False}, + {"age": 25, "active": True}, + ], + [False, False, True], + "complex_and_filter", + ), + pytest.param( + (col("status") == "approved") | (col("priority") == "high"), + [ + {"status": "pending", "priority": "low"}, + {"status": "approved", "priority": "low"}, + {"status": "pending", "priority": "high"}, + ], + [False, True, True], + "complex_or_filter", + ), + # Filter with null handling + pytest.param( + col("value").is_not_null() & (col("value") > 0), + [ + {"value": None}, + {"value": -5}, + {"value": 10}, + ], + [ + False, + False, + True, + ], + "null_aware_filter", + ), + # Filter with string operations - reorder to check null first + pytest.param( + col("name").is_not_null() & (col("name") != "excluded"), + [ + {"name": "included"}, + {"name": "excluded"}, + {"name": None}, + ], + [True, False, False], + "string_filter", + ), + # Filter with membership operations + pytest.param( + col("category").is_in(["A", "B"]), + [ + {"category": "A"}, + {"category": "B"}, + {"category": "C"}, + {"category": "D"}, + ], + [True, True, False, False], + "membership_filter", + ), + # Nested filter expressions + pytest.param( + (col("score") >= 50) & (col("grade") != "F"), + [ + {"score": 45, "grade": "F"}, + {"score": 55, "grade": "D"}, + {"score": 75, "grade": "B"}, + {"score": 30, "grade": "F"}, + ], + [False, True, True, False], + "nested_filters", + ), + ], +) +def test_with_column_filter_expressions( + ray_start_regular_shared, + filter_expr, + test_data, + expected_flags, + test_description, +): + """Test filter() expression functionality with with_column for creating boolean flag columns.""" + ds = ray.data.from_items(test_data) + result_ds = ds.with_column("is_filtered", filter_expr) + + # Convert to pandas and verify the filter results + result_df = result_ds.to_pandas() + + # Build expected dataframe + expected_df = pd.DataFrame(test_data) + expected_df["is_filtered"] = expected_flags + + pd.testing.assert_frame_equal(result_df, expected_df, check_dtype=False) + + +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="with_column requires PyArrow >= 20.0.0", +) +def test_with_column_filter_in_pipeline(ray_start_regular_shared): + """Test filter() expressions used in a data processing pipeline with multiple transformations.""" + # Create test data for a sales analysis pipeline + test_data = [ + {"product": "A", "quantity": 10, "price": 100, "region": "North"}, + {"product": "B", "quantity": 5, "price": 200, "region": "South"}, + {"product": "C", "quantity": 20, "price": 50, "region": "North"}, + {"product": "D", "quantity": 15, "price": 75, "region": "East"}, + {"product": "E", "quantity": 3, "price": 300, "region": "West"}, + ] + + ds = ray.data.from_items(test_data) + + # Build a pipeline with multiple filter expressions + result_ds = ( + ds + # Calculate total revenue + .with_column("revenue", col("quantity") * col("price")) + # Flag high-value transactions + .with_column("is_high_value", col("revenue") >= 1000) + # Flag bulk orders + .with_column("is_bulk_order", col("quantity") >= 10) + # Flag premium products + .with_column("is_premium", col("price") >= 100) + # Create composite filter for special handling + .with_column( + "needs_special_handling", + (col("is_high_value")) | (col("is_bulk_order") & col("is_premium")), + ) + # Regional filter + .with_column("is_north_region", col("region") == "North") + ) + + # Convert to pandas and verify + result_df = result_ds.to_pandas() + + expected_df = pd.DataFrame( + { + "product": ["A", "B", "C", "D", "E"], + "quantity": [10, 5, 20, 15, 3], + "price": [100, 200, 50, 75, 300], + "region": ["North", "South", "North", "East", "West"], + "revenue": [1000, 1000, 1000, 1125, 900], + "is_high_value": [True, True, True, True, False], + "is_bulk_order": [True, False, True, True, False], + "is_premium": [True, True, False, False, True], + "needs_special_handling": [True, True, True, True, False], + "is_north_region": [True, False, True, False, False], + } + ) + + pd.testing.assert_frame_equal(result_df, expected_df, check_dtype=False) + + if __name__ == "__main__": import sys