Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ClasswiseWrapper and JaccardIndex confmat attribute error #2465

Closed
isaaccorley opened this issue Mar 21, 2024 · 2 comments
Closed

ClasswiseWrapper and JaccardIndex confmat attribute error #2465

isaaccorley opened this issue Mar 21, 2024 · 2 comments
Labels
bug / fix Something isn't working help wanted Extra attention is needed

Comments

@isaaccorley
Copy link

isaaccorley commented Mar 21, 2024

🐛 Bug

Using JaccardIndex and ClasswiseWrapper results in an error when trying to list the named_children of a MetricCollection.

To Reproduce

from torchmetrics.classification import JaccardIndex
from torchmetrics import MetricCollection
from torchmetrics.wrappers import ClasswiseWrapper

metrics = MetricCollection(
    {
        "IoU": ClasswiseWrapper(
            JaccardIndex(
                task="multiclass", num_classes=2, average="none"
            ),
            labels=["cat", "dog"],
        ),
    },
    prefix="train_",
)

for name, module in metrics.named_children():
    print(name, module)
AttributeError                            Traceback (most recent call last)
Cell In[9], [line 1](vscode-notebook-cell:?execution_count=9&line=1)
----> [1](vscode-notebook-cell:?execution_count=9&line=1) for name, module in metrics.named_children():
      [2](vscode-notebook-cell:?execution_count=9&line=2)     print(name, module)

File [/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2302](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2302), in Module.named_children(self)
   [2300](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2300) memo = set()
   [2301](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2301) for name, module in self._modules.items():
-> [2302](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2302)     if module is not None and module not in memo:
   [2303](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2303)         memo.add(module)
   [2304](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:2304)         yield name, module

File [/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:928](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:928), in Metric.__hash__(self)
    [925](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:925) hash_vals = [self.__class__.__name__, id(self)]
    [927](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:927) for key in self._defaults:
--> [928](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:928)     val = getattr(self, key)
    [929](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:929)     # Special case: allow list values, so long
    [930](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:930)     # as their elements are hashable
    [931](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/metric.py:931)     if hasattr(val, "__iter__") and not isinstance(val, Tensor):

File [/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/wrappers/classwise.py:223](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/wrappers/classwise.py:223), in ClasswiseWrapper.__getattr__(self, name)
    [220](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/wrappers/classwise.py:220) if name in ["tp", "fp", "fn", "tn"]:
    [221](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/wrappers/classwise.py:221)     return getattr(self.metric, name)
--> [223](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torchmetrics/wrappers/classwise.py:223) return super().__getattr__(name)

File [/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1688](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1688), in Module.__getattr__(self, name)
   [1686](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1686)     if name in modules:
   [1687](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1687)         return modules[name]
-> [1688](https://vscode-remote+ssh-002dremote-002bml-002disaac.vscode-resource.vscode-cdn.net/opt/conda/envs/dbenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1688) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

AttributeError: 'ClasswiseWrapper' object has no attribute 'confmat'

Expected behavior

No error should occur.

Environment

  • TorchMetrics: 1.3.2
  • Python: 3.10.11
  • PyTorch: 2.2.1
  • Any other relevant information such as OS (e.g., Linux): Ubuntu

Additional context

I found this bug by trying to train using PyTorch Lightning. During the trainer setup a list of all the named_children of the LightningModule occurs and this error is then raised.

I noticed in torchmetrics/wrappers/classwise.py that the __getattr__ has a check for ["tp", "fp", "fn", "tn"]. I think maybe confmat should be added to this?

    def __getattr__(self, name: str) -> Union[Tensor, "Module"]:
        """Get attribute from classwise wrapper."""
        # return state from self.metric
        if name in ["tp", "fp", "fn", "tn"]:
            return getattr(self.metric, name)

        return super().__getattr__(name)
@isaaccorley isaaccorley added bug / fix Something isn't working help wanted Extra attention is needed labels Mar 21, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

@isaaccorley
Copy link
Author

isaaccorley commented Mar 21, 2024

Closing because I see this is fixed in #2424. Would love to see this in a release!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant