Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import math
from typing import Callable, Dict, Optional, Tuple, Union, List

from tvm import relax
from tvm import relax, tir


class BaseFXGraphImporter(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -1164,6 +1164,85 @@ def _repeat(self, node: fx.Node) -> relax.Var:
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:]
return self.block_builder.emit(relax.op.tile(x, dims))

def _roll(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
input_tensor = args[0]
shifts = args[1] if len(node.args) > 1 else node.kwargs.get("shifts", None)
dims = args[2] if len(node.args) > 2 else node.kwargs.get("dims", None)

# Get original shape
original_shape = self.shape_of(input_tensor)

def to_int(val):
if isinstance(val, tir.IntImm):
return int(val.value)
elif isinstance(val, int):
return val
elif hasattr(val, "__int__"):
return int(val)
raise TypeError(f"Unsupported type for shift/dim: {type(val)}")

def roll_single_dim(tensor: relax.Var, shift: int, dim: int) -> relax.Var:
shape = self.shape_of(tensor)

dim_size = shape.values[dim]
shift_val = to_int(shift)
dim_size_val = to_int(dim_size)
shift_mod = shift_val % dim_size_val
if shift_mod == 0:
return tensor

split_pos = dim_size_val - shift_mod
part1 = self.block_builder.emit(
relax.op.strided_slice(
tensor,
axes=[dim],
begin=[0],
end=[split_pos],
strides=[1],
)
)
part2 = self.block_builder.emit(
relax.op.strided_slice(
tensor,
axes=[dim],
begin=[split_pos],
end=[dim_size_val],
strides=[1],
)
)
return self.block_builder.emit(relax.op.concat([part2, part1], axis=dim))

# Handle dims=None (flatten -> roll -> reshape)
if dims is None:
flattened = self.block_builder.emit(relax.op.reshape(input_tensor, (-1,)))
shift_scalar = to_int(shifts[0] if isinstance(shifts, (list, tuple)) else shifts)
rolled = roll_single_dim(flattened, shift_scalar, 0)
return self.block_builder.emit(relax.op.reshape(rolled, original_shape))

# Normalize shifts and dims
if isinstance(shifts, (list, tuple)):
shifts = [to_int(s) for s in shifts]
else:
shifts = [to_int(shifts)]

if isinstance(dims, (list, tuple)):
dims = [to_int(d) for d in dims]
else:
dims = [to_int(dims)]

if len(shifts) != len(dims):
raise ValueError("shifts and dims must have the same length")

result = input_tensor
rank = len(original_shape.values)
for shift, dim in zip(shifts, dims):
if dim < 0:
dim += rank
result = roll_single_dim(result, shift, dim)

return result

def _reshape(self, node: fx.Node) -> relax.Var:
import torch # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def create_convert_map(
"narrow.default": self._narrow,
"permute.default": self._permute,
"repeat.default": self._repeat,
"roll.default": self._roll,
"select.int": self._select,
"slice.Tensor": self._slice,
"split.Tensor": self._split,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ def create_convert_map(
"numel": self._numel,
"permute": self._permute,
"repeat": self._repeat,
"roll": self._roll,
"reshape": self._reshape,
"scatter": self._scatter,
"select": self._select,
Expand Down
125 changes: 125 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -2968,6 +2968,131 @@ def main(
verify_model(ReshapeAs(), example_args, {}, expected1)


def test_roll():
class Roll1(Module):
def forward(self, x):
return torch.roll(x, 1)

class Roll2(Module):
def forward(self, x):
return torch.roll(x, -1, 0)

class Roll3(Module):
def forward(self, x):
return torch.roll(x, shifts=(2, 1), dims=(0, 1))

# Test case 1: torch.roll(x, 1)
@I.ir_module
class Expected1:
@R.function
def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")):
with R.dataflow():
lv: R.Tensor((8,), dtype="int64") = R.reshape(x, R.shape([8]))
lv1: R.Tensor((7,), dtype="int64") = R.strided_slice(
lv,
axes=[0],
begin=[R.prim_value(0)],
end=[R.prim_value(7)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv2: R.Tensor((1,), dtype="int64") = R.strided_slice(
lv,
axes=[0],
begin=[R.prim_value(7)],
end=[R.prim_value(8)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0)
lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2]))
gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv4,)
R.output(gv)
return gv

# Test case 2: torch.roll(x, -1, 0)
@I.ir_module
class Expected2:
@R.function
def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")):
with R.dataflow():
lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice(
x,
axes=[0],
begin=[R.prim_value(0)],
end=[R.prim_value(1)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice(
x,
axes=[0],
begin=[R.prim_value(1)],
end=[R.prim_value(4)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0)
gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv2,)
R.output(gv)
return gv

# Test case 3: torch.roll(x, shifts=(2,1), dims=(0,1))
@I.ir_module
class Expected3:
@R.function
def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")):
with R.dataflow():
# First roll along dim=0 with shift=2
lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
x,
axes=[0],
begin=[R.prim_value(0)],
end=[R.prim_value(2)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
x,
axes=[0],
begin=[R.prim_value(2)],
end=[R.prim_value(4)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0)

# Second roll along dim=1 with shift=1
lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
lv2,
axes=[1],
begin=[R.prim_value(0)],
end=[R.prim_value(1)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
lv2,
axes=[1],
begin=[R.prim_value(1)],
end=[R.prim_value(2)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1)
gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,)
R.output(gv)
return gv

# Test inputs
example_input = torch.randint(0, 10, (4, 2), dtype=torch.int64)

# Run verification for each case
verify_model(Roll1(), (example_input,), {}, Expected1)
verify_model(Roll2(), (example_input,), {}, Expected2)
verify_model(Roll3(), (example_input,), {}, Expected3)


def test_select_slice():
class Slice1(Module):
def forward(self, x):
Expand Down
120 changes: 120 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3560,6 +3560,126 @@ def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((4, 6), dtype="float3
verify_model(Tile2(), [(torch.Size([1, 3]), "float32")], {}, expected2)


def test_roll():
class Roll1(Module):
def forward(self, x):
return torch.roll(x, 1)

class Roll2(Module):
def forward(self, x):
return torch.roll(x, -1, 0)

class Roll3(Module):
def forward(self, x):
return torch.roll(x, shifts=(2, 1), dims=(0, 1))

# Test case 1: torch.roll(x, 1)
@I.ir_module
class Expected1:
@R.function
def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int64"):
with R.dataflow():
lv: R.Tensor((8,), dtype="int64") = R.reshape(inp_0, R.shape([8]))
lv1: R.Tensor((7,), dtype="int64") = R.strided_slice(
lv,
axes=[0],
begin=[R.prim_value(0)],
end=[R.prim_value(7)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv2: R.Tensor((1,), dtype="int64") = R.strided_slice(
lv,
axes=[0],
begin=[R.prim_value(7)],
end=[R.prim_value(8)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0)
lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2]))
gv: R.Tensor((4, 2), dtype="int64") = lv4
R.output(gv)
return gv

# Test case 2: torch.roll(x, -1, 0)
@I.ir_module
class Expected2:
@R.function
def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int64"):
with R.dataflow():
lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice(
inp_0,
axes=[0],
begin=[R.prim_value(0)],
end=[R.prim_value(1)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice(
inp_0,
axes=[0],
begin=[R.prim_value(1)],
end=[R.prim_value(4)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0)
gv: R.Tensor((4, 2), dtype="int64") = lv2
R.output(gv)
return gv

# Test case 3: torch.roll(x, shifts=(2, 1), dims=(0, 1))
@I.ir_module
class Expected3:
@R.function
def main(inp_0: R.Tensor((4, 2), dtype="int64")) -> R.Tensor((4, 2), dtype="int64"):
with R.dataflow():
lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
inp_0,
axes=[0],
begin=[R.prim_value(0)],
end=[R.prim_value(2)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
inp_0,
axes=[0],
begin=[R.prim_value(2)],
end=[R.prim_value(4)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0)
lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
lv2,
axes=[1],
begin=[R.prim_value(0)],
end=[R.prim_value(1)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
lv2,
axes=[1],
begin=[R.prim_value(1)],
end=[R.prim_value(2)],
strides=[R.prim_value(1)],
assume_inbound=False,
)
lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1)
gv: R.Tensor((4, 2), dtype="int64") = lv5
R.output(gv)
return gv

input_info = [([4, 2], "int64")]

verify_model(Roll1(), input_info, {}, Expected1)
verify_model(Roll2(), input_info, {}, Expected2)
verify_model(Roll3(), input_info, {}, Expected3)


def test_view():
input_info = [([1, 2, 3, 4], "float32")]

Expand Down