Skip to content

Commit

Permalink
[TIR] Update LowerDeviceKernelLaunch to avoid kIsHostFunc
Browse files Browse the repository at this point in the history
Update to use the `tvm::tir::IsHostFunc` utility function, rather than
the `kIsHostFunc` attribute.  Per discussion on
apache#14020, the `kIsHostFunct` attribute
should only be used in `BindTarget`, and should not be re-introduced
in `SplitHostDevice`.
  • Loading branch information
Lunderberg committed May 24, 2023
1 parent 65ab408 commit 7fcc78f
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/tir/transforms/lower_device_kernel_launch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,11 @@ struct KernelInfo {
*/
class DeviceInfoCollector : public StmtVisitor {
public:
static KernelInfo Collect(const GlobalVar& gvar, const PrimFuncNode* func) {
static KernelInfo Collect(const GlobalVar& gvar, const PrimFunc& func) {
DeviceInfoCollector collector;
collector.info_.target = [&]() -> Target {
auto target_attr = func->GetAttr<Target>(tvm::attr::kTarget).value();
bool is_host_func =
func->GetAttr<Bool>(tvm::tir::attr::kIsHostFunc).value_or(Bool(false))->value;
bool is_host_func = IsHostFunc(func).value_or(false);
if (is_host_func) {
if (auto target_host = target_attr->GetHost()) {
return target_host.value();
Expand Down Expand Up @@ -291,8 +290,8 @@ Pass LowerDeviceKernelLaunch() {
auto mutator = [&mod]() {
std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto* prim_func = base_func.as<PrimFuncNode>()) {
device_info_map[gvar.get()] = DeviceInfoCollector::Collect(gvar, prim_func);
if (auto prim_func = base_func.as<PrimFunc>()) {
device_info_map[gvar.get()] = DeviceInfoCollector::Collect(gvar, prim_func.value());
}
}
return DeviceKernelMutator(std::move(device_info_map));
Expand Down

0 comments on commit 7fcc78f

Please sign in to comment.