diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c9c6afd71a64..3018b0db771d 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1271,6 +1271,28 @@ def _fill(self, node: fx.Node) -> relax.Var: value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + def _full(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) + return self.block_builder.emit( + relax.op.full( + size, + value, + dtype, + ) + ) + + def _full_like(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + fill_value = relax.const(node.args[1]) + return self.block_builder.emit(relax.op.full_like(x, fill_value)) + def _index_select(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] @@ -1292,6 +1314,22 @@ def _new_ones(self, node: fx.Node) -> relax.Var: ) ) + def _ones(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, dtype), + dtype, + ) + ) + ########## DataType ########## def _to(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 875ec3b83ea8..bcb8b6468f72 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -433,10 +433,13 @@ def create_convert_map( "empty.memory_format": self._empty, "empty_like.default": self._empty_like, "fill.Scalar": self._fill, + "full.default": self._full, + "full_like.default": self._full_like, "index_select.default": self._index_select, "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, + "ones.default": self._ones, # datatype "to.dtype": self._to, "to.dtype_layout": self._to, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a5b50a7d1dce..f1b9a6d6e28c 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -468,23 +468,6 @@ def _inplace_fill(self, node: fx.Node) -> relax.Var: self.env[node.args[0]] = filled return filled - def _full(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) - dtype = self._convert_data_type( - node.kwargs.get("dtype", torch.get_default_dtype()), self.env - ) - value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) - return self.block_builder.emit( - relax.op.full( - size, - value, - dtype, - ) - ) - def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] @@ -527,22 +510,6 @@ def _masked_scatter(self, node: fx.Node) -> relax.Var: mask = self.block_builder.emit(relax.op.broadcast_to(mask, x.struct_info.shape)) return self.block_builder.emit(relax.op.where(mask, gathered_source, x)) - def _ones(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) - dtype = self._convert_data_type( - node.kwargs.get("dtype", torch.get_default_dtype()), self.env - ) - return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, dtype), - dtype, - ) - ) - def _one_hot(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] num_classes = node.args[1] if len(node.args) > 1 else node.kwargs.get("num_classes") diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 8405f48576d8..e92855885e35 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -63,6 +63,54 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) +@tvm.testing.parametrize_targets("cuda") +def test_full(target, dev): + class FullModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.full((2, 3), 3.141592) + + torch_module = FullModel().eval() + + raw_data = np.random.rand(3, 3).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_full_like(target, dev): + class FullLike(nn.Module): + def __init__(self): + super().__init__() + self.fill_value = 7.0 + + def forward(self, x): + return torch.full_like(x, self.fill_value) + + torch_module = FullLike().eval() + raw_data = np.random.rand(2, 3).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_ones(target, dev): + class FullModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.ones((2, 3)) + + torch_module = FullModel().eval() + + raw_data = np.random.rand(1, 1).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module):