Skip to content

Commit 4130e4d

Browse files
[Data] Fixing handling of renames in projection pushdown (#58033)
## Description This change properly handles of pushing of the renaming projections into read ops (that support projections, like parquet reads). ## Related issues > Link related issues: "Fixes #1234", "Closes #1234", or "Related to #1234". ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
1 parent 42b58a4 commit 4130e4d

File tree

7 files changed

+217
-46
lines changed

7 files changed

+217
-46
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Dict, TypeVar
2+
3+
K = TypeVar("K")
4+
5+
6+
def collapse_transitive_map(d: Dict[K, K]) -> Dict[K, K]:
7+
"""Collapse transitive mappings in a dictionary. Given a mapping like
8+
{a: b, b: c, c: d}, returns {a: d}, removing intermediate b -> c, c -> d.
9+
10+
Only keeps mappings where the key is NOT a value in another mapping (i.e., chain starting points).
11+
12+
Args:
13+
d: Dictionary representing a mapping
14+
15+
Returns:
16+
Dictionary with all transitive mappings collapsed, keeping only KV-pairs,
17+
such that K and V are starting and terminal point of a chain
18+
19+
Examples:
20+
>>> collapse_transitive_map({"a": "b", "b": "c", "c": "d"})
21+
{'a': 'd'}
22+
>>> collapse_transitive_map({"a": "b", "x": "y"})
23+
{'a': 'b', 'x': 'y'}
24+
"""
25+
if not d:
26+
return {}
27+
28+
collapsed = {}
29+
values_set = set(d.values())
30+
for k in d:
31+
# Skip mappings that are in the value-set, meaning that they are
32+
# part of the mapping chain (for ex, {a -> b, b -> c})
33+
if k in values_set:
34+
continue
35+
36+
cur = k
37+
visited = {cur}
38+
39+
# Follow the chain until we reach a key that's not in the mapping
40+
while cur in d:
41+
next = d[cur]
42+
if next in visited:
43+
raise ValueError(f"Detected a cycle in the mapping {d}")
44+
visited.add(next)
45+
cur = next
46+
47+
collapsed[k] = cur
48+
49+
return collapsed

python/ray/data/_internal/datasource/parquet_datasource.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
_BATCH_SIZE_PRESERVING_STUB_COL_NAME,
2828
ArrowBlockAccessor,
2929
)
30+
from ray.data._internal.collections import collapse_transitive_map
3031
from ray.data._internal.progress_bar import ProgressBar
3132
from ray.data._internal.remote_fn import cached_remote_fn
3233
from ray.data._internal.util import (
@@ -275,6 +276,7 @@ def __init__(
275276
self._block_udf = _block_udf
276277
self._to_batches_kwargs = to_batch_kwargs
277278
self._data_columns = data_columns
279+
self._data_columns_rename_map = {}
278280
self._partition_columns = partition_columns
279281
self._read_schema = schema
280282
self._file_schema = pq_ds.schema
@@ -379,6 +381,7 @@ def get_read_tasks(
379381
to_batches_kwargs,
380382
default_read_batch_size_rows,
381383
data_columns,
384+
data_columns_rename_map,
382385
partition_columns,
383386
read_schema,
384387
include_paths,
@@ -388,6 +391,7 @@ def get_read_tasks(
388391
self._to_batches_kwargs,
389392
self._default_batch_size,
390393
self._data_columns,
394+
self._data_columns_rename_map,
391395
self._partition_columns,
392396
self._read_schema,
393397
self._include_paths,
@@ -401,6 +405,7 @@ def get_read_tasks(
401405
to_batches_kwargs,
402406
default_read_batch_size_rows,
403407
data_columns,
408+
data_columns_rename_map,
404409
partition_columns,
405410
read_schema,
406411
f,
@@ -437,9 +442,18 @@ def get_current_projection(self) -> Optional[List[str]]:
437442

438443
return (self._data_columns or []) + (self._partition_columns or [])
439444

440-
def apply_projection(self, columns: List[str]) -> "ParquetDatasource":
445+
def apply_projection(
446+
self,
447+
columns: Optional[List[str]],
448+
column_rename_map: Optional[Dict[str, str]],
449+
) -> "ParquetDatasource":
441450
clone = copy.copy(self)
442-
clone._data_columns = columns
451+
452+
clone._data_columns = _combine_projection(self._data_columns, columns)
453+
clone._data_columns_rename_map = _combine_rename_map(
454+
self._data_columns_rename_map, column_rename_map
455+
)
456+
443457
return clone
444458

445459
def _estimate_in_mem_size(self, fragments: List[_ParquetFragment]) -> int:
@@ -453,6 +467,7 @@ def read_fragments(
453467
to_batches_kwargs: Dict[str, Any],
454468
default_read_batch_size_rows: Optional[int],
455469
data_columns: Optional[List[str]],
470+
data_columns_rename_map: Optional[Dict[str, str]],
456471
partition_columns: Optional[List[str]],
457472
schema: Optional[Union[type, "pyarrow.lib.Schema"]],
458473
fragments: List[_ParquetFragment],
@@ -475,6 +490,7 @@ def read_fragments(
475490
fragment.original,
476491
schema=schema,
477492
data_columns=data_columns,
493+
data_columns_rename_map=data_columns_rename_map,
478494
partition_columns=partition_columns,
479495
partitioning=partitioning,
480496
include_path=include_paths,
@@ -497,6 +513,7 @@ def _read_batches_from(
497513
*,
498514
schema: "pyarrow.Schema",
499515
data_columns: Optional[List[str]],
516+
data_columns_rename_map: Optional[Dict[str, str]],
500517
partition_columns: Optional[List[str]],
501518
partitioning: Partitioning,
502519
filter_expr: Optional["pyarrow.dataset.Expression"] = None,
@@ -562,6 +579,14 @@ def _read_batches_from(
562579
_BATCH_SIZE_PRESERVING_STUB_COL_NAME, pa.nulls(table.num_rows)
563580
)
564581

582+
if data_columns_rename_map is not None:
583+
table = table.rename_columns(
584+
[
585+
data_columns_rename_map.get(col, col)
586+
for col in table.schema.names
587+
]
588+
)
589+
565590
yield table
566591

567592
except pa.lib.ArrowInvalid as e:
@@ -825,6 +850,43 @@ def _add_partitions_to_table(
825850
return table
826851

827852

853+
def _combine_projection(
854+
prev_projected_cols: Optional[List[str]], new_projected_cols: Optional[List[str]]
855+
) -> Optional[List[str]]:
856+
# NOTE: Null projection carries special meaning of all columns being selected
857+
if prev_projected_cols is None:
858+
return new_projected_cols
859+
elif new_projected_cols is None:
860+
# Retain original projection
861+
return prev_projected_cols
862+
else:
863+
illegal_refs = [
864+
col for col in new_projected_cols if col not in prev_projected_cols
865+
]
866+
867+
if illegal_refs:
868+
raise ValueError(
869+
f"New projection {new_projected_cols} references non-existent columns "
870+
f"(existing projection {prev_projected_cols})"
871+
)
872+
873+
return new_projected_cols
874+
875+
876+
def _combine_rename_map(
877+
prev_column_rename_map: Optional[Dict[str, str]],
878+
new_column_rename_map: Optional[Dict[str, str]],
879+
):
880+
if not prev_column_rename_map:
881+
combined = new_column_rename_map
882+
elif not new_column_rename_map:
883+
combined = prev_column_rename_map
884+
else:
885+
combined = prev_column_rename_map | new_column_rename_map
886+
887+
return collapse_transitive_map(combined)
888+
889+
828890
def _get_partition_columns_schema(
829891
partitioning: Partitioning,
830892
file_paths: List[str],

python/ray/data/_internal/logical/interfaces/logical_operator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,9 @@ def supports_projection_pushdown(self) -> bool:
9898
def get_current_projection(self) -> Optional[List[str]]:
9999
return None
100100

101-
def apply_projection(self, columns: Optional[List[str]]) -> LogicalOperator:
101+
def apply_projection(
102+
self,
103+
columns: Optional[List[str]],
104+
column_rename_map: Optional[Dict[str, str]],
105+
) -> LogicalOperator:
102106
return self

python/ray/data/_internal/logical/operators/read_operator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,16 @@ def supports_projection_pushdown(self) -> bool:
158158
def get_current_projection(self) -> Optional[List[str]]:
159159
return self._datasource.get_current_projection()
160160

161-
def apply_projection(self, columns: List[str]):
161+
def apply_projection(
162+
self,
163+
columns: Optional[List[str]],
164+
column_rename_map: Optional[Dict[str, str]],
165+
) -> "Read":
162166
clone = copy.copy(self)
163167

164-
projected_datasource = self._datasource.apply_projection(columns)
168+
projected_datasource = self._datasource.apply_projection(
169+
columns, column_rename_map
170+
)
165171
clone._datasource = projected_datasource
166172
clone._datasource_or_legacy_reader = projected_datasource
167173

python/ray/data/_internal/logical/rules/projection_pushdown.py

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Set, Tuple
1+
from typing import Dict, List, Optional, Set, Tuple
22

33
from ray.data._internal.logical.interfaces import (
44
LogicalOperator,
@@ -19,7 +19,7 @@
1919
)
2020

2121

22-
def _collect_referenced_columns(exprs: List[Expr]) -> Optional[Set[str]]:
22+
def _collect_referenced_columns(exprs: List[Expr]) -> Optional[List[str]]:
2323
"""
2424
Extract all column names referenced by the given expressions.
2525
@@ -37,7 +37,8 @@ def _collect_referenced_columns(exprs: List[Expr]) -> Optional[Set[str]]:
3737
collector = _ColumnReferenceCollector()
3838
for expr in exprs or []:
3939
collector.visit(expr)
40-
return collector.referenced_columns
40+
41+
return collector.get_column_refs()
4142

4243

4344
def _extract_simple_rename(expr: Expr) -> Optional[Tuple[str, str]]:
@@ -118,9 +119,11 @@ def _validate_fusion(
118119
if isinstance(expr, StarExpr):
119120
continue
120121

121-
referenced_columns = _collect_referenced_columns([expr]) or set()
122-
columns_from_original = referenced_columns - (
123-
referenced_columns & upstream_output_columns
122+
column_refs = _collect_referenced_columns([expr])
123+
column_refs_set = set(column_refs or [])
124+
125+
columns_from_original = column_refs_set - (
126+
column_refs_set & upstream_output_columns
124127
)
125128

126129
# Validate accessibility
@@ -280,27 +283,65 @@ def _push_projection_into_read_op(cls, op: LogicalOperator) -> LogicalOperator:
280283
# Step 2: Push projection into the data source if supported
281284
input_op = current_project.input_dependency
282285
if (
283-
not current_project.has_star_expr() # Must be a selection, not additive
284-
and isinstance(input_op, LogicalOperatorSupportsProjectionPushdown)
286+
isinstance(input_op, LogicalOperatorSupportsProjectionPushdown)
285287
and input_op.supports_projection_pushdown()
286288
):
287-
required_columns = _collect_referenced_columns(list(current_project.exprs))
288-
if required_columns is not None: # None means star() was present
289-
optimized_source = input_op.apply_projection(list(required_columns))
290-
291-
is_simple_selection = all(
292-
isinstance(expr, ColumnExpr) for expr in current_project.exprs
289+
if current_project.has_star_expr():
290+
# If project has a star, than no projection is feasible
291+
required_columns = None
292+
else:
293+
# Otherwise, collect required column for projection
294+
required_columns = _collect_referenced_columns(current_project.exprs)
295+
296+
# Check if it's a simple projection that could be pushed into
297+
# read as a whole
298+
is_simple_projection = all(
299+
_is_col_expr(expr)
300+
for expr in current_project.exprs
301+
if not isinstance(expr, StarExpr)
302+
)
303+
304+
if is_simple_projection:
305+
# NOTE: We only can rename output columns when it's a simple
306+
# projection and Project operator is discarded (otherwise
307+
# it might be holding expression referencing attributes
308+
# by original their names prior to renaming)
309+
#
310+
# TODO fix by instead rewriting exprs
311+
output_column_rename_map = _collect_output_column_rename_map(
312+
current_project.exprs
293313
)
294314

295-
if is_simple_selection:
296-
# Simple column selection: Read handles everything
297-
return optimized_source
298-
else:
299-
# Has transformations: Keep Project on top of optimized Read
300-
return Project(
301-
optimized_source,
302-
exprs=current_project.exprs,
303-
ray_remote_args=current_project._ray_remote_args,
304-
)
315+
# Apply projection of columns to the read op
316+
return input_op.apply_projection(
317+
required_columns, output_column_rename_map
318+
)
319+
else:
320+
# Otherwise just apply projection without renaming
321+
projected_input_op = input_op.apply_projection(required_columns, None)
322+
323+
# Has transformations: Keep Project on top of optimized Read
324+
return Project(
325+
projected_input_op,
326+
exprs=current_project.exprs,
327+
ray_remote_args=current_project._ray_remote_args,
328+
)
305329

306330
return current_project
331+
332+
333+
def _is_col_expr(expr: Expr) -> bool:
334+
return isinstance(expr, ColumnExpr) or (
335+
isinstance(expr, AliasExpr) and isinstance(expr.expr, ColumnExpr)
336+
)
337+
338+
339+
def _collect_output_column_rename_map(exprs: List[Expr]) -> Dict[str, str]:
340+
# First, extract all potential rename pairs
341+
rename_map = {
342+
expr.expr.name: expr.name
343+
for expr in exprs
344+
if isinstance(expr, AliasExpr) and isinstance(expr.expr, ColumnExpr)
345+
}
346+
347+
return rename_map

0 commit comments

Comments
 (0)