Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[performance] ensure causal_mask is created directly on device #22378

Merged
merged 6 commits into from
Mar 28, 2023

Conversation

jeffra
Copy link
Contributor

@jeffra jeffra commented Mar 25, 2023

What does this PR do?

@tjruwase and @tohtana discovered that causal_mask is currently being created on CPU then moved to GPU during the forward pass of OPT (and we think other models). This appears to be causing a significant performance degradation on multi-gpu environments due to parallel host to device copies going on. It's not 100% clear to us why this is so bad but here is what we observe before and after this patch:

Before this patch w. OPT-125m on x8 A100s:
image

After the patch:
image

These numbers were gathered from a modified version of https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py but turning on wall_clock_breakdown: true in our deepspeed config.

One major complication we see in accepting this PR is that the two functions being modified are copied across lots of different models and the make fix-copies script doesn't seem to address all of them correctly across both _make_causal_mask and _prepare_decoder_attention_mask

Who can review?

Tagging @sgugger and @stas00 to help triage to the right people

@stas00 stas00 changed the title ensure causal_mask is created directly on device [performance] ensure causal_mask is created directly on device Mar 25, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 25, 2023

The documentation is not available anymore as the PR was closed or merged.

@ydshieh
Copy link
Collaborator

ydshieh commented Mar 25, 2023

cc @thomasw21 @NouamaneTazi since both of you are experts on this kind of things - to see if you have any general opinion and/or if you would like to review this PR too.

@ydshieh
Copy link
Collaborator

ydshieh commented Mar 25, 2023

@jeffra Would it possible for you (and/or @tjruwase and @tohtana) to provide your script that finds/measures/profiles the running time for this issue 🙏 . It would be super helpful for us to dive into internally too.

Copy link
Member

@NouamaneTazi NouamaneTazi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a big supporter of removing CPU-GPU syncs, so I would very like see this merged! ⚡️

Comment on lines +74 to +75
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)

It seems that torch.tensor(torch.finfo(torch.float32).min, device="cuda") requires CPU-GPU sync

Copy link
Member

@NouamaneTazi NouamaneTazi Mar 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise if a tensor is needed, we can also do

torch.cuda.FloatTensor([torch.finfo(dtype).min]) # no sync
torch.ones(1, device=device) * torch.finfo(dtype).min # no sync

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, it would be nice to remove that additional CPU-GPU sync. However, I don't think this would work in a case where device=cpu or any other non-cuda device.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried my suggestion:mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
and it does work on "cpu". I'm not sure why it wouldn't work? 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the pytorch doc (https://pytorch.org/docs/stable/generated/torch.full.html), torch.full allows for scalar type for fill_value. So I think what @NouamaneTazi is trying to convey is that you don't need to put it first cast -inf to a tensor type and then fill, you could just fill with -inf

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot for the fix! Note that the same modification needs to be applied to BART (since OPT copies from BART) in order for all quality checks to pass.

@ydshieh
Copy link
Collaborator

ydshieh commented Mar 27, 2023

LGTM, thanks a lot for the fix! Note that the same modification needs to be applied to BART (since OPT copies from BART) in order for all quality checks to pass.

FYI (@sgugger) : @stas00 mentioned on Slack

I tried to support Jeff to tell him to how make copies but he found that many copies are either not tagged properly or the copied functions were completely renamed and thus it's very difficult to make an automatedtransformers-wide fix

and in this PR description, the author(s)

One major complication we see in accepting this PR is that the two functions being modified are copied across lots of different models and the make fix-copies script doesn't seem to address all of them correctly across both _make_causal_mask and _prepare_decoder_attention_mask

It's likely that they expect us to help on this part. I can help (I was waiting for the approval for the fix in OPT which is done now.)

@sgugger
Copy link
Collaborator

sgugger commented Mar 27, 2023

I think just copying the same fix to BART and then applying make fix-copies is simple enough for this PR. Dealing with functions that are not copies or are named differently can indeed be done in followup PRs.

@jeffra
Copy link
Contributor Author

jeffra commented Mar 27, 2023

Ok, i've updated the BART implementation and attempted to get make fix-copies to work for me but I think I might be doing something wrong. Some of the original issues I saw are now fixed on other models (e.g., #22382 adds a # Copied from tag for llama). However, I am still seeing issues i think coming from the fix-up scripts getting confused with the function signature change of _make_causal_mask. Also, I added the # Copied from tag into opt for _make_causal_mask which was part of my previous issue i think.

Can someone try make fix-copies on their side with this? You should be able to push to my branch.

For example, here's the diff of src/transformers/models/xglm/modeling_xglm.py after applying make fix-copies in this branch, it does not add device as an argument to _make_causal_mask:

diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py
index 8a1955793..59851bd85 100755
--- a/src/transformers/models/xglm/modeling_xglm.py
+++ b/src/transformers/models/xglm/modeling_xglm.py
@@ -119,13 +119,13 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
     Make causal mask used for bi-directional self-attention.
     """
     bsz, tgt_len = input_ids_shape
-    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
-    mask_cond = torch.arange(mask.size(-1))
+    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
+    mask_cond = torch.arange(mask.size(-1), device=device)
     mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
     mask = mask.to(dtype)

     if past_key_values_length > 0:
-        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
+        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
     return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

It modifies all of these models, so ideally don't want to edit these manually :)

        modified:   src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
        modified:   src/transformers/models/biogpt/modeling_biogpt.py
        modified:   src/transformers/models/blenderbot/modeling_blenderbot.py
        modified:   src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
        modified:   src/transformers/models/informer/modeling_informer.py
        modified:   src/transformers/models/llama/modeling_llama.py
        modified:   src/transformers/models/m2m_100/modeling_m2m_100.py
        modified:   src/transformers/models/marian/modeling_marian.py
        modified:   src/transformers/models/mbart/modeling_mbart.py
        modified:   src/transformers/models/mvp/modeling_mvp.py
        modified:   src/transformers/models/nllb_moe/modeling_nllb_moe.py
        modified:   src/transformers/models/pegasus/modeling_pegasus.py
        modified:   src/transformers/models/pegasus_x/modeling_pegasus_x.py
        modified:   src/transformers/models/plbart/modeling_plbart.py
        modified:   src/transformers/models/speech_to_text/modeling_speech_to_text.py
        modified:   src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
        modified:   src/transformers/models/speecht5/modeling_speecht5.py
        modified:   src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
        modified:   src/transformers/models/trocr/modeling_trocr.py
        modified:   src/transformers/models/whisper/modeling_whisper.py
        modified:   src/transformers/models/xglm/modeling_xglm.py

@sgugger
Copy link
Collaborator

sgugger commented Mar 27, 2023

Ah yes, make fix-copies does not change the signature of the function so that is indeed something to edit manually. If it's too much work I can try to push this to your branch tomorrow.

@jeffra
Copy link
Contributor Author

jeffra commented Mar 27, 2023

Ah yes, make fix-copies does not change the signature of the function so that is indeed something to edit manually. If it's too much work I can try to push this to your branch tomorrow.

Sounds good, I might have some time this afternoon for this. Otherwise feel free to do it :) Just wasn't sure if this was an expected issue with the copy scripts or not.

@jeffra
Copy link
Contributor Author

jeffra commented Mar 28, 2023

Okay all the models should be fixed now, make fixup is clear on my local tests.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for copy-pasting the signature changes manually!

@sgugger sgugger merged commit ae5fc2d into huggingface:main Mar 28, 2023
raghavanone pushed a commit to raghavanone/transformers that referenced this pull request Apr 5, 2023
…gingface#22378)

* ensure causal_mask is created directly on device

* add copy tag to opt, update bart implementation

* add device to all _make_causal_mask copies

* formatting fixes

* more manual fixes due to unlinked versions of _prepare_decoder_attention_mask
xloem pushed a commit to xloem/transformers that referenced this pull request Apr 9, 2023
…gingface#22378)

* ensure causal_mask is created directly on device

* add copy tag to opt, update bart implementation

* add device to all _make_causal_mask copies

* formatting fixes

* more manual fixes due to unlinked versions of _prepare_decoder_attention_mask
xloem pushed a commit to xloem/transformers that referenced this pull request Apr 10, 2023
…gingface#22378)

* ensure causal_mask is created directly on device

* add copy tag to opt, update bart implementation

* add device to all _make_causal_mask copies

* formatting fixes

* more manual fixes due to unlinked versions of _prepare_decoder_attention_mask
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…gingface#22378)

* ensure causal_mask is created directly on device

* add copy tag to opt, update bart implementation

* add device to all _make_causal_mask copies

* formatting fixes

* more manual fixes due to unlinked versions of _prepare_decoder_attention_mask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants