diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index a4464cc737b9..52f813dc6b6d 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -234,12 +234,11 @@ def call_dps_packed(self, func_name: str, args, out_sinfo): return out[0] if len(out) == 1 else out def call_py_func(self, func_name: str, args): - """Call a Python function stored in the IRModule's pyfuncs.""" - if func_name not in self.ir_mod.pyfuncs: - raise ValueError(f"Python function '{func_name}' not found in IRModule pyfuncs") - py_func = self.ir_mod.pyfuncs[func_name] - converted_args = self._convert_tvm_to_pytorch(args) - return py_func(*converted_args) + """Call a Python function stored in the module's pyfuncs.""" + if func_name not in self.pyfuncs: + raise ValueError(f"Python function '{func_name}' not found in module pyfuncs") + py_func = self.pyfuncs[func_name] + return py_func(self, *args) def _create_output_tensors(self, out_sinfo, in_args=None): # pylint: disable=import-outside-toplevel diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index fd3672368b68..6ea8305ecadb 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -27,6 +27,7 @@ call_dps_packed, call_inplace_packed, call_pure_packed, + call_py_func, call_tir, call_tir_inplace, call_tir_with_grad, diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index e77920d8dea6..e205abde30b4 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -304,6 +304,42 @@ def call_dps_packed( return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore +@args_converter.auto +def call_py_func( + func_name: str, + args: Expr, + out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]], +) -> Call: + """ + Call a Python function and return the output. + + Parameters + ---------- + func_name : str + The name of the Python function to call. This should correspond to a function + in the IRModule's pyfuncs attribute. + + args : Expr + The input arguments. + + out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] + The structure info of the call_py_func output. + It should be a single or a list of TensorStructInfo. Each one denotes the + structure info of a returned tensor. + + Returns + ------- + ret: Call + A call node for the call_py_func operator. + """ + args = _wrap_inline_arg_tuple(args) + + if not isinstance(out_sinfo, list): + out_sinfo = [out_sinfo] + + return _ffi_api.call_py_func(func_name, args, out_sinfo) # type: ignore + + @args_converter.auto def call_builtin_with_ctx( func: Union[str, Expr], diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index d28ff3430aaa..3fa735197ac5 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -30,6 +30,7 @@ Expr, ExternFunc, ShapeExpr, + StringImm, TupleGetItem, Var, VarBinding, @@ -64,6 +65,7 @@ call_dps_packed, call_inplace_packed, call_pure_packed, + call_py_func as _call_py_func, call_tir, call_tir_inplace, call_tir_with_grad, @@ -451,6 +453,57 @@ def call_packed( return Call(op, args, attrs=attrs, sinfo_args=sinfo_args) +@args_converter.auto +def call_py_func( + py_func_name: py_str, + *args: Expr, + out_sinfo: Union[StructInfo, List[StructInfo]], +) -> Call: + """Create a relax Call, which calls a Python function. + + Parameters + ---------- + py_func_name: str + The name of the Python function to call. This should correspond to a function + in the IRModule's pyfuncs attribute. + *args : Expr + The arguments. + out_sinfo: Union[StructInfo, List[StructInfo]] + The structure info of the call_py_func output. + It should be a single or a list of TensorStructInfo. Each one denotes the + structure info of a returned tensor. + + Returns + ------- + call: Call + The created Relax Call for call_py_func operator. + """ + if isinstance(out_sinfo, py_tuple): # type: ignore + out_sinfo = list(out_sinfo) + elif not isinstance(out_sinfo, list): + out_sinfo = [out_sinfo] + + out_sinfo = [ + ( + sinfo() + if callable(sinfo) + else sinfo.asobject() + if isinstance(sinfo, ObjectConvertible) + else sinfo + ) + for sinfo in out_sinfo + ] + + # Convert string to StringImm + try: + func_name_imm = ( + StringImm(py_func_name) if isinstance(py_func_name, py_str) else py_func_name + ) + except (TypeError, ValueError, AttributeError): + func_name_imm = StringImm(py_func_name) + return _call_py_func(func_name_imm, args, out_sinfo) + + def _sinfo_arg_wrapper(func): """A wrapper to convert StructInfoProxies to StructInfo for builtin operators with sinfo_args""" @@ -743,6 +796,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "call_tir_inplace", "call_tir_with_grad", "call_dps_packed", + "call_py_func", "call_builtin_with_ctx", "ceil", "clip", diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index ddf6a056f00a..b956b4d168fd 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -858,6 +858,70 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef().def("relax.op.call_dps_packed", MakeCallDPSPacked); }); +// call_py_func + +StructInfo InferStructInfoCallPyFunc(const Call& call, const BlockBuilder& ctx) { + if (call->sinfo_args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "sinfo_args should have exact 1 output struct info."); + } + return call->sinfo_args[0]; +} + +void ValidateCallPyFunc(Call call) { + // Validate that the function name is a string literal + auto func_name = call->args[0]; + CHECK(func_name->IsInstance()) + << "Operation " << call->op << " expects the first argument to be a string literal " + << "specifying the Python function name. However, the first argument " << func_name + << " is not a string literal."; + + // Validate that args is a tuple + Expr arg_tuple = call->args[1]; + CHECK(arg_tuple->struct_info_.as()) + << "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. " + << "However, the second argument " << arg_tuple << " has struct info " + << arg_tuple->struct_info_ << "."; + + CHECK(arg_tuple.as() || arg_tuple.as()) + << "Operation " << call->op << " must hold its arguments as an in-line tuple. " + << "However, " << call << " has arguments " << arg_tuple + << ", which is neither an in-line tuple, " + << "nor a variable binding that may be normalized to an in-line tuple."; +} + +TVM_REGISTER_OP("relax.call_py_func") + .set_num_inputs(2) + .add_argument("func_name", "StringImm", "The name of the Python function to call.") + .add_argument("args", "Tuple", "The input arguments.") + .set_attr("FInferStructInfo", InferStructInfoCallPyFunc) + .set_attr("FValidate", ValidateCallPyFunc) + .set_attr("FPurity", Bool(true)); + +Expr MakeCallPyFunc(StringImm func_name, Tuple args, ffi::Array out_sinfo_list) { + for (const TensorStructInfo& sinfo : out_sinfo_list) { + const auto* shape = sinfo->shape.as(); + CHECK(shape != nullptr) << "out_sinfo of call_py_func should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; + } + + StructInfo out_sinfo{nullptr}; + if (out_sinfo_list.size() == 1) { + out_sinfo = out_sinfo_list[0]; + } else { + out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + } + + static const Op& op = Op::Get("relax.call_py_func"); + return Call(op, {func_name, args}, {}, {out_sinfo}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.call_py_func", MakeCallPyFunc); +} + // call builtin StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { if (call->sinfo_args.size() == 0) { diff --git a/tests/python/relax/test_base_py_module_printer.py b/tests/python/relax/test_base_py_module_printer.py index 92c799f6cb70..6e87174fda35 100644 --- a/tests/python/relax/test_base_py_module_printer.py +++ b/tests/python/relax/test_base_py_module_printer.py @@ -758,3 +758,110 @@ def test_python_functions_in_irmodule(): assert pyfuncs["multiply"].__name__ == "multiply" else: pytest.fail("pyfuncs attribute not found in IRModule") + + +def test_call_py_func_validation(): + """Test call_py_func validation and error handling.""" + import torch + + @I.ir_module + class ValidationTestModule(BasePyModule): + """Test module for validation.""" + + @I.pyfunc + def valid_func(self, x): + """Valid Python function.""" + return x * 2 + + @R.function + def test_invalid_call(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + # This should cause a validation error + result = R.call_py_func("non_existent_func", (x,), out_sinfo=R.Tensor((5,), "float32")) + return result + + device = tvm.cpu() + module = ValidationTestModule(device) + + # Test that calling non-existent function raises error + x = torch.randn(5, dtype=torch.float32) + + with pytest.raises(ValueError, match="Python function 'non_existent_func' not found"): + module.call_py_func("non_existent_func", [x]) + + +def test_call_py_func_in_relax_function(): + """Test using call_py_func within Relax functions.""" + import torch + + @I.ir_module + class RelaxCallPyFuncModule(BasePyModule): + """Test module with call_py_func in Relax functions.""" + + @I.pyfunc + def torch_relu(self, x): + """PyTorch ReLU implementation.""" + return torch.relu(x) + + @I.pyfunc + def torch_softmax(self, x, dim=0): + """PyTorch softmax implementation.""" + return torch.softmax(x, dim=dim) + + @R.function + def mixed_computation(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32"): + # Use Python function for ReLU + relu_result = R.call_py_func("torch_relu", (x,), out_sinfo=R.Tensor((10,), "float32")) + # Use Python function for softmax + final_result = R.call_py_func( + "torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), "float32") + ) + return final_result + + device = tvm.cpu() + module = RelaxCallPyFuncModule(device) + + # Test the mixed computation + x = torch.randn(10, dtype=torch.float32) + + expected = torch.softmax(torch.relu(x), dim=0) + + relu_result = module.call_py_func("torch_relu", [x]) + final_result = module.call_py_func("torch_softmax", [relu_result]) + + assert torch.allclose(final_result, expected, atol=1e-5) + + +def test_call_py_func_operator_creation(): + """Test R.call_py_func operator creation and basic properties.""" + from tvm.relax.op import call_py_func + from tvm.relax.expr import StringImm + from tvm.relax import Var, TensorStructInfo + + # Create variables + x = Var("x", TensorStructInfo((5,), "float32")) + y = Var("y", TensorStructInfo((5,), "float32")) + + # Create call_py_func call + call_expr = call_py_func(StringImm("test_func"), (x, y), out_sinfo=R.Tensor((5,), "float32")) + + # Verify operator properties + assert call_expr.op.name == "relax.call_py_func" + assert call_expr.args[0].value == "test_func" + assert len(call_expr.args) == 2 + + +def test_call_py_func_compilation_validation(): + """Test call_py_func compilation validation.""" + from tvm.relax.op import call_py_func + from tvm.relax import Var, TensorStructInfo + + # Test operator parameter validation + try: + call_py_func( + "invalid", + (Var("x", TensorStructInfo((5,), "float32")),), + out_sinfo=R.Tensor((5,), "float32"), + ) + assert False, "Should raise type error" + except Exception as e: + assert "Mismatched type" in str(e) or "Expected" in str(e)