-
Notifications
You must be signed in to change notification settings - Fork 83
Make onnx export SDPA match aten behavior #2479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make onnx export SDPA match aten behavior #2479
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2479 +/- ##
==========================================
- Coverage 69.81% 69.81% -0.01%
==========================================
Files 209 209
Lines 25313 25314 +1
Branches 2525 2525
==========================================
Hits 17673 17673
- Misses 6762 6763 +1
Partials 878 878 ☔ View full report in Codecov by Sentry. |
Do we need to update https://github.com/pytorch/pytorch/blob/main/torch/onnx/_internal/exporter/_torchlib/ops/nn.py as well, or improve specs of the Attention op? @gramalingam @titaiwangms |
# This is because there's no safe/masked softmax imp in ONNX, so we need to handle NaN values explicitly to match | ||
# the behavior of PyTorch with boolean masks. | ||
attn_weight = op.Where(op.IsNaN(attn_weight), zero, attn_weight) | ||
attn_weight, _ = op.Dropout(attn_weight, dropout_p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@titaiwangms we should probably conditionally skip this line (even though there is a rewrite rule already)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you fix this, can you also please add a reference to pytorch/pytorch#103749 in the comments for the previous line fixing NaN?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We skip when dropout_p
is 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR makes onnx sdpa export match the behavior of aten sdpa when boolean mask is used.
fails the assertion because the ort model outputs nans.