From bd0d81a74db6b157ff014446a9916654a5d97dee Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 8 Feb 2023 22:31:47 +0800 Subject: [PATCH 1/3] [TVMScript] IRModule TVMScript Parser. This PR adds the TVMScript parser/ir_builder support based on the blockbuilder. This commit contains the non-relax portions from https://github.com/apache/tvm/pull/13932. Co-authored-by: Ruihang Lai Co-authored-by: Junru Shao Co-authored-by: Tianqi Chen Co-authored-by: Yuchen Jin Co-authored-by: Steven S. Lyubomirsky Co-authored-by: Yong Wu --- include/tvm/script/ir_builder/ir/frame.h | 11 +++-- include/tvm/script/ir_builder/ir/ir.h | 17 +++++++ python/tvm/script/ir_builder/base.py | 6 ++- python/tvm/script/ir_builder/ir/__init__.py | 2 +- python/tvm/script/ir_builder/ir/ir.py | 45 ++++++++++++++++++ python/tvm/script/parser/core/diagnostics.py | 2 +- python/tvm/script/parser/core/evaluator.py | 2 +- python/tvm/script/parser/core/parser.py | 50 ++++++++++++++------ python/tvm/script/parser/ir/parser.py | 4 ++ python/tvm/script/parser/tir/entry.py | 4 +- python/tvm/script/parser/tir/parser.py | 26 ++++++++++ src/script/ir_builder/ir/frame.cc | 12 +++-- src/script/ir_builder/ir/ir.cc | 32 ++++++++++++- src/script/ir_builder/ir/utils.h | 49 +++++++++++++++++++ src/script/ir_builder/tir/frame.cc | 15 ++++-- src/script/ir_builder/tir/utils.h | 2 +- 16 files changed, 245 insertions(+), 34 deletions(-) create mode 100644 src/script/ir_builder/ir/utils.h diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index 887981ccffc8..dacfc361a6c7 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -38,12 +38,17 @@ namespace ir { */ class IRModuleFrameNode : public IRBuilderFrameNode { public: - Array global_vars; - Array functions; + /*! \brief A map from string names to global variables that ensures global uniqueness. */ + Map global_var_map; + /*! + * \brief A map from GlobalVar to all global functions. + * \note Only defined functions are in the map, while declared functions are not included. + */ + Map functions; void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); - v->Visit("global_vars", &global_vars); + v->Visit("global_vars", &global_var_map); v->Visit("functions", &functions); } diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h index f0e7cc6f5c2f..49bdcf60e6fb 100644 --- a/include/tvm/script/ir_builder/ir/ir.h +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -37,6 +37,23 @@ namespace ir { */ TVM_DLL IRModuleFrame IRModule(); +/*! + * \brief Declare a Function without given the specific function implementation. + * \note It is usually used in cross-function call. And we can specify the function by `DefFunction` + * \param func_name The function unique name. + * \param func_signature A Function w/o body, which used to specify the function signature + * (i.e. func params and func return type/shape). + * \return The corresponding GlobalVar. + */ +TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature); + +/*! + * \brief Define the function which is declared before. + * \param func_name The function unique name. + * \param func The given function implementation + */ +TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func); + } // namespace ir } // namespace ir_builder } // namespace script diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index 7aa33ee49c72..b35bbd0a7df5 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -64,8 +64,10 @@ def __enter__(self) -> "IRBuilderFrame": _ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member return self - def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument - _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member + def __exit__(self, exc_type, exc_value, trace) -> None: # pylint: disable=unused-argument + if exc_type is None and exc_value is None: + # Do not execute `FrameExit` if the with scope exits because of exceptions + _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member def add_callback(self, callback: Callable[[], None]) -> None: """Add a callback method invoked when exiting the with-scope. diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index ebb9728737ad..946be263a779 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -16,4 +16,4 @@ # under the License. """Package tvm.script.ir_builder.ir""" from .frame import IRModuleFrame -from .ir import ir_module +from .ir import decl_function, def_function, ir_module diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 213180463cb2..796d6f3aad04 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,9 +16,54 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" +from tvm.ir import BaseFunc, GlobalVar + from . import _ffi_api from .frame import IRModuleFrame def ir_module() -> IRModuleFrame: + """Start a ir_module frame. + Returns + ------- + frame: IRModuleFrame + The constructed frame. + """ return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint: disable=no-member + + +def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar: + """Declare a Function without given the specific function implementation. + Parameters + ---------- + func_name : str + The function unique name. + + func_signature: Optional[BaseFunc] + A Function w/o body, which used to specify the function signature + (i.e. func params and func return type/shape). + + Note + ---- + It is usually used in cross-function call. And we can specify the function by `DefFunction` + Returns + ------- + gv : GlobalVar + The corresponding GlobalVar. + """ + + return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member + func_name, func_signature + ) + + +def def_function(func_name: str, func: BaseFunc) -> None: + """Define the function which is declared before. + Parameters + ---------- + func_name : str + The function unique name. + func: BaseFunc + The given function implementation + """ + return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py index ad7ae5034780..2767a97f6096 100644 --- a/python/tvm/script/parser/core/diagnostics.py +++ b/python/tvm/script/parser/core/diagnostics.py @@ -220,7 +220,7 @@ def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) level : diagnostics.DiagnosticLevel The diagnostic level. """ - lineno = node.lineno or self.source.start_line + lineno = node.lineno or 1 col_offset = node.col_offset or self.source.start_column end_lineno = node.end_lineno or lineno end_col_offset = node.end_col_offset or col_offset diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 3a72a3c33106..075aedd89146 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -203,7 +203,7 @@ def _visit(self, node: doc.AST) -> Any: else: value = self._eval_expr(node.__class__(**fields)) except Exception as e: # pylint: disable=broad-except,invalid-name - self.parser.report_error(node, str(e)) + self.parser.report_error(node, e) return self._add_intermediate_result(value) def _eval_lambda(self, node: doc.Lambda) -> Any: diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index fdccabcd235d..837b7cce5d5e 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -60,6 +60,10 @@ def context(): return context() +def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument + pass + + class VarTableFrame: """The variable table frame. A frame of variable table stores the variables created in one block or scope. @@ -260,6 +264,17 @@ def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: node = self.diag.source.as_ast() self.visit(node) + def get_dispatch_token(self, node: doc.FunctionDef) -> str: + if not isinstance(node, doc.FunctionDef): + self.report_error(node, "Only can get dispatch token for function.") + if not node.decorator_list: + self.report_error(node, "Function must be decorated") + # TODO: only the last decorator is parsed + decorator = self.eval_expr(node.decorator_list[-1]) + if not hasattr(decorator, "dispatch_token"): + self.report_error(node, "The parser does not understand the decorator") + return decorator.dispatch_token + def with_dispatch_token(self, token: str): """Add a new dispatching token as with statement. @@ -389,6 +404,8 @@ def report_error( # Only take the last line of the error message if isinstance(err, TVMError): msg = list(filter(None, str(err).split("\n")))[-1] + elif isinstance(err, KeyError): + msg = "KeyError: " + str(err) else: msg = str(err) self.diag.error(node, msg) @@ -458,30 +475,33 @@ def visit_tvm_annotation(self, node: doc.expr) -> Any: """ return _dispatch(self, "tvm_annotation")(self, node) - def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name - """The general function definition visiting method. + def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=invalid-name + """The general function definition visit method. Parameters ---------- node : doc.FunctionDef - The doc AST function definition node. - - Returns - ------- - res : Any - The visiting result. + The doc FunctionDef node. """ - if not node.decorator_list: - self.report_error(node, "Function must be decorated") - # TODO: only the last decorator is parsed - decorator = self.eval_expr(node.decorator_list[-1]) - if not hasattr(decorator, "dispatch_token"): - self.report_error(node, "The parser does not understand the decorator") - token = decorator.dispatch_token + token = self.get_dispatch_token(node) + current_token = self.dispatch_tokens[-1] func = dispatch.get(token=token, type_name="FunctionDef", default=None) if func is None: self.report_error(node, "The parser does not understand the decorator") + pre_func = dispatch.get( + token=current_token, type_name="pre_token_switch", default=_do_nothing + ) + post_func = dispatch.get( + token=current_token, type_name="post_token_switch", default=_do_nothing + ) + pre_func(self, node) _dispatch_wrapper(func)(self, node) + post_func(self, node) + + def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None: + token = self.get_dispatch_token(node) + with self.with_dispatch_token(token): + _dispatch(self, "tvm_declare_function")(self, node) def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name """The general class definition visiting method. diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index e0268412d284..13b3e298590f 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -32,8 +32,12 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: node : doc.ClassDef The doc AST class definition node. """ + with self.var_table.with_frame(): with I.ir_module(): + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + self.visit_tvm_declare_function(stmt) with self.with_dispatch_token("ir"): self.visit_body(node.body) diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 411a7f8f3c83..649f817411f0 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -83,7 +83,7 @@ def __getitem__(self, keys) -> Buffer: return self(keys) if len(keys) >= 2 and not isinstance(keys[1], str): return self(keys) - return self(*keys) # pylint: disable=no-member # type: ignore + return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member class PtrProxy: @@ -93,7 +93,7 @@ class PtrProxy: def __call__(self, dtype, storage_scope="global"): if callable(dtype): dtype = dtype().dtype - return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore + return ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member @deprecated("T.Ptr[...]", "T.handle(...)") def __getitem__(self, keys): diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 8a067267a352..63171f672289 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -24,6 +24,7 @@ from tvm.ir import PrimType from tvm.tir import Buffer, IterVar, PrimExpr, Var +from ...ir_builder import ir as I from ...ir_builder import tir as T from ...ir_builder.base import IRBuilder from ...ir_builder.base import IRBuilderFrame as Frame @@ -473,3 +474,28 @@ def visit_return(self: Parser, node: doc.Return) -> None: The doc AST return node. """ self.report_error(node, "Return is not allowed.") + + +@dispatch.register(token="tir", type_name="tvm_declare_function") +def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: + """The function declaration step for tir + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Return + The doc AST return node. + """ + + ret_type = None + if node.returns is not None: + ret_type = self.eval_expr(node.returns) + if callable(ret_type): + ret_type = PrimType(ret_type().dtype) + + # Only ret_type is needed for func_signature. + func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type) + global_var = I.decl_function(node.name, func_signature) + self.var_table.add(node.name, global_var) diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index a81c56922dff..addf12928435 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -26,11 +26,15 @@ namespace ir_builder { namespace ir { void IRModuleFrameNode::ExitWithScope() { - ICHECK_EQ(functions.size(), global_vars.size()); - int n = functions.size(); Map func_map; - for (int i = 0; i < n; ++i) { - func_map.Set(global_vars[i], functions[i]); + CHECK_EQ(functions.size(), global_var_map.size()) + << "All functions must be defined in the IRModule. Got " << global_var_map.size() + << "declared function(s), but only " << functions.size() << "defined function(s)."; + for (const auto& kv : functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + CHECK(func.defined()) << "ValueError: function " << gv->name_hint << " is not defined"; + func_map.Set(gv, func); } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index a8cc452e4f0c..5764e90c8dd4 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -20,6 +20,8 @@ #include #include +#include "./utils.h" + namespace tvm { namespace script { namespace ir_builder { @@ -27,12 +29,40 @@ namespace ir { IRModuleFrame IRModule() { ObjectPtr n = make_object(); - n->global_vars.clear(); + n->global_var_map.clear(); n->functions.clear(); return IRModuleFrame(n); } +GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) { + IRModuleFrame frame = FindModuleFrame("I.DeclFunction"); + CHECK(!frame->global_var_map.count(func_name)) + << "ValueError: function " << func_name << " already exists"; + GlobalVar gv = GlobalVar(func_name); + CHECK(frame->functions.find(gv) == frame->functions.end()) + << "ValueError: function " << func_name << " has already been defined."; + frame->global_var_map.Set(func_name, gv); + if (func_signature.defined()) { + frame->functions.Set(gv, func_signature); + } + return gv; +} + +void DefFunction(const String& func_name, const BaseFunc& func) { + IRModuleFrame frame = FindModuleFrame("I.DefFunction"); + auto it = frame->global_var_map.find(func_name); + CHECK(it != frame->global_var_map.end()) + << "ValueError: function " << func_name << " does not exist, please declare it first."; + const GlobalVar& gv = (*it).second; + frame->functions.Set(gv, func); + if (func->checked_type_.defined()) { + gv->checked_type_ = func->checked_type_; + } +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); } // namespace ir } // namespace ir_builder diff --git a/src/script/ir_builder/ir/utils.h b/src/script/ir_builder/ir/utils.h new file mode 100644 index 000000000000..58d5e53f7032 --- /dev/null +++ b/src/script/ir_builder/ir/utils.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ + +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace ir { + +inline IRModuleFrame FindModuleFrame(const String& method) { + IRBuilder builder = IRBuilder::Current(); + if (Optional frame = builder->FindFrame()) { + const Optional& last_module_frame = builder->GetLastFrame(); + if (last_module_frame.defined() && last_module_frame.value() == frame) { + return frame.value(); + } + } else { + LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure '" << method + << "' is called under I.ir_module()"; + } + LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under I.ir_module()"; + throw; +} + +} // namespace ir +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 1e63201a40dd..dd8d3c2ed3f3 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include @@ -41,9 +42,17 @@ void PrimFuncFrameNode::ExitWithScope() { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; } else if (Optional opt_frame = builder->FindFrame()) { - ir::IRModuleFrame frame = opt_frame.value(); - frame->global_vars.push_back(GlobalVar(name.value_or(""))); - frame->functions.push_back(func); + CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; + const ir::IRModuleFrame& frame = opt_frame.value(); + const String& func_name = name.value_or(""); + if (!frame->global_var_map.count(func_name)) { + // Case. First time visiting the function. + ir::DeclFunction(func_name, func); + } + // Define the function. + // Note we do checks to disallow redefinition of functions inside the `DefFunction`. + ir::DefFunction(func_name, func); } else { LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc"; } diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index 7ccc132fa1fe..f3b547532cfd 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -87,7 +87,7 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) { * \return The top frame of BlockFrame. */ inline BlockFrame FindBlockFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + if (Optional frame = IRBuilder::Current()->FindFrame()) { return frame.value(); } else if (Optional frame = IRBuilder::Current()->FindFrame()) { LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.block(). " From 78a4ce8d1859526ec3a2d947953334d44b9677b3 Mon Sep 17 00:00:00 2001 From: Hongyi Jin <3231950289@qq.com> Date: Sun, 26 Feb 2023 11:05:47 -0500 Subject: [PATCH 2/3] [TVMScript] Expose IRModule::attrs as I.module_attrs This is an upstreaming of the non-relax portions of https://github.com/apache/tvm/pull/14132, including a unit test specically to validate `I.module_attrs`. --- include/tvm/script/ir_builder/base.h | 2 ++ include/tvm/script/ir_builder/ir/frame.h | 3 +++ python/tvm/ir/module.py | 14 ++++++++++++-- python/tvm/script/ir_builder/base.py | 11 +++++++++++ python/tvm/script/ir_builder/ir/__init__.py | 7 ++++++- python/tvm/script/ir_builder/ir/ir.py | 14 ++++++++++++++ python/tvm/script/parser/ir/__init__.py | 4 ++-- python/tvm/script/parser/ir/parser.py | 11 +++++++++-- src/ir/module.cc | 6 ++---- src/script/ir_builder/base.cc | 6 ++++++ src/script/ir_builder/ir/frame.cc | 3 ++- src/script/ir_builder/ir/ir.cc | 12 ++++++++++++ src/script/printer/ir/ir.cc | 5 +++++ tests/python/unittest/test_tvmscript_roundtrip.py | 14 ++++++++++++++ 14 files changed, 100 insertions(+), 12 deletions(-) diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index 61ca3eb9f7eb..a00ea5768e23 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -237,6 +237,8 @@ class IRBuilder : public runtime::ObjectRef { * \sa tvm::support::With */ static IRBuilder Current(); + /*! \brief See if the current thread-local scope has an IRBuilder. */ + static bool IsInScope(); /*! * \brief Give a string name to the `obj` * \tparam TObjectRef The type of the object to name. diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index dacfc361a6c7..ed425cf61441 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -45,11 +45,14 @@ class IRModuleFrameNode : public IRBuilderFrameNode { * \note Only defined functions are in the map, while declared functions are not included. */ Map functions; + /*! \brief IRModule's attributes. */ + Map attrs; void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); v->Visit("global_vars", &global_var_map); v->Visit("functions", &functions); + v->Visit("attrs", &attrs); } static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame"; diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 3daffb2640c5..232c70aa93d8 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -37,7 +37,7 @@ class IRModule(Node, Scriptable): Map of global var to BaseFunc """ - def __init__(self, functions=None, type_definitions=None): + def __init__(self, functions=None, type_definitions=None, attrs=None): if functions is None: functions = {} elif isinstance(functions, dict): @@ -60,7 +60,17 @@ def __init__(self, functions=None, type_definitions=None): raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") mapped_type_defs[k] = v type_definitions = mapped_type_defs - self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions) + + attrs = None if not attrs else attrs + if attrs is not None: + attrs = ast.literal_eval(str(attrs)) + attrs = tvm.ir.make_node("DictAttrs", **attrs) + self.__init_handle_by_constructor__( + _ffi_api.IRModule, + functions, + type_definitions, + attrs, + ) def __setitem__(self, var, val): """Add a mapping to the module. diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index b35bbd0a7df5..1d5d050444f7 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -138,6 +138,17 @@ def current() -> "IRBuilder": """ return _ffi_api.IRBuilderCurrent() # type: ignore[attr-defined] # pylint: disable=no-member + @staticmethod + def is_in_scope() -> bool: + """See if the current thread-local scope has an IRBuilder. + + Returns + ------- + bool + Whether the current thread-local scope has an IRBuilder + """ + return _ffi_api.IRBuilderIsInScope() # type: ignore[attr-defined] # pylint: disable=no-member + def get(self) -> _Object: """Get the constructed IR.""" return _ffi_api.IRBuilderGet(self) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index 946be263a779..b796de8113f3 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -16,4 +16,9 @@ # under the License. """Package tvm.script.ir_builder.ir""" from .frame import IRModuleFrame -from .ir import decl_function, def_function, ir_module +from .ir import ( + decl_function, + def_function, + ir_module, + module_attrs, +) diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 796d6f3aad04..c5276f8d136e 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,6 +16,10 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" +from typing import Dict + +from tvm.runtime import Object as tvm_Object + from tvm.ir import BaseFunc, GlobalVar from . import _ffi_api @@ -67,3 +71,13 @@ def def_function(func_name: str, func: BaseFunc) -> None: The given function implementation """ return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member + + +def module_attrs(attrs: Dict[str, tvm_Object]) -> None: + """Specify the attrs of the ir_module frame. + Parameters + ---------- + attrs: Dict[str, Object] + The module attrs. + """ + return _ffi_api.ModuleAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/parser/ir/__init__.py b/python/tvm/script/parser/ir/__init__.py index fedd2f0a14a8..adda17601206 100644 --- a/python/tvm/script/parser/ir/__init__.py +++ b/python/tvm/script/parser/ir/__init__.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """The ir module parser""" - +from ...ir_builder.ir import * # pylint: disable=redefined-builtin from . import parser as _parser from .entry import ir_module -__all__ = ["ir_module"] +__all__ = ["ir_module", "module_attrs"] diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index 13b3e298590f..201c99074f20 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -35,11 +35,17 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: with self.var_table.with_frame(): with I.ir_module(): + with self.with_dispatch_token("ir"): + for stmt in node.body: + if not isinstance(stmt, doc.FunctionDef): + self.visit(stmt) for stmt in node.body: if isinstance(stmt, doc.FunctionDef): self.visit_tvm_declare_function(stmt) with self.with_dispatch_token("ir"): - self.visit_body(node.body) + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + self.visit(stmt) @dispatch.register(token="ir", type_name="Assign") @@ -57,7 +63,7 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None: @dispatch.register(token="ir", type_name="Expr") -def _visit_expr(_self: Parser, _node: doc.Expr) -> None: +def _visit_expr(self: Parser, node: doc.Expr) -> None: """The expression visiting method for ir module. Parameters @@ -68,6 +74,7 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None: node : doc.ClassDef The doc AST expression node. """ + self.eval_expr(node.value) @dispatch.register(token="default", type_name="Assign") diff --git a/src/ir/module.cc b/src/ir/module.cc index 7a973da29dfa..9d663f9a1a2f 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -382,10 +382,8 @@ IRModule IRModule::FromText(const String& text, const String& source_path) { TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("ir.IRModule") - .set_body_typed([](tvm::Map funcs, - tvm::Map types) { - return IRModule(funcs, types, {}); - }); + .set_body_typed([](tvm::Map funcs, tvm::Map types, + tvm::DictAttrs attrs) { return IRModule(funcs, types, {}, {}, attrs); }); TVM_REGISTER_GLOBAL("ir.Module_Add") .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc index 8303efff4f20..879db4f3d713 100644 --- a/src/script/ir_builder/base.cc +++ b/src/script/ir_builder/base.cc @@ -77,6 +77,11 @@ IRBuilder IRBuilder::Current() { return stack->back(); } +bool IRBuilder::IsInScope() { + std::vector* stack = ThreadLocalBuilderStack(); + return !stack->empty(); +} + namespace details { Namer::FType& Namer::vtable() { @@ -106,6 +111,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope").set_body_typed(IRBuilder::IsInScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet") .set_body_method(&IRBuilderNode::Get); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name); diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index addf12928435..92470ec65342 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -38,7 +38,8 @@ void IRModuleFrameNode::ExitWithScope() { } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; - builder->result = tvm::IRModule(func_map); + auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); + builder->result = tvm::IRModule(func_map, {}, {}, {}, dict_attrs); } TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index 5764e90c8dd4..0c34f85246c9 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -60,9 +60,21 @@ void DefFunction(const String& func_name, const BaseFunc& func) { } } +void ModuleAttrs(Map attrs) { + if (IRBuilder::IsInScope()) { + // TODO(hongyi): add comments to explain why we need to check if the module frame is in scope + IRModuleFrame frame = FindModuleFrame("I.ModuleAttr"); + if (!frame->attrs.empty()) { + LOG(FATAL) << "ValueError: Duplicate module attrs, previous one is:\n" << frame->attrs; + } + frame->attrs = attrs; + } +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs); } // namespace ir } // namespace ir_builder diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 065cfe5168ad..1c751d40f2e7 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -64,6 +64,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) std::sort(functions.begin(), functions.end()); With f(d); (*f)->AddDispatchToken(d, "ir"); + if (mod->attrs.defined() && !mod->attrs->dict.empty()) { + (*f)->stmts.push_back( + ExprStmtDoc(IR(d, "module_attrs") // + ->Call({d->AsDoc(mod->attrs, p->Attr("attrs"))}))); + } for (const auto& entry : functions) { const GlobalVar& gv = entry.gv; const BaseFunc& func = entry.func; diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index bbc6dd45a83e..52d99550be92 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3725,6 +3725,19 @@ def tir_packed_call(A: T.Buffer(16)): return tvm.tir.transform.LowerTVMBuiltin()(Module) +def ir_module_with_attrs(): + @I.ir_module + class Module: + I.module_attrs({"attr": 10}) + + @T.prim_func + def tir_func(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")): + for i in range(16): + B[i] = A[i] + + return Module + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -3791,6 +3804,7 @@ def tir_packed_call(A: T.Buffer(16)): if_then_else_var, tvm_shfl_builtins, tvm_struct_set_generated_in_cpp, + ir_module_with_attrs, ) From f845122e05388b49b5cb4ee4439cec572e62ac58 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 5 Apr 2023 09:31:21 -0500 Subject: [PATCH 3/3] Expose attrs argument of "ir.IRModule" to Rust bindings --- rust/tvm/src/ir/module.rs | 16 +++++++++++----- src/ir/module.cc | 16 +++++++++++++++- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 8f71a8be2c7c..4cdca826ec16 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -28,7 +28,7 @@ use crate::runtime::array::Array; use crate::runtime::function::Result; use crate::runtime::map::Map; use crate::runtime::string::String as TVMString; -use crate::runtime::{external, IsObjectRef, Object}; +use crate::runtime::{external, IsObjectRef, Object, ObjectRef}; use super::expr::GlobalVar; use super::function::BaseFunc; @@ -62,7 +62,7 @@ external! { #[name("relay.parser.ParseExpr")] fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule; #[name("ir.IRModule")] - fn module_new(funcs: Map, types: Map) -> IRModule; + fn module_new(funcs: Map, types: Map, attrs: Map) -> IRModule; // Module methods #[name("ir.Module_Add")] fn module_add(module: IRModule, type_name: GlobalVar, expr: BaseFunc, update: bool) -> IRModule; @@ -99,18 +99,24 @@ external! { // Note: we don't expose update here as update is going to be removed. impl IRModule { - pub fn new<'a, F, T>(funcs: F, types: T) -> Result + pub fn new<'a, F, T, A>(funcs: F, types: T, attrs: A) -> Result where F: IntoIterator, T: IntoIterator, + A: IntoIterator, { - module_new(Map::from_iter(funcs), Map::from_iter(types)) + module_new( + Map::from_iter(funcs), + Map::from_iter(types), + Map::from_iter(attrs), + ) } pub fn empty() -> Result { let funcs = HashMap::::new(); let types = HashMap::::new(); - IRModule::new(funcs.iter(), types.iter()) + let attrs = HashMap::::new(); + IRModule::new(funcs.iter(), types.iter(), attrs.iter()) } pub fn parse(file_name: N, source: S) -> Result diff --git a/src/ir/module.cc b/src/ir/module.cc index 9d663f9a1a2f..b891f6f61371 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -383,7 +383,21 @@ TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("ir.IRModule") .set_body_typed([](tvm::Map funcs, tvm::Map types, - tvm::DictAttrs attrs) { return IRModule(funcs, types, {}, {}, attrs); }); + tvm::ObjectRef attrs) { + auto dict_attrs = [&attrs]() { + if (!attrs.defined()) { + return DictAttrs(); + } else if (auto* as_dict_attrs = attrs.as()) { + return GetRef(as_dict_attrs); + } else if (attrs.as()) { + return tvm::DictAttrs(Downcast>(attrs)); + } else { + LOG(FATAL) << "Expected attrs argument to be either DictAttrs or Map"; + } + }(); + + return IRModule(funcs, types, {}, {}, dict_attrs); + }); TVM_REGISTER_GLOBAL("ir.Module_Add") .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule {