Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-GPU autograd error with Pytorch 0.4 #7092

Closed
erogol opened this issue Apr 30, 2018 · 35 comments
Closed

Multi-GPU autograd error with Pytorch 0.4 #7092

erogol opened this issue Apr 30, 2018 · 35 comments
Assignees
Labels
todo Not as important as medium or high priority tasks, but we will work on these.

Comments

@erogol
Copy link

erogol commented Apr 30, 2018

After updating pytorch 0.4 I am getting the following error when I try to train my model here: https://github.com/mozilla/TTS with multi-gpus. I have no idea about what it means unfortunately. A bug or just a problem that I need some feedback on. Thx.

Traceback (most recent call last):
  File "train.py", line 403, in <module>
    main(args)
  File "train.py", line 393, in main
    model, criterion, train_loader, optimizer, epoch)
  File "train.py", line 111, in train
    model.forward(text_input, mel_spec)
  File "/home/erogol/miniconda3/envs/pytorch4/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 114, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/erogol/miniconda3/envs/pytorch4/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 124, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/erogol/miniconda3/envs/pytorch4/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 65, in parallel_apply
    raise output
  File "/home/erogol/miniconda3/envs/pytorch4/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 41, in _worker
    output = module(*input, **kwargs)
  File "/home/erogol/miniconda3/envs/pytorch4/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/erogol/projects/TTS/models/tacotron.py", line 28, in forward
    encoder_outputs = self.encoder(inputs)
  File "/home/erogol/miniconda3/envs/pytorch4/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/erogol/projects/TTS/layers/tacotron.py", line 205, in forward
    return self.cbhg(inputs)
  File "/home/erogol/miniconda3/envs/pytorch4/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/erogol/projects/TTS/layers/tacotron.py", line 183, in forward
    outputs, _ = self.gru(x)
  File "/home/erogol/miniconda3/envs/pytorch4/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/erogol/miniconda3/envs/pytorch4/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 192, in forward
    output, hidden = func(input, self.all_weights, hx, batch_sizes)
  File "/home/erogol/miniconda3/envs/pytorch4/lib/python3.6/site-packages/torch/nn/_functions/rnn.py", line 323, in forward
    return func(input, *fargs, **fkwargs)
  File "/home/erogol/miniconda3/envs/pytorch4/lib/python3.6/site-packages/torch/nn/_functions/rnn.py", line 287, in forward
    dropout_ts)
RuntimeError: torch/csrc/autograd/variable.cpp:115: get_grad_fn: Assertion `output_nr == 0` failed.
@ssnl
Copy link
Collaborator

ssnl commented Apr 30, 2018

Urr. This seems like a bug. Can you try to come up with a minimal working example please? Thanks for reporting!

@erogol
Copy link
Author

erogol commented Apr 30, 2018

The problem is flattening the parameters. If I dont use it everything works with a warning suggesting it.

import torch

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.gru = torch.nn.GRU(128, 64, 1, batch_first=True, bidirectional=True)
                                
    def forward(self, inp):
        self.gru.flatten_parameters()
        out, _ = self.gru(inp)
        return out
                        

net = Net()
inp = torch.rand(32, 8, 128, requires_grad=True)

net = torch.nn.DataParallel(net)

inp = inp.cuda()
net = net.cuda()
out = net.forward(inp)

@ssnl
Copy link
Collaborator

ssnl commented Apr 30, 2018

@erogol Thanks Eren. This is very helpful. We'll look into it.

@zou3519
Copy link
Contributor

zou3519 commented May 1, 2018

@erogol What was the last version of pytorch you used that the code worked? The code crashes on 0.3.1 for me

@erogol
Copy link
Author

erogol commented May 1, 2018

@zou3519 0.3.0

@adampolyak
Copy link
Contributor

Same bug when using the code here: https://github.com/NVIDIA/tacotron2 with Pytorch0.4

@BangLiu
Copy link

BangLiu commented May 10, 2018

I also have the same problem here. Has the bug been fixed? Or do we have any solution now?

@colesbury
Copy link
Member

@BangLiu the workaround is to to remove the flatten_parameters() call from def forward(). I'm not sure how that ever worked in a DataParallel.

@erogol
Copy link
Author

erogol commented May 11, 2018

@colesbury do you think that would make any performance difference, calling it in forward()

@soumith
Copy link
Member

soumith commented May 11, 2018

@erogol yes, calling it in forward is detrimental to performance.

If you are using RNN, it is better to use DistributedDataParallel (and 1 process per GPU) than using DataParallel. It has benefits of being faster and you dont have to restructure your batching (your code is as if it is just using 1-GPU).

See https://pytorch.org/docs/stable/distributed.html#launch-utility

@liqing-ustc
Copy link

+1. I met the same bug.

@soumith
Copy link
Member

soumith commented May 14, 2018

@liqing-ustc see #7092 (comment) for answer

@zou3519 zou3519 added the todo Not as important as medium or high priority tasks, but we will work on these. label May 14, 2018
@wangkenpu
Copy link

wangkenpu commented May 28, 2018

I met the same problem with PyTorch 0.4.0. I was wondering if this bug have been solved?

I found if I deleted self.gru.flatten_parameters() in forward() function, every thing worked well. But I faced another warning.

UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().

@eriche2016
Copy link

any good solution to this problem?

@erogol
Copy link
Author

erogol commented May 28, 2018

only suggested solution so far #7092 (comment)

@eriche2016
Copy link

Can anyone give tips on how to use DistributedDataParallel , I am trying to use DistributedDataParallel
as follows,

import torch

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.gru = torch.nn.GRU(128, 64, 1, batch_first=True, bidirectional=True)
                                
    def forward(self, inp):
        # self.gru.flatten_parameters()
        out, _ = self.gru(inp)
        return out
                        

net = Net()
inp = torch.rand(32, 8, 128, requires_grad=True)

inp = inp.cuda()
net = net.cuda()
if False: 
    net = torch.nn.DataParallel(net, device_ids=range(2))
else: 
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids=range(2))


out = net.forward(inp)

gives the errror

AssertionError: collective only supported in process-group mode

@eriche2016
Copy link

eriche2016 commented May 29, 2018

@erogol @soumith

@erogol
Copy link
Author

erogol commented May 29, 2018

@eriche2016 I had the same issue then quit trying. However, I think the forum is the right place to raise this.

@apaszke
Copy link
Contributor

apaszke commented May 29, 2018

You're not even initializing the distributed module. Use torch.distributed.init_process_group(...).

edorado93 added a commit to edorado93/Writing-editing-Network that referenced this issue Jun 25, 2018
flatten_parameters() issue with PyTorch 0.4
pytorch/pytorch#7092
@PetrochukM
Copy link

PetrochukM commented Jun 26, 2018

Any chance applying distributed to training will be as easy as calling a function like: torch.nn.parallel.data_parallel? Thanks for your help!

Folks, btw, here is a good example of applying distributed to Pytorch: https://github.com/pytorch/examples/blob/master/imagenet/main.py

It's a bit cumbersome but doable.

@soumith
Copy link
Member

soumith commented Jun 26, 2018

@erogol @PetrochukM @eriche2016 the distributed launcher page cleanly describes (in 4 steps) what you have to do to your code to make it use distributed.

https://pytorch.org/docs/stable/distributed.html#launch-utility

@PetrochukM
Copy link

PetrochukM commented Jun 29, 2018

@soumith Read through the tutorial and the launch utility, the distributed API has many options that enable it to be really powerful. Amazing!

Following up on this comment -- "If you are using RNN, it is better to use DistributedDataParallel (and 1 process per GPU) than using DataParallel." @soumith

For a single machine with multiple GPUs 2 - 8 running RNNs, what are the best parameters.

  • Should we use tcp or the file system for communication?
  • I am assuming that it'd be best to have one process per GPU.
  • Do you recommend using torch.utils.data.distributed.DistributedSampler? What cases would it be helpful?
  • Anything else?

Sorry, do not have much experience with this and following reading the article, it was not obviously clear!

@PetrochukM
Copy link

PetrochukM commented Jul 2, 2018

Okay... running through this process, here are some things to watch out for:

Got distributed running and the error went away but I would not recommend this approach. This is due to:

  • The training was approx. 4x slower (1 vs 4 minutes) (Using a single GRU layer running over sequences of 900 with the same batch size 32 split over 4 GPUs.)
  • The codebase needs to be restructured into a master worker script (e.g. the same script runs master and worker nodes)
  • Because there are multiple processes in a master and worker relationship, the log files become very cluttered with the same logs from both the master and workers. The relationship also means that saving checkpoints or any files for that matter should be hedged to the master node.

Other notes:

  • Distributed should be faster for an RNN that is unrolled with a loop due to:

    Each process contains an independent Python interpreter, eliminating the extra interpreter overhead and “GIL-thrashing” that comes from driving several execution threads, model replicas, or GPUs from a single Python process. This is especially important for models that make heavy use of the Python runtime, including models with recurrent layers or many small components.

  • It's not clear why something similar to torch.nn.parallel.data_parallel cannot be added to PyTorch but rather than launching threads it launches processes.

@ailzhang
Copy link
Contributor

ailzhang commented Jul 5, 2018

Hi @PetrochukM , to answer your questions:

  • Since you are using single machine with DisitributedDataParallel, yes 1 process per GPU is recommended. An alternative is using 1 process for 8 GPUs along with DataParallel.
  • Again for one node DDP, tcp and file initialization are the same. Since with tcp you will always use 127.0.0.1 as master address and this is for sure accessible from all processes. Same for the file initialization. You will need to be careful if you train DDP on multiple machines. tcp requires the machines pingable among each other and file initilization require a shared file system among all nodes.
  • The codebase definitely doesn't have to be restructured as master-worker. An example is https://github.com/pytorch/examples/blob/master/imagenet/main.py. If you have to do that, something might be wrong. You are welcome to post a minimal repro of your script here so that we can help.
  • Yes the log file can be messy if all processes write to the same file. However this can be relieved by either having rank printed in the beginning of the line, or let processes write to a filename suffixed with rank so that they are separated.
  • I'm not sure why you see 4x slower performance when scaling from 1GPU to 4. Could you first confirm that you are not running 4 times of the dataset in 4GPU case? Or you can post a toy script so that I can take a look.
    Hope these answer your questions above. Feel free to let me know if I'm missing something.

@PetrochukM
Copy link

@ailzhang Thanks for your thorough reply, learned something new :)

Responding to a couple of your points:

  • The code base needs to be master/worker because the same script is getting run multiple times; therefore, if an action should only be taken once then it should be assigned to the master worker. Let me know if that does not make sense.
  • The performance comparison is between DataParallel and DisitributedDataParallel. The batch size was 32, I checked, both DisitributedDataParallel and DataParallel split the batch size into chunks of 8 across 4 GPUs. DataParallel performance 4x faster. When I get a free cycle, I'll post a minimal example.

@ailzhang
Copy link
Contributor

ailzhang commented Jul 6, 2018

Hi @PetrochukM ,

  • May I ask what operation need to be run only once? Since each process of DDP is exactly the same as one GPU running on different batches. Running an operation only on master but not all of them basically means you want to apply something on some batches but not the others. Please let me know if you have any context on this.
  • If I understand correctly, you are comparing DataParallel(1 process with 4GPUs) with batchsize 32, to 4 DistributedDataParallel processes(1 process per GPU) where batchsize is 8. This is a fair comparison. If that's already the case, it might be the computation time / communication time is too little that means you can simply increase the batchsize per gpu to saturate the GPU first and then compare. In the later case I could help if needed if you post a simple script here. Let me know!

@PetrochukM
Copy link

Hi!

  • For example, one operation is setting up an experiment folder, tensorboard, and preprocessing.
  • Okay, Thanks! The memory is intended to be pretty low for this experiment!

@soumith
Copy link
Member

soumith commented Jul 6, 2018

one simple way to only execute particular code on 1 particular "master" worker is to have a simple if conditioned on the rank of the process.

if args.rank == 0:
    print("foo")

it's a common style in mpi-style distributed code

@Aspire1Inspire2
Copy link

Still have the same bug for multi-GPU with Pytorch 0.4.0
removing flatten_parameters() gives another warning of downgrading performance.

Another problem with DataParallel is that the hidden state of RNN is batched on the second dimension.
The (dim=1) argument of DataParallel works only on the input/output slicing, disregarding that the RNN hidden state has this batching on the second dimension.
The (batch_first=True) argument of RNN also works only on the input/output.
Either way, DataParallel cannot recognize the correct hidden state tensor size. Work around for user is tedious. Hopefully someone could fix this.

@ailzhang
Copy link
Contributor

@Aspire1Inspire2 could you provide a simple script to repro the RNN bug and expected behavior?

@Aspire1Inspire2
Copy link

If you google "pytorch dataparallel rnn", or "pytorch dataparallel lstm" you will find several dozen examples who complain about the hidden state not parallelizable. Here is a simple example toy code,

https://stackoverflow.com/questions/44595338/how-to-parallelize-rnn-function-in-pytorch-with-dataparallel

I want to comment one more time. Neither dim=1 nor batch_first=True works. It is simply the hidden state that got messed up with dataparallel.

Thank you for your response.

@ailzhang ailzhang self-assigned this Aug 7, 2018
@kishwarshafin
Copy link

kishwarshafin commented Aug 11, 2018

Having this problem since we switched to 0.4 and it's been a while now. Any chance we will get a solution soon?

Also, is there is a performance comparison between models that do flatten_parameter() and those don't?

@ticlazau
Copy link

@ALL, same here running MLperf/speech_recognition, on a 4x GPU system:
RuntimeError: torch/csrc/autograd/variable.cpp:115: get_grad_fn: Assertion output_nr == 0 failed.

@sniperwrb
Copy link

Same bug when using the code here: https://github.com/NVIDIA/tacotron2 with Pytorch0.4

Same here

@ngimel
Copy link
Collaborator

ngimel commented Oct 30, 2020

Closing due to age

@ngimel ngimel closed this as completed Oct 30, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
todo Not as important as medium or high priority tasks, but we will work on these.
Projects
None yet
Development

No branches or pull requests