diff --git a/CHANGELOG.md b/CHANGELOG.md index 646427e9fcf..e42adf6c0d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed restrictive dtype checking in `spearman_corrcoef` when used with autocast ([#1303](https://github.com/Lightning-AI/metrics/pull/1303)) +- Fixed bug in `Metrictracker.best_metric` when `return_step=False` ([#1306](https://github.com/Lightning-AI/metrics/pull/1306)) + + ## [0.10.1] - 2022-10-21 ### Fixed diff --git a/src/torchmetrics/wrappers/tracker.py b/src/torchmetrics/wrappers/tracker.py index 0b222dd91b1..25aaa19441a 100644 --- a/src/torchmetrics/wrappers/tracker.py +++ b/src/torchmetrics/wrappers/tracker.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import warnings from copy import deepcopy from typing import Any, Dict, List, Tuple, Union @@ -21,6 +20,7 @@ from torchmetrics.collections import MetricCollection from torchmetrics.metric import Metric +from torchmetrics.utilities.prints import rank_zero_warn class MetricTracker(ModuleList): @@ -171,12 +171,12 @@ def best_metric( if isinstance(self._base_metric, Metric): fn = torch.max if self.maximize else torch.min try: - idx, best = fn(self.compute_all(), 0) + value, idx = fn(self.compute_all(), 0) if return_step: - return idx.item(), best.item() - return best.item() + return value.item(), idx.item() + return value.item() except ValueError as error: - warnings.warn( + rank_zero_warn( f"Encountered the following error when trying to get the best metric: {error}" "this is probably due to the 'best' not being defined for this metric." "Returning `None` instead.", @@ -189,24 +189,24 @@ def best_metric( else: # this is a metric collection res = self.compute_all() maximize = self.maximize if isinstance(self.maximize, list) else len(res) * [self.maximize] - idx, best = {}, {} + value, idx = {}, {} for i, (k, v) in enumerate(res.items()): try: fn = torch.max if maximize[i] else torch.min out = fn(v, 0) - idx[k], best[k] = out[0].item(), out[1].item() + value[k], idx[k] = out[0].item(), out[1].item() except ValueError as error: - warnings.warn( + rank_zero_warn( f"Encountered the following error when trying to get the best metric for metric {k}:" f"{error} this is probably due to the 'best' not being defined for this metric." "Returning `None` instead.", UserWarning, ) - idx[k], best[k] = None, None + value[k], idx[k] = None, None if return_step: - return idx, best - return best + return value, idx + return value def _check_for_increment(self, method: str) -> None: if not self._increment_called: diff --git a/tests/unittests/wrappers/test_tracker.py b/tests/unittests/wrappers/test_tracker.py index dab589b30d1..a112cfb9bae 100644 --- a/tests/unittests/wrappers/test_tracker.py +++ b/tests/unittests/wrappers/test_tracker.py @@ -126,6 +126,9 @@ def test_tracker(base_metric, metric_input, maximize): assert val != 0.0 assert idx in list(range(5)) + val2 = tracker.best_metric(return_step=False) + assert val == val2 + @pytest.mark.parametrize( "base_metric",