Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metrics fail on DP and multiple GPU #4353

Closed
LittlePea13 opened this issue Oct 25, 2020 · 13 comments · Fixed by #4494
Closed

Metrics fail on DP and multiple GPU #4353

LittlePea13 opened this issue Oct 25, 2020 · 13 comments · Fixed by #4494
Assignees
Labels
feature Is an improvement or enhancement help wanted Open to be worked on strategy: dp (removed in pl) DataParallel
Milestone

Comments

@LittlePea13
Copy link

LittlePea13 commented Oct 25, 2020

🐛 Bug

When using a metric such as Accuracy from pytorch_lightning.metrics in machine with 4 GPU and in 'dp' mode, there is an error due to accumulating the metric in different devices. In the case of Accuracy, in line:
https://github.com/PyTorchLightning/pytorch-lightning/blob/c8ccec7a02c53ed38af6ef7193232426384eee4a/pytorch_lightning/metrics/classification/accuracy.py#L108

The arguments in torch.sum are in the same device the metric is been called from, but the self.correct is in a different one. The traceback is as follows:

    self.accuracy_val(y_hat, y)
  File "/home/***/.conda/envs/***/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/***/.conda/envs/***/lib/python3.8/site-packages/pytorch_lightning/metrics/metric.py", line 153, in forward
    self.update(*args, **kwargs)
  File "/home/***/.conda/envs/***/lib/python3.8/site-packages/pytorch_lightning/metrics/metric.py", line 199, in wrapped_func
    return update(*args, **kwargs)
  File "/home/***/.conda/envs/***/lib/python3.8/site-packages/pytorch_lightning/metrics/classification/accuracy.py", line 109, in update
    self.correct += torch.sum(preds == target)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

Please reproduce using the BoringModel and post here

https://colab.research.google.com/drive/1zcU1ADuHZj82clrBysv-EGfgqG7SxUhN#scrollTo=V7ELesz1kVQo

To Reproduce

The shared colab is not going to be able to replicate the bug since it needs 'dp' on multiple gpus, but it should give an idea of when the error occurs. So setting

        num_gpus=4,
        accelerator="dp",

in the Trainer and then using a metric should bring up the issue. I have tested it with Accuracy but other users in the Slack channel reported it for other metrics such as Precision or Recall.

Expected behavior

The devices should be the same when the values are added together. I am not sure of which would be the correct approach, I have "brutely" solved it by:

        self.correct += torch.sum(preds.cuda(self.correct.device.index) == target.cuda(self.correct.device.index))
        self.total += target.cuda(self.correct.device.index).numel()

in the case of Accuracy, but that is quite an ugly way of dealing with it.
Update: Although this doesn't produce the error, the accuracy is not properly computed, as values get reset to 0 for some reason between steps.

Environment

  • CUDA:
    - GPU:
    - GeForce GTX 1080 Ti
    - GeForce GTX 1080 Ti
    - GeForce GTX 1080 Ti
    - GeForce GTX 1080 Ti
    - available: True
    - version: 10.2
  • Packages:
    - numpy: 1.19.2
    - pyTorch_debug: False
    - pyTorch_version: 1.6.0
    - pytorch-lightning: 1.0.3
    - tqdm: 4.50.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor:
    - python: 3.8.5
    - version: Proposal for help #1 SMP Debian 4.19.152-1 (2020-10-18)
@LittlePea13 LittlePea13 added bug Something isn't working help wanted Open to be worked on labels Oct 25, 2020
@SkafteNicki SkafteNicki added this to the 1.0.x milestone Oct 25, 2020
@marrrcin
Copy link

If you do metrics update in <step name>_step_end() methods it will work correctly.

@SkafteNicki
Copy link
Member

@marrrcin could you provide a bit more detail? Like how and where do you explicitly call metric? (example)
I would like to get to the bottom of this.

@LittlePea13
Copy link
Author

So according to a conversation with him on slack, if the metric.forward() is called in the *_step_end() method (let's say self.acc(true,logits)) rather than in *_step() the error no longer happens.

This probably has to do with the fact that the output values from the step methods are properly gathered across devices into _step_end. Still I am not sure that this is the intended behavior of the metrics modules when using DP, since then they could only be called in train_step_end, validation_step_end and test_step_end.

Maybe @marrrcin can clarify when online, but that is what I understood.

@marrrcin
Copy link

@SkafteNicki
so my workaround is the following:
In my LightningModules __init__ I have:

        METRICS_SUFFIX = "_metrics"
        TRAINING = "train"
        TESTING = "test"
        VALIDATION = "val"
        metrics_factory = lambda: nn.ModuleList(
            [
                pl.metrics.Accuracy(),
                pl.metrics.Precision(hparams.num_classes),
                pl.metrics.Recall(hparams.num_classes),
                pl.metrics.Fbeta(hparams.num_classes),
            ]
        )
        self.metrics: nn.ModuleDict = nn.ModuleDict(
            {
                # you cannot have name `train` in ModuleDict, because nn.Module has function called `train`
                TRAINING + METRICS_SUFFIX: metrics_factory(),
                VALIDATION + METRICS_SUFFIX: metrics_factory(),
                TESTING + METRICS_SUFFIX: metrics_factory(),
            }
        )

Then, some utility functions:

    @staticmethod
    def _get_metric_name(m: pl.metrics.Metric):
        return m.__class__.__name__.lower()

    def update_metrics(
        self,
        step_name,
        pred_y: torch.Tensor,
        true_y: torch.Tensor,
        log_metrics=True,
    ):
        for metric in self.metrics[step_name + METRICS_SUFFIX]:
            m = metric(pred_y, true_y)
            if log_metrics:
                self.log(
                    f"{self._get_metric_name(metric)}/{step_name}",
                    m,
                    on_step=False,
                    on_epoch=True,
                )

Then, from *_step I return dicts with: loss, pred_y, true_y keys.
Lastly, in *_step_end I call the update:

    def training_step_end(self, outputs: dict) -> torch.Tensor:
        loss = outputs["loss"].mean()

        self.update_metrics(TRAINING, outputs["pred_y"], outputs["true_y"])
# etc...

@EspenHa
Copy link
Contributor

EspenHa commented Oct 26, 2020

Seems to be a duplicate of / closely related to #4073, you might want to try the fix in #4138.

@SkafteNicki
Copy link
Member

@EspenHa I agree that the error is the same, but this happens long before trying to log the metrics. Actually, it has nothing to do with the rest of lightning, but a general incompatibility between lightning metrics and DataParallel:

from pytorch_lightning import metrics
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

metric = metrics.Accuracy()
metric_dp = torch.nn.DataParallel(metric)
metric_dp.to(device)
pred = torch.randint(2, size=(10,))
target = torch.randint(2, size=(10,))
val = metric_dp(pred, target)

Even this small example fails with the same error. The core of the problem seems to have something to do with how we register the state as a buffer. As far as I understand, if we use self.register_buffer (which we do) pytorch should move the states to the correct devices, but this does not seems to happen.

@SkafteNicki
Copy link
Member

Just a small update:
It seems that the pitfall is in fact that we use self.register_buffer for the internal states in metrics. They are also making troubles in ddp mode since the buffer on rank 0 in each forward pass is overwriting the buffer on all other ranks, leading to wrong result in the. I am trying at the moment to come up with a solution for this.

That said, another problem with metrics in dp mode have come to my attention. Since dp is creating and destroying replicas of the model on each forward call, the internal state of metrics will be destroyed before we have a chance to accumulate them over the different devices. Therefore, until we implement some kind of state maintenance in dp (PR: #1895), the only way forward right now is (thanks @marrrcin for the workaround):

  • return preds,target in <mode>_step (<mode> either training, val or test)
  • call the metric in <mode>_step_end

@edenlightning
Copy link
Contributor

@teddykoker @ananyahjha93 please take a look as well

@edenlightning
Copy link
Contributor

This is currently not supported, until we have stateful DP, see #1895

@edenlightning edenlightning added feature Is an improvement or enhancement and removed bug Something isn't working labels Nov 2, 2020
@edenlightning
Copy link
Contributor

@ananyahjha93 what should be next steps? should we change the docs?

@ananyahjha93
Copy link
Contributor

@edenlightning we need to get the state maintenance for DP in before we tackle this. Even if it doesn't solve this completely, that PR has been lying around for a few months.

@mhamilton723
Copy link

Really love PT lightning but still Having this issue @ananyahjha93 do you know what the fix is? Thanks!

@blacksnail789521
Copy link

I think the problem is initializing a torchmetrics object inside a dict.
If I have the following code, I would have an error:

self.metrics = {
    "cross_entropy": nn.CrossEntropyLoss(),
    "accuracy": torchmetrics.Accuracy(task="multiclass", num_classes=10),
}

Instead, if I initialize the metrics explicitly, I won't have any error:

self.cross_entropy = nn.CrossEntropyLoss()
self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on strategy: dp (removed in pl) DataParallel
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants