Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric not moved to device #531

Closed
cowwoc opened this issue Sep 18, 2021 · 14 comments · Fixed by #542
Closed

Metric not moved to device #531

cowwoc opened this issue Sep 18, 2021 · 14 comments · Fixed by #542
Labels
help wanted Extra attention is needed question Further information is requested

Comments

@cowwoc
Copy link

cowwoc commented Sep 18, 2021

🐛 Bug

Version 1.4.7

Per https://torchmetrics.readthedocs.io/en/latest/pages/overview.html#metrics-and-devices if a metric is properly defined (identified as a child of a module) then it is supposed to be automatically moved to the same device as the module. Unfortunately, in my own project this does not occur.

When I run this code:

class MyModel(LightningModule):
  def __init__(self):
    self.accuracy= Accuracy()

  def forward(self, input):
    print(f"self.device: {self.device}")
    print(f"self.accuracy.device: {self.accuracy.device}")

I get:

self.device: cuda:0
self.accuracy.device: cpu

Expected behavior

I expect the metric to be on cuda:0.

Environment

  • PyTorch Version (e.g., 1.0): 1.9.0+cu111
  • OS (e.g., Linux): Windows 10.0.19043.1237
  • How you installed PyTorch (conda, pip, source): pip
  • Python version: 3.9.7
  • CUDA/cuDNN version: 11.1
  • GPU models and configuration: GeForce RTX 3080
@cowwoc cowwoc added bug / fix Something isn't working help wanted Extra attention is needed labels Sep 18, 2021
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@cowwoc
Copy link
Author

cowwoc commented Sep 18, 2021

Interestingly, any states registered inside the metric are moved to the right device but self.device inside the metric evaluates to the wrong value. The reason this is relevant is I've got methods in my metric that create new Tensors on self.device. I suspect other users will also expect self.device to map to the same device used by the metric's state variables.

@Borda Borda added duplicate This issue or pull request already exists question Further information is requested and removed bug / fix Something isn't working labels Sep 18, 2021
@Borda
Copy link
Member

Borda commented Sep 18, 2021

see #340

@Borda Borda closed this as completed Sep 18, 2021
@cowwoc
Copy link
Author

cowwoc commented Sep 18, 2021

@Borda The linked issue does not resolve this problem. I am already doing what it recommends.

The documentation claims that the metric's device will be updated, but it is not. I consider this a bug report not a question.

Either the documentation or implementation are wrong. Please reopen this issue.

@Borda
Copy link
Member

Borda commented Sep 18, 2021

I see, then the issue is in docs, no metrics is automatically moved unless you use it with PL logging... Mind send PR fix for it?

@Borda Borda reopened this Sep 18, 2021
@cowwoc
Copy link
Author

cowwoc commented Sep 18, 2021

The example code found at https://torchmetrics.readthedocs.io/en/latest/pages/overview.html#metrics-and-devices does say anything about having to use PL logging. Quoting the relevant parts:

when properly defined inside a Module or LightningModule the metric will be be automatically move to the same device as the the module when using .to(device)

from torchmetrics import Accuracy, MetricCollection

class MyModule(torch.nn.Module):
    def __init__(self):
        ...
        # valid ways metrics will be identified as child modules
        self.metric1 = Accuracy()
        self.metric2 = nn.ModuleList(Accuracy())
        self.metric3 = nn.ModuleDict({'accuracy': Accuracy()})
        self.metric4 = MetricCollection([Accuracy()]) # torchmetrics build-in collection class

    def forward(self, batch):
        data, target = batch
        preds = self(data)
        ...
        val1 = self.metric1(preds, target)
        val2 = self.metric2[0](preds, target)
        val3 = self.metric3['accuracy'](preds, target)
        val4 = self.metric4(preds, target)

It sounds a bit odd that you have to use PL logging in order for a metric to get moved to the correct device... Can you point me to the relevant code in PL that moves the metric?

@SkafteNicki
Copy link
Member

The problem seems to be that if you call .cuda on a parent module, it does not execute:

for m in self.modules():
    m.cuda()

but it instead calls self._apply which will call

for module in self.children():
    module._apply(fn)

this will move the metric states to the correct device, but currently the metric.device is only updated when the .cuda, .cpu, .to methods are executed.

@SkafteNicki SkafteNicki removed the duplicate This issue or pull request already exists label Sep 21, 2021
This was referenced Sep 21, 2021
@jlehrer1
Copy link

jlehrer1 commented Aug 17, 2023

This is not resolved in 2.0.1. Setting up a LightningModule with the init like

        self.metrics = {
            "train": {"accuracy", Accuracy(task="binary")},
            "val": {"accuracy", Accuracy(task="binary")}
        }

And logging with

    def training_step(self, batch, batch_idx):
        loss, probs = self(batch)
        self.log(f"train_loss", loss, on_epoch=True, on_step=True)
        for name, metric in self.metrics["train"].items():
            value = metric(probs, batch[1])
            self.log(f"train_{name}", value=value)

Gives the error

RuntimeError: Encountered different devices in metric calculation (see stacktrace for details). 

This could be due to the metric class not being on the same device as input. Instead of `metric=BinaryAccuracy(...)` try to do `metric=BinaryAccuracy(...).to(device)` where device corresponds to the device of the input.

Where the stacktrace errors on

  File "/home/user/micromamba/lib/python3.9/site-packages/torchmetrics/metric.py", line 390, in wrapped_func
    update(*args, **kwargs)
  File "/home/user/micromamba/lib/python3.9/site-packages/torchmetrics/classification/stat_scores.py", line 322, in update
    self._update_state(tp, fp, tn, fn)
  File "/home/user/micromamba/lib/python3.9/site-packages/torchmetrics/classification/stat_scores.py", line 70, in _update_state
    self.tp += tp
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

@Borda

@SkafteNicki
Copy link
Member

@jlehrer1in your case, it has to with the initialization which should use a ModuleDict:

        self.metrics = torch.nn.ModuleDict({
            "train": {"accuracy", Accuracy(task="binary")},
            "val": {"accuracy", Accuracy(task="binary")}
        })

You can read more about why here: https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metrics-and-devices

@amorehead
Copy link
Contributor

@SkafteNicki, it doesn't seem like your latest code snippet works with PyTorch 2.0+, since a dict is not a subclass of nn.Module (that's the error PyTorch is raising for me).

@SkafteNicki
Copy link
Member

Hi @amorehead, in my last example I do not refer to a standard dict but instead the special ModuleDict from torch which is a subclass of nn.Module:
https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#ModuleDict

@amorehead
Copy link
Contributor

Hi, @SkafteNicki. When I referred to the dict object in your code example, I meant the inner-dict objects assigned to the train and val keys in your outer-ModuleDict object. When I try instantiating such a data structure as you have it above, PyTorch complains that the inner-dict must be a subclass of nn.Module. This instead works if I wrap the inner-dicts within another ModuleDict though. I mention this just in case anyone else runs into this issue.

@SkafteNicki
Copy link
Member

@amorehead thanks, and sorry for the confusion on my part, you are indeed correct :)
Thanks for clarifying this for anyone that stumbles on this issue.

@amorehead
Copy link
Contributor

No worries! Thanks for the original suggestion. It reminded me to organize my torchmetrics more cleanly :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants