Skip to content

Commit

Permalink
[LLVM] Include LLVM headers in files that use them, not in llvm_commo…
Browse files Browse the repository at this point in the history
…n.h (#11888)

This is following the same principle we use everywhere else in TVM, that
is, every source file includes headers that it depends on. While including
unnecessary LLVM headers (which may happen by including llvm_common.h)
is not actively harmful, it makes the header dependencies much less trans-
parent.
  • Loading branch information
Krzysztof Parzyszek authored Jun 25, 2022
1 parent 98bf40f commit 59fb421
Show file tree
Hide file tree
Showing 19 changed files with 374 additions and 91 deletions.
41 changes: 33 additions & 8 deletions src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,35 @@
*/
#ifdef TVM_LLVM_VERSION

#include <llvm/ADT/SmallString.h>
#include <llvm/IR/Attributes.h>
#include <llvm/IR/CallingConv.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/GlobalValue.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/Intrinsics.h>
#if TVM_LLVM_VERSION >= 100
#include <llvm/IR/IntrinsicsAMDGPU.h>
#endif
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IRReader/IRReader.h>
#if TVM_LLVM_VERSION >= 100
#include <llvm/Support/Alignment.h>
#endif
#include <llvm/Support/CodeGen.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
#include <llvm/Transforms/Utils/Cloning.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>

#include "../../runtime/rocm/rocm_module.h"
#include "../build_common.h"
#include "codegen_llvm.h"
#include "llvm_common.h"

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -60,6 +82,9 @@ static inline int DetectROCMmaxThreadsPerBlock() {
// AMDGPU code generator.
class CodeGenAMDGPU : public CodeGenLLVM {
public:
CodeGenAMDGPU() = default;
virtual ~CodeGenAMDGPU() = default;

void AddFunction(const PrimFunc& f) final {
// add function as void return value
CodeGenLLVM::AddFunctionInternal(f, true);
Expand Down Expand Up @@ -128,17 +153,17 @@ class CodeGenAMDGPU : public CodeGenLLVM {
// Return the thread index via intrinsics.
llvm::Value* GetThreadIndex(const IterVar& iv) final {
runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x;
llvm::Intrinsic::ID intrin_id = llvm::Intrinsic::amdgcn_workitem_id_x;
if (ts.rank == 1) {
switch (ts.dim_index) {
case 0:
intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x;
intrin_id = llvm::Intrinsic::amdgcn_workitem_id_x;
break;
case 1:
intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_y;
intrin_id = llvm::Intrinsic::amdgcn_workitem_id_y;
break;
case 2:
intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_z;
intrin_id = llvm::Intrinsic::amdgcn_workitem_id_z;
break;
default:
LOG(FATAL) << "unknown workitem idx";
Expand All @@ -147,13 +172,13 @@ class CodeGenAMDGPU : public CodeGenLLVM {
ICHECK_EQ(ts.rank, 0);
switch (ts.dim_index) {
case 0:
intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_x;
intrin_id = llvm::Intrinsic::amdgcn_workgroup_id_x;
break;
case 1:
intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_y;
intrin_id = llvm::Intrinsic::amdgcn_workgroup_id_y;
break;
case 2:
intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_z;
intrin_id = llvm::Intrinsic::amdgcn_workgroup_id_z;
break;
default:
LOG(FATAL) << "unknown workgroup idx";
Expand All @@ -169,7 +194,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
return nullptr;
} else if (sync == "shared") {
llvm::Function* f =
llvm::Intrinsic::getDeclaration(module_.get(), ::llvm::Intrinsic::amdgcn_s_barrier);
llvm::Intrinsic::getDeclaration(module_.get(), llvm::Intrinsic::amdgcn_s_barrier);
return builder_->CreateCall(f, {});
} else {
LOG(FATAL) << "Do not support sync " << sync;
Expand Down
14 changes: 11 additions & 3 deletions src/target/llvm/codegen_arm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
*/
#ifdef TVM_LLVM_VERSION

#include <llvm/IR/Intrinsics.h>
#include <tvm/runtime/registry.h>
#if TVM_LLVM_VERSION >= 100
#include <llvm/IR/IntrinsicsARM.h>
#endif
#include <llvm/Target/TargetMachine.h>

#include "codegen_cpu.h"

Expand All @@ -34,6 +39,9 @@ namespace codegen {
// how to override behavior llvm code generator for specific target
class CodeGenARM final : public CodeGenCPU {
public:
CodeGenARM() = default;
virtual ~CodeGenARM() = default;

void InitTarget(llvm::TargetMachine* tm) final {
// set native vector bits.
native_vector_bits_ = 16 * 8;
Expand All @@ -48,7 +56,7 @@ class CodeGenARM final : public CodeGenCPU {
llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
if (id == ::llvm::Intrinsic::ctpop) {
if (id == llvm::Intrinsic::ctpop) {
PrimExpr e = ARMPopcount(op);
return CodeGenCPU::CreateIntrinsic(e.as<CallNode>());
}
Expand All @@ -59,8 +67,8 @@ llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
using namespace tir;
const PrimExpr& e = call->args[2];
::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop;
::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu;
llvm::Intrinsic::ID ctpop_id = llvm::Intrinsic::ctpop;
llvm::Intrinsic::ID vpaddlu_id = llvm::Intrinsic::arm_neon_vpaddlu;

// Fallback to default llvm lowering rule if input type not a full vector or half vector length
int total_size = call->dtype.bits() * call->dtype.lanes();
Expand Down
26 changes: 26 additions & 0 deletions src/target/llvm/codegen_blob.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,38 @@
* \file codegen_blob.cc
*/
#ifdef TVM_LLVM_VERSION

#include "codegen_blob.h"

#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/ADT/Triple.h>
#include <llvm/ADT/Twine.h>
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/GlobalVariable.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Metadata.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Type.h>
#include <llvm/IR/Value.h>
#if TVM_LLVM_VERSION >= 100
#include <llvm/Support/Alignment.h>
#endif
#include <llvm/Target/TargetMachine.h>
#include <llvm/Transforms/Utils/ModuleUtils.h>
#include <tvm/runtime/module.h>
#include <tvm/target/target.h>

#include <cstring>
#include <memory>
#include <string>
#include <utility>

#include "llvm_common.h"

namespace tvm {
namespace codegen {
Expand Down
8 changes: 6 additions & 2 deletions src/target/llvm/codegen_blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@
*/
#ifndef TVM_TARGET_LLVM_CODEGEN_BLOB_H_
#define TVM_TARGET_LLVM_CODEGEN_BLOB_H_

#ifdef TVM_LLVM_VERSION

#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Module.h>

#include <memory>
#include <string>
#include <utility>

#include "llvm_common.h"

namespace tvm {
namespace codegen {
/**
Expand All @@ -46,5 +49,6 @@ std::pair<std::unique_ptr<llvm::Module>, std::shared_ptr<llvm::LLVMContext>> Cod

} // namespace codegen
} // namespace tvm

#endif // LLVM_VERSION
#endif // TVM_TARGET_LLVM_CODEGEN_BLOB_H_
31 changes: 31 additions & 0 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,31 @@

#include "codegen_cpu.h"

#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/IR/Argument.h>
#include <llvm/IR/Attributes.h>
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/CallingConv.h>
#include <llvm/IR/Comdat.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/DIBuilder.h>
#include <llvm/IR/DebugInfoMetadata.h>
#include <llvm/IR/DebugLoc.h>
#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/GlobalVariable.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/MDBuilder.h>
#include <llvm/IR/Metadata.h>
#include <llvm/IR/Module.h>
#if TVM_LLVM_VERSION >= 100
#include <llvm/Support/Alignment.h>
#endif
#include <llvm/Support/raw_ostream.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Transforms/Utils/ModuleUtils.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/module.h>
#include <tvm/tir/analysis.h>
Expand All @@ -35,9 +60,15 @@

#include "../func_registry_generator.h"
#include "../metadata_utils.h"

namespace tvm {
namespace codegen {

// Make these non-inline because of std::unique_ptr. See comment in
// codegen_llvm.cc for more information.
CodeGenCPU::CodeGenCPU() = default;
CodeGenCPU::~CodeGenCPU() = default;

void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm,
llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup,
bool target_c_runtime) {
Expand Down
22 changes: 22 additions & 0 deletions src/target/llvm/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,34 @@

#include "codegen_llvm.h"

namespace llvm {
class BasicBlock;
class Constant;
class DIBuilder;
class DIType;
class Function;
class FunctionType;
class GlobalVariable;
class LLVMContext;
class MDNode;
class StructType;
class TargetMachine;
class Type;
class Value;

// Used in std::unique_ptr
class Module;
} // namespace llvm

namespace tvm {
namespace codegen {

// CPU host code generation
class CodeGenCPU : public CodeGenLLVM {
public:
CodeGenCPU();
virtual ~CodeGenCPU();

void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx,
bool system_lib, bool dynamic_lookup, bool target_c_runtime) override;
void AddFunction(const PrimFunc& f) override;
Expand Down
32 changes: 26 additions & 6 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,33 @@

#if defined(TVM_LLVM_VERSION) && TVM_LLVM_VERSION >= 70

#include <llvm/ADT/ArrayRef.h>
#include <llvm/ADT/SmallString.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Bitcode/BitcodeWriter.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/GlobalVariable.h>
#include <llvm/IR/Instructions.h>
#if TVM_LLVM_VERSION <= 90
#include <llvm/IR/Intrinsics.h>
#else
#include <llvm/IR/IntrinsicsHexagon.h>
#endif
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/MDBuilder.h>
#include <llvm/IR/Module.h>
#if TVM_LLVM_VERSION >= 100
#include <llvm/Support/Alignment.h>
#endif
#include <llvm/Support/CodeGen.h>
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/FileSystem.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Transforms/Utils/Cloning.h>
#include <tvm/runtime/module.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/analysis.h>
Expand All @@ -42,6 +62,7 @@
#include "../../runtime/hexagon/hexagon_module.h"
#include "../build_common.h"
#include "codegen_cpu.h"
#include "llvm_common.h"

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -369,18 +390,17 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {
else
llvm::WriteBitcodeToFile(m, os);
} else if (cgft == Asm || cgft == Obj) {
using namespace llvm;
#if TVM_LLVM_VERSION <= 90
auto ft = cgft == Asm ? TargetMachine::CodeGenFileType::CGFT_AssemblyFile
: TargetMachine::CodeGenFileType::CGFT_ObjectFile;
auto ft = cgft == Asm ? llvm::TargetMachine::CodeGenFileType::CGFT_AssemblyFile
: llvm::TargetMachine::CodeGenFileType::CGFT_ObjectFile;
#else
auto ft = cgft == Asm ? llvm::CGFT_AssemblyFile : llvm::CGFT_ObjectFile;
#endif

SmallString<16384> ss; // Will grow on demand.
llvm::SmallString<16384> ss; // Will grow on demand.
llvm::raw_svector_ostream os(ss);
std::unique_ptr<llvm::Module> cm = CloneModule(m);
legacy::PassManager pass;
std::unique_ptr<llvm::Module> cm = llvm::CloneModule(m);
llvm::legacy::PassManager pass;
ICHECK(tm->addPassesToEmitFile(pass, os, nullptr, ft) == 0) << "Cannot emit target code";
pass.run(*cm.get());
out.assign(ss.c_str(), ss.size());
Expand Down
Loading

0 comments on commit 59fb421

Please sign in to comment.