File tree Expand file tree Collapse file tree 1 file changed +5
-0
lines changed
onnxscript/function_libs/torch_lib/ops Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Original file line number Diff line number Diff line change @@ -2076,6 +2076,11 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx(
2076
2076
op .Add (op .MatMul (query_scaled , key_transposed_scaled ), attn_mask ),
2077
2077
axis = - 1 ,
2078
2078
)
2079
+ # When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values
2080
+ # due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output.
2081
+ # This is because there's no safe/masked softmax imp in ONNX, so we need to handle NaN values explicitly to match
2082
+ # the behavior of PyTorch with boolean masks.
2083
+ attn_weight = op .Where (op .IsNaN (attn_weight ), zero , attn_weight )
2079
2084
attn_weight , _ = op .Dropout (attn_weight , dropout_p )
2080
2085
return op .MatMul (attn_weight , value )
2081
2086
You can’t perform that action at this time.
0 commit comments