From d6eccaa50475cdcf06175c92b38419efc04258b6 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Sun, 26 Oct 2025 19:38:51 +0530 Subject: [PATCH 1/2] Fix attention mask to use float_lowest instead of -inf and add unit test for softmax NaN case --- onnxscript/function_libs/torch_lib/ops/nn.py | 6 +++++- tests/common/testutils.py | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 4f81cc7907..65bb2aa079 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -14,6 +14,7 @@ from __future__ import annotations +import numpy as np import math from typing import Optional, Sequence, Tuple, TypeVar, Union @@ -2048,6 +2049,9 @@ def _aten_scaled_dot_product_attention_no_mask_onnx( attn_weight, _ = op.Dropout(attn_weight, dropout_p) return op.MatMul(attn_weight, value) +def float_lowest(dtype): + """Returns the lowest representable value for the given numpy dtype.""" + return np.finfo(np.dtype(dtype)).min def _aten_scaled_dot_product_attention_bool_mask_onnx( query: TFloat, @@ -2078,7 +2082,7 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale)) # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype)) - neg_inf = op.Constant(value=ir.tensor(-float("inf"), dtype=query.dtype)) + neg_inf = op.Constant(value=ir.tensor(float_lowest(query.dtype)), dtype=query.dtype) attn_mask = op.Where(attn_mask, zero, neg_inf) attn_weight = op.Softmax( op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), diff --git a/tests/common/testutils.py b/tests/common/testutils.py index 2a2697b240..1db673eab8 100644 --- a/tests/common/testutils.py +++ b/tests/common/testutils.py @@ -14,6 +14,7 @@ import torch from onnxscript import optimizer +from onnxscript.onnx_opset import opset18 as op from onnxscript.rewriter import onnxruntime as ort_rewriter from onnxscript.utils import evaluation_utils @@ -101,3 +102,9 @@ def test_onnxruntime_rewrite( f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}" ) raise + +def test_softmax_with_all_inf_mask(): + # GH #2561 + input = np.array([[-float("inf"), -float("inf")]], dtype=np.float32) + output = op.Softmax(input, axis=-1) + assert np.isnan(output).all(), "Softmax should return NaN when all inputs are -inf" From 0d7c411d1e381322178f0e5fe252eecd54d4d25a Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Thu, 13 Nov 2025 20:25:40 +0530 Subject: [PATCH 2/2] Remove helper function and test --- onnxscript/function_libs/torch_lib/ops/nn.py | 6 +----- tests/common/testutils.py | 7 ------- 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 65bb2aa079..6cce402ddf 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -14,7 +14,6 @@ from __future__ import annotations -import numpy as np import math from typing import Optional, Sequence, Tuple, TypeVar, Union @@ -2049,9 +2048,6 @@ def _aten_scaled_dot_product_attention_no_mask_onnx( attn_weight, _ = op.Dropout(attn_weight, dropout_p) return op.MatMul(attn_weight, value) -def float_lowest(dtype): - """Returns the lowest representable value for the given numpy dtype.""" - return np.finfo(np.dtype(dtype)).min def _aten_scaled_dot_product_attention_bool_mask_onnx( query: TFloat, @@ -2082,7 +2078,7 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale)) # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype)) - neg_inf = op.Constant(value=ir.tensor(float_lowest(query.dtype)), dtype=query.dtype) + neg_inf = op.Constant(value=ir.tensor(query.dtype.min), dtype=query.dtype) attn_mask = op.Where(attn_mask, zero, neg_inf) attn_weight = op.Softmax( op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), diff --git a/tests/common/testutils.py b/tests/common/testutils.py index 1db673eab8..2a2697b240 100644 --- a/tests/common/testutils.py +++ b/tests/common/testutils.py @@ -14,7 +14,6 @@ import torch from onnxscript import optimizer -from onnxscript.onnx_opset import opset18 as op from onnxscript.rewriter import onnxruntime as ort_rewriter from onnxscript.utils import evaluation_utils @@ -102,9 +101,3 @@ def test_onnxruntime_rewrite( f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}" ) raise - -def test_softmax_with_all_inf_mask(): - # GH #2561 - input = np.array([[-float("inf"), -float("inf")]], dtype=np.float32) - output = op.Softmax(input, axis=-1) - assert np.isnan(output).all(), "Softmax should return NaN when all inputs are -inf"