Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TIR] SplitHostDevice, handle subroutines (apache#14918)
This PR refactors SplitHostDevice into three separate transformations. Previously, SplitHostDevice would replace device regions with a builtin::tvm_call_packed() node to replace the extracted region. After this PR, this process is performed in three separate steps. AnnotateDeviceRegion: Annotate the regions that should be executed on another target. SplitHostDevice: Extract the annotated region into an independent PrimFunc, with a GlobalVar to represent the call from into the new subroutine. LowerDeviceKernelLaunch: For any subroutine call where the caller and callee are on different devices, replace with a device kernel launch. * PR#14915 [TVMScript] Allow T.target("device", host="host") in TVMScript Prior to this commit, the `TargetNode::host` could be specified in TVMScript as part of the config dictionary, under the key `"host"`. However, this required all other device parameters to be explicitly specified, rather than using any of the short-hand string representations. This commit forwards the `host` argument from TVMScript's `T.target` method to `tvm.target.Target`, allowing both the device and host to be specified using the shorthand string representation. ```python @T.prim_func def before_this_commit(): T.func_attr( { "target": T.target( { "arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32, } ) } ) T.evaluate(0) @T.prim_func def after_this_commit(): T.func_attr({"target": T.target("cuda", host="llvm")}) T.evaluate(0) ``` * [Target] Added WithoutHost method * [TIR] SplitHostDevice, handle missing kGlobalSymbol Previously, the symbol name of the extracted compute kernel was defined based on the `kGlobalSymbol` attribute, which was required to be present. This commit updates `SplitHostDevice` to generate the symbol name using `kGlobalSymbol` if present, and to fall back to the name of the `tvm::GlobalVar` for internal functions. * [TIR] Refactor SplitHostDevice into three separate passes First pass, `AnnotateDeviceRegions`. This pass decides which portions of a PrimFunc should be run on the device, and annotates them with `kTarget` attribute, indicating which target should be used for later lowering steps. Second pass, `SplitHostDevice`. This pass extracts the annotated region into an independent PrimFunc. The `kTarget` attribute of the extracted kernel is defined by the `kTarget` annotation inserted by `AnnotateDeviceRegions`. The host function is marked by the `tvm::tir::attr::kIsHostFunc` attribute, allowing it to be recognized by later host-only lowering passes. Third pass, `LowerDeviceKernelLaunch`. This pass identifies subroutine calls that call into device kernels, and rewrites them into `T.tvm_call_packed`. * Add unit tests specifically for SplitHostDevice behavior * Added unit test specifically for AnnotateDeviceRegions * Added unit tests for LowerDeviceKernelLaunch * Minor cleanup, moved all kernel launch collection into one spot Previously, the SplitHostDevice pass added the `tir::attr::kKernelLaunchParams` attribute, and the LowerDeviceKernelLaunch pass filled in the values for it. This cleanup makes the kernel launch params be the sole responsibility of LowerDeviceKernelLaunch. * Updated unit tests for LowerWarpMemory * Updated unit tests for ThreadSync * Updated unit test for inject ptx async copy * [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI PRs apache#14913 and apache#14914 made analogous changes to `MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls. Both PRs introduced the same symbol, `tvm::tir::SubroutineCallRewriter`, a local utility to update internal calls to a modified function. While each PR passed CI individually, and was therefore able to merge, having both changes caused a duplicate symbol. This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place their local utilities into anonymous namespaces, avoiding the conflict. * Maintain "tir.is_global_func" attr in device-side entry point * SplitHostDevice, update the host-side target to be the host * [TIR] Update LowerDeviceKernelLaunch to avoid kIsHostFunc Update to use the `tvm::tir::IsHostFunc` utility function, rather than the `kIsHostFunc` attribute. Per discussion on apache#14020, the `kIsHostFunct` attribute should only be used in `BindTarget`, and should not be re-introduced in `SplitHostDevice`. * Remove is_host_func from SplitHostDevice tests
- Loading branch information