Skip to content

Commit a43b94c

Browse files
authored
Add _FLOATE4M3 and _FLOATE5M2 data type to GemmDataType. (#74757)
1 parent 04c0f50 commit a43b94c

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

paddle/phi/kernels/fusion/cutlass/cutlass_kernels/gemm_config_manager.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ enum GemmDataType {
3636
_NVBFLOAT16,
3737
_INT8,
3838
_INT4,
39+
_FLOATE4M3,
40+
_FLOATE5M2,
3941
};
4042

4143
enum GemmType {
@@ -55,6 +57,10 @@ constexpr GemmDataType getGemmDataType() {
5557
return GemmDataType::_INT8;
5658
} else if constexpr (std::is_same<T, cutlass::uint4b_t>::value) {
5759
return GemmDataType::_INT4;
60+
} else if constexpr (std::is_same<T, cutlass::float_e4m3_t>::value) {
61+
return GemmDataType::_FLOATE4M3;
62+
} else if constexpr (std::is_same<T, cutlass::float_e5m2_t>::value) {
63+
return GemmDataType::_FLOATE5M2;
5864
} else {
5965
static_assert(!std::is_same<T, T>::value,
6066
"Unsupported data type combination for GemmDataType.");

0 commit comments

Comments
 (0)