Skip to content

Commit d2f68b9

Browse files
committed
Add fast math operations for CUDA: exp, exp10, log, log2, log10, tan, cos, and sin (tile-ai#865)
1 parent fa4fd0b commit d2f68b9

File tree

7 files changed

+634
-1
lines changed

7 files changed

+634
-1
lines changed

3rdparty/tvm

Submodule tvm updated from 0524f76 to 883e96b

src/op/builtin.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,50 @@ DataType cuTensorMapType() { return DataType::UInt(8, 128); }
4040
TVM_REGISTER_OP("tl." #OpName) \
4141
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName)
4242

43+
// fast math related op
44+
TIR_DEFINE_TL_BUILTIN(__exp)
45+
.set_num_inputs(1)
46+
.set_attr<TCallEffectKind>("TCallEffectKind",
47+
Integer(CallEffectKind::kOpaque));
48+
49+
50+
TIR_DEFINE_TL_BUILTIN(__exp10)
51+
.set_num_inputs(1)
52+
.set_attr<TCallEffectKind>("TCallEffectKind",
53+
Integer(CallEffectKind::kOpaque));
54+
55+
56+
TIR_DEFINE_TL_BUILTIN(__log)
57+
.set_num_inputs(1)
58+
.set_attr<TCallEffectKind>("TCallEffectKind",
59+
Integer(CallEffectKind::kOpaque));
60+
61+
TIR_DEFINE_TL_BUILTIN(__log2)
62+
.set_num_inputs(1)
63+
.set_attr<TCallEffectKind>("TCallEffectKind",
64+
Integer(CallEffectKind::kOpaque));
65+
66+
TIR_DEFINE_TL_BUILTIN(__log10)
67+
.set_num_inputs(1)
68+
.set_attr<TCallEffectKind>("TCallEffectKind",
69+
Integer(CallEffectKind::kOpaque));
70+
71+
TIR_DEFINE_TL_BUILTIN(__tan)
72+
.set_num_inputs(1)
73+
.set_attr<TCallEffectKind>("TCallEffectKind",
74+
Integer(CallEffectKind::kOpaque));
75+
76+
TIR_DEFINE_TL_BUILTIN(__cos)
77+
.set_num_inputs(1)
78+
.set_attr<TCallEffectKind>("TCallEffectKind",
79+
Integer(CallEffectKind::kOpaque));
80+
81+
82+
TIR_DEFINE_TL_BUILTIN(__sin)
83+
.set_num_inputs(1)
84+
.set_attr<TCallEffectKind>("TCallEffectKind",
85+
Integer(CallEffectKind::kOpaque));
86+
4387
TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier)
4488
.set_num_inputs(-1)
4589
.set_attr<TCallEffectKind>("TCallEffectKind",

src/op/builtin.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,17 @@ static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";
7575
*/
7676
DataType cuTensorMapType();
7777

78+
79+
// fast math related op
80+
TVM_DLL const Op &__exp();
81+
TVM_DLL const Op &__exp10();
82+
TVM_DLL const Op &__log();
83+
TVM_DLL const Op &__log2();
84+
TVM_DLL const Op &__log10();
85+
TVM_DLL const Op &__tan();
86+
TVM_DLL const Op &__cos();
87+
TVM_DLL const Op &__sin();
88+
7889
/*!
7990
* \brief tvm intrinsics for TMADescriptor creation for tiled load
8091
*

src/target/codegen_cuda.cc

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,80 @@ namespace tvm {
2121
namespace codegen {
2222
using namespace tvm::tl::codegen;
2323

24+
struct CUDAMath {
25+
std::string operator()(DataType t, std::string name) const {
26+
if (t.is_float()) {
27+
switch (t.bits()) {
28+
case 64:
29+
return name;
30+
case 32:
31+
return name + 'f';
32+
case 16: {
33+
if (name == "fabs") {
34+
return "__habs";
35+
} else if (name == "round") {
36+
return "hrint";
37+
} else {
38+
return "h" + name;
39+
}
40+
}
41+
default:
42+
return "";
43+
}
44+
} else if (t.is_bfloat16()) {
45+
if (name == "fabs") {
46+
return "__habs";
47+
} else if (name == "round") {
48+
return "hrint";
49+
} else {
50+
return "h" + name;
51+
}
52+
} else if (t.is_int() || t.is_uint()) {
53+
switch (t.bits()) {
54+
case 32:
55+
return "__" + name;
56+
case 64:
57+
return "__" + name + "ll";
58+
default:
59+
return "";
60+
}
61+
}
62+
return "";
63+
}
64+
};
65+
66+
67+
struct CUDAFastMath : public CUDAMath {
68+
std::string operator()(DataType t, std::string name) const {
69+
if (t.is_float() && t.bits() == 32) {
70+
return "__" + name + 'f';
71+
} else {
72+
return CUDAMath::operator()(t, name);
73+
}
74+
return "";
75+
}
76+
};
77+
78+
struct CUDAFastMathTan : public CUDAMath {
79+
std::string operator()(DataType t, std::string name) const {
80+
if (t.is_float()) {
81+
switch (t.bits()) {
82+
case 64:
83+
return name;
84+
// `__tanf` seems to produce some values too deviant from numpy tan
85+
// version. So, let's use just `tanf` instead.
86+
case 32:
87+
return name + 'f';
88+
case 16:
89+
return 'h' + name;
90+
default:
91+
return "";
92+
}
93+
}
94+
return "";
95+
}
96+
};
97+
2498
static std::string GetFP8Type(DataType type) {
2599
std::stringstream stream;
26100
int32_t lanes = type.lanes();
@@ -1628,6 +1702,38 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
16281702
op->args, true, os);
16291703
} else if (op->op.same_as(tl::tl_shuffle_elect())) {
16301704
os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()";
1705+
} else if (op->op.same_as(tl::__exp())) {
1706+
CUDAFastMath math_func;
1707+
std::string func_name = math_func(op->dtype, "exp");
1708+
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
1709+
} else if (op->op.same_as(tl::__exp10())) {
1710+
CUDAFastMath math_func;
1711+
std::string func_name = math_func(op->dtype, "exp10");
1712+
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
1713+
} else if (op->op.same_as(tl::__log())) {
1714+
CUDAFastMath math_func;
1715+
std::string func_name = math_func(op->dtype, "log");
1716+
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
1717+
} else if (op->op.same_as(tl::__log2())) {
1718+
CUDAFastMath math_func;
1719+
std::string func_name = math_func(op->dtype, "log2");
1720+
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
1721+
} else if (op->op.same_as(tl::__log10())) {
1722+
CUDAFastMath math_func;
1723+
std::string func_name = math_func(op->dtype, "log10");
1724+
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
1725+
} else if (op->op.same_as(tl::__tan())) {
1726+
CUDAFastMath math_func;
1727+
std::string func_name = math_func(op->dtype, "tan");
1728+
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
1729+
} else if (op->op.same_as(tl::__cos())) {
1730+
CUDAFastMath math_func;
1731+
std::string func_name = math_func(op->dtype, "cos");
1732+
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
1733+
} else if (op->op.same_as(tl::__sin())) {
1734+
CUDAFastMath math_func;
1735+
std::string func_name = math_func(op->dtype, "sin");
1736+
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
16311737
} else {
16321738
CodeGenC::VisitExpr_(op, os);
16331739
}

0 commit comments

Comments
 (0)