diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index bb767071e7..e9ee699aa3 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2306,11 +2306,23 @@ def aten_upsample_bilinear2d( result = _aten_upsample_bilinear2d_output_size(self, output_size) else: assert scales_h is not None - assert scales_h == scales_w + assert scales_h == scales_w, f"scale_h({scales_h}) != scale_w({scales_w})" result = _aten_upsample_bilinear2d_scales(self, scales_h, scales_w) return result +@torch_op("aten::upsample_bilinear2d.vec", trace_only=True) +def aten_upsample_bilinear2d_vec( + self: TReal, + output_size: Optional[INT64] = None, + align_corners: bool = True, + scale_factors: Optional[Sequence[float]] = None, +) -> TReal: + scales_h = scale_factors[0] if scale_factors is not None else None + scales_w = scale_factors[1] if scale_factors is not None else None + return aten_upsample_bilinear2d(self, output_size, scales_h, scales_w, align_corners) + + @torch_op("aten::upsample_bilinear2d", private=True) def _aten_upsample_bilinear2d_output_size( self: TReal, diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 7aeba0d14b..cabc13268a 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -427,6 +427,18 @@ def _upsample_bilinear2d_input_wrangler( return args, kwargs +def _upsample_bilinear2d_vec_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + if "size" in kwargs: + args.append(np.array(kwargs["size"], dtype=np.int64)) + del kwargs["size"] # promote tensor type kwargs to args + if "scale_factor" in kwargs: + kwargs["scale_factors"] = [kwargs["scale_factor"]] * 2 + del kwargs["scale_factor"] # adapt the function signature + return args, kwargs + + def _upsample_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -2122,6 +2134,12 @@ def _where_input_wrangler( input_wrangler=_upsample_bilinear2d_input_wrangler, trace_only=True, ), + TorchLibOpInfo( + "nn.functional.upsample_bilinear2d", + nn_ops.aten_upsample_bilinear2d_vec, + input_wrangler=_upsample_bilinear2d_vec_input_wrangler, + trace_only=True, + ), TorchLibOpInfo( "ops.aten.upsample_bicubic2d", nn_ops.aten_upsample_bicubic2d,