Skip to content

Commit

Permalink
Fix bootstrapper wrapper not being reset correctly (#2574)
Browse files Browse the repository at this point in the history
* fix + add to tests

* changelog

* update values

---------

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
3 people authored May 31, 2024
1 parent 744905c commit cce449b
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed class order of `panoptic_quality(..., return_per_class=True)` output ([#2548](https://github.com/Lightning-AI/torchmetrics/pull/2548))


- Fixed `BootstrapWrapper` not being reset correctly ([#2574](https://github.com/Lightning-AI/torchmetrics/pull/2574))


## [1.4.0] - 2024-05-03

### Added
Expand Down
2 changes: 1 addition & 1 deletion docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ of metrics e.g. computation of confidence intervals by resampling of input data.
.. testoutput::
:options: +NORMALIZE_WHITESPACE

{'mean': tensor(0.1476), 'std': tensor(0.0613)}
{'mean': tensor(0.1333), 'std': tensor(0.1554)}

You can see all implemented wrappers under the wrapper section of the API docs.

Expand Down
6 changes: 6 additions & 0 deletions src/torchmetrics/wrappers/bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Use the original forward method of the base metric class."""
return super(WrapperMetric, self).forward(*args, **kwargs)

def reset(self) -> None:
"""Reset the state of the base metric."""
for m in self.metrics:
m.reset()
super().reset()

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
Expand Down
11 changes: 11 additions & 0 deletions tests/unittests/wrappers/test_bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ def test_bootstrap(device, sampling_strategy, metric, ref_metric):
assert np.allclose(output["std"].cpu(), np.std(sk_scores, ddof=1))
assert np.allclose(output["raw"].cpu(), sk_scores)

# check that resetting works
bootstrapper.reset()

assert bootstrapper.update_count == 0
assert all(m.update_count == 0 for m in bootstrapper.metrics)
output = bootstrapper.compute()
if not isinstance(metric, MeanSquaredError):
assert output["mean"] == 0
assert output["std"] == 0
assert (output["raw"] == torch.zeros(10, device=device)).all()


@pytest.mark.parametrize("sampling_strategy", ["poisson", "multinomial"])
def test_low_sample_amount(sampling_strategy):
Expand Down

0 comments on commit cce449b

Please sign in to comment.