Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Implement TIR macros #15260

Merged
merged 14 commits into from
Jul 11, 2023
2 changes: 1 addition & 1 deletion python/tvm/script/parser/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# pylint: disable=unused-import
from .core import dispatch, doc, utils
from .core.dispatch import OpMethod, register_op
from .core.entry import parse
from .core.entry import parse, parse_macro
from .core.parser import Parser
8 changes: 8 additions & 0 deletions python/tvm/script/parser/core/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@
from .parser import Parser


def parse_macro(program: Union[Any, str]) -> Any:
"""Generate the AST, and the source code for __repr__."""
# The AST will be converted into TIR at the time of insertion.
source = Source(program)
node = source.as_ast()
return node, source.source


def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any:
"""Register a method for a operand type, AST operator node and operand index.

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/parser/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
# so most tvmscript won't trigger pylint error here.
prim_func = staticmethod
else:
from .entry import prim_func
from .entry import prim_func, macro, insert

__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"]
__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func", "macro", "insert"]
46 changes: 44 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
# under the License.
"""The entry point of TVM parser for tir."""
import inspect
from typing import Callable, Union
from typing import Any, Callable, Union

from tvm.ir.base import deprecated
from tvm.tir import Buffer, PrimFunc

from ...ir_builder.tir import buffer, ptr
from .._core import parse, utils
from .._core import doc, parse, parse_macro, utils


def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
Expand Down Expand Up @@ -50,6 +50,48 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
setattr(prim_func, "dispatch_token", "tir")


# Semantics of TIR macros:
# - Function that is decorated with @T.macro can have any parameters that
# follow Python syntax, i.e. positional, keyword, etc. Type annotations
# are not required, but are allowed.
# - The arguments to `T.insert` are: macro name (either as value, or as
# a string with the name), followed by the argument list.
# For `T.insert(arg1, arg2, arg3, ...)`, the values are substituted into
# the body of the macro as in the call `arg1(arg2, arg3, ...)`.
# The body with the substituted values is then inserted at the point
# where the `T.insert` is located.


class TIRMacro:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be specific to TIR, or should it apply to any dialect supported by TVMScript? Thinking that this would be quite useful on the unity branch as well, where a Relax method for an end-to-end model often contains many repeated elements. Implementing those as a macro would also allow Relax's shape propagation to resolve differently for each expansion of the macro (e.g. in a chain of convolutions).

If we want it to be more general, we could move the implementation over to the tvm.script.parser.ir namespace instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be I.macro then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would, yes. Within TVMScript, the default set of global definitions is defined here. This provides both tvm.script.tir as T and tvm.script.ir as I.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this in a separate PR?

"""Representation of T.macro: consists of the doc.AST and the text of the source."""

def __init__(self, node, source):
self.doc = node
self.source = source

def __repr__(self):
return self.source


def macro(func: Callable) -> doc.AST:
obj = TIRMacro(*parse_macro(func))
setattr(obj, "__name__", func.__name__)
kparzysz-quic marked this conversation as resolved.
Show resolved Hide resolved
# We don't need to explicitly store the return value anywhere.
kparzysz-quic marked this conversation as resolved.
Show resolved Hide resolved
# This function is a decorator, so the return value will replace
# the function definition (to which the decorator it is applied)
# in that function's name space.
return obj
kparzysz-quic marked this conversation as resolved.
Show resolved Hide resolved


# There is no dispatch_token for macro, because macro doesn't invoke parser.


def insert(name: Union[str, doc.Name], *args, **kwargs) -> Any: # pylint: disable=unused-argument
"""Placeholder function, so that T.insert (i.e. macro insertion)
can be parsed without errors.
"""


class BufferProxy:
"""Buffer proxy class for constructing tir buffer."""

Expand Down
90 changes: 89 additions & 1 deletion python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# under the License.
"""The base parser for tir"""

import ast
import contextlib
from functools import partial
from typing import Any
from typing import Any, Union

import tvm
from tvm.ir import GlobalVar, PrimType
Expand Down Expand Up @@ -427,6 +428,20 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
node : doc.Expr
The doc AST Expr node.
"""

def is_insert_macro(node: doc.Call) -> bool:
kparzysz-quic marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(node.func, doc.Attribute):
return False
attr = node.func
if not isinstance(attr.value, doc.Name):
return False
if attr.value.id != "T" or attr.attr != "insert":
kparzysz-quic marked this conversation as resolved.
Show resolved Hide resolved
return False
return True

if isinstance(node.value, doc.Call) and is_insert_macro(node.value):
return process_insert_macro(self, node.value)

res = self.eval_expr(node.value)
if res is None:
pass
Expand Down Expand Up @@ -528,3 +543,76 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar
# Only ret_type is needed for func_signature.
func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type)
return I.decl_function(node.name, func_signature)


def process_insert_macro(self: Parser, call: doc.Call) -> None:
"""Bind arguments to T.insert to the parameters of the macro, and pass the macro body
for further parsing.
"""

def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]:
for decl in decl_list:
if isinstance(decl, doc.FunctionDef) and decl.name == name:
return decl
return None

macro_name = call.args[0]

if not isinstance(macro_name, doc.Name):
self.report_error(call, "Invalid macro name in T.insert")
macro_name = macro_name.id

macro = self.var_table.get().get(macro_name)
if macro is None:
self.report_error(node, f"Undefined macro '{macro_name}'")

if isinstance(macro.doc, doc.Module):
macro_def = find_macro_def(macro_name, macro.doc.body)
kparzysz-quic marked this conversation as resolved.
Show resolved Hide resolved
elif not isinstance(macro.doc, doc.FunctionDef) or macro.doc.name != macro_name:
macro_def = None

if macro_def is None:
self.report_error(call, f"Undefined macro {macro_name}")

# `macro_def` is a FunctionDef of the macro.

# We have the AST for the macro definition, and the AST for the call. We need to
# substitute the actual arguments from the call for the parameters from the
# definition. To allow full flexibility of python, i.e. positional, unnamed, and
# keyword parameters, get the python interpreter to do the work: create and execute
# the following:
# ```
# def macro_name(...macro parameters...)
# return locals()
# tmp = macro_name(...arguments from the call...)
# ```
# Obtain the dictionary `tmp` resulting from the execution, and update the var_table
# with it.

# Construct the function with the macro's parameters, and returning locals().
macro_ast = doc.from_doc(macro_def)
kparzysz-quic marked this conversation as resolved.
Show resolved Hide resolved
macro_ast.body = [
ast.Return(value=ast.Call(func=ast.Name("locals", ctx=ast.Load()), args=[], keywords=[]))
]
macro_ast.decorator_list = []

# Construct the assignment with the call.
call_ast = doc.from_doc(call)
call_ast.func = ast.Name(macro_name, ctx=ast.Load())
call_ast.args = call_ast.args[1:]
tmp_name = "__tmp_param_eval_64e98b523301204b"
assign_ast = ast.Assign(targets=[ast.Name(tmp_name, ctx=ast.Store())], value=call_ast)

# Finalize and execute the module:
module_ast = ast.Module(body=[macro_ast, assign_ast], type_ignores=[])
module_ast = ast.fix_missing_locations(module_ast)
cmacro = compile(module_ast, filename="<tmp-string>", mode="exec")
local_vars = {}
exec(cmacro, self.var_table.get(), local_vars) # pylint: disable=exec-used
local_vars = local_vars[tmp_name]

with self.var_table.with_frame():
kparzysz-quic marked this conversation as resolved.
Show resolved Hide resolved
for k, v in local_vars.items():
self.var_table.add(k, v)

self.visit_body(macro_def.body)
28 changes: 28 additions & 0 deletions tests/python/unittest/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,33 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
assert matmul.__name__ == "matmul"


def test_tir_macro():
@T.macro
def assign(i, *args, t1, **kwargs):
vi, vj, vk = T.axis.remap("SSR", [i, args[0], args[1]])
kwargs["t3"][vi, vj] = kwargs["t3"][vi, vj] + t1[vi, vk] * kwargs["t2"][vj, vk]

@T.prim_func
def matmul_w_macro(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
T.insert(assign, i, j, k, t1=A, t2=B, t3=C)

@T.prim_func
def matmul_no_macro(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

tvm.ir.assert_structural_equal(matmul_no_macro, matmul_w_macro)


if __name__ == "__main__":
tvm.testing.main()