Skip to content

Commit d44ee95

Browse files
committed
[Feature] Add float32 to float8 conversion support in CUDA codegen
* Implemented handling for conversion from float32 to float8 (E4M3/E5M2) in the VisitExpr_ method. * Added vectorized conversion support using __nv_cvt_float2_to_fp8x2 for float2 to fp8x2 transformations. * Enhanced type handling for better compatibility with TileLang, particularly for float8 types.
1 parent 6debbb9 commit d44ee95

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

src/target/codegen_cuda.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
920920
}
921921
}
922922

923+
// Handle conversion from float32 to float8 (E4M3/E5M2)
924+
if (from_ty.is_float() && target_ty.is_float8()) {
925+
// FP32 -> FP8: Use __nv_cvt_float2_to_fp8x2 for vectorized conversion (float2 -> fp8x2)
926+
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
927+
PrintIndent();
928+
stream << "*reinterpret_cast<__nv_fp8x2_storage_t*>(&(" << sret << ")) = __nv_cvt_float2_to_fp8x2(*reinterpret_cast<float2*>(&(" << src
929+
<< ")), __NV_SATFINITE, "
930+
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") << ");\n";
931+
os << sret;
932+
return;
933+
}
934+
}
935+
923936
// Handle bfloat16 special cases with supported ops
924937
bool used_bf16_op = false;
925938
if (from_ty.is_bfloat16() || target_ty.is_bfloat16()) {
@@ -970,6 +983,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
970983
}
971984
stream << " const &>(" << src << "));\n";
972985
stream << "#else\n";
986+
// bf16 cases don't need early return, as we use elementwise cast as fallback
973987
}
974988
}
975989

0 commit comments

Comments
 (0)