Skip to content

Commit

Permalink
[Data] Revert PR ray-project#48186 and use `functools.cached_property…
Browse files Browse the repository at this point in the history
…` instead (ray-project#48436)

Signed-off-by: Chi-Sheng Liu <chishengliu@chishengliu.com>
  • Loading branch information
MortalHappiness authored Nov 7, 2024
1 parent ac84104 commit 218bdd7
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 34 deletions.
14 changes: 4 additions & 10 deletions python/ray/data/_internal/datasource/range_datasource.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import builtins
import functools
from copy import copy
from typing import Iterable, List, Optional, Tuple

Expand All @@ -24,7 +25,6 @@ def __init__(
self._block_format = block_format
self._tensor_shape = tensor_shape
self._column_name = column_name
self._schema_cache = None

def estimate_inmemory_data_size(self) -> Optional[int]:
if self._block_format == "tensor":
Expand Down Expand Up @@ -96,7 +96,7 @@ def make_blocks(
meta = BlockMetadata(
num_rows=count,
size_bytes=8 * count * element_size,
schema=copy(self._get_schema()),
schema=copy(self._schema),
input_files=None,
exec_stats=None,
)
Expand All @@ -112,14 +112,8 @@ def make_blocks(

return read_tasks

def _get_schema(self):
"""Get the schema, using cached value if available."""
if self._schema_cache is None:
self._schema_cache = self._compute_schema()
return self._schema_cache

def _compute_schema(self):
"""Compute the schema without caching."""
@functools.cached_property
def _schema(self):
if self._n == 0:
return None

Expand Down
11 changes: 4 additions & 7 deletions python/ray/data/_internal/logical/operators/from_operators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import functools
from typing import TYPE_CHECKING, List, Optional, Union

from ray.data._internal.execution.interfaces import RefBundle
Expand Down Expand Up @@ -31,7 +32,6 @@ def __init__(
RefBundle([(input_blocks[i], input_metadata[i])], owns_blocks=False)
for i in range(len(input_blocks))
]
self._output_metadata_cache = None

@property
def input_data(self) -> List[RefBundle]:
Expand All @@ -41,13 +41,10 @@ def output_data(self) -> Optional[List[RefBundle]]:
return self._input_data

def aggregate_output_metadata(self) -> BlockMetadata:
"""Get aggregated output metadata, using cache if available."""
if self._output_metadata_cache is None:
self._output_metadata_cache = self._compute_output_metadata()
return self._output_metadata_cache
return self._cached_output_metadata

def _compute_output_metadata(self) -> BlockMetadata:
"""Compute the output metadata without caching."""
@functools.cached_property
def _cached_output_metadata(self) -> BlockMetadata:
return BlockMetadata(
num_rows=self._num_rows(),
size_bytes=self._size_bytes(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from typing import Callable, List, Optional

from ray.data._internal.execution.interfaces import RefBundle
Expand Down Expand Up @@ -26,24 +27,20 @@ def __init__(
)
self.input_data = input_data
self.input_data_factory = input_data_factory
self._output_metadata_cache = None

def output_data(self) -> Optional[List[RefBundle]]:
if self.input_data is None:
return None
return self.input_data

def aggregate_output_metadata(self) -> BlockMetadata:
"""Get aggregated output metadata, using cache if available."""
return self._cached_output_metadata

@functools.cached_property
def _cached_output_metadata(self) -> BlockMetadata:
if self.input_data is None:
return BlockMetadata(None, None, None, None, None)

if self._output_metadata_cache is None:
self._output_metadata_cache = self._compute_output_metadata()
return self._output_metadata_cache

def _compute_output_metadata(self) -> BlockMetadata:
"""Compute the output metadata without caching."""
return BlockMetadata(
num_rows=self._num_rows(),
size_bytes=self._size_bytes(),
Expand Down
14 changes: 5 additions & 9 deletions python/ray/data/_internal/logical/operators/read_operator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from typing import Any, Dict, Optional, Union

from ray.data._internal.logical.operators.map_operator import AbstractMap
Expand Down Expand Up @@ -31,7 +32,6 @@ def __init__(
self._mem_size = mem_size
self._concurrency = concurrency
self._detected_parallelism = None
self._output_metadata_cache = None

def set_detected_parallelism(self, parallelism: int):
"""
Expand All @@ -52,15 +52,11 @@ def aggregate_output_metadata(self) -> BlockMetadata:
This method gets metadata from the read tasks. It doesn't trigger any actual
execution.
"""
# Legacy datasources might not implement `get_read_tasks`.
if self._output_metadata_cache is not None:
return self._output_metadata_cache

self._output_metadata_cache = self._compute_output_metadata()
return self._output_metadata_cache
return self._cached_output_metadata

def _compute_output_metadata(self) -> BlockMetadata:
"""Compute the output metadata without caching."""
@functools.cached_property
def _cached_output_metadata(self) -> BlockMetadata:
# Legacy datasources might not implement `get_read_tasks`.
if self._datasource.should_create_reader:
return BlockMetadata(None, None, None, None, None)

Expand Down

0 comments on commit 218bdd7

Please sign in to comment.