diff --git a/python/tvm/script/parser/_core.py b/python/tvm/script/parser/_core.py index 4f5411dc368f..b7ba5ee4713f 100644 --- a/python/tvm/script/parser/_core.py +++ b/python/tvm/script/parser/_core.py @@ -18,5 +18,5 @@ # pylint: disable=unused-import from .core import dispatch, doc, utils from .core.dispatch import OpMethod, register_op -from .core.entry import parse +from .core.entry import parse, parse_macro from .core.parser import Parser diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 5315c0f6755e..08a593d5d31b 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -25,6 +25,25 @@ from .parser import Parser +def _default_globals() -> Dict[str, Any]: + import tvm # pylint: disable=import-outside-toplevel + from tvm.script.parser import ir # pylint: disable=import-outside-toplevel + from tvm.script.parser import tir # pylint: disable=import-outside-toplevel + + extra_vars = {"tvm": tvm, "I": ir, "ir": ir, "T": tir, "tir": tir} + return extra_vars + + +def parse_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any: + """Generate the AST, and the source code for __repr__.""" + # The AST will be converted into TIR at the time of expansion. + source = Source(program) + source_txt = source.source + source_ast = source.as_ast() + closure_vars = extra_vars or _default_globals() + return source_ast, source_txt, closure_vars + + def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any: """Register a method for a operand type, AST operator node and operand index. @@ -42,17 +61,7 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) The parsed TVMScript program. """ if extra_vars is None: - import tvm # pylint: disable=import-outside-toplevel - from tvm.script.parser import ir # pylint: disable=import-outside-toplevel - from tvm.script.parser import tir # pylint: disable=import-outside-toplevel - - extra_vars = { - "tvm": tvm, - "I": ir, - "ir": ir, - "T": tir, - "tir": tir, - } + extra_vars = _default_globals() ann = {} if inspect.isfunction(program): diff --git a/python/tvm/script/parser/tir/__init__.py b/python/tvm/script/parser/tir/__init__.py index ad16821a89a3..9d3fc1ec98da 100644 --- a/python/tvm/script/parser/tir/__init__.py +++ b/python/tvm/script/parser/tir/__init__.py @@ -30,6 +30,6 @@ # so most tvmscript won't trigger pylint error here. prim_func = staticmethod else: - from .entry import prim_func + from .entry import prim_func, macro -__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"] +__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func", "macro"] diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index d5bff7a856d5..64b71d699f3d 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -16,13 +16,13 @@ # under the License. """The entry point of TVM parser for tir.""" import inspect -from typing import Callable, Union +from typing import Any, Callable, Dict, Union from tvm.ir.base import deprecated from tvm.tir import Buffer, PrimFunc from ...ir_builder.tir import buffer, ptr -from .._core import parse, utils +from .._core import doc, parse, parse_macro, utils def prim_func(func: Callable) -> Union[PrimFunc, Callable]: @@ -50,6 +50,101 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]: setattr(prim_func, "dispatch_token", "tir") +# Semantics of TIR macros: +# - Function that is decorated with @T.macro can have any parameters that +# follow Python syntax, i.e. positional, keyword, etc. Type annotations +# are not required, but are allowed. +# - Macro use follows the same syntax as a function call. +# For `macro_name(arg1, arg2, arg3, ...)`, the values are substituted into +# the body of the macro, and the body with the substituted values is then +# inserted at the point where the call to the macro is located. + + +class TIRMacro: + """Representation of T.macro.""" + + def __init__( + self, + source_ast: doc.AST, + source_txt: str, + closure_vars: Dict[str, Any], + func: Callable, + hygienic: bool, + ) -> None: + self.source_ast = source_ast + self.source_txt = source_txt + self.closure_vars = closure_vars + self.func = func + self.hygienic = hygienic + + def __repr__(self): + return self.source_txt + + +def macro(*args, hygienic: bool = True) -> Callable: + """Decorator for macro definitions. + + Parameters + ---------- + hygienic: bool + Specifies whether the macro is hygienic or not. + A macro is hygienic if all symbols used in the macro's body are resolved + to values from the location of the macro definition. A non-hygienic macro + will have its symbols resolved to values at the time of the macro's use. + + Example: + ``` + import tvm + from tvm.script import tir as T + + x_value = 128 + + @T.macro(hygienic=True) + def static_capture(A, B): + B[()] = A[x_value] ### x_value binds to 128 + + @T.macro(hygienic=False) + def dynamic_capture(A, B): + B[()] = A[x_value] ### x_value will bind at the time of use + + + @T.prim_func + def use1(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: + for x_value in T.serial(10): + static_capture(A, B) ### Produces B[()] = A[128] + + @T.prim_func + def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: + for x_value in T.serial(10): + dynamic_capture(A, B) ### Produces B[()] = A[x_value] + ``` + """ + + def _decorator(func: Callable) -> TIRMacro: + source_ast, source_txt, closure_vars = parse_macro( + func, utils.inspect_function_capture(func) + ) + obj = TIRMacro(source_ast, source_txt, closure_vars, func, hygienic) + obj.__name__ = func.__name__ + # We don't need to explicitly store the return value anywhere. + # This function is a decorator, so the return value will replace + # the function definition (to which the decorator it is applied) + # in that function's name space. + return obj + + if len(args) == 0: + return _decorator + if len(args) == 1 and inspect.isfunction(args[0]): + return _decorator(args[0]) + + raise ValueError( + "Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])" + ) + + +# There is no dispatch_token for macro, because macro doesn't invoke parser. + + class BufferProxy: """Buffer proxy class for constructing tir buffer.""" diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index f81f9bd9ea78..67e14d0e9772 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -17,8 +17,9 @@ """The base parser for tir""" import contextlib +import inspect from functools import partial -from typing import Any +from typing import Any, Union import tvm from tvm.ir import GlobalVar, PrimType @@ -29,6 +30,8 @@ from ...ir_builder.base import IRBuilder from ...ir_builder.base import IRBuilderFrame as Frame from .._core import Parser, dispatch, doc +from ..core.parser import VarTable +from .entry import TIRMacro def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: @@ -427,6 +430,12 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: node : doc.Expr The doc AST Expr node. """ + + if isinstance(node.value, doc.Call): + callee = self.eval_expr(node.value.func) + if isinstance(callee, TIRMacro): + return expand_macro(self, callee, node.value) + res = self.eval_expr(node.value) if res is None: pass @@ -447,6 +456,7 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: pass else: self.report_error(node, f"Parsing resulted in unexpected type {type(res)}") + return None # For pylint @dispatch.register(token="tir", type_name="If") @@ -528,3 +538,51 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar # Only ret_type is needed for func_signature. func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type) return I.decl_function(node.name, func_signature) + + +def expand_macro(self: Parser, callee: TIRMacro, call: doc.Call) -> None: + """Bind arguments to the macro invocation to the parameters in the macro definition, + and pass the macro body for further parsing. + """ + + assert isinstance(callee, TIRMacro), f"Unexpected macro type {type(callee)}" + + def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]: + for decl in decl_list: + if isinstance(decl, doc.FunctionDef) and decl.name == name: + return decl + return None + + macro_def = find_macro_def(callee.__name__, callee.source_ast.body) + assert macro_def is not None, f"Invalid macro AST for {callee.__name__}" + # `macro_def` is the FunctionDef of the macro. + + args = [self.eval_expr(arg) for arg in call.args] + kwargs = {kw.arg: self.eval_expr(kw.value) for kw in call.keywords} + param_binding = inspect.signature(callee.func).bind(*args, **kwargs) + param_binding.apply_defaults() + local_vars = param_binding.arguments + + if callee.hygienic: + # If the macro was hygienic, construct new var_table with a single frame that + # contains the captured environment, and process the macro's body with that + # frame. + saved_var_table = self.var_table + self.var_table = VarTable() + with self.var_table.with_frame(): + for k, v in callee.closure_vars.items(): + self.var_table.add(k, v) + for k, v in local_vars.items(): + self.var_table.add(k, v) + + self.visit_body(macro_def.body) + + self.var_table = saved_var_table + + else: + # Otherwise, dynamically resolve symbols in the macro's body. + with self.var_table.with_frame(): + for k, v in local_vars.items(): + self.var_table.add(k, v) + + self.visit_body(macro_def.body) diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index 31bf5cc10180..38d3e1474656 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -16,6 +16,7 @@ # under the License. """Unittests for tvm.script.parser.tir""" +import pytest import tvm.testing from tvm.script.parser import tir as T from tvm import ir, tir @@ -71,5 +72,111 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: assert matmul.__name__ == "matmul" +def test_tir_macro_decorator_signature(): + @T.prim_func + def evaluate0(): + T.evaluate(0) + + # Ok, no parentheses + @T.macro + def func1(): + T.evaluate(0) + + assert func1.hygienic + + @T.prim_func + def use1(): + func1() + + tvm.ir.assert_structural_equal(use1, evaluate0) + + # Ok, empty parentheses + @T.macro() + def func2(): + T.evaluate(0) + + assert func2.hygienic + + @T.prim_func + def use2(): + func2() + + tvm.ir.assert_structural_equal(use1, evaluate0) + + with pytest.raises(ValueError): + # Wrong: non-keyword argument + @T.macro(True) + def func3(): + T.evaluate() + + +def test_tir_macro_signature(): + @T.macro + def assign(i, *args, t1, **kwargs): + vi, vj, vk = T.axis.remap("SSR", [i, args[0], args[1]]) + kwargs["t3"][vi, vj] = kwargs["t3"][vi, vj] + t1[vi, vk] * kwargs["t2"][vj, vk] + + @T.prim_func + def matmul_w_macro(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + assign(i, j, k, t1=A, t2=B, t3=C) + + @T.prim_func + def matmul_no_macro(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + tvm.ir.assert_structural_equal(matmul_no_macro, matmul_w_macro) + + +def test_tir_macro_hygienic(): + x_value = 128 + + @T.macro(hygienic=True) + def static_capture(A, B): + B[()] = A[x_value] + + @T.prim_func + def use_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: + for x_value in T.serial(10): + static_capture(A, B) + + @T.prim_func + def expected_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: + for x_value in range(10): + B[()] = A[128] + + tvm.ir.assert_structural_equal(use_hygienic, expected_hygienic) + + +def test_tir_macro_non_hygienic(): + x_value = 128 + + @T.macro(hygienic=False) + def dynamic_capture(A, B): + B[()] = A[x_value] + + @T.prim_func + def use_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: + for x_value in T.serial(10): + dynamic_capture(A, B) + + @T.prim_func + def expected_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: + for x_value in range(10): + B[()] = A[x_value] + + tvm.ir.assert_structural_equal(use_non_hygienic, expected_non_hygienic) + + if __name__ == "__main__": tvm.testing.main()