Skip to content

Commit e448754

Browse files
author
nicunxiao
committed
fix: Reinterpret types to cute types in GEMM
1 parent 0aad651 commit e448754

File tree

3 files changed

+21
-10
lines changed

3 files changed

+21
-10
lines changed

src/tl_templates/cuda/common.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,8 @@ TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
322322
}
323323

324324
// and add the desired implicit conversion from bfloat16_t.
325-
struct float_e4m3_t : public cutlass::float_e4m3_t {
326-
using cutlass::float_e4m3_t::float_e4m3_t;
325+
struct float_e4m3_t : public cute::float_e4m3_t {
326+
using cute::float_e4m3_t::float_e4m3_t;
327327
CUTLASS_HOST_DEVICE
328328
float_e4m3_t() = default;
329329

@@ -332,8 +332,8 @@ struct float_e4m3_t : public cutlass::float_e4m3_t {
332332
: float_e4m3_t(static_cast<float>(x)) {}
333333
};
334334

335-
struct float_e5m2_t : public cutlass::float_e5m2_t {
336-
using cutlass::float_e5m2_t::float_e5m2_t;
335+
struct float_e5m2_t : public cute::float_e5m2_t {
336+
using cute::float_e5m2_t::float_e5m2_t;
337337
CUTLASS_HOST_DEVICE
338338
float_e5m2_t() = default;
339339

src/tl_templates/cuda/gemm_mma.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,18 +257,24 @@ struct OperandTraits<64, N, K, false, num_warp_n, leading_dim,
257257
using Copy = DefaultCopy;
258258
};
259259

260+
template<typename T> struct to_cute_type {using type = T;};
261+
template<> struct to_cute_type<tl::float_e4m3_t> {using type = cute::float_e4m3_t;};
262+
template<> struct to_cute_type<tl::float_e5m2_t> {using type = cute::float_e5m2_t;};
263+
260264
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
261265
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
262266
int offset_b, typename A_type_raw, typename B_type_raw,
263267
typename C_type_raw>
264268
class GemmTensorOp {
265269
public:
270+
using A_type_cute = typename to_cute_type<A_type_raw>::type;
271+
using B_type_cute = typename to_cute_type<B_type_raw>::type;
266272
using A_type =
267-
typename std::conditional<std::is_same<A_type_raw, float>::value,
273+
typename std::conditional<std::is_same<A_type_cute, float>::value,
268274
tfloat32_t, A_type_raw>::type;
269275
using B_type =
270276
typename std::conditional<std::is_same<B_type_raw, float>::value,
271-
tfloat32_t, A_type_raw>::type;
277+
tfloat32_t, B_type_cute>::type;
272278
using C_type = C_type_raw;
273279

274280
using Instruction =

src/tl_templates/cuda/gemm_sm90.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,21 @@ using namespace SM90;
1515
namespace tl_wgmma {
1616

1717
using namespace cutlass::gemm::collective::detail; // ss_smem_selector
18+
template<typename T> struct to_cute_type {using type = T;};
19+
template<> struct to_cute_type<tl::float_e4m3_t> {using type = cute::float_e4m3_t;};
20+
template<> struct to_cute_type<tl::float_e5m2_t> {using type = cute::float_e5m2_t;};
1821

1922
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
2023
bool trans_B, bool clear_accum, typename A_type_raw,
2124
typename B_type_raw, typename C_type_raw>
2225
class GemmTensorOp {
2326
public:
24-
using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
25-
tfloat32_t, A_type_raw>;
26-
using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
27-
tfloat32_t, B_type_raw>;
27+
using A_type_cute = typename to_cute_type<A_type_raw>::type;
28+
using B_type_cute = typename to_cute_type<B_type_raw>::type;
29+
using A_type = conditional_t<std::is_same<A_type_cute, float>::value,
30+
tfloat32_t, A_type_cute>;
31+
using B_type = conditional_t<std::is_same<B_type_cute, float>::value,
32+
tfloat32_t, A_type_cute>;
2833
using C_type = C_type_raw;
2934

3035
static constexpr GMMA::Major GmmaMajorA =

0 commit comments

Comments
 (0)