Skip to content

Commit

Permalink
[Data] Try get size_bytes from metadata and consolidate metadata me…
Browse files Browse the repository at this point in the history
…thods (#46862)

LogicalOperator and Datasource expose methods like schema() to make metadata efficiently available to Dataset APIs like Dataset.schema(). Currently, LogicalOperator exposes three such methds: num_rows(), schema(), and input_files().

This PR adds size_bytes() because it was missing. To simplify the interface, it also consolidates the metadata methods into a single aggregate_output_metadata() method.

Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
  • Loading branch information
bveeramani authored Jul 31, 2024
1 parent 9b03733 commit a13fc9a
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 120 deletions.
16 changes: 0 additions & 16 deletions python/ray/data/_internal/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,22 +454,6 @@ def get_name(self):
def supports_distributed_reads(self) -> bool:
return self._supports_distributed_reads

def num_rows(self) -> Optional[int]:
# If there is a filter operation, the total row count is unknown.
if self._to_batches_kwargs.get("filter") is not None:
return None

if not self._metadata:
return None

return sum(metadata.num_rows for metadata in self._metadata)

def schema(self) -> "pyarrow.Schema":
return self._inferred_schema

def input_files(self) -> Optional[List[str]]:
return self._pq_paths


def _read_fragments(
block_udf,
Expand Down
7 changes: 2 additions & 5 deletions python/ray/data/_internal/datasource/range_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def make_blocks(
meta = BlockMetadata(
num_rows=count,
size_bytes=8 * count * element_size,
schema=copy(self.schema()),
schema=copy(self._schema()),
input_files=None,
exec_stats=None,
)
Expand All @@ -113,7 +113,7 @@ def make_blocks(
return read_tasks

@functools.cache
def schema(self):
def _schema(self):
if self._n == 0:
return None

Expand All @@ -137,6 +137,3 @@ def schema(self):
else:
raise ValueError("Unsupported block type", self._block_format)
return schema

def num_rows(self):
return self._n
31 changes: 9 additions & 22 deletions python/ray/data/_internal/logical/interfaces/logical_operator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Iterator, List, Optional

from .operator import Operator
from ray.data.block import BlockMetadata

if TYPE_CHECKING:
import pyarrow

from ray.data._internal.execution.interfaces import RefBundle


Expand Down Expand Up @@ -57,29 +56,17 @@ def output_dependencies(self) -> List["LogicalOperator"]:
def post_order_iter(self) -> Iterator["LogicalOperator"]:
return super().post_order_iter() # type: ignore

def schema(self) -> Optional[Union[type, "pyarrow.lib.Schema"]]:
"""The schema of operator outputs, or ``None`` if not known.
This method is used to get the dataset schema without performing actual
computation.
"""
def output_data(self) -> Optional[List["RefBundle"]]:
"""The output data of this operator, or ``None`` if not known."""
return None

def num_rows(self) -> Optional[int]:
"""The number of rows outputted by this operator, or ``None`` if not known.
def aggregate_output_metadata(self) -> BlockMetadata:
"""A ``BlockMetadata`` that represents the aggregate metadata of the outputs.
This method is used to count the number of rows in a dataset without performing
actual computation.
This method is used by methods like :meth:`~ray.data.Dataset.schema` to
efficiently return metadata.
"""
return None

def input_files(self) -> Optional[List[str]]:
"""The input files of this operator, or ``None`` if not known."""
return None

def output_data(self) -> Optional[List["RefBundle"]]:
"""The output data of this operator, or ``None`` if not known."""
return None
return BlockMetadata(None, None, None, None, None)

def is_lineage_serializable(self) -> bool:
"""Returns whether the lineage of this operator can be serialized.
Expand Down
21 changes: 15 additions & 6 deletions python/ray/data/_internal/logical/operators/all_to_all_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ray.data._internal.planner.exchange.shuffle_task_spec import ShuffleTaskSpec
from ray.data._internal.planner.exchange.sort_task_spec import SortKey, SortTaskSpec
from ray.data.aggregate import AggregateFn
from ray.data.block import BlockMetadata


class AbstractAllToAll(LogicalOperator):
Expand Down Expand Up @@ -50,13 +51,9 @@ def __init__(
)
self._seed = seed

def schema(self):
def aggregate_output_metadata(self) -> BlockMetadata:
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
return self._input_dependencies[0].schema()

def num_rows(self):
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
return self._input_dependencies[0].num_rows()
return self._input_dependencies[0].aggregate_output_metadata()


class RandomShuffle(AbstractAllToAll):
Expand All @@ -80,6 +77,10 @@ def __init__(
)
self._seed = seed

def aggregate_output_metadata(self) -> BlockMetadata:
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
return self._input_dependencies[0].aggregate_output_metadata()


class Repartition(AbstractAllToAll):
"""Logical operator for repartition."""
Expand Down Expand Up @@ -107,6 +108,10 @@ def __init__(
)
self._shuffle = shuffle

def aggregate_output_metadata(self) -> BlockMetadata:
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
return self._input_dependencies[0].aggregate_output_metadata()


class Sort(AbstractAllToAll):
"""Logical operator for sort."""
Expand All @@ -127,6 +132,10 @@ def __init__(
)
self._sort_key = sort_key

def aggregate_output_metadata(self) -> BlockMetadata:
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
return self._input_dependencies[0].aggregate_output_metadata()


class Aggregate(AbstractAllToAll):
"""Logical operator for aggregate."""
Expand Down
30 changes: 24 additions & 6 deletions python/ray/data/_internal/logical/operators/from_operators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import functools
from typing import TYPE_CHECKING, List, Optional, Union

from ray.data._internal.execution.interfaces import RefBundle
Expand Down Expand Up @@ -36,18 +37,35 @@ def __init__(
def input_data(self) -> List[RefBundle]:
return self._input_data

def schema(self):
metadata = [m for bundle in self._input_data for m in bundle.metadata]
return unify_block_metadata_schema(metadata)
def output_data(self) -> Optional[List[RefBundle]]:
return self._input_data

def num_rows(self):
@functools.cache
def aggregate_output_metadata(self) -> BlockMetadata:
return BlockMetadata(
num_rows=self._num_rows(),
size_bytes=self._size_bytes(),
schema=self._schema(),
input_files=None,
exec_stats=None,
)

def _num_rows(self):
if all(bundle.num_rows() is not None for bundle in self._input_data):
return sum(bundle.num_rows() for bundle in self._input_data)
else:
return None

def output_data(self) -> Optional[List[RefBundle]]:
return self._input_data
def _size_bytes(self):
metadata = [m for bundle in self._input_data for m in bundle.metadata]
if all(m.size_bytes is not None for m in metadata):
return sum(m.size_bytes for m in metadata)
else:
return None

def _schema(self):
metadata = [m for bundle in self._input_data for m in bundle.metadata]
return unify_block_metadata_schema(metadata)

def is_lineage_serializable(self) -> bool:
# This operator isn't serializable because it contains ObjectRefs.
Expand Down
40 changes: 30 additions & 10 deletions python/ray/data/_internal/logical/operators/input_data_operator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import functools
from typing import Callable, List, Optional

from ray.data._internal.execution.interfaces import RefBundle
from ray.data._internal.logical.interfaces import LogicalOperator
from ray.data._internal.util import unify_block_metadata_schema
from ray.data.block import BlockMetadata


class InputData(LogicalOperator):
Expand All @@ -26,25 +28,43 @@ def __init__(
self.input_data = input_data
self.input_data_factory = input_data_factory

def schema(self):
def output_data(self) -> Optional[List[RefBundle]]:
if self.input_data is None:
return None
return self.input_data

metadata = [m for bundle in self.input_data for m in bundle.metadata]
return unify_block_metadata_schema(metadata)

def num_rows(self):
@functools.cache
def aggregate_output_metadata(self) -> BlockMetadata:
if self.input_data is None:
return None
elif all(bundle.num_rows() is not None for bundle in self.input_data):
return BlockMetadata(None, None, None, None, None)

return BlockMetadata(
num_rows=self._num_rows(),
size_bytes=self._size_bytes(),
schema=self._schema(),
input_files=None,
exec_stats=None,
)

def _num_rows(self):
assert self.input_data is not None
if all(bundle.num_rows() is not None for bundle in self.input_data):
return sum(bundle.num_rows() for bundle in self.input_data)
else:
return None

def output_data(self) -> Optional[List[RefBundle]]:
if self.input_data is None:
def _size_bytes(self):
assert self.input_data is not None
metadata = [m for bundle in self.input_data for m in bundle.metadata]
if all(m.size_bytes is not None for m in metadata):
return sum(m.size_bytes for m in metadata)
else:
return None
return self.input_data

def _schema(self):
assert self.input_data is not None
metadata = [m for bundle in self.input_data for m in bundle.metadata]
return unify_block_metadata_schema(metadata)

def is_lineage_serializable(self) -> bool:
# This operator isn't serializable because it contains ObjectRefs.
Expand Down
22 changes: 18 additions & 4 deletions python/ray/data/_internal/logical/operators/one_to_one_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional

from ray.data._internal.logical.interfaces import LogicalOperator
from ray.data.block import BlockMetadata


class AbstractOneToOne(LogicalOperator):
Expand Down Expand Up @@ -53,14 +54,27 @@ def __init__(
def can_modify_num_rows(self) -> bool:
return True

def schema(self):
def aggregate_output_metadata(self) -> BlockMetadata:
return BlockMetadata(
num_rows=self._num_rows(),
size_bytes=None,
schema=self._schema(),
input_files=self._input_files(),
exec_stats=None,
)

def _schema(self):
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
return self._input_dependencies[0].schema()
return self._input_dependencies[0].aggregate_output_metadata().schema

def num_rows(self):
def _num_rows(self):
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
input_rows = self._input_dependencies[0].num_rows()
input_rows = self._input_dependencies[0].aggregate_output_metadata().num_rows
if input_rows is not None:
return min(input_rows, self._limit)
else:
return None

def _input_files(self):
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
return self._input_dependencies[0].aggregate_output_metadata().input_files
53 changes: 46 additions & 7 deletions python/ray/data/_internal/logical/operators/read_operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Any, Dict, List, Optional, Union
import functools
from typing import Any, Dict, Optional, Union

from ray.data._internal.logical.operators.map_operator import AbstractMap
from ray.data._internal.util import unify_block_metadata_schema
from ray.data.block import BlockMetadata
from ray.data.datasource.datasource import Datasource, Reader


Expand Down Expand Up @@ -43,11 +46,47 @@ def get_detected_parallelism(self) -> int:
"""
return self._detected_parallelism

def schema(self):
return self._datasource.schema()
@functools.cache
def aggregate_output_metadata(self) -> BlockMetadata:
"""A ``BlockMetadata`` that represents the aggregate metadata of the outputs.
def num_rows(self):
return self._datasource.num_rows()
This method gets metadata from the read tasks. It doesn't trigger any actual
execution.
"""
# Legacy datasources might not implement `get_read_tasks`.
if self._datasource.should_create_reader:
return BlockMetadata(None, None, None, None, None)

# HACK: Try to get a single read task to get the metadata.
read_tasks = self._datasource.get_read_tasks(1)
if len(read_tasks) == 0:
# If there are no read tasks, the dataset is probably empty.
return BlockMetadata(None, None, None, None, None)

# `get_read_tasks` isn't guaranteed to return exactly one read task.
metadata = [read_task.get_metadata() for read_task in read_tasks]

if all(meta.num_rows is not None for meta in metadata):
num_rows = sum(meta.num_rows for meta in metadata)
else:
num_rows = None

def input_files(self) -> Optional[List[str]]:
return self._datasource.input_files()
if all(meta.size_bytes is not None for meta in metadata):
size_bytes = sum(meta.size_bytes for meta in metadata)
else:
size_bytes = None

schema = unify_block_metadata_schema(metadata)

input_files = []
for meta in metadata:
if meta.input_files is not None:
input_files.extend(meta.input_files)

return BlockMetadata(
num_rows=num_rows,
size_bytes=size_bytes,
schema=schema,
input_files=input_files,
exec_stats=None,
)
Loading

0 comments on commit a13fc9a

Please sign in to comment.