-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[performance] ensure causal_mask
is created directly on device
#22378
Conversation
causal_mask
is created directly on device
The documentation is not available anymore as the PR was closed or merged. |
cc @thomasw21 @NouamaneTazi since both of you are experts on this kind of things - to see if you have any general opinion and/or if you would like to review this PR too. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a big supporter of removing CPU-GPU syncs, so I would very like see this merged! ⚡️
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) | ||
mask_cond = torch.arange(mask.size(-1), device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) | |
mask_cond = torch.arange(mask.size(-1), device=device) | |
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) | |
mask_cond = torch.arange(mask.size(-1), device=device) |
It seems that torch.tensor(torch.finfo(torch.float32).min, device="cuda")
requires CPU-GPU sync
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise if a tensor is needed, we can also do
torch.cuda.FloatTensor([torch.finfo(dtype).min]) # no sync
torch.ones(1, device=device) * torch.finfo(dtype).min # no sync
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, it would be nice to remove that additional CPU-GPU sync. However, I don't think this would work in a case where device=cpu
or any other non-cuda device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just tried my suggestion:mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
and it does work on "cpu". I'm not sure why it wouldn't work? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the pytorch doc (https://pytorch.org/docs/stable/generated/torch.full.html), torch.full
allows for scalar type for fill_value
. So I think what @NouamaneTazi is trying to convey is that you don't need to put it first cast -inf
to a tensor type and then fill, you could just fill with -inf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks a lot for the fix! Note that the same modification needs to be applied to BART (since OPT copies from BART) in order for all quality checks to pass.
FYI (@sgugger) : @stas00 mentioned on Slack
and in this PR description, the author(s)
It's likely that they expect us to help on this part. I can help (I was waiting for the approval for the fix in |
I think just copying the same fix to BART and then applying |
Ok, i've updated the BART implementation and attempted to get Can someone try For example, here's the diff of diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py
index 8a1955793..59851bd85 100755
--- a/src/transformers/models/xglm/modeling_xglm.py
+++ b/src/transformers/models/xglm/modeling_xglm.py
@@ -119,13 +119,13 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
- mask_cond = torch.arange(mask.size(-1))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) It modifies all of these models, so ideally don't want to edit these manually :)
|
Ah yes, |
Sounds good, I might have some time this afternoon for this. Otherwise feel free to do it :) Just wasn't sure if this was an expected issue with the copy scripts or not. |
Okay all the models should be fixed now, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for copy-pasting the signature changes manually!
…gingface#22378) * ensure causal_mask is created directly on device * add copy tag to opt, update bart implementation * add device to all _make_causal_mask copies * formatting fixes * more manual fixes due to unlinked versions of _prepare_decoder_attention_mask
…gingface#22378) * ensure causal_mask is created directly on device * add copy tag to opt, update bart implementation * add device to all _make_causal_mask copies * formatting fixes * more manual fixes due to unlinked versions of _prepare_decoder_attention_mask
…gingface#22378) * ensure causal_mask is created directly on device * add copy tag to opt, update bart implementation * add device to all _make_causal_mask copies * formatting fixes * more manual fixes due to unlinked versions of _prepare_decoder_attention_mask
…gingface#22378) * ensure causal_mask is created directly on device * add copy tag to opt, update bart implementation * add device to all _make_causal_mask copies * formatting fixes * more manual fixes due to unlinked versions of _prepare_decoder_attention_mask
What does this PR do?
@tjruwase and @tohtana discovered that causal_mask is currently being created on CPU then moved to GPU during the forward pass of OPT (and we think other models). This appears to be causing a significant performance degradation on multi-gpu environments due to parallel host to device copies going on. It's not 100% clear to us why this is so bad but here is what we observe before and after this patch:
Before this patch w. OPT-125m on x8 A100s:
After the patch:
These numbers were gathered from a modified version of https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py but turning on
wall_clock_breakdown: true
in our deepspeed config.One major complication we see in accepting this PR is that the two functions being modified are copied across lots of different models and the
make fix-copies
script doesn't seem to address all of them correctly across both_make_causal_mask
and_prepare_decoder_attention_mask
Who can review?
Tagging @sgugger and @stas00 to help triage to the right people