Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Upstream IRModule parser from unity #14487

Merged
merged 3 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/script/ir_builder/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ class IRBuilder : public runtime::ObjectRef {
* \sa tvm::support::With
*/
static IRBuilder Current();
/*! \brief See if the current thread-local scope has an IRBuilder. */
static bool IsInScope();
/*!
* \brief Give a string name to the `obj`
* \tparam TObjectRef The type of the object to name.
Expand Down
14 changes: 11 additions & 3 deletions include/tvm/script/ir_builder/ir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,21 @@ namespace ir {
*/
class IRModuleFrameNode : public IRBuilderFrameNode {
public:
Array<GlobalVar> global_vars;
Array<BaseFunc> functions;
/*! \brief A map from string names to global variables that ensures global uniqueness. */
Map<String, GlobalVar> global_var_map;
/*!
* \brief A map from GlobalVar to all global functions.
* \note Only defined functions are in the map, while declared functions are not included.
*/
Map<GlobalVar, BaseFunc> functions;
/*! \brief IRModule's attributes. */
Map<String, ObjectRef> attrs;

void VisitAttrs(tvm::AttrVisitor* v) {
IRBuilderFrameNode::VisitAttrs(v);
v->Visit("global_vars", &global_vars);
v->Visit("global_vars", &global_var_map);
v->Visit("functions", &functions);
v->Visit("attrs", &attrs);
}

static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame";
Expand Down
17 changes: 17 additions & 0 deletions include/tvm/script/ir_builder/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@ namespace ir {
*/
TVM_DLL IRModuleFrame IRModule();

/*!
* \brief Declare a Function without given the specific function implementation.
* \note It is usually used in cross-function call. And we can specify the function by `DefFunction`
* \param func_name The function unique name.
* \param func_signature A Function w/o body, which used to specify the function signature
* (i.e. func params and func return type/shape).
* \return The corresponding GlobalVar.
*/
TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature);

/*!
* \brief Define the function which is declared before.
* \param func_name The function unique name.
* \param func The given function implementation
*/
TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func);

} // namespace ir
} // namespace ir_builder
} // namespace script
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class IRModule(Node, Scriptable):
Map of global var to BaseFunc
"""

def __init__(self, functions=None, type_definitions=None):
def __init__(self, functions=None, type_definitions=None, attrs=None):
if functions is None:
functions = {}
elif isinstance(functions, dict):
Expand All @@ -60,7 +60,17 @@ def __init__(self, functions=None, type_definitions=None):
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
mapped_type_defs[k] = v
type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)

attrs = None if not attrs else attrs
if attrs is not None:
attrs = ast.literal_eval(str(attrs))
attrs = tvm.ir.make_node("DictAttrs", **attrs)
self.__init_handle_by_constructor__(
_ffi_api.IRModule,
functions,
type_definitions,
attrs,
)

def __setitem__(self, var, val):
"""Add a mapping to the module.
Expand Down
17 changes: 15 additions & 2 deletions python/tvm/script/ir_builder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ def __enter__(self) -> "IRBuilderFrame":
_ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member
return self

def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
_ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member
def __exit__(self, exc_type, exc_value, trace) -> None: # pylint: disable=unused-argument
if exc_type is None and exc_value is None:
# Do not execute `FrameExit` if the with scope exits because of exceptions
_ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member

def add_callback(self, callback: Callable[[], None]) -> None:
"""Add a callback method invoked when exiting the with-scope.
Expand Down Expand Up @@ -136,6 +138,17 @@ def current() -> "IRBuilder":
"""
return _ffi_api.IRBuilderCurrent() # type: ignore[attr-defined] # pylint: disable=no-member

@staticmethod
def is_in_scope() -> bool:
"""See if the current thread-local scope has an IRBuilder.

Returns
-------
bool
Whether the current thread-local scope has an IRBuilder
"""
return _ffi_api.IRBuilderIsInScope() # type: ignore[attr-defined] # pylint: disable=no-member

def get(self) -> _Object:
"""Get the constructed IR."""
return _ffi_api.IRBuilderGet(self) # type: ignore[attr-defined] # pylint: disable=no-member
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/script/ir_builder/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,9 @@
# under the License.
"""Package tvm.script.ir_builder.ir"""
from .frame import IRModuleFrame
from .ir import ir_module
from .ir import (
decl_function,
def_function,
ir_module,
module_attrs,
)
59 changes: 59 additions & 0 deletions python/tvm/script/ir_builder/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,68 @@
# under the License.
"""Package tvm.script.ir_builder.ir.ir"""

from typing import Dict

from tvm.runtime import Object as tvm_Object

from tvm.ir import BaseFunc, GlobalVar

from . import _ffi_api
from .frame import IRModuleFrame


def ir_module() -> IRModuleFrame:
"""Start a ir_module frame.
Returns
-------
frame: IRModuleFrame
The constructed frame.
"""
return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint: disable=no-member


def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar:
"""Declare a Function without given the specific function implementation.
Parameters
----------
func_name : str
The function unique name.

func_signature: Optional[BaseFunc]
A Function w/o body, which used to specify the function signature
(i.e. func params and func return type/shape).

Note
----
It is usually used in cross-function call. And we can specify the function by `DefFunction`
Returns
-------
gv : GlobalVar
The corresponding GlobalVar.
"""

return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member
func_name, func_signature
)


def def_function(func_name: str, func: BaseFunc) -> None:
"""Define the function which is declared before.
Parameters
----------
func_name : str
The function unique name.
func: BaseFunc
The given function implementation
"""
return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member


def module_attrs(attrs: Dict[str, tvm_Object]) -> None:
"""Specify the attrs of the ir_module frame.
Parameters
----------
attrs: Dict[str, Object]
The module attrs.
"""
return _ffi_api.ModuleAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member
2 changes: 1 addition & 1 deletion python/tvm/script/parser/core/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel)
level : diagnostics.DiagnosticLevel
The diagnostic level.
"""
lineno = node.lineno or self.source.start_line
lineno = node.lineno or 1
col_offset = node.col_offset or self.source.start_column
end_lineno = node.end_lineno or lineno
end_col_offset = node.end_col_offset or col_offset
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/parser/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _visit(self, node: doc.AST) -> Any:
else:
value = self._eval_expr(node.__class__(**fields))
except Exception as e: # pylint: disable=broad-except,invalid-name
self.parser.report_error(node, str(e))
self.parser.report_error(node, e)
return self._add_intermediate_result(value)

def _eval_lambda(self, node: doc.Lambda) -> Any:
Expand Down
50 changes: 35 additions & 15 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def context():
return context()


def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument
pass


class VarTableFrame:
"""The variable table frame.
A frame of variable table stores the variables created in one block or scope.
Expand Down Expand Up @@ -260,6 +264,17 @@ def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
node = self.diag.source.as_ast()
self.visit(node)

def get_dispatch_token(self, node: doc.FunctionDef) -> str:
if not isinstance(node, doc.FunctionDef):
self.report_error(node, "Only can get dispatch token for function.")
if not node.decorator_list:
self.report_error(node, "Function must be decorated")
# TODO: only the last decorator is parsed
decorator = self.eval_expr(node.decorator_list[-1])
if not hasattr(decorator, "dispatch_token"):
self.report_error(node, "The parser does not understand the decorator")
return decorator.dispatch_token

def with_dispatch_token(self, token: str):
"""Add a new dispatching token as with statement.

Expand Down Expand Up @@ -389,6 +404,8 @@ def report_error(
# Only take the last line of the error message
if isinstance(err, TVMError):
msg = list(filter(None, str(err).split("\n")))[-1]
elif isinstance(err, KeyError):
msg = "KeyError: " + str(err)
else:
msg = str(err)
self.diag.error(node, msg)
Expand Down Expand Up @@ -458,30 +475,33 @@ def visit_tvm_annotation(self, node: doc.expr) -> Any:
"""
return _dispatch(self, "tvm_annotation")(self, node)

def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name
"""The general function definition visiting method.
def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=invalid-name
"""The general function definition visit method.

Parameters
----------
node : doc.FunctionDef
The doc AST function definition node.

Returns
-------
res : Any
The visiting result.
The doc FunctionDef node.
"""
if not node.decorator_list:
self.report_error(node, "Function must be decorated")
# TODO: only the last decorator is parsed
decorator = self.eval_expr(node.decorator_list[-1])
if not hasattr(decorator, "dispatch_token"):
self.report_error(node, "The parser does not understand the decorator")
token = decorator.dispatch_token
token = self.get_dispatch_token(node)
current_token = self.dispatch_tokens[-1]
func = dispatch.get(token=token, type_name="FunctionDef", default=None)
if func is None:
self.report_error(node, "The parser does not understand the decorator")
pre_func = dispatch.get(
token=current_token, type_name="pre_token_switch", default=_do_nothing
)
post_func = dispatch.get(
token=current_token, type_name="post_token_switch", default=_do_nothing
)
pre_func(self, node)
_dispatch_wrapper(func)(self, node)
post_func(self, node)

def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None:
token = self.get_dispatch_token(node)
with self.with_dispatch_token(token):
_dispatch(self, "tvm_declare_function")(self, node)

def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name
"""The general class definition visiting method.
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/parser/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""The ir module parser"""

from ...ir_builder.ir import * # pylint: disable=redefined-builtin
from . import parser as _parser
from .entry import ir_module

__all__ = ["ir_module"]
__all__ = ["ir_module", "module_attrs"]
15 changes: 13 additions & 2 deletions python/tvm/script/parser/ir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,20 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
node : doc.ClassDef
The doc AST class definition node.
"""

with self.var_table.with_frame():
with I.ir_module():
with self.with_dispatch_token("ir"):
self.visit_body(node.body)
for stmt in node.body:
if not isinstance(stmt, doc.FunctionDef):
self.visit(stmt)
for stmt in node.body:
if isinstance(stmt, doc.FunctionDef):
self.visit_tvm_declare_function(stmt)
with self.with_dispatch_token("ir"):
for stmt in node.body:
if isinstance(stmt, doc.FunctionDef):
self.visit(stmt)


@dispatch.register(token="ir", type_name="Assign")
Expand All @@ -53,7 +63,7 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None:


@dispatch.register(token="ir", type_name="Expr")
def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
def _visit_expr(self: Parser, node: doc.Expr) -> None:
"""The expression visiting method for ir module.

Parameters
Expand All @@ -64,6 +74,7 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
node : doc.ClassDef
The doc AST expression node.
"""
self.eval_expr(node.value)


@dispatch.register(token="default", type_name="Assign")
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __getitem__(self, keys) -> Buffer:
return self(keys)
if len(keys) >= 2 and not isinstance(keys[1], str):
return self(keys)
return self(*keys) # pylint: disable=no-member # type: ignore
return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member


class PtrProxy:
Expand All @@ -93,7 +93,7 @@ class PtrProxy:
def __call__(self, dtype, storage_scope="global"):
if callable(dtype):
dtype = dtype().dtype
return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore
return ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member

@deprecated("T.Ptr[...]", "T.handle(...)")
def __getitem__(self, keys):
Expand Down
Loading