Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Try get size_bytes from metadata and consolidate metadata methods #46862

Merged
merged 13 commits into from
Jul 31, 2024
16 changes: 0 additions & 16 deletions python/ray/data/_internal/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,22 +454,6 @@ def get_name(self):
def supports_distributed_reads(self) -> bool:
return self._supports_distributed_reads

def num_rows(self) -> Optional[int]:
# If there is a filter operation, the total row count is unknown.
if self._to_batches_kwargs.get("filter") is not None:
return None

if not self._metadata:
return None

return sum(metadata.num_rows for metadata in self._metadata)

def schema(self) -> "pyarrow.Schema":
return self._inferred_schema

def input_files(self) -> Optional[List[str]]:
return self._pq_paths


def _read_fragments(
block_udf,
Expand Down
7 changes: 2 additions & 5 deletions python/ray/data/_internal/datasource/range_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def make_blocks(
meta = BlockMetadata(
num_rows=count,
size_bytes=8 * count * element_size,
schema=copy(self.schema()),
schema=copy(self._schema()),
input_files=None,
exec_stats=None,
)
Expand All @@ -113,7 +113,7 @@ def make_blocks(
return read_tasks

@functools.cache
def schema(self):
def _schema(self):
if self._n == 0:
return None

Expand All @@ -137,6 +137,3 @@ def schema(self):
else:
raise ValueError("Unsupported block type", self._block_format)
return schema

def num_rows(self):
return self._n
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Iterator, List, Optional

from .operator import Operator
from ray.data.block import BlockMetadata

if TYPE_CHECKING:
import pyarrow

from ray.data._internal.execution.interfaces import RefBundle


Expand Down Expand Up @@ -57,29 +56,17 @@ def output_dependencies(self) -> List["LogicalOperator"]:
def post_order_iter(self) -> Iterator["LogicalOperator"]:
return super().post_order_iter() # type: ignore

def schema(self) -> Optional[Union[type, "pyarrow.lib.Schema"]]:
"""The schema of operator outputs, or ``None`` if not known.

This method is used to get the dataset schema without performing actual
computation.
"""
def output_data(self) -> Optional[List["RefBundle"]]:
"""The output data of this operator, or ``None`` if not known."""
return None

def num_rows(self) -> Optional[int]:
"""The number of rows outputted by this operator, or ``None`` if not known.
def aggregate_output_metadata(self) -> BlockMetadata:
"""A ``BlockMetadata`` that represents the aggregate metadata of the outputs.

This method is used to count the number of rows in a dataset without performing
actual computation.
This method is used by methods like :meth:`~ray.data.Dataset.schema` to
efficiently return metadata.
"""
return None

def input_files(self) -> Optional[List[str]]:
"""The input files of this operator, or ``None`` if not known."""
return None

def output_data(self) -> Optional[List["RefBundle"]]:
"""The output data of this operator, or ``None`` if not known."""
return None
return BlockMetadata(None, None, None, None, None)

def is_lineage_serializable(self) -> bool:
"""Returns whether the lineage of this operator can be serialized.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ray.data._internal.planner.exchange.shuffle_task_spec import ShuffleTaskSpec
from ray.data._internal.planner.exchange.sort_task_spec import SortKey, SortTaskSpec
from ray.data.aggregate import AggregateFn
from ray.data.block import BlockMetadata


class AbstractAllToAll(LogicalOperator):
Expand Down Expand Up @@ -50,13 +51,9 @@ def __init__(
)
self._seed = seed

def schema(self):
def aggregate_output_metadata(self) -> BlockMetadata:
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
return self._input_dependencies[0].schema()

def num_rows(self):
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
return self._input_dependencies[0].num_rows()
return self._input_dependencies[0].aggregate_output_metadata()


class RandomShuffle(AbstractAllToAll):
Expand All @@ -80,6 +77,10 @@ def __init__(
)
self._seed = seed

def aggregate_output_metadata(self) -> BlockMetadata:
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
return self._input_dependencies[0].aggregate_output_metadata()


class Repartition(AbstractAllToAll):
"""Logical operator for repartition."""
Expand Down Expand Up @@ -107,6 +108,10 @@ def __init__(
)
self._shuffle = shuffle

def aggregate_output_metadata(self) -> BlockMetadata:
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
return self._input_dependencies[0].aggregate_output_metadata()


class Sort(AbstractAllToAll):
"""Logical operator for sort."""
Expand All @@ -127,6 +132,10 @@ def __init__(
)
self._sort_key = sort_key

def aggregate_output_metadata(self) -> BlockMetadata:
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
return self._input_dependencies[0].aggregate_output_metadata()


class Aggregate(AbstractAllToAll):
"""Logical operator for aggregate."""
Expand Down
30 changes: 24 additions & 6 deletions python/ray/data/_internal/logical/operators/from_operators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import functools
from typing import TYPE_CHECKING, List, Optional, Union

from ray.data._internal.execution.interfaces import RefBundle
Expand Down Expand Up @@ -36,18 +37,35 @@ def __init__(
def input_data(self) -> List[RefBundle]:
return self._input_data

def schema(self):
metadata = [m for bundle in self._input_data for m in bundle.metadata]
return unify_block_metadata_schema(metadata)
def output_data(self) -> Optional[List[RefBundle]]:
return self._input_data

def num_rows(self):
@functools.cache
def aggregate_output_metadata(self) -> BlockMetadata:
return BlockMetadata(
num_rows=self._num_rows(),
size_bytes=self._size_bytes(),
schema=self._schema(),
input_files=None,
exec_stats=None,
)

def _num_rows(self):
if all(bundle.num_rows() is not None for bundle in self._input_data):
return sum(bundle.num_rows() for bundle in self._input_data)
else:
return None

def output_data(self) -> Optional[List[RefBundle]]:
return self._input_data
def _size_bytes(self):
metadata = [m for bundle in self._input_data for m in bundle.metadata]
if all(m.size_bytes is not None for m in metadata):
return sum(m.size_bytes for m in metadata)
else:
return None

def _schema(self):
metadata = [m for bundle in self._input_data for m in bundle.metadata]
return unify_block_metadata_schema(metadata)
bveeramani marked this conversation as resolved.
Show resolved Hide resolved

def is_lineage_serializable(self) -> bool:
# This operator isn't serializable because it contains ObjectRefs.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import functools
from typing import Callable, List, Optional

from ray.data._internal.execution.interfaces import RefBundle
from ray.data._internal.logical.interfaces import LogicalOperator
from ray.data._internal.util import unify_block_metadata_schema
from ray.data.block import BlockMetadata


class InputData(LogicalOperator):
Expand All @@ -26,25 +28,43 @@ def __init__(
self.input_data = input_data
self.input_data_factory = input_data_factory

def schema(self):
def output_data(self) -> Optional[List[RefBundle]]:
if self.input_data is None:
return None
return self.input_data

metadata = [m for bundle in self.input_data for m in bundle.metadata]
return unify_block_metadata_schema(metadata)

def num_rows(self):
@functools.cache
def aggregate_output_metadata(self) -> BlockMetadata:
if self.input_data is None:
return None
elif all(bundle.num_rows() is not None for bundle in self.input_data):
return BlockMetadata(None, None, None, None, None)

return BlockMetadata(
num_rows=self._num_rows(),
size_bytes=self._size_bytes(),
schema=self._schema(),
input_files=None,
exec_stats=None,
)

def _num_rows(self):
assert self.input_data is not None
if all(bundle.num_rows() is not None for bundle in self.input_data):
return sum(bundle.num_rows() for bundle in self.input_data)
Comment on lines +51 to 52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: not a big deal especially since it elongates the code. technically this is a double for loop, which could be optimized into one loop and early exit if we run into a Non num_rows. something like:

sum_num_rows = 0
for bundle in self.input_data:
    if not bundle.num_rows():
        return None
    sum_num_rows += bundle.num_rows()
return sum_num_rows

and the same could be applied in other similar places, e.g. _size_bytes method and methods in AbstractFrom. it may be worthwhile to put this logic in a shared utility function, for example as a static method in BlockMetadata.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, although I prefer keeping as-is since we don't know if this is a performance issue and I want to avoid premature optimization

else:
return None

def output_data(self) -> Optional[List[RefBundle]]:
if self.input_data is None:
def _size_bytes(self):
assert self.input_data is not None
metadata = [m for bundle in self.input_data for m in bundle.metadata]
if all(m.size_bytes is not None for m in metadata):
return sum(m.size_bytes for m in metadata)
else:
return None
return self.input_data

def _schema(self):
assert self.input_data is not None
metadata = [m for bundle in self.input_data for m in bundle.metadata]
return unify_block_metadata_schema(metadata)

def is_lineage_serializable(self) -> bool:
# This operator isn't serializable because it contains ObjectRefs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional

from ray.data._internal.logical.interfaces import LogicalOperator
from ray.data.block import BlockMetadata


class AbstractOneToOne(LogicalOperator):
Expand Down Expand Up @@ -53,14 +54,27 @@ def __init__(
def can_modify_num_rows(self) -> bool:
return True

def schema(self):
def aggregate_output_metadata(self) -> BlockMetadata:
return BlockMetadata(
num_rows=self._num_rows(),
size_bytes=None,
schema=self._schema(),
input_files=self._input_files(),
exec_stats=None,
)

def _schema(self):
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
return self._input_dependencies[0].schema()
return self._input_dependencies[0].aggregate_output_metadata().schema

def num_rows(self):
def _num_rows(self):
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
input_rows = self._input_dependencies[0].num_rows()
input_rows = self._input_dependencies[0].aggregate_output_metadata().num_rows
if input_rows is not None:
return min(input_rows, self._limit)
else:
return None

def _input_files(self):
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
return self._input_dependencies[0].aggregate_output_metadata().input_files
53 changes: 46 additions & 7 deletions python/ray/data/_internal/logical/operators/read_operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Any, Dict, List, Optional, Union
import functools
from typing import Any, Dict, Optional, Union

from ray.data._internal.logical.operators.map_operator import AbstractMap
from ray.data._internal.util import unify_block_metadata_schema
from ray.data.block import BlockMetadata
from ray.data.datasource.datasource import Datasource, Reader


Expand Down Expand Up @@ -43,11 +46,47 @@ def get_detected_parallelism(self) -> int:
"""
return self._detected_parallelism

def schema(self):
return self._datasource.schema()
@functools.cache
def aggregate_output_metadata(self) -> BlockMetadata:
"""A ``BlockMetadata`` that represents the aggregate metadata of the outputs.

def num_rows(self):
return self._datasource.num_rows()
This method gets metadata from the read tasks. It doesn't trigger any actual
execution.
"""
# Legacy datasources might not implement `get_read_tasks`.
if self._datasource.should_create_reader:
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
return BlockMetadata(None, None, None, None, None)

# HACK: Try to get a single read task to get the metadata.
read_tasks = self._datasource.get_read_tasks(1)
if len(read_tasks) == 0:
# If there are no read tasks, the dataset is probably empty.
return BlockMetadata(None, None, None, None, None)

# `get_read_tasks` isn't guaranteed to return exactly one read task.
metadata = [read_task.get_metadata() for read_task in read_tasks]

if all(meta.num_rows is not None for meta in metadata):
num_rows = sum(meta.num_rows for meta in metadata)
else:
num_rows = None

def input_files(self) -> Optional[List[str]]:
return self._datasource.input_files()
if all(meta.size_bytes is not None for meta in metadata):
size_bytes = sum(meta.size_bytes for meta in metadata)
else:
size_bytes = None

schema = unify_block_metadata_schema(metadata)

input_files = []
for meta in metadata:
if meta.input_files is not None:
input_files.extend(meta.input_files)

return BlockMetadata(
num_rows=num_rows,
size_bytes=size_bytes,
schema=schema,
input_files=input_files,
exec_stats=None,
)
Loading
Loading