From c09704cb7f14ae6e2e3bc99e396a6eb0a62352d5 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Sun, 14 Sep 2025 02:24:57 +0800 Subject: [PATCH 01/19] finish1 --- python/tvm/relax/base_py_module.py | 13 +- python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/base.py | 36 ++++ python/tvm/script/ir_builder/relax/ir.py | 53 +++++ src/relax/op/op.cc | 65 ++++++ .../relax/test_base_py_module_printer.py | 204 ++++++++++++++++++ 6 files changed, 366 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index a4464cc737b9..70b99755c5f5 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -234,12 +234,13 @@ def call_dps_packed(self, func_name: str, args, out_sinfo): return out[0] if len(out) == 1 else out def call_py_func(self, func_name: str, args): - """Call a Python function stored in the IRModule's pyfuncs.""" - if func_name not in self.ir_mod.pyfuncs: - raise ValueError(f"Python function '{func_name}' not found in IRModule pyfuncs") - py_func = self.ir_mod.pyfuncs[func_name] - converted_args = self._convert_tvm_to_pytorch(args) - return py_func(*converted_args) + """Call a Python function stored in the module's pyfuncs.""" + if func_name not in self.pyfuncs: + raise ValueError(f"Python function '{func_name}' not found in module pyfuncs") + py_func = self.pyfuncs[func_name] + # args 已经是 PyTorch 张量,不需要转换 + # py_func 是绑定到 self 的方法,需要传递 self 作为第一个参数 + return py_func(self, *args) def _create_output_tensors(self, out_sinfo, in_args=None): # pylint: disable=import-outside-toplevel diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index fd3672368b68..6ea8305ecadb 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -27,6 +27,7 @@ call_dps_packed, call_inplace_packed, call_pure_packed, + call_py_func, call_tir, call_tir_inplace, call_tir_with_grad, diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index e77920d8dea6..e205abde30b4 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -304,6 +304,42 @@ def call_dps_packed( return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore +@args_converter.auto +def call_py_func( + func_name: str, + args: Expr, + out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]], +) -> Call: + """ + Call a Python function and return the output. + + Parameters + ---------- + func_name : str + The name of the Python function to call. This should correspond to a function + in the IRModule's pyfuncs attribute. + + args : Expr + The input arguments. + + out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] + The structure info of the call_py_func output. + It should be a single or a list of TensorStructInfo. Each one denotes the + structure info of a returned tensor. + + Returns + ------- + ret: Call + A call node for the call_py_func operator. + """ + args = _wrap_inline_arg_tuple(args) + + if not isinstance(out_sinfo, list): + out_sinfo = [out_sinfo] + + return _ffi_api.call_py_func(func_name, args, out_sinfo) # type: ignore + + @args_converter.auto def call_builtin_with_ctx( func: Union[str, Expr], diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index d28ff3430aaa..2e54e405cb86 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -451,6 +451,58 @@ def call_packed( return Call(op, args, attrs=attrs, sinfo_args=sinfo_args) +@args_converter.auto +def call_py_func( + func_name: py_str, + *args: Expr, + out_sinfo: Union[StructInfo, List[StructInfo]], +) -> Call: + """Create a relax Call, which calls a Python function. + + Parameters + ---------- + func_name: str + The name of the Python function to call. This should correspond to a function + in the IRModule's pyfuncs attribute. + *args : Expr + The arguments. + out_sinfo: Union[StructInfo, List[StructInfo]] + The structure info of the call_py_func output. + It should be a single or a list of TensorStructInfo. Each one denotes the + structure info of a returned tensor. + + Returns + ------- + call: Call + The created Relax Call for call_py_func operator. + """ + from tvm.relax.op import call_py_func as _call_py_func + from tvm.relax.expr import StringImm + + if isinstance(out_sinfo, py_tuple): # type: ignore + out_sinfo = list(out_sinfo) + elif not isinstance(out_sinfo, list): + out_sinfo = [out_sinfo] + + out_sinfo = [ + ( + sinfo() + if callable(sinfo) + else sinfo.asobject() + if isinstance(sinfo, ObjectConvertible) + else sinfo + ) + for sinfo in out_sinfo + ] + + # Convert string to StringImm + try: + func_name_imm = StringImm(func_name) if hasattr(func_name, 'strip') else func_name + except: + func_name_imm = StringImm(func_name) + return _call_py_func(func_name_imm, args, out_sinfo) + + def _sinfo_arg_wrapper(func): """A wrapper to convert StructInfoProxies to StructInfo for builtin operators with sinfo_args""" @@ -743,6 +795,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "call_tir_inplace", "call_tir_with_grad", "call_dps_packed", + "call_py_func", "call_builtin_with_ctx", "ceil", "clip", diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index ddf6a056f00a..b679410b8c93 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -858,6 +858,71 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def("relax.op.call_dps_packed", MakeCallDPSPacked); }); +// call_py_func + +StructInfo InferStructInfoCallPyFunc(const Call& call, const BlockBuilder& ctx) { + if (call->sinfo_args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "sinfo_args should have exact 1 output struct info."); + } + return call->sinfo_args[0]; +} + +void ValidateCallPyFunc(Call call) { + // Validate that the function name is a string literal + auto func_name = call->args[0]; + CHECK(func_name->IsInstance()) + << "Operation " << call->op << " expects the first argument to be a string literal " + << "specifying the Python function name. However, the first argument " << func_name + << " is not a string literal."; + + // Validate that args is a tuple + Expr arg_tuple = call->args[1]; + CHECK(arg_tuple->struct_info_.as()) + << "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. " + << "However, the second argument " << arg_tuple << " has struct info " + << arg_tuple->struct_info_ << "."; + + CHECK(arg_tuple.as() || arg_tuple.as()) + << "Operation " << call->op << " must hold its arguments as an in-line tuple. " + << "However, " << call << " has arguments " << arg_tuple + << ", which is neither an in-line tuple, " + << "nor a variable binding that may be normalized to an in-line tuple."; +} + +TVM_REGISTER_OP("relax.call_py_func") + .set_num_inputs(2) + .add_argument("func_name", "StringImm", "The name of the Python function to call.") + .add_argument("args", "Tuple", "The input arguments.") + .set_attr("FInferStructInfo", InferStructInfoCallPyFunc) + .set_attr("FValidate", ValidateCallPyFunc) + .set_attr("FPurity", Bool(true)); + +Expr MakeCallPyFunc(StringImm func_name, Tuple args, ffi::Array out_sinfo_list) { + for (const TensorStructInfo& sinfo : out_sinfo_list) { + const auto* shape = sinfo->shape.as(); + CHECK(shape != nullptr) + << "out_sinfo of call_py_func should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; + } + + StructInfo out_sinfo{nullptr}; + if (out_sinfo_list.size() == 1) { + out_sinfo = out_sinfo_list[0]; + } else { + out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + } + + static const Op& op = Op::Get("relax.call_py_func"); + return Call(op, {func_name, args}, {}, {out_sinfo}); +} + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.call_py_func", MakeCallPyFunc); +}); + // call builtin StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { if (call->sinfo_args.size() == 0) { diff --git a/tests/python/relax/test_base_py_module_printer.py b/tests/python/relax/test_base_py_module_printer.py index 92c799f6cb70..20d5cd8da195 100644 --- a/tests/python/relax/test_base_py_module_printer.py +++ b/tests/python/relax/test_base_py_module_printer.py @@ -758,3 +758,207 @@ def test_python_functions_in_irmodule(): assert pyfuncs["multiply"].__name__ == "multiply" else: pytest.fail("pyfuncs attribute not found in IRModule") + + +def test_call_py_func_operator(): + """Test R.call_py_func operator functionality.""" + import torch + + @I.ir_module + class CallPyFuncTestModule(BasePyModule): + """Test module with call_py_func usage.""" + + @I.pyfunc + def pytorch_add(self, x, y): + """Simple PyTorch addition.""" + return x + y + + @I.pyfunc + def pytorch_multiply(self, x, y): + """Simple PyTorch multiplication.""" + return x * y + + @R.function + def test_call_py_func( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + # Test calling Python function from Relax + result = R.call_py_func("pytorch_add", (x, y), out_sinfo=R.Tensor((5,), "float32")) + return result + + @R.function + def test_call_py_func_chain( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + # First call + intermediate = R.call_py_func("pytorch_add", (x, y), out_sinfo=R.Tensor((5,), "float32")) + # Second call + result = R.call_py_func("pytorch_multiply", (intermediate, y), out_sinfo=R.Tensor((5,), "float32")) + return result + + # Test basic functionality + device = tvm.cpu() + module = CallPyFuncTestModule(device) + + # Create test tensors + x = torch.randn(5, dtype=torch.float32) + y = torch.randn(5, dtype=torch.float32) + + # Test direct Python function calls + expected_add = x + y + expected_multiply = x * y + + # Test through BasePyModule + result_add = module.call_py_func("pytorch_add", [x, y]) + result_multiply = module.call_py_func("pytorch_multiply", [x, y]) + + assert torch.allclose(result_add, expected_add, atol=1e-5) + assert torch.allclose(result_multiply, expected_multiply, atol=1e-5) + + # Test that the module has the pyfuncs + assert hasattr(module, "pyfuncs") + assert "pytorch_add" in module.pyfuncs + assert "pytorch_multiply" in module.pyfuncs + + +def test_call_py_func_validation(): + """Test call_py_func validation and error handling.""" + import torch + + @I.ir_module + class ValidationTestModule(BasePyModule): + """Test module for validation.""" + + @I.pyfunc + def valid_func(self, x): + """Valid Python function.""" + return x * 2 + + @R.function + def test_invalid_call( + x: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + # This should cause a validation error + result = R.call_py_func("non_existent_func", (x,), out_sinfo=R.Tensor((5,), "float32")) + return result + + device = tvm.cpu() + module = ValidationTestModule(device) + + # Test that calling non-existent function raises error + x = torch.randn(5, dtype=torch.float32) + + with pytest.raises(ValueError, match="Python function 'non_existent_func' not found"): + module.call_py_func("non_existent_func", [x]) + + +def test_call_py_func_in_relax_function(): + """Test using call_py_func within Relax functions.""" + import torch + + @I.ir_module + class RelaxCallPyFuncModule(BasePyModule): + """Test module with call_py_func in Relax functions.""" + + @I.pyfunc + def torch_relu(self, x): + """PyTorch ReLU implementation.""" + return torch.relu(x) + + @I.pyfunc + def torch_softmax(self, x, dim=0): + """PyTorch softmax implementation.""" + return torch.softmax(x, dim=dim) + + @R.function + def mixed_computation( + x: R.Tensor((10,), "float32") + ) -> R.Tensor((10,), "float32"): + # Use Python function for ReLU + relu_result = R.call_py_func("torch_relu", (x,), out_sinfo=R.Tensor((10,), "float32")) + # Use Python function for softmax + final_result = R.call_py_func("torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), "float32")) + return final_result + + device = tvm.cpu() + module = RelaxCallPyFuncModule(device) + + # Test the mixed computation + x = torch.randn(10, dtype=torch.float32) + + expected = torch.softmax(torch.relu(x), dim=0) + + relu_result = module.call_py_func("torch_relu", [x]) + final_result = module.call_py_func("torch_softmax", [relu_result]) + + assert torch.allclose(final_result, expected, atol=1e-5) + + +def test_call_py_func_operator_creation(): + """Test R.call_py_func operator creation and basic properties.""" + from tvm.relax.op import call_py_func + from tvm.relax.expr import StringImm + from tvm.relax import Var, TensorStructInfo + + # Create variables + x = Var("x", TensorStructInfo((5,), "float32")) + y = Var("y", TensorStructInfo((5,), "float32")) + + # Create call_py_func call + call_expr = call_py_func(StringImm("test_func"), (x, y), out_sinfo=R.Tensor((5,), "float32")) + + # Verify operator properties + assert call_expr.op.name == "relax.call_py_func" + assert call_expr.args[0].value == "test_func" + assert len(call_expr.args) == 2 + + +def test_call_py_func_compilation_validation(): + """Test call_py_func compilation validation.""" + from tvm.relax.op import call_py_func + from tvm.relax import Var, TensorStructInfo + + # Test operator parameter validation + try: + call_py_func("invalid", (Var("x", TensorStructInfo((5,), "float32")),), out_sinfo=R.Tensor((5,), "float32")) + assert False, "Should raise type error" + except Exception as e: + assert "Mismatched type" in str(e) or "Expected" in str(e) + + +def test_call_py_func_runtime_execution(): + """Test call_py_func runtime execution with complex operations.""" + import torch + + @I.ir_module + class RuntimeTestModule(BasePyModule): + """Test module for runtime execution""" + + @I.pyfunc + def pytorch_add(self, x, y): + """PyTorch addition function""" + return x + y + + @I.pyfunc + def pytorch_relu(self, x): + """PyTorch ReLU function""" + return torch.relu(x) + + # Create module instance + device = tvm.cpu() + module = RuntimeTestModule(device) + + # Create test tensors + x = torch.randn(5, dtype=torch.float32) + y = torch.randn(5, dtype=torch.float32) + + # Test direct Python function calls + result_add = module.call_py_func("pytorch_add", [x, y]) + result_relu = module.call_py_func("pytorch_relu", [x]) + + expected_add = x + y + expected_relu = torch.relu(x) + + # Verify results + assert torch.allclose(result_add, expected_add, atol=1e-5) + assert torch.allclose(result_relu, expected_relu, atol=1e-5) From 77a792ee43ecc89f58020f812629458a96c10a47 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Sun, 14 Sep 2025 02:53:29 +0800 Subject: [PATCH 02/19] finish2 --- python/tvm/relax/base_py_module.py | 2 - python/tvm/script/ir_builder/relax/ir.py | 17 +++--- .../relax/test_base_py_module_printer.py | 58 ++++++++++--------- 3 files changed, 40 insertions(+), 37 deletions(-) diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index 70b99755c5f5..52f813dc6b6d 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -238,8 +238,6 @@ def call_py_func(self, func_name: str, args): if func_name not in self.pyfuncs: raise ValueError(f"Python function '{func_name}' not found in module pyfuncs") py_func = self.pyfuncs[func_name] - # args 已经是 PyTorch 张量,不需要转换 - # py_func 是绑定到 self 的方法,需要传递 self 作为第一个参数 return py_func(self, *args) def _create_output_tensors(self, out_sinfo, in_args=None): diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 2e54e405cb86..d9499652131d 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -30,6 +30,7 @@ Expr, ExternFunc, ShapeExpr, + StringImm, TupleGetItem, Var, VarBinding, @@ -64,6 +65,7 @@ call_dps_packed, call_inplace_packed, call_pure_packed, + call_py_func as _call_py_func, call_tir, call_tir_inplace, call_tir_with_grad, @@ -453,15 +455,15 @@ def call_packed( @args_converter.auto def call_py_func( - func_name: py_str, + py_func_name: py_str, *args: Expr, out_sinfo: Union[StructInfo, List[StructInfo]], ) -> Call: """Create a relax Call, which calls a Python function. - + Parameters ---------- - func_name: str + py_func_name: str The name of the Python function to call. This should correspond to a function in the IRModule's pyfuncs attribute. *args : Expr @@ -476,9 +478,6 @@ def call_py_func( call: Call The created Relax Call for call_py_func operator. """ - from tvm.relax.op import call_py_func as _call_py_func - from tvm.relax.expr import StringImm - if isinstance(out_sinfo, py_tuple): # type: ignore out_sinfo = list(out_sinfo) elif not isinstance(out_sinfo, list): @@ -497,9 +496,9 @@ def call_py_func( # Convert string to StringImm try: - func_name_imm = StringImm(func_name) if hasattr(func_name, 'strip') else func_name - except: - func_name_imm = StringImm(func_name) + func_name_imm = StringImm(py_func_name) if hasattr(py_func_name, "strip") else py_func_name + except Exception: + func_name_imm = StringImm(py_func_name) return _call_py_func(func_name_imm, args, out_sinfo) diff --git a/tests/python/relax/test_base_py_module_printer.py b/tests/python/relax/test_base_py_module_printer.py index 20d5cd8da195..4c4e54ee5b36 100644 --- a/tests/python/relax/test_base_py_module_printer.py +++ b/tests/python/relax/test_base_py_module_printer.py @@ -791,9 +791,13 @@ def test_call_py_func_chain( x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") ) -> R.Tensor((5,), "float32"): # First call - intermediate = R.call_py_func("pytorch_add", (x, y), out_sinfo=R.Tensor((5,), "float32")) + intermediate = R.call_py_func( + "pytorch_add", (x, y), out_sinfo=R.Tensor((5,), "float32") + ) # Second call - result = R.call_py_func("pytorch_multiply", (intermediate, y), out_sinfo=R.Tensor((5,), "float32")) + result = R.call_py_func( + "pytorch_multiply", (intermediate, y), out_sinfo=R.Tensor((5,), "float32") + ) return result # Test basic functionality @@ -835,9 +839,7 @@ def valid_func(self, x): return x * 2 @R.function - def test_invalid_call( - x: R.Tensor((5,), "float32") - ) -> R.Tensor((5,), "float32"): + def test_invalid_call(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): # This should cause a validation error result = R.call_py_func("non_existent_func", (x,), out_sinfo=R.Tensor((5,), "float32")) return result @@ -847,7 +849,7 @@ def test_invalid_call( # Test that calling non-existent function raises error x = torch.randn(5, dtype=torch.float32) - + with pytest.raises(ValueError, match="Python function 'non_existent_func' not found"): module.call_py_func("non_existent_func", [x]) @@ -871,13 +873,13 @@ def torch_softmax(self, x, dim=0): return torch.softmax(x, dim=dim) @R.function - def mixed_computation( - x: R.Tensor((10,), "float32") - ) -> R.Tensor((10,), "float32"): + def mixed_computation(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32"): # Use Python function for ReLU relu_result = R.call_py_func("torch_relu", (x,), out_sinfo=R.Tensor((10,), "float32")) # Use Python function for softmax - final_result = R.call_py_func("torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), "float32")) + final_result = R.call_py_func( + "torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), "float32") + ) return final_result device = tvm.cpu() @@ -885,12 +887,12 @@ def mixed_computation( # Test the mixed computation x = torch.randn(10, dtype=torch.float32) - + expected = torch.softmax(torch.relu(x), dim=0) - + relu_result = module.call_py_func("torch_relu", [x]) final_result = module.call_py_func("torch_softmax", [relu_result]) - + assert torch.allclose(final_result, expected, atol=1e-5) @@ -899,14 +901,14 @@ def test_call_py_func_operator_creation(): from tvm.relax.op import call_py_func from tvm.relax.expr import StringImm from tvm.relax import Var, TensorStructInfo - + # Create variables x = Var("x", TensorStructInfo((5,), "float32")) y = Var("y", TensorStructInfo((5,), "float32")) - + # Create call_py_func call call_expr = call_py_func(StringImm("test_func"), (x, y), out_sinfo=R.Tensor((5,), "float32")) - + # Verify operator properties assert call_expr.op.name == "relax.call_py_func" assert call_expr.args[0].value == "test_func" @@ -917,10 +919,14 @@ def test_call_py_func_compilation_validation(): """Test call_py_func compilation validation.""" from tvm.relax.op import call_py_func from tvm.relax import Var, TensorStructInfo - + # Test operator parameter validation try: - call_py_func("invalid", (Var("x", TensorStructInfo((5,), "float32")),), out_sinfo=R.Tensor((5,), "float32")) + call_py_func( + "invalid", + (Var("x", TensorStructInfo((5,), "float32")),), + out_sinfo=R.Tensor((5,), "float32"), + ) assert False, "Should raise type error" except Exception as e: assert "Mismatched type" in str(e) or "Expected" in str(e) @@ -929,36 +935,36 @@ def test_call_py_func_compilation_validation(): def test_call_py_func_runtime_execution(): """Test call_py_func runtime execution with complex operations.""" import torch - + @I.ir_module class RuntimeTestModule(BasePyModule): """Test module for runtime execution""" - + @I.pyfunc def pytorch_add(self, x, y): """PyTorch addition function""" return x + y - + @I.pyfunc def pytorch_relu(self, x): """PyTorch ReLU function""" return torch.relu(x) - + # Create module instance device = tvm.cpu() module = RuntimeTestModule(device) - + # Create test tensors x = torch.randn(5, dtype=torch.float32) y = torch.randn(5, dtype=torch.float32) - + # Test direct Python function calls result_add = module.call_py_func("pytorch_add", [x, y]) result_relu = module.call_py_func("pytorch_relu", [x]) - + expected_add = x + y expected_relu = torch.relu(x) - + # Verify results assert torch.allclose(result_add, expected_add, atol=1e-5) assert torch.allclose(result_relu, expected_relu, atol=1e-5) From 05bc35c0fc8b76998da98c74404614c6effa0c66 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Sun, 14 Sep 2025 03:27:19 +0800 Subject: [PATCH 03/19] finish3:lint --- python/tvm/script/ir_builder/relax/ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index d9499652131d..a352dbf8f12c 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -497,7 +497,7 @@ def call_py_func( # Convert string to StringImm try: func_name_imm = StringImm(py_func_name) if hasattr(py_func_name, "strip") else py_func_name - except Exception: + except (TypeError, ValueError, AttributeError): func_name_imm = StringImm(py_func_name) return _call_py_func(func_name_imm, args, out_sinfo) From e58a9edda699fb6486efba6e905dc7a4ae3f48e0 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Sun, 14 Sep 2025 03:40:04 +0800 Subject: [PATCH 04/19] finish4:lint --- src/relax/op/op.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index b679410b8c93..2ac9f30d7423 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -901,10 +901,9 @@ TVM_REGISTER_OP("relax.call_py_func") Expr MakeCallPyFunc(StringImm func_name, Tuple args, ffi::Array out_sinfo_list) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); - CHECK(shape != nullptr) - << "out_sinfo of call_py_func should have defined ShapeExpr as shape. " - "However, one given structure info is " - << sinfo; + CHECK(shape != nullptr) << "out_sinfo of call_py_func should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; } StructInfo out_sinfo{nullptr}; From 6fe557198ef76bd5f7c037e6944fc406b900c3ce Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 19 Sep 2025 09:12:45 +0800 Subject: [PATCH 05/19] fix --- python/tvm/script/ir_builder/relax/ir.py | 2 +- .../relax/test_base_py_module_printer.py | 103 ------------------ 2 files changed, 1 insertion(+), 104 deletions(-) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index a352dbf8f12c..aeb3fae8c875 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -496,7 +496,7 @@ def call_py_func( # Convert string to StringImm try: - func_name_imm = StringImm(py_func_name) if hasattr(py_func_name, "strip") else py_func_name + func_name_imm = StringImm(py_func_name) if isinstance(py_func_name, str) else py_func_name except (TypeError, ValueError, AttributeError): func_name_imm = StringImm(py_func_name) return _call_py_func(func_name_imm, args, out_sinfo) diff --git a/tests/python/relax/test_base_py_module_printer.py b/tests/python/relax/test_base_py_module_printer.py index 4c4e54ee5b36..6e87174fda35 100644 --- a/tests/python/relax/test_base_py_module_printer.py +++ b/tests/python/relax/test_base_py_module_printer.py @@ -760,71 +760,6 @@ def test_python_functions_in_irmodule(): pytest.fail("pyfuncs attribute not found in IRModule") -def test_call_py_func_operator(): - """Test R.call_py_func operator functionality.""" - import torch - - @I.ir_module - class CallPyFuncTestModule(BasePyModule): - """Test module with call_py_func usage.""" - - @I.pyfunc - def pytorch_add(self, x, y): - """Simple PyTorch addition.""" - return x + y - - @I.pyfunc - def pytorch_multiply(self, x, y): - """Simple PyTorch multiplication.""" - return x * y - - @R.function - def test_call_py_func( - x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") - ) -> R.Tensor((5,), "float32"): - # Test calling Python function from Relax - result = R.call_py_func("pytorch_add", (x, y), out_sinfo=R.Tensor((5,), "float32")) - return result - - @R.function - def test_call_py_func_chain( - x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") - ) -> R.Tensor((5,), "float32"): - # First call - intermediate = R.call_py_func( - "pytorch_add", (x, y), out_sinfo=R.Tensor((5,), "float32") - ) - # Second call - result = R.call_py_func( - "pytorch_multiply", (intermediate, y), out_sinfo=R.Tensor((5,), "float32") - ) - return result - - # Test basic functionality - device = tvm.cpu() - module = CallPyFuncTestModule(device) - - # Create test tensors - x = torch.randn(5, dtype=torch.float32) - y = torch.randn(5, dtype=torch.float32) - - # Test direct Python function calls - expected_add = x + y - expected_multiply = x * y - - # Test through BasePyModule - result_add = module.call_py_func("pytorch_add", [x, y]) - result_multiply = module.call_py_func("pytorch_multiply", [x, y]) - - assert torch.allclose(result_add, expected_add, atol=1e-5) - assert torch.allclose(result_multiply, expected_multiply, atol=1e-5) - - # Test that the module has the pyfuncs - assert hasattr(module, "pyfuncs") - assert "pytorch_add" in module.pyfuncs - assert "pytorch_multiply" in module.pyfuncs - - def test_call_py_func_validation(): """Test call_py_func validation and error handling.""" import torch @@ -930,41 +865,3 @@ def test_call_py_func_compilation_validation(): assert False, "Should raise type error" except Exception as e: assert "Mismatched type" in str(e) or "Expected" in str(e) - - -def test_call_py_func_runtime_execution(): - """Test call_py_func runtime execution with complex operations.""" - import torch - - @I.ir_module - class RuntimeTestModule(BasePyModule): - """Test module for runtime execution""" - - @I.pyfunc - def pytorch_add(self, x, y): - """PyTorch addition function""" - return x + y - - @I.pyfunc - def pytorch_relu(self, x): - """PyTorch ReLU function""" - return torch.relu(x) - - # Create module instance - device = tvm.cpu() - module = RuntimeTestModule(device) - - # Create test tensors - x = torch.randn(5, dtype=torch.float32) - y = torch.randn(5, dtype=torch.float32) - - # Test direct Python function calls - result_add = module.call_py_func("pytorch_add", [x, y]) - result_relu = module.call_py_func("pytorch_relu", [x]) - - expected_add = x + y - expected_relu = torch.relu(x) - - # Verify results - assert torch.allclose(result_add, expected_add, atol=1e-5) - assert torch.allclose(result_relu, expected_relu, atol=1e-5) From 621698a8f7295eae1eeea6d7fbd7f4a424294d58 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 19 Sep 2025 09:24:03 +0800 Subject: [PATCH 06/19] fix2 --- python/tvm/script/ir_builder/relax/ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index aeb3fae8c875..61d4b86be50b 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -496,7 +496,7 @@ def call_py_func( # Convert string to StringImm try: - func_name_imm = StringImm(py_func_name) if isinstance(py_func_name, str) else py_func_name + func_name_imm = StringImm(py_func_name) if isinstance(py_func_name, py_str) else py_func_name except (TypeError, ValueError, AttributeError): func_name_imm = StringImm(py_func_name) return _call_py_func(func_name_imm, args, out_sinfo) From 6e541c6cf04ec0b700bbb3d71ea6610a60fae8ea Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 19 Sep 2025 10:19:23 +0800 Subject: [PATCH 07/19] fix3 --- python/tvm/script/ir_builder/relax/ir.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 61d4b86be50b..3fa735197ac5 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -496,7 +496,9 @@ def call_py_func( # Convert string to StringImm try: - func_name_imm = StringImm(py_func_name) if isinstance(py_func_name, py_str) else py_func_name + func_name_imm = ( + StringImm(py_func_name) if isinstance(py_func_name, py_str) else py_func_name + ) except (TypeError, ValueError, AttributeError): func_name_imm = StringImm(py_func_name) return _call_py_func(func_name_imm, args, out_sinfo) From c75e296b95ef59bec1b62222701ceb2921b6636f Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 19 Sep 2025 10:40:11 +0800 Subject: [PATCH 08/19] fix4 --- src/relax/op/op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 2ac9f30d7423..b956b4d168fd 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -917,10 +917,10 @@ Expr MakeCallPyFunc(StringImm func_name, Tuple args, ffi::Array Date: Sat, 20 Sep 2025 04:24:09 +0800 Subject: [PATCH 09/19] s1 --- src/relax/backend/vm/codegen_vm.cc | 1 + src/relax/backend/vm/lower_runtime_builtin.cc | 32 +++++++++++++++++++ src/relax/op/op.cc | 4 +-- src/runtime/vm/builtin.cc | 18 +++++++++++ 4 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index e29f580793b1..31aada5e776f 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -369,6 +369,7 @@ class CodeGenVM : public ExprFunctor { builder_->EmitCall(func, args, dst_reg); } + void EmitNormalCall(const Call& call_node, RegName dst_reg) { Instruction::Arg func = VisitExpr(call_node->op); std::vector args = VisitArray(call_node->args); diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index cb5b8e8b1360..75197e16c5bb 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -52,6 +53,8 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { return ShapeOf(call); } else if (call->op == tensor_to_shape_op_) { return TensorToShape(call); + } else if (call->op == call_py_func_op_) { + return CallPyFunc(call); } else if (call->op == to_vdevice_op_) { return ToDevice(call); } else if (call->op == make_closure_op_) { @@ -139,6 +142,33 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { return Call(builtin_tensor_to_shape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); } + Expr CallPyFunc(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->struct_info_.defined()); + + // Get function name from first argument (StringImm) + auto func_name = Downcast(call_node->args[0])->value; + + // Get arguments from second argument (Tuple) + auto args_tuple = Downcast(call_node->args[1]); + + // Create extern function for Python function call + auto py_func_extern = ExternFunc("vm.builtin.call_py_func"); + + // Create new arguments: [py_func_extern, (func_name, args_tuple)] + ffi::Array new_args; + new_args.push_back(py_func_extern); + + // Create tuple with function name and arguments + ffi::Array tuple_fields; + tuple_fields.push_back(call_node->args[0]); // function name + tuple_fields.push_back(call_node->args[1]); // arguments tuple + new_args.push_back(Tuple(tuple_fields)); + + // Create call_builtin_with_ctx call + return Call(call_builtin_with_ctx_op_, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); + } + Expr ToDevice(const Call& call_node) { // TODO(yongwww): replace ToVDeviceAttrs with related Expr ICHECK(call_node->args.size() == 1); @@ -198,6 +228,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { const Op& reshape_op_ = Op::Get("relax.reshape"); const Op& shape_of_op_ = Op::Get("relax.shape_of"); const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape"); + const Op& call_py_func_op_ = Op::Get("relax.call_py_func"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); const Op& make_closure_op_ = Op::Get("relax.make_closure"); const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); @@ -216,6 +247,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { const ExternFunc builtin_reshape_{"vm.builtin.reshape"}; const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"}; const ExternFunc builtin_tensor_to_shape_{"vm.builtin.tensor_to_shape"}; + const ExternFunc builtin_call_py_func_{"vm.builtin.call_py_func"}; const ExternFunc builtin_to_device_{"vm.builtin.to_device"}; const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"}; const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index b956b4d168fd..2ac9f30d7423 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -917,10 +917,10 @@ Expr MakeCallPyFunc(StringImm func_name, Tuple args, ffi::Array Date: Fri, 19 Sep 2025 19:48:47 -0400 Subject: [PATCH 10/19] te2 --- python/tvm/relax/base_py_module.py | 34 +++++ simple_test.py | 46 ++++++ src/relax/backend/vm/lower_runtime_builtin.cc | 21 +-- src/relax/op/op.cc | 64 -------- test_call_py_func.py | 137 ++++++++++++++++++ test_full_compilation.py | 132 +++++++++++++++++ 6 files changed, 353 insertions(+), 81 deletions(-) create mode 100644 simple_test.py create mode 100644 test_call_py_func.py create mode 100644 test_full_compilation.py diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index 52f813dc6b6d..c51fc802d85b 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -100,6 +100,7 @@ def _getattr_python_function(name: str) -> Any: self._compile_functions() self._wrap_tir_functions() self._wrap_relax_functions() + self._register_python_functions() def _collect_function_names(self): """Collect names of TIR and Relax functions from IRModule.""" @@ -177,6 +178,39 @@ def wrapper(*args, **kwargs): setattr(self, func_name, _create_relax_wrapper(func_name)) + def _register_python_functions(self): + """Register Python functions with the VM runtime for call_py_func support.""" + if not hasattr(self.ir_mod, 'pyfuncs') or not self.ir_mod.pyfuncs: + return + + # Get the register function from TVM + try: + register_py_func = tvm.get_global_func("vm.builtin.register_py_func") + except ValueError: + # Function not available, skip registration + return + + for func_name, py_func in self.ir_mod.pyfuncs.items(): + # Create a wrapper that handles TVM tensor conversion + def create_py_func_wrapper(name, original_func): + def wrapper(*args, **kwargs): + # Convert TVM tensors to PyTorch tensors for Python function + converted_args = [self._convert_tvm_to_pytorch(arg) for arg in args] + converted_kwargs = {k: self._convert_tvm_to_pytorch(v) for k, v in kwargs.items()} + + # Call the original Python function + result = original_func(self, *converted_args, **converted_kwargs) + + # Convert result back to TVM format + return self._convert_pytorch_to_tvm(result) + + wrapper.__name__ = name + return wrapper + + # Register the wrapped function + wrapped_func = create_py_func_wrapper(func_name, py_func) + register_py_func(func_name, wrapped_func) + def call_tir(self, tir_func, args, out_sinfo): """Call a TIR function with PyTorch tensors.""" # Try to get function name from different sources diff --git a/simple_test.py b/simple_test.py new file mode 100644 index 000000000000..dc83b148fdb2 --- /dev/null +++ b/simple_test.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +""" +简化的 R.call_py_func 测试 +""" + +import tvm +from tvm import relax as R +from tvm.relax.op import call_py_func +import numpy as np + +# 测试 1: 操作符创建 +print("=== 测试操作符创建 ===") +try: + x_var = R.Var("x", R.TensorStructInfo((3,), "float32")) + call_expr = call_py_func(R.StringImm("add_one"), (x_var,), out_sinfo=R.TensorStructInfo((3,), "float32")) + print(f"✓ 成功创建 call_py_func 表达式: {call_expr}") + print(f"操作符类型: {type(call_expr)}") +except Exception as e: + print(f"✗ 创建失败: {e}") + +# 测试 2: 函数注册 +print("\n=== 测试函数注册 ===") +try: + def add_one(x): + return x + 1.0 + + register_func = tvm.get_global_func("vm.builtin.register_py_func") + register_func("add_one", add_one) + print("✓ 成功注册 Python 函数") +except Exception as e: + print(f"✗ 注册失败: {e}") + +# 测试 3: 简单的 VM 调用(不使用 call_py_func) +print("\n=== 测试直接 Python 函数调用 ===") +try: + def add_one(x): + print(f"Python 函数被调用,输入: {x}") + return x + 1.0 + + x = np.array([1.0, 2.0, 3.0], dtype=np.float32) + result = add_one(x) + print(f"✓ 直接调用成功,结果: {result}") +except Exception as e: + print(f"✗ 直接调用失败: {e}") + +print("\n=== 测试完成 ===") diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index 95a28ecf24bf..6a72f572d926 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -146,27 +146,14 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { ICHECK(call_node->args.size() == 2); ICHECK(call_node->struct_info_.defined()); - // Get function name from first argument (StringImm) - auto func_name = Downcast(call_node->args[0])->value; - - // Get arguments from second argument (Tuple) - auto args_tuple = Downcast(call_node->args[1]); - - // Create extern function for Python function call - auto py_func_extern = ExternFunc("vm.builtin.call_py_func"); - - // Create new arguments: [py_func_extern, (func_name, args_tuple)] - ffi::Array new_args; - new_args.push_back(py_func_extern); - - // Create tuple with function name and arguments + // Create tuple with function name and arguments tuple ffi::Array tuple_fields; tuple_fields.push_back(call_node->args[0]); // function name tuple_fields.push_back(call_node->args[1]); // arguments tuple - new_args.push_back(Tuple(tuple_fields)); + auto combined_tuple = Tuple(tuple_fields); - // Create call_builtin_with_ctx call - return Call(call_builtin_with_ctx_op_, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); + // Direct call to vm.builtin.call_py_func + return Call(builtin_call_py_func_, {combined_tuple}, call_node->attrs, call_node->sinfo_args, call_node->span); } Expr ToDevice(const Call& call_node) { diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 66239dd4c4d2..d91c19b63fd2 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -922,70 +922,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.call_py_func", MakeCallPyFunc); } -// call_py_func - -StructInfo InferStructInfoCallPyFunc(const Call& call, const BlockBuilder& ctx) { - if (call->sinfo_args.size() != 1) { - ctx->ReportFatal(Diagnostic::Error(call) - << "sinfo_args should have exact 1 output struct info."); - } - return call->sinfo_args[0]; -} - -void ValidateCallPyFunc(Call call) { - // Validate that the function name is a string literal - auto func_name = call->args[0]; - CHECK(func_name->IsInstance()) - << "Operation " << call->op << " expects the first argument to be a string literal " - << "specifying the Python function name. However, the first argument " << func_name - << " is not a string literal."; - - // Validate that args is a tuple - Expr arg_tuple = call->args[1]; - CHECK(arg_tuple->struct_info_.as()) - << "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. " - << "However, the second argument " << arg_tuple << " has struct info " - << arg_tuple->struct_info_ << "."; - - CHECK(arg_tuple.as() || arg_tuple.as()) - << "Operation " << call->op << " must hold its arguments as an in-line tuple. " - << "However, " << call << " has arguments " << arg_tuple - << ", which is neither an in-line tuple, " - << "nor a variable binding that may be normalized to an in-line tuple."; -} - -TVM_REGISTER_OP("relax.call_py_func") - .set_num_inputs(2) - .add_argument("func_name", "StringImm", "The name of the Python function to call.") - .add_argument("args", "Tuple", "The input arguments.") - .set_attr("FInferStructInfo", InferStructInfoCallPyFunc) - .set_attr("FValidate", ValidateCallPyFunc) - .set_attr("FPurity", Bool(true)); - -Expr MakeCallPyFunc(StringImm func_name, Tuple args, ffi::Array out_sinfo_list) { - for (const TensorStructInfo& sinfo : out_sinfo_list) { - const auto* shape = sinfo->shape.as(); - CHECK(shape != nullptr) << "out_sinfo of call_py_func should have defined ShapeExpr as shape. " - "However, one given structure info is " - << sinfo; - } - - StructInfo out_sinfo{nullptr}; - if (out_sinfo_list.size() == 1) { - out_sinfo = out_sinfo_list[0]; - } else { - out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); - } - - static const Op& op = Op::Get("relax.call_py_func"); - return Call(op, {func_name, args}, {}, {out_sinfo}); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.op.call_py_func", MakeCallPyFunc); -}); - // call builtin StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { if (call->sinfo_args.size() == 0) { diff --git a/test_call_py_func.py b/test_call_py_func.py new file mode 100644 index 000000000000..02be947f11bf --- /dev/null +++ b/test_call_py_func.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python3 +""" +Test script for R.call_py_func functionality +测试编译后的模块能否执行 Python 函数 +""" + +import tvm +from tvm import relax as R +from tvm.relax.op import call_py_func +import numpy as np + +# 定义一个简单的 Python 函数 +def add_one(x): + """Add one to input tensor.""" + print(f"Python function called with: {x}") + return x + 1.0 + +# 测试 1: 直接测试 R.call_py_func 操作符 +def test_call_py_func_operator(): + """测试 R.call_py_func 操作符是否能被正确识别""" + print("=== 测试 R.call_py_func 操作符 ===") + + # 创建测试数据 + x = np.array([1.0, 2.0, 3.0], dtype=np.float32) + + # 尝试创建 call_py_func 调用 + try: + # 创建 Relax 变量而不是直接使用 numpy 数组 + x_var = R.Var("x", R.TensorStructInfo((3,), "float32")) + call_expr = call_py_func(R.StringImm("add_one"), (x_var,), out_sinfo=R.TensorStructInfo((3,), "float32")) + print(f"成功创建 call_py_func 表达式: {call_expr}") + print(f"操作符类型: {type(call_expr)}") + return True + except Exception as e: + print(f"创建 call_py_func 失败: {e}") + return False + +# 测试 2: 测试 VM 运行时是否能处理 call_py_func +def test_vm_runtime(): + """测试 VM 运行时是否能处理 call_py_func""" + print("\n=== 测试 VM 运行时 ===") + + # 注册 Python 函数到 VM + try: + register_func = tvm.get_global_func("vm.builtin.register_py_func") + register_func("add_one", add_one) + print("成功注册 Python 函数到 VM") + except Exception as e: + print(f"注册 Python 函数失败: {e}") + return False + + # 测试 VM builtin 调用 + try: + call_py_func = tvm.get_global_func("vm.builtin.call_py_func") + x = np.array([1.0, 2.0, 3.0], dtype=np.float32) + # 将 numpy 数组转换为 TVM tensor + x_tvm = tvm.runtime.Tensor(x) + result = call_py_func(("add_one", (x_tvm,))) + print(f"VM 调用成功,结果: {result}") + return True + except Exception as e: + print(f"VM 调用失败: {e}") + import traceback + traceback.print_exc() + return False + +# 测试 3: 测试完整的编译和执行流程 +def test_compilation_flow(): + """测试完整的编译和执行流程""" + print("\n=== 测试编译和执行流程 ===") + + # 注册 Python 函数 + try: + register_func = tvm.get_global_func("vm.builtin.register_py_func") + register_func("add_one", add_one) + print("✓ 成功注册 Python 函数") + except Exception as e: + print(f"✗ 注册 Python 函数失败: {e}") + return False + + # 创建一个简单的 Relax 函数,使用 call_py_func + try: + # 使用 BlockBuilder 创建函数 + bb = R.BlockBuilder() + + # 创建函数参数 + x_param = R.Var("x", R.TensorStructInfo((3,), "float32")) + with bb.function("main", (x_param,)): + result = bb.emit(call_py_func(R.StringImm("add_one"), (x_param,), out_sinfo=R.TensorStructInfo((3,), "float32"))) + bb.emit_output(result) + + mod = bb.get() + print("✓ 成功创建 Relax 模块") + print(f"模块: {mod}") + + # 编译模块 + vm = R.vm.VirtualMachine(mod, tvm.cpu()) + print("✓ 成功创建 VM") + + # 执行模块 + x = np.array([1.0, 2.0, 3.0], dtype=np.float32) + result = vm["main"](x) + print(f"✓ 执行成功,结果: {result}") + + return True + + except Exception as e: + print(f"✗ 编译/执行失败: {e}") + import traceback + traceback.print_exc() + return False + +if __name__ == "__main__": + print("测试 R.call_py_func 功能") + print("目标:让编译后的模块能够执行 Python 函数") + print("=" * 50) + + # 测试操作符 + op_success = test_call_py_func_operator() + + # 测试运行时 + runtime_success = test_vm_runtime() + + # 测试完整流程 + flow_success = test_compilation_flow() + + print("\n" + "=" * 50) + print("测试结果总结:") + print(f"操作符创建: {'✓' if op_success else '✗'}") + print(f"运行时执行: {'✓' if runtime_success else '✗'}") + print(f"完整流程: {'✓' if flow_success else '✗'}") + + if op_success and runtime_success and flow_success: + print("🎉 R.call_py_func 完整功能测试通过!") + print("✅ 编译后的模块可以成功执行 Python 函数!") + else: + print("❌ 部分测试失败,需要进一步调试") \ No newline at end of file diff --git a/test_full_compilation.py b/test_full_compilation.py new file mode 100644 index 000000000000..d55ba2e71ed2 --- /dev/null +++ b/test_full_compilation.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +""" +测试完整的 R.call_py_func 编译和执行流程 +""" + +import tvm +from tvm import relax as R +from tvm.relax.op import call_py_func +import numpy as np + +# 定义 Python 函数 +def add_one(x): + print(f"Python 函数被调用,输入: {x}") + # 将 TVM Tensor 转换为 NumPy 数组 + if hasattr(x, 'numpy'): + x_np = x.numpy() + else: + x_np = x + result = x_np + 1.0 + # 将结果转换回 TVM Tensor + return tvm.runtime.tensor(result) + +# 注册 Python 函数 +print("=== 注册 Python 函数 ===") +try: + register_func = tvm.get_global_func("vm.builtin.register_py_func") + register_func("add_one", add_one) + print("✓ 成功注册 Python 函数") +except Exception as e: + print(f"✗ 注册失败: {e}") + exit(1) + +# 创建 Relax 模块 +print("\n=== 创建 Relax 模块 ===") +try: + bb = R.BlockBuilder() + + # 创建函数参数 + x_param = R.Var("x", R.TensorStructInfo((3,), "float32")) + + with bb.function("main", (x_param,)): + result = bb.emit(call_py_func(R.StringImm("add_one"), (x_param,), out_sinfo=R.TensorStructInfo((3,), "float32"))) + bb.emit_func_output(result) + + mod = bb.finalize() + print("✓ 成功创建 Relax 模块") + print(f"模块: {mod}") +except Exception as e: + print(f"✗ 创建模块失败: {e}") + import traceback + traceback.print_exc() + exit(1) + +# 测试 relax.build +print("\n=== 测试 relax.build ===") +try: + # 编译模块 + target = tvm.target.Target("llvm") + ex = R.build(mod, target, exec_mode="compiled") + print("✓ 成功编译模块") + print(f"编译结果类型: {type(ex)}") +except Exception as e: + print(f"✗ 编译失败: {e}") + import traceback + traceback.print_exc() + exit(1) + +# 测试执行 +print("\n=== 测试执行 ===") +try: + # 创建 VirtualMachine + vm = R.VirtualMachine(ex, tvm.cpu()) + print("✓ 成功创建 VirtualMachine") + + # 创建测试数据 + x = np.array([1.0, 2.0, 3.0], dtype=np.float32) + x_tvm = tvm.runtime.tensor(x) + print(f"输入数据: {x}") + + # 执行编译后的模块 + result = vm["main"](x_tvm) + print(f"✓ 执行成功,结果: {result}") + print(f"结果类型: {type(result)}") + + # 验证结果 + expected = x + 1.0 + print(f"期望结果: {expected}") + print(f"结果匹配: {np.allclose(result.numpy(), expected)}") + +except Exception as e: + print(f"✗ 执行失败: {e}") + import traceback + traceback.print_exc() + exit(1) + +print("\n🎉 完整编译和执行流程测试成功!") +print("✅ R.call_py_func 可以让编译后的模块执行 Python 函数!") + +# 清理资源 +print("\n=== 清理资源 ===") +try: + # 清理 Python 函数注册 + if 'register_func' in locals(): + del register_func + if 'vm' in locals(): + del vm + if 'ex' in locals(): + del ex + print("✓ 资源清理完成") +except Exception as e: + print(f"清理过程中出现警告: {e}") + +# 强制垃圾回收 +import gc +gc.collect() + +# 使用 atexit 确保程序退出时清理 +import atexit + +def cleanup_on_exit(): + try: + # 清理全局 Python 函数注册 + if 'py_func_registry' in globals(): + del globals()['py_func_registry'] + except: + pass + +atexit.register(cleanup_on_exit) + +# 直接退出,避免段错误 +import sys +sys.exit(0) From 9a7ae3ea952f9d51213ecd0eda97c8f0bdb32d13 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 19 Sep 2025 20:09:36 -0400 Subject: [PATCH 11/19] finish1 --- simple_test.py | 46 ------- src/runtime/vm/builtin.cc | 85 +++++++++++-- test_call_py_func.py | 137 --------------------- test_full_compilation.py | 132 -------------------- tests/python/relax/test_relax_operators.py | 108 ++++++++++++++++ 5 files changed, 180 insertions(+), 328 deletions(-) delete mode 100644 simple_test.py delete mode 100644 test_call_py_func.py delete mode 100644 test_full_compilation.py diff --git a/simple_test.py b/simple_test.py deleted file mode 100644 index dc83b148fdb2..000000000000 --- a/simple_test.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env python3 -""" -简化的 R.call_py_func 测试 -""" - -import tvm -from tvm import relax as R -from tvm.relax.op import call_py_func -import numpy as np - -# 测试 1: 操作符创建 -print("=== 测试操作符创建 ===") -try: - x_var = R.Var("x", R.TensorStructInfo((3,), "float32")) - call_expr = call_py_func(R.StringImm("add_one"), (x_var,), out_sinfo=R.TensorStructInfo((3,), "float32")) - print(f"✓ 成功创建 call_py_func 表达式: {call_expr}") - print(f"操作符类型: {type(call_expr)}") -except Exception as e: - print(f"✗ 创建失败: {e}") - -# 测试 2: 函数注册 -print("\n=== 测试函数注册 ===") -try: - def add_one(x): - return x + 1.0 - - register_func = tvm.get_global_func("vm.builtin.register_py_func") - register_func("add_one", add_one) - print("✓ 成功注册 Python 函数") -except Exception as e: - print(f"✗ 注册失败: {e}") - -# 测试 3: 简单的 VM 调用(不使用 call_py_func) -print("\n=== 测试直接 Python 函数调用 ===") -try: - def add_one(x): - print(f"Python 函数被调用,输入: {x}") - return x + 1.0 - - x = np.array([1.0, 2.0, 3.0], dtype=np.float32) - result = add_one(x) - print(f"✓ 直接调用成功,结果: {result}") -except Exception as e: - print(f"✗ 直接调用失败: {e}") - -print("\n=== 测试完成 ===") diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index 97f4706f11da..6da370e17a10 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -33,6 +33,7 @@ #include #include #include +#include namespace tvm { namespace runtime { @@ -433,20 +434,78 @@ TVM_FFI_STATIC_INIT_BLOCK() { //------------------------------------- // Python function call support //------------------------------------- -TVM_FFI_STATIC_INIT_BLOCK({ + +// Global registry for Python functions +static std::unordered_map py_func_registry; + +/*! + * \brief Clear the Python function registry on shutdown + */ +void ClearPyFuncRegistry() { + py_func_registry.clear(); +} + +/*! + * \brief Register a Python function for call_py_func + * \param name The function name + * \param func The Python function wrapped as ffi::Function + */ +void RegisterPyFunc(const std::string& name, ffi::Function func) { + py_func_registry[name] = func; +} + +/*! + * \brief Get a registered Python function + * \param name The function name + * \return The Python function + */ +ffi::Function GetPyFunc(const std::string& name) { + auto it = py_func_registry.find(name); + if (it == py_func_registry.end()) { + LOG(FATAL) << "Python function '" << name << "' not found in registry"; + } + return it->second; +} + +/*! + * \brief Call a Python function from VM + * \param args The packed function arguments (tuple containing function name and arguments) + * \param rv The return value + */ +void CallPyFunc(ffi::PackedArgs args, ffi::Any* rv) { + // args[0] should be a tuple containing (func_name, args_tuple) + if (args.size() != 1) { + LOG(FATAL) << "vm.builtin.call_py_func expects exactly 1 argument (tuple)"; + } + + auto tuple_arg = args[0].cast>(); + if (tuple_arg.size() != 2) { + LOG(FATAL) << "vm.builtin.call_py_func tuple should contain (func_name, args)"; + } + + // Get function name + std::string func_name = tuple_arg[0].cast(); + + // Get arguments tuple + auto func_args = tuple_arg[1].cast>(); + + // Look up Python function in registry + ffi::Function py_func = GetPyFunc(func_name); + + // Call the Python function with the arguments + std::vector py_args_vec(func_args.begin(), func_args.end()); + ffi::PackedArgs py_args(py_args_vec.data(), py_args_vec.size()); + py_func.CallPacked(py_args, rv); +} + +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("vm.builtin.call_py_func", [](ffi::PackedArgs args, ffi::Any* rv) { - // This is a placeholder implementation - // In a real implementation, this would: - // 1. Get the function name from args[0] - // 2. Get the arguments from args[1] - // 3. Look up the Python function in the global registry - // 4. Convert TVM tensors to Python objects - // 5. Call the Python function - // 6. Convert the result back to TVM format - LOG(FATAL) << "vm.builtin.call_py_func not implemented yet - Python function calls not supported in VM runtime"; - }); -}); + refl::GlobalDef() + .def_packed("vm.builtin.call_py_func", CallPyFunc) + .def("vm.builtin.register_py_func", RegisterPyFunc) + .def("vm.builtin.get_py_func", GetPyFunc) + .def("vm.builtin.clear_py_func_registry", ClearPyFuncRegistry); +} //------------------------------------- // Builtin runtime operators. diff --git a/test_call_py_func.py b/test_call_py_func.py deleted file mode 100644 index 02be947f11bf..000000000000 --- a/test_call_py_func.py +++ /dev/null @@ -1,137 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for R.call_py_func functionality -测试编译后的模块能否执行 Python 函数 -""" - -import tvm -from tvm import relax as R -from tvm.relax.op import call_py_func -import numpy as np - -# 定义一个简单的 Python 函数 -def add_one(x): - """Add one to input tensor.""" - print(f"Python function called with: {x}") - return x + 1.0 - -# 测试 1: 直接测试 R.call_py_func 操作符 -def test_call_py_func_operator(): - """测试 R.call_py_func 操作符是否能被正确识别""" - print("=== 测试 R.call_py_func 操作符 ===") - - # 创建测试数据 - x = np.array([1.0, 2.0, 3.0], dtype=np.float32) - - # 尝试创建 call_py_func 调用 - try: - # 创建 Relax 变量而不是直接使用 numpy 数组 - x_var = R.Var("x", R.TensorStructInfo((3,), "float32")) - call_expr = call_py_func(R.StringImm("add_one"), (x_var,), out_sinfo=R.TensorStructInfo((3,), "float32")) - print(f"成功创建 call_py_func 表达式: {call_expr}") - print(f"操作符类型: {type(call_expr)}") - return True - except Exception as e: - print(f"创建 call_py_func 失败: {e}") - return False - -# 测试 2: 测试 VM 运行时是否能处理 call_py_func -def test_vm_runtime(): - """测试 VM 运行时是否能处理 call_py_func""" - print("\n=== 测试 VM 运行时 ===") - - # 注册 Python 函数到 VM - try: - register_func = tvm.get_global_func("vm.builtin.register_py_func") - register_func("add_one", add_one) - print("成功注册 Python 函数到 VM") - except Exception as e: - print(f"注册 Python 函数失败: {e}") - return False - - # 测试 VM builtin 调用 - try: - call_py_func = tvm.get_global_func("vm.builtin.call_py_func") - x = np.array([1.0, 2.0, 3.0], dtype=np.float32) - # 将 numpy 数组转换为 TVM tensor - x_tvm = tvm.runtime.Tensor(x) - result = call_py_func(("add_one", (x_tvm,))) - print(f"VM 调用成功,结果: {result}") - return True - except Exception as e: - print(f"VM 调用失败: {e}") - import traceback - traceback.print_exc() - return False - -# 测试 3: 测试完整的编译和执行流程 -def test_compilation_flow(): - """测试完整的编译和执行流程""" - print("\n=== 测试编译和执行流程 ===") - - # 注册 Python 函数 - try: - register_func = tvm.get_global_func("vm.builtin.register_py_func") - register_func("add_one", add_one) - print("✓ 成功注册 Python 函数") - except Exception as e: - print(f"✗ 注册 Python 函数失败: {e}") - return False - - # 创建一个简单的 Relax 函数,使用 call_py_func - try: - # 使用 BlockBuilder 创建函数 - bb = R.BlockBuilder() - - # 创建函数参数 - x_param = R.Var("x", R.TensorStructInfo((3,), "float32")) - with bb.function("main", (x_param,)): - result = bb.emit(call_py_func(R.StringImm("add_one"), (x_param,), out_sinfo=R.TensorStructInfo((3,), "float32"))) - bb.emit_output(result) - - mod = bb.get() - print("✓ 成功创建 Relax 模块") - print(f"模块: {mod}") - - # 编译模块 - vm = R.vm.VirtualMachine(mod, tvm.cpu()) - print("✓ 成功创建 VM") - - # 执行模块 - x = np.array([1.0, 2.0, 3.0], dtype=np.float32) - result = vm["main"](x) - print(f"✓ 执行成功,结果: {result}") - - return True - - except Exception as e: - print(f"✗ 编译/执行失败: {e}") - import traceback - traceback.print_exc() - return False - -if __name__ == "__main__": - print("测试 R.call_py_func 功能") - print("目标:让编译后的模块能够执行 Python 函数") - print("=" * 50) - - # 测试操作符 - op_success = test_call_py_func_operator() - - # 测试运行时 - runtime_success = test_vm_runtime() - - # 测试完整流程 - flow_success = test_compilation_flow() - - print("\n" + "=" * 50) - print("测试结果总结:") - print(f"操作符创建: {'✓' if op_success else '✗'}") - print(f"运行时执行: {'✓' if runtime_success else '✗'}") - print(f"完整流程: {'✓' if flow_success else '✗'}") - - if op_success and runtime_success and flow_success: - print("🎉 R.call_py_func 完整功能测试通过!") - print("✅ 编译后的模块可以成功执行 Python 函数!") - else: - print("❌ 部分测试失败,需要进一步调试") \ No newline at end of file diff --git a/test_full_compilation.py b/test_full_compilation.py deleted file mode 100644 index d55ba2e71ed2..000000000000 --- a/test_full_compilation.py +++ /dev/null @@ -1,132 +0,0 @@ -#!/usr/bin/env python3 -""" -测试完整的 R.call_py_func 编译和执行流程 -""" - -import tvm -from tvm import relax as R -from tvm.relax.op import call_py_func -import numpy as np - -# 定义 Python 函数 -def add_one(x): - print(f"Python 函数被调用,输入: {x}") - # 将 TVM Tensor 转换为 NumPy 数组 - if hasattr(x, 'numpy'): - x_np = x.numpy() - else: - x_np = x - result = x_np + 1.0 - # 将结果转换回 TVM Tensor - return tvm.runtime.tensor(result) - -# 注册 Python 函数 -print("=== 注册 Python 函数 ===") -try: - register_func = tvm.get_global_func("vm.builtin.register_py_func") - register_func("add_one", add_one) - print("✓ 成功注册 Python 函数") -except Exception as e: - print(f"✗ 注册失败: {e}") - exit(1) - -# 创建 Relax 模块 -print("\n=== 创建 Relax 模块 ===") -try: - bb = R.BlockBuilder() - - # 创建函数参数 - x_param = R.Var("x", R.TensorStructInfo((3,), "float32")) - - with bb.function("main", (x_param,)): - result = bb.emit(call_py_func(R.StringImm("add_one"), (x_param,), out_sinfo=R.TensorStructInfo((3,), "float32"))) - bb.emit_func_output(result) - - mod = bb.finalize() - print("✓ 成功创建 Relax 模块") - print(f"模块: {mod}") -except Exception as e: - print(f"✗ 创建模块失败: {e}") - import traceback - traceback.print_exc() - exit(1) - -# 测试 relax.build -print("\n=== 测试 relax.build ===") -try: - # 编译模块 - target = tvm.target.Target("llvm") - ex = R.build(mod, target, exec_mode="compiled") - print("✓ 成功编译模块") - print(f"编译结果类型: {type(ex)}") -except Exception as e: - print(f"✗ 编译失败: {e}") - import traceback - traceback.print_exc() - exit(1) - -# 测试执行 -print("\n=== 测试执行 ===") -try: - # 创建 VirtualMachine - vm = R.VirtualMachine(ex, tvm.cpu()) - print("✓ 成功创建 VirtualMachine") - - # 创建测试数据 - x = np.array([1.0, 2.0, 3.0], dtype=np.float32) - x_tvm = tvm.runtime.tensor(x) - print(f"输入数据: {x}") - - # 执行编译后的模块 - result = vm["main"](x_tvm) - print(f"✓ 执行成功,结果: {result}") - print(f"结果类型: {type(result)}") - - # 验证结果 - expected = x + 1.0 - print(f"期望结果: {expected}") - print(f"结果匹配: {np.allclose(result.numpy(), expected)}") - -except Exception as e: - print(f"✗ 执行失败: {e}") - import traceback - traceback.print_exc() - exit(1) - -print("\n🎉 完整编译和执行流程测试成功!") -print("✅ R.call_py_func 可以让编译后的模块执行 Python 函数!") - -# 清理资源 -print("\n=== 清理资源 ===") -try: - # 清理 Python 函数注册 - if 'register_func' in locals(): - del register_func - if 'vm' in locals(): - del vm - if 'ex' in locals(): - del ex - print("✓ 资源清理完成") -except Exception as e: - print(f"清理过程中出现警告: {e}") - -# 强制垃圾回收 -import gc -gc.collect() - -# 使用 atexit 确保程序退出时清理 -import atexit - -def cleanup_on_exit(): - try: - # 清理全局 Python 函数注册 - if 'py_func_registry' in globals(): - del globals()['py_func_registry'] - except: - pass - -atexit.register(cleanup_on_exit) - -# 直接退出,避免段错误 -import sys -sys.exit(0) diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index 8558f6e911b8..d6f3c13e4d8e 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -409,6 +409,114 @@ def inplace_tuple(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32") assert (result[1].numpy() == sum).all() +def test_op_call_py_func(exec_mode): + """Test R.call_py_func operator functionality.""" + + # Define Python functions for testing + def add_one(x): + """Add one to input tensor.""" + if hasattr(x, 'numpy'): + x_np = x.numpy() + elif hasattr(x, 'asnumpy'): + x_np = x.asnumpy() + else: + # Convert TVM Array to NumPy + x_np = np.array(x) + # If it's still a Tensor, convert to numpy + if hasattr(x_np, 'numpy'): + x_np = x_np.numpy() + # If the array contains Tensor objects, extract their values + if len(x_np) > 0 and hasattr(x_np[0], 'numpy'): + x_np = np.array([t.numpy() for t in x_np]) + # Flatten if needed to match expected shape + if x_np.ndim > 1: + x_np = x_np.flatten() + result = x_np + 1.0 + return tvm.runtime.tensor(result) + + def multiply_by_two(x): + """Multiply input tensor by two.""" + if hasattr(x, 'numpy'): + x_np = x.numpy() + elif hasattr(x, 'asnumpy'): + x_np = x.asnumpy() + else: + # Convert TVM Array to NumPy + x_np = np.array(x) + # If it's still a Tensor, convert to numpy + if hasattr(x_np, 'numpy'): + x_np = x_np.numpy() + # If the array contains Tensor objects, extract their values + if len(x_np) > 0 and hasattr(x_np[0], 'numpy'): + x_np = np.array([t.numpy() for t in x_np]) + # Flatten if needed to match expected shape + if x_np.ndim > 1: + x_np = x_np.flatten() + result = x_np * 2.0 + return tvm.runtime.tensor(result) + + # Register Python functions + register_func = tvm.get_global_func("vm.builtin.register_py_func") + register_func("add_one", add_one) + register_func("multiply_by_two", multiply_by_two) + + @tvm.script.ir_module + class CallPyFuncTest: + @R.function + def simple_call(x: R.Tensor((3,), "float32")): + # Simple call_py_func test + result = R.call_py_func( + R.str("add_one"), + (x,), + out_sinfo=R.Tensor((3,), "float32") + ) + return result + + @R.function + def multiple_calls(x: R.Tensor((2,), "float32")): + # Multiple call_py_func calls + y = R.call_py_func( + R.str("add_one"), + (x,), + out_sinfo=R.Tensor((2,), "float32") + ) + z = R.call_py_func( + R.str("multiply_by_two"), + (y,), + out_sinfo=R.Tensor((2,), "float32") + ) + return z + + try: + # Test simple call + x_data = np.array([1.0, 2.0, 3.0], dtype=np.float32) + x_tvm = tvm.runtime.tensor(x_data) + + result = run_cpu(CallPyFuncTest, "simple_call", x_tvm, exec_mode=exec_mode) + + # Verify result + expected = x_data + 1.0 + np.testing.assert_allclose(result.numpy(), expected, rtol=1e-5, atol=1e-5) + + # Test multiple calls + y_data = np.array([0.5, 1.5], dtype=np.float32) + y_tvm = tvm.runtime.tensor(y_data) + + result2 = run_cpu(CallPyFuncTest, "multiple_calls", y_tvm, exec_mode=exec_mode) + + # Verify result (should be (y + 1) * 2) + expected2 = (y_data + 1.0) * 2.0 + np.testing.assert_allclose(result2.numpy(), expected2, rtol=1e-5, atol=1e-5) + + finally: + # Clean up - clear Python function registry + try: + clear_func = tvm.get_global_func("vm.builtin.clear_py_func_registry") + clear_func() + except: + pass + + def test_op_to_device(exec_mode): @tvm.script.ir_module class CallToDevice: From 64b13eaa8580581126c71c858120edfbc521fd26 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 19 Sep 2025 20:31:58 -0400 Subject: [PATCH 12/19] finish2 --- python/tvm/relax/base_py_module.py | 18 ++++--- tests/python/relax/test_relax_operators.py | 58 +++++++++------------- 2 files changed, 33 insertions(+), 43 deletions(-) diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index c51fc802d85b..f7a40c7fd3c9 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -180,33 +180,35 @@ def wrapper(*args, **kwargs): def _register_python_functions(self): """Register Python functions with the VM runtime for call_py_func support.""" - if not hasattr(self.ir_mod, 'pyfuncs') or not self.ir_mod.pyfuncs: + if not hasattr(self.ir_mod, "pyfuncs") or not self.ir_mod.pyfuncs: return - + # Get the register function from TVM try: register_py_func = tvm.get_global_func("vm.builtin.register_py_func") except ValueError: # Function not available, skip registration return - + for func_name, py_func in self.ir_mod.pyfuncs.items(): # Create a wrapper that handles TVM tensor conversion def create_py_func_wrapper(name, original_func): def wrapper(*args, **kwargs): # Convert TVM tensors to PyTorch tensors for Python function converted_args = [self._convert_tvm_to_pytorch(arg) for arg in args] - converted_kwargs = {k: self._convert_tvm_to_pytorch(v) for k, v in kwargs.items()} - + converted_kwargs = { + k: self._convert_tvm_to_pytorch(v) for k, v in kwargs.items() + } + # Call the original Python function result = original_func(self, *converted_args, **converted_kwargs) - + # Convert result back to TVM format return self._convert_pytorch_to_tvm(result) - + wrapper.__name__ = name return wrapper - + # Register the wrapped function wrapped_func = create_py_func_wrapper(func_name, py_func) register_py_func(func_name, wrapped_func) diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index d6f3c13e4d8e..c1dfce5a9829 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -411,103 +411,91 @@ def inplace_tuple(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32") def test_op_call_py_func(exec_mode): """Test R.call_py_func operator functionality.""" - + # Define Python functions for testing def add_one(x): """Add one to input tensor.""" - if hasattr(x, 'numpy'): + if isinstance(x, tvm.runtime.Tensor): x_np = x.numpy() - elif hasattr(x, 'asnumpy'): + elif hasattr(x, "asnumpy"): # Keep hasattr for backward compatibility x_np = x.asnumpy() else: # Convert TVM Array to NumPy x_np = np.array(x) # If it's still a Tensor, convert to numpy - if hasattr(x_np, 'numpy'): + if isinstance(x_np, tvm.runtime.Tensor): x_np = x_np.numpy() # If the array contains Tensor objects, extract their values - if len(x_np) > 0 and hasattr(x_np[0], 'numpy'): + if len(x_np) > 0 and isinstance(x_np[0], tvm.runtime.Tensor): x_np = np.array([t.numpy() for t in x_np]) # Flatten if needed to match expected shape if x_np.ndim > 1: x_np = x_np.flatten() result = x_np + 1.0 return tvm.runtime.tensor(result) - + def multiply_by_two(x): """Multiply input tensor by two.""" - if hasattr(x, 'numpy'): + if isinstance(x, tvm.runtime.Tensor): x_np = x.numpy() - elif hasattr(x, 'asnumpy'): + elif hasattr(x, "asnumpy"): # Keep hasattr for backward compatibility x_np = x.asnumpy() else: # Convert TVM Array to NumPy x_np = np.array(x) # If it's still a Tensor, convert to numpy - if hasattr(x_np, 'numpy'): + if isinstance(x_np, tvm.runtime.Tensor): x_np = x_np.numpy() # If the array contains Tensor objects, extract their values - if len(x_np) > 0 and hasattr(x_np[0], 'numpy'): + if len(x_np) > 0 and isinstance(x_np[0], tvm.runtime.Tensor): x_np = np.array([t.numpy() for t in x_np]) # Flatten if needed to match expected shape if x_np.ndim > 1: x_np = x_np.flatten() result = x_np * 2.0 return tvm.runtime.tensor(result) - + # Register Python functions register_func = tvm.get_global_func("vm.builtin.register_py_func") register_func("add_one", add_one) register_func("multiply_by_two", multiply_by_two) - + @tvm.script.ir_module class CallPyFuncTest: @R.function def simple_call(x: R.Tensor((3,), "float32")): # Simple call_py_func test - result = R.call_py_func( - R.str("add_one"), - (x,), - out_sinfo=R.Tensor((3,), "float32") - ) + result = R.call_py_func(R.str("add_one"), (x,), out_sinfo=R.Tensor((3,), "float32")) return result - + @R.function def multiple_calls(x: R.Tensor((2,), "float32")): # Multiple call_py_func calls - y = R.call_py_func( - R.str("add_one"), - (x,), - out_sinfo=R.Tensor((2,), "float32") - ) - z = R.call_py_func( - R.str("multiply_by_two"), - (y,), - out_sinfo=R.Tensor((2,), "float32") - ) + y = R.call_py_func(R.str("add_one"), (x,), out_sinfo=R.Tensor((2,), "float32")) + z = R.call_py_func(R.str("multiply_by_two"), (y,), out_sinfo=R.Tensor((2,), "float32")) return z - + try: # Test simple call x_data = np.array([1.0, 2.0, 3.0], dtype=np.float32) x_tvm = tvm.runtime.tensor(x_data) - + result = run_cpu(CallPyFuncTest, "simple_call", x_tvm, exec_mode=exec_mode) - + # Verify result expected = x_data + 1.0 np.testing.assert_allclose(result.numpy(), expected, rtol=1e-5, atol=1e-5) - + # Test multiple calls y_data = np.array([0.5, 1.5], dtype=np.float32) y_tvm = tvm.runtime.tensor(y_data) - + result2 = run_cpu(CallPyFuncTest, "multiple_calls", y_tvm, exec_mode=exec_mode) - + # Verify result (should be (y + 1) * 2) expected2 = (y_data + 1.0) * 2.0 np.testing.assert_allclose(result2.numpy(), expected2, rtol=1e-5, atol=1e-5) - + finally: # Clean up - clear Python function registry try: From 6ac7f31ad44f74d983ea0ef5e953099d12ae27a3 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 19 Sep 2025 21:37:57 -0400 Subject: [PATCH 13/19] finish3 --- python/tvm/relax/base_py_module.py | 8 +- .../relax/test_base_py_module_printer.py | 107 ---------------- tests/python/relax/test_relax_operators.py | 114 ++++++++---------- 3 files changed, 48 insertions(+), 181 deletions(-) diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index f7a40c7fd3c9..efd75ff2044b 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -183,33 +183,27 @@ def _register_python_functions(self): if not hasattr(self.ir_mod, "pyfuncs") or not self.ir_mod.pyfuncs: return - # Get the register function from TVM try: register_py_func = tvm.get_global_func("vm.builtin.register_py_func") except ValueError: - # Function not available, skip registration return for func_name, py_func in self.ir_mod.pyfuncs.items(): - # Create a wrapper that handles TVM tensor conversion + def create_py_func_wrapper(name, original_func): def wrapper(*args, **kwargs): - # Convert TVM tensors to PyTorch tensors for Python function converted_args = [self._convert_tvm_to_pytorch(arg) for arg in args] converted_kwargs = { k: self._convert_tvm_to_pytorch(v) for k, v in kwargs.items() } - # Call the original Python function result = original_func(self, *converted_args, **converted_kwargs) - # Convert result back to TVM format return self._convert_pytorch_to_tvm(result) wrapper.__name__ = name return wrapper - # Register the wrapped function wrapped_func = create_py_func_wrapper(func_name, py_func) register_py_func(func_name, wrapped_func) diff --git a/tests/python/relax/test_base_py_module_printer.py b/tests/python/relax/test_base_py_module_printer.py index 6e87174fda35..92c799f6cb70 100644 --- a/tests/python/relax/test_base_py_module_printer.py +++ b/tests/python/relax/test_base_py_module_printer.py @@ -758,110 +758,3 @@ def test_python_functions_in_irmodule(): assert pyfuncs["multiply"].__name__ == "multiply" else: pytest.fail("pyfuncs attribute not found in IRModule") - - -def test_call_py_func_validation(): - """Test call_py_func validation and error handling.""" - import torch - - @I.ir_module - class ValidationTestModule(BasePyModule): - """Test module for validation.""" - - @I.pyfunc - def valid_func(self, x): - """Valid Python function.""" - return x * 2 - - @R.function - def test_invalid_call(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): - # This should cause a validation error - result = R.call_py_func("non_existent_func", (x,), out_sinfo=R.Tensor((5,), "float32")) - return result - - device = tvm.cpu() - module = ValidationTestModule(device) - - # Test that calling non-existent function raises error - x = torch.randn(5, dtype=torch.float32) - - with pytest.raises(ValueError, match="Python function 'non_existent_func' not found"): - module.call_py_func("non_existent_func", [x]) - - -def test_call_py_func_in_relax_function(): - """Test using call_py_func within Relax functions.""" - import torch - - @I.ir_module - class RelaxCallPyFuncModule(BasePyModule): - """Test module with call_py_func in Relax functions.""" - - @I.pyfunc - def torch_relu(self, x): - """PyTorch ReLU implementation.""" - return torch.relu(x) - - @I.pyfunc - def torch_softmax(self, x, dim=0): - """PyTorch softmax implementation.""" - return torch.softmax(x, dim=dim) - - @R.function - def mixed_computation(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32"): - # Use Python function for ReLU - relu_result = R.call_py_func("torch_relu", (x,), out_sinfo=R.Tensor((10,), "float32")) - # Use Python function for softmax - final_result = R.call_py_func( - "torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), "float32") - ) - return final_result - - device = tvm.cpu() - module = RelaxCallPyFuncModule(device) - - # Test the mixed computation - x = torch.randn(10, dtype=torch.float32) - - expected = torch.softmax(torch.relu(x), dim=0) - - relu_result = module.call_py_func("torch_relu", [x]) - final_result = module.call_py_func("torch_softmax", [relu_result]) - - assert torch.allclose(final_result, expected, atol=1e-5) - - -def test_call_py_func_operator_creation(): - """Test R.call_py_func operator creation and basic properties.""" - from tvm.relax.op import call_py_func - from tvm.relax.expr import StringImm - from tvm.relax import Var, TensorStructInfo - - # Create variables - x = Var("x", TensorStructInfo((5,), "float32")) - y = Var("y", TensorStructInfo((5,), "float32")) - - # Create call_py_func call - call_expr = call_py_func(StringImm("test_func"), (x, y), out_sinfo=R.Tensor((5,), "float32")) - - # Verify operator properties - assert call_expr.op.name == "relax.call_py_func" - assert call_expr.args[0].value == "test_func" - assert len(call_expr.args) == 2 - - -def test_call_py_func_compilation_validation(): - """Test call_py_func compilation validation.""" - from tvm.relax.op import call_py_func - from tvm.relax import Var, TensorStructInfo - - # Test operator parameter validation - try: - call_py_func( - "invalid", - (Var("x", TensorStructInfo((5,), "float32")),), - out_sinfo=R.Tensor((5,), "float32"), - ) - assert False, "Should raise type error" - except Exception as e: - assert "Mismatched type" in str(e) or "Expected" in str(e) diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index c1dfce5a9829..897082dd792f 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -411,98 +411,78 @@ def inplace_tuple(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32") def test_op_call_py_func(exec_mode): """Test R.call_py_func operator functionality.""" + import torch - # Define Python functions for testing - def add_one(x): - """Add one to input tensor.""" + def torch_relu(x): if isinstance(x, tvm.runtime.Tensor): - x_np = x.numpy() - elif hasattr(x, "asnumpy"): # Keep hasattr for backward compatibility - x_np = x.asnumpy() + x_torch = torch.from_numpy(x.numpy()) + elif hasattr(x, "asnumpy"): + x_torch = torch.from_numpy(x.asnumpy()) else: - # Convert TVM Array to NumPy x_np = np.array(x) - # If it's still a Tensor, convert to numpy if isinstance(x_np, tvm.runtime.Tensor): - x_np = x_np.numpy() - # If the array contains Tensor objects, extract their values - if len(x_np) > 0 and isinstance(x_np[0], tvm.runtime.Tensor): - x_np = np.array([t.numpy() for t in x_np]) - # Flatten if needed to match expected shape - if x_np.ndim > 1: - x_np = x_np.flatten() - result = x_np + 1.0 - return tvm.runtime.tensor(result) - - def multiply_by_two(x): - """Multiply input tensor by two.""" + x_torch = torch.from_numpy(x_np.numpy()) + elif len(x_np) > 0 and isinstance(x_np[0], tvm.runtime.Tensor): + x_torch = torch.from_numpy(np.array([t.numpy() for t in x_np])) + if x_torch.ndim > 1: + x_torch = x_torch.flatten() + else: + x_torch = torch.from_numpy(x_np) + result = torch.relu(x_torch) + return tvm.runtime.tensor(result.numpy()) + + def torch_sigmoid(x): if isinstance(x, tvm.runtime.Tensor): - x_np = x.numpy() - elif hasattr(x, "asnumpy"): # Keep hasattr for backward compatibility - x_np = x.asnumpy() + x_torch = torch.from_numpy(x.numpy()) + elif hasattr(x, "asnumpy"): + x_torch = torch.from_numpy(x.asnumpy()) else: - # Convert TVM Array to NumPy x_np = np.array(x) - # If it's still a Tensor, convert to numpy if isinstance(x_np, tvm.runtime.Tensor): - x_np = x_np.numpy() - # If the array contains Tensor objects, extract their values - if len(x_np) > 0 and isinstance(x_np[0], tvm.runtime.Tensor): - x_np = np.array([t.numpy() for t in x_np]) - # Flatten if needed to match expected shape - if x_np.ndim > 1: - x_np = x_np.flatten() - result = x_np * 2.0 - return tvm.runtime.tensor(result) - - # Register Python functions + x_torch = torch.from_numpy(x_np.numpy()) + elif len(x_np) > 0 and isinstance(x_np[0], tvm.runtime.Tensor): + x_torch = torch.from_numpy(np.array([t.numpy() for t in x_np])) + if x_torch.ndim > 1: + x_torch = x_torch.flatten() + else: + x_torch = torch.from_numpy(x_np) + result = torch.sigmoid(x_torch) + return tvm.runtime.tensor(result.numpy()) + register_func = tvm.get_global_func("vm.builtin.register_py_func") - register_func("add_one", add_one) - register_func("multiply_by_two", multiply_by_two) + register_func("torch_relu", torch_relu) + register_func("torch_sigmoid", torch_sigmoid) @tvm.script.ir_module class CallPyFuncTest: @R.function def simple_call(x: R.Tensor((3,), "float32")): - # Simple call_py_func test - result = R.call_py_func(R.str("add_one"), (x,), out_sinfo=R.Tensor((3,), "float32")) + result = R.call_py_func(R.str("torch_relu"), (x,), out_sinfo=R.Tensor((3,), "float32")) return result @R.function def multiple_calls(x: R.Tensor((2,), "float32")): - # Multiple call_py_func calls - y = R.call_py_func(R.str("add_one"), (x,), out_sinfo=R.Tensor((2,), "float32")) - z = R.call_py_func(R.str("multiply_by_two"), (y,), out_sinfo=R.Tensor((2,), "float32")) + y = R.call_py_func(R.str("torch_relu"), (x,), out_sinfo=R.Tensor((2,), "float32")) + z = R.call_py_func(R.str("torch_sigmoid"), (y,), out_sinfo=R.Tensor((2,), "float32")) return z - try: - # Test simple call - x_data = np.array([1.0, 2.0, 3.0], dtype=np.float32) - x_tvm = tvm.runtime.tensor(x_data) - - result = run_cpu(CallPyFuncTest, "simple_call", x_tvm, exec_mode=exec_mode) - - # Verify result - expected = x_data + 1.0 - np.testing.assert_allclose(result.numpy(), expected, rtol=1e-5, atol=1e-5) + np.random.seed(0) + x_data = np.array([-1.0, 0.0, 1.0], dtype=np.float32) + x_tvm = tvm.runtime.tensor(x_data) - # Test multiple calls - y_data = np.array([0.5, 1.5], dtype=np.float32) - y_tvm = tvm.runtime.tensor(y_data) + result = run_cpu(CallPyFuncTest, "simple_call", x_tvm, exec_mode=exec_mode) + expected = np.maximum(x_data, 0.0) + assert (result.numpy() == expected).all() - result2 = run_cpu(CallPyFuncTest, "multiple_calls", y_tvm, exec_mode=exec_mode) + y_data = np.array([-0.5, 0.5], dtype=np.float32) + y_tvm = tvm.runtime.tensor(y_data) - # Verify result (should be (y + 1) * 2) - expected2 = (y_data + 1.0) * 2.0 - np.testing.assert_allclose(result2.numpy(), expected2, rtol=1e-5, atol=1e-5) + result2 = run_cpu(CallPyFuncTest, "multiple_calls", y_tvm, exec_mode=exec_mode) + expected2 = 1.0 / (1.0 + np.exp(-np.maximum(y_data, 0.0))) + assert (result2.numpy() == expected2).all() - finally: - # Clean up - clear Python function registry - try: - clear_func = tvm.get_global_func("vm.builtin.clear_py_func_registry") - clear_func() - except: - pass + clear_func = tvm.get_global_func("vm.builtin.clear_py_func_registry") + clear_func() def test_op_to_device(exec_mode): From 2f3f7fa69d4ecb373c1884e463e41bef6b32d0f9 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 19 Sep 2025 21:55:22 -0400 Subject: [PATCH 14/19] finish4 --- python/tvm/relax/base_py_module.py | 8 ++ .../relax/test_base_py_module_printer.py | 91 +++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index efd75ff2044b..7a790d28a720 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -45,6 +45,14 @@ class BasePyModule: Only IRModules that inherit from this class are allowed to contain Python functions. """ + def __del__(self): + """Clean up registered Python functions on module destruction.""" + try: + clear_func = tvm.get_global_func("vm.builtin.clear_py_func_registry") + clear_func() + except (ValueError, AttributeError): + pass + def __init__( self, ir_mod: IRModule, diff --git a/tests/python/relax/test_base_py_module_printer.py b/tests/python/relax/test_base_py_module_printer.py index 92c799f6cb70..c9d23a746567 100644 --- a/tests/python/relax/test_base_py_module_printer.py +++ b/tests/python/relax/test_base_py_module_printer.py @@ -758,3 +758,94 @@ def test_python_functions_in_irmodule(): assert pyfuncs["multiply"].__name__ == "multiply" else: pytest.fail("pyfuncs attribute not found in IRModule") + + +def test_call_py_func_with_base_py_module(): + """Test R.call_py_func with BasePyModule.""" + import torch + import numpy as np + from tvm.relax.op import call_py_func + from tvm.relax.expr import StringImm + from tvm.relax import Var, TensorStructInfo + + # Test 1: Operator creation and basic properties + x = Var("x", TensorStructInfo((5,), "float32")) + y = Var("y", TensorStructInfo((5,), "float32")) + + call_expr = call_py_func(StringImm("test_func"), (x, y), out_sinfo=R.Tensor((5,), "float32")) + + assert call_expr.op.name == "relax.call_py_func" + assert call_expr.args[0].value == "test_func" + assert len(call_expr.args) == 2 + + # Test 2: Compilation validation + try: + call_py_func( + "invalid", + (Var("x", TensorStructInfo((5,), "float32")),), + out_sinfo=R.Tensor((5,), "float32"), + ) + assert False, "Should raise type error" + except Exception as e: + assert "Mismatched type" in str(e) or "Expected" in str(e) + + # Test 3: Validation and error handling + @I.ir_module + class ValidationTestModule(BasePyModule): + @R.function + def test_invalid_call(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + result = R.call_py_func("non_existent_func", (x,), out_sinfo=R.Tensor((5,), "float32")) + return result + + device = tvm.cpu() + module = ValidationTestModule(device) + + x = torch.randn(5, dtype=torch.float32) + + with pytest.raises(ValueError, match="Python function 'non_existent_func' not found"): + module.call_py_func("non_existent_func", [x]) + + # Test 4: Using call_py_func within Relax functions + @I.ir_module + class RelaxCallPyFuncModule(BasePyModule): + @I.pyfunc + def torch_relu(self, x): + """PyTorch ReLU implementation.""" + return torch.relu(x) + + @I.pyfunc + def torch_softmax(self, x, dim=0): + """PyTorch softmax implementation.""" + return torch.softmax(x, dim=dim) + + @R.function + def mixed_computation(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32"): + relu_result = R.call_py_func("torch_relu", (x,), out_sinfo=R.Tensor((10,), "float32")) + final_result = R.call_py_func( + "torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), "float32") + ) + return final_result + + device = tvm.cpu() + module = RelaxCallPyFuncModule(device) + + x = torch.randn(10, dtype=torch.float32) + + expected = torch.softmax(torch.relu(x), dim=0) + + relu_result = module.call_py_func("torch_relu", [x]) + final_result = module.call_py_func("torch_softmax", [relu_result]) + + # Convert to numpy for comparison + if isinstance(final_result, tvm.runtime.Tensor): + final_result_np = final_result.numpy() + else: + final_result_np = final_result + + if isinstance(expected, torch.Tensor): + expected_np = expected.numpy() + else: + expected_np = expected + + # Use numpy for comparison since we have numpy arrays + np.testing.assert_allclose(final_result_np, expected_np, rtol=1e-5, atol=1e-5) From 419d91d20515b2a59c1f2864bb05d59b45f6ff4b Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 19 Sep 2025 23:56:31 -0400 Subject: [PATCH 15/19] finish5 --- src/relax/backend/vm/lower_runtime_builtin.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index 6a72f572d926..108c8d1979fb 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -151,7 +151,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { tuple_fields.push_back(call_node->args[0]); // function name tuple_fields.push_back(call_node->args[1]); // arguments tuple auto combined_tuple = Tuple(tuple_fields); - + // Direct call to vm.builtin.call_py_func return Call(builtin_call_py_func_, {combined_tuple}, call_node->attrs, call_node->sinfo_args, call_node->span); } From dbb69ef26a0cbd4647a0fd59feb3a6f07a2e6f34 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Sat, 20 Sep 2025 00:30:00 -0400 Subject: [PATCH 16/19] finish6 --- src/relax/backend/vm/lower_runtime_builtin.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index 108c8d1979fb..0344064b4594 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -153,7 +153,8 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { auto combined_tuple = Tuple(tuple_fields); // Direct call to vm.builtin.call_py_func - return Call(builtin_call_py_func_, {combined_tuple}, call_node->attrs, call_node->sinfo_args, call_node->span); + return Call(builtin_call_py_func_, {combined_tuple}, call_node->attrs, + call_node->sinfo_args, call_node->span); } Expr ToDevice(const Call& call_node) { From 8f7e97ae4c2b19305e99c593a5270d41ccc7bded Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Sat, 20 Sep 2025 00:42:39 -0400 Subject: [PATCH 17/19] finish7 --- src/runtime/vm/builtin.cc | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index 6da370e17a10..32e64e2c4eab 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -441,18 +441,14 @@ static std::unordered_map py_func_registry; /*! * \brief Clear the Python function registry on shutdown */ -void ClearPyFuncRegistry() { - py_func_registry.clear(); -} +void ClearPyFuncRegistry() { py_func_registry.clear(); } /*! * \brief Register a Python function for call_py_func * \param name The function name * \param func The Python function wrapped as ffi::Function */ -void RegisterPyFunc(const std::string& name, ffi::Function func) { - py_func_registry[name] = func; -} +void RegisterPyFunc(const std::string& name, ffi::Function func) { py_func_registry[name] = func; } /*! * \brief Get a registered Python function From dbc87f9810771e1fcc90cea78fef7d9a7d315aa3 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Sat, 20 Sep 2025 00:57:14 -0400 Subject: [PATCH 18/19] finish8 --- src/relax/backend/vm/lower_runtime_builtin.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index 0344064b4594..52ceb9f720d2 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -153,8 +153,8 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { auto combined_tuple = Tuple(tuple_fields); // Direct call to vm.builtin.call_py_func - return Call(builtin_call_py_func_, {combined_tuple}, call_node->attrs, - call_node->sinfo_args, call_node->span); + return Call(builtin_call_py_func_, {combined_tuple}, call_node->attrs, call_node->sinfo_args, + call_node->span); } Expr ToDevice(const Call& call_node) { From cddad3523e3f3f27e8f25a661b489de632f6ca01 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Sat, 20 Sep 2025 01:14:57 -0400 Subject: [PATCH 19/19] finish9 --- src/relax/backend/vm/codegen_vm.cc | 2 -- src/relax/backend/vm/lower_runtime_builtin.cc | 2 +- src/runtime/vm/builtin.cc | 1 + 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 334a5cfd5d99..e2d9b5b068b7 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -368,8 +368,6 @@ class CodeGenVM : public ExprFunctor { builder_->EmitCall(func, args, dst_reg); } - - void EmitNormalCall(const Call& call_node, RegName dst_reg) { Instruction::Arg func = VisitExpr(call_node->op); std::vector args = VisitArray(call_node->args); diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index 52ceb9f720d2..71b8413e9889 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -24,10 +24,10 @@ #include #include #include +#include #include #include #include -#include #include #include diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index 32e64e2c4eab..41c011678ef3 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -33,6 +33,7 @@ #include #include #include + #include namespace tvm {