diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index 52f813dc6b6d..7a790d28a720 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -45,6 +45,14 @@ class BasePyModule: Only IRModules that inherit from this class are allowed to contain Python functions. """ + def __del__(self): + """Clean up registered Python functions on module destruction.""" + try: + clear_func = tvm.get_global_func("vm.builtin.clear_py_func_registry") + clear_func() + except (ValueError, AttributeError): + pass + def __init__( self, ir_mod: IRModule, @@ -100,6 +108,7 @@ def _getattr_python_function(name: str) -> Any: self._compile_functions() self._wrap_tir_functions() self._wrap_relax_functions() + self._register_python_functions() def _collect_function_names(self): """Collect names of TIR and Relax functions from IRModule.""" @@ -177,6 +186,35 @@ def wrapper(*args, **kwargs): setattr(self, func_name, _create_relax_wrapper(func_name)) + def _register_python_functions(self): + """Register Python functions with the VM runtime for call_py_func support.""" + if not hasattr(self.ir_mod, "pyfuncs") or not self.ir_mod.pyfuncs: + return + + try: + register_py_func = tvm.get_global_func("vm.builtin.register_py_func") + except ValueError: + return + + for func_name, py_func in self.ir_mod.pyfuncs.items(): + + def create_py_func_wrapper(name, original_func): + def wrapper(*args, **kwargs): + converted_args = [self._convert_tvm_to_pytorch(arg) for arg in args] + converted_kwargs = { + k: self._convert_tvm_to_pytorch(v) for k, v in kwargs.items() + } + + result = original_func(self, *converted_args, **converted_kwargs) + + return self._convert_pytorch_to_tvm(result) + + wrapper.__name__ = name + return wrapper + + wrapped_func = create_py_func_wrapper(func_name, py_func) + register_py_func(func_name, wrapped_func) + def call_tir(self, tir_func, args, out_sinfo): """Call a TIR function with PyTorch tensors.""" # Try to get function name from different sources diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 96dac05cb63e..e2d9b5b068b7 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -368,7 +368,6 @@ class CodeGenVM : public ExprFunctor { builder_->EmitCall(func, args, dst_reg); } - void EmitNormalCall(const Call& call_node, RegName dst_reg) { Instruction::Arg func = VisitExpr(call_node->op); std::vector args = VisitArray(call_node->args); diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index d52155c615ac..71b8413e9889 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -52,6 +53,8 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { return ShapeOf(call); } else if (call->op == tensor_to_shape_op_) { return TensorToShape(call); + } else if (call->op == call_py_func_op_) { + return CallPyFunc(call); } else if (call->op == to_vdevice_op_) { return ToDevice(call); } else if (call->op == make_closure_op_) { @@ -139,6 +142,21 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { return Call(builtin_tensor_to_shape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); } + Expr CallPyFunc(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->struct_info_.defined()); + + // Create tuple with function name and arguments tuple + ffi::Array tuple_fields; + tuple_fields.push_back(call_node->args[0]); // function name + tuple_fields.push_back(call_node->args[1]); // arguments tuple + auto combined_tuple = Tuple(tuple_fields); + + // Direct call to vm.builtin.call_py_func + return Call(builtin_call_py_func_, {combined_tuple}, call_node->attrs, call_node->sinfo_args, + call_node->span); + } + Expr ToDevice(const Call& call_node) { // TODO(yongwww): replace ToVDeviceAttrs with related Expr ICHECK(call_node->args.size() == 1); @@ -198,6 +216,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { const Op& reshape_op_ = Op::Get("relax.reshape"); const Op& shape_of_op_ = Op::Get("relax.shape_of"); const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape"); + const Op& call_py_func_op_ = Op::Get("relax.call_py_func"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); const Op& make_closure_op_ = Op::Get("relax.make_closure"); const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); @@ -216,6 +235,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { const ExternFunc builtin_reshape_{"vm.builtin.reshape"}; const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"}; const ExternFunc builtin_tensor_to_shape_{"vm.builtin.tensor_to_shape"}; + const ExternFunc builtin_call_py_func_{"vm.builtin.call_py_func"}; const ExternFunc builtin_to_device_{"vm.builtin.to_device"}; const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"}; const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index 362a7e4c89aa..41c011678ef3 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -34,6 +34,8 @@ #include #include +#include + namespace tvm { namespace runtime { namespace vm { @@ -430,6 +432,78 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } +//------------------------------------- +// Python function call support +//------------------------------------- + +// Global registry for Python functions +static std::unordered_map py_func_registry; + +/*! + * \brief Clear the Python function registry on shutdown + */ +void ClearPyFuncRegistry() { py_func_registry.clear(); } + +/*! + * \brief Register a Python function for call_py_func + * \param name The function name + * \param func The Python function wrapped as ffi::Function + */ +void RegisterPyFunc(const std::string& name, ffi::Function func) { py_func_registry[name] = func; } + +/*! + * \brief Get a registered Python function + * \param name The function name + * \return The Python function + */ +ffi::Function GetPyFunc(const std::string& name) { + auto it = py_func_registry.find(name); + if (it == py_func_registry.end()) { + LOG(FATAL) << "Python function '" << name << "' not found in registry"; + } + return it->second; +} + +/*! + * \brief Call a Python function from VM + * \param args The packed function arguments (tuple containing function name and arguments) + * \param rv The return value + */ +void CallPyFunc(ffi::PackedArgs args, ffi::Any* rv) { + // args[0] should be a tuple containing (func_name, args_tuple) + if (args.size() != 1) { + LOG(FATAL) << "vm.builtin.call_py_func expects exactly 1 argument (tuple)"; + } + + auto tuple_arg = args[0].cast>(); + if (tuple_arg.size() != 2) { + LOG(FATAL) << "vm.builtin.call_py_func tuple should contain (func_name, args)"; + } + + // Get function name + std::string func_name = tuple_arg[0].cast(); + + // Get arguments tuple + auto func_args = tuple_arg[1].cast>(); + + // Look up Python function in registry + ffi::Function py_func = GetPyFunc(func_name); + + // Call the Python function with the arguments + std::vector py_args_vec(func_args.begin(), func_args.end()); + ffi::PackedArgs py_args(py_args_vec.data(), py_args_vec.size()); + py_func.CallPacked(py_args, rv); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def_packed("vm.builtin.call_py_func", CallPyFunc) + .def("vm.builtin.register_py_func", RegisterPyFunc) + .def("vm.builtin.get_py_func", GetPyFunc) + .def("vm.builtin.clear_py_func_registry", ClearPyFuncRegistry); +} + //------------------------------------- // Builtin runtime operators. //------------------------------------- diff --git a/tests/python/relax/test_base_py_module_printer.py b/tests/python/relax/test_base_py_module_printer.py index 6e87174fda35..c9d23a746567 100644 --- a/tests/python/relax/test_base_py_module_printer.py +++ b/tests/python/relax/test_base_py_module_printer.py @@ -760,43 +760,54 @@ def test_python_functions_in_irmodule(): pytest.fail("pyfuncs attribute not found in IRModule") -def test_call_py_func_validation(): - """Test call_py_func validation and error handling.""" +def test_call_py_func_with_base_py_module(): + """Test R.call_py_func with BasePyModule.""" import torch + import numpy as np + from tvm.relax.op import call_py_func + from tvm.relax.expr import StringImm + from tvm.relax import Var, TensorStructInfo - @I.ir_module - class ValidationTestModule(BasePyModule): - """Test module for validation.""" + # Test 1: Operator creation and basic properties + x = Var("x", TensorStructInfo((5,), "float32")) + y = Var("y", TensorStructInfo((5,), "float32")) - @I.pyfunc - def valid_func(self, x): - """Valid Python function.""" - return x * 2 + call_expr = call_py_func(StringImm("test_func"), (x, y), out_sinfo=R.Tensor((5,), "float32")) + assert call_expr.op.name == "relax.call_py_func" + assert call_expr.args[0].value == "test_func" + assert len(call_expr.args) == 2 + + # Test 2: Compilation validation + try: + call_py_func( + "invalid", + (Var("x", TensorStructInfo((5,), "float32")),), + out_sinfo=R.Tensor((5,), "float32"), + ) + assert False, "Should raise type error" + except Exception as e: + assert "Mismatched type" in str(e) or "Expected" in str(e) + + # Test 3: Validation and error handling + @I.ir_module + class ValidationTestModule(BasePyModule): @R.function def test_invalid_call(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): - # This should cause a validation error result = R.call_py_func("non_existent_func", (x,), out_sinfo=R.Tensor((5,), "float32")) return result device = tvm.cpu() module = ValidationTestModule(device) - # Test that calling non-existent function raises error x = torch.randn(5, dtype=torch.float32) with pytest.raises(ValueError, match="Python function 'non_existent_func' not found"): module.call_py_func("non_existent_func", [x]) - -def test_call_py_func_in_relax_function(): - """Test using call_py_func within Relax functions.""" - import torch - + # Test 4: Using call_py_func within Relax functions @I.ir_module class RelaxCallPyFuncModule(BasePyModule): - """Test module with call_py_func in Relax functions.""" - @I.pyfunc def torch_relu(self, x): """PyTorch ReLU implementation.""" @@ -809,9 +820,7 @@ def torch_softmax(self, x, dim=0): @R.function def mixed_computation(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32"): - # Use Python function for ReLU relu_result = R.call_py_func("torch_relu", (x,), out_sinfo=R.Tensor((10,), "float32")) - # Use Python function for softmax final_result = R.call_py_func( "torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), "float32") ) @@ -820,7 +829,6 @@ def mixed_computation(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32 device = tvm.cpu() module = RelaxCallPyFuncModule(device) - # Test the mixed computation x = torch.randn(10, dtype=torch.float32) expected = torch.softmax(torch.relu(x), dim=0) @@ -828,40 +836,16 @@ def mixed_computation(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32 relu_result = module.call_py_func("torch_relu", [x]) final_result = module.call_py_func("torch_softmax", [relu_result]) - assert torch.allclose(final_result, expected, atol=1e-5) - - -def test_call_py_func_operator_creation(): - """Test R.call_py_func operator creation and basic properties.""" - from tvm.relax.op import call_py_func - from tvm.relax.expr import StringImm - from tvm.relax import Var, TensorStructInfo - - # Create variables - x = Var("x", TensorStructInfo((5,), "float32")) - y = Var("y", TensorStructInfo((5,), "float32")) - - # Create call_py_func call - call_expr = call_py_func(StringImm("test_func"), (x, y), out_sinfo=R.Tensor((5,), "float32")) - - # Verify operator properties - assert call_expr.op.name == "relax.call_py_func" - assert call_expr.args[0].value == "test_func" - assert len(call_expr.args) == 2 - + # Convert to numpy for comparison + if isinstance(final_result, tvm.runtime.Tensor): + final_result_np = final_result.numpy() + else: + final_result_np = final_result -def test_call_py_func_compilation_validation(): - """Test call_py_func compilation validation.""" - from tvm.relax.op import call_py_func - from tvm.relax import Var, TensorStructInfo + if isinstance(expected, torch.Tensor): + expected_np = expected.numpy() + else: + expected_np = expected - # Test operator parameter validation - try: - call_py_func( - "invalid", - (Var("x", TensorStructInfo((5,), "float32")),), - out_sinfo=R.Tensor((5,), "float32"), - ) - assert False, "Should raise type error" - except Exception as e: - assert "Mismatched type" in str(e) or "Expected" in str(e) + # Use numpy for comparison since we have numpy arrays + np.testing.assert_allclose(final_result_np, expected_np, rtol=1e-5, atol=1e-5) diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index 8558f6e911b8..897082dd792f 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -409,6 +409,82 @@ def inplace_tuple(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32") assert (result[1].numpy() == sum).all() +def test_op_call_py_func(exec_mode): + """Test R.call_py_func operator functionality.""" + import torch + + def torch_relu(x): + if isinstance(x, tvm.runtime.Tensor): + x_torch = torch.from_numpy(x.numpy()) + elif hasattr(x, "asnumpy"): + x_torch = torch.from_numpy(x.asnumpy()) + else: + x_np = np.array(x) + if isinstance(x_np, tvm.runtime.Tensor): + x_torch = torch.from_numpy(x_np.numpy()) + elif len(x_np) > 0 and isinstance(x_np[0], tvm.runtime.Tensor): + x_torch = torch.from_numpy(np.array([t.numpy() for t in x_np])) + if x_torch.ndim > 1: + x_torch = x_torch.flatten() + else: + x_torch = torch.from_numpy(x_np) + result = torch.relu(x_torch) + return tvm.runtime.tensor(result.numpy()) + + def torch_sigmoid(x): + if isinstance(x, tvm.runtime.Tensor): + x_torch = torch.from_numpy(x.numpy()) + elif hasattr(x, "asnumpy"): + x_torch = torch.from_numpy(x.asnumpy()) + else: + x_np = np.array(x) + if isinstance(x_np, tvm.runtime.Tensor): + x_torch = torch.from_numpy(x_np.numpy()) + elif len(x_np) > 0 and isinstance(x_np[0], tvm.runtime.Tensor): + x_torch = torch.from_numpy(np.array([t.numpy() for t in x_np])) + if x_torch.ndim > 1: + x_torch = x_torch.flatten() + else: + x_torch = torch.from_numpy(x_np) + result = torch.sigmoid(x_torch) + return tvm.runtime.tensor(result.numpy()) + + register_func = tvm.get_global_func("vm.builtin.register_py_func") + register_func("torch_relu", torch_relu) + register_func("torch_sigmoid", torch_sigmoid) + + @tvm.script.ir_module + class CallPyFuncTest: + @R.function + def simple_call(x: R.Tensor((3,), "float32")): + result = R.call_py_func(R.str("torch_relu"), (x,), out_sinfo=R.Tensor((3,), "float32")) + return result + + @R.function + def multiple_calls(x: R.Tensor((2,), "float32")): + y = R.call_py_func(R.str("torch_relu"), (x,), out_sinfo=R.Tensor((2,), "float32")) + z = R.call_py_func(R.str("torch_sigmoid"), (y,), out_sinfo=R.Tensor((2,), "float32")) + return z + + np.random.seed(0) + x_data = np.array([-1.0, 0.0, 1.0], dtype=np.float32) + x_tvm = tvm.runtime.tensor(x_data) + + result = run_cpu(CallPyFuncTest, "simple_call", x_tvm, exec_mode=exec_mode) + expected = np.maximum(x_data, 0.0) + assert (result.numpy() == expected).all() + + y_data = np.array([-0.5, 0.5], dtype=np.float32) + y_tvm = tvm.runtime.tensor(y_data) + + result2 = run_cpu(CallPyFuncTest, "multiple_calls", y_tvm, exec_mode=exec_mode) + expected2 = 1.0 / (1.0 + np.exp(-np.maximum(y_data, 0.0))) + assert (result2.numpy() == expected2).all() + + clear_func = tvm.get_global_func("vm.builtin.clear_py_func_registry") + clear_func() + + def test_op_to_device(exec_mode): @tvm.script.ir_module class CallToDevice: