diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 8b771b5d2fb9..d5cad2381b49 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -994,7 +994,19 @@ def _transpose(self, node: fx.Node) -> relax.Var: ########## Creation ########## + def _detach(self, node: fx.Node) -> relax.Var: + # There is no way to implement detach() such that the output shares + # the same memory as the input. In-place operations are not supported + # by the translator, and therefore we just return a copy of the input. + return self.env[node.args[0]] + + def _copy_(self, node: fx.Node) -> relax.Var: + # Copies the source tensor's to the destination tensor + # In TVM, that means simply returning the source tensor + return self.env[node.args[1]] + def _to_copy(self, node: fx.Node) -> relax.Var: + # Returns a copy of the input tensor import torch # type: ignore x = self.env[node.args[0]] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index f3c0a6467640..4ff31ea1d772 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -295,6 +295,7 @@ def create_convert_map( # tensor manipulation "cat.default": self._cat, "concat.default": self._cat, + "copy_.default": self._copy_, "cumsum.default": self._cumsum, "expand.default": self._expand, "permute.default": self._permute, @@ -313,6 +314,9 @@ def create_convert_map( "reshape.default": self._reshape, # tensor creation "_to_copy.default": self._to_copy, + "lift_fresh_copy.default": self._to_copy, + "detach.default": self._detach, + "detach_.default": self._detach, "arange.start": self._arange, "contiguous.default": lambda node: self.env[node.args[0]], # no-op "clone.default": lambda node: self.env[node.args[0]], diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 69daab36a581..bd4bdcf61770 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -20,6 +20,7 @@ import tvm.testing import numpy as np import torch +from torch import nn from torch.export import export from tvm.relax.frontend.torch import from_exported_program from torch.nn import Softmax, Upsample @@ -55,6 +56,24 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) +@tvm.testing.parametrize_targets("cuda") +def test_copy_(target, dev): + class CopyTester(nn.Module): + def __init__(self, size): + super().__init__() + self.register_buffer("buffer", torch.zeros(size)) + + def forward(self, x): + self.buffer.copy_(x) + + return x * 3 + self.buffer * 5 + + size = (2, 2) + raw_data = np.random.rand(*size).astype(np.float32) + torch_module = CopyTester(size).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_upsample_with_size(target, dev): """ @@ -72,6 +91,19 @@ def test_upsample_with_size(target, dev): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_detach_no_change(target, dev): + # In TVM, detach() is just identity + class DetachTester(nn.Module): + def forward(self, x): + detached = x.detach() + return detached + + raw_data = np.ones((2, 2)).astype(np.float32) + torch_module = DetachTester().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_upsample_with_scale_factor(target, dev): """ @@ -87,7 +119,6 @@ def test_upsample_with_scale_factor(target, dev): ) raw_data = np.random.rand(batch_size, channels, height, width).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)