From 9d76f3f9f1b0f067869e67f0bd426035a332b28c Mon Sep 17 00:00:00 2001 From: Su YR Date: Tue, 5 Mar 2024 22:09:23 +0800 Subject: [PATCH] Feat: make `__getattr__` and `__setattr__` of ClasswiseWrapper more general (#2424) * fix: MetricCollection did not copy inner state of metric in ClasswiseWrapper when computing groups metrics Issue Link: https://github.com/Lightning-AI/torchmetrics/issues/2389 * fix: set _persistent and _reductions be same as internal metric * test: check metric state_dict wrapped in `ClasswiseWrapper` * refactor: make __getattr__ and __setattr__ of ClasswiseWrapper more general * chlog --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka --- CHANGELOG.md | 2 +- src/torchmetrics/wrappers/classwise.py | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f2aa69ba29..f970dbc9df2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- +- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424)) ### Deprecated diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 698d0f51848..0920118c919 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -216,19 +216,19 @@ def plot( def __getattr__(self, name: str) -> Union[Tensor, "Module"]: """Get attribute from classwise wrapper.""" - # return state from self.metric - if name in ["tp", "fp", "fn", "tn"]: - return getattr(self.metric, name) + if name == "metric" or (name in self.__dict__ and name not in self.metric.__dict__): + # we need this to prevent from infinite getattribute loop. + return super().__getattr__(name) - return super().__getattr__(name) + return getattr(self.metric, name) def __setattr__(self, name: str, value: Any) -> None: """Set attribute to classwise wrapper.""" - super().__setattr__(name, value) - if name == "metric": - self._defaults = self.metric._defaults - self._persistent = self.metric._persistent - self._reductions = self.metric._reductions - if hasattr(self, "metric") and name in ["tp", "fp", "fn", "tn", "_update_count", "_computed"]: - # update ``_update_count`` and ``_computed`` of internal metric to prevent warning. + if hasattr(self, "metric") and name in self.metric._defaults: setattr(self.metric, name, value) + else: + super().__setattr__(name, value) + if name == "metric": + self._defaults = self.metric._defaults + self._persistent = self.metric._persistent + self._reductions = self.metric._reductions