-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[trainer] seq2seq doesn't handle mt5 correctly #9865
Comments
OK, I can reproduce the problem with just google/mt5-small and 2 gpus:
We will get it sorted out today. |
ok, the problem had nothing to do with DeepSpeed, it's just a seq2seq neglect. The fix is:
Please let me know if you can manage to apply this fix. I will make a proper PR later, but it'll take some work, since I need to make a tiny mt5 model and add a test. You can just edit the file if you don't know how to apply a patch. |
The fix should be merged shortly #9879 |
I can solve the As for questions 3 and 4, I noticed that the title of the issue has been edited. I don't know if these questions are caused by the model or the seq2seq trainer. Maybe I should raise them in a new issue? |
Oh, you wrote those items as steps to reproduce the problem, so I didn't know that those were issues that needed to/could be fixed. Once I discovered that the issue you posted was unrelated to DeepSpeed I took the liberty to adjust the subject. In general, yes, let's try to keep each issue separate, so that it makes it much easier to track things and not let things fall between the cracks. Back to your follow up question: Looking just at the params:
So the 2nd model is substantially larger, and if t5-3b fit tightly onto a 24GB card it's not surprising that the larger model didn't. and in addition to model params you also need to allocate memory for:
I tried mt5-xl on 4x 40gb gpu setup and it worked, but took ~29GB on each GPU, so there is the problem - you're 5GB short. The command I run was:
You may try to tweak the buffer sizes in I'm working on a 2D Parallelism solution that will combine pipe|model-parallelism w/ ZeRO-DP (DeepSpeed), which should enable such feats with huge models, but it might take some time. The docs aren't quite there so it takes a lot of trial and error to move forward. You may want to track this PR #9765 for updates. Alternatively when fairscale or DeepSpeed releases ZeRO phase 3, you shouldn't have a problem loading this model onto 4x 24GB gpus. Currently the problem is that the model params are too big w/o phase 3. In phase 3 params are partitioned too - problem solved. |
That's help a lot! Thank you! I am also looking forward to ZeRO stage 3 and your pipe|model-parallelism. Hope one day we can working on it. Thank you again! |
Did you get
mT5-xl is actually quite bigger than T5-3b for two reasons
|
@patil-suraj That's very helpful! Thank you a lot! Now I understand that there are many differences between mT5-xl and T5-3b, and I will set up separate experiments for them in the future. By the way, do you have any plans to repair the FP16 in mt5-large/xl ? |
Dear @patil-suraj, here you have mentioned for mt5-small you have made it work with fp16? since you did not mention this model, do you mind telling me how you made it work? I am having a hard time with mt5-small with fp16 thanks a lot for your advice |
I have a similar error here from transformers import T5TokenizerFast, MT5ForConditionalGeneration
tokenizer = T5TokenizerFast.from_pretrained('google/mt5-base') # "google/mt5-base" "google/mt5-large" "google/mt5-xl"
model = MT5ForConditionalGeneration.from_pretrained('google/mt5-base', return_dict=True)
condition = "translate English to German: "
input = "My name is Azeem and I live in India"
# You can also use "translate English to French" and "translate English to Romanian"
input_ids = tokenizer(condition+input, return_tensors="pt").input_ids # Batch size 1
outputs = model.generate(input_ids)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(decoded) Stacktrace:
@stas00 any idea? I'm using HF master:
|
Environment info
transformers
version: 4.2.2Who can help
@stas00,@patrickvonplaten, @patil-suraj
Information
Model I am using (MT5-xl,MT5-large):
The problem arises when using:
The tasks I am working on is:
To reproduce
Steps to reproduce the behavior:
exmaples/seq2seq/finetune_trainer.py
, which was originally used to reproduce the training of T5-3b on single 3090. All processes are the same as #8771 and it can reproduce the training of T5-3b(whether single card or 2/4 cards).--freeze_embeds
seems to bring bugs. I used 4*3090, My script isHere is my report:
--freeze_embeds
and tried to train MT5-xl again, but I got CUDA out of memory. My device is 4*24G 3090, with BS=1, ZeRO stage=2, and CPU_offload=true. I assume that T5-3b and MT5-xl should be in the same order of magnitude, and I can do it on t5-3b, so I think this should not happen.Expected behavior
The text was updated successfully, but these errors were encountered: