Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codegen llvm: move nvptx-specific intrinsic handling into codegen_nvptx #5726

Merged
merged 1 commit into from
Jun 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 0 additions & 52 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -736,40 +736,7 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type
#endif // TVM_LLVM_VERSION
}

// Check if this is a warp shuffle intrinsic call and match its
// corresponding nvvm intrinsic. Return true if the match is successful.
static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) {
// Only 32 bit data type is supported.
if (op->dtype.is_vector() || op->dtype.bits() != 32) {
return false;
}

// Intrinsic lookup table.
// It is difficult to emit _sync verion that works on Pascal.
// We ignore the mask and only emit the non-sync version for nvptx.
llvm::Intrinsic::ID ids[] = {
llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32,
llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32,
llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32};

int offset = 0;
if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) {
offset = 0;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) {
offset = 2;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) {
offset = 4;
} else {
return false;
}

*id = ids[offset + op->dtype.is_float()];
return true;
}

llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic;

if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
Expand Down Expand Up @@ -814,25 +781,6 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
}
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return CreateStorageSync(op);
} else if (GetWarpShuffleIntrinsic(op, &id)) {
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
// Ignore the first mask operand and remove the last
// redundant warp_size..
size_t n_args = op->args.size() - 1;
for (size_t i = 1; i < n_args; ++i) {
arg_value.push_back(MakeValue(op->args[i]));
arg_type.push_back(arg_value.back()->getType());
}
llvm::Type* return_type = arg_type[0];
llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type);
return builder_->CreateCall(func, arg_value);
} else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) {
// Only nvptx target may keep this intrinsic at this point.
// PTX assembly: asm "activemask.b32 r1;"
auto fty = llvm::FunctionType::get(t_int32_, false);
auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true);
return builder_->CreateCall(val);
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const LoadNode* l = op->args[0].as<LoadNode>();
CHECK(op->args.size() == 1 && l);
Expand Down
62 changes: 61 additions & 1 deletion src/target/llvm/codegen_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ class CodeGenNVPTX : public CodeGenLLVM {
CodeGenLLVM::Optimize();
}

llvm::Value* CreateIntrinsic(const CallNode* op) override;

protected:
void InitTarget(llvm::TargetMachine* tm) final {
// Maximum vector lane = float4
Expand All @@ -178,6 +180,62 @@ class CodeGenNVPTX : public CodeGenLLVM {
}
};

// Check if this is a warp shuffle intrinsic call and match its
// corresponding nvvm intrinsic. Return true if the match is successful.
static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) {
// Only 32 bit data type is supported.
if (op->dtype.is_vector() || op->dtype.bits() != 32) {
return false;
}

// Intrinsic lookup table.
// It is difficult to emit _sync verion that works on Pascal.
// We ignore the mask and only emit the non-sync version for nvptx.
llvm::Intrinsic::ID ids[] = {
llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32,
llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32,
llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32};

int offset = 0;
if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) {
offset = 0;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) {
offset = 2;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) {
offset = 4;
} else {
return false;
}

*id = ids[offset + op->dtype.is_float()];
return true;
}

llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) {
llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic;
if (GetWarpShuffleIntrinsic(op, &id)) {
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
// Ignore the first mask operand and remove the last
// redundant warp_size..
size_t n_args = op->args.size() - 1;
for (size_t i = 1; i < n_args; ++i) {
arg_value.push_back(MakeValue(op->args[i]));
arg_type.push_back(arg_value.back()->getType());
}
llvm::Type* return_type = arg_type[0];
llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type);
return builder_->CreateCall(func, arg_value);
} else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) {
// Only nvptx target may keep this intrinsic at this point.
// PTX assembly: asm "activemask.b32 r1;"
auto fty = llvm::FunctionType::get(t_int32_, false);
auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true);
return builder_->CreateCall(val);
}
return CodeGenLLVM::CreateIntrinsic(op);
}

inline int DetectCUDAComputeVersion() {
TVMContext tvm_ctx;
tvm_ctx.device_type = kDLGPU;
Expand All @@ -204,8 +262,10 @@ runtime::Module BuildNVPTX(IRModule mod, std::string target) {
config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" << compute_ver
<< target.substr(5, target.length() - 5);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
// careful: cg will hold a naked pointer reference to ctx, so it should
// have a shorter lifetime than the ctx.
std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());

cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false);

Expand Down
3 changes: 3 additions & 0 deletions topi/python/topi/cuda/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def schedule_softmax(outs):
def sched_warp_softmax():
if tgt.target_name == "nvptx":
return softmax.dtype == "float32" or softmax.dtype == "int32"
if tgt.target_name != "cuda":
# this is used as the gpu schedule for other arches which may not have warp reductions
return False
return True

if len(softmax.shape) > 2:
Expand Down