Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

--fp causes an issue when running example scripts in distributed mode #4657

Closed
CMobley7 opened this issue May 28, 2020 · 11 comments · Fixed by #4728
Closed

--fp causes an issue when running example scripts in distributed mode #4657

CMobley7 opened this issue May 28, 2020 · 11 comments · Fixed by #4728
Labels
PyTorch Anything PyTorch

Comments

@CMobley7
Copy link

🐛 Bug

Information

Model I am using (Bert, XLNet ...):
roberta-large
Language I am using the model on (English, Chinese ...):
English

The problem arises when using:

  • the official example scripts

The tasks I am working on is:

  • Finetuning a LM with run_language_modeling.py and the SST-2 task with run_glue.py
  • my own dataset

To reproduce

If I run either of the following commands, I get the error included below. However, if I remove --fp, everything works normally. Also, if I add --fp, but run it non-distributed, everything works normally. So, it appears there is an issue with my running -fp in a distributed fashion. I haven't had an issue with this before; so, I'm not sure what the problem is. Any ideas? Thanks in advance.

I installed apex in two different way, but still get the same results.

#Install package required for fp16 computations
RUN git clone https://github.com/NVIDIA/apex.git \
    && cd apex \
    && python3 setup.py install --cuda_ext --cpp_ext
Install package required for fp16 computations
RUN git clone https://github.com/NVIDIA/apex.git \
    && cd apex \
    && pip3 install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
python3 -m torch.distributed.launch --nproc_per_node 2 run_language_modeling.py --output_dir=/ptcc/shared/lm_roberta_20200528_164228 --model_type=roberta --do_train --train_data_file=/ptcc/data/train.txt --do_eval --eval_data_file=/ptcc/data/test.txt --evaluate_during_training --per_gpu_train_batch_size=2 --per_gpu_eval_batch_size=2 --learning_rate=5e-06 --model_name_or_path=roberta-large --mlm --max_steps=120000 --warmup_steps=10000 --save_steps=12000 --seed=42 --fp16 --logging_dir=/ptcc/shared/roberta_20200528_164228_tf_logs'
python3 -m torch.distributed.launch --nproc_per_node 2 run_glue.py --model_type roberta --task_name SST-2 --do_train --do_eval --evaluate_during_training --data_dir /ptcc/data/ --per_gpu_train_batch_size 2 --per_gpu_eval_batch_size 2 --learning_rate 1e-06 --output_dir clf_roberta_20200528_162937 --model_name_or_path /ptcc/shared/lm_roberta_20200528_113420 --num_train_epochs 2.0 --save_steps 1000 --seed 42 --fp16 --logging_dir=/ptcc/shared/roberta_20200528_162937_tf_logs
ptcc_1  | 05/28/2020 20:30:38 - INFO - transformers.trainer -     Starting fine-tuning.
Epoch:   0%|          | 0/2 [00:00<?, ?it/s]       Traceback (most recent call last):
ptcc_1  |   File "/ptcc/run_glue.py", line 228, in <module>
ptcc_1  |     main()
ptcc_1  |   File "/ptcc/run_glue.py", line 160, in main
ptcc_1  |     model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
ptcc_1  |   File "/usr/local/lib/python3.6/dist-packages/transformers/trainer.py", line 470, in train
ptcc_1  |     tr_loss += self._training_step(model, inputs, optimizer)
ptcc_1  |   File "/usr/local/lib/python3.6/dist-packages/transformers/trainer.py", line 577, in _training_step
ptcc_1  |     scaled_loss.backward()
ptcc_1  |   File "/usr/lib/python3.6/contextlib.py", line 88, in __exit__
ptcc_1  |     next(self.gen)
ptcc_1  |   File "/usr/local/lib/python3.6/dist-packages/apex-0.1-py3.6-linux-x86_64.egg/apex/amp/handle.py", line 127, in scale_loss
ptcc_1  |     should_skip = False if delay_overflow_check else loss_scaler.update_scale()
ptcc_1  |   File "/usr/local/lib/python3.6/dist-packages/apex-0.1-py3.6-linux-x86_64.egg/apex/amp/scaler.py", line 200, in update_scale
ptcc_1  |     self._has_overflow = self._overflow_buf.item()
ptcc_1  | RuntimeError: CUDA error: an illegal memory access was encountered
ptcc_1  | /usr/local/lib/python3.6/dist-packages/torch/optim/lr_scheduler.py:114: UserWarning: Seems like `optimizer.step()` has been overridden after learning rate scheduler initialization. Please, make sure to call `optimizer.step()` before `lr_scheduler.step()`. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
ptcc_1  |   "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
ptcc_1  |                                                  terminate called after throwing an instance of 'c10::Error'
ptcc_1  |   what():  CUDA error: an illegal memory access was encountered (insert_events at /pytorch/c10/cuda/CUDACachingAllocator.cpp:771)
ptcc_1  | frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x46 (0x7f69777f6536 in /usr/local/lib/python3.6/dist-packages/torch/lib/libc10.so)
ptcc_1  | frame #1: c10::cuda::CUDACachingAllocator::raw_delete(void*) + 0x7ae (0x7f6977a39fbe in /usr/local/lib/python3.6/dist-packages/torch/lib/libc10_cuda.so)
ptcc_1  | frame #2: c10::TensorImpl::release_resources() + 0x4d (0x7f69777e6abd in /usr/local/lib/python3.6/dist-packages/torch/lib/libc10.so)
ptcc_1  | frame #3: std::vector<c10d::Reducer::Bucket, std::allocator<c10d::Reducer::Bucket> >::~vector() + 0x1d9 (0x7f69c3926ef9 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
ptcc_1  | frame #4: c10d::Reducer::~Reducer() + 0x23a (0x7f69c391c84a in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
ptcc_1  | frame #5: std::_Sp_counted_ptr<c10d::Reducer*, (__gnu_cxx::_Lock_policy)2>::_M_dispose() + 0x12 (0x7f69c38fb7c2 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
ptcc_1  | frame #6: std::_Sp_counted_base<(__gnu_cxx::_Lock_policy)2>::_M_release() + 0x46 (0x7f69c32be466 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
ptcc_1  | frame #7: <unknown function> + 0x87146b (0x7f69c38fc46b in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
ptcc_1  | frame #8: <unknown function> + 0x240500 (0x7f69c32cb500 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
ptcc_1  | frame #9: <unknown function> + 0x24174e (0x7f69c32cc74e in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
ptcc_1  | frame #10: /usr/bin/python3() [0x572a27]
ptcc_1  | frame #11: /usr/bin/python3() [0x54eef2]
ptcc_1  | frame #12: /usr/bin/python3() [0x588948]
ptcc_1  | frame #13: /usr/bin/python3() [0x5ad438]
ptcc_1  | frame #14: /usr/bin/python3() [0x5ad44e]
ptcc_1  | frame #15: /usr/bin/python3() [0x5ad44e]
ptcc_1  | frame #16: /usr/bin/python3() [0x56b276]
ptcc_1  | frame #17: PyDict_SetItemString + 0x153 (0x5709f3 in /usr/bin/python3)
ptcc_1  | frame #18: PyImport_Cleanup + 0x76 (0x4f2fc6 in /usr/bin/python3)
ptcc_1  | frame #19: Py_FinalizeEx + 0x5e (0x637e2e in /usr/bin/python3)
ptcc_1  | frame #20: Py_Main + 0x395 (0x638e95 in /usr/bin/python3)
ptcc_1  | frame #21: main + 0xe0 (0x4b0d00 in /usr/bin/python3)
ptcc_1  | frame #22: __libc_start_main + 0xe7 (0x7f69e4727b97 in /lib/x86_64-linux-gnu/libc.so.6)
ptcc_1  | frame #23: _start + 0x2a (0x5b250a in /usr/bin/python3)

Environment info

  • transformers version: 2.10.0
  • Platform: Linux-5.3.0-26-generic-x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.6.9
  • PyTorch version (GPU?): 1.5.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: Y, 2 Tesla V100-SXM2
  • Using distributed or parallel set-up in script?: Y, 2 Tesla V100-SXM2
@CMobley7
Copy link
Author

I've tried transformers 2.10.0 under CUDA 10.2 with PyTorch 1.5.0 and apex compiled for that environment, as well as under CUDA 10.1 with both PyTorch 1.5.0 and 1.4.1, as well as apex compiled for both of those. However, I get pretty much the same issue. Should I down convert to a different version of transformers?

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]      Traceback (most recent call last):
ptcc_1  |   File "/ptcc/run_language_modeling.py", line 281, in <module>
ptcc_1  |     main()
ptcc_1  |   File "/ptcc/run_language_modeling.py", line 245, in main
ptcc_1  |     trainer.train(model_path=model_path)
ptcc_1  |   File "/usr/local/lib/python3.6/dist-packages/transformers/trainer.py", line 470, in train
ptcc_1  |     tr_loss += self._training_step(model, inputs, optimizer)
ptcc_1  |   File "/usr/local/lib/python3.6/dist-packages/transformers/trainer.py", line 577, in _training_step
ptcc_1  |     scaled_loss.backward()
ptcc_1  |   File "/usr/lib/python3.6/contextlib.py", line 88, in __exit__
ptcc_1  |     next(self.gen)
ptcc_1  |   File "/usr/local/lib/python3.6/dist-packages/apex-0.1-py3.6-linux-x86_64.egg/apex/amp/handle.py", line 127, in scale_loss
ptcc_1  |     should_skip = False if delay_overflow_check else loss_scaler.update_scale()
ptcc_1  |   File "/usr/local/lib/python3.6/dist-packages/apex-0.1-py3.6-linux-x86_64.egg/apex/amp/scaler.py", line 200, in update_scale
ptcc_1  |     self._has_overflow = self._overflow_buf.item()
ptcc_1  | RuntimeError: CUDA error: an illegal memory access was encountered
ptcc_1  | Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
ptcc_1  | /usr/local/lib/python3.6/dist-packages/torch/optim/lr_scheduler.py:114: UserWarning: Seems like `optimizer.step()` has been overridden after learning rate scheduler initialization. Please, make sure to call `optimizer.step()` before `lr_scheduler.step()`. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
ptcc_1  |   "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
ptcc_1  |                                                 terminate called after throwing an instance of 'c10::Error'
ptcc_1  |   what():  CUDA error: an illegal memory access was encountered (insert_events at /pytorch/c10/cuda/CUDACachingAllocator.cpp:771)
ptcc_1  | frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x46 (0x7f2ededfd536 in /usr/local/lib/python3.6/dist-packages/torch/lib/libc10.so)
ptcc_1  | frame #1: c10::cuda::CUDACachingAllocator::raw_delete(void*) + 0x7ae (0x7f2edf040fbe in /usr/local/lib/python3.6/dist-packages/torch/lib/libc10_cuda.so)
ptcc_1  | frame #2: c10::TensorImpl::release_resources() + 0x4d (0x7f2edededabd in /usr/local/lib/python3.6/dist-packages/torch/lib/libc10.so)
ptcc_1  | frame #3: std::vector<c10d::Reducer::Bucket, std::allocator<c10d::Reducer::Bucket> >::~vector() + 0x1d9 (0x7f2f26356d99 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
ptcc_1  | frame #4: c10d::Reducer::~Reducer() + 0x23a (0x7f2f2634c6ea in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
ptcc_1  | frame #5: std::_Sp_counted_ptr<c10d::Reducer*, (__gnu_cxx::_Lock_policy)2>::_M_dispose() + 0x12 (0x7f2f2632b662 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
ptcc_1  | frame #6: std::_Sp_counted_base<(__gnu_cxx::_Lock_policy)2>::_M_release() + 0x46 (0x7f2f25cee306 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
ptcc_1  | frame #7: <unknown function> + 0x87130b (0x7f2f2632c30b in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
ptcc_1  | frame #8: <unknown function> + 0x2403a0 (0x7f2f25cfb3a0 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
ptcc_1  | frame #9: <unknown function> + 0x2415ee (0x7f2f25cfc5ee in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
ptcc_1  | frame #10: /usr/bin/python3() [0x572a27]
ptcc_1  | frame #11: /usr/bin/python3() [0x54eef2]
ptcc_1  | frame #12: /usr/bin/python3() [0x588948]
ptcc_1  | frame #13: /usr/bin/python3() [0x5ad438]
ptcc_1  | frame #14: /usr/bin/python3() [0x5ad44e]
ptcc_1  | frame #15: /usr/bin/python3() [0x5ad44e]
ptcc_1  | frame #16: /usr/bin/python3() [0x56b276]
ptcc_1  | frame #17: PyDict_SetItemString + 0x153 (0x5709f3 in /usr/bin/python3)
ptcc_1  | frame #18: PyImport_Cleanup + 0x76 (0x4f2fc6 in /usr/bin/python3)
ptcc_1  | frame #19: Py_FinalizeEx + 0x5e (0x637e2e in /usr/bin/python3)
ptcc_1  | frame #20: Py_Main + 0x395 (0x638e95 in /usr/bin/python3)
ptcc_1  | frame #21: main + 0xe0 (0x4b0d00 in /usr/bin/python3)
ptcc_1  | frame #22: __libc_start_main + 0xe7 (0x7f2f2b53cb97 in /lib/x86_64-linux-gnu/libc.so.6)
ptcc_1  | frame #23: _start + 0x2a (0x5b250a in /usr/bin/python3)``` 

@CMobley7 CMobley7 mentioned this issue May 29, 2020
4 tasks
@CMobley7
Copy link
Author

I've also tried 3 different machines. All ubuntu 18.04, but with different GPUs sets. 2 Tesla V100-SXM2, 2 P100-SXM2, and 2 Tesla M40, but still get the same error.

@BramVanroy
Copy link
Collaborator

Can you install the repo from source and try again? There have been some issues with PyTorch upstream that Julien addressed here: #4300. So you can try with the latest master branch.

@BramVanroy BramVanroy added the PyTorch Anything PyTorch label May 30, 2020
@CMobley7
Copy link
Author

CMobley7 commented Jun 1, 2020

@BramVanroy, that merge request appears to have been merged prior to v2.10.0 release. I've installed both v2.10.0 and master from source and unfortunately get the same error above when I tried to train a model distributed using mixed precision.

@BramVanroy
Copy link
Collaborator

BramVanroy commented Jun 1, 2020

The one thing I can think of that you can try is specifically setting the current device for each process.

Can you try cloning the library and installing in dev mode, and adding a line here:

model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if data_args.eval_data_file is None and training_args.do_eval:

So that it looks like this:

    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    torch.cuda.set_device(training_args.device)
    if data_args.eval_data_file is None and training_args.do_eval:

@CMobley7
Copy link
Author

CMobley7 commented Jun 2, 2020

Thanks @BramVanroy , you suggestion worked. I really appreciate it.

@BramVanroy
Copy link
Collaborator

Re-opening so that we can close this in a PR.

@BramVanroy BramVanroy reopened this Jun 3, 2020
@CMobley7
Copy link
Author

@BramVanroy, while your suggestion works for multiple GPUs. I get the following error when trying to use a single GPU.

Traceback (most recent call last):
  File "/ptcc/run_language_modeling.py", line 283, in <module>
    main()
  File "/ptcc/run_language_modeling.py", line 136, in main
    torch.cuda.set_device(training_args.device)
  File "/usr/local/lib/python3.6/dist-packages/torch/cuda/__init__.py", line 243, in set_device
    device = _get_device_index(device)
  File "/usr/local/lib/python3.6/dist-packages/torch/cuda/_utils.py", line 34, in _get_device_index
    'or an integer, but got: '.format(device))
ValueError: Expected a cuda device with a specified index or an integer, but got:

and

Traceback (most recent call last):
  File "/ptcc/run_glue.py", line 230, in <module>
    main()
  File "/ptcc/run_glue.py", line 78, in main
    torch.cuda.set_device(training_args.device)
  File "/usr/local/lib/python3.6/dist-packages/torch/cuda/__init__.py", line 243, in set_device
    device = _get_device_index(device)
  File "/usr/local/lib/python3.6/dist-packages/torch/cuda/_utils.py", line 34, in _get_device_index
    'or an integer, but got: '.format(device))
ValueError: Expected a cuda device with a specified index or an integer, but got:

@BramVanroy
Copy link
Collaborator

@CMobley7 Thanks for the update! I pushed another update to my PR, can you try that one out? When we are not using DDP (and local_rank is -1), we do not specify the GPU id to use. It's best to strictly select that main device, so now we select it by using index 0. (This will still work if you set different devices with CUDA_VISIBLE_DEVICES, it'll just select the first device available in that environment).

@CMobley7
Copy link
Author

CMobley7 commented Jun 12, 2020

@BramVanroy , I can confirm that the changes made in #4728 successfully fix the apex issues with both a single and multiple GPUs. I've tested on 3 different machines. All ubuntu 18.04, but with different GPUs sets. 2 Tesla V100-SXM2, 2 P100-SXM2, and 2 Tesla M40. Thanks for your help.

julien-c pushed a commit that referenced this issue Jun 15, 2020
* manually set device in trainer args

* check if current device is cuda before set_device

* Explicitly set GPU ID when using single GPU

This addresses #4657 (comment)
@julien-c
Copy link
Member

Thank you @CMobley7 for the extensive testing, this is very valuable.

And thanks @BramVanroy for fixing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
PyTorch Anything PyTorch
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants