Skip to content

Commit 4f9a2cc

Browse files
[Data] [2/n] - Add new expr based project
Signed-off-by: Goutam <goutam@anyscale.com>
1 parent 194ddf8 commit 4f9a2cc

File tree

12 files changed

+1315
-365
lines changed

12 files changed

+1315
-365
lines changed

python/ray/data/_internal/arrow_block.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858

5959

6060
_MIN_PYARROW_VERSION_TO_NUMPY_ZERO_COPY_ONLY = parse_version("13.0.0")
61+
_BATCH_SIZE_PRESERVING_STUB_COL_NAME = "__bsp_stub"
6162

6263

6364
# Set the max chunk size in bytes for Arrow to Batches conversion in
@@ -221,7 +222,7 @@ def fill_column(self, name: str, value: Any) -> Block:
221222

222223
array = pyarrow.nulls(len(self._table), type=type)
223224
array = pc.fill_null(array, value)
224-
return self._table.append_column(name, array)
225+
return self.upsert_column(name, array)
225226

226227
@classmethod
227228
def from_bytes(cls, data: bytes) -> "ArrowBlockAccessor":
@@ -366,6 +367,10 @@ def select(self, columns: List[str]) -> "pyarrow.Table":
366367
"Columns must be a list of column name strings when aggregating on "
367368
f"Arrow blocks, but got: {columns}."
368369
)
370+
if len(columns) == 0:
371+
# Applicable for count which does an empty projection.
372+
# Pyarrow returns a table with 0 columns and num_rows rows.
373+
return self.fill_column(_BATCH_SIZE_PRESERVING_STUB_COL_NAME, None)
369374
return self._table.select(columns)
370375

371376
def rename_columns(self, columns_rename: Dict[str, str]) -> "pyarrow.Table":

python/ray/data/_internal/datasource/parquet_datasource.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323

2424
import ray
2525
from ray._private.arrow_utils import get_pyarrow_version
26-
from ray.data._internal.arrow_block import ArrowBlockAccessor
26+
from ray.data._internal.arrow_block import (
27+
_BATCH_SIZE_PRESERVING_STUB_COL_NAME,
28+
ArrowBlockAccessor,
29+
)
2730
from ray.data._internal.progress_bar import ProgressBar
2831
from ray.data._internal.remote_fn import cached_remote_fn
2932
from ray.data._internal.util import (
@@ -104,9 +107,6 @@
104107
PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS = 1024
105108

106109

107-
_BATCH_SIZE_PRESERVING_STUB_COL_NAME = "__bsp_stub"
108-
109-
110110
class _ParquetFragment:
111111
"""This wrapper class is created to avoid utilizing `ParquetFileFragment` original
112112
serialization protocol that actually does network RPCs during serialization

python/ray/data/_internal/logical/operators/map_operator.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ray.data._internal.logical.interfaces import LogicalOperator
88
from ray.data._internal.logical.operators.one_to_one_operator import AbstractOneToOne
99
from ray.data.block import UserDefinedFunction
10-
from ray.data.expressions import Expr
10+
from ray.data.expressions import Expr, StarExpr
1111
from ray.data.preprocessor import Preprocessor
1212

1313
logger = logging.getLogger(__name__)
@@ -268,16 +268,12 @@ def can_modify_num_rows(self) -> bool:
268268

269269

270270
class Project(AbstractMap):
271-
"""Logical operator for select_columns."""
271+
"""Logical operator for all Projection Operations."""
272272

273273
def __init__(
274274
self,
275275
input_op: LogicalOperator,
276-
cols: Optional[List[str]] = None,
277-
cols_rename: Optional[Dict[str, str]] = None,
278-
exprs: Optional[
279-
Dict[str, "Expr"]
280-
] = None, # TODO Remove cols and cols_rename and replace them with corresponding exprs
276+
exprs: list["Expr"],
281277
compute: Optional[ComputeStrategy] = None,
282278
ray_remote_args: Optional[Dict[str, Any]] = None,
283279
):
@@ -288,30 +284,23 @@ def __init__(
288284
compute=compute,
289285
)
290286
self._batch_size = None
291-
self._cols = cols
292-
self._cols_rename = cols_rename
293287
self._exprs = exprs
294288
self._batch_format = "pyarrow"
295289
self._zero_copy_batch = True
296290

297-
if exprs is not None:
298-
# Validate that all values are expressions
299-
for name, expr in exprs.items():
300-
if not isinstance(expr, Expr):
301-
raise TypeError(
302-
f"Expected Expr for column '{name}', got {type(expr)}"
303-
)
291+
for expr in self._exprs:
292+
if expr.name is None and not isinstance(expr, StarExpr):
293+
raise TypeError(
294+
"All Project expressions must be named (use .alias(name) or col(name)), "
295+
"or be a star() expression."
296+
)
304297

305-
@property
306-
def cols(self) -> Optional[List[str]]:
307-
return self._cols
308-
309-
@property
310-
def cols_rename(self) -> Optional[Dict[str, str]]:
311-
return self._cols_rename
298+
def has_star_expr(self) -> bool:
299+
"""Check if this projection contains a star() expression."""
300+
return any(isinstance(expr, StarExpr) for expr in self._exprs)
312301

313302
@property
314-
def exprs(self) -> Optional[Dict[str, "Expr"]]:
303+
def exprs(self) -> List["Expr"]:
315304
return self._exprs
316305

317306
def can_modify_num_rows(self) -> bool:

0 commit comments

Comments
 (0)