-
Notifications
You must be signed in to change notification settings - Fork 7k
[Data] Fixing handling of renames in projection pushdown #58033
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| from typing import Dict, TypeVar | ||
|
|
||
| K = TypeVar("K") | ||
|
|
||
|
|
||
| def collapse_transitive_map(d: Dict[K, K]) -> Dict[K, K]: | ||
| """Collapse transitive mappings in a dictionary. Given a mapping like | ||
| {a: b, b: c, c: d}, returns {a: d}, removing intermediate b -> c, c -> d. | ||
|
|
||
| Only keeps mappings where the key is NOT a value in another mapping (i.e., chain starting points). | ||
|
|
||
| Args: | ||
| d: Dictionary representing a mapping | ||
|
|
||
| Returns: | ||
| Dictionary with all transitive mappings collapsed, keeping only KV-pairs, | ||
| such that K and V are starting and terminal point of a chain | ||
|
|
||
| Examples: | ||
| >>> collapse_transitive_map({"a": "b", "b": "c", "c": "d"}) | ||
| {'a': 'd'} | ||
| >>> collapse_transitive_map({"a": "b", "x": "y"}) | ||
| {'a': 'b', 'x': 'y'} | ||
| """ | ||
| if not d: | ||
| return {} | ||
|
|
||
| collapsed = {} | ||
| values_set = set(d.values()) | ||
| for k in d: | ||
| # Skip mappings that are in the value-set, meaning that they are | ||
| # part of the mapping chain (for ex, {a -> b, b -> c}) | ||
| if k in values_set: | ||
| continue | ||
|
|
||
| cur = k | ||
| visited = {cur} | ||
|
|
||
| # Follow the chain until we reach a key that's not in the mapping | ||
| while cur in d: | ||
| next = d[cur] | ||
| if next in visited: | ||
| raise ValueError(f"Detected a cycle in the mapping {d}") | ||
| visited.add(next) | ||
| cur = next | ||
|
|
||
| collapsed[k] = cur | ||
|
|
||
| return collapsed | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
| _BATCH_SIZE_PRESERVING_STUB_COL_NAME, | ||
| ArrowBlockAccessor, | ||
| ) | ||
| from ray.data._internal.collections import collapse_transitive_map | ||
| from ray.data._internal.progress_bar import ProgressBar | ||
| from ray.data._internal.remote_fn import cached_remote_fn | ||
| from ray.data._internal.util import ( | ||
|
|
@@ -275,6 +276,7 @@ def __init__( | |
| self._block_udf = _block_udf | ||
| self._to_batches_kwargs = to_batch_kwargs | ||
| self._data_columns = data_columns | ||
| self._data_columns_rename_map = {} | ||
| self._partition_columns = partition_columns | ||
| self._read_schema = schema | ||
| self._file_schema = pq_ds.schema | ||
|
|
@@ -379,6 +381,7 @@ def get_read_tasks( | |
| to_batches_kwargs, | ||
| default_read_batch_size_rows, | ||
| data_columns, | ||
| data_columns_rename_map, | ||
| partition_columns, | ||
| read_schema, | ||
| include_paths, | ||
|
|
@@ -388,6 +391,7 @@ def get_read_tasks( | |
| self._to_batches_kwargs, | ||
| self._default_batch_size, | ||
| self._data_columns, | ||
| self._data_columns_rename_map, | ||
| self._partition_columns, | ||
| self._read_schema, | ||
| self._include_paths, | ||
|
|
@@ -401,6 +405,7 @@ def get_read_tasks( | |
| to_batches_kwargs, | ||
| default_read_batch_size_rows, | ||
| data_columns, | ||
| data_columns_rename_map, | ||
| partition_columns, | ||
| read_schema, | ||
| f, | ||
|
|
@@ -437,9 +442,18 @@ def get_current_projection(self) -> Optional[List[str]]: | |
|
|
||
| return (self._data_columns or []) + (self._partition_columns or []) | ||
|
|
||
| def apply_projection(self, columns: List[str]) -> "ParquetDatasource": | ||
| def apply_projection( | ||
| self, | ||
| columns: Optional[List[str]], | ||
| column_rename_map: Optional[Dict[str, str]], | ||
| ) -> "ParquetDatasource": | ||
| clone = copy.copy(self) | ||
| clone._data_columns = columns | ||
|
|
||
| clone._data_columns = _combine_projection(self._data_columns, columns) | ||
| clone._data_columns_rename_map = _combine_rename_map( | ||
| self._data_columns_rename_map, column_rename_map | ||
| ) | ||
|
|
||
| return clone | ||
|
|
||
| def _estimate_in_mem_size(self, fragments: List[_ParquetFragment]) -> int: | ||
|
|
@@ -453,6 +467,7 @@ def read_fragments( | |
| to_batches_kwargs: Dict[str, Any], | ||
| default_read_batch_size_rows: Optional[int], | ||
| data_columns: Optional[List[str]], | ||
| data_columns_rename_map: Optional[Dict[str, str]], | ||
| partition_columns: Optional[List[str]], | ||
| schema: Optional[Union[type, "pyarrow.lib.Schema"]], | ||
| fragments: List[_ParquetFragment], | ||
|
|
@@ -475,6 +490,7 @@ def read_fragments( | |
| fragment.original, | ||
| schema=schema, | ||
| data_columns=data_columns, | ||
| data_columns_rename_map=data_columns_rename_map, | ||
| partition_columns=partition_columns, | ||
| partitioning=partitioning, | ||
| include_path=include_paths, | ||
|
|
@@ -497,6 +513,7 @@ def _read_batches_from( | |
| *, | ||
| schema: "pyarrow.Schema", | ||
| data_columns: Optional[List[str]], | ||
| data_columns_rename_map: Optional[Dict[str, str]], | ||
| partition_columns: Optional[List[str]], | ||
| partitioning: Partitioning, | ||
| filter_expr: Optional["pyarrow.dataset.Expression"] = None, | ||
|
|
@@ -562,6 +579,14 @@ def _read_batches_from( | |
| _BATCH_SIZE_PRESERVING_STUB_COL_NAME, pa.nulls(table.num_rows) | ||
| ) | ||
|
|
||
| if data_columns_rename_map is not None: | ||
| table = table.rename_columns( | ||
| [ | ||
| data_columns_rename_map.get(col, col) | ||
| for col in table.schema.names | ||
| ] | ||
| ) | ||
|
|
||
| yield table | ||
|
|
||
| except pa.lib.ArrowInvalid as e: | ||
|
|
@@ -825,6 +850,43 @@ def _add_partitions_to_table( | |
| return table | ||
|
|
||
|
|
||
| def _combine_projection( | ||
| prev_projected_cols: Optional[List[str]], new_projected_cols: Optional[List[str]] | ||
| ) -> Optional[List[str]]: | ||
| # NOTE: Null projection carries special meaning of all columns being selected | ||
| if prev_projected_cols is None: | ||
| return new_projected_cols | ||
| elif new_projected_cols is None: | ||
| # Retain original projection | ||
| return prev_projected_cols | ||
| else: | ||
| illegal_refs = [ | ||
| col for col in new_projected_cols if col not in prev_projected_cols | ||
| ] | ||
|
|
||
| if illegal_refs: | ||
| raise ValueError( | ||
| f"New projection {new_projected_cols} references non-existent columns " | ||
| f"(existing projection {prev_projected_cols})" | ||
| ) | ||
|
|
||
| return new_projected_cols | ||
|
|
||
|
|
||
| def _combine_rename_map( | ||
| prev_column_rename_map: Optional[Dict[str, str]], | ||
| new_column_rename_map: Optional[Dict[str, str]], | ||
| ): | ||
| if not prev_column_rename_map: | ||
| combined = new_column_rename_map | ||
| elif not new_column_rename_map: | ||
| combined = prev_column_rename_map | ||
| else: | ||
| combined = prev_column_rename_map | new_column_rename_map | ||
|
|
||
| return collapse_transitive_map(combined) | ||
|
Comment on lines
+876
to
+887
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is missing a return type hint. Adding it would improve type safety and code clarity. Additionally, the logic for combining the dictionaries can be made more concise and robust against def _combine_rename_map(
prev_column_rename_map: Optional[Dict[str, str]],
new_column_rename_map: Optional[Dict[str, str]],
) -> Dict[str, str]:
combined = (prev_column_rename_map or {}) | (new_column_rename_map or {})
return collapse_transitive_map(combined) |
||
|
|
||
|
|
||
| def _get_partition_columns_schema( | ||
| partitioning: Partitioning, | ||
| file_paths: List[str], | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| from typing import List, Optional, Set, Tuple | ||
| from typing import Dict, List, Optional, Set, Tuple | ||
|
|
||
| from ray.data._internal.logical.interfaces import ( | ||
| LogicalOperator, | ||
|
|
@@ -19,7 +19,7 @@ | |
| ) | ||
|
|
||
|
|
||
| def _collect_referenced_columns(exprs: List[Expr]) -> Optional[Set[str]]: | ||
| def _collect_referenced_columns(exprs: List[Expr]) -> Optional[List[str]]: | ||
| """ | ||
| Extract all column names referenced by the given expressions. | ||
|
|
||
|
|
@@ -37,7 +37,8 @@ def _collect_referenced_columns(exprs: List[Expr]) -> Optional[Set[str]]: | |
| collector = _ColumnReferenceCollector() | ||
| for expr in exprs or []: | ||
| collector.visit(expr) | ||
| return collector.referenced_columns | ||
|
|
||
| return collector.get_column_refs() | ||
|
|
||
|
|
||
| def _extract_simple_rename(expr: Expr) -> Optional[Tuple[str, str]]: | ||
|
|
@@ -118,9 +119,11 @@ def _validate_fusion( | |
| if isinstance(expr, StarExpr): | ||
| continue | ||
|
|
||
| referenced_columns = _collect_referenced_columns([expr]) or set() | ||
| columns_from_original = referenced_columns - ( | ||
| referenced_columns & upstream_output_columns | ||
| column_refs = _collect_referenced_columns([expr]) | ||
| column_refs_set = set(column_refs or []) | ||
|
|
||
| columns_from_original = column_refs_set - ( | ||
| column_refs_set & upstream_output_columns | ||
| ) | ||
|
|
||
| # Validate accessibility | ||
|
|
@@ -280,27 +283,65 @@ def _push_projection_into_read_op(cls, op: LogicalOperator) -> LogicalOperator: | |
| # Step 2: Push projection into the data source if supported | ||
| input_op = current_project.input_dependency | ||
| if ( | ||
| not current_project.has_star_expr() # Must be a selection, not additive | ||
| and isinstance(input_op, LogicalOperatorSupportsProjectionPushdown) | ||
| isinstance(input_op, LogicalOperatorSupportsProjectionPushdown) | ||
| and input_op.supports_projection_pushdown() | ||
| ): | ||
| required_columns = _collect_referenced_columns(list(current_project.exprs)) | ||
| if required_columns is not None: # None means star() was present | ||
| optimized_source = input_op.apply_projection(list(required_columns)) | ||
|
|
||
| is_simple_selection = all( | ||
| isinstance(expr, ColumnExpr) for expr in current_project.exprs | ||
| if current_project.has_star_expr(): | ||
| # If project has a star, than no projection is feasible | ||
| required_columns = None | ||
| else: | ||
| # Otherwise, collect required column for projection | ||
| required_columns = _collect_referenced_columns(current_project.exprs) | ||
|
|
||
| # Check if it's a simple projection that could be pushed into | ||
| # read as a whole | ||
| is_simple_projection = all( | ||
| _is_col_expr(expr) | ||
| for expr in current_project.exprs | ||
| if not isinstance(expr, StarExpr) | ||
| ) | ||
|
|
||
| if is_simple_projection: | ||
| # NOTE: We only can rename output columns when it's a simple | ||
| # projection and Project operator is discarded (otherwise | ||
| # it might be holding expression referencing attributes | ||
| # by original their names prior to renaming) | ||
| # | ||
| # TODO fix by instead rewriting exprs | ||
| output_column_rename_map = _collect_output_column_rename_map( | ||
| current_project.exprs | ||
| ) | ||
|
|
||
| if is_simple_selection: | ||
| # Simple column selection: Read handles everything | ||
| return optimized_source | ||
| else: | ||
| # Has transformations: Keep Project on top of optimized Read | ||
| return Project( | ||
| optimized_source, | ||
| exprs=current_project.exprs, | ||
| ray_remote_args=current_project._ray_remote_args, | ||
| ) | ||
| # Apply projection of columns to the read op | ||
| return input_op.apply_projection( | ||
| required_columns, output_column_rename_map | ||
| ) | ||
| else: | ||
| # Otherwise just apply projection without renaming | ||
| projected_input_op = input_op.apply_projection(required_columns, None) | ||
|
|
||
| # Has transformations: Keep Project on top of optimized Read | ||
| return Project( | ||
| projected_input_op, | ||
| exprs=current_project.exprs, | ||
| ray_remote_args=current_project._ray_remote_args, | ||
| ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Projection Pushdown Fails with
|
||
|
|
||
| return current_project | ||
|
|
||
|
|
||
| def _is_col_expr(expr: Expr) -> bool: | ||
| return isinstance(expr, ColumnExpr) or ( | ||
| isinstance(expr, AliasExpr) and isinstance(expr.expr, ColumnExpr) | ||
| ) | ||
|
|
||
|
|
||
| def _collect_output_column_rename_map(exprs: List[Expr]) -> Dict[str, str]: | ||
| # First, extract all potential rename pairs | ||
| rename_map = { | ||
| expr.expr.name: expr.name | ||
| for expr in exprs | ||
| if isinstance(expr, AliasExpr) and isinstance(expr.expr, ColumnExpr) | ||
| } | ||
|
|
||
| return rename_map | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable
nextshadows the built-in functionnext(). It's a good practice to avoid shadowing built-ins to prevent potential confusion and bugs. Consider renaming it to something more specific likenext_keyornext_val.