Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorCore Support using Intrinsic #4136

Merged
merged 12 commits into from
Oct 24, 2019
58 changes: 58 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,16 @@ constexpr const char* opengl_stage_scope = "opengl_stage_scope";
*/
constexpr const char* device_scope = "device_scope";

/*!
* \brief Mark that the shape of TensorCore fragment
*/
constexpr const char* fragment_shape = "fragment_shape";

/*!
* \brief Mark that the layout of TensorCore fragment
*/
constexpr const char* fragment_layout = "fragment_layout";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down Expand Up @@ -1552,6 +1562,54 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
* }
*/
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
/*!
* \brief tvm intrinsic for tensor core load operators.
*
* void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
* Expr index, Expr buffer_ptr, Expr stride,
* StringImm layout) {
* // m, n, k are the shape of wmma fragment.
* // Determine fragment layout(column-major or row major) by layout.
* // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope.
* nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride);
* }
*/
constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync";
vinx13 marked this conversation as resolved.
Show resolved Hide resolved
/*!
* \brief tvm intrinsic for tensor core mma_sync operators.
*
* void tvm_mma_sync(Var fragment_d, Expr index_d,
* Var fragment_a, Expr index_a,
* Var fragment_b, Expr index_b,
* Var fragment_c, Expr index_c) {
* nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a],
* fragment_b[index_b], fragment_c[index_c]);
* }
*/
constexpr const char* tvm_mma_sync = "tvm_mma_sync";
/*!
* \brief tvm intrinsic for tensor core fill_fragment operators.
*
* void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
* Expr index, Expr value) {
* // m, n, k are the shape of wmma fragment
* // fragments must be in 'wmma.accumulator' scope.
* nvcuda::wmma::fill_fragment(fragment[index], value);
* }
*/
constexpr const char* tvm_fill_fragment = "tvm_fill_fragment";
/*!
* \brief tvm intrinsic for tensor core store operators.
*
* void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
* Expr index, Expr buffer_ptr, Expr stride,
* StringImm layout) {
* // m, n, k are the shape of wmma fragment
* // fragments must be in 'wmma.accumulator' scope.
* nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout);
* }
*/
constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";

} // namespace intrinsic

Expand Down
17 changes: 17 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,15 @@ LoweredFunc CombineContextCall(LoweredFunc f);
*/
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);

/*!
* \brief Lower attached storage access information on device.
* Do this pass after all storage access analysis finish.
*
* \param func The device function to be lowered.
* \return Transformed function.
*/
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func);

/*!
* \brief Lower intrinsic function calls.
* \param f The device function to be lowered.
Expand All @@ -532,6 +541,14 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
*/
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);

/*!
* \brief Infer the TensorCore fragment infomation using tensor intrinsics
*
* \param f The device function to be lowered.
* \return Transformed function.
*/
LoweredFunc InferFragment(LoweredFunc f);

/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,6 @@ def lower(sch,

# Phase 3
stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.LowerStorageAccessInfo(stmt)
stmt = ir_pass.RemoveNoOp(stmt)
if not cfg.disable_select_rewriting:
stmt = ir_pass.RewriteUnsafeSelect(stmt)
Expand Down Expand Up @@ -465,6 +464,7 @@ def _build_for_device(flist, target, target_host):
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
func = ir_pass.ThreadSync(func, "warp")
func = ir_pass.InferFragment(func)
warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(func)]
Expand Down Expand Up @@ -494,6 +494,8 @@ def _build_for_device(flist, target, target_host):
assert not fdevice

target_host = _target.create(target_host)
fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
Expand Down
10 changes: 10 additions & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
});
});

TVM_REGISTER_API("ir_pass.LowerStorageAccess")
.set_body([](TVMArgs args, TVMRetValue *ret) {
LoweredFunc f = args[0];
auto n = make_node<LoweredFuncNode>(*f.operator->());
n->body = LowerStorageAccessInfo(f->body);
*ret = LoweredFunc(n);
});

// make from two arguments
#define REGISTER_PASS(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
Expand All @@ -140,6 +148,7 @@ REGISTER_PASS(SplitHostDevice);
REGISTER_PASS(StorageRewrite);
REGISTER_PASS(CoProcSync);
REGISTER_PASS(LowerStorageAccessInfo);
REGISTER_PASS(LowerDeviceStorageAccessInfo)
REGISTER_PASS(InjectVirtualThread);
REGISTER_PASS(InjectPrefetch);
REGISTER_PASS(InjectDoubleBuffer);
Expand All @@ -161,5 +170,6 @@ REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse);
REGISTER_PASS(InferFragment)
} // namespace ir
} // namespace tvm
3 changes: 2 additions & 1 deletion src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,6 @@ Stmt BuildStmt(Schedule sch,

// Phase 2
stmt = ir::Simplify(stmt);
stmt = ir::LowerStorageAccessInfo(stmt);
stmt = ir::RemoveNoOp(stmt);

if (!(config->disable_select_rewriting))
Expand Down Expand Up @@ -517,13 +516,15 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i];
func = ir::BindDeviceType(func, target->device_type);
func = ir::LowerDeviceStorageAccessInfo(func);
func = ir::LowerTVMBuiltin(func);
fhost.Set(i, func);
}

for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i];
func = ir::LowerIntrin(func, target_host->target_name);
func = ir::LowerDeviceStorageAccessInfo(func);
func = ir::CombineContextCall(func);
fhost.Set(i, func);
}
Expand Down
167 changes: 165 additions & 2 deletions src/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <math_constants.h>\n";
}

if (need_mma_h_) {
decl_stream << "#include <mma.h>\n";
}

return CodeGenC::Finish();
}

Expand Down Expand Up @@ -102,14 +106,22 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
case 16: os << "half";
case 16:
enable_fp16_ = true;
if (lanes == 1) {
os << "half";
} else if (lanes <= 8) {
CHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "float" << lanes / 2;
} else {
fail = true;
}
break;
case 32: os << "float"; break;
case 64: os << "double"; break;
default: fail = true; break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes == 1 || t.bits() == 16)) return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes; return;
}
Expand Down Expand Up @@ -290,6 +302,113 @@ void CodeGenCUDA::PrintStorageScope(
}
}

void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 6U);
os << "nvcuda::wmma::fill_fragment(";
this->PrintExpr(op->args[0], os);
os << "[";
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[5], os);
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::load_matrix_sync(";
this->PrintExpr(op->args[0], os);
os << "[";
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[5], os);
os << ", ";
this->PrintExpr(op->args[6], os);
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::store_matrix_sync(";
this->PrintExpr(op->args[5], os);
os << ", ";
this->PrintExpr(op->args[0], os);
os << "[";
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[6], os);
if (const StringImm *str = op->args[7].as<StringImm>()) {
os << ", nvcuda::wmma::mem_" << str->value;
} else {
LOG(FATAL) << "Invalid parameters";
}
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::mma_sync(";
for (int i = 0; i < 4; ++i) {
this->PrintExpr(op->args[i * 2], os);
os << "[";
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", ": ")");
}
} else {
CodeGenC::VisitExpr_(op, os);
}
}

void CodeGenCUDA::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::fragment_shape) {
const Variable* buffer = op->node.as<Variable>();
const StringImm* shape_str = op->value.as<StringImm>();
fragment_shapes[buffer] = shape_str->value;
} else if (op->attr_key == attr::fragment_layout) {
const Variable* buffer = op->node.as<Variable>();
const StringImm* layout_str = op->value.as<StringImm>();
fragment_layouts[buffer] = layout_str->value;
}
CodeGenC::VisitStmt_(op);
}

void CodeGenCUDA::VisitStmt_(const Allocate* op) {
CHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
// Prefer global static allocation for the program
CHECK_EQ(op->free_function, "nop");
std::string new_data = PrintExpr(op->new_expr);
this->PrintIndent();
PrintType(op->type, stream);
stream << "* "<< vid << '=' << new_data << ";\n";
} else {
this->PrintIndent();
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
const Variable* buffer = op->buffer_var.as<Variable>();
std::string scope = alloc_storage_scope_.at(buffer);
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
CHECK(op->type == Float(16) || op->type == Int(8) || op->type == UInt(8))
<< "Matrix_a and matrix_b only support half or char or unsigned char type for now";
} else {
CHECK(op->type == Float(16) || op->type == Float(32) || op->type == Int(32))
<< "Accumulator only support half, float and int type for now";
}
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
PrintWmmaScope(scope, op->type, buffer, stream);
} else {
PrintStorageScope(scope, stream);
stream << ' ';
PrintType(op->type, stream);
}
stream << ' '<< vid << '['
<< constant_size << "];\n";
}
RegisterHandleType(op->buffer_var.get(), op->type);
this->PrintStmt(op->body);
}

void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
if (is_const(op->value)) return;
const Call* call = op->value.as<Call>();
Expand Down Expand Up @@ -392,5 +511,49 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*
PrintConst(op, os, this);
}

void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t,
const Variable* variable, std::ostream &os) {
std::stringstream type;
PrintType(t, type);
std::string shape_str = fragment_shapes[variable];
if (scope == "wmma.matrix_a") {
need_mma_h_ = true;
std::string layout_str = fragment_layouts[variable];
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, "
<< shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">";
} else if (scope == "wmma.matrix_b") {
need_mma_h_ = true;
std::string layout_str = fragment_layouts[variable];
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, "
<< shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">";
} else if (scope == "wmma.accumulator") {
need_mma_h_ = true;
os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, "
<< shape_str << ", "<< type.str() << ">";
}
}

int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope,
const Variable* variable, int32_t size) {
std::string shape_str = fragment_shapes[variable];
size_t m, n, k;
size_t last_pos = 0, pos = 0;
pos = shape_str.find(", ", last_pos);
m = std::stoi(shape_str.substr(last_pos, pos - last_pos));
last_pos = pos + 2;
pos = shape_str.find(", ", last_pos);
n = std::stoi(shape_str.substr(last_pos, pos - last_pos));
last_pos = pos + 2;
k = std::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos));
if (scope == "wmma.matrix_a") {
return size / m / k;
} else if (scope == "wmma.matrix_b") {
return size / n / k;
} else if (scope == "wmma.accumulator") {
return size / m / n;
}
return 0;
}

} // namespace codegen
} // namespace tvm
Loading