Skip to content

Commit

Permalink
Prevent iteration over metrics (#1320)
Browse files Browse the repository at this point in the history
* fix

* changelog

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
SkafteNicki and mergify[bot] authored Nov 8, 2022
1 parent 920fe0f commit 00bb1ed
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug in `Metrictracker.best_metric` when `return_step=False` ([#1306](https://github.com/Lightning-AI/metrics/pull/1306))


- Fixed bug to prevent users from going into a infinite loop if trying to iterate of a single metric ([#1320](https://github.com/Lightning-AI/metrics/pull/1320))


## [0.10.2] - 2022-10-31

### Changed
Expand Down
3 changes: 3 additions & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,9 @@ def __getitem__(self, idx: int) -> "Metric":
def __getnewargs__(self) -> Tuple:
return (Metric.__str__(self),)

def __iter__(self):
raise NotImplementedError("Metrics does not support iteration.")


def _neg(x: Tensor) -> Tensor:
return -torch.abs(x)
Expand Down
7 changes: 7 additions & 0 deletions tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,3 +464,10 @@ def test_custom_availability_check_and_sync_fn():
acc.compute()
dummy_availability_check.assert_called_once()
assert dummy_dist_sync_fn.call_count == 4 # tp, fp, tn, fn


def test_no_iteration_allowed():
metric = DummyMetric()
with pytest.raises(NotImplementedError, match="Metrics does not support iteration."):
for m in metric:
continue

0 comments on commit 00bb1ed

Please sign in to comment.