Skip to content

Commit 8281b05

Browse files
author
nicunxiao
committed
implement overloaded cast codegen for type conversion
1 parent cb907dd commit 8281b05

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

src/op/copy.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,10 +325,6 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
325325

326326
PrimExpr value = BufferLoad(src, src_indices);
327327
if (src->dtype != dst->dtype) {
328-
// If dst is fp8 and src is bf16, first cast dst to fp32.
329-
if (src->dtype.is_bfloat16() && dst->dtype.is_float8_e4m3()) {
330-
value = Cast(DataType::Float(32), value);
331-
}
332328
value = Cast(dst->dtype, value);
333329
}
334330
if (src_predicate.defined())

src/target/codegen_cuda.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,12 +953,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
953953
}
954954
}
955955

956+
const char *convert_part =
957+
(from_ty.is_bfloat16() || target_ty.is_float8_e4m3()) ? ")(half)(" : ")(";
958+
956959
// Fallback: elementwise cast
957960
for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
958961
std::ostringstream val;
959962
val << "(";
960963
PrintType(target_ty.element_of(), val);
961-
val << ")(";
964+
val << convert_part;
962965
PrintVecElemLoad(src, from_ty, i, val);
963966
val << ")";
964967
PrintVecElemStore(sret, target_ty, i, val.str());

0 commit comments

Comments
 (0)