Skip to content
38 changes: 38 additions & 0 deletions python/tvm/relax/base_py_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ class BasePyModule:
Only IRModules that inherit from this class are allowed to contain Python functions.
"""

def __del__(self):
"""Clean up registered Python functions on module destruction."""
try:
clear_func = tvm.get_global_func("vm.builtin.clear_py_func_registry")
clear_func()
except (ValueError, AttributeError):
pass

def __init__(
self,
ir_mod: IRModule,
Expand Down Expand Up @@ -100,6 +108,7 @@ def _getattr_python_function(name: str) -> Any:
self._compile_functions()
self._wrap_tir_functions()
self._wrap_relax_functions()
self._register_python_functions()

def _collect_function_names(self):
"""Collect names of TIR and Relax functions from IRModule."""
Expand Down Expand Up @@ -177,6 +186,35 @@ def wrapper(*args, **kwargs):

setattr(self, func_name, _create_relax_wrapper(func_name))

def _register_python_functions(self):
"""Register Python functions with the VM runtime for call_py_func support."""
if not hasattr(self.ir_mod, "pyfuncs") or not self.ir_mod.pyfuncs:
return

try:
register_py_func = tvm.get_global_func("vm.builtin.register_py_func")
except ValueError:
return

for func_name, py_func in self.ir_mod.pyfuncs.items():

def create_py_func_wrapper(name, original_func):
def wrapper(*args, **kwargs):
converted_args = [self._convert_tvm_to_pytorch(arg) for arg in args]
converted_kwargs = {
k: self._convert_tvm_to_pytorch(v) for k, v in kwargs.items()
}

result = original_func(self, *converted_args, **converted_kwargs)

return self._convert_pytorch_to_tvm(result)

wrapper.__name__ = name
return wrapper

wrapped_func = create_py_func_wrapper(func_name, py_func)
register_py_func(func_name, wrapped_func)

def call_tir(self, tir_func, args, out_sinfo):
"""Call a TIR function with PyTorch tensors."""
# Try to get function name from different sources
Expand Down
1 change: 0 additions & 1 deletion src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,6 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {

builder_->EmitCall(func, args, dst_reg);
}

void EmitNormalCall(const Call& call_node, RegName dst_reg) {
Instruction::Arg func = VisitExpr(call_node->op);
std::vector<Instruction::Arg> args = VisitArray(call_node->args);
Expand Down
20 changes: 20 additions & 0 deletions src/relax/backend/vm/lower_runtime_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/backend.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/type.h>
Expand Down Expand Up @@ -52,6 +53,8 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
return ShapeOf(call);
} else if (call->op == tensor_to_shape_op_) {
return TensorToShape(call);
} else if (call->op == call_py_func_op_) {
return CallPyFunc(call);
} else if (call->op == to_vdevice_op_) {
return ToDevice(call);
} else if (call->op == make_closure_op_) {
Expand Down Expand Up @@ -139,6 +142,21 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
return Call(builtin_tensor_to_shape_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}

Expr CallPyFunc(const Call& call_node) {
ICHECK(call_node->args.size() == 2);
ICHECK(call_node->struct_info_.defined());

// Create tuple with function name and arguments tuple
ffi::Array<Expr> tuple_fields;
tuple_fields.push_back(call_node->args[0]); // function name
tuple_fields.push_back(call_node->args[1]); // arguments tuple
auto combined_tuple = Tuple(tuple_fields);

// Direct call to vm.builtin.call_py_func
return Call(builtin_call_py_func_, {combined_tuple}, call_node->attrs, call_node->sinfo_args,
call_node->span);
}

Expr ToDevice(const Call& call_node) {
// TODO(yongwww): replace ToVDeviceAttrs with related Expr
ICHECK(call_node->args.size() == 1);
Expand Down Expand Up @@ -198,6 +216,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
const Op& reshape_op_ = Op::Get("relax.reshape");
const Op& shape_of_op_ = Op::Get("relax.shape_of");
const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape");
const Op& call_py_func_op_ = Op::Get("relax.call_py_func");
const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
Expand All @@ -216,6 +235,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"};
const ExternFunc builtin_tensor_to_shape_{"vm.builtin.tensor_to_shape"};
const ExternFunc builtin_call_py_func_{"vm.builtin.call_py_func"};
const ExternFunc builtin_to_device_{"vm.builtin.to_device"};
const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
Expand Down
74 changes: 74 additions & 0 deletions src/runtime/vm/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
#include <tvm/runtime/vm/bytecode.h>
#include <tvm/runtime/vm/vm.h>

#include <unordered_map>

namespace tvm {
namespace runtime {
namespace vm {
Expand Down Expand Up @@ -430,6 +432,78 @@ TVM_FFI_STATIC_INIT_BLOCK() {
});
}

//-------------------------------------
// Python function call support
//-------------------------------------

// Global registry for Python functions
static std::unordered_map<std::string, ffi::Function> py_func_registry;

/*!
* \brief Clear the Python function registry on shutdown
*/
void ClearPyFuncRegistry() { py_func_registry.clear(); }

/*!
* \brief Register a Python function for call_py_func
* \param name The function name
* \param func The Python function wrapped as ffi::Function
*/
void RegisterPyFunc(const std::string& name, ffi::Function func) { py_func_registry[name] = func; }

/*!
* \brief Get a registered Python function
* \param name The function name
* \return The Python function
*/
ffi::Function GetPyFunc(const std::string& name) {
auto it = py_func_registry.find(name);
if (it == py_func_registry.end()) {
LOG(FATAL) << "Python function '" << name << "' not found in registry";
}
return it->second;
}

/*!
* \brief Call a Python function from VM
* \param args The packed function arguments (tuple containing function name and arguments)
* \param rv The return value
*/
void CallPyFunc(ffi::PackedArgs args, ffi::Any* rv) {
// args[0] should be a tuple containing (func_name, args_tuple)
if (args.size() != 1) {
LOG(FATAL) << "vm.builtin.call_py_func expects exactly 1 argument (tuple)";
}

auto tuple_arg = args[0].cast<ffi::Array<ffi::Any>>();
if (tuple_arg.size() != 2) {
LOG(FATAL) << "vm.builtin.call_py_func tuple should contain (func_name, args)";
}

// Get function name
std::string func_name = tuple_arg[0].cast<ffi::String>();

// Get arguments tuple
auto func_args = tuple_arg[1].cast<ffi::Array<ffi::Any>>();

// Look up Python function in registry
ffi::Function py_func = GetPyFunc(func_name);

// Call the Python function with the arguments
std::vector<ffi::AnyView> py_args_vec(func_args.begin(), func_args.end());
ffi::PackedArgs py_args(py_args_vec.data(), py_args_vec.size());
py_func.CallPacked(py_args, rv);
}

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed("vm.builtin.call_py_func", CallPyFunc)
.def("vm.builtin.register_py_func", RegisterPyFunc)
.def("vm.builtin.get_py_func", GetPyFunc)
.def("vm.builtin.clear_py_func_registry", ClearPyFuncRegistry);
}

//-------------------------------------
// Builtin runtime operators.
//-------------------------------------
Expand Down
96 changes: 40 additions & 56 deletions tests/python/relax/test_base_py_module_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,43 +760,54 @@ def test_python_functions_in_irmodule():
pytest.fail("pyfuncs attribute not found in IRModule")


def test_call_py_func_validation():
"""Test call_py_func validation and error handling."""
def test_call_py_func_with_base_py_module():
"""Test R.call_py_func with BasePyModule."""
import torch
import numpy as np
from tvm.relax.op import call_py_func
from tvm.relax.expr import StringImm
from tvm.relax import Var, TensorStructInfo

@I.ir_module
class ValidationTestModule(BasePyModule):
"""Test module for validation."""
# Test 1: Operator creation and basic properties
x = Var("x", TensorStructInfo((5,), "float32"))
y = Var("y", TensorStructInfo((5,), "float32"))

@I.pyfunc
def valid_func(self, x):
"""Valid Python function."""
return x * 2
call_expr = call_py_func(StringImm("test_func"), (x, y), out_sinfo=R.Tensor((5,), "float32"))

assert call_expr.op.name == "relax.call_py_func"
assert call_expr.args[0].value == "test_func"
assert len(call_expr.args) == 2

# Test 2: Compilation validation
try:
call_py_func(
"invalid",
(Var("x", TensorStructInfo((5,), "float32")),),
out_sinfo=R.Tensor((5,), "float32"),
)
assert False, "Should raise type error"
except Exception as e:
assert "Mismatched type" in str(e) or "Expected" in str(e)

# Test 3: Validation and error handling
@I.ir_module
class ValidationTestModule(BasePyModule):
@R.function
def test_invalid_call(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
# This should cause a validation error
result = R.call_py_func("non_existent_func", (x,), out_sinfo=R.Tensor((5,), "float32"))
return result

device = tvm.cpu()
module = ValidationTestModule(device)

# Test that calling non-existent function raises error
x = torch.randn(5, dtype=torch.float32)

with pytest.raises(ValueError, match="Python function 'non_existent_func' not found"):
module.call_py_func("non_existent_func", [x])


def test_call_py_func_in_relax_function():
"""Test using call_py_func within Relax functions."""
import torch

# Test 4: Using call_py_func within Relax functions
@I.ir_module
class RelaxCallPyFuncModule(BasePyModule):
"""Test module with call_py_func in Relax functions."""

@I.pyfunc
def torch_relu(self, x):
"""PyTorch ReLU implementation."""
Expand All @@ -809,9 +820,7 @@ def torch_softmax(self, x, dim=0):

@R.function
def mixed_computation(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32"):
# Use Python function for ReLU
relu_result = R.call_py_func("torch_relu", (x,), out_sinfo=R.Tensor((10,), "float32"))
# Use Python function for softmax
final_result = R.call_py_func(
"torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), "float32")
)
Expand All @@ -820,48 +829,23 @@ def mixed_computation(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32
device = tvm.cpu()
module = RelaxCallPyFuncModule(device)

# Test the mixed computation
x = torch.randn(10, dtype=torch.float32)

expected = torch.softmax(torch.relu(x), dim=0)

relu_result = module.call_py_func("torch_relu", [x])
final_result = module.call_py_func("torch_softmax", [relu_result])

assert torch.allclose(final_result, expected, atol=1e-5)


def test_call_py_func_operator_creation():
"""Test R.call_py_func operator creation and basic properties."""
from tvm.relax.op import call_py_func
from tvm.relax.expr import StringImm
from tvm.relax import Var, TensorStructInfo

# Create variables
x = Var("x", TensorStructInfo((5,), "float32"))
y = Var("y", TensorStructInfo((5,), "float32"))

# Create call_py_func call
call_expr = call_py_func(StringImm("test_func"), (x, y), out_sinfo=R.Tensor((5,), "float32"))

# Verify operator properties
assert call_expr.op.name == "relax.call_py_func"
assert call_expr.args[0].value == "test_func"
assert len(call_expr.args) == 2

# Convert to numpy for comparison
if isinstance(final_result, tvm.runtime.Tensor):
final_result_np = final_result.numpy()
else:
final_result_np = final_result

def test_call_py_func_compilation_validation():
"""Test call_py_func compilation validation."""
from tvm.relax.op import call_py_func
from tvm.relax import Var, TensorStructInfo
if isinstance(expected, torch.Tensor):
expected_np = expected.numpy()
else:
expected_np = expected

# Test operator parameter validation
try:
call_py_func(
"invalid",
(Var("x", TensorStructInfo((5,), "float32")),),
out_sinfo=R.Tensor((5,), "float32"),
)
assert False, "Should raise type error"
except Exception as e:
assert "Mismatched type" in str(e) or "Expected" in str(e)
# Use numpy for comparison since we have numpy arrays
np.testing.assert_allclose(final_result_np, expected_np, rtol=1e-5, atol=1e-5)
Loading
Loading