Skip to content

Commit

Permalink
[Datasets] Not change map_batches() UDF name in Dataset.__repr__ (r…
Browse files Browse the repository at this point in the history
…ay-project#32411)

This is to fix the Dataset.__repr__ issue in ray-project#32410, after we introduce function name in ray-project#31526. We should only make operator/stage name to be camel case.

Signed-off-by: Cheng Su <scnju13@gmail.com>
  • Loading branch information
c21 committed Feb 10, 2023
1 parent 0d92c42 commit 8077110
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
9 changes: 8 additions & 1 deletion python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,14 @@ def get_plan_as_string(self) -> str:
# Get string representation of each stage in reverse order.
for stage in self._stages_after_snapshot[::-1]:
# Get name of each stage in camel case.
stage_name = capitalize(stage.name)
# The stage representation should be in "<stage-name>(...)" format,
# e.g. "MapBatches(my_udf)".
#
# TODO(chengsu): create a class to represent stage name to make it less
# fragile to parse.
stage_str = stage.name.split("(")
stage_str[0] = capitalize(stage_str[0])
stage_name = "(".join(stage_str)
if num_stages == 0:
plan_str += f"{stage_name}\n"
else:
Expand Down
10 changes: 10 additions & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,16 @@ def test_dataset_repr(ray_start_regular_shared):
"Zip\n" "+- Dataset(num_blocks=10, num_rows=9, schema=<class 'int'>)"
)

def my_dummy_fn(x):
return x

ds = ray.data.range(10, parallelism=10)
ds = ds.map_batches(my_dummy_fn)
assert repr(ds) == (
"MapBatches(my_dummy_fn)\n"
"+- Dataset(num_blocks=10, num_rows=10, schema=<class 'int'>)"
)


@pytest.mark.parametrize("lazy", [False, True])
def test_limit(ray_start_regular_shared, lazy):
Expand Down

0 comments on commit 8077110

Please sign in to comment.