Skip to content

Commit

Permalink
codegen llvm: move nvptx-specific intrinsic handling into codegen_nvptx
Browse files Browse the repository at this point in the history
See discussion in apache#5600.

I'm also throwing in a pointer lifetime fix for the context held by
NVPTX because otherwise topi/tests/python/test_topi_softmax.py
would sefault for me. With the test, I can also run resnet-18 on
the nvptx target in gpu_imagenet_bench.py.
  • Loading branch information
t-vi committed Jun 4, 2020
1 parent 4347b41 commit a77ee86
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 53 deletions.
52 changes: 0 additions & 52 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -736,40 +736,7 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type
#endif // TVM_LLVM_VERSION
}

// Check if this is a warp shuffle intrinsic call and match its
// corresponding nvvm intrinsic. Return true if the match is successful.
static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) {
// Only 32 bit data type is supported.
if (op->dtype.is_vector() || op->dtype.bits() != 32) {
return false;
}

// Intrinsic lookup table.
// It is difficult to emit _sync verion that works on Pascal.
// We ignore the mask and only emit the non-sync version for nvptx.
llvm::Intrinsic::ID ids[] = {
llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32,
llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32,
llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32};

int offset = 0;
if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) {
offset = 0;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) {
offset = 2;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) {
offset = 4;
} else {
return false;
}

*id = ids[offset + op->dtype.is_float()];
return true;
}

llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic;

if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
Expand Down Expand Up @@ -814,25 +781,6 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
}
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return CreateStorageSync(op);
} else if (GetWarpShuffleIntrinsic(op, &id)) {
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
// Ignore the first mask operand and remove the last
// redundant warp_size..
size_t n_args = op->args.size() - 1;
for (size_t i = 1; i < n_args; ++i) {
arg_value.push_back(MakeValue(op->args[i]));
arg_type.push_back(arg_value.back()->getType());
}
llvm::Type* return_type = arg_type[0];
llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type);
return builder_->CreateCall(func, arg_value);
} else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) {
// Only nvptx target may keep this intrinsic at this point.
// PTX assembly: asm "activemask.b32 r1;"
auto fty = llvm::FunctionType::get(t_int32_, false);
auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true);
return builder_->CreateCall(val);
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const LoadNode* l = op->args[0].as<LoadNode>();
CHECK(op->args.size() == 1 && l);
Expand Down
62 changes: 61 additions & 1 deletion src/target/llvm/codegen_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ class CodeGenNVPTX : public CodeGenLLVM {
CodeGenLLVM::Optimize();
}

llvm::Value* CreateIntrinsic(const CallNode* op) override;

protected:
void InitTarget(llvm::TargetMachine* tm) final {
// Maximum vector lane = float4
Expand All @@ -178,6 +180,62 @@ class CodeGenNVPTX : public CodeGenLLVM {
}
};

// Check if this is a warp shuffle intrinsic call and match its
// corresponding nvvm intrinsic. Return true if the match is successful.
static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) {
// Only 32 bit data type is supported.
if (op->dtype.is_vector() || op->dtype.bits() != 32) {
return false;
}

// Intrinsic lookup table.
// It is difficult to emit _sync verion that works on Pascal.
// We ignore the mask and only emit the non-sync version for nvptx.
llvm::Intrinsic::ID ids[] = {
llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32,
llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32,
llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32};

int offset = 0;
if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) {
offset = 0;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) {
offset = 2;
} else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) {
offset = 4;
} else {
return false;
}

*id = ids[offset + op->dtype.is_float()];
return true;
}

llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) {
llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic;
if (GetWarpShuffleIntrinsic(op, &id)) {
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
// Ignore the first mask operand and remove the last
// redundant warp_size..
size_t n_args = op->args.size() - 1;
for (size_t i = 1; i < n_args; ++i) {
arg_value.push_back(MakeValue(op->args[i]));
arg_type.push_back(arg_value.back()->getType());
}
llvm::Type* return_type = arg_type[0];
llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type);
return builder_->CreateCall(func, arg_value);
} else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) {
// Only nvptx target may keep this intrinsic at this point.
// PTX assembly: asm "activemask.b32 r1;"
auto fty = llvm::FunctionType::get(t_int32_, false);
auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true);
return builder_->CreateCall(val);
}
return CodeGenLLVM::CreateIntrinsic(op);
}

inline int DetectCUDAComputeVersion() {
TVMContext tvm_ctx;
tvm_ctx.device_type = kDLGPU;
Expand All @@ -204,8 +262,10 @@ runtime::Module BuildNVPTX(IRModule mod, std::string target) {
config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" << compute_ver
<< target.substr(5, target.length() - 5);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
// careful: cg will hold a naked pointer reference to ctx, so it should
// have a shorter lifetime than the ctx.
std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());

cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false);

Expand Down
3 changes: 3 additions & 0 deletions topi/python/topi/cuda/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def schedule_softmax(outs):
def sched_warp_softmax():
if tgt.target_name == "nvptx":
return softmax.dtype == "float32" or softmax.dtype == "int32"
elif tgt.target_name != "cuda":
# this is used as the gpu schedule for other arches which may not have warp reductions
return False
return True

if len(softmax.shape) > 2:
Expand Down

0 comments on commit a77ee86

Please sign in to comment.