-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix torch.compile
with fullgraph=True
when attention_mask
input is used
#29211
Conversation
Let's consider using pytorch/pytorch#120400 if this is accepted and released. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we don't have a choice?
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when | ||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. | ||
# Details: https://github.com/pytorch/pytorch/issues/110213 | ||
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1, keepdim=True)).to( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kind of related to #29210
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker this would conflict but is unrelated
is_tracing = ( | ||
torch.jit.is_tracing() | ||
or isinstance(input_tensor, torch.fx.Proxy) | ||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree
The other choice would be: is_tracing = (
torch.jit.is_tracing()
or isinstance(input_tensor, torch.fx.Proxy)
or torch._dynamo.is_fullgraph_tracing())
) but One other possibility is to always do One other possibility is to not use One other possibility is to move the causal mask logic outside of the modeling code. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, let's make sure CIs are green en bench are not slower!
Thanks for the fix! Summary:
|
@kwen2501 Thank you! torch.export is not always using dynamo? Reading https://pytorch.org/docs/stable/export.html I thought so! |
… is used (#29211) * fix torch.export.export for llama * do not change doc title * make fix copies
As per title.
Fixes #29190