Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,13 @@ def _empty_like(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
return self.block_builder.emit(relax.op.zeros_like(x))

def _eye(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
n = args[0]
m = args[1] if len(args) > 1 else n
dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
return self.block_builder.emit(relax.op.eye(n, m, dtype=dtype))

def _fill(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ def create_convert_map(
"clone.default": lambda node: self.env[node.args[0]],
"empty.memory_format": self._empty,
"empty_like.default": self._empty_like,
"eye.default": self._eye,
"eye.m": self._eye,
"fill.Scalar": self._fill,
"full.default": self._full,
"full_like.default": self._full_like,
Expand Down
40 changes: 40 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -4377,5 +4377,45 @@ def main(
verify_model(Narrow(), example_args, {}, Expected)


def test_eye():
class Eye1(Module):
def forward(self, input):
return torch.eye(3, 5, dtype=torch.float32)

@tvm.script.ir_module
class Expected1:
@R.function
def main(
input: R.Tensor((3, 5), dtype="float32")
) -> R.Tuple(R.Tensor((3, 5), dtype="float32")):
with R.dataflow():
lv: R.Tensor((3, 5), dtype="float32") = R.eye(3, 5, dtype="float32")
gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv,)
R.output(gv)
return gv

class Eye2(Module):
def forward(self, input):
return torch.eye(5, dtype=torch.float32)

@tvm.script.ir_module
class Expected2:
@R.function
def main(
input: R.Tensor((5,), dtype="float32")
) -> R.Tuple(R.Tensor((5, 5), dtype="float32")):
with R.dataflow():
lv: R.Tensor((5, 5), dtype="float32") = R.eye(5, dtype="float32")
gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args1 = (torch.randn(3, 5, dtype=torch.float32),)
verify_model(Eye1(), example_args1, {}, Expected1)

example_args2 = (torch.randn(5, dtype=torch.float32),)
verify_model(Eye2(), example_args2, {}, Expected2)


if __name__ == "__main__":
tvm.testing.main()