From 22f4a63ae29d511983c5dec88278625217bdf617 Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Thu, 23 Oct 2025 02:57:51 -0400 Subject: [PATCH] [Data] Fixing handling of renames in projection pushdown (#58033) ## Description This change properly handles of pushing of the renaming projections into read ops (that support projections, like parquet reads). ## Related issues > Link related issues: "Fixes #1234", "Closes #1234", or "Related to #1234". ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Alexey Kudinkin --- python/ray/data/_internal/collections.py | 49 ++++++++++ .../datasource/parquet_datasource.py | 66 +++++++++++++- .../logical/interfaces/logical_operator.py | 6 +- .../logical/operators/read_operator.py | 10 ++- .../logical/rules/projection_pushdown.py | 89 ++++++++++++++----- .../plan_expression/expression_visitors.py | 35 ++++---- python/ray/data/datasource/datasource.py | 8 +- 7 files changed, 217 insertions(+), 46 deletions(-) create mode 100644 python/ray/data/_internal/collections.py diff --git a/python/ray/data/_internal/collections.py b/python/ray/data/_internal/collections.py new file mode 100644 index 000000000000..9888a669571e --- /dev/null +++ b/python/ray/data/_internal/collections.py @@ -0,0 +1,49 @@ +from typing import Dict, TypeVar + +K = TypeVar("K") + + +def collapse_transitive_map(d: Dict[K, K]) -> Dict[K, K]: + """Collapse transitive mappings in a dictionary. Given a mapping like + {a: b, b: c, c: d}, returns {a: d}, removing intermediate b -> c, c -> d. + + Only keeps mappings where the key is NOT a value in another mapping (i.e., chain starting points). + + Args: + d: Dictionary representing a mapping + + Returns: + Dictionary with all transitive mappings collapsed, keeping only KV-pairs, + such that K and V are starting and terminal point of a chain + + Examples: + >>> collapse_transitive_map({"a": "b", "b": "c", "c": "d"}) + {'a': 'd'} + >>> collapse_transitive_map({"a": "b", "x": "y"}) + {'a': 'b', 'x': 'y'} + """ + if not d: + return {} + + collapsed = {} + values_set = set(d.values()) + for k in d: + # Skip mappings that are in the value-set, meaning that they are + # part of the mapping chain (for ex, {a -> b, b -> c}) + if k in values_set: + continue + + cur = k + visited = {cur} + + # Follow the chain until we reach a key that's not in the mapping + while cur in d: + next = d[cur] + if next in visited: + raise ValueError(f"Detected a cycle in the mapping {d}") + visited.add(next) + cur = next + + collapsed[k] = cur + + return collapsed diff --git a/python/ray/data/_internal/datasource/parquet_datasource.py b/python/ray/data/_internal/datasource/parquet_datasource.py index b37be53c0338..074715dcd84a 100644 --- a/python/ray/data/_internal/datasource/parquet_datasource.py +++ b/python/ray/data/_internal/datasource/parquet_datasource.py @@ -27,6 +27,7 @@ _BATCH_SIZE_PRESERVING_STUB_COL_NAME, ArrowBlockAccessor, ) +from ray.data._internal.collections import collapse_transitive_map from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.util import ( @@ -275,6 +276,7 @@ def __init__( self._block_udf = _block_udf self._to_batches_kwargs = to_batch_kwargs self._data_columns = data_columns + self._data_columns_rename_map = {} self._partition_columns = partition_columns self._read_schema = schema self._file_schema = pq_ds.schema @@ -379,6 +381,7 @@ def get_read_tasks( to_batches_kwargs, default_read_batch_size_rows, data_columns, + data_columns_rename_map, partition_columns, read_schema, include_paths, @@ -388,6 +391,7 @@ def get_read_tasks( self._to_batches_kwargs, self._default_batch_size, self._data_columns, + self._data_columns_rename_map, self._partition_columns, self._read_schema, self._include_paths, @@ -401,6 +405,7 @@ def get_read_tasks( to_batches_kwargs, default_read_batch_size_rows, data_columns, + data_columns_rename_map, partition_columns, read_schema, f, @@ -437,9 +442,18 @@ def get_current_projection(self) -> Optional[List[str]]: return (self._data_columns or []) + (self._partition_columns or []) - def apply_projection(self, columns: List[str]) -> "ParquetDatasource": + def apply_projection( + self, + columns: Optional[List[str]], + column_rename_map: Optional[Dict[str, str]], + ) -> "ParquetDatasource": clone = copy.copy(self) - clone._data_columns = columns + + clone._data_columns = _combine_projection(self._data_columns, columns) + clone._data_columns_rename_map = _combine_rename_map( + self._data_columns_rename_map, column_rename_map + ) + return clone def _estimate_in_mem_size(self, fragments: List[_ParquetFragment]) -> int: @@ -453,6 +467,7 @@ def read_fragments( to_batches_kwargs: Dict[str, Any], default_read_batch_size_rows: Optional[int], data_columns: Optional[List[str]], + data_columns_rename_map: Optional[Dict[str, str]], partition_columns: Optional[List[str]], schema: Optional[Union[type, "pyarrow.lib.Schema"]], fragments: List[_ParquetFragment], @@ -475,6 +490,7 @@ def read_fragments( fragment.original, schema=schema, data_columns=data_columns, + data_columns_rename_map=data_columns_rename_map, partition_columns=partition_columns, partitioning=partitioning, include_path=include_paths, @@ -497,6 +513,7 @@ def _read_batches_from( *, schema: "pyarrow.Schema", data_columns: Optional[List[str]], + data_columns_rename_map: Optional[Dict[str, str]], partition_columns: Optional[List[str]], partitioning: Partitioning, filter_expr: Optional["pyarrow.dataset.Expression"] = None, @@ -562,6 +579,14 @@ def _read_batches_from( _BATCH_SIZE_PRESERVING_STUB_COL_NAME, pa.nulls(table.num_rows) ) + if data_columns_rename_map is not None: + table = table.rename_columns( + [ + data_columns_rename_map.get(col, col) + for col in table.schema.names + ] + ) + yield table except pa.lib.ArrowInvalid as e: @@ -825,6 +850,43 @@ def _add_partitions_to_table( return table +def _combine_projection( + prev_projected_cols: Optional[List[str]], new_projected_cols: Optional[List[str]] +) -> Optional[List[str]]: + # NOTE: Null projection carries special meaning of all columns being selected + if prev_projected_cols is None: + return new_projected_cols + elif new_projected_cols is None: + # Retain original projection + return prev_projected_cols + else: + illegal_refs = [ + col for col in new_projected_cols if col not in prev_projected_cols + ] + + if illegal_refs: + raise ValueError( + f"New projection {new_projected_cols} references non-existent columns " + f"(existing projection {prev_projected_cols})" + ) + + return new_projected_cols + + +def _combine_rename_map( + prev_column_rename_map: Optional[Dict[str, str]], + new_column_rename_map: Optional[Dict[str, str]], +): + if not prev_column_rename_map: + combined = new_column_rename_map + elif not new_column_rename_map: + combined = prev_column_rename_map + else: + combined = prev_column_rename_map | new_column_rename_map + + return collapse_transitive_map(combined) + + def _get_partition_columns_schema( partitioning: Partitioning, file_paths: List[str], diff --git a/python/ray/data/_internal/logical/interfaces/logical_operator.py b/python/ray/data/_internal/logical/interfaces/logical_operator.py index 037776017b4a..d7141af78987 100644 --- a/python/ray/data/_internal/logical/interfaces/logical_operator.py +++ b/python/ray/data/_internal/logical/interfaces/logical_operator.py @@ -98,5 +98,9 @@ def supports_projection_pushdown(self) -> bool: def get_current_projection(self) -> Optional[List[str]]: return None - def apply_projection(self, columns: Optional[List[str]]) -> LogicalOperator: + def apply_projection( + self, + columns: Optional[List[str]], + column_rename_map: Optional[Dict[str, str]], + ) -> LogicalOperator: return self diff --git a/python/ray/data/_internal/logical/operators/read_operator.py b/python/ray/data/_internal/logical/operators/read_operator.py index 38ee615ffe39..8561e7aff75c 100644 --- a/python/ray/data/_internal/logical/operators/read_operator.py +++ b/python/ray/data/_internal/logical/operators/read_operator.py @@ -158,10 +158,16 @@ def supports_projection_pushdown(self) -> bool: def get_current_projection(self) -> Optional[List[str]]: return self._datasource.get_current_projection() - def apply_projection(self, columns: List[str]): + def apply_projection( + self, + columns: Optional[List[str]], + column_rename_map: Optional[Dict[str, str]], + ) -> "Read": clone = copy.copy(self) - projected_datasource = self._datasource.apply_projection(columns) + projected_datasource = self._datasource.apply_projection( + columns, column_rename_map + ) clone._datasource = projected_datasource clone._datasource_or_legacy_reader = projected_datasource diff --git a/python/ray/data/_internal/logical/rules/projection_pushdown.py b/python/ray/data/_internal/logical/rules/projection_pushdown.py index 0cb7950a10d5..e85f8ba0eb88 100644 --- a/python/ray/data/_internal/logical/rules/projection_pushdown.py +++ b/python/ray/data/_internal/logical/rules/projection_pushdown.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple from ray.data._internal.logical.interfaces import ( LogicalOperator, @@ -19,7 +19,7 @@ ) -def _collect_referenced_columns(exprs: List[Expr]) -> Optional[Set[str]]: +def _collect_referenced_columns(exprs: List[Expr]) -> Optional[List[str]]: """ Extract all column names referenced by the given expressions. @@ -37,7 +37,8 @@ def _collect_referenced_columns(exprs: List[Expr]) -> Optional[Set[str]]: collector = _ColumnReferenceCollector() for expr in exprs or []: collector.visit(expr) - return collector.referenced_columns + + return collector.get_column_refs() def _extract_simple_rename(expr: Expr) -> Optional[Tuple[str, str]]: @@ -118,9 +119,11 @@ def _validate_fusion( if isinstance(expr, StarExpr): continue - referenced_columns = _collect_referenced_columns([expr]) or set() - columns_from_original = referenced_columns - ( - referenced_columns & upstream_output_columns + column_refs = _collect_referenced_columns([expr]) + column_refs_set = set(column_refs or []) + + columns_from_original = column_refs_set - ( + column_refs_set & upstream_output_columns ) # Validate accessibility @@ -280,27 +283,65 @@ def _push_projection_into_read_op(cls, op: LogicalOperator) -> LogicalOperator: # Step 2: Push projection into the data source if supported input_op = current_project.input_dependency if ( - not current_project.has_star_expr() # Must be a selection, not additive - and isinstance(input_op, LogicalOperatorSupportsProjectionPushdown) + isinstance(input_op, LogicalOperatorSupportsProjectionPushdown) and input_op.supports_projection_pushdown() ): - required_columns = _collect_referenced_columns(list(current_project.exprs)) - if required_columns is not None: # None means star() was present - optimized_source = input_op.apply_projection(list(required_columns)) - - is_simple_selection = all( - isinstance(expr, ColumnExpr) for expr in current_project.exprs + if current_project.has_star_expr(): + # If project has a star, than no projection is feasible + required_columns = None + else: + # Otherwise, collect required column for projection + required_columns = _collect_referenced_columns(current_project.exprs) + + # Check if it's a simple projection that could be pushed into + # read as a whole + is_simple_projection = all( + _is_col_expr(expr) + for expr in current_project.exprs + if not isinstance(expr, StarExpr) + ) + + if is_simple_projection: + # NOTE: We only can rename output columns when it's a simple + # projection and Project operator is discarded (otherwise + # it might be holding expression referencing attributes + # by original their names prior to renaming) + # + # TODO fix by instead rewriting exprs + output_column_rename_map = _collect_output_column_rename_map( + current_project.exprs ) - if is_simple_selection: - # Simple column selection: Read handles everything - return optimized_source - else: - # Has transformations: Keep Project on top of optimized Read - return Project( - optimized_source, - exprs=current_project.exprs, - ray_remote_args=current_project._ray_remote_args, - ) + # Apply projection of columns to the read op + return input_op.apply_projection( + required_columns, output_column_rename_map + ) + else: + # Otherwise just apply projection without renaming + projected_input_op = input_op.apply_projection(required_columns, None) + + # Has transformations: Keep Project on top of optimized Read + return Project( + projected_input_op, + exprs=current_project.exprs, + ray_remote_args=current_project._ray_remote_args, + ) return current_project + + +def _is_col_expr(expr: Expr) -> bool: + return isinstance(expr, ColumnExpr) or ( + isinstance(expr, AliasExpr) and isinstance(expr.expr, ColumnExpr) + ) + + +def _collect_output_column_rename_map(exprs: List[Expr]) -> Dict[str, str]: + # First, extract all potential rename pairs + rename_map = { + expr.expr.name: expr.name + for expr in exprs + if isinstance(expr, AliasExpr) and isinstance(expr.expr, ColumnExpr) + } + + return rename_map diff --git a/python/ray/data/_internal/planner/plan_expression/expression_visitors.py b/python/ray/data/_internal/planner/plan_expression/expression_visitors.py index 9162f417f3e5..9e2b2ae6cbb7 100644 --- a/python/ray/data/_internal/planner/plan_expression/expression_visitors.py +++ b/python/ray/data/_internal/planner/plan_expression/expression_visitors.py @@ -1,4 +1,4 @@ -from typing import Set, TypeVar +from typing import Dict, List, Set, TypeVar from ray.data.expressions import ( AliasExpr, @@ -42,6 +42,18 @@ def visit_udf(self, expr: "UDFExpr") -> None: for value in expr.kwargs.values(): super().visit(value) + def visit_literal(self, expr: LiteralExpr) -> None: + """Visit a literal expression (no columns to collect).""" + pass + + def visit_star(self, expr: StarExpr) -> None: + """Visit a star expression (no columns to collect).""" + pass + + def visit_download(self, expr: "Expr") -> None: + """Visit a download expression (no columns to collect).""" + pass + class _ColumnReferenceCollector(_ExprVisitorBase): """Visitor that collects all column references from expression trees. @@ -52,7 +64,12 @@ class _ColumnReferenceCollector(_ExprVisitorBase): def __init__(self): """Initialize with an empty set of referenced columns.""" - self.referenced_columns: Set[str] = set() + + # NOTE: We're using dict to maintain insertion ordering + self._col_refs: Dict[str, None] = dict() + + def get_column_refs(self) -> List[str]: + return list(self._col_refs.keys()) def visit_column(self, expr: ColumnExpr) -> None: """Visit a column expression and collect its name. @@ -63,7 +80,7 @@ def visit_column(self, expr: ColumnExpr) -> None: Returns: None (only collects columns as a side effect). """ - self.referenced_columns.add(expr.name) + self._col_refs[expr.name] = None def visit_alias(self, expr: AliasExpr) -> None: """Visit an alias expression and collect from its inner expression. @@ -76,18 +93,6 @@ def visit_alias(self, expr: AliasExpr) -> None: """ self.visit(expr.expr) - def visit_literal(self, expr: LiteralExpr) -> None: - """Visit a literal expression (no columns to collect).""" - pass - - def visit_star(self, expr: StarExpr) -> None: - """Visit a star expression (no columns to collect).""" - pass - - def visit_download(self, expr: "Expr") -> None: - """Visit a download expression (no columns to collect).""" - pass - class _ColumnRewriter(_ExprVisitor[Expr]): """Visitor that rewrites column references in expression trees. diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index b39a692e2bba..c217e3a2d69d 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, List, Optional +from typing import Callable, Dict, Iterable, List, Optional import numpy as np @@ -20,7 +20,11 @@ def get_current_projection(self) -> Optional[List[str]]: """Retrurns current projection""" return None - def apply_projection(self, columns: List[str]) -> "Datasource": + def apply_projection( + self, + columns: Optional[List[str]], + column_rename_map: Optional[Dict[str, str]], + ) -> "Datasource": return self