diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 8b6fc9c5bed6..ed4d8f0ac3b9 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2640,8 +2640,6 @@ def write_datasource( ray_remote_args: Kwargs passed to ray.remote in the write tasks. write_args: Additional write args to pass to the datasource. """ - - ctx = DatasetContext.get_current() if ray_remote_args is None: ray_remote_args = {} path = write_args.get("path", None) @@ -2655,37 +2653,71 @@ def write_datasource( soft=False, ) - blocks, metadata = zip(*self._plan.execute().get_blocks_with_metadata()) - - # TODO(ekl) remove this feature flag. - if "RAY_DATASET_FORCE_LOCAL_METADATA" in os.environ: - write_results: List[ObjectRef[WriteResult]] = datasource.do_write( - blocks, metadata, ray_remote_args=ray_remote_args, **write_args - ) - else: - # Prepare write in a remote task so that in Ray client mode, we - # don't do metadata resolution from the client machine. - do_write = cached_remote_fn(_do_write, retry_exceptions=False, num_cpus=0) - write_results: List[ObjectRef[WriteResult]] = ray.get( - do_write.remote( - datasource, - ctx, - blocks, - metadata, + if hasattr(datasource, "write"): + # If the write operator succeeds, the resulting Dataset is a list of + # WriteResult (one element per write task). Otherwise, an error will + # be raised. The Datasource can handle execution outcomes with the + # on_write_complete() and on_write_failed(). + def transform(blocks: Iterable[Block], ctx, fn) -> Iterable[Block]: + return [[datasource.write(blocks, ctx, **write_args)]] + + plan = self._plan.with_stage( + OneToOneStage( + "write", + transform, + "tasks", ray_remote_args, - _wrap_arrow_serialization_workaround(write_args), + fn=lambda x: x, ) ) + try: + self._write_ds = Dataset(plan, self._epoch, self._lazy).fully_executed() + datasource.on_write_complete( + ray.get(self._write_ds._plan.execute().get_blocks()) + ) + except Exception as e: + datasource.on_write_failed([], e) + raise + else: + ctx = DatasetContext.get_current() + blocks, metadata = zip(*self._plan.execute().get_blocks_with_metadata()) - progress = ProgressBar("Write Progress", len(write_results)) - try: - progress.block_until_complete(write_results) - datasource.on_write_complete(ray.get(write_results)) - except Exception as e: - datasource.on_write_failed(write_results, e) - raise - finally: - progress.close() + # TODO(ekl) remove this feature flag. + if "RAY_DATASET_FORCE_LOCAL_METADATA" in os.environ: + write_results: List[ObjectRef[WriteResult]] = datasource.do_write( + blocks, metadata, ray_remote_args=ray_remote_args, **write_args + ) + else: + logger.warning( + "The Datasource.do_write() is deprecated in " + "Ray 2.4 and will be removed in future release. Use " + "Datasource.write() instead." + ) + # Prepare write in a remote task so that in Ray client mode, we + # don't do metadata resolution from the client machine. + do_write = cached_remote_fn( + _do_write, retry_exceptions=False, num_cpus=0 + ) + write_results: List[ObjectRef[WriteResult]] = ray.get( + do_write.remote( + datasource, + ctx, + blocks, + metadata, + ray_remote_args, + _wrap_arrow_serialization_workaround(write_args), + ) + ) + + progress = ProgressBar("Write Progress", len(write_results)) + try: + progress.block_until_complete(write_results) + datasource.on_write_complete(ray.get(write_results)) + except Exception as e: + datasource.on_write_failed(write_results, e) + raise + finally: + progress.close() def iterator(self) -> DatasetIterator: """Return a :class:`~ray.data.DatasetIterator` that diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index 98c0eb722f1c..0919ffba8c8d 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -6,6 +6,7 @@ import ray from ray.data._internal.arrow_block import ArrowRow from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.execution.interfaces import TaskContext from ray.data._internal.util import _check_pyarrow_version from ray.data.block import ( Block, @@ -31,7 +32,7 @@ class Datasource(Generic[T]): of how to implement readable and writable datasources. Datasource instances must be serializable, since ``create_reader()`` and - ``do_write()`` are called in remote tasks. + ``write()`` are called in remote tasks. """ def create_reader(self, **read_args) -> "Reader[T]": @@ -50,6 +51,25 @@ def prepare_read(self, parallelism: int, **read_args) -> List["ReadTask[T]"]: """Deprecated: Please implement create_reader() instead.""" raise NotImplementedError + def write( + self, + blocks: Iterable[Block], + **write_args, + ) -> WriteResult: + """Write blocks out to the datasource. This is used by a single write task. + + Args: + blocks: List of data blocks. + write_args: Additional kwargs to pass to the datasource impl. + + Returns: + The output of the write task. + """ + raise NotImplementedError + + @Deprecated( + message="do_write() is deprecated in Ray 2.4. Use write() instead", warning=True + ) def do_write( self, blocks: List[ObjectRef[Block]], @@ -319,35 +339,33 @@ def __init__(self): def write(self, block: Block) -> str: block = BlockAccessor.for_block(block) - if not self.enabled: - raise ValueError("disabled") self.rows_written += block.num_rows() return "ok" def get_rows_written(self): return self.rows_written - def set_enabled(self, enabled): - self.enabled = enabled - self.data_sink = DataSink.remote() self.num_ok = 0 self.num_failed = 0 + self.enabled = True - def do_write( + def write( self, - blocks: List[ObjectRef[Block]], - metadata: List[BlockMetadata], - ray_remote_args: Dict[str, Any], + blocks: Iterable[Block], + ctx: TaskContext, **write_args, - ) -> List[ObjectRef[WriteResult]]: + ) -> WriteResult: tasks = [] + if not self.enabled: + raise ValueError("disabled") for b in blocks: tasks.append(self.data_sink.write.remote(b)) - return tasks + ray.get(tasks) + return "ok" def on_write_complete(self, write_results: List[WriteResult]) -> None: - assert all(w == "ok" for w in write_results), write_results + assert all(w == ["ok"] for w in write_results), write_results self.num_ok += 1 def on_write_failed( diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index e4cd15550ba3..2badee7158f6 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -16,9 +16,9 @@ ) from ray.data._internal.arrow_block import ArrowRow -from ray.data._internal.block_list import BlockMetadata +from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.execution.interfaces import TaskContext from ray.data._internal.output_buffer import BlockOutputBuffer -from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.util import _check_pyarrow_version, _resolve_custom_scheme from ray.data.block import Block, BlockAccessor from ray.data.context import DatasetContext @@ -60,7 +60,7 @@ def _get_write_path_for_block( *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, dataset_uuid: Optional[str] = None, - block: Optional[ObjectRef[Block]] = None, + block: Optional[Block] = None, block_index: Optional[int] = None, file_format: Optional[str] = None, ) -> str: @@ -77,7 +77,7 @@ def _get_write_path_for_block( write a file out to the write path returned. dataset_uuid: Unique identifier for the dataset that this block belongs to. - block: Object reference to the block to write. + block: The block to write. block_index: Ordered index of the block to write within its parent dataset. file_format: File format string for the block that can be used as @@ -94,7 +94,7 @@ def __call__( *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, dataset_uuid: Optional[str] = None, - block: Optional[ObjectRef[Block]] = None, + block: Optional[Block] = None, block_index: Optional[int] = None, file_format: Optional[str] = None, ) -> str: @@ -257,10 +257,10 @@ def _convert_block_to_tabular_block( "then you need to implement `_convert_block_to_tabular_block." ) - def do_write( + def write( self, - blocks: List[ObjectRef[Block]], - metadata: List[BlockMetadata], + blocks: Iterable[Block], + ctx: TaskContext, path: str, dataset_uuid: str, filesystem: Optional["pyarrow.fs.FileSystem"] = None, @@ -269,10 +269,9 @@ def do_write( block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(), write_args_fn: Callable[[], Dict[str, Any]] = lambda: {}, _block_udf: Optional[Callable[[Block], Block]] = None, - ray_remote_args: Dict[str, Any] = None, **write_args, - ) -> List[ObjectRef[WriteResult]]: - """Creates and returns write tasks for a file-based datasource.""" + ) -> WriteResult: + """Write blocks for a file-based datasource.""" path, filesystem = _resolve_paths_and_filesystem(path, filesystem) path = path[0] if try_create_dir: @@ -287,9 +286,6 @@ def do_write( if open_stream_args is None: open_stream_args = {} - if ray_remote_args is None: - ray_remote_args = {} - def write_block(write_path: str, block: Block): logger.debug(f"Writing {write_path} file.") fs = filesystem @@ -305,29 +301,30 @@ def write_block(write_path: str, block: Block): writer_args_fn=write_args_fn, **write_args, ) - - write_block = cached_remote_fn(write_block).options(**ray_remote_args) + # TODO: decide if we want to return richer object when the task + # succeeds. + return "ok" file_format = self._FILE_EXTENSION if isinstance(file_format, list): file_format = file_format[0] - write_tasks = [] + builder = DelegatingBlockBuilder() + for block in blocks: + builder.add_block(block) + block = builder.build() + if not block_path_provider: block_path_provider = DefaultBlockWritePathProvider() - for block_idx, block in enumerate(blocks): - write_path = block_path_provider( - path, - filesystem=filesystem, - dataset_uuid=dataset_uuid, - block=block, - block_index=block_idx, - file_format=file_format, - ) - write_task = write_block.remote(write_path, block) - write_tasks.append(write_task) - - return write_tasks + write_path = block_path_provider( + path, + filesystem=filesystem, + dataset_uuid=dataset_uuid, + block=block, + block_index=ctx.task_idx, + file_format=file_format, + ) + return write_block(write_path, block) def _write_block( self, diff --git a/python/ray/data/datasource/mongo_datasource.py b/python/ray/data/datasource/mongo_datasource.py index f1153271c532..ef35497bbe7f 100644 --- a/python/ray/data/datasource/mongo_datasource.py +++ b/python/ray/data/datasource/mongo_datasource.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional, TYPE_CHECKING from ray.data.datasource.datasource import Datasource, Reader, ReadTask, WriteResult from ray.data.block import ( @@ -7,9 +7,10 @@ BlockAccessor, BlockMetadata, ) -from ray.data._internal.remote_fn import cached_remote_fn -from ray.types import ObjectRef +from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.execution.interfaces import TaskContext from ray.util.annotations import PublicAPI +from typing import Iterable if TYPE_CHECKING: import pymongoarrow.api @@ -37,15 +38,14 @@ class MongoDatasource(Datasource): def create_reader(self, **kwargs) -> Reader: return _MongoDatasourceReader(**kwargs) - def do_write( + def write( self, - blocks: List[ObjectRef[Block]], - metadata: List[BlockMetadata], - ray_remote_args: Optional[Dict[str, Any]], + blocks: Iterable[Block], + ctx: TaskContext, uri: str, database: str, collection: str, - ) -> List[ObjectRef[WriteResult]]: + ) -> WriteResult: import pymongo _validate_database_collection_exist( @@ -59,15 +59,16 @@ def write_block(uri: str, database: str, collection: str, block: Block): client = pymongo.MongoClient(uri) write(client[database][collection], block) - if ray_remote_args is None: - ray_remote_args = {} - - write_block = cached_remote_fn(write_block).options(**ray_remote_args) - write_tasks = [] + builder = DelegatingBlockBuilder() for block in blocks: - write_task = write_block.remote(uri, database, collection, block) - write_tasks.append(write_task) - return write_tasks + builder.add_block(block) + block = builder.build() + + write_block(uri, database, collection, block) + + # TODO: decide if we want to return richer object when the task + # succeeds. + return "ok" class _MongoDatasourceReader(Reader): diff --git a/python/ray/data/tests/conftest.py b/python/ray/data/tests/conftest.py index 0d1cb84deb92..2cf6f8abe2d3 100644 --- a/python/ray/data/tests/conftest.py +++ b/python/ray/data/tests/conftest.py @@ -162,7 +162,7 @@ def _get_write_path_for_block( block_index=None, file_format=None, ): - num_rows = BlockAccessor.for_block(ray.get(block)).num_rows() + num_rows = BlockAccessor.for_block(block).num_rows() suffix = ( f"{block_index:06}_{num_rows:02}_{dataset_uuid}" f".test.{file_format}" ) diff --git a/python/ray/data/tests/test_dataset_formats.py b/python/ray/data/tests/test_dataset_formats.py index efbbef88886f..94e07ad126a2 100644 --- a/python/ray/data/tests/test_dataset_formats.py +++ b/python/ray/data/tests/test_dataset_formats.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Union +from typing import List, Union import pandas as pd import pyarrow as pa @@ -13,7 +13,8 @@ import ray from ray.data._internal.arrow_block import ArrowRow -from ray.data.block import Block, BlockAccessor, BlockMetadata +from ray.data._internal.execution.interfaces import TaskContext +from ray.data.block import Block, BlockAccessor from ray.data.datasource import ( Datasource, DummyOutputDatasource, @@ -24,6 +25,7 @@ from ray.data.tests.mock_http_server import * # noqa from ray.tests.conftest import * # noqa from ray.types import ObjectRef +from typing import Iterable def maybe_pipeline(ds, enabled): @@ -179,10 +181,10 @@ def test_write_datasource(ray_start_regular_shared, pipelined): assert output.num_failed == 0 assert ray.get(output.data_sink.get_rows_written.remote()) == 10 - ray.get(output.data_sink.set_enabled.remote(False)) - ds = maybe_pipeline(ds0, pipelined) + output.enabled = False + ds = maybe_pipeline(ray.data.range(10, parallelism=2), pipelined) with pytest.raises(ValueError): - ds.write_datasource(output) + ds.write_datasource(output, ray_remote_args={"max_retries": 0}) if pipelined: assert output.num_ok == 2 else: @@ -228,13 +230,10 @@ def __init__(self): class DataSink: def __init__(self): self.rows_written = 0 - self.enabled = True self.node_ids = set() def write(self, node_id: str, block: Block) -> str: block = BlockAccessor.for_block(block) - if not self.enabled: - raise ValueError("disabled") self.rows_written += block.num_rows() self.node_ids.add(node_id) return "ok" @@ -245,34 +244,30 @@ def get_rows_written(self): def get_node_ids(self): return self.node_ids - def set_enabled(self, enabled): - self.enabled = enabled - self.data_sink = DataSink.remote() self.num_ok = 0 self.num_failed = 0 - def do_write( + def write( self, - blocks: List[ObjectRef[Block]], - metadata: List[BlockMetadata], - ray_remote_args: Dict[str, Any], + blocks: Iterable[Block], + ctx: TaskContext, **write_args, - ) -> List[ObjectRef[WriteResult]]: + ) -> WriteResult: data_sink = self.data_sink - @ray.remote def write(b): node_id = ray.get_runtime_context().get_node_id() - return ray.get(data_sink.write.remote(node_id, b)) + return data_sink.write.remote(node_id, b) tasks = [] for b in blocks: - tasks.append(write.options(**ray_remote_args).remote(b)) - return tasks + tasks.append(write(b)) + ray.get(tasks) + return "ok" def on_write_complete(self, write_results: List[WriteResult]) -> None: - assert all(w == "ok" for w in write_results), write_results + assert all(w == ["ok"] for w in write_results), write_results self.num_ok += 1 def on_write_failed( @@ -282,6 +277,7 @@ def on_write_failed( def test_write_datasource_ray_remote_args(ray_start_cluster): + ray.shutdown() cluster = ray_start_cluster cluster.add_node( resources={"foo": 100}, @@ -299,7 +295,7 @@ def get_node_id(): output = NodeLoggerOutputDatasource() ds = ray.data.range(100, parallelism=10) - # Pin write tasks to + # Pin write tasks to node with "bar" resource. ds.write_datasource(output, ray_remote_args={"resources": {"bar": 1}}) assert output.num_ok == 1 assert output.num_failed == 0 diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 83b2c9b6cd01..1cc091391b7f 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -1,6 +1,4 @@ -import os import pytest -import pandas as pd import ray from ray.data._internal.execution.operators.map_operator import MapOperator @@ -545,20 +543,23 @@ def test_sort_e2e( ds = ds.sort() assert ds.take_all() == list(range(100)) - df = pd.DataFrame({"one": list(range(100)), "two": ["a"] * 100}) - ds = ray.data.from_pandas([df]) - path = os.path.join(local_path, "test_parquet_dir") - os.mkdir(path) - ds.write_parquet(path) - - ds = ray.data.read_parquet(path) - ds = ds.random_shuffle() - ds1 = ds.sort("one") - ds2 = ds.sort("one", descending=True) - r1 = ds1.select_columns(["one"]).take_all() - r2 = ds2.select_columns(["one"]).take_all() - assert [d["one"] for d in r1] == list(range(100)) - assert [d["one"] for d in r2] == list(reversed(range(100))) + # TODO: write_XXX and from_XXX are not supported yet in new execution plan. + # Re-enable once supported. + + # df = pd.DataFrame({"one": list(range(100)), "two": ["a"] * 100}) + # ds = ray.data.from_pandas([df]) + # path = os.path.join(local_path, "test_parquet_dir") + # os.mkdir(path) + # ds.write_parquet(path) + + # ds = ray.data.read_parquet(path) + # ds = ds.random_shuffle() + # ds1 = ds.sort("one") + # ds2 = ds.sort("one", descending=True) + # r1 = ds1.select_columns(["one"]).take_all() + # r2 = ds2.select_columns(["one"]).take_all() + # assert [d["one"] for d in r1] == list(range(100)) + # assert [d["one"] for d in r2] == list(reversed(range(100))) if __name__ == "__main__": diff --git a/python/ray/data/tests/test_optimize.py b/python/ray/data/tests/test_optimize.py index 97b62885e1b2..317167faefb1 100644 --- a/python/ray/data/tests/test_optimize.py +++ b/python/ray/data/tests/test_optimize.py @@ -349,6 +349,31 @@ def test_window_randomize_fusion(ray_start_regular_shared): assert "read->randomize_block_order->MapBatches(dummy_map)" in stats, stats +def test_write_fusion(ray_start_regular_shared, tmp_path): + context = DatasetContext.get_current() + context.optimize_fuse_stages = True + context.optimize_fuse_read_stages = True + context.optimize_fuse_shuffle_stages = True + + path = os.path.join(tmp_path, "out") + ds = ray.data.range(100).map_batches(lambda x: x) + ds.write_csv(path) + stats = ds._write_ds.stats() + assert "read->MapBatches()->write" in stats, stats + + ds = ( + ray.data.range(100) + .map_batches(lambda x: x) + .random_shuffle() + .map_batches(lambda x: x) + ) + ds.write_csv(path) + stats = ds._write_ds.stats() + assert "read->MapBatches()" in stats, stats + assert "random_shuffle" in stats, stats + assert "MapBatches()->write" in stats, stats + + def test_optimize_fuse(ray_start_regular_shared): context = DatasetContext.get_current() diff --git a/python/ray/data/tests/test_size_estimation.py b/python/ray/data/tests/test_size_estimation.py index 35cdce44fed9..018a85394260 100644 --- a/python/ray/data/tests/test_size_estimation.py +++ b/python/ray/data/tests/test_size_estimation.py @@ -158,9 +158,16 @@ def test_split_read_parquet(ray_start_regular_shared, tmp_path): def gen(name): path = os.path.join(tmp_path, name) - ray.data.range(200000, parallelism=1).map( - lambda _: uuid.uuid4().hex - ).write_parquet(path) + ds = ( + ray.data.range(200000, parallelism=1) + .map(lambda _: uuid.uuid4().hex) + .fully_executed() + ) + # Fully execute the operations prior to write, because with + # parallelism=1, there is only one task; so the write operator + # will only write to one file, even though there are multiple + # blocks created by block splitting. + ds.write_parquet(path) return ray.data.read_parquet(path, parallelism=200) # 20MiB