Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/source/data/api/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,6 @@ instantiate them directly, but you may encounter them when working with expressi
Expr
ColumnExpr
LiteralExpr
BinaryExpr
BinaryExpr
UnaryExpr
UDFExpr
35 changes: 31 additions & 4 deletions python/ray/data/_expression_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,56 @@
LiteralExpr,
Operation,
UDFExpr,
UnaryExpr,
)

_PANDAS_EXPR_OPS_MAP = {

def _pa_is_in(left: Any, right: Any) -> Any:
if not isinstance(right, (pa.Array, pa.ChunkedArray)):
right = pa.array(right.as_py() if isinstance(right, pa.Scalar) else right)
return pc.is_in(left, right)


_PANDAS_EXPR_OPS_MAP: Dict[Operation, Callable[..., Any]] = {
Operation.ADD: operator.add,
Operation.SUB: operator.sub,
Operation.MUL: operator.mul,
Operation.DIV: operator.truediv,
Operation.FLOORDIV: operator.floordiv,
Operation.GT: operator.gt,
Operation.LT: operator.lt,
Operation.GE: operator.ge,
Operation.LE: operator.le,
Operation.EQ: operator.eq,
Operation.NE: operator.ne,
Operation.AND: operator.and_,
Operation.OR: operator.or_,
Operation.NOT: operator.not_,
Operation.IS_NULL: pd.isna,
Operation.IS_NOT_NULL: pd.notna,
Operation.IN: lambda left, right: left.is_in(right),
Operation.NOT_IN: lambda left, right: ~left.is_in(right),
}

_ARROW_EXPR_OPS_MAP = {
_ARROW_EXPR_OPS_MAP: Dict[Operation, Callable[..., Any]] = {
Operation.ADD: pc.add,
Operation.SUB: pc.subtract,
Operation.MUL: pc.multiply,
Operation.DIV: pc.divide,
Operation.FLOORDIV: lambda left, right: pc.floor(pc.divide(left, right)),
Operation.GT: pc.greater,
Operation.LT: pc.less,
Operation.GE: pc.greater_equal,
Operation.LE: pc.less_equal,
Operation.EQ: pc.equal,
Operation.AND: pc.and_,
Operation.OR: pc.or_,
Operation.NE: pc.not_equal,
Operation.AND: pc.and_kleene,
Operation.OR: pc.or_kleene,
Operation.NOT: pc.invert,
Operation.IS_NULL: pc.is_null,
Operation.IS_NOT_NULL: pc.is_valid,
Operation.IN: _pa_is_in,
Operation.NOT_IN: lambda left, right: pc.invert(_pa_is_in(left, right)),
}


Expand All @@ -63,6 +85,10 @@ def _eval_expr_recursive(
_eval_expr_recursive(expr.left, batch, ops),
_eval_expr_recursive(expr.right, batch, ops),
)
if isinstance(expr, UnaryExpr):
# TODO: Use Visitor pattern here and store ops in shared state.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdym by storing ops in a shared state?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return ops[expr.op](_eval_expr_recursive(expr.operand, batch, ops))

if isinstance(expr, UDFExpr):
args = [_eval_expr_recursive(arg, batch, ops) for arg in expr.args]
kwargs = {
Expand All @@ -79,6 +105,7 @@ def _eval_expr_recursive(
)

return result

raise TypeError(f"Unsupported expression node: {type(expr).__name__}")


Expand Down
2 changes: 2 additions & 0 deletions python/ray/data/_internal/pandas_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ def rename_columns(self, columns_rename: Dict[str, str]) -> "pandas.DataFrame":
def upsert_column(
self, column_name: str, column_data: BlockColumn
) -> "pandas.DataFrame":
import pyarrow

if isinstance(column_data, (pyarrow.Array, pyarrow.ChunkedArray)):
column_data = column_data.to_pandas()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def visit_Compare(self, node: ast.Compare) -> ds.Expression:

op = node.ops[0]
if isinstance(op, ast.In):
return left_expr.isin(comparators[0])
return left_expr.is_in(comparators[0])
elif isinstance(op, ast.NotIn):
return ~left_expr.isin(comparators[0])
return ~left_expr.is_in(comparators[0])
elif isinstance(op, ast.Eq):
return left_expr == comparators[0]
elif isinstance(op, ast.NotEq):
Expand Down Expand Up @@ -210,7 +210,7 @@ def visit_Call(self, node: ast.Call) -> ds.Expression:
nan_is_null=nan_is_null
),
"is_valid": lambda arg: arg.is_valid(),
"isin": lambda arg1, arg2: arg1.isin(arg2),
"is_in": lambda arg1, arg2: arg1.is_in(arg2),
}

if func_name in function_map:
Expand All @@ -224,11 +224,11 @@ def visit_Call(self, node: ast.Call) -> ds.Expression:
return function_map[func_name](args[0], args[1])
else:
raise ValueError("is_null function requires one or two arguments.")
# Handle the "isin" function with exactly two arguments
elif func_name == "isin" and len(args) != 2:
raise ValueError("isin function requires two arguments.")
# Handle the "is_in" function with exactly two arguments
elif func_name == "is_in" and len(args) != 2:
raise ValueError("is_in function requires two arguments.")
# Ensure the function has one argument (for functions like is_valid)
elif func_name != "isin" and len(args) != 1:
elif func_name != "is_in" and len(args) != 1:
raise ValueError(f"{func_name} function requires exactly one argument.")
# Call the corresponding function with the arguments
return function_map[func_name](*args)
Expand Down
87 changes: 86 additions & 1 deletion python/ray/data/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Union

from ray.data.block import BatchColumn
from ray.data.datatype import DataType
Expand All @@ -23,26 +23,40 @@ class Operation(Enum):
SUB: Subtraction operation (-)
MUL: Multiplication operation (*)
DIV: Division operation (/)
FLOORDIV: Floor division operation (//)
Copy link
Contributor

@iamjustinhsu iamjustinhsu Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there modulo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GT: Greater than comparison (>)
LT: Less than comparison (<)
GE: Greater than or equal comparison (>=)
LE: Less than or equal comparison (<=)
EQ: Equality comparison (==)
NE: Not equal comparison (!=)
AND: Logical AND operation (&)
OR: Logical OR operation (|)
NOT: Logical NOT operation (~)
IS_NULL: Check if value is null
IS_NOT_NULL: Check if value is not null
IN: Check if value is in a list
NOT_IN: Check if value is not in a list
"""

ADD = "add"
SUB = "sub"
MUL = "mul"
DIV = "div"
FLOORDIV = "floordiv"
GT = "gt"
LT = "lt"
GE = "ge"
LE = "le"
EQ = "eq"
NE = "ne"
AND = "and"
OR = "or"
NOT = "not"
IS_NULL = "is_null"
IS_NOT_NULL = "is_not_null"
IN = "in"
NOT_IN = "not_in"


@DeveloperAPI(stability="alpha")
Expand Down Expand Up @@ -127,6 +141,14 @@ def __rtruediv__(self, other: Any) -> "Expr":
"""Reverse division operator (for literal / expr)."""
return LiteralExpr(other)._bin(self, Operation.DIV)

def __floordiv__(self, other: Any) -> "Expr":
"""Floor division operator (//)."""
return self._bin(other, Operation.FLOORDIV)

def __rfloordiv__(self, other: Any) -> "Expr":
"""Reverse floor division operator (for literal // expr)."""
return LiteralExpr(other)._bin(self, Operation.FLOORDIV)

# comparison
def __gt__(self, other: Any) -> "Expr":
"""Greater than operator (>)."""
Expand All @@ -148,6 +170,10 @@ def __eq__(self, other: Any) -> "Expr":
"""Equality operator (==)."""
return self._bin(other, Operation.EQ)

def __ne__(self, other: Any) -> "Expr":
"""Not equal operator (!=)."""
return self._bin(other, Operation.NE)

# boolean
def __and__(self, other: Any) -> "Expr":
"""Logical AND operator (&)."""
Expand All @@ -157,6 +183,31 @@ def __or__(self, other: Any) -> "Expr":
"""Logical OR operator (|)."""
return self._bin(other, Operation.OR)

def __invert__(self) -> "Expr":
"""Logical NOT operator (~)."""
return UnaryExpr(Operation.NOT, self)

# predicate methods
def is_null(self) -> "Expr":
"""Check if the expression value is null."""
return UnaryExpr(Operation.IS_NULL, self)

def is_not_null(self) -> "Expr":
"""Check if the expression value is not null."""
return UnaryExpr(Operation.IS_NOT_NULL, self)

def is_in(self, values: Union[List[Any], "Expr"]) -> "Expr":
"""Check if the expression value is in a list of values."""
if not isinstance(values, Expr):
values = LiteralExpr(values)
return self._bin(values, Operation.IN)

def not_in(self, values: Union[List[Any], "Expr"]) -> "Expr":
"""Check if the expression value is not in a list of values."""
if not isinstance(values, Expr):
values = LiteralExpr(values)
return self._bin(values, Operation.NOT_IN)


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True, eq=False)
Expand Down Expand Up @@ -257,6 +308,39 @@ def structurally_equals(self, other: Any) -> bool:
)


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True, eq=False)
class UnaryExpr(Expr):
"""Expression that represents a unary operation on a single expression.

This expression type represents an operation with one operand.
Common unary operations include logical NOT, IS NULL, IS NOT NULL, etc.

Args:
op: The operation to perform (from Operation enum)
operand: The operand expression

Example:
>>> from ray.data.expressions import col
>>> # Check if a column is null
>>> expr = col("age").is_null() # Creates UnaryExpr(IS_NULL, col("age"))
>>> # Logical not
>>> expr = ~(col("active")) # Creates UnaryExpr(NOT, col("active"))
"""

op: Operation
operand: Expr

data_type: DataType = field(init=False)

def structurally_equals(self, other: Any) -> bool:
return (
isinstance(other, UnaryExpr)
and self.op is other.op
and self.operand.structurally_equals(other.operand)
)


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True, eq=False)
class UDFExpr(Expr):
Expand Down Expand Up @@ -517,6 +601,7 @@ def download(uri_column_name: str) -> DownloadExpr:
"ColumnExpr",
"LiteralExpr",
"BinaryExpr",
"UnaryExpr",
"UDFExpr",
"udf",
"DownloadExpr",
Expand Down
Loading