diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 890f925079e0..d99411bd5658 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -19,6 +19,7 @@ # pylint: disable=import-outside-toplevel """Base class for PyTorch FX Graph importer.""" import abc +from functools import reduce import math from typing import Callable, Dict, Optional, Tuple, Union @@ -1018,6 +1019,24 @@ def _expand_as(self, node: fx.Node) -> relax.Var: other_shape = self.shape_of(args[1]) # the shape of 'other' return self.block_builder.emit(relax.op.broadcast_to(data, other_shape)) + def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var: + shape = self.shape_of(x) + start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim + end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim + flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) + new_shape = ( + [shape[i] for i in range(0, start_dim)] + + [flattened] + + [shape[i] for i in range(end_dim + 1, len(shape))] + ) + return self.block_builder.emit(relax.op.reshape(x, new_shape)) + + def _flatten(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + start_dim = node.args[1] if len(node.args) >= 2 else node.kwargs.get("start_dim", 0) + end_dim = node.args[2] if len(node.args) == 3 else node.kwargs.get("end_dim", -1) + return self._flatten_impl(x, start_dim, end_dim) + def _flip(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dims = node.args[1] if len(node.args) > 1 else node.kwargs.get("dims", None) @@ -1233,6 +1252,21 @@ def _new_ones(self, node: fx.Node) -> relax.Var: ) ) + ########## DataType ########## + + def _to(self, node: fx.Node) -> relax.Var: + import torch + + x = self.env[node.args[0]] + if len(node.args) == 2: + if isinstance(node.args[1], torch.dtype): + dtype = BaseFXGraphImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + elif "dtype" in node.kwargs: + dtype = BaseFXGraphImporter._convert_data_type(node.kwargs["dtype"], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + return x + ########## Others ########## def _getitem(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 2e7c682aa34b..26121ecdea10 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -377,6 +377,7 @@ def create_convert_map( "cumprod.default": self._cumprod, "expand.default": self._expand, "expand_as.default": self._expand_as, + "flatten.using_ints": self._flatten, "flip.default": self._flip, "gather.default": self._gather, "permute.default": self._permute, @@ -411,6 +412,9 @@ def create_convert_map( "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, + # datatype + "to.dtype": self._to, + "to.dtype_layout": self._to, # other "getitem": self._getitem, } diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 3ddf919c2ed1..e79c1dbc48fa 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -415,24 +415,6 @@ def _chunk(self, node: fx.Node) -> relax.Var: dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) return self.block_builder.emit(relax.op.split(x, chunks, dim)) - def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var: - shape = self.shape_of(x) - start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim - end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim - flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) - new_shape = ( - [shape[i] for i in range(0, start_dim)] - + [flattened] - + [shape[i] for i in range(end_dim + 1, len(shape))] - ) - return self.block_builder.emit(relax.op.reshape(x, new_shape)) - - def _flatten(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - start_dim = node.args[1] if len(node.args) >= 2 else node.kwargs.get("start_dim", 0) - end_dim = node.args[2] if len(node.args) == 3 else node.kwargs.get("end_dim", -1) - return self._flatten_impl(x, start_dim, end_dim) - def _flatten_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 2175f9aa391c..cc2f669d32e0 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1021,10 +1021,6 @@ def main( verify_model(Min1(), example_args1, {}, expected_min1) -@pytest.mark.skipif( - version.parse(torch_version) >= version.parse("2.6.0"), - reason="Tests not compatible with PyTorch >= 2.6", -) def test_batchnorm2d(): class BatchNorm2d(Module): def __init__(self): @@ -2702,10 +2698,6 @@ def main( verify_model(Expand2(), example_args, {}, expected1) -@pytest.mark.skipif( - version.parse(torch_version) >= version.parse("2.6.0"), - reason="Tests not compatible with PyTorch >= 2.6", -) def test_flatten(): class Flatten(Module): def __init__(self): @@ -2907,10 +2899,6 @@ def main( verify_model(Slice2(), example_args, {}, expected2) -@pytest.mark.skipif( - version.parse(torch_version) >= version.parse("2.6.0"), - reason="Tests not compatible with PyTorch >= 2.6", -) def test_split(): class Chunk(Module): def forward(self, input): @@ -3340,10 +3328,6 @@ def main( verify_model(NewOnes(), example_args, {}, expected1) -@pytest.mark.skipif( - version.parse(torch_version) >= version.parse("2.6.0"), - reason="Tests not compatible with PyTorch >= 2.6", -) def test_to_copy(): # float class ToFloat(Module): @@ -3394,7 +3378,8 @@ def main( ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): # block 0 with R.dataflow(): - gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (x,) + lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, dtype="float32") + gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,) R.output(gv) return gv