Skip to content

Commit 5c25147

Browse files
author
nicunxiao
committed
Overload fp8 for implicit conversion
1 parent 900ae67 commit 5c25147

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

src/target/codegen_cuda.cc

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -953,23 +953,18 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
953953
}
954954
}
955955

956-
const char *convert_part =
957-
(from_ty.is_bfloat16() &&
958-
(target_ty.is_float8_e4m3() || target_ty.is_float8_e5m2()))
959-
? ")(half)("
960-
: ")(";
961-
962956
// Fallback: elementwise cast
963957
for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
964958
std::ostringstream val;
965959
val << "(";
966960
PrintType(target_ty.element_of(), val);
967-
val << convert_part;
961+
val << ")(";
968962
PrintVecElemLoad(src, from_ty, i, val);
969963
val << ")";
970964
PrintVecElemStore(sret, target_ty, i, val.str());
971965
}
972966

967+
973968
if (used_bf16_op) {
974969
stream << "#endif\n";
975970
}

src/tl_templates/cuda/common.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
#include <cutlass/numeric_types.h>
1111
#include <math_constants.h>
1212

13+
#include <cutlass/float8.h>
14+
#include <cutlass/bfloat16.h>
15+
1316
using cutlass::bfloat16_t;
1417
using cutlass::half_t;
1518
using cutlass::tfloat32_t;
@@ -318,6 +321,27 @@ TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
318321
descriptor.reg32_[0] += (offset >> 4);
319322
}
320323

324+
// and add the desired implicit conversion from bfloat16_t.
325+
struct float_e4m3_t : public cutlass::float_e4m3_t {
326+
using cutlass::float_e4m3_t::float_e4m3_t;
327+
CUTLASS_HOST_DEVICE
328+
float_e4m3_t() = default;
329+
330+
CUTLASS_HOST_DEVICE
331+
explicit float_e4m3_t(__nv_bfloat16 x) : float_e4m3_t(static_cast<float>(x)) {
332+
}
333+
};
334+
335+
struct float_e5m2_t : public cutlass::float_e5m2_t {
336+
using cutlass::float_e5m2_t::float_e5m2_t;
337+
CUTLASS_HOST_DEVICE
338+
float_e5m2_t() = default;
339+
340+
CUTLASS_HOST_DEVICE
341+
explicit float_e5m2_t(__nv_bfloat16 x) : float_e5m2_t(static_cast<float>(x)) {
342+
}
343+
};
344+
321345
} // namespace tl
322346

323347
namespace cutlass {

src/tl_templates/cuda/cuda_fp8.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
#include <cuda_fp8.h>
44
#include <cute/numeric/numeric_types.hpp>
5+
#include "common.h"
56

6-
using fp8_e4_t = cute::float_e4m3_t;
7-
using fp8_e5_t = cute::float_e5m2_t;
7+
using fp8_e4_t = tl::float_e4m3_t;
8+
using fp8_e5_t = tl::float_e5m2_t;
89

910
struct __CUDA_ALIGN__(2) fp8_e4_2_t {
1011
fp8_e4_t x;

0 commit comments

Comments
 (0)