Skip to content

Commit ecb7677

Browse files
Make onnx export SDPA match aten behavior (#2479)
This PR makes onnx sdpa export match the behavior of aten sdpa when boolean mask is used. ```python import onnxruntime as ort import torch class ScaledDotProductAttention(torch.nn.Module): def forward(self, query, key, value, attn_mask): return torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask) model = ScaledDotProductAttention() attn_mask = torch.ones(2, 4, 8, 8).bool() # boolean mask for attention attn_mask[0, 0, 0, :] = False # masking an entire row (padding token) query = key = value = torch.randn(2, 4, 8, 16) output = model(query, key, value, attn_mask) torch.onnx.export( model, (query, key, value, attn_mask), "scaled_dot_product_attention.onnx", input_names=["query", "key", "value", "attn_mask"], output_names=["output"], opset_version=18, dynamo=True, # or False ) ort_session = ort.InferenceSession("scaled_dot_product_attention.onnx") np_inputs = {"query": query.numpy(), "key": key.numpy(), "value": value.numpy(), "attn_mask": attn_mask.numpy()} onnx_outputs = ort_session.run(None, np_inputs)[0] torch.testing.assert_close(output, torch.tensor(onnx_outputs), equal_nan=True) ``` fails the assertion because the ort model outputs nans.
1 parent 32f2196 commit ecb7677

File tree

1 file changed

+5
-0
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+5
-0
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2076,6 +2076,11 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx(
20762076
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),
20772077
axis=-1,
20782078
)
2079+
# When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values
2080+
# due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output.
2081+
# This is because there's no safe/masked softmax imp in ONNX, so we need to handle NaN values explicitly to match
2082+
# the behavior of PyTorch with boolean masks.
2083+
attn_weight = op.Where(op.IsNaN(attn_weight), zero, attn_weight)
20792084
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
20802085
return op.MatMul(attn_weight, value)
20812086

0 commit comments

Comments
 (0)