diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index a6cbb18493..7ea058a951 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2204,7 +2204,9 @@ def _get_upsample_align_corners_mode(align_corners: bool) -> str: @torch_op( ( "aten::upsample_bicubic2d", + "aten::upsample_bicubic2d_aa", "aten::upsample_bilinear2d", + "aten::upsample_bilinear2d_aa", "aten::upsample_nearest1d", "aten::upsample_nearest2d", "aten::upsample_nearest3d", @@ -2216,6 +2218,7 @@ def _aten_upsample_output_size( output_size: INT64, mode: str, coordinate_transformation_mode: str, + antialias: int = 0, ) -> TReal: self_shape = op.Shape(self) starts = op.Constant(value_ints=[0]) @@ -2230,6 +2233,7 @@ def _aten_upsample_output_size( mode=mode, coordinate_transformation_mode=coordinate_transformation_mode, nearest_mode="floor", + antialias=antialias, ) @@ -2273,6 +2277,28 @@ def aten_upsample_bicubic2d( ) +@torch_op("aten::_upsample_bicubic2d_aa", trace_only=True) +def aten__upsample_bicubic2d_aa( + self: TReal, + output_size: INT64, + align_corners: bool, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> TReal: + """_upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" + + # NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch, + # unless when align_corners is True, in which case we do not know what is going on. + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + return _aten_upsample_output_size( + self, + output_size, + mode="cubic", + coordinate_transformation_mode=coordinate_transformation_mode, + antialias=1, + ) + + @torch_op("aten::upsample_bicubic2d.vec", trace_only=True) def aten_upsample_bicubic2d_vec( self: TReal, @@ -2335,6 +2361,28 @@ def aten_upsample_bilinear2d( ) +@torch_op("aten::_upsample_bilinear2d_aa", trace_only=True) +def aten__upsample_bilinear2d_aa( + self: TReal, + output_size: INT64, + align_corners: bool, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> TReal: + """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" + + # NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch, + # unless when align_corners is True, in which case we do not know what is going on. + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + return _aten_upsample_output_size( + self, + output_size, + coordinate_transformation_mode=coordinate_transformation_mode, + mode="linear", + antialias=1, + ) + + @torch_op("aten::upsample_bilinear2d.vec", trace_only=True) def aten_upsample_bilinear2d_vec( self: TReal, diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 086264e9bf..f1b42dfc03 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -2232,6 +2232,13 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_2d, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._upsample_bicubic2d_aa", + aten_name="_upsample_bicubic2d_aa", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_2d, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.upsample_bicubic2d.vec", aten_name="upsample_bicubic2d.vec", @@ -2239,6 +2246,13 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_2d_vec, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._upsample_bilinear2d_aa", + aten_name="_upsample_bilinear2d_aa", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_2d, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.upsample_bilinear2d.default", aten_name="upsample_bilinear2d", diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 3cfd4a1629..430c2c691e 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -2100,6 +2100,13 @@ def _where_input_wrangler( and sample.kwargs.get("scales_h") is not None, reason="fixme: align_corners=False output mismatch when scales are provided", ), + TorchLibOpInfo( + "ops.aten._upsample_bilinear2d_aa", + nn_ops.aten__upsample_bilinear2d_aa, + trace_only=True, + # ONNX use different antialias method than PyTorch, so the result is different + compare_shape_only_for_output=(0,), + ), TorchLibOpInfo( "ops.aten.upsample_bilinear2d.vec", nn_ops.aten_upsample_bilinear2d_vec, @@ -2119,6 +2126,13 @@ def _where_input_wrangler( nn_ops.aten_upsample_bicubic2d_vec, trace_only=True, ), + TorchLibOpInfo( + "ops.aten._upsample_bicubic2d_aa", + nn_ops.aten__upsample_bicubic2d_aa, + trace_only=True, + # ONNX use different antialias method than PyTorch, so the result is different + compare_shape_only_for_output=(0,), + ), TorchLibOpInfo( "ops.aten.upsample_linear1d", nn_ops.aten_upsample_linear1d,