-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] Enable Host Func Attribute for PrimFunc #14020
Conversation
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
Hey Xiyou, My understanding of I probably don't have much context on this but was curious:
|
Hi, thanks for checking my PR this late! Very good questions! Let me share some of the context here. We are trying to support a dynamic shape operator on Cuda. This function is generated during a relax pass called
Apparently, it's supposed to be running on CPU, i.e, the host instead of the device. However, since this pass doesn't have access to the target information, when the function is generated it doesn't include the target in its attribute. Therefore, we would like to add an attribute to automatically bind it to the target host in For Q1, It does not fail because this pass is after |
I'm not sure if I'm missing anything, but I do think the tvm/src/tir/analysis/verify_memory.cc Lines 180 to 185 in d7253fb
|
Yes, it's available in |
Thanks @zxybazh . After looking at the discussions especially inputs from junru. I think it would be great to clarify that we want After explicit target being attached to the function then such attr is no longer necessary and can be a source of duplication. So we only need changes for BindTarget here and possibly a UT that pass only |
Thanks @zxybazh @tqchen for the clarification - this has been much clear to me now! Let's do the following change:
|
src/tir/transforms/primfunc_utils.cc
Outdated
@@ -30,6 +30,12 @@ namespace tir { | |||
namespace transform { | |||
transform::Pass BindTarget(Target target) { | |||
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { | |||
if (f->GetAttr<Integer>(tvm::tir::attr::kIsHostFunc) == 1) { | |||
return WithAttrs(std::move(f), Map<String, ObjectRef>{ | |||
{tvm::attr::kTarget, target->host.value_or(Target("llvm"))}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is the best option when target host is not available. This is my impression on the default target host.
Thanks for the careful review and discussion. I've removed duplicate changes and created a unittest that checks target and host func attribute. Please take another look :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One final nit
@@ -30,6 +30,10 @@ namespace tir { | |||
namespace transform { | |||
transform::Pass BindTarget(Target target) { | |||
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { | |||
if (f->GetAttr<Integer>(tvm::tir::attr::kIsHostFunc) == 1) { | |||
return WithAttr(std::move(WithoutAttr(std::move(f), tvm::tir::attr::kIsHostFunc)), | |||
tvm::attr::kTarget, target->host.value_or(Target("llvm"))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a case where the target host is None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, when we use target tags like nvidia/geforce-rtx-3070
the default target host is None
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it. that makes sense!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Feel free to merge it in once the CI is green!
Update to use the `tvm::tir::IsHostFunc` utility function, rather than the `kIsHostFunc` attribute. Per discussion on apache#14020, the `kIsHostFunct` attribute should only be used in `BindTarget`, and should not be re-introduced in `SplitHostDevice`.
Per discussion on apache#14020, the `kIsHostFunct` attribute should only be used in `BindTarget`, and should not be re-introduced in `SplitHostDevice`.
Update to use the `tvm::tir::IsHostFunc` utility function, rather than the `kIsHostFunc` attribute. Per discussion on apache#14020, the `kIsHostFunct` attribute should only be used in `BindTarget`, and should not be re-introduced in `SplitHostDevice`.
Per discussion on apache#14020, the `kIsHostFunct` attribute should only be used in `BindTarget`, and should not be re-introduced in `SplitHostDevice`.
Update to use the `tvm::tir::IsHostFunc` utility function, rather than the `kIsHostFunc` attribute. Per discussion on apache#14020, the `kIsHostFunct` attribute should only be used in `BindTarget`, and should not be re-introduced in `SplitHostDevice`.
Per discussion on apache#14020, the `kIsHostFunct` attribute should only be used in `BindTarget`, and should not be re-introduced in `SplitHostDevice`.
Update to use the `tvm::tir::IsHostFunc` utility function, rather than the `kIsHostFunc` attribute. Per discussion on apache#14020, the `kIsHostFunct` attribute should only be used in `BindTarget`, and should not be re-introduced in `SplitHostDevice`.
This PR refactors SplitHostDevice into three separate transformations. Previously, SplitHostDevice would replace device regions with a builtin::tvm_call_packed() node to replace the extracted region. After this PR, this process is performed in three separate steps. AnnotateDeviceRegion: Annotate the regions that should be executed on another target. SplitHostDevice: Extract the annotated region into an independent PrimFunc, with a GlobalVar to represent the call from into the new subroutine. LowerDeviceKernelLaunch: For any subroutine call where the caller and callee are on different devices, replace with a device kernel launch. * PR#14915 [TVMScript] Allow T.target("device", host="host") in TVMScript Prior to this commit, the `TargetNode::host` could be specified in TVMScript as part of the config dictionary, under the key `"host"`. However, this required all other device parameters to be explicitly specified, rather than using any of the short-hand string representations. This commit forwards the `host` argument from TVMScript's `T.target` method to `tvm.target.Target`, allowing both the device and host to be specified using the shorthand string representation. ```python @T.prim_func def before_this_commit(): T.func_attr( { "target": T.target( { "arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32, } ) } ) T.evaluate(0) @T.prim_func def after_this_commit(): T.func_attr({"target": T.target("cuda", host="llvm")}) T.evaluate(0) ``` * [Target] Added WithoutHost method * [TIR] SplitHostDevice, handle missing kGlobalSymbol Previously, the symbol name of the extracted compute kernel was defined based on the `kGlobalSymbol` attribute, which was required to be present. This commit updates `SplitHostDevice` to generate the symbol name using `kGlobalSymbol` if present, and to fall back to the name of the `tvm::GlobalVar` for internal functions. * [TIR] Refactor SplitHostDevice into three separate passes First pass, `AnnotateDeviceRegions`. This pass decides which portions of a PrimFunc should be run on the device, and annotates them with `kTarget` attribute, indicating which target should be used for later lowering steps. Second pass, `SplitHostDevice`. This pass extracts the annotated region into an independent PrimFunc. The `kTarget` attribute of the extracted kernel is defined by the `kTarget` annotation inserted by `AnnotateDeviceRegions`. The host function is marked by the `tvm::tir::attr::kIsHostFunc` attribute, allowing it to be recognized by later host-only lowering passes. Third pass, `LowerDeviceKernelLaunch`. This pass identifies subroutine calls that call into device kernels, and rewrites them into `T.tvm_call_packed`. * Add unit tests specifically for SplitHostDevice behavior * Added unit test specifically for AnnotateDeviceRegions * Added unit tests for LowerDeviceKernelLaunch * Minor cleanup, moved all kernel launch collection into one spot Previously, the SplitHostDevice pass added the `tir::attr::kKernelLaunchParams` attribute, and the LowerDeviceKernelLaunch pass filled in the values for it. This cleanup makes the kernel launch params be the sole responsibility of LowerDeviceKernelLaunch. * Updated unit tests for LowerWarpMemory * Updated unit tests for ThreadSync * Updated unit test for inject ptx async copy * [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI PRs #14913 and #14914 made analogous changes to `MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls. Both PRs introduced the same symbol, `tvm::tir::SubroutineCallRewriter`, a local utility to update internal calls to a modified function. While each PR passed CI individually, and was therefore able to merge, having both changes caused a duplicate symbol. This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place their local utilities into anonymous namespaces, avoiding the conflict. * Maintain "tir.is_global_func" attr in device-side entry point * SplitHostDevice, update the host-side target to be the host * [TIR] Update LowerDeviceKernelLaunch to avoid kIsHostFunc Update to use the `tvm::tir::IsHostFunc` utility function, rather than the `kIsHostFunc` attribute. Per discussion on #14020, the `kIsHostFunct` attribute should only be used in `BindTarget`, and should not be re-introduced in `SplitHostDevice`. * Remove is_host_func from SplitHostDevice tests
This PR refactors SplitHostDevice into three separate transformations. Previously, SplitHostDevice would replace device regions with a builtin::tvm_call_packed() node to replace the extracted region. After this PR, this process is performed in three separate steps. AnnotateDeviceRegion: Annotate the regions that should be executed on another target. SplitHostDevice: Extract the annotated region into an independent PrimFunc, with a GlobalVar to represent the call from into the new subroutine. LowerDeviceKernelLaunch: For any subroutine call where the caller and callee are on different devices, replace with a device kernel launch. * PR#14915 [TVMScript] Allow T.target("device", host="host") in TVMScript Prior to this commit, the `TargetNode::host` could be specified in TVMScript as part of the config dictionary, under the key `"host"`. However, this required all other device parameters to be explicitly specified, rather than using any of the short-hand string representations. This commit forwards the `host` argument from TVMScript's `T.target` method to `tvm.target.Target`, allowing both the device and host to be specified using the shorthand string representation. ```python @T.prim_func def before_this_commit(): T.func_attr( { "target": T.target( { "arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32, } ) } ) T.evaluate(0) @T.prim_func def after_this_commit(): T.func_attr({"target": T.target("cuda", host="llvm")}) T.evaluate(0) ``` * [Target] Added WithoutHost method * [TIR] SplitHostDevice, handle missing kGlobalSymbol Previously, the symbol name of the extracted compute kernel was defined based on the `kGlobalSymbol` attribute, which was required to be present. This commit updates `SplitHostDevice` to generate the symbol name using `kGlobalSymbol` if present, and to fall back to the name of the `tvm::GlobalVar` for internal functions. * [TIR] Refactor SplitHostDevice into three separate passes First pass, `AnnotateDeviceRegions`. This pass decides which portions of a PrimFunc should be run on the device, and annotates them with `kTarget` attribute, indicating which target should be used for later lowering steps. Second pass, `SplitHostDevice`. This pass extracts the annotated region into an independent PrimFunc. The `kTarget` attribute of the extracted kernel is defined by the `kTarget` annotation inserted by `AnnotateDeviceRegions`. The host function is marked by the `tvm::tir::attr::kIsHostFunc` attribute, allowing it to be recognized by later host-only lowering passes. Third pass, `LowerDeviceKernelLaunch`. This pass identifies subroutine calls that call into device kernels, and rewrites them into `T.tvm_call_packed`. * Add unit tests specifically for SplitHostDevice behavior * Added unit test specifically for AnnotateDeviceRegions * Added unit tests for LowerDeviceKernelLaunch * Minor cleanup, moved all kernel launch collection into one spot Previously, the SplitHostDevice pass added the `tir::attr::kKernelLaunchParams` attribute, and the LowerDeviceKernelLaunch pass filled in the values for it. This cleanup makes the kernel launch params be the sole responsibility of LowerDeviceKernelLaunch. * Updated unit tests for LowerWarpMemory * Updated unit tests for ThreadSync * Updated unit test for inject ptx async copy * [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI PRs apache#14913 and apache#14914 made analogous changes to `MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls. Both PRs introduced the same symbol, `tvm::tir::SubroutineCallRewriter`, a local utility to update internal calls to a modified function. While each PR passed CI individually, and was therefore able to merge, having both changes caused a duplicate symbol. This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place their local utilities into anonymous namespaces, avoiding the conflict. * Maintain "tir.is_global_func" attr in device-side entry point * SplitHostDevice, update the host-side target to be the host * [TIR] Update LowerDeviceKernelLaunch to avoid kIsHostFunc Update to use the `tvm::tir::IsHostFunc` utility function, rather than the `kIsHostFunc` attribute. Per discussion on apache#14020, the `kIsHostFunct` attribute should only be used in `BindTarget`, and should not be re-introduced in `SplitHostDevice`. * Remove is_host_func from SplitHostDevice tests
This PR enables a new attribute
kIsHostFunc
to ensure certrain prim func is run on CPU, for exampleshape_func
that computes shape information dynamically. With the new attribute, the primfunc will be skipped in verification pass and split host device pass. A unit test is added.CC: @sunggg @YuchenJin @tqchen