Skip to content

Commit

Permalink
[Tune] Add repr for ResultGrid class (#31941)
Browse files Browse the repository at this point in the history
Add __repr__() for ResultGrid class and prettify __repr__() of Result class.

Signed-off-by: Yunxuan Xiao <yunxuanx@Yunxuans-MBP.local.meter>
Co-authored-by: Yunxuan Xiao <yunxuanx@Yunxuans-MBP.local.meter>
  • Loading branch information
woshiyyya and Yunxuan Xiao authored Feb 7, 2023
1 parent 9995599 commit cf95514
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 10 deletions.
23 changes: 19 additions & 4 deletions python/ray/air/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Result:
log_dir: Optional[Path]
metrics_dataframe: Optional["pd.DataFrame"]
best_checkpoints: Optional[List[Tuple[Checkpoint, Dict[str, Any]]]]
_items_to_repr = ["metrics", "error", "log_dir"]
_items_to_repr = ["error", "metrics", "log_dir", "checkpoint"]

@property
def config(self) -> Optional[Dict[str, Any]]:
Expand All @@ -53,14 +53,29 @@ def config(self) -> Optional[Dict[str, Any]]:
return None
return self.metrics.get("config", None)

def __repr__(self):
def _repr(self, indent: int = 0) -> str:
"""Construct the representation with specified number of space indent."""
from ray.tune.result import AUTO_RESULT_KEYS

shown_attributes = {k: self.__dict__[k] for k in self._items_to_repr}
if self.error:
shown_attributes["error"] = type(self.error).__name__
else:
shown_attributes.pop("error")

if self.metrics:
shown_attributes["metrics"] = {
k: v for k, v in self.metrics.items() if k not in AUTO_RESULT_KEYS
}
kws = [f"{key}={value!r}" for key, value in shown_attributes.items()]
return "{}({})".format(type(self).__name__, ", ".join(kws))

cls_indent = " " * indent
kws_indent = " " * (indent + 2)

kws = [
f"{kws_indent}{key}={value!r}" for key, value in shown_attributes.items()
]
kws_repr = ",\n".join(kws)
return "{0}{1}(\n{2}\n{0})".format(cls_indent, type(self).__name__, kws_repr)

def __repr__(self) -> str:
return self._repr(indent=0)
16 changes: 11 additions & 5 deletions python/ray/tune/result_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def __init__(
experiment_analysis: ExperimentAnalysis,
):
self._experiment_analysis = experiment_analysis
self._results = [
self._trial_to_result(trial) for trial in self._experiment_analysis.trials
]

def get_best_result(
self,
Expand Down Expand Up @@ -178,13 +181,11 @@ def get_dataframe(
)

def __len__(self) -> int:
return len(self._experiment_analysis.trials)
return len(self._results)

def __getitem__(self, i: int) -> Result:
"""Returns the i'th result in the grid."""
return self._trial_to_result(
self._experiment_analysis.trials[i],
)
return self._results[i]

@property
def errors(self):
Expand Down Expand Up @@ -235,8 +236,13 @@ def _trial_to_result(self, trial: Trial) -> Result:
metrics_dataframe=self._experiment_analysis.trial_dataframes.get(
trial.logdir
)
if self._experiment_analysis
if self._experiment_analysis.trial_dataframes
else None,
best_checkpoints=best_checkpoints,
)
return result

def __repr__(self) -> str:
all_results_repr = [result._repr(indent=2) for result in self]
all_results_repr = ",\n".join(all_results_repr)
return f"ResultGrid<[\n{all_results_repr}\n]>"
57 changes: 56 additions & 1 deletion python/ray/tune/tests/test_result_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
from ray import air, tune
from ray.air import Checkpoint, session
from ray.air.result import Result
from ray.tune.registry import get_trainable_cls
from ray.tune.result_grid import ResultGrid
from ray.tune.experiment import Trial
Expand Down Expand Up @@ -142,7 +143,12 @@ def test_result_grid_future_checkpoint(ray_start_2_cpus, to_object):
)
trial.pickled_error_filename = None
trial.error_filename = None
result_grid = ResultGrid(None)

class MockExperimentAnalysis:
trials = []
trial_dataframes = None

result_grid = ResultGrid(MockExperimentAnalysis())

# Internal result grid conversion
result = result_grid._trial_to_result(trial)
Expand Down Expand Up @@ -224,6 +230,55 @@ def f(config):
assert not any(key in representation for key in AUTO_RESULT_KEYS)


def test_result_grid_repr():
class MockExperimentAnalysis:
trials = []

result_grid = ResultGrid(experiment_analysis=MockExperimentAnalysis())

result_grid._results = [
Result(
metrics={"loss": 1.0},
checkpoint=Checkpoint(data_dict={"weight": 1.0}),
log_dir=Path("./log_1"),
error=None,
metrics_dataframe=None,
best_checkpoints=None,
),
Result(
metrics={"loss": 2.0},
checkpoint=Checkpoint(data_dict={"weight": 2.0}),
log_dir=Path("./log_2"),
error=RuntimeError(),
metrics_dataframe=None,
best_checkpoints=None,
),
]

representation = result_grid.__repr__()

from ray.tune.result import AUTO_RESULT_KEYS

assert len(result_grid) == 2
assert not any(key in representation for key in AUTO_RESULT_KEYS)

expected_repr = """ResultGrid<[
Result(
metrics={'loss': 1.0},
log_dir=PosixPath('log_1'),
checkpoint=Checkpoint(data_dict={'weight': 1.0})
),
Result(
error='RuntimeError',
metrics={'loss': 2.0},
log_dir=PosixPath('log_2'),
checkpoint=Checkpoint(data_dict={'weight': 2.0})
)
]>"""

assert representation == expected_repr


def test_no_metric_mode(ray_start_2_cpus):
def f(config):
tune.report(x=1)
Expand Down
4 changes: 4 additions & 0 deletions python/ray/tune/tests/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,12 @@ def test_tuner_api_kwargs(shutdown_only, params_expected):

caught_kwargs = {}

class MockExperimentAnalysis:
trials = []

def catch_kwargs(**kwargs):
caught_kwargs.update(kwargs)
return MockExperimentAnalysis()

with patch("ray.tune.impl.tuner_internal.run", catch_kwargs):
tuner.fit()
Expand Down

0 comments on commit cf95514

Please sign in to comment.