Skip to content

Commit

Permalink
Feat: make __getattr__ and __setattr__ of ClasswiseWrapper more g…
Browse files Browse the repository at this point in the history
…eneral (#2424)

* fix: MetricCollection did not copy inner state of metric in ClasswiseWrapper when computing groups metrics

Issue Link: #2389

* fix: set _persistent and _reductions be same as internal metric

* test: check metric state_dict wrapped in `ClasswiseWrapper`

* refactor: make __getattr__ and __setattr__ of ClasswiseWrapper more general

* chlog

---------

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
3 people authored Mar 5, 2024
1 parent 1951a06 commit 9d76f3f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424))


### Deprecated
Expand Down
22 changes: 11 additions & 11 deletions src/torchmetrics/wrappers/classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,19 +216,19 @@ def plot(

def __getattr__(self, name: str) -> Union[Tensor, "Module"]:
"""Get attribute from classwise wrapper."""
# return state from self.metric
if name in ["tp", "fp", "fn", "tn"]:
return getattr(self.metric, name)
if name == "metric" or (name in self.__dict__ and name not in self.metric.__dict__):
# we need this to prevent from infinite getattribute loop.
return super().__getattr__(name)

return super().__getattr__(name)
return getattr(self.metric, name)

def __setattr__(self, name: str, value: Any) -> None:
"""Set attribute to classwise wrapper."""
super().__setattr__(name, value)
if name == "metric":
self._defaults = self.metric._defaults
self._persistent = self.metric._persistent
self._reductions = self.metric._reductions
if hasattr(self, "metric") and name in ["tp", "fp", "fn", "tn", "_update_count", "_computed"]:
# update ``_update_count`` and ``_computed`` of internal metric to prevent warning.
if hasattr(self, "metric") and name in self.metric._defaults:
setattr(self.metric, name, value)
else:
super().__setattr__(name, value)
if name == "metric":
self._defaults = self.metric._defaults
self._persistent = self.metric._persistent
self._reductions = self.metric._reductions

0 comments on commit 9d76f3f

Please sign in to comment.