diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index 397eebe03140..594ce3b7d7e6 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -1,7 +1,7 @@ from __future__ import annotations import inspect -from typing import TypeVar, Generic, Union, Dict, List, Tuple, Optional +from typing import TypeVar, Generic, Union, Dict, List, Tuple, Optional, Callable from io import StringIO from enum import Enum @@ -837,11 +837,6 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, rx.Expr]: op = self.transform_expr(expr.func_name) if op == SpecialOp.CALL_PACKED: - if len(expr.params) != 2: - self.report_error( - op.value + " takes an extern function name and a tuple of arguments", - expr.span, - ) extern_func = expr.params[0] if not (isinstance(extern_func, ast.Constant) and isinstance(extern_func.value, str)): self.report_error( @@ -849,7 +844,7 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, rx.Expr]: extern_func.span, ) op = rx.ExternFunc(extern_func.value, self.to_tvm_span(extern_func.span)) - args = [self.transform_expr(expr.params[1])] + args = [self.transform_expr(arg) for arg in expr.params[1:]] elif isinstance(op, ArithmeticOp): args = [self.transform_expr(arg) for arg in expr.params] @@ -1045,7 +1040,7 @@ def transform_block(self, block: ast.Block) -> rx.SeqExpr: # self.tvm_diag_ctx.render() -def script(f) -> Union[rx.Function, tvm.IRModule]: +def script(f) -> Union[rx.Function, Callable[[], tvm.IRModule]]: """Parses the decorated Relax function or module (in Relax IR) to a Relax AST. Parameters @@ -1056,15 +1051,19 @@ def script(f) -> Union[rx.Function, tvm.IRModule]: Returns ------- Union[rx.Function, IRModule] - The parsed Relax function or IRModule. + The parsed Relax function or IRModule factory (which returns the parsed IRModule when + called). """ diag_ctx = tvm.script.diagnostics.TVMDiagnosticCtx() ast = synr.to_ast(f, diag_ctx) - return RelaxTransformer().do_transform(ast, diag_ctx) + mod = RelaxTransformer().do_transform(ast, diag_ctx) + if isinstance(mod, tvm.IRModule): + return lambda: mod + return mod def fromtext(source: str, source_name: str = "from_string") -> Union[rx.Function, tvm.IRModule]: - """Parses the given input string (in the Relax text format) to a Relax function or IRModule. + """Parses the given input string (in the Relax text format) to a Relax AST. Parameters ---------- @@ -1076,12 +1075,16 @@ def fromtext(source: str, source_name: str = "from_string") -> Union[rx.Function Returns ------- Union[rx.Function, IRModule] - The parsed Relax function or IRModule. + The parsed Relax function or IRModule factory (which returns the parsed IRModule when + called). """ # TODO(@altanh): actually use source_name somewhere? diag_ctx = tvm.script.diagnostics.TVMDiagnosticCtx() ast = synr.to_ast(source, diag_ctx) - return RelaxTransformer().do_transform(ast, diag_ctx) + mod = RelaxTransformer().do_transform(ast, diag_ctx) + if isinstance(mod, tvm.IRModule): + return lambda: mod + return mod def pretty_print(node): diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index 627d4da7602f..c48c85e0edcf 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -77,6 +77,7 @@ class RelaxScriptPrinter : public relax::IRFunctor, Doc VisitExpr_(const tir::FloorDivNode* op) override; Doc PrintIRModule(const IRModule& mod); + Doc PrintPrimFunc(const String& name, const tir::PrimFunc& func); Doc PrintIfStmt(const relax::Var& var, const relay::If& ite); Doc PrintFunctionDef(const Doc& name, const relax::Function& func); @@ -178,17 +179,19 @@ Doc RelaxScriptPrinter::VisitNode_(const relay::CallNode* op) { // TODO(@altanh): how to support when func cannot be printed as Python expr? // e.g. Function or If Doc doc; + std::vector args; if (op->op.as()) { - ICHECK_EQ(op->args.size(), 1) << "extern calls should only have one argument"; - doc << "relax.call_packed(" << Print(op->op) << ", " << Print(op->args[0]); + doc << "relax.call_packed"; + args.push_back(Print(op->op)); } else { - std::vector args; - for (const Expr& arg : op->args) { - args.push_back(Print(arg)); - } - doc << Print(op->op) << "(" << Doc::Concat(args, Doc::Text(", ")); + doc << Print(op->op); + } + + for (const Expr& arg : op->args) { + args.push_back(Print(arg)); } + doc << "(" << Doc::Concat(args, Doc::Text(", ")); std::vector attrs = PrintAttrs(op->attrs); if (!attrs.empty()) { @@ -260,12 +263,7 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::VarBindingNode* op) { } else if (const relax::FunctionNode* func = op->value.as()) { return PrintFunctionDef(Print(op->var), GetRef(func)); } else if (const tir::PrimFuncNode* prim_func = op->value.as()) { - // we need the mod for TVMScriptPrinter to properly print the function name - maybe it's worth - // refactoring to avoid this? - tir::PrimFunc prim_func_ref = GetRef(prim_func); - IRModule mod; - mod->Add(relay::GlobalVar(op->var->name_hint()), prim_func_ref); - return tir::AsTVMScriptDoc(mod, false, prim_func_ref); + return PrintPrimFunc(op->var->name_hint(), GetRef(prim_func)); } else { Doc doc; doc << Print(op->var) << PrintVarAnnotation(op->var); @@ -426,11 +424,25 @@ Doc RelaxScriptPrinter::PrintIRModule(const IRModule& mod) { Doc doc; doc << "class Module:"; for (const std::pair& pr : mod->functions) { - doc << Doc::Indent(4, Doc::NewLine() << Print(pr.second)); + Doc func; + if (pr.second.as()) { + func = PrintPrimFunc(pr.first->name_hint, Downcast(pr.second)); + } else { + func = Print(pr.second); + } + doc << Doc::Indent(4, Doc::NewLine() << func); } return doc; } +Doc RelaxScriptPrinter::PrintPrimFunc(const String& name, const tir::PrimFunc& func) { + // we need the mod for TVMScriptPrinter to properly print the function name - maybe it's worth + // refactoring to avoid this? + IRModule mod; + mod->Add(relay::GlobalVar(name), func); + return tir::AsTVMScriptDoc(mod, false, func); +} + Doc RelaxScriptPrinter::PrintIfStmt(const relax::Var& var, const relay::If& ite) { const relax::SeqExprNode* true_branch = ite->true_branch.as(); const relax::SeqExprNode* false_branch = ite->false_branch.as(); diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 5785e791a66d..93392281ee43 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -432,9 +432,9 @@ def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: def test_call_packed(): @rx.script - def f(x: Tensor[(3, 4), "float32"]): + def f(x: Tensor[(3, 3), "float32"]): # test that we can intro dim vars - z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", (x, x), mp=False) + z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", x, x, mp=False) return z x = f.params[0] @@ -444,7 +444,7 @@ def f(x: Tensor[(3, 4), "float32"]): assert isinstance(z_bind.value.op, rx.ExternFunc) assert z_bind.value.op.global_symbol == "contrib.my_matmul" assert "mp" in z_bind.value.attrs and z_bind.value.attrs["mp"] == False - assert structural_equal(z_bind.value.args, [rx.Tuple([x, x])]) + assert structural_equal(z_bind.value.args, [x, x]) def test_primexpr_arithmetic(): @@ -480,18 +480,36 @@ def f(x: Tensor): def test_class_irmodule(): @rx.script - class my_module: - def f(x: Tensor[(n, m), _]) -> Tensor: + class MyModule: + @tvm.script.tir + def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = tir.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + def f(x: Tensor[(n, n), _]) -> Tensor: return g(x) - def g(y: Tensor[(n, m), _]) -> Tensor: - return y + def g(y: Tensor[(n, n), _]) -> Tensor: + return relax.call_dps((n, n), my_matmul, (y, y)) + + def h(x, y, z): + _ = my_matmul(x, y, z) + return z + my_module = MyModule() assert isinstance(my_module, tvm.IRModule) var_f = my_module.get_global_var("f") var_g = my_module.get_global_var("g") + var_my_matmul = my_module.get_global_var("my_matmul") f = my_module[var_f] g = my_module[var_g] assert f.body.body.op == var_g + assert g.body.body.args[1] == var_my_matmul diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py index daa7f9727ad7..aa6b3fd8c2b4 100644 --- a/tests/python/relax/test_printer.py +++ b/tests/python/relax/test_printer.py @@ -25,6 +25,8 @@ def check_roundtrip(f_pre): f_post = rx.parser.fromtext(rx.parser.astext(f_pre)) + if isinstance(f_pre, tvm.IRModule): + f_post = f_post() assert_structural_equal(f_pre, f_post, map_free_vars=True) @@ -140,9 +142,9 @@ def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: def test_call_packed(): @rx.script - def foo(x: Tensor[(3, 4), "float32"]): + def foo(x: Tensor[(3, 3), "float32"]): # test that we can intro dim vars - z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", (x, x), mp=False) + z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", x, x, mp=False) return z check_roundtrip(foo) @@ -169,11 +171,27 @@ def foo(x: Tensor): def test_class_irmodule(): @rx.script - class my_module: - def f(x: Tensor[(n, m), _]) -> Tensor: + class MyModule: + @tvm.script.tir + def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = tir.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + def f(x: Tensor[(n, n), _]) -> Tensor: return g(x) - def g(y: Tensor[(n, m), _]) -> Tensor: - return y + def g(y: Tensor[(n, n), _]) -> Tensor: + return relax.call_dps((n, n), my_matmul, (y, y)) + + def h(x, y, z): + _ = my_matmul(x, y, z) + return z - check_roundtrip(my_module) + mod = MyModule() + check_roundtrip(mod)