Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[T5] Enable naive Pipeline Parallelism training for T5 #22535

Merged
merged 4 commits into from
Apr 3, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Apr 3, 2023

What does this PR do?

Similarly as #22329 this PR enables training T5 models in a "Naive Pipeline Parallelism" setup. What is termed as "Naive Pipeline Parallelism" is simply to spread the model across multiple GPUs and run naively the forward/backward pass by communicating the activations and gradients between each GPU.

Without this fix, users will encounter device mismatch issues when training this model that has been loaded across multiple GPUs. Hence, the fix is to manually set the device of the labels to the same device as lm_logits.

A simple snippet to reproduce the behaviour below (this needs to be run on a multi-gpu env):

import torch
from transformers import AutoModelForSeq2SeqLM

model_id = "google/flan-t5-base"

model = AutoModelForSeq2SeqLM.from_pretrained(model_id, device_map="balanced")
print(set(model.hf_device_map.values())) # >>> {0, 1}

dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]])

loss = model(input_ids=dummy_input, labels=dummy_input).loss

Error trace:

│   1746 │   │   loss = None                                                                       │
│   1747 │   │   if labels is not None:                                                            │
│   1748 │   │   │   loss_fct = CrossEntropyLoss(ignore_index=-100)                                │
│ ❱ 1749 │   │   │   loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))      │
│   1750 │   │   │   # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc  │
│   1751 │   │                                                                                     │
│   1752 │   │   if not return_dict:                                                               │
│                                                                                                  │
│ /home/younes_huggingface_co/miniconda3/envs/fix-test/lib/python3.9/site-packages/torch/nn/module │
│ s/module.py:1501 in _call_impl                                                                   │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/younes_huggingface_co/miniconda3/envs/fix-test/lib/python3.9/site-packages/torch/nn/module │
│ s/loss.py:1174 in forward                                                                        │
│                                                                                                  │
│   1171 │   │   self.label_smoothing = label_smoothing                                            │
│   1172 │                                                                                         │
│   1173 │   def forward(self, input: Tensor, target: Tensor) -> Tensor:                           │
│ ❱ 1174 │   │   return F.cross_entropy(input, target, weight=self.weight,                         │
│   1175 │   │   │   │   │   │   │      ignore_index=self.ignore_index, reduction=self.reduction,  │
│   1176 │   │   │   │   │   │   │      label_smoothing=self.label_smoothing)                      │
│   1177                                                                                           │
│                                                                                                  │
│ /home/younes_huggingface_co/miniconda3/envs/fix-test/lib/python3.9/site-packages/torch/nn/functi │
│ onal.py:3029 in cross_entropy                                                                    │
│                                                                                                  │
│   3026 │   │   )                                                                                 │
│   3027 │   if size_average is not None or reduce is not None:                                    │
│   3028 │   │   reduction = _Reduction.legacy_get_string(size_average, reduce)                    │
│ ❱ 3029 │   return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(re  │
│   3030                                                                                           │
│   3031                                                                                           │
│   3032 def binary_cross_entropy(                                                                 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument target 
in method wrapper_CUDA_nll_loss_forward)

cc @sgugger

Related issues:

huggingface/peft#242
huggingface/peft#205

@younesbelkada younesbelkada marked this pull request as ready for review April 3, 2023 15:17
@younesbelkada younesbelkada requested a review from sgugger April 3, 2023 15:17
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 3, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

@younesbelkada younesbelkada merged commit d7a4f5b into huggingface:main Apr 3, 2023
@younesbelkada younesbelkada deleted the fix-t5-pp branch April 3, 2023 15:55
raghavanone pushed a commit to raghavanone/transformers that referenced this pull request Apr 5, 2023
…#22535)

* enable PP for T5

* make fixup

* fix failing tests
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…#22535)

* enable PP for T5

* make fixup

* fix failing tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants