Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: make __getattr__ and __setattr__ of ClasswiseWrapper more general #2424

Merged
merged 11 commits into from
Mar 5, 2024
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424))


### Deprecated
Expand Down
22 changes: 11 additions & 11 deletions src/torchmetrics/wrappers/classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,19 +216,19 @@ def plot(

def __getattr__(self, name: str) -> Union[Tensor, "Module"]:
"""Get attribute from classwise wrapper."""
# return state from self.metric
if name in ["tp", "fp", "fn", "tn"]:
return getattr(self.metric, name)
if name == "metric" or (name in self.__dict__ and name not in self.metric.__dict__):
# we need this to prevent from infinite getattribute loop.
return super().__getattr__(name)

return super().__getattr__(name)
return getattr(self.metric, name)

def __setattr__(self, name: str, value: Any) -> None:
"""Set attribute to classwise wrapper."""
super().__setattr__(name, value)
if name == "metric":
self._defaults = self.metric._defaults
self._persistent = self.metric._persistent
self._reductions = self.metric._reductions
if hasattr(self, "metric") and name in ["tp", "fp", "fn", "tn", "_update_count", "_computed"]:
# update ``_update_count`` and ``_computed`` of internal metric to prevent warning.
if hasattr(self, "metric") and name in self.metric._defaults:
setattr(self.metric, name, value)
else:
super().__setattr__(name, value)
if name == "metric":
self._defaults = self.metric._defaults
self._persistent = self.metric._persistent
self._reductions = self.metric._reductions
Loading