Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Implement TIR macros #15260

Merged
merged 14 commits into from
Jul 11, 2023
Prev Previous commit
Next Next commit
Implement macro hygiene
Krzysztof Parzyszek committed Jul 8, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit de6e34cbb6465bbea1df395e31441e0ba42b7293
77 changes: 63 additions & 14 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
@@ -61,29 +61,78 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:


class TIRMacro:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be specific to TIR, or should it apply to any dialect supported by TVMScript? Thinking that this would be quite useful on the unity branch as well, where a Relax method for an end-to-end model often contains many repeated elements. Implementing those as a macro would also allow Relax's shape propagation to resolve differently for each expansion of the macro (e.g. in a chain of convolutions).

If we want it to be more general, we could move the implementation over to the tvm.script.parser.ir namespace instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be I.macro then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would, yes. Within TVMScript, the default set of global definitions is defined here. This provides both tvm.script.tir as T and tvm.script.ir as I.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this in a separate PR?

"""Representation of T.macro: consists of the doc.AST and the text of the source."""
"""Representation of T.macro."""

func: Callable

def __init__(self, source_ast: doc.AST, source_txt: str, closure_vars: Dict[str, Any]) -> None:
def __init__(
self,
source_ast: doc.AST,
source_txt: str,
closure_vars: Dict[str, Any],
func: Callable,
hygienic: bool,
) -> None:
self.source_ast = source_ast
self.source_txt = source_txt
self.closure_vars = closure_vars
self.func = None
self.func = func
self.hygienic = hygienic

def __repr__(self):
return self.source_txt


def macro(func: Callable) -> doc.AST:
obj = TIRMacro(*parse_macro(func, utils.inspect_function_capture(func)))
obj.__name__ = func.__name__
obj.func = func
# We don't need to explicitly store the return value anywhere.
# This function is a decorator, so the return value will replace
# the function definition (to which the decorator it is applied)
# in that function's name space.
return obj
def macro(*, hygienic: bool = True) -> Callable:
"""Decorator for macro definitions.

Parameters
----------
hygienic: bool
Specifies whether the macro is hygienic or not.
A macro is hygienic if all symbols used in the macro's body are resolved
to values from the location of the macro definition. A non-hygienic macro
will have its symbols resolved to values at the time of the macro's use.

Example:
```
import tvm
from tvm.script import tir as T

x_value = 128

@T.macro(hygienic=True)
def static_capture(A, B):
B[()] = A[x_value] ### x_value binds to 128

@T.macro(hygienic=False)
def dynamic_capture(A, B):
B[()] = A[x_value] ### x_value will bind at the time of use


@T.prim_func
def use1(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in T.serial(10):
static_capture(A, B) ### Produces B[()] = A[128]

@T.prim_func
def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in T.serial(10):
dynamic_capture(A, B) ### Produces B[()] = A[x_value]
```
"""

def _decorator(func: Callable) -> TIRMacro:
source_ast, source_txt, closure_vars = parse_macro(
func, utils.inspect_function_capture(func)
)
obj = TIRMacro(source_ast, source_txt, closure_vars, func, hygienic)
obj.__name__ = func.__name__
# We don't need to explicitly store the return value anywhere.
# This function is a decorator, so the return value will replace
# the function definition (to which the decorator it is applied)
# in that function's name space.
return obj

return _decorator


# There is no dispatch_token for macro, because macro doesn't invoke parser.
27 changes: 23 additions & 4 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@
from ...ir_builder.base import IRBuilder
from ...ir_builder.base import IRBuilderFrame as Frame
from .._core import Parser, dispatch, doc
from ..core.parser import VarTable
from .entry import TIRMacro


@@ -562,8 +563,26 @@ def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]
param_binding.apply_defaults()
local_vars = param_binding.arguments

with self.var_table.with_frame():
for k, v in local_vars.items():
self.var_table.add(k, v)
if callee.hygienic:
# If the macro was hygienic, construct new var_table with a single frame that
# contains the captured environment, and process the macro's body with that
# frame.
saved_var_table = self.var_table
self.var_table = VarTable()
with self.var_table.with_frame():
for k, v in callee.closure_vars.items():
self.var_table.add(k, v)
for k, v in local_vars.items():
self.var_table.add(k, v)

self.visit_body(macro_def.body)

self.var_table = saved_var_table

else:
# Otherwise, dynamically resolve symbols in the macro's body.
with self.var_table.with_frame():
for k, v in local_vars.items():
self.var_table.add(k, v)

self.visit_body(macro_def.body)
self.visit_body(macro_def.body)
42 changes: 41 additions & 1 deletion tests/python/unittest/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
@@ -71,7 +71,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
assert matmul.__name__ == "matmul"


def test_tir_macro():
def test_tir_macro_signature():
@T.macro
def assign(i, *args, t1, **kwargs):
vi, vj, vk = T.axis.remap("SSR", [i, args[0], args[1]])
@@ -99,5 +99,45 @@ def matmul_no_macro(a: T.handle, b: T.handle, c: T.handle) -> None:
tvm.ir.assert_structural_equal(matmul_no_macro, matmul_w_macro)


def test_tir_macro_hygienic():
x_value = 128

@T.macro(hygienic=True)
def static_capture(A, B):
B[()] = A[x_value]

@T.prim_func
def use_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in T.serial(10):
static_capture(A, B)

@T.prim_func
def expected_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in range(10):
B[()] = A[128]

tvm.ir.assert_structural_equal(use_hygienic, expected_hygienic)


def test_tir_macro_non_hygienic():
x_value = 128

@T.macro(hygienic=False)
def dynamic_capture(A, B):
B[()] = A[x_value]

@T.prim_func
def use_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in T.serial(10):
dynamic_capture(A, B)

@T.prim_func
def expected_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in range(10):
B[()] = A[x_value]

tvm.ir.assert_structural_equal(use_non_hygienic, expected_non_hygienic)


if __name__ == "__main__":
tvm.testing.main()