@@ -21,6 +21,80 @@ namespace tvm {
2121namespace codegen {
2222using namespace tvm ::tl::codegen;
2323
24+ struct CUDAMath {
25+ std::string operator ()(DataType t, std::string name) const {
26+ if (t.is_float ()) {
27+ switch (t.bits ()) {
28+ case 64 :
29+ return name;
30+ case 32 :
31+ return name + ' f' ;
32+ case 16 : {
33+ if (name == " fabs" ) {
34+ return " __habs" ;
35+ } else if (name == " round" ) {
36+ return " hrint" ;
37+ } else {
38+ return " h" + name;
39+ }
40+ }
41+ default :
42+ return " " ;
43+ }
44+ } else if (t.is_bfloat16 ()) {
45+ if (name == " fabs" ) {
46+ return " __habs" ;
47+ } else if (name == " round" ) {
48+ return " hrint" ;
49+ } else {
50+ return " h" + name;
51+ }
52+ } else if (t.is_int () || t.is_uint ()) {
53+ switch (t.bits ()) {
54+ case 32 :
55+ return " __" + name;
56+ case 64 :
57+ return " __" + name + " ll" ;
58+ default :
59+ return " " ;
60+ }
61+ }
62+ return " " ;
63+ }
64+ };
65+
66+
67+ struct CUDAFastMath : public CUDAMath {
68+ std::string operator ()(DataType t, std::string name) const {
69+ if (t.is_float () && t.bits () == 32 ) {
70+ return " __" + name + ' f' ;
71+ } else {
72+ return CUDAMath::operator ()(t, name);
73+ }
74+ return " " ;
75+ }
76+ };
77+
78+ struct CUDAFastMathTan : public CUDAMath {
79+ std::string operator ()(DataType t, std::string name) const {
80+ if (t.is_float ()) {
81+ switch (t.bits ()) {
82+ case 64 :
83+ return name;
84+ // `__tanf` seems to produce some values too deviant from numpy tan
85+ // version. So, let's use just `tanf` instead.
86+ case 32 :
87+ return name + ' f' ;
88+ case 16 :
89+ return ' h' + name;
90+ default :
91+ return " " ;
92+ }
93+ }
94+ return " " ;
95+ }
96+ };
97+
2498static std::string GetFP8Type (DataType type) {
2599 std::stringstream stream;
26100 int32_t lanes = type.lanes ();
@@ -1628,6 +1702,38 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
16281702 op->args , true , os);
16291703 } else if (op->op .same_as (tl::tl_shuffle_elect ())) {
16301704 os << " tl::tl_shuffle_elect<" << PrintExpr (op->args [0 ]) << " >()" ;
1705+ } else if (op->op .same_as (tl::__exp ())) {
1706+ CUDAFastMath math_func;
1707+ std::string func_name = math_func (op->dtype , " exp" );
1708+ os << func_name << " (" << PrintExpr (op->args [0 ]) << " )" ;
1709+ } else if (op->op .same_as (tl::__exp10 ())) {
1710+ CUDAFastMath math_func;
1711+ std::string func_name = math_func (op->dtype , " exp10" );
1712+ os << func_name << " (" << PrintExpr (op->args [0 ]) << " )" ;
1713+ } else if (op->op .same_as (tl::__log ())) {
1714+ CUDAFastMath math_func;
1715+ std::string func_name = math_func (op->dtype , " log" );
1716+ os << func_name << " (" << PrintExpr (op->args [0 ]) << " )" ;
1717+ } else if (op->op .same_as (tl::__log2 ())) {
1718+ CUDAFastMath math_func;
1719+ std::string func_name = math_func (op->dtype , " log2" );
1720+ os << func_name << " (" << PrintExpr (op->args [0 ]) << " )" ;
1721+ } else if (op->op .same_as (tl::__log10 ())) {
1722+ CUDAFastMath math_func;
1723+ std::string func_name = math_func (op->dtype , " log10" );
1724+ os << func_name << " (" << PrintExpr (op->args [0 ]) << " )" ;
1725+ } else if (op->op .same_as (tl::__tan ())) {
1726+ CUDAFastMath math_func;
1727+ std::string func_name = math_func (op->dtype , " tan" );
1728+ os << func_name << " (" << PrintExpr (op->args [0 ]) << " )" ;
1729+ } else if (op->op .same_as (tl::__cos ())) {
1730+ CUDAFastMath math_func;
1731+ std::string func_name = math_func (op->dtype , " cos" );
1732+ os << func_name << " (" << PrintExpr (op->args [0 ]) << " )" ;
1733+ } else if (op->op .same_as (tl::__sin ())) {
1734+ CUDAFastMath math_func;
1735+ std::string func_name = math_func (op->dtype , " sin" );
1736+ os << func_name << " (" << PrintExpr (op->args [0 ]) << " )" ;
16311737 } else {
16321738 CodeGenC::VisitExpr_ (op, os);
16331739 }
0 commit comments