From bab7982f4fc843b41e84d2cec9c08dd19208f18f Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 9 May 2025 09:36:40 +0000 Subject: [PATCH 1/5] add upsample bicubic op support into torch frontend --- .../torch/exported_program_translator.py | 28 ++++++++++++++ python/tvm/relax/op/image/image.py | 2 +- python/tvm/topi/image/resize.py | 6 +-- .../test_frontend_from_exported_program.py | 30 +++++++++++++++ tests/python/relax/test_frontend_from_fx.py | 37 +++++++++++++++++++ 5 files changed, 99 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index dbe37b886017..1a4cd3360bc6 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -208,6 +208,33 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var: align_corners=align_corners, ) + def _upsample_bicubic2d(self, node: fx.node) -> relax.Var: + x = self.env[node.args[0]] + size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) + + if size is not None: + scale_factor = None + align_corners = ( + node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", None) + ) + else: + align_corners = ( + node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", None) + ) + scale_arg = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", 1) + if isinstance(scale_arg, (list, tuple)): + scale_factor = scale_arg[0] + else: + scale_factor = scale_arg + + return self._upsample_impl( + x, + size=size, + scale_factor=scale_factor, + method="cubic", + align_corners=align_corners, + ) + ########## Manipulation ########## def _narrow(self, node: fx.Node) -> relax.Var: @@ -426,6 +453,7 @@ def create_convert_map( "unbind.int": self._unbind, "upsample_bilinear2d.vec": self._upsample_bilinear2d, "upsample_nearest2d.vec": self._upsample_nearest2d, + "upsample_bicubic2d.vec": self._upsample_bicubic2d, # statistical "mean.dim": self._mean, "prod.default": self._prod, diff --git a/python/tvm/relax/op/image/image.py b/python/tvm/relax/op/image/image.py index e314e9b49af5..6bec22161dbc 100644 --- a/python/tvm/relax/op/image/image.py +++ b/python/tvm/relax/op/image/image.py @@ -35,7 +35,7 @@ def resize2d( method: str = "linear", coordinate_transformation_mode: str = "half_pixel", rounding_method: str = "round", - cubic_alpha: float = -0.5, + cubic_alpha: float = -0.75, cubic_exclude: int = 0, extrapolation_value: float = 0.0, out_dtype: Optional[Union[str, DataType]] = None, diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 5cbc292adbd7..ad2c99fa3ac1 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -376,7 +376,7 @@ def resize1d( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="", - bicubic_alpha=-0.5, + bicubic_alpha=-0.75, bicubic_exclude=0, extrapolation_value=0.0, out_dtype=None, @@ -748,7 +748,7 @@ def resize2d( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="", - bicubic_alpha=-0.5, + bicubic_alpha=-0.75, bicubic_exclude=0, extrapolation_value=0.0, out_dtype=None, @@ -1217,7 +1217,7 @@ def resize3d( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="", - bicubic_alpha=-0.5, + bicubic_alpha=-0.75, bicubic_exclude=0, extrapolation_value=0.0, out_dtype=None, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index ef198d2f83f3..1ebb2ffdadd3 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2973,9 +2973,39 @@ def main( R.output(gv) return gv + class InterpolateBicubic(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, (224, 224), mode="bicubic") + + @tvm.script.ir_module + class expected_bicubic: + @R.function + def main( + input: R.Tensor((1, 3, 112, 112), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d( + input, + R.shape([224, 224]), + roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], + layout="NCHW", + method="cubic", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0.0, + out_dtype="void", + ) + gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,) + R.output(gv) + return gv + example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),) verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear) verify_model(InterpolateNearest(), example_args, {}, expected_nearest) + verify_model(InterpolateBicubic(), example_args, {}, expected_bicubic) def test_mean(): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 681474244ae8..b3bbd581df6b 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3335,6 +3335,43 @@ def main( verify_model(Interpolate3(), input_info, {}, expected3) + class Interpolate4(Module): + def forward(self, input): + return torch.nn.functional.interpolate( + input, + size=None, + scale_factor=(2.0, 1.0), + mode="bicubic", + align_corners=False, + ) + + @tvm.script.ir_module + class expected4: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 20, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 20, 10), dtype="float32") = R.image.resize2d( + input_1, + (20, 10), + roi=[0.000000, 0.000000, 0.000000, 0.000000], + layout="NCHW", + method="cubic", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 3, 20, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Interpolate4(), input_info, {}, expected4) + def test_addmm(): input_info = [ From 87a84d61287e5bdc58653b6533c38859a4a81d64 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 9 May 2025 11:15:21 +0000 Subject: [PATCH 2/5] fix cubic alpha value for all interpolate func --- tests/python/relax/test_frontend_from_exported_program.py | 4 ++-- tests/python/relax/test_frontend_from_fx.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 1ebb2ffdadd3..6ef1eb8fa840 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2935,7 +2935,7 @@ def main( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0.0, out_dtype="void", @@ -2964,7 +2964,7 @@ def main( method="nearest_neighbor", coordinate_transformation_mode="half_pixel", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0.0, out_dtype="void", diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index b3bbd581df6b..0c9dca5c99a3 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3250,7 +3250,7 @@ def main( method="nearest_neighbor", coordinate_transformation_mode="asymmetric", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0, out_dtype="", @@ -3287,7 +3287,7 @@ def main( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0, out_dtype="", @@ -3324,7 +3324,7 @@ def main( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0, out_dtype="", From 7c9a1e168018fc51ea111a1125d9860022e5612e Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 9 May 2025 11:59:09 +0000 Subject: [PATCH 3/5] fix cubic alpha values in all test script --- tests/python/relax/test_frontend_nn_op.py | 2 +- tests/python/relax/test_transform_convert_layout.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 483e48217d92..1af13f048770 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -305,7 +305,7 @@ def test( method="nearest_neighbor", coordinate_transformation_mode="asymmetric", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0, out_dtype="void", diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index db4130f947d1..262e37b91b1b 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -1434,7 +1434,7 @@ def main( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0, out_dtype="void", @@ -1477,7 +1477,7 @@ def main( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0, out_dtype="void", From 1db0bdb2d47352d83c4946fcd273e1b7daf8f2c4 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 9 May 2025 13:06:46 +0000 Subject: [PATCH 4/5] update the mapping code in frontend --- .../frontend/torch/exported_program_translator.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 1a4cd3360bc6..65bbc0c05ca0 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -211,16 +211,12 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var: def _upsample_bicubic2d(self, node: fx.node) -> relax.Var: x = self.env[node.args[0]] size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) - + align_corners = ( + node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", None) + ) if size is not None: scale_factor = None - align_corners = ( - node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", None) - ) else: - align_corners = ( - node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", None) - ) scale_arg = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", 1) if isinstance(scale_arg, (list, tuple)): scale_factor = scale_arg[0] From b3d8b0dc70c3a9d81f55c240de4da647b28fc5ee Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 9 May 2025 13:37:57 +0000 Subject: [PATCH 5/5] fix lint issue --- python/tvm/relax/frontend/torch/exported_program_translator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 65bbc0c05ca0..49d642d59646 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -212,7 +212,7 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var: x = self.env[node.args[0]] size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) align_corners = ( - node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", None) + node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", None) ) if size is not None: scale_factor = None