Skip to content

Commit

Permalink
Add write operator in new logical plan (ray-project#32440)
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
jianoaix authored and edoakes committed Mar 22, 2023
1 parent d8f1b32 commit 597e566
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 9 deletions.
21 changes: 21 additions & 0 deletions python/ray/data/_internal/logical/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from ray.data.block import BatchUDF, RowUDF
from ray.data.context import DEFAULT_BATCH_SIZE
from ray.data.datasource import Datasource


if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -119,6 +120,26 @@ def __init__(
)


class Write(AbstractMap):
"""Logical operator for write."""

def __init__(
self,
input_op: LogicalOperator,
datasource: Datasource,
ray_remote_args: Optional[Dict[str, Any]] = None,
**write_args,
):
super().__init__(
"Write",
input_op,
fn=lambda x: x,
ray_remote_args=ray_remote_args,
)
self._datasource = datasource
self._write_args = write_args


class Filter(AbstractMap):
"""Logical operator for filter."""

Expand Down
4 changes: 4 additions & 0 deletions python/ray/data/_internal/planner/plan_map_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
FlatMap,
MapBatches,
MapRows,
Write,
)
from ray.data._internal.planner.filter import generate_filter_fn
from ray.data._internal.planner.flat_map import generate_flat_map_fn
from ray.data._internal.planner.map_batches import generate_map_batches_fn
from ray.data._internal.planner.map_rows import generate_map_rows_fn
from ray.data._internal.planner.write import generate_write_fn
from ray.data.block import Block, CallableClass


Expand All @@ -41,6 +43,8 @@ def _plan_map_op(op: AbstractMap, input_physical_dag: PhysicalOperator) -> MapOp
transform_fn = generate_flat_map_fn()
elif isinstance(op, Filter):
transform_fn = generate_filter_fn()
elif isinstance(op, Write):
transform_fn = generate_write_fn(op._datasource, **op._write_args)
else:
raise ValueError(f"Found unknown logical operator during planning: {op}")

Expand Down
18 changes: 18 additions & 0 deletions python/ray/data/_internal/planner/write.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Callable, Iterator

from ray.data._internal.execution.interfaces import TaskContext
from ray.data.block import Block, RowUDF
from ray.data.datasource import Datasource


def generate_write_fn(
datasource: Datasource, **write_args
) -> Callable[[Iterator[Block], TaskContext, RowUDF], Iterator[Block]]:
# If the write op succeeds, the resulting Dataset is a list of
# WriteResult (one element per write task). Otherwise, an error will
# be raised. The Datasource can handle execution outcomes with the
# on_write_complete() and on_write_failed().
def fn(blocks: Iterator[Block], ctx, fn) -> Iterator[Block]:
return [[datasource.write(blocks, ctx, **write_args)]]

return fn
26 changes: 17 additions & 9 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@
FlatMap,
MapRows,
MapBatches,
Write,
)
from ray.data._internal.planner.filter import generate_filter_fn
from ray.data._internal.planner.flat_map import generate_flat_map_fn
from ray.data._internal.planner.map_batches import generate_map_batches_fn
from ray.data._internal.planner.map_rows import generate_map_rows_fn
from ray.data._internal.planner.write import generate_write_fn
from ray.data.dataset_iterator import DatasetIterator
from ray.data._internal.block_batching import batch_block_refs
from ray.data._internal.block_list import BlockList
Expand Down Expand Up @@ -2686,24 +2688,30 @@ def write_datasource(
)

if hasattr(datasource, "write"):
# If the write operator succeeds, the resulting Dataset is a list of
# WriteResult (one element per write task). Otherwise, an error will
# be raised. The Datasource can handle execution outcomes with the
# on_write_complete() and on_write_failed().
def transform(blocks: Iterable[Block], ctx, fn) -> Iterable[Block]:
return [[datasource.write(blocks, ctx, **write_args)]]

plan = self._plan.with_stage(
OneToOneStage(
"write",
transform,
generate_write_fn(datasource, **write_args),
"tasks",
ray_remote_args,
fn=lambda x: x,
)
)

logical_plan = self._logical_plan
if logical_plan is not None:
write_op = Write(
logical_plan.dag,
datasource,
ray_remote_args=ray_remote_args,
**write_args,
)
logical_plan = LogicalPlan(write_op)

try:
self._write_ds = Dataset(plan, self._epoch, self._lazy).fully_executed()
self._write_ds = Dataset(
plan, self._epoch, self._lazy, logical_plan
).fully_executed()
datasource.on_write_complete(
ray.get(self._write_ds._plan.execute().get_blocks())
)
Expand Down
6 changes: 6 additions & 0 deletions python/ray/data/tests/test_execution_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,12 @@ def test_read_map_chain_operator_fusion_e2e(ray_start_regular_shared, enable_opt
assert name in ds.stats()


def test_write_operator(ray_start_regular_shared, enable_optimizer, tmp_path):
ds = ray.data.range(10, parallelism=2)
ds.write_csv(tmp_path)
assert "DoRead->Write" in ds._write_ds.stats()


def test_sort_operator(ray_start_regular_shared, enable_optimizer):
planner = Planner()
read_op = Read(ParquetDatasource())
Expand Down

0 comments on commit 597e566

Please sign in to comment.