diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 2652b167e5c0..ae4c918900ec 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -23,7 +23,7 @@ import math from typing import Callable, Dict, Optional, Tuple, Union, List -from tvm import relax +from tvm import relax, tir class BaseFXGraphImporter(metaclass=abc.ABCMeta): @@ -1164,6 +1164,85 @@ def _repeat(self, node: fx.Node) -> relax.Var: dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.tile(x, dims)) + def _roll(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + input_tensor = args[0] + shifts = args[1] if len(node.args) > 1 else node.kwargs.get("shifts", None) + dims = args[2] if len(node.args) > 2 else node.kwargs.get("dims", None) + + # Get original shape + original_shape = self.shape_of(input_tensor) + + def to_int(val): + if isinstance(val, tir.IntImm): + return int(val.value) + elif isinstance(val, int): + return val + elif hasattr(val, "__int__"): + return int(val) + raise TypeError(f"Unsupported type for shift/dim: {type(val)}") + + def roll_single_dim(tensor: relax.Var, shift: int, dim: int) -> relax.Var: + shape = self.shape_of(tensor) + + dim_size = shape.values[dim] + shift_val = to_int(shift) + dim_size_val = to_int(dim_size) + shift_mod = shift_val % dim_size_val + if shift_mod == 0: + return tensor + + split_pos = dim_size_val - shift_mod + part1 = self.block_builder.emit( + relax.op.strided_slice( + tensor, + axes=[dim], + begin=[0], + end=[split_pos], + strides=[1], + ) + ) + part2 = self.block_builder.emit( + relax.op.strided_slice( + tensor, + axes=[dim], + begin=[split_pos], + end=[dim_size_val], + strides=[1], + ) + ) + return self.block_builder.emit(relax.op.concat([part2, part1], axis=dim)) + + # Handle dims=None (flatten -> roll -> reshape) + if dims is None: + flattened = self.block_builder.emit(relax.op.reshape(input_tensor, (-1,))) + shift_scalar = to_int(shifts[0] if isinstance(shifts, (list, tuple)) else shifts) + rolled = roll_single_dim(flattened, shift_scalar, 0) + return self.block_builder.emit(relax.op.reshape(rolled, original_shape)) + + # Normalize shifts and dims + if isinstance(shifts, (list, tuple)): + shifts = [to_int(s) for s in shifts] + else: + shifts = [to_int(shifts)] + + if isinstance(dims, (list, tuple)): + dims = [to_int(d) for d in dims] + else: + dims = [to_int(dims)] + + if len(shifts) != len(dims): + raise ValueError("shifts and dims must have the same length") + + result = input_tensor + rank = len(original_shape.values) + for shift, dim in zip(shifts, dims): + if dim < 0: + dim += rank + result = roll_single_dim(result, shift, dim) + + return result + def _reshape(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 5d4f3437b257..932607287571 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -423,6 +423,7 @@ def create_convert_map( "narrow.default": self._narrow, "permute.default": self._permute, "repeat.default": self._repeat, + "roll.default": self._roll, "select.int": self._select, "slice.Tensor": self._slice, "split.Tensor": self._split, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index e6b1fdd223ea..5a34befb9296 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -750,6 +750,7 @@ def create_convert_map( "numel": self._numel, "permute": self._permute, "repeat": self._repeat, + "roll": self._roll, "reshape": self._reshape, "scatter": self._scatter, "select": self._select, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 9259936dc223..80c0bd5fb4f5 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2968,6 +2968,131 @@ def main( verify_model(ReshapeAs(), example_args, {}, expected1) +def test_roll(): + class Roll1(Module): + def forward(self, x): + return torch.roll(x, 1) + + class Roll2(Module): + def forward(self, x): + return torch.roll(x, -1, 0) + + class Roll3(Module): + def forward(self, x): + return torch.roll(x, shifts=(2, 1), dims=(0, 1)) + + # Test case 1: torch.roll(x, 1) + @I.ir_module + class Expected1: + @R.function + def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((8,), dtype="int64") = R.reshape(x, R.shape([8])) + lv1: R.Tensor((7,), dtype="int64") = R.strided_slice( + lv, + axes=[0], + begin=[R.prim_value(0)], + end=[R.prim_value(7)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv2: R.Tensor((1,), dtype="int64") = R.strided_slice( + lv, + axes=[0], + begin=[R.prim_value(7)], + end=[R.prim_value(8)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0) + lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2])) + gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv4,) + R.output(gv) + return gv + + # Test case 2: torch.roll(x, -1, 0) + @I.ir_module + class Expected2: + @R.function + def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice( + x, + axes=[0], + begin=[R.prim_value(0)], + end=[R.prim_value(1)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice( + x, + axes=[0], + begin=[R.prim_value(1)], + end=[R.prim_value(4)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) + gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv2,) + R.output(gv) + return gv + + # Test case 3: torch.roll(x, shifts=(2,1), dims=(0,1)) + @I.ir_module + class Expected3: + @R.function + def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): + with R.dataflow(): + # First roll along dim=0 with shift=2 + lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice( + x, + axes=[0], + begin=[R.prim_value(0)], + end=[R.prim_value(2)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice( + x, + axes=[0], + begin=[R.prim_value(2)], + end=[R.prim_value(4)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) + + # Second roll along dim=1 with shift=1 + lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice( + lv2, + axes=[1], + begin=[R.prim_value(0)], + end=[R.prim_value(1)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice( + lv2, + axes=[1], + begin=[R.prim_value(1)], + end=[R.prim_value(2)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1) + gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,) + R.output(gv) + return gv + + # Test inputs + example_input = torch.randint(0, 10, (4, 2), dtype=torch.int64) + + # Run verification for each case + verify_model(Roll1(), (example_input,), {}, Expected1) + verify_model(Roll2(), (example_input,), {}, Expected2) + verify_model(Roll3(), (example_input,), {}, Expected3) + + def test_select_slice(): class Slice1(Module): def forward(self, x): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 53c925e14ee6..c52255638072 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3560,6 +3560,126 @@ def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((4, 6), dtype="float3 verify_model(Tile2(), [(torch.Size([1, 3]), "float32")], {}, expected2) +def test_roll(): + class Roll1(Module): + def forward(self, x): + return torch.roll(x, 1) + + class Roll2(Module): + def forward(self, x): + return torch.roll(x, -1, 0) + + class Roll3(Module): + def forward(self, x): + return torch.roll(x, shifts=(2, 1), dims=(0, 1)) + + # Test case 1: torch.roll(x, 1) + @I.ir_module + class Expected1: + @R.function + def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((8,), dtype="int64") = R.reshape(inp_0, R.shape([8])) + lv1: R.Tensor((7,), dtype="int64") = R.strided_slice( + lv, + axes=[0], + begin=[R.prim_value(0)], + end=[R.prim_value(7)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv2: R.Tensor((1,), dtype="int64") = R.strided_slice( + lv, + axes=[0], + begin=[R.prim_value(7)], + end=[R.prim_value(8)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0) + lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2])) + gv: R.Tensor((4, 2), dtype="int64") = lv4 + R.output(gv) + return gv + + # Test case 2: torch.roll(x, -1, 0) + @I.ir_module + class Expected2: + @R.function + def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice( + inp_0, + axes=[0], + begin=[R.prim_value(0)], + end=[R.prim_value(1)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice( + inp_0, + axes=[0], + begin=[R.prim_value(1)], + end=[R.prim_value(4)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) + gv: R.Tensor((4, 2), dtype="int64") = lv2 + R.output(gv) + return gv + + # Test case 3: torch.roll(x, shifts=(2, 1), dims=(0, 1)) + @I.ir_module + class Expected3: + @R.function + def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice( + inp_0, + axes=[0], + begin=[R.prim_value(0)], + end=[R.prim_value(2)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice( + inp_0, + axes=[0], + begin=[R.prim_value(2)], + end=[R.prim_value(4)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) + lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice( + lv2, + axes=[1], + begin=[R.prim_value(0)], + end=[R.prim_value(1)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice( + lv2, + axes=[1], + begin=[R.prim_value(1)], + end=[R.prim_value(2)], + strides=[R.prim_value(1)], + assume_inbound=False, + ) + lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1) + gv: R.Tensor((4, 2), dtype="int64") = lv5 + R.output(gv) + return gv + + input_info = [([4, 2], "int64")] + + verify_model(Roll1(), input_info, {}, Expected1) + verify_model(Roll2(), input_info, {}, Expected2) + verify_model(Roll3(), input_info, {}, Expected3) + + def test_view(): input_info = [([1, 2, 3, 4], "float32")]