Skip to content

Commit

Permalink
[Tune] Add repr for ResultGrid class
Browse files Browse the repository at this point in the history
Signed-off-by: Yunxuan Xiao <yunxuanx@Yunxuans-MBP.local.meter>
  • Loading branch information
Yunxuan Xiao authored and Yunxuan Xiao committed Jan 27, 2023
1 parent fe65c3e commit fac8f78
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
10 changes: 8 additions & 2 deletions python/ray/air/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Result:
log_dir: Optional[Path]
metrics_dataframe: Optional["pd.DataFrame"]
best_checkpoints: Optional[List[Tuple[Checkpoint, Dict[str, Any]]]]
_items_to_repr = ["metrics", "error", "log_dir"]
_items_to_repr = ["error", "metrics", "log_dir", "checkpoint"]

@property
def config(self) -> Optional[Dict[str, Any]]:
Expand All @@ -57,10 +57,16 @@ def __repr__(self):
from ray.tune.result import AUTO_RESULT_KEYS

shown_attributes = {k: self.__dict__[k] for k in self._items_to_repr}
if self.error:
shown_attributes["error"] = type(self.error).__name__
else:
shown_attributes.pop("error")

if self.metrics:
shown_attributes["metrics"] = {
k: v for k, v in self.metrics.items() if k not in AUTO_RESULT_KEYS
}

kws = [f"{key}={value!r}" for key, value in shown_attributes.items()]
return "{}({})".format(type(self).__name__, ", ".join(kws))
kws_repr = " " + ",\n ".join(kws)
return "{}(\n{}\n)".format(type(self).__name__, kws_repr)
7 changes: 7 additions & 0 deletions python/ray/tune/result_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,10 @@ def _trial_to_result(self, trial: Trial) -> Result:
best_checkpoints=best_checkpoints,
)
return result

def __repr__(self) -> str:
all_results_repr = ""
for result in self:
result_repr = " " + result.__repr__().replace("\n", "\n ")
all_results_repr += f"{result_repr},\n"
return f"[\n{all_results_repr}]"
33 changes: 33 additions & 0 deletions python/ray/tune/tests/test_result_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,39 @@ def f(config):
assert not any(key in representation for key in AUTO_RESULT_KEYS)


def test_result_grid_repr(ray_start_2_cpus, tmpdir):
from ray.air import session

def f(config):
if config["x"] == 1:
raise RuntimeError

metrics = {"loss": 1}
session.report(metrics, checkpoint=Checkpoint.from_dict(metrics))

tuner = tune.Tuner(
f,
run_config=air.RunConfig(
name="test_result_grid_repr",
local_dir=str(tmpdir / "test_result_grid_repr_results"),
),
param_space={"x": tune.grid_search([1, 2])},
)
result_grid = tuner.fit()

representation = result_grid.__repr__()

from ray.tune.result import AUTO_RESULT_KEYS

assert not any(key in representation for key in AUTO_RESULT_KEYS)

assert len(result_grid) == 2
assert representation.count("metrics=") == 2
assert representation.count("log_dir=") == 2
assert representation.count("checkpoint=") == 2
assert representation.count("error=") == 1 and "RuntimeError" in representation


def test_no_metric_mode(ray_start_2_cpus):
def f(config):
tune.report(x=1)
Expand Down

0 comments on commit fac8f78

Please sign in to comment.