-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] Implement TIR macros #15260
[TIR] Implement TIR macros #15260
Conversation
This patch introduces two new symbols: `T.macro` and `T.insert`. `T.macro` is a decorator that, when applied to a function, turns the body of that function into a piece of TIR that can be inserted via `T.insert` into a PrimFunc. For example: ```python @T.macro def copy_backwards(dst, src, size): with T.block("backwards"): for i in T.serial(size): ai = T.axis.remap("S", [i]) T.reads(src[0:size]) T.writes(dst[0:size]) dst[ai] = src[size - ai - 1] @T.prim_func def foo_int32(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")): T.insert(copy_backwards, A, B, 128) @T.prim_func def foo_int8(A: T.Buffer((128,), "int8"), B: T.Buffer((128,), "int8")): T.insert(copy_backwards, A, B, 128) ``` The above will generate two PrimFuncs that do the same backwards copy, but applied to buffers with different data types. Semantics: - Function that is decorated with @T.macro can have any parameters that follow Python syntax, i.e. positional, keyword, etc. Type annotations are not required, but are allowed. - The arguments to `T.insert` are macro name followed by the argument list. For `T.insert(arg1, arg2, arg3, ...)`, the values are substituted into the body of the macro as in the call `arg1(arg2, arg3, ...)`. The body with the substituted values is then inserted at the point where the `T.insert` is located.
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
3 similar comments
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
cc: @yzh119 |
One linter suggested something that the other didn't like...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like it, just some questions on simplifying the implementation, which variables should be in-scope within the body of the macro, and whether macros should be provided for use in Relax as well as TIR.
I also liked your point from the discuss thread about the root tir::Block
being handled differently if trying to generate a Stmt
externally, and agree that it is better to re-parse the body of the macro in order to avoid those special cases.
# where the `T.insert` is located. | ||
|
||
|
||
class TIRMacro: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be specific to TIR, or should it apply to any dialect supported by TVMScript? Thinking that this would be quite useful on the unity branch as well, where a Relax method for an end-to-end model often contains many repeated elements. Implementing those as a macro would also allow Relax's shape propagation to resolve differently for each expansion of the macro (e.g. in a chain of convolutions).
If we want it to be more general, we could move the implementation over to the tvm.script.parser.ir
namespace instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be I.macro
then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would, yes. Within TVMScript, the default set of global definitions is defined here. This provides both tvm.script.tir
as T
and tvm.script.ir
as I
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do this in a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review!
The previous commit inadvertently made T.macro (without parentheses) illegal, only abbreviated form allowed was T.macro(). Restore T.macro as a valid decorator use.
Do you have any more comments? @Lunderberg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, and thank you for making the changes!
I think the only open question was which namespace the macro support should go under. By placing it under the tvm.script.tir
namespace, we'll need to keep a backwards-compatibility import if/when we move it to tvm.script.ir
, unless done in the current PR. That said, that would be a pretty useful import anyways, regardless of the backwards compatibility, so I'm not worried there.
Didn't have the chance to review timely but this looks really awesome to me! |
@Lunderberg, @junrushao
Admittedly, I'm wondering if you have any better ideas. |
In Relax, I think this use case could be handled by a function that returns a For multiple return values, I can't remember if Relax's normalizer would remove trivial Tuple bindings, where the function return value is a tuple which gets immediately unpacked, but that would be a reasonable extension to add to it. |
For the purposes of the rest of the comment, I renamed There are some complications with macros returning values, specifically when the call to the macro is a part of a larger expression. The problem is that there are no specific parser visitors for subexpressions---the whole expression is evaluated by the evaluator all at once, so I can't do the same "AST injection" as I did for TIR (at least not without adding detailed visitors and capturing AST.Call nodes). For example,
The entire I'm inclined to follow the TIR route instead, but I wanted to consult with you in case this is a bad idea. |
I think I got it to work: import tvm
from tvm import relax
from tvm.script import relax as R
@R.macro
def alloc_and_shape(dtype: str):
alloc = R.builtin.alloc_tensor(R.shape([4, 4]), runtime_device_index=0, dtype=dtype)
shape = R.shape_of(alloc)
return shape
@R.function
def foo(x: R.Tensor((4, 4), "float32")):
shape = alloc_and_shape(dtype="float32")
return shape
print(alloc_and_shape)
print()
print(foo) Produces @R.macro
def alloc_and_shape(dtype: str):
alloc = R.builtin.alloc_tensor(R.shape([4, 4]), runtime_device_index=0, dtype=dtype)
shape = R.shape_of(alloc)
return shape
# from tvm.script import relax as R
@R.function
def foo(x: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]):
alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([4, 4]), R.dtype("float32"), R.prim_value(0))
shape: R.Shape([4, 4]) = R.shape_of(alloc)
shape_1: R.Shape([4, 4]) = shape
return shape_1 This is still via the I'll wait for |
* [TIR] Implement TIR macros This patch introduces two new symbols: `T.macro` and `T.insert`. `T.macro` is a decorator that, when applied to a function, turns the body of that function into a piece of TIR that can be inserted via `T.insert` into a PrimFunc. For example: ```python @T.macro def copy_backwards(dst, src, size): with T.block("backwards"): for i in T.serial(size): ai = T.axis.remap("S", [i]) T.reads(src[0:size]) T.writes(dst[0:size]) dst[ai] = src[size - ai - 1] @T.prim_func def foo_int32(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")): T.insert(copy_backwards, A, B, 128) @T.prim_func def foo_int8(A: T.Buffer((128,), "int8"), B: T.Buffer((128,), "int8")): T.insert(copy_backwards, A, B, 128) ``` The above will generate two PrimFuncs that do the same backwards copy, but applied to buffers with different data types. Semantics: - Function that is decorated with @T.macro can have any parameters that follow Python syntax, i.e. positional, keyword, etc. Type annotations are not required, but are allowed. - The arguments to `T.insert` are macro name followed by the argument list. For `T.insert(arg1, arg2, arg3, ...)`, the values are substituted into the body of the macro as in the call `arg1(arg2, arg3, ...)`. The body with the substituted values is then inserted at the point where the `T.insert` is located. * Fix linter * Fix linter again One linter suggested something that the other didn't like... * Get rid of T.insert, apply macro via function-call syntax * Store closure vars in TIRMacro * ast.parse always returns ast.Module, hence doc is doc.Module * Simplify `expand_macro`, capture environment variables * Implement macro hygiene * Fix linter * Make T.macro work same as T.macro() The previous commit inadvertently made T.macro (without parentheses) illegal, only abbreviated form allowed was T.macro(). Restore T.macro as a valid decorator use. * Edit comment: insertion -> expansion * Add import pytest * One more typo... * Remove stale testcase
* [TIR] Implement TIR macros This patch introduces two new symbols: `T.macro` and `T.insert`. `T.macro` is a decorator that, when applied to a function, turns the body of that function into a piece of TIR that can be inserted via `T.insert` into a PrimFunc. For example: ```python @T.macro def copy_backwards(dst, src, size): with T.block("backwards"): for i in T.serial(size): ai = T.axis.remap("S", [i]) T.reads(src[0:size]) T.writes(dst[0:size]) dst[ai] = src[size - ai - 1] @T.prim_func def foo_int32(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")): T.insert(copy_backwards, A, B, 128) @T.prim_func def foo_int8(A: T.Buffer((128,), "int8"), B: T.Buffer((128,), "int8")): T.insert(copy_backwards, A, B, 128) ``` The above will generate two PrimFuncs that do the same backwards copy, but applied to buffers with different data types. Semantics: - Function that is decorated with @T.macro can have any parameters that follow Python syntax, i.e. positional, keyword, etc. Type annotations are not required, but are allowed. - The arguments to `T.insert` are macro name followed by the argument list. For `T.insert(arg1, arg2, arg3, ...)`, the values are substituted into the body of the macro as in the call `arg1(arg2, arg3, ...)`. The body with the substituted values is then inserted at the point where the `T.insert` is located. * Fix linter * Fix linter again One linter suggested something that the other didn't like... * Get rid of T.insert, apply macro via function-call syntax * Store closure vars in TIRMacro * ast.parse always returns ast.Module, hence doc is doc.Module * Simplify `expand_macro`, capture environment variables * Implement macro hygiene * Fix linter * Make T.macro work same as T.macro() The previous commit inadvertently made T.macro (without parentheses) illegal, only abbreviated form allowed was T.macro(). Restore T.macro as a valid decorator use. * Edit comment: insertion -> expansion * Add import pytest * One more typo... * Remove stale testcase
* [TIR] Implement TIR macros This patch introduces two new symbols: `T.macro` and `T.insert`. `T.macro` is a decorator that, when applied to a function, turns the body of that function into a piece of TIR that can be inserted via `T.insert` into a PrimFunc. For example: ```python @T.macro def copy_backwards(dst, src, size): with T.block("backwards"): for i in T.serial(size): ai = T.axis.remap("S", [i]) T.reads(src[0:size]) T.writes(dst[0:size]) dst[ai] = src[size - ai - 1] @T.prim_func def foo_int32(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")): T.insert(copy_backwards, A, B, 128) @T.prim_func def foo_int8(A: T.Buffer((128,), "int8"), B: T.Buffer((128,), "int8")): T.insert(copy_backwards, A, B, 128) ``` The above will generate two PrimFuncs that do the same backwards copy, but applied to buffers with different data types. Semantics: - Function that is decorated with @T.macro can have any parameters that follow Python syntax, i.e. positional, keyword, etc. Type annotations are not required, but are allowed. - The arguments to `T.insert` are macro name followed by the argument list. For `T.insert(arg1, arg2, arg3, ...)`, the values are substituted into the body of the macro as in the call `arg1(arg2, arg3, ...)`. The body with the substituted values is then inserted at the point where the `T.insert` is located. * Fix linter * Fix linter again One linter suggested something that the other didn't like... * Get rid of T.insert, apply macro via function-call syntax * Store closure vars in TIRMacro * ast.parse always returns ast.Module, hence doc is doc.Module * Simplify `expand_macro`, capture environment variables * Implement macro hygiene * Fix linter * Make T.macro work same as T.macro() The previous commit inadvertently made T.macro (without parentheses) illegal, only abbreviated form allowed was T.macro(). Restore T.macro as a valid decorator use. * Edit comment: insertion -> expansion * Add import pytest * One more typo... * Remove stale testcase
This patch introduces new symbol:
T.macro
.T.macro
is a decorator that, when applied to a function, turns the body of that function into a piece of TIR that can be inserted into a PrimFunc by calling the macro the same way as a function.For example:
The above will generate two PrimFuncs that do the same backwards copy, but applied to buffers with different data types.
Semantics:
macro_name(arg1, arg2, arg3, ...)
, the values are substituted into the body of the macro, and the body with the substituted values is then inserted at the point where the call is located.hygienic
toT.macro
, e.g.T.macro(hygienic=False)
.