From edc51bdce6a5ec40e11f7d98423e1e39f7c37eda Mon Sep 17 00:00:00 2001 From: jianoaix Date: Thu, 8 Dec 2022 23:20:23 +0000 Subject: [PATCH 01/27] Fix read_tfrecords_benchmark nightly test Signed-off-by: jianoaix --- release/nightly_tests/dataset/read_tfrecords_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/release/nightly_tests/dataset/read_tfrecords_benchmark.py b/release/nightly_tests/dataset/read_tfrecords_benchmark.py index 45a5cea2d24b..bf189b450d8f 100644 --- a/release/nightly_tests/dataset/read_tfrecords_benchmark.py +++ b/release/nightly_tests/dataset/read_tfrecords_benchmark.py @@ -25,7 +25,7 @@ def generate_tfrecords_from_images( # Convert images from NumPy to bytes def images_to_bytes(batch): - images_as_bytes = [image.tobytes() for image in batch] + images_as_bytes = [image.tobytes() for image in batch.values()] return pa.table({"image": images_as_bytes}) ds = ds.map_batches(images_to_bytes, batch_format="numpy") From 253da6abbca65ce25591f2cf6187fe90bcee0e4c Mon Sep 17 00:00:00 2001 From: jianoaix Date: Fri, 27 Jan 2023 23:22:04 +0000 Subject: [PATCH 02/27] Make write an operator as part of the execution plan --- .../data/_internal/execution/legacy_compat.py | 14 ++-- .../operators/actor_pool_submitter.py | 11 ++- .../operators/all_to_all_operator.py | 6 +- .../execution/operators/map_operator_state.py | 8 +- .../execution/operators/map_task_submitter.py | 7 +- .../operators/task_pool_submitter.py | 4 +- python/ray/data/_internal/plan.py | 24 +++--- .../data/_internal/shuffle_and_partition.py | 2 +- python/ray/data/_internal/stage_impl.py | 10 ++- python/ray/data/dataset.py | 59 ++++++--------- python/ray/data/datasource/datasource.py | 31 +++++++- .../data/datasource/file_based_datasource.py | 74 ++++++++++++++++++- .../ray/data/datasource/mongo_datasource.py | 28 +++++++ python/ray/data/tests/conftest.py | 2 +- python/ray/data/tests/test_dataset_formats.py | 17 ++--- 15 files changed, 215 insertions(+), 82 deletions(-) diff --git a/python/ray/data/_internal/execution/legacy_compat.py b/python/ray/data/_internal/execution/legacy_compat.py index 6f4f0ff86199..c4e1e3d81a00 100644 --- a/python/ray/data/_internal/execution/legacy_compat.py +++ b/python/ray/data/_internal/execution/legacy_compat.py @@ -147,7 +147,7 @@ def _blocks_to_input_buffer(blocks: BlockList, owns_blocks: bool) -> PhysicalOpe for b in i.blocks: trace_allocation(b[0], "legacy_compat.blocks_to_input_buf[0]") - def do_read(blocks: Iterator[Block]) -> Iterator[Block]: + def do_read(blocks: Iterator[Block], task_idx: int) -> Iterator[Block]: for read_task in blocks: yield from read_task() @@ -214,8 +214,8 @@ def fn(item: Any) -> Any: fn_args += stage.fn_args fn_kwargs = stage.fn_kwargs or {} - def do_map(blocks: Iterator[Block]) -> Iterator[Block]: - yield from block_fn(blocks, *fn_args, **fn_kwargs) + def do_map(blocks: Iterator[Block], task_idx) -> Iterator[Block]: + yield from block_fn(blocks, task_idx, *fn_args, **fn_kwargs) return MapOperator( do_map, @@ -231,14 +231,18 @@ def do_map(blocks: Iterator[Block]) -> Iterator[Block]: remote_args = stage.ray_remote_args stage_name = stage.name - def bulk_fn(refs: List[RefBundle]) -> Tuple[List[RefBundle], StatsDict]: + def bulk_fn( + refs: List[RefBundle], task_idx: int + ) -> Tuple[List[RefBundle], StatsDict]: input_owned = all(b.owns_blocks for b in refs) if isinstance(stage, RandomizeBlocksStage): output_owned = input_owned # Passthrough ownership hack. else: output_owned = True block_list = _bundles_to_block_list(refs) - block_list, stats_dict = fn(block_list, input_owned, block_udf, remote_args) + block_list, stats_dict = fn( + block_list, task_idx, input_owned, block_udf, remote_args + ) output = _block_list_to_bundles(block_list, owns_blocks=output_owned) if not stats_dict: stats_dict = {stage_name: block_list.get_metadata()} diff --git a/python/ray/data/_internal/execution/operators/actor_pool_submitter.py b/python/ray/data/_internal/execution/operators/actor_pool_submitter.py index adcf95990b93..edf8c8706881 100644 --- a/python/ray/data/_internal/execution/operators/actor_pool_submitter.py +++ b/python/ray/data/_internal/execution/operators/actor_pool_submitter.py @@ -50,13 +50,13 @@ def start(self, options: ExecutionOptions): self._actor_pool.add_actor(cls_.remote()) def submit( - self, input_blocks: List[ObjectRef[Block]] + self, input_blocks: List[ObjectRef[Block]], task_idx ) -> ObjectRef[ObjectRefGenerator]: # Pick an actor from the pool. actor = self._actor_pool.pick_actor() # Submit the map task. ref = actor.submit.options(num_returns="dynamic").remote( - self._transform_fn_ref, *input_blocks + task_idx, self._transform_fn_ref, *input_blocks ) self._active_actors[ref] = actor return ref @@ -94,9 +94,12 @@ def ready(self): return "ok" def submit( - self, fn: Callable[[Iterator[Block]], Iterator[Block]], *blocks: Block + self, + task_idx: int, + fn: Callable[[Iterator[Block]], Iterator[Block]], + *blocks: Block, ) -> Iterator[Union[Block, List[BlockMetadata]]]: - yield from _map_task(fn, *blocks) + yield from _map_task(fn, task_idx, *blocks) class ActorPool: diff --git a/python/ray/data/_internal/execution/operators/all_to_all_operator.py b/python/ray/data/_internal/execution/operators/all_to_all_operator.py index 9a10adfee11d..4756d0a7f053 100644 --- a/python/ray/data/_internal/execution/operators/all_to_all_operator.py +++ b/python/ray/data/_internal/execution/operators/all_to_all_operator.py @@ -31,6 +31,7 @@ def __init__( name: The name of this operator. """ self._bulk_fn = bulk_fn + self._next_task_index = 0 self._num_outputs = num_outputs self._input_buffer: List[RefBundle] = [] self._output_buffer: List[RefBundle] = [] @@ -50,7 +51,10 @@ def add_input(self, refs: RefBundle, input_index: int) -> None: self._input_buffer.append(refs) def inputs_done(self) -> None: - self._output_buffer, self._stats = self._bulk_fn(self._input_buffer) + self._output_buffer, self._stats = self._bulk_fn( + self._input_buffer, self._next_task_index + ) + self._next_task_index += 1 self._input_buffer.clear() super().inputs_done() diff --git a/python/ray/data/_internal/execution/operators/map_operator_state.py b/python/ray/data/_internal/execution/operators/map_operator_state.py index 462d0024802a..5fc958b50c06 100644 --- a/python/ray/data/_internal/execution/operators/map_operator_state.py +++ b/python/ray/data/_internal/execution/operators/map_operator_state.py @@ -63,6 +63,9 @@ def __init__( raise ValueError(f"Unsupported execution strategy {compute_strategy}") self._task_submitter: MapTaskSubmitter = task_submitter + # Increment task index by one each time we submit a new task. + self._next_task_idx = 0 + # The temporary block bundle used to accumulate inputs until they meet the # min_rows_per_bundle requirement. self._block_bundle: Optional[RefBundle] = None @@ -182,7 +185,10 @@ def _create_task(self, bundle: RefBundle) -> None: # TODO fix for Ray client: https://github.com/ray-project/ray/issues/30458 if not DatasetContext.get_current().block_splitting_enabled: raise NotImplementedError("New backend requires block splitting") - ref: ObjectRef[ObjectRefGenerator] = self._task_submitter.submit(input_blocks) + ref: ObjectRef[ObjectRefGenerator] = self._task_submitter.submit( + input_blocks, self._next_task_idx + ) + self._next_task_idx += 1 task = _TaskState(bundle) self._tasks[ref] = task self._output_queue.notify_pending_task(task) diff --git a/python/ray/data/_internal/execution/operators/map_task_submitter.py b/python/ray/data/_internal/execution/operators/map_task_submitter.py index 88b17a97cedf..350e2e6137f4 100644 --- a/python/ray/data/_internal/execution/operators/map_task_submitter.py +++ b/python/ray/data/_internal/execution/operators/map_task_submitter.py @@ -51,7 +51,9 @@ def start(self, options: ExecutionOptions): @abstractmethod def submit( - self, input_blocks: List[ObjectRef[Block]] + self, + input_blocks: List[ObjectRef[Block]], + task_idx: int, ) -> Union[ ObjectRef[ObjectRefGenerator], Tuple[ObjectRef[Block], ObjectRef[BlockMetadata]] ]: @@ -93,6 +95,7 @@ def shutdown(self, task_refs: List[ObjectRef[Union[ObjectRefGenerator, Block]]]) def _map_task( fn: Callable[[Iterator[Block]], Iterator[Block]], + task_idx: int, *blocks: Block, ) -> Iterator[Union[Block, List[BlockMetadata]]]: """Remote function for a single operator task. @@ -108,7 +111,7 @@ def _map_task( """ output_metadata = [] stats = BlockExecStats.builder() - for b_out in fn(iter(blocks)): + for b_out in fn(iter(blocks), task_idx): # TODO(Clark): Add input file propagation from input blocks. m_out = BlockAccessor.for_block(b_out).get_metadata([], None) m_out.exec_stats = stats.build() diff --git a/python/ray/data/_internal/execution/operators/task_pool_submitter.py b/python/ray/data/_internal/execution/operators/task_pool_submitter.py index b850c22be71d..5d305d466598 100644 --- a/python/ray/data/_internal/execution/operators/task_pool_submitter.py +++ b/python/ray/data/_internal/execution/operators/task_pool_submitter.py @@ -15,12 +15,12 @@ class TaskPoolSubmitter(MapTaskSubmitter): """A task submitter for MapOperator that uses normal Ray tasks.""" def submit( - self, input_blocks: List[ObjectRef[Block]] + self, input_blocks: List[ObjectRef[Block]], task_idx: int ) -> ObjectRef[ObjectRefGenerator]: # Submit the task as a normal Ray task. map_task = cached_remote_fn(_map_task, num_returns="dynamic") return map_task.options(**self._ray_remote_args).remote( - self._transform_fn_ref, *input_blocks + self._transform_fn_ref, task_idx, *input_blocks ) def shutdown(self, task_refs: List[ObjectRef[Union[ObjectRefGenerator, Block]]]): diff --git a/python/ray/data/_internal/plan.py b/python/ray/data/_internal/plan.py index 3bc492e944e9..a1bc4ad2f01a 100644 --- a/python/ray/data/_internal/plan.py +++ b/python/ray/data/_internal/plan.py @@ -880,6 +880,7 @@ def fuse(self, prev: Stage): def block_fn( blocks: Iterable[Block], + task_idx: int, fn: UDF, *fn_args, **fn_kwargs, @@ -897,8 +898,8 @@ def block_fn( prev_fn_args = ( prev_fn_args if prev_fn_ is None else (prev_fn_,) + prev_fn_args ) - blocks = block_fn1(blocks, *prev_fn_args, **prev_fn_kwargs) - return block_fn2(blocks, *self_fn_args, **self_fn_kwargs) + blocks = block_fn1(blocks, task_idx, *prev_fn_args, **prev_fn_kwargs) + return block_fn2(blocks, task_idx, *self_fn_args, **self_fn_kwargs) return OneToOneStage( name, @@ -991,19 +992,22 @@ def fuse(self, prev: Stage): prev_block_fn = prev.block_fn if self.block_udf is None: - def block_udf(blocks: Iterable[Block]) -> Iterable[Block]: - yield from prev_block_fn(blocks, *prev_fn_args, **prev_fn_kwargs) + def block_udf(blocks: Iterable[Block], task_idx: int) -> Iterable[Block]: + yield from prev_block_fn( + blocks, task_idx, *prev_fn_args, **prev_fn_kwargs + ) else: self_block_udf = self.block_udf - def block_udf(blocks: Iterable[Block]) -> Iterable[Block]: + def block_udf(blocks: Iterable[Block], task_idx) -> Iterable[Block]: blocks = prev_block_fn( blocks, + task_idx, *prev_fn_args, **prev_fn_kwargs, ) - yield from self_block_udf(blocks) + yield from self_block_udf(blocks, task_idx) return AllToAllStage( name, self.num_blocks, self.fn, True, block_udf, prev.ray_remote_args @@ -1079,7 +1083,9 @@ def _rewrite_read_stage( ) @_adapt_for_multiple_blocks - def block_fn(read_fn: Callable[[], Iterator[Block]]) -> Iterator[Block]: + def block_fn( + read_fn: Callable[[], Iterator[Block]], task_idx: int + ) -> Iterator[Block]: for block in read_fn(): yield block @@ -1198,8 +1204,8 @@ def _adapt_for_multiple_blocks( fn: Callable[..., Iterable[Block]], ) -> Callable[..., Iterable[Block]]: @functools.wraps(fn) - def wrapper(blocks: Iterable[Block], *args, **kwargs): + def wrapper(blocks: Iterable[Block], task_idx, *args, **kwargs): for block in blocks: - yield from fn(block, *args, **kwargs) + yield from fn(block, task_idx, *args, **kwargs) return wrapper diff --git a/python/ray/data/_internal/shuffle_and_partition.py b/python/ray/data/_internal/shuffle_and_partition.py index 45ab69b4721f..f9287648cc98 100644 --- a/python/ray/data/_internal/shuffle_and_partition.py +++ b/python/ray/data/_internal/shuffle_and_partition.py @@ -37,7 +37,7 @@ def map( stats = BlockExecStats.builder() if block_udf: # TODO(ekl) note that this effectively disables block splitting. - blocks = list(block_udf([block])) + blocks = list(block_udf([block], idx)) if len(blocks) > 1: builder = BlockAccessor.for_block(blocks[0]).builder() for b in blocks: diff --git a/python/ray/data/_internal/stage_impl.py b/python/ray/data/_internal/stage_impl.py index 38bd467407c5..c29127851e9b 100644 --- a/python/ray/data/_internal/stage_impl.py +++ b/python/ray/data/_internal/stage_impl.py @@ -31,7 +31,11 @@ def __init__(self, num_blocks: int, shuffle: bool): if shuffle: def do_shuffle( - block_list, clear_input_blocks: bool, block_udf, remote_args + block_list, + task_idx: int, + clear_input_blocks: bool, + block_udf, + remote_args, ): if clear_input_blocks: blocks = block_list.copy() @@ -94,7 +98,9 @@ def __init__( output_num_blocks: Optional[int], remote_args: Optional[Dict[str, Any]] = None, ): - def do_shuffle(block_list, clear_input_blocks: bool, block_udf, remote_args): + def do_shuffle( + block_list, task_idx: int, clear_input_blocks: bool, block_udf, remote_args + ): num_blocks = block_list.executed_num_blocks() # Blocking. if num_blocks == 0: return block_list, {} diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 7ab51669992b..4b09e1d2115b 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1,7 +1,6 @@ import collections import itertools import logging -import os import sys import time import html @@ -69,7 +68,6 @@ ZipStage, SortStage, ) -from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.split import _split_at_index, _split_at_indices, _get_num_rows from ray.data._internal.stats import DatasetStats, DatasetStatsSummary @@ -108,7 +106,6 @@ WriteResult, ) from ray.data.datasource.file_based_datasource import ( - _unwrap_arrow_serialization_workaround, _wrap_arrow_serialization_workaround, ) from ray.data.random_access_dataset import RandomAccessDataset @@ -318,7 +315,7 @@ def map( context = DatasetContext.get_current() @_adapt_for_multiple_blocks - def transform(block: Block, fn: RowUDF[T, U]) -> Iterable[Block]: + def transform(block: Block, task_idx: int, fn: RowUDF[T, U]) -> Iterable[Block]: DatasetContext._set_current(context) output_buffer = BlockOutputBuffer(None, context.target_max_block_size) block = BlockAccessor.for_block(block) @@ -595,6 +592,7 @@ def map_batches( def transform( blocks: Iterable[Block], + task_idx: int, batch_fn: BatchUDF, *fn_args, **fn_kwargs, @@ -888,7 +886,7 @@ def flat_map( context = DatasetContext.get_current() @_adapt_for_multiple_blocks - def transform(block: Block, fn: RowUDF[T, U]) -> Iterable[Block]: + def transform(block: Block, task_idx, fn: RowUDF[T, U]) -> Iterable[Block]: DatasetContext._set_current(context) output_buffer = BlockOutputBuffer(None, context.target_max_block_size) block = BlockAccessor.for_block(block) @@ -968,7 +966,7 @@ def filter( context = DatasetContext.get_current() @_adapt_for_multiple_blocks - def transform(block: Block, fn: RowUDF[T, U]) -> Iterable[Block]: + def transform(block: Block, task_idx, fn: RowUDF[T, U]) -> Iterable[Block]: DatasetContext._set_current(context) block = BlockAccessor.for_block(block) builder = block.builder() @@ -2661,8 +2659,6 @@ def write_datasource( ray_remote_args: Kwargs passed to ray.remote in the write tasks. write_args: Additional write args to pass to the datasource. """ - - ctx = DatasetContext.get_current() if ray_remote_args is None: ray_remote_args = {} path = write_args.get("path", None) @@ -2676,37 +2672,26 @@ def write_datasource( soft=False, ) - blocks, metadata = zip(*self._plan.execute().get_blocks_with_metadata()) + def transform(blocks: Iterable[Block], task_idx, fn) -> []: + try: + datasource.sync_write(blocks, task_idx, **write_args) + datasource.on_write_complete([]) + except Exception as e: + datasource.on_write_failed([], e) + raise + return [] - # TODO(ekl) remove this feature flag. - if "RAY_DATASET_FORCE_LOCAL_METADATA" in os.environ: - write_results: List[ObjectRef[WriteResult]] = datasource.do_write( - blocks, metadata, ray_remote_args=ray_remote_args, **write_args - ) - else: - # Prepare write in a remote task so that in Ray client mode, we - # don't do metadata resolution from the client machine. - do_write = cached_remote_fn(_do_write, retry_exceptions=False, num_cpus=0) - write_results: List[ObjectRef[WriteResult]] = ray.get( - do_write.remote( - datasource, - ctx, - blocks, - metadata, - ray_remote_args, - _wrap_arrow_serialization_workaround(write_args), - ) + plan = self._plan.with_stage( + OneToOneStage( + "write", + transform, + "tasks", + ray_remote_args, + fn=lambda x: x, ) - - progress = ProgressBar("Write Progress", len(write_results)) - try: - progress.block_until_complete(write_results) - datasource.on_write_complete(ray.get(write_results)) - except Exception as e: - datasource.on_write_failed(write_results, e) - raise - finally: - progress.close() + ) + ds = Dataset(plan, self._epoch, self._lazy) + ds.fully_executed() def iterator(self) -> DatasetIterator: """Return a :class:`~ray.data.DatasetIterator` that diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index 98c0eb722f1c..773b69f88f0b 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -316,6 +316,8 @@ class DataSink: def __init__(self): self.rows_written = 0 self.enabled = True + self.num_ok = 0 + self.num_failed = 0 def write(self, block: Block) -> str: block = BlockAccessor.for_block(block) @@ -330,9 +332,30 @@ def get_rows_written(self): def set_enabled(self, enabled): self.enabled = enabled + def increment_ok(self): + self.num_ok += 1 + + def get_num_ok(self): + return self.num_ok + + def increment_failed(self): + self.num_failed += 1 + + def get_num_failed(self): + return self.num_failed + self.data_sink = DataSink.remote() - self.num_ok = 0 - self.num_failed = 0 + + def sync_write( + self, + blocks: Iterable[Block], + task_idx: int, + **write_args, + ): + tasks = [] + for b in blocks: + tasks.append(self.data_sink.write.remote(b)) + return ray.get(tasks) def do_write( self, @@ -348,12 +371,12 @@ def do_write( def on_write_complete(self, write_results: List[WriteResult]) -> None: assert all(w == "ok" for w in write_results), write_results - self.num_ok += 1 + self.data_sink.increment_ok.remote() def on_write_failed( self, write_results: List[ObjectRef[WriteResult]], error: Exception ) -> None: - self.num_failed += 1 + self.data_sink.increment_failed.remote() @DeveloperAPI diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index e4cd15550ba3..2352b30e4f27 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -17,6 +17,7 @@ from ray.data._internal.arrow_block import ArrowRow from ray.data._internal.block_list import BlockMetadata +from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data._internal.output_buffer import BlockOutputBuffer from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.util import _check_pyarrow_version, _resolve_custom_scheme @@ -60,7 +61,7 @@ def _get_write_path_for_block( *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, dataset_uuid: Optional[str] = None, - block: Optional[ObjectRef[Block]] = None, + block: Optional[Block] = None, block_index: Optional[int] = None, file_format: Optional[str] = None, ) -> str: @@ -77,7 +78,7 @@ def _get_write_path_for_block( write a file out to the write path returned. dataset_uuid: Unique identifier for the dataset that this block belongs to. - block: Object reference to the block to write. + block: The block to write. block_index: Ordered index of the block to write within its parent dataset. file_format: File format string for the block that can be used as @@ -94,7 +95,7 @@ def __call__( *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, dataset_uuid: Optional[str] = None, - block: Optional[ObjectRef[Block]] = None, + block: Optional[Block] = None, block_index: Optional[int] = None, file_format: Optional[str] = None, ) -> str: @@ -257,6 +258,73 @@ def _convert_block_to_tabular_block( "then you need to implement `_convert_block_to_tabular_block." ) + # This doesn't launch Ray tasks to write. + def sync_write( + self, + blocks: Iterable[Block], + task_idx: int, + path: str, + dataset_uuid: str, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + open_stream_args: Optional[Dict[str, Any]] = None, + block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(), + write_args_fn: Callable[[], Dict[str, Any]] = lambda: {}, + _block_udf: Optional[Callable[[Block], Block]] = None, + **write_args, + ): + """Creates and returns write tasks for a file-based datasource.""" + path, filesystem = _resolve_paths_and_filesystem(path, filesystem) + path = path[0] + if try_create_dir: + # Arrow's S3FileSystem doesn't allow creating buckets by default, so we add + # a query arg enabling bucket creation if an S3 URI is provided. + tmp = _add_creatable_buckets_param_if_s3_uri(path) + filesystem.create_dir(tmp, recursive=True) + filesystem = _wrap_s3_serialization_workaround(filesystem) + + _write_block_to_file = self._write_block + + if open_stream_args is None: + open_stream_args = {} + + def write_block(write_path: str, block: Block): + logger.debug(f"Writing {write_path} file.") + fs = filesystem + if isinstance(fs, _S3FileSystemWrapper): + fs = fs.unwrap() + if _block_udf is not None: + block = _block_udf(block) + + with fs.open_output_stream(write_path, **open_stream_args) as f: + _write_block_to_file( + f, + BlockAccessor.for_block(block), + writer_args_fn=write_args_fn, + **write_args, + ) + + file_format = self._FILE_EXTENSION + if isinstance(file_format, list): + file_format = file_format[0] + + builder = DelegatingBlockBuilder() + for block in blocks: + builder.add_block(block) + block = builder.build() + + if not block_path_provider: + block_path_provider = DefaultBlockWritePathProvider() + write_path = block_path_provider( + path, + filesystem=filesystem, + dataset_uuid=dataset_uuid, + block=block, + block_index=task_idx, + file_format=file_format, + ) + write_block(write_path, block) + def do_write( self, blocks: List[ObjectRef[Block]], diff --git a/python/ray/data/datasource/mongo_datasource.py b/python/ray/data/datasource/mongo_datasource.py index f1153271c532..18b6c71a863c 100644 --- a/python/ray/data/datasource/mongo_datasource.py +++ b/python/ray/data/datasource/mongo_datasource.py @@ -10,6 +10,7 @@ from ray.data._internal.remote_fn import cached_remote_fn from ray.types import ObjectRef from ray.util.annotations import PublicAPI +from typing import Iterable if TYPE_CHECKING: import pymongoarrow.api @@ -37,6 +38,33 @@ class MongoDatasource(Datasource): def create_reader(self, **kwargs) -> Reader: return _MongoDatasourceReader(**kwargs) + def sync_write( + self, + blocks: Iterable[Block], + task_idx: int, + uri: str, + database: str, + collection: str, + ) -> List[ObjectRef[WriteResult]]: + import pymongo + + _validate_database_collection_exist( + pymongo.MongoClient(uri), database, collection + ) + + def write_block(uri: str, database: str, collection: str, block: Block): + from pymongoarrow.api import write + + block = BlockAccessor.for_block(block).to_arrow() + client = pymongo.MongoClient(uri) + write(client[database][collection], block) + + write_tasks = [] + for block in blocks: + write_task = write_block(uri, database, collection, block) + write_tasks.append(write_task) + return write_tasks + def do_write( self, blocks: List[ObjectRef[Block]], diff --git a/python/ray/data/tests/conftest.py b/python/ray/data/tests/conftest.py index 7d013d9f84eb..edf9b28cf8d8 100644 --- a/python/ray/data/tests/conftest.py +++ b/python/ray/data/tests/conftest.py @@ -146,7 +146,7 @@ def _get_write_path_for_block( block_index=None, file_format=None, ): - num_rows = BlockAccessor.for_block(ray.get(block)).num_rows() + num_rows = BlockAccessor.for_block(block).num_rows() suffix = ( f"{block_index:06}_{num_rows:02}_{dataset_uuid}" f".test.{file_format}" ) diff --git a/python/ray/data/tests/test_dataset_formats.py b/python/ray/data/tests/test_dataset_formats.py index efbbef88886f..0cc5d818fdbc 100644 --- a/python/ray/data/tests/test_dataset_formats.py +++ b/python/ray/data/tests/test_dataset_formats.py @@ -172,23 +172,20 @@ def test_write_datasource(ray_start_regular_shared, pipelined): ds0 = ray.data.range(10, parallelism=2) ds = maybe_pipeline(ds0, pipelined) ds.write_datasource(output) - if pipelined: - assert output.num_ok == 2 - else: - assert output.num_ok == 1 - assert output.num_failed == 0 + assert ray.get(output.data_sink.get_num_ok.remote()) == 2 + assert ray.get(output.data_sink.get_num_failed.remote()) == 0 assert ray.get(output.data_sink.get_rows_written.remote()) == 10 ray.get(output.data_sink.set_enabled.remote(False)) ds = maybe_pipeline(ds0, pipelined) with pytest.raises(ValueError): - ds.write_datasource(output) + ds.write_datasource(output, ray_remote_args={"max_retries": 0}) + assert ray.get(output.data_sink.get_num_ok.remote()) == 2 + assert ray.get(output.data_sink.get_rows_written.remote()) == 10 if pipelined: - assert output.num_ok == 2 + assert ray.get(output.data_sink.get_num_failed.remote()) == 1 else: - assert output.num_ok == 1 - assert output.num_failed == 1 - assert ray.get(output.data_sink.get_rows_written.remote()) == 10 + assert ray.get(output.data_sink.get_num_failed.remote()) == 2 def test_from_tf(ray_start_regular_shared): From 514ec14756cbc02a06d54729e2fc95fad3a8f6bc Mon Sep 17 00:00:00 2001 From: jianoaix Date: Sat, 28 Jan 2023 00:03:19 +0000 Subject: [PATCH 03/27] fix / fix do_write --- python/ray/data/dataset.py | 18 +---- python/ray/data/datasource/datasource.py | 31 ++----- .../data/datasource/file_based_datasource.py | 81 +------------------ .../ray/data/datasource/mongo_datasource.py | 37 +-------- python/ray/data/tests/test_dataset_formats.py | 43 ++++++---- 5 files changed, 40 insertions(+), 170 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 4b09e1d2115b..480bc30a7a91 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -105,9 +105,6 @@ TFRecordDatasource, WriteResult, ) -from ray.data.datasource.file_based_datasource import ( - _wrap_arrow_serialization_workaround, -) from ray.data.random_access_dataset import RandomAccessDataset from ray.data.row import TableRow from ray.types import ObjectRef @@ -2674,7 +2671,7 @@ def write_datasource( def transform(blocks: Iterable[Block], task_idx, fn) -> []: try: - datasource.sync_write(blocks, task_idx, **write_args) + datasource.do_write(blocks, task_idx, **write_args) datasource.on_write_complete([]) except Exception as e: datasource.on_write_failed([], e) @@ -4414,16 +4411,3 @@ def _sliding_window(iterable: Iterable, n: int): for elem in it: window.append(elem) yield tuple(window) - - -def _do_write( - ds: Datasource, - ctx: DatasetContext, - blocks: List[Block], - meta: List[BlockMetadata], - ray_remote_args: Dict[str, Any], - write_args: Dict[str, Any], -) -> List[ObjectRef[WriteResult]]: - write_args = _unwrap_arrow_serialization_workaround(write_args) - DatasetContext._set_current(ctx) - return ds.do_write(blocks, meta, ray_remote_args=ray_remote_args, **write_args) diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index 773b69f88f0b..d274eb6c284c 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -1,5 +1,5 @@ import builtins -from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Generic, Iterable, List, Optional, Tuple, Union import numpy as np @@ -52,22 +52,17 @@ def prepare_read(self, parallelism: int, **read_args) -> List["ReadTask[T]"]: def do_write( self, - blocks: List[ObjectRef[Block]], - metadata: List[BlockMetadata], - ray_remote_args: Dict[str, Any], + blocks: Iterable[Block], **write_args, - ) -> List[ObjectRef[WriteResult]]: - """Launch Ray tasks for writing blocks out to the datasource. + ) -> WriteResult: + """Write blocks out to the datasource. This is used by a single write task. Args: - blocks: List of data block references. It is recommended that one - write task be generated per block. - metadata: List of block metadata. - ray_remote_args: Kwargs passed to ray.remote in the write tasks. + blocks: List of data blocks. write_args: Additional kwargs to pass to the datasource impl. Returns: - A list of the output of the write tasks. + The output of the write tasks. """ raise NotImplementedError @@ -346,7 +341,7 @@ def get_num_failed(self): self.data_sink = DataSink.remote() - def sync_write( + def do_write( self, blocks: Iterable[Block], task_idx: int, @@ -357,18 +352,6 @@ def sync_write( tasks.append(self.data_sink.write.remote(b)) return ray.get(tasks) - def do_write( - self, - blocks: List[ObjectRef[Block]], - metadata: List[BlockMetadata], - ray_remote_args: Dict[str, Any], - **write_args, - ) -> List[ObjectRef[WriteResult]]: - tasks = [] - for b in blocks: - tasks.append(self.data_sink.write.remote(b)) - return tasks - def on_write_complete(self, write_results: List[WriteResult]) -> None: assert all(w == "ok" for w in write_results), write_results self.data_sink.increment_ok.remote() diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index 2352b30e4f27..72ea194992c5 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -16,14 +16,12 @@ ) from ray.data._internal.arrow_block import ArrowRow -from ray.data._internal.block_list import BlockMetadata from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data._internal.output_buffer import BlockOutputBuffer -from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.util import _check_pyarrow_version, _resolve_custom_scheme from ray.data.block import Block, BlockAccessor from ray.data.context import DatasetContext -from ray.data.datasource.datasource import Datasource, Reader, ReadTask, WriteResult +from ray.data.datasource.datasource import Datasource, Reader, ReadTask from ray.data.datasource.file_meta_provider import ( BaseFileMetadataProvider, DefaultFileMetadataProvider, @@ -258,8 +256,7 @@ def _convert_block_to_tabular_block( "then you need to implement `_convert_block_to_tabular_block." ) - # This doesn't launch Ray tasks to write. - def sync_write( + def do_write( self, blocks: Iterable[Block], task_idx: int, @@ -323,79 +320,7 @@ def write_block(write_path: str, block: Block): block_index=task_idx, file_format=file_format, ) - write_block(write_path, block) - - def do_write( - self, - blocks: List[ObjectRef[Block]], - metadata: List[BlockMetadata], - path: str, - dataset_uuid: str, - filesystem: Optional["pyarrow.fs.FileSystem"] = None, - try_create_dir: bool = True, - open_stream_args: Optional[Dict[str, Any]] = None, - block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(), - write_args_fn: Callable[[], Dict[str, Any]] = lambda: {}, - _block_udf: Optional[Callable[[Block], Block]] = None, - ray_remote_args: Dict[str, Any] = None, - **write_args, - ) -> List[ObjectRef[WriteResult]]: - """Creates and returns write tasks for a file-based datasource.""" - path, filesystem = _resolve_paths_and_filesystem(path, filesystem) - path = path[0] - if try_create_dir: - # Arrow's S3FileSystem doesn't allow creating buckets by default, so we add - # a query arg enabling bucket creation if an S3 URI is provided. - tmp = _add_creatable_buckets_param_if_s3_uri(path) - filesystem.create_dir(tmp, recursive=True) - filesystem = _wrap_s3_serialization_workaround(filesystem) - - _write_block_to_file = self._write_block - - if open_stream_args is None: - open_stream_args = {} - - if ray_remote_args is None: - ray_remote_args = {} - - def write_block(write_path: str, block: Block): - logger.debug(f"Writing {write_path} file.") - fs = filesystem - if isinstance(fs, _S3FileSystemWrapper): - fs = fs.unwrap() - if _block_udf is not None: - block = _block_udf(block) - - with fs.open_output_stream(write_path, **open_stream_args) as f: - _write_block_to_file( - f, - BlockAccessor.for_block(block), - writer_args_fn=write_args_fn, - **write_args, - ) - - write_block = cached_remote_fn(write_block).options(**ray_remote_args) - - file_format = self._FILE_EXTENSION - if isinstance(file_format, list): - file_format = file_format[0] - - write_tasks = [] - if not block_path_provider: - block_path_provider = DefaultBlockWritePathProvider() - for block_idx, block in enumerate(blocks): - write_path = block_path_provider( - path, - filesystem=filesystem, - dataset_uuid=dataset_uuid, - block=block, - block_index=block_idx, - file_format=file_format, - ) - write_task = write_block.remote(write_path, block) - write_tasks.append(write_task) - - return write_tasks + return write_block(write_path, block) def _write_block( self, diff --git a/python/ray/data/datasource/mongo_datasource.py b/python/ray/data/datasource/mongo_datasource.py index 18b6c71a863c..fa6286694d91 100644 --- a/python/ray/data/datasource/mongo_datasource.py +++ b/python/ray/data/datasource/mongo_datasource.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional, TYPE_CHECKING from ray.data.datasource.datasource import Datasource, Reader, ReadTask, WriteResult from ray.data.block import ( @@ -7,7 +7,6 @@ BlockAccessor, BlockMetadata, ) -from ray.data._internal.remote_fn import cached_remote_fn from ray.types import ObjectRef from ray.util.annotations import PublicAPI from typing import Iterable @@ -38,7 +37,7 @@ class MongoDatasource(Datasource): def create_reader(self, **kwargs) -> Reader: return _MongoDatasourceReader(**kwargs) - def sync_write( + def do_write( self, blocks: Iterable[Block], task_idx: int, @@ -65,38 +64,6 @@ def write_block(uri: str, database: str, collection: str, block: Block): write_tasks.append(write_task) return write_tasks - def do_write( - self, - blocks: List[ObjectRef[Block]], - metadata: List[BlockMetadata], - ray_remote_args: Optional[Dict[str, Any]], - uri: str, - database: str, - collection: str, - ) -> List[ObjectRef[WriteResult]]: - import pymongo - - _validate_database_collection_exist( - pymongo.MongoClient(uri), database, collection - ) - - def write_block(uri: str, database: str, collection: str, block: Block): - from pymongoarrow.api import write - - block = BlockAccessor.for_block(block).to_arrow() - client = pymongo.MongoClient(uri) - write(client[database][collection], block) - - if ray_remote_args is None: - ray_remote_args = {} - - write_block = cached_remote_fn(write_block).options(**ray_remote_args) - write_tasks = [] - for block in blocks: - write_task = write_block.remote(uri, database, collection, block) - write_tasks.append(write_task) - return write_tasks - class _MongoDatasourceReader(Reader): def __init__( diff --git a/python/ray/data/tests/test_dataset_formats.py b/python/ray/data/tests/test_dataset_formats.py index 0cc5d818fdbc..000069d9333f 100644 --- a/python/ray/data/tests/test_dataset_formats.py +++ b/python/ray/data/tests/test_dataset_formats.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Union +from typing import List, Union import pandas as pd import pyarrow as pa @@ -13,7 +13,7 @@ import ray from ray.data._internal.arrow_block import ArrowRow -from ray.data.block import Block, BlockAccessor, BlockMetadata +from ray.data.block import Block, BlockAccessor from ray.data.datasource import ( Datasource, DummyOutputDatasource, @@ -24,6 +24,7 @@ from ray.data.tests.mock_http_server import * # noqa from ray.tests.conftest import * # noqa from ray.types import ObjectRef +from typing import Iterable def maybe_pipeline(ds, enabled): @@ -227,6 +228,8 @@ def __init__(self): self.rows_written = 0 self.enabled = True self.node_ids = set() + self.num_ok = 0 + self.num_failed = 0 def write(self, node_id: str, block: Block) -> str: block = BlockAccessor.for_block(block) @@ -245,40 +248,48 @@ def get_node_ids(self): def set_enabled(self, enabled): self.enabled = enabled + def increment_ok(self): + self.num_ok += 1 + + def get_num_ok(self): + return self.num_ok + + def increment_failed(self): + self.num_failed += 1 + + def get_num_failed(self): + return self.num_failed + self.data_sink = DataSink.remote() - self.num_ok = 0 - self.num_failed = 0 def do_write( self, - blocks: List[ObjectRef[Block]], - metadata: List[BlockMetadata], - ray_remote_args: Dict[str, Any], + blocks: Iterable[Block], + task_idx: int, **write_args, - ) -> List[ObjectRef[WriteResult]]: + ) -> WriteResult: data_sink = self.data_sink - @ray.remote def write(b): node_id = ray.get_runtime_context().get_node_id() return ray.get(data_sink.write.remote(node_id, b)) - tasks = [] for b in blocks: - tasks.append(write.options(**ray_remote_args).remote(b)) - return tasks + result = write(b) + return result def on_write_complete(self, write_results: List[WriteResult]) -> None: assert all(w == "ok" for w in write_results), write_results - self.num_ok += 1 + self.data_sink.increment_ok.remote() def on_write_failed( self, write_results: List[ObjectRef[WriteResult]], error: Exception ) -> None: - self.num_failed += 1 + self.data_sink.increment_failed.remote() def test_write_datasource_ray_remote_args(ray_start_cluster): + ray.shutdown() cluster = ray_start_cluster cluster.add_node( resources={"foo": 100}, @@ -298,8 +309,8 @@ def get_node_id(): ds = ray.data.range(100, parallelism=10) # Pin write tasks to ds.write_datasource(output, ray_remote_args={"resources": {"bar": 1}}) - assert output.num_ok == 1 - assert output.num_failed == 0 + assert ray.get(output.data_sink.get_num_ok.remote()) == 10 + assert ray.get(output.data_sink.get_num_failed.remote()) == 0 assert ray.get(output.data_sink.get_rows_written.remote()) == 100 node_ids = ray.get(output.data_sink.get_node_ids.remote()) From b75bf499e6b189e40de99657b19f4e3d4cc2ac1f Mon Sep 17 00:00:00 2001 From: jianoaix Date: Sat, 28 Jan 2023 00:46:38 +0000 Subject: [PATCH 04/27] Fix the merge --- .../execution/operators/actor_pool_map_operator.py | 8 +++++--- .../data/_internal/execution/operators/map_operator.py | 3 ++- .../execution/operators/task_pool_map_operator.py | 4 +++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py index fcee783b55a0..95cdc194bb76 100644 --- a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py @@ -63,6 +63,7 @@ def __init__( self._cls = None # Whether no more submittable bundles will be added. self._inputs_done = False + self._next_task_idx = 0 def start(self, options: ExecutionOptions): super().start(options) @@ -101,8 +102,9 @@ def _dispatch_tasks(self): bundle = self._bundle_queue.popleft() input_blocks = [block for block, _ in bundle.blocks] ref = actor.submit.options(num_returns="dynamic").remote( - self._transform_fn_ref, *input_blocks + self._transform_fn_ref, *input_blocks, self._next_task_idx ) + self._next_task_idx += 1 task = _TaskState(bundle) self._tasks[ref] = (task, actor) self._handle_task_submitted(task) @@ -210,9 +212,9 @@ def ready(self): return "ok" def submit( - self, fn: Callable[[Iterator[Block]], Iterator[Block]], *blocks: Block + self, fn: Callable[[Iterator[Block]], Iterator[Block]], task_idx, *blocks: Block ) -> Iterator[Union[Block, List[BlockMetadata]]]: - yield from _map_task(fn, *blocks) + yield from _map_task(fn, task_idx, *blocks) class _ActorPool: diff --git a/python/ray/data/_internal/execution/operators/map_operator.py b/python/ray/data/_internal/execution/operators/map_operator.py index 0fe279cc05c2..774b5de8810f 100644 --- a/python/ray/data/_internal/execution/operators/map_operator.py +++ b/python/ray/data/_internal/execution/operators/map_operator.py @@ -324,6 +324,7 @@ def to_metrics_dict(self) -> Dict[str, int]: def _map_task( fn: Callable[[Iterator[Block]], Iterator[Block]], + task_idx: int, *blocks: Block, ) -> Iterator[Union[Block, List[BlockMetadata]]]: """Remote function for a single operator task. @@ -339,7 +340,7 @@ def _map_task( """ output_metadata = [] stats = BlockExecStats.builder() - for b_out in fn(iter(blocks)): + for b_out in fn(iter(blocks), task_idx): # TODO(Clark): Add input file propagation from input blocks. m_out = BlockAccessor.for_block(b_out).get_metadata([], None) m_out.exec_stats = stats.build() diff --git a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py index 342d7e2b5f94..5e9c86af78a3 100644 --- a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py @@ -44,14 +44,16 @@ def __init__( transform_fn, input_op, name, min_rows_per_bundle, ray_remote_args ) self._tasks: Dict[ObjectRef[ObjectRefGenerator], _TaskState] = {} + self._next_task_idx = 0 def _add_bundled_input(self, bundle: RefBundle): # Submit the task as a normal Ray task. map_task = cached_remote_fn(_map_task, num_returns="dynamic") input_blocks = [block for block, _ in bundle.blocks] ref = map_task.options(**self._ray_remote_args).remote( - self._transform_fn_ref, *input_blocks + self._transform_fn_ref, self._next_task_idx, *input_blocks ) + self._next_task_idx += 1 task = _TaskState(bundle) self._tasks[ref] = task self._handle_task_submitted(task) From 6a282573e7c0c72eed143e7dfbc747f9ad599e05 Mon Sep 17 00:00:00 2001 From: jianoaix Date: Sat, 28 Jan 2023 00:56:44 +0000 Subject: [PATCH 05/27] fix arg passing --- .../_internal/execution/operators/actor_pool_map_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py index 95cdc194bb76..e4d9e66b252a 100644 --- a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py @@ -102,7 +102,7 @@ def _dispatch_tasks(self): bundle = self._bundle_queue.popleft() input_blocks = [block for block, _ in bundle.blocks] ref = actor.submit.options(num_returns="dynamic").remote( - self._transform_fn_ref, *input_blocks, self._next_task_idx + self._transform_fn_ref, self._next_task_idx, *input_blocks ) self._next_task_idx += 1 task = _TaskState(bundle) From 3082e52896c189071357d95681b803a36483377f Mon Sep 17 00:00:00 2001 From: jianoaix Date: Sat, 28 Jan 2023 01:02:16 +0000 Subject: [PATCH 06/27] lint --- python/ray/data/dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 5088624973e3..8cbe895b27e6 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -103,7 +103,6 @@ ParquetDatasource, ReadTask, TFRecordDatasource, - WriteResult, ) from ray.data.random_access_dataset import RandomAccessDataset from ray.data.row import TableRow @@ -2677,7 +2676,7 @@ def write_datasource( soft=False, ) - def transform(blocks: Iterable[Block], task_idx, fn) -> []: + def transform(blocks: Iterable[Block], task_idx, fn): try: datasource.do_write(blocks, task_idx, **write_args) datasource.on_write_complete([]) From 5ddfdc04bd87ba3c775ab2a573efa72e58106730 Mon Sep 17 00:00:00 2001 From: jianoaix Date: Tue, 31 Jan 2023 16:36:11 +0000 Subject: [PATCH 07/27] Reconcile taskcontext --- python/ray/data/_internal/execution/legacy_compat.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/ray/data/_internal/execution/legacy_compat.py b/python/ray/data/_internal/execution/legacy_compat.py index e5e4b92157d1..5bc7e3435035 100644 --- a/python/ray/data/_internal/execution/legacy_compat.py +++ b/python/ray/data/_internal/execution/legacy_compat.py @@ -31,6 +31,7 @@ Executor, PhysicalOperator, RefBundle, + TaskContext, ) @@ -147,7 +148,7 @@ def _blocks_to_input_buffer(blocks: BlockList, owns_blocks: bool) -> PhysicalOpe for b in i.blocks: trace_allocation(b[0], "legacy_compat.blocks_to_input_buf[0]") - def do_read(blocks: Iterator[Block], task_idx: int) -> Iterator[Block]: + def do_read(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: for read_task in blocks: yield from read_task() @@ -214,8 +215,8 @@ def fn(item: Any) -> Any: fn_args += stage.fn_args fn_kwargs = stage.fn_kwargs or {} - def do_map(blocks: Iterator[Block], task_idx) -> Iterator[Block]: - yield from block_fn(blocks, task_idx, *fn_args, **fn_kwargs) + def do_map(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: + yield from block_fn(blocks, ctx, *fn_args, **fn_kwargs) return MapOperator.create( do_map, @@ -232,7 +233,7 @@ def do_map(blocks: Iterator[Block], task_idx) -> Iterator[Block]: stage_name = stage.name def bulk_fn( - refs: List[RefBundle], task_idx: int + refs: List[RefBundle], ctx: TaskContext ) -> Tuple[List[RefBundle], StatsDict]: input_owned = all(b.owns_blocks for b in refs) if isinstance(stage, RandomizeBlocksStage): @@ -241,7 +242,7 @@ def bulk_fn( output_owned = True block_list = _bundles_to_block_list(refs) block_list, stats_dict = fn( - block_list, task_idx, input_owned, block_udf, remote_args + block_list, ctx, input_owned, block_udf, remote_args ) output = _block_list_to_bundles(block_list, owns_blocks=output_owned) if not stats_dict: From 3843ff4552b5c5f62aa905d4f2fd3b8e076a622f Mon Sep 17 00:00:00 2001 From: jianoaix Date: Tue, 31 Jan 2023 16:38:47 +0000 Subject: [PATCH 08/27] Reconcile taskcontext continued --- .../operators/actor_pool_map_operator.py | 255 ++++++++++++++++-- .../operators/all_to_all_operator.py | 11 +- .../execution/operators/map_operator.py | 26 +- .../operators/task_pool_map_operator.py | 4 +- python/ray/data/_internal/plan.py | 25 +- .../data/_internal/shuffle_and_partition.py | 4 +- python/ray/data/_internal/stage_impl.py | 9 +- 7 files changed, 280 insertions(+), 54 deletions(-) diff --git a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py index e4d9e66b252a..d5a2bafa7a6d 100644 --- a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py @@ -1,14 +1,17 @@ import collections +from dataclasses import dataclass from typing import Dict, Any, Iterator, Callable, List, Tuple, Union, Optional import ray from ray.data.block import Block, BlockMetadata from ray.data.context import DatasetContext, DEFAULT_SCHEDULING_STRATEGY +from ray.data._internal.compute import ActorPoolStrategy from ray.data._internal.execution.interfaces import ( RefBundle, ExecutionResources, ExecutionOptions, PhysicalOperator, + TaskContext, ) from ray.data._internal.execution.operators.map_operator import ( MapOperator, @@ -26,37 +29,39 @@ def __init__( self, transform_fn: Callable[[Iterator[Block]], Iterator[Block]], input_op: PhysicalOperator, + autoscaling_policy: "AutoscalingPolicy", name: str = "ActorPoolMap", min_rows_per_bundle: Optional[int] = None, ray_remote_args: Optional[Dict[str, Any]] = None, - pool_size: int = 1, ): """Create an ActorPoolMapOperator instance. Args: transform_fn: The function to apply to each ref bundle input. input_op: Operator generating input data for this op. + autoscaling_policy: A policy controlling when the actor pool should be + scaled up and scaled down. name: The name of this operator. min_rows_per_bundle: The number of rows to gather per batch passed to the transform_fn, or None to use the block size. Setting the batch size is important for the performance of GPU-accelerated transform functions. The actual rows passed may be less if the dataset is small. ray_remote_args: Customize the ray remote args for this op's tasks. - pool_size: The desired size of the actor pool. """ super().__init__( transform_fn, input_op, name, min_rows_per_bundle, ray_remote_args ) self._ray_remote_args = self._apply_default_remote_args(self._ray_remote_args) - self._pool_size = pool_size + # Create autoscaling policy from compute strategy. + self._autoscaling_policy = autoscaling_policy # A map from task output futures to task state and the actor on which its # running. self._tasks: Dict[ ObjectRef[ObjectRefGenerator], Tuple[_TaskState, ray.actor.ActorHandle] ] = {} # A pool of running actors on which we can execute mapper tasks. - self._actor_pool = _ActorPool() + self._actor_pool = _ActorPool(autoscaling_policy._config.max_tasks_in_flight) # A queue of bundles awaiting dispatch to actors. self._bundle_queue = collections.deque() # Cached actor class. @@ -65,12 +70,15 @@ def __init__( self._inputs_done = False self._next_task_idx = 0 + def internal_queue_size(self) -> int: + return len(self._bundle_queue) + def start(self, options: ExecutionOptions): super().start(options) # Create the actor workers and add them to the pool. self._cls = ray.remote(**self._ray_remote_args)(_MapWorker) - for _ in range(self._pool_size): + for _ in range(self._autoscaling_policy.min_workers): self._start_actor() def _start_actor(self): @@ -101,17 +109,46 @@ def _dispatch_tasks(self): # Submit the map task. bundle = self._bundle_queue.popleft() input_blocks = [block for block, _ in bundle.blocks] + ctx = TaskContext(task_idx=self._next_task_idx) ref = actor.submit.options(num_returns="dynamic").remote( - self._transform_fn_ref, self._next_task_idx, *input_blocks + self._transform_fn_ref, ctx, *input_blocks ) self._next_task_idx += 1 task = _TaskState(bundle) self._tasks[ref] = (task, actor) self._handle_task_submitted(task) + if self._bundle_queue: + # Try to scale up if work remains in the work queue. + self._scale_up_if_needed() + else: + # Only try to scale down if the work queue has been fully consumed. + self._scale_down_if_needed() + + def _scale_up_if_needed(self): + """Try to scale up the pool if the autoscaling policy allows it.""" + while self._autoscaling_policy.should_scale_up( + num_total_workers=self._actor_pool.num_total_actors(), + num_running_workers=self._actor_pool.num_running_actors(), + ): + self._start_actor() + + def _scale_down_if_needed(self): + """Try to scale down the pool if the autoscaling policy allows it.""" # Kill inactive workers if there's no more work to do. self._kill_inactive_workers_if_done() + while self._autoscaling_policy.should_scale_down( + num_total_workers=self._actor_pool.num_total_actors(), + num_idle_workers=self._actor_pool.num_idle_actors(), + ): + killed = self._actor_pool.kill_inactive_actor() + if not killed: + # This scaledown is best-effort, only killing an inactive worker if an + # inactive worker exists. If there are no inactive workers to kill, we + # break out of the scale-down loop. + break + def notify_work_completed( self, ref: Union[ObjectRef[ObjectRefGenerator], ray.ObjectRef] ): @@ -143,9 +180,8 @@ def inputs_done(self): # once the bundle queue is exhausted. self._inputs_done = True - # Manually trigger inactive worker termination in case the bundle queue is - # alread exhausted. - self._kill_inactive_workers_if_done() + # Try to scale pool down. + self._scale_down_if_needed() def _kill_inactive_workers_if_done(self): if self._inputs_done and not self._bundle_queue: @@ -168,13 +204,14 @@ def num_active_work_refs(self) -> int: return len(self._tasks) def progress_str(self) -> str: - return ( - f"{self._actor_pool.num_running_actors()} " - f"({self._actor_pool.num_pending_actors()} pending)" - ) + base = f"{self._actor_pool.num_running_actors()} actors" + pending = self._actor_pool.num_pending_actors() + if pending: + base += f" ({pending} pending)" + return base def base_resource_usage(self) -> ExecutionResources: - min_workers = self._pool_size + min_workers = self._autoscaling_policy.min_workers return ExecutionResources( cpu=self._ray_remote_args.get("num_cpus", 0) * min_workers, gpu=self._ray_remote_args.get("num_gpus", 0) * min_workers, @@ -190,7 +227,22 @@ def current_resource_usage(self) -> ExecutionResources: ) def incremental_resource_usage(self) -> ExecutionResources: - return ExecutionResources(cpu=0, gpu=0) + # We would only have nonzero incremental CPU/GPU resources if a new task would + # require scale-up to run. + if self._autoscaling_policy.should_scale_up( + num_total_workers=self._actor_pool.num_total_actors(), + num_running_workers=self._actor_pool.num_running_actors(), + ): + # A new task would trigger scale-up, so we include the actor resouce + # requests in the incremental resources. + num_cpus = self._ray_remote_args.get("num_cpus", 0) + num_gpus = self._ray_remote_args.get("num_gpus", 0) + else: + # A new task wouldn't trigger scale-up, so we consider the incremental + # compute resources to be 0. + num_cpus = 0 + num_gpus = 0 + return ExecutionResources(cpu=num_cpus, gpu=num_gpus) @staticmethod def _apply_default_remote_args(ray_remote_args: Dict[str, Any]) -> Dict[str, Any]: @@ -212,9 +264,134 @@ def ready(self): return "ok" def submit( - self, fn: Callable[[Iterator[Block]], Iterator[Block]], task_idx, *blocks: Block + self, + fn: Callable[[Iterator[Block], TaskContext], Iterator[Block]], + ctx, + *blocks: Block, ) -> Iterator[Union[Block, List[BlockMetadata]]]: - yield from _map_task(fn, task_idx, *blocks) + yield from _map_task(fn, ctx, *blocks) + + +# TODO(Clark): Promote this to a public config once we deprecate the legacy compute +# strategies. +@dataclass +class AutoscalingConfig: + """Configuration for an autoscaling actor pool.""" + + # Minimum number of workers in the actor pool. + min_workers: int + # Maximum number of workers in the actor pool. + max_workers: int + # Maximum number of tasks that can be in flight for a single worker. + # TODO(Clark): Have this informed by the prefetch_batches configuration, once async + # prefetching has been ported to this new actor pool. + max_tasks_in_flight: int = 2 + # Minimum ratio of ready workers to the total number of workers. If the pool is + # above this ratio, it will be allowed to be scaled up. + ready_to_total_workers_ratio: float = 0.8 + # Maximum ratio of idle workers to the total number of workers. If the pool goes + # above this ratio, the pool will be scaled down. + idle_to_total_workers_ratio: float = 0.5 + + def __post_init__(self): + if self.min_workers < 1: + raise ValueError("min_workers must be >= 1, got: ", self.min_workers) + if self.max_workers is not None and self.min_workers > self.max_workers: + raise ValueError( + "min_workers must be <= max_workers, got: ", + self.min_workers, + self.max_workers, + ) + if self.max_tasks_in_flight < 1: + raise ValueError( + "max_tasks_in_flight must be >= 1, got: ", + self.max_tasks_in_flight, + ) + + @classmethod + def from_compute_strategy(cls, compute_strategy: ActorPoolStrategy): + """Convert a legacy ActorPoolStrategy to an AutoscalingConfig.""" + # TODO(Clark): Remove this once the legacy compute strategies are deprecated. + assert isinstance(compute_strategy, ActorPoolStrategy) + return cls( + min_workers=compute_strategy.min_size, + max_workers=compute_strategy.max_size, + max_tasks_in_flight=compute_strategy.max_tasks_in_flight_per_actor, + ready_to_total_workers_ratio=compute_strategy.ready_to_total_workers_ratio, + ) + + +class AutoscalingPolicy: + """Autoscaling policy for an actor pool, determining when the pool should be scaled + up and when it should be scaled down. + """ + + def __init__(self, autoscaling_config: "AutoscalingConfig"): + self._config = autoscaling_config + + @property + def min_workers(self) -> int: + """The minimum number of actors that must be in the actor pool.""" + return self._config.min_workers + + @property + def max_workers(self) -> int: + """The maximum number of actors that can be added to the actor pool.""" + return self._config.max_workers + + def should_scale_up(self, num_total_workers: int, num_running_workers: int) -> bool: + """Whether the actor pool should scale up by adding a new actor. + + Args: + num_total_workers: Total number of workers in actor pool. + num_running_workers: Number of currently running workers in actor pool. + + Returns: + Whether the actor pool should be scaled up by one actor. + """ + # TODO(Clark): Replace the ready-to-total-ratio heuristic with a a work queue + # heuristic such that scale-up is only triggered if the current pool doesn't + # have enough worker slots to process the work queue. + # TODO(Clark): Use profiling of the bundle arrival rate, worker startup + # time, and task execution time to tailor the work queue heuristic to the + # running workload and observed Ray performance. E.g. this could be done via an + # augmented EMA using a queueing model + return ( + # 1. The actor pool will not exceed the configured maximum size. + num_total_workers < self._config.max_workers + # TODO(Clark): Remove this once we have a good work queue heuristic and our + # resource-based backpressure is working well. + # 2. At least 80% of the workers in the pool have already started. This will + # ensure that workers will be launched in parallel while bounding the worker + # pool to requesting 125% of the cluster's available resources. + and num_running_workers / num_total_workers + > self._config.ready_to_total_workers_ratio + ) + + def should_scale_down( + self, + num_total_workers: int, + num_idle_workers: int, + ) -> bool: + """Whether the actor pool should scale down by terminating an inactive actor. + + Args: + num_total_workers: Total number of workers in actor pool. + num_idle_workers: Number of currently idle workers in the actor pool. + + Returns: + Whether the actor pool should be scaled down by one actor. + """ + # TODO(Clark): Add an idleness timeout-based scale-down. + # TODO(Clark): Make the idleness timeout dynamically determined by bundle + # arrival rate, worker startup time, and task execution time. + return ( + # 1. The actor pool will not go below the configured minimum size. + num_total_workers > self._config.min_workers + # 2. The actor pool contains more than 50% idle workers. + and num_idle_workers / num_total_workers + > self._config.idle_to_total_workers_ratio + ) class _ActorPool: @@ -225,7 +402,8 @@ class _ActorPool: actors when the operator is done submitting work to the pool. """ - def __init__(self): + def __init__(self, max_tasks_in_flight: int = float("inf")): + self._max_tasks_in_flight = max_tasks_in_flight # Number of tasks in flight per actor. self._num_tasks_in_flight: Dict[ray.actor.ActorHandle, int] = {} # Actors that are not yet ready (still pending creation). @@ -273,7 +451,8 @@ def pick_actor(self) -> Optional[ray.actor.ActorHandle]: """Provides the least heavily loaded running actor in the pool for task submission. - None will be returned if all actors are still pending. + None will be returned if all actors are either at capacity (according to + max_tasks_in_flight) or are still pending. """ if not self._num_tasks_in_flight: # Actor pool is empty or all actors are still pending. @@ -283,8 +462,12 @@ def pick_actor(self) -> Optional[ray.actor.ActorHandle]: self._num_tasks_in_flight.keys(), key=lambda actor: self._num_tasks_in_flight[actor], ) - self._num_tasks_in_flight[actor] += 1 - return actor + if self._num_tasks_in_flight[actor] >= self._max_tasks_in_flight: + # All actors are at capacity. + return None + else: + self._num_tasks_in_flight[actor] += 1 + return actor def return_actor(self, actor: ray.actor.ActorHandle): """Returns the provided actor to the pool.""" @@ -326,6 +509,36 @@ def num_active_actors(self) -> int: for num_tasks_in_flight in self._num_tasks_in_flight.values() ) + def kill_inactive_actor(self) -> bool: + """Kills a single pending or idle actor, if any actors are pending/idle. + + Returns whether an inactive actor was actually killed. + """ + # We prioritize killing pending actors over idle actors to reduce actor starting + # churn. + killed = self._maybe_kill_pending_actor() + if not killed: + # If no pending actor was killed, so kill actor. + killed = self._maybe_kill_idle_actor() + return killed + + def _maybe_kill_pending_actor(self) -> bool: + if self._pending_actors: + # At least one pending actor, so kill first one. + self._kill_pending_actor(next(iter(self._pending_actors.keys()))) + return True + # No pending actors, so indicate to the caller that no actors were killed. + return False + + def _maybe_kill_idle_actor(self) -> bool: + for actor, tasks_in_flight in self._num_tasks_in_flight.items(): + if tasks_in_flight == 0: + # At least one idle actor, so kill first one found. + self._kill_running_actor(actor) + return True + # No idle actors, so indicate to the caller that no actors were killed. + return False + def kill_all_inactive_actors(self): """Kills all currently inactive actors and ensures that all actors that become idle in the future will be eagerly killed. diff --git a/python/ray/data/_internal/execution/operators/all_to_all_operator.py b/python/ray/data/_internal/execution/operators/all_to_all_operator.py index 4756d0a7f053..fdc1d850e711 100644 --- a/python/ray/data/_internal/execution/operators/all_to_all_operator.py +++ b/python/ray/data/_internal/execution/operators/all_to_all_operator.py @@ -1,9 +1,11 @@ -from typing import List, Callable, Optional, Tuple +from typing import List, Optional from ray.data._internal.stats import StatsDict from ray.data._internal.execution.interfaces import ( + AllToAllTransformFn, RefBundle, PhysicalOperator, + TaskContext, ) @@ -15,7 +17,7 @@ class AllToAllOperator(PhysicalOperator): def __init__( self, - bulk_fn: Callable[[List[RefBundle]], Tuple[List[RefBundle], StatsDict]], + bulk_fn: AllToAllTransformFn, input_op: PhysicalOperator, num_outputs: Optional[int] = None, name: str = "AllToAll", @@ -51,9 +53,8 @@ def add_input(self, refs: RefBundle, input_index: int) -> None: self._input_buffer.append(refs) def inputs_done(self) -> None: - self._output_buffer, self._stats = self._bulk_fn( - self._input_buffer, self._next_task_index - ) + ctx = TaskContext(task_idx=self._next_task_index) + self._output_buffer, self._stats = self._bulk_fn(self._input_buffer, ctx) self._next_task_index += 1 self._input_buffer.clear() super().inputs_done() diff --git a/python/ray/data/_internal/execution/operators/map_operator.py b/python/ray/data/_internal/execution/operators/map_operator.py index 774b5de8810f..545b5ed5058c 100644 --- a/python/ray/data/_internal/execution/operators/map_operator.py +++ b/python/ray/data/_internal/execution/operators/map_operator.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass import itertools -from typing import List, Iterator, Any, Dict, Callable, Optional, Union +from typing import List, Iterator, Any, Dict, Optional, Union import ray from ray.data.block import Block, BlockAccessor, BlockMetadata, BlockExecStats @@ -15,6 +15,8 @@ ExecutionOptions, ExecutionResources, PhysicalOperator, + TaskContext, + MapTransformFn, ) from ray.data._internal.memory_tracing import trace_allocation from ray.data._internal.stats import StatsDict @@ -32,7 +34,7 @@ class MapOperator(PhysicalOperator, ABC): def __init__( self, - transform_fn: Callable[[Iterator[Block]], Iterator[Block]], + transform_fn: MapTransformFn, input_op: PhysicalOperator, name: str, min_rows_per_bundle: Optional[int], @@ -62,7 +64,7 @@ def __init__( @classmethod def create( cls, - transform_fn: Callable[[Iterator[Block]], Iterator[Block]], + transform_fn: MapTransformFn, input_op: PhysicalOperator, name: str = "Map", # TODO(ekl): slim down ComputeStrategy to only specify the compute @@ -107,19 +109,21 @@ def create( elif isinstance(compute_strategy, ActorPoolStrategy): from ray.data._internal.execution.operators.actor_pool_map_operator import ( ActorPoolMapOperator, + AutoscalingConfig, + AutoscalingPolicy, ) - pool_size = compute_strategy.max_size - if pool_size == float("inf"): - # Use min_size if max_size is unbounded (default). - pool_size = compute_strategy.min_size + autoscaling_config = AutoscalingConfig.from_compute_strategy( + compute_strategy + ) + autoscaling_policy = AutoscalingPolicy(autoscaling_config) return ActorPoolMapOperator( transform_fn, input_op, + autoscaling_policy=autoscaling_policy, name=name, min_rows_per_bundle=min_rows_per_bundle, ray_remote_args=ray_remote_args, - pool_size=pool_size, ) else: raise ValueError(f"Unsupported execution strategy {compute_strategy}") @@ -323,8 +327,8 @@ def to_metrics_dict(self) -> Dict[str, int]: def _map_task( - fn: Callable[[Iterator[Block]], Iterator[Block]], - task_idx: int, + fn: MapTransformFn, + ctx: TaskContext, *blocks: Block, ) -> Iterator[Union[Block, List[BlockMetadata]]]: """Remote function for a single operator task. @@ -340,7 +344,7 @@ def _map_task( """ output_metadata = [] stats = BlockExecStats.builder() - for b_out in fn(iter(blocks), task_idx): + for b_out in fn(iter(blocks), ctx): # TODO(Clark): Add input file propagation from input blocks. m_out = BlockAccessor.for_block(b_out).get_metadata([], None) m_out.exec_stats = stats.build() diff --git a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py index 5e9c86af78a3..2004667c05ad 100644 --- a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py @@ -6,6 +6,7 @@ RefBundle, ExecutionResources, PhysicalOperator, + TaskContext, ) from ray.data._internal.execution.operators.map_operator import ( MapOperator, @@ -50,8 +51,9 @@ def _add_bundled_input(self, bundle: RefBundle): # Submit the task as a normal Ray task. map_task = cached_remote_fn(_map_task, num_returns="dynamic") input_blocks = [block for block, _ in bundle.blocks] + ctx = TaskContext(task_idx=self._next_task_idx) ref = map_task.options(**self._ray_remote_args).remote( - self._transform_fn_ref, self._next_task_idx, *input_blocks + self._transform_fn_ref, ctx, *input_blocks ) self._next_task_idx += 1 task = _TaskState(bundle) diff --git a/python/ray/data/_internal/plan.py b/python/ray/data/_internal/plan.py index 52b40cccb133..c1c30078daf5 100644 --- a/python/ray/data/_internal/plan.py +++ b/python/ray/data/_internal/plan.py @@ -29,6 +29,7 @@ is_task_compute, ) from ray.data._internal.dataset_logger import DatasetLogger +from ray.data._internal.execution.interfaces import TaskContext from ray.data._internal.lazy_block_list import LazyBlockList from ray.data._internal.stats import DatasetStats, DatasetStatsSummary from ray.data.block import Block @@ -904,7 +905,7 @@ def fuse(self, prev: Stage): def block_fn( blocks: Iterable[Block], - task_idx: int, + ctx: TaskContext, fn: UDF, *fn_args, **fn_kwargs, @@ -922,8 +923,8 @@ def block_fn( prev_fn_args = ( prev_fn_args if prev_fn_ is None else (prev_fn_,) + prev_fn_args ) - blocks = block_fn1(blocks, task_idx, *prev_fn_args, **prev_fn_kwargs) - return block_fn2(blocks, task_idx, *self_fn_args, **self_fn_kwargs) + blocks = block_fn1(blocks, ctx, *prev_fn_args, **prev_fn_kwargs) + return block_fn2(blocks, ctx, *self_fn_args, **self_fn_kwargs) return OneToOneStage( name, @@ -1016,22 +1017,20 @@ def fuse(self, prev: Stage): prev_block_fn = prev.block_fn if self.block_udf is None: - def block_udf(blocks: Iterable[Block], task_idx: int) -> Iterable[Block]: - yield from prev_block_fn( - blocks, task_idx, *prev_fn_args, **prev_fn_kwargs - ) + def block_udf(blocks: Iterable[Block], ctx: TaskContext) -> Iterable[Block]: + yield from prev_block_fn(blocks, ctx, *prev_fn_args, **prev_fn_kwargs) else: self_block_udf = self.block_udf - def block_udf(blocks: Iterable[Block], task_idx) -> Iterable[Block]: + def block_udf(blocks: Iterable[Block], ctx: TaskContext) -> Iterable[Block]: blocks = prev_block_fn( blocks, - task_idx, + ctx, *prev_fn_args, **prev_fn_kwargs, ) - yield from self_block_udf(blocks, task_idx) + yield from self_block_udf(blocks, ctx) return AllToAllStage( name, self.num_blocks, self.fn, True, block_udf, prev.ray_remote_args @@ -1108,7 +1107,7 @@ def _rewrite_read_stage( @_adapt_for_multiple_blocks def block_fn( - read_fn: Callable[[], Iterator[Block]], task_idx: int + read_fn: Callable[[], Iterator[Block]], ctx: TaskContext ) -> Iterator[Block]: for block in read_fn(): yield block @@ -1228,8 +1227,8 @@ def _adapt_for_multiple_blocks( fn: Callable[..., Iterable[Block]], ) -> Callable[..., Iterable[Block]]: @functools.wraps(fn) - def wrapper(blocks: Iterable[Block], task_idx, *args, **kwargs): + def wrapper(blocks: Iterable[Block], ctx: TaskContext, *args, **kwargs): for block in blocks: - yield from fn(block, task_idx, *args, **kwargs) + yield from fn(block, ctx, *args, **kwargs) return wrapper diff --git a/python/ray/data/_internal/shuffle_and_partition.py b/python/ray/data/_internal/shuffle_and_partition.py index f9287648cc98..db278705993f 100644 --- a/python/ray/data/_internal/shuffle_and_partition.py +++ b/python/ray/data/_internal/shuffle_and_partition.py @@ -4,6 +4,7 @@ import numpy as np from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.execution.interfaces import TaskContext from ray.data._internal.push_based_shuffle import PushBasedShufflePlan from ray.data._internal.shuffle import ShuffleOp, SimpleShufflePlan from ray.data.block import Block, BlockAccessor, BlockExecStats, BlockMetadata @@ -36,8 +37,9 @@ def map( ) -> List[Union[BlockMetadata, Block]]: stats = BlockExecStats.builder() if block_udf: + ctx = TaskContext(task_idx=idx) # TODO(ekl) note that this effectively disables block splitting. - blocks = list(block_udf([block], idx)) + blocks = list(block_udf([block], ctx)) if len(blocks) > 1: builder = BlockAccessor.for_block(blocks[0]).builder() for b in blocks: diff --git a/python/ray/data/_internal/stage_impl.py b/python/ray/data/_internal/stage_impl.py index c29127851e9b..da3c236a0360 100644 --- a/python/ray/data/_internal/stage_impl.py +++ b/python/ray/data/_internal/stage_impl.py @@ -8,6 +8,7 @@ SimpleShufflePartitionOp, ) from ray.data._internal.block_list import BlockList +from ray.data._internal.execution.interfaces import TaskContext from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.sort import sort_impl from ray.data.context import DatasetContext @@ -32,7 +33,7 @@ def __init__(self, num_blocks: int, shuffle: bool): def do_shuffle( block_list, - task_idx: int, + ctx: TaskContext, clear_input_blocks: bool, block_udf, remote_args, @@ -99,7 +100,11 @@ def __init__( remote_args: Optional[Dict[str, Any]] = None, ): def do_shuffle( - block_list, task_idx: int, clear_input_blocks: bool, block_udf, remote_args + block_list, + ctx: TaskContext, + clear_input_blocks: bool, + block_udf, + remote_args, ): num_blocks = block_list.executed_num_blocks() # Blocking. if num_blocks == 0: From 84a74f03e628cde9ce1b6470f0e60f406b7cf39f Mon Sep 17 00:00:00 2001 From: jianoaix Date: Tue, 31 Jan 2023 16:57:47 +0000 Subject: [PATCH 09/27] Use task context in write op --- python/ray/data/dataset.py | 4 ++-- python/ray/data/datasource/datasource.py | 3 ++- python/ray/data/datasource/file_based_datasource.py | 5 +++-- python/ray/data/datasource/mongo_datasource.py | 3 ++- python/ray/data/tests/test_dataset_formats.py | 3 ++- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index f6035881629a..097bc92bc307 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2626,9 +2626,9 @@ def write_datasource( soft=False, ) - def transform(blocks: Iterable[Block], task_idx, fn): + def transform(blocks: Iterable[Block], ctx, fn): try: - datasource.do_write(blocks, task_idx, **write_args) + datasource.do_write(blocks, ctx, **write_args) datasource.on_write_complete([]) except Exception as e: datasource.on_write_failed([], e) diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index d274eb6c284c..e2de30ec2949 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -6,6 +6,7 @@ import ray from ray.data._internal.arrow_block import ArrowRow from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.execution.interfaces import TaskContext from ray.data._internal.util import _check_pyarrow_version from ray.data.block import ( Block, @@ -344,7 +345,7 @@ def get_num_failed(self): def do_write( self, blocks: Iterable[Block], - task_idx: int, + ctx: TaskContext, **write_args, ): tasks = [] diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index 72ea194992c5..47f4b7067134 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -17,6 +17,7 @@ from ray.data._internal.arrow_block import ArrowRow from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.execution.interfaces import TaskContext from ray.data._internal.output_buffer import BlockOutputBuffer from ray.data._internal.util import _check_pyarrow_version, _resolve_custom_scheme from ray.data.block import Block, BlockAccessor @@ -259,7 +260,7 @@ def _convert_block_to_tabular_block( def do_write( self, blocks: Iterable[Block], - task_idx: int, + ctx: TaskContext, path: str, dataset_uuid: str, filesystem: Optional["pyarrow.fs.FileSystem"] = None, @@ -317,7 +318,7 @@ def write_block(write_path: str, block: Block): filesystem=filesystem, dataset_uuid=dataset_uuid, block=block, - block_index=task_idx, + block_index=ctx.task_idx, file_format=file_format, ) return write_block(write_path, block) diff --git a/python/ray/data/datasource/mongo_datasource.py b/python/ray/data/datasource/mongo_datasource.py index fa6286694d91..9c4440d57b8e 100644 --- a/python/ray/data/datasource/mongo_datasource.py +++ b/python/ray/data/datasource/mongo_datasource.py @@ -7,6 +7,7 @@ BlockAccessor, BlockMetadata, ) +from ray.data._internal.execution.interfaces import TaskContext from ray.types import ObjectRef from ray.util.annotations import PublicAPI from typing import Iterable @@ -40,7 +41,7 @@ def create_reader(self, **kwargs) -> Reader: def do_write( self, blocks: Iterable[Block], - task_idx: int, + ctx: TaskContext, uri: str, database: str, collection: str, diff --git a/python/ray/data/tests/test_dataset_formats.py b/python/ray/data/tests/test_dataset_formats.py index 000069d9333f..aedc55d19045 100644 --- a/python/ray/data/tests/test_dataset_formats.py +++ b/python/ray/data/tests/test_dataset_formats.py @@ -13,6 +13,7 @@ import ray from ray.data._internal.arrow_block import ArrowRow +from ray.data._internal.execution.interfaces import TaskContext from ray.data.block import Block, BlockAccessor from ray.data.datasource import ( Datasource, @@ -265,7 +266,7 @@ def get_num_failed(self): def do_write( self, blocks: Iterable[Block], - task_idx: int, + ctx: TaskContext, **write_args, ) -> WriteResult: data_sink = self.data_sink From bb2a47451544cfe4afad5dfa1c4c53f891816038 Mon Sep 17 00:00:00 2001 From: jianoaix Date: Tue, 31 Jan 2023 21:44:30 +0000 Subject: [PATCH 10/27] fix test --- python/ray/data/tests/test_size_estimation.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/ray/data/tests/test_size_estimation.py b/python/ray/data/tests/test_size_estimation.py index 35cdce44fed9..018a85394260 100644 --- a/python/ray/data/tests/test_size_estimation.py +++ b/python/ray/data/tests/test_size_estimation.py @@ -158,9 +158,16 @@ def test_split_read_parquet(ray_start_regular_shared, tmp_path): def gen(name): path = os.path.join(tmp_path, name) - ray.data.range(200000, parallelism=1).map( - lambda _: uuid.uuid4().hex - ).write_parquet(path) + ds = ( + ray.data.range(200000, parallelism=1) + .map(lambda _: uuid.uuid4().hex) + .fully_executed() + ) + # Fully execute the operations prior to write, because with + # parallelism=1, there is only one task; so the write operator + # will only write to one file, even though there are multiple + # blocks created by block splitting. + ds.write_parquet(path) return ray.data.read_parquet(path, parallelism=200) # 20MiB From ad5f7c7506b11d3125dc26d3c5a029760264588f Mon Sep 17 00:00:00 2001 From: jianoaix Date: Tue, 31 Jan 2023 22:07:33 +0000 Subject: [PATCH 11/27] feedback: backward compatibility --- python/ray/data/dataset.py | 2 +- python/ray/data/datasource/datasource.py | 26 +++++- .../data/datasource/file_based_datasource.py | 81 ++++++++++++++++++- .../ray/data/datasource/mongo_datasource.py | 37 ++++++++- 4 files changed, 138 insertions(+), 8 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 097bc92bc307..55e2d96c8c1b 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2628,7 +2628,7 @@ def write_datasource( def transform(blocks: Iterable[Block], ctx, fn): try: - datasource.do_write(blocks, ctx, **write_args) + datasource.direct_write(blocks, ctx, **write_args) datasource.on_write_complete([]) except Exception as e: datasource.on_write_failed([], e) diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index e2de30ec2949..307cb3e941e1 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -1,5 +1,5 @@ import builtins -from typing import Any, Callable, Generic, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Tuple, Union import numpy as np @@ -51,7 +51,7 @@ def prepare_read(self, parallelism: int, **read_args) -> List["ReadTask[T]"]: """Deprecated: Please implement create_reader() instead.""" raise NotImplementedError - def do_write( + def direct_write( self, blocks: Iterable[Block], **write_args, @@ -67,6 +67,28 @@ def do_write( """ raise NotImplementedError + @Deprecated + def do_write( + self, + blocks: List[ObjectRef[Block]], + metadata: List[BlockMetadata], + ray_remote_args: Dict[str, Any], + **write_args, + ) -> List[ObjectRef[WriteResult]]: + """Launch Ray tasks for writing blocks out to the datasource. + + Args: + blocks: List of data block references. It is recommended that one + write task be generated per block. + metadata: List of block metadata. + ray_remote_args: Kwargs passed to ray.remote in the write tasks. + write_args: Additional kwargs to pass to the datasource impl. + + Returns: + A list of the output of the write tasks. + """ + raise NotImplementedError + def on_write_complete(self, write_results: List[WriteResult], **kwargs) -> None: """Callback for when a write job completes. diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index 47f4b7067134..b57ae6a7a32e 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -17,12 +17,14 @@ from ray.data._internal.arrow_block import ArrowRow from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.block_list import BlockMetadata from ray.data._internal.execution.interfaces import TaskContext from ray.data._internal.output_buffer import BlockOutputBuffer +from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.util import _check_pyarrow_version, _resolve_custom_scheme from ray.data.block import Block, BlockAccessor from ray.data.context import DatasetContext -from ray.data.datasource.datasource import Datasource, Reader, ReadTask +from ray.data.datasource.datasource import Datasource, Reader, ReadTask, WriteResult from ray.data.datasource.file_meta_provider import ( BaseFileMetadataProvider, DefaultFileMetadataProvider, @@ -34,7 +36,7 @@ ) from ray.types import ObjectRef -from ray.util.annotations import DeveloperAPI, PublicAPI +from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI from ray._private.utils import _add_creatable_buckets_param_if_s3_uri if TYPE_CHECKING: @@ -257,7 +259,7 @@ def _convert_block_to_tabular_block( "then you need to implement `_convert_block_to_tabular_block." ) - def do_write( + def direct_write( self, blocks: Iterable[Block], ctx: TaskContext, @@ -323,6 +325,79 @@ def write_block(write_path: str, block: Block): ) return write_block(write_path, block) + @Deprecated + def do_write( + self, + blocks: List[ObjectRef[Block]], + metadata: List[BlockMetadata], + path: str, + dataset_uuid: str, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + open_stream_args: Optional[Dict[str, Any]] = None, + block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(), + write_args_fn: Callable[[], Dict[str, Any]] = lambda: {}, + _block_udf: Optional[Callable[[Block], Block]] = None, + ray_remote_args: Dict[str, Any] = None, + **write_args, + ) -> List[ObjectRef[WriteResult]]: + """Creates and returns write tasks for a file-based datasource.""" + path, filesystem = _resolve_paths_and_filesystem(path, filesystem) + path = path[0] + if try_create_dir: + # Arrow's S3FileSystem doesn't allow creating buckets by default, so we add + # a query arg enabling bucket creation if an S3 URI is provided. + tmp = _add_creatable_buckets_param_if_s3_uri(path) + filesystem.create_dir(tmp, recursive=True) + filesystem = _wrap_s3_serialization_workaround(filesystem) + + _write_block_to_file = self._write_block + + if open_stream_args is None: + open_stream_args = {} + + if ray_remote_args is None: + ray_remote_args = {} + + def write_block(write_path: str, block: Block): + logger.debug(f"Writing {write_path} file.") + fs = filesystem + if isinstance(fs, _S3FileSystemWrapper): + fs = fs.unwrap() + if _block_udf is not None: + block = _block_udf(block) + + with fs.open_output_stream(write_path, **open_stream_args) as f: + _write_block_to_file( + f, + BlockAccessor.for_block(block), + writer_args_fn=write_args_fn, + **write_args, + ) + + write_block = cached_remote_fn(write_block).options(**ray_remote_args) + + file_format = self._FILE_EXTENSION + if isinstance(file_format, list): + file_format = file_format[0] + + write_tasks = [] + if not block_path_provider: + block_path_provider = DefaultBlockWritePathProvider() + for block_idx, block in enumerate(blocks): + write_path = block_path_provider( + path, + filesystem=filesystem, + dataset_uuid=dataset_uuid, + block=block, + block_index=block_idx, + file_format=file_format, + ) + write_task = write_block.remote(write_path, block) + write_tasks.append(write_task) + + return write_tasks + def _write_block( self, f: "pyarrow.NativeFile", diff --git a/python/ray/data/datasource/mongo_datasource.py b/python/ray/data/datasource/mongo_datasource.py index 9c4440d57b8e..76168f9ab886 100644 --- a/python/ray/data/datasource/mongo_datasource.py +++ b/python/ray/data/datasource/mongo_datasource.py @@ -1,6 +1,7 @@ import logging -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import Any, Dict, List, Optional, TYPE_CHECKING +from ray.data._internal.remote_fn import cached_remote_fn from ray.data.datasource.datasource import Datasource, Reader, ReadTask, WriteResult from ray.data.block import ( Block, @@ -38,7 +39,7 @@ class MongoDatasource(Datasource): def create_reader(self, **kwargs) -> Reader: return _MongoDatasourceReader(**kwargs) - def do_write( + def direct_write( self, blocks: Iterable[Block], ctx: TaskContext, @@ -65,6 +66,38 @@ def write_block(uri: str, database: str, collection: str, block: Block): write_tasks.append(write_task) return write_tasks + def do_write( + self, + blocks: List[ObjectRef[Block]], + metadata: List[BlockMetadata], + ray_remote_args: Optional[Dict[str, Any]], + uri: str, + database: str, + collection: str, + ) -> List[ObjectRef[WriteResult]]: + import pymongo + + _validate_database_collection_exist( + pymongo.MongoClient(uri), database, collection + ) + + def write_block(uri: str, database: str, collection: str, block: Block): + from pymongoarrow.api import write + + block = BlockAccessor.for_block(block).to_arrow() + client = pymongo.MongoClient(uri) + write(client[database][collection], block) + + if ray_remote_args is None: + ray_remote_args = {} + + write_block = cached_remote_fn(write_block).options(**ray_remote_args) + write_tasks = [] + for block in blocks: + write_task = write_block.remote(uri, database, collection, block) + write_tasks.append(write_task) + return write_tasks + class _MongoDatasourceReader(Reader): def __init__( From a77053ade777690daa79f5885facc969f6ab2a71 Mon Sep 17 00:00:00 2001 From: jianoaix Date: Tue, 31 Jan 2023 22:18:33 +0000 Subject: [PATCH 12/27] fix --- python/ray/data/datasource/datasource.py | 2 +- python/ray/data/datasource/mongo_datasource.py | 5 +++-- python/ray/data/tests/test_dataset_formats.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index 307cb3e941e1..7cf70a14401f 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -364,7 +364,7 @@ def get_num_failed(self): self.data_sink = DataSink.remote() - def do_write( + def direct_write( self, blocks: Iterable[Block], ctx: TaskContext, diff --git a/python/ray/data/datasource/mongo_datasource.py b/python/ray/data/datasource/mongo_datasource.py index 76168f9ab886..f45f220884ea 100644 --- a/python/ray/data/datasource/mongo_datasource.py +++ b/python/ray/data/datasource/mongo_datasource.py @@ -1,16 +1,16 @@ import logging from typing import Any, Dict, List, Optional, TYPE_CHECKING -from ray.data._internal.remote_fn import cached_remote_fn from ray.data.datasource.datasource import Datasource, Reader, ReadTask, WriteResult from ray.data.block import ( Block, BlockAccessor, BlockMetadata, ) +from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.execution.interfaces import TaskContext from ray.types import ObjectRef -from ray.util.annotations import PublicAPI +from ray.util.annotations import Deprecated, PublicAPI from typing import Iterable if TYPE_CHECKING: @@ -66,6 +66,7 @@ def write_block(uri: str, database: str, collection: str, block: Block): write_tasks.append(write_task) return write_tasks + @Deprecated def do_write( self, blocks: List[ObjectRef[Block]], diff --git a/python/ray/data/tests/test_dataset_formats.py b/python/ray/data/tests/test_dataset_formats.py index aedc55d19045..6a0283d765a4 100644 --- a/python/ray/data/tests/test_dataset_formats.py +++ b/python/ray/data/tests/test_dataset_formats.py @@ -263,7 +263,7 @@ def get_num_failed(self): self.data_sink = DataSink.remote() - def do_write( + def direct_write( self, blocks: Iterable[Block], ctx: TaskContext, From 554171a31e96d50029349159a6dea2019ec9056b Mon Sep 17 00:00:00 2001 From: jianoaix Date: Tue, 31 Jan 2023 23:41:36 +0000 Subject: [PATCH 13/27] test write fusion --- python/ray/data/dataset.py | 5 ++-- python/ray/data/tests/test_dataset_formats.py | 2 +- python/ray/data/tests/test_optimize.py | 25 +++++++++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 55e2d96c8c1b..2f47783b04d7 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2635,7 +2635,7 @@ def transform(blocks: Iterable[Block], ctx, fn): raise return [] - plan = self._plan.with_stage( + self._plan = self._plan.with_stage( OneToOneStage( "write", transform, @@ -2644,8 +2644,7 @@ def transform(blocks: Iterable[Block], ctx, fn): fn=lambda x: x, ) ) - ds = Dataset(plan, self._epoch, self._lazy) - ds.fully_executed() + self.fully_executed() def iterator(self) -> DatasetIterator: """Return a :class:`~ray.data.DatasetIterator` that diff --git a/python/ray/data/tests/test_dataset_formats.py b/python/ray/data/tests/test_dataset_formats.py index 6a0283d765a4..805a8e60faf9 100644 --- a/python/ray/data/tests/test_dataset_formats.py +++ b/python/ray/data/tests/test_dataset_formats.py @@ -179,7 +179,7 @@ def test_write_datasource(ray_start_regular_shared, pipelined): assert ray.get(output.data_sink.get_rows_written.remote()) == 10 ray.get(output.data_sink.set_enabled.remote(False)) - ds = maybe_pipeline(ds0, pipelined) + ds = maybe_pipeline(ray.data.range(10, parallelism=2), pipelined) with pytest.raises(ValueError): ds.write_datasource(output, ray_remote_args={"max_retries": 0}) assert ray.get(output.data_sink.get_num_ok.remote()) == 2 diff --git a/python/ray/data/tests/test_optimize.py b/python/ray/data/tests/test_optimize.py index 97b62885e1b2..b5d9673dbc05 100644 --- a/python/ray/data/tests/test_optimize.py +++ b/python/ray/data/tests/test_optimize.py @@ -349,6 +349,31 @@ def test_window_randomize_fusion(ray_start_regular_shared): assert "read->randomize_block_order->MapBatches(dummy_map)" in stats, stats +def test_write_fusion(ray_start_regular_shared, tmp_path): + context = DatasetContext.get_current() + context.optimize_fuse_stages = True + context.optimize_fuse_read_stages = True + context.optimize_fuse_shuffle_stages = True + + path = os.path.join(tmp_path, "out") + ds = ray.data.range(100).map_batches(lambda x: x) + ds.write_csv(path) + stats = ds.stats() + assert "read->MapBatches()->write" in stats, stats + + ds = ( + ray.data.range(100) + .map_batches(lambda x: x) + .random_shuffle() + .map_batches(lambda x: x) + ) + ds.write_csv(path) + stats = ds.stats() + assert "read->MapBatches()" in stats, stats + assert "random_shuffle" in stats, stats + assert "MapBatches()->write" in stats, stats + + def test_optimize_fuse(ray_start_regular_shared): context = DatasetContext.get_current() From 1ba1b9f0b467fd6f0049bc92a655ab28c5fc9cd2 Mon Sep 17 00:00:00 2001 From: jianoaix Date: Wed, 1 Feb 2023 05:44:59 +0000 Subject: [PATCH 14/27] Result of write operator; datasource callbacks --- python/ray/data/dataset.py | 31 ++++++++---- python/ray/data/datasource/datasource.py | 28 ++++------- .../data/datasource/file_based_datasource.py | 20 +++++--- .../ray/data/datasource/mongo_datasource.py | 37 ++++++++------ python/ray/data/tests/test_dataset_formats.py | 50 +++++++++---------- 5 files changed, 89 insertions(+), 77 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 2f47783b04d7..3325a2a3d355 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -109,6 +109,7 @@ ParquetDatasource, ReadTask, TFRecordDatasource, + WriteResult, ) from ray.data.random_access_dataset import RandomAccessDataset from ray.data.row import TableRow @@ -2626,16 +2627,15 @@ def write_datasource( soft=False, ) - def transform(blocks: Iterable[Block], ctx, fn): - try: - datasource.direct_write(blocks, ctx, **write_args) - datasource.on_write_complete([]) - except Exception as e: - datasource.on_write_failed([], e) - raise - return [] - - self._plan = self._plan.with_stage( + # The resulting Dataset from a write operator is a "status Dataset" + # indicating the result status of each write task launched by the + # write operator. Specifically, it will contain N blocks, where + # each block has a single row (i.e. WriteResult, indicating the result + # status of the write task) and N is the number of write tasks. + def transform(blocks: Iterable[Block], ctx, fn) -> List[List[WriteResult]]: + return [[datasource.direct_write(blocks, ctx, **write_args)]] + + plan = self._plan.with_stage( OneToOneStage( "write", transform, @@ -2644,7 +2644,16 @@ def transform(blocks: Iterable[Block], ctx, fn): fn=lambda x: x, ) ) - self.fully_executed() + ds = Dataset(plan, self._epoch, self._lazy) + ds = ds.fully_executed() + results = list(ds.iter_rows()) + if all(not isinstance(w, Exception) for w in results): + datasource.on_write_complete(results) + else: + datasource.on_write_failed(results, None) + for r in results: + if isinstance(r, Exception): + raise r def iterator(self) -> DatasetIterator: """Return a :class:`~ray.data.DatasetIterator` that diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index 7cf70a14401f..2595cd79a989 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -334,8 +334,6 @@ class DataSink: def __init__(self): self.rows_written = 0 self.enabled = True - self.num_ok = 0 - self.num_failed = 0 def write(self, block: Block) -> str: block = BlockAccessor.for_block(block) @@ -350,39 +348,33 @@ def get_rows_written(self): def set_enabled(self, enabled): self.enabled = enabled - def increment_ok(self): - self.num_ok += 1 - - def get_num_ok(self): - return self.num_ok - - def increment_failed(self): - self.num_failed += 1 - - def get_num_failed(self): - return self.num_failed - self.data_sink = DataSink.remote() + self.num_ok = 0 + self.num_failed = 0 def direct_write( self, blocks: Iterable[Block], ctx: TaskContext, **write_args, - ): + ) -> WriteResult: tasks = [] for b in blocks: tasks.append(self.data_sink.write.remote(b)) - return ray.get(tasks) + try: + ray.get(tasks) + return "ok" + except Exception as e: + return e def on_write_complete(self, write_results: List[WriteResult]) -> None: assert all(w == "ok" for w in write_results), write_results - self.data_sink.increment_ok.remote() + self.num_ok += 1 def on_write_failed( self, write_results: List[ObjectRef[WriteResult]], error: Exception ) -> None: - self.data_sink.increment_failed.remote() + self.num_failed += 1 @DeveloperAPI diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index b57ae6a7a32e..3618b5bc3365 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -272,7 +272,7 @@ def direct_write( write_args_fn: Callable[[], Dict[str, Any]] = lambda: {}, _block_udf: Optional[Callable[[Block], Block]] = None, **write_args, - ): + ) -> WriteResult: """Creates and returns write tasks for a file-based datasource.""" path, filesystem = _resolve_paths_and_filesystem(path, filesystem) path = path[0] @@ -297,12 +297,18 @@ def write_block(write_path: str, block: Block): block = _block_udf(block) with fs.open_output_stream(write_path, **open_stream_args) as f: - _write_block_to_file( - f, - BlockAccessor.for_block(block), - writer_args_fn=write_args_fn, - **write_args, - ) + try: + _write_block_to_file( + f, + BlockAccessor.for_block(block), + writer_args_fn=write_args_fn, + **write_args, + ) + # TODO: decide if we want to return richer object when the task + # succeeds. + return "ok" + except Exception as e: + return e file_format = self._FILE_EXTENSION if isinstance(file_format, list): diff --git a/python/ray/data/datasource/mongo_datasource.py b/python/ray/data/datasource/mongo_datasource.py index f45f220884ea..953c9a0f8a89 100644 --- a/python/ray/data/datasource/mongo_datasource.py +++ b/python/ray/data/datasource/mongo_datasource.py @@ -7,6 +7,7 @@ BlockAccessor, BlockMetadata, ) +from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.execution.interfaces import TaskContext from ray.types import ObjectRef @@ -46,25 +47,33 @@ def direct_write( uri: str, database: str, collection: str, - ) -> List[ObjectRef[WriteResult]]: + ) -> WriteResult: import pymongo - _validate_database_collection_exist( - pymongo.MongoClient(uri), database, collection - ) + try: + _validate_database_collection_exist( + pymongo.MongoClient(uri), database, collection + ) - def write_block(uri: str, database: str, collection: str, block: Block): - from pymongoarrow.api import write + def write_block(uri: str, database: str, collection: str, block: Block): + from pymongoarrow.api import write - block = BlockAccessor.for_block(block).to_arrow() - client = pymongo.MongoClient(uri) - write(client[database][collection], block) + block = BlockAccessor.for_block(block).to_arrow() + client = pymongo.MongoClient(uri) + write(client[database][collection], block) - write_tasks = [] - for block in blocks: - write_task = write_block(uri, database, collection, block) - write_tasks.append(write_task) - return write_tasks + builder = DelegatingBlockBuilder() + for block in blocks: + builder.add_block(block) + block = builder.build() + + write_block(uri, database, collection, block) + + # TODO: decide if we want to return richer object when the task + # succeeds. + return "ok" + except Exception as e: + return e @Deprecated def do_write( diff --git a/python/ray/data/tests/test_dataset_formats.py b/python/ray/data/tests/test_dataset_formats.py index 805a8e60faf9..5bcaa9b748ce 100644 --- a/python/ray/data/tests/test_dataset_formats.py +++ b/python/ray/data/tests/test_dataset_formats.py @@ -174,20 +174,23 @@ def test_write_datasource(ray_start_regular_shared, pipelined): ds0 = ray.data.range(10, parallelism=2) ds = maybe_pipeline(ds0, pipelined) ds.write_datasource(output) - assert ray.get(output.data_sink.get_num_ok.remote()) == 2 - assert ray.get(output.data_sink.get_num_failed.remote()) == 0 + if pipelined: + assert output.num_ok == 2 + else: + assert output.num_ok == 1 + assert output.num_failed == 0 assert ray.get(output.data_sink.get_rows_written.remote()) == 10 ray.get(output.data_sink.set_enabled.remote(False)) ds = maybe_pipeline(ray.data.range(10, parallelism=2), pipelined) with pytest.raises(ValueError): ds.write_datasource(output, ray_remote_args={"max_retries": 0}) - assert ray.get(output.data_sink.get_num_ok.remote()) == 2 - assert ray.get(output.data_sink.get_rows_written.remote()) == 10 if pipelined: - assert ray.get(output.data_sink.get_num_failed.remote()) == 1 + assert output.num_ok == 2 else: - assert ray.get(output.data_sink.get_num_failed.remote()) == 2 + assert output.num_ok == 1 + assert output.num_failed == 1 + assert ray.get(output.data_sink.get_rows_written.remote()) == 10 def test_from_tf(ray_start_regular_shared): @@ -229,8 +232,6 @@ def __init__(self): self.rows_written = 0 self.enabled = True self.node_ids = set() - self.num_ok = 0 - self.num_failed = 0 def write(self, node_id: str, block: Block) -> str: block = BlockAccessor.for_block(block) @@ -249,19 +250,9 @@ def get_node_ids(self): def set_enabled(self, enabled): self.enabled = enabled - def increment_ok(self): - self.num_ok += 1 - - def get_num_ok(self): - return self.num_ok - - def increment_failed(self): - self.num_failed += 1 - - def get_num_failed(self): - return self.num_failed - self.data_sink = DataSink.remote() + self.num_ok = 0 + self.num_failed = 0 def direct_write( self, @@ -273,20 +264,25 @@ def direct_write( def write(b): node_id = ray.get_runtime_context().get_node_id() - return ray.get(data_sink.write.remote(node_id, b)) + return data_sink.write.remote(node_id, b) + tasks = [] for b in blocks: - result = write(b) - return result + tasks.append(write(b)) + try: + ray.get(tasks) + return "ok" + except Exception as e: + return e def on_write_complete(self, write_results: List[WriteResult]) -> None: assert all(w == "ok" for w in write_results), write_results - self.data_sink.increment_ok.remote() + self.num_ok += 1 def on_write_failed( self, write_results: List[ObjectRef[WriteResult]], error: Exception ) -> None: - self.data_sink.increment_failed.remote() + self.num_failed += 1 def test_write_datasource_ray_remote_args(ray_start_cluster): @@ -310,8 +306,8 @@ def get_node_id(): ds = ray.data.range(100, parallelism=10) # Pin write tasks to ds.write_datasource(output, ray_remote_args={"resources": {"bar": 1}}) - assert ray.get(output.data_sink.get_num_ok.remote()) == 10 - assert ray.get(output.data_sink.get_num_failed.remote()) == 0 + assert output.num_ok == 1 + assert output.num_failed == 0 assert ray.get(output.data_sink.get_rows_written.remote()) == 100 node_ids = ray.get(output.data_sink.get_node_ids.remote()) From 5ced2467458a55963611a0eb322aa49bf381693b Mon Sep 17 00:00:00 2001 From: jianoaix Date: Fri, 3 Feb 2023 21:15:01 +0000 Subject: [PATCH 15/27] Handle an empty list on failure --- python/ray/data/dataset.py | 26 ++++++-------- python/ray/data/datasource/datasource.py | 21 +++++------ .../data/datasource/file_based_datasource.py | 21 +++++------ .../ray/data/datasource/mongo_datasource.py | 35 +++++++++---------- python/ray/data/tests/test_dataset_formats.py | 19 +++------- 5 files changed, 49 insertions(+), 73 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 3325a2a3d355..46d01173b576 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2627,12 +2627,11 @@ def write_datasource( soft=False, ) - # The resulting Dataset from a write operator is a "status Dataset" - # indicating the result status of each write task launched by the - # write operator. Specifically, it will contain N blocks, where - # each block has a single row (i.e. WriteResult, indicating the result - # status of the write task) and N is the number of write tasks. - def transform(blocks: Iterable[Block], ctx, fn) -> List[List[WriteResult]]: + # If the write operator succeeds, the resulting Dataset is a list of + # WriteResult (one element per write task). Otherwise, an error will + # be raised. The Datasource can handle execution outcomes with the + # on_write_complete() and on_write_failed(). + def transform(blocks: Iterable[Block], ctx, fn) -> List[ObjectRef[WriteResult]]: return [[datasource.direct_write(blocks, ctx, **write_args)]] plan = self._plan.with_stage( @@ -2645,15 +2644,12 @@ def transform(blocks: Iterable[Block], ctx, fn) -> List[List[WriteResult]]: ) ) ds = Dataset(plan, self._epoch, self._lazy) - ds = ds.fully_executed() - results = list(ds.iter_rows()) - if all(not isinstance(w, Exception) for w in results): - datasource.on_write_complete(results) - else: - datasource.on_write_failed(results, None) - for r in results: - if isinstance(r, Exception): - raise r + try: + ds = ds.fully_executed() + datasource.on_write_complete(ds._plan.execute().get_blocks()) + except Exception as e: + datasource.on_write_failed([], e) + raise def iterator(self) -> DatasetIterator: """Return a :class:`~ray.data.DatasetIterator` that diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index 2595cd79a989..4499f5fe9410 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -337,20 +337,16 @@ def __init__(self): def write(self, block: Block) -> str: block = BlockAccessor.for_block(block) - if not self.enabled: - raise ValueError("disabled") self.rows_written += block.num_rows() return "ok" def get_rows_written(self): return self.rows_written - def set_enabled(self, enabled): - self.enabled = enabled - self.data_sink = DataSink.remote() self.num_ok = 0 self.num_failed = 0 + self.enabled = True def direct_write( self, @@ -359,16 +355,15 @@ def direct_write( **write_args, ) -> WriteResult: tasks = [] + if not self.enabled: + raise ValueError("disabled") for b in blocks: tasks.append(self.data_sink.write.remote(b)) - try: - ray.get(tasks) - return "ok" - except Exception as e: - return e - - def on_write_complete(self, write_results: List[WriteResult]) -> None: - assert all(w == "ok" for w in write_results), write_results + ray.get(tasks) + return "ok" + + def on_write_complete(self, write_results: List[ObjectRef[WriteResult]]) -> None: + assert all(ray.get(w) == ["ok"] for w in write_results), write_results self.num_ok += 1 def on_write_failed( diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index 3618b5bc3365..dd48574262ce 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -297,18 +297,15 @@ def write_block(write_path: str, block: Block): block = _block_udf(block) with fs.open_output_stream(write_path, **open_stream_args) as f: - try: - _write_block_to_file( - f, - BlockAccessor.for_block(block), - writer_args_fn=write_args_fn, - **write_args, - ) - # TODO: decide if we want to return richer object when the task - # succeeds. - return "ok" - except Exception as e: - return e + _write_block_to_file( + f, + BlockAccessor.for_block(block), + writer_args_fn=write_args_fn, + **write_args, + ) + # TODO: decide if we want to return richer object when the task + # succeeds. + return "ok" file_format = self._FILE_EXTENSION if isinstance(file_format, list): diff --git a/python/ray/data/datasource/mongo_datasource.py b/python/ray/data/datasource/mongo_datasource.py index 953c9a0f8a89..82bbb2324cf6 100644 --- a/python/ray/data/datasource/mongo_datasource.py +++ b/python/ray/data/datasource/mongo_datasource.py @@ -50,30 +50,27 @@ def direct_write( ) -> WriteResult: import pymongo - try: - _validate_database_collection_exist( - pymongo.MongoClient(uri), database, collection - ) + _validate_database_collection_exist( + pymongo.MongoClient(uri), database, collection + ) - def write_block(uri: str, database: str, collection: str, block: Block): - from pymongoarrow.api import write + def write_block(uri: str, database: str, collection: str, block: Block): + from pymongoarrow.api import write - block = BlockAccessor.for_block(block).to_arrow() - client = pymongo.MongoClient(uri) - write(client[database][collection], block) + block = BlockAccessor.for_block(block).to_arrow() + client = pymongo.MongoClient(uri) + write(client[database][collection], block) - builder = DelegatingBlockBuilder() - for block in blocks: - builder.add_block(block) - block = builder.build() + builder = DelegatingBlockBuilder() + for block in blocks: + builder.add_block(block) + block = builder.build() - write_block(uri, database, collection, block) + write_block(uri, database, collection, block) - # TODO: decide if we want to return richer object when the task - # succeeds. - return "ok" - except Exception as e: - return e + # TODO: decide if we want to return richer object when the task + # succeeds. + return "ok" @Deprecated def do_write( diff --git a/python/ray/data/tests/test_dataset_formats.py b/python/ray/data/tests/test_dataset_formats.py index 5bcaa9b748ce..0965b5be4b29 100644 --- a/python/ray/data/tests/test_dataset_formats.py +++ b/python/ray/data/tests/test_dataset_formats.py @@ -181,7 +181,7 @@ def test_write_datasource(ray_start_regular_shared, pipelined): assert output.num_failed == 0 assert ray.get(output.data_sink.get_rows_written.remote()) == 10 - ray.get(output.data_sink.set_enabled.remote(False)) + output.enabled = False ds = maybe_pipeline(ray.data.range(10, parallelism=2), pipelined) with pytest.raises(ValueError): ds.write_datasource(output, ray_remote_args={"max_retries": 0}) @@ -230,13 +230,10 @@ def __init__(self): class DataSink: def __init__(self): self.rows_written = 0 - self.enabled = True self.node_ids = set() def write(self, node_id: str, block: Block) -> str: block = BlockAccessor.for_block(block) - if not self.enabled: - raise ValueError("disabled") self.rows_written += block.num_rows() self.node_ids.add(node_id) return "ok" @@ -247,9 +244,6 @@ def get_rows_written(self): def get_node_ids(self): return self.node_ids - def set_enabled(self, enabled): - self.enabled = enabled - self.data_sink = DataSink.remote() self.num_ok = 0 self.num_failed = 0 @@ -269,14 +263,11 @@ def write(b): tasks = [] for b in blocks: tasks.append(write(b)) - try: - ray.get(tasks) - return "ok" - except Exception as e: - return e + ray.get(tasks) + return "ok" def on_write_complete(self, write_results: List[WriteResult]) -> None: - assert all(w == "ok" for w in write_results), write_results + assert all(ray.get(w) == ["ok"] for w in write_results), write_results self.num_ok += 1 def on_write_failed( @@ -304,7 +295,7 @@ def get_node_id(): output = NodeLoggerOutputDatasource() ds = ray.data.range(100, parallelism=10) - # Pin write tasks to + # Pin write tasks to node with "bar" resource. ds.write_datasource(output, ray_remote_args={"resources": {"bar": 1}}) assert output.num_ok == 1 assert output.num_failed == 0 From 43eca29521dc31dc745a9455093c5bc8a8e20730 Mon Sep 17 00:00:00 2001 From: jianoaix Date: Fri, 3 Feb 2023 21:31:11 +0000 Subject: [PATCH 16/27] execute the plan in-place in write_datasource --- python/ray/data/dataset.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 46d01173b576..c8b0b28da8bc 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2634,7 +2634,7 @@ def write_datasource( def transform(blocks: Iterable[Block], ctx, fn) -> List[ObjectRef[WriteResult]]: return [[datasource.direct_write(blocks, ctx, **write_args)]] - plan = self._plan.with_stage( + self._plan = self._plan.with_stage( OneToOneStage( "write", transform, @@ -2643,10 +2643,9 @@ def transform(blocks: Iterable[Block], ctx, fn) -> List[ObjectRef[WriteResult]]: fn=lambda x: x, ) ) - ds = Dataset(plan, self._epoch, self._lazy) try: - ds = ds.fully_executed() - datasource.on_write_complete(ds._plan.execute().get_blocks()) + self._plan.execute(force_read=True) + datasource.on_write_complete(self._plan.execute().get_blocks()) except Exception as e: datasource.on_write_failed([], e) raise From f25d54bad4b98ccea9373297a56ffb0268be1eba Mon Sep 17 00:00:00 2001 From: jianoaix Date: Fri, 3 Feb 2023 21:53:58 +0000 Subject: [PATCH 17/27] Keep write_datasource semantics diff-neutral regarding the plan --- python/ray/data/dataset.py | 6 +++--- python/ray/data/tests/test_optimize.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index c8b0b28da8bc..307180738c9c 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2634,7 +2634,7 @@ def write_datasource( def transform(blocks: Iterable[Block], ctx, fn) -> List[ObjectRef[WriteResult]]: return [[datasource.direct_write(blocks, ctx, **write_args)]] - self._plan = self._plan.with_stage( + plan = self._plan.with_stage( OneToOneStage( "write", transform, @@ -2644,8 +2644,8 @@ def transform(blocks: Iterable[Block], ctx, fn) -> List[ObjectRef[WriteResult]]: ) ) try: - self._plan.execute(force_read=True) - datasource.on_write_complete(self._plan.execute().get_blocks()) + self._write_ds = Dataset(plan, self._epoch, self._lazy).fully_executed() + datasource.on_write_complete(self._write_ds._plan.execute().get_blocks()) except Exception as e: datasource.on_write_failed([], e) raise diff --git a/python/ray/data/tests/test_optimize.py b/python/ray/data/tests/test_optimize.py index b5d9673dbc05..317167faefb1 100644 --- a/python/ray/data/tests/test_optimize.py +++ b/python/ray/data/tests/test_optimize.py @@ -358,7 +358,7 @@ def test_write_fusion(ray_start_regular_shared, tmp_path): path = os.path.join(tmp_path, "out") ds = ray.data.range(100).map_batches(lambda x: x) ds.write_csv(path) - stats = ds.stats() + stats = ds._write_ds.stats() assert "read->MapBatches()->write" in stats, stats ds = ( @@ -368,7 +368,7 @@ def test_write_fusion(ray_start_regular_shared, tmp_path): .map_batches(lambda x: x) ) ds.write_csv(path) - stats = ds.stats() + stats = ds._write_ds.stats() assert "read->MapBatches()" in stats, stats assert "random_shuffle" in stats, stats assert "MapBatches()->write" in stats, stats From 1d58e13ae75571817ba1dfdfab925d96deafde91 Mon Sep 17 00:00:00 2001 From: jianoaix Date: Fri, 3 Feb 2023 23:41:10 +0000 Subject: [PATCH 18/27] disable the write_XX in new optimizer: it's not supported yet --- .../data/tests/test_execution_optimizer.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 83b2c9b6cd01..1cc091391b7f 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -1,6 +1,4 @@ -import os import pytest -import pandas as pd import ray from ray.data._internal.execution.operators.map_operator import MapOperator @@ -545,20 +543,23 @@ def test_sort_e2e( ds = ds.sort() assert ds.take_all() == list(range(100)) - df = pd.DataFrame({"one": list(range(100)), "two": ["a"] * 100}) - ds = ray.data.from_pandas([df]) - path = os.path.join(local_path, "test_parquet_dir") - os.mkdir(path) - ds.write_parquet(path) - - ds = ray.data.read_parquet(path) - ds = ds.random_shuffle() - ds1 = ds.sort("one") - ds2 = ds.sort("one", descending=True) - r1 = ds1.select_columns(["one"]).take_all() - r2 = ds2.select_columns(["one"]).take_all() - assert [d["one"] for d in r1] == list(range(100)) - assert [d["one"] for d in r2] == list(reversed(range(100))) + # TODO: write_XXX and from_XXX are not supported yet in new execution plan. + # Re-enable once supported. + + # df = pd.DataFrame({"one": list(range(100)), "two": ["a"] * 100}) + # ds = ray.data.from_pandas([df]) + # path = os.path.join(local_path, "test_parquet_dir") + # os.mkdir(path) + # ds.write_parquet(path) + + # ds = ray.data.read_parquet(path) + # ds = ds.random_shuffle() + # ds1 = ds.sort("one") + # ds2 = ds.sort("one", descending=True) + # r1 = ds1.select_columns(["one"]).take_all() + # r2 = ds2.select_columns(["one"]).take_all() + # assert [d["one"] for d in r1] == list(range(100)) + # assert [d["one"] for d in r2] == list(reversed(range(100))) if __name__ == "__main__": From d309dbd517d0609b53b1d223c908ae975865d3fa Mon Sep 17 00:00:00 2001 From: jianoaix Date: Fri, 3 Feb 2023 23:45:02 +0000 Subject: [PATCH 19/27] fix comment --- python/ray/data/datasource/file_based_datasource.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index dd48574262ce..ad52c513462d 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -273,7 +273,7 @@ def direct_write( _block_udf: Optional[Callable[[Block], Block]] = None, **write_args, ) -> WriteResult: - """Creates and returns write tasks for a file-based datasource.""" + """Write blocks for a file-based datasource.""" path, filesystem = _resolve_paths_and_filesystem(path, filesystem) path = path[0] if try_create_dir: From 21a50db850a36c00e1ee50a73bcc9413825d9814 Mon Sep 17 00:00:00 2001 From: jianoaix Date: Sat, 4 Feb 2023 00:13:10 +0000 Subject: [PATCH 20/27] refactor: do_write() calls direct_write() to reduce code duplication --- .../data/datasource/file_based_datasource.py | 60 +++++-------------- python/ray/data/tests/test_dataset.py | 10 ++++ 2 files changed, 26 insertions(+), 44 deletions(-) diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index ad52c513462d..cbb8fab6e58a 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -345,58 +345,30 @@ def do_write( **write_args, ) -> List[ObjectRef[WriteResult]]: """Creates and returns write tasks for a file-based datasource.""" - path, filesystem = _resolve_paths_and_filesystem(path, filesystem) - path = path[0] - if try_create_dir: - # Arrow's S3FileSystem doesn't allow creating buckets by default, so we add - # a query arg enabling bucket creation if an S3 URI is provided. - tmp = _add_creatable_buckets_param_if_s3_uri(path) - filesystem.create_dir(tmp, recursive=True) - filesystem = _wrap_s3_serialization_workaround(filesystem) - - _write_block_to_file = self._write_block - - if open_stream_args is None: - open_stream_args = {} - if ray_remote_args is None: ray_remote_args = {} - def write_block(write_path: str, block: Block): - logger.debug(f"Writing {write_path} file.") - fs = filesystem - if isinstance(fs, _S3FileSystemWrapper): - fs = fs.unwrap() - if _block_udf is not None: - block = _block_udf(block) - - with fs.open_output_stream(write_path, **open_stream_args) as f: - _write_block_to_file( - f, - BlockAccessor.for_block(block), - writer_args_fn=write_args_fn, - **write_args, - ) + def write_block(block_idx, block): + ctx = TaskContext(task_idx=block_idx) + return self.direct_write( + [block], + ctx, + path, + dataset_uuid, + filesystem, + try_create_dir, + open_stream_args, + block_path_provider, + write_args_fn, + _block_udf, + **write_args, + ) write_block = cached_remote_fn(write_block).options(**ray_remote_args) - file_format = self._FILE_EXTENSION - if isinstance(file_format, list): - file_format = file_format[0] - write_tasks = [] - if not block_path_provider: - block_path_provider = DefaultBlockWritePathProvider() for block_idx, block in enumerate(blocks): - write_path = block_path_provider( - path, - filesystem=filesystem, - dataset_uuid=dataset_uuid, - block=block, - block_index=block_idx, - file_format=file_format, - ) - write_task = write_block.remote(write_path, block) + write_task = write_block.remote(block_idx, block) write_tasks.append(write_task) return write_tasks diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 99ace502471f..c3bd9693150d 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -5517,6 +5517,16 @@ def test_ragged_tensors(ray_start_regular_shared): ] +def test_legacy_do_write(ray_start_regular_shared, tmp_path): + out = CSVDatasource() + ds = ray.data.range(100).fully_executed() + blocks, metadata = zip(*ds._plan.execute().get_blocks_with_metadata()) + out.do_write(blocks, metadata, tmp_path, ds._uuid) + + ds = ray.data.read_csv(tmp_path) + assert [e["value"] for e in ds.take_all()] == list(range(100)) + + class LoggerWarningCalled(Exception): """Custom exception used in test_warning_execute_with_no_cpu() and test_nowarning_execute_with_cpu(). Raised when the `logger.warning` method From a84e27bf83498f0c28fe89f205326216b81e61d3 Mon Sep 17 00:00:00 2001 From: jianoaix Date: Sat, 4 Feb 2023 00:44:40 +0000 Subject: [PATCH 21/27] refactor: for mongo datasource do_write --- .../ray/data/datasource/mongo_datasource.py | 19 ++++-------- python/ray/data/tests/test_mongo_dataset.py | 29 +++++++++++++++++++ 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/python/ray/data/datasource/mongo_datasource.py b/python/ray/data/datasource/mongo_datasource.py index 82bbb2324cf6..f5c240cc2a87 100644 --- a/python/ray/data/datasource/mongo_datasource.py +++ b/python/ray/data/datasource/mongo_datasource.py @@ -82,26 +82,17 @@ def do_write( database: str, collection: str, ) -> List[ObjectRef[WriteResult]]: - import pymongo - - _validate_database_collection_exist( - pymongo.MongoClient(uri), database, collection - ) - - def write_block(uri: str, database: str, collection: str, block: Block): - from pymongoarrow.api import write - - block = BlockAccessor.for_block(block).to_arrow() - client = pymongo.MongoClient(uri) - write(client[database][collection], block) + def write_block(block_idx, block): + ctx = TaskContext(task_idx=block_idx) + return self.direct_write([block], ctx, uri, database, collection) if ray_remote_args is None: ray_remote_args = {} write_block = cached_remote_fn(write_block).options(**ray_remote_args) write_tasks = [] - for block in blocks: - write_task = write_block.remote(uri, database, collection, block) + for idx, block in enumerate(blocks): + write_task = write_block.remote(idx, block) write_tasks.append(write_task) return write_tasks diff --git a/python/ray/data/tests/test_mongo_dataset.py b/python/ray/data/tests/test_mongo_dataset.py index d87d83421003..e55571a07fc1 100644 --- a/python/ray/data/tests/test_mongo_dataset.py +++ b/python/ray/data/tests/test_mongo_dataset.py @@ -257,6 +257,35 @@ def test_mongo_datasource(ray_start_regular_shared, start_mongo): df[df["int_field"] < 3].equals(ds.drop_columns(["_id"]).to_pandas()) +def test_legacy_do_write(ray_start_regular_shared, start_mongo): + client, mongo_url = start_mongo + foo_db = "foo-db" + foo_collection = "foo-collection" + foo = client[foo_db][foo_collection] + foo.delete_many({}) + + docs = [{"float_field": 2.0 * val, "int_field": val} for val in range(5)] + foo.insert_many(docs) + + docs = [{"float_field": 2.0 * val, "int_field": val} for val in range(5, 10)] + df = pd.DataFrame(docs).astype({"int_field": "int32"}) + ds = ray.data.from_pandas(df) + + out = MongoDatasource() + blocks, metadata = zip(*ds._plan.execute().get_blocks_with_metadata()) + tasks = out.do_write(blocks, metadata, {}, mongo_url, foo_db, foo_collection) + ray.get(tasks) + + ds = ray.data.read_mongo( + uri=mongo_url, + database=foo_db, + collection=foo_collection, + ) + docs = [{"float_field": 2.0 * val, "int_field": val} for val in range(10)] + df = pd.DataFrame(docs).astype({"int_field": "int32"}) + assert df.equals(ds.drop_columns(["_id"]).to_pandas()) + + if __name__ == "__main__": import sys From 8879df0692136733541a6ec3b11b2985412c3c90 Mon Sep 17 00:00:00 2001 From: jianoaix Date: Tue, 7 Feb 2023 00:17:08 +0000 Subject: [PATCH 22/27] backward compatible --- python/ray/data/dataset.py | 101 +++++++++++++++++++++++++++++-------- 1 file changed, 80 insertions(+), 21 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 7552124c8f8e..7c09a3c9a73b 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1,6 +1,7 @@ import collections import itertools import logging +import os import sys import time import html @@ -76,6 +77,7 @@ ZipStage, SortStage, ) +from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.split import _split_at_index, _split_at_indices, _get_num_rows from ray.data._internal.stats import DatasetStats, DatasetStatsSummary @@ -113,6 +115,10 @@ TFRecordDatasource, WriteResult, ) +from ray.data.datasource.file_based_datasource import ( + _unwrap_arrow_serialization_workaround, + _wrap_arrow_serialization_workaround, +) from ray.data.random_access_dataset import RandomAccessDataset from ray.data.row import TableRow from ray.types import ObjectRef @@ -2647,28 +2653,68 @@ def write_datasource( soft=False, ) - # If the write operator succeeds, the resulting Dataset is a list of - # WriteResult (one element per write task). Otherwise, an error will - # be raised. The Datasource can handle execution outcomes with the - # on_write_complete() and on_write_failed(). - def transform(blocks: Iterable[Block], ctx, fn) -> List[ObjectRef[WriteResult]]: - return [[datasource.direct_write(blocks, ctx, **write_args)]] - - plan = self._plan.with_stage( - OneToOneStage( - "write", - transform, - "tasks", - ray_remote_args, - fn=lambda x: x, + if hasattr(datasource, "direct_write"): + # If the write operator succeeds, the resulting Dataset is a list of + # WriteResult (one element per write task). Otherwise, an error will + # be raised. The Datasource can handle execution outcomes with the + # on_write_complete() and on_write_failed(). + def transform( + blocks: Iterable[Block], ctx, fn + ) -> Iterable[Block]: + return [[datasource.direct_write(blocks, ctx, **write_args)]] + + plan = self._plan.with_stage( + OneToOneStage( + "write", + transform, + "tasks", + ray_remote_args, + fn=lambda x: x, + ) ) - ) - try: - self._write_ds = Dataset(plan, self._epoch, self._lazy).fully_executed() - datasource.on_write_complete(self._write_ds._plan.execute().get_blocks()) - except Exception as e: - datasource.on_write_failed([], e) - raise + try: + self._write_ds = Dataset(plan, self._epoch, self._lazy).fully_executed() + datasource.on_write_complete( + self._write_ds._plan.execute().get_blocks() + ) + except Exception as e: + datasource.on_write_failed([], e) + raise + else: + ctx = DatasetContext.get_current() + blocks, metadata = zip(*self._plan.execute().get_blocks_with_metadata()) + + # TODO(ekl) remove this feature flag. + if "RAY_DATASET_FORCE_LOCAL_METADATA" in os.environ: + write_results: List[ObjectRef[WriteResult]] = datasource.do_write( + blocks, metadata, ray_remote_args=ray_remote_args, **write_args + ) + else: + # Prepare write in a remote task so that in Ray client mode, we + # don't do metadata resolution from the client machine. + do_write = cached_remote_fn( + _do_write, retry_exceptions=False, num_cpus=0 + ) + write_results: List[ObjectRef[WriteResult]] = ray.get( + do_write.remote( + datasource, + ctx, + blocks, + metadata, + ray_remote_args, + _wrap_arrow_serialization_workaround(write_args), + ) + ) + + progress = ProgressBar("Write Progress", len(write_results)) + try: + progress.block_until_complete(write_results) + datasource.on_write_complete(ray.get(write_results)) + except Exception as e: + datasource.on_write_failed(write_results, e) + raise + finally: + progress.close() def iterator(self) -> DatasetIterator: """Return a :class:`~ray.data.DatasetIterator` that @@ -4391,3 +4437,16 @@ def _sliding_window(iterable: Iterable, n: int): for elem in it: window.append(elem) yield tuple(window) + + +def _do_write( + ds: Datasource, + ctx: DatasetContext, + blocks: List[Block], + meta: List[BlockMetadata], + ray_remote_args: Dict[str, Any], + write_args: Dict[str, Any], +) -> List[ObjectRef[WriteResult]]: + write_args = _unwrap_arrow_serialization_workaround(write_args) + DatasetContext._set_current(ctx) + return ds.do_write(blocks, meta, ray_remote_args=ray_remote_args, **write_args) From d6873e18cd09556141a4044fbe24dbcb328c36c7 Mon Sep 17 00:00:00 2001 From: jianoaix Date: Tue, 7 Feb 2023 00:23:47 +0000 Subject: [PATCH 23/27] rename: direct_write -> write --- python/ray/data/dataset.py | 4 +- python/ray/data/datasource/datasource.py | 4 +- .../data/datasource/file_based_datasource.py | 47 +------------------ .../ray/data/datasource/mongo_datasource.py | 26 +--------- python/ray/data/tests/test_dataset_formats.py | 2 +- 5 files changed, 7 insertions(+), 76 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 7c09a3c9a73b..b5e0ebeaba6d 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2653,7 +2653,7 @@ def write_datasource( soft=False, ) - if hasattr(datasource, "direct_write"): + if hasattr(datasource, "write"): # If the write operator succeeds, the resulting Dataset is a list of # WriteResult (one element per write task). Otherwise, an error will # be raised. The Datasource can handle execution outcomes with the @@ -2661,7 +2661,7 @@ def write_datasource( def transform( blocks: Iterable[Block], ctx, fn ) -> Iterable[Block]: - return [[datasource.direct_write(blocks, ctx, **write_args)]] + return [[datasource.write(blocks, ctx, **write_args)]] plan = self._plan.with_stage( OneToOneStage( diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index 4499f5fe9410..4bc7f4efe74d 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -51,7 +51,7 @@ def prepare_read(self, parallelism: int, **read_args) -> List["ReadTask[T]"]: """Deprecated: Please implement create_reader() instead.""" raise NotImplementedError - def direct_write( + def write( self, blocks: Iterable[Block], **write_args, @@ -348,7 +348,7 @@ def get_rows_written(self): self.num_failed = 0 self.enabled = True - def direct_write( + def write( self, blocks: Iterable[Block], ctx: TaskContext, diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index cbb8fab6e58a..0c2cf63573a8 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -259,7 +259,7 @@ def _convert_block_to_tabular_block( "then you need to implement `_convert_block_to_tabular_block." ) - def direct_write( + def write( self, blocks: Iterable[Block], ctx: TaskContext, @@ -328,51 +328,6 @@ def write_block(write_path: str, block: Block): ) return write_block(write_path, block) - @Deprecated - def do_write( - self, - blocks: List[ObjectRef[Block]], - metadata: List[BlockMetadata], - path: str, - dataset_uuid: str, - filesystem: Optional["pyarrow.fs.FileSystem"] = None, - try_create_dir: bool = True, - open_stream_args: Optional[Dict[str, Any]] = None, - block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(), - write_args_fn: Callable[[], Dict[str, Any]] = lambda: {}, - _block_udf: Optional[Callable[[Block], Block]] = None, - ray_remote_args: Dict[str, Any] = None, - **write_args, - ) -> List[ObjectRef[WriteResult]]: - """Creates and returns write tasks for a file-based datasource.""" - if ray_remote_args is None: - ray_remote_args = {} - - def write_block(block_idx, block): - ctx = TaskContext(task_idx=block_idx) - return self.direct_write( - [block], - ctx, - path, - dataset_uuid, - filesystem, - try_create_dir, - open_stream_args, - block_path_provider, - write_args_fn, - _block_udf, - **write_args, - ) - - write_block = cached_remote_fn(write_block).options(**ray_remote_args) - - write_tasks = [] - for block_idx, block in enumerate(blocks): - write_task = write_block.remote(block_idx, block) - write_tasks.append(write_task) - - return write_tasks - def _write_block( self, f: "pyarrow.NativeFile", diff --git a/python/ray/data/datasource/mongo_datasource.py b/python/ray/data/datasource/mongo_datasource.py index f5c240cc2a87..e19d108b1b1f 100644 --- a/python/ray/data/datasource/mongo_datasource.py +++ b/python/ray/data/datasource/mongo_datasource.py @@ -40,7 +40,7 @@ class MongoDatasource(Datasource): def create_reader(self, **kwargs) -> Reader: return _MongoDatasourceReader(**kwargs) - def direct_write( + def write( self, blocks: Iterable[Block], ctx: TaskContext, @@ -72,30 +72,6 @@ def write_block(uri: str, database: str, collection: str, block: Block): # succeeds. return "ok" - @Deprecated - def do_write( - self, - blocks: List[ObjectRef[Block]], - metadata: List[BlockMetadata], - ray_remote_args: Optional[Dict[str, Any]], - uri: str, - database: str, - collection: str, - ) -> List[ObjectRef[WriteResult]]: - def write_block(block_idx, block): - ctx = TaskContext(task_idx=block_idx) - return self.direct_write([block], ctx, uri, database, collection) - - if ray_remote_args is None: - ray_remote_args = {} - - write_block = cached_remote_fn(write_block).options(**ray_remote_args) - write_tasks = [] - for idx, block in enumerate(blocks): - write_task = write_block.remote(idx, block) - write_tasks.append(write_task) - return write_tasks - class _MongoDatasourceReader(Reader): def __init__( diff --git a/python/ray/data/tests/test_dataset_formats.py b/python/ray/data/tests/test_dataset_formats.py index 0965b5be4b29..f8f5364c584d 100644 --- a/python/ray/data/tests/test_dataset_formats.py +++ b/python/ray/data/tests/test_dataset_formats.py @@ -248,7 +248,7 @@ def get_node_ids(self): self.num_ok = 0 self.num_failed = 0 - def direct_write( + def write( self, blocks: Iterable[Block], ctx: TaskContext, From 10ef980a07848caf64034b551374683ca0e0b11d Mon Sep 17 00:00:00 2001 From: jianoaix Date: Tue, 7 Feb 2023 00:43:59 +0000 Subject: [PATCH 24/27] unnecessary test removed --- python/ray/data/tests/test_dataset.py | 11 -------- python/ray/data/tests/test_mongo_dataset.py | 29 --------------------- 2 files changed, 40 deletions(-) diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index c3bd9693150d..15fc7a8c8bcb 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -5516,17 +5516,6 @@ def test_ragged_tensors(ray_start_regular_shared): ArrowVariableShapedTensorType(dtype=new_type, ndim=3), ] - -def test_legacy_do_write(ray_start_regular_shared, tmp_path): - out = CSVDatasource() - ds = ray.data.range(100).fully_executed() - blocks, metadata = zip(*ds._plan.execute().get_blocks_with_metadata()) - out.do_write(blocks, metadata, tmp_path, ds._uuid) - - ds = ray.data.read_csv(tmp_path) - assert [e["value"] for e in ds.take_all()] == list(range(100)) - - class LoggerWarningCalled(Exception): """Custom exception used in test_warning_execute_with_no_cpu() and test_nowarning_execute_with_cpu(). Raised when the `logger.warning` method diff --git a/python/ray/data/tests/test_mongo_dataset.py b/python/ray/data/tests/test_mongo_dataset.py index e55571a07fc1..d87d83421003 100644 --- a/python/ray/data/tests/test_mongo_dataset.py +++ b/python/ray/data/tests/test_mongo_dataset.py @@ -257,35 +257,6 @@ def test_mongo_datasource(ray_start_regular_shared, start_mongo): df[df["int_field"] < 3].equals(ds.drop_columns(["_id"]).to_pandas()) -def test_legacy_do_write(ray_start_regular_shared, start_mongo): - client, mongo_url = start_mongo - foo_db = "foo-db" - foo_collection = "foo-collection" - foo = client[foo_db][foo_collection] - foo.delete_many({}) - - docs = [{"float_field": 2.0 * val, "int_field": val} for val in range(5)] - foo.insert_many(docs) - - docs = [{"float_field": 2.0 * val, "int_field": val} for val in range(5, 10)] - df = pd.DataFrame(docs).astype({"int_field": "int32"}) - ds = ray.data.from_pandas(df) - - out = MongoDatasource() - blocks, metadata = zip(*ds._plan.execute().get_blocks_with_metadata()) - tasks = out.do_write(blocks, metadata, {}, mongo_url, foo_db, foo_collection) - ray.get(tasks) - - ds = ray.data.read_mongo( - uri=mongo_url, - database=foo_db, - collection=foo_collection, - ) - docs = [{"float_field": 2.0 * val, "int_field": val} for val in range(10)] - df = pd.DataFrame(docs).astype({"int_field": "int32"}) - assert df.equals(ds.drop_columns(["_id"]).to_pandas()) - - if __name__ == "__main__": import sys From 48e94158d1e652fc1e9cd552321c81a3dff46bbd Mon Sep 17 00:00:00 2001 From: jianoaix Date: Tue, 7 Feb 2023 00:50:57 +0000 Subject: [PATCH 25/27] fix --- python/ray/data/dataset.py | 6 ++---- python/ray/data/datasource/datasource.py | 4 ++-- python/ray/data/datasource/file_based_datasource.py | 4 +--- python/ray/data/datasource/mongo_datasource.py | 6 ++---- python/ray/data/tests/test_dataset.py | 1 + python/ray/data/tests/test_dataset_formats.py | 2 +- 6 files changed, 9 insertions(+), 14 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index b5e0ebeaba6d..22cccd0f46fd 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2658,9 +2658,7 @@ def write_datasource( # WriteResult (one element per write task). Otherwise, an error will # be raised. The Datasource can handle execution outcomes with the # on_write_complete() and on_write_failed(). - def transform( - blocks: Iterable[Block], ctx, fn - ) -> Iterable[Block]: + def transform(blocks: Iterable[Block], ctx, fn) -> Iterable[Block]: return [[datasource.write(blocks, ctx, **write_args)]] plan = self._plan.with_stage( @@ -2675,7 +2673,7 @@ def transform( try: self._write_ds = Dataset(plan, self._epoch, self._lazy).fully_executed() datasource.on_write_complete( - self._write_ds._plan.execute().get_blocks() + ray.get(self._write_ds._plan.execute().get_blocks()) ) except Exception as e: datasource.on_write_failed([], e) diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index 4bc7f4efe74d..f8b409af1294 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -362,8 +362,8 @@ def write( ray.get(tasks) return "ok" - def on_write_complete(self, write_results: List[ObjectRef[WriteResult]]) -> None: - assert all(ray.get(w) == ["ok"] for w in write_results), write_results + def on_write_complete(self, write_results: List[WriteResult]) -> None: + assert all(w == ["ok"] for w in write_results), write_results self.num_ok += 1 def on_write_failed( diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index 0c2cf63573a8..2badee7158f6 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -17,10 +17,8 @@ from ray.data._internal.arrow_block import ArrowRow from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder -from ray.data._internal.block_list import BlockMetadata from ray.data._internal.execution.interfaces import TaskContext from ray.data._internal.output_buffer import BlockOutputBuffer -from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.util import _check_pyarrow_version, _resolve_custom_scheme from ray.data.block import Block, BlockAccessor from ray.data.context import DatasetContext @@ -36,7 +34,7 @@ ) from ray.types import ObjectRef -from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI +from ray.util.annotations import DeveloperAPI, PublicAPI from ray._private.utils import _add_creatable_buckets_param_if_s3_uri if TYPE_CHECKING: diff --git a/python/ray/data/datasource/mongo_datasource.py b/python/ray/data/datasource/mongo_datasource.py index e19d108b1b1f..ef35497bbe7f 100644 --- a/python/ray/data/datasource/mongo_datasource.py +++ b/python/ray/data/datasource/mongo_datasource.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional, TYPE_CHECKING from ray.data.datasource.datasource import Datasource, Reader, ReadTask, WriteResult from ray.data.block import ( @@ -8,10 +8,8 @@ BlockMetadata, ) from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder -from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.execution.interfaces import TaskContext -from ray.types import ObjectRef -from ray.util.annotations import Deprecated, PublicAPI +from ray.util.annotations import PublicAPI from typing import Iterable if TYPE_CHECKING: diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 15fc7a8c8bcb..99ace502471f 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -5516,6 +5516,7 @@ def test_ragged_tensors(ray_start_regular_shared): ArrowVariableShapedTensorType(dtype=new_type, ndim=3), ] + class LoggerWarningCalled(Exception): """Custom exception used in test_warning_execute_with_no_cpu() and test_nowarning_execute_with_cpu(). Raised when the `logger.warning` method diff --git a/python/ray/data/tests/test_dataset_formats.py b/python/ray/data/tests/test_dataset_formats.py index f8f5364c584d..94e07ad126a2 100644 --- a/python/ray/data/tests/test_dataset_formats.py +++ b/python/ray/data/tests/test_dataset_formats.py @@ -267,7 +267,7 @@ def write(b): return "ok" def on_write_complete(self, write_results: List[WriteResult]) -> None: - assert all(ray.get(w) == ["ok"] for w in write_results), write_results + assert all(w == ["ok"] for w in write_results), write_results self.num_ok += 1 def on_write_failed( From 87dc9252e971aca026dad9ec456f37af8f8dca06 Mon Sep 17 00:00:00 2001 From: jianoaix Date: Tue, 7 Feb 2023 16:21:23 +0000 Subject: [PATCH 26/27] deprecation message/logging --- python/ray/data/dataset.py | 4 ++++ python/ray/data/datasource/datasource.py | 8 +++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 22cccd0f46fd..41e447403e88 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2684,6 +2684,10 @@ def transform(blocks: Iterable[Block], ctx, fn) -> Iterable[Block]: # TODO(ekl) remove this feature flag. if "RAY_DATASET_FORCE_LOCAL_METADATA" in os.environ: + logger.warning( + "RAY_DATASET_FORCE_LOCAL_METADATA is deprecated in " + "Ray 2.4 and will be removed in future release." + ) write_results: List[ObjectRef[WriteResult]] = datasource.do_write( blocks, metadata, ray_remote_args=ray_remote_args, **write_args ) diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index f8b409af1294..0919ffba8c8d 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -32,7 +32,7 @@ class Datasource(Generic[T]): of how to implement readable and writable datasources. Datasource instances must be serializable, since ``create_reader()`` and - ``do_write()`` are called in remote tasks. + ``write()`` are called in remote tasks. """ def create_reader(self, **read_args) -> "Reader[T]": @@ -63,11 +63,13 @@ def write( write_args: Additional kwargs to pass to the datasource impl. Returns: - The output of the write tasks. + The output of the write task. """ raise NotImplementedError - @Deprecated + @Deprecated( + message="do_write() is deprecated in Ray 2.4. Use write() instead", warning=True + ) def do_write( self, blocks: List[ObjectRef[Block]], From b77ca8dec7ef200d8df15ddcd82553a781e437d8 Mon Sep 17 00:00:00 2001 From: jianoaix Date: Tue, 7 Feb 2023 18:09:28 +0000 Subject: [PATCH 27/27] deprecation logging --- python/ray/data/dataset.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 41e447403e88..ed4d8f0ac3b9 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2684,14 +2684,15 @@ def transform(blocks: Iterable[Block], ctx, fn) -> Iterable[Block]: # TODO(ekl) remove this feature flag. if "RAY_DATASET_FORCE_LOCAL_METADATA" in os.environ: - logger.warning( - "RAY_DATASET_FORCE_LOCAL_METADATA is deprecated in " - "Ray 2.4 and will be removed in future release." - ) write_results: List[ObjectRef[WriteResult]] = datasource.do_write( blocks, metadata, ray_remote_args=ray_remote_args, **write_args ) else: + logger.warning( + "The Datasource.do_write() is deprecated in " + "Ray 2.4 and will be removed in future release. Use " + "Datasource.write() instead." + ) # Prepare write in a remote task so that in Ray client mode, we # don't do metadata resolution from the client machine. do_write = cached_remote_fn(