From 3ed87ad2c1b4293b2cc0528f80624280e3a12138 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Mon, 22 Sep 2025 08:31:54 -0400 Subject: [PATCH 1/3] finish1 --- python/tvm/relax/base_py_module.py | 58 +++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index 7a790d28a720..53c1d3de3eb8 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -32,6 +32,13 @@ except ImportError: to_dlpack_legacy = None +try: + from tvm_ffi._optional_torch_c_dlpack import load_torch_c_dlpack_extension + + _faster_dlpack_extension = load_torch_c_dlpack_extension() +except ImportError: + _faster_dlpack_extension = None + class BasePyModule: """Base class that allows Python functions in IRModule with DLPack conversion. @@ -369,20 +376,29 @@ def _convert_pytorch_to_tvm( return self._convert_single_pytorch_to_tvm(tensors) def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor: - """Convert a single PyTorch tensor to TVM Tensor with robust fallbacks.""" + """Convert a single PyTorch tensor to TVM Tensor with faster DLPack converter.""" # pylint: disable=import-outside-toplevel import torch if isinstance(tensor, Tensor): return tensor if isinstance(tensor, torch.Tensor): - # 1. Try modern `torch.to_dlpack` (preferred for PyTorch >= 1.7) + # 1. Try faster C++ DLPack converter + if _faster_dlpack_extension is not None: + try: + dlpack = torch.to_dlpack(tensor) + return tvm.runtime.from_dlpack(dlpack) + except (AttributeError, ValueError): + pass # Fall through to the next method + + # 2. Try modern `torch.to_dlpack` (preferred for PyTorch >= 1.7) try: dlpack = torch.to_dlpack(tensor) return tvm.runtime.from_dlpack(dlpack) except (AttributeError, ValueError): pass # Fall through to the next method - # 2. Try legacy `torch.utils.dlpack.to_dlpack` + + # 3. Try legacy `torch.utils.dlpack.to_dlpack` if to_dlpack_legacy: try: dlpack = to_dlpack_legacy(tensor) @@ -392,7 +408,8 @@ def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor: f"Warning: Legacy DLPack conversion failed ({error_legacy}), " f"using numpy fallback." ) - # 3. If all DLPack methods fail, use numpy fallback + + # 4. If all DLPack methods fail, use numpy fallback numpy_array = tensor.detach().cpu().numpy() return tvm.runtime.tensor(numpy_array, device=self.device) @@ -406,28 +423,37 @@ def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor: ) from error def _convert_tvm_to_pytorch( - self, tvm_arrays: Union[Any, List[Any]] + self, tvm_tensors: Union[Any, List[Any]] ) -> Union["torch.Tensor", List["torch.Tensor"]]: """Convert TVM Tensors to PyTorch tensors using DLPack.""" - if isinstance(tvm_arrays, (list, tuple)): - return [self._convert_single_tvm_to_pytorch(arr) for arr in tvm_arrays] - return self._convert_single_tvm_to_pytorch(tvm_arrays) + if isinstance(tvm_tensors, (list, tuple)): + return [self._convert_single_tvm_to_pytorch(tensor) for tensor in tvm_tensors] + return self._convert_single_tvm_to_pytorch(tvm_tensors) - def _convert_single_tvm_to_pytorch(self, tvm_array: Any) -> "torch.Tensor": - """Convert a single TVM Tensor to PyTorch tensor using DLPack.""" + def _convert_single_tvm_to_pytorch(self, tvm_tensor: Any) -> "torch.Tensor": + """Convert a single TVM Tensor to PyTorch tensor using faster DLPack converter.""" # pylint: disable=import-outside-toplevel import torch - if isinstance(tvm_array, torch.Tensor): - return tvm_array - if not isinstance(tvm_array, Tensor): - return torch.tensor(tvm_array) + if isinstance(tvm_tensor, torch.Tensor): + return tvm_tensor + if not isinstance(tvm_tensor, Tensor): + return torch.tensor(tvm_tensor) + + # 1. Try faster C++ DLPack converter + if _faster_dlpack_extension is not None: + try: + return torch.from_dlpack(tvm_tensor) + except (AttributeError, ValueError): + pass # Fall through to the next method + + # 2. Try standard DLPack conversion try: - return torch.from_dlpack(tvm_array) + return torch.from_dlpack(tvm_tensor) # pylint: disable=broad-exception-caught except Exception as error: print(f"Warning: DLPack conversion from TVM failed ({error}), using numpy fallback") - numpy_array = tvm_array.numpy() + numpy_array = tvm_tensor.numpy() return torch.from_numpy(numpy_array) def get_function(self, name: str) -> Optional[PackedFunc]: From 75f9ea4d525e65a0d91be8a2faf508f6e312cc9c Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Mon, 22 Sep 2025 08:45:14 -0400 Subject: [PATCH 2/3] finish2 --- python/tvm/relax/base_py_module.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index 53c1d3de3eb8..41ef44fb300b 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -35,9 +35,9 @@ try: from tvm_ffi._optional_torch_c_dlpack import load_torch_c_dlpack_extension - _faster_dlpack_extension = load_torch_c_dlpack_extension() + _FASTER_DLPACK_EXTENSION = load_torch_c_dlpack_extension() except ImportError: - _faster_dlpack_extension = None + _FASTER_DLPACK_EXTENSION = None class BasePyModule: @@ -384,7 +384,7 @@ def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor: return tensor if isinstance(tensor, torch.Tensor): # 1. Try faster C++ DLPack converter - if _faster_dlpack_extension is not None: + if _FASTER_DLPACK_EXTENSION is not None: try: dlpack = torch.to_dlpack(tensor) return tvm.runtime.from_dlpack(dlpack) @@ -441,7 +441,7 @@ def _convert_single_tvm_to_pytorch(self, tvm_tensor: Any) -> "torch.Tensor": return torch.tensor(tvm_tensor) # 1. Try faster C++ DLPack converter - if _faster_dlpack_extension is not None: + if _FASTER_DLPACK_EXTENSION is not None: try: return torch.from_dlpack(tvm_tensor) except (AttributeError, ValueError): From 645bd43c49ded6bbc38d9595f1af941e76181400 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Mon, 22 Sep 2025 11:51:19 -0400 Subject: [PATCH 3/3] fix --- tests/python/relax/test_base_py_module.py | 2 +- .../relax/test_base_py_module_printer.py | 52 ++----------------- 2 files changed, 5 insertions(+), 49 deletions(-) diff --git a/tests/python/relax/test_base_py_module.py b/tests/python/relax/test_base_py_module.py index 19cc5c9eec6d..1f888991be1b 100644 --- a/tests/python/relax/test_base_py_module.py +++ b/tests/python/relax/test_base_py_module.py @@ -203,4 +203,4 @@ def my_softmax(tensor, dim): if __name__ == "__main__": - pytest.main([__file__]) + tvm.testing.main() diff --git a/tests/python/relax/test_base_py_module_printer.py b/tests/python/relax/test_base_py_module_printer.py index c9d23a746567..a64b3fed5aea 100644 --- a/tests/python/relax/test_base_py_module_printer.py +++ b/tests/python/relax/test_base_py_module_printer.py @@ -420,54 +420,6 @@ def safe_transform(data: T.handle, output: T.handle): Output[i] = 0.0 -if __name__ == "__main__": - # This allows the file to be run directly for debugging - # In normal pytest usage, these classes are automatically tested by TVMScript - print("All test modules defined successfully!") - print("TVMScript will automatically validate these modules during testing.") - - # Demo the printer functionality - print("\n" + "=" * 60) - print("DEMO: BasePyModule Printer Functionality") - print("=" * 60) - - # Test the printer with SimplePyFuncModule - try: - ir_mod = SimplePyFuncModule - device = tvm.cpu() - module = BasePyModule(ir_mod, device) - - print("\n1. Testing script() method:") - print("-" * 40) - script_output = module.script() - print(script_output[:500] + "..." if len(script_output) > 500 else script_output) - - print("\n2. Testing show() method:") - print("-" * 40) - module.show() - - print("\n3. Python functions found in pyfuncs:") - print("-" * 40) - if hasattr(ir_mod, "pyfuncs"): - for name, func in ir_mod.pyfuncs.items(): - print(f" - {name}: {func}") - else: - print(" No pyfuncs attribute found") - - except Exception as e: - print(f"Demo failed: {e}") - print("This is expected for testing-only TVMScript code.") - - # Run all tests using tvm.testing.main() - print("\n" + "=" * 60) - print("Running all tests with tvm.testing.main()...") - print("=" * 60) - - import tvm.testing - - tvm.testing.main() - - # Pytest test functions to verify the classes work correctly def test_simple_pyfunc_module_creation(): """Test that SimplePyFuncModule can be created.""" @@ -849,3 +801,7 @@ def mixed_computation(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32 # Use numpy for comparison since we have numpy arrays np.testing.assert_allclose(final_result_np, expected_np, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + tvm.testing.main()