Skip to content

Commit

Permalink
Improve Exception when metric_to_watch is wrong (#1437)
Browse files Browse the repository at this point in the history
* fix

* update

* improve phrasing
  • Loading branch information
Louis-Dupont authored Aug 31, 2023
1 parent f8723ca commit 1964250
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
8 changes: 6 additions & 2 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,11 +557,15 @@ def _init_monitored_items(self):

# make sure the metric_to_watch is an exact match
metric_titles = self.loss_logging_items_names + get_metrics_titles(self.valid_metrics)
metric_to_watch_idx = fuzzy_idx_in_list(self.metric_to_watch, metric_titles)
try:
metric_to_watch_idx = fuzzy_idx_in_list(self.metric_to_watch, metric_titles)
except IndexError:
raise ValueError(f"No match found for `metric_to_watch={self.metric_to_watch}`. Available metrics to monitor are: `{metric_titles}`.")

metric_to_watch = metric_titles[metric_to_watch_idx]
if metric_to_watch != self.metric_to_watch:
logger.warning(
f"No exact match found for `metric_to_watch={self.metric_to_watch}`. It should be one of {metric_titles}. \n"
f"No exact match found for `metric_to_watch={self.metric_to_watch}`. Available metrics to monitor are: `{metric_titles}`. \n"
f"`metric_to_watch={metric_to_watch} will be used instead.`"
)
self.metric_to_watch = metric_to_watch
Expand Down
7 changes: 6 additions & 1 deletion src/super_gradients/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,12 @@ def fuzzy_idx_in_list(name: str, lst: List[str]) -> int:
:param lst: List[str], the list as described above.
:return: int, index of name in lst in the matter discussed above.
"""
return [fuzzy_str(x) for x in lst].index(fuzzy_str(name))
fuzzy_name = fuzzy_str(name)
fuzzy_list = [fuzzy_str(x) for x in lst]
if fuzzy_name in fuzzy_list:
return fuzzy_list.index(fuzzy_name)
else:
raise IndexError(f"Value `{name}` not found in the list `{lst}`. Please check the spelling.")


def get_param(params, name, default_val=None):
Expand Down

0 comments on commit 1964250

Please sign in to comment.