Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVM] Encapsulate LLVM target for use with LLVM libraries #11933

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/tvm/support/with.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ class With {
/*! \brief destructor, leaves the scope of the context. */
~With() DMLC_THROW_EXCEPTION { ctx_.ExitWithScope(); }

ContextType* operator->() { return &ctx_; }
const ContextType* operator->() const { return &ctx_; }
ContextType& operator*() { return ctx_; }
const ContextType* operator*() const { return ctx_; }

private:
/*! \brief internal context type. */
ContextType ctx_;
Expand Down
30 changes: 12 additions & 18 deletions src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
#include "../../runtime/rocm/rocm_module.h"
#include "../build_common.h"
#include "codegen_llvm.h"
#include "llvm_common.h"
#include "llvm_scope.h"

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -238,48 +238,41 @@ class CodeGenAMDGPU : public CodeGenLLVM {
}

protected:
void InitTarget(llvm::TargetMachine* tm) final {
void InitTarget() final {
// Maximum vector lane = float4
native_vector_bits_ = 4 * 32;
CodeGenLLVM::InitTarget(tm);
CodeGenLLVM::InitTarget();
}
};

runtime::Module BuildAMDGPU(IRModule mod, Target target) {
LLVMScope llvm_scope;

With<LLVMTarget> llvm_target(llvm_scope, target);
#if TVM_LLVM_VERSION < 90
LOG(FATAL) << "AMDGPU backend requires at least LLVM 9";
// Lower versions will crash when loading the bitcode, see
// issue #4087 for a discussion
#endif
InitializeLLVM();
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target);
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
// careful: cg will hold a naked pointer reference to ctx, so it should
// have a shorter lifetime than the ctx.
std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU());

cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false, false);
cg->Init("TVMAMDGPUModule", llvm_target.operator->(), false, false, false);

cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) {
ICHECK(kv.second->template IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
return Downcast<PrimFunc>(kv.second);
});

llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
const auto* find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path");
Array<runtime::String> bitcode_files = (*find_rocm_bitcodes)();

for (auto& bitcode_path : bitcode_files) {
std::string path = bitcode_path;
llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, *ctx);
if (mlib.get() == nullptr) {
std::string msg(err.getMessage());
LOG(FATAL) << "Fail to load bitcode file " << path << "\n"
<< "line " << err.getLineNo() << ":" << msg;
}
mlib->setTargetTriple(tm->getTargetTriple().str());
std::unique_ptr<llvm::Module> mlib = llvm_scope.LoadIR(bitcode_path);
mlib->setTargetTriple(llvm_target->GetTargetTriple());
mlib->setDataLayout(tm->createDataLayout());

for (llvm::Function& f : mlib->functions()) {
f.addFnAttr(llvm::Attribute::AlwaysInline);
}
Expand Down Expand Up @@ -351,4 +344,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_rocm")

} // namespace codegen
} // namespace tvm

#endif // TVM_LLVM_VERSION
5 changes: 3 additions & 2 deletions src/target/llvm/codegen_arm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ class CodeGenARM final : public CodeGenCPU {
CodeGenARM() = default;
virtual ~CodeGenARM() = default;

void InitTarget(llvm::TargetMachine* tm) final {
void InitTarget() final {
// set native vector bits.
native_vector_bits_ = 16 * 8;
CodeGenCPU::InitTarget(tm);
CodeGenCPU::InitTarget();
}
llvm::Value* CreateIntrinsic(const CallNode* op) override;

Expand Down Expand Up @@ -139,4 +139,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")

} // namespace codegen
} // namespace tvm

#endif // TVM_LLVM_VERSION
24 changes: 10 additions & 14 deletions src/target/llvm/codegen_blob.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,20 @@
#include <string>
#include <utility>

#include "llvm_common.h"
#include "llvm_scope.h"

namespace tvm {
namespace codegen {

std::pair<std::unique_ptr<llvm::Module>, std::shared_ptr<llvm::LLVMContext>> CodeGenBlob(
const std::string& data, bool system_lib, const std::string& llvm_target_string) {
InitializeLLVM();
Target target(llvm_target_string);
auto tm = GetLLVMTargetMachine(target);
auto triple = tm->getTargetTriple();
auto ctx = std::make_shared<llvm::LLVMContext>();
std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_lib,
LLVMTarget* llvm_target) {
llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
const llvm::Triple& triple = tm->getTargetTriple();
llvm::LLVMContext* ctx = llvm_target->GetContext();
std::string module_name = "devc";
std::unique_ptr<llvm::Module> module(new llvm::Module(module_name, *ctx));
auto module = std::make_unique<llvm::Module>(module_name, *ctx);
module->setTargetTriple(triple.str());
// Store full target string in metadata, because flags such as -mfloat-abi must be preserved for
// ModulePackImportsToLLVM.
module->addModuleFlag(llvm::Module::ModFlagBehavior::Override, "tvm_target",
llvm::MDString::get(*ctx, LLVMTargetToString(target)));
llvm_target->SetTargetMetadata(module.get());
module->setDataLayout(tm->createDataLayout());
auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false);
auto* tvm_dev_mblob = new llvm::GlobalVariable(
Expand Down Expand Up @@ -188,9 +183,10 @@ std::pair<std::unique_ptr<llvm::Module>, std::shared_ptr<llvm::LLVMContext>> Cod
ir_builder.CreateRetVoid();
}

return std::make_pair(std::move(module), ctx);
return module;
}

} // namespace codegen
} // namespace tvm

#endif // TVM_LLVM_VERSION
15 changes: 9 additions & 6 deletions src/target/llvm/codegen_blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@

#ifdef TVM_LLVM_VERSION

#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Module.h>

#include <memory>
#include <string>
#include <utility>

namespace llvm {
class Module;
}

namespace tvm {
namespace codegen {

class LLVMTarget;

/**
* \brief Code Generation of blob data
*
Expand All @@ -44,8 +47,8 @@ namespace codegen {
*
* \return LLVM module and LLVM context
*/
std::pair<std::unique_ptr<llvm::Module>, std::shared_ptr<llvm::LLVMContext>> CodeGenBlob(
const std::string& data, bool system_lib, const std::string& llvm_target_string);
std::unique_ptr<llvm::Module> CodeGenBlob(const std::string& data, bool system_lib,
LLVMTarget* llvm_target);

} // namespace codegen
} // namespace tvm
Expand Down
Loading