From cf955141f90bb644a040c55bacb0d88801d62d5d Mon Sep 17 00:00:00 2001 From: Yunxuan Xiao Date: Tue, 7 Feb 2023 14:56:34 -0800 Subject: [PATCH] [Tune] Add repr for ResultGrid class (#31941) Add __repr__() for ResultGrid class and prettify __repr__() of Result class. Signed-off-by: Yunxuan Xiao Co-authored-by: Yunxuan Xiao --- python/ray/air/result.py | 23 +++++++-- python/ray/tune/result_grid.py | 16 +++++-- python/ray/tune/tests/test_result_grid.py | 57 ++++++++++++++++++++++- python/ray/tune/tests/test_tuner.py | 4 ++ 4 files changed, 90 insertions(+), 10 deletions(-) diff --git a/python/ray/air/result.py b/python/ray/air/result.py index 002a80b770f0..f3aaf9d57edf 100644 --- a/python/ray/air/result.py +++ b/python/ray/air/result.py @@ -44,7 +44,7 @@ class Result: log_dir: Optional[Path] metrics_dataframe: Optional["pd.DataFrame"] best_checkpoints: Optional[List[Tuple[Checkpoint, Dict[str, Any]]]] - _items_to_repr = ["metrics", "error", "log_dir"] + _items_to_repr = ["error", "metrics", "log_dir", "checkpoint"] @property def config(self) -> Optional[Dict[str, Any]]: @@ -53,14 +53,29 @@ def config(self) -> Optional[Dict[str, Any]]: return None return self.metrics.get("config", None) - def __repr__(self): + def _repr(self, indent: int = 0) -> str: + """Construct the representation with specified number of space indent.""" from ray.tune.result import AUTO_RESULT_KEYS shown_attributes = {k: self.__dict__[k] for k in self._items_to_repr} + if self.error: + shown_attributes["error"] = type(self.error).__name__ + else: + shown_attributes.pop("error") if self.metrics: shown_attributes["metrics"] = { k: v for k, v in self.metrics.items() if k not in AUTO_RESULT_KEYS } - kws = [f"{key}={value!r}" for key, value in shown_attributes.items()] - return "{}({})".format(type(self).__name__, ", ".join(kws)) + + cls_indent = " " * indent + kws_indent = " " * (indent + 2) + + kws = [ + f"{kws_indent}{key}={value!r}" for key, value in shown_attributes.items() + ] + kws_repr = ",\n".join(kws) + return "{0}{1}(\n{2}\n{0})".format(cls_indent, type(self).__name__, kws_repr) + + def __repr__(self) -> str: + return self._repr(indent=0) diff --git a/python/ray/tune/result_grid.py b/python/ray/tune/result_grid.py index 82fcb325890f..61e5ec9550be 100644 --- a/python/ray/tune/result_grid.py +++ b/python/ray/tune/result_grid.py @@ -69,6 +69,9 @@ def __init__( experiment_analysis: ExperimentAnalysis, ): self._experiment_analysis = experiment_analysis + self._results = [ + self._trial_to_result(trial) for trial in self._experiment_analysis.trials + ] def get_best_result( self, @@ -178,13 +181,11 @@ def get_dataframe( ) def __len__(self) -> int: - return len(self._experiment_analysis.trials) + return len(self._results) def __getitem__(self, i: int) -> Result: """Returns the i'th result in the grid.""" - return self._trial_to_result( - self._experiment_analysis.trials[i], - ) + return self._results[i] @property def errors(self): @@ -235,8 +236,13 @@ def _trial_to_result(self, trial: Trial) -> Result: metrics_dataframe=self._experiment_analysis.trial_dataframes.get( trial.logdir ) - if self._experiment_analysis + if self._experiment_analysis.trial_dataframes else None, best_checkpoints=best_checkpoints, ) return result + + def __repr__(self) -> str: + all_results_repr = [result._repr(indent=2) for result in self] + all_results_repr = ",\n".join(all_results_repr) + return f"ResultGrid<[\n{all_results_repr}\n]>" diff --git a/python/ray/tune/tests/test_result_grid.py b/python/ray/tune/tests/test_result_grid.py index d2873573ef96..81086370fb50 100644 --- a/python/ray/tune/tests/test_result_grid.py +++ b/python/ray/tune/tests/test_result_grid.py @@ -12,6 +12,7 @@ from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint from ray import air, tune from ray.air import Checkpoint, session +from ray.air.result import Result from ray.tune.registry import get_trainable_cls from ray.tune.result_grid import ResultGrid from ray.tune.experiment import Trial @@ -142,7 +143,12 @@ def test_result_grid_future_checkpoint(ray_start_2_cpus, to_object): ) trial.pickled_error_filename = None trial.error_filename = None - result_grid = ResultGrid(None) + + class MockExperimentAnalysis: + trials = [] + trial_dataframes = None + + result_grid = ResultGrid(MockExperimentAnalysis()) # Internal result grid conversion result = result_grid._trial_to_result(trial) @@ -224,6 +230,55 @@ def f(config): assert not any(key in representation for key in AUTO_RESULT_KEYS) +def test_result_grid_repr(): + class MockExperimentAnalysis: + trials = [] + + result_grid = ResultGrid(experiment_analysis=MockExperimentAnalysis()) + + result_grid._results = [ + Result( + metrics={"loss": 1.0}, + checkpoint=Checkpoint(data_dict={"weight": 1.0}), + log_dir=Path("./log_1"), + error=None, + metrics_dataframe=None, + best_checkpoints=None, + ), + Result( + metrics={"loss": 2.0}, + checkpoint=Checkpoint(data_dict={"weight": 2.0}), + log_dir=Path("./log_2"), + error=RuntimeError(), + metrics_dataframe=None, + best_checkpoints=None, + ), + ] + + representation = result_grid.__repr__() + + from ray.tune.result import AUTO_RESULT_KEYS + + assert len(result_grid) == 2 + assert not any(key in representation for key in AUTO_RESULT_KEYS) + + expected_repr = """ResultGrid<[ + Result( + metrics={'loss': 1.0}, + log_dir=PosixPath('log_1'), + checkpoint=Checkpoint(data_dict={'weight': 1.0}) + ), + Result( + error='RuntimeError', + metrics={'loss': 2.0}, + log_dir=PosixPath('log_2'), + checkpoint=Checkpoint(data_dict={'weight': 2.0}) + ) +]>""" + + assert representation == expected_repr + + def test_no_metric_mode(ray_start_2_cpus): def f(config): tune.report(x=1) diff --git a/python/ray/tune/tests/test_tuner.py b/python/ray/tune/tests/test_tuner.py index d32a676e8b89..843186693f7a 100644 --- a/python/ray/tune/tests/test_tuner.py +++ b/python/ray/tune/tests/test_tuner.py @@ -343,8 +343,12 @@ def test_tuner_api_kwargs(shutdown_only, params_expected): caught_kwargs = {} + class MockExperimentAnalysis: + trials = [] + def catch_kwargs(**kwargs): caught_kwargs.update(kwargs) + return MockExperimentAnalysis() with patch("ray.tune.impl.tuner_internal.run", catch_kwargs): tuner.fit()