Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-gpu training still has issues #242

Closed
macabdul9 opened this issue Mar 31, 2023 · 12 comments
Closed

Multi-gpu training still has issues #242

macabdul9 opened this issue Mar 31, 2023 · 12 comments
Labels
solved solved

Comments

@macabdul9
Copy link

macabdul9 commented Mar 31, 2023

With int8
Error: RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cuda:1

Without int8
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper__native_layer_norm)

Even after #145

Details: I am training whisper-large-v2 and using pretty everything from example notebook here: https://github.com/huggingface/peft/blob/main/examples/int8_training/peft_bnb_whisper_large_v2_training.ipynb

Similar issue: #205

cc: @pacman100

@pacman100
Copy link
Contributor

Looks like Multi-GPU training with naive pipeline using accelerate's device map fails for encoder-decoder models (#205 had T5 and this issue observes it for Whisper). @younesbelkada any ideas on what might be happening?

@pacman100
Copy link
Contributor

pacman100 commented Mar 31, 2023

For Whisper multi-gpu naive pp using accelerate and peft and trainer, following changes are required:

  1. Base model loading:
from transformers import WhisperForConditionalGeneration
import copy
from accelerate import dispatch_model
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")
device_map = model.hf_device_map.copy()

# required because `labels` are on main execution device (0) while the output of `proj_out` is on other device.
# So, this leads to device mismatch error when calculation cross-entropy between logits and labels.
# Won't arise during inference as `labels` aren't supplied during that time
# instead of changing device of one of the tied modules, I have to do this for all tied modules 
# else the execution device of remaining tied modules isn't changed
device_map["model.decoder.embed_tokens"] = model._hf_hook.execution_device
device_map["model.decoder.embed_positions"] = model._hf_hook.execution_device
device_map["proj_out"] = model._hf_hook.execution_device
dispatch_model(model, device_map=device_map)

print(model.model.decoder.embed_tokens._hf_hook)
print(model.proj_out._hf_hook)
model.hf_device_map

output logs:

AlignDeviceHook(execution_device=0, offload=False, io_same_device=False, offload_buffers=False, place_submodules=True)
AlignDeviceHook(execution_device=0, offload=False, io_same_device=False, offload_buffers=False, place_submodules=True)
{'model.encoder.conv1': 0,
 'model.encoder.conv2': 0,
 'model.encoder.embed_positions': 0,
 'model.encoder.layers.0': 0,
 'model.encoder.layers.1': 0,
 'model.encoder.layers.2': 0,
 'model.encoder.layers.3': 0,
 'model.encoder.layers.4': 0,
 'model.encoder.layers.5': 0,
 'model.encoder.layers.6': 0,
 'model.encoder.layers.7': 0,
 'model.encoder.layers.8': 0,
 'model.encoder.layers.9': 0,
 'model.encoder.layers.10': 0,
 'model.encoder.layers.11': 0,
 'model.encoder.layers.12': 0,
 'model.encoder.layers.13': 0,
 'model.encoder.layers.14': 0,
 'model.encoder.layers.15': 0,
 'model.encoder.layers.16': 0,
 'model.encoder.layers.17': 0,
 'model.encoder.layers.18': 0,
 'model.encoder.layers.19': 1,
 'model.encoder.layers.20': 1,
 'model.encoder.layers.21': 1,
 'model.encoder.layers.22': 1,
 'model.encoder.layers.23': 1,
 'model.encoder.layers.24': 1,
 'model.encoder.layers.25': 1,
 'model.encoder.layers.26': 1,
 'model.encoder.layers.27': 1,
 'model.encoder.layers.28': 1,
 'model.encoder.layers.29': 1,
 'model.encoder.layers.30': 1,
 'model.encoder.layers.31': 1,
 'model.encoder.layer_norm': 1,
 'model.decoder.embed_tokens': 0,
 'proj_out': 0,
 'model.decoder.embed_positions': 0,
 'model.decoder.layers.0': 1,
 'model.decoder.layers.1': 1,
 'model.decoder.layers.2': 1,
 'model.decoder.layers.3': 1,
 'model.decoder.layers.4': 2,
 'model.decoder.layers.5': 2,
 'model.decoder.layers.6': 2,
 'model.decoder.layers.7': 2,
 'model.decoder.layers.8': 2,
 'model.decoder.layers.9': 2,
 'model.decoder.layers.10': 2,
 'model.decoder.layers.11': 2,
 'model.decoder.layers.12': 2,
 'model.decoder.layers.13': 2,
 'model.decoder.layers.14': 2,
 'model.decoder.layers.15': 2,
 'model.decoder.layers.16': 2,
 'model.decoder.layers.17': 2,
 'model.decoder.layers.18': 2,
 'model.decoder.layers.19': 2,
 'model.decoder.layers.20': 2,
 'model.decoder.layers.21': 3,
 'model.decoder.layers.22': 3,
 'model.decoder.layers.23': 3,
 'model.decoder.layers.24': 3,
 'model.decoder.layers.25': 3,
 'model.decoder.layers.26': 3,
 'model.decoder.layers.27': 3,
 'model.decoder.layers.28': 3,
 'model.decoder.layers.29': 3,
 'model.decoder.layers.30': 3,
 'model.decoder.layers.31': 3,
 'model.decoder.layer_norm': 3}
  1. To avoid trainer using DP as this is naive pp => model parallelism:
setattr(model, 'model_parallel', True)
setattr(model, 'is_parallelizable', True)
  1. It should now work properly

Screenshot 2023-03-31 at 1 07 15 PM

@younesbelkada
Copy link
Contributor

hi @pacman100 @macabdul9
Please see: #205 (comment)
These changes were required to enable multi-gpu training with Trainer

@pacman100
Copy link
Contributor

Hello @younes, I did mention it in the above comment, Thank you for mentioning it on the other issue too:

To avoid trainer using DP as this is naive pp => model parallelism:
setattr(model, 'model_parallel', True)
setattr(model, 'is_parallelizable', True)

@pacman100 pacman100 added the solved solved label Mar 31, 2023
@younesbelkada
Copy link
Contributor

Awesome thanks a lot!

@macabdul9
Copy link
Author

It's still not working. @pacman100 can you share your working notebook if possible?

@AttentionAllUNeed
Copy link

AttentionAllUNeed commented Apr 6, 2023

Hi, I have the same problem, do you solve it now? @macabdul9

@mn9891
Copy link

mn9891 commented Apr 7, 2023

Same here, any solutions?

@cchen-dialpad
Copy link

Not sure why, but I found that using DDP instead of DP works without this problem, i.e., python -m torch.distributed.launch --nproc_per_node=2 your_script.py

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@rrva
Copy link

rrva commented Aug 24, 2023

I have made a PR #855 with some of the changes mentioned in this thread, but I still have issues with lora_config rank_pattern becoming None

@phineas-pta
Copy link

here what i did to enable distributed data parallelism:

from accelerate import Acceleratoraccelerator = Accelerator(…)
…
model = WhisperForConditionalGeneration.from_pretrained(…, device_map={"": accelerator.device})
…
training_args = Seq2SeqTrainingArgument(…, accelerator_config={"split_batches": True})
…
trainer.train()
accelerator.wait_for_everyone()
if accelerator.is_main_process:
	trainer.save_model()

then launch with python -m accelerate.commands.launch --multi_gpu --num_machines 1 --num_processes 2 your_script.py if u have 2 gpu for example

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved solved
Projects
None yet
Development

No branches or pull requests

8 participants