From 8077110597db438ac867112887cd439399ec0223 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Fri, 10 Feb 2023 10:35:53 -0800 Subject: [PATCH] [Datasets] Not change `map_batches()` UDF name in `Dataset.__repr__` (#32411) This is to fix the Dataset.__repr__ issue in #32410, after we introduce function name in #31526. We should only make operator/stage name to be camel case. Signed-off-by: Cheng Su --- python/ray/data/_internal/plan.py | 9 ++++++++- python/ray/data/tests/test_dataset.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/python/ray/data/_internal/plan.py b/python/ray/data/_internal/plan.py index c1c30078daf5..ecd588adf745 100644 --- a/python/ray/data/_internal/plan.py +++ b/python/ray/data/_internal/plan.py @@ -182,7 +182,14 @@ def get_plan_as_string(self) -> str: # Get string representation of each stage in reverse order. for stage in self._stages_after_snapshot[::-1]: # Get name of each stage in camel case. - stage_name = capitalize(stage.name) + # The stage representation should be in "(...)" format, + # e.g. "MapBatches(my_udf)". + # + # TODO(chengsu): create a class to represent stage name to make it less + # fragile to parse. + stage_str = stage.name.split("(") + stage_str[0] = capitalize(stage_str[0]) + stage_name = "(".join(stage_str) if num_stages == 0: plan_str += f"{stage_name}\n" else: diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 4cd1d39aa5a8..aff8d0bb6c02 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -1556,6 +1556,16 @@ def test_dataset_repr(ray_start_regular_shared): "Zip\n" "+- Dataset(num_blocks=10, num_rows=9, schema=)" ) + def my_dummy_fn(x): + return x + + ds = ray.data.range(10, parallelism=10) + ds = ds.map_batches(my_dummy_fn) + assert repr(ds) == ( + "MapBatches(my_dummy_fn)\n" + "+- Dataset(num_blocks=10, num_rows=10, schema=)" + ) + @pytest.mark.parametrize("lazy", [False, True]) def test_limit(ray_start_regular_shared, lazy):