Skip to content

Commit

Permalink
[TIR] SplitHostDevice, handle subroutines (apache#14918)
Browse files Browse the repository at this point in the history
This PR refactors SplitHostDevice into three separate transformations. Previously, SplitHostDevice would replace device regions with a builtin::tvm_call_packed() node to replace the extracted region. After this PR, this process is performed in three separate steps.

AnnotateDeviceRegion: Annotate the regions that should be executed on another target.
SplitHostDevice: Extract the annotated region into an independent PrimFunc, with a GlobalVar to represent the call from into the new subroutine.
LowerDeviceKernelLaunch: For any subroutine call where the caller and callee are on different devices, replace with a device kernel launch.

* PR#14915 [TVMScript] Allow T.target("device", host="host") in TVMScript

Prior to this commit, the `TargetNode::host` could be specified in
TVMScript as part of the config dictionary, under the key `"host"`.
However, this required all other device parameters to be explicitly
specified, rather than using any of the short-hand string
representations.  This commit forwards the `host` argument from TVMScript's
`T.target` method to `tvm.target.Target`, allowing both the device and
host to be specified using the shorthand string representation.

```python
@T.prim_func
def before_this_commit():
    T.func_attr(
        {
            "target": T.target(
                {
                    "arch": "sm_86",
                    "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""},
                    "keys": ["cuda", "gpu"],
                    "kind": "cuda",
                    "max_num_threads": 1024,
                    "tag": "",
                    "thread_warp_size": 32,
                }
            )
        }
    )
    T.evaluate(0)

@T.prim_func
def after_this_commit():
    T.func_attr({"target": T.target("cuda", host="llvm")})
    T.evaluate(0)
```

* [Target] Added WithoutHost method

* [TIR] SplitHostDevice, handle missing kGlobalSymbol

Previously, the symbol name of the extracted compute kernel was
defined based on the `kGlobalSymbol` attribute, which was required to
be present.  This commit updates `SplitHostDevice` to generate the
symbol name using `kGlobalSymbol` if present, and to fall back to the
name of the `tvm::GlobalVar` for internal functions.

* [TIR] Refactor SplitHostDevice into three separate passes

First pass, `AnnotateDeviceRegions`.  This pass decides which portions
of a PrimFunc should be run on the device, and annotates them with
`kTarget` attribute, indicating which target should be used for later
lowering steps.

Second pass, `SplitHostDevice`.  This pass extracts the annotated
region into an independent PrimFunc.  The `kTarget` attribute of the
extracted kernel is defined by the `kTarget` annotation inserted by
`AnnotateDeviceRegions`.  The host function is marked by the
`tvm::tir::attr::kIsHostFunc` attribute, allowing it to be recognized
by later host-only lowering passes.

Third pass, `LowerDeviceKernelLaunch`.  This pass identifies
subroutine calls that call into device kernels, and rewrites them into
`T.tvm_call_packed`.

* Add unit tests specifically for SplitHostDevice behavior

* Added unit test specifically for AnnotateDeviceRegions

* Added unit tests for LowerDeviceKernelLaunch

* Minor cleanup, moved all kernel launch collection into one spot

Previously, the SplitHostDevice pass added the
`tir::attr::kKernelLaunchParams` attribute, and the
LowerDeviceKernelLaunch pass filled in the values for it.  This
cleanup makes the kernel launch params be the sole responsibility of
LowerDeviceKernelLaunch.

* Updated unit tests for LowerWarpMemory

* Updated unit tests for ThreadSync

* Updated unit test for inject ptx async copy

* [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI

PRs apache#14913 and
apache#14914 made analogous changes to
`MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls.
Both PRs introduced the same symbol,
`tvm::tir::SubroutineCallRewriter`, a local utility to update internal
calls to a modified function.  While each PR passed CI individually,
and was therefore able to merge, having both changes caused a
duplicate symbol.

This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place
their local utilities into anonymous namespaces, avoiding the
conflict.

* Maintain "tir.is_global_func" attr in device-side entry point

* SplitHostDevice, update the host-side target to be the host

* [TIR] Update LowerDeviceKernelLaunch to avoid kIsHostFunc

Update to use the `tvm::tir::IsHostFunc` utility function, rather than
the `kIsHostFunc` attribute.  Per discussion on
apache#14020, the `kIsHostFunct` attribute
should only be used in `BindTarget`, and should not be re-introduced
in `SplitHostDevice`.

* Remove is_host_func from SplitHostDevice tests
  • Loading branch information
Lunderberg authored and mei-ye committed Jun 1, 2023
1 parent 527ba43 commit 260eb4b
Show file tree
Hide file tree
Showing 13 changed files with 908 additions and 239 deletions.
38 changes: 38 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,51 @@ TVM_DLL Pass LowerCustomDatatypes();
*/
TVM_DLL Pass DecorateDeviceScope();

/*!
* \brief Annotate locations that should be run on the device
*
* Insert `AttrStmt` nodes specifying a target on which regions within
* the PrimFunc should be executed. Only modifies functions that have
* a `tvm::attr::kTarget` attribute, and where that target defines a
* host.
*
* \return The pass.
*/
TVM_DLL Pass AnnotateDeviceRegions();

/*!
* \brief Split the function into a host function and device functions.
*
* The resulting host-side function will keep the same
* `tvm::attr::kTarget` attribute (e.g. `T.target("cuda",
* host=T.target("llvm"))`). This ensures that `MakePackedAPI` knows
* which device type should be used for the input buffers.
*
* The resulting device-side function will
* have the host stripped from its target attribute
* (e.g. `T.target("cuda")`).
*
* \return The pass.
*/
TVM_DLL Pass SplitHostDevice();

/*!
* \brief Lower cross-device function calls.
*
* Prior to this pass, host to device calls are represented as
* subroutine calls, with environment parameters (e.g. env_thread)
* specified internally. The device function is an internal function,
* without a `tvm::attr::kGlobalSymbol` attribute.
*
* After this pass, host to device calls are represented as
* tvm_call_packed built-in. The device function is an
* externally-exposed function, with a non-empty
* `tvm::attr::kGlobalSymbol` attribute.
*
* \return The pass.
*/
TVM_DLL Pass LowerDeviceKernelLaunch();

/*!
* \brief skip assert stmt.
*
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args):
The call expression.
"""
assert isinstance(global_var, tvm.ir.GlobalVar)
return Call(dtype="handle", op=global_var, args=args)
return Call(dtype="void", op=global_var, args=args)


def start_profile_intrinsic(id):
Expand Down
38 changes: 38 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,22 @@ def MakeUnpackedAPI():
return _ffi_api.MakeUnpackedAPI() # type: ignore


def AnnotateDeviceRegions():
"""Annotate locations that should be run on the device
Insert `AttrStmt` nodes specifying a target on which regions
within the PrimFunc should be executed. Only modifies functions
that have a `tvm::attr::kTarget` attribute, and where that target
defines a host.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.AnnotateDeviceRegions() # type: ignore


def SplitHostDevice():
"""Split the function into a host function and device functions.
Expand All @@ -446,6 +462,28 @@ def SplitHostDevice():
return _ffi_api.SplitHostDevice() # type: ignore


def LowerDeviceKernelLaunch():
"""Lower cross-device function calls.
Prior to this pass, host to device calls are represented as
subroutine calls, with environment parameters (e.g. env_thread)
specified internally. The device function is an internal
function, without a `tvm::attr::kGlobalSymbol` attribute.
After this pass, host to device calls are represented as
tvm_call_packed built-in. The device function is an
externally-exposed function, with a non-empty
`tvm::attr::kGlobalSymbol` attribute.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerDeviceKernelLaunch() # type: ignore


def DecorateDeviceScope():
"""Decorate all the function's body as device function.
Expand Down
3 changes: 3 additions & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,10 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());

mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());

return transform::Sequential(mixed_pass_list);
}
Expand Down
81 changes: 81 additions & 0 deletions src/tir/transforms/annotate_device_regions.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file annotate_device_regions.cc
* \brief Split device function from host.
*/
#include <tvm/ir/transform.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tir {

class DeviceRegionAnnotater : public StmtMutator {
public:
explicit DeviceRegionAnnotater(Target device_target) : device_target_(device_target) {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tvm::attr::kTarget) {
// If a target attribute already exists, use it as-is.
return GetRef<Stmt>(op);
} else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope ||
op->attr_key == attr::device_scope) {
// These attributes are only allowed in device-side code, so
// they should be annotated with the function's default target.
Stmt body = GetRef<Stmt>(op);
return AttrStmt(device_target_, tvm::attr::kTarget, 0, body);
} else {
// All other annotations are ignored
return StmtMutator::VisitStmt_(op);
}
}

private:
Target device_target_;
};

namespace transform {

Pass AnnotateDeviceRegions() {
auto pass_func = [](PrimFunc func, IRModule mod, PassContext ctx) -> PrimFunc {
auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute";
Target target = opt_target.value();

if (target->GetHost()) {
DeviceRegionAnnotater mutator(target.WithoutHost());
func.CopyOnWrite()->body = mutator(func->body);
}
return func;
};

return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateDeviceRegions", {});
}

TVM_REGISTER_GLOBAL("tir.transform.AnnotateDeviceRegions").set_body_typed(AnnotateDeviceRegions);

} // namespace transform
} // namespace tir
} // namespace tvm
Loading

0 comments on commit 260eb4b

Please sign in to comment.