Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions python/ray/data/_internal/collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Dict, TypeVar

K = TypeVar("K")


def collapse_transitive_map(d: Dict[K, K]) -> Dict[K, K]:
"""Collapse transitive mappings in a dictionary. Given a mapping like
{a: b, b: c, c: d}, returns {a: d}, removing intermediate b -> c, c -> d.

Only keeps mappings where the key is NOT a value in another mapping (i.e., chain starting points).

Args:
d: Dictionary representing a mapping

Returns:
Dictionary with all transitive mappings collapsed, keeping only KV-pairs,
such that K and V are starting and terminal point of a chain

Examples:
>>> collapse_transitive_map({"a": "b", "b": "c", "c": "d"})
{'a': 'd'}
>>> collapse_transitive_map({"a": "b", "x": "y"})
{'a': 'b', 'x': 'y'}
"""
if not d:
return {}

collapsed = {}
values_set = set(d.values())
for k in d:
# Skip mappings that are in the value-set, meaning that they are
# part of the mapping chain (for ex, {a -> b, b -> c})
if k in values_set:
continue

cur = k
visited = {cur}

# Follow the chain until we reach a key that's not in the mapping
while cur in d:
next = d[cur]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable next shadows the built-in function next(). It's a good practice to avoid shadowing built-ins to prevent potential confusion and bugs. Consider renaming it to something more specific like next_key or next_val.

Suggested change
next = d[cur]
next_val = d[cur]

if next in visited:
raise ValueError(f"Detected a cycle in the mapping {d}")
visited.add(next)
cur = next

collapsed[k] = cur

return collapsed
66 changes: 64 additions & 2 deletions python/ray/data/_internal/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_BATCH_SIZE_PRESERVING_STUB_COL_NAME,
ArrowBlockAccessor,
)
from ray.data._internal.collections import collapse_transitive_map
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.util import (
Expand Down Expand Up @@ -275,6 +276,7 @@ def __init__(
self._block_udf = _block_udf
self._to_batches_kwargs = to_batch_kwargs
self._data_columns = data_columns
self._data_columns_rename_map = {}
self._partition_columns = partition_columns
self._read_schema = schema
self._file_schema = pq_ds.schema
Expand Down Expand Up @@ -379,6 +381,7 @@ def get_read_tasks(
to_batches_kwargs,
default_read_batch_size_rows,
data_columns,
data_columns_rename_map,
partition_columns,
read_schema,
include_paths,
Expand All @@ -388,6 +391,7 @@ def get_read_tasks(
self._to_batches_kwargs,
self._default_batch_size,
self._data_columns,
self._data_columns_rename_map,
self._partition_columns,
self._read_schema,
self._include_paths,
Expand All @@ -401,6 +405,7 @@ def get_read_tasks(
to_batches_kwargs,
default_read_batch_size_rows,
data_columns,
data_columns_rename_map,
partition_columns,
read_schema,
f,
Expand Down Expand Up @@ -437,9 +442,18 @@ def get_current_projection(self) -> Optional[List[str]]:

return (self._data_columns or []) + (self._partition_columns or [])

def apply_projection(self, columns: List[str]) -> "ParquetDatasource":
def apply_projection(
self,
columns: Optional[List[str]],
column_rename_map: Optional[Dict[str, str]],
) -> "ParquetDatasource":
clone = copy.copy(self)
clone._data_columns = columns

clone._data_columns = _combine_projection(self._data_columns, columns)
clone._data_columns_rename_map = _combine_rename_map(
self._data_columns_rename_map, column_rename_map
)

return clone

def _estimate_in_mem_size(self, fragments: List[_ParquetFragment]) -> int:
Expand All @@ -453,6 +467,7 @@ def read_fragments(
to_batches_kwargs: Dict[str, Any],
default_read_batch_size_rows: Optional[int],
data_columns: Optional[List[str]],
data_columns_rename_map: Optional[Dict[str, str]],
partition_columns: Optional[List[str]],
schema: Optional[Union[type, "pyarrow.lib.Schema"]],
fragments: List[_ParquetFragment],
Expand All @@ -475,6 +490,7 @@ def read_fragments(
fragment.original,
schema=schema,
data_columns=data_columns,
data_columns_rename_map=data_columns_rename_map,
partition_columns=partition_columns,
partitioning=partitioning,
include_path=include_paths,
Expand All @@ -497,6 +513,7 @@ def _read_batches_from(
*,
schema: "pyarrow.Schema",
data_columns: Optional[List[str]],
data_columns_rename_map: Optional[Dict[str, str]],
partition_columns: Optional[List[str]],
partitioning: Partitioning,
filter_expr: Optional["pyarrow.dataset.Expression"] = None,
Expand Down Expand Up @@ -562,6 +579,14 @@ def _read_batches_from(
_BATCH_SIZE_PRESERVING_STUB_COL_NAME, pa.nulls(table.num_rows)
)

if data_columns_rename_map is not None:
table = table.rename_columns(
[
data_columns_rename_map.get(col, col)
for col in table.schema.names
]
)

yield table

except pa.lib.ArrowInvalid as e:
Expand Down Expand Up @@ -825,6 +850,43 @@ def _add_partitions_to_table(
return table


def _combine_projection(
prev_projected_cols: Optional[List[str]], new_projected_cols: Optional[List[str]]
) -> Optional[List[str]]:
# NOTE: Null projection carries special meaning of all columns being selected
if prev_projected_cols is None:
return new_projected_cols
elif new_projected_cols is None:
# Retain original projection
return prev_projected_cols
else:
illegal_refs = [
col for col in new_projected_cols if col not in prev_projected_cols
]

if illegal_refs:
raise ValueError(
f"New projection {new_projected_cols} references non-existent columns "
f"(existing projection {prev_projected_cols})"
)

return new_projected_cols


def _combine_rename_map(
prev_column_rename_map: Optional[Dict[str, str]],
new_column_rename_map: Optional[Dict[str, str]],
):
if not prev_column_rename_map:
combined = new_column_rename_map
elif not new_column_rename_map:
combined = prev_column_rename_map
else:
combined = prev_column_rename_map | new_column_rename_map

return collapse_transitive_map(combined)
Comment on lines +876 to +887
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function is missing a return type hint. Adding it would improve type safety and code clarity. Additionally, the logic for combining the dictionaries can be made more concise and robust against None values by using or {}.

def _combine_rename_map(
    prev_column_rename_map: Optional[Dict[str, str]],
    new_column_rename_map: Optional[Dict[str, str]],
) -> Dict[str, str]:
    combined = (prev_column_rename_map or {}) | (new_column_rename_map or {})
    return collapse_transitive_map(combined)



def _get_partition_columns_schema(
partitioning: Partitioning,
file_paths: List[str],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,9 @@ def supports_projection_pushdown(self) -> bool:
def get_current_projection(self) -> Optional[List[str]]:
return None

def apply_projection(self, columns: Optional[List[str]]) -> LogicalOperator:
def apply_projection(
self,
columns: Optional[List[str]],
column_rename_map: Optional[Dict[str, str]],
) -> LogicalOperator:
return self
10 changes: 8 additions & 2 deletions python/ray/data/_internal/logical/operators/read_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,16 @@ def supports_projection_pushdown(self) -> bool:
def get_current_projection(self) -> Optional[List[str]]:
return self._datasource.get_current_projection()

def apply_projection(self, columns: List[str]):
def apply_projection(
self,
columns: Optional[List[str]],
column_rename_map: Optional[Dict[str, str]],
) -> "Read":
clone = copy.copy(self)

projected_datasource = self._datasource.apply_projection(columns)
projected_datasource = self._datasource.apply_projection(
columns, column_rename_map
)
clone._datasource = projected_datasource
clone._datasource_or_legacy_reader = projected_datasource

Expand Down
89 changes: 65 additions & 24 deletions python/ray/data/_internal/logical/rules/projection_pushdown.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple

from ray.data._internal.logical.interfaces import (
LogicalOperator,
Expand All @@ -19,7 +19,7 @@
)


def _collect_referenced_columns(exprs: List[Expr]) -> Optional[Set[str]]:
def _collect_referenced_columns(exprs: List[Expr]) -> Optional[List[str]]:
"""
Extract all column names referenced by the given expressions.

Expand All @@ -37,7 +37,8 @@ def _collect_referenced_columns(exprs: List[Expr]) -> Optional[Set[str]]:
collector = _ColumnReferenceCollector()
for expr in exprs or []:
collector.visit(expr)
return collector.referenced_columns

return collector.get_column_refs()


def _extract_simple_rename(expr: Expr) -> Optional[Tuple[str, str]]:
Expand Down Expand Up @@ -118,9 +119,11 @@ def _validate_fusion(
if isinstance(expr, StarExpr):
continue

referenced_columns = _collect_referenced_columns([expr]) or set()
columns_from_original = referenced_columns - (
referenced_columns & upstream_output_columns
column_refs = _collect_referenced_columns([expr])
column_refs_set = set(column_refs or [])

columns_from_original = column_refs_set - (
column_refs_set & upstream_output_columns
)

# Validate accessibility
Expand Down Expand Up @@ -280,27 +283,65 @@ def _push_projection_into_read_op(cls, op: LogicalOperator) -> LogicalOperator:
# Step 2: Push projection into the data source if supported
input_op = current_project.input_dependency
if (
not current_project.has_star_expr() # Must be a selection, not additive
and isinstance(input_op, LogicalOperatorSupportsProjectionPushdown)
isinstance(input_op, LogicalOperatorSupportsProjectionPushdown)
and input_op.supports_projection_pushdown()
):
required_columns = _collect_referenced_columns(list(current_project.exprs))
if required_columns is not None: # None means star() was present
optimized_source = input_op.apply_projection(list(required_columns))

is_simple_selection = all(
isinstance(expr, ColumnExpr) for expr in current_project.exprs
if current_project.has_star_expr():
# If project has a star, than no projection is feasible
required_columns = None
else:
# Otherwise, collect required column for projection
required_columns = _collect_referenced_columns(current_project.exprs)

# Check if it's a simple projection that could be pushed into
# read as a whole
is_simple_projection = all(
_is_col_expr(expr)
for expr in current_project.exprs
if not isinstance(expr, StarExpr)
)

if is_simple_projection:
# NOTE: We only can rename output columns when it's a simple
# projection and Project operator is discarded (otherwise
# it might be holding expression referencing attributes
# by original their names prior to renaming)
#
# TODO fix by instead rewriting exprs
output_column_rename_map = _collect_output_column_rename_map(
current_project.exprs
)

if is_simple_selection:
# Simple column selection: Read handles everything
return optimized_source
else:
# Has transformations: Keep Project on top of optimized Read
return Project(
optimized_source,
exprs=current_project.exprs,
ray_remote_args=current_project._ray_remote_args,
)
# Apply projection of columns to the read op
return input_op.apply_projection(
required_columns, output_column_rename_map
)
else:
# Otherwise just apply projection without renaming
projected_input_op = input_op.apply_projection(required_columns, None)

# Has transformations: Keep Project on top of optimized Read
return Project(
projected_input_op,
exprs=current_project.exprs,
ray_remote_args=current_project._ray_remote_args,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Projection Pushdown Fails with star() and Renames

The projection pushdown logic now incorrectly attempts to push down projections that include star() expressions. This changes semantics, especially when star() is combined with column renames (e.g., [star(), col("a").alias("b")]). Instead of preserving all original columns and adding the renamed one, the original column is replaced, altering the output schema.

Fix in Cursor Fix in Web


return current_project


def _is_col_expr(expr: Expr) -> bool:
return isinstance(expr, ColumnExpr) or (
isinstance(expr, AliasExpr) and isinstance(expr.expr, ColumnExpr)
)


def _collect_output_column_rename_map(exprs: List[Expr]) -> Dict[str, str]:
# First, extract all potential rename pairs
rename_map = {
expr.expr.name: expr.name
for expr in exprs
if isinstance(expr, AliasExpr) and isinstance(expr.expr, ColumnExpr)
}

return rename_map
Loading