From de1acf4e8b87615af8cc42e304f49c3c9bffb74b Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 18 Sep 2023 19:54:33 -0400 Subject: [PATCH] [Codegen][ROCm] Mismatched Dtype of Workgroup/Workitem This PR fixes a ROCm codegen error that the dtype of `@llvm.amdgcn.workgroup.id*` and `@llvm.amdgcn.workitem.id.*` are always i32 when generating LLVM IR, even if it's marked as T.int64 in TIR. An example that triggers this issue: ```python @T.prim_func def encode_kernel(A: T.handle("float16", "global"), max_abs_value: T.handle("float16", "global"), v: T.int64): T.func_attr({"calling_conv": 2, "target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["rocm", "gpu"], "kind": "rocm", "max_num_threads": 256, "max_shared_memory_per_block": 65536, "max_threads_per_block": 1024, "mcpu": "gfx1100", "mtriple": "amdgcn-amd-amdhsa-hcc", "tag": "", "thread_warp_size": 32}), "tir.is_global_func": T.bool(True), "tir.kernel_launch_params": ["blockIdx.x", "threadIdx.x"], "tir.noalias": T.bool(True)}) A_1 = T.decl_buffer((v * T.int64(8192),), "float16", data=A) max_abs_value_1 = T.decl_buffer((T.min(v, (v * T.int64(256) + T.int64(65535)) // T.int64(65536) * T.int64(256)) * T.int64(256),), "float16", data=max_abs_value) blockIdx_x = T.launch_thread("blockIdx.x", T.int64(256)) threadIdx_x = T.launch_thread("threadIdx.x", T.int64(256)) for i_j_fused_0, k in T.grid(T.shift_right(v + T.int64(255), T.int64(8)), T.int64(32)): if i_j_fused_0 * T.int64(256) + blockIdx_x - v < T.int64(0): if k == T.int64(0): max_abs_value_1[i_j_fused_0 * T.int64(65536) + blockIdx_x * T.int64(256) + threadIdx_x] = T.float16(-65504) max_abs_value_1[i_j_fused_0 * T.int64(65536) + blockIdx_x * T.int64(256) + threadIdx_x] = T.max(max_abs_value_1[i_j_fused_0 * T.int64(65536) + blockIdx_x * T.int64(256) + threadIdx_x], T.call_pure_extern("float16", "__ocml_fabs_f16", A_1[i_j_fused_0 * T.int64(2097152) + blockIdx_x * T.int64(8192) + threadIdx_x * T.int64(32) + k])) ``` --- src/target/llvm/codegen_amdgpu.cc | 3 ++- src/target/llvm/intrin_rule_rocm.cc | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index d95f985fe63f..0ab4771f6c69 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -187,7 +187,8 @@ class CodeGenAMDGPU : public CodeGenLLVM { } } llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id); - return builder_->CreateCall(f, {}); + llvm::Value* result = builder_->CreateCall(f, {}); + return this->CreateCast(DataType::Int(32), iv->var->dtype, result); } llvm::Value* CreateStorageSync(const CallNode* op) final { diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 17baaf3e657a..d25126f5d828 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -89,7 +89,7 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { index = self + delta; index = Select((self & (width - 1)) + delta >= width, self, index); } - PrimExpr res = Call(var.dtype(), builtin::call_pure_extern(), + PrimExpr res = Call(DataType::Int(32), builtin::call_pure_extern(), {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var}); return res; }