diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 6831700264516..eab2f9226bb40 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -146,6 +146,9 @@ class OpNode : public RelayExprNode { // Internal function to compute if it is primitive op bool IsPrimitiveOp_() const { const auto& fn_ty = this->op_type; + if (!fn_ty.get()) { + return false; + } ICHECK(fn_ty.get() != nullptr) << "op_type of " << this->name << " is not registered"; if (fn_ty->type_constraints.size() != 1) return false; const TypeRelationNode* rel = fn_ty->type_constraints[0].as(); diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index d39ae8ba63579..397eebe031406 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -23,33 +23,6 @@ import tvm.relax as rx -def pretty_print(node): - """Prints the given Relax IR node in the Relax text format. - - Parameters - ---------- - node : Union[rx.Type, rx.Expr, rx.Binding, rx.BindingBlock] - The Relax IR node to print. - """ - print(tvm.script._ffi_api.AsRelaxScript(node)) - - -def astext(node) -> str: - """Returns the Relax text format representation of the given Relax IR node. - - Parameters - ---------- - node : Union[rx.Type, rx.Expr, rx.Binding, rx.BindingBlock] - The Relax IR node to print. - - Returns - ------- - str - The text format representation of the given Relax IR node. - """ - return tvm.script._ffi_api.AsRelaxScript(node) - - def _is_registered(op_name: str, op_set=None) -> bool: """Returns whether or not the given operator is registered. @@ -130,10 +103,9 @@ class ArithmeticOp(Enum): class RelaxTransformer(Transformer): - def __init__(self, definition_scope): + def __init__(self): super().__init__() - self.definition_scope = definition_scope - self.module = {} + self.module = tvm.IRModule() self._scopes = [{}] # str -> Var self._registered_ops = set(tvm.ir._ffi_api.ListOpNames()) # cached @@ -415,7 +387,7 @@ def parse_primexpr(self, expr: ast.Expr, bind_free_vars: bool) -> tir.PrimExpr: self.report_error(f"unsupported dimension expression: {expr}", expr.span) def transform_module(self, mod: ast.Module) -> IRModule: - """Transforms the given synr Module to a Relax IRModule. + """Transforms the given synr Module to a Relax IRModule or Function. Parameters ---------- @@ -424,13 +396,28 @@ def transform_module(self, mod: ast.Module) -> IRModule: Returns ------- - IRModule - The parsed Relax IRModule + Union[IRModule, Function] + The parsed Relax IRModule or Function """ - for func_name in mod.funcs: - func = mod.funcs[func_name] - self.module[func_name] = self.transform_function(func, is_global=True) - return self.module + if len(mod.funcs) != 1: + self.report_error( + "the input must be either a single function or a single class", mod.span + ) + + (root_func,) = mod.funcs.values() + + if isinstance(root_func, ast.Function): + return self.transform_function(root_func, is_global=True) + elif isinstance(root_func, ast.Class): + # add global vars to the root scope for resolving global function calls + for func_name in root_func.funcs: + self.scope[func_name] = relay.GlobalVar(func_name) + for func_name, func in root_func.funcs.items(): + global_var = self.scope[func_name] + self.module[global_var] = self.transform_function(func, is_global=True) + return self.module + else: + self.report_error(f"unsupported input class: {root_func}", root_func.span) def _parse_attrs_to_str(self, expr: ast.Attr) -> str: strs = [] @@ -804,6 +791,104 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock: return rx.DataflowBlock(bindings, self.to_tvm_span(block.span)) + def parse_attr(self, expr: ast.Attr) -> rx.Expr: + """Parses the given synr Attr node to a Relax expression. + + Parameters + ---------- + expr : ast.Attr + The synr Attr node to be parsed. + + Returns + ------- + rx.Expr + The parsed expression. + """ + if expr.field.name == "shape": + obj = self.transform_expr(expr.object) + attrs = tvm.ir.attrs.make_node("relay.attrs.ShapeOfAttrs", dtype="int32") + return relay.Call( + relay.op.get("shape_of"), [obj], attrs=attrs, span=self.to_tvm_span(expr.span) + ) + else: + # assume it's a hierarchical op identifier (e.g. nn.softmax, relax.call_dps) + op_name = self._parse_attrs_to_str(expr) + # NOTE: at least for now, all special operators are namespaced + try: + return SpecialOp(op_name) + except ValueError: + # TODO(@altanh): maybe diagnostics here in case this fails? + return relay.op.get(op_name) + + def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, rx.Expr]: + """Parses the given synr Call node to a Relax expression or PrimExpr. + + Parameters + ---------- + expr : ast.Call + The synr Call node to be parsed. + + Returns + ------- + Union[tir.PrimExpr, rx.Expr] + The parsed expression. It will be a PrimExpr if expr is an arithmetic operation on + PrimExprs. + """ + op = self.transform_expr(expr.func_name) + + if op == SpecialOp.CALL_PACKED: + if len(expr.params) != 2: + self.report_error( + op.value + " takes an extern function name and a tuple of arguments", + expr.span, + ) + extern_func = expr.params[0] + if not (isinstance(extern_func, ast.Constant) and isinstance(extern_func.value, str)): + self.report_error( + "the first argument of " + op.value + " must be the extern function name", + extern_func.span, + ) + op = rx.ExternFunc(extern_func.value, self.to_tvm_span(extern_func.span)) + args = [self.transform_expr(expr.params[1])] + + elif isinstance(op, ArithmeticOp): + args = [self.transform_expr(arg) for arg in expr.params] + if all([isinstance(arg, tir.PrimExpr) for arg in args]): + return PRIMEXPR_ARITHMETIC_OP_MAP[op](*args, span=self.to_tvm_span(expr.span)) + # otherwise it's just a normal Relax operator call + op = RELAX_ARITHMETIC_OP_MAP[op] + + elif isinstance(op, tvm.ir.Op): + args = [self.transform_expr(arg) for arg in expr.params] + # check call arity eagerly + if op.num_inputs != -1 and len(args) != op.num_inputs: + self.report_error( + f"{op.name} expects {op.num_input} arguments but got {len(args)}", expr.span + ) + if op.name == "relax.call_dps" and isinstance(args[1], str): + # extern function call case: rewrite identifier to an ExternFunc + args[1] = rx.ExternFunc(args[1], self.to_tvm_span(expr.params[1].span)) + + elif isinstance(op, relay.Expr): + args = [self.transform_expr(arg) for arg in expr.params] + + else: + self.report_error(f"unsupported function in call: {op}", expr.func_name.span) + + # parse call attributes if applicable + if isinstance(op, rx.ExternFunc) or (isinstance(op, tvm.ir.Op) and op.attrs_type_key != ""): + attrs_type_key = "DictAttrs" if isinstance(op, rx.ExternFunc) else op.attrs_type_key + kwargs = {} + for key, val in expr.keyword_params.items(): + assert isinstance(key, ast.Constant) and isinstance(key.value, str) + # TODO(@altanh): might need separate attribute parsing eventually + kwargs[key.value] = self.transform_expr(val) + attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs) + else: + attrs = None + + return relay.Call(op, args, attrs=attrs, span=self.to_tvm_span(expr.span)) + # Exprs: # - ArrayLiteral: unsupported for now? # - Attr: use for .shape, and intrinsic/special operator namespace @@ -827,65 +912,10 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: The corresponding Relax expression """ if isinstance(expr, ast.Attr): - if expr.field.name == "shape": - obj = self.transform_expr(expr.object) - attrs = tvm.ir.attrs.make_node("relay.attrs.ShapeOfAttrs", dtype="int32") - return relay.Call( - relay.op.get("shape_of"), [obj], attrs=attrs, span=self.to_tvm_span(expr.span) - ) - else: - # assume it's a hierarchical op identifier (e.g. nn.softmax, relax.call_dps) - op_name = self._parse_attrs_to_str(expr) - # NOTE: at least for now, all special operators are namespaced - try: - return SpecialOp(op_name) - except ValueError: - # TODO(@altanh): maybe diagnostics here in case this fails? - return relay.op.get(op_name) - - if isinstance(expr, ast.Call): - # TODO(@altanh): support parsing kwargs as attributes? - op = self.transform_expr(expr.func_name) - if op == SpecialOp.CALL_PACKED: - if len(expr.params) != 2: - self.report_error( - op.value + " takes an extern function name and a tuple of arguments", - expr.span, - ) - extern_func = expr.params[0] - if not ( - isinstance(extern_func, ast.Constant) and isinstance(extern_func.value, str) - ): - self.report_error( - "the first argument of " + op.value + " must be the extern function name", - extern_func.span, - ) - op = rx.ExternFunc(extern_func.value, self.to_tvm_span(extern_func.span)) - args = [self.transform_expr(expr.params[1])] - elif isinstance(op, ArithmeticOp): - args = [self.transform_expr(arg) for arg in expr.params] - if all([isinstance(arg, tir.PrimExpr) for arg in args]): - return PRIMEXPR_ARITHMETIC_OP_MAP[op](*args, span=self.to_tvm_span(expr.span)) - # otherwise it's just a normal Relax operator call - op = RELAX_ARITHMETIC_OP_MAP[op] - elif isinstance(op, (tvm.ir.Op, relay.Expr)): - args = [self.transform_expr(arg) for arg in expr.params] - else: - self.report_error(f"unsupported function in call: {op}", expr.func_name.span) + return self.parse_attr(expr) - if isinstance(op, rx.ExternFunc) or ( - isinstance(op, tvm.ir.Op) and op.attrs_type_key != "" - ): - attrs_type_key = "DictAttrs" if isinstance(op, rx.ExternFunc) else op.attrs_type_key - kwargs = {} - for key, val in expr.keyword_params.items(): - assert isinstance(key, ast.Constant) and isinstance(key.value, str) - kwargs[key.value] = self.transform_expr(val) - attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs) - else: - attrs = None - # TODO(@altanh): should we check for correct arity here eagerly, or defer to a pass? - return relay.Call(op, args, attrs=attrs, span=self.to_tvm_span(expr.span)) + elif isinstance(expr, ast.Call): + return self.parse_call(expr) elif isinstance(expr, ast.Tuple): fields = [self.transform_expr(field) for field in expr.values] @@ -1015,61 +1045,67 @@ def transform_block(self, block: ast.Block) -> rx.SeqExpr: # self.tvm_diag_ctx.render() -# TODO(@altanh, @jroesch): revisit this? -class RelaxDecoratedFn: - def __init__(self, fn_name, relax_module, diag_ctx): - self.fn_name = fn_name - self.module = relax_module - self.diag_ctx = diag_ctx - - def __call__(self, *args): - pretty_print(self.module[self.fn_name]) - # compiler = Compiler(self.diag_ctx, self.module, self.fn_name) - # compiled_f = compiler.compile(execute=True) - # # Actually compute needed buffer sizes. - # out = tvm.nd.array(np.random.rand(10).astype('float32')) - # compiled_f(*(list(args) + [out])) - # return out - - -def script(f) -> RelaxDecoratedFn: - """Parses the decorated Relax function (in Relax IR) to a Relax AST. +def script(f) -> Union[rx.Function, tvm.IRModule]: + """Parses the decorated Relax function or module (in Relax IR) to a Relax AST. Parameters ---------- - f : function - The function to be parsed, written in the Relax IR + f : Union[function, class] + The function or class to be parsed, written in the Relax IR. Returns ------- - RelaxDecoratedFn - The parsed Relax function + Union[rx.Function, IRModule] + The parsed Relax function or IRModule. """ - # ir_module = tvm.IRModule({}) - # diag_ctx = diagnostics.DiagnosticContext(ir_module, diagnostics.get_renderer()) diag_ctx = tvm.script.diagnostics.TVMDiagnosticCtx() ast = synr.to_ast(f, diag_ctx) - definition_scope = inspect.getmodule(f) - module = RelaxTransformer(definition_scope).do_transform(ast, diag_ctx) - return RelaxDecoratedFn(f.__name__, module, diag_ctx) + return RelaxTransformer().do_transform(ast, diag_ctx) -def fromtext(source: str, source_name: str = "from_string"): - """Parses the given input string (in the Relax text format) to a Relax AST. +def fromtext(source: str, source_name: str = "from_string") -> Union[rx.Function, tvm.IRModule]: + """Parses the given input string (in the Relax text format) to a Relax function or IRModule. Parameters ---------- source : str - The input source string. + The input source string. It should be either a decorated Python class or function. source_name : str, optional A descriptive name for error reporting, by default "from_string". Returns ------- - Relax AST - The parsed Relax AST. + Union[rx.Function, IRModule] + The parsed Relax function or IRModule. """ + # TODO(@altanh): actually use source_name somewhere? diag_ctx = tvm.script.diagnostics.TVMDiagnosticCtx() ast = synr.to_ast(source, diag_ctx) - module = RelaxTransformer(None).do_transform(ast, diag_ctx) - return module + return RelaxTransformer().do_transform(ast, diag_ctx) + + +def pretty_print(node): + """Prints the given Relax IR node in the Relax text format. + + Parameters + ---------- + node : Union[rx.Type, rx.Expr, rx.Binding, rx.BindingBlock] + The Relax IR node to print. + """ + print(tvm.script._ffi_api.AsRelaxScript(node)) + + +def astext(node) -> str: + """Returns the Relax text format representation of the given Relax IR node. + + Parameters + ---------- + node : Union[rx.Type, rx.Expr, rx.Binding, rx.BindingBlock] + The Relax IR node to print. + + Returns + ------- + str + The text format representation of the given Relax IR node. + """ + return tvm.script._ffi_api.AsRelaxScript(node) diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index 9a7b4a126b2e0..627d4da7602f9 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -76,6 +76,8 @@ class RelaxScriptPrinter : public relax::IRFunctor, Doc VisitExpr_(const tir::DivNode* op) override; Doc VisitExpr_(const tir::FloorDivNode* op) override; + Doc PrintIRModule(const IRModule& mod); + Doc PrintIfStmt(const relax::Var& var, const relay::If& ite); Doc PrintFunctionDef(const Doc& name, const relax::Function& func); @@ -135,7 +137,9 @@ class RelaxScriptPrinter : public relax::IRFunctor, }; Doc RelaxScriptPrinter::Print(const ObjectRef& node) { - if (node->IsInstance()) { + if (node->IsInstance()) { + return PrintIRModule(Downcast(node)); + } else if (node->IsInstance()) { return VisitType(Downcast(node)); } else if (node->IsInstance()) { return VisitExpr(Downcast(node)); @@ -418,6 +422,15 @@ Doc RelaxScriptPrinter::VisitAttr_(const tir::FloatImmNode* op) { return Doc::Text(std::to_string(op->value)); } +Doc RelaxScriptPrinter::PrintIRModule(const IRModule& mod) { + Doc doc; + doc << "class Module:"; + for (const std::pair& pr : mod->functions) { + doc << Doc::Indent(4, Doc::NewLine() << Print(pr.second)); + } + return doc; +} + Doc RelaxScriptPrinter::PrintIfStmt(const relax::Var& var, const relay::If& ite) { const relax::SeqExprNode* true_branch = ite->true_branch.as(); const relax::SeqExprNode* false_branch = ite->false_branch.as(); @@ -518,8 +531,8 @@ Doc RelaxScriptPrinter::GetUniqueName(std::string prefix, std::string fallback = } String AsRelaxScript(const ObjectRef& mod) { - ICHECK(mod->IsInstance()); - return "@tvm.script.relax\n" + RelaxScriptPrinter().Print(mod).str() + "\n"; + ICHECK(mod->IsInstance() || mod->IsInstance()); + return "@tvm.script.relax\n" + RelaxScriptPrinter().Print(mod).str(); } TVM_REGISTER_GLOBAL("script.AsRelaxScript").set_body_typed(AsRelaxScript); diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 95df0ea9250e1..5785e791a66d1 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -26,10 +26,6 @@ # c.f. tests/python/unittest/test_tvmscript_error_report.py -def rx_func(func): - return func.module[func.fn_name] - - def check_shape(e, s): if isinstance(e, rx.Expr): e = e.shape_ @@ -69,7 +65,7 @@ def check_call(call, op, args): def test_annotations(): @rx.script - def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: + def f(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: z: Tensor[(32, k), "float32"] = nn.matmul(x, y, units=None) w: Tensor[_, _] = multiply(z, z) q: Tensor[(_, _), _] = add(w, w) @@ -77,7 +73,6 @@ def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: sh: Shape = t.shape return t - f = rx_func(foo) x, y = f.params z_bind, w_bind, q_bind, t_bind, sh_bind = f.body.blocks[0].bindings z, mm = z_bind.var, z_bind.value @@ -106,12 +101,11 @@ def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: def test_match_shape(): @rx.script - def foo(x: Tensor[_, "float32"]): + def f(x: Tensor[_, "float32"]): relax.match_shape(x.shape, (n, m)) y: Tensor[(n, m), "float32"] = add(x, x) return x - f = rx_func(foo) match_sh = f.body.blocks[0].bindings[0] pattern, value = match_sh.pattern, match_sh.value @@ -122,14 +116,14 @@ def foo(x: Tensor[_, "float32"]): @pytest.mark.xfail def test_dim_var_intro_fail(): @rx.script - def foo(x: Tensor[_, _]): + def f(x: Tensor[_, _]): y: Tensor[(n, m), "float32"] = x return y def test_if(): @rx.script - def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + def f(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): if cond: w = add(x, x) y = multiply(w, w) @@ -138,7 +132,6 @@ def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): y = add(w, w) return y - f = rx_func(foo) cond, x = f.params y_bind = f.body.blocks[0].bindings[0] y, ite = y_bind.var, y_bind.value @@ -172,7 +165,7 @@ def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): @pytest.mark.xfail def test_var_redefine_fail(): @rx.script - def foo(x, y): + def f(x, y): z = add(x, y) y = z return y @@ -181,7 +174,7 @@ def foo(x, y): @pytest.mark.xfail def test_var_redefine_fail_if(): @rx.script - def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + def f(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): y = x if cond: w = add(x, x) @@ -195,7 +188,7 @@ def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): @pytest.mark.xfail def test_var_if_scoping_fail(): @rx.script - def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + def f(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): if cond: w = add(x, x) y = multiply(w, w) @@ -208,7 +201,7 @@ def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): @pytest.mark.xfail def test_if_mismatch_var_fail(): @rx.script - def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): + def f(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): if cond: w = add(x, x) y = multiply(w, w) @@ -221,18 +214,17 @@ def foo(cond: Tensor[(), "bool"], x: Tensor[(1,), "float32"]): @pytest.mark.xfail def test_unassigned_call_fail(): @rx.script - def foo(x: Tensor[_, _]): + def f(x: Tensor[_, _]): add(x, x) return x def test_tuple(): @rx.script - def foo(x: Tensor[_, _], y: Tensor[(32,), "float32"]): + def f(x: Tensor[_, _], y: Tensor[(32,), "float32"]): t: Tuple[Tensor[_, _], Tensor[(32,), "float32"]] = (x, y) return t - f = rx_func(foo) x, y = f.params t_bind = f.body.blocks[0].bindings[0] t, tup = t_bind.var, t_bind.value @@ -253,14 +245,13 @@ def foo(x: Tensor[_, _], y: Tensor[(32,), "float32"]): def test_local_func(): @rx.script - def foo(x: Tensor[_, _]): + def f(x: Tensor[_, _]): def bar(y: Tensor[_, _]): return y y = bar(x) # tests local function variable scoping return y - f = rx_func(foo) bar_bind, y_bind = f.body.blocks[0].bindings bar, bar_fn = bar_bind.var, bar_bind.value bar_x = y_bind.value @@ -273,7 +264,7 @@ def bar(y: Tensor[_, _]): def test_dataflow(): @rx.script - def foo(x: Tensor[_, _]): + def f(x: Tensor[_, _]): with relax.dataflow(): y = add(x, x) z = multiply(y, x) @@ -282,7 +273,6 @@ def foo(x: Tensor[_, _]): t = divide(y, w) return t - f = rx_func(foo) assert len(f.body.blocks) == 2 df_block = f.body.blocks[0] y_bind, z_bind, w_bind = df_block.bindings @@ -304,7 +294,7 @@ def foo(x: Tensor[_, _]): def test_dataflow_match_shape(): @rx.script - def foo(x: Tensor[_, _]): + def f(x: Tensor[_, _]): with relax.dataflow(): x2: Tensor[(n, m), _] = relax.match_shape(x, (n, m)) y = add(x2, x2) @@ -316,7 +306,6 @@ def foo(x: Tensor[_, _]): q: Tensor[(n, m), _] = add(t, x2) return q - f = rx_func(foo) x = f.params[0] df_block = f.body.blocks[0] x2_bind = df_block.bindings[0] @@ -336,7 +325,7 @@ def foo(x: Tensor[_, _]): @pytest.mark.xfail def test_dataflow_scope_fail(): @rx.script - def foo(x: Tensor[_, _]): + def f(x: Tensor[_, _]): with relax.dataflow(): y = add(x, x) z = multiply(y, x) @@ -349,7 +338,7 @@ def foo(x: Tensor[_, _]): @pytest.mark.xfail def test_dataflow_syntax_fail_pattern(): @rx.script - def foo(x: Tensor[_, _]): + def f(x: Tensor[_, _]): with relax.dataflow() as df: y = add(x, x) z = multiply(y, x) @@ -362,7 +351,7 @@ def foo(x: Tensor[_, _]): @pytest.mark.xfail def test_dataflow_syntax_fail_params(): @rx.script - def foo(x: Tensor[_, _]): + def f(x: Tensor[_, _]): with relax.dataflow(x) as df: y = add(x, x) z = multiply(y, x) @@ -375,7 +364,7 @@ def foo(x: Tensor[_, _]): @pytest.mark.xfail def test_dataflow_unbound_outputs(): @rx.script - def foo(x: Tensor[_, _]): + def f(x: Tensor[_, _]): with relax.dataflow(): y = add(x, x) z = multiply(y, x) @@ -388,7 +377,7 @@ def foo(x: Tensor[_, _]): @pytest.mark.xfail def test_invalid_special_op_dataflow(): @rx.script - def foo(x: Tensor): + def f(x: Tensor): y = add(x, x) z = relax.dataflow() return z @@ -397,7 +386,7 @@ def foo(x: Tensor): @pytest.mark.xfail def test_invalid_special_op_output(): @rx.script - def foo(x: Tensor): + def f(x: Tensor): y = add(x, x) z = relax.output(y) return z @@ -406,13 +395,13 @@ def foo(x: Tensor): @pytest.mark.xfail def test_func_no_return_fail(): @rx.script - def foo(x: Tensor[_, _]): + def f(x: Tensor[_, _]): y = add(x, x) def test_inline_tir(): @rx.script - def foo(x: Tensor[(B, 128), "float32"], y: Tensor[(128, 128), "float32"]): + def f(x: Tensor[(B, 128), "float32"], y: Tensor[(128, 128), "float32"]): @tvm.script.tir def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) @@ -427,7 +416,6 @@ def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: z = relax.call_dps((B, 128), my_matmul, (x, y)) return z - f = rx_func(foo) x, y = f.params B = x.shape_[0] mm_bind, z_bind = f.body.blocks[0].bindings @@ -444,12 +432,11 @@ def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: def test_call_packed(): @rx.script - def foo(x: Tensor[(3, 4), "float32"]): + def f(x: Tensor[(3, 4), "float32"]): # test that we can intro dim vars z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", (x, x), mp=False) return z - f = rx_func(foo) x = f.params[0] (z_bind,) = f.body.blocks[0].bindings check_tensor_var(z_bind.var, ("n", "m"), "float32") @@ -462,15 +449,49 @@ def foo(x: Tensor[(3, 4), "float32"]): def test_primexpr_arithmetic(): @rx.script - def foo(x: Tensor[(n, m), "float32"]): + def f(x: Tensor[(n, m), "float32"]): z: Tensor[(n * m,), "float32"] = relax.call_packed("my_flatten", (x,)) sh: Shape = (n + m, n // m) return z - f = rx_func(foo) x = f.params[0] n, m = x.shape_ z_bind, sh_bind = f.body.blocks[0].bindings assert structural_equal(z_bind.var.shape_.values, [tir.Mul(n, m)]) assert structural_equal(sh_bind.value.values, [tir.Add(n, m), tir.FloorDiv(n, m)]) + + +def test_call_dps_extern(): + @rx.script + def f(x: Tensor): + z = relax.call_dps((10,), "my_extern", (x,)) + return z + + x = f.params[0] + (z_bind,) = f.body.blocks[0].bindings + + check_call( + z_bind.value, + "relax.call_dps", + [rx.ShapeExpr([tir.IntImm("int32", 10)]), rx.ExternFunc("my_extern"), rx.Tuple([x])], + ) + + +def test_class_irmodule(): + @rx.script + class my_module: + def f(x: Tensor[(n, m), _]) -> Tensor: + return g(x) + + def g(y: Tensor[(n, m), _]) -> Tensor: + return y + + assert isinstance(my_module, tvm.IRModule) + + var_f = my_module.get_global_var("f") + var_g = my_module.get_global_var("g") + f = my_module[var_f] + g = my_module[var_g] + + assert f.body.body.op == var_g diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py index b99b0a9223d52..daa7f9727ad71 100644 --- a/tests/python/relax/test_printer.py +++ b/tests/python/relax/test_printer.py @@ -23,13 +23,8 @@ from tvm.ir import structural_equal, assert_structural_equal -def rx_func(func): - return func.module[func.fn_name] - - -def check_roundtrip(fn): - f_pre = rx_func(fn) - f_post = rx.parser.fromtext(rx.parser.astext(f_pre))[fn.fn_name] +def check_roundtrip(f_pre): + f_post = rx.parser.fromtext(rx.parser.astext(f_pre)) assert_structural_equal(f_pre, f_post, map_free_vars=True) @@ -161,3 +156,24 @@ def foo(x: Tensor[(n, m), "float32"]): return z check_roundtrip(foo) + + +def test_call_dps_extern(): + @rx.script + def foo(x: Tensor): + z = relax.call_dps((10,), "my_extern", (x,)) + return z + + check_roundtrip(foo) + + +def test_class_irmodule(): + @rx.script + class my_module: + def f(x: Tensor[(n, m), _]) -> Tensor: + return g(x) + + def g(y: Tensor[(n, m), _]) -> Tensor: + return y + + check_roundtrip(my_module)