Skip to content

Conversation

IlyasMoutawwakil
Copy link
Contributor

This PR makes onnx sdpa export match the behavior of aten sdpa when boolean mask is used.

import onnxruntime as ort
import torch


class ScaledDotProductAttention(torch.nn.Module):
    def forward(self, query, key, value, attn_mask):
        return torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)


model = ScaledDotProductAttention()
attn_mask = torch.ones(2, 4, 8, 8).bool()  # boolean mask for attention
attn_mask[0, 0, 0, :] = False  # masking an entire row (padding token)
query = key = value = torch.randn(2, 4, 8, 16)
output = model(query, key, value, attn_mask)

torch.onnx.export(
    model,
    (query, key, value, attn_mask),
    "scaled_dot_product_attention.onnx",
    input_names=["query", "key", "value", "attn_mask"],
    output_names=["output"],
    opset_version=18,
    dynamo=True, # or False
)
ort_session = ort.InferenceSession("scaled_dot_product_attention.onnx")

np_inputs = {"query": query.numpy(), "key": key.numpy(), "value": value.numpy(), "attn_mask": attn_mask.numpy()}
onnx_outputs = ort_session.run(None, np_inputs)[0]

torch.testing.assert_close(output, torch.tensor(onnx_outputs), equal_nan=True)

fails the assertion because the ort model outputs nans.

@IlyasMoutawwakil
Copy link
Contributor Author

@titaiwangms @justinchuby

@titaiwangms titaiwangms enabled auto-merge (squash) August 7, 2025 15:55
Copy link

codecov bot commented Aug 7, 2025

Codecov Report

❌ Patch coverage is 0% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 69.81%. Comparing base (32f2196) to head (0068e40).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/nn.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2479      +/-   ##
==========================================
- Coverage   69.81%   69.81%   -0.01%     
==========================================
  Files         209      209              
  Lines       25313    25314       +1     
  Branches     2525     2525              
==========================================
  Hits        17673    17673              
- Misses       6762     6763       +1     
  Partials      878      878              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@titaiwangms titaiwangms disabled auto-merge August 7, 2025 16:26
@titaiwangms titaiwangms merged commit ecb7677 into microsoft:main Aug 7, 2025
25 of 32 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in ONNX Script Review Board Aug 7, 2025
titaiwangms added a commit that referenced this pull request Aug 8, 2025
@justinchuby
Copy link
Collaborator

# This is because there's no safe/masked softmax imp in ONNX, so we need to handle NaN values explicitly to match
# the behavior of PyTorch with boolean masks.
attn_weight = op.Where(op.IsNaN(attn_weight), zero, attn_weight)
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@titaiwangms we should probably conditionally skip this line (even though there is a rewrite rule already)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you fix this, can you also please add a reference to pytorch/pytorch#103749 in the comments for the previous line fixing NaN?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We skip when dropout_p is 0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Aug 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development
Projects
Development

Successfully merging this pull request may close these issues.

4 participants