Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data Parallel bug (return outputs not being moved to same device) #4073

Closed
willprice opened this issue Oct 11, 2020 · 7 comments · Fixed by #4138
Closed

Data Parallel bug (return outputs not being moved to same device) #4073

willprice opened this issue Oct 11, 2020 · 7 comments · Fixed by #4138
Assignees
Labels
bug Something isn't working help wanted Open to be worked on logging Related to the `LoggerConnector` and `log()` priority: 0 High priority task strategy: dp (removed in pl) DataParallel waiting on author Waiting on user action, correction, or update

Comments

@willprice
Copy link
Contributor

willprice commented Oct 11, 2020

🐛 Bug

Under backend='dp' doesn't handle reduction of the loss across multiple GPUs correctly. This is present in v0.10--v1.0.0rc4

To Reproduce

Code sample

import torch
import pytorch_lightning as ptl
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset


class RandomDictDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        a = self.data[index]
        b = a + 2
        return {"a": a, "b": b}

    def __len__(self):
        return self.len


class RandomDictStringDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return {"id": str(index), "x": self.data[index]}

    def __len__(self):
        return self.len


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        """
        Testing PL Module
        Use as follows:
        - subclass
        - modify the behavior for what you want
        class TestModel(BaseTestModel):
            def training_step(...):
                # do your own thing
        or:
        model = BaseTestModel()
        model.training_epoch_end = None
        """
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.cross_entropy(
            prediction,
            torch.ones(len(prediction), dtype=torch.long, device=prediction.device),
        )

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        self.log("loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        self.log("loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

    def train_dataloader(self):
        return torch.utils.data.DataLoader(RandomDataset(32, 64), batch_size=16)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(RandomDataset(32, 64), batch_size=16)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(RandomDataset(32, 64), batch_size=16)


def main():
    model = BoringModel()
    trainer = ptl.Trainer(
        distributed_backend="dp",
        gpus=4,
    )
    trainer.fit(model)


if __name__ == "__main__":
    main()

Produces the following

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

| Name  | Type   | Params
---------------------------------
0 | layer | Linear | 66
/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 104 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
warnings.warn(*args, **kwargs)
Validation sanity check: 0it [00:00, ?it/s]/home/user/.conda/envs/env/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
warnings.warn('Was asked to gather along dimension 0, but all '
/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 104 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
warnings.warn(*args, **kwargs)
Epoch 1:  50%|████████████████Traceback (most recent call last):████████████████████████████████████                                                                                      | 4/8 [00:00<00:00, 184.41it/s, loss=0.497, v_num=53]
File "dp_bug.py", line 118, in <module>
main()
File "dp_bug.py", line 114, in main
trainer.fit(model)
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 440, in fit
results = self.accelerator_backend.train()
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/accelerators/dp_accelerator.py", line 97, in train
results = self.train_or_test()
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 53, in train_or_test
results = self.trainer.train()
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 483, in train
self.train_loop.run_training_epoch()
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 557, in run_training_epoch
self.trainer.run_evaluation(test_mode=False)
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 599, in run_evaluation
eval_loop_results = self.evaluation_loop.log_epoch_metrics(deprecated_eval_results, epoch_logs, test_mode)
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 210, in log_epoch_metrics
eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end(
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector.py", line 113, in on_evaluation_epoch_end
self._log_on_evaluation_epoch_end_metrics(epoch_logs)
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector.py", line 178, in _log_on_evaluation_epoch_end_metrics
reduced_epoch_metrics = dl_metrics[0].__class__.reduce_on_epoch_end(dl_metrics)
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/core/step_result.py", line 433, in reduce_on_epoch_end
recursive_stack(result)
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/core/step_result.py", line 552, in recursive_stack
result[k] = collate_tensors(v)
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/core/step_result.py", line 574, in collate_tensors
return torch.stack(items)
RuntimeError: All input tensors must be on the same device. Received cuda:3 and cuda:1
Exception ignored in: <function tqdm.__del__ at 0x7fcf54050a60>
Traceback (most recent call last):
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/tqdm/std.py", line 1087, in __del__
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/tqdm/std.py", line 1294, in close
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/tqdm/std.py", line 1472, in display
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/tqdm/std.py", line 1090, in __repr__
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/tqdm/std.py", line 1434, in format_dict
TypeError: cannot unpack non-iterable NoneType object

Specifically note the line saying

RuntimeError: All input tensors must be on the same device. Received cuda:3 and cuda:1

Expected behavior

Environment

  • PyTorch Version (e.g., 1.0): 1.6.0
  • OS (e.g., Linux): Ubuntu 18.04
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): N/A
  • Python version: 3.8.5
  • CUDA/cuDNN version: 11.0
  • GPU models and configuration: 8 GPU (RTX 2080Ti)
  • Any other relevant information:

Additional context

This works on v0.9.0:

import torch
import pytorch_lightning as ptl
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset


class RandomDictDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        a = self.data[index]
        b = a + 2
        return {"a": a, "b": b}

    def __len__(self):
        return self.len


class RandomDictStringDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return {"id": str(index), "x": self.data[index]}

    def __len__(self):
        return self.len


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        """
        Testing PL Module
        Use as follows:
        - subclass
        - modify the behavior for what you want
        class TestModel(BaseTestModel):
            def training_step(...):
                # do your own thing
        or:
        model = BaseTestModel()
        model.training_epoch_end = None
        """
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.cross_entropy(
            prediction,
            torch.ones(len(prediction), dtype=torch.long, device=prediction.device),
        )

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"val_loss": loss}

    def test_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"test_loss": loss}

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

    def train_dataloader(self):
        return torch.utils.data.DataLoader(RandomDataset(32, 64), batch_size=16)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(RandomDataset(32, 64), batch_size=16)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(RandomDataset(32, 64), batch_size=16)


def main():
    model = BoringModel()
    trainer = ptl.Trainer(
        distributed_backend="dp",
        gpus=4,
        # log_every_n_steps=5,
        # flush_logs_every_n_steps=20,
        # benchmark=True,
        # gradient_clip_val=20,
    )
    trainer.fit(model)


if __name__ == "__main__":
    main()

but causes this error under v1.0.0rc4

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

| Name  | Type   | Params
---------------------------------
0 | layer | Linear | 66    
/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 104 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
warnings.warn(*args, **kwargs)
/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 104 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
warnings.warn(*args, **kwargs)
Epoch 0:   0%|                                                                                                                                                                                                          | 0/8 [00:00<?, ?it/s]Traceback (most recent call last):
File "dp_bug.py", line 116, in <module>
main()
File "dp_bug.py", line 112, in main
trainer.fit(model)
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 440, in fit
results = self.accelerator_backend.train()
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/accelerators/dp_accelerator.py", line 97, in train
results = self.train_or_test()
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 53, in train_or_test
results = self.trainer.train()
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 483, in train
self.train_loop.run_training_epoch()
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 529, in run_training_epoch
batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 661, in run_training_batch
opt_closure_result = self.training_step_and_backward(
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 753, in training_step_and_backward
self.backward(result, optimizer, opt_idx)
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 767, in backward
result.closure_loss = self.trainer.accelerator_backend.backward(
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 83, in backward
model.backward(closure_loss, optimizer, opt_idx)
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1077, in backward
loss.backward()
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/torch/tensor.py", line 185, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/torch/autograd/__init__.py", line 121, in backward
grad_tensors = _make_grads(tensors, grad_tensors)
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/torch/autograd/__init__.py", line 47, in _make_grads
raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs
Exception ignored in: <function tqdm.__del__ at 0x7fed7b0c1a60>
Traceback (most recent call last):
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/tqdm/std.py", line 1087, in __del__
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/tqdm/std.py", line 1294, in close
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/tqdm/std.py", line 1472, in display
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/tqdm/std.py", line 1090, in __repr__
File "/home/user/.conda/envs/env/lib/python3.8/site-packages/tqdm/std.py", line 1434, in format_dict
TypeError: cannot unpack non-iterable NoneType object

@willprice willprice added bug Something isn't working help wanted Open to be worked on labels Oct 11, 2020
@awaelchli
Copy link
Contributor

awaelchli commented Oct 12, 2020

import torch

# this is what is happening in will's code:

prediction = torch.rand(8, 2, requires_grad=True)

# device 0 computes:
x = torch.nn.functional.cross_entropy(
    prediction,
    torch.ones(len(prediction), dtype=torch.long, device=prediction.device),
)

# devices 1 computes:
y = torch.nn.functional.cross_entropy(
    prediction,
    torch.ones(len(prediction), dtype=torch.long, device=prediction.device),
)

# dp backend calls backward on stacked tensor
l = torch.stack((x, y))
l.backward()  # backward on a non-scalar

Here is the pytorch code that shows the problem. Gives the same error as reported by @willprice

  File "/home/adrian/repositories/imagenet-optical-flow/asdf.py", line 19, in <module>
    l.backward()
  File "/home/adrian/bin/anaconda3/envs/lightning-0.7.1/lib/python3.7/site-packages/torch/tensor.py", line 185, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/adrian/bin/anaconda3/envs/lightning-0.7.1/lib/python3.7/site-packages/torch/autograd/__init__.py", line 121, in backward
    grad_tensors = _make_grads(tensors, grad_tensors)
  File "/home/adrian/bin/anaconda3/envs/lightning-0.7.1/lib/python3.7/site-packages/torch/autograd/__init__.py", line 47, in _make_grads
    raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs

Conclusion: Somewhere in the dp backend the losses get stacked and backward is called on a non-scalar tensor.
Have limited time rn, so dropping this info here for now to pick it up later

@awaelchli awaelchli added the priority: 0 High priority task label Oct 12, 2020
@justusschock
Copy link
Member

@willprice can you check the fix fro #4138 ? For me this worked on the reproduction script.

@justusschock justusschock self-assigned this Oct 14, 2020
@edenlightning edenlightning added the waiting on author Waiting on user action, correction, or update label Oct 19, 2020
@edenlightning edenlightning modified the milestones: 1.1, 1.0.3 Oct 19, 2020
@edenlightning edenlightning modified the milestones: 1.0.x, 1.0.7 Nov 10, 2020
@Borda Borda modified the milestones: 1.0.7, 1.0.x Nov 11, 2020
@edenlightning
Copy link
Contributor

@willprice friendly ping :)

willprice added a commit to epic-kitchens/C1-Action-Recognition-TSN-TRN-TSM that referenced this issue Nov 12, 2020
@willprice
Copy link
Contributor Author

Hey @edenlightning, I can confirm that this is fixed for me on the reproduction script and my own codebase. Although I did run into this issue when testing out #4138 on my codebase (I have example_input_array set on my LightningModel).

@edenlightning edenlightning removed this from the 1.0.x milestone Nov 13, 2020
@awaelchli
Copy link
Contributor

I think I rediscovered this bug in our examples:
python pl_examples/basic_examples/simple_image_classifier.py --gpus 2 --accelerator dp
I will try to help with the PR @justusschock has still open so we can also finish #4764

@awaelchli awaelchli reopened this Dec 3, 2020
@awaelchli awaelchli added the logging Related to the `LoggerConnector` and `log()` label Dec 3, 2020
@awaelchli awaelchli self-assigned this Dec 3, 2020
@MaveriQ
Copy link

MaveriQ commented Jan 7, 2021

Hi. My PL version is : pytorch-lightning 1.1.3 pyhd8ed1ab_0 conda-forge

I noticed that I was getting the error RuntimeError: grad can be implicitly created only for scalar outputs,
when I used

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

And I did not get the error when I used

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        self.log("loss", loss)
        return loss

The difference being whether I return the dictionary or just the loss value. I found it odd, so I am putting it out here, in case it helps someone.

@feliyur
Copy link

feliyur commented Jan 13, 2021

I'm getting "grad can be implicitly created only for scalar outputs" with pytorch-lightning 1.1.4, using backend='dp'. Solved it by adding a training_step_end that explicitly aggregates the loss:

    def training_step_end(self, training_step_outputs):
        return {'loss': training_step_outputs['loss'].sum()}

Perhaps related - the argument training_step_outputs itself is a dictionary rather than a tuple (using 2 gpus), as I would expect from the documentation

            def training_step_end(self, training_step_outputs):
                gpu_0_pred = training_step_outputs[0]['pred']
                gpu_1_pred = training_step_outputs[1]['pred']
                gpu_n_pred = training_step_outputs[n]['pred']

@MaveriQ's fix does not work with the return value of training_step_end, i.e. returning more than one element tensor causes the original error message.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on logging Related to the `LoggerConnector` and `log()` priority: 0 High priority task strategy: dp (removed in pl) DataParallel waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants