Skip to content

Commit

Permalink
[ROCm] Fix some ROCm codegen bugs (#15454)
Browse files Browse the repository at this point in the history
* rocm bug fix:Module hip should be either dso exportable or binary serializable

rocm bug fix: llvm.amdgcn.ds.bpermute Intrinsic has incorrect return type

rocm bug fix:ptr addrspace(3) @shmem Global is external, but doesn't have external or weak linkage

Co-authored-by: zhangxiao-stack <1244360827@qq.com>

* lint

---------

Co-authored-by: zhangxiao-stack <zhangqha@sugon.com>
Co-authored-by: zhangxiao-stack <1244360827@qq.com>
  • Loading branch information
3 people authored Aug 2, 2023
1 parent b77d659 commit bab295e
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/runtime/rocm/rocm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ class ROCMModuleNode : public runtime::ModuleNode {
}

const char* type_key() const final { return "hip"; }

int GetPropertyMask() const final {
return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable;
}
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final;

void SaveToFile(const String& file_name, const String& format) final {
Expand Down
4 changes: 2 additions & 2 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,8 @@ llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, size_t s
llvm::GlobalValue::LinkageTypes linkage) {
llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size);
llvm::GlobalVariable* global =
new llvm::GlobalVariable(*module_, type, false, linkage, nullptr, "shmem", nullptr,
llvm::GlobalValue::NotThreadLocal, shared_address_space);
new llvm::GlobalVariable(*module_, type, false, linkage, llvm::UndefValue::get(type), "shmem",
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
#if TVM_LLVM_VERSION >= 100
global->setAlignment(llvm::Align(alignment));
#else
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
(std::any_of(types.begin(), types.end(), [](DataType ty) {
if (ty.is_vector()) return true;
if ((ty.is_vector()) || !ty.is_int()) return true;
return ty.bits() != 32;
}))) {
return false;
Expand Down

0 comments on commit bab295e

Please sign in to comment.