Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@


_MIN_PYARROW_VERSION_TO_NUMPY_ZERO_COPY_ONLY = parse_version("13.0.0")
_BATCH_SIZE_PRESERVING_STUB_COL_NAME = "__bsp_stub"


# Set the max chunk size in bytes for Arrow to Batches conversion in
Expand Down Expand Up @@ -221,7 +222,7 @@ def fill_column(self, name: str, value: Any) -> Block:

array = pyarrow.nulls(len(self._table), type=type)
array = pc.fill_null(array, value)
return self._table.append_column(name, array)
return self.upsert_column(name, array)

@classmethod
def from_bytes(cls, data: bytes) -> "ArrowBlockAccessor":
Expand Down Expand Up @@ -366,6 +367,10 @@ def select(self, columns: List[str]) -> "pyarrow.Table":
"Columns must be a list of column name strings when aggregating on "
f"Arrow blocks, but got: {columns}."
)
if len(columns) == 0:
# Applicable for count which does an empty projection.
# Pyarrow returns a table with 0 columns and num_rows rows.
return self.fill_column(_BATCH_SIZE_PRESERVING_STUB_COL_NAME, None)
return self._table.select(columns)

def rename_columns(self, columns_rename: Dict[str, str]) -> "pyarrow.Table":
Expand Down
8 changes: 4 additions & 4 deletions python/ray/data/_internal/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@

import ray
from ray._private.arrow_utils import get_pyarrow_version
from ray.data._internal.arrow_block import ArrowBlockAccessor
from ray.data._internal.arrow_block import (
_BATCH_SIZE_PRESERVING_STUB_COL_NAME,
ArrowBlockAccessor,
)
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.util import (
Expand Down Expand Up @@ -104,9 +107,6 @@
PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS = 1024


_BATCH_SIZE_PRESERVING_STUB_COL_NAME = "__bsp_stub"


class _ParquetFragment:
"""This wrapper class is created to avoid utilizing `ParquetFileFragment` original
serialization protocol that actually does network RPCs during serialization
Expand Down
37 changes: 13 additions & 24 deletions python/ray/data/_internal/logical/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ray.data._internal.logical.interfaces import LogicalOperator
from ray.data._internal.logical.operators.one_to_one_operator import AbstractOneToOne
from ray.data.block import UserDefinedFunction
from ray.data.expressions import Expr
from ray.data.expressions import Expr, StarExpr
from ray.data.preprocessor import Preprocessor

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -268,16 +268,12 @@ def can_modify_num_rows(self) -> bool:


class Project(AbstractMap):
"""Logical operator for select_columns."""
"""Logical operator for all Projection Operations."""

def __init__(
self,
input_op: LogicalOperator,
cols: Optional[List[str]] = None,
cols_rename: Optional[Dict[str, str]] = None,
exprs: Optional[
Dict[str, "Expr"]
] = None, # TODO Remove cols and cols_rename and replace them with corresponding exprs
exprs: list["Expr"],
compute: Optional[ComputeStrategy] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
Expand All @@ -288,30 +284,23 @@ def __init__(
compute=compute,
)
self._batch_size = None
self._cols = cols
self._cols_rename = cols_rename
self._exprs = exprs
self._batch_format = "pyarrow"
self._zero_copy_batch = True

if exprs is not None:
# Validate that all values are expressions
for name, expr in exprs.items():
if not isinstance(expr, Expr):
raise TypeError(
f"Expected Expr for column '{name}', got {type(expr)}"
)
for expr in self._exprs:
if expr.name is None and not isinstance(expr, StarExpr):
raise TypeError(
"All Project expressions must be named (use .alias(name) or col(name)), "
"or be a star() expression."
)

@property
def cols(self) -> Optional[List[str]]:
return self._cols

@property
def cols_rename(self) -> Optional[Dict[str, str]]:
return self._cols_rename
def has_star_expr(self) -> bool:
"""Check if this projection contains a star() expression."""
return any(isinstance(expr, StarExpr) for expr in self._exprs)

@property
def exprs(self) -> Optional[Dict[str, "Expr"]]:
def exprs(self) -> List["Expr"]:
return self._exprs

def can_modify_num_rows(self) -> bool:
Expand Down
Loading